From: Mike Bayer Date: Sun, 6 Jan 2019 06:14:26 +0000 (-0500) Subject: Run black -l 79 against all source files X-Git-Tag: rel_1_3_0b2~46 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1e1a38e7801f410f244e4bbb44ec795ae152e04e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Run black -l 79 against all source files This is a straight reformat run using black as is, with no edits applied at all. The black run will format code consistently, however in some cases that are prevalent in SQLAlchemy code it produces too-long lines. The too-long lines will be resolved in the following commit that will resolve all remaining flake8 issues including shadowed builtins, long lines, import order, unused imports, duplicate imports, and docstring issues. Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9 --- diff --git a/examples/adjacency_list/adjacency_list.py b/examples/adjacency_list/adjacency_list.py index 3c91443235..f1628a6329 100644 --- a/examples/adjacency_list/adjacency_list.py +++ b/examples/adjacency_list/adjacency_list.py @@ -8,7 +8,7 @@ Base = declarative_base() class TreeNode(Base): - __tablename__ = 'tree' + __tablename__ = "tree" id = Column(Integer, primary_key=True) parent_id = Column(Integer, ForeignKey(id)) name = Column(String(50), nullable=False) @@ -17,15 +17,13 @@ class TreeNode(Base): "TreeNode", # cascade deletions cascade="all, delete-orphan", - # many to one + adjacency list - remote_side # is required to reference the 'remote' # column in the join condition. backref=backref("parent", remote_side=id), - # children will be represented as a dictionary # on the "name" attribute. - collection_class=attribute_mapped_collection('name'), + collection_class=attribute_mapped_collection("name"), ) def __init__(self, name, parent=None): @@ -36,20 +34,20 @@ class TreeNode(Base): return "TreeNode(name=%r, id=%r, parent_id=%r)" % ( self.name, self.id, - self.parent_id + self.parent_id, ) def dump(self, _indent=0): - return " " * _indent + repr(self) + \ - "\n" + \ - "".join([ - c.dump(_indent + 1) - for c in self.children.values() - ]) + return ( + " " * _indent + + repr(self) + + "\n" + + "".join([c.dump(_indent + 1) for c in self.children.values()]) + ) -if __name__ == '__main__': - engine = create_engine('sqlite://', echo=True) +if __name__ == "__main__": + engine = create_engine("sqlite://", echo=True) def msg(msg, *args): msg = msg % args @@ -63,14 +61,14 @@ if __name__ == '__main__': session = Session(engine) - node = TreeNode('rootnode') - TreeNode('node1', parent=node) - TreeNode('node3', parent=node) + node = TreeNode("rootnode") + TreeNode("node1", parent=node) + TreeNode("node3", parent=node) - node2 = TreeNode('node2') - TreeNode('subnode1', parent=node2) - node.children['node2'] = node2 - TreeNode('subnode2', parent=node.children['node2']) + node2 = TreeNode("node2") + TreeNode("subnode1", parent=node2) + node.children["node2"] = node2 + TreeNode("subnode2", parent=node.children["node2"]) msg("Created new tree structure:\n%s", node.dump()) @@ -81,28 +79,33 @@ if __name__ == '__main__': msg("Tree After Save:\n %s", node.dump()) - TreeNode('node4', parent=node) - TreeNode('subnode3', parent=node.children['node4']) - TreeNode('subnode4', parent=node.children['node4']) - TreeNode('subsubnode1', parent=node.children['node4'].children['subnode3']) + TreeNode("node4", parent=node) + TreeNode("subnode3", parent=node.children["node4"]) + TreeNode("subnode4", parent=node.children["node4"]) + TreeNode("subsubnode1", parent=node.children["node4"].children["subnode3"]) # remove node1 from the parent, which will trigger a delete # via the delete-orphan cascade. - del node.children['node1'] + del node.children["node1"] msg("Removed node1. flush + commit:") session.commit() msg("Tree after save:\n %s", node.dump()) - msg("Emptying out the session entirely, selecting tree on root, using " - "eager loading to join four levels deep.") + msg( + "Emptying out the session entirely, selecting tree on root, using " + "eager loading to join four levels deep." + ) session.expunge_all() - node = session.query(TreeNode).\ - options(joinedload_all("children", "children", - "children", "children")).\ - filter(TreeNode.name == "rootnode").\ - first() + node = ( + session.query(TreeNode) + .options( + joinedload_all("children", "children", "children", "children") + ) + .filter(TreeNode.name == "rootnode") + .first() + ) msg("Full Tree:\n%s", node.dump()) diff --git a/examples/association/basic_association.py b/examples/association/basic_association.py index 6714aa6816..52476f1842 100644 --- a/examples/association/basic_association.py +++ b/examples/association/basic_association.py @@ -12,8 +12,16 @@ of "items", with a particular price paid associated with each "item". from datetime import datetime -from sqlalchemy import (create_engine, Column, Integer, String, DateTime, - Float, ForeignKey, and_) +from sqlalchemy import ( + create_engine, + Column, + Integer, + String, + DateTime, + Float, + ForeignKey, + and_, +) from sqlalchemy.orm import relationship, Session from sqlalchemy.ext.declarative import declarative_base @@ -21,20 +29,21 @@ Base = declarative_base() class Order(Base): - __tablename__ = 'order' + __tablename__ = "order" order_id = Column(Integer, primary_key=True) customer_name = Column(String(30), nullable=False) order_date = Column(DateTime, nullable=False, default=datetime.now()) - order_items = relationship("OrderItem", cascade="all, delete-orphan", - backref='order') + order_items = relationship( + "OrderItem", cascade="all, delete-orphan", backref="order" + ) def __init__(self, customer_name): self.customer_name = customer_name class Item(Base): - __tablename__ = 'item' + __tablename__ = "item" item_id = Column(Integer, primary_key=True) description = Column(String(30), nullable=False) price = Column(Float, nullable=False) @@ -44,41 +53,40 @@ class Item(Base): self.price = price def __repr__(self): - return 'Item(%r, %r)' % ( - self.description, self.price - ) + return "Item(%r, %r)" % (self.description, self.price) class OrderItem(Base): - __tablename__ = 'orderitem' - order_id = Column(Integer, ForeignKey('order.order_id'), primary_key=True) - item_id = Column(Integer, ForeignKey('item.item_id'), primary_key=True) + __tablename__ = "orderitem" + order_id = Column(Integer, ForeignKey("order.order_id"), primary_key=True) + item_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True) price = Column(Float, nullable=False) def __init__(self, item, price=None): self.item = item self.price = price or item.price - item = relationship(Item, lazy='joined') + + item = relationship(Item, lazy="joined") -if __name__ == '__main__': - engine = create_engine('sqlite://') +if __name__ == "__main__": + engine = create_engine("sqlite://") Base.metadata.create_all(engine) session = Session(engine) # create catalog tshirt, mug, hat, crowbar = ( - Item('SA T-Shirt', 10.99), - Item('SA Mug', 6.50), - Item('SA Hat', 8.99), - Item('MySQL Crowbar', 16.99) + Item("SA T-Shirt", 10.99), + Item("SA Mug", 6.50), + Item("SA Hat", 8.99), + Item("MySQL Crowbar", 16.99), ) session.add_all([tshirt, mug, hat, crowbar]) session.commit() # create an order - order = Order('john smith') + order = Order("john smith") # add three OrderItem associations to the Order and save order.order_items.append(OrderItem(mug)) @@ -88,13 +96,18 @@ if __name__ == '__main__': session.commit() # query the order, print items - order = session.query(Order).filter_by(customer_name='john smith').one() - print([(order_item.item.description, order_item.price) - for order_item in order.order_items]) + order = session.query(Order).filter_by(customer_name="john smith").one() + print( + [ + (order_item.item.description, order_item.price) + for order_item in order.order_items + ] + ) # print customers who bought 'MySQL Crowbar' on sale - q = session.query(Order).join('order_items', 'item') - q = q.filter(and_(Item.description == 'MySQL Crowbar', - Item.price > OrderItem.price)) + q = session.query(Order).join("order_items", "item") + q = q.filter( + and_(Item.description == "MySQL Crowbar", Item.price > OrderItem.price) + ) print([order.customer_name for order in q]) diff --git a/examples/association/dict_of_sets_with_default.py b/examples/association/dict_of_sets_with_default.py index fb9b6aa06c..7f668c0879 100644 --- a/examples/association/dict_of_sets_with_default.py +++ b/examples/association/dict_of_sets_with_default.py @@ -37,7 +37,9 @@ class A(Base): __tablename__ = "a" associations = relationship( "B", - collection_class=lambda: GenDefaultCollection(operator.attrgetter("key")) + collection_class=lambda: GenDefaultCollection( + operator.attrgetter("key") + ), ) collections = association_proxy("associations", "values") @@ -71,19 +73,15 @@ class C(Base): self.value = value -if __name__ == '__main__': - engine = create_engine('sqlite://', echo=True) +if __name__ == "__main__": + engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) # only "A" is referenced explicitly. Using "collections", # we deal with a dict of key/sets of integers directly. - session.add_all([ - A(collections={ - "1": set([1, 2, 3]), - }) - ]) + session.add_all([A(collections={"1": set([1, 2, 3])})]) session.commit() a1 = session.query(A).first() diff --git a/examples/association/proxied_association.py b/examples/association/proxied_association.py index 3393fdd1d1..46785c6e25 100644 --- a/examples/association/proxied_association.py +++ b/examples/association/proxied_association.py @@ -7,8 +7,15 @@ to ``OrderItem`` optional. from datetime import datetime -from sqlalchemy import (create_engine, Column, Integer, String, DateTime, - Float, ForeignKey) +from sqlalchemy import ( + create_engine, + Column, + Integer, + String, + DateTime, + Float, + ForeignKey, +) from sqlalchemy.orm import relationship, Session from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.associationproxy import association_proxy @@ -17,13 +24,14 @@ Base = declarative_base() class Order(Base): - __tablename__ = 'order' + __tablename__ = "order" order_id = Column(Integer, primary_key=True) customer_name = Column(String(30), nullable=False) order_date = Column(DateTime, nullable=False, default=datetime.now()) - order_items = relationship("OrderItem", cascade="all, delete-orphan", - backref='order') + order_items = relationship( + "OrderItem", cascade="all, delete-orphan", backref="order" + ) items = association_proxy("order_items", "item") def __init__(self, customer_name): @@ -31,7 +39,7 @@ class Order(Base): class Item(Base): - __tablename__ = 'item' + __tablename__ = "item" item_id = Column(Integer, primary_key=True) description = Column(String(30), nullable=False) price = Column(Float, nullable=False) @@ -41,39 +49,40 @@ class Item(Base): self.price = price def __repr__(self): - return 'Item(%r, %r)' % (self.description, self.price) + return "Item(%r, %r)" % (self.description, self.price) class OrderItem(Base): - __tablename__ = 'orderitem' - order_id = Column(Integer, ForeignKey('order.order_id'), primary_key=True) - item_id = Column(Integer, ForeignKey('item.item_id'), primary_key=True) + __tablename__ = "orderitem" + order_id = Column(Integer, ForeignKey("order.order_id"), primary_key=True) + item_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True) price = Column(Float, nullable=False) def __init__(self, item, price=None): self.item = item self.price = price or item.price - item = relationship(Item, lazy='joined') + + item = relationship(Item, lazy="joined") -if __name__ == '__main__': - engine = create_engine('sqlite://') +if __name__ == "__main__": + engine = create_engine("sqlite://") Base.metadata.create_all(engine) session = Session(engine) # create catalog tshirt, mug, hat, crowbar = ( - Item('SA T-Shirt', 10.99), - Item('SA Mug', 6.50), - Item('SA Hat', 8.99), - Item('MySQL Crowbar', 16.99) + Item("SA T-Shirt", 10.99), + Item("SA Mug", 6.50), + Item("SA Hat", 8.99), + Item("MySQL Crowbar", 16.99), ) session.add_all([tshirt, mug, hat, crowbar]) session.commit() # create an order - order = Order('john smith') + order = Order("john smith") # add items via the association proxy. # the OrderItem is created automatically. @@ -87,19 +96,24 @@ if __name__ == '__main__': session.commit() # query the order, print items - order = session.query(Order).filter_by(customer_name='john smith').one() + order = session.query(Order).filter_by(customer_name="john smith").one() # print items based on the OrderItem collection directly - print([(assoc.item.description, assoc.price, assoc.item.price) - for assoc in order.order_items]) + print( + [ + (assoc.item.description, assoc.price, assoc.item.price) + for assoc in order.order_items + ] + ) # print items based on the "proxied" items collection - print([(item.description, item.price) - for item in order.items]) + print([(item.description, item.price) for item in order.items]) # print customers who bought 'MySQL Crowbar' on sale - orders = session.query(Order).\ - join('order_items', 'item').\ - filter(Item.description == 'MySQL Crowbar').\ - filter(Item.price > OrderItem.price) + orders = ( + session.query(Order) + .join("order_items", "item") + .filter(Item.description == "MySQL Crowbar") + .filter(Item.price > OrderItem.price) + ) print([o.customer_name for o in orders]) diff --git a/examples/custom_attributes/__init__.py b/examples/custom_attributes/__init__.py index cbc65dfed9..8d73d27e3c 100644 --- a/examples/custom_attributes/__init__.py +++ b/examples/custom_attributes/__init__.py @@ -4,4 +4,4 @@ system. .. autosource:: -""" \ No newline at end of file +""" diff --git a/examples/custom_attributes/active_column_defaults.py b/examples/custom_attributes/active_column_defaults.py index dd823e814e..f05a53173a 100644 --- a/examples/custom_attributes/active_column_defaults.py +++ b/examples/custom_attributes/active_column_defaults.py @@ -35,6 +35,7 @@ def default_listener(col_attr, default): user integrating this feature. """ + @event.listens_for(col_attr, "init_scalar", retval=True, propagate=True) def init_scalar(target, value, dict_): @@ -52,7 +53,8 @@ def default_listener(col_attr, default): # or can procure a connection from an Engine # or Session and actually run the SQL, if desired. raise NotImplementedError( - "Can't invoke pre-default for a SQL-level column default") + "Can't invoke pre-default for a SQL-level column default" + ) # set the value in the given dict_; this won't emit any further # attribute set events or create attribute "history", but the value @@ -63,7 +65,7 @@ def default_listener(col_attr, default): return value -if __name__ == '__main__': +if __name__ == "__main__": from sqlalchemy import Column, Integer, DateTime, create_engine from sqlalchemy.orm import Session @@ -72,10 +74,10 @@ if __name__ == '__main__': Base = declarative_base() - event.listen(Base, 'mapper_configured', configure_listener, propagate=True) + event.listen(Base, "mapper_configured", configure_listener, propagate=True) class Widget(Base): - __tablename__ = 'widget' + __tablename__ = "widget" id = Column(Integer, primary_key=True) @@ -96,8 +98,8 @@ if __name__ == '__main__': # Column-level default for the "timestamp" column will no longer fire # off. current_time = w1.timestamp - assert ( - current_time > datetime.datetime.now() - datetime.timedelta(seconds=5) + assert current_time > datetime.datetime.now() - datetime.timedelta( + seconds=5 ) # persist @@ -107,7 +109,7 @@ if __name__ == '__main__': # data is persisted. The timestamp is also the one we generated above; # e.g. the default wasn't re-invoked later. - assert ( - sess.query(Widget.radius, Widget.timestamp).first() == - (30, current_time) + assert sess.query(Widget.radius, Widget.timestamp).first() == ( + 30, + current_time, ) diff --git a/examples/custom_attributes/custom_management.py b/examples/custom_attributes/custom_management.py index 2199e01382..8123859060 100644 --- a/examples/custom_attributes/custom_management.py +++ b/examples/custom_attributes/custom_management.py @@ -9,31 +9,44 @@ descriptors with a user-defined system. """ -from sqlalchemy import create_engine, MetaData, Table, Column, Integer, Text,\ - ForeignKey +from sqlalchemy import ( + create_engine, + MetaData, + Table, + Column, + Integer, + Text, + ForeignKey, +) from sqlalchemy.orm import mapper, relationship, Session -from sqlalchemy.orm.attributes import set_attribute, get_attribute, \ - del_attribute +from sqlalchemy.orm.attributes import ( + set_attribute, + get_attribute, + del_attribute, +) from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.ext.instrumentation import InstrumentationManager + class MyClassState(InstrumentationManager): def get_instance_dict(self, class_, instance): return instance._goofy_dict def initialize_instance_dict(self, class_, instance): - instance.__dict__['_goofy_dict'] = {} + instance.__dict__["_goofy_dict"] = {} def install_state(self, class_, instance, state): - instance.__dict__['_goofy_dict']['state'] = state + instance.__dict__["_goofy_dict"]["state"] = state def state_getter(self, class_): def find(instance): - return instance.__dict__['_goofy_dict']['state'] + return instance.__dict__["_goofy_dict"]["state"] + return find + class MyClass(object): __sa_instrumentation_manager__ = MyClassState @@ -63,17 +76,23 @@ class MyClass(object): del self._goofy_dict[key] -if __name__ == '__main__': - engine = create_engine('sqlite://') +if __name__ == "__main__": + engine = create_engine("sqlite://") meta = MetaData() - table1 = Table('table1', meta, - Column('id', Integer, primary_key=True), - Column('name', Text)) - table2 = Table('table2', meta, - Column('id', Integer, primary_key=True), - Column('name', Text), - Column('t1id', Integer, ForeignKey('table1.id'))) + table1 = Table( + "table1", + meta, + Column("id", Integer, primary_key=True), + Column("name", Text), + ) + table2 = Table( + "table2", + meta, + Column("id", Integer, primary_key=True), + Column("name", Text), + Column("t1id", Integer, ForeignKey("table1.id")), + ) meta.create_all(engine) class A(MyClass): @@ -82,16 +101,14 @@ if __name__ == '__main__': class B(MyClass): pass - mapper(A, table1, properties={ - 'bs': relationship(B) - }) + mapper(A, table1, properties={"bs": relationship(B)}) mapper(B, table2) - a1 = A(name='a1', bs=[B(name='b1'), B(name='b2')]) + a1 = A(name="a1", bs=[B(name="b1"), B(name="b2")]) - assert a1.name == 'a1' - assert a1.bs[0].name == 'b1' + assert a1.name == "a1" + assert a1.bs[0].name == "b1" sess = Session(engine) sess.add(a1) @@ -100,8 +117,8 @@ if __name__ == '__main__': a1 = sess.query(A).get(a1.id) - assert a1.name == 'a1' - assert a1.bs[0].name == 'b1' + assert a1.name == "a1" + assert a1.bs[0].name == "b1" a1.bs.remove(a1.bs[0]) diff --git a/examples/custom_attributes/listen_for_events.py b/examples/custom_attributes/listen_for_events.py index 0aeebc1d13..e3ef4cbea8 100644 --- a/examples/custom_attributes/listen_for_events.py +++ b/examples/custom_attributes/listen_for_events.py @@ -5,6 +5,7 @@ and listen for change events. from sqlalchemy import event + def configure_listener(class_, key, inst): def append(instance, value, initiator): instance.receive_change_event("append", key, value, None) @@ -15,19 +16,18 @@ def configure_listener(class_, key, inst): def set_(instance, value, oldvalue, initiator): instance.receive_change_event("set", key, value, oldvalue) - event.listen(inst, 'append', append) - event.listen(inst, 'remove', remove) - event.listen(inst, 'set', set_) + event.listen(inst, "append", append) + event.listen(inst, "remove", remove) + event.listen(inst, "set", set_) -if __name__ == '__main__': +if __name__ == "__main__": from sqlalchemy import Column, Integer, String, ForeignKey from sqlalchemy.orm import relationship from sqlalchemy.ext.declarative import declarative_base class Base(object): - def receive_change_event(self, verb, key, value, oldvalue): s = "Value '%s' %s on attribute '%s', " % (value, verb, key) if oldvalue: @@ -37,7 +37,7 @@ if __name__ == '__main__': Base = declarative_base(cls=Base) - event.listen(Base, 'attribute_instrument', configure_listener) + event.listen(Base, "attribute_instrument", configure_listener) class MyMappedClass(Base): __tablename__ = "mytable" @@ -61,9 +61,7 @@ if __name__ == '__main__': # classes are instrumented. Demonstrate the events ! - m1 = MyMappedClass(data='m1', related=Related(data='r1')) - m1.data = 'm1mod' - m1.related.mapped.append(MyMappedClass(data='m2')) + m1 = MyMappedClass(data="m1", related=Related(data="r1")) + m1.data = "m1mod" + m1.related.mapped.append(MyMappedClass(data="m2")) del m1.data - - diff --git a/examples/dogpile_caching/advanced.py b/examples/dogpile_caching/advanced.py index dc2ed0771e..8f395bd7b3 100644 --- a/examples/dogpile_caching/advanced.py +++ b/examples/dogpile_caching/advanced.py @@ -7,6 +7,7 @@ from .environment import Session from .model import Person, cache_address_bits from .caching_query import FromCache, RelationshipCache + def load_name_range(start, end, invalidate=False): """Load Person objects on a range of names. @@ -24,10 +25,14 @@ def load_name_range(start, end, invalidate=False): SQL that emits for unloaded Person objects as well as the distribution of data within the cache. """ - q = Session.query(Person).\ - filter(Person.name.between("person %.2d" % start, "person %.2d" % end)).\ - options(cache_address_bits).\ - options(FromCache("default", "name_range")) + q = ( + Session.query(Person) + .filter( + Person.name.between("person %.2d" % start, "person %.2d" % end) + ) + .options(cache_address_bits) + .options(FromCache("default", "name_range")) + ) # have the "addresses" collection cached separately # each lazyload of Person.addresses loads from cache. @@ -37,7 +42,7 @@ def load_name_range(start, end, invalidate=False): # be cached together. This issues a bigger SQL statement and caches # a single, larger value in the cache per person rather than two # separate ones. - #q = q.options(joinedload(Person.addresses)) + # q = q.options(joinedload(Person.addresses)) # if requested, invalidate the cache on current criterion. if invalidate: @@ -45,6 +50,7 @@ def load_name_range(start, end, invalidate=False): return q.all() + print("two through twelve, possibly from cache:\n") print(", ".join([p.name for p in load_name_range(2, 12)])) @@ -61,7 +67,9 @@ print(", ".join([p.name for p in load_name_range(25, 40, True)])) # illustrate the address loading from either cache/already # on the Person -print("\n\nPeople plus addresses, two through twelve, addresses possibly from cache") +print( + "\n\nPeople plus addresses, two through twelve, addresses possibly from cache" +) for p in load_name_range(2, 12): print(p.format_full()) @@ -71,5 +79,7 @@ print("\n\nPeople plus addresses, two through twelve, addresses from cache") for p in load_name_range(2, 12): print(p.format_full()) -print("\n\nIf this was the first run of advanced.py, try "\ - "a second run. Only one SQL statement will be emitted.") +print( + "\n\nIf this was the first run of advanced.py, try " + "a second run. Only one SQL statement will be emitted." +) diff --git a/examples/dogpile_caching/caching_query.py b/examples/dogpile_caching/caching_query.py index 6ad2dba4d3..060c146132 100644 --- a/examples/dogpile_caching/caching_query.py +++ b/examples/dogpile_caching/caching_query.py @@ -59,7 +59,7 @@ class CachingQuery(Query): """ super_ = super(CachingQuery, self) - if hasattr(self, '_cache_region'): + if hasattr(self, "_cache_region"): return self.get_value(createfunc=lambda: list(super_.__iter__())) else: return super_.__iter__() @@ -78,13 +78,11 @@ class CachingQuery(Query): """ super_ = super(CachingQuery, self) - if context.query is not self and hasattr(self, '_cache_region'): + if context.query is not self and hasattr(self, "_cache_region"): # special logic called when the Query._execute_and_instances() # method is called directly from the baked query return self.get_value( - createfunc=lambda: list( - super_._execute_and_instances(context) - ) + createfunc=lambda: list(super_._execute_and_instances(context)) ) else: return super_._execute_and_instances(context) @@ -105,8 +103,13 @@ class CachingQuery(Query): dogpile_region, cache_key = self._get_cache_plus_key() dogpile_region.delete(cache_key) - def get_value(self, merge=True, createfunc=None, - expiration_time=None, ignore_expiration=False): + def get_value( + self, + merge=True, + createfunc=None, + expiration_time=None, + ignore_expiration=False, + ): """Return the value from the cache for this query. Raise KeyError if no value present and no @@ -119,19 +122,20 @@ class CachingQuery(Query): # but is expired, return it anyway. This doesn't make sense # with createfunc, which says, if the value is expired, generate # a new value. - assert not ignore_expiration or not createfunc, \ - "Can't ignore expiration and also provide createfunc" + assert ( + not ignore_expiration or not createfunc + ), "Can't ignore expiration and also provide createfunc" if ignore_expiration or not createfunc: - cached_value = dogpile_region.get(cache_key, - expiration_time=expiration_time, - ignore_expiration=ignore_expiration) + cached_value = dogpile_region.get( + cache_key, + expiration_time=expiration_time, + ignore_expiration=ignore_expiration, + ) else: cached_value = dogpile_region.get_or_create( - cache_key, - createfunc, - expiration_time=expiration_time - ) + cache_key, createfunc, expiration_time=expiration_time + ) if cached_value is NO_VALUE: raise KeyError(cache_key) if merge: @@ -144,11 +148,14 @@ class CachingQuery(Query): dogpile_region, cache_key = self._get_cache_plus_key() dogpile_region.set(cache_key, value) + def query_callable(regions, query_cls=CachingQuery): def query(*arg, **kw): return query_cls(regions, *arg, **kw) + return query + def _key_from_query(query, qualifier=None): """Given a Query, create a cache key. @@ -168,9 +175,8 @@ def _key_from_query(query, qualifier=None): # here we return the key as a long string. our "key mangler" # set up with the region will boil it down to an md5. - return " ".join( - [str(compiled)] + - [str(params[k]) for k in sorted(params)]) + return " ".join([str(compiled)] + [str(params[k]) for k in sorted(params)]) + class FromCache(MapperOption): """Specifies that a Query should load results from a cache.""" @@ -198,6 +204,7 @@ class FromCache(MapperOption): """Process a Query during normal loading operation.""" query._cache_region = self + class RelationshipCache(MapperOption): """Specifies that a Query as called within a "lazy load" should load results from a cache.""" @@ -237,7 +244,9 @@ class RelationshipCache(MapperOption): for cls in mapper.class_.__mro__: if (cls, key) in self._relationship_options: - relationship_option = self._relationship_options[(cls, key)] + relationship_option = self._relationship_options[ + (cls, key) + ] query._cache_region = relationship_option break @@ -264,4 +273,3 @@ class RelationshipCache(MapperOption): """ return None - diff --git a/examples/dogpile_caching/environment.py b/examples/dogpile_caching/environment.py index 130dfdb2b7..13bd0a3100 100644 --- a/examples/dogpile_caching/environment.py +++ b/examples/dogpile_caching/environment.py @@ -10,6 +10,7 @@ from dogpile.cache.region import make_region import os from hashlib import md5 import sys + py2k = sys.version_info < (3, 0) if py2k: @@ -23,9 +24,7 @@ regions = {} # using a callable that will associate the dictionary # of regions with the Query. Session = scoped_session( - sessionmaker( - query_cls=caching_query.query_callable(regions) - ) + sessionmaker(query_cls=caching_query.query_callable(regions)) ) # global declarative base class. @@ -42,7 +41,7 @@ if not os.path.exists(root): os.makedirs(root) dbfile = os.path.join(root, "dogpile_demo.db") -engine = create_engine('sqlite:///%s' % dbfile, echo=True) +engine = create_engine("sqlite:///%s" % dbfile, echo=True) Session.configure(bind=engine) @@ -51,10 +50,11 @@ def md5_key_mangler(key): distill them into an md5 hash. """ - return md5(key.encode('ascii')).hexdigest() + return md5(key.encode("ascii")).hexdigest() + # configure the "default" cache region. -regions['default'] = make_region( +regions["default"] = make_region( # the "dbm" backend needs # string-encoded keys key_mangler=md5_key_mangler @@ -63,11 +63,9 @@ regions['default'] = make_region( # serialized persistence. Normally # memcached or similar is a better choice # for caching. - 'dogpile.cache.dbm', + "dogpile.cache.dbm", expiration_time=3600, - arguments={ - "filename": os.path.join(root, "cache.dbm") - } + arguments={"filename": os.path.join(root, "cache.dbm")}, ) # optional; call invalidate() on the region @@ -83,6 +81,7 @@ installed = False def bootstrap(): global installed from . import fixture_data + if not os.path.exists(dbfile): fixture_data.install() - installed = True \ No newline at end of file + installed = True diff --git a/examples/dogpile_caching/fixture_data.py b/examples/dogpile_caching/fixture_data.py index 4651718912..e301db2a46 100644 --- a/examples/dogpile_caching/fixture_data.py +++ b/examples/dogpile_caching/fixture_data.py @@ -12,13 +12,19 @@ def install(): Base.metadata.create_all(Session().bind) data = [ - ('Chicago', 'United States', ('60601', '60602', '60603', '60604')), - ('Montreal', 'Canada', ('H2S 3K9', 'H2B 1V4', 'H7G 2T8')), - ('Edmonton', 'Canada', ('T5J 1R9', 'T5J 1Z4', 'T5H 1P6')), - ('New York', 'United States', - ('10001', '10002', '10003', '10004', '10005', '10006')), - ('San Francisco', 'United States', - ('94102', '94103', '94104', '94105', '94107', '94108')) + ("Chicago", "United States", ("60601", "60602", "60603", "60604")), + ("Montreal", "Canada", ("H2S 3K9", "H2B 1V4", "H7G 2T8")), + ("Edmonton", "Canada", ("T5J 1R9", "T5J 1Z4", "T5H 1P6")), + ( + "New York", + "United States", + ("10001", "10002", "10003", "10004", "10005", "10006"), + ), + ( + "San Francisco", + "United States", + ("94102", "94103", "94104", "94105", "94107", "94108"), + ), ] countries = {} @@ -40,8 +46,9 @@ def install(): Address( street="street %.2d" % i, postal_code=all_post_codes[ - random.randint(0, len(all_post_codes) - 1)] - ) + random.randint(0, len(all_post_codes) - 1) + ], + ), ) Session.add(person) diff --git a/examples/dogpile_caching/helloworld.py b/examples/dogpile_caching/helloworld.py index 0dbde5eafa..eb565344e9 100644 --- a/examples/dogpile_caching/helloworld.py +++ b/examples/dogpile_caching/helloworld.py @@ -21,28 +21,34 @@ people = Session.query(Person).options(FromCache("default")).all() # Specifying a different query produces a different cache key, so # these results are independently cached. print("loading people two through twelve") -people_two_through_twelve = Session.query(Person).\ - options(FromCache("default")).\ - filter(Person.name.between("person 02", "person 12")).\ - all() +people_two_through_twelve = ( + Session.query(Person) + .options(FromCache("default")) + .filter(Person.name.between("person 02", "person 12")) + .all() +) # the data is cached under string structure of the SQL statement, *plus* # the bind parameters of the query. So this query, having # different literal parameters under "Person.name.between()" than the # previous one, issues new SQL... print("loading people five through fifteen") -people_five_through_fifteen = Session.query(Person).\ - options(FromCache("default")).\ - filter(Person.name.between("person 05", "person 15")).\ - all() +people_five_through_fifteen = ( + Session.query(Person) + .options(FromCache("default")) + .filter(Person.name.between("person 05", "person 15")) + .all() +) # ... but using the same params as are already cached, no SQL print("loading people two through twelve...again!") -people_two_through_twelve = Session.query(Person).\ - options(FromCache("default")).\ - filter(Person.name.between("person 02", "person 12")).\ - all() +people_two_through_twelve = ( + Session.query(Person) + .options(FromCache("default")) + .filter(Person.name.between("person 02", "person 12")) + .all() +) # invalidate the cache for the three queries we've done. Recreate @@ -51,10 +57,9 @@ people_two_through_twelve = Session.query(Person).\ # same order, then call invalidate(). print("invalidating everything") Session.query(Person).options(FromCache("default")).invalidate() -Session.query(Person).\ - options(FromCache("default")).\ - filter(Person.name.between("person 02", "person 12")).invalidate() -Session.query(Person).\ - options(FromCache("default", "people_on_range")).\ - filter(Person.name.between("person 05", "person 15")).invalidate() - +Session.query(Person).options(FromCache("default")).filter( + Person.name.between("person 02", "person 12") +).invalidate() +Session.query(Person).options(FromCache("default", "people_on_range")).filter( + Person.name.between("person 05", "person 15") +).invalidate() diff --git a/examples/dogpile_caching/local_session_caching.py b/examples/dogpile_caching/local_session_caching.py index 633252fc7e..358886bf0a 100644 --- a/examples/dogpile_caching/local_session_caching.py +++ b/examples/dogpile_caching/local_session_caching.py @@ -30,7 +30,7 @@ class ScopedSessionBackend(CacheBackend): """ def __init__(self, arguments): - self.scoped_session = arguments['scoped_session'] + self.scoped_session = arguments["scoped_session"] def get(self, key): return self._cache_dictionary.get(key, NO_VALUE) @@ -52,10 +52,11 @@ class ScopedSessionBackend(CacheBackend): sess._cache_dictionary = cache_dict = {} return cache_dict + register_backend("sqlalchemy.session", __name__, "ScopedSessionBackend") -if __name__ == '__main__': +if __name__ == "__main__": from .environment import Session, regions from .caching_query import FromCache from dogpile.cache import make_region @@ -63,20 +64,19 @@ if __name__ == '__main__': # set up a region based on the ScopedSessionBackend, # pointing to the scoped_session declared in the example # environment. - regions['local_session'] = make_region().configure( - 'sqlalchemy.session', - arguments={ - "scoped_session": Session - } + regions["local_session"] = make_region().configure( + "sqlalchemy.session", arguments={"scoped_session": Session} ) from .model import Person # query to load Person by name, with criterion # of "person 10" - q = Session.query(Person).\ - options(FromCache("local_session")).\ - filter(Person.name == "person 10") + q = ( + Session.query(Person) + .options(FromCache("local_session")) + .filter(Person.name == "person 10") + ) # load from DB person10 = q.one() diff --git a/examples/dogpile_caching/model.py b/examples/dogpile_caching/model.py index 3eb02108c5..f6a2598206 100644 --- a/examples/dogpile_caching/model.py +++ b/examples/dogpile_caching/model.py @@ -14,7 +14,7 @@ from .environment import Base, bootstrap class Country(Base): - __tablename__ = 'country' + __tablename__ = "country" id = Column(Integer, primary_key=True) name = Column(String(100), nullable=False) @@ -24,11 +24,11 @@ class Country(Base): class City(Base): - __tablename__ = 'city' + __tablename__ = "city" id = Column(Integer, primary_key=True) name = Column(String(100), nullable=False) - country_id = Column(Integer, ForeignKey('country.id'), nullable=False) + country_id = Column(Integer, ForeignKey("country.id"), nullable=False) country = relationship(Country) def __init__(self, name, country): @@ -37,11 +37,11 @@ class City(Base): class PostalCode(Base): - __tablename__ = 'postal_code' + __tablename__ = "postal_code" id = Column(Integer, primary_key=True) code = Column(String(10), nullable=False) - city_id = Column(Integer, ForeignKey('city.id'), nullable=False) + city_id = Column(Integer, ForeignKey("city.id"), nullable=False) city = relationship(City) @property @@ -54,12 +54,12 @@ class PostalCode(Base): class Address(Base): - __tablename__ = 'address' + __tablename__ = "address" id = Column(Integer, primary_key=True) - person_id = Column(Integer, ForeignKey('person.id'), nullable=False) + person_id = Column(Integer, ForeignKey("person.id"), nullable=False) street = Column(String(200), nullable=False) - postal_code_id = Column(Integer, ForeignKey('postal_code.id')) + postal_code_id = Column(Integer, ForeignKey("postal_code.id")) postal_code = relationship(PostalCode) @property @@ -71,15 +71,16 @@ class Address(Base): return self.postal_code.country def __str__(self): - return ( - "%s\t%s, %s\t%s" % ( - self.street, self.city.name, - self.postal_code.code, self.country.name) + return "%s\t%s, %s\t%s" % ( + self.street, + self.city.name, + self.postal_code.code, + self.country.name, ) class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) name = Column(String(100), nullable=False) @@ -98,14 +99,14 @@ class Person(Base): def format_full(self): return "\t".join([str(x) for x in [self] + list(self.addresses)]) + # Caching options. A set of three RelationshipCache options # which can be applied to Query(), causing the "lazy load" # of these attributes to be loaded from cache. -cache_address_bits = RelationshipCache(PostalCode.city, "default").\ - and_( - RelationshipCache(City.country, "default") -).and_( - RelationshipCache(Address.postal_code, "default") +cache_address_bits = ( + RelationshipCache(PostalCode.city, "default") + .and_(RelationshipCache(City.country, "default")) + .and_(RelationshipCache(Address.postal_code, "default")) ) bootstrap() diff --git a/examples/dogpile_caching/relationship_caching.py b/examples/dogpile_caching/relationship_caching.py index 920d696f8b..76c7e767f7 100644 --- a/examples/dogpile_caching/relationship_caching.py +++ b/examples/dogpile_caching/relationship_caching.py @@ -12,7 +12,8 @@ from sqlalchemy.orm import joinedload import os for p in Session.query(Person).options( - joinedload(Person.addresses), cache_address_bits): + joinedload(Person.addresses), cache_address_bits +): print(p.format_full()) @@ -25,5 +26,5 @@ print( "related data is pulled from cache.\n" "To clear the cache, delete the file %r. \n" "This will cause a re-load of cities, postal codes and countries on " - "the next run.\n" - % os.path.join(root, 'cache.dbm')) + "the next run.\n" % os.path.join(root, "cache.dbm") +) diff --git a/examples/dynamic_dict/__init__.py b/examples/dynamic_dict/__init__.py index e592ea2005..ed31df062f 100644 --- a/examples/dynamic_dict/__init__.py +++ b/examples/dynamic_dict/__init__.py @@ -5,4 +5,4 @@ full collection at once. .. autosource:: -""" \ No newline at end of file +""" diff --git a/examples/dynamic_dict/dynamic_dict.py b/examples/dynamic_dict/dynamic_dict.py index 530674f2ea..62da55c387 100644 --- a/examples/dynamic_dict/dynamic_dict.py +++ b/examples/dynamic_dict/dynamic_dict.py @@ -14,7 +14,7 @@ class ProxyDict(object): return [x[0] for x in self.collection.values(descriptor)] def __getitem__(self, key): - x = self.collection.filter_by(**{self.keyname:key}).first() + x = self.collection.filter_by(**{self.keyname: key}).first() if x: return x else: @@ -28,43 +28,48 @@ class ProxyDict(object): pass self.collection.append(value) + from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import create_engine, Column, Integer, String, ForeignKey from sqlalchemy.orm import sessionmaker, relationship -engine = create_engine('sqlite://', echo=True) +engine = create_engine("sqlite://", echo=True) Base = declarative_base(engine) + class Parent(Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) name = Column(String(50)) - _collection = relationship("Child", lazy="dynamic", - cascade="all, delete-orphan") + _collection = relationship( + "Child", lazy="dynamic", cascade="all, delete-orphan" + ) @property def child_map(self): - return ProxyDict(self, '_collection', Child, 'key') + return ProxyDict(self, "_collection", Child, "key") + class Child(Base): - __tablename__ = 'child' + __tablename__ = "child" id = Column(Integer, primary_key=True) key = Column(String(50)) - parent_id = Column(Integer, ForeignKey('parent.id')) + parent_id = Column(Integer, ForeignKey("parent.id")) def __repr__(self): return "Child(key=%r)" % self.key + Base.metadata.create_all() sess = sessionmaker()() -p1 = Parent(name='p1') +p1 = Parent(name="p1") sess.add(p1) print("\n---------begin setting nodes, autoflush occurs\n") -p1.child_map['k1'] = Child(key='k1') -p1.child_map['k2'] = Child(key='k2') +p1.child_map["k1"] = Child(key="k1") +p1.child_map["k2"] = Child(key="k2") # this will autoflush the current map. # ['k1', 'k2'] @@ -73,16 +78,15 @@ print(list(p1.child_map.keys())) # k1 print("\n---------print 'k1' node\n") -print(p1.child_map['k1']) +print(p1.child_map["k1"]) print("\n---------update 'k2' node - must find existing, and replace\n") -p1.child_map['k2'] = Child(key='k2') +p1.child_map["k2"] = Child(key="k2") print("\n---------print 'k2' key - flushes first\n") # k2 -print(p1.child_map['k2']) +print(p1.child_map["k2"]) print("\n---------print all child nodes\n") # [k1, k2b] print(sess.query(Child).all()) - diff --git a/examples/elementtree/__init__.py b/examples/elementtree/__init__.py index 66e9cfbbe3..82d00ff5ad 100644 --- a/examples/elementtree/__init__.py +++ b/examples/elementtree/__init__.py @@ -22,4 +22,4 @@ E.g.:: .. autosource:: :files: pickle.py, adjacency_list.py, optimized_al.py -""" \ No newline at end of file +""" diff --git a/examples/elementtree/adjacency_list.py b/examples/elementtree/adjacency_list.py index 5e27ba9cae..1f71612128 100644 --- a/examples/elementtree/adjacency_list.py +++ b/examples/elementtree/adjacency_list.py @@ -15,42 +15,63 @@ styles of persistence are identical, as is the structure of the main Document cl """ ################################# PART I - Imports/Coniguration #################################### -from sqlalchemy import (MetaData, Table, Column, Integer, String, ForeignKey, - Unicode, and_, create_engine) +from sqlalchemy import ( + MetaData, + Table, + Column, + Integer, + String, + ForeignKey, + Unicode, + and_, + create_engine, +) from sqlalchemy.orm import mapper, relationship, Session, lazyload import sys, os, io, re from xml.etree import ElementTree -e = create_engine('sqlite://') +e = create_engine("sqlite://") meta = MetaData() ################################# PART II - Table Metadata ######################################### # stores a top level record of an XML document. -documents = Table('documents', meta, - Column('document_id', Integer, primary_key=True), - Column('filename', String(30), unique=True), - Column('element_id', Integer, ForeignKey('elements.element_id')) +documents = Table( + "documents", + meta, + Column("document_id", Integer, primary_key=True), + Column("filename", String(30), unique=True), + Column("element_id", Integer, ForeignKey("elements.element_id")), ) # stores XML nodes in an adjacency list model. This corresponds to # Element and SubElement objects. -elements = Table('elements', meta, - Column('element_id', Integer, primary_key=True), - Column('parent_id', Integer, ForeignKey('elements.element_id')), - Column('tag', Unicode(30), nullable=False), - Column('text', Unicode), - Column('tail', Unicode) - ) +elements = Table( + "elements", + meta, + Column("element_id", Integer, primary_key=True), + Column("parent_id", Integer, ForeignKey("elements.element_id")), + Column("tag", Unicode(30), nullable=False), + Column("text", Unicode), + Column("tail", Unicode), +) # stores attributes. This corresponds to the dictionary of attributes # stored by an Element or SubElement. -attributes = Table('attributes', meta, - Column('element_id', Integer, ForeignKey('elements.element_id'), primary_key=True), - Column('name', Unicode(100), nullable=False, primary_key=True), - Column('value', Unicode(255))) +attributes = Table( + "attributes", + meta, + Column( + "element_id", + Integer, + ForeignKey("elements.element_id"), + primary_key=True, + ), + Column("name", Unicode(100), nullable=False, primary_key=True), + Column("value", Unicode(255)), +) meta.create_all(e) @@ -68,6 +89,7 @@ class Document(object): self.element.write(buf) return buf.getvalue() + #################################### PART IV - Persistence Mapping ################################# # Node class. a non-public class which will represent @@ -78,6 +100,7 @@ class Document(object): class _Node(object): pass + # Attribute class. also internal, this will represent the key/value attributes stored for # a particular Node. class _Attribute(object): @@ -85,16 +108,25 @@ class _Attribute(object): self.name = name self.value = value + # setup mappers. Document will eagerly load a list of _Node objects. -mapper(Document, documents, properties={ - '_root':relationship(_Node, lazy='joined', cascade="all") -}) +mapper( + Document, + documents, + properties={"_root": relationship(_Node, lazy="joined", cascade="all")}, +) -mapper(_Node, elements, properties={ - 'children':relationship(_Node, cascade="all"), - # eagerly load attributes - 'attributes':relationship(_Attribute, lazy='joined', cascade="all, delete-orphan"), -}) +mapper( + _Node, + elements, + properties={ + "children": relationship(_Node, cascade="all"), + # eagerly load attributes + "attributes": relationship( + _Attribute, lazy="joined", cascade="all, delete-orphan" + ), + }, +) mapper(_Attribute, attributes) @@ -106,7 +138,7 @@ class ElementTreeMarshal(object): if document is None: return self - if hasattr(document, '_element'): + if hasattr(document, "_element"): return document._element def traverse(node, parent=None): @@ -132,7 +164,9 @@ class ElementTreeMarshal(object): n.text = str(node.text) n.tail = str(node.tail) n.children = [traverse(n2) for n2 in node] - n.attributes = [_Attribute(str(k), str(v)) for k, v in node.attrib.items()] + n.attributes = [ + _Attribute(str(k), str(v)) for k, v in node.attrib.items() + ] return n document._root = traverse(element.getroot()) @@ -142,6 +176,7 @@ class ElementTreeMarshal(object): del document._element document._root = [] + # override Document's "element" attribute with the marshaller. Document.element = ElementTreeMarshal() @@ -153,7 +188,7 @@ line = "\n--------------------------------------------------------" session = Session(e) # get ElementTree documents -for file in ('test.xml', 'test2.xml', 'test3.xml'): +for file in ("test.xml", "test2.xml", "test3.xml"): filename = os.path.join(os.path.dirname(__file__), file) doc = ElementTree.parse(filename) session.add(Document(file, doc)) @@ -170,10 +205,16 @@ print(document) ############################################ PART VI - Searching for Paths ######################### # manually search for a document which contains "/somefile/header/field1:hi" -d = session.query(Document).join('_root', aliased=True).filter(_Node.tag=='somefile').\ - join('children', aliased=True, from_joinpoint=True).filter(_Node.tag=='header').\ - join('children', aliased=True, from_joinpoint=True).filter( - and_(_Node.tag=='field1', _Node.text=='hi')).one() +d = ( + session.query(Document) + .join("_root", aliased=True) + .filter(_Node.tag == "somefile") + .join("children", aliased=True, from_joinpoint=True) + .filter(_Node.tag == "header") + .join("children", aliased=True, from_joinpoint=True) + .filter(and_(_Node.tag == "field1", _Node.text == "hi")) + .one() +) print(d) # generalize the above approach into an extremely impoverished xpath function: @@ -181,26 +222,39 @@ def find_document(path, compareto): j = documents prev_elements = None query = session.query(Document) - attribute = '_root' - for i, match in enumerate(re.finditer(r'/([\w_]+)(?:\[@([\w_]+)(?:=(.*))?\])?', path)): + attribute = "_root" + for i, match in enumerate( + re.finditer(r"/([\w_]+)(?:\[@([\w_]+)(?:=(.*))?\])?", path) + ): (token, attrname, attrvalue) = match.group(1, 2, 3) - query = query.join(attribute, aliased=True, from_joinpoint=True).filter(_Node.tag==token) - attribute = 'children' + query = query.join( + attribute, aliased=True, from_joinpoint=True + ).filter(_Node.tag == token) + attribute = "children" if attrname: if attrvalue: - query = query.join('attributes', aliased=True, from_joinpoint=True).filter( - and_(_Attribute.name==attrname, _Attribute.value==attrvalue)) + query = query.join( + "attributes", aliased=True, from_joinpoint=True + ).filter( + and_( + _Attribute.name == attrname, + _Attribute.value == attrvalue, + ) + ) else: - query = query.join('attributes', aliased=True, from_joinpoint=True).filter( - _Attribute.name==attrname) - return query.options(lazyload('_root')).filter(_Node.text==compareto).all() + query = query.join( + "attributes", aliased=True, from_joinpoint=True + ).filter(_Attribute.name == attrname) + return ( + query.options(lazyload("_root")).filter(_Node.text == compareto).all() + ) + for path, compareto in ( - ('/somefile/header/field1', 'hi'), - ('/somefile/field1', 'hi'), - ('/somefile/header/field2', 'there'), - ('/somefile/header/field2[@attr=foo]', 'there') - ): + ("/somefile/header/field1", "hi"), + ("/somefile/field1", "hi"), + ("/somefile/header/field2", "there"), + ("/somefile/header/field2[@attr=foo]", "there"), +): print("\nDocuments containing '%s=%s':" % (path, compareto), line) print([d.filename for d in find_document(path, compareto)]) - diff --git a/examples/elementtree/optimized_al.py b/examples/elementtree/optimized_al.py index e13f5b0eed..8e9c48b96b 100644 --- a/examples/elementtree/optimized_al.py +++ b/examples/elementtree/optimized_al.py @@ -8,42 +8,63 @@ """ ##################### PART I - Imports/Configuration ######################### -from sqlalchemy import (MetaData, Table, Column, Integer, String, ForeignKey, - Unicode, and_, create_engine) +from sqlalchemy import ( + MetaData, + Table, + Column, + Integer, + String, + ForeignKey, + Unicode, + and_, + create_engine, +) from sqlalchemy.orm import mapper, relationship, Session, lazyload import sys, os, io, re from xml.etree import ElementTree -e = create_engine('sqlite://', echo=True) +e = create_engine("sqlite://", echo=True) meta = MetaData() ####################### PART II - Table Metadata ############################# # stores a top level record of an XML document. -documents = Table('documents', meta, - Column('document_id', Integer, primary_key=True), - Column('filename', String(30), unique=True), +documents = Table( + "documents", + meta, + Column("document_id", Integer, primary_key=True), + Column("filename", String(30), unique=True), ) # stores XML nodes in an adjacency list model. This corresponds to # Element and SubElement objects. -elements = Table('elements', meta, - Column('element_id', Integer, primary_key=True), - Column('parent_id', Integer, ForeignKey('elements.element_id')), - Column('document_id', Integer, ForeignKey('documents.document_id')), - Column('tag', Unicode(30), nullable=False), - Column('text', Unicode), - Column('tail', Unicode) - ) +elements = Table( + "elements", + meta, + Column("element_id", Integer, primary_key=True), + Column("parent_id", Integer, ForeignKey("elements.element_id")), + Column("document_id", Integer, ForeignKey("documents.document_id")), + Column("tag", Unicode(30), nullable=False), + Column("text", Unicode), + Column("tail", Unicode), +) # stores attributes. This corresponds to the dictionary of attributes # stored by an Element or SubElement. -attributes = Table('attributes', meta, - Column('element_id', Integer, ForeignKey('elements.element_id'), primary_key=True), - Column('name', Unicode(100), nullable=False, primary_key=True), - Column('value', Unicode(255))) +attributes = Table( + "attributes", + meta, + Column( + "element_id", + Integer, + ForeignKey("elements.element_id"), + primary_key=True, + ), + Column("name", Unicode(100), nullable=False, primary_key=True), + Column("value", Unicode(255)), +) meta.create_all(e) @@ -61,6 +82,7 @@ class Document(object): self.element.write(buf) return buf.getvalue() + ########################## PART IV - Persistence Mapping ##################### # Node class. a non-public class which will represent @@ -71,6 +93,7 @@ class Document(object): class _Node(object): pass + # Attribute class. also internal, this will represent the key/value attributes stored for # a particular Node. class _Attribute(object): @@ -78,21 +101,36 @@ class _Attribute(object): self.name = name self.value = value + # setup mappers. Document will eagerly load a list of _Node objects. # they will be ordered in primary key/insert order, so that we can reconstruct # an ElementTree structure from the list. -mapper(Document, documents, properties={ - '_nodes':relationship(_Node, lazy='joined', cascade="all, delete-orphan") -}) +mapper( + Document, + documents, + properties={ + "_nodes": relationship( + _Node, lazy="joined", cascade="all, delete-orphan" + ) + }, +) # the _Node objects change the way they load so that a list of _Nodes will organize # themselves hierarchically using the ElementTreeMarshal. this depends on the ordering of # nodes being hierarchical as well; relationship() always applies at least ROWID/primary key # ordering to rows which will suffice. -mapper(_Node, elements, properties={ - 'children':relationship(_Node, lazy=None), # doesnt load; used only for the save relationship - 'attributes':relationship(_Attribute, lazy='joined', cascade="all, delete-orphan"), # eagerly load attributes -}) +mapper( + _Node, + elements, + properties={ + "children": relationship( + _Node, lazy=None + ), # doesnt load; used only for the save relationship + "attributes": relationship( + _Attribute, lazy="joined", cascade="all, delete-orphan" + ), # eagerly load attributes + }, +) mapper(_Attribute, attributes) @@ -104,7 +142,7 @@ class ElementTreeMarshal(object): if document is None: return self - if hasattr(document, '_element'): + if hasattr(document, "_element"): return document._element nodes = {} @@ -134,7 +172,9 @@ class ElementTreeMarshal(object): n.tail = str(node.tail) document._nodes.append(n) n.children = [traverse(n2) for n2 in node] - n.attributes = [_Attribute(str(k), str(v)) for k, v in node.attrib.items()] + n.attributes = [ + _Attribute(str(k), str(v)) for k, v in node.attrib.items() + ] return n traverse(element.getroot()) @@ -144,6 +184,7 @@ class ElementTreeMarshal(object): del document._element document._nodes = [] + # override Document's "element" attribute with the marshaller. Document.element = ElementTreeMarshal() @@ -155,7 +196,7 @@ line = "\n--------------------------------------------------------" session = Session(e) # get ElementTree documents -for file in ('test.xml', 'test2.xml', 'test3.xml'): +for file in ("test.xml", "test2.xml", "test3.xml"): filename = os.path.join(os.path.dirname(__file__), file) doc = ElementTree.parse(filename) session.add(Document(file, doc)) @@ -173,13 +214,16 @@ print(document) # manually search for a document which contains "/somefile/header/field1:hi" print("\nManual search for /somefile/header/field1=='hi':", line) -d = session.query(Document).join('_nodes', aliased=True).\ - filter(and_(_Node.parent_id==None, _Node.tag=='somefile')).\ - join('children', aliased=True, from_joinpoint=True).\ - filter(_Node.tag=='header').\ - join('children', aliased=True, from_joinpoint=True).\ - filter(and_(_Node.tag=='field1', _Node.text=='hi')).\ - one() +d = ( + session.query(Document) + .join("_nodes", aliased=True) + .filter(and_(_Node.parent_id == None, _Node.tag == "somefile")) + .join("children", aliased=True, from_joinpoint=True) + .filter(_Node.tag == "header") + .join("children", aliased=True, from_joinpoint=True) + .filter(and_(_Node.tag == "field1", _Node.text == "hi")) + .one() +) print(d) # generalize the above approach into an extremely impoverished xpath function: @@ -188,28 +232,39 @@ def find_document(path, compareto): prev_elements = None query = session.query(Document) first = True - for i, match in enumerate(re.finditer(r'/([\w_]+)(?:\[@([\w_]+)(?:=(.*))?\])?', path)): + for i, match in enumerate( + re.finditer(r"/([\w_]+)(?:\[@([\w_]+)(?:=(.*))?\])?", path) + ): (token, attrname, attrvalue) = match.group(1, 2, 3) if first: - query = query.join('_nodes', aliased=True).filter(_Node.parent_id==None) + query = query.join("_nodes", aliased=True).filter( + _Node.parent_id == None + ) first = False else: - query = query.join('children', aliased=True, from_joinpoint=True) - query = query.filter(_Node.tag==token) + query = query.join("children", aliased=True, from_joinpoint=True) + query = query.filter(_Node.tag == token) if attrname: - query = query.join('attributes', aliased=True, from_joinpoint=True) + query = query.join("attributes", aliased=True, from_joinpoint=True) if attrvalue: - query = query.filter(and_(_Attribute.name==attrname, _Attribute.value==attrvalue)) + query = query.filter( + and_( + _Attribute.name == attrname, + _Attribute.value == attrvalue, + ) + ) else: - query = query.filter(_Attribute.name==attrname) - return query.options(lazyload('_nodes')).filter(_Node.text==compareto).all() + query = query.filter(_Attribute.name == attrname) + return ( + query.options(lazyload("_nodes")).filter(_Node.text == compareto).all() + ) + for path, compareto in ( - ('/somefile/header/field1', 'hi'), - ('/somefile/field1', 'hi'), - ('/somefile/header/field2', 'there'), - ('/somefile/header/field2[@attr=foo]', 'there') - ): + ("/somefile/header/field1", "hi"), + ("/somefile/field1", "hi"), + ("/somefile/header/field2", "there"), + ("/somefile/header/field2[@attr=foo]", "there"), +): print("\nDocuments containing '%s=%s':" % (path, compareto), line) print([d.filename for d in find_document(path, compareto)]) - diff --git a/examples/elementtree/pickle.py b/examples/elementtree/pickle.py index d40af275bd..a86fe30e56 100644 --- a/examples/elementtree/pickle.py +++ b/examples/elementtree/pickle.py @@ -6,15 +6,22 @@ structure in distinct rows using two additional mapped entities. Note that the styles of persistence are identical, as is the structure of the main Document class. """ -from sqlalchemy import (create_engine, MetaData, Table, Column, Integer, String, - PickleType) +from sqlalchemy import ( + create_engine, + MetaData, + Table, + Column, + Integer, + String, + PickleType, +) from sqlalchemy.orm import mapper, Session import sys, os from xml.etree import ElementTree -e = create_engine('sqlite://') +e = create_engine("sqlite://") meta = MetaData() # setup a comparator for the PickleType since it's a mutable @@ -22,12 +29,15 @@ meta = MetaData() def are_elements_equal(x, y): return x == y + # stores a top level record of an XML document. # the "element" column will store the ElementTree document as a BLOB. -documents = Table('documents', meta, - Column('document_id', Integer, primary_key=True), - Column('filename', String(30), unique=True), - Column('element', PickleType(comparator=are_elements_equal)) +documents = Table( + "documents", + meta, + Column("document_id", Integer, primary_key=True), + Column("filename", String(30), unique=True), + Column("element", PickleType(comparator=are_elements_equal)), ) meta.create_all(e) @@ -39,6 +49,7 @@ class Document(object): self.filename = name self.element = element + # setup mapper. mapper(Document, documents) @@ -58,4 +69,3 @@ document = session.query(Document).filter_by(filename="test.xml").first() # print document.element.write(sys.stdout) - diff --git a/examples/generic_associations/__init__.py b/examples/generic_associations/__init__.py index b6593b4f46..9d103b73e7 100644 --- a/examples/generic_associations/__init__.py +++ b/examples/generic_associations/__init__.py @@ -15,4 +15,4 @@ are modernized versions of recipes presented in the 2007 blog post .. autosource:: -""" \ No newline at end of file +""" diff --git a/examples/generic_associations/discriminator_on_association.py b/examples/generic_associations/discriminator_on_association.py index c3501eefbe..38f2370f3d 100644 --- a/examples/generic_associations/discriminator_on_association.py +++ b/examples/generic_associations/discriminator_on_association.py @@ -16,27 +16,31 @@ objects, but is also slightly more complex. """ from sqlalchemy.ext.declarative import as_declarative, declared_attr -from sqlalchemy import create_engine, Integer, Column, \ - String, ForeignKey +from sqlalchemy import create_engine, Integer, Column, String, ForeignKey from sqlalchemy.orm import Session, relationship, backref from sqlalchemy.ext.associationproxy import association_proxy + @as_declarative() class Base(object): """Base class which provides automated table name and surrogate primary key column. """ + @declared_attr def __tablename__(cls): return cls.__name__.lower() + id = Column(Integer, primary_key=True) + class AddressAssociation(Base): """Associates a collection of Address objects with a particular parent. """ + __tablename__ = "address_association" discriminator = Column(String) @@ -44,6 +48,7 @@ class AddressAssociation(Base): __mapper_args__ = {"polymorphic_on": discriminator} + class Address(Base): """The Address class. @@ -51,6 +56,7 @@ class Address(Base): single table. """ + association_id = Column(Integer, ForeignKey("address_association.id")) street = Column(String) city = Column(String) @@ -60,15 +66,20 @@ class Address(Base): parent = association_proxy("association", "parent") def __repr__(self): - return "%s(street=%r, city=%r, zip=%r)" % \ - (self.__class__.__name__, self.street, - self.city, self.zip) + return "%s(street=%r, city=%r, zip=%r)" % ( + self.__class__.__name__, + self.street, + self.city, + self.zip, + ) + class HasAddresses(object): """HasAddresses mixin, creates a relationship to the address_association table for each parent. """ + @declared_attr def address_association_id(cls): return Column(Integer, ForeignKey("address_association.id")) @@ -79,63 +90,62 @@ class HasAddresses(object): discriminator = name.lower() assoc_cls = type( - "%sAddressAssociation" % name, - (AddressAssociation, ), - dict( - __tablename__=None, - __mapper_args__={ - "polymorphic_identity": discriminator - } - ) - ) + "%sAddressAssociation" % name, + (AddressAssociation,), + dict( + __tablename__=None, + __mapper_args__={"polymorphic_identity": discriminator}, + ), + ) cls.addresses = association_proxy( - "address_association", "addresses", - creator=lambda addresses: assoc_cls(addresses=addresses) - ) - return relationship(assoc_cls, - backref=backref("parent", uselist=False)) + "address_association", + "addresses", + creator=lambda addresses: assoc_cls(addresses=addresses), + ) + return relationship( + assoc_cls, backref=backref("parent", uselist=False) + ) class Customer(HasAddresses, Base): name = Column(String) + class Supplier(HasAddresses, Base): company_name = Column(String) -engine = create_engine('sqlite://', echo=True) + +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) -session.add_all([ - Customer( - name='customer 1', - addresses=[ - Address( - street='123 anywhere street', - city="New York", - zip="10110"), - Address( - street='40 main street', - city="San Francisco", - zip="95732") - ] - ), - Supplier( - company_name="Ace Hammers", - addresses=[ - Address( - street='2569 west elm', - city="Detroit", - zip="56785") - ] - ), -]) +session.add_all( + [ + Customer( + name="customer 1", + addresses=[ + Address( + street="123 anywhere street", city="New York", zip="10110" + ), + Address( + street="40 main street", city="San Francisco", zip="95732" + ), + ], + ), + Supplier( + company_name="Ace Hammers", + addresses=[ + Address(street="2569 west elm", city="Detroit", zip="56785") + ], + ), + ] +) session.commit() for customer in session.query(Customer): for address in customer.addresses: print(address) - print(address.parent) \ No newline at end of file + print(address.parent) diff --git a/examples/generic_associations/generic_fk.py b/examples/generic_associations/generic_fk.py index 31d2c138d2..ded8f749d1 100644 --- a/examples/generic_associations/generic_fk.py +++ b/examples/generic_associations/generic_fk.py @@ -20,8 +20,7 @@ or "table_per_association" instead of this approach. """ from sqlalchemy.ext.declarative import as_declarative, declared_attr -from sqlalchemy import create_engine, Integer, Column, \ - String, and_ +from sqlalchemy import create_engine, Integer, Column, String, and_ from sqlalchemy.orm import Session, relationship, foreign, remote, backref from sqlalchemy import event @@ -32,11 +31,14 @@ class Base(object): and surrogate primary key column. """ + @declared_attr def __tablename__(cls): return cls.__name__.lower() + id = Column(Integer, primary_key=True) + class Address(Base): """The Address class. @@ -44,6 +46,7 @@ class Address(Base): single table. """ + street = Column(String) city = Column(String) zip = Column(String) @@ -66,9 +69,13 @@ class Address(Base): return getattr(self, "parent_%s" % self.discriminator) def __repr__(self): - return "%s(street=%r, city=%r, zip=%r)" % \ - (self.__class__.__name__, self.street, - self.city, self.zip) + return "%s(street=%r, city=%r, zip=%r)" % ( + self.__class__.__name__, + self.street, + self.city, + self.zip, + ) + class HasAddresses(object): """HasAddresses mixin, creates a relationship to @@ -76,63 +83,66 @@ class HasAddresses(object): """ + @event.listens_for(HasAddresses, "mapper_configured", propagate=True) def setup_listener(mapper, class_): name = class_.__name__ discriminator = name.lower() - class_.addresses = relationship(Address, - primaryjoin=and_( - class_.id == foreign(remote(Address.parent_id)), - Address.discriminator == discriminator - ), - backref=backref( - "parent_%s" % discriminator, - primaryjoin=remote(class_.id) == foreign(Address.parent_id) - ) - ) + class_.addresses = relationship( + Address, + primaryjoin=and_( + class_.id == foreign(remote(Address.parent_id)), + Address.discriminator == discriminator, + ), + backref=backref( + "parent_%s" % discriminator, + primaryjoin=remote(class_.id) == foreign(Address.parent_id), + ), + ) + @event.listens_for(class_.addresses, "append") def append_address(target, value, initiator): value.discriminator = discriminator + class Customer(HasAddresses, Base): name = Column(String) + class Supplier(HasAddresses, Base): company_name = Column(String) -engine = create_engine('sqlite://', echo=True) + +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) -session.add_all([ - Customer( - name='customer 1', - addresses=[ - Address( - street='123 anywhere street', - city="New York", - zip="10110"), - Address( - street='40 main street', - city="San Francisco", - zip="95732") - ] - ), - Supplier( - company_name="Ace Hammers", - addresses=[ - Address( - street='2569 west elm', - city="Detroit", - zip="56785") - ] - ), -]) +session.add_all( + [ + Customer( + name="customer 1", + addresses=[ + Address( + street="123 anywhere street", city="New York", zip="10110" + ), + Address( + street="40 main street", city="San Francisco", zip="95732" + ), + ], + ), + Supplier( + company_name="Ace Hammers", + addresses=[ + Address(street="2569 west elm", city="Detroit", zip="56785") + ], + ), + ] +) session.commit() for customer in session.query(Customer): for address in customer.addresses: print(address) - print(address.parent) \ No newline at end of file + print(address.parent) diff --git a/examples/generic_associations/table_per_association.py b/examples/generic_associations/table_per_association.py index d54d2f1fae..7de2469345 100644 --- a/examples/generic_associations/table_per_association.py +++ b/examples/generic_associations/table_per_association.py @@ -12,21 +12,31 @@ has no dependency on the system. """ from sqlalchemy.ext.declarative import as_declarative, declared_attr -from sqlalchemy import create_engine, Integer, Column, \ - String, ForeignKey, Table +from sqlalchemy import ( + create_engine, + Integer, + Column, + String, + ForeignKey, + Table, +) from sqlalchemy.orm import Session, relationship + @as_declarative() class Base(object): """Base class which provides automated table name and surrogate primary key column. """ + @declared_attr def __tablename__(cls): return cls.__name__.lower() + id = Column(Integer, primary_key=True) + class Address(Base): """The Address class. @@ -34,72 +44,79 @@ class Address(Base): single table. """ + street = Column(String) city = Column(String) zip = Column(String) def __repr__(self): - return "%s(street=%r, city=%r, zip=%r)" % \ - (self.__class__.__name__, self.street, - self.city, self.zip) + return "%s(street=%r, city=%r, zip=%r)" % ( + self.__class__.__name__, + self.street, + self.city, + self.zip, + ) + class HasAddresses(object): """HasAddresses mixin, creates a new address_association table for each parent. """ + @declared_attr def addresses(cls): address_association = Table( "%s_addresses" % cls.__tablename__, cls.metadata, - Column("address_id", ForeignKey("address.id"), - primary_key=True), - Column("%s_id" % cls.__tablename__, - ForeignKey("%s.id" % cls.__tablename__), - primary_key=True), + Column("address_id", ForeignKey("address.id"), primary_key=True), + Column( + "%s_id" % cls.__tablename__, + ForeignKey("%s.id" % cls.__tablename__), + primary_key=True, + ), ) return relationship(Address, secondary=address_association) + class Customer(HasAddresses, Base): name = Column(String) + class Supplier(HasAddresses, Base): company_name = Column(String) -engine = create_engine('sqlite://', echo=True) + +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) -session.add_all([ - Customer( - name='customer 1', - addresses=[ - Address( - street='123 anywhere street', - city="New York", - zip="10110"), - Address( - street='40 main street', - city="San Francisco", - zip="95732") - ] - ), - Supplier( - company_name="Ace Hammers", - addresses=[ - Address( - street='2569 west elm', - city="Detroit", - zip="56785") - ] - ), -]) +session.add_all( + [ + Customer( + name="customer 1", + addresses=[ + Address( + street="123 anywhere street", city="New York", zip="10110" + ), + Address( + street="40 main street", city="San Francisco", zip="95732" + ), + ], + ), + Supplier( + company_name="Ace Hammers", + addresses=[ + Address(street="2569 west elm", city="Detroit", zip="56785") + ], + ), + ] +) session.commit() for customer in session.query(Customer): for address in customer.addresses: print(address) - # no parent here \ No newline at end of file + # no parent here diff --git a/examples/generic_associations/table_per_related.py b/examples/generic_associations/table_per_related.py index 51c9f1b262..9c5e0e1798 100644 --- a/examples/generic_associations/table_per_related.py +++ b/examples/generic_associations/table_per_related.py @@ -20,17 +20,21 @@ from sqlalchemy.ext.declarative import as_declarative, declared_attr from sqlalchemy import create_engine, Integer, Column, String, ForeignKey from sqlalchemy.orm import Session, relationship + @as_declarative() class Base(object): """Base class which provides automated table name and surrogate primary key column. """ + @declared_attr def __tablename__(cls): return cls.__name__.lower() + id = Column(Integer, primary_key=True) + class Address(object): """Define columns that will be present in each 'Address' table. @@ -40,74 +44,82 @@ class Address(object): should be set up using @declared_attr. """ + street = Column(String) city = Column(String) zip = Column(String) def __repr__(self): - return "%s(street=%r, city=%r, zip=%r)" % \ - (self.__class__.__name__, self.street, - self.city, self.zip) + return "%s(street=%r, city=%r, zip=%r)" % ( + self.__class__.__name__, + self.street, + self.city, + self.zip, + ) + class HasAddresses(object): """HasAddresses mixin, creates a new Address class for each parent. """ + @declared_attr def addresses(cls): cls.Address = type( "%sAddress" % cls.__name__, - (Address, Base,), + (Address, Base), dict( - __tablename__="%s_address" % - cls.__tablename__, - parent_id=Column(Integer, - ForeignKey("%s.id" % cls.__tablename__)), - parent=relationship(cls) - ) + __tablename__="%s_address" % cls.__tablename__, + parent_id=Column( + Integer, ForeignKey("%s.id" % cls.__tablename__) + ), + parent=relationship(cls), + ), ) return relationship(cls.Address) + class Customer(HasAddresses, Base): name = Column(String) + class Supplier(HasAddresses, Base): company_name = Column(String) -engine = create_engine('sqlite://', echo=True) + +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) -session.add_all([ - Customer( - name='customer 1', - addresses=[ - Customer.Address( - street='123 anywhere street', - city="New York", - zip="10110"), - Customer.Address( - street='40 main street', - city="San Francisco", - zip="95732") - ] - ), - Supplier( - company_name="Ace Hammers", - addresses=[ - Supplier.Address( - street='2569 west elm', - city="Detroit", - zip="56785") - ] - ), -]) +session.add_all( + [ + Customer( + name="customer 1", + addresses=[ + Customer.Address( + street="123 anywhere street", city="New York", zip="10110" + ), + Customer.Address( + street="40 main street", city="San Francisco", zip="95732" + ), + ], + ), + Supplier( + company_name="Ace Hammers", + addresses=[ + Supplier.Address( + street="2569 west elm", city="Detroit", zip="56785" + ) + ], + ), + ] +) session.commit() for customer in session.query(Customer): for address in customer.addresses: print(address) - print(address.parent) \ No newline at end of file + print(address.parent) diff --git a/examples/graphs/__init__.py b/examples/graphs/__init__.py index 57d41453b0..0f8fe58a7b 100644 --- a/examples/graphs/__init__.py +++ b/examples/graphs/__init__.py @@ -10,4 +10,4 @@ and querying for lower- and upper- neighbors are illustrated:: .. autosource:: -""" \ No newline at end of file +""" diff --git a/examples/graphs/directed_graph.py b/examples/graphs/directed_graph.py index 7bcfc56835..af85a4295f 100644 --- a/examples/graphs/directed_graph.py +++ b/examples/graphs/directed_graph.py @@ -1,7 +1,6 @@ """a directed graph example.""" -from sqlalchemy import Column, Integer, ForeignKey, \ - create_engine +from sqlalchemy import Column, Integer, ForeignKey, create_engine from sqlalchemy.orm import relationship, sessionmaker from sqlalchemy.ext.declarative import declarative_base @@ -9,7 +8,7 @@ Base = declarative_base() class Node(Base): - __tablename__ = 'node' + __tablename__ = "node" node_id = Column(Integer, primary_key=True) @@ -21,33 +20,26 @@ class Node(Base): class Edge(Base): - __tablename__ = 'edge' + __tablename__ = "edge" - lower_id = Column( - Integer, - ForeignKey('node.node_id'), - primary_key=True) + lower_id = Column(Integer, ForeignKey("node.node_id"), primary_key=True) - higher_id = Column( - Integer, - ForeignKey('node.node_id'), - primary_key=True) + higher_id = Column(Integer, ForeignKey("node.node_id"), primary_key=True) lower_node = relationship( - Node, - primaryjoin=lower_id == Node.node_id, - backref='lower_edges') + Node, primaryjoin=lower_id == Node.node_id, backref="lower_edges" + ) higher_node = relationship( - Node, - primaryjoin=higher_id == Node.node_id, - backref='higher_edges') + Node, primaryjoin=higher_id == Node.node_id, backref="higher_edges" + ) def __init__(self, n1, n2): self.lower_node = n1 self.higher_node = n2 -engine = create_engine('sqlite://', echo=True) + +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = sessionmaker(engine)() @@ -80,4 +72,3 @@ assert [x for x in n3.higher_neighbors()] == [n6] assert [x for x in n3.lower_neighbors()] == [n1] assert [x for x in n2.lower_neighbors()] == [n1] assert [x for x in n2.higher_neighbors()] == [n1, n5, n7] - diff --git a/examples/inheritance/concrete.py b/examples/inheritance/concrete.py index 258f410250..2245aa4e0e 100644 --- a/examples/inheritance/concrete.py +++ b/examples/inheritance/concrete.py @@ -1,7 +1,14 @@ """Concrete-table (table-per-class) inheritance example.""" -from sqlalchemy import Column, Integer, String, \ - ForeignKey, create_engine, inspect, or_ +from sqlalchemy import ( + Column, + Integer, + String, + ForeignKey, + create_engine, + inspect, + or_, +) from sqlalchemy.orm import relationship, Session, with_polymorphic from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import ConcreteBase @@ -11,107 +18,105 @@ Base = declarative_base() class Company(Base): - __tablename__ = 'company' + __tablename__ = "company" id = Column(Integer, primary_key=True) name = Column(String(50)) employees = relationship( - "Person", - back_populates='company', - cascade='all, delete-orphan') + "Person", back_populates="company", cascade="all, delete-orphan" + ) def __repr__(self): return "Company %s" % self.name class Person(ConcreteBase, Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) - company_id = Column(ForeignKey('company.id')) + company_id = Column(ForeignKey("company.id")) name = Column(String(50)) company = relationship("Company", back_populates="employees") - __mapper_args__ = { - 'polymorphic_identity': 'person', - } + __mapper_args__ = {"polymorphic_identity": "person"} def __repr__(self): return "Ordinary person %s" % self.name class Engineer(Person): - __tablename__ = 'engineer' + __tablename__ = "engineer" id = Column(Integer, primary_key=True) name = Column(String(50)) - company_id = Column(ForeignKey('company.id')) + company_id = Column(ForeignKey("company.id")) status = Column(String(30)) engineer_name = Column(String(30)) primary_language = Column(String(30)) company = relationship("Company", back_populates="employees") - __mapper_args__ = { - 'polymorphic_identity': 'engineer', - 'concrete': True - } + __mapper_args__ = {"polymorphic_identity": "engineer", "concrete": True} def __repr__(self): return ( "Engineer %s, status %s, engineer_name %s, " - "primary_language %s" % - ( - self.name, self.status, - self.engineer_name, self.primary_language) + "primary_language %s" + % ( + self.name, + self.status, + self.engineer_name, + self.primary_language, + ) ) class Manager(Person): - __tablename__ = 'manager' + __tablename__ = "manager" id = Column(Integer, primary_key=True) name = Column(String(50)) - company_id = Column(ForeignKey('company.id')) + company_id = Column(ForeignKey("company.id")) status = Column(String(30)) manager_name = Column(String(30)) company = relationship("Company", back_populates="employees") - __mapper_args__ = { - 'polymorphic_identity': 'manager', - 'concrete': True - } + __mapper_args__ = {"polymorphic_identity": "manager", "concrete": True} def __repr__(self): return "Manager %s, status %s, manager_name %s" % ( - self.name, self.status, self.manager_name) + self.name, + self.status, + self.manager_name, + ) -engine = create_engine('sqlite://', echo=True) +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) -c = Company(name='company1', employees=[ - Manager( - name='pointy haired boss', - status='AAB', - manager_name='manager1'), - Engineer( - name='dilbert', - status='BBA', - engineer_name='engineer1', - primary_language='java'), - Person(name='joesmith'), - Engineer( - name='wally', - status='CGG', - engineer_name='engineer2', - primary_language='python'), - Manager( - name='jsmith', - status='ABA', - manager_name='manager2') -]) +c = Company( + name="company1", + employees=[ + Manager( + name="pointy haired boss", status="AAB", manager_name="manager1" + ), + Engineer( + name="dilbert", + status="BBA", + engineer_name="engineer1", + primary_language="java", + ), + Person(name="joesmith"), + Engineer( + name="wally", + status="CGG", + engineer_name="engineer2", + primary_language="python", + ), + Manager(name="jsmith", status="ABA", manager_name="manager2"), + ], +) session.add(c) session.commit() @@ -120,14 +125,15 @@ c = session.query(Company).get(1) for e in c.employees: print(e, inspect(e).key, e.company) assert set([e.name for e in c.employees]) == set( - ['pointy haired boss', 'dilbert', 'joesmith', 'wally', 'jsmith']) + ["pointy haired boss", "dilbert", "joesmith", "wally", "jsmith"] +) print("\n") -dilbert = session.query(Person).filter_by(name='dilbert').one() -dilbert2 = session.query(Engineer).filter_by(name='dilbert').one() +dilbert = session.query(Person).filter_by(name="dilbert").one() +dilbert2 = session.query(Engineer).filter_by(name="dilbert").one() assert dilbert is dilbert2 -dilbert.engineer_name = 'hes dilbert!' +dilbert.engineer_name = "hes dilbert!" session.commit() @@ -138,24 +144,28 @@ for e in c.employees: # query using with_polymorphic. eng_manager = with_polymorphic(Person, [Engineer, Manager]) print( - session.query(eng_manager). - filter( + session.query(eng_manager) + .filter( or_( - eng_manager.Engineer.engineer_name == 'engineer1', - eng_manager.Manager.manager_name == 'manager2' + eng_manager.Engineer.engineer_name == "engineer1", + eng_manager.Manager.manager_name == "manager2", ) - ).all() + ) + .all() ) # illustrate join from Company eng_manager = with_polymorphic(Person, [Engineer, Manager]) print( - session.query(Company). - join( - Company.employees.of_type(eng_manager) - ).filter( - or_(eng_manager.Engineer.engineer_name == 'engineer1', - eng_manager.Manager.manager_name == 'manager2') - ).all()) + session.query(Company) + .join(Company.employees.of_type(eng_manager)) + .filter( + or_( + eng_manager.Engineer.engineer_name == "engineer1", + eng_manager.Manager.manager_name == "manager2", + ) + ) + .all() +) session.commit() diff --git a/examples/inheritance/joined.py b/examples/inheritance/joined.py index f9322158ea..a3a61d7630 100644 --- a/examples/inheritance/joined.py +++ b/examples/inheritance/joined.py @@ -1,7 +1,14 @@ """Joined-table (table-per-subclass) inheritance example.""" -from sqlalchemy import Column, Integer, String, \ - ForeignKey, create_engine, inspect, or_ +from sqlalchemy import ( + Column, + Integer, + String, + ForeignKey, + create_engine, + inspect, + or_, +) from sqlalchemy.orm import relationship, Session, with_polymorphic from sqlalchemy.ext.declarative import declarative_base @@ -9,31 +16,30 @@ Base = declarative_base() class Company(Base): - __tablename__ = 'company' + __tablename__ = "company" id = Column(Integer, primary_key=True) name = Column(String(50)) employees = relationship( - "Person", - back_populates='company', - cascade='all, delete-orphan') + "Person", back_populates="company", cascade="all, delete-orphan" + ) def __repr__(self): return "Company %s" % self.name class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) - company_id = Column(ForeignKey('company.id')) + company_id = Column(ForeignKey("company.id")) name = Column(String(50)) type = Column(String(50)) company = relationship("Company", back_populates="employees") __mapper_args__ = { - 'polymorphic_identity': 'person', - 'polymorphic_on': type + "polymorphic_identity": "person", + "polymorphic_on": type, } def __repr__(self): @@ -41,67 +47,70 @@ class Person(Base): class Engineer(Person): - __tablename__ = 'engineer' - id = Column(ForeignKey('person.id'), primary_key=True) + __tablename__ = "engineer" + id = Column(ForeignKey("person.id"), primary_key=True) status = Column(String(30)) engineer_name = Column(String(30)) primary_language = Column(String(30)) - __mapper_args__ = { - 'polymorphic_identity': 'engineer', - } + __mapper_args__ = {"polymorphic_identity": "engineer"} def __repr__(self): return ( "Engineer %s, status %s, engineer_name %s, " - "primary_language %s" % - ( - self.name, self.status, - self.engineer_name, self.primary_language) + "primary_language %s" + % ( + self.name, + self.status, + self.engineer_name, + self.primary_language, + ) ) class Manager(Person): - __tablename__ = 'manager' - id = Column(ForeignKey('person.id'), primary_key=True) + __tablename__ = "manager" + id = Column(ForeignKey("person.id"), primary_key=True) status = Column(String(30)) manager_name = Column(String(30)) - __mapper_args__ = { - 'polymorphic_identity': 'manager', - } + __mapper_args__ = {"polymorphic_identity": "manager"} def __repr__(self): return "Manager %s, status %s, manager_name %s" % ( - self.name, self.status, self.manager_name) + self.name, + self.status, + self.manager_name, + ) -engine = create_engine('sqlite://', echo=True) +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) -c = Company(name='company1', employees=[ - Manager( - name='pointy haired boss', - status='AAB', - manager_name='manager1'), - Engineer( - name='dilbert', - status='BBA', - engineer_name='engineer1', - primary_language='java'), - Person(name='joesmith'), - Engineer( - name='wally', - status='CGG', - engineer_name='engineer2', - primary_language='python'), - Manager( - name='jsmith', - status='ABA', - manager_name='manager2') -]) +c = Company( + name="company1", + employees=[ + Manager( + name="pointy haired boss", status="AAB", manager_name="manager1" + ), + Engineer( + name="dilbert", + status="BBA", + engineer_name="engineer1", + primary_language="java", + ), + Person(name="joesmith"), + Engineer( + name="wally", + status="CGG", + engineer_name="engineer2", + primary_language="python", + ), + Manager(name="jsmith", status="ABA", manager_name="manager2"), + ], +) session.add(c) session.commit() @@ -110,14 +119,15 @@ c = session.query(Company).get(1) for e in c.employees: print(e, inspect(e).key, e.company) assert set([e.name for e in c.employees]) == set( - ['pointy haired boss', 'dilbert', 'joesmith', 'wally', 'jsmith']) + ["pointy haired boss", "dilbert", "joesmith", "wally", "jsmith"] +) print("\n") -dilbert = session.query(Person).filter_by(name='dilbert').one() -dilbert2 = session.query(Engineer).filter_by(name='dilbert').one() +dilbert = session.query(Person).filter_by(name="dilbert").one() +dilbert2 = session.query(Engineer).filter_by(name="dilbert").one() assert dilbert is dilbert2 -dilbert.engineer_name = 'hes dilbert!' +dilbert.engineer_name = "hes dilbert!" session.commit() @@ -128,13 +138,14 @@ for e in c.employees: # query using with_polymorphic. eng_manager = with_polymorphic(Person, [Engineer, Manager]) print( - session.query(eng_manager). - filter( + session.query(eng_manager) + .filter( or_( - eng_manager.Engineer.engineer_name == 'engineer1', - eng_manager.Manager.manager_name == 'manager2' + eng_manager.Engineer.engineer_name == "engineer1", + eng_manager.Manager.manager_name == "manager2", ) - ).all() + ) + .all() ) # illustrate join from Company. @@ -144,12 +155,15 @@ print( # loading. eng_manager = with_polymorphic(Person, [Engineer, Manager], flat=True) print( - session.query(Company). - join( - Company.employees.of_type(eng_manager) - ).filter( - or_(eng_manager.Engineer.engineer_name == 'engineer1', - eng_manager.Manager.manager_name == 'manager2') - ).all()) + session.query(Company) + .join(Company.employees.of_type(eng_manager)) + .filter( + or_( + eng_manager.Engineer.engineer_name == "engineer1", + eng_manager.Manager.manager_name == "manager2", + ) + ) + .all() +) session.commit() diff --git a/examples/inheritance/single.py b/examples/inheritance/single.py index 56397f540e..46d690c943 100644 --- a/examples/inheritance/single.py +++ b/examples/inheritance/single.py @@ -1,7 +1,14 @@ """Single-table (table-per-hierarchy) inheritance example.""" -from sqlalchemy import Column, Integer, String, \ - ForeignKey, create_engine, inspect, or_ +from sqlalchemy import ( + Column, + Integer, + String, + ForeignKey, + create_engine, + inspect, + or_, +) from sqlalchemy.orm import relationship, Session, with_polymorphic from sqlalchemy.ext.declarative import declarative_base, declared_attr @@ -9,31 +16,30 @@ Base = declarative_base() class Company(Base): - __tablename__ = 'company' + __tablename__ = "company" id = Column(Integer, primary_key=True) name = Column(String(50)) employees = relationship( - "Person", - back_populates='company', - cascade='all, delete-orphan') + "Person", back_populates="company", cascade="all, delete-orphan" + ) def __repr__(self): return "Company %s" % self.name class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) - company_id = Column(ForeignKey('company.id')) + company_id = Column(ForeignKey("company.id")) name = Column(String(50)) type = Column(String(50)) company = relationship("Company", back_populates="employees") __mapper_args__ = { - 'polymorphic_identity': 'person', - 'polymorphic_on': type + "polymorphic_identity": "person", + "polymorphic_on": type, } def __repr__(self): @@ -50,19 +56,20 @@ class Engineer(Person): # declarative/inheritance.html#resolving-column-conflicts @declared_attr def status(cls): - return Person.__table__.c.get('status', Column(String(30))) + return Person.__table__.c.get("status", Column(String(30))) - __mapper_args__ = { - 'polymorphic_identity': 'engineer', - } + __mapper_args__ = {"polymorphic_identity": "engineer"} def __repr__(self): return ( "Engineer %s, status %s, engineer_name %s, " - "primary_language %s" % - ( - self.name, self.status, - self.engineer_name, self.primary_language) + "primary_language %s" + % ( + self.name, + self.status, + self.engineer_name, + self.primary_language, + ) ) @@ -71,43 +78,45 @@ class Manager(Person): @declared_attr def status(cls): - return Person.__table__.c.get('status', Column(String(30))) + return Person.__table__.c.get("status", Column(String(30))) - __mapper_args__ = { - 'polymorphic_identity': 'manager', - } + __mapper_args__ = {"polymorphic_identity": "manager"} def __repr__(self): return "Manager %s, status %s, manager_name %s" % ( - self.name, self.status, self.manager_name) + self.name, + self.status, + self.manager_name, + ) -engine = create_engine('sqlite://', echo=True) +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) -c = Company(name='company1', employees=[ - Manager( - name='pointy haired boss', - status='AAB', - manager_name='manager1'), - Engineer( - name='dilbert', - status='BBA', - engineer_name='engineer1', - primary_language='java'), - Person(name='joesmith'), - Engineer( - name='wally', - status='CGG', - engineer_name='engineer2', - primary_language='python'), - Manager( - name='jsmith', - status='ABA', - manager_name='manager2') -]) +c = Company( + name="company1", + employees=[ + Manager( + name="pointy haired boss", status="AAB", manager_name="manager1" + ), + Engineer( + name="dilbert", + status="BBA", + engineer_name="engineer1", + primary_language="java", + ), + Person(name="joesmith"), + Engineer( + name="wally", + status="CGG", + engineer_name="engineer2", + primary_language="python", + ), + Manager(name="jsmith", status="ABA", manager_name="manager2"), + ], +) session.add(c) session.commit() @@ -116,14 +125,15 @@ c = session.query(Company).get(1) for e in c.employees: print(e, inspect(e).key, e.company) assert set([e.name for e in c.employees]) == set( - ['pointy haired boss', 'dilbert', 'joesmith', 'wally', 'jsmith']) + ["pointy haired boss", "dilbert", "joesmith", "wally", "jsmith"] +) print("\n") -dilbert = session.query(Person).filter_by(name='dilbert').one() -dilbert2 = session.query(Engineer).filter_by(name='dilbert').one() +dilbert = session.query(Person).filter_by(name="dilbert").one() +dilbert2 = session.query(Engineer).filter_by(name="dilbert").one() assert dilbert is dilbert2 -dilbert.engineer_name = 'hes dilbert!' +dilbert.engineer_name = "hes dilbert!" session.commit() @@ -134,24 +144,28 @@ for e in c.employees: # query using with_polymorphic. eng_manager = with_polymorphic(Person, [Engineer, Manager]) print( - session.query(eng_manager). - filter( + session.query(eng_manager) + .filter( or_( - eng_manager.Engineer.engineer_name == 'engineer1', - eng_manager.Manager.manager_name == 'manager2' + eng_manager.Engineer.engineer_name == "engineer1", + eng_manager.Manager.manager_name == "manager2", ) - ).all() + ) + .all() ) # illustrate join from Company, eng_manager = with_polymorphic(Person, [Engineer, Manager]) print( - session.query(Company). - join( - Company.employees.of_type(eng_manager) - ).filter( - or_(eng_manager.Engineer.engineer_name == 'engineer1', - eng_manager.Manager.manager_name == 'manager2') - ).all()) + session.query(Company) + .join(Company.employees.of_type(eng_manager)) + .filter( + or_( + eng_manager.Engineer.engineer_name == "engineer1", + eng_manager.Manager.manager_name == "manager2", + ) + ) + .all() +) session.commit() diff --git a/examples/join_conditions/__init__.py b/examples/join_conditions/__init__.py index 3a561d0849..d67eb68e43 100644 --- a/examples/join_conditions/__init__.py +++ b/examples/join_conditions/__init__.py @@ -4,4 +4,4 @@ of join conditions. .. autosource:: -""" \ No newline at end of file +""" diff --git a/examples/join_conditions/cast.py b/examples/join_conditions/cast.py index 246bc1d57d..7ea7756899 100644 --- a/examples/join_conditions/cast.py +++ b/examples/join_conditions/cast.py @@ -35,6 +35,7 @@ from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() + class StringAsInt(TypeDecorator): """Coerce string->integer type. @@ -44,52 +45,55 @@ class StringAsInt(TypeDecorator): on the child during a flush. """ + impl = Integer + def process_bind_param(self, value, dialect): if value is not None: value = int(value) return value + class A(Base): """Parent. The referenced column is a string type.""" - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) a_id = Column(String) + class B(Base): """Child. The column we reference 'A' with is an integer.""" - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) a_id = Column(StringAsInt) - a = relationship("A", - # specify primaryjoin. The string form is optional - # here, but note that Declarative makes available all - # of the built-in functions we might need, including - # cast() and foreign(). - primaryjoin="cast(A.a_id, Integer) == foreign(B.a_id)", - backref="bs") + a = relationship( + "A", + # specify primaryjoin. The string form is optional + # here, but note that Declarative makes available all + # of the built-in functions we might need, including + # cast() and foreign(). + primaryjoin="cast(A.a_id, Integer) == foreign(B.a_id)", + backref="bs", + ) + # we demonstrate with SQLite, but the important part # is the CAST rendered in the SQL output. -e = create_engine('sqlite://', echo=True) +e = create_engine("sqlite://", echo=True) Base.metadata.create_all(e) s = Session(e) -s.add_all([ - A(a_id="1"), - A(a_id="2", bs=[B(), B()]), - A(a_id="3", bs=[B()]), -]) +s.add_all([A(a_id="1"), A(a_id="2", bs=[B(), B()]), A(a_id="3", bs=[B()])]) s.commit() b1 = s.query(B).filter_by(a_id="2").first() print(b1.a) a1 = s.query(A).filter_by(a_id="2").first() -print(a1.bs) \ No newline at end of file +print(a1.bs) diff --git a/examples/join_conditions/threeway.py b/examples/join_conditions/threeway.py index 13df0f349a..2570026374 100644 --- a/examples/join_conditions/threeway.py +++ b/examples/join_conditions/threeway.py @@ -39,46 +39,56 @@ from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() + class First(Base): - __tablename__ = 'first' + __tablename__ = "first" first_id = Column(Integer, primary_key=True) partition_key = Column(String) def __repr__(self): - return ("First(%s, %s)" % (self.first_id, self.partition_key)) + return "First(%s, %s)" % (self.first_id, self.partition_key) + class Second(Base): - __tablename__ = 'second' + __tablename__ = "second" first_id = Column(Integer, primary_key=True) other_id = Column(Integer, primary_key=True) + class Partitioned(Base): - __tablename__ = 'partitioned' + __tablename__ = "partitioned" other_id = Column(Integer, primary_key=True) partition_key = Column(String, primary_key=True) def __repr__(self): - return ("Partitioned(%s, %s)" % (self.other_id, self.partition_key)) + return "Partitioned(%s, %s)" % (self.other_id, self.partition_key) j = join(Partitioned, Second, Partitioned.other_id == Second.other_id) -partitioned_second = mapper(Partitioned, j, non_primary=True, properties={ +partitioned_second = mapper( + Partitioned, + j, + non_primary=True, + properties={ # note we need to disambiguate columns here - the join() # will provide them as j.c._ for access, # but they retain their real names in the mapping - "other_id": [j.c.partitioned_other_id, j.c.second_other_id], - }) + "other_id": [j.c.partitioned_other_id, j.c.second_other_id] + }, +) First.partitioned = relationship( - partitioned_second, - primaryjoin=and_( - First.partition_key == partitioned_second.c.partition_key, - First.first_id == foreign(partitioned_second.c.first_id) - ), innerjoin=True) + partitioned_second, + primaryjoin=and_( + First.partition_key == partitioned_second.c.partition_key, + First.first_id == foreign(partitioned_second.c.first_id), + ), + innerjoin=True, +) # when using any database other than SQLite, we will get a nested # join, e.g. "first JOIN (partitioned JOIN second ON ..) ON ..". @@ -87,17 +97,19 @@ e = create_engine("sqlite://", echo=True) Base.metadata.create_all(e) s = Session(e) -s.add_all([ - First(first_id=1, partition_key='p1'), - First(first_id=2, partition_key='p1'), - First(first_id=3, partition_key='p2'), - Second(first_id=1, other_id=1), - Second(first_id=2, other_id=1), - Second(first_id=3, other_id=2), - Partitioned(partition_key='p1', other_id=1), - Partitioned(partition_key='p1', other_id=2), - Partitioned(partition_key='p2', other_id=2), -]) +s.add_all( + [ + First(first_id=1, partition_key="p1"), + First(first_id=2, partition_key="p1"), + First(first_id=3, partition_key="p2"), + Second(first_id=1, other_id=1), + Second(first_id=2, other_id=1), + Second(first_id=3, other_id=2), + Partitioned(partition_key="p1", other_id=1), + Partitioned(partition_key="p1", other_id=2), + Partitioned(partition_key="p2", other_id=2), + ] +) s.commit() for row in s.query(First, Partitioned).join(First.partitioned): diff --git a/examples/large_collection/large_collection.py b/examples/large_collection/large_collection.py index 82d2e554b1..eb014c6cb0 100644 --- a/examples/large_collection/large_collection.py +++ b/examples/large_collection/large_collection.py @@ -1,54 +1,76 @@ - -from sqlalchemy import (MetaData, Table, Column, Integer, String, ForeignKey, - create_engine) -from sqlalchemy.orm import (mapper, relationship, sessionmaker) +from sqlalchemy import ( + MetaData, + Table, + Column, + Integer, + String, + ForeignKey, + create_engine, +) +from sqlalchemy.orm import mapper, relationship, sessionmaker meta = MetaData() -org_table = Table('organizations', meta, - Column('org_id', Integer, primary_key=True), - Column('org_name', String(50), nullable=False, key='name'), - mysql_engine='InnoDB') - -member_table = Table('members', meta, - Column('member_id', Integer, primary_key=True), - Column('member_name', String(50), nullable=False, key='name'), - Column('org_id', Integer, - ForeignKey('organizations.org_id', ondelete="CASCADE")), - mysql_engine='InnoDB') +org_table = Table( + "organizations", + meta, + Column("org_id", Integer, primary_key=True), + Column("org_name", String(50), nullable=False, key="name"), + mysql_engine="InnoDB", +) + +member_table = Table( + "members", + meta, + Column("member_id", Integer, primary_key=True), + Column("member_name", String(50), nullable=False, key="name"), + Column( + "org_id", + Integer, + ForeignKey("organizations.org_id", ondelete="CASCADE"), + ), + mysql_engine="InnoDB", +) class Organization(object): def __init__(self, name): self.name = name + class Member(object): def __init__(self, name): self.name = name -mapper(Organization, org_table, properties = { - 'members' : relationship(Member, - # Organization.members will be a Query object - no loading - # of the entire collection occurs unless requested - lazy="dynamic", - - # Member objects "belong" to their parent, are deleted when - # removed from the collection - cascade="all, delete-orphan", - - # "delete, delete-orphan" cascade does not load in objects on delete, - # allows ON DELETE CASCADE to handle it. - # this only works with a database that supports ON DELETE CASCADE - - # *not* sqlite or MySQL with MyISAM - passive_deletes=True, - ) -}) + +mapper( + Organization, + org_table, + properties={ + "members": relationship( + Member, + # Organization.members will be a Query object - no loading + # of the entire collection occurs unless requested + lazy="dynamic", + # Member objects "belong" to their parent, are deleted when + # removed from the collection + cascade="all, delete-orphan", + # "delete, delete-orphan" cascade does not load in objects on delete, + # allows ON DELETE CASCADE to handle it. + # this only works with a database that supports ON DELETE CASCADE - + # *not* sqlite or MySQL with MyISAM + passive_deletes=True, + ) + }, +) mapper(Member, member_table) -if __name__ == '__main__': - engine = create_engine("postgresql://scott:tiger@localhost/test", echo=True) +if __name__ == "__main__": + engine = create_engine( + "postgresql://scott:tiger@localhost/test", echo=True + ) meta.create_all(engine) # expire_on_commit=False means the session contents @@ -56,10 +78,10 @@ if __name__ == '__main__': sess = sessionmaker(engine, expire_on_commit=False)() # create org with some members - org = Organization('org one') - org.members.append(Member('member one')) - org.members.append(Member('member two')) - org.members.append(Member('member three')) + org = Organization("org one") + org.members.append(Member("member one")) + org.members.append(Member("member two")) + org.members.append(Member("member three")) sess.add(org) @@ -69,14 +91,14 @@ if __name__ == '__main__': # the 'members' collection is a Query. it issues # SQL as needed to load subsets of the collection. print("-------------------------\nload subset of members\n") - members = org.members.filter(member_table.c.name.like('%member t%')).all() + members = org.members.filter(member_table.c.name.like("%member t%")).all() print(members) # new Members can be appended without any # SQL being emitted to load the full collection - org.members.append(Member('member four')) - org.members.append(Member('member five')) - org.members.append(Member('member six')) + org.members.append(Member("member four")) + org.members.append(Member("member five")) + org.members.append(Member("member six")) print("-------------------------\nflush two - save 3 more members\n") sess.commit() @@ -85,7 +107,9 @@ if __name__ == '__main__': # SQL is only emitted for the head row - the Member rows # disappear automatically without the need for additional SQL. sess.delete(org) - print("-------------------------\nflush three - delete org, delete members in one statement\n") + print( + "-------------------------\nflush three - delete org, delete members in one statement\n" + ) sess.commit() print("-------------------------\nno Member rows should remain:\n") @@ -93,4 +117,4 @@ if __name__ == '__main__': sess.close() print("------------------------\ndone. dropping tables.") - meta.drop_all(engine) \ No newline at end of file + meta.drop_all(engine) diff --git a/examples/materialized_paths/materialized_paths.py b/examples/materialized_paths/materialized_paths.py index 4ded90f7ec..45ae0c1932 100644 --- a/examples/materialized_paths/materialized_paths.py +++ b/examples/materialized_paths/materialized_paths.py @@ -44,21 +44,35 @@ class Node(Base): # To find the descendants of this node, we look for nodes whose path # starts with this node's path. descendants = relationship( - "Node", viewonly=True, order_by=path, - primaryjoin=remote(foreign(path)).like(path.concat(".%"))) + "Node", + viewonly=True, + order_by=path, + primaryjoin=remote(foreign(path)).like(path.concat(".%")), + ) # Finding the ancestors is a little bit trickier. We need to create a fake # secondary table since this behaves like a many-to-many join. - secondary = select([ - id.label("id"), - func.unnest(cast(func.string_to_array( - func.regexp_replace(path, r"\.?\d+$", ""), "."), - ARRAY(Integer))).label("ancestor_id") - ]).alias() - ancestors = relationship("Node", viewonly=True, secondary=secondary, - primaryjoin=id == secondary.c.id, - secondaryjoin=secondary.c.ancestor_id == id, - order_by=path) + secondary = select( + [ + id.label("id"), + func.unnest( + cast( + func.string_to_array( + func.regexp_replace(path, r"\.?\d+$", ""), "." + ), + ARRAY(Integer), + ) + ).label("ancestor_id"), + ] + ).alias() + ancestors = relationship( + "Node", + viewonly=True, + secondary=secondary, + primaryjoin=id == secondary.c.id, + secondaryjoin=secondary.c.ancestor_id == id, + order_by=path, + ) @property def depth(self): @@ -70,38 +84,44 @@ class Node(Base): def __str__(self): root_depth = self.depth s = [str(self.id)] - s.extend(((n.depth - root_depth) * " " + str(n.id)) - for n in self.descendants) + s.extend( + ((n.depth - root_depth) * " " + str(n.id)) + for n in self.descendants + ) return "\n".join(s) def move_to(self, new_parent): new_path = new_parent.path + "." + str(self.id) for n in self.descendants: - n.path = new_path + n.path[len(self.path):] + n.path = new_path + n.path[len(self.path) :] self.path = new_path if __name__ == "__main__": - engine = create_engine("postgresql://scott:tiger@localhost/test", echo=True) + engine = create_engine( + "postgresql://scott:tiger@localhost/test", echo=True + ) Base.metadata.create_all(engine) session = Session(engine) print("-" * 80) print("create a tree") - session.add_all([ - Node(id=1, path="1"), - Node(id=2, path="1.2"), - Node(id=3, path="1.3"), - Node(id=4, path="1.3.4"), - Node(id=5, path="1.3.5"), - Node(id=6, path="1.3.6"), - Node(id=7, path="1.7"), - Node(id=8, path="1.7.8"), - Node(id=9, path="1.7.9"), - Node(id=10, path="1.7.9.10"), - Node(id=11, path="1.7.11"), - ]) + session.add_all( + [ + Node(id=1, path="1"), + Node(id=2, path="1.2"), + Node(id=3, path="1.3"), + Node(id=4, path="1.3.4"), + Node(id=5, path="1.3.5"), + Node(id=6, path="1.3.6"), + Node(id=7, path="1.7"), + Node(id=8, path="1.7.8"), + Node(id=9, path="1.7.9"), + Node(id=10, path="1.7.9.10"), + Node(id=11, path="1.7.11"), + ] + ) session.flush() print(str(session.query(Node).get(1))) diff --git a/examples/nested_sets/__init__.py b/examples/nested_sets/__init__.py index 3e73bb13e9..5fdfbcedc0 100644 --- a/examples/nested_sets/__init__.py +++ b/examples/nested_sets/__init__.py @@ -3,4 +3,4 @@ pattern for hierarchical data using the SQLAlchemy ORM. .. autosource:: -""" \ No newline at end of file +""" diff --git a/examples/nested_sets/nested_sets.py b/examples/nested_sets/nested_sets.py index c64b15b61d..705a3d279e 100644 --- a/examples/nested_sets/nested_sets.py +++ b/examples/nested_sets/nested_sets.py @@ -4,19 +4,27 @@ http://www.intelligententerprise.com/001020/celko.jhtml """ -from sqlalchemy import (create_engine, Column, Integer, String, select, case, - func) +from sqlalchemy import ( + create_engine, + Column, + Integer, + String, + select, + case, + func, +) from sqlalchemy.orm import Session, aliased from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import event Base = declarative_base() + class Employee(Base): - __tablename__ = 'personnel' + __tablename__ = "personnel" __mapper_args__ = { - 'batch': False # allows extension to fire for each - # instance before going to the next. + "batch": False # allows extension to fire for each + # instance before going to the next. } parent = None @@ -29,6 +37,7 @@ class Employee(Base): def __repr__(self): return "Employee(%s, %d, %d)" % (self.emp, self.left, self.right) + @event.listens_for(Employee, "before_insert") def before_insert(mapper, connection, instance): if not instance.parent: @@ -37,23 +46,31 @@ def before_insert(mapper, connection, instance): else: personnel = mapper.mapped_table right_most_sibling = connection.scalar( - select([personnel.c.rgt]). - where(personnel.c.emp == instance.parent.emp) + select([personnel.c.rgt]).where( + personnel.c.emp == instance.parent.emp + ) ) connection.execute( - personnel.update( - personnel.c.rgt >= right_most_sibling).values( - lft=case( - [(personnel.c.lft > right_most_sibling, - personnel.c.lft + 2)], - else_=personnel.c.lft - ), - rgt=case( - [(personnel.c.rgt >= right_most_sibling, - personnel.c.rgt + 2)], - else_=personnel.c.rgt - ) + personnel.update(personnel.c.rgt >= right_most_sibling).values( + lft=case( + [ + ( + personnel.c.lft > right_most_sibling, + personnel.c.lft + 2, + ) + ], + else_=personnel.c.lft, + ), + rgt=case( + [ + ( + personnel.c.rgt >= right_most_sibling, + personnel.c.rgt + 2, + ) + ], + else_=personnel.c.rgt, + ), ) ) instance.left = right_most_sibling @@ -62,18 +79,19 @@ def before_insert(mapper, connection, instance): # before_update() would be needed to support moving of nodes # after_delete() would be needed to support removal of nodes. -engine = create_engine('sqlite://', echo=True) + +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(bind=engine) -albert = Employee(emp='Albert') -bert = Employee(emp='Bert') -chuck = Employee(emp='Chuck') -donna = Employee(emp='Donna') -eddie = Employee(emp='Eddie') -fred = Employee(emp='Fred') +albert = Employee(emp="Albert") +bert = Employee(emp="Bert") +chuck = Employee(emp="Chuck") +donna = Employee(emp="Donna") +eddie = Employee(emp="Eddie") +fred = Employee(emp="Fred") bert.parent = albert chuck.parent = albert @@ -90,22 +108,28 @@ print(session.query(Employee).all()) # 1. Find an employee and all their supervisors, no matter how deep the tree. ealias = aliased(Employee) -print(session.query(Employee).\ - filter(ealias.left.between(Employee.left, Employee.right)).\ - filter(ealias.emp == 'Eddie').all()) - -#2. Find the employee and all their subordinates. +print( + session.query(Employee) + .filter(ealias.left.between(Employee.left, Employee.right)) + .filter(ealias.emp == "Eddie") + .all() +) + +# 2. Find the employee and all their subordinates. # (This query has a nice symmetry with the first query.) -print(session.query(Employee).\ - filter(Employee.left.between(ealias.left, ealias.right)).\ - filter(ealias.emp == 'Chuck').all()) - -#3. Find the level of each node, so you can print the tree +print( + session.query(Employee) + .filter(Employee.left.between(ealias.left, ealias.right)) + .filter(ealias.emp == "Chuck") + .all() +) + +# 3. Find the level of each node, so you can print the tree # as an indented listing. -for indentation, employee in session.query( - func.count(Employee.emp).label('indentation') - 1, ealias).\ - filter(ealias.left.between(Employee.left, Employee.right)).\ - group_by(ealias.emp).\ - order_by(ealias.left): +for indentation, employee in ( + session.query(func.count(Employee.emp).label("indentation") - 1, ealias) + .filter(ealias.left.between(Employee.left, Employee.right)) + .group_by(ealias.emp) + .order_by(ealias.left) +): print(" " * indentation + str(employee)) - diff --git a/examples/performance/__init__.py b/examples/performance/__init__.py index 6264ae9f70..b66199f3c3 100644 --- a/examples/performance/__init__.py +++ b/examples/performance/__init__.py @@ -255,7 +255,8 @@ class Profiler(object): def profile(cls, fn): if cls.name is None: raise ValueError( - "Need to call Profile.init(, ) first.") + "Need to call Profile.init(, ) first." + ) cls.tests.append(fn) return fn @@ -270,7 +271,8 @@ class Profiler(object): def setup_once(cls, fn): if cls._setup_once is not None: raise ValueError( - "setup_once function already set to %s" % cls._setup_once) + "setup_once function already set to %s" % cls._setup_once + ) cls._setup_once = staticmethod(fn) return fn @@ -298,7 +300,7 @@ class Profiler(object): finally: pr.disable() - stats = pstats.Stats(pr).sort_stats('cumulative') + stats = pstats.Stats(pr).sort_stats("cumulative") self.stats.append(TestResult(self, fn, stats=stats)) return result @@ -326,7 +328,8 @@ class Profiler(object): if cls.name is None: parser.add_argument( - "name", choices=cls._suite_names(), help="suite to run") + "name", choices=cls._suite_names(), help="suite to run" + ) if len(sys.argv) > 1: potential_name = sys.argv[1] @@ -335,35 +338,44 @@ class Profiler(object): except ImportError: pass - parser.add_argument( - "--test", type=str, - help="run specific test name" - ) + parser.add_argument("--test", type=str, help="run specific test name") parser.add_argument( - '--dburl', type=str, default="sqlite:///profile.db", - help="database URL, default sqlite:///profile.db" + "--dburl", + type=str, + default="sqlite:///profile.db", + help="database URL, default sqlite:///profile.db", ) parser.add_argument( - '--num', type=int, default=cls.num, + "--num", + type=int, + default=cls.num, help="Number of iterations/items/etc for tests; " - "default is %d module-specific" % cls.num + "default is %d module-specific" % cls.num, ) parser.add_argument( - '--profile', action='store_true', - help='run profiling and dump call counts') + "--profile", + action="store_true", + help="run profiling and dump call counts", + ) parser.add_argument( - '--dump', action='store_true', - help='dump full call profile (implies --profile)') + "--dump", + action="store_true", + help="dump full call profile (implies --profile)", + ) parser.add_argument( - '--callers', action='store_true', - help='print callers as well (implies --dump)') + "--callers", + action="store_true", + help="print callers as well (implies --dump)", + ) parser.add_argument( - '--runsnake', action='store_true', - help='invoke runsnakerun (implies --profile)') + "--runsnake", + action="store_true", + help="invoke runsnakerun (implies --profile)", + ) parser.add_argument( - '--echo', action='store_true', - help="Echo SQL output") + "--echo", action="store_true", help="Echo SQL output" + ) args = parser.parse_args() args.dump = args.dump or args.callers @@ -378,7 +390,7 @@ class Profiler(object): def _suite_names(cls): suites = [] for file_ in os.listdir(os.path.dirname(__file__)): - match = re.match(r'^([a-z].*).py$', file_) + match = re.match(r"^([a-z].*).py$", file_) if match: suites.append(match.group(1)) return suites @@ -398,7 +410,10 @@ class TestResult(object): def _summary(self): summary = "%s : %s (%d iterations)" % ( - self.test.__name__, self.test.__doc__, self.profile.num) + self.test.__name__, + self.test.__doc__, + self.profile.num, + ) if self.total_time: summary += "; total time %f sec" % self.total_time if self.stats: @@ -412,7 +427,7 @@ class TestResult(object): self._dump() def _dump(self): - self.stats.sort_stats('time', 'calls') + self.stats.sort_stats("time", "calls") self.stats.print_stats() if self.profile.callers: self.stats.print_callers() @@ -424,5 +439,3 @@ class TestResult(object): os.system("runsnake %s" % filename) finally: os.remove(filename) - - diff --git a/examples/performance/__main__.py b/examples/performance/__main__.py index 5e05143bf2..945458651a 100644 --- a/examples/performance/__main__.py +++ b/examples/performance/__main__.py @@ -2,6 +2,5 @@ from . import Profiler -if __name__ == '__main__': +if __name__ == "__main__": Profiler.main() - diff --git a/examples/performance/bulk_inserts.py b/examples/performance/bulk_inserts.py index 9c3cff5b22..52f0f32e67 100644 --- a/examples/performance/bulk_inserts.py +++ b/examples/performance/bulk_inserts.py @@ -36,12 +36,15 @@ def test_flush_no_pk(n): """Individual INSERT statements via the ORM, calling upon last row id""" session = Session(bind=engine) for chunk in range(0, n, 1000): - session.add_all([ - Customer( - name='customer name %d' % i, - description='customer description %d' % i) - for i in range(chunk, chunk + 1000) - ]) + session.add_all( + [ + Customer( + name="customer name %d" % i, + description="customer description %d" % i, + ) + for i in range(chunk, chunk + 1000) + ] + ) session.flush() session.commit() @@ -50,13 +53,16 @@ def test_flush_no_pk(n): def test_bulk_save_return_pks(n): """Individual INSERT statements in "bulk", but calling upon last row id""" session = Session(bind=engine) - session.bulk_save_objects([ - Customer( - name='customer name %d' % i, - description='customer description %d' % i - ) - for i in range(n) - ], return_defaults=True) + session.bulk_save_objects( + [ + Customer( + name="customer name %d" % i, + description="customer description %d" % i, + ) + for i in range(n) + ], + return_defaults=True, + ) session.commit() @@ -65,13 +71,16 @@ def test_flush_pk_given(n): """Batched INSERT statements via the ORM, PKs already defined""" session = Session(bind=engine) for chunk in range(0, n, 1000): - session.add_all([ - Customer( - id=i + 1, - name='customer name %d' % i, - description='customer description %d' % i) - for i in range(chunk, chunk + 1000) - ]) + session.add_all( + [ + Customer( + id=i + 1, + name="customer name %d" % i, + description="customer description %d" % i, + ) + for i in range(chunk, chunk + 1000) + ] + ) session.flush() session.commit() @@ -80,13 +89,15 @@ def test_flush_pk_given(n): def test_bulk_save(n): """Batched INSERT statements via the ORM in "bulk", discarding PKs.""" session = Session(bind=engine) - session.bulk_save_objects([ - Customer( - name='customer name %d' % i, - description='customer description %d' % i - ) - for i in range(n) - ]) + session.bulk_save_objects( + [ + Customer( + name="customer name %d" % i, + description="customer description %d" % i, + ) + for i in range(n) + ] + ) session.commit() @@ -94,13 +105,16 @@ def test_bulk_save(n): def test_bulk_insert_mappings(n): """Batched INSERT statements via the ORM "bulk", using dictionaries.""" session = Session(bind=engine) - session.bulk_insert_mappings(Customer, [ - dict( - name='customer name %d' % i, - description='customer description %d' % i - ) - for i in range(n) - ]) + session.bulk_insert_mappings( + Customer, + [ + dict( + name="customer name %d" % i, + description="customer description %d" % i, + ) + for i in range(n) + ], + ) session.commit() @@ -112,11 +126,12 @@ def test_core_insert(n): Customer.__table__.insert(), [ dict( - name='customer name %d' % i, - description='customer description %d' % i + name="customer name %d" % i, + description="customer description %d" % i, ) for i in range(n) - ]) + ], + ) @Profiler.profile @@ -125,30 +140,30 @@ def test_dbapi_raw(n): conn = engine.pool._creator() cursor = conn.cursor() - compiled = Customer.__table__.insert().values( - name=bindparam('name'), - description=bindparam('description')).\ - compile(dialect=engine.dialect) + compiled = ( + Customer.__table__.insert() + .values(name=bindparam("name"), description=bindparam("description")) + .compile(dialect=engine.dialect) + ) if compiled.positional: args = ( - ('customer name %d' % i, 'customer description %d' % i) - for i in range(n)) + ("customer name %d" % i, "customer description %d" % i) + for i in range(n) + ) else: args = ( dict( - name='customer name %d' % i, - description='customer description %d' % i + name="customer name %d" % i, + description="customer description %d" % i, ) for i in range(n) ) - cursor.executemany( - str(compiled), - list(args) - ) + cursor.executemany(str(compiled), list(args)) conn.commit() conn.close() -if __name__ == '__main__': + +if __name__ == "__main__": Profiler.main() diff --git a/examples/performance/bulk_updates.py b/examples/performance/bulk_updates.py index 9522e4bf5a..ebb7000686 100644 --- a/examples/performance/bulk_updates.py +++ b/examples/performance/bulk_updates.py @@ -32,12 +32,16 @@ def setup_database(dburl, echo, num): s = Session(engine) for chunk in range(0, num, 10000): - s.bulk_insert_mappings(Customer, [ - { - 'name': 'customer name %d' % i, - 'description': 'customer description %d' % i - } for i in range(chunk, chunk + 10000) - ]) + s.bulk_insert_mappings( + Customer, + [ + { + "name": "customer name %d" % i, + "description": "customer description %d" % i, + } + for i in range(chunk, chunk + 10000) + ], + ) s.commit() @@ -46,8 +50,11 @@ def test_orm_flush(n): """UPDATE statements via the ORM flush process.""" session = Session(bind=engine) for chunk in range(0, n, 1000): - customers = session.query(Customer).\ - filter(Customer.id.between(chunk, chunk + 1000)).all() + customers = ( + session.query(Customer) + .filter(Customer.id.between(chunk, chunk + 1000)) + .all() + ) for customer in customers: customer.description += "updated" session.flush() diff --git a/examples/performance/large_resultsets.py b/examples/performance/large_resultsets.py index c13683040a..ad1c231941 100644 --- a/examples/performance/large_resultsets.py +++ b/examples/performance/large_resultsets.py @@ -46,9 +46,12 @@ def setup_database(dburl, echo, num): Customer.__table__.insert(), params=[ { - 'name': 'customer name %d' % i, - 'description': 'customer description %d' % i - } for i in range(chunk, chunk + 10000)]) + "name": "customer name %d" % i, + "description": "customer description %d" % i, + } + for i in range(chunk, chunk + 10000) + ], + ) s.commit() @@ -74,8 +77,9 @@ def test_orm_bundles(n): """Load lightweight "bundle" objects using the ORM.""" sess = Session(engine) - bundle = Bundle('customer', - Customer.id, Customer.name, Customer.description) + bundle = Bundle( + "customer", Customer.id, Customer.name, Customer.description + ) for row in sess.query(bundle).yield_per(10000).limit(n): pass @@ -85,9 +89,11 @@ def test_orm_columns(n): """Load individual columns into named tuples using the ORM.""" sess = Session(engine) - for row in sess.query( - Customer.id, Customer.name, - Customer.description).yield_per(10000).limit(n): + for row in ( + sess.query(Customer.id, Customer.name, Customer.description) + .yield_per(10000) + .limit(n) + ): pass @@ -98,7 +104,7 @@ def test_core_fetchall(n): with engine.connect() as conn: result = conn.execute(Customer.__table__.select().limit(n)).fetchall() for row in result: - data = row['id'], row['name'], row['description'] + data = row["id"], row["name"], row["description"] @Profiler.profile @@ -106,14 +112,15 @@ def test_core_fetchmany_w_streaming(n): """Load Core result rows using fetchmany/streaming.""" with engine.connect() as conn: - result = conn.execution_options(stream_results=True).\ - execute(Customer.__table__.select().limit(n)) + result = conn.execution_options(stream_results=True).execute( + Customer.__table__.select().limit(n) + ) while True: chunk = result.fetchmany(10000) if not chunk: break for row in chunk: - data = row['id'], row['name'], row['description'] + data = row["id"], row["name"], row["description"] @Profiler.profile @@ -127,7 +134,7 @@ def test_core_fetchmany(n): if not chunk: break for row in chunk: - data = row['id'], row['name'], row['description'] + data = row["id"], row["name"], row["description"] @Profiler.profile @@ -145,10 +152,13 @@ def test_dbapi_fetchall_no_object(n): def _test_dbapi_raw(n, make_objects): - compiled = Customer.__table__.select().limit(n).\ - compile( - dialect=engine.dialect, - compile_kwargs={"literal_binds": True}) + compiled = ( + Customer.__table__.select() + .limit(n) + .compile( + dialect=engine.dialect, compile_kwargs={"literal_binds": True} + ) + ) if make_objects: # because if you're going to roll your own, you're probably @@ -170,7 +180,8 @@ def _test_dbapi_raw(n, make_objects): for row in cursor.fetchall(): # ensure that we fully fetch! customer = SimpleCustomer( - id=row[0], name=row[1], description=row[2]) + id=row[0], name=row[1], description=row[2] + ) else: for row in cursor.fetchall(): # ensure that we fully fetch! @@ -178,5 +189,6 @@ def _test_dbapi_raw(n, make_objects): conn.close() -if __name__ == '__main__': + +if __name__ == "__main__": Profiler.main() diff --git a/examples/performance/short_selects.py b/examples/performance/short_selects.py index 6f64aa63e3..4a8d401ad0 100644 --- a/examples/performance/short_selects.py +++ b/examples/performance/short_selects.py @@ -6,8 +6,14 @@ record by primary key from . import Profiler from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy import Column, Integer, String, create_engine, \ - bindparam, select +from sqlalchemy import ( + Column, + Integer, + String, + create_engine, + bindparam, + select, +) from sqlalchemy.orm import Session, deferred from sqlalchemy.ext import baked import random @@ -29,6 +35,7 @@ class Customer(Base): y = deferred(Column(Integer)) z = deferred(Column(Integer)) + Profiler.init("short_selects", num=10000) @@ -39,16 +46,20 @@ def setup_database(dburl, echo, num): Base.metadata.drop_all(engine) Base.metadata.create_all(engine) sess = Session(engine) - sess.add_all([ - Customer( - id=i, name='c%d' % i, description="c%d" % i, - q=i * 10, - p=i * 20, - x=i * 30, - y=i * 40, - ) - for i in ids - ]) + sess.add_all( + [ + Customer( + id=i, + name="c%d" % i, + description="c%d" % i, + q=i * 10, + p=i * 20, + x=i * 30, + y=i * 40, + ) + for i in ids + ] + ) sess.commit() @@ -65,9 +76,9 @@ def test_orm_query_cols_only(n): """test an ORM query of only the entity columns.""" session = Session(bind=engine) for id_ in random.sample(ids, n): - session.query( - Customer.id, Customer.name, Customer.description - ).filter(Customer.id == id_).one() + session.query(Customer.id, Customer.name, Customer.description).filter( + Customer.id == id_ + ).one() @Profiler.profile @@ -77,7 +88,7 @@ def test_baked_query(n): s = Session(bind=engine) for id_ in random.sample(ids, n): q = bakery(lambda s: s.query(Customer)) - q += lambda q: q.filter(Customer.id == bindparam('id')) + q += lambda q: q.filter(Customer.id == bindparam("id")) q(s).params(id=id_).one() @@ -88,9 +99,9 @@ def test_baked_query_cols_only(n): s = Session(bind=engine) for id_ in random.sample(ids, n): q = bakery( - lambda s: s.query( - Customer.id, Customer.name, Customer.description)) - q += lambda q: q.filter(Customer.id == bindparam('id')) + lambda s: s.query(Customer.id, Customer.name, Customer.description) + ) + q += lambda q: q.filter(Customer.id == bindparam("id")) q(s).params(id=id_).one() @@ -109,7 +120,7 @@ def test_core_new_stmt_each_time(n): def test_core_reuse_stmt(n): """test core, reusing the same statement (but recompiling each time).""" - stmt = select([Customer.__table__]).where(Customer.id == bindparam('id')) + stmt = select([Customer.__table__]).where(Customer.id == bindparam("id")) with engine.connect() as conn: for id_ in random.sample(ids, n): @@ -122,13 +133,14 @@ def test_core_reuse_stmt_compiled_cache(n): """test core, reusing the same statement + compiled cache.""" compiled_cache = {} - stmt = select([Customer.__table__]).where(Customer.id == bindparam('id')) - with engine.connect().\ - execution_options(compiled_cache=compiled_cache) as conn: + stmt = select([Customer.__table__]).where(Customer.id == bindparam("id")) + with engine.connect().execution_options( + compiled_cache=compiled_cache + ) as conn: for id_ in random.sample(ids, n): row = conn.execute(stmt, id=id_).first() tuple(row) -if __name__ == '__main__': +if __name__ == "__main__": Profiler.main() diff --git a/examples/performance/single_inserts.py b/examples/performance/single_inserts.py index cfce903004..79e34dfe6b 100644 --- a/examples/performance/single_inserts.py +++ b/examples/performance/single_inserts.py @@ -28,7 +28,7 @@ Profiler.init("single_inserts", num=10000) def setup_database(dburl, echo, num): global engine engine = create_engine(dburl, echo=echo) - if engine.dialect.name == 'sqlite': + if engine.dialect.name == "sqlite": engine.pool = pool.StaticPool(creator=engine.pool._creator) Base.metadata.drop_all(engine) Base.metadata.create_all(engine) @@ -42,8 +42,9 @@ def test_orm_commit(n): session = Session(bind=engine) session.add( Customer( - name='customer name %d' % i, - description='customer description %d' % i) + name="customer name %d" % i, + description="customer description %d" % i, + ) ) session.commit() @@ -54,11 +55,14 @@ def test_bulk_save(n): for i in range(n): session = Session(bind=engine) - session.bulk_save_objects([ - Customer( - name='customer name %d' % i, - description='customer description %d' % i - )]) + session.bulk_save_objects( + [ + Customer( + name="customer name %d" % i, + description="customer description %d" % i, + ) + ] + ) session.commit() @@ -68,11 +72,15 @@ def test_bulk_insert_dictionaries(n): for i in range(n): session = Session(bind=engine) - session.bulk_insert_mappings(Customer, [ - dict( - name='customer name %d' % i, - description='customer description %d' % i - )]) + session.bulk_insert_mappings( + Customer, + [ + dict( + name="customer name %d" % i, + description="customer description %d" % i, + ) + ], + ) session.commit() @@ -85,9 +93,9 @@ def test_core(n): conn.execute( Customer.__table__.insert(), dict( - name='customer name %d' % i, - description='customer description %d' % i - ) + name="customer name %d" % i, + description="customer description %d" % i, + ), ) @@ -102,9 +110,9 @@ def test_core_query_caching(n): conn.execution_options(compiled_cache=cache).execute( ins, dict( - name='customer name %d' % i, - description='customer description %d' % i - ) + name="customer name %d" % i, + description="customer description %d" % i, + ), ) @@ -123,20 +131,22 @@ def test_dbapi_raw_w_pool(n): def _test_dbapi_raw(n, connect): - compiled = Customer.__table__.insert().values( - name=bindparam('name'), - description=bindparam('description')).\ - compile(dialect=engine.dialect) + compiled = ( + Customer.__table__.insert() + .values(name=bindparam("name"), description=bindparam("description")) + .compile(dialect=engine.dialect) + ) if compiled.positional: args = ( - ('customer name %d' % i, 'customer description %d' % i) - for i in range(n)) + ("customer name %d" % i, "customer description %d" % i) + for i in range(n) + ) else: args = ( dict( - name='customer name %d' % i, - description='customer description %d' % i + name="customer name %d" % i, + description="customer description %d" % i, ) for i in range(n) ) @@ -162,5 +172,5 @@ def _test_dbapi_raw(n, connect): conn.close() -if __name__ == '__main__': +if __name__ == "__main__": Profiler.main() diff --git a/examples/postgis/__init__.py b/examples/postgis/__init__.py index 250d9ce876..66ae65d3c0 100644 --- a/examples/postgis/__init__.py +++ b/examples/postgis/__init__.py @@ -36,4 +36,3 @@ E.g.:: .. autosource:: """ - diff --git a/examples/postgis/postgis.py b/examples/postgis/postgis.py index ffea3d0189..508d633988 100644 --- a/examples/postgis/postgis.py +++ b/examples/postgis/postgis.py @@ -5,6 +5,7 @@ import binascii # Python datatypes + class GisElement(object): """Represents a geometry value.""" @@ -12,16 +13,21 @@ class GisElement(object): return self.desc def __repr__(self): - return "<%s at 0x%x; %r>" % (self.__class__.__name__, - id(self), self.desc) + return "<%s at 0x%x; %r>" % ( + self.__class__.__name__, + id(self), + self.desc, + ) + class BinaryGisElement(GisElement, expression.Function): """Represents a Geometry value expressed as binary.""" def __init__(self, data): self.data = data - expression.Function.__init__(self, "ST_GeomFromEWKB", data, - type_=Geometry(coerce_="binary")) + expression.Function.__init__( + self, "ST_GeomFromEWKB", data, type_=Geometry(coerce_="binary") + ) @property def desc(self): @@ -31,24 +37,26 @@ class BinaryGisElement(GisElement, expression.Function): def as_hex(self): return binascii.hexlify(self.data) + class TextualGisElement(GisElement, expression.Function): """Represents a Geometry value expressed as text.""" def __init__(self, desc, srid=-1): self.desc = desc - expression.Function.__init__(self, "ST_GeomFromText", desc, srid, - type_=Geometry) + expression.Function.__init__( + self, "ST_GeomFromText", desc, srid, type_=Geometry + ) # SQL datatypes. + class Geometry(UserDefinedType): """Base PostGIS Geometry column type.""" name = "GEOMETRY" - def __init__(self, dimension=None, srid=-1, - coerce_="text"): + def __init__(self, dimension=None, srid=-1, coerce_="text"): self.dimension = dimension self.srid = srid self.coerce = coerce_ @@ -58,11 +66,11 @@ class Geometry(UserDefinedType): # override the __eq__() operator def __eq__(self, other): - return self.op('~=')(other) + return self.op("~=")(other) # add a custom operator def intersects(self, other): - return self.op('&&')(other) + return self.op("&&")(other) # any number of GIS operators can be overridden/added here # using the techniques above. @@ -95,6 +103,7 @@ class Geometry(UserDefinedType): return value.desc else: return value + return process def result_processor(self, dialect, coltype): @@ -104,27 +113,35 @@ class Geometry(UserDefinedType): fac = BinaryGisElement else: assert False + def process(value): if value is not None: return fac(value) else: return value + return process def adapt(self, impltype): - return impltype(dimension=self.dimension, - srid=self.srid, coerce_=self.coerce) + return impltype( + dimension=self.dimension, srid=self.srid, coerce_=self.coerce + ) + # other datatypes can be added as needed. + class Point(Geometry): - name = 'POINT' + name = "POINT" + class Curve(Geometry): - name = 'CURVE' + name = "CURVE" + class LineString(Curve): - name = 'LINESTRING' + name = "LINESTRING" + # ... etc. @@ -135,6 +152,7 @@ class LineString(Curve): # versions don't appear to require these special steps anymore. However, # here we illustrate how to set up these features in any case. + def setup_ddl_events(): @event.listens_for(Table, "before_create") def before_create(target, connection, **kw): @@ -153,9 +171,10 @@ def setup_ddl_events(): dispatch("after-drop", target, connection) def dispatch(event, table, bind): - if event in ('before-create', 'before-drop'): - regular_cols = [c for c in table.c if not - isinstance(c.type, Geometry)] + if event in ("before-create", "before-drop"): + regular_cols = [ + c for c in table.c if not isinstance(c.type, Geometry) + ] gis_cols = set(table.c).difference(regular_cols) table.info["_saved_columns"] = table.c @@ -163,85 +182,129 @@ def setup_ddl_events(): # Geometry columns table.columns = expression.ColumnCollection(*regular_cols) - if event == 'before-drop': + if event == "before-drop": for c in gis_cols: bind.execute( - select([ + select( + [ func.DropGeometryColumn( - 'public', table.name, c.name)], - autocommit=True) - ) + "public", table.name, c.name + ) + ], + autocommit=True, + ) + ) - elif event == 'after-create': - table.columns = table.info.pop('_saved_columns') + elif event == "after-create": + table.columns = table.info.pop("_saved_columns") for c in table.c: if isinstance(c.type, Geometry): bind.execute( - select([ - func.AddGeometryColumn( - table.name, c.name, - c.type.srid, - c.type.name, - c.type.dimension)], - autocommit=True) + select( + [ + func.AddGeometryColumn( + table.name, + c.name, + c.type.srid, + c.type.name, + c.type.dimension, + ) + ], + autocommit=True, ) - elif event == 'after-drop': - table.columns = table.info.pop('_saved_columns') -setup_ddl_events() + ) + elif event == "after-drop": + table.columns = table.info.pop("_saved_columns") +setup_ddl_events() + # illustrate usage -if __name__ == '__main__': - from sqlalchemy import (create_engine, MetaData, Column, Integer, String, - func, select) +if __name__ == "__main__": + from sqlalchemy import ( + create_engine, + MetaData, + Column, + Integer, + String, + func, + select, + ) from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base - engine = create_engine('postgresql://scott:tiger@localhost/test', echo=True) + engine = create_engine( + "postgresql://scott:tiger@localhost/test", echo=True + ) metadata = MetaData(engine) Base = declarative_base(metadata=metadata) class Road(Base): - __tablename__ = 'roads' + __tablename__ = "roads" road_id = Column(Integer, primary_key=True) road_name = Column(String) road_geom = Column(Geometry(2)) - metadata.drop_all() metadata.create_all() session = sessionmaker(bind=engine)() # Add objects. We can use strings... - session.add_all([ - Road(road_name='Jeff Rd', road_geom='LINESTRING(191232 243118,191108 243242)'), - Road(road_name='Geordie Rd', road_geom='LINESTRING(189141 244158,189265 244817)'), - Road(road_name='Paul St', road_geom='LINESTRING(192783 228138,192612 229814)'), - Road(road_name='Graeme Ave', road_geom='LINESTRING(189412 252431,189631 259122)'), - Road(road_name='Phil Tce', road_geom='LINESTRING(190131 224148,190871 228134)'), - ]) + session.add_all( + [ + Road( + road_name="Jeff Rd", + road_geom="LINESTRING(191232 243118,191108 243242)", + ), + Road( + road_name="Geordie Rd", + road_geom="LINESTRING(189141 244158,189265 244817)", + ), + Road( + road_name="Paul St", + road_geom="LINESTRING(192783 228138,192612 229814)", + ), + Road( + road_name="Graeme Ave", + road_geom="LINESTRING(189412 252431,189631 259122)", + ), + Road( + road_name="Phil Tce", + road_geom="LINESTRING(190131 224148,190871 228134)", + ), + ] + ) # or use an explicit TextualGisElement (similar to saying func.GeomFromText()) - r = Road(road_name='Dave Cres', road_geom=TextualGisElement('LINESTRING(198231 263418,198213 268322)', -1)) + r = Road( + road_name="Dave Cres", + road_geom=TextualGisElement( + "LINESTRING(198231 263418,198213 268322)", -1 + ), + ) session.add(r) # pre flush, the TextualGisElement represents the string we sent. - assert str(r.road_geom) == 'LINESTRING(198231 263418,198213 268322)' + assert str(r.road_geom) == "LINESTRING(198231 263418,198213 268322)" session.commit() # after flush and/or commit, all the TextualGisElements become PersistentGisElements. assert str(r.road_geom) == "LINESTRING(198231 263418,198213 268322)" - r1 = session.query(Road).filter(Road.road_name == 'Graeme Ave').one() + r1 = session.query(Road).filter(Road.road_name == "Graeme Ave").one() # illustrate the overridden __eq__() operator. # strings come in as TextualGisElements - r2 = session.query(Road).filter(Road.road_geom == 'LINESTRING(189412 252431,189631 259122)').one() + r2 = ( + session.query(Road) + .filter(Road.road_geom == "LINESTRING(189412 252431,189631 259122)") + .one() + ) r3 = session.query(Road).filter(Road.road_geom == r1.road_geom).one() @@ -250,22 +313,29 @@ if __name__ == '__main__': # core usage just fine: road_table = Road.__table__ - stmt = select([road_table]).where(road_table.c.road_geom.intersects(r1.road_geom)) + stmt = select([road_table]).where( + road_table.c.road_geom.intersects(r1.road_geom) + ) print(session.execute(stmt).fetchall()) # TODO: for some reason the auto-generated labels have the internal replacement # strings exposed, even though PG doesn't complain # look up the hex binary version, using SQLAlchemy casts - as_binary = session.scalar(select([type_coerce(r.road_geom, Geometry(coerce_="binary"))])) - assert as_binary.as_hex == \ - '01020000000200000000000000b832084100000000e813104100000000283208410000000088601041' + as_binary = session.scalar( + select([type_coerce(r.road_geom, Geometry(coerce_="binary"))]) + ) + assert ( + as_binary.as_hex + == "01020000000200000000000000b832084100000000e813104100000000283208410000000088601041" + ) # back again, same method ! - as_text = session.scalar(select([type_coerce(as_binary, Geometry(coerce_="text"))])) + as_text = session.scalar( + select([type_coerce(as_binary, Geometry(coerce_="text"))]) + ) assert as_text.desc == "LINESTRING(198231 263418,198213 268322)" - session.rollback() metadata.drop_all() diff --git a/examples/sharding/attribute_shard.py b/examples/sharding/attribute_shard.py index 0e19b69f36..48a3dc9325 100644 --- a/examples/sharding/attribute_shard.py +++ b/examples/sharding/attribute_shard.py @@ -1,5 +1,13 @@ -from sqlalchemy import (create_engine, Table, Column, Integer, - String, ForeignKey, Float, DateTime) +from sqlalchemy import ( + create_engine, + Table, + Column, + Integer, + String, + ForeignKey, + Float, + DateTime, +) from sqlalchemy.orm import sessionmaker, relationship from sqlalchemy.ext.horizontal_shard import ShardedSession from sqlalchemy.sql import operators, visitors @@ -12,22 +20,24 @@ import datetime # causes the id_generator() to use the same connection as that # of an ongoing transaction within db1. echo = True -db1 = create_engine('sqlite://', echo=echo, pool_threadlocal=True) -db2 = create_engine('sqlite://', echo=echo) -db3 = create_engine('sqlite://', echo=echo) -db4 = create_engine('sqlite://', echo=echo) +db1 = create_engine("sqlite://", echo=echo, pool_threadlocal=True) +db2 = create_engine("sqlite://", echo=echo) +db3 = create_engine("sqlite://", echo=echo) +db4 = create_engine("sqlite://", echo=echo) # create session function. this binds the shard ids # to databases within a ShardedSession and returns it. create_session = sessionmaker(class_=ShardedSession) -create_session.configure(shards={ - 'north_america': db1, - 'asia': db2, - 'europe': db3, - 'south_america': db4 -}) +create_session.configure( + shards={ + "north_america": db1, + "asia": db2, + "europe": db3, + "south_america": db4, + } +) # mappings and tables @@ -40,9 +50,7 @@ Base = declarative_base() # #1. Any other method will do just as well; UUID, hilo, application-specific, # etc. -ids = Table( - 'ids', Base.metadata, - Column('nextid', Integer, nullable=False)) +ids = Table("ids", Base.metadata, Column("nextid", Integer, nullable=False)) def id_generator(ctx): @@ -52,6 +60,7 @@ def id_generator(ctx): conn.execute(ids.update(values={ids.c.nextid: ids.c.nextid + 1})) return nextid + # table setup. we'll store a lead table of continents/cities, and a secondary # table storing locations. a particular row will be placed in the database # whose shard id corresponds to the 'continent'. in this setup, secondary rows @@ -67,7 +76,7 @@ class WeatherLocation(Base): continent = Column(String(30), nullable=False) city = Column(String(50), nullable=False) - reports = relationship("Report", backref='location') + reports = relationship("Report", backref="location") def __init__(self, continent, city): self.continent = continent @@ -79,14 +88,17 @@ class Report(Base): id = Column(Integer, primary_key=True) location_id = Column( - 'location_id', Integer, ForeignKey('weather_locations.id')) - temperature = Column('temperature', Float) + "location_id", Integer, ForeignKey("weather_locations.id") + ) + temperature = Column("temperature", Float) report_time = Column( - 'report_time', DateTime, default=datetime.datetime.now) + "report_time", DateTime, default=datetime.datetime.now + ) def __init__(self, temperature): self.temperature = temperature + # create tables for db in (db1, db2, db3, db4): Base.metadata.drop_all(db) @@ -101,10 +113,10 @@ db1.execute(ids.insert(), nextid=1) # we'll use a straight mapping of a particular set of "country" # attributes to shard id. shard_lookup = { - 'North America': 'north_america', - 'Asia': 'asia', - 'Europe': 'europe', - 'South America': 'south_america' + "North America": "north_america", + "Asia": "asia", + "Europe": "europe", + "South America": "south_america", } @@ -139,7 +151,7 @@ def id_chooser(query, ident): # set things up. return [query.lazy_loaded_from.identity_token] else: - return ['north_america', 'asia', 'europe', 'south_america'] + return ["north_america", "asia", "europe", "south_america"] def query_chooser(query): @@ -168,7 +180,7 @@ def query_chooser(query): ids.extend(shard_lookup[v] for v in value) if len(ids) == 0: - return ['north_america', 'asia', 'europe', 'south_america'] + return ["north_america", "asia", "europe", "south_america"] else: return ids @@ -208,13 +220,16 @@ def _get_query_comparisons(query): def visit_binary(binary): # special handling for "col IN (params)" - if binary.left in clauses and \ - binary.operator == operators.in_op and \ - hasattr(binary.right, 'clauses'): + if ( + binary.left in clauses + and binary.operator == operators.in_op + and hasattr(binary.right, "clauses") + ): comparisons.append( ( - binary.left, binary.operator, - tuple(binds[bind] for bind in binary.right.clauses) + binary.left, + binary.operator, + tuple(binds[bind] for bind in binary.right.clauses), ) ) elif binary.left in clauses and binary.right in binds: @@ -232,29 +247,33 @@ def _get_query_comparisons(query): # into a list. if query._criterion is not None: visitors.traverse_depthfirst( - query._criterion, {}, - {'bindparam': visit_bindparam, - 'binary': visit_binary, - 'column': visit_column} + query._criterion, + {}, + { + "bindparam": visit_bindparam, + "binary": visit_binary, + "column": visit_column, + }, ) return comparisons + # further configure create_session to use these functions create_session.configure( shard_chooser=shard_chooser, id_chooser=id_chooser, - query_chooser=query_chooser + query_chooser=query_chooser, ) # save and load objects! -tokyo = WeatherLocation('Asia', 'Tokyo') -newyork = WeatherLocation('North America', 'New York') -toronto = WeatherLocation('North America', 'Toronto') -london = WeatherLocation('Europe', 'London') -dublin = WeatherLocation('Europe', 'Dublin') -brasilia = WeatherLocation('South America', 'Brasila') -quito = WeatherLocation('South America', 'Quito') +tokyo = WeatherLocation("Asia", "Tokyo") +newyork = WeatherLocation("North America", "New York") +toronto = WeatherLocation("North America", "Toronto") +london = WeatherLocation("Europe", "London") +dublin = WeatherLocation("Europe", "Dublin") +brasilia = WeatherLocation("South America", "Brasila") +quito = WeatherLocation("South America", "Quito") tokyo.reports.append(Report(80.0)) newyork.reports.append(Report(75)) @@ -271,12 +290,14 @@ assert t.city == tokyo.city assert t.reports[0].temperature == 80.0 north_american_cities = sess.query(WeatherLocation).filter( - WeatherLocation.continent == 'North America') -assert {c.city for c in north_american_cities} == {'New York', 'Toronto'} + WeatherLocation.continent == "North America" +) +assert {c.city for c in north_american_cities} == {"New York", "Toronto"} asia_and_europe = sess.query(WeatherLocation).filter( - WeatherLocation.continent.in_(['Europe', 'Asia'])) -assert {c.city for c in asia_and_europe} == {'Tokyo', 'London', 'Dublin'} + WeatherLocation.continent.in_(["Europe", "Asia"]) +) +assert {c.city for c in asia_and_europe} == {"Tokyo", "London", "Dublin"} # the Report class uses a simple integer primary key. So across two databases, # a primary key will be repeated. The "identity_token" tracks in memory @@ -284,8 +305,8 @@ assert {c.city for c in asia_and_europe} == {'Tokyo', 'London', 'Dublin'} newyork_report = newyork.reports[0] tokyo_report = tokyo.reports[0] -assert inspect(newyork_report).identity_key == (Report, (1, ), "north_america") -assert inspect(tokyo_report).identity_key == (Report, (1, ), "asia") +assert inspect(newyork_report).identity_key == (Report, (1,), "north_america") +assert inspect(tokyo_report).identity_key == (Report, (1,), "asia") # the token representing the originating shard is also available directly diff --git a/examples/space_invaders/__init__.py b/examples/space_invaders/__init__.py index 8816045dc3..944f8bb466 100644 --- a/examples/space_invaders/__init__.py +++ b/examples/space_invaders/__init__.py @@ -21,4 +21,4 @@ enjoy! .. autosource:: -""" \ No newline at end of file +""" diff --git a/examples/space_invaders/space_invaders.py b/examples/space_invaders/space_invaders.py index 3ce280aece..d5437d8cf0 100644 --- a/examples/space_invaders/space_invaders.py +++ b/examples/space_invaders/space_invaders.py @@ -1,7 +1,6 @@ import sys from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy import create_engine, Integer, Column, ForeignKey, \ - String, func +from sqlalchemy import create_engine, Integer, Column, ForeignKey, String, func from sqlalchemy.orm import relationship, Session, joinedload from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method import curses @@ -18,7 +17,7 @@ if _PY3: logging.basicConfig( filename="space_invaders.log", - format="%(asctime)s,%(msecs)03d %(levelname)-5.5s %(message)s" + format="%(asctime)s,%(msecs)03d %(levelname)-5.5s %(message)s", ) logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) @@ -47,7 +46,7 @@ COLOR_MAP = { "M": curses.COLOR_MAGENTA, "R": curses.COLOR_RED, "W": curses.COLOR_WHITE, - "Y": curses.COLOR_YELLOW + "Y": curses.COLOR_YELLOW, } @@ -56,7 +55,8 @@ class Glyph(Base): to be painted on the screen. """ - __tablename__ = 'glyph' + + __tablename__ = "glyph" id = Column(Integer, primary_key=True) name = Column(String) type = Column(String) @@ -68,11 +68,9 @@ class Glyph(Base): def __init__(self, name, img, alt=None): self.name = name - self.data, self.width, self.height = \ - self._encode_glyph(img) + self.data, self.width, self.height = self._encode_glyph(img) if alt is not None: - self.alt_data, alt_w, alt_h = \ - self._encode_glyph(alt) + self.alt_data, alt_w, alt_h = self._encode_glyph(alt) def _encode_glyph(self, img): """Receive a textual description of the glyph and @@ -80,7 +78,7 @@ class Glyph(Base): GlyphCoordinate.render(). """ - img = re.sub(r'^\n', "", textwrap.dedent(img)) + img = re.sub(r"^\n", "", textwrap.dedent(img)) color = "W" lines = [line.rstrip() for line in img.split("\n")] data = [] @@ -89,15 +87,15 @@ class Glyph(Base): line = list(line) while line: char = line.pop(0) - if char == '#': + if char == "#": color = line.pop(0) continue render_line.append((color, char)) data.append(render_line) width = max([len(rl) for rl in data]) data = "".join( - "".join("%s%s" % (color, char) for color, char in render_line) + - ("W " * (width - len(render_line))) + "".join("%s%s" % (color, char) for color, char in render_line) + + ("W " * (width - len(render_line))) for render_line in data ) return data, width, len(lines) @@ -121,9 +119,10 @@ class GlyphCoordinate(Base): score value. """ - __tablename__ = 'glyph_coordinate' + + __tablename__ = "glyph_coordinate" id = Column(Integer, primary_key=True) - glyph_id = Column(Integer, ForeignKey('glyph.id')) + glyph_id = Column(Integer, ForeignKey("glyph.id")) x = Column(Integer) y = Column(Integer) tick = Column(Integer) @@ -132,11 +131,9 @@ class GlyphCoordinate(Base): glyph = relationship(Glyph, innerjoin=True) def __init__( - self, - session, glyph_name, x, y, - tick=None, label=None, score=None): - self.glyph = session.query(Glyph).\ - filter_by(name=glyph_name).one() + self, session, glyph_name, x, y, tick=None, label=None, score=None + ): + self.glyph = session.query(Glyph).filter_by(name=glyph_name).one() self.x = x self.y = y self.tick = tick @@ -152,8 +149,7 @@ class GlyphCoordinate(Base): glyph = self.glyph data = glyph.glyph_for_state(self, state) for color, char in [ - (data[i], data[i + 1]) - for i in xrange(0, len(data), 2) + (data[i], data[i + 1]) for i in xrange(0, len(data), 2) ]: x = self.x + col @@ -163,7 +159,8 @@ class GlyphCoordinate(Base): y + VERT_PADDING, x + HORIZ_PADDING, char, - _COLOR_PAIRS[color]) + _COLOR_PAIRS[color], + ) col += 1 if col == glyph.width: col = 0 @@ -186,10 +183,7 @@ class GlyphCoordinate(Base): width = min(glyph.width, MAX_X - x) or 1 for y_a in xrange(self.y, self.y + glyph.height): y = y_a - window.addstr( - y + VERT_PADDING, - x + HORIZ_PADDING, - " " * width) + window.addstr(y + VERT_PADDING, x + HORIZ_PADDING, " " * width) if self.label: self._render_label(window, True) @@ -236,21 +230,22 @@ class GlyphCoordinate(Base): the given GlyphCoordinate.""" return ~( - (self.x + self.width < other.x) | - (self.x > other.x + other.width) + (self.x + self.width < other.x) | (self.x > other.x + other.width) ) & ~( - (self.y + self.height < other.y) | - (self.y > other.y + other.height) + (self.y + self.height < other.y) + | (self.y > other.y + other.height) ) class EnemyGlyph(Glyph): """Describe an enemy.""" + __mapper_args__ = {"polymorphic_identity": "enemy"} class ArmyGlyph(EnemyGlyph): """Describe an enemy that's part of the "army". """ + __mapper_args__ = {"polymorphic_identity": "army"} def glyph_for_state(self, coord, state): @@ -262,6 +257,7 @@ class ArmyGlyph(EnemyGlyph): class SaucerGlyph(EnemyGlyph): """Describe the enemy saucer flying overhead.""" + __mapper_args__ = {"polymorphic_identity": "saucer"} def glyph_for_state(self, coord, state): @@ -273,21 +269,25 @@ class SaucerGlyph(EnemyGlyph): class MessageGlyph(Glyph): """Describe a glyph for displaying a message.""" + __mapper_args__ = {"polymorphic_identity": "message"} class PlayerGlyph(Glyph): """Describe a glyph representing the player.""" + __mapper_args__ = {"polymorphic_identity": "player"} class MissileGlyph(Glyph): """Describe a glyph representing a missile.""" + __mapper_args__ = {"polymorphic_identity": "missile"} class SplatGlyph(Glyph): """Describe a glyph representing a "splat".""" + __mapper_args__ = {"polymorphic_identity": "splat"} def glyph_for_state(self, coord, state): @@ -302,36 +302,39 @@ def init_glyph(session): """Create the glyphs used during play.""" enemy1 = ArmyGlyph( - "enemy1", """ + "enemy1", + """ #W-#B^#R-#B^#W- #G| | """, """ #W>#B^#R-#B^#W< #G^ ^ - """ + """, ) enemy2 = ArmyGlyph( - "enemy2", """ + "enemy2", + """ #W*** #R<#C~~~#R> """, """ #W@@@ #R<#C---#R> - """ + """, ) enemy3 = ArmyGlyph( - "enemy3", """ + "enemy3", + """ #Y((--)) #M-~-~-~ """, """ #Y[[--]] #M~-~-~- - """ + """, ) saucer = SaucerGlyph( @@ -351,35 +354,49 @@ def init_glyph(session): #M| #M- #Y+++ #M- #M| - """ + """, ) - ship = PlayerGlyph("ship", """ + ship = PlayerGlyph( + "ship", + """ #Y^ #G===== - """) + """, + ) - missile = MissileGlyph("missile", """ + missile = MissileGlyph( + "missile", + """ | - """) + """, + ) start = MessageGlyph( "start_message", "J = move left; L = move right; SPACE = fire\n" - " #GPress any key to start") - lose = MessageGlyph("lose_message", - "#YY O U L O S E ! ! !") - win = MessageGlyph( - "win_message", - "#RL E V E L C L E A R E D ! ! !" + " #GPress any key to start", ) + lose = MessageGlyph("lose_message", "#YY O U L O S E ! ! !") + win = MessageGlyph("win_message", "#RL E V E L C L E A R E D ! ! !") paused = MessageGlyph( - "pause_message", - "#WP A U S E D\n#GPress P to continue") + "pause_message", "#WP A U S E D\n#GPress P to continue" + ) session.add_all( - [enemy1, enemy2, enemy3, ship, saucer, - missile, start, lose, win, - paused, splat1]) + [ + enemy1, + enemy2, + enemy3, + ship, + saucer, + missile, + start, + lose, + win, + paused, + splat1, + ] + ) def setup_curses(): @@ -392,7 +409,8 @@ def setup_curses(): WINDOW_HEIGHT + (VERT_PADDING * 2), WINDOW_WIDTH + (HORIZ_PADDING * 2), WINDOW_TOP - VERT_PADDING, - WINDOW_LEFT - HORIZ_PADDING) + WINDOW_LEFT - HORIZ_PADDING, + ) curses.start_color() global _COLOR_PAIRS @@ -416,24 +434,25 @@ def init_positions(session): session.add( GlyphCoordinate( - session, "ship", - WINDOW_WIDTH // 2 - 2, - WINDOW_HEIGHT - 4) + session, "ship", WINDOW_WIDTH // 2 - 2, WINDOW_HEIGHT - 4 + ) ) arrangement = ( - ("enemy3", 50), ("enemy2", 25), - ("enemy1", 10), ("enemy2", 25), - ("enemy1", 10)) + ("enemy3", 50), + ("enemy2", 25), + ("enemy1", 10), + ("enemy2", 25), + ("enemy1", 10), + ) for (ship_vert, (etype, score)) in zip( - xrange(5, 30, ENEMY_VERT_SPACING), arrangement): + xrange(5, 30, ENEMY_VERT_SPACING), arrangement + ): for ship_horiz in xrange(0, 50, 10): session.add( GlyphCoordinate( - session, etype, - ship_horiz, - ship_vert, - score=score) + session, etype, ship_horiz, ship_vert, score=score + ) ) @@ -442,12 +461,9 @@ def draw(session, window, state): database and render. """ - for gcoord in session.query(GlyphCoordinate).\ - options(joinedload("glyph")): + for gcoord in session.query(GlyphCoordinate).options(joinedload("glyph")): gcoord.render(window, state) - window.addstr( - 1, WINDOW_WIDTH - 5, - "Score: %.4d" % state['score']) + window.addstr(1, WINDOW_WIDTH - 5, "Score: %.4d" % state["score"]) window.move(0, 0) window.refresh() @@ -456,11 +472,11 @@ def check_win(session, state): """Return the number of army glyphs remaining - the player wins if this is zero.""" - return session.query( - func.count(GlyphCoordinate.id) - ).join( - GlyphCoordinate.glyph.of_type(ArmyGlyph) - ).scalar() + return ( + session.query(func.count(GlyphCoordinate.id)) + .join(GlyphCoordinate.glyph.of_type(ArmyGlyph)) + .scalar() + ) def check_lose(session, state): @@ -470,12 +486,14 @@ def check_lose(session, state): The player loses if this is non-zero.""" player = state["player"] - return session.query(GlyphCoordinate).join( - GlyphCoordinate.glyph.of_type(ArmyGlyph) - ).filter( - GlyphCoordinate.intersects(player) | - GlyphCoordinate.bottom_bound - ).count() + return ( + session.query(GlyphCoordinate) + .join(GlyphCoordinate.glyph.of_type(ArmyGlyph)) + .filter( + GlyphCoordinate.intersects(player) | GlyphCoordinate.bottom_bound + ) + .count() + ) def render_message(session, window, msg, x, y): @@ -490,9 +508,11 @@ def render_message(session, window, msg, x, y): msg = GlyphCoordinate(session, msg, x, y) # clear existing glyphs which intersect - for gly in session.query(GlyphCoordinate).join( - GlyphCoordinate.glyph - ).filter(GlyphCoordinate.intersects(msg)): + for gly in ( + session.query(GlyphCoordinate) + .join(GlyphCoordinate.glyph) + .filter(GlyphCoordinate.intersects(msg)) + ): gly.blank(window) # render @@ -551,12 +571,14 @@ def move_army(session, window, state): # get the lower/upper boundaries of the army # along the X axis. - min_x, max_x = session.query( - func.min(GlyphCoordinate.x), - func.max(GlyphCoordinate.x + GlyphCoordinate.width), - ).join( - GlyphCoordinate.glyph.of_type(ArmyGlyph) - ).first() + min_x, max_x = ( + session.query( + func.min(GlyphCoordinate.x), + func.max(GlyphCoordinate.x + GlyphCoordinate.width), + ) + .join(GlyphCoordinate.glyph.of_type(ArmyGlyph)) + .first() + ) if min_x is None or max_x is None: # no enemies @@ -603,27 +625,26 @@ def move_player(session, window, state): player.x -= 1 elif ch == FIRE_KEY and state["missile"] is None: state["missile"] = GlyphCoordinate( - session, - "missile", - player.x + 3, - player.y - 1) + session, "missile", player.x + 3, player.y - 1 + ) def move_missile(session, window, state): """Update the status of the current missile, if any.""" - if state["missile"] is None or \ - state["tick"] % 2 != 0: + if state["missile"] is None or state["tick"] % 2 != 0: return missile = state["missile"] # locate enemy glyphs which intersect with the # missile's current position; i.e. a hit - glyph = session.query(GlyphCoordinate).\ - join(GlyphCoordinate.glyph.of_type(EnemyGlyph)).\ - filter(GlyphCoordinate.intersects(missile)).\ - first() + glyph = ( + session.query(GlyphCoordinate) + .join(GlyphCoordinate.glyph.of_type(EnemyGlyph)) + .filter(GlyphCoordinate.intersects(missile)) + .first() + ) missile.blank(window) if glyph or missile.top_bound: # missle is done @@ -642,15 +663,13 @@ def move_saucer(session, window, state): saucer_interval = 500 saucer_speed_interval = 4 - if state["saucer"] is None and \ - state["tick"] % saucer_interval != 0: + if state["saucer"] is None and state["tick"] % saucer_interval != 0: return if state["saucer"] is None: state["saucer"] = saucer = GlyphCoordinate( - session, - "saucer", -6, 1, - score=random.randrange(100, 600, 100)) + session, "saucer", -6, 1, score=random.randrange(100, 600, 100) + ) elif state["tick"] % saucer_speed_interval == 0: saucer = state["saucer"] saucer.blank(window) @@ -663,8 +682,9 @@ def move_saucer(session, window, state): def update_splat(session, window, state): """Render splat animations.""" - for splat in session.query(GlyphCoordinate).\ - join(GlyphCoordinate.glyph.of_type(SplatGlyph)): + for splat in session.query(GlyphCoordinate).join( + GlyphCoordinate.glyph.of_type(SplatGlyph) + ): age = state["tick"] - splat.tick if age > 10: splat.blank(window) @@ -683,8 +703,13 @@ def score(session, window, state, glyph): state["score"] += glyph.score # render a splat ! GlyphCoordinate( - session, "splat1", glyph.x, glyph.y, - tick=state["tick"], label=str(glyph.score)) + session, + "splat1", + glyph.x, + glyph.y, + tick=state["tick"], + label=str(glyph.score), + ) def update_state(session, window, state): @@ -713,19 +738,23 @@ def start(session, window, state, continue_=False): init_positions(session) - player = session.query(GlyphCoordinate).join( - GlyphCoordinate.glyph.of_type(PlayerGlyph) - ).one() - state.update({ - "field_pos": 0, - "alt": False, - "tick": 0, - "missile": None, - "saucer": None, - "player": player, - "army_direction": 0, - "flip": False - }) + player = ( + session.query(GlyphCoordinate) + .join(GlyphCoordinate.glyph.of_type(PlayerGlyph)) + .one() + ) + state.update( + { + "field_pos": 0, + "alt": False, + "tick": 0, + "missile": None, + "saucer": None, + "player": player, + "army_direction": 0, + "flip": False, + } + ) if not continue_: state["score"] = 0 @@ -748,7 +777,8 @@ def main(): while True: update_state(session, window, state) draw(session, window, state) - time.sleep(.01) + time.sleep(0.01) + -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/versioned_history/__init__.py b/examples/versioned_history/__init__.py index 7478450ac0..7670cd6132 100644 --- a/examples/versioned_history/__init__.py +++ b/examples/versioned_history/__init__.py @@ -60,4 +60,4 @@ can be applied:: .. autosource:: -""" \ No newline at end of file +""" diff --git a/examples/versioned_history/history_meta.py b/examples/versioned_history/history_meta.py index bad60a3982..749a6a5cab 100644 --- a/examples/versioned_history/history_meta.py +++ b/examples/versioned_history/history_meta.py @@ -30,7 +30,7 @@ def _history_mapper(local_mapper): getattr(local_mapper.class_, prop.key).impl.active_history = True super_mapper = local_mapper.inherits - super_history_mapper = getattr(cls, '__history_mapper__', None) + super_history_mapper = getattr(cls, "__history_mapper__", None) polymorphic_on = None super_fks = [] @@ -38,18 +38,20 @@ def _history_mapper(local_mapper): def _col_copy(col): orig = col col = col.copy() - orig.info['history_copy'] = col + orig.info["history_copy"] = col col.unique = False col.default = col.server_default = None col.autoincrement = False return col properties = util.OrderedDict() - if not super_mapper or \ - local_mapper.local_table is not super_mapper.local_table: + if ( + not super_mapper + or local_mapper.local_table is not super_mapper.local_table + ): cols = [] version_meta = {"version_meta": True} # add column.info to identify - # columns specific to versioning + # columns specific to versioning for column in local_mapper.local_table.c: if _is_versioning_col(column): @@ -57,12 +59,13 @@ def _history_mapper(local_mapper): col = _col_copy(column) - if super_mapper and \ - col_references_table(column, super_mapper.local_table): + if super_mapper and col_references_table( + column, super_mapper.local_table + ): super_fks.append( ( col.key, - list(super_history_mapper.local_table.primary_key)[0] + list(super_history_mapper.local_table.primary_key)[0], ) ) @@ -73,38 +76,48 @@ def _history_mapper(local_mapper): orig_prop = local_mapper.get_property_by_column(column) # carry over column re-mappings - if len(orig_prop.columns) > 1 or \ - orig_prop.columns[0].key != orig_prop.key: + if ( + len(orig_prop.columns) > 1 + or orig_prop.columns[0].key != orig_prop.key + ): properties[orig_prop.key] = tuple( - col.info['history_copy'] for col in orig_prop.columns) + col.info["history_copy"] for col in orig_prop.columns + ) if super_mapper: super_fks.append( - ( - 'version', super_history_mapper.local_table.c.version - ) + ("version", super_history_mapper.local_table.c.version) ) # "version" stores the integer version id. This column is # required. cols.append( Column( - 'version', Integer, primary_key=True, - autoincrement=False, info=version_meta)) + "version", + Integer, + primary_key=True, + autoincrement=False, + info=version_meta, + ) + ) # "changed" column stores the UTC timestamp of when the # history row was created. # This column is optional and can be omitted. - cols.append(Column( - 'changed', DateTime, - default=datetime.datetime.utcnow, - info=version_meta)) + cols.append( + Column( + "changed", + DateTime, + default=datetime.datetime.utcnow, + info=version_meta, + ) + ) if super_fks: cols.append(ForeignKeyConstraint(*zip(*super_fks))) table = Table( - local_mapper.local_table.name + '_history', + local_mapper.local_table.name + "_history", local_mapper.local_table.metadata, *cols, schema=local_mapper.local_table.schema @@ -122,9 +135,8 @@ def _history_mapper(local_mapper): bases = (super_history_mapper.class_,) if table is not None: - properties['changed'] = ( - (table.c.changed, ) + - tuple(super_history_mapper.attrs.changed.columns) + properties["changed"] = (table.c.changed,) + tuple( + super_history_mapper.attrs.changed.columns ) else: @@ -137,16 +149,17 @@ def _history_mapper(local_mapper): inherits=super_history_mapper, polymorphic_on=polymorphic_on, polymorphic_identity=local_mapper.polymorphic_identity, - properties=properties + properties=properties, ) cls.__history_mapper__ = m if not super_history_mapper: local_mapper.local_table.append_column( - Column('version', Integer, default=1, nullable=False) + Column("version", Integer, default=1, nullable=False) ) local_mapper.add_property( - "version", local_mapper.local_table.c.version) + "version", local_mapper.local_table.c.version + ) class Versioned(object): @@ -156,16 +169,17 @@ class Versioned(object): mp = mapper(cls, *arg, **kw) _history_mapper(mp) return mp + return map - __table_args__ = {'sqlite_autoincrement': True} + __table_args__ = {"sqlite_autoincrement": True} """Use sqlite_autoincrement, to ensure unique integer values are used for new rows even for rows taht have been deleted.""" def versioned_objects(iter): for obj in iter: - if hasattr(obj, '__history_mapper__'): + if hasattr(obj, "__history_mapper__"): yield obj @@ -181,8 +195,7 @@ def create_version(obj, session, deleted=False): obj_changed = False for om, hm in zip( - obj_mapper.iterate_to_root(), - history_mapper.iterate_to_root() + obj_mapper.iterate_to_root(), history_mapper.iterate_to_root() ): if hm.single: continue @@ -228,10 +241,12 @@ def create_version(obj, session, deleted=False): # not changed, but we have relationships. OK # check those too for prop in obj_mapper.iterate_properties: - if isinstance(prop, RelationshipProperty) and \ - attributes.get_history( - obj, prop.key, - passive=attributes.PASSIVE_NO_INITIALIZE).has_changes(): + if ( + isinstance(prop, RelationshipProperty) + and attributes.get_history( + obj, prop.key, passive=attributes.PASSIVE_NO_INITIALIZE + ).has_changes() + ): for p in prop.local_columns: if p.foreign_keys: obj_changed = True @@ -242,7 +257,7 @@ def create_version(obj, session, deleted=False): if not obj_changed and not deleted: return - attr['version'] = obj.version + attr["version"] = obj.version hist = history_cls() for key, value in attr.items(): setattr(hist, key, value) @@ -251,7 +266,7 @@ def create_version(obj, session, deleted=False): def versioned_session(session): - @event.listens_for(session, 'before_flush') + @event.listens_for(session, "before_flush") def before_flush(session, flush_context, instances): for obj in versioned_objects(session.dirty): create_version(obj, session) diff --git a/examples/versioned_history/test_versioning.py b/examples/versioned_history/test_versioning.py index 37ef739366..3270ad5fda 100644 --- a/examples/versioned_history/test_versioning.py +++ b/examples/versioned_history/test_versioning.py @@ -4,10 +4,22 @@ module functions.""" from unittest import TestCase from sqlalchemy.ext.declarative import declarative_base from .history_meta import Versioned, versioned_session -from sqlalchemy import create_engine, Column, Integer, String, \ - ForeignKey, Boolean, select -from sqlalchemy.orm import clear_mappers, Session, deferred, relationship, \ - column_property +from sqlalchemy import ( + create_engine, + Column, + Integer, + String, + ForeignKey, + Boolean, + select, +) +from sqlalchemy.orm import ( + clear_mappers, + Session, + deferred, + relationship, + column_property, +) from sqlalchemy.testing import AssertsCompiledSQL, eq_, assert_raises, ne_ from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.orm import exc as orm_exc @@ -20,11 +32,11 @@ engine = None def setup_module(): global engine - engine = create_engine('sqlite://', echo=True) + engine = create_engine("sqlite://", echo=True) class TestVersioning(TestCase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def setUp(self): self.session = Session(engine) @@ -41,18 +53,18 @@ class TestVersioning(TestCase, AssertsCompiledSQL): def test_plain(self): class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(50)) self.create_tables() sess = self.session - sc = SomeClass(name='sc1') + sc = SomeClass(name="sc1") sess.add(sc) sess.commit() - sc.name = 'sc1modified' + sc.name = "sc1modified" sess.commit() assert sc.version == 2 @@ -60,56 +72,60 @@ class TestVersioning(TestCase, AssertsCompiledSQL): SomeClassHistory = SomeClass.__history_mapper__.class_ eq_( - sess.query(SomeClassHistory).filter( - SomeClassHistory.version == 1).all(), - [SomeClassHistory(version=1, name='sc1')] + sess.query(SomeClassHistory) + .filter(SomeClassHistory.version == 1) + .all(), + [SomeClassHistory(version=1, name="sc1")], ) - sc.name = 'sc1modified2' + sc.name = "sc1modified2" eq_( - sess.query(SomeClassHistory).order_by( - SomeClassHistory.version).all(), + sess.query(SomeClassHistory) + .order_by(SomeClassHistory.version) + .all(), [ - SomeClassHistory(version=1, name='sc1'), - SomeClassHistory(version=2, name='sc1modified') - ] + SomeClassHistory(version=1, name="sc1"), + SomeClassHistory(version=2, name="sc1modified"), + ], ) assert sc.version == 3 sess.commit() - sc.name = 'temp' - sc.name = 'sc1modified2' + sc.name = "temp" + sc.name = "sc1modified2" sess.commit() eq_( - sess.query(SomeClassHistory).order_by( - SomeClassHistory.version).all(), + sess.query(SomeClassHistory) + .order_by(SomeClassHistory.version) + .all(), [ - SomeClassHistory(version=1, name='sc1'), - SomeClassHistory(version=2, name='sc1modified') - ] + SomeClassHistory(version=1, name="sc1"), + SomeClassHistory(version=2, name="sc1modified"), + ], ) sess.delete(sc) sess.commit() eq_( - sess.query(SomeClassHistory).order_by( - SomeClassHistory.version).all(), + sess.query(SomeClassHistory) + .order_by(SomeClassHistory.version) + .all(), [ - SomeClassHistory(version=1, name='sc1'), - SomeClassHistory(version=2, name='sc1modified'), - SomeClassHistory(version=3, name='sc1modified2') - ] + SomeClassHistory(version=1, name="sc1"), + SomeClassHistory(version=2, name="sc1modified"), + SomeClassHistory(version=3, name="sc1modified2"), + ], ) def test_w_mapper_versioning(self): class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(50)) @@ -118,27 +134,24 @@ class TestVersioning(TestCase, AssertsCompiledSQL): self.create_tables() sess = self.session - sc = SomeClass(name='sc1') + sc = SomeClass(name="sc1") sess.add(sc) sess.commit() s2 = Session(sess.bind) sc2 = s2.query(SomeClass).first() - sc2.name = 'sc1modified' + sc2.name = "sc1modified" - sc.name = 'sc1modified_again' + sc.name = "sc1modified_again" sess.commit() eq_(sc.version, 2) - assert_raises( - orm_exc.StaleDataError, - s2.flush - ) + assert_raises(orm_exc.StaleDataError, s2.flush) def test_from_null(self): class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(50)) @@ -149,14 +162,14 @@ class TestVersioning(TestCase, AssertsCompiledSQL): sess.add(sc) sess.commit() - sc.name = 'sc1' + sc.name = "sc1" sess.commit() assert sc.version == 2 def test_insert_null(self): class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) boole = Column(Boolean, default=False) @@ -176,9 +189,10 @@ class TestVersioning(TestCase, AssertsCompiledSQL): SomeClassHistory = SomeClass.__history_mapper__.class_ eq_( - sess.query(SomeClassHistory.boole).order_by( - SomeClassHistory.id).all(), - [(True, ), (None, )] + sess.query(SomeClassHistory.boole) + .order_by(SomeClassHistory.id) + .all(), + [(True,), (None,)], ) eq_(sc.version, 3) @@ -187,7 +201,7 @@ class TestVersioning(TestCase, AssertsCompiledSQL): """test versioning of unloaded, deferred columns.""" class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(50)) @@ -195,15 +209,15 @@ class TestVersioning(TestCase, AssertsCompiledSQL): self.create_tables() sess = self.session - sc = SomeClass(name='sc1', data='somedata') + sc = SomeClass(name="sc1", data="somedata") sess.add(sc) sess.commit() sess.close() sc = sess.query(SomeClass).first() - assert 'data' not in sc.__dict__ + assert "data" not in sc.__dict__ - sc.name = 'sc1modified' + sc.name = "sc1modified" sess.commit() assert sc.version == 2 @@ -211,137 +225,149 @@ class TestVersioning(TestCase, AssertsCompiledSQL): SomeClassHistory = SomeClass.__history_mapper__.class_ eq_( - sess.query(SomeClassHistory).filter( - SomeClassHistory.version == 1).all(), - [SomeClassHistory(version=1, name='sc1', data='somedata')] + sess.query(SomeClassHistory) + .filter(SomeClassHistory.version == 1) + .all(), + [SomeClassHistory(version=1, name="sc1", data="somedata")], ) def test_joined_inheritance(self): class BaseClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'basetable' + __tablename__ = "basetable" id = Column(Integer, primary_key=True) name = Column(String(50)) type = Column(String(20)) __mapper_args__ = { - 'polymorphic_on': type, - 'polymorphic_identity': 'base'} + "polymorphic_on": type, + "polymorphic_identity": "base", + } class SubClassSeparatePk(BaseClass): - __tablename__ = 'subtable1' + __tablename__ = "subtable1" id = column_property( - Column(Integer, primary_key=True), - BaseClass.id + Column(Integer, primary_key=True), BaseClass.id ) - base_id = Column(Integer, ForeignKey('basetable.id')) + base_id = Column(Integer, ForeignKey("basetable.id")) subdata1 = Column(String(50)) - __mapper_args__ = {'polymorphic_identity': 'sep'} + __mapper_args__ = {"polymorphic_identity": "sep"} class SubClassSamePk(BaseClass): - __tablename__ = 'subtable2' + __tablename__ = "subtable2" - id = Column( - Integer, ForeignKey('basetable.id'), primary_key=True) + id = Column(Integer, ForeignKey("basetable.id"), primary_key=True) subdata2 = Column(String(50)) - __mapper_args__ = {'polymorphic_identity': 'same'} + __mapper_args__ = {"polymorphic_identity": "same"} self.create_tables() sess = self.session - sep1 = SubClassSeparatePk(name='sep1', subdata1='sep1subdata') - base1 = BaseClass(name='base1') - same1 = SubClassSamePk(name='same1', subdata2='same1subdata') + sep1 = SubClassSeparatePk(name="sep1", subdata1="sep1subdata") + base1 = BaseClass(name="base1") + same1 = SubClassSamePk(name="same1", subdata2="same1subdata") sess.add_all([sep1, base1, same1]) sess.commit() - base1.name = 'base1mod' - same1.subdata2 = 'same1subdatamod' - sep1.name = 'sep1mod' + base1.name = "base1mod" + same1.subdata2 = "same1subdatamod" + sep1.name = "sep1mod" sess.commit() BaseClassHistory = BaseClass.__history_mapper__.class_ - SubClassSeparatePkHistory = \ + SubClassSeparatePkHistory = ( SubClassSeparatePk.__history_mapper__.class_ + ) SubClassSamePkHistory = SubClassSamePk.__history_mapper__.class_ eq_( sess.query(BaseClassHistory).order_by(BaseClassHistory.id).all(), [ SubClassSeparatePkHistory( - id=1, name='sep1', type='sep', version=1), - BaseClassHistory(id=2, name='base1', type='base', version=1), + id=1, name="sep1", type="sep", version=1 + ), + BaseClassHistory(id=2, name="base1", type="base", version=1), SubClassSamePkHistory( - id=3, name='same1', type='same', version=1) - ] + id=3, name="same1", type="same", version=1 + ), + ], ) - same1.subdata2 = 'same1subdatamod2' + same1.subdata2 = "same1subdatamod2" eq_( - sess.query(BaseClassHistory).order_by( - BaseClassHistory.id, BaseClassHistory.version).all(), + sess.query(BaseClassHistory) + .order_by(BaseClassHistory.id, BaseClassHistory.version) + .all(), [ SubClassSeparatePkHistory( - id=1, name='sep1', type='sep', version=1), - BaseClassHistory(id=2, name='base1', type='base', version=1), + id=1, name="sep1", type="sep", version=1 + ), + BaseClassHistory(id=2, name="base1", type="base", version=1), SubClassSamePkHistory( - id=3, name='same1', type='same', version=1), + id=3, name="same1", type="same", version=1 + ), SubClassSamePkHistory( - id=3, name='same1', type='same', version=2) - ] + id=3, name="same1", type="same", version=2 + ), + ], ) - base1.name = 'base1mod2' + base1.name = "base1mod2" eq_( - sess.query(BaseClassHistory).order_by( - BaseClassHistory.id, BaseClassHistory.version).all(), + sess.query(BaseClassHistory) + .order_by(BaseClassHistory.id, BaseClassHistory.version) + .all(), [ SubClassSeparatePkHistory( - id=1, name='sep1', type='sep', version=1), - BaseClassHistory(id=2, name='base1', type='base', version=1), + id=1, name="sep1", type="sep", version=1 + ), + BaseClassHistory(id=2, name="base1", type="base", version=1), BaseClassHistory( - id=2, name='base1mod', type='base', version=2), + id=2, name="base1mod", type="base", version=2 + ), SubClassSamePkHistory( - id=3, name='same1', type='same', version=1), + id=3, name="same1", type="same", version=1 + ), SubClassSamePkHistory( - id=3, name='same1', type='same', version=2) - ] + id=3, name="same1", type="same", version=2 + ), + ], ) def test_joined_inheritance_multilevel(self): class BaseClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'basetable' + __tablename__ = "basetable" id = Column(Integer, primary_key=True) name = Column(String(50)) type = Column(String(20)) __mapper_args__ = { - 'polymorphic_on': type, - 'polymorphic_identity': 'base'} + "polymorphic_on": type, + "polymorphic_identity": "base", + } class SubClass(BaseClass): - __tablename__ = 'subtable' + __tablename__ = "subtable" id = column_property( - Column(Integer, primary_key=True), - BaseClass.id + Column(Integer, primary_key=True), BaseClass.id ) - base_id = Column(Integer, ForeignKey('basetable.id')) + base_id = Column(Integer, ForeignKey("basetable.id")) subdata1 = Column(String(50)) - __mapper_args__ = {'polymorphic_identity': 'sub'} + __mapper_args__ = {"polymorphic_identity": "sub"} class SubSubClass(SubClass): - __tablename__ = 'subsubtable' + __tablename__ = "subsubtable" - id = Column(Integer, ForeignKey('subtable.id'), primary_key=True) + id = Column(Integer, ForeignKey("subtable.id"), primary_key=True) subdata2 = Column(String(50)) - __mapper_args__ = {'polymorphic_identity': 'subsub'} + __mapper_args__ = {"polymorphic_identity": "subsub"} self.create_tables() @@ -350,27 +376,18 @@ class TestVersioning(TestCase, AssertsCompiledSQL): q = sess.query(SubSubHistory) self.assert_compile( q, - - "SELECT " - "subsubtable_history.id AS subsubtable_history_id, " "subtable_history.id AS subtable_history_id, " "basetable_history.id AS basetable_history_id, " - "subsubtable_history.changed AS subsubtable_history_changed, " "subtable_history.changed AS subtable_history_changed, " "basetable_history.changed AS basetable_history_changed, " - "basetable_history.name AS basetable_history_name, " - "basetable_history.type AS basetable_history_type, " - "subsubtable_history.version AS subsubtable_history_version, " "subtable_history.version AS subtable_history_version, " "basetable_history.version AS basetable_history_version, " - - "subtable_history.base_id AS subtable_history_base_id, " "subtable_history.subdata1 AS subtable_history_subdata1, " "subsubtable_history.subdata2 AS subsubtable_history_subdata2 " @@ -380,64 +397,73 @@ class TestVersioning(TestCase, AssertsCompiledSQL): "AND basetable_history.version = subtable_history.version " "JOIN subsubtable_history ON subtable_history.id = " "subsubtable_history.id AND subtable_history.version = " - "subsubtable_history.version" + "subsubtable_history.version", ) - ssc = SubSubClass(name='ss1', subdata1='sd1', subdata2='sd2') + ssc = SubSubClass(name="ss1", subdata1="sd1", subdata2="sd2") sess.add(ssc) sess.commit() + eq_(sess.query(SubSubHistory).all(), []) + ssc.subdata1 = "sd11" + ssc.subdata2 = "sd22" + sess.commit() eq_( sess.query(SubSubHistory).all(), - [] + [ + SubSubHistory( + name="ss1", + subdata1="sd1", + subdata2="sd2", + type="subsub", + version=1, + ) + ], ) - ssc.subdata1 = 'sd11' - ssc.subdata2 = 'sd22' - sess.commit() eq_( - sess.query(SubSubHistory).all(), - [SubSubHistory(name='ss1', subdata1='sd1', - subdata2='sd2', type='subsub', version=1)] + ssc, + SubSubClass( + name="ss1", subdata1="sd11", subdata2="sd22", version=2 + ), ) - eq_(ssc, SubSubClass( - name='ss1', subdata1='sd11', - subdata2='sd22', version=2)) def test_joined_inheritance_changed(self): class BaseClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'basetable' + __tablename__ = "basetable" id = Column(Integer, primary_key=True) name = Column(String(50)) type = Column(String(20)) __mapper_args__ = { - 'polymorphic_on': type, - 'polymorphic_identity': 'base' + "polymorphic_on": type, + "polymorphic_identity": "base", } class SubClass(BaseClass): - __tablename__ = 'subtable' + __tablename__ = "subtable" - id = Column(Integer, ForeignKey('basetable.id'), primary_key=True) + id = Column(Integer, ForeignKey("basetable.id"), primary_key=True) - __mapper_args__ = {'polymorphic_identity': 'sep'} + __mapper_args__ = {"polymorphic_identity": "sep"} self.create_tables() BaseClassHistory = BaseClass.__history_mapper__.class_ SubClassHistory = SubClass.__history_mapper__.class_ sess = self.session - s1 = SubClass(name='s1') + s1 = SubClass(name="s1") sess.add(s1) sess.commit() - s1.name = 's2' + s1.name = "s2" sess.commit() actual_changed_base = sess.scalar( - select([BaseClass.__history_mapper__.local_table.c.changed])) + select([BaseClass.__history_mapper__.local_table.c.changed]) + ) actual_changed_sub = sess.scalar( - select([SubClass.__history_mapper__.local_table.c.changed])) + select([SubClass.__history_mapper__.local_table.c.changed]) + ) h1 = sess.query(BaseClassHistory).first() eq_(h1.changed, actual_changed_base) eq_(h1.changed, actual_changed_sub) @@ -448,53 +474,57 @@ class TestVersioning(TestCase, AssertsCompiledSQL): def test_single_inheritance(self): class BaseClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'basetable' + __tablename__ = "basetable" id = Column(Integer, primary_key=True) name = Column(String(50)) type = Column(String(50)) __mapper_args__ = { - 'polymorphic_on': type, - 'polymorphic_identity': 'base'} + "polymorphic_on": type, + "polymorphic_identity": "base", + } class SubClass(BaseClass): subname = Column(String(50), unique=True) - __mapper_args__ = {'polymorphic_identity': 'sub'} + __mapper_args__ = {"polymorphic_identity": "sub"} self.create_tables() sess = self.session - b1 = BaseClass(name='b1') - sc = SubClass(name='s1', subname='sc1') + b1 = BaseClass(name="b1") + sc = SubClass(name="s1", subname="sc1") sess.add_all([b1, sc]) sess.commit() - b1.name = 'b1modified' + b1.name = "b1modified" BaseClassHistory = BaseClass.__history_mapper__.class_ SubClassHistory = SubClass.__history_mapper__.class_ eq_( - sess.query(BaseClassHistory).order_by( - BaseClassHistory.id, BaseClassHistory.version).all(), - [BaseClassHistory(id=1, name='b1', type='base', version=1)] + sess.query(BaseClassHistory) + .order_by(BaseClassHistory.id, BaseClassHistory.version) + .all(), + [BaseClassHistory(id=1, name="b1", type="base", version=1)], ) - sc.name = 's1modified' - b1.name = 'b1modified2' + sc.name = "s1modified" + b1.name = "b1modified2" eq_( - sess.query(BaseClassHistory).order_by( - BaseClassHistory.id, BaseClassHistory.version).all(), + sess.query(BaseClassHistory) + .order_by(BaseClassHistory.id, BaseClassHistory.version) + .all(), [ - BaseClassHistory(id=1, name='b1', type='base', version=1), + BaseClassHistory(id=1, name="b1", type="base", version=1), BaseClassHistory( - id=1, name='b1modified', type='base', version=2), - SubClassHistory(id=2, name='s1', type='sub', version=1) - ] + id=1, name="b1modified", type="base", version=2 + ), + SubClassHistory(id=2, name="s1", type="sub", version=1), + ], ) # test the unique constraint on the subclass @@ -504,7 +534,7 @@ class TestVersioning(TestCase, AssertsCompiledSQL): def test_unique(self): class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(50), unique=True) @@ -512,40 +542,39 @@ class TestVersioning(TestCase, AssertsCompiledSQL): self.create_tables() sess = self.session - sc = SomeClass(name='sc1', data='sc1') + sc = SomeClass(name="sc1", data="sc1") sess.add(sc) sess.commit() - sc.data = 'sc1modified' + sc.data = "sc1modified" sess.commit() assert sc.version == 2 - sc.data = 'sc1modified2' + sc.data = "sc1modified2" sess.commit() assert sc.version == 3 def test_relationship(self): - class SomeRelated(self.Base, ComparableEntity): - __tablename__ = 'somerelated' + __tablename__ = "somerelated" id = Column(Integer, primary_key=True) class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(50)) - related_id = Column(Integer, ForeignKey('somerelated.id')) - related = relationship("SomeRelated", backref='classes') + related_id = Column(Integer, ForeignKey("somerelated.id")) + related = relationship("SomeRelated", backref="classes") SomeClassHistory = SomeClass.__history_mapper__.class_ self.create_tables() sess = self.session - sc = SomeClass(name='sc1') + sc = SomeClass(name="sc1") sess.add(sc) sess.commit() @@ -558,36 +587,37 @@ class TestVersioning(TestCase, AssertsCompiledSQL): assert sc.version == 2 eq_( - sess.query(SomeClassHistory).filter( - SomeClassHistory.version == 1).all(), - [SomeClassHistory(version=1, name='sc1', related_id=None)] + sess.query(SomeClassHistory) + .filter(SomeClassHistory.version == 1) + .all(), + [SomeClassHistory(version=1, name="sc1", related_id=None)], ) sc.related = None eq_( - sess.query(SomeClassHistory).order_by( - SomeClassHistory.version).all(), + sess.query(SomeClassHistory) + .order_by(SomeClassHistory.version) + .all(), [ - SomeClassHistory(version=1, name='sc1', related_id=None), - SomeClassHistory(version=2, name='sc1', related_id=sr1.id) - ] + SomeClassHistory(version=1, name="sc1", related_id=None), + SomeClassHistory(version=2, name="sc1", related_id=sr1.id), + ], ) assert sc.version == 3 def test_backref_relationship(self): - class SomeRelated(self.Base, ComparableEntity): - __tablename__ = 'somerelated' + __tablename__ = "somerelated" id = Column(Integer, primary_key=True) name = Column(String(50)) - related_id = Column(Integer, ForeignKey('sometable.id')) - related = relationship("SomeClass", backref='related') + related_id = Column(Integer, ForeignKey("sometable.id")) + related = relationship("SomeClass", backref="related") class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) @@ -599,13 +629,13 @@ class TestVersioning(TestCase, AssertsCompiledSQL): assert sc.version == 1 - sr = SomeRelated(name='sr', related=sc) + sr = SomeRelated(name="sr", related=sc) sess.add(sr) sess.commit() assert sc.version == 1 - sr.name = 'sr2' + sr.name = "sr2" sess.commit() assert sc.version == 1 @@ -616,9 +646,8 @@ class TestVersioning(TestCase, AssertsCompiledSQL): assert sc.version == 1 def test_create_double_flush(self): - class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(30)) @@ -629,56 +658,56 @@ class TestVersioning(TestCase, AssertsCompiledSQL): sc = SomeClass() self.session.add(sc) self.session.flush() - sc.name = 'Foo' + sc.name = "Foo" self.session.flush() assert sc.version == 2 def test_mutate_plain_column(self): class Document(self.Base, Versioned): - __tablename__ = 'document' + __tablename__ = "document" id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String, nullable=True) - description_ = Column('description', String, nullable=True) + description_ = Column("description", String, nullable=True) self.create_tables() document = Document() self.session.add(document) - document.name = 'Foo' + document.name = "Foo" self.session.commit() - document.name = 'Bar' + document.name = "Bar" self.session.commit() DocumentHistory = Document.__history_mapper__.class_ v2 = self.session.query(Document).one() v1 = self.session.query(DocumentHistory).one() self.assertEqual(v1.id, v2.id) - self.assertEqual(v2.name, 'Bar') - self.assertEqual(v1.name, 'Foo') + self.assertEqual(v2.name, "Bar") + self.assertEqual(v1.name, "Foo") def test_mutate_named_column(self): class Document(self.Base, Versioned): - __tablename__ = 'document' + __tablename__ = "document" id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String, nullable=True) - description_ = Column('description', String, nullable=True) + description_ = Column("description", String, nullable=True) self.create_tables() document = Document() self.session.add(document) - document.description_ = 'Foo' + document.description_ = "Foo" self.session.commit() - document.description_ = 'Bar' + document.description_ = "Bar" self.session.commit() DocumentHistory = Document.__history_mapper__.class_ v2 = self.session.query(Document).one() v1 = self.session.query(DocumentHistory).one() self.assertEqual(v1.id, v2.id) - self.assertEqual(v2.description_, 'Bar') - self.assertEqual(v1.description_, 'Foo') + self.assertEqual(v2.description_, "Bar") + self.assertEqual(v1.description_, "Foo") def test_unique_identifiers_across_deletes(self): """Ensure unique integer values are used for the primary table. @@ -690,21 +719,21 @@ class TestVersioning(TestCase, AssertsCompiledSQL): """ class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(50)) self.create_tables() sess = self.session - sc = SomeClass(name='sc1') + sc = SomeClass(name="sc1") sess.add(sc) sess.commit() sess.delete(sc) sess.commit() - sc2 = SomeClass(name='sc2') + sc2 = SomeClass(name="sc2") sess.add(sc2) sess.commit() @@ -721,5 +750,5 @@ class TestVersioning(TestCase, AssertsCompiledSQL): ne_(sc2.id, scdeleted.id) # If previous assertion fails, this will also fail: - sc2.name = 'sc2 modified' + sc2.name = "sc2 modified" sess.commit() diff --git a/examples/versioned_rows/__init__.py b/examples/versioned_rows/__init__.py index 637e1aca69..e5016a7400 100644 --- a/examples/versioned_rows/__init__.py +++ b/examples/versioned_rows/__init__.py @@ -9,4 +9,4 @@ history row to a separate history table. .. autosource:: -""" \ No newline at end of file +""" diff --git a/examples/versioned_rows/versioned_map.py b/examples/versioned_rows/versioned_map.py index 6a5c86a3af..46bdbb7839 100644 --- a/examples/versioned_rows/versioned_map.py +++ b/examples/versioned_rows/versioned_map.py @@ -28,11 +28,17 @@ those additional values. """ -from sqlalchemy import Column, String, Integer, ForeignKey, \ - create_engine +from sqlalchemy import Column, String, Integer, ForeignKey, create_engine from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import attributes, relationship, backref, \ - sessionmaker, make_transient, validates, Session +from sqlalchemy.orm import ( + attributes, + relationship, + backref, + sessionmaker, + make_transient, + validates, + Session, +) from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm.collections import attribute_mapped_collection from sqlalchemy import event @@ -45,8 +51,9 @@ def before_flush(session, flush_context, instances): """ for instance in session.dirty: - if hasattr(instance, 'new_version') and \ - session.is_modified(instance, passive=True): + if hasattr(instance, "new_version") and session.is_modified( + instance, passive=True + ): # make it transient instance.new_version(session) @@ -54,6 +61,7 @@ def before_flush(session, flush_context, instances): # re-add session.add(instance) + Base = declarative_base() @@ -67,7 +75,8 @@ class ConfigData(Base): string name mapped to a string/int value. """ - __tablename__ = 'config' + + __tablename__ = "config" id = Column(Integer, primary_key=True) """Primary key column of this ConfigData.""" @@ -76,7 +85,7 @@ class ConfigData(Base): "ConfigValueAssociation", collection_class=attribute_mapped_collection("name"), backref=backref("config_data"), - lazy="subquery" + lazy="subquery", ) """Dictionary-backed collection of ConfigValueAssociation objects, keyed to the name of the associated ConfigValue. @@ -97,7 +106,7 @@ class ConfigData(Base): def __init__(self, data): self.data = data - @validates('elements') + @validates("elements") def _associate_with_element(self, key, element): """Associate incoming ConfigValues with this ConfigData, if not already associated. @@ -117,11 +126,11 @@ class ConfigData(Base): # history of the 'elements' collection. # this is a tuple of groups: (added, unchanged, deleted) - hist = attributes.get_history(self, 'elements') + hist = attributes.get_history(self, "elements") # rewrite the 'elements' collection # from scratch, removing all history - attributes.set_committed_value(self, 'elements', {}) + attributes.set_committed_value(self, "elements", {}) # new elements in the "added" group # are moved to our new collection. @@ -133,7 +142,8 @@ class ConfigData(Base): # the old ones stay associated with the old ConfigData for elem in hist.unchanged: self.elements[elem.name] = ConfigValueAssociation( - elem.config_value) + elem.config_value + ) # we also need to expire changes on each ConfigValueAssociation # that is to remain associated with the old ConfigData. @@ -144,12 +154,12 @@ class ConfigData(Base): class ConfigValueAssociation(Base): """Relate ConfigData objects to associated ConfigValue objects.""" - __tablename__ = 'config_value_association' + __tablename__ = "config_value_association" - config_id = Column(ForeignKey('config.id'), primary_key=True) + config_id = Column(ForeignKey("config.id"), primary_key=True) """Reference the primary key of the ConfigData object.""" - config_value_id = Column(ForeignKey('config_value.id'), primary_key=True) + config_value_id = Column(ForeignKey("config_value.id"), primary_key=True) """Reference the primary key of the ConfigValue object.""" config_value = relationship("ConfigValue", lazy="joined", innerjoin=True) @@ -182,8 +192,7 @@ class ConfigValueAssociation(Base): """ if value != self.config_value.value: - self.config_data.elements[self.name] = \ - ConfigValueAssociation( + self.config_data.elements[self.name] = ConfigValueAssociation( ConfigValue(self.config_value.name, value) ) @@ -194,13 +203,14 @@ class ConfigValue(Base): ConfigValue is immutable. """ - __tablename__ = 'config_value' + + __tablename__ = "config_value" id = Column(Integer, primary_key=True) name = Column(String(50), nullable=False) originating_config_id = Column( - Integer, ForeignKey('config.id'), - nullable=False) + Integer, ForeignKey("config.id"), nullable=False + ) int_value = Column(Integer) string_value = Column(String(255)) @@ -221,7 +231,7 @@ class ConfigValue(Base): @property def value(self): - for k in ('int_value', 'string_value'): + for k in ("int_value", "string_value"): v = getattr(self, k) if v is not None: return v @@ -237,25 +247,23 @@ class ConfigValue(Base): self.string_value = str(value) self.int_value = None -if __name__ == '__main__': - engine = create_engine('sqlite://', echo=True) + +if __name__ == "__main__": + engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) Session = sessionmaker(engine) sess = Session() - config = ConfigData({ - 'user_name': 'twitter', - 'hash_id': '4fedffca37eaf', - 'x': 27, - 'y': 450 - }) + config = ConfigData( + {"user_name": "twitter", "hash_id": "4fedffca37eaf", "x": 27, "y": 450} + ) sess.add(config) sess.commit() version_one = config.id - config.data['user_name'] = 'yahoo' + config.data["user_name"] = "yahoo" sess.commit() version_two = config.id @@ -265,27 +273,29 @@ if __name__ == '__main__': # two versions have been created. assert config.data == { - 'user_name': 'yahoo', - 'hash_id': '4fedffca37eaf', - 'x': 27, - 'y': 450 + "user_name": "yahoo", + "hash_id": "4fedffca37eaf", + "x": 27, + "y": 450, } old_config = sess.query(ConfigData).get(version_one) assert old_config.data == { - 'user_name': 'twitter', - 'hash_id': '4fedffca37eaf', - 'x': 27, - 'y': 450 + "user_name": "twitter", + "hash_id": "4fedffca37eaf", + "x": 27, + "y": 450, } # the history of any key can be acquired using # the originating_config_id attribute - history = sess.query(ConfigValue).\ - filter(ConfigValue.name == 'user_name').\ - order_by(ConfigValue.originating_config_id).\ - all() + history = ( + sess.query(ConfigValue) + .filter(ConfigValue.name == "user_name") + .order_by(ConfigValue.originating_config_id) + .all() + ) assert [(h.value, h.originating_config_id) for h in history] == ( - [('twitter', version_one), ('yahoo', version_two)] + [("twitter", version_one), ("yahoo", version_two)] ) diff --git a/examples/versioned_rows/versioned_rows.py b/examples/versioned_rows/versioned_rows.py index ca896190de..03e1c35101 100644 --- a/examples/versioned_rows/versioned_rows.py +++ b/examples/versioned_rows/versioned_rows.py @@ -3,8 +3,13 @@ an UPDATE statement on a single row into an INSERT statement, so that a new row is inserted with the new data, keeping the old row intact. """ -from sqlalchemy.orm import sessionmaker, relationship, make_transient, \ - backref, Session +from sqlalchemy.orm import ( + sessionmaker, + relationship, + make_transient, + backref, + Session, +) from sqlalchemy import Column, ForeignKey, create_engine, Integer, String from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import attributes @@ -38,9 +43,10 @@ def before_flush(session, flush_context, instances): # re-add session.add(instance) + Base = declarative_base() -engine = create_engine('sqlite://', echo=True) +engine = create_engine("sqlite://", echo=True) Session = sessionmaker(engine) @@ -48,43 +54,44 @@ Session = sessionmaker(engine) class Example(Versioned, Base): - __tablename__ = 'example' + __tablename__ = "example" id = Column(Integer, primary_key=True) data = Column(String) + Base.metadata.create_all(engine) session = Session() -e1 = Example(data='e1') +e1 = Example(data="e1") session.add(e1) session.commit() -e1.data = 'e2' +e1.data = "e2" session.commit() assert session.query(Example.id, Example.data).order_by(Example.id).all() == ( - [(1, 'e1'), (2, 'e2')] + [(1, "e1"), (2, "e2")] ) # example 2, versioning with a parent class Parent(Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) - child_id = Column(Integer, ForeignKey('child.id')) - child = relationship("Child", backref=backref('parent', uselist=False)) + child_id = Column(Integer, ForeignKey("child.id")) + child = relationship("Child", backref=backref("parent", uselist=False)) class Child(Versioned, Base): - __tablename__ = 'child' + __tablename__ = "child" id = Column(Integer, primary_key=True) data = Column(String) def new_version(self, session): # expire parent's reference to us - session.expire(self.parent, ['child']) + session.expire(self.parent, ["child"]) # create new version Versioned.new_version(self, session) @@ -92,18 +99,19 @@ class Child(Versioned, Base): # re-add ourselves to the parent self.parent.child = self + Base.metadata.create_all(engine) session = Session() -p1 = Parent(child=Child(data='c1')) +p1 = Parent(child=Child(data="c1")) session.add(p1) session.commit() -p1.child.data = 'c2' +p1.child.data = "c2" session.commit() assert p1.child_id == 2 assert session.query(Child.id, Child.data).order_by(Child.id).all() == ( - [(1, 'c1'), (2, 'c2')] + [(1, "c1"), (2, "c2")] ) diff --git a/examples/versioned_rows/versioned_rows_w_versionid.py b/examples/versioned_rows/versioned_rows_w_versionid.py index 8445401c5f..5fd6f9fc41 100644 --- a/examples/versioned_rows/versioned_rows_w_versionid.py +++ b/examples/versioned_rows/versioned_rows_w_versionid.py @@ -6,10 +6,24 @@ This example adds a numerical version_id to the Versioned class as well as the ability to see which row is the most "current" vesion. """ -from sqlalchemy.orm import sessionmaker, relationship, make_transient, \ - backref, Session, column_property -from sqlalchemy import Column, ForeignKeyConstraint, create_engine, \ - Integer, String, Boolean, select, func +from sqlalchemy.orm import ( + sessionmaker, + relationship, + make_transient, + backref, + Session, + column_property, +) +from sqlalchemy import ( + Column, + ForeignKeyConstraint, + create_engine, + Integer, + String, + Boolean, + select, + func, +) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import attributes from sqlalchemy import event @@ -38,7 +52,8 @@ class Versioned(object): # optional - set previous version to have is_current_version=False old_id = self.id session.query(self.__class__).filter_by(id=old_id).update( - values=dict(is_current_version=False), synchronize_session=False) + values=dict(is_current_version=False), synchronize_session=False + ) # make us transient (removes persistent # identity). @@ -65,9 +80,10 @@ def before_flush(session, flush_context, instances): # re-add session.add(instance) + Base = declarative_base() -engine = create_engine('sqlite://', echo=True) +engine = create_engine("sqlite://", echo=True) Session = sessionmaker(engine) @@ -75,17 +91,18 @@ Session = sessionmaker(engine) class Example(Versioned, Base): - __tablename__ = 'example' + __tablename__ = "example" data = Column(String) + Base.metadata.create_all(engine) session = Session() -e1 = Example(id=1, data='e1') +e1 = Example(id=1, data="e1") session.add(e1) session.commit() -e1.data = 'e2' +e1.data = "e2" session.commit() assert session.query( @@ -93,36 +110,36 @@ assert session.query( Example.version_id, Example.is_current_version, Example.calc_is_current_version, - Example.data).order_by(Example.id, Example.version_id).all() == ( - [(1, 1, False, False, 'e1'), (1, 2, True, True, 'e2')] + Example.data, +).order_by(Example.id, Example.version_id).all() == ( + [(1, 1, False, False, "e1"), (1, 2, True, True, "e2")] ) # example 2, versioning with a parent class Parent(Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) child_id = Column(Integer) child_version_id = Column(Integer) - child = relationship("Child", backref=backref('parent', uselist=False)) + child = relationship("Child", backref=backref("parent", uselist=False)) __table_args__ = ( ForeignKeyConstraint( - ['child_id', 'child_version_id'], - ['child.id', 'child.version_id'], + ["child_id", "child_version_id"], ["child.id", "child.version_id"] ), ) class Child(Versioned, Base): - __tablename__ = 'child' + __tablename__ = "child" data = Column(String) def new_version(self, session): # expire parent's reference to us - session.expire(self.parent, ['child']) + session.expire(self.parent, ["child"]) # create new version Versioned.new_version(self, session) @@ -131,15 +148,16 @@ class Child(Versioned, Base): # parent foreign key to be updated also self.parent.child = self + Base.metadata.create_all(engine) session = Session() -p1 = Parent(child=Child(id=1, data='c1')) +p1 = Parent(child=Child(id=1, data="c1")) session.add(p1) session.commit() -p1.child.data = 'c2' +p1.child.data = "c2" session.commit() assert p1.child_id == 1 @@ -150,6 +168,7 @@ assert session.query( Child.version_id, Child.is_current_version, Child.calc_is_current_version, - Child.data).order_by(Child.id, Child.version_id).all() == ( - [(1, 1, False, False, 'c1'), (1, 2, True, True, 'c2')] + Child.data, +).order_by(Child.id, Child.version_id).all() == ( + [(1, 1, False, False, "c1"), (1, 2, True, True, "c2")] ) diff --git a/examples/versioned_rows/versioned_update_old_row.py b/examples/versioned_rows/versioned_update_old_row.py index 0159d25671..17c82fdc37 100644 --- a/examples/versioned_rows/versioned_update_old_row.py +++ b/examples/versioned_rows/versioned_update_old_row.py @@ -6,12 +6,23 @@ to only the most recent version. """ from sqlalchemy import ( - create_engine, Integer, String, event, Column, DateTime, - inspect, literal + create_engine, + Integer, + String, + event, + Column, + DateTime, + inspect, + literal, ) from sqlalchemy.orm import ( - make_transient, Session, relationship, attributes, backref, - make_transient_to_detached, Query + make_transient, + Session, + relationship, + attributes, + backref, + make_transient_to_detached, + Query, ) from sqlalchemy.ext.declarative import declarative_base import datetime @@ -50,7 +61,8 @@ class VersionedStartEnd(object): # make the "old" version of us, which we will turn into an # UPDATE old_copy_of_us = self.__class__( - id=self.id, start=self.start, end=self.end) + id=self.id, start=self.start, end=self.end + ) # turn old_copy_of_us into an UPDATE make_transient_to_detached(old_copy_of_us) @@ -95,11 +107,11 @@ def before_compile(query): """ensure all queries for VersionedStartEnd include criteria """ for ent in query.column_descriptions: - entity = ent['entity'] + entity = ent["entity"] if entity is None: continue - insp = inspect(ent['entity']) - mapper = getattr(insp, 'mapper', None) + insp = inspect(ent["entity"]) + mapper = getattr(insp, "mapper", None) if mapper and issubclass(mapper.class_, VersionedStartEnd): query = query.enable_assertions(False).filter( # using a literal "now" because SQLite's "between" @@ -107,14 +119,14 @@ def before_compile(query): # ``func.now()`` and we'd be using PostgreSQL literal( current_time() + datetime.timedelta(seconds=1) - ).between(ent['entity'].start, ent['entity'].end) + ).between(ent["entity"].start, ent["entity"].end) ) return query class Parent(VersionedStartEnd, Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) start = Column(DateTime, primary_key=True) end = Column(DateTime, primary_key=True) @@ -124,10 +136,7 @@ class Parent(VersionedStartEnd, Base): child = relationship( "Child", - primaryjoin=( - "Child.id == foreign(Parent.child_n)" - ), - + primaryjoin=("Child.id == foreign(Parent.child_n)"), # note the primaryjoin can also be: # # "and_(Child.id == foreign(Parent.child_n), " @@ -138,14 +147,13 @@ class Parent(VersionedStartEnd, Base): # as well, it just means the criteria will be present twice for most # parent->child load operations # - uselist=False, - backref=backref('parent', uselist=False) + backref=backref("parent", uselist=False), ) class Child(VersionedStartEnd, Base): - __tablename__ = 'child' + __tablename__ = "child" id = Column(Integer, primary_key=True) start = Column(DateTime, primary_key=True) @@ -155,7 +163,7 @@ class Child(VersionedStartEnd, Base): def new_version(self, session): # expire parent's reference to us - session.expire(self.parent, ['child']) + session.expire(self.parent, ["child"]) # create new version VersionedStartEnd.new_version(self, session) @@ -163,6 +171,7 @@ class Child(VersionedStartEnd, Base): # re-add ourselves to the parent self.parent.child = self + times = [] @@ -185,27 +194,37 @@ def time_passes(s): assert times[-1] > times[-2] return times[-1] -e = create_engine("sqlite://", echo='debug') + +e = create_engine("sqlite://", echo="debug") Base.metadata.create_all(e) s = Session(e) now = time_passes(s) -c1 = Child(id=1, data='child 1') -p1 = Parent(id=1, data='c1', child=c1) +c1 = Child(id=1, data="child 1") +p1 = Parent(id=1, data="c1", child=c1) s.add(p1) s.commit() # assert raw DB data assert s.query(Parent.__table__).all() == [ - (1, times[0] - datetime.timedelta(days=3), - times[0] + datetime.timedelta(days=3), 'c1', 1) + ( + 1, + times[0] - datetime.timedelta(days=3), + times[0] + datetime.timedelta(days=3), + "c1", + 1, + ) ] assert s.query(Child.__table__).all() == [ - (1, times[0] - datetime.timedelta(days=3), - times[0] + datetime.timedelta(days=3), 'child 1') + ( + 1, + times[0] - datetime.timedelta(days=3), + times[0] + datetime.timedelta(days=3), + "child 1", + ) ] now = time_passes(s) @@ -214,7 +233,7 @@ p1_check = s.query(Parent).first() assert p1_check is p1 assert p1_check.child is c1 -p1.child.data = 'elvis presley' +p1.child.data = "elvis presley" s.commit() @@ -226,40 +245,51 @@ c2_check = p2_check.child assert p2_check.child is c1 # new data -assert c1.data == 'elvis presley' +assert c1.data == "elvis presley" # new end time assert c1.end == now + datetime.timedelta(days=2) # assert raw DB data assert s.query(Parent.__table__).all() == [ - (1, times[0] - datetime.timedelta(days=3), - times[0] + datetime.timedelta(days=3), 'c1', 1) + ( + 1, + times[0] - datetime.timedelta(days=3), + times[0] + datetime.timedelta(days=3), + "c1", + 1, + ) ] assert s.query(Child.__table__).order_by(Child.end).all() == [ - (1, times[0] - datetime.timedelta(days=3), times[1], 'child 1'), - (1, times[1], times[1] + datetime.timedelta(days=2), 'elvis presley') + (1, times[0] - datetime.timedelta(days=3), times[1], "child 1"), + (1, times[1], times[1] + datetime.timedelta(days=2), "elvis presley"), ] now = time_passes(s) -p1.data = 'c2 elvis presley' +p1.data = "c2 elvis presley" s.commit() # assert raw DB data. now there are two parent rows. assert s.query(Parent.__table__).order_by(Parent.end).all() == [ - (1, times[0] - datetime.timedelta(days=3), times[2], 'c1', 1), - (1, times[2], times[2] + datetime.timedelta(days=2), 'c2 elvis presley', 1) + (1, times[0] - datetime.timedelta(days=3), times[2], "c1", 1), + ( + 1, + times[2], + times[2] + datetime.timedelta(days=2), + "c2 elvis presley", + 1, + ), ] assert s.query(Child.__table__).order_by(Child.end).all() == [ - (1, times[0] - datetime.timedelta(days=3), times[1], 'child 1'), - (1, times[1], times[1] + datetime.timedelta(days=2), 'elvis presley') + (1, times[0] - datetime.timedelta(days=3), times[1], "child 1"), + (1, times[1], times[1] + datetime.timedelta(days=2), "elvis presley"), ] # add some more rows to test that these aren't coming back for # queries -s.add(Parent(id=2, data='unrelated', child=Child(id=2, data='unrelated'))) +s.add(Parent(id=2, data="unrelated", child=Child(id=2, data="unrelated"))) s.commit() @@ -274,6 +304,6 @@ c3_check = s.query(Child).filter(Child.parent == p3_check).one() assert c3_check is c1 # one child one parent.... -c3_check = s.query(Child).join(Parent.child).filter( - Parent.id == p3_check.id).one() - +c3_check = ( + s.query(Child).join(Parent.child).filter(Parent.id == p3_check.id).one() +) diff --git a/examples/vertical/__init__.py b/examples/vertical/__init__.py index 0b69f32ea5..0e09b9a55d 100644 --- a/examples/vertical/__init__.py +++ b/examples/vertical/__init__.py @@ -31,4 +31,4 @@ Example:: .. autosource:: -""" \ No newline at end of file +""" diff --git a/examples/vertical/dictlike-polymorphic.py b/examples/vertical/dictlike-polymorphic.py index 7147ac40be..c000ff3cf9 100644 --- a/examples/vertical/dictlike-polymorphic.py +++ b/examples/vertical/dictlike-polymorphic.py @@ -30,6 +30,7 @@ from sqlalchemy import event from sqlalchemy import literal_column from .dictlike import ProxiedDictMixin + class PolymorphicVerticalProperty(object): """A key/value pair with polymorphic value storage. @@ -69,6 +70,7 @@ class PolymorphicVerticalProperty(object): """A comparator for .value, builds a polymorphic comparison via CASE. """ + def __init__(self, cls): self.cls = cls @@ -77,40 +79,59 @@ class PolymorphicVerticalProperty(object): whens = [ ( literal_column("'%s'" % discriminator), - cast(getattr(self.cls, attribute), String) - ) for attribute, discriminator in pairs + cast(getattr(self.cls, attribute), String), + ) + for attribute, discriminator in pairs if attribute is not None ] return case(whens, self.cls.type, null()) + def __eq__(self, other): return self._case() == cast(other, String) + def __ne__(self, other): return self._case() != cast(other, String) def __repr__(self): - return '<%s %r=%r>' % (self.__class__.__name__, self.key, self.value) + return "<%s %r=%r>" % (self.__class__.__name__, self.key, self.value) + -@event.listens_for(PolymorphicVerticalProperty, "mapper_configured", propagate=True) +@event.listens_for( + PolymorphicVerticalProperty, "mapper_configured", propagate=True +) def on_new_class(mapper, cls_): """Look for Column objects with type info in them, and work up a lookup table.""" info_dict = {} - info_dict[type(None)] = (None, 'none') - info_dict['none'] = (None, 'none') + info_dict[type(None)] = (None, "none") + info_dict["none"] = (None, "none") for k in mapper.c.keys(): col = mapper.c[k] - if 'type' in col.info: - python_type, discriminator = col.info['type'] + if "type" in col.info: + python_type, discriminator = col.info["type"] info_dict[python_type] = (k, discriminator) info_dict[discriminator] = (k, discriminator) cls_.type_map = info_dict -if __name__ == '__main__': - from sqlalchemy import (Column, Integer, Unicode, - ForeignKey, UnicodeText, and_, or_, String, Boolean, cast, - null, case, create_engine) + +if __name__ == "__main__": + from sqlalchemy import ( + Column, + Integer, + Unicode, + ForeignKey, + UnicodeText, + and_, + or_, + String, + Boolean, + cast, + null, + case, + create_engine, + ) from sqlalchemy.orm import relationship, Session from sqlalchemy.orm.collections import attribute_mapped_collection from sqlalchemy.ext.declarative import declarative_base @@ -118,36 +139,38 @@ if __name__ == '__main__': Base = declarative_base() - class AnimalFact(PolymorphicVerticalProperty, Base): """A fact about an animal.""" - __tablename__ = 'animal_fact' + __tablename__ = "animal_fact" - animal_id = Column(ForeignKey('animal.id'), primary_key=True) + animal_id = Column(ForeignKey("animal.id"), primary_key=True) key = Column(Unicode(64), primary_key=True) type = Column(Unicode(16)) # add information about storage for different types # in the info dictionary of Columns - int_value = Column(Integer, info={'type': (int, 'integer')}) - char_value = Column(UnicodeText, info={'type': (str, 'string')}) - boolean_value = Column(Boolean, info={'type': (bool, 'boolean')}) + int_value = Column(Integer, info={"type": (int, "integer")}) + char_value = Column(UnicodeText, info={"type": (str, "string")}) + boolean_value = Column(Boolean, info={"type": (bool, "boolean")}) class Animal(ProxiedDictMixin, Base): """an Animal""" - __tablename__ = 'animal' + __tablename__ = "animal" id = Column(Integer, primary_key=True) name = Column(Unicode(100)) - facts = relationship("AnimalFact", - collection_class=attribute_mapped_collection('key')) + facts = relationship( + "AnimalFact", collection_class=attribute_mapped_collection("key") + ) - _proxied = association_proxy("facts", "value", - creator= - lambda key, value: AnimalFact(key=key, value=value)) + _proxied = association_proxy( + "facts", + "value", + creator=lambda key, value: AnimalFact(key=key, value=value), + ) def __init__(self, name): self.name = name @@ -159,66 +182,66 @@ if __name__ == '__main__': def with_characteristic(self, key, value): return self.facts.any(key=key, value=value) - engine = create_engine('sqlite://', echo=True) + engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) - stoat = Animal('stoat') - stoat['color'] = 'red' - stoat['cuteness'] = 7 - stoat['weasel-like'] = True + stoat = Animal("stoat") + stoat["color"] = "red" + stoat["cuteness"] = 7 + stoat["weasel-like"] = True session.add(stoat) session.commit() - critter = session.query(Animal).filter(Animal.name == 'stoat').one() - print(critter['color']) - print(critter['cuteness']) + critter = session.query(Animal).filter(Animal.name == "stoat").one() + print(critter["color"]) + print(critter["cuteness"]) print("changing cuteness value and type:") - critter['cuteness'] = 'very cute' + critter["cuteness"] = "very cute" session.commit() - marten = Animal('marten') - marten['cuteness'] = 5 - marten['weasel-like'] = True - marten['poisonous'] = False + marten = Animal("marten") + marten["cuteness"] = 5 + marten["weasel-like"] = True + marten["poisonous"] = False session.add(marten) - shrew = Animal('shrew') - shrew['cuteness'] = 5 - shrew['weasel-like'] = False - shrew['poisonous'] = True + shrew = Animal("shrew") + shrew["cuteness"] = 5 + shrew["weasel-like"] = False + shrew["poisonous"] = True session.add(shrew) session.commit() - q = (session.query(Animal). - filter(Animal.facts.any( - and_(AnimalFact.key == 'weasel-like', - AnimalFact.value == True)))) - print('weasel-like animals', q.all()) - - q = (session.query(Animal). - filter(Animal.with_characteristic('weasel-like', True))) - print('weasel-like animals again', q.all()) - - q = (session.query(Animal). - filter(Animal.with_characteristic('poisonous', False))) - print('animals with poisonous=False', q.all()) - - q = (session.query(Animal). - filter(or_( - Animal.with_characteristic('poisonous', False), - ~Animal.facts.any(AnimalFact.key == 'poisonous') - ) - ) + q = session.query(Animal).filter( + Animal.facts.any( + and_(AnimalFact.key == "weasel-like", AnimalFact.value == True) ) - print('non-poisonous animals', q.all()) - - q = (session.query(Animal). - filter(Animal.facts.any(AnimalFact.value == 5))) - print('any animal with a .value of 5', q.all()) + ) + print("weasel-like animals", q.all()) + + q = session.query(Animal).filter( + Animal.with_characteristic("weasel-like", True) + ) + print("weasel-like animals again", q.all()) + + q = session.query(Animal).filter( + Animal.with_characteristic("poisonous", False) + ) + print("animals with poisonous=False", q.all()) + + q = session.query(Animal).filter( + or_( + Animal.with_characteristic("poisonous", False), + ~Animal.facts.any(AnimalFact.key == "poisonous"), + ) + ) + print("non-poisonous animals", q.all()) + q = session.query(Animal).filter(Animal.facts.any(AnimalFact.value == 5)) + print("any animal with a .value of 5", q.all()) diff --git a/examples/vertical/dictlike.py b/examples/vertical/dictlike.py index 08989d8c2a..f1f3642079 100644 --- a/examples/vertical/dictlike.py +++ b/examples/vertical/dictlike.py @@ -32,6 +32,7 @@ can be used with many common vertical schemas as-is or with minor adaptations. """ from __future__ import unicode_literals + class ProxiedDictMixin(object): """Adds obj[key] access to a mapped class. @@ -60,9 +61,16 @@ class ProxiedDictMixin(object): del self._proxied[key] -if __name__ == '__main__': - from sqlalchemy import (Column, Integer, Unicode, - ForeignKey, UnicodeText, and_, create_engine) +if __name__ == "__main__": + from sqlalchemy import ( + Column, + Integer, + Unicode, + ForeignKey, + UnicodeText, + and_, + create_engine, + ) from sqlalchemy.orm import relationship, Session from sqlalchemy.orm.collections import attribute_mapped_collection from sqlalchemy.ext.declarative import declarative_base @@ -73,26 +81,29 @@ if __name__ == '__main__': class AnimalFact(Base): """A fact about an animal.""" - __tablename__ = 'animal_fact' + __tablename__ = "animal_fact" - animal_id = Column(ForeignKey('animal.id'), primary_key=True) + animal_id = Column(ForeignKey("animal.id"), primary_key=True) key = Column(Unicode(64), primary_key=True) value = Column(UnicodeText) class Animal(ProxiedDictMixin, Base): """an Animal""" - __tablename__ = 'animal' + __tablename__ = "animal" id = Column(Integer, primary_key=True) name = Column(Unicode(100)) - facts = relationship("AnimalFact", - collection_class=attribute_mapped_collection('key')) + facts = relationship( + "AnimalFact", collection_class=attribute_mapped_collection("key") + ) - _proxied = association_proxy("facts", "value", - creator= - lambda key, value: AnimalFact(key=key, value=value)) + _proxied = association_proxy( + "facts", + "value", + creator=lambda key, value: AnimalFact(key=key, value=value), + ) def __init__(self, name): self.name = name @@ -109,57 +120,56 @@ if __name__ == '__main__': session = Session(bind=engine) - stoat = Animal('stoat') - stoat['color'] = 'reddish' - stoat['cuteness'] = 'somewhat' + stoat = Animal("stoat") + stoat["color"] = "reddish" + stoat["cuteness"] = "somewhat" # dict-like assignment transparently creates entries in the # stoat.facts collection: - print(stoat.facts['color']) + print(stoat.facts["color"]) session.add(stoat) session.commit() - critter = session.query(Animal).filter(Animal.name == 'stoat').one() - print(critter['color']) - print(critter['cuteness']) + critter = session.query(Animal).filter(Animal.name == "stoat").one() + print(critter["color"]) + print(critter["cuteness"]) - critter['cuteness'] = 'very' + critter["cuteness"] = "very" - print('changing cuteness:') + print("changing cuteness:") - marten = Animal('marten') - marten['color'] = 'brown' - marten['cuteness'] = 'somewhat' + marten = Animal("marten") + marten["color"] = "brown" + marten["cuteness"] = "somewhat" session.add(marten) - shrew = Animal('shrew') - shrew['cuteness'] = 'somewhat' - shrew['poisonous-part'] = 'saliva' + shrew = Animal("shrew") + shrew["cuteness"] = "somewhat" + shrew["poisonous-part"] = "saliva" session.add(shrew) - loris = Animal('slow loris') - loris['cuteness'] = 'fairly' - loris['poisonous-part'] = 'elbows' + loris = Animal("slow loris") + loris["cuteness"] = "fairly" + loris["poisonous-part"] = "elbows" session.add(loris) - q = (session.query(Animal). - filter(Animal.facts.any( - and_(AnimalFact.key == 'color', - AnimalFact.value == 'reddish')))) - print('reddish animals', q.all()) + q = session.query(Animal).filter( + Animal.facts.any( + and_(AnimalFact.key == "color", AnimalFact.value == "reddish") + ) + ) + print("reddish animals", q.all()) - q = session.query(Animal).\ - filter(Animal.with_characteristic("color", 'brown')) - print('brown animals', q.all()) + q = session.query(Animal).filter( + Animal.with_characteristic("color", "brown") + ) + print("brown animals", q.all()) - q = session.query(Animal).\ - filter(~Animal.with_characteristic("poisonous-part", 'elbows')) - print('animals without poisonous-part == elbows', q.all()) + q = session.query(Animal).filter( + ~Animal.with_characteristic("poisonous-part", "elbows") + ) + print("animals without poisonous-part == elbows", q.all()) - q = (session.query(Animal). - filter(Animal.facts.any(value='somewhat'))) + q = session.query(Animal).filter(Animal.facts.any(value="somewhat")) print('any animal with any .value of "somewhat"', q.all()) - - - diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 6162ead5c6..171f230289 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -56,7 +56,7 @@ from .sql import ( union_all, update, within_group, - ) +) from .types import ( ARRAY, @@ -102,7 +102,7 @@ from .types import ( UnicodeText, VARBINARY, VARCHAR, - ) +) from .schema import ( @@ -123,14 +123,14 @@ from .schema import ( ThreadLocalMetaData, UniqueConstraint, DDL, - BLANK_SCHEMA + BLANK_SCHEMA, ) from .inspection import inspect from .engine import create_engine, engine_from_config -__version__ = '1.3.0b2' +__version__ = "1.3.0b2" def __go(lcls): @@ -141,8 +141,13 @@ def __go(lcls): import inspect as _inspect - __all__ = sorted(name for name, obj in lcls.items() - if not (name.startswith('_') or _inspect.ismodule(obj))) + __all__ = sorted( + name + for name, obj in lcls.items() + if not (name.startswith("_") or _inspect.ismodule(obj)) + ) _sa_util.dependencies.resolve_all("sqlalchemy") + + __go(locals()) diff --git a/lib/sqlalchemy/connectors/mxodbc.py b/lib/sqlalchemy/connectors/mxodbc.py index 65be4c7d5c..209877e4a5 100644 --- a/lib/sqlalchemy/connectors/mxodbc.py +++ b/lib/sqlalchemy/connectors/mxodbc.py @@ -27,7 +27,7 @@ from . import Connector class MxODBCConnector(Connector): - driver = 'mxodbc' + driver = "mxodbc" supports_sane_multi_rowcount = False supports_unicode_statements = True @@ -41,12 +41,12 @@ class MxODBCConnector(Connector): # attribute of the same name, so this is normally only called once. cls._load_mx_exceptions() platform = sys.platform - if platform == 'win32': + if platform == "win32": from mx.ODBC import Windows as Module # this can be the string "linux2", and possibly others - elif 'linux' in platform: + elif "linux" in platform: from mx.ODBC import unixODBC as Module - elif platform == 'darwin': + elif platform == "darwin": from mx.ODBC import iODBC as Module else: raise ImportError("Unrecognized platform for mxODBC import") @@ -68,6 +68,7 @@ class MxODBCConnector(Connector): conn.datetimeformat = self.dbapi.PYDATETIME_DATETIMEFORMAT conn.decimalformat = self.dbapi.DECIMAL_DECIMALFORMAT conn.errorhandler = self._error_handler() + return connect def _error_handler(self): @@ -79,11 +80,12 @@ class MxODBCConnector(Connector): def error_handler(connection, cursor, errorclass, errorvalue): if issubclass(errorclass, MxOdbcWarning): errorclass.__bases__ = (Warning,) - warnings.warn(message=str(errorvalue), - category=errorclass, - stacklevel=2) + warnings.warn( + message=str(errorvalue), category=errorclass, stacklevel=2 + ) else: raise errorclass(errorvalue) + return error_handler def create_connect_args(self, url): @@ -101,11 +103,11 @@ class MxODBCConnector(Connector): not be populated. """ - opts = url.translate_connect_args(username='user') + opts = url.translate_connect_args(username="user") opts.update(url.query) - args = opts.pop('host') - opts.pop('port', None) - opts.pop('database', None) + args = opts.pop("host") + opts.pop("port", None) + opts.pop("database", None) return (args,), opts def is_disconnect(self, e, connection, cursor): @@ -114,7 +116,7 @@ class MxODBCConnector(Connector): if isinstance(e, self.dbapi.ProgrammingError): return "connection already closed" in str(e) elif isinstance(e, self.dbapi.Error): - return '[08S01]' in str(e) + return "[08S01]" in str(e) else: return False @@ -123,7 +125,7 @@ class MxODBCConnector(Connector): # of what we're doing here dbapi_con = connection.connection version = [] - r = re.compile(r'[.\-]') + r = re.compile(r"[.\-]") # 18 == pyodbc.SQL_DBMS_VER for n in r.split(dbapi_con.getinfo(18)[1]): try: @@ -134,8 +136,9 @@ class MxODBCConnector(Connector): def _get_direct(self, context): if context: - native_odbc_execute = context.execution_options.\ - get('native_odbc_execute', 'auto') + native_odbc_execute = context.execution_options.get( + "native_odbc_execute", "auto" + ) # default to direct=True in all cases, is more generally # compatible especially with SQL Server return False if native_odbc_execute is True else True @@ -144,8 +147,8 @@ class MxODBCConnector(Connector): def do_executemany(self, cursor, statement, parameters, context=None): cursor.executemany( - statement, parameters, direct=self._get_direct(context)) + statement, parameters, direct=self._get_direct(context) + ) def do_execute(self, cursor, statement, parameters, context=None): - cursor.execute(statement, parameters, - direct=self._get_direct(context)) + cursor.execute(statement, parameters, direct=self._get_direct(context)) diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index 41ba89de68..8f5eea89b8 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -13,7 +13,7 @@ import re class PyODBCConnector(Connector): - driver = 'pyodbc' + driver = "pyodbc" supports_sane_rowcount_returning = False supports_sane_multi_rowcount = False @@ -22,7 +22,7 @@ class PyODBCConnector(Connector): supports_unicode_binds = True supports_native_decimal = True - default_paramstyle = 'named' + default_paramstyle = "named" # for non-DSN connections, this *may* be used to # hold the desired driver name @@ -35,10 +35,10 @@ class PyODBCConnector(Connector): @classmethod def dbapi(cls): - return __import__('pyodbc') + return __import__("pyodbc") def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') + opts = url.translate_connect_args(username="user") opts.update(url.query) keys = opts @@ -46,52 +46,55 @@ class PyODBCConnector(Connector): query = url.query connect_args = {} - for param in ('ansi', 'unicode_results', 'autocommit'): + for param in ("ansi", "unicode_results", "autocommit"): if param in keys: connect_args[param] = util.asbool(keys.pop(param)) - if 'odbc_connect' in keys: - connectors = [util.unquote_plus(keys.pop('odbc_connect'))] + if "odbc_connect" in keys: + connectors = [util.unquote_plus(keys.pop("odbc_connect"))] else: + def check_quote(token): if ";" in str(token): token = "'%s'" % token return token - keys = dict( - (k, check_quote(v)) for k, v in keys.items() - ) + keys = dict((k, check_quote(v)) for k, v in keys.items()) - dsn_connection = 'dsn' in keys or \ - ('host' in keys and 'database' not in keys) + dsn_connection = "dsn" in keys or ( + "host" in keys and "database" not in keys + ) if dsn_connection: - connectors = ['dsn=%s' % (keys.pop('host', '') or - keys.pop('dsn', ''))] + connectors = [ + "dsn=%s" % (keys.pop("host", "") or keys.pop("dsn", "")) + ] else: - port = '' - if 'port' in keys and 'port' not in query: - port = ',%d' % int(keys.pop('port')) + port = "" + if "port" in keys and "port" not in query: + port = ",%d" % int(keys.pop("port")) connectors = [] - driver = keys.pop('driver', self.pyodbc_driver_name) + driver = keys.pop("driver", self.pyodbc_driver_name) if driver is None: util.warn( "No driver name specified; " "this is expected by PyODBC when using " - "DSN-less connections") + "DSN-less connections" + ) else: connectors.append("DRIVER={%s}" % driver) connectors.extend( [ - 'Server=%s%s' % (keys.pop('host', ''), port), - 'Database=%s' % keys.pop('database', '') - ]) + "Server=%s%s" % (keys.pop("host", ""), port), + "Database=%s" % keys.pop("database", ""), + ] + ) user = keys.pop("user", None) if user: connectors.append("UID=%s" % user) - connectors.append("PWD=%s" % keys.pop('password', '')) + connectors.append("PWD=%s" % keys.pop("password", "")) else: connectors.append("Trusted_Connection=Yes") @@ -99,18 +102,20 @@ class PyODBCConnector(Connector): # convert textual data from your database encoding to your # client encoding. This should obviously be set to 'No' if # you query a cp1253 encoded database from a latin1 client... - if 'odbc_autotranslate' in keys: - connectors.append("AutoTranslate=%s" % - keys.pop("odbc_autotranslate")) + if "odbc_autotranslate" in keys: + connectors.append( + "AutoTranslate=%s" % keys.pop("odbc_autotranslate") + ) - connectors.extend(['%s=%s' % (k, v) for k, v in keys.items()]) + connectors.extend(["%s=%s" % (k, v) for k, v in keys.items()]) return [[";".join(connectors)], connect_args] def is_disconnect(self, e, connection, cursor): if isinstance(e, self.dbapi.ProgrammingError): - return "The cursor's connection has been closed." in str(e) or \ - 'Attempt to use a closed connection.' in str(e) + return "The cursor's connection has been closed." in str( + e + ) or "Attempt to use a closed connection." in str(e) else: return False @@ -123,10 +128,7 @@ class PyODBCConnector(Connector): return self._parse_dbapi_version(self.dbapi.version) def _parse_dbapi_version(self, vers): - m = re.match( - r'(?:py.*-)?([\d\.]+)(?:-(\w+))?', - vers - ) + m = re.match(r"(?:py.*-)?([\d\.]+)(?:-(\w+))?", vers) if not m: return () vers = tuple([int(x) for x in m.group(1).split(".")]) @@ -140,7 +142,7 @@ class PyODBCConnector(Connector): # queries. dbapi_con = connection.connection version = [] - r = re.compile(r'[.\-]') + r = re.compile(r"[.\-]") for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): try: version.append(int(n)) @@ -153,12 +155,11 @@ class PyODBCConnector(Connector): # adjust for ConnectionFairy being present # allows attribute set e.g. "connection.autocommit = True" # to work properly - if hasattr(connection, 'connection'): + if hasattr(connection, "connection"): connection = connection.connection - if level == 'AUTOCOMMIT': + if level == "AUTOCOMMIT": connection.autocommit = True else: connection.autocommit = False - super(PyODBCConnector, self).set_isolation_level(connection, - level) + super(PyODBCConnector, self).set_isolation_level(connection, level) diff --git a/lib/sqlalchemy/connectors/zxJDBC.py b/lib/sqlalchemy/connectors/zxJDBC.py index 71decd9abc..003ecbed1c 100644 --- a/lib/sqlalchemy/connectors/zxJDBC.py +++ b/lib/sqlalchemy/connectors/zxJDBC.py @@ -10,15 +10,15 @@ from . import Connector class ZxJDBCConnector(Connector): - driver = 'zxjdbc' + driver = "zxjdbc" supports_sane_rowcount = False supports_sane_multi_rowcount = False supports_unicode_binds = True - supports_unicode_statements = sys.version > '2.5.0+' + supports_unicode_statements = sys.version > "2.5.0+" description_encoding = None - default_paramstyle = 'qmark' + default_paramstyle = "qmark" jdbc_db_name = None jdbc_driver_name = None @@ -26,6 +26,7 @@ class ZxJDBCConnector(Connector): @classmethod def dbapi(cls): from com.ziclix.python.sql import zxJDBC + return zxJDBC def _driver_kwargs(self): @@ -34,25 +35,31 @@ class ZxJDBCConnector(Connector): def _create_jdbc_url(self, url): """Create a JDBC url from a :class:`~sqlalchemy.engine.url.URL`""" - return 'jdbc:%s://%s%s/%s' % (self.jdbc_db_name, url.host, - url.port is not None - and ':%s' % url.port or '', - url.database) + return "jdbc:%s://%s%s/%s" % ( + self.jdbc_db_name, + url.host, + url.port is not None and ":%s" % url.port or "", + url.database, + ) def create_connect_args(self, url): opts = self._driver_kwargs() opts.update(url.query) return [ - [self._create_jdbc_url(url), - url.username, url.password, - self.jdbc_driver_name], - opts] + [ + self._create_jdbc_url(url), + url.username, + url.password, + self.jdbc_driver_name, + ], + opts, + ] def is_disconnect(self, e, connection, cursor): if not isinstance(e, self.dbapi.ProgrammingError): return False e = str(e) - return 'connection is closed' in e or 'cursor is closed' in e + return "connection is closed" in e or "cursor is closed" in e def _get_server_version_info(self, connection): # use connection.connection.dbversion, and parse appropriately diff --git a/lib/sqlalchemy/databases/__init__.py b/lib/sqlalchemy/databases/__init__.py index 2cb252737f..d2d56a7ae2 100644 --- a/lib/sqlalchemy/databases/__init__.py +++ b/lib/sqlalchemy/databases/__init__.py @@ -11,6 +11,7 @@ compatibility with pre 0.6 versions. """ from ..dialects.sqlite import base as sqlite from ..dialects.postgresql import base as postgresql + postgres = postgresql from ..dialects.mysql import base as mysql from ..dialects.oracle import base as oracle @@ -20,11 +21,11 @@ from ..dialects.sybase import base as sybase __all__ = ( - 'firebird', - 'mssql', - 'mysql', - 'postgresql', - 'sqlite', - 'oracle', - 'sybase', + "firebird", + "mssql", + "mysql", + "postgresql", + "sqlite", + "oracle", + "sybase", ) diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py index 963babcb8c..65f30bb768 100644 --- a/lib/sqlalchemy/dialects/__init__.py +++ b/lib/sqlalchemy/dialects/__init__.py @@ -6,18 +6,19 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php __all__ = ( - 'firebird', - 'mssql', - 'mysql', - 'oracle', - 'postgresql', - 'sqlite', - 'sybase', + "firebird", + "mssql", + "mysql", + "oracle", + "postgresql", + "sqlite", + "sybase", ) from .. import util -_translates = {'postgres': 'postgresql'} +_translates = {"postgres": "postgresql"} + def _auto_fn(name): """default dialect importer. @@ -40,7 +41,7 @@ def _auto_fn(name): ) dialect = translated try: - module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects + module = __import__("sqlalchemy.dialects.%s" % (dialect,)).dialects except ImportError: return None @@ -51,6 +52,7 @@ def _auto_fn(name): else: return None + registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn) -plugins = util.PluginLoader("sqlalchemy.plugins") \ No newline at end of file +plugins = util.PluginLoader("sqlalchemy.plugins") diff --git a/lib/sqlalchemy/dialects/firebird/__init__.py b/lib/sqlalchemy/dialects/firebird/__init__.py index c83db453be..510d623374 100644 --- a/lib/sqlalchemy/dialects/firebird/__init__.py +++ b/lib/sqlalchemy/dialects/firebird/__init__.py @@ -7,14 +7,35 @@ from . import base, kinterbasdb, fdb # noqa -from sqlalchemy.dialects.firebird.base import \ - SMALLINT, BIGINT, FLOAT, DATE, TIME, \ - TEXT, NUMERIC, TIMESTAMP, VARCHAR, CHAR, BLOB +from sqlalchemy.dialects.firebird.base import ( + SMALLINT, + BIGINT, + FLOAT, + DATE, + TIME, + TEXT, + NUMERIC, + TIMESTAMP, + VARCHAR, + CHAR, + BLOB, +) base.dialect = dialect = fdb.dialect __all__ = ( - 'SMALLINT', 'BIGINT', 'FLOAT', 'FLOAT', 'DATE', 'TIME', - 'TEXT', 'NUMERIC', 'FLOAT', 'TIMESTAMP', 'VARCHAR', 'CHAR', 'BLOB', - 'dialect' + "SMALLINT", + "BIGINT", + "FLOAT", + "FLOAT", + "DATE", + "TIME", + "TEXT", + "NUMERIC", + "FLOAT", + "TIMESTAMP", + "VARCHAR", + "CHAR", + "BLOB", + "dialect", ) diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index 7b470c1899..1e9c778f3e 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -79,48 +79,254 @@ from sqlalchemy.engine import base, default, reflection from sqlalchemy.sql import compiler from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.types import (BIGINT, BLOB, DATE, FLOAT, INTEGER, NUMERIC, - SMALLINT, TEXT, TIME, TIMESTAMP, Integer) - - -RESERVED_WORDS = set([ - "active", "add", "admin", "after", "all", "alter", "and", "any", "as", - "asc", "ascending", "at", "auto", "avg", "before", "begin", "between", - "bigint", "bit_length", "blob", "both", "by", "case", "cast", "char", - "character", "character_length", "char_length", "check", "close", - "collate", "column", "commit", "committed", "computed", "conditional", - "connect", "constraint", "containing", "count", "create", "cross", - "cstring", "current", "current_connection", "current_date", - "current_role", "current_time", "current_timestamp", - "current_transaction", "current_user", "cursor", "database", "date", - "day", "dec", "decimal", "declare", "default", "delete", "desc", - "descending", "disconnect", "distinct", "do", "domain", "double", - "drop", "else", "end", "entry_point", "escape", "exception", - "execute", "exists", "exit", "external", "extract", "fetch", "file", - "filter", "float", "for", "foreign", "from", "full", "function", - "gdscode", "generator", "gen_id", "global", "grant", "group", - "having", "hour", "if", "in", "inactive", "index", "inner", - "input_type", "insensitive", "insert", "int", "integer", "into", "is", - "isolation", "join", "key", "leading", "left", "length", "level", - "like", "long", "lower", "manual", "max", "maximum_segment", "merge", - "min", "minute", "module_name", "month", "names", "national", - "natural", "nchar", "no", "not", "null", "numeric", "octet_length", - "of", "on", "only", "open", "option", "or", "order", "outer", - "output_type", "overflow", "page", "pages", "page_size", "parameter", - "password", "plan", "position", "post_event", "precision", "primary", - "privileges", "procedure", "protected", "rdb$db_key", "read", "real", - "record_version", "recreate", "recursive", "references", "release", - "reserv", "reserving", "retain", "returning_values", "returns", - "revoke", "right", "rollback", "rows", "row_count", "savepoint", - "schema", "second", "segment", "select", "sensitive", "set", "shadow", - "shared", "singular", "size", "smallint", "snapshot", "some", "sort", - "sqlcode", "stability", "start", "starting", "starts", "statistics", - "sub_type", "sum", "suspend", "table", "then", "time", "timestamp", - "to", "trailing", "transaction", "trigger", "trim", "uncommitted", - "union", "unique", "update", "upper", "user", "using", "value", - "values", "varchar", "variable", "varying", "view", "wait", "when", - "where", "while", "with", "work", "write", "year", -]) +from sqlalchemy.types import ( + BIGINT, + BLOB, + DATE, + FLOAT, + INTEGER, + NUMERIC, + SMALLINT, + TEXT, + TIME, + TIMESTAMP, + Integer, +) + + +RESERVED_WORDS = set( + [ + "active", + "add", + "admin", + "after", + "all", + "alter", + "and", + "any", + "as", + "asc", + "ascending", + "at", + "auto", + "avg", + "before", + "begin", + "between", + "bigint", + "bit_length", + "blob", + "both", + "by", + "case", + "cast", + "char", + "character", + "character_length", + "char_length", + "check", + "close", + "collate", + "column", + "commit", + "committed", + "computed", + "conditional", + "connect", + "constraint", + "containing", + "count", + "create", + "cross", + "cstring", + "current", + "current_connection", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_transaction", + "current_user", + "cursor", + "database", + "date", + "day", + "dec", + "decimal", + "declare", + "default", + "delete", + "desc", + "descending", + "disconnect", + "distinct", + "do", + "domain", + "double", + "drop", + "else", + "end", + "entry_point", + "escape", + "exception", + "execute", + "exists", + "exit", + "external", + "extract", + "fetch", + "file", + "filter", + "float", + "for", + "foreign", + "from", + "full", + "function", + "gdscode", + "generator", + "gen_id", + "global", + "grant", + "group", + "having", + "hour", + "if", + "in", + "inactive", + "index", + "inner", + "input_type", + "insensitive", + "insert", + "int", + "integer", + "into", + "is", + "isolation", + "join", + "key", + "leading", + "left", + "length", + "level", + "like", + "long", + "lower", + "manual", + "max", + "maximum_segment", + "merge", + "min", + "minute", + "module_name", + "month", + "names", + "national", + "natural", + "nchar", + "no", + "not", + "null", + "numeric", + "octet_length", + "of", + "on", + "only", + "open", + "option", + "or", + "order", + "outer", + "output_type", + "overflow", + "page", + "pages", + "page_size", + "parameter", + "password", + "plan", + "position", + "post_event", + "precision", + "primary", + "privileges", + "procedure", + "protected", + "rdb$db_key", + "read", + "real", + "record_version", + "recreate", + "recursive", + "references", + "release", + "reserv", + "reserving", + "retain", + "returning_values", + "returns", + "revoke", + "right", + "rollback", + "rows", + "row_count", + "savepoint", + "schema", + "second", + "segment", + "select", + "sensitive", + "set", + "shadow", + "shared", + "singular", + "size", + "smallint", + "snapshot", + "some", + "sort", + "sqlcode", + "stability", + "start", + "starting", + "starts", + "statistics", + "sub_type", + "sum", + "suspend", + "table", + "then", + "time", + "timestamp", + "to", + "trailing", + "transaction", + "trigger", + "trim", + "uncommitted", + "union", + "unique", + "update", + "upper", + "user", + "using", + "value", + "values", + "varchar", + "variable", + "varying", + "view", + "wait", + "when", + "where", + "while", + "with", + "work", + "write", + "year", + ] +) class _StringType(sqltypes.String): @@ -133,7 +339,8 @@ class _StringType(sqltypes.String): class VARCHAR(_StringType, sqltypes.VARCHAR): """Firebird VARCHAR type""" - __visit_name__ = 'VARCHAR' + + __visit_name__ = "VARCHAR" def __init__(self, length=None, **kwargs): super(VARCHAR, self).__init__(length=length, **kwargs) @@ -141,7 +348,8 @@ class VARCHAR(_StringType, sqltypes.VARCHAR): class CHAR(_StringType, sqltypes.CHAR): """Firebird CHAR type""" - __visit_name__ = 'CHAR' + + __visit_name__ = "CHAR" def __init__(self, length=None, **kwargs): super(CHAR, self).__init__(length=length, **kwargs) @@ -154,32 +362,33 @@ class _FBDateTime(sqltypes.DateTime): return datetime.datetime(value.year, value.month, value.day) else: return value + return process -colspecs = { - sqltypes.DateTime: _FBDateTime -} + +colspecs = {sqltypes.DateTime: _FBDateTime} ischema_names = { - 'SHORT': SMALLINT, - 'LONG': INTEGER, - 'QUAD': FLOAT, - 'FLOAT': FLOAT, - 'DATE': DATE, - 'TIME': TIME, - 'TEXT': TEXT, - 'INT64': BIGINT, - 'DOUBLE': FLOAT, - 'TIMESTAMP': TIMESTAMP, - 'VARYING': VARCHAR, - 'CSTRING': CHAR, - 'BLOB': BLOB, + "SHORT": SMALLINT, + "LONG": INTEGER, + "QUAD": FLOAT, + "FLOAT": FLOAT, + "DATE": DATE, + "TIME": TIME, + "TEXT": TEXT, + "INT64": BIGINT, + "DOUBLE": FLOAT, + "TIMESTAMP": TIMESTAMP, + "VARYING": VARCHAR, + "CSTRING": CHAR, + "BLOB": BLOB, } # TODO: date conversion types (should be implemented as _FBDateTime, # _FBDate, etc. as bind/result functionality is required) + class FBTypeCompiler(compiler.GenericTypeCompiler): def visit_boolean(self, type_, **kw): return self.visit_SMALLINT(type_, **kw) @@ -194,11 +403,11 @@ class FBTypeCompiler(compiler.GenericTypeCompiler): return "BLOB SUB_TYPE 0" def _extend_string(self, type_, basic): - charset = getattr(type_, 'charset', None) + charset = getattr(type_, "charset", None) if charset is None: return basic else: - return '%s CHARACTER SET %s' % (basic, charset) + return "%s CHARACTER SET %s" % (basic, charset) def visit_CHAR(self, type_, **kw): basic = super(FBTypeCompiler, self).visit_CHAR(type_, **kw) @@ -207,8 +416,8 @@ class FBTypeCompiler(compiler.GenericTypeCompiler): def visit_VARCHAR(self, type_, **kw): if not type_.length: raise exc.CompileError( - "VARCHAR requires a length on dialect %s" % - self.dialect.name) + "VARCHAR requires a length on dialect %s" % self.dialect.name + ) basic = super(FBTypeCompiler, self).visit_VARCHAR(type_, **kw) return self._extend_string(type_, basic) @@ -228,36 +437,42 @@ class FBCompiler(sql.compiler.SQLCompiler): return "CURRENT_TIMESTAMP" def visit_startswith_op_binary(self, binary, operator, **kw): - return '%s STARTING WITH %s' % ( + return "%s STARTING WITH %s" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) + binary.right._compiler_dispatch(self, **kw), + ) def visit_notstartswith_op_binary(self, binary, operator, **kw): - return '%s NOT STARTING WITH %s' % ( + return "%s NOT STARTING WITH %s" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) + binary.right._compiler_dispatch(self, **kw), + ) def visit_mod_binary(self, binary, operator, **kw): return "mod(%s, %s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_alias(self, alias, asfrom=False, **kwargs): if self.dialect._version_two: - return super(FBCompiler, self).\ - visit_alias(alias, asfrom=asfrom, **kwargs) + return super(FBCompiler, self).visit_alias( + alias, asfrom=asfrom, **kwargs + ) else: # Override to not use the AS keyword which FB 1.5 does not like if asfrom: - alias_name = isinstance(alias.name, - expression._truncated_label) and \ - self._truncated_identifier("alias", - alias.name) or alias.name - - return self.process( - alias.original, asfrom=asfrom, **kwargs) + \ - " " + \ - self.preparer.format_alias(alias, alias_name) + alias_name = ( + isinstance(alias.name, expression._truncated_label) + and self._truncated_identifier("alias", alias.name) + or alias.name + ) + + return ( + self.process(alias.original, asfrom=asfrom, **kwargs) + + " " + + self.preparer.format_alias(alias, alias_name) + ) else: return self.process(alias.original, **kwargs) @@ -320,7 +535,7 @@ class FBCompiler(sql.compiler.SQLCompiler): for c in expression._select_iterables(returning_cols) ] - return 'RETURNING ' + ', '.join(columns) + return "RETURNING " + ", ".join(columns) class FBDDLCompiler(sql.compiler.DDLCompiler): @@ -333,27 +548,33 @@ class FBDDLCompiler(sql.compiler.DDLCompiler): # http://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html if create.element.start is not None: raise NotImplemented( - "Firebird SEQUENCE doesn't support START WITH") + "Firebird SEQUENCE doesn't support START WITH" + ) if create.element.increment is not None: raise NotImplemented( - "Firebird SEQUENCE doesn't support INCREMENT BY") + "Firebird SEQUENCE doesn't support INCREMENT BY" + ) if self.dialect._version_two: - return "CREATE SEQUENCE %s" % \ - self.preparer.format_sequence(create.element) + return "CREATE SEQUENCE %s" % self.preparer.format_sequence( + create.element + ) else: - return "CREATE GENERATOR %s" % \ - self.preparer.format_sequence(create.element) + return "CREATE GENERATOR %s" % self.preparer.format_sequence( + create.element + ) def visit_drop_sequence(self, drop): """Generate a ``DROP GENERATOR`` statement for the sequence.""" if self.dialect._version_two: - return "DROP SEQUENCE %s" % \ - self.preparer.format_sequence(drop.element) + return "DROP SEQUENCE %s" % self.preparer.format_sequence( + drop.element + ) else: - return "DROP GENERATOR %s" % \ - self.preparer.format_sequence(drop.element) + return "DROP GENERATOR %s" % self.preparer.format_sequence( + drop.element + ) class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): @@ -361,7 +582,8 @@ class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union( - ['_']) + ["_"] + ) def __init__(self, dialect): super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True) @@ -372,16 +594,16 @@ class FBExecutionContext(default.DefaultExecutionContext): """Get the next value from the sequence using ``gen_id()``.""" return self._execute_scalar( - "SELECT gen_id(%s, 1) FROM rdb$database" % - self.dialect.identifier_preparer.format_sequence(seq), - type_ + "SELECT gen_id(%s, 1) FROM rdb$database" + % self.dialect.identifier_preparer.format_sequence(seq), + type_, ) class FBDialect(default.DefaultDialect): """Firebird dialect""" - name = 'firebird' + name = "firebird" max_identifier_length = 31 @@ -413,23 +635,23 @@ class FBDialect(default.DefaultDialect): def initialize(self, connection): super(FBDialect, self).initialize(connection) - self._version_two = ('firebird' in self.server_version_info and - self.server_version_info >= (2, ) - ) or \ - ('interbase' in self.server_version_info and - self.server_version_info >= (6, ) - ) + self._version_two = ( + "firebird" in self.server_version_info + and self.server_version_info >= (2,) + ) or ( + "interbase" in self.server_version_info + and self.server_version_info >= (6,) + ) if not self._version_two: # TODO: whatever other pre < 2.0 stuff goes here self.ischema_names = ischema_names.copy() - self.ischema_names['TIMESTAMP'] = sqltypes.DATE - self.colspecs = { - sqltypes.DateTime: sqltypes.DATE - } + self.ischema_names["TIMESTAMP"] = sqltypes.DATE + self.colspecs = {sqltypes.DateTime: sqltypes.DATE} - self.implicit_returning = self._version_two and \ - self.__dict__.get('implicit_returning', True) + self.implicit_returning = self._version_two and self.__dict__.get( + "implicit_returning", True + ) def normalize_name(self, name): # Remove trailing spaces: FB uses a CHAR() type, @@ -437,8 +659,9 @@ class FBDialect(default.DefaultDialect): name = name and name.rstrip() if name is None: return None - elif name.upper() == name and \ - not self.identifier_preparer._requires_quotes(name.lower()): + elif name.upper() == name and not self.identifier_preparer._requires_quotes( + name.lower() + ): return name.lower() elif name.lower() == name: return quoted_name(name, quote=True) @@ -448,8 +671,9 @@ class FBDialect(default.DefaultDialect): def denormalize_name(self, name): if name is None: return None - elif name.lower() == name and \ - not self.identifier_preparer._requires_quotes(name.lower()): + elif name.lower() == name and not self.identifier_preparer._requires_quotes( + name.lower() + ): return name.upper() else: return name @@ -522,7 +746,7 @@ class FBDialect(default.DefaultDialect): rp = connection.execute(qry, [self.denormalize_name(view_name)]) row = rp.first() if row: - return row['view_source'] + return row["view_source"] else: return None @@ -538,13 +762,13 @@ class FBDialect(default.DefaultDialect): tablename = self.denormalize_name(table_name) # get primary key fields c = connection.execute(keyqry, ["PRIMARY KEY", tablename]) - pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()] - return {'constrained_columns': pkfields, 'name': None} + pkfields = [self.normalize_name(r["fname"]) for r in c.fetchall()] + return {"constrained_columns": pkfields, "name": None} @reflection.cache - def get_column_sequence(self, connection, - table_name, column_name, - schema=None, **kw): + def get_column_sequence( + self, connection, table_name, column_name, schema=None, **kw + ): tablename = self.denormalize_name(table_name) colname = self.denormalize_name(column_name) # Heuristic-query to determine the generator associated to a PK field @@ -567,7 +791,7 @@ class FBDialect(default.DefaultDialect): """ genr = connection.execute(genqry, [tablename, colname]).first() if genr is not None: - return dict(name=self.normalize_name(genr['fgenerator'])) + return dict(name=self.normalize_name(genr["fgenerator"])) @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): @@ -595,7 +819,7 @@ class FBDialect(default.DefaultDialect): """ # get the PK, used to determine the eventual associated sequence pk_constraint = self.get_pk_constraint(connection, table_name) - pkey_cols = pk_constraint['constrained_columns'] + pkey_cols = pk_constraint["constrained_columns"] tablename = self.denormalize_name(table_name) # get all of the fields for this table @@ -605,26 +829,28 @@ class FBDialect(default.DefaultDialect): row = c.fetchone() if row is None: break - name = self.normalize_name(row['fname']) - orig_colname = row['fname'] + name = self.normalize_name(row["fname"]) + orig_colname = row["fname"] # get the data type - colspec = row['ftype'].rstrip() + colspec = row["ftype"].rstrip() coltype = self.ischema_names.get(colspec) if coltype is None: - util.warn("Did not recognize type '%s' of column '%s'" % - (colspec, name)) + util.warn( + "Did not recognize type '%s' of column '%s'" + % (colspec, name) + ) coltype = sqltypes.NULLTYPE - elif issubclass(coltype, Integer) and row['fprec'] != 0: + elif issubclass(coltype, Integer) and row["fprec"] != 0: coltype = NUMERIC( - precision=row['fprec'], - scale=row['fscale'] * -1) - elif colspec in ('VARYING', 'CSTRING'): - coltype = coltype(row['flen']) - elif colspec == 'TEXT': - coltype = TEXT(row['flen']) - elif colspec == 'BLOB': - if row['stype'] == 1: + precision=row["fprec"], scale=row["fscale"] * -1 + ) + elif colspec in ("VARYING", "CSTRING"): + coltype = coltype(row["flen"]) + elif colspec == "TEXT": + coltype = TEXT(row["flen"]) + elif colspec == "BLOB": + if row["stype"] == 1: coltype = TEXT() else: coltype = BLOB() @@ -633,36 +859,36 @@ class FBDialect(default.DefaultDialect): # does it have a default value? defvalue = None - if row['fdefault'] is not None: + if row["fdefault"] is not None: # the value comes down as "DEFAULT 'value'": there may be # more than one whitespace around the "DEFAULT" keyword # and it may also be lower case # (see also http://tracker.firebirdsql.org/browse/CORE-356) - defexpr = row['fdefault'].lstrip() - assert defexpr[:8].rstrip().upper() == \ - 'DEFAULT', "Unrecognized default value: %s" % \ - defexpr + defexpr = row["fdefault"].lstrip() + assert defexpr[:8].rstrip().upper() == "DEFAULT", ( + "Unrecognized default value: %s" % defexpr + ) defvalue = defexpr[8:].strip() - if defvalue == 'NULL': + if defvalue == "NULL": # Redundant defvalue = None col_d = { - 'name': name, - 'type': coltype, - 'nullable': not bool(row['null_flag']), - 'default': defvalue, - 'autoincrement': 'auto', + "name": name, + "type": coltype, + "nullable": not bool(row["null_flag"]), + "default": defvalue, + "autoincrement": "auto", } if orig_colname.lower() == orig_colname: - col_d['quote'] = True + col_d["quote"] = True # if the PK is a single field, try to see if its linked to # a sequence thru a trigger if len(pkey_cols) == 1 and name == pkey_cols[0]: seq_d = self.get_column_sequence(connection, tablename, name) if seq_d is not None: - col_d['sequence'] = seq_d + col_d["sequence"] = seq_d cols.append(col_d) return cols @@ -689,24 +915,26 @@ class FBDialect(default.DefaultDialect): tablename = self.denormalize_name(table_name) c = connection.execute(fkqry, ["FOREIGN KEY", tablename]) - fks = util.defaultdict(lambda: { - 'name': None, - 'constrained_columns': [], - 'referred_schema': None, - 'referred_table': None, - 'referred_columns': [] - }) + fks = util.defaultdict( + lambda: { + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], + } + ) for row in c: - cname = self.normalize_name(row['cname']) + cname = self.normalize_name(row["cname"]) fk = fks[cname] - if not fk['name']: - fk['name'] = cname - fk['referred_table'] = self.normalize_name(row['targetrname']) - fk['constrained_columns'].append( - self.normalize_name(row['fname'])) - fk['referred_columns'].append( - self.normalize_name(row['targetfname'])) + if not fk["name"]: + fk["name"] = cname + fk["referred_table"] = self.normalize_name(row["targetrname"]) + fk["constrained_columns"].append(self.normalize_name(row["fname"])) + fk["referred_columns"].append( + self.normalize_name(row["targetfname"]) + ) return list(fks.values()) @reflection.cache @@ -729,13 +957,14 @@ class FBDialect(default.DefaultDialect): indexes = util.defaultdict(dict) for row in c: - indexrec = indexes[row['index_name']] - if 'name' not in indexrec: - indexrec['name'] = self.normalize_name(row['index_name']) - indexrec['column_names'] = [] - indexrec['unique'] = bool(row['unique_flag']) - - indexrec['column_names'].append( - self.normalize_name(row['field_name'])) + indexrec = indexes[row["index_name"]] + if "name" not in indexrec: + indexrec["name"] = self.normalize_name(row["index_name"]) + indexrec["column_names"] = [] + indexrec["unique"] = bool(row["unique_flag"]) + + indexrec["column_names"].append( + self.normalize_name(row["field_name"]) + ) return list(indexes.values()) diff --git a/lib/sqlalchemy/dialects/firebird/fdb.py b/lib/sqlalchemy/dialects/firebird/fdb.py index e8da6e1b73..5bf3d2c49b 100644 --- a/lib/sqlalchemy/dialects/firebird/fdb.py +++ b/lib/sqlalchemy/dialects/firebird/fdb.py @@ -73,25 +73,23 @@ from ... import util class FBDialect_fdb(FBDialect_kinterbasdb): - - def __init__(self, enable_rowcount=True, - retaining=False, **kwargs): + def __init__(self, enable_rowcount=True, retaining=False, **kwargs): super(FBDialect_fdb, self).__init__( - enable_rowcount=enable_rowcount, - retaining=retaining, **kwargs) + enable_rowcount=enable_rowcount, retaining=retaining, **kwargs + ) @classmethod def dbapi(cls): - return __import__('fdb') + return __import__("fdb") def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if opts.get('port'): - opts['host'] = "%s/%s" % (opts['host'], opts['port']) - del opts['port'] + opts = url.translate_connect_args(username="user") + if opts.get("port"): + opts["host"] = "%s/%s" % (opts["host"], opts["port"]) + del opts["port"] opts.update(url.query) - util.coerce_kw_type(opts, 'type_conv', int) + util.coerce_kw_type(opts, "type_conv", int) return ([], opts) @@ -115,4 +113,5 @@ class FBDialect_fdb(FBDialect_kinterbasdb): return self._parse_version_info(version) + dialect = FBDialect_fdb diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py index dc88fc8499..6d71440966 100644 --- a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py +++ b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py @@ -51,6 +51,7 @@ class _kinterbasdb_numeric(object): return str(value) else: return value + return process @@ -65,15 +66,16 @@ class _FBFloat_kinterbasdb(_kinterbasdb_numeric, sqltypes.Float): class FBExecutionContext_kinterbasdb(FBExecutionContext): @property def rowcount(self): - if self.execution_options.get('enable_rowcount', - self.dialect.enable_rowcount): + if self.execution_options.get( + "enable_rowcount", self.dialect.enable_rowcount + ): return self.cursor.rowcount else: return -1 class FBDialect_kinterbasdb(FBDialect): - driver = 'kinterbasdb' + driver = "kinterbasdb" supports_sane_rowcount = False supports_sane_multi_rowcount = False execution_ctx_cls = FBExecutionContext_kinterbasdb @@ -85,13 +87,17 @@ class FBDialect_kinterbasdb(FBDialect): { sqltypes.Numeric: _FBNumeric_kinterbasdb, sqltypes.Float: _FBFloat_kinterbasdb, - } - + }, ) - def __init__(self, type_conv=200, concurrency_level=1, - enable_rowcount=True, - retaining=False, **kwargs): + def __init__( + self, + type_conv=200, + concurrency_level=1, + enable_rowcount=True, + retaining=False, + **kwargs + ): super(FBDialect_kinterbasdb, self).__init__(**kwargs) self.enable_rowcount = enable_rowcount self.type_conv = type_conv @@ -102,7 +108,7 @@ class FBDialect_kinterbasdb(FBDialect): @classmethod def dbapi(cls): - return __import__('kinterbasdb') + return __import__("kinterbasdb") def do_execute(self, cursor, statement, parameters, context=None): # kinterbase does not accept a None, but wants an empty list @@ -116,28 +122,30 @@ class FBDialect_kinterbasdb(FBDialect): dbapi_connection.commit(self.retaining) def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if opts.get('port'): - opts['host'] = "%s/%s" % (opts['host'], opts['port']) - del opts['port'] + opts = url.translate_connect_args(username="user") + if opts.get("port"): + opts["host"] = "%s/%s" % (opts["host"], opts["port"]) + del opts["port"] opts.update(url.query) - util.coerce_kw_type(opts, 'type_conv', int) + util.coerce_kw_type(opts, "type_conv", int) - type_conv = opts.pop('type_conv', self.type_conv) - concurrency_level = opts.pop('concurrency_level', - self.concurrency_level) + type_conv = opts.pop("type_conv", self.type_conv) + concurrency_level = opts.pop( + "concurrency_level", self.concurrency_level + ) if self.dbapi is not None: - initialized = getattr(self.dbapi, 'initialized', None) + initialized = getattr(self.dbapi, "initialized", None) if initialized is None: # CVS rev 1.96 changed the name of the attribute: # http://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/ # Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96 - initialized = getattr(self.dbapi, '_initialized', False) + initialized = getattr(self.dbapi, "_initialized", False) if not initialized: - self.dbapi.init(type_conv=type_conv, - concurrency_level=concurrency_level) + self.dbapi.init( + type_conv=type_conv, concurrency_level=concurrency_level + ) return ([], opts) def _get_server_version_info(self, connection): @@ -160,25 +168,31 @@ class FBDialect_kinterbasdb(FBDialect): def _parse_version_info(self, version): m = match( - r'\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?', version) + r"\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?", version + ) if not m: raise AssertionError( - "Could not determine version from string '%s'" % version) + "Could not determine version from string '%s'" % version + ) if m.group(5) != None: - return tuple([int(x) for x in m.group(6, 7, 4)] + ['firebird']) + return tuple([int(x) for x in m.group(6, 7, 4)] + ["firebird"]) else: - return tuple([int(x) for x in m.group(1, 2, 3)] + ['interbase']) + return tuple([int(x) for x in m.group(1, 2, 3)] + ["interbase"]) def is_disconnect(self, e, connection, cursor): - if isinstance(e, (self.dbapi.OperationalError, - self.dbapi.ProgrammingError)): + if isinstance( + e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError) + ): msg = str(e) - return ('Unable to complete network request to host' in msg or - 'Invalid connection state' in msg or - 'Invalid cursor state' in msg or - 'connection shutdown' in msg) + return ( + "Unable to complete network request to host" in msg + or "Invalid connection state" in msg + or "Invalid cursor state" in msg + or "connection shutdown" in msg + ) else: return False + dialect = FBDialect_kinterbasdb diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py index 9c861e89db..88a94fcfb5 100644 --- a/lib/sqlalchemy/dialects/mssql/__init__.py +++ b/lib/sqlalchemy/dialects/mssql/__init__.py @@ -7,20 +7,74 @@ from . import base, pyodbc, adodbapi, pymssql, zxjdbc, mxodbc # noqa -from .base import \ - INTEGER, BIGINT, SMALLINT, TINYINT, VARCHAR, NVARCHAR, CHAR, \ - NCHAR, TEXT, NTEXT, DECIMAL, NUMERIC, FLOAT, DATETIME,\ - DATETIME2, DATETIMEOFFSET, DATE, TIME, SMALLDATETIME, \ - BINARY, VARBINARY, BIT, REAL, IMAGE, TIMESTAMP, ROWVERSION, \ - MONEY, SMALLMONEY, UNIQUEIDENTIFIER, SQL_VARIANT, XML +from .base import ( + INTEGER, + BIGINT, + SMALLINT, + TINYINT, + VARCHAR, + NVARCHAR, + CHAR, + NCHAR, + TEXT, + NTEXT, + DECIMAL, + NUMERIC, + FLOAT, + DATETIME, + DATETIME2, + DATETIMEOFFSET, + DATE, + TIME, + SMALLDATETIME, + BINARY, + VARBINARY, + BIT, + REAL, + IMAGE, + TIMESTAMP, + ROWVERSION, + MONEY, + SMALLMONEY, + UNIQUEIDENTIFIER, + SQL_VARIANT, + XML, +) base.dialect = dialect = pyodbc.dialect __all__ = ( - 'INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT', 'VARCHAR', 'NVARCHAR', 'CHAR', - 'NCHAR', 'TEXT', 'NTEXT', 'DECIMAL', 'NUMERIC', 'FLOAT', 'DATETIME', - 'DATETIME2', 'DATETIMEOFFSET', 'DATE', 'TIME', 'SMALLDATETIME', - 'BINARY', 'VARBINARY', 'BIT', 'REAL', 'IMAGE', 'TIMESTAMP', 'ROWVERSION', - 'MONEY', 'SMALLMONEY', 'UNIQUEIDENTIFIER', 'SQL_VARIANT', 'XML', 'dialect' + "INTEGER", + "BIGINT", + "SMALLINT", + "TINYINT", + "VARCHAR", + "NVARCHAR", + "CHAR", + "NCHAR", + "TEXT", + "NTEXT", + "DECIMAL", + "NUMERIC", + "FLOAT", + "DATETIME", + "DATETIME2", + "DATETIMEOFFSET", + "DATE", + "TIME", + "SMALLDATETIME", + "BINARY", + "VARBINARY", + "BIT", + "REAL", + "IMAGE", + "TIMESTAMP", + "ROWVERSION", + "MONEY", + "SMALLMONEY", + "UNIQUEIDENTIFIER", + "SQL_VARIANT", + "XML", + "dialect", ) diff --git a/lib/sqlalchemy/dialects/mssql/adodbapi.py b/lib/sqlalchemy/dialects/mssql/adodbapi.py index e5bb9ba57f..d985c3bb69 100644 --- a/lib/sqlalchemy/dialects/mssql/adodbapi.py +++ b/lib/sqlalchemy/dialects/mssql/adodbapi.py @@ -33,6 +33,7 @@ class MSDateTime_adodbapi(MSDateTime): if type(value) is datetime.date: return datetime.datetime(value.year, value.month, value.day) return value + return process @@ -41,18 +42,16 @@ class MSDialect_adodbapi(MSDialect): supports_sane_multi_rowcount = True supports_unicode = sys.maxunicode == 65535 supports_unicode_statements = True - driver = 'adodbapi' + driver = "adodbapi" @classmethod def import_dbapi(cls): import adodbapi as module + return module colspecs = util.update_copy( - MSDialect.colspecs, - { - sqltypes.DateTime: MSDateTime_adodbapi - } + MSDialect.colspecs, {sqltypes.DateTime: MSDateTime_adodbapi} ) def create_connect_args(self, url): @@ -61,14 +60,13 @@ class MSDialect_adodbapi(MSDialect): token = "'%s'" % token return token - keys = dict( - (k, check_quote(v)) for k, v in url.query.items() - ) + keys = dict((k, check_quote(v)) for k, v in url.query.items()) connectors = ["Provider=SQLOLEDB"] - if 'port' in keys: - connectors.append("Data Source=%s, %s" % - (keys.get("host"), keys.get("port"))) + if "port" in keys: + connectors.append( + "Data Source=%s, %s" % (keys.get("host"), keys.get("port")) + ) else: connectors.append("Data Source=%s" % keys.get("host")) connectors.append("Initial Catalog=%s" % keys.get("database")) @@ -81,7 +79,9 @@ class MSDialect_adodbapi(MSDialect): return [[";".join(connectors)], {}] def is_disconnect(self, e, connection, cursor): - return isinstance(e, self.dbapi.adodbapi.DatabaseError) and \ - "'connection failure'" in str(e) + return isinstance( + e, self.dbapi.adodbapi.DatabaseError + ) and "'connection failure'" in str(e) + dialect = MSDialect_adodbapi diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 9269225d31..161297015f 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -655,9 +655,22 @@ from ...sql import compiler, expression, util as sql_util, quoted_name from ... import engine from ...engine import reflection, default from ... import types as sqltypes -from ...types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ - FLOAT, DATETIME, DATE, BINARY, \ - TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR +from ...types import ( + INTEGER, + BIGINT, + SMALLINT, + DECIMAL, + NUMERIC, + FLOAT, + DATETIME, + DATE, + BINARY, + TEXT, + VARCHAR, + NVARCHAR, + CHAR, + NCHAR, +) from ...util import update_wrapper @@ -672,48 +685,202 @@ MS_2005_VERSION = (9,) MS_2000_VERSION = (8,) RESERVED_WORDS = set( - ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization', - 'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade', - 'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce', - 'collate', 'column', 'commit', 'compute', 'constraint', 'contains', - 'containstable', 'continue', 'convert', 'create', 'cross', 'current', - 'current_date', 'current_time', 'current_timestamp', 'current_user', - 'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default', - 'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double', - 'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec', - 'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor', - 'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full', - 'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity', - 'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert', - 'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like', - 'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not', - 'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource', - 'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer', - 'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print', - 'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext', - 'reconfigure', 'references', 'replication', 'restore', 'restrict', - 'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount', - 'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select', - 'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics', - 'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top', - 'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union', - 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values', - 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with', - 'writetext', - ]) + [ + "add", + "all", + "alter", + "and", + "any", + "as", + "asc", + "authorization", + "backup", + "begin", + "between", + "break", + "browse", + "bulk", + "by", + "cascade", + "case", + "check", + "checkpoint", + "close", + "clustered", + "coalesce", + "collate", + "column", + "commit", + "compute", + "constraint", + "contains", + "containstable", + "continue", + "convert", + "create", + "cross", + "current", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "dbcc", + "deallocate", + "declare", + "default", + "delete", + "deny", + "desc", + "disk", + "distinct", + "distributed", + "double", + "drop", + "dump", + "else", + "end", + "errlvl", + "escape", + "except", + "exec", + "execute", + "exists", + "exit", + "external", + "fetch", + "file", + "fillfactor", + "for", + "foreign", + "freetext", + "freetexttable", + "from", + "full", + "function", + "goto", + "grant", + "group", + "having", + "holdlock", + "identity", + "identity_insert", + "identitycol", + "if", + "in", + "index", + "inner", + "insert", + "intersect", + "into", + "is", + "join", + "key", + "kill", + "left", + "like", + "lineno", + "load", + "merge", + "national", + "nocheck", + "nonclustered", + "not", + "null", + "nullif", + "of", + "off", + "offsets", + "on", + "open", + "opendatasource", + "openquery", + "openrowset", + "openxml", + "option", + "or", + "order", + "outer", + "over", + "percent", + "pivot", + "plan", + "precision", + "primary", + "print", + "proc", + "procedure", + "public", + "raiserror", + "read", + "readtext", + "reconfigure", + "references", + "replication", + "restore", + "restrict", + "return", + "revert", + "revoke", + "right", + "rollback", + "rowcount", + "rowguidcol", + "rule", + "save", + "schema", + "securityaudit", + "select", + "session_user", + "set", + "setuser", + "shutdown", + "some", + "statistics", + "system_user", + "table", + "tablesample", + "textsize", + "then", + "to", + "top", + "tran", + "transaction", + "trigger", + "truncate", + "tsequal", + "union", + "unique", + "unpivot", + "update", + "updatetext", + "use", + "user", + "values", + "varying", + "view", + "waitfor", + "when", + "where", + "while", + "with", + "writetext", + ] +) class REAL(sqltypes.REAL): - __visit_name__ = 'REAL' + __visit_name__ = "REAL" def __init__(self, **kw): # REAL is a synonym for FLOAT(24) on SQL server - kw['precision'] = 24 + kw["precision"] = 24 super(REAL, self).__init__(**kw) class TINYINT(sqltypes.Integer): - __visit_name__ = 'TINYINT' + __visit_name__ = "TINYINT" # MSSQL DATE/TIME types have varied behavior, sometimes returning @@ -721,14 +888,15 @@ class TINYINT(sqltypes.Integer): # filter bind parameters into datetime objects (required by pyodbc, # not sure about other dialects). -class _MSDate(sqltypes.Date): +class _MSDate(sqltypes.Date): def bind_processor(self, dialect): def process(value): if type(value) == datetime.date: return datetime.datetime(value.year, value.month, value.day) else: return value + return process _reg = re.compile(r"(\d+)-(\d+)-(\d+)") @@ -741,18 +909,16 @@ class _MSDate(sqltypes.Date): m = self._reg.match(value) if not m: raise ValueError( - "could not parse %r as a date value" % (value, )) - return datetime.date(*[ - int(x or 0) - for x in m.groups() - ]) + "could not parse %r as a date value" % (value,) + ) + return datetime.date(*[int(x or 0) for x in m.groups()]) else: return value + return process class TIME(sqltypes.TIME): - def __init__(self, precision=None, **kwargs): self.precision = precision super(TIME, self).__init__() @@ -763,10 +929,12 @@ class TIME(sqltypes.TIME): def process(value): if isinstance(value, datetime.datetime): value = datetime.datetime.combine( - self.__zero_date, value.time()) + self.__zero_date, value.time() + ) elif isinstance(value, datetime.time): value = datetime.datetime.combine(self.__zero_date, value) return value + return process _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d{0,6}))?") @@ -779,24 +947,26 @@ class TIME(sqltypes.TIME): m = self._reg.match(value) if not m: raise ValueError( - "could not parse %r as a time value" % (value, )) - return datetime.time(*[ - int(x or 0) - for x in m.groups()]) + "could not parse %r as a time value" % (value,) + ) + return datetime.time(*[int(x or 0) for x in m.groups()]) else: return value + return process + + _MSTime = TIME class _DateTimeBase(object): - def bind_processor(self, dialect): def process(value): if type(value) == datetime.date: return datetime.datetime(value.year, value.month, value.day) else: return value + return process @@ -805,11 +975,11 @@ class _MSDateTime(_DateTimeBase, sqltypes.DateTime): class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime): - __visit_name__ = 'SMALLDATETIME' + __visit_name__ = "SMALLDATETIME" class DATETIME2(_DateTimeBase, sqltypes.DateTime): - __visit_name__ = 'DATETIME2' + __visit_name__ = "DATETIME2" def __init__(self, precision=None, **kw): super(DATETIME2, self).__init__(**kw) @@ -818,7 +988,7 @@ class DATETIME2(_DateTimeBase, sqltypes.DateTime): # TODO: is this not an Interval ? class DATETIMEOFFSET(sqltypes.TypeEngine): - __visit_name__ = 'DATETIMEOFFSET' + __visit_name__ = "DATETIMEOFFSET" def __init__(self, precision=None, **kwargs): self.precision = precision @@ -847,7 +1017,7 @@ class TIMESTAMP(sqltypes._Binary): """ - __visit_name__ = 'TIMESTAMP' + __visit_name__ = "TIMESTAMP" # expected by _Binary to be present length = None @@ -866,12 +1036,14 @@ class TIMESTAMP(sqltypes._Binary): def result_processor(self, dialect, coltype): super_ = super(TIMESTAMP, self).result_processor(dialect, coltype) if self.convert_int: + def process(value): value = super_(value) if value is not None: # https://stackoverflow.com/a/30403242/34549 - value = int(codecs.encode(value, 'hex'), 16) + value = int(codecs.encode(value, "hex"), 16) return value + return process else: return super_ @@ -898,7 +1070,7 @@ class ROWVERSION(TIMESTAMP): """ - __visit_name__ = 'ROWVERSION' + __visit_name__ = "ROWVERSION" class NTEXT(sqltypes.UnicodeText): @@ -906,7 +1078,7 @@ class NTEXT(sqltypes.UnicodeText): """MSSQL NTEXT type, for variable-length unicode text up to 2^30 characters.""" - __visit_name__ = 'NTEXT' + __visit_name__ = "NTEXT" class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary): @@ -925,11 +1097,12 @@ class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary): """ - __visit_name__ = 'VARBINARY' + + __visit_name__ = "VARBINARY" class IMAGE(sqltypes.LargeBinary): - __visit_name__ = 'IMAGE' + __visit_name__ = "IMAGE" class XML(sqltypes.Text): @@ -943,19 +1116,20 @@ class XML(sqltypes.Text): .. versionadded:: 1.1.11 """ - __visit_name__ = 'XML' + + __visit_name__ = "XML" class BIT(sqltypes.TypeEngine): - __visit_name__ = 'BIT' + __visit_name__ = "BIT" class MONEY(sqltypes.TypeEngine): - __visit_name__ = 'MONEY' + __visit_name__ = "MONEY" class SMALLMONEY(sqltypes.TypeEngine): - __visit_name__ = 'SMALLMONEY' + __visit_name__ = "SMALLMONEY" class UNIQUEIDENTIFIER(sqltypes.TypeEngine): @@ -963,7 +1137,8 @@ class UNIQUEIDENTIFIER(sqltypes.TypeEngine): class SQL_VARIANT(sqltypes.TypeEngine): - __visit_name__ = 'SQL_VARIANT' + __visit_name__ = "SQL_VARIANT" + # old names. MSDateTime = _MSDateTime @@ -990,36 +1165,36 @@ MSUniqueIdentifier = UNIQUEIDENTIFIER MSVariant = SQL_VARIANT ischema_names = { - 'int': INTEGER, - 'bigint': BIGINT, - 'smallint': SMALLINT, - 'tinyint': TINYINT, - 'varchar': VARCHAR, - 'nvarchar': NVARCHAR, - 'char': CHAR, - 'nchar': NCHAR, - 'text': TEXT, - 'ntext': NTEXT, - 'decimal': DECIMAL, - 'numeric': NUMERIC, - 'float': FLOAT, - 'datetime': DATETIME, - 'datetime2': DATETIME2, - 'datetimeoffset': DATETIMEOFFSET, - 'date': DATE, - 'time': TIME, - 'smalldatetime': SMALLDATETIME, - 'binary': BINARY, - 'varbinary': VARBINARY, - 'bit': BIT, - 'real': REAL, - 'image': IMAGE, - 'xml': XML, - 'timestamp': TIMESTAMP, - 'money': MONEY, - 'smallmoney': SMALLMONEY, - 'uniqueidentifier': UNIQUEIDENTIFIER, - 'sql_variant': SQL_VARIANT, + "int": INTEGER, + "bigint": BIGINT, + "smallint": SMALLINT, + "tinyint": TINYINT, + "varchar": VARCHAR, + "nvarchar": NVARCHAR, + "char": CHAR, + "nchar": NCHAR, + "text": TEXT, + "ntext": NTEXT, + "decimal": DECIMAL, + "numeric": NUMERIC, + "float": FLOAT, + "datetime": DATETIME, + "datetime2": DATETIME2, + "datetimeoffset": DATETIMEOFFSET, + "date": DATE, + "time": TIME, + "smalldatetime": SMALLDATETIME, + "binary": BINARY, + "varbinary": VARBINARY, + "bit": BIT, + "real": REAL, + "image": IMAGE, + "xml": XML, + "timestamp": TIMESTAMP, + "money": MONEY, + "smallmoney": SMALLMONEY, + "uniqueidentifier": UNIQUEIDENTIFIER, + "sql_variant": SQL_VARIANT, } @@ -1030,8 +1205,8 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): """ - if getattr(type_, 'collation', None): - collation = 'COLLATE %s' % type_.collation + if getattr(type_, "collation", None): + collation = "COLLATE %s" % type_.collation else: collation = None @@ -1041,15 +1216,14 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): if length: spec = spec + "(%s)" % length - return ' '.join([c for c in (spec, collation) - if c is not None]) + return " ".join([c for c in (spec, collation) if c is not None]) def visit_FLOAT(self, type_, **kw): - precision = getattr(type_, 'precision', None) + precision = getattr(type_, "precision", None) if precision is None: return "FLOAT" else: - return "FLOAT(%(precision)s)" % {'precision': precision} + return "FLOAT(%(precision)s)" % {"precision": precision} def visit_TINYINT(self, type_, **kw): return "TINYINT" @@ -1061,7 +1235,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "DATETIMEOFFSET" def visit_TIME(self, type_, **kw): - precision = getattr(type_, 'precision', None) + precision = getattr(type_, "precision", None) if precision is not None: return "TIME(%s)" % precision else: @@ -1074,7 +1248,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "ROWVERSION" def visit_DATETIME2(self, type_, **kw): - precision = getattr(type_, 'precision', None) + precision = getattr(type_, "precision", None) if precision is not None: return "DATETIME2(%s)" % precision else: @@ -1105,7 +1279,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return self._extend("TEXT", type_) def visit_VARCHAR(self, type_, **kw): - return self._extend("VARCHAR", type_, length=type_.length or 'max') + return self._extend("VARCHAR", type_, length=type_.length or "max") def visit_CHAR(self, type_, **kw): return self._extend("CHAR", type_) @@ -1114,7 +1288,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return self._extend("NCHAR", type_) def visit_NVARCHAR(self, type_, **kw): - return self._extend("NVARCHAR", type_, length=type_.length or 'max') + return self._extend("NVARCHAR", type_, length=type_.length or "max") def visit_date(self, type_, **kw): if self.dialect.server_version_info < MS_2008_VERSION: @@ -1141,10 +1315,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "XML" def visit_VARBINARY(self, type_, **kw): - return self._extend( - "VARBINARY", - type_, - length=type_.length or 'max') + return self._extend("VARBINARY", type_, length=type_.length or "max") def visit_boolean(self, type_, **kw): return self.visit_BIT(type_) @@ -1156,13 +1327,13 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "MONEY" def visit_SMALLMONEY(self, type_, **kw): - return 'SMALLMONEY' + return "SMALLMONEY" def visit_UNIQUEIDENTIFIER(self, type_, **kw): return "UNIQUEIDENTIFIER" def visit_SQL_VARIANT(self, type_, **kw): - return 'SQL_VARIANT' + return "SQL_VARIANT" class MSExecutionContext(default.DefaultExecutionContext): @@ -1186,41 +1357,44 @@ class MSExecutionContext(default.DefaultExecutionContext): insert_has_sequence = seq_column is not None if insert_has_sequence: - self._enable_identity_insert = \ - seq_column.key in self.compiled_parameters[0] or \ - ( - self.compiled.statement.parameters and ( - ( - self.compiled.statement._has_multi_parameters - and - seq_column.key in - self.compiled.statement.parameters[0] - ) or ( - not - self.compiled.statement._has_multi_parameters - and - seq_column.key in - self.compiled.statement.parameters - ) + self._enable_identity_insert = seq_column.key in self.compiled_parameters[ + 0 + ] or ( + self.compiled.statement.parameters + and ( + ( + self.compiled.statement._has_multi_parameters + and seq_column.key + in self.compiled.statement.parameters[0] + ) + or ( + not self.compiled.statement._has_multi_parameters + and seq_column.key + in self.compiled.statement.parameters ) ) + ) else: self._enable_identity_insert = False - self._select_lastrowid = not self.compiled.inline and \ - insert_has_sequence and \ - not self.compiled.returning and \ - not self._enable_identity_insert and \ - not self.executemany + self._select_lastrowid = ( + not self.compiled.inline + and insert_has_sequence + and not self.compiled.returning + and not self._enable_identity_insert + and not self.executemany + ) if self._enable_identity_insert: self.root_connection._cursor_execute( self.cursor, self._opt_encode( - "SET IDENTITY_INSERT %s ON" % - self.dialect.identifier_preparer.format_table(tbl)), + "SET IDENTITY_INSERT %s ON" + % self.dialect.identifier_preparer.format_table(tbl) + ), (), - self) + self, + ) def post_exec(self): """Disable IDENTITY_INSERT if enabled.""" @@ -1230,29 +1404,35 @@ class MSExecutionContext(default.DefaultExecutionContext): if self.dialect.use_scope_identity: conn._cursor_execute( self.cursor, - "SELECT scope_identity() AS lastrowid", (), self) + "SELECT scope_identity() AS lastrowid", + (), + self, + ) else: - conn._cursor_execute(self.cursor, - "SELECT @@identity AS lastrowid", - (), - self) + conn._cursor_execute( + self.cursor, "SELECT @@identity AS lastrowid", (), self + ) # fetchall() ensures the cursor is consumed without closing it row = self.cursor.fetchall()[0] self._lastrowid = int(row[0]) - if (self.isinsert or self.isupdate or self.isdelete) and \ - self.compiled.returning: + if ( + self.isinsert or self.isupdate or self.isdelete + ) and self.compiled.returning: self._result_proxy = engine.FullyBufferedResultProxy(self) if self._enable_identity_insert: conn._cursor_execute( self.cursor, self._opt_encode( - "SET IDENTITY_INSERT %s OFF" % - self.dialect.identifier_preparer. format_table( - self.compiled.statement.table)), + "SET IDENTITY_INSERT %s OFF" + % self.dialect.identifier_preparer.format_table( + self.compiled.statement.table + ) + ), (), - self) + self, + ) def get_lastrowid(self): return self._lastrowid @@ -1262,9 +1442,12 @@ class MSExecutionContext(default.DefaultExecutionContext): try: self.cursor.execute( self._opt_encode( - "SET IDENTITY_INSERT %s OFF" % - self.dialect.identifier_preparer. format_table( - self.compiled.statement.table))) + "SET IDENTITY_INSERT %s OFF" + % self.dialect.identifier_preparer.format_table( + self.compiled.statement.table + ) + ) + ) except Exception: pass @@ -1281,11 +1464,12 @@ class MSSQLCompiler(compiler.SQLCompiler): extract_map = util.update_copy( compiler.SQLCompiler.extract_map, { - 'doy': 'dayofyear', - 'dow': 'weekday', - 'milliseconds': 'millisecond', - 'microseconds': 'microsecond' - }) + "doy": "dayofyear", + "dow": "weekday", + "milliseconds": "millisecond", + "microseconds": "microsecond", + }, + ) def __init__(self, *args, **kwargs): self.tablealiases = {} @@ -1298,6 +1482,7 @@ class MSSQLCompiler(compiler.SQLCompiler): else: super_ = getattr(super(MSSQLCompiler, self), fn.__name__) return super_(*arg, **kw) + return decorate def visit_now_func(self, fn, **kw): @@ -1313,20 +1498,22 @@ class MSSQLCompiler(compiler.SQLCompiler): return "LEN%s" % self.function_argspec(fn, **kw) def visit_concat_op_binary(self, binary, operator, **kw): - return "%s + %s" % \ - (self.process(binary.left, **kw), - self.process(binary.right, **kw)) + return "%s + %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) def visit_true(self, expr, **kw): - return '1' + return "1" def visit_false(self, expr, **kw): - return '0' + return "0" def visit_match_op_binary(self, binary, operator, **kw): return "CONTAINS (%s, %s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def get_select_precolumns(self, select, **kw): """ MS-SQL puts TOP, it's version of LIMIT here """ @@ -1345,7 +1532,8 @@ class MSSQLCompiler(compiler.SQLCompiler): return s else: return compiler.SQLCompiler.get_select_precolumns( - self, select, **kw) + self, select, **kw + ) def get_from_hint_text(self, table, text): return text @@ -1363,20 +1551,21 @@ class MSSQLCompiler(compiler.SQLCompiler): """ if ( - ( - not select._simple_int_limit and - select._limit_clause is not None - ) or ( - select._offset_clause is not None and - not select._simple_int_offset or select._offset + (not select._simple_int_limit and select._limit_clause is not None) + or ( + select._offset_clause is not None + and not select._simple_int_offset + or select._offset ) - ) and not getattr(select, '_mssql_visit', None): + ) and not getattr(select, "_mssql_visit", None): # to use ROW_NUMBER(), an ORDER BY is required. if not select._order_by_clause.clauses: - raise exc.CompileError('MSSQL requires an order_by when ' - 'using an OFFSET or a non-simple ' - 'LIMIT clause') + raise exc.CompileError( + "MSSQL requires an order_by when " + "using an OFFSET or a non-simple " + "LIMIT clause" + ) _order_by_clauses = [ sql_util.unwrap_label_reference(elem) @@ -1385,24 +1574,31 @@ class MSSQLCompiler(compiler.SQLCompiler): limit_clause = select._limit_clause offset_clause = select._offset_clause - kwargs['select_wraps_for'] = select + kwargs["select_wraps_for"] = select select = select._generate() select._mssql_visit = True - select = select.column( - sql.func.ROW_NUMBER().over(order_by=_order_by_clauses) - .label("mssql_rn")).order_by(None).alias() + select = ( + select.column( + sql.func.ROW_NUMBER() + .over(order_by=_order_by_clauses) + .label("mssql_rn") + ) + .order_by(None) + .alias() + ) - mssql_rn = sql.column('mssql_rn') - limitselect = sql.select([c for c in select.c if - c.key != 'mssql_rn']) + mssql_rn = sql.column("mssql_rn") + limitselect = sql.select( + [c for c in select.c if c.key != "mssql_rn"] + ) if offset_clause is not None: limitselect.append_whereclause(mssql_rn > offset_clause) if limit_clause is not None: limitselect.append_whereclause( - mssql_rn <= (limit_clause + offset_clause)) + mssql_rn <= (limit_clause + offset_clause) + ) else: - limitselect.append_whereclause( - mssql_rn <= (limit_clause)) + limitselect.append_whereclause(mssql_rn <= (limit_clause)) return self.process(limitselect, **kwargs) else: return compiler.SQLCompiler.visit_select(self, select, **kwargs) @@ -1422,35 +1618,38 @@ class MSSQLCompiler(compiler.SQLCompiler): @_with_legacy_schema_aliasing def visit_alias(self, alias, **kw): # translate for schema-qualified table aliases - kw['mssql_aliased'] = alias.original + kw["mssql_aliased"] = alias.original return super(MSSQLCompiler, self).visit_alias(alias, **kw) @_with_legacy_schema_aliasing def visit_column(self, column, add_to_result_map=None, **kw): - if column.table is not None and \ - (not self.isupdate and not self.isdelete) or \ - self.is_subquery(): + if ( + column.table is not None + and (not self.isupdate and not self.isdelete) + or self.is_subquery() + ): # translate for schema-qualified table aliases t = self._schema_aliased_table(column.table) if t is not None: converted = expression._corresponding_column_or_error( - t, column) + t, column + ) if add_to_result_map is not None: add_to_result_map( column.name, column.name, (column, column.name, column.key), - column.type + column.type, ) - return super(MSSQLCompiler, self).\ - visit_column(converted, **kw) + return super(MSSQLCompiler, self).visit_column(converted, **kw) return super(MSSQLCompiler, self).visit_column( - column, add_to_result_map=add_to_result_map, **kw) + column, add_to_result_map=add_to_result_map, **kw + ) def _schema_aliased_table(self, table): - if getattr(table, 'schema', None) is not None: + if getattr(table, "schema", None) is not None: if table not in self.tablealiases: self.tablealiases[table] = table.alias() return self.tablealiases[table] @@ -1459,16 +1658,17 @@ class MSSQLCompiler(compiler.SQLCompiler): def visit_extract(self, extract, **kw): field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART(%s, %s)' % \ - (field, self.process(extract.expr, **kw)) + return "DATEPART(%s, %s)" % (field, self.process(extract.expr, **kw)) def visit_savepoint(self, savepoint_stmt): - return "SAVE TRANSACTION %s" % \ - self.preparer.format_savepoint(savepoint_stmt) + return "SAVE TRANSACTION %s" % self.preparer.format_savepoint( + savepoint_stmt + ) def visit_rollback_to_savepoint(self, savepoint_stmt): - return ("ROLLBACK TRANSACTION %s" - % self.preparer.format_savepoint(savepoint_stmt)) + return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint( + savepoint_stmt + ) def visit_binary(self, binary, **kwargs): """Move bind parameters to the right-hand side of an operator, where @@ -1481,10 +1681,11 @@ class MSSQLCompiler(compiler.SQLCompiler): and not isinstance(binary.right, expression.BindParameter) ): return self.process( - expression.BinaryExpression(binary.right, - binary.left, - binary.operator), - **kwargs) + expression.BinaryExpression( + binary.right, binary.left, binary.operator + ), + **kwargs + ) return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) def returning_clause(self, stmt, returning_cols): @@ -1497,12 +1698,13 @@ class MSSQLCompiler(compiler.SQLCompiler): adapter = sql_util.ClauseAdapter(target) columns = [ - self._label_select_column(None, adapter.traverse(c), - True, False, {}) + self._label_select_column( + None, adapter.traverse(c), True, False, {} + ) for c in expression._select_iterables(returning_cols) ] - return 'OUTPUT ' + ', '.join(columns) + return "OUTPUT " + ", ".join(columns) def get_cte_preamble(self, recursive): # SQL Server finds it too inconvenient to accept @@ -1515,13 +1717,14 @@ class MSSQLCompiler(compiler.SQLCompiler): if isinstance(column, expression.Function): return column.label(None) else: - return super(MSSQLCompiler, self).\ - label_select_column(select, column, asfrom) + return super(MSSQLCompiler, self).label_select_column( + select, column, asfrom + ) def for_update_clause(self, select): # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which # SQLAlchemy doesn't use - return '' + return "" def order_by_clause(self, select, **kw): order_by = self.process(select._order_by_clause, **kw) @@ -1532,10 +1735,9 @@ class MSSQLCompiler(compiler.SQLCompiler): else: return "" - def update_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the UPDATE..FROM clause specific to MSSQL. In MSSQL, if the UPDATE statement involves an alias of the table to @@ -1543,13 +1745,12 @@ class MSSQLCompiler(compiler.SQLCompiler): well. Otherwise, it is optional. Here, we add it regardless. """ - return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in [from_table] + extra_froms) + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) - def delete_table_clause(self, delete_stmt, from_table, - extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms): """If we have extra froms make sure we render any alias as hint.""" ashint = False if extra_froms: @@ -1558,20 +1759,21 @@ class MSSQLCompiler(compiler.SQLCompiler): self, asfrom=True, iscrud=True, ashint=ashint ) - def delete_extra_from_clause(self, delete_stmt, from_table, - extra_froms, from_hints, **kw): + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the DELETE .. FROM clause specific to MSSQL. Yes, it has the FROM keyword twice. """ - return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in [from_table] + extra_froms) + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) def visit_empty_set_expr(self, type_): - return 'SELECT 1 WHERE 1!=1' + return "SELECT 1 WHERE 1!=1" class MSSQLStrictCompiler(MSSQLCompiler): @@ -1583,20 +1785,21 @@ class MSSQLStrictCompiler(MSSQLCompiler): binds are used. """ + ansi_bind_rules = True def visit_in_op_binary(self, binary, operator, **kw): - kw['literal_binds'] = True + kw["literal_binds"] = True return "%s IN %s" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) def visit_notin_op_binary(self, binary, operator, **kw): - kw['literal_binds'] = True + kw["literal_binds"] = True return "%s NOT IN %s" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) def render_literal_value(self, value, type_): @@ -1615,23 +1818,28 @@ class MSSQLStrictCompiler(MSSQLCompiler): # SQL Server wants single quotes around the date string. return "'" + str(value) + "'" else: - return super(MSSQLStrictCompiler, self).\ - render_literal_value(value, type_) + return super(MSSQLStrictCompiler, self).render_literal_value( + value, type_ + ) class MSDDLCompiler(compiler.DDLCompiler): - def get_column_specification(self, column, **kwargs): colspec = ( - self.preparer.format_column(column) + " " + self.preparer.format_column(column) + + " " + self.dialect.type_compiler.process( - column.type, type_expression=column) + column.type, type_expression=column + ) ) if column.nullable is not None: - if not column.nullable or column.primary_key or \ - isinstance(column.default, sa_schema.Sequence) or \ - column.autoincrement is True: + if ( + not column.nullable + or column.primary_key + or isinstance(column.default, sa_schema.Sequence) + or column.autoincrement is True + ): colspec += " NOT NULL" else: colspec += " NULL" @@ -1639,15 +1847,18 @@ class MSDDLCompiler(compiler.DDLCompiler): if column.table is None: raise exc.CompileError( "mssql requires Table-bound columns " - "in order to generate DDL") + "in order to generate DDL" + ) # install an IDENTITY Sequence if we either a sequence or an implicit # IDENTITY column if isinstance(column.default, sa_schema.Sequence): - if (column.default.start is not None or - column.default.increment is not None or - column is not column.table._autoincrement_column): + if ( + column.default.start is not None + or column.default.increment is not None + or column is not column.table._autoincrement_column + ): util.warn_deprecated( "Use of Sequence with SQL Server in order to affect the " "parameters of the IDENTITY value is deprecated, as " @@ -1655,18 +1866,23 @@ class MSDDLCompiler(compiler.DDLCompiler): "will correspond to an actual SQL Server " "CREATE SEQUENCE in " "a future release. Please use the mssql_identity_start " - "and mssql_identity_increment parameters.") + "and mssql_identity_increment parameters." + ) if column.default.start == 0: start = 0 else: start = column.default.start or 1 - colspec += " IDENTITY(%s,%s)" % (start, - column.default.increment or 1) - elif column is column.table._autoincrement_column or \ - column.autoincrement is True: - start = column.dialect_options['mssql']['identity_start'] - increment = column.dialect_options['mssql']['identity_increment'] + colspec += " IDENTITY(%s,%s)" % ( + start, + column.default.increment or 1, + ) + elif ( + column is column.table._autoincrement_column + or column.autoincrement is True + ): + start = column.dialect_options["mssql"]["identity_start"] + increment = column.dialect_options["mssql"]["identity_increment"] colspec += " IDENTITY(%s,%s)" % (start, increment) else: default = self.get_column_default_string(column) @@ -1684,84 +1900,88 @@ class MSDDLCompiler(compiler.DDLCompiler): text += "UNIQUE " # handle clustering option - clustered = index.dialect_options['mssql']['clustered'] + clustered = index.dialect_options["mssql"]["clustered"] if clustered is not None: if clustered: text += "CLUSTERED " else: text += "NONCLUSTERED " - text += "INDEX %s ON %s (%s)" \ - % ( - self._prepared_index_name(index, - include_schema=include_schema), - preparer.format_table(index.table), - ', '.join( - self.sql_compiler.process(expr, - include_table=False, - literal_binds=True) for - expr in index.expressions) - ) + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=include_schema), + preparer.format_table(index.table), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) # handle other included columns - if index.dialect_options['mssql']['include']: - inclusions = [index.table.c[col] - if isinstance(col, util.string_types) else col - for col in - index.dialect_options['mssql']['include'] - ] + if index.dialect_options["mssql"]["include"]: + inclusions = [ + index.table.c[col] + if isinstance(col, util.string_types) + else col + for col in index.dialect_options["mssql"]["include"] + ] - text += " INCLUDE (%s)" \ - % ', '.join([preparer.quote(c.name) - for c in inclusions]) + text += " INCLUDE (%s)" % ", ".join( + [preparer.quote(c.name) for c in inclusions] + ) return text def visit_drop_index(self, drop): return "\nDROP INDEX %s ON %s" % ( self._prepared_index_name(drop.element, include_schema=False), - self.preparer.format_table(drop.element.table) + self.preparer.format_table(drop.element.table), ) def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: - return '' + return "" text = "" if constraint.name is not None: - text += "CONSTRAINT %s " % \ - self.preparer.format_constraint(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint( + constraint + ) text += "PRIMARY KEY " - clustered = constraint.dialect_options['mssql']['clustered'] + clustered = constraint.dialect_options["mssql"]["clustered"] if clustered is not None: if clustered: text += "CLUSTERED " else: text += "NONCLUSTERED " - text += "(%s)" % ', '.join(self.preparer.quote(c.name) - for c in constraint) + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) for c in constraint + ) text += self.define_constraint_deferrability(constraint) return text def visit_unique_constraint(self, constraint): if len(constraint) == 0: - return '' + return "" text = "" if constraint.name is not None: - text += "CONSTRAINT %s " % \ - self.preparer.format_constraint(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint( + constraint + ) text += "UNIQUE " - clustered = constraint.dialect_options['mssql']['clustered'] + clustered = constraint.dialect_options["mssql"]["clustered"] if clustered is not None: if clustered: text += "CLUSTERED " else: text += "NONCLUSTERED " - text += "(%s)" % ', '.join(self.preparer.quote(c.name) - for c in constraint) + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) for c in constraint + ) text += self.define_constraint_deferrability(constraint) return text @@ -1771,8 +1991,11 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect): super(MSIdentifierPreparer, self).__init__( - dialect, initial_quote='[', - final_quote=']', quote_case_sensitive_collations=False) + dialect, + initial_quote="[", + final_quote="]", + quote_case_sensitive_collations=False, + ) def _escape_identifier(self, value): return value @@ -1783,7 +2006,9 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): dbname, owner = _schema_elements(schema) if dbname: result = "%s.%s" % ( - self.quote(dbname, force), self.quote(owner, force)) + self.quote(dbname, force), + self.quote(owner, force), + ) elif owner: result = self.quote(owner, force) else: @@ -1794,16 +2019,37 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): def _db_plus_owner_listing(fn): def wrap(dialect, connection, schema=None, **kw): dbname, owner = _owner_plus_db(dialect, schema) - return _switch_db(dbname, connection, fn, dialect, connection, - dbname, owner, schema, **kw) + return _switch_db( + dbname, + connection, + fn, + dialect, + connection, + dbname, + owner, + schema, + **kw + ) + return update_wrapper(wrap, fn) def _db_plus_owner(fn): def wrap(dialect, connection, tablename, schema=None, **kw): dbname, owner = _owner_plus_db(dialect, schema) - return _switch_db(dbname, connection, fn, dialect, connection, - tablename, dbname, owner, schema, **kw) + return _switch_db( + dbname, + connection, + fn, + dialect, + connection, + tablename, + dbname, + owner, + schema, + **kw + ) + return update_wrapper(wrap, fn) @@ -1837,9 +2083,9 @@ def _schema_elements(schema): for token in re.split(r"(\[|\]|\.)", schema): if not token: continue - if token == '[': + if token == "[": bracket = True - elif token == ']': + elif token == "]": bracket = False elif not bracket and token == ".": push.append(symbol) @@ -1857,7 +2103,7 @@ def _schema_elements(schema): class MSDialect(default.DefaultDialect): - name = 'mssql' + name = "mssql" supports_default_values = True supports_empty_insert = False execution_ctx_cls = MSExecutionContext @@ -1871,9 +2117,9 @@ class MSDialect(default.DefaultDialect): sqltypes.Time: TIME, } - engine_config_types = default.DefaultDialect.engine_config_types.union([ - ('legacy_schema_aliasing', util.asbool), - ]) + engine_config_types = default.DefaultDialect.engine_config_types.union( + [("legacy_schema_aliasing", util.asbool)] + ) ischema_names = ischema_names @@ -1890,36 +2136,30 @@ class MSDialect(default.DefaultDialect): preparer = MSIdentifierPreparer construct_arguments = [ - (sa_schema.PrimaryKeyConstraint, { - "clustered": None - }), - (sa_schema.UniqueConstraint, { - "clustered": None - }), - (sa_schema.Index, { - "clustered": None, - "include": None - }), - (sa_schema.Column, { - "identity_start": 1, - "identity_increment": 1 - }) + (sa_schema.PrimaryKeyConstraint, {"clustered": None}), + (sa_schema.UniqueConstraint, {"clustered": None}), + (sa_schema.Index, {"clustered": None, "include": None}), + (sa_schema.Column, {"identity_start": 1, "identity_increment": 1}), ] - def __init__(self, - query_timeout=None, - use_scope_identity=True, - max_identifier_length=None, - schema_name="dbo", - isolation_level=None, - deprecate_large_types=None, - legacy_schema_aliasing=False, **opts): + def __init__( + self, + query_timeout=None, + use_scope_identity=True, + max_identifier_length=None, + schema_name="dbo", + isolation_level=None, + deprecate_large_types=None, + legacy_schema_aliasing=False, + **opts + ): self.query_timeout = int(query_timeout or 0) self.schema_name = schema_name self.use_scope_identity = use_scope_identity - self.max_identifier_length = int(max_identifier_length or 0) or \ - self.max_identifier_length + self.max_identifier_length = ( + int(max_identifier_length or 0) or self.max_identifier_length + ) self.deprecate_large_types = deprecate_large_types self.legacy_schema_aliasing = legacy_schema_aliasing @@ -1936,27 +2176,33 @@ class MSDialect(default.DefaultDialect): # SQL Server does not support RELEASE SAVEPOINT pass - _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED', - 'READ COMMITTED', 'REPEATABLE READ', - 'SNAPSHOT']) + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "SNAPSHOT", + ] + ) def set_isolation_level(self, connection, level): - level = level.replace('_', ' ') + level = level.replace("_", " ") if level not in self._isolation_lookup: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) ) cursor = connection.cursor() - cursor.execute( - "SET TRANSACTION ISOLATION LEVEL %s" % level) + cursor.execute("SET TRANSACTION ISOLATION LEVEL %s" % level) cursor.close() def get_isolation_level(self, connection): if self.server_version_info < MS_2005_VERSION: raise NotImplementedError( - "Can't fetch isolation level prior to SQL Server 2005") + "Can't fetch isolation level prior to SQL Server 2005" + ) last_error = None @@ -1964,7 +2210,8 @@ class MSDialect(default.DefaultDialect): for view in views: cursor = connection.cursor() try: - cursor.execute(""" + cursor.execute( + """ SELECT CASE transaction_isolation_level WHEN 0 THEN NULL WHEN 1 THEN 'READ UNCOMMITTED' @@ -1974,7 +2221,9 @@ class MSDialect(default.DefaultDialect): WHEN 5 THEN 'SNAPSHOT' END AS TRANSACTION_ISOLATION_LEVEL FROM %s where session_id = @@SPID - """ % view) + """ + % view + ) val = cursor.fetchone()[0] except self.dbapi.Error as err: # Python3 scoping rules @@ -1987,7 +2236,8 @@ class MSDialect(default.DefaultDialect): else: util.warn( "Could not fetch transaction isolation level, " - "tried views: %s; final error was: %s" % (views, last_error)) + "tried views: %s; final error was: %s" % (views, last_error) + ) raise NotImplementedError( "Can't fetch isolation level on this particular " @@ -2000,8 +2250,10 @@ class MSDialect(default.DefaultDialect): def on_connect(self): if self.isolation_level is not None: + def connect(conn): self.set_isolation_level(conn, self.isolation_level) + return connect else: return None @@ -2010,16 +2262,20 @@ class MSDialect(default.DefaultDialect): if self.server_version_info[0] not in list(range(8, 17)): util.warn( "Unrecognized server version info '%s'. Some SQL Server " - "features may not function properly." % - ".".join(str(x) for x in self.server_version_info)) - if self.server_version_info >= MS_2005_VERSION and \ - 'implicit_returning' not in self.__dict__: + "features may not function properly." + % ".".join(str(x) for x in self.server_version_info) + ) + if ( + self.server_version_info >= MS_2005_VERSION + and "implicit_returning" not in self.__dict__ + ): self.implicit_returning = True if self.server_version_info >= MS_2008_VERSION: self.supports_multivalues_insert = True if self.deprecate_large_types is None: - self.deprecate_large_types = \ + self.deprecate_large_types = ( self.server_version_info >= MS_2012_VERSION + ) def _get_default_schema_name(self, connection): if self.server_version_info < MS_2005_VERSION: @@ -2039,17 +2295,19 @@ class MSDialect(default.DefaultDialect): whereclause = columns.c.table_name == tablename if owner: - whereclause = sql.and_(whereclause, - columns.c.table_schema == owner) + whereclause = sql.and_( + whereclause, columns.c.table_schema == owner + ) s = sql.select([columns], whereclause) c = connection.execute(s) return c.first() is not None @reflection.cache def get_schema_names(self, connection, **kw): - s = sql.select([ischema.schemata.c.schema_name], - order_by=[ischema.schemata.c.schema_name] - ) + s = sql.select( + [ischema.schemata.c.schema_name], + order_by=[ischema.schemata.c.schema_name], + ) schema_names = [r[0] for r in connection.execute(s)] return schema_names @@ -2057,12 +2315,13 @@ class MSDialect(default.DefaultDialect): @_db_plus_owner_listing def get_table_names(self, connection, dbname, owner, schema, **kw): tables = ischema.tables - s = sql.select([tables.c.table_name], - sql.and_( - tables.c.table_schema == owner, - tables.c.table_type == 'BASE TABLE' - ), - order_by=[tables.c.table_name] + s = sql.select( + [tables.c.table_name], + sql.and_( + tables.c.table_schema == owner, + tables.c.table_type == "BASE TABLE", + ), + order_by=[tables.c.table_name], ) table_names = [r[0] for r in connection.execute(s)] return table_names @@ -2071,12 +2330,12 @@ class MSDialect(default.DefaultDialect): @_db_plus_owner_listing def get_view_names(self, connection, dbname, owner, schema, **kw): tables = ischema.tables - s = sql.select([tables.c.table_name], - sql.and_( - tables.c.table_schema == owner, - tables.c.table_type == 'VIEW' - ), - order_by=[tables.c.table_name] + s = sql.select( + [tables.c.table_name], + sql.and_( + tables.c.table_schema == owner, tables.c.table_type == "VIEW" + ), + order_by=[tables.c.table_name], ) view_names = [r[0] for r in connection.execute(s)] return view_names @@ -2090,30 +2349,33 @@ class MSDialect(default.DefaultDialect): return [] rp = connection.execute( - sql.text("select ind.index_id, ind.is_unique, ind.name " - "from sys.indexes as ind join sys.tables as tab on " - "ind.object_id=tab.object_id " - "join sys.schemas as sch on sch.schema_id=tab.schema_id " - "where tab.name = :tabname " - "and sch.name=:schname " - "and ind.is_primary_key=0 and ind.type != 0", - bindparams=[ - sql.bindparam('tabname', tablename, - sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', owner, - sqltypes.String(convert_unicode=True)) - ], - typemap={ - 'name': sqltypes.Unicode() - } - ) + sql.text( + "select ind.index_id, ind.is_unique, ind.name " + "from sys.indexes as ind join sys.tables as tab on " + "ind.object_id=tab.object_id " + "join sys.schemas as sch on sch.schema_id=tab.schema_id " + "where tab.name = :tabname " + "and sch.name=:schname " + "and ind.is_primary_key=0 and ind.type != 0", + bindparams=[ + sql.bindparam( + "tabname", + tablename, + sqltypes.String(convert_unicode=True), + ), + sql.bindparam( + "schname", owner, sqltypes.String(convert_unicode=True) + ), + ], + typemap={"name": sqltypes.Unicode()}, + ) ) indexes = {} for row in rp: - indexes[row['index_id']] = { - 'name': row['name'], - 'unique': row['is_unique'] == 1, - 'column_names': [] + indexes[row["index_id"]] = { + "name": row["name"], + "unique": row["is_unique"] == 1, + "column_names": [], } rp = connection.execute( sql.text( @@ -2127,24 +2389,29 @@ class MSDialect(default.DefaultDialect): "where tab.name=:tabname " "and sch.name=:schname", bindparams=[ - sql.bindparam('tabname', tablename, - sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', owner, - sqltypes.String(convert_unicode=True)) + sql.bindparam( + "tabname", + tablename, + sqltypes.String(convert_unicode=True), + ), + sql.bindparam( + "schname", owner, sqltypes.String(convert_unicode=True) + ), ], - typemap={'name': sqltypes.Unicode()} - ), + typemap={"name": sqltypes.Unicode()}, + ) ) for row in rp: - if row['index_id'] in indexes: - indexes[row['index_id']]['column_names'].append(row['name']) + if row["index_id"] in indexes: + indexes[row["index_id"]]["column_names"].append(row["name"]) return list(indexes.values()) @reflection.cache @_db_plus_owner - def get_view_definition(self, connection, viewname, - dbname, owner, schema, **kw): + def get_view_definition( + self, connection, viewname, dbname, owner, schema, **kw + ): rp = connection.execute( sql.text( "select definition from sys.sql_modules as mod, " @@ -2155,11 +2422,15 @@ class MSDialect(default.DefaultDialect): "views.schema_id=sch.schema_id and " "views.name=:viewname and sch.name=:schname", bindparams=[ - sql.bindparam('viewname', viewname, - sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', owner, - sqltypes.String(convert_unicode=True)) - ] + sql.bindparam( + "viewname", + viewname, + sqltypes.String(convert_unicode=True), + ), + sql.bindparam( + "schname", owner, sqltypes.String(convert_unicode=True) + ), + ], ) ) @@ -2173,12 +2444,15 @@ class MSDialect(default.DefaultDialect): # Get base columns columns = ischema.columns if owner: - whereclause = sql.and_(columns.c.table_name == tablename, - columns.c.table_schema == owner) + whereclause = sql.and_( + columns.c.table_name == tablename, + columns.c.table_schema == owner, + ) else: whereclause = columns.c.table_name == tablename - s = sql.select([columns], whereclause, - order_by=[columns.c.ordinal_position]) + s = sql.select( + [columns], whereclause, order_by=[columns.c.ordinal_position] + ) c = connection.execute(s) cols = [] @@ -2186,57 +2460,76 @@ class MSDialect(default.DefaultDialect): row = c.fetchone() if row is None: break - (name, type, nullable, charlen, - numericprec, numericscale, default, collation) = ( + ( + name, + type, + nullable, + charlen, + numericprec, + numericscale, + default, + collation, + ) = ( row[columns.c.column_name], row[columns.c.data_type], - row[columns.c.is_nullable] == 'YES', + row[columns.c.is_nullable] == "YES", row[columns.c.character_maximum_length], row[columns.c.numeric_precision], row[columns.c.numeric_scale], row[columns.c.column_default], - row[columns.c.collation_name] + row[columns.c.collation_name], ) coltype = self.ischema_names.get(type, None) kwargs = {} - if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, - MSNText, MSBinary, MSVarBinary, - sqltypes.LargeBinary): + if coltype in ( + MSString, + MSChar, + MSNVarchar, + MSNChar, + MSText, + MSNText, + MSBinary, + MSVarBinary, + sqltypes.LargeBinary, + ): if charlen == -1: charlen = None - kwargs['length'] = charlen + kwargs["length"] = charlen if collation: - kwargs['collation'] = collation + kwargs["collation"] = collation if coltype is None: util.warn( - "Did not recognize type '%s' of column '%s'" % - (type, name)) + "Did not recognize type '%s' of column '%s'" % (type, name) + ) coltype = sqltypes.NULLTYPE else: - if issubclass(coltype, sqltypes.Numeric) and \ - coltype is not MSReal: - kwargs['scale'] = numericscale - kwargs['precision'] = numericprec + if ( + issubclass(coltype, sqltypes.Numeric) + and coltype is not MSReal + ): + kwargs["scale"] = numericscale + kwargs["precision"] = numericprec coltype = coltype(**kwargs) cdict = { - 'name': name, - 'type': coltype, - 'nullable': nullable, - 'default': default, - 'autoincrement': False, + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": False, } cols.append(cdict) # autoincrement and identity colmap = {} for col in cols: - colmap[col['name']] = col + colmap[col["name"]] = col # We also run an sp_columns to check for identity columns: - cursor = connection.execute("sp_columns @table_name = '%s', " - "@table_owner = '%s'" - % (tablename, owner)) + cursor = connection.execute( + "sp_columns @table_name = '%s', " + "@table_owner = '%s'" % (tablename, owner) + ) ic = None while True: row = cursor.fetchone() @@ -2245,10 +2538,10 @@ class MSDialect(default.DefaultDialect): (col_name, type_name) = row[3], row[5] if type_name.endswith("identity") and col_name in colmap: ic = col_name - colmap[col_name]['autoincrement'] = True - colmap[col_name]['dialect_options'] = { - 'mssql_identity_start': 1, - 'mssql_identity_increment': 1 + colmap[col_name]["autoincrement"] = True + colmap[col_name]["dialect_options"] = { + "mssql_identity_start": 1, + "mssql_identity_increment": 1, } break cursor.close() @@ -2262,64 +2555,74 @@ class MSDialect(default.DefaultDialect): row = cursor.first() if row is not None and row[0] is not None: - colmap[ic]['dialect_options'].update({ - 'mssql_identity_start': int(row[0]), - 'mssql_identity_increment': int(row[1]) - }) + colmap[ic]["dialect_options"].update( + { + "mssql_identity_start": int(row[0]), + "mssql_identity_increment": int(row[1]), + } + ) return cols @reflection.cache @_db_plus_owner - def get_pk_constraint(self, connection, tablename, - dbname, owner, schema, **kw): + def get_pk_constraint( + self, connection, tablename, dbname, owner, schema, **kw + ): pkeys = [] TC = ischema.constraints - C = ischema.key_constraints.alias('C') + C = ischema.key_constraints.alias("C") # Primary key constraints - s = sql.select([C.c.column_name, - TC.c.constraint_type, - C.c.constraint_name], - sql.and_(TC.c.constraint_name == C.c.constraint_name, - TC.c.table_schema == C.c.table_schema, - C.c.table_name == tablename, - C.c.table_schema == owner) - ) + s = sql.select( + [C.c.column_name, TC.c.constraint_type, C.c.constraint_name], + sql.and_( + TC.c.constraint_name == C.c.constraint_name, + TC.c.table_schema == C.c.table_schema, + C.c.table_name == tablename, + C.c.table_schema == owner, + ), + ) c = connection.execute(s) constraint_name = None for row in c: - if 'PRIMARY' in row[TC.c.constraint_type.name]: + if "PRIMARY" in row[TC.c.constraint_type.name]: pkeys.append(row[0]) if constraint_name is None: constraint_name = row[C.c.constraint_name.name] - return {'constrained_columns': pkeys, 'name': constraint_name} + return {"constrained_columns": pkeys, "name": constraint_name} @reflection.cache @_db_plus_owner - def get_foreign_keys(self, connection, tablename, - dbname, owner, schema, **kw): + def get_foreign_keys( + self, connection, tablename, dbname, owner, schema, **kw + ): RR = ischema.ref_constraints - C = ischema.key_constraints.alias('C') - R = ischema.key_constraints.alias('R') + C = ischema.key_constraints.alias("C") + R = ischema.key_constraints.alias("R") # Foreign key constraints - s = sql.select([C.c.column_name, - R.c.table_schema, R.c.table_name, R.c.column_name, - RR.c.constraint_name, RR.c.match_option, - RR.c.update_rule, - RR.c.delete_rule], - sql.and_(C.c.table_name == tablename, - C.c.table_schema == owner, - RR.c.constraint_schema == C.c.table_schema, - C.c.constraint_name == RR.c.constraint_name, - R.c.constraint_name == - RR.c.unique_constraint_name, - R.c.constraint_schema == - RR.c.unique_constraint_schema, - C.c.ordinal_position == R.c.ordinal_position - ), - order_by=[RR.c.constraint_name, R.c.ordinal_position] - ) + s = sql.select( + [ + C.c.column_name, + R.c.table_schema, + R.c.table_name, + R.c.column_name, + RR.c.constraint_name, + RR.c.match_option, + RR.c.update_rule, + RR.c.delete_rule, + ], + sql.and_( + C.c.table_name == tablename, + C.c.table_schema == owner, + RR.c.constraint_schema == C.c.table_schema, + C.c.constraint_name == RR.c.constraint_name, + R.c.constraint_name == RR.c.unique_constraint_name, + R.c.constraint_schema == RR.c.unique_constraint_schema, + C.c.ordinal_position == R.c.ordinal_position, + ), + order_by=[RR.c.constraint_name, R.c.ordinal_position], + ) # group rows by constraint ID, to handle multi-column FKs fkeys = [] @@ -2327,11 +2630,11 @@ class MSDialect(default.DefaultDialect): def fkey_rec(): return { - 'name': None, - 'constrained_columns': [], - 'referred_schema': None, - 'referred_table': None, - 'referred_columns': [] + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], } fkeys = util.defaultdict(fkey_rec) @@ -2340,17 +2643,18 @@ class MSDialect(default.DefaultDialect): scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r rec = fkeys[rfknm] - rec['name'] = rfknm - if not rec['referred_table']: - rec['referred_table'] = rtbl + rec["name"] = rfknm + if not rec["referred_table"]: + rec["referred_table"] = rtbl if schema is not None or owner != rschema: if dbname: rschema = dbname + "." + rschema - rec['referred_schema'] = rschema + rec["referred_schema"] = rschema - local_cols, remote_cols = \ - rec['constrained_columns'],\ - rec['referred_columns'] + local_cols, remote_cols = ( + rec["constrained_columns"], + rec["referred_columns"], + ) local_cols.append(scol) remote_cols.append(rcol) diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py index 3682fae481..c4ea8ab0cb 100644 --- a/lib/sqlalchemy/dialects/mssql/information_schema.py +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -38,102 +38,122 @@ class _cast_on_2005(expression.ColumnElement): @compiles(_cast_on_2005) def _compile(element, compiler, **kw): from . import base - if compiler.dialect.server_version_info is None or \ - compiler.dialect.server_version_info < base.MS_2005_VERSION: + + if ( + compiler.dialect.server_version_info is None + or compiler.dialect.server_version_info < base.MS_2005_VERSION + ): return compiler.process(element.bindvalue, **kw) else: return compiler.process(cast(element.bindvalue, Unicode), **kw) -schemata = Table("SCHEMATA", ischema, - Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"), - Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"), - Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"), - schema="INFORMATION_SCHEMA") - -tables = Table("TABLES", ischema, - Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), - Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, key="table_name"), - Column( - "TABLE_TYPE", String(convert_unicode=True), - key="table_type"), - schema="INFORMATION_SCHEMA") - -columns = Table("COLUMNS", ischema, - Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, key="table_name"), - Column("COLUMN_NAME", CoerceUnicode, key="column_name"), - Column("IS_NULLABLE", Integer, key="is_nullable"), - Column("DATA_TYPE", String, key="data_type"), - Column("ORDINAL_POSITION", Integer, key="ordinal_position"), - Column("CHARACTER_MAXIMUM_LENGTH", Integer, - key="character_maximum_length"), - Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), - Column("NUMERIC_SCALE", Integer, key="numeric_scale"), - Column("COLUMN_DEFAULT", Integer, key="column_default"), - Column("COLLATION_NAME", String, key="collation_name"), - schema="INFORMATION_SCHEMA") - -constraints = Table("TABLE_CONSTRAINTS", ischema, - Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, key="table_name"), - Column("CONSTRAINT_NAME", CoerceUnicode, - key="constraint_name"), - Column("CONSTRAINT_TYPE", String( - convert_unicode=True), key="constraint_type"), - schema="INFORMATION_SCHEMA") - -column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema, - Column("TABLE_SCHEMA", CoerceUnicode, - key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, - key="table_name"), - Column("COLUMN_NAME", CoerceUnicode, - key="column_name"), - Column("CONSTRAINT_NAME", CoerceUnicode, - key="constraint_name"), - schema="INFORMATION_SCHEMA") - -key_constraints = Table("KEY_COLUMN_USAGE", ischema, - Column("TABLE_SCHEMA", CoerceUnicode, - key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, - key="table_name"), - Column("COLUMN_NAME", CoerceUnicode, - key="column_name"), - Column("CONSTRAINT_NAME", CoerceUnicode, - key="constraint_name"), - Column("CONSTRAINT_SCHEMA", CoerceUnicode, - key="constraint_schema"), - Column("ORDINAL_POSITION", Integer, - key="ordinal_position"), - schema="INFORMATION_SCHEMA") - -ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema, - Column("CONSTRAINT_CATALOG", CoerceUnicode, - key="constraint_catalog"), - Column("CONSTRAINT_SCHEMA", CoerceUnicode, - key="constraint_schema"), - Column("CONSTRAINT_NAME", CoerceUnicode, - key="constraint_name"), - # TODO: is CATLOG misspelled ? - Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode, - key="unique_constraint_catalog"), - - Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode, - key="unique_constraint_schema"), - Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode, - key="unique_constraint_name"), - Column("MATCH_OPTION", String, key="match_option"), - Column("UPDATE_RULE", String, key="update_rule"), - Column("DELETE_RULE", String, key="delete_rule"), - schema="INFORMATION_SCHEMA") - -views = Table("VIEWS", ischema, - Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), - Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, key="table_name"), - Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"), - Column("CHECK_OPTION", String, key="check_option"), - Column("IS_UPDATABLE", String, key="is_updatable"), - schema="INFORMATION_SCHEMA") + +schemata = Table( + "SCHEMATA", + ischema, + Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"), + Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"), + Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"), + schema="INFORMATION_SCHEMA", +) + +tables = Table( + "TABLES", + ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("TABLE_TYPE", String(convert_unicode=True), key="table_type"), + schema="INFORMATION_SCHEMA", +) + +columns = Table( + "COLUMNS", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column( + "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length" + ), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="INFORMATION_SCHEMA", +) + +constraints = Table( + "TABLE_CONSTRAINTS", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column( + "CONSTRAINT_TYPE", String(convert_unicode=True), key="constraint_type" + ), + schema="INFORMATION_SCHEMA", +) + +column_constraints = Table( + "CONSTRAINT_COLUMN_USAGE", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + schema="INFORMATION_SCHEMA", +) + +key_constraints = Table( + "KEY_COLUMN_USAGE", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + schema="INFORMATION_SCHEMA", +) + +ref_constraints = Table( + "REFERENTIAL_CONSTRAINTS", + ischema, + Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"), + Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + # TODO: is CATLOG misspelled ? + Column( + "UNIQUE_CONSTRAINT_CATLOG", + CoerceUnicode, + key="unique_constraint_catalog", + ), + Column( + "UNIQUE_CONSTRAINT_SCHEMA", + CoerceUnicode, + key="unique_constraint_schema", + ), + Column( + "UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name" + ), + Column("MATCH_OPTION", String, key="match_option"), + Column("UPDATE_RULE", String, key="update_rule"), + Column("DELETE_RULE", String, key="delete_rule"), + schema="INFORMATION_SCHEMA", +) + +views = Table( + "VIEWS", + ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"), + Column("CHECK_OPTION", String, key="check_option"), + Column("IS_UPDATABLE", String, key="is_updatable"), + schema="INFORMATION_SCHEMA", +) diff --git a/lib/sqlalchemy/dialects/mssql/mxodbc.py b/lib/sqlalchemy/dialects/mssql/mxodbc.py index 8983a3b60a..3b9ea27077 100644 --- a/lib/sqlalchemy/dialects/mssql/mxodbc.py +++ b/lib/sqlalchemy/dialects/mssql/mxodbc.py @@ -46,10 +46,14 @@ of ``False`` will unconditionally use string-escaped parameters. from ... import types as sqltypes from ...connectors.mxodbc import MxODBCConnector from .pyodbc import MSExecutionContext_pyodbc, _MSNumeric_pyodbc -from .base import (MSDialect, - MSSQLStrictCompiler, - VARBINARY, - _MSDateTime, _MSDate, _MSTime) +from .base import ( + MSDialect, + MSSQLStrictCompiler, + VARBINARY, + _MSDateTime, + _MSDate, + _MSTime, +) class _MSNumeric_mxodbc(_MSNumeric_pyodbc): @@ -64,6 +68,7 @@ class _MSDate_mxodbc(_MSDate): return "%s-%s-%s" % (value.year, value.month, value.day) else: return None + return process @@ -74,6 +79,7 @@ class _MSTime_mxodbc(_MSTime): return "%s:%s:%s" % (value.hour, value.minute, value.second) else: return None + return process @@ -98,6 +104,7 @@ class _VARBINARY_mxodbc(VARBINARY): else: # should pull from mx.ODBC.Manager.BinaryNull return dialect.dbapi.BinaryNull + return process @@ -107,6 +114,7 @@ class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc): SELECT SCOPE_IDENTITY in cases where OUTPUT clause does not work (tables with insert triggers). """ + # todo - investigate whether the pyodbc execution context # is really only being used in cases where OUTPUT # won't work. @@ -136,4 +144,5 @@ class MSDialect_mxodbc(MxODBCConnector, MSDialect): super(MSDialect_mxodbc, self).__init__(**params) self.description_encoding = description_encoding + dialect = MSDialect_mxodbc diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py index 8589c8b06a..847c003292 100644 --- a/lib/sqlalchemy/dialects/mssql/pymssql.py +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -35,7 +35,6 @@ class _MSNumeric_pymssql(sqltypes.Numeric): class MSIdentifierPreparer_pymssql(MSIdentifierPreparer): - def __init__(self, dialect): super(MSIdentifierPreparer_pymssql, self).__init__(dialect) # pymssql has the very unusual behavior that it uses pyformat @@ -45,47 +44,45 @@ class MSIdentifierPreparer_pymssql(MSIdentifierPreparer): class MSDialect_pymssql(MSDialect): supports_native_decimal = True - driver = 'pymssql' + driver = "pymssql" preparer = MSIdentifierPreparer_pymssql colspecs = util.update_copy( MSDialect.colspecs, - { - sqltypes.Numeric: _MSNumeric_pymssql, - sqltypes.Float: sqltypes.Float, - } + {sqltypes.Numeric: _MSNumeric_pymssql, sqltypes.Float: sqltypes.Float}, ) @classmethod def dbapi(cls): - module = __import__('pymssql') + module = __import__("pymssql") # pymmsql < 2.1.1 doesn't have a Binary method. we use string client_ver = tuple(int(x) for x in module.__version__.split(".")) if client_ver < (2, 1, 1): # TODO: monkeypatching here is less than ideal - module.Binary = lambda x: x if hasattr(x, 'decode') else str(x) + module.Binary = lambda x: x if hasattr(x, "decode") else str(x) - if client_ver < (1, ): - util.warn("The pymssql dialect expects at least " - "the 1.0 series of the pymssql DBAPI.") + if client_ver < (1,): + util.warn( + "The pymssql dialect expects at least " + "the 1.0 series of the pymssql DBAPI." + ) return module def _get_server_version_info(self, connection): vers = connection.scalar("select @@version") - m = re.match( - r"Microsoft .*? - (\d+).(\d+).(\d+).(\d+)", vers) + m = re.match(r"Microsoft .*? - (\d+).(\d+).(\d+).(\d+)", vers) if m: return tuple(int(x) for x in m.group(1, 2, 3, 4)) else: return None def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') + opts = url.translate_connect_args(username="user") opts.update(url.query) - port = opts.pop('port', None) - if port and 'host' in opts: - opts['host'] = "%s:%s" % (opts['host'], port) + port = opts.pop("port", None) + if port and "host" in opts: + opts["host"] = "%s:%s" % (opts["host"], port) return [[], opts] def is_disconnect(self, e, connection, cursor): @@ -105,12 +102,13 @@ class MSDialect_pymssql(MSDialect): return False def set_isolation_level(self, connection, level): - if level == 'AUTOCOMMIT': + if level == "AUTOCOMMIT": connection.autocommit(True) else: connection.autocommit(False) - super(MSDialect_pymssql, self).set_isolation_level(connection, - level) + super(MSDialect_pymssql, self).set_isolation_level( + connection, level + ) dialect = MSDialect_pymssql diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 34f81d6e8b..db5573c2c7 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -132,15 +132,13 @@ class _ms_numeric_pyodbc(object): def bind_processor(self, dialect): - super_process = super(_ms_numeric_pyodbc, self).\ - bind_processor(dialect) + super_process = super(_ms_numeric_pyodbc, self).bind_processor(dialect) if not dialect._need_decimal_fix: return super_process def process(value): - if self.asdecimal and \ - isinstance(value, decimal.Decimal): + if self.asdecimal and isinstance(value, decimal.Decimal): adjusted = value.adjusted() if adjusted < 0: return self._small_dec_to_string(value) @@ -151,6 +149,7 @@ class _ms_numeric_pyodbc(object): return super_process(value) else: return value + return process # these routines needed for older versions of pyodbc. @@ -158,30 +157,31 @@ class _ms_numeric_pyodbc(object): def _small_dec_to_string(self, value): return "%s0.%s%s" % ( - (value < 0 and '-' or ''), - '0' * (abs(value.adjusted()) - 1), - "".join([str(nint) for nint in value.as_tuple()[1]])) + (value < 0 and "-" or ""), + "0" * (abs(value.adjusted()) - 1), + "".join([str(nint) for nint in value.as_tuple()[1]]), + ) def _large_dec_to_string(self, value): _int = value.as_tuple()[1] - if 'E' in str(value): + if "E" in str(value): result = "%s%s%s" % ( - (value < 0 and '-' or ''), + (value < 0 and "-" or ""), "".join([str(s) for s in _int]), - "0" * (value.adjusted() - (len(_int) - 1))) + "0" * (value.adjusted() - (len(_int) - 1)), + ) else: if (len(_int) - 1) > value.adjusted(): result = "%s%s.%s" % ( - (value < 0 and '-' or ''), - "".join( - [str(s) for s in _int][0:value.adjusted() + 1]), - "".join( - [str(s) for s in _int][value.adjusted() + 1:])) + (value < 0 and "-" or ""), + "".join([str(s) for s in _int][0 : value.adjusted() + 1]), + "".join([str(s) for s in _int][value.adjusted() + 1 :]), + ) else: result = "%s%s" % ( - (value < 0 and '-' or ''), - "".join( - [str(s) for s in _int][0:value.adjusted() + 1])) + (value < 0 and "-" or ""), + "".join([str(s) for s in _int][0 : value.adjusted() + 1]), + ) return result @@ -212,6 +212,7 @@ class _ms_binary_pyodbc(object): else: # pyodbc-specific return dialect.dbapi.BinaryNull + return process @@ -243,9 +244,11 @@ class MSExecutionContext_pyodbc(MSExecutionContext): # don't embed the scope_identity select into an # "INSERT .. DEFAULT VALUES" - if self._select_lastrowid and \ - self.dialect.use_scope_identity and \ - len(self.parameters[0]): + if ( + self._select_lastrowid + and self.dialect.use_scope_identity + and len(self.parameters[0]) + ): self._embedded_scope_identity = True self.statement += "; select scope_identity()" @@ -281,26 +284,31 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): sqltypes.Numeric: _MSNumeric_pyodbc, sqltypes.Float: _MSFloat_pyodbc, BINARY: _BINARY_pyodbc, - # SQL Server dialect has a VARBINARY that is just to support # "deprecate_large_types" w/ VARBINARY(max), but also we must # handle the usual SQL standard VARBINARY VARBINARY: _VARBINARY_pyodbc, sqltypes.VARBINARY: _VARBINARY_pyodbc, sqltypes.LargeBinary: _VARBINARY_pyodbc, - } + }, ) - def __init__(self, description_encoding=None, fast_executemany=False, - **params): - if 'description_encoding' in params: - self.description_encoding = params.pop('description_encoding') + def __init__( + self, description_encoding=None, fast_executemany=False, **params + ): + if "description_encoding" in params: + self.description_encoding = params.pop("description_encoding") super(MSDialect_pyodbc, self).__init__(**params) - self.use_scope_identity = self.use_scope_identity and \ - self.dbapi and \ - hasattr(self.dbapi.Cursor, 'nextset') - self._need_decimal_fix = self.dbapi and \ - self._dbapi_version() < (2, 1, 8) + self.use_scope_identity = ( + self.use_scope_identity + and self.dbapi + and hasattr(self.dbapi.Cursor, "nextset") + ) + self._need_decimal_fix = self.dbapi and self._dbapi_version() < ( + 2, + 1, + 8, + ) self.fast_executemany = fast_executemany def _get_server_version_info(self, connection): @@ -308,16 +316,18 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): # "Version of the instance of SQL Server, in the form # of 'major.minor.build.revision'" raw = connection.scalar( - "SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)") + "SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)" + ) except exc.DBAPIError: # SQL Server docs indicate this function isn't present prior to # 2008. Before we had the VARCHAR cast above, pyodbc would also # fail on this query. - return super(MSDialect_pyodbc, self).\ - _get_server_version_info(connection, allow_chars=False) + return super(MSDialect_pyodbc, self)._get_server_version_info( + connection, allow_chars=False + ) else: version = [] - r = re.compile(r'[.\-]') + r = re.compile(r"[.\-]") for n in r.split(raw): try: version.append(int(n)) @@ -329,17 +339,27 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): if self.fast_executemany: cursor.fast_executemany = True super(MSDialect_pyodbc, self).do_executemany( - cursor, statement, parameters, context=context) + cursor, statement, parameters, context=context + ) def is_disconnect(self, e, connection, cursor): if isinstance(e, self.dbapi.Error): for code in ( - '08S01', '01002', '08003', '08007', - '08S02', '08001', 'HYT00', 'HY010', - '10054'): + "08S01", + "01002", + "08003", + "08007", + "08S02", + "08001", + "HYT00", + "HY010", + "10054", + ): if code in str(e): return True return super(MSDialect_pyodbc, self).is_disconnect( - e, connection, cursor) + e, connection, cursor + ) + dialect = MSDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/mssql/zxjdbc.py b/lib/sqlalchemy/dialects/mssql/zxjdbc.py index 3fb93b28a6..13fc46e190 100644 --- a/lib/sqlalchemy/dialects/mssql/zxjdbc.py +++ b/lib/sqlalchemy/dialects/mssql/zxjdbc.py @@ -44,26 +44,28 @@ class MSExecutionContext_zxjdbc(MSExecutionContext): self.cursor.nextset() self._lastrowid = int(row[0]) - if (self.isinsert or self.isupdate or self.isdelete) and \ - self.compiled.returning: + if ( + self.isinsert or self.isupdate or self.isdelete + ) and self.compiled.returning: self._result_proxy = engine.FullyBufferedResultProxy(self) if self._enable_identity_insert: table = self.dialect.identifier_preparer.format_table( - self.compiled.statement.table) + self.compiled.statement.table + ) self.cursor.execute("SET IDENTITY_INSERT %s OFF" % table) class MSDialect_zxjdbc(ZxJDBCConnector, MSDialect): - jdbc_db_name = 'jtds:sqlserver' - jdbc_driver_name = 'net.sourceforge.jtds.jdbc.Driver' + jdbc_db_name = "jtds:sqlserver" + jdbc_driver_name = "net.sourceforge.jtds.jdbc.Driver" execution_ctx_cls = MSExecutionContext_zxjdbc def _get_server_version_info(self, connection): return tuple( - int(x) - for x in connection.connection.dbversion.split('.') + int(x) for x in connection.connection.dbversion.split(".") ) + dialect = MSDialect_zxjdbc diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py index de4e1fa41d..ffeb8f486b 100644 --- a/lib/sqlalchemy/dialects/mysql/__init__.py +++ b/lib/sqlalchemy/dialects/mysql/__init__.py @@ -5,18 +5,56 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from . import base, mysqldb, oursql, \ - pyodbc, zxjdbc, mysqlconnector, pymysql, \ - gaerdbms, cymysql +from . import ( + base, + mysqldb, + oursql, + pyodbc, + zxjdbc, + mysqlconnector, + pymysql, + gaerdbms, + cymysql, +) -from .base import \ - BIGINT, BINARY, BIT, BLOB, BOOLEAN, CHAR, DATE, DATETIME, \ - DECIMAL, DOUBLE, ENUM, DECIMAL,\ - FLOAT, INTEGER, INTEGER, JSON, LONGBLOB, LONGTEXT, MEDIUMBLOB, \ - MEDIUMINT, MEDIUMTEXT, NCHAR, \ - NVARCHAR, NUMERIC, SET, SMALLINT, REAL, TEXT, TIME, TIMESTAMP, \ - TINYBLOB, TINYINT, TINYTEXT,\ - VARBINARY, VARCHAR, YEAR +from .base import ( + BIGINT, + BINARY, + BIT, + BLOB, + BOOLEAN, + CHAR, + DATE, + DATETIME, + DECIMAL, + DOUBLE, + ENUM, + DECIMAL, + FLOAT, + INTEGER, + INTEGER, + JSON, + LONGBLOB, + LONGTEXT, + MEDIUMBLOB, + MEDIUMINT, + MEDIUMTEXT, + NCHAR, + NVARCHAR, + NUMERIC, + SET, + SMALLINT, + REAL, + TEXT, + TIME, + TIMESTAMP, + TINYBLOB, + TINYINT, + TINYTEXT, + VARBINARY, + VARCHAR, + YEAR, +) from .dml import insert, Insert @@ -25,10 +63,41 @@ base.dialect = dialect = mysqldb.dialect __all__ = ( - 'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', - 'DECIMAL', 'DOUBLE', 'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER', - 'JSON', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT', 'MEDIUMTEXT', - 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME', - 'TIMESTAMP', 'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR', - 'YEAR', 'dialect' + "BIGINT", + "BINARY", + "BIT", + "BLOB", + "BOOLEAN", + "CHAR", + "DATE", + "DATETIME", + "DECIMAL", + "DOUBLE", + "ENUM", + "DECIMAL", + "FLOAT", + "INTEGER", + "INTEGER", + "JSON", + "LONGBLOB", + "LONGTEXT", + "MEDIUMBLOB", + "MEDIUMINT", + "MEDIUMTEXT", + "NCHAR", + "NVARCHAR", + "NUMERIC", + "SET", + "SMALLINT", + "REAL", + "TEXT", + "TIME", + "TIMESTAMP", + "TINYBLOB", + "TINYINT", + "TINYTEXT", + "VARBINARY", + "VARCHAR", + "YEAR", + "dialect", ) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 673d4b9ff8..7b0d0618c7 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -746,85 +746,340 @@ from ...engine import reflection from ...engine import default from ... import types as sqltypes from ...util import topological -from ...types import DATE, BOOLEAN, \ - BLOB, BINARY, VARBINARY +from ...types import DATE, BOOLEAN, BLOB, BINARY, VARBINARY from . import reflection as _reflection -from .types import BIGINT, BIT, CHAR, DECIMAL, DATETIME, \ - DOUBLE, FLOAT, INTEGER, LONGBLOB, LONGTEXT, MEDIUMBLOB, MEDIUMINT, \ - MEDIUMTEXT, NCHAR, NUMERIC, NVARCHAR, REAL, SMALLINT, TEXT, TIME, \ - TIMESTAMP, TINYBLOB, TINYINT, TINYTEXT, VARCHAR, YEAR -from .types import _StringType, _IntegerType, _NumericType, \ - _FloatType, _MatchType +from .types import ( + BIGINT, + BIT, + CHAR, + DECIMAL, + DATETIME, + DOUBLE, + FLOAT, + INTEGER, + LONGBLOB, + LONGTEXT, + MEDIUMBLOB, + MEDIUMINT, + MEDIUMTEXT, + NCHAR, + NUMERIC, + NVARCHAR, + REAL, + SMALLINT, + TEXT, + TIME, + TIMESTAMP, + TINYBLOB, + TINYINT, + TINYTEXT, + VARCHAR, + YEAR, +) +from .types import ( + _StringType, + _IntegerType, + _NumericType, + _FloatType, + _MatchType, +) from .enumerated import ENUM, SET from .json import JSON, JSONIndexType, JSONPathType RESERVED_WORDS = set( - ['accessible', 'add', 'all', 'alter', 'analyze', 'and', 'as', 'asc', - 'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both', - 'by', 'call', 'cascade', 'case', 'change', 'char', 'character', 'check', - 'collate', 'column', 'condition', 'constraint', 'continue', 'convert', - 'create', 'cross', 'current_date', 'current_time', 'current_timestamp', - 'current_user', 'cursor', 'database', 'databases', 'day_hour', - 'day_microsecond', 'day_minute', 'day_second', 'dec', 'decimal', - 'declare', 'default', 'delayed', 'delete', 'desc', 'describe', - 'deterministic', 'distinct', 'distinctrow', 'div', 'double', 'drop', - 'dual', 'each', 'else', 'elseif', 'enclosed', 'escaped', 'exists', - 'exit', 'explain', 'false', 'fetch', 'float', 'float4', 'float8', - 'for', 'force', 'foreign', 'from', 'fulltext', 'grant', 'group', - 'having', 'high_priority', 'hour_microsecond', 'hour_minute', - 'hour_second', 'if', 'ignore', 'in', 'index', 'infile', 'inner', 'inout', - 'insensitive', 'insert', 'int', 'int1', 'int2', 'int3', 'int4', 'int8', - 'integer', 'interval', 'into', 'is', 'iterate', 'join', 'key', 'keys', - 'kill', 'leading', 'leave', 'left', 'like', 'limit', 'linear', 'lines', - 'load', 'localtime', 'localtimestamp', 'lock', 'long', 'longblob', - 'longtext', 'loop', 'low_priority', 'master_ssl_verify_server_cert', - 'match', 'mediumblob', 'mediumint', 'mediumtext', 'middleint', - 'minute_microsecond', 'minute_second', 'mod', 'modifies', 'natural', - 'not', 'no_write_to_binlog', 'null', 'numeric', 'on', 'optimize', - 'option', 'optionally', 'or', 'order', 'out', 'outer', 'outfile', - 'precision', 'primary', 'procedure', 'purge', 'range', 'read', 'reads', - 'read_only', 'read_write', 'real', 'references', 'regexp', 'release', - 'rename', 'repeat', 'replace', 'require', 'restrict', 'return', - 'revoke', 'right', 'rlike', 'schema', 'schemas', 'second_microsecond', - 'select', 'sensitive', 'separator', 'set', 'show', 'smallint', 'spatial', - 'specific', 'sql', 'sqlexception', 'sqlstate', 'sqlwarning', - 'sql_big_result', 'sql_calc_found_rows', 'sql_small_result', 'ssl', - 'starting', 'straight_join', 'table', 'terminated', 'then', 'tinyblob', - 'tinyint', 'tinytext', 'to', 'trailing', 'trigger', 'true', 'undo', - 'union', 'unique', 'unlock', 'unsigned', 'update', 'usage', 'use', - 'using', 'utc_date', 'utc_time', 'utc_timestamp', 'values', 'varbinary', - 'varchar', 'varcharacter', 'varying', 'when', 'where', 'while', 'with', - - 'write', 'x509', 'xor', 'year_month', 'zerofill', # 5.0 - - 'columns', 'fields', 'privileges', 'soname', 'tables', # 4.1 - - 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range', - 'read_only', 'read_write', # 5.1 - - 'general', 'ignore_server_ids', 'master_heartbeat_period', 'maxvalue', - 'resignal', 'signal', 'slow', # 5.5 - - 'get', 'io_after_gtids', 'io_before_gtids', 'master_bind', 'one_shot', - 'partition', 'sql_after_gtids', 'sql_before_gtids', # 5.6 - - 'generated', 'optimizer_costs', 'stored', 'virtual', # 5.7 - - 'admin', 'cume_dist', 'empty', 'except', 'first_value', 'grouping', - 'function', 'groups', 'json_table', 'last_value', 'nth_value', - 'ntile', 'of', 'over', 'percent_rank', 'persist', 'persist_only', - 'rank', 'recursive', 'role', 'row', 'rows', 'row_number', 'system', - 'window', # 8.0 - ]) + [ + "accessible", + "add", + "all", + "alter", + "analyze", + "and", + "as", + "asc", + "asensitive", + "before", + "between", + "bigint", + "binary", + "blob", + "both", + "by", + "call", + "cascade", + "case", + "change", + "char", + "character", + "check", + "collate", + "column", + "condition", + "constraint", + "continue", + "convert", + "create", + "cross", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "databases", + "day_hour", + "day_microsecond", + "day_minute", + "day_second", + "dec", + "decimal", + "declare", + "default", + "delayed", + "delete", + "desc", + "describe", + "deterministic", + "distinct", + "distinctrow", + "div", + "double", + "drop", + "dual", + "each", + "else", + "elseif", + "enclosed", + "escaped", + "exists", + "exit", + "explain", + "false", + "fetch", + "float", + "float4", + "float8", + "for", + "force", + "foreign", + "from", + "fulltext", + "grant", + "group", + "having", + "high_priority", + "hour_microsecond", + "hour_minute", + "hour_second", + "if", + "ignore", + "in", + "index", + "infile", + "inner", + "inout", + "insensitive", + "insert", + "int", + "int1", + "int2", + "int3", + "int4", + "int8", + "integer", + "interval", + "into", + "is", + "iterate", + "join", + "key", + "keys", + "kill", + "leading", + "leave", + "left", + "like", + "limit", + "linear", + "lines", + "load", + "localtime", + "localtimestamp", + "lock", + "long", + "longblob", + "longtext", + "loop", + "low_priority", + "master_ssl_verify_server_cert", + "match", + "mediumblob", + "mediumint", + "mediumtext", + "middleint", + "minute_microsecond", + "minute_second", + "mod", + "modifies", + "natural", + "not", + "no_write_to_binlog", + "null", + "numeric", + "on", + "optimize", + "option", + "optionally", + "or", + "order", + "out", + "outer", + "outfile", + "precision", + "primary", + "procedure", + "purge", + "range", + "read", + "reads", + "read_only", + "read_write", + "real", + "references", + "regexp", + "release", + "rename", + "repeat", + "replace", + "require", + "restrict", + "return", + "revoke", + "right", + "rlike", + "schema", + "schemas", + "second_microsecond", + "select", + "sensitive", + "separator", + "set", + "show", + "smallint", + "spatial", + "specific", + "sql", + "sqlexception", + "sqlstate", + "sqlwarning", + "sql_big_result", + "sql_calc_found_rows", + "sql_small_result", + "ssl", + "starting", + "straight_join", + "table", + "terminated", + "then", + "tinyblob", + "tinyint", + "tinytext", + "to", + "trailing", + "trigger", + "true", + "undo", + "union", + "unique", + "unlock", + "unsigned", + "update", + "usage", + "use", + "using", + "utc_date", + "utc_time", + "utc_timestamp", + "values", + "varbinary", + "varchar", + "varcharacter", + "varying", + "when", + "where", + "while", + "with", + "write", + "x509", + "xor", + "year_month", + "zerofill", # 5.0 + "columns", + "fields", + "privileges", + "soname", + "tables", # 4.1 + "accessible", + "linear", + "master_ssl_verify_server_cert", + "range", + "read_only", + "read_write", # 5.1 + "general", + "ignore_server_ids", + "master_heartbeat_period", + "maxvalue", + "resignal", + "signal", + "slow", # 5.5 + "get", + "io_after_gtids", + "io_before_gtids", + "master_bind", + "one_shot", + "partition", + "sql_after_gtids", + "sql_before_gtids", # 5.6 + "generated", + "optimizer_costs", + "stored", + "virtual", # 5.7 + "admin", + "cume_dist", + "empty", + "except", + "first_value", + "grouping", + "function", + "groups", + "json_table", + "last_value", + "nth_value", + "ntile", + "of", + "over", + "percent_rank", + "persist", + "persist_only", + "rank", + "recursive", + "role", + "row", + "rows", + "row_number", + "system", + "window", # 8.0 + ] +) AUTOCOMMIT_RE = re.compile( - r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)', - re.I | re.UNICODE) + r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)", + re.I | re.UNICODE, +) SET_RE = re.compile( - r'\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w', - re.I | re.UNICODE) + r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE +) # old names @@ -870,52 +1125,50 @@ colspecs = { sqltypes.MatchType: _MatchType, sqltypes.JSON: JSON, sqltypes.JSON.JSONIndexType: JSONIndexType, - sqltypes.JSON.JSONPathType: JSONPathType - + sqltypes.JSON.JSONPathType: JSONPathType, } # Everything 3.23 through 5.1 excepting OpenGIS types. ischema_names = { - 'bigint': BIGINT, - 'binary': BINARY, - 'bit': BIT, - 'blob': BLOB, - 'boolean': BOOLEAN, - 'char': CHAR, - 'date': DATE, - 'datetime': DATETIME, - 'decimal': DECIMAL, - 'double': DOUBLE, - 'enum': ENUM, - 'fixed': DECIMAL, - 'float': FLOAT, - 'int': INTEGER, - 'integer': INTEGER, - 'json': JSON, - 'longblob': LONGBLOB, - 'longtext': LONGTEXT, - 'mediumblob': MEDIUMBLOB, - 'mediumint': MEDIUMINT, - 'mediumtext': MEDIUMTEXT, - 'nchar': NCHAR, - 'nvarchar': NVARCHAR, - 'numeric': NUMERIC, - 'set': SET, - 'smallint': SMALLINT, - 'text': TEXT, - 'time': TIME, - 'timestamp': TIMESTAMP, - 'tinyblob': TINYBLOB, - 'tinyint': TINYINT, - 'tinytext': TINYTEXT, - 'varbinary': VARBINARY, - 'varchar': VARCHAR, - 'year': YEAR, + "bigint": BIGINT, + "binary": BINARY, + "bit": BIT, + "blob": BLOB, + "boolean": BOOLEAN, + "char": CHAR, + "date": DATE, + "datetime": DATETIME, + "decimal": DECIMAL, + "double": DOUBLE, + "enum": ENUM, + "fixed": DECIMAL, + "float": FLOAT, + "int": INTEGER, + "integer": INTEGER, + "json": JSON, + "longblob": LONGBLOB, + "longtext": LONGTEXT, + "mediumblob": MEDIUMBLOB, + "mediumint": MEDIUMINT, + "mediumtext": MEDIUMTEXT, + "nchar": NCHAR, + "nvarchar": NVARCHAR, + "numeric": NUMERIC, + "set": SET, + "smallint": SMALLINT, + "text": TEXT, + "time": TIME, + "timestamp": TIMESTAMP, + "tinyblob": TINYBLOB, + "tinyint": TINYINT, + "tinytext": TINYTEXT, + "varbinary": VARBINARY, + "varchar": VARCHAR, + "year": YEAR, } class MySQLExecutionContext(default.DefaultExecutionContext): - def should_autocommit_text(self, statement): return AUTOCOMMIT_RE.match(statement) @@ -932,7 +1185,7 @@ class MySQLCompiler(compiler.SQLCompiler): """Overridden from base SQLCompiler value""" extract_map = compiler.SQLCompiler.extract_map.copy() - extract_map.update({'milliseconds': 'millisecond'}) + extract_map.update({"milliseconds": "millisecond"}) def visit_random_func(self, fn, **kw): return "rand%s" % self.function_argspec(fn) @@ -943,12 +1196,14 @@ class MySQLCompiler(compiler.SQLCompiler): def visit_json_getitem_op_binary(self, binary, operator, **kw): return "JSON_EXTRACT(%s, %s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_json_path_getitem_op_binary(self, binary, operator, **kw): return "JSON_EXTRACT(%s, %s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_on_duplicate_key_update(self, on_duplicate, **kw): if on_duplicate._parameter_ordering: @@ -958,7 +1213,8 @@ class MySQLCompiler(compiler.SQLCompiler): ] ordered_keys = set(parameter_ordering) cols = [ - self.statement.table.c[key] for key in parameter_ordering + self.statement.table.c[key] + for key in parameter_ordering if key in self.statement.table.c ] + [ c for c in self.statement.table.c if c.key not in ordered_keys @@ -979,9 +1235,11 @@ class MySQLCompiler(compiler.SQLCompiler): val = val._clone() val.type = column.type value_text = self.process(val.self_group(), use_schema=False) - elif isinstance(val, elements.ColumnClause) \ - and val.table is on_duplicate.inserted_alias: - value_text = 'VALUES(' + self.preparer.quote(column.name) + ')' + elif ( + isinstance(val, elements.ColumnClause) + and val.table is on_duplicate.inserted_alias + ): + value_text = "VALUES(" + self.preparer.quote(column.name) + ")" else: value_text = self.process(val.self_group(), use_schema=False) name_text = self.preparer.quote(column.name) @@ -990,22 +1248,27 @@ class MySQLCompiler(compiler.SQLCompiler): non_matching = set(on_duplicate.update) - set(c.key for c in cols) if non_matching: util.warn( - 'Additional column names not matching ' - "any column keys in table '%s': %s" % ( + "Additional column names not matching " + "any column keys in table '%s': %s" + % ( self.statement.table.name, - (', '.join("'%s'" % c for c in non_matching)) + (", ".join("'%s'" % c for c in non_matching)), ) ) - return 'ON DUPLICATE KEY UPDATE ' + ', '.join(clauses) + return "ON DUPLICATE KEY UPDATE " + ", ".join(clauses) def visit_concat_op_binary(self, binary, operator, **kw): - return "concat(%s, %s)" % (self.process(binary.left, **kw), - self.process(binary.right, **kw)) + return "concat(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) def visit_match_op_binary(self, binary, operator, **kw): - return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % \ - (self.process(binary.left, **kw), self.process(binary.right, **kw)) + return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) def get_from_hint_text(self, table, text): return text @@ -1016,26 +1279,35 @@ class MySQLCompiler(compiler.SQLCompiler): if isinstance(type_, sqltypes.TypeDecorator): return self.visit_typeclause(typeclause, type_.impl, **kw) elif isinstance(type_, sqltypes.Integer): - if getattr(type_, 'unsigned', False): - return 'UNSIGNED INTEGER' + if getattr(type_, "unsigned", False): + return "UNSIGNED INTEGER" else: - return 'SIGNED INTEGER' + return "SIGNED INTEGER" elif isinstance(type_, sqltypes.TIMESTAMP): - return 'DATETIME' - elif isinstance(type_, (sqltypes.DECIMAL, sqltypes.DateTime, - sqltypes.Date, sqltypes.Time)): + return "DATETIME" + elif isinstance( + type_, + ( + sqltypes.DECIMAL, + sqltypes.DateTime, + sqltypes.Date, + sqltypes.Time, + ), + ): return self.dialect.type_compiler.process(type_) - elif isinstance(type_, sqltypes.String) \ - and not isinstance(type_, (ENUM, SET)): + elif isinstance(type_, sqltypes.String) and not isinstance( + type_, (ENUM, SET) + ): adapted = CHAR._adapt_string_for_cast(type_) return self.dialect.type_compiler.process(adapted) elif isinstance(type_, sqltypes._Binary): - return 'BINARY' + return "BINARY" elif isinstance(type_, sqltypes.JSON): return "JSON" elif isinstance(type_, sqltypes.NUMERIC): - return self.dialect.type_compiler.process( - type_).replace('NUMERIC', 'DECIMAL') + return self.dialect.type_compiler.process(type_).replace( + "NUMERIC", "DECIMAL" + ) else: return None @@ -1044,23 +1316,25 @@ class MySQLCompiler(compiler.SQLCompiler): if not self.dialect._supports_cast: util.warn( "Current MySQL version does not support " - "CAST; the CAST will be skipped.") + "CAST; the CAST will be skipped." + ) return self.process(cast.clause.self_group(), **kw) type_ = self.process(cast.typeclause) if type_ is None: util.warn( "Datatype %s does not support CAST on MySQL; " - "the CAST will be skipped." % - self.dialect.type_compiler.process(cast.typeclause.type)) + "the CAST will be skipped." + % self.dialect.type_compiler.process(cast.typeclause.type) + ) return self.process(cast.clause.self_group(), **kw) - return 'CAST(%s AS %s)' % (self.process(cast.clause, **kw), type_) + return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_) def render_literal_value(self, value, type_): value = super(MySQLCompiler, self).render_literal_value(value, type_) if self.dialect._backslash_escapes: - value = value.replace('\\', '\\\\') + value = value.replace("\\", "\\\\") return value # override native_boolean=False behavior here, as @@ -1096,12 +1370,15 @@ class MySQLCompiler(compiler.SQLCompiler): else: join_type = " INNER JOIN " - return ''.join( - (self.process(join.left, asfrom=True, **kwargs), - join_type, - self.process(join.right, asfrom=True, **kwargs), - " ON ", - self.process(join.onclause, **kwargs))) + return "".join( + ( + self.process(join.left, asfrom=True, **kwargs), + join_type, + self.process(join.right, asfrom=True, **kwargs), + " ON ", + self.process(join.onclause, **kwargs), + ) + ) def for_update_clause(self, select, **kw): if select._for_update_arg.read: @@ -1118,11 +1395,13 @@ class MySQLCompiler(compiler.SQLCompiler): # The latter is more readable for offsets but we're stuck with the # former until we can refine dialects by server revision. - limit_clause, offset_clause = select._limit_clause, \ - select._offset_clause + limit_clause, offset_clause = ( + select._limit_clause, + select._offset_clause, + ) if limit_clause is None and offset_clause is None: - return '' + return "" elif offset_clause is not None: # As suggested by the MySQL docs, need to apply an # artificial limit if one wasn't provided @@ -1134,35 +1413,38 @@ class MySQLCompiler(compiler.SQLCompiler): # but also is consistent with the usage of the upper # bound as part of MySQL's "syntax" for OFFSET with # no LIMIT - return ' \n LIMIT %s, %s' % ( + return " \n LIMIT %s, %s" % ( self.process(offset_clause, **kw), - "18446744073709551615") + "18446744073709551615", + ) else: - return ' \n LIMIT %s, %s' % ( + return " \n LIMIT %s, %s" % ( self.process(offset_clause, **kw), - self.process(limit_clause, **kw)) + self.process(limit_clause, **kw), + ) else: # No offset provided, so just use the limit - return ' \n LIMIT %s' % (self.process(limit_clause, **kw),) + return " \n LIMIT %s" % (self.process(limit_clause, **kw),) def update_limit_clause(self, update_stmt): - limit = update_stmt.kwargs.get('%s_limit' % self.dialect.name, None) + limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None) if limit: return "LIMIT %s" % limit else: return None - def update_tables_clause(self, update_stmt, from_table, - extra_froms, **kw): - return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) - for t in [from_table] + list(extra_froms)) + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): + return ", ".join( + t._compiler_dispatch(self, asfrom=True, **kw) + for t in [from_table] + list(extra_froms) + ) - def update_from_clause(self, update_stmt, from_table, - extra_froms, from_hints, **kw): + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): return None - def delete_table_clause(self, delete_stmt, from_table, - extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms): """If we have extra froms make sure we render any alias as hint.""" ashint = False if extra_froms: @@ -1171,24 +1453,27 @@ class MySQLCompiler(compiler.SQLCompiler): self, asfrom=True, iscrud=True, ashint=ashint ) - def delete_extra_from_clause(self, delete_stmt, from_table, - extra_froms, from_hints, **kw): + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the DELETE .. USING clause specific to MySQL.""" - return "USING " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in [from_table] + extra_froms) + return "USING " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) def visit_empty_set_expr(self, element_types): return ( "SELECT %(outer)s FROM (SELECT %(inner)s) " - "as _empty_set WHERE 1!=1" % { + "as _empty_set WHERE 1!=1" + % { "inner": ", ".join( "1 AS _in_%s" % idx - for idx, type_ in enumerate(element_types)), + for idx, type_ in enumerate(element_types) + ), "outer": ", ".join( - "_in_%s" % idx - for idx, type_ in enumerate(element_types)) + "_in_%s" % idx for idx, type_ in enumerate(element_types) + ), } ) @@ -1200,35 +1485,39 @@ class MySQLDDLCompiler(compiler.DDLCompiler): colspec = [ self.preparer.format_column(column), self.dialect.type_compiler.process( - column.type, type_expression=column) + column.type, type_expression=column + ), ] is_timestamp = isinstance(column.type, sqltypes.TIMESTAMP) if not column.nullable: - colspec.append('NOT NULL') + colspec.append("NOT NULL") # see: http://docs.sqlalchemy.org/en/latest/dialects/ # mysql.html#mysql_timestamp_null elif column.nullable and is_timestamp: - colspec.append('NULL') + colspec.append("NULL") default = self.get_column_default_string(column) if default is not None: - colspec.append('DEFAULT ' + default) + colspec.append("DEFAULT " + default) comment = column.comment if comment is not None: literal = self.sql_compiler.render_literal_value( - comment, sqltypes.String()) - colspec.append('COMMENT ' + literal) + comment, sqltypes.String() + ) + colspec.append("COMMENT " + literal) - if column.table is not None \ - and column is column.table._autoincrement_column and \ - column.server_default is None: - colspec.append('AUTO_INCREMENT') + if ( + column.table is not None + and column is column.table._autoincrement_column + and column.server_default is None + ): + colspec.append("AUTO_INCREMENT") - return ' '.join(colspec) + return " ".join(colspec) def post_create_table(self, table): """Build table-level CREATE options like ENGINE and COLLATE.""" @@ -1236,76 +1525,94 @@ class MySQLDDLCompiler(compiler.DDLCompiler): table_opts = [] opts = dict( - ( - k[len(self.dialect.name) + 1:].upper(), - v - ) + (k[len(self.dialect.name) + 1 :].upper(), v) for k, v in table.kwargs.items() - if k.startswith('%s_' % self.dialect.name) + if k.startswith("%s_" % self.dialect.name) ) if table.comment is not None: - opts['COMMENT'] = table.comment + opts["COMMENT"] = table.comment partition_options = [ - 'PARTITION_BY', 'PARTITIONS', 'SUBPARTITIONS', - 'SUBPARTITION_BY' + "PARTITION_BY", + "PARTITIONS", + "SUBPARTITIONS", + "SUBPARTITION_BY", ] nonpart_options = set(opts).difference(partition_options) part_options = set(opts).intersection(partition_options) - for opt in topological.sort([ - ('DEFAULT_CHARSET', 'COLLATE'), - ('DEFAULT_CHARACTER_SET', 'COLLATE'), - ], nonpart_options): + for opt in topological.sort( + [ + ("DEFAULT_CHARSET", "COLLATE"), + ("DEFAULT_CHARACTER_SET", "COLLATE"), + ], + nonpart_options, + ): arg = opts[opt] if opt in _reflection._options_of_type_string: arg = self.sql_compiler.render_literal_value( - arg, sqltypes.String()) - - if opt in ('DATA_DIRECTORY', 'INDEX_DIRECTORY', - 'DEFAULT_CHARACTER_SET', 'CHARACTER_SET', - 'DEFAULT_CHARSET', - 'DEFAULT_COLLATE'): - opt = opt.replace('_', ' ') + arg, sqltypes.String() + ) - joiner = '=' - if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET', - 'CHARACTER SET', 'COLLATE'): - joiner = ' ' + if opt in ( + "DATA_DIRECTORY", + "INDEX_DIRECTORY", + "DEFAULT_CHARACTER_SET", + "CHARACTER_SET", + "DEFAULT_CHARSET", + "DEFAULT_COLLATE", + ): + opt = opt.replace("_", " ") + + joiner = "=" + if opt in ( + "TABLESPACE", + "DEFAULT CHARACTER SET", + "CHARACTER SET", + "COLLATE", + ): + joiner = " " table_opts.append(joiner.join((opt, arg))) - for opt in topological.sort([ - ('PARTITION_BY', 'PARTITIONS'), - ('PARTITION_BY', 'SUBPARTITION_BY'), - ('PARTITION_BY', 'SUBPARTITIONS'), - ('PARTITIONS', 'SUBPARTITIONS'), - ('PARTITIONS', 'SUBPARTITION_BY'), - ('SUBPARTITION_BY', 'SUBPARTITIONS') - ], part_options): + for opt in topological.sort( + [ + ("PARTITION_BY", "PARTITIONS"), + ("PARTITION_BY", "SUBPARTITION_BY"), + ("PARTITION_BY", "SUBPARTITIONS"), + ("PARTITIONS", "SUBPARTITIONS"), + ("PARTITIONS", "SUBPARTITION_BY"), + ("SUBPARTITION_BY", "SUBPARTITIONS"), + ], + part_options, + ): arg = opts[opt] if opt in _reflection._options_of_type_string: arg = self.sql_compiler.render_literal_value( - arg, sqltypes.String()) + arg, sqltypes.String() + ) - opt = opt.replace('_', ' ') - joiner = ' ' + opt = opt.replace("_", " ") + joiner = " " table_opts.append(joiner.join((opt, arg))) - return ' '.join(table_opts) + return " ".join(table_opts) def visit_create_index(self, create, **kw): index = create.element self._verify_index_table(index) preparer = self.preparer table = preparer.format_table(index.table) - columns = [self.sql_compiler.process(expr, include_table=False, - literal_binds=True) - for expr in index.expressions] + columns = [ + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ] name = self._prepared_index_name(index) @@ -1313,53 +1620,54 @@ class MySQLDDLCompiler(compiler.DDLCompiler): if index.unique: text += "UNIQUE " - index_prefix = index.kwargs.get('mysql_prefix', None) + index_prefix = index.kwargs.get("mysql_prefix", None) if index_prefix: - text += index_prefix + ' ' + text += index_prefix + " " text += "INDEX %s ON %s " % (name, table) - length = index.dialect_options['mysql']['length'] + length = index.dialect_options["mysql"]["length"] if length is not None: if isinstance(length, dict): # length value can be a (column_name --> integer value) # mapping specifying the prefix length for each column of the # index - columns = ', '.join( - '%s(%d)' % (expr, length[col.name]) if col.name in length - else - ( - '%s(%d)' % (expr, length[expr]) if expr in length - else '%s' % expr + columns = ", ".join( + "%s(%d)" % (expr, length[col.name]) + if col.name in length + else ( + "%s(%d)" % (expr, length[expr]) + if expr in length + else "%s" % expr ) for col, expr in zip(index.expressions, columns) ) else: # or can be an integer value specifying the same # prefix length for all columns of the index - columns = ', '.join( - '%s(%d)' % (col, length) - for col in columns + columns = ", ".join( + "%s(%d)" % (col, length) for col in columns ) else: - columns = ', '.join(columns) - text += '(%s)' % columns + columns = ", ".join(columns) + text += "(%s)" % columns - parser = index.dialect_options['mysql']['with_parser'] + parser = index.dialect_options["mysql"]["with_parser"] if parser is not None: - text += " WITH PARSER %s" % (parser, ) + text += " WITH PARSER %s" % (parser,) - using = index.dialect_options['mysql']['using'] + using = index.dialect_options["mysql"]["using"] if using is not None: text += " USING %s" % (preparer.quote(using)) return text def visit_primary_key_constraint(self, constraint): - text = super(MySQLDDLCompiler, self).\ - visit_primary_key_constraint(constraint) - using = constraint.dialect_options['mysql']['using'] + text = super(MySQLDDLCompiler, self).visit_primary_key_constraint( + constraint + ) + using = constraint.dialect_options["mysql"]["using"] if using: text += " USING %s" % (self.preparer.quote(using)) return text @@ -1368,9 +1676,9 @@ class MySQLDDLCompiler(compiler.DDLCompiler): index = drop.element return "\nDROP INDEX %s ON %s" % ( - self._prepared_index_name(index, - include_schema=False), - self.preparer.format_table(index.table)) + self._prepared_index_name(index, include_schema=False), + self.preparer.format_table(index.table), + ) def visit_drop_constraint(self, drop): constraint = drop.element @@ -1386,29 +1694,33 @@ class MySQLDDLCompiler(compiler.DDLCompiler): else: qual = "" const = self.preparer.format_constraint(constraint) - return "ALTER TABLE %s DROP %s%s" % \ - (self.preparer.format_table(constraint.table), - qual, const) + return "ALTER TABLE %s DROP %s%s" % ( + self.preparer.format_table(constraint.table), + qual, + const, + ) def define_constraint_match(self, constraint): if constraint.match is not None: raise exc.CompileError( "MySQL ignores the 'MATCH' keyword while at the same time " - "causes ON UPDATE/ON DELETE clauses to be ignored.") + "causes ON UPDATE/ON DELETE clauses to be ignored." + ) return "" def visit_set_table_comment(self, create): return "ALTER TABLE %s COMMENT %s" % ( self.preparer.format_table(create.element), self.sql_compiler.render_literal_value( - create.element.comment, sqltypes.String()) + create.element.comment, sqltypes.String() + ), ) def visit_set_column_comment(self, create): return "ALTER TABLE %s CHANGE %s %s" % ( self.preparer.format_table(create.element.table), self.preparer.format_column(create.element), - self.get_column_specification(create.element) + self.get_column_specification(create.element), ) @@ -1420,9 +1732,9 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return spec if type_.unsigned: - spec += ' UNSIGNED' + spec += " UNSIGNED" if type_.zerofill: - spec += ' ZEROFILL' + spec += " ZEROFILL" return spec def _extend_string(self, type_, defaults, spec): @@ -1434,28 +1746,30 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def attr(name): return getattr(type_, name, defaults.get(name)) - if attr('charset'): - charset = 'CHARACTER SET %s' % attr('charset') - elif attr('ascii'): - charset = 'ASCII' - elif attr('unicode'): - charset = 'UNICODE' + if attr("charset"): + charset = "CHARACTER SET %s" % attr("charset") + elif attr("ascii"): + charset = "ASCII" + elif attr("unicode"): + charset = "UNICODE" else: charset = None - if attr('collation'): - collation = 'COLLATE %s' % type_.collation - elif attr('binary'): - collation = 'BINARY' + if attr("collation"): + collation = "COLLATE %s" % type_.collation + elif attr("binary"): + collation = "BINARY" else: collation = None - if attr('national'): + if attr("national"): # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. - return ' '.join([c for c in ('NATIONAL', spec, collation) - if c is not None]) - return ' '.join([c for c in (spec, charset, collation) - if c is not None]) + return " ".join( + [c for c in ("NATIONAL", spec, collation) if c is not None] + ) + return " ".join( + [c for c in (spec, charset, collation) if c is not None] + ) def _mysql_type(self, type_): return isinstance(type_, (_StringType, _NumericType)) @@ -1464,95 +1778,113 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): if type_.precision is None: return self._extend_numeric(type_, "NUMERIC") elif type_.scale is None: - return self._extend_numeric(type_, - "NUMERIC(%(precision)s)" % - {'precision': type_.precision}) + return self._extend_numeric( + type_, + "NUMERIC(%(precision)s)" % {"precision": type_.precision}, + ) else: - return self._extend_numeric(type_, - "NUMERIC(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + return self._extend_numeric( + type_, + "NUMERIC(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return self._extend_numeric(type_, "DECIMAL") elif type_.scale is None: - return self._extend_numeric(type_, - "DECIMAL(%(precision)s)" % - {'precision': type_.precision}) + return self._extend_numeric( + type_, + "DECIMAL(%(precision)s)" % {"precision": type_.precision}, + ) else: - return self._extend_numeric(type_, - "DECIMAL(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + return self._extend_numeric( + type_, + "DECIMAL(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) def visit_DOUBLE(self, type_, **kw): if type_.precision is not None and type_.scale is not None: - return self._extend_numeric(type_, - "DOUBLE(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + return self._extend_numeric( + type_, + "DOUBLE(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) else: - return self._extend_numeric(type_, 'DOUBLE') + return self._extend_numeric(type_, "DOUBLE") def visit_REAL(self, type_, **kw): if type_.precision is not None and type_.scale is not None: - return self._extend_numeric(type_, - "REAL(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + return self._extend_numeric( + type_, + "REAL(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) else: - return self._extend_numeric(type_, 'REAL') + return self._extend_numeric(type_, "REAL") def visit_FLOAT(self, type_, **kw): - if self._mysql_type(type_) and \ - type_.scale is not None and \ - type_.precision is not None: + if ( + self._mysql_type(type_) + and type_.scale is not None + and type_.precision is not None + ): return self._extend_numeric( - type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale)) + type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale) + ) elif type_.precision is not None: - return self._extend_numeric(type_, - "FLOAT(%s)" % (type_.precision,)) + return self._extend_numeric( + type_, "FLOAT(%s)" % (type_.precision,) + ) else: return self._extend_numeric(type_, "FLOAT") def visit_INTEGER(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( - type_, "INTEGER(%(display_width)s)" % - {'display_width': type_.display_width}) + type_, + "INTEGER(%(display_width)s)" + % {"display_width": type_.display_width}, + ) else: return self._extend_numeric(type_, "INTEGER") def visit_BIGINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( - type_, "BIGINT(%(display_width)s)" % - {'display_width': type_.display_width}) + type_, + "BIGINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) else: return self._extend_numeric(type_, "BIGINT") def visit_MEDIUMINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( - type_, "MEDIUMINT(%(display_width)s)" % - {'display_width': type_.display_width}) + type_, + "MEDIUMINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) else: return self._extend_numeric(type_, "MEDIUMINT") def visit_TINYINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: - return self._extend_numeric(type_, - "TINYINT(%s)" % type_.display_width) + return self._extend_numeric( + type_, "TINYINT(%s)" % type_.display_width + ) else: return self._extend_numeric(type_, "TINYINT") def visit_SMALLINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: - return self._extend_numeric(type_, - "SMALLINT(%(display_width)s)" % - {'display_width': type_.display_width} - ) + return self._extend_numeric( + type_, + "SMALLINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) else: return self._extend_numeric(type_, "SMALLINT") @@ -1563,7 +1895,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return "BIT" def visit_DATETIME(self, type_, **kw): - if getattr(type_, 'fsp', None): + if getattr(type_, "fsp", None): return "DATETIME(%d)" % type_.fsp else: return "DATETIME" @@ -1572,13 +1904,13 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return "DATE" def visit_TIME(self, type_, **kw): - if getattr(type_, 'fsp', None): + if getattr(type_, "fsp", None): return "TIME(%d)" % type_.fsp else: return "TIME" def visit_TIMESTAMP(self, type_, **kw): - if getattr(type_, 'fsp', None): + if getattr(type_, "fsp", None): return "TIMESTAMP(%d)" % type_.fsp else: return "TIMESTAMP" @@ -1606,17 +1938,17 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def visit_VARCHAR(self, type_, **kw): if type_.length: - return self._extend_string( - type_, {}, "VARCHAR(%d)" % type_.length) + return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length) else: raise exc.CompileError( - "VARCHAR requires a length on dialect %s" % - self.dialect.name) + "VARCHAR requires a length on dialect %s" % self.dialect.name + ) def visit_CHAR(self, type_, **kw): if type_.length: - return self._extend_string(type_, {}, "CHAR(%(length)s)" % - {'length': type_.length}) + return self._extend_string( + type_, {}, "CHAR(%(length)s)" % {"length": type_.length} + ) else: return self._extend_string(type_, {}, "CHAR") @@ -1625,22 +1957,26 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): # of "NVARCHAR". if type_.length: return self._extend_string( - type_, {'national': True}, - "VARCHAR(%(length)s)" % {'length': type_.length}) + type_, + {"national": True}, + "VARCHAR(%(length)s)" % {"length": type_.length}, + ) else: raise exc.CompileError( - "NVARCHAR requires a length on dialect %s" % - self.dialect.name) + "NVARCHAR requires a length on dialect %s" % self.dialect.name + ) def visit_NCHAR(self, type_, **kw): # We'll actually generate the equiv. # "NATIONAL CHAR" instead of "NCHAR". if type_.length: return self._extend_string( - type_, {'national': True}, - "CHAR(%(length)s)" % {'length': type_.length}) + type_, + {"national": True}, + "CHAR(%(length)s)" % {"length": type_.length}, + ) else: - return self._extend_string(type_, {'national': True}, "CHAR") + return self._extend_string(type_, {"national": True}, "CHAR") def visit_VARBINARY(self, type_, **kw): return "VARBINARY(%d)" % type_.length @@ -1676,17 +2012,19 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): quoted_enums = [] for e in enumerated_values: quoted_enums.append("'%s'" % e.replace("'", "''")) - return self._extend_string(type_, {}, "%s(%s)" % ( - name, ",".join(quoted_enums)) + return self._extend_string( + type_, {}, "%s(%s)" % (name, ",".join(quoted_enums)) ) def visit_ENUM(self, type_, **kw): - return self._visit_enumerated_values("ENUM", type_, - type_._enumerated_values) + return self._visit_enumerated_values( + "ENUM", type_, type_._enumerated_values + ) def visit_SET(self, type_, **kw): - return self._visit_enumerated_values("SET", type_, - type_._enumerated_values) + return self._visit_enumerated_values( + "SET", type_, type_._enumerated_values + ) def visit_BOOLEAN(self, type, **kw): return "BOOL" @@ -1703,9 +2041,8 @@ class MySQLIdentifierPreparer(compiler.IdentifierPreparer): quote = '"' super(MySQLIdentifierPreparer, self).__init__( - dialect, - initial_quote=quote, - escape_quote=quote) + dialect, initial_quote=quote, escape_quote=quote + ) def _quote_free_identifiers(self, *ids): """Unilaterally identifier-quote any number of strings.""" @@ -1719,7 +2056,7 @@ class MySQLDialect(default.DefaultDialect): Not used directly in application code. """ - name = 'mysql' + name = "mysql" supports_alter = True # MySQL has no true "boolean" type; we @@ -1738,7 +2075,7 @@ class MySQLDialect(default.DefaultDialect): supports_comments = True inline_comments = True - default_paramstyle = 'format' + default_paramstyle = "format" colspecs = colspecs cte_follows_insert = True @@ -1756,26 +2093,28 @@ class MySQLDialect(default.DefaultDialect): _server_ansiquotes = False construct_arguments = [ - (sa_schema.Table, { - "*": None - }), - (sql.Update, { - "limit": None - }), - (sa_schema.PrimaryKeyConstraint, { - "using": None - }), - (sa_schema.Index, { - "using": None, - "length": None, - "prefix": None, - "with_parser": None - }) + (sa_schema.Table, {"*": None}), + (sql.Update, {"limit": None}), + (sa_schema.PrimaryKeyConstraint, {"using": None}), + ( + sa_schema.Index, + { + "using": None, + "length": None, + "prefix": None, + "with_parser": None, + }, + ), ] - def __init__(self, isolation_level=None, json_serializer=None, - json_deserializer=None, **kwargs): - kwargs.pop('use_ansiquotes', None) # legacy + def __init__( + self, + isolation_level=None, + json_serializer=None, + json_deserializer=None, + **kwargs + ): + kwargs.pop("use_ansiquotes", None) # legacy default.DefaultDialect.__init__(self, **kwargs) self.isolation_level = isolation_level self._json_serializer = json_serializer @@ -1783,22 +2122,30 @@ class MySQLDialect(default.DefaultDialect): def on_connect(self): if self.isolation_level is not None: + def connect(conn): self.set_isolation_level(conn, self.isolation_level) + return connect else: return None - _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED', - 'READ COMMITTED', 'REPEATABLE READ']) + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + ] + ) def set_isolation_level(self, connection, level): - level = level.replace('_', ' ') + level = level.replace("_", " ") # adjust for ConnectionFairy being present # allows attribute set e.g. "connection.autocommit = True" # to work properly - if hasattr(connection, 'connection'): + if hasattr(connection, "connection"): connection = connection.connection self._set_isolation_level(connection, level) @@ -1807,8 +2154,8 @@ class MySQLDialect(default.DefaultDialect): if level not in self._isolation_lookup: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) ) cursor = connection.cursor() cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level) @@ -1818,9 +2165,9 @@ class MySQLDialect(default.DefaultDialect): def get_isolation_level(self, connection): cursor = connection.cursor() if self._is_mysql and self.server_version_info >= (5, 7, 20): - cursor.execute('SELECT @@transaction_isolation') + cursor.execute("SELECT @@transaction_isolation") else: - cursor.execute('SELECT @@tx_isolation') + cursor.execute("SELECT @@tx_isolation") val = cursor.fetchone()[0] cursor.close() if util.py3k and isinstance(val, bytes): @@ -1840,7 +2187,7 @@ class MySQLDialect(default.DefaultDialect): val = val.decode() version = [] - r = re.compile(r'[.\-]') + r = re.compile(r"[.\-]") for n in r.split(val): try: version.append(int(n)) @@ -1885,29 +2232,38 @@ class MySQLDialect(default.DefaultDialect): connection.execute(sql.text("XA END :xid"), xid=xid) connection.execute(sql.text("XA PREPARE :xid"), xid=xid) - def do_rollback_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: connection.execute(sql.text("XA END :xid"), xid=xid) connection.execute(sql.text("XA ROLLBACK :xid"), xid=xid) - def do_commit_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: self.do_prepare_twophase(connection, xid) connection.execute(sql.text("XA COMMIT :xid"), xid=xid) def do_recover_twophase(self, connection): resultset = connection.execute("XA RECOVER") - return [row['data'][0:row['gtrid_length']] for row in resultset] + return [row["data"][0 : row["gtrid_length"]] for row in resultset] def is_disconnect(self, e, connection, cursor): - if isinstance(e, (self.dbapi.OperationalError, - self.dbapi.ProgrammingError)): - return self._extract_error_code(e) in \ - (2006, 2013, 2014, 2045, 2055) + if isinstance( + e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError) + ): + return self._extract_error_code(e) in ( + 2006, + 2013, + 2014, + 2045, + 2055, + ) elif isinstance( - e, (self.dbapi.InterfaceError, self.dbapi.InternalError)): + e, (self.dbapi.InterfaceError, self.dbapi.InternalError) + ): # if underlying connection is closed, # this is the error you get return "(0, '')" in str(e) @@ -1944,7 +2300,7 @@ class MySQLDialect(default.DefaultDialect): raise NotImplementedError() def _get_default_schema_name(self, connection): - return connection.execute('SELECT DATABASE()').scalar() + return connection.execute("SELECT DATABASE()").scalar() def has_table(self, connection, table_name, schema=None): # SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly @@ -1957,15 +2313,19 @@ class MySQLDialect(default.DefaultDialect): # full_name = self.identifier_preparer.format_table(table, # use_schema=True) - full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( - schema, table_name)) + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers( + schema, table_name + ) + ) st = "DESCRIBE %s" % full_name rs = None try: try: rs = connection.execution_options( - skip_user_error_events=True).execute(st) + skip_user_error_events=True + ).execute(st) have = rs.fetchone() is not None rs.close() return have @@ -1986,12 +2346,13 @@ class MySQLDialect(default.DefaultDialect): # if ansiquotes == True, build a new IdentifierPreparer # with the new setting self.identifier_preparer = self.preparer( - self, server_ansiquotes=self._server_ansiquotes) + self, server_ansiquotes=self._server_ansiquotes + ) default.DefaultDialect.initialize(self, connection) self._needs_correct_for_88718 = ( - not self._is_mariadb and self.server_version_info >= (8, ) + not self._is_mariadb and self.server_version_info >= (8,) ) self._warn_for_known_db_issues() @@ -2007,20 +2368,23 @@ class MySQLDialect(default.DefaultDialect): "additional issue prevents proper migrations of columns " "with CHECK constraints (MDEV-11114). Please upgrade to " "MariaDB 10.2.9 or greater, or use the MariaDB 10.1 " - "series, to avoid these issues." % (mdb_version, )) + "series, to avoid these issues." % (mdb_version,) + ) @property def _is_mariadb(self): - return 'MariaDB' in self.server_version_info + return "MariaDB" in self.server_version_info @property def _is_mysql(self): - return 'MariaDB' not in self.server_version_info + return "MariaDB" not in self.server_version_info @property def _is_mariadb_102(self): - return self._is_mariadb and \ - self._mariadb_normalized_version_info > (10, 2) + return self._is_mariadb and self._mariadb_normalized_version_info > ( + 10, + 2, + ) @property def _mariadb_normalized_version_info(self): @@ -2028,15 +2392,17 @@ class MySQLDialect(default.DefaultDialect): # the string "5.5"; now that we use @@version we no longer see this. if self._is_mariadb: - idx = self.server_version_info.index('MariaDB') - return self.server_version_info[idx - 3: idx] + idx = self.server_version_info.index("MariaDB") + return self.server_version_info[idx - 3 : idx] else: return self.server_version_info @property def _supports_cast(self): - return self.server_version_info is None or \ - self.server_version_info >= (4, 0, 2) + return ( + self.server_version_info is None + or self.server_version_info >= (4, 0, 2) + ) @reflection.cache def get_schema_names(self, connection, **kw): @@ -2054,18 +2420,23 @@ class MySQLDialect(default.DefaultDialect): charset = self._connection_charset if self.server_version_info < (5, 0, 2): rp = connection.execute( - "SHOW TABLES FROM %s" % - self.identifier_preparer.quote_identifier(current_schema)) - return [row[0] for - row in self._compat_fetchall(rp, charset=charset)] + "SHOW TABLES FROM %s" + % self.identifier_preparer.quote_identifier(current_schema) + ) + return [ + row[0] for row in self._compat_fetchall(rp, charset=charset) + ] else: rp = connection.execute( - "SHOW FULL TABLES FROM %s" % - self.identifier_preparer.quote_identifier(current_schema)) + "SHOW FULL TABLES FROM %s" + % self.identifier_preparer.quote_identifier(current_schema) + ) - return [row[0] - for row in self._compat_fetchall(rp, charset=charset) - if row[1] == 'BASE TABLE'] + return [ + row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] == "BASE TABLE" + ] @reflection.cache def get_view_names(self, connection, schema=None, **kw): @@ -2077,72 +2448,77 @@ class MySQLDialect(default.DefaultDialect): return self.get_table_names(connection, schema) charset = self._connection_charset rp = connection.execute( - "SHOW FULL TABLES FROM %s" % - self.identifier_preparer.quote_identifier(schema)) - return [row[0] - for row in self._compat_fetchall(rp, charset=charset) - if row[1] in ('VIEW', 'SYSTEM VIEW')] + "SHOW FULL TABLES FROM %s" + % self.identifier_preparer.quote_identifier(schema) + ) + return [ + row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] in ("VIEW", "SYSTEM VIEW") + ] @reflection.cache def get_table_options(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) return parsed_state.table_options @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) return parsed_state.columns @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) for key in parsed_state.keys: - if key['type'] == 'PRIMARY': + if key["type"] == "PRIMARY": # There can be only one. - cols = [s[0] for s in key['columns']] - return {'constrained_columns': cols, 'name': None} - return {'constrained_columns': [], 'name': None} + cols = [s[0] for s in key["columns"]] + return {"constrained_columns": cols, "name": None} + return {"constrained_columns": [], "name": None} @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) default_schema = None fkeys = [] for spec in parsed_state.fk_constraints: - ref_name = spec['table'][-1] - ref_schema = len(spec['table']) > 1 and \ - spec['table'][-2] or schema + ref_name = spec["table"][-1] + ref_schema = len(spec["table"]) > 1 and spec["table"][-2] or schema if not ref_schema: if default_schema is None: - default_schema = \ - connection.dialect.default_schema_name + default_schema = connection.dialect.default_schema_name if schema == default_schema: ref_schema = schema - loc_names = spec['local'] - ref_names = spec['foreign'] + loc_names = spec["local"] + ref_names = spec["foreign"] con_kw = {} - for opt in ('onupdate', 'ondelete'): + for opt in ("onupdate", "ondelete"): if spec.get(opt, False): con_kw[opt] = spec[opt] fkey_d = { - 'name': spec['name'], - 'constrained_columns': loc_names, - 'referred_schema': ref_schema, - 'referred_table': ref_name, - 'referred_columns': ref_names, - 'options': con_kw + "name": spec["name"], + "constrained_columns": loc_names, + "referred_schema": ref_schema, + "referred_table": ref_name, + "referred_columns": ref_names, + "options": con_kw, } fkeys.append(fkey_d) @@ -2172,25 +2548,26 @@ class MySQLDialect(default.DefaultDialect): default_schema_name = connection.dialect.default_schema_name col_tuples = [ ( - lower(rec['referred_schema'] or default_schema_name), - lower(rec['referred_table']), - col_name + lower(rec["referred_schema"] or default_schema_name), + lower(rec["referred_table"]), + col_name, ) for rec in fkeys - for col_name in rec['referred_columns'] + for col_name in rec["referred_columns"] ] if col_tuples: correct_for_wrong_fk_case = connection.execute( - sql.text(""" + sql.text( + """ select table_schema, table_name, column_name from information_schema.columns where (table_schema, table_name, lower(column_name)) in :table_data; - """).bindparams( - sql.bindparam("table_data", expanding=True) - ), table_data=col_tuples + """ + ).bindparams(sql.bindparam("table_data", expanding=True)), + table_data=col_tuples, ) # in casing=0, table name and schema name come back in their @@ -2208,109 +2585,117 @@ class MySQLDialect(default.DefaultDialect): d[(lower(schema), lower(tname))][cname.lower()] = cname for fkey in fkeys: - fkey['referred_columns'] = [ + fkey["referred_columns"] = [ d[ ( lower( - fkey['referred_schema'] or - default_schema_name), - lower(fkey['referred_table']) + fkey["referred_schema"] or default_schema_name + ), + lower(fkey["referred_table"]), ) ][col.lower()] - for col in fkey['referred_columns'] + for col in fkey["referred_columns"] ] @reflection.cache - def get_check_constraints( - self, connection, table_name, schema=None, **kw): + def get_check_constraints(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) return [ - {"name": spec['name'], "sqltext": spec['sqltext']} + {"name": spec["name"], "sqltext": spec["sqltext"]} for spec in parsed_state.ck_constraints ] @reflection.cache def get_table_comment(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) - return {"text": parsed_state.table_options.get('mysql_comment', None)} + connection, table_name, schema, **kw + ) + return {"text": parsed_state.table_options.get("mysql_comment", None)} @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) indexes = [] for spec in parsed_state.keys: dialect_options = {} unique = False - flavor = spec['type'] - if flavor == 'PRIMARY': + flavor = spec["type"] + if flavor == "PRIMARY": continue - if flavor == 'UNIQUE': + if flavor == "UNIQUE": unique = True - elif flavor in ('FULLTEXT', 'SPATIAL'): + elif flavor in ("FULLTEXT", "SPATIAL"): dialect_options["mysql_prefix"] = flavor elif flavor is None: pass else: self.logger.info( - "Converting unknown KEY type %s to a plain KEY", flavor) + "Converting unknown KEY type %s to a plain KEY", flavor + ) pass - if spec['parser']: - dialect_options['mysql_with_parser'] = spec['parser'] + if spec["parser"]: + dialect_options["mysql_with_parser"] = spec["parser"] index_d = {} if dialect_options: index_d["dialect_options"] = dialect_options - index_d['name'] = spec['name'] - index_d['column_names'] = [s[0] for s in spec['columns']] - index_d['unique'] = unique + index_d["name"] = spec["name"] + index_d["column_names"] = [s[0] for s in spec["columns"]] + index_d["unique"] = unique if flavor: - index_d['type'] = flavor + index_d["type"] = flavor indexes.append(index_d) return indexes @reflection.cache - def get_unique_constraints(self, connection, table_name, - schema=None, **kw): + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) return [ { - 'name': key['name'], - 'column_names': [col[0] for col in key['columns']], - 'duplicates_index': key['name'], + "name": key["name"], + "column_names": [col[0] for col in key["columns"]], + "duplicates_index": key["name"], } for key in parsed_state.keys - if key['type'] == 'UNIQUE' + if key["type"] == "UNIQUE" ] @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): charset = self._connection_charset - full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( - schema, view_name)) - sql = self._show_create_table(connection, None, charset, - full_name=full_name) + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers(schema, view_name) + ) + sql = self._show_create_table( + connection, None, charset, full_name=full_name + ) return sql - def _parsed_state_or_create(self, connection, table_name, - schema=None, **kw): + def _parsed_state_or_create( + self, connection, table_name, schema=None, **kw + ): return self._setup_parser( connection, table_name, schema, - info_cache=kw.get('info_cache', None) + info_cache=kw.get("info_cache", None), ) @util.memoized_property @@ -2321,7 +2706,7 @@ class MySQLDialect(default.DefaultDialect): retrieved server version information first. """ - if (self.server_version_info < (4, 1) and self._server_ansiquotes): + if self.server_version_info < (4, 1) and self._server_ansiquotes: # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1 preparer = self.preparer(self, server_ansiquotes=False) else: @@ -2332,14 +2717,19 @@ class MySQLDialect(default.DefaultDialect): def _setup_parser(self, connection, table_name, schema=None, **kw): charset = self._connection_charset parser = self._tabledef_parser - full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( - schema, table_name)) - sql = self._show_create_table(connection, None, charset, - full_name=full_name) - if re.match(r'^CREATE (?:ALGORITHM)?.* VIEW', sql): + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers( + schema, table_name + ) + ) + sql = self._show_create_table( + connection, None, charset, full_name=full_name + ) + if re.match(r"^CREATE (?:ALGORITHM)?.* VIEW", sql): # Adapt views to something table-like. - columns = self._describe_table(connection, None, charset, - full_name=full_name) + columns = self._describe_table( + connection, None, charset, full_name=full_name + ) sql = parser._describe_to_create(table_name, columns) return parser.parse(sql, charset) @@ -2356,17 +2746,18 @@ class MySQLDialect(default.DefaultDialect): # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html charset = self._connection_charset - row = self._compat_first(connection.execute( - "SHOW VARIABLES LIKE 'lower_case_table_names'"), - charset=charset) + row = self._compat_first( + connection.execute("SHOW VARIABLES LIKE 'lower_case_table_names'"), + charset=charset, + ) if not row: cs = 0 else: # 4.0.15 returns OFF or ON according to [ticket:489] # 3.23 doesn't, 4.0.27 doesn't.. - if row[1] == 'OFF': + if row[1] == "OFF": cs = 0 - elif row[1] == 'ON': + elif row[1] == "ON": cs = 1 else: cs = int(row[1]) @@ -2384,7 +2775,7 @@ class MySQLDialect(default.DefaultDialect): pass else: charset = self._connection_charset - rs = connection.execute('SHOW COLLATION') + rs = connection.execute("SHOW COLLATION") for row in self._compat_fetchall(rs, charset): collations[row[0]] = row[1] return collations @@ -2392,33 +2783,36 @@ class MySQLDialect(default.DefaultDialect): def _detect_sql_mode(self, connection): row = self._compat_first( connection.execute("SHOW VARIABLES LIKE 'sql_mode'"), - charset=self._connection_charset) + charset=self._connection_charset, + ) if not row: util.warn( "Could not retrieve SQL_MODE; please ensure the " - "MySQL user has permissions to SHOW VARIABLES") - self._sql_mode = '' + "MySQL user has permissions to SHOW VARIABLES" + ) + self._sql_mode = "" else: - self._sql_mode = row[1] or '' + self._sql_mode = row[1] or "" def _detect_ansiquotes(self, connection): """Detect and adjust for the ANSI_QUOTES sql mode.""" mode = self._sql_mode if not mode: - mode = '' + mode = "" elif mode.isdigit(): mode_no = int(mode) - mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or '' + mode = (mode_no | 4 == mode_no) and "ANSI_QUOTES" or "" - self._server_ansiquotes = 'ANSI_QUOTES' in mode + self._server_ansiquotes = "ANSI_QUOTES" in mode # as of MySQL 5.0.1 - self._backslash_escapes = 'NO_BACKSLASH_ESCAPES' not in mode + self._backslash_escapes = "NO_BACKSLASH_ESCAPES" not in mode - def _show_create_table(self, connection, table, charset=None, - full_name=None): + def _show_create_table( + self, connection, table, charset=None, full_name=None + ): """Run SHOW CREATE TABLE for a ``Table``.""" if full_name is None: @@ -2428,7 +2822,8 @@ class MySQLDialect(default.DefaultDialect): rp = None try: rp = connection.execution_options( - skip_user_error_events=True).execute(st) + skip_user_error_events=True + ).execute(st) except exc.DBAPIError as e: if self._extract_error_code(e.orig) == 1146: raise exc.NoSuchTableError(full_name) @@ -2441,8 +2836,7 @@ class MySQLDialect(default.DefaultDialect): return sql - def _describe_table(self, connection, table, charset=None, - full_name=None): + def _describe_table(self, connection, table, charset=None, full_name=None): """Run DESCRIBE for a ``Table`` and return processed rows.""" if full_name is None: @@ -2453,7 +2847,8 @@ class MySQLDialect(default.DefaultDialect): try: try: rp = connection.execution_options( - skip_user_error_events=True).execute(st) + skip_user_error_events=True + ).execute(st) except exc.DBAPIError as e: code = self._extract_error_code(e.orig) if code == 1146: @@ -2486,11 +2881,11 @@ class _DecodingRowProxy(object): # seem to come up in DDL queries. _encoding_compat = { - 'koi8r': 'koi8_r', - 'koi8u': 'koi8_u', - 'utf16': 'utf-16-be', # MySQL's uft16 is always bigendian - 'utf8mb4': 'utf8', # real utf8 - 'eucjpms': 'ujis', + "koi8r": "koi8_r", + "koi8u": "koi8_u", + "utf16": "utf-16-be", # MySQL's uft16 is always bigendian + "utf8mb4": "utf8", # real utf8 + "eucjpms": "ujis", } def __init__(self, rowproxy, charset): diff --git a/lib/sqlalchemy/dialects/mysql/cymysql.py b/lib/sqlalchemy/dialects/mysql/cymysql.py index d142905948..8a60608db7 100644 --- a/lib/sqlalchemy/dialects/mysql/cymysql.py +++ b/lib/sqlalchemy/dialects/mysql/cymysql.py @@ -18,7 +18,7 @@ import re from .mysqldb import MySQLDialect_mysqldb -from .base import (BIT, MySQLDialect) +from .base import BIT, MySQLDialect from ... import util @@ -34,27 +34,23 @@ class _cymysqlBIT(BIT): v = v << 8 | i return v return value + return process class MySQLDialect_cymysql(MySQLDialect_mysqldb): - driver = 'cymysql' + driver = "cymysql" description_encoding = None supports_sane_rowcount = True supports_sane_multi_rowcount = False supports_unicode_statements = True - colspecs = util.update_copy( - MySQLDialect.colspecs, - { - BIT: _cymysqlBIT, - } - ) + colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT}) @classmethod def dbapi(cls): - return __import__('cymysql') + return __import__("cymysql") def _detect_charset(self, connection): return connection.connection.charset @@ -64,8 +60,13 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb): def is_disconnect(self, e, connection, cursor): if isinstance(e, self.dbapi.OperationalError): - return self._extract_error_code(e) in \ - (2006, 2013, 2014, 2045, 2055) + return self._extract_error_code(e) in ( + 2006, + 2013, + 2014, + 2045, + 2055, + ) elif isinstance(e, self.dbapi.InterfaceError): # if underlying connection is closed, # this is the error you get @@ -73,4 +74,5 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb): else: return False + dialect = MySQLDialect_cymysql diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py index 130ef23477..5d59b2073e 100644 --- a/lib/sqlalchemy/dialects/mysql/dml.py +++ b/lib/sqlalchemy/dialects/mysql/dml.py @@ -6,7 +6,7 @@ from ...sql.base import _generative from ... import exc from ... import util -__all__ = ('Insert', 'insert') +__all__ = ("Insert", "insert") class Insert(StandardInsert): @@ -39,7 +39,7 @@ class Insert(StandardInsert): @util.memoized_property def inserted_alias(self): - return alias(self.table, name='inserted') + return alias(self.table, name="inserted") @_generative def on_duplicate_key_update(self, *args, **kw): @@ -87,27 +87,29 @@ class Insert(StandardInsert): """ if args and kw: raise exc.ArgumentError( - "Can't pass kwargs and positional arguments simultaneously") + "Can't pass kwargs and positional arguments simultaneously" + ) if args: if len(args) > 1: raise exc.ArgumentError( "Only a single dictionary or list of tuples " - "is accepted positionally.") + "is accepted positionally." + ) values = args[0] else: values = kw - inserted_alias = getattr(self, 'inserted_alias', None) + inserted_alias = getattr(self, "inserted_alias", None) self._post_values_clause = OnDuplicateClause(inserted_alias, values) return self -insert = public_factory(Insert, '.dialects.mysql.insert') +insert = public_factory(Insert, ".dialects.mysql.insert") class OnDuplicateClause(ClauseElement): - __visit_name__ = 'on_duplicate_key_update' + __visit_name__ = "on_duplicate_key_update" _parameter_ordering = None @@ -118,11 +120,12 @@ class OnDuplicateClause(ClauseElement): # Update._proces_colparams(), however we don't look for a special flag # in this case since we are not disambiguating from other use cases as # we are in Update.values(). - if isinstance(update, list) and \ - (update and isinstance(update[0], tuple)): + if isinstance(update, list) and ( + update and isinstance(update[0], tuple) + ): self._parameter_ordering = [key for key, value in update] update = dict(update) if not update or not isinstance(update, dict): - raise ValueError('update parameter must be a non-empty dictionary') + raise ValueError("update parameter must be a non-empty dictionary") self.update = update diff --git a/lib/sqlalchemy/dialects/mysql/enumerated.py b/lib/sqlalchemy/dialects/mysql/enumerated.py index f63d64e8f9..9586eff3ff 100644 --- a/lib/sqlalchemy/dialects/mysql/enumerated.py +++ b/lib/sqlalchemy/dialects/mysql/enumerated.py @@ -14,29 +14,30 @@ from ...sql import sqltypes class _EnumeratedValues(_StringType): def _init_values(self, values, kw): - self.quoting = kw.pop('quoting', 'auto') + self.quoting = kw.pop("quoting", "auto") - if self.quoting == 'auto' and len(values): + if self.quoting == "auto" and len(values): # What quoting character are we using? q = None for e in values: if len(e) == 0: - self.quoting = 'unquoted' + self.quoting = "unquoted" break elif q is None: q = e[0] if len(e) == 1 or e[0] != q or e[-1] != q: - self.quoting = 'unquoted' + self.quoting = "unquoted" break else: - self.quoting = 'quoted' + self.quoting = "quoted" - if self.quoting == 'quoted': + if self.quoting == "quoted": util.warn_deprecated( - 'Manually quoting %s value literals is deprecated. Supply ' - 'unquoted values and use the quoting= option in cases of ' - 'ambiguity.' % self.__class__.__name__) + "Manually quoting %s value literals is deprecated. Supply " + "unquoted values and use the quoting= option in cases of " + "ambiguity." % self.__class__.__name__ + ) values = self._strip_values(values) @@ -58,7 +59,7 @@ class _EnumeratedValues(_StringType): class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _EnumeratedValues): """MySQL ENUM type.""" - __visit_name__ = 'ENUM' + __visit_name__ = "ENUM" native_enum = True @@ -115,7 +116,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _EnumeratedValues): """ - kw.pop('strict', None) + kw.pop("strict", None) self._enum_init(enums, kw) _StringType.__init__(self, length=self.length, **kw) @@ -145,13 +146,14 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _EnumeratedValues): def __repr__(self): return util.generic_repr( - self, to_inspect=[ENUM, _StringType, sqltypes.Enum]) + self, to_inspect=[ENUM, _StringType, sqltypes.Enum] + ) class SET(_EnumeratedValues): """MySQL SET type.""" - __visit_name__ = 'SET' + __visit_name__ = "SET" def __init__(self, *values, **kw): """Construct a SET. @@ -216,45 +218,43 @@ class SET(_EnumeratedValues): """ - self.retrieve_as_bitwise = kw.pop('retrieve_as_bitwise', False) + self.retrieve_as_bitwise = kw.pop("retrieve_as_bitwise", False) values, length = self._init_values(values, kw) self.values = tuple(values) - if not self.retrieve_as_bitwise and '' in values: + if not self.retrieve_as_bitwise and "" in values: raise exc.ArgumentError( "Can't use the blank value '' in a SET without " - "setting retrieve_as_bitwise=True") + "setting retrieve_as_bitwise=True" + ) if self.retrieve_as_bitwise: self._bitmap = dict( - (value, 2 ** idx) - for idx, value in enumerate(self.values) + (value, 2 ** idx) for idx, value in enumerate(self.values) ) self._bitmap.update( - (2 ** idx, value) - for idx, value in enumerate(self.values) + (2 ** idx, value) for idx, value in enumerate(self.values) ) - kw.setdefault('length', length) + kw.setdefault("length", length) super(SET, self).__init__(**kw) def column_expression(self, colexpr): if self.retrieve_as_bitwise: return sql.type_coerce( - sql.type_coerce(colexpr, sqltypes.Integer) + 0, - self + sql.type_coerce(colexpr, sqltypes.Integer) + 0, self ) else: return colexpr def result_processor(self, dialect, coltype): if self.retrieve_as_bitwise: + def process(value): if value is not None: value = int(value) - return set( - util.map_bits(self._bitmap.__getitem__, value) - ) + return set(util.map_bits(self._bitmap.__getitem__, value)) else: return None + else: super_convert = super(SET, self).result_processor(dialect, coltype) @@ -263,18 +263,20 @@ class SET(_EnumeratedValues): # MySQLdb returns a string, let's parse if super_convert: value = super_convert(value) - return set(re.findall(r'[^,]+', value)) + return set(re.findall(r"[^,]+", value)) else: # mysql-connector-python does a naive # split(",") which throws in an empty string if value is not None: - value.discard('') + value.discard("") return value + return process def bind_processor(self, dialect): super_convert = super(SET, self).bind_processor(dialect) if self.retrieve_as_bitwise: + def process(value): if value is None: return None @@ -288,24 +290,23 @@ class SET(_EnumeratedValues): for v in value: int_value |= self._bitmap[v] return int_value + else: def process(value): # accept strings and int (actually bitflag) values directly if value is not None and not isinstance( - value, util.int_types + util.string_types): + value, util.int_types + util.string_types + ): value = ",".join(value) if super_convert: return super_convert(value) else: return value + return process def adapt(self, impltype, **kw): - kw['retrieve_as_bitwise'] = self.retrieve_as_bitwise - return util.constructor_copy( - self, impltype, - *self.values, - **kw - ) + kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise + return util.constructor_copy(self, impltype, *self.values, **kw) diff --git a/lib/sqlalchemy/dialects/mysql/gaerdbms.py b/lib/sqlalchemy/dialects/mysql/gaerdbms.py index 806e4c8745..117cd28a2f 100644 --- a/lib/sqlalchemy/dialects/mysql/gaerdbms.py +++ b/lib/sqlalchemy/dialects/mysql/gaerdbms.py @@ -44,11 +44,10 @@ from sqlalchemy.util import warn_deprecated def _is_dev_environment(): - return os.environ.get('SERVER_SOFTWARE', '').startswith('Development/') + return os.environ.get("SERVER_SOFTWARE", "").startswith("Development/") class MySQLDialect_gaerdbms(MySQLDialect_mysqldb): - @classmethod def dbapi(cls): @@ -69,12 +68,15 @@ class MySQLDialect_gaerdbms(MySQLDialect_mysqldb): if _is_dev_environment(): from google.appengine.api import rdbms_mysqldb + return rdbms_mysqldb - elif apiproxy_stub_map.apiproxy.GetStub('rdbms'): + elif apiproxy_stub_map.apiproxy.GetStub("rdbms"): from google.storage.speckle.python.api import rdbms_apiproxy + return rdbms_apiproxy else: from google.storage.speckle.python.api import rdbms_googleapi + return rdbms_googleapi @classmethod @@ -87,8 +89,8 @@ class MySQLDialect_gaerdbms(MySQLDialect_mysqldb): if not _is_dev_environment(): # 'dsn' and 'instance' are because we are skipping # the traditional google.api.rdbms wrapper - opts['dsn'] = '' - opts['instance'] = url.query['instance'] + opts["dsn"] = "" + opts["instance"] = url.query["instance"] return [], opts def _extract_error_code(self, exception): @@ -99,4 +101,5 @@ class MySQLDialect_gaerdbms(MySQLDialect_mysqldb): if code: return int(code) + dialect = MySQLDialect_gaerdbms diff --git a/lib/sqlalchemy/dialects/mysql/json.py b/lib/sqlalchemy/dialects/mysql/json.py index 534fb989d0..162d48f73b 100644 --- a/lib/sqlalchemy/dialects/mysql/json.py +++ b/lib/sqlalchemy/dialects/mysql/json.py @@ -58,7 +58,6 @@ class _FormatTypeMixin(object): class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): - def _format_value(self, value): if isinstance(value, int): value = "$[%s]" % value @@ -70,8 +69,10 @@ class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): def _format_value(self, value): return "$%s" % ( - "".join([ - "[%s]" % elem if isinstance(elem, int) - else '."%s"' % elem for elem in value - ]) + "".join( + [ + "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem + for elem in value + ] + ) ) diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py index e16b68bada..9c1502a14b 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -47,9 +47,13 @@ are contributed to SQLAlchemy. """ -from .base import (MySQLDialect, MySQLExecutionContext, - MySQLCompiler, MySQLIdentifierPreparer, - BIT) +from .base import ( + MySQLDialect, + MySQLExecutionContext, + MySQLCompiler, + MySQLIdentifierPreparer, + BIT, +) from ... import util import re @@ -57,7 +61,6 @@ from ... import processors class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext): - def get_lastrowid(self): return self.cursor.lastrowid @@ -65,21 +68,27 @@ class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext): class MySQLCompiler_mysqlconnector(MySQLCompiler): def visit_mod_binary(self, binary, operator, **kw): if self.dialect._mysqlconnector_double_percents: - return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) else: - return self.process(binary.left, **kw) + " % " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " % " + + self.process(binary.right, **kw) + ) def post_process_text(self, text): if self.dialect._mysqlconnector_double_percents: - return text.replace('%', '%%') + return text.replace("%", "%%") else: return text def escape_literal_column(self, text): if self.dialect._mysqlconnector_double_percents: - return text.replace('%', '%%') + return text.replace("%", "%%") else: return text @@ -109,7 +118,7 @@ class _myconnpyBIT(BIT): class MySQLDialect_mysqlconnector(MySQLDialect): - driver = 'mysqlconnector' + driver = "mysqlconnector" supports_unicode_binds = True @@ -118,28 +127,22 @@ class MySQLDialect_mysqlconnector(MySQLDialect): supports_native_decimal = True - default_paramstyle = 'format' + default_paramstyle = "format" execution_ctx_cls = MySQLExecutionContext_mysqlconnector statement_compiler = MySQLCompiler_mysqlconnector preparer = MySQLIdentifierPreparer_mysqlconnector - colspecs = util.update_copy( - MySQLDialect.colspecs, - { - BIT: _myconnpyBIT, - } - ) + colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT}) def __init__(self, *arg, **kw): super(MySQLDialect_mysqlconnector, self).__init__(*arg, **kw) # hack description encoding since mysqlconnector randomly # returns bytes or not - self._description_decoder = \ - processors.to_conditional_unicode_processor_factory( - self.description_encoding - ) + self._description_decoder = processors.to_conditional_unicode_processor_factory( + self.description_encoding + ) def _check_unicode_description(self, connection): # hack description encoding since mysqlconnector randomly @@ -158,6 +161,7 @@ class MySQLDialect_mysqlconnector(MySQLDialect): @classmethod def dbapi(cls): from mysql import connector + return connector def do_ping(self, dbapi_connection): @@ -172,54 +176,52 @@ class MySQLDialect_mysqlconnector(MySQLDialect): return True def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') + opts = url.translate_connect_args(username="user") opts.update(url.query) - util.coerce_kw_type(opts, 'allow_local_infile', bool) - util.coerce_kw_type(opts, 'autocommit', bool) - util.coerce_kw_type(opts, 'buffered', bool) - util.coerce_kw_type(opts, 'compress', bool) - util.coerce_kw_type(opts, 'connection_timeout', int) - util.coerce_kw_type(opts, 'connect_timeout', int) - util.coerce_kw_type(opts, 'consume_results', bool) - util.coerce_kw_type(opts, 'force_ipv6', bool) - util.coerce_kw_type(opts, 'get_warnings', bool) - util.coerce_kw_type(opts, 'pool_reset_session', bool) - util.coerce_kw_type(opts, 'pool_size', int) - util.coerce_kw_type(opts, 'raise_on_warnings', bool) - util.coerce_kw_type(opts, 'raw', bool) - util.coerce_kw_type(opts, 'ssl_verify_cert', bool) - util.coerce_kw_type(opts, 'use_pure', bool) - util.coerce_kw_type(opts, 'use_unicode', bool) + util.coerce_kw_type(opts, "allow_local_infile", bool) + util.coerce_kw_type(opts, "autocommit", bool) + util.coerce_kw_type(opts, "buffered", bool) + util.coerce_kw_type(opts, "compress", bool) + util.coerce_kw_type(opts, "connection_timeout", int) + util.coerce_kw_type(opts, "connect_timeout", int) + util.coerce_kw_type(opts, "consume_results", bool) + util.coerce_kw_type(opts, "force_ipv6", bool) + util.coerce_kw_type(opts, "get_warnings", bool) + util.coerce_kw_type(opts, "pool_reset_session", bool) + util.coerce_kw_type(opts, "pool_size", int) + util.coerce_kw_type(opts, "raise_on_warnings", bool) + util.coerce_kw_type(opts, "raw", bool) + util.coerce_kw_type(opts, "ssl_verify_cert", bool) + util.coerce_kw_type(opts, "use_pure", bool) + util.coerce_kw_type(opts, "use_unicode", bool) # unfortunately, MySQL/connector python refuses to release a # cursor without reading fully, so non-buffered isn't an option - opts.setdefault('buffered', True) + opts.setdefault("buffered", True) # FOUND_ROWS must be set in ClientFlag to enable # supports_sane_rowcount. if self.dbapi is not None: try: from mysql.connector.constants import ClientFlag + client_flags = opts.get( - 'client_flags', ClientFlag.get_default()) + "client_flags", ClientFlag.get_default() + ) client_flags |= ClientFlag.FOUND_ROWS - opts['client_flags'] = client_flags + opts["client_flags"] = client_flags except Exception: pass return [[], opts] @util.memoized_property def _mysqlconnector_version_info(self): - if self.dbapi and hasattr(self.dbapi, '__version__'): - m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?', - self.dbapi.__version__) + if self.dbapi and hasattr(self.dbapi, "__version__"): + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__) if m: - return tuple( - int(x) - for x in m.group(1, 2, 3) - if x is not None) + return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) @util.memoized_property def _mysqlconnector_double_percents(self): @@ -235,9 +237,11 @@ class MySQLDialect_mysqlconnector(MySQLDialect): errnos = (2006, 2013, 2014, 2045, 2055, 2048) exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError) if isinstance(e, exceptions): - return e.errno in errnos or \ - "MySQL Connection not available." in str(e) or \ - "Connection to MySQL is not available" in str(e) + return ( + e.errno in errnos + or "MySQL Connection not available." in str(e) + or "Connection to MySQL is not available" in str(e) + ) else: return False @@ -247,17 +251,24 @@ class MySQLDialect_mysqlconnector(MySQLDialect): def _compat_fetchone(self, rp, charset=None): return rp.fetchone() - _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED', - 'READ COMMITTED', 'REPEATABLE READ', - 'AUTOCOMMIT']) + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + ] + ) def _set_isolation_level(self, connection, level): - if level == 'AUTOCOMMIT': + if level == "AUTOCOMMIT": connection.autocommit = True else: connection.autocommit = False super(MySQLDialect_mysqlconnector, self)._set_isolation_level( - connection, level) + connection, level + ) dialect = MySQLDialect_mysqlconnector diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index edac816feb..6d42f5c04e 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -45,8 +45,12 @@ The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`. """ -from .base import (MySQLDialect, MySQLExecutionContext, - MySQLCompiler, MySQLIdentifierPreparer) +from .base import ( + MySQLDialect, + MySQLExecutionContext, + MySQLCompiler, + MySQLIdentifierPreparer, +) from .base import TEXT from ... import sql from ... import util @@ -54,10 +58,9 @@ import re class MySQLExecutionContext_mysqldb(MySQLExecutionContext): - @property def rowcount(self): - if hasattr(self, '_rowcount'): + if hasattr(self, "_rowcount"): return self._rowcount else: return self.cursor.rowcount @@ -72,14 +75,14 @@ class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer): class MySQLDialect_mysqldb(MySQLDialect): - driver = 'mysqldb' + driver = "mysqldb" supports_unicode_statements = True supports_sane_rowcount = True supports_sane_multi_rowcount = True supports_native_decimal = True - default_paramstyle = 'format' + default_paramstyle = "format" execution_ctx_cls = MySQLExecutionContext_mysqldb statement_compiler = MySQLCompiler_mysqldb preparer = MySQLIdentifierPreparer_mysqldb @@ -87,24 +90,23 @@ class MySQLDialect_mysqldb(MySQLDialect): def __init__(self, server_side_cursors=False, **kwargs): super(MySQLDialect_mysqldb, self).__init__(**kwargs) self.server_side_cursors = server_side_cursors - self._mysql_dbapi_version = self._parse_dbapi_version( - self.dbapi.__version__) if self.dbapi is not None \ - and hasattr(self.dbapi, '__version__') else (0, 0, 0) + self._mysql_dbapi_version = ( + self._parse_dbapi_version(self.dbapi.__version__) + if self.dbapi is not None and hasattr(self.dbapi, "__version__") + else (0, 0, 0) + ) def _parse_dbapi_version(self, version): - m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?', version) + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version) if m: - return tuple( - int(x) - for x in m.group(1, 2, 3) - if x is not None) + return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) else: return (0, 0, 0) @util.langhelpers.memoized_property def supports_server_side_cursors(self): try: - cursors = __import__('MySQLdb.cursors').cursors + cursors = __import__("MySQLdb.cursors").cursors self._sscursor = cursors.SSCursor return True except (ImportError, AttributeError): @@ -112,7 +114,7 @@ class MySQLDialect_mysqldb(MySQLDialect): @classmethod def dbapi(cls): - return __import__('MySQLdb') + return __import__("MySQLdb") def do_ping(self, dbapi_connection): try: @@ -135,67 +137,74 @@ class MySQLDialect_mysqldb(MySQLDialect): # https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8 # specific issue w/ the utf8mb4_bin collation and unicode returns - has_utf8mb4_bin = self.server_version_info > (5, ) and \ - connection.scalar( - "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'" - % ( - self.identifier_preparer.quote("Charset"), - self.identifier_preparer.quote("Collation") - )) + has_utf8mb4_bin = self.server_version_info > ( + 5, + ) and connection.scalar( + "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'" + % ( + self.identifier_preparer.quote("Charset"), + self.identifier_preparer.quote("Collation"), + ) + ) if has_utf8mb4_bin: additional_tests = [ - sql.collate(sql.cast( - sql.literal_column( - "'test collated returns'"), - TEXT(charset='utf8mb4')), "utf8mb4_bin") + sql.collate( + sql.cast( + sql.literal_column("'test collated returns'"), + TEXT(charset="utf8mb4"), + ), + "utf8mb4_bin", + ) ] else: additional_tests = [] return super(MySQLDialect_mysqldb, self)._check_unicode_returns( - connection, additional_tests) + connection, additional_tests + ) def create_connect_args(self, url): - opts = url.translate_connect_args(database='db', username='user', - password='passwd') + opts = url.translate_connect_args( + database="db", username="user", password="passwd" + ) opts.update(url.query) - util.coerce_kw_type(opts, 'compress', bool) - util.coerce_kw_type(opts, 'connect_timeout', int) - util.coerce_kw_type(opts, 'read_timeout', int) - util.coerce_kw_type(opts, 'write_timeout', int) - util.coerce_kw_type(opts, 'client_flag', int) - util.coerce_kw_type(opts, 'local_infile', int) + util.coerce_kw_type(opts, "compress", bool) + util.coerce_kw_type(opts, "connect_timeout", int) + util.coerce_kw_type(opts, "read_timeout", int) + util.coerce_kw_type(opts, "write_timeout", int) + util.coerce_kw_type(opts, "client_flag", int) + util.coerce_kw_type(opts, "local_infile", int) # Note: using either of the below will cause all strings to be # returned as Unicode, both in raw SQL operations and with column # types like String and MSString. - util.coerce_kw_type(opts, 'use_unicode', bool) - util.coerce_kw_type(opts, 'charset', str) + util.coerce_kw_type(opts, "use_unicode", bool) + util.coerce_kw_type(opts, "charset", str) # Rich values 'cursorclass' and 'conv' are not supported via # query string. ssl = {} - keys = ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher'] + keys = ["ssl_ca", "ssl_key", "ssl_cert", "ssl_capath", "ssl_cipher"] for key in keys: if key in opts: ssl[key[4:]] = opts[key] util.coerce_kw_type(ssl, key[4:], str) del opts[key] if ssl: - opts['ssl'] = ssl + opts["ssl"] = ssl # FOUND_ROWS must be set in CLIENT_FLAGS to enable # supports_sane_rowcount. - client_flag = opts.get('client_flag', 0) + client_flag = opts.get("client_flag", 0) if self.dbapi is not None: try: CLIENT_FLAGS = __import__( - self.dbapi.__name__ + '.constants.CLIENT' + self.dbapi.__name__ + ".constants.CLIENT" ).constants.CLIENT client_flag |= CLIENT_FLAGS.FOUND_ROWS except (AttributeError, ImportError): self.supports_sane_rowcount = False - opts['client_flag'] = client_flag + opts["client_flag"] = client_flag return [[], opts] def _extract_error_code(self, exception): @@ -213,22 +222,30 @@ class MySQLDialect_mysqldb(MySQLDialect): "No 'character_set_name' can be detected with " "this MySQL-Python version; " "please upgrade to a recent version of MySQL-Python. " - "Assuming latin1.") - return 'latin1' + "Assuming latin1." + ) + return "latin1" else: return cset_name() - _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED', - 'READ COMMITTED', 'REPEATABLE READ', - 'AUTOCOMMIT']) + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + ] + ) def _set_isolation_level(self, connection, level): - if level == 'AUTOCOMMIT': + if level == "AUTOCOMMIT": connection.autocommit(True) else: connection.autocommit(False) - super(MySQLDialect_mysqldb, self)._set_isolation_level(connection, - level) + super(MySQLDialect_mysqldb, self)._set_isolation_level( + connection, level + ) dialect = MySQLDialect_mysqldb diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py index 67dbb7cf23..8ba353a311 100644 --- a/lib/sqlalchemy/dialects/mysql/oursql.py +++ b/lib/sqlalchemy/dialects/mysql/oursql.py @@ -24,7 +24,7 @@ handling. import re -from .base import (BIT, MySQLDialect, MySQLExecutionContext) +from .base import BIT, MySQLDialect, MySQLExecutionContext from ... import types as sqltypes, util @@ -36,14 +36,13 @@ class _oursqlBIT(BIT): class MySQLExecutionContext_oursql(MySQLExecutionContext): - @property def plain_query(self): - return self.execution_options.get('_oursql_plain_query', False) + return self.execution_options.get("_oursql_plain_query", False) class MySQLDialect_oursql(MySQLDialect): - driver = 'oursql' + driver = "oursql" if util.py2k: supports_unicode_binds = True @@ -56,16 +55,12 @@ class MySQLDialect_oursql(MySQLDialect): execution_ctx_cls = MySQLExecutionContext_oursql colspecs = util.update_copy( - MySQLDialect.colspecs, - { - sqltypes.Time: sqltypes.Time, - BIT: _oursqlBIT, - } + MySQLDialect.colspecs, {sqltypes.Time: sqltypes.Time, BIT: _oursqlBIT} ) @classmethod def dbapi(cls): - return __import__('oursql') + return __import__("oursql") def do_execute(self, cursor, statement, parameters, context=None): """Provide an implementation of @@ -77,7 +72,7 @@ class MySQLDialect_oursql(MySQLDialect): cursor.execute(statement, parameters) def do_begin(self, connection): - connection.cursor().execute('BEGIN', plain_query=True) + connection.cursor().execute("BEGIN", plain_query=True) def _xa_query(self, connection, query, xid): if util.py2k: @@ -85,10 +80,12 @@ class MySQLDialect_oursql(MySQLDialect): else: charset = self._connection_charset arg = connection.connection._escape_string( - xid.encode(charset)).decode(charset) + xid.encode(charset) + ).decode(charset) arg = "'%s'" % arg - connection.execution_options( - _oursql_plain_query=True).execute(query % arg) + connection.execution_options(_oursql_plain_query=True).execute( + query % arg + ) # Because mysql is bad, these methods have to be # reimplemented to use _PlainQuery. Basically, some queries @@ -96,23 +93,25 @@ class MySQLDialect_oursql(MySQLDialect): # the parameterized query API, or refuse to be parameterized # in the first place. def do_begin_twophase(self, connection, xid): - self._xa_query(connection, 'XA BEGIN %s', xid) + self._xa_query(connection, "XA BEGIN %s", xid) def do_prepare_twophase(self, connection, xid): - self._xa_query(connection, 'XA END %s', xid) - self._xa_query(connection, 'XA PREPARE %s', xid) + self._xa_query(connection, "XA END %s", xid) + self._xa_query(connection, "XA PREPARE %s", xid) - def do_rollback_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: - self._xa_query(connection, 'XA END %s', xid) - self._xa_query(connection, 'XA ROLLBACK %s', xid) + self._xa_query(connection, "XA END %s", xid) + self._xa_query(connection, "XA ROLLBACK %s", xid) - def do_commit_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: self.do_prepare_twophase(connection, xid) - self._xa_query(connection, 'XA COMMIT %s', xid) + self._xa_query(connection, "XA COMMIT %s", xid) # Q: why didn't we need all these "plain_query" overrides earlier ? # am i on a newer/older version of OurSQL ? @@ -121,7 +120,7 @@ class MySQLDialect_oursql(MySQLDialect): self, connection.connect().execution_options(_oursql_plain_query=True), table_name, - schema + schema, ) def get_table_options(self, connection, table_name, schema=None, **kw): @@ -154,7 +153,7 @@ class MySQLDialect_oursql(MySQLDialect): return MySQLDialect.get_table_names( self, connection.connect().execution_options(_oursql_plain_query=True), - schema + schema, ) def get_schema_names(self, connection, **kw): @@ -166,57 +165,69 @@ class MySQLDialect_oursql(MySQLDialect): def initialize(self, connection): return MySQLDialect.initialize( - self, - connection.execution_options(_oursql_plain_query=True) + self, connection.execution_options(_oursql_plain_query=True) ) - def _show_create_table(self, connection, table, charset=None, - full_name=None): + def _show_create_table( + self, connection, table, charset=None, full_name=None + ): return MySQLDialect._show_create_table( self, - connection.contextual_connect(close_with_result=True). - execution_options(_oursql_plain_query=True), - table, charset, full_name + connection.contextual_connect( + close_with_result=True + ).execution_options(_oursql_plain_query=True), + table, + charset, + full_name, ) def is_disconnect(self, e, connection, cursor): if isinstance(e, self.dbapi.ProgrammingError): - return e.errno is None and 'cursor' not in e.args[1] \ - and e.args[1].endswith('closed') + return ( + e.errno is None + and "cursor" not in e.args[1] + and e.args[1].endswith("closed") + ) else: return e.errno in (2006, 2013, 2014, 2045, 2055) def create_connect_args(self, url): - opts = url.translate_connect_args(database='db', username='user', - password='passwd') + opts = url.translate_connect_args( + database="db", username="user", password="passwd" + ) opts.update(url.query) - util.coerce_kw_type(opts, 'port', int) - util.coerce_kw_type(opts, 'compress', bool) - util.coerce_kw_type(opts, 'autoping', bool) - util.coerce_kw_type(opts, 'raise_on_warnings', bool) + util.coerce_kw_type(opts, "port", int) + util.coerce_kw_type(opts, "compress", bool) + util.coerce_kw_type(opts, "autoping", bool) + util.coerce_kw_type(opts, "raise_on_warnings", bool) - util.coerce_kw_type(opts, 'default_charset', bool) - if opts.pop('default_charset', False): - opts['charset'] = None + util.coerce_kw_type(opts, "default_charset", bool) + if opts.pop("default_charset", False): + opts["charset"] = None else: - util.coerce_kw_type(opts, 'charset', str) - opts['use_unicode'] = opts.get('use_unicode', True) - util.coerce_kw_type(opts, 'use_unicode', bool) + util.coerce_kw_type(opts, "charset", str) + opts["use_unicode"] = opts.get("use_unicode", True) + util.coerce_kw_type(opts, "use_unicode", bool) # FOUND_ROWS must be set in CLIENT_FLAGS to enable # supports_sane_rowcount. - opts.setdefault('found_rows', True) + opts.setdefault("found_rows", True) ssl = {} - for key in ['ssl_ca', 'ssl_key', 'ssl_cert', - 'ssl_capath', 'ssl_cipher']: + for key in [ + "ssl_ca", + "ssl_key", + "ssl_cert", + "ssl_capath", + "ssl_cipher", + ]: if key in opts: ssl[key[4:]] = opts[key] util.coerce_kw_type(ssl, key[4:], str) del opts[key] if ssl: - opts['ssl'] = ssl + opts["ssl"] = ssl return [[], opts] diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py index 5f176cef2c..94dbfff063 100644 --- a/lib/sqlalchemy/dialects/mysql/pymysql.py +++ b/lib/sqlalchemy/dialects/mysql/pymysql.py @@ -34,7 +34,7 @@ from ...util import langhelpers, py3k class MySQLDialect_pymysql(MySQLDialect_mysqldb): - driver = 'pymysql' + driver = "pymysql" description_encoding = None @@ -51,7 +51,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): @langhelpers.memoized_property def supports_server_side_cursors(self): try: - cursors = __import__('pymysql.cursors').cursors + cursors = __import__("pymysql.cursors").cursors self._sscursor = cursors.SSCursor return True except (ImportError, AttributeError): @@ -59,10 +59,12 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): @classmethod def dbapi(cls): - return __import__('pymysql') + return __import__("pymysql") def is_disconnect(self, e, connection, cursor): - if super(MySQLDialect_pymysql, self).is_disconnect(e, connection, cursor): + if super(MySQLDialect_pymysql, self).is_disconnect( + e, connection, cursor + ): return True elif isinstance(e, self.dbapi.Error): return "Already closed" in str(e) @@ -70,9 +72,11 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): return False if py3k: + def _extract_error_code(self, exception): if isinstance(exception.args[0], Exception): exception = exception.args[0] return exception.args[0] + dialect = MySQLDialect_pymysql diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py index 718754651a..91512857ee 100644 --- a/lib/sqlalchemy/dialects/mysql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -29,7 +29,6 @@ import re class MySQLExecutionContext_pyodbc(MySQLExecutionContext): - def get_lastrowid(self): cursor = self.create_cursor() cursor.execute("SELECT LAST_INSERT_ID()") @@ -46,7 +45,7 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): def __init__(self, **kw): # deal with http://code.google.com/p/pyodbc/issues/detail?id=25 - kw.setdefault('convert_unicode', True) + kw.setdefault("convert_unicode", True) super(MySQLDialect_pyodbc, self).__init__(**kw) def _detect_charset(self, connection): @@ -60,13 +59,15 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): # this can prefer the driver value. rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") opts = {row[0]: row[1] for row in self._compat_fetchall(rs)} - for key in ('character_set_connection', 'character_set'): + for key in ("character_set_connection", "character_set"): if opts.get(key, None): return opts[key] - util.warn("Could not detect the connection character set. " - "Assuming latin1.") - return 'latin1' + util.warn( + "Could not detect the connection character set. " + "Assuming latin1." + ) + return "latin1" def _extract_error_code(self, exception): m = re.compile(r"\((\d+)\)").search(str(exception.args)) @@ -76,4 +77,5 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): else: return None + dialect = MySQLDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/mysql/reflection.py b/lib/sqlalchemy/dialects/mysql/reflection.py index e88bc3f42f..d0513eb4d9 100644 --- a/lib/sqlalchemy/dialects/mysql/reflection.py +++ b/lib/sqlalchemy/dialects/mysql/reflection.py @@ -36,16 +36,16 @@ class MySQLTableDefinitionParser(object): def parse(self, show_create, charset): state = ReflectedState() state.charset = charset - for line in re.split(r'\r?\n', show_create): - if line.startswith(' ' + self.preparer.initial_quote): + for line in re.split(r"\r?\n", show_create): + if line.startswith(" " + self.preparer.initial_quote): self._parse_column(line, state) # a regular table options line - elif line.startswith(') '): + elif line.startswith(") "): self._parse_table_options(line, state) # an ANSI-mode table options line - elif line == ')': + elif line == ")": pass - elif line.startswith('CREATE '): + elif line.startswith("CREATE "): self._parse_table_name(line, state) # Not present in real reflection, but may be if # loading from a file. @@ -55,11 +55,11 @@ class MySQLTableDefinitionParser(object): type_, spec = self._parse_constraints(line) if type_ is None: util.warn("Unknown schema content: %r" % line) - elif type_ == 'key': + elif type_ == "key": state.keys.append(spec) - elif type_ == 'fk_constraint': + elif type_ == "fk_constraint": state.fk_constraints.append(spec) - elif type_ == 'ck_constraint': + elif type_ == "ck_constraint": state.ck_constraints.append(spec) else: pass @@ -78,39 +78,39 @@ class MySQLTableDefinitionParser(object): # convert columns into name, length pairs # NOTE: we may want to consider SHOW INDEX as the # format of indexes in MySQL becomes more complex - spec['columns'] = self._parse_keyexprs(spec['columns']) - if spec['version_sql']: - m2 = self._re_key_version_sql.match(spec['version_sql']) - if m2 and m2.groupdict()['parser']: - spec['parser'] = m2.groupdict()['parser'] - if spec['parser']: - spec['parser'] = self.preparer.unformat_identifiers( - spec['parser'])[0] - return 'key', spec + spec["columns"] = self._parse_keyexprs(spec["columns"]) + if spec["version_sql"]: + m2 = self._re_key_version_sql.match(spec["version_sql"]) + if m2 and m2.groupdict()["parser"]: + spec["parser"] = m2.groupdict()["parser"] + if spec["parser"]: + spec["parser"] = self.preparer.unformat_identifiers( + spec["parser"] + )[0] + return "key", spec # FOREIGN KEY CONSTRAINT m = self._re_fk_constraint.match(line) if m: spec = m.groupdict() - spec['table'] = \ - self.preparer.unformat_identifiers(spec['table']) - spec['local'] = [c[0] - for c in self._parse_keyexprs(spec['local'])] - spec['foreign'] = [c[0] - for c in self._parse_keyexprs(spec['foreign'])] - return 'fk_constraint', spec + spec["table"] = self.preparer.unformat_identifiers(spec["table"]) + spec["local"] = [c[0] for c in self._parse_keyexprs(spec["local"])] + spec["foreign"] = [ + c[0] for c in self._parse_keyexprs(spec["foreign"]) + ] + return "fk_constraint", spec # CHECK constraint m = self._re_ck_constraint.match(line) if m: spec = m.groupdict() - return 'ck_constraint', spec + return "ck_constraint", spec # PARTITION and SUBPARTITION m = self._re_partition.match(line) if m: # Punt! - return 'partition', line + return "partition", line # No match. return (None, line) @@ -124,7 +124,7 @@ class MySQLTableDefinitionParser(object): regex, cleanup = self._pr_name m = regex.match(line) if m: - state.table_name = cleanup(m.group('name')) + state.table_name = cleanup(m.group("name")) def _parse_table_options(self, line, state): """Build a dictionary of all reflected table-level options. @@ -134,7 +134,7 @@ class MySQLTableDefinitionParser(object): options = {} - if not line or line == ')': + if not line or line == ")": pass else: @@ -143,17 +143,17 @@ class MySQLTableDefinitionParser(object): m = regex.search(rest_of_line) if not m: continue - directive, value = m.group('directive'), m.group('val') + directive, value = m.group("directive"), m.group("val") if cleanup: value = cleanup(value) options[directive.lower()] = value - rest_of_line = regex.sub('', rest_of_line) + rest_of_line = regex.sub("", rest_of_line) - for nope in ('auto_increment', 'data directory', 'index directory'): + for nope in ("auto_increment", "data directory", "index directory"): options.pop(nope, None) for opt, val in options.items(): - state.table_options['%s_%s' % (self.dialect.name, opt)] = val + state.table_options["%s_%s" % (self.dialect.name, opt)] = val def _parse_column(self, line, state): """Extract column details. @@ -167,29 +167,30 @@ class MySQLTableDefinitionParser(object): m = self._re_column.match(line) if m: spec = m.groupdict() - spec['full'] = True + spec["full"] = True else: m = self._re_column_loose.match(line) if m: spec = m.groupdict() - spec['full'] = False + spec["full"] = False if not spec: util.warn("Unknown column definition %r" % line) return - if not spec['full']: + if not spec["full"]: util.warn("Incomplete reflection of column definition %r" % line) - name, type_, args = spec['name'], spec['coltype'], spec['arg'] + name, type_, args = spec["name"], spec["coltype"], spec["arg"] try: col_type = self.dialect.ischema_names[type_] except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (type_, name)) + util.warn( + "Did not recognize type '%s' of column '%s'" % (type_, name) + ) col_type = sqltypes.NullType # Column type positional arguments eg. varchar(32) - if args is None or args == '': + if args is None or args == "": type_args = [] elif args[0] == "'" and args[-1] == "'": type_args = self._re_csv_str.findall(args) @@ -201,50 +202,51 @@ class MySQLTableDefinitionParser(object): if issubclass(col_type, (DATETIME, TIME, TIMESTAMP)): if type_args: - type_kw['fsp'] = type_args.pop(0) + type_kw["fsp"] = type_args.pop(0) - for kw in ('unsigned', 'zerofill'): + for kw in ("unsigned", "zerofill"): if spec.get(kw, False): type_kw[kw] = True - for kw in ('charset', 'collate'): + for kw in ("charset", "collate"): if spec.get(kw, False): type_kw[kw] = spec[kw] if issubclass(col_type, _EnumeratedValues): type_args = _EnumeratedValues._strip_values(type_args) - if issubclass(col_type, SET) and '' in type_args: - type_kw['retrieve_as_bitwise'] = True + if issubclass(col_type, SET) and "" in type_args: + type_kw["retrieve_as_bitwise"] = True type_instance = col_type(*type_args, **type_kw) col_kw = {} # NOT NULL - col_kw['nullable'] = True + col_kw["nullable"] = True # this can be "NULL" in the case of TIMESTAMP - if spec.get('notnull', False) == 'NOT NULL': - col_kw['nullable'] = False + if spec.get("notnull", False) == "NOT NULL": + col_kw["nullable"] = False # AUTO_INCREMENT - if spec.get('autoincr', False): - col_kw['autoincrement'] = True + if spec.get("autoincr", False): + col_kw["autoincrement"] = True elif issubclass(col_type, sqltypes.Integer): - col_kw['autoincrement'] = False + col_kw["autoincrement"] = False # DEFAULT - default = spec.get('default', None) + default = spec.get("default", None) - if default == 'NULL': + if default == "NULL": # eliminates the need to deal with this later. default = None - comment = spec.get('comment', None) + comment = spec.get("comment", None) if comment is not None: comment = comment.replace("\\\\", "\\").replace("''", "'") - col_d = dict(name=name, type=type_instance, default=default, - comment=comment) + col_d = dict( + name=name, type=type_instance, default=default, comment=comment + ) col_d.update(col_kw) state.columns.append(col_d) @@ -262,36 +264,44 @@ class MySQLTableDefinitionParser(object): buffer = [] for row in columns: - (name, col_type, nullable, default, extra) = \ - [row[i] for i in (0, 1, 2, 4, 5)] + (name, col_type, nullable, default, extra) = [ + row[i] for i in (0, 1, 2, 4, 5) + ] - line = [' '] + line = [" "] line.append(self.preparer.quote_identifier(name)) line.append(col_type) if not nullable: - line.append('NOT NULL') + line.append("NOT NULL") if default: - if 'auto_increment' in default: + if "auto_increment" in default: pass - elif (col_type.startswith('timestamp') and - default.startswith('C')): - line.append('DEFAULT') + elif col_type.startswith("timestamp") and default.startswith( + "C" + ): + line.append("DEFAULT") line.append(default) - elif default == 'NULL': - line.append('DEFAULT') + elif default == "NULL": + line.append("DEFAULT") line.append(default) else: - line.append('DEFAULT') + line.append("DEFAULT") line.append("'%s'" % default.replace("'", "''")) if extra: line.append(extra) - buffer.append(' '.join(line)) - - return ''.join([('CREATE TABLE %s (\n' % - self.preparer.quote_identifier(table_name)), - ',\n'.join(buffer), - '\n) ']) + buffer.append(" ".join(line)) + + return "".join( + [ + ( + "CREATE TABLE %s (\n" + % self.preparer.quote_identifier(table_name) + ), + ",\n".join(buffer), + "\n) ", + ] + ) def _parse_keyexprs(self, identifiers): """Unpack '"col"(2),"col" ASC'-ish strings into components.""" @@ -306,29 +316,39 @@ class MySQLTableDefinitionParser(object): _final = self.preparer.final_quote - quotes = dict(zip(('iq', 'fq', 'esc_fq'), - [re.escape(s) for s in - (self.preparer.initial_quote, - _final, - self.preparer._escape_identifier(_final))])) + quotes = dict( + zip( + ("iq", "fq", "esc_fq"), + [ + re.escape(s) + for s in ( + self.preparer.initial_quote, + _final, + self.preparer._escape_identifier(_final), + ) + ], + ) + ) self._pr_name = _pr_compile( - r'^CREATE (?:\w+ +)?TABLE +' - r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($' % quotes, - self.preparer._unescape_identifier) + r"^CREATE (?:\w+ +)?TABLE +" + r"%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($" % quotes, + self.preparer._unescape_identifier, + ) # `col`,`col2`(32),`col3`(15) DESC # self._re_keyexprs = _re_compile( - r'(?:' - r'(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)' - r'(?:\((\d+)\))?(?: +(ASC|DESC))?(?=\,|$))+' % quotes) + r"(?:" + r"(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)" + r"(?:\((\d+)\))?(?: +(ASC|DESC))?(?=\,|$))+" % quotes + ) # 'foo' or 'foo','bar' or 'fo,o','ba''a''r' - self._re_csv_str = _re_compile(r'\x27(?:\x27\x27|[^\x27])*\x27') + self._re_csv_str = _re_compile(r"\x27(?:\x27\x27|[^\x27])*\x27") # 123 or 123,456 - self._re_csv_int = _re_compile(r'\d+') + self._re_csv_int = _re_compile(r"\d+") # `colname` [type opts] # (NOT NULL | NULL) @@ -356,43 +376,39 @@ class MySQLTableDefinitionParser(object): r"(?: +COLUMN_FORMAT +(?P\w+))?" r"(?: +STORAGE +(?P\w+))?" r"(?: +(?P.*))?" - r",?$" - % quotes + r",?$" % quotes ) # Fallback, try to parse as little as possible self._re_column_loose = _re_compile( - r' ' - r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' - r'(?P\w+)' - r'(?:\((?P(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?' - r'.*?(?P(?:NOT )NULL)?' - % quotes + r" " + r"%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"(?P\w+)" + r"(?:\((?P(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?" + r".*?(?P(?:NOT )NULL)?" % quotes ) # (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))? # (`col` (ASC|DESC)?, `col` (ASC|DESC)?) # KEY_BLOCK_SIZE size | WITH PARSER name /*!50100 WITH PARSER name */ self._re_key = _re_compile( - r' ' - r'(?:(?P\S+) )?KEY' - r'(?: +%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?' - r'(?: +USING +(?P\S+))?' - r' +\((?P.+?)\)' - r'(?: +USING +(?P\S+))?' - r'(?: +KEY_BLOCK_SIZE *[ =]? *(?P\S+))?' - r'(?: +WITH PARSER +(?P\S+))?' - r'(?: +COMMENT +(?P(\x27\x27|\x27([^\x27])*?\x27)+))?' - r'(?: +/\*(?P.+)\*/ +)?' - r',?$' - % quotes + r" " + r"(?:(?P\S+) )?KEY" + r"(?: +%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?" + r"(?: +USING +(?P\S+))?" + r" +\((?P.+?)\)" + r"(?: +USING +(?P\S+))?" + r"(?: +KEY_BLOCK_SIZE *[ =]? *(?P\S+))?" + r"(?: +WITH PARSER +(?P\S+))?" + r"(?: +COMMENT +(?P(\x27\x27|\x27([^\x27])*?\x27)+))?" + r"(?: +/\*(?P.+)\*/ +)?" + r",?$" % quotes ) # https://forums.mysql.com/read.php?20,567102,567111#msg-567111 # It means if the MySQL version >= \d+, execute what's in the comment self._re_key_version_sql = _re_compile( - r'\!\d+ ' - r'(?: *WITH PARSER +(?P\S+) *)?' + r"\!\d+ " r"(?: *WITH PARSER +(?P\S+) *)?" ) # CONSTRAINT `name` FOREIGN KEY (`local_col`) @@ -402,20 +418,19 @@ class MySQLTableDefinitionParser(object): # # unique constraints come back as KEYs kw = quotes.copy() - kw['on'] = 'RESTRICT|CASCADE|SET NULL|NOACTION' + kw["on"] = "RESTRICT|CASCADE|SET NULL|NOACTION" self._re_fk_constraint = _re_compile( - r' ' - r'CONSTRAINT +' - r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' - r'FOREIGN KEY +' - r'\((?P[^\)]+?)\) REFERENCES +' - r'(?P%(iq)s[^%(fq)s]+%(fq)s' - r'(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +' - r'\((?P[^\)]+?)\)' - r'(?: +(?PMATCH \w+))?' - r'(?: +ON DELETE (?P%(on)s))?' - r'(?: +ON UPDATE (?P%(on)s))?' - % kw + r" " + r"CONSTRAINT +" + r"%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"FOREIGN KEY +" + r"\((?P[^\)]+?)\) REFERENCES +" + r"(?P
%(iq)s[^%(fq)s]+%(fq)s" + r"(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +" + r"\((?P[^\)]+?)\)" + r"(?: +(?PMATCH \w+))?" + r"(?: +ON DELETE (?P%(on)s))?" + r"(?: +ON UPDATE (?P%(on)s))?" % kw ) # CONSTRAINT `CONSTRAINT_1` CHECK (`x` > 5)' @@ -423,18 +438,17 @@ class MySQLTableDefinitionParser(object): # is returned on a line by itself, so to match without worrying # about parenthesis in the expresion we go to the end of the line self._re_ck_constraint = _re_compile( - r' ' - r'CONSTRAINT +' - r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' - r'CHECK +' - r'\((?P.+)\),?' - % kw + r" " + r"CONSTRAINT +" + r"%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"CHECK +" + r"\((?P.+)\),?" % kw ) # PARTITION # # punt! - self._re_partition = _re_compile(r'(?:.*)(?:SUB)?PARTITION(?:.*)') + self._re_partition = _re_compile(r"(?:.*)(?:SUB)?PARTITION(?:.*)") # Table-level options (COLLATE, ENGINE, etc.) # Do the string options first, since they have quoted @@ -442,44 +456,68 @@ class MySQLTableDefinitionParser(object): for option in _options_of_type_string: self._add_option_string(option) - for option in ('ENGINE', 'TYPE', 'AUTO_INCREMENT', - 'AVG_ROW_LENGTH', 'CHARACTER SET', - 'DEFAULT CHARSET', 'CHECKSUM', - 'COLLATE', 'DELAY_KEY_WRITE', 'INSERT_METHOD', - 'MAX_ROWS', 'MIN_ROWS', 'PACK_KEYS', 'ROW_FORMAT', - 'KEY_BLOCK_SIZE'): + for option in ( + "ENGINE", + "TYPE", + "AUTO_INCREMENT", + "AVG_ROW_LENGTH", + "CHARACTER SET", + "DEFAULT CHARSET", + "CHECKSUM", + "COLLATE", + "DELAY_KEY_WRITE", + "INSERT_METHOD", + "MAX_ROWS", + "MIN_ROWS", + "PACK_KEYS", + "ROW_FORMAT", + "KEY_BLOCK_SIZE", + ): self._add_option_word(option) - self._add_option_regex('UNION', r'\([^\)]+\)') - self._add_option_regex('TABLESPACE', r'.*? STORAGE DISK') + self._add_option_regex("UNION", r"\([^\)]+\)") + self._add_option_regex("TABLESPACE", r".*? STORAGE DISK") self._add_option_regex( - 'RAID_TYPE', - r'\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+') + "RAID_TYPE", + r"\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+", + ) - _optional_equals = r'(?:\s*(?:=\s*)|\s+)' + _optional_equals = r"(?:\s*(?:=\s*)|\s+)" def _add_option_string(self, directive): - regex = (r'(?P%s)%s' - r"'(?P(?:[^']|'')*?)'(?!')" % - (re.escape(directive), self._optional_equals)) - self._pr_options.append(_pr_compile( - regex, lambda v: v.replace("\\\\", "\\").replace("''", "'") - )) + regex = r"(?P%s)%s" r"'(?P(?:[^']|'')*?)'(?!')" % ( + re.escape(directive), + self._optional_equals, + ) + self._pr_options.append( + _pr_compile( + regex, lambda v: v.replace("\\\\", "\\").replace("''", "'") + ) + ) def _add_option_word(self, directive): - regex = (r'(?P%s)%s' - r'(?P\w+)' % - (re.escape(directive), self._optional_equals)) + regex = r"(?P%s)%s" r"(?P\w+)" % ( + re.escape(directive), + self._optional_equals, + ) self._pr_options.append(_pr_compile(regex)) def _add_option_regex(self, directive, regex): - regex = (r'(?P%s)%s' - r'(?P%s)' % - (re.escape(directive), self._optional_equals, regex)) + regex = r"(?P%s)%s" r"(?P%s)" % ( + re.escape(directive), + self._optional_equals, + regex, + ) self._pr_options.append(_pr_compile(regex)) -_options_of_type_string = ('COMMENT', 'DATA DIRECTORY', 'INDEX DIRECTORY', - 'PASSWORD', 'CONNECTION') + +_options_of_type_string = ( + "COMMENT", + "DATA DIRECTORY", + "INDEX DIRECTORY", + "PASSWORD", + "CONNECTION", +) def _pr_compile(regex, cleanup=None): diff --git a/lib/sqlalchemy/dialects/mysql/types.py b/lib/sqlalchemy/dialects/mysql/types.py index cb09a0841e..ad97a9bbe1 100644 --- a/lib/sqlalchemy/dialects/mysql/types.py +++ b/lib/sqlalchemy/dialects/mysql/types.py @@ -24,28 +24,30 @@ class _NumericType(object): super(_NumericType, self).__init__(**kw) def __repr__(self): - return util.generic_repr(self, - to_inspect=[_NumericType, sqltypes.Numeric]) + return util.generic_repr( + self, to_inspect=[_NumericType, sqltypes.Numeric] + ) class _FloatType(_NumericType, sqltypes.Float): def __init__(self, precision=None, scale=None, asdecimal=True, **kw): - if isinstance(self, (REAL, DOUBLE)) and \ - ( - (precision is None and scale is not None) or - (precision is not None and scale is None) + if isinstance(self, (REAL, DOUBLE)) and ( + (precision is None and scale is not None) + or (precision is not None and scale is None) ): raise exc.ArgumentError( "You must specify both precision and scale or omit " - "both altogether.") + "both altogether." + ) super(_FloatType, self).__init__( - precision=precision, asdecimal=asdecimal, **kw) + precision=precision, asdecimal=asdecimal, **kw + ) self.scale = scale def __repr__(self): - return util.generic_repr(self, to_inspect=[_FloatType, - _NumericType, - sqltypes.Float]) + return util.generic_repr( + self, to_inspect=[_FloatType, _NumericType, sqltypes.Float] + ) class _IntegerType(_NumericType, sqltypes.Integer): @@ -54,21 +56,28 @@ class _IntegerType(_NumericType, sqltypes.Integer): super(_IntegerType, self).__init__(**kw) def __repr__(self): - return util.generic_repr(self, to_inspect=[_IntegerType, - _NumericType, - sqltypes.Integer]) + return util.generic_repr( + self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer] + ) class _StringType(sqltypes.String): """Base for MySQL string types.""" - def __init__(self, charset=None, collation=None, - ascii=False, binary=False, unicode=False, - national=False, **kw): + def __init__( + self, + charset=None, + collation=None, + ascii=False, + binary=False, + unicode=False, + national=False, + **kw + ): self.charset = charset # allow collate= or collation= - kw.setdefault('collation', kw.pop('collate', collation)) + kw.setdefault("collation", kw.pop("collate", collation)) self.ascii = ascii self.unicode = unicode @@ -77,8 +86,9 @@ class _StringType(sqltypes.String): super(_StringType, self).__init__(**kw) def __repr__(self): - return util.generic_repr(self, - to_inspect=[_StringType, sqltypes.String]) + return util.generic_repr( + self, to_inspect=[_StringType, sqltypes.String] + ) class _MatchType(sqltypes.Float, sqltypes.MatchType): @@ -88,11 +98,10 @@ class _MatchType(sqltypes.Float, sqltypes.MatchType): sqltypes.MatchType.__init__(self) - class NUMERIC(_NumericType, sqltypes.NUMERIC): """MySQL NUMERIC type.""" - __visit_name__ = 'NUMERIC' + __visit_name__ = "NUMERIC" def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a NUMERIC. @@ -110,14 +119,15 @@ class NUMERIC(_NumericType, sqltypes.NUMERIC): numeric. """ - super(NUMERIC, self).__init__(precision=precision, - scale=scale, asdecimal=asdecimal, **kw) + super(NUMERIC, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) class DECIMAL(_NumericType, sqltypes.DECIMAL): """MySQL DECIMAL type.""" - __visit_name__ = 'DECIMAL' + __visit_name__ = "DECIMAL" def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a DECIMAL. @@ -135,14 +145,15 @@ class DECIMAL(_NumericType, sqltypes.DECIMAL): numeric. """ - super(DECIMAL, self).__init__(precision=precision, scale=scale, - asdecimal=asdecimal, **kw) + super(DECIMAL, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) class DOUBLE(_FloatType): """MySQL DOUBLE type.""" - __visit_name__ = 'DOUBLE' + __visit_name__ = "DOUBLE" def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a DOUBLE. @@ -168,14 +179,15 @@ class DOUBLE(_FloatType): numeric. """ - super(DOUBLE, self).__init__(precision=precision, scale=scale, - asdecimal=asdecimal, **kw) + super(DOUBLE, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) class REAL(_FloatType, sqltypes.REAL): """MySQL REAL type.""" - __visit_name__ = 'REAL' + __visit_name__ = "REAL" def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a REAL. @@ -201,14 +213,15 @@ class REAL(_FloatType, sqltypes.REAL): numeric. """ - super(REAL, self).__init__(precision=precision, scale=scale, - asdecimal=asdecimal, **kw) + super(REAL, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) class FLOAT(_FloatType, sqltypes.FLOAT): """MySQL FLOAT type.""" - __visit_name__ = 'FLOAT' + __visit_name__ = "FLOAT" def __init__(self, precision=None, scale=None, asdecimal=False, **kw): """Construct a FLOAT. @@ -226,8 +239,9 @@ class FLOAT(_FloatType, sqltypes.FLOAT): numeric. """ - super(FLOAT, self).__init__(precision=precision, scale=scale, - asdecimal=asdecimal, **kw) + super(FLOAT, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) def bind_processor(self, dialect): return None @@ -236,7 +250,7 @@ class FLOAT(_FloatType, sqltypes.FLOAT): class INTEGER(_IntegerType, sqltypes.INTEGER): """MySQL INTEGER type.""" - __visit_name__ = 'INTEGER' + __visit_name__ = "INTEGER" def __init__(self, display_width=None, **kw): """Construct an INTEGER. @@ -257,7 +271,7 @@ class INTEGER(_IntegerType, sqltypes.INTEGER): class BIGINT(_IntegerType, sqltypes.BIGINT): """MySQL BIGINTEGER type.""" - __visit_name__ = 'BIGINT' + __visit_name__ = "BIGINT" def __init__(self, display_width=None, **kw): """Construct a BIGINTEGER. @@ -278,7 +292,7 @@ class BIGINT(_IntegerType, sqltypes.BIGINT): class MEDIUMINT(_IntegerType): """MySQL MEDIUMINTEGER type.""" - __visit_name__ = 'MEDIUMINT' + __visit_name__ = "MEDIUMINT" def __init__(self, display_width=None, **kw): """Construct a MEDIUMINTEGER @@ -299,7 +313,7 @@ class MEDIUMINT(_IntegerType): class TINYINT(_IntegerType): """MySQL TINYINT type.""" - __visit_name__ = 'TINYINT' + __visit_name__ = "TINYINT" def __init__(self, display_width=None, **kw): """Construct a TINYINT. @@ -320,7 +334,7 @@ class TINYINT(_IntegerType): class SMALLINT(_IntegerType, sqltypes.SMALLINT): """MySQL SMALLINTEGER type.""" - __visit_name__ = 'SMALLINT' + __visit_name__ = "SMALLINT" def __init__(self, display_width=None, **kw): """Construct a SMALLINTEGER. @@ -347,7 +361,7 @@ class BIT(sqltypes.TypeEngine): """ - __visit_name__ = 'BIT' + __visit_name__ = "BIT" def __init__(self, length=None): """Construct a BIT. @@ -374,13 +388,14 @@ class BIT(sqltypes.TypeEngine): v = v << 8 | i return v return value + return process class TIME(sqltypes.TIME): """MySQL TIME type. """ - __visit_name__ = 'TIME' + __visit_name__ = "TIME" def __init__(self, timezone=False, fsp=None): """Construct a MySQL TIME type. @@ -413,12 +428,15 @@ class TIME(sqltypes.TIME): microseconds = value.microseconds seconds = value.seconds minutes = seconds // 60 - return time(minutes // 60, - minutes % 60, - seconds - minutes * 60, - microsecond=microseconds) + return time( + minutes // 60, + minutes % 60, + seconds - minutes * 60, + microsecond=microseconds, + ) else: return None + return process @@ -427,7 +445,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP): """ - __visit_name__ = 'TIMESTAMP' + __visit_name__ = "TIMESTAMP" def __init__(self, timezone=False, fsp=None): """Construct a MySQL TIMESTAMP type. @@ -457,7 +475,7 @@ class DATETIME(sqltypes.DATETIME): """ - __visit_name__ = 'DATETIME' + __visit_name__ = "DATETIME" def __init__(self, timezone=False, fsp=None): """Construct a MySQL DATETIME type. @@ -485,7 +503,7 @@ class DATETIME(sqltypes.DATETIME): class YEAR(sqltypes.TypeEngine): """MySQL YEAR type, for single byte storage of years 1901-2155.""" - __visit_name__ = 'YEAR' + __visit_name__ = "YEAR" def __init__(self, display_width=None): self.display_width = display_width @@ -494,7 +512,7 @@ class YEAR(sqltypes.TypeEngine): class TEXT(_StringType, sqltypes.TEXT): """MySQL TEXT type, for text up to 2^16 characters.""" - __visit_name__ = 'TEXT' + __visit_name__ = "TEXT" def __init__(self, length=None, **kw): """Construct a TEXT. @@ -530,7 +548,7 @@ class TEXT(_StringType, sqltypes.TEXT): class TINYTEXT(_StringType): """MySQL TINYTEXT type, for text up to 2^8 characters.""" - __visit_name__ = 'TINYTEXT' + __visit_name__ = "TINYTEXT" def __init__(self, **kwargs): """Construct a TINYTEXT. @@ -562,7 +580,7 @@ class TINYTEXT(_StringType): class MEDIUMTEXT(_StringType): """MySQL MEDIUMTEXT type, for text up to 2^24 characters.""" - __visit_name__ = 'MEDIUMTEXT' + __visit_name__ = "MEDIUMTEXT" def __init__(self, **kwargs): """Construct a MEDIUMTEXT. @@ -594,7 +612,7 @@ class MEDIUMTEXT(_StringType): class LONGTEXT(_StringType): """MySQL LONGTEXT type, for text up to 2^32 characters.""" - __visit_name__ = 'LONGTEXT' + __visit_name__ = "LONGTEXT" def __init__(self, **kwargs): """Construct a LONGTEXT. @@ -626,7 +644,7 @@ class LONGTEXT(_StringType): class VARCHAR(_StringType, sqltypes.VARCHAR): """MySQL VARCHAR type, for variable-length character data.""" - __visit_name__ = 'VARCHAR' + __visit_name__ = "VARCHAR" def __init__(self, length=None, **kwargs): """Construct a VARCHAR. @@ -658,7 +676,7 @@ class VARCHAR(_StringType, sqltypes.VARCHAR): class CHAR(_StringType, sqltypes.CHAR): """MySQL CHAR type, for fixed-length character data.""" - __visit_name__ = 'CHAR' + __visit_name__ = "CHAR" def __init__(self, length=None, **kwargs): """Construct a CHAR. @@ -690,7 +708,7 @@ class CHAR(_StringType, sqltypes.CHAR): ascii=type_.ascii, binary=type_.binary, unicode=type_.unicode, - national=False # not supported in CAST + national=False, # not supported in CAST ) else: return CHAR(length=type_.length) @@ -703,7 +721,7 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR): character set. """ - __visit_name__ = 'NVARCHAR' + __visit_name__ = "NVARCHAR" def __init__(self, length=None, **kwargs): """Construct an NVARCHAR. @@ -718,7 +736,7 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR): compatible with the national character set. """ - kwargs['national'] = True + kwargs["national"] = True super(NVARCHAR, self).__init__(length=length, **kwargs) @@ -729,7 +747,7 @@ class NCHAR(_StringType, sqltypes.NCHAR): character set. """ - __visit_name__ = 'NCHAR' + __visit_name__ = "NCHAR" def __init__(self, length=None, **kwargs): """Construct an NCHAR. @@ -744,23 +762,23 @@ class NCHAR(_StringType, sqltypes.NCHAR): compatible with the national character set. """ - kwargs['national'] = True + kwargs["national"] = True super(NCHAR, self).__init__(length=length, **kwargs) class TINYBLOB(sqltypes._Binary): """MySQL TINYBLOB type, for binary data up to 2^8 bytes.""" - __visit_name__ = 'TINYBLOB' + __visit_name__ = "TINYBLOB" class MEDIUMBLOB(sqltypes._Binary): """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes.""" - __visit_name__ = 'MEDIUMBLOB' + __visit_name__ = "MEDIUMBLOB" class LONGBLOB(sqltypes._Binary): """MySQL LONGBLOB type, for binary data up to 2^32 bytes.""" - __visit_name__ = 'LONGBLOB' + __visit_name__ = "LONGBLOB" diff --git a/lib/sqlalchemy/dialects/mysql/zxjdbc.py b/lib/sqlalchemy/dialects/mysql/zxjdbc.py index 4aee2dbb74..d8ee437489 100644 --- a/lib/sqlalchemy/dialects/mysql/zxjdbc.py +++ b/lib/sqlalchemy/dialects/mysql/zxjdbc.py @@ -37,6 +37,7 @@ from .base import BIT, MySQLDialect, MySQLExecutionContext class _ZxJDBCBit(BIT): def result_processor(self, dialect, coltype): """Converts boolean or byte arrays from MySQL Connector/J to longs.""" + def process(value): if value is None: return value @@ -44,9 +45,10 @@ class _ZxJDBCBit(BIT): return int(value) v = 0 for i in value: - v = v << 8 | (i & 0xff) + v = v << 8 | (i & 0xFF) value = v return value + return process @@ -60,17 +62,13 @@ class MySQLExecutionContext_zxjdbc(MySQLExecutionContext): class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect): - jdbc_db_name = 'mysql' - jdbc_driver_name = 'com.mysql.jdbc.Driver' + jdbc_db_name = "mysql" + jdbc_driver_name = "com.mysql.jdbc.Driver" execution_ctx_cls = MySQLExecutionContext_zxjdbc colspecs = util.update_copy( - MySQLDialect.colspecs, - { - sqltypes.Time: sqltypes.Time, - BIT: _ZxJDBCBit - } + MySQLDialect.colspecs, {sqltypes.Time: sqltypes.Time, BIT: _ZxJDBCBit} ) def _detect_charset(self, connection): @@ -83,17 +81,19 @@ class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect): # this can prefer the driver value. rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") opts = {row[0]: row[1] for row in self._compat_fetchall(rs)} - for key in ('character_set_connection', 'character_set'): + for key in ("character_set_connection", "character_set"): if opts.get(key, None): return opts[key] - util.warn("Could not detect the connection character set. " - "Assuming latin1.") - return 'latin1' + util.warn( + "Could not detect the connection character set. " + "Assuming latin1." + ) + return "latin1" def _driver_kwargs(self): """return kw arg dict to be sent to connect().""" - return dict(characterEncoding='UTF-8', yearIsDateType='false') + return dict(characterEncoding="UTF-8", yearIsDateType="false") def _extract_error_code(self, exception): # e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist @@ -106,7 +106,7 @@ class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect): def _get_server_version_info(self, connection): dbapi_con = connection.connection version = [] - r = re.compile(r'[.\-]') + r = re.compile(r"[.\-]") for n in r.split(dbapi_con.dbversion): try: version.append(int(n)) @@ -114,4 +114,5 @@ class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect): version.append(n) return tuple(version) + dialect = MySQLDialect_zxjdbc diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py index e3d9fed2cc..1b9007fcc7 100644 --- a/lib/sqlalchemy/dialects/oracle/__init__.py +++ b/lib/sqlalchemy/dialects/oracle/__init__.py @@ -7,18 +7,51 @@ from . import base, cx_oracle, zxjdbc # noqa -from .base import \ - VARCHAR, NVARCHAR, CHAR, DATE, NUMBER,\ - BLOB, BFILE, BINARY_FLOAT, BINARY_DOUBLE, CLOB, NCLOB, TIMESTAMP, RAW,\ - FLOAT, DOUBLE_PRECISION, LONG, INTERVAL,\ - VARCHAR2, NVARCHAR2, ROWID +from .base import ( + VARCHAR, + NVARCHAR, + CHAR, + DATE, + NUMBER, + BLOB, + BFILE, + BINARY_FLOAT, + BINARY_DOUBLE, + CLOB, + NCLOB, + TIMESTAMP, + RAW, + FLOAT, + DOUBLE_PRECISION, + LONG, + INTERVAL, + VARCHAR2, + NVARCHAR2, + ROWID, +) base.dialect = dialect = cx_oracle.dialect __all__ = ( - 'VARCHAR', 'NVARCHAR', 'CHAR', 'DATE', 'NUMBER', - 'BLOB', 'BFILE', 'CLOB', 'NCLOB', 'TIMESTAMP', 'RAW', - 'FLOAT', 'DOUBLE_PRECISION', 'BINARY_DOUBLE', 'BINARY_FLOAT', - 'LONG', 'dialect', 'INTERVAL', - 'VARCHAR2', 'NVARCHAR2', 'ROWID' + "VARCHAR", + "NVARCHAR", + "CHAR", + "DATE", + "NUMBER", + "BLOB", + "BFILE", + "CLOB", + "NCLOB", + "TIMESTAMP", + "RAW", + "FLOAT", + "DOUBLE_PRECISION", + "BINARY_DOUBLE", + "BINARY_FLOAT", + "LONG", + "dialect", + "INTERVAL", + "VARCHAR2", + "NVARCHAR2", + "ROWID", ) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index b5aea4386b..944fe21c36 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -353,49 +353,63 @@ from sqlalchemy.sql import compiler, visitors, expression, util as sql_util from sqlalchemy.sql import operators as sql_operators from sqlalchemy.sql.elements import quoted_name from sqlalchemy import types as sqltypes, schema as sa_schema -from sqlalchemy.types import VARCHAR, NVARCHAR, CHAR, \ - BLOB, CLOB, TIMESTAMP, FLOAT, INTEGER +from sqlalchemy.types import ( + VARCHAR, + NVARCHAR, + CHAR, + BLOB, + CLOB, + TIMESTAMP, + FLOAT, + INTEGER, +) from itertools import groupby -RESERVED_WORDS = \ - set('SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN ' - 'DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED ' - 'ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE ' - 'ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE ' - 'BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES ' - 'AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS ' - 'NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER ' - 'CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR ' - 'DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL'.split()) +RESERVED_WORDS = set( + "SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN " + "DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED " + "ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE " + "ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE " + "BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES " + "AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS " + "NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER " + "CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR " + "DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL".split() +) -NO_ARG_FNS = set('UID CURRENT_DATE SYSDATE USER ' - 'CURRENT_TIME CURRENT_TIMESTAMP'.split()) +NO_ARG_FNS = set( + "UID CURRENT_DATE SYSDATE USER " "CURRENT_TIME CURRENT_TIMESTAMP".split() +) class RAW(sqltypes._Binary): - __visit_name__ = 'RAW' + __visit_name__ = "RAW" + + OracleRaw = RAW class NCLOB(sqltypes.Text): - __visit_name__ = 'NCLOB' + __visit_name__ = "NCLOB" class VARCHAR2(VARCHAR): - __visit_name__ = 'VARCHAR2' + __visit_name__ = "VARCHAR2" + NVARCHAR2 = NVARCHAR class NUMBER(sqltypes.Numeric, sqltypes.Integer): - __visit_name__ = 'NUMBER' + __visit_name__ = "NUMBER" def __init__(self, precision=None, scale=None, asdecimal=None): if asdecimal is None: asdecimal = bool(scale and scale > 0) super(NUMBER, self).__init__( - precision=precision, scale=scale, asdecimal=asdecimal) + precision=precision, scale=scale, asdecimal=asdecimal + ) def adapt(self, impltype): ret = super(NUMBER, self).adapt(impltype) @@ -412,23 +426,23 @@ class NUMBER(sqltypes.Numeric, sqltypes.Integer): class DOUBLE_PRECISION(sqltypes.Float): - __visit_name__ = 'DOUBLE_PRECISION' + __visit_name__ = "DOUBLE_PRECISION" class BINARY_DOUBLE(sqltypes.Float): - __visit_name__ = 'BINARY_DOUBLE' + __visit_name__ = "BINARY_DOUBLE" class BINARY_FLOAT(sqltypes.Float): - __visit_name__ = 'BINARY_FLOAT' + __visit_name__ = "BINARY_FLOAT" class BFILE(sqltypes.LargeBinary): - __visit_name__ = 'BFILE' + __visit_name__ = "BFILE" class LONG(sqltypes.Text): - __visit_name__ = 'LONG' + __visit_name__ = "LONG" class DATE(sqltypes.DateTime): @@ -441,18 +455,17 @@ class DATE(sqltypes.DateTime): .. versionadded:: 0.9.4 """ - __visit_name__ = 'DATE' + + __visit_name__ = "DATE" def _compare_type_affinity(self, other): return other._type_affinity in (sqltypes.DateTime, sqltypes.Date) class INTERVAL(sqltypes.TypeEngine): - __visit_name__ = 'INTERVAL' + __visit_name__ = "INTERVAL" - def __init__(self, - day_precision=None, - second_precision=None): + def __init__(self, day_precision=None, second_precision=None): """Construct an INTERVAL. Note that only DAY TO SECOND intervals are currently supported. @@ -471,8 +484,10 @@ class INTERVAL(sqltypes.TypeEngine): @classmethod def _adapt_from_generic_interval(cls, interval): - return INTERVAL(day_precision=interval.day_precision, - second_precision=interval.second_precision) + return INTERVAL( + day_precision=interval.day_precision, + second_precision=interval.second_precision, + ) @property def _type_affinity(self): @@ -485,38 +500,40 @@ class ROWID(sqltypes.TypeEngine): When used in a cast() or similar, generates ROWID. """ - __visit_name__ = 'ROWID' + + __visit_name__ = "ROWID" class _OracleBoolean(sqltypes.Boolean): def get_dbapi_type(self, dbapi): return dbapi.NUMBER + colspecs = { sqltypes.Boolean: _OracleBoolean, sqltypes.Interval: INTERVAL, - sqltypes.DateTime: DATE + sqltypes.DateTime: DATE, } ischema_names = { - 'VARCHAR2': VARCHAR, - 'NVARCHAR2': NVARCHAR, - 'CHAR': CHAR, - 'DATE': DATE, - 'NUMBER': NUMBER, - 'BLOB': BLOB, - 'BFILE': BFILE, - 'CLOB': CLOB, - 'NCLOB': NCLOB, - 'TIMESTAMP': TIMESTAMP, - 'TIMESTAMP WITH TIME ZONE': TIMESTAMP, - 'INTERVAL DAY TO SECOND': INTERVAL, - 'RAW': RAW, - 'FLOAT': FLOAT, - 'DOUBLE PRECISION': DOUBLE_PRECISION, - 'LONG': LONG, - 'BINARY_DOUBLE': BINARY_DOUBLE, - 'BINARY_FLOAT': BINARY_FLOAT + "VARCHAR2": VARCHAR, + "NVARCHAR2": NVARCHAR, + "CHAR": CHAR, + "DATE": DATE, + "NUMBER": NUMBER, + "BLOB": BLOB, + "BFILE": BFILE, + "CLOB": CLOB, + "NCLOB": NCLOB, + "TIMESTAMP": TIMESTAMP, + "TIMESTAMP WITH TIME ZONE": TIMESTAMP, + "INTERVAL DAY TO SECOND": INTERVAL, + "RAW": RAW, + "FLOAT": FLOAT, + "DOUBLE PRECISION": DOUBLE_PRECISION, + "LONG": LONG, + "BINARY_DOUBLE": BINARY_DOUBLE, + "BINARY_FLOAT": BINARY_FLOAT, } @@ -540,12 +557,12 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): def visit_INTERVAL(self, type_, **kw): return "INTERVAL DAY%s TO SECOND%s" % ( - type_.day_precision is not None and - "(%d)" % type_.day_precision or - "", - type_.second_precision is not None and - "(%d)" % type_.second_precision or - "", + type_.day_precision is not None + and "(%d)" % type_.day_precision + or "", + type_.second_precision is not None + and "(%d)" % type_.second_precision + or "", ) def visit_LONG(self, type_, **kw): @@ -569,52 +586,53 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): def visit_FLOAT(self, type_, **kw): # don't support conversion between decimal/binary # precision yet - kw['no_precision'] = True + kw["no_precision"] = True return self._generate_numeric(type_, "FLOAT", **kw) def visit_NUMBER(self, type_, **kw): return self._generate_numeric(type_, "NUMBER", **kw) def _generate_numeric( - self, type_, name, precision=None, - scale=None, no_precision=False, **kw): + self, type_, name, precision=None, scale=None, no_precision=False, **kw + ): if precision is None: precision = type_.precision if scale is None: - scale = getattr(type_, 'scale', None) + scale = getattr(type_, "scale", None) if no_precision or precision is None: return name elif scale is None: n = "%(name)s(%(precision)s)" - return n % {'name': name, 'precision': precision} + return n % {"name": name, "precision": precision} else: n = "%(name)s(%(precision)s, %(scale)s)" - return n % {'name': name, 'precision': precision, 'scale': scale} + return n % {"name": name, "precision": precision, "scale": scale} def visit_string(self, type_, **kw): return self.visit_VARCHAR2(type_, **kw) def visit_VARCHAR2(self, type_, **kw): - return self._visit_varchar(type_, '', '2') + return self._visit_varchar(type_, "", "2") def visit_NVARCHAR2(self, type_, **kw): - return self._visit_varchar(type_, 'N', '2') + return self._visit_varchar(type_, "N", "2") + visit_NVARCHAR = visit_NVARCHAR2 def visit_VARCHAR(self, type_, **kw): - return self._visit_varchar(type_, '', '') + return self._visit_varchar(type_, "", "") def _visit_varchar(self, type_, n, num): if not type_.length: - return "%(n)sVARCHAR%(two)s" % {'two': num, 'n': n} + return "%(n)sVARCHAR%(two)s" % {"two": num, "n": n} elif not n and self.dialect._supports_char_length: varchar = "VARCHAR%(two)s(%(length)s CHAR)" - return varchar % {'length': type_.length, 'two': num} + return varchar % {"length": type_.length, "two": num} else: varchar = "%(n)sVARCHAR%(two)s(%(length)s)" - return varchar % {'length': type_.length, 'two': num, 'n': n} + return varchar % {"length": type_.length, "two": num, "n": n} def visit_text(self, type_, **kw): return self.visit_CLOB(type_, **kw) @@ -636,7 +654,7 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): def visit_RAW(self, type_, **kw): if type_.length: - return "RAW(%(length)s)" % {'length': type_.length} + return "RAW(%(length)s)" % {"length": type_.length} else: return "RAW" @@ -652,9 +670,7 @@ class OracleCompiler(compiler.SQLCompiler): compound_keywords = util.update_copy( compiler.SQLCompiler.compound_keywords, - { - expression.CompoundSelect.EXCEPT: 'MINUS' - } + {expression.CompoundSelect.EXCEPT: "MINUS"}, ) def __init__(self, *args, **kwargs): @@ -663,8 +679,10 @@ class OracleCompiler(compiler.SQLCompiler): super(OracleCompiler, self).__init__(*args, **kwargs) def visit_mod_binary(self, binary, operator, **kw): - return "mod(%s, %s)" % (self.process(binary.left, **kw), - self.process(binary.right, **kw)) + return "mod(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) def visit_now_func(self, fn, **kw): return "CURRENT_TIMESTAMP" @@ -673,22 +691,22 @@ class OracleCompiler(compiler.SQLCompiler): return "LENGTH" + self.function_argspec(fn, **kw) def visit_match_op_binary(self, binary, operator, **kw): - return "CONTAINS (%s, %s)" % (self.process(binary.left), - self.process(binary.right)) + return "CONTAINS (%s, %s)" % ( + self.process(binary.left), + self.process(binary.right), + ) def visit_true(self, expr, **kw): - return '1' + return "1" def visit_false(self, expr, **kw): - return '0' + return "0" def get_cte_preamble(self, recursive): return "WITH" def get_select_hint_text(self, byfroms): - return " ".join( - "/*+ %s */" % text for table, text in byfroms.items() - ) + return " ".join("/*+ %s */" % text for table, text in byfroms.items()) def function_argspec(self, fn, **kw): if len(fn.clauses) > 0 or fn.name.upper() not in NO_ARG_FNS: @@ -709,13 +727,16 @@ class OracleCompiler(compiler.SQLCompiler): if self.dialect.use_ansi: return compiler.SQLCompiler.visit_join(self, join, **kwargs) else: - kwargs['asfrom'] = True + kwargs["asfrom"] = True if isinstance(join.right, expression.FromGrouping): right = join.right.element else: right = join.right - return self.process(join.left, **kwargs) + \ - ", " + self.process(right, **kwargs) + return ( + self.process(join.left, **kwargs) + + ", " + + self.process(right, **kwargs) + ) def _get_nonansi_join_whereclause(self, froms): clauses = [] @@ -727,14 +748,20 @@ class OracleCompiler(compiler.SQLCompiler): # the join condition in the WHERE clause" - that is, # unconditionally regardless of operator or the other side def visit_binary(binary): - if isinstance(binary.left, expression.ColumnClause) \ - and join.right.is_derived_from(binary.left.table): + if isinstance( + binary.left, expression.ColumnClause + ) and join.right.is_derived_from(binary.left.table): binary.left = _OuterJoinColumn(binary.left) - elif isinstance(binary.right, expression.ColumnClause) \ - and join.right.is_derived_from(binary.right.table): + elif isinstance( + binary.right, expression.ColumnClause + ) and join.right.is_derived_from(binary.right.table): binary.right = _OuterJoinColumn(binary.right) - clauses.append(visitors.cloned_traverse( - join.onclause, {}, {'binary': visit_binary})) + + clauses.append( + visitors.cloned_traverse( + join.onclause, {}, {"binary": visit_binary} + ) + ) else: clauses.append(join.onclause) @@ -757,8 +784,9 @@ class OracleCompiler(compiler.SQLCompiler): return self.process(vc.column, **kw) + "(+)" def visit_sequence(self, seq, **kw): - return (self.dialect.identifier_preparer.format_sequence(seq) + - ".nextval") + return ( + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval" + ) def get_render_as_alias_suffix(self, alias_name_text): """Oracle doesn't like ``FROM table AS alias``""" @@ -770,7 +798,8 @@ class OracleCompiler(compiler.SQLCompiler): binds = [] for i, column in enumerate( - expression._select_iterables(returning_cols)): + expression._select_iterables(returning_cols) + ): if column.type._has_column_expression: col_expr = column.type.column_expression(column) else: @@ -779,19 +808,22 @@ class OracleCompiler(compiler.SQLCompiler): outparam = sql.outparam("ret_%d" % i, type_=column.type) self.binds[outparam.key] = outparam binds.append( - self.bindparam_string(self._truncate_bindparam(outparam))) - columns.append( - self.process(col_expr, within_columns_clause=False)) + self.bindparam_string(self._truncate_bindparam(outparam)) + ) + columns.append(self.process(col_expr, within_columns_clause=False)) self._add_to_result_map( - getattr(col_expr, 'name', col_expr.anon_label), - getattr(col_expr, 'name', col_expr.anon_label), - (column, getattr(column, 'name', None), - getattr(column, 'key', None)), - column.type + getattr(col_expr, "name", col_expr.anon_label), + getattr(col_expr, "name", col_expr.anon_label), + ( + column, + getattr(column, "name", None), + getattr(column, "key", None), + ), + column.type, ) - return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) + return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds) def _TODO_visit_compound_select(self, select): """Need to determine how to get ``LIMIT``/``OFFSET`` into a @@ -804,10 +836,11 @@ class OracleCompiler(compiler.SQLCompiler): so tries to wrap it in a subquery with ``rownum`` criterion. """ - if not getattr(select, '_oracle_visit', None): + if not getattr(select, "_oracle_visit", None): if not self.dialect.use_ansi: froms = self._display_froms_for_select( - select, kwargs.get('asfrom', False)) + select, kwargs.get("asfrom", False) + ) whereclause = self._get_nonansi_join_whereclause(froms) if whereclause is not None: select = select.where(whereclause) @@ -828,18 +861,20 @@ class OracleCompiler(compiler.SQLCompiler): # Outer select and "ROWNUM as ora_rn" can be dropped if # limit=0 - kwargs['select_wraps_for'] = select + kwargs["select_wraps_for"] = select select = select._generate() select._oracle_visit = True # Wrap the middle select and add the hint limitselect = sql.select([c for c in select.c]) - if limit_clause is not None and \ - self.dialect.optimize_limits and \ - select._simple_int_limit: + if ( + limit_clause is not None + and self.dialect.optimize_limits + and select._simple_int_limit + ): limitselect = limitselect.prefix_with( - "/*+ FIRST_ROWS(%d) */" % - select._limit) + "/*+ FIRST_ROWS(%d) */" % select._limit + ) limitselect._oracle_visit = True limitselect._is_wrapper = True @@ -855,8 +890,8 @@ class OracleCompiler(compiler.SQLCompiler): adapter = sql_util.ClauseAdapter(select) for_update.of = [ - adapter.traverse(elem) - for elem in for_update.of] + adapter.traverse(elem) for elem in for_update.of + ] # If needed, add the limiting clause if limit_clause is not None: @@ -873,7 +908,8 @@ class OracleCompiler(compiler.SQLCompiler): if offset_clause is not None: max_row = max_row + offset_clause limitselect.append_whereclause( - sql.literal_column("ROWNUM") <= max_row) + sql.literal_column("ROWNUM") <= max_row + ) # If needed, add the ora_rn, and wrap again with offset. if offset_clause is None: @@ -881,12 +917,14 @@ class OracleCompiler(compiler.SQLCompiler): select = limitselect else: limitselect = limitselect.column( - sql.literal_column("ROWNUM").label("ora_rn")) + sql.literal_column("ROWNUM").label("ora_rn") + ) limitselect._oracle_visit = True limitselect._is_wrapper = True offsetselect = sql.select( - [c for c in limitselect.c if c.key != 'ora_rn']) + [c for c in limitselect.c if c.key != "ora_rn"] + ) offsetselect._oracle_visit = True offsetselect._is_wrapper = True @@ -897,9 +935,11 @@ class OracleCompiler(compiler.SQLCompiler): if not self.dialect.use_binds_for_limits: offset_clause = sql.literal_column( - "%d" % select._offset) + "%d" % select._offset + ) offsetselect.append_whereclause( - sql.literal_column("ora_rn") > offset_clause) + sql.literal_column("ora_rn") > offset_clause + ) offsetselect._for_update_arg = for_update select = offsetselect @@ -910,18 +950,17 @@ class OracleCompiler(compiler.SQLCompiler): return "" def visit_empty_set_expr(self, type_): - return 'SELECT 1 FROM DUAL WHERE 1!=1' + return "SELECT 1 FROM DUAL WHERE 1!=1" def for_update_clause(self, select, **kw): if self.is_subquery(): return "" - tmp = ' FOR UPDATE' + tmp = " FOR UPDATE" if select._for_update_arg.of: - tmp += ' OF ' + ', '.join( - self.process(elem, **kw) for elem in - select._for_update_arg.of + tmp += " OF " + ", ".join( + self.process(elem, **kw) for elem in select._for_update_arg.of ) if select._for_update_arg.nowait: @@ -933,7 +972,6 @@ class OracleCompiler(compiler.SQLCompiler): class OracleDDLCompiler(compiler.DDLCompiler): - def define_constraint_cascades(self, constraint): text = "" if constraint.ondelete is not None: @@ -947,7 +985,8 @@ class OracleDDLCompiler(compiler.DDLCompiler): "Oracle does not contain native UPDATE CASCADE " "functionality - onupdates will not be rendered for foreign " "keys. Consider using deferrable=True, initially='deferred' " - "or triggers.") + "or triggers." + ) return text @@ -958,75 +997,79 @@ class OracleDDLCompiler(compiler.DDLCompiler): text = "CREATE " if index.unique: text += "UNIQUE " - if index.dialect_options['oracle']['bitmap']: + if index.dialect_options["oracle"]["bitmap"]: text += "BITMAP " text += "INDEX %s ON %s (%s)" % ( self._prepared_index_name(index, include_schema=True), preparer.format_table(index.table, use_schema=True), - ', '.join( + ", ".join( self.sql_compiler.process( - expr, - include_table=False, literal_binds=True) - for expr in index.expressions) + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), ) - if index.dialect_options['oracle']['compress'] is not False: - if index.dialect_options['oracle']['compress'] is True: + if index.dialect_options["oracle"]["compress"] is not False: + if index.dialect_options["oracle"]["compress"] is True: text += " COMPRESS" else: text += " COMPRESS %d" % ( - index.dialect_options['oracle']['compress'] + index.dialect_options["oracle"]["compress"] ) return text def post_create_table(self, table): table_opts = [] - opts = table.dialect_options['oracle'] + opts = table.dialect_options["oracle"] - if opts['on_commit']: - on_commit_options = opts['on_commit'].replace("_", " ").upper() - table_opts.append('\n ON COMMIT %s' % on_commit_options) + if opts["on_commit"]: + on_commit_options = opts["on_commit"].replace("_", " ").upper() + table_opts.append("\n ON COMMIT %s" % on_commit_options) - if opts['compress']: - if opts['compress'] is True: + if opts["compress"]: + if opts["compress"] is True: table_opts.append("\n COMPRESS") else: - table_opts.append("\n COMPRESS FOR %s" % ( - opts['compress'] - )) + table_opts.append("\n COMPRESS FOR %s" % (opts["compress"])) - return ''.join(table_opts) + return "".join(table_opts) class OracleIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = {x.lower() for x in RESERVED_WORDS} - illegal_initial_characters = {str(dig) for dig in range(0, 10)} \ - .union(["_", "$"]) + illegal_initial_characters = {str(dig) for dig in range(0, 10)}.union( + ["_", "$"] + ) def _bindparam_requires_quotes(self, value): """Return True if the given identifier requires quoting.""" lc_value = value.lower() - return (lc_value in self.reserved_words - or value[0] in self.illegal_initial_characters - or not self.legal_characters.match(util.text_type(value)) - ) + return ( + lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(util.text_type(value)) + ) def format_savepoint(self, savepoint): - name = savepoint.ident.lstrip('_') - return super( - OracleIdentifierPreparer, self).format_savepoint(savepoint, name) + name = savepoint.ident.lstrip("_") + return super(OracleIdentifierPreparer, self).format_savepoint( + savepoint, name + ) class OracleExecutionContext(default.DefaultExecutionContext): def fire_sequence(self, seq, type_): return self._execute_scalar( - "SELECT " + - self.dialect.identifier_preparer.format_sequence(seq) + - ".nextval FROM DUAL", type_) + "SELECT " + + self.dialect.identifier_preparer.format_sequence(seq) + + ".nextval FROM DUAL", + type_, + ) class OracleDialect(default.DefaultDialect): - name = 'oracle' + name = "oracle" supports_alter = True supports_unicode_statements = False supports_unicode_binds = False @@ -1039,7 +1082,7 @@ class OracleDialect(default.DefaultDialect): sequences_optional = False postfetch_lastrowid = False - default_paramstyle = 'named' + default_paramstyle = "named" colspecs = colspecs ischema_names = ischema_names requires_name_normalize = True @@ -1054,29 +1097,27 @@ class OracleDialect(default.DefaultDialect): preparer = OracleIdentifierPreparer execution_ctx_cls = OracleExecutionContext - reflection_options = ('oracle_resolve_synonyms', ) + reflection_options = ("oracle_resolve_synonyms",) _use_nchar_for_unicode = False construct_arguments = [ - (sa_schema.Table, { - "resolve_synonyms": False, - "on_commit": None, - "compress": False - }), - (sa_schema.Index, { - "bitmap": False, - "compress": False - }) + ( + sa_schema.Table, + {"resolve_synonyms": False, "on_commit": None, "compress": False}, + ), + (sa_schema.Index, {"bitmap": False, "compress": False}), ] - def __init__(self, - use_ansi=True, - optimize_limits=False, - use_binds_for_limits=True, - use_nchar_for_unicode=False, - exclude_tablespaces=('SYSTEM', 'SYSAUX', ), - **kwargs): + def __init__( + self, + use_ansi=True, + optimize_limits=False, + use_binds_for_limits=True, + use_nchar_for_unicode=False, + exclude_tablespaces=("SYSTEM", "SYSAUX"), + **kwargs + ): default.DefaultDialect.__init__(self, **kwargs) self._use_nchar_for_unicode = use_nchar_for_unicode self.use_ansi = use_ansi @@ -1087,8 +1128,7 @@ class OracleDialect(default.DefaultDialect): def initialize(self, connection): super(OracleDialect, self).initialize(connection) self.implicit_returning = self.__dict__.get( - 'implicit_returning', - self.server_version_info > (10, ) + "implicit_returning", self.server_version_info > (10,) ) if self._is_oracle_8: @@ -1098,18 +1138,15 @@ class OracleDialect(default.DefaultDialect): @property def _is_oracle_8(self): - return self.server_version_info and \ - self.server_version_info < (9, ) + return self.server_version_info and self.server_version_info < (9,) @property def _supports_table_compression(self): - return self.server_version_info and \ - self.server_version_info >= (10, 1, ) + return self.server_version_info and self.server_version_info >= (10, 1) @property def _supports_table_compress_for(self): - return self.server_version_info and \ - self.server_version_info >= (11, ) + return self.server_version_info and self.server_version_info >= (11,) @property def _supports_char_length(self): @@ -1123,31 +1160,38 @@ class OracleDialect(default.DefaultDialect): additional_tests = [ expression.cast( expression.literal_column("'test nvarchar2 returns'"), - sqltypes.NVARCHAR(60) - ), + sqltypes.NVARCHAR(60), + ) ] return super(OracleDialect, self)._check_unicode_returns( - connection, additional_tests) + connection, additional_tests + ) def has_table(self, connection, table_name, schema=None): if not schema: schema = self.default_schema_name cursor = connection.execute( - sql.text("SELECT table_name FROM all_tables " - "WHERE table_name = :name AND owner = :schema_name"), + sql.text( + "SELECT table_name FROM all_tables " + "WHERE table_name = :name AND owner = :schema_name" + ), name=self.denormalize_name(table_name), - schema_name=self.denormalize_name(schema)) + schema_name=self.denormalize_name(schema), + ) return cursor.first() is not None def has_sequence(self, connection, sequence_name, schema=None): if not schema: schema = self.default_schema_name cursor = connection.execute( - sql.text("SELECT sequence_name FROM all_sequences " - "WHERE sequence_name = :name AND " - "sequence_owner = :schema_name"), + sql.text( + "SELECT sequence_name FROM all_sequences " + "WHERE sequence_name = :name AND " + "sequence_owner = :schema_name" + ), name=self.denormalize_name(sequence_name), - schema_name=self.denormalize_name(schema)) + schema_name=self.denormalize_name(schema), + ) return cursor.first() is not None def normalize_name(self, name): @@ -1156,8 +1200,9 @@ class OracleDialect(default.DefaultDialect): if util.py2k: if isinstance(name, str): name = name.decode(self.encoding) - if name.upper() == name and not \ - self.identifier_preparer._requires_quotes(name.lower()): + if name.upper() == name and not self.identifier_preparer._requires_quotes( + name.lower() + ): return name.lower() elif name.lower() == name: return quoted_name(name, quote=True) @@ -1167,8 +1212,9 @@ class OracleDialect(default.DefaultDialect): def denormalize_name(self, name): if name is None: return None - elif name.lower() == name and not \ - self.identifier_preparer._requires_quotes(name.lower()): + elif name.lower() == name and not self.identifier_preparer._requires_quotes( + name.lower() + ): name = name.upper() if util.py2k: if not self.supports_unicode_binds: @@ -1179,10 +1225,16 @@ class OracleDialect(default.DefaultDialect): def _get_default_schema_name(self, connection): return self.normalize_name( - connection.execute('SELECT USER FROM DUAL').scalar()) + connection.execute("SELECT USER FROM DUAL").scalar() + ) - def _resolve_synonym(self, connection, desired_owner=None, - desired_synonym=None, desired_table=None): + def _resolve_synonym( + self, + connection, + desired_owner=None, + desired_synonym=None, + desired_table=None, + ): """search for a local synonym matching the given desired owner/name. if desired_owner is None, attempts to locate a distinct owner. @@ -1191,19 +1243,21 @@ class OracleDialect(default.DefaultDialect): found. """ - q = "SELECT owner, table_owner, table_name, db_link, "\ + q = ( + "SELECT owner, table_owner, table_name, db_link, " "synonym_name FROM all_synonyms WHERE " + ) clauses = [] params = {} if desired_synonym: clauses.append("synonym_name = :synonym_name") - params['synonym_name'] = desired_synonym + params["synonym_name"] = desired_synonym if desired_owner: clauses.append("owner = :desired_owner") - params['desired_owner'] = desired_owner + params["desired_owner"] = desired_owner if desired_table: clauses.append("table_name = :tname") - params['tname'] = desired_table + params["tname"] = desired_table q += " AND ".join(clauses) @@ -1211,8 +1265,12 @@ class OracleDialect(default.DefaultDialect): if desired_owner: row = result.first() if row: - return (row['table_name'], row['table_owner'], - row['db_link'], row['synonym_name']) + return ( + row["table_name"], + row["table_owner"], + row["db_link"], + row["synonym_name"], + ) else: return None, None, None, None else: @@ -1220,23 +1278,35 @@ class OracleDialect(default.DefaultDialect): if len(rows) > 1: raise AssertionError( "There are multiple tables visible to the schema, you " - "must specify owner") + "must specify owner" + ) elif len(rows) == 1: row = rows[0] - return (row['table_name'], row['table_owner'], - row['db_link'], row['synonym_name']) + return ( + row["table_name"], + row["table_owner"], + row["db_link"], + row["synonym_name"], + ) else: return None, None, None, None @reflection.cache - def _prepare_reflection_args(self, connection, table_name, schema=None, - resolve_synonyms=False, dblink='', **kw): + def _prepare_reflection_args( + self, + connection, + table_name, + schema=None, + resolve_synonyms=False, + dblink="", + **kw + ): if resolve_synonyms: actual_name, owner, dblink, synonym = self._resolve_synonym( connection, desired_owner=self.denormalize_name(schema), - desired_synonym=self.denormalize_name(table_name) + desired_synonym=self.denormalize_name(table_name), ) else: actual_name, owner, dblink, synonym = None, None, None, None @@ -1250,18 +1320,21 @@ class OracleDialect(default.DefaultDialect): # will need to hear from more users if we are doing # the right thing here. See [ticket:2619] owner = connection.scalar( - sql.text("SELECT username FROM user_db_links " - "WHERE db_link=:link"), link=dblink) + sql.text( + "SELECT username FROM user_db_links " "WHERE db_link=:link" + ), + link=dblink, + ) dblink = "@" + dblink elif not owner: owner = self.denormalize_name(schema or self.default_schema_name) - return (actual_name, owner, dblink or '', synonym) + return (actual_name, owner, dblink or "", synonym) @reflection.cache def get_schema_names(self, connection, **kw): s = "SELECT username FROM all_users ORDER BY username" - cursor = connection.execute(s,) + cursor = connection.execute(s) return [self.normalize_name(row[0]) for row in cursor] @reflection.cache @@ -1276,14 +1349,12 @@ class OracleDialect(default.DefaultDialect): if self.exclude_tablespaces: sql_str += ( "nvl(tablespace_name, 'no tablespace') " - "NOT IN (%s) AND " % ( - ', '.join(["'%s'" % ts for ts in self.exclude_tablespaces]) - ) + "NOT IN (%s) AND " + % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces])) ) sql_str += ( - "OWNER = :owner " - "AND IOT_NAME IS NULL " - "AND DURATION IS NULL") + "OWNER = :owner " "AND IOT_NAME IS NULL " "AND DURATION IS NULL" + ) cursor = connection.execute(sql.text(sql_str), owner=schema) return [self.normalize_name(row[0]) for row in cursor] @@ -1296,14 +1367,14 @@ class OracleDialect(default.DefaultDialect): if self.exclude_tablespaces: sql_str += ( "nvl(tablespace_name, 'no tablespace') " - "NOT IN (%s) AND " % ( - ', '.join(["'%s'" % ts for ts in self.exclude_tablespaces]) - ) + "NOT IN (%s) AND " + % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces])) ) sql_str += ( "OWNER = :owner " "AND IOT_NAME IS NULL " - "AND DURATION IS NOT NULL") + "AND DURATION IS NOT NULL" + ) cursor = connection.execute(sql.text(sql_str), owner=schema) return [self.normalize_name(row[0]) for row in cursor] @@ -1319,14 +1390,18 @@ class OracleDialect(default.DefaultDialect): def get_table_options(self, connection, table_name, schema=None, **kw): options = {} - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) params = {"table_name": table_name} @@ -1336,14 +1411,16 @@ class OracleDialect(default.DefaultDialect): if self._supports_table_compress_for: columns.append("compress_for") - text = "SELECT %(columns)s "\ - "FROM ALL_TABLES%(dblink)s "\ + text = ( + "SELECT %(columns)s " + "FROM ALL_TABLES%(dblink)s " "WHERE table_name = :table_name" + ) if schema is not None: - params['owner'] = schema + params["owner"] = schema text += " AND owner = :owner " - text = text % {'dblink': dblink, 'columns': ", ".join(columns)} + text = text % {"dblink": dblink, "columns": ", ".join(columns)} result = connection.execute(sql.text(text), **params) @@ -1353,9 +1430,9 @@ class OracleDialect(default.DefaultDialect): if row: if "compression" in row and enabled.get(row.compression, False): if "compress_for" in row: - options['oracle_compress'] = row.compress_for + options["oracle_compress"] = row.compress_for else: - options['oracle_compress'] = True + options["oracle_compress"] = True return options @@ -1371,19 +1448,23 @@ class OracleDialect(default.DefaultDialect): """ - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) columns = [] if self._supports_char_length: - char_length_col = 'char_length' + char_length_col = "char_length" else: - char_length_col = 'data_length' + char_length_col = "data_length" params = {"table_name": table_name} text = """ @@ -1398,10 +1479,10 @@ class OracleDialect(default.DefaultDialect): WHERE col.table_name = :table_name """ if schema is not None: - params['owner'] = schema + params["owner"] = schema text += " AND col.owner = :owner " text += " ORDER BY col.column_id" - text = text % {'dblink': dblink, 'char_length_col': char_length_col} + text = text % {"dblink": dblink, "char_length_col": char_length_col} c = connection.execute(sql.text(text), **params) @@ -1412,54 +1493,67 @@ class OracleDialect(default.DefaultDialect): length = row[2] precision = row[3] scale = row[4] - nullable = row[5] == 'Y' + nullable = row[5] == "Y" default = row[6] comment = row[7] - if coltype == 'NUMBER': + if coltype == "NUMBER": if precision is None and scale == 0: coltype = INTEGER() else: coltype = NUMBER(precision, scale) - elif coltype == 'FLOAT': + elif coltype == "FLOAT": # TODO: support "precision" here as "binary_precision" coltype = FLOAT() - elif coltype in ('VARCHAR2', 'NVARCHAR2', 'CHAR'): + elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR"): coltype = self.ischema_names.get(coltype)(length) - elif 'WITH TIME ZONE' in coltype: + elif "WITH TIME ZONE" in coltype: coltype = TIMESTAMP(timezone=True) else: - coltype = re.sub(r'\(\d+\)', '', coltype) + coltype = re.sub(r"\(\d+\)", "", coltype) try: coltype = self.ischema_names[coltype] except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (coltype, colname)) + util.warn( + "Did not recognize type '%s' of column '%s'" + % (coltype, colname) + ) coltype = sqltypes.NULLTYPE cdict = { - 'name': colname, - 'type': coltype, - 'nullable': nullable, - 'default': default, - 'autoincrement': 'auto', - 'comment': comment, + "name": colname, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": "auto", + "comment": comment, } if orig_colname.lower() == orig_colname: - cdict['quote'] = True + cdict["quote"] = True columns.append(cdict) return columns @reflection.cache - def get_table_comment(self, connection, table_name, schema=None, - resolve_synonyms=False, dblink='', **kw): - - info_cache = kw.get('info_cache') - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + def get_table_comment( + self, + connection, + table_name, + schema=None, + resolve_synonyms=False, + dblink="", + **kw + ): + + info_cache = kw.get("info_cache") + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) COMMENT_SQL = """ SELECT comments @@ -1471,67 +1565,90 @@ class OracleDialect(default.DefaultDialect): return {"text": c.scalar()} @reflection.cache - def get_indexes(self, connection, table_name, schema=None, - resolve_synonyms=False, dblink='', **kw): - - info_cache = kw.get('info_cache') - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + def get_indexes( + self, + connection, + table_name, + schema=None, + resolve_synonyms=False, + dblink="", + **kw + ): + + info_cache = kw.get("info_cache") + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) indexes = [] - params = {'table_name': table_name} - text = \ - "SELECT a.index_name, a.column_name, "\ - "\nb.index_type, b.uniqueness, b.compression, b.prefix_length "\ - "\nFROM ALL_IND_COLUMNS%(dblink)s a, "\ - "\nALL_INDEXES%(dblink)s b "\ - "\nWHERE "\ - "\na.index_name = b.index_name "\ - "\nAND a.table_owner = b.table_owner "\ - "\nAND a.table_name = b.table_name "\ + params = {"table_name": table_name} + text = ( + "SELECT a.index_name, a.column_name, " + "\nb.index_type, b.uniqueness, b.compression, b.prefix_length " + "\nFROM ALL_IND_COLUMNS%(dblink)s a, " + "\nALL_INDEXES%(dblink)s b " + "\nWHERE " + "\na.index_name = b.index_name " + "\nAND a.table_owner = b.table_owner " + "\nAND a.table_name = b.table_name " "\nAND a.table_name = :table_name " + ) if schema is not None: - params['schema'] = schema + params["schema"] = schema text += "AND a.table_owner = :schema " text += "ORDER BY a.index_name, a.column_position" - text = text % {'dblink': dblink} + text = text % {"dblink": dblink} q = sql.text(text) rp = connection.execute(q, **params) indexes = [] last_index_name = None pk_constraint = self.get_pk_constraint( - connection, table_name, schema, resolve_synonyms=resolve_synonyms, - dblink=dblink, info_cache=kw.get('info_cache')) - pkeys = pk_constraint['constrained_columns'] + connection, + table_name, + schema, + resolve_synonyms=resolve_synonyms, + dblink=dblink, + info_cache=kw.get("info_cache"), + ) + pkeys = pk_constraint["constrained_columns"] uniqueness = dict(NONUNIQUE=False, UNIQUE=True) enabled = dict(DISABLED=False, ENABLED=True) - oracle_sys_col = re.compile(r'SYS_NC\d+\$', re.IGNORECASE) + oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE) index = None for rset in rp: if rset.index_name != last_index_name: - index = dict(name=self.normalize_name(rset.index_name), - column_names=[], dialect_options={}) + index = dict( + name=self.normalize_name(rset.index_name), + column_names=[], + dialect_options={}, + ) indexes.append(index) - index['unique'] = uniqueness.get(rset.uniqueness, False) + index["unique"] = uniqueness.get(rset.uniqueness, False) - if rset.index_type in ('BITMAP', 'FUNCTION-BASED BITMAP'): - index['dialect_options']['oracle_bitmap'] = True + if rset.index_type in ("BITMAP", "FUNCTION-BASED BITMAP"): + index["dialect_options"]["oracle_bitmap"] = True if enabled.get(rset.compression, False): - index['dialect_options']['oracle_compress'] = rset.prefix_length + index["dialect_options"][ + "oracle_compress" + ] = rset.prefix_length # filter out Oracle SYS_NC names. could also do an outer join # to the all_tab_columns table and check for real col names there. if not oracle_sys_col.match(rset.column_name): - index['column_names'].append( - self.normalize_name(rset.column_name)) + index["column_names"].append( + self.normalize_name(rset.column_name) + ) last_index_name = rset.index_name def upper_name_set(names): @@ -1539,18 +1656,21 @@ class OracleDialect(default.DefaultDialect): pk_names = upper_name_set(pkeys) if pk_names: + def is_pk_index(index): # don't include the primary key index - return upper_name_set(index['column_names']) == pk_names + return upper_name_set(index["column_names"]) == pk_names + indexes = [idx for idx in indexes if not is_pk_index(idx)] return indexes @reflection.cache - def _get_constraint_data(self, connection, table_name, schema=None, - dblink='', **kw): + def _get_constraint_data( + self, connection, table_name, schema=None, dblink="", **kw + ): - params = {'table_name': table_name} + params = {"table_name": table_name} text = ( "SELECT" @@ -1572,7 +1692,7 @@ class OracleDialect(default.DefaultDialect): ) if schema is not None: - params['owner'] = schema + params["owner"] = schema text += "\nAND ac.owner = :owner" text += ( @@ -1584,35 +1704,49 @@ class OracleDialect(default.DefaultDialect): "\nORDER BY ac.constraint_name, loc.position" ) - text = text % {'dblink': dblink} + text = text % {"dblink": dblink} rp = connection.execute(sql.text(text), **params) constraint_data = rp.fetchall() return constraint_data @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) pkeys = [] constraint_name = None constraint_data = self._get_constraint_data( - connection, table_name, schema, dblink, - info_cache=kw.get('info_cache')) + connection, + table_name, + schema, + dblink, + info_cache=kw.get("info_cache"), + ) for row in constraint_data: - (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ - row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) - if cons_type == 'P': + ( + cons_name, + cons_type, + local_column, + remote_table, + remote_column, + remote_owner, + ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + if cons_type == "P": if constraint_name is None: constraint_name = self.normalize_name(cons_name) pkeys.append(local_column) - return {'constrained_columns': pkeys, 'name': constraint_name} + return {"constrained_columns": pkeys, "name": constraint_name} @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): @@ -1626,74 +1760,94 @@ class OracleDialect(default.DefaultDialect): """ requested_schema = schema # to check later on - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) constraint_data = self._get_constraint_data( - connection, table_name, schema, dblink, - info_cache=kw.get('info_cache')) + connection, + table_name, + schema, + dblink, + info_cache=kw.get("info_cache"), + ) def fkey_rec(): return { - 'name': None, - 'constrained_columns': [], - 'referred_schema': None, - 'referred_table': None, - 'referred_columns': [], - 'options': {}, + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], + "options": {}, } fkeys = util.defaultdict(fkey_rec) for row in constraint_data: - (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ - row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + ( + cons_name, + cons_type, + local_column, + remote_table, + remote_column, + remote_owner, + ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) cons_name = self.normalize_name(cons_name) - if cons_type == 'R': + if cons_type == "R": if remote_table is None: # ticket 363 util.warn( - ("Got 'None' querying 'table_name' from " - "all_cons_columns%(dblink)s - does the user have " - "proper rights to the table?") % {'dblink': dblink}) + ( + "Got 'None' querying 'table_name' from " + "all_cons_columns%(dblink)s - does the user have " + "proper rights to the table?" + ) + % {"dblink": dblink} + ) continue rec = fkeys[cons_name] - rec['name'] = cons_name - local_cols, remote_cols = rec[ - 'constrained_columns'], rec['referred_columns'] + rec["name"] = cons_name + local_cols, remote_cols = ( + rec["constrained_columns"], + rec["referred_columns"], + ) - if not rec['referred_table']: + if not rec["referred_table"]: if resolve_synonyms: - ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = \ - self._resolve_synonym( - connection, - desired_owner=self.denormalize_name( - remote_owner), - desired_table=self.denormalize_name( - remote_table) - ) + ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = self._resolve_synonym( + connection, + desired_owner=self.denormalize_name(remote_owner), + desired_table=self.denormalize_name(remote_table), + ) if ref_synonym: remote_table = self.normalize_name(ref_synonym) remote_owner = self.normalize_name( - ref_remote_owner) + ref_remote_owner + ) - rec['referred_table'] = remote_table + rec["referred_table"] = remote_table - if requested_schema is not None or \ - self.denormalize_name(remote_owner) != schema: - rec['referred_schema'] = remote_owner + if ( + requested_schema is not None + or self.denormalize_name(remote_owner) != schema + ): + rec["referred_schema"] = remote_owner - if row[9] != 'NO ACTION': - rec['options']['ondelete'] = row[9] + if row[9] != "NO ACTION": + rec["options"]["ondelete"] = row[9] local_cols.append(local_column) remote_cols.append(remote_column) @@ -1701,54 +1855,82 @@ class OracleDialect(default.DefaultDialect): return list(fkeys.values()) @reflection.cache - def get_unique_constraints(self, connection, table_name, schema=None, **kw): - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) constraint_data = self._get_constraint_data( - connection, table_name, schema, dblink, - info_cache=kw.get('info_cache')) + connection, + table_name, + schema, + dblink, + info_cache=kw.get("info_cache"), + ) - unique_keys = filter(lambda x: x[1] == 'U', constraint_data) + unique_keys = filter(lambda x: x[1] == "U", constraint_data) uniques_group = groupby(unique_keys, lambda x: x[0]) - index_names = set([ix['name'] for ix in self.get_indexes(connection, table_name, schema=schema)]) + index_names = set( + [ + ix["name"] + for ix in self.get_indexes( + connection, table_name, schema=schema + ) + ] + ) return [ { - 'name': name, - 'column_names': cols, - 'duplicates_index': name if name in index_names else None + "name": name, + "column_names": cols, + "duplicates_index": name if name in index_names else None, } - for name, cols in - [ + for name, cols in [ [ self.normalize_name(i[0]), - [self.normalize_name(x[2]) for x in i[1]] - ] for i in uniques_group + [self.normalize_name(x[2]) for x in i[1]], + ] + for i in uniques_group ] ] @reflection.cache - def get_view_definition(self, connection, view_name, schema=None, - resolve_synonyms=False, dblink='', **kw): - info_cache = kw.get('info_cache') - (view_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, view_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) - - params = {'view_name': view_name} + def get_view_definition( + self, + connection, + view_name, + schema=None, + resolve_synonyms=False, + dblink="", + **kw + ): + info_cache = kw.get("info_cache") + (view_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + view_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) + + params = {"view_name": view_name} text = "SELECT text FROM all_views WHERE view_name=:view_name" if schema is not None: text += " AND owner = :schema" - params['schema'] = schema + params["schema"] = schema rp = connection.execute(sql.text(text), **params).scalar() if rp: @@ -1759,34 +1941,41 @@ class OracleDialect(default.DefaultDialect): return None @reflection.cache - def get_check_constraints(self, connection, table_name, schema=None, - include_all=False, **kw): - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + def get_check_constraints( + self, connection, table_name, schema=None, include_all=False, **kw + ): + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) constraint_data = self._get_constraint_data( - connection, table_name, schema, dblink, - info_cache=kw.get('info_cache')) + connection, + table_name, + schema, + dblink, + info_cache=kw.get("info_cache"), + ) - check_constraints = filter(lambda x: x[1] == 'C', constraint_data) + check_constraints = filter(lambda x: x[1] == "C", constraint_data) return [ - { - 'name': self.normalize_name(cons[0]), - 'sqltext': cons[8], - } - for cons in check_constraints if include_all or - not re.match(r'..+?. IS NOT NULL$', cons[8])] + {"name": self.normalize_name(cons[0]), "sqltext": cons[8]} + for cons in check_constraints + if include_all or not re.match(r"..+?. IS NOT NULL$", cons[8]) + ] class _OuterJoinColumn(sql.ClauseElement): - __visit_name__ = 'outer_join_column' + __visit_name__ = "outer_join_column" def __init__(self, column): self.column = column diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index a00e7d95ec..91534c0da2 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -296,16 +296,13 @@ class _OracleInteger(sqltypes.Integer): def _cx_oracle_var(self, dialect, cursor): cx_Oracle = dialect.dbapi return cursor.var( - cx_Oracle.STRING, - 255, - arraysize=cursor.arraysize, - outconverter=int + cx_Oracle.STRING, 255, arraysize=cursor.arraysize, outconverter=int ) def _cx_oracle_outputtypehandler(self, dialect): - def handler(cursor, name, - default_type, size, precision, scale): + def handler(cursor, name, default_type, size, precision, scale): return self._cx_oracle_var(dialect, cursor) + return handler @@ -317,7 +314,8 @@ class _OracleNumeric(sqltypes.Numeric): return None elif self.asdecimal: processor = processors.to_decimal_processor_factory( - decimal.Decimal, self._effective_decimal_return_scale) + decimal.Decimal, self._effective_decimal_return_scale + ) def process(value): if isinstance(value, (int, float)): @@ -326,6 +324,7 @@ class _OracleNumeric(sqltypes.Numeric): return float(value) else: return value + return process else: return processors.to_float @@ -383,9 +382,10 @@ class _OracleNumeric(sqltypes.Numeric): type_ = cx_Oracle.NATIVE_FLOAT return cursor.var( - type_, 255, + type_, + 255, arraysize=cursor.arraysize, - outconverter=outconverter + outconverter=outconverter, ) return handler @@ -418,6 +418,7 @@ class _OracleDate(sqltypes.Date): return value.date() else: return value + return process @@ -467,6 +468,7 @@ class _OracleEnum(sqltypes.Enum): def process(value): raw_str = enum_proc(value) return raw_str + return process @@ -482,7 +484,8 @@ class _OracleBinary(sqltypes.LargeBinary): return None else: return super(_OracleBinary, self).result_processor( - dialect, coltype) + dialect, coltype + ) class _OracleInterval(oracle.INTERVAL): @@ -503,14 +506,18 @@ class OracleCompiler_cx_oracle(OracleCompiler): _oracle_cx_sql_compiler = True def bindparam_string(self, name, **kw): - quote = getattr(name, 'quote', None) - if quote is True or quote is not False and \ - self.preparer._bindparam_requires_quotes(name): - if kw.get('expanding', False): + quote = getattr(name, "quote", None) + if ( + quote is True + or quote is not False + and self.preparer._bindparam_requires_quotes(name) + ): + if kw.get("expanding", False): raise exc.CompileError( "Can't use expanding feature with parameter name " "%r on Oracle; it requires quoting which is not supported " - "in this context." % name) + "in this context." % name + ) quoted_name = '"%s"' % name self._quoted_bind_names[name] = quoted_name return OracleCompiler.bindparam_string(self, quoted_name, **kw) @@ -537,21 +544,22 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): if bindparam.isoutparam: name = self.compiled.bind_names[bindparam] type_impl = bindparam.type.dialect_impl(self.dialect) - if hasattr(type_impl, '_cx_oracle_var'): + if hasattr(type_impl, "_cx_oracle_var"): self.out_parameters[name] = type_impl._cx_oracle_var( - self.dialect, self.cursor) + self.dialect, self.cursor + ) else: dbtype = type_impl.get_dbapi_type(self.dialect.dbapi) if dbtype is None: raise exc.InvalidRequestError( "Cannot create out parameter for parameter " "%r - its type %r is not supported by" - " cx_oracle" % - (bindparam.key, bindparam.type) + " cx_oracle" % (bindparam.key, bindparam.type) ) self.out_parameters[name] = self.cursor.var(dbtype) - self.parameters[0][quoted_bind_names.get(name, name)] = \ - self.out_parameters[name] + self.parameters[0][ + quoted_bind_names.get(name, name) + ] = self.out_parameters[name] def _generate_cursor_outputtype_handler(self): output_handlers = {} @@ -559,8 +567,9 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): for (keyname, name, objects, type_) in self.compiled._result_columns: handler = type_._cached_custom_processor( self.dialect, - 'cx_oracle_outputtypehandler', - self._get_cx_oracle_type_handler) + "cx_oracle_outputtypehandler", + self._get_cx_oracle_type_handler, + ) if handler: denormalized_name = self.dialect.denormalize_name(keyname) @@ -569,16 +578,18 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): if output_handlers: default_handler = self._dbapi_connection.outputtypehandler - def output_type_handler(cursor, name, default_type, - size, precision, scale): + def output_type_handler( + cursor, name, default_type, size, precision, scale + ): if name in output_handlers: return output_handlers[name]( - cursor, name, - default_type, size, precision, scale) + cursor, name, default_type, size, precision, scale + ) else: return default_handler( cursor, name, default_type, size, precision, scale ) + self.cursor.outputtypehandler = output_type_handler def _get_cx_oracle_type_handler(self, impl): @@ -598,7 +609,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): self.set_input_sizes( self.compiled._quoted_bind_names, - include_types=self.dialect._include_setinputsizes + include_types=self.dialect._include_setinputsizes, ) self._handle_out_parameters() @@ -615,9 +626,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): def get_result_proxy(self): if self.out_parameters and self.compiled.returning: returning_params = [ - self.dialect._returningval( - self.out_parameters["ret_%d" % i] - ) + self.dialect._returningval(self.out_parameters["ret_%d" % i]) for i in range(len(self.out_parameters)) ] return ReturningResultProxy(self, returning_params) @@ -625,8 +634,10 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): result = _result.ResultProxy(self) if self.out_parameters: - if self.compiled_parameters is not None and \ - len(self.compiled_parameters) == 1: + if ( + self.compiled_parameters is not None + and len(self.compiled_parameters) == 1 + ): result.out_parameters = out_parameters = {} for bind, name in self.compiled.bind_names.items(): @@ -634,22 +645,24 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): type = bind.type impl_type = type.dialect_impl(self.dialect) dbapi_type = impl_type.get_dbapi_type( - self.dialect.dbapi) - result_processor = impl_type.\ - result_processor(self.dialect, - dbapi_type) + self.dialect.dbapi + ) + result_processor = impl_type.result_processor( + self.dialect, dbapi_type + ) if result_processor is not None: - out_parameters[name] = \ - result_processor( - self.dialect._paramval( - self.out_parameters[name] - )) + out_parameters[name] = result_processor( + self.dialect._paramval( + self.out_parameters[name] + ) + ) else: out_parameters[name] = self.dialect._paramval( - self.out_parameters[name]) + self.out_parameters[name] + ) else: result.out_parameters = dict( - (k, self._dialect._paramval(v)) + (k, self._dialect._paramval(v)) for k, v in self.out_parameters.items() ) @@ -667,14 +680,11 @@ class ReturningResultProxy(_result.FullyBufferedResultProxy): def _cursor_description(self): returning = self.context.compiled.returning return [ - (getattr(col, 'name', col.anon_label), None) - for col in returning + (getattr(col, "name", col.anon_label), None) for col in returning ] def _buffer_rows(self): - return collections.deque( - [tuple(self._returning_params)] - ) + return collections.deque([tuple(self._returning_params)]) class OracleDialect_cx_oracle(OracleDialect): @@ -696,7 +706,6 @@ class OracleDialect_cx_oracle(OracleDialect): oracle.BINARY_DOUBLE: _OracleBINARY_DOUBLE, sqltypes.Integer: _OracleInteger, oracle.NUMBER: _OracleNUMBER, - sqltypes.Date: _OracleDate, sqltypes.LargeBinary: _OracleBinary, sqltypes.Boolean: oracle._OracleBoolean, @@ -707,7 +716,6 @@ class OracleDialect_cx_oracle(OracleDialect): sqltypes.UnicodeText: _OracleUnicodeTextCLOB, sqltypes.CHAR: _OracleChar, sqltypes.Enum: _OracleEnum, - oracle.LONG: _OracleLong, oracle.RAW: _OracleRaw, sqltypes.Unicode: _OracleUnicodeStringCHAR, @@ -721,13 +729,15 @@ class OracleDialect_cx_oracle(OracleDialect): _cx_oracle_threaded = None - def __init__(self, - auto_convert_lobs=True, - coerce_to_unicode=True, - coerce_to_decimal=True, - arraysize=50, - threaded=None, - **kwargs): + def __init__( + self, + auto_convert_lobs=True, + coerce_to_unicode=True, + coerce_to_decimal=True, + arraysize=50, + threaded=None, + **kwargs + ): OracleDialect.__init__(self, **kwargs) self.arraysize = arraysize @@ -757,15 +767,23 @@ class OracleDialect_cx_oracle(OracleDialect): self.cx_oracle_ver = self._parse_cx_oracle_ver(cx_Oracle.version) if self.cx_oracle_ver < (5, 2) and self.cx_oracle_ver > (0, 0, 0): raise exc.InvalidRequestError( - "cx_Oracle version 5.2 and above are supported") + "cx_Oracle version 5.2 and above are supported" + ) self._has_native_int = hasattr(cx_Oracle, "NATIVE_INT") self._include_setinputsizes = { - cx_Oracle.NCLOB, cx_Oracle.CLOB, cx_Oracle.LOB, - cx_Oracle.NCHAR, cx_Oracle.FIXED_NCHAR, - cx_Oracle.BLOB, cx_Oracle.FIXED_CHAR, cx_Oracle.TIMESTAMP, - _OracleInteger, _OracleBINARY_FLOAT, _OracleBINARY_DOUBLE + cx_Oracle.NCLOB, + cx_Oracle.CLOB, + cx_Oracle.LOB, + cx_Oracle.NCHAR, + cx_Oracle.FIXED_NCHAR, + cx_Oracle.BLOB, + cx_Oracle.FIXED_CHAR, + cx_Oracle.TIMESTAMP, + _OracleInteger, + _OracleBINARY_FLOAT, + _OracleBINARY_DOUBLE, } self._paramval = lambda value: value.getvalue() @@ -786,18 +804,19 @@ class OracleDialect_cx_oracle(OracleDialect): else: self._returningval = self._paramval - self._is_cx_oracle_6 = self.cx_oracle_ver >= (6, ) + self._is_cx_oracle_6 = self.cx_oracle_ver >= (6,) def _pop_deprecated_kwargs(self, kwargs): - auto_setinputsizes = kwargs.pop('auto_setinputsizes', None) - exclude_setinputsizes = kwargs.pop('exclude_setinputsizes', None) + auto_setinputsizes = kwargs.pop("auto_setinputsizes", None) + exclude_setinputsizes = kwargs.pop("exclude_setinputsizes", None) if auto_setinputsizes or exclude_setinputsizes: util.warn_deprecated( "auto_setinputsizes and exclude_setinputsizes are deprecated. " "Modern cx_Oracle only requires that LOB types are part " "of this behavior, and these parameters no longer have any " - "effect.") - allow_twophase = kwargs.pop('allow_twophase', None) + "effect." + ) + allow_twophase = kwargs.pop("allow_twophase", None) if allow_twophase is not None: util.warn.deprecated( "allow_twophase is deprecated. The cx_Oracle dialect no " @@ -805,18 +824,16 @@ class OracleDialect_cx_oracle(OracleDialect): ) def _parse_cx_oracle_ver(self, version): - m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?', version) + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version) if m: - return tuple( - int(x) - for x in m.group(1, 2, 3) - if x is not None) + return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) else: return (0, 0, 0) @classmethod def dbapi(cls): import cx_Oracle + return cx_Oracle def initialize(self, connection): @@ -835,15 +852,18 @@ class OracleDialect_cx_oracle(OracleDialect): self._decimal_char = connection.scalar( "select value from nls_session_parameters " - "where parameter = 'NLS_NUMERIC_CHARACTERS'")[0] - if self._decimal_char != '.': + "where parameter = 'NLS_NUMERIC_CHARACTERS'" + )[0] + if self._decimal_char != ".": _detect_decimal = self._detect_decimal _to_decimal = self._to_decimal self._detect_decimal = lambda value: _detect_decimal( - value.replace(self._decimal_char, ".")) + value.replace(self._decimal_char, ".") + ) self._to_decimal = lambda value: _to_decimal( - value.replace(self._decimal_char, ".")) + value.replace(self._decimal_char, ".") + ) def _detect_decimal(self, value): if "." in value: @@ -862,13 +882,16 @@ class OracleDialect_cx_oracle(OracleDialect): dialect = self cx_Oracle = dialect.dbapi - number_handler = _OracleNUMBER(asdecimal=True).\ - _cx_oracle_outputtypehandler(dialect) - float_handler = _OracleNUMBER(asdecimal=False).\ - _cx_oracle_outputtypehandler(dialect) + number_handler = _OracleNUMBER( + asdecimal=True + )._cx_oracle_outputtypehandler(dialect) + float_handler = _OracleNUMBER( + asdecimal=False + )._cx_oracle_outputtypehandler(dialect) - def output_type_handler(cursor, name, default_type, - size, precision, scale): + def output_type_handler( + cursor, name, default_type, size, precision, scale + ): if default_type == cx_Oracle.NUMBER: if not dialect.coerce_to_decimal: return None @@ -879,7 +902,8 @@ class OracleDialect_cx_oracle(OracleDialect): cx_Oracle.STRING, 255, outconverter=dialect._detect_decimal, - arraysize=cursor.arraysize) + arraysize=cursor.arraysize, + ) elif precision and scale > 0: return number_handler( cursor, name, default_type, size, precision, scale @@ -890,43 +914,55 @@ class OracleDialect_cx_oracle(OracleDialect): ) # allow all strings to come back natively as Unicode - elif dialect.coerce_to_unicode and \ - default_type in (cx_Oracle.STRING, cx_Oracle.FIXED_CHAR): + elif dialect.coerce_to_unicode and default_type in ( + cx_Oracle.STRING, + cx_Oracle.FIXED_CHAR, + ): if compat.py2k: outconverter = processors.to_unicode_processor_factory( - dialect.encoding, None) - return cursor.var( - cx_Oracle.STRING, size, cursor.arraysize, - outconverter=outconverter + dialect.encoding, None ) - else: return cursor.var( - util.text_type, size, cursor.arraysize + cx_Oracle.STRING, + size, + cursor.arraysize, + outconverter=outconverter, ) + else: + return cursor.var(util.text_type, size, cursor.arraysize) elif dialect.auto_convert_lobs and default_type in ( - cx_Oracle.CLOB, cx_Oracle.NCLOB + cx_Oracle.CLOB, + cx_Oracle.NCLOB, ): if compat.py2k: outconverter = processors.to_unicode_processor_factory( - dialect.encoding, None) + dialect.encoding, None + ) return cursor.var( - default_type, size, cursor.arraysize, - outconverter=lambda value: outconverter(value.read()) + default_type, + size, + cursor.arraysize, + outconverter=lambda value: outconverter(value.read()), ) else: return cursor.var( - default_type, size, cursor.arraysize, - outconverter=lambda value: value.read() + default_type, + size, + cursor.arraysize, + outconverter=lambda value: value.read(), ) elif dialect.auto_convert_lobs and default_type in ( - cx_Oracle.BLOB, + cx_Oracle.BLOB, ): return cursor.var( - default_type, size, cursor.arraysize, - outconverter=lambda value: value.read() + default_type, + size, + cursor.arraysize, + outconverter=lambda value: value.read(), ) + return output_type_handler def on_connect(self): @@ -941,16 +977,17 @@ class OracleDialect_cx_oracle(OracleDialect): def create_connect_args(self, url): opts = dict(url.query) - for opt in ('use_ansi', 'auto_convert_lobs'): + for opt in ("use_ansi", "auto_convert_lobs"): if opt in opts: util.warn_deprecated( "cx_oracle dialect option %r should only be passed to " - "create_engine directly, not within the URL string" % opt) + "create_engine directly, not within the URL string" % opt + ) util.coerce_kw_type(opts, opt, bool) setattr(self, opt, opts.pop(opt)) database = url.database - service_name = opts.pop('service_name', None) + service_name = opts.pop("service_name", None) if database or service_name: # if we have a database, then we have a remote host port = url.port @@ -962,11 +999,12 @@ class OracleDialect_cx_oracle(OracleDialect): if database and service_name: raise exc.InvalidRequestError( '"service_name" option shouldn\'t ' - 'be used with a "database" part of the url') + 'be used with a "database" part of the url' + ) if database: - makedsn_kwargs = {'sid': database} + makedsn_kwargs = {"sid": database} if service_name: - makedsn_kwargs = {'service_name': service_name} + makedsn_kwargs = {"service_name": service_name} dsn = self.dbapi.makedsn(url.host, port, **makedsn_kwargs) else: @@ -974,11 +1012,11 @@ class OracleDialect_cx_oracle(OracleDialect): dsn = url.host if dsn is not None: - opts['dsn'] = dsn + opts["dsn"] = dsn if url.password is not None: - opts['password'] = url.password + opts["password"] = url.password if url.username is not None: - opts['user'] = url.username + opts["user"] = url.username if self._cx_oracle_threaded is not None: opts.setdefault("threaded", self._cx_oracle_threaded) @@ -995,28 +1033,24 @@ class OracleDialect_cx_oracle(OracleDialect): else: return value - util.coerce_kw_type(opts, 'mode', convert_cx_oracle_constant) - util.coerce_kw_type(opts, 'threaded', bool) - util.coerce_kw_type(opts, 'events', bool) - util.coerce_kw_type(opts, 'purity', convert_cx_oracle_constant) + util.coerce_kw_type(opts, "mode", convert_cx_oracle_constant) + util.coerce_kw_type(opts, "threaded", bool) + util.coerce_kw_type(opts, "events", bool) + util.coerce_kw_type(opts, "purity", convert_cx_oracle_constant) return ([], opts) def _get_server_version_info(self, connection): - return tuple( - int(x) - for x in connection.connection.version.split('.') - ) + return tuple(int(x) for x in connection.connection.version.split(".")) def is_disconnect(self, e, connection, cursor): error, = e.args if isinstance( - e, - (self.dbapi.InterfaceError, self.dbapi.DatabaseError) + e, (self.dbapi.InterfaceError, self.dbapi.DatabaseError) ) and "not connected" in str(e): return True - if hasattr(error, 'code'): + if hasattr(error, "code"): # ORA-00028: your session has been killed # ORA-03114: not connected to ORACLE # ORA-03113: end-of-file on communication channel @@ -1052,22 +1086,25 @@ class OracleDialect_cx_oracle(OracleDialect): def do_prepare_twophase(self, connection, xid): result = connection.connection.prepare() - connection.info['cx_oracle_prepared'] = result + connection.info["cx_oracle_prepared"] = result - def do_rollback_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): self.do_rollback(connection.connection) - def do_commit_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: self.do_commit(connection.connection) else: - oci_prepared = connection.info['cx_oracle_prepared'] + oci_prepared = connection.info["cx_oracle_prepared"] if oci_prepared: self.do_commit(connection.connection) def do_recover_twophase(self, connection): - connection.info.pop('cx_oracle_prepared', None) + connection.info.pop("cx_oracle_prepared", None) + dialect = OracleDialect_cx_oracle diff --git a/lib/sqlalchemy/dialects/oracle/zxjdbc.py b/lib/sqlalchemy/dialects/oracle/zxjdbc.py index aa2562573d..0a365f8b02 100644 --- a/lib/sqlalchemy/dialects/oracle/zxjdbc.py +++ b/lib/sqlalchemy/dialects/oracle/zxjdbc.py @@ -21,9 +21,11 @@ import re from sqlalchemy import sql, types as sqltypes, util from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector -from sqlalchemy.dialects.oracle.base import (OracleCompiler, - OracleDialect, - OracleExecutionContext) +from sqlalchemy.dialects.oracle.base import ( + OracleCompiler, + OracleDialect, + OracleExecutionContext, +) from sqlalchemy.engine import result as _result from sqlalchemy.sql import expression import collections @@ -32,92 +34,100 @@ SQLException = zxJDBC = None class _ZxJDBCDate(sqltypes.Date): - def result_processor(self, dialect, coltype): def process(value): if value is None: return None else: return value.date() + return process class _ZxJDBCNumeric(sqltypes.Numeric): - def result_processor(self, dialect, coltype): # XXX: does the dialect return Decimal or not??? # if it does (in all cases), we could use a None processor as well as # the to_float generic processor if self.asdecimal: + def process(value): if isinstance(value, decimal.Decimal): return value else: return decimal.Decimal(str(value)) + else: + def process(value): if isinstance(value, decimal.Decimal): return float(value) else: return value + return process class OracleCompiler_zxjdbc(OracleCompiler): - def returning_clause(self, stmt, returning_cols): self.returning_cols = list( - expression._select_iterables(returning_cols)) + expression._select_iterables(returning_cols) + ) # within_columns_clause=False so that labels (foo AS bar) don't render - columns = [self.process(c, within_columns_clause=False) - for c in self.returning_cols] + columns = [ + self.process(c, within_columns_clause=False) + for c in self.returning_cols + ] - if not hasattr(self, 'returning_parameters'): + if not hasattr(self, "returning_parameters"): self.returning_parameters = [] binds = [] for i, col in enumerate(self.returning_cols): - dbtype = col.type.dialect_impl( - self.dialect).get_dbapi_type(self.dialect.dbapi) + dbtype = col.type.dialect_impl(self.dialect).get_dbapi_type( + self.dialect.dbapi + ) self.returning_parameters.append((i + 1, dbtype)) bindparam = sql.bindparam( - "ret_%d" % i, value=ReturningParam(dbtype)) + "ret_%d" % i, value=ReturningParam(dbtype) + ) self.binds[bindparam.key] = bindparam binds.append( - self.bindparam_string(self._truncate_bindparam(bindparam))) + self.bindparam_string(self._truncate_bindparam(bindparam)) + ) - return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) + return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds) class OracleExecutionContext_zxjdbc(OracleExecutionContext): - def pre_exec(self): - if hasattr(self.compiled, 'returning_parameters'): + if hasattr(self.compiled, "returning_parameters"): # prepare a zxJDBC statement so we can grab its underlying # OraclePreparedStatement's getReturnResultSet later self.statement = self.cursor.prepare(self.statement) def get_result_proxy(self): - if hasattr(self.compiled, 'returning_parameters'): + if hasattr(self.compiled, "returning_parameters"): rrs = None try: try: rrs = self.statement.__statement__.getReturnResultSet() next(rrs) except SQLException as sqle: - msg = '%s [SQLCode: %d]' % ( - sqle.getMessage(), sqle.getErrorCode()) + msg = "%s [SQLCode: %d]" % ( + sqle.getMessage(), + sqle.getErrorCode(), + ) if sqle.getSQLState() is not None: - msg += ' [SQLState: %s]' % sqle.getSQLState() + msg += " [SQLState: %s]" % sqle.getSQLState() raise zxJDBC.Error(msg) else: row = tuple( - self.cursor.datahandler.getPyObject( - rrs, index, dbtype) - for index, dbtype in - self.compiled.returning_parameters) + self.cursor.datahandler.getPyObject(rrs, index, dbtype) + for index, dbtype in self.compiled.returning_parameters + ) return ReturningResultProxy(self, row) finally: if rrs is not None: @@ -146,7 +156,7 @@ class ReturningResultProxy(_result.FullyBufferedResultProxy): def _cursor_description(self): ret = [] for c in self.context.compiled.returning_cols: - if hasattr(c, 'name'): + if hasattr(c, "name"): ret.append((c.name, c.type)) else: ret.append((c.anon_label, c.type)) @@ -178,23 +188,24 @@ class ReturningParam(object): def __repr__(self): kls = self.__class__ - return '<%s.%s object at 0x%x type=%s>' % ( - kls.__module__, kls.__name__, id(self), self.type) + return "<%s.%s object at 0x%x type=%s>" % ( + kls.__module__, + kls.__name__, + id(self), + self.type, + ) class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect): - jdbc_db_name = 'oracle' - jdbc_driver_name = 'oracle.jdbc.OracleDriver' + jdbc_db_name = "oracle" + jdbc_driver_name = "oracle.jdbc.OracleDriver" statement_compiler = OracleCompiler_zxjdbc execution_ctx_cls = OracleExecutionContext_zxjdbc colspecs = util.update_copy( OracleDialect.colspecs, - { - sqltypes.Date: _ZxJDBCDate, - sqltypes.Numeric: _ZxJDBCNumeric - } + {sqltypes.Date: _ZxJDBCDate, sqltypes.Numeric: _ZxJDBCNumeric}, ) def __init__(self, *args, **kwargs): @@ -212,24 +223,31 @@ class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect): statement.registerReturnParameter(index, object.type) elif dbtype is None: OracleDataHandler.setJDBCObject( - self, statement, index, object) + self, statement, index, object + ) else: OracleDataHandler.setJDBCObject( - self, statement, index, object, dbtype) + self, statement, index, object, dbtype + ) + self.DataHandler = OracleReturningDataHandler def initialize(self, connection): super(OracleDialect_zxjdbc, self).initialize(connection) - self.implicit_returning = \ - connection.connection.driverversion >= '10.2' + self.implicit_returning = connection.connection.driverversion >= "10.2" def _create_jdbc_url(self, url): - return 'jdbc:oracle:thin:@%s:%s:%s' % ( - url.host, url.port or 1521, url.database) + return "jdbc:oracle:thin:@%s:%s:%s" % ( + url.host, + url.port or 1521, + url.database, + ) def _get_server_version_info(self, connection): version = re.search( - r'Release ([\d\.]+)', connection.connection.dbversion).group(1) - return tuple(int(x) for x in version.split('.')) + r"Release ([\d\.]+)", connection.connection.dbversion + ).group(1) + return tuple(int(x) for x in version.split(".")) + dialect = OracleDialect_zxjdbc diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 84f7200285..9e65484fa1 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -5,33 +5,110 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from . import base, psycopg2, pg8000, pypostgresql, pygresql, \ - zxjdbc, psycopg2cffi # noqa +from . import ( + base, + psycopg2, + pg8000, + pypostgresql, + pygresql, + zxjdbc, + psycopg2cffi, +) # noqa -from .base import \ - INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \ - INET, CIDR, UUID, BIT, MACADDR, MONEY, OID, REGCLASS, DOUBLE_PRECISION, \ - TIMESTAMP, TIME, DATE, BYTEA, BOOLEAN, INTERVAL, ENUM, TSVECTOR, \ - DropEnumType, CreateEnumType +from .base import ( + INTEGER, + BIGINT, + SMALLINT, + VARCHAR, + CHAR, + TEXT, + NUMERIC, + FLOAT, + REAL, + INET, + CIDR, + UUID, + BIT, + MACADDR, + MONEY, + OID, + REGCLASS, + DOUBLE_PRECISION, + TIMESTAMP, + TIME, + DATE, + BYTEA, + BOOLEAN, + INTERVAL, + ENUM, + TSVECTOR, + DropEnumType, + CreateEnumType, +) from .hstore import HSTORE, hstore from .json import JSON, JSONB from .array import array, ARRAY, Any, All from .ext import aggregate_order_by, ExcludeConstraint, array_agg from .dml import insert, Insert -from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \ - TSTZRANGE +from .ranges import ( + INT4RANGE, + INT8RANGE, + NUMRANGE, + DATERANGE, + TSRANGE, + TSTZRANGE, +) base.dialect = dialect = psycopg2.dialect __all__ = ( - 'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC', - 'FLOAT', 'REAL', 'INET', 'CIDR', 'UUID', 'BIT', 'MACADDR', 'MONEY', 'OID', - 'REGCLASS', 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA', - 'BOOLEAN', 'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'array', 'HSTORE', - 'hstore', 'INT4RANGE', 'INT8RANGE', 'NUMRANGE', 'DATERANGE', - 'TSRANGE', 'TSTZRANGE', 'JSON', 'JSONB', 'Any', 'All', - 'DropEnumType', 'CreateEnumType', 'ExcludeConstraint', - 'aggregate_order_by', 'array_agg', 'insert', 'Insert' + "INTEGER", + "BIGINT", + "SMALLINT", + "VARCHAR", + "CHAR", + "TEXT", + "NUMERIC", + "FLOAT", + "REAL", + "INET", + "CIDR", + "UUID", + "BIT", + "MACADDR", + "MONEY", + "OID", + "REGCLASS", + "DOUBLE_PRECISION", + "TIMESTAMP", + "TIME", + "DATE", + "BYTEA", + "BOOLEAN", + "INTERVAL", + "ARRAY", + "ENUM", + "dialect", + "array", + "HSTORE", + "hstore", + "INT4RANGE", + "INT8RANGE", + "NUMRANGE", + "DATERANGE", + "TSRANGE", + "TSTZRANGE", + "JSON", + "JSONB", + "Any", + "All", + "DropEnumType", + "CreateEnumType", + "ExcludeConstraint", + "aggregate_order_by", + "array_agg", + "insert", + "Insert", ) diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index b2674046e4..07167f9d0f 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -78,7 +78,8 @@ class array(expression.Tuple): :class:`.postgresql.ARRAY` """ - __visit_name__ = 'array' + + __visit_name__ = "array" def __init__(self, clauses, **kw): super(array, self).__init__(*clauses, **kw) @@ -90,18 +91,26 @@ class array(expression.Tuple): # a Slice object from that assert isinstance(obj, int) return expression.BindParameter( - None, obj, _compared_to_operator=operator, + None, + obj, + _compared_to_operator=operator, type_=type_, - _compared_to_type=self.type, unique=True) + _compared_to_type=self.type, + unique=True, + ) else: - return array([ - self._bind_param(operator, o, _assume_scalar=True, type_=type_) - for o in obj]) + return array( + [ + self._bind_param( + operator, o, _assume_scalar=True, type_=type_ + ) + for o in obj + ] + ) def self_group(self, against=None): - if (against in ( - operators.any_op, operators.all_op, operators.getitem)): + if against in (operators.any_op, operators.all_op, operators.getitem): return expression.Grouping(self) else: return self @@ -180,7 +189,8 @@ class ARRAY(sqltypes.ARRAY): elements of the argument array expression. """ return self.operate( - CONTAINED_BY, other, result_type=sqltypes.Boolean) + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) def overlap(self, other): """Boolean expression. Test if array has elements in common with @@ -190,8 +200,9 @@ class ARRAY(sqltypes.ARRAY): comparator_factory = Comparator - def __init__(self, item_type, as_tuple=False, dimensions=None, - zero_indexes=False): + def __init__( + self, item_type, as_tuple=False, dimensions=None, zero_indexes=False + ): """Construct an ARRAY. E.g.:: @@ -228,8 +239,10 @@ class ARRAY(sqltypes.ARRAY): """ if isinstance(item_type, ARRAY): - raise ValueError("Do not nest ARRAY types; ARRAY(basetype) " - "handles multi-dimensional arrays of basetype") + raise ValueError( + "Do not nest ARRAY types; ARRAY(basetype) " + "handles multi-dimensional arrays of basetype" + ) if isinstance(item_type, type): item_type = item_type() self.item_type = item_type @@ -251,11 +264,17 @@ class ARRAY(sqltypes.ARRAY): def _proc_array(self, arr, itemproc, dim, collection): if dim is None: arr = list(arr) - if dim == 1 or dim is None and ( + if ( + dim == 1 + or dim is None + and ( # this has to be (list, tuple), or at least # not hasattr('__iter__'), since Py3K strings # etc. have __iter__ - not arr or not isinstance(arr[0], (list, tuple))): + not arr + or not isinstance(arr[0], (list, tuple)) + ) + ): if itemproc: return collection(itemproc(x) for x in arr) else: @@ -263,30 +282,33 @@ class ARRAY(sqltypes.ARRAY): else: return collection( self._proc_array( - x, itemproc, + x, + itemproc, dim - 1 if dim is not None else None, - collection) + collection, + ) for x in arr ) def bind_processor(self, dialect): - item_proc = self.item_type.dialect_impl(dialect).\ - bind_processor(dialect) + item_proc = self.item_type.dialect_impl(dialect).bind_processor( + dialect + ) def process(value): if value is None: return value else: return self._proc_array( - value, - item_proc, - self.dimensions, - list) + value, item_proc, self.dimensions, list + ) + return process def result_processor(self, dialect, coltype): - item_proc = self.item_type.dialect_impl(dialect).\ - result_processor(dialect, coltype) + item_proc = self.item_type.dialect_impl(dialect).result_processor( + dialect, coltype + ) def process(value): if value is None: @@ -296,8 +318,11 @@ class ARRAY(sqltypes.ARRAY): value, item_proc, self.dimensions, - tuple if self.as_tuple else list) + tuple if self.as_tuple else list, + ) + return process + colspecs[sqltypes.ARRAY] = ARRAY -ischema_names['_array'] = ARRAY +ischema_names["_array"] = ARRAY diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index d68ab8ef58..11833da573 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -930,57 +930,164 @@ try: except ImportError: _python_UUID = None -from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, VARCHAR, \ - CHAR, TEXT, FLOAT, NUMERIC, \ - DATE, BOOLEAN, REAL +from sqlalchemy.types import ( + INTEGER, + BIGINT, + SMALLINT, + VARCHAR, + CHAR, + TEXT, + FLOAT, + NUMERIC, + DATE, + BOOLEAN, + REAL, +) AUTOCOMMIT_REGEXP = re.compile( - r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|GRANT|REVOKE|' - 'IMPORT FOREIGN SCHEMA|REFRESH MATERIALIZED VIEW|TRUNCATE)', - re.I | re.UNICODE) + r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|GRANT|REVOKE|" + "IMPORT FOREIGN SCHEMA|REFRESH MATERIALIZED VIEW|TRUNCATE)", + re.I | re.UNICODE, +) RESERVED_WORDS = set( - ["all", "analyse", "analyze", "and", "any", "array", "as", "asc", - "asymmetric", "both", "case", "cast", "check", "collate", "column", - "constraint", "create", "current_catalog", "current_date", - "current_role", "current_time", "current_timestamp", "current_user", - "default", "deferrable", "desc", "distinct", "do", "else", "end", - "except", "false", "fetch", "for", "foreign", "from", "grant", "group", - "having", "in", "initially", "intersect", "into", "leading", "limit", - "localtime", "localtimestamp", "new", "not", "null", "of", "off", - "offset", "old", "on", "only", "or", "order", "placing", "primary", - "references", "returning", "select", "session_user", "some", "symmetric", - "table", "then", "to", "trailing", "true", "union", "unique", "user", - "using", "variadic", "when", "where", "window", "with", "authorization", - "between", "binary", "cross", "current_schema", "freeze", "full", - "ilike", "inner", "is", "isnull", "join", "left", "like", "natural", - "notnull", "outer", "over", "overlaps", "right", "similar", "verbose" - ]) + [ + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "both", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "current_catalog", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "fetch", + "for", + "foreign", + "from", + "grant", + "group", + "having", + "in", + "initially", + "intersect", + "into", + "leading", + "limit", + "localtime", + "localtimestamp", + "new", + "not", + "null", + "of", + "off", + "offset", + "old", + "on", + "only", + "or", + "order", + "placing", + "primary", + "references", + "returning", + "select", + "session_user", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "variadic", + "when", + "where", + "window", + "with", + "authorization", + "between", + "binary", + "cross", + "current_schema", + "freeze", + "full", + "ilike", + "inner", + "is", + "isnull", + "join", + "left", + "like", + "natural", + "notnull", + "outer", + "over", + "overlaps", + "right", + "similar", + "verbose", + ] +) _DECIMAL_TYPES = (1231, 1700) _FLOAT_TYPES = (700, 701, 1021, 1022) _INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) + class BYTEA(sqltypes.LargeBinary): - __visit_name__ = 'BYTEA' + __visit_name__ = "BYTEA" class DOUBLE_PRECISION(sqltypes.Float): - __visit_name__ = 'DOUBLE_PRECISION' + __visit_name__ = "DOUBLE_PRECISION" class INET(sqltypes.TypeEngine): __visit_name__ = "INET" + + PGInet = INET class CIDR(sqltypes.TypeEngine): __visit_name__ = "CIDR" + + PGCidr = CIDR class MACADDR(sqltypes.TypeEngine): __visit_name__ = "MACADDR" + + PGMacAddr = MACADDR @@ -991,6 +1098,7 @@ class MONEY(sqltypes.TypeEngine): .. versionadded:: 1.2 """ + __visit_name__ = "MONEY" @@ -1001,6 +1109,7 @@ class OID(sqltypes.TypeEngine): .. versionadded:: 0.9.5 """ + __visit_name__ = "OID" @@ -1011,18 +1120,17 @@ class REGCLASS(sqltypes.TypeEngine): .. versionadded:: 1.2.7 """ + __visit_name__ = "REGCLASS" class TIMESTAMP(sqltypes.TIMESTAMP): - def __init__(self, timezone=False, precision=None): super(TIMESTAMP, self).__init__(timezone=timezone) self.precision = precision class TIME(sqltypes.TIME): - def __init__(self, timezone=False, precision=None): super(TIME, self).__init__(timezone=timezone) self.precision = precision @@ -1036,7 +1144,8 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): It is known to work on psycopg2 and not pg8000 or zxjdbc. """ - __visit_name__ = 'INTERVAL' + + __visit_name__ = "INTERVAL" native = True def __init__(self, precision=None, fields=None): @@ -1065,11 +1174,12 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): def python_type(self): return dt.timedelta + PGInterval = INTERVAL class BIT(sqltypes.TypeEngine): - __visit_name__ = 'BIT' + __visit_name__ = "BIT" def __init__(self, length=None, varying=False): if not varying: @@ -1080,6 +1190,7 @@ class BIT(sqltypes.TypeEngine): self.length = length self.varying = varying + PGBit = BIT @@ -1095,7 +1206,8 @@ class UUID(sqltypes.TypeEngine): It is known to work on psycopg2 and not pg8000. """ - __visit_name__ = 'UUID' + + __visit_name__ = "UUID" def __init__(self, as_uuid=False): """Construct a UUID type. @@ -1115,24 +1227,29 @@ class UUID(sqltypes.TypeEngine): def bind_processor(self, dialect): if self.as_uuid: + def process(value): if value is not None: value = util.text_type(value) return value + return process else: return None def result_processor(self, dialect, coltype): if self.as_uuid: + def process(value): if value is not None: value = _python_UUID(value) return value + return process else: return None + PGUuid = UUID @@ -1151,7 +1268,8 @@ class TSVECTOR(sqltypes.TypeEngine): :ref:`postgresql_match` """ - __visit_name__ = 'TSVECTOR' + + __visit_name__ = "TSVECTOR" class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): @@ -1273,12 +1391,12 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): """ kw.setdefault("validate_strings", impl.validate_strings) - kw.setdefault('name', impl.name) - kw.setdefault('schema', impl.schema) - kw.setdefault('inherit_schema', impl.inherit_schema) - kw.setdefault('metadata', impl.metadata) - kw.setdefault('_create_events', False) - kw.setdefault('values_callable', impl.values_callable) + kw.setdefault("name", impl.name) + kw.setdefault("schema", impl.schema) + kw.setdefault("inherit_schema", impl.inherit_schema) + kw.setdefault("metadata", impl.metadata) + kw.setdefault("_create_events", False) + kw.setdefault("values_callable", impl.values_callable) return cls(**kw) def create(self, bind=None, checkfirst=True): @@ -1300,9 +1418,9 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): if not bind.dialect.supports_native_enum: return - if not checkfirst or \ - not bind.dialect.has_type( - bind, self.name, schema=self.schema): + if not checkfirst or not bind.dialect.has_type( + bind, self.name, schema=self.schema + ): bind.execute(CreateEnumType(self)) def drop(self, bind=None, checkfirst=True): @@ -1323,8 +1441,9 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): if not bind.dialect.supports_native_enum: return - if not checkfirst or \ - bind.dialect.has_type(bind, self.name, schema=self.schema): + if not checkfirst or bind.dialect.has_type( + bind, self.name, schema=self.schema + ): bind.execute(DropEnumType(self)) def _check_for_name_in_memos(self, checkfirst, kw): @@ -1338,12 +1457,12 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): """ if not self.create_type: return True - if '_ddl_runner' in kw: - ddl_runner = kw['_ddl_runner'] - if '_pg_enums' in ddl_runner.memo: - pg_enums = ddl_runner.memo['_pg_enums'] + if "_ddl_runner" in kw: + ddl_runner = kw["_ddl_runner"] + if "_pg_enums" in ddl_runner.memo: + pg_enums = ddl_runner.memo["_pg_enums"] else: - pg_enums = ddl_runner.memo['_pg_enums'] = set() + pg_enums = ddl_runner.memo["_pg_enums"] = set() present = (self.schema, self.name) in pg_enums pg_enums.add((self.schema, self.name)) return present @@ -1351,16 +1470,22 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): return False def _on_table_create(self, target, bind, checkfirst=False, **kw): - if checkfirst or ( - not self.metadata and - not kw.get('_is_metadata_operation', False)) and \ - not self._check_for_name_in_memos(checkfirst, kw): + if ( + checkfirst + or ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + ) + and not self._check_for_name_in_memos(checkfirst, kw) + ): self.create(bind=bind, checkfirst=checkfirst) def _on_table_drop(self, target, bind, checkfirst=False, **kw): - if not self.metadata and \ - not kw.get('_is_metadata_operation', False) and \ - not self._check_for_name_in_memos(checkfirst, kw): + if ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + and not self._check_for_name_in_memos(checkfirst, kw) + ): self.drop(bind=bind, checkfirst=checkfirst) def _on_metadata_create(self, target, bind, checkfirst=False, **kw): @@ -1371,49 +1496,46 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): if not self._check_for_name_in_memos(checkfirst, kw): self.drop(bind=bind, checkfirst=checkfirst) -colspecs = { - sqltypes.Interval: INTERVAL, - sqltypes.Enum: ENUM, -} + +colspecs = {sqltypes.Interval: INTERVAL, sqltypes.Enum: ENUM} ischema_names = { - 'integer': INTEGER, - 'bigint': BIGINT, - 'smallint': SMALLINT, - 'character varying': VARCHAR, - 'character': CHAR, + "integer": INTEGER, + "bigint": BIGINT, + "smallint": SMALLINT, + "character varying": VARCHAR, + "character": CHAR, '"char"': sqltypes.String, - 'name': sqltypes.String, - 'text': TEXT, - 'numeric': NUMERIC, - 'float': FLOAT, - 'real': REAL, - 'inet': INET, - 'cidr': CIDR, - 'uuid': UUID, - 'bit': BIT, - 'bit varying': BIT, - 'macaddr': MACADDR, - 'money': MONEY, - 'oid': OID, - 'regclass': REGCLASS, - 'double precision': DOUBLE_PRECISION, - 'timestamp': TIMESTAMP, - 'timestamp with time zone': TIMESTAMP, - 'timestamp without time zone': TIMESTAMP, - 'time with time zone': TIME, - 'time without time zone': TIME, - 'date': DATE, - 'time': TIME, - 'bytea': BYTEA, - 'boolean': BOOLEAN, - 'interval': INTERVAL, - 'tsvector': TSVECTOR + "name": sqltypes.String, + "text": TEXT, + "numeric": NUMERIC, + "float": FLOAT, + "real": REAL, + "inet": INET, + "cidr": CIDR, + "uuid": UUID, + "bit": BIT, + "bit varying": BIT, + "macaddr": MACADDR, + "money": MONEY, + "oid": OID, + "regclass": REGCLASS, + "double precision": DOUBLE_PRECISION, + "timestamp": TIMESTAMP, + "timestamp with time zone": TIMESTAMP, + "timestamp without time zone": TIMESTAMP, + "time with time zone": TIME, + "time without time zone": TIME, + "date": DATE, + "time": TIME, + "bytea": BYTEA, + "boolean": BOOLEAN, + "interval": INTERVAL, + "tsvector": TSVECTOR, } class PGCompiler(compiler.SQLCompiler): - def visit_array(self, element, **kw): return "ARRAY[%s]" % self.visit_clauselist(element, **kw) @@ -1424,77 +1546,75 @@ class PGCompiler(compiler.SQLCompiler): ) def visit_json_getitem_op_binary(self, binary, operator, **kw): - kw['eager_grouping'] = True - return self._generate_generic_binary( - binary, " -> ", **kw - ) + kw["eager_grouping"] = True + return self._generate_generic_binary(binary, " -> ", **kw) def visit_json_path_getitem_op_binary(self, binary, operator, **kw): - kw['eager_grouping'] = True - return self._generate_generic_binary( - binary, " #> ", **kw - ) + kw["eager_grouping"] = True + return self._generate_generic_binary(binary, " #> ", **kw) def visit_getitem_binary(self, binary, operator, **kw): return "%s[%s]" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) def visit_aggregate_order_by(self, element, **kw): return "%s ORDER BY %s" % ( self.process(element.target, **kw), - self.process(element.order_by, **kw) + self.process(element.order_by, **kw), ) def visit_match_op_binary(self, binary, operator, **kw): if "postgresql_regconfig" in binary.modifiers: regconfig = self.render_literal_value( - binary.modifiers['postgresql_regconfig'], - sqltypes.STRINGTYPE) + binary.modifiers["postgresql_regconfig"], sqltypes.STRINGTYPE + ) if regconfig: return "%s @@ to_tsquery(%s, %s)" % ( self.process(binary.left, **kw), regconfig, - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) return "%s @@ to_tsquery(%s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) def visit_ilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return '%s ILIKE %s' % \ - (self.process(binary.left, **kw), - self.process(binary.right, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + return "%s ILIKE %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_notilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return '%s NOT ILIKE %s' % \ - (self.process(binary.left, **kw), - self.process(binary.right, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + return "%s NOT ILIKE %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_empty_set_expr(self, element_types): # cast the empty set to the type we are comparing against. if # we are comparing against the null type, pick an arbitrary # datatype for the empty set - return 'SELECT %s WHERE 1!=1' % ( + return "SELECT %s WHERE 1!=1" % ( ", ".join( - "CAST(NULL AS %s)" % self.dialect.type_compiler.process( - INTEGER() if type_._isnull else type_, - ) for type_ in element_types or [INTEGER()] + "CAST(NULL AS %s)" + % self.dialect.type_compiler.process( + INTEGER() if type_._isnull else type_ + ) + for type_ in element_types or [INTEGER()] ), ) @@ -1502,7 +1622,7 @@ class PGCompiler(compiler.SQLCompiler): value = super(PGCompiler, self).render_literal_value(value, type_) if self.dialect._backslash_escapes: - value = value.replace('\\', '\\\\') + value = value.replace("\\", "\\\\") return value def visit_sequence(self, seq, **kw): @@ -1519,7 +1639,7 @@ class PGCompiler(compiler.SQLCompiler): return text def format_from_hint_text(self, sqltext, table, hint, iscrud): - if hint.upper() != 'ONLY': + if hint.upper() != "ONLY": raise exc.CompileError("Unrecognized hint: %r" % hint) return "ONLY " + sqltext @@ -1528,12 +1648,19 @@ class PGCompiler(compiler.SQLCompiler): if select._distinct is True: return "DISTINCT " elif isinstance(select._distinct, (list, tuple)): - return "DISTINCT ON (" + ', '.join( - [self.process(col, **kw) for col in select._distinct] - ) + ") " + return ( + "DISTINCT ON (" + + ", ".join( + [self.process(col, **kw) for col in select._distinct] + ) + + ") " + ) else: - return "DISTINCT ON (" + \ - self.process(select._distinct, **kw) + ") " + return ( + "DISTINCT ON (" + + self.process(select._distinct, **kw) + + ") " + ) else: return "" @@ -1551,8 +1678,9 @@ class PGCompiler(compiler.SQLCompiler): if select._for_update_arg.of: tables = util.OrderedSet( - c.table if isinstance(c, expression.ColumnClause) - else c for c in select._for_update_arg.of) + c.table if isinstance(c, expression.ColumnClause) else c + for c in select._for_update_arg.of + ) tmp += " OF " + ", ".join( self.process(table, ashint=True, use_schema=False, **kw) for table in tables @@ -1572,7 +1700,7 @@ class PGCompiler(compiler.SQLCompiler): for c in expression._select_iterables(returning_cols) ] - return 'RETURNING ' + ', '.join(columns) + return "RETURNING " + ", ".join(columns) def visit_substring_func(self, func, **kw): s = self.process(func.clauses.clauses[0], **kw) @@ -1586,24 +1714,24 @@ class PGCompiler(compiler.SQLCompiler): def _on_conflict_target(self, clause, **kw): if clause.constraint_target is not None: - target_text = 'ON CONSTRAINT %s' % clause.constraint_target + target_text = "ON CONSTRAINT %s" % clause.constraint_target elif clause.inferred_target_elements is not None: - target_text = '(%s)' % ', '.join( - (self.preparer.quote(c) - if isinstance(c, util.string_types) - else - self.process(c, include_table=False, use_schema=False)) + target_text = "(%s)" % ", ".join( + ( + self.preparer.quote(c) + if isinstance(c, util.string_types) + else self.process(c, include_table=False, use_schema=False) + ) for c in clause.inferred_target_elements ) if clause.inferred_target_whereclause is not None: - target_text += ' WHERE %s' % \ - self.process( - clause.inferred_target_whereclause, - include_table=False, - use_schema=False - ) + target_text += " WHERE %s" % self.process( + clause.inferred_target_whereclause, + include_table=False, + use_schema=False, + ) else: - target_text = '' + target_text = "" return target_text @@ -1627,36 +1755,35 @@ class PGCompiler(compiler.SQLCompiler): set_parameters = dict(clause.update_values_to_set) # create a list of column assignment clauses as tuples - insert_statement = self.stack[-1]['selectable'] + insert_statement = self.stack[-1]["selectable"] cols = insert_statement.table.c for c in cols: col_key = c.key if col_key in set_parameters: value = set_parameters.pop(col_key) if elements._is_literal(value): - value = elements.BindParameter( - None, value, type_=c.type - ) + value = elements.BindParameter(None, value, type_=c.type) else: - if isinstance(value, elements.BindParameter) and \ - value.type._isnull: + if ( + isinstance(value, elements.BindParameter) + and value.type._isnull + ): value = value._clone() value.type = c.type value_text = self.process(value.self_group(), use_schema=False) - key_text = ( - self.preparer.quote(col_key) - ) - action_set_ops.append('%s = %s' % (key_text, value_text)) + key_text = self.preparer.quote(col_key) + action_set_ops.append("%s = %s" % (key_text, value_text)) # check for names that don't match columns if set_parameters: util.warn( "Additional column names not matching " - "any column keys in table '%s': %s" % ( + "any column keys in table '%s': %s" + % ( self.statement.table.name, - (", ".join("'%s'" % c for c in set_parameters)) + (", ".join("'%s'" % c for c in set_parameters)), ) ) for k, v in set_parameters.items(): @@ -1666,42 +1793,37 @@ class PGCompiler(compiler.SQLCompiler): else self.process(k, use_schema=False) ) value_text = self.process( - elements._literal_as_binds(v), - use_schema=False + elements._literal_as_binds(v), use_schema=False ) - action_set_ops.append('%s = %s' % (key_text, value_text)) + action_set_ops.append("%s = %s" % (key_text, value_text)) - action_text = ', '.join(action_set_ops) + action_text = ", ".join(action_set_ops) if clause.update_whereclause is not None: - action_text += ' WHERE %s' % \ - self.process( - clause.update_whereclause, - include_table=True, - use_schema=False - ) + action_text += " WHERE %s" % self.process( + clause.update_whereclause, include_table=True, use_schema=False + ) - return 'ON CONFLICT %s DO UPDATE SET %s' % (target_text, action_text) + return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text) - def update_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): - return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in extra_froms) + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in extra_froms + ) - def delete_extra_from_clause(self, delete_stmt, from_table, - extra_froms, from_hints, **kw): + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the DELETE .. USING clause specific to PostgreSQL.""" - return "USING " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in extra_froms) + return "USING " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in extra_froms + ) class PGDDLCompiler(compiler.DDLCompiler): - def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) @@ -1709,17 +1831,21 @@ class PGDDLCompiler(compiler.DDLCompiler): if isinstance(impl_type, sqltypes.TypeDecorator): impl_type = impl_type.impl - if column.primary_key and \ - column is column.table._autoincrement_column and \ - ( - self.dialect.supports_smallserial or - not isinstance(impl_type, sqltypes.SmallInteger) - ) and ( - column.default is None or - ( - isinstance(column.default, schema.Sequence) and - column.default.optional - )): + if ( + column.primary_key + and column is column.table._autoincrement_column + and ( + self.dialect.supports_smallserial + or not isinstance(impl_type, sqltypes.SmallInteger) + ) + and ( + column.default is None + or ( + isinstance(column.default, schema.Sequence) + and column.default.optional + ) + ) + ): if isinstance(impl_type, sqltypes.BigInteger): colspec += " BIGSERIAL" elif isinstance(impl_type, sqltypes.SmallInteger): @@ -1728,7 +1854,8 @@ class PGDDLCompiler(compiler.DDLCompiler): colspec += " SERIAL" else: colspec += " " + self.dialect.type_compiler.process( - column.type, type_expression=column) + column.type, type_expression=column + ) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -1744,15 +1871,14 @@ class PGDDLCompiler(compiler.DDLCompiler): self.preparer.format_type(type_), ", ".join( self.sql_compiler.process(sql.literal(e), literal_binds=True) - for e in type_.enums) + for e in type_.enums + ), ) def visit_drop_enum_type(self, drop): type_ = drop.element - return "DROP TYPE %s" % ( - self.preparer.format_type(type_) - ) + return "DROP TYPE %s" % (self.preparer.format_type(type_)) def visit_create_index(self, create): preparer = self.preparer @@ -1764,46 +1890,53 @@ class PGDDLCompiler(compiler.DDLCompiler): text += "INDEX " if self.dialect._supports_create_index_concurrently: - concurrently = index.dialect_options['postgresql']['concurrently'] + concurrently = index.dialect_options["postgresql"]["concurrently"] if concurrently: text += "CONCURRENTLY " text += "%s ON %s " % ( - self._prepared_index_name(index, - include_schema=False), - preparer.format_table(index.table) + self._prepared_index_name(index, include_schema=False), + preparer.format_table(index.table), ) - using = index.dialect_options['postgresql']['using'] + using = index.dialect_options["postgresql"]["using"] if using: text += "USING %s " % preparer.quote(using) ops = index.dialect_options["postgresql"]["ops"] - text += "(%s)" \ - % ( - ', '.join([ - self.sql_compiler.process( - expr.self_group() - if not isinstance(expr, expression.ColumnClause) - else expr, - include_table=False, literal_binds=True) + - ( - (' ' + ops[expr.key]) - if hasattr(expr, 'key') - and expr.key in ops else '' - ) - for expr in index.expressions - ]) - ) + text += "(%s)" % ( + ", ".join( + [ + self.sql_compiler.process( + expr.self_group() + if not isinstance(expr, expression.ColumnClause) + else expr, + include_table=False, + literal_binds=True, + ) + + ( + (" " + ops[expr.key]) + if hasattr(expr, "key") and expr.key in ops + else "" + ) + for expr in index.expressions + ] + ) + ) - withclause = index.dialect_options['postgresql']['with'] + withclause = index.dialect_options["postgresql"]["with"] if withclause: - text += " WITH (%s)" % (', '.join( - ['%s = %s' % storage_parameter - for storage_parameter in withclause.items()])) + text += " WITH (%s)" % ( + ", ".join( + [ + "%s = %s" % storage_parameter + for storage_parameter in withclause.items() + ] + ) + ) - tablespace_name = index.dialect_options['postgresql']['tablespace'] + tablespace_name = index.dialect_options["postgresql"]["tablespace"] if tablespace_name: text += " TABLESPACE %s" % preparer.quote(tablespace_name) @@ -1812,8 +1945,8 @@ class PGDDLCompiler(compiler.DDLCompiler): if whereclause is not None: where_compiled = self.sql_compiler.process( - whereclause, include_table=False, - literal_binds=True) + whereclause, include_table=False, literal_binds=True + ) text += " WHERE " + where_compiled return text @@ -1823,7 +1956,7 @@ class PGDDLCompiler(compiler.DDLCompiler): text = "\nDROP INDEX " if self.dialect._supports_drop_index_concurrently: - concurrently = index.dialect_options['postgresql']['concurrently'] + concurrently = index.dialect_options["postgresql"]["concurrently"] if concurrently: text += "CONCURRENTLY " @@ -1833,55 +1966,59 @@ class PGDDLCompiler(compiler.DDLCompiler): def visit_exclude_constraint(self, constraint, **kw): text = "" if constraint.name is not None: - text += "CONSTRAINT %s " % \ - self.preparer.format_constraint(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint( + constraint + ) elements = [] for expr, name, op in constraint._render_exprs: - kw['include_table'] = False + kw["include_table"] = False elements.append( "%s WITH %s" % (self.sql_compiler.process(expr, **kw), op) ) - text += "EXCLUDE USING %s (%s)" % (constraint.using, - ', '.join(elements)) + text += "EXCLUDE USING %s (%s)" % ( + constraint.using, + ", ".join(elements), + ) if constraint.where is not None: - text += ' WHERE (%s)' % self.sql_compiler.process( - constraint.where, - literal_binds=True) + text += " WHERE (%s)" % self.sql_compiler.process( + constraint.where, literal_binds=True + ) text += self.define_constraint_deferrability(constraint) return text def post_create_table(self, table): table_opts = [] - pg_opts = table.dialect_options['postgresql'] + pg_opts = table.dialect_options["postgresql"] - inherits = pg_opts.get('inherits') + inherits = pg_opts.get("inherits") if inherits is not None: if not isinstance(inherits, (list, tuple)): - inherits = (inherits, ) + inherits = (inherits,) table_opts.append( - '\n INHERITS ( ' + - ', '.join(self.preparer.quote(name) for name in inherits) + - ' )') + "\n INHERITS ( " + + ", ".join(self.preparer.quote(name) for name in inherits) + + " )" + ) - if pg_opts['partition_by']: - table_opts.append('\n PARTITION BY %s' % pg_opts['partition_by']) + if pg_opts["partition_by"]: + table_opts.append("\n PARTITION BY %s" % pg_opts["partition_by"]) - if pg_opts['with_oids'] is True: - table_opts.append('\n WITH OIDS') - elif pg_opts['with_oids'] is False: - table_opts.append('\n WITHOUT OIDS') + if pg_opts["with_oids"] is True: + table_opts.append("\n WITH OIDS") + elif pg_opts["with_oids"] is False: + table_opts.append("\n WITHOUT OIDS") - if pg_opts['on_commit']: - on_commit_options = pg_opts['on_commit'].replace("_", " ").upper() - table_opts.append('\n ON COMMIT %s' % on_commit_options) + if pg_opts["on_commit"]: + on_commit_options = pg_opts["on_commit"].replace("_", " ").upper() + table_opts.append("\n ON COMMIT %s" % on_commit_options) - if pg_opts['tablespace']: - tablespace_name = pg_opts['tablespace'] + if pg_opts["tablespace"]: + tablespace_name = pg_opts["tablespace"] table_opts.append( - '\n TABLESPACE %s' % self.preparer.quote(tablespace_name) + "\n TABLESPACE %s" % self.preparer.quote(tablespace_name) ) - return ''.join(table_opts) + return "".join(table_opts) class PGTypeCompiler(compiler.GenericTypeCompiler): @@ -1910,7 +2047,7 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): if not type_.precision: return "FLOAT" else: - return "FLOAT(%(precision)s)" % {'precision': type_.precision} + return "FLOAT(%(precision)s)" % {"precision": type_.precision} def visit_DOUBLE_PRECISION(self, type_, **kw): return "DOUBLE PRECISION" @@ -1960,15 +2097,17 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_TIMESTAMP(self, type_, **kw): return "TIMESTAMP%s %s" % ( "(%d)" % type_.precision - if getattr(type_, 'precision', None) is not None else "", - (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + if getattr(type_, "precision", None) is not None + else "", + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", ) def visit_TIME(self, type_, **kw): return "TIME%s %s" % ( "(%d)" % type_.precision - if getattr(type_, 'precision', None) is not None else "", - (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + if getattr(type_, "precision", None) is not None + else "", + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", ) def visit_INTERVAL(self, type_, **kw): @@ -2002,13 +2141,16 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): # TODO: pass **kw? inner = self.process(type_.item_type) return re.sub( - r'((?: COLLATE.*)?)$', - (r'%s\1' % ( - "[]" * - (type_.dimensions if type_.dimensions is not None else 1) - )), + r"((?: COLLATE.*)?)$", + ( + r"%s\1" + % ( + "[]" + * (type_.dimensions if type_.dimensions is not None else 1) + ) + ), inner, - count=1 + count=1, ) @@ -2018,8 +2160,9 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): def _unquote_identifier(self, value): if value[0] == self.initial_quote: - value = value[1:-1].\ - replace(self.escape_to_quote, self.escape_quote) + value = value[1:-1].replace( + self.escape_to_quote, self.escape_quote + ) return value def format_type(self, type_, use_schema=True): @@ -2029,22 +2172,25 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): name = self.quote(type_.name) effective_schema = self.schema_for_object(type_) - if not self.omit_schema and use_schema and \ - effective_schema is not None: + if ( + not self.omit_schema + and use_schema + and effective_schema is not None + ): name = self.quote_schema(effective_schema) + "." + name return name class PGInspector(reflection.Inspector): - def __init__(self, conn): reflection.Inspector.__init__(self, conn) def get_table_oid(self, table_name, schema=None): """Return the OID for the given table name.""" - return self.dialect.get_table_oid(self.bind, table_name, schema, - info_cache=self.info_cache) + return self.dialect.get_table_oid( + self.bind, table_name, schema, info_cache=self.info_cache + ) def get_enums(self, schema=None): """Return a list of ENUM objects. @@ -2080,7 +2226,7 @@ class PGInspector(reflection.Inspector): schema = schema or self.default_schema_name return self.dialect._get_foreign_table_names(self.bind, schema) - def get_view_names(self, schema=None, include=('plain', 'materialized')): + def get_view_names(self, schema=None, include=("plain", "materialized")): """Return all view names in `schema`. :param schema: Optional, retrieve names from a non-default schema. @@ -2094,9 +2240,9 @@ class PGInspector(reflection.Inspector): """ - return self.dialect.get_view_names(self.bind, schema, - info_cache=self.info_cache, - include=include) + return self.dialect.get_view_names( + self.bind, schema, info_cache=self.info_cache, include=include + ) class CreateEnumType(schema._CreateDropBase): @@ -2108,25 +2254,27 @@ class DropEnumType(schema._CreateDropBase): class PGExecutionContext(default.DefaultExecutionContext): - def fire_sequence(self, seq, type_): - return self._execute_scalar(( - "select nextval('%s')" % - self.dialect.identifier_preparer.format_sequence(seq)), type_) + return self._execute_scalar( + ( + "select nextval('%s')" + % self.dialect.identifier_preparer.format_sequence(seq) + ), + type_, + ) def get_insert_default(self, column): - if column.primary_key and \ - column is column.table._autoincrement_column: + if column.primary_key and column is column.table._autoincrement_column: if column.server_default and column.server_default.has_argument: # pre-execute passive defaults on primary key columns - return self._execute_scalar("select %s" % - column.server_default.arg, - column.type) + return self._execute_scalar( + "select %s" % column.server_default.arg, column.type + ) - elif (column.default is None or - (column.default.is_sequence and - column.default.optional)): + elif column.default is None or ( + column.default.is_sequence and column.default.optional + ): # execute the sequence associated with a SERIAL primary # key column. for non-primary-key SERIAL, the ID just @@ -2137,23 +2285,25 @@ class PGExecutionContext(default.DefaultExecutionContext): except AttributeError: tab = column.table.name col = column.name - tab = tab[0:29 + max(0, (29 - len(col)))] - col = col[0:29 + max(0, (29 - len(tab)))] + tab = tab[0 : 29 + max(0, (29 - len(col)))] + col = col[0 : 29 + max(0, (29 - len(tab)))] name = "%s_%s_seq" % (tab, col) column._postgresql_seq_name = seq_name = name if column.table is not None: effective_schema = self.connection.schema_for_object( - column.table) + column.table + ) else: effective_schema = None if effective_schema is not None: - exc = "select nextval('\"%s\".\"%s\"')" % \ - (effective_schema, seq_name) + exc = 'select nextval(\'"%s"."%s"\')' % ( + effective_schema, + seq_name, + ) else: - exc = "select nextval('\"%s\"')" % \ - (seq_name, ) + exc = "select nextval('\"%s\"')" % (seq_name,) return self._execute_scalar(exc, column.type) @@ -2164,7 +2314,7 @@ class PGExecutionContext(default.DefaultExecutionContext): class PGDialect(default.DefaultDialect): - name = 'postgresql' + name = "postgresql" supports_alter = True max_identifier_length = 63 supports_sane_rowcount = True @@ -2182,7 +2332,7 @@ class PGDialect(default.DefaultDialect): supports_default_values = True supports_empty_insert = False supports_multivalues_insert = True - default_paramstyle = 'pyformat' + default_paramstyle = "pyformat" ischema_names = ischema_names colspecs = colspecs @@ -2195,32 +2345,43 @@ class PGDialect(default.DefaultDialect): isolation_level = None construct_arguments = [ - (schema.Index, { - "using": False, - "where": None, - "ops": {}, - "concurrently": False, - "with": {}, - "tablespace": None - }), - (schema.Table, { - "ignore_search_path": False, - "tablespace": None, - "partition_by": None, - "with_oids": None, - "on_commit": None, - "inherits": None - }), + ( + schema.Index, + { + "using": False, + "where": None, + "ops": {}, + "concurrently": False, + "with": {}, + "tablespace": None, + }, + ), + ( + schema.Table, + { + "ignore_search_path": False, + "tablespace": None, + "partition_by": None, + "with_oids": None, + "on_commit": None, + "inherits": None, + }, + ), ] - reflection_options = ('postgresql_ignore_search_path', ) + reflection_options = ("postgresql_ignore_search_path",) _backslash_escapes = True _supports_create_index_concurrently = True _supports_drop_index_concurrently = True - def __init__(self, isolation_level=None, json_serializer=None, - json_deserializer=None, **kwargs): + def __init__( + self, + isolation_level=None, + json_serializer=None, + json_deserializer=None, + **kwargs + ): default.DefaultDialect.__init__(self, **kwargs) self.isolation_level = isolation_level self._json_deserializer = json_deserializer @@ -2228,8 +2389,10 @@ class PGDialect(default.DefaultDialect): def initialize(self, connection): super(PGDialect, self).initialize(connection) - self.implicit_returning = self.server_version_info > (8, 2) and \ - self.__dict__.get('implicit_returning', True) + self.implicit_returning = self.server_version_info > ( + 8, + 2, + ) and self.__dict__.get("implicit_returning", True) self.supports_native_enum = self.server_version_info >= (8, 3) if not self.supports_native_enum: self.colspecs = self.colspecs.copy() @@ -2241,45 +2404,57 @@ class PGDialect(default.DefaultDialect): # http://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689 self.supports_smallserial = self.server_version_info >= (9, 2) - self._backslash_escapes = self.server_version_info < (8, 2) or \ - connection.scalar( - "show standard_conforming_strings" - ) == 'off' + self._backslash_escapes = ( + self.server_version_info < (8, 2) + or connection.scalar("show standard_conforming_strings") == "off" + ) - self._supports_create_index_concurrently = \ + self._supports_create_index_concurrently = ( self.server_version_info >= (8, 2) - self._supports_drop_index_concurrently = \ - self.server_version_info >= (9, 2) + ) + self._supports_drop_index_concurrently = self.server_version_info >= ( + 9, + 2, + ) def on_connect(self): if self.isolation_level is not None: + def connect(conn): self.set_isolation_level(conn, self.isolation_level) + return connect else: return None - _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED', - 'READ COMMITTED', 'REPEATABLE READ']) + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + ] + ) def set_isolation_level(self, connection, level): - level = level.replace('_', ' ') + level = level.replace("_", " ") if level not in self._isolation_lookup: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) ) cursor = connection.cursor() cursor.execute( "SET SESSION CHARACTERISTICS AS TRANSACTION " - "ISOLATION LEVEL %s" % level) + "ISOLATION LEVEL %s" % level + ) cursor.execute("COMMIT") cursor.close() def get_isolation_level(self, connection): cursor = connection.cursor() - cursor.execute('show transaction isolation level') + cursor.execute("show transaction isolation level") val = cursor.fetchone()[0] cursor.close() return val.upper() @@ -2290,8 +2465,9 @@ class PGDialect(default.DefaultDialect): def do_prepare_twophase(self, connection, xid): connection.execute("PREPARE TRANSACTION '%s'" % xid) - def do_rollback_twophase(self, connection, xid, - is_prepared=True, recover=False): + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if is_prepared: if recover: # FIXME: ugly hack to get out of transaction @@ -2305,8 +2481,9 @@ class PGDialect(default.DefaultDialect): else: self.do_rollback(connection.connection) - def do_commit_twophase(self, connection, xid, - is_prepared=True, recover=False): + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if is_prepared: if recover: connection.execute("ROLLBACK") @@ -2318,22 +2495,27 @@ class PGDialect(default.DefaultDialect): def do_recover_twophase(self, connection): resultset = connection.execute( - sql.text("SELECT gid FROM pg_prepared_xacts")) + sql.text("SELECT gid FROM pg_prepared_xacts") + ) return [row[0] for row in resultset] def _get_default_schema_name(self, connection): return connection.scalar("select current_schema()") def has_schema(self, connection, schema): - query = ("select nspname from pg_namespace " - "where lower(nspname)=:schema") + query = ( + "select nspname from pg_namespace " "where lower(nspname)=:schema" + ) cursor = connection.execute( sql.text( query, bindparams=[ sql.bindparam( - 'schema', util.text_type(schema.lower()), - type_=sqltypes.Unicode)] + "schema", + util.text_type(schema.lower()), + type_=sqltypes.Unicode, + ) + ], ) ) @@ -2349,8 +2531,12 @@ class PGDialect(default.DefaultDialect): "pg_catalog.pg_table_is_visible(c.oid) " "and relname=:name", bindparams=[ - sql.bindparam('name', util.text_type(table_name), - type_=sqltypes.Unicode)] + sql.bindparam( + "name", + util.text_type(table_name), + type_=sqltypes.Unicode, + ) + ], ) ) else: @@ -2360,12 +2546,17 @@ class PGDialect(default.DefaultDialect): "n.oid=c.relnamespace where n.nspname=:schema and " "relname=:name", bindparams=[ - sql.bindparam('name', - util.text_type(table_name), - type_=sqltypes.Unicode), - sql.bindparam('schema', - util.text_type(schema), - type_=sqltypes.Unicode)] + sql.bindparam( + "name", + util.text_type(table_name), + type_=sqltypes.Unicode, + ), + sql.bindparam( + "schema", + util.text_type(schema), + type_=sqltypes.Unicode, + ), + ], ) ) return bool(cursor.first()) @@ -2379,9 +2570,12 @@ class PGDialect(default.DefaultDialect): "n.nspname=current_schema() " "and relname=:name", bindparams=[ - sql.bindparam('name', util.text_type(sequence_name), - type_=sqltypes.Unicode) - ] + sql.bindparam( + "name", + util.text_type(sequence_name), + type_=sqltypes.Unicode, + ) + ], ) ) else: @@ -2391,12 +2585,17 @@ class PGDialect(default.DefaultDialect): "n.oid=c.relnamespace where relkind='S' and " "n.nspname=:schema and relname=:name", bindparams=[ - sql.bindparam('name', util.text_type(sequence_name), - type_=sqltypes.Unicode), - sql.bindparam('schema', - util.text_type(schema), - type_=sqltypes.Unicode) - ] + sql.bindparam( + "name", + util.text_type(sequence_name), + type_=sqltypes.Unicode, + ), + sql.bindparam( + "schema", + util.text_type(schema), + type_=sqltypes.Unicode, + ), + ], ) ) @@ -2423,13 +2622,15 @@ class PGDialect(default.DefaultDialect): """ query = sql.text(query) query = query.bindparams( - sql.bindparam('typname', - util.text_type(type_name), type_=sqltypes.Unicode), + sql.bindparam( + "typname", util.text_type(type_name), type_=sqltypes.Unicode + ) ) if schema is not None: query = query.bindparams( - sql.bindparam('nspname', - util.text_type(schema), type_=sqltypes.Unicode), + sql.bindparam( + "nspname", util.text_type(schema), type_=sqltypes.Unicode + ) ) cursor = connection.execute(query) return bool(cursor.scalar()) @@ -2437,12 +2638,14 @@ class PGDialect(default.DefaultDialect): def _get_server_version_info(self, connection): v = connection.execute("select version()").scalar() m = re.match( - r'.*(?:PostgreSQL|EnterpriseDB) ' - r'(\d+)\.?(\d+)?(?:\.(\d+))?(?:\.\d+)?(?:devel|beta)?', - v) + r".*(?:PostgreSQL|EnterpriseDB) " + r"(\d+)\.?(\d+)?(?:\.(\d+))?(?:\.\d+)?(?:devel|beta)?", + v, + ) if not m: raise AssertionError( - "Could not determine version from string '%s'" % v) + "Could not determine version from string '%s'" % v + ) return tuple([int(x) for x in m.group(1, 2, 3) if x is not None]) @reflection.cache @@ -2459,14 +2662,17 @@ class PGDialect(default.DefaultDialect): schema_where_clause = "n.nspname = :schema" else: schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" - query = """ + query = ( + """ SELECT c.oid FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE (%s) AND c.relname = :table_name AND c.relkind in ('r', 'v', 'm', 'f', 'p') - """ % schema_where_clause + """ + % schema_where_clause + ) # Since we're binding to unicode, table_name and schema_name must be # unicode. table_name = util.text_type(table_name) @@ -2475,7 +2681,7 @@ class PGDialect(default.DefaultDialect): s = sql.text(query).bindparams(table_name=sqltypes.Unicode) s = s.columns(oid=sqltypes.Integer) if schema: - s = s.bindparams(sql.bindparam('schema', type_=sqltypes.Unicode)) + s = s.bindparams(sql.bindparam("schema", type_=sqltypes.Unicode)) c = connection.execute(s, table_name=table_name, schema=schema) table_oid = c.scalar() if table_oid is None: @@ -2485,75 +2691,88 @@ class PGDialect(default.DefaultDialect): @reflection.cache def get_schema_names(self, connection, **kw): result = connection.execute( - sql.text("SELECT nspname FROM pg_namespace " - "WHERE nspname NOT LIKE 'pg_%' " - "ORDER BY nspname" - ).columns(nspname=sqltypes.Unicode)) + sql.text( + "SELECT nspname FROM pg_namespace " + "WHERE nspname NOT LIKE 'pg_%' " + "ORDER BY nspname" + ).columns(nspname=sqltypes.Unicode) + ) return [name for name, in result] @reflection.cache def get_table_names(self, connection, schema=None, **kw): result = connection.execute( - sql.text("SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')" - ).columns(relname=sqltypes.Unicode), - schema=schema if schema is not None else self.default_schema_name) + sql.text( + "SELECT c.relname FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')" + ).columns(relname=sqltypes.Unicode), + schema=schema if schema is not None else self.default_schema_name, + ) return [name for name, in result] @reflection.cache def _get_foreign_table_names(self, connection, schema=None, **kw): result = connection.execute( - sql.text("SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind = 'f'" - ).columns(relname=sqltypes.Unicode), - schema=schema if schema is not None else self.default_schema_name) + sql.text( + "SELECT c.relname FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relkind = 'f'" + ).columns(relname=sqltypes.Unicode), + schema=schema if schema is not None else self.default_schema_name, + ) return [name for name, in result] @reflection.cache def get_view_names( - self, connection, schema=None, - include=('plain', 'materialized'), **kw): + self, connection, schema=None, include=("plain", "materialized"), **kw + ): - include_kind = {'plain': 'v', 'materialized': 'm'} + include_kind = {"plain": "v", "materialized": "m"} try: kinds = [include_kind[i] for i in util.to_list(include)] except KeyError: raise ValueError( "include %r unknown, needs to be a sequence containing " - "one or both of 'plain' and 'materialized'" % (include,)) + "one or both of 'plain' and 'materialized'" % (include,) + ) if not kinds: raise ValueError( "empty include, needs to be a sequence containing " - "one or both of 'plain' and 'materialized'") + "one or both of 'plain' and 'materialized'" + ) result = connection.execute( - sql.text("SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind IN (%s)" % - (", ".join("'%s'" % elem for elem in kinds)) - ).columns(relname=sqltypes.Unicode), - schema=schema if schema is not None else self.default_schema_name) + sql.text( + "SELECT c.relname FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relkind IN (%s)" + % (", ".join("'%s'" % elem for elem in kinds)) + ).columns(relname=sqltypes.Unicode), + schema=schema if schema is not None else self.default_schema_name, + ) return [name for name, in result] @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): view_def = connection.scalar( - sql.text("SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relname = :view_name " - "AND c.relkind IN ('v', 'm')" - ).columns(view_def=sqltypes.Unicode), + sql.text( + "SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relname = :view_name " + "AND c.relkind IN ('v', 'm')" + ).columns(view_def=sqltypes.Unicode), schema=schema if schema is not None else self.default_schema_name, - view_name=view_name) + view_name=view_name, + ) return view_def @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid(connection, table_name, schema, - info_cache=kw.get('info_cache')) + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) SQL_COLS = """ SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), @@ -2571,13 +2790,11 @@ class PGDialect(default.DefaultDialect): AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum """ - s = sql.text(SQL_COLS, - bindparams=[ - sql.bindparam('table_oid', type_=sqltypes.Integer)], - typemap={ - 'attname': sqltypes.Unicode, - 'default': sqltypes.Unicode} - ) + s = sql.text( + SQL_COLS, + bindparams=[sql.bindparam("table_oid", type_=sqltypes.Integer)], + typemap={"attname": sqltypes.Unicode, "default": sqltypes.Unicode}, + ) c = connection.execute(s, table_oid=table_oid) rows = c.fetchall() @@ -2588,34 +2805,58 @@ class PGDialect(default.DefaultDialect): # dictionary with (name, ) if default search path or (schema, name) # as keys enums = dict( - ((rec['name'], ), rec) - if rec['visible'] else ((rec['schema'], rec['name']), rec) - for rec in self._load_enums(connection, schema='*') + ((rec["name"],), rec) + if rec["visible"] + else ((rec["schema"], rec["name"]), rec) + for rec in self._load_enums(connection, schema="*") ) # format columns columns = [] - for name, format_type, default_, notnull, attnum, table_oid, \ - comment in rows: + for ( + name, + format_type, + default_, + notnull, + attnum, + table_oid, + comment, + ) in rows: column_info = self._get_column_info( - name, format_type, default_, notnull, domains, enums, - schema, comment) + name, + format_type, + default_, + notnull, + domains, + enums, + schema, + comment, + ) columns.append(column_info) return columns - def _get_column_info(self, name, format_type, default, - notnull, domains, enums, schema, comment): + def _get_column_info( + self, + name, + format_type, + default, + notnull, + domains, + enums, + schema, + comment, + ): def _handle_array_type(attype): return ( # strip '[]' from integer[], etc. - re.sub(r'\[\]$', '', attype), - attype.endswith('[]'), + re.sub(r"\[\]$", "", attype), + attype.endswith("[]"), ) # strip (*) from character varying(5), timestamp(5) # with time zone, geometry(POLYGON), etc. - attype = re.sub(r'\(.*\)', '', format_type) + attype = re.sub(r"\(.*\)", "", format_type) # strip '[]' from integer[], etc. and check if an array attype, is_array = _handle_array_type(attype) @@ -2625,50 +2866,52 @@ class PGDialect(default.DefaultDialect): nullable = not notnull - charlen = re.search(r'\(([\d,]+)\)', format_type) + charlen = re.search(r"\(([\d,]+)\)", format_type) if charlen: charlen = charlen.group(1) - args = re.search(r'\((.*)\)', format_type) + args = re.search(r"\((.*)\)", format_type) if args and args.group(1): - args = tuple(re.split(r'\s*,\s*', args.group(1))) + args = tuple(re.split(r"\s*,\s*", args.group(1))) else: args = () kwargs = {} - if attype == 'numeric': + if attype == "numeric": if charlen: - prec, scale = charlen.split(',') + prec, scale = charlen.split(",") args = (int(prec), int(scale)) else: args = () - elif attype == 'double precision': - args = (53, ) - elif attype == 'integer': + elif attype == "double precision": + args = (53,) + elif attype == "integer": args = () - elif attype in ('timestamp with time zone', - 'time with time zone'): - kwargs['timezone'] = True + elif attype in ("timestamp with time zone", "time with time zone"): + kwargs["timezone"] = True if charlen: - kwargs['precision'] = int(charlen) + kwargs["precision"] = int(charlen) args = () - elif attype in ('timestamp without time zone', - 'time without time zone', 'time'): - kwargs['timezone'] = False + elif attype in ( + "timestamp without time zone", + "time without time zone", + "time", + ): + kwargs["timezone"] = False if charlen: - kwargs['precision'] = int(charlen) + kwargs["precision"] = int(charlen) args = () - elif attype == 'bit varying': - kwargs['varying'] = True + elif attype == "bit varying": + kwargs["varying"] = True if charlen: args = (int(charlen),) else: args = () - elif attype.startswith('interval'): - field_match = re.match(r'interval (.+)', attype, re.I) + elif attype.startswith("interval"): + field_match = re.match(r"interval (.+)", attype, re.I) if charlen: - kwargs['precision'] = int(charlen) + kwargs["precision"] = int(charlen) if field_match: - kwargs['fields'] = field_match.group(1) + kwargs["fields"] = field_match.group(1) attype = "interval" args = () elif charlen: @@ -2682,23 +2925,23 @@ class PGDialect(default.DefaultDialect): elif enum_or_domain_key in enums: enum = enums[enum_or_domain_key] coltype = ENUM - kwargs['name'] = enum['name'] - if not enum['visible']: - kwargs['schema'] = enum['schema'] - args = tuple(enum['labels']) + kwargs["name"] = enum["name"] + if not enum["visible"]: + kwargs["schema"] = enum["schema"] + args = tuple(enum["labels"]) break elif enum_or_domain_key in domains: domain = domains[enum_or_domain_key] - attype = domain['attype'] + attype = domain["attype"] attype, is_array = _handle_array_type(attype) # strip quotes from case sensitive enum or domain names enum_or_domain_key = tuple(util.quoted_token_parser(attype)) # A table can't override whether the domain is nullable. - nullable = domain['nullable'] - if domain['default'] and not default: + nullable = domain["nullable"] + if domain["default"] and not default: # It can, however, override the default # value, but can't set it to null. - default = domain['default'] + default = domain["default"] continue else: coltype = None @@ -2707,10 +2950,11 @@ class PGDialect(default.DefaultDialect): if coltype: coltype = coltype(*args, **kwargs) if is_array: - coltype = self.ischema_names['_array'](coltype) + coltype = self.ischema_names["_array"](coltype) else: - util.warn("Did not recognize type '%s' of column '%s'" % - (attype, name)) + util.warn( + "Did not recognize type '%s' of column '%s'" % (attype, name) + ) coltype = sqltypes.NULLTYPE # adjust the default value autoincrement = False @@ -2721,23 +2965,33 @@ class PGDialect(default.DefaultDialect): autoincrement = True # the default is related to a Sequence sch = schema - if '.' not in match.group(2) and sch is not None: + if "." not in match.group(2) and sch is not None: # unconditionally quote the schema name. this could # later be enhanced to obey quoting rules / # "quote schema" - default = match.group(1) + \ - ('"%s"' % sch) + '.' + \ - match.group(2) + match.group(3) + default = ( + match.group(1) + + ('"%s"' % sch) + + "." + + match.group(2) + + match.group(3) + ) - column_info = dict(name=name, type=coltype, nullable=nullable, - default=default, autoincrement=autoincrement, - comment=comment) + column_info = dict( + name=name, + type=coltype, + nullable=nullable, + default=default, + autoincrement=autoincrement, + comment=comment, + ) return column_info @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid(connection, table_name, schema, - info_cache=kw.get('info_cache')) + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) if self.server_version_info < (8, 4): PK_SQL = """ @@ -2750,7 +3004,9 @@ class PGDialect(default.DefaultDialect): WHERE t.oid = :table_oid and ix.indisprimary = 't' ORDER BY a.attnum - """ % self._pg_index_any("a.attnum", "ix.indkey") + """ % self._pg_index_any( + "a.attnum", "ix.indkey" + ) else: # unnest() and generate_subscripts() both introduced in @@ -2766,7 +3022,7 @@ class PGDialect(default.DefaultDialect): WHERE a.attrelid = :table_oid ORDER BY k.ord """ - t = sql.text(PK_SQL, typemap={'attname': sqltypes.Unicode}) + t = sql.text(PK_SQL, typemap={"attname": sqltypes.Unicode}) c = connection.execute(t, table_oid=table_oid) cols = [r[0] for r in c.fetchall()] @@ -2776,18 +3032,25 @@ class PGDialect(default.DefaultDialect): WHERE r.conrelid = :table_oid AND r.contype = 'p' ORDER BY 1 """ - t = sql.text(PK_CONS_SQL, typemap={'conname': sqltypes.Unicode}) + t = sql.text(PK_CONS_SQL, typemap={"conname": sqltypes.Unicode}) c = connection.execute(t, table_oid=table_oid) name = c.scalar() - return {'constrained_columns': cols, 'name': name} + return {"constrained_columns": cols, "name": name} @reflection.cache - def get_foreign_keys(self, connection, table_name, schema=None, - postgresql_ignore_search_path=False, **kw): + def get_foreign_keys( + self, + connection, + table_name, + schema=None, + postgresql_ignore_search_path=False, + **kw + ): preparer = self.identifier_preparer - table_oid = self.get_table_oid(connection, table_name, schema, - info_cache=kw.get('info_cache')) + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) FK_SQL = """ SELECT r.conname, @@ -2805,34 +3068,35 @@ class PGDialect(default.DefaultDialect): """ # http://www.postgresql.org/docs/9.0/static/sql-createtable.html FK_REGEX = re.compile( - r'FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)' - r'[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?' - r'[\s]?(ON UPDATE ' - r'(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?' - r'[\s]?(ON DELETE ' - r'(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?' - r'[\s]?(DEFERRABLE|NOT DEFERRABLE)?' - r'[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?' + r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)" + r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?" + r"[\s]?(ON UPDATE " + r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?" + r"[\s]?(ON DELETE " + r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?" + r"[\s]?(DEFERRABLE|NOT DEFERRABLE)?" + r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?" ) - t = sql.text(FK_SQL, typemap={ - 'conname': sqltypes.Unicode, - 'condef': sqltypes.Unicode}) + t = sql.text( + FK_SQL, + typemap={"conname": sqltypes.Unicode, "condef": sqltypes.Unicode}, + ) c = connection.execute(t, table=table_oid) fkeys = [] for conname, condef, conschema in c.fetchall(): m = re.search(FK_REGEX, condef).groups() - constrained_columns, referred_schema, \ - referred_table, referred_columns, \ - _, match, _, onupdate, _, ondelete, \ - deferrable, _, initially = m + constrained_columns, referred_schema, referred_table, referred_columns, _, match, _, onupdate, _, ondelete, deferrable, _, initially = ( + m + ) if deferrable is not None: - deferrable = True if deferrable == 'DEFERRABLE' else False - constrained_columns = [preparer._unquote_identifier(x) - for x in re.split( - r'\s*,\s*', constrained_columns)] + deferrable = True if deferrable == "DEFERRABLE" else False + constrained_columns = [ + preparer._unquote_identifier(x) + for x in re.split(r"\s*,\s*", constrained_columns) + ] if postgresql_ignore_search_path: # when ignoring search path, we use the actual schema @@ -2845,30 +3109,30 @@ class PGDialect(default.DefaultDialect): # referred_schema is the schema that we regexp'ed from # pg_get_constraintdef(). If the schema is in the search # path, pg_get_constraintdef() will give us None. - referred_schema = \ - preparer._unquote_identifier(referred_schema) + referred_schema = preparer._unquote_identifier(referred_schema) elif schema is not None and schema == conschema: # If the actual schema matches the schema of the table # we're reflecting, then we will use that. referred_schema = schema referred_table = preparer._unquote_identifier(referred_table) - referred_columns = [preparer._unquote_identifier(x) - for x in - re.split(r'\s*,\s', referred_columns)] + referred_columns = [ + preparer._unquote_identifier(x) + for x in re.split(r"\s*,\s", referred_columns) + ] fkey_d = { - 'name': conname, - 'constrained_columns': constrained_columns, - 'referred_schema': referred_schema, - 'referred_table': referred_table, - 'referred_columns': referred_columns, - 'options': { - 'onupdate': onupdate, - 'ondelete': ondelete, - 'deferrable': deferrable, - 'initially': initially, - 'match': match - } + "name": conname, + "constrained_columns": constrained_columns, + "referred_schema": referred_schema, + "referred_table": referred_table, + "referred_columns": referred_columns, + "options": { + "onupdate": onupdate, + "ondelete": ondelete, + "deferrable": deferrable, + "initially": initially, + "match": match, + }, } fkeys.append(fkey_d) return fkeys @@ -2882,16 +3146,16 @@ class PGDialect(default.DefaultDialect): # for now. # regards, tom lane" return "(%s)" % " OR ".join( - "%s[%d] = %s" % (compare_to, ind, col) - for ind in range(0, 10) + "%s[%d] = %s" % (compare_to, ind, col) for ind in range(0, 10) ) else: return "%s = ANY(%s)" % (col, compare_to) @reflection.cache def get_indexes(self, connection, table_name, schema, **kw): - table_oid = self.get_table_oid(connection, table_name, schema, - info_cache=kw.get('info_cache')) + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) # cast indkey as varchar since it's an int2vector, # returned as a list by some drivers such as pypostgresql @@ -2925,9 +3189,10 @@ class PGDialect(default.DefaultDialect): # cast does not work in PG 8.2.4, does work in 8.3.0. # nothing in PG changelogs regarding this. "::varchar" if self.server_version_info >= (8, 3) else "", - "i.reloptions" if self.server_version_info >= (8, 2) + "i.reloptions" + if self.server_version_info >= (8, 2) else "NULL", - self._pg_index_any("a.attnum", "ix.indkey") + self._pg_index_any("a.attnum", "ix.indkey"), ) else: IDX_SQL = """ @@ -2960,76 +3225,93 @@ class PGDialect(default.DefaultDialect): i.relname """ - t = sql.text(IDX_SQL, typemap={ - 'relname': sqltypes.Unicode, - 'attname': sqltypes.Unicode}) + t = sql.text( + IDX_SQL, + typemap={"relname": sqltypes.Unicode, "attname": sqltypes.Unicode}, + ) c = connection.execute(t, table_oid=table_oid) indexes = defaultdict(lambda: defaultdict(dict)) sv_idx_name = None for row in c.fetchall(): - (idx_name, unique, expr, prd, col, - col_num, conrelid, idx_key, options, amname) = row + ( + idx_name, + unique, + expr, + prd, + col, + col_num, + conrelid, + idx_key, + options, + amname, + ) = row if expr: if idx_name != sv_idx_name: util.warn( "Skipped unsupported reflection of " - "expression-based index %s" - % idx_name) + "expression-based index %s" % idx_name + ) sv_idx_name = idx_name continue if prd and not idx_name == sv_idx_name: util.warn( "Predicate of partial index %s ignored during reflection" - % idx_name) + % idx_name + ) sv_idx_name = idx_name has_idx = idx_name in indexes index = indexes[idx_name] if col is not None: - index['cols'][col_num] = col + index["cols"][col_num] = col if not has_idx: - index['key'] = [int(k.strip()) for k in idx_key.split()] - index['unique'] = unique + index["key"] = [int(k.strip()) for k in idx_key.split()] + index["unique"] = unique if conrelid is not None: - index['duplicates_constraint'] = idx_name + index["duplicates_constraint"] = idx_name if options: - index['options'] = dict( - [option.split("=") for option in options]) + index["options"] = dict( + [option.split("=") for option in options] + ) # it *might* be nice to include that this is 'btree' in the # reflection info. But we don't want an Index object # to have a ``postgresql_using`` in it that is just the # default, so for the moment leaving this out. - if amname and amname != 'btree': - index['amname'] = amname + if amname and amname != "btree": + index["amname"] = amname result = [] for name, idx in indexes.items(): entry = { - 'name': name, - 'unique': idx['unique'], - 'column_names': [idx['cols'][i] for i in idx['key']] + "name": name, + "unique": idx["unique"], + "column_names": [idx["cols"][i] for i in idx["key"]], } - if 'duplicates_constraint' in idx: - entry['duplicates_constraint'] = idx['duplicates_constraint'] - if 'options' in idx: - entry.setdefault( - 'dialect_options', {})["postgresql_with"] = idx['options'] - if 'amname' in idx: - entry.setdefault( - 'dialect_options', {})["postgresql_using"] = idx['amname'] + if "duplicates_constraint" in idx: + entry["duplicates_constraint"] = idx["duplicates_constraint"] + if "options" in idx: + entry.setdefault("dialect_options", {})[ + "postgresql_with" + ] = idx["options"] + if "amname" in idx: + entry.setdefault("dialect_options", {})[ + "postgresql_using" + ] = idx["amname"] result.append(entry) return result @reflection.cache - def get_unique_constraints(self, connection, table_name, - schema=None, **kw): - table_oid = self.get_table_oid(connection, table_name, schema, - info_cache=kw.get('info_cache')) + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) UNIQUE_SQL = """ SELECT @@ -3047,7 +3329,7 @@ class PGDialect(default.DefaultDialect): cons.contype = 'u' """ - t = sql.text(UNIQUE_SQL, typemap={'col_name': sqltypes.Unicode}) + t = sql.text(UNIQUE_SQL, typemap={"col_name": sqltypes.Unicode}) c = connection.execute(t, table_oid=table_oid) uniques = defaultdict(lambda: defaultdict(dict)) @@ -3057,15 +3339,15 @@ class PGDialect(default.DefaultDialect): uc["cols"][row.col_num] = row.col_name return [ - {'name': name, - 'column_names': [uc["cols"][i] for i in uc["key"]]} + {"name": name, "column_names": [uc["cols"][i] for i in uc["key"]]} for name, uc in uniques.items() ] @reflection.cache def get_table_comment(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid(connection, table_name, schema, - info_cache=kw.get('info_cache')) + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) COMMENT_SQL = """ SELECT @@ -3081,10 +3363,10 @@ class PGDialect(default.DefaultDialect): return {"text": c.scalar()} @reflection.cache - def get_check_constraints( - self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid(connection, table_name, schema, - info_cache=kw.get('info_cache')) + def get_check_constraints(self, connection, table_name, schema=None, **kw): + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) CHECK_SQL = """ SELECT @@ -3100,10 +3382,8 @@ class PGDialect(default.DefaultDialect): c = connection.execute(sql.text(CHECK_SQL), table_oid=table_oid) return [ - {'name': name, - 'sqltext': src[1:-1]} - for name, src in c.fetchall() - ] + {"name": name, "sqltext": src[1:-1]} for name, src in c.fetchall() + ] def _load_enums(self, connection, schema=None): schema = schema or self.default_schema_name @@ -3124,17 +3404,18 @@ class PGDialect(default.DefaultDialect): WHERE t.typtype = 'e' """ - if schema != '*': + if schema != "*": SQL_ENUMS += "AND n.nspname = :schema " # e.oid gives us label order within an enum SQL_ENUMS += 'ORDER BY "schema", "name", e.oid' - s = sql.text(SQL_ENUMS, typemap={ - 'attname': sqltypes.Unicode, - 'label': sqltypes.Unicode}) + s = sql.text( + SQL_ENUMS, + typemap={"attname": sqltypes.Unicode, "label": sqltypes.Unicode}, + ) - if schema != '*': + if schema != "*": s = s.bindparams(schema=schema) c = connection.execute(s) @@ -3142,15 +3423,15 @@ class PGDialect(default.DefaultDialect): enums = [] enum_by_name = {} for enum in c.fetchall(): - key = (enum['schema'], enum['name']) + key = (enum["schema"], enum["name"]) if key in enum_by_name: - enum_by_name[key]['labels'].append(enum['label']) + enum_by_name[key]["labels"].append(enum["label"]) else: enum_by_name[key] = enum_rec = { - 'name': enum['name'], - 'schema': enum['schema'], - 'visible': enum['visible'], - 'labels': [enum['label']], + "name": enum["name"], + "schema": enum["schema"], + "visible": enum["visible"], + "labels": [enum["label"]], } enums.append(enum_rec) return enums @@ -3169,26 +3450,26 @@ class PGDialect(default.DefaultDialect): WHERE t.typtype = 'd' """ - s = sql.text(SQL_DOMAINS, typemap={'attname': sqltypes.Unicode}) + s = sql.text(SQL_DOMAINS, typemap={"attname": sqltypes.Unicode}) c = connection.execute(s) domains = {} for domain in c.fetchall(): # strip (30) from character varying(30) - attype = re.search(r'([^\(]+)', domain['attype']).group(1) + attype = re.search(r"([^\(]+)", domain["attype"]).group(1) # 'visible' just means whether or not the domain is in a # schema that's on the search path -- or not overridden by # a schema with higher precedence. If it's not visible, # it will be prefixed with the schema-name when it's used. - if domain['visible']: - key = (domain['name'], ) + if domain["visible"]: + key = (domain["name"],) else: - key = (domain['schema'], domain['name']) + key = (domain["schema"], domain["name"]) domains[key] = { - 'attype': attype, - 'nullable': domain['nullable'], - 'default': domain['default'] + "attype": attype, + "nullable": domain["nullable"], + "default": domain["default"], } return domains diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index 555a9006ce..825f132387 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -14,7 +14,7 @@ from ...sql.base import _generative from ... import util from . import ext -__all__ = ('Insert', 'insert') +__all__ = ("Insert", "insert") class Insert(StandardInsert): @@ -40,13 +40,17 @@ class Insert(StandardInsert): to use :attr:`.Insert.excluded` """ - return alias(self.table, name='excluded').columns + return alias(self.table, name="excluded").columns @_generative def on_conflict_do_update( - self, - constraint=None, index_elements=None, - index_where=None, set_=None, where=None): + self, + constraint=None, + index_elements=None, + index_where=None, + set_=None, + where=None, + ): """ Specifies a DO UPDATE SET action for ON CONFLICT clause. @@ -96,13 +100,14 @@ class Insert(StandardInsert): """ self._post_values_clause = OnConflictDoUpdate( - constraint, index_elements, index_where, set_, where) + constraint, index_elements, index_where, set_, where + ) return self @_generative def on_conflict_do_nothing( - self, - constraint=None, index_elements=None, index_where=None): + self, constraint=None, index_elements=None, index_where=None + ): """ Specifies a DO NOTHING action for ON CONFLICT clause. @@ -130,30 +135,29 @@ class Insert(StandardInsert): """ self._post_values_clause = OnConflictDoNothing( - constraint, index_elements, index_where) + constraint, index_elements, index_where + ) return self -insert = public_factory(Insert, '.dialects.postgresql.insert') + +insert = public_factory(Insert, ".dialects.postgresql.insert") class OnConflictClause(ClauseElement): - def __init__( - self, - constraint=None, - index_elements=None, - index_where=None): + def __init__(self, constraint=None, index_elements=None, index_where=None): if constraint is not None: - if not isinstance(constraint, util.string_types) and \ - isinstance(constraint, ( - schema.Index, schema.Constraint, - ext.ExcludeConstraint)): - constraint = getattr(constraint, 'name') or constraint + if not isinstance(constraint, util.string_types) and isinstance( + constraint, + (schema.Index, schema.Constraint, ext.ExcludeConstraint), + ): + constraint = getattr(constraint, "name") or constraint if constraint is not None: if index_elements is not None: raise ValueError( - "'constraint' and 'index_elements' are mutually exclusive") + "'constraint' and 'index_elements' are mutually exclusive" + ) if isinstance(constraint, util.string_types): self.constraint_target = constraint @@ -161,54 +165,61 @@ class OnConflictClause(ClauseElement): self.inferred_target_whereclause = None elif isinstance(constraint, schema.Index): index_elements = constraint.expressions - index_where = \ - constraint.dialect_options['postgresql'].get("where") + index_where = constraint.dialect_options["postgresql"].get( + "where" + ) elif isinstance(constraint, ext.ExcludeConstraint): index_elements = constraint.columns index_where = constraint.where else: index_elements = constraint.columns - index_where = \ - constraint.dialect_options['postgresql'].get("where") + index_where = constraint.dialect_options["postgresql"].get( + "where" + ) if index_elements is not None: self.constraint_target = None self.inferred_target_elements = index_elements self.inferred_target_whereclause = index_where elif constraint is None: - self.constraint_target = self.inferred_target_elements = \ - self.inferred_target_whereclause = None + self.constraint_target = ( + self.inferred_target_elements + ) = self.inferred_target_whereclause = None class OnConflictDoNothing(OnConflictClause): - __visit_name__ = 'on_conflict_do_nothing' + __visit_name__ = "on_conflict_do_nothing" class OnConflictDoUpdate(OnConflictClause): - __visit_name__ = 'on_conflict_do_update' + __visit_name__ = "on_conflict_do_update" def __init__( - self, - constraint=None, - index_elements=None, - index_where=None, - set_=None, - where=None): + self, + constraint=None, + index_elements=None, + index_where=None, + set_=None, + where=None, + ): super(OnConflictDoUpdate, self).__init__( constraint=constraint, index_elements=index_elements, - index_where=index_where) + index_where=index_where, + ) - if self.inferred_target_elements is None and \ - self.constraint_target is None: + if ( + self.inferred_target_elements is None + and self.constraint_target is None + ): raise ValueError( "Either constraint or index_elements, " - "but not both, must be specified unless DO NOTHING") + "but not both, must be specified unless DO NOTHING" + ) - if (not isinstance(set_, dict) or not set_): + if not isinstance(set_, dict) or not set_: raise ValueError("set parameter must be a non-empty dictionary") self.update_values_to_set = [ - (key, value) - for key, value in set_.items() + (key, value) for key, value in set_.items() ] self.update_whereclause = where diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index a588eafddd..da0c6250c4 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -47,7 +47,7 @@ class aggregate_order_by(expression.ColumnElement): """ - __visit_name__ = 'aggregate_order_by' + __visit_name__ = "aggregate_order_by" def __init__(self, target, *order_by): self.target = elements._literal_as_binds(target) @@ -59,8 +59,8 @@ class aggregate_order_by(expression.ColumnElement): self.order_by = elements._literal_as_binds(order_by[0]) else: self.order_by = elements.ClauseList( - *order_by, - _literal_as_text=elements._literal_as_binds) + *order_by, _literal_as_text=elements._literal_as_binds + ) def self_group(self, against=None): return self @@ -87,7 +87,7 @@ class ExcludeConstraint(ColumnCollectionConstraint): static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE """ - __visit_name__ = 'exclude_constraint' + __visit_name__ = "exclude_constraint" where = None @@ -173,8 +173,7 @@ static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE expressions, operators = zip(*elements) for (expr, column, strname, add_element), operator in zip( - self._extract_col_expression_collection(expressions), - operators + self._extract_col_expression_collection(expressions), operators ): if add_element is not None: columns.append(add_element) @@ -187,32 +186,31 @@ static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE expr = expression._literal_as_text(expr) - render_exprs.append( - (expr, name, operator) - ) + render_exprs.append((expr, name, operator)) self._render_exprs = render_exprs ColumnCollectionConstraint.__init__( self, *columns, - name=kw.get('name'), - deferrable=kw.get('deferrable'), - initially=kw.get('initially') + name=kw.get("name"), + deferrable=kw.get("deferrable"), + initially=kw.get("initially") ) - self.using = kw.get('using', 'gist') - where = kw.get('where') + self.using = kw.get("using", "gist") + where = kw.get("where") if where is not None: self.where = expression._literal_as_text(where) def copy(self, **kw): - elements = [(col, self.operators[col]) - for col in self.columns.keys()] - c = self.__class__(*elements, - name=self.name, - deferrable=self.deferrable, - initially=self.initially, - where=self.where, - using=self.using) + elements = [(col, self.operators[col]) for col in self.columns.keys()] + c = self.__class__( + *elements, + name=self.name, + deferrable=self.deferrable, + initially=self.initially, + where=self.where, + using=self.using + ) c.dispatch._update(self.dispatch) return c @@ -226,5 +224,5 @@ def array_agg(*arg, **kw): .. versionadded:: 1.1 """ - kw['_default_array_type'] = ARRAY + kw["_default_array_type"] = ARRAY return functions.func.array_agg(*arg, **kw) diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py index b6c9e7124c..e4bac692a3 100644 --- a/lib/sqlalchemy/dialects/postgresql/hstore.py +++ b/lib/sqlalchemy/dialects/postgresql/hstore.py @@ -14,38 +14,50 @@ from ...sql import functions as sqlfunc from ...sql import operators from ... import util -__all__ = ('HSTORE', 'hstore') +__all__ = ("HSTORE", "hstore") idx_precedence = operators._PRECEDENCE[operators.json_getitem_op] GETITEM = operators.custom_op( - "->", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "->", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) HAS_KEY = operators.custom_op( - "?", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "?", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) HAS_ALL = operators.custom_op( - "?&", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "?&", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) HAS_ANY = operators.custom_op( - "?|", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "?|", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) CONTAINS = operators.custom_op( - "@>", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "@>", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) CONTAINED_BY = operators.custom_op( - "<@", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "<@", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) @@ -122,7 +134,7 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): """ - __visit_name__ = 'HSTORE' + __visit_name__ = "HSTORE" hashable = False text_type = sqltypes.Text() @@ -139,7 +151,8 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): self.text_type = text_type class Comparator( - sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator): + sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator + ): """Define comparison operations for :class:`.HSTORE`.""" def has_key(self, other): @@ -169,7 +182,8 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): keys of the argument jsonb expression. """ return self.operate( - CONTAINED_BY, other, result_type=sqltypes.Boolean) + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) def _setup_getitem(self, index): return GETITEM, index, self.type.text_type @@ -223,12 +237,15 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): return _serialize_hstore(value).encode(encoding) else: return value + else: + def process(value): if isinstance(value, dict): return _serialize_hstore(value) else: return value + return process def result_processor(self, dialect, coltype): @@ -240,16 +257,19 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): return _parse_hstore(value.decode(encoding)) else: return value + else: + def process(value): if value is not None: return _parse_hstore(value) else: return value + return process -ischema_names['hstore'] = HSTORE +ischema_names["hstore"] = HSTORE class hstore(sqlfunc.GenericFunction): @@ -279,43 +299,44 @@ class hstore(sqlfunc.GenericFunction): :class:`.HSTORE` - the PostgreSQL ``HSTORE`` datatype. """ + type = HSTORE - name = 'hstore' + name = "hstore" class _HStoreDefinedFunction(sqlfunc.GenericFunction): type = sqltypes.Boolean - name = 'defined' + name = "defined" class _HStoreDeleteFunction(sqlfunc.GenericFunction): type = HSTORE - name = 'delete' + name = "delete" class _HStoreSliceFunction(sqlfunc.GenericFunction): type = HSTORE - name = 'slice' + name = "slice" class _HStoreKeysFunction(sqlfunc.GenericFunction): type = ARRAY(sqltypes.Text) - name = 'akeys' + name = "akeys" class _HStoreValsFunction(sqlfunc.GenericFunction): type = ARRAY(sqltypes.Text) - name = 'avals' + name = "avals" class _HStoreArrayFunction(sqlfunc.GenericFunction): type = ARRAY(sqltypes.Text) - name = 'hstore_to_array' + name = "hstore_to_array" class _HStoreMatrixFunction(sqlfunc.GenericFunction): type = ARRAY(sqltypes.Text) - name = 'hstore_to_matrix' + name = "hstore_to_matrix" # @@ -326,7 +347,8 @@ class _HStoreMatrixFunction(sqlfunc.GenericFunction): # My best guess at the parsing rules of hstore literals, since no formal # grammar is given. This is mostly reverse engineered from PG's input parser # behavior. -HSTORE_PAIR_RE = re.compile(r""" +HSTORE_PAIR_RE = re.compile( + r""" ( "(?P (\\ . | [^"])* )" # Quoted key ) @@ -335,11 +357,16 @@ HSTORE_PAIR_RE = re.compile(r""" (?P NULL ) # NULL value | "(?P (\\ . | [^"])* )" # Quoted value ) -""", re.VERBOSE) +""", + re.VERBOSE, +) -HSTORE_DELIMITER_RE = re.compile(r""" +HSTORE_DELIMITER_RE = re.compile( + r""" [ ]* , [ ]* -""", re.VERBOSE) +""", + re.VERBOSE, +) def _parse_error(hstore_str, pos): @@ -348,16 +375,19 @@ def _parse_error(hstore_str, pos): ctx = 20 hslen = len(hstore_str) - parsed_tail = hstore_str[max(pos - ctx - 1, 0):min(pos, hslen)] - residual = hstore_str[min(pos, hslen):min(pos + ctx + 1, hslen)] + parsed_tail = hstore_str[max(pos - ctx - 1, 0) : min(pos, hslen)] + residual = hstore_str[min(pos, hslen) : min(pos + ctx + 1, hslen)] if len(parsed_tail) > ctx: - parsed_tail = '[...]' + parsed_tail[1:] + parsed_tail = "[...]" + parsed_tail[1:] if len(residual) > ctx: - residual = residual[:-1] + '[...]' + residual = residual[:-1] + "[...]" return "After %r, could not parse residual at position %d: %r" % ( - parsed_tail, pos, residual) + parsed_tail, + pos, + residual, + ) def _parse_hstore(hstore_str): @@ -377,13 +407,15 @@ def _parse_hstore(hstore_str): pair_match = HSTORE_PAIR_RE.match(hstore_str) while pair_match is not None: - key = pair_match.group('key').replace(r'\"', '"').replace( - "\\\\", "\\") - if pair_match.group('value_null'): + key = pair_match.group("key").replace(r"\"", '"').replace("\\\\", "\\") + if pair_match.group("value_null"): value = None else: - value = pair_match.group('value').replace( - r'\"', '"').replace("\\\\", "\\") + value = ( + pair_match.group("value") + .replace(r"\"", '"') + .replace("\\\\", "\\") + ) result[key] = value pos += pair_match.end() @@ -405,16 +437,17 @@ def _serialize_hstore(val): both be strings (except None for values). """ + def esc(s, position): - if position == 'value' and s is None: - return 'NULL' + if position == "value" and s is None: + return "NULL" elif isinstance(s, util.string_types): - return '"%s"' % s.replace("\\", "\\\\").replace('"', r'\"') + return '"%s"' % s.replace("\\", "\\\\").replace('"', r"\"") else: - raise ValueError("%r in %s position is not a string." % - (s, position)) - - return ', '.join('%s=>%s' % (esc(k, 'key'), esc(v, 'value')) - for k, v in val.items()) - + raise ValueError( + "%r in %s position is not a string." % (s, position) + ) + return ", ".join( + "%s=>%s" % (esc(k, "key"), esc(v, "value")) for k, v in val.items() + ) diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index e9256daf31..f9421de37c 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -12,44 +12,58 @@ from ...sql import operators from ...sql import elements from ... import util -__all__ = ('JSON', 'JSONB') +__all__ = ("JSON", "JSONB") idx_precedence = operators._PRECEDENCE[operators.json_getitem_op] ASTEXT = operators.custom_op( - "->>", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "->>", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) JSONPATH_ASTEXT = operators.custom_op( - "#>>", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "#>>", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) HAS_KEY = operators.custom_op( - "?", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "?", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) HAS_ALL = operators.custom_op( - "?&", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "?&", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) HAS_ANY = operators.custom_op( - "?|", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "?|", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) CONTAINS = operators.custom_op( - "@>", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "@>", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) CONTAINED_BY = operators.custom_op( - "<@", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "<@", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) @@ -59,7 +73,7 @@ class JSONPathType(sqltypes.JSON.JSONPathType): def process(value): assert isinstance(value, util.collections_abc.Sequence) - tokens = [util.text_type(elem)for elem in value] + tokens = [util.text_type(elem) for elem in value] value = "{%s}" % (", ".join(tokens)) if super_proc: value = super_proc(value) @@ -72,7 +86,7 @@ class JSONPathType(sqltypes.JSON.JSONPathType): def process(value): assert isinstance(value, util.collections_abc.Sequence) - tokens = [util.text_type(elem)for elem in value] + tokens = [util.text_type(elem) for elem in value] value = "{%s}" % (", ".join(tokens)) if super_proc: value = super_proc(value) @@ -80,6 +94,7 @@ class JSONPathType(sqltypes.JSON.JSONPathType): return process + colspecs[sqltypes.JSON.JSONPathType] = JSONPathType @@ -203,16 +218,19 @@ class JSON(sqltypes.JSON): if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType): return self.expr.left.operate( JSONPATH_ASTEXT, - self.expr.right, result_type=self.type.astext_type) + self.expr.right, + result_type=self.type.astext_type, + ) else: return self.expr.left.operate( - ASTEXT, self.expr.right, result_type=self.type.astext_type) + ASTEXT, self.expr.right, result_type=self.type.astext_type + ) comparator_factory = Comparator colspecs[sqltypes.JSON] = JSON -ischema_names['json'] = JSON +ischema_names["json"] = JSON class JSONB(JSON): @@ -259,7 +277,7 @@ class JSONB(JSON): """ - __visit_name__ = 'JSONB' + __visit_name__ = "JSONB" class Comparator(JSON.Comparator): """Define comparison operations for :class:`.JSON`.""" @@ -291,8 +309,10 @@ class JSONB(JSON): keys of the argument jsonb expression. """ return self.operate( - CONTAINED_BY, other, result_type=sqltypes.Boolean) + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) comparator_factory = Comparator -ischema_names['jsonb'] = JSONB + +ischema_names["jsonb"] = JSONB diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index 80929b8086..fef09e0ebb 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -69,8 +69,15 @@ import decimal from ... import processors from ... import types as sqltypes from .base import ( - PGDialect, PGCompiler, PGIdentifierPreparer, PGExecutionContext, - _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID) + PGDialect, + PGCompiler, + PGIdentifierPreparer, + PGExecutionContext, + _DECIMAL_TYPES, + _FLOAT_TYPES, + _INT_TYPES, + UUID, +) import re from sqlalchemy.dialects.postgresql.json import JSON from ...sql.elements import quoted_name @@ -86,13 +93,15 @@ class _PGNumeric(sqltypes.Numeric): if self.asdecimal: if coltype in _FLOAT_TYPES: return processors.to_decimal_processor_factory( - decimal.Decimal, self._effective_decimal_return_scale) + decimal.Decimal, self._effective_decimal_return_scale + ) elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: # pg8000 returns Decimal natively for 1700 return None else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype + ) else: if coltype in _FLOAT_TYPES: # pg8000 returns float natively for 701 @@ -101,7 +110,8 @@ class _PGNumeric(sqltypes.Numeric): return processors.to_float else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype + ) class _PGNumericNoBind(_PGNumeric): @@ -110,7 +120,6 @@ class _PGNumericNoBind(_PGNumeric): class _PGJSON(JSON): - def result_processor(self, dialect, coltype): if dialect._dbapi_version > (1, 10, 1): return None # Has native JSON @@ -121,18 +130,22 @@ class _PGJSON(JSON): class _PGUUID(UUID): def bind_processor(self, dialect): if not self.as_uuid: + def process(value): if value is not None: value = _python_UUID(value) return value + return process def result_processor(self, dialect, coltype): if not self.as_uuid: + def process(value): if value is not None: value = str(value) return value + return process @@ -142,36 +155,41 @@ class PGExecutionContext_pg8000(PGExecutionContext): class PGCompiler_pg8000(PGCompiler): def visit_mod_binary(self, binary, operator, **kw): - return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) def post_process_text(self, text): - if '%%' in text: - util.warn("The SQLAlchemy postgresql dialect " - "now automatically escapes '%' in text() " - "expressions to '%%'.") - return text.replace('%', '%%') + if "%%" in text: + util.warn( + "The SQLAlchemy postgresql dialect " + "now automatically escapes '%' in text() " + "expressions to '%%'." + ) + return text.replace("%", "%%") class PGIdentifierPreparer_pg8000(PGIdentifierPreparer): def _escape_identifier(self, value): value = value.replace(self.escape_quote, self.escape_to_quote) - return value.replace('%', '%%') + return value.replace("%", "%%") class PGDialect_pg8000(PGDialect): - driver = 'pg8000' + driver = "pg8000" supports_unicode_statements = True supports_unicode_binds = True - default_paramstyle = 'format' + default_paramstyle = "format" supports_sane_multi_rowcount = True execution_ctx_cls = PGExecutionContext_pg8000 statement_compiler = PGCompiler_pg8000 preparer = PGIdentifierPreparer_pg8000 - description_encoding = 'use_encoding' + description_encoding = "use_encoding" colspecs = util.update_copy( PGDialect.colspecs, @@ -180,8 +198,8 @@ class PGDialect_pg8000(PGDialect): sqltypes.Float: _PGNumeric, JSON: _PGJSON, sqltypes.JSON: _PGJSON, - UUID: _PGUUID - } + UUID: _PGUUID, + }, ) def __init__(self, client_encoding=None, **kwargs): @@ -194,22 +212,26 @@ class PGDialect_pg8000(PGDialect): @util.memoized_property def _dbapi_version(self): - if self.dbapi and hasattr(self.dbapi, '__version__'): + if self.dbapi and hasattr(self.dbapi, "__version__"): return tuple( [ - int(x) for x in re.findall( - r'(\d+)(?:[-\.]?|$)', self.dbapi.__version__)]) + int(x) + for x in re.findall( + r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__ + ) + ] + ) else: return (99, 99, 99) @classmethod def dbapi(cls): - return __import__('pg8000') + return __import__("pg8000") def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if 'port' in opts: - opts['port'] = int(opts['port']) + opts = url.translate_connect_args(username="user") + if "port" in opts: + opts["port"] = int(opts["port"]) opts.update(url.query) return ([], opts) @@ -217,32 +239,33 @@ class PGDialect_pg8000(PGDialect): return "connection is closed" in str(e) def set_isolation_level(self, connection, level): - level = level.replace('_', ' ') + level = level.replace("_", " ") # adjust for ConnectionFairy possibly being present - if hasattr(connection, 'connection'): + if hasattr(connection, "connection"): connection = connection.connection - if level == 'AUTOCOMMIT': + if level == "AUTOCOMMIT": connection.autocommit = True elif level in self._isolation_lookup: connection.autocommit = False cursor = connection.cursor() cursor.execute( "SET SESSION CHARACTERISTICS AS TRANSACTION " - "ISOLATION LEVEL %s" % level) + "ISOLATION LEVEL %s" % level + ) cursor.execute("COMMIT") cursor.close() else: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s or AUTOCOMMIT" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s or AUTOCOMMIT" + % (level, self.name, ", ".join(self._isolation_lookup)) ) def set_client_encoding(self, connection, client_encoding): # adjust for ConnectionFairy possibly being present - if hasattr(connection, 'connection'): + if hasattr(connection, "connection"): connection = connection.connection cursor = connection.cursor() @@ -251,18 +274,20 @@ class PGDialect_pg8000(PGDialect): cursor.close() def do_begin_twophase(self, connection, xid): - connection.connection.tpc_begin((0, xid, '')) + connection.connection.tpc_begin((0, xid, "")) def do_prepare_twophase(self, connection, xid): connection.connection.tpc_prepare() def do_rollback_twophase( - self, connection, xid, is_prepared=True, recover=False): - connection.connection.tpc_rollback((0, xid, '')) + self, connection, xid, is_prepared=True, recover=False + ): + connection.connection.tpc_rollback((0, xid, "")) def do_commit_twophase( - self, connection, xid, is_prepared=True, recover=False): - connection.connection.tpc_commit((0, xid, '')) + self, connection, xid, is_prepared=True, recover=False + ): + connection.connection.tpc_commit((0, xid, "")) def do_recover_twophase(self, connection): return [row[1] for row in connection.connection.tpc_recover()] @@ -272,24 +297,32 @@ class PGDialect_pg8000(PGDialect): def on_connect(conn): conn.py_types[quoted_name] = conn.py_types[util.text_type] + fns.append(on_connect) if self.client_encoding is not None: + def on_connect(conn): self.set_client_encoding(conn, self.client_encoding) + fns.append(on_connect) if self.isolation_level is not None: + def on_connect(conn): self.set_isolation_level(conn, self.isolation_level) + fns.append(on_connect) if len(fns) > 0: + def on_connect(conn): for fn in fns: fn(conn) + return on_connect else: return None + dialect = PGDialect_pg8000 diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index baa0e00d52..2c27c69193 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -353,10 +353,17 @@ from ... import processors from ...engine import result as _result from ...sql import expression from ... import types as sqltypes -from .base import PGDialect, PGCompiler, \ - PGIdentifierPreparer, PGExecutionContext, \ - ENUM, _DECIMAL_TYPES, _FLOAT_TYPES,\ - _INT_TYPES, UUID +from .base import ( + PGDialect, + PGCompiler, + PGIdentifierPreparer, + PGExecutionContext, + ENUM, + _DECIMAL_TYPES, + _FLOAT_TYPES, + _INT_TYPES, + UUID, +) from .hstore import HSTORE from .json import JSON, JSONB @@ -366,7 +373,7 @@ except ImportError: _python_UUID = None -logger = logging.getLogger('sqlalchemy.dialects.postgresql') +logger = logging.getLogger("sqlalchemy.dialects.postgresql") class _PGNumeric(sqltypes.Numeric): @@ -377,14 +384,15 @@ class _PGNumeric(sqltypes.Numeric): if self.asdecimal: if coltype in _FLOAT_TYPES: return processors.to_decimal_processor_factory( - decimal.Decimal, - self._effective_decimal_return_scale) + decimal.Decimal, self._effective_decimal_return_scale + ) elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: # pg8000 returns Decimal natively for 1700 return None else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype + ) else: if coltype in _FLOAT_TYPES: # pg8000 returns float natively for 701 @@ -393,7 +401,8 @@ class _PGNumeric(sqltypes.Numeric): return processors.to_float else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype + ) class _PGEnum(ENUM): @@ -421,7 +430,6 @@ class _PGHStore(HSTORE): class _PGJSON(JSON): - def result_processor(self, dialect, coltype): if dialect._has_native_json: return None @@ -430,7 +438,6 @@ class _PGJSON(JSON): class _PGJSONB(JSONB): - def result_processor(self, dialect, coltype): if dialect._has_native_jsonb: return None @@ -447,14 +454,17 @@ class _PGUUID(UUID): if value is not None: value = _python_UUID(value) return value + return process def result_processor(self, dialect, coltype): if not self.as_uuid and dialect.use_native_uuid: + def process(value): if value is not None: value = str(value) return value + return process @@ -465,8 +475,7 @@ class PGExecutionContext_psycopg2(PGExecutionContext): def create_server_side_cursor(self): # use server-side cursors: # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html - ident = "c_%s_%s" % (hex(id(self))[2:], - hex(_server_side_id())[2:]) + ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:]) return self._dbapi_connection.cursor(ident) def get_result_proxy(self): @@ -497,13 +506,13 @@ class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer): class PGDialect_psycopg2(PGDialect): - driver = 'psycopg2' + driver = "psycopg2" if util.py2k: supports_unicode_statements = False supports_server_side_cursors = True - default_paramstyle = 'pyformat' + default_paramstyle = "pyformat" # set to true based on psycopg2 version supports_sane_multi_rowcount = False execution_ctx_cls = PGExecutionContext_psycopg2 @@ -516,16 +525,16 @@ class PGDialect_psycopg2(PGDialect): native_jsonb=(2, 5, 4), sane_multi_rowcount=(2, 0, 9), array_oid=(2, 4, 3), - hstore_adapter=(2, 4) + hstore_adapter=(2, 4), ) _has_native_hstore = False _has_native_json = False _has_native_jsonb = False - engine_config_types = PGDialect.engine_config_types.union([ - ('use_native_unicode', util.asbool), - ]) + engine_config_types = PGDialect.engine_config_types.union( + [("use_native_unicode", util.asbool)] + ) colspecs = util.update_copy( PGDialect.colspecs, @@ -537,15 +546,20 @@ class PGDialect_psycopg2(PGDialect): JSON: _PGJSON, sqltypes.JSON: _PGJSON, JSONB: _PGJSONB, - UUID: _PGUUID - } + UUID: _PGUUID, + }, ) - def __init__(self, server_side_cursors=False, use_native_unicode=True, - client_encoding=None, - use_native_hstore=True, use_native_uuid=True, - use_batch_mode=False, - **kwargs): + def __init__( + self, + server_side_cursors=False, + use_native_unicode=True, + client_encoding=None, + use_native_hstore=True, + use_native_uuid=True, + use_batch_mode=False, + **kwargs + ): PGDialect.__init__(self, **kwargs) self.server_side_cursors = server_side_cursors self.use_native_unicode = use_native_unicode @@ -554,65 +568,70 @@ class PGDialect_psycopg2(PGDialect): self.supports_unicode_binds = use_native_unicode self.client_encoding = client_encoding self.psycopg2_batch_mode = use_batch_mode - if self.dbapi and hasattr(self.dbapi, '__version__'): - m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?', - self.dbapi.__version__) + if self.dbapi and hasattr(self.dbapi, "__version__"): + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__) if m: self.psycopg2_version = tuple( - int(x) - for x in m.group(1, 2, 3) - if x is not None) + int(x) for x in m.group(1, 2, 3) if x is not None + ) def initialize(self, connection): super(PGDialect_psycopg2, self).initialize(connection) - self._has_native_hstore = self.use_native_hstore and \ - self._hstore_oids(connection.connection) \ - is not None - self._has_native_json = \ - self.psycopg2_version >= self.FEATURE_VERSION_MAP['native_json'] - self._has_native_jsonb = \ - self.psycopg2_version >= self.FEATURE_VERSION_MAP['native_jsonb'] + self._has_native_hstore = ( + self.use_native_hstore + and self._hstore_oids(connection.connection) is not None + ) + self._has_native_json = ( + self.psycopg2_version >= self.FEATURE_VERSION_MAP["native_json"] + ) + self._has_native_jsonb = ( + self.psycopg2_version >= self.FEATURE_VERSION_MAP["native_jsonb"] + ) # http://initd.org/psycopg/docs/news.html#what-s-new-in-psycopg-2-0-9 - self.supports_sane_multi_rowcount = \ - self.psycopg2_version >= \ - self.FEATURE_VERSION_MAP['sane_multi_rowcount'] and \ - not self.psycopg2_batch_mode + self.supports_sane_multi_rowcount = ( + self.psycopg2_version + >= self.FEATURE_VERSION_MAP["sane_multi_rowcount"] + and not self.psycopg2_batch_mode + ) @classmethod def dbapi(cls): import psycopg2 + return psycopg2 @classmethod def _psycopg2_extensions(cls): from psycopg2 import extensions + return extensions @classmethod def _psycopg2_extras(cls): from psycopg2 import extras + return extras @util.memoized_property def _isolation_lookup(self): extensions = self._psycopg2_extensions() return { - 'AUTOCOMMIT': extensions.ISOLATION_LEVEL_AUTOCOMMIT, - 'READ COMMITTED': extensions.ISOLATION_LEVEL_READ_COMMITTED, - 'READ UNCOMMITTED': extensions.ISOLATION_LEVEL_READ_UNCOMMITTED, - 'REPEATABLE READ': extensions.ISOLATION_LEVEL_REPEATABLE_READ, - 'SERIALIZABLE': extensions.ISOLATION_LEVEL_SERIALIZABLE + "AUTOCOMMIT": extensions.ISOLATION_LEVEL_AUTOCOMMIT, + "READ COMMITTED": extensions.ISOLATION_LEVEL_READ_COMMITTED, + "READ UNCOMMITTED": extensions.ISOLATION_LEVEL_READ_UNCOMMITTED, + "REPEATABLE READ": extensions.ISOLATION_LEVEL_REPEATABLE_READ, + "SERIALIZABLE": extensions.ISOLATION_LEVEL_SERIALIZABLE, } def set_isolation_level(self, connection, level): try: - level = self._isolation_lookup[level.replace('_', ' ')] + level = self._isolation_lookup[level.replace("_", " ")] except KeyError: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) ) connection.set_isolation_level(level) @@ -623,54 +642,72 @@ class PGDialect_psycopg2(PGDialect): fns = [] if self.client_encoding is not None: + def on_connect(conn): conn.set_client_encoding(self.client_encoding) + fns.append(on_connect) if self.isolation_level is not None: + def on_connect(conn): self.set_isolation_level(conn, self.isolation_level) + fns.append(on_connect) if self.dbapi and self.use_native_uuid: + def on_connect(conn): extras.register_uuid(None, conn) + fns.append(on_connect) if self.dbapi and self.use_native_unicode: + def on_connect(conn): extensions.register_type(extensions.UNICODE, conn) extensions.register_type(extensions.UNICODEARRAY, conn) + fns.append(on_connect) if self.dbapi and self.use_native_hstore: + def on_connect(conn): hstore_oids = self._hstore_oids(conn) if hstore_oids is not None: oid, array_oid = hstore_oids - kw = {'oid': oid} + kw = {"oid": oid} if util.py2k: - kw['unicode'] = True - if self.psycopg2_version >= \ - self.FEATURE_VERSION_MAP['array_oid']: - kw['array_oid'] = array_oid + kw["unicode"] = True + if ( + self.psycopg2_version + >= self.FEATURE_VERSION_MAP["array_oid"] + ): + kw["array_oid"] = array_oid extras.register_hstore(conn, **kw) + fns.append(on_connect) if self.dbapi and self._json_deserializer: + def on_connect(conn): if self._has_native_json: extras.register_default_json( - conn, loads=self._json_deserializer) + conn, loads=self._json_deserializer + ) if self._has_native_jsonb: extras.register_default_jsonb( - conn, loads=self._json_deserializer) + conn, loads=self._json_deserializer + ) + fns.append(on_connect) if fns: + def on_connect(conn): for fn in fns: fn(conn) + return on_connect else: return None @@ -684,7 +721,7 @@ class PGDialect_psycopg2(PGDialect): @util.memoized_instancemethod def _hstore_oids(self, conn): - if self.psycopg2_version >= self.FEATURE_VERSION_MAP['hstore_adapter']: + if self.psycopg2_version >= self.FEATURE_VERSION_MAP["hstore_adapter"]: extras = self._psycopg2_extras() oids = extras.HstoreAdapter.get_oids(conn) if oids is not None and oids[0]: @@ -692,9 +729,9 @@ class PGDialect_psycopg2(PGDialect): return None def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if 'port' in opts: - opts['port'] = int(opts['port']) + opts = url.translate_connect_args(username="user") + if "port" in opts: + opts["port"] = int(opts["port"]) opts.update(url.query) return ([], opts) @@ -704,7 +741,7 @@ class PGDialect_psycopg2(PGDialect): # present on old psycopg2 versions. Also, # this flag doesn't actually help in a lot of disconnect # situations, so don't rely on it. - if getattr(connection, 'closed', False): + if getattr(connection, "closed", False): return True # checks based on strings. in the case that .closed @@ -713,28 +750,29 @@ class PGDialect_psycopg2(PGDialect): for msg in [ # these error messages from libpq: interfaces/libpq/fe-misc.c # and interfaces/libpq/fe-secure.c. - 'terminating connection', - 'closed the connection', - 'connection not open', - 'could not receive data from server', - 'could not send data to server', + "terminating connection", + "closed the connection", + "connection not open", + "could not receive data from server", + "could not send data to server", # psycopg2 client errors, psycopg2/conenction.h, # psycopg2/cursor.h - 'connection already closed', - 'cursor already closed', + "connection already closed", + "cursor already closed", # not sure where this path is originally from, it may # be obsolete. It really says "losed", not "closed". - 'losed the connection unexpectedly', + "losed the connection unexpectedly", # these can occur in newer SSL - 'connection has been closed unexpectedly', - 'SSL SYSCALL error: Bad file descriptor', - 'SSL SYSCALL error: EOF detected', - 'SSL error: decryption failed or bad record mac', - 'SSL SYSCALL error: Operation timed out', + "connection has been closed unexpectedly", + "SSL SYSCALL error: Bad file descriptor", + "SSL SYSCALL error: EOF detected", + "SSL error: decryption failed or bad record mac", + "SSL SYSCALL error: Operation timed out", ]: idx = str_e.find(msg) if idx >= 0 and '"' not in str_e[:idx]: return True return False + dialect = PGDialect_psycopg2 diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py index a1141a90e4..7343bc9735 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py @@ -28,7 +28,7 @@ from .psycopg2 import PGDialect_psycopg2 class PGDialect_psycopg2cffi(PGDialect_psycopg2): - driver = 'psycopg2cffi' + driver = "psycopg2cffi" supports_unicode_statements = True # psycopg2cffi's first release is 2.5.0, but reports @@ -40,21 +40,21 @@ class PGDialect_psycopg2cffi(PGDialect_psycopg2): native_jsonb=(2, 7, 1), sane_multi_rowcount=(2, 4, 4), array_oid=(2, 4, 4), - hstore_adapter=(2, 4, 4) + hstore_adapter=(2, 4, 4), ) @classmethod def dbapi(cls): - return __import__('psycopg2cffi') + return __import__("psycopg2cffi") @classmethod def _psycopg2_extensions(cls): - root = __import__('psycopg2cffi', fromlist=['extensions']) + root = __import__("psycopg2cffi", fromlist=["extensions"]) return root.extensions @classmethod def _psycopg2_extras(cls): - root = __import__('psycopg2cffi', fromlist=['extras']) + root = __import__("psycopg2cffi", fromlist=["extras"]) return root.extras diff --git a/lib/sqlalchemy/dialects/postgresql/pygresql.py b/lib/sqlalchemy/dialects/postgresql/pygresql.py index 304afca448..c7edb8fc3d 100644 --- a/lib/sqlalchemy/dialects/postgresql/pygresql.py +++ b/lib/sqlalchemy/dialects/postgresql/pygresql.py @@ -20,14 +20,20 @@ import re from ... import exc, processors, util from ...types import Numeric, JSON as Json from ...sql.elements import Null -from .base import PGDialect, PGCompiler, PGIdentifierPreparer, \ - _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID +from .base import ( + PGDialect, + PGCompiler, + PGIdentifierPreparer, + _DECIMAL_TYPES, + _FLOAT_TYPES, + _INT_TYPES, + UUID, +) from .hstore import HSTORE from .json import JSON, JSONB class _PGNumeric(Numeric): - def bind_processor(self, dialect): return None @@ -37,14 +43,15 @@ class _PGNumeric(Numeric): if self.asdecimal: if coltype in _FLOAT_TYPES: return processors.to_decimal_processor_factory( - decimal.Decimal, - self._effective_decimal_return_scale) + decimal.Decimal, self._effective_decimal_return_scale + ) elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: # PyGreSQL returns Decimal natively for 1700 (numeric) return None else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype + ) else: if coltype in _FLOAT_TYPES: # PyGreSQL returns float natively for 701 (float8) @@ -53,19 +60,21 @@ class _PGNumeric(Numeric): return processors.to_float else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype + ) class _PGHStore(HSTORE): - def bind_processor(self, dialect): if not dialect.has_native_hstore: return super(_PGHStore, self).bind_processor(dialect) hstore = dialect.dbapi.Hstore + def process(value): if isinstance(value, dict): return hstore(value) return value + return process def result_processor(self, dialect, coltype): @@ -74,7 +83,6 @@ class _PGHStore(HSTORE): class _PGJSON(JSON): - def bind_processor(self, dialect): if not dialect.has_native_json: return super(_PGJSON, self).bind_processor(dialect) @@ -84,7 +92,8 @@ class _PGJSON(JSON): if value is self.NULL: value = None elif isinstance(value, Null) or ( - value is None and self.none_as_null): + value is None and self.none_as_null + ): return None if value is None or isinstance(value, (dict, list)): return json(value) @@ -98,7 +107,6 @@ class _PGJSON(JSON): class _PGJSONB(JSONB): - def bind_processor(self, dialect): if not dialect.has_native_json: return super(_PGJSONB, self).bind_processor(dialect) @@ -108,7 +116,8 @@ class _PGJSONB(JSONB): if value is self.NULL: value = None elif isinstance(value, Null) or ( - value is None and self.none_as_null): + value is None and self.none_as_null + ): return None if value is None or isinstance(value, (dict, list)): return json(value) @@ -122,7 +131,6 @@ class _PGJSONB(JSONB): class _PGUUID(UUID): - def bind_processor(self, dialect): if not dialect.has_native_uuid: return super(_PGUUID, self).bind_processor(dialect) @@ -145,32 +153,35 @@ class _PGUUID(UUID): if not dialect.has_native_uuid: return super(_PGUUID, self).result_processor(dialect, coltype) if not self.as_uuid: + def process(value): if value is not None: return str(value) + return process class _PGCompiler(PGCompiler): - def visit_mod_binary(self, binary, operator, **kw): - return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) def post_process_text(self, text): - return text.replace('%', '%%') + return text.replace("%", "%%") class _PGIdentifierPreparer(PGIdentifierPreparer): - def _escape_identifier(self, value): value = value.replace(self.escape_quote, self.escape_to_quote) - return value.replace('%', '%%') + return value.replace("%", "%%") class PGDialect_pygresql(PGDialect): - driver = 'pygresql' + driver = "pygresql" statement_compiler = _PGCompiler preparer = _PGIdentifierPreparer @@ -178,6 +189,7 @@ class PGDialect_pygresql(PGDialect): @classmethod def dbapi(cls): import pgdb + return pgdb colspecs = util.update_copy( @@ -189,14 +201,14 @@ class PGDialect_pygresql(PGDialect): JSON: _PGJSON, JSONB: _PGJSONB, UUID: _PGUUID, - } + }, ) def __init__(self, **kwargs): super(PGDialect_pygresql, self).__init__(**kwargs) try: version = self.dbapi.version - m = re.match(r'(\d+)\.(\d+)', version) + m = re.match(r"(\d+)\.(\d+)", version) version = (int(m.group(1)), int(m.group(2))) except (AttributeError, ValueError, TypeError): version = (0, 0) @@ -204,8 +216,10 @@ class PGDialect_pygresql(PGDialect): if version < (5, 0): has_native_hstore = has_native_json = has_native_uuid = False if version != (0, 0): - util.warn("PyGreSQL is only fully supported by SQLAlchemy" - " since version 5.0.") + util.warn( + "PyGreSQL is only fully supported by SQLAlchemy" + " since version 5.0." + ) else: self.supports_unicode_statements = True self.supports_unicode_binds = True @@ -215,10 +229,12 @@ class PGDialect_pygresql(PGDialect): self.has_native_uuid = has_native_uuid def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if 'port' in opts: - opts['host'] = '%s:%s' % ( - opts.get('host', '').rsplit(':', 1)[0], opts.pop('port')) + opts = url.translate_connect_args(username="user") + if "port" in opts: + opts["host"] = "%s:%s" % ( + opts.get("host", "").rsplit(":", 1)[0], + opts.pop("port"), + ) opts.update(url.query) return [], opts diff --git a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py index b633323b40..93bf653a45 100644 --- a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py +++ b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py @@ -37,12 +37,12 @@ class PGExecutionContext_pypostgresql(PGExecutionContext): class PGDialect_pypostgresql(PGDialect): - driver = 'pypostgresql' + driver = "pypostgresql" supports_unicode_statements = True supports_unicode_binds = True description_encoding = None - default_paramstyle = 'pyformat' + default_paramstyle = "pyformat" # requires trunk version to support sane rowcounts # TODO: use dbapi version information to set this flag appropriately @@ -54,22 +54,27 @@ class PGDialect_pypostgresql(PGDialect): PGDialect.colspecs, { sqltypes.Numeric: PGNumeric, - # prevents PGNumeric from being used sqltypes.Float: sqltypes.Float, - } + }, ) @classmethod def dbapi(cls): from postgresql.driver import dbapi20 + return dbapi20 _DBAPI_ERROR_NAMES = [ "Error", - "InterfaceError", "DatabaseError", "DataError", - "OperationalError", "IntegrityError", "InternalError", - "ProgrammingError", "NotSupportedError" + "InterfaceError", + "DatabaseError", + "DataError", + "OperationalError", + "IntegrityError", + "InternalError", + "ProgrammingError", + "NotSupportedError", ] @util.memoized_property @@ -83,15 +88,16 @@ class PGDialect_pypostgresql(PGDialect): ) def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if 'port' in opts: - opts['port'] = int(opts['port']) + opts = url.translate_connect_args(username="user") + if "port" in opts: + opts["port"] = int(opts["port"]) else: - opts['port'] = 5432 + opts["port"] = 5432 opts.update(url.query) return ([], opts) def is_disconnect(self, e, connection, cursor): return "connection is closed" in str(e) + dialect = PGDialect_pypostgresql diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index eb2d86bbdc..62d1275a60 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -7,7 +7,7 @@ from .base import ischema_names from ... import types as sqltypes -__all__ = ('INT4RANGE', 'INT8RANGE', 'NUMRANGE') +__all__ = ("INT4RANGE", "INT8RANGE", "NUMRANGE") class RangeOperators(object): @@ -34,35 +34,36 @@ class RangeOperators(object): def __ne__(self, other): "Boolean expression. Returns true if two ranges are not equal" if other is None: - return super( - RangeOperators.comparator_factory, self).__ne__(other) + return super(RangeOperators.comparator_factory, self).__ne__( + other + ) else: - return self.expr.op('<>')(other) + return self.expr.op("<>")(other) def contains(self, other, **kw): """Boolean expression. Returns true if the right hand operand, which can be an element or a range, is contained within the column. """ - return self.expr.op('@>')(other) + return self.expr.op("@>")(other) def contained_by(self, other): """Boolean expression. Returns true if the column is contained within the right hand operand. """ - return self.expr.op('<@')(other) + return self.expr.op("<@")(other) def overlaps(self, other): """Boolean expression. Returns true if the column overlaps (has points in common with) the right hand operand. """ - return self.expr.op('&&')(other) + return self.expr.op("&&")(other) def strictly_left_of(self, other): """Boolean expression. Returns true if the column is strictly left of the right hand operand. """ - return self.expr.op('<<')(other) + return self.expr.op("<<")(other) __lshift__ = strictly_left_of @@ -70,7 +71,7 @@ class RangeOperators(object): """Boolean expression. Returns true if the column is strictly right of the right hand operand. """ - return self.expr.op('>>')(other) + return self.expr.op(">>")(other) __rshift__ = strictly_right_of @@ -78,26 +79,26 @@ class RangeOperators(object): """Boolean expression. Returns true if the range in the column does not extend right of the range in the operand. """ - return self.expr.op('&<')(other) + return self.expr.op("&<")(other) def not_extend_left_of(self, other): """Boolean expression. Returns true if the range in the column does not extend left of the range in the operand. """ - return self.expr.op('&>')(other) + return self.expr.op("&>")(other) def adjacent_to(self, other): """Boolean expression. Returns true if the range in the column is adjacent to the range in the operand. """ - return self.expr.op('-|-')(other) + return self.expr.op("-|-")(other) def __add__(self, other): """Range expression. Returns the union of the two ranges. Will raise an exception if the resulting range is not contigous. """ - return self.expr.op('+')(other) + return self.expr.op("+")(other) class INT4RANGE(RangeOperators, sqltypes.TypeEngine): @@ -107,9 +108,10 @@ class INT4RANGE(RangeOperators, sqltypes.TypeEngine): """ - __visit_name__ = 'INT4RANGE' + __visit_name__ = "INT4RANGE" -ischema_names['int4range'] = INT4RANGE + +ischema_names["int4range"] = INT4RANGE class INT8RANGE(RangeOperators, sqltypes.TypeEngine): @@ -119,9 +121,10 @@ class INT8RANGE(RangeOperators, sqltypes.TypeEngine): """ - __visit_name__ = 'INT8RANGE' + __visit_name__ = "INT8RANGE" + -ischema_names['int8range'] = INT8RANGE +ischema_names["int8range"] = INT8RANGE class NUMRANGE(RangeOperators, sqltypes.TypeEngine): @@ -131,9 +134,10 @@ class NUMRANGE(RangeOperators, sqltypes.TypeEngine): """ - __visit_name__ = 'NUMRANGE' + __visit_name__ = "NUMRANGE" + -ischema_names['numrange'] = NUMRANGE +ischema_names["numrange"] = NUMRANGE class DATERANGE(RangeOperators, sqltypes.TypeEngine): @@ -143,9 +147,10 @@ class DATERANGE(RangeOperators, sqltypes.TypeEngine): """ - __visit_name__ = 'DATERANGE' + __visit_name__ = "DATERANGE" -ischema_names['daterange'] = DATERANGE + +ischema_names["daterange"] = DATERANGE class TSRANGE(RangeOperators, sqltypes.TypeEngine): @@ -155,9 +160,10 @@ class TSRANGE(RangeOperators, sqltypes.TypeEngine): """ - __visit_name__ = 'TSRANGE' + __visit_name__ = "TSRANGE" + -ischema_names['tsrange'] = TSRANGE +ischema_names["tsrange"] = TSRANGE class TSTZRANGE(RangeOperators, sqltypes.TypeEngine): @@ -167,6 +173,7 @@ class TSTZRANGE(RangeOperators, sqltypes.TypeEngine): """ - __visit_name__ = 'TSTZRANGE' + __visit_name__ = "TSTZRANGE" + -ischema_names['tstzrange'] = TSTZRANGE +ischema_names["tstzrange"] = TSTZRANGE diff --git a/lib/sqlalchemy/dialects/postgresql/zxjdbc.py b/lib/sqlalchemy/dialects/postgresql/zxjdbc.py index ef6e8f1f97..4d984443ae 100644 --- a/lib/sqlalchemy/dialects/postgresql/zxjdbc.py +++ b/lib/sqlalchemy/dialects/postgresql/zxjdbc.py @@ -19,7 +19,6 @@ from .base import PGDialect, PGExecutionContext class PGExecutionContext_zxjdbc(PGExecutionContext): - def create_cursor(self): cursor = self._dbapi_connection.cursor() cursor.datahandler = self.dialect.DataHandler(cursor.datahandler) @@ -27,8 +26,8 @@ class PGExecutionContext_zxjdbc(PGExecutionContext): class PGDialect_zxjdbc(ZxJDBCConnector, PGDialect): - jdbc_db_name = 'postgresql' - jdbc_driver_name = 'org.postgresql.Driver' + jdbc_db_name = "postgresql" + jdbc_driver_name = "org.postgresql.Driver" execution_ctx_cls = PGExecutionContext_zxjdbc @@ -37,10 +36,12 @@ class PGDialect_zxjdbc(ZxJDBCConnector, PGDialect): def __init__(self, *args, **kwargs): super(PGDialect_zxjdbc, self).__init__(*args, **kwargs) from com.ziclix.python.sql.handler import PostgresqlDataHandler + self.DataHandler = PostgresqlDataHandler def _get_server_version_info(self, connection): - parts = connection.connection.dbversion.split('.') + parts = connection.connection.dbversion.split(".") return tuple(int(x) for x in parts) + dialect = PGDialect_zxjdbc diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py index a735815213..41f0175975 100644 --- a/lib/sqlalchemy/dialects/sqlite/__init__.py +++ b/lib/sqlalchemy/dialects/sqlite/__init__.py @@ -8,14 +8,44 @@ from . import base, pysqlite, pysqlcipher # noqa from sqlalchemy.dialects.sqlite.base import ( - BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, FLOAT, INTEGER, JSON, REAL, - NUMERIC, SMALLINT, TEXT, TIME, TIMESTAMP, VARCHAR + BLOB, + BOOLEAN, + CHAR, + DATE, + DATETIME, + DECIMAL, + FLOAT, + INTEGER, + JSON, + REAL, + NUMERIC, + SMALLINT, + TEXT, + TIME, + TIMESTAMP, + VARCHAR, ) # default dialect base.dialect = dialect = pysqlite.dialect -__all__ = ('BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', - 'FLOAT', 'INTEGER', 'JSON', 'NUMERIC', 'SMALLINT', 'TEXT', 'TIME', - 'TIMESTAMP', 'VARCHAR', 'REAL', 'dialect') +__all__ = ( + "BLOB", + "BOOLEAN", + "CHAR", + "DATE", + "DATETIME", + "DECIMAL", + "FLOAT", + "INTEGER", + "JSON", + "NUMERIC", + "SMALLINT", + "TEXT", + "TIME", + "TIMESTAMP", + "VARCHAR", + "REAL", + "dialect", +) diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index c487af8981..cb9389af18 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -579,9 +579,20 @@ from ... import util from ...engine import default, reflection from ...sql import compiler -from ...types import (BLOB, BOOLEAN, CHAR, DECIMAL, FLOAT, - INTEGER, REAL, NUMERIC, SMALLINT, TEXT, - TIMESTAMP, VARCHAR) +from ...types import ( + BLOB, + BOOLEAN, + CHAR, + DECIMAL, + FLOAT, + INTEGER, + REAL, + NUMERIC, + SMALLINT, + TEXT, + TIMESTAMP, + VARCHAR, +) from .json import JSON, JSONIndexType, JSONPathType @@ -610,10 +621,15 @@ class _DateTimeMixin(object): """ spec = self._storage_format % { - "year": 0, "month": 0, "day": 0, "hour": 0, - "minute": 0, "second": 0, "microsecond": 0 + "year": 0, + "month": 0, + "day": 0, + "hour": 0, + "minute": 0, + "second": 0, + "microsecond": 0, } - return bool(re.search(r'[^0-9]', spec)) + return bool(re.search(r"[^0-9]", spec)) def adapt(self, cls, **kw): if issubclass(cls, _DateTimeMixin): @@ -628,6 +644,7 @@ class _DateTimeMixin(object): def process(value): return "'%s'" % bp(value) + return process @@ -671,13 +688,17 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): ) def __init__(self, *args, **kwargs): - truncate_microseconds = kwargs.pop('truncate_microseconds', False) + truncate_microseconds = kwargs.pop("truncate_microseconds", False) super(DATETIME, self).__init__(*args, **kwargs) if truncate_microseconds: - assert 'storage_format' not in kwargs, "You can specify only "\ + assert "storage_format" not in kwargs, ( + "You can specify only " "one of truncate_microseconds or storage_format." - assert 'regexp' not in kwargs, "You can specify only one of "\ + ) + assert "regexp" not in kwargs, ( + "You can specify only one of " "truncate_microseconds or regexp." + ) self._storage_format = ( "%(year)04d-%(month)02d-%(day)02d " "%(hour)02d:%(minute)02d:%(second)02d" @@ -693,33 +714,37 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): return None elif isinstance(value, datetime_datetime): return format % { - 'year': value.year, - 'month': value.month, - 'day': value.day, - 'hour': value.hour, - 'minute': value.minute, - 'second': value.second, - 'microsecond': value.microsecond, + "year": value.year, + "month": value.month, + "day": value.day, + "hour": value.hour, + "minute": value.minute, + "second": value.second, + "microsecond": value.microsecond, } elif isinstance(value, datetime_date): return format % { - 'year': value.year, - 'month': value.month, - 'day': value.day, - 'hour': 0, - 'minute': 0, - 'second': 0, - 'microsecond': 0, + "year": value.year, + "month": value.month, + "day": value.day, + "hour": 0, + "minute": 0, + "second": 0, + "microsecond": 0, } else: - raise TypeError("SQLite DateTime type only accepts Python " - "datetime and date objects as input.") + raise TypeError( + "SQLite DateTime type only accepts Python " + "datetime and date objects as input." + ) + return process def result_processor(self, dialect, coltype): if self._reg: return processors.str_to_datetime_processor_factory( - self._reg, datetime.datetime) + self._reg, datetime.datetime + ) else: return processors.str_to_datetime @@ -768,19 +793,23 @@ class DATE(_DateTimeMixin, sqltypes.Date): return None elif isinstance(value, datetime_date): return format % { - 'year': value.year, - 'month': value.month, - 'day': value.day, + "year": value.year, + "month": value.month, + "day": value.day, } else: - raise TypeError("SQLite Date type only accepts Python " - "date objects as input.") + raise TypeError( + "SQLite Date type only accepts Python " + "date objects as input." + ) + return process def result_processor(self, dialect, coltype): if self._reg: return processors.str_to_datetime_processor_factory( - self._reg, datetime.date) + self._reg, datetime.date + ) else: return processors.str_to_date @@ -820,13 +849,17 @@ class TIME(_DateTimeMixin, sqltypes.Time): _storage_format = "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" def __init__(self, *args, **kwargs): - truncate_microseconds = kwargs.pop('truncate_microseconds', False) + truncate_microseconds = kwargs.pop("truncate_microseconds", False) super(TIME, self).__init__(*args, **kwargs) if truncate_microseconds: - assert 'storage_format' not in kwargs, "You can specify only "\ + assert "storage_format" not in kwargs, ( + "You can specify only " "one of truncate_microseconds or storage_format." - assert 'regexp' not in kwargs, "You can specify only one of "\ + ) + assert "regexp" not in kwargs, ( + "You can specify only one of " "truncate_microseconds or regexp." + ) self._storage_format = "%(hour)02d:%(minute)02d:%(second)02d" def bind_processor(self, dialect): @@ -838,23 +871,28 @@ class TIME(_DateTimeMixin, sqltypes.Time): return None elif isinstance(value, datetime_time): return format % { - 'hour': value.hour, - 'minute': value.minute, - 'second': value.second, - 'microsecond': value.microsecond, + "hour": value.hour, + "minute": value.minute, + "second": value.second, + "microsecond": value.microsecond, } else: - raise TypeError("SQLite Time type only accepts Python " - "time objects as input.") + raise TypeError( + "SQLite Time type only accepts Python " + "time objects as input." + ) + return process def result_processor(self, dialect, coltype): if self._reg: return processors.str_to_datetime_processor_factory( - self._reg, datetime.time) + self._reg, datetime.time + ) else: return processors.str_to_time + colspecs = { sqltypes.Date: DATE, sqltypes.DateTime: DATETIME, @@ -865,31 +903,31 @@ colspecs = { } ischema_names = { - 'BIGINT': sqltypes.BIGINT, - 'BLOB': sqltypes.BLOB, - 'BOOL': sqltypes.BOOLEAN, - 'BOOLEAN': sqltypes.BOOLEAN, - 'CHAR': sqltypes.CHAR, - 'DATE': sqltypes.DATE, - 'DATE_CHAR': sqltypes.DATE, - 'DATETIME': sqltypes.DATETIME, - 'DATETIME_CHAR': sqltypes.DATETIME, - 'DOUBLE': sqltypes.FLOAT, - 'DECIMAL': sqltypes.DECIMAL, - 'FLOAT': sqltypes.FLOAT, - 'INT': sqltypes.INTEGER, - 'INTEGER': sqltypes.INTEGER, - 'JSON': JSON, - 'NUMERIC': sqltypes.NUMERIC, - 'REAL': sqltypes.REAL, - 'SMALLINT': sqltypes.SMALLINT, - 'TEXT': sqltypes.TEXT, - 'TIME': sqltypes.TIME, - 'TIME_CHAR': sqltypes.TIME, - 'TIMESTAMP': sqltypes.TIMESTAMP, - 'VARCHAR': sqltypes.VARCHAR, - 'NVARCHAR': sqltypes.NVARCHAR, - 'NCHAR': sqltypes.NCHAR, + "BIGINT": sqltypes.BIGINT, + "BLOB": sqltypes.BLOB, + "BOOL": sqltypes.BOOLEAN, + "BOOLEAN": sqltypes.BOOLEAN, + "CHAR": sqltypes.CHAR, + "DATE": sqltypes.DATE, + "DATE_CHAR": sqltypes.DATE, + "DATETIME": sqltypes.DATETIME, + "DATETIME_CHAR": sqltypes.DATETIME, + "DOUBLE": sqltypes.FLOAT, + "DECIMAL": sqltypes.DECIMAL, + "FLOAT": sqltypes.FLOAT, + "INT": sqltypes.INTEGER, + "INTEGER": sqltypes.INTEGER, + "JSON": JSON, + "NUMERIC": sqltypes.NUMERIC, + "REAL": sqltypes.REAL, + "SMALLINT": sqltypes.SMALLINT, + "TEXT": sqltypes.TEXT, + "TIME": sqltypes.TIME, + "TIME_CHAR": sqltypes.TIME, + "TIMESTAMP": sqltypes.TIMESTAMP, + "VARCHAR": sqltypes.VARCHAR, + "NVARCHAR": sqltypes.NVARCHAR, + "NCHAR": sqltypes.NCHAR, } @@ -897,17 +935,18 @@ class SQLiteCompiler(compiler.SQLCompiler): extract_map = util.update_copy( compiler.SQLCompiler.extract_map, { - 'month': '%m', - 'day': '%d', - 'year': '%Y', - 'second': '%S', - 'hour': '%H', - 'doy': '%j', - 'minute': '%M', - 'epoch': '%s', - 'dow': '%w', - 'week': '%W', - }) + "month": "%m", + "day": "%d", + "year": "%Y", + "second": "%S", + "hour": "%H", + "doy": "%j", + "minute": "%M", + "epoch": "%s", + "dow": "%w", + "week": "%W", + }, + ) def visit_now_func(self, fn, **kw): return "CURRENT_TIMESTAMP" @@ -916,10 +955,10 @@ class SQLiteCompiler(compiler.SQLCompiler): return 'DATETIME(CURRENT_TIMESTAMP, "localtime")' def visit_true(self, expr, **kw): - return '1' + return "1" def visit_false(self, expr, **kw): - return '0' + return "0" def visit_char_length_func(self, fn, **kw): return "length%s" % self.function_argspec(fn) @@ -934,11 +973,12 @@ class SQLiteCompiler(compiler.SQLCompiler): try: return "CAST(STRFTIME('%s', %s) AS INTEGER)" % ( self.extract_map[extract.field], - self.process(extract.expr, **kw) + self.process(extract.expr, **kw), ) except KeyError: raise exc.CompileError( - "%s is not a valid extract argument." % extract.field) + "%s is not a valid extract argument." % extract.field + ) def limit_clause(self, select, **kw): text = "" @@ -954,35 +994,41 @@ class SQLiteCompiler(compiler.SQLCompiler): def for_update_clause(self, select, **kw): # sqlite has no "FOR UPDATE" AFAICT - return '' + return "" def visit_is_distinct_from_binary(self, binary, operator, **kw): - return "%s IS NOT %s" % (self.process(binary.left), - self.process(binary.right)) + return "%s IS NOT %s" % ( + self.process(binary.left), + self.process(binary.right), + ) def visit_isnot_distinct_from_binary(self, binary, operator, **kw): - return "%s IS %s" % (self.process(binary.left), - self.process(binary.right)) + return "%s IS %s" % ( + self.process(binary.left), + self.process(binary.right), + ) def visit_json_getitem_op_binary(self, binary, operator, **kw): return "JSON_QUOTE(JSON_EXTRACT(%s, %s))" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_json_path_getitem_op_binary(self, binary, operator, **kw): return "JSON_QUOTE(JSON_EXTRACT(%s, %s))" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_empty_set_expr(self, type_): - return 'SELECT 1 FROM (SELECT 1) WHERE 1!=1' + return "SELECT 1 FROM (SELECT 1) WHERE 1!=1" class SQLiteDDLCompiler(compiler.DDLCompiler): - def get_column_specification(self, column, **kwargs): coltype = self.dialect.type_compiler.process( - column.type, type_expression=column) + column.type, type_expression=column + ) colspec = self.preparer.format_column(column) + " " + coltype default = self.get_column_default_string(column) if default is not None: @@ -991,29 +1037,33 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): if not column.nullable: colspec += " NOT NULL" - on_conflict_clause = column.dialect_options['sqlite'][ - 'on_conflict_not_null'] + on_conflict_clause = column.dialect_options["sqlite"][ + "on_conflict_not_null" + ] if on_conflict_clause is not None: colspec += " ON CONFLICT " + on_conflict_clause if column.primary_key: if ( - column.autoincrement is True and - len(column.table.primary_key.columns) != 1 + column.autoincrement is True + and len(column.table.primary_key.columns) != 1 ): raise exc.CompileError( "SQLite does not support autoincrement for " - "composite primary keys") + "composite primary keys" + ) - if (column.table.dialect_options['sqlite']['autoincrement'] and - len(column.table.primary_key.columns) == 1 and - issubclass( - column.type._type_affinity, sqltypes.Integer) and - not column.foreign_keys): + if ( + column.table.dialect_options["sqlite"]["autoincrement"] + and len(column.table.primary_key.columns) == 1 + and issubclass(column.type._type_affinity, sqltypes.Integer) + and not column.foreign_keys + ): colspec += " PRIMARY KEY" - on_conflict_clause = column.dialect_options['sqlite'][ - 'on_conflict_primary_key'] + on_conflict_clause = column.dialect_options["sqlite"][ + "on_conflict_primary_key" + ] if on_conflict_clause is not None: colspec += " ON CONFLICT " + on_conflict_clause @@ -1027,21 +1077,25 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): # with the column itself. if len(constraint.columns) == 1: c = list(constraint)[0] - if (c.primary_key and - c.table.dialect_options['sqlite']['autoincrement'] and - issubclass(c.type._type_affinity, sqltypes.Integer) and - not c.foreign_keys): + if ( + c.primary_key + and c.table.dialect_options["sqlite"]["autoincrement"] + and issubclass(c.type._type_affinity, sqltypes.Integer) + and not c.foreign_keys + ): return None - text = super( - SQLiteDDLCompiler, - self).visit_primary_key_constraint(constraint) + text = super(SQLiteDDLCompiler, self).visit_primary_key_constraint( + constraint + ) - on_conflict_clause = constraint.dialect_options['sqlite'][ - 'on_conflict'] + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] if on_conflict_clause is None and len(constraint.columns) == 1: - on_conflict_clause = list(constraint)[0].\ - dialect_options['sqlite']['on_conflict_primary_key'] + on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][ + "on_conflict_primary_key" + ] if on_conflict_clause is not None: text += " ON CONFLICT " + on_conflict_clause @@ -1049,15 +1103,17 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text def visit_unique_constraint(self, constraint): - text = super( - SQLiteDDLCompiler, - self).visit_unique_constraint(constraint) + text = super(SQLiteDDLCompiler, self).visit_unique_constraint( + constraint + ) - on_conflict_clause = constraint.dialect_options['sqlite'][ - 'on_conflict'] + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] if on_conflict_clause is None and len(constraint.columns) == 1: - on_conflict_clause = list(constraint)[0].\ - dialect_options['sqlite']['on_conflict_unique'] + on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][ + "on_conflict_unique" + ] if on_conflict_clause is not None: text += " ON CONFLICT " + on_conflict_clause @@ -1065,12 +1121,13 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text def visit_check_constraint(self, constraint): - text = super( - SQLiteDDLCompiler, - self).visit_check_constraint(constraint) + text = super(SQLiteDDLCompiler, self).visit_check_constraint( + constraint + ) - on_conflict_clause = constraint.dialect_options['sqlite'][ - 'on_conflict'] + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] if on_conflict_clause is not None: text += " ON CONFLICT " + on_conflict_clause @@ -1078,14 +1135,15 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text def visit_column_check_constraint(self, constraint): - text = super( - SQLiteDDLCompiler, - self).visit_column_check_constraint(constraint) + text = super(SQLiteDDLCompiler, self).visit_column_check_constraint( + constraint + ) - if constraint.dialect_options['sqlite']['on_conflict'] is not None: + if constraint.dialect_options["sqlite"]["on_conflict"] is not None: raise exc.CompileError( "SQLite does not support on conflict clause for " - "column check constraint") + "column check constraint" + ) return text @@ -1097,40 +1155,40 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): if local_table.schema != remote_table.schema: return None else: - return super( - SQLiteDDLCompiler, - self).visit_foreign_key_constraint(constraint) + return super(SQLiteDDLCompiler, self).visit_foreign_key_constraint( + constraint + ) def define_constraint_remote_table(self, constraint, table, preparer): """Format the remote table clause of a CREATE CONSTRAINT clause.""" return preparer.format_table(table, use_schema=False) - def visit_create_index(self, create, include_schema=False, - include_table_schema=True): + def visit_create_index( + self, create, include_schema=False, include_table_schema=True + ): index = create.element self._verify_index_table(index) preparer = self.preparer text = "CREATE " if index.unique: text += "UNIQUE " - text += "INDEX %s ON %s (%s)" \ - % ( - self._prepared_index_name(index, - include_schema=True), - preparer.format_table(index.table, - use_schema=False), - ', '.join( - self.sql_compiler.process( - expr, include_table=False, literal_binds=True) for - expr in index.expressions) - ) + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=True), + preparer.format_table(index.table, use_schema=False), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) whereclause = index.dialect_options["sqlite"]["where"] if whereclause is not None: where_compiled = self.sql_compiler.process( - whereclause, include_table=False, - literal_binds=True) + whereclause, include_table=False, literal_binds=True + ) text += " WHERE " + where_compiled return text @@ -1141,22 +1199,28 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler): return self.visit_BLOB(type_) def visit_DATETIME(self, type_, **kw): - if not isinstance(type_, _DateTimeMixin) or \ - type_.format_is_text_affinity: + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): return super(SQLiteTypeCompiler, self).visit_DATETIME(type_) else: return "DATETIME_CHAR" def visit_DATE(self, type_, **kw): - if not isinstance(type_, _DateTimeMixin) or \ - type_.format_is_text_affinity: + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): return super(SQLiteTypeCompiler, self).visit_DATE(type_) else: return "DATE_CHAR" def visit_TIME(self, type_, **kw): - if not isinstance(type_, _DateTimeMixin) or \ - type_.format_is_text_affinity: + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): return super(SQLiteTypeCompiler, self).visit_TIME(type_) else: return "TIME_CHAR" @@ -1169,33 +1233,135 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler): class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = set([ - 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', - 'attach', 'autoincrement', 'before', 'begin', 'between', 'by', - 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit', - 'conflict', 'constraint', 'create', 'cross', 'current_date', - 'current_time', 'current_timestamp', 'database', 'default', - 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', - 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', - 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob', - 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index', - 'indexed', 'initially', 'inner', 'insert', 'instead', 'intersect', - 'into', 'is', 'isnull', 'join', 'key', 'left', 'like', 'limit', - 'match', 'natural', 'not', 'notnull', 'null', 'of', 'offset', 'on', - 'or', 'order', 'outer', 'plan', 'pragma', 'primary', 'query', - 'raise', 'references', 'reindex', 'rename', 'replace', 'restrict', - 'right', 'rollback', 'row', 'select', 'set', 'table', 'temp', - 'temporary', 'then', 'to', 'transaction', 'trigger', 'true', 'union', - 'unique', 'update', 'using', 'vacuum', 'values', 'view', 'virtual', - 'when', 'where', - ]) + reserved_words = set( + [ + "add", + "after", + "all", + "alter", + "analyze", + "and", + "as", + "asc", + "attach", + "autoincrement", + "before", + "begin", + "between", + "by", + "cascade", + "case", + "cast", + "check", + "collate", + "column", + "commit", + "conflict", + "constraint", + "create", + "cross", + "current_date", + "current_time", + "current_timestamp", + "database", + "default", + "deferrable", + "deferred", + "delete", + "desc", + "detach", + "distinct", + "drop", + "each", + "else", + "end", + "escape", + "except", + "exclusive", + "explain", + "false", + "fail", + "for", + "foreign", + "from", + "full", + "glob", + "group", + "having", + "if", + "ignore", + "immediate", + "in", + "index", + "indexed", + "initially", + "inner", + "insert", + "instead", + "intersect", + "into", + "is", + "isnull", + "join", + "key", + "left", + "like", + "limit", + "match", + "natural", + "not", + "notnull", + "null", + "of", + "offset", + "on", + "or", + "order", + "outer", + "plan", + "pragma", + "primary", + "query", + "raise", + "references", + "reindex", + "rename", + "replace", + "restrict", + "right", + "rollback", + "row", + "select", + "set", + "table", + "temp", + "temporary", + "then", + "to", + "transaction", + "trigger", + "true", + "union", + "unique", + "update", + "using", + "vacuum", + "values", + "view", + "virtual", + "when", + "where", + ] + ) class SQLiteExecutionContext(default.DefaultExecutionContext): @util.memoized_property def _preserve_raw_colnames(self): - return not self.dialect._broken_dotted_colnames or \ - self.execution_options.get("sqlite_raw_colnames", False) + return ( + not self.dialect._broken_dotted_colnames + or self.execution_options.get("sqlite_raw_colnames", False) + ) def _translate_colname(self, colname): # TODO: detect SQLite version 3.10.0 or greater; @@ -1212,7 +1378,7 @@ class SQLiteExecutionContext(default.DefaultExecutionContext): class SQLiteDialect(default.DefaultDialect): - name = 'sqlite' + name = "sqlite" supports_alter = False supports_unicode_statements = True supports_unicode_binds = True @@ -1221,7 +1387,7 @@ class SQLiteDialect(default.DefaultDialect): supports_cast = True supports_multivalues_insert = True - default_paramstyle = 'qmark' + default_paramstyle = "qmark" execution_ctx_cls = SQLiteExecutionContext statement_compiler = SQLiteCompiler ddl_compiler = SQLiteDDLCompiler @@ -1235,27 +1401,30 @@ class SQLiteDialect(default.DefaultDialect): supports_default_values = True construct_arguments = [ - (sa_schema.Table, { - "autoincrement": False - }), - (sa_schema.Index, { - "where": None, - }), - (sa_schema.Column, { - "on_conflict_primary_key": None, - "on_conflict_not_null": None, - "on_conflict_unique": None, - }), - (sa_schema.Constraint, { - "on_conflict": None, - }), + (sa_schema.Table, {"autoincrement": False}), + (sa_schema.Index, {"where": None}), + ( + sa_schema.Column, + { + "on_conflict_primary_key": None, + "on_conflict_not_null": None, + "on_conflict_unique": None, + }, + ), + (sa_schema.Constraint, {"on_conflict": None}), ] _broken_fk_pragma_quotes = False _broken_dotted_colnames = False - def __init__(self, isolation_level=None, native_datetime=False, - _json_serializer=None, _json_deserializer=None, **kwargs): + def __init__( + self, + isolation_level=None, + native_datetime=False, + _json_serializer=None, + _json_deserializer=None, + **kwargs + ): default.DefaultDialect.__init__(self, **kwargs) self.isolation_level = isolation_level self._json_serializer = _json_serializer @@ -1269,35 +1438,42 @@ class SQLiteDialect(default.DefaultDialect): if self.dbapi is not None: self.supports_right_nested_joins = ( - self.dbapi.sqlite_version_info >= (3, 7, 16)) - self._broken_dotted_colnames = ( - self.dbapi.sqlite_version_info < (3, 10, 0) + self.dbapi.sqlite_version_info >= (3, 7, 16) + ) + self._broken_dotted_colnames = self.dbapi.sqlite_version_info < ( + 3, + 10, + 0, + ) + self.supports_default_values = self.dbapi.sqlite_version_info >= ( + 3, + 3, + 8, ) - self.supports_default_values = ( - self.dbapi.sqlite_version_info >= (3, 3, 8)) - self.supports_cast = ( - self.dbapi.sqlite_version_info >= (3, 2, 3)) + self.supports_cast = self.dbapi.sqlite_version_info >= (3, 2, 3) self.supports_multivalues_insert = ( # http://www.sqlite.org/releaselog/3_7_11.html - self.dbapi.sqlite_version_info >= (3, 7, 11)) + self.dbapi.sqlite_version_info + >= (3, 7, 11) + ) # see http://www.sqlalchemy.org/trac/ticket/2568 # as well as http://www.sqlite.org/src/info/600482d161 - self._broken_fk_pragma_quotes = ( - self.dbapi.sqlite_version_info < (3, 6, 14)) + self._broken_fk_pragma_quotes = self.dbapi.sqlite_version_info < ( + 3, + 6, + 14, + ) - _isolation_lookup = { - 'READ UNCOMMITTED': 1, - 'SERIALIZABLE': 0, - } + _isolation_lookup = {"READ UNCOMMITTED": 1, "SERIALIZABLE": 0} def set_isolation_level(self, connection, level): try: - isolation_level = self._isolation_lookup[level.replace('_', ' ')] + isolation_level = self._isolation_lookup[level.replace("_", " ")] except KeyError: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) ) cursor = connection.cursor() cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level) @@ -1305,7 +1481,7 @@ class SQLiteDialect(default.DefaultDialect): def get_isolation_level(self, connection): cursor = connection.cursor() - cursor.execute('PRAGMA read_uncommitted') + cursor.execute("PRAGMA read_uncommitted") res = cursor.fetchone() if res: value = res[0] @@ -1327,8 +1503,10 @@ class SQLiteDialect(default.DefaultDialect): def on_connect(self): if self.isolation_level is not None: + def connect(conn): self.set_isolation_level(conn, self.isolation_level) + return connect else: return None @@ -1344,44 +1522,51 @@ class SQLiteDialect(default.DefaultDialect): def get_table_names(self, connection, schema=None, **kw): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = '%s.sqlite_master' % qschema + master = "%s.sqlite_master" % qschema else: master = "sqlite_master" - s = ("SELECT name FROM %s " - "WHERE type='table' ORDER BY name") % (master,) + s = ("SELECT name FROM %s " "WHERE type='table' ORDER BY name") % ( + master, + ) rs = connection.execute(s) return [row[0] for row in rs] @reflection.cache def get_temp_table_names(self, connection, **kw): - s = "SELECT name FROM sqlite_temp_master "\ + s = ( + "SELECT name FROM sqlite_temp_master " "WHERE type='table' ORDER BY name " + ) rs = connection.execute(s) return [row[0] for row in rs] @reflection.cache def get_temp_view_names(self, connection, **kw): - s = "SELECT name FROM sqlite_temp_master "\ + s = ( + "SELECT name FROM sqlite_temp_master " "WHERE type='view' ORDER BY name " + ) rs = connection.execute(s) return [row[0] for row in rs] def has_table(self, connection, table_name, schema=None): info = self._get_table_pragma( - connection, "table_info", table_name, schema=schema) + connection, "table_info", table_name, schema=schema + ) return bool(info) @reflection.cache def get_view_names(self, connection, schema=None, **kw): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = '%s.sqlite_master' % qschema + master = "%s.sqlite_master" % qschema else: master = "sqlite_master" - s = ("SELECT name FROM %s " - "WHERE type='view' ORDER BY name") % (master,) + s = ("SELECT name FROM %s " "WHERE type='view' ORDER BY name") % ( + master, + ) rs = connection.execute(s) return [row[0] for row in rs] @@ -1390,21 +1575,27 @@ class SQLiteDialect(default.DefaultDialect): def get_view_definition(self, connection, view_name, schema=None, **kw): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = '%s.sqlite_master' % qschema - s = ("SELECT sql FROM %s WHERE name = '%s'" - "AND type='view'") % (master, view_name) + master = "%s.sqlite_master" % qschema + s = ("SELECT sql FROM %s WHERE name = '%s'" "AND type='view'") % ( + master, + view_name, + ) rs = connection.execute(s) else: try: - s = ("SELECT sql FROM " - " (SELECT * FROM sqlite_master UNION ALL " - " SELECT * FROM sqlite_temp_master) " - "WHERE name = '%s' " - "AND type='view'") % view_name + s = ( + "SELECT sql FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE name = '%s' " + "AND type='view'" + ) % view_name rs = connection.execute(s) except exc.DBAPIError: - s = ("SELECT sql FROM sqlite_master WHERE name = '%s' " - "AND type='view'") % view_name + s = ( + "SELECT sql FROM sqlite_master WHERE name = '%s' " + "AND type='view'" + ) % view_name rs = connection.execute(s) result = rs.fetchall() @@ -1414,15 +1605,24 @@ class SQLiteDialect(default.DefaultDialect): @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): info = self._get_table_pragma( - connection, "table_info", table_name, schema=schema) + connection, "table_info", table_name, schema=schema + ) columns = [] for row in info: (name, type_, nullable, default, primary_key) = ( - row[1], row[2].upper(), not row[3], row[4], row[5]) + row[1], + row[2].upper(), + not row[3], + row[4], + row[5], + ) - columns.append(self._get_column_info(name, type_, nullable, - default, primary_key)) + columns.append( + self._get_column_info( + name, type_, nullable, default, primary_key + ) + ) return columns def _get_column_info(self, name, type_, nullable, default, primary_key): @@ -1432,12 +1632,12 @@ class SQLiteDialect(default.DefaultDialect): default = util.text_type(default) return { - 'name': name, - 'type': coltype, - 'nullable': nullable, - 'default': default, - 'autoincrement': 'auto', - 'primary_key': primary_key, + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": "auto", + "primary_key": primary_key, } def _resolve_type_affinity(self, type_): @@ -1457,36 +1657,37 @@ class SQLiteDialect(default.DefaultDialect): DATE and DOUBLE). """ - match = re.match(r'([\w ]+)(\(.*?\))?', type_) + match = re.match(r"([\w ]+)(\(.*?\))?", type_) if match: coltype = match.group(1) args = match.group(2) else: - coltype = '' - args = '' + coltype = "" + args = "" if coltype in self.ischema_names: coltype = self.ischema_names[coltype] - elif 'INT' in coltype: + elif "INT" in coltype: coltype = sqltypes.INTEGER - elif 'CHAR' in coltype or 'CLOB' in coltype or 'TEXT' in coltype: + elif "CHAR" in coltype or "CLOB" in coltype or "TEXT" in coltype: coltype = sqltypes.TEXT - elif 'BLOB' in coltype or not coltype: + elif "BLOB" in coltype or not coltype: coltype = sqltypes.NullType - elif 'REAL' in coltype or 'FLOA' in coltype or 'DOUB' in coltype: + elif "REAL" in coltype or "FLOA" in coltype or "DOUB" in coltype: coltype = sqltypes.REAL else: coltype = sqltypes.NUMERIC if args is not None: - args = re.findall(r'(\d+)', args) + args = re.findall(r"(\d+)", args) try: coltype = coltype(*[int(a) for a in args]) except TypeError: util.warn( "Could not instantiate type %s with " - "reflected arguments %s; using no arguments." % - (coltype, args)) + "reflected arguments %s; using no arguments." + % (coltype, args) + ) coltype = coltype() else: coltype = coltype() @@ -1498,58 +1699,59 @@ class SQLiteDialect(default.DefaultDialect): constraint_name = None table_data = self._get_table_sql(connection, table_name, schema=schema) if table_data: - PK_PATTERN = r'CONSTRAINT (\w+) PRIMARY KEY' + PK_PATTERN = r"CONSTRAINT (\w+) PRIMARY KEY" result = re.search(PK_PATTERN, table_data, re.I) constraint_name = result.group(1) if result else None cols = self.get_columns(connection, table_name, schema, **kw) pkeys = [] for col in cols: - if col['primary_key']: - pkeys.append(col['name']) + if col["primary_key"]: + pkeys.append(col["name"]) - return {'constrained_columns': pkeys, 'name': constraint_name} + return {"constrained_columns": pkeys, "name": constraint_name} @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): # sqlite makes this *extremely difficult*. # First, use the pragma to get the actual FKs. pragma_fks = self._get_table_pragma( - connection, "foreign_key_list", - table_name, schema=schema + connection, "foreign_key_list", table_name, schema=schema ) fks = {} for row in pragma_fks: - (numerical_id, rtbl, lcol, rcol) = ( - row[0], row[2], row[3], row[4]) + (numerical_id, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4]) if rcol is None: rcol = lcol if self._broken_fk_pragma_quotes: - rtbl = re.sub(r'^[\"\[`\']|[\"\]`\']$', '', rtbl) + rtbl = re.sub(r"^[\"\[`\']|[\"\]`\']$", "", rtbl) if numerical_id in fks: fk = fks[numerical_id] else: fk = fks[numerical_id] = { - 'name': None, - 'constrained_columns': [], - 'referred_schema': schema, - 'referred_table': rtbl, - 'referred_columns': [], - 'options': {} + "name": None, + "constrained_columns": [], + "referred_schema": schema, + "referred_table": rtbl, + "referred_columns": [], + "options": {}, } fks[numerical_id] = fk - fk['constrained_columns'].append(lcol) - fk['referred_columns'].append(rcol) + fk["constrained_columns"].append(lcol) + fk["referred_columns"].append(rcol) def fk_sig(constrained_columns, referred_table, referred_columns): - return tuple(constrained_columns) + (referred_table,) + \ - tuple(referred_columns) + return ( + tuple(constrained_columns) + + (referred_table,) + + tuple(referred_columns) + ) # then, parse the actual SQL and attempt to find DDL that matches # the names as well. SQLite saves the DDL in whatever format @@ -1558,10 +1760,13 @@ class SQLiteDialect(default.DefaultDialect): keys_by_signature = dict( ( fk_sig( - fk['constrained_columns'], - fk['referred_table'], fk['referred_columns']), - fk - ) for fk in fks.values() + fk["constrained_columns"], + fk["referred_table"], + fk["referred_columns"], + ), + fk, + ) + for fk in fks.values() ) table_data = self._get_table_sql(connection, table_name, schema=schema) @@ -1571,55 +1776,66 @@ class SQLiteDialect(default.DefaultDialect): def parse_fks(): FK_PATTERN = ( - r'(?:CONSTRAINT (\w+) +)?' - r'FOREIGN KEY *\( *(.+?) *\) +' + r"(?:CONSTRAINT (\w+) +)?" + r"FOREIGN KEY *\( *(.+?) *\) +" r'REFERENCES +(?:(?:"(.+?)")|([a-z0-9_]+)) *\((.+?)\) *' - r'((?:ON (?:DELETE|UPDATE) ' - r'(?:SET NULL|SET DEFAULT|CASCADE|RESTRICT|NO ACTION) *)*)' + r"((?:ON (?:DELETE|UPDATE) " + r"(?:SET NULL|SET DEFAULT|CASCADE|RESTRICT|NO ACTION) *)*)" ) for match in re.finditer(FK_PATTERN, table_data, re.I): ( - constraint_name, constrained_columns, - referred_quoted_name, referred_name, - referred_columns, onupdatedelete) = \ - match.group(1, 2, 3, 4, 5, 6) + constraint_name, + constrained_columns, + referred_quoted_name, + referred_name, + referred_columns, + onupdatedelete, + ) = match.group(1, 2, 3, 4, 5, 6) constrained_columns = list( - self._find_cols_in_sig(constrained_columns)) + self._find_cols_in_sig(constrained_columns) + ) if not referred_columns: referred_columns = constrained_columns else: referred_columns = list( - self._find_cols_in_sig(referred_columns)) + self._find_cols_in_sig(referred_columns) + ) referred_name = referred_quoted_name or referred_name options = {} for token in re.split(r" *\bON\b *", onupdatedelete.upper()): if token.startswith("DELETE"): - options['ondelete'] = token[6:].strip() + options["ondelete"] = token[6:].strip() elif token.startswith("UPDATE"): options["onupdate"] = token[6:].strip() yield ( - constraint_name, constrained_columns, - referred_name, referred_columns, options) + constraint_name, + constrained_columns, + referred_name, + referred_columns, + options, + ) + fkeys = [] for ( - constraint_name, constrained_columns, - referred_name, referred_columns, options) in parse_fks(): - sig = fk_sig( - constrained_columns, referred_name, referred_columns) + constraint_name, + constrained_columns, + referred_name, + referred_columns, + options, + ) in parse_fks(): + sig = fk_sig(constrained_columns, referred_name, referred_columns) if sig not in keys_by_signature: util.warn( "WARNING: SQL-parsed foreign key constraint " "'%s' could not be located in PRAGMA " - "foreign_keys for table %s" % ( - sig, - table_name - )) + "foreign_keys for table %s" % (sig, table_name) + ) continue key = keys_by_signature.pop(sig) - key['name'] = constraint_name - key['options'] = options + key["name"] = constraint_name + key["options"] = options fkeys.append(key) # assume the remainders are the unnamed, inline constraints, just # use them as is as it's extremely difficult to parse inline @@ -1632,20 +1848,26 @@ class SQLiteDialect(default.DefaultDialect): yield match.group(1) or match.group(2) @reflection.cache - def get_unique_constraints(self, connection, table_name, - schema=None, **kw): + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): auto_index_by_sig = {} for idx in self.get_indexes( - connection, table_name, schema=schema, - include_auto_indexes=True, **kw): - if not idx['name'].startswith("sqlite_autoindex"): + connection, + table_name, + schema=schema, + include_auto_indexes=True, + **kw + ): + if not idx["name"].startswith("sqlite_autoindex"): continue - sig = tuple(idx['column_names']) + sig = tuple(idx["column_names"]) auto_index_by_sig[sig] = idx table_data = self._get_table_sql( - connection, table_name, schema=schema, **kw) + connection, table_name, schema=schema, **kw + ) if not table_data: return [] @@ -1654,8 +1876,8 @@ class SQLiteDialect(default.DefaultDialect): def parse_uqs(): UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)' INLINE_UNIQUE_PATTERN = ( - r'(?:(".+?")|([a-z0-9]+)) ' - r'+[a-z0-9_ ]+? +UNIQUE') + r'(?:(".+?")|([a-z0-9]+)) ' r"+[a-z0-9_ ]+? +UNIQUE" + ) for match in re.finditer(UNIQUE_PATTERN, table_data, re.I): name, cols = match.group(1, 2) @@ -1666,34 +1888,29 @@ class SQLiteDialect(default.DefaultDialect): # are kind of the same thing :) for match in re.finditer(INLINE_UNIQUE_PATTERN, table_data, re.I): cols = list( - self._find_cols_in_sig(match.group(1) or match.group(2))) + self._find_cols_in_sig(match.group(1) or match.group(2)) + ) yield None, cols for name, cols in parse_uqs(): sig = tuple(cols) if sig in auto_index_by_sig: auto_index_by_sig.pop(sig) - parsed_constraint = { - 'name': name, - 'column_names': cols - } + parsed_constraint = {"name": name, "column_names": cols} unique_constraints.append(parsed_constraint) # NOTE: auto_index_by_sig might not be empty here, # the PRIMARY KEY may have an entry. return unique_constraints @reflection.cache - def get_check_constraints(self, connection, table_name, - schema=None, **kw): + def get_check_constraints(self, connection, table_name, schema=None, **kw): table_data = self._get_table_sql( - connection, table_name, schema=schema, **kw) + connection, table_name, schema=schema, **kw + ) if not table_data: return [] - CHECK_PATTERN = ( - r'(?:CONSTRAINT (\w+) +)?' - r'CHECK *\( *(.+) *\),? *' - ) + CHECK_PATTERN = r"(?:CONSTRAINT (\w+) +)?" r"CHECK *\( *(.+) *\),? *" check_constraints = [] # NOTE: we aren't using re.S here because we actually are # taking advantage of each CHECK constraint being all on one @@ -1701,25 +1918,26 @@ class SQLiteDialect(default.DefaultDialect): # necessarily makes assumptions as to how the CREATE TABLE # was emitted. for match in re.finditer(CHECK_PATTERN, table_data, re.I): - check_constraints.append({ - 'sqltext': match.group(2), - 'name': match.group(1) - }) + check_constraints.append( + {"sqltext": match.group(2), "name": match.group(1)} + ) return check_constraints @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): pragma_indexes = self._get_table_pragma( - connection, "index_list", table_name, schema=schema) + connection, "index_list", table_name, schema=schema + ) indexes = [] - include_auto_indexes = kw.pop('include_auto_indexes', False) + include_auto_indexes = kw.pop("include_auto_indexes", False) for row in pragma_indexes: # ignore implicit primary key index. # http://www.mail-archive.com/sqlite-users@sqlite.org/msg30517.html - if (not include_auto_indexes and - row[1].startswith('sqlite_autoindex')): + if not include_auto_indexes and row[1].startswith( + "sqlite_autoindex" + ): continue indexes.append(dict(name=row[1], column_names=[], unique=row[2])) @@ -1727,34 +1945,38 @@ class SQLiteDialect(default.DefaultDialect): # loop thru unique indexes to get the column names. for idx in indexes: pragma_index = self._get_table_pragma( - connection, "index_info", idx['name']) + connection, "index_info", idx["name"] + ) for row in pragma_index: - idx['column_names'].append(row[2]) + idx["column_names"].append(row[2]) return indexes @reflection.cache def _get_table_sql(self, connection, table_name, schema=None, **kw): if schema: schema_expr = "%s." % ( - self.identifier_preparer.quote_identifier(schema)) + self.identifier_preparer.quote_identifier(schema) + ) else: schema_expr = "" try: - s = ("SELECT sql FROM " - " (SELECT * FROM %(schema)ssqlite_master UNION ALL " - " SELECT * FROM %(schema)ssqlite_temp_master) " - "WHERE name = '%(table)s' " - "AND type = 'table'" % { - "schema": schema_expr, - "table": table_name}) + s = ( + "SELECT sql FROM " + " (SELECT * FROM %(schema)ssqlite_master UNION ALL " + " SELECT * FROM %(schema)ssqlite_temp_master) " + "WHERE name = '%(table)s' " + "AND type = 'table'" + % {"schema": schema_expr, "table": table_name} + ) rs = connection.execute(s) except exc.DBAPIError: - s = ("SELECT sql FROM %(schema)ssqlite_master " - "WHERE name = '%(table)s' " - "AND type = 'table'" % { - "schema": schema_expr, - "table": table_name}) + s = ( + "SELECT sql FROM %(schema)ssqlite_master " + "WHERE name = '%(table)s' " + "AND type = 'table'" + % {"schema": schema_expr, "table": table_name} + ) rs = connection.execute(s) return rs.scalar() diff --git a/lib/sqlalchemy/dialects/sqlite/json.py b/lib/sqlalchemy/dialects/sqlite/json.py index 90929fbd89..db185dd4d2 100644 --- a/lib/sqlalchemy/dialects/sqlite/json.py +++ b/lib/sqlalchemy/dialects/sqlite/json.py @@ -58,7 +58,6 @@ class _FormatTypeMixin(object): class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): - def _format_value(self, value): if isinstance(value, int): value = "$[%s]" % value @@ -70,8 +69,10 @@ class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): def _format_value(self, value): return "$%s" % ( - "".join([ - "[%s]" % elem if isinstance(elem, int) - else '."%s"' % elem for elem in value - ]) + "".join( + [ + "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem + for elem in value + ] + ) ) diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py index 09f2b80093..fca425127c 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py @@ -82,9 +82,9 @@ from ... import pool class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): - driver = 'pysqlcipher' + driver = "pysqlcipher" - pragmas = ('kdf_iter', 'cipher', 'cipher_page_size', 'cipher_use_hmac') + pragmas = ("kdf_iter", "cipher", "cipher_page_size", "cipher_use_hmac") @classmethod def dbapi(cls): @@ -102,15 +102,13 @@ class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): return pool.SingletonThreadPool def connect(self, *cargs, **cparams): - passphrase = cparams.pop('passphrase', '') + passphrase = cparams.pop("passphrase", "") - pragmas = dict( - (key, cparams.pop(key, None)) for key in - self.pragmas - ) + pragmas = dict((key, cparams.pop(key, None)) for key in self.pragmas) - conn = super(SQLiteDialect_pysqlcipher, self).\ - connect(*cargs, **cparams) + conn = super(SQLiteDialect_pysqlcipher, self).connect( + *cargs, **cparams + ) conn.execute('pragma key="%s"' % passphrase) for prag, value in pragmas.items(): if value is not None: @@ -120,11 +118,17 @@ class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): def create_connect_args(self, url): super_url = _url.URL( - url.drivername, username=url.username, - host=url.host, database=url.database, query=url.query) - c_args, opts = super(SQLiteDialect_pysqlcipher, self).\ - create_connect_args(super_url) - opts['passphrase'] = url.password + url.drivername, + username=url.username, + host=url.host, + database=url.database, + query=url.query, + ) + c_args, opts = super( + SQLiteDialect_pysqlcipher, self + ).create_connect_args(super_url) + opts["passphrase"] = url.password return c_args, opts + dialect = SQLiteDialect_pysqlcipher diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index 8809962df2..e78d76ae6c 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -301,20 +301,20 @@ class _SQLite_pysqliteDate(DATE): class SQLiteDialect_pysqlite(SQLiteDialect): - default_paramstyle = 'qmark' + default_paramstyle = "qmark" colspecs = util.update_copy( SQLiteDialect.colspecs, { sqltypes.Date: _SQLite_pysqliteDate, sqltypes.TIMESTAMP: _SQLite_pysqliteTimeStamp, - } + }, ) if not util.py2k: description_encoding = None - driver = 'pysqlite' + driver = "pysqlite" def __init__(self, **kwargs): SQLiteDialect.__init__(self, **kwargs) @@ -323,10 +323,13 @@ class SQLiteDialect_pysqlite(SQLiteDialect): sqlite_ver = self.dbapi.version_info if sqlite_ver < (2, 1, 3): util.warn( - ("The installed version of pysqlite2 (%s) is out-dated " - "and will cause errors in some cases. Version 2.1.3 " - "or greater is recommended.") % - '.'.join([str(subver) for subver in sqlite_ver])) + ( + "The installed version of pysqlite2 (%s) is out-dated " + "and will cause errors in some cases. Version 2.1.3 " + "or greater is recommended." + ) + % ".".join([str(subver) for subver in sqlite_ver]) + ) @classmethod def dbapi(cls): @@ -341,7 +344,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect): @classmethod def get_pool_class(cls, url): - if url.database and url.database != ':memory:': + if url.database and url.database != ":memory:": return pool.NullPool else: return pool.SingletonThreadPool @@ -356,22 +359,25 @@ class SQLiteDialect_pysqlite(SQLiteDialect): "Valid SQLite URL forms are:\n" " sqlite:///:memory: (or, sqlite://)\n" " sqlite:///relative/path/to/file.db\n" - " sqlite:////absolute/path/to/file.db" % (url,)) - filename = url.database or ':memory:' - if filename != ':memory:': + " sqlite:////absolute/path/to/file.db" % (url,) + ) + filename = url.database or ":memory:" + if filename != ":memory:": filename = os.path.abspath(filename) opts = url.query.copy() - util.coerce_kw_type(opts, 'timeout', float) - util.coerce_kw_type(opts, 'isolation_level', str) - util.coerce_kw_type(opts, 'detect_types', int) - util.coerce_kw_type(opts, 'check_same_thread', bool) - util.coerce_kw_type(opts, 'cached_statements', int) + util.coerce_kw_type(opts, "timeout", float) + util.coerce_kw_type(opts, "isolation_level", str) + util.coerce_kw_type(opts, "detect_types", int) + util.coerce_kw_type(opts, "check_same_thread", bool) + util.coerce_kw_type(opts, "cached_statements", int) return ([filename], opts) def is_disconnect(self, e, connection, cursor): - return isinstance(e, self.dbapi.ProgrammingError) and \ - "Cannot operate on a closed database." in str(e) + return isinstance( + e, self.dbapi.ProgrammingError + ) and "Cannot operate on a closed database." in str(e) + dialect = SQLiteDialect_pysqlite diff --git a/lib/sqlalchemy/dialects/sybase/__init__.py b/lib/sqlalchemy/dialects/sybase/__init__.py index be434977fa..2f55d3bf6f 100644 --- a/lib/sqlalchemy/dialects/sybase/__init__.py +++ b/lib/sqlalchemy/dialects/sybase/__init__.py @@ -7,21 +7,61 @@ from . import base, pysybase, pyodbc # noqa -from .base import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\ - TEXT, DATE, DATETIME, FLOAT, NUMERIC,\ - BIGINT, INT, INTEGER, SMALLINT, BINARY,\ - VARBINARY, UNITEXT, UNICHAR, UNIVARCHAR,\ - IMAGE, BIT, MONEY, SMALLMONEY, TINYINT +from .base import ( + CHAR, + VARCHAR, + TIME, + NCHAR, + NVARCHAR, + TEXT, + DATE, + DATETIME, + FLOAT, + NUMERIC, + BIGINT, + INT, + INTEGER, + SMALLINT, + BINARY, + VARBINARY, + UNITEXT, + UNICHAR, + UNIVARCHAR, + IMAGE, + BIT, + MONEY, + SMALLMONEY, + TINYINT, +) # default dialect base.dialect = dialect = pyodbc.dialect __all__ = ( - 'CHAR', 'VARCHAR', 'TIME', 'NCHAR', 'NVARCHAR', - 'TEXT', 'DATE', 'DATETIME', 'FLOAT', 'NUMERIC', - 'BIGINT', 'INT', 'INTEGER', 'SMALLINT', 'BINARY', - 'VARBINARY', 'UNITEXT', 'UNICHAR', 'UNIVARCHAR', - 'IMAGE', 'BIT', 'MONEY', 'SMALLMONEY', 'TINYINT', - 'dialect' + "CHAR", + "VARCHAR", + "TIME", + "NCHAR", + "NVARCHAR", + "TEXT", + "DATE", + "DATETIME", + "FLOAT", + "NUMERIC", + "BIGINT", + "INT", + "INTEGER", + "SMALLINT", + "BINARY", + "VARBINARY", + "UNITEXT", + "UNICHAR", + "UNIVARCHAR", + "IMAGE", + "BIT", + "MONEY", + "SMALLMONEY", + "TINYINT", + "dialect", ) diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index 7dd9735733..1214a92794 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -31,70 +31,257 @@ from sqlalchemy.sql import operators as sql_operators from sqlalchemy import schema as sa_schema from sqlalchemy import util, sql, exc -from sqlalchemy.types import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\ - TEXT, DATE, DATETIME, FLOAT, NUMERIC,\ - BIGINT, INT, INTEGER, SMALLINT, BINARY,\ - VARBINARY, DECIMAL, TIMESTAMP, Unicode,\ - UnicodeText, REAL - -RESERVED_WORDS = set([ - "add", "all", "alter", "and", - "any", "as", "asc", "backup", - "begin", "between", "bigint", "binary", - "bit", "bottom", "break", "by", - "call", "capability", "cascade", "case", - "cast", "char", "char_convert", "character", - "check", "checkpoint", "close", "comment", - "commit", "connect", "constraint", "contains", - "continue", "convert", "create", "cross", - "cube", "current", "current_timestamp", "current_user", - "cursor", "date", "dbspace", "deallocate", - "dec", "decimal", "declare", "default", - "delete", "deleting", "desc", "distinct", - "do", "double", "drop", "dynamic", - "else", "elseif", "encrypted", "end", - "endif", "escape", "except", "exception", - "exec", "execute", "existing", "exists", - "externlogin", "fetch", "first", "float", - "for", "force", "foreign", "forward", - "from", "full", "goto", "grant", - "group", "having", "holdlock", "identified", - "if", "in", "index", "index_lparen", - "inner", "inout", "insensitive", "insert", - "inserting", "install", "instead", "int", - "integer", "integrated", "intersect", "into", - "iq", "is", "isolation", "join", - "key", "lateral", "left", "like", - "lock", "login", "long", "match", - "membership", "message", "mode", "modify", - "natural", "new", "no", "noholdlock", - "not", "notify", "null", "numeric", - "of", "off", "on", "open", - "option", "options", "or", "order", - "others", "out", "outer", "over", - "passthrough", "precision", "prepare", "primary", - "print", "privileges", "proc", "procedure", - "publication", "raiserror", "readtext", "real", - "reference", "references", "release", "remote", - "remove", "rename", "reorganize", "resource", - "restore", "restrict", "return", "revoke", - "right", "rollback", "rollup", "save", - "savepoint", "scroll", "select", "sensitive", - "session", "set", "setuser", "share", - "smallint", "some", "sqlcode", "sqlstate", - "start", "stop", "subtrans", "subtransaction", - "synchronize", "syntax_error", "table", "temporary", - "then", "time", "timestamp", "tinyint", - "to", "top", "tran", "trigger", - "truncate", "tsequal", "unbounded", "union", - "unique", "unknown", "unsigned", "update", - "updating", "user", "using", "validate", - "values", "varbinary", "varchar", "variable", - "varying", "view", "wait", "waitfor", - "when", "where", "while", "window", - "with", "with_cube", "with_lparen", "with_rollup", - "within", "work", "writetext", -]) +from sqlalchemy.types import ( + CHAR, + VARCHAR, + TIME, + NCHAR, + NVARCHAR, + TEXT, + DATE, + DATETIME, + FLOAT, + NUMERIC, + BIGINT, + INT, + INTEGER, + SMALLINT, + BINARY, + VARBINARY, + DECIMAL, + TIMESTAMP, + Unicode, + UnicodeText, + REAL, +) + +RESERVED_WORDS = set( + [ + "add", + "all", + "alter", + "and", + "any", + "as", + "asc", + "backup", + "begin", + "between", + "bigint", + "binary", + "bit", + "bottom", + "break", + "by", + "call", + "capability", + "cascade", + "case", + "cast", + "char", + "char_convert", + "character", + "check", + "checkpoint", + "close", + "comment", + "commit", + "connect", + "constraint", + "contains", + "continue", + "convert", + "create", + "cross", + "cube", + "current", + "current_timestamp", + "current_user", + "cursor", + "date", + "dbspace", + "deallocate", + "dec", + "decimal", + "declare", + "default", + "delete", + "deleting", + "desc", + "distinct", + "do", + "double", + "drop", + "dynamic", + "else", + "elseif", + "encrypted", + "end", + "endif", + "escape", + "except", + "exception", + "exec", + "execute", + "existing", + "exists", + "externlogin", + "fetch", + "first", + "float", + "for", + "force", + "foreign", + "forward", + "from", + "full", + "goto", + "grant", + "group", + "having", + "holdlock", + "identified", + "if", + "in", + "index", + "index_lparen", + "inner", + "inout", + "insensitive", + "insert", + "inserting", + "install", + "instead", + "int", + "integer", + "integrated", + "intersect", + "into", + "iq", + "is", + "isolation", + "join", + "key", + "lateral", + "left", + "like", + "lock", + "login", + "long", + "match", + "membership", + "message", + "mode", + "modify", + "natural", + "new", + "no", + "noholdlock", + "not", + "notify", + "null", + "numeric", + "of", + "off", + "on", + "open", + "option", + "options", + "or", + "order", + "others", + "out", + "outer", + "over", + "passthrough", + "precision", + "prepare", + "primary", + "print", + "privileges", + "proc", + "procedure", + "publication", + "raiserror", + "readtext", + "real", + "reference", + "references", + "release", + "remote", + "remove", + "rename", + "reorganize", + "resource", + "restore", + "restrict", + "return", + "revoke", + "right", + "rollback", + "rollup", + "save", + "savepoint", + "scroll", + "select", + "sensitive", + "session", + "set", + "setuser", + "share", + "smallint", + "some", + "sqlcode", + "sqlstate", + "start", + "stop", + "subtrans", + "subtransaction", + "synchronize", + "syntax_error", + "table", + "temporary", + "then", + "time", + "timestamp", + "tinyint", + "to", + "top", + "tran", + "trigger", + "truncate", + "tsequal", + "unbounded", + "union", + "unique", + "unknown", + "unsigned", + "update", + "updating", + "user", + "using", + "validate", + "values", + "varbinary", + "varchar", + "variable", + "varying", + "view", + "wait", + "waitfor", + "when", + "where", + "while", + "window", + "with", + "with_cube", + "with_lparen", + "with_rollup", + "within", + "work", + "writetext", + ] +) class _SybaseUnitypeMixin(object): @@ -106,27 +293,28 @@ class _SybaseUnitypeMixin(object): return str(value) # decode("ucs-2") else: return None + return process class UNICHAR(_SybaseUnitypeMixin, sqltypes.Unicode): - __visit_name__ = 'UNICHAR' + __visit_name__ = "UNICHAR" class UNIVARCHAR(_SybaseUnitypeMixin, sqltypes.Unicode): - __visit_name__ = 'UNIVARCHAR' + __visit_name__ = "UNIVARCHAR" class UNITEXT(_SybaseUnitypeMixin, sqltypes.UnicodeText): - __visit_name__ = 'UNITEXT' + __visit_name__ = "UNITEXT" class TINYINT(sqltypes.Integer): - __visit_name__ = 'TINYINT' + __visit_name__ = "TINYINT" class BIT(sqltypes.TypeEngine): - __visit_name__ = 'BIT' + __visit_name__ = "BIT" class MONEY(sqltypes.TypeEngine): @@ -142,7 +330,7 @@ class UNIQUEIDENTIFIER(sqltypes.TypeEngine): class IMAGE(sqltypes.LargeBinary): - __visit_name__ = 'IMAGE' + __visit_name__ = "IMAGE" class SybaseTypeCompiler(compiler.GenericTypeCompiler): @@ -182,67 +370,66 @@ class SybaseTypeCompiler(compiler.GenericTypeCompiler): def visit_UNIQUEIDENTIFIER(self, type_, **kw): return "UNIQUEIDENTIFIER" -ischema_names = { - 'bigint': BIGINT, - 'int': INTEGER, - 'integer': INTEGER, - 'smallint': SMALLINT, - 'tinyint': TINYINT, - 'unsigned bigint': BIGINT, # TODO: unsigned flags - 'unsigned int': INTEGER, # TODO: unsigned flags - 'unsigned smallint': SMALLINT, # TODO: unsigned flags - 'numeric': NUMERIC, - 'decimal': DECIMAL, - 'dec': DECIMAL, - 'float': FLOAT, - 'double': NUMERIC, # TODO - 'double precision': NUMERIC, # TODO - 'real': REAL, - 'smallmoney': SMALLMONEY, - 'money': MONEY, - 'smalldatetime': DATETIME, - 'datetime': DATETIME, - 'date': DATE, - 'time': TIME, - 'char': CHAR, - 'character': CHAR, - 'varchar': VARCHAR, - 'character varying': VARCHAR, - 'char varying': VARCHAR, - 'unichar': UNICHAR, - 'unicode character': UNIVARCHAR, - 'nchar': NCHAR, - 'national char': NCHAR, - 'national character': NCHAR, - 'nvarchar': NVARCHAR, - 'nchar varying': NVARCHAR, - 'national char varying': NVARCHAR, - 'national character varying': NVARCHAR, - 'text': TEXT, - 'unitext': UNITEXT, - 'binary': BINARY, - 'varbinary': VARBINARY, - 'image': IMAGE, - 'bit': BIT, +ischema_names = { + "bigint": BIGINT, + "int": INTEGER, + "integer": INTEGER, + "smallint": SMALLINT, + "tinyint": TINYINT, + "unsigned bigint": BIGINT, # TODO: unsigned flags + "unsigned int": INTEGER, # TODO: unsigned flags + "unsigned smallint": SMALLINT, # TODO: unsigned flags + "numeric": NUMERIC, + "decimal": DECIMAL, + "dec": DECIMAL, + "float": FLOAT, + "double": NUMERIC, # TODO + "double precision": NUMERIC, # TODO + "real": REAL, + "smallmoney": SMALLMONEY, + "money": MONEY, + "smalldatetime": DATETIME, + "datetime": DATETIME, + "date": DATE, + "time": TIME, + "char": CHAR, + "character": CHAR, + "varchar": VARCHAR, + "character varying": VARCHAR, + "char varying": VARCHAR, + "unichar": UNICHAR, + "unicode character": UNIVARCHAR, + "nchar": NCHAR, + "national char": NCHAR, + "national character": NCHAR, + "nvarchar": NVARCHAR, + "nchar varying": NVARCHAR, + "national char varying": NVARCHAR, + "national character varying": NVARCHAR, + "text": TEXT, + "unitext": UNITEXT, + "binary": BINARY, + "varbinary": VARBINARY, + "image": IMAGE, + "bit": BIT, # not in documentation for ASE 15.7 - 'long varchar': TEXT, # TODO - 'timestamp': TIMESTAMP, - 'uniqueidentifier': UNIQUEIDENTIFIER, - + "long varchar": TEXT, # TODO + "timestamp": TIMESTAMP, + "uniqueidentifier": UNIQUEIDENTIFIER, } class SybaseInspector(reflection.Inspector): - def __init__(self, conn): reflection.Inspector.__init__(self, conn) def get_table_id(self, table_name, schema=None): """Return the table id from `table_name` and `schema`.""" - return self.dialect.get_table_id(self.bind, table_name, schema, - info_cache=self.info_cache) + return self.dialect.get_table_id( + self.bind, table_name, schema, info_cache=self.info_cache + ) class SybaseExecutionContext(default.DefaultExecutionContext): @@ -267,15 +454,17 @@ class SybaseExecutionContext(default.DefaultExecutionContext): insert_has_sequence = seq_column is not None if insert_has_sequence: - self._enable_identity_insert = \ + self._enable_identity_insert = ( seq_column.key in self.compiled_parameters[0] + ) else: self._enable_identity_insert = False if self._enable_identity_insert: self.cursor.execute( - "SET IDENTITY_INSERT %s ON" % - self.dialect.identifier_preparer.format_table(tbl)) + "SET IDENTITY_INSERT %s ON" + % self.dialect.identifier_preparer.format_table(tbl) + ) if self.isddl: # TODO: to enhance this, we can detect "ddl in tran" on the @@ -284,14 +473,16 @@ class SybaseExecutionContext(default.DefaultExecutionContext): if not self.should_autocommit: raise exc.InvalidRequestError( "The Sybase dialect only supports " - "DDL in 'autocommit' mode at this time.") + "DDL in 'autocommit' mode at this time." + ) self.root_connection.engine.logger.info( - "AUTOCOMMIT (Assuming no Sybase 'ddl in tran')") + "AUTOCOMMIT (Assuming no Sybase 'ddl in tran')" + ) self.set_ddl_autocommit( - self.root_connection.connection.connection, - True) + self.root_connection.connection.connection, True + ) def post_exec(self): if self.isddl: @@ -299,9 +490,10 @@ class SybaseExecutionContext(default.DefaultExecutionContext): if self._enable_identity_insert: self.cursor.execute( - "SET IDENTITY_INSERT %s OFF" % - self.dialect.identifier_preparer. - format_table(self.compiled.statement.table) + "SET IDENTITY_INSERT %s OFF" + % self.dialect.identifier_preparer.format_table( + self.compiled.statement.table + ) ) def get_lastrowid(self): @@ -317,11 +509,8 @@ class SybaseSQLCompiler(compiler.SQLCompiler): extract_map = util.update_copy( compiler.SQLCompiler.extract_map, - { - 'doy': 'dayofyear', - 'dow': 'weekday', - 'milliseconds': 'millisecond' - }) + {"doy": "dayofyear", "dow": "weekday", "milliseconds": "millisecond"}, + ) def get_select_precolumns(self, select, **kw): s = select._distinct and "DISTINCT " or "" @@ -330,9 +519,9 @@ class SybaseSQLCompiler(compiler.SQLCompiler): limit = select._limit if limit: # if select._limit == 1: - # s += "FIRST " + # s += "FIRST " # else: - # s += "TOP %s " % (select._limit,) + # s += "TOP %s " % (select._limit,) s += "TOP %s " % (limit,) offset = select._offset if offset: @@ -348,8 +537,7 @@ class SybaseSQLCompiler(compiler.SQLCompiler): def visit_extract(self, extract, **kw): field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART("%s", %s)' % ( - field, self.process(extract.expr, **kw)) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw)) def visit_now_func(self, fn, **kw): return "GETDATE()" @@ -357,10 +545,10 @@ class SybaseSQLCompiler(compiler.SQLCompiler): def for_update_clause(self, select): # "FOR UPDATE" is only allowed on "DECLARE CURSOR" # which SQLAlchemy doesn't use - return '' + return "" def order_by_clause(self, select, **kw): - kw['literal_binds'] = True + kw["literal_binds"] = True order_by = self.process(select._order_by_clause, **kw) # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT @@ -369,8 +557,7 @@ class SybaseSQLCompiler(compiler.SQLCompiler): else: return "" - def delete_table_clause(self, delete_stmt, from_table, - extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms): """If we have extra froms make sure we render any alias as hint.""" ashint = False if extra_froms: @@ -379,34 +566,41 @@ class SybaseSQLCompiler(compiler.SQLCompiler): self, asfrom=True, iscrud=True, ashint=ashint ) - def delete_extra_from_clause(self, delete_stmt, from_table, - extra_froms, from_hints, **kw): + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the DELETE .. FROM clause specific to Sybase.""" - return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in [from_table] + extra_froms) + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) class SybaseDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process( - column.type, type_expression=column) + colspec = ( + self.preparer.format_column(column) + + " " + + self.dialect.type_compiler.process( + column.type, type_expression=column + ) + ) if column.table is None: raise exc.CompileError( "The Sybase dialect requires Table-bound " - "columns in order to generate DDL") + "columns in order to generate DDL" + ) seq_col = column.table._autoincrement_column # install a IDENTITY Sequence if we have an implicit IDENTITY column if seq_col is column: - sequence = isinstance(column.default, sa_schema.Sequence) \ + sequence = ( + isinstance(column.default, sa_schema.Sequence) and column.default + ) if sequence: - start, increment = sequence.start or 1, \ - sequence.increment or 1 + start, increment = sequence.start or 1, sequence.increment or 1 else: start, increment = 1, 1 if (start, increment) == (1, 1): @@ -431,8 +625,7 @@ class SybaseDDLCompiler(compiler.DDLCompiler): index = drop.element return "\nDROP INDEX %s.%s" % ( self.preparer.quote_identifier(index.table.name), - self._prepared_index_name(drop.element, - include_schema=False) + self._prepared_index_name(drop.element, include_schema=False), ) @@ -441,7 +634,7 @@ class SybaseIdentifierPreparer(compiler.IdentifierPreparer): class SybaseDialect(default.DefaultDialect): - name = 'sybase' + name = "sybase" supports_unicode_statements = False supports_sane_rowcount = False supports_sane_multi_rowcount = False @@ -463,14 +656,18 @@ class SybaseDialect(default.DefaultDialect): def _get_default_schema_name(self, connection): return connection.scalar( - text("SELECT user_name() as user_name", - typemap={'user_name': Unicode}) + text( + "SELECT user_name() as user_name", + typemap={"user_name": Unicode}, + ) ) def initialize(self, connection): super(SybaseDialect, self).initialize(connection) - if self.server_version_info is not None and\ - self.server_version_info < (15, ): + if ( + self.server_version_info is not None + and self.server_version_info < (15,) + ): self.max_identifier_length = 30 else: self.max_identifier_length = 255 @@ -488,22 +685,24 @@ class SybaseDialect(default.DefaultDialect): if schema is None: schema = self.default_schema_name - TABLEID_SQL = text(""" + TABLEID_SQL = text( + """ SELECT o.id AS id FROM sysobjects o JOIN sysusers u ON o.uid=u.uid WHERE u.name = :schema_name AND o.name = :table_name AND o.type in ('U', 'V') - """) + """ + ) if util.py2k: if isinstance(schema, unicode): schema = schema.encode("ascii") if isinstance(table_name, unicode): table_name = table_name.encode("ascii") - result = connection.execute(TABLEID_SQL, - schema_name=schema, - table_name=table_name) + result = connection.execute( + TABLEID_SQL, schema_name=schema, table_name=table_name + ) table_id = result.scalar() if table_id is None: raise exc.NoSuchTableError(table_name) @@ -511,10 +710,12 @@ class SybaseDialect(default.DefaultDialect): @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): - table_id = self.get_table_id(connection, table_name, schema, - info_cache=kw.get("info_cache")) + table_id = self.get_table_id( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) - COLUMN_SQL = text(""" + COLUMN_SQL = text( + """ SELECT col.name AS name, t.name AS type, (col.status & 8) AS nullable, @@ -528,23 +729,47 @@ class SybaseDialect(default.DefaultDialect): WHERE col.usertype = t.usertype AND col.id = :table_id ORDER BY col.colid - """) + """ + ) results = connection.execute(COLUMN_SQL, table_id=table_id) columns = [] - for (name, type_, nullable, autoincrement, default, precision, scale, - length) in results: - col_info = self._get_column_info(name, type_, bool(nullable), - bool(autoincrement), - default, precision, scale, - length) + for ( + name, + type_, + nullable, + autoincrement, + default, + precision, + scale, + length, + ) in results: + col_info = self._get_column_info( + name, + type_, + bool(nullable), + bool(autoincrement), + default, + precision, + scale, + length, + ) columns.append(col_info) return columns - def _get_column_info(self, name, type_, nullable, autoincrement, default, - precision, scale, length): + def _get_column_info( + self, + name, + type_, + nullable, + autoincrement, + default, + precision, + scale, + length, + ): coltype = self.ischema_names.get(type_, None) @@ -565,8 +790,9 @@ class SybaseDialect(default.DefaultDialect): # if is_array: # coltype = ARRAY(coltype) else: - util.warn("Did not recognize type '%s' of column '%s'" % - (type_, name)) + util.warn( + "Did not recognize type '%s' of column '%s'" % (type_, name) + ) coltype = sqltypes.NULLTYPE if default: @@ -575,15 +801,21 @@ class SybaseDialect(default.DefaultDialect): else: default = None - column_info = dict(name=name, type=coltype, nullable=nullable, - default=default, autoincrement=autoincrement) + column_info = dict( + name=name, + type=coltype, + nullable=nullable, + default=default, + autoincrement=autoincrement, + ) return column_info @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): - table_id = self.get_table_id(connection, table_name, schema, - info_cache=kw.get("info_cache")) + table_id = self.get_table_id( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) table_cache = {} column_cache = {} @@ -591,11 +823,13 @@ class SybaseDialect(default.DefaultDialect): table_cache[table_id] = {"name": table_name, "schema": schema} - COLUMN_SQL = text(""" + COLUMN_SQL = text( + """ SELECT c.colid AS id, c.name AS name FROM syscolumns c WHERE c.id = :table_id - """) + """ + ) results = connection.execute(COLUMN_SQL, table_id=table_id) columns = {} @@ -603,7 +837,8 @@ class SybaseDialect(default.DefaultDialect): columns[col["id"]] = col["name"] column_cache[table_id] = columns - REFCONSTRAINT_SQL = text(""" + REFCONSTRAINT_SQL = text( + """ SELECT o.name AS name, r.reftabid AS reftable_id, r.keycnt AS 'count', r.fokey1 AS fokey1, r.fokey2 AS fokey2, r.fokey3 AS fokey3, @@ -621,15 +856,19 @@ class SybaseDialect(default.DefaultDialect): r.refkey16 AS refkey16 FROM sysreferences r JOIN sysobjects o on r.tableid = o.id WHERE r.tableid = :table_id - """) + """ + ) referential_constraints = connection.execute( - REFCONSTRAINT_SQL, table_id=table_id).fetchall() + REFCONSTRAINT_SQL, table_id=table_id + ).fetchall() - REFTABLE_SQL = text(""" + REFTABLE_SQL = text( + """ SELECT o.name AS name, u.name AS 'schema' FROM sysobjects o JOIN sysusers u ON o.uid = u.uid WHERE o.id = :table_id - """) + """ + ) for r in referential_constraints: reftable_id = r["reftable_id"] @@ -639,8 +878,10 @@ class SybaseDialect(default.DefaultDialect): reftable = c.fetchone() c.close() table_info = {"name": reftable["name"], "schema": None} - if (schema is not None or - reftable["schema"] != self.default_schema_name): + if ( + schema is not None + or reftable["schema"] != self.default_schema_name + ): table_info["schema"] = reftable["schema"] table_cache[reftable_id] = table_info @@ -664,7 +905,7 @@ class SybaseDialect(default.DefaultDialect): "referred_schema": reftable["schema"], "referred_table": reftable["name"], "referred_columns": referred_columns, - "name": r["name"] + "name": r["name"], } foreign_keys.append(fk_info) @@ -673,10 +914,12 @@ class SybaseDialect(default.DefaultDialect): @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): - table_id = self.get_table_id(connection, table_name, schema, - info_cache=kw.get("info_cache")) + table_id = self.get_table_id( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) - INDEX_SQL = text(""" + INDEX_SQL = text( + """ SELECT object_name(i.id) AS table_name, i.keycnt AS 'count', i.name AS name, @@ -702,7 +945,8 @@ class SybaseDialect(default.DefaultDialect): AND o.id = :table_id AND (i.status & 2048) = 0 AND i.indid BETWEEN 1 AND 254 - """) + """ + ) results = connection.execute(INDEX_SQL, table_id=table_id) indexes = [] @@ -710,19 +954,23 @@ class SybaseDialect(default.DefaultDialect): column_names = [] for i in range(1, r["count"]): column_names.append(r["col_%i" % (i,)]) - index_info = {"name": r["name"], - "unique": bool(r["unique"]), - "column_names": column_names} + index_info = { + "name": r["name"], + "unique": bool(r["unique"]), + "column_names": column_names, + } indexes.append(index_info) return indexes @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): - table_id = self.get_table_id(connection, table_name, schema, - info_cache=kw.get("info_cache")) + table_id = self.get_table_id( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) - PK_SQL = text(""" + PK_SQL = text( + """ SELECT object_name(i.id) AS table_name, i.keycnt AS 'count', i.name AS name, @@ -747,7 +995,8 @@ class SybaseDialect(default.DefaultDialect): AND o.id = :table_id AND (i.status & 2048) = 2048 AND i.indid BETWEEN 1 AND 254 - """) + """ + ) results = connection.execute(PK_SQL, table_id=table_id) pks = results.fetchone() @@ -757,8 +1006,10 @@ class SybaseDialect(default.DefaultDialect): if pks: for i in range(1, pks["count"] + 1): constrained_columns.append(pks["pk_%i" % (i,)]) - return {"constrained_columns": constrained_columns, - "name": pks["name"]} + return { + "constrained_columns": constrained_columns, + "name": pks["name"], + } else: return {"constrained_columns": [], "name": None} @@ -776,12 +1027,14 @@ class SybaseDialect(default.DefaultDialect): if schema is None: schema = self.default_schema_name - TABLE_SQL = text(""" + TABLE_SQL = text( + """ SELECT o.name AS name FROM sysobjects o JOIN sysusers u ON o.uid = u.uid WHERE u.name = :schema_name AND o.type = 'U' - """) + """ + ) if util.py2k: if isinstance(schema, unicode): @@ -796,12 +1049,14 @@ class SybaseDialect(default.DefaultDialect): if schema is None: schema = self.default_schema_name - VIEW_DEF_SQL = text(""" + VIEW_DEF_SQL = text( + """ SELECT c.text FROM syscomments c JOIN sysobjects o ON c.id = o.id WHERE o.name = :view_name AND o.type = 'V' - """) + """ + ) if util.py2k: if isinstance(view_name, unicode): @@ -816,12 +1071,14 @@ class SybaseDialect(default.DefaultDialect): if schema is None: schema = self.default_schema_name - VIEW_SQL = text(""" + VIEW_SQL = text( + """ SELECT o.name AS name FROM sysobjects o JOIN sysusers u ON o.uid = u.uid WHERE u.name = :schema_name AND o.type = 'V' - """) + """ + ) if util.py2k: if isinstance(schema, unicode): diff --git a/lib/sqlalchemy/dialects/sybase/mxodbc.py b/lib/sqlalchemy/dialects/sybase/mxodbc.py index ddb6b7e211..eeceb359b7 100644 --- a/lib/sqlalchemy/dialects/sybase/mxodbc.py +++ b/lib/sqlalchemy/dialects/sybase/mxodbc.py @@ -30,4 +30,5 @@ class SybaseExecutionContext_mxodbc(SybaseExecutionContext): class SybaseDialect_mxodbc(MxODBCConnector, SybaseDialect): execution_ctx_cls = SybaseExecutionContext_mxodbc + dialect = SybaseDialect_mxodbc diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py index af6469dada..a4759428c8 100644 --- a/lib/sqlalchemy/dialects/sybase/pyodbc.py +++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py @@ -34,8 +34,10 @@ Currently *not* supported are:: """ -from sqlalchemy.dialects.sybase.base import SybaseDialect,\ - SybaseExecutionContext +from sqlalchemy.dialects.sybase.base import ( + SybaseDialect, + SybaseExecutionContext, +) from sqlalchemy.connectors.pyodbc import PyODBCConnector from sqlalchemy import types as sqltypes, processors import decimal @@ -51,12 +53,10 @@ class _SybNumeric_pyodbc(sqltypes.Numeric): """ def bind_processor(self, dialect): - super_process = super(_SybNumeric_pyodbc, self).\ - bind_processor(dialect) + super_process = super(_SybNumeric_pyodbc, self).bind_processor(dialect) def process(value): - if self.asdecimal and \ - isinstance(value, decimal.Decimal): + if self.asdecimal and isinstance(value, decimal.Decimal): if value.adjusted() < -6: return processors.to_float(value) @@ -65,6 +65,7 @@ class _SybNumeric_pyodbc(sqltypes.Numeric): return super_process(value) else: return value + return process @@ -79,8 +80,7 @@ class SybaseExecutionContext_pyodbc(SybaseExecutionContext): class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect): execution_ctx_cls = SybaseExecutionContext_pyodbc - colspecs = { - sqltypes.Numeric: _SybNumeric_pyodbc, - } + colspecs = {sqltypes.Numeric: _SybNumeric_pyodbc} + dialect = SybaseDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/sybase/pysybase.py b/lib/sqlalchemy/dialects/sybase/pysybase.py index 2168d55727..09d2cf380f 100644 --- a/lib/sqlalchemy/dialects/sybase/pysybase.py +++ b/lib/sqlalchemy/dialects/sybase/pysybase.py @@ -22,8 +22,11 @@ kind at this time. """ from sqlalchemy import types as sqltypes, processors -from sqlalchemy.dialects.sybase.base import SybaseDialect, \ - SybaseExecutionContext, SybaseSQLCompiler +from sqlalchemy.dialects.sybase.base import ( + SybaseDialect, + SybaseExecutionContext, + SybaseSQLCompiler, +) class _SybNumeric(sqltypes.Numeric): @@ -35,7 +38,6 @@ class _SybNumeric(sqltypes.Numeric): class SybaseExecutionContext_pysybase(SybaseExecutionContext): - def set_ddl_autocommit(self, dbapi_connection, value): if value: # call commit() on the Sybase connection directly, @@ -58,24 +60,22 @@ class SybaseSQLCompiler_pysybase(SybaseSQLCompiler): class SybaseDialect_pysybase(SybaseDialect): - driver = 'pysybase' + driver = "pysybase" execution_ctx_cls = SybaseExecutionContext_pysybase statement_compiler = SybaseSQLCompiler_pysybase - colspecs = { - sqltypes.Numeric: _SybNumeric, - sqltypes.Float: sqltypes.Float - } + colspecs = {sqltypes.Numeric: _SybNumeric, sqltypes.Float: sqltypes.Float} @classmethod def dbapi(cls): import Sybase + return Sybase def create_connect_args(self, url): - opts = url.translate_connect_args(username='user', password='passwd') + opts = url.translate_connect_args(username="user", password="passwd") - return ([opts.pop('host')], opts) + return ([opts.pop("host")], opts) def do_executemany(self, cursor, statement, parameters, context=None): # calling python-sybase executemany yields: @@ -90,13 +90,17 @@ class SybaseDialect_pysybase(SybaseDialect): return (vers / 1000, vers % 1000 / 100, vers % 100 / 10, vers % 10) def is_disconnect(self, e, connection, cursor): - if isinstance(e, (self.dbapi.OperationalError, - self.dbapi.ProgrammingError)): + if isinstance( + e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError) + ): msg = str(e) - return ('Unable to complete network request to host' in msg or - 'Invalid connection state' in msg or - 'Invalid cursor state' in msg) + return ( + "Unable to complete network request to host" in msg + or "Invalid connection state" in msg + or "Invalid cursor state" in msg + ) else: return False + dialect = SybaseDialect_pysybase diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index 6342b3c216..590359c380 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -57,10 +57,9 @@ from .interfaces import ( Dialect, ExecutionContext, ExceptionContext, - # backwards compat Compiled, - TypeCompiler + TypeCompiler, ) from .base import ( @@ -82,9 +81,7 @@ from .result import ( RowProxy, ) -from .util import ( - connection_memoize -) +from .util import connection_memoize from . import util, strategies @@ -92,7 +89,7 @@ from . import util, strategies # backwards compat from ..sql import ddl -default_strategy = 'plain' +default_strategy = "plain" def create_engine(*args, **kwargs): @@ -460,12 +457,12 @@ def create_engine(*args, **kwargs): """ - strategy = kwargs.pop('strategy', default_strategy) + strategy = kwargs.pop("strategy", default_strategy) strategy = strategies.strategies[strategy] return strategy.create(*args, **kwargs) -def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs): +def engine_from_config(configuration, prefix="sqlalchemy.", **kwargs): """Create a new Engine instance using a configuration dictionary. The dictionary is typically produced from a config file. @@ -497,16 +494,15 @@ def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs): """ - options = dict((key[len(prefix):], configuration[key]) - for key in configuration - if key.startswith(prefix)) - options['_coerce_config'] = True + options = dict( + (key[len(prefix) :], configuration[key]) + for key in configuration + if key.startswith(prefix) + ) + options["_coerce_config"] = True options.update(kwargs) - url = options.pop('url') + url = options.pop("url") return create_engine(url, **options) -__all__ = ( - 'create_engine', - 'engine_from_config', -) +__all__ = ("create_engine", "engine_from_config") diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 4a057ee596..75d03b7448 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -61,10 +61,16 @@ class Connection(Connectable): """ - def __init__(self, engine, connection=None, close_with_result=False, - _branch_from=None, _execution_options=None, - _dispatch=None, - _has_events=None): + def __init__( + self, + engine, + connection=None, + close_with_result=False, + _branch_from=None, + _execution_options=None, + _dispatch=None, + _has_events=None, + ): """Construct a new Connection. The constructor here is not public and is only called only by an @@ -86,8 +92,11 @@ class Connection(Connectable): self._has_events = _branch_from._has_events self.schema_for_object = _branch_from.schema_for_object else: - self.__connection = connection \ - if connection is not None else engine.raw_connection() + self.__connection = ( + connection + if connection is not None + else engine.raw_connection() + ) self.__transaction = None self.__savepoint_seq = 0 self.should_close_with_result = close_with_result @@ -101,7 +110,8 @@ class Connection(Connectable): # want to handle any of the engine's events in that case. self.dispatch = self.dispatch._join(engine.dispatch) self._has_events = _has_events or ( - _has_events is None and engine._has_events) + _has_events is None and engine._has_events + ) assert not _execution_options self._execution_options = engine._execution_options @@ -134,7 +144,8 @@ class Connection(Connectable): _branch_from=self, _execution_options=self._execution_options, _has_events=self._has_events, - _dispatch=self.dispatch) + _dispatch=self.dispatch, + ) @property def _root(self): @@ -322,8 +333,10 @@ class Connection(Connectable): def closed(self): """Return True if this connection is closed.""" - return '_Connection__connection' not in self.__dict__ \ + return ( + "_Connection__connection" not in self.__dict__ and not self.__can_reconnect + ) @property def invalidated(self): @@ -425,7 +438,8 @@ class Connection(Connectable): if self.__transaction is not None: raise exc.InvalidRequestError( "Can't reconnect until invalid " - "transaction is rolled back") + "transaction is rolled back" + ) self.__connection = self.engine.raw_connection(_connection=self) self.__invalid = False return self.__connection @@ -437,14 +451,15 @@ class Connection(Connectable): # dialect initializer, where the connection is not wrapped in # _ConnectionFairy - return getattr(self.__connection, 'is_valid', False) + return getattr(self.__connection, "is_valid", False) @property def _still_open_and_connection_is_valid(self): - return \ - not self.closed and \ - not self.invalidated and \ - getattr(self.__connection, 'is_valid', False) + return ( + not self.closed + and not self.invalidated + and getattr(self.__connection, "is_valid", False) + ) @property def info(self): @@ -656,7 +671,8 @@ class Connection(Connectable): if self.__transaction is not None: raise exc.InvalidRequestError( "Cannot start a two phase transaction when a transaction " - "is already in progress.") + "is already in progress." + ) if xid is None: xid = self.engine.dialect.create_xid() self.__transaction = TwoPhaseTransaction(self, xid) @@ -705,8 +721,10 @@ class Connection(Connectable): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) finally: - if not self.__invalid and \ - self.connection._reset_agent is self.__transaction: + if ( + not self.__invalid + and self.connection._reset_agent is self.__transaction + ): self.connection._reset_agent = None self.__transaction = None else: @@ -725,8 +743,10 @@ class Connection(Connectable): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) finally: - if not self.__invalid and \ - self.connection._reset_agent is self.__transaction: + if ( + not self.__invalid + and self.connection._reset_agent is self.__transaction + ): self.connection._reset_agent = None self.__transaction = None @@ -738,7 +758,7 @@ class Connection(Connectable): if name is None: self.__savepoint_seq += 1 - name = 'sa_savepoint_%s' % self.__savepoint_seq + name = "sa_savepoint_%s" % self.__savepoint_seq if self._still_open_and_connection_is_valid: self.engine.dialect.do_savepoint(self, name) return name @@ -797,7 +817,8 @@ class Connection(Connectable): assert isinstance(self.__transaction, TwoPhaseTransaction) try: self.engine.dialect.do_rollback_twophase( - self, xid, is_prepared) + self, xid, is_prepared + ) finally: if self.connection._reset_agent is self.__transaction: self.connection._reset_agent = None @@ -950,16 +971,16 @@ class Connection(Connectable): def _execute_function(self, func, multiparams, params): """Execute a sql.FunctionElement object.""" - return self._execute_clauseelement(func.select(), - multiparams, params) + return self._execute_clauseelement(func.select(), multiparams, params) def _execute_default(self, default, multiparams, params): """Execute a schema.ColumnDefault object.""" if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - default, multiparams, params = \ - fn(self, default, multiparams, params) + default, multiparams, params = fn( + self, default, multiparams, params + ) try: try: @@ -972,8 +993,7 @@ class Connection(Connectable): conn = self._revalidate_connection() dialect = self.dialect - ctx = dialect.execution_ctx_cls._init_default( - dialect, self, conn) + ctx = dialect.execution_ctx_cls._init_default(dialect, self, conn) except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) @@ -982,8 +1002,9 @@ class Connection(Connectable): self.close() if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - default, multiparams, params, ret) + self.dispatch.after_execute( + self, default, multiparams, params, ret + ) return ret @@ -992,25 +1013,25 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - ddl, multiparams, params = \ - fn(self, ddl, multiparams, params) + ddl, multiparams, params = fn(self, ddl, multiparams, params) dialect = self.dialect compiled = ddl.compile( dialect=dialect, schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default else None) + if not self.schema_for_object.is_default + else None, + ) ret = self._execute_context( dialect, dialect.execution_ctx_cls._init_ddl, compiled, None, - compiled + compiled, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - ddl, multiparams, params, ret) + self.dispatch.after_execute(self, ddl, multiparams, params, ret) return ret def _execute_clauseelement(self, elem, multiparams, params): @@ -1018,8 +1039,7 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - elem, multiparams, params = \ - fn(self, elem, multiparams, params) + elem, multiparams, params = fn(self, elem, multiparams, params) distilled_params = _distill_params(multiparams, params) if distilled_params: @@ -1030,38 +1050,45 @@ class Connection(Connectable): keys = [] dialect = self.dialect - if 'compiled_cache' in self._execution_options: + if "compiled_cache" in self._execution_options: key = ( - dialect, elem, tuple(sorted(keys)), + dialect, + elem, + tuple(sorted(keys)), self.schema_for_object.hash_key, - len(distilled_params) > 1 + len(distilled_params) > 1, ) - compiled_sql = self._execution_options['compiled_cache'].get(key) + compiled_sql = self._execution_options["compiled_cache"].get(key) if compiled_sql is None: compiled_sql = elem.compile( - dialect=dialect, column_keys=keys, + dialect=dialect, + column_keys=keys, inline=len(distilled_params) > 1, schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default else None + if not self.schema_for_object.is_default + else None, ) - self._execution_options['compiled_cache'][key] = compiled_sql + self._execution_options["compiled_cache"][key] = compiled_sql else: compiled_sql = elem.compile( - dialect=dialect, column_keys=keys, + dialect=dialect, + column_keys=keys, inline=len(distilled_params) > 1, schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default else None) + if not self.schema_for_object.is_default + else None, + ) ret = self._execute_context( dialect, dialect.execution_ctx_cls._init_compiled, compiled_sql, distilled_params, - compiled_sql, distilled_params + compiled_sql, + distilled_params, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - elem, multiparams, params, ret) + self.dispatch.after_execute(self, elem, multiparams, params, ret) return ret def _execute_compiled(self, compiled, multiparams, params): @@ -1069,8 +1096,9 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - compiled, multiparams, params = \ - fn(self, compiled, multiparams, params) + compiled, multiparams, params = fn( + self, compiled, multiparams, params + ) dialect = self.dialect parameters = _distill_params(multiparams, params) @@ -1079,11 +1107,13 @@ class Connection(Connectable): dialect.execution_ctx_cls._init_compiled, compiled, parameters, - compiled, parameters + compiled, + parameters, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - compiled, multiparams, params, ret) + self.dispatch.after_execute( + self, compiled, multiparams, params, ret + ) return ret def _execute_text(self, statement, multiparams, params): @@ -1091,8 +1121,9 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: for fn in self.dispatch.before_execute: - statement, multiparams, params = \ - fn(self, statement, multiparams, params) + statement, multiparams, params = fn( + self, statement, multiparams, params + ) dialect = self.dialect parameters = _distill_params(multiparams, params) @@ -1101,16 +1132,18 @@ class Connection(Connectable): dialect.execution_ctx_cls._init_statement, statement, parameters, - statement, parameters + statement, + parameters, ) if self._has_events or self.engine._has_events: - self.dispatch.after_execute(self, - statement, multiparams, params, ret) + self.dispatch.after_execute( + self, statement, multiparams, params, ret + ) return ret - def _execute_context(self, dialect, constructor, - statement, parameters, - *args): + def _execute_context( + self, dialect, constructor, statement, parameters, *args + ): """Create an :class:`.ExecutionContext` and execute, returning a :class:`.ResultProxy`.""" @@ -1127,31 +1160,36 @@ class Connection(Connectable): context = constructor(dialect, self, conn, *args) except BaseException as e: self._handle_dbapi_exception( - e, - util.text_type(statement), parameters, - None, None) + e, util.text_type(statement), parameters, None, None + ) if context.compiled: context.pre_exec() - cursor, statement, parameters = context.cursor, \ - context.statement, \ - context.parameters + cursor, statement, parameters = ( + context.cursor, + context.statement, + context.parameters, + ) if not context.executemany: parameters = parameters[0] if self._has_events or self.engine._has_events: for fn in self.dispatch.before_cursor_execute: - statement, parameters = \ - fn(self, cursor, statement, parameters, - context, context.executemany) + statement, parameters = fn( + self, + cursor, + statement, + parameters, + context, + context.executemany, + ) if self._echo: self.engine.logger.info(statement) self.engine.logger.info( - "%r", - sql_util._repr_params(parameters, batches=10) + "%r", sql_util._repr_params(parameters, batches=10) ) evt_handled = False @@ -1164,10 +1202,8 @@ class Connection(Connectable): break if not evt_handled: self.dialect.do_executemany( - cursor, - statement, - parameters, - context) + cursor, statement, parameters, context + ) elif not parameters and context.no_parameters: if self.dialect._has_events: for fn in self.dialect.dispatch.do_execute_no_params: @@ -1176,9 +1212,8 @@ class Connection(Connectable): break if not evt_handled: self.dialect.do_execute_no_params( - cursor, - statement, - context) + cursor, statement, context + ) else: if self.dialect._has_events: for fn in self.dialect.dispatch.do_execute: @@ -1187,24 +1222,22 @@ class Connection(Connectable): break if not evt_handled: self.dialect.do_execute( - cursor, - statement, - parameters, - context) + cursor, statement, parameters, context + ) except BaseException as e: self._handle_dbapi_exception( - e, - statement, - parameters, - cursor, - context) + e, statement, parameters, cursor, context + ) if self._has_events or self.engine._has_events: - self.dispatch.after_cursor_execute(self, cursor, - statement, - parameters, - context, - context.executemany) + self.dispatch.after_cursor_execute( + self, + cursor, + statement, + parameters, + context, + context.executemany, + ) if context.compiled: context.post_exec() @@ -1245,39 +1278,32 @@ class Connection(Connectable): """ if self._has_events or self.engine._has_events: for fn in self.dispatch.before_cursor_execute: - statement, parameters = \ - fn(self, cursor, statement, parameters, - context, - False) + statement, parameters = fn( + self, cursor, statement, parameters, context, False + ) if self._echo: self.engine.logger.info(statement) self.engine.logger.info("%r", parameters) try: - for fn in () if not self.dialect._has_events \ - else self.dialect.dispatch.do_execute: + for fn in ( + () + if not self.dialect._has_events + else self.dialect.dispatch.do_execute + ): if fn(cursor, statement, parameters, context): break else: - self.dialect.do_execute( - cursor, - statement, - parameters, - context) + self.dialect.do_execute(cursor, statement, parameters, context) except BaseException as e: self._handle_dbapi_exception( - e, - statement, - parameters, - cursor, - context) + e, statement, parameters, cursor, context + ) if self._has_events or self.engine._has_events: - self.dispatch.after_cursor_execute(self, cursor, - statement, - parameters, - context, - False) + self.dispatch.after_cursor_execute( + self, cursor, statement, parameters, context, False + ) def _safe_close_cursor(self, cursor): """Close the given cursor, catching exceptions @@ -1289,17 +1315,15 @@ class Connection(Connectable): except Exception: # log the error through the connection pool's logger. self.engine.pool.logger.error( - "Error closing cursor", exc_info=True) + "Error closing cursor", exc_info=True + ) _reentrant_error = False _is_disconnect = False - def _handle_dbapi_exception(self, - e, - statement, - parameters, - cursor, - context): + def _handle_dbapi_exception( + self, e, statement, parameters, cursor, context + ): exc_info = sys.exc_info() if context and context.exception is None: @@ -1309,15 +1333,14 @@ class Connection(Connectable): if not self._is_disconnect: self._is_disconnect = ( - isinstance(e, self.dialect.dbapi.Error) and - not self.closed and - self.dialect.is_disconnect( + isinstance(e, self.dialect.dbapi.Error) + and not self.closed + and self.dialect.is_disconnect( e, self.__connection if not self.invalidated else None, - cursor) - ) or ( - is_exit_exception and not self.closed - ) + cursor, + ) + ) or (is_exit_exception and not self.closed) if context: context.is_disconnect = self._is_disconnect @@ -1326,20 +1349,24 @@ class Connection(Connectable): if self._reentrant_error: util.raise_from_cause( - exc.DBAPIError.instance(statement, - parameters, - e, - self.dialect.dbapi.Error, - dialect=self.dialect), - exc_info + exc.DBAPIError.instance( + statement, + parameters, + e, + self.dialect.dbapi.Error, + dialect=self.dialect, + ), + exc_info, ) self._reentrant_error = True try: # non-DBAPI error - if we already got a context, # or there's no string statement, don't wrap it - should_wrap = isinstance(e, self.dialect.dbapi.Error) or \ - (statement is not None - and context is None and not is_exit_exception) + should_wrap = isinstance(e, self.dialect.dbapi.Error) or ( + statement is not None + and context is None + and not is_exit_exception + ) if should_wrap: sqlalchemy_exception = exc.DBAPIError.instance( @@ -1348,30 +1375,37 @@ class Connection(Connectable): e, self.dialect.dbapi.Error, connection_invalidated=self._is_disconnect, - dialect=self.dialect) + dialect=self.dialect, + ) else: sqlalchemy_exception = None newraise = None - if (self._has_events or self.engine._has_events) and \ - not self._execution_options.get( - 'skip_user_error_events', False): + if ( + self._has_events or self.engine._has_events + ) and not self._execution_options.get( + "skip_user_error_events", False + ): # legacy dbapi_error event if should_wrap and context: - self.dispatch.dbapi_error(self, - cursor, - statement, - parameters, - context, - e) + self.dispatch.dbapi_error( + self, cursor, statement, parameters, context, e + ) # new handle_error event ctx = ExceptionContextImpl( - e, sqlalchemy_exception, self.engine, - self, cursor, statement, - parameters, context, self._is_disconnect, - invalidate_pool_on_disconnect) + e, + sqlalchemy_exception, + self.engine, + self, + cursor, + statement, + parameters, + context, + self._is_disconnect, + invalidate_pool_on_disconnect, + ) for fn in self.dispatch.handle_error: try: @@ -1388,13 +1422,15 @@ class Connection(Connectable): if self._is_disconnect != ctx.is_disconnect: self._is_disconnect = ctx.is_disconnect if sqlalchemy_exception: - sqlalchemy_exception.connection_invalidated = \ + sqlalchemy_exception.connection_invalidated = ( ctx.is_disconnect + ) # set up potentially user-defined value for # invalidate pool. - invalidate_pool_on_disconnect = \ + invalidate_pool_on_disconnect = ( ctx.invalidate_pool_on_disconnect + ) if should_wrap and context: context.handle_dbapi_exception(e) @@ -1408,10 +1444,7 @@ class Connection(Connectable): if newraise: util.raise_from_cause(newraise, exc_info) elif should_wrap: - util.raise_from_cause( - sqlalchemy_exception, - exc_info - ) + util.raise_from_cause(sqlalchemy_exception, exc_info) else: util.reraise(*exc_info) @@ -1441,7 +1474,8 @@ class Connection(Connectable): None, e, dialect.dbapi.Error, - connection_invalidated=is_disconnect) + connection_invalidated=is_disconnect, + ) else: sqlalchemy_exception = None @@ -1449,8 +1483,17 @@ class Connection(Connectable): if engine._has_events: ctx = ExceptionContextImpl( - e, sqlalchemy_exception, engine, None, None, None, - None, None, is_disconnect, True) + e, + sqlalchemy_exception, + engine, + None, + None, + None, + None, + None, + is_disconnect, + True, + ) for fn in engine.dispatch.handle_error: try: # handler returns an exception; @@ -1463,18 +1506,15 @@ class Connection(Connectable): newraise = _raised break - if sqlalchemy_exception and \ - is_disconnect != ctx.is_disconnect: - sqlalchemy_exception.connection_invalidated = \ - is_disconnect = ctx.is_disconnect + if sqlalchemy_exception and is_disconnect != ctx.is_disconnect: + sqlalchemy_exception.connection_invalidated = ( + is_disconnect + ) = ctx.is_disconnect if newraise: util.raise_from_cause(newraise, exc_info) elif should_wrap: - util.raise_from_cause( - sqlalchemy_exception, - exc_info - ) + util.raise_from_cause(sqlalchemy_exception, exc_info) else: util.reraise(*exc_info) @@ -1545,16 +1585,25 @@ class Connection(Connectable): return callable_(self, *args, **kwargs) def _run_visitor(self, visitorcallable, element, **kwargs): - visitorcallable(self.dialect, self, - **kwargs).traverse_single(element) + visitorcallable(self.dialect, self, **kwargs).traverse_single(element) class ExceptionContextImpl(ExceptionContext): """Implement the :class:`.ExceptionContext` interface.""" - def __init__(self, exception, sqlalchemy_exception, - engine, connection, cursor, statement, parameters, - context, is_disconnect, invalidate_pool_on_disconnect): + def __init__( + self, + exception, + sqlalchemy_exception, + engine, + connection, + cursor, + statement, + parameters, + context, + is_disconnect, + invalidate_pool_on_disconnect, + ): self.engine = engine self.connection = connection self.sqlalchemy_exception = sqlalchemy_exception @@ -1691,12 +1740,14 @@ class NestedTransaction(Transaction): def _do_rollback(self): if self.is_active: self.connection._rollback_to_savepoint_impl( - self._savepoint, self._parent) + self._savepoint, self._parent + ) def _do_commit(self): if self.is_active: self.connection._release_savepoint_impl( - self._savepoint, self._parent) + self._savepoint, self._parent + ) class TwoPhaseTransaction(Transaction): @@ -1771,10 +1822,16 @@ class Engine(Connectable, log.Identified): """ - def __init__(self, pool, dialect, url, - logging_name=None, echo=None, proxy=None, - execution_options=None - ): + def __init__( + self, + pool, + dialect, + url, + logging_name=None, + echo=None, + proxy=None, + execution_options=None, + ): self.pool = pool self.url = url self.dialect = dialect @@ -1805,8 +1862,7 @@ class Engine(Connectable, log.Identified): :meth:`.Engine.execution_options` """ - self._execution_options = \ - self._execution_options.union(opt) + self._execution_options = self._execution_options.union(opt) self.dispatch.set_engine_execution_options(self, opt) self.dialect.set_engine_execution_options(self, opt) @@ -1894,7 +1950,7 @@ class Engine(Connectable, log.Identified): echo = log.echo_property() def __repr__(self): - return 'Engine(%r)' % self.url + return "Engine(%r)" % self.url def dispose(self): """Dispose of the connection pool used by this :class:`.Engine`. @@ -1934,8 +1990,9 @@ class Engine(Connectable, log.Identified): else: yield connection - def _run_visitor(self, visitorcallable, element, - connection=None, **kwargs): + def _run_visitor( + self, visitorcallable, element, connection=None, **kwargs + ): with self._optional_conn_ctx_manager(connection) as conn: conn._run_visitor(visitorcallable, element, **kwargs) @@ -2122,7 +2179,8 @@ class Engine(Connectable, log.Identified): self, self._wrap_pool_connect(self.pool.connect, None), close_with_result=close_with_result, - **kwargs) + **kwargs + ) def table_names(self, schema=None, connection=None): """Return a list of all table names available in the database. @@ -2159,7 +2217,8 @@ class Engine(Connectable, log.Identified): except dialect.dbapi.Error as e: if connection is None: Connection._handle_dbapi_exception_noconnection( - e, dialect, self) + e, dialect, self + ) else: util.reraise(*sys.exc_info()) @@ -2185,7 +2244,8 @@ class Engine(Connectable, log.Identified): """ return self._wrap_pool_connect( - self.pool.unique_connection, _connection) + self.pool.unique_connection, _connection + ) class OptionEngine(Engine): @@ -2225,10 +2285,11 @@ class OptionEngine(Engine): pool = property(_get_pool, _set_pool) def _get_has_events(self): - return self._proxied._has_events or \ - self.__dict__.get('_has_events', False) + return self._proxied._has_events or self.__dict__.get( + "_has_events", False + ) def _set_has_events(self, value): - self.__dict__['_has_events'] = value + self.__dict__["_has_events"] = value _has_events = property(_get_has_events, _set_has_events) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 028abc4c24..d7c2518fe0 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -24,13 +24,11 @@ import weakref from .. import event AUTOCOMMIT_REGEXP = re.compile( - r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', - re.I | re.UNICODE) + r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)", re.I | re.UNICODE +) # When we're handed literal SQL, ensure it's a SELECT query -SERVER_SIDE_CURSOR_RE = re.compile( - r'\s*SELECT', - re.I | re.UNICODE) +SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE) class DefaultDialect(interfaces.Dialect): @@ -68,16 +66,18 @@ class DefaultDialect(interfaces.Dialect): supports_simple_order_by_label = True - engine_config_types = util.immutabledict([ - ('convert_unicode', util.bool_or_str('force')), - ('pool_timeout', util.asint), - ('echo', util.bool_or_str('debug')), - ('echo_pool', util.bool_or_str('debug')), - ('pool_recycle', util.asint), - ('pool_size', util.asint), - ('max_overflow', util.asint), - ('pool_threadlocal', util.asbool), - ]) + engine_config_types = util.immutabledict( + [ + ("convert_unicode", util.bool_or_str("force")), + ("pool_timeout", util.asint), + ("echo", util.bool_or_str("debug")), + ("echo_pool", util.bool_or_str("debug")), + ("pool_recycle", util.asint), + ("pool_size", util.asint), + ("max_overflow", util.asint), + ("pool_threadlocal", util.asbool), + ] + ) # if the NUMERIC type # returns decimal.Decimal. @@ -93,9 +93,9 @@ class DefaultDialect(interfaces.Dialect): supports_unicode_statements = False supports_unicode_binds = False returns_unicode_strings = False - description_encoding = 'use_encoding' + description_encoding = "use_encoding" - name = 'default' + name = "default" # length at which to truncate # any identifier. @@ -111,7 +111,7 @@ class DefaultDialect(interfaces.Dialect): supports_sane_rowcount = True supports_sane_multi_rowcount = True colspecs = {} - default_paramstyle = 'named' + default_paramstyle = "named" supports_default_values = False supports_empty_insert = True supports_multivalues_insert = False @@ -175,19 +175,26 @@ class DefaultDialect(interfaces.Dialect): """ - def __init__(self, convert_unicode=False, - encoding='utf-8', paramstyle=None, dbapi=None, - implicit_returning=None, - supports_right_nested_joins=None, - case_sensitive=True, - supports_native_boolean=None, - empty_in_strategy='static', - label_length=None, **kwargs): - - if not getattr(self, 'ported_sqla_06', True): + def __init__( + self, + convert_unicode=False, + encoding="utf-8", + paramstyle=None, + dbapi=None, + implicit_returning=None, + supports_right_nested_joins=None, + case_sensitive=True, + supports_native_boolean=None, + empty_in_strategy="static", + label_length=None, + **kwargs + ): + + if not getattr(self, "ported_sqla_06", True): util.warn( - "The %s dialect is not yet ported to the 0.6 format" % - self.name) + "The %s dialect is not yet ported to the 0.6 format" + % self.name + ) self.convert_unicode = convert_unicode self.encoding = encoding @@ -202,7 +209,7 @@ class DefaultDialect(interfaces.Dialect): self.paramstyle = self.default_paramstyle if implicit_returning is not None: self.implicit_returning = implicit_returning - self.positional = self.paramstyle in ('qmark', 'format', 'numeric') + self.positional = self.paramstyle in ("qmark", "format", "numeric") self.identifier_preparer = self.preparer(self) self.type_compiler = self.type_compiler(self) if supports_right_nested_joins is not None: @@ -212,33 +219,33 @@ class DefaultDialect(interfaces.Dialect): self.case_sensitive = case_sensitive self.empty_in_strategy = empty_in_strategy - if empty_in_strategy == 'static': + if empty_in_strategy == "static": self._use_static_in = True - elif empty_in_strategy in ('dynamic', 'dynamic_warn'): + elif empty_in_strategy in ("dynamic", "dynamic_warn"): self._use_static_in = False - self._warn_on_empty_in = empty_in_strategy == 'dynamic_warn' + self._warn_on_empty_in = empty_in_strategy == "dynamic_warn" else: raise exc.ArgumentError( "empty_in_strategy may be 'static', " - "'dynamic', or 'dynamic_warn'") + "'dynamic', or 'dynamic_warn'" + ) if label_length and label_length > self.max_identifier_length: raise exc.ArgumentError( "Label length of %d is greater than this dialect's" - " maximum identifier length of %d" % - (label_length, self.max_identifier_length)) + " maximum identifier length of %d" + % (label_length, self.max_identifier_length) + ) self.label_length = label_length - if self.description_encoding == 'use_encoding': - self._description_decoder = \ - processors.to_unicode_processor_factory( - encoding - ) + if self.description_encoding == "use_encoding": + self._description_decoder = processors.to_unicode_processor_factory( + encoding + ) elif self.description_encoding is not None: - self._description_decoder = \ - processors.to_unicode_processor_factory( - self.description_encoding - ) + self._description_decoder = processors.to_unicode_processor_factory( + self.description_encoding + ) self._encoder = codecs.getencoder(self.encoding) self._decoder = processors.to_unicode_processor_factory(self.encoding) @@ -256,30 +263,35 @@ class DefaultDialect(interfaces.Dialect): @classmethod def get_pool_class(cls, url): - return getattr(cls, 'poolclass', pool.QueuePool) + return getattr(cls, "poolclass", pool.QueuePool) def initialize(self, connection): try: - self.server_version_info = \ - self._get_server_version_info(connection) + self.server_version_info = self._get_server_version_info( + connection + ) except NotImplementedError: self.server_version_info = None try: - self.default_schema_name = \ - self._get_default_schema_name(connection) + self.default_schema_name = self._get_default_schema_name( + connection + ) except NotImplementedError: self.default_schema_name = None try: - self.default_isolation_level = \ - self.get_isolation_level(connection.connection) + self.default_isolation_level = self.get_isolation_level( + connection.connection + ) except NotImplementedError: self.default_isolation_level = None self.returns_unicode_strings = self._check_unicode_returns(connection) - if self.description_encoding is not None and \ - self._check_unicode_description(connection): + if ( + self.description_encoding is not None + and self._check_unicode_description(connection) + ): self._description_decoder = self.description_encoding = None self.do_rollback(connection.connection) @@ -311,7 +323,8 @@ class DefaultDialect(interfaces.Dialect): def check_unicode(test): statement = cast_to( - expression.select([test]).compile(dialect=self)) + expression.select([test]).compile(dialect=self) + ) try: cursor = connection.connection.cursor() connection._cursor_execute(cursor, statement, parameters) @@ -320,8 +333,10 @@ class DefaultDialect(interfaces.Dialect): except exc.DBAPIError as de: # note that _cursor_execute() will have closed the cursor # if an exception is thrown. - util.warn("Exception attempting to " - "detect unicode returns: %r" % de) + util.warn( + "Exception attempting to " + "detect unicode returns: %r" % de + ) return False else: return isinstance(row[0], util.text_type) @@ -330,13 +345,13 @@ class DefaultDialect(interfaces.Dialect): # detect plain VARCHAR expression.cast( expression.literal_column("'test plain returns'"), - sqltypes.VARCHAR(60) + sqltypes.VARCHAR(60), ), # detect if there's an NVARCHAR type with different behavior # available expression.cast( expression.literal_column("'test unicode returns'"), - sqltypes.Unicode(60) + sqltypes.Unicode(60), ), ] @@ -364,9 +379,9 @@ class DefaultDialect(interfaces.Dialect): try: cursor.execute( cast_to( - expression.select([ - expression.literal_column("'x'").label("some_label") - ]).compile(dialect=self) + expression.select( + [expression.literal_column("'x'").label("some_label")] + ).compile(dialect=self) ) ) return isinstance(cursor.description[0][0], util.text_type) @@ -385,10 +400,12 @@ class DefaultDialect(interfaces.Dialect): return sqltypes.adapt_type(typeobj, self.colspecs) def reflecttable( - self, connection, table, include_columns, exclude_columns, **opts): + self, connection, table, include_columns, exclude_columns, **opts + ): insp = reflection.Inspector.from_engine(connection) return insp.reflecttable( - table, include_columns, exclude_columns, **opts) + table, include_columns, exclude_columns, **opts + ) def get_pk_constraint(self, conn, table_name, schema=None, **kw): """Compatibility method, adapts the result of get_primary_keys() @@ -396,16 +413,16 @@ class DefaultDialect(interfaces.Dialect): """ return { - 'constrained_columns': - self.get_primary_keys(conn, table_name, - schema=schema, **kw) + "constrained_columns": self.get_primary_keys( + conn, table_name, schema=schema, **kw + ) } def validate_identifier(self, ident): if len(ident) > self.max_identifier_length: raise exc.IdentifierError( - "Identifier '%s' exceeds maximum length of %d characters" % - (ident, self.max_identifier_length) + "Identifier '%s' exceeds maximum length of %d characters" + % (ident, self.max_identifier_length) ) def connect(self, *cargs, **cparams): @@ -417,16 +434,16 @@ class DefaultDialect(interfaces.Dialect): return [[], opts] def set_engine_execution_options(self, engine, opts): - if 'isolation_level' in opts: - isolation_level = opts['isolation_level'] + if "isolation_level" in opts: + isolation_level = opts["isolation_level"] @event.listens_for(engine, "engine_connect") def set_isolation(connection, branch): if not branch: self._set_connection_isolation(connection, isolation_level) - if 'schema_translate_map' in opts: - getter = schema._schema_getter(opts['schema_translate_map']) + if "schema_translate_map" in opts: + getter = schema._schema_getter(opts["schema_translate_map"]) engine.schema_for_object = getter @event.listens_for(engine, "engine_connect") @@ -434,11 +451,11 @@ class DefaultDialect(interfaces.Dialect): connection.schema_for_object = getter def set_connection_execution_options(self, connection, opts): - if 'isolation_level' in opts: - self._set_connection_isolation(connection, opts['isolation_level']) + if "isolation_level" in opts: + self._set_connection_isolation(connection, opts["isolation_level"]) - if 'schema_translate_map' in opts: - getter = schema._schema_getter(opts['schema_translate_map']) + if "schema_translate_map" in opts: + getter = schema._schema_getter(opts["schema_translate_map"]) connection.schema_for_object = getter def _set_connection_isolation(self, connection, level): @@ -447,10 +464,12 @@ class DefaultDialect(interfaces.Dialect): "Connection is already established with a Transaction; " "setting isolation_level may implicitly rollback or commit " "the existing transaction, or have no effect until " - "next transaction") + "next transaction" + ) self.set_isolation_level(connection.connection, level) - connection.connection._connection_record.\ - finalize_callback.append(self.reset_isolation_level) + connection.connection._connection_record.finalize_callback.append( + self.reset_isolation_level + ) def do_begin(self, dbapi_connection): pass @@ -593,8 +612,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return self @classmethod - def _init_compiled(cls, dialect, connection, dbapi_connection, - compiled, parameters): + def _init_compiled( + cls, dialect, connection, dbapi_connection, compiled, parameters + ): """Initialize execution context for a Compiled construct.""" self = cls.__new__(cls) @@ -609,16 +629,20 @@ class DefaultExecutionContext(interfaces.ExecutionContext): assert compiled.can_execute self.execution_options = compiled.execution_options.union( - connection._execution_options) + connection._execution_options + ) self.result_column_struct = ( - compiled._result_columns, compiled._ordered_columns, - compiled._textual_ordered_columns) + compiled._result_columns, + compiled._ordered_columns, + compiled._textual_ordered_columns, + ) self.unicode_statement = util.text_type(compiled) if not dialect.supports_unicode_statements: self.statement = self.unicode_statement.encode( - self.dialect.encoding) + self.dialect.encoding + ) else: self.statement = self.unicode_statement @@ -630,9 +654,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if not parameters: self.compiled_parameters = [compiled.construct_params()] else: - self.compiled_parameters = \ - [compiled.construct_params(m, _group_number=grp) for - grp, m in enumerate(parameters)] + self.compiled_parameters = [ + compiled.construct_params(m, _group_number=grp) + for grp, m in enumerate(parameters) + ] self.executemany = len(parameters) > 1 @@ -642,7 +667,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.is_crud = True self._is_explicit_returning = bool(compiled.statement._returning) self._is_implicit_returning = bool( - compiled.returning and not compiled.statement._returning) + compiled.returning and not compiled.statement._returning + ) if self.compiled.insert_prefetch or self.compiled.update_prefetch: if self.executemany: @@ -680,7 +706,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): dialect._encoder(key)[0], processors[key](compiled_params[key]) if key in processors - else compiled_params[key] + else compiled_params[key], ) for key in compiled_params ) @@ -690,7 +716,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): key, processors[key](compiled_params[key]) if key in processors - else compiled_params[key] + else compiled_params[key], ) for key in compiled_params ) @@ -708,14 +734,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext): """ if self.executemany: raise exc.InvalidRequestError( - "'expanding' parameters can't be used with " - "executemany()") + "'expanding' parameters can't be used with " "executemany()" + ) if self.compiled.positional and self.compiled._numeric_binds: # I'm not familiar with any DBAPI that uses 'numeric' raise NotImplementedError( "'expanding' bind parameters not supported with " - "'numeric' paramstyle at this time.") + "'numeric' paramstyle at this time." + ) self._expanded_parameters = {} @@ -729,7 +756,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): to_update_sets = {} for name in ( - self.compiled.positiontup if compiled.positional + self.compiled.positiontup + if compiled.positional else self.compiled.binds ): parameter = self.compiled.binds[name] @@ -748,12 +776,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if not values: to_update = to_update_sets[name] = [] - replacement_expressions[name] = ( - self.compiled.visit_empty_set_expr( - parameter._expanding_in_types - if parameter._expanding_in_types - else [parameter.type] - ) + replacement_expressions[ + name + ] = self.compiled.visit_empty_set_expr( + parameter._expanding_in_types + if parameter._expanding_in_types + else [parameter.type] ) elif isinstance(values[0], (tuple, list)): @@ -763,15 +791,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext): for j, value in enumerate(tuple_element, 1) ] replacement_expressions[name] = ", ".join( - "(%s)" % ", ".join( - self.compiled.bindtemplate % { - "name": - to_update[i * len(tuple_element) + j][0] + "(%s)" + % ", ".join( + self.compiled.bindtemplate + % { + "name": to_update[ + i * len(tuple_element) + j + ][0] } for j, value in enumerate(tuple_element) ) for i, tuple_element in enumerate(values) - ) else: to_update = to_update_sets[name] = [ @@ -779,20 +809,21 @@ class DefaultExecutionContext(interfaces.ExecutionContext): for i, value in enumerate(values, 1) ] replacement_expressions[name] = ", ".join( - self.compiled.bindtemplate % { - "name": key} + self.compiled.bindtemplate % {"name": key} for key, value in to_update ) compiled_params.update(to_update) processors.update( (key, processors[name]) - for key, value in to_update if name in processors + for key, value in to_update + if name in processors ) if compiled.positional: positiontup.extend(name for name, value in to_update) self._expanded_parameters[name] = [ - expand_key for expand_key, value in to_update] + expand_key for expand_key, value in to_update + ] elif compiled.positional: positiontup.append(name) @@ -800,15 +831,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return replacement_expressions[m.group(1)] self.statement = re.sub( - r"\[EXPANDING_(\S+)\]", - process_expanding, - self.statement + r"\[EXPANDING_(\S+)\]", process_expanding, self.statement ) return positiontup @classmethod - def _init_statement(cls, dialect, connection, dbapi_connection, - statement, parameters): + def _init_statement( + cls, dialect, connection, dbapi_connection, statement, parameters + ): """Initialize execution context for a string SQL statement.""" self = cls.__new__(cls) @@ -836,13 +866,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext): for d in parameters ] or [{}] else: - self.parameters = [dialect.execute_sequence_format(p) - for p in parameters] + self.parameters = [ + dialect.execute_sequence_format(p) for p in parameters + ] self.executemany = len(parameters) > 1 - if not dialect.supports_unicode_statements and \ - isinstance(statement, util.text_type): + if not dialect.supports_unicode_statements and isinstance( + statement, util.text_type + ): self.unicode_statement = statement self.statement = dialect._encoder(statement)[0] else: @@ -890,11 +922,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext): @util.memoized_property def should_autocommit(self): - autocommit = self.execution_options.get('autocommit', - not self.compiled and - self.statement and - expression.PARSE_AUTOCOMMIT - or False) + autocommit = self.execution_options.get( + "autocommit", + not self.compiled + and self.statement + and expression.PARSE_AUTOCOMMIT + or False, + ) if autocommit is expression.PARSE_AUTOCOMMIT: return self.should_autocommit_text(self.unicode_statement) @@ -912,8 +946,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext): """ conn = self.root_connection - if isinstance(stmt, util.text_type) and \ - not self.dialect.supports_unicode_statements: + if ( + isinstance(stmt, util.text_type) + and not self.dialect.supports_unicode_statements + ): stmt = self.dialect._encoder(stmt)[0] if self.dialect.positional: @@ -926,8 +962,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if type_ is not None: # apply type post processors to the result proc = type_._cached_result_processor( - self.dialect, - self.cursor.description[0][1] + self.dialect, self.cursor.description[0][1] ) if proc: return proc(r) @@ -945,22 +980,30 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return False if self.dialect.server_side_cursors: - use_server_side = \ - self.execution_options.get('stream_results', True) and ( - (self.compiled and isinstance(self.compiled.statement, - expression.Selectable) - or - ( - (not self.compiled or - isinstance(self.compiled.statement, - expression.TextClause)) - and self.statement and SERVER_SIDE_CURSOR_RE.match( - self.statement)) - ) + use_server_side = self.execution_options.get( + "stream_results", True + ) and ( + ( + self.compiled + and isinstance( + self.compiled.statement, expression.Selectable + ) + or ( + ( + not self.compiled + or isinstance( + self.compiled.statement, expression.TextClause + ) + ) + and self.statement + and SERVER_SIDE_CURSOR_RE.match(self.statement) + ) ) + ) else: - use_server_side = \ - self.execution_options.get('stream_results', False) + use_server_side = self.execution_options.get( + "stream_results", False + ) return use_server_side @@ -1039,11 +1082,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return self.dialect.supports_sane_multi_rowcount def _setup_crud_result_proxy(self): - if self.isinsert and \ - not self.executemany: - if not self._is_implicit_returning and \ - not self.compiled.inline and \ - self.dialect.postfetch_lastrowid: + if self.isinsert and not self.executemany: + if ( + not self._is_implicit_returning + and not self.compiled.inline + and self.dialect.postfetch_lastrowid + ): self._setup_ins_pk_from_lastrowid() @@ -1087,12 +1131,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if autoinc_col is not None: # apply type post processors to the lastrowid proc = autoinc_col.type._cached_result_processor( - self.dialect, None) + self.dialect, None + ) if proc is not None: lastrowid = proc(lastrowid) self.inserted_primary_key = [ - lastrowid if c is autoinc_col else - compiled_params.get(key_getter(c), None) + lastrowid + if c is autoinc_col + else compiled_params.get(key_getter(c), None) for c in table.primary_key ] else: @@ -1108,8 +1154,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): table = self.compiled.statement.table compiled_params = self.compiled_parameters[0] self.inserted_primary_key = [ - compiled_params.get(key_getter(c), None) - for c in table.primary_key + compiled_params.get(key_getter(c), None) for c in table.primary_key ] def _setup_ins_pk_from_implicit_returning(self, row): @@ -1129,11 +1174,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext): ] def lastrow_has_defaults(self): - return (self.isinsert or self.isupdate) and \ - bool(self.compiled.postfetch) + return (self.isinsert or self.isupdate) and bool( + self.compiled.postfetch + ) def set_input_sizes( - self, translate=None, include_types=None, exclude_types=None): + self, translate=None, include_types=None, exclude_types=None + ): """Given a cursor and ClauseParameters, call the appropriate style of ``setinputsizes()`` on the cursor, using DB-API types from the bind parameter's ``TypeEngine`` objects. @@ -1143,7 +1190,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): """ - if not hasattr(self.compiled, 'bind_names'): + if not hasattr(self.compiled, "bind_names"): return inputsizes = {} @@ -1153,12 +1200,18 @@ class DefaultExecutionContext(interfaces.ExecutionContext): dialect_impl_cls = type(dialect_impl) dbtype = dialect_impl.get_dbapi_type(self.dialect.dbapi) - if dbtype is not None and ( - not exclude_types or dbtype not in exclude_types and - dialect_impl_cls not in exclude_types - ) and ( - not include_types or dbtype in include_types or - dialect_impl_cls in include_types + if ( + dbtype is not None + and ( + not exclude_types + or dbtype not in exclude_types + and dialect_impl_cls not in exclude_types + ) + and ( + not include_types + or dbtype in include_types + or dialect_impl_cls in include_types + ) ): inputsizes[bindparam] = dbtype else: @@ -1177,14 +1230,16 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if dbtype is not None: if key in self._expanded_parameters: positional_inputsizes.extend( - [dbtype] * len(self._expanded_parameters[key])) + [dbtype] * len(self._expanded_parameters[key]) + ) else: positional_inputsizes.append(dbtype) try: self.cursor.setinputsizes(*positional_inputsizes) except BaseException as e: self.root_connection._handle_dbapi_exception( - e, None, None, None, self) + e, None, None, None, self + ) else: keyword_inputsizes = {} for bindparam, key in self.compiled.bind_names.items(): @@ -1199,8 +1254,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): key = self.dialect._encoder(key)[0] if key in self._expanded_parameters: keyword_inputsizes.update( - (expand_key, dbtype) for expand_key - in self._expanded_parameters[key] + (expand_key, dbtype) + for expand_key in self._expanded_parameters[key] ) else: keyword_inputsizes[key] = dbtype @@ -1208,7 +1263,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.cursor.setinputsizes(**keyword_inputsizes) except BaseException as e: self.root_connection._handle_dbapi_exception( - e, None, None, None, self) + e, None, None, None, self + ) def _exec_default(self, column, default, type_): if default.is_sequence: @@ -1290,10 +1346,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext): except AttributeError: raise exc.InvalidRequestError( "get_current_parameters() can only be invoked in the " - "context of a Python side column default function") - if isolate_multiinsert_groups and \ - self.isinsert and \ - self.compiled.statement._has_multi_parameters: + "context of a Python side column default function" + ) + if ( + isolate_multiinsert_groups + and self.isinsert + and self.compiled.statement._has_multi_parameters + ): if column._is_multiparam_column: index = column.index + 1 d = {column.original.key: parameters[column.key]} @@ -1302,8 +1361,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): index = 0 keys = self.compiled.statement.parameters[0].keys() d.update( - (key, parameters["%s_m%d" % (key, index)]) - for key in keys + (key, parameters["%s_m%d" % (key, index)]) for key in keys ) return d else: @@ -1360,12 +1418,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def _process_executesingle_defaults(self): key_getter = self.compiled._key_getters_for_crud_column[2] - self.current_parameters = compiled_parameters = \ - self.compiled_parameters[0] + self.current_parameters = ( + compiled_parameters + ) = self.compiled_parameters[0] for c in self.compiled.insert_prefetch: - if c.default and \ - not c.default.is_sequence and c.default.is_scalar: + if c.default and not c.default.is_sequence and c.default.is_scalar: val = c.default.arg else: val = self.get_insert_default(c) diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 9c3b24e9a8..e10e6e8844 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -198,7 +198,8 @@ class Dialect(object): pass def reflecttable( - self, connection, table, include_columns, exclude_columns): + self, connection, table, include_columns, exclude_columns + ): """Load table description from the database. Given a :class:`.Connection` and a @@ -367,7 +368,8 @@ class Dialect(object): raise NotImplementedError() def get_unique_constraints( - self, connection, table_name, schema=None, **kw): + self, connection, table_name, schema=None, **kw + ): r"""Return information about unique constraints in `table_name`. Given a string `table_name` and an optional string `schema`, return @@ -389,8 +391,7 @@ class Dialect(object): raise NotImplementedError() - def get_check_constraints( - self, connection, table_name, schema=None, **kw): + def get_check_constraints(self, connection, table_name, schema=None, **kw): r"""Return information about check constraints in `table_name`. Given a string `table_name` and an optional string `schema`, return @@ -412,8 +413,7 @@ class Dialect(object): raise NotImplementedError() - def get_table_comment( - self, connection, table_name, schema=None, **kw): + def get_table_comment(self, connection, table_name, schema=None, **kw): r"""Return the "comment" for the table identified by `table_name`. Given a string `table_name` and an optional string `schema`, return @@ -613,8 +613,9 @@ class Dialect(object): raise NotImplementedError() - def do_rollback_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): """Rollback a two phase transaction on the given connection. :param connection: a :class:`.Connection`. @@ -627,8 +628,9 @@ class Dialect(object): raise NotImplementedError() - def do_commit_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): """Commit a two phase transaction on the given connection. @@ -664,8 +666,9 @@ class Dialect(object): raise NotImplementedError() - def do_execute_no_params(self, cursor, statement, parameters, - context=None): + def do_execute_no_params( + self, cursor, statement, parameters, context=None + ): """Provide an implementation of ``cursor.execute(statement)``. The parameter collection should not be sent. @@ -899,6 +902,7 @@ class CreateEnginePlugin(object): .. versionadded:: 1.1 """ + def __init__(self, url, kwargs): """Contruct a new :class:`.CreateEnginePlugin`. @@ -1129,20 +1133,24 @@ class Connectable(object): raise NotImplementedError() - @util.deprecated("0.7", - "Use the create() method on the given schema " - "object directly, i.e. :meth:`.Table.create`, " - ":meth:`.Index.create`, :meth:`.MetaData.create_all`") + @util.deprecated( + "0.7", + "Use the create() method on the given schema " + "object directly, i.e. :meth:`.Table.create`, " + ":meth:`.Index.create`, :meth:`.MetaData.create_all`", + ) def create(self, entity, **kwargs): """Emit CREATE statements for the given schema entity. """ raise NotImplementedError() - @util.deprecated("0.7", - "Use the drop() method on the given schema " - "object directly, i.e. :meth:`.Table.drop`, " - ":meth:`.Index.drop`, :meth:`.MetaData.drop_all`") + @util.deprecated( + "0.7", + "Use the drop() method on the given schema " + "object directly, i.e. :meth:`.Table.drop`, " + ":meth:`.Index.drop`, :meth:`.MetaData.drop_all`", + ) def drop(self, entity, **kwargs): """Emit DROP statements for the given schema entity. """ @@ -1160,8 +1168,7 @@ class Connectable(object): """ raise NotImplementedError() - def _run_visitor(self, visitorcallable, element, - **kwargs): + def _run_visitor(self, visitorcallable, element, **kwargs): raise NotImplementedError() def _execute_clauseelement(self, elem, multiparams=None, params=None): diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 841bb4dfb4..9b5fa24595 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -37,17 +37,17 @@ from .base import Connectable @util.decorator def cache(fn, self, con, *args, **kw): - info_cache = kw.get('info_cache', None) + info_cache = kw.get("info_cache", None) if info_cache is None: return fn(self, con, *args, **kw) key = ( fn.__name__, tuple(a for a in args if isinstance(a, util.string_types)), - tuple((k, v) for k, v in kw.items() if - isinstance(v, - util.string_types + util.int_types + (float, ) - ) - ) + tuple( + (k, v) + for k, v in kw.items() + if isinstance(v, util.string_types + util.int_types + (float,)) + ), ) ret = info_cache.get(key) if ret is None: @@ -99,7 +99,7 @@ class Inspector(object): self.bind = bind # set the engine - if hasattr(bind, 'engine'): + if hasattr(bind, "engine"): self.engine = bind.engine else: self.engine = bind @@ -130,7 +130,7 @@ class Inspector(object): See the example at :class:`.Inspector`. """ - if hasattr(bind.dialect, 'inspector'): + if hasattr(bind.dialect, "inspector"): return bind.dialect.inspector(bind) return Inspector(bind) @@ -153,9 +153,10 @@ class Inspector(object): """Return all schema names. """ - if hasattr(self.dialect, 'get_schema_names'): - return self.dialect.get_schema_names(self.bind, - info_cache=self.info_cache) + if hasattr(self.dialect, "get_schema_names"): + return self.dialect.get_schema_names( + self.bind, info_cache=self.info_cache + ) return [] def get_table_names(self, schema=None, order_by=None): @@ -196,17 +197,18 @@ class Inspector(object): """ - if hasattr(self.dialect, 'get_table_names'): + if hasattr(self.dialect, "get_table_names"): tnames = self.dialect.get_table_names( - self.bind, schema, info_cache=self.info_cache) + self.bind, schema, info_cache=self.info_cache + ) else: tnames = self.engine.table_names(schema) - if order_by == 'foreign_key': + if order_by == "foreign_key": tuples = [] for tname in tnames: for fkey in self.get_foreign_keys(tname, schema): - if tname != fkey['referred_table']: - tuples.append((fkey['referred_table'], tname)) + if tname != fkey["referred_table"]: + tuples.append((fkey["referred_table"], tname)) tnames = list(topological.sort(tuples, tnames)) return tnames @@ -234,9 +236,10 @@ class Inspector(object): with an already-given :class:`.MetaData`. """ - if hasattr(self.dialect, 'get_table_names'): + if hasattr(self.dialect, "get_table_names"): tnames = self.dialect.get_table_names( - self.bind, schema, info_cache=self.info_cache) + self.bind, schema, info_cache=self.info_cache + ) else: tnames = self.engine.table_names(schema) @@ -246,20 +249,17 @@ class Inspector(object): fknames_for_table = {} for tname in tnames: fkeys = self.get_foreign_keys(tname, schema) - fknames_for_table[tname] = set( - [fk['name'] for fk in fkeys] - ) + fknames_for_table[tname] = set([fk["name"] for fk in fkeys]) for fkey in fkeys: - if tname != fkey['referred_table']: - tuples.add((fkey['referred_table'], tname)) + if tname != fkey["referred_table"]: + tuples.add((fkey["referred_table"], tname)) try: candidate_sort = list(topological.sort(tuples, tnames)) except exc.CircularDependencyError as err: for edge in err.edges: tuples.remove(edge) remaining_fkcs.update( - (edge[1], fkc) - for fkc in fknames_for_table[edge[1]] + (edge[1], fkc) for fkc in fknames_for_table[edge[1]] ) candidate_sort = list(topological.sort(tuples, tnames)) @@ -278,7 +278,8 @@ class Inspector(object): """ return self.dialect.get_temp_table_names( - self.bind, info_cache=self.info_cache) + self.bind, info_cache=self.info_cache + ) def get_temp_view_names(self): """return a list of temporary view names for the current bind. @@ -290,7 +291,8 @@ class Inspector(object): """ return self.dialect.get_temp_view_names( - self.bind, info_cache=self.info_cache) + self.bind, info_cache=self.info_cache + ) def get_table_options(self, table_name, schema=None, **kw): """Return a dictionary of options specified when the table of the @@ -306,10 +308,10 @@ class Inspector(object): use :class:`.quoted_name`. """ - if hasattr(self.dialect, 'get_table_options'): + if hasattr(self.dialect, "get_table_options"): return self.dialect.get_table_options( - self.bind, table_name, schema, - info_cache=self.info_cache, **kw) + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) return {} def get_view_names(self, schema=None): @@ -320,8 +322,9 @@ class Inspector(object): """ - return self.dialect.get_view_names(self.bind, schema, - info_cache=self.info_cache) + return self.dialect.get_view_names( + self.bind, schema, info_cache=self.info_cache + ) def get_view_definition(self, view_name, schema=None): """Return definition for `view_name`. @@ -332,7 +335,8 @@ class Inspector(object): """ return self.dialect.get_view_definition( - self.bind, view_name, schema, info_cache=self.info_cache) + self.bind, view_name, schema, info_cache=self.info_cache + ) def get_columns(self, table_name, schema=None, **kw): """Return information about columns in `table_name`. @@ -364,18 +368,21 @@ class Inspector(object): """ - col_defs = self.dialect.get_columns(self.bind, table_name, schema, - info_cache=self.info_cache, - **kw) + col_defs = self.dialect.get_columns( + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) for col_def in col_defs: # make this easy and only return instances for coltype - coltype = col_def['type'] + coltype = col_def["type"] if not isinstance(coltype, TypeEngine): - col_def['type'] = coltype() + col_def["type"] = coltype() return col_defs - @deprecated('0.7', 'Call to deprecated method get_primary_keys.' - ' Use get_pk_constraint instead.') + @deprecated( + "0.7", + "Call to deprecated method get_primary_keys." + " Use get_pk_constraint instead.", + ) def get_primary_keys(self, table_name, schema=None, **kw): """Return information about primary keys in `table_name`. @@ -383,9 +390,9 @@ class Inspector(object): primary key information as a list of column names. """ - return self.dialect.get_pk_constraint(self.bind, table_name, schema, - info_cache=self.info_cache, - **kw)['constrained_columns'] + return self.dialect.get_pk_constraint( + self.bind, table_name, schema, info_cache=self.info_cache, **kw + )["constrained_columns"] def get_pk_constraint(self, table_name, schema=None, **kw): """Return information about primary key constraint on `table_name`. @@ -407,9 +414,9 @@ class Inspector(object): use :class:`.quoted_name`. """ - return self.dialect.get_pk_constraint(self.bind, table_name, schema, - info_cache=self.info_cache, - **kw) + return self.dialect.get_pk_constraint( + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) def get_foreign_keys(self, table_name, schema=None, **kw): """Return information about foreign_keys in `table_name`. @@ -442,9 +449,9 @@ class Inspector(object): """ - return self.dialect.get_foreign_keys(self.bind, table_name, schema, - info_cache=self.info_cache, - **kw) + return self.dialect.get_foreign_keys( + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) def get_indexes(self, table_name, schema=None, **kw): """Return information about indexes in `table_name`. @@ -476,9 +483,9 @@ class Inspector(object): """ - return self.dialect.get_indexes(self.bind, table_name, - schema, - info_cache=self.info_cache, **kw) + return self.dialect.get_indexes( + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) def get_unique_constraints(self, table_name, schema=None, **kw): """Return information about unique constraints in `table_name`. @@ -504,7 +511,8 @@ class Inspector(object): """ return self.dialect.get_unique_constraints( - self.bind, table_name, schema, info_cache=self.info_cache, **kw) + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) def get_table_comment(self, table_name, schema=None, **kw): """Return information about the table comment for ``table_name``. @@ -523,8 +531,8 @@ class Inspector(object): """ return self.dialect.get_table_comment( - self.bind, table_name, schema, info_cache=self.info_cache, - **kw) + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) def get_check_constraints(self, table_name, schema=None, **kw): """Return information about check constraints in `table_name`. @@ -550,10 +558,12 @@ class Inspector(object): """ return self.dialect.get_check_constraints( - self.bind, table_name, schema, info_cache=self.info_cache, **kw) + self.bind, table_name, schema, info_cache=self.info_cache, **kw + ) - def reflecttable(self, table, include_columns, exclude_columns=(), - _extend_on=None): + def reflecttable( + self, table, include_columns, exclude_columns=(), _extend_on=None + ): """Given a Table object, load its internal constructs based on introspection. @@ -599,7 +609,8 @@ class Inspector(object): # reflect table options, like mysql_engine tbl_opts = self.get_table_options( - table_name, schema, **table.dialect_kwargs) + table_name, schema, **table.dialect_kwargs + ) if tbl_opts: # add additional kwargs to the Table if the dialect # returned them @@ -615,185 +626,251 @@ class Inspector(object): cols_by_orig_name = {} for col_d in self.get_columns( - table_name, schema, **table.dialect_kwargs): + table_name, schema, **table.dialect_kwargs + ): found_table = True self._reflect_column( - table, col_d, include_columns, - exclude_columns, cols_by_orig_name) + table, + col_d, + include_columns, + exclude_columns, + cols_by_orig_name, + ) if not found_table: raise exc.NoSuchTableError(table.name) self._reflect_pk( - table_name, schema, table, cols_by_orig_name, exclude_columns) + table_name, schema, table, cols_by_orig_name, exclude_columns + ) self._reflect_fk( - table_name, schema, table, cols_by_orig_name, - exclude_columns, _extend_on, reflection_options) + table_name, + schema, + table, + cols_by_orig_name, + exclude_columns, + _extend_on, + reflection_options, + ) self._reflect_indexes( - table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options) + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ) self._reflect_unique_constraints( - table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options) + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ) self._reflect_check_constraints( - table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options) + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ) self._reflect_table_comment( table_name, schema, table, reflection_options ) def _reflect_column( - self, table, col_d, include_columns, - exclude_columns, cols_by_orig_name): + self, table, col_d, include_columns, exclude_columns, cols_by_orig_name + ): - orig_name = col_d['name'] + orig_name = col_d["name"] table.dispatch.column_reflect(self, table, col_d) # fetch name again as column_reflect is allowed to # change it - name = col_d['name'] - if (include_columns and name not in include_columns) \ - or (exclude_columns and name in exclude_columns): + name = col_d["name"] + if (include_columns and name not in include_columns) or ( + exclude_columns and name in exclude_columns + ): return - coltype = col_d['type'] + coltype = col_d["type"] col_kw = dict( (k, col_d[k]) for k in [ - 'nullable', 'autoincrement', 'quote', 'info', 'key', - 'comment'] + "nullable", + "autoincrement", + "quote", + "info", + "key", + "comment", + ] if k in col_d ) - if 'dialect_options' in col_d: - col_kw.update(col_d['dialect_options']) + if "dialect_options" in col_d: + col_kw.update(col_d["dialect_options"]) colargs = [] - if col_d.get('default') is not None: - default = col_d['default'] + if col_d.get("default") is not None: + default = col_d["default"] if isinstance(default, sql.elements.TextClause): default = sa_schema.DefaultClause(default, _reflected=True) elif not isinstance(default, sa_schema.FetchedValue): default = sa_schema.DefaultClause( - sql.text(col_d['default']), _reflected=True) + sql.text(col_d["default"]), _reflected=True + ) colargs.append(default) - if 'sequence' in col_d: + if "sequence" in col_d: self._reflect_col_sequence(col_d, colargs) - cols_by_orig_name[orig_name] = col = \ - sa_schema.Column(name, coltype, *colargs, **col_kw) + cols_by_orig_name[orig_name] = col = sa_schema.Column( + name, coltype, *colargs, **col_kw + ) if col.key in table.primary_key: col.primary_key = True table.append_column(col) def _reflect_col_sequence(self, col_d, colargs): - if 'sequence' in col_d: + if "sequence" in col_d: # TODO: mssql and sybase are using this. - seq = col_d['sequence'] - sequence = sa_schema.Sequence(seq['name'], 1, 1) - if 'start' in seq: - sequence.start = seq['start'] - if 'increment' in seq: - sequence.increment = seq['increment'] + seq = col_d["sequence"] + sequence = sa_schema.Sequence(seq["name"], 1, 1) + if "start" in seq: + sequence.start = seq["start"] + if "increment" in seq: + sequence.increment = seq["increment"] colargs.append(sequence) def _reflect_pk( - self, table_name, schema, table, - cols_by_orig_name, exclude_columns): + self, table_name, schema, table, cols_by_orig_name, exclude_columns + ): pk_cons = self.get_pk_constraint( - table_name, schema, **table.dialect_kwargs) + table_name, schema, **table.dialect_kwargs + ) if pk_cons: pk_cols = [ cols_by_orig_name[pk] - for pk in pk_cons['constrained_columns'] + for pk in pk_cons["constrained_columns"] if pk in cols_by_orig_name and pk not in exclude_columns ] # update pk constraint name - table.primary_key.name = pk_cons.get('name') + table.primary_key.name = pk_cons.get("name") # tell the PKConstraint to re-initialize # its column collection table.primary_key._reload(pk_cols) def _reflect_fk( - self, table_name, schema, table, cols_by_orig_name, - exclude_columns, _extend_on, reflection_options): + self, + table_name, + schema, + table, + cols_by_orig_name, + exclude_columns, + _extend_on, + reflection_options, + ): fkeys = self.get_foreign_keys( - table_name, schema, **table.dialect_kwargs) + table_name, schema, **table.dialect_kwargs + ) for fkey_d in fkeys: - conname = fkey_d['name'] + conname = fkey_d["name"] # look for columns by orig name in cols_by_orig_name, # but support columns that are in-Python only as fallback constrained_columns = [ - cols_by_orig_name[c].key - if c in cols_by_orig_name else c - for c in fkey_d['constrained_columns'] + cols_by_orig_name[c].key if c in cols_by_orig_name else c + for c in fkey_d["constrained_columns"] ] if exclude_columns and set(constrained_columns).intersection( - exclude_columns): + exclude_columns + ): continue - referred_schema = fkey_d['referred_schema'] - referred_table = fkey_d['referred_table'] - referred_columns = fkey_d['referred_columns'] + referred_schema = fkey_d["referred_schema"] + referred_table = fkey_d["referred_table"] + referred_columns = fkey_d["referred_columns"] refspec = [] if referred_schema is not None: - sa_schema.Table(referred_table, table.metadata, - autoload=True, schema=referred_schema, - autoload_with=self.bind, - _extend_on=_extend_on, - **reflection_options - ) + sa_schema.Table( + referred_table, + table.metadata, + autoload=True, + schema=referred_schema, + autoload_with=self.bind, + _extend_on=_extend_on, + **reflection_options + ) for column in referred_columns: - refspec.append(".".join( - [referred_schema, referred_table, column])) + refspec.append( + ".".join([referred_schema, referred_table, column]) + ) else: - sa_schema.Table(referred_table, table.metadata, autoload=True, - autoload_with=self.bind, - schema=sa_schema.BLANK_SCHEMA, - _extend_on=_extend_on, - **reflection_options - ) + sa_schema.Table( + referred_table, + table.metadata, + autoload=True, + autoload_with=self.bind, + schema=sa_schema.BLANK_SCHEMA, + _extend_on=_extend_on, + **reflection_options + ) for column in referred_columns: refspec.append(".".join([referred_table, column])) - if 'options' in fkey_d: - options = fkey_d['options'] + if "options" in fkey_d: + options = fkey_d["options"] else: options = {} table.append_constraint( - sa_schema.ForeignKeyConstraint(constrained_columns, refspec, - conname, link_to_name=True, - **options)) + sa_schema.ForeignKeyConstraint( + constrained_columns, + refspec, + conname, + link_to_name=True, + **options + ) + ) def _reflect_indexes( - self, table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options): + self, + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ): # Indexes indexes = self.get_indexes(table_name, schema) for index_d in indexes: - name = index_d['name'] - columns = index_d['column_names'] - unique = index_d['unique'] - flavor = index_d.get('type', 'index') - dialect_options = index_d.get('dialect_options', {}) - - duplicates = index_d.get('duplicates_constraint') - if include_columns and \ - not set(columns).issubset(include_columns): + name = index_d["name"] + columns = index_d["column_names"] + unique = index_d["unique"] + flavor = index_d.get("type", "index") + dialect_options = index_d.get("dialect_options", {}) + + duplicates = index_d.get("duplicates_constraint") + if include_columns and not set(columns).issubset(include_columns): util.warn( - "Omitting %s key for (%s), key covers omitted columns." % - (flavor, ', '.join(columns))) + "Omitting %s key for (%s), key covers omitted columns." + % (flavor, ", ".join(columns)) + ) continue if duplicates: continue @@ -802,26 +879,36 @@ class Inspector(object): idx_cols = [] for c in columns: try: - idx_col = cols_by_orig_name[c] \ - if c in cols_by_orig_name else table.c[c] + idx_col = ( + cols_by_orig_name[c] + if c in cols_by_orig_name + else table.c[c] + ) except KeyError: util.warn( "%s key '%s' was not located in " - "columns for table '%s'" % ( - flavor, c, table_name - )) + "columns for table '%s'" % (flavor, c, table_name) + ) else: idx_cols.append(idx_col) sa_schema.Index( - name, *idx_cols, + name, + *idx_cols, _table=table, - **dict(list(dialect_options.items()) + [('unique', unique)]) + **dict(list(dialect_options.items()) + [("unique", unique)]) ) def _reflect_unique_constraints( - self, table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options): + self, + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ): # Unique Constraints try: @@ -831,15 +918,14 @@ class Inspector(object): return for const_d in constraints: - conname = const_d['name'] - columns = const_d['column_names'] - duplicates = const_d.get('duplicates_index') - if include_columns and \ - not set(columns).issubset(include_columns): + conname = const_d["name"] + columns = const_d["column_names"] + duplicates = const_d.get("duplicates_index") + if include_columns and not set(columns).issubset(include_columns): util.warn( "Omitting unique constraint key for (%s), " - "key covers omitted columns." % - ', '.join(columns)) + "key covers omitted columns." % ", ".join(columns) + ) continue if duplicates: continue @@ -848,20 +934,32 @@ class Inspector(object): constrained_cols = [] for c in columns: try: - constrained_col = cols_by_orig_name[c] \ - if c in cols_by_orig_name else table.c[c] + constrained_col = ( + cols_by_orig_name[c] + if c in cols_by_orig_name + else table.c[c] + ) except KeyError: util.warn( "unique constraint key '%s' was not located in " - "columns for table '%s'" % (c, table_name)) + "columns for table '%s'" % (c, table_name) + ) else: constrained_cols.append(constrained_col) table.append_constraint( - sa_schema.UniqueConstraint(*constrained_cols, name=conname)) + sa_schema.UniqueConstraint(*constrained_cols, name=conname) + ) def _reflect_check_constraints( - self, table_name, schema, table, cols_by_orig_name, - include_columns, exclude_columns, reflection_options): + self, + table_name, + schema, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ): try: constraints = self.get_check_constraints(table_name, schema) except NotImplementedError: @@ -869,14 +967,14 @@ class Inspector(object): return for const_d in constraints: - table.append_constraint( - sa_schema.CheckConstraint(**const_d)) + table.append_constraint(sa_schema.CheckConstraint(**const_d)) def _reflect_table_comment( - self, table_name, schema, table, reflection_options): + self, table_name, schema, table, reflection_options + ): try: comment_dict = self.get_table_comment(table_name, schema) except NotImplementedError: return else: - table.comment = comment_dict.get('text', None) + table.comment = comment_dict.get("text", None) diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index d4c8623757..5ad0d29098 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -27,20 +27,25 @@ try: # the extension is present. def rowproxy_reconstructor(cls, state): return safe_rowproxy_reconstructor(cls, state) + + except ImportError: + def rowproxy_reconstructor(cls, state): obj = cls.__new__(cls) obj.__setstate__(state) return obj + try: from sqlalchemy.cresultproxy import BaseRowProxy + _baserowproxy_usecext = True except ImportError: _baserowproxy_usecext = False class BaseRowProxy(object): - __slots__ = ('_parent', '_row', '_processors', '_keymap') + __slots__ = ("_parent", "_row", "_processors", "_keymap") def __init__(self, parent, row, processors, keymap): """RowProxy objects are constructed by ResultProxy objects.""" @@ -51,8 +56,10 @@ except ImportError: self._keymap = keymap def __reduce__(self): - return (rowproxy_reconstructor, - (self.__class__, self.__getstate__())) + return ( + rowproxy_reconstructor, + (self.__class__, self.__getstate__()), + ) def values(self): """Return the values represented by this RowProxy as a list.""" @@ -76,8 +83,9 @@ except ImportError: except TypeError: if isinstance(key, slice): l = [] - for processor, value in zip(self._processors[key], - self._row[key]): + for processor, value in zip( + self._processors[key], self._row[key] + ): if processor is None: l.append(value) else: @@ -88,7 +96,8 @@ except ImportError: if index is None: raise exc.InvalidRequestError( "Ambiguous column name '%s' in " - "result set column descriptions" % obj) + "result set column descriptions" % obj + ) if processor is not None: return processor(self._row[index]) else: @@ -110,29 +119,29 @@ class RowProxy(BaseRowProxy): mapped to the original Columns that produced this result set (for results that correspond to constructed SQL expressions). """ + __slots__ = () def __contains__(self, key): return self._parent._has_key(key) def __getstate__(self): - return { - '_parent': self._parent, - '_row': tuple(self) - } + return {"_parent": self._parent, "_row": tuple(self)} def __setstate__(self, state): - self._parent = parent = state['_parent'] - self._row = state['_row'] + self._parent = parent = state["_parent"] + self._row = state["_row"] self._processors = parent._processors self._keymap = parent._keymap __hash__ = None def _op(self, other, op): - return op(tuple(self), tuple(other)) \ - if isinstance(other, RowProxy) \ + return ( + op(tuple(self), tuple(other)) + if isinstance(other, RowProxy) else op(tuple(self), other) + ) def __lt__(self, other): return self._op(other, operator.lt) @@ -176,6 +185,7 @@ class RowProxy(BaseRowProxy): def itervalues(self): return iter(self) + try: # Register RowProxy with Sequence, # so sequence protocol is implemented @@ -189,8 +199,13 @@ class ResultMetaData(object): context.""" __slots__ = ( - '_keymap', 'case_sensitive', 'matched_on_name', - '_processors', 'keys', '_orig_processors') + "_keymap", + "case_sensitive", + "matched_on_name", + "_processors", + "keys", + "_orig_processors", + ) def __init__(self, parent, cursor_description): context = parent.context @@ -200,18 +215,25 @@ class ResultMetaData(object): self._orig_processors = None if context.result_column_struct: - result_columns, cols_are_ordered, textual_ordered = \ + result_columns, cols_are_ordered, textual_ordered = ( context.result_column_struct + ) num_ctx_cols = len(result_columns) else: - result_columns = cols_are_ordered = \ - num_ctx_cols = textual_ordered = False + result_columns = ( + cols_are_ordered + ) = num_ctx_cols = textual_ordered = False # merge cursor.description with the column info # present in the compiled structure, if any raw = self._merge_cursor_description( - context, cursor_description, result_columns, - num_ctx_cols, cols_are_ordered, textual_ordered) + context, + cursor_description, + result_columns, + num_ctx_cols, + cols_are_ordered, + textual_ordered, + ) self._keymap = {} if not _baserowproxy_usecext: @@ -223,23 +245,20 @@ class ResultMetaData(object): len_raw = len(raw) - self._keymap.update([ - (elem[0], (elem[3], elem[4], elem[0])) - for elem in raw - ] + [ - (elem[0] - len_raw, (elem[3], elem[4], elem[0])) - for elem in raw - ]) + self._keymap.update( + [(elem[0], (elem[3], elem[4], elem[0])) for elem in raw] + + [ + (elem[0] - len_raw, (elem[3], elem[4], elem[0])) + for elem in raw + ] + ) # processors in key order for certain per-row # views like __iter__ and slices self._processors = [elem[3] for elem in raw] # keymap by primary string... - by_key = dict([ - (elem[2], (elem[3], elem[4], elem[0])) - for elem in raw - ]) + by_key = dict([(elem[2], (elem[3], elem[4], elem[0])) for elem in raw]) # for compiled SQL constructs, copy additional lookup keys into # the key lookup map, such as Column objects, labels, @@ -264,29 +283,38 @@ class ResultMetaData(object): # copy secondary elements from compiled columns # into self._keymap, write in the potentially "ambiguous" # element - self._keymap.update([ - (obj_elem, by_key[elem[2]]) - for elem in raw if elem[4] - for obj_elem in elem[4] - ]) + self._keymap.update( + [ + (obj_elem, by_key[elem[2]]) + for elem in raw + if elem[4] + for obj_elem in elem[4] + ] + ) # if we did a pure positional match, then reset the # original "expression element" back to the "unambiguous" # entry. This is a new behavior in 1.1 which impacts # TextAsFrom but also straight compiled SQL constructs. if not self.matched_on_name: - self._keymap.update([ - (elem[4][0], (elem[3], elem[4], elem[0])) - for elem in raw if elem[4] - ]) + self._keymap.update( + [ + (elem[4][0], (elem[3], elem[4], elem[0])) + for elem in raw + if elem[4] + ] + ) else: # no dupes - copy secondary elements from compiled # columns into self._keymap - self._keymap.update([ - (obj_elem, (elem[3], elem[4], elem[0])) - for elem in raw if elem[4] - for obj_elem in elem[4] - ]) + self._keymap.update( + [ + (obj_elem, (elem[3], elem[4], elem[0])) + for elem in raw + if elem[4] + for obj_elem in elem[4] + ] + ) # update keymap with primary string names taking # precedence @@ -294,14 +322,19 @@ class ResultMetaData(object): # update keymap with "translated" names (sqlite-only thing) if not num_ctx_cols and context._translate_colname: - self._keymap.update([ - (elem[5], self._keymap[elem[2]]) - for elem in raw if elem[5] - ]) + self._keymap.update( + [(elem[5], self._keymap[elem[2]]) for elem in raw if elem[5]] + ) def _merge_cursor_description( - self, context, cursor_description, result_columns, - num_ctx_cols, cols_are_ordered, textual_ordered): + self, + context, + cursor_description, + result_columns, + num_ctx_cols, + cols_are_ordered, + textual_ordered, + ): """Merge a cursor.description with compiled result column information. There are at least four separate strategies used here, selected @@ -357,10 +390,12 @@ class ResultMetaData(object): case_sensitive = context.dialect.case_sensitive - if num_ctx_cols and \ - cols_are_ordered and \ - not textual_ordered and \ - num_ctx_cols == len(cursor_description): + if ( + num_ctx_cols + and cols_are_ordered + and not textual_ordered + and num_ctx_cols == len(cursor_description) + ): self.keys = [elem[0] for elem in result_columns] # pure positional 1-1 case; doesn't need to read # the names from cursor.description @@ -373,9 +408,9 @@ class ResultMetaData(object): type_, key, cursor_description[idx][1] ), obj, - None - ) for idx, (key, name, obj, type_) - in enumerate(result_columns) + None, + ) + for idx, (key, name, obj, type_) in enumerate(result_columns) ] else: # name-based or text-positional cases, where we need @@ -383,26 +418,32 @@ class ResultMetaData(object): if textual_ordered: # textual positional case raw_iterator = self._merge_textual_cols_by_position( - context, cursor_description, result_columns) + context, cursor_description, result_columns + ) elif num_ctx_cols: # compiled SQL with a mismatch of description cols # vs. compiled cols, or textual w/ unordered columns raw_iterator = self._merge_cols_by_name( - context, cursor_description, result_columns) + context, cursor_description, result_columns + ) else: # no compiled SQL, just a raw string raw_iterator = self._merge_cols_by_none( - context, cursor_description) + context, cursor_description + ) return [ ( - idx, colname, colname, + idx, + colname, + colname, context.get_result_processor( - mapped_type, colname, coltype), - obj, untranslated) - - for idx, colname, mapped_type, coltype, obj, untranslated - in raw_iterator + mapped_type, colname, coltype + ), + obj, + untranslated, + ) + for idx, colname, mapped_type, coltype, obj, untranslated in raw_iterator ] def _colnames_from_description(self, context, cursor_description): @@ -416,10 +457,14 @@ class ResultMetaData(object): dialect = context.dialect case_sensitive = dialect.case_sensitive translate_colname = context._translate_colname - description_decoder = dialect._description_decoder \ - if dialect.description_encoding else None - normalize_name = dialect.normalize_name \ - if dialect.requires_name_normalize else None + description_decoder = ( + dialect._description_decoder + if dialect.description_encoding + else None + ) + normalize_name = ( + dialect.normalize_name if dialect.requires_name_normalize else None + ) untranslated = None self.keys = [] @@ -444,20 +489,25 @@ class ResultMetaData(object): yield idx, colname, untranslated, coltype def _merge_textual_cols_by_position( - self, context, cursor_description, result_columns): + self, context, cursor_description, result_columns + ): dialect = context.dialect num_ctx_cols = len(result_columns) if result_columns else None if num_ctx_cols > len(cursor_description): util.warn( "Number of columns in textual SQL (%d) is " - "smaller than number of columns requested (%d)" % ( - num_ctx_cols, len(cursor_description) - )) + "smaller than number of columns requested (%d)" + % (num_ctx_cols, len(cursor_description)) + ) seen = set() - for idx, colname, untranslated, coltype in \ - self._colnames_from_description(context, cursor_description): + for ( + idx, + colname, + untranslated, + coltype, + ) in self._colnames_from_description(context, cursor_description): if idx < num_ctx_cols: ctx_rec = result_columns[idx] obj = ctx_rec[2] @@ -465,7 +515,8 @@ class ResultMetaData(object): if obj[0] in seen: raise exc.InvalidRequestError( "Duplicate column expression requested " - "in textual SQL: %r" % obj[0]) + "in textual SQL: %r" % obj[0] + ) seen.add(obj[0]) else: mapped_type = sqltypes.NULLTYPE @@ -479,8 +530,12 @@ class ResultMetaData(object): result_map = self._create_result_map(result_columns, case_sensitive) self.matched_on_name = True - for idx, colname, untranslated, coltype in \ - self._colnames_from_description(context, cursor_description): + for ( + idx, + colname, + untranslated, + coltype, + ) in self._colnames_from_description(context, cursor_description): try: ctx_rec = result_map[colname] except KeyError: @@ -493,8 +548,12 @@ class ResultMetaData(object): def _merge_cols_by_none(self, context, cursor_description): dialect = context.dialect - for idx, colname, untranslated, coltype in \ - self._colnames_from_description(context, cursor_description): + for ( + idx, + colname, + untranslated, + coltype, + ) in self._colnames_from_description(context, cursor_description): yield idx, colname, sqltypes.NULLTYPE, coltype, None, untranslated @classmethod @@ -525,27 +584,28 @@ class ResultMetaData(object): # or colummn('name') constructs to ColumnElements, or after a # pickle/unpickle roundtrip elif isinstance(key, expression.ColumnElement): - if key._label and ( - key._label - if self.case_sensitive - else key._label.lower()) in map: - result = map[key._label - if self.case_sensitive - else key._label.lower()] - elif hasattr(key, 'name') and ( - key.name - if self.case_sensitive - else key.name.lower()) in map: + if ( + key._label + and (key._label if self.case_sensitive else key._label.lower()) + in map + ): + result = map[ + key._label if self.case_sensitive else key._label.lower() + ] + elif ( + hasattr(key, "name") + and (key.name if self.case_sensitive else key.name.lower()) + in map + ): # match is only on name. - result = map[key.name - if self.case_sensitive - else key.name.lower()] + result = map[ + key.name if self.case_sensitive else key.name.lower() + ] # search extra hard to make sure this # isn't a column/label name overlap. # this check isn't currently available if the row # was unpickled. - if result is not None and \ - result[1] is not None: + if result is not None and result[1] is not None: for obj in result[1]: if key._compare_name_for_result(obj): break @@ -554,8 +614,9 @@ class ResultMetaData(object): if result is None: if raiseerr: raise exc.NoSuchColumnError( - "Could not locate column in row for column '%s'" % - expression._string_or_unprintable(key)) + "Could not locate column in row for column '%s'" + % expression._string_or_unprintable(key) + ) else: return None else: @@ -580,34 +641,35 @@ class ResultMetaData(object): if index is None: raise exc.InvalidRequestError( "Ambiguous column name '%s' in " - "result set column descriptions" % obj) + "result set column descriptions" % obj + ) return operator.itemgetter(index) def __getstate__(self): return { - '_pickled_keymap': dict( + "_pickled_keymap": dict( (key, index) for key, (processor, obj, index) in self._keymap.items() if isinstance(key, util.string_types + util.int_types) ), - 'keys': self.keys, + "keys": self.keys, "case_sensitive": self.case_sensitive, - "matched_on_name": self.matched_on_name + "matched_on_name": self.matched_on_name, } def __setstate__(self, state): # the row has been processed at pickling time so we don't need any # processor anymore - self._processors = [None for _ in range(len(state['keys']))] + self._processors = [None for _ in range(len(state["keys"]))] self._keymap = keymap = {} - for key, index in state['_pickled_keymap'].items(): + for key, index in state["_pickled_keymap"].items(): # not preserving "obj" here, unfortunately our # proxy comparison fails with the unpickle keymap[key] = (None, None, index) - self.keys = state['keys'] - self.case_sensitive = state['case_sensitive'] - self.matched_on_name = state['matched_on_name'] + self.keys = state["keys"] + self.case_sensitive = state["case_sensitive"] + self.matched_on_name = state["matched_on_name"] class ResultProxy(object): @@ -643,8 +705,9 @@ class ResultProxy(object): self.dialect = context.dialect self.cursor = self._saved_cursor = context.cursor self.connection = context.root_connection - self._echo = self.connection._echo and \ - context.engine._should_log_debug() + self._echo = ( + self.connection._echo and context.engine._should_log_debug() + ) self._init_metadata() def _getter(self, key, raiseerr=True): @@ -666,18 +729,22 @@ class ResultProxy(object): def _init_metadata(self): cursor_description = self._cursor_description() if cursor_description is not None: - if self.context.compiled and \ - 'compiled_cache' in self.context.execution_options: + if ( + self.context.compiled + and "compiled_cache" in self.context.execution_options + ): if self.context.compiled._cached_metadata: self._metadata = self.context.compiled._cached_metadata else: - self._metadata = self.context.compiled._cached_metadata = \ - ResultMetaData(self, cursor_description) + self._metadata = ( + self.context.compiled._cached_metadata + ) = ResultMetaData(self, cursor_description) else: self._metadata = ResultMetaData(self, cursor_description) if self._echo: self.context.engine.logger.debug( - "Col %r", tuple(x[0] for x in cursor_description)) + "Col %r", tuple(x[0] for x in cursor_description) + ) def keys(self): """Return the current set of string keys for rows.""" @@ -731,7 +798,8 @@ class ResultProxy(object): return self.context.rowcount except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, self.cursor, self.context) + e, None, None, self.cursor, self.context + ) @property def lastrowid(self): @@ -753,8 +821,8 @@ class ResultProxy(object): return self._saved_cursor.lastrowid except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, - self._saved_cursor, self.context) + e, None, None, self._saved_cursor, self.context + ) @property def returns_rows(self): @@ -913,17 +981,18 @@ class ResultProxy(object): if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " - "expression construct.") + "Statement is not a compiled " "expression construct." + ) elif not self.context.isinsert: raise exc.InvalidRequestError( - "Statement is not an insert() " - "expression construct.") + "Statement is not an insert() " "expression construct." + ) elif self.context._is_explicit_returning: raise exc.InvalidRequestError( "Can't call inserted_primary_key " "when returning() " - "is used.") + "is used." + ) return self.context.inserted_primary_key @@ -938,12 +1007,12 @@ class ResultProxy(object): """ if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " - "expression construct.") + "Statement is not a compiled " "expression construct." + ) elif not self.context.isupdate: raise exc.InvalidRequestError( - "Statement is not an update() " - "expression construct.") + "Statement is not an update() " "expression construct." + ) elif self.context.executemany: return self.context.compiled_parameters else: @@ -960,12 +1029,12 @@ class ResultProxy(object): """ if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " - "expression construct.") + "Statement is not a compiled " "expression construct." + ) elif not self.context.isinsert: raise exc.InvalidRequestError( - "Statement is not an insert() " - "expression construct.") + "Statement is not an insert() " "expression construct." + ) elif self.context.executemany: return self.context.compiled_parameters else: @@ -1013,12 +1082,13 @@ class ResultProxy(object): if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " - "expression construct.") + "Statement is not a compiled " "expression construct." + ) elif not self.context.isinsert and not self.context.isupdate: raise exc.InvalidRequestError( "Statement is not an insert() or update() " - "expression construct.") + "expression construct." + ) return self.context.postfetch_cols def prefetch_cols(self): @@ -1035,12 +1105,13 @@ class ResultProxy(object): if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled " - "expression construct.") + "Statement is not a compiled " "expression construct." + ) elif not self.context.isinsert and not self.context.isupdate: raise exc.InvalidRequestError( "Statement is not an insert() or update() " - "expression construct.") + "expression construct." + ) return self.context.prefetch_cols def supports_sane_rowcount(self): @@ -1086,7 +1157,7 @@ class ResultProxy(object): if self._metadata is None: raise exc.ResourceClosedError( "This result object does not return rows. " - "It has been closed automatically.", + "It has been closed automatically." ) elif self.closed: raise exc.ResourceClosedError("This result object is closed.") @@ -1106,8 +1177,9 @@ class ResultProxy(object): l.append(process_row(metadata, row, processors, keymap)) return l else: - return [process_row(metadata, row, processors, keymap) - for row in rows] + return [ + process_row(metadata, row, processors, keymap) for row in rows + ] def fetchall(self): """Fetch all rows, just like DB-API ``cursor.fetchall()``. @@ -1132,8 +1204,8 @@ class ResultProxy(object): return l except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, - self.cursor, self.context) + e, None, None, self.cursor, self.context + ) def fetchmany(self, size=None): """Fetch many rows, just like DB-API @@ -1161,8 +1233,8 @@ class ResultProxy(object): return l except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, - self.cursor, self.context) + e, None, None, self.cursor, self.context + ) def fetchone(self): """Fetch one row, just like DB-API ``cursor.fetchone()``. @@ -1190,8 +1262,8 @@ class ResultProxy(object): return None except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, - self.cursor, self.context) + e, None, None, self.cursor, self.context + ) def first(self): """Fetch the first row and then close the result set unconditionally. @@ -1209,8 +1281,8 @@ class ResultProxy(object): row = self._fetchone_impl() except BaseException as e: self.connection._handle_dbapi_exception( - e, None, None, - self.cursor, self.context) + e, None, None, self.cursor, self.context + ) try: if row is not None: @@ -1268,7 +1340,8 @@ class BufferedRowResultProxy(ResultProxy): def _init_metadata(self): self._max_row_buffer = self.context.execution_options.get( - 'max_row_buffer', None) + "max_row_buffer", None + ) self.__buffer_rows() super(BufferedRowResultProxy, self)._init_metadata() @@ -1284,13 +1357,13 @@ class BufferedRowResultProxy(ResultProxy): 50: 100, 100: 250, 250: 500, - 500: 1000 + 500: 1000, } def __buffer_rows(self): if self.cursor is None: return - size = getattr(self, '_bufsize', 1) + size = getattr(self, "_bufsize", 1) self.__rowbuffer = collections.deque(self.cursor.fetchmany(size)) self._bufsize = self.size_growth.get(size, size) if self._max_row_buffer is not None: @@ -1385,8 +1458,9 @@ class BufferedColumnRow(RowProxy): row[index] = processor(row[index]) index += 1 row = tuple(row) - super(BufferedColumnRow, self).__init__(parent, row, - processors, keymap) + super(BufferedColumnRow, self).__init__( + parent, row, processors, keymap + ) class BufferedColumnResultProxy(ResultProxy): diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index d4f5185de6..4aecb9537e 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -51,18 +51,20 @@ class DefaultEngineStrategy(EngineStrategy): plugins = u._instantiate_plugins(kwargs) - u.query.pop('plugin', None) - kwargs.pop('plugins', None) + u.query.pop("plugin", None) + kwargs.pop("plugins", None) entrypoint = u._get_entrypoint() dialect_cls = entrypoint.get_dialect_cls(u) - if kwargs.pop('_coerce_config', False): + if kwargs.pop("_coerce_config", False): + def pop_kwarg(key, default=None): value = kwargs.pop(key, default) if key in dialect_cls.engine_config_types: value = dialect_cls.engine_config_types[key](value) return value + else: pop_kwarg = kwargs.pop @@ -72,7 +74,7 @@ class DefaultEngineStrategy(EngineStrategy): if k in kwargs: dialect_args[k] = pop_kwarg(k) - dbapi = kwargs.pop('module', None) + dbapi = kwargs.pop("module", None) if dbapi is None: dbapi_args = {} for k in util.get_func_kwargs(dialect_cls.dbapi): @@ -80,7 +82,7 @@ class DefaultEngineStrategy(EngineStrategy): dbapi_args[k] = pop_kwarg(k) dbapi = dialect_cls.dbapi(**dbapi_args) - dialect_args['dbapi'] = dbapi + dialect_args["dbapi"] = dbapi for plugin in plugins: plugin.handle_dialect_kwargs(dialect_cls, dialect_args) @@ -90,41 +92,43 @@ class DefaultEngineStrategy(EngineStrategy): # assemble connection arguments (cargs, cparams) = dialect.create_connect_args(u) - cparams.update(pop_kwarg('connect_args', {})) + cparams.update(pop_kwarg("connect_args", {})) cargs = list(cargs) # allow mutability # look for existing pool or create - pool = pop_kwarg('pool', None) + pool = pop_kwarg("pool", None) if pool is None: + def connect(connection_record=None): if dialect._has_events: for fn in dialect.dispatch.do_connect: connection = fn( - dialect, connection_record, cargs, cparams) + dialect, connection_record, cargs, cparams + ) if connection is not None: return connection return dialect.connect(*cargs, **cparams) - creator = pop_kwarg('creator', connect) + creator = pop_kwarg("creator", connect) - poolclass = pop_kwarg('poolclass', None) + poolclass = pop_kwarg("poolclass", None) if poolclass is None: poolclass = dialect_cls.get_pool_class(u) - pool_args = { - 'dialect': dialect - } + pool_args = {"dialect": dialect} # consume pool arguments from kwargs, translating a few of # the arguments - translate = {'logging_name': 'pool_logging_name', - 'echo': 'echo_pool', - 'timeout': 'pool_timeout', - 'recycle': 'pool_recycle', - 'events': 'pool_events', - 'use_threadlocal': 'pool_threadlocal', - 'reset_on_return': 'pool_reset_on_return', - 'pre_ping': 'pool_pre_ping', - 'use_lifo': 'pool_use_lifo'} + translate = { + "logging_name": "pool_logging_name", + "echo": "echo_pool", + "timeout": "pool_timeout", + "recycle": "pool_recycle", + "events": "pool_events", + "use_threadlocal": "pool_threadlocal", + "reset_on_return": "pool_reset_on_return", + "pre_ping": "pool_pre_ping", + "use_lifo": "pool_use_lifo", + } for k in util.get_cls_kwargs(poolclass): tk = translate.get(k, k) if tk in kwargs: @@ -149,7 +153,7 @@ class DefaultEngineStrategy(EngineStrategy): if k in kwargs: engine_args[k] = pop_kwarg(k) - _initialize = kwargs.pop('_initialize', True) + _initialize = kwargs.pop("_initialize", True) # all kwargs should be consumed if kwargs: @@ -157,32 +161,40 @@ class DefaultEngineStrategy(EngineStrategy): "Invalid argument(s) %s sent to create_engine(), " "using configuration %s/%s/%s. Please check that the " "keyword arguments are appropriate for this combination " - "of components." % (','.join("'%s'" % k for k in kwargs), - dialect.__class__.__name__, - pool.__class__.__name__, - engineclass.__name__)) + "of components." + % ( + ",".join("'%s'" % k for k in kwargs), + dialect.__class__.__name__, + pool.__class__.__name__, + engineclass.__name__, + ) + ) engine = engineclass(pool, dialect, u, **engine_args) if _initialize: do_on_connect = dialect.on_connect() if do_on_connect: + def on_connect(dbapi_connection, connection_record): conn = getattr( - dbapi_connection, '_sqla_unwrap', dbapi_connection) + dbapi_connection, "_sqla_unwrap", dbapi_connection + ) if conn is None: return do_on_connect(conn) - event.listen(pool, 'first_connect', on_connect) - event.listen(pool, 'connect', on_connect) + event.listen(pool, "first_connect", on_connect) + event.listen(pool, "connect", on_connect) def first_connect(dbapi_connection, connection_record): - c = base.Connection(engine, connection=dbapi_connection, - _has_events=False) + c = base.Connection( + engine, connection=dbapi_connection, _has_events=False + ) c._execution_options = util.immutabledict() dialect.initialize(c) - event.listen(pool, 'first_connect', first_connect, once=True) + + event.listen(pool, "first_connect", first_connect, once=True) dialect_cls.engine_created(engine) if entrypoint is not dialect_cls: @@ -197,18 +209,20 @@ class DefaultEngineStrategy(EngineStrategy): class PlainEngineStrategy(DefaultEngineStrategy): """Strategy for configuring a regular Engine.""" - name = 'plain' + name = "plain" engine_cls = base.Engine + PlainEngineStrategy() class ThreadLocalEngineStrategy(DefaultEngineStrategy): """Strategy for configuring an Engine with threadlocal behavior.""" - name = 'threadlocal' + name = "threadlocal" engine_cls = threadlocal.TLEngine + ThreadLocalEngineStrategy() @@ -220,7 +234,7 @@ class MockEngineStrategy(EngineStrategy): """ - name = 'mock' + name = "mock" def create(self, name_or_url, executor, **kwargs): # create url.URL object @@ -245,7 +259,7 @@ class MockEngineStrategy(EngineStrategy): self.execute = execute engine = property(lambda s: s) - dialect = property(attrgetter('_dialect')) + dialect = property(attrgetter("_dialect")) name = property(lambda s: s._dialect.name) schema_for_object = schema._schema_getter(None) @@ -258,29 +272,35 @@ class MockEngineStrategy(EngineStrategy): def compiler(self, statement, parameters, **kwargs): return self._dialect.compiler( - statement, parameters, engine=self, **kwargs) + statement, parameters, engine=self, **kwargs + ) def create(self, entity, **kwargs): - kwargs['checkfirst'] = False + kwargs["checkfirst"] = False from sqlalchemy.engine import ddl - ddl.SchemaGenerator( - self.dialect, self, **kwargs).traverse_single(entity) + ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse_single( + entity + ) def drop(self, entity, **kwargs): - kwargs['checkfirst'] = False + kwargs["checkfirst"] = False from sqlalchemy.engine import ddl - ddl.SchemaDropper( - self.dialect, self, **kwargs).traverse_single(entity) - def _run_visitor(self, visitorcallable, element, - connection=None, - **kwargs): - kwargs['checkfirst'] = False - visitorcallable(self.dialect, self, - **kwargs).traverse_single(element) + ddl.SchemaDropper(self.dialect, self, **kwargs).traverse_single( + entity + ) + + def _run_visitor( + self, visitorcallable, element, connection=None, **kwargs + ): + kwargs["checkfirst"] = False + visitorcallable(self.dialect, self, **kwargs).traverse_single( + element + ) def execute(self, object, *multiparams, **params): raise NotImplementedError() + MockEngineStrategy() diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 0ec1f9613c..5b2bdabc09 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -19,7 +19,6 @@ import weakref class TLConnection(base.Connection): - def __init__(self, *arg, **kw): super(TLConnection, self).__init__(*arg, **kw) self.__opencount = 0 @@ -43,6 +42,7 @@ class TLEngine(base.Engine): transactions. """ + _tl_connection_cls = TLConnection def __init__(self, *args, **kwargs): @@ -50,7 +50,7 @@ class TLEngine(base.Engine): self._connections = util.threading.local() def contextual_connect(self, **kw): - if not hasattr(self._connections, 'conn'): + if not hasattr(self._connections, "conn"): connection = None else: connection = self._connections.conn() @@ -60,29 +60,31 @@ class TLEngine(base.Engine): # or not connection.connection.is_valid: connection = self._tl_connection_cls( self, - self._wrap_pool_connect( - self.pool.connect, connection), - **kw) + self._wrap_pool_connect(self.pool.connect, connection), + **kw + ) self._connections.conn = weakref.ref(connection) return connection._increment_connect() def begin_twophase(self, xid=None): - if not hasattr(self._connections, 'trans'): + if not hasattr(self._connections, "trans"): self._connections.trans = [] self._connections.trans.append( - self.contextual_connect().begin_twophase(xid=xid)) + self.contextual_connect().begin_twophase(xid=xid) + ) return self def begin_nested(self): - if not hasattr(self._connections, 'trans'): + if not hasattr(self._connections, "trans"): self._connections.trans = [] self._connections.trans.append( - self.contextual_connect().begin_nested()) + self.contextual_connect().begin_nested() + ) return self def begin(self): - if not hasattr(self._connections, 'trans'): + if not hasattr(self._connections, "trans"): self._connections.trans = [] self._connections.trans.append(self.contextual_connect().begin()) return self @@ -97,21 +99,27 @@ class TLEngine(base.Engine): self.rollback() def prepare(self): - if not hasattr(self._connections, 'trans') or \ - not self._connections.trans: + if ( + not hasattr(self._connections, "trans") + or not self._connections.trans + ): return self._connections.trans[-1].prepare() def commit(self): - if not hasattr(self._connections, 'trans') or \ - not self._connections.trans: + if ( + not hasattr(self._connections, "trans") + or not self._connections.trans + ): return trans = self._connections.trans.pop(-1) trans.commit() def rollback(self): - if not hasattr(self._connections, 'trans') or \ - not self._connections.trans: + if ( + not hasattr(self._connections, "trans") + or not self._connections.trans + ): return trans = self._connections.trans.pop(-1) trans.rollback() @@ -122,9 +130,11 @@ class TLEngine(base.Engine): @property def closed(self): - return not hasattr(self._connections, 'conn') or \ - self._connections.conn() is None or \ - self._connections.conn().closed + return ( + not hasattr(self._connections, "conn") + or self._connections.conn() is None + or self._connections.conn().closed + ) def close(self): if not self.closed: @@ -135,4 +145,4 @@ class TLEngine(base.Engine): self._connections.trans = [] def __repr__(self): - return 'TLEngine(%r)' % self.url + return "TLEngine(%r)" % self.url diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index 1662efe209..e92e57b8e3 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -50,8 +50,16 @@ class URL(object): """ - def __init__(self, drivername, username=None, password=None, - host=None, port=None, database=None, query=None): + def __init__( + self, + drivername, + username=None, + password=None, + host=None, + port=None, + database=None, + query=None, + ): self.drivername = drivername self.username = username self.password_original = password @@ -68,26 +76,26 @@ class URL(object): if self.username is not None: s += _rfc_1738_quote(self.username) if self.password is not None: - s += ':' + ('***' if hide_password - else _rfc_1738_quote(self.password)) + s += ":" + ( + "***" if hide_password else _rfc_1738_quote(self.password) + ) s += "@" if self.host is not None: - if ':' in self.host: + if ":" in self.host: s += "[%s]" % self.host else: s += self.host if self.port is not None: - s += ':' + str(self.port) + s += ":" + str(self.port) if self.database is not None: - s += '/' + self.database + s += "/" + self.database if self.query: keys = list(self.query) keys.sort() - s += '?' + "&".join( - "%s=%s" % ( - k, - element - ) for k in keys for element in util.to_list(self.query[k]) + s += "?" + "&".join( + "%s=%s" % (k, element) + for k in keys + for element in util.to_list(self.query[k]) ) return s @@ -101,14 +109,15 @@ class URL(object): return hash(str(self)) def __eq__(self, other): - return \ - isinstance(other, URL) and \ - self.drivername == other.drivername and \ - self.username == other.username and \ - self.password == other.password and \ - self.host == other.host and \ - self.database == other.database and \ - self.query == other.query + return ( + isinstance(other, URL) + and self.drivername == other.drivername + and self.username == other.username + and self.password == other.password + and self.host == other.host + and self.database == other.database + and self.query == other.query + ) @property def password(self): @@ -122,20 +131,20 @@ class URL(object): self.password_original = password def get_backend_name(self): - if '+' not in self.drivername: + if "+" not in self.drivername: return self.drivername else: - return self.drivername.split('+')[0] + return self.drivername.split("+")[0] def get_driver_name(self): - if '+' not in self.drivername: + if "+" not in self.drivername: return self.get_dialect().driver else: - return self.drivername.split('+')[1] + return self.drivername.split("+")[1] def _instantiate_plugins(self, kwargs): - plugin_names = util.to_list(self.query.get('plugin', ())) - plugin_names += kwargs.get('plugins', []) + plugin_names = util.to_list(self.query.get("plugin", ())) + plugin_names += kwargs.get("plugins", []) return [ plugins.load(plugin_name)(self, kwargs) @@ -149,17 +158,19 @@ class URL(object): returned class implements the get_dialect_cls() method. """ - if '+' not in self.drivername: + if "+" not in self.drivername: name = self.drivername else: - name = self.drivername.replace('+', '.') + name = self.drivername.replace("+", ".") cls = registry.load(name) # check for legacy dialects that # would return a module with 'dialect' as the # actual class - if hasattr(cls, 'dialect') and \ - isinstance(cls.dialect, type) and \ - issubclass(cls.dialect, Dialect): + if ( + hasattr(cls, "dialect") + and isinstance(cls.dialect, type) + and issubclass(cls.dialect, Dialect) + ): return cls.dialect else: return cls @@ -187,7 +198,7 @@ class URL(object): """ translated = {} - attribute_names = ['host', 'database', 'username', 'password', 'port'] + attribute_names = ["host", "database", "username", "password", "port"] for sname in attribute_names: if names: name = names.pop(0) @@ -214,7 +225,8 @@ def make_url(name_or_url): def _parse_rfc1738_args(name): - pattern = re.compile(r''' + pattern = re.compile( + r""" (?P[\w\+]+):// (?: (?P[^:/]*) @@ -228,21 +240,23 @@ def _parse_rfc1738_args(name): (?::(?P[^/]*))? )? (?:/(?P.*))? - ''', re.X) + """, + re.X, + ) m = pattern.match(name) if m is not None: components = m.groupdict() - if components['database'] is not None: - tokens = components['database'].split('?', 2) - components['database'] = tokens[0] + if components["database"] is not None: + tokens = components["database"].split("?", 2) + components["database"] = tokens[0] if len(tokens) > 1: query = {} for key, value in util.parse_qsl(tokens[1]): if util.py2k: - key = key.encode('ascii') + key = key.encode("ascii") if key in query: query[key] = util.to_list(query[key]) query[key].append(value) @@ -252,26 +266,27 @@ def _parse_rfc1738_args(name): query = None else: query = None - components['query'] = query + components["query"] = query - if components['username'] is not None: - components['username'] = _rfc_1738_unquote(components['username']) + if components["username"] is not None: + components["username"] = _rfc_1738_unquote(components["username"]) - if components['password'] is not None: - components['password'] = _rfc_1738_unquote(components['password']) + if components["password"] is not None: + components["password"] = _rfc_1738_unquote(components["password"]) - ipv4host = components.pop('ipv4host') - ipv6host = components.pop('ipv6host') - components['host'] = ipv4host or ipv6host - name = components.pop('name') + ipv4host = components.pop("ipv4host") + ipv6host = components.pop("ipv6host") + components["host"] = ipv4host or ipv6host + name = components.pop("name") return URL(name, **components) else: raise exc.ArgumentError( - "Could not parse rfc1738 URL from string '%s'" % name) + "Could not parse rfc1738 URL from string '%s'" % name + ) def _rfc_1738_quote(text): - return re.sub(r'[:@/]', lambda m: "%%%X" % ord(m.group(0)), text) + return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text) def _rfc_1738_unquote(text): @@ -279,7 +294,7 @@ def _rfc_1738_unquote(text): def _parse_keyvalue_args(name): - m = re.match(r'(\w+)://(.*)', name) + m = re.match(r"(\w+)://(.*)", name) if m is not None: (name, args) = m.group(1, 2) opts = dict(util.parse_qsl(args)) diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index 17bc9a3b47..76bb8f4b54 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -46,28 +46,34 @@ def py_fallback(): elif len(multiparams) == 1: zero = multiparams[0] if isinstance(zero, (list, tuple)): - if not zero or hasattr(zero[0], '__iter__') and \ - not hasattr(zero[0], 'strip'): + if ( + not zero + or hasattr(zero[0], "__iter__") + and not hasattr(zero[0], "strip") + ): # execute(stmt, [{}, {}, {}, ...]) # execute(stmt, [(), (), (), ...]) return zero else: # execute(stmt, ("value", "value")) return [zero] - elif hasattr(zero, 'keys'): + elif hasattr(zero, "keys"): # execute(stmt, {"key":"value"}) return [zero] else: # execute(stmt, "value") return [[zero]] else: - if hasattr(multiparams[0], '__iter__') and \ - not hasattr(multiparams[0], 'strip'): + if hasattr(multiparams[0], "__iter__") and not hasattr( + multiparams[0], "strip" + ): return multiparams else: return [multiparams] return locals() + + try: from sqlalchemy.cutils import _distill_params except ImportError: diff --git a/lib/sqlalchemy/event/api.py b/lib/sqlalchemy/event/api.py index acfacc233e..f9e04503c6 100644 --- a/lib/sqlalchemy/event/api.py +++ b/lib/sqlalchemy/event/api.py @@ -14,8 +14,8 @@ from .. import util, exc from .base import _registrars from .registry import _EventKey -CANCEL = util.symbol('CANCEL') -NO_RETVAL = util.symbol('NO_RETVAL') +CANCEL = util.symbol("CANCEL") +NO_RETVAL = util.symbol("NO_RETVAL") def _event_key(target, identifier, fn): @@ -24,8 +24,9 @@ def _event_key(target, identifier, fn): if tgt is not None: return _EventKey(target, identifier, fn, tgt) else: - raise exc.InvalidRequestError("No such event '%s' for target '%s'" % - (identifier, target)) + raise exc.InvalidRequestError( + "No such event '%s' for target '%s'" % (identifier, target) + ) def listen(target, identifier, fn, *args, **kw): @@ -120,9 +121,11 @@ def listens_for(target, identifier, *args, **kw): :func:`.listen` - general description of event listening """ + def decorate(fn): listen(target, identifier, fn, *args, **kw) return fn + return decorate diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index c33ec82ff6..31a9f28ca4 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -41,7 +41,7 @@ import collections class RefCollection(util.MemoizedSlots): - __slots__ = 'ref', + __slots__ = ("ref",) def _memoized_attr_ref(self): return weakref.ref(self, registry._collection_gced) @@ -67,20 +67,27 @@ class _empty_collection(object): class _ClsLevelDispatch(RefCollection): """Class-level events on :class:`._Dispatch` classes.""" - __slots__ = ('name', 'arg_names', 'has_kw', - 'legacy_signatures', '_clslevel', '__weakref__') + __slots__ = ( + "name", + "arg_names", + "has_kw", + "legacy_signatures", + "_clslevel", + "__weakref__", + ) def __init__(self, parent_dispatch_cls, fn): self.name = fn.__name__ argspec = util.inspect_getargspec(fn) self.arg_names = argspec.args[1:] self.has_kw = bool(argspec.keywords) - self.legacy_signatures = list(reversed( - sorted( - getattr(fn, '_legacy_signatures', []), - key=lambda s: s[0] + self.legacy_signatures = list( + reversed( + sorted( + getattr(fn, "_legacy_signatures", []), key=lambda s: s[0] + ) ) - )) + ) fn.__doc__ = legacy._augment_fn_docs(self, parent_dispatch_cls, fn) self._clslevel = weakref.WeakKeyDictionary() @@ -102,15 +109,18 @@ class _ClsLevelDispatch(RefCollection): argdict = dict(zip(self.arg_names, args)) argdict.update(kw) return fn(**argdict) + return wrap_kw def insert(self, event_key, propagate): target = event_key.dispatch_target - assert isinstance(target, type), \ - "Class-level Event targets must be classes." - if not getattr(target, '_sa_propagate_class_events', True): + assert isinstance( + target, type + ), "Class-level Event targets must be classes." + if not getattr(target, "_sa_propagate_class_events", True): raise exc.InvalidRequestError( - "Can't assign an event directly to the %s class" % target) + "Can't assign an event directly to the %s class" % target + ) stack = [target] while stack: cls = stack.pop(0) @@ -125,11 +135,13 @@ class _ClsLevelDispatch(RefCollection): def append(self, event_key, propagate): target = event_key.dispatch_target - assert isinstance(target, type), \ - "Class-level Event targets must be classes." - if not getattr(target, '_sa_propagate_class_events', True): + assert isinstance( + target, type + ), "Class-level Event targets must be classes." + if not getattr(target, "_sa_propagate_class_events", True): raise exc.InvalidRequestError( - "Can't assign an event directly to the %s class" % target) + "Can't assign an event directly to the %s class" % target + ) stack = [target] while stack: cls = stack.pop(0) @@ -143,7 +155,7 @@ class _ClsLevelDispatch(RefCollection): registry._stored_in_collection(event_key, self) def _assign_cls_collection(self, target): - if getattr(target, '_sa_propagate_class_events', True): + if getattr(target, "_sa_propagate_class_events", True): self._clslevel[target] = collections.deque() else: self._clslevel[target] = _empty_collection() @@ -154,11 +166,9 @@ class _ClsLevelDispatch(RefCollection): clslevel = self._clslevel[target] for cls in target.__mro__[1:]: if cls in self._clslevel: - clslevel.extend([ - fn for fn - in self._clslevel[cls] - if fn not in clslevel - ]) + clslevel.extend( + [fn for fn in self._clslevel[cls] if fn not in clslevel] + ) def remove(self, event_key): target = event_key.dispatch_target @@ -209,7 +219,7 @@ class _EmptyListener(_InstanceLevelDispatch): propagate = frozenset() listeners = () - __slots__ = 'parent', 'parent_listeners', 'name' + __slots__ = "parent", "parent_listeners", "name" def __init__(self, parent, target_cls): if target_cls not in parent._clslevel: @@ -258,7 +268,7 @@ class _EmptyListener(_InstanceLevelDispatch): class _CompoundListener(_InstanceLevelDispatch): - __slots__ = '_exec_once_mutex', '_exec_once' + __slots__ = "_exec_once_mutex", "_exec_once" def _memoized_attr__exec_once_mutex(self): return threading.Lock() @@ -306,8 +316,13 @@ class _ListenerCollection(_CompoundListener): """ __slots__ = ( - 'parent_listeners', 'parent', 'name', 'listeners', - 'propagate', '__weakref__') + "parent_listeners", + "parent", + "name", + "listeners", + "propagate", + "__weakref__", + ) def __init__(self, parent, target_cls): if target_cls not in parent._clslevel: @@ -335,11 +350,13 @@ class _ListenerCollection(_CompoundListener): existing_listeners = self.listeners existing_listener_set = set(existing_listeners) self.propagate.update(other.propagate) - other_listeners = [l for l - in other.listeners - if l not in existing_listener_set - and not only_propagate or l in self.propagate - ] + other_listeners = [ + l + for l in other.listeners + if l not in existing_listener_set + and not only_propagate + or l in self.propagate + ] existing_listeners.extend(other_listeners) @@ -368,7 +385,7 @@ class _ListenerCollection(_CompoundListener): class _JoinedListener(_CompoundListener): - __slots__ = 'parent', 'name', 'local', 'parent_listeners' + __slots__ = "parent", "name", "local", "parent_listeners" def __init__(self, parent, name, local): self._exec_once = False diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index 137aec2589..c750be70ac 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -26,7 +26,7 @@ _registrars = util.defaultdict(list) def _is_event_name(name): - return not name.startswith('_') and name != 'dispatch' + return not name.startswith("_") and name != "dispatch" class _UnpickleDispatch(object): @@ -37,8 +37,8 @@ class _UnpickleDispatch(object): def __call__(self, _instance_cls): for cls in _instance_cls.__mro__: - if 'dispatch' in cls.__dict__: - return cls.__dict__['dispatch'].dispatch._for_class( + if "dispatch" in cls.__dict__: + return cls.__dict__["dispatch"].dispatch._for_class( _instance_cls ) else: @@ -67,7 +67,7 @@ class _Dispatch(object): # In one ORM edge case, an attribute is added to _Dispatch, # so __dict__ is used in just that case and potentially others. - __slots__ = '_parent', '_instance_cls', '__dict__', '_empty_listeners' + __slots__ = "_parent", "_instance_cls", "__dict__", "_empty_listeners" _empty_listener_reg = weakref.WeakKeyDictionary() @@ -79,7 +79,9 @@ class _Dispatch(object): try: self._empty_listeners = self._empty_listener_reg[instance_cls] except KeyError: - self._empty_listeners = self._empty_listener_reg[instance_cls] = { + self._empty_listeners = self._empty_listener_reg[ + instance_cls + ] = { ls.name: _EmptyListener(ls, instance_cls) for ls in parent._event_descriptors } @@ -122,17 +124,18 @@ class _Dispatch(object): :class:`._Dispatch` objects. """ - if '_joined_dispatch_cls' not in self.__class__.__dict__: + if "_joined_dispatch_cls" not in self.__class__.__dict__: cls = type( "Joined%s" % self.__class__.__name__, - (_JoinedDispatcher, ), {'__slots__': self._event_names} + (_JoinedDispatcher,), + {"__slots__": self._event_names}, ) self.__class__._joined_dispatch_cls = cls return self._joined_dispatch_cls(self, other) def __reduce__(self): - return _UnpickleDispatch(), (self._instance_cls, ) + return _UnpickleDispatch(), (self._instance_cls,) def _update(self, other, only_propagate=True): """Populate from the listeners in another :class:`_Dispatch` @@ -140,8 +143,9 @@ class _Dispatch(object): for ls in other._event_descriptors: if isinstance(ls, _EmptyListener): continue - getattr(self, ls.name).\ - for_modify(self)._update(ls, only_propagate=only_propagate) + getattr(self, ls.name).for_modify(self)._update( + ls, only_propagate=only_propagate + ) def _clear(self): for ls in self._event_descriptors: @@ -164,14 +168,15 @@ def _create_dispatcher_class(cls, classname, bases, dict_): # there's all kinds of ways to do this, # i.e. make a Dispatch class that shares the '_listen' method # of the Event class, this is the straight monkeypatch. - if hasattr(cls, 'dispatch'): + if hasattr(cls, "dispatch"): dispatch_base = cls.dispatch.__class__ else: dispatch_base = _Dispatch event_names = [k for k in dict_ if _is_event_name(k)] - dispatch_cls = type("%sDispatch" % classname, - (dispatch_base, ), {'__slots__': event_names}) + dispatch_cls = type( + "%sDispatch" % classname, (dispatch_base,), {"__slots__": event_names} + ) dispatch_cls._event_names = event_names @@ -186,7 +191,7 @@ def _create_dispatcher_class(cls, classname, bases, dict_): setattr(dispatch_inst, ls.name, ls) dispatch_cls._event_names.append(ls.name) - if getattr(cls, '_dispatch_target', None): + if getattr(cls, "_dispatch_target", None): cls._dispatch_target.dispatch = dispatcher(cls) @@ -221,12 +226,14 @@ class Events(util.with_metaclass(_EventMeta, object)): # Mapper, ClassManager, Session override this to # also accept classes, scoped_sessions, sessionmakers, etc. - if hasattr(target, 'dispatch'): + if hasattr(target, "dispatch"): if ( dispatch_is(cls.dispatch.__class__) or dispatch_is(type, cls.dispatch.__class__) - or (dispatch_is(_JoinedDispatcher) - and dispatch_parent_is(cls.dispatch.__class__)) + or ( + dispatch_is(_JoinedDispatcher) + and dispatch_parent_is(cls.dispatch.__class__) + ) ): return target @@ -246,7 +253,7 @@ class Events(util.with_metaclass(_EventMeta, object)): class _JoinedDispatcher(object): """Represent a connection between two _Dispatch objects.""" - __slots__ = 'local', 'parent', '_instance_cls' + __slots__ = "local", "parent", "_instance_cls" def __init__(self, local, parent): self.local = local @@ -281,5 +288,5 @@ class dispatcher(object): def __get__(self, obj, cls): if obj is None: return self.dispatch - obj.__dict__['dispatch'] = disp = self.dispatch._for_instance(obj) + obj.__dict__["dispatch"] = disp = self.dispatch._for_instance(obj) return disp diff --git a/lib/sqlalchemy/event/legacy.py b/lib/sqlalchemy/event/legacy.py index 1883070f4d..c30b922fd3 100644 --- a/lib/sqlalchemy/event/legacy.py +++ b/lib/sqlalchemy/event/legacy.py @@ -15,10 +15,11 @@ from .. import util def _legacy_signature(since, argnames, converter=None): def leg(fn): - if not hasattr(fn, '_legacy_signatures'): + if not hasattr(fn, "_legacy_signatures"): fn._legacy_signatures = [] fn._legacy_signatures.append((since, argnames, converter)) return fn + return leg @@ -30,15 +31,18 @@ def _wrap_fn_for_legacy(dispatch_collection, fn, argspec): else: has_kw = False - if len(argnames) == len(argspec.args) \ - and has_kw is bool(argspec.keywords): + if len(argnames) == len(argspec.args) and has_kw is bool( + argspec.keywords + ): if conv: assert not has_kw def wrap_leg(*args): return fn(*conv(*args)) + else: + def wrap_leg(*args, **kw): argdict = dict(zip(dispatch_collection.arg_names, args)) args = [argdict[name] for name in argnames] @@ -46,16 +50,14 @@ def _wrap_fn_for_legacy(dispatch_collection, fn, argspec): return fn(*args, **kw) else: return fn(*args) + return wrap_leg else: return fn def _indent(text, indent): - return "\n".join( - indent + line - for line in text.split("\n") - ) + return "\n".join(indent + line for line in text.split("\n")) def _standard_listen_example(dispatch_collection, sample_target, fn): @@ -64,10 +66,13 @@ def _standard_listen_example(dispatch_collection, sample_target, fn): "%(arg)s = kw['%(arg)s']" % {"arg": arg} for arg in dispatch_collection.arg_names[0:2] ), - " ") + " ", + ) if dispatch_collection.legacy_signatures: - current_since = max(since for since, args, conv - in dispatch_collection.legacy_signatures) + current_since = max( + since + for since, args, conv in dispatch_collection.legacy_signatures + ) else: current_since = None text = ( @@ -82,7 +87,6 @@ def _standard_listen_example(dispatch_collection, sample_target, fn): if len(dispatch_collection.arg_names) > 3: text += ( - "\n# named argument style (new in 0.9)\n" "@event.listens_for(" "%(sample_target)s, '%(event_name)s', named=True)\n" @@ -93,13 +97,14 @@ def _standard_listen_example(dispatch_collection, sample_target, fn): ) text %= { - "current_since": " (arguments as of %s)" % - current_since if current_since else "", + "current_since": " (arguments as of %s)" % current_since + if current_since + else "", "event_name": fn.__name__, "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "", "named_event_arguments": ", ".join(dispatch_collection.arg_names), "example_kw_arg": example_kw_arg, - "sample_target": sample_target + "sample_target": sample_target, } return text @@ -113,13 +118,15 @@ def _legacy_listen_examples(dispatch_collection, sample_target, fn): "def receive_%(event_name)s(" "%(named_event_arguments)s%(has_kw_arguments)s):\n" " \"listen for the '%(event_name)s' event\"\n" - "\n # ... (event handling logic) ...\n" % { + "\n # ... (event handling logic) ...\n" + % { "since": since, "event_name": fn.__name__, "has_kw_arguments": " **kw" - if dispatch_collection.has_kw else "", + if dispatch_collection.has_kw + else "", "named_event_arguments": ", ".join(args), - "sample_target": sample_target + "sample_target": sample_target, } ) return text @@ -133,37 +140,34 @@ def _version_signature_changes(dispatch_collection): " arguments ``%(named_event_arguments)s%(has_kw_arguments)s``.\n" " Listener functions which accept the previous argument \n" " signature(s) listed above will be automatically \n" - " adapted to the new signature." % { + " adapted to the new signature." + % { "since": since, "event_name": dispatch_collection.name, "named_event_arguments": ", ".join(dispatch_collection.arg_names), - "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "" + "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "", } ) def _augment_fn_docs(dispatch_collection, parent_dispatch_cls, fn): - header = ".. container:: event_signatures\n\n"\ - " Example argument forms::\n"\ + header = ( + ".. container:: event_signatures\n\n" + " Example argument forms::\n" "\n" + ) sample_target = getattr(parent_dispatch_cls, "_target_class_doc", "obj") - text = ( - header + - _indent( - _standard_listen_example( - dispatch_collection, sample_target, fn), - " " * 8) + text = header + _indent( + _standard_listen_example(dispatch_collection, sample_target, fn), + " " * 8, ) if dispatch_collection.legacy_signatures: text += _indent( - _legacy_listen_examples( - dispatch_collection, sample_target, fn), - " " * 8) + _legacy_listen_examples(dispatch_collection, sample_target, fn), + " " * 8, + ) text += _version_signature_changes(dispatch_collection) - return util.inject_docstring_text(fn.__doc__, - text, - 1 - ) + return util.inject_docstring_text(fn.__doc__, text, 1) diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py index 8d4bada0b2..c862ae4033 100644 --- a/lib/sqlalchemy/event/registry.py +++ b/lib/sqlalchemy/event/registry.py @@ -141,11 +141,15 @@ class _EventKey(object): """ __slots__ = ( - 'target', 'identifier', 'fn', 'fn_key', 'fn_wrap', 'dispatch_target' + "target", + "identifier", + "fn", + "fn_key", + "fn_wrap", + "dispatch_target", ) - def __init__(self, target, identifier, - fn, dispatch_target, _fn_wrap=None): + def __init__(self, target, identifier, fn, dispatch_target, _fn_wrap=None): self.target = target self.identifier = identifier self.fn = fn @@ -169,7 +173,7 @@ class _EventKey(object): self.identifier, self.fn, self.dispatch_target, - _fn_wrap=fn_wrap + _fn_wrap=fn_wrap, ) def with_dispatch_target(self, dispatch_target): @@ -181,15 +185,18 @@ class _EventKey(object): self.identifier, self.fn, dispatch_target, - _fn_wrap=self.fn_wrap + _fn_wrap=self.fn_wrap, ) def listen(self, *args, **kw): once = kw.pop("once", False) named = kw.pop("named", False) - target, identifier, fn = \ - self.dispatch_target, self.identifier, self._listen_fn + target, identifier, fn = ( + self.dispatch_target, + self.identifier, + self._listen_fn, + ) dispatch_collection = getattr(target.dispatch, identifier) @@ -198,8 +205,9 @@ class _EventKey(object): self = self.with_wrapper(adjusted_fn) if once: - self.with_wrapper( - util.only_once(self._listen_fn)).listen(*args, **kw) + self.with_wrapper(util.only_once(self._listen_fn)).listen( + *args, **kw + ) else: self.dispatch_target.dispatch._listen(self, *args, **kw) @@ -208,8 +216,8 @@ class _EventKey(object): if key not in _key_to_collection: raise exc.InvalidRequestError( - "No listeners found for event %s / %r / %s " % - (self.target, self.identifier, self.fn) + "No listeners found for event %s / %r / %s " + % (self.target, self.identifier, self.fn) ) dispatch_reg = _key_to_collection.pop(key) @@ -224,20 +232,26 @@ class _EventKey(object): """ return self._key in _key_to_collection - def base_listen(self, propagate=False, insert=False, - named=False, retval=None): + def base_listen( + self, propagate=False, insert=False, named=False, retval=None + ): - target, identifier, fn = \ - self.dispatch_target, self.identifier, self._listen_fn + target, identifier, fn = ( + self.dispatch_target, + self.identifier, + self._listen_fn, + ) dispatch_collection = getattr(target.dispatch, identifier) if insert: - dispatch_collection.\ - for_modify(target.dispatch).insert(self, propagate) + dispatch_collection.for_modify(target.dispatch).insert( + self, propagate + ) else: - dispatch_collection.\ - for_modify(target.dispatch).append(self, propagate) + dispatch_collection.for_modify(target.dispatch).append( + self, propagate + ) @property def _listen_fn(self): diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py index 3e97ea8968..fa62b77056 100644 --- a/lib/sqlalchemy/events.py +++ b/lib/sqlalchemy/events.py @@ -600,39 +600,53 @@ class ConnectionEvents(event.Events): @classmethod def _listen(cls, event_key, retval=False): - target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, \ - event_key._listen_fn + target, identifier, fn = ( + event_key.dispatch_target, + event_key.identifier, + event_key._listen_fn, + ) target._has_events = True if not retval: - if identifier == 'before_execute': + if identifier == "before_execute": orig_fn = fn - def wrap_before_execute(conn, clauseelement, - multiparams, params): + def wrap_before_execute( + conn, clauseelement, multiparams, params + ): orig_fn(conn, clauseelement, multiparams, params) return clauseelement, multiparams, params + fn = wrap_before_execute - elif identifier == 'before_cursor_execute': + elif identifier == "before_cursor_execute": orig_fn = fn - def wrap_before_cursor_execute(conn, cursor, statement, - parameters, context, - executemany): - orig_fn(conn, cursor, statement, - parameters, context, executemany) + def wrap_before_cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): + orig_fn( + conn, + cursor, + statement, + parameters, + context, + executemany, + ) return statement, parameters + fn = wrap_before_cursor_execute - elif retval and \ - identifier not in ('before_execute', - 'before_cursor_execute', 'handle_error'): + elif retval and identifier not in ( + "before_execute", + "before_cursor_execute", + "handle_error", + ): raise exc.ArgumentError( "Only the 'before_execute', " "'before_cursor_execute' and 'handle_error' engine " "event listeners accept the 'retval=True' " - "argument.") + "argument." + ) event_key.with_wrapper(fn).base_listen() def before_execute(self, conn, clauseelement, multiparams, params): @@ -677,8 +691,9 @@ class ConnectionEvents(event.Events): """ - def before_cursor_execute(self, conn, cursor, statement, - parameters, context, executemany): + def before_cursor_execute( + self, conn, cursor, statement, parameters, context, executemany + ): """Intercept low-level cursor execute() events before execution, receiving the string SQL statement and DBAPI-specific parameter list to be invoked against a cursor. @@ -718,8 +733,9 @@ class ConnectionEvents(event.Events): """ - def after_cursor_execute(self, conn, cursor, statement, - parameters, context, executemany): + def after_cursor_execute( + self, conn, cursor, statement, parameters, context, executemany + ): """Intercept low-level cursor execute() events after execution. :param conn: :class:`.Connection` object @@ -737,8 +753,9 @@ class ConnectionEvents(event.Events): """ - def dbapi_error(self, conn, cursor, statement, parameters, - context, exception): + def dbapi_error( + self, conn, cursor, statement, parameters, context, exception + ): """Intercept a raw DBAPI error. This event is called with the DBAPI exception instance @@ -1039,6 +1056,7 @@ class ConnectionEvents(event.Events): .. versionadded:: 1.0.5 """ + def begin(self, conn): """Intercept begin() events. @@ -1173,8 +1191,11 @@ class DialectEvents(event.Events): @classmethod def _listen(cls, event_key, retval=False): - target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, event_key.fn + target, identifier, fn = ( + event_key.dispatch_target, + event_key.identifier, + event_key.fn, + ) target._has_events = True event_key.base_listen() @@ -1235,8 +1256,9 @@ class DialectEvents(event.Events): """ - def do_setinputsizes(self, - inputsizes, cursor, statement, parameters, context): + def do_setinputsizes( + self, inputsizes, cursor, statement, parameters, context + ): """Receive the setinputsizes dictionary for possible modification. This event is emitted in the case where the dialect makes use of the diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index 40dcb7c55b..832c5ee524 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -22,7 +22,7 @@ class SQLAlchemyError(Exception): code = None def __init__(self, *arg, **kw): - code = kw.pop('code', None) + code = kw.pop("code", None) if code is not None: self.code = code super(SQLAlchemyError, self).__init__(*arg, **kw) @@ -33,7 +33,7 @@ class SQLAlchemyError(Exception): else: return ( "(Background on this error at: " - "http://sqlalche.me/e/%s)" % (self.code, ) + "http://sqlalche.me/e/%s)" % (self.code,) ) def _message(self): @@ -48,9 +48,7 @@ class SQLAlchemyError(Exception): message = self._message() if self.code: - message = ( - "%s %s" % (message, self._code_str()) - ) + message = "%s %s" % (message, self._code_str()) return message @@ -112,6 +110,7 @@ class CircularDependencyError(SQLAlchemyError): see :ref:`use_alter`. """ + def __init__(self, message, cycles, edges, msg=None, code=None): if msg is None: message += " (%s)" % ", ".join(repr(s) for s in cycles) @@ -122,8 +121,7 @@ class CircularDependencyError(SQLAlchemyError): self.edges = edges def __reduce__(self): - return self.__class__, (None, self.cycles, - self.edges, self.args[0]) + return self.__class__, (None, self.cycles, self.edges, self.args[0]) class CompileError(SQLAlchemyError): @@ -140,8 +138,9 @@ class UnsupportedCompilationError(CompileError): def __init__(self, compiler, element_type): super(UnsupportedCompilationError, self).__init__( - "Compiler %r can't render element of type %s" % - (compiler, element_type)) + "Compiler %r can't render element of type %s" + % (compiler, element_type) + ) class IdentifierError(SQLAlchemyError): @@ -158,6 +157,7 @@ class DisconnectionError(SQLAlchemyError): regarding the connection attempt. """ + invalidate_pool = False @@ -175,6 +175,7 @@ class InvalidatePoolError(DisconnectionError): .. versionadded:: 1.2 """ + invalidate_pool = True @@ -213,6 +214,7 @@ class NoReferencedTableError(NoReferenceError): located. """ + def __init__(self, message, tname): NoReferenceError.__init__(self, message) self.table_name = tname @@ -226,14 +228,17 @@ class NoReferencedColumnError(NoReferenceError): located. """ + def __init__(self, message, tname, cname): NoReferenceError.__init__(self, message) self.table_name = tname self.column_name = cname def __reduce__(self): - return self.__class__, (self.args[0], self.table_name, - self.column_name) + return ( + self.__class__, + (self.args[0], self.table_name, self.column_name), + ) class NoSuchTableError(InvalidRequestError): @@ -273,6 +278,7 @@ class DontWrapMixin(object): """ + # Moved to orm.exc; compatibility definition installed by orm import until 0.6 UnmappedColumnError = None @@ -310,8 +316,10 @@ class StatementError(SQLAlchemyError): self.detail.append(msg) def __reduce__(self): - return self.__class__, (self.args[0], self.statement, - self.params, self.orig) + return ( + self.__class__, + (self.args[0], self.statement, self.params, self.orig), + ) def __str__(self): from sqlalchemy.sql import util @@ -325,9 +333,7 @@ class StatementError(SQLAlchemyError): code_str = self._code_str() if code_str: details.append(code_str) - return ' '.join([ - "(%s)" % det for det in self.detail - ] + details) + return " ".join(["(%s)" % det for det in self.detail] + details) class DBAPIError(StatementError): @@ -353,18 +359,23 @@ class DBAPIError(StatementError): """ - code = 'dbapi' + code = "dbapi" @classmethod - def instance(cls, statement, params, - orig, dbapi_base_err, - connection_invalidated=False, - dialect=None): + def instance( + cls, + statement, + params, + orig, + dbapi_base_err, + connection_invalidated=False, + dialect=None, + ): # Don't ever wrap these, just return them directly as if # DBAPIError didn't exist. - if (isinstance(orig, BaseException) and - not isinstance(orig, Exception)) or \ - isinstance(orig, DontWrapMixin): + if ( + isinstance(orig, BaseException) and not isinstance(orig, Exception) + ) or isinstance(orig, DontWrapMixin): return orig if orig is not None: @@ -372,17 +383,28 @@ class DBAPIError(StatementError): # raise a StatementError if isinstance(orig, SQLAlchemyError) and statement: return StatementError( - "(%s.%s) %s" % - (orig.__class__.__module__, orig.__class__.__name__, - orig.args[0]), - statement, params, orig, code=orig.code + "(%s.%s) %s" + % ( + orig.__class__.__module__, + orig.__class__.__name__, + orig.args[0], + ), + statement, + params, + orig, + code=orig.code, ) elif not isinstance(orig, dbapi_base_err) and statement: return StatementError( - "(%s.%s) %s" % - (orig.__class__.__module__, orig.__class__.__name__, - orig), - statement, params, orig + "(%s.%s) %s" + % ( + orig.__class__.__module__, + orig.__class__.__name__, + orig, + ), + statement, + params, + orig, ) glob = globals() @@ -390,31 +412,42 @@ class DBAPIError(StatementError): name = super_.__name__ if dialect: name = dialect.dbapi_exception_translation_map.get( - name, name) + name, name + ) if name in glob and issubclass(glob[name], DBAPIError): cls = glob[name] break - return cls(statement, params, orig, connection_invalidated, - code=cls.code) + return cls( + statement, params, orig, connection_invalidated, code=cls.code + ) def __reduce__(self): - return self.__class__, (self.statement, self.params, - self.orig, self.connection_invalidated) + return ( + self.__class__, + ( + self.statement, + self.params, + self.orig, + self.connection_invalidated, + ), + ) - def __init__(self, statement, params, orig, connection_invalidated=False, - code=None): + def __init__( + self, statement, params, orig, connection_invalidated=False, code=None + ): try: text = str(orig) except Exception as e: - text = 'Error in str() of DB-API-generated exception: ' + str(e) + text = "Error in str() of DB-API-generated exception: " + str(e) StatementError.__init__( self, - '(%s.%s) %s' % ( - orig.__class__.__module__, orig.__class__.__name__, text, ), + "(%s.%s) %s" + % (orig.__class__.__module__, orig.__class__.__name__, text), statement, params, - orig, code=code + orig, + code=code, ) self.connection_invalidated = connection_invalidated @@ -466,8 +499,10 @@ class NotSupportedError(DatabaseError): code = "tw8g" + # Warnings + class SADeprecationWarning(DeprecationWarning): """Issued once per usage of a deprecated API.""" diff --git a/lib/sqlalchemy/ext/__init__.py b/lib/sqlalchemy/ext/__init__.py index 9558b2a1f8..9fed09e2bf 100644 --- a/lib/sqlalchemy/ext/__init__.py +++ b/lib/sqlalchemy/ext/__init__.py @@ -8,4 +8,3 @@ from .. import util as _sa_util _sa_util.dependencies.resolve_all("sqlalchemy.ext") - diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index ff9433d4de..56b91ce0bb 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -76,7 +76,7 @@ def association_proxy(target_collection, attr, **kw): return AssociationProxy(target_collection, attr, **kw) -ASSOCIATION_PROXY = util.symbol('ASSOCIATION_PROXY') +ASSOCIATION_PROXY = util.symbol("ASSOCIATION_PROXY") """Symbol indicating an :class:`InspectionAttr` that's of type :class:`.AssociationProxy`. @@ -92,10 +92,17 @@ class AssociationProxy(interfaces.InspectionAttrInfo): is_attribute = False extension_type = ASSOCIATION_PROXY - def __init__(self, target_collection, attr, creator=None, - getset_factory=None, proxy_factory=None, - proxy_bulk_set=None, info=None, - cascade_scalar_deletes=False): + def __init__( + self, + target_collection, + attr, + creator=None, + getset_factory=None, + proxy_factory=None, + proxy_bulk_set=None, + info=None, + cascade_scalar_deletes=False, + ): """Construct a new :class:`.AssociationProxy`. The :func:`.association_proxy` function is provided as the usual @@ -162,8 +169,11 @@ class AssociationProxy(interfaces.InspectionAttrInfo): self.proxy_bulk_set = proxy_bulk_set self.cascade_scalar_deletes = cascade_scalar_deletes - self.key = '_%s_%s_%s' % ( - type(self).__name__, target_collection, id(self)) + self.key = "_%s_%s_%s" % ( + type(self).__name__, + target_collection, + id(self), + ) if info: self.info = info @@ -264,12 +274,17 @@ class AssociationProxy(interfaces.InspectionAttrInfo): def getter(target): return _getter(target) if target is not None else None + if collection_class is dict: + def setter(o, k, v): setattr(o, attr, v) + else: + def setter(o, v): setattr(o, attr, v) + return getter, setter @@ -325,20 +340,21 @@ class AssociationProxyInstance(object): def for_proxy(cls, parent, owning_class, parent_instance): target_collection = parent.target_collection value_attr = parent.value_attr - prop = orm.class_mapper(owning_class).\ - get_property(target_collection) + prop = orm.class_mapper(owning_class).get_property(target_collection) # this was never asserted before but this should be made clear. if not isinstance(prop, orm.RelationshipProperty): raise NotImplementedError( "association proxy to a non-relationship " - "intermediary is not supported") + "intermediary is not supported" + ) target_class = prop.mapper.class_ try: target_assoc = cls._cls_unwrap_target_assoc_proxy( - target_class, value_attr) + target_class, value_attr + ) except AttributeError: # the proxied attribute doesn't exist on the target class; # return an "ambiguous" instance that will work on a per-object @@ -353,8 +369,8 @@ class AssociationProxyInstance(object): @classmethod def _construct_for_assoc( - cls, target_assoc, parent, owning_class, - target_class, value_attr): + cls, target_assoc, parent, owning_class, target_class, value_attr + ): if target_assoc is not None: return ObjectAssociationProxyInstance( parent, owning_class, target_class, value_attr @@ -371,8 +387,9 @@ class AssociationProxyInstance(object): ) def _get_property(self): - return orm.class_mapper(self.owning_class).\ - get_property(self.target_collection) + return orm.class_mapper(self.owning_class).get_property( + self.target_collection + ) @property def _comparator(self): @@ -388,7 +405,8 @@ class AssociationProxyInstance(object): @util.memoized_property def _unwrap_target_assoc_proxy(self): return self._cls_unwrap_target_assoc_proxy( - self.target_class, self.value_attr) + self.target_class, self.value_attr + ) @property def remote_attr(self): @@ -448,8 +466,11 @@ class AssociationProxyInstance(object): @util.memoized_property def _value_is_scalar(self): - return not self._get_property().\ - mapper.get_property(self.value_attr).uselist + return ( + not self._get_property() + .mapper.get_property(self.value_attr) + .uselist + ) @property def _target_is_object(self): @@ -468,12 +489,17 @@ class AssociationProxyInstance(object): def getter(target): return _getter(target) if target is not None else None + if collection_class is dict: + def setter(o, k, v): return setattr(o, attr, v) + else: + def setter(o, v): return setattr(o, attr, v) + return getter, setter @property @@ -500,14 +526,18 @@ class AssociationProxyInstance(object): return proxy self.collection_class, proxy = self._new( - _lazy_collection(obj, self.target_collection)) + _lazy_collection(obj, self.target_collection) + ) setattr(obj, self.key, (id(obj), id(self), proxy)) return proxy def set(self, obj, values): if self.scalar: - creator = self.parent.creator \ - if self.parent.creator else self.target_class + creator = ( + self.parent.creator + if self.parent.creator + else self.target_class + ) target = getattr(obj, self.target_collection) if target is None: if values is None: @@ -535,35 +565,52 @@ class AssociationProxyInstance(object): delattr(obj, self.target_collection) def _new(self, lazy_collection): - creator = self.parent.creator if self.parent.creator else \ - self.target_class + creator = ( + self.parent.creator if self.parent.creator else self.target_class + ) collection_class = util.duck_type_collection(lazy_collection()) if self.parent.proxy_factory: - return collection_class, self.parent.proxy_factory( - lazy_collection, creator, self.value_attr, self) + return ( + collection_class, + self.parent.proxy_factory( + lazy_collection, creator, self.value_attr, self + ), + ) if self.parent.getset_factory: - getter, setter = self.parent.getset_factory( - collection_class, self) + getter, setter = self.parent.getset_factory(collection_class, self) else: getter, setter = self.parent._default_getset(collection_class) if collection_class is list: - return collection_class, _AssociationList( - lazy_collection, creator, getter, setter, self) + return ( + collection_class, + _AssociationList( + lazy_collection, creator, getter, setter, self + ), + ) elif collection_class is dict: - return collection_class, _AssociationDict( - lazy_collection, creator, getter, setter, self) + return ( + collection_class, + _AssociationDict( + lazy_collection, creator, getter, setter, self + ), + ) elif collection_class is set: - return collection_class, _AssociationSet( - lazy_collection, creator, getter, setter, self) + return ( + collection_class, + _AssociationSet( + lazy_collection, creator, getter, setter, self + ), + ) else: raise exc.ArgumentError( - 'could not guess which interface to use for ' + "could not guess which interface to use for " 'collection_class "%s" backing "%s"; specify a ' - 'proxy_factory and proxy_bulk_set manually' % - (self.collection_class.__name__, self.target_collection)) + "proxy_factory and proxy_bulk_set manually" + % (self.collection_class.__name__, self.target_collection) + ) def _set(self, proxy, values): if self.parent.proxy_bulk_set: @@ -576,16 +623,19 @@ class AssociationProxyInstance(object): proxy.update(values) else: raise exc.ArgumentError( - 'no proxy_bulk_set supplied for custom ' - 'collection_class implementation') + "no proxy_bulk_set supplied for custom " + "collection_class implementation" + ) def _inflate(self, proxy): - creator = self.parent.creator and \ - self.parent.creator or self.target_class + creator = ( + self.parent.creator and self.parent.creator or self.target_class + ) if self.parent.getset_factory: getter, setter = self.parent.getset_factory( - self.collection_class, self) + self.collection_class, self + ) else: getter, setter = self.parent._default_getset(self.collection_class) @@ -594,12 +644,13 @@ class AssociationProxyInstance(object): proxy.setter = setter def _criterion_exists(self, criterion=None, **kwargs): - is_has = kwargs.pop('is_has', None) + is_has = kwargs.pop("is_has", None) target_assoc = self._unwrap_target_assoc_proxy if target_assoc is not None: inner = target_assoc._criterion_exists( - criterion=criterion, **kwargs) + criterion=criterion, **kwargs + ) return self._comparator._criterion_exists(inner) if self._target_is_object: @@ -631,15 +682,15 @@ class AssociationProxyInstance(object): """ if self._unwrap_target_assoc_proxy is None and ( - self.scalar and ( - not self._target_is_object or self._value_is_scalar) + self.scalar + and (not self._target_is_object or self._value_is_scalar) ): raise exc.InvalidRequestError( - "'any()' not implemented for scalar " - "attributes. Use has()." + "'any()' not implemented for scalar " "attributes. Use has()." ) return self._criterion_exists( - criterion=criterion, is_has=False, **kwargs) + criterion=criterion, is_has=False, **kwargs + ) def has(self, criterion=None, **kwargs): """Produce a proxied 'has' expression using EXISTS. @@ -651,14 +702,15 @@ class AssociationProxyInstance(object): """ if self._unwrap_target_assoc_proxy is None and ( - not self.scalar or ( - self._target_is_object and not self._value_is_scalar) + not self.scalar + or (self._target_is_object and not self._value_is_scalar) ): raise exc.InvalidRequestError( - "'has()' not implemented for collections. " - "Use any().") + "'has()' not implemented for collections. " "Use any()." + ) return self._criterion_exists( - criterion=criterion, is_has=True, **kwargs) + criterion=criterion, is_has=True, **kwargs + ) class AmbiguousAssociationProxyInstance(AssociationProxyInstance): @@ -673,10 +725,14 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance): "Association proxy %s.%s refers to an attribute '%s' that is not " "directly mapped on class %s; therefore this operation cannot " "proceed since we don't know what type of object is referred " - "towards" % ( - self.owning_class.__name__, self.target_collection, - self.value_attr, self.target_class - )) + "towards" + % ( + self.owning_class.__name__, + self.target_collection, + self.value_attr, + self.target_class, + ) + ) def get(self, obj): self._ambiguous() @@ -718,27 +774,32 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance): return self def _populate_cache(self, instance_class): - prop = orm.class_mapper(self.owning_class).\ - get_property(self.target_collection) + prop = orm.class_mapper(self.owning_class).get_property( + self.target_collection + ) if inspect(instance_class).mapper.isa(prop.mapper): target_class = instance_class try: target_assoc = self._cls_unwrap_target_assoc_proxy( - target_class, self.value_attr) + target_class, self.value_attr + ) except AttributeError: pass else: - self._lookup_cache[instance_class] = \ - self._construct_for_assoc( - target_assoc, self.parent, self.owning_class, - target_class, self.value_attr + self._lookup_cache[instance_class] = self._construct_for_assoc( + target_assoc, + self.parent, + self.owning_class, + target_class, + self.value_attr, ) class ObjectAssociationProxyInstance(AssociationProxyInstance): """an :class:`.AssociationProxyInstance` that has an object as a target. """ + _target_is_object = True _is_canonical = True @@ -756,17 +817,21 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): if target_assoc is not None: return self._comparator._criterion_exists( target_assoc.contains(obj) - if not target_assoc.scalar else target_assoc == obj + if not target_assoc.scalar + else target_assoc == obj ) - elif self._target_is_object and self.scalar and \ - not self._value_is_scalar: + elif ( + self._target_is_object + and self.scalar + and not self._value_is_scalar + ): return self._comparator.has( getattr(self.target_class, self.value_attr).contains(obj) ) - elif self._target_is_object and self.scalar and \ - self._value_is_scalar: + elif self._target_is_object and self.scalar and self._value_is_scalar: raise exc.InvalidRequestError( - "contains() doesn't apply to a scalar object endpoint; use ==") + "contains() doesn't apply to a scalar object endpoint; use ==" + ) else: return self._comparator._criterion_exists(**{self.value_attr: obj}) @@ -777,7 +842,7 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): if obj is None: return or_( self._comparator.has(**{self.value_attr: obj}), - self._comparator == None + self._comparator == None, ) else: return self._comparator.has(**{self.value_attr: obj}) @@ -786,14 +851,17 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): # note the has() here will fail for collections; eq_() # is only allowed with a scalar. return self._comparator.has( - getattr(self.target_class, self.value_attr) != obj) + getattr(self.target_class, self.value_attr) != obj + ) class ColumnAssociationProxyInstance( - ColumnOperators, AssociationProxyInstance): + ColumnOperators, AssociationProxyInstance +): """an :class:`.AssociationProxyInstance` that has a database column as a target. """ + _target_is_object = False _is_canonical = True @@ -803,9 +871,7 @@ class ColumnAssociationProxyInstance( self.remote_attr.operate(operator.eq, other) ) if other is None: - return or_( - expr, self._comparator == None - ) + return or_(expr, self._comparator == None) else: return expr @@ -824,11 +890,11 @@ class _lazy_collection(object): return getattr(self.parent, self.target) def __getstate__(self): - return {'obj': self.parent, 'target': self.target} + return {"obj": self.parent, "target": self.target} def __setstate__(self, state): - self.parent = state['obj'] - self.target = state['target'] + self.parent = state["obj"] + self.target = state["target"] class _AssociationCollection(object): @@ -874,11 +940,11 @@ class _AssociationCollection(object): __nonzero__ = __bool__ def __getstate__(self): - return {'parent': self.parent, 'lazy_collection': self.lazy_collection} + return {"parent": self.parent, "lazy_collection": self.lazy_collection} def __setstate__(self, state): - self.parent = state['parent'] - self.lazy_collection = state['lazy_collection'] + self.parent = state["parent"] + self.lazy_collection = state["lazy_collection"] self.parent._inflate(self) @@ -925,8 +991,8 @@ class _AssociationList(_AssociationCollection): if len(value) != len(rng): raise ValueError( "attempt to assign sequence of size %s to " - "extended slice of size %s" % (len(value), - len(rng))) + "extended slice of size %s" % (len(value), len(rng)) + ) for i, item in zip(rng, value): self._set(self.col[i], item) @@ -968,8 +1034,14 @@ class _AssociationList(_AssociationCollection): col.append(item) def count(self, value): - return sum([1 for _ in - util.itertools_filter(lambda v: v == value, iter(self))]) + return sum( + [ + 1 + for _ in util.itertools_filter( + lambda v: v == value, iter(self) + ) + ] + ) def extend(self, values): for v in values: @@ -999,7 +1071,7 @@ class _AssociationList(_AssociationCollection): raise NotImplementedError def clear(self): - del self.col[0:len(self.col)] + del self.col[0 : len(self.col)] def __eq__(self, other): return list(self) == other @@ -1040,6 +1112,7 @@ class _AssociationList(_AssociationCollection): if not isinstance(n, int): return NotImplemented return list(self) * n + __rmul__ = __mul__ def __iadd__(self, iterable): @@ -1072,13 +1145,17 @@ class _AssociationList(_AssociationCollection): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in list(locals().items()): - if (util.callable(func) and func.__name__ == func_name and - not func.__doc__ and hasattr(list, func_name)): + if ( + util.callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(list, func_name) + ): func.__doc__ = getattr(list, func_name).__doc__ del func_name, func -_NotProvided = util.symbol('_NotProvided') +_NotProvided = util.symbol("_NotProvided") class _AssociationDict(_AssociationCollection): @@ -1160,6 +1237,7 @@ class _AssociationDict(_AssociationCollection): return self.col.keys() if util.py2k: + def iteritems(self): return ((key, self._get(self.col[key])) for key in self.col) @@ -1174,7 +1252,9 @@ class _AssociationDict(_AssociationCollection): def items(self): return [(k, self._get(self.col[k])) for k in self] + else: + def items(self): return ((key, self._get(self.col[key])) for key in self.col) @@ -1194,14 +1274,15 @@ class _AssociationDict(_AssociationCollection): def update(self, *a, **kw): if len(a) > 1: - raise TypeError('update expected at most 1 arguments, got %i' % - len(a)) + raise TypeError( + "update expected at most 1 arguments, got %i" % len(a) + ) elif len(a) == 1: seq_or_map = a[0] # discern dict from sequence - took the advice from # http://www.voidspace.org.uk/python/articles/duck_typing.shtml # still not perfect :( - if hasattr(seq_or_map, 'keys'): + if hasattr(seq_or_map, "keys"): for item in seq_or_map: self[item] = seq_or_map[item] else: @@ -1211,7 +1292,8 @@ class _AssociationDict(_AssociationCollection): except ValueError: raise ValueError( "dictionary update sequence " - "requires 2-element tuples") + "requires 2-element tuples" + ) for key, value in kw: self[key] = value @@ -1223,8 +1305,12 @@ class _AssociationDict(_AssociationCollection): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in list(locals().items()): - if (util.callable(func) and func.__name__ == func_name and - not func.__doc__ and hasattr(dict, func_name)): + if ( + util.callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(dict, func_name) + ): func.__doc__ = getattr(dict, func_name).__doc__ del func_name, func @@ -1288,7 +1374,7 @@ class _AssociationSet(_AssociationCollection): def pop(self): if not self.col: - raise KeyError('pop from an empty set') + raise KeyError("pop from an empty set") member = self.col.pop() return self._get(member) @@ -1420,7 +1506,11 @@ class _AssociationSet(_AssociationCollection): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in list(locals().items()): - if (util.callable(func) and func.__name__ == func_name and - not func.__doc__ and hasattr(set, func_name)): + if ( + util.callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(set, func_name) + ): func.__doc__ = getattr(set, func_name).__doc__ del func_name, func diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py index cafb3d61c8..747373a2a4 100644 --- a/lib/sqlalchemy/ext/automap.py +++ b/lib/sqlalchemy/ext/automap.py @@ -580,7 +580,8 @@ def name_for_scalar_relationship(base, local_cls, referred_cls, constraint): def name_for_collection_relationship( - base, local_cls, referred_cls, constraint): + base, local_cls, referred_cls, constraint +): """Return the attribute name that should be used to refer from one class to another, for a collection reference. @@ -607,7 +608,8 @@ def name_for_collection_relationship( def generate_relationship( - base, direction, return_fn, attrname, local_cls, referred_cls, **kw): + base, direction, return_fn, attrname, local_cls, referred_cls, **kw +): r"""Generate a :func:`.relationship` or :func:`.backref` on behalf of two mapped classes. @@ -677,6 +679,7 @@ class AutomapBase(object): :ref:`automap_toplevel` """ + __abstract__ = True classes = None @@ -694,15 +697,16 @@ class AutomapBase(object): @classmethod def prepare( - cls, - engine=None, - reflect=False, - schema=None, - classname_for_table=classname_for_table, - collection_class=list, - name_for_scalar_relationship=name_for_scalar_relationship, - name_for_collection_relationship=name_for_collection_relationship, - generate_relationship=generate_relationship): + cls, + engine=None, + reflect=False, + schema=None, + classname_for_table=classname_for_table, + collection_class=list, + name_for_scalar_relationship=name_for_scalar_relationship, + name_for_collection_relationship=name_for_collection_relationship, + generate_relationship=generate_relationship, + ): """Extract mapped classes and relationships from the :class:`.MetaData` and perform mappings. @@ -752,15 +756,16 @@ class AutomapBase(object): engine, schema=schema, extend_existing=True, - autoload_replace=False + autoload_replace=False, ) _CONFIGURE_MUTEX.acquire() try: table_to_map_config = dict( (m.local_table, m) - for m in _DeferredMapperConfig. - classes_for_base(cls, sort=False) + for m in _DeferredMapperConfig.classes_for_base( + cls, sort=False + ) ) many_to_many = [] @@ -774,30 +779,39 @@ class AutomapBase(object): elif table not in table_to_map_config: mapped_cls = type( classname_for_table(cls, table.name, table), - (cls, ), - {"__table__": table} + (cls,), + {"__table__": table}, ) map_config = _DeferredMapperConfig.config_for_cls( - mapped_cls) + mapped_cls + ) cls.classes[map_config.cls.__name__] = mapped_cls table_to_map_config[table] = map_config for map_config in table_to_map_config.values(): - _relationships_for_fks(cls, - map_config, - table_to_map_config, - collection_class, - name_for_scalar_relationship, - name_for_collection_relationship, - generate_relationship) + _relationships_for_fks( + cls, + map_config, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, + ) for lcl_m2m, rem_m2m, m2m_const, table in many_to_many: - _m2m_relationship(cls, lcl_m2m, rem_m2m, m2m_const, table, - table_to_map_config, - collection_class, - name_for_scalar_relationship, - name_for_collection_relationship, - generate_relationship) + _m2m_relationship( + cls, + lcl_m2m, + rem_m2m, + m2m_const, + table, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, + ) for map_config in _DeferredMapperConfig.classes_for_base(cls): map_config.map() @@ -853,20 +867,27 @@ def automap_base(declarative_base=None, **kw): return type( Base.__name__, - (AutomapBase, Base,), - {"__abstract__": True, "classes": util.Properties({})} + (AutomapBase, Base), + {"__abstract__": True, "classes": util.Properties({})}, ) def _is_many_to_many(automap_base, table): - fk_constraints = [const for const in table.constraints - if isinstance(const, ForeignKeyConstraint)] + fk_constraints = [ + const + for const in table.constraints + if isinstance(const, ForeignKeyConstraint) + ] if len(fk_constraints) != 2: return None, None, None cols = sum( - [[fk.parent for fk in fk_constraint.elements] - for fk_constraint in fk_constraints], []) + [ + [fk.parent for fk in fk_constraint.elements] + for fk_constraint in fk_constraints + ], + [], + ) if set(cols) != set(table.c): return None, None, None @@ -874,15 +895,19 @@ def _is_many_to_many(automap_base, table): return ( fk_constraints[0].elements[0].column.table, fk_constraints[1].elements[0].column.table, - fk_constraints + fk_constraints, ) -def _relationships_for_fks(automap_base, map_config, table_to_map_config, - collection_class, - name_for_scalar_relationship, - name_for_collection_relationship, - generate_relationship): +def _relationships_for_fks( + automap_base, + map_config, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, +): local_table = map_config.local_table local_cls = map_config.cls # derived from a weakref, may be None @@ -898,32 +923,33 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config, referred_cls = referred_cfg.cls if local_cls is not referred_cls and issubclass( - local_cls, referred_cls): + local_cls, referred_cls + ): continue relationship_name = name_for_scalar_relationship( - automap_base, - local_cls, - referred_cls, constraint) + automap_base, local_cls, referred_cls, constraint + ) backref_name = name_for_collection_relationship( - automap_base, - referred_cls, - local_cls, - constraint + automap_base, referred_cls, local_cls, constraint ) o2m_kws = {} nullable = False not in {fk.parent.nullable for fk in fks} if not nullable: - o2m_kws['cascade'] = "all, delete-orphan" + o2m_kws["cascade"] = "all, delete-orphan" - if constraint.ondelete and \ - constraint.ondelete.lower() == "cascade": - o2m_kws['passive_deletes'] = True + if ( + constraint.ondelete + and constraint.ondelete.lower() == "cascade" + ): + o2m_kws["passive_deletes"] = True else: - if constraint.ondelete and \ - constraint.ondelete.lower() == "set null": - o2m_kws['passive_deletes'] = True + if ( + constraint.ondelete + and constraint.ondelete.lower() == "set null" + ): + o2m_kws["passive_deletes"] = True create_backref = backref_name not in referred_cfg.properties @@ -931,54 +957,65 @@ def _relationships_for_fks(automap_base, map_config, table_to_map_config, if create_backref: backref_obj = generate_relationship( automap_base, - interfaces.ONETOMANY, backref, - backref_name, referred_cls, local_cls, + interfaces.ONETOMANY, + backref, + backref_name, + referred_cls, + local_cls, collection_class=collection_class, - **o2m_kws) + **o2m_kws + ) else: backref_obj = None - rel = generate_relationship(automap_base, - interfaces.MANYTOONE, - relationship, - relationship_name, - local_cls, referred_cls, - foreign_keys=[ - fk.parent - for fk in constraint.elements], - backref=backref_obj, - remote_side=[ - fk.column - for fk in constraint.elements] - ) + rel = generate_relationship( + automap_base, + interfaces.MANYTOONE, + relationship, + relationship_name, + local_cls, + referred_cls, + foreign_keys=[fk.parent for fk in constraint.elements], + backref=backref_obj, + remote_side=[fk.column for fk in constraint.elements], + ) if rel is not None: map_config.properties[relationship_name] = rel if not create_backref: referred_cfg.properties[ - backref_name].back_populates = relationship_name + backref_name + ].back_populates = relationship_name elif create_backref: - rel = generate_relationship(automap_base, - interfaces.ONETOMANY, - relationship, - backref_name, - referred_cls, local_cls, - foreign_keys=[ - fk.parent - for fk in constraint.elements], - back_populates=relationship_name, - collection_class=collection_class, - **o2m_kws) + rel = generate_relationship( + automap_base, + interfaces.ONETOMANY, + relationship, + backref_name, + referred_cls, + local_cls, + foreign_keys=[fk.parent for fk in constraint.elements], + back_populates=relationship_name, + collection_class=collection_class, + **o2m_kws + ) if rel is not None: referred_cfg.properties[backref_name] = rel map_config.properties[ - relationship_name].back_populates = backref_name - - -def _m2m_relationship(automap_base, lcl_m2m, rem_m2m, m2m_const, table, - table_to_map_config, - collection_class, - name_for_scalar_relationship, - name_for_collection_relationship, - generate_relationship): + relationship_name + ].back_populates = backref_name + + +def _m2m_relationship( + automap_base, + lcl_m2m, + rem_m2m, + m2m_const, + table, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, +): map_config = table_to_map_config.get(lcl_m2m, None) referred_cfg = table_to_map_config.get(rem_m2m, None) @@ -989,14 +1026,10 @@ def _m2m_relationship(automap_base, lcl_m2m, rem_m2m, m2m_const, table, referred_cls = referred_cfg.cls relationship_name = name_for_collection_relationship( - automap_base, - local_cls, - referred_cls, m2m_const[0]) + automap_base, local_cls, referred_cls, m2m_const[0] + ) backref_name = name_for_collection_relationship( - automap_base, - referred_cls, - local_cls, - m2m_const[1] + automap_base, referred_cls, local_cls, m2m_const[1] ) create_backref = backref_name not in referred_cfg.properties @@ -1008,48 +1041,56 @@ def _m2m_relationship(automap_base, lcl_m2m, rem_m2m, m2m_const, table, interfaces.MANYTOMANY, backref, backref_name, - referred_cls, local_cls, - collection_class=collection_class + referred_cls, + local_cls, + collection_class=collection_class, ) else: backref_obj = None - rel = generate_relationship(automap_base, - interfaces.MANYTOMANY, - relationship, - relationship_name, - local_cls, referred_cls, - secondary=table, - primaryjoin=and_( - fk.column == fk.parent - for fk in m2m_const[0].elements), - secondaryjoin=and_( - fk.column == fk.parent - for fk in m2m_const[1].elements), - backref=backref_obj, - collection_class=collection_class - ) + rel = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + relationship, + relationship_name, + local_cls, + referred_cls, + secondary=table, + primaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[0].elements + ), + secondaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[1].elements + ), + backref=backref_obj, + collection_class=collection_class, + ) if rel is not None: map_config.properties[relationship_name] = rel if not create_backref: referred_cfg.properties[ - backref_name].back_populates = relationship_name + backref_name + ].back_populates = relationship_name elif create_backref: - rel = generate_relationship(automap_base, - interfaces.MANYTOMANY, - relationship, - backref_name, - referred_cls, local_cls, - secondary=table, - primaryjoin=and_( - fk.column == fk.parent - for fk in m2m_const[1].elements), - secondaryjoin=and_( - fk.column == fk.parent - for fk in m2m_const[0].elements), - back_populates=relationship_name, - collection_class=collection_class) + rel = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + relationship, + backref_name, + referred_cls, + local_cls, + secondary=table, + primaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[1].elements + ), + secondaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[0].elements + ), + back_populates=relationship_name, + collection_class=collection_class, + ) if rel is not None: referred_cfg.properties[backref_name] = rel map_config.properties[ - relationship_name].back_populates = backref_name + relationship_name + ].back_populates = backref_name diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index 5168791424..f55231a091 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -38,7 +38,8 @@ class Bakery(object): """ - __slots__ = 'cls', 'cache' + + __slots__ = "cls", "cache" def __init__(self, cls_, cache): self.cls = cls_ @@ -51,7 +52,7 @@ class Bakery(object): class BakedQuery(object): """A builder object for :class:`.query.Query` objects.""" - __slots__ = 'steps', '_bakery', '_cache_key', '_spoiled' + __slots__ = "steps", "_bakery", "_cache_key", "_spoiled" def __init__(self, bakery, initial_fn, args=()): self._cache_key = () @@ -148,7 +149,7 @@ class BakedQuery(object): """ if not full and not self._spoiled: _spoil_point = self._clone() - _spoil_point._cache_key += ('_query_only', ) + _spoil_point._cache_key += ("_query_only",) self.steps = [_spoil_point._retrieve_baked_query] self._spoiled = True return self @@ -164,7 +165,7 @@ class BakedQuery(object): session will want to use. """ - return self._cache_key + (session._query_cls, ) + return self._cache_key + (session._query_cls,) def _with_lazyload_options(self, options, effective_path, cache_path=None): """Cloning version of _add_lazyload_options. @@ -201,16 +202,20 @@ class BakedQuery(object): key += cache_key self.add_criteria( - lambda q: q._with_current_path(effective_path). - _conditional_options(*options), - cache_path.path, key + lambda q: q._with_current_path( + effective_path + )._conditional_options(*options), + cache_path.path, + key, ) def _retrieve_baked_query(self, session): query = self._bakery.get(self._effective_key(session), None) if query is None: query = self._as_query(session) - self._bakery[self._effective_key(session)] = query.with_session(None) + self._bakery[self._effective_key(session)] = query.with_session( + None + ) return query.with_session(session) def _bake(self, session): @@ -227,8 +232,12 @@ class BakedQuery(object): # so delete some compilation-use-only attributes that can take up # space for attr in ( - '_correlate', '_from_obj', '_mapper_adapter_map', - '_joinpath', '_joinpoint'): + "_correlate", + "_from_obj", + "_mapper_adapter_map", + "_joinpath", + "_joinpoint", + ): query.__dict__.pop(attr, None) self._bakery[self._effective_key(session)] = context return context @@ -276,11 +285,13 @@ class BakedQuery(object): session = query_or_session.session if session is None: raise sa_exc.ArgumentError( - "Given Query needs to be associated with a Session") + "Given Query needs to be associated with a Session" + ) else: raise TypeError( - "Query or Session object expected, got %r." % - type(query_or_session)) + "Query or Session object expected, got %r." + % type(query_or_session) + ) return self._as_query(session) def _as_query(self, session): @@ -299,10 +310,10 @@ class BakedQuery(object): a "baked" query so that we save on performance too. """ - context.attributes['baked_queries'] = baked_queries = [] + context.attributes["baked_queries"] = baked_queries = [] for k, v in list(context.attributes.items()): if isinstance(v, Query): - if 'subquery' in k: + if "subquery" in k: bk = BakedQuery(self._bakery, lambda *args: v) bk._cache_key = self._cache_key + k bk._bake(session) @@ -310,15 +321,17 @@ class BakedQuery(object): del context.attributes[k] def _unbake_subquery_loaders( - self, session, context, params, post_criteria): + self, session, context, params, post_criteria + ): """Retrieve subquery eager loaders stored by _bake_subquery_loaders and turn them back into Result objects that will iterate just like a Query object. """ for k, cache_key, query in context.attributes["baked_queries"]: - bk = BakedQuery(self._bakery, - lambda sess, q=query: q.with_session(sess)) + bk = BakedQuery( + self._bakery, lambda sess, q=query: q.with_session(sess) + ) bk._cache_key = cache_key q = bk.for_session(session) for fn in post_criteria: @@ -334,7 +347,8 @@ class Result(object): against a target :class:`.Session`, and is then invoked for results. """ - __slots__ = 'bq', 'session', '_params', '_post_criteria' + + __slots__ = "bq", "session", "_params", "_post_criteria" def __init__(self, bq, session): self.bq = bq @@ -350,7 +364,8 @@ class Result(object): elif len(args) > 0: raise sa_exc.ArgumentError( "params() takes zero or one positional argument, " - "which is a dictionary.") + "which is a dictionary." + ) self._params.update(kw) return self @@ -403,7 +418,8 @@ class Result(object): context.attributes = context.attributes.copy() bq._unbake_subquery_loaders( - self.session, context, self._params, self._post_criteria) + self.session, context, self._params, self._post_criteria + ) context.statement.use_labels = True if context.autoflush and not context.populate_existing: @@ -426,7 +442,7 @@ class Result(object): """ - col = func.count(literal_column('*')) + col = func.count(literal_column("*")) bq = self.bq.with_criteria(lambda q: q.from_self(col)) return bq.for_session(self.session).params(self._params).scalar() @@ -456,8 +472,10 @@ class Result(object): """ bq = self.bq.with_criteria(lambda q: q.slice(0, 1)) ret = list( - bq.for_session(self.session).params(self._params). - _using_post_criteria(self._post_criteria)) + bq.for_session(self.session) + .params(self._params) + ._using_post_criteria(self._post_criteria) + ) if len(ret) > 0: return ret[0] else: @@ -473,7 +491,8 @@ class Result(object): ret = self.one_or_none() except orm_exc.MultipleResultsFound: raise orm_exc.MultipleResultsFound( - "Multiple rows were found for one()") + "Multiple rows were found for one()" + ) else: if ret is None: raise orm_exc.NoResultFound("No row was found for one()") @@ -497,7 +516,8 @@ class Result(object): return None else: raise orm_exc.MultipleResultsFound( - "Multiple rows were found for one_or_none()") + "Multiple rows were found for one_or_none()" + ) def all(self): """Return all rows. @@ -533,13 +553,18 @@ class Result(object): # None present in ident - turn those comparisons # into "IS NULL" if None in primary_key_identity: - nones = set([ - _get_params[col].key for col, value in - zip(mapper.primary_key, primary_key_identity) - if value is None - ]) + nones = set( + [ + _get_params[col].key + for col, value in zip( + mapper.primary_key, primary_key_identity + ) + if value is None + ] + ) _lcl_get_clause = sql_util.adapt_criterion_to_null( - _lcl_get_clause, nones) + _lcl_get_clause, nones + ) _lcl_get_clause = q._adapt_clause(_lcl_get_clause, True, False) q._criterion = _lcl_get_clause @@ -556,16 +581,20 @@ class Result(object): # key so that if a race causes multiple calls to _get_clause, # we've cached on ours bq = bq._clone() - bq._cache_key += (_get_clause, ) + bq._cache_key += (_get_clause,) bq = bq.with_criteria( - setup, tuple(elem is None for elem in primary_key_identity)) + setup, tuple(elem is None for elem in primary_key_identity) + ) - params = dict([ - (_get_params[primary_key].key, id_val) - for id_val, primary_key - in zip(primary_key_identity, mapper.primary_key) - ]) + params = dict( + [ + (_get_params[primary_key].key, id_val) + for id_val, primary_key in zip( + primary_key_identity, mapper.primary_key + ) + ] + ) result = list(bq.for_session(self.session).params(**params)) l = len(result) @@ -578,7 +607,8 @@ class Result(object): @util.deprecated( - "1.2", "Baked lazy loading is now the default implementation.") + "1.2", "Baked lazy loading is now the default implementation." +) def bake_lazy_loaders(): """Enable the use of baked queries for all lazyloaders systemwide. @@ -590,7 +620,8 @@ def bake_lazy_loaders(): @util.deprecated( - "1.2", "Baked lazy loading is now the default implementation.") + "1.2", "Baked lazy loading is now the default implementation." +) def unbake_lazy_loaders(): """Disable the use of baked queries for all lazyloaders systemwide. @@ -601,7 +632,8 @@ def unbake_lazy_loaders(): """ raise NotImplementedError( - "Baked lazy loading is now the default implementation") + "Baked lazy loading is now the default implementation" + ) @strategy_options.loader_option() @@ -615,20 +647,27 @@ def baked_lazyload(loadopt, attr): @baked_lazyload._add_unbound_fn @util.deprecated( - "1.2", "Baked lazy loading is now the default " - "implementation for lazy loading.") + "1.2", + "Baked lazy loading is now the default " + "implementation for lazy loading.", +) def baked_lazyload(*keys): return strategy_options._UnboundLoad._from_keys( - strategy_options._UnboundLoad.baked_lazyload, keys, False, {}) + strategy_options._UnboundLoad.baked_lazyload, keys, False, {} + ) @baked_lazyload._add_unbound_all_fn @util.deprecated( - "1.2", "Baked lazy loading is now the default " - "implementation for lazy loading.") + "1.2", + "Baked lazy loading is now the default " + "implementation for lazy loading.", +) def baked_lazyload_all(*keys): return strategy_options._UnboundLoad._from_keys( - strategy_options._UnboundLoad.baked_lazyload, keys, True, {}) + strategy_options._UnboundLoad.baked_lazyload, keys, True, {} + ) + baked_lazyload = baked_lazyload._unbound_fn baked_lazyload_all = baked_lazyload_all._unbound_all_fn diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 6a0909d361..220b2c057b 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -407,37 +407,44 @@ def compiles(class_, *specs): def decorate(fn): # get an existing @compiles handler - existing = class_.__dict__.get('_compiler_dispatcher', None) + existing = class_.__dict__.get("_compiler_dispatcher", None) # get the original handler. All ClauseElement classes have one # of these, but some TypeEngine classes will not. - existing_dispatch = getattr(class_, '_compiler_dispatch', None) + existing_dispatch = getattr(class_, "_compiler_dispatch", None) if not existing: existing = _dispatcher() if existing_dispatch: + def _wrap_existing_dispatch(element, compiler, **kw): try: return existing_dispatch(element, compiler, **kw) except exc.UnsupportedCompilationError: raise exc.CompileError( "%s construct has no default " - "compilation handler." % type(element)) - existing.specs['default'] = _wrap_existing_dispatch + "compilation handler." % type(element) + ) + + existing.specs["default"] = _wrap_existing_dispatch # TODO: why is the lambda needed ? - setattr(class_, '_compiler_dispatch', - lambda *arg, **kw: existing(*arg, **kw)) - setattr(class_, '_compiler_dispatcher', existing) + setattr( + class_, + "_compiler_dispatch", + lambda *arg, **kw: existing(*arg, **kw), + ) + setattr(class_, "_compiler_dispatcher", existing) if specs: for s in specs: existing.specs[s] = fn else: - existing.specs['default'] = fn + existing.specs["default"] = fn return fn + return decorate @@ -445,7 +452,7 @@ def deregister(class_): """Remove all custom compilers associated with a given :class:`.ClauseElement` type.""" - if hasattr(class_, '_compiler_dispatcher'): + if hasattr(class_, "_compiler_dispatcher"): # regenerate default _compiler_dispatch visitors._generate_dispatch(class_) # remove custom directive @@ -461,10 +468,11 @@ class _dispatcher(object): fn = self.specs.get(compiler.dialect.name, None) if not fn: try: - fn = self.specs['default'] + fn = self.specs["default"] except KeyError: raise exc.CompileError( "%s construct has no default " - "compilation handler." % type(element)) + "compilation handler." % type(element) + ) return fn(element, compiler, **kw) diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py index cb81f51e5c..2b0a37884a 100644 --- a/lib/sqlalchemy/ext/declarative/__init__.py +++ b/lib/sqlalchemy/ext/declarative/__init__.py @@ -5,14 +5,31 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from .api import declarative_base, synonym_for, comparable_using, \ - instrument_declarative, ConcreteBase, AbstractConcreteBase, \ - DeclarativeMeta, DeferredReflection, has_inherited_table,\ - declared_attr, as_declarative +from .api import ( + declarative_base, + synonym_for, + comparable_using, + instrument_declarative, + ConcreteBase, + AbstractConcreteBase, + DeclarativeMeta, + DeferredReflection, + has_inherited_table, + declared_attr, + as_declarative, +) -__all__ = ['declarative_base', 'synonym_for', 'has_inherited_table', - 'comparable_using', 'instrument_declarative', 'declared_attr', - 'as_declarative', - 'ConcreteBase', 'AbstractConcreteBase', 'DeclarativeMeta', - 'DeferredReflection'] +__all__ = [ + "declarative_base", + "synonym_for", + "has_inherited_table", + "comparable_using", + "instrument_declarative", + "declared_attr", + "as_declarative", + "ConcreteBase", + "AbstractConcreteBase", + "DeclarativeMeta", + "DeferredReflection", +] diff --git a/lib/sqlalchemy/ext/declarative/api.py b/lib/sqlalchemy/ext/declarative/api.py index 865cd16f0f..987e921196 100644 --- a/lib/sqlalchemy/ext/declarative/api.py +++ b/lib/sqlalchemy/ext/declarative/api.py @@ -8,9 +8,13 @@ from ...schema import Table, MetaData, Column -from ...orm import synonym as _orm_synonym, \ - comparable_property,\ - interfaces, properties, attributes +from ...orm import ( + synonym as _orm_synonym, + comparable_property, + interfaces, + properties, + attributes, +) from ...orm.util import polymorphic_union from ...orm.base import _mapper_or_none from ...util import OrderedDict, hybridmethod, hybridproperty @@ -19,9 +23,13 @@ from ... import exc import weakref import re -from .base import _as_declarative, \ - _declarative_constructor,\ - _DeferredMapperConfig, _add_attribute, _del_attribute +from .base import ( + _as_declarative, + _declarative_constructor, + _DeferredMapperConfig, + _add_attribute, + _del_attribute, +) from .clsregistry import _class_resolver @@ -31,10 +39,10 @@ def instrument_declarative(cls, registry, metadata): MetaData object. """ - if '_decl_class_registry' in cls.__dict__: + if "_decl_class_registry" in cls.__dict__: raise exc.InvalidRequestError( - "Class %r already has been " - "instrumented declaratively" % cls) + "Class %r already has been " "instrumented declaratively" % cls + ) cls._decl_class_registry = registry cls.metadata = metadata _as_declarative(cls, cls.__name__, cls.__dict__) @@ -54,14 +62,14 @@ def has_inherited_table(cls): """ for class_ in cls.__mro__[1:]: - if getattr(class_, '__table__', None) is not None: + if getattr(class_, "__table__", None) is not None: return True return False class DeclarativeMeta(type): def __init__(cls, classname, bases, dict_): - if '_decl_class_registry' not in cls.__dict__: + if "_decl_class_registry" not in cls.__dict__: _as_declarative(cls, classname, cls.__dict__) type.__init__(cls, classname, bases, dict_) @@ -71,6 +79,7 @@ class DeclarativeMeta(type): def __delattr__(cls, key): _del_attribute(cls, key) + def synonym_for(name, map_column=False): """Decorator that produces an :func:`.orm.synonym` attribute in conjunction with a Python descriptor. @@ -104,8 +113,10 @@ def synonym_for(name, map_column=False): can be achieved with synonyms. """ + def decorate(fn): return _orm_synonym(name, map_column=map_column, descriptor=fn) + return decorate @@ -127,8 +138,10 @@ def comparable_using(comparator_factory): prop = comparable_property(MyComparatorType) """ + def decorate(fn): return comparable_property(comparator_factory, fn) + return decorate @@ -190,14 +203,16 @@ class declared_attr(interfaces._MappedAttribute, property): self._cascading = cascading def __get__(desc, self, cls): - reg = cls.__dict__.get('_sa_declared_attr_reg', None) + reg = cls.__dict__.get("_sa_declared_attr_reg", None) if reg is None: - if not re.match(r'^__.+__$', desc.fget.__name__) and \ - attributes.manager_of_class(cls) is None: + if ( + not re.match(r"^__.+__$", desc.fget.__name__) + and attributes.manager_of_class(cls) is None + ): util.warn( "Unmanaged access of declarative attribute %s from " - "non-mapped class %s" % - (desc.fget.__name__, cls.__name__)) + "non-mapped class %s" % (desc.fget.__name__, cls.__name__) + ) return desc.fget(cls) elif desc in reg: return reg[desc] @@ -283,10 +298,16 @@ class _stateful_declared_attr(declared_attr): return declared_attr(fn, **self.kw) -def declarative_base(bind=None, metadata=None, mapper=None, cls=object, - name='Base', constructor=_declarative_constructor, - class_registry=None, - metaclass=DeclarativeMeta): +def declarative_base( + bind=None, + metadata=None, + mapper=None, + cls=object, + name="Base", + constructor=_declarative_constructor, + class_registry=None, + metaclass=DeclarativeMeta, +): r"""Construct a base class for declarative class definitions. The new base class will be given a metaclass that produces @@ -357,16 +378,17 @@ def declarative_base(bind=None, metadata=None, mapper=None, cls=object, class_registry = weakref.WeakValueDictionary() bases = not isinstance(cls, tuple) and (cls,) or cls - class_dict = dict(_decl_class_registry=class_registry, - metadata=lcl_metadata) + class_dict = dict( + _decl_class_registry=class_registry, metadata=lcl_metadata + ) if isinstance(cls, type): - class_dict['__doc__'] = cls.__doc__ + class_dict["__doc__"] = cls.__doc__ if constructor: - class_dict['__init__'] = constructor + class_dict["__init__"] = constructor if mapper: - class_dict['__mapper_cls__'] = mapper + class_dict["__mapper_cls__"] = mapper return metaclass(name, bases, class_dict) @@ -401,9 +423,10 @@ def as_declarative(**kw): :func:`.declarative_base` """ + def decorate(cls): - kw['cls'] = cls - kw['name'] = cls.__name__ + kw["cls"] = cls + kw["name"] = cls.__name__ return declarative_base(**kw) return decorate @@ -456,10 +479,13 @@ class ConcreteBase(object): @classmethod def _create_polymorphic_union(cls, mappers): - return polymorphic_union(OrderedDict( - (mp.polymorphic_identity, mp.local_table) - for mp in mappers - ), 'type', 'pjoin') + return polymorphic_union( + OrderedDict( + (mp.polymorphic_identity, mp.local_table) for mp in mappers + ), + "type", + "pjoin", + ) @classmethod def __declare_first__(cls): @@ -568,7 +594,7 @@ class AbstractConcreteBase(ConcreteBase): @classmethod def _sa_decl_prepare_nocascade(cls): - if getattr(cls, '__mapper__', None): + if getattr(cls, "__mapper__", None): return to_map = _DeferredMapperConfig.config_for_cls(cls) @@ -604,8 +630,9 @@ class AbstractConcreteBase(ConcreteBase): def mapper_args(): args = m_args() - args['polymorphic_on'] = pjoin.c.type + args["polymorphic_on"] = pjoin.c.type return args + to_map.mapper_args_fn = mapper_args m = to_map.map() @@ -684,6 +711,7 @@ class DeferredReflection(object): .. versionadded:: 0.8 """ + @classmethod def prepare(cls, engine): """Reflect all :class:`.Table` objects for all current @@ -696,8 +724,10 @@ class DeferredReflection(object): mapper = thingy.cls.__mapper__ metadata = mapper.class_.metadata for rel in mapper._props.values(): - if isinstance(rel, properties.RelationshipProperty) and \ - rel.secondary is not None: + if ( + isinstance(rel, properties.RelationshipProperty) + and rel.secondary is not None + ): if isinstance(rel.secondary, Table): cls._reflect_table(rel.secondary, engine) elif isinstance(rel.secondary, _class_resolver): @@ -711,6 +741,7 @@ class DeferredReflection(object): t1 = Table(key, metadata) cls._reflect_table(t1, engine) return t1 + return _resolve @classmethod @@ -724,10 +755,12 @@ class DeferredReflection(object): @classmethod def _reflect_table(cls, table, engine): - Table(table.name, - table.metadata, - extend_existing=True, - autoload_replace=False, - autoload=True, - autoload_with=engine, - schema=table.schema) + Table( + table.name, + table.metadata, + extend_existing=True, + autoload_replace=False, + autoload=True, + autoload_with=engine, + schema=table.schema, + ) diff --git a/lib/sqlalchemy/ext/declarative/base.py b/lib/sqlalchemy/ext/declarative/base.py index f27314b5e6..07778f733c 100644 --- a/lib/sqlalchemy/ext/declarative/base.py +++ b/lib/sqlalchemy/ext/declarative/base.py @@ -39,7 +39,7 @@ def _resolve_for_abstract_or_classical(cls): if cls is object: return None - if _get_immediate_cls_attr(cls, '__abstract__', strict=True): + if _get_immediate_cls_attr(cls, "__abstract__", strict=True): for sup in cls.__bases__: sup = _resolve_for_abstract_or_classical(sup) if sup is not None: @@ -59,7 +59,7 @@ def _dive_for_classically_mapped_class(cls): # if we are within a base hierarchy, don't # search at all for classical mappings - if hasattr(cls, '_decl_class_registry'): + if hasattr(cls, "_decl_class_registry"): return None manager = instrumentation.manager_of_class(cls) @@ -89,15 +89,19 @@ def _get_immediate_cls_attr(cls, attrname, strict=False): return None for base in cls.__mro__: - _is_declarative_inherits = hasattr(base, '_decl_class_registry') - _is_classicial_inherits = not _is_declarative_inherits and \ - _dive_for_classically_mapped_class(base) is not None + _is_declarative_inherits = hasattr(base, "_decl_class_registry") + _is_classicial_inherits = ( + not _is_declarative_inherits + and _dive_for_classically_mapped_class(base) is not None + ) if attrname in base.__dict__ and ( - base is cls or - ((base in cls.__bases__ if strict else True) + base is cls + or ( + (base in cls.__bases__ if strict else True) and not _is_declarative_inherits - and not _is_classicial_inherits) + and not _is_classicial_inherits + ) ): return getattr(base, attrname) else: @@ -108,9 +112,10 @@ def _as_declarative(cls, classname, dict_): global declared_attr, declarative_props if declared_attr is None: from .api import declared_attr + declarative_props = (declared_attr, util.classproperty) - if _get_immediate_cls_attr(cls, '__abstract__', strict=True): + if _get_immediate_cls_attr(cls, "__abstract__", strict=True): return _MapperConfig.setup_mapping(cls, classname, dict_) @@ -119,23 +124,23 @@ def _as_declarative(cls, classname, dict_): def _check_declared_props_nocascade(obj, name, cls): if isinstance(obj, declarative_props): - if getattr(obj, '_cascading', False): + if getattr(obj, "_cascading", False): util.warn( "@declared_attr.cascading is not supported on the %s " "attribute on class %s. This attribute invokes for " - "subclasses in any case." % (name, cls)) + "subclasses in any case." % (name, cls) + ) return True else: return False class _MapperConfig(object): - @classmethod def setup_mapping(cls, cls_, classname, dict_): defer_map = _get_immediate_cls_attr( - cls_, '_sa_decl_prepare_nocascade', strict=True) or \ - hasattr(cls_, '_sa_decl_prepare') + cls_, "_sa_decl_prepare_nocascade", strict=True + ) or hasattr(cls_, "_sa_decl_prepare") if defer_map: cfg_cls = _DeferredMapperConfig @@ -179,12 +184,14 @@ class _MapperConfig(object): self.map() def _setup_declared_events(self): - if _get_immediate_cls_attr(self.cls, '__declare_last__'): + if _get_immediate_cls_attr(self.cls, "__declare_last__"): + @event.listens_for(mapper, "after_configured") def after_configured(): self.cls.__declare_last__() - if _get_immediate_cls_attr(self.cls, '__declare_first__'): + if _get_immediate_cls_attr(self.cls, "__declare_first__"): + @event.listens_for(mapper, "before_configured") def before_configured(): self.cls.__declare_first__() @@ -198,59 +205,62 @@ class _MapperConfig(object): tablename = None for base in cls.__mro__: - class_mapped = base is not cls and \ - _declared_mapping_info(base) is not None and \ - not _get_immediate_cls_attr( - base, '_sa_decl_prepare_nocascade', strict=True) + class_mapped = ( + base is not cls + and _declared_mapping_info(base) is not None + and not _get_immediate_cls_attr( + base, "_sa_decl_prepare_nocascade", strict=True + ) + ) if not class_mapped and base is not cls: self._produce_column_copies(base) for name, obj in vars(base).items(): - if name == '__mapper_args__': - check_decl = \ - _check_declared_props_nocascade(obj, name, cls) - if not mapper_args_fn and ( - not class_mapped or - check_decl - ): + if name == "__mapper_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not mapper_args_fn and (not class_mapped or check_decl): # don't even invoke __mapper_args__ until # after we've determined everything about the # mapped table. # make a copy of it so a class-level dictionary # is not overwritten when we update column-based # arguments. - mapper_args_fn = lambda: dict(cls.__mapper_args__) # noqa - elif name == '__tablename__': - check_decl = \ - _check_declared_props_nocascade(obj, name, cls) - if not tablename and ( - not class_mapped or - check_decl - ): + mapper_args_fn = lambda: dict( + cls.__mapper_args__ + ) # noqa + elif name == "__tablename__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not tablename and (not class_mapped or check_decl): tablename = cls.__tablename__ - elif name == '__table_args__': - check_decl = \ - _check_declared_props_nocascade(obj, name, cls) - if not table_args and ( - not class_mapped or - check_decl - ): + elif name == "__table_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not table_args and (not class_mapped or check_decl): table_args = cls.__table_args__ if not isinstance( - table_args, (tuple, dict, type(None))): + table_args, (tuple, dict, type(None)) + ): raise exc.ArgumentError( "__table_args__ value must be a tuple, " - "dict, or None") + "dict, or None" + ) if base is not cls: inherited_table_args = True elif class_mapped: if isinstance(obj, declarative_props): - util.warn("Regular (i.e. not __special__) " - "attribute '%s.%s' uses @declared_attr, " - "but owning class %s is mapped - " - "not applying to subclass %s." - % (base.__name__, name, base, cls)) + util.warn( + "Regular (i.e. not __special__) " + "attribute '%s.%s' uses @declared_attr, " + "but owning class %s is mapped - " + "not applying to subclass %s." + % (base.__name__, name, base, cls) + ) continue elif base is not cls: # we're a mixin, abstract base, or something that is @@ -263,7 +273,8 @@ class _MapperConfig(object): "Mapper properties (i.e. deferred," "column_property(), relationship(), etc.) must " "be declared as @declared_attr callables " - "on declarative mixin classes.") + "on declarative mixin classes." + ) elif isinstance(obj, declarative_props): oldclassprop = isinstance(obj, util.classproperty) if not oldclassprop and obj._cascading: @@ -278,15 +289,18 @@ class _MapperConfig(object): "Attribute '%s' on class %s cannot be " "processed due to " "@declared_attr.cascading; " - "skipping" % (name, cls)) - dict_[name] = column_copies[obj] = \ - ret = obj.__get__(obj, cls) + "skipping" % (name, cls) + ) + dict_[name] = column_copies[ + obj + ] = ret = obj.__get__(obj, cls) setattr(cls, name, ret) else: if oldclassprop: util.warn_deprecated( "Use of sqlalchemy.util.classproperty on " - "declarative classes is deprecated.") + "declarative classes is deprecated." + ) # access attribute using normal class access ret = getattr(cls, name) @@ -294,14 +308,20 @@ class _MapperConfig(object): # or similar. note there is no known case that # produces nested proxies, so we are only # looking one level deep right now. - if isinstance(ret, InspectionAttr) and \ - ret._is_internal_proxy and not isinstance( - ret.original_property, MapperProperty): + if ( + isinstance(ret, InspectionAttr) + and ret._is_internal_proxy + and not isinstance( + ret.original_property, MapperProperty + ) + ): ret = ret.descriptor dict_[name] = column_copies[obj] = ret - if isinstance(ret, (Column, MapperProperty)) and \ - ret.doc is None: + if ( + isinstance(ret, (Column, MapperProperty)) + and ret.doc is None + ): ret.doc = obj.__doc__ # here, the attribute is some other kind of property that # we assume is not part of the declarative mapping. @@ -321,8 +341,9 @@ class _MapperConfig(object): util.warn( "Attribute '%s' on class %s appears to be a non-schema " "'sqlalchemy.sql.column()' " - "object; this won't be part of the declarative mapping" % - (key, cls)) + "object; this won't be part of the declarative mapping" + % (key, cls) + ) def _produce_column_copies(self, base): cls = self.cls @@ -340,10 +361,11 @@ class _MapperConfig(object): raise exc.InvalidRequestError( "Columns with foreign keys to other columns " "must be declared as @declared_attr callables " - "on declarative mixin classes. ") + "on declarative mixin classes. " + ) elif name not in dict_ and not ( - '__table__' in dict_ and - (obj.name or name) in dict_['__table__'].c + "__table__" in dict_ + and (obj.name or name) in dict_["__table__"].c ): column_copies[obj] = copy_ = obj.copy() copy_._creation_order = obj._creation_order @@ -357,11 +379,12 @@ class _MapperConfig(object): our_stuff = self.properties late_mapped = _get_immediate_cls_attr( - cls, '_sa_decl_prepare_nocascade', strict=True) + cls, "_sa_decl_prepare_nocascade", strict=True + ) for k in list(dict_): - if k in ('__table__', '__tablename__', '__mapper_args__'): + if k in ("__table__", "__tablename__", "__mapper_args__"): continue value = dict_[k] @@ -371,29 +394,37 @@ class _MapperConfig(object): "Use of @declared_attr.cascading only applies to " "Declarative 'mixin' and 'abstract' classes. " "Currently, this flag is ignored on mapped class " - "%s" % self.cls) + "%s" % self.cls + ) value = getattr(cls, k) - elif isinstance(value, QueryableAttribute) and \ - value.class_ is not cls and \ - value.key != k: + elif ( + isinstance(value, QueryableAttribute) + and value.class_ is not cls + and value.key != k + ): # detect a QueryableAttribute that's already mapped being # assigned elsewhere in userland, turn into a synonym() value = synonym(value.key) setattr(cls, k, value) - if (isinstance(value, tuple) and len(value) == 1 and - isinstance(value[0], (Column, MapperProperty))): - util.warn("Ignoring declarative-like tuple value of attribute " - "'%s': possibly a copy-and-paste error with a comma " - "accidentally placed at the end of the line?" % k) + if ( + isinstance(value, tuple) + and len(value) == 1 + and isinstance(value[0], (Column, MapperProperty)) + ): + util.warn( + "Ignoring declarative-like tuple value of attribute " + "'%s': possibly a copy-and-paste error with a comma " + "accidentally placed at the end of the line?" % k + ) continue elif not isinstance(value, (Column, MapperProperty)): # using @declared_attr for some object that # isn't Column/MapperProperty; remove from the dict_ # and place the evaluated value onto the class. - if not k.startswith('__'): + if not k.startswith("__"): dict_.pop(k) self._warn_for_decl_attributes(cls, k, value) if not late_mapped: @@ -402,7 +433,7 @@ class _MapperConfig(object): # we expect to see the name 'metadata' in some valid cases; # however at this point we see it's assigned to something trying # to be mapped, so raise for that. - elif k == 'metadata': + elif k == "metadata": raise exc.InvalidRequestError( "Attribute name 'metadata' is reserved " "for the MetaData instance when using a " @@ -423,8 +454,7 @@ class _MapperConfig(object): for key, c in list(our_stuff.items()): if isinstance(c, (ColumnProperty, CompositeProperty)): for col in c.columns: - if isinstance(col, Column) and \ - col.table is None: + if isinstance(col, Column) and col.table is None: _undefer_column_name(key, col) if not isinstance(c, CompositeProperty): name_to_prop_key[col.name].add(key) @@ -447,8 +477,8 @@ class _MapperConfig(object): "On class %r, Column object %r named " "directly multiple times, " "only one will be used: %s. " - "Consider using orm.synonym instead" % - (self.classname, name, (", ".join(sorted(keys)))) + "Consider using orm.synonym instead" + % (self.classname, name, (", ".join(sorted(keys)))) ) def _setup_table(self): @@ -459,15 +489,16 @@ class _MapperConfig(object): declared_columns = self.declared_columns declared_columns = self.declared_columns = sorted( - declared_columns, key=lambda c: c._creation_order) + declared_columns, key=lambda c: c._creation_order + ) table = None - if hasattr(cls, '__table_cls__'): + if hasattr(cls, "__table_cls__"): table_cls = util.unbound_method_to_callable(cls.__table_cls__) else: table_cls = Table - if '__table__' not in dict_: + if "__table__" not in dict_: if tablename is not None: args, table_kw = (), {} @@ -480,14 +511,16 @@ class _MapperConfig(object): else: args = table_args - autoload = dict_.get('__autoload__') + autoload = dict_.get("__autoload__") if autoload: - table_kw['autoload'] = True + table_kw["autoload"] = True cls.__table__ = table = table_cls( - tablename, cls.metadata, + tablename, + cls.metadata, *(tuple(declared_columns) + tuple(args)), - **table_kw) + **table_kw + ) else: table = cls.__table__ if declared_columns: @@ -512,21 +545,27 @@ class _MapperConfig(object): c = _resolve_for_abstract_or_classical(c) if c is None: continue - if _declared_mapping_info(c) is not None and \ - not _get_immediate_cls_attr( - c, '_sa_decl_prepare_nocascade', strict=True): + if _declared_mapping_info( + c + ) is not None and not _get_immediate_cls_attr( + c, "_sa_decl_prepare_nocascade", strict=True + ): inherits.append(c) if inherits: if len(inherits) > 1: raise exc.InvalidRequestError( - "Class %s has multiple mapped bases: %r" % (cls, inherits)) + "Class %s has multiple mapped bases: %r" % (cls, inherits) + ) self.inherits = inherits[0] else: self.inherits = None - if table is None and self.inherits is None and \ - not _get_immediate_cls_attr(cls, '__no_table__'): + if ( + table is None + and self.inherits is None + and not _get_immediate_cls_attr(cls, "__no_table__") + ): raise exc.InvalidRequestError( "Class %r does not have a __table__ or __tablename__ " @@ -553,8 +592,8 @@ class _MapperConfig(object): continue raise exc.ArgumentError( "Column '%s' on class %s conflicts with " - "existing column '%s'" % - (c, cls, inherited_table.c[c.name]) + "existing column '%s'" + % (c, cls, inherited_table.c[c.name]) ) if c.primary_key: raise exc.ArgumentError( @@ -562,8 +601,10 @@ class _MapperConfig(object): "class with no table." ) inherited_table.append_column(c) - if inherited_mapped_table is not None and \ - inherited_mapped_table is not inherited_table: + if ( + inherited_mapped_table is not None + and inherited_mapped_table is not inherited_table + ): inherited_mapped_table._refresh_for_new_column(c) def _prepare_mapper_arguments(self): @@ -575,18 +616,19 @@ class _MapperConfig(object): # make sure that column copies are used rather # than the original columns from any mixins - for k in ('version_id_col', 'polymorphic_on',): + for k in ("version_id_col", "polymorphic_on"): if k in mapper_args: v = mapper_args[k] mapper_args[k] = self.column_copies.get(v, v) - assert 'inherits' not in mapper_args, \ - "Can't specify 'inherits' explicitly with declarative mappings" + assert ( + "inherits" not in mapper_args + ), "Can't specify 'inherits' explicitly with declarative mappings" if self.inherits: - mapper_args['inherits'] = self.inherits + mapper_args["inherits"] = self.inherits - if self.inherits and not mapper_args.get('concrete', False): + if self.inherits and not mapper_args.get("concrete", False): # single or joined inheritance # exclude any cols on the inherited table which are # not mapped on the parent class, to avoid @@ -594,16 +636,17 @@ class _MapperConfig(object): inherited_mapper = _declared_mapping_info(self.inherits) inherited_table = inherited_mapper.local_table - if 'exclude_properties' not in mapper_args: - mapper_args['exclude_properties'] = exclude_properties = \ - set( - [c.key for c in inherited_table.c - if c not in inherited_mapper._columntoproperty] - ).union( - inherited_mapper.exclude_properties or () - ) + if "exclude_properties" not in mapper_args: + mapper_args["exclude_properties"] = exclude_properties = set( + [ + c.key + for c in inherited_table.c + if c not in inherited_mapper._columntoproperty + ] + ).union(inherited_mapper.exclude_properties or ()) exclude_properties.difference_update( - [c.key for c in self.declared_columns]) + [c.key for c in self.declared_columns] + ) # look through columns in the current mapper that # are keyed to a propname different than the colname @@ -621,21 +664,20 @@ class _MapperConfig(object): # first. See [ticket:1892] for background. properties[k] = [col] + p.columns result_mapper_args = mapper_args.copy() - result_mapper_args['properties'] = properties + result_mapper_args["properties"] = properties self.mapper_args = result_mapper_args def map(self): self._prepare_mapper_arguments() - if hasattr(self.cls, '__mapper_cls__'): + if hasattr(self.cls, "__mapper_cls__"): mapper_cls = util.unbound_method_to_callable( - self.cls.__mapper_cls__) + self.cls.__mapper_cls__ + ) else: mapper_cls = mapper self.cls.__mapper__ = mp_ = mapper_cls( - self.cls, - self.local_table, - **self.mapper_args + self.cls, self.local_table, **self.mapper_args ) del self.cls._sa_declared_attr_reg return mp_ @@ -663,8 +705,7 @@ class _DeferredMapperConfig(_MapperConfig): @classmethod def has_cls(cls, class_): # 2.6 fails on weakref if class_ is an old style class - return isinstance(class_, type) and \ - weakref.ref(class_) in cls._configs + return isinstance(class_, type) and weakref.ref(class_) in cls._configs @classmethod def config_for_cls(cls, class_): @@ -673,18 +714,15 @@ class _DeferredMapperConfig(_MapperConfig): @classmethod def classes_for_base(cls, base_cls, sort=True): classes_for_base = [ - m for m, cls_ in - [(m, m.cls) for m in cls._configs.values()] + m + for m, cls_ in [(m, m.cls) for m in cls._configs.values()] if cls_ is not None and issubclass(cls_, base_cls) ] if not sort: return classes_for_base - all_m_by_cls = dict( - (m.cls, m) - for m in classes_for_base - ) + all_m_by_cls = dict((m.cls, m) for m in classes_for_base) tuples = [] for m_cls in all_m_by_cls: @@ -693,12 +731,7 @@ class _DeferredMapperConfig(_MapperConfig): for base_cls in m_cls.__bases__ if base_cls in all_m_by_cls ) - return list( - topological.sort( - tuples, - classes_for_base - ) - ) + return list(topological.sort(tuples, classes_for_base)) def map(self): self._configs.pop(self._cls, None) @@ -713,7 +746,7 @@ def _add_attribute(cls, key, value): """ - if '__mapper__' in cls.__dict__: + if "__mapper__" in cls.__dict__: if isinstance(value, Column): _undefer_column_name(key, value) cls.__table__.append_column(value) @@ -726,16 +759,14 @@ def _add_attribute(cls, key, value): cls.__mapper__.add_property(key, value) elif isinstance(value, MapperProperty): cls.__mapper__.add_property( - key, - clsregistry._deferred_relationship(cls, value) + key, clsregistry._deferred_relationship(cls, value) ) elif isinstance(value, QueryableAttribute) and value.key != key: # detect a QueryableAttribute that's already mapped being # assigned elsewhere in userland, turn into a synonym() value = synonym(value.key) cls.__mapper__.add_property( - key, - clsregistry._deferred_relationship(cls, value) + key, clsregistry._deferred_relationship(cls, value) ) else: type.__setattr__(cls, key, value) @@ -746,15 +777,18 @@ def _add_attribute(cls, key, value): def _del_attribute(cls, key): - if '__mapper__' in cls.__dict__ and \ - key in cls.__dict__ and not cls.__mapper__._dispose_called: + if ( + "__mapper__" in cls.__dict__ + and key in cls.__dict__ + and not cls.__mapper__._dispose_called + ): value = cls.__dict__[key] if isinstance( - value, - (Column, ColumnProperty, MapperProperty, QueryableAttribute) + value, (Column, ColumnProperty, MapperProperty, QueryableAttribute) ): raise NotImplementedError( - "Can't un-map individual mapped attributes on a mapped class.") + "Can't un-map individual mapped attributes on a mapped class." + ) else: type.__delattr__(cls, key) cls.__mapper__._expire_memoizations() @@ -776,10 +810,12 @@ def _declarative_constructor(self, **kwargs): for k in kwargs: if not hasattr(cls_, k): raise TypeError( - "%r is an invalid keyword argument for %s" % - (k, cls_.__name__)) + "%r is an invalid keyword argument for %s" % (k, cls_.__name__) + ) setattr(self, k, kwargs[k]) -_declarative_constructor.__name__ = '__init__' + + +_declarative_constructor.__name__ = "__init__" def _undefer_column_name(key, column): diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py index e941b9ed32..c52ae4a2f1 100644 --- a/lib/sqlalchemy/ext/declarative/clsregistry.py +++ b/lib/sqlalchemy/ext/declarative/clsregistry.py @@ -10,8 +10,11 @@ This system allows specification of classes and expressions used in :func:`.relationship` using strings. """ -from ...orm.properties import ColumnProperty, RelationshipProperty, \ - SynonymProperty +from ...orm.properties import ( + ColumnProperty, + RelationshipProperty, + SynonymProperty, +) from ...schema import _get_table_key from ...orm import class_mapper, interfaces from ... import util @@ -35,17 +38,18 @@ def add_class(classname, cls): # class already exists. existing = cls._decl_class_registry[classname] if not isinstance(existing, _MultipleClassMarker): - existing = \ - cls._decl_class_registry[classname] = \ - _MultipleClassMarker([cls, existing]) + existing = cls._decl_class_registry[ + classname + ] = _MultipleClassMarker([cls, existing]) else: cls._decl_class_registry[classname] = cls try: - root_module = cls._decl_class_registry['_sa_module_registry'] + root_module = cls._decl_class_registry["_sa_module_registry"] except KeyError: - cls._decl_class_registry['_sa_module_registry'] = \ - root_module = _ModuleMarker('_sa_module_registry', None) + cls._decl_class_registry[ + "_sa_module_registry" + ] = root_module = _ModuleMarker("_sa_module_registry", None) tokens = cls.__module__.split(".") @@ -71,12 +75,13 @@ class _MultipleClassMarker(object): """ - __slots__ = 'on_remove', 'contents', '__weakref__' + __slots__ = "on_remove", "contents", "__weakref__" def __init__(self, classes, on_remove=None): self.on_remove = on_remove - self.contents = set([ - weakref.ref(item, self._remove_item) for item in classes]) + self.contents = set( + [weakref.ref(item, self._remove_item) for item in classes] + ) _registries.add(self) def __iter__(self): @@ -85,10 +90,10 @@ class _MultipleClassMarker(object): def attempt_get(self, path, key): if len(self.contents) > 1: raise exc.InvalidRequestError( - "Multiple classes found for path \"%s\" " + 'Multiple classes found for path "%s" ' "in the registry of this declarative " - "base. Please use a fully module-qualified path." % - (".".join(path + [key])) + "base. Please use a fully module-qualified path." + % (".".join(path + [key])) ) else: ref = list(self.contents)[0] @@ -108,17 +113,19 @@ class _MultipleClassMarker(object): # protect against class registration race condition against # asynchronous garbage collection calling _remove_item, # [ticket:3208] - modules = set([ - cls.__module__ for cls in - [ref() for ref in self.contents] if cls is not None]) + modules = set( + [ + cls.__module__ + for cls in [ref() for ref in self.contents] + if cls is not None + ] + ) if item.__module__ in modules: util.warn( "This declarative base already contains a class with the " "same class name and module name as %s.%s, and will " - "be replaced in the string-lookup table." % ( - item.__module__, - item.__name__ - ) + "be replaced in the string-lookup table." + % (item.__module__, item.__name__) ) self.contents.add(weakref.ref(item, self._remove_item)) @@ -129,7 +136,7 @@ class _ModuleMarker(object): """ - __slots__ = 'parent', 'name', 'contents', 'mod_ns', 'path', '__weakref__' + __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__" def __init__(self, name, parent): self.parent = parent @@ -170,13 +177,13 @@ class _ModuleMarker(object): existing = self.contents[name] existing.add_item(cls) else: - existing = self.contents[name] = \ - _MultipleClassMarker([cls], - on_remove=lambda: self._remove_item(name)) + existing = self.contents[name] = _MultipleClassMarker( + [cls], on_remove=lambda: self._remove_item(name) + ) class _ModNS(object): - __slots__ = '__parent', + __slots__ = ("__parent",) def __init__(self, parent): self.__parent = parent @@ -193,13 +200,14 @@ class _ModNS(object): else: assert isinstance(value, _MultipleClassMarker) return value.attempt_get(self.__parent.path, key) - raise AttributeError("Module %r has no mapped classes " - "registered under the name %r" % ( - self.__parent.name, key)) + raise AttributeError( + "Module %r has no mapped classes " + "registered under the name %r" % (self.__parent.name, key) + ) class _GetColumns(object): - __slots__ = 'cls', + __slots__ = ("cls",) def __init__(self, cls): self.cls = cls @@ -210,7 +218,8 @@ class _GetColumns(object): if key not in mp.all_orm_descriptors: raise exc.InvalidRequestError( "Class %r does not have a mapped column named %r" - % (self.cls, key)) + % (self.cls, key) + ) desc = mp.all_orm_descriptors[key] if desc.extension_type is interfaces.NOT_EXTENSION: @@ -221,24 +230,25 @@ class _GetColumns(object): raise exc.InvalidRequestError( "Property %r is not an instance of" " ColumnProperty (i.e. does not correspond" - " directly to a Column)." % key) + " directly to a Column)." % key + ) return getattr(self.cls, key) + inspection._inspects(_GetColumns)( - lambda target: inspection.inspect(target.cls)) + lambda target: inspection.inspect(target.cls) +) class _GetTable(object): - __slots__ = 'key', 'metadata' + __slots__ = "key", "metadata" def __init__(self, key, metadata): self.key = key self.metadata = metadata def __getattr__(self, key): - return self.metadata.tables[ - _get_table_key(key, self.key) - ] + return self.metadata.tables[_get_table_key(key, self.key)] def _determine_container(key, value): @@ -264,9 +274,11 @@ class _class_resolver(object): return cls.metadata.tables[key] elif key in cls.metadata._schemas: return _GetTable(key, cls.metadata) - elif '_sa_module_registry' in cls._decl_class_registry and \ - key in cls._decl_class_registry['_sa_module_registry']: - registry = cls._decl_class_registry['_sa_module_registry'] + elif ( + "_sa_module_registry" in cls._decl_class_registry + and key in cls._decl_class_registry["_sa_module_registry"] + ): + registry = cls._decl_class_registry["_sa_module_registry"] return registry.resolve_attr(key) elif self._resolvers: for resolv in self._resolvers: @@ -289,8 +301,8 @@ class _class_resolver(object): "When initializing mapper %s, expression %r failed to " "locate a name (%r). If this is a class name, consider " "adding this relationship() to the %r class after " - "both dependent classes have been defined." % - (self.prop.parent, self.arg, n.args[0], self.cls) + "both dependent classes have been defined." + % (self.prop.parent, self.arg, n.args[0], self.cls) ) @@ -299,10 +311,11 @@ def _resolver(cls, prop): from sqlalchemy.orm import foreign, remote fallback = sqlalchemy.__dict__.copy() - fallback.update({'foreign': foreign, 'remote': remote}) + fallback.update({"foreign": foreign, "remote": remote}) def resolve_arg(arg): return _class_resolver(cls, prop, fallback, arg) + return resolve_arg @@ -311,18 +324,32 @@ def _deferred_relationship(cls, prop): if isinstance(prop, RelationshipProperty): resolve_arg = _resolver(cls, prop) - for attr in ('argument', 'order_by', 'primaryjoin', 'secondaryjoin', - 'secondary', '_user_defined_foreign_keys', 'remote_side'): + for attr in ( + "argument", + "order_by", + "primaryjoin", + "secondaryjoin", + "secondary", + "_user_defined_foreign_keys", + "remote_side", + ): v = getattr(prop, attr) if isinstance(v, util.string_types): setattr(prop, attr, resolve_arg(v)) if prop.backref and isinstance(prop.backref, tuple): key, kwargs = prop.backref - for attr in ('primaryjoin', 'secondaryjoin', 'secondary', - 'foreign_keys', 'remote_side', 'order_by'): - if attr in kwargs and isinstance(kwargs[attr], - util.string_types): + for attr in ( + "primaryjoin", + "secondaryjoin", + "secondary", + "foreign_keys", + "remote_side", + "order_by", + ): + if attr in kwargs and isinstance( + kwargs[attr], util.string_types + ): kwargs[attr] = resolve_arg(kwargs[attr]) return prop diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index f86e4fc93c..7248e5b4da 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -20,7 +20,7 @@ from .. import util from ..orm.session import Session from ..orm.query import Query -__all__ = ['ShardedSession', 'ShardedQuery'] +__all__ = ["ShardedSession", "ShardedQuery"] class ShardedQuery(Query): @@ -43,12 +43,10 @@ class ShardedQuery(Query): def _execute_and_instances(self, context): def iter_for_shard(shard_id): - context.attributes['shard_id'] = context.identity_token = shard_id + context.attributes["shard_id"] = context.identity_token = shard_id result = self._connection_from_session( - mapper=self._bind_mapper(), - shard_id=shard_id).execute( - context.statement, - self._params) + mapper=self._bind_mapper(), shard_id=shard_id + ).execute(context.statement, self._params) return self.instances(result, context) if context.identity_token is not None: @@ -70,7 +68,8 @@ class ShardedQuery(Query): mapper=mapper, shard_id=shard_id, clause=stmt, - close_with_result=True) + close_with_result=True, + ) result = conn.execute(stmt, self._params) return result @@ -87,8 +86,13 @@ class ShardedQuery(Query): return ShardedResult(results, rowcount) def _identity_lookup( - self, mapper, primary_key_identity, identity_token=None, - lazy_loaded_from=None, **kw): + self, + mapper, + primary_key_identity, + identity_token=None, + lazy_loaded_from=None, + **kw + ): """override the default Query._identity_lookup method so that we search for a given non-token primary key identity across all possible identity tokens (e.g. shard ids). @@ -97,8 +101,10 @@ class ShardedQuery(Query): if identity_token is not None: return super(ShardedQuery, self)._identity_lookup( - mapper, primary_key_identity, - identity_token=identity_token, **kw + mapper, + primary_key_identity, + identity_token=identity_token, + **kw ) else: q = self.session.query(mapper) @@ -113,13 +119,13 @@ class ShardedQuery(Query): return None - def _get_impl( - self, primary_key_identity, db_load_fn, identity_token=None): + def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None): """Override the default Query._get_impl() method so that we emit a query to the DB for each possible identity token, if we don't have one already. """ + def _db_load_fn(query, primary_key_identity): # load from the database. The original db_load_fn will # use the given Query object to load from the DB, so our @@ -142,7 +148,8 @@ class ShardedQuery(Query): identity_token = self._shard_id return super(ShardedQuery, self)._get_impl( - primary_key_identity, _db_load_fn, identity_token=identity_token) + primary_key_identity, _db_load_fn, identity_token=identity_token + ) class ShardedResult(object): @@ -158,7 +165,7 @@ class ShardedResult(object): .. versionadded:: 1.3 """ - __slots__ = ('result_proxies', 'aggregate_rowcount',) + __slots__ = ("result_proxies", "aggregate_rowcount") def __init__(self, result_proxies, aggregate_rowcount): self.result_proxies = result_proxies @@ -168,9 +175,17 @@ class ShardedResult(object): def rowcount(self): return self.aggregate_rowcount + class ShardedSession(Session): - def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, - query_cls=ShardedQuery, **kwargs): + def __init__( + self, + shard_chooser, + id_chooser, + query_chooser, + shards=None, + query_cls=ShardedQuery, + **kwargs + ): """Construct a ShardedSession. :param shard_chooser: A callable which, passed a Mapper, a mapped @@ -225,16 +240,16 @@ class ShardedSession(Session): return self.transaction.connection(mapper, shard_id=shard_id) else: return self.get_bind( - mapper, - shard_id=shard_id, - instance=instance + mapper, shard_id=shard_id, instance=instance ).contextual_connect(**kwargs) - def get_bind(self, mapper, shard_id=None, - instance=None, clause=None, **kw): + def get_bind( + self, mapper, shard_id=None, instance=None, clause=None, **kw + ): if shard_id is None: shard_id = self._choose_shard_and_assign( - mapper, instance, clause=clause) + mapper, instance, clause=clause + ) return self.__binds[shard_id] def bind_shard(self, shard_id, bind): diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 95eecb93f9..d51a083da7 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -778,7 +778,7 @@ there's probably a whole lot of amazing things it can be used for. from .. import util from ..orm import attributes, interfaces -HYBRID_METHOD = util.symbol('HYBRID_METHOD') +HYBRID_METHOD = util.symbol("HYBRID_METHOD") """Symbol indicating an :class:`InspectionAttr` that's of type :class:`.hybrid_method`. @@ -791,7 +791,7 @@ HYBRID_METHOD = util.symbol('HYBRID_METHOD') """ -HYBRID_PROPERTY = util.symbol('HYBRID_PROPERTY') +HYBRID_PROPERTY = util.symbol("HYBRID_PROPERTY") """Symbol indicating an :class:`InspectionAttr` that's of type :class:`.hybrid_method`. @@ -860,8 +860,14 @@ class hybrid_property(interfaces.InspectionAttrInfo): extension_type = HYBRID_PROPERTY def __init__( - self, fget, fset=None, fdel=None, - expr=None, custom_comparator=None, update_expr=None): + self, + fget, + fset=None, + fdel=None, + expr=None, + custom_comparator=None, + update_expr=None, + ): """Create a new :class:`.hybrid_property`. Usage is typically via decorator:: @@ -906,7 +912,8 @@ class hybrid_property(interfaces.InspectionAttrInfo): defaults = { key: value for key, value in self.__dict__.items() - if not key.startswith("_")} + if not key.startswith("_") + } defaults.update(**kw) return type(self)(**defaults) @@ -1078,9 +1085,9 @@ class hybrid_property(interfaces.InspectionAttrInfo): return self._get_expr(self.fget) def _get_expr(self, expr): - def _expr(cls): return ExprComparator(cls, expr(cls), self) + util.update_wrapper(_expr, expr) return self._get_comparator(_expr) @@ -1091,8 +1098,13 @@ class hybrid_property(interfaces.InspectionAttrInfo): def expr_comparator(owner): return proxy_attr( - owner, self.__name__, self, comparator(owner), - doc=comparator.__doc__ or self.__doc__) + owner, + self.__name__, + self, + comparator(owner), + doc=comparator.__doc__ or self.__doc__, + ) + return expr_comparator @@ -1108,7 +1120,7 @@ class Comparator(interfaces.PropComparator): def __clause_element__(self): expr = self.expression - if hasattr(expr, '__clause_element__'): + if hasattr(expr, "__clause_element__"): expr = expr.__clause_element__() return expr diff --git a/lib/sqlalchemy/ext/indexable.py b/lib/sqlalchemy/ext/indexable.py index 0bc2b65bb2..368e5b00ae 100644 --- a/lib/sqlalchemy/ext/indexable.py +++ b/lib/sqlalchemy/ext/indexable.py @@ -232,7 +232,7 @@ from ..orm.attributes import flag_modified from ..ext.hybrid import hybrid_property -__all__ = ['index_property'] +__all__ = ["index_property"] class index_property(hybrid_property): # noqa @@ -251,8 +251,14 @@ class index_property(hybrid_property): # noqa _NO_DEFAULT_ARGUMENT = object() def __init__( - self, attr_name, index, default=_NO_DEFAULT_ARGUMENT, - datatype=None, mutable=True, onebased=True): + self, + attr_name, + index, + default=_NO_DEFAULT_ARGUMENT, + datatype=None, + mutable=True, + onebased=True, + ): """Create a new :class:`.index_property`. :param attr_name: diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py index 30a0ab7d73..b2b8dd7c5d 100644 --- a/lib/sqlalchemy/ext/instrumentation.py +++ b/lib/sqlalchemy/ext/instrumentation.py @@ -28,15 +28,18 @@ see the example :ref:`examples_instrumentation`. """ from ..orm import instrumentation as orm_instrumentation from ..orm.instrumentation import ( - ClassManager, InstrumentationFactory, _default_state_getter, - _default_dict_getter, _default_manager_getter + ClassManager, + InstrumentationFactory, + _default_state_getter, + _default_dict_getter, + _default_manager_getter, ) from ..orm import attributes, collections, base as orm_base from .. import util from ..orm import exc as orm_exc import weakref -INSTRUMENTATION_MANAGER = '__sa_instrumentation_manager__' +INSTRUMENTATION_MANAGER = "__sa_instrumentation_manager__" """Attribute, elects custom instrumentation when present on a mapped class. Allows a class to specify a slightly or wildly different technique for @@ -66,6 +69,7 @@ def find_native_user_instrumentation_hook(cls): """Find user-specified instrumentation management for a class.""" return getattr(cls, INSTRUMENTATION_MANAGER, None) + instrumentation_finders = [find_native_user_instrumentation_hook] """An extensible sequence of callables which return instrumentation implementations @@ -89,6 +93,7 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): class managers. """ + _manager_finders = weakref.WeakKeyDictionary() _state_finders = weakref.WeakKeyDictionary() _dict_finders = weakref.WeakKeyDictionary() @@ -104,13 +109,15 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): return None, None def _check_conflicts(self, class_, factory): - existing_factories = self._collect_management_factories_for(class_).\ - difference([factory]) + existing_factories = self._collect_management_factories_for( + class_ + ).difference([factory]) if existing_factories: raise TypeError( "multiple instrumentation implementations specified " - "in %s inheritance hierarchy: %r" % ( - class_.__name__, list(existing_factories))) + "in %s inheritance hierarchy: %r" + % (class_.__name__, list(existing_factories)) + ) def _extended_class_manager(self, class_, factory): manager = factory(class_) @@ -178,17 +185,20 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): if instance is None: raise AttributeError("None has no persistent state.") return self._state_finders.get( - instance.__class__, _default_state_getter)(instance) + instance.__class__, _default_state_getter + )(instance) def dict_of(self, instance): if instance is None: raise AttributeError("None has no persistent state.") return self._dict_finders.get( - instance.__class__, _default_dict_getter)(instance) + instance.__class__, _default_dict_getter + )(instance) -orm_instrumentation._instrumentation_factory = \ - _instrumentation_factory = ExtendedInstrumentationRegistry() +orm_instrumentation._instrumentation_factory = ( + _instrumentation_factory +) = ExtendedInstrumentationRegistry() orm_instrumentation.instrumentation_finders = instrumentation_finders @@ -222,14 +232,15 @@ class InstrumentationManager(object): pass def manage(self, class_, manager): - setattr(class_, '_default_class_manager', manager) + setattr(class_, "_default_class_manager", manager) def dispose(self, class_, manager): - delattr(class_, '_default_class_manager') + delattr(class_, "_default_class_manager") def manager_getter(self, class_): def get(cls): return cls._default_class_manager + return get def instrument_attribute(self, class_, key, inst): @@ -260,13 +271,13 @@ class InstrumentationManager(object): pass def install_state(self, class_, instance, state): - setattr(instance, '_default_state', state) + setattr(instance, "_default_state", state) def remove_state(self, class_, instance): - delattr(instance, '_default_state') + delattr(instance, "_default_state") def state_getter(self, class_): - return lambda instance: getattr(instance, '_default_state') + return lambda instance: getattr(instance, "_default_state") def dict_getter(self, class_): return lambda inst: self.get_instance_dict(class_, inst) @@ -314,15 +325,17 @@ class _ClassInstrumentationAdapter(ClassManager): def instrument_collection_class(self, key, collection_class): return self._adapted.instrument_collection_class( - self.class_, key, collection_class) + self.class_, key, collection_class + ) def initialize_collection(self, key, state, factory): - delegate = getattr(self._adapted, 'initialize_collection', None) + delegate = getattr(self._adapted, "initialize_collection", None) if delegate: return delegate(key, state, factory) else: - return ClassManager.initialize_collection(self, key, - state, factory) + return ClassManager.initialize_collection( + self, key, state, factory + ) def new_instance(self, state=None): instance = self.class_.__new__(self.class_) @@ -384,7 +397,7 @@ def _install_instrumented_lookups(): dict( instance_state=_instrumentation_factory.state_of, instance_dict=_instrumentation_factory.dict_of, - manager_of_class=_instrumentation_factory.manager_of_class + manager_of_class=_instrumentation_factory.manager_of_class, ) ) @@ -395,7 +408,7 @@ def _reinstall_default_lookups(): dict( instance_state=_default_state_getter, instance_dict=_default_dict_getter, - manager_of_class=_default_manager_getter + manager_of_class=_default_manager_getter, ) ) _instrumentation_factory._extended = False @@ -403,12 +416,15 @@ def _reinstall_default_lookups(): def _install_lookups(lookups): global instance_state, instance_dict, manager_of_class - instance_state = lookups['instance_state'] - instance_dict = lookups['instance_dict'] - manager_of_class = lookups['manager_of_class'] - orm_base.instance_state = attributes.instance_state = \ - orm_instrumentation.instance_state = instance_state - orm_base.instance_dict = attributes.instance_dict = \ - orm_instrumentation.instance_dict = instance_dict - orm_base.manager_of_class = attributes.manager_of_class = \ - orm_instrumentation.manager_of_class = manager_of_class + instance_state = lookups["instance_state"] + instance_dict = lookups["instance_dict"] + manager_of_class = lookups["manager_of_class"] + orm_base.instance_state = ( + attributes.instance_state + ) = orm_instrumentation.instance_state = instance_state + orm_base.instance_dict = ( + attributes.instance_dict + ) = orm_instrumentation.instance_dict = instance_dict + orm_base.manager_of_class = ( + attributes.manager_of_class + ) = orm_instrumentation.manager_of_class = manager_of_class diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index 014cef3cce..0f6ccdc333 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -502,27 +502,29 @@ class MutableBase(object): def pickle(state, state_dict): val = state.dict.get(key, None) if val is not None: - if 'ext.mutable.values' not in state_dict: - state_dict['ext.mutable.values'] = [] - state_dict['ext.mutable.values'].append(val) + if "ext.mutable.values" not in state_dict: + state_dict["ext.mutable.values"] = [] + state_dict["ext.mutable.values"].append(val) def unpickle(state, state_dict): - if 'ext.mutable.values' in state_dict: - for val in state_dict['ext.mutable.values']: + if "ext.mutable.values" in state_dict: + for val in state_dict["ext.mutable.values"]: val._parents[state.obj()] = key - event.listen(parent_cls, 'load', load, - raw=True, propagate=True) - event.listen(parent_cls, 'refresh', load_attrs, - raw=True, propagate=True) - event.listen(parent_cls, 'refresh_flush', load_attrs, - raw=True, propagate=True) - event.listen(attribute, 'set', set, - raw=True, retval=True, propagate=True) - event.listen(parent_cls, 'pickle', pickle, - raw=True, propagate=True) - event.listen(parent_cls, 'unpickle', unpickle, - raw=True, propagate=True) + event.listen(parent_cls, "load", load, raw=True, propagate=True) + event.listen( + parent_cls, "refresh", load_attrs, raw=True, propagate=True + ) + event.listen( + parent_cls, "refresh_flush", load_attrs, raw=True, propagate=True + ) + event.listen( + attribute, "set", set, raw=True, retval=True, propagate=True + ) + event.listen(parent_cls, "pickle", pickle, raw=True, propagate=True) + event.listen( + parent_cls, "unpickle", unpickle, raw=True, propagate=True + ) class Mutable(MutableBase): @@ -572,7 +574,7 @@ class Mutable(MutableBase): if isinstance(prop.columns[0].type, sqltype): cls.associate_with_attribute(getattr(class_, prop.key)) - event.listen(mapper, 'mapper_configured', listen_for_type) + event.listen(mapper, "mapper_configured", listen_for_type) @classmethod def as_mutable(cls, sqltype): @@ -613,9 +615,11 @@ class Mutable(MutableBase): # and we'll lose our ability to link that type back to the original. # so track our original type w/ columns if isinstance(sqltype, SchemaEventTarget): + @event.listens_for(sqltype, "before_parent_attach") def _add_column_memo(sqltyp, parent): - parent.info['_ext_mutable_orig_type'] = sqltyp + parent.info["_ext_mutable_orig_type"] = sqltyp + schema_event_check = True else: schema_event_check = False @@ -625,16 +629,14 @@ class Mutable(MutableBase): return for prop in mapper.column_attrs: if ( - schema_event_check and - hasattr(prop.expression, 'info') and - prop.expression.info.get('_ext_mutable_orig_type') - is sqltype - ) or ( - prop.columns[0].type is sqltype - ): + schema_event_check + and hasattr(prop.expression, "info") + and prop.expression.info.get("_ext_mutable_orig_type") + is sqltype + ) or (prop.columns[0].type is sqltype): cls.associate_with_attribute(getattr(class_, prop.key)) - event.listen(mapper, 'mapper_configured', listen_for_type) + event.listen(mapper, "mapper_configured", listen_for_type) return sqltype @@ -659,21 +661,27 @@ class MutableComposite(MutableBase): prop = object_mapper(parent).get_property(key) for value, attr_name in zip( - self.__composite_values__(), - prop._attribute_keys): + self.__composite_values__(), prop._attribute_keys + ): setattr(parent, attr_name, value) def _setup_composite_listener(): def _listen_for_type(mapper, class_): for prop in mapper.iterate_properties: - if (hasattr(prop, 'composite_class') and - isinstance(prop.composite_class, type) and - issubclass(prop.composite_class, MutableComposite)): + if ( + hasattr(prop, "composite_class") + and isinstance(prop.composite_class, type) + and issubclass(prop.composite_class, MutableComposite) + ): prop.composite_class._listen_on_attribute( - getattr(class_, prop.key), False, class_) + getattr(class_, prop.key), False, class_ + ) + if not event.contains(Mapper, "mapper_configured", _listen_for_type): - event.listen(Mapper, 'mapper_configured', _listen_for_type) + event.listen(Mapper, "mapper_configured", _listen_for_type) + + _setup_composite_listener() @@ -947,4 +955,4 @@ class MutableSet(Mutable, set): self.update(state) def __reduce_ex__(self, proto): - return (self.__class__, (list(self), )) + return (self.__class__, (list(self),)) diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index 316742a674..2a8522120b 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -122,7 +122,7 @@ start numbering at 1 or some other integer, provide ``count_from=1``. from ..orm.collections import collection, collection_adapter from .. import util -__all__ = ['ordering_list'] +__all__ = ["ordering_list"] def ordering_list(attr, count_from=None, **kw): @@ -180,8 +180,9 @@ def count_from_n_factory(start): def f(index, collection): return index + start + try: - f.__name__ = 'count_from_%i' % start + f.__name__ = "count_from_%i" % start except TypeError: pass return f @@ -194,14 +195,14 @@ def _unsugar_count_from(**kw): ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged. """ - count_from = kw.pop('count_from', None) - if kw.get('ordering_func', None) is None and count_from is not None: + count_from = kw.pop("count_from", None) + if kw.get("ordering_func", None) is None and count_from is not None: if count_from == 0: - kw['ordering_func'] = count_from_0 + kw["ordering_func"] = count_from_0 elif count_from == 1: - kw['ordering_func'] = count_from_1 + kw["ordering_func"] = count_from_1 else: - kw['ordering_func'] = count_from_n_factory(count_from) + kw["ordering_func"] = count_from_n_factory(count_from) return kw @@ -214,8 +215,9 @@ class OrderingList(list): """ - def __init__(self, ordering_attr=None, ordering_func=None, - reorder_on_append=False): + def __init__( + self, ordering_attr=None, ordering_func=None, reorder_on_append=False + ): """A custom list that manages position information for its children. ``OrderingList`` is a ``collection_class`` list implementation that @@ -311,6 +313,7 @@ class OrderingList(list): """Append without any ordering behavior.""" super(OrderingList, self).append(entity) + _raw_append = collection.adds(1)(_raw_append) def insert(self, index, entity): @@ -361,8 +364,12 @@ class OrderingList(list): return _reconstitute, (self.__class__, self.__dict__, list(self)) for func_name, func in list(locals().items()): - if (util.callable(func) and func.__name__ == func_name and - not func.__doc__ and hasattr(list, func_name)): + if ( + util.callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(list, func_name) + ): func.__doc__ = getattr(list, func_name).__doc__ del func_name, func diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py index 2fded51d19..3adcec34f4 100644 --- a/lib/sqlalchemy/ext/serializer.py +++ b/lib/sqlalchemy/ext/serializer.py @@ -64,7 +64,7 @@ from ..util import pickle, byte_buffer, b64encode, b64decode, text_type import re -__all__ = ['Serializer', 'Deserializer', 'dumps', 'loads'] +__all__ = ["Serializer", "Deserializer", "dumps", "loads"] def Serializer(*args, **kw): @@ -79,13 +79,18 @@ def Serializer(*args, **kw): elif isinstance(obj, Mapper) and not obj.non_primary: id = "mapper:" + b64encode(pickle.dumps(obj.class_)) elif isinstance(obj, MapperProperty) and not obj.parent.non_primary: - id = "mapperprop:" + b64encode(pickle.dumps(obj.parent.class_)) + \ - ":" + obj.key + id = ( + "mapperprop:" + + b64encode(pickle.dumps(obj.parent.class_)) + + ":" + + obj.key + ) elif isinstance(obj, Table): id = "table:" + text_type(obj.key) elif isinstance(obj, Column) and isinstance(obj.table, Table): - id = "column:" + \ - text_type(obj.table.key) + ":" + text_type(obj.key) + id = ( + "column:" + text_type(obj.table.key) + ":" + text_type(obj.key) + ) elif isinstance(obj, Session): id = "session:" elif isinstance(obj, Engine): @@ -97,8 +102,10 @@ def Serializer(*args, **kw): pickler.persistent_id = persistent_id return pickler + our_ids = re.compile( - r'(mapperprop|mapper|table|column|session|attribute|engine):(.*)') + r"(mapperprop|mapper|table|column|session|attribute|engine):(.*)" +) def Deserializer(file, metadata=None, scoped_session=None, engine=None): @@ -120,7 +127,7 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None): return None else: type_, args = m.group(1, 2) - if type_ == 'attribute': + if type_ == "attribute": key, clsarg = args.split(":") cls = pickle.loads(b64decode(clsarg)) return getattr(cls, key) @@ -128,13 +135,13 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None): cls = pickle.loads(b64decode(args)) return class_mapper(cls) elif type_ == "mapperprop": - mapper, keyname = args.split(':') + mapper, keyname = args.split(":") cls = pickle.loads(b64decode(mapper)) return class_mapper(cls).attrs[keyname] elif type_ == "table": return metadata.tables[args] elif type_ == "column": - table, colname = args.split(':') + table, colname = args.split(":") return metadata.tables[table].c[colname] elif type_ == "session": return scoped_session() @@ -142,6 +149,7 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None): return get_engine() else: raise Exception("Unknown token: %s" % type_) + unpickler.persistent_load = persistent_load return unpickler diff --git a/lib/sqlalchemy/inspection.py b/lib/sqlalchemy/inspection.py index 3a03e25073..7c2ff97c57 100644 --- a/lib/sqlalchemy/inspection.py +++ b/lib/sqlalchemy/inspection.py @@ -32,6 +32,7 @@ in a forwards-compatible way. """ from . import util, exc + _registrars = util.defaultdict(list) @@ -66,13 +67,11 @@ def inspect(subject, raiseerr=True): else: reg = ret = None - if raiseerr and ( - reg is None or ret is None - ): + if raiseerr and (reg is None or ret is None): raise exc.NoInspectionAvailable( "No inspection system is " - "available for object of type %s" % - type_) + "available for object of type %s" % type_ + ) return ret @@ -81,10 +80,11 @@ def _inspects(*types): for type_ in types: if type_ in _registrars: raise AssertionError( - "Type %s is already " - "registered" % type_) + "Type %s is already " "registered" % type_ + ) _registrars[type_] = fn_or_cls return fn_or_cls + return decorate diff --git a/lib/sqlalchemy/interfaces.py b/lib/sqlalchemy/interfaces.py index 30698ea331..f352f7f263 100644 --- a/lib/sqlalchemy/interfaces.py +++ b/lib/sqlalchemy/interfaces.py @@ -80,17 +80,18 @@ class PoolListener(object): """ - listener = util.as_interface(listener, - methods=('connect', 'first_connect', - 'checkout', 'checkin')) - if hasattr(listener, 'connect'): - event.listen(self, 'connect', listener.connect) - if hasattr(listener, 'first_connect'): - event.listen(self, 'first_connect', listener.first_connect) - if hasattr(listener, 'checkout'): - event.listen(self, 'checkout', listener.checkout) - if hasattr(listener, 'checkin'): - event.listen(self, 'checkin', listener.checkin) + listener = util.as_interface( + listener, + methods=("connect", "first_connect", "checkout", "checkin"), + ) + if hasattr(listener, "connect"): + event.listen(self, "connect", listener.connect) + if hasattr(listener, "first_connect"): + event.listen(self, "first_connect", listener.first_connect) + if hasattr(listener, "checkout"): + event.listen(self, "checkout", listener.checkout) + if hasattr(listener, "checkin"): + event.listen(self, "checkin", listener.checkin) def connect(self, dbapi_con, con_record): """Called once for each new DB-API connection or Pool's ``creator()``. @@ -187,27 +188,20 @@ class ConnectionProxy(object): @classmethod def _adapt_listener(cls, self, listener): - def adapt_execute(conn, clauseelement, multiparams, params): - def execute_wrapper(clauseelement, *multiparams, **params): return clauseelement, multiparams, params - return listener.execute(conn, execute_wrapper, - clauseelement, *multiparams, - **params) + return listener.execute( + conn, execute_wrapper, clauseelement, *multiparams, **params + ) - event.listen(self, 'before_execute', adapt_execute) + event.listen(self, "before_execute", adapt_execute) - def adapt_cursor_execute(conn, cursor, statement, - parameters, context, executemany): - - def execute_wrapper( - cursor, - statement, - parameters, - context, - ): + def adapt_cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): + def execute_wrapper(cursor, statement, parameters, context): return statement, parameters return listener.cursor_execute( @@ -217,46 +211,56 @@ class ConnectionProxy(object): parameters, context, executemany, - ) + ) - event.listen(self, 'before_cursor_execute', adapt_cursor_execute) + event.listen(self, "before_cursor_execute", adapt_cursor_execute) def do_nothing_callback(*arg, **kw): pass def adapt_listener(fn): - def go(conn, *arg, **kw): fn(conn, do_nothing_callback, *arg, **kw) return util.update_wrapper(go, fn) - event.listen(self, 'begin', adapt_listener(listener.begin)) - event.listen(self, 'rollback', - adapt_listener(listener.rollback)) - event.listen(self, 'commit', adapt_listener(listener.commit)) - event.listen(self, 'savepoint', - adapt_listener(listener.savepoint)) - event.listen(self, 'rollback_savepoint', - adapt_listener(listener.rollback_savepoint)) - event.listen(self, 'release_savepoint', - adapt_listener(listener.release_savepoint)) - event.listen(self, 'begin_twophase', - adapt_listener(listener.begin_twophase)) - event.listen(self, 'prepare_twophase', - adapt_listener(listener.prepare_twophase)) - event.listen(self, 'rollback_twophase', - adapt_listener(listener.rollback_twophase)) - event.listen(self, 'commit_twophase', - adapt_listener(listener.commit_twophase)) + event.listen(self, "begin", adapt_listener(listener.begin)) + event.listen(self, "rollback", adapt_listener(listener.rollback)) + event.listen(self, "commit", adapt_listener(listener.commit)) + event.listen(self, "savepoint", adapt_listener(listener.savepoint)) + event.listen( + self, + "rollback_savepoint", + adapt_listener(listener.rollback_savepoint), + ) + event.listen( + self, + "release_savepoint", + adapt_listener(listener.release_savepoint), + ) + event.listen( + self, "begin_twophase", adapt_listener(listener.begin_twophase) + ) + event.listen( + self, "prepare_twophase", adapt_listener(listener.prepare_twophase) + ) + event.listen( + self, + "rollback_twophase", + adapt_listener(listener.rollback_twophase), + ) + event.listen( + self, "commit_twophase", adapt_listener(listener.commit_twophase) + ) def execute(self, conn, execute, clauseelement, *multiparams, **params): """Intercept high level execute() events.""" return execute(clauseelement, *multiparams, **params) - def cursor_execute(self, execute, cursor, statement, parameters, - context, executemany): + def cursor_execute( + self, execute, cursor, statement, parameters, context, executemany + ): """Intercept low-level cursor execute() events.""" return execute(cursor, statement, parameters, context) diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py index a79b21e174..6b0b2e90ef 100644 --- a/lib/sqlalchemy/log.py +++ b/lib/sqlalchemy/log.py @@ -24,15 +24,16 @@ import sys # set initial level to WARN. This so that # log statements don't occur in the absence of explicit # logging being enabled for 'sqlalchemy'. -rootlogger = logging.getLogger('sqlalchemy') +rootlogger = logging.getLogger("sqlalchemy") if rootlogger.level == logging.NOTSET: rootlogger.setLevel(logging.WARN) def _add_default_handler(logger): handler = logging.StreamHandler(sys.stdout) - handler.setFormatter(logging.Formatter( - '%(asctime)s %(levelname)s %(name)s %(message)s')) + handler.setFormatter( + logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s") + ) logger.addHandler(handler) @@ -82,7 +83,7 @@ class InstanceLogger(object): None: logging.NOTSET, False: logging.NOTSET, True: logging.INFO, - 'debug': logging.DEBUG, + "debug": logging.DEBUG, } def __init__(self, echo, name): @@ -91,8 +92,7 @@ class InstanceLogger(object): # if echo flag is enabled and no handlers, # add a handler to the list - if self._echo_map[echo] <= logging.INFO \ - and not self.logger.handlers: + if self._echo_map[echo] <= logging.INFO and not self.logger.handlers: _add_default_handler(self.logger) # @@ -174,12 +174,16 @@ def instance_logger(instance, echoflag=None): """create a logger for an instance that implements :class:`.Identified`.""" if instance.logging_name: - name = "%s.%s.%s" % (instance.__class__.__module__, - instance.__class__.__name__, - instance.logging_name) + name = "%s.%s.%s" % ( + instance.__class__.__module__, + instance.__class__.__name__, + instance.logging_name, + ) else: - name = "%s.%s" % (instance.__class__.__module__, - instance.__class__.__name__) + name = "%s.%s" % ( + instance.__class__.__module__, + instance.__class__.__name__, + ) instance._echo = echoflag diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 1784ea21fb..8e7b4cee65 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -20,14 +20,9 @@ from .mapper import ( class_mapper, configure_mappers, reconstructor, - validates -) -from .interfaces import ( - EXT_CONTINUE, - EXT_STOP, - EXT_SKIP, - PropComparator, + validates, ) +from .interfaces import EXT_CONTINUE, EXT_STOP, EXT_SKIP, PropComparator from .deprecated_interfaces import ( MapperExtension, SessionExtension, @@ -50,20 +45,15 @@ from .descriptor_props import ( CompositeProperty, SynonymProperty, ) -from .relationships import ( - foreign, - remote, -) +from .relationships import foreign, remote from .session import ( Session, object_session, sessionmaker, make_transient, - make_transient_to_detached -) -from .scoping import ( - scoped_session + make_transient_to_detached, ) +from .scoping import scoped_session from . import mapper as mapperlib from .query import AliasOption, Query, Bundle from ..util.langhelpers import public_factory @@ -103,11 +93,12 @@ def create_session(bind=None, **kwargs): create_session(). """ - kwargs.setdefault('autoflush', False) - kwargs.setdefault('autocommit', True) - kwargs.setdefault('expire_on_commit', False) + kwargs.setdefault("autoflush", False) + kwargs.setdefault("autocommit", True) + kwargs.setdefault("expire_on_commit", False) return Session(bind=bind, **kwargs) + relationship = public_factory(RelationshipProperty, ".orm.relationship") @@ -133,7 +124,7 @@ def dynamic_loader(argument, **kw): on dynamic loading. """ - kw['lazy'] = 'dynamic' + kw["lazy"] = "dynamic" return relationship(argument, **kw) @@ -193,16 +184,21 @@ def query_expression(): prop.strategy_key = (("query_expression", True),) return prop + mapper = public_factory(Mapper, ".orm.mapper") synonym = public_factory(SynonymProperty, ".orm.synonym") -comparable_property = public_factory(ComparableProperty, - ".orm.comparable_property") +comparable_property = public_factory( + ComparableProperty, ".orm.comparable_property" +) -@_sa_util.deprecated("0.7", message=":func:`.compile_mappers` " - "is renamed to :func:`.configure_mappers`") +@_sa_util.deprecated( + "0.7", + message=":func:`.compile_mappers` " + "is renamed to :func:`.configure_mappers`", +) def compile_mappers(): """Initialize the inter-mapper relationships of all mappers that have been defined. @@ -243,6 +239,7 @@ def clear_mappers(): finally: mapperlib._CONFIGURE_MUTEX.release() + from . import strategy_options joinedload = strategy_options.joinedload._unbound_fn @@ -289,10 +286,14 @@ def __go(lcls): from . import loading import inspect as _inspect - __all__ = sorted(name for name, obj in lcls.items() - if not (name.startswith('_') or _inspect.ismodule(obj))) + __all__ = sorted( + name + for name, obj in lcls.items() + if not (name.startswith("_") or _inspect.ismodule(obj)) + ) _sa_util.dependencies.resolve_all("sqlalchemy.orm") _sa_util.dependencies.resolve_all("sqlalchemy.ext") + __go(locals()) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index b08c467413..1648c9ae10 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -20,19 +20,37 @@ from . import interfaces, collections, exc as orm_exc from .base import instance_state, instance_dict, manager_of_class -from .base import PASSIVE_NO_RESULT, ATTR_WAS_SET, ATTR_EMPTY, NO_VALUE,\ - NEVER_SET, NO_CHANGE, CALLABLES_OK, SQL_OK, RELATED_OBJECT_OK,\ - INIT_OK, NON_PERSISTENT_OK, LOAD_AGAINST_COMMITTED, PASSIVE_OFF,\ - PASSIVE_RETURN_NEVER_SET, PASSIVE_NO_INITIALIZE, PASSIVE_NO_FETCH,\ - PASSIVE_NO_FETCH_RELATED, PASSIVE_ONLY_PERSISTENT, NO_AUTOFLUSH, \ - NO_RAISE +from .base import ( + PASSIVE_NO_RESULT, + ATTR_WAS_SET, + ATTR_EMPTY, + NO_VALUE, + NEVER_SET, + NO_CHANGE, + CALLABLES_OK, + SQL_OK, + RELATED_OBJECT_OK, + INIT_OK, + NON_PERSISTENT_OK, + LOAD_AGAINST_COMMITTED, + PASSIVE_OFF, + PASSIVE_RETURN_NEVER_SET, + PASSIVE_NO_INITIALIZE, + PASSIVE_NO_FETCH, + PASSIVE_NO_FETCH_RELATED, + PASSIVE_ONLY_PERSISTENT, + NO_AUTOFLUSH, + NO_RAISE, +) from .base import state_str, instance_str @inspection._self_inspects -class QueryableAttribute(interfaces._MappedAttribute, - interfaces.InspectionAttr, - interfaces.PropComparator): +class QueryableAttribute( + interfaces._MappedAttribute, + interfaces.InspectionAttr, + interfaces.PropComparator, +): """Base class for :term:`descriptor` objects that intercept attribute events on behalf of a :class:`.MapperProperty` object. The actual :class:`.MapperProperty` is accessible @@ -53,9 +71,15 @@ class QueryableAttribute(interfaces._MappedAttribute, is_attribute = True - def __init__(self, class_, key, impl=None, - comparator=None, parententity=None, - of_type=None): + def __init__( + self, + class_, + key, + impl=None, + comparator=None, + parententity=None, + of_type=None, + ): self.class_ = class_ self.key = key self.impl = impl @@ -77,8 +101,9 @@ class QueryableAttribute(interfaces._MappedAttribute, return self.impl.supports_population def get_history(self, instance, passive=PASSIVE_OFF): - return self.impl.get_history(instance_state(instance), - instance_dict(instance), passive) + return self.impl.get_history( + instance_state(instance), instance_dict(instance), passive + ) def __selectable__(self): # TODO: conditionally attach this method based on clause_element ? @@ -159,11 +184,13 @@ class QueryableAttribute(interfaces._MappedAttribute, def adapt_to_entity(self, adapt_to_entity): assert not self._of_type - return self.__class__(adapt_to_entity.entity, - self.key, impl=self.impl, - comparator=self.comparator.adapt_to_entity( - adapt_to_entity), - parententity=adapt_to_entity) + return self.__class__( + adapt_to_entity.entity, + self.key, + impl=self.impl, + comparator=self.comparator.adapt_to_entity(adapt_to_entity), + parententity=adapt_to_entity, + ) def of_type(self, cls): return QueryableAttribute( @@ -172,7 +199,8 @@ class QueryableAttribute(interfaces._MappedAttribute, self.impl, self.comparator.of_type(cls), self._parententity, - of_type=cls) + of_type=cls, + ) def label(self, name): return self._query_clause_element().label(name) @@ -191,12 +219,14 @@ class QueryableAttribute(interfaces._MappedAttribute, return getattr(self.comparator, key) except AttributeError: raise AttributeError( - 'Neither %r object nor %r object associated with %s ' - 'has an attribute %r' % ( + "Neither %r object nor %r object associated with %s " + "has an attribute %r" + % ( type(self).__name__, type(self.comparator).__name__, self, - key) + key, + ) ) def __str__(self): @@ -226,8 +256,9 @@ class InstrumentedAttribute(QueryableAttribute): """ def __set__(self, instance, value): - self.impl.set(instance_state(instance), - instance_dict(instance), value, None) + self.impl.set( + instance_state(instance), instance_dict(instance), value, None + ) def __delete__(self, instance): self.impl.delete(instance_state(instance), instance_dict(instance)) @@ -260,10 +291,16 @@ def create_proxied_attribute(descriptor): """ - def __init__(self, class_, key, descriptor, - comparator, - adapt_to_entity=None, doc=None, - original_property=None): + def __init__( + self, + class_, + key, + descriptor, + comparator, + adapt_to_entity=None, + doc=None, + original_property=None, + ): self.class_ = class_ self.key = key self.descriptor = descriptor @@ -284,15 +321,18 @@ def create_proxied_attribute(descriptor): self._comparator = self._comparator() if self._adapt_to_entity: self._comparator = self._comparator.adapt_to_entity( - self._adapt_to_entity) + self._adapt_to_entity + ) return self._comparator def adapt_to_entity(self, adapt_to_entity): - return self.__class__(adapt_to_entity.entity, - self.key, - self.descriptor, - self._comparator, - adapt_to_entity) + return self.__class__( + adapt_to_entity.entity, + self.key, + self.descriptor, + self._comparator, + adapt_to_entity, + ) def __get__(self, instance, owner): if instance is None: @@ -314,21 +354,24 @@ def create_proxied_attribute(descriptor): return getattr(self.comparator, attribute) except AttributeError: raise AttributeError( - 'Neither %r object nor %r object associated with %s ' - 'has an attribute %r' % ( + "Neither %r object nor %r object associated with %s " + "has an attribute %r" + % ( type(descriptor).__name__, type(self.comparator).__name__, self, - attribute) + attribute, + ) ) - Proxy.__name__ = type(descriptor).__name__ + 'Proxy' + Proxy.__name__ = type(descriptor).__name__ + "Proxy" - util.monkeypatch_proxied_specials(Proxy, type(descriptor), - name='descriptor', - from_instance=descriptor) + util.monkeypatch_proxied_specials( + Proxy, type(descriptor), name="descriptor", from_instance=descriptor + ) return Proxy + OP_REMOVE = util.symbol("REMOVE") OP_APPEND = util.symbol("APPEND") OP_REPLACE = util.symbol("REPLACE") @@ -364,7 +407,7 @@ class Event(object): """ - __slots__ = 'impl', 'op', 'parent_token' + __slots__ = "impl", "op", "parent_token" def __init__(self, attribute_impl, op): self.impl = attribute_impl @@ -372,9 +415,11 @@ class Event(object): self.parent_token = self.impl.parent_token def __eq__(self, other): - return isinstance(other, Event) and \ - other.impl is self.impl and \ - other.op == self.op + return ( + isinstance(other, Event) + and other.impl is self.impl + and other.op == self.op + ) @property def key(self): @@ -387,12 +432,22 @@ class Event(object): class AttributeImpl(object): """internal implementation for instrumented attributes.""" - def __init__(self, class_, key, - callable_, dispatch, trackparent=False, extension=None, - compare_function=None, active_history=False, - parent_token=None, expire_missing=True, - send_modified_events=True, accepts_scalar_loader=None, - **kwargs): + def __init__( + self, + class_, + key, + callable_, + dispatch, + trackparent=False, + extension=None, + compare_function=None, + active_history=False, + parent_token=None, + expire_missing=True, + send_modified_events=True, + accepts_scalar_loader=None, + **kwargs + ): r"""Construct an AttributeImpl. \class_ @@ -471,9 +526,17 @@ class AttributeImpl(object): self._modified_token = Event(self, OP_MODIFIED) __slots__ = ( - 'class_', 'key', 'callable_', 'dispatch', 'trackparent', - 'parent_token', 'send_modified_events', 'is_equal', 'expire_missing', - '_modified_token', 'accepts_scalar_loader' + "class_", + "key", + "callable_", + "dispatch", + "trackparent", + "parent_token", + "send_modified_events", + "is_equal", + "expire_missing", + "_modified_token", + "accepts_scalar_loader", ) def __str__(self): @@ -508,8 +571,9 @@ class AttributeImpl(object): msg = "This AttributeImpl is not configured to track parents." assert self.trackparent, msg - return state.parents.get(id(self.parent_token), optimistic) \ - is not False + return ( + state.parents.get(id(self.parent_token), optimistic) is not False + ) def sethasparent(self, state, parent_state, value): """Set a boolean flag on the given item corresponding to @@ -527,8 +591,10 @@ class AttributeImpl(object): if id_ in state.parents: last_parent = state.parents[id_] - if last_parent is not False and \ - last_parent.key != parent_state.key: + if ( + last_parent is not False + and last_parent.key != parent_state.key + ): if last_parent.obj() is None: raise orm_exc.StaleDataError( @@ -536,10 +602,13 @@ class AttributeImpl(object): "state %s along attribute '%s', " "but the parent record " "has gone stale, can't be sure this " - "is the most recent parent." % - (state_str(state), - state_str(parent_state), - self.key)) + "is the most recent parent." + % ( + state_str(state), + state_str(parent_state), + self.key, + ) + ) return @@ -588,8 +657,10 @@ class AttributeImpl(object): else: # if history present, don't load key = self.key - if key not in state.committed_state or \ - state.committed_state[key] is NEVER_SET: + if ( + key not in state.committed_state + or state.committed_state[key] is NEVER_SET + ): if not passive & CALLABLES_OK: return PASSIVE_NO_RESULT @@ -613,7 +684,8 @@ class AttributeImpl(object): raise KeyError( "Deferred loader for attribute " "%r failed to populate " - "correctly" % key) + "correctly" % key + ) elif value is not ATTR_EMPTY: return self.set_committed_value(state, dict_, value) @@ -627,15 +699,31 @@ class AttributeImpl(object): self.set(state, dict_, value, initiator, passive=passive) def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): - self.set(state, dict_, None, initiator, - passive=passive, check_old=value) + self.set( + state, dict_, None, initiator, passive=passive, check_old=value + ) def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF): - self.set(state, dict_, None, initiator, - passive=passive, check_old=value, pop=True) + self.set( + state, + dict_, + None, + initiator, + passive=passive, + check_old=value, + pop=True, + ) - def set(self, state, dict_, value, initiator, - passive=PASSIVE_OFF, check_old=None, pop=False): + def set( + self, + state, + dict_, + value, + initiator, + passive=PASSIVE_OFF, + check_old=None, + pop=False, + ): raise NotImplementedError() def get_committed_value(self, state, dict_, passive=PASSIVE_OFF): @@ -667,7 +755,7 @@ class ScalarAttributeImpl(AttributeImpl): collection = False dynamic = False - __slots__ = '_replace_token', '_append_token', '_remove_token' + __slots__ = "_replace_token", "_append_token", "_remove_token" def __init__(self, *arg, **kw): super(ScalarAttributeImpl, self).__init__(*arg, **kw) @@ -685,10 +773,13 @@ class ScalarAttributeImpl(AttributeImpl): state._modified_event(dict_, self, old) existing = dict_.pop(self.key, NO_VALUE) - if existing is NO_VALUE and old is NO_VALUE and \ - not state.expired and \ - self.key not in state.expired_attributes: - raise AttributeError("%s object does not have a value" % self) + if ( + existing is NO_VALUE + and old is NO_VALUE + and not state.expired + and self.key not in state.expired_attributes + ): + raise AttributeError("%s object does not have a value" % self) def get_history(self, state, dict_, passive=PASSIVE_OFF): if self.key in dict_: @@ -702,23 +793,33 @@ class ScalarAttributeImpl(AttributeImpl): else: return History.from_scalar_attribute(self, state, current) - def set(self, state, dict_, value, initiator, - passive=PASSIVE_OFF, check_old=None, pop=False): + def set( + self, + state, + dict_, + value, + initiator, + passive=PASSIVE_OFF, + check_old=None, + pop=False, + ): if self.dispatch._active_history: old = self.get(state, dict_, PASSIVE_RETURN_NEVER_SET) else: old = dict_.get(self.key, NO_VALUE) if self.dispatch.set: - value = self.fire_replace_event(state, dict_, - value, old, initiator) + value = self.fire_replace_event( + state, dict_, value, old, initiator + ) state._modified_event(dict_, self, old) dict_[self.key] = value def fire_replace_event(self, state, dict_, value, previous, initiator): for fn in self.dispatch.set: value = fn( - state, value, previous, initiator or self._replace_token) + state, value, previous, initiator or self._replace_token + ) return value def fire_remove_event(self, state, dict_, value, initiator): @@ -748,13 +849,20 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): def delete(self, state, dict_): if self.dispatch._active_history: old = self.get( - state, dict_, - passive=PASSIVE_ONLY_PERSISTENT | - NO_AUTOFLUSH | LOAD_AGAINST_COMMITTED) + state, + dict_, + passive=PASSIVE_ONLY_PERSISTENT + | NO_AUTOFLUSH + | LOAD_AGAINST_COMMITTED, + ) else: old = self.get( - state, dict_, passive=PASSIVE_NO_FETCH ^ INIT_OK | - LOAD_AGAINST_COMMITTED | NO_RAISE) + state, + dict_, + passive=PASSIVE_NO_FETCH ^ INIT_OK + | LOAD_AGAINST_COMMITTED + | NO_RAISE, + ) self.fire_remove_event(state, dict_, old, self._remove_token) @@ -763,8 +871,11 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): # if the attribute is expired, we currently have no way to tell # that an object-attribute was expired vs. not loaded. So # for this test, we look to see if the object has a DB identity. - if existing is NO_VALUE and old is not PASSIVE_NO_RESULT and \ - state.key is None: + if ( + existing is NO_VALUE + and old is not PASSIVE_NO_RESULT + and state.key is None + ): raise AttributeError("%s object does not have a value" % self) def get_history(self, state, dict_, passive=PASSIVE_OFF): @@ -788,50 +899,69 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): return [] # can't use __hash__(), can't use __eq__() here - if current is not None and \ - current is not PASSIVE_NO_RESULT and \ - current is not NEVER_SET: + if ( + current is not None + and current is not PASSIVE_NO_RESULT + and current is not NEVER_SET + ): ret = [(instance_state(current), current)] else: ret = [(None, None)] if self.key in state.committed_state: original = state.committed_state[self.key] - if original is not None and \ - original is not PASSIVE_NO_RESULT and \ - original is not NEVER_SET and \ - original is not current: + if ( + original is not None + and original is not PASSIVE_NO_RESULT + and original is not NEVER_SET + and original is not current + ): ret.append((instance_state(original), original)) return ret - def set(self, state, dict_, value, initiator, - passive=PASSIVE_OFF, check_old=None, pop=False): + def set( + self, + state, + dict_, + value, + initiator, + passive=PASSIVE_OFF, + check_old=None, + pop=False, + ): """Set a value on the given InstanceState. """ if self.dispatch._active_history: old = self.get( - state, dict_, - passive=PASSIVE_ONLY_PERSISTENT | - NO_AUTOFLUSH | LOAD_AGAINST_COMMITTED) + state, + dict_, + passive=PASSIVE_ONLY_PERSISTENT + | NO_AUTOFLUSH + | LOAD_AGAINST_COMMITTED, + ) else: old = self.get( - state, dict_, passive=PASSIVE_NO_FETCH ^ INIT_OK | - LOAD_AGAINST_COMMITTED | NO_RAISE) + state, + dict_, + passive=PASSIVE_NO_FETCH ^ INIT_OK + | LOAD_AGAINST_COMMITTED + | NO_RAISE, + ) - if check_old is not None and \ - old is not PASSIVE_NO_RESULT and \ - check_old is not old: + if ( + check_old is not None + and old is not PASSIVE_NO_RESULT + and check_old is not old + ): if pop: return else: raise ValueError( - "Object %s not associated with %s on attribute '%s'" % ( - instance_str(check_old), - state_str(state), - self.key - )) + "Object %s not associated with %s on attribute '%s'" + % (instance_str(check_old), state_str(state), self.key) + ) value = self.fire_replace_event(state, dict_, value, old, initiator) dict_[self.key] = value @@ -847,13 +977,17 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): def fire_replace_event(self, state, dict_, value, previous, initiator): if self.trackparent: - if (previous is not value and - previous not in (None, PASSIVE_NO_RESULT, NEVER_SET)): + if previous is not value and previous not in ( + None, + PASSIVE_NO_RESULT, + NEVER_SET, + ): self.sethasparent(instance_state(previous), state, False) for fn in self.dispatch.set: value = fn( - state, value, previous, initiator or self._replace_token) + state, value, previous, initiator or self._replace_token + ) state._modified_event(dict_, self, previous) @@ -875,6 +1009,7 @@ class CollectionAttributeImpl(AttributeImpl): semantics to the orm layer independent of the user data implementation. """ + default_accepts_scalar_loader = False uses_objects = True supports_population = True @@ -882,21 +1017,37 @@ class CollectionAttributeImpl(AttributeImpl): dynamic = False __slots__ = ( - 'copy', 'collection_factory', '_append_token', '_remove_token', - '_bulk_replace_token', '_duck_typed_as' + "copy", + "collection_factory", + "_append_token", + "_remove_token", + "_bulk_replace_token", + "_duck_typed_as", ) - def __init__(self, class_, key, callable_, dispatch, - typecallable=None, trackparent=False, extension=None, - copy_function=None, compare_function=None, **kwargs): + def __init__( + self, + class_, + key, + callable_, + dispatch, + typecallable=None, + trackparent=False, + extension=None, + copy_function=None, + compare_function=None, + **kwargs + ): super(CollectionAttributeImpl, self).__init__( class_, key, - callable_, dispatch, + callable_, + dispatch, trackparent=trackparent, extension=extension, compare_function=compare_function, - **kwargs) + **kwargs + ) if copy_function is None: copy_function = self.__copy @@ -906,7 +1057,8 @@ class CollectionAttributeImpl(AttributeImpl): self._remove_token = Event(self, OP_REMOVE) self._bulk_replace_token = Event(self, OP_BULK_REPLACE) self._duck_typed_as = util.duck_type_collection( - self.collection_factory()) + self.collection_factory() + ) if getattr(self.collection_factory, "_sa_linker", None): @@ -935,35 +1087,42 @@ class CollectionAttributeImpl(AttributeImpl): return [] current = dict_[self.key] - current = getattr(current, '_sa_adapter') + current = getattr(current, "_sa_adapter") if self.key in state.committed_state: original = state.committed_state[self.key] if original not in (NO_VALUE, NEVER_SET): - current_states = [((c is not None) and - instance_state(c) or None, c) - for c in current] - original_states = [((c is not None) and - instance_state(c) or None, c) - for c in original] + current_states = [ + ((c is not None) and instance_state(c) or None, c) + for c in current + ] + original_states = [ + ((c is not None) and instance_state(c) or None, c) + for c in original + ] current_set = dict(current_states) original_set = dict(original_states) - return \ - [(s, o) for s, o in current_states - if s not in original_set] + \ - [(s, o) for s, o in current_states - if s in original_set] + \ - [(s, o) for s, o in original_states - if s not in current_set] + return ( + [ + (s, o) + for s, o in current_states + if s not in original_set + ] + + [(s, o) for s, o in current_states if s in original_set] + + [ + (s, o) + for s, o in original_states + if s not in current_set + ] + ) return [(instance_state(o), o) for o in current] def fire_append_event(self, state, dict_, value, initiator): for fn in self.dispatch.append: - value = fn( - state, value, initiator or self._append_token) + value = fn(state, value, initiator or self._append_token) state._modified_event(dict_, self, NEVER_SET, True) @@ -1015,7 +1174,8 @@ class CollectionAttributeImpl(AttributeImpl): def _initialize_collection(self, state): adapter, collection = state.manager.initialize_collection( - self.key, state, self.collection_factory) + self.key, state, self.collection_factory + ) self.dispatch.init_collection(state, collection, adapter) @@ -1025,8 +1185,9 @@ class CollectionAttributeImpl(AttributeImpl): collection = self.get_collection(state, dict_, passive=passive) if collection is PASSIVE_NO_RESULT: value = self.fire_append_event(state, dict_, value, initiator) - assert self.key not in dict_, \ - "Collection was loaded during event handling." + assert ( + self.key not in dict_ + ), "Collection was loaded during event handling." state._get_pending_mutation(self.key).append(value) else: collection.append_with_event(value, initiator) @@ -1035,8 +1196,9 @@ class CollectionAttributeImpl(AttributeImpl): collection = self.get_collection(state, state.dict, passive=passive) if collection is PASSIVE_NO_RESULT: self.fire_remove_event(state, dict_, value, initiator) - assert self.key not in dict_, \ - "Collection was loaded during event handling." + assert ( + self.key not in dict_ + ), "Collection was loaded during event handling." state._get_pending_mutation(self.key).remove(value) else: collection.remove_with_event(value, initiator) @@ -1050,8 +1212,16 @@ class CollectionAttributeImpl(AttributeImpl): except (ValueError, KeyError, IndexError): pass - def set(self, state, dict_, value, initiator=None, - passive=PASSIVE_OFF, pop=False, _adapt=True): + def set( + self, + state, + dict_, + value, + initiator=None, + passive=PASSIVE_OFF, + pop=False, + _adapt=True, + ): iterable = orig_iterable = value # pulling a new collection first so that an adaptation exception does @@ -1065,23 +1235,28 @@ class CollectionAttributeImpl(AttributeImpl): receiving_type = self._duck_typed_as if setting_type is not receiving_type: - given = iterable is None and 'None' or \ - iterable.__class__.__name__ + given = ( + iterable is None + and "None" + or iterable.__class__.__name__ + ) wanted = self._duck_typed_as.__name__ raise TypeError( - "Incompatible collection type: %s is not %s-like" % ( - given, wanted)) + "Incompatible collection type: %s is not %s-like" + % (given, wanted) + ) # If the object is an adapted collection, return the (iterable) # adapter. - if hasattr(iterable, '_sa_iterator'): + if hasattr(iterable, "_sa_iterator"): iterable = iterable._sa_iterator() elif setting_type is dict: if util.py3k: iterable = iterable.values() else: iterable = getattr( - iterable, 'itervalues', iterable.values)() + iterable, "itervalues", iterable.values + )() else: iterable = iter(iterable) new_values = list(iterable) @@ -1106,14 +1281,14 @@ class CollectionAttributeImpl(AttributeImpl): dict_[self.key] = user_data collections.bulk_replace( - new_values, old_collection, new_collection, - initiator=evt) + new_values, old_collection, new_collection, initiator=evt + ) del old._sa_adapter self.dispatch.dispose_collection(state, old, old_collection) def _invalidate_collection(self, collection): - adapter = getattr(collection, '_sa_adapter') + adapter = getattr(collection, "_sa_adapter") adapter.invalidated = True def set_committed_value(self, state, dict_, value): @@ -1143,8 +1318,9 @@ class CollectionAttributeImpl(AttributeImpl): return user_data - def get_collection(self, state, dict_, - user_data=None, passive=PASSIVE_OFF): + def get_collection( + self, state, dict_, user_data=None, passive=PASSIVE_OFF + ): """Retrieve the CollectionAdapter associated with the given state. Creates a new CollectionAdapter if one does not exist. @@ -1155,7 +1331,7 @@ class CollectionAttributeImpl(AttributeImpl): if user_data is PASSIVE_NO_RESULT: return user_data - return getattr(user_data, '_sa_adapter') + return getattr(user_data, "_sa_adapter") def backref_listeners(attribute, key, uselist): @@ -1177,24 +1353,29 @@ def backref_listeners(attribute, key, uselist): "Bidirectional attribute conflict detected: " 'Passing object %s to attribute "%s" ' 'triggers a modify event on attribute "%s" ' - 'via the backref "%s".' % ( + 'via the backref "%s".' + % ( state_str(child_state), initiator.parent_token, child_impl.parent_token, - attribute.impl.parent_token + attribute.impl.parent_token, ) ) def emit_backref_from_scalar_set_event(state, child, oldchild, initiator): if oldchild is child: return child - if oldchild is not None and \ - oldchild is not PASSIVE_NO_RESULT and \ - oldchild is not NEVER_SET: + if ( + oldchild is not None + and oldchild is not PASSIVE_NO_RESULT + and oldchild is not NEVER_SET + ): # With lazy=None, there's no guarantee that the full collection is # present when updating via a backref. - old_state, old_dict = instance_state(oldchild),\ - instance_dict(oldchild) + old_state, old_dict = ( + instance_state(oldchild), + instance_dict(oldchild), + ) impl = old_state.manager[key].impl # tokens to test for a recursive loop. @@ -1204,69 +1385,90 @@ def backref_listeners(attribute, key, uselist): check_recursive_token = impl._remove_token if initiator is not check_recursive_token: - impl.pop(old_state, - old_dict, - state.obj(), - parent_impl._append_token, - passive=PASSIVE_NO_FETCH) + impl.pop( + old_state, + old_dict, + state.obj(), + parent_impl._append_token, + passive=PASSIVE_NO_FETCH, + ) if child is not None: - child_state, child_dict = instance_state(child),\ - instance_dict(child) + child_state, child_dict = ( + instance_state(child), + instance_dict(child), + ) child_impl = child_state.manager[key].impl - if initiator.parent_token is not parent_token and \ - initiator.parent_token is not child_impl.parent_token: + if ( + initiator.parent_token is not parent_token + and initiator.parent_token is not child_impl.parent_token + ): _acceptable_key_err(state, initiator, child_impl) # tokens to test for a recursive loop. check_append_token = child_impl._append_token - check_bulk_replace_token = child_impl._bulk_replace_token \ - if child_impl.collection else None + check_bulk_replace_token = ( + child_impl._bulk_replace_token + if child_impl.collection + else None + ) - if initiator is not check_append_token and \ - initiator is not check_bulk_replace_token: + if ( + initiator is not check_append_token + and initiator is not check_bulk_replace_token + ): child_impl.append( child_state, child_dict, state.obj(), initiator, - passive=PASSIVE_NO_FETCH) + passive=PASSIVE_NO_FETCH, + ) return child def emit_backref_from_collection_append_event(state, child, initiator): if child is None: return - child_state, child_dict = instance_state(child), \ - instance_dict(child) + child_state, child_dict = instance_state(child), instance_dict(child) child_impl = child_state.manager[key].impl - if initiator.parent_token is not parent_token and \ - initiator.parent_token is not child_impl.parent_token: + if ( + initiator.parent_token is not parent_token + and initiator.parent_token is not child_impl.parent_token + ): _acceptable_key_err(state, initiator, child_impl) # tokens to test for a recursive loop. check_append_token = child_impl._append_token - check_bulk_replace_token = child_impl._bulk_replace_token \ - if child_impl.collection else None + check_bulk_replace_token = ( + child_impl._bulk_replace_token if child_impl.collection else None + ) - if initiator is not check_append_token and \ - initiator is not check_bulk_replace_token: + if ( + initiator is not check_append_token + and initiator is not check_bulk_replace_token + ): child_impl.append( child_state, child_dict, state.obj(), initiator, - passive=PASSIVE_NO_FETCH) + passive=PASSIVE_NO_FETCH, + ) return child def emit_backref_from_collection_remove_event(state, child, initiator): - if child is not None and \ - child is not PASSIVE_NO_RESULT and \ - child is not NEVER_SET: - child_state, child_dict = instance_state(child),\ - instance_dict(child) + if ( + child is not None + and child is not PASSIVE_NO_RESULT + and child is not NEVER_SET + ): + child_state, child_dict = ( + instance_state(child), + instance_dict(child), + ) child_impl = child_state.manager[key].impl # tokens to test for a recursive loop. @@ -1276,47 +1478,64 @@ def backref_listeners(attribute, key, uselist): check_for_dupes_on_remove = uselist and not parent_impl.dynamic else: check_remove_token = child_impl._remove_token - check_replace_token = child_impl._bulk_replace_token \ - if child_impl.collection else None + check_replace_token = ( + child_impl._bulk_replace_token + if child_impl.collection + else None + ) check_for_dupes_on_remove = False - if initiator is not check_remove_token and \ - initiator is not check_replace_token: - - if not check_for_dupes_on_remove or \ - not util.has_dupes( - # when this event is called, the item is usually - # present in the list, except for a pop() operation. - state.dict[parent_impl.key], child): + if ( + initiator is not check_remove_token + and initiator is not check_replace_token + ): + + if not check_for_dupes_on_remove or not util.has_dupes( + # when this event is called, the item is usually + # present in the list, except for a pop() operation. + state.dict[parent_impl.key], + child, + ): child_impl.pop( child_state, child_dict, state.obj(), initiator, - passive=PASSIVE_NO_FETCH) + passive=PASSIVE_NO_FETCH, + ) if uselist: - event.listen(attribute, "append", - emit_backref_from_collection_append_event, - retval=True, raw=True) + event.listen( + attribute, + "append", + emit_backref_from_collection_append_event, + retval=True, + raw=True, + ) else: - event.listen(attribute, "set", - emit_backref_from_scalar_set_event, - retval=True, raw=True) + event.listen( + attribute, + "set", + emit_backref_from_scalar_set_event, + retval=True, + raw=True, + ) # TODO: need coverage in test/orm/ of remove event - event.listen(attribute, "remove", - emit_backref_from_collection_remove_event, - retval=True, raw=True) + event.listen( + attribute, + "remove", + emit_backref_from_collection_remove_event, + retval=True, + raw=True, + ) -_NO_HISTORY = util.symbol('NO_HISTORY') -_NO_STATE_SYMBOLS = frozenset([ - id(PASSIVE_NO_RESULT), - id(NO_VALUE), - id(NEVER_SET)]) -History = util.namedtuple("History", [ - "added", "unchanged", "deleted" -]) +_NO_HISTORY = util.symbol("NO_HISTORY") +_NO_STATE_SYMBOLS = frozenset( + [id(PASSIVE_NO_RESULT), id(NO_VALUE), id(NEVER_SET)] +) + +History = util.namedtuple("History", ["added", "unchanged", "deleted"]) class History(History): @@ -1346,6 +1565,7 @@ class History(History): def __bool__(self): return self != HISTORY_BLANK + __nonzero__ = __bool__ def empty(self): @@ -1354,29 +1574,24 @@ class History(History): """ - return not bool( - (self.added or self.deleted) - or self.unchanged - ) + return not bool((self.added or self.deleted) or self.unchanged) def sum(self): """Return a collection of added + unchanged + deleted.""" - return (self.added or []) +\ - (self.unchanged or []) +\ - (self.deleted or []) + return ( + (self.added or []) + (self.unchanged or []) + (self.deleted or []) + ) def non_deleted(self): """Return a collection of added + unchanged.""" - return (self.added or []) +\ - (self.unchanged or []) + return (self.added or []) + (self.unchanged or []) def non_added(self): """Return a collection of unchanged + deleted.""" - return (self.unchanged or []) +\ - (self.deleted or []) + return (self.unchanged or []) + (self.deleted or []) def has_changes(self): """Return True if this :class:`.History` has changes.""" @@ -1385,15 +1600,18 @@ class History(History): def as_state(self): return History( - [(c is not None) - and instance_state(c) or None - for c in self.added], - [(c is not None) - and instance_state(c) or None - for c in self.unchanged], - [(c is not None) - and instance_state(c) or None - for c in self.deleted], + [ + (c is not None) and instance_state(c) or None + for c in self.added + ], + [ + (c is not None) and instance_state(c) or None + for c in self.unchanged + ], + [ + (c is not None) and instance_state(c) or None + for c in self.deleted + ], ) @classmethod @@ -1464,21 +1682,21 @@ class History(History): if current is NO_VALUE or current is NEVER_SET: return cls((), (), ()) - current = getattr(current, '_sa_adapter') + current = getattr(current, "_sa_adapter") if original in (NO_VALUE, NEVER_SET): return cls(list(current), (), ()) elif original is _NO_HISTORY: return cls((), list(current), ()) else: - current_states = [((c is not None) and instance_state(c) - or None, c) - for c in current - ] - original_states = [((c is not None) and instance_state(c) - or None, c) - for c in original - ] + current_states = [ + ((c is not None) and instance_state(c) or None, c) + for c in current + ] + original_states = [ + ((c is not None) and instance_state(c) or None, c) + for c in original + ] current_set = dict(current_states) original_set = dict(original_states) @@ -1486,9 +1704,10 @@ class History(History): return cls( [o for s, o in current_states if s not in original_set], [o for s, o in current_states if s in original_set], - [o for s, o in original_states if s not in current_set] + [o for s, o in original_states if s not in current_set], ) + HISTORY_BLANK = History(None, None, None) @@ -1509,12 +1728,16 @@ def get_history(obj, key, passive=PASSIVE_OFF): """ if passive is True: - util.warn_deprecated("Passing True for 'passive' is deprecated. " - "Use attributes.PASSIVE_NO_INITIALIZE") + util.warn_deprecated( + "Passing True for 'passive' is deprecated. " + "Use attributes.PASSIVE_NO_INITIALIZE" + ) passive = PASSIVE_NO_INITIALIZE elif passive is False: - util.warn_deprecated("Passing False for 'passive' is " - "deprecated. Use attributes.PASSIVE_OFF") + util.warn_deprecated( + "Passing False for 'passive' is " + "deprecated. Use attributes.PASSIVE_OFF" + ) passive = PASSIVE_OFF return get_state_history(instance_state(obj), key, passive) @@ -1532,38 +1755,46 @@ def has_parent(cls, obj, key, optimistic=False): def register_attribute(class_, key, **kw): - comparator = kw.pop('comparator', None) - parententity = kw.pop('parententity', None) - doc = kw.pop('doc', None) - desc = register_descriptor(class_, key, - comparator, parententity, doc=doc) + comparator = kw.pop("comparator", None) + parententity = kw.pop("parententity", None) + doc = kw.pop("doc", None) + desc = register_descriptor(class_, key, comparator, parententity, doc=doc) register_attribute_impl(class_, key, **kw) return desc -def register_attribute_impl(class_, key, - uselist=False, callable_=None, - useobject=False, - impl_class=None, backref=None, **kw): +def register_attribute_impl( + class_, + key, + uselist=False, + callable_=None, + useobject=False, + impl_class=None, + backref=None, + **kw +): manager = manager_of_class(class_) if uselist: - factory = kw.pop('typecallable', None) + factory = kw.pop("typecallable", None) typecallable = manager.instrument_collection_class( - key, factory or list) + key, factory or list + ) else: - typecallable = kw.pop('typecallable', None) + typecallable = kw.pop("typecallable", None) dispatch = manager[key].dispatch if impl_class: impl = impl_class(class_, key, typecallable, dispatch, **kw) elif uselist: - impl = CollectionAttributeImpl(class_, key, callable_, dispatch, - typecallable=typecallable, **kw) + impl = CollectionAttributeImpl( + class_, key, callable_, dispatch, typecallable=typecallable, **kw + ) elif useobject: - impl = ScalarObjectAttributeImpl(class_, key, callable_, - dispatch, **kw) + impl = ScalarObjectAttributeImpl( + class_, key, callable_, dispatch, **kw + ) else: impl = ScalarAttributeImpl(class_, key, callable_, dispatch, **kw) @@ -1576,12 +1807,14 @@ def register_attribute_impl(class_, key, return manager[key] -def register_descriptor(class_, key, comparator=None, - parententity=None, doc=None): +def register_descriptor( + class_, key, comparator=None, parententity=None, doc=None +): manager = manager_of_class(class_) - descriptor = InstrumentedAttribute(class_, key, comparator=comparator, - parententity=parententity) + descriptor = InstrumentedAttribute( + class_, key, comparator=comparator, parententity=parententity + ) descriptor.__doc__ = doc diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index deddaa5a4e..abc572d9ad 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -15,66 +15,69 @@ from . import exc import operator PASSIVE_NO_RESULT = util.symbol( - 'PASSIVE_NO_RESULT', + "PASSIVE_NO_RESULT", """Symbol returned by a loader callable or other attribute/history retrieval operation when a value could not be determined, based on loader callable flags. - """ + """, ) ATTR_WAS_SET = util.symbol( - 'ATTR_WAS_SET', + "ATTR_WAS_SET", """Symbol returned by a loader callable to indicate the retrieved value, or values, were assigned to their attributes on the target object. - """ + """, ) ATTR_EMPTY = util.symbol( - 'ATTR_EMPTY', - """Symbol used internally to indicate an attribute had no callable.""" + "ATTR_EMPTY", + """Symbol used internally to indicate an attribute had no callable.""", ) NO_VALUE = util.symbol( - 'NO_VALUE', + "NO_VALUE", """Symbol which may be placed as the 'previous' value of an attribute, indicating no value was loaded for an attribute when it was modified, and flags indicated we were not to load it. - """ + """, ) NEVER_SET = util.symbol( - 'NEVER_SET', + "NEVER_SET", """Symbol which may be placed as the 'previous' value of an attribute indicating that the attribute had not been assigned to previously. - """ + """, ) NO_CHANGE = util.symbol( "NO_CHANGE", """No callables or SQL should be emitted on attribute access and no state should change - """, canonical=0 + """, + canonical=0, ) CALLABLES_OK = util.symbol( "CALLABLES_OK", """Loader callables can be fired off if a value is not present. - """, canonical=1 + """, + canonical=1, ) SQL_OK = util.symbol( "SQL_OK", """Loader callables can emit SQL at least on scalar value attributes.""", - canonical=2 + canonical=2, ) RELATED_OBJECT_OK = util.symbol( "RELATED_OBJECT_OK", """Callables can use SQL to load related objects as well as scalar value attributes. - """, canonical=4 + """, + canonical=4, ) INIT_OK = util.symbol( @@ -82,111 +85,116 @@ INIT_OK = util.symbol( """Attributes should be initialized with a blank value (None or an empty collection) upon get, if no other value can be obtained. - """, canonical=8 + """, + canonical=8, ) NON_PERSISTENT_OK = util.symbol( "NON_PERSISTENT_OK", """Callables can be emitted if the parent is not persistent.""", - canonical=16 + canonical=16, ) LOAD_AGAINST_COMMITTED = util.symbol( "LOAD_AGAINST_COMMITTED", """Callables should use committed values as primary/foreign keys during a load. - """, canonical=32 + """, + canonical=32, ) NO_AUTOFLUSH = util.symbol( "NO_AUTOFLUSH", """Loader callables should disable autoflush.""", - canonical=64 + canonical=64, ) NO_RAISE = util.symbol( "NO_RAISE", """Loader callables should not raise any assertions""", - canonical=128 + canonical=128, ) # pre-packaged sets of flags used as inputs PASSIVE_OFF = util.symbol( "PASSIVE_OFF", "Callables can be emitted in all cases.", - canonical=(RELATED_OBJECT_OK | NON_PERSISTENT_OK | - INIT_OK | CALLABLES_OK | SQL_OK) + canonical=( + RELATED_OBJECT_OK | NON_PERSISTENT_OK | INIT_OK | CALLABLES_OK | SQL_OK + ), ) PASSIVE_RETURN_NEVER_SET = util.symbol( "PASSIVE_RETURN_NEVER_SET", """PASSIVE_OFF ^ INIT_OK""", - canonical=PASSIVE_OFF ^ INIT_OK + canonical=PASSIVE_OFF ^ INIT_OK, ) PASSIVE_NO_INITIALIZE = util.symbol( "PASSIVE_NO_INITIALIZE", "PASSIVE_RETURN_NEVER_SET ^ CALLABLES_OK", - canonical=PASSIVE_RETURN_NEVER_SET ^ CALLABLES_OK + canonical=PASSIVE_RETURN_NEVER_SET ^ CALLABLES_OK, ) PASSIVE_NO_FETCH = util.symbol( - "PASSIVE_NO_FETCH", - "PASSIVE_OFF ^ SQL_OK", - canonical=PASSIVE_OFF ^ SQL_OK + "PASSIVE_NO_FETCH", "PASSIVE_OFF ^ SQL_OK", canonical=PASSIVE_OFF ^ SQL_OK ) PASSIVE_NO_FETCH_RELATED = util.symbol( "PASSIVE_NO_FETCH_RELATED", "PASSIVE_OFF ^ RELATED_OBJECT_OK", - canonical=PASSIVE_OFF ^ RELATED_OBJECT_OK + canonical=PASSIVE_OFF ^ RELATED_OBJECT_OK, ) PASSIVE_ONLY_PERSISTENT = util.symbol( "PASSIVE_ONLY_PERSISTENT", "PASSIVE_OFF ^ NON_PERSISTENT_OK", - canonical=PASSIVE_OFF ^ NON_PERSISTENT_OK + canonical=PASSIVE_OFF ^ NON_PERSISTENT_OK, ) -DEFAULT_MANAGER_ATTR = '_sa_class_manager' -DEFAULT_STATE_ATTR = '_sa_instance_state' -_INSTRUMENTOR = ('mapper', 'instrumentor') +DEFAULT_MANAGER_ATTR = "_sa_class_manager" +DEFAULT_STATE_ATTR = "_sa_instance_state" +_INSTRUMENTOR = ("mapper", "instrumentor") -EXT_CONTINUE = util.symbol('EXT_CONTINUE') -EXT_STOP = util.symbol('EXT_STOP') -EXT_SKIP = util.symbol('EXT_SKIP') +EXT_CONTINUE = util.symbol("EXT_CONTINUE") +EXT_STOP = util.symbol("EXT_STOP") +EXT_SKIP = util.symbol("EXT_SKIP") ONETOMANY = util.symbol( - 'ONETOMANY', + "ONETOMANY", """Indicates the one-to-many direction for a :func:`.relationship`. This symbol is typically used by the internals but may be exposed within certain API features. - """) + """, +) MANYTOONE = util.symbol( - 'MANYTOONE', + "MANYTOONE", """Indicates the many-to-one direction for a :func:`.relationship`. This symbol is typically used by the internals but may be exposed within certain API features. - """) + """, +) MANYTOMANY = util.symbol( - 'MANYTOMANY', + "MANYTOMANY", """Indicates the many-to-many direction for a :func:`.relationship`. This symbol is typically used by the internals but may be exposed within certain API features. - """) + """, +) NOT_EXTENSION = util.symbol( - 'NOT_EXTENSION', + "NOT_EXTENSION", """Symbol indicating an :class:`InspectionAttr` that's not part of sqlalchemy.ext. Is assigned to the :attr:`.InspectionAttr.extension_type` attibute. - """) + """, +) _never_set = frozenset([NEVER_SET]) @@ -207,6 +215,7 @@ def _generative(*assertions): assertion(self, fn.__name__) fn(self, *args[1:], **kw) return self + return generate @@ -215,9 +224,10 @@ def _generative(*assertions): def manager_of_class(cls): return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None) + instance_state = operator.attrgetter(DEFAULT_STATE_ATTR) -instance_dict = operator.attrgetter('__dict__') +instance_dict = operator.attrgetter("__dict__") def instance_str(instance): @@ -232,7 +242,7 @@ def state_str(state): if state is None: return "None" else: - return '<%s at 0x%x>' % (state.class_.__name__, id(state.obj())) + return "<%s at 0x%x>" % (state.class_.__name__, id(state.obj())) def state_class_str(state): @@ -243,7 +253,7 @@ def state_class_str(state): if state is None: return "None" else: - return '<%s>' % (state.class_.__name__, ) + return "<%s>" % (state.class_.__name__,) def attribute_str(instance, attribute): @@ -335,15 +345,15 @@ def _is_mapped_class(entity): """ insp = inspection.inspect(entity, False) - return insp is not None and \ - not insp.is_clause_element and \ - ( - insp.is_mapper or insp.is_aliased_class - ) + return ( + insp is not None + and not insp.is_clause_element + and (insp.is_mapper or insp.is_aliased_class) + ) def _attr_as_key(attr): - if hasattr(attr, 'key'): + if hasattr(attr, "key"): return attr.key else: return expression._column_as_key(attr) @@ -351,7 +361,7 @@ def _attr_as_key(attr): def _orm_columns(entity): insp = inspection.inspect(entity, False) - if hasattr(insp, 'selectable') and hasattr(insp.selectable, 'c'): + if hasattr(insp, "selectable") and hasattr(insp.selectable, "c"): return [c for c in insp.selectable.c] else: return [entity] @@ -359,8 +369,7 @@ def _orm_columns(entity): def _is_aliased_class(entity): insp = inspection.inspect(entity, False) - return insp is not None and \ - getattr(insp, "is_aliased_class", False) + return insp is not None and getattr(insp, "is_aliased_class", False) def _entity_descriptor(entity, key): @@ -386,11 +395,11 @@ def _entity_descriptor(entity, key): return getattr(entity, key) except AttributeError: raise sa_exc.InvalidRequestError( - "Entity '%s' has no property '%s'" % - (description, key) + "Entity '%s' has no property '%s'" % (description, key) ) -_state_mapper = util.dottedgetter('manager.mapper') + +_state_mapper = util.dottedgetter("manager.mapper") @inspection._inspects(type) @@ -429,7 +438,8 @@ def class_mapper(class_, configure=True): if mapper is None: if not isinstance(class_, type): raise sa_exc.ArgumentError( - "Class object expected, got '%r'." % (class_, )) + "Class object expected, got '%r'." % (class_,) + ) raise exc.UnmappedClassError(class_) else: return mapper @@ -449,6 +459,7 @@ class InspectionAttr(object): here intact for forwards-compatibility. """ + __slots__ = () is_selectable = False @@ -551,4 +562,5 @@ class _MappedAttribute(object): attributes. """ + __slots__ = () diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 54c29bb5e8..be92917415 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -113,9 +113,13 @@ from . import base from sqlalchemy.util.compat import inspect_getargspec -__all__ = ['collection', 'collection_adapter', - 'mapped_collection', 'column_mapped_collection', - 'attribute_mapped_collection'] +__all__ = [ + "collection", + "collection_adapter", + "mapped_collection", + "column_mapped_collection", + "attribute_mapped_collection", +] __instrumentation_mutex = util.threading.Lock() @@ -172,10 +176,12 @@ class _SerializableColumnGetter(object): def __call__(self, value): state = base.instance_state(value) m = base._state_mapper(state) - key = [m._get_state_attr_by_column( - state, state.dict, - m.mapped_table.columns[k]) - for k in self.colkeys] + key = [ + m._get_state_attr_by_column( + state, state.dict, m.mapped_table.columns[k] + ) + for k in self.colkeys + ] if self.composite: return tuple(key) else: @@ -208,16 +214,15 @@ class _SerializableColumnGetterV2(_PlainColumnGetter): return None else: return c.table.key + colkeys = [(c.key, _table_key(c)) for c in cols] return _SerializableColumnGetterV2, (colkeys,) def _cols(self, mapper): cols = [] - metadata = getattr(mapper.local_table, 'metadata', None) + metadata = getattr(mapper.local_table, "metadata", None) for (ckey, tkey) in self.colkeys: - if tkey is None or \ - metadata is None or \ - tkey not in metadata: + if tkey is None or metadata is None or tkey not in metadata: cols.append(mapper.local_table.c[ckey]) else: cols.append(metadata.tables[tkey].c[ckey]) @@ -237,9 +242,10 @@ def column_mapped_collection(mapping_spec): after a session flush. """ - cols = [expression._only_column_elements(q, "mapping_spec") - for q in util.to_list(mapping_spec) - ] + cols = [ + expression._only_column_elements(q, "mapping_spec") + for q in util.to_list(mapping_spec) + ] keyfunc = _PlainColumnGetter(cols) return lambda: MappedCollection(keyfunc) @@ -253,7 +259,7 @@ class _SerializableAttrGetter(object): return self.getter(target) def __reduce__(self): - return _SerializableAttrGetter, (self.name, ) + return _SerializableAttrGetter, (self.name,) def attribute_mapped_collection(attr_name): @@ -311,6 +317,7 @@ class collection(object): def popitem(self): ... """ + # Bundled as a class solely for ease of use: packaging, doc strings, # importability. @@ -355,7 +362,7 @@ class collection(object): promulgation to collection events. """ - fn._sa_instrument_role = 'appender' + fn._sa_instrument_role = "appender" return fn @staticmethod @@ -382,7 +389,7 @@ class collection(object): promulgation to collection events. """ - fn._sa_instrument_role = 'remover' + fn._sa_instrument_role = "remover" return fn @staticmethod @@ -396,7 +403,7 @@ class collection(object): def __iter__(self): ... """ - fn._sa_instrument_role = 'iterator' + fn._sa_instrument_role = "iterator" return fn @staticmethod @@ -435,7 +442,7 @@ class collection(object): and :meth:`.AttributeEvents.dispose_collection` handlers. """ - fn._sa_instrument_role = 'linker' + fn._sa_instrument_role = "linker" return fn link = linker @@ -472,7 +479,7 @@ class collection(object): validation on the values about to be assigned. """ - fn._sa_instrument_role = 'converter' + fn._sa_instrument_role = "converter" return fn @staticmethod @@ -491,9 +498,11 @@ class collection(object): def do_stuff(self, thing, entity=None): ... """ + def decorator(fn): - fn._sa_instrument_before = ('fire_append_event', arg) + fn._sa_instrument_before = ("fire_append_event", arg) return fn + return decorator @staticmethod @@ -511,10 +520,12 @@ class collection(object): def __setitem__(self, index, item): ... """ + def decorator(fn): - fn._sa_instrument_before = ('fire_append_event', arg) - fn._sa_instrument_after = 'fire_remove_event' + fn._sa_instrument_before = ("fire_append_event", arg) + fn._sa_instrument_after = "fire_remove_event" return fn + return decorator @staticmethod @@ -533,9 +544,11 @@ class collection(object): collection.removes_return. """ + def decorator(fn): - fn._sa_instrument_before = ('fire_remove_event', arg) + fn._sa_instrument_before = ("fire_remove_event", arg) return fn + return decorator @staticmethod @@ -553,13 +566,15 @@ class collection(object): collection.remove. """ + def decorator(fn): - fn._sa_instrument_after = 'fire_remove_event' + fn._sa_instrument_after = "fire_remove_event" return fn + return decorator -collection_adapter = operator.attrgetter('_sa_adapter') +collection_adapter = operator.attrgetter("_sa_adapter") """Fetch the :class:`.CollectionAdapter` for a collection.""" @@ -577,7 +592,13 @@ class CollectionAdapter(object): """ __slots__ = ( - 'attr', '_key', '_data', 'owner_state', '_converter', 'invalidated') + "attr", + "_key", + "_data", + "owner_state", + "_converter", + "invalidated", + ) def __init__(self, attr, owner_state, data): self.attr = attr @@ -676,9 +697,8 @@ class CollectionAdapter(object): if self.invalidated: self._warn_invalidated() return self.attr.fire_append_event( - self.owner_state, - self.owner_state.dict, - item, initiator) + self.owner_state, self.owner_state.dict, item, initiator + ) else: return item @@ -694,9 +714,8 @@ class CollectionAdapter(object): if self.invalidated: self._warn_invalidated() self.attr.fire_remove_event( - self.owner_state, - self.owner_state.dict, - item, initiator) + self.owner_state, self.owner_state.dict, item, initiator + ) def fire_pre_remove_event(self, initiator=None): """Notify that an entity is about to be removed from the collection. @@ -708,25 +727,26 @@ class CollectionAdapter(object): if self.invalidated: self._warn_invalidated() self.attr.fire_pre_remove_event( - self.owner_state, - self.owner_state.dict, - initiator=initiator) + self.owner_state, self.owner_state.dict, initiator=initiator + ) def __getstate__(self): - return {'key': self._key, - 'owner_state': self.owner_state, - 'owner_cls': self.owner_state.class_, - 'data': self.data, - 'invalidated': self.invalidated} + return { + "key": self._key, + "owner_state": self.owner_state, + "owner_cls": self.owner_state.class_, + "data": self.data, + "invalidated": self.invalidated, + } def __setstate__(self, d): - self._key = d['key'] - self.owner_state = d['owner_state'] - self._data = weakref.ref(d['data']) - self._converter = d['data']._sa_converter - d['data']._sa_adapter = self - self.invalidated = d['invalidated'] - self.attr = getattr(d['owner_cls'], self._key).impl + self._key = d["key"] + self.owner_state = d["owner_state"] + self._data = weakref.ref(d["data"]) + self._converter = d["data"]._sa_converter + d["data"]._sa_adapter = self + self.invalidated = d["invalidated"] + self.attr = getattr(d["owner_cls"], self._key).impl def bulk_replace(values, existing_adapter, new_adapter, initiator=None): @@ -796,7 +816,7 @@ def prepare_instrumentation(factory): # Instrument the class if needed. if __instrumentation_mutex.acquire(): try: - if getattr(cls, '_sa_instrumented', None) != id(cls): + if getattr(cls, "_sa_instrumented", None) != id(cls): _instrument_class(cls) finally: __instrumentation_mutex.release() @@ -829,10 +849,11 @@ def _instrument_class(cls): # In the normal call flow, a request for any of the 3 basic collection # types is transformed into one of our trivial subclasses # (e.g. InstrumentedList). Catch anything else that sneaks in here... - if cls.__module__ == '__builtin__': + if cls.__module__ == "__builtin__": raise sa_exc.ArgumentError( "Can not instrument a built-in type. Use a " - "subclass, even a trivial one.") + "subclass, even a trivial one." + ) roles, methods = _locate_roles_and_methods(cls) @@ -858,25 +879,30 @@ def _locate_roles_and_methods(cls): continue # note role declarations - if hasattr(method, '_sa_instrument_role'): + if hasattr(method, "_sa_instrument_role"): role = method._sa_instrument_role - assert role in ('appender', 'remover', 'iterator', - 'linker', 'converter') + assert role in ( + "appender", + "remover", + "iterator", + "linker", + "converter", + ) roles.setdefault(role, name) # transfer instrumentation requests from decorated function # to the combined queue before, after = None, None - if hasattr(method, '_sa_instrument_before'): + if hasattr(method, "_sa_instrument_before"): op, argument = method._sa_instrument_before - assert op in ('fire_append_event', 'fire_remove_event') + assert op in ("fire_append_event", "fire_remove_event") before = op, argument - if hasattr(method, '_sa_instrument_after'): + if hasattr(method, "_sa_instrument_after"): op = method._sa_instrument_after - assert op in ('fire_append_event', 'fire_remove_event') + assert op in ("fire_append_event", "fire_remove_event") after = op if before: - methods[name] = before + (after, ) + methods[name] = before + (after,) elif after: methods[name] = None, None, after return roles, methods @@ -898,8 +924,11 @@ def _setup_canned_roles(cls, roles, methods): # apply ABC auto-decoration to methods that need it for method, decorator in decorators.items(): fn = getattr(cls, method, None) - if (fn and method not in methods and - not hasattr(fn, '_sa_instrumented')): + if ( + fn + and method not in methods + and not hasattr(fn, "_sa_instrumented") + ): setattr(cls, method, decorator(fn)) @@ -908,26 +937,31 @@ def _assert_required_roles(cls, roles, methods): needed """ - if 'appender' not in roles or not hasattr(cls, roles['appender']): + if "appender" not in roles or not hasattr(cls, roles["appender"]): raise sa_exc.ArgumentError( "Type %s must elect an appender method to be " - "a collection class" % cls.__name__) - elif (roles['appender'] not in methods and - not hasattr(getattr(cls, roles['appender']), '_sa_instrumented')): - methods[roles['appender']] = ('fire_append_event', 1, None) - - if 'remover' not in roles or not hasattr(cls, roles['remover']): + "a collection class" % cls.__name__ + ) + elif roles["appender"] not in methods and not hasattr( + getattr(cls, roles["appender"]), "_sa_instrumented" + ): + methods[roles["appender"]] = ("fire_append_event", 1, None) + + if "remover" not in roles or not hasattr(cls, roles["remover"]): raise sa_exc.ArgumentError( "Type %s must elect a remover method to be " - "a collection class" % cls.__name__) - elif (roles['remover'] not in methods and - not hasattr(getattr(cls, roles['remover']), '_sa_instrumented')): - methods[roles['remover']] = ('fire_remove_event', 1, None) - - if 'iterator' not in roles or not hasattr(cls, roles['iterator']): + "a collection class" % cls.__name__ + ) + elif roles["remover"] not in methods and not hasattr( + getattr(cls, roles["remover"]), "_sa_instrumented" + ): + methods[roles["remover"]] = ("fire_remove_event", 1, None) + + if "iterator" not in roles or not hasattr(cls, roles["iterator"]): raise sa_exc.ArgumentError( "Type %s must elect an iterator method to be " - "a collection class" % cls.__name__) + "a collection class" % cls.__name__ + ) def _set_collection_attributes(cls, roles, methods): @@ -936,16 +970,20 @@ def _set_collection_attributes(cls, roles, methods): """ for method_name, (before, argument, after) in methods.items(): - setattr(cls, method_name, - _instrument_membership_mutator(getattr(cls, method_name), - before, argument, after)) + setattr( + cls, + method_name, + _instrument_membership_mutator( + getattr(cls, method_name), before, argument, after + ), + ) # intern the role map for role, method_name in roles.items(): - setattr(cls, '_sa_%s' % role, getattr(cls, method_name)) + setattr(cls, "_sa_%s" % role, getattr(cls, method_name)) cls._sa_adapter = None - if not hasattr(cls, '_sa_converter'): + if not hasattr(cls, "_sa_converter"): cls._sa_converter = None cls._sa_instrumented = id(cls) @@ -972,7 +1010,8 @@ def _instrument_membership_mutator(method, before, argument, after): if pos_arg is None: if named_arg not in kw: raise sa_exc.ArgumentError( - "Missing argument %s" % argument) + "Missing argument %s" % argument + ) value = kw[named_arg] else: if len(args) > pos_arg: @@ -981,9 +1020,10 @@ def _instrument_membership_mutator(method, before, argument, after): value = kw[named_arg] else: raise sa_exc.ArgumentError( - "Missing argument %s" % argument) + "Missing argument %s" % argument + ) - initiator = kw.pop('_sa_initiator', None) + initiator = kw.pop("_sa_initiator", None) if initiator is False: executor = None else: @@ -1055,6 +1095,7 @@ def _list_decorators(): def append(self, item, _sa_initiator=None): item = __set(self, item, _sa_initiator) fn(self, item) + _tidy(append) return append @@ -1063,6 +1104,7 @@ def _list_decorators(): __del(self, value, _sa_initiator) # testlib.pragma exempt:__eq__ fn(self, value) + _tidy(remove) return remove @@ -1070,6 +1112,7 @@ def _list_decorators(): def insert(self, index, value): value = __set(self, value) fn(self, index, value) + _tidy(insert) return insert @@ -1106,10 +1149,12 @@ def _list_decorators(): if len(value) != len(rng): raise ValueError( "attempt to assign sequence of size %s to " - "extended slice of size %s" % (len(value), - len(rng))) + "extended slice of size %s" + % (len(value), len(rng)) + ) for i, item in zip(rng, value): self.__setitem__(i, item) + _tidy(__setitem__) return __setitem__ @@ -1126,16 +1171,19 @@ def _list_decorators(): for item in self[index]: __del(self, item) fn(self, index) + _tidy(__delitem__) return __delitem__ if util.py2k: + def __setslice__(fn): def __setslice__(self, start, end, values): for value in self[start:end]: __del(self, value) values = [__set(self, value) for value in values] fn(self, start, end, values) + _tidy(__setslice__) return __setslice__ @@ -1144,6 +1192,7 @@ def _list_decorators(): for value in self[start:end]: __del(self, value) fn(self, start, end) + _tidy(__delslice__) return __delslice__ @@ -1151,6 +1200,7 @@ def _list_decorators(): def extend(self, iterable): for value in iterable: self.append(value) + _tidy(extend) return extend @@ -1161,6 +1211,7 @@ def _list_decorators(): for value in iterable: self.append(value) return self + _tidy(__iadd__) return __iadd__ @@ -1170,15 +1221,18 @@ def _list_decorators(): item = fn(self, index) __del(self, item) return item + _tidy(pop) return pop if not util.py2k: + def clear(fn): def clear(self, index=-1): for item in self: __del(self, item) fn(self) + _tidy(clear) return clear @@ -1188,7 +1242,7 @@ def _list_decorators(): # desired. hard to imagine a use case for __imul__, though. l = locals().copy() - l.pop('_tidy') + l.pop("_tidy") return l @@ -1199,7 +1253,7 @@ def _dict_decorators(): fn._sa_instrumented = True fn.__doc__ = getattr(dict, fn.__name__).__doc__ - Unspecified = util.symbol('Unspecified') + Unspecified = util.symbol("Unspecified") def __setitem__(fn): def __setitem__(self, key, value, _sa_initiator=None): @@ -1207,6 +1261,7 @@ def _dict_decorators(): __del(self, self[key], _sa_initiator) value = __set(self, value, _sa_initiator) fn(self, key, value) + _tidy(__setitem__) return __setitem__ @@ -1215,6 +1270,7 @@ def _dict_decorators(): if key in self: __del(self, self[key], _sa_initiator) fn(self, key) + _tidy(__delitem__) return __delitem__ @@ -1223,6 +1279,7 @@ def _dict_decorators(): for key in self: __del(self, self[key]) fn(self) + _tidy(clear) return clear @@ -1237,6 +1294,7 @@ def _dict_decorators(): if _to_del: __del(self, item) return item + _tidy(pop) return pop @@ -1246,6 +1304,7 @@ def _dict_decorators(): item = fn(self) __del(self, item[1]) return item + _tidy(popitem) return popitem @@ -1256,16 +1315,16 @@ def _dict_decorators(): return default else: return self.__getitem__(key) + _tidy(setdefault) return setdefault def update(fn): def update(self, __other=Unspecified, **kw): if __other is not Unspecified: - if hasattr(__other, 'keys'): + if hasattr(__other, "keys"): for key in list(__other): - if (key not in self or - self[key] is not __other[key]): + if key not in self or self[key] is not __other[key]: self[key] = __other[key] else: for key, value in __other: @@ -1274,14 +1333,16 @@ def _dict_decorators(): for key in kw: if key not in self or self[key] is not kw[key]: self[key] = kw[key] + _tidy(update) return update l = locals().copy() - l.pop('_tidy') - l.pop('Unspecified') + l.pop("_tidy") + l.pop("Unspecified") return l + _set_binop_bases = (set, frozenset) @@ -1293,8 +1354,10 @@ def _set_binops_check_strict(self, obj): def _set_binops_check_loose(self, obj): """Allow anything set-like to participate in set binops.""" - return (isinstance(obj, _set_binop_bases + (self.__class__,)) or - util.duck_type_collection(obj) == set) + return ( + isinstance(obj, _set_binop_bases + (self.__class__,)) + or util.duck_type_collection(obj) == set + ) def _set_decorators(): @@ -1304,7 +1367,7 @@ def _set_decorators(): fn._sa_instrumented = True fn.__doc__ = getattr(set, fn.__name__).__doc__ - Unspecified = util.symbol('Unspecified') + Unspecified = util.symbol("Unspecified") def add(fn): def add(self, value, _sa_initiator=None): @@ -1312,6 +1375,7 @@ def _set_decorators(): value = __set(self, value, _sa_initiator) # testlib.pragma exempt:__hash__ fn(self, value) + _tidy(add) return add @@ -1322,6 +1386,7 @@ def _set_decorators(): __del(self, value, _sa_initiator) # testlib.pragma exempt:__hash__ fn(self, value) + _tidy(discard) return discard @@ -1332,6 +1397,7 @@ def _set_decorators(): __del(self, value, _sa_initiator) # testlib.pragma exempt:__hash__ fn(self, value) + _tidy(remove) return remove @@ -1343,6 +1409,7 @@ def _set_decorators(): # that will be popped before pop is called. __del(self, item) return item + _tidy(pop) return pop @@ -1350,6 +1417,7 @@ def _set_decorators(): def clear(self): for item in list(self): self.remove(item) + _tidy(clear) return clear @@ -1357,6 +1425,7 @@ def _set_decorators(): def update(self, value): for item in value: self.add(item) + _tidy(update) return update @@ -1367,6 +1436,7 @@ def _set_decorators(): for item in value: self.add(item) return self + _tidy(__ior__) return __ior__ @@ -1374,6 +1444,7 @@ def _set_decorators(): def difference_update(self, value): for item in value: self.discard(item) + _tidy(difference_update) return difference_update @@ -1384,6 +1455,7 @@ def _set_decorators(): for item in value: self.discard(item) return self + _tidy(__isub__) return __isub__ @@ -1396,6 +1468,7 @@ def _set_decorators(): self.remove(item) for item in add: self.add(item) + _tidy(intersection_update) return intersection_update @@ -1411,6 +1484,7 @@ def _set_decorators(): for item in add: self.add(item) return self + _tidy(__iand__) return __iand__ @@ -1423,6 +1497,7 @@ def _set_decorators(): self.remove(item) for item in add: self.add(item) + _tidy(symmetric_difference_update) return symmetric_difference_update @@ -1438,12 +1513,13 @@ def _set_decorators(): for item in add: self.add(item) return self + _tidy(__ixor__) return __ixor__ l = locals().copy() - l.pop('_tidy') - l.pop('Unspecified') + l.pop("_tidy") + l.pop("Unspecified") return l @@ -1467,18 +1543,17 @@ __canned_instrumentation = { __interfaces = { list: ( - {'appender': 'append', 'remover': 'remove', - 'iterator': '__iter__'}, _list_decorators() + {"appender": "append", "remover": "remove", "iterator": "__iter__"}, + _list_decorators(), + ), + set: ( + {"appender": "add", "remover": "remove", "iterator": "__iter__"}, + _set_decorators(), ), - - set: ({'appender': 'add', - 'remover': 'remove', - 'iterator': '__iter__'}, _set_decorators() - ), - # decorators are required for dicts and object collections. - dict: ({'iterator': 'values'}, _dict_decorators()) if util.py3k - else ({'iterator': 'itervalues'}, _dict_decorators()), + dict: ({"iterator": "values"}, _dict_decorators()) + if util.py3k + else ({"iterator": "itervalues"}, _dict_decorators()), } @@ -1529,10 +1604,11 @@ class MappedCollection(dict): "Can not remove '%s': collection holds '%s' for key '%s'. " "Possible cause: is the MappedCollection key function " "based on mutable properties or properties that only obtain " - "values after flush?" % - (value, self[key], key)) + "values after flush?" % (value, self[key], key) + ) self.__delitem__(key, _sa_initiator) + # ensure instrumentation is associated with # these built-in classes; if a user-defined class # subclasses these and uses @internally_instrumented, diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index 960b9e5d59..cba4d2141a 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -10,8 +10,7 @@ """ from .. import sql, util, exc as sa_exc -from . import attributes, exc, sync, unitofwork, \ - util as mapperutil +from . import attributes, exc, sync, unitofwork, util as mapperutil from .interfaces import ONETOMANY, MANYTOONE, MANYTOMANY @@ -41,8 +40,8 @@ class DependencyProcessor(object): raise sa_exc.ArgumentError( "Can't build a DependencyProcessor for relationship %s. " "No target attributes to populate between parent and " - "child are present" % - self.prop) + "child are present" % self.prop + ) @classmethod def from_relationship(cls, prop): @@ -70,31 +69,28 @@ class DependencyProcessor(object): before_delete = unitofwork.ProcessAll(uow, self, True, True) parent_saves = unitofwork.SaveUpdateAll( - uow, - self.parent.primary_base_mapper + uow, self.parent.primary_base_mapper ) child_saves = unitofwork.SaveUpdateAll( - uow, - self.mapper.primary_base_mapper + uow, self.mapper.primary_base_mapper ) parent_deletes = unitofwork.DeleteAll( - uow, - self.parent.primary_base_mapper + uow, self.parent.primary_base_mapper ) child_deletes = unitofwork.DeleteAll( - uow, - self.mapper.primary_base_mapper + uow, self.mapper.primary_base_mapper ) - self.per_property_dependencies(uow, - parent_saves, - child_saves, - parent_deletes, - child_deletes, - after_save, - before_delete - ) + self.per_property_dependencies( + uow, + parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete, + ) def per_state_flush_actions(self, uow, states, isdelete): """establish actions and dependencies related to a flush. @@ -130,9 +126,7 @@ class DependencyProcessor(object): # child side is not part of the cycle, so we will link per-state # actions to the aggregate "saves", "deletes" actions - child_actions = [ - (child_saves, False), (child_deletes, True) - ] + child_actions = [(child_saves, False), (child_deletes, True)] child_in_cycles = False else: child_in_cycles = True @@ -140,15 +134,13 @@ class DependencyProcessor(object): # check if the "parent" side is part of the cycle if not isdelete: parent_saves = unitofwork.SaveUpdateAll( - uow, - self.parent.base_mapper) + uow, self.parent.base_mapper + ) parent_deletes = before_delete = None if parent_saves in uow.cycles: parent_in_cycles = True else: - parent_deletes = unitofwork.DeleteAll( - uow, - self.parent.base_mapper) + parent_deletes = unitofwork.DeleteAll(uow, self.parent.base_mapper) parent_saves = after_save = None if parent_deletes in uow.cycles: parent_in_cycles = True @@ -160,17 +152,18 @@ class DependencyProcessor(object): # by a preprocessor on this state/attribute. In the # case of deletes we may try to load missing items here as well. sum_ = state.manager[self.key].impl.get_all_pending( - state, state.dict, + state, + state.dict, self._passive_delete_flag if isdelete - else attributes.PASSIVE_NO_INITIALIZE) + else attributes.PASSIVE_NO_INITIALIZE, + ) if not sum_: continue if isdelete: - before_delete = unitofwork.ProcessState(uow, - self, True, state) + before_delete = unitofwork.ProcessState(uow, self, True, state) if parent_in_cycles: parent_deletes = unitofwork.DeleteState(uow, state) else: @@ -188,21 +181,28 @@ class DependencyProcessor(object): if deleted: child_action = ( unitofwork.DeleteState(uow, child_state), - True) + True, + ) else: child_action = ( unitofwork.SaveUpdateState(uow, child_state), - False) + False, + ) child_actions.append(child_action) # establish dependencies between our possibly per-state # parent action and our possibly per-state child action. for child_action, childisdelete in child_actions: - self.per_state_dependencies(uow, parent_saves, - parent_deletes, - child_action, - after_save, before_delete, - isdelete, childisdelete) + self.per_state_dependencies( + uow, + parent_saves, + parent_deletes, + child_action, + after_save, + before_delete, + isdelete, + childisdelete, + ) def presort_deletes(self, uowcommit, states): return False @@ -228,76 +228,74 @@ class DependencyProcessor(object): # TODO: add a high speed method # to InstanceState which returns: attribute # has a non-None value, or had one - history = uowcommit.get_attribute_history( - s, - self.key, - passive) + history = uowcommit.get_attribute_history(s, self.key, passive) if history and not history.empty(): return True else: - return states and \ - not self.prop._is_self_referential and \ - self.mapper in uowcommit.mappers + return ( + states + and not self.prop._is_self_referential + and self.mapper in uowcommit.mappers + ) def _verify_canload(self, state): if self.prop.uselist and state is None: raise exc.FlushError( "Can't flush None value found in " - "collection %s" % (self.prop, )) - elif state is not None and \ - not self.mapper._canload( - state, allow_subtypes=not self.enable_typechecks): + "collection %s" % (self.prop,) + ) + elif state is not None and not self.mapper._canload( + state, allow_subtypes=not self.enable_typechecks + ): if self.mapper._canload(state, allow_subtypes=True): - raise exc.FlushError('Attempting to flush an item of type ' - '%(x)s as a member of collection ' - '"%(y)s". Expected an object of type ' - '%(z)s or a polymorphic subclass of ' - 'this type. If %(x)s is a subclass of ' - '%(z)s, configure mapper "%(zm)s" to ' - 'load this subtype polymorphically, or ' - 'set enable_typechecks=False to allow ' - 'any subtype to be accepted for flush. ' - % { - 'x': state.class_, - 'y': self.prop, - 'z': self.mapper.class_, - 'zm': self.mapper, - }) + raise exc.FlushError( + "Attempting to flush an item of type " + "%(x)s as a member of collection " + '"%(y)s". Expected an object of type ' + "%(z)s or a polymorphic subclass of " + "this type. If %(x)s is a subclass of " + '%(z)s, configure mapper "%(zm)s" to ' + "load this subtype polymorphically, or " + "set enable_typechecks=False to allow " + "any subtype to be accepted for flush. " + % { + "x": state.class_, + "y": self.prop, + "z": self.mapper.class_, + "zm": self.mapper, + } + ) else: raise exc.FlushError( - 'Attempting to flush an item of type ' - '%(x)s as a member of collection ' + "Attempting to flush an item of type " + "%(x)s as a member of collection " '"%(y)s". Expected an object of type ' - '%(z)s or a polymorphic subclass of ' - 'this type.' % { - 'x': state.class_, - 'y': self.prop, - 'z': self.mapper.class_, - }) - - def _synchronize(self, state, child, associationrow, - clearkeys, uowcommit): + "%(z)s or a polymorphic subclass of " + "this type." + % { + "x": state.class_, + "y": self.prop, + "z": self.mapper.class_, + } + ) + + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): raise NotImplementedError() def _get_reversed_processed_set(self, uow): if not self.prop._reverse_property: return None - process_key = tuple(sorted( - [self.key] + - [p.key for p in self.prop._reverse_property] - )) - return uow.memo( - ('reverse_key', process_key), - set + process_key = tuple( + sorted([self.key] + [p.key for p in self.prop._reverse_property]) ) + return uow.memo(("reverse_key", process_key), set) def _post_update(self, state, uowcommit, related, is_m2o_delete=False): for x in related: if not is_m2o_delete or x is not None: uowcommit.register_post_update( - state, - [r for l, r in self.prop.synchronize_pairs] + state, [r for l, r in self.prop.synchronize_pairs] ) break @@ -309,114 +307,126 @@ class DependencyProcessor(object): class OneToManyDP(DependencyProcessor): - - def per_property_dependencies(self, uow, parent_saves, - child_saves, - parent_deletes, - child_deletes, - after_save, - before_delete, - ): + def per_property_dependencies( + self, + uow, + parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete, + ): if self.post_update: child_post_updates = unitofwork.PostUpdateAll( - uow, - self.mapper.primary_base_mapper, - False) + uow, self.mapper.primary_base_mapper, False + ) child_pre_updates = unitofwork.PostUpdateAll( - uow, - self.mapper.primary_base_mapper, - True) - - uow.dependencies.update([ - (child_saves, after_save), - (parent_saves, after_save), - (after_save, child_post_updates), - - (before_delete, child_pre_updates), - (child_pre_updates, parent_deletes), - (child_pre_updates, child_deletes), - - ]) + uow, self.mapper.primary_base_mapper, True + ) + + uow.dependencies.update( + [ + (child_saves, after_save), + (parent_saves, after_save), + (after_save, child_post_updates), + (before_delete, child_pre_updates), + (child_pre_updates, parent_deletes), + (child_pre_updates, child_deletes), + ] + ) else: - uow.dependencies.update([ - (parent_saves, after_save), - (after_save, child_saves), - (after_save, child_deletes), - - (child_saves, parent_deletes), - (child_deletes, parent_deletes), - - (before_delete, child_saves), - (before_delete, child_deletes), - ]) - - def per_state_dependencies(self, uow, - save_parent, - delete_parent, - child_action, - after_save, before_delete, - isdelete, childisdelete): + uow.dependencies.update( + [ + (parent_saves, after_save), + (after_save, child_saves), + (after_save, child_deletes), + (child_saves, parent_deletes), + (child_deletes, parent_deletes), + (before_delete, child_saves), + (before_delete, child_deletes), + ] + ) + + def per_state_dependencies( + self, + uow, + save_parent, + delete_parent, + child_action, + after_save, + before_delete, + isdelete, + childisdelete, + ): if self.post_update: child_post_updates = unitofwork.PostUpdateAll( - uow, - self.mapper.primary_base_mapper, - False) + uow, self.mapper.primary_base_mapper, False + ) child_pre_updates = unitofwork.PostUpdateAll( - uow, - self.mapper.primary_base_mapper, - True) + uow, self.mapper.primary_base_mapper, True + ) # TODO: this whole block is not covered # by any tests if not isdelete: if childisdelete: - uow.dependencies.update([ - (child_action, after_save), - (after_save, child_post_updates), - ]) + uow.dependencies.update( + [ + (child_action, after_save), + (after_save, child_post_updates), + ] + ) else: - uow.dependencies.update([ - (save_parent, after_save), - (child_action, after_save), - (after_save, child_post_updates), - ]) + uow.dependencies.update( + [ + (save_parent, after_save), + (child_action, after_save), + (after_save, child_post_updates), + ] + ) else: if childisdelete: - uow.dependencies.update([ - (before_delete, child_pre_updates), - (child_pre_updates, delete_parent), - ]) + uow.dependencies.update( + [ + (before_delete, child_pre_updates), + (child_pre_updates, delete_parent), + ] + ) else: - uow.dependencies.update([ - (before_delete, child_pre_updates), - (child_pre_updates, delete_parent), - ]) + uow.dependencies.update( + [ + (before_delete, child_pre_updates), + (child_pre_updates, delete_parent), + ] + ) elif not isdelete: - uow.dependencies.update([ - (save_parent, after_save), - (after_save, child_action), - (save_parent, child_action) - ]) + uow.dependencies.update( + [ + (save_parent, after_save), + (after_save, child_action), + (save_parent, child_action), + ] + ) else: - uow.dependencies.update([ - (before_delete, child_action), - (child_action, delete_parent) - ]) + uow.dependencies.update( + [(before_delete, child_action), (child_action, delete_parent)] + ) def presort_deletes(self, uowcommit, states): # head object is being deleted, and we manage its list of # child objects the child objects have to have their # foreign key to the parent set to NULL - should_null_fks = not self.cascade.delete and \ - not self.passive_deletes == 'all' + should_null_fks = ( + not self.cascade.delete and not self.passive_deletes == "all" + ) for state in states: history = uowcommit.get_attribute_history( - state, - self.key, - self._passive_delete_flag) + state, self.key, self._passive_delete_flag + ) if history: for child in history.deleted: if child is not None and self.hasparent(child) is False: @@ -429,13 +439,16 @@ class OneToManyDP(DependencyProcessor): for child in history.unchanged: if child is not None: uowcommit.register_object( - child, operation="delete", prop=self.prop) + child, operation="delete", prop=self.prop + ) def presort_saves(self, uowcommit, states): - children_added = uowcommit.memo(('children_added', self), set) + children_added = uowcommit.memo(("children_added", self), set) - should_null_fks = not self.cascade.delete_orphan and \ - not self.passive_deletes == 'all' + should_null_fks = ( + not self.cascade.delete_orphan + and not self.passive_deletes == "all" + ) for state in states: pks_changed = self._pks_changed(uowcommit, state) @@ -445,34 +458,39 @@ class OneToManyDP(DependencyProcessor): else: passive = attributes.PASSIVE_OFF - history = uowcommit.get_attribute_history( - state, - self.key, - passive) + history = uowcommit.get_attribute_history(state, self.key, passive) if history: for child in history.added: if child is not None: - uowcommit.register_object(child, cancel_delete=True, - operation="add", - prop=self.prop) + uowcommit.register_object( + child, + cancel_delete=True, + operation="add", + prop=self.prop, + ) children_added.update(history.added) for child in history.deleted: if not self.cascade.delete_orphan: if should_null_fks: - uowcommit.register_object(child, isdelete=False, - operation='delete', - prop=self.prop) + uowcommit.register_object( + child, + isdelete=False, + operation="delete", + prop=self.prop, + ) elif self.hasparent(child) is False: uowcommit.register_object( - child, isdelete=True, - operation="delete", prop=self.prop) + child, + isdelete=True, + operation="delete", + prop=self.prop, + ) for c, m, st_, dct_ in self.mapper.cascade_iterator( - 'delete', child): - uowcommit.register_object( - st_, - isdelete=True) + "delete", child + ): + uowcommit.register_object(st_, isdelete=True) if pks_changed: if history: @@ -483,7 +501,8 @@ class OneToManyDP(DependencyProcessor): False, self.passive_updates, operation="pk change", - prop=self.prop) + prop=self.prop, + ) def process_deletes(self, uowcommit, states): # head object is being deleted, and we manage its list of @@ -492,39 +511,37 @@ class OneToManyDP(DependencyProcessor): # safely for any cascade but is unnecessary if delete cascade # is on. - if self.post_update or not self.passive_deletes == 'all': - children_added = uowcommit.memo(('children_added', self), set) + if self.post_update or not self.passive_deletes == "all": + children_added = uowcommit.memo(("children_added", self), set) for state in states: history = uowcommit.get_attribute_history( - state, - self.key, - self._passive_delete_flag) + state, self.key, self._passive_delete_flag + ) if history: for child in history.deleted: - if child is not None and \ - self.hasparent(child) is False: + if ( + child is not None + and self.hasparent(child) is False + ): self._synchronize( - state, - child, - None, True, - uowcommit, False) + state, child, None, True, uowcommit, False + ) if self.post_update and child: self._post_update(child, uowcommit, [state]) if self.post_update or not self.cascade.delete: - for child in set(history.unchanged).\ - difference(children_added): + for child in set(history.unchanged).difference( + children_added + ): if child is not None: self._synchronize( - state, - child, - None, True, - uowcommit, False) + state, child, None, True, uowcommit, False + ) if self.post_update and child: - self._post_update(child, - uowcommit, - [state]) + self._post_update( + child, uowcommit, [state] + ) # technically, we can even remove each child from the # collection here too. but this would be a somewhat @@ -532,54 +549,66 @@ class OneToManyDP(DependencyProcessor): # if the old parent wasn't deleted but child was moved. def process_saves(self, uowcommit, states): - should_null_fks = not self.cascade.delete_orphan and \ - not self.passive_deletes == 'all' + should_null_fks = ( + not self.cascade.delete_orphan + and not self.passive_deletes == "all" + ) for state in states: history = uowcommit.get_attribute_history( - state, - self.key, - attributes.PASSIVE_NO_INITIALIZE) + state, self.key, attributes.PASSIVE_NO_INITIALIZE + ) if history: for child in history.added: - self._synchronize(state, child, None, - False, uowcommit, False) + self._synchronize( + state, child, None, False, uowcommit, False + ) if child is not None and self.post_update: self._post_update(child, uowcommit, [state]) for child in history.deleted: - if should_null_fks and not self.cascade.delete_orphan and \ - not self.hasparent(child): - self._synchronize(state, child, None, True, - uowcommit, False) + if ( + should_null_fks + and not self.cascade.delete_orphan + and not self.hasparent(child) + ): + self._synchronize( + state, child, None, True, uowcommit, False + ) if self._pks_changed(uowcommit, state): for child in history.unchanged: - self._synchronize(state, child, None, - False, uowcommit, True) + self._synchronize( + state, child, None, False, uowcommit, True + ) - def _synchronize(self, state, child, - associationrow, clearkeys, uowcommit, - pks_changed): + def _synchronize( + self, state, child, associationrow, clearkeys, uowcommit, pks_changed + ): source = state dest = child self._verify_canload(child) - if dest is None or \ - (not self.post_update and uowcommit.is_deleted(dest)): + if dest is None or ( + not self.post_update and uowcommit.is_deleted(dest) + ): return if clearkeys: sync.clear(dest, self.mapper, self.prop.synchronize_pairs) else: - sync.populate(source, self.parent, dest, self.mapper, - self.prop.synchronize_pairs, uowcommit, - self.passive_updates and pks_changed) + sync.populate( + source, + self.parent, + dest, + self.mapper, + self.prop.synchronize_pairs, + uowcommit, + self.passive_updates and pks_changed, + ) def _pks_changed(self, uowcommit, state): return sync.source_modified( - uowcommit, - state, - self.parent, - self.prop.synchronize_pairs) + uowcommit, state, self.parent, self.prop.synchronize_pairs + ) class ManyToOneDP(DependencyProcessor): @@ -587,105 +616,110 @@ class ManyToOneDP(DependencyProcessor): DependencyProcessor.__init__(self, prop) self.mapper._dependency_processors.append(DetectKeySwitch(prop)) - def per_property_dependencies(self, uow, - parent_saves, - child_saves, - parent_deletes, - child_deletes, - after_save, - before_delete): + def per_property_dependencies( + self, + uow, + parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete, + ): if self.post_update: parent_post_updates = unitofwork.PostUpdateAll( - uow, - self.parent.primary_base_mapper, - False) + uow, self.parent.primary_base_mapper, False + ) parent_pre_updates = unitofwork.PostUpdateAll( - uow, - self.parent.primary_base_mapper, - True) - - uow.dependencies.update([ - (child_saves, after_save), - (parent_saves, after_save), - (after_save, parent_post_updates), - - (after_save, parent_pre_updates), - (before_delete, parent_pre_updates), - - (parent_pre_updates, child_deletes), - (parent_pre_updates, parent_deletes), - ]) + uow, self.parent.primary_base_mapper, True + ) + + uow.dependencies.update( + [ + (child_saves, after_save), + (parent_saves, after_save), + (after_save, parent_post_updates), + (after_save, parent_pre_updates), + (before_delete, parent_pre_updates), + (parent_pre_updates, child_deletes), + (parent_pre_updates, parent_deletes), + ] + ) else: - uow.dependencies.update([ - (child_saves, after_save), - (after_save, parent_saves), - (parent_saves, child_deletes), - (parent_deletes, child_deletes) - ]) - - def per_state_dependencies(self, uow, - save_parent, - delete_parent, - child_action, - after_save, before_delete, - isdelete, childisdelete): + uow.dependencies.update( + [ + (child_saves, after_save), + (after_save, parent_saves), + (parent_saves, child_deletes), + (parent_deletes, child_deletes), + ] + ) + + def per_state_dependencies( + self, + uow, + save_parent, + delete_parent, + child_action, + after_save, + before_delete, + isdelete, + childisdelete, + ): if self.post_update: if not isdelete: parent_post_updates = unitofwork.PostUpdateAll( - uow, - self.parent.primary_base_mapper, - False) + uow, self.parent.primary_base_mapper, False + ) if childisdelete: - uow.dependencies.update([ - (after_save, parent_post_updates), - (parent_post_updates, child_action) - ]) + uow.dependencies.update( + [ + (after_save, parent_post_updates), + (parent_post_updates, child_action), + ] + ) else: - uow.dependencies.update([ - (save_parent, after_save), - (child_action, after_save), - - (after_save, parent_post_updates) - ]) + uow.dependencies.update( + [ + (save_parent, after_save), + (child_action, after_save), + (after_save, parent_post_updates), + ] + ) else: parent_pre_updates = unitofwork.PostUpdateAll( - uow, - self.parent.primary_base_mapper, - True) + uow, self.parent.primary_base_mapper, True + ) - uow.dependencies.update([ - (before_delete, parent_pre_updates), - (parent_pre_updates, delete_parent), - (parent_pre_updates, child_action) - ]) + uow.dependencies.update( + [ + (before_delete, parent_pre_updates), + (parent_pre_updates, delete_parent), + (parent_pre_updates, child_action), + ] + ) elif not isdelete: if not childisdelete: - uow.dependencies.update([ - (child_action, after_save), - (after_save, save_parent), - ]) + uow.dependencies.update( + [(child_action, after_save), (after_save, save_parent)] + ) else: - uow.dependencies.update([ - (after_save, save_parent), - ]) + uow.dependencies.update([(after_save, save_parent)]) else: if childisdelete: - uow.dependencies.update([ - (delete_parent, child_action) - ]) + uow.dependencies.update([(delete_parent, child_action)]) def presort_deletes(self, uowcommit, states): if self.cascade.delete or self.cascade.delete_orphan: for state in states: history = uowcommit.get_attribute_history( - state, - self.key, - self._passive_delete_flag) + state, self.key, self._passive_delete_flag + ) if history: if self.cascade.delete_orphan: todelete = history.sum() @@ -695,36 +729,42 @@ class ManyToOneDP(DependencyProcessor): if child is None: continue uowcommit.register_object( - child, isdelete=True, - operation="delete", prop=self.prop) - t = self.mapper.cascade_iterator('delete', child) + child, + isdelete=True, + operation="delete", + prop=self.prop, + ) + t = self.mapper.cascade_iterator("delete", child) for c, m, st_, dct_ in t: - uowcommit.register_object( - st_, isdelete=True) + uowcommit.register_object(st_, isdelete=True) def presort_saves(self, uowcommit, states): for state in states: uowcommit.register_object(state, operation="add", prop=self.prop) if self.cascade.delete_orphan: history = uowcommit.get_attribute_history( - state, - self.key, - self._passive_delete_flag) + state, self.key, self._passive_delete_flag + ) if history: for child in history.deleted: if self.hasparent(child) is False: uowcommit.register_object( - child, isdelete=True, - operation="delete", prop=self.prop) + child, + isdelete=True, + operation="delete", + prop=self.prop, + ) - t = self.mapper.cascade_iterator('delete', child) + t = self.mapper.cascade_iterator("delete", child) for c, m, st_, dct_ in t: uowcommit.register_object(st_, isdelete=True) def process_deletes(self, uowcommit, states): - if self.post_update and \ - not self.cascade.delete_orphan and \ - not self.passive_deletes == 'all': + if ( + self.post_update + and not self.cascade.delete_orphan + and not self.passive_deletes == "all" + ): # post_update means we have to update our # row to not reference the child object @@ -733,55 +773,70 @@ class ManyToOneDP(DependencyProcessor): self._synchronize(state, None, None, True, uowcommit) if state and self.post_update: history = uowcommit.get_attribute_history( - state, - self.key, - self._passive_delete_flag) + state, self.key, self._passive_delete_flag + ) if history: self._post_update( - state, uowcommit, history.sum(), - is_m2o_delete=True) + state, uowcommit, history.sum(), is_m2o_delete=True + ) def process_saves(self, uowcommit, states): for state in states: history = uowcommit.get_attribute_history( - state, - self.key, - attributes.PASSIVE_NO_INITIALIZE) + state, self.key, attributes.PASSIVE_NO_INITIALIZE + ) if history: if history.added: for child in history.added: - self._synchronize(state, child, None, False, - uowcommit, "add") + self._synchronize( + state, child, None, False, uowcommit, "add" + ) elif history.deleted: self._synchronize( - state, None, None, True, uowcommit, "delete") + state, None, None, True, uowcommit, "delete" + ) if self.post_update: self._post_update(state, uowcommit, history.sum()) - def _synchronize(self, state, child, associationrow, - clearkeys, uowcommit, operation=None): - if state is None or \ - (not self.post_update and uowcommit.is_deleted(state)): + def _synchronize( + self, + state, + child, + associationrow, + clearkeys, + uowcommit, + operation=None, + ): + if state is None or ( + not self.post_update and uowcommit.is_deleted(state) + ): return - if operation is not None and \ - child is not None and \ - not uowcommit.session._contains_state(child): + if ( + operation is not None + and child is not None + and not uowcommit.session._contains_state(child) + ): util.warn( "Object of type %s not in session, %s " - "operation along '%s' won't proceed" % - (mapperutil.state_class_str(child), operation, self.prop)) + "operation along '%s' won't proceed" + % (mapperutil.state_class_str(child), operation, self.prop) + ) return if clearkeys or child is None: sync.clear(state, self.parent, self.prop.synchronize_pairs) else: self._verify_canload(child) - sync.populate(child, self.mapper, state, - self.parent, - self.prop.synchronize_pairs, - uowcommit, - False) + sync.populate( + child, + self.mapper, + state, + self.parent, + self.prop.synchronize_pairs, + uowcommit, + False, + ) class DetectKeySwitch(DependencyProcessor): @@ -801,20 +856,18 @@ class DetectKeySwitch(DependencyProcessor): if self.passive_updates: return else: - if False in (prop.passive_updates for - prop in self.prop._reverse_property): + if False in ( + prop.passive_updates + for prop in self.prop._reverse_property + ): return uow.register_preprocessor(self, False) def per_property_flush_actions(self, uow): - parent_saves = unitofwork.SaveUpdateAll( - uow, - self.parent.base_mapper) + parent_saves = unitofwork.SaveUpdateAll(uow, self.parent.base_mapper) after_save = unitofwork.ProcessAll(uow, self, False, False) - uow.dependencies.update([ - (parent_saves, after_save) - ]) + uow.dependencies.update([(parent_saves, after_save)]) def per_state_flush_actions(self, uow, states, isdelete): pass @@ -848,8 +901,7 @@ class DetectKeySwitch(DependencyProcessor): def _key_switchers(self, uow, states): switched, notswitched = uow.memo( - ('pk_switchers', self), - lambda: (set(), set()) + ("pk_switchers", self), lambda: (set(), set()) ) allstates = switched.union(notswitched) @@ -871,74 +923,86 @@ class DetectKeySwitch(DependencyProcessor): continue dict_ = state.dict related = state.get_impl(self.key).get( - state, dict_, passive=self._passive_update_flag) - if related is not attributes.PASSIVE_NO_RESULT and \ - related is not None: + state, dict_, passive=self._passive_update_flag + ) + if ( + related is not attributes.PASSIVE_NO_RESULT + and related is not None + ): related_state = attributes.instance_state(dict_[self.key]) if related_state in switchers: - uowcommit.register_object(state, - False, - self.passive_updates) + uowcommit.register_object( + state, False, self.passive_updates + ) sync.populate( related_state, - self.mapper, state, - self.parent, self.prop.synchronize_pairs, - uowcommit, self.passive_updates) + self.mapper, + state, + self.parent, + self.prop.synchronize_pairs, + uowcommit, + self.passive_updates, + ) def _pks_changed(self, uowcommit, state): return bool(state.key) and sync.source_modified( - uowcommit, state, self.mapper, self.prop.synchronize_pairs) + uowcommit, state, self.mapper, self.prop.synchronize_pairs + ) class ManyToManyDP(DependencyProcessor): + def per_property_dependencies( + self, + uow, + parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete, + ): + + uow.dependencies.update( + [ + (parent_saves, after_save), + (child_saves, after_save), + (after_save, child_deletes), + # a rowswitch on the parent from deleted to saved + # can make this one occur, as the "save" may remove + # an element from the + # "deleted" list before we have a chance to + # process its child rows + (before_delete, parent_saves), + (before_delete, parent_deletes), + (before_delete, child_deletes), + (before_delete, child_saves), + ] + ) - def per_property_dependencies(self, uow, parent_saves, - child_saves, - parent_deletes, - child_deletes, - after_save, - before_delete - ): - - uow.dependencies.update([ - (parent_saves, after_save), - (child_saves, after_save), - (after_save, child_deletes), - - # a rowswitch on the parent from deleted to saved - # can make this one occur, as the "save" may remove - # an element from the - # "deleted" list before we have a chance to - # process its child rows - (before_delete, parent_saves), - - (before_delete, parent_deletes), - (before_delete, child_deletes), - (before_delete, child_saves), - ]) - - def per_state_dependencies(self, uow, - save_parent, - delete_parent, - child_action, - after_save, before_delete, - isdelete, childisdelete): + def per_state_dependencies( + self, + uow, + save_parent, + delete_parent, + child_action, + after_save, + before_delete, + isdelete, + childisdelete, + ): if not isdelete: if childisdelete: - uow.dependencies.update([ - (save_parent, after_save), - (after_save, child_action), - ]) + uow.dependencies.update( + [(save_parent, after_save), (after_save, child_action)] + ) else: - uow.dependencies.update([ - (save_parent, after_save), - (child_action, after_save), - ]) + uow.dependencies.update( + [(save_parent, after_save), (child_action, after_save)] + ) else: - uow.dependencies.update([ - (before_delete, child_action), - (before_delete, delete_parent) - ]) + uow.dependencies.update( + [(before_delete, child_action), (before_delete, delete_parent)] + ) def presort_deletes(self, uowcommit, states): # TODO: no tests fail if this whole @@ -949,9 +1013,8 @@ class ManyToManyDP(DependencyProcessor): # returns True for state in states: uowcommit.get_attribute_history( - state, - self.key, - self._passive_delete_flag) + state, self.key, self._passive_delete_flag + ) def presort_saves(self, uowcommit, states): if not self.passive_updates: @@ -961,9 +1024,8 @@ class ManyToManyDP(DependencyProcessor): for state in states: if self._pks_changed(uowcommit, state): history = uowcommit.get_attribute_history( - state, - self.key, - attributes.PASSIVE_OFF) + state, self.key, attributes.PASSIVE_OFF + ) if not self.cascade.delete_orphan: return @@ -972,20 +1034,21 @@ class ManyToManyDP(DependencyProcessor): # if delete_orphan check is turned on. for state in states: history = uowcommit.get_attribute_history( - state, - self.key, - attributes.PASSIVE_NO_INITIALIZE) + state, self.key, attributes.PASSIVE_NO_INITIALIZE + ) if history: for child in history.deleted: if self.hasparent(child) is False: uowcommit.register_object( - child, isdelete=True, - operation="delete", prop=self.prop) + child, + isdelete=True, + operation="delete", + prop=self.prop, + ) for c, m, st_, dct_ in self.mapper.cascade_iterator( - 'delete', - child): - uowcommit.register_object( - st_, isdelete=True) + "delete", child + ): + uowcommit.register_object(st_, isdelete=True) def process_deletes(self, uowcommit, states): secondary_delete = [] @@ -998,21 +1061,23 @@ class ManyToManyDP(DependencyProcessor): # this history should be cached already, as # we loaded it in preprocess_deletes history = uowcommit.get_attribute_history( - state, - self.key, - self._passive_delete_flag) + state, self.key, self._passive_delete_flag + ) if history: for child in history.non_added(): - if child is None or \ - (processed is not None and - (state, child) in processed): + if child is None or ( + processed is not None and (state, child) in processed + ): continue associationrow = {} if not self._synchronize( - state, - child, - associationrow, - False, uowcommit, "delete"): + state, + child, + associationrow, + False, + uowcommit, + "delete", + ): continue secondary_delete.append(associationrow) @@ -1021,8 +1086,9 @@ class ManyToManyDP(DependencyProcessor): if processed is not None: processed.update(tmp) - self._run_crud(uowcommit, secondary_insert, - secondary_update, secondary_delete) + self._run_crud( + uowcommit, secondary_insert, secondary_update, secondary_delete + ) def process_saves(self, uowcommit, states): secondary_delete = [] @@ -1033,110 +1099,133 @@ class ManyToManyDP(DependencyProcessor): tmp = set() for state in states: - need_cascade_pks = not self.passive_updates and \ - self._pks_changed(uowcommit, state) + need_cascade_pks = not self.passive_updates and self._pks_changed( + uowcommit, state + ) if need_cascade_pks: passive = attributes.PASSIVE_OFF else: passive = attributes.PASSIVE_NO_INITIALIZE - history = uowcommit.get_attribute_history(state, self.key, - passive) + history = uowcommit.get_attribute_history(state, self.key, passive) if history: for child in history.added: - if (processed is not None and - (state, child) in processed): + if processed is not None and (state, child) in processed: continue associationrow = {} - if not self._synchronize(state, - child, - associationrow, - False, uowcommit, "add"): + if not self._synchronize( + state, child, associationrow, False, uowcommit, "add" + ): continue secondary_insert.append(associationrow) for child in history.deleted: - if (processed is not None and - (state, child) in processed): + if processed is not None and (state, child) in processed: continue associationrow = {} - if not self._synchronize(state, - child, - associationrow, - False, uowcommit, "delete"): + if not self._synchronize( + state, + child, + associationrow, + False, + uowcommit, + "delete", + ): continue secondary_delete.append(associationrow) - tmp.update((c, state) - for c in history.added + history.deleted) + tmp.update((c, state) for c in history.added + history.deleted) if need_cascade_pks: for child in history.unchanged: associationrow = {} - sync.update(state, - self.parent, - associationrow, - "old_", - self.prop.synchronize_pairs) - sync.update(child, - self.mapper, - associationrow, - "old_", - self.prop.secondary_synchronize_pairs) + sync.update( + state, + self.parent, + associationrow, + "old_", + self.prop.synchronize_pairs, + ) + sync.update( + child, + self.mapper, + associationrow, + "old_", + self.prop.secondary_synchronize_pairs, + ) secondary_update.append(associationrow) if processed is not None: processed.update(tmp) - self._run_crud(uowcommit, secondary_insert, - secondary_update, secondary_delete) + self._run_crud( + uowcommit, secondary_insert, secondary_update, secondary_delete + ) - def _run_crud(self, uowcommit, secondary_insert, - secondary_update, secondary_delete): + def _run_crud( + self, uowcommit, secondary_insert, secondary_update, secondary_delete + ): connection = uowcommit.transaction.connection(self.mapper) if secondary_delete: associationrow = secondary_delete[0] - statement = self.secondary.delete(sql.and_(*[ - c == sql.bindparam(c.key, type_=c.type) - for c in self.secondary.c - if c.key in associationrow - ])) + statement = self.secondary.delete( + sql.and_( + *[ + c == sql.bindparam(c.key, type_=c.type) + for c in self.secondary.c + if c.key in associationrow + ] + ) + ) result = connection.execute(statement, secondary_delete) - if result.supports_sane_multi_rowcount() and \ - result.rowcount != len(secondary_delete): + if result.supports_sane_multi_rowcount() and result.rowcount != len( + secondary_delete + ): raise exc.StaleDataError( "DELETE statement on table '%s' expected to delete " - "%d row(s); Only %d were matched." % - (self.secondary.description, len(secondary_delete), - result.rowcount) + "%d row(s); Only %d were matched." + % ( + self.secondary.description, + len(secondary_delete), + result.rowcount, + ) ) if secondary_update: associationrow = secondary_update[0] - statement = self.secondary.update(sql.and_(*[ - c == sql.bindparam("old_" + c.key, type_=c.type) - for c in self.secondary.c - if c.key in associationrow - ])) + statement = self.secondary.update( + sql.and_( + *[ + c == sql.bindparam("old_" + c.key, type_=c.type) + for c in self.secondary.c + if c.key in associationrow + ] + ) + ) result = connection.execute(statement, secondary_update) - if result.supports_sane_multi_rowcount() and \ - result.rowcount != len(secondary_update): + if result.supports_sane_multi_rowcount() and result.rowcount != len( + secondary_update + ): raise exc.StaleDataError( "UPDATE statement on table '%s' expected to update " - "%d row(s); Only %d were matched." % - (self.secondary.description, len(secondary_update), - result.rowcount) + "%d row(s); Only %d were matched." + % ( + self.secondary.description, + len(secondary_update), + result.rowcount, + ) ) if secondary_insert: statement = self.secondary.insert() connection.execute(statement, secondary_insert) - def _synchronize(self, state, child, associationrow, - clearkeys, uowcommit, operation): + def _synchronize( + self, state, child, associationrow, clearkeys, uowcommit, operation + ): # this checks for None if uselist=True self._verify_canload(child) @@ -1150,23 +1239,28 @@ class ManyToManyDP(DependencyProcessor): if not child.deleted: util.warn( "Object of type %s not in session, %s " - "operation along '%s' won't proceed" % - (mapperutil.state_class_str(child), operation, self.prop)) + "operation along '%s' won't proceed" + % (mapperutil.state_class_str(child), operation, self.prop) + ) return False - sync.populate_dict(state, self.parent, associationrow, - self.prop.synchronize_pairs) - sync.populate_dict(child, self.mapper, associationrow, - self.prop.secondary_synchronize_pairs) + sync.populate_dict( + state, self.parent, associationrow, self.prop.synchronize_pairs + ) + sync.populate_dict( + child, + self.mapper, + associationrow, + self.prop.secondary_synchronize_pairs, + ) return True def _pks_changed(self, uowcommit, state): return sync.source_modified( - uowcommit, - state, - self.parent, - self.prop.synchronize_pairs) + uowcommit, state, self.parent, self.prop.synchronize_pairs + ) + _direction_to_processor = { ONETOMANY: OneToManyDP, diff --git a/lib/sqlalchemy/orm/deprecated_interfaces.py b/lib/sqlalchemy/orm/deprecated_interfaces.py index 426288e03f..6b51404d0c 100644 --- a/lib/sqlalchemy/orm/deprecated_interfaces.py +++ b/lib/sqlalchemy/orm/deprecated_interfaces.py @@ -58,23 +58,25 @@ class MapperExtension(object): @classmethod def _adapt_instrument_class(cls, self, listener): - cls._adapt_listener_methods(self, listener, ('instrument_class',)) + cls._adapt_listener_methods(self, listener, ("instrument_class",)) @classmethod def _adapt_listener(cls, self, listener): cls._adapt_listener_methods( - self, listener, + self, + listener, ( - 'init_instance', - 'init_failed', - 'reconstruct_instance', - 'before_insert', - 'after_insert', - 'before_update', - 'after_update', - 'before_delete', - 'after_delete' - )) + "init_instance", + "init_failed", + "reconstruct_instance", + "before_insert", + "after_insert", + "before_update", + "after_update", + "before_delete", + "after_delete", + ), + ) @classmethod def _adapt_listener_methods(cls, self, listener, methods): @@ -84,36 +86,75 @@ class MapperExtension(object): ls_meth = getattr(listener, meth) if not util.methods_equivalent(me_meth, ls_meth): - if meth == 'reconstruct_instance': + if meth == "reconstruct_instance": + def go(ls_meth): def reconstruct(instance, ctx): ls_meth(self, instance) + return reconstruct - event.listen(self.class_manager, 'load', - go(ls_meth), raw=False, propagate=True) - elif meth == 'init_instance': + + event.listen( + self.class_manager, + "load", + go(ls_meth), + raw=False, + propagate=True, + ) + elif meth == "init_instance": + def go(ls_meth): def init_instance(instance, args, kwargs): - ls_meth(self, self.class_, - self.class_manager.original_init, - instance, args, kwargs) + ls_meth( + self, + self.class_, + self.class_manager.original_init, + instance, + args, + kwargs, + ) + return init_instance - event.listen(self.class_manager, 'init', - go(ls_meth), raw=False, propagate=True) - elif meth == 'init_failed': + + event.listen( + self.class_manager, + "init", + go(ls_meth), + raw=False, + propagate=True, + ) + elif meth == "init_failed": + def go(ls_meth): def init_failed(instance, args, kwargs): util.warn_exception( - ls_meth, self, self.class_, + ls_meth, + self, + self.class_, self.class_manager.original_init, - instance, args, kwargs) + instance, + args, + kwargs, + ) return init_failed - event.listen(self.class_manager, 'init_failure', - go(ls_meth), raw=False, propagate=True) + + event.listen( + self.class_manager, + "init_failure", + go(ls_meth), + raw=False, + propagate=True, + ) else: - event.listen(self, "%s" % meth, ls_meth, - raw=False, retval=True, propagate=True) + event.listen( + self, + "%s" % meth, + ls_meth, + raw=False, + retval=True, + propagate=True, + ) def instrument_class(self, mapper, class_): """Receive a class when the mapper is first constructed, and has @@ -302,16 +343,16 @@ class SessionExtension(object): @classmethod def _adapt_listener(cls, self, listener): for meth in [ - 'before_commit', - 'after_commit', - 'after_rollback', - 'before_flush', - 'after_flush', - 'after_flush_postexec', - 'after_begin', - 'after_attach', - 'after_bulk_update', - 'after_bulk_delete', + "before_commit", + "after_commit", + "after_rollback", + "before_flush", + "after_flush", + "after_flush_postexec", + "after_begin", + "after_attach", + "after_bulk_update", + "after_bulk_delete", ]: me_meth = getattr(SessionExtension, meth) ls_meth = getattr(listener, meth) @@ -450,15 +491,30 @@ class AttributeExtension(object): @classmethod def _adapt_listener(cls, self, listener): - event.listen(self, 'append', listener.append, - active_history=listener.active_history, - raw=True, retval=True) - event.listen(self, 'remove', listener.remove, - active_history=listener.active_history, - raw=True, retval=True) - event.listen(self, 'set', listener.set, - active_history=listener.active_history, - raw=True, retval=True) + event.listen( + self, + "append", + listener.append, + active_history=listener.active_history, + raw=True, + retval=True, + ) + event.listen( + self, + "remove", + listener.remove, + active_history=listener.active_history, + raw=True, + retval=True, + ) + event.listen( + self, + "set", + listener.set, + active_history=listener.active_history, + raw=True, + retval=True, + ) def append(self, state, value, initiator): """Receive a collection append event. diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index fefd2d2a11..37517e84c9 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -37,9 +37,11 @@ class DescriptorProperty(MapperProperty): def __init__(self, key): self.key = key - if hasattr(prop, 'get_history'): - def get_history(self, state, dict_, - passive=attributes.PASSIVE_OFF): + if hasattr(prop, "get_history"): + + def get_history( + self, state, dict_, passive=attributes.PASSIVE_OFF + ): return prop.get_history(state, dict_, passive) if self.descriptor is None: @@ -48,6 +50,7 @@ class DescriptorProperty(MapperProperty): self.descriptor = desc if self.descriptor is None: + def fset(obj, value): setattr(obj, self.name, value) @@ -57,21 +60,16 @@ class DescriptorProperty(MapperProperty): def fget(obj): return getattr(obj, self.name) - self.descriptor = property( - fget=fget, - fset=fset, - fdel=fdel, - ) + self.descriptor = property(fget=fget, fset=fset, fdel=fdel) - proxy_attr = attributes.create_proxied_attribute( - self.descriptor)( - self.parent.class_, - self.key, - self.descriptor, - lambda: self._comparator_factory(mapper), - doc=self.doc, - original_property=self - ) + proxy_attr = attributes.create_proxied_attribute(self.descriptor)( + self.parent.class_, + self.key, + self.descriptor, + lambda: self._comparator_factory(mapper), + doc=self.doc, + original_property=self, + ) proxy_attr.impl = _ProxyImpl(self.key) mapper.class_manager.instrument_attribute(self.key, proxy_attr) @@ -149,13 +147,14 @@ class CompositeProperty(DescriptorProperty): self.attrs = attrs self.composite_class = class_ - self.active_history = kwargs.get('active_history', False) - self.deferred = kwargs.get('deferred', False) - self.group = kwargs.get('group', None) - self.comparator_factory = kwargs.pop('comparator_factory', - self.__class__.Comparator) - if 'info' in kwargs: - self.info = kwargs.pop('info') + self.active_history = kwargs.get("active_history", False) + self.deferred = kwargs.get("deferred", False) + self.group = kwargs.get("group", None) + self.comparator_factory = kwargs.pop( + "comparator_factory", self.__class__.Comparator + ) + if "info" in kwargs: + self.info = kwargs.pop("info") util.set_creation_order(self) self._create_descriptor() @@ -186,8 +185,7 @@ class CompositeProperty(DescriptorProperty): # attributes, retrieve their values. This # ensures they all load. values = [ - getattr(instance, key) - for key in self._attribute_keys + getattr(instance, key) for key in self._attribute_keys ] # current expected behavior here is that the composite is @@ -196,8 +194,7 @@ class CompositeProperty(DescriptorProperty): # if the composite were created unconditionally, # but that would be a behavioral change. if self.key not in dict_ and ( - state.key is not None or - not _none_set.issuperset(values) + state.key is not None or not _none_set.issuperset(values) ): dict_[self.key] = self.composite_class(*values) state.manager.dispatch.refresh(state, None, [self.key]) @@ -217,8 +214,8 @@ class CompositeProperty(DescriptorProperty): setattr(instance, key, None) else: for key, value in zip( - self._attribute_keys, - value.__composite_values__()): + self._attribute_keys, value.__composite_values__() + ): setattr(instance, key, value) def fdel(instance): @@ -234,18 +231,14 @@ class CompositeProperty(DescriptorProperty): @util.memoized_property def _comparable_elements(self): - return [ - getattr(self.parent.class_, prop.key) - for prop in self.props - ] + return [getattr(self.parent.class_, prop.key) for prop in self.props] @util.memoized_property def props(self): props = [] for attr in self.attrs: if isinstance(attr, str): - prop = self.parent.get_property( - attr, _configure_mappers=False) + prop = self.parent.get_property(attr, _configure_mappers=False) elif isinstance(attr, schema.Column): prop = self.parent._columntoproperty[attr] elif isinstance(attr, attributes.InstrumentedAttribute): @@ -254,7 +247,8 @@ class CompositeProperty(DescriptorProperty): raise sa_exc.ArgumentError( "Composite expects Column objects or mapped " "attributes/attribute names as arguments, got: %r" - % (attr,)) + % (attr,) + ) props.append(prop) return props @@ -271,9 +265,7 @@ class CompositeProperty(DescriptorProperty): prop.active_history = self.active_history if self.deferred: prop.deferred = self.deferred - prop.strategy_key = ( - ("deferred", True), - ("instrument", True)) + prop.strategy_key = (("deferred", True), ("instrument", True)) prop.group = self.group def _setup_event_handlers(self): @@ -299,8 +291,7 @@ class CompositeProperty(DescriptorProperty): return dict_[self.key] = self.composite_class( - *[state.dict[key] for key in - self._attribute_keys] + *[state.dict[key] for key in self._attribute_keys] ) def expire_handler(state, keys): @@ -317,24 +308,27 @@ class CompositeProperty(DescriptorProperty): state.dict.pop(self.key, None) - event.listen(self.parent, 'after_insert', - insert_update_handler, raw=True) - event.listen(self.parent, 'after_update', - insert_update_handler, raw=True) - event.listen(self.parent, 'load', - load_handler, raw=True, propagate=True) - event.listen(self.parent, 'refresh', - refresh_handler, raw=True, propagate=True) - event.listen(self.parent, 'expire', - expire_handler, raw=True, propagate=True) + event.listen( + self.parent, "after_insert", insert_update_handler, raw=True + ) + event.listen( + self.parent, "after_update", insert_update_handler, raw=True + ) + event.listen( + self.parent, "load", load_handler, raw=True, propagate=True + ) + event.listen( + self.parent, "refresh", refresh_handler, raw=True, propagate=True + ) + event.listen( + self.parent, "expire", expire_handler, raw=True, propagate=True + ) # TODO: need a deserialize hook here @util.memoized_property def _attribute_keys(self): - return [ - prop.key for prop in self.props - ] + return [prop.key for prop in self.props] def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF): """Provided for userland code that uses attributes.get_history().""" @@ -363,12 +357,10 @@ class CompositeProperty(DescriptorProperty): return attributes.History( [self.composite_class(*added)], (), - [self.composite_class(*deleted)] + [self.composite_class(*deleted)], ) else: - return attributes.History( - (), [self.composite_class(*added)], () - ) + return attributes.History((), [self.composite_class(*added)], ()) def _comparator_factory(self, mapper): return self.comparator_factory(self, mapper) @@ -377,12 +369,15 @@ class CompositeProperty(DescriptorProperty): def __init__(self, property, expr): self.property = property super(CompositeProperty.CompositeBundle, self).__init__( - property.key, *expr) + property.key, *expr + ) def create_row_processor(self, query, procs, labels): def proc(row): return self.property.composite_class( - *[proc(row) for proc in procs]) + *[proc(row) for proc in procs] + ) + return proc class Comparator(PropComparator): @@ -412,11 +407,13 @@ class CompositeProperty(DescriptorProperty): def __clause_element__(self): return expression.ClauseList( - group=False, *self._comparable_elements) + group=False, *self._comparable_elements + ) def _query_clause_element(self): return CompositeProperty.CompositeBundle( - self.prop, self.__clause_element__()) + self.prop, self.__clause_element__() + ) def _bulk_update_tuples(self, value): if value is None: @@ -425,22 +422,18 @@ class CompositeProperty(DescriptorProperty): values = value.__composite_values__() else: raise sa_exc.ArgumentError( - "Can't UPDATE composite attribute %s to %r" % - (self.prop, value)) + "Can't UPDATE composite attribute %s to %r" + % (self.prop, value) + ) - return zip( - self._comparable_elements, - values - ) + return zip(self._comparable_elements, values) @util.memoized_property def _comparable_elements(self): if self._adapt_to_entity: return [ - getattr( - self._adapt_to_entity.entity, - prop.key - ) for prop in self.prop._comparable_elements + getattr(self._adapt_to_entity.entity, prop.key) + for prop in self.prop._comparable_elements ] else: return self.prop._comparable_elements @@ -451,8 +444,7 @@ class CompositeProperty(DescriptorProperty): else: values = other.__composite_values__() comparisons = [ - a == b - for a, b in zip(self.prop._comparable_elements, values) + a == b for a, b in zip(self.prop._comparable_elements, values) ] if self._adapt_to_entity: comparisons = [self.adapter(x) for x in comparisons] @@ -495,14 +487,16 @@ class ConcreteInheritedProperty(DescriptorProperty): def __init__(self): super(ConcreteInheritedProperty, self).__init__() + def warn(): - raise AttributeError("Concrete %s does not implement " - "attribute %r at the instance level. Add " - "this property explicitly to %s." % - (self.parent, self.key, self.parent)) + raise AttributeError( + "Concrete %s does not implement " + "attribute %r at the instance level. Add " + "this property explicitly to %s." + % (self.parent, self.key, self.parent) + ) class NoninheritedConcreteProp(object): - def __set__(s, obj, value): warn() @@ -513,15 +507,21 @@ class ConcreteInheritedProperty(DescriptorProperty): if obj is None: return self.descriptor warn() + self.descriptor = NoninheritedConcreteProp() @util.langhelpers.dependency_for("sqlalchemy.orm.properties", add_to_all=True) class SynonymProperty(DescriptorProperty): - - def __init__(self, name, map_column=None, - descriptor=None, comparator_factory=None, - doc=None, info=None): + def __init__( + self, + name, + map_column=None, + descriptor=None, + comparator_factory=None, + doc=None, + info=None, + ): """Denote an attribute name as a synonym to a mapped property, in that the attribute will mirror the value and expression behavior of another attribute. @@ -639,15 +639,13 @@ class SynonymProperty(DescriptorProperty): @util.memoized_property def _proxied_property(self): attr = getattr(self.parent.class_, self.name) - if not hasattr(attr, 'property') or not \ - isinstance(attr.property, MapperProperty): + if not hasattr(attr, "property") or not isinstance( + attr.property, MapperProperty + ): raise sa_exc.InvalidRequestError( """synonym() attribute "%s.%s" only supports """ - """ORM mapped attributes, got %r""" % ( - self.parent.class_.__name__, - self.name, - attr - ) + """ORM mapped attributes, got %r""" + % (self.parent.class_.__name__, self.name, attr) ) return attr.property @@ -671,23 +669,23 @@ class SynonymProperty(DescriptorProperty): raise sa_exc.ArgumentError( "Can't compile synonym '%s': no column on table " "'%s' named '%s'" - % (self.name, parent.mapped_table.description, self.key)) - elif parent.mapped_table.c[self.key] in \ - parent._columntoproperty and \ - parent._columntoproperty[ - parent.mapped_table.c[self.key] - ].key == self.name: + % (self.name, parent.mapped_table.description, self.key) + ) + elif ( + parent.mapped_table.c[self.key] in parent._columntoproperty + and parent._columntoproperty[ + parent.mapped_table.c[self.key] + ].key + == self.name + ): raise sa_exc.ArgumentError( "Can't call map_column=True for synonym %r=%r, " "a ColumnProperty already exists keyed to the name " - "%r for column %r" % - (self.key, self.name, self.name, self.key) + "%r for column %r" + % (self.key, self.name, self.name, self.key) ) p = properties.ColumnProperty(parent.mapped_table.c[self.key]) - parent._configure_property( - self.name, p, - init=init, - setparent=True) + parent._configure_property(self.name, p, init=init, setparent=True) p._mapped_by_synonym = self.key self.parent = parent @@ -698,7 +696,8 @@ class ComparableProperty(DescriptorProperty): """Instruments a Python property for use in query expressions.""" def __init__( - self, comparator_factory, descriptor=None, doc=None, info=None): + self, comparator_factory, descriptor=None, doc=None, info=None + ): """Provides a method of applying a :class:`.PropComparator` to any Python descriptor attribute. diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 087e7dcc64..e5c6b80b66 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -15,8 +15,13 @@ basic add/delete mutation. from .. import log, util, exc from ..sql import operators from . import ( - attributes, object_session, util as orm_util, strategies, - object_mapper, exc as orm_exc, properties + attributes, + object_session, + util as orm_util, + strategies, + object_mapper, + exc as orm_exc, + properties, ) from .query import Query @@ -30,7 +35,8 @@ class DynaLoader(strategies.AbstractRelationshipLoader): raise exc.InvalidRequestError( "On relationship %s, 'dynamic' loaders cannot be used with " "many-to-one/one-to-one relationships and/or " - "uselist=False." % self.parent_property) + "uselist=False." % self.parent_property + ) strategies._register_attribute( self.parent_property, mapper, @@ -49,11 +55,20 @@ class DynamicAttributeImpl(attributes.AttributeImpl): collection = False dynamic = True - def __init__(self, class_, key, typecallable, - dispatch, - target_mapper, order_by, query_class=None, **kw): - super(DynamicAttributeImpl, self).\ - __init__(class_, key, typecallable, dispatch, **kw) + def __init__( + self, + class_, + key, + typecallable, + dispatch, + target_mapper, + order_by, + query_class=None, + **kw + ): + super(DynamicAttributeImpl, self).__init__( + class_, key, typecallable, dispatch, **kw + ) self.target_mapper = target_mapper self.order_by = order_by if not query_class: @@ -66,15 +81,20 @@ class DynamicAttributeImpl(attributes.AttributeImpl): def get(self, state, dict_, passive=attributes.PASSIVE_OFF): if not passive & attributes.SQL_OK: return self._get_collection_history( - state, attributes.PASSIVE_NO_INITIALIZE).added_items + state, attributes.PASSIVE_NO_INITIALIZE + ).added_items else: return self.query_class(self, state) - def get_collection(self, state, dict_, user_data=None, - passive=attributes.PASSIVE_NO_INITIALIZE): + def get_collection( + self, + state, + dict_, + user_data=None, + passive=attributes.PASSIVE_NO_INITIALIZE, + ): if not passive & attributes.SQL_OK: - return self._get_collection_history(state, - passive).added_items + return self._get_collection_history(state, passive).added_items else: history = self._get_collection_history(state, passive) return history.added_plus_unchanged @@ -87,8 +107,9 @@ class DynamicAttributeImpl(attributes.AttributeImpl): def _remove_token(self): return attributes.Event(self, attributes.OP_REMOVE) - def fire_append_event(self, state, dict_, value, initiator, - collection_history=None): + def fire_append_event( + self, state, dict_, value, initiator, collection_history=None + ): if collection_history is None: collection_history = self._modified_event(state, dict_) @@ -100,8 +121,9 @@ class DynamicAttributeImpl(attributes.AttributeImpl): if self.trackparent and value is not None: self.sethasparent(attributes.instance_state(value), state, True) - def fire_remove_event(self, state, dict_, value, initiator, - collection_history=None): + def fire_remove_event( + self, state, dict_, value, initiator, collection_history=None + ): if collection_history is None: collection_history = self._modified_event(state, dict_) @@ -118,18 +140,24 @@ class DynamicAttributeImpl(attributes.AttributeImpl): if self.key not in state.committed_state: state.committed_state[self.key] = CollectionHistory(self, state) - state._modified_event(dict_, - self, - attributes.NEVER_SET) + state._modified_event(dict_, self, attributes.NEVER_SET) # this is a hack to allow the fixtures.ComparableEntity fixture # to work dict_[self.key] = True return state.committed_state[self.key] - def set(self, state, dict_, value, initiator=None, - passive=attributes.PASSIVE_OFF, - check_old=None, pop=False, _adapt=True): + def set( + self, + state, + dict_, + value, + initiator=None, + passive=attributes.PASSIVE_OFF, + check_old=None, + pop=False, + _adapt=True, + ): if initiator and initiator.parent_token is self.parent_token: return @@ -146,7 +174,8 @@ class DynamicAttributeImpl(attributes.AttributeImpl): old_collection = collection_history.added_items else: old_collection = old_collection.union( - collection_history.added_items) + collection_history.added_items + ) idset = util.IdentitySet constants = old_collection.intersection(new_values) @@ -155,33 +184,40 @@ class DynamicAttributeImpl(attributes.AttributeImpl): for member in new_values: if member in additions: - self.fire_append_event(state, dict_, member, None, - collection_history=collection_history) + self.fire_append_event( + state, + dict_, + member, + None, + collection_history=collection_history, + ) for member in removals: - self.fire_remove_event(state, dict_, member, None, - collection_history=collection_history) + self.fire_remove_event( + state, + dict_, + member, + None, + collection_history=collection_history, + ) def delete(self, *args, **kwargs): raise NotImplementedError() def set_committed_value(self, state, dict_, value): - raise NotImplementedError("Dynamic attributes don't support " - "collection population.") + raise NotImplementedError( + "Dynamic attributes don't support " "collection population." + ) def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF): c = self._get_collection_history(state, passive) return c.as_history() - def get_all_pending(self, state, dict_, - passive=attributes.PASSIVE_NO_INITIALIZE): - c = self._get_collection_history( - state, passive) - return [ - (attributes.instance_state(x), x) - for x in - c.all_items - ] + def get_all_pending( + self, state, dict_, passive=attributes.PASSIVE_NO_INITIALIZE + ): + c = self._get_collection_history(state, passive) + return [(attributes.instance_state(x), x) for x in c.all_items] def _get_collection_history(self, state, passive=attributes.PASSIVE_OFF): if self.key in state.committed_state: @@ -194,18 +230,21 @@ class DynamicAttributeImpl(attributes.AttributeImpl): else: return c - def append(self, state, dict_, value, initiator, - passive=attributes.PASSIVE_OFF): + def append( + self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF + ): if initiator is not self: self.fire_append_event(state, dict_, value, initiator) - def remove(self, state, dict_, value, initiator, - passive=attributes.PASSIVE_OFF): + def remove( + self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF + ): if initiator is not self: self.fire_remove_event(state, dict_, value, initiator) - def pop(self, state, dict_, value, initiator, - passive=attributes.PASSIVE_OFF): + def pop( + self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF + ): self.remove(state, dict_, value, initiator, passive=passive) @@ -229,30 +268,36 @@ class AppenderMixin(object): # doesn't fail, and secondary is then in _from_obj[1]. self._from_obj = (prop.mapper.selectable, prop.secondary) - self._criterion = prop._with_parent( - instance, - alias_secondary=False) + self._criterion = prop._with_parent(instance, alias_secondary=False) if self.attr.order_by: self._order_by = self.attr.order_by def session(self): sess = object_session(self.instance) - if sess is not None and self.autoflush and sess.autoflush \ - and self.instance in sess: + if ( + sess is not None + and self.autoflush + and sess.autoflush + and self.instance in sess + ): sess.flush() if not orm_util.has_identity(self.instance): return None else: return sess + session = property(session, lambda s, x: None) def __iter__(self): sess = self.session if sess is None: - return iter(self.attr._get_collection_history( - attributes.instance_state(self.instance), - attributes.PASSIVE_NO_INITIALIZE).added_items) + return iter( + self.attr._get_collection_history( + attributes.instance_state(self.instance), + attributes.PASSIVE_NO_INITIALIZE, + ).added_items + ) else: return iter(self._clone(sess)) @@ -261,16 +306,20 @@ class AppenderMixin(object): if sess is None: return self.attr._get_collection_history( attributes.instance_state(self.instance), - attributes.PASSIVE_NO_INITIALIZE).indexed(index) + attributes.PASSIVE_NO_INITIALIZE, + ).indexed(index) else: return self._clone(sess).__getitem__(index) def count(self): sess = self.session if sess is None: - return len(self.attr._get_collection_history( - attributes.instance_state(self.instance), - attributes.PASSIVE_NO_INITIALIZE).added_items) + return len( + self.attr._get_collection_history( + attributes.instance_state(self.instance), + attributes.PASSIVE_NO_INITIALIZE, + ).added_items + ) else: return self._clone(sess).count() @@ -285,8 +334,9 @@ class AppenderMixin(object): raise orm_exc.DetachedInstanceError( "Parent instance %s is not bound to a Session, and no " "contextual session is established; lazy load operation " - "of attribute '%s' cannot proceed" % ( - orm_util.instance_str(instance), self.attr.key)) + "of attribute '%s' cannot proceed" + % (orm_util.instance_str(instance), self.attr.key) + ) if self.query_class: query = self.query_class(self.attr.target_mapper, session=sess) @@ -303,17 +353,26 @@ class AppenderMixin(object): for item in iterator: self.attr.append( attributes.instance_state(self.instance), - attributes.instance_dict(self.instance), item, None) + attributes.instance_dict(self.instance), + item, + None, + ) def append(self, item): self.attr.append( attributes.instance_state(self.instance), - attributes.instance_dict(self.instance), item, None) + attributes.instance_dict(self.instance), + item, + None, + ) def remove(self, item): self.attr.remove( attributes.instance_state(self.instance), - attributes.instance_dict(self.instance), item, None) + attributes.instance_dict(self.instance), + item, + None, + ) class AppenderQuery(AppenderMixin, Query): @@ -322,8 +381,8 @@ class AppenderQuery(AppenderMixin, Query): def mixin_user_query(cls): """Return a new class with AppenderQuery functionality layered over.""" - name = 'Appender' + cls.__name__ - return type(name, (AppenderMixin, cls), {'query_class': cls}) + name = "Appender" + cls.__name__ + return type(name, (AppenderMixin, cls), {"query_class": cls}) class CollectionHistory(object): @@ -348,8 +407,11 @@ class CollectionHistory(object): @property def all_items(self): - return list(self.added_items.union( - self.unchanged_items).union(self.deleted_items)) + return list( + self.added_items.union(self.unchanged_items).union( + self.deleted_items + ) + ) def as_history(self): if self._reconcile_collection: @@ -357,14 +419,12 @@ class CollectionHistory(object): deleted = self.deleted_items.intersection(self.unchanged_items) unchanged = self.unchanged_items.difference(deleted) else: - added, unchanged, deleted = self.added_items,\ - self.unchanged_items,\ - self.deleted_items - return attributes.History( - list(added), - list(unchanged), - list(deleted), - ) + added, unchanged, deleted = ( + self.added_items, + self.unchanged_items, + self.deleted_items, + ) + return attributes.History(list(added), list(unchanged), list(deleted)) def indexed(self, index): return list(self.added_items)[index] diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index 4abf08ab13..ac031d84f5 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -14,17 +14,40 @@ from .. import util class UnevaluatableError(Exception): pass -_straight_ops = set(getattr(operators, op) - for op in ('add', 'mul', 'sub', - 'div', - 'mod', 'truediv', - 'lt', 'le', 'ne', 'gt', 'ge', 'eq')) - -_notimplemented_ops = set(getattr(operators, op) - for op in ('like_op', 'notlike_op', 'ilike_op', - 'notilike_op', 'between_op', 'in_op', - 'notin_op', 'endswith_op', 'concat_op')) +_straight_ops = set( + getattr(operators, op) + for op in ( + "add", + "mul", + "sub", + "div", + "mod", + "truediv", + "lt", + "le", + "ne", + "gt", + "ge", + "eq", + ) +) + + +_notimplemented_ops = set( + getattr(operators, op) + for op in ( + "like_op", + "notlike_op", + "ilike_op", + "notilike_op", + "between_op", + "in_op", + "notin_op", + "endswith_op", + "concat_op", + ) +) class EvaluatorCompiler(object): @@ -35,7 +58,8 @@ class EvaluatorCompiler(object): meth = getattr(self, "visit_%s" % clause.__visit_name__, None) if not meth: raise UnevaluatableError( - "Cannot evaluate %s" % type(clause).__name__) + "Cannot evaluate %s" % type(clause).__name__ + ) return meth(clause) def visit_grouping(self, clause): @@ -51,28 +75,30 @@ class EvaluatorCompiler(object): return lambda obj: True def visit_column(self, clause): - if 'parentmapper' in clause._annotations: - parentmapper = clause._annotations['parentmapper'] + if "parentmapper" in clause._annotations: + parentmapper = clause._annotations["parentmapper"] if self.target_cls and not issubclass( - self.target_cls, parentmapper.class_): + self.target_cls, parentmapper.class_ + ): raise UnevaluatableError( - "Can't evaluate criteria against alternate class %s" % - parentmapper.class_ + "Can't evaluate criteria against alternate class %s" + % parentmapper.class_ ) key = parentmapper._columntoproperty[clause].key else: key = clause.key - if self.target_cls and \ - key in inspect(self.target_cls).column_attrs: + if ( + self.target_cls + and key in inspect(self.target_cls).column_attrs + ): util.warn( "Evaluating non-mapped column expression '%s' onto " "ORM instances; this is a deprecated use case. Please " "make use of the actual mapped columns in ORM-evaluated " - "UPDATE / DELETE expressions." % clause) - else: - raise UnevaluatableError( - "Cannot evaluate column: %s" % clause + "UPDATE / DELETE expressions." % clause ) + else: + raise UnevaluatableError("Cannot evaluate column: %s" % clause) get_corresponding_attr = operator.attrgetter(key) return lambda obj: get_corresponding_attr(obj) @@ -80,6 +106,7 @@ class EvaluatorCompiler(object): def visit_clauselist(self, clause): evaluators = list(map(self.process, clause.clauses)) if clause.operator is operators.or_: + def evaluate(obj): has_null = False for sub_evaluate in evaluators: @@ -90,7 +117,9 @@ class EvaluatorCompiler(object): if has_null: return None return False + elif clause.operator is operators.and_: + def evaluate(obj): for sub_evaluate in evaluators: value = sub_evaluate(obj) @@ -99,48 +128,60 @@ class EvaluatorCompiler(object): return None return False return True + else: raise UnevaluatableError( - "Cannot evaluate clauselist with operator %s" % - clause.operator) + "Cannot evaluate clauselist with operator %s" % clause.operator + ) return evaluate def visit_binary(self, clause): - eval_left, eval_right = list(map(self.process, - [clause.left, clause.right])) + eval_left, eval_right = list( + map(self.process, [clause.left, clause.right]) + ) operator = clause.operator if operator is operators.is_: + def evaluate(obj): return eval_left(obj) == eval_right(obj) + elif operator is operators.isnot: + def evaluate(obj): return eval_left(obj) != eval_right(obj) + elif operator in _straight_ops: + def evaluate(obj): left_val = eval_left(obj) right_val = eval_right(obj) if left_val is None or right_val is None: return None return operator(eval_left(obj), eval_right(obj)) + else: raise UnevaluatableError( - "Cannot evaluate %s with operator %s" % - (type(clause).__name__, clause.operator)) + "Cannot evaluate %s with operator %s" + % (type(clause).__name__, clause.operator) + ) return evaluate def visit_unary(self, clause): eval_inner = self.process(clause.element) if clause.operator is operators.inv: + def evaluate(obj): value = eval_inner(obj) if value is None: return None return not value + return evaluate raise UnevaluatableError( - "Cannot evaluate %s with operator %s" % - (type(clause).__name__, clause.operator)) + "Cannot evaluate %s with operator %s" + % (type(clause).__name__, clause.operator) + ) def visit_bindparam(self, clause): if clause.callable: diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index c414f548ee..c2a2d15ee3 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -20,6 +20,7 @@ from .attributes import QueryableAttribute from .query import Query from sqlalchemy.util.compat import inspect_getargspec + class InstrumentationEvents(event.Events): """Events related to class instrumentation events. @@ -61,9 +62,11 @@ class InstrumentationEvents(event.Events): @classmethod def _listen(cls, event_key, propagate=True, **kw): - target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, \ - event_key._listen_fn + target, identifier, fn = ( + event_key.dispatch_target, + event_key.identifier, + event_key._listen_fn, + ) def listen(target_cls, *arg): listen_cls = target() @@ -74,16 +77,20 @@ class InstrumentationEvents(event.Events): def remove(ref): key = event.registry._EventKey( - None, identifier, listen, - instrumentation._instrumentation_factory) - getattr(instrumentation._instrumentation_factory.dispatch, - identifier).remove(key) + None, + identifier, + listen, + instrumentation._instrumentation_factory, + ) + getattr( + instrumentation._instrumentation_factory.dispatch, identifier + ).remove(key) target = weakref.ref(target.class_, remove) - event_key.\ - with_dispatch_target(instrumentation._instrumentation_factory).\ - with_wrapper(listen).base_listen(**kw) + event_key.with_dispatch_target( + instrumentation._instrumentation_factory + ).with_wrapper(listen).base_listen(**kw) @classmethod def _clear(cls): @@ -193,21 +200,24 @@ class InstanceEvents(event.Events): @classmethod def _listen(cls, event_key, raw=False, propagate=False, **kw): - target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, \ - event_key._listen_fn + target, identifier, fn = ( + event_key.dispatch_target, + event_key.identifier, + event_key._listen_fn, + ) if not raw: + def wrap(state, *arg, **kw): return fn(state.obj(), *arg, **kw) + event_key = event_key.with_wrapper(wrap) event_key.base_listen(propagate=propagate, **kw) if propagate: for mgr in target.subclass_managers(True): - event_key.with_dispatch_target(mgr).base_listen( - propagate=True) + event_key.with_dispatch_target(mgr).base_listen(propagate=True) @classmethod def _clear(cls): @@ -438,10 +448,13 @@ class _EventsHold(event.RefCollection): @classmethod def _listen( - cls, event_key, raw=False, propagate=False, - retval=False, **kw): - target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, event_key.fn + cls, event_key, raw=False, propagate=False, retval=False, **kw + ): + target, identifier, fn = ( + event_key.dispatch_target, + event_key.identifier, + event_key.fn, + ) if target.class_ in target.all_holds: collection = target.all_holds[target.class_] @@ -460,12 +473,16 @@ class _EventsHold(event.RefCollection): if subject is not None: # we are already going through __subclasses__() # so leave generic propagate flag False - event_key.with_dispatch_target(subject).\ - listen(raw=raw, propagate=False, retval=retval, **kw) + event_key.with_dispatch_target(subject).listen( + raw=raw, propagate=False, retval=retval, **kw + ) def remove(self, event_key): - target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, event_key.fn + target, identifier, fn = ( + event_key.dispatch_target, + event_key.identifier, + event_key.fn, + ) if isinstance(target, _EventsHold): collection = target.all_holds[target.class_] @@ -483,8 +500,9 @@ class _EventsHold(event.RefCollection): # populate(), we rely upon _EventsHold for all event # assignment, instead of using the generic propagate # flag. - event_key.with_dispatch_target(subject).\ - listen(raw=raw, propagate=False, retval=retval) + event_key.with_dispatch_target(subject).listen( + raw=raw, propagate=False, retval=retval + ) class _InstanceEventsHold(_EventsHold): @@ -594,24 +612,31 @@ class MapperEvents(event.Events): @classmethod def _listen( - cls, event_key, raw=False, retval=False, propagate=False, **kw): - target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, \ - event_key._listen_fn - - if identifier in ("before_configured", "after_configured") and \ - target is not mapperlib.Mapper: + cls, event_key, raw=False, retval=False, propagate=False, **kw + ): + target, identifier, fn = ( + event_key.dispatch_target, + event_key.identifier, + event_key._listen_fn, + ) + + if ( + identifier in ("before_configured", "after_configured") + and target is not mapperlib.Mapper + ): util.warn( "'before_configured' and 'after_configured' ORM events " "only invoke with the mapper() function or Mapper class " - "as the target.") + "as the target." + ) if not raw or not retval: if not raw: meth = getattr(cls, identifier) try: - target_index = \ - inspect_getargspec(meth)[0].index('target') - 1 + target_index = ( + inspect_getargspec(meth)[0].index("target") - 1 + ) except ValueError: target_index = None @@ -624,12 +649,14 @@ class MapperEvents(event.Events): return interfaces.EXT_CONTINUE else: return fn(*arg, **kw) + event_key = event_key.with_wrapper(wrap) if propagate: for mapper in target.self_and_descendants: event_key.with_dispatch_target(mapper).base_listen( - propagate=True, **kw) + propagate=True, **kw + ) else: event_key.base_listen(**kw) @@ -1219,15 +1246,14 @@ class SessionEvents(event.Events): if isinstance(target, scoped_session): target = target.session_factory - if not isinstance(target, sessionmaker) and \ - ( - not isinstance(target, type) or - not issubclass(target, Session) + if not isinstance(target, sessionmaker) and ( + not isinstance(target, type) or not issubclass(target, Session) ): raise exc.ArgumentError( "Session event listen on a scoped_session " "requires that its creation callable " - "is associated with the Session class.") + "is associated with the Session class." + ) if isinstance(target, sessionmaker): return target.class_ @@ -1561,13 +1587,16 @@ class SessionEvents(event.Events): """ - @event._legacy_signature("0.9", - ["session", "query", "query_context", "result"], - lambda update_context: ( - update_context.session, - update_context.query, - update_context.context, - update_context.result)) + @event._legacy_signature( + "0.9", + ["session", "query", "query_context", "result"], + lambda update_context: ( + update_context.session, + update_context.query, + update_context.context, + update_context.result, + ), + ) def after_bulk_update(self, update_context): """Execute after a bulk update operation to the session. @@ -1587,13 +1616,16 @@ class SessionEvents(event.Events): """ - @event._legacy_signature("0.9", - ["session", "query", "query_context", "result"], - lambda delete_context: ( - delete_context.session, - delete_context.query, - delete_context.context, - delete_context.result)) + @event._legacy_signature( + "0.9", + ["session", "query", "query_context", "result"], + lambda delete_context: ( + delete_context.session, + delete_context.query, + delete_context.context, + delete_context.result, + ), + ) def after_bulk_delete(self, delete_context): """Execute after a bulk delete operation to the session. @@ -1927,18 +1959,26 @@ class AttributeEvents(event.Events): return target @classmethod - def _listen(cls, event_key, active_history=False, - raw=False, retval=False, - propagate=False): - - target, identifier, fn = \ - event_key.dispatch_target, event_key.identifier, \ - event_key._listen_fn + def _listen( + cls, + event_key, + active_history=False, + raw=False, + retval=False, + propagate=False, + ): + + target, identifier, fn = ( + event_key.dispatch_target, + event_key.identifier, + event_key._listen_fn, + ) if active_history: target.dispatch._active_history = True if not raw or not retval: + def wrap(target, *arg): if not raw: target = target.obj() @@ -1951,6 +1991,7 @@ class AttributeEvents(event.Events): return value else: return fn(target, *arg) + event_key = event_key.with_wrapper(wrap) event_key.base_listen(propagate=propagate) @@ -1959,8 +2000,9 @@ class AttributeEvents(event.Events): manager = instrumentation.manager_of_class(target.class_) for mgr in manager.subclass_managers(True): - event_key.with_dispatch_target( - mgr[target.key]).base_listen(propagate=True) + event_key.with_dispatch_target(mgr[target.key]).base_listen( + propagate=True + ) def append(self, target, value, initiator): """Receive a collection append event. @@ -2315,11 +2357,11 @@ class QueryEvents(event.Events): """ @classmethod - def _listen( - cls, event_key, retval=False, **kw): + def _listen(cls, event_key, retval=False, **kw): fn = event_key._listen_fn if not retval: + def wrap(*arg, **kw): if not retval: query = arg[0] @@ -2327,6 +2369,7 @@ class QueryEvents(event.Events): return query else: return fn(*arg, **kw) + event_key = event_key.with_wrapper(wrap) event_key.base_listen(**kw) diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py index eb4baa08df..f0aa02e99e 100644 --- a/lib/sqlalchemy/orm/exc.py +++ b/lib/sqlalchemy/orm/exc.py @@ -38,6 +38,7 @@ class StaleDataError(sa_exc.SQLAlchemyError): """ + ConcurrentModificationError = StaleDataError @@ -72,16 +73,19 @@ class UnmappedInstanceError(UnmappedError): try: base.class_mapper(type(obj)) name = _safe_cls_name(type(obj)) - msg = ("Class %r is mapped, but this instance lacks " - "instrumentation. This occurs when the instance " - "is created before sqlalchemy.orm.mapper(%s) " - "was called." % (name, name)) + msg = ( + "Class %r is mapped, but this instance lacks " + "instrumentation. This occurs when the instance " + "is created before sqlalchemy.orm.mapper(%s) " + "was called." % (name, name) + ) except UnmappedClassError: msg = _default_unmapped(type(obj)) if isinstance(obj, type): msg += ( - '; was a class (%s) supplied where an instance was ' - 'required?' % _safe_cls_name(obj)) + "; was a class (%s) supplied where an instance was " + "required?" % _safe_cls_name(obj) + ) UnmappedError.__init__(self, msg) def __reduce__(self): @@ -119,11 +123,14 @@ class ObjectDeletedError(sa_exc.InvalidRequestError): object. """ + @util.dependencies("sqlalchemy.orm.base") def __init__(self, base, state, msg=None): if not msg: - msg = "Instance '%s' has been deleted, or its "\ + msg = ( + "Instance '%s' has been deleted, or its " "row is otherwise not present." % base.state_str(state) + ) sa_exc.InvalidRequestError.__init__(self, msg) @@ -145,9 +152,9 @@ class MultipleResultsFound(sa_exc.InvalidRequestError): def _safe_cls_name(cls): try: - cls_name = '.'.join((cls.__module__, cls.__name__)) + cls_name = ".".join((cls.__module__, cls.__name__)) except AttributeError: - cls_name = getattr(cls, '__name__', None) + cls_name = getattr(cls, "__name__", None) if cls_name is None: cls_name = repr(cls) return cls_name diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index b03bb0a0d4..2487cdb233 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -11,6 +11,7 @@ from .. import util from .. import exc as sa_exc from . import util as orm_util + class IdentityMap(object): def __init__(self): self._dict = {} @@ -84,7 +85,6 @@ class IdentityMap(object): class WeakInstanceDict(IdentityMap): - def __getitem__(self, key): state = self._dict[key] o = state.obj() @@ -145,8 +145,9 @@ class WeakInstanceDict(IdentityMap): raise sa_exc.InvalidRequestError( "Can't attach instance " "%s; another instance with key %s is already " - "present in this session." % ( - orm_util.state_str(state), state.key)) + "present in this session." + % (orm_util.state_str(state), state.key) + ) else: return False self._dict[key] = state @@ -253,6 +254,7 @@ class StrongInstanceDict(IdentityMap): """ if util.py2k: + def itervalues(self): return self._dict.itervalues() @@ -282,8 +284,9 @@ class StrongInstanceDict(IdentityMap): def contains_state(self, state): return ( - state.key in self and - attributes.instance_state(self[state.key]) is state) + state.key in self + and attributes.instance_state(self[state.key]) is state + ) def replace(self, state): if state.key in self._dict: @@ -303,8 +306,9 @@ class StrongInstanceDict(IdentityMap): raise sa_exc.InvalidRequestError( "Can't attach instance " "%s; another instance with key %s is already " - "present in this session." % ( - orm_util.state_str(state), state.key)) + "present in this session." + % (orm_util.state_str(state), state.key) + ) return False else: self._dict[state.key] = state.obj() diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index d34326e0fd..fa29c32333 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -59,11 +59,15 @@ class ClassManager(dict): self.local_attrs = {} self.originals = {} - self._bases = [mgr for mgr in [ - manager_of_class(base) - for base in self.class_.__bases__ - if isinstance(base, type) - ] if mgr is not None] + self._bases = [ + mgr + for mgr in [ + manager_of_class(base) + for base in self.class_.__bases__ + if isinstance(base, type) + ] + if mgr is not None + ] for base in self._bases: self.update(base) @@ -78,12 +82,13 @@ class ClassManager(dict): self.manage() self._instrument_init() - if '__del__' in class_.__dict__: - util.warn("__del__() method on class %s will " - "cause unreachable cycles and memory leaks, " - "as SQLAlchemy instrumentation often creates " - "reference cycles. Please remove this method." % - class_) + if "__del__" in class_.__dict__: + util.warn( + "__del__() method on class %s will " + "cause unreachable cycles and memory leaks, " + "as SQLAlchemy instrumentation often creates " + "reference cycles. Please remove this method." % class_ + ) def __hash__(self): return id(self) @@ -93,7 +98,7 @@ class ClassManager(dict): @property def is_mapped(self): - return 'mapper' in self.__dict__ + return "mapper" in self.__dict__ @_memoized_key_collection def _all_key_set(self): @@ -101,14 +106,19 @@ class ClassManager(dict): @_memoized_key_collection def _collection_impl_keys(self): - return frozenset([ - attr.key for attr in self.values() if attr.impl.collection]) + return frozenset( + [attr.key for attr in self.values() if attr.impl.collection] + ) @_memoized_key_collection def _scalar_loader_impls(self): - return frozenset([ - attr.impl for attr in - self.values() if attr.impl.accepts_scalar_loader]) + return frozenset( + [ + attr.impl + for attr in self.values() + if attr.impl.accepts_scalar_loader + ] + ) @util.memoized_property def mapper(self): @@ -174,11 +184,11 @@ class ClassManager(dict): # of such, since this adds method overhead. self.original_init = self.class_.__init__ self.new_init = _generate_init(self.class_, self) - self.install_member('__init__', self.new_init) + self.install_member("__init__", self.new_init) def _uninstrument_init(self): if self.new_init: - self.uninstall_member('__init__') + self.uninstall_member("__init__") self.new_init = None @util.memoized_property @@ -239,8 +249,9 @@ class ClassManager(dict): yield m def post_configure_attribute(self, key): - _instrumentation_factory.dispatch.\ - attribute_instrument(self.class_, key, self[key]) + _instrumentation_factory.dispatch.attribute_instrument( + self.class_, key, self[key] + ) def uninstrument_attribute(self, key, propagated=False): if key not in self: @@ -272,9 +283,10 @@ class ClassManager(dict): def install_descriptor(self, key, inst): if key in (self.STATE_ATTR, self.MANAGER_ATTR): - raise KeyError("%r: requested attribute name conflicts with " - "instrumentation attribute of the same name." % - key) + raise KeyError( + "%r: requested attribute name conflicts with " + "instrumentation attribute of the same name." % key + ) setattr(self.class_, key, inst) def uninstall_descriptor(self, key): @@ -282,9 +294,10 @@ class ClassManager(dict): def install_member(self, key, implementation): if key in (self.STATE_ATTR, self.MANAGER_ATTR): - raise KeyError("%r: requested attribute name conflicts with " - "instrumentation attribute of the same name." % - key) + raise KeyError( + "%r: requested attribute name conflicts with " + "instrumentation attribute of the same name." % key + ) self.originals.setdefault(key, getattr(self.class_, key, None)) setattr(self.class_, key, implementation) @@ -299,7 +312,8 @@ class ClassManager(dict): def initialize_collection(self, key, state, factory): user_data = factory() adapter = collections.CollectionAdapter( - self.get_impl(key), state, user_data) + self.get_impl(key), state, user_data + ) return adapter, user_data def is_instrumented(self, key, search=False): @@ -343,15 +357,15 @@ class ClassManager(dict): """ if hasattr(instance, self.STATE_ATTR): return False - elif self.class_ is not instance.__class__ and \ - self.is_mapped: + elif self.class_ is not instance.__class__ and self.is_mapped: # this will create a new ClassManager for the # subclass, without a mapper. This is likely a # user error situation but allow the object # to be constructed, so that it is usable # in a non-ORM context at least. - return self._subclass_manager(instance.__class__).\ - _new_state_if_none(instance) + return self._subclass_manager( + instance.__class__ + )._new_state_if_none(instance) else: state = self._state_constructor(instance, self) self._state_setter(instance, state) @@ -371,8 +385,11 @@ class ClassManager(dict): __nonzero__ = __bool__ def __repr__(self): - return '<%s of %r at %x>' % ( - self.__class__.__name__, self.class_, id(self)) + return "<%s of %r at %x>" % ( + self.__class__.__name__, + self.class_, + id(self), + ) class _SerializeManager(object): @@ -396,8 +413,8 @@ class _SerializeManager(object): "Cannot deserialize object of type %r - " "no mapper() has " "been configured for this class within the current " - "Python process!" % - self.class_) + "Python process!" % self.class_, + ) elif manager.is_mapped and not manager.mapper.configured: manager.mapper._configure_all() @@ -447,6 +464,7 @@ class InstrumentationFactory(object): if ClassManager.MANAGER_ATTR in class_.__dict__: delattr(class_, ClassManager.MANAGER_ATTR) + # this attribute is replaced by sqlalchemy.ext.instrumentation # when importred. _instrumentation_factory = InstrumentationFactory() @@ -488,8 +506,9 @@ def is_instrumented(instance, key): applied directly to the class, i.e. no descriptors are required. """ - return manager_of_class(instance.__class__).\ - is_instrumented(key, search=True) + return manager_of_class(instance.__class__).is_instrumented( + key, search=True + ) def _generate_init(class_, class_manager): @@ -518,15 +537,15 @@ def __init__(%(apply_pos)s): func_text = func_body % func_vars if util.py2k: - func = getattr(original__init__, 'im_func', original__init__) - func_defaults = getattr(func, 'func_defaults', None) + func = getattr(original__init__, "im_func", original__init__) + func_defaults = getattr(func, "func_defaults", None) else: - func_defaults = getattr(original__init__, '__defaults__', None) - func_kw_defaults = getattr(original__init__, '__kwdefaults__', None) + func_defaults = getattr(original__init__, "__defaults__", None) + func_kw_defaults = getattr(original__init__, "__kwdefaults__", None) env = locals().copy() exec(func_text, env) - __init__ = env['__init__'] + __init__ = env["__init__"] __init__.__doc__ = original__init__.__doc__ __init__._sa_original_init = original__init__ diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 80d0a63037..d7e70c5d74 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -22,8 +22,15 @@ from __future__ import absolute_import from .. import util from ..sql import operators -from .base import (ONETOMANY, MANYTOONE, MANYTOMANY, - EXT_CONTINUE, EXT_STOP, EXT_SKIP, NOT_EXTENSION) +from .base import ( + ONETOMANY, + MANYTOONE, + MANYTOMANY, + EXT_CONTINUE, + EXT_STOP, + EXT_SKIP, + NOT_EXTENSION, +) from .base import InspectionAttr, InspectionAttrInfo, _MappedAttribute import collections from .. import inspect @@ -33,21 +40,21 @@ from . import path_registry MapperExtension = SessionExtension = AttributeExtension = None __all__ = ( - 'AttributeExtension', - 'EXT_CONTINUE', - 'EXT_STOP', - 'EXT_SKIP', - 'ONETOMANY', - 'MANYTOMANY', - 'MANYTOONE', - 'NOT_EXTENSION', - 'LoaderStrategy', - 'MapperExtension', - 'MapperOption', - 'MapperProperty', - 'PropComparator', - 'SessionExtension', - 'StrategizedProperty', + "AttributeExtension", + "EXT_CONTINUE", + "EXT_STOP", + "EXT_SKIP", + "ONETOMANY", + "MANYTOMANY", + "MANYTOONE", + "NOT_EXTENSION", + "LoaderStrategy", + "MapperExtension", + "MapperOption", + "MapperProperty", + "PropComparator", + "SessionExtension", + "StrategizedProperty", ) @@ -64,8 +71,11 @@ class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots): """ __slots__ = ( - '_configure_started', '_configure_finished', 'parent', 'key', - 'info' + "_configure_started", + "_configure_finished", + "parent", + "key", + "info", ) cascade = frozenset() @@ -118,15 +128,17 @@ class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots): """ - def create_row_processor(self, context, path, - mapper, result, adapter, populators): + def create_row_processor( + self, context, path, mapper, result, adapter, populators + ): """Produce row processing functions and append to the given set of populators lists. """ - def cascade_iterator(self, type_, state, visited_instances=None, - halt_on=None): + def cascade_iterator( + self, type_, state, visited_instances=None, halt_on=None + ): """Iterate through instances related to the given instance for a particular 'cascade', starting with this MapperProperty. @@ -234,17 +246,28 @@ class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots): """ - def merge(self, session, source_state, source_dict, dest_state, - dest_dict, load, _recursive, _resolve_conflict_map): + def merge( + self, + session, + source_state, + source_dict, + dest_state, + dest_dict, + load, + _recursive, + _resolve_conflict_map, + ): """Merge the attribute represented by this ``MapperProperty`` from source to destination object. """ def __repr__(self): - return '<%s at 0x%x; %s>' % ( + return "<%s at 0x%x; %s>" % ( self.__class__.__name__, - id(self), getattr(self, 'key', 'no key')) + id(self), + getattr(self, "key", "no key"), + ) class PropComparator(operators.ColumnOperators): @@ -335,7 +358,7 @@ class PropComparator(operators.ColumnOperators): """ - __slots__ = 'prop', 'property', '_parententity', '_adapt_to_entity' + __slots__ = "prop", "property", "_parententity", "_adapt_to_entity" def __init__(self, prop, parentmapper, adapt_to_entity=None): self.prop = self.property = prop @@ -467,21 +490,27 @@ class StrategizedProperty(MapperProperty): """ __slots__ = ( - '_strategies', 'strategy', - '_wildcard_token', '_default_path_loader_key' + "_strategies", + "strategy", + "_wildcard_token", + "_default_path_loader_key", ) strategy_wildcard_key = None def _memoized_attr__wildcard_token(self): - return ("%s:%s" % ( - self.strategy_wildcard_key, path_registry._WILDCARD_TOKEN), ) + return ( + "%s:%s" + % (self.strategy_wildcard_key, path_registry._WILDCARD_TOKEN), + ) def _memoized_attr__default_path_loader_key(self): return ( "loader", - ("%s:%s" % ( - self.strategy_wildcard_key, path_registry._DEFAULT_TOKEN), ) + ( + "%s:%s" + % (self.strategy_wildcard_key, path_registry._DEFAULT_TOKEN), + ), ) def _get_context_loader(self, context, path): @@ -496,7 +525,7 @@ class StrategizedProperty(MapperProperty): for path_key in ( search_path._loader_key, search_path._wildcard_path_loader_key, - search_path._default_path_loader_key + search_path._default_path_loader_key, ): if path_key in context.attributes: load = context.attributes[path_key] @@ -509,12 +538,12 @@ class StrategizedProperty(MapperProperty): return self._strategies[key] except KeyError: cls = self._strategy_lookup(*key) - self._strategies[key] = self._strategies[ - cls] = strategy = cls(self, key) + self._strategies[key] = self._strategies[cls] = strategy = cls( + self, key + ) return strategy - def setup( - self, context, entity, path, adapter, **kwargs): + def setup(self, context, entity, path, adapter, **kwargs): loader = self._get_context_loader(context, path) if loader and loader.strategy: strat = self._get_strategy(loader.strategy) @@ -523,24 +552,26 @@ class StrategizedProperty(MapperProperty): strat.setup_query(context, entity, path, loader, adapter, **kwargs) def create_row_processor( - self, context, path, mapper, - result, adapter, populators): + self, context, path, mapper, result, adapter, populators + ): loader = self._get_context_loader(context, path) if loader and loader.strategy: strat = self._get_strategy(loader.strategy) else: strat = self.strategy strat.create_row_processor( - context, path, loader, - mapper, result, adapter, populators) + context, path, loader, mapper, result, adapter, populators + ) def do_init(self): self._strategies = {} self.strategy = self._get_strategy(self.strategy_key) def post_instrument_class(self, mapper): - if not self.parent.non_primary and \ - not mapper.class_manager._attr_has_impl(self.key): + if ( + not self.parent.non_primary + and not mapper.class_manager._attr_has_impl(self.key) + ): self.strategy.init_class_attribute(mapper) _all_strategies = collections.defaultdict(dict) @@ -550,12 +581,13 @@ class StrategizedProperty(MapperProperty): def decorate(dec_cls): # ensure each subclass of the strategy has its # own _strategy_keys collection - if '_strategy_keys' not in dec_cls.__dict__: + if "_strategy_keys" not in dec_cls.__dict__: dec_cls._strategy_keys = [] key = tuple(sorted(kw.items())) cls._all_strategies[cls][key] = dec_cls dec_cls._strategy_keys.append(key) return dec_cls + return decorate @classmethod @@ -671,8 +703,14 @@ class LoaderStrategy(object): """ - __slots__ = 'parent_property', 'is_class_level', 'parent', 'key', \ - 'strategy_key', 'strategy_opts' + __slots__ = ( + "parent_property", + "is_class_level", + "parent", + "key", + "strategy_key", + "strategy_opts", + ) def __init__(self, parent, strategy_key): self.parent_property = parent @@ -695,8 +733,9 @@ class LoaderStrategy(object): """ - def create_row_processor(self, context, path, loadopt, mapper, - result, adapter, populators): + def create_row_processor( + self, context, path, loadopt, mapper, result, adapter, populators + ): """Establish row processing functions for a given QueryContext. This method fulfills the contract specified by diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 0a6f8023aa..96eddcb326 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -37,32 +37,35 @@ def instances(query, cursor, context): filtered = query._has_mapper_entities - single_entity = not query._only_return_tuples and \ - len(query._entities) == 1 and \ - query._entities[0].supports_single_entity + single_entity = ( + not query._only_return_tuples + and len(query._entities) == 1 + and query._entities[0].supports_single_entity + ) if filtered: if single_entity: filter_fn = id else: + def filter_fn(row): return tuple( - id(item) - if ent.use_id_for_hash - else item + id(item) if ent.use_id_for_hash else item for ent, item in zip(query._entities, row) ) try: - (process, labels) = \ - list(zip(*[ - query_entity.row_processor(query, - context, cursor) - for query_entity in query._entities - ])) + (process, labels) = list( + zip( + *[ + query_entity.row_processor(query, context, cursor) + for query_entity in query._entities + ] + ) + ) if not single_entity: - keyed_tuple = util.lightweight_named_tuple('result', labels) + keyed_tuple = util.lightweight_named_tuple("result", labels) while True: context.partials = {} @@ -78,11 +81,12 @@ def instances(query, cursor, context): proc = process[0] rows = [proc(row) for row in fetch] else: - rows = [keyed_tuple([proc(row) for proc in process]) - for row in fetch] + rows = [ + keyed_tuple([proc(row) for proc in process]) + for row in fetch + ] - for path, post_load in \ - context.post_load_paths.items(): + for path, post_load in context.post_load_paths.items(): post_load.invoke(context, path) if filtered: @@ -113,19 +117,27 @@ def merge_result(querylib, query, iterator, load=True): single_entity = len(query._entities) == 1 if single_entity: if isinstance(query._entities[0], querylib._MapperEntity): - result = [session._merge( - attributes.instance_state(instance), - attributes.instance_dict(instance), - load=load, _recursive={}, _resolve_conflict_map={}) - for instance in iterator] + result = [ + session._merge( + attributes.instance_state(instance), + attributes.instance_dict(instance), + load=load, + _recursive={}, + _resolve_conflict_map={}, + ) + for instance in iterator + ] else: result = list(iterator) else: - mapped_entities = [i for i, e in enumerate(query._entities) - if isinstance(e, querylib._MapperEntity)] + mapped_entities = [ + i + for i, e in enumerate(query._entities) + if isinstance(e, querylib._MapperEntity) + ] result = [] keys = [ent._label_name for ent in query._entities] - keyed_tuple = util.lightweight_named_tuple('result', keys) + keyed_tuple = util.lightweight_named_tuple("result", keys) for row in iterator: newrow = list(row) for i in mapped_entities: @@ -133,7 +145,10 @@ def merge_result(querylib, query, iterator, load=True): newrow[i] = session._merge( attributes.instance_state(newrow[i]), attributes.instance_dict(newrow[i]), - load=load, _recursive={}, _resolve_conflict_map={}) + load=load, + _recursive={}, + _resolve_conflict_map={}, + ) result.append(keyed_tuple(newrow)) return iter(result) @@ -170,9 +185,9 @@ def get_from_identity(session, key, passive): return None -def load_on_ident(query, key, - refresh_state=None, with_for_update=None, - only_load_props=None): +def load_on_ident( + query, key, refresh_state=None, with_for_update=None, only_load_props=None +): """Load the given identity key from the database.""" if key is not None: @@ -182,16 +197,23 @@ def load_on_ident(query, key, ident = identity_token = None return load_on_pk_identity( - query, ident, refresh_state=refresh_state, + query, + ident, + refresh_state=refresh_state, with_for_update=with_for_update, only_load_props=only_load_props, - identity_token=identity_token + identity_token=identity_token, ) -def load_on_pk_identity(query, primary_key_identity, - refresh_state=None, with_for_update=None, - only_load_props=None, identity_token=None): +def load_on_pk_identity( + query, + primary_key_identity, + refresh_state=None, + with_for_update=None, + only_load_props=None, + identity_token=None, +): """Load the given primary key identity from the database.""" @@ -209,22 +231,28 @@ def load_on_pk_identity(query, primary_key_identity, # None present in ident - turn those comparisons # into "IS NULL" if None in primary_key_identity: - nones = set([ - _get_params[col].key for col, value in - zip(mapper.primary_key, primary_key_identity) - if value is None - ]) - _get_clause = sql_util.adapt_criterion_to_null( - _get_clause, nones) + nones = set( + [ + _get_params[col].key + for col, value in zip( + mapper.primary_key, primary_key_identity + ) + if value is None + ] + ) + _get_clause = sql_util.adapt_criterion_to_null(_get_clause, nones) _get_clause = q._adapt_clause(_get_clause, True, False) q._criterion = _get_clause - params = dict([ - (_get_params[primary_key].key, id_val) - for id_val, primary_key - in zip(primary_key_identity, mapper.primary_key) - ]) + params = dict( + [ + (_get_params[primary_key].key, id_val) + for id_val, primary_key in zip( + primary_key_identity, mapper.primary_key + ) + ] + ) q._params = params @@ -243,7 +271,8 @@ def load_on_pk_identity(query, primary_key_identity, version_check=version_check, only_load_props=only_load_props, refresh_state=refresh_state, - identity_token=identity_token) + identity_token=identity_token, + ) q._order_by = None try: @@ -253,27 +282,31 @@ def load_on_pk_identity(query, primary_key_identity, def _setup_entity_query( - context, mapper, query_entity, - path, adapter, column_collection, - with_polymorphic=None, only_load_props=None, - polymorphic_discriminator=None, **kw): + context, + mapper, + query_entity, + path, + adapter, + column_collection, + with_polymorphic=None, + only_load_props=None, + polymorphic_discriminator=None, + **kw +): if with_polymorphic: poly_properties = mapper._iterate_polymorphic_properties( - with_polymorphic) + with_polymorphic + ) else: poly_properties = mapper._polymorphic_properties quick_populators = {} - path.set( - context.attributes, - "memoized_setups", - quick_populators) + path.set(context.attributes, "memoized_setups", quick_populators) for value in poly_properties: - if only_load_props and \ - value.key not in only_load_props: + if only_load_props and value.key not in only_load_props: continue value.setup( context, @@ -286,9 +319,10 @@ def _setup_entity_query( **kw ) - if polymorphic_discriminator is not None and \ - polymorphic_discriminator \ - is not mapper.polymorphic_on: + if ( + polymorphic_discriminator is not None + and polymorphic_discriminator is not mapper.polymorphic_on + ): if adapter: pd = adapter.columns[polymorphic_discriminator] @@ -298,10 +332,16 @@ def _setup_entity_query( def _instance_processor( - mapper, context, result, path, adapter, - only_load_props=None, refresh_state=None, - polymorphic_discriminator=None, - _polymorphic_from=None): + mapper, + context, + result, + path, + adapter, + only_load_props=None, + refresh_state=None, + polymorphic_discriminator=None, + _polymorphic_from=None, +): """Produce a mapper level row processor callable which processes rows into mapped instances.""" @@ -322,11 +362,11 @@ def _instance_processor( props = mapper._prop_set if only_load_props is not None: - props = props.intersection( - mapper._props[k] for k in only_load_props) + props = props.intersection(mapper._props[k] for k in only_load_props) quick_populators = path.get( - context.attributes, "memoized_setups", _none_set) + context.attributes, "memoized_setups", _none_set + ) for prop in props: if prop in quick_populators: @@ -334,7 +374,8 @@ def _instance_processor( col = quick_populators[prop] if col is _DEFER_FOR_STATE: populators["new"].append( - (prop.key, prop._deferred_column_loader)) + (prop.key, prop._deferred_column_loader) + ) elif col is _SET_DEFERRED_EXPIRED: # note that in this path, we are no longer # searching in the result to see if the column might @@ -366,14 +407,19 @@ def _instance_processor( # will iterate through all of its columns # to see if one fits prop.create_row_processor( - context, path, mapper, result, adapter, populators) + context, path, mapper, result, adapter, populators + ) else: prop.create_row_processor( - context, path, mapper, result, adapter, populators) + context, path, mapper, result, adapter, populators + ) propagate_options = context.propagate_options - load_path = context.query._current_path + path \ - if context.query._current_path.path else path + load_path = ( + context.query._current_path + path + if context.query._current_path.path + else path + ) session_identity_map = context.session.identity_map @@ -391,18 +437,18 @@ def _instance_processor( identity_token = context.identity_token if not refresh_state and _polymorphic_from is not None: - key = ('loader', path.path) - if ( - key in context.attributes and - context.attributes[key].strategy == - (('selectinload_polymorphic', True), ) + key = ("loader", path.path) + if key in context.attributes and context.attributes[key].strategy == ( + ("selectinload_polymorphic", True), ): selectin_load_via = mapper._should_selectin_load( - context.attributes[key].local_opts['entities'], - _polymorphic_from) + context.attributes[key].local_opts["entities"], + _polymorphic_from, + ) else: selectin_load_via = mapper._should_selectin_load( - None, _polymorphic_from) + None, _polymorphic_from + ) if selectin_load_via and selectin_load_via is not _polymorphic_from: # only_load_props goes w/ refresh_state only, and in a refresh @@ -413,9 +459,13 @@ def _instance_processor( callable_ = _load_subclass_via_in(context, path, selectin_load_via) PostLoad.callable_for_path( - context, load_path, selectin_load_via.mapper, + context, + load_path, + selectin_load_via.mapper, + selectin_load_via, + callable_, selectin_load_via, - callable_, selectin_load_via) + ) post_load = PostLoad.for_context(context, load_path, only_load_props) @@ -425,8 +475,9 @@ def _instance_processor( # super-rare condition; a refresh is being called # on a non-instance-key instance; this is meant to only # occur within a flush() - refresh_identity_key = \ - mapper._identity_key_from_state(refresh_state) + refresh_identity_key = mapper._identity_key_from_state( + refresh_state + ) else: refresh_identity_key = None @@ -452,7 +503,7 @@ def _instance_processor( identitykey = ( identity_class, tuple([row[column] for column in pk_cols]), - identity_token + identity_token, ) instance = session_identity_map.get(identitykey) @@ -507,8 +558,16 @@ def _instance_processor( state.load_path = load_path _populate_full( - context, row, state, dict_, isnew, load_path, - loaded_instance, populate_existing, populators) + context, + row, + state, + dict_, + isnew, + load_path, + loaded_instance, + populate_existing, + populators, + ) if isnew: if loaded_instance: @@ -518,7 +577,8 @@ def _instance_processor( loaded_as_persistent(context.session, state.obj()) elif refresh_evt: state.manager.dispatch.refresh( - state, context, only_load_props) + state, context, only_load_props + ) if populate_existing or state.modified: if refresh_state and only_load_props: @@ -542,13 +602,19 @@ def _instance_processor( # and add to the "context.partials" collection. to_load = _populate_partial( - context, row, state, dict_, isnew, load_path, - unloaded, populators) + context, + row, + state, + dict_, + isnew, + load_path, + unloaded, + populators, + ) if isnew: if refresh_evt: - state.manager.dispatch.refresh( - state, context, to_load) + state.manager.dispatch.refresh(state, context, to_load) state._commit(dict_, to_load) @@ -561,8 +627,14 @@ def _instance_processor( # if we are doing polymorphic, dispatch to a different _instance() # method specific to the subclass mapper _instance = _decorate_polymorphic_switch( - _instance, context, mapper, result, path, - polymorphic_discriminator, adapter) + _instance, + context, + mapper, + result, + path, + polymorphic_discriminator, + adapter, + ) return _instance @@ -581,14 +653,13 @@ def _load_subclass_via_in(context, path, entity): orig_query = context.query q2 = q._with_lazyload_options( - (enable_opt, ) + orig_query._with_options + (disable_opt, ), - path.parent, cache_path=path + (enable_opt,) + orig_query._with_options + (disable_opt,), + path.parent, + cache_path=path, ) if orig_query._populate_existing: - q2.add_criteria( - lambda q: q.populate_existing() - ) + q2.add_criteria(lambda q: q.populate_existing()) q2(context.session).params( primary_keys=[ @@ -601,8 +672,16 @@ def _load_subclass_via_in(context, path, entity): def _populate_full( - context, row, state, dict_, isnew, load_path, - loaded_instance, populate_existing, populators): + context, + row, + state, + dict_, + isnew, + load_path, + loaded_instance, + populate_existing, + populators, +): if isnew: # first time we are seeing a row with this identity. state.runid = context.runid @@ -650,8 +729,8 @@ def _populate_full( def _populate_partial( - context, row, state, dict_, isnew, load_path, - unloaded, populators): + context, row, state, dict_, isnew, load_path, unloaded, populators +): if not isnew: to_load = context.partials[state] @@ -693,19 +772,32 @@ def _validate_version_id(mapper, state, dict_, row, adapter): if adapter: version_id_col = adapter.columns[version_id_col] - if mapper._get_state_attr_by_column( - state, dict_, mapper.version_id_col) != row[version_id_col]: + if ( + mapper._get_state_attr_by_column(state, dict_, mapper.version_id_col) + != row[version_id_col] + ): raise orm_exc.StaleDataError( "Instance '%s' has version id '%s' which " "does not match database-loaded version id '%s'." - % (state_str(state), mapper._get_state_attr_by_column( - state, dict_, mapper.version_id_col), - row[version_id_col])) + % ( + state_str(state), + mapper._get_state_attr_by_column( + state, dict_, mapper.version_id_col + ), + row[version_id_col], + ) + ) def _decorate_polymorphic_switch( - instance_fn, context, mapper, result, path, - polymorphic_discriminator, adapter): + instance_fn, + context, + mapper, + result, + path, + polymorphic_discriminator, + adapter, +): if polymorphic_discriminator is not None: polymorphic_on = polymorphic_discriminator else: @@ -721,19 +813,22 @@ def _decorate_polymorphic_switch( sub_mapper = mapper.polymorphic_map[discriminator] except KeyError: raise AssertionError( - "No such polymorphic_identity %r is defined" % - discriminator) + "No such polymorphic_identity %r is defined" % discriminator + ) else: if sub_mapper is mapper: return None return _instance_processor( - sub_mapper, context, result, - path, adapter, _polymorphic_from=mapper) + sub_mapper, + context, + result, + path, + adapter, + _polymorphic_from=mapper, + ) - polymorphic_instances = util.PopulateDict( - configure_subclass_mapper - ) + polymorphic_instances = util.PopulateDict(configure_subclass_mapper) def polymorphic_instance(row): discriminator = row[polymorphic_on] @@ -742,6 +837,7 @@ def _decorate_polymorphic_switch( if _instance: return _instance(row) return instance_fn(row) + return polymorphic_instance @@ -749,7 +845,8 @@ class PostLoad(object): """Track loaders and states for "post load" operations. """ - __slots__ = 'loaders', 'states', 'load_keys' + + __slots__ = "loaders", "states", "load_keys" def __init__(self): self.loaders = {} @@ -770,8 +867,7 @@ class PostLoad(object): for token, limit_to_mapper, loader, arg, kw in self.loaders.values(): states = [ (state, overwrite) - for state, overwrite - in self.states.items() + for state, overwrite in self.states.items() if state.manager.mapper.isa(limit_to_mapper) ] if states: @@ -787,13 +883,15 @@ class PostLoad(object): @classmethod def path_exists(self, context, path, key): - return path.path in context.post_load_paths and \ - key in context.post_load_paths[path.path].loaders + return ( + path.path in context.post_load_paths + and key in context.post_load_paths[path.path].loaders + ) @classmethod def callable_for_path( - cls, context, path, limit_to_mapper, token, - loader_callable, *arg, **kw): + cls, context, path, limit_to_mapper, token, loader_callable, *arg, **kw + ): if path.path in context.post_load_paths: pl = context.post_load_paths[path.path] else: @@ -809,8 +907,8 @@ def load_scalar_attributes(mapper, state, attribute_names): if not session: raise orm_exc.DetachedInstanceError( "Instance %s is not bound to a Session; " - "attribute refresh operation cannot proceed" % - (state_str(state))) + "attribute refresh operation cannot proceed" % (state_str(state)) + ) has_key = bool(state.key) @@ -833,13 +931,12 @@ def load_scalar_attributes(mapper, state, attribute_names): statement = mapper._optimized_get_statement(state, attribute_names) if statement is not None: result = load_on_ident( - session.query(mapper). - options( - strategy_options.Load(mapper).undefer("*") - ).from_statement(statement), + session.query(mapper) + .options(strategy_options.Load(mapper).undefer("*")) + .from_statement(statement), None, only_load_props=attribute_names, - refresh_state=state + refresh_state=state, ) if result is False: @@ -850,30 +947,34 @@ def load_scalar_attributes(mapper, state, attribute_names): # object is becoming persistent but hasn't yet been assigned # an identity_key. # check here to ensure we have the attrs we need. - pk_attrs = [mapper._columntoproperty[col].key - for col in mapper.primary_key] + pk_attrs = [ + mapper._columntoproperty[col].key for col in mapper.primary_key + ] if state.expired_attributes.intersection(pk_attrs): raise sa_exc.InvalidRequestError( "Instance %s cannot be refreshed - it's not " " persistent and does not " - "contain a full primary key." % state_str(state)) + "contain a full primary key." % state_str(state) + ) identity_key = mapper._identity_key_from_state(state) - if (_none_set.issubset(identity_key) and - not mapper.allow_partial_pks) or \ - _none_set.issuperset(identity_key): + if ( + _none_set.issubset(identity_key) and not mapper.allow_partial_pks + ) or _none_set.issuperset(identity_key): util.warn_limited( "Instance %s to be refreshed doesn't " "contain a full primary key - can't be refreshed " "(and shouldn't be expired, either).", - state_str(state)) + state_str(state), + ) return result = load_on_ident( session.query(mapper), identity_key, refresh_state=state, - only_load_props=attribute_names) + only_load_props=attribute_names, + ) # if instance is pending, a refresh operation # may not complete (even if PK attributes are assigned) diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index fa731f7298..ea88907889 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -26,12 +26,21 @@ from ..sql import expression, visitors, operators, util as sql_util from . import instrumentation, attributes, exc as orm_exc, loading from . import properties from . import util as orm_util -from .interfaces import MapperProperty, InspectionAttr, _MappedAttribute, \ - EXT_SKIP - - -from .base import _class_to_mapper, _state_mapper, class_mapper, \ - state_str, _INSTRUMENTOR +from .interfaces import ( + MapperProperty, + InspectionAttr, + _MappedAttribute, + EXT_SKIP, +) + + +from .base import ( + _class_to_mapper, + _state_mapper, + class_mapper, + state_str, + _INSTRUMENTOR, +) from .path_registry import PathRegistry import sys @@ -46,7 +55,7 @@ _memoized_configured_property = util.group_expirable_memoized_property() # a constant returned by _get_attr_by_column to indicate # this mapper is not handling an attribute for a particular # column -NO_ATTRIBUTE = util.symbol('NO_ATTRIBUTE') +NO_ATTRIBUTE = util.symbol("NO_ATTRIBUTE") # lock used to synchronize the "mapper configure" step _CONFIGURE_MUTEX = util.threading.RLock() @@ -90,38 +99,39 @@ class Mapper(InspectionAttr): _new_mappers = False _dispose_called = False - def __init__(self, - class_, - local_table=None, - properties=None, - primary_key=None, - non_primary=False, - inherits=None, - inherit_condition=None, - inherit_foreign_keys=None, - extension=None, - order_by=False, - always_refresh=False, - version_id_col=None, - version_id_generator=None, - polymorphic_on=None, - _polymorphic_map=None, - polymorphic_identity=None, - concrete=False, - with_polymorphic=None, - polymorphic_load=None, - allow_partial_pks=True, - batch=True, - column_prefix=None, - include_properties=None, - exclude_properties=None, - passive_updates=True, - passive_deletes=False, - confirm_deleted_rows=True, - eager_defaults=False, - legacy_is_orphan=False, - _compiled_cache_size=100, - ): + def __init__( + self, + class_, + local_table=None, + properties=None, + primary_key=None, + non_primary=False, + inherits=None, + inherit_condition=None, + inherit_foreign_keys=None, + extension=None, + order_by=False, + always_refresh=False, + version_id_col=None, + version_id_generator=None, + polymorphic_on=None, + _polymorphic_map=None, + polymorphic_identity=None, + concrete=False, + with_polymorphic=None, + polymorphic_load=None, + allow_partial_pks=True, + batch=True, + column_prefix=None, + include_properties=None, + exclude_properties=None, + passive_updates=True, + passive_deletes=False, + confirm_deleted_rows=True, + eager_defaults=False, + legacy_is_orphan=False, + _compiled_cache_size=100, + ): r"""Return a new :class:`~.Mapper` object. This function is typically used behind the scenes @@ -588,7 +598,7 @@ class Mapper(InspectionAttr): """ - self.class_ = util.assert_arg_type(class_, type, 'class_') + self.class_ = util.assert_arg_type(class_, type, "class_") self.class_manager = None @@ -600,7 +610,8 @@ class Mapper(InspectionAttr): util.warn_deprecated( "Mapper.order_by is deprecated." "Use Query.order_by() in order to affect the ordering of ORM " - "result sets.") + "result sets." + ) else: self.order_by = order_by @@ -631,7 +642,8 @@ class Mapper(InspectionAttr): self.eager_defaults = eager_defaults self.column_prefix = column_prefix self.polymorphic_on = expression._clause_element_as_expr( - polymorphic_on) + polymorphic_on + ) self._dependency_processors = [] self.validators = util.immutabledict() self.passive_updates = passive_updates @@ -974,14 +986,16 @@ class Mapper(InspectionAttr): self.inherits = class_mapper(self.inherits, configure=False) if not issubclass(self.class_, self.inherits.class_): raise sa_exc.ArgumentError( - "Class '%s' does not inherit from '%s'" % - (self.class_.__name__, self.inherits.class_.__name__)) + "Class '%s' does not inherit from '%s'" + % (self.class_.__name__, self.inherits.class_.__name__) + ) if self.non_primary != self.inherits.non_primary: np = not self.non_primary and "primary" or "non-primary" raise sa_exc.ArgumentError( "Inheritance of %s mapper for class '%s' is " - "only allowed from a %s mapper" % - (np, self.class_.__name__, np)) + "only allowed from a %s mapper" + % (np, self.class_.__name__, np) + ) # inherit_condition is optional. if self.local_table is None: self.local_table = self.inherits.local_table @@ -1000,18 +1014,19 @@ class Mapper(InspectionAttr): # full table which could pull in other stuff we don't # want (allows test/inheritance.InheritTest4 to pass) self.inherit_condition = sql_util.join_condition( - self.inherits.local_table, - self.local_table) + self.inherits.local_table, self.local_table + ) self.mapped_table = sql.join( self.inherits.mapped_table, self.local_table, - self.inherit_condition) + self.inherit_condition, + ) fks = util.to_set(self.inherit_foreign_keys) - self._inherits_equated_pairs = \ - sql_util.criterion_as_pairs( - self.mapped_table.onclause, - consider_as_foreign_keys=fks) + self._inherits_equated_pairs = sql_util.criterion_as_pairs( + self.mapped_table.onclause, + consider_as_foreign_keys=fks, + ) else: self.mapped_table = self.local_table @@ -1023,21 +1038,27 @@ class Mapper(InspectionAttr): if self.version_id_col is None: self.version_id_col = self.inherits.version_id_col self.version_id_generator = self.inherits.version_id_generator - elif self.inherits.version_id_col is not None and \ - self.version_id_col is not self.inherits.version_id_col: + elif ( + self.inherits.version_id_col is not None + and self.version_id_col is not self.inherits.version_id_col + ): util.warn( "Inheriting version_id_col '%s' does not match inherited " "version_id_col '%s' and will not automatically populate " "the inherited versioning column. " "version_id_col should only be specified on " - "the base-most mapper that includes versioning." % - (self.version_id_col.description, - self.inherits.version_id_col.description) + "the base-most mapper that includes versioning." + % ( + self.version_id_col.description, + self.inherits.version_id_col.description, + ) ) - if self.order_by is False and \ - not self.concrete and \ - self.inherits.order_by is not False: + if ( + self.order_by is False + and not self.concrete + and self.inherits.order_by is not False + ): self.order_by = self.inherits.order_by self.polymorphic_map = self.inherits.polymorphic_map @@ -1045,8 +1066,9 @@ class Mapper(InspectionAttr): self.inherits._inheriting_mappers.append(self) self.base_mapper = self.inherits.base_mapper self.passive_updates = self.inherits.passive_updates - self.passive_deletes = self.inherits.passive_deletes or \ - self.passive_deletes + self.passive_deletes = ( + self.inherits.passive_deletes or self.passive_deletes + ) self._all_tables = self.inherits._all_tables if self.polymorphic_identity is not None: @@ -1054,25 +1076,30 @@ class Mapper(InspectionAttr): util.warn( "Reassigning polymorphic association for identity %r " "from %r to %r: Check for duplicate use of %r as " - "value for polymorphic_identity." % - (self.polymorphic_identity, - self.polymorphic_map[self.polymorphic_identity], - self, self.polymorphic_identity) + "value for polymorphic_identity." + % ( + self.polymorphic_identity, + self.polymorphic_map[self.polymorphic_identity], + self, + self.polymorphic_identity, + ) ) self.polymorphic_map[self.polymorphic_identity] = self if self.polymorphic_load and self.concrete: raise exc.ArgumentError( "polymorphic_load is not currently supported " - "with concrete table inheritance") - if self.polymorphic_load == 'inline': + "with concrete table inheritance" + ) + if self.polymorphic_load == "inline": self.inherits._add_with_polymorphic_subclass(self) - elif self.polymorphic_load == 'selectin': + elif self.polymorphic_load == "selectin": pass elif self.polymorphic_load is not None: raise sa_exc.ArgumentError( - "unknown argument for polymorphic_load: %r" % - self.polymorphic_load) + "unknown argument for polymorphic_load: %r" + % self.polymorphic_load + ) else: self._all_tables = set() @@ -1084,15 +1111,16 @@ class Mapper(InspectionAttr): if self.mapped_table is None: raise sa_exc.ArgumentError( - "Mapper '%s' does not have a mapped_table specified." - % self) + "Mapper '%s' does not have a mapped_table specified." % self + ) def _set_with_polymorphic(self, with_polymorphic): - if with_polymorphic == '*': - self.with_polymorphic = ('*', None) + if with_polymorphic == "*": + self.with_polymorphic = ("*", None) elif isinstance(with_polymorphic, (tuple, list)): if isinstance( - with_polymorphic[0], util.string_types + (tuple, list)): + with_polymorphic[0], util.string_types + (tuple, list) + ): self.with_polymorphic = with_polymorphic else: self.with_polymorphic = (with_polymorphic, None) @@ -1109,11 +1137,13 @@ class Mapper(InspectionAttr): "SELECT from a subquery that does not have an alias." ) - if self.with_polymorphic and \ - isinstance(self.with_polymorphic[1], - expression.SelectBase): - self.with_polymorphic = (self.with_polymorphic[0], - self.with_polymorphic[1].alias()) + if self.with_polymorphic and isinstance( + self.with_polymorphic[1], expression.SelectBase + ): + self.with_polymorphic = ( + self.with_polymorphic[0], + self.with_polymorphic[1].alias(), + ) if self.configured: self._expire_memoizations() @@ -1122,12 +1152,9 @@ class Mapper(InspectionAttr): subcl = mapper.class_ if self.with_polymorphic is None: self._set_with_polymorphic((subcl,)) - elif self.with_polymorphic[0] != '*': + elif self.with_polymorphic[0] != "*": self._set_with_polymorphic( - ( - self.with_polymorphic[0] + (subcl, ), - self.with_polymorphic[1] - ) + (self.with_polymorphic[0] + (subcl,), self.with_polymorphic[1]) ) def _set_concrete_base(self, mapper): @@ -1152,9 +1179,9 @@ class Mapper(InspectionAttr): self._all_tables = self.inherits._all_tables for key, prop in mapper._props.items(): - if key not in self._props and \ - not self._should_exclude(key, key, local=False, - column=None): + if key not in self._props and not self._should_exclude( + key, key, local=False, column=None + ): self._adapt_inherited_property(key, prop, False) def _set_polymorphic_on(self, polymorphic_on): @@ -1166,8 +1193,13 @@ class Mapper(InspectionAttr): if self.inherits: self.dispatch._update(self.inherits.dispatch) super_extensions = set( - chain(*[m._deprecated_extensions - for m in self.inherits.iterate_to_root()])) + chain( + *[ + m._deprecated_extensions + for m in self.inherits.iterate_to_root() + ] + ) + ) else: super_extensions = set() @@ -1178,8 +1210,13 @@ class Mapper(InspectionAttr): def _configure_listeners(self): if self.inherits: super_extensions = set( - chain(*[m._deprecated_extensions - for m in self.inherits.iterate_to_root()])) + chain( + *[ + m._deprecated_extensions + for m in self.inherits.iterate_to_root() + ] + ) + ) else: super_extensions = set() @@ -1206,7 +1243,8 @@ class Mapper(InspectionAttr): raise sa_exc.InvalidRequestError( "Class %s has no primary mapper configured. Configure " "a primary mapper first before setting up a non primary " - "Mapper." % self.class_) + "Mapper." % self.class_ + ) self.class_manager = manager self._identity_class = manager.mapper._identity_class _mapper_registry[self] = True @@ -1219,12 +1257,13 @@ class Mapper(InspectionAttr): "Class '%s' already has a primary mapper defined. " "Use non_primary=True to " "create a non primary Mapper. clear_mappers() will " - "remove *all* current mappers from all classes." % - self.class_) + "remove *all* current mappers from all classes." + % self.class_ + ) # else: - # a ClassManager may already exist as - # ClassManager.instrument_attribute() creates - # new managers for each subclass if they don't yet exist. + # a ClassManager may already exist as + # ClassManager.instrument_attribute() creates + # new managers for each subclass if they don't yet exist. _mapper_registry[self] = True @@ -1239,33 +1278,35 @@ class Mapper(InspectionAttr): manager.mapper = self manager.deferred_scalar_loader = util.partial( - loading.load_scalar_attributes, self) + loading.load_scalar_attributes, self + ) # The remaining members can be added by any mapper, # e_name None or not. if manager.info.get(_INSTRUMENTOR, False): return - event.listen(manager, 'first_init', _event_on_first_init, raw=True) - event.listen(manager, 'init', _event_on_init, raw=True) + event.listen(manager, "first_init", _event_on_first_init, raw=True) + event.listen(manager, "init", _event_on_init, raw=True) for key, method in util.iterate_attributes(self.class_): - if key == '__init__' and hasattr(method, '_sa_original_init'): + if key == "__init__" and hasattr(method, "_sa_original_init"): method = method._sa_original_init if isinstance(method, types.MethodType): method = method.im_func if isinstance(method, types.FunctionType): - if hasattr(method, '__sa_reconstructor__'): + if hasattr(method, "__sa_reconstructor__"): self._reconstructor = method - event.listen(manager, 'load', _event_on_load, raw=True) - elif hasattr(method, '__sa_validators__'): + event.listen(manager, "load", _event_on_load, raw=True) + elif hasattr(method, "__sa_validators__"): validation_opts = method.__sa_validation_opts__ for name in method.__sa_validators__: if name in self.validators: raise sa_exc.InvalidRequestError( "A validation function for mapped " - "attribute %r on mapper %s already exists." % - (name, self)) + "attribute %r on mapper %s already exists." + % (name, self) + ) self.validators = self.validators.union( {name: (method, validation_opts)} ) @@ -1283,13 +1324,15 @@ class Mapper(InspectionAttr): self.configured = True self._dispose_called = True - if hasattr(self, '_configure_failed'): + if hasattr(self, "_configure_failed"): del self._configure_failed - if not self.non_primary and \ - self.class_manager is not None and \ - self.class_manager.is_mapped and \ - self.class_manager.mapper is self: + if ( + not self.non_primary + and self.class_manager is not None + and self.class_manager.is_mapped + and self.class_manager.mapper is self + ): instrumentation.unregister_class(self.class_) def _configure_pks(self): @@ -1298,9 +1341,9 @@ class Mapper(InspectionAttr): self._pks_by_table = {} self._cols_by_table = {} - all_cols = util.column_set(chain(*[ - col.proxy_set for col in - self._columntoproperty])) + all_cols = util.column_set( + chain(*[col.proxy_set for col in self._columntoproperty]) + ) pk_cols = util.column_set(c for c in all_cols if c.primary_key) @@ -1311,12 +1354,12 @@ class Mapper(InspectionAttr): if t.primary_key and pk_cols.issuperset(t.primary_key): # ordering is important since it determines the ordering of # mapper.primary_key (and therefore query.get()) - self._pks_by_table[t] = \ - util.ordered_column_set(t.primary_key).\ - intersection(pk_cols) - self._cols_by_table[t] = \ - util.ordered_column_set(t.c).\ - intersection(all_cols) + self._pks_by_table[t] = util.ordered_column_set( + t.primary_key + ).intersection(pk_cols) + self._cols_by_table[t] = util.ordered_column_set(t.c).intersection( + all_cols + ) # if explicit PK argument sent, add those columns to the # primary key mappings @@ -1327,22 +1370,30 @@ class Mapper(InspectionAttr): self._pks_by_table[k.table].add(k) # otherwise, see that we got a full PK for the mapped table - elif self.mapped_table not in self._pks_by_table or \ - len(self._pks_by_table[self.mapped_table]) == 0: + elif ( + self.mapped_table not in self._pks_by_table + or len(self._pks_by_table[self.mapped_table]) == 0 + ): raise sa_exc.ArgumentError( "Mapper %s could not assemble any primary " - "key columns for mapped table '%s'" % - (self, self.mapped_table.description)) - elif self.local_table not in self._pks_by_table and \ - isinstance(self.local_table, schema.Table): - util.warn("Could not assemble any primary " - "keys for locally mapped table '%s' - " - "no rows will be persisted in this Table." - % self.local_table.description) - - if self.inherits and \ - not self.concrete and \ - not self._primary_key_argument: + "key columns for mapped table '%s'" + % (self, self.mapped_table.description) + ) + elif self.local_table not in self._pks_by_table and isinstance( + self.local_table, schema.Table + ): + util.warn( + "Could not assemble any primary " + "keys for locally mapped table '%s' - " + "no rows will be persisted in this Table." + % self.local_table.description + ) + + if ( + self.inherits + and not self.concrete + and not self._primary_key_argument + ): # if inheriting, the "primary key" for this mapper is # that of the inheriting (unless concrete or explicit) self.primary_key = self.inherits.primary_key @@ -1351,19 +1402,24 @@ class Mapper(InspectionAttr): # reduce to the minimal set of columns if self._primary_key_argument: primary_key = sql_util.reduce_columns( - [self.mapped_table.corresponding_column(c) for c in - self._primary_key_argument], - ignore_nonexistent_tables=True) + [ + self.mapped_table.corresponding_column(c) + for c in self._primary_key_argument + ], + ignore_nonexistent_tables=True, + ) else: primary_key = sql_util.reduce_columns( self._pks_by_table[self.mapped_table], - ignore_nonexistent_tables=True) + ignore_nonexistent_tables=True, + ) if len(primary_key) == 0: raise sa_exc.ArgumentError( "Mapper %s could not assemble any primary " - "key columns for mapped table '%s'" % - (self, self.mapped_table.description)) + "key columns for mapped table '%s'" + % (self, self.mapped_table.description) + ) self.primary_key = tuple(primary_key) self._log("Identified primary key columns: %s", primary_key) @@ -1373,9 +1429,12 @@ class Mapper(InspectionAttr): self._readonly_props = set( self._columntoproperty[col] for col in self._columntoproperty - if self._columntoproperty[col] not in self._identity_key_props and - (not hasattr(col, 'table') or - col.table not in self._cols_by_table)) + if self._columntoproperty[col] not in self._identity_key_props + and ( + not hasattr(col, "table") + or col.table not in self._cols_by_table + ) + ) def _configure_properties(self): # Column and other ClauseElement objects which are mapped @@ -1397,9 +1456,9 @@ class Mapper(InspectionAttr): # pull properties from the inherited mapper if any. if self.inherits: for key, prop in self.inherits._props.items(): - if key not in self._props and \ - not self._should_exclude(key, key, local=False, - column=None): + if key not in self._props and not self._should_exclude( + key, key, local=False, column=None + ): self._adapt_inherited_property(key, prop, False) # create properties for each column in the mapped table, @@ -1408,12 +1467,13 @@ class Mapper(InspectionAttr): if column in self._columntoproperty: continue - column_key = (self.column_prefix or '') + column.key + column_key = (self.column_prefix or "") + column.key if self._should_exclude( - column.key, column_key, + column.key, + column_key, local=self.local_table.c.contains_column(column), - column=column + column=column, ): continue @@ -1423,10 +1483,9 @@ class Mapper(InspectionAttr): if column in mapper._columntoproperty: column_key = mapper._columntoproperty[column].key - self._configure_property(column_key, - column, - init=False, - setparent=True) + self._configure_property( + column_key, column, init=False, setparent=True + ) def _configure_polymorphic_setter(self, init=False): """Configure an attribute on the mapper representing the @@ -1453,7 +1512,8 @@ class Mapper(InspectionAttr): raise sa_exc.ArgumentError( "Can't determine polymorphic_on " "value '%s' - no attribute is " - "mapped to this name." % self.polymorphic_on) + "mapped to this name." % self.polymorphic_on + ) if self.polymorphic_on in self._columntoproperty: # polymorphic_on is a column that is already mapped @@ -1462,12 +1522,14 @@ class Mapper(InspectionAttr): elif isinstance(self.polymorphic_on, MapperProperty): # polymorphic_on is directly a MapperProperty, # ensure it's a ColumnProperty - if not isinstance(self.polymorphic_on, - properties.ColumnProperty): + if not isinstance( + self.polymorphic_on, properties.ColumnProperty + ): raise sa_exc.ArgumentError( "Only direct column-mapped " "property or SQL expression " - "can be passed for polymorphic_on") + "can be passed for polymorphic_on" + ) prop = self.polymorphic_on elif not expression._is_column(self.polymorphic_on): # polymorphic_on is not a Column and not a ColumnProperty; @@ -1484,7 +1546,8 @@ class Mapper(InspectionAttr): # 2. a totally standalone SQL expression which we'd # hope is compatible with this mapper's mapped_table col = self.mapped_table.corresponding_column( - self.polymorphic_on) + self.polymorphic_on + ) if col is None: # polymorphic_on doesn't derive from any # column/expression isn't present in the mapped @@ -1500,14 +1563,16 @@ class Mapper(InspectionAttr): instrument = False col = self.polymorphic_on if isinstance(col, schema.Column) and ( - self.with_polymorphic is None or - self.with_polymorphic[1]. - corresponding_column(col) is None): + self.with_polymorphic is None + or self.with_polymorphic[1].corresponding_column(col) + is None + ): raise sa_exc.InvalidRequestError( "Could not map polymorphic_on column " "'%s' to the mapped table - polymorphic " "loads will not function properly" - % col.description) + % col.description + ) else: # column/expression that polymorphic_on derives from # is present in our mapped table @@ -1518,16 +1583,15 @@ class Mapper(InspectionAttr): # polymorphic_union. # we'll make a separate ColumnProperty for it. instrument = True - key = getattr(col, 'key', None) + key = getattr(col, "key", None) if key: if self._should_exclude(col.key, col.key, False, col): raise sa_exc.InvalidRequestError( "Cannot exclude or override the " - "discriminator column %r" % - col.key) + "discriminator column %r" % col.key + ) else: - self.polymorphic_on = col = \ - col.label("_sa_polymorphic_on") + self.polymorphic_on = col = col.label("_sa_polymorphic_on") key = col.key prop = properties.ColumnProperty(col, _instrument=instrument) @@ -1551,43 +1615,51 @@ class Mapper(InspectionAttr): if self.mapped_table is mapper.mapped_table: self.polymorphic_on = mapper.polymorphic_on else: - self.polymorphic_on = \ - self.mapped_table.corresponding_column( - mapper.polymorphic_on) + self.polymorphic_on = self.mapped_table.corresponding_column( + mapper.polymorphic_on + ) # we can use the parent mapper's _set_polymorphic_identity # directly; it ensures the polymorphic_identity of the # instance's mapper is used so is portable to subclasses. if self.polymorphic_on is not None: - self._set_polymorphic_identity = \ + self._set_polymorphic_identity = ( mapper._set_polymorphic_identity - self._validate_polymorphic_identity = \ + ) + self._validate_polymorphic_identity = ( mapper._validate_polymorphic_identity + ) else: self._set_polymorphic_identity = None return if setter: + def _set_polymorphic_identity(state): dict_ = state.dict state.get_impl(polymorphic_key).set( - state, dict_, + state, + dict_, state.manager.mapper.polymorphic_identity, - None) + None, + ) def _validate_polymorphic_identity(mapper, state, dict_): - if polymorphic_key in dict_ and \ - dict_[polymorphic_key] not in \ - mapper._acceptable_polymorphic_identities: + if ( + polymorphic_key in dict_ + and dict_[polymorphic_key] + not in mapper._acceptable_polymorphic_identities + ): util.warn_limited( "Flushing object %s with " "incompatible polymorphic identity %r; the " "object may not refresh and/or load correctly", - (state_str(state), dict_[polymorphic_key]) + (state_str(state), dict_[polymorphic_key]), ) self._set_polymorphic_identity = _set_polymorphic_identity - self._validate_polymorphic_identity = \ + self._validate_polymorphic_identity = ( _validate_polymorphic_identity + ) else: self._set_polymorphic_identity = None @@ -1628,16 +1700,20 @@ class Mapper(InspectionAttr): # mapper and we don't map this. don't trip user-defined # descriptors that might have side effects when invoked. implementing_attribute = self.class_manager._get_class_attr_mro( - key, prop) - if implementing_attribute is prop or (isinstance( - implementing_attribute, - attributes.InstrumentedAttribute) and - implementing_attribute._parententity is prop.parent + key, prop + ) + if implementing_attribute is prop or ( + isinstance( + implementing_attribute, attributes.InstrumentedAttribute + ) + and implementing_attribute._parententity is prop.parent ): self._configure_property( key, properties.ConcreteInheritedProperty(), - init=init, setparent=True) + init=init, + setparent=True, + ) def _configure_property(self, key, prop, init=True, setparent=True): self._log("_configure_property(%s, %s)", key, prop.__class__.__name__) @@ -1659,7 +1735,8 @@ class Mapper(InspectionAttr): for m2 in path: m2.mapped_table._reset_exported() col = self.mapped_table.corresponding_column( - prop.columns[0]) + prop.columns[0] + ) break path.append(m) @@ -1670,26 +1747,30 @@ class Mapper(InspectionAttr): # column is coming in after _readonly_props was # initialized; check for 'readonly' - if hasattr(self, '_readonly_props') and \ - (not hasattr(col, 'table') or - col.table not in self._cols_by_table): + if hasattr(self, "_readonly_props") and ( + not hasattr(col, "table") + or col.table not in self._cols_by_table + ): self._readonly_props.add(prop) else: # if column is coming in after _cols_by_table was # initialized, ensure the col is in the right set - if hasattr(self, '_cols_by_table') and \ - col.table in self._cols_by_table and \ - col not in self._cols_by_table[col.table]: + if ( + hasattr(self, "_cols_by_table") + and col.table in self._cols_by_table + and col not in self._cols_by_table[col.table] + ): self._cols_by_table[col.table].add(col) # if this properties.ColumnProperty represents the "polymorphic # discriminator" column, mark it. We'll need this when rendering # columns in SELECT statements. - if not hasattr(prop, '_is_polymorphic_discriminator'): - prop._is_polymorphic_discriminator = \ - (col is self.polymorphic_on or - prop.columns[0] is self.polymorphic_on) + if not hasattr(prop, "_is_polymorphic_discriminator"): + prop._is_polymorphic_discriminator = ( + col is self.polymorphic_on + or prop.columns[0] is self.polymorphic_on + ) self.columns[key] = col for col in prop.columns + prop._orig_columns: @@ -1701,8 +1782,9 @@ class Mapper(InspectionAttr): if setparent: prop.set_parent(self, init) - if key in self._props and \ - getattr(self._props[key], '_mapped_by_synonym', False): + if key in self._props and getattr( + self._props[key], "_mapped_by_synonym", False + ): syn = self._props[key]._mapped_by_synonym raise sa_exc.ArgumentError( "Can't call map_column=True for synonym %r=%r, " @@ -1710,20 +1792,22 @@ class Mapper(InspectionAttr): "%r for column %r" % (syn, key, key, syn) ) - if key in self._props and \ - not isinstance(prop, properties.ColumnProperty) and \ - not isinstance( - self._props[key], - ( - properties.ColumnProperty, - properties.ConcreteInheritedProperty) - ): - util.warn("Property %s on %s being replaced with new " - "property %s; the old property will be discarded" % ( - self._props[key], - self, - prop, - )) + if ( + key in self._props + and not isinstance(prop, properties.ColumnProperty) + and not isinstance( + self._props[key], + ( + properties.ColumnProperty, + properties.ConcreteInheritedProperty, + ), + ) + ): + util.warn( + "Property %s on %s being replaced with new " + "property %s; the old property will be discarded" + % (self._props[key], self, prop) + ) oldprop = self._props[key] self._path_registry.pop(oldprop, None) @@ -1753,23 +1837,29 @@ class Mapper(InspectionAttr): if not expression._is_column(column): raise sa_exc.ArgumentError( "%s=%r is not an instance of MapperProperty or Column" - % (key, prop)) + % (key, prop) + ) prop = self._props.get(key, None) if isinstance(prop, properties.ColumnProperty): if ( - not self._inherits_equated_pairs or - (prop.columns[0], column) not in self._inherits_equated_pairs - ) and \ - not prop.columns[0].shares_lineage(column) and \ - prop.columns[0] is not self.version_id_col and \ - column is not self.version_id_col: + ( + not self._inherits_equated_pairs + or (prop.columns[0], column) + not in self._inherits_equated_pairs + ) + and not prop.columns[0].shares_lineage(column) + and prop.columns[0] is not self.version_id_col + and column is not self.version_id_col + ): warn_only = prop.parent is not self - msg = ("Implicitly combining column %s with column " - "%s under attribute '%s'. Please configure one " - "or more attributes for these same-named columns " - "explicitly." % (prop.columns[-1], column, key)) + msg = ( + "Implicitly combining column %s with column " + "%s under attribute '%s'. Please configure one " + "or more attributes for these same-named columns " + "explicitly." % (prop.columns[-1], column, key) + ) if warn_only: util.warn(msg) else: @@ -1779,11 +1869,14 @@ class Mapper(InspectionAttr): # mapper. make a copy and append our column to it prop = prop.copy() prop.columns.insert(0, column) - self._log("inserting column to existing list " - "in properties.ColumnProperty %s" % (key)) + self._log( + "inserting column to existing list " + "in properties.ColumnProperty %s" % (key) + ) return prop - elif prop is None or isinstance(prop, - properties.ConcreteInheritedProperty): + elif prop is None or isinstance( + prop, properties.ConcreteInheritedProperty + ): mapped_column = [] for c in columns: mc = self.mapped_table.corresponding_column(c) @@ -1802,7 +1895,8 @@ class Mapper(InspectionAttr): "column '%s' is not represented in the mapper's " "table. Use the `column_property()` function to " "force this column to be mapped as a read-only " - "attribute." % (key, self, c)) + "attribute." % (key, self, c) + ) mapped_column.append(mc) return properties.ColumnProperty(*mapped_column) else: @@ -1815,8 +1909,8 @@ class Mapper(InspectionAttr): "(including its availability as a foreign key), " "use the 'include_properties' or 'exclude_properties' " "mapper arguments to control specifically which table " - "columns get mapped." % - (key, self, column.key, prop)) + "columns get mapped." % (key, self, column.key, prop) + ) def _post_configure_properties(self): """Call the ``init()`` method on all ``MapperProperties`` @@ -1867,34 +1961,35 @@ class Mapper(InspectionAttr): @property def _log_desc(self): - return "(" + self.class_.__name__ + \ - "|" + \ - (self.local_table is not None and - self.local_table.description or - str(self.local_table)) +\ - (self.non_primary and - "|non-primary" or "") + ")" + return ( + "(" + + self.class_.__name__ + + "|" + + ( + self.local_table is not None + and self.local_table.description + or str(self.local_table) + ) + + (self.non_primary and "|non-primary" or "") + + ")" + ) def _log(self, msg, *args): - self.logger.info( - "%s " + msg, *((self._log_desc,) + args) - ) + self.logger.info("%s " + msg, *((self._log_desc,) + args)) def _log_debug(self, msg, *args): - self.logger.debug( - "%s " + msg, *((self._log_desc,) + args) - ) + self.logger.debug("%s " + msg, *((self._log_desc,) + args)) def __repr__(self): - return '' % ( - id(self), self.class_.__name__) + return "" % (id(self), self.class_.__name__) def __str__(self): return "Mapper|%s|%s%s" % ( self.class_.__name__, - self.local_table is not None and - self.local_table.description or None, - self.non_primary and "|non-primary" or "" + self.local_table is not None + and self.local_table.description + or None, + self.non_primary and "|non-primary" or "", ) def _is_orphan(self, state): @@ -1904,7 +1999,8 @@ class Mapper(InspectionAttr): orphan_possible = True has_parent = attributes.manager_of_class(cls).has_parent( - state, key, optimistic=state.has_identity) + state, key, optimistic=state.has_identity + ) if self.legacy_is_orphan and has_parent: return False @@ -1930,7 +2026,8 @@ class Mapper(InspectionAttr): return self._props[key] except KeyError: raise sa_exc.InvalidRequestError( - "Mapper '%s' has no property '%s'" % (self, key)) + "Mapper '%s' has no property '%s'" % (self, key) + ) def get_property_by_column(self, column): """Given a :class:`.Column` object, return the @@ -1953,7 +2050,7 @@ class Mapper(InspectionAttr): selectable, if present. This helps some more legacy-ish mappings. """ - if spec == '*': + if spec == "*": mappers = list(self.self_and_descendants) elif spec: mappers = set() @@ -1961,8 +2058,8 @@ class Mapper(InspectionAttr): m = _class_to_mapper(m) if not m.isa(self): raise sa_exc.InvalidRequestError( - "%r does not inherit from %r" % - (m, self)) + "%r does not inherit from %r" % (m, self) + ) if selectable is None: mappers.update(m.iterate_to_root()) @@ -1973,8 +2070,9 @@ class Mapper(InspectionAttr): mappers = [] if selectable is not None: - tables = set(sql_util.find_tables(selectable, - include_aliases=True)) + tables = set( + sql_util.find_tables(selectable, include_aliases=True) + ) mappers = [m for m in mappers if m.local_table in tables] return mappers @@ -1991,25 +2089,26 @@ class Mapper(InspectionAttr): if m.concrete: raise sa_exc.InvalidRequestError( "'with_polymorphic()' requires 'selectable' argument " - "when concrete-inheriting mappers are used.") + "when concrete-inheriting mappers are used." + ) elif not m.single: if innerjoin: - from_obj = from_obj.join(m.local_table, - m.inherit_condition) + from_obj = from_obj.join( + m.local_table, m.inherit_condition + ) else: - from_obj = from_obj.outerjoin(m.local_table, - m.inherit_condition) + from_obj = from_obj.outerjoin( + m.local_table, m.inherit_condition + ) return from_obj @_memoized_configured_property def _single_table_criterion(self): - if self.single and \ - self.inherits and \ - self.polymorphic_on is not None: + if self.single and self.inherits and self.polymorphic_on is not None: return self.polymorphic_on.in_( - m.polymorphic_identity - for m in self.self_and_descendants) + m.polymorphic_identity for m in self.self_and_descendants + ) else: return None @@ -2031,8 +2130,8 @@ class Mapper(InspectionAttr): return selectable else: return self._selectable_from_mappers( - self._mappers_from_spec(spec, selectable), - False) + self._mappers_from_spec(spec, selectable), False + ) with_polymorphic_mappers = _with_polymorphic_mappers """The list of :class:`.Mapper` objects included in the @@ -2046,9 +2145,8 @@ class Mapper(InspectionAttr): ( table, frozenset( - col for col in columns - if col.type.should_evaluate_none - ) + col for col in columns if col.type.should_evaluate_none + ), ) for table, columns in self._cols_by_table.items() ) @@ -2059,10 +2157,13 @@ class Mapper(InspectionAttr): ( table, frozenset( - col.key for col in columns - if not col.primary_key and - not col.server_default and not col.default - and not col.type.should_evaluate_none) + col.key + for col in columns + if not col.primary_key + and not col.server_default + and not col.default + and not col.type.should_evaluate_none + ), ) for table, columns in self._cols_by_table.items() ) @@ -2073,9 +2174,8 @@ class Mapper(InspectionAttr): ( table, dict( - (self._columntoproperty[col].key, col) - for col in columns - ) + (self._columntoproperty[col].key, col) for col in columns + ), ) for table, columns in self._cols_by_table.items() ) @@ -2083,10 +2183,7 @@ class Mapper(InspectionAttr): @_memoized_configured_property def _pk_keys_by_table(self): return dict( - ( - table, - frozenset([col.key for col in pks]) - ) + (table, frozenset([col.key for col in pks])) for table, pks in self._pks_by_table.items() ) @@ -2095,7 +2192,7 @@ class Mapper(InspectionAttr): return dict( ( table, - frozenset([self._columntoproperty[col].key for col in pks]) + frozenset([self._columntoproperty[col].key for col in pks]), ) for table, pks in self._pks_by_table.items() ) @@ -2105,9 +2202,13 @@ class Mapper(InspectionAttr): return dict( ( table, - frozenset([ - col.key for col in columns - if col.server_default is not None]) + frozenset( + [ + col.key + for col in columns + if col.server_default is not None + ] + ), ) for table, columns in self._cols_by_table.items() ) @@ -2119,11 +2220,9 @@ class Mapper(InspectionAttr): for table, columns in self._cols_by_table.items(): for col in columns: if ( - ( - col.server_default is not None or - col.server_onupdate is not None - ) and col in self._columntoproperty - ): + col.server_default is not None + or col.server_onupdate is not None + ) and col in self._columntoproperty: result.add(self._columntoproperty[col].key) return result @@ -2133,9 +2232,13 @@ class Mapper(InspectionAttr): return dict( ( table, - frozenset([ - col.key for col in columns - if col.server_onupdate is not None]) + frozenset( + [ + col.key + for col in columns + if col.server_onupdate is not None + ] + ), ) for table, columns in self._cols_by_table.items() ) @@ -2152,8 +2255,9 @@ class Mapper(InspectionAttr): """ return self._with_polymorphic_selectable - def _with_polymorphic_args(self, spec=None, selectable=False, - innerjoin=False): + def _with_polymorphic_args( + self, spec=None, selectable=False, innerjoin=False + ): if self.with_polymorphic: if not spec: spec = self.with_polymorphic[0] @@ -2165,13 +2269,15 @@ class Mapper(InspectionAttr): if selectable is not None: return mappers, selectable else: - return mappers, self._selectable_from_mappers(mappers, - innerjoin) + return mappers, self._selectable_from_mappers(mappers, innerjoin) @_memoized_configured_property def _polymorphic_properties(self): - return list(self._iterate_polymorphic_properties( - self._with_polymorphic_mappers)) + return list( + self._iterate_polymorphic_properties( + self._with_polymorphic_mappers + ) + ) def _iterate_polymorphic_properties(self, mappers=None): """Return an iterator of MapperProperty objects which will render into @@ -2187,14 +2293,17 @@ class Mapper(InspectionAttr): # from other mappers, as these are sometimes dependent on that # mapper's polymorphic selectable (which we don't want rendered) for c in util.unique_list( - chain(*[ - list(mapper.iterate_properties) for mapper in - [self] + mappers - ]) + chain( + *[ + list(mapper.iterate_properties) + for mapper in [self] + mappers + ] + ) ): - if getattr(c, '_is_polymorphic_discriminator', False) and \ - (self.polymorphic_on is None or - c.columns[0] is not self.polymorphic_on): + if getattr(c, "_is_polymorphic_discriminator", False) and ( + self.polymorphic_on is None + or c.columns[0] is not self.polymorphic_on + ): continue yield c @@ -2282,7 +2391,8 @@ class Mapper(InspectionAttr): """ return util.ImmutableProperties( - dict(self.class_manager._all_sqla_attributes())) + dict(self.class_manager._all_sqla_attributes()) + ) @_memoized_configured_property def synonyms(self): @@ -2351,10 +2461,11 @@ class Mapper(InspectionAttr): def _filter_properties(self, type_): if Mapper._new_mappers: configure_mappers() - return util.ImmutableProperties(util.OrderedDict( - (k, v) for k, v in self._props.items() - if isinstance(v, type_) - )) + return util.ImmutableProperties( + util.OrderedDict( + (k, v) for k, v in self._props.items() if isinstance(v, type_) + ) + ) @_memoized_configured_property def _get_clause(self): @@ -2363,10 +2474,14 @@ class Mapper(InspectionAttr): by primary key. """ - params = [(primary_key, sql.bindparam(None, type_=primary_key.type)) - for primary_key in self.primary_key] - return sql.and_(*[k == v for (k, v) in params]), \ - util.column_dict(params) + params = [ + (primary_key, sql.bindparam(None, type_=primary_key.type)) + for primary_key in self.primary_key + ] + return ( + sql.and_(*[k == v for (k, v) in params]), + util.column_dict(params), + ) @_memoized_configured_property def _equivalent_columns(self): @@ -2401,18 +2516,24 @@ class Mapper(InspectionAttr): result[binary.right].add(binary.left) else: result[binary.right] = util.column_set((binary.left,)) + for mapper in self.base_mapper.self_and_descendants: if mapper.inherit_condition is not None: visitors.traverse( - mapper.inherit_condition, {}, - {'binary': visit_binary}) + mapper.inherit_condition, {}, {"binary": visit_binary} + ) return result def _is_userland_descriptor(self, obj): - if isinstance(obj, (_MappedAttribute, - instrumentation.ClassManager, - expression.ColumnElement)): + if isinstance( + obj, + ( + _MappedAttribute, + instrumentation.ClassManager, + expression.ColumnElement, + ), + ): return False else: return True @@ -2429,26 +2550,29 @@ class Mapper(InspectionAttr): # check for class-bound attributes and/or descriptors, # either local or from an inherited class if local: - if self.class_.__dict__.get(assigned_name, None) is not None \ - and self._is_userland_descriptor( - self.class_.__dict__[assigned_name]): + if self.class_.__dict__.get( + assigned_name, None + ) is not None and self._is_userland_descriptor( + self.class_.__dict__[assigned_name] + ): return True else: attr = self.class_manager._get_class_attr_mro(assigned_name, None) if attr is not None and self._is_userland_descriptor(attr): return True - if self.include_properties is not None and \ - name not in self.include_properties and \ - (column is None or column not in self.include_properties): + if ( + self.include_properties is not None + and name not in self.include_properties + and (column is None or column not in self.include_properties) + ): self._log("not including property %s" % (name)) return True - if self.exclude_properties is not None and \ - ( - name in self.exclude_properties or - (column is not None and column in self.exclude_properties) - ): + if self.exclude_properties is not None and ( + name in self.exclude_properties + or (column is not None and column in self.exclude_properties) + ): self._log("excluding property %s" % (name)) return True @@ -2545,8 +2669,11 @@ class Mapper(InspectionAttr): if adapter: pk_cols = [adapter.columns[c] for c in pk_cols] - return self._identity_class, \ - tuple(row[column] for column in pk_cols), identity_token + return ( + self._identity_class, + tuple(row[column] for column in pk_cols), + identity_token, + ) def identity_key_from_primary_key(self, primary_key, identity_token=None): """Return an identity-map key for use in storing/retrieving an @@ -2574,14 +2701,20 @@ class Mapper(InspectionAttr): return self._identity_key_from_state(state, attributes.PASSIVE_OFF) def _identity_key_from_state( - self, state, passive=attributes.PASSIVE_RETURN_NEVER_SET): + self, state, passive=attributes.PASSIVE_RETURN_NEVER_SET + ): dict_ = state.dict manager = state.manager - return self._identity_class, tuple([ - manager[prop.key]. - impl.get(state, dict_, passive) - for prop in self._identity_key_props - ]), state.identity_token + return ( + self._identity_class, + tuple( + [ + manager[prop.key].impl.get(state, dict_, passive) + for prop in self._identity_key_props + ] + ), + state.identity_token, + ) def primary_key_from_instance(self, instance): """Return the list of primary key values for the given @@ -2595,7 +2728,8 @@ class Mapper(InspectionAttr): """ state = attributes.instance_state(instance) identity_key = self._identity_key_from_state( - state, attributes.PASSIVE_OFF) + state, attributes.PASSIVE_OFF + ) return identity_key[1] @_memoized_configured_property @@ -2621,8 +2755,8 @@ class Mapper(InspectionAttr): return {prop.key for prop in self._all_pk_props} def _get_state_attr_by_column( - self, state, dict_, column, - passive=attributes.PASSIVE_RETURN_NEVER_SET): + self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NEVER_SET + ): prop = self._columntoproperty[column] return state.manager[prop.key].impl.get(state, dict_, passive=passive) @@ -2638,15 +2772,17 @@ class Mapper(InspectionAttr): state = attributes.instance_state(obj) dict_ = attributes.instance_dict(obj) return self._get_committed_state_attr_by_column( - state, dict_, column, passive=attributes.PASSIVE_OFF) + state, dict_, column, passive=attributes.PASSIVE_OFF + ) def _get_committed_state_attr_by_column( - self, state, dict_, column, - passive=attributes.PASSIVE_RETURN_NEVER_SET): + self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NEVER_SET + ): prop = self._columntoproperty[column] - return state.manager[prop.key].impl.\ - get_committed_value(state, dict_, passive=passive) + return state.manager[prop.key].impl.get_committed_value( + state, dict_, passive=passive + ) def _optimized_get_statement(self, state, attribute_names): """assemble a WHERE clause which retrieves a given state by primary @@ -2660,11 +2796,15 @@ class Mapper(InspectionAttr): """ props = self._props - tables = set(chain( - *[sql_util.find_tables(c, check_columns=True) - for key in attribute_names - for c in props[key].columns] - )) + tables = set( + chain( + *[ + sql_util.find_tables(c, check_columns=True) + for key in attribute_names + for c in props[key].columns + ] + ) + ) if self.base_mapper.local_table in tables: return None @@ -2680,22 +2820,28 @@ class Mapper(InspectionAttr): if leftcol.table not in tables: leftval = self._get_committed_state_attr_by_column( - state, state.dict, + state, + state.dict, leftcol, - passive=attributes.PASSIVE_NO_INITIALIZE) + passive=attributes.PASSIVE_NO_INITIALIZE, + ) if leftval in orm_util._none_set: raise ColumnsNotAvailable() - binary.left = sql.bindparam(None, leftval, - type_=binary.right.type) + binary.left = sql.bindparam( + None, leftval, type_=binary.right.type + ) elif rightcol.table not in tables: rightval = self._get_committed_state_attr_by_column( - state, state.dict, + state, + state.dict, rightcol, - passive=attributes.PASSIVE_NO_INITIALIZE) + passive=attributes.PASSIVE_NO_INITIALIZE, + ) if rightval in orm_util._none_set: raise ColumnsNotAvailable() - binary.right = sql.bindparam(None, rightval, - type_=binary.right.type) + binary.right = sql.bindparam( + None, rightval, type_=binary.right.type + ) allconds = [] @@ -2704,15 +2850,17 @@ class Mapper(InspectionAttr): for mapper in reversed(list(self.iterate_to_root())): if mapper.local_table in tables: start = True - elif not isinstance(mapper.local_table, - expression.TableClause): + elif not isinstance( + mapper.local_table, expression.TableClause + ): return None if start and not mapper.single: - allconds.append(visitors.cloned_traverse( - mapper.inherit_condition, - {}, - {'binary': visit_binary} - ) + allconds.append( + visitors.cloned_traverse( + mapper.inherit_condition, + {}, + {"binary": visit_binary}, + ) ) except ColumnsNotAvailable: return None @@ -2730,8 +2878,7 @@ class Mapper(InspectionAttr): for m in self.iterate_to_root(): yield m - if m is not prev and prev not in \ - m._with_polymorphic_mappers: + if m is not prev and prev not in m._with_polymorphic_mappers: break prev = m @@ -2743,7 +2890,7 @@ class Mapper(InspectionAttr): # common case, takes place for all polymorphic loads mapper = polymorphic_from for m in self._iterate_to_target_viawpoly(mapper): - if m.polymorphic_load == 'selectin': + if m.polymorphic_load == "selectin": return m else: # uncommon case, selectin load options were used @@ -2752,15 +2899,17 @@ class Mapper(InspectionAttr): for entity in enabled_via_opt.union([polymorphic_from]): mapper = entity.mapper for m in self._iterate_to_target_viawpoly(mapper): - if m.polymorphic_load == 'selectin' or \ - m in enabled_via_opt_mappers: + if ( + m.polymorphic_load == "selectin" + or m in enabled_via_opt_mappers + ): return enabled_via_opt_mappers.get(m, m) return None @util.dependencies( - "sqlalchemy.ext.baked", - "sqlalchemy.orm.strategy_options") + "sqlalchemy.ext.baked", "sqlalchemy.orm.strategy_options" + ) def _subclass_load_via_in(self, baked, strategy_options, entity): """Assemble a BakedQuery that can load the columns local to this subclass as a SELECT with IN. @@ -2768,10 +2917,8 @@ class Mapper(InspectionAttr): """ assert self.inherits - polymorphic_prop = self._columntoproperty[ - self.polymorphic_on] - keep_props = set( - [polymorphic_prop] + self._identity_key_props) + polymorphic_prop = self._columntoproperty[self.polymorphic_on] + keep_props = set([polymorphic_prop] + self._identity_key_props) disable_opt = strategy_options.Load(entity) enable_opt = strategy_options.Load(entity) @@ -2781,16 +2928,14 @@ class Mapper(InspectionAttr): # "enable" options, to turn on the properties that we want to # load by default (subject to options from the query) enable_opt.set_generic_strategy( - (prop.key, ), - dict(prop.strategy_key) + (prop.key,), dict(prop.strategy_key) ) else: # "disable" options, to turn off the properties from the # superclass that we *don't* want to load, applied after # the options from the query to override them disable_opt.set_generic_strategy( - (prop.key, ), - {"do_nothing": True} + (prop.key,), {"do_nothing": True} ) if len(self.primary_key) > 1: @@ -2802,22 +2947,21 @@ class Mapper(InspectionAttr): assert entity.mapper is self q = baked.BakedQuery( self._compiled_cache, - lambda session: session.query(entity). - select_entity_from(entity.selectable)._adapt_all_clauses(), - (self, ) + lambda session: session.query(entity) + .select_entity_from(entity.selectable) + ._adapt_all_clauses(), + (self,), ) q.spoil() else: q = baked.BakedQuery( self._compiled_cache, lambda session: session.query(self), - (self, ) + (self,), ) q += lambda q: q.filter( - in_expr.in_( - sql.bindparam('primary_keys', expanding=True) - ) + in_expr.in_(sql.bindparam("primary_keys", expanding=True)) ).order_by(*self.primary_key) return q, enable_opt, disable_opt @@ -2856,8 +3000,9 @@ class Mapper(InspectionAttr): assert state.mapper.isa(self) - visitables = deque([(deque(state.mapper._props.values()), prp, - state, state.dict)]) + visitables = deque( + [(deque(state.mapper._props.values()), prp, state, state.dict)] + ) while visitables: iterator, item_type, parent_state, parent_dict = visitables[-1] @@ -2869,21 +3014,28 @@ class Mapper(InspectionAttr): prop = iterator.popleft() if type_ not in prop.cascade: continue - queue = deque(prop.cascade_iterator( - type_, parent_state, parent_dict, - visited_states, halt_on)) + queue = deque( + prop.cascade_iterator( + type_, + parent_state, + parent_dict, + visited_states, + halt_on, + ) + ) if queue: visitables.append((queue, mpp, None, None)) elif item_type is mpp: - instance, instance_mapper, corresponding_state, \ - corresponding_dict = iterator.popleft() - yield instance, instance_mapper, \ - corresponding_state, corresponding_dict + instance, instance_mapper, corresponding_state, corresponding_dict = ( + iterator.popleft() + ) + yield instance, instance_mapper, corresponding_state, corresponding_dict visitables.append( ( deque(instance_mapper._props.values()), - prp, corresponding_state, - corresponding_dict + prp, + corresponding_state, + corresponding_dict, ) ) @@ -2903,10 +3055,9 @@ class Mapper(InspectionAttr): for table, mapper in table_to_mapper.items(): super_ = mapper.inherits if super_: - extra_dependencies.extend([ - (super_table, table) - for super_table in super_.tables - ]) + extra_dependencies.extend( + [(super_table, table) for super_table in super_.tables] + ) def skip(fk): # attempt to skip dependencies that are not @@ -2916,22 +3067,27 @@ class Mapper(InspectionAttr): # not what we mean to sort on here. parent = table_to_mapper.get(fk.parent.table) dep = table_to_mapper.get(fk.column.table) - if parent is not None and \ - dep is not None and \ - dep is not parent and \ - dep.inherit_condition is not None: + if ( + parent is not None + and dep is not None + and dep is not parent + and dep.inherit_condition is not None + ): cols = set(sql_util._find_columns(dep.inherit_condition)) if parent.inherit_condition is not None: - cols = cols.union(sql_util._find_columns( - parent.inherit_condition)) + cols = cols.union( + sql_util._find_columns(parent.inherit_condition) + ) return fk.parent not in cols and fk.column not in cols else: return fk.parent not in cols return False - sorted_ = sql_util.sort_tables(table_to_mapper, - skip_fn=skip, - extra_dependencies=extra_dependencies) + sorted_ = sql_util.sort_tables( + table_to_mapper, + skip_fn=skip, + extra_dependencies=extra_dependencies, + ) ret = util.OrderedDict() for t in sorted_: @@ -2955,12 +3111,12 @@ class Mapper(InspectionAttr): for table in self._sorted_tables: cols = set(table.c) for m in self.iterate_to_root(): - if m._inherits_equated_pairs and \ - cols.intersection( - util.reduce(set.union, - [l.proxy_set for l, r in - m._inherits_equated_pairs]) - ): + if m._inherits_equated_pairs and cols.intersection( + util.reduce( + set.union, + [l.proxy_set for l, r in m._inherits_equated_pairs], + ) + ): result[table].append((m, m._inherits_equated_pairs)) return result @@ -3034,13 +3190,14 @@ def configure_mappers(): if run_configure is EXT_SKIP: continue - if getattr(mapper, '_configure_failed', False): + if getattr(mapper, "_configure_failed", False): e = sa_exc.InvalidRequestError( "One or more mappers failed to initialize - " "can't proceed with initialization of other " "mappers. Triggering mapper: '%s'. " "Original exception was: %s" - % (mapper, mapper._configure_failed)) + % (mapper, mapper._configure_failed) + ) e._configure_failed = mapper._configure_failed raise e @@ -3049,10 +3206,11 @@ def configure_mappers(): mapper._post_configure_properties() mapper._expire_memoizations() mapper.dispatch.mapper_configured( - mapper, mapper.class_) + mapper, mapper.class_ + ) except Exception: exc = sys.exc_info()[1] - if not hasattr(exc, '_configure_failed'): + if not hasattr(exc, "_configure_failed"): mapper._configure_failed = exc raise @@ -3127,16 +3285,17 @@ def validates(*names, **kw): :ref:`simple_validators` - usage examples for :func:`.validates` """ - include_removes = kw.pop('include_removes', False) - include_backrefs = kw.pop('include_backrefs', True) + include_removes = kw.pop("include_removes", False) + include_backrefs = kw.pop("include_backrefs", True) def wrap(fn): fn.__sa_validators__ = names fn.__sa_validation_opts__ = { "include_removes": include_removes, - "include_backrefs": include_backrefs + "include_backrefs": include_backrefs, } return fn + return wrap @@ -3180,7 +3339,7 @@ def _event_on_init(state, args, kwargs): class _ColumnMapping(dict): """Error reporting helper for mapper._columntoproperty.""" - __slots__ = 'mapper', + __slots__ = ("mapper",) def __init__(self, mapper): self.mapper = mapper @@ -3190,8 +3349,10 @@ class _ColumnMapping(dict): if prop: raise orm_exc.UnmappedColumnError( "Column '%s.%s' is not available, due to " - "conflicting property '%s':%r" % ( - column.table.name, column.name, column.key, prop)) + "conflicting property '%s':%r" + % (column.table.name, column.name, column.key, prop) + ) raise orm_exc.UnmappedColumnError( - "No column %s is configured on mapper %s..." % - (column, self.mapper)) + "No column %s is configured on mapper %s..." + % (column, self.mapper) + ) diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index bb4e2eda5f..f33c209cc7 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -56,8 +56,7 @@ class PathRegistry(object): is_root = False def __eq__(self, other): - return other is not None and \ - self.path == other.path + return other is not None and self.path == other.path def set(self, attributes, key, value): log.debug("set '%s' on path '%s' to '%s'", key, self, value) @@ -87,11 +86,8 @@ class PathRegistry(object): yield path[i], path[i + 1] def contains_mapper(self, mapper): - for path_mapper in [ - self.path[i] for i in range(0, len(self.path), 2) - ]: - if path_mapper.is_mapper and \ - path_mapper.isa(mapper): + for path_mapper in [self.path[i] for i in range(0, len(self.path), 2)]: + if path_mapper.is_mapper and path_mapper.isa(mapper): return True else: return False @@ -100,40 +96,49 @@ class PathRegistry(object): return (key, self.path) in attributes def __reduce__(self): - return _unreduce_path, (self.serialize(), ) + return _unreduce_path, (self.serialize(),) def serialize(self): path = self.path - return list(zip( - [m.class_ for m in [path[i] for i in range(0, len(path), 2)]], - [path[i].key for i in range(1, len(path), 2)] + [None] - )) + return list( + zip( + [m.class_ for m in [path[i] for i in range(0, len(path), 2)]], + [path[i].key for i in range(1, len(path), 2)] + [None], + ) + ) @classmethod def deserialize(cls, path): if path is None: return None - p = tuple(chain(*[(class_mapper(mcls), - class_mapper(mcls).attrs[key] - if key is not None else None) - for mcls, key in path])) + p = tuple( + chain( + *[ + ( + class_mapper(mcls), + class_mapper(mcls).attrs[key] + if key is not None + else None, + ) + for mcls, key in path + ] + ) + ) if p and p[-1] is None: p = p[0:-1] return cls.coerce(p) @classmethod def per_mapper(cls, mapper): - return EntityRegistry( - cls.root, mapper - ) + return EntityRegistry(cls.root, mapper) @classmethod def coerce(cls, raw): return util.reduce(lambda prev, next: prev[next], raw, cls.root) def token(self, token): - if token.endswith(':' + _WILDCARD_TOKEN): + if token.endswith(":" + _WILDCARD_TOKEN): return TokenRegistry(self, token) elif token.endswith(":" + _DEFAULT_TOKEN): return TokenRegistry(self.root, token) @@ -141,12 +146,10 @@ class PathRegistry(object): raise exc.ArgumentError("invalid token: %s" % token) def __add__(self, other): - return util.reduce( - lambda prev, next: prev[next], - other.path, self) + return util.reduce(lambda prev, next: prev[next], other.path, self) def __repr__(self): - return "%s(%r)" % (self.__class__.__name__, self.path, ) + return "%s(%r)" % (self.__class__.__name__, self.path) class RootRegistry(PathRegistry): @@ -154,6 +157,7 @@ class RootRegistry(PathRegistry): paths are maintained per-root-mapper. """ + path = () has_entity = False is_aliased_class = False @@ -162,6 +166,7 @@ class RootRegistry(PathRegistry): def __getitem__(self, entity): return entity._path_registry + PathRegistry.root = RootRegistry() @@ -194,8 +199,10 @@ class PropRegistry(PathRegistry): if not insp.is_aliased_class or insp._use_mapper_path: parent = parent.parent[prop.parent] elif insp.is_aliased_class and insp.with_polymorphic_mappers: - if prop.parent is not insp.mapper and \ - prop.parent in insp.with_polymorphic_mappers: + if ( + prop.parent is not insp.mapper + and prop.parent in insp.with_polymorphic_mappers + ): subclass_entity = parent[-1]._entity_for_mapper(prop.parent) parent = parent.parent[subclass_entity] @@ -205,15 +212,13 @@ class PropRegistry(PathRegistry): self._wildcard_path_loader_key = ( "loader", - self.parent.path + self.prop._wildcard_token + self.parent.path + self.prop._wildcard_token, ) self._default_path_loader_key = self.prop._default_path_loader_key self._loader_key = ("loader", self.path) def __str__(self): - return " -> ".join( - str(elem) for elem in self.path - ) + return " -> ".join(str(elem) for elem in self.path) @util.memoized_property def has_entity(self): @@ -235,9 +240,7 @@ class PropRegistry(PathRegistry): if isinstance(entity, (int, slice)): return self.path[entity] else: - return EntityRegistry( - self, entity - ) + return EntityRegistry(self, entity) class EntityRegistry(PathRegistry, dict): @@ -258,6 +261,7 @@ class EntityRegistry(PathRegistry, dict): def __bool__(self): return True + __nonzero__ = __bool__ def __getitem__(self, entity): diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 7f9b7db0ce..dc86a60e54 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -25,8 +25,13 @@ from . import loading def _bulk_insert( - mapper, mappings, session_transaction, isstates, return_defaults, - render_nulls): + mapper, + mappings, + session_transaction, + isstates, + return_defaults, + render_nulls, +): base_mapper = mapper.base_mapper cached_connections = _cached_connection_dict(base_mapper) @@ -34,7 +39,8 @@ def _bulk_insert( if session_transaction.session.connection_callable: raise NotImplementedError( "connection_callable / per-instance sharding " - "not supported in bulk_insert()") + "not supported in bulk_insert()" + ) if isstates: if return_defaults: @@ -51,22 +57,33 @@ def _bulk_insert( continue records = ( - (None, state_dict, params, mapper, - connection, value_params, has_all_pks, has_all_defaults) - for - state, state_dict, params, mp, - conn, value_params, has_all_pks, - has_all_defaults in _collect_insert_commands(table, ( - (None, mapping, mapper, connection) - for mapping in mappings), - bulk=True, return_defaults=return_defaults, - render_nulls=render_nulls + ( + None, + state_dict, + params, + mapper, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) + for state, state_dict, params, mp, conn, value_params, has_all_pks, has_all_defaults in _collect_insert_commands( + table, + ((None, mapping, mapper, connection) for mapping in mappings), + bulk=True, + return_defaults=return_defaults, + render_nulls=render_nulls, ) ) - _emit_insert_statements(base_mapper, None, - cached_connections, - super_mapper, table, records, - bookkeeping=return_defaults) + _emit_insert_statements( + base_mapper, + None, + cached_connections, + super_mapper, + table, + records, + bookkeeping=return_defaults, + ) if return_defaults and isstates: identity_cls = mapper._identity_class @@ -74,12 +91,13 @@ def _bulk_insert( for state, dict_ in states: state.key = ( identity_cls, - tuple([dict_[key] for key in identity_props]) + tuple([dict_[key] for key in identity_props]), ) -def _bulk_update(mapper, mappings, session_transaction, - isstates, update_changed_only): +def _bulk_update( + mapper, mappings, session_transaction, isstates, update_changed_only +): base_mapper = mapper.base_mapper cached_connections = _cached_connection_dict(base_mapper) @@ -91,9 +109,8 @@ def _bulk_update(mapper, mappings, session_transaction, def _changed_dict(mapper, state): return dict( (k, v) - for k, v in state.dict.items() if k in state.committed_state or k - in search_keys - + for k, v in state.dict.items() + if k in state.committed_state or k in search_keys ) if isstates: @@ -107,7 +124,8 @@ def _bulk_update(mapper, mappings, session_transaction, if session_transaction.session.connection_callable: raise NotImplementedError( "connection_callable / per-instance sharding " - "not supported in bulk_update()") + "not supported in bulk_update()" + ) connection = session_transaction.connection(base_mapper) @@ -115,21 +133,38 @@ def _bulk_update(mapper, mappings, session_transaction, if not mapper.isa(super_mapper): continue - records = _collect_update_commands(None, table, ( - (None, mapping, mapper, connection, - (mapping[mapper._version_id_prop.key] - if mapper._version_id_prop else None)) - for mapping in mappings - ), bulk=True) + records = _collect_update_commands( + None, + table, + ( + ( + None, + mapping, + mapper, + connection, + ( + mapping[mapper._version_id_prop.key] + if mapper._version_id_prop + else None + ), + ) + for mapping in mappings + ), + bulk=True, + ) - _emit_update_statements(base_mapper, None, - cached_connections, - super_mapper, table, records, - bookkeeping=False) + _emit_update_statements( + base_mapper, + None, + cached_connections, + super_mapper, + table, + records, + bookkeeping=False, + ) -def save_obj( - base_mapper, states, uowtransaction, single=False): +def save_obj(base_mapper, states, uowtransaction, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. @@ -150,19 +185,21 @@ def save_obj( states_to_insert = [] cached_connections = _cached_connection_dict(base_mapper) - for (state, dict_, mapper, connection, - has_identity, - row_switch, update_version_id) in _organize_states_for_save( - base_mapper, states, uowtransaction - ): + for ( + state, + dict_, + mapper, + connection, + has_identity, + row_switch, + update_version_id, + ) in _organize_states_for_save(base_mapper, states, uowtransaction): if has_identity or row_switch: states_to_update.append( (state, dict_, mapper, connection, update_version_id) ) else: - states_to_insert.append( - (state, dict_, mapper, connection) - ) + states_to_insert.append((state, dict_, mapper, connection)) for table, mapper in base_mapper._sorted_tables.items(): if table not in mapper._pks_by_table: @@ -170,18 +207,30 @@ def save_obj( insert = _collect_insert_commands(table, states_to_insert) update = _collect_update_commands( - uowtransaction, table, states_to_update) + uowtransaction, table, states_to_update + ) - _emit_update_statements(base_mapper, uowtransaction, - cached_connections, - mapper, table, update) + _emit_update_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + update, + ) - _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, - mapper, table, insert) + _emit_insert_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + insert, + ) _finalize_insert_update_commands( - base_mapper, uowtransaction, + base_mapper, + uowtransaction, chain( ( (state, state_dict, mapper, connection, False) @@ -189,10 +238,9 @@ def save_obj( ), ( (state, state_dict, mapper, connection, True) - for state, state_dict, mapper, connection, - update_version_id in states_to_update - ) - ) + for state, state_dict, mapper, connection, update_version_id in states_to_update + ), + ), ) @@ -203,9 +251,9 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols): """ cached_connections = _cached_connection_dict(base_mapper) - states_to_update = list(_organize_states_for_post_update( - base_mapper, - states, uowtransaction)) + states_to_update = list( + _organize_states_for_post_update(base_mapper, states, uowtransaction) + ) for table, mapper in base_mapper._sorted_tables.items(): if table not in mapper._pks_by_table: @@ -213,25 +261,32 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols): update = ( ( - state, state_dict, sub_mapper, connection, + state, + state_dict, + sub_mapper, + connection, mapper._get_committed_state_attr_by_column( state, state_dict, mapper.version_id_col - ) if mapper.version_id_col is not None else None + ) + if mapper.version_id_col is not None + else None, ) - for - state, state_dict, sub_mapper, connection in states_to_update + for state, state_dict, sub_mapper, connection in states_to_update if table in sub_mapper._pks_by_table ) update = _collect_post_update_commands( - base_mapper, uowtransaction, - table, update, - post_update_cols + base_mapper, uowtransaction, table, update, post_update_cols ) - _emit_post_update_statements(base_mapper, uowtransaction, - cached_connections, - mapper, table, update) + _emit_post_update_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + update, + ) def delete_obj(base_mapper, states, uowtransaction): @@ -244,10 +299,9 @@ def delete_obj(base_mapper, states, uowtransaction): cached_connections = _cached_connection_dict(base_mapper) - states_to_delete = list(_organize_states_for_delete( - base_mapper, - states, - uowtransaction)) + states_to_delete = list( + _organize_states_for_delete(base_mapper, states, uowtransaction) + ) table_to_mapper = base_mapper._sorted_tables @@ -258,14 +312,26 @@ def delete_obj(base_mapper, states, uowtransaction): elif mapper.inherits and mapper.passive_deletes: continue - delete = _collect_delete_commands(base_mapper, uowtransaction, - table, states_to_delete) + delete = _collect_delete_commands( + base_mapper, uowtransaction, table, states_to_delete + ) - _emit_delete_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, delete) + _emit_delete_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + delete, + ) - for state, state_dict, mapper, connection, \ - update_version_id in states_to_delete: + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_delete: mapper.dispatch.after_delete(mapper, connection, state) @@ -282,8 +348,8 @@ def _organize_states_for_save(base_mapper, states, uowtransaction): """ for state, dict_, mapper, connection in _connections_for_states( - base_mapper, uowtransaction, - states): + base_mapper, uowtransaction, states + ): has_identity = bool(state.key) @@ -304,25 +370,29 @@ def _organize_states_for_save(base_mapper, states, uowtransaction): # no instance_key attached to it), and another instance # with the same identity key already exists as persistent. # convert to an UPDATE if so. - if not has_identity and \ - instance_key in uowtransaction.session.identity_map: - instance = \ - uowtransaction.session.identity_map[instance_key] + if ( + not has_identity + and instance_key in uowtransaction.session.identity_map + ): + instance = uowtransaction.session.identity_map[instance_key] existing = attributes.instance_state(instance) if not uowtransaction.was_already_deleted(existing): if not uowtransaction.is_deleted(existing): raise orm_exc.FlushError( "New instance %s with identity key %s conflicts " - "with persistent instance %s" % - (state_str(state), instance_key, - state_str(existing))) + "with persistent instance %s" + % (state_str(state), instance_key, state_str(existing)) + ) base_mapper._log_debug( "detected row switch for identity %s. " "will update %s, remove %s from " - "transaction", instance_key, - state_str(state), state_str(existing)) + "transaction", + instance_key, + state_str(state), + state_str(existing), + ) # remove the "delete" flag from the existing element uowtransaction.remove_state_actions(existing) @@ -332,14 +402,21 @@ def _organize_states_for_save(base_mapper, states, uowtransaction): update_version_id = mapper._get_committed_state_attr_by_column( row_switch if row_switch else state, row_switch.dict if row_switch else dict_, - mapper.version_id_col) + mapper.version_id_col, + ) - yield (state, dict_, mapper, connection, - has_identity, row_switch, update_version_id) + yield ( + state, + dict_, + mapper, + connection, + has_identity, + row_switch, + update_version_id, + ) -def _organize_states_for_post_update(base_mapper, states, - uowtransaction): +def _organize_states_for_post_update(base_mapper, states, uowtransaction): """Make an initial pass across a set of states for UPDATE corresponding to post_update. @@ -360,26 +437,28 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction): """ for state, dict_, mapper, connection in _connections_for_states( - base_mapper, uowtransaction, - states): + base_mapper, uowtransaction, states + ): mapper.dispatch.before_delete(mapper, connection, state) if mapper.version_id_col is not None: - update_version_id = \ - mapper._get_committed_state_attr_by_column( - state, dict_, - mapper.version_id_col) + update_version_id = mapper._get_committed_state_attr_by_column( + state, dict_, mapper.version_id_col + ) else: update_version_id = None - yield ( - state, dict_, mapper, connection, update_version_id) + yield (state, dict_, mapper, connection, update_version_id) def _collect_insert_commands( - table, states_to_insert, - bulk=False, return_defaults=False, render_nulls=False): + table, + states_to_insert, + bulk=False, + return_defaults=False, + render_nulls=False, +): """Identify sets of values to use in INSERT statements for a list of states. @@ -400,10 +479,16 @@ def _collect_insert_commands( col = propkey_to_col[propkey] if value is None and col not in eval_none and not render_nulls: continue - elif not bulk and hasattr(value, '__clause_element__') or \ - isinstance(value, sql.ClauseElement): - value_params[col.key] = value.__clause_element__() \ - if hasattr(value, '__clause_element__') else value + elif ( + not bulk + and hasattr(value, "__clause_element__") + or isinstance(value, sql.ClauseElement) + ): + value_params[col.key] = ( + value.__clause_element__() + if hasattr(value, "__clause_element__") + else value + ) else: params[col.key] = value @@ -414,8 +499,11 @@ def _collect_insert_commands( # which might be worth removing, as it should not be necessary # and also produces confusion, given that "missing" and None # now have distinct meanings - for colkey in mapper._insert_cols_as_none[table].\ - difference(params).difference(value_params): + for colkey in ( + mapper._insert_cols_as_none[table] + .difference(params) + .difference(value_params) + ): params[colkey] = None if not bulk or return_defaults: @@ -424,28 +512,38 @@ def _collect_insert_commands( has_all_pks = mapper._pk_keys_by_table[table].issubset(params) if mapper.base_mapper.eager_defaults: - has_all_defaults = mapper._server_default_cols[table].\ - issubset(params) + has_all_defaults = mapper._server_default_cols[table].issubset( + params + ) else: has_all_defaults = True else: has_all_defaults = has_all_pks = True - if mapper.version_id_generator is not False \ - and mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: - params[mapper.version_id_col.key] = \ - mapper.version_id_generator(None) + if ( + mapper.version_id_generator is not False + and mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): + params[mapper.version_id_col.key] = mapper.version_id_generator( + None + ) yield ( - state, state_dict, params, mapper, - connection, value_params, has_all_pks, - has_all_defaults) + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) def _collect_update_commands( - uowtransaction, table, states_to_update, - bulk=False): + uowtransaction, table, states_to_update, bulk=False +): """Identify sets of values to use in UPDATE statements for a list of states. @@ -457,8 +555,13 @@ def _collect_update_commands( """ - for state, state_dict, mapper, connection, \ - update_version_id in states_to_update: + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_update: if table not in mapper._pks_by_table: continue @@ -474,36 +577,48 @@ def _collect_update_commands( # look at mapper attribute keys for pk params = dict( (propkey_to_col[propkey].key, state_dict[propkey]) - for propkey in - set(propkey_to_col).intersection(state_dict).difference( - mapper._pk_attr_keys_by_table[table]) + for propkey in set(propkey_to_col) + .intersection(state_dict) + .difference(mapper._pk_attr_keys_by_table[table]) ) has_all_defaults = True else: params = {} for propkey in set(propkey_to_col).intersection( - state.committed_state): + state.committed_state + ): value = state_dict[propkey] col = propkey_to_col[propkey] - if hasattr(value, '__clause_element__') or \ - isinstance(value, sql.ClauseElement): - value_params[col] = value.__clause_element__() \ - if hasattr(value, '__clause_element__') else value + if hasattr(value, "__clause_element__") or isinstance( + value, sql.ClauseElement + ): + value_params[col] = ( + value.__clause_element__() + if hasattr(value, "__clause_element__") + else value + ) # guard against values that generate non-__nonzero__ # objects for __eq__() - elif state.manager[propkey].impl.is_equal( - value, state.committed_state[propkey]) is not True: + elif ( + state.manager[propkey].impl.is_equal( + value, state.committed_state[propkey] + ) + is not True + ): params[col.key] = value if mapper.base_mapper.eager_defaults: - has_all_defaults = mapper._server_onupdate_default_cols[table].\ - issubset(params) + has_all_defaults = mapper._server_onupdate_default_cols[ + table + ].issubset(params) else: has_all_defaults = True - if update_version_id is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: + if ( + update_version_id is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): if not bulk and not (params or value_params): # HACK: check for history in other tables, in case the @@ -511,10 +626,9 @@ def _collect_update_commands( # where the version_id_col is. This logic was lost # from 0.9 -> 1.0.0 and restored in 1.0.6. for prop in mapper._columntoproperty.values(): - history = ( - state.manager[prop.key].impl.get_history( - state, state_dict, - attributes.PASSIVE_NO_INITIALIZE)) + history = state.manager[prop.key].impl.get_history( + state, state_dict, attributes.PASSIVE_NO_INITIALIZE + ) if history.added: break else: @@ -525,8 +639,9 @@ def _collect_update_commands( no_params = not params and not value_params params[col._label] = update_version_id - if (bulk or col.key not in params) and \ - mapper.version_id_generator is not False: + if ( + bulk or col.key not in params + ) and mapper.version_id_generator is not False: val = mapper.version_id_generator(update_version_id) params[col.key] = val elif mapper.version_id_generator is False and no_params: @@ -545,9 +660,9 @@ def _collect_update_commands( # look at mapper attribute keys for pk pk_params = dict( (propkey_to_col[propkey]._label, state_dict.get(propkey)) - for propkey in - set(propkey_to_col). - intersection(mapper._pk_attr_keys_by_table[table]) + for propkey in set(propkey_to_col).intersection( + mapper._pk_attr_keys_by_table[table] + ) ) else: pk_params = {} @@ -555,12 +670,15 @@ def _collect_update_commands( propkey = mapper._columntoproperty[col].key history = state.manager[propkey].impl.get_history( - state, state_dict, attributes.PASSIVE_OFF) + state, state_dict, attributes.PASSIVE_OFF + ) if history.added: - if not history.deleted or \ - ("pk_cascaded", state, col) in \ - uowtransaction.attributes: + if ( + not history.deleted + or ("pk_cascaded", state, col) + in uowtransaction.attributes + ): pk_params[col._label] = history.added[0] params.pop(col.key, None) else: @@ -573,24 +691,38 @@ def _collect_update_commands( if pk_params[col._label] is None: raise orm_exc.FlushError( "Can't update table %s using NULL for primary " - "key value on column %s" % (table, col)) + "key value on column %s" % (table, col) + ) if params or value_params: params.update(pk_params) yield ( - state, state_dict, params, mapper, - connection, value_params, has_all_defaults, has_all_pks) + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) -def _collect_post_update_commands(base_mapper, uowtransaction, table, - states_to_update, post_update_cols): +def _collect_post_update_commands( + base_mapper, uowtransaction, table, states_to_update, post_update_cols +): """Identify sets of values to use in UPDATE statements for a list of states within a post_update operation. """ - for state, state_dict, mapper, connection, \ - update_version_id in states_to_update: + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_update: # assert table in mapper._pks_by_table @@ -600,100 +732,128 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table, for col in mapper._cols_by_table[table]: if col in pks: - params[col._label] = \ - mapper._get_state_attr_by_column( - state, - state_dict, col, passive=attributes.PASSIVE_OFF) + params[col._label] = mapper._get_state_attr_by_column( + state, state_dict, col, passive=attributes.PASSIVE_OFF + ) elif col in post_update_cols or col.onupdate is not None: prop = mapper._columntoproperty[col] history = state.manager[prop.key].impl.get_history( - state, state_dict, - attributes.PASSIVE_NO_INITIALIZE) + state, state_dict, attributes.PASSIVE_NO_INITIALIZE + ) if history.added: value = history.added[0] params[col.key] = value hasdata = True if hasdata: - if update_version_id is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: + if ( + update_version_id is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): col = mapper.version_id_col params[col._label] = update_version_id - if bool(state.key) and col.key not in params and \ - mapper.version_id_generator is not False: + if ( + bool(state.key) + and col.key not in params + and mapper.version_id_generator is not False + ): val = mapper.version_id_generator(update_version_id) params[col.key] = val yield state, state_dict, mapper, connection, params -def _collect_delete_commands(base_mapper, uowtransaction, table, - states_to_delete): +def _collect_delete_commands( + base_mapper, uowtransaction, table, states_to_delete +): """Identify values to use in DELETE statements for a list of states to be deleted.""" - for state, state_dict, mapper, connection, \ - update_version_id in states_to_delete: + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_delete: if table not in mapper._pks_by_table: continue params = {} for col in mapper._pks_by_table[table]: - params[col.key] = \ - value = \ - mapper._get_committed_state_attr_by_column( - state, state_dict, col) + params[ + col.key + ] = value = mapper._get_committed_state_attr_by_column( + state, state_dict, col + ) if value is None: raise orm_exc.FlushError( "Can't delete from table %s " "using NULL for primary " - "key value on column %s" % (table, col)) + "key value on column %s" % (table, col) + ) - if update_version_id is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: + if ( + update_version_id is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): params[mapper.version_id_col.key] = update_version_id yield params, connection -def _emit_update_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, update, - bookkeeping=True): +def _emit_update_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + update, + bookkeeping=True, +): """Emit UPDATE statements corresponding to value lists collected by _collect_update_commands().""" - needs_version_id = mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table] + needs_version_id = ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ) def update_stmt(): clause = sql.and_() for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) + clause.clauses.append( + col == sql.bindparam(col._label, type_=col.type) + ) if needs_version_id: clause.clauses.append( - mapper.version_id_col == sql.bindparam( + mapper.version_id_col + == sql.bindparam( mapper.version_id_col._label, - type_=mapper.version_id_col.type)) + type_=mapper.version_id_col.type, + ) + ) stmt = table.update(clause) return stmt - cached_stmt = base_mapper._memo(('update', table), update_stmt) - - for (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), \ - records in groupby( - update, - lambda rec: ( - rec[4], # connection - set(rec[2]), # set of parameter keys - bool(rec[5]), # whether or not we have "value" parameters - rec[6], # has_all_defaults - rec[7] # has all pks - ) + cached_stmt = base_mapper._memo(("update", table), update_stmt) + + for ( + (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), + records, + ) in groupby( + update, + lambda rec: ( + rec[4], # connection + set(rec[2]), # set of parameter keys + bool(rec[5]), # whether or not we have "value" parameters + rec[6], # has_all_defaults + rec[7], # has all pks + ), ): rows = 0 records = list(records) @@ -704,8 +864,11 @@ def _emit_update_statements(base_mapper, uowtransaction, if not has_all_pks: statement = statement.return_defaults() return_defaults = True - elif bookkeeping and not has_all_defaults and \ - mapper.base_mapper.eager_defaults: + elif ( + bookkeeping + and not has_all_defaults + and mapper.base_mapper.eager_defaults + ): statement = statement.return_defaults() return_defaults = True elif mapper.version_id_col is not None: @@ -718,17 +881,24 @@ def _emit_update_statements(base_mapper, uowtransaction, else connection.dialect.supports_sane_rowcount_returning ) - assert_multirow = assert_singlerow and \ - connection.dialect.supports_sane_multi_rowcount + assert_multirow = ( + assert_singlerow + and connection.dialect.supports_sane_multi_rowcount + ) allow_multirow = has_all_defaults and not needs_version_id if hasvalue: - for state, state_dict, params, mapper, \ - connection, value_params, \ - has_all_defaults, has_all_pks in records: - c = connection.execute( - statement.values(value_params), - params) + for ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) in records: + c = connection.execute(statement.values(value_params), params) if bookkeeping: _postfetch( mapper, @@ -738,17 +908,26 @@ def _emit_update_statements(base_mapper, uowtransaction, state_dict, c, c.context.compiled_parameters[0], - value_params) + value_params, + ) rows += c.rowcount check_rowcount = assert_singlerow else: if not allow_multirow: check_rowcount = assert_singlerow - for state, state_dict, params, mapper, \ - connection, value_params, has_all_defaults, \ - has_all_pks in records: - c = cached_connections[connection].\ - execute(statement, params) + for ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) in records: + c = cached_connections[connection].execute( + statement, params + ) # TODO: why with bookkeeping=False? if bookkeeping: @@ -760,24 +939,32 @@ def _emit_update_statements(base_mapper, uowtransaction, state_dict, c, c.context.compiled_parameters[0], - value_params) + value_params, + ) rows += c.rowcount else: multiparams = [rec[2] for rec in records] check_rowcount = assert_multirow or ( - assert_singlerow and - len(multiparams) == 1 + assert_singlerow and len(multiparams) == 1 ) - c = cached_connections[connection].\ - execute(statement, multiparams) + c = cached_connections[connection].execute( + statement, multiparams + ) rows += c.rowcount - for state, state_dict, params, mapper, \ - connection, value_params, \ - has_all_defaults, has_all_pks in records: + for ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) in records: if bookkeeping: _postfetch( mapper, @@ -787,59 +974,85 @@ def _emit_update_statements(base_mapper, uowtransaction, state_dict, c, c.context.compiled_parameters[0], - value_params) + value_params, + ) if check_rowcount: if rows != len(records): raise orm_exc.StaleDataError( "UPDATE statement on table '%s' expected to " - "update %d row(s); %d were matched." % - (table.description, len(records), rows)) + "update %d row(s); %d were matched." + % (table.description, len(records), rows) + ) elif needs_version_id: - util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % - c.dialect.dialect_description) + util.warn( + "Dialect %s does not support updated rowcount " + "- versioning cannot be verified." + % c.dialect.dialect_description + ) -def _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, insert, - bookkeeping=True): +def _emit_insert_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + insert, + bookkeeping=True, +): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" - cached_stmt = base_mapper._memo(('insert', table), table.insert) - - for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \ - records in groupby( - insert, - lambda rec: ( - rec[4], # connection - set(rec[2]), # parameter keys - bool(rec[5]), # whether we have "value" parameters - rec[6], - rec[7])): + cached_stmt = base_mapper._memo(("insert", table), table.insert) + + for ( + (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), + records, + ) in groupby( + insert, + lambda rec: ( + rec[4], # connection + set(rec[2]), # parameter keys + bool(rec[5]), # whether we have "value" parameters + rec[6], + rec[7], + ), + ): statement = cached_stmt - if not bookkeeping or \ - ( - has_all_defaults - or not base_mapper.eager_defaults - or not connection.dialect.implicit_returning - ) and has_all_pks and not hasvalue: + if ( + not bookkeeping + or ( + has_all_defaults + or not base_mapper.eager_defaults + or not connection.dialect.implicit_returning + ) + and has_all_pks + and not hasvalue + ): records = list(records) multiparams = [rec[2] for rec in records] - c = cached_connections[connection].\ - execute(statement, multiparams) + c = cached_connections[connection].execute(statement, multiparams) if bookkeeping: - for (state, state_dict, params, mapper_rec, - conn, value_params, has_all_pks, has_all_defaults), \ - last_inserted_params in \ - zip(records, c.context.compiled_parameters): + for ( + ( + state, + state_dict, + params, + mapper_rec, + conn, + value_params, + has_all_pks, + has_all_defaults, + ), + last_inserted_params, + ) in zip(records, c.context.compiled_parameters): if state: _postfetch( mapper_rec, @@ -849,7 +1062,8 @@ def _emit_insert_statements(base_mapper, uowtransaction, state_dict, c, last_inserted_params, - value_params) + value_params, + ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) @@ -859,24 +1073,33 @@ def _emit_insert_statements(base_mapper, uowtransaction, elif mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) - for state, state_dict, params, mapper_rec, \ - connection, value_params, \ - has_all_pks, has_all_defaults in records: + for ( + state, + state_dict, + params, + mapper_rec, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) in records: if value_params: result = connection.execute( - statement.values(value_params), - params) + statement.values(value_params), params + ) else: - result = cached_connections[connection].\ - execute(statement, params) + result = cached_connections[connection].execute( + statement, params + ) primary_key = result.context.inserted_primary_key if primary_key is not None: # set primary key attributes - for pk, col in zip(primary_key, - mapper._pks_by_table[table]): + for pk, col in zip( + primary_key, mapper._pks_by_table[table] + ): prop = mapper_rec._columntoproperty[col] if state_dict.get(prop.key) is None: state_dict[prop.key] = pk @@ -890,31 +1113,39 @@ def _emit_insert_statements(base_mapper, uowtransaction, state_dict, result, result.context.compiled_parameters[0], - value_params) + value_params, + ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) -def _emit_post_update_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, update): +def _emit_post_update_statements( + base_mapper, uowtransaction, cached_connections, mapper, table, update +): """Emit UPDATE statements corresponding to value lists collected by _collect_post_update_commands().""" - needs_version_id = mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table] + needs_version_id = ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ) def update_stmt(): clause = sql.and_() for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) + clause.clauses.append( + col == sql.bindparam(col._label, type_=col.type) + ) if needs_version_id: clause.clauses.append( - mapper.version_id_col == sql.bindparam( + mapper.version_id_col + == sql.bindparam( mapper.version_id_col._label, - type_=mapper.version_id_col.type)) + type_=mapper.version_id_col.type, + ) + ) stmt = table.update(clause) @@ -923,17 +1154,15 @@ def _emit_post_update_statements(base_mapper, uowtransaction, return stmt - statement = base_mapper._memo(('post_update', table), update_stmt) + statement = base_mapper._memo(("post_update", table), update_stmt) # execute each UPDATE in the order according to the original # list of states to guarantee row access order, but # also group them into common (connection, cols) sets # to support executemany(). for key, records in groupby( - update, lambda rec: ( - rec[3], # connection - set(rec[4]), # parameter keys - ) + update, + lambda rec: (rec[3], set(rec[4])), # connection # parameter keys ): rows = 0 @@ -945,84 +1174,96 @@ def _emit_post_update_statements(base_mapper, uowtransaction, if mapper.version_id_col is None else connection.dialect.supports_sane_rowcount_returning ) - assert_multirow = assert_singlerow and \ - connection.dialect.supports_sane_multi_rowcount + assert_multirow = ( + assert_singlerow + and connection.dialect.supports_sane_multi_rowcount + ) allow_multirow = not needs_version_id or assert_multirow - if not allow_multirow: check_rowcount = assert_singlerow - for state, state_dict, mapper_rec, \ - connection, params in records: - c = cached_connections[connection].\ - execute(statement, params) + for state, state_dict, mapper_rec, connection, params in records: + c = cached_connections[connection].execute(statement, params) _postfetch_post_update( - mapper_rec, uowtransaction, table, state, state_dict, - c, c.context.compiled_parameters[0]) + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + ) rows += c.rowcount else: multiparams = [ - params for - state, state_dict, mapper_rec, conn, params in records] + params + for state, state_dict, mapper_rec, conn, params in records + ] check_rowcount = assert_multirow or ( - assert_singlerow and - len(multiparams) == 1 + assert_singlerow and len(multiparams) == 1 ) - c = cached_connections[connection].\ - execute(statement, multiparams) + c = cached_connections[connection].execute(statement, multiparams) rows += c.rowcount - for state, state_dict, mapper_rec, \ - connection, params in records: + for state, state_dict, mapper_rec, connection, params in records: _postfetch_post_update( - mapper_rec, uowtransaction, table, state, state_dict, - c, c.context.compiled_parameters[0]) + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + ) if check_rowcount: if rows != len(records): raise orm_exc.StaleDataError( "UPDATE statement on table '%s' expected to " - "update %d row(s); %d were matched." % - (table.description, len(records), rows)) + "update %d row(s); %d were matched." + % (table.description, len(records), rows) + ) elif needs_version_id: - util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % - c.dialect.dialect_description) + util.warn( + "Dialect %s does not support updated rowcount " + "- versioning cannot be verified." + % c.dialect.dialect_description + ) -def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, - mapper, table, delete): +def _emit_delete_statements( + base_mapper, uowtransaction, cached_connections, mapper, table, delete +): """Emit DELETE statements corresponding to value lists collected by _collect_delete_commands().""" - need_version_id = mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table] + need_version_id = ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ) def delete_stmt(): clause = sql.and_() for col in mapper._pks_by_table[table]: clause.clauses.append( - col == sql.bindparam(col.key, type_=col.type)) + col == sql.bindparam(col.key, type_=col.type) + ) if need_version_id: clause.clauses.append( - mapper.version_id_col == - sql.bindparam( - mapper.version_id_col.key, - type_=mapper.version_id_col.type + mapper.version_id_col + == sql.bindparam( + mapper.version_id_col.key, type_=mapper.version_id_col.type ) ) return table.delete(clause) - statement = base_mapper._memo(('delete', table), delete_stmt) - for connection, recs in groupby( - delete, - lambda rec: rec[1] # connection - ): + statement = base_mapper._memo(("delete", table), delete_stmt) + for connection, recs in groupby(delete, lambda rec: rec[1]): # connection del_objects = [params for params, connection in recs] connection = cached_connections[connection] @@ -1049,9 +1290,10 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, else: util.warn( "Dialect %s does not support deleted rowcount " - "- versioning cannot be verified." % - connection.dialect.dialect_description, - stacklevel=12) + "- versioning cannot be verified." + % connection.dialect.dialect_description, + stacklevel=12, + ) connection.execute(statement, del_objects) else: c = connection.execute(statement, del_objects) @@ -1061,23 +1303,26 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, rows_matched = c.rowcount - if base_mapper.confirm_deleted_rows and \ - rows_matched > -1 and expected != rows_matched: + if ( + base_mapper.confirm_deleted_rows + and rows_matched > -1 + and expected != rows_matched + ): if only_warn: util.warn( "DELETE statement on table '%s' expected to " "delete %d row(s); %d were matched. Please set " "confirm_deleted_rows=False within the mapper " - "configuration to prevent this warning." % - (table.description, expected, rows_matched) + "configuration to prevent this warning." + % (table.description, expected, rows_matched) ) else: raise orm_exc.StaleDataError( "DELETE statement on table '%s' expected to " "delete %d row(s); %d were matched. Please set " "confirm_deleted_rows=False within the mapper " - "configuration to prevent this warning." % - (table.description, expected, rows_matched) + "configuration to prevent this warning." + % (table.description, expected, rows_matched) ) @@ -1091,13 +1336,16 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): if mapper._readonly_props: readonly = state.unmodified_intersection( [ - p.key for p in mapper._readonly_props + p.key + for p in mapper._readonly_props if ( - p.expire_on_flush and - (not p.deferred or p.key in state.dict) - ) or ( - not p.expire_on_flush and - not p.deferred and p.key not in state.dict + p.expire_on_flush + and (not p.deferred or p.key in state.dict) + ) + or ( + not p.expire_on_flush + and not p.deferred + and p.key not in state.dict ) ] ) @@ -1112,11 +1360,14 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): if base_mapper.eager_defaults: toload_now.extend( state._unloaded_non_object.intersection( - mapper._server_default_plus_onupdate_propkeys) + mapper._server_default_plus_onupdate_propkeys + ) ) - if mapper.version_id_col is not None and \ - mapper.version_id_generator is False: + if ( + mapper.version_id_col is not None + and mapper.version_id_generator is False + ): if mapper._version_id_prop.key in state.unloaded: toload_now.extend([mapper._version_id_prop.key]) @@ -1124,8 +1375,10 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): state.key = base_mapper._identity_key_from_state(state) loading.load_on_ident( uowtransaction.session.query(mapper), - state.key, refresh_state=state, - only_load_props=toload_now) + state.key, + refresh_state=state, + only_load_props=toload_now, + ) # call after_XXX extensions if not has_identity: @@ -1133,23 +1386,29 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): else: mapper.dispatch.after_update(mapper, connection, state) - if mapper.version_id_generator is False and \ - mapper.version_id_col is not None: + if ( + mapper.version_id_generator is False + and mapper.version_id_col is not None + ): if state_dict[mapper._version_id_prop.key] is None: raise orm_exc.FlushError( - "Instance does not contain a non-NULL version value") + "Instance does not contain a non-NULL version value" + ) -def _postfetch_post_update(mapper, uowtransaction, table, - state, dict_, result, params): +def _postfetch_post_update( + mapper, uowtransaction, table, state, dict_, result, params +): if uowtransaction.is_deleted(state): return prefetch_cols = result.context.compiled.prefetch postfetch_cols = result.context.compiled.postfetch - if mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: + if ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) @@ -1164,18 +1423,23 @@ def _postfetch_post_update(mapper, uowtransaction, table, if refresh_flush and load_evt_attrs: mapper.class_manager.dispatch.refresh_flush( - state, uowtransaction, load_evt_attrs) + state, uowtransaction, load_evt_attrs + ) if postfetch_cols: - state._expire_attributes(state.dict, - [mapper._columntoproperty[c].key - for c in postfetch_cols if c in - mapper._columntoproperty] - ) + state._expire_attributes( + state.dict, + [ + mapper._columntoproperty[c].key + for c in postfetch_cols + if c in mapper._columntoproperty + ], + ) -def _postfetch(mapper, uowtransaction, table, - state, dict_, result, params, value_params): +def _postfetch( + mapper, uowtransaction, table, state, dict_, result, params, value_params +): """Expire attributes in need of newly persisted database state, after an INSERT or UPDATE statement has proceeded for that state.""" @@ -1184,8 +1448,10 @@ def _postfetch(mapper, uowtransaction, table, postfetch_cols = result.context.compiled.postfetch returning_cols = result.context.compiled.returning - if mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: + if ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) @@ -1219,23 +1485,32 @@ def _postfetch(mapper, uowtransaction, table, if refresh_flush and load_evt_attrs: mapper.class_manager.dispatch.refresh_flush( - state, uowtransaction, load_evt_attrs) + state, uowtransaction, load_evt_attrs + ) if postfetch_cols: - state._expire_attributes(state.dict, - [mapper._columntoproperty[c].key - for c in postfetch_cols if c in - mapper._columntoproperty] - ) + state._expire_attributes( + state.dict, + [ + mapper._columntoproperty[c].key + for c in postfetch_cols + if c in mapper._columntoproperty + ], + ) # synchronize newly inserted ids from one table to the next # TODO: this still goes a little too often. would be nice to # have definitive list of "columns that changed" here for m, equated_pairs in mapper._table_to_equated[table]: - sync.populate(state, m, state, m, - equated_pairs, - uowtransaction, - mapper.passive_updates) + sync.populate( + state, + m, + state, + m, + equated_pairs, + uowtransaction, + mapper.passive_updates, + ) def _postfetch_bulk_save(mapper, dict_, table): @@ -1255,8 +1530,7 @@ def _connections_for_states(base_mapper, uowtransaction, states): # organize individual states with the connection # to use for update if uowtransaction.session.connection_callable: - connection_callable = \ - uowtransaction.session.connection_callable + connection_callable = uowtransaction.session.connection_callable else: connection = uowtransaction.transaction.connection(base_mapper) connection_callable = None @@ -1275,7 +1549,8 @@ def _cached_connection_dict(base_mapper): return util.PopulateDict( lambda conn: conn.execution_options( compiled_cache=base_mapper._compiled_cache - )) + ) + ) def _sort_states(states): @@ -1287,9 +1562,12 @@ def _sort_states(states): except TypeError as err: raise sa_exc.InvalidRequestError( "Could not sort objects by primary key; primary key " - "values must be sortable in Python (was: %s)" % err) - return sorted(pending, key=operator.attrgetter("insert_order")) + \ - persistent_sorted + "values must be sortable in Python (was: %s)" % err + ) + return ( + sorted(pending, key=operator.attrgetter("insert_order")) + + persistent_sorted + ) class BulkUD(object): @@ -1302,21 +1580,22 @@ class BulkUD(object): def _validate_query_state(self): for attr, methname, notset, op in ( - ('_limit', 'limit()', None, operator.is_), - ('_offset', 'offset()', None, operator.is_), - ('_order_by', 'order_by()', False, operator.is_), - ('_group_by', 'group_by()', False, operator.is_), - ('_distinct', 'distinct()', False, operator.is_), + ("_limit", "limit()", None, operator.is_), + ("_offset", "offset()", None, operator.is_), + ("_order_by", "order_by()", False, operator.is_), + ("_group_by", "group_by()", False, operator.is_), + ("_distinct", "distinct()", False, operator.is_), ( - '_from_obj', - 'join(), outerjoin(), select_from(), or from_self()', - (), operator.eq) + "_from_obj", + "join(), outerjoin(), select_from(), or from_self()", + (), + operator.eq, + ), ): if not op(getattr(self.query, attr), notset): raise sa_exc.InvalidRequestError( "Can't call Query.update() or Query.delete() " - "when %s has been called" % - (methname, ) + "when %s has been called" % (methname,) ) @property @@ -1330,8 +1609,8 @@ class BulkUD(object): except KeyError: raise sa_exc.ArgumentError( "Valid strategies for session synchronization " - "are %s" % (", ".join(sorted(repr(x) - for x in lookup)))) + "are %s" % (", ".join(sorted(repr(x) for x in lookup))) + ) else: return klass(*arg) @@ -1400,9 +1679,9 @@ class BulkEvaluate(BulkUD): try: evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) if query.whereclause is not None: - eval_condition = evaluator_compiler.process( - query.whereclause) + eval_condition = evaluator_compiler.process(query.whereclause) else: + def eval_condition(obj): return True @@ -1411,15 +1690,20 @@ class BulkEvaluate(BulkUD): except evaluator.UnevaluatableError as err: raise sa_exc.InvalidRequestError( 'Could not evaluate current criteria in Python: "%s". ' - 'Specify \'fetch\' or False for the ' - 'synchronize_session parameter.' % err) + "Specify 'fetch' or False for the " + "synchronize_session parameter." % err + ) # TODO: detect when the where clause is a trivial primary key match self.matched_objects = [ - obj for (cls, pk, identity_token), obj in - query.session.identity_map.items() - if issubclass(cls, target_cls) and - eval_condition(obj)] + obj + for ( + cls, + pk, + identity_token, + ), obj in query.session.identity_map.items() + if issubclass(cls, target_cls) and eval_condition(obj) + ] class BulkFetch(BulkUD): @@ -1430,11 +1714,11 @@ class BulkFetch(BulkUD): session = query.session context = query._compile_context() select_stmt = context.statement.with_only_columns( - self.primary_table.primary_key) + self.primary_table.primary_key + ) self.matched_rows = session.execute( - select_stmt, - mapper=self.mapper, - params=query._params).fetchall() + select_stmt, mapper=self.mapper, params=query._params + ).fetchall() class BulkUpdate(BulkUD): @@ -1447,18 +1731,26 @@ class BulkUpdate(BulkUD): @classmethod def factory(cls, query, synchronize_session, values, update_kwargs): - return BulkUD._factory({ - "evaluate": BulkUpdateEvaluate, - "fetch": BulkUpdateFetch, - False: BulkUpdate - }, synchronize_session, query, values, update_kwargs) + return BulkUD._factory( + { + "evaluate": BulkUpdateEvaluate, + "fetch": BulkUpdateFetch, + False: BulkUpdate, + }, + synchronize_session, + query, + values, + update_kwargs, + ) @property def _resolved_values(self): values = [] for k, v in ( - self.values.items() if hasattr(self.values, 'items') - else self.values): + self.values.items() + if hasattr(self.values, "items") + else self.values + ): if self.mapper: if isinstance(k, util.string_types): desc = _entity_descriptor(self.mapper, k) @@ -1478,7 +1770,7 @@ class BulkUpdate(BulkUD): if isinstance(k, attributes.QueryableAttribute): values.append((k.key, v)) continue - elif hasattr(k, '__clause_element__'): + elif hasattr(k, "__clause_element__"): k = k.__clause_element__() if self.mapper and isinstance(k, expression.ColumnElement): @@ -1490,18 +1782,22 @@ class BulkUpdate(BulkUD): values.append((attr.key, v)) else: raise sa_exc.InvalidRequestError( - "Invalid expression type: %r" % k) + "Invalid expression type: %r" % k + ) return values def _do_exec(self): values = self._resolved_values - if not self.update_kwargs.get('preserve_parameter_order', False): + if not self.update_kwargs.get("preserve_parameter_order", False): values = dict(values) - update_stmt = sql.update(self.primary_table, - self.context.whereclause, values, - **self.update_kwargs) + update_stmt = sql.update( + self.primary_table, + self.context.whereclause, + values, + **self.update_kwargs + ) self._execute_stmt(update_stmt) @@ -1518,15 +1814,18 @@ class BulkDelete(BulkUD): @classmethod def factory(cls, query, synchronize_session): - return BulkUD._factory({ - "evaluate": BulkDeleteEvaluate, - "fetch": BulkDeleteFetch, - False: BulkDelete - }, synchronize_session, query) + return BulkUD._factory( + { + "evaluate": BulkDeleteEvaluate, + "fetch": BulkDeleteFetch, + False: BulkDelete, + }, + synchronize_session, + query, + ) def _do_exec(self): - delete_stmt = sql.delete(self.primary_table, - self.context.whereclause) + delete_stmt = sql.delete(self.primary_table, self.context.whereclause) self._execute_stmt(delete_stmt) @@ -1544,32 +1843,33 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): values = self._resolved_values_keys_as_propnames for key, value in values: self.value_evaluators[key] = evaluator_compiler.process( - expression._literal_as_binds(value)) + expression._literal_as_binds(value) + ) def _do_post_synchronize(self): session = self.query.session states = set() evaluated_keys = list(self.value_evaluators.keys()) for obj in self.matched_objects: - state, dict_ = attributes.instance_state(obj),\ - attributes.instance_dict(obj) + state, dict_ = ( + attributes.instance_state(obj), + attributes.instance_dict(obj), + ) # only evaluate unmodified attributes - to_evaluate = state.unmodified.intersection( - evaluated_keys) + to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: dict_[key] = self.value_evaluators[key](obj) - state.manager.dispatch.refresh( - state, None, to_evaluate) + state.manager.dispatch.refresh(state, None, to_evaluate) state._commit(dict_, list(to_evaluate)) # expire attributes with pending changes # (there was no autoflush, so they are overwritten) - state._expire_attributes(dict_, - set(evaluated_keys). - difference(to_evaluate)) + state._expire_attributes( + dict_, set(evaluated_keys).difference(to_evaluate) + ) states.add(state) session._register_altered(states) @@ -1580,8 +1880,8 @@ class BulkDeleteEvaluate(BulkEvaluate, BulkDelete): def _do_post_synchronize(self): self.query.session._remove_newly_deleted( - [attributes.instance_state(obj) - for obj in self.matched_objects]) + [attributes.instance_state(obj) for obj in self.matched_objects] + ) class BulkUpdateFetch(BulkFetch, BulkUpdate): @@ -1592,15 +1892,18 @@ class BulkUpdateFetch(BulkFetch, BulkUpdate): session = self.query.session target_mapper = self.query._mapper_zero() - states = set([ - attributes.instance_state(session.identity_map[identity_key]) - for identity_key in [ - target_mapper.identity_key_from_primary_key( - list(primary_key)) - for primary_key in self.matched_rows + states = set( + [ + attributes.instance_state(session.identity_map[identity_key]) + for identity_key in [ + target_mapper.identity_key_from_primary_key( + list(primary_key) + ) + for primary_key in self.matched_rows + ] + if identity_key in session.identity_map ] - if identity_key in session.identity_map - ]) + ) values = self._resolved_values_keys_as_propnames attrib = set(k for k, v in values) @@ -1622,10 +1925,13 @@ class BulkDeleteFetch(BulkFetch, BulkDelete): # TODO: inline this and call remove_newly_deleted # once identity_key = target_mapper.identity_key_from_primary_key( - list(primary_key)) + list(primary_key) + ) if identity_key in session.identity_map: session._remove_newly_deleted( - [attributes.instance_state( - session.identity_map[identity_key] - )] + [ + attributes.instance_state( + session.identity_map[identity_key] + ) + ] ) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index ca47fe7eaa..a39cd87036 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -20,7 +20,7 @@ from .util import _orm_full_deannotate from .interfaces import PropComparator, StrategizedProperty -__all__ = ['ColumnProperty'] +__all__ = ["ColumnProperty"] @log.class_logger @@ -31,14 +31,27 @@ class ColumnProperty(StrategizedProperty): """ - strategy_wildcard_key = 'column' + strategy_wildcard_key = "column" __slots__ = ( - '_orig_columns', 'columns', 'group', 'deferred', - 'instrument', 'comparator_factory', 'descriptor', 'extension', - 'active_history', 'expire_on_flush', 'info', 'doc', - 'strategy_key', '_creation_order', '_is_polymorphic_discriminator', - '_mapped_by_synonym', '_deferred_column_loader') + "_orig_columns", + "columns", + "group", + "deferred", + "instrument", + "comparator_factory", + "descriptor", + "extension", + "active_history", + "expire_on_flush", + "info", + "doc", + "strategy_key", + "_creation_order", + "_is_polymorphic_discriminator", + "_mapped_by_synonym", + "_deferred_column_loader", + ) def __init__(self, *columns, **kwargs): r"""Provide a column-level property for use with a Mapper. @@ -117,26 +130,28 @@ class ColumnProperty(StrategizedProperty): """ super(ColumnProperty, self).__init__() self._orig_columns = [expression._labeled(c) for c in columns] - self.columns = [expression._labeled(_orm_full_deannotate(c)) - for c in columns] - self.group = kwargs.pop('group', None) - self.deferred = kwargs.pop('deferred', False) - self.instrument = kwargs.pop('_instrument', True) - self.comparator_factory = kwargs.pop('comparator_factory', - self.__class__.Comparator) - self.descriptor = kwargs.pop('descriptor', None) - self.extension = kwargs.pop('extension', None) - self.active_history = kwargs.pop('active_history', False) - self.expire_on_flush = kwargs.pop('expire_on_flush', True) - - if 'info' in kwargs: - self.info = kwargs.pop('info') - - if 'doc' in kwargs: - self.doc = kwargs.pop('doc') + self.columns = [ + expression._labeled(_orm_full_deannotate(c)) for c in columns + ] + self.group = kwargs.pop("group", None) + self.deferred = kwargs.pop("deferred", False) + self.instrument = kwargs.pop("_instrument", True) + self.comparator_factory = kwargs.pop( + "comparator_factory", self.__class__.Comparator + ) + self.descriptor = kwargs.pop("descriptor", None) + self.extension = kwargs.pop("extension", None) + self.active_history = kwargs.pop("active_history", False) + self.expire_on_flush = kwargs.pop("expire_on_flush", True) + + if "info" in kwargs: + self.info = kwargs.pop("info") + + if "doc" in kwargs: + self.doc = kwargs.pop("doc") else: for col in reversed(self.columns): - doc = getattr(col, 'doc', None) + doc = getattr(col, "doc", None) if doc is not None: self.doc = doc break @@ -145,22 +160,24 @@ class ColumnProperty(StrategizedProperty): if kwargs: raise TypeError( - "%s received unexpected keyword argument(s): %s" % ( - self.__class__.__name__, - ', '.join(sorted(kwargs.keys())))) + "%s received unexpected keyword argument(s): %s" + % (self.__class__.__name__, ", ".join(sorted(kwargs.keys()))) + ) util.set_creation_order(self) self.strategy_key = ( ("deferred", self.deferred), - ("instrument", self.instrument) + ("instrument", self.instrument), ) @util.dependencies("sqlalchemy.orm.state", "sqlalchemy.orm.strategies") def _memoized_attr__deferred_column_loader(self, state, strategies): return state.InstanceState._instance_level_callable_processor( self.parent.class_manager, - strategies.LoadDeferredColumns(self.key), self.key) + strategies.LoadDeferredColumns(self.key), + self.key, + ) def __clause_element__(self): """Allow the ColumnProperty to work in expression before it is turned @@ -185,34 +202,50 @@ class ColumnProperty(StrategizedProperty): self.key, comparator=self.comparator_factory(self, mapper), parententity=mapper, - doc=self.doc + doc=self.doc, ) def do_init(self): super(ColumnProperty, self).do_init() - if len(self.columns) > 1 and \ - set(self.parent.primary_key).issuperset(self.columns): + if len(self.columns) > 1 and set(self.parent.primary_key).issuperset( + self.columns + ): util.warn( - ("On mapper %s, primary key column '%s' is being combined " - "with distinct primary key column '%s' in attribute '%s'. " - "Use explicit properties to give each column its own mapped " - "attribute name.") % (self.parent, self.columns[1], - self.columns[0], self.key)) + ( + "On mapper %s, primary key column '%s' is being combined " + "with distinct primary key column '%s' in attribute '%s'. " + "Use explicit properties to give each column its own mapped " + "attribute name." + ) + % (self.parent, self.columns[1], self.columns[0], self.key) + ) def copy(self): return ColumnProperty( deferred=self.deferred, group=self.group, active_history=self.active_history, - *self.columns) + *self.columns + ) - def _getcommitted(self, state, dict_, column, - passive=attributes.PASSIVE_OFF): - return state.get_impl(self.key).\ - get_committed_value(state, dict_, passive=passive) + def _getcommitted( + self, state, dict_, column, passive=attributes.PASSIVE_OFF + ): + return state.get_impl(self.key).get_committed_value( + state, dict_, passive=passive + ) - def merge(self, session, source_state, source_dict, dest_state, - dest_dict, load, _recursive, _resolve_conflict_map): + def merge( + self, + session, + source_state, + source_dict, + dest_state, + dest_dict, + load, + _recursive, + _resolve_conflict_map, + ): if not self.instrument: return elif self.key in source_dict: @@ -225,7 +258,8 @@ class ColumnProperty(StrategizedProperty): impl.set(dest_state, dest_dict, value, None) elif dest_state.has_identity and self.key not in dest_dict: dest_state._expire_attributes( - dest_dict, [self.key], no_loader=True) + dest_dict, [self.key], no_loader=True + ) class Comparator(util.MemoizedSlots, PropComparator): """Produce boolean, comparison, and other operators for @@ -246,7 +280,7 @@ class ColumnProperty(StrategizedProperty): """ - __slots__ = '__clause_element__', 'info' + __slots__ = "__clause_element__", "info" def _memoized_method___clause_element__(self): if self.adapter: @@ -254,9 +288,12 @@ class ColumnProperty(StrategizedProperty): else: # no adapter, so we aren't aliased # assert self._parententity is self._parentmapper - return self.prop.columns[0]._annotate({ - "parententity": self._parententity, - "parentmapper": self._parententity}) + return self.prop.columns[0]._annotate( + { + "parententity": self._parententity, + "parentmapper": self._parententity, + } + ) def _memoized_attr_info(self): ce = self.__clause_element__() diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index febf627b4a..4a55a32478 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -22,26 +22,37 @@ database to return iterable result sets. from itertools import chain from . import ( - attributes, interfaces, object_mapper, persistence, - exc as orm_exc, loading + attributes, + interfaces, + object_mapper, + persistence, + exc as orm_exc, + loading, +) +from .base import ( + _entity_descriptor, + _is_aliased_class, + _is_mapped_class, + _orm_columns, + _generative, + InspectionAttr, ) -from .base import _entity_descriptor, _is_aliased_class, \ - _is_mapped_class, _orm_columns, _generative, InspectionAttr from .path_registry import PathRegistry from .util import ( - AliasedClass, ORMAdapter, join as orm_join, with_parent, aliased, - _entity_corresponds_to + AliasedClass, + ORMAdapter, + join as orm_join, + with_parent, + aliased, + _entity_corresponds_to, ) from .. import sql, util, log, exc as sa_exc, inspect, inspection from ..sql.expression import _interpret_as_from -from ..sql import ( - util as sql_util, - expression, visitors -) +from ..sql import util as sql_util, expression, visitors from ..sql.base import ColumnCollection from . import properties -__all__ = ['Query', 'QueryContext', 'aliased'] +__all__ = ["Query", "QueryContext", "aliased"] _path_registry = PathRegistry.root @@ -192,16 +203,20 @@ class Query(object): for entity in ent.entities: if entity not in d: ext_info = inspect(entity) - if not ext_info.is_aliased_class and \ - ext_info.mapper.with_polymorphic: - if ext_info.mapper.mapped_table not in \ - self._polymorphic_adapters: + if ( + not ext_info.is_aliased_class + and ext_info.mapper.with_polymorphic + ): + if ( + ext_info.mapper.mapped_table + not in self._polymorphic_adapters + ): self._mapper_loads_polymorphically_with( ext_info.mapper, sql_util.ColumnAdapter( ext_info.selectable, - ext_info.mapper._equivalent_columns - ) + ext_info.mapper._equivalent_columns, + ), ) aliased_adapter = None elif ext_info.is_aliased_class: @@ -209,10 +224,7 @@ class Query(object): else: aliased_adapter = None - d[entity] = ( - ext_info, - aliased_adapter - ) + d[entity] = (ext_info, aliased_adapter) ent.setup_entity(*d[entity]) def _mapper_loads_polymorphically_with(self, mapper, adapter): @@ -227,18 +239,21 @@ class Query(object): for from_obj in obj: info = inspect(from_obj) - if hasattr(info, 'mapper') and \ - (info.is_mapper or info.is_aliased_class): + if hasattr(info, "mapper") and ( + info.is_mapper or info.is_aliased_class + ): self._select_from_entity = info if set_base_alias and not info.is_aliased_class: raise sa_exc.ArgumentError( "A selectable (FromClause) instance is " - "expected when the base alias is being set.") + "expected when the base alias is being set." + ) fa.append(info.selectable) elif not info.is_selectable: raise sa_exc.ArgumentError( "argument is not a mapped class, mapper, " - "aliased(), or FromClause instance.") + "aliased(), or FromClause instance." + ) else: if isinstance(from_obj, expression.SelectBase): from_obj = from_obj.alias() @@ -248,16 +263,21 @@ class Query(object): self._from_obj = tuple(fa) - if set_base_alias and \ - len(self._from_obj) == 1 and \ - isinstance(select_from_alias, expression.Alias): + if ( + set_base_alias + and len(self._from_obj) == 1 + and isinstance(select_from_alias, expression.Alias) + ): equivs = self.__all_equivs() self._from_obj_alias = sql_util.ColumnAdapter( - self._from_obj[0], equivs) - elif set_base_alias and \ - len(self._from_obj) == 1 and \ - hasattr(info, "mapper") and \ - info.is_aliased_class: + self._from_obj[0], equivs + ) + elif ( + set_base_alias + and len(self._from_obj) == 1 + and hasattr(info, "mapper") + and info.is_aliased_class + ): self._from_obj_alias = info._adapter def _reset_polymorphic_adapter(self, mapper): @@ -268,14 +288,14 @@ class Query(object): def _adapt_polymorphic_element(self, element): if "parententity" in element._annotations: - search = element._annotations['parententity'] + search = element._annotations["parententity"] alias = self._polymorphic_adapters.get(search, None) if alias: return alias.adapt_clause(element) if isinstance(element, expression.FromClause): search = element - elif hasattr(element, 'table'): + elif hasattr(element, "table"): search = element.table else: return None @@ -287,8 +307,8 @@ class Query(object): def _adapt_col_list(self, cols): return [ self._adapt_clause( - expression._literal_as_label_reference(o), - True, True) + expression._literal_as_label_reference(o), True, True + ) for o in cols ] @@ -312,11 +332,7 @@ class Query(object): if as_filter and self._filter_aliases: for fa in self._filter_aliases.visitor_iterator: - adapters.append( - ( - orm_only, fa.replace - ) - ) + adapters.append((orm_only, fa.replace)) if self._from_obj_alias: # for the "from obj" alias, apply extra rule to the @@ -326,16 +342,12 @@ class Query(object): adapters.append( ( orm_only if self._orm_only_from_obj_alias else False, - self._from_obj_alias.replace + self._from_obj_alias.replace, ) ) if self._polymorphic_adapters: - adapters.append( - ( - orm_only, self._adapt_polymorphic_element - ) - ) + adapters.append((orm_only, self._adapt_polymorphic_element)) if not adapters: return clause @@ -344,19 +356,17 @@ class Query(object): for _orm_only, adapter in adapters: # if 'orm only', look for ORM annotations # in the element before adapting. - if not _orm_only or \ - '_orm_adapt' in elem._annotations or \ - "parententity" in elem._annotations: + if ( + not _orm_only + or "_orm_adapt" in elem._annotations + or "parententity" in elem._annotations + ): e = adapter(elem) if e is not None: return e - return visitors.replacement_traverse( - clause, - {}, - replace - ) + return visitors.replacement_traverse(clause, {}, replace) def _query_entity_zero(self): """Return the first QueryEntity.""" @@ -371,9 +381,11 @@ class Query(object): with the first QueryEntity, or alternatively the 'select from' entity if specified.""" - return self._select_from_entity \ - if self._select_from_entity is not None \ + return ( + self._select_from_entity + if self._select_from_entity is not None else self._query_entity_zero().entity_zero + ) @property def _mapper_entities(self): @@ -382,10 +394,7 @@ class Query(object): yield ent def _joinpoint_zero(self): - return self._joinpoint.get( - '_joinpoint_entity', - self._entity_zero() - ) + return self._joinpoint.get("_joinpoint_entity", self._entity_zero()) def _bind_mapper(self): ezero = self._entity_zero() @@ -400,14 +409,15 @@ class Query(object): if self._entities != [self._primary_entity]: raise sa_exc.InvalidRequestError( "%s() can only be used against " - "a single mapped class." % methname) + "a single mapped class." % methname + ) return self._primary_entity.entity_zero def _only_entity_zero(self, rationale=None): if len(self._entities) > 1: raise sa_exc.InvalidRequestError( - rationale or - "This operation requires a Query " + rationale + or "This operation requires a Query " "against a single mapper." ) return self._entity_zero() @@ -420,7 +430,8 @@ class Query(object): def _get_condition(self): return self._no_criterion_condition( - "get", order_by=False, distinct=False) + "get", order_by=False, distinct=False + ) def _get_existing_condition(self): self._no_criterion_assertion("get", order_by=False, distinct=False) @@ -428,14 +439,20 @@ class Query(object): def _no_criterion_assertion(self, meth, order_by=True, distinct=True): if not self._enable_assertions: return - if self._criterion is not None or \ - self._statement is not None or self._from_obj or \ - self._limit is not None or self._offset is not None or \ - self._group_by or (order_by and self._order_by) or \ - (distinct and self._distinct): + if ( + self._criterion is not None + or self._statement is not None + or self._from_obj + or self._limit is not None + or self._offset is not None + or self._group_by + or (order_by and self._order_by) + or (distinct and self._distinct) + ): raise sa_exc.InvalidRequestError( "Query.%s() being called on a " - "Query with existing criterion. " % meth) + "Query with existing criterion. " % meth + ) def _no_criterion_condition(self, meth, order_by=True, distinct=True): self._no_criterion_assertion(meth, order_by, distinct) @@ -450,7 +467,8 @@ class Query(object): if self._order_by: raise sa_exc.InvalidRequestError( "Query.%s() being called on a " - "Query with existing criterion. " % meth) + "Query with existing criterion. " % meth + ) self._no_criterion_condition(meth) def _no_statement_condition(self, meth): @@ -458,8 +476,12 @@ class Query(object): return if self._statement is not None: raise sa_exc.InvalidRequestError( - ("Query.%s() being called on a Query with an existing full " - "statement - can't apply criterion.") % meth) + ( + "Query.%s() being called on a Query with an existing full " + "statement - can't apply criterion." + ) + % meth + ) def _no_limit_offset(self, meth): if not self._enable_assertions: @@ -470,15 +492,17 @@ class Query(object): "or OFFSET applied. To modify the row-limited results of a " " Query, call from_self() first. " "Otherwise, call %s() before limit() or offset() " - "are applied." - % (meth, meth) + "are applied." % (meth, meth) ) - def _get_options(self, populate_existing=None, - version_check=None, - only_load_props=None, - refresh_state=None, - identity_token=None): + def _get_options( + self, + populate_existing=None, + version_check=None, + only_load_props=None, + refresh_state=None, + identity_token=None, + ): if populate_existing: self._populate_existing = populate_existing if version_check: @@ -507,8 +531,7 @@ class Query(object): """ - stmt = self._compile_context(labels=self._with_labels).\ - statement + stmt = self._compile_context(labels=self._with_labels).statement if self._params: stmt = stmt.params(self._params) @@ -602,8 +625,9 @@ class Query(object): :meth:`.HasCTE.cte` """ - return self.enable_eagerloads(False).\ - statement.cte(name=name, recursive=recursive) + return self.enable_eagerloads(False).statement.cte( + name=name, recursive=recursive + ) def label(self, name): """Return the full SELECT statement represented by this @@ -678,7 +702,8 @@ class Query(object): "compatible with %s eager loading. Please " "specify lazyload('*') or query.enable_eagerloads(False) in " "order to " - "proceed with query.yield_per()." % message) + "proceed with query.yield_per()." % message + ) @_generative() def with_labels(self): @@ -752,10 +777,9 @@ class Query(object): self._current_path = path @_generative(_no_clauseelement_condition) - def with_polymorphic(self, - cls_or_mappers, - selectable=None, - polymorphic_on=None): + def with_polymorphic( + self, cls_or_mappers, selectable=None, polymorphic_on=None + ): """Load columns for inheriting classes. :meth:`.Query.with_polymorphic` applies transformations @@ -783,13 +807,16 @@ class Query(object): if not self._primary_entity: raise sa_exc.InvalidRequestError( - "No primary mapper set up for this Query.") + "No primary mapper set up for this Query." + ) entity = self._entities[0]._clone() self._entities = [entity] + self._entities[1:] - entity.set_with_polymorphic(self, - cls_or_mappers, - selectable=selectable, - polymorphic_on=polymorphic_on) + entity.set_with_polymorphic( + self, + cls_or_mappers, + selectable=selectable, + polymorphic_on=polymorphic_on, + ) @_generative() def yield_per(self, count): @@ -858,8 +885,8 @@ class Query(object): """ self._yield_per = count self._execution_options = self._execution_options.union( - {"stream_results": True, - "max_row_buffer": count}) + {"stream_results": True, "max_row_buffer": count} + ) def get(self, ident): """Return an instance based on the given primary key identifier, @@ -918,12 +945,16 @@ class Query(object): :return: The object instance, or ``None``. """ - return self._get_impl( - ident, loading.load_on_pk_identity) - - def _identity_lookup(self, mapper, primary_key_identity, - identity_token=None, passive=attributes.PASSIVE_OFF, - lazy_loaded_from=None): + return self._get_impl(ident, loading.load_on_pk_identity) + + def _identity_lookup( + self, + mapper, + primary_key_identity, + identity_token=None, + passive=attributes.PASSIVE_OFF, + lazy_loaded_from=None, + ): """Locate an object in the identity map. Given a primary key identity, constructs an identity key and then @@ -966,14 +997,13 @@ class Query(object): """ key = mapper.identity_key_from_primary_key( - primary_key_identity, identity_token=identity_token) - return loading.get_from_identity( - self.session, key, passive) + primary_key_identity, identity_token=identity_token + ) + return loading.get_from_identity(self.session, key, passive) - def _get_impl( - self, primary_key_identity, db_load_fn, identity_token=None): + def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None): # convert composite types to individual args - if hasattr(primary_key_identity, '__composite_values__'): + if hasattr(primary_key_identity, "__composite_values__"): primary_key_identity = primary_key_identity.__composite_values__() primary_key_identity = util.to_list(primary_key_identity) @@ -983,16 +1013,19 @@ class Query(object): if len(primary_key_identity) != len(mapper.primary_key): raise sa_exc.InvalidRequestError( "Incorrect number of values in identifier to formulate " - "primary key for query.get(); primary key columns are %s" % - ','.join("'%s'" % c for c in mapper.primary_key)) + "primary key for query.get(); primary key columns are %s" + % ",".join("'%s'" % c for c in mapper.primary_key) + ) - if not self._populate_existing and \ - not mapper.always_refresh and \ - self._for_update_arg is None: + if ( + not self._populate_existing + and not mapper.always_refresh + and self._for_update_arg is None + ): instance = self._identity_lookup( - mapper, primary_key_identity, - identity_token=identity_token) + mapper, primary_key_identity, identity_token=identity_token + ) if instance is not None: self._get_existing_condition() @@ -1106,17 +1139,20 @@ class Query(object): mapper = object_mapper(instance) for prop in mapper.iterate_properties: - if isinstance(prop, properties.RelationshipProperty) and \ - prop.mapper is entity_zero.mapper: + if ( + isinstance(prop, properties.RelationshipProperty) + and prop.mapper is entity_zero.mapper + ): property = prop break else: raise sa_exc.InvalidRequestError( "Could not locate a property which relates instances " - "of class '%s' to instances of class '%s'" % - ( + "of class '%s' to instances of class '%s'" + % ( entity_zero.mapper.class_.__name__, - instance.__class__.__name__) + instance.__class__.__name__, + ) ) return self.filter(with_parent(instance, property, entity_zero.entity)) @@ -1323,8 +1359,11 @@ class Query(object): those being selected. """ - fromclause = self.with_labels().enable_eagerloads(False).\ - statement.correlate(None) + fromclause = ( + self.with_labels() + .enable_eagerloads(False) + .statement.correlate(None) + ) q = self._from_selectable(fromclause) q._enable_single_crit = False q._select_from_entity = self._entity_zero() @@ -1339,12 +1378,18 @@ class Query(object): @_generative() def _from_selectable(self, fromclause): for attr in ( - '_statement', '_criterion', - '_order_by', '_group_by', - '_limit', '_offset', - '_joinpath', '_joinpoint', - '_distinct', '_having', - '_prefixes', '_suffixes' + "_statement", + "_criterion", + "_order_by", + "_group_by", + "_limit", + "_offset", + "_joinpath", + "_joinpoint", + "_distinct", + "_having", + "_prefixes", + "_suffixes", ): self.__dict__.pop(attr, None) self._set_select_from([fromclause], True) @@ -1369,6 +1414,7 @@ class Query(object): if not q._yield_per: q._yield_per = 10 return iter(q) + _values = values def value(self, column): @@ -1420,10 +1466,11 @@ class Query(object): # given arg is a FROM clause self._set_entity_selectables(self._entities[l:]) - @util.pending_deprecation("0.7", - ":meth:`.add_column` is superseded " - "by :meth:`.add_columns`", - False) + @util.pending_deprecation( + "0.7", + ":meth:`.add_column` is superseded " "by :meth:`.add_columns`", + False, + ) def add_column(self, column): """Add a column expression to the list of result columns to be returned. @@ -1454,8 +1501,8 @@ class Query(object): # most MapperOptions write to the '_attributes' dictionary, # so copy that as well self._attributes = self._attributes.copy() - if '_unbound_load_dedupes' not in self._attributes: - self._attributes['_unbound_load_dedupes'] = set() + if "_unbound_load_dedupes" not in self._attributes: + self._attributes["_unbound_load_dedupes"] = set() opts = tuple(util.flatten_iterator(args)) self._with_options = self._with_options + opts if conditional: @@ -1487,7 +1534,7 @@ class Query(object): return fn(self) @_generative() - def with_hint(self, selectable, text, dialect_name='*'): + def with_hint(self, selectable, text, dialect_name="*"): """Add an indexing or other executional context hint for the given entity or selectable to this :class:`.Query`. @@ -1508,7 +1555,7 @@ class Query(object): self._with_hints += ((selectable, text, dialect_name),) - def with_statement_hint(self, text, dialect_name='*'): + def with_statement_hint(self, text, dialect_name="*"): """add a statement hint to this :class:`.Select`. This method is similar to :meth:`.Select.with_hint` except that @@ -1570,8 +1617,14 @@ class Query(object): self._for_update_arg = LockmodeArg.parse_legacy_query(mode) @_generative() - def with_for_update(self, read=False, nowait=False, of=None, - skip_locked=False, key_share=False): + def with_for_update( + self, + read=False, + nowait=False, + of=None, + skip_locked=False, + key_share=False, + ): """return a new :class:`.Query` with the specified options for the ``FOR UPDATE`` clause. @@ -1599,9 +1652,13 @@ class Query(object): full argument and behavioral description. """ - self._for_update_arg = LockmodeArg(read=read, nowait=nowait, of=of, - skip_locked=skip_locked, - key_share=key_share) + self._for_update_arg = LockmodeArg( + read=read, + nowait=nowait, + of=of, + skip_locked=skip_locked, + key_share=key_share, + ) @_generative() def params(self, *args, **kwargs): @@ -1619,7 +1676,8 @@ class Query(object): elif len(args) > 0: raise sa_exc.ArgumentError( "params() takes zero or one positional argument, " - "which is a dictionary.") + "which is a dictionary." + ) self._params = self._params.copy() self._params.update(kwargs) @@ -1683,8 +1741,10 @@ class Query(object): """ - clauses = [_entity_descriptor(self._joinpoint_zero(), key) == value - for key, value in kwargs.items()] + clauses = [ + _entity_descriptor(self._joinpoint_zero(), key) == value + for key, value in kwargs.items() + ] return self.filter(sql.and_(*clauses)) @_generative(_no_statement_condition, _no_limit_offset) @@ -1704,7 +1764,7 @@ class Query(object): if len(criterion) == 1: if criterion[0] is False: - if '_order_by' in self.__dict__: + if "_order_by" in self.__dict__: self._order_by = False return if criterion[0] is None: @@ -1765,11 +1825,13 @@ class Query(object): criterion = expression._expression_literal_as_text(criterion) - if criterion is not None and \ - not isinstance(criterion, sql.ClauseElement): + if criterion is not None and not isinstance( + criterion, sql.ClauseElement + ): raise sa_exc.ArgumentError( "having() argument must be of type " - "sqlalchemy.sql.ClauseElement or string") + "sqlalchemy.sql.ClauseElement or string" + ) criterion = self._adapt_clause(criterion, True, True) @@ -2122,17 +2184,23 @@ class Query(object): SQLAlchemy versions was the primary ORM-level joining interface. """ - aliased, from_joinpoint, isouter, full = kwargs.pop('aliased', False),\ - kwargs.pop('from_joinpoint', False),\ - kwargs.pop('isouter', False),\ - kwargs.pop('full', False) + aliased, from_joinpoint, isouter, full = ( + kwargs.pop("aliased", False), + kwargs.pop("from_joinpoint", False), + kwargs.pop("isouter", False), + kwargs.pop("full", False), + ) if kwargs: - raise TypeError("unknown arguments: %s" % - ', '.join(sorted(kwargs))) - return self._join(props, - outerjoin=isouter, full=full, - create_aliases=aliased, - from_joinpoint=from_joinpoint) + raise TypeError( + "unknown arguments: %s" % ", ".join(sorted(kwargs)) + ) + return self._join( + props, + outerjoin=isouter, + full=full, + create_aliases=aliased, + from_joinpoint=from_joinpoint, + ) def outerjoin(self, *props, **kwargs): """Create a left outer join against this ``Query`` object's criterion @@ -2141,25 +2209,32 @@ class Query(object): Usage is the same as the ``join()`` method. """ - aliased, from_joinpoint, full = kwargs.pop('aliased', False), \ - kwargs.pop('from_joinpoint', False), \ - kwargs.pop('full', False) + aliased, from_joinpoint, full = ( + kwargs.pop("aliased", False), + kwargs.pop("from_joinpoint", False), + kwargs.pop("full", False), + ) if kwargs: - raise TypeError("unknown arguments: %s" % - ', '.join(sorted(kwargs))) - return self._join(props, - outerjoin=True, full=full, create_aliases=aliased, - from_joinpoint=from_joinpoint) + raise TypeError( + "unknown arguments: %s" % ", ".join(sorted(kwargs)) + ) + return self._join( + props, + outerjoin=True, + full=full, + create_aliases=aliased, + from_joinpoint=from_joinpoint, + ) def _update_joinpoint(self, jp): self._joinpoint = jp # copy backwards to the root of the _joinpath # dict, so that no existing dict in the path is mutated - while 'prev' in jp: - f, prev = jp['prev'] + while "prev" in jp: + f, prev = jp["prev"] prev = prev.copy() prev[f] = jp - jp['prev'] = (f, prev) + jp["prev"] = (f, prev) jp = prev self._joinpath = jp @@ -2173,11 +2248,16 @@ class Query(object): if not from_joinpoint: self._reset_joinpoint() - if len(keys) == 2 and \ - isinstance(keys[0], (expression.FromClause, - type, AliasedClass)) and \ - isinstance(keys[1], (str, expression.ClauseElement, - interfaces.PropComparator)): + if ( + len(keys) == 2 + and isinstance( + keys[0], (expression.FromClause, type, AliasedClass) + ) + and isinstance( + keys[1], + (str, expression.ClauseElement, interfaces.PropComparator), + ) + ): # detect 2-arg form of join and # convert to a tuple. keys = (keys,) @@ -2202,20 +2282,22 @@ class Query(object): # is a little bit of legacy behavior still at work here # which means they might be in either order. if isinstance( - arg1, (interfaces.PropComparator, util.string_types)): + arg1, (interfaces.PropComparator, util.string_types) + ): right, onclause = arg2, arg1 else: right, onclause = arg1, arg2 if onclause is None: r_info = inspect(right) - if not r_info.is_selectable and not hasattr(r_info, 'mapper'): + if not r_info.is_selectable and not hasattr(r_info, "mapper"): raise sa_exc.ArgumentError( "Expected mapped entity or " - "selectable/table as join target") + "selectable/table as join target" + ) if isinstance(onclause, interfaces.PropComparator): - of_type = getattr(onclause, '_of_type', None) + of_type = getattr(onclause, "_of_type", None) else: of_type = None @@ -2234,12 +2316,13 @@ class Query(object): # to work with the aliased=True flag, which is also something # that probably shouldn't exist on join() due to its high # complexity/usefulness ratio - elif from_joinpoint and \ - isinstance(onclause, interfaces.PropComparator): + elif from_joinpoint and isinstance( + onclause, interfaces.PropComparator + ): jp0 = self._joinpoint_zero() info = inspect(jp0) - if getattr(info, 'mapper', None) is onclause._parententity: + if getattr(info, "mapper", None) is onclause._parententity: onclause = _entity_descriptor(jp0, onclause.key) if isinstance(onclause, interfaces.PropComparator): @@ -2256,8 +2339,7 @@ class Query(object): alias = self._polymorphic_adapters.get(left, None) # could be None or could be ColumnAdapter also - if isinstance(alias, ORMAdapter) and \ - alias.mapper.isa(left): + if isinstance(alias, ORMAdapter) and alias.mapper.isa(left): left = alias.aliased_class onclause = getattr(left, onclause.key) @@ -2278,14 +2360,15 @@ class Query(object): # and then mutate the child, which might be # shared by a different query object. jp = self._joinpoint[edge].copy() - jp['prev'] = (edge, self._joinpoint) + jp["prev"] = (edge, self._joinpoint) self._update_joinpoint(jp) # warn only on the last element of the list if idx == len(keylist) - 1: util.warn( "Pathed join target %s has already " - "been joined to; skipping" % prop) + "been joined to; skipping" % prop + ) continue else: # no descriptor/property given; we will need to figure out @@ -2295,13 +2378,12 @@ class Query(object): # figure out the final "left" and "right" sides and create an # ORMJoin to add to our _from_obj tuple self._join_left_to_right( - left, right, onclause, prop, create_aliases, - outerjoin, full + left, right, onclause, prop, create_aliases, outerjoin, full ) def _join_left_to_right( - self, left, right, onclause, prop, - create_aliases, outerjoin, full): + self, left, right, onclause, prop, create_aliases, outerjoin, full + ): """given raw "left", "right", "onclause" parameters consumed from a particular key within _join(), add a real ORMJoin object to our _from_obj list (or augment an existing one) @@ -2315,15 +2397,17 @@ class Query(object): # figure out the best "left" side based on our existing froms / # entities assert prop is None - left, replace_from_obj_index, use_entity_index = \ - self._join_determine_implicit_left_side(left, right, onclause) + left, replace_from_obj_index, use_entity_index = self._join_determine_implicit_left_side( + left, right, onclause + ) else: # left is given via a relationship/name. Determine where in our # "froms" list it should be spliced/appended as well as what # existing entity it corresponds to. assert prop is not None - replace_from_obj_index, use_entity_index = \ - self._join_place_explicit_left_side(left) + replace_from_obj_index, use_entity_index = self._join_place_explicit_left_side( + left + ) # this should never happen because we would not have found a place # to join on @@ -2333,7 +2417,7 @@ class Query(object): # a lot of things can be wrong with it. handle all that and # get back the new effective "right" side r_info, right, onclause = self._join_check_and_adapt_right_side( - left, right, onclause, prop, create_aliases, + left, right, onclause, prop, create_aliases ) if replace_from_obj_index is not None: @@ -2342,11 +2426,18 @@ class Query(object): left_clause = self._from_obj[replace_from_obj_index] self._from_obj = ( - self._from_obj[:replace_from_obj_index] + - (orm_join( - left_clause, right, - onclause, isouter=outerjoin, full=full), ) + - self._from_obj[replace_from_obj_index + 1:]) + self._from_obj[:replace_from_obj_index] + + ( + orm_join( + left_clause, + right, + onclause, + isouter=outerjoin, + full=full, + ), + ) + + self._from_obj[replace_from_obj_index + 1 :] + ) else: # add a new element to the self._from_obj list @@ -2358,8 +2449,8 @@ class Query(object): self._from_obj = self._from_obj + ( orm_join( - left_clause, right, onclause, - isouter=outerjoin, full=full), + left_clause, right, onclause, isouter=outerjoin, full=full + ), ) def _join_determine_implicit_left_side(self, left, right, onclause): @@ -2388,8 +2479,8 @@ class Query(object): # join has to connect to one of those FROMs. indexes = sql_util.find_left_clause_to_join_from( - self._from_obj, - r_info.selectable, onclause) + self._from_obj, r_info.selectable, onclause + ) if len(indexes) == 1: replace_from_obj_index = indexes[0] @@ -2399,12 +2490,13 @@ class Query(object): "Can't determine which FROM clause to join " "from, there are multiple FROMS which can " "join to this entity. Try adding an explicit ON clause " - "to help resolve the ambiguity.") + "to help resolve the ambiguity." + ) else: raise sa_exc.InvalidRequestError( "Don't know how to join to %s; please use " "an ON clause to more clearly establish the left " - "side of this join" % (right, ) + "side of this join" % (right,) ) elif self._entities: @@ -2430,7 +2522,8 @@ class Query(object): all_clauses = list(potential.keys()) indexes = sql_util.find_left_clause_to_join_from( - all_clauses, r_info.selectable, onclause) + all_clauses, r_info.selectable, onclause + ) if len(indexes) == 1: use_entity_index, left = potential[all_clauses[indexes[0]]] @@ -2439,18 +2532,20 @@ class Query(object): "Can't determine which FROM clause to join " "from, there are multiple FROMS which can " "join to this entity. Try adding an explicit ON clause " - "to help resolve the ambiguity.") + "to help resolve the ambiguity." + ) else: raise sa_exc.InvalidRequestError( "Don't know how to join to %s; please use " "an ON clause to more clearly establish the left " - "side of this join" % (right, ) + "side of this join" % (right,) ) else: raise sa_exc.InvalidRequestError( "No entities to join from; please use " "select_from() to establish the left " - "entity/selectable of this join") + "entity/selectable of this join" + ) return left, replace_from_obj_index, use_entity_index @@ -2484,13 +2579,15 @@ class Query(object): l_info = inspect(left) if self._from_obj: indexes = sql_util.find_left_clause_that_matches_given( - self._from_obj, l_info.selectable) + self._from_obj, l_info.selectable + ) if len(indexes) > 1: raise sa_exc.InvalidRequestError( "Can't identify which entity in which to assign the " "left side of this join. Please use a more specific " - "ON clause.") + "ON clause." + ) # have an index, means the left side is already present in # an existing FROM in the self._from_obj tuple @@ -2504,8 +2601,11 @@ class Query(object): # self._from_obj tuple. Determine if this left side matches up # with existing mapper entities, in which case we want to apply the # aliasing / adaptation rules present on that entity if any - if replace_from_obj_index is None and \ - self._entities and hasattr(l_info, 'mapper'): + if ( + replace_from_obj_index is None + and self._entities + and hasattr(l_info, "mapper") + ): for idx, ent in enumerate(self._entities): # TODO: should we be checking for multiple mapper entities # matching? @@ -2516,7 +2616,8 @@ class Query(object): return replace_from_obj_index, use_entity_index def _join_check_and_adapt_right_side( - self, left, right, onclause, prop, create_aliases): + self, left, right, onclause, prop, create_aliases + ): """transform the "right" side of the join as well as the onclause according to polymorphic mapping translations, aliasing on the query or on the join, special cases where the right and left side have @@ -2533,30 +2634,37 @@ class Query(object): # if the target is a joined inheritance mapping, # be more liberal about auto-aliasing. if right_mapper and ( - right_mapper.with_polymorphic or - isinstance(right_mapper.mapped_table, expression.Join) + right_mapper.with_polymorphic + or isinstance(right_mapper.mapped_table, expression.Join) ): for from_obj in self._from_obj or [l_info.selectable]: if sql_util.selectables_overlap( - l_info.selectable, from_obj) and \ - sql_util.selectables_overlap( - from_obj, r_info.selectable): + l_info.selectable, from_obj + ) and sql_util.selectables_overlap( + from_obj, r_info.selectable + ): overlap = True break - if (overlap or not create_aliases) and \ - l_info.selectable is r_info.selectable: + if ( + overlap or not create_aliases + ) and l_info.selectable is r_info.selectable: raise sa_exc.InvalidRequestError( - "Can't join table/selectable '%s' to itself" % - l_info.selectable) + "Can't join table/selectable '%s' to itself" + % l_info.selectable + ) - right_mapper, right_selectable, right_is_aliased = \ - getattr(r_info, 'mapper', None), \ - r_info.selectable, \ - getattr(r_info, 'is_aliased_class', False) + right_mapper, right_selectable, right_is_aliased = ( + getattr(r_info, "mapper", None), + r_info.selectable, + getattr(r_info, "is_aliased_class", False), + ) - if right_mapper and prop and \ - not right_mapper.common_parent(prop.mapper): + if ( + right_mapper + and prop + and not right_mapper.common_parent(prop.mapper) + ): raise sa_exc.InvalidRequestError( "Join target %s does not correspond to " "the right side of join condition %s" % (right, onclause) @@ -2564,8 +2672,8 @@ class Query(object): # _join_entities is used as a hint for single-table inheritance # purposes at the moment - if hasattr(r_info, 'mapper'): - self._join_entities += (r_info, ) + if hasattr(r_info, "mapper"): + self._join_entities += (r_info,) if not right_mapper and prop: right_mapper = prop.mapper @@ -2579,12 +2687,14 @@ class Query(object): right = self._adapt_clause(right, True, False) if right_mapper and right is right_selectable: - if not right_selectable.is_derived_from( - right_mapper.mapped_table): + if not right_selectable.is_derived_from(right_mapper.mapped_table): raise sa_exc.InvalidRequestError( - "Selectable '%s' is not derived from '%s'" % - (right_selectable.description, - right_mapper.mapped_table.description)) + "Selectable '%s' is not derived from '%s'" + % ( + right_selectable.description, + right_mapper.mapped_table.description, + ) + ) if isinstance(right_selectable, expression.SelectBase): # TODO: this isn't even covered now! @@ -2593,16 +2703,20 @@ class Query(object): right = aliased(right_mapper, right_selectable) - aliased_entity = right_mapper and \ - not right_is_aliased and \ - ( - right_mapper.with_polymorphic and isinstance( - right_mapper._with_polymorphic_selectable, - expression.Alias) or overlap + aliased_entity = ( + right_mapper + and not right_is_aliased + and ( + right_mapper.with_polymorphic + and isinstance( + right_mapper._with_polymorphic_selectable, expression.Alias + ) + or overlap # test for overlap: # orm/inheritance/relationships.py # SelfReferentialM2MTest ) + ) if not need_adapter and (create_aliases or aliased_entity): right = aliased(right, flat=True) @@ -2614,9 +2728,11 @@ class Query(object): if need_adapter: self._filter_aliases = ORMAdapter( right, - equivalents=right_mapper and - right_mapper._equivalent_columns or {}, - chain_to=self._filter_aliases) + equivalents=right_mapper + and right_mapper._equivalent_columns + or {}, + chain_to=self._filter_aliases, + ) # if the onclause is a ClauseElement, adapt it with any # adapters that are in place right now @@ -2631,20 +2747,21 @@ class Query(object): self._mapper_loads_polymorphically_with( right_mapper, ORMAdapter( - right, - equivalents=right_mapper._equivalent_columns - ) + right, equivalents=right_mapper._equivalent_columns + ), ) # if joining on a MapperProperty path, # track the path to prevent redundant joins if not create_aliases and prop: - self._update_joinpoint({ - '_joinpoint_entity': right, - 'prev': ((left, right, prop.key), self._joinpoint) - }) + self._update_joinpoint( + { + "_joinpoint_entity": right, + "prev": ((left, right, prop.key), self._joinpoint), + } + ) else: - self._joinpoint = {'_joinpoint_entity': right} + self._joinpoint = {"_joinpoint_entity": right} return right, inspect(right), onclause @@ -2821,27 +2938,30 @@ class Query(object): if isinstance(item, slice): start, stop, step = util.decode_slice(item) - if isinstance(stop, int) and \ - isinstance(start, int) and \ - stop - start <= 0: + if ( + isinstance(stop, int) + and isinstance(start, int) + and stop - start <= 0 + ): return [] # perhaps we should execute a count() here so that we # can still use LIMIT/OFFSET ? - elif (isinstance(start, int) and start < 0) \ - or (isinstance(stop, int) and stop < 0): + elif (isinstance(start, int) and start < 0) or ( + isinstance(stop, int) and stop < 0 + ): return list(self)[item] res = self.slice(start, stop) if step is not None: - return list(res)[None:None:item.step] + return list(res)[None : None : item.step] else: return list(res) else: if item == -1: return list(self)[-1] else: - return list(self[item:item + 1])[0] + return list(self[item : item + 1])[0] @_generative(_no_statement_condition) def slice(self, start, stop): @@ -3014,12 +3134,13 @@ class Query(object): """ statement = expression._expression_literal_as_text(statement) - if not isinstance(statement, - (expression.TextClause, - expression.SelectBase)): + if not isinstance( + statement, (expression.TextClause, expression.SelectBase) + ): raise sa_exc.ArgumentError( "from_statement accepts text(), select(), " - "and union() objects only.") + "and union() objects only." + ) self._statement = statement @@ -3082,7 +3203,8 @@ class Query(object): return None else: raise orm_exc.MultipleResultsFound( - "Multiple rows were found for one_or_none()") + "Multiple rows were found for one_or_none()" + ) def one(self): """Return exactly one result or raise an exception. @@ -3106,7 +3228,8 @@ class Query(object): ret = self.one_or_none() except orm_exc.MultipleResultsFound: raise orm_exc.MultipleResultsFound( - "Multiple rows were found for one()") + "Multiple rows were found for one()" + ) else: if ret is None: raise orm_exc.NoResultFound("No row was found for one()") @@ -3149,8 +3272,11 @@ class Query(object): def __str__(self): context = self._compile_context() try: - bind = self._get_bind_args( - context, self.session.get_bind) if self.session else None + bind = ( + self._get_bind_args(context, self.session.get_bind) + if self.session + else None + ) except sa_exc.UnboundExecutionError: bind = None return str(context.statement.compile(bind)) @@ -3163,24 +3289,22 @@ class Query(object): def _execute_and_instances(self, querycontext): conn = self._get_bind_args( - querycontext, - self._connection_from_session, - close_with_result=True) + querycontext, self._connection_from_session, close_with_result=True + ) result = conn.execute(querycontext.statement, self._params) return loading.instances(querycontext.query, result, querycontext) def _execute_crud(self, stmt, mapper): conn = self._connection_from_session( - mapper=mapper, clause=stmt, close_with_result=True) + mapper=mapper, clause=stmt, close_with_result=True + ) return conn.execute(stmt, self._params) def _get_bind_args(self, querycontext, fn, **kw): return fn( - mapper=self._bind_mapper(), - clause=querycontext.statement, - **kw + mapper=self._bind_mapper(), clause=querycontext.statement, **kw ) @property @@ -3225,21 +3349,23 @@ class Query(object): return [ { - 'name': ent._label_name, - 'type': ent.type, - 'aliased': getattr(insp_ent, 'is_aliased_class', False), - 'expr': ent.expr, - 'entity': - getattr(insp_ent, "entity", None) - if ent.entity_zero is not None - and not insp_ent.is_clause_element - else None + "name": ent._label_name, + "type": ent.type, + "aliased": getattr(insp_ent, "is_aliased_class", False), + "expr": ent.expr, + "entity": getattr(insp_ent, "entity", None) + if ent.entity_zero is not None + and not insp_ent.is_clause_element + else None, } for ent, insp_ent in [ ( _ent, - (inspect(_ent.entity_zero) - if _ent.entity_zero is not None else None) + ( + inspect(_ent.entity_zero) + if _ent.entity_zero is not None + else None + ), ) for _ent in self._entities ] @@ -3290,21 +3416,23 @@ class Query(object): @property def _select_args(self): return { - 'limit': self._limit, - 'offset': self._offset, - 'distinct': self._distinct, - 'prefixes': self._prefixes, - 'suffixes': self._suffixes, - 'group_by': self._group_by or None, - 'having': self._having + "limit": self._limit, + "offset": self._offset, + "distinct": self._distinct, + "prefixes": self._prefixes, + "suffixes": self._suffixes, + "group_by": self._group_by or None, + "having": self._having, } @property def _should_nest_selectable(self): kwargs = self._select_args - return (kwargs.get('limit') is not None or - kwargs.get('offset') is not None or - kwargs.get('distinct', False)) + return ( + kwargs.get("limit") is not None + or kwargs.get("offset") is not None + or kwargs.get("distinct", False) + ) def exists(self): """A convenience method that turns a query into an EXISTS subquery @@ -3343,9 +3471,12 @@ class Query(object): # omitting the FROM clause from a query(X) (#2818); # .with_only_columns() after we have a core select() so that # we get just "SELECT 1" without any entities. - return sql.exists(self.enable_eagerloads(False).add_columns('1'). - with_labels(). - statement.with_only_columns([1])) + return sql.exists( + self.enable_eagerloads(False) + .add_columns("1") + .with_labels() + .statement.with_only_columns([1]) + ) def count(self): r"""Return a count of rows this Query would return. @@ -3384,10 +3515,10 @@ class Query(object): session.query(func.count(distinct(User.name))) """ - col = sql.func.count(sql.literal_column('*')) + col = sql.func.count(sql.literal_column("*")) return self.from_self(col).scalar() - def delete(self, synchronize_session='evaluate'): + def delete(self, synchronize_session="evaluate"): r"""Perform a bulk delete query. Deletes rows matched by this query from the database. @@ -3506,12 +3637,11 @@ class Query(object): """ - delete_op = persistence.BulkDelete.factory( - self, synchronize_session) + delete_op = persistence.BulkDelete.factory(self, synchronize_session) delete_op.exec_() return delete_op.rowcount - def update(self, values, synchronize_session='evaluate', update_args=None): + def update(self, values, synchronize_session="evaluate", update_args=None): r"""Perform a bulk update query. Updates rows matched by this query in the database. @@ -3640,7 +3770,8 @@ class Query(object): update_args = update_args or {} update_op = persistence.BulkUpdate.factory( - self, synchronize_session, values, update_args) + self, synchronize_session, values, update_args + ) update_op.exec_() return update_op.rowcount @@ -3682,11 +3813,12 @@ class Query(object): raise sa_exc.InvalidRequestError( "No column-based properties specified for " "refresh operation. Use session.expire() " - "to reload collections and related items.") + "to reload collections and related items." + ) else: raise sa_exc.InvalidRequestError( - "Query contains no columns with which to " - "SELECT from.") + "Query contains no columns with which to " "SELECT from." + ) if context.multi_row_eager_loaders and self._should_nest_selectable: context.statement = self._compound_eager_statement(context) @@ -3701,11 +3833,9 @@ class Query(object): # then append eager joins onto that if context.order_by: - order_by_col_expr = \ - sql_util.expand_column_list_from_order_by( - context.primary_columns, - context.order_by - ) + order_by_col_expr = sql_util.expand_column_list_from_order_by( + context.primary_columns, context.order_by + ) else: context.order_by = None order_by_col_expr = [] @@ -3738,15 +3868,17 @@ class Query(object): context.adapter = sql_util.ColumnAdapter(inner, equivs) statement = sql.select( - [inner] + context.secondary_columns, - use_labels=context.labels) + [inner] + context.secondary_columns, use_labels=context.labels + ) # Oracle however does not allow FOR UPDATE on the subquery, # and the Oracle dialect ignores it, plus for PostgreSQL, MySQL # we expect that all elements of the row are locked, so also put it # on the outside (except in the case of PG when OF is used) - if context._for_update_arg is not None and \ - context._for_update_arg.of is None: + if ( + context._for_update_arg is not None + and context._for_update_arg.of is None + ): statement._for_update_arg = context._for_update_arg from_clause = inner @@ -3755,16 +3887,14 @@ class Query(object): # giving us a marker as to where the "splice point" of # the join should be from_clause = sql_util.splice_joins( - from_clause, - eager_join, eager_join.stop_on) + from_clause, eager_join, eager_join.stop_on + ) statement.append_from(from_clause) if context.order_by: statement.append_order_by( - *context.adapter.copy_and_process( - context.order_by - ) + *context.adapter.copy_and_process(context.order_by) ) statement.append_order_by(*context.eager_order_by) @@ -3775,16 +3905,13 @@ class Query(object): context.order_by = None if self._distinct is True and context.order_by: - context.primary_columns += \ - sql_util.expand_column_list_from_order_by( - context.primary_columns, - context.order_by - ) + context.primary_columns += sql_util.expand_column_list_from_order_by( + context.primary_columns, context.order_by + ) context.froms += tuple(context.eager_joins.values()) statement = sql.select( - context.primary_columns + - context.secondary_columns, + context.primary_columns + context.secondary_columns, context.whereclause, from_obj=context.froms, use_labels=context.labels, @@ -3815,8 +3942,10 @@ class Query(object): """ search = set(self._mapper_adapter_map.values()) - if self._select_from_entity and \ - self._select_from_entity not in self._mapper_adapter_map: + if ( + self._select_from_entity + and self._select_from_entity not in self._mapper_adapter_map + ): insp = inspect(self._select_from_entity) if insp.is_aliased_class: adapter = insp._adapter @@ -3833,8 +3962,8 @@ class Query(object): single_crit = adapter.traverse(single_crit) single_crit = self._adapt_clause(single_crit, False, False) context.whereclause = sql.and_( - sql.True_._ifnone(context.whereclause), - single_crit) + sql.True_._ifnone(context.whereclause), single_crit + ) from ..sql.selectable import ForUpdateArg @@ -3856,7 +3985,8 @@ class LockmodeArg(ForUpdateArg): read = False else: raise sa_exc.ArgumentError( - "Unknown with_lockmode argument: %r" % mode) + "Unknown with_lockmode argument: %r" % mode + ) return LockmodeArg(read=read, nowait=nowait) @@ -3867,8 +3997,9 @@ class _QueryEntity(object): def __new__(cls, *args, **kwargs): if cls is _QueryEntity: entity = args[1] - if not isinstance(entity, util.string_types) and \ - _is_mapped_class(entity): + if not isinstance(entity, util.string_types) and _is_mapped_class( + entity + ): cls = _MapperEntity elif isinstance(entity, Bundle): cls = _BundleEntity @@ -3903,8 +4034,7 @@ class _MapperEntity(_QueryEntity): self.selectable = ext_info.selectable self.is_aliased_class = ext_info.is_aliased_class self._with_polymorphic = ext_info.with_polymorphic_mappers - self._polymorphic_discriminator = \ - ext_info.polymorphic_on + self._polymorphic_discriminator = ext_info.polymorphic_on self.entity_zero = ext_info if ext_info.is_aliased_class: self._label_name = self.entity_zero.name @@ -3912,8 +4042,9 @@ class _MapperEntity(_QueryEntity): self._label_name = self.mapper.class_.__name__ self.path = self.entity_zero._path_registry - def set_with_polymorphic(self, query, cls_or_mappers, - selectable, polymorphic_on): + def set_with_polymorphic( + self, query, cls_or_mappers, selectable, polymorphic_on + ): """Receive an update from a call to query.with_polymorphic(). Note the newer style of using a free standing with_polymporphic() @@ -3924,8 +4055,7 @@ class _MapperEntity(_QueryEntity): if self.is_aliased_class: # TODO: invalidrequest ? raise NotImplementedError( - "Can't use with_polymorphic() against " - "an Aliased object" + "Can't use with_polymorphic() against " "an Aliased object" ) if cls_or_mappers is None: @@ -3933,14 +4063,16 @@ class _MapperEntity(_QueryEntity): return mappers, from_obj = self.mapper._with_polymorphic_args( - cls_or_mappers, selectable) + cls_or_mappers, selectable + ) self._with_polymorphic = mappers self._polymorphic_discriminator = polymorphic_on self.selectable = from_obj query._mapper_loads_polymorphically_with( - self.mapper, sql_util.ColumnAdapter( - from_obj, self.mapper._equivalent_columns)) + self.mapper, + sql_util.ColumnAdapter(from_obj, self.mapper._equivalent_columns), + ) @property def type(self): @@ -3989,8 +4121,8 @@ class _MapperEntity(_QueryEntity): # require row aliasing unconditionally. if not adapter and self.mapper._requires_row_aliasing: adapter = sql_util.ColumnAdapter( - self.selectable, - self.mapper._equivalent_columns) + self.selectable, self.mapper._equivalent_columns + ) if query._primary_entity is self: only_load_props = query._only_load_props @@ -4006,7 +4138,7 @@ class _MapperEntity(_QueryEntity): adapter, only_load_props=only_load_props, refresh_state=refresh_state, - polymorphic_discriminator=self._polymorphic_discriminator + polymorphic_discriminator=self._polymorphic_discriminator, ) return _instance, self._label_name @@ -4023,17 +4155,19 @@ class _MapperEntity(_QueryEntity): # apply adaptation to the mapper's order_by if needed. if adapter: context.order_by = adapter.adapt_list( - util.to_list( - context.order_by - ) + util.to_list(context.order_by) ) loading._setup_entity_query( - context, self.mapper, self, - self.path, adapter, context.primary_columns, + context, + self.mapper, + self, + self.path, + adapter, + context.primary_columns, with_polymorphic=self._with_polymorphic, only_load_props=query._only_load_props, - polymorphic_discriminator=self._polymorphic_discriminator + polymorphic_discriminator=self._polymorphic_discriminator, ) def __str__(self): @@ -4091,9 +4225,10 @@ class Bundle(InspectionAttr): self.name = self._label = name self.exprs = exprs self.c = self.columns = ColumnCollection() - self.columns.update((getattr(col, "key", col._label), col) - for col in exprs) - self.single_entity = kw.pop('single_entity', self.single_entity) + self.columns.update( + (getattr(col, "key", col._label), col) for col in exprs + ) + self.single_entity = kw.pop("single_entity", self.single_entity) columns = None """A namespace of SQL expressions referred to by this :class:`.Bundle`. @@ -4152,10 +4287,11 @@ class Bundle(InspectionAttr): :ref:`bundles` - includes an example of subclassing. """ - keyed_tuple = util.lightweight_named_tuple('result', labels) + keyed_tuple = util.lightweight_named_tuple("result", labels) def proc(row): return keyed_tuple([proc(row) for proc in procs]) + return proc @@ -4235,8 +4371,10 @@ class _BundleEntity(_QueryEntity): def row_processor(self, query, context, result): procs, labels = zip( - *[ent.row_processor(query, context, result) - for ent in self._entities] + *[ + ent.row_processor(query, context, result) + for ent in self._entities + ] ) proc = self.bundle.create_row_processor(query, procs, labels) @@ -4259,11 +4397,10 @@ class _ColumnEntity(_QueryEntity): search_entities = False check_column = True _entity = None - elif isinstance(column, ( - attributes.QueryableAttribute, - interfaces.PropComparator - )): - _entity = getattr(column, '_parententity', None) + elif isinstance( + column, (attributes.QueryableAttribute, interfaces.PropComparator) + ): + _entity = getattr(column, "_parententity", None) if _entity is not None: search_entities = False self._label_name = column.key @@ -4274,7 +4411,7 @@ class _ColumnEntity(_QueryEntity): return if not isinstance(column, sql.ColumnElement): - if hasattr(column, '_select_iterable'): + if hasattr(column, "_select_iterable"): # break out an object like Table into # individual columns for c in column._select_iterable: @@ -4286,10 +4423,10 @@ class _ColumnEntity(_QueryEntity): raise sa_exc.InvalidRequestError( "SQL expression, column, or mapped entity " - "expected - got '%r'" % (column, ) + "expected - got '%r'" % (column,) ) elif not check_column: - self._label_name = getattr(column, 'key', None) + self._label_name = getattr(column, "key", None) search_entities = True self.type = type_ = column.type @@ -4301,7 +4438,7 @@ class _ColumnEntity(_QueryEntity): # if the expression's identity has been changed # due to adaption. - if not column._label and not getattr(column, 'is_literal', False): + if not column._label and not getattr(column, "is_literal", False): column = column.label(self._label_name) query._entities.append(self) @@ -4328,23 +4465,29 @@ class _ColumnEntity(_QueryEntity): self._from_entities = set(self.entities) else: all_elements = [ - elem for elem in sql_util.surface_column_elements( - column, include_scalar_selects=False) - if 'parententity' in elem._annotations + elem + for elem in sql_util.surface_column_elements( + column, include_scalar_selects=False + ) + if "parententity" in elem._annotations ] - self.entities = util.unique_list([ - elem._annotations['parententity'] - for elem in all_elements - if 'parententity' in elem._annotations - ]) - - self._from_entities = set([ - elem._annotations['parententity'] - for elem in all_elements - if 'parententity' in elem._annotations - and actual_froms.intersection(elem._from_objects) - ]) + self.entities = util.unique_list( + [ + elem._annotations["parententity"] + for elem in all_elements + if "parententity" in elem._annotations + ] + ) + + self._from_entities = set( + [ + elem._annotations["parententity"] + for elem in all_elements + if "parententity" in elem._annotations + and actual_froms.intersection(elem._from_objects) + ] + ) if self.entities: self.entity_zero = self.entities[0] self.mapper = self.entity_zero.mapper @@ -4373,7 +4516,7 @@ class _ColumnEntity(_QueryEntity): c.entities = self.entities def setup_entity(self, ext_info, aliased_adapter): - if 'selectable' not in self.__dict__: + if "selectable" not in self.__dict__: self.selectable = ext_info.selectable if self.actual_froms.intersection(ext_info.selectable._from_objects): @@ -4386,12 +4529,13 @@ class _ColumnEntity(_QueryEntity): # TODO: polymorphic subclasses ? return entity is self.entity_zero else: - return not _is_aliased_class(self.entity_zero) and \ - entity.common_parent(self.entity_zero) + return not _is_aliased_class( + self.entity_zero + ) and entity.common_parent(self.entity_zero) def row_processor(self, query, context, result): - if ('fetch_column', self) in context.attributes: - column = context.attributes[('fetch_column', self)] + if ("fetch_column", self) in context.attributes: + column = context.attributes[("fetch_column", self)] else: column = query._adapt_clause(self.column, False, True) @@ -4417,7 +4561,7 @@ class _ColumnEntity(_QueryEntity): context.froms += tuple(self.froms) context.primary_columns.append(column) - context.attributes[('fetch_column', self)] = column + context.attributes[("fetch_column", self)] = column def __str__(self): return str(self.column) @@ -4425,22 +4569,44 @@ class _ColumnEntity(_QueryEntity): class QueryContext(object): __slots__ = ( - 'multi_row_eager_loaders', 'adapter', 'froms', 'for_update', - 'query', 'session', 'autoflush', 'populate_existing', - 'invoke_all_eagers', 'version_check', 'refresh_state', - 'primary_columns', 'secondary_columns', 'eager_order_by', - 'eager_joins', 'create_eager_joins', 'propagate_options', - 'attributes', 'statement', 'from_clause', 'whereclause', - 'order_by', 'labels', '_for_update_arg', 'runid', 'partials', - 'post_load_paths', 'identity_token' + "multi_row_eager_loaders", + "adapter", + "froms", + "for_update", + "query", + "session", + "autoflush", + "populate_existing", + "invoke_all_eagers", + "version_check", + "refresh_state", + "primary_columns", + "secondary_columns", + "eager_order_by", + "eager_joins", + "create_eager_joins", + "propagate_options", + "attributes", + "statement", + "from_clause", + "whereclause", + "order_by", + "labels", + "_for_update_arg", + "runid", + "partials", + "post_load_paths", + "identity_token", ) def __init__(self, query): if query._statement is not None: - if isinstance(query._statement, expression.SelectBase) and \ - not query._statement._textual and \ - not query._statement.use_labels: + if ( + isinstance(query._statement, expression.SelectBase) + and not query._statement._textual + and not query._statement.use_labels + ): self.statement = query._statement.apply_labels() else: self.statement = query._statement @@ -4466,8 +4632,9 @@ class QueryContext(object): self.eager_order_by = [] self.eager_joins = {} self.create_eager_joins = [] - self.propagate_options = set(o for o in query._with_options if - o.propagate_to_loaders) + self.propagate_options = set( + o for o in query._with_options if o.propagate_to_loaders + ) self.attributes = query._attributes.copy() if self.refresh_state is not None: self.identity_token = query._refresh_identity_token @@ -4476,7 +4643,6 @@ class QueryContext(object): class AliasOption(interfaces.MapperOption): - def __init__(self, alias): r"""Return a :class:`.MapperOption` that will indicate to the :class:`.Query` that the main table has been aliased. diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index e7896c4234..e89d1542fa 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -22,14 +22,23 @@ from . import dependency from . import attributes from ..sql.util import ( ClauseAdapter, - join_condition, _shallow_annotate, visit_binary_product, - _deep_deannotate, selectables_overlap, adapt_criterion_to_null + join_condition, + _shallow_annotate, + visit_binary_product, + _deep_deannotate, + selectables_overlap, + adapt_criterion_to_null, ) from .base import state_str from ..sql import operators, expression, visitors -from .interfaces import (MANYTOMANY, MANYTOONE, ONETOMANY, - StrategizedProperty, PropComparator) +from .interfaces import ( + MANYTOMANY, + MANYTOONE, + ONETOMANY, + StrategizedProperty, + PropComparator, +) from ..inspection import inspect from . import mapper as mapperlib import collections @@ -51,8 +60,9 @@ def remote(expr): :func:`.foreign` """ - return _annotate_columns(expression._clause_element_as_expr(expr), - {"remote": True}) + return _annotate_columns( + expression._clause_element_as_expr(expr), {"remote": True} + ) def foreign(expr): @@ -72,8 +82,9 @@ def foreign(expr): """ - return _annotate_columns(expression._clause_element_as_expr(expr), - {"foreign": True}) + return _annotate_columns( + expression._clause_element_as_expr(expr), {"foreign": True} + ) @log.class_logger @@ -90,36 +101,46 @@ class RelationshipProperty(StrategizedProperty): """ - strategy_wildcard_key = 'relationship' + strategy_wildcard_key = "relationship" _dependency_processor = None - def __init__(self, argument, - secondary=None, primaryjoin=None, - secondaryjoin=None, - foreign_keys=None, - uselist=None, - order_by=False, - backref=None, - back_populates=None, - post_update=False, - cascade=False, extension=None, - viewonly=False, lazy="select", - collection_class=None, passive_deletes=False, - passive_updates=True, remote_side=None, - enable_typechecks=True, join_depth=None, - comparator_factory=None, - single_parent=False, innerjoin=False, - distinct_target_key=None, - doc=None, - active_history=False, - cascade_backrefs=True, - load_on_pending=False, - bake_queries=True, - _local_remote_pairs=None, - query_class=None, - info=None, - omit_join=None): + def __init__( + self, + argument, + secondary=None, + primaryjoin=None, + secondaryjoin=None, + foreign_keys=None, + uselist=None, + order_by=False, + backref=None, + back_populates=None, + post_update=False, + cascade=False, + extension=None, + viewonly=False, + lazy="select", + collection_class=None, + passive_deletes=False, + passive_updates=True, + remote_side=None, + enable_typechecks=True, + join_depth=None, + comparator_factory=None, + single_parent=False, + innerjoin=False, + distinct_target_key=None, + doc=None, + active_history=False, + cascade_backrefs=True, + load_on_pending=False, + bake_queries=True, + _local_remote_pairs=None, + query_class=None, + info=None, + omit_join=None, + ): """Provide a relationship between two mapped classes. This corresponds to a parent-child or associative table relationship. @@ -858,20 +879,22 @@ class RelationshipProperty(StrategizedProperty): self.extension = extension self.bake_queries = bake_queries self.load_on_pending = load_on_pending - self.comparator_factory = comparator_factory or \ - RelationshipProperty.Comparator + self.comparator_factory = ( + comparator_factory or RelationshipProperty.Comparator + ) self.comparator = self.comparator_factory(self, None) util.set_creation_order(self) if info is not None: self.info = info - self.strategy_key = (("lazy", self.lazy), ) + self.strategy_key = (("lazy", self.lazy),) self._reverse_property = set() - self.cascade = cascade if cascade is not False \ - else "save-update, merge" + self.cascade = ( + cascade if cascade is not False else "save-update, merge" + ) self.order_by = order_by @@ -881,7 +904,8 @@ class RelationshipProperty(StrategizedProperty): if backref: raise sa_exc.ArgumentError( "backref and back_populates keyword arguments " - "are mutually exclusive") + "are mutually exclusive" + ) self.backref = None else: self.backref = backref @@ -919,7 +943,8 @@ class RelationshipProperty(StrategizedProperty): _of_type = None def __init__( - self, prop, parentmapper, adapt_to_entity=None, of_type=None): + self, prop, parentmapper, adapt_to_entity=None, of_type=None + ): """Construction of :class:`.RelationshipProperty.Comparator` is internal to the ORM's attribute mechanics. @@ -931,9 +956,12 @@ class RelationshipProperty(StrategizedProperty): self._of_type = of_type def adapt_to_entity(self, adapt_to_entity): - return self.__class__(self.property, self._parententity, - adapt_to_entity=adapt_to_entity, - of_type=self._of_type) + return self.__class__( + self.property, + self._parententity, + adapt_to_entity=adapt_to_entity, + of_type=self._of_type, + ) @util.memoized_property def mapper(self): @@ -963,11 +991,11 @@ class RelationshipProperty(StrategizedProperty): else: of_type = None - pj, sj, source, dest, \ - secondary, target_adapter = self.property._create_joins( - source_selectable=adapt_from, - source_polymorphic=True, - of_type=of_type) + pj, sj, source, dest, secondary, target_adapter = self.property._create_joins( + source_selectable=adapt_from, + source_polymorphic=True, + of_type=of_type, + ) if sj is not None: return pj & sj else: @@ -983,17 +1011,20 @@ class RelationshipProperty(StrategizedProperty): self.property, self._parententity, adapt_to_entity=self._adapt_to_entity, - of_type=cls) + of_type=cls, + ) def in_(self, other): """Produce an IN clause - this is not implemented for :func:`~.orm.relationship`-based attributes at this time. """ - raise NotImplementedError('in_() not yet supported for ' - 'relationships. For a simple ' - 'many-to-one, use in_() against ' - 'the set of foreign key values.') + raise NotImplementedError( + "in_() not yet supported for " + "relationships. For a simple " + "many-to-one, use in_() against " + "the set of foreign key values." + ) __hash__ = None @@ -1038,24 +1069,32 @@ class RelationshipProperty(StrategizedProperty): if self.property.direction in [ONETOMANY, MANYTOMANY]: return ~self._criterion_exists() else: - return _orm_annotate(self.property._optimized_compare( - None, adapt_source=self.adapter)) + return _orm_annotate( + self.property._optimized_compare( + None, adapt_source=self.adapter + ) + ) elif self.property.uselist: raise sa_exc.InvalidRequestError( "Can't compare a collection to an object or collection; " - "use contains() to test for membership.") + "use contains() to test for membership." + ) else: return _orm_annotate( self.property._optimized_compare( - other, adapt_source=self.adapter)) + other, adapt_source=self.adapter + ) + ) def _criterion_exists(self, criterion=None, **kwargs): - if getattr(self, '_of_type', None): + if getattr(self, "_of_type", None): info = inspect(self._of_type) - target_mapper, to_selectable, is_aliased_class = \ - info.mapper, info.selectable, info.is_aliased_class - if self.property._is_self_referential and not \ - is_aliased_class: + target_mapper, to_selectable, is_aliased_class = ( + info.mapper, + info.selectable, + info.is_aliased_class, + ) + if self.property._is_self_referential and not is_aliased_class: to_selectable = to_selectable.alias() single_crit = target_mapper._single_table_criterion @@ -1073,11 +1112,11 @@ class RelationshipProperty(StrategizedProperty): else: source_selectable = None - pj, sj, source, dest, secondary, target_adapter = \ - self.property._create_joins( - dest_polymorphic=True, - dest_selectable=to_selectable, - source_selectable=source_selectable) + pj, sj, source, dest, secondary, target_adapter = self.property._create_joins( + dest_polymorphic=True, + dest_selectable=to_selectable, + source_selectable=source_selectable, + ) for k in kwargs: crit = getattr(self.property.mapper.class_, k) == kwargs[k] @@ -1094,8 +1133,11 @@ class RelationshipProperty(StrategizedProperty): else: j = _orm_annotate(pj, exclude=self.property.remote_side) - if criterion is not None and target_adapter and not \ - is_aliased_class: + if ( + criterion is not None + and target_adapter + and not is_aliased_class + ): # limit this adapter to annotated only? criterion = target_adapter.traverse(criterion) @@ -1106,16 +1148,19 @@ class RelationshipProperty(StrategizedProperty): # to anything in the enclosing query. if criterion is not None: criterion = criterion._annotate( - {'no_replacement_traverse': True}) + {"no_replacement_traverse": True} + ) crit = j & sql.True_._ifnone(criterion) if secondary is not None: - ex = sql.exists([1], crit, from_obj=[dest, secondary]).\ - correlate_except(dest, secondary) + ex = sql.exists( + [1], crit, from_obj=[dest, secondary] + ).correlate_except(dest, secondary) else: - ex = sql.exists([1], crit, from_obj=dest).\ - correlate_except(dest) + ex = sql.exists([1], crit, from_obj=dest).correlate_except( + dest + ) return ex def any(self, criterion=None, **kwargs): @@ -1197,8 +1242,8 @@ class RelationshipProperty(StrategizedProperty): """ if self.property.uselist: raise sa_exc.InvalidRequestError( - "'has()' not implemented for collections. " - "Use any().") + "'has()' not implemented for collections. " "Use any()." + ) return self._criterion_exists(criterion, **kwargs) def contains(self, other, **kwargs): @@ -1260,13 +1305,16 @@ class RelationshipProperty(StrategizedProperty): if not self.property.uselist: raise sa_exc.InvalidRequestError( "'contains' not implemented for scalar " - "attributes. Use ==") + "attributes. Use ==" + ) clause = self.property._optimized_compare( - other, adapt_source=self.adapter) + other, adapt_source=self.adapter + ) if self.property.secondaryjoin is not None: - clause.negation_clause = \ - self.__negated_contains_or_equals(other) + clause.negation_clause = self.__negated_contains_or_equals( + other + ) return clause @@ -1277,10 +1325,11 @@ class RelationshipProperty(StrategizedProperty): def state_bindparam(x, state, col): dict_ = state.dict return sql.bindparam( - x, unique=True, + x, + unique=True, callable_=self.property._get_attr_w_warn_on_none( self.property.mapper, state, dict_, col - ) + ), ) def adapt(col): @@ -1290,19 +1339,26 @@ class RelationshipProperty(StrategizedProperty): return col if self.property._use_get: - return sql.and_(*[ - sql.or_( - adapt(x) != state_bindparam(adapt(x), state, y), - adapt(x) == None) - for (x, y) in self.property.local_remote_pairs]) - - criterion = sql.and_(*[ - x == y for (x, y) in - zip( - self.property.mapper.primary_key, - self.property.mapper.primary_key_from_instance(other) - ) - ]) + return sql.and_( + *[ + sql.or_( + adapt(x) + != state_bindparam(adapt(x), state, y), + adapt(x) == None, + ) + for (x, y) in self.property.local_remote_pairs + ] + ) + + criterion = sql.and_( + *[ + x == y + for (x, y) in zip( + self.property.mapper.primary_key, + self.property.mapper.primary_key_from_instance(other), + ) + ] + ) return ~self._criterion_exists(criterion) @@ -1347,8 +1403,11 @@ class RelationshipProperty(StrategizedProperty): """ if isinstance(other, (util.NoneType, expression.Null)): if self.property.direction == MANYTOONE: - return _orm_annotate(~self.property._optimized_compare( - None, adapt_source=self.adapter)) + return _orm_annotate( + ~self.property._optimized_compare( + None, adapt_source=self.adapter + ) + ) else: return self._criterion_exists() @@ -1356,7 +1415,8 @@ class RelationshipProperty(StrategizedProperty): raise sa_exc.InvalidRequestError( "Can't compare a collection" " to an object or collection; use " - "contains() to test for membership.") + "contains() to test for membership." + ) else: return _orm_annotate(self.__negated_contains_or_equals(other)) @@ -1374,12 +1434,19 @@ class RelationshipProperty(StrategizedProperty): if insp.is_aliased_class: adapt_source = insp._adapter.adapt_clause return self._optimized_compare( - instance, value_is_parent=True, adapt_source=adapt_source, - alias_secondary=alias_secondary) + instance, + value_is_parent=True, + adapt_source=adapt_source, + alias_secondary=alias_secondary, + ) - def _optimized_compare(self, state, value_is_parent=False, - adapt_source=None, - alias_secondary=True): + def _optimized_compare( + self, + state, + value_is_parent=False, + adapt_source=None, + alias_secondary=True, + ): if state is not None: state = attributes.instance_state(state) @@ -1387,17 +1454,19 @@ class RelationshipProperty(StrategizedProperty): if state is None: return self._lazy_none_clause( - reverse_direction, - adapt_source=adapt_source) + reverse_direction, adapt_source=adapt_source + ) if not reverse_direction: - criterion, bind_to_col = \ - self._lazy_strategy._lazywhere, \ - self._lazy_strategy._bind_to_col + criterion, bind_to_col = ( + self._lazy_strategy._lazywhere, + self._lazy_strategy._bind_to_col, + ) else: - criterion, bind_to_col = \ - self._lazy_strategy._rev_lazywhere, \ - self._lazy_strategy._rev_bind_to_col + criterion, bind_to_col = ( + self._lazy_strategy._rev_lazywhere, + self._lazy_strategy._rev_bind_to_col, + ) if reverse_direction: mapper = self.mapper @@ -1409,16 +1478,20 @@ class RelationshipProperty(StrategizedProperty): def visit_bindparam(bindparam): if bindparam._identifying_key in bind_to_col: bindparam.callable = self._get_attr_w_warn_on_none( - mapper, state, dict_, - bind_to_col[bindparam._identifying_key]) + mapper, + state, + dict_, + bind_to_col[bindparam._identifying_key], + ) if self.secondary is not None and alias_secondary: - criterion = ClauseAdapter( - self.secondary.alias()).\ - traverse(criterion) + criterion = ClauseAdapter(self.secondary.alias()).traverse( + criterion + ) criterion = visitors.cloned_traverse( - criterion, {}, {'bindparam': visit_bindparam}) + criterion, {}, {"bindparam": visit_bindparam} + ) if adapt_source: criterion = adapt_source(criterion) @@ -1483,25 +1556,27 @@ class RelationshipProperty(StrategizedProperty): # only if we can't get a value now due to detachment do we return # the last known value current_value = mapper._get_state_attr_by_column( - state, dict_, column, + state, + dict_, + column, passive=attributes.PASSIVE_RETURN_NEVER_SET if state.persistent - else attributes.PASSIVE_NO_FETCH ^ attributes.INIT_OK) + else attributes.PASSIVE_NO_FETCH ^ attributes.INIT_OK, + ) if current_value is attributes.NEVER_SET: if not existing_is_available: raise sa_exc.InvalidRequestError( "Can't resolve value for column %s on object " - "%s; no value has been set for this column" % ( - column, state_str(state)) + "%s; no value has been set for this column" + % (column, state_str(state)) ) elif current_value is attributes.PASSIVE_NO_RESULT: if not existing_is_available: raise sa_exc.InvalidRequestError( "Can't resolve value for column %s on object " "%s; the object is detached and the value was " - "expired" % ( - column, state_str(state)) + "expired" % (column, state_str(state)) ) else: to_return = current_value @@ -1510,19 +1585,23 @@ class RelationshipProperty(StrategizedProperty): "Got None for value of column %s; this is unsupported " "for a relationship comparison and will not " "currently produce an IS comparison " - "(but may in a future release)" % column) + "(but may in a future release)" % column + ) return to_return + return _go def _lazy_none_clause(self, reverse_direction=False, adapt_source=None): if not reverse_direction: - criterion, bind_to_col = \ - self._lazy_strategy._lazywhere, \ - self._lazy_strategy._bind_to_col + criterion, bind_to_col = ( + self._lazy_strategy._lazywhere, + self._lazy_strategy._bind_to_col, + ) else: - criterion, bind_to_col = \ - self._lazy_strategy._rev_lazywhere, \ - self._lazy_strategy._rev_bind_to_col + criterion, bind_to_col = ( + self._lazy_strategy._rev_lazywhere, + self._lazy_strategy._rev_bind_to_col, + ) criterion = adapt_criterion_to_null(criterion, bind_to_col) @@ -1533,13 +1612,17 @@ class RelationshipProperty(StrategizedProperty): def __str__(self): return str(self.parent.class_.__name__) + "." + self.key - def merge(self, - session, - source_state, - source_dict, - dest_state, - dest_dict, - load, _recursive, _resolve_conflict_map): + def merge( + self, + session, + source_state, + source_dict, + dest_state, + dest_dict, + load, + _recursive, + _resolve_conflict_map, + ): if load: for r in self._reverse_property: @@ -1553,9 +1636,10 @@ class RelationshipProperty(StrategizedProperty): return if self.uselist: - instances = source_state.get_impl(self.key).\ - get(source_state, source_dict) - if hasattr(instances, '_sa_adapter'): + instances = source_state.get_impl(self.key).get( + source_state, source_dict + ) + if hasattr(instances, "_sa_adapter"): # convert collections to adapters to get a true iterator instances = instances._sa_adapter @@ -1573,21 +1657,25 @@ class RelationshipProperty(StrategizedProperty): current_dict = attributes.instance_dict(current) _recursive[(current_state, self)] = True obj = session._merge( - current_state, current_dict, - load=load, _recursive=_recursive, - _resolve_conflict_map=_resolve_conflict_map) + current_state, + current_dict, + load=load, + _recursive=_recursive, + _resolve_conflict_map=_resolve_conflict_map, + ) if obj is not None: dest_list.append(obj) if not load: - coll = attributes.init_state_collection(dest_state, - dest_dict, self.key) + coll = attributes.init_state_collection( + dest_state, dest_dict, self.key + ) for c in dest_list: coll.append_without_event(c) else: dest_state.get_impl(self.key).set( - dest_state, dest_dict, dest_list, - _adapt=False) + dest_state, dest_dict, dest_list, _adapt=False + ) else: current = source_dict[self.key] if current is not None: @@ -1595,20 +1683,25 @@ class RelationshipProperty(StrategizedProperty): current_dict = attributes.instance_dict(current) _recursive[(current_state, self)] = True obj = session._merge( - current_state, current_dict, - load=load, _recursive=_recursive, - _resolve_conflict_map=_resolve_conflict_map) + current_state, + current_dict, + load=load, + _recursive=_recursive, + _resolve_conflict_map=_resolve_conflict_map, + ) else: obj = None if not load: dest_dict[self.key] = obj else: - dest_state.get_impl(self.key).set(dest_state, - dest_dict, obj, None) + dest_state.get_impl(self.key).set( + dest_state, dest_dict, obj, None + ) - def _value_as_iterable(self, state, dict_, key, - passive=attributes.PASSIVE_OFF): + def _value_as_iterable( + self, state, dict_, key, passive=attributes.PASSIVE_OFF + ): """Return a list of tuples (state, obj) for the given key. @@ -1619,34 +1712,36 @@ class RelationshipProperty(StrategizedProperty): x = impl.get(state, dict_, passive=passive) if x is attributes.PASSIVE_NO_RESULT or x is None: return [] - elif hasattr(impl, 'get_collection'): + elif hasattr(impl, "get_collection"): return [ - (attributes.instance_state(o), o) for o in - impl.get_collection(state, dict_, x, passive=passive) + (attributes.instance_state(o), o) + for o in impl.get_collection(state, dict_, x, passive=passive) ] else: return [(attributes.instance_state(x), x)] - def cascade_iterator(self, type_, state, dict_, - visited_states, halt_on=None): + def cascade_iterator( + self, type_, state, dict_, visited_states, halt_on=None + ): # assert type_ in self._cascade # only actively lazy load on the 'delete' cascade - if type_ != 'delete' or self.passive_deletes: + if type_ != "delete" or self.passive_deletes: passive = attributes.PASSIVE_NO_INITIALIZE else: passive = attributes.PASSIVE_OFF - if type_ == 'save-update': - tuples = state.manager[self.key].impl.\ - get_all_pending(state, dict_) + if type_ == "save-update": + tuples = state.manager[self.key].impl.get_all_pending(state, dict_) else: - tuples = self._value_as_iterable(state, dict_, self.key, - passive=passive) + tuples = self._value_as_iterable( + state, dict_, self.key, passive=passive + ) - skip_pending = type_ == 'refresh-expire' and 'delete-orphan' \ - not in self._cascade + skip_pending = ( + type_ == "refresh-expire" and "delete-orphan" not in self._cascade + ) for instance_state, c in tuples: if instance_state in visited_states: @@ -1670,13 +1765,12 @@ class RelationshipProperty(StrategizedProperty): instance_mapper = instance_state.manager.mapper if not instance_mapper.isa(self.mapper.class_manager.mapper): - raise AssertionError("Attribute '%s' on class '%s' " - "doesn't handle objects " - "of type '%s'" % ( - self.key, - self.parent.class_, - c.__class__ - )) + raise AssertionError( + "Attribute '%s' on class '%s' " + "doesn't handle objects " + "of type '%s'" + % (self.key, self.parent.class_, c.__class__) + ) visited_states.add(instance_state) @@ -1689,18 +1783,22 @@ class RelationshipProperty(StrategizedProperty): if not other.mapper.common_parent(self.parent): raise sa_exc.ArgumentError( - 'reverse_property %r on ' - 'relationship %s references relationship %s, which ' - 'does not reference mapper %s' % - (key, self, other, self.parent)) + "reverse_property %r on " + "relationship %s references relationship %s, which " + "does not reference mapper %s" + % (key, self, other, self.parent) + ) - if self.direction in (ONETOMANY, MANYTOONE) and self.direction \ - == other.direction: + if ( + self.direction in (ONETOMANY, MANYTOONE) + and self.direction == other.direction + ): raise sa_exc.ArgumentError( - '%s and back-reference %s are ' - 'both of the same direction %r. Did you mean to ' - 'set remote_side on the many-to-one side ?' % - (other, self, self.direction)) + "%s and back-reference %s are " + "both of the same direction %r. Did you mean to " + "set remote_side on the many-to-one side ?" + % (other, self, self.direction) + ) @util.memoized_property def mapper(self): @@ -1710,22 +1808,23 @@ class RelationshipProperty(StrategizedProperty): This is a lazy-initializing static attribute. """ - if util.callable(self.argument) and \ - not isinstance(self.argument, (type, mapperlib.Mapper)): + if util.callable(self.argument) and not isinstance( + self.argument, (type, mapperlib.Mapper) + ): argument = self.argument() else: argument = self.argument if isinstance(argument, type): - mapper_ = mapperlib.class_mapper(argument, - configure=False) + mapper_ = mapperlib.class_mapper(argument, configure=False) elif isinstance(self.argument, mapperlib.Mapper): mapper_ = argument else: raise sa_exc.ArgumentError( "relationship '%s' expects " "a class or a mapper argument (received: %s)" - % (self.key, type(argument))) + % (self.key, type(argument)) + ) return mapper_ @util.memoized_property @@ -1759,8 +1858,12 @@ class RelationshipProperty(StrategizedProperty): # deferred initialization. This technique is used # by declarative "string configs" and some recipes. for attr in ( - 'order_by', 'primaryjoin', 'secondaryjoin', - 'secondary', '_user_defined_foreign_keys', 'remote_side', + "order_by", + "primaryjoin", + "secondaryjoin", + "secondary", + "_user_defined_foreign_keys", + "remote_side", ): attr_value = getattr(self, attr) if util.callable(attr_value): @@ -1768,11 +1871,15 @@ class RelationshipProperty(StrategizedProperty): # remove "annotations" which are present if mapped class # descriptors are used to create the join expression. - for attr in 'primaryjoin', 'secondaryjoin': + for attr in "primaryjoin", "secondaryjoin": val = getattr(self, attr) if val is not None: - setattr(self, attr, _orm_deannotate( - expression._only_column_elements(val, attr)) + setattr( + self, + attr, + _orm_deannotate( + expression._only_column_elements(val, attr) + ), ) # ensure expressions in self.order_by, foreign_keys, @@ -1780,21 +1887,18 @@ class RelationshipProperty(StrategizedProperty): if self.order_by is not False and self.order_by is not None: self.order_by = [ expression._only_column_elements(x, "order_by") - for x in - util.to_list(self.order_by)] - - self._user_defined_foreign_keys = \ - util.column_set( - expression._only_column_elements(x, "foreign_keys") - for x in util.to_column_set( - self._user_defined_foreign_keys - )) - - self.remote_side = \ - util.column_set( - expression._only_column_elements(x, "remote_side") - for x in - util.to_column_set(self.remote_side)) + for x in util.to_list(self.order_by) + ] + + self._user_defined_foreign_keys = util.column_set( + expression._only_column_elements(x, "foreign_keys") + for x in util.to_column_set(self._user_defined_foreign_keys) + ) + + self.remote_side = util.column_set( + expression._only_column_elements(x, "remote_side") + for x in util.to_column_set(self.remote_side) + ) self.target = self.mapper.mapped_table @@ -1815,7 +1919,7 @@ class RelationshipProperty(StrategizedProperty): self_referential=self._is_self_referential, prop=self, support_sync=not self.viewonly, - can_be_synced_fn=self._columns_are_mapped + can_be_synced_fn=self._columns_are_mapped, ) self.primaryjoin = jc.primaryjoin self.secondaryjoin = jc.secondaryjoin @@ -1832,16 +1936,20 @@ class RelationshipProperty(StrategizedProperty): inheritance conflicts.""" if self.parent.non_primary and not mapperlib.class_mapper( - self.parent.class_, - configure=False).has_property(self.key): + self.parent.class_, configure=False + ).has_property(self.key): raise sa_exc.ArgumentError( "Attempting to assign a new " "relationship '%s' to a non-primary mapper on " "class '%s'. New relationships can only be added " "to the primary mapper, i.e. the very first mapper " - "created for class '%s' " % - (self.key, self.parent.class_.__name__, - self.parent.class_.__name__)) + "created for class '%s' " + % ( + self.key, + self.parent.class_.__name__, + self.parent.class_.__name__, + ) + ) def _get_cascade(self): """Return the current cascade setting for this @@ -1851,7 +1959,7 @@ class RelationshipProperty(StrategizedProperty): def _set_cascade(self, cascade): cascade = CascadeOptions(cascade) - if 'mapper' in self.__dict__: + if "mapper" in self.__dict__: self._check_cascade_settings(cascade) self._cascade = cascade @@ -1861,27 +1969,31 @@ class RelationshipProperty(StrategizedProperty): cascade = property(_get_cascade, _set_cascade) def _check_cascade_settings(self, cascade): - if cascade.delete_orphan and not self.single_parent \ - and (self.direction is MANYTOMANY or self.direction - is MANYTOONE): + if ( + cascade.delete_orphan + and not self.single_parent + and (self.direction is MANYTOMANY or self.direction is MANYTOONE) + ): raise sa_exc.ArgumentError( - 'On %s, delete-orphan cascade is not supported ' - 'on a many-to-many or many-to-one relationship ' - 'when single_parent is not set. Set ' - 'single_parent=True on the relationship().' - % self) + "On %s, delete-orphan cascade is not supported " + "on a many-to-many or many-to-one relationship " + "when single_parent is not set. Set " + "single_parent=True on the relationship()." % self + ) if self.direction is MANYTOONE and self.passive_deletes: - util.warn("On %s, 'passive_deletes' is normally configured " - "on one-to-many, one-to-one, many-to-many " - "relationships only." - % self) - - if self.passive_deletes == 'all' and \ - ("delete" in cascade or - "delete-orphan" in cascade): + util.warn( + "On %s, 'passive_deletes' is normally configured " + "on one-to-many, one-to-one, many-to-many " + "relationships only." % self + ) + + if self.passive_deletes == "all" and ( + "delete" in cascade or "delete-orphan" in cascade + ): raise sa_exc.ArgumentError( "On %s, can't set passive_deletes='all' in conjunction " - "with 'delete' or 'delete-orphan' cascade" % self) + "with 'delete' or 'delete-orphan' cascade" % self + ) if cascade.delete_orphan: self.mapper.primary_mapper()._delete_orphans.append( @@ -1894,8 +2006,10 @@ class RelationshipProperty(StrategizedProperty): """ - return self.key in mapper.relationships and \ - mapper.relationships[self.key] is self + return ( + self.key in mapper.relationships + and mapper.relationships[self.key] is self + ) def _columns_are_mapped(self, *cols): """Return True if all columns in the given collection are @@ -1903,11 +2017,14 @@ class RelationshipProperty(StrategizedProperty): """ for c in cols: - if self.secondary is not None \ - and self.secondary.c.contains_column(c): + if ( + self.secondary is not None + and self.secondary.c.contains_column(c) + ): continue - if not self.parent.mapped_table.c.contains_column(c) and \ - not self.target.c.contains_column(c): + if not self.parent.mapped_table.c.contains_column( + c + ) and not self.target.c.contains_column(c): return False return True @@ -1925,15 +2042,17 @@ class RelationshipProperty(StrategizedProperty): mapper = self.mapper.primary_mapper() if not mapper.concrete: - check = set(mapper.iterate_to_root()).\ - union(mapper.self_and_descendants) + check = set(mapper.iterate_to_root()).union( + mapper.self_and_descendants + ) for m in check: if m.has_property(backref_key) and not m.concrete: raise sa_exc.ArgumentError( "Error creating backref " "'%s' on relationship '%s': property of that " - "name exists on mapper '%s'" % - (backref_key, self, m)) + "name exists on mapper '%s'" + % (backref_key, self, m) + ) # determine primaryjoin/secondaryjoin for the # backref. Use the one we had, so that @@ -1944,35 +2063,42 @@ class RelationshipProperty(StrategizedProperty): # secondaryjoin. use the annotated # pj/sj on the _join_condition. pj = kwargs.pop( - 'primaryjoin', - self._join_condition.secondaryjoin_minus_local) + "primaryjoin", + self._join_condition.secondaryjoin_minus_local, + ) sj = kwargs.pop( - 'secondaryjoin', - self._join_condition.primaryjoin_minus_local) + "secondaryjoin", + self._join_condition.primaryjoin_minus_local, + ) else: pj = kwargs.pop( - 'primaryjoin', - self._join_condition.primaryjoin_reverse_remote) - sj = kwargs.pop('secondaryjoin', None) + "primaryjoin", + self._join_condition.primaryjoin_reverse_remote, + ) + sj = kwargs.pop("secondaryjoin", None) if sj: raise sa_exc.InvalidRequestError( "Can't assign 'secondaryjoin' on a backref " "against a non-secondary relationship." ) - foreign_keys = kwargs.pop('foreign_keys', - self._user_defined_foreign_keys) + foreign_keys = kwargs.pop( + "foreign_keys", self._user_defined_foreign_keys + ) parent = self.parent.primary_mapper() - kwargs.setdefault('viewonly', self.viewonly) - kwargs.setdefault('post_update', self.post_update) - kwargs.setdefault('passive_updates', self.passive_updates) + kwargs.setdefault("viewonly", self.viewonly) + kwargs.setdefault("post_update", self.post_update) + kwargs.setdefault("passive_updates", self.passive_updates) self.back_populates = backref_key relationship = RelationshipProperty( - parent, self.secondary, - pj, sj, + parent, + self.secondary, + pj, + sj, foreign_keys=foreign_keys, back_populates=self.key, - **kwargs) + **kwargs + ) mapper._configure_property(backref_key, relationship) if self.back_populates: @@ -1982,8 +2108,9 @@ class RelationshipProperty(StrategizedProperty): if self.uselist is None: self.uselist = self.direction is not MANYTOONE if not self.viewonly: - self._dependency_processor = \ - dependency.DependencyProcessor.from_relationship(self) + self._dependency_processor = dependency.DependencyProcessor.from_relationship( + self + ) @util.memoized_property def _use_get(self): @@ -1997,9 +2124,14 @@ class RelationshipProperty(StrategizedProperty): def _is_self_referential(self): return self.mapper.common_parent(self.parent) - def _create_joins(self, source_polymorphic=False, - source_selectable=None, dest_polymorphic=False, - dest_selectable=None, of_type=None): + def _create_joins( + self, + source_polymorphic=False, + source_selectable=None, + dest_polymorphic=False, + dest_selectable=None, + of_type=None, + ): if source_selectable is None: if source_polymorphic and self.parent.with_polymorphic: source_selectable = self.parent._with_polymorphic_selectable @@ -2023,16 +2155,21 @@ class RelationshipProperty(StrategizedProperty): single_crit = dest_mapper._single_table_criterion aliased = aliased or (source_selectable is not None) - primaryjoin, secondaryjoin, secondary, target_adapter, dest_selectable = \ - self._join_condition.join_targets( - source_selectable, dest_selectable, aliased, single_crit - ) + primaryjoin, secondaryjoin, secondary, target_adapter, dest_selectable = self._join_condition.join_targets( + source_selectable, dest_selectable, aliased, single_crit + ) if source_selectable is None: source_selectable = self.parent.local_table if dest_selectable is None: dest_selectable = self.mapper.local_table - return (primaryjoin, secondaryjoin, source_selectable, - dest_selectable, secondary, target_adapter) + return ( + primaryjoin, + secondaryjoin, + source_selectable, + dest_selectable, + secondary, + target_adapter, + ) def _annotate_columns(element, annotations): @@ -2048,24 +2185,25 @@ def _annotate_columns(element, annotations): class JoinCondition(object): - def __init__(self, - parent_selectable, - child_selectable, - parent_local_selectable, - child_local_selectable, - primaryjoin=None, - secondary=None, - secondaryjoin=None, - parent_equivalents=None, - child_equivalents=None, - consider_as_foreign_keys=None, - local_remote_pairs=None, - remote_side=None, - self_referential=False, - prop=None, - support_sync=True, - can_be_synced_fn=lambda *c: True - ): + def __init__( + self, + parent_selectable, + child_selectable, + parent_local_selectable, + child_local_selectable, + primaryjoin=None, + secondary=None, + secondaryjoin=None, + parent_equivalents=None, + child_equivalents=None, + consider_as_foreign_keys=None, + local_remote_pairs=None, + remote_side=None, + self_referential=False, + prop=None, + support_sync=True, + can_be_synced_fn=lambda *c: True, + ): self.parent_selectable = parent_selectable self.parent_local_selectable = parent_local_selectable self.child_selectable = child_selectable @@ -2100,27 +2238,41 @@ class JoinCondition(object): if self.prop is None: return log = self.prop.logger - log.info('%s setup primary join %s', self.prop, - self.primaryjoin) - log.info('%s setup secondary join %s', self.prop, - self.secondaryjoin) - log.info('%s synchronize pairs [%s]', self.prop, - ','.join('(%s => %s)' % (l, r) for (l, r) in - self.synchronize_pairs)) - log.info('%s secondary synchronize pairs [%s]', self.prop, - ','.join('(%s => %s)' % (l, r) for (l, r) in - self.secondary_synchronize_pairs or [])) - log.info('%s local/remote pairs [%s]', self.prop, - ','.join('(%s / %s)' % (l, r) for (l, r) in - self.local_remote_pairs)) - log.info('%s remote columns [%s]', self.prop, - ','.join('%s' % col for col in self.remote_columns) - ) - log.info('%s local columns [%s]', self.prop, - ','.join('%s' % col for col in self.local_columns) - ) - log.info('%s relationship direction %s', self.prop, - self.direction) + log.info("%s setup primary join %s", self.prop, self.primaryjoin) + log.info("%s setup secondary join %s", self.prop, self.secondaryjoin) + log.info( + "%s synchronize pairs [%s]", + self.prop, + ",".join( + "(%s => %s)" % (l, r) for (l, r) in self.synchronize_pairs + ), + ) + log.info( + "%s secondary synchronize pairs [%s]", + self.prop, + ",".join( + "(%s => %s)" % (l, r) + for (l, r) in self.secondary_synchronize_pairs or [] + ), + ) + log.info( + "%s local/remote pairs [%s]", + self.prop, + ",".join( + "(%s / %s)" % (l, r) for (l, r) in self.local_remote_pairs + ), + ) + log.info( + "%s remote columns [%s]", + self.prop, + ",".join("%s" % col for col in self.remote_columns), + ) + log.info( + "%s local columns [%s]", + self.prop, + ",".join("%s" % col for col in self.local_columns), + ) + log.info("%s relationship direction %s", self.prop, self.direction) def _sanitize_joins(self): """remove the parententity annotation from our join conditions which @@ -2133,10 +2285,12 @@ class JoinCondition(object): """ self.primaryjoin = _deep_deannotate( - self.primaryjoin, values=("parententity",)) + self.primaryjoin, values=("parententity",) + ) if self.secondaryjoin is not None: self.secondaryjoin = _deep_deannotate( - self.secondaryjoin, values=("parententity",)) + self.secondaryjoin, values=("parententity",) + ) def _determine_joins(self): """Determine the 'primaryjoin' and 'secondaryjoin' attributes, @@ -2150,7 +2304,8 @@ class JoinCondition(object): raise sa_exc.ArgumentError( "Property %s specified with secondary " "join condition but " - "no secondary argument" % self.prop) + "no secondary argument" % self.prop + ) # find a join between the given mapper's mapped table and # the given table. will try the mapper's local table first @@ -2161,30 +2316,27 @@ class JoinCondition(object): consider_as_foreign_keys = self.consider_as_foreign_keys or None if self.secondary is not None: if self.secondaryjoin is None: - self.secondaryjoin = \ - join_condition( - self.child_selectable, - self.secondary, - a_subset=self.child_local_selectable, - consider_as_foreign_keys=consider_as_foreign_keys - ) + self.secondaryjoin = join_condition( + self.child_selectable, + self.secondary, + a_subset=self.child_local_selectable, + consider_as_foreign_keys=consider_as_foreign_keys, + ) if self.primaryjoin is None: - self.primaryjoin = \ - join_condition( - self.parent_selectable, - self.secondary, - a_subset=self.parent_local_selectable, - consider_as_foreign_keys=consider_as_foreign_keys - ) + self.primaryjoin = join_condition( + self.parent_selectable, + self.secondary, + a_subset=self.parent_local_selectable, + consider_as_foreign_keys=consider_as_foreign_keys, + ) else: if self.primaryjoin is None: - self.primaryjoin = \ - join_condition( - self.parent_selectable, - self.child_selectable, - a_subset=self.parent_local_selectable, - consider_as_foreign_keys=consider_as_foreign_keys - ) + self.primaryjoin = join_condition( + self.parent_selectable, + self.child_selectable, + a_subset=self.parent_local_selectable, + consider_as_foreign_keys=consider_as_foreign_keys, + ) except sa_exc.NoForeignKeysError: if self.secondary is not None: raise sa_exc.NoForeignKeysError( @@ -2195,7 +2347,8 @@ class JoinCondition(object): "Ensure that referencing columns are associated " "with a ForeignKey or ForeignKeyConstraint, or " "specify 'primaryjoin' and 'secondaryjoin' " - "expressions." % (self.prop, self.secondary)) + "expressions." % (self.prop, self.secondary) + ) else: raise sa_exc.NoForeignKeysError( "Could not determine join " @@ -2204,7 +2357,8 @@ class JoinCondition(object): "linking these tables. " "Ensure that referencing columns are associated " "with a ForeignKey or ForeignKeyConstraint, or " - "specify a 'primaryjoin' expression." % self.prop) + "specify a 'primaryjoin' expression." % self.prop + ) except sa_exc.AmbiguousForeignKeysError: if self.secondary is not None: raise sa_exc.AmbiguousForeignKeysError( @@ -2216,8 +2370,8 @@ class JoinCondition(object): "argument, providing a list of those columns which " "should be counted as containing a foreign key " "reference from the secondary table to each of the " - "parent and child tables." - % (self.prop, self.secondary)) + "parent and child tables." % (self.prop, self.secondary) + ) else: raise sa_exc.AmbiguousForeignKeysError( "Could not determine join " @@ -2226,8 +2380,8 @@ class JoinCondition(object): "paths linking the tables. Specify the " "'foreign_keys' argument, providing a list of those " "columns which should be counted as containing a " - "foreign key reference to the parent table." - % self.prop) + "foreign key reference to the parent table." % self.prop + ) @property def primaryjoin_minus_local(self): @@ -2235,8 +2389,7 @@ class JoinCondition(object): @property def secondaryjoin_minus_local(self): - return _deep_deannotate(self.secondaryjoin, - values=("local", "remote")) + return _deep_deannotate(self.secondaryjoin, values=("local", "remote")) @util.memoized_property def primaryjoin_reverse_remote(self): @@ -2250,24 +2403,26 @@ class JoinCondition(object): """ if self._has_remote_annotations: + def replace(element): if "remote" in element._annotations: v = element._annotations.copy() - del v['remote'] - v['local'] = True + del v["remote"] + v["local"] = True return element._with_annotations(v) elif "local" in element._annotations: v = element._annotations.copy() - del v['local'] - v['remote'] = True + del v["local"] + v["remote"] = True return element._with_annotations(v) - return visitors.replacement_traverse( - self.primaryjoin, {}, replace) + + return visitors.replacement_traverse(self.primaryjoin, {}, replace) else: if self._has_foreign_annotations: # TODO: coverage - return _deep_deannotate(self.primaryjoin, - values=("local", "remote")) + return _deep_deannotate( + self.primaryjoin, values=("local", "remote") + ) else: return _deep_deannotate(self.primaryjoin) @@ -2304,16 +2459,13 @@ class JoinCondition(object): def check_fk(col): if col in self.consider_as_foreign_keys: return col._annotate({"foreign": True}) + self.primaryjoin = visitors.replacement_traverse( - self.primaryjoin, - {}, - check_fk + self.primaryjoin, {}, check_fk ) if self.secondaryjoin is not None: self.secondaryjoin = visitors.replacement_traverse( - self.secondaryjoin, - {}, - check_fk + self.secondaryjoin, {}, check_fk ) def _annotate_present_fks(self): @@ -2323,8 +2475,7 @@ class JoinCondition(object): secondarycols = set() def is_foreign(a, b): - if isinstance(a, schema.Column) and \ - isinstance(b, schema.Column): + if isinstance(a, schema.Column) and isinstance(b, schema.Column): if a.references(b): return a elif b.references(a): @@ -2337,31 +2488,30 @@ class JoinCondition(object): return b def visit_binary(binary): - if not isinstance(binary.left, sql.ColumnElement) or \ - not isinstance(binary.right, sql.ColumnElement): + if not isinstance( + binary.left, sql.ColumnElement + ) or not isinstance(binary.right, sql.ColumnElement): return - if "foreign" not in binary.left._annotations and \ - "foreign" not in binary.right._annotations: + if ( + "foreign" not in binary.left._annotations + and "foreign" not in binary.right._annotations + ): col = is_foreign(binary.left, binary.right) if col is not None: if col.compare(binary.left): - binary.left = binary.left._annotate( - {"foreign": True}) + binary.left = binary.left._annotate({"foreign": True}) elif col.compare(binary.right): binary.right = binary.right._annotate( - {"foreign": True}) + {"foreign": True} + ) self.primaryjoin = visitors.cloned_traverse( - self.primaryjoin, - {}, - {"binary": visit_binary} + self.primaryjoin, {}, {"binary": visit_binary} ) if self.secondaryjoin is not None: self.secondaryjoin = visitors.cloned_traverse( - self.secondaryjoin, - {}, - {"binary": visit_binary} + self.secondaryjoin, {}, {"binary": visit_binary} ) def _refers_to_parent_table(self): @@ -2376,26 +2526,24 @@ class JoinCondition(object): def visit_binary(binary): c, f = binary.left, binary.right if ( - isinstance(c, expression.ColumnClause) and - isinstance(f, expression.ColumnClause) and - pt.is_derived_from(c.table) and - pt.is_derived_from(f.table) and - mt.is_derived_from(c.table) and - mt.is_derived_from(f.table) + isinstance(c, expression.ColumnClause) + and isinstance(f, expression.ColumnClause) + and pt.is_derived_from(c.table) + and pt.is_derived_from(f.table) + and mt.is_derived_from(c.table) + and mt.is_derived_from(f.table) ): result[0] = True - visitors.traverse( - self.primaryjoin, - {}, - {"binary": visit_binary} - ) + + visitors.traverse(self.primaryjoin, {}, {"binary": visit_binary}) return result[0] def _tables_overlap(self): """Return True if parent/child tables have some overlap.""" return selectables_overlap( - self.parent_selectable, self.child_selectable) + self.parent_selectable, self.child_selectable + ) def _annotate_remote(self): """Annotate the primaryjoin and secondaryjoin @@ -2411,7 +2559,9 @@ class JoinCondition(object): elif self._local_remote_pairs or self._remote_side: self._annotate_remote_from_args() elif self._refers_to_parent_table(): - self._annotate_selfref(lambda col: "foreign" in col._annotations, False) + self._annotate_selfref( + lambda col: "foreign" in col._annotations, False + ) elif self._tables_overlap(): self._annotate_remote_with_overlap() else: @@ -2422,35 +2572,40 @@ class JoinCondition(object): when 'secondary' is present. """ + def repl(element): if self.secondary.c.contains_column(element): return element._annotate({"remote": True}) + self.primaryjoin = visitors.replacement_traverse( - self.primaryjoin, {}, repl) + self.primaryjoin, {}, repl + ) self.secondaryjoin = visitors.replacement_traverse( - self.secondaryjoin, {}, repl) + self.secondaryjoin, {}, repl + ) def _annotate_selfref(self, fn, remote_side_given): """annotate 'remote' in primaryjoin, secondaryjoin when the relationship is detected as self-referential. """ + def visit_binary(binary): equated = binary.left.compare(binary.right) - if isinstance(binary.left, expression.ColumnClause) and \ - isinstance(binary.right, expression.ColumnClause): + if isinstance(binary.left, expression.ColumnClause) and isinstance( + binary.right, expression.ColumnClause + ): # assume one to many - FKs are "remote" if fn(binary.left): binary.left = binary.left._annotate({"remote": True}) if fn(binary.right) and not equated: - binary.right = binary.right._annotate( - {"remote": True}) + binary.right = binary.right._annotate({"remote": True}) elif not remote_side_given: self._warn_non_column_elements() self.primaryjoin = visitors.cloned_traverse( - self.primaryjoin, {}, - {"binary": visit_binary}) + self.primaryjoin, {}, {"binary": visit_binary} + ) def _annotate_remote_from_args(self): """annotate 'remote' in primaryjoin, secondaryjoin @@ -2463,7 +2618,8 @@ class JoinCondition(object): raise sa_exc.ArgumentError( "remote_side argument is redundant " "against more detailed _local_remote_side " - "argument.") + "argument." + ) remote_side = [r for (l, r) in self._local_remote_pairs] else: @@ -2472,11 +2628,14 @@ class JoinCondition(object): if self._refers_to_parent_table(): self._annotate_selfref(lambda col: col in remote_side, True) else: + def repl(element): if element in remote_side: return element._annotate({"remote": True}) + self.primaryjoin = visitors.replacement_traverse( - self.primaryjoin, {}, repl) + self.primaryjoin, {}, repl + ) def _annotate_remote_with_overlap(self): """annotate 'remote' in primaryjoin, secondaryjoin @@ -2485,26 +2644,36 @@ class JoinCondition(object): relationship. """ + def visit_binary(binary): - binary.left, binary.right = proc_left_right(binary.left, - binary.right) - binary.right, binary.left = proc_left_right(binary.right, - binary.left) + binary.left, binary.right = proc_left_right( + binary.left, binary.right + ) + binary.right, binary.left = proc_left_right( + binary.right, binary.left + ) - check_entities = self.prop is not None and \ - self.prop.mapper is not self.prop.parent + check_entities = ( + self.prop is not None and self.prop.mapper is not self.prop.parent + ) def proc_left_right(left, right): - if isinstance(left, expression.ColumnClause) and \ - isinstance(right, expression.ColumnClause): - if self.child_selectable.c.contains_column(right) and \ - self.parent_selectable.c.contains_column(left): + if isinstance(left, expression.ColumnClause) and isinstance( + right, expression.ColumnClause + ): + if self.child_selectable.c.contains_column( + right + ) and self.parent_selectable.c.contains_column(left): right = right._annotate({"remote": True}) - elif check_entities and \ - right._annotations.get('parentmapper') is self.prop.mapper: + elif ( + check_entities + and right._annotations.get("parentmapper") is self.prop.mapper + ): right = right._annotate({"remote": True}) - elif check_entities and \ - left._annotations.get('parentmapper') is self.prop.mapper: + elif ( + check_entities + and left._annotations.get("parentmapper") is self.prop.mapper + ): left = left._annotate({"remote": True}) else: self._warn_non_column_elements() @@ -2512,8 +2681,8 @@ class JoinCondition(object): return left, right self.primaryjoin = visitors.cloned_traverse( - self.primaryjoin, {}, - {"binary": visit_binary}) + self.primaryjoin, {}, {"binary": visit_binary} + ) def _annotate_remote_distinct_selectables(self): """annotate 'remote' in primaryjoin, secondaryjoin @@ -2521,22 +2690,23 @@ class JoinCondition(object): separate. """ + def repl(element): - if self.child_selectable.c.contains_column(element) and \ - (not self.parent_local_selectable.c. - contains_column(element) or - self.child_local_selectable.c. - contains_column(element)): + if self.child_selectable.c.contains_column(element) and ( + not self.parent_local_selectable.c.contains_column(element) + or self.child_local_selectable.c.contains_column(element) + ): return element._annotate({"remote": True}) + self.primaryjoin = visitors.replacement_traverse( - self.primaryjoin, {}, repl) + self.primaryjoin, {}, repl + ) def _warn_non_column_elements(self): util.warn( "Non-simple column elements in primary " "join condition for property %s - consider using " - "remote() annotations to mark the remote side." - % self.prop + "remote() annotations to mark the remote side." % self.prop ) def _annotate_local(self): @@ -2554,15 +2724,16 @@ class JoinCondition(object): return if self._local_remote_pairs: - local_side = util.column_set([l for (l, r) - in self._local_remote_pairs]) + local_side = util.column_set( + [l for (l, r) in self._local_remote_pairs] + ) else: local_side = util.column_set(self.parent_selectable.c) def locals_(elem): - if "remote" not in elem._annotations and \ - elem in local_side: + if "remote" not in elem._annotations and elem in local_side: return elem._annotate({"local": True}) + self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, locals_ ) @@ -2576,6 +2747,7 @@ class JoinCondition(object): return elem._annotate({"parentmapper": self.prop.mapper}) elif "local" in elem._annotations: return elem._annotate({"parentmapper": self.prop.parent}) + self.primaryjoin = visitors.replacement_traverse( self.primaryjoin, {}, parentmappers_ ) @@ -2583,14 +2755,15 @@ class JoinCondition(object): def _check_remote_side(self): if not self.local_remote_pairs: raise sa_exc.ArgumentError( - 'Relationship %s could ' - 'not determine any unambiguous local/remote column ' - 'pairs based on join condition and remote_side ' - 'arguments. ' - 'Consider using the remote() annotation to ' - 'accurately mark those elements of the join ' - 'condition that are on the remote side of ' - 'the relationship.' % (self.prop, )) + "Relationship %s could " + "not determine any unambiguous local/remote column " + "pairs based on join condition and remote_side " + "arguments. " + "Consider using the remote() annotation to " + "accurately mark those elements of the join " + "condition that are on the remote side of " + "the relationship." % (self.prop,) + ) def _check_foreign_cols(self, join_condition, primary): """Check the foreign key columns collected and emit error @@ -2599,7 +2772,8 @@ class JoinCondition(object): can_sync = False foreign_cols = self._gather_columns_with_annotation( - join_condition, "foreign") + join_condition, "foreign" + ) has_foreign = bool(foreign_cols) @@ -2608,42 +2782,53 @@ class JoinCondition(object): else: can_sync = bool(self.secondary_synchronize_pairs) - if self.support_sync and can_sync or \ - (not self.support_sync and has_foreign): + if ( + self.support_sync + and can_sync + or (not self.support_sync and has_foreign) + ): return # from here below is just determining the best error message # to report. Check for a join condition using any operator # (not just ==), perhaps they need to turn on "viewonly=True". if self.support_sync and has_foreign and not can_sync: - err = "Could not locate any simple equality expressions "\ - "involving locally mapped foreign key columns for "\ - "%s join condition "\ - "'%s' on relationship %s." % ( - primary and 'primary' or 'secondary', + err = ( + "Could not locate any simple equality expressions " + "involving locally mapped foreign key columns for " + "%s join condition " + "'%s' on relationship %s." + % ( + primary and "primary" or "secondary", join_condition, - self.prop + self.prop, ) - err += \ - " Ensure that referencing columns are associated "\ - "with a ForeignKey or ForeignKeyConstraint, or are "\ - "annotated in the join condition with the foreign() "\ - "annotation. To allow comparison operators other than "\ + ) + err += ( + " Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or are " + "annotated in the join condition with the foreign() " + "annotation. To allow comparison operators other than " "'==', the relationship can be marked as viewonly=True." + ) raise sa_exc.ArgumentError(err) else: - err = "Could not locate any relevant foreign key columns "\ - "for %s join condition '%s' on relationship %s." % ( - primary and 'primary' or 'secondary', + err = ( + "Could not locate any relevant foreign key columns " + "for %s join condition '%s' on relationship %s." + % ( + primary and "primary" or "secondary", join_condition, - self.prop + self.prop, ) - err += \ - ' Ensure that referencing columns are associated '\ - 'with a ForeignKey or ForeignKeyConstraint, or are '\ - 'annotated in the join condition with the foreign() '\ - 'annotation.' + ) + err += ( + " Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or are " + "annotated in the join condition with the foreign() " + "annotation." + ) raise sa_exc.ArgumentError(err) def _determine_direction(self): @@ -2658,13 +2843,11 @@ class JoinCondition(object): targetcols = util.column_set(self.child_selectable.c) # fk collection which suggests ONETOMANY. - onetomany_fk = targetcols.intersection( - self.foreign_key_columns) + onetomany_fk = targetcols.intersection(self.foreign_key_columns) # fk collection which suggests MANYTOONE. - manytoone_fk = parentcols.intersection( - self.foreign_key_columns) + manytoone_fk = parentcols.intersection(self.foreign_key_columns) if onetomany_fk and manytoone_fk: # fks on both sides. test for overlap of local/remote @@ -2676,15 +2859,20 @@ class JoinCondition(object): # 1. columns that are both remote and FK suggest # onetomany. onetomany_local = self._gather_columns_with_annotation( - self.primaryjoin, "remote", "foreign") + self.primaryjoin, "remote", "foreign" + ) # 2. columns that are FK but are not remote (e.g. local) # suggest manytoone. - manytoone_local = set([c for c in - self._gather_columns_with_annotation( - self.primaryjoin, - "foreign") - if "remote" not in c._annotations]) + manytoone_local = set( + [ + c + for c in self._gather_columns_with_annotation( + self.primaryjoin, "foreign" + ) + if "remote" not in c._annotations + ] + ) # 3. if both collections are present, remove columns that # refer to themselves. This is for the case of @@ -2713,7 +2901,8 @@ class JoinCondition(object): "Ensure that only those columns referring " "to a parent column are marked as foreign, " "either via the foreign() annotation or " - "via the foreign_keys argument." % self.prop) + "via the foreign_keys argument." % self.prop + ) elif onetomany_fk: self.direction = ONETOMANY elif manytoone_fk: @@ -2723,7 +2912,8 @@ class JoinCondition(object): "Can't determine relationship " "direction for relationship '%s' - foreign " "key columns are present in neither the parent " - "nor the child's mapped tables" % self.prop) + "nor the child's mapped tables" % self.prop + ) def _deannotate_pairs(self, collection): """provide deannotation for the various lists of @@ -2732,8 +2922,7 @@ class JoinCondition(object): original columns mapped. """ - return [(x._deannotate(), y._deannotate()) - for x, y in collection] + return [(x._deannotate(), y._deannotate()) for x, y in collection] def _setup_pairs(self): sync_pairs = [] @@ -2742,25 +2931,31 @@ class JoinCondition(object): def go(joincond, collection): def visit_binary(binary, left, right): - if "remote" in right._annotations and \ - "remote" not in left._annotations and \ - self.can_be_synced_fn(left): + if ( + "remote" in right._annotations + and "remote" not in left._annotations + and self.can_be_synced_fn(left) + ): lrp.add((left, right)) - elif "remote" in left._annotations and \ - "remote" not in right._annotations and \ - self.can_be_synced_fn(right): + elif ( + "remote" in left._annotations + and "remote" not in right._annotations + and self.can_be_synced_fn(right) + ): lrp.add((right, left)) - if binary.operator is operators.eq and \ - self.can_be_synced_fn(left, right): + if binary.operator is operators.eq and self.can_be_synced_fn( + left, right + ): if "foreign" in right._annotations: collection.append((left, right)) elif "foreign" in left._annotations: collection.append((right, left)) + visit_binary_product(visit_binary, joincond) for joincond, collection in [ (self.primaryjoin, sync_pairs), - (self.secondaryjoin, secondary_sync_pairs) + (self.secondaryjoin, secondary_sync_pairs), ]: if joincond is None: continue @@ -2768,8 +2963,9 @@ class JoinCondition(object): self.local_remote_pairs = self._deannotate_pairs(lrp) self.synchronize_pairs = self._deannotate_pairs(sync_pairs) - self.secondary_synchronize_pairs = \ - self._deannotate_pairs(secondary_sync_pairs) + self.secondary_synchronize_pairs = self._deannotate_pairs( + secondary_sync_pairs + ) _track_overlapping_sync_targets = weakref.WeakKeyDictionary() @@ -2797,20 +2993,23 @@ class JoinCondition(object): continue if to_ not in self._track_overlapping_sync_targets: - self._track_overlapping_sync_targets[to_] = \ - weakref.WeakKeyDictionary({self.prop: from_}) + self._track_overlapping_sync_targets[ + to_ + ] = weakref.WeakKeyDictionary({self.prop: from_}) else: other_props = [] prop_to_from = self._track_overlapping_sync_targets[to_] for pr, fr_ in prop_to_from.items(): - if pr.mapper in mapperlib._mapper_registry and \ - ( - self.prop._persists_for(pr.parent) or - pr._persists_for(self.prop.parent) - ) and \ - fr_ is not from_ and \ - pr not in self.prop._reverse_property: + if ( + pr.mapper in mapperlib._mapper_registry + and ( + self.prop._persists_for(pr.parent) + or pr._persists_for(self.prop.parent) + ) + and fr_ is not from_ + and pr not in self.prop._reverse_property + ): other_props.append((pr, fr_)) @@ -2821,12 +3020,15 @@ class JoinCondition(object): "Consider applying " "viewonly=True to read-only relationships, or provide " "a primaryjoin condition marking writable columns " - "with the foreign() annotation." % ( + "with the foreign() annotation." + % ( self.prop, - from_, to_, + from_, + to_, ", ".join( "'%s' (copies %s to %s)" % (pr, fr_, to_) - for (pr, fr_) in other_props) + for (pr, fr_) in other_props + ), ) ) self._track_overlapping_sync_targets[to_][self.prop] = from_ @@ -2845,27 +3047,29 @@ class JoinCondition(object): def _gather_join_annotations(self, annotation): s = set( - self._gather_columns_with_annotation( - self.primaryjoin, annotation) + self._gather_columns_with_annotation(self.primaryjoin, annotation) ) if self.secondaryjoin is not None: s.update( self._gather_columns_with_annotation( - self.secondaryjoin, annotation) + self.secondaryjoin, annotation + ) ) return {x._deannotate() for x in s} def _gather_columns_with_annotation(self, clause, *annotation): annotation = set(annotation) - return set([ - col for col in visitors.iterate(clause, {}) - if annotation.issubset(col._annotations) - ]) - - def join_targets(self, source_selectable, - dest_selectable, - aliased, - single_crit=None): + return set( + [ + col + for col in visitors.iterate(clause, {}) + if annotation.issubset(col._annotations) + ] + ) + + def join_targets( + self, source_selectable, dest_selectable, aliased, single_crit=None + ): """Given a source and destination selectable, create a join between them. @@ -2881,11 +3085,14 @@ class JoinCondition(object): # its internal structure remains fixed # regardless of context. dest_selectable = _shallow_annotate( - dest_selectable, - {'no_replacement_traverse': True}) + dest_selectable, {"no_replacement_traverse": True} + ) - primaryjoin, secondaryjoin, secondary = self.primaryjoin, \ - self.secondaryjoin, self.secondary + primaryjoin, secondaryjoin, secondary = ( + self.primaryjoin, + self.secondaryjoin, + self.secondary, + ) # adjust the join condition for single table inheritance, # in the case that the join is to a subclass @@ -2902,28 +3109,31 @@ class JoinCondition(object): if secondary is not None: secondary = secondary.alias(flat=True) primary_aliasizer = ClauseAdapter(secondary) - secondary_aliasizer = \ - ClauseAdapter(dest_selectable, - equivalents=self.child_equivalents).\ - chain(primary_aliasizer) + secondary_aliasizer = ClauseAdapter( + dest_selectable, equivalents=self.child_equivalents + ).chain(primary_aliasizer) if source_selectable is not None: - primary_aliasizer = \ - ClauseAdapter(secondary).\ - chain(ClauseAdapter( + primary_aliasizer = ClauseAdapter(secondary).chain( + ClauseAdapter( source_selectable, - equivalents=self.parent_equivalents)) - secondaryjoin = \ - secondary_aliasizer.traverse(secondaryjoin) + equivalents=self.parent_equivalents, + ) + ) + secondaryjoin = secondary_aliasizer.traverse(secondaryjoin) else: primary_aliasizer = ClauseAdapter( dest_selectable, exclude_fn=_ColInAnnotations("local"), - equivalents=self.child_equivalents) + equivalents=self.child_equivalents, + ) if source_selectable is not None: primary_aliasizer.chain( - ClauseAdapter(source_selectable, - exclude_fn=_ColInAnnotations("remote"), - equivalents=self.parent_equivalents)) + ClauseAdapter( + source_selectable, + exclude_fn=_ColInAnnotations("remote"), + equivalents=self.parent_equivalents, + ) + ) secondary_aliasizer = None primaryjoin = primary_aliasizer.traverse(primaryjoin) @@ -2931,8 +3141,13 @@ class JoinCondition(object): target_adapter.exclude_fn = None else: target_adapter = None - return primaryjoin, secondaryjoin, secondary, \ - target_adapter, dest_selectable + return ( + primaryjoin, + secondaryjoin, + secondary, + target_adapter, + dest_selectable, + ) def create_lazy_clause(self, reverse_direction=False): binds = util.column_dict() @@ -2955,28 +3170,32 @@ class JoinCondition(object): def col_to_bind(col): if ( - (not reverse_direction and 'local' in col._annotations) or - reverse_direction and ( - (has_secondary and col in lookup) or - (not has_secondary and 'remote' in col._annotations) + (not reverse_direction and "local" in col._annotations) + or reverse_direction + and ( + (has_secondary and col in lookup) + or (not has_secondary and "remote" in col._annotations) ) ): if col not in binds: binds[col] = sql.bindparam( - None, None, type_=col.type, unique=True) + None, None, type_=col.type, unique=True + ) return binds[col] return None lazywhere = self.primaryjoin if self.secondaryjoin is None or not reverse_direction: lazywhere = visitors.replacement_traverse( - lazywhere, {}, col_to_bind) + lazywhere, {}, col_to_bind + ) if self.secondaryjoin is not None: secondaryjoin = self.secondaryjoin if reverse_direction: secondaryjoin = visitors.replacement_traverse( - secondaryjoin, {}, col_to_bind) + secondaryjoin, {}, col_to_bind + ) lazywhere = sql.and_(lazywhere, secondaryjoin) bind_to_col = {binds[col].key: col for col in binds} diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 2e16872f99..2eeaf5b6d3 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -11,7 +11,7 @@ from . import class_mapper, exc as orm_exc from .session import Session -__all__ = ['scoped_session'] +__all__ = ["scoped_session"] class scoped_session(object): @@ -65,7 +65,8 @@ class scoped_session(object): if self.registry.has(): raise sa_exc.InvalidRequestError( "Scoped session is already present; " - "no new arguments may be specified.") + "no new arguments may be specified." + ) else: sess = self.session_factory(**kw) self.registry.set(sess) @@ -99,9 +100,11 @@ class scoped_session(object): """ if self.registry.has(): - warn('At least one scoped session is already present. ' - ' configure() can not affect sessions that have ' - 'already been created.') + warn( + "At least one scoped session is already present. " + " configure() can not affect sessions that have " + "already been created." + ) self.session_factory.configure(**kwargs) @@ -129,6 +132,7 @@ class scoped_session(object): a class. """ + class query(object): def __get__(s, instance, owner): try: @@ -142,8 +146,10 @@ class scoped_session(object): return self.registry().query(mapper) except orm_exc.UnmappedClassError: return None + return query() + ScopedSession = scoped_session """Old name for backwards compatibility.""" @@ -151,8 +157,10 @@ ScopedSession = scoped_session def instrument(name): def do(self, *args, **kwargs): return getattr(self.registry(), name)(*args, **kwargs) + return do + for meth in Session.public_methods: setattr(scoped_session, meth, instrument(meth)) @@ -166,16 +174,28 @@ def makeprop(name): return property(get, set) -for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', - 'is_active', 'autoflush', 'no_autoflush', 'info', - 'autocommit'): + +for prop in ( + "bind", + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + "autocommit", +): setattr(scoped_session, prop, makeprop(prop)) def clslevel(name): def do(cls, *args, **kwargs): return getattr(Session, name)(*args, **kwargs) + return classmethod(do) -for prop in ('close_all', 'object_session', 'identity_key'): + +for prop in ("close_all", "object_session", "identity_key"): setattr(scoped_session, prop, clslevel(prop)) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index b1993118dd..a3edacc194 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -10,15 +10,17 @@ import weakref from .. import util, sql, engine, exc as sa_exc from ..sql import util as sql_util, expression -from . import ( - SessionExtension, attributes, exc, query, - loading, identity -) +from . import SessionExtension, attributes, exc, query, loading, identity from ..inspection import inspect from .base import ( - object_mapper, class_mapper, - _class_to_mapper, _state_mapper, object_state, - _none_set, state_str, instance_str + object_mapper, + class_mapper, + _class_to_mapper, + _state_mapper, + object_state, + _none_set, + state_str, + instance_str, ) import itertools from . import persistence @@ -26,8 +28,7 @@ from .unitofwork import UOWTransaction from . import state as statelib import sys -__all__ = ['Session', 'SessionTransaction', - 'SessionExtension', 'sessionmaker'] +__all__ = ["Session", "SessionTransaction", "SessionExtension", "sessionmaker"] _sessions = weakref.WeakValueDictionary() """Weak-referencing dictionary of :class:`.Session` objects. @@ -77,11 +78,11 @@ class _SessionClassMethods(object): return object_session(instance) -ACTIVE = util.symbol('ACTIVE') -PREPARED = util.symbol('PREPARED') -COMMITTED = util.symbol('COMMITTED') -DEACTIVE = util.symbol('DEACTIVE') -CLOSED = util.symbol('CLOSED') +ACTIVE = util.symbol("ACTIVE") +PREPARED = util.symbol("PREPARED") +COMMITTED = util.symbol("COMMITTED") +DEACTIVE = util.symbol("DEACTIVE") +CLOSED = util.symbol("CLOSED") class SessionTransaction(object): @@ -212,7 +213,8 @@ class SessionTransaction(object): if not parent and nested: raise sa_exc.InvalidRequestError( "Can't start a SAVEPOINT transaction when no existing " - "transaction is in progress") + "transaction is in progress" + ) if self.session._enable_transaction_accounting: self._take_snapshot() @@ -249,10 +251,13 @@ class SessionTransaction(object): def is_active(self): return self.session is not None and self._state is ACTIVE - def _assert_active(self, prepared_ok=False, - rollback_ok=False, - deactive_ok=False, - closed_msg="This transaction is closed"): + def _assert_active( + self, + prepared_ok=False, + rollback_ok=False, + deactive_ok=False, + closed_msg="This transaction is closed", + ): if self._state is COMMITTED: raise sa_exc.InvalidRequestError( "This session is in 'committed' state; no further " @@ -295,21 +300,21 @@ class SessionTransaction(object): def _begin(self, nested=False): self._assert_active() - return SessionTransaction( - self.session, self, nested=nested) + return SessionTransaction(self.session, self, nested=nested) def _iterate_self_and_parents(self, upto=None): current = self result = () while current: - result += (current, ) + result += (current,) if current._parent is upto: break elif current._parent is None: raise sa_exc.InvalidRequestError( - "Transaction %s is not on the active transaction list" % ( - upto)) + "Transaction %s is not on the active transaction list" + % (upto) + ) else: current = current._parent @@ -376,7 +381,8 @@ class SessionTransaction(object): s._expire(s.dict, self.session.identity_map._modified) statelib.InstanceState._detach_states( - list(self._deleted), self.session) + list(self._deleted), self.session + ) self._deleted.clear() elif self.nested: self._parent._new.update(self._new) @@ -391,7 +397,8 @@ class SessionTransaction(object): if execution_options: util.warn( "Connection is already established for the " - "given bind; execution_options ignored") + "given bind; execution_options ignored" + ) return self._connections[bind][0] if self._parent: @@ -404,7 +411,8 @@ class SessionTransaction(object): if conn.engine in self._connections: raise sa_exc.InvalidRequestError( "Session already has a Connection associated for the " - "given Connection's Engine") + "given Connection's Engine" + ) else: conn = bind.contextual_connect() @@ -418,8 +426,11 @@ class SessionTransaction(object): else: transaction = conn.begin() - self._connections[conn] = self._connections[conn.engine] = \ - (conn, transaction, conn is not bind) + self._connections[conn] = self._connections[conn.engine] = ( + conn, + transaction, + conn is not bind, + ) self.session.dispatch.after_begin(self.session, self, conn) return conn @@ -427,7 +438,8 @@ class SessionTransaction(object): if self._parent is not None or not self.session.twophase: raise sa_exc.InvalidRequestError( "'twophase' mode not enabled, or not root transaction; " - "can't prepare.") + "can't prepare." + ) self._prepare_impl() def _prepare_impl(self): @@ -449,7 +461,8 @@ class SessionTransaction(object): raise exc.FlushError( "Over 100 subsequent flushes have occurred within " "session.commit() - is an after_flush() hook " - "creating new objects?") + "creating new objects?" + ) if self._parent is None and self.session.twophase: try: @@ -504,7 +517,8 @@ class SessionTransaction(object): transaction._state = DEACTIVE if self.session._enable_transaction_accounting: transaction._restore_snapshot( - dirty_only=transaction.nested) + dirty_only=transaction.nested + ) boundary = transaction break else: @@ -512,15 +526,19 @@ class SessionTransaction(object): sess = self.session - if not rollback_err and sess._enable_transaction_accounting and \ - not sess._is_clean(): + if ( + not rollback_err + and sess._enable_transaction_accounting + and not sess._is_clean() + ): # if items were added, deleted, or mutated # here, we need to re-restore the snapshot util.warn( "Session's state has been changed on " "a non-active transaction - this state " - "will be discarded.") + "will be discarded." + ) boundary._restore_snapshot(dirty_only=boundary.nested) self.close() @@ -535,12 +553,12 @@ class SessionTransaction(object): return self._parent - def close(self, invalidate=False): self.session.transaction = self._parent if self._parent is None: - for connection, transaction, autoclose in \ - set(self._connections.values()): + for connection, transaction, autoclose in set( + self._connections.values() + ): if invalidate: connection.invalidate() if autoclose: @@ -583,21 +601,49 @@ class Session(_SessionClassMethods): """ public_methods = ( - '__contains__', '__iter__', 'add', 'add_all', 'begin', 'begin_nested', - 'close', 'commit', 'connection', 'delete', 'execute', 'expire', - 'expire_all', 'expunge', 'expunge_all', 'flush', 'get_bind', - 'is_modified', 'bulk_save_objects', 'bulk_insert_mappings', - 'bulk_update_mappings', - 'merge', 'query', 'refresh', 'rollback', - 'scalar') - - def __init__(self, bind=None, autoflush=True, expire_on_commit=True, - _enable_transaction_accounting=True, - autocommit=False, twophase=False, - weak_identity_map=True, binds=None, extension=None, - enable_baked_queries=True, - info=None, - query_cls=query.Query): + "__contains__", + "__iter__", + "add", + "add_all", + "begin", + "begin_nested", + "close", + "commit", + "connection", + "delete", + "execute", + "expire", + "expire_all", + "expunge", + "expunge_all", + "flush", + "get_bind", + "is_modified", + "bulk_save_objects", + "bulk_insert_mappings", + "bulk_update_mappings", + "merge", + "query", + "refresh", + "rollback", + "scalar", + ) + + def __init__( + self, + bind=None, + autoflush=True, + expire_on_commit=True, + _enable_transaction_accounting=True, + autocommit=False, + twophase=False, + weak_identity_map=True, + binds=None, + extension=None, + enable_baked_queries=True, + info=None, + query_cls=query.Query, + ): r"""Construct a new Session. See also the :class:`.sessionmaker` function which is used to @@ -753,12 +799,13 @@ class Session(_SessionClassMethods): "weak_identity_map=False is deprecated. " "See the documentation on 'Session Referencing Behavior' " "for an event-based approach to maintaining strong identity " - "references.") + "references." + ) self._identity_cls = identity.StrongInstanceDict self.identity_map = self._identity_cls() - self._new = {} # InstanceState->object, strong refs object + self._new = {} # InstanceState->object, strong refs object self._deleted = {} # same self.bind = bind self.__binds = {} @@ -861,15 +908,14 @@ class Session(_SessionClassMethods): """ if self.transaction is not None: if subtransactions or nested: - self.transaction = self.transaction._begin( - nested=nested) + self.transaction = self.transaction._begin(nested=nested) else: raise sa_exc.InvalidRequestError( "A transaction is already begun. Use " - "subtransactions=True to allow subtransactions.") + "subtransactions=True to allow subtransactions." + ) else: - self.transaction = SessionTransaction( - self, nested=nested) + self.transaction = SessionTransaction(self, nested=nested) return self.transaction # needed for __enter__/__exit__ hook def begin_nested(self): @@ -972,11 +1018,15 @@ class Session(_SessionClassMethods): self.transaction.prepare() - def connection(self, mapper=None, clause=None, - bind=None, - close_with_result=False, - execution_options=None, - **kw): + def connection( + self, + mapper=None, + clause=None, + bind=None, + close_with_result=False, + execution_options=None, + **kw + ): r"""Return a :class:`.Connection` object corresponding to this :class:`.Session` object's transactional state. @@ -1041,14 +1091,17 @@ class Session(_SessionClassMethods): if bind is None: bind = self.get_bind(mapper, clause=clause, **kw) - return self._connection_for_bind(bind, - close_with_result=close_with_result, - execution_options=execution_options) + return self._connection_for_bind( + bind, + close_with_result=close_with_result, + execution_options=execution_options, + ) def _connection_for_bind(self, engine, execution_options=None, **kw): if self.transaction is not None: return self.transaction._connection_for_bind( - engine, execution_options) + engine, execution_options + ) else: conn = engine.contextual_connect(**kw) if execution_options: @@ -1183,14 +1236,16 @@ class Session(_SessionClassMethods): if bind is None: bind = self.get_bind(mapper, clause=clause, **kw) - return self._connection_for_bind( - bind, close_with_result=True).execute(clause, params or {}) + return self._connection_for_bind(bind, close_with_result=True).execute( + clause, params or {} + ) def scalar(self, clause, params=None, mapper=None, bind=None, **kw): """Like :meth:`~.Session.execute` but return a scalar result.""" return self.execute( - clause, params=params, mapper=mapper, bind=bind, **kw).scalar() + clause, params=params, mapper=mapper, bind=bind, **kw + ).scalar() def close(self): """Close this Session. @@ -1256,9 +1311,7 @@ class Session(_SessionClassMethods): self._new = {} self._deleted = {} - statelib.InstanceState._detach_states( - all_states, self - ) + statelib.InstanceState._detach_states(all_states, self) def _add_bind(self, key, bind): try: @@ -1266,7 +1319,8 @@ class Session(_SessionClassMethods): except sa_exc.NoInspectionAvailable: if not isinstance(key, type): raise sa_exc.ArgumentError( - "Not an acceptable bind target: %s" % key) + "Not an acceptable bind target: %s" % key + ) else: self.__binds[key] = bind else: @@ -1278,7 +1332,8 @@ class Session(_SessionClassMethods): self.__binds[selectable] = bind else: raise sa_exc.ArgumentError( - "Not an acceptable bind target: %s" % key) + "Not an acceptable bind target: %s" % key + ) def bind_mapper(self, mapper, bind): """Associate a :class:`.Mapper` or arbitrary Python class with a @@ -1408,7 +1463,8 @@ class Session(_SessionClassMethods): raise sa_exc.UnboundExecutionError( "This session is not bound to a single Engine or " "Connection, and no context was provided to locate " - "a binding.") + "a binding." + ) if mapper is not None: try: @@ -1443,13 +1499,14 @@ class Session(_SessionClassMethods): context = [] if mapper is not None: - context.append('mapper %s' % mapper) + context.append("mapper %s" % mapper) if clause is not None: - context.append('SQL expression') + context.append("SQL expression") raise sa_exc.UnboundExecutionError( - "Could not locate a bind configured on %s or this Session" % ( - ', '.join(context))) + "Could not locate a bind configured on %s or this Session" + % (", ".join(context)) + ) def query(self, *entities, **kwargs): """Return a new :class:`.Query` object corresponding to this @@ -1499,12 +1556,17 @@ class Session(_SessionClassMethods): e.add_detail( "raised as a result of Query-invoked autoflush; " "consider using a session.no_autoflush block if this " - "flush is occurring prematurely") + "flush is occurring prematurely" + ) util.raise_from_cause(e) def refresh( - self, instance, attribute_names=None, with_for_update=None, - lockmode=None): + self, + instance, + attribute_names=None, + with_for_update=None, + lockmode=None, + ): """Expire and refresh the attributes on the given instance. A query will be issued to the database and all attributes will be @@ -1560,7 +1622,8 @@ class Session(_SessionClassMethods): raise sa_exc.ArgumentError( "with_for_update should be the boolean value " "True, or a dictionary with options. " - "A blank dictionary is ambiguous.") + "A blank dictionary is ambiguous." + ) if lockmode: with_for_update = query.LockmodeArg.parse_legacy_query(lockmode) @@ -1572,14 +1635,19 @@ class Session(_SessionClassMethods): else: with_for_update = None - if loading.load_on_ident( + if ( + loading.load_on_ident( self.query(object_mapper(instance)), - state.key, refresh_state=state, + state.key, + refresh_state=state, with_for_update=with_for_update, - only_load_props=attribute_names) is None: + only_load_props=attribute_names, + ) + is None + ): raise sa_exc.InvalidRequestError( - "Could not refresh instance '%s'" % - instance_str(instance)) + "Could not refresh instance '%s'" % instance_str(instance) + ) def expire_all(self): """Expires all persistent instances within this Session. @@ -1662,8 +1730,9 @@ class Session(_SessionClassMethods): else: # pre-fetch the full cascade since the expire is going to # remove associations - cascaded = list(state.manager.mapper.cascade_iterator( - 'refresh-expire', state)) + cascaded = list( + state.manager.mapper.cascade_iterator("refresh-expire", state) + ) self._conditional_expire(state) for o, m, st_, dct_ in cascaded: self._conditional_expire(st_) @@ -1677,8 +1746,11 @@ class Session(_SessionClassMethods): self._new.pop(state) state._detach(self) - @util.deprecated("0.7", "The non-weak-referencing identity map " - "feature is no longer needed.") + @util.deprecated( + "0.7", + "The non-weak-referencing identity map " + "feature is no longer needed.", + ) def prune(self): """Remove unreferenced instances cached in the identity map. @@ -1705,14 +1777,13 @@ class Session(_SessionClassMethods): raise exc.UnmappedInstanceError(instance) if state.session_id is not self.hash_key: raise sa_exc.InvalidRequestError( - "Instance %s is not present in this Session" % - state_str(state)) + "Instance %s is not present in this Session" % state_str(state) + ) - cascaded = list(state.manager.mapper.cascade_iterator( - 'expunge', state)) - self._expunge_states( - [state] + [st_ for o, m, st_, dct_ in cascaded] + cascaded = list( + state.manager.mapper.cascade_iterator("expunge", state) ) + self._expunge_states([state] + [st_ for o, m, st_, dct_ in cascaded]) def _expunge_states(self, states, to_transient=False): for state in states: @@ -1726,7 +1797,8 @@ class Session(_SessionClassMethods): # in the transaction snapshot self.transaction._deleted.pop(state, None) statelib.InstanceState._detach_states( - states, self, to_transient=to_transient) + states, self, to_transient=to_transient + ) def _register_newly_persistent(self, states): pending_to_persistent = self.dispatch.pending_to_persistent or None @@ -1739,9 +1811,11 @@ class Session(_SessionClassMethods): instance_key = mapper._identity_key_from_state(state) - if _none_set.intersection(instance_key[1]) and \ - not mapper.allow_partial_pks or \ - _none_set.issuperset(instance_key[1]): + if ( + _none_set.intersection(instance_key[1]) + and not mapper.allow_partial_pks + or _none_set.issuperset(instance_key[1]) + ): raise exc.FlushError( "Instance %s has a NULL identity key. If this is an " "auto-generated value, check that the database table " @@ -1765,15 +1839,16 @@ class Session(_SessionClassMethods): else: orig_key = state.key self.transaction._key_switches[state] = ( - orig_key, instance_key) + orig_key, + instance_key, + ) state.key = instance_key self.identity_map.replace(state) state._orphaned_outside_of_session = False statelib.InstanceState._commit_all_states( - ((state, state.dict) for state in states), - self.identity_map + ((state, state.dict) for state in states), self.identity_map ) self._register_altered(states) @@ -1849,9 +1924,8 @@ class Session(_SessionClassMethods): mapper = _state_mapper(state) for o, m, st_, dct_ in mapper.cascade_iterator( - 'save-update', - state, - halt_on=self._contains_state): + "save-update", state, halt_on=self._contains_state + ): self._save_or_update_impl(st_) def delete(self, instance): @@ -1875,8 +1949,8 @@ class Session(_SessionClassMethods): if state.key is None: if head: raise sa_exc.InvalidRequestError( - "Instance '%s' is not persisted" % - state_str(state)) + "Instance '%s' is not persisted" % state_str(state) + ) else: return @@ -1894,8 +1968,9 @@ class Session(_SessionClassMethods): # grab the cascades before adding the item to the deleted list # so that autoflush does not delete the item # the strong reference to the instance itself is significant here - cascade_states = list(state.manager.mapper.cascade_iterator( - 'delete', state)) + cascade_states = list( + state.manager.mapper.cascade_iterator("delete", state) + ) self._deleted[state] = obj @@ -1975,13 +2050,21 @@ class Session(_SessionClassMethods): return self._merge( attributes.instance_state(instance), attributes.instance_dict(instance), - load=load, _recursive=_recursive, - _resolve_conflict_map=_resolve_conflict_map) + load=load, + _recursive=_recursive, + _resolve_conflict_map=_resolve_conflict_map, + ) finally: self.autoflush = autoflush - def _merge(self, state, state_dict, load=True, _recursive=None, - _resolve_conflict_map=None): + def _merge( + self, + state, + state_dict, + load=True, + _recursive=None, + _resolve_conflict_map=None, + ): mapper = _state_mapper(state) if state in _recursive: return _recursive[state] @@ -1995,11 +2078,15 @@ class Session(_SessionClassMethods): "merge() with load=False option does not support " "objects transient (i.e. unpersisted) objects. flush() " "all changes on mapped instances before merging with " - "load=False.") + "load=False." + ) key = mapper._identity_key_from_state(state) key_is_persistent = attributes.NEVER_SET not in key[1] and ( - not _none_set.intersection(key[1]) or - (mapper.allow_partial_pks and not _none_set.issuperset(key[1])) + not _none_set.intersection(key[1]) + or ( + mapper.allow_partial_pks + and not _none_set.issuperset(key[1]) + ) ) else: key_is_persistent = True @@ -2022,7 +2109,8 @@ class Session(_SessionClassMethods): raise sa_exc.InvalidRequestError( "merge() with load=False option does not support " "objects marked as 'dirty'. flush() all changes on " - "mapped instances before merging with load=False.") + "mapped instances before merging with load=False." + ) merged = mapper.class_manager.new_instance() merged_state = attributes.instance_state(merged) merged_state.key = key @@ -2054,17 +2142,21 @@ class Session(_SessionClassMethods): state, state_dict, mapper.version_id_col, - passive=attributes.PASSIVE_NO_INITIALIZE) + passive=attributes.PASSIVE_NO_INITIALIZE, + ) merged_version = mapper._get_state_attr_by_column( merged_state, merged_dict, mapper.version_id_col, - passive=attributes.PASSIVE_NO_INITIALIZE) + passive=attributes.PASSIVE_NO_INITIALIZE, + ) - if existing_version is not attributes.PASSIVE_NO_RESULT and \ - merged_version is not attributes.PASSIVE_NO_RESULT and \ - existing_version != merged_version: + if ( + existing_version is not attributes.PASSIVE_NO_RESULT + and merged_version is not attributes.PASSIVE_NO_RESULT + and existing_version != merged_version + ): raise exc.StaleDataError( "Version id '%s' on merged state %s " "does not match existing version '%s'. " @@ -2073,8 +2165,9 @@ class Session(_SessionClassMethods): % ( existing_version, state_str(merged_state), - merged_version - )) + merged_version, + ) + ) merged_state.load_path = state.load_path merged_state.load_options = state.load_options @@ -2087,9 +2180,16 @@ class Session(_SessionClassMethods): merged_state._copy_callables(state) for prop in mapper.iterate_properties: - prop.merge(self, state, state_dict, - merged_state, merged_dict, - load, _recursive, _resolve_conflict_map) + prop.merge( + self, + state, + state_dict, + merged_state, + merged_dict, + load, + _recursive, + _resolve_conflict_map, + ) if not load: # remove any history @@ -2102,14 +2202,16 @@ class Session(_SessionClassMethods): def _validate_persistent(self, state): if not self.identity_map.contains_state(state): raise sa_exc.InvalidRequestError( - "Instance '%s' is not persistent within this Session" % - state_str(state)) + "Instance '%s' is not persistent within this Session" + % state_str(state) + ) def _save_impl(self, state): if state.key is not None: raise sa_exc.InvalidRequestError( "Object '%s' already has an identity - " - "it can't be registered as pending" % state_str(state)) + "it can't be registered as pending" % state_str(state) + ) obj = state.obj() to_attach = self._before_attach(state, obj) @@ -2122,8 +2224,8 @@ class Session(_SessionClassMethods): def _update_impl(self, state, revert_deletion=False): if state.key is None: raise sa_exc.InvalidRequestError( - "Instance '%s' is not persisted" % - state_str(state)) + "Instance '%s' is not persisted" % state_str(state) + ) if state._deleted: if revert_deletion: @@ -2135,8 +2237,7 @@ class Session(_SessionClassMethods): "Instance '%s' has been deleted. " "Use the make_transient() " "function to send this object back " - "to the transient state." % - state_str(state) + "to the transient state." % state_str(state) ) obj = state.obj() @@ -2234,8 +2335,9 @@ class Session(_SessionClassMethods): if state.session_id and state.session_id in _sessions: raise sa_exc.InvalidRequestError( "Object '%s' is already attached to session '%s' " - "(this is '%s')" % (state_str(state), - state.session_id, self.hash_key)) + "(this is '%s')" + % (state_str(state), state.session_id, self.hash_key) + ) self.dispatch.before_attach(self, obj) @@ -2271,7 +2373,8 @@ class Session(_SessionClassMethods): """ return iter( - list(self._new.values()) + list(self.identity_map.values())) + list(self._new.values()) + list(self.identity_map.values()) + ) def _contains_state(self, state): return state in self._new or self.identity_map.contains_state(state) @@ -2319,13 +2422,15 @@ class Session(_SessionClassMethods): "Usage of the '%s' operation is not currently supported " "within the execution stage of the flush process. " "Results may not be consistent. Consider using alternative " - "event listeners or connection-level operations instead." - % method) + "event listeners or connection-level operations instead." % method + ) def _is_clean(self): - return not self.identity_map.check_modified() and \ - not self._deleted and \ - not self._new + return ( + not self.identity_map.check_modified() + and not self._deleted + and not self._new + ) def _flush(self, objects=None): @@ -2375,12 +2480,16 @@ class Session(_SessionClassMethods): is_persistent_orphan = is_orphan and state.has_identity - if is_orphan and not is_persistent_orphan and \ - state._orphaned_outside_of_session: + if ( + is_orphan + and not is_persistent_orphan + and state._orphaned_outside_of_session + ): self._expunge_states([state]) else: _reg = flush_context.register_object( - state, isdelete=is_persistent_orphan) + state, isdelete=is_persistent_orphan + ) assert _reg, "Failed to add object to the flush context!" processed.add(state) @@ -2397,7 +2506,8 @@ class Session(_SessionClassMethods): return flush_context.transaction = transaction = self.begin( - subtransactions=True) + subtransactions=True + ) try: self._warn_on_events = True try: @@ -2413,16 +2523,20 @@ class Session(_SessionClassMethods): len_ = len(self.identity_map._modified) statelib.InstanceState._commit_all_states( - [(state, state.dict) for state in - self.identity_map._modified], - instance_dict=self.identity_map) - util.warn("Attribute history events accumulated on %d " - "previously clean instances " - "within inner-flush event handlers have been " - "reset, and will not result in database updates. " - "Consider using set_committed_value() within " - "inner-flush event handlers to avoid this warning." - % len_) + [ + (state, state.dict) + for state in self.identity_map._modified + ], + instance_dict=self.identity_map, + ) + util.warn( + "Attribute history events accumulated on %d " + "previously clean instances " + "within inner-flush event handlers have been " + "reset, and will not result in database updates. " + "Consider using set_committed_value() within " + "inner-flush event handlers to avoid this warning." % len_ + ) # useful assertions: # if not objects: @@ -2440,8 +2554,12 @@ class Session(_SessionClassMethods): transaction.rollback(_capture_exception=True) def bulk_save_objects( - self, objects, return_defaults=False, update_changed_only=True, - preserve_order=True): + self, + objects, + return_defaults=False, + update_changed_only=True, + preserve_order=True, + ): """Perform a bulk save of the given list of objects. The bulk save feature allows mapped objects to be used as the @@ -2520,6 +2638,7 @@ class Session(_SessionClassMethods): :meth:`.Session.bulk_update_mappings` """ + def key(state): return (state.mapper, state.key is not None) @@ -2527,15 +2646,20 @@ class Session(_SessionClassMethods): if not preserve_order: obj_states = sorted(obj_states, key=key) - for (mapper, isupdate), states in itertools.groupby( - obj_states, key - ): + for (mapper, isupdate), states in itertools.groupby(obj_states, key): self._bulk_save_mappings( - mapper, states, isupdate, True, - return_defaults, update_changed_only, False) + mapper, + states, + isupdate, + True, + return_defaults, + update_changed_only, + False, + ) def bulk_insert_mappings( - self, mapper, mappings, return_defaults=False, render_nulls=False): + self, mapper, mappings, return_defaults=False, render_nulls=False + ): """Perform a bulk insert of the given list of mapping dictionaries. The bulk insert feature allows plain Python dictionaries to be used as @@ -2622,8 +2746,14 @@ class Session(_SessionClassMethods): """ self._bulk_save_mappings( - mapper, mappings, False, False, - return_defaults, False, render_nulls) + mapper, + mappings, + False, + False, + return_defaults, + False, + render_nulls, + ) def bulk_update_mappings(self, mapper, mappings): """Perform a bulk update of the given list of mapping dictionaries. @@ -2673,25 +2803,41 @@ class Session(_SessionClassMethods): """ self._bulk_save_mappings( - mapper, mappings, True, False, False, False, False) + mapper, mappings, True, False, False, False, False + ) def _bulk_save_mappings( - self, mapper, mappings, isupdate, isstates, - return_defaults, update_changed_only, render_nulls): + self, + mapper, + mappings, + isupdate, + isstates, + return_defaults, + update_changed_only, + render_nulls, + ): mapper = _class_to_mapper(mapper) self._flushing = True - transaction = self.begin( - subtransactions=True) + transaction = self.begin(subtransactions=True) try: if isupdate: persistence._bulk_update( - mapper, mappings, transaction, - isstates, update_changed_only) + mapper, + mappings, + transaction, + isstates, + update_changed_only, + ) else: persistence._bulk_insert( - mapper, mappings, transaction, - isstates, return_defaults, render_nulls) + mapper, + mappings, + transaction, + isstates, + return_defaults, + render_nulls, + ) transaction.commit() except: @@ -2700,8 +2846,7 @@ class Session(_SessionClassMethods): finally: self._flushing = False - def is_modified(self, instance, include_collections=True, - passive=True): + def is_modified(self, instance, include_collections=True, passive=True): r"""Return ``True`` if the given instance has locally modified attributes. @@ -2775,16 +2920,15 @@ class Session(_SessionClassMethods): dict_ = state.dict for attr in state.manager.attributes: - if \ - ( - not include_collections and - hasattr(attr.impl, 'get_collection') - ) or not hasattr(attr.impl, 'get_history'): + if ( + not include_collections + and hasattr(attr.impl, "get_collection") + ) or not hasattr(attr.impl, "get_history"): continue - (added, unchanged, deleted) = \ - attr.impl.get_history(state, dict_, - passive=attributes.NO_CHANGE) + (added, unchanged, deleted) = attr.impl.get_history( + state, dict_, passive=attributes.NO_CHANGE + ) if added or deleted: return True @@ -2898,9 +3042,12 @@ class Session(_SessionClassMethods): """ return util.IdentitySet( - [state.obj() - for state in self._dirty_states - if state not in self._deleted]) + [ + state.obj() + for state in self._dirty_states + if state not in self._deleted + ] + ) @property def deleted(self): @@ -2961,10 +3108,16 @@ class sessionmaker(_SessionClassMethods): """ - def __init__(self, bind=None, class_=Session, autoflush=True, - autocommit=False, - expire_on_commit=True, - info=None, **kw): + def __init__( + self, + bind=None, + class_=Session, + autoflush=True, + autocommit=False, + expire_on_commit=True, + info=None, + **kw + ): r"""Construct a new :class:`.sessionmaker`. All arguments here except for ``class_`` correspond to arguments @@ -2992,12 +3145,12 @@ class sessionmaker(_SessionClassMethods): constructor of newly created :class:`.Session` objects. """ - kw['bind'] = bind - kw['autoflush'] = autoflush - kw['autocommit'] = autocommit - kw['expire_on_commit'] = expire_on_commit + kw["bind"] = bind + kw["autoflush"] = autoflush + kw["autocommit"] = autocommit + kw["expire_on_commit"] = expire_on_commit if info is not None: - kw['info'] = info + kw["info"] = info self.kw = kw # make our own subclass of the given class, so that # events can be associated with it specifically. @@ -3015,10 +3168,10 @@ class sessionmaker(_SessionClassMethods): """ for k, v in self.kw.items(): - if k == 'info' and 'info' in local_kw: + if k == "info" and "info" in local_kw: d = v.copy() - d.update(local_kw['info']) - local_kw['info'] = d + d.update(local_kw["info"]) + local_kw["info"] = d else: local_kw.setdefault(k, v) return self.class_(**local_kw) @@ -3038,7 +3191,7 @@ class sessionmaker(_SessionClassMethods): return "%s(class_=%r, %s)" % ( self.__class__.__name__, self.class_.__name__, - ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()) + ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()), ) @@ -3139,8 +3292,7 @@ def make_transient_to_detached(instance): """ state = attributes.instance_state(instance) if state.session_id or state.key: - raise sa_exc.InvalidRequestError( - "Given object must be transient") + raise sa_exc.InvalidRequestError("Given object must be transient") state.key = state.mapper._identity_key_from_state(state) if state._deleted: del state._deleted diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 944dc8177f..c36d8817b8 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -18,8 +18,16 @@ from .. import inspection from .. import exc as sa_exc from . import exc as orm_exc, interfaces from .path_registry import PathRegistry -from .base import PASSIVE_NO_RESULT, SQL_OK, NEVER_SET, ATTR_WAS_SET, \ - NO_VALUE, PASSIVE_NO_INITIALIZE, INIT_OK, PASSIVE_OFF +from .base import ( + PASSIVE_NO_RESULT, + SQL_OK, + NEVER_SET, + ATTR_WAS_SET, + NO_VALUE, + PASSIVE_NO_INITIALIZE, + INIT_OK, + PASSIVE_OFF, +) from . import base @@ -106,10 +114,7 @@ class InstanceState(interfaces.InspectionAttrInfo): """ return util.ImmutableProperties( - dict( - (key, AttributeState(self, key)) - for key in self.manager - ) + dict((key, AttributeState(self, key)) for key in self.manager) ) @property @@ -121,8 +126,7 @@ class InstanceState(interfaces.InspectionAttrInfo): :ref:`session_object_states` """ - return self.key is None and \ - not self._attached + return self.key is None and not self._attached @property def pending(self): @@ -134,8 +138,7 @@ class InstanceState(interfaces.InspectionAttrInfo): :ref:`session_object_states` """ - return self.key is None and \ - self._attached + return self.key is None and self._attached @property def deleted(self): @@ -164,8 +167,7 @@ class InstanceState(interfaces.InspectionAttrInfo): :ref:`session_object_states` """ - return self.key is not None and \ - self._attached and self._deleted + return self.key is not None and self._attached and self._deleted @property def was_deleted(self): @@ -210,8 +212,7 @@ class InstanceState(interfaces.InspectionAttrInfo): :ref:`session_object_states` """ - return self.key is not None and \ - self._attached and not self._deleted + return self.key is not None and self._attached and not self._deleted @property def detached(self): @@ -227,8 +228,10 @@ class InstanceState(interfaces.InspectionAttrInfo): @property @util.dependencies("sqlalchemy.orm.session") def _attached(self, sessionlib): - return self.session_id is not None and \ - self.session_id in sessionlib._sessions + return ( + self.session_id is not None + and self.session_id in sessionlib._sessions + ) def _track_last_known_value(self, key): """Track the last known value of a particular key after expiration @@ -323,14 +326,14 @@ class InstanceState(interfaces.InspectionAttrInfo): @classmethod def _detach_states(self, states, session, to_transient=False): - persistent_to_detached = \ + persistent_to_detached = ( session.dispatch.persistent_to_detached or None - deleted_to_detached = \ - session.dispatch.deleted_to_detached or None - pending_to_transient = \ - session.dispatch.pending_to_transient or None - persistent_to_transient = \ + ) + deleted_to_detached = session.dispatch.deleted_to_detached or None + pending_to_transient = session.dispatch.pending_to_transient or None + persistent_to_transient = ( session.dispatch.persistent_to_transient or None + ) for state in states: deleted = state._deleted @@ -448,23 +451,33 @@ class InstanceState(interfaces.InspectionAttrInfo): return self._pending_mutations[key] def __getstate__(self): - state_dict = {'instance': self.obj()} + state_dict = {"instance": self.obj()} state_dict.update( - (k, self.__dict__[k]) for k in ( - 'committed_state', '_pending_mutations', 'modified', - 'expired', 'callables', 'key', 'parents', 'load_options', - 'class_', 'expired_attributes', 'info' - ) if k in self.__dict__ + (k, self.__dict__[k]) + for k in ( + "committed_state", + "_pending_mutations", + "modified", + "expired", + "callables", + "key", + "parents", + "load_options", + "class_", + "expired_attributes", + "info", + ) + if k in self.__dict__ ) if self.load_path: - state_dict['load_path'] = self.load_path.serialize() + state_dict["load_path"] = self.load_path.serialize() - state_dict['manager'] = self.manager._serialize(self, state_dict) + state_dict["manager"] = self.manager._serialize(self, state_dict) return state_dict def __setstate__(self, state_dict): - inst = state_dict['instance'] + inst = state_dict["instance"] if inst is not None: self.obj = weakref.ref(inst, self._cleanup) self.class_ = inst.__class__ @@ -473,20 +486,20 @@ class InstanceState(interfaces.InspectionAttrInfo): # due to storage of state in "parents". "class_" # also new. self.obj = None - self.class_ = state_dict['class_'] - - self.committed_state = state_dict.get('committed_state', {}) - self._pending_mutations = state_dict.get('_pending_mutations', {}) - self.parents = state_dict.get('parents', {}) - self.modified = state_dict.get('modified', False) - self.expired = state_dict.get('expired', False) - if 'info' in state_dict: - self.info.update(state_dict['info']) - if 'callables' in state_dict: - self.callables = state_dict['callables'] + self.class_ = state_dict["class_"] + + self.committed_state = state_dict.get("committed_state", {}) + self._pending_mutations = state_dict.get("_pending_mutations", {}) + self.parents = state_dict.get("parents", {}) + self.modified = state_dict.get("modified", False) + self.expired = state_dict.get("expired", False) + if "info" in state_dict: + self.info.update(state_dict["info"]) + if "callables" in state_dict: + self.callables = state_dict["callables"] try: - self.expired_attributes = state_dict['expired_attributes'] + self.expired_attributes = state_dict["expired_attributes"] except KeyError: self.expired_attributes = set() # 0.9 and earlier compat @@ -495,30 +508,31 @@ class InstanceState(interfaces.InspectionAttrInfo): self.expired_attributes.add(k) del self.callables[k] else: - if 'expired_attributes' in state_dict: - self.expired_attributes = state_dict['expired_attributes'] + if "expired_attributes" in state_dict: + self.expired_attributes = state_dict["expired_attributes"] else: self.expired_attributes = set() - self.__dict__.update([ - (k, state_dict[k]) for k in ( - 'key', 'load_options' - ) if k in state_dict - ]) + self.__dict__.update( + [ + (k, state_dict[k]) + for k in ("key", "load_options") + if k in state_dict + ] + ) if self.key: try: self.identity_token = self.key[2] except IndexError: # 1.1 and earlier compat before identity_token assert len(self.key) == 2 - self.key = self.key + (None, ) + self.key = self.key + (None,) self.identity_token = None - if 'load_path' in state_dict: - self.load_path = PathRegistry.\ - deserialize(state_dict['load_path']) + if "load_path" in state_dict: + self.load_path = PathRegistry.deserialize(state_dict["load_path"]) - state_dict['manager'](self, inst, state_dict) + state_dict["manager"](self, inst, state_dict) def _reset(self, dict_, key): """Remove the given attribute and any @@ -532,25 +546,29 @@ class InstanceState(interfaces.InspectionAttrInfo): self.callables.pop(key, None) def _copy_callables(self, from_): - if 'callables' in from_.__dict__: + if "callables" in from_.__dict__: self.callables = dict(from_.callables) @classmethod def _instance_level_callable_processor(cls, manager, fn, key): impl = manager[key].impl if impl.collection: + def _set_callable(state, dict_, row): - if 'callables' not in state.__dict__: + if "callables" not in state.__dict__: state.callables = {} old = dict_.pop(key, None) if old is not None: impl._invalidate_collection(old) state.callables[key] = fn + else: + def _set_callable(state, dict_, row): - if 'callables' not in state.__dict__: + if "callables" not in state.__dict__: state.callables = {} state.callables[key] = fn + return _set_callable def _expire(self, dict_, modified_set): @@ -563,15 +581,18 @@ class InstanceState(interfaces.InspectionAttrInfo): self._strong_obj = None - if '_pending_mutations' in self.__dict__: - del self.__dict__['_pending_mutations'] + if "_pending_mutations" in self.__dict__: + del self.__dict__["_pending_mutations"] - if 'parents' in self.__dict__: - del self.__dict__['parents'] + if "parents" in self.__dict__: + del self.__dict__["parents"] self.expired_attributes.update( - [impl.key for impl in self.manager._scalar_loader_impls - if impl.expire_missing or impl.key in dict_] + [ + impl.key + for impl in self.manager._scalar_loader_impls + if impl.expire_missing or impl.key in dict_ + ] ) if self.callables: @@ -584,8 +605,7 @@ class InstanceState(interfaces.InspectionAttrInfo): if self._last_known_values: self._last_known_values.update( - (k, dict_[k]) for k in self._last_known_values - if k in dict_ + (k, dict_[k]) for k in self._last_known_values if k in dict_ ) for key in self.manager._all_key_set.intersection(dict_): @@ -594,17 +614,14 @@ class InstanceState(interfaces.InspectionAttrInfo): self.manager.dispatch.expire(self, None) def _expire_attributes(self, dict_, attribute_names, no_loader=False): - pending = self.__dict__.get('_pending_mutations', None) + pending = self.__dict__.get("_pending_mutations", None) callables = self.callables for key in attribute_names: impl = self.manager[key].impl if impl.accepts_scalar_loader: - if no_loader and ( - impl.callable_ or - key in callables - ): + if no_loader and (impl.callable_ or key in callables): continue self.expired_attributes.add(key) @@ -614,8 +631,11 @@ class InstanceState(interfaces.InspectionAttrInfo): if impl.collection and old is not NO_VALUE: impl._invalidate_collection(old) - if self._last_known_values and key in self._last_known_values \ - and old is not NO_VALUE: + if ( + self._last_known_values + and key in self._last_known_values + and old is not NO_VALUE + ): self._last_known_values[key] = old self.committed_state.pop(key, None) @@ -634,8 +654,7 @@ class InstanceState(interfaces.InspectionAttrInfo): if not passive & SQL_OK: return PASSIVE_NO_RESULT - toload = self.expired_attributes.\ - intersection(self.unmodified) + toload = self.expired_attributes.intersection(self.unmodified) self.manager.deferred_scalar_loader(self, toload) @@ -656,9 +675,11 @@ class InstanceState(interfaces.InspectionAttrInfo): def unmodified_intersection(self, keys): """Return self.unmodified.intersection(keys).""" - - return set(keys).intersection(self.manager).\ - difference(self.committed_state) + return ( + set(keys) + .intersection(self.manager) + .difference(self.committed_state) + ) @property def unloaded(self): @@ -668,9 +689,11 @@ class InstanceState(interfaces.InspectionAttrInfo): was never populated or modified. """ - return set(self.manager).\ - difference(self.committed_state).\ - difference(self.dict) + return ( + set(self.manager) + .difference(self.committed_state) + .difference(self.dict) + ) @property def unloaded_expirable(self): @@ -681,13 +704,16 @@ class InstanceState(interfaces.InspectionAttrInfo): """ return self.unloaded.intersection( - attr for attr in self.manager - if self.manager[attr].impl.expire_missing) + attr + for attr in self.manager + if self.manager[attr].impl.expire_missing + ) @property def _unloaded_non_object(self): return self.unloaded.intersection( - attr for attr in self.manager + attr + for attr in self.manager if self.manager[attr].impl.accepts_scalar_loader ) @@ -695,14 +721,16 @@ class InstanceState(interfaces.InspectionAttrInfo): return None def _modified_event( - self, dict_, attr, previous, collection=False, is_userland=False): + self, dict_, attr, previous, collection=False, is_userland=False + ): if attr: if not attr.send_modified_events: return if is_userland and attr.key not in dict_: raise sa_exc.InvalidRequestError( "Can't flag attribute '%s' modified; it's not present in " - "the object state" % attr.key) + "the object state" % attr.key + ) if attr.key not in self.committed_state or is_userland: if collection: if previous is NEVER_SET: @@ -718,8 +746,7 @@ class InstanceState(interfaces.InspectionAttrInfo): # assert self._strong_obj is None or self.modified - if (self.session_id and self._strong_obj is None) \ - or not self.modified: + if (self.session_id and self._strong_obj is None) or not self.modified: self.modified = True instance_dict = self._instance_dict() if instance_dict: @@ -737,10 +764,8 @@ class InstanceState(interfaces.InspectionAttrInfo): "Can't emit change event for attribute '%s' - " "parent object of type %s has been garbage " "collected." - % ( - self.manager[attr.key], - base.state_class_str(self) - )) + % (self.manager[attr.key], base.state_class_str(self)) + ) def _commit(self, dict_, keys): """Commit attributes. @@ -758,17 +783,18 @@ class InstanceState(interfaces.InspectionAttrInfo): self.expired = False self.expired_attributes.difference_update( - set(keys).intersection(dict_)) + set(keys).intersection(dict_) + ) # the per-keys commit removes object-level callables, # while that of commit_all does not. it's not clear # if this behavior has a clear rationale, however tests do # ensure this is what it does. if self.callables: - for key in set(self.callables).\ - intersection(keys).\ - intersection(dict_): - del self.callables[key] + for key in ( + set(self.callables).intersection(keys).intersection(dict_) + ): + del self.callables[key] def _commit_all(self, dict_, instance_dict=None): """commit all attributes unconditionally. @@ -797,8 +823,8 @@ class InstanceState(interfaces.InspectionAttrInfo): state.committed_state.clear() - if '_pending_mutations' in state_dict: - del state_dict['_pending_mutations'] + if "_pending_mutations" in state_dict: + del state_dict["_pending_mutations"] state.expired_attributes.difference_update(dict_) @@ -848,7 +874,8 @@ class AttributeState(object): """ return self.state.manager[self.key].__get__( - self.state.obj(), self.state.class_) + self.state.obj(), self.state.class_ + ) @property def history(self): @@ -866,8 +893,7 @@ class AttributeState(object): :func:`.attributes.get_history` - underlying function """ - return self.state.get_history(self.key, - PASSIVE_NO_INITIALIZE) + return self.state.get_history(self.key, PASSIVE_NO_INITIALIZE) def load_history(self): """Return the current pre-flush change history for @@ -885,8 +911,7 @@ class AttributeState(object): .. versionadded:: 0.9.0 """ - return self.state.get_history(self.key, - PASSIVE_OFF ^ INIT_OK) + return self.state.get_history(self.key, PASSIVE_OFF ^ INIT_OK) class PendingCollection(object): diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 47791f9b96..5c972b26b8 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -13,22 +13,27 @@ from .. import util, log, event from ..sql import util as sql_util, visitors from .. import sql from . import ( - attributes, interfaces, exc as orm_exc, loading, - unitofwork, util as orm_util, query + attributes, + interfaces, + exc as orm_exc, + loading, + unitofwork, + util as orm_util, + query, ) from .state import InstanceState from .util import _none_set, aliased from . import properties -from .interfaces import ( - LoaderStrategy, StrategizedProperty -) +from .interfaces import LoaderStrategy, StrategizedProperty from .base import _SET_DEFERRED_EXPIRED, _DEFER_FOR_STATE from .session import _state_session import itertools def _register_attribute( - prop, mapper, useobject, + prop, + mapper, + useobject, compare_function=None, typecallable=None, callable_=None, @@ -51,8 +56,8 @@ def _register_attribute( fn, opts = prop.parent.validators[prop.key] listen_hooks.append( lambda desc, prop: orm_util._validator_events( - desc, - prop.key, fn, **opts) + desc, prop.key, fn, **opts + ) ) if useobject: @@ -65,9 +70,7 @@ def _register_attribute( if backref: listen_hooks.append( lambda desc, prop: attributes.backref_listeners( - desc, - backref, - uselist + desc, backref, uselist ) ) @@ -83,8 +86,9 @@ def _register_attribute( # on mappers not already being set up so we have to check each one. for m in mapper.self_and_descendants: - if prop is m._props.get(prop.key) and \ - not m.class_manager._attr_has_impl(prop.key): + if prop is m._props.get( + prop.key + ) and not m.class_manager._attr_has_impl(prop.key): desc = attributes.register_attribute_impl( m.class_, @@ -94,9 +98,11 @@ def _register_attribute( compare_function=compare_function, useobject=useobject, extension=attribute_ext, - trackparent=useobject and ( - prop.single_parent or - prop.direction is interfaces.ONETOMANY), + trackparent=useobject + and ( + prop.single_parent + or prop.direction is interfaces.ONETOMANY + ), typecallable=typecallable, callable_=callable_, active_history=active_history, @@ -118,23 +124,31 @@ class UninstrumentedColumnLoader(LoaderStrategy): if the argument is against the with_polymorphic selectable. """ - __slots__ = 'columns', + + __slots__ = ("columns",) def __init__(self, parent, strategy_key): super(UninstrumentedColumnLoader, self).__init__(parent, strategy_key) self.columns = self.parent_property.columns def setup_query( - self, context, entity, path, loadopt, adapter, - column_collection=None, **kwargs): + self, + context, + entity, + path, + loadopt, + adapter, + column_collection=None, + **kwargs + ): for c in self.columns: if adapter: c = adapter.columns[c] column_collection.append(c) def create_row_processor( - self, context, path, loadopt, - mapper, result, adapter, populators): + self, context, path, loadopt, mapper, result, adapter, populators + ): pass @@ -143,16 +157,24 @@ class UninstrumentedColumnLoader(LoaderStrategy): class ColumnLoader(LoaderStrategy): """Provide loading behavior for a :class:`.ColumnProperty`.""" - __slots__ = 'columns', 'is_composite' + __slots__ = "columns", "is_composite" def __init__(self, parent, strategy_key): super(ColumnLoader, self).__init__(parent, strategy_key) self.columns = self.parent_property.columns - self.is_composite = hasattr(self.parent_property, 'composite_class') + self.is_composite = hasattr(self.parent_property, "composite_class") def setup_query( - self, context, entity, path, loadopt, - adapter, column_collection, memoized_populators, **kwargs): + self, + context, + entity, + path, + loadopt, + adapter, + column_collection, + memoized_populators, + **kwargs + ): for c in self.columns: if adapter: @@ -168,19 +190,23 @@ class ColumnLoader(LoaderStrategy): self.is_class_level = True coltype = self.columns[0].type # TODO: check all columns ? check for foreign key as well? - active_history = self.parent_property.active_history or \ - self.columns[0].primary_key or \ - mapper.version_id_col in set(self.columns) + active_history = ( + self.parent_property.active_history + or self.columns[0].primary_key + or mapper.version_id_col in set(self.columns) + ) _register_attribute( - self.parent_property, mapper, useobject=False, + self.parent_property, + mapper, + useobject=False, compare_function=coltype.compare_values, - active_history=active_history + active_history=active_history, ) def create_row_processor( - self, context, path, - loadopt, mapper, result, adapter, populators): + self, context, path, loadopt, mapper, result, adapter, populators + ): # look through list of columns represented here # to see which, if any, is present in the row. for col in self.columns: @@ -201,8 +227,16 @@ class ExpressionColumnLoader(ColumnLoader): super(ExpressionColumnLoader, self).__init__(parent, strategy_key) def setup_query( - self, context, entity, path, loadopt, - adapter, column_collection, memoized_populators, **kwargs): + self, + context, + entity, + path, + loadopt, + adapter, + column_collection, + memoized_populators, + **kwargs + ): if loadopt and "expression" in loadopt.local_opts: columns = [loadopt.local_opts["expression"]] @@ -218,8 +252,8 @@ class ExpressionColumnLoader(ColumnLoader): memoized_populators[self.parent_property] = fetch def create_row_processor( - self, context, path, - loadopt, mapper, result, adapter, populators): + self, context, path, loadopt, mapper, result, adapter, populators + ): # look through list of columns represented here # to see which, if any, is present in the row. if loadopt and "expression" in loadopt.local_opts: @@ -239,9 +273,11 @@ class ExpressionColumnLoader(ColumnLoader): self.is_class_level = True _register_attribute( - self.parent_property, mapper, useobject=False, + self.parent_property, + mapper, + useobject=False, compare_function=self.columns[0].type.compare_values, - accepts_scalar_loader=False + accepts_scalar_loader=False, ) @@ -251,27 +287,29 @@ class ExpressionColumnLoader(ColumnLoader): class DeferredColumnLoader(LoaderStrategy): """Provide loading behavior for a deferred :class:`.ColumnProperty`.""" - __slots__ = 'columns', 'group' + __slots__ = "columns", "group" def __init__(self, parent, strategy_key): super(DeferredColumnLoader, self).__init__(parent, strategy_key) - if hasattr(self.parent_property, 'composite_class'): - raise NotImplementedError("Deferred loading for composite " - "types not implemented yet") + if hasattr(self.parent_property, "composite_class"): + raise NotImplementedError( + "Deferred loading for composite " "types not implemented yet" + ) self.columns = self.parent_property.columns self.group = self.parent_property.group def create_row_processor( - self, context, path, loadopt, - mapper, result, adapter, populators): + self, context, path, loadopt, mapper, result, adapter, populators + ): # this path currently does not check the result # for the column; this is because in most cases we are # working just with the setup_query() directive which does # not support this, and the behavior here should be consistent. if not self.is_class_level: - set_deferred_for_local_state = \ + set_deferred_for_local_state = ( self.parent_property._deferred_column_loader + ) populators["new"].append((self.key, set_deferred_for_local_state)) else: populators["expire"].append((self.key, False)) @@ -280,41 +318,56 @@ class DeferredColumnLoader(LoaderStrategy): self.is_class_level = True _register_attribute( - self.parent_property, mapper, useobject=False, + self.parent_property, + mapper, + useobject=False, compare_function=self.columns[0].type.compare_values, callable_=self._load_for_state, - expire_missing=False + expire_missing=False, ) def setup_query( - self, context, entity, path, loadopt, - adapter, column_collection, memoized_populators, - only_load_props=None, **kw): + self, + context, + entity, + path, + loadopt, + adapter, + column_collection, + memoized_populators, + only_load_props=None, + **kw + ): if ( ( - loadopt and - 'undefer_pks' in loadopt.local_opts and - set(self.columns).intersection( - self.parent._should_undefer_in_wildcard) - ) - or - ( - loadopt and - self.group and - loadopt.local_opts.get('undefer_group_%s' % self.group, False) + loadopt + and "undefer_pks" in loadopt.local_opts + and set(self.columns).intersection( + self.parent._should_undefer_in_wildcard + ) ) - or - ( - only_load_props and self.key in only_load_props + or ( + loadopt + and self.group + and loadopt.local_opts.get( + "undefer_group_%s" % self.group, False + ) ) + or (only_load_props and self.key in only_load_props) ): self.parent_property._get_strategy( (("deferred", False), ("instrument", True)) ).setup_query( - context, entity, - path, loadopt, adapter, - column_collection, memoized_populators, **kw) + context, + entity, + path, + loadopt, + adapter, + column_collection, + memoized_populators, + **kw + ) elif self.is_class_level: memoized_populators[self.parent_property] = _SET_DEFERRED_EXPIRED else: @@ -331,11 +384,11 @@ class DeferredColumnLoader(LoaderStrategy): if self.group: toload = [ - p.key for p in - localparent.iterate_properties - if isinstance(p, StrategizedProperty) and - isinstance(p.strategy, DeferredColumnLoader) and - p.group == self.group + p.key + for p in localparent.iterate_properties + if isinstance(p, StrategizedProperty) + and isinstance(p.strategy, DeferredColumnLoader) + and p.group == self.group ] else: toload = [self.key] @@ -347,14 +400,17 @@ class DeferredColumnLoader(LoaderStrategy): if session is None: raise orm_exc.DetachedInstanceError( "Parent instance %s is not bound to a Session; " - "deferred load operation of attribute '%s' cannot proceed" % - (orm_util.state_str(state), self.key) + "deferred load operation of attribute '%s' cannot proceed" + % (orm_util.state_str(state), self.key) ) query = session.query(localparent) - if loading.load_on_ident( - query, state.key, - only_load_props=group, refresh_state=state) is None: + if ( + loading.load_on_ident( + query, state.key, only_load_props=group, refresh_state=state + ) + is None + ): raise orm_exc.ObjectDeletedError(state) return attributes.ATTR_WAS_SET @@ -378,7 +434,7 @@ class LoadDeferredColumns(object): class AbstractRelationshipLoader(LoaderStrategy): """LoaderStratgies which deal with related objects.""" - __slots__ = 'mapper', 'target', 'uselist' + __slots__ = "mapper", "target", "uselist" def __init__(self, parent, strategy_key): super(AbstractRelationshipLoader, self).__init__(parent, strategy_key) @@ -414,19 +470,21 @@ class NoLoader(AbstractRelationshipLoader): self.is_class_level = True _register_attribute( - self.parent_property, mapper, + self.parent_property, + mapper, useobject=True, typecallable=self.parent_property.collection_class, ) def create_row_processor( - self, context, path, loadopt, mapper, - result, adapter, populators): + self, context, path, loadopt, mapper, result, adapter, populators + ): def invoke_no_load(state, dict_, row): if self.uselist: state.manager.get_impl(self.key).initialize(state, dict_) else: dict_[self.key] = None + populators["new"].append((self.key, invoke_no_load)) @@ -443,10 +501,18 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): """ __slots__ = ( - '_lazywhere', '_rev_lazywhere', 'use_get', '_bind_to_col', - '_equated_columns', '_rev_bind_to_col', '_rev_equated_columns', - '_simple_lazy_clause', '_raise_always', '_raise_on_sql', - '_bakery') + "_lazywhere", + "_rev_lazywhere", + "use_get", + "_bind_to_col", + "_equated_columns", + "_rev_bind_to_col", + "_rev_equated_columns", + "_simple_lazy_clause", + "_raise_always", + "_raise_on_sql", + "_bakery", + ) def __init__(self, parent, strategy_key): super(LazyLoader, self).__init__(parent, strategy_key) @@ -454,25 +520,23 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): self._raise_on_sql = self.strategy_opts["lazy"] == "raise_on_sql" join_condition = self.parent_property._join_condition - self._lazywhere, \ - self._bind_to_col, \ - self._equated_columns = join_condition.create_lazy_clause() + self._lazywhere, self._bind_to_col, self._equated_columns = ( + join_condition.create_lazy_clause() + ) - self._rev_lazywhere, \ - self._rev_bind_to_col, \ - self._rev_equated_columns = join_condition.create_lazy_clause( - reverse_direction=True) + self._rev_lazywhere, self._rev_bind_to_col, self._rev_equated_columns = join_condition.create_lazy_clause( + reverse_direction=True + ) self.logger.info("%s lazy loading clause %s", self, self._lazywhere) # determine if our "lazywhere" clause is the same as the mapper's # get() clause. then we can just use mapper.get() - self.use_get = not self.uselist and \ - self.mapper._get_clause[0].compare( - self._lazywhere, - use_proxies=True, - equivalents=self.mapper._equivalent_columns - ) + self.use_get = not self.uselist and self.mapper._get_clause[0].compare( + self._lazywhere, + use_proxies=True, + equivalents=self.mapper._equivalent_columns, + ) if self.use_get: for col in list(self._equated_columns): @@ -480,16 +544,17 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): for c in self.mapper._equivalent_columns[col]: self._equated_columns[c] = self._equated_columns[col] - self.logger.info("%s will use query.get() to " - "optimize instance loads", self) + self.logger.info( + "%s will use query.get() to " "optimize instance loads", self + ) def init_class_attribute(self, mapper): self.is_class_level = True active_history = ( - self.parent_property.active_history or - self.parent_property.direction is not interfaces.MANYTOONE or - not self.use_get + self.parent_property.active_history + or self.parent_property.direction is not interfaces.MANYTOONE + or not self.use_get ) # MANYTOONE currently only needs the @@ -504,28 +569,29 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): useobject=True, callable_=self._load_for_state, typecallable=self.parent_property.collection_class, - active_history=active_history + active_history=active_history, ) def _memoized_attr__simple_lazy_clause(self): - criterion, bind_to_col = ( - self._lazywhere, - self._bind_to_col - ) + criterion, bind_to_col = (self._lazywhere, self._bind_to_col) params = [] def visit_bindparam(bindparam): bindparam.unique = False if bindparam._identifying_key in bind_to_col: - params.append(( - bindparam.key, bind_to_col[bindparam._identifying_key], - None)) + params.append( + ( + bindparam.key, + bind_to_col[bindparam._identifying_key], + None, + ) + ) elif bindparam.callable is None: params.append((bindparam.key, None, bindparam.value)) criterion = visitors.cloned_traverse( - criterion, {}, {'bindparam': visit_bindparam} + criterion, {}, {"bindparam": visit_bindparam} ) return criterion, params @@ -535,7 +601,8 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): if state is None: return sql_util.adapt_criterion_to_null( - criterion, [key for key, ident, value in param_keys]) + criterion, [key for key, ident, value in param_keys] + ) mapper = self.parent_property.parent @@ -550,10 +617,12 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): if ident is not None: if passive and passive & attributes.LOAD_AGAINST_COMMITTED: value = mapper._get_committed_state_attr_by_column( - state, dict_, ident, passive) + state, dict_, ident, passive + ) else: value = mapper._get_state_attr_by_column( - state, dict_, ident, passive) + state, dict_, ident, passive + ) params[key] = value @@ -567,21 +636,19 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): def _load_for_state(self, state, passive): if not state.key and ( - ( - not self.parent_property.load_on_pending - and not state._load_pending - ) - or not state.session_id + ( + not self.parent_property.load_on_pending + and not state._load_pending + ) + or not state.session_id ): return attributes.ATTR_EMPTY pending = not state.key primary_key_identity = None - if ( - (not passive & attributes.SQL_OK and not self.use_get) - or - (not passive & attributes.NON_PERSISTENT_OK and pending) + if (not passive & attributes.SQL_OK and not self.use_get) or ( + not passive & attributes.NON_PERSISTENT_OK and pending ): return attributes.PASSIVE_NO_RESULT @@ -595,17 +662,15 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): raise orm_exc.DetachedInstanceError( "Parent instance %s is not bound to a Session; " - "lazy load operation of attribute '%s' cannot proceed" % - (orm_util.state_str(state), self.key) + "lazy load operation of attribute '%s' cannot proceed" + % (orm_util.state_str(state), self.key) ) # if we have a simple primary key load, check the # identity map without generating a Query at all if self.use_get: primary_key_identity = self._get_ident_for_use_get( - session, - state, - passive + session, state, passive ) if attributes.PASSIVE_NO_RESULT in primary_key_identity: return attributes.PASSIVE_NO_RESULT @@ -620,18 +685,23 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): # does this, including how it decides what the correct # identity_token would be for this identity. instance = session.query()._identity_lookup( - self.mapper, primary_key_identity, passive=passive, - lazy_loaded_from=state + self.mapper, + primary_key_identity, + passive=passive, + lazy_loaded_from=state, ) if instance is not None: return instance - elif not passive & attributes.SQL_OK or \ - not passive & attributes.RELATED_OBJECT_OK: + elif ( + not passive & attributes.SQL_OK + or not passive & attributes.RELATED_OBJECT_OK + ): return attributes.PASSIVE_NO_RESULT return self._emit_lazyload( - session, state, primary_key_identity, passive) + session, state, primary_key_identity, passive + ) def _get_ident_for_use_get(self, session, state, passive): instance_mapper = state.manager.mapper @@ -644,11 +714,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): dict_ = state.dict return [ - get_attr( - state, - dict_, - self._equated_columns[pk], - passive=passive) + get_attr(state, dict_, self._equated_columns[pk], passive=passive) for pk in self.mapper.primary_key ] @@ -656,11 +722,10 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): def _memoized_attr__bakery(self, baked): return baked.bakery(size=50) - @util.dependencies( - "sqlalchemy.orm.strategy_options") + @util.dependencies("sqlalchemy.orm.strategy_options") def _emit_lazyload( - self, strategy_options, session, state, - primary_key_identity, passive): + self, strategy_options, session, state, primary_key_identity, passive + ): # emit lazy load now using BakedQuery, to cut way down on the overhead # of generating queries. # there are two big things we are trying to guard against here: @@ -688,15 +753,18 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): q.add_criteria( lambda q: q._adapt_all_clauses()._with_invoke_all_eagers(False), - self.parent_property) + self.parent_property, + ) if not self.parent_property.bake_queries: q.spoil(full=True) if self.parent_property.secondary is not None: q.add_criteria( - lambda q: - q.select_from(self.mapper, self.parent_property.secondary)) + lambda q: q.select_from( + self.mapper, self.parent_property.secondary + ) + ) pending = not state.key @@ -712,35 +780,38 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): # is usually a throwaway object. effective_path = state.load_path[self.parent_property] - q._add_lazyload_options( - state.load_options, effective_path - ) + q._add_lazyload_options(state.load_options, effective_path) if self.use_get: if self._raise_on_sql: self._invoke_raise_load(state, passive, "raise_on_sql") - return q(session).\ - with_post_criteria(lambda q: q._set_lazyload_from(state)).\ - _load_on_pk_identity( - session.query(self.mapper), - primary_key_identity) + return ( + q(session) + .with_post_criteria(lambda q: q._set_lazyload_from(state)) + ._load_on_pk_identity( + session.query(self.mapper), primary_key_identity + ) + ) if self.parent_property.order_by: q.add_criteria( - lambda q: - q.order_by(*util.to_list(self.parent_property.order_by))) + lambda q: q.order_by( + *util.to_list(self.parent_property.order_by) + ) + ) for rev in self.parent_property._reverse_property: # reverse props that are MANYTOONE are loading *this* # object from get(), so don't need to eager out to those. - if rev.direction is interfaces.MANYTOONE and \ - rev._use_get and \ - not isinstance(rev.strategy, LazyLoader): + if ( + rev.direction is interfaces.MANYTOONE + and rev._use_get + and not isinstance(rev.strategy, LazyLoader) + ): q.add_criteria( - lambda q: - q.options( + lambda q: q.options( strategy_options.Load.for_existing_path( q._current_path[rev.parent] ).lazyload(rev.key) @@ -750,8 +821,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): lazy_clause, params = self._generate_lazy_clause(state, passive) if pending: - if util.has_intersection( - orm_util._none_set, params.values()): + if util.has_intersection(orm_util._none_set, params.values()): return None elif util.has_intersection(orm_util._never_set, params.values()): @@ -769,9 +839,12 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): q._params = params return q - result = q(session).\ - with_post_criteria(lambda q: q._set_lazyload_from(state)).\ - with_post_criteria(set_default_params).all() + result = ( + q(session) + .with_post_criteria(lambda q: q._set_lazyload_from(state)) + .with_post_criteria(set_default_params) + .all() + ) if self.uselist: return result else: @@ -781,15 +854,16 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): util.warn( "Multiple rows returned with " "uselist=False for lazily-loaded attribute '%s' " - % self.parent_property) + % self.parent_property + ) return result[0] else: return None def create_row_processor( - self, context, path, loadopt, - mapper, result, adapter, populators): + self, context, path, loadopt, mapper, result, adapter, populators + ): key = self.key if not self.is_class_level: @@ -802,11 +876,12 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): # attribute - "eager" attributes always have a # class-level lazyloader installed. set_lazy_callable = InstanceState._instance_level_callable_processor( - mapper.class_manager, - LoadLazyAttribute(key, self), key) + mapper.class_manager, LoadLazyAttribute(key, self), key + ) populators["new"].append((self.key, set_lazy_callable)) elif context.populate_existing or mapper.always_refresh: + def reset_for_lazy_callable(state, dict_, row): # we are the primary manager for this attribute on # this class - reset its @@ -842,19 +917,26 @@ class ImmediateLoader(AbstractRelationshipLoader): __slots__ = () def init_class_attribute(self, mapper): - self.parent_property.\ - _get_strategy((("lazy", "select"),)).\ - init_class_attribute(mapper) + self.parent_property._get_strategy( + (("lazy", "select"),) + ).init_class_attribute(mapper) def setup_query( - self, context, entity, - path, loadopt, adapter, column_collection=None, - parentmapper=None, **kwargs): + self, + context, + entity, + path, + loadopt, + adapter, + column_collection=None, + parentmapper=None, + **kwargs + ): pass def create_row_processor( - self, context, path, loadopt, - mapper, result, adapter, populators): + self, context, path, loadopt, mapper, result, adapter, populators + ): def load_immediate(state, dict_, row): state.get_impl(self.key).get(state, dict_) @@ -864,22 +946,28 @@ class ImmediateLoader(AbstractRelationshipLoader): @log.class_logger @properties.RelationshipProperty.strategy_for(lazy="subquery") class SubqueryLoader(AbstractRelationshipLoader): - __slots__ = 'join_depth', + __slots__ = ("join_depth",) def __init__(self, parent, strategy_key): super(SubqueryLoader, self).__init__(parent, strategy_key) self.join_depth = self.parent_property.join_depth def init_class_attribute(self, mapper): - self.parent_property.\ - _get_strategy((("lazy", "select"),)).\ - init_class_attribute(mapper) + self.parent_property._get_strategy( + (("lazy", "select"),) + ).init_class_attribute(mapper) def setup_query( - self, context, entity, - path, loadopt, adapter, - column_collection=None, - parentmapper=None, **kwargs): + self, + context, + entity, + path, + loadopt, + adapter, + column_collection=None, + parentmapper=None, + **kwargs + ): if not context.query._enable_eagerloads: return @@ -891,16 +979,16 @@ class SubqueryLoader(AbstractRelationshipLoader): # build up a path indicating the path from the leftmost # entity to the thing we're subquery loading. with_poly_info = path.get( - context.attributes, - "path_with_polymorphic", None) + context.attributes, "path_with_polymorphic", None + ) if with_poly_info is not None: effective_entity = with_poly_info.entity else: effective_entity = self.mapper subq_path = context.attributes.get( - ('subquery_path', None), - orm_util.PathRegistry.root) + ("subquery_path", None), orm_util.PathRegistry.root + ) subq_path = subq_path + path @@ -909,27 +997,33 @@ class SubqueryLoader(AbstractRelationshipLoader): if not path.contains(context.attributes, "loader"): if self.join_depth: if ( - (context.query._current_path.length - if context.query._current_path else 0) + - path.length + ( + context.query._current_path.length + if context.query._current_path + else 0 + ) + + path.length ) / 2 > self.join_depth: return elif subq_path.contains_mapper(self.mapper): return - leftmost_mapper, leftmost_attr, leftmost_relationship = \ - self._get_leftmost(subq_path) + leftmost_mapper, leftmost_attr, leftmost_relationship = self._get_leftmost( + subq_path + ) orig_query = context.attributes.get( - ("orig_query", SubqueryLoader), - context.query) + ("orig_query", SubqueryLoader), context.query + ) # generate a new Query from the original, then # produce a subquery from it. left_alias = self._generate_from_original_query( - orig_query, leftmost_mapper, - leftmost_attr, leftmost_relationship, - entity.entity_zero + orig_query, + leftmost_mapper, + leftmost_attr, + leftmost_relationship, + entity.entity_zero, ) # generate another Query that will join the @@ -940,17 +1034,18 @@ class SubqueryLoader(AbstractRelationshipLoader): q = orig_query.session.query(effective_entity) q._attributes = { ("orig_query", SubqueryLoader): orig_query, - ('subquery_path', None): subq_path + ("subquery_path", None): subq_path, } q = q._set_enable_single_crit(False) - to_join, local_attr, parent_alias = \ - self._prep_for_joins(left_alias, subq_path) + to_join, local_attr, parent_alias = self._prep_for_joins( + left_alias, subq_path + ) q = q.order_by(*local_attr) q = q.add_columns(*local_attr) q = self._apply_joins( - q, to_join, left_alias, - parent_alias, effective_entity) + q, to_join, left_alias, parent_alias, effective_entity + ) q = self._setup_options(q, subq_path, orig_query, effective_entity) q = self._setup_outermost_orderby(q) @@ -964,21 +1059,20 @@ class SubqueryLoader(AbstractRelationshipLoader): subq_mapper = orm_util._class_to_mapper(subq_path[0]) # determine attributes of the leftmost mapper - if self.parent.isa(subq_mapper) and \ - self.parent_property is subq_path[1]: - leftmost_mapper, leftmost_prop = \ - self.parent, self.parent_property + if ( + self.parent.isa(subq_mapper) + and self.parent_property is subq_path[1] + ): + leftmost_mapper, leftmost_prop = self.parent, self.parent_property else: - leftmost_mapper, leftmost_prop = \ - subq_mapper, \ - subq_path[1] + leftmost_mapper, leftmost_prop = subq_mapper, subq_path[1] leftmost_cols = leftmost_prop.local_columns leftmost_attr = [ getattr( - subq_path[0].entity, - leftmost_mapper._columntoproperty[c].key) + subq_path[0].entity, leftmost_mapper._columntoproperty[c].key + ) for c in leftmost_cols ] @@ -986,8 +1080,11 @@ class SubqueryLoader(AbstractRelationshipLoader): def _generate_from_original_query( self, - orig_query, leftmost_mapper, - leftmost_attr, leftmost_relationship, orig_entity + orig_query, + leftmost_mapper, + leftmost_attr, + leftmost_relationship, + orig_entity, ): # reformat the original query # to look only for significant columns @@ -999,11 +1096,16 @@ class SubqueryLoader(AbstractRelationshipLoader): # all entities mentioned in things like WHERE, JOIN, etc. if not q._from_obj: q._set_select_from( - list(set([ - ent['entity'] for ent in orig_query.column_descriptions - if ent['entity'] is not None - ])), - False + list( + set( + [ + ent["entity"] + for ent in orig_query.column_descriptions + if ent["entity"] is not None + ] + ) + ), + False, ) # select from the identity columns of the outer (specifically, these @@ -1037,8 +1139,8 @@ class SubqueryLoader(AbstractRelationshipLoader): embed_q = q.with_labels().subquery() left_alias = orm_util.AliasedClass( - leftmost_mapper, embed_q, - use_mapper_path=True) + leftmost_mapper, embed_q, use_mapper_path=True + ) return left_alias def _prep_for_joins(self, left_alias, subq_path): @@ -1077,8 +1179,8 @@ class SubqueryLoader(AbstractRelationshipLoader): # alias a plain mapper as we may be # joining multiple times parent_alias = orm_util.AliasedClass( - info.entity, - use_mapper_path=True) + info.entity, use_mapper_path=True + ) local_cols = self.parent_property.local_columns @@ -1089,8 +1191,8 @@ class SubqueryLoader(AbstractRelationshipLoader): return to_join, local_attr, parent_alias def _apply_joins( - self, q, to_join, left_alias, parent_alias, - effective_entity): + self, q, to_join, left_alias, parent_alias, effective_entity + ): ltj = len(to_join) if ltj == 1: @@ -1100,7 +1202,9 @@ class SubqueryLoader(AbstractRelationshipLoader): elif ltj == 2: to_join = [ getattr(left_alias, to_join[0][1]).of_type(parent_alias), - getattr(parent_alias, to_join[-1][1]).of_type(effective_entity) + getattr(parent_alias, to_join[-1][1]).of_type( + effective_entity + ), ] elif ltj > 2: middle = [ @@ -1108,8 +1212,9 @@ class SubqueryLoader(AbstractRelationshipLoader): orm_util.AliasedClass(item[0]) if not inspect(item[0]).is_aliased_class else item[0].entity, - item[1] - ) for item in to_join[1:-1] + item[1], + ) + for item in to_join[1:-1] ] inner = [] @@ -1123,11 +1228,15 @@ class SubqueryLoader(AbstractRelationshipLoader): inner.append(attr) - to_join = [ - getattr(left_alias, to_join[0][1]).of_type(inner[0].parent) - ] + inner + [ - getattr(parent_alias, to_join[-1][1]).of_type(effective_entity) - ] + to_join = ( + [getattr(left_alias, to_join[0][1]).of_type(inner[0].parent)] + + inner + + [ + getattr(parent_alias, to_join[-1][1]).of_type( + effective_entity + ) + ] + ) for attr in to_join: q = q.join(attr, from_joinpoint=True) @@ -1151,13 +1260,9 @@ class SubqueryLoader(AbstractRelationshipLoader): # this really only picks up the "secondary" table # right now. eagerjoin = q._from_obj[0] - eager_order_by = \ - eagerjoin._target_adapter.\ - copy_and_process( - util.to_list( - self.parent_property.order_by - ) - ) + eager_order_by = eagerjoin._target_adapter.copy_and_process( + util.to_list(self.parent_property.order_by) + ) q = q.order_by(*eager_order_by) return q @@ -1167,6 +1272,7 @@ class SubqueryLoader(AbstractRelationshipLoader): first moment a value is needed. """ + _data = None def __init__(self, subq): @@ -1180,10 +1286,7 @@ class SubqueryLoader(AbstractRelationshipLoader): def _load(self): self._data = dict( (k, [vv[0] for vv in v]) - for k, v in itertools.groupby( - self.subq, - lambda x: x[1:] - ) + for k, v in itertools.groupby(self.subq, lambda x: x[1:]) ) def loader(self, state, dict_, row): @@ -1191,17 +1294,17 @@ class SubqueryLoader(AbstractRelationshipLoader): self._load() def create_row_processor( - self, context, path, loadopt, - mapper, result, adapter, populators): + self, context, path, loadopt, mapper, result, adapter, populators + ): if not self.parent.class_manager[self.key].impl.supports_population: raise sa_exc.InvalidRequestError( "'%s' does not support object " - "population - eager loading cannot be applied." % - self) + "population - eager loading cannot be applied." % self + ) path = path[self.parent_property] - subq = path.get(context.attributes, 'subquery') + subq = path.get(context.attributes, "subquery") if subq is None: return @@ -1220,65 +1323,67 @@ class SubqueryLoader(AbstractRelationshipLoader): collections = path.get(context.attributes, "collections") if collections is None: collections = self._SubqCollections(subq) - path.set(context.attributes, 'collections', collections) + path.set(context.attributes, "collections", collections) if adapter: local_cols = [adapter.columns[c] for c in local_cols] if self.uselist: self._create_collection_loader( - context, collections, local_cols, populators) + context, collections, local_cols, populators + ) else: self._create_scalar_loader( - context, collections, local_cols, populators) + context, collections, local_cols, populators + ) def _create_collection_loader( - self, context, collections, local_cols, populators): + self, context, collections, local_cols, populators + ): def load_collection_from_subq(state, dict_, row): collection = collections.get( - tuple([row[col] for col in local_cols]), - () + tuple([row[col] for col in local_cols]), () + ) + state.get_impl(self.key).set_committed_value( + state, dict_, collection ) - state.get_impl(self.key).\ - set_committed_value(state, dict_, collection) def load_collection_from_subq_existing_row(state, dict_, row): if self.key not in dict_: load_collection_from_subq(state, dict_, row) - populators["new"].append( - (self.key, load_collection_from_subq)) + populators["new"].append((self.key, load_collection_from_subq)) populators["existing"].append( - (self.key, load_collection_from_subq_existing_row)) + (self.key, load_collection_from_subq_existing_row) + ) if context.invoke_all_eagers: populators["eager"].append((self.key, collections.loader)) def _create_scalar_loader( - self, context, collections, local_cols, populators): + self, context, collections, local_cols, populators + ): def load_scalar_from_subq(state, dict_, row): collection = collections.get( - tuple([row[col] for col in local_cols]), - (None,) + tuple([row[col] for col in local_cols]), (None,) ) if len(collection) > 1: util.warn( "Multiple rows returned with " - "uselist=False for eagerly-loaded attribute '%s' " - % self) + "uselist=False for eagerly-loaded attribute '%s' " % self + ) scalar = collection[0] - state.get_impl(self.key).\ - set_committed_value(state, dict_, scalar) + state.get_impl(self.key).set_committed_value(state, dict_, scalar) def load_scalar_from_subq_existing_row(state, dict_, row): if self.key not in dict_: load_scalar_from_subq(state, dict_, row) - populators["new"].append( - (self.key, load_scalar_from_subq)) + populators["new"].append((self.key, load_scalar_from_subq)) populators["existing"].append( - (self.key, load_scalar_from_subq_existing_row)) + (self.key, load_scalar_from_subq_existing_row) + ) if context.invoke_all_eagers: populators["eager"].append((self.key, collections.loader)) @@ -1292,7 +1397,7 @@ class JoinedLoader(AbstractRelationshipLoader): """ - __slots__ = 'join_depth', '_aliased_class_pool' + __slots__ = "join_depth", "_aliased_class_pool" def __init__(self, parent, strategy_key): super(JoinedLoader, self).__init__(parent, strategy_key) @@ -1300,14 +1405,22 @@ class JoinedLoader(AbstractRelationshipLoader): self._aliased_class_pool = [] def init_class_attribute(self, mapper): - self.parent_property.\ - _get_strategy((("lazy", "select"),)).init_class_attribute(mapper) + self.parent_property._get_strategy( + (("lazy", "select"),) + ).init_class_attribute(mapper) def setup_query( - self, context, entity, path, loadopt, adapter, - column_collection=None, parentmapper=None, - chained_from_outerjoin=False, - **kwargs): + self, + context, + entity, + path, + loadopt, + adapter, + column_collection=None, + parentmapper=None, + chained_from_outerjoin=False, + **kwargs + ): """Add a left outer join to the statement that's being constructed.""" if not context.query._enable_eagerloads: @@ -1319,15 +1432,16 @@ class JoinedLoader(AbstractRelationshipLoader): with_polymorphic = None - user_defined_adapter = self._init_user_defined_eager_proc( - loadopt, context) if loadopt else False + user_defined_adapter = ( + self._init_user_defined_eager_proc(loadopt, context) + if loadopt + else False + ) if user_defined_adapter is not False: - clauses, adapter, add_to_collection = \ - self._setup_query_on_user_defined_adapter( - context, entity, path, adapter, - user_defined_adapter - ) + clauses, adapter, add_to_collection = self._setup_query_on_user_defined_adapter( + context, entity, path, adapter, user_defined_adapter + ) else: # if not via query option, check for # a cycle @@ -1338,16 +1452,19 @@ class JoinedLoader(AbstractRelationshipLoader): elif path.contains_mapper(self.mapper): return - clauses, adapter, add_to_collection, chained_from_outerjoin = \ - self._generate_row_adapter( - context, entity, path, loadopt, adapter, - column_collection, parentmapper, chained_from_outerjoin - ) + clauses, adapter, add_to_collection, chained_from_outerjoin = self._generate_row_adapter( + context, + entity, + path, + loadopt, + adapter, + column_collection, + parentmapper, + chained_from_outerjoin, + ) with_poly_info = path.get( - context.attributes, - "path_with_polymorphic", - None + context.attributes, "path_with_polymorphic", None ) if with_poly_info is not None: with_polymorphic = with_poly_info.with_polymorphic_mappers @@ -1357,14 +1474,20 @@ class JoinedLoader(AbstractRelationshipLoader): path = path[self.mapper] loading._setup_entity_query( - context, self.mapper, entity, - path, clauses, add_to_collection, + context, + self.mapper, + entity, + path, + clauses, + add_to_collection, with_polymorphic=with_polymorphic, parentmapper=self.mapper, - chained_from_outerjoin=chained_from_outerjoin) + chained_from_outerjoin=chained_from_outerjoin, + ) - if with_poly_info is not None and \ - None in set(context.secondary_columns): + if with_poly_info is not None and None in set( + context.secondary_columns + ): raise sa_exc.InvalidRequestError( "Detected unaliased columns when generating joined " "load. Make sure to use aliased=True or flat=True " @@ -1383,8 +1506,8 @@ class JoinedLoader(AbstractRelationshipLoader): # the option applies. check if the "user_defined_eager_row_processor" # has been built up. adapter = path.get( - context.attributes, - "user_defined_eager_row_processor", False) + context.attributes, "user_defined_eager_row_processor", False + ) if adapter is not False: # just return it return adapter @@ -1394,38 +1517,39 @@ class JoinedLoader(AbstractRelationshipLoader): root_mapper, prop = path[-2:] - #from .mapper import Mapper - #from .interfaces import MapperProperty - #assert isinstance(root_mapper, Mapper) - #assert isinstance(prop, MapperProperty) + # from .mapper import Mapper + # from .interfaces import MapperProperty + # assert isinstance(root_mapper, Mapper) + # assert isinstance(prop, MapperProperty) if alias is not None: if isinstance(alias, str): alias = prop.target.alias(alias) adapter = sql_util.ColumnAdapter( - alias, - equivalents=prop.mapper._equivalent_columns) + alias, equivalents=prop.mapper._equivalent_columns + ) else: if path.contains(context.attributes, "path_with_polymorphic"): with_poly_info = path.get( - context.attributes, - "path_with_polymorphic") + context.attributes, "path_with_polymorphic" + ) adapter = orm_util.ORMAdapter( with_poly_info.entity, - equivalents=prop.mapper._equivalent_columns) + equivalents=prop.mapper._equivalent_columns, + ) else: adapter = context.query._polymorphic_adapters.get( - prop.mapper, None) + prop.mapper, None + ) path.set( - context.attributes, - "user_defined_eager_row_processor", - adapter) + context.attributes, "user_defined_eager_row_processor", adapter + ) return adapter def _setup_query_on_user_defined_adapter( - self, context, entity, - path, adapter, user_defined_adapter): + self, context, entity, path, adapter, user_defined_adapter + ): # apply some more wrapping to the "user defined adapter" # if we are setting up the query for SQL render. @@ -1434,13 +1558,17 @@ class JoinedLoader(AbstractRelationshipLoader): if adapter and user_defined_adapter: user_defined_adapter = user_defined_adapter.wrap(adapter) path.set( - context.attributes, "user_defined_eager_row_processor", - user_defined_adapter) + context.attributes, + "user_defined_eager_row_processor", + user_defined_adapter, + ) elif adapter: user_defined_adapter = adapter path.set( - context.attributes, "user_defined_eager_row_processor", - user_defined_adapter) + context.attributes, + "user_defined_eager_row_processor", + user_defined_adapter, + ) add_to_collection = context.primary_columns return user_defined_adapter, adapter, add_to_collection @@ -1450,7 +1578,7 @@ class JoinedLoader(AbstractRelationshipLoader): # we need one unique AliasedClass per query per appearance of our # entity in the query. - key = ('joinedloader_ac', self) + key = ("joinedloader_ac", self) if key not in context.attributes: context.attributes[key] = idx = 0 else: @@ -1458,9 +1586,8 @@ class JoinedLoader(AbstractRelationshipLoader): if idx >= len(self._aliased_class_pool): to_adapt = orm_util.AliasedClass( - self.mapper, - flat=True, - use_mapper_path=True) + self.mapper, flat=True, use_mapper_path=True + ) # load up the .columns collection on the Alias() before # the object becomes shared among threads. this prevents # races for column identities. @@ -1471,13 +1598,18 @@ class JoinedLoader(AbstractRelationshipLoader): return self._aliased_class_pool[idx] def _generate_row_adapter( - self, - context, entity, path, loadopt, adapter, - column_collection, parentmapper, chained_from_outerjoin): + self, + context, + entity, + path, + loadopt, + adapter, + column_collection, + parentmapper, + chained_from_outerjoin, + ): with_poly_info = path.get( - context.attributes, - "path_with_polymorphic", - None + context.attributes, "path_with_polymorphic", None ) if with_poly_info: to_adapt = with_poly_info.entity @@ -1489,8 +1621,9 @@ class JoinedLoader(AbstractRelationshipLoader): orm_util.ORMAdapter, to_adapt, equivalents=self.mapper._equivalent_columns, - adapt_required=True, allow_label_resolve=False, - anonymize_labels=True + adapt_required=True, + allow_label_resolve=False, + anonymize_labels=True, ) assert clauses.aliased_class is not None @@ -1499,8 +1632,7 @@ class JoinedLoader(AbstractRelationshipLoader): context.multi_row_eager_loaders = True innerjoin = ( - loadopt.local_opts.get( - 'innerjoin', self.parent_property.innerjoin) + loadopt.local_opts.get("innerjoin", self.parent_property.innerjoin) if loadopt is not None else self.parent_property.innerjoin ) @@ -1512,9 +1644,15 @@ class JoinedLoader(AbstractRelationshipLoader): context.create_eager_joins.append( ( - self._create_eager_join, context, - entity, path, adapter, - parentmapper, clauses, innerjoin, chained_from_outerjoin + self._create_eager_join, + context, + entity, + path, + adapter, + parentmapper, + clauses, + innerjoin, + chained_from_outerjoin, ) ) @@ -1524,9 +1662,16 @@ class JoinedLoader(AbstractRelationshipLoader): return clauses, adapter, add_to_collection, chained_from_outerjoin def _create_eager_join( - self, context, entity, - path, adapter, parentmapper, - clauses, innerjoin, chained_from_outerjoin): + self, + context, + entity, + path, + adapter, + parentmapper, + clauses, + innerjoin, + chained_from_outerjoin, + ): if parentmapper is None: localparent = entity.mapper @@ -1536,16 +1681,21 @@ class JoinedLoader(AbstractRelationshipLoader): # whether or not the Query will wrap the selectable in a subquery, # and then attach eager load joins to that (i.e., in the case of # LIMIT/OFFSET etc.) - should_nest_selectable = context.multi_row_eager_loaders and \ - context.query._should_nest_selectable + should_nest_selectable = ( + context.multi_row_eager_loaders + and context.query._should_nest_selectable + ) entity_key = None - if entity not in context.eager_joins and \ - not should_nest_selectable and \ - context.from_clause: + if ( + entity not in context.eager_joins + and not should_nest_selectable + and context.from_clause + ): indexes = sql_util.find_left_clause_that_matches_given( - context.from_clause, entity.selectable) + context.from_clause, entity.selectable + ) if len(indexes) > 1: # for the eager load case, I can't reproduce this right @@ -1553,7 +1703,8 @@ class JoinedLoader(AbstractRelationshipLoader): raise sa_exc.InvalidRequestError( "Can't identify which entity in which to joined eager " "load from. Please use an exact match when specifying " - "the join path.") + "the join path." + ) if indexes: clause = context.from_clause[indexes[0]] @@ -1569,29 +1720,27 @@ class JoinedLoader(AbstractRelationshipLoader): towrap = context.eager_joins.setdefault(entity_key, default_towrap) if adapter: - if getattr(adapter, 'aliased_class', None): + if getattr(adapter, "aliased_class", None): # joining from an adapted entity. The adapted entity # might be a "with_polymorphic", so resolve that to our # specific mapper's entity before looking for our attribute # name on it. - efm = inspect(adapter.aliased_class).\ - _entity_for_mapper( - localparent - if localparent.isa(self.parent) else self.parent) + efm = inspect(adapter.aliased_class)._entity_for_mapper( + localparent + if localparent.isa(self.parent) + else self.parent + ) # look for our attribute on the adapted entity, else fall back # to our straight property - onclause = getattr( - efm.entity, self.key, - self.parent_property) + onclause = getattr(efm.entity, self.key, self.parent_property) else: onclause = getattr( orm_util.AliasedClass( - self.parent, - adapter.selectable, - use_mapper_path=True + self.parent, adapter.selectable, use_mapper_path=True ), - self.key, self.parent_property + self.key, + self.parent_property, ) else: @@ -1600,9 +1749,10 @@ class JoinedLoader(AbstractRelationshipLoader): assert clauses.aliased_class is not None attach_on_outside = ( - not chained_from_outerjoin or - not innerjoin or innerjoin == 'unnested' or - entity.entity_zero.represents_outer_join + not chained_from_outerjoin + or not innerjoin + or innerjoin == "unnested" + or entity.entity_zero.represents_outer_join ) if attach_on_outside: @@ -1611,16 +1761,17 @@ class JoinedLoader(AbstractRelationshipLoader): towrap, clauses.aliased_class, onclause, - isouter=not innerjoin or - entity.entity_zero.represents_outer_join or - ( - chained_from_outerjoin and isinstance(towrap, sql.Join) - ), _left_memo=self.parent, _right_memo=self.mapper + isouter=not innerjoin + or entity.entity_zero.represents_outer_join + or (chained_from_outerjoin and isinstance(towrap, sql.Join)), + _left_memo=self.parent, + _right_memo=self.mapper, ) else: # all other cases are innerjoin=='nested' approach eagerjoin = self._splice_nested_inner_join( - path, towrap, clauses, onclause) + path, towrap, clauses, onclause + ) context.eager_joins[entity_key] = eagerjoin @@ -1636,22 +1787,21 @@ class JoinedLoader(AbstractRelationshipLoader): # This has the effect # of "undefering" those columns. for col in sql_util._find_columns( - self.parent_property.primaryjoin): + self.parent_property.primaryjoin + ): if localparent.mapped_table.c.contains_column(col): if adapter: col = adapter.columns[col] context.primary_columns.append(col) if self.parent_property.order_by: - context.eager_order_by += eagerjoin._target_adapter.\ - copy_and_process( - util.to_list( - self.parent_property.order_by - ) - ) + context.eager_order_by += eagerjoin._target_adapter.copy_and_process( + util.to_list(self.parent_property.order_by) + ) def _splice_nested_inner_join( - self, path, join_obj, clauses, onclause, splicing=False): + self, path, join_obj, clauses, onclause, splicing=False + ): if splicing is False: # first call is always handed a join object @@ -1664,28 +1814,31 @@ class JoinedLoader(AbstractRelationshipLoader): elif not isinstance(join_obj, orm_util._ORMJoin): if path[-2] is splicing: return orm_util._ORMJoin( - join_obj, clauses.aliased_class, - onclause, isouter=False, + join_obj, + clauses.aliased_class, + onclause, + isouter=False, _left_memo=splicing, - _right_memo=path[-1].mapper + _right_memo=path[-1].mapper, ) else: # only here if splicing == True return None target_join = self._splice_nested_inner_join( - path, join_obj.right, clauses, - onclause, join_obj._right_memo) + path, join_obj.right, clauses, onclause, join_obj._right_memo + ) if target_join is None: right_splice = False target_join = self._splice_nested_inner_join( - path, join_obj.left, clauses, - onclause, join_obj._left_memo) + path, join_obj.left, clauses, onclause, join_obj._left_memo + ) if target_join is None: # should only return None when recursively called, # e.g. splicing==True - assert splicing is not False, \ - "assertion failed attempting to produce joined eager loads" + assert ( + splicing is not False + ), "assertion failed attempting to produce joined eager loads" return None else: right_splice = True @@ -1698,21 +1851,30 @@ class JoinedLoader(AbstractRelationshipLoader): eagerjoin = join_obj._splice_into_center(target_join) else: eagerjoin = orm_util._ORMJoin( - join_obj.left, target_join, - join_obj.onclause, isouter=join_obj.isouter, - _left_memo=join_obj._left_memo) + join_obj.left, + target_join, + join_obj.onclause, + isouter=join_obj.isouter, + _left_memo=join_obj._left_memo, + ) else: eagerjoin = orm_util._ORMJoin( - target_join, join_obj.right, - join_obj.onclause, isouter=join_obj.isouter, - _right_memo=join_obj._right_memo) + target_join, + join_obj.right, + join_obj.onclause, + isouter=join_obj.isouter, + _right_memo=join_obj._right_memo, + ) eagerjoin._target_adapter = target_join._target_adapter return eagerjoin def _create_eager_adapter(self, context, result, adapter, path, loadopt): - user_defined_adapter = self._init_user_defined_eager_proc( - loadopt, context) if loadopt else False + user_defined_adapter = ( + self._init_user_defined_eager_proc(loadopt, context) + if loadopt + else False + ) if user_defined_adapter is not False: decorator = user_defined_adapter @@ -1736,21 +1898,19 @@ class JoinedLoader(AbstractRelationshipLoader): return False def create_row_processor( - self, context, path, loadopt, mapper, - result, adapter, populators): + self, context, path, loadopt, mapper, result, adapter, populators + ): if not self.parent.class_manager[self.key].impl.supports_population: raise sa_exc.InvalidRequestError( "'%s' does not support object " - "population - eager loading cannot be applied." % - self + "population - eager loading cannot be applied." % self ) our_path = path[self.parent_property] eager_adapter = self._create_eager_adapter( - context, - result, - adapter, our_path, loadopt) + context, result, adapter, our_path, loadopt + ) if eager_adapter is not False: key = self.key @@ -1760,25 +1920,28 @@ class JoinedLoader(AbstractRelationshipLoader): context, result, our_path[self.mapper], - eager_adapter) + eager_adapter, + ) if not self.uselist: self._create_scalar_loader(context, key, _instance, populators) else: self._create_collection_loader( - context, key, _instance, populators) + context, key, _instance, populators + ) else: - self.parent_property._get_strategy((("lazy", "select"),)).\ - create_row_processor( - context, path, loadopt, - mapper, result, adapter, populators) + self.parent_property._get_strategy( + (("lazy", "select"),) + ).create_row_processor( + context, path, loadopt, mapper, result, adapter, populators + ) def _create_collection_loader(self, context, key, _instance, populators): def load_collection_from_joined_new_row(state, dict_, row): - collection = attributes.init_state_collection( - state, dict_, key) - result_list = util.UniqueAppender(collection, - 'append_without_event') + collection = attributes.init_state_collection(state, dict_, key) + result_list = util.UniqueAppender( + collection, "append_without_event" + ) context.attributes[(state, key)] = result_list inst = _instance(row) if inst is not None: @@ -1793,10 +1956,11 @@ class JoinedLoader(AbstractRelationshipLoader): # is used; the same instance may be present in two # distinct sets of result columns collection = attributes.init_state_collection( - state, dict_, key) + state, dict_, key + ) result_list = util.UniqueAppender( - collection, - 'append_without_event') + collection, "append_without_event" + ) context.attributes[(state, key)] = result_list inst = _instance(row) if inst is not None: @@ -1805,12 +1969,16 @@ class JoinedLoader(AbstractRelationshipLoader): def load_collection_from_joined_exec(state, dict_, row): _instance(row) - populators["new"].append((self.key, load_collection_from_joined_new_row)) + populators["new"].append( + (self.key, load_collection_from_joined_new_row) + ) populators["existing"].append( - (self.key, load_collection_from_joined_existing_row)) + (self.key, load_collection_from_joined_existing_row) + ) if context.invoke_all_eagers: populators["eager"].append( - (self.key, load_collection_from_joined_exec)) + (self.key, load_collection_from_joined_exec) + ) def _create_scalar_loader(self, context, key, _instance, populators): def load_scalar_from_joined_new_row(state, dict_, row): @@ -1829,7 +1997,8 @@ class JoinedLoader(AbstractRelationshipLoader): util.warn( "Multiple rows returned with " "uselist=False for eagerly-loaded attribute '%s' " - % self) + % self + ) else: # this case is when one row has multiple loads of the # same entity (e.g. via aliasing), one has an attribute @@ -1841,17 +2010,25 @@ class JoinedLoader(AbstractRelationshipLoader): populators["new"].append((self.key, load_scalar_from_joined_new_row)) populators["existing"].append( - (self.key, load_scalar_from_joined_existing_row)) + (self.key, load_scalar_from_joined_existing_row) + ) if context.invoke_all_eagers: - populators["eager"].append((self.key, load_scalar_from_joined_exec)) + populators["eager"].append( + (self.key, load_scalar_from_joined_exec) + ) @log.class_logger @properties.RelationshipProperty.strategy_for(lazy="selectin") class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): __slots__ = ( - 'join_depth', 'omit_join', '_parent_alias', '_in_expr', - '_pk_cols', '_zero_idx', '_bakery' + "join_depth", + "omit_join", + "_parent_alias", + "_in_expr", + "_pk_cols", + "_zero_idx", + "_bakery", ) _chunksize = 500 @@ -1864,11 +2041,12 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): self.omit_join = self.parent_property.omit_join else: lazyloader = self.parent_property._get_strategy( - (("lazy", "select"),)) + (("lazy", "select"),) + ) self.omit_join = self.parent._get_clause[0].compare( lazyloader._rev_lazywhere, use_proxies=True, - equivalents=self.parent._equivalent_columns + equivalents=self.parent._equivalent_columns, ) if self.omit_join: self._init_for_omit_join() @@ -1886,8 +2064,8 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): ) self._pk_cols = fk_cols = [ - pk_to_fk[col] - for col in self.parent.primary_key if col in pk_to_fk] + pk_to_fk[col] for col in self.parent.primary_key if col in pk_to_fk + ] if len(fk_cols) > 1: self._in_expr = sql.tuple_(*fk_cols) self._zero_idx = False @@ -1899,7 +2077,8 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): self._parent_alias = aliased(self.parent.class_) pa_insp = inspect(self._parent_alias) self._pk_cols = pk_cols = [ - pa_insp._adapt_element(col) for col in self.parent.primary_key] + pa_insp._adapt_element(col) for col in self.parent.primary_key + ] if len(pk_cols) > 1: self._in_expr = sql.tuple_(*pk_cols) self._zero_idx = False @@ -1908,26 +2087,26 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): self._zero_idx = True def init_class_attribute(self, mapper): - self.parent_property.\ - _get_strategy((("lazy", "select"),)).\ - init_class_attribute(mapper) + self.parent_property._get_strategy( + (("lazy", "select"),) + ).init_class_attribute(mapper) @util.dependencies("sqlalchemy.ext.baked") def _memoized_attr__bakery(self, baked): return baked.bakery(size=50) def create_row_processor( - self, context, path, loadopt, mapper, - result, adapter, populators): + self, context, path, loadopt, mapper, result, adapter, populators + ): if not self.parent.class_manager[self.key].impl.supports_population: raise sa_exc.InvalidRequestError( "'%s' does not support object " - "population - eager loading cannot be applied." % - self + "population - eager loading cannot be applied." % self ) selectin_path = ( - context.query._current_path or orm_util.PathRegistry.root) + path + context.query._current_path or orm_util.PathRegistry.root + ) + path if not orm_util._entity_isa(path[-1], self.parent): return @@ -1941,8 +2120,8 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): # build up a path indicating the path from the leftmost # entity to the thing we're subquery loading. with_poly_info = path_w_prop.get( - context.attributes, - "path_with_polymorphic", None) + context.attributes, "path_with_polymorphic", None + ) if with_poly_info is not None: effective_entity = with_poly_info.entity @@ -1957,19 +2136,24 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): return loading.PostLoad.callable_for_path( - context, selectin_path, self.parent, self.key, - self._load_for_path, effective_entity) + context, + selectin_path, + self.parent, + self.key, + self._load_for_path, + effective_entity, + ) @util.dependencies("sqlalchemy.ext.baked") def _load_for_path( - self, baked, context, path, states, load_only, effective_entity): + self, baked, context, path, states, load_only, effective_entity + ): if load_only and self.key not in load_only: return our_states = [ - (state.key[1], state, overwrite) - for state, overwrite in states + (state.key[1], state, overwrite) for state, overwrite in states ] pk_cols = self._pk_cols @@ -1984,17 +2168,15 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): # parent entity and do not need adaption. insp = inspect(effective_entity) if insp.is_aliased_class: - pk_cols = [ - insp._adapt_element(col) - for col in pk_cols - ] + pk_cols = [insp._adapt_element(col) for col in pk_cols] in_expr = insp._adapt_element(in_expr) pk_cols = [insp._adapt_element(col) for col in pk_cols] q = self._bakery( lambda session: session.query( - query.Bundle("pk", *pk_cols), effective_entity, - ), self + query.Bundle("pk", *pk_cols), effective_entity + ), + self, ) if self.omit_join: @@ -2012,60 +2194,53 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): q.add_criteria( lambda q: q.select_from(pa).join( getattr(pa, self.parent_property.key).of_type( - effective_entity) + effective_entity + ) ) ) q.add_criteria( lambda q: q.filter( - in_expr.in_( - sql.bindparam("primary_keys", expanding=True)) - ).order_by(*pk_cols)) + in_expr.in_(sql.bindparam("primary_keys", expanding=True)) + ).order_by(*pk_cols) + ) orig_query = context.query q._add_lazyload_options( - orig_query._with_options, - path[self.parent_property] + orig_query._with_options, path[self.parent_property] ) if orig_query._populate_existing: - q.add_criteria( - lambda q: q.populate_existing() - ) + q.add_criteria(lambda q: q.populate_existing()) if self.parent_property.order_by: if self.omit_join: eager_order_by = self.parent_property.order_by if insp.is_aliased_class: eager_order_by = [ - insp._adapt_element(elem) for elem in - eager_order_by + insp._adapt_element(elem) for elem in eager_order_by ] - q.add_criteria( - lambda q: q.order_by(*eager_order_by) - ) + q.add_criteria(lambda q: q.order_by(*eager_order_by)) else: + def _setup_outermost_orderby(q): # imitate the same method that subquery eager loading uses, # looking for the adapted "secondary" table eagerjoin = q._from_obj[0] - eager_order_by = \ - eagerjoin._target_adapter.\ - copy_and_process( - util.to_list(self.parent_property.order_by) - ) + eager_order_by = eagerjoin._target_adapter.copy_and_process( + util.to_list(self.parent_property.order_by) + ) return q.order_by(*eager_order_by) - q.add_criteria( - _setup_outermost_orderby - ) + + q.add_criteria(_setup_outermost_orderby) uselist = self.uselist _empty_result = () if uselist else None while our_states: - chunk = our_states[0:self._chunksize] - our_states = our_states[self._chunksize:] + chunk = our_states[0 : self._chunksize] + our_states = our_states[self._chunksize :] data = { k: [vv[1] for vv in v] @@ -2073,9 +2248,10 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): q(context.session).params( primary_keys=[ key[0] if self._zero_idx else key - for key, state, overwrite in chunk] + for key, state, overwrite in chunk + ] ), - lambda x: x[0] + lambda x: x[0], ) } @@ -2091,13 +2267,15 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): util.warn( "Multiple rows returned with " "uselist=False for eagerly-loaded " - "attribute '%s' " - % self) + "attribute '%s' " % self + ) state.get_impl(self.key).set_committed_value( - state, state.dict, collection[0]) + state, state.dict, collection[0] + ) else: state.get_impl(self.key).set_committed_value( - state, state.dict, collection) + state, state.dict, collection + ) def single_parent_validator(desc, prop): @@ -2108,8 +2286,8 @@ def single_parent_validator(desc, prop): raise sa_exc.InvalidRequestError( "Instance %s is already associated with an instance " "of %s via its %s attribute, and is only allowed a " - "single parent." % - (orm_util.instance_str(value), state.class_, prop) + "single parent." + % (orm_util.instance_str(value), state.class_, prop) ) return value @@ -2120,8 +2298,6 @@ def single_parent_validator(desc, prop): return _do_check(state, value, oldvalue, initiator) event.listen( - desc, 'append', append, raw=True, retval=True, - active_history=True) - event.listen( - desc, 'set', set_, raw=True, retval=True, - active_history=True) + desc, "append", append, raw=True, retval=True, active_history=True + ) + event.listen(desc, "set", set_, raw=True, retval=True, active_history=True) diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index f0d2091101..b2f6bcb11c 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -13,11 +13,19 @@ from .attributes import QueryableAttribute from .. import util from ..sql.base import _generative, Generative from .. import exc as sa_exc, inspect -from .base import _is_aliased_class, _class_to_mapper, _is_mapped_class, \ - InspectionAttr +from .base import ( + _is_aliased_class, + _class_to_mapper, + _is_mapped_class, + InspectionAttr, +) from . import util as orm_util -from .path_registry import PathRegistry, TokenRegistry, \ - _WILDCARD_TOKEN, _DEFAULT_TOKEN +from .path_registry import ( + PathRegistry, + TokenRegistry, + _WILDCARD_TOKEN, + _DEFAULT_TOKEN, +) class Load(Generative, MapperOption): @@ -94,12 +102,14 @@ class Load(Generative, MapperOption): if ( # means loader_path and path are unrelated, # this does not need to be part of a cache key - chopped is None + chopped + is None ) or ( # means no additional path with loader_path + path # and the endpoint isn't using of_type so isn't modified # into an alias or other unsafe entity - not chopped and not obj._of_type + not chopped + and not obj._of_type ): continue @@ -124,12 +134,18 @@ class Load(Generative, MapperOption): serialized.append( ( - tuple(serialized_path) + - (obj.strategy or ()) + - (tuple([ - (key, obj.local_opts[key]) - for key in sorted(obj.local_opts) - ]) if obj.local_opts else ()) + tuple(serialized_path) + + (obj.strategy or ()) + + ( + tuple( + [ + (key, obj.local_opts[key]) + for key in sorted(obj.local_opts) + ] + ) + if obj.local_opts + else () + ) ) ) if not serialized: @@ -170,12 +186,13 @@ class Load(Generative, MapperOption): if raiseerr and not path.has_entity: if isinstance(path, TokenRegistry): raise sa_exc.ArgumentError( - "Wildcard token cannot be followed by another entity") + "Wildcard token cannot be followed by another entity" + ) else: raise sa_exc.ArgumentError( "Attribute '%s' of entity '%s' does not " - "refer to a mapped entity" % - (path.prop.key, path.parent.entity) + "refer to a mapped entity" + % (path.prop.key, path.parent.entity) ) if isinstance(attr, util.string_types): @@ -201,8 +218,7 @@ class Load(Generative, MapperOption): if raiseerr: raise sa_exc.ArgumentError( "Can't find property named '%s' on the " - "mapped entity %s in this Query. " % ( - attr, ent) + "mapped entity %s in this Query. " % (attr, ent) ) else: return None @@ -215,7 +231,8 @@ class Load(Generative, MapperOption): if raiseerr: raise sa_exc.ArgumentError( "Attribute '%s' does not " - "link from element '%s'" % (attr, path.entity)) + "link from element '%s'" % (attr, path.entity) + ) else: return None else: @@ -225,22 +242,26 @@ class Load(Generative, MapperOption): if raiseerr: raise sa_exc.ArgumentError( "Attribute '%s' does not " - "link from element '%s'" % (attr, path.entity)) + "link from element '%s'" % (attr, path.entity) + ) else: return None - if getattr(attr, '_of_type', None): + if getattr(attr, "_of_type", None): ac = attr._of_type ext_info = of_type_info = inspect(ac) existing = path.entity_path[prop].get( - self.context, "path_with_polymorphic") + self.context, "path_with_polymorphic" + ) if not ext_info.is_aliased_class: ac = orm_util.with_polymorphic( ext_info.mapper.base_mapper, - ext_info.mapper, aliased=True, + ext_info.mapper, + aliased=True, _use_mapper_path=True, - _existing_alias=existing) + _existing_alias=existing, + ) ext_info = inspect(ac) elif not ext_info.with_polymorphic_mappers: ext_info = orm_util.AliasedInsp( @@ -253,11 +274,12 @@ class Load(Generative, MapperOption): ext_info._base_alias, ext_info._use_mapper_path, ext_info._adapt_on_names, - ext_info.represents_outer_join + ext_info.represents_outer_join, ) path.entity_path[prop].set( - self.context, "path_with_polymorphic", ext_info) + self.context, "path_with_polymorphic", ext_info + ) # the path here will go into the context dictionary and # needs to match up to how the class graph is traversed. @@ -280,7 +302,7 @@ class Load(Generative, MapperOption): return path def __str__(self): - return "Load(strategy=%r)" % (self.strategy, ) + return "Load(strategy=%r)" % (self.strategy,) def _coerce_strat(self, strategy): if strategy is not None: @@ -289,7 +311,8 @@ class Load(Generative, MapperOption): @_generative def set_relationship_strategy( - self, attr, strategy, propagate_to_loaders=True): + self, attr, strategy, propagate_to_loaders=True + ): strategy = self._coerce_strat(strategy) self.is_class_strategy = False @@ -365,12 +388,18 @@ class Load(Generative, MapperOption): if effective_path.is_token: for path in effective_path.generate_for_superclasses(): self._set_for_path( - self.context, path, replace=True, - merge_opts=self.is_opts_only) + self.context, + path, + replace=True, + merge_opts=self.is_opts_only, + ) else: self._set_for_path( - self.context, effective_path, replace=True, - merge_opts=self.is_opts_only) + self.context, + effective_path, + replace=True, + merge_opts=self.is_opts_only, + ) def __getstate__(self): d = self.__dict__.copy() @@ -389,21 +418,26 @@ class Load(Generative, MapperOption): # TODO: this is approximated from the _UnboundLoad # version and probably has issues, not fully covered. - if i == 0 and c_token.endswith(':' + _DEFAULT_TOKEN): + if i == 0 and c_token.endswith(":" + _DEFAULT_TOKEN): return to_chop - elif c_token != 'relationship:%s' % (_WILDCARD_TOKEN,) and \ - c_token != p_token.key: + elif ( + c_token != "relationship:%s" % (_WILDCARD_TOKEN,) + and c_token != p_token.key + ): return None if c_token is p_token: continue - elif isinstance(c_token, InspectionAttr) and \ - c_token.is_mapper and p_token.is_mapper and \ - c_token.isa(p_token): + elif ( + isinstance(c_token, InspectionAttr) + and c_token.is_mapper + and p_token.is_mapper + and c_token.isa(p_token) + ): continue else: return None - return to_chop[i + 1:] + return to_chop[i + 1 :] class _UnboundLoad(Load): @@ -431,9 +465,7 @@ class _UnboundLoad(Load): if local_elem is not val_elem: break else: - opt = val._bind_loader( - [path.path[0]], - None, None, False) + opt = val._bind_loader([path.path[0]], None, None, False) if opt: c_key = opt._generate_cache_key(path) if c_key is False: @@ -449,26 +481,29 @@ class _UnboundLoad(Load): self._to_bind.append(self) def _generate_path(self, path, attr, wildcard_key): - if wildcard_key and isinstance(attr, util.string_types) and \ - attr in (_WILDCARD_TOKEN, _DEFAULT_TOKEN): + if ( + wildcard_key + and isinstance(attr, util.string_types) + and attr in (_WILDCARD_TOKEN, _DEFAULT_TOKEN) + ): if attr == _DEFAULT_TOKEN: self.propagate_to_loaders = False attr = "%s:%s" % (wildcard_key, attr) if path and _is_mapped_class(path[-1]) and not self.is_class_strategy: path = path[0:-1] if attr: - path = path + (attr, ) + path = path + (attr,) self.path = path return path def __getstate__(self): d = self.__dict__.copy() - d['path'] = self._serialize_path(self.path, filter_aliased_class=True) + d["path"] = self._serialize_path(self.path, filter_aliased_class=True) return d def __setstate__(self, state): ret = [] - for key in state['path']: + for key in state["path"]: if isinstance(key, tuple): if len(key) == 2: # support legacy @@ -482,17 +517,20 @@ class _UnboundLoad(Load): ret.append(prop) else: ret.append(key) - state['path'] = tuple(ret) + state["path"] = tuple(ret) self.__dict__ = state def _process(self, query, raiseerr): - dedupes = query._attributes['_unbound_load_dedupes'] + dedupes = query._attributes["_unbound_load_dedupes"] for val in self._to_bind: if val not in dedupes: dedupes.add(val) val._bind_loader( [ent.entity_zero for ent in query._mapper_entities], - query._current_path, query._attributes, raiseerr) + query._current_path, + query._attributes, + raiseerr, + ) @classmethod def _from_keys(cls, meth, keys, chained, kw): @@ -502,13 +540,14 @@ class _UnboundLoad(Load): if isinstance(key, util.string_types): # coerce fooload('*') into "default loader strategy" if key == _WILDCARD_TOKEN: - return (_DEFAULT_TOKEN, ) + return (_DEFAULT_TOKEN,) # coerce fooload(".*") into "wildcard on default entity" elif key.startswith("." + _WILDCARD_TOKEN): key = key[1:] return key.split(".") else: return (key,) + all_tokens = [token for key in keys for token in _split_key(key)] for token in all_tokens[0:-1]: @@ -526,21 +565,24 @@ class _UnboundLoad(Load): def _chop_path(self, to_chop, path): i = -1 for i, (c_token, (p_entity, p_prop)) in enumerate( - zip(to_chop, path.pairs())): + zip(to_chop, path.pairs()) + ): if isinstance(c_token, util.string_types): - if i == 0 and c_token.endswith(':' + _DEFAULT_TOKEN): + if i == 0 and c_token.endswith(":" + _DEFAULT_TOKEN): return to_chop - elif c_token != 'relationship:%s' % ( - _WILDCARD_TOKEN,) and c_token != p_prop.key: + elif ( + c_token != "relationship:%s" % (_WILDCARD_TOKEN,) + and c_token != p_prop.key + ): return None elif isinstance(c_token, PropComparator): - if c_token.property is not p_prop or \ - ( - c_token._parententity is not p_entity and ( - not c_token._parententity.is_mapper or - not c_token._parententity.isa(p_entity) - ) - ): + if c_token.property is not p_prop or ( + c_token._parententity is not p_entity + and ( + not c_token._parententity.is_mapper + or not c_token._parententity.isa(p_entity) + ) + ): return None else: i += 1 @@ -551,15 +593,16 @@ class _UnboundLoad(Load): ret = [] for token in path: if isinstance(token, QueryableAttribute): - if filter_aliased_class and token._of_type and \ - inspect(token._of_type).is_aliased_class: - ret.append( - (token._parentmapper.class_, - token.key, None)) + if ( + filter_aliased_class + and token._of_type + and inspect(token._of_type).is_aliased_class + ): + ret.append((token._parentmapper.class_, token.key, None)) else: ret.append( - (token._parentmapper.class_, token.key, - token._of_type)) + (token._parentmapper.class_, token.key, token._of_type) + ) elif isinstance(token, PropComparator): ret.append((token._parentmapper.class_, token.key, None)) else: @@ -605,7 +648,7 @@ class _UnboundLoad(Load): start_path = self.path if self.is_class_strategy and current_path: - start_path += (entities[0], ) + start_path += (entities[0],) # _current_path implies we're in a # secondary load with an existing path @@ -621,23 +664,20 @@ class _UnboundLoad(Load): token = start_path[0] if isinstance(token, util.string_types): - entity = self._find_entity_basestring( - entities, token, raiseerr) + entity = self._find_entity_basestring(entities, token, raiseerr) elif isinstance(token, PropComparator): prop = token.property entity = self._find_entity_prop_comparator( - entities, - prop.key, - token._parententity, - raiseerr) + entities, prop.key, token._parententity, raiseerr + ) elif self.is_class_strategy and _is_mapped_class(token): entity = inspect(token) if entity not in entities: entity = None else: raise sa_exc.ArgumentError( - "mapper option expects " - "string key or list of attributes") + "mapper option expects " "string key or list of attributes" + ) if not entity: return @@ -663,7 +703,8 @@ class _UnboundLoad(Load): if not loader.is_class_strategy: for token in start_path: if not loader._generate_path( - loader.path, token, None, raiseerr): + loader.path, token, None, raiseerr + ): return loader.local_opts.update(self.local_opts) @@ -680,14 +721,18 @@ class _UnboundLoad(Load): if effective_path.is_token: for path in effective_path.generate_for_superclasses(): loader._set_for_path( - context, path, + context, + path, replace=not self._is_chain_link, - merge_opts=self.is_opts_only) + merge_opts=self.is_opts_only, + ) else: loader._set_for_path( - context, effective_path, + context, + effective_path, replace=not self._is_chain_link, - merge_opts=self.is_opts_only) + merge_opts=self.is_opts_only, + ) return loader @@ -704,28 +749,27 @@ class _UnboundLoad(Load): if not list(entities): raise sa_exc.ArgumentError( "Query has only expression-based entities - " - "can't find property named '%s'." - % (token, ) + "can't find property named '%s'." % (token,) ) else: raise sa_exc.ArgumentError( "Can't find property '%s' on any entity " "specified in this Query. Note the full path " "from root (%s) to target entity must be specified." - % (token, ",".join(str(x) for - x in entities)) + % (token, ",".join(str(x) for x in entities)) ) else: return None def _find_entity_basestring(self, entities, token, raiseerr): - if token.endswith(':' + _WILDCARD_TOKEN): + if token.endswith(":" + _WILDCARD_TOKEN): if len(list(entities)) != 1: if raiseerr: raise sa_exc.ArgumentError( "Wildcard loader can only be used with exactly " "one entity. Use Load(ent) to specify " - "specific entities.") + "specific entities." + ) elif token.endswith(_DEFAULT_TOKEN): raiseerr = False @@ -738,8 +782,7 @@ class _UnboundLoad(Load): if raiseerr: raise sa_exc.ArgumentError( "Query has only expression-based entities - " - "can't find property named '%s'." - % (token, ) + "can't find property named '%s'." % (token,) ) else: return None @@ -766,7 +809,9 @@ class loader_option(object): See :func:`.orm.%(name)s` for usage examples. -""" % {"name": self.name} +""" % { + "name": self.name + } fn.__doc__ = fn_doc return self @@ -783,7 +828,9 @@ See :func:`.orm.%(name)s` for usage examples. %(name)s("someattribute").%(name)s("anotherattribute") ) -""" % {"name": self.name} +""" % { + "name": self.name + } return self @@ -840,23 +887,22 @@ def contains_eager(loadopt, attr, alias=None): info = inspect(alias) alias = info.selectable - elif getattr(attr, '_of_type', None): + elif getattr(attr, "_of_type", None): ot = inspect(attr._of_type) alias = ot.selectable cloned = loadopt.set_relationship_strategy( - attr, - {"lazy": "joined"}, - propagate_to_loaders=False + attr, {"lazy": "joined"}, propagate_to_loaders=False ) - cloned.local_opts['eager_from_alias'] = alias + cloned.local_opts["eager_from_alias"] = alias return cloned @contains_eager._add_unbound_fn def contains_eager(*keys, **kw): return _UnboundLoad()._from_keys( - _UnboundLoad.contains_eager, keys, True, kw) + _UnboundLoad.contains_eager, keys, True, kw + ) @loader_option() @@ -894,12 +940,11 @@ def load_only(loadopt, *attrs): """ cloned = loadopt.set_column_strategy( - attrs, - {"deferred": False, "instrument": True} + attrs, {"deferred": False, "instrument": True} + ) + cloned.set_column_strategy( + "*", {"deferred": True, "instrument": True}, {"undefer_pks": True} ) - cloned.set_column_strategy("*", - {"deferred": True, "instrument": True}, - {"undefer_pks": True}) return cloned @@ -996,20 +1041,18 @@ def joinedload(loadopt, attr, innerjoin=None): """ loader = loadopt.set_relationship_strategy(attr, {"lazy": "joined"}) if innerjoin is not None: - loader.local_opts['innerjoin'] = innerjoin + loader.local_opts["innerjoin"] = innerjoin return loader @joinedload._add_unbound_fn def joinedload(*keys, **kw): - return _UnboundLoad._from_keys( - _UnboundLoad.joinedload, keys, False, kw) + return _UnboundLoad._from_keys(_UnboundLoad.joinedload, keys, False, kw) @joinedload._add_unbound_all_fn def joinedload_all(*keys, **kw): - return _UnboundLoad._from_keys( - _UnboundLoad.joinedload, keys, True, kw) + return _UnboundLoad._from_keys(_UnboundLoad.joinedload, keys, True, kw) @loader_option() @@ -1152,8 +1195,7 @@ def immediateload(loadopt, attr): @immediateload._add_unbound_fn def immediateload(*keys): - return _UnboundLoad._from_keys( - _UnboundLoad.immediateload, keys, False, {}) + return _UnboundLoad._from_keys(_UnboundLoad.immediateload, keys, False, {}) @loader_option() @@ -1213,7 +1255,8 @@ def raiseload(loadopt, attr, sql_only=False): """ return loadopt.set_relationship_strategy( - attr, {"lazy": "raise_on_sql" if sql_only else "raise"}) + attr, {"lazy": "raise_on_sql" if sql_only else "raise"} + ) @raiseload._add_unbound_fn @@ -1251,10 +1294,7 @@ def defaultload(loadopt, attr): :ref:`deferred_loading_w_multiple` """ - return loadopt.set_relationship_strategy( - attr, - None - ) + return loadopt.set_relationship_strategy(attr, None) @defaultload._add_unbound_fn @@ -1315,15 +1355,15 @@ def defer(loadopt, key): """ return loadopt.set_column_strategy( - (key, ), - {"deferred": True, "instrument": True} + (key,), {"deferred": True, "instrument": True} ) @defer._add_unbound_fn def defer(key, *addl_attrs): return _UnboundLoad._from_keys( - _UnboundLoad.defer, (key, ) + addl_attrs, False, {}) + _UnboundLoad.defer, (key,) + addl_attrs, False, {} + ) @loader_option() @@ -1362,15 +1402,15 @@ def undefer(loadopt, key): """ return loadopt.set_column_strategy( - (key, ), - {"deferred": False, "instrument": True} + (key,), {"deferred": False, "instrument": True} ) @undefer._add_unbound_fn def undefer(key, *addl_attrs): return _UnboundLoad._from_keys( - _UnboundLoad.undefer, (key, ) + addl_attrs, False, {}) + _UnboundLoad.undefer, (key,) + addl_attrs, False, {} + ) @loader_option() @@ -1405,10 +1445,7 @@ def undefer_group(loadopt, name): """ return loadopt.set_column_strategy( - "*", - None, - {"undefer_group_%s" % name: True}, - opts_only=True + "*", None, {"undefer_group_%s" % name: True}, opts_only=True ) @@ -1448,21 +1485,18 @@ def with_expression(loadopt, key, expression): """ - expression = sql_expr._labeled( - _orm_full_deannotate(expression)) + expression = sql_expr._labeled(_orm_full_deannotate(expression)) return loadopt.set_column_strategy( - (key, ), - {"query_expression": True}, - opts={"expression": expression} + (key,), {"query_expression": True}, opts={"expression": expression} ) @with_expression._add_unbound_fn def with_expression(key, expression): return _UnboundLoad._from_keys( - _UnboundLoad.with_expression, (key, ), - False, {"expression": expression}) + _UnboundLoad.with_expression, (key,), False, {"expression": expression} + ) @loader_option() @@ -1483,7 +1517,11 @@ def selectin_polymorphic(loadopt, classes): """ loadopt.set_class_strategy( {"selectinload_polymorphic": True}, - opts={"entities": tuple(sorted((inspect(cls) for cls in classes), key=id))} + opts={ + "entities": tuple( + sorted((inspect(cls) for cls in classes), key=id) + ) + }, ) return loadopt @@ -1492,8 +1530,6 @@ def selectin_polymorphic(loadopt, classes): def selectin_polymorphic(base_cls, classes): ul = _UnboundLoad() ul.is_class_strategy = True - ul.path = (inspect(base_cls), ) - ul.selectin_polymorphic( - classes - ) + ul.path = (inspect(base_cls),) + ul.selectin_polymorphic(classes) return ul diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index 08a66a8db4..0cd488cbd4 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -13,8 +13,15 @@ between instances based on join conditions. from . import exc, util as orm_util, attributes -def populate(source, source_mapper, dest, dest_mapper, - synchronize_pairs, uowcommit, flag_cascaded_pks): +def populate( + source, + source_mapper, + dest, + dest_mapper, + synchronize_pairs, + uowcommit, + flag_cascaded_pks, +): source_dict = source.dict dest_dict = dest.dict @@ -22,8 +29,9 @@ def populate(source, source_mapper, dest, dest_mapper, try: # inline of source_mapper._get_state_attr_by_column prop = source_mapper._columntoproperty[l] - value = source.manager[prop.key].impl.get(source, source_dict, - attributes.PASSIVE_OFF) + value = source.manager[prop.key].impl.get( + source, source_dict, attributes.PASSIVE_OFF + ) except exc.UnmappedColumnError: _raise_col_to_prop(False, source_mapper, l, dest_mapper, r) @@ -39,14 +47,16 @@ def populate(source, source_mapper, dest, dest_mapper, # how often this logic is invoked for memory/performance # reasons, since we only need this info for a primary key # destination. - if flag_cascaded_pks and l.primary_key and \ - r.primary_key and \ - r.references(l): + if ( + flag_cascaded_pks + and l.primary_key + and r.primary_key + and r.references(l) + ): uowcommit.attributes[("pk_cascaded", dest, r)] = True -def bulk_populate_inherit_keys( - source_dict, source_mapper, synchronize_pairs): +def bulk_populate_inherit_keys(source_dict, source_mapper, synchronize_pairs): # a simplified version of populate() used by bulk insert mode for l, r in synchronize_pairs: try: @@ -64,14 +74,15 @@ def bulk_populate_inherit_keys( def clear(dest, dest_mapper, synchronize_pairs): for l, r in synchronize_pairs: - if r.primary_key and \ - dest_mapper._get_state_attr_by_column( - dest, dest.dict, r) not in orm_util._none_set: + if ( + r.primary_key + and dest_mapper._get_state_attr_by_column(dest, dest.dict, r) + not in orm_util._none_set + ): raise AssertionError( "Dependency rule tried to blank-out primary key " - "column '%s' on instance '%s'" % - (r, orm_util.state_str(dest)) + "column '%s' on instance '%s'" % (r, orm_util.state_str(dest)) ) try: dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None) @@ -83,9 +94,11 @@ def update(source, source_mapper, dest, old_prefix, synchronize_pairs): for l, r in synchronize_pairs: try: oldvalue = source_mapper._get_committed_attr_by_column( - source.obj(), l) + source.obj(), l + ) value = source_mapper._get_state_attr_by_column( - source, source.dict, l, passive=attributes.PASSIVE_OFF) + source, source.dict, l, passive=attributes.PASSIVE_OFF + ) except exc.UnmappedColumnError: _raise_col_to_prop(False, source_mapper, l, None, r) dest[r.key] = value @@ -96,7 +109,8 @@ def populate_dict(source, source_mapper, dict_, synchronize_pairs): for l, r in synchronize_pairs: try: value = source_mapper._get_state_attr_by_column( - source, source.dict, l, passive=attributes.PASSIVE_OFF) + source, source.dict, l, passive=attributes.PASSIVE_OFF + ) except exc.UnmappedColumnError: _raise_col_to_prop(False, source_mapper, l, None, r) @@ -114,27 +128,31 @@ def source_modified(uowcommit, source, source_mapper, synchronize_pairs): except exc.UnmappedColumnError: _raise_col_to_prop(False, source_mapper, l, None, r) history = uowcommit.get_attribute_history( - source, prop.key, attributes.PASSIVE_NO_INITIALIZE) + source, prop.key, attributes.PASSIVE_NO_INITIALIZE + ) if bool(history.deleted): return True else: return False -def _raise_col_to_prop(isdest, source_mapper, source_column, - dest_mapper, dest_column): +def _raise_col_to_prop( + isdest, source_mapper, source_column, dest_mapper, dest_column +): if isdest: raise exc.UnmappedColumnError( "Can't execute sync rule for " "destination column '%s'; mapper '%s' does not map " "this column. Try using an explicit `foreign_keys` " "collection which does not include this column (or use " - "a viewonly=True relation)." % (dest_column, dest_mapper)) + "a viewonly=True relation)." % (dest_column, dest_mapper) + ) else: raise exc.UnmappedColumnError( "Can't execute sync rule for " "source column '%s'; mapper '%s' does not map this " "column. Try using an explicit `foreign_keys` " "collection which does not include destination column " - "'%s' (or use a viewonly=True relation)." % - (source_column, source_mapper, dest_column)) + "'%s' (or use a viewonly=True relation)." + % (source_column, source_mapper, dest_column) + ) diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index a83a99d78d..545811bb42 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -41,9 +41,11 @@ def track_cascade_events(descriptor, prop): prop = state.manager.mapper._props[key] item_state = attributes.instance_state(item) - if prop._cascade.save_update and \ - (prop.cascade_backrefs or key == initiator.key) and \ - not sess._contains_state(item_state): + if ( + prop._cascade.save_update + and (prop.cascade_backrefs or key == initiator.key) + and not sess._contains_state(item_state) + ): sess._save_or_update_state(item_state) return item @@ -59,12 +61,15 @@ def track_cascade_events(descriptor, prop): sess._flush_warning( "collection remove" if prop.uselist - else "related attribute delete") + else "related attribute delete" + ) - if item is not None and \ - item is not attributes.NEVER_SET and \ - item is not attributes.PASSIVE_NO_RESULT and \ - prop._cascade.delete_orphan: + if ( + item is not None + and item is not attributes.NEVER_SET + and item is not attributes.PASSIVE_NO_RESULT + and prop._cascade.delete_orphan + ): # expunge pending orphans item_state = attributes.instance_state(item) @@ -93,26 +98,31 @@ def track_cascade_events(descriptor, prop): prop = state.manager.mapper._props[key] if newvalue is not None: newvalue_state = attributes.instance_state(newvalue) - if prop._cascade.save_update and \ - (prop.cascade_backrefs or key == initiator.key) and \ - not sess._contains_state(newvalue_state): + if ( + prop._cascade.save_update + and (prop.cascade_backrefs or key == initiator.key) + and not sess._contains_state(newvalue_state) + ): sess._save_or_update_state(newvalue_state) - if oldvalue is not None and \ - oldvalue is not attributes.NEVER_SET and \ - oldvalue is not attributes.PASSIVE_NO_RESULT and \ - prop._cascade.delete_orphan: + if ( + oldvalue is not None + and oldvalue is not attributes.NEVER_SET + and oldvalue is not attributes.PASSIVE_NO_RESULT + and prop._cascade.delete_orphan + ): # possible to reach here with attributes.NEVER_SET ? oldvalue_state = attributes.instance_state(oldvalue) - if oldvalue_state in sess._new and \ - prop.mapper._is_orphan(oldvalue_state): + if oldvalue_state in sess._new and prop.mapper._is_orphan( + oldvalue_state + ): sess.expunge(oldvalue) return newvalue - event.listen(descriptor, 'append', append, raw=True, retval=True) - event.listen(descriptor, 'remove', remove, raw=True, retval=True) - event.listen(descriptor, 'set', set_, raw=True, retval=True) + event.listen(descriptor, "append", append, raw=True, retval=True) + event.listen(descriptor, "remove", remove, raw=True, retval=True) + event.listen(descriptor, "set", set_, raw=True, retval=True) class UOWTransaction(object): @@ -197,8 +207,9 @@ class UOWTransaction(object): self.states[state] = (isdelete, True) - def get_attribute_history(self, state, key, - passive=attributes.PASSIVE_NO_INITIALIZE): + def get_attribute_history( + self, state, key, passive=attributes.PASSIVE_NO_INITIALIZE + ): """facade to attributes.get_state_history(), including caching of results.""" @@ -213,12 +224,16 @@ class UOWTransaction(object): # if the cached lookup was "passive" and now # we want non-passive, do a non-passive lookup and re-cache - if not cached_passive & attributes.SQL_OK \ - and passive & attributes.SQL_OK: + if ( + not cached_passive & attributes.SQL_OK + and passive & attributes.SQL_OK + ): impl = state.manager[key].impl - history = impl.get_history(state, state.dict, - attributes.PASSIVE_OFF | - attributes.LOAD_AGAINST_COMMITTED) + history = impl.get_history( + state, + state.dict, + attributes.PASSIVE_OFF | attributes.LOAD_AGAINST_COMMITTED, + ) if history and impl.uses_objects: state_history = history.as_state() else: @@ -228,14 +243,14 @@ class UOWTransaction(object): impl = state.manager[key].impl # TODO: store the history as (state, object) tuples # so we don't have to keep converting here - history = impl.get_history(state, state.dict, passive | - attributes.LOAD_AGAINST_COMMITTED) + history = impl.get_history( + state, state.dict, passive | attributes.LOAD_AGAINST_COMMITTED + ) if history and impl.uses_objects: state_history = history.as_state() else: state_history = history - self.attributes[hashkey] = (history, state_history, - passive) + self.attributes[hashkey] = (history, state_history, passive) return state_history @@ -247,17 +262,25 @@ class UOWTransaction(object): if key not in self.presort_actions: self.presort_actions[key] = Preprocess(processor, fromparent) - def register_object(self, state, isdelete=False, - listonly=False, cancel_delete=False, - operation=None, prop=None): + def register_object( + self, + state, + isdelete=False, + listonly=False, + cancel_delete=False, + operation=None, + prop=None, + ): if not self.session._contains_state(state): # this condition is normal when objects are registered # as part of a relationship cascade operation. it should # not occur for the top-level register from Session.flush(). if not state.deleted and operation is not None: - util.warn("Object of type %s not in session, %s operation " - "along '%s' will not proceed" % - (orm_util.state_class_str(state), operation, prop)) + util.warn( + "Object of type %s not in session, %s operation " + "along '%s' will not proceed" + % (orm_util.state_class_str(state), operation, prop) + ) return False if state not in self.states: @@ -340,24 +363,26 @@ class UOWTransaction(object): # see if the graph of mapper dependencies has cycles. self.cycles = cycles = topological.find_cycles( - self.dependencies, - list(self.postsort_actions.values())) + self.dependencies, list(self.postsort_actions.values()) + ) if cycles: # if yes, break the per-mapper actions into # per-state actions convert = dict( - (rec, set(rec.per_state_flush_actions(self))) - for rec in cycles + (rec, set(rec.per_state_flush_actions(self))) for rec in cycles ) # rewrite the existing dependencies to point to # the per-state actions for those per-mapper actions # that were broken up. for edge in list(self.dependencies): - if None in edge or \ - edge[0].disabled or edge[1].disabled or \ - cycles.issuperset(edge): + if ( + None in edge + or edge[0].disabled + or edge[1].disabled + or cycles.issuperset(edge) + ): self.dependencies.remove(edge) elif edge[0] in cycles: self.dependencies.remove(edge) @@ -368,10 +393,9 @@ class UOWTransaction(object): for dep in convert[edge[1]]: self.dependencies.add((edge[0], dep)) - return set([a for a in self.postsort_actions.values() - if not a.disabled - ] - ).difference(cycles) + return set( + [a for a in self.postsort_actions.values() if not a.disabled] + ).difference(cycles) def execute(self): postsort_actions = self._generate_actions() @@ -386,15 +410,13 @@ class UOWTransaction(object): # execute if self.cycles: for set_ in topological.sort_as_subsets( - self.dependencies, - postsort_actions): + self.dependencies, postsort_actions + ): while set_: n = set_.pop() n.execute_aggregate(self, set_) else: - for rec in topological.sort( - self.dependencies, - postsort_actions): + for rec in topological.sort(self.dependencies, postsort_actions): rec.execute(self) def finalize_flush_changes(self): @@ -410,8 +432,7 @@ class UOWTransaction(object): states = set(self.states) isdel = set( - s for (s, (isdelete, listonly)) in self.states.items() - if isdelete + s for (s, (isdelete, listonly)) in self.states.items() if isdelete ) other = states.difference(isdel) if isdel: @@ -424,8 +445,8 @@ class IterateMappersMixin(object): def _mappers(self, uow): if self.fromparent: return iter( - m for m in - self.dependency_processor.parent.self_and_descendants + m + for m in self.dependency_processor.parent.self_and_descendants if uow._mapper_for_dep[(m, self.dependency_processor)] ) else: @@ -434,8 +455,10 @@ class IterateMappersMixin(object): class Preprocess(IterateMappersMixin): __slots__ = ( - 'dependency_processor', 'fromparent', 'processed', - 'setup_flush_actions' + "dependency_processor", + "fromparent", + "processed", + "setup_flush_actions", ) def __init__(self, dependency_processor, fromparent): @@ -464,12 +487,14 @@ class Preprocess(IterateMappersMixin): self.dependency_processor.presort_saves(uow, save_states) self.processed.update(save_states) - if (delete_states or save_states): + if delete_states or save_states: if not self.setup_flush_actions and ( - self.dependency_processor. - prop_has_changes(uow, delete_states, True) or - self.dependency_processor. - prop_has_changes(uow, save_states, False) + self.dependency_processor.prop_has_changes( + uow, delete_states, True + ) + or self.dependency_processor.prop_has_changes( + uow, save_states, False + ) ): self.dependency_processor.per_property_flush_actions(uow) self.setup_flush_actions = True @@ -479,16 +504,14 @@ class Preprocess(IterateMappersMixin): class PostSortRec(object): - __slots__ = 'disabled', + __slots__ = ("disabled",) def __new__(cls, uow, *args): - key = (cls, ) + args + key = (cls,) + args if key in uow.postsort_actions: return uow.postsort_actions[key] else: - uow.postsort_actions[key] = \ - ret = \ - object.__new__(cls) + uow.postsort_actions[key] = ret = object.__new__(cls) ret.disabled = False return ret @@ -497,14 +520,15 @@ class PostSortRec(object): class ProcessAll(IterateMappersMixin, PostSortRec): - __slots__ = 'dependency_processor', 'isdelete', 'fromparent' + __slots__ = "dependency_processor", "isdelete", "fromparent" def __init__(self, uow, dependency_processor, isdelete, fromparent): self.dependency_processor = dependency_processor self.isdelete = isdelete self.fromparent = fromparent - uow.deps[dependency_processor.parent.base_mapper].\ - add(dependency_processor) + uow.deps[dependency_processor.parent.base_mapper].add( + dependency_processor + ) def execute(self, uow): states = self._elements(uow) @@ -524,7 +548,7 @@ class ProcessAll(IterateMappersMixin, PostSortRec): return "%s(%s, isdelete=%s)" % ( self.__class__.__name__, self.dependency_processor, - self.isdelete + self.isdelete, ) def _elements(self, uow): @@ -536,7 +560,7 @@ class ProcessAll(IterateMappersMixin, PostSortRec): class PostUpdateAll(PostSortRec): - __slots__ = 'mapper', 'isdelete' + __slots__ = "mapper", "isdelete" def __init__(self, uow, mapper, isdelete): self.mapper = mapper @@ -550,22 +574,23 @@ class PostUpdateAll(PostSortRec): class SaveUpdateAll(PostSortRec): - __slots__ = 'mapper', + __slots__ = ("mapper",) def __init__(self, uow, mapper): self.mapper = mapper assert mapper is mapper.base_mapper def execute(self, uow): - persistence.save_obj(self.mapper, - uow.states_for_mapper_hierarchy( - self.mapper, False, False), - uow - ) + persistence.save_obj( + self.mapper, + uow.states_for_mapper_hierarchy(self.mapper, False, False), + uow, + ) def per_state_flush_actions(self, uow): - states = list(uow.states_for_mapper_hierarchy( - self.mapper, False, False)) + states = list( + uow.states_for_mapper_hierarchy(self.mapper, False, False) + ) base_mapper = self.mapper.base_mapper delete_all = DeleteAll(uow, base_mapper) for state in states: @@ -580,29 +605,27 @@ class SaveUpdateAll(PostSortRec): dep.per_state_flush_actions(uow, states_for_prop, False) def __repr__(self): - return "%s(%s)" % ( - self.__class__.__name__, - self.mapper - ) + return "%s(%s)" % (self.__class__.__name__, self.mapper) class DeleteAll(PostSortRec): - __slots__ = 'mapper', + __slots__ = ("mapper",) def __init__(self, uow, mapper): self.mapper = mapper assert mapper is mapper.base_mapper def execute(self, uow): - persistence.delete_obj(self.mapper, - uow.states_for_mapper_hierarchy( - self.mapper, True, False), - uow - ) + persistence.delete_obj( + self.mapper, + uow.states_for_mapper_hierarchy(self.mapper, True, False), + uow, + ) def per_state_flush_actions(self, uow): - states = list(uow.states_for_mapper_hierarchy( - self.mapper, True, False)) + states = list( + uow.states_for_mapper_hierarchy(self.mapper, True, False) + ) base_mapper = self.mapper.base_mapper save_all = SaveUpdateAll(uow, base_mapper) for state in states: @@ -617,14 +640,11 @@ class DeleteAll(PostSortRec): dep.per_state_flush_actions(uow, states_for_prop, True) def __repr__(self): - return "%s(%s)" % ( - self.__class__.__name__, - self.mapper - ) + return "%s(%s)" % (self.__class__.__name__, self.mapper) class ProcessState(PostSortRec): - __slots__ = 'dependency_processor', 'isdelete', 'state' + __slots__ = "dependency_processor", "isdelete", "state" def __init__(self, uow, dependency_processor, isdelete, state): self.dependency_processor = dependency_processor @@ -635,10 +655,13 @@ class ProcessState(PostSortRec): cls_ = self.__class__ dependency_processor = self.dependency_processor isdelete = self.isdelete - our_recs = [r for r in recs - if r.__class__ is cls_ and - r.dependency_processor is dependency_processor and - r.isdelete is isdelete] + our_recs = [ + r + for r in recs + if r.__class__ is cls_ + and r.dependency_processor is dependency_processor + and r.isdelete is isdelete + ] recs.difference_update(our_recs) states = [self.state] + [r.state for r in our_recs] if isdelete: @@ -651,12 +674,12 @@ class ProcessState(PostSortRec): self.__class__.__name__, self.dependency_processor, orm_util.state_str(self.state), - self.isdelete + self.isdelete, ) class SaveUpdateState(PostSortRec): - __slots__ = 'state', 'mapper' + __slots__ = "state", "mapper" def __init__(self, uow, state): self.state = state @@ -665,24 +688,23 @@ class SaveUpdateState(PostSortRec): def execute_aggregate(self, uow, recs): cls_ = self.__class__ mapper = self.mapper - our_recs = [r for r in recs - if r.__class__ is cls_ and - r.mapper is mapper] + our_recs = [ + r for r in recs if r.__class__ is cls_ and r.mapper is mapper + ] recs.difference_update(our_recs) - persistence.save_obj(mapper, - [self.state] + - [r.state for r in our_recs], - uow) + persistence.save_obj( + mapper, [self.state] + [r.state for r in our_recs], uow + ) def __repr__(self): return "%s(%s)" % ( self.__class__.__name__, - orm_util.state_str(self.state) + orm_util.state_str(self.state), ) class DeleteState(PostSortRec): - __slots__ = 'state', 'mapper' + __slots__ = "state", "mapper" def __init__(self, uow, state): self.state = state @@ -691,17 +713,17 @@ class DeleteState(PostSortRec): def execute_aggregate(self, uow, recs): cls_ = self.__class__ mapper = self.mapper - our_recs = [r for r in recs - if r.__class__ is cls_ and - r.mapper is mapper] + our_recs = [ + r for r in recs if r.__class__ is cls_ and r.mapper is mapper + ] recs.difference_update(our_recs) states = [self.state] + [r.state for r in our_recs] - persistence.delete_obj(mapper, - [s for s in states if uow.states[s][0]], - uow) + persistence.delete_obj( + mapper, [s for s in states if uow.states[s][0]], uow + ) def __repr__(self): return "%s(%s)" % ( self.__class__.__name__, - orm_util.state_str(self.state) + orm_util.state_str(self.state), ) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 43709a58cc..a1b0cd5dad 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -12,27 +12,51 @@ from .interfaces import PropComparator, MapperProperty from . import attributes import re -from .base import instance_str, state_str, state_class_str, attribute_str, \ - state_attribute_str, object_mapper, object_state, _none_set, _never_set +from .base import ( + instance_str, + state_str, + state_class_str, + attribute_str, + state_attribute_str, + object_mapper, + object_state, + _none_set, + _never_set, +) from .base import class_mapper, _class_to_mapper from .base import InspectionAttr from .path_registry import PathRegistry -all_cascades = frozenset(("delete", "delete-orphan", "all", "merge", - "expunge", "save-update", "refresh-expire", - "none")) +all_cascades = frozenset( + ( + "delete", + "delete-orphan", + "all", + "merge", + "expunge", + "save-update", + "refresh-expire", + "none", + ) +) class CascadeOptions(frozenset): """Keeps track of the options sent to relationship().cascade""" - _add_w_all_cascades = all_cascades.difference([ - 'all', 'none', 'delete-orphan']) + _add_w_all_cascades = all_cascades.difference( + ["all", "none", "delete-orphan"] + ) _allowed_cascades = all_cascades __slots__ = ( - 'save_update', 'delete', 'refresh_expire', 'merge', - 'expunge', 'delete_orphan') + "save_update", + "delete", + "refresh_expire", + "merge", + "expunge", + "delete_orphan", + ) def __new__(cls, value_list): if isinstance(value_list, util.string_types) or value_list is None: @@ -40,60 +64,62 @@ class CascadeOptions(frozenset): values = set(value_list) if values.difference(cls._allowed_cascades): raise sa_exc.ArgumentError( - "Invalid cascade option(s): %s" % - ", ".join([repr(x) for x in - sorted(values.difference(cls._allowed_cascades))])) + "Invalid cascade option(s): %s" + % ", ".join( + [ + repr(x) + for x in sorted( + values.difference(cls._allowed_cascades) + ) + ] + ) + ) if "all" in values: values.update(cls._add_w_all_cascades) if "none" in values: values.clear() - values.discard('all') + values.discard("all") self = frozenset.__new__(CascadeOptions, values) - self.save_update = 'save-update' in values - self.delete = 'delete' in values - self.refresh_expire = 'refresh-expire' in values - self.merge = 'merge' in values - self.expunge = 'expunge' in values + self.save_update = "save-update" in values + self.delete = "delete" in values + self.refresh_expire = "refresh-expire" in values + self.merge = "merge" in values + self.expunge = "expunge" in values self.delete_orphan = "delete-orphan" in values if self.delete_orphan and not self.delete: - util.warn("The 'delete-orphan' cascade " - "option requires 'delete'.") + util.warn( + "The 'delete-orphan' cascade " "option requires 'delete'." + ) return self def __repr__(self): - return "CascadeOptions(%r)" % ( - ",".join([x for x in sorted(self)]) - ) + return "CascadeOptions(%r)" % (",".join([x for x in sorted(self)])) @classmethod def from_string(cls, arg): - values = [ - c for c - in re.split(r'\s*,\s*', arg or "") - if c - ] + values = [c for c in re.split(r"\s*,\s*", arg or "") if c] return cls(values) -def _validator_events( - desc, key, validator, include_removes, include_backrefs): +def _validator_events(desc, key, validator, include_removes, include_backrefs): """Runs a validation method on an attribute value to be set or appended. """ if not include_backrefs: + def detect_is_backref(state, initiator): impl = state.manager[key].impl return initiator.impl is not impl if include_removes: + def append(state, value, initiator): - if ( - initiator.op is not attributes.OP_BULK_REPLACE and - (include_backrefs or not detect_is_backref(state, initiator)) + if initiator.op is not attributes.OP_BULK_REPLACE and ( + include_backrefs or not detect_is_backref(state, initiator) ): return validator(state.obj(), key, value, False) else: @@ -103,7 +129,8 @@ def _validator_events( if include_backrefs or not detect_is_backref(state, initiator): obj = state.obj() values[:] = [ - validator(obj, key, value, False) for value in values] + validator(obj, key, value, False) for value in values + ] def set_(state, value, oldvalue, initiator): if include_backrefs or not detect_is_backref(state, initiator): @@ -116,10 +143,10 @@ def _validator_events( validator(state.obj(), key, value, True) else: + def append(state, value, initiator): - if ( - initiator.op is not attributes.OP_BULK_REPLACE and - (include_backrefs or not detect_is_backref(state, initiator)) + if initiator.op is not attributes.OP_BULK_REPLACE and ( + include_backrefs or not detect_is_backref(state, initiator) ): return validator(state.obj(), key, value) else: @@ -128,8 +155,7 @@ def _validator_events( def bulk_set(state, values, initiator): if include_backrefs or not detect_is_backref(state, initiator): obj = state.obj() - values[:] = [ - validator(obj, key, value) for value in values] + values[:] = [validator(obj, key, value) for value in values] def set_(state, value, oldvalue, initiator): if include_backrefs or not detect_is_backref(state, initiator): @@ -137,15 +163,16 @@ def _validator_events( else: return value - event.listen(desc, 'append', append, raw=True, retval=True) - event.listen(desc, 'bulk_replace', bulk_set, raw=True) - event.listen(desc, 'set', set_, raw=True, retval=True) + event.listen(desc, "append", append, raw=True, retval=True) + event.listen(desc, "bulk_replace", bulk_set, raw=True) + event.listen(desc, "set", set_, raw=True, retval=True) if include_removes: event.listen(desc, "remove", remove, raw=True, retval=True) -def polymorphic_union(table_map, typecolname, - aliasname='p_union', cast_nulls=True): +def polymorphic_union( + table_map, typecolname, aliasname="p_union", cast_nulls=True +): """Create a ``UNION`` statement used by a polymorphic mapper. See :ref:`concrete_inheritance` for an example of how @@ -197,14 +224,22 @@ def polymorphic_union(table_map, typecolname, for type, table in table_map.items(): if typecolname is not None: result.append( - sql.select([col(name, table) for name in colnames] + - [sql.literal_column( - sql_util._quote_ddl_expr(type)). - label(typecolname)], - from_obj=[table])) + sql.select( + [col(name, table) for name in colnames] + + [ + sql.literal_column( + sql_util._quote_ddl_expr(type) + ).label(typecolname) + ], + from_obj=[table], + ) + ) else: - result.append(sql.select([col(name, table) for name in colnames], - from_obj=[table])) + result.append( + sql.select( + [col(name, table) for name in colnames], from_obj=[table] + ) + ) return sql.union_all(*result).alias(aliasname) @@ -284,25 +319,29 @@ first() class_, ident = args else: raise sa_exc.ArgumentError( - "expected up to three positional arguments, " - "got %s" % largs) + "expected up to three positional arguments, " "got %s" % largs + ) identity_token = kwargs.pop("identity_token", None) if kwargs: - raise sa_exc.ArgumentError("unknown keyword arguments: %s" - % ", ".join(kwargs)) + raise sa_exc.ArgumentError( + "unknown keyword arguments: %s" % ", ".join(kwargs) + ) mapper = class_mapper(class_) if row is None: return mapper.identity_key_from_primary_key( - util.to_list(ident), identity_token=identity_token) + util.to_list(ident), identity_token=identity_token + ) else: return mapper.identity_key_from_row( - row, identity_token=identity_token) + row, identity_token=identity_token + ) else: instance = kwargs.pop("instance") if kwargs: - raise sa_exc.ArgumentError("unknown keyword arguments: %s" - % ", ".join(kwargs.keys)) + raise sa_exc.ArgumentError( + "unknown keyword arguments: %s" % ", ".join(kwargs.keys) + ) mapper = object_mapper(instance) return mapper.identity_key_from_instance(instance) @@ -313,9 +352,15 @@ class ORMAdapter(sql_util.ColumnAdapter): """ - def __init__(self, entity, equivalents=None, adapt_required=False, - chain_to=None, allow_label_resolve=True, - anonymize_labels=False): + def __init__( + self, + entity, + equivalents=None, + adapt_required=False, + chain_to=None, + allow_label_resolve=True, + anonymize_labels=False, + ): info = inspection.inspect(entity) self.mapper = info.mapper @@ -327,15 +372,18 @@ class ORMAdapter(sql_util.ColumnAdapter): self.aliased_class = None sql_util.ColumnAdapter.__init__( - self, selectable, equivalents, chain_to, + self, + selectable, + equivalents, + chain_to, adapt_required=adapt_required, allow_label_resolve=allow_label_resolve, anonymize_labels=anonymize_labels, - include_fn=self._include_fn + include_fn=self._include_fn, ) def _include_fn(self, elem): - entity = elem._annotations.get('parentmapper', None) + entity = elem._annotations.get("parentmapper", None) return not entity or entity.isa(self.mapper) @@ -380,20 +428,25 @@ class AliasedClass(object): """ - def __init__(self, cls, alias=None, - name=None, - flat=False, - adapt_on_names=False, - # TODO: None for default here? - with_polymorphic_mappers=(), - with_polymorphic_discriminator=None, - base_alias=None, - use_mapper_path=False, - represents_outer_join=False): + def __init__( + self, + cls, + alias=None, + name=None, + flat=False, + adapt_on_names=False, + # TODO: None for default here? + with_polymorphic_mappers=(), + with_polymorphic_discriminator=None, + base_alias=None, + use_mapper_path=False, + represents_outer_join=False, + ): mapper = _class_to_mapper(cls) if alias is None: alias = mapper._with_polymorphic_selectable.alias( - name=name, flat=flat) + name=name, flat=flat + ) self._aliased_insp = AliasedInsp( self, @@ -409,14 +462,14 @@ class AliasedClass(object): base_alias, use_mapper_path, adapt_on_names, - represents_outer_join + represents_outer_join, ) - self.__name__ = 'AliasedClass_%s' % mapper.class_.__name__ + self.__name__ = "AliasedClass_%s" % mapper.class_.__name__ def __getattr__(self, key): try: - _aliased_insp = self.__dict__['_aliased_insp'] + _aliased_insp = self.__dict__["_aliased_insp"] except KeyError: raise AttributeError() else: @@ -434,13 +487,13 @@ class AliasedClass(object): ret = attr.adapt_to_entity(_aliased_insp) setattr(self, key, ret) return ret - elif hasattr(attr, 'func_code'): + elif hasattr(attr, "func_code"): is_method = getattr(_aliased_insp._target, key, None) if is_method and is_method.__self__ is not None: return util.types.MethodType(attr.__func__, self, self) else: return None - elif hasattr(attr, '__get__'): + elif hasattr(attr, "__get__"): ret = attr.__get__(None, self) if isinstance(ret, PropComparator): return ret.adapt_to_entity(_aliased_insp) @@ -450,8 +503,10 @@ class AliasedClass(object): return attr def __repr__(self): - return '' % ( - id(self), self._aliased_insp._target.__name__) + return "" % ( + id(self), + self._aliased_insp._target.__name__, + ) class AliasedInsp(InspectionAttr): @@ -490,10 +545,19 @@ class AliasedInsp(InspectionAttr): """ - def __init__(self, entity, mapper, selectable, name, - with_polymorphic_mappers, polymorphic_on, - _base_alias, _use_mapper_path, adapt_on_names, - represents_outer_join): + def __init__( + self, + entity, + mapper, + selectable, + name, + with_polymorphic_mappers, + polymorphic_on, + _base_alias, + _use_mapper_path, + adapt_on_names, + represents_outer_join, + ): self.entity = entity self.mapper = mapper self.selectable = selectable @@ -505,18 +569,28 @@ class AliasedInsp(InspectionAttr): self.represents_outer_join = represents_outer_join self._adapter = sql_util.ColumnAdapter( - selectable, equivalents=mapper._equivalent_columns, - adapt_on_names=adapt_on_names, anonymize_labels=True) + selectable, + equivalents=mapper._equivalent_columns, + adapt_on_names=adapt_on_names, + anonymize_labels=True, + ) self._adapt_on_names = adapt_on_names self._target = mapper.class_ for poly in self.with_polymorphic_mappers: if poly is not mapper: - setattr(self.entity, poly.class_.__name__, - AliasedClass(poly.class_, selectable, base_alias=self, - adapt_on_names=adapt_on_names, - use_mapper_path=_use_mapper_path)) + setattr( + self.entity, + poly.class_.__name__, + AliasedClass( + poly.class_, + selectable, + base_alias=self, + adapt_on_names=adapt_on_names, + use_mapper_path=_use_mapper_path, + ), + ) is_aliased_class = True "always returns True" @@ -536,39 +610,35 @@ class AliasedInsp(InspectionAttr): def __getstate__(self): return { - 'entity': self.entity, - 'mapper': self.mapper, - 'alias': self.selectable, - 'name': self.name, - 'adapt_on_names': self._adapt_on_names, - 'with_polymorphic_mappers': - self.with_polymorphic_mappers, - 'with_polymorphic_discriminator': - self.polymorphic_on, - 'base_alias': self._base_alias, - 'use_mapper_path': self._use_mapper_path, - 'represents_outer_join': self.represents_outer_join + "entity": self.entity, + "mapper": self.mapper, + "alias": self.selectable, + "name": self.name, + "adapt_on_names": self._adapt_on_names, + "with_polymorphic_mappers": self.with_polymorphic_mappers, + "with_polymorphic_discriminator": self.polymorphic_on, + "base_alias": self._base_alias, + "use_mapper_path": self._use_mapper_path, + "represents_outer_join": self.represents_outer_join, } def __setstate__(self, state): self.__init__( - state['entity'], - state['mapper'], - state['alias'], - state['name'], - state['with_polymorphic_mappers'], - state['with_polymorphic_discriminator'], - state['base_alias'], - state['use_mapper_path'], - state['adapt_on_names'], - state['represents_outer_join'] + state["entity"], + state["mapper"], + state["alias"], + state["name"], + state["with_polymorphic_mappers"], + state["with_polymorphic_discriminator"], + state["base_alias"], + state["use_mapper_path"], + state["adapt_on_names"], + state["represents_outer_join"], ) def _adapt_element(self, elem): - return self._adapter.traverse(elem).\ - _annotate({ - 'parententity': self, - 'parentmapper': self.mapper} + return self._adapter.traverse(elem)._annotate( + {"parententity": self, "parentmapper": self.mapper} ) def _entity_for_mapper(self, mapper): @@ -578,12 +648,12 @@ class AliasedInsp(InspectionAttr): return self else: return getattr( - self.entity, mapper.class_.__name__)._aliased_insp + self.entity, mapper.class_.__name__ + )._aliased_insp elif mapper.isa(self.mapper): return self else: - assert False, "mapper %s doesn't correspond to %s" % ( - mapper, self) + assert False, "mapper %s doesn't correspond to %s" % (mapper, self) @util.memoized_property def _memoized_values(self): @@ -599,11 +669,15 @@ class AliasedInsp(InspectionAttr): def __repr__(self): if self.with_polymorphic_mappers: with_poly = "(%s)" % ", ".join( - mp.class_.__name__ for mp in self.with_polymorphic_mappers) + mp.class_.__name__ for mp in self.with_polymorphic_mappers + ) else: with_poly = "" - return '' % ( - id(self), self.class_.__name__, with_poly) + return "" % ( + id(self), + self.class_.__name__, + with_poly, + ) inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) @@ -700,15 +774,26 @@ def aliased(element, alias=None, name=None, flat=False, adapt_on_names=False): ) return element.alias(name, flat=flat) else: - return AliasedClass(element, alias=alias, flat=flat, - name=name, adapt_on_names=adapt_on_names) + return AliasedClass( + element, + alias=alias, + flat=flat, + name=name, + adapt_on_names=adapt_on_names, + ) -def with_polymorphic(base, classes, selectable=False, - flat=False, - polymorphic_on=None, aliased=False, - innerjoin=False, _use_mapper_path=False, - _existing_alias=None): +def with_polymorphic( + base, + classes, + selectable=False, + flat=False, + polymorphic_on=None, + aliased=False, + innerjoin=False, + _use_mapper_path=False, + _existing_alias=None, +): """Produce an :class:`.AliasedClass` construct which specifies columns for descendant mappers of the given base. @@ -777,24 +862,26 @@ def with_polymorphic(base, classes, selectable=False, if _existing_alias: assert _existing_alias.mapper is primary_mapper classes = util.to_set(classes) - new_classes = set([ - mp.class_ for mp in - _existing_alias.with_polymorphic_mappers]) + new_classes = set( + [mp.class_ for mp in _existing_alias.with_polymorphic_mappers] + ) if classes == new_classes: return _existing_alias else: classes = classes.union(new_classes) - mappers, selectable = primary_mapper.\ - _with_polymorphic_args(classes, selectable, - innerjoin=innerjoin) + mappers, selectable = primary_mapper._with_polymorphic_args( + classes, selectable, innerjoin=innerjoin + ) if aliased or flat: selectable = selectable.alias(flat=flat) - return AliasedClass(base, - selectable, - with_polymorphic_mappers=mappers, - with_polymorphic_discriminator=polymorphic_on, - use_mapper_path=_use_mapper_path, - represents_outer_join=not innerjoin) + return AliasedClass( + base, + selectable, + with_polymorphic_mappers=mappers, + with_polymorphic_discriminator=polymorphic_on, + use_mapper_path=_use_mapper_path, + represents_outer_join=not innerjoin, + ) def _orm_annotate(element, exclude=None): @@ -804,7 +891,7 @@ def _orm_annotate(element, exclude=None): Elements within the exclude collection will be cloned but not annotated. """ - return sql_util._deep_annotate(element, {'_orm_adapt': True}, exclude) + return sql_util._deep_annotate(element, {"_orm_adapt": True}, exclude) def _orm_deannotate(element): @@ -816,9 +903,9 @@ def _orm_deannotate(element): """ - return sql_util._deep_deannotate(element, - values=("_orm_adapt", "parententity") - ) + return sql_util._deep_deannotate( + element, values=("_orm_adapt", "parententity") + ) def _orm_full_deannotate(element): @@ -831,12 +918,18 @@ class _ORMJoin(expression.Join): __visit_name__ = expression.Join.__visit_name__ def __init__( - self, - left, right, onclause=None, isouter=False, - full=False, _left_memo=None, _right_memo=None): + self, + left, + right, + onclause=None, + isouter=False, + full=False, + _left_memo=None, + _right_memo=None, + ): left_info = inspection.inspect(left) - left_orm_info = getattr(left, '_joined_from_info', left_info) + left_orm_info = getattr(left, "_joined_from_info", left_info) right_info = inspection.inspect(right) adapt_to = right_info.selectable @@ -859,19 +952,18 @@ class _ORMJoin(expression.Join): prop = None if prop: - if sql_util.clause_is_present( - on_selectable, left_info.selectable): + if sql_util.clause_is_present(on_selectable, left_info.selectable): adapt_from = on_selectable else: adapt_from = left_info.selectable - pj, sj, source, dest, \ - secondary, target_adapter = prop._create_joins( - source_selectable=adapt_from, - dest_selectable=adapt_to, - source_polymorphic=True, - dest_polymorphic=True, - of_type=right_info.mapper) + pj, sj, source, dest, secondary, target_adapter = prop._create_joins( + source_selectable=adapt_from, + dest_selectable=adapt_to, + source_polymorphic=True, + dest_polymorphic=True, + of_type=right_info.mapper, + ) if sj is not None: if isouter: @@ -887,8 +979,11 @@ class _ORMJoin(expression.Join): expression.Join.__init__(self, left, right, onclause, isouter, full) - if not prop and getattr(right_info, 'mapper', None) \ - and right_info.mapper.single: + if ( + not prop + and getattr(right_info, "mapper", None) + and right_info.mapper.single + ): # if single inheritance target and we are using a manual # or implicit ON clause, augment it the same way we'd augment the # WHERE. @@ -911,33 +1006,39 @@ class _ORMJoin(expression.Join): assert self.right is leftmost left = _ORMJoin( - self.left, other.left, - self.onclause, isouter=self.isouter, + self.left, + other.left, + self.onclause, + isouter=self.isouter, _left_memo=self._left_memo, - _right_memo=other._left_memo + _right_memo=other._left_memo, ) return _ORMJoin( left, other.right, - other.onclause, isouter=other.isouter, - _right_memo=other._right_memo + other.onclause, + isouter=other.isouter, + _right_memo=other._right_memo, ) def join( - self, right, onclause=None, - isouter=False, full=False, join_to_left=None): + self, + right, + onclause=None, + isouter=False, + full=False, + join_to_left=None, + ): return _ORMJoin(self, right, onclause, full, isouter) - def outerjoin( - self, right, onclause=None, - full=False, join_to_left=None): + def outerjoin(self, right, onclause=None, full=False, join_to_left=None): return _ORMJoin(self, right, onclause, True, full=full) def join( - left, right, onclause=None, isouter=False, - full=False, join_to_left=None): + left, right, onclause=None, isouter=False, full=False, join_to_left=None +): r"""Produce an inner join between left and right clauses. :func:`.orm.join` is an extension to the core join interface @@ -1085,8 +1186,9 @@ def _entity_isa(given, mapper): """ if given.is_aliased_class: - return mapper in given.with_polymorphic_mappers or \ - given.mapper.isa(mapper) + return mapper in given.with_polymorphic_mappers or given.mapper.isa( + mapper + ) elif given.with_polymorphic_mappers: return mapper in given.with_polymorphic_mappers else: @@ -1126,5 +1228,7 @@ def randomize_unitofwork(): from sqlalchemy.orm import unitofwork, session, mapper, dependency from sqlalchemy.util import topological from sqlalchemy.testing.util import RandomSet - topological.set = unitofwork.set = session.set = mapper.set = \ - dependency.set = RandomSet + + topological.set = ( + unitofwork.set + ) = session.set = mapper.set = dependency.set = RandomSet diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py index f2f0350518..2aa6eeeb7b 100644 --- a/lib/sqlalchemy/pool/__init__.py +++ b/lib/sqlalchemy/pool/__init__.py @@ -20,7 +20,12 @@ SQLAlchemy connection pool. from .base import _refs # noqa from .base import Pool # noqa from .impl import ( # noqa - QueuePool, StaticPool, NullPool, AssertionPool, SingletonThreadPool) + QueuePool, + StaticPool, + NullPool, + AssertionPool, + SingletonThreadPool, +) from .dbapi_proxy import manage, clear_managers # noqa from .base import reset_rollback, reset_commit, reset_none # noqa diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 442d3b64a5..382e740c6b 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -18,9 +18,9 @@ from .. import exc, log, event, interfaces, util from ..util import threading -reset_rollback = util.symbol('reset_rollback') -reset_commit = util.symbol('reset_commit') -reset_none = util.symbol('reset_none') +reset_rollback = util.symbol("reset_rollback") +reset_commit = util.symbol("reset_commit") +reset_none = util.symbol("reset_none") class _ConnDialect(object): @@ -46,7 +46,8 @@ class _ConnDialect(object): def do_ping(self, dbapi_connection): raise NotImplementedError( "The ping feature requires that a dialect is " - "passed to the connection pool.") + "passed to the connection pool." + ) class Pool(log.Identified): @@ -55,16 +56,20 @@ class Pool(log.Identified): _dialect = _ConnDialect() - def __init__(self, - creator, recycle=-1, echo=None, - use_threadlocal=False, - logging_name=None, - reset_on_return=True, - listeners=None, - events=None, - dialect=None, - pre_ping=False, - _dispatch=None): + def __init__( + self, + creator, + recycle=-1, + echo=None, + use_threadlocal=False, + logging_name=None, + reset_on_return=True, + listeners=None, + events=None, + dialect=None, + pre_ping=False, + _dispatch=None, + ): """ Construct a Pool. @@ -200,16 +205,16 @@ class Pool(log.Identified): self._invalidate_time = 0 self._use_threadlocal = use_threadlocal self._pre_ping = pre_ping - if reset_on_return in ('rollback', True, reset_rollback): + if reset_on_return in ("rollback", True, reset_rollback): self._reset_on_return = reset_rollback - elif reset_on_return in ('none', None, False, reset_none): + elif reset_on_return in ("none", None, False, reset_none): self._reset_on_return = reset_none - elif reset_on_return in ('commit', reset_commit): + elif reset_on_return in ("commit", reset_commit): self._reset_on_return = reset_commit else: raise exc.ArgumentError( - "Invalid value for 'reset_on_return': %r" - % reset_on_return) + "Invalid value for 'reset_on_return': %r" % reset_on_return + ) self.echo = echo @@ -223,17 +228,18 @@ class Pool(log.Identified): if listeners: util.warn_deprecated( "The 'listeners' argument to Pool (and " - "create_engine()) is deprecated. Use event.listen().") + "create_engine()) is deprecated. Use event.listen()." + ) for l in listeners: self.add_listener(l) @property def _creator(self): - return self.__dict__['_creator'] + return self.__dict__["_creator"] @_creator.setter def _creator(self, creator): - self.__dict__['_creator'] = creator + self.__dict__["_creator"] = creator self._invoke_creator = self._should_wrap_creator(creator) def _should_wrap_creator(self, creator): @@ -252,7 +258,7 @@ class Pool(log.Identified): # look for the exact arg signature that DefaultStrategy # sends us - if (argspec[0], argspec[3]) == (['connection_record'], (None,)): + if (argspec[0], argspec[3]) == (["connection_record"], (None,)): return creator # or just a single positional elif positionals == 1: @@ -268,11 +274,13 @@ class Pool(log.Identified): try: self._dialect.do_close(connection) except Exception: - self.logger.error("Exception closing connection %r", - connection, exc_info=True) + self.logger.error( + "Exception closing connection %r", connection, exc_info=True + ) @util.deprecated( - 2.7, "Pool.add_listener is deprecated. Use event.listen()") + 2.7, "Pool.add_listener is deprecated. Use event.listen()" + ) def add_listener(self, listener): """Add a :class:`.PoolListener`-like object to this pool. @@ -315,7 +323,7 @@ class Pool(log.Identified): rec = getattr(connection, "_connection_record", None) if not rec or self._invalidate_time < rec.starttime: self._invalidate_time = time.time() - if _checkin and getattr(connection, 'is_valid', False): + if _checkin and getattr(connection, "is_valid", False): connection.invalidate(exception) def recreate(self): @@ -491,15 +499,14 @@ class _ConnectionRecord(object): fairy = _ConnectionFairy(dbapi_connection, rec, echo) rec.fairy_ref = weakref.ref( fairy, - lambda ref: _finalize_fairy and - _finalize_fairy( - None, - rec, pool, ref, echo) + lambda ref: _finalize_fairy + and _finalize_fairy(None, rec, pool, ref, echo), ) _refs.add(rec) if echo: - pool.logger.debug("Connection %r checked out from pool", - dbapi_connection) + pool.logger.debug( + "Connection %r checked out from pool", dbapi_connection + ) return fairy def _checkin_failed(self, err): @@ -563,12 +570,16 @@ class _ConnectionRecord(object): self.__pool.logger.info( "%sInvalidate connection %r (reason: %s:%s)", "Soft " if soft else "", - self.connection, e.__class__.__name__, e) + self.connection, + e.__class__.__name__, + e, + ) else: self.__pool.logger.info( "%sInvalidate connection %r", "Soft " if soft else "", - self.connection) + self.connection, + ) if soft: self._soft_invalidate_time = time.time() else: @@ -580,24 +591,26 @@ class _ConnectionRecord(object): if self.connection is None: self.info.clear() self.__connect() - elif self.__pool._recycle > -1 and \ - time.time() - self.starttime > self.__pool._recycle: + elif ( + self.__pool._recycle > -1 + and time.time() - self.starttime > self.__pool._recycle + ): self.__pool.logger.info( - "Connection %r exceeded timeout; recycling", - self.connection) + "Connection %r exceeded timeout; recycling", self.connection + ) recycle = True elif self.__pool._invalidate_time > self.starttime: self.__pool.logger.info( - "Connection %r invalidated due to pool invalidation; " + - "recycling", - self.connection + "Connection %r invalidated due to pool invalidation; " + + "recycling", + self.connection, ) recycle = True elif self._soft_invalidate_time > self.starttime: self.__pool.logger.info( - "Connection %r invalidated due to local soft invalidation; " + - "recycling", - self.connection + "Connection %r invalidated due to local soft invalidation; " + + "recycling", + self.connection, ) recycle = True @@ -631,15 +644,16 @@ class _ConnectionRecord(object): raise else: if first_connect_check: - pool.dispatch.first_connect.\ - for_modify(pool.dispatch).\ - exec_once(self.connection, self) + pool.dispatch.first_connect.for_modify( + pool.dispatch + ).exec_once(self.connection, self) if pool.dispatch.connect: pool.dispatch.connect(self.connection, self) -def _finalize_fairy(connection, connection_record, - pool, ref, echo, fairy=None): +def _finalize_fairy( + connection, connection_record, pool, ref, echo, fairy=None +): """Cleanup for a :class:`._ConnectionFairy` whether or not it's already been garbage collected. @@ -654,12 +668,14 @@ def _finalize_fairy(connection, connection_record, if connection is not None: if connection_record and echo: - pool.logger.debug("Connection %r being returned to pool", - connection) + pool.logger.debug( + "Connection %r being returned to pool", connection + ) try: fairy = fairy or _ConnectionFairy( - connection, connection_record, echo) + connection, connection_record, echo + ) assert fairy.connection is connection fairy._reset(pool) @@ -670,7 +686,8 @@ def _finalize_fairy(connection, connection_record, pool._close_connection(connection) except BaseException as e: pool.logger.error( - "Exception during reset or similar", exc_info=True) + "Exception during reset or similar", exc_info=True + ) if connection_record: connection_record.invalidate(e=e) if not isinstance(e, Exception): @@ -752,8 +769,9 @@ class _ConnectionFairy(object): raise exc.InvalidRequestError("This connection is closed") fairy._counter += 1 - if (not pool.dispatch.checkout and not pool._pre_ping) or \ - fairy._counter != 1: + if ( + not pool.dispatch.checkout and not pool._pre_ping + ) or fairy._counter != 1: return fairy # Pool listeners can trigger a reconnection on checkout, as well @@ -767,38 +785,45 @@ class _ConnectionFairy(object): if pool._pre_ping: if fairy._echo: pool.logger.debug( - "Pool pre-ping on connection %s", - fairy.connection) + "Pool pre-ping on connection %s", fairy.connection + ) result = pool._dialect.do_ping(fairy.connection) if not result: if fairy._echo: pool.logger.debug( "Pool pre-ping on connection %s failed, " - "will invalidate pool", fairy.connection) + "will invalidate pool", + fairy.connection, + ) raise exc.InvalidatePoolError() - pool.dispatch.checkout(fairy.connection, - fairy._connection_record, - fairy) + pool.dispatch.checkout( + fairy.connection, fairy._connection_record, fairy + ) return fairy except exc.DisconnectionError as e: if e.invalidate_pool: pool.logger.info( "Disconnection detected on checkout, " "invalidating all pooled connections prior to " - "current timestamp (reason: %r)", e) + "current timestamp (reason: %r)", + e, + ) fairy._connection_record.invalidate(e) pool._invalidate(fairy, e, _checkin=False) else: pool.logger.info( "Disconnection detected on checkout, " "invalidating individual connection %s (reason: %r)", - fairy.connection, e) + fairy.connection, + e, + ) fairy._connection_record.invalidate(e) try: - fairy.connection = \ + fairy.connection = ( fairy._connection_record.get_connection() + ) except Exception as err: with util.safe_reraise(): fairy._connection_record._checkin_failed(err) @@ -813,8 +838,14 @@ class _ConnectionFairy(object): return _ConnectionFairy._checkout(self._pool, fairy=self) def _checkin(self): - _finalize_fairy(self.connection, self._connection_record, - self._pool, None, self._echo, fairy=self) + _finalize_fairy( + self.connection, + self._connection_record, + self._pool, + None, + self._echo, + fairy=self, + ) self.connection = None self._connection_record = None @@ -825,20 +856,22 @@ class _ConnectionFairy(object): pool.dispatch.reset(self, self._connection_record) if pool._reset_on_return is reset_rollback: if self._echo: - pool.logger.debug("Connection %s rollback-on-return%s", - self.connection, - ", via agent" - if self._reset_agent else "") + pool.logger.debug( + "Connection %s rollback-on-return%s", + self.connection, + ", via agent" if self._reset_agent else "", + ) if self._reset_agent: self._reset_agent.rollback() else: pool._dialect.do_rollback(self) elif pool._reset_on_return is reset_commit: if self._echo: - pool.logger.debug("Connection %s commit-on-return%s", - self.connection, - ", via agent" - if self._reset_agent else "") + pool.logger.debug( + "Connection %s commit-on-return%s", + self.connection, + ", via agent" if self._reset_agent else "", + ) if self._reset_agent: self._reset_agent.commit() else: @@ -964,5 +997,3 @@ class _ConnectionFairy(object): self._counter -= 1 if self._counter == 0: self._checkin() - - diff --git a/lib/sqlalchemy/pool/dbapi_proxy.py b/lib/sqlalchemy/pool/dbapi_proxy.py index aa439bd239..425c4a1145 100644 --- a/lib/sqlalchemy/pool/dbapi_proxy.py +++ b/lib/sqlalchemy/pool/dbapi_proxy.py @@ -101,9 +101,10 @@ class _DBProxy(object): self._create_pool_mutex.acquire() try: if key not in self.pools: - kw.pop('sa_pool_key', None) + kw.pop("sa_pool_key", None) pool = self.poolclass( - lambda: self.module.connect(*args, **kw), **self.kw) + lambda: self.module.connect(*args, **kw), **self.kw + ) self.pools[key] = pool return pool else: @@ -138,9 +139,6 @@ class _DBProxy(object): def _serialize(self, *args, **kw): if "sa_pool_key" in kw: - return kw['sa_pool_key'] + return kw["sa_pool_key"] - return tuple( - list(args) + - [(k, kw[k]) for k in sorted(kw)] - ) + return tuple(list(args) + [(k, kw[k]) for k in sorted(kw)]) diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py index 3058d62472..6159f6a5b3 100644 --- a/lib/sqlalchemy/pool/impl.py +++ b/lib/sqlalchemy/pool/impl.py @@ -30,8 +30,15 @@ class QueuePool(Pool): """ - def __init__(self, creator, pool_size=5, max_overflow=10, timeout=30, use_lifo=False, - **kw): + def __init__( + self, + creator, + pool_size=5, + max_overflow=10, + timeout=30, + use_lifo=False, + **kw + ): r""" Construct a QueuePool. @@ -117,8 +124,10 @@ class QueuePool(Pool): else: raise exc.TimeoutError( "QueuePool limit of size %d overflow %d reached, " - "connection timed out, timeout %d" % - (self.size(), self.overflow(), self._timeout), code="3o7r") + "connection timed out, timeout %d" + % (self.size(), self.overflow(), self._timeout), + code="3o7r", + ) if self._inc_overflow(): try: @@ -150,15 +159,19 @@ class QueuePool(Pool): def recreate(self): self.logger.info("Pool recreating") - return self.__class__(self._creator, pool_size=self._pool.maxsize, - max_overflow=self._max_overflow, - timeout=self._timeout, - recycle=self._recycle, echo=self.echo, - logging_name=self._orig_logging_name, - use_threadlocal=self._use_threadlocal, - reset_on_return=self._reset_on_return, - _dispatch=self.dispatch, - dialect=self._dialect) + return self.__class__( + self._creator, + pool_size=self._pool.maxsize, + max_overflow=self._max_overflow, + timeout=self._timeout, + recycle=self._recycle, + echo=self.echo, + logging_name=self._orig_logging_name, + use_threadlocal=self._use_threadlocal, + reset_on_return=self._reset_on_return, + _dispatch=self.dispatch, + dialect=self._dialect, + ) def dispose(self): while True: @@ -172,12 +185,17 @@ class QueuePool(Pool): self.logger.info("Pool disposed. %s", self.status()) def status(self): - return "Pool size: %d Connections in pool: %d "\ - "Current Overflow: %d Current Checked out "\ - "connections: %d" % (self.size(), - self.checkedin(), - self.overflow(), - self.checkedout()) + return ( + "Pool size: %d Connections in pool: %d " + "Current Overflow: %d Current Checked out " + "connections: %d" + % ( + self.size(), + self.checkedin(), + self.overflow(), + self.checkedout(), + ) + ) def size(self): return self._pool.maxsize @@ -221,14 +239,16 @@ class NullPool(Pool): def recreate(self): self.logger.info("Pool recreating") - return self.__class__(self._creator, - recycle=self._recycle, - echo=self.echo, - logging_name=self._orig_logging_name, - use_threadlocal=self._use_threadlocal, - reset_on_return=self._reset_on_return, - _dispatch=self.dispatch, - dialect=self._dialect) + return self.__class__( + self._creator, + recycle=self._recycle, + echo=self.echo, + logging_name=self._orig_logging_name, + use_threadlocal=self._use_threadlocal, + reset_on_return=self._reset_on_return, + _dispatch=self.dispatch, + dialect=self._dialect, + ) def dispose(self): pass @@ -266,7 +286,7 @@ class SingletonThreadPool(Pool): """ def __init__(self, creator, pool_size=5, **kw): - kw['use_threadlocal'] = True + kw["use_threadlocal"] = True Pool.__init__(self, creator, **kw) self._conn = threading.local() self._all_conns = set() @@ -274,15 +294,17 @@ class SingletonThreadPool(Pool): def recreate(self): self.logger.info("Pool recreating") - return self.__class__(self._creator, - pool_size=self.size, - recycle=self._recycle, - echo=self.echo, - logging_name=self._orig_logging_name, - use_threadlocal=self._use_threadlocal, - reset_on_return=self._reset_on_return, - _dispatch=self.dispatch, - dialect=self._dialect) + return self.__class__( + self._creator, + pool_size=self.size, + recycle=self._recycle, + echo=self.echo, + logging_name=self._orig_logging_name, + use_threadlocal=self._use_threadlocal, + reset_on_return=self._reset_on_return, + _dispatch=self.dispatch, + dialect=self._dialect, + ) def dispose(self): """Dispose of this pool.""" @@ -303,8 +325,10 @@ class SingletonThreadPool(Pool): c.close() def status(self): - return "SingletonThreadPool id:%d size: %d" % \ - (id(self), len(self._all_conns)) + return "SingletonThreadPool id:%d size: %d" % ( + id(self), + len(self._all_conns), + ) def _do_return_conn(self, conn): pass @@ -347,20 +371,22 @@ class StaticPool(Pool): return "StaticPool" def dispose(self): - if '_conn' in self.__dict__: + if "_conn" in self.__dict__: self._conn.close() self._conn = None def recreate(self): self.logger.info("Pool recreating") - return self.__class__(creator=self._creator, - recycle=self._recycle, - use_threadlocal=self._use_threadlocal, - reset_on_return=self._reset_on_return, - echo=self.echo, - logging_name=self._orig_logging_name, - _dispatch=self.dispatch, - dialect=self._dialect) + return self.__class__( + creator=self._creator, + recycle=self._recycle, + use_threadlocal=self._use_threadlocal, + reset_on_return=self._reset_on_return, + echo=self.echo, + logging_name=self._orig_logging_name, + _dispatch=self.dispatch, + dialect=self._dialect, + ) def _create_connection(self): return self._conn @@ -391,7 +417,7 @@ class AssertionPool(Pool): def __init__(self, *args, **kw): self._conn = None self._checked_out = False - self._store_traceback = kw.pop('store_traceback', True) + self._store_traceback = kw.pop("store_traceback", True) self._checkout_traceback = None Pool.__init__(self, *args, **kw) @@ -411,18 +437,22 @@ class AssertionPool(Pool): def recreate(self): self.logger.info("Pool recreating") - return self.__class__(self._creator, echo=self.echo, - logging_name=self._orig_logging_name, - _dispatch=self.dispatch, - dialect=self._dialect) + return self.__class__( + self._creator, + echo=self.echo, + logging_name=self._orig_logging_name, + _dispatch=self.dispatch, + dialect=self._dialect, + ) def _do_get(self): if self._checked_out: if self._checkout_traceback: - suffix = ' at:\n%s' % ''.join( - chop_traceback(self._checkout_traceback)) + suffix = " at:\n%s" % "".join( + chop_traceback(self._checkout_traceback) + ) else: - suffix = '' + suffix = "" raise AssertionError("connection is already checked out" + suffix) if not self._conn: diff --git a/lib/sqlalchemy/processors.py b/lib/sqlalchemy/processors.py index 860a55b8f5..46d5dcbc63 100644 --- a/lib/sqlalchemy/processors.py +++ b/lib/sqlalchemy/processors.py @@ -32,20 +32,30 @@ def str_to_datetime_processor_factory(regexp, type_): try: m = rmatch(value) except TypeError: - raise ValueError("Couldn't parse %s string '%r' " - "- value is not a string." % - (type_.__name__, value)) + raise ValueError( + "Couldn't parse %s string '%r' " + "- value is not a string." % (type_.__name__, value) + ) if m is None: - raise ValueError("Couldn't parse %s string: " - "'%s'" % (type_.__name__, value)) + raise ValueError( + "Couldn't parse %s string: " + "'%s'" % (type_.__name__, value) + ) if has_named_groups: groups = m.groupdict(0) - return type_(**dict(list(zip( - iter(groups.keys()), - list(map(int, iter(groups.values()))) - )))) + return type_( + **dict( + list( + zip( + iter(groups.keys()), + list(map(int, iter(groups.values()))), + ) + ) + ) + ) else: return type_(*list(map(int, m.groups(0)))) + return process @@ -61,6 +71,7 @@ def py_fallback(): # len part is safe: it is done that way in the normal # 'xx'.decode(encoding) code path. return decoder(value, errors)[0] + return process def to_conditional_unicode_processor_factory(encoding, errors=None): @@ -76,6 +87,7 @@ def py_fallback(): # len part is safe: it is done that way in the normal # 'xx'.decode(encoding) code path. return decoder(value, errors)[0] + return process def to_decimal_processor_factory(target_class, scale): @@ -86,6 +98,7 @@ def py_fallback(): return None else: return target_class(fstring % value) + return process def to_float(value): @@ -107,22 +120,30 @@ def py_fallback(): return bool(value) DATETIME_RE = re.compile( - r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?") + r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?" + ) TIME_RE = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") DATE_RE = re.compile(r"(\d+)-(\d+)-(\d+)") - str_to_datetime = str_to_datetime_processor_factory(DATETIME_RE, - datetime.datetime) + str_to_datetime = str_to_datetime_processor_factory( + DATETIME_RE, datetime.datetime + ) str_to_time = str_to_datetime_processor_factory(TIME_RE, datetime.time) str_to_date = str_to_datetime_processor_factory(DATE_RE, datetime.date) return locals() + try: - from sqlalchemy.cprocessors import UnicodeResultProcessor, \ - DecimalResultProcessor, \ - to_float, to_str, int_to_boolean, \ - str_to_datetime, str_to_time, \ - str_to_date + from sqlalchemy.cprocessors import ( + UnicodeResultProcessor, + DecimalResultProcessor, + to_float, + to_str, + int_to_boolean, + str_to_datetime, + str_to_time, + str_to_date, + ) def to_unicode_processor_factory(encoding, errors=None): if errors is not None: @@ -144,5 +165,6 @@ try: # return Decimal('5'). These are equivalent of course. return DecimalResultProcessor(target_class, "%%.%df" % scale).process + except ImportError: globals().update(py_fallback()) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index aa7b4f0089..598d499dc8 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -9,9 +9,7 @@ """ -from .sql.base import ( - SchemaVisitor - ) +from .sql.base import SchemaVisitor from .sql.schema import ( @@ -36,8 +34,8 @@ from .sql.schema import ( UniqueConstraint, _get_table_key, ColumnCollectionConstraint, - ColumnCollectionMixin - ) + ColumnCollectionMixin, +) from .sql.naming import conv diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index aa811388b5..87e2fb6c39 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -72,7 +72,7 @@ from .expression import ( union, union_all, update, - within_group + within_group, ) from .visitors import ClauseVisitor @@ -84,12 +84,16 @@ def __go(lcls): import inspect as _inspect - __all__ = sorted(name for name, obj in lcls.items() - if not (name.startswith('_') or _inspect.ismodule(obj))) + __all__ = sorted( + name + for name, obj in lcls.items() + if not (name.startswith("_") or _inspect.ismodule(obj)) + ) from .annotation import _prepare_annotations, Annotated from .elements import AnnotatedColumnElement, ClauseList from .selectable import AnnotatedFromClause + _prepare_annotations(ColumnElement, AnnotatedColumnElement) _prepare_annotations(FromClause, AnnotatedFromClause) _prepare_annotations(ClauseList, Annotated) @@ -98,4 +102,5 @@ def __go(lcls): from . import naming + __go(locals()) diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index c1d484d953..64cfa630e8 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -76,8 +76,7 @@ class Annotated(object): return self._with_annotations(_values) def _compiler_dispatch(self, visitor, **kw): - return self.__element.__class__._compiler_dispatch( - self, visitor, **kw) + return self.__element.__class__._compiler_dispatch(self, visitor, **kw) @property def _constructor(self): @@ -120,10 +119,13 @@ def _deep_annotate(element, annotations, exclude=None): Elements within the exclude collection will be cloned but not annotated. """ + def clone(elem): - if exclude and \ - hasattr(elem, 'proxy_set') and \ - elem.proxy_set.intersection(exclude): + if ( + exclude + and hasattr(elem, "proxy_set") + and elem.proxy_set.intersection(exclude) + ): newelem = elem._clone() elif annotations != elem._annotations: newelem = elem._annotate(annotations) @@ -191,8 +193,8 @@ def _new_annotation_type(cls, base_cls): break annotated_classes[cls] = anno_cls = type( - "Annotated%s" % cls.__name__, - (base_cls, cls), {}) + "Annotated%s" % cls.__name__, (base_cls, cls), {} + ) globals()["Annotated%s" % cls.__name__] = anno_cls return anno_cls diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 6b9b55753e..45db215fee 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -15,8 +15,8 @@ import itertools from .visitors import ClauseVisitor import re -PARSE_AUTOCOMMIT = util.symbol('PARSE_AUTOCOMMIT') -NO_ARG = util.symbol('NO_ARG') +PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT") +NO_ARG = util.symbol("NO_ARG") class Immutable(object): @@ -77,7 +77,8 @@ class _DialectArgView(util.collections_abc.MutableMapping): dialect, value_key = self._key(key) except KeyError: raise exc.ArgumentError( - "Keys must be of the form _") + "Keys must be of the form _" + ) else: self.obj.dialect_options[dialect][value_key] = value @@ -86,15 +87,18 @@ class _DialectArgView(util.collections_abc.MutableMapping): del self.obj.dialect_options[dialect][value_key] def __len__(self): - return sum(len(args._non_defaults) for args in - self.obj.dialect_options.values()) + return sum( + len(args._non_defaults) + for args in self.obj.dialect_options.values() + ) def __iter__(self): return ( util.safe_kwarg("%s_%s" % (dialect_name, value_name)) for dialect_name in self.obj.dialect_options - for value_name in - self.obj.dialect_options[dialect_name]._non_defaults + for value_name in self.obj.dialect_options[ + dialect_name + ]._non_defaults ) @@ -187,8 +191,8 @@ class DialectKWArgs(object): if construct_arg_dictionary is None: raise exc.ArgumentError( "Dialect '%s' does have keyword-argument " - "validation and defaults enabled configured" % - dialect_name) + "validation and defaults enabled configured" % dialect_name + ) if cls not in construct_arg_dictionary: construct_arg_dictionary[cls] = {} construct_arg_dictionary[cls][argument_name] = default @@ -230,6 +234,7 @@ class DialectKWArgs(object): if dialect_cls.construct_arguments is None: return None return dict(dialect_cls.construct_arguments) + _kw_registry = util.PopulateDict(_kw_reg_for_dialect) def _kw_reg_for_dialect_cls(self, dialect_name): @@ -274,11 +279,12 @@ class DialectKWArgs(object): return for k in kwargs: - m = re.match('^(.+?)_(.+)$', k) + m = re.match("^(.+?)_(.+)$", k) if not m: raise TypeError( "Additional arguments should be " - "named _, got '%s'" % k) + "named _, got '%s'" % k + ) dialect_name, arg_name = m.group(1, 2) try: @@ -286,20 +292,22 @@ class DialectKWArgs(object): except exc.NoSuchModuleError: util.warn( "Can't validate argument %r; can't " - "locate any SQLAlchemy dialect named %r" % - (k, dialect_name)) + "locate any SQLAlchemy dialect named %r" + % (k, dialect_name) + ) self.dialect_options[dialect_name] = d = _DialectArgDict() d._defaults.update({"*": None}) d._non_defaults[arg_name] = kwargs[k] else: - if "*" not in construct_arg_dictionary and \ - arg_name not in construct_arg_dictionary: + if ( + "*" not in construct_arg_dictionary + and arg_name not in construct_arg_dictionary + ): raise exc.ArgumentError( "Argument %r is not accepted by " - "dialect %r on behalf of %r" % ( - k, - dialect_name, self.__class__ - )) + "dialect %r on behalf of %r" + % (k, dialect_name, self.__class__) + ) else: construct_arg_dictionary[arg_name] = kwargs[k] @@ -359,14 +367,14 @@ class Executable(Generative): :meth:`.Query.execution_options()` """ - if 'isolation_level' in kw: + if "isolation_level" in kw: raise exc.ArgumentError( "'isolation_level' execution option may only be specified " "on Connection.execution_options(), or " "per-engine using the isolation_level " "argument to create_engine()." ) - if 'compiled_cache' in kw: + if "compiled_cache" in kw: raise exc.ArgumentError( "'compiled_cache' execution option may only be specified " "on Connection.execution_options(), not per statement." @@ -377,10 +385,12 @@ class Executable(Generative): """Compile and execute this :class:`.Executable`.""" e = self.bind if e is None: - label = getattr(self, 'description', self.__class__.__name__) - msg = ('This %s is not directly bound to a Connection or Engine. ' - 'Use the .execute() method of a Connection or Engine ' - 'to execute this construct.' % label) + label = getattr(self, "description", self.__class__.__name__) + msg = ( + "This %s is not directly bound to a Connection or Engine. " + "Use the .execute() method of a Connection or Engine " + "to execute this construct." % label + ) raise exc.UnboundExecutionError(msg) return e._execute_clauseelement(self, multiparams, params) @@ -434,7 +444,7 @@ class SchemaEventTarget(object): class SchemaVisitor(ClauseVisitor): """Define the visiting for ``SchemaItem`` objects.""" - __traverse_options__ = {'schema_visitor': True} + __traverse_options__ = {"schema_visitor": True} class ColumnCollection(util.OrderedProperties): @@ -446,11 +456,11 @@ class ColumnCollection(util.OrderedProperties): """ - __slots__ = '_all_columns' + __slots__ = "_all_columns" def __init__(self, *columns): super(ColumnCollection, self).__init__() - object.__setattr__(self, '_all_columns', []) + object.__setattr__(self, "_all_columns", []) for c in columns: self.add(c) @@ -485,8 +495,9 @@ class ColumnCollection(util.OrderedProperties): self._data[column.key] = column if remove_col is not None: - self._all_columns[:] = [column if c is remove_col - else c for c in self._all_columns] + self._all_columns[:] = [ + column if c is remove_col else c for c in self._all_columns + ] else: self._all_columns.append(column) @@ -499,7 +510,8 @@ class ColumnCollection(util.OrderedProperties): """ if not column.key: raise exc.ArgumentError( - "Can't add unnamed column to column collection") + "Can't add unnamed column to column collection" + ) self[column.key] = column def __delitem__(self, key): @@ -521,10 +533,12 @@ class ColumnCollection(util.OrderedProperties): return if not existing.shares_lineage(value): - util.warn('Column %r on table %r being replaced by ' - '%r, which has the same key. Consider ' - 'use_labels for select() statements.' % - (key, getattr(existing, 'table', None), value)) + util.warn( + "Column %r on table %r being replaced by " + "%r, which has the same key. Consider " + "use_labels for select() statements." + % (key, getattr(existing, "table", None), value) + ) # pop out memoized proxy_set as this # operation may very well be occurring @@ -540,13 +554,15 @@ class ColumnCollection(util.OrderedProperties): def remove(self, column): del self._data[column.key] self._all_columns[:] = [ - c for c in self._all_columns if c is not column] + c for c in self._all_columns if c is not column + ] def update(self, iter): cols = list(iter) all_col_set = set(self._all_columns) self._all_columns.extend( - c for label, c in cols if c not in all_col_set) + c for label, c in cols if c not in all_col_set + ) self._data.update((label, c) for label, c in cols) def extend(self, iter): @@ -572,12 +588,11 @@ class ColumnCollection(util.OrderedProperties): return util.OrderedProperties.__contains__(self, other) def __getstate__(self): - return {'_data': self._data, - '_all_columns': self._all_columns} + return {"_data": self._data, "_all_columns": self._all_columns} def __setstate__(self, state): - object.__setattr__(self, '_data', state['_data']) - object.__setattr__(self, '_all_columns', state['_all_columns']) + object.__setattr__(self, "_data", state["_data"]) + object.__setattr__(self, "_all_columns", state["_all_columns"]) def contains_column(self, col): return col in set(self._all_columns) @@ -589,7 +604,7 @@ class ColumnCollection(util.OrderedProperties): class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection): def __init__(self, data, all_columns): util.ImmutableProperties.__init__(self, data) - object.__setattr__(self, '_all_columns', all_columns) + object.__setattr__(self, "_all_columns", all_columns) extend = remove = util.ImmutableProperties._immutable @@ -622,15 +637,18 @@ def _bind_or_error(schemaitem, msg=None): bind = schemaitem.bind if not bind: name = schemaitem.__class__.__name__ - label = getattr(schemaitem, 'fullname', - getattr(schemaitem, 'name', None)) + label = getattr( + schemaitem, "fullname", getattr(schemaitem, "name", None) + ) if label: - item = '%s object %r' % (name, label) + item = "%s object %r" % (name, label) else: - item = '%s object' % name + item = "%s object" % name if msg is None: - msg = "%s is not bound to an Engine or Connection. "\ - "Execution can not proceed without a database to execute "\ + msg = ( + "%s is not bound to an Engine or Connection. " + "Execution can not proceed without a database to execute " "against." % item + ) raise exc.UnboundExecutionError(msg) return bind diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 80ed707edf..f641d0a844 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -25,133 +25,218 @@ To generate user-defined SQL strings, see import contextlib import re -from . import schema, sqltypes, operators, functions, visitors, \ - elements, selectable, crud +from . import ( + schema, + sqltypes, + operators, + functions, + visitors, + elements, + selectable, + crud, +) from .. import util, exc import itertools -RESERVED_WORDS = set([ - 'all', 'analyse', 'analyze', 'and', 'any', 'array', - 'as', 'asc', 'asymmetric', 'authorization', 'between', - 'binary', 'both', 'case', 'cast', 'check', 'collate', - 'column', 'constraint', 'create', 'cross', 'current_date', - 'current_role', 'current_time', 'current_timestamp', - 'current_user', 'default', 'deferrable', 'desc', - 'distinct', 'do', 'else', 'end', 'except', 'false', - 'for', 'foreign', 'freeze', 'from', 'full', 'grant', - 'group', 'having', 'ilike', 'in', 'initially', 'inner', - 'intersect', 'into', 'is', 'isnull', 'join', 'leading', - 'left', 'like', 'limit', 'localtime', 'localtimestamp', - 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset', - 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps', - 'placing', 'primary', 'references', 'right', 'select', - 'session_user', 'set', 'similar', 'some', 'symmetric', 'table', - 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user', - 'using', 'verbose', 'when', 'where']) - -LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I) -ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(['$']) - -BIND_PARAMS = re.compile(r'(? ', - operators.ge: ' >= ', - operators.eq: ' = ', - operators.is_distinct_from: ' IS DISTINCT FROM ', - operators.isnot_distinct_from: ' IS NOT DISTINCT FROM ', - operators.concat_op: ' || ', - operators.match_op: ' MATCH ', - operators.notmatch_op: ' NOT MATCH ', - operators.in_op: ' IN ', - operators.notin_op: ' NOT IN ', - operators.comma_op: ', ', - operators.from_: ' FROM ', - operators.as_: ' AS ', - operators.is_: ' IS ', - operators.isnot: ' IS NOT ', - operators.collate: ' COLLATE ', - + operators.and_: " AND ", + operators.or_: " OR ", + operators.add: " + ", + operators.mul: " * ", + operators.sub: " - ", + operators.div: " / ", + operators.mod: " % ", + operators.truediv: " / ", + operators.neg: "-", + operators.lt: " < ", + operators.le: " <= ", + operators.ne: " != ", + operators.gt: " > ", + operators.ge: " >= ", + operators.eq: " = ", + operators.is_distinct_from: " IS DISTINCT FROM ", + operators.isnot_distinct_from: " IS NOT DISTINCT FROM ", + operators.concat_op: " || ", + operators.match_op: " MATCH ", + operators.notmatch_op: " NOT MATCH ", + operators.in_op: " IN ", + operators.notin_op: " NOT IN ", + operators.comma_op: ", ", + operators.from_: " FROM ", + operators.as_: " AS ", + operators.is_: " IS ", + operators.isnot: " IS NOT ", + operators.collate: " COLLATE ", # unary - operators.exists: 'EXISTS ', - operators.distinct_op: 'DISTINCT ', - operators.inv: 'NOT ', - operators.any_op: 'ANY ', - operators.all_op: 'ALL ', - + operators.exists: "EXISTS ", + operators.distinct_op: "DISTINCT ", + operators.inv: "NOT ", + operators.any_op: "ANY ", + operators.all_op: "ALL ", # modifiers - operators.desc_op: ' DESC', - operators.asc_op: ' ASC', - operators.nullsfirst_op: ' NULLS FIRST', - operators.nullslast_op: ' NULLS LAST', - + operators.desc_op: " DESC", + operators.asc_op: " ASC", + operators.nullsfirst_op: " NULLS FIRST", + operators.nullslast_op: " NULLS LAST", } FUNCTIONS = { - functions.coalesce: 'coalesce', - functions.current_date: 'CURRENT_DATE', - functions.current_time: 'CURRENT_TIME', - functions.current_timestamp: 'CURRENT_TIMESTAMP', - functions.current_user: 'CURRENT_USER', - functions.localtime: 'LOCALTIME', - functions.localtimestamp: 'LOCALTIMESTAMP', - functions.random: 'random', - functions.sysdate: 'sysdate', - functions.session_user: 'SESSION_USER', - functions.user: 'USER', - functions.cube: 'CUBE', - functions.rollup: 'ROLLUP', - functions.grouping_sets: 'GROUPING SETS', + functions.coalesce: "coalesce", + functions.current_date: "CURRENT_DATE", + functions.current_time: "CURRENT_TIME", + functions.current_timestamp: "CURRENT_TIMESTAMP", + functions.current_user: "CURRENT_USER", + functions.localtime: "LOCALTIME", + functions.localtimestamp: "LOCALTIMESTAMP", + functions.random: "random", + functions.sysdate: "sysdate", + functions.session_user: "SESSION_USER", + functions.user: "USER", + functions.cube: "CUBE", + functions.rollup: "ROLLUP", + functions.grouping_sets: "GROUPING SETS", } EXTRACT_MAP = { - 'month': 'month', - 'day': 'day', - 'year': 'year', - 'second': 'second', - 'hour': 'hour', - 'doy': 'doy', - 'minute': 'minute', - 'quarter': 'quarter', - 'dow': 'dow', - 'week': 'week', - 'epoch': 'epoch', - 'milliseconds': 'milliseconds', - 'microseconds': 'microseconds', - 'timezone_hour': 'timezone_hour', - 'timezone_minute': 'timezone_minute' + "month": "month", + "day": "day", + "year": "year", + "second": "second", + "hour": "hour", + "doy": "doy", + "minute": "minute", + "quarter": "quarter", + "dow": "dow", + "week": "week", + "epoch": "epoch", + "milliseconds": "milliseconds", + "microseconds": "microseconds", + "timezone_hour": "timezone_hour", + "timezone_minute": "timezone_minute", } COMPOUND_KEYWORDS = { - selectable.CompoundSelect.UNION: 'UNION', - selectable.CompoundSelect.UNION_ALL: 'UNION ALL', - selectable.CompoundSelect.EXCEPT: 'EXCEPT', - selectable.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL', - selectable.CompoundSelect.INTERSECT: 'INTERSECT', - selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL' + selectable.CompoundSelect.UNION: "UNION", + selectable.CompoundSelect.UNION_ALL: "UNION ALL", + selectable.CompoundSelect.EXCEPT: "EXCEPT", + selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL", + selectable.CompoundSelect.INTERSECT: "INTERSECT", + selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL", } @@ -177,9 +262,14 @@ class Compiled(object): sub-elements of the statement can modify these. """ - def __init__(self, dialect, statement, bind=None, - schema_translate_map=None, - compile_kwargs=util.immutabledict()): + def __init__( + self, + dialect, + statement, + bind=None, + schema_translate_map=None, + compile_kwargs=util.immutabledict(), + ): """Construct a new :class:`.Compiled` object. :param dialect: :class:`.Dialect` to compile against. @@ -209,7 +299,8 @@ class Compiled(object): self.preparer = self.dialect.identifier_preparer if schema_translate_map: self.preparer = self.preparer._with_schema_translate( - schema_translate_map) + schema_translate_map + ) if statement is not None: self.statement = statement @@ -218,8 +309,10 @@ class Compiled(object): self.execution_options = statement._execution_options self.string = self.process(self.statement, **compile_kwargs) - @util.deprecated("0.7", ":class:`.Compiled` objects now compile " - "within the constructor.") + @util.deprecated( + "0.7", + ":class:`.Compiled` objects now compile " "within the constructor.", + ) def compile(self): """Produce the internal string representation of this element. """ @@ -247,7 +340,7 @@ class Compiled(object): def __str__(self): """Return the string text of the generated SQL or DDL.""" - return self.string or '' + return self.string or "" def construct_params(self, params=None): """Return the bind params for this compiled object. @@ -271,7 +364,9 @@ class Compiled(object): if e is None: raise exc.UnboundExecutionError( "This Compiled object is not bound to any Engine " - "or Connection.", code="2afi") + "or Connection.", + code="2afi", + ) return e._execute_compiled(self, multiparams, params) def scalar(self, *multiparams, **params): @@ -284,7 +379,7 @@ class Compiled(object): class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)): """Produces DDL specification for TypeEngine objects.""" - ensure_kwarg = r'visit_\w+' + ensure_kwarg = r"visit_\w+" def __init__(self, dialect): self.dialect = dialect @@ -297,8 +392,8 @@ class _CompileLabel(visitors.Visitable): """lightweight label object which acts as an expression.Label.""" - __visit_name__ = 'label' - __slots__ = 'element', 'name' + __visit_name__ = "label" + __slots__ = "element", "name" def __init__(self, col, name, alt_names=()): self.element = col @@ -390,8 +485,9 @@ class SQLCompiler(Compiled): insert_prefetch = update_prefetch = () - def __init__(self, dialect, statement, column_keys=None, - inline=False, **kwargs): + def __init__( + self, dialect, statement, column_keys=None, inline=False, **kwargs + ): """Construct a new :class:`.SQLCompiler` object. :param dialect: :class:`.Dialect` to be used @@ -412,7 +508,7 @@ class SQLCompiler(Compiled): # compile INSERT/UPDATE defaults/sequences inlined (no pre- # execute) - self.inline = inline or getattr(statement, 'inline', False) + self.inline = inline or getattr(statement, "inline", False) # a dictionary of bind parameter keys to BindParameter # instances. @@ -440,8 +536,9 @@ class SQLCompiler(Compiled): self.ctes = None - self.label_length = dialect.label_length \ - or dialect.max_identifier_length + self.label_length = ( + dialect.label_length or dialect.max_identifier_length + ) # a map which tracks "anonymous" identifiers that are created on # the fly here @@ -453,7 +550,7 @@ class SQLCompiler(Compiled): Compiled.__init__(self, dialect, statement, **kwargs) if ( - self.isinsert or self.isupdate or self.isdelete + self.isinsert or self.isupdate or self.isdelete ) and statement._returning: self.returning = statement._returning @@ -482,37 +579,43 @@ class SQLCompiler(Compiled): def _nested_result(self): """special API to support the use case of 'nested result sets'""" result_columns, ordered_columns = ( - self._result_columns, self._ordered_columns) + self._result_columns, + self._ordered_columns, + ) self._result_columns, self._ordered_columns = [], False try: if self.stack: entry = self.stack[-1] - entry['need_result_map_for_nested'] = True + entry["need_result_map_for_nested"] = True else: entry = None yield self._result_columns, self._ordered_columns finally: if entry: - entry.pop('need_result_map_for_nested') + entry.pop("need_result_map_for_nested") self._result_columns, self._ordered_columns = ( - result_columns, ordered_columns) + result_columns, + ordered_columns, + ) def _apply_numbered_params(self): poscount = itertools.count(1) self.string = re.sub( - r'\[_POSITION\]', - lambda m: str(util.next(poscount)), - self.string) + r"\[_POSITION\]", lambda m: str(util.next(poscount)), self.string + ) @util.memoized_property def _bind_processors(self): return dict( - (key, value) for key, value in - ((self.bind_names[bindparam], - bindparam.type._cached_bind_processor(self.dialect) - ) - for bindparam in self.bind_names) + (key, value) + for key, value in ( + ( + self.bind_names[bindparam], + bindparam.type._cached_bind_processor(self.dialect), + ) + for bindparam in self.bind_names + ) if value is not None ) @@ -539,12 +642,16 @@ class SQLCompiler(Compiled): if _group_number: raise exc.InvalidRequestError( "A value is required for bind parameter %r, " - "in parameter group %d" % - (bindparam.key, _group_number), code="cd3x") + "in parameter group %d" + % (bindparam.key, _group_number), + code="cd3x", + ) else: raise exc.InvalidRequestError( "A value is required for bind parameter %r" - % bindparam.key, code="cd3x") + % bindparam.key, + code="cd3x", + ) elif bindparam.callable: pd[name] = bindparam.effective_value @@ -558,12 +665,16 @@ class SQLCompiler(Compiled): if _group_number: raise exc.InvalidRequestError( "A value is required for bind parameter %r, " - "in parameter group %d" % - (bindparam.key, _group_number), code="cd3x") + "in parameter group %d" + % (bindparam.key, _group_number), + code="cd3x", + ) else: raise exc.InvalidRequestError( "A value is required for bind parameter %r" - % bindparam.key, code="cd3x") + % bindparam.key, + code="cd3x", + ) if bindparam.callable: pd[self.bind_names[bindparam]] = bindparam.effective_value @@ -595,9 +706,10 @@ class SQLCompiler(Compiled): return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" def visit_label_reference( - self, element, within_columns_clause=False, **kwargs): + self, element, within_columns_clause=False, **kwargs + ): if self.stack and self.dialect.supports_simple_order_by_label: - selectable = self.stack[-1]['selectable'] + selectable = self.stack[-1]["selectable"] with_cols, only_froms, only_cols = selectable._label_resolve_dict if within_columns_clause: @@ -611,25 +723,30 @@ class SQLCompiler(Compiled): # to something else like a ColumnClause expression. order_by_elem = element.element._order_by_label_element - if order_by_elem is not None and order_by_elem.name in \ - resolve_dict and \ - order_by_elem.shares_lineage( - resolve_dict[order_by_elem.name]): - kwargs['render_label_as_label'] = \ - element.element._order_by_label_element + if ( + order_by_elem is not None + and order_by_elem.name in resolve_dict + and order_by_elem.shares_lineage( + resolve_dict[order_by_elem.name] + ) + ): + kwargs[ + "render_label_as_label" + ] = element.element._order_by_label_element return self.process( - element.element, within_columns_clause=within_columns_clause, - **kwargs) + element.element, + within_columns_clause=within_columns_clause, + **kwargs + ) def visit_textual_label_reference( - self, element, within_columns_clause=False, **kwargs): + self, element, within_columns_clause=False, **kwargs + ): if not self.stack: # compiling the element outside of the context of a SELECT - return self.process( - element._text_clause - ) + return self.process(element._text_clause) - selectable = self.stack[-1]['selectable'] + selectable = self.stack[-1]["selectable"] with_cols, only_froms, only_cols = selectable._label_resolve_dict try: if within_columns_clause: @@ -640,26 +757,30 @@ class SQLCompiler(Compiled): # treat it like text() util.warn_limited( "Can't resolve label reference %r; converting to text()", - util.ellipses_string(element.element)) - return self.process( - element._text_clause + util.ellipses_string(element.element), ) + return self.process(element._text_clause) else: - kwargs['render_label_as_label'] = col + kwargs["render_label_as_label"] = col return self.process( - col, within_columns_clause=within_columns_clause, **kwargs) - - def visit_label(self, label, - add_to_result_map=None, - within_label_clause=False, - within_columns_clause=False, - render_label_as_label=None, - **kw): + col, within_columns_clause=within_columns_clause, **kwargs + ) + + def visit_label( + self, + label, + add_to_result_map=None, + within_label_clause=False, + within_columns_clause=False, + render_label_as_label=None, + **kw + ): # only render labels within the columns clause # or ORDER BY clause of a select. dialect-specific compilers # can modify this behavior. - render_label_with_as = (within_columns_clause and not - within_label_clause) + render_label_with_as = ( + within_columns_clause and not within_label_clause + ) render_label_only = render_label_as_label is label if render_label_only or render_label_with_as: @@ -673,27 +794,35 @@ class SQLCompiler(Compiled): add_to_result_map( labelname, label.name, - (label, labelname, ) + label._alt_names, - label.type + (label, labelname) + label._alt_names, + label.type, ) - return label.element._compiler_dispatch( - self, within_columns_clause=True, - within_label_clause=True, **kw) + \ - OPERATORS[operators.as_] + \ - self.preparer.format_label(label, labelname) + return ( + label.element._compiler_dispatch( + self, + within_columns_clause=True, + within_label_clause=True, + **kw + ) + + OPERATORS[operators.as_] + + self.preparer.format_label(label, labelname) + ) elif render_label_only: return self.preparer.format_label(label, labelname) else: return label.element._compiler_dispatch( - self, within_columns_clause=False, **kw) + self, within_columns_clause=False, **kw + ) def _fallback_column_name(self, column): - raise exc.CompileError("Cannot compile Column object until " - "its 'name' is assigned.") + raise exc.CompileError( + "Cannot compile Column object until " "its 'name' is assigned." + ) - def visit_column(self, column, add_to_result_map=None, - include_table=True, **kwargs): + def visit_column( + self, column, add_to_result_map=None, include_table=True, **kwargs + ): name = orig_name = column.name if name is None: name = self._fallback_column_name(column) @@ -704,10 +833,7 @@ class SQLCompiler(Compiled): if add_to_result_map is not None: add_to_result_map( - name, - orig_name, - (column, name, column.key), - column.type + name, orig_name, (column, name, column.key), column.type ) if is_literal: @@ -721,17 +847,16 @@ class SQLCompiler(Compiled): effective_schema = self.preparer.schema_for_object(table) if effective_schema: - schema_prefix = self.preparer.quote_schema( - effective_schema) + '.' + schema_prefix = ( + self.preparer.quote_schema(effective_schema) + "." + ) else: - schema_prefix = '' + schema_prefix = "" tablename = table.name if isinstance(tablename, elements._truncated_label): tablename = self._truncated_identifier("alias", tablename) - return schema_prefix + \ - self.preparer.quote(tablename) + \ - "." + name + return schema_prefix + self.preparer.quote(tablename) + "." + name def visit_collation(self, element, **kw): return self.preparer.format_collation(element.collation) @@ -743,17 +868,17 @@ class SQLCompiler(Compiled): return index.name def visit_typeclause(self, typeclause, **kw): - kw['type_expression'] = typeclause + kw["type_expression"] = typeclause return self.dialect.type_compiler.process(typeclause.type, **kw) def post_process_text(self, text): if self.preparer._double_percents: - text = text.replace('%', '%%') + text = text.replace("%", "%%") return text def escape_literal_column(self, text): if self.preparer._double_percents: - text = text.replace('%', '%%') + text = text.replace("%", "%%") return text def visit_textclause(self, textclause, **kw): @@ -771,30 +896,36 @@ class SQLCompiler(Compiled): return BIND_PARAMS_ESC.sub( lambda m: m.group(1), BIND_PARAMS.sub( - do_bindparam, - self.post_process_text(textclause.text)) + do_bindparam, self.post_process_text(textclause.text) + ), ) - def visit_text_as_from(self, taf, - compound_index=None, - asfrom=False, - parens=True, **kw): + def visit_text_as_from( + self, taf, compound_index=None, asfrom=False, parens=True, **kw + ): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - populate_result_map = toplevel or \ - ( - compound_index == 0 and entry.get( - 'need_result_map_for_compound', False) - ) or entry.get('need_result_map_for_nested', False) + populate_result_map = ( + toplevel + or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) + or entry.get("need_result_map_for_nested", False) + ) if populate_result_map: - self._ordered_columns = \ - self._textual_ordered_columns = taf.positional + self._ordered_columns = ( + self._textual_ordered_columns + ) = taf.positional for c in taf.column_args: - self.process(c, within_columns_clause=True, - add_to_result_map=self._add_to_result_map) + self.process( + c, + within_columns_clause=True, + add_to_result_map=self._add_to_result_map, + ) text = self.process(taf.element, **kw) if asfrom and parens: @@ -802,17 +933,17 @@ class SQLCompiler(Compiled): return text def visit_null(self, expr, **kw): - return 'NULL' + return "NULL" def visit_true(self, expr, **kw): if self.dialect.supports_native_boolean: - return 'true' + return "true" else: return "1" def visit_false(self, expr, **kw): if self.dialect.supports_native_boolean: - return 'false' + return "false" else: return "0" @@ -823,25 +954,29 @@ class SQLCompiler(Compiled): else: sep = OPERATORS[clauselist.operator] return sep.join( - s for s in - ( - c._compiler_dispatch(self, **kw) - for c in clauselist.clauses) - if s) + s + for s in ( + c._compiler_dispatch(self, **kw) for c in clauselist.clauses + ) + if s + ) def visit_case(self, clause, **kwargs): x = "CASE " if clause.value is not None: x += clause.value._compiler_dispatch(self, **kwargs) + " " for cond, result in clause.whens: - x += "WHEN " + cond._compiler_dispatch( - self, **kwargs - ) + " THEN " + result._compiler_dispatch( - self, **kwargs) + " " + x += ( + "WHEN " + + cond._compiler_dispatch(self, **kwargs) + + " THEN " + + result._compiler_dispatch(self, **kwargs) + + " " + ) if clause.else_ is not None: - x += "ELSE " + clause.else_._compiler_dispatch( - self, **kwargs - ) + " " + x += ( + "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " " + ) x += "END" return x @@ -849,79 +984,84 @@ class SQLCompiler(Compiled): return type_coerce.typed_expression._compiler_dispatch(self, **kw) def visit_cast(self, cast, **kwargs): - return "CAST(%s AS %s)" % \ - (cast.clause._compiler_dispatch(self, **kwargs), - cast.typeclause._compiler_dispatch(self, **kwargs)) + return "CAST(%s AS %s)" % ( + cast.clause._compiler_dispatch(self, **kwargs), + cast.typeclause._compiler_dispatch(self, **kwargs), + ) def _format_frame_clause(self, range_, **kw): - return '%s AND %s' % ( + return "%s AND %s" % ( "UNBOUNDED PRECEDING" if range_[0] is elements.RANGE_UNBOUNDED - else "CURRENT ROW" if range_[0] is elements.RANGE_CURRENT - else "%s PRECEDING" % ( - self.process(elements.literal(abs(range_[0])), **kw), ) + else "CURRENT ROW" + if range_[0] is elements.RANGE_CURRENT + else "%s PRECEDING" + % (self.process(elements.literal(abs(range_[0])), **kw),) if range_[0] < 0 - else "%s FOLLOWING" % ( - self.process(elements.literal(range_[0]), **kw), ), - + else "%s FOLLOWING" + % (self.process(elements.literal(range_[0]), **kw),), "UNBOUNDED FOLLOWING" if range_[1] is elements.RANGE_UNBOUNDED - else "CURRENT ROW" if range_[1] is elements.RANGE_CURRENT - else "%s PRECEDING" % ( - self.process(elements.literal(abs(range_[1])), **kw), ) + else "CURRENT ROW" + if range_[1] is elements.RANGE_CURRENT + else "%s PRECEDING" + % (self.process(elements.literal(abs(range_[1])), **kw),) if range_[1] < 0 - else "%s FOLLOWING" % ( - self.process(elements.literal(range_[1]), **kw), ), + else "%s FOLLOWING" + % (self.process(elements.literal(range_[1]), **kw),), ) def visit_over(self, over, **kwargs): if over.range_: range_ = "RANGE BETWEEN %s" % self._format_frame_clause( - over.range_, **kwargs) + over.range_, **kwargs + ) elif over.rows: range_ = "ROWS BETWEEN %s" % self._format_frame_clause( - over.rows, **kwargs) + over.rows, **kwargs + ) else: range_ = None return "%s OVER (%s)" % ( over.element._compiler_dispatch(self, **kwargs), - ' '.join([ - '%s BY %s' % ( - word, clause._compiler_dispatch(self, **kwargs) - ) - for word, clause in ( - ('PARTITION', over.partition_by), - ('ORDER', over.order_by) - ) - if clause is not None and len(clause) - ] + ([range_] if range_ else []) - ) + " ".join( + [ + "%s BY %s" + % (word, clause._compiler_dispatch(self, **kwargs)) + for word, clause in ( + ("PARTITION", over.partition_by), + ("ORDER", over.order_by), + ) + if clause is not None and len(clause) + ] + + ([range_] if range_ else []) + ), ) def visit_withingroup(self, withingroup, **kwargs): return "%s WITHIN GROUP (ORDER BY %s)" % ( withingroup.element._compiler_dispatch(self, **kwargs), - withingroup.order_by._compiler_dispatch(self, **kwargs) + withingroup.order_by._compiler_dispatch(self, **kwargs), ) def visit_funcfilter(self, funcfilter, **kwargs): return "%s FILTER (WHERE %s)" % ( funcfilter.func._compiler_dispatch(self, **kwargs), - funcfilter.criterion._compiler_dispatch(self, **kwargs) + funcfilter.criterion._compiler_dispatch(self, **kwargs), ) def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) return "EXTRACT(%s FROM %s)" % ( - field, extract.expr._compiler_dispatch(self, **kwargs)) + field, + extract.expr._compiler_dispatch(self, **kwargs), + ) def visit_function(self, func, add_to_result_map=None, **kwargs): if add_to_result_map is not None: - add_to_result_map( - func.name, func.name, (), func.type - ) + add_to_result_map(func.name, func.name, (), func.type) disp = getattr(self, "visit_%s_func" % func.name.lower(), None) if disp: @@ -933,51 +1073,63 @@ class SQLCompiler(Compiled): name += "%(expr)s" else: name = func.name + "%(expr)s" - return ".".join(list(func.packagenames) + [name]) % \ - {'expr': self.function_argspec(func, **kwargs)} + return ".".join(list(func.packagenames) + [name]) % { + "expr": self.function_argspec(func, **kwargs) + } def visit_next_value_func(self, next_value, **kw): return self.visit_sequence(next_value.sequence) def visit_sequence(self, sequence, **kw): raise NotImplementedError( - "Dialect '%s' does not support sequence increments." % - self.dialect.name + "Dialect '%s' does not support sequence increments." + % self.dialect.name ) def function_argspec(self, func, **kwargs): return func.clause_expr._compiler_dispatch(self, **kwargs) - def visit_compound_select(self, cs, asfrom=False, - parens=True, compound_index=0, **kwargs): + def visit_compound_select( + self, cs, asfrom=False, parens=True, compound_index=0, **kwargs + ): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - need_result_map = toplevel or \ - (compound_index == 0 - and entry.get('need_result_map_for_compound', False)) + need_result_map = toplevel or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) self.stack.append( { - 'correlate_froms': entry['correlate_froms'], - 'asfrom_froms': entry['asfrom_froms'], - 'selectable': cs, - 'need_result_map_for_compound': need_result_map - }) + "correlate_froms": entry["correlate_froms"], + "asfrom_froms": entry["asfrom_froms"], + "selectable": cs, + "need_result_map_for_compound": need_result_map, + } + ) keyword = self.compound_keywords.get(cs.keyword) text = (" " + keyword + " ").join( - (c._compiler_dispatch(self, - asfrom=asfrom, parens=False, - compound_index=i, **kwargs) - for i, c in enumerate(cs.selects)) + ( + c._compiler_dispatch( + self, + asfrom=asfrom, + parens=False, + compound_index=i, + **kwargs + ) + for i, c in enumerate(cs.selects) + ) ) text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs)) text += self.order_by_clause(cs, **kwargs) - text += (cs._limit_clause is not None - or cs._offset_clause is not None) and \ - self.limit_clause(cs, **kwargs) or "" + text += ( + (cs._limit_clause is not None or cs._offset_clause is not None) + and self.limit_clause(cs, **kwargs) + or "" + ) if self.ctes and toplevel: text = self._render_cte_clause() + text @@ -990,8 +1142,10 @@ class SQLCompiler(Compiled): def _get_operator_dispatch(self, operator_, qualifier1, qualifier2): attrname = "visit_%s_%s%s" % ( - operator_.__name__, qualifier1, - "_" + qualifier2 if qualifier2 else "") + operator_.__name__, + qualifier1, + "_" + qualifier2 if qualifier2 else "", + ) return getattr(self, attrname, None) def visit_unary(self, unary, **kw): @@ -999,51 +1153,63 @@ class SQLCompiler(Compiled): if unary.modifier: raise exc.CompileError( "Unary expression does not support operator " - "and modifier simultaneously") + "and modifier simultaneously" + ) disp = self._get_operator_dispatch( - unary.operator, "unary", "operator") + unary.operator, "unary", "operator" + ) if disp: return disp(unary, unary.operator, **kw) else: return self._generate_generic_unary_operator( - unary, OPERATORS[unary.operator], **kw) + unary, OPERATORS[unary.operator], **kw + ) elif unary.modifier: disp = self._get_operator_dispatch( - unary.modifier, "unary", "modifier") + unary.modifier, "unary", "modifier" + ) if disp: return disp(unary, unary.modifier, **kw) else: return self._generate_generic_unary_modifier( - unary, OPERATORS[unary.modifier], **kw) + unary, OPERATORS[unary.modifier], **kw + ) else: raise exc.CompileError( - "Unary expression has no operator or modifier") + "Unary expression has no operator or modifier" + ) def visit_istrue_unary_operator(self, element, operator, **kw): - if element._is_implicitly_boolean or \ - self.dialect.supports_native_boolean: + if ( + element._is_implicitly_boolean + or self.dialect.supports_native_boolean + ): return self.process(element.element, **kw) else: return "%s = 1" % self.process(element.element, **kw) def visit_isfalse_unary_operator(self, element, operator, **kw): - if element._is_implicitly_boolean or \ - self.dialect.supports_native_boolean: + if ( + element._is_implicitly_boolean + or self.dialect.supports_native_boolean + ): return "NOT %s" % self.process(element.element, **kw) else: return "%s = 0" % self.process(element.element, **kw) def visit_notmatch_op_binary(self, binary, operator, **kw): return "NOT %s" % self.visit_binary( - binary, override_operator=operators.match_op) + binary, override_operator=operators.match_op + ) def _emit_empty_in_warning(self): util.warn( - 'The IN-predicate was invoked with an ' - 'empty sequence. This results in a ' - 'contradiction, which nonetheless can be ' - 'expensive to evaluate. Consider alternative ' - 'strategies for improved performance.') + "The IN-predicate was invoked with an " + "empty sequence. This results in a " + "contradiction, which nonetheless can be " + "expensive to evaluate. Consider alternative " + "strategies for improved performance." + ) def visit_empty_in_op_binary(self, binary, operator, **kw): if self.dialect._use_static_in: @@ -1063,18 +1229,21 @@ class SQLCompiler(Compiled): def visit_empty_set_expr(self, element_types): raise NotImplementedError( - "Dialect '%s' does not support empty set expression." % - self.dialect.name + "Dialect '%s' does not support empty set expression." + % self.dialect.name ) - def visit_binary(self, binary, override_operator=None, - eager_grouping=False, **kw): + def visit_binary( + self, binary, override_operator=None, eager_grouping=False, **kw + ): # don't allow "? = ?" to render - if self.ansi_bind_rules and \ - isinstance(binary.left, elements.BindParameter) and \ - isinstance(binary.right, elements.BindParameter): - kw['literal_binds'] = True + if ( + self.ansi_bind_rules + and isinstance(binary.left, elements.BindParameter) + and isinstance(binary.right, elements.BindParameter) + ): + kw["literal_binds"] = True operator_ = override_operator or binary.operator disp = self._get_operator_dispatch(operator_, "binary", None) @@ -1093,36 +1262,50 @@ class SQLCompiler(Compiled): def visit_mod_binary(self, binary, operator, **kw): if self.preparer._double_percents: - return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) else: - return self.process(binary.left, **kw) + " % " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " % " + + self.process(binary.right, **kw) + ) def visit_custom_op_binary(self, element, operator, **kw): - kw['eager_grouping'] = operator.eager_grouping + kw["eager_grouping"] = operator.eager_grouping return self._generate_generic_binary( - element, " " + operator.opstring + " ", **kw) + element, " " + operator.opstring + " ", **kw + ) def visit_custom_op_unary_operator(self, element, operator, **kw): return self._generate_generic_unary_operator( - element, operator.opstring + " ", **kw) + element, operator.opstring + " ", **kw + ) def visit_custom_op_unary_modifier(self, element, operator, **kw): return self._generate_generic_unary_modifier( - element, " " + operator.opstring, **kw) + element, " " + operator.opstring, **kw + ) def _generate_generic_binary( - self, binary, opstring, eager_grouping=False, **kw): + self, binary, opstring, eager_grouping=False, **kw + ): - _in_binary = kw.get('_in_binary', False) + _in_binary = kw.get("_in_binary", False) - kw['_in_binary'] = True - text = binary.left._compiler_dispatch( - self, eager_grouping=eager_grouping, **kw) + \ - opstring + \ - binary.right._compiler_dispatch( - self, eager_grouping=eager_grouping, **kw) + kw["_in_binary"] = True + text = ( + binary.left._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + + opstring + + binary.right._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + ) if _in_binary and eager_grouping: text = "(%s)" % text @@ -1153,17 +1336,13 @@ class SQLCompiler(Compiled): def visit_startswith_op_binary(self, binary, operator, **kw): binary = binary._clone() percent = self._like_percent_literal - binary.right = percent.__radd__( - binary.right - ) + binary.right = percent.__radd__(binary.right) return self.visit_like_op_binary(binary, operator, **kw) def visit_notstartswith_op_binary(self, binary, operator, **kw): binary = binary._clone() percent = self._like_percent_literal - binary.right = percent.__radd__( - binary.right - ) + binary.right = percent.__radd__(binary.right) return self.visit_notlike_op_binary(binary, operator, **kw) def visit_endswith_op_binary(self, binary, operator, **kw): @@ -1182,98 +1361,105 @@ class SQLCompiler(Compiled): escape = binary.modifiers.get("escape", None) # TODO: use ternary here, not "and"/ "or" - return '%s LIKE %s' % ( + return "%s LIKE %s" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_notlike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return '%s NOT LIKE %s' % ( + return "%s NOT LIKE %s" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_ilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return 'lower(%s) LIKE lower(%s)' % ( + return "lower(%s) LIKE lower(%s)" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_notilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return 'lower(%s) NOT LIKE lower(%s)' % ( + return "lower(%s) NOT LIKE lower(%s)" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_between_op_binary(self, binary, operator, **kw): symmetric = binary.modifiers.get("symmetric", False) return self._generate_generic_binary( - binary, " BETWEEN SYMMETRIC " - if symmetric else " BETWEEN ", **kw) + binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw + ) def visit_notbetween_op_binary(self, binary, operator, **kw): symmetric = binary.modifiers.get("symmetric", False) return self._generate_generic_binary( - binary, " NOT BETWEEN SYMMETRIC " - if symmetric else " NOT BETWEEN ", **kw) + binary, + " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ", + **kw + ) - def visit_bindparam(self, bindparam, within_columns_clause=False, - literal_binds=False, - skip_bind_expression=False, - **kwargs): + def visit_bindparam( + self, + bindparam, + within_columns_clause=False, + literal_binds=False, + skip_bind_expression=False, + **kwargs + ): if not skip_bind_expression: impl = bindparam.type.dialect_impl(self.dialect) if impl._has_bind_expression: bind_expression = impl.bind_expression(bindparam) return self.process( - bind_expression, skip_bind_expression=True, + bind_expression, + skip_bind_expression=True, within_columns_clause=within_columns_clause, literal_binds=literal_binds, **kwargs ) - if literal_binds or \ - (within_columns_clause and - self.ansi_bind_rules): + if literal_binds or (within_columns_clause and self.ansi_bind_rules): if bindparam.value is None and bindparam.callable is None: - raise exc.CompileError("Bind parameter '%s' without a " - "renderable value not allowed here." - % bindparam.key) + raise exc.CompileError( + "Bind parameter '%s' without a " + "renderable value not allowed here." % bindparam.key + ) return self.render_literal_bindparam( - bindparam, within_columns_clause=True, **kwargs) + bindparam, within_columns_clause=True, **kwargs + ) name = self._truncate_bindparam(bindparam) if name in self.binds: existing = self.binds[name] if existing is not bindparam: - if (existing.unique or bindparam.unique) and \ - not existing.proxy_set.intersection( - bindparam.proxy_set): + if ( + existing.unique or bindparam.unique + ) and not existing.proxy_set.intersection(bindparam.proxy_set): raise exc.CompileError( "Bind parameter '%s' conflicts with " - "unique bind parameter of the same name" % - bindparam.key + "unique bind parameter of the same name" + % bindparam.key ) elif existing._is_crud or bindparam._is_crud: raise exc.CompileError( @@ -1282,14 +1468,15 @@ class SQLCompiler(Compiled): "clause of this " "insert/update statement. Please use a " "name other than column name when using bindparam() " - "with insert() or update() (for example, 'b_%s')." % - (bindparam.key, bindparam.key) + "with insert() or update() (for example, 'b_%s')." + % (bindparam.key, bindparam.key) ) self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string( - name, expanding=bindparam.expanding, **kwargs) + name, expanding=bindparam.expanding, **kwargs + ) def render_literal_bindparam(self, bindparam, **kw): value = bindparam.effective_value @@ -1311,7 +1498,8 @@ class SQLCompiler(Compiled): return processor(value) else: raise NotImplementedError( - "Don't know how to literal-quote value %r" % value) + "Don't know how to literal-quote value %r" % value + ) def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: @@ -1334,8 +1522,11 @@ class SQLCompiler(Compiled): if len(anonname) > self.label_length - 6: counter = self.truncated_names.get(ident_class, 1) - truncname = anonname[0:max(self.label_length - 6, 0)] + \ - "_" + hex(counter)[2:] + truncname = ( + anonname[0 : max(self.label_length - 6, 0)] + + "_" + + hex(counter)[2:] + ) self.truncated_names[ident_class] = counter + 1 else: truncname = anonname @@ -1346,13 +1537,14 @@ class SQLCompiler(Compiled): return name % self.anon_map def _process_anon(self, key): - (ident, derived) = key.split(' ', 1) + (ident, derived) = key.split(" ", 1) anonymous_counter = self.anon_map.get(derived, 1) self.anon_map[derived] = anonymous_counter + 1 return derived + "_" + str(anonymous_counter) def bindparam_string( - self, name, positional_names=None, expanding=False, **kw): + self, name, positional_names=None, expanding=False, **kw + ): if self.positional: if positional_names is not None: positional_names.append(name) @@ -1362,14 +1554,20 @@ class SQLCompiler(Compiled): self.contains_expanding_parameters = True return "([EXPANDING_%s])" % name else: - return self.bindtemplate % {'name': name} - - def visit_cte(self, cte, asfrom=False, ashint=False, - fromhints=None, visiting_cte=None, - **kwargs): + return self.bindtemplate % {"name": name} + + def visit_cte( + self, + cte, + asfrom=False, + ashint=False, + fromhints=None, + visiting_cte=None, + **kwargs + ): self._init_cte_state() - kwargs['visiting_cte'] = cte + kwargs["visiting_cte"] = cte if isinstance(cte.name, elements._truncated_label): cte_name = self._truncated_identifier("alias", cte.name) else: @@ -1394,8 +1592,8 @@ class SQLCompiler(Compiled): else: raise exc.CompileError( "Multiple, unrelated CTEs found with " - "the same name: %r" % - cte_name) + "the same name: %r" % cte_name + ) if asfrom or is_new_cte: if cte._cte_alias is not None: @@ -1403,7 +1601,8 @@ class SQLCompiler(Compiled): cte_pre_alias_name = cte._cte_alias.name if isinstance(cte_pre_alias_name, elements._truncated_label): cte_pre_alias_name = self._truncated_identifier( - "alias", cte_pre_alias_name) + "alias", cte_pre_alias_name + ) else: pre_alias_cte = cte cte_pre_alias_name = None @@ -1412,11 +1611,17 @@ class SQLCompiler(Compiled): self.ctes_by_name[cte_name] = cte # look for embedded DML ctes and propagate autocommit - if 'autocommit' in cte.element._execution_options and \ - 'autocommit' not in self.execution_options: + if ( + "autocommit" in cte.element._execution_options + and "autocommit" not in self.execution_options + ): self.execution_options = self.execution_options.union( - {"autocommit": - cte.element._execution_options['autocommit']}) + { + "autocommit": cte.element._execution_options[ + "autocommit" + ] + } + ) if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) @@ -1432,25 +1637,30 @@ class SQLCompiler(Compiled): col_source = cte.original.selects[0] else: assert False - recur_cols = [c for c in - util.unique_list(col_source.inner_columns) - if c is not None] - - text += "(%s)" % (", ".join( - self.preparer.format_column(ident) - for ident in recur_cols)) + recur_cols = [ + c + for c in util.unique_list(col_source.inner_columns) + if c is not None + ] + + text += "(%s)" % ( + ", ".join( + self.preparer.format_column(ident) + for ident in recur_cols + ) + ) if self.positional: - kwargs['positional_names'] = self.cte_positional[cte] = [] + kwargs["positional_names"] = self.cte_positional[cte] = [] - text += " AS \n" + \ - cte.original._compiler_dispatch( - self, asfrom=True, **kwargs - ) + text += " AS \n" + cte.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) if cte._suffixes: text += " " + self._generate_prefixes( - cte, cte._suffixes, **kwargs) + cte, cte._suffixes, **kwargs + ) self.ctes[cte] = text @@ -1467,9 +1677,15 @@ class SQLCompiler(Compiled): else: return self.preparer.format_alias(cte, cte_name) - def visit_alias(self, alias, asfrom=False, ashint=False, - iscrud=False, - fromhints=None, **kwargs): + def visit_alias( + self, + alias, + asfrom=False, + ashint=False, + iscrud=False, + fromhints=None, + **kwargs + ): if asfrom or ashint: if isinstance(alias.name, elements._truncated_label): alias_name = self._truncated_identifier("alias", alias.name) @@ -1479,31 +1695,35 @@ class SQLCompiler(Compiled): if ashint: return self.preparer.format_alias(alias, alias_name) elif asfrom: - ret = alias.original._compiler_dispatch(self, - asfrom=True, **kwargs) + \ - self.get_render_as_alias_suffix( - self.preparer.format_alias(alias, alias_name)) + ret = alias.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) + self.get_render_as_alias_suffix( + self.preparer.format_alias(alias, alias_name) + ) if fromhints and alias in fromhints: - ret = self.format_from_hint_text(ret, alias, - fromhints[alias], iscrud) + ret = self.format_from_hint_text( + ret, alias, fromhints[alias], iscrud + ) return ret else: return alias.original._compiler_dispatch(self, **kwargs) def visit_lateral(self, lateral, **kw): - kw['lateral'] = True + kw["lateral"] = True return "LATERAL %s" % self.visit_alias(lateral, **kw) def visit_tablesample(self, tablesample, asfrom=False, **kw): text = "%s TABLESAMPLE %s" % ( self.visit_alias(tablesample, asfrom=True, **kw), - tablesample._get_method()._compiler_dispatch(self, **kw)) + tablesample._get_method()._compiler_dispatch(self, **kw), + ) if tablesample.seed is not None: text += " REPEATABLE (%s)" % ( - tablesample.seed._compiler_dispatch(self, **kw)) + tablesample.seed._compiler_dispatch(self, **kw) + ) return text @@ -1513,22 +1733,27 @@ class SQLCompiler(Compiled): def _add_to_result_map(self, keyname, name, objects, type_): self._result_columns.append((keyname, name, objects, type_)) - def _label_select_column(self, select, column, - populate_result_map, - asfrom, column_clause_args, - name=None, - within_columns_clause=True): + def _label_select_column( + self, + select, + column, + populate_result_map, + asfrom, + column_clause_args, + name=None, + within_columns_clause=True, + ): """produce labeled columns present in a select().""" impl = column.type.dialect_impl(self.dialect) - if impl._has_column_expression and \ - populate_result_map: + if impl._has_column_expression and populate_result_map: col_expr = impl.column_expression(column) def add_to_result_map(keyname, name, objects, type_): self._add_to_result_map( - keyname, name, - (column,) + objects, type_) + keyname, name, (column,) + objects, type_ + ) + else: col_expr = column if populate_result_map: @@ -1541,58 +1766,56 @@ class SQLCompiler(Compiled): elif isinstance(column, elements.Label): if col_expr is not column: result_expr = _CompileLabel( - col_expr, - column.name, - alt_names=(column.element,) + col_expr, column.name, alt_names=(column.element,) ) else: result_expr = col_expr elif select is not None and name: + result_expr = _CompileLabel( + col_expr, name, alt_names=(column._key_label,) + ) + + elif ( + asfrom + and isinstance(column, elements.ColumnClause) + and not column.is_literal + and column.table is not None + and not isinstance(column.table, selectable.Select) + ): result_expr = _CompileLabel( col_expr, - name, - alt_names=(column._key_label,) - ) - - elif \ - asfrom and \ - isinstance(column, elements.ColumnClause) and \ - not column.is_literal and \ - column.table is not None and \ - not isinstance(column.table, selectable.Select): - result_expr = _CompileLabel(col_expr, - elements._as_truncated(column.name), - alt_names=(column.key,)) + elements._as_truncated(column.name), + alt_names=(column.key,), + ) elif ( - not isinstance(column, elements.TextClause) and - ( - not isinstance(column, elements.UnaryExpression) or - column.wraps_column_expression - ) and - ( - not hasattr(column, 'name') or - isinstance(column, functions.Function) + not isinstance(column, elements.TextClause) + and ( + not isinstance(column, elements.UnaryExpression) + or column.wraps_column_expression + ) + and ( + not hasattr(column, "name") + or isinstance(column, functions.Function) ) ): result_expr = _CompileLabel(col_expr, column.anon_label) elif col_expr is not column: # TODO: are we sure "column" has a .name and .key here ? # assert isinstance(column, elements.ColumnClause) - result_expr = _CompileLabel(col_expr, - elements._as_truncated(column.name), - alt_names=(column.key,)) + result_expr = _CompileLabel( + col_expr, + elements._as_truncated(column.name), + alt_names=(column.key,), + ) else: result_expr = col_expr column_clause_args.update( within_columns_clause=within_columns_clause, - add_to_result_map=add_to_result_map - ) - return result_expr._compiler_dispatch( - self, - **column_clause_args + add_to_result_map=add_to_result_map, ) + return result_expr._compiler_dispatch(self, **column_clause_args) def format_from_hint_text(self, sqltext, table, hint, iscrud): hinttext = self.get_from_hint_text(table, hint) @@ -1631,8 +1854,11 @@ class SQLCompiler(Compiled): newelem = cloned[element] = element._clone() - if newelem.is_selectable and newelem._is_join and \ - isinstance(newelem.right, selectable.FromGrouping): + if ( + newelem.is_selectable + and newelem._is_join + and isinstance(newelem.right, selectable.FromGrouping) + ): newelem._reset_exported() newelem.left = visit(newelem.left, **kw) @@ -1640,8 +1866,8 @@ class SQLCompiler(Compiled): right = visit(newelem.right, **kw) selectable_ = selectable.Select( - [right.element], - use_labels=True).alias() + [right.element], use_labels=True + ).alias() for c in selectable_.c: c._key_label = c.key @@ -1680,17 +1906,18 @@ class SQLCompiler(Compiled): elif newelem._is_from_container: # if we hit an Alias, CompoundSelect or ScalarSelect, put a # marker in the stack. - kw['transform_clue'] = 'select_container' + kw["transform_clue"] = "select_container" newelem._copy_internals(clone=visit, **kw) elif newelem.is_selectable and newelem._is_select: - barrier_select = kw.get('transform_clue', None) == \ - 'select_container' + barrier_select = ( + kw.get("transform_clue", None) == "select_container" + ) # if we're still descended from an # Alias/CompoundSelect/ScalarSelect, we're # in a FROM clause, so start with a new translate collection if barrier_select: column_translate.append({}) - kw['transform_clue'] = 'inside_select' + kw["transform_clue"] = "inside_select" newelem._copy_internals(clone=visit, **kw) if barrier_select: del column_translate[-1] @@ -1702,24 +1929,22 @@ class SQLCompiler(Compiled): return visit(select) def _transform_result_map_for_nested_joins( - self, select, transformed_select): - inner_col = dict((c._key_label, c) for - c in transformed_select.inner_columns) - - d = dict( - (inner_col[c._key_label], c) - for c in select.inner_columns + self, select, transformed_select + ): + inner_col = dict( + (c._key_label, c) for c in transformed_select.inner_columns ) + d = dict((inner_col[c._key_label], c) for c in select.inner_columns) + self._result_columns = [ (key, name, tuple([d.get(col, col) for col in objs]), typ) for key, name, objs, typ in self._result_columns ] - _default_stack_entry = util.immutabledict([ - ('correlate_froms', frozenset()), - ('asfrom_froms', frozenset()) - ]) + _default_stack_entry = util.immutabledict( + [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] + ) def _display_froms_for_select(self, select, asfrom, lateral=False): # utility method to help external dialects @@ -1729,72 +1954,88 @@ class SQLCompiler(Compiled): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - correlate_froms = entry['correlate_froms'] - asfrom_froms = entry['asfrom_froms'] + correlate_froms = entry["correlate_froms"] + asfrom_froms = entry["asfrom_froms"] if asfrom and not lateral: froms = select._get_display_froms( explicit_correlate_froms=correlate_froms.difference( - asfrom_froms), - implicit_correlate_froms=()) + asfrom_froms + ), + implicit_correlate_froms=(), + ) else: froms = select._get_display_froms( explicit_correlate_froms=correlate_froms, - implicit_correlate_froms=asfrom_froms) + implicit_correlate_froms=asfrom_froms, + ) return froms - def visit_select(self, select, asfrom=False, parens=True, - fromhints=None, - compound_index=0, - nested_join_translation=False, - select_wraps_for=None, - lateral=False, - **kwargs): - - needs_nested_translation = \ - select.use_labels and \ - not nested_join_translation and \ - not self.stack and \ - not self.dialect.supports_right_nested_joins + def visit_select( + self, + select, + asfrom=False, + parens=True, + fromhints=None, + compound_index=0, + nested_join_translation=False, + select_wraps_for=None, + lateral=False, + **kwargs + ): + + needs_nested_translation = ( + select.use_labels + and not nested_join_translation + and not self.stack + and not self.dialect.supports_right_nested_joins + ) if needs_nested_translation: transformed_select = self._transform_select_for_nested_joins( - select) + select + ) text = self.visit_select( - transformed_select, asfrom=asfrom, parens=parens, + transformed_select, + asfrom=asfrom, + parens=parens, fromhints=fromhints, compound_index=compound_index, - nested_join_translation=True, **kwargs + nested_join_translation=True, + **kwargs ) toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - populate_result_map = toplevel or \ - ( - compound_index == 0 and entry.get( - 'need_result_map_for_compound', False) - ) or entry.get('need_result_map_for_nested', False) + populate_result_map = ( + toplevel + or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) + or entry.get("need_result_map_for_nested", False) + ) # this was first proposed as part of #3372; however, it is not # reached in current tests and could possibly be an assertion # instead. - if not populate_result_map and 'add_to_result_map' in kwargs: - del kwargs['add_to_result_map'] + if not populate_result_map and "add_to_result_map" in kwargs: + del kwargs["add_to_result_map"] if needs_nested_translation: if populate_result_map: self._transform_result_map_for_nested_joins( - select, transformed_select) + select, transformed_select + ) return text froms = self._setup_select_stack(select, entry, asfrom, lateral) column_clause_args = kwargs.copy() - column_clause_args.update({ - 'within_label_clause': False, - 'within_columns_clause': False - }) + column_clause_args.update( + {"within_label_clause": False, "within_columns_clause": False} + ) text = "SELECT " # we're off to a good start ! @@ -1806,19 +2047,21 @@ class SQLCompiler(Compiled): byfrom = None if select._prefixes: - text += self._generate_prefixes( - select, select._prefixes, **kwargs) + text += self._generate_prefixes(select, select._prefixes, **kwargs) text += self.get_select_precolumns(select, **kwargs) # the actual list of columns to print in the SELECT column list. inner_columns = [ - c for c in [ + c + for c in [ self._label_select_column( select, column, - populate_result_map, asfrom, + populate_result_map, + asfrom, column_clause_args, - name=name) + name=name, + ) for name, column in select._columns_plus_names ] if c is not None @@ -1831,8 +2074,11 @@ class SQLCompiler(Compiled): translate = dict( zip( [name for (key, name) in select._columns_plus_names], - [name for (key, name) in - select_wraps_for._columns_plus_names]) + [ + name + for (key, name) in select_wraps_for._columns_plus_names + ], + ) ) self._result_columns = [ @@ -1841,13 +2087,14 @@ class SQLCompiler(Compiled): ] text = self._compose_select_body( - text, select, inner_columns, froms, byfrom, kwargs) + text, select, inner_columns, froms, byfrom, kwargs + ) if select._statement_hints: per_dialect = [ - ht for (dialect_name, ht) - in select._statement_hints - if dialect_name in ('*', self.dialect.name) + ht + for (dialect_name, ht) in select._statement_hints + if dialect_name in ("*", self.dialect.name) ] if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) @@ -1857,7 +2104,8 @@ class SQLCompiler(Compiled): if select._suffixes: text += " " + self._generate_prefixes( - select, select._suffixes, **kwargs) + select, select._suffixes, **kwargs + ) self.stack.pop(-1) @@ -1867,60 +2115,73 @@ class SQLCompiler(Compiled): return text def _setup_select_hints(self, select): - byfrom = dict([ - (from_, hinttext % { - 'name': from_._compiler_dispatch( - self, ashint=True) - }) - for (from_, dialect), hinttext in - select._hints.items() - if dialect in ('*', self.dialect.name) - ]) + byfrom = dict( + [ + ( + from_, + hinttext + % {"name": from_._compiler_dispatch(self, ashint=True)}, + ) + for (from_, dialect), hinttext in select._hints.items() + if dialect in ("*", self.dialect.name) + ] + ) hint_text = self.get_select_hint_text(byfrom) return hint_text, byfrom def _setup_select_stack(self, select, entry, asfrom, lateral): - correlate_froms = entry['correlate_froms'] - asfrom_froms = entry['asfrom_froms'] + correlate_froms = entry["correlate_froms"] + asfrom_froms = entry["asfrom_froms"] if asfrom and not lateral: froms = select._get_display_froms( explicit_correlate_froms=correlate_froms.difference( - asfrom_froms), - implicit_correlate_froms=()) + asfrom_froms + ), + implicit_correlate_froms=(), + ) else: froms = select._get_display_froms( explicit_correlate_froms=correlate_froms, - implicit_correlate_froms=asfrom_froms) + implicit_correlate_froms=asfrom_froms, + ) new_correlate_froms = set(selectable._from_objects(*froms)) all_correlate_froms = new_correlate_froms.union(correlate_froms) new_entry = { - 'asfrom_froms': new_correlate_froms, - 'correlate_froms': all_correlate_froms, - 'selectable': select, + "asfrom_froms": new_correlate_froms, + "correlate_froms": all_correlate_froms, + "selectable": select, } self.stack.append(new_entry) return froms def _compose_select_body( - self, text, select, inner_columns, froms, byfrom, kwargs): - text += ', '.join(inner_columns) + self, text, select, inner_columns, froms, byfrom, kwargs + ): + text += ", ".join(inner_columns) if froms: text += " \nFROM " if select._hints: - text += ', '.join( - [f._compiler_dispatch(self, asfrom=True, - fromhints=byfrom, **kwargs) - for f in froms]) + text += ", ".join( + [ + f._compiler_dispatch( + self, asfrom=True, fromhints=byfrom, **kwargs + ) + for f in froms + ] + ) else: - text += ', '.join( - [f._compiler_dispatch(self, asfrom=True, **kwargs) - for f in froms]) + text += ", ".join( + [ + f._compiler_dispatch(self, asfrom=True, **kwargs) + for f in froms + ] + ) else: text += self.default_from() @@ -1940,8 +2201,10 @@ class SQLCompiler(Compiled): if select._order_by_clause.clauses: text += self.order_by_clause(select, **kwargs) - if (select._limit_clause is not None or - select._offset_clause is not None): + if ( + select._limit_clause is not None + or select._offset_clause is not None + ): text += self.limit_clause(select, **kwargs) if select._for_update_arg is not None: @@ -1953,8 +2216,7 @@ class SQLCompiler(Compiled): clause = " ".join( prefix._compiler_dispatch(self, **kw) for prefix, dialect_name in prefixes - if dialect_name is None or - dialect_name == self.dialect.name + if dialect_name is None or dialect_name == self.dialect.name ) if clause: clause += " " @@ -1962,14 +2224,12 @@ class SQLCompiler(Compiled): def _render_cte_clause(self): if self.positional: - self.positiontup = sum([ - self.cte_positional[cte] - for cte in self.ctes], []) + \ - self.positiontup + self.positiontup = ( + sum([self.cte_positional[cte] for cte in self.ctes], []) + + self.positiontup + ) cte_text = self.get_cte_preamble(self.ctes_recursive) + " " - cte_text += ", \n".join( - [txt for txt in self.ctes.values()] - ) + cte_text += ", \n".join([txt for txt in self.ctes.values()]) cte_text += "\n " return cte_text @@ -2010,7 +2270,8 @@ class SQLCompiler(Compiled): def returning_clause(self, stmt, returning_cols): raise exc.CompileError( "RETURNING is not supported by this " - "dialect's statement compiler.") + "dialect's statement compiler." + ) def limit_clause(self, select, **kw): text = "" @@ -2022,19 +2283,31 @@ class SQLCompiler(Compiled): text += " OFFSET " + self.process(select._offset_clause, **kw) return text - def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, - fromhints=None, use_schema=True, **kwargs): + def visit_table( + self, + table, + asfrom=False, + iscrud=False, + ashint=False, + fromhints=None, + use_schema=True, + **kwargs + ): if asfrom or ashint: effective_schema = self.preparer.schema_for_object(table) if use_schema and effective_schema: - ret = self.preparer.quote_schema(effective_schema) + \ - "." + self.preparer.quote(table.name) + ret = ( + self.preparer.quote_schema(effective_schema) + + "." + + self.preparer.quote(table.name) + ) else: ret = self.preparer.quote(table.name) if fromhints and table in fromhints: - ret = self.format_from_hint_text(ret, table, - fromhints[table], iscrud) + ret = self.format_from_hint_text( + ret, table, fromhints[table], iscrud + ) return ret else: return "" @@ -2047,26 +2320,24 @@ class SQLCompiler(Compiled): else: join_type = " JOIN " return ( - join.left._compiler_dispatch(self, asfrom=True, **kwargs) + - join_type + - join.right._compiler_dispatch(self, asfrom=True, **kwargs) + - " ON " + - join.onclause._compiler_dispatch(self, **kwargs) + join.left._compiler_dispatch(self, asfrom=True, **kwargs) + + join_type + + join.right._compiler_dispatch(self, asfrom=True, **kwargs) + + " ON " + + join.onclause._compiler_dispatch(self, **kwargs) ) def _setup_crud_hints(self, stmt, table_text): - dialect_hints = dict([ - (table, hint_text) - for (table, dialect), hint_text in - stmt._hints.items() - if dialect in ('*', self.dialect.name) - ]) + dialect_hints = dict( + [ + (table, hint_text) + for (table, dialect), hint_text in stmt._hints.items() + if dialect in ("*", self.dialect.name) + ] + ) if stmt.table in dialect_hints: table_text = self.format_from_hint_text( - table_text, - stmt.table, - dialect_hints[stmt.table], - True + table_text, stmt.table, dialect_hints[stmt.table], True ) return dialect_hints, table_text @@ -2074,28 +2345,35 @@ class SQLCompiler(Compiled): toplevel = not self.stack self.stack.append( - {'correlate_froms': set(), - "asfrom_froms": set(), - "selectable": insert_stmt}) + { + "correlate_froms": set(), + "asfrom_froms": set(), + "selectable": insert_stmt, + } + ) crud_params = crud._setup_crud_params( - self, insert_stmt, crud.ISINSERT, **kw) + self, insert_stmt, crud.ISINSERT, **kw + ) - if not crud_params and \ - not self.dialect.supports_default_values and \ - not self.dialect.supports_empty_insert: - raise exc.CompileError("The '%s' dialect with current database " - "version settings does not support empty " - "inserts." % - self.dialect.name) + if ( + not crud_params + and not self.dialect.supports_default_values + and not self.dialect.supports_empty_insert + ): + raise exc.CompileError( + "The '%s' dialect with current database " + "version settings does not support empty " + "inserts." % self.dialect.name + ) if insert_stmt._has_multi_parameters: if not self.dialect.supports_multivalues_insert: raise exc.CompileError( "The '%s' dialect with current database " "version settings does not support " - "in-place multirow inserts." % - self.dialect.name) + "in-place multirow inserts." % self.dialect.name + ) crud_params_single = crud_params[0] else: crud_params_single = crud_params @@ -2106,27 +2384,31 @@ class SQLCompiler(Compiled): text = "INSERT " if insert_stmt._prefixes: - text += self._generate_prefixes(insert_stmt, - insert_stmt._prefixes, **kw) + text += self._generate_prefixes( + insert_stmt, insert_stmt._prefixes, **kw + ) text += "INTO " table_text = preparer.format_table(insert_stmt.table) if insert_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( - insert_stmt, table_text) + insert_stmt, table_text + ) else: dialect_hints = None text += table_text if crud_params_single or not supports_default_values: - text += " (%s)" % ', '.join([preparer.format_column(c[0]) - for c in crud_params_single]) + text += " (%s)" % ", ".join( + [preparer.format_column(c[0]) for c in crud_params_single] + ) if self.returning or insert_stmt._returning: returning_clause = self.returning_clause( - insert_stmt, self.returning or insert_stmt._returning) + insert_stmt, self.returning or insert_stmt._returning + ) if self.returning_precedes_values: text += " " + returning_clause @@ -2145,19 +2427,17 @@ class SQLCompiler(Compiled): elif insert_stmt._has_multi_parameters: text += " VALUES %s" % ( ", ".join( - "(%s)" % ( - ', '.join(c[1] for c in crud_param_set) - ) + "(%s)" % (", ".join(c[1] for c in crud_param_set)) for crud_param_set in crud_params ) ) else: - text += " VALUES (%s)" % \ - ', '.join([c[1] for c in crud_params]) + text += " VALUES (%s)" % ", ".join([c[1] for c in crud_params]) if insert_stmt._post_values_clause is not None: post_values_clause = self.process( - insert_stmt._post_values_clause, **kw) + insert_stmt._post_values_clause, **kw + ) if post_values_clause: text += " " + post_values_clause @@ -2178,21 +2458,19 @@ class SQLCompiler(Compiled): """Provide a hook for MySQL to add LIMIT to the UPDATE""" return None - def update_tables_clause(self, update_stmt, from_table, - extra_froms, **kw): + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): """Provide a hook to override the initial table clause in an UPDATE statement. MySQL overrides this. """ - kw['asfrom'] = True + kw["asfrom"] = True return from_table._compiler_dispatch(self, iscrud=True, **kw) - def update_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): """Provide a hook to override the generation of an UPDATE..FROM clause. @@ -2201,7 +2479,8 @@ class SQLCompiler(Compiled): """ raise NotImplementedError( "This backend does not support multiple-table " - "criteria within UPDATE") + "criteria within UPDATE" + ) def visit_update(self, update_stmt, asfrom=False, **kw): toplevel = not self.stack @@ -2221,49 +2500,61 @@ class SQLCompiler(Compiled): correlate_froms = {update_stmt.table} self.stack.append( - {'correlate_froms': correlate_froms, - "asfrom_froms": correlate_froms, - "selectable": update_stmt}) + { + "correlate_froms": correlate_froms, + "asfrom_froms": correlate_froms, + "selectable": update_stmt, + } + ) text = "UPDATE " if update_stmt._prefixes: - text += self._generate_prefixes(update_stmt, - update_stmt._prefixes, **kw) + text += self._generate_prefixes( + update_stmt, update_stmt._prefixes, **kw + ) - table_text = self.update_tables_clause(update_stmt, update_stmt.table, - render_extra_froms, **kw) + table_text = self.update_tables_clause( + update_stmt, update_stmt.table, render_extra_froms, **kw + ) crud_params = crud._setup_crud_params( - self, update_stmt, crud.ISUPDATE, **kw) + self, update_stmt, crud.ISUPDATE, **kw + ) if update_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( - update_stmt, table_text) + update_stmt, table_text + ) else: dialect_hints = None text += table_text - text += ' SET ' - include_table = is_multitable and \ - self.render_table_with_column_in_update_from - text += ', '.join( - c[0]._compiler_dispatch(self, - include_table=include_table) + - '=' + c[1] for c in crud_params + text += " SET " + include_table = ( + is_multitable and self.render_table_with_column_in_update_from + ) + text += ", ".join( + c[0]._compiler_dispatch(self, include_table=include_table) + + "=" + + c[1] + for c in crud_params ) if self.returning or update_stmt._returning: if self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning or update_stmt._returning) + update_stmt, self.returning or update_stmt._returning + ) if extra_froms: extra_from_text = self.update_from_clause( update_stmt, update_stmt.table, render_extra_froms, - dialect_hints, **kw) + dialect_hints, + **kw + ) if extra_from_text: text += " " + extra_from_text @@ -2276,10 +2567,12 @@ class SQLCompiler(Compiled): if limit_clause: text += " " + limit_clause - if (self.returning or update_stmt._returning) and \ - not self.returning_precedes_values: + if ( + self.returning or update_stmt._returning + ) and not self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning or update_stmt._returning) + update_stmt, self.returning or update_stmt._returning + ) if self.ctes and toplevel: text = self._render_cte_clause() + text @@ -2295,9 +2588,9 @@ class SQLCompiler(Compiled): def _key_getters_for_crud_column(self): return crud._key_getters_for_crud_column(self, self.statement) - def delete_extra_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, **kw): + def delete_extra_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): """Provide a hook to override the generation of an DELETE..FROM clause. @@ -2308,10 +2601,10 @@ class SQLCompiler(Compiled): """ raise NotImplementedError( "This backend does not support multiple-table " - "criteria within DELETE") + "criteria within DELETE" + ) - def delete_table_clause(self, delete_stmt, from_table, - extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms): return from_table._compiler_dispatch(self, asfrom=True, iscrud=True) def visit_delete(self, delete_stmt, asfrom=False, **kw): @@ -2322,23 +2615,30 @@ class SQLCompiler(Compiled): extra_froms = delete_stmt._extra_froms correlate_froms = {delete_stmt.table}.union(extra_froms) - self.stack.append({'correlate_froms': correlate_froms, - "asfrom_froms": correlate_froms, - "selectable": delete_stmt}) + self.stack.append( + { + "correlate_froms": correlate_froms, + "asfrom_froms": correlate_froms, + "selectable": delete_stmt, + } + ) text = "DELETE " if delete_stmt._prefixes: - text += self._generate_prefixes(delete_stmt, - delete_stmt._prefixes, **kw) + text += self._generate_prefixes( + delete_stmt, delete_stmt._prefixes, **kw + ) text += "FROM " - table_text = self.delete_table_clause(delete_stmt, delete_stmt.table, - extra_froms) + table_text = self.delete_table_clause( + delete_stmt, delete_stmt.table, extra_froms + ) if delete_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( - delete_stmt, table_text) + delete_stmt, table_text + ) else: dialect_hints = None @@ -2347,14 +2647,17 @@ class SQLCompiler(Compiled): if delete_stmt._returning: if self.returning_precedes_values: text += " " + self.returning_clause( - delete_stmt, delete_stmt._returning) + delete_stmt, delete_stmt._returning + ) if extra_froms: extra_from_text = self.delete_extra_from_clause( delete_stmt, delete_stmt.table, extra_froms, - dialect_hints, **kw) + dialect_hints, + **kw + ) if extra_from_text: text += " " + extra_from_text @@ -2365,7 +2668,8 @@ class SQLCompiler(Compiled): if delete_stmt._returning and not self.returning_precedes_values: text += " " + self.returning_clause( - delete_stmt, delete_stmt._returning) + delete_stmt, delete_stmt._returning + ) if self.ctes and toplevel: text = self._render_cte_clause() + text @@ -2381,12 +2685,14 @@ class SQLCompiler(Compiled): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def visit_rollback_to_savepoint(self, savepoint_stmt): - return "ROLLBACK TO SAVEPOINT %s" % \ - self.preparer.format_savepoint(savepoint_stmt) + return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint( + savepoint_stmt + ) def visit_release_savepoint(self, savepoint_stmt): - return "RELEASE SAVEPOINT %s" % \ - self.preparer.format_savepoint(savepoint_stmt) + return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint( + savepoint_stmt + ) class StrSQLCompiler(SQLCompiler): @@ -2403,7 +2709,7 @@ class StrSQLCompiler(SQLCompiler): def visit_getitem_binary(self, binary, operator, **kw): return "%s[%s]" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) def visit_json_getitem_op_binary(self, binary, operator, **kw): @@ -2421,29 +2727,26 @@ class StrSQLCompiler(SQLCompiler): for c in elements._select_iterables(returning_cols) ] - return 'RETURNING ' + ', '.join(columns) + return "RETURNING " + ", ".join(columns) - def update_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): - return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in extra_froms) + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in extra_froms + ) - def delete_extra_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): - return ', ' + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in extra_froms) + def delete_extra_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + return ", " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in extra_froms + ) class DDLCompiler(Compiled): - @util.memoized_property def sql_compiler(self): return self.dialect.statement_compiler(self.dialect, None) @@ -2464,13 +2767,13 @@ class DDLCompiler(Compiled): preparer = self.preparer path = preparer.format_table_seq(ddl.target) if len(path) == 1: - table, sch = path[0], '' + table, sch = path[0], "" else: table, sch = path[-1], path[0] - context.setdefault('table', table) - context.setdefault('schema', sch) - context.setdefault('fullname', preparer.format_table(ddl.target)) + context.setdefault("table", table) + context.setdefault("schema", sch) + context.setdefault("fullname", preparer.format_table(ddl.target)) return self.sql_compiler.post_process_text(ddl.statement % context) @@ -2507,9 +2810,9 @@ class DDLCompiler(Compiled): for create_column in create.columns: column = create_column.element try: - processed = self.process(create_column, - first_pk=column.primary_key - and not first_pk) + processed = self.process( + create_column, first_pk=column.primary_key and not first_pk + ) if processed is not None: text += separator separator = ", \n" @@ -2519,13 +2822,15 @@ class DDLCompiler(Compiled): except exc.CompileError as ce: util.raise_from_cause( exc.CompileError( - util.u("(in table '%s', column '%s'): %s") % - (table.description, column.name, ce.args[0]) - )) + util.u("(in table '%s', column '%s'): %s") + % (table.description, column.name, ce.args[0]) + ) + ) const = self.create_table_constraints( - table, _include_foreign_key_constraints= # noqa - create.include_foreign_key_constraints) + table, + _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa + ) if const: text += separator + "\t" + const @@ -2538,20 +2843,18 @@ class DDLCompiler(Compiled): if column.system: return None - text = self.get_column_specification( - column, - first_pk=first_pk + text = self.get_column_specification(column, first_pk=first_pk) + const = " ".join( + self.process(constraint) for constraint in column.constraints ) - const = " ".join(self.process(constraint) - for constraint in column.constraints) if const: text += " " + const return text def create_table_constraints( - self, table, - _include_foreign_key_constraints=None): + self, table, _include_foreign_key_constraints=None + ): # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) @@ -2565,21 +2868,29 @@ class DDLCompiler(Compiled): else: omit_fkcs = set() - constraints.extend([c for c in table._sorted_constraints - if c is not table.primary_key and - c not in omit_fkcs]) + constraints.extend( + [ + c + for c in table._sorted_constraints + if c is not table.primary_key and c not in omit_fkcs + ] + ) return ", \n\t".join( - p for p in - (self.process(constraint) + p + for p in ( + self.process(constraint) for constraint in constraints if ( - constraint._create_rule is None or - constraint._create_rule(self)) + constraint._create_rule is None + or constraint._create_rule(self) + ) and ( - not self.dialect.supports_alter or - not getattr(constraint, 'use_alter', False) - )) if p is not None + not self.dialect.supports_alter + or not getattr(constraint, "use_alter", False) + ) + ) + if p is not None ) def visit_drop_table(self, drop): @@ -2590,34 +2901,38 @@ class DDLCompiler(Compiled): def _verify_index_table(self, index): if index.table is None: - raise exc.CompileError("Index '%s' is not associated " - "with any table." % index.name) + raise exc.CompileError( + "Index '%s' is not associated " "with any table." % index.name + ) - def visit_create_index(self, create, include_schema=False, - include_table_schema=True): + def visit_create_index( + self, create, include_schema=False, include_table_schema=True + ): index = create.element self._verify_index_table(index) preparer = self.preparer text = "CREATE " if index.unique: text += "UNIQUE " - text += "INDEX %s ON %s (%s)" \ - % ( - self._prepared_index_name(index, - include_schema=include_schema), - preparer.format_table(index.table, - use_schema=include_table_schema), - ', '.join( - self.sql_compiler.process( - expr, include_table=False, literal_binds=True) for - expr in index.expressions) - ) + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=include_schema), + preparer.format_table( + index.table, use_schema=include_table_schema + ), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) return text def visit_drop_index(self, drop): index = drop.element return "\nDROP INDEX " + self._prepared_index_name( - index, include_schema=True) + index, include_schema=True + ) def _prepared_index_name(self, index, include_schema=False): if index.table is not None: @@ -2638,35 +2953,41 @@ class DDLCompiler(Compiled): def visit_add_constraint(self, create): return "ALTER TABLE %s ADD %s" % ( self.preparer.format_table(create.element.table), - self.process(create.element) + self.process(create.element), ) def visit_set_table_comment(self, create): return "COMMENT ON TABLE %s IS %s" % ( self.preparer.format_table(create.element), self.sql_compiler.render_literal_value( - create.element.comment, sqltypes.String()) + create.element.comment, sqltypes.String() + ), ) def visit_drop_table_comment(self, drop): - return "COMMENT ON TABLE %s IS NULL" % \ - self.preparer.format_table(drop.element) + return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table( + drop.element + ) def visit_set_column_comment(self, create): return "COMMENT ON COLUMN %s IS %s" % ( self.preparer.format_column( - create.element, use_table=True, use_schema=True), + create.element, use_table=True, use_schema=True + ), self.sql_compiler.render_literal_value( - create.element.comment, sqltypes.String()) + create.element.comment, sqltypes.String() + ), ) def visit_drop_column_comment(self, drop): - return "COMMENT ON COLUMN %s IS NULL" % \ - self.preparer.format_column(drop.element, use_table=True) + return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column( + drop.element, use_table=True + ) def visit_create_sequence(self, create): - text = "CREATE SEQUENCE %s" % \ - self.preparer.format_sequence(create.element) + text = "CREATE SEQUENCE %s" % self.preparer.format_sequence( + create.element + ) if create.element.increment is not None: text += " INCREMENT BY %d" % create.element.increment if create.element.start is not None: @@ -2688,8 +3009,7 @@ class DDLCompiler(Compiled): return text def visit_drop_sequence(self, drop): - return "DROP SEQUENCE %s" % \ - self.preparer.format_sequence(drop.element) + return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) def visit_drop_constraint(self, drop): constraint = drop.element @@ -2701,17 +3021,22 @@ class DDLCompiler(Compiled): if formatted_name is None: raise exc.CompileError( "Can't emit DROP CONSTRAINT for constraint %r; " - "it has no name" % drop.element) + "it has no name" % drop.element + ) return "ALTER TABLE %s DROP CONSTRAINT %s%s" % ( self.preparer.format_table(drop.element.table), formatted_name, - drop.cascade and " CASCADE" or "" + drop.cascade and " CASCADE" or "", ) def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process( - column.type, type_expression=column) + colspec = ( + self.preparer.format_column(column) + + " " + + self.dialect.type_compiler.process( + column.type, type_expression=column + ) + ) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -2721,19 +3046,21 @@ class DDLCompiler(Compiled): return colspec def create_table_suffix(self, table): - return '' + return "" def post_create_table(self, table): - return '' + return "" def get_column_default_string(self, column): if isinstance(column.server_default, schema.DefaultClause): if isinstance(column.server_default.arg, util.string_types): return self.sql_compiler.render_literal_value( - column.server_default.arg, sqltypes.STRINGTYPE) + column.server_default.arg, sqltypes.STRINGTYPE + ) else: return self.sql_compiler.process( - column.server_default.arg, literal_binds=True) + column.server_default.arg, literal_binds=True + ) else: return None @@ -2743,9 +3070,9 @@ class DDLCompiler(Compiled): formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name - text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext, - include_table=False, - literal_binds=True) + text += "CHECK (%s)" % self.sql_compiler.process( + constraint.sqltext, include_table=False, literal_binds=True + ) text += self.define_constraint_deferrability(constraint) return text @@ -2755,25 +3082,29 @@ class DDLCompiler(Compiled): formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name - text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext, - include_table=False, - literal_binds=True) + text += "CHECK (%s)" % self.sql_compiler.process( + constraint.sqltext, include_table=False, literal_binds=True + ) text += self.define_constraint_deferrability(constraint) return text def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: - return '' + return "" text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name text += "PRIMARY KEY " - text += "(%s)" % ', '.join(self.preparer.quote(c.name) - for c in (constraint.columns_autoinc_first - if constraint._implicit_generated - else constraint.columns)) + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) + for c in ( + constraint.columns_autoinc_first + if constraint._implicit_generated + else constraint.columns + ) + ) text += self.define_constraint_deferrability(constraint) return text @@ -2786,12 +3117,15 @@ class DDLCompiler(Compiled): text += "CONSTRAINT %s " % formatted_name remote_table = list(constraint.elements)[0].column.table text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( - ', '.join(preparer.quote(f.parent.name) - for f in constraint.elements), + ", ".join( + preparer.quote(f.parent.name) for f in constraint.elements + ), self.define_constraint_remote_table( - constraint, remote_table, preparer), - ', '.join(preparer.quote(f.column.name) - for f in constraint.elements) + constraint, remote_table, preparer + ), + ", ".join( + preparer.quote(f.column.name) for f in constraint.elements + ), ) text += self.define_constraint_match(constraint) text += self.define_constraint_cascades(constraint) @@ -2805,14 +3139,14 @@ class DDLCompiler(Compiled): def visit_unique_constraint(self, constraint): if len(constraint) == 0: - return '' + return "" text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) text += "CONSTRAINT %s " % formatted_name text += "UNIQUE (%s)" % ( - ', '.join(self.preparer.quote(c.name) - for c in constraint)) + ", ".join(self.preparer.quote(c.name) for c in constraint) + ) text += self.define_constraint_deferrability(constraint) return text @@ -2843,7 +3177,6 @@ class DDLCompiler(Compiled): class GenericTypeCompiler(TypeCompiler): - def visit_FLOAT(self, type_, **kw): return "FLOAT" @@ -2854,23 +3187,23 @@ class GenericTypeCompiler(TypeCompiler): if type_.precision is None: return "NUMERIC" elif type_.scale is None: - return "NUMERIC(%(precision)s)" % \ - {'precision': type_.precision} + return "NUMERIC(%(precision)s)" % {"precision": type_.precision} else: - return "NUMERIC(%(precision)s, %(scale)s)" % \ - {'precision': type_.precision, - 'scale': type_.scale} + return "NUMERIC(%(precision)s, %(scale)s)" % { + "precision": type_.precision, + "scale": type_.scale, + } def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return "DECIMAL" elif type_.scale is None: - return "DECIMAL(%(precision)s)" % \ - {'precision': type_.precision} + return "DECIMAL(%(precision)s)" % {"precision": type_.precision} else: - return "DECIMAL(%(precision)s, %(scale)s)" % \ - {'precision': type_.precision, - 'scale': type_.scale} + return "DECIMAL(%(precision)s, %(scale)s)" % { + "precision": type_.precision, + "scale": type_.scale, + } def visit_INTEGER(self, type_, **kw): return "INTEGER" @@ -2882,7 +3215,7 @@ class GenericTypeCompiler(TypeCompiler): return "BIGINT" def visit_TIMESTAMP(self, type_, **kw): - return 'TIMESTAMP' + return "TIMESTAMP" def visit_DATETIME(self, type_, **kw): return "DATETIME" @@ -2984,9 +3317,11 @@ class GenericTypeCompiler(TypeCompiler): return self.visit_VARCHAR(type_, **kw) def visit_null(self, type_, **kw): - raise exc.CompileError("Can't generate DDL for %r; " - "did you forget to specify a " - "type on this Column?" % type_) + raise exc.CompileError( + "Can't generate DDL for %r; " + "did you forget to specify a " + "type on this Column?" % type_ + ) def visit_type_decorator(self, type_, **kw): return self.process(type_.type_engine(self.dialect), **kw) @@ -3018,9 +3353,15 @@ class IdentifierPreparer(object): schema_for_object = schema._schema_getter(None) - def __init__(self, dialect, initial_quote='"', - final_quote=None, escape_quote='"', - quote_case_sensitive_collations=True, omit_schema=False): + def __init__( + self, + dialect, + initial_quote='"', + final_quote=None, + escape_quote='"', + quote_case_sensitive_collations=True, + omit_schema=False, + ): """Construct a new ``IdentifierPreparer`` object. initial_quote @@ -3043,7 +3384,10 @@ class IdentifierPreparer(object): self.omit_schema = omit_schema self.quote_case_sensitive_collations = quote_case_sensitive_collations self._strings = {} - self._double_percents = self.dialect.paramstyle in ('format', 'pyformat') + self._double_percents = self.dialect.paramstyle in ( + "format", + "pyformat", + ) def _with_schema_translate(self, schema_translate_map): prep = self.__class__.__new__(self.__class__) @@ -3060,7 +3404,7 @@ class IdentifierPreparer(object): value = value.replace(self.escape_quote, self.escape_to_quote) if self._double_percents: - value = value.replace('%', '%%') + value = value.replace("%", "%%") return value def _unescape_identifier(self, value): @@ -3079,17 +3423,21 @@ class IdentifierPreparer(object): quoting behavior. """ - return self.initial_quote + \ - self._escape_identifier(value) + \ - self.final_quote + return ( + self.initial_quote + + self._escape_identifier(value) + + self.final_quote + ) def _requires_quotes(self, value): """Return True if the given identifier requires quoting.""" lc_value = value.lower() - return (lc_value in self.reserved_words - or value[0] in self.illegal_initial_characters - or not self.legal_characters.match(util.text_type(value)) - or (lc_value != value)) + return ( + lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(util.text_type(value)) + or (lc_value != value) + ) def quote_schema(self, schema, force=None): """Conditionally quote a schema. @@ -3135,8 +3483,11 @@ class IdentifierPreparer(object): effective_schema = self.schema_for_object(sequence) - if (not self.omit_schema and use_schema and - effective_schema is not None): + if ( + not self.omit_schema + and use_schema + and effective_schema is not None + ): name = self.quote_schema(effective_schema) + "." + name return name @@ -3159,7 +3510,8 @@ class IdentifierPreparer(object): def format_constraint(self, naming, constraint): if isinstance(constraint.name, elements._defer_name): name = naming._constraint_name_for_table( - constraint, constraint.table) + constraint, constraint.table + ) if name is None: if isinstance(constraint.name, elements._defer_none_name): @@ -3170,14 +3522,15 @@ class IdentifierPreparer(object): name = constraint.name if isinstance(name, elements._truncated_label): - if constraint.__visit_name__ == 'index': - max_ = self.dialect.max_index_name_length or \ - self.dialect.max_identifier_length + if constraint.__visit_name__ == "index": + max_ = ( + self.dialect.max_index_name_length + or self.dialect.max_identifier_length + ) else: max_ = self.dialect.max_identifier_length if len(name) > max_: - name = name[0:max_ - 8] + \ - "_" + util.md5_hex(name)[-4:] + name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:] else: self.dialect.validate_identifier(name) @@ -3195,8 +3548,7 @@ class IdentifierPreparer(object): effective_schema = self.schema_for_object(table) - if not self.omit_schema and use_schema \ - and effective_schema: + if not self.omit_schema and use_schema and effective_schema: result = self.quote_schema(effective_schema) + "." + result return result @@ -3205,17 +3557,27 @@ class IdentifierPreparer(object): return self.quote(name, quote) - def format_column(self, column, use_table=False, - name=None, table_name=None, use_schema=False): + def format_column( + self, + column, + use_table=False, + name=None, + table_name=None, + use_schema=False, + ): """Prepare a quoted column name.""" if name is None: name = column.name - if not getattr(column, 'is_literal', False): + if not getattr(column, "is_literal", False): if use_table: - return self.format_table( - column.table, use_schema=use_schema, - name=table_name) + "." + self.quote(name) + return ( + self.format_table( + column.table, use_schema=use_schema, name=table_name + ) + + "." + + self.quote(name) + ) else: return self.quote(name) else: @@ -3223,9 +3585,13 @@ class IdentifierPreparer(object): # which shouldn't get quoted if use_table: - return self.format_table( - column.table, use_schema=use_schema, - name=table_name) + '.' + name + return ( + self.format_table( + column.table, use_schema=use_schema, name=table_name + ) + + "." + + name + ) else: return name @@ -3238,31 +3604,37 @@ class IdentifierPreparer(object): effective_schema = self.schema_for_object(table) - if not self.omit_schema and use_schema and \ - effective_schema: - return (self.quote_schema(effective_schema), - self.format_table(table, use_schema=False)) + if not self.omit_schema and use_schema and effective_schema: + return ( + self.quote_schema(effective_schema), + self.format_table(table, use_schema=False), + ) else: - return (self.format_table(table, use_schema=False), ) + return (self.format_table(table, use_schema=False),) @util.memoized_property def _r_identifiers(self): - initial, final, escaped_final = \ - [re.escape(s) for s in - (self.initial_quote, self.final_quote, - self._escape_identifier(self.final_quote))] + initial, final, escaped_final = [ + re.escape(s) + for s in ( + self.initial_quote, + self.final_quote, + self._escape_identifier(self.final_quote), + ) + ] r = re.compile( - r'(?:' - r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' - r'|([^\.]+))(?=\.|$))+' % - {'initial': initial, - 'final': final, - 'escaped': escaped_final}) + r"(?:" + r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s" + r"|([^\.]+))(?=\.|$))+" + % {"initial": initial, "final": final, "escaped": escaped_final} + ) return r def unformat_identifiers(self, identifiers): """Unpack 'schema.table.column'-like strings into components.""" r = self._r_identifiers - return [self._unescape_identifier(i) - for i in [a or b for a, b in r.findall(identifiers)]] + return [ + self._unescape_identifier(i) + for i in [a or b for a, b in r.findall(identifiers)] + ] diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 999d48a552..602b91a255 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -15,7 +15,9 @@ from . import dml from . import elements import operator -REQUIRED = util.symbol('REQUIRED', """ +REQUIRED = util.symbol( + "REQUIRED", + """ Placeholder for the value within a :class:`.BindParameter` which is required to be present when the statement is passed to :meth:`.Connection.execute`. @@ -24,11 +26,12 @@ This symbol is typically used when a :func:`.expression.insert` or :func:`.expression.update` statement is compiled without parameter values present. -""") +""", +) -ISINSERT = util.symbol('ISINSERT') -ISUPDATE = util.symbol('ISUPDATE') -ISDELETE = util.symbol('ISDELETE') +ISINSERT = util.symbol("ISINSERT") +ISUPDATE = util.symbol("ISUPDATE") +ISDELETE = util.symbol("ISDELETE") def _setup_crud_params(compiler, stmt, local_stmt_type, **kw): @@ -82,8 +85,7 @@ def _get_crud_params(compiler, stmt, **kw): # compiled params - return binds for all columns if compiler.column_keys is None and stmt.parameters is None: return [ - (c, _create_bind_param( - compiler, c, None, required=True)) + (c, _create_bind_param(compiler, c, None, required=True)) for c in stmt.table.columns ] @@ -95,26 +97,28 @@ def _get_crud_params(compiler, stmt, **kw): # getters - these are normally just column.key, # but in the case of mysql multi-table update, the rules for # .key must conditionally take tablename into account - _column_as_key, _getattr_col_key, _col_bind_name = \ - _key_getters_for_crud_column(compiler, stmt) + _column_as_key, _getattr_col_key, _col_bind_name = _key_getters_for_crud_column( + compiler, stmt + ) # if we have statement parameters - set defaults in the # compiled params if compiler.column_keys is None: parameters = {} else: - parameters = dict((_column_as_key(key), REQUIRED) - for key in compiler.column_keys - if not stmt_parameters or - key not in stmt_parameters) + parameters = dict( + (_column_as_key(key), REQUIRED) + for key in compiler.column_keys + if not stmt_parameters or key not in stmt_parameters + ) # create a list of column assignment clauses as tuples values = [] if stmt_parameters is not None: _get_stmt_parameters_params( - compiler, - parameters, stmt_parameters, _column_as_key, values, kw) + compiler, parameters, stmt_parameters, _column_as_key, values, kw + ) check_columns = {} @@ -122,28 +126,51 @@ def _get_crud_params(compiler, stmt, **kw): # statements if compiler.isupdate and stmt._extra_froms and stmt_parameters: _get_multitable_params( - compiler, stmt, stmt_parameters, check_columns, - _col_bind_name, _getattr_col_key, values, kw) + compiler, + stmt, + stmt_parameters, + check_columns, + _col_bind_name, + _getattr_col_key, + values, + kw, + ) if compiler.isinsert and stmt.select_names: _scan_insert_from_select_cols( - compiler, stmt, parameters, - _getattr_col_key, _column_as_key, - _col_bind_name, check_columns, values, kw) + compiler, + stmt, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + kw, + ) else: _scan_cols( - compiler, stmt, parameters, - _getattr_col_key, _column_as_key, - _col_bind_name, check_columns, values, kw) + compiler, + stmt, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + kw, + ) if parameters and stmt_parameters: - check = set(parameters).intersection( - _column_as_key(k) for k in stmt_parameters - ).difference(check_columns) + check = ( + set(parameters) + .intersection(_column_as_key(k) for k in stmt_parameters) + .difference(check_columns) + ) if check: raise exc.CompileError( - "Unconsumed column names: %s" % - (", ".join("%s" % c for c in check)) + "Unconsumed column names: %s" + % (", ".join("%s" % c for c in check)) ) if stmt._has_multi_parameters: @@ -153,12 +180,13 @@ def _get_crud_params(compiler, stmt, **kw): def _create_bind_param( - compiler, col, value, process=True, - required=False, name=None, **kw): + compiler, col, value, process=True, required=False, name=None, **kw +): if name is None: name = col.key bindparam = elements.BindParameter( - name, value, type_=col.type, required=required) + name, value, type_=col.type, required=required + ) bindparam._is_crud = True if process: bindparam = bindparam._compiler_dispatch(compiler, **kw) @@ -177,7 +205,7 @@ def _key_getters_for_crud_column(compiler, stmt): def _column_as_key(key): str_key = elements._column_as_key(key) - if hasattr(key, 'table') and key.table in _et: + if hasattr(key, "table") and key.table in _et: return (key.table.name, str_key) else: return str_key @@ -202,15 +230,22 @@ def _key_getters_for_crud_column(compiler, stmt): def _scan_insert_from_select_cols( - compiler, stmt, parameters, _getattr_col_key, - _column_as_key, _col_bind_name, check_columns, values, kw): - - need_pks, implicit_returning, \ - implicit_return_defaults, postfetch_lastrowid = \ - _get_returning_modifiers(compiler, stmt) + compiler, + stmt, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + kw, +): + + need_pks, implicit_returning, implicit_return_defaults, postfetch_lastrowid = _get_returning_modifiers( + compiler, stmt + ) - cols = [stmt.table.c[_column_as_key(name)] - for name in stmt.select_names] + cols = [stmt.table.c[_column_as_key(name)] for name in stmt.select_names] compiler._insert_from_select = stmt.select @@ -228,32 +263,39 @@ def _scan_insert_from_select_cols( values.append((c, None)) else: _append_param_insert_select_hasdefault( - compiler, stmt, c, add_select_cols, kw) + compiler, stmt, c, add_select_cols, kw + ) if add_select_cols: values.extend(add_select_cols) compiler._insert_from_select = compiler._insert_from_select._generate() - compiler._insert_from_select._raw_columns = \ - tuple(compiler._insert_from_select._raw_columns) + tuple( - expr for col, expr in add_select_cols) + compiler._insert_from_select._raw_columns = tuple( + compiler._insert_from_select._raw_columns + ) + tuple(expr for col, expr in add_select_cols) def _scan_cols( - compiler, stmt, parameters, _getattr_col_key, - _column_as_key, _col_bind_name, check_columns, values, kw): - - need_pks, implicit_returning, \ - implicit_return_defaults, postfetch_lastrowid = \ - _get_returning_modifiers(compiler, stmt) + compiler, + stmt, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + kw, +): + + need_pks, implicit_returning, implicit_return_defaults, postfetch_lastrowid = _get_returning_modifiers( + compiler, stmt + ) if stmt._parameter_ordering: parameter_ordering = [ _column_as_key(key) for key in stmt._parameter_ordering ] ordered_keys = set(parameter_ordering) - cols = [ - stmt.table.c[key] for key in parameter_ordering - ] + [ + cols = [stmt.table.c[key] for key in parameter_ordering] + [ c for c in stmt.table.c if c.key not in ordered_keys ] else: @@ -265,72 +307,95 @@ def _scan_cols( if col_key in parameters and col_key not in check_columns: _append_param_parameter( - compiler, stmt, c, col_key, parameters, _col_bind_name, - implicit_returning, implicit_return_defaults, values, kw) + compiler, + stmt, + c, + col_key, + parameters, + _col_bind_name, + implicit_returning, + implicit_return_defaults, + values, + kw, + ) elif compiler.isinsert: - if c.primary_key and \ - need_pks and \ - ( - implicit_returning or - not postfetch_lastrowid or - c is not stmt.table._autoincrement_column - ): + if ( + c.primary_key + and need_pks + and ( + implicit_returning + or not postfetch_lastrowid + or c is not stmt.table._autoincrement_column + ) + ): if implicit_returning: _append_param_insert_pk_returning( - compiler, stmt, c, values, kw) + compiler, stmt, c, values, kw + ) else: _append_param_insert_pk(compiler, stmt, c, values, kw) elif c.default is not None: _append_param_insert_hasdefault( - compiler, stmt, c, implicit_return_defaults, - values, kw) + compiler, stmt, c, implicit_return_defaults, values, kw + ) elif c.server_default is not None: - if implicit_return_defaults and \ - c in implicit_return_defaults: + if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) - elif implicit_return_defaults and \ - c in implicit_return_defaults: + elif implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) - elif c.primary_key and \ - c is not stmt.table._autoincrement_column and \ - not c.nullable: + elif ( + c.primary_key + and c is not stmt.table._autoincrement_column + and not c.nullable + ): _warn_pk_with_no_anticipated_value(c) elif compiler.isupdate: _append_param_update( - compiler, stmt, c, implicit_return_defaults, values, kw) + compiler, stmt, c, implicit_return_defaults, values, kw + ) def _append_param_parameter( - compiler, stmt, c, col_key, parameters, _col_bind_name, - implicit_returning, implicit_return_defaults, values, kw): + compiler, + stmt, + c, + col_key, + parameters, + _col_bind_name, + implicit_returning, + implicit_return_defaults, + values, + kw, +): value = parameters.pop(col_key) if elements._is_literal(value): value = _create_bind_param( - compiler, c, value, required=value is REQUIRED, + compiler, + c, + value, + required=value is REQUIRED, name=_col_bind_name(c) if not stmt._has_multi_parameters else "%s_m0" % _col_bind_name(c), **kw ) else: - if isinstance(value, elements.BindParameter) and \ - value.type._isnull: + if isinstance(value, elements.BindParameter) and value.type._isnull: value = value._clone() value.type = c.type if c.primary_key and implicit_returning: compiler.returning.append(c) value = compiler.process(value.self_group(), **kw) - elif implicit_return_defaults and \ - c in implicit_return_defaults: + elif implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) value = compiler.process(value.self_group(), **kw) else: @@ -358,22 +423,20 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): """ if c.default is not None: if c.default.is_sequence: - if compiler.dialect.supports_sequences and \ - (not c.default.optional or - not compiler.dialect.sequences_optional): + if compiler.dialect.supports_sequences and ( + not c.default.optional + or not compiler.dialect.sequences_optional + ): proc = compiler.process(c.default, **kw) values.append((c, proc)) compiler.returning.append(c) elif c.default.is_clause_element: values.append( - (c, compiler.process( - c.default.arg.self_group(), **kw)) + (c, compiler.process(c.default.arg.self_group(), **kw)) ) compiler.returning.append(c) else: - values.append( - (c, _create_insert_prefetch_bind_param(compiler, c)) - ) + values.append((c, _create_insert_prefetch_bind_param(compiler, c))) elif c is stmt.table._autoincrement_column or c.server_default is not None: compiler.returning.append(c) elif not c.nullable: @@ -405,9 +468,11 @@ class _multiparam_column(elements.ColumnElement): self.type = original.type def __eq__(self, other): - return isinstance(other, _multiparam_column) and \ - other.key == self.key and \ - other.original == self.original + return ( + isinstance(other, _multiparam_column) + and other.key == self.key + and other.original == self.original + ) def _process_multiparam_default_bind(compiler, stmt, c, index, kw): @@ -416,7 +481,8 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw): raise exc.CompileError( "INSERT value for column %s is explicitly rendered as a bound" "parameter in the VALUES clause; " - "a Python-side value or SQL expression is required" % c) + "a Python-side value or SQL expression is required" % c + ) elif c.default.is_clause_element: return compiler.process(c.default.arg.self_group(), **kw) else: @@ -440,30 +506,24 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw): """ if ( - ( - # column has a Python-side default - c.default is not None and - ( - # and it won't be a Sequence - not c.default.is_sequence or - compiler.dialect.supports_sequences - ) - ) - or - ( - # column is the "autoincrement column" - c is stmt.table._autoincrement_column and - ( - # and it's either a "sequence" or a - # pre-executable "autoincrement" sequence - compiler.dialect.supports_sequences or - compiler.dialect.preexecute_autoincrement_sequences - ) - ) - ): - values.append( - (c, _create_insert_prefetch_bind_param(compiler, c)) + # column has a Python-side default + c.default is not None + and ( + # and it won't be a Sequence + not c.default.is_sequence + or compiler.dialect.supports_sequences ) + ) or ( + # column is the "autoincrement column" + c is stmt.table._autoincrement_column + and ( + # and it's either a "sequence" or a + # pre-executable "autoincrement" sequence + compiler.dialect.supports_sequences + or compiler.dialect.preexecute_autoincrement_sequences + ) + ): + values.append((c, _create_insert_prefetch_bind_param(compiler, c))) elif c.default is None and c.server_default is None and not c.nullable: # no .default, no .server_default, not autoincrement, we have # no indication this primary key column will have any value @@ -471,16 +531,16 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw): def _append_param_insert_hasdefault( - compiler, stmt, c, implicit_return_defaults, values, kw): + compiler, stmt, c, implicit_return_defaults, values, kw +): if c.default.is_sequence: - if compiler.dialect.supports_sequences and \ - (not c.default.optional or - not compiler.dialect.sequences_optional): + if compiler.dialect.supports_sequences and ( + not c.default.optional or not compiler.dialect.sequences_optional + ): proc = compiler.process(c.default, **kw) values.append((c, proc)) - if implicit_return_defaults and \ - c in implicit_return_defaults: + if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) @@ -488,25 +548,21 @@ def _append_param_insert_hasdefault( proc = compiler.process(c.default.arg.self_group(), **kw) values.append((c, proc)) - if implicit_return_defaults and \ - c in implicit_return_defaults: + if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) elif not c.primary_key: # don't add primary key column to postfetch compiler.postfetch.append(c) else: - values.append( - (c, _create_insert_prefetch_bind_param(compiler, c)) - ) + values.append((c, _create_insert_prefetch_bind_param(compiler, c))) -def _append_param_insert_select_hasdefault( - compiler, stmt, c, values, kw): +def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw): if c.default.is_sequence: - if compiler.dialect.supports_sequences and \ - (not c.default.optional or - not compiler.dialect.sequences_optional): + if compiler.dialect.supports_sequences and ( + not c.default.optional or not compiler.dialect.sequences_optional + ): proc = c.default values.append((c, proc.next_value())) elif c.default.is_clause_element: @@ -519,38 +575,43 @@ def _append_param_insert_select_hasdefault( def _append_param_update( - compiler, stmt, c, implicit_return_defaults, values, kw): + compiler, stmt, c, implicit_return_defaults, values, kw +): if c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, compiler.process( - c.onupdate.arg.self_group(), **kw)) + (c, compiler.process(c.onupdate.arg.self_group(), **kw)) ) - if implicit_return_defaults and \ - c in implicit_return_defaults: + if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) else: compiler.postfetch.append(c) else: - values.append( - (c, _create_update_prefetch_bind_param(compiler, c)) - ) + values.append((c, _create_update_prefetch_bind_param(compiler, c))) elif c.server_onupdate is not None: - if implicit_return_defaults and \ - c in implicit_return_defaults: + if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) else: compiler.postfetch.append(c) - elif implicit_return_defaults and \ - stmt._return_defaults is not True and \ - c in implicit_return_defaults: + elif ( + implicit_return_defaults + and stmt._return_defaults is not True + and c in implicit_return_defaults + ): compiler.returning.append(c) def _get_multitable_params( - compiler, stmt, stmt_parameters, check_columns, - _col_bind_name, _getattr_col_key, values, kw): + compiler, + stmt, + stmt_parameters, + check_columns, + _col_bind_name, + _getattr_col_key, + values, + kw, +): normalized_params = dict( (elements._clause_element_as_expr(c), param) @@ -565,8 +626,12 @@ def _get_multitable_params( value = normalized_params[c] if elements._is_literal(value): value = _create_bind_param( - compiler, c, value, required=value is REQUIRED, - name=_col_bind_name(c)) + compiler, + c, + value, + required=value is REQUIRED, + name=_col_bind_name(c), + ) else: compiler.postfetch.append(c) value = compiler.process(value.self_group(), **kw) @@ -577,20 +642,25 @@ def _get_multitable_params( for c in t.c: if c in normalized_params: continue - elif (c.onupdate is not None and not - c.onupdate.is_sequence): + elif c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, compiler.process( - c.onupdate.arg.self_group(), - **kw) - ) + ( + c, + compiler.process( + c.onupdate.arg.self_group(), **kw + ), + ) ) compiler.postfetch.append(c) else: values.append( - (c, _create_update_prefetch_bind_param( - compiler, c, name=_col_bind_name(c))) + ( + c, + _create_update_prefetch_bind_param( + compiler, c, name=_col_bind_name(c) + ), + ) ) elif c.server_onupdate is not None: compiler.postfetch.append(c) @@ -608,8 +678,11 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw): if elements._is_literal(row[key]): new_param = _create_bind_param( - compiler, col, row[key], - name="%s_m%d" % (col.key, i + 1), **kw + compiler, + col, + row[key], + name="%s_m%d" % (col.key, i + 1), + **kw ) else: new_param = compiler.process(row[key].self_group(), **kw) @@ -626,7 +699,8 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw): def _get_stmt_parameters_params( - compiler, parameters, stmt_parameters, _column_as_key, values, kw): + compiler, parameters, stmt_parameters, _column_as_key, values, kw +): for k, v in stmt_parameters.items(): colkey = _column_as_key(k) if colkey is not None: @@ -637,8 +711,8 @@ def _get_stmt_parameters_params( # coercing right side to bound param if elements._is_literal(v): v = compiler.process( - elements.BindParameter(None, v, type_=k.type), - **kw) + elements.BindParameter(None, v, type_=k.type), **kw + ) else: v = compiler.process(v.self_group(), **kw) @@ -646,22 +720,27 @@ def _get_stmt_parameters_params( def _get_returning_modifiers(compiler, stmt): - need_pks = compiler.isinsert and \ - not compiler.inline and \ - not stmt._returning and \ - not stmt._has_multi_parameters + need_pks = ( + compiler.isinsert + and not compiler.inline + and not stmt._returning + and not stmt._has_multi_parameters + ) - implicit_returning = need_pks and \ - compiler.dialect.implicit_returning and \ - stmt.table.implicit_returning + implicit_returning = ( + need_pks + and compiler.dialect.implicit_returning + and stmt.table.implicit_returning + ) if compiler.isinsert: - implicit_return_defaults = (implicit_returning and - stmt._return_defaults) + implicit_return_defaults = implicit_returning and stmt._return_defaults elif compiler.isupdate: - implicit_return_defaults = (compiler.dialect.implicit_returning and - stmt.table.implicit_returning and - stmt._return_defaults) + implicit_return_defaults = ( + compiler.dialect.implicit_returning + and stmt.table.implicit_returning + and stmt._return_defaults + ) else: # this line is unused, currently we are always # isinsert or isupdate @@ -675,8 +754,12 @@ def _get_returning_modifiers(compiler, stmt): postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid - return need_pks, implicit_returning, \ - implicit_return_defaults, postfetch_lastrowid + return ( + need_pks, + implicit_returning, + implicit_return_defaults, + postfetch_lastrowid, + ) def _warn_pk_with_no_anticipated_value(c): @@ -687,8 +770,8 @@ def _warn_pk_with_no_anticipated_value(c): "nor does it indicate 'autoincrement=True' or 'nullable=True', " "and no explicit value is passed. " "Primary key columns typically may not store NULL." - % - (c.table.fullname, c.name, c.table.fullname)) + % (c.table.fullname, c.name, c.table.fullname) + ) if len(c.table.primary_key) > 1: msg += ( " Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be " @@ -696,5 +779,6 @@ def _warn_pk_with_no_anticipated_value(c): "keys if AUTO_INCREMENT/SERIAL/IDENTITY " "behavior is expected for one of the columns in the primary key. " "CREATE TABLE statements are impacted by this change as well on " - "most backends.") + "most backends." + ) util.warn(msg) diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 91e93efe74..f21b3d7f06 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -56,8 +56,9 @@ class DDLElement(Executable, _DDLCompiles): """ - _execution_options = Executable.\ - _execution_options.union({'autocommit': True}) + _execution_options = Executable._execution_options.union( + {"autocommit": True} + ) target = None on = None @@ -95,11 +96,13 @@ class DDLElement(Executable, _DDLCompiles): if self._should_execute(target, bind): return bind.execute(self.against(target)) else: - bind.engine.logger.info( - "DDL execution skipped, criteria not met.") + bind.engine.logger.info("DDL execution skipped, criteria not met.") - @util.deprecated("0.7", "See :class:`.DDLEvents`, as well as " - ":meth:`.DDLElement.execute_if`.") + @util.deprecated( + "0.7", + "See :class:`.DDLEvents`, as well as " + ":meth:`.DDLElement.execute_if`.", + ) def execute_at(self, event_name, target): """Link execution of this DDL to the DDL lifecycle of a SchemaItem. @@ -129,11 +132,12 @@ class DDLElement(Executable, _DDLCompiles): """ def call_event(target, connection, **kw): - if self._should_execute_deprecated(event_name, - target, connection, **kw): + if self._should_execute_deprecated( + event_name, target, connection, **kw + ): return connection.execute(self.against(target)) - event.listen(target, "" + event_name.replace('-', '_'), call_event) + event.listen(target, "" + event_name.replace("-", "_"), call_event) @_generative def against(self, target): @@ -211,8 +215,9 @@ class DDLElement(Executable, _DDLCompiles): self.state = state def _should_execute(self, target, bind, **kw): - if self.on is not None and \ - not self._should_execute_deprecated(None, target, bind, **kw): + if self.on is not None and not self._should_execute_deprecated( + None, target, bind, **kw + ): return False if isinstance(self.dialect, util.string_types): @@ -221,9 +226,9 @@ class DDLElement(Executable, _DDLCompiles): elif isinstance(self.dialect, (tuple, list, set)): if bind.engine.name not in self.dialect: return False - if (self.callable_ is not None and - not self.callable_(self, target, bind, - state=self.state, **kw)): + if self.callable_ is not None and not self.callable_( + self, target, bind, state=self.state, **kw + ): return False return True @@ -245,13 +250,15 @@ class DDLElement(Executable, _DDLCompiles): return bind.execute(self.against(target)) def _check_ddl_on(self, on): - if (on is not None and - (not isinstance(on, util.string_types + (tuple, list, set)) and - not util.callable(on))): + if on is not None and ( + not isinstance(on, util.string_types + (tuple, list, set)) + and not util.callable(on) + ): raise exc.ArgumentError( "Expected the name of a database dialect, a tuple " "of names, or a callable for " - "'on' criteria, got type '%s'." % type(on).__name__) + "'on' criteria, got type '%s'." % type(on).__name__ + ) def bind(self): if self._bind: @@ -259,6 +266,7 @@ class DDLElement(Executable, _DDLCompiles): def _set_bind(self, bind): self._bind = bind + bind = property(bind, _set_bind) def _generate(self): @@ -375,8 +383,9 @@ class DDL(DDLElement): if not isinstance(statement, util.string_types): raise exc.ArgumentError( - "Expected a string or unicode SQL statement, got '%r'" % - statement) + "Expected a string or unicode SQL statement, got '%r'" + % statement + ) self.statement = statement self.context = context or {} @@ -386,12 +395,18 @@ class DDL(DDLElement): self._bind = bind def __repr__(self): - return '<%s@%s; %s>' % ( - type(self).__name__, id(self), - ', '.join([repr(self.statement)] + - ['%s=%r' % (key, getattr(self, key)) - for key in ('on', 'context') - if getattr(self, key)])) + return "<%s@%s; %s>" % ( + type(self).__name__, + id(self), + ", ".join( + [repr(self.statement)] + + [ + "%s=%r" % (key, getattr(self, key)) + for key in ("on", "context") + if getattr(self, key) + ] + ), + ) class _CreateDropBase(DDLElement): @@ -464,8 +479,8 @@ class CreateTable(_CreateDropBase): __visit_name__ = "create_table" def __init__( - self, element, on=None, bind=None, - include_foreign_key_constraints=None): + self, element, on=None, bind=None, include_foreign_key_constraints=None + ): """Create a :class:`.CreateTable` construct. :param element: a :class:`.Table` that's the subject @@ -481,9 +496,7 @@ class CreateTable(_CreateDropBase): """ super(CreateTable, self).__init__(element, on=on, bind=bind) - self.columns = [CreateColumn(column) - for column in element.columns - ] + self.columns = [CreateColumn(column) for column in element.columns] self.include_foreign_key_constraints = include_foreign_key_constraints @@ -494,6 +507,7 @@ class _DropView(_CreateDropBase): This object will eventually be part of a public "view" API. """ + __visit_name__ = "drop_view" @@ -602,7 +616,8 @@ class CreateColumn(_DDLCompiles): to support custom column creation styles. """ - __visit_name__ = 'create_column' + + __visit_name__ = "create_column" def __init__(self, element): self.element = element @@ -646,7 +661,8 @@ class AddConstraint(_CreateDropBase): def __init__(self, element, *args, **kw): super(AddConstraint, self).__init__(element, *args, **kw) element._create_rule = util.portable_instancemethod( - self._create_rule_disable) + self._create_rule_disable + ) class DropConstraint(_CreateDropBase): @@ -658,7 +674,8 @@ class DropConstraint(_CreateDropBase): self.cascade = cascade super(DropConstraint, self).__init__(element, **kw) element._create_rule = util.portable_instancemethod( - self._create_rule_disable) + self._create_rule_disable + ) class SetTableComment(_CreateDropBase): @@ -691,9 +708,9 @@ class DDLBase(SchemaVisitor): class SchemaGenerator(DDLBase): - - def __init__(self, dialect, connection, checkfirst=False, - tables=None, **kwargs): + def __init__( + self, dialect, connection, checkfirst=False, tables=None, **kwargs + ): super(SchemaGenerator, self).__init__(connection, **kwargs) self.checkfirst = checkfirst self.tables = tables @@ -706,25 +723,22 @@ class SchemaGenerator(DDLBase): effective_schema = self.connection.schema_for_object(table) if effective_schema: self.dialect.validate_identifier(effective_schema) - return not self.checkfirst or \ - not self.dialect.has_table(self.connection, - table.name, schema=effective_schema) + return not self.checkfirst or not self.dialect.has_table( + self.connection, table.name, schema=effective_schema + ) def _can_create_sequence(self, sequence): effective_schema = self.connection.schema_for_object(sequence) - return self.dialect.supports_sequences and \ - ( - (not self.dialect.sequences_optional or - not sequence.optional) and - ( - not self.checkfirst or - not self.dialect.has_sequence( - self.connection, - sequence.name, - schema=effective_schema) + return self.dialect.supports_sequences and ( + (not self.dialect.sequences_optional or not sequence.optional) + and ( + not self.checkfirst + or not self.dialect.has_sequence( + self.connection, sequence.name, schema=effective_schema ) ) + ) def visit_metadata(self, metadata): if self.tables is not None: @@ -733,18 +747,23 @@ class SchemaGenerator(DDLBase): tables = list(metadata.tables.values()) collection = sort_tables_and_constraints( - [t for t in tables if self._can_create_table(t)]) - - seq_coll = [s for s in metadata._sequences.values() - if s.column is None and self._can_create_sequence(s)] + [t for t in tables if self._can_create_table(t)] + ) - event_collection = [ - t for (t, fks) in collection if t is not None + seq_coll = [ + s + for s in metadata._sequences.values() + if s.column is None and self._can_create_sequence(s) ] - metadata.dispatch.before_create(metadata, self.connection, - tables=event_collection, - checkfirst=self.checkfirst, - _ddl_runner=self) + + event_collection = [t for (t, fks) in collection if t is not None] + metadata.dispatch.before_create( + metadata, + self.connection, + tables=event_collection, + checkfirst=self.checkfirst, + _ddl_runner=self, + ) for seq in seq_coll: self.traverse_single(seq, create_ok=True) @@ -752,30 +771,40 @@ class SchemaGenerator(DDLBase): for table, fkcs in collection: if table is not None: self.traverse_single( - table, create_ok=True, + table, + create_ok=True, include_foreign_key_constraints=fkcs, - _is_metadata_operation=True) + _is_metadata_operation=True, + ) else: for fkc in fkcs: self.traverse_single(fkc) - metadata.dispatch.after_create(metadata, self.connection, - tables=event_collection, - checkfirst=self.checkfirst, - _ddl_runner=self) + metadata.dispatch.after_create( + metadata, + self.connection, + tables=event_collection, + checkfirst=self.checkfirst, + _ddl_runner=self, + ) def visit_table( - self, table, create_ok=False, - include_foreign_key_constraints=None, - _is_metadata_operation=False): + self, + table, + create_ok=False, + include_foreign_key_constraints=None, + _is_metadata_operation=False, + ): if not create_ok and not self._can_create_table(table): return table.dispatch.before_create( - table, self.connection, + table, + self.connection, checkfirst=self.checkfirst, _ddl_runner=self, - _is_metadata_operation=_is_metadata_operation) + _is_metadata_operation=_is_metadata_operation, + ) for column in table.columns: if column.default is not None: @@ -788,10 +817,11 @@ class SchemaGenerator(DDLBase): self.connection.execute( CreateTable( table, - include_foreign_key_constraints=include_foreign_key_constraints - )) + include_foreign_key_constraints=include_foreign_key_constraints, + ) + ) - if hasattr(table, 'indexes'): + if hasattr(table, "indexes"): for index in table.indexes: self.traverse_single(index) @@ -804,10 +834,12 @@ class SchemaGenerator(DDLBase): self.connection.execute(SetColumnComment(column)) table.dispatch.after_create( - table, self.connection, + table, + self.connection, checkfirst=self.checkfirst, _ddl_runner=self, - _is_metadata_operation=_is_metadata_operation) + _is_metadata_operation=_is_metadata_operation, + ) def visit_foreign_key_constraint(self, constraint): if not self.dialect.supports_alter: @@ -824,9 +856,9 @@ class SchemaGenerator(DDLBase): class SchemaDropper(DDLBase): - - def __init__(self, dialect, connection, checkfirst=False, - tables=None, **kwargs): + def __init__( + self, dialect, connection, checkfirst=False, tables=None, **kwargs + ): super(SchemaDropper, self).__init__(connection, **kwargs) self.checkfirst = checkfirst self.tables = tables @@ -842,15 +874,17 @@ class SchemaDropper(DDLBase): try: unsorted_tables = [t for t in tables if self._can_drop_table(t)] - collection = list(reversed( - sort_tables_and_constraints( - unsorted_tables, - filter_fn=lambda constraint: False - if not self.dialect.supports_alter - or constraint.name is None - else None + collection = list( + reversed( + sort_tables_and_constraints( + unsorted_tables, + filter_fn=lambda constraint: False + if not self.dialect.supports_alter + or constraint.name is None + else None, + ) ) - )) + ) except exc.CircularDependencyError as err2: if not self.dialect.supports_alter: util.warn( @@ -862,16 +896,15 @@ class SchemaDropper(DDLBase): "ForeignKeyConstraint " "objects involved in the cycle to mark these as known " "cycles that will be ignored." - % ( - ", ".join(sorted([t.fullname for t in err2.cycles])) - ) + % (", ".join(sorted([t.fullname for t in err2.cycles]))) ) collection = [(t, ()) for t in unsorted_tables] else: util.raise_from_cause( exc.CircularDependencyError( err2.args[0], - err2.cycles, err2.edges, + err2.cycles, + err2.edges, msg="Can't sort tables for DROP; an " "unresolvable foreign key " "dependency exists between tables: %s. Please ensure " @@ -880,9 +913,10 @@ class SchemaDropper(DDLBase): "names so that they can be dropped using " "DROP CONSTRAINT." % ( - ", ".join(sorted([t.fullname for t in err2.cycles])) - ) - + ", ".join( + sorted([t.fullname for t in err2.cycles]) + ) + ), ) ) @@ -892,18 +926,21 @@ class SchemaDropper(DDLBase): if s.column is None and self._can_drop_sequence(s) ] - event_collection = [ - t for (t, fks) in collection if t is not None - ] + event_collection = [t for (t, fks) in collection if t is not None] metadata.dispatch.before_drop( - metadata, self.connection, tables=event_collection, - checkfirst=self.checkfirst, _ddl_runner=self) + metadata, + self.connection, + tables=event_collection, + checkfirst=self.checkfirst, + _ddl_runner=self, + ) for table, fkcs in collection: if table is not None: self.traverse_single( - table, drop_ok=True, _is_metadata_operation=True) + table, drop_ok=True, _is_metadata_operation=True + ) else: for fkc in fkcs: self.traverse_single(fkc) @@ -912,8 +949,12 @@ class SchemaDropper(DDLBase): self.traverse_single(seq, drop_ok=True) metadata.dispatch.after_drop( - metadata, self.connection, tables=event_collection, - checkfirst=self.checkfirst, _ddl_runner=self) + metadata, + self.connection, + tables=event_collection, + checkfirst=self.checkfirst, + _ddl_runner=self, + ) def _can_drop_table(self, table): self.dialect.validate_identifier(table.name) @@ -921,19 +962,20 @@ class SchemaDropper(DDLBase): if effective_schema: self.dialect.validate_identifier(effective_schema) return not self.checkfirst or self.dialect.has_table( - self.connection, table.name, schema=effective_schema) + self.connection, table.name, schema=effective_schema + ) def _can_drop_sequence(self, sequence): effective_schema = self.connection.schema_for_object(sequence) - return self.dialect.supports_sequences and \ - ((not self.dialect.sequences_optional or - not sequence.optional) and - (not self.checkfirst or - self.dialect.has_sequence( - self.connection, - sequence.name, - schema=effective_schema)) - ) + return self.dialect.supports_sequences and ( + (not self.dialect.sequences_optional or not sequence.optional) + and ( + not self.checkfirst + or self.dialect.has_sequence( + self.connection, sequence.name, schema=effective_schema + ) + ) + ) def visit_index(self, index): self.connection.execute(DropIndex(index)) @@ -943,10 +985,12 @@ class SchemaDropper(DDLBase): return table.dispatch.before_drop( - table, self.connection, + table, + self.connection, checkfirst=self.checkfirst, _ddl_runner=self, - _is_metadata_operation=_is_metadata_operation) + _is_metadata_operation=_is_metadata_operation, + ) self.connection.execute(DropTable(table)) @@ -960,10 +1004,12 @@ class SchemaDropper(DDLBase): self.traverse_single(column.default) table.dispatch.after_drop( - table, self.connection, + table, + self.connection, checkfirst=self.checkfirst, _ddl_runner=self, - _is_metadata_operation=_is_metadata_operation) + _is_metadata_operation=_is_metadata_operation, + ) def visit_foreign_key_constraint(self, constraint): if not self.dialect.supports_alter: @@ -1019,25 +1065,29 @@ def sort_tables(tables, skip_fn=None, extra_dependencies=None): """ if skip_fn is not None: + def _skip_fn(fkc): for fk in fkc.elements: if skip_fn(fk): return True else: return None + else: _skip_fn = None return [ - t for (t, fkcs) in - sort_tables_and_constraints( - tables, filter_fn=_skip_fn, extra_dependencies=extra_dependencies) + t + for (t, fkcs) in sort_tables_and_constraints( + tables, filter_fn=_skip_fn, extra_dependencies=extra_dependencies + ) if t is not None ] def sort_tables_and_constraints( - tables, filter_fn=None, extra_dependencies=None): + tables, filter_fn=None, extra_dependencies=None +): """sort a collection of :class:`.Table` / :class:`.ForeignKeyConstraint` objects. @@ -1109,8 +1159,9 @@ def sort_tables_and_constraints( try: candidate_sort = list( topological.sort( - fixed_dependencies.union(mutable_dependencies), tables, - deterministic_order=True + fixed_dependencies.union(mutable_dependencies), + tables, + deterministic_order=True, ) ) except exc.CircularDependencyError as err: @@ -1118,8 +1169,10 @@ def sort_tables_and_constraints( if edge in mutable_dependencies: table = edge[1] can_remove = [ - fkc for fkc in table.foreign_key_constraints - if filter_fn is None or filter_fn(fkc) is not False] + fkc + for fkc in table.foreign_key_constraints + if filter_fn is None or filter_fn(fkc) is not False + ] remaining_fkcs.update(can_remove) for fkc in can_remove: dependent_on = fkc.referred_table @@ -1127,8 +1180,9 @@ def sort_tables_and_constraints( mutable_dependencies.discard((dependent_on, table)) candidate_sort = list( topological.sort( - fixed_dependencies.union(mutable_dependencies), tables, - deterministic_order=True + fixed_dependencies.union(mutable_dependencies), + tables, + deterministic_order=True, ) ) diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 8149f9731d..fa00521987 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -11,19 +11,43 @@ from .. import exc, util from . import type_api from . import operators -from .elements import BindParameter, True_, False_, BinaryExpression, \ - Null, _const_expr, _clause_element_as_expr, \ - ClauseList, ColumnElement, TextClause, UnaryExpression, \ - collate, _is_literal, _literal_as_text, ClauseElement, and_, or_, \ - Slice, Visitable, _literal_as_binds, CollectionAggregate, \ - Tuple +from .elements import ( + BindParameter, + True_, + False_, + BinaryExpression, + Null, + _const_expr, + _clause_element_as_expr, + ClauseList, + ColumnElement, + TextClause, + UnaryExpression, + collate, + _is_literal, + _literal_as_text, + ClauseElement, + and_, + or_, + Slice, + Visitable, + _literal_as_binds, + CollectionAggregate, + Tuple, +) from .selectable import SelectBase, Alias, Selectable, ScalarSelect -def _boolean_compare(expr, op, obj, negate=None, reverse=False, - _python_is_types=(util.NoneType, bool), - result_type = None, - **kwargs): +def _boolean_compare( + expr, + op, + obj, + negate=None, + reverse=False, + _python_is_types=(util.NoneType, bool), + result_type=None, + **kwargs +): if result_type is None: result_type = type_api.BOOLEANTYPE @@ -33,57 +57,64 @@ def _boolean_compare(expr, op, obj, negate=None, reverse=False, # allow x ==/!= True/False to be treated as a literal. # this comes out to "== / != true/false" or "1/0" if those # constants aren't supported and works on all platforms - if op in (operators.eq, operators.ne) and \ - isinstance(obj, (bool, True_, False_)): - return BinaryExpression(expr, - _literal_as_text(obj), - op, - type_=result_type, - negate=negate, modifiers=kwargs) + if op in (operators.eq, operators.ne) and isinstance( + obj, (bool, True_, False_) + ): + return BinaryExpression( + expr, + _literal_as_text(obj), + op, + type_=result_type, + negate=negate, + modifiers=kwargs, + ) elif op in (operators.is_distinct_from, operators.isnot_distinct_from): - return BinaryExpression(expr, - _literal_as_text(obj), - op, - type_=result_type, - negate=negate, modifiers=kwargs) + return BinaryExpression( + expr, + _literal_as_text(obj), + op, + type_=result_type, + negate=negate, + modifiers=kwargs, + ) else: # all other None/True/False uses IS, IS NOT if op in (operators.eq, operators.is_): - return BinaryExpression(expr, _const_expr(obj), - operators.is_, - negate=operators.isnot, - type_=result_type - ) + return BinaryExpression( + expr, + _const_expr(obj), + operators.is_, + negate=operators.isnot, + type_=result_type, + ) elif op in (operators.ne, operators.isnot): - return BinaryExpression(expr, _const_expr(obj), - operators.isnot, - negate=operators.is_, - type_=result_type - ) + return BinaryExpression( + expr, + _const_expr(obj), + operators.isnot, + negate=operators.is_, + type_=result_type, + ) else: raise exc.ArgumentError( "Only '=', '!=', 'is_()', 'isnot()', " "'is_distinct_from()', 'isnot_distinct_from()' " - "operators can be used with None/True/False") + "operators can be used with None/True/False" + ) else: obj = _check_literal(expr, op, obj) if reverse: - return BinaryExpression(obj, - expr, - op, - type_=result_type, - negate=negate, modifiers=kwargs) + return BinaryExpression( + obj, expr, op, type_=result_type, negate=negate, modifiers=kwargs + ) else: - return BinaryExpression(expr, - obj, - op, - type_=result_type, - negate=negate, modifiers=kwargs) + return BinaryExpression( + expr, obj, op, type_=result_type, negate=negate, modifiers=kwargs + ) -def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, - **kw): +def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw): if result_type is None: if op.return_type: result_type = op.return_type @@ -91,11 +122,11 @@ def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, result_type = type_api.BOOLEANTYPE return _binary_operate( - expr, op, obj, reverse=reverse, result_type=result_type, **kw) + expr, op, obj, reverse=reverse, result_type=result_type, **kw + ) -def _binary_operate(expr, op, obj, reverse=False, result_type=None, - **kw): +def _binary_operate(expr, op, obj, reverse=False, result_type=None, **kw): obj = _check_literal(expr, op, obj) if reverse: @@ -105,10 +136,10 @@ def _binary_operate(expr, op, obj, reverse=False, result_type=None, if result_type is None: op, result_type = left.comparator._adapt_expression( - op, right.comparator) + op, right.comparator + ) - return BinaryExpression( - left, right, op, type_=result_type, modifiers=kw) + return BinaryExpression(left, right, op, type_=result_type, modifiers=kw) def _conjunction_operate(expr, op, other, **kw): @@ -128,8 +159,7 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): seq_or_selectable = _clause_element_as_expr(seq_or_selectable) if isinstance(seq_or_selectable, ScalarSelect): - return _boolean_compare(expr, op, seq_or_selectable, - negate=negate_op) + return _boolean_compare(expr, op, seq_or_selectable, negate=negate_op) elif isinstance(seq_or_selectable, SelectBase): # TODO: if we ever want to support (x, y, z) IN (select x, @@ -138,32 +168,33 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): # does not export itself as a FROM clause return _boolean_compare( - expr, op, seq_or_selectable.as_scalar(), - negate=negate_op, **kw) + expr, op, seq_or_selectable.as_scalar(), negate=negate_op, **kw + ) elif isinstance(seq_or_selectable, (Selectable, TextClause)): - return _boolean_compare(expr, op, seq_or_selectable, - negate=negate_op, **kw) + return _boolean_compare( + expr, op, seq_or_selectable, negate=negate_op, **kw + ) elif isinstance(seq_or_selectable, ClauseElement): - if isinstance(seq_or_selectable, BindParameter) and \ - seq_or_selectable.expanding: + if ( + isinstance(seq_or_selectable, BindParameter) + and seq_or_selectable.expanding + ): if isinstance(expr, Tuple): - seq_or_selectable = ( - seq_or_selectable._with_expanding_in_types( - [elem.type for elem in expr] - ) + seq_or_selectable = seq_or_selectable._with_expanding_in_types( + [elem.type for elem in expr] ) return _boolean_compare( - expr, op, - seq_or_selectable, - negate=negate_op) + expr, op, seq_or_selectable, negate=negate_op + ) else: raise exc.InvalidRequestError( - 'in_() accepts' - ' either a list of expressions, ' + "in_() accepts" + " either a list of expressions, " 'a selectable, or an "expanding" bound parameter: %r' - % seq_or_selectable) + % seq_or_selectable + ) # Handle non selectable arguments as sequences args = [] @@ -171,9 +202,10 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): if not _is_literal(o): if not isinstance(o, operators.ColumnOperators): raise exc.InvalidRequestError( - 'in_() accepts' - ' either a list of expressions, ' - 'a selectable, or an "expanding" bound parameter: %r' % o) + "in_() accepts" + " either a list of expressions, " + 'a selectable, or an "expanding" bound parameter: %r' % o + ) elif o is None: o = Null() else: @@ -182,15 +214,14 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): if len(args) == 0: op, negate_op = ( - operators.empty_in_op, - operators.empty_notin_op) if op is operators.in_op \ - else ( - operators.empty_notin_op, - operators.empty_in_op) + (operators.empty_in_op, operators.empty_notin_op) + if op is operators.in_op + else (operators.empty_notin_op, operators.empty_in_op) + ) - return _boolean_compare(expr, op, - ClauseList(*args).self_group(against=op), - negate=negate_op) + return _boolean_compare( + expr, op, ClauseList(*args).self_group(against=op), negate=negate_op + ) def _getitem_impl(expr, op, other, **kw): @@ -202,13 +233,14 @@ def _getitem_impl(expr, op, other, **kw): def _unsupported_impl(expr, op, *arg, **kw): - raise NotImplementedError("Operator '%s' is not supported on " - "this expression" % op.__name__) + raise NotImplementedError( + "Operator '%s' is not supported on " "this expression" % op.__name__ + ) def _inv_impl(expr, op, **kw): """See :meth:`.ColumnOperators.__inv__`.""" - if hasattr(expr, 'negation_clause'): + if hasattr(expr, "negation_clause"): return expr.negation_clause else: return expr._negate() @@ -223,20 +255,22 @@ def _match_impl(expr, op, other, **kw): """See :meth:`.ColumnOperators.match`.""" return _boolean_compare( - expr, operators.match_op, - _check_literal( - expr, operators.match_op, other), + expr, + operators.match_op, + _check_literal(expr, operators.match_op, other), result_type=type_api.MATCHTYPE, negate=operators.notmatch_op - if op is operators.match_op else operators.match_op, + if op is operators.match_op + else operators.match_op, **kw ) def _distinct_impl(expr, op, **kw): """See :meth:`.ColumnOperators.distinct`.""" - return UnaryExpression(expr, operator=operators.distinct_op, - type_=expr.type) + return UnaryExpression( + expr, operator=operators.distinct_op, type_=expr.type + ) def _between_impl(expr, op, cleft, cright, **kw): @@ -247,17 +281,21 @@ def _between_impl(expr, op, cleft, cright, **kw): _check_literal(expr, operators.and_, cleft), _check_literal(expr, operators.and_, cright), operator=operators.and_, - group=False, group_contents=False), + group=False, + group_contents=False, + ), op, negate=operators.notbetween_op if op is operators.between_op else operators.between_op, - modifiers=kw) + modifiers=kw, + ) def _collate_impl(expr, op, other, **kw): return collate(expr, other) + # a mapping of operators with the method they use, along with # their negated operator for comparison operators operator_lookup = { @@ -271,8 +309,8 @@ operator_lookup = { "mod": (_binary_operate,), "truediv": (_binary_operate,), "custom_op": (_custom_op_operate,), - "json_path_getitem_op": (_binary_operate, ), - "json_getitem_op": (_binary_operate, ), + "json_path_getitem_op": (_binary_operate,), + "json_getitem_op": (_binary_operate,), "concat_op": (_binary_operate,), "any_op": (_scalar, CollectionAggregate._create_any), "all_op": (_scalar, CollectionAggregate._create_all), @@ -303,8 +341,8 @@ operator_lookup = { "match_op": (_match_impl,), "notmatch_op": (_match_impl,), "distinct_op": (_distinct_impl,), - "between_op": (_between_impl, ), - "notbetween_op": (_between_impl, ), + "between_op": (_between_impl,), + "notbetween_op": (_between_impl,), "neg": (_neg_impl,), "getitem": (_getitem_impl,), "lshift": (_unsupported_impl,), @@ -315,12 +353,11 @@ operator_lookup = { def _check_literal(expr, operator, other, bindparam_type=None): if isinstance(other, (ColumnElement, TextClause)): - if isinstance(other, BindParameter) and \ - other.type._isnull: + if isinstance(other, BindParameter) and other.type._isnull: other = other._clone() other.type = expr.type return other - elif hasattr(other, '__clause_element__'): + elif hasattr(other, "__clause_element__"): other = other.__clause_element__() elif isinstance(other, type_api.TypeEngine.Comparator): other = other.expr @@ -331,4 +368,3 @@ def _check_literal(expr, operator, other, bindparam_type=None): return expr._bind_param(operator, other, type_=bindparam_type) else: return other - diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index d6890de154..0cea5ccc42 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -9,26 +9,43 @@ Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`. """ -from .base import Executable, _generative, _from_objects, DialectKWArgs, \ - ColumnCollection -from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \ - _column_as_key -from .selectable import _interpret_as_from, _interpret_as_select, \ - HasPrefixes, HasCTE +from .base import ( + Executable, + _generative, + _from_objects, + DialectKWArgs, + ColumnCollection, +) +from .elements import ( + ClauseElement, + _literal_as_text, + Null, + and_, + _clone, + _column_as_key, +) +from .selectable import ( + _interpret_as_from, + _interpret_as_select, + HasPrefixes, + HasCTE, +) from .. import util from .. import exc class UpdateBase( - HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement): + HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement +): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements. """ - __visit_name__ = 'update_base' + __visit_name__ = "update_base" - _execution_options = \ - Executable._execution_options.union({'autocommit': True}) + _execution_options = Executable._execution_options.union( + {"autocommit": True} + ) _hints = util.immutabledict() _parameter_ordering = None _prefixes = () @@ -37,30 +54,33 @@ class UpdateBase( def _process_colparams(self, parameters): def process_single(p): if isinstance(p, (list, tuple)): - return dict( - (c.key, pval) - for c, pval in zip(self.table.c, p) - ) + return dict((c.key, pval) for c, pval in zip(self.table.c, p)) else: return p if self._preserve_parameter_order and parameters is not None: - if not isinstance(parameters, list) or \ - (parameters and not isinstance(parameters[0], tuple)): + if not isinstance(parameters, list) or ( + parameters and not isinstance(parameters[0], tuple) + ): raise ValueError( "When preserve_parameter_order is True, " - "values() only accepts a list of 2-tuples") + "values() only accepts a list of 2-tuples" + ) self._parameter_ordering = [key for key, value in parameters] return dict(parameters), False - if (isinstance(parameters, (list, tuple)) and parameters and - isinstance(parameters[0], (list, tuple, dict))): + if ( + isinstance(parameters, (list, tuple)) + and parameters + and isinstance(parameters[0], (list, tuple, dict)) + ): if not self._supports_multi_parameters: raise exc.InvalidRequestError( "This construct does not support " - "multiple parameter sets.") + "multiple parameter sets." + ) return [process_single(p) for p in parameters], True else: @@ -77,7 +97,8 @@ class UpdateBase( raise NotImplementedError( "params() is not supported for INSERT/UPDATE/DELETE statements." " To set the values for an INSERT or UPDATE statement, use" - " stmt.values(**parameters).") + " stmt.values(**parameters)." + ) def bind(self): """Return a 'bind' linked to this :class:`.UpdateBase` @@ -88,6 +109,7 @@ class UpdateBase( def _set_bind(self, bind): self._bind = bind + bind = property(bind, _set_bind) @_generative @@ -181,15 +203,14 @@ class UpdateBase( if selectable is None: selectable = self.table - self._hints = self._hints.union( - {(selectable, dialect_name): text}) + self._hints = self._hints.union({(selectable, dialect_name): text}) class ValuesBase(UpdateBase): """Supplies support for :meth:`.ValuesBase.values` to INSERT and UPDATE constructs.""" - __visit_name__ = 'values_base' + __visit_name__ = "values_base" _supports_multi_parameters = False _has_multi_parameters = False @@ -199,8 +220,9 @@ class ValuesBase(UpdateBase): def __init__(self, table, values, prefixes): self.table = _interpret_as_from(table) - self.parameters, self._has_multi_parameters = \ - self._process_colparams(values) + self.parameters, self._has_multi_parameters = self._process_colparams( + values + ) if prefixes: self._setup_prefixes(prefixes) @@ -332,23 +354,27 @@ class ValuesBase(UpdateBase): """ if self.select is not None: raise exc.InvalidRequestError( - "This construct already inserts from a SELECT") + "This construct already inserts from a SELECT" + ) if self._has_multi_parameters and kwargs: raise exc.InvalidRequestError( - "This construct already has multiple parameter sets.") + "This construct already has multiple parameter sets." + ) if args: if len(args) > 1: raise exc.ArgumentError( "Only a single dictionary/tuple or list of " - "dictionaries/tuples is accepted positionally.") + "dictionaries/tuples is accepted positionally." + ) v = args[0] else: v = {} if self.parameters is None: - self.parameters, self._has_multi_parameters = \ - self._process_colparams(v) + self.parameters, self._has_multi_parameters = self._process_colparams( + v + ) else: if self._has_multi_parameters: self.parameters = list(self.parameters) @@ -356,7 +382,8 @@ class ValuesBase(UpdateBase): if not self._has_multi_parameters: raise exc.ArgumentError( "Can't mix single-values and multiple values " - "formats in one statement") + "formats in one statement" + ) self.parameters.extend(p) else: @@ -365,14 +392,16 @@ class ValuesBase(UpdateBase): if self._has_multi_parameters: raise exc.ArgumentError( "Can't mix single-values and multiple values " - "formats in one statement") + "formats in one statement" + ) self.parameters.update(p) if kwargs: if self._has_multi_parameters: raise exc.ArgumentError( "Can't pass kwargs and multiple parameter sets " - "simultaneously") + "simultaneously" + ) else: self.parameters.update(kwargs) @@ -456,19 +485,22 @@ class Insert(ValuesBase): :ref:`coretutorial_insert_expressions` """ - __visit_name__ = 'insert' + + __visit_name__ = "insert" _supports_multi_parameters = True - def __init__(self, - table, - values=None, - inline=False, - bind=None, - prefixes=None, - returning=None, - return_defaults=False, - **dialect_kw): + def __init__( + self, + table, + values=None, + inline=False, + bind=None, + prefixes=None, + returning=None, + return_defaults=False, + **dialect_kw + ): """Construct an :class:`.Insert` object. Similar functionality is available via the @@ -526,7 +558,7 @@ class Insert(ValuesBase): def get_children(self, **kwargs): if self.select is not None: - return self.select, + return (self.select,) else: return () @@ -578,11 +610,12 @@ class Insert(ValuesBase): """ if self.parameters: raise exc.InvalidRequestError( - "This construct already inserts value expressions") + "This construct already inserts value expressions" + ) - self.parameters, self._has_multi_parameters = \ - self._process_colparams( - {_column_as_key(n): Null() for n in names}) + self.parameters, self._has_multi_parameters = self._process_colparams( + {_column_as_key(n): Null() for n in names} + ) self.select_names = names self.inline = True @@ -603,19 +636,22 @@ class Update(ValuesBase): function. """ - __visit_name__ = 'update' - - def __init__(self, - table, - whereclause=None, - values=None, - inline=False, - bind=None, - prefixes=None, - returning=None, - return_defaults=False, - preserve_parameter_order=False, - **dialect_kw): + + __visit_name__ = "update" + + def __init__( + self, + table, + whereclause=None, + values=None, + inline=False, + bind=None, + prefixes=None, + returning=None, + return_defaults=False, + preserve_parameter_order=False, + **dialect_kw + ): r"""Construct an :class:`.Update` object. E.g.:: @@ -745,7 +781,7 @@ class Update(ValuesBase): def get_children(self, **kwargs): if self._whereclause is not None: - return self._whereclause, + return (self._whereclause,) else: return () @@ -761,8 +797,9 @@ class Update(ValuesBase): """ if self._whereclause is not None: - self._whereclause = and_(self._whereclause, - _literal_as_text(whereclause)) + self._whereclause = and_( + self._whereclause, _literal_as_text(whereclause) + ) else: self._whereclause = _literal_as_text(whereclause) @@ -788,15 +825,17 @@ class Delete(UpdateBase): """ - __visit_name__ = 'delete' - - def __init__(self, - table, - whereclause=None, - bind=None, - returning=None, - prefixes=None, - **dialect_kw): + __visit_name__ = "delete" + + def __init__( + self, + table, + whereclause=None, + bind=None, + returning=None, + prefixes=None, + **dialect_kw + ): """Construct :class:`.Delete` object. Similar functionality is available via the @@ -847,7 +886,7 @@ class Delete(UpdateBase): def get_children(self, **kwargs): if self._whereclause is not None: - return self._whereclause, + return (self._whereclause,) else: return () @@ -856,8 +895,9 @@ class Delete(UpdateBase): """Add the given WHERE clause to a newly returned delete construct.""" if self._whereclause is not None: - self._whereclause = and_(self._whereclause, - _literal_as_text(whereclause)) + self._whereclause = and_( + self._whereclause, _literal_as_text(whereclause) + ) else: self._whereclause = _literal_as_text(whereclause) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index de3b7992af..e857f2da85 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -51,9 +51,8 @@ def collate(expression, collation): expr = _literal_as_binds(expression) return BinaryExpression( - expr, - CollationClause(collation), - operators.collate, type_=expr.type) + expr, CollationClause(collation), operators.collate, type_=expr.type + ) def between(expr, lower_bound, upper_bound, symmetric=False): @@ -130,8 +129,6 @@ def literal(value, type_=None): return BindParameter(None, value, type_=type_, unique=True) - - def outparam(key, type_=None): """Create an 'OUT' parameter for usage in functions (stored procedures), for databases which support them. @@ -142,8 +139,7 @@ def outparam(key, type_=None): attribute, which returns a dictionary containing the values. """ - return BindParameter( - key, None, type_=type_, unique=False, isoutparam=True) + return BindParameter(key, None, type_=type_, unique=False, isoutparam=True) def not_(clause): @@ -163,7 +159,8 @@ class ClauseElement(Visitable): expression. """ - __visit_name__ = 'clause' + + __visit_name__ = "clause" _annotations = {} supports_execution = False @@ -230,7 +227,7 @@ class ClauseElement(Visitable): def __getstate__(self): d = self.__dict__.copy() - d.pop('_is_clone_of', None) + d.pop("_is_clone_of", None) return d def _annotate(self, values): @@ -300,7 +297,8 @@ class ClauseElement(Visitable): kwargs.update(optionaldict[0]) elif len(optionaldict) > 1: raise exc.ArgumentError( - "params() takes zero or one positional dictionary argument") + "params() takes zero or one positional dictionary argument" + ) def visit_bindparam(bind): if bind.key in kwargs: @@ -308,7 +306,8 @@ class ClauseElement(Visitable): bind.required = False if unique: bind._convert_to_unique() - return cloned_traverse(self, {}, {'bindparam': visit_bindparam}) + + return cloned_traverse(self, {}, {"bindparam": visit_bindparam}) def compare(self, other, **kw): r"""Compare this ClauseElement to the given ClauseElement. @@ -451,7 +450,7 @@ class ClauseElement(Visitable): if util.py3k: return str(self.compile()) else: - return unicode(self.compile()).encode('ascii', 'backslashreplace') + return unicode(self.compile()).encode("ascii", "backslashreplace") def __and__(self, other): """'and' at the ClauseElement level. @@ -472,7 +471,7 @@ class ClauseElement(Visitable): return or_(self, other) def __invert__(self): - if hasattr(self, 'negation_clause'): + if hasattr(self, "negation_clause"): return self.negation_clause else: return self._negate() @@ -481,7 +480,8 @@ class ClauseElement(Visitable): return UnaryExpression( self.self_group(against=operators.inv), operator=operators.inv, - negate=None) + negate=None, + ) def __bool__(self): raise TypeError("Boolean value of this clause is not defined") @@ -493,8 +493,12 @@ class ClauseElement(Visitable): if friendly is None: return object.__repr__(self) else: - return '<%s.%s at 0x%x; %s>' % ( - self.__module__, self.__class__.__name__, id(self), friendly) + return "<%s.%s at 0x%x; %s>" % ( + self.__module__, + self.__class__.__name__, + id(self), + friendly, + ) class ColumnElement(operators.ColumnOperators, ClauseElement): @@ -571,7 +575,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): """ - __visit_name__ = 'column_element' + __visit_name__ = "column_element" primary_key = False foreign_keys = [] @@ -646,11 +650,12 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): _alt_names = () def self_group(self, against=None): - if (against in (operators.and_, operators.or_, operators._asbool) and - self.type._type_affinity - is type_api.BOOLEANTYPE._type_affinity): + if ( + against in (operators.and_, operators.or_, operators._asbool) + and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity + ): return AsBoolean(self, operators.istrue, operators.isfalse) - elif (against in (operators.any_op, operators.all_op)): + elif against in (operators.any_op, operators.all_op): return Grouping(self) else: return self @@ -675,7 +680,8 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): except AttributeError: raise TypeError( "Object %r associated with '.type' attribute " - "is not a TypeEngine class or object" % self.type) + "is not a TypeEngine class or object" % self.type + ) else: return comparator_factory(self) @@ -684,10 +690,8 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): return getattr(self.comparator, key) except AttributeError: raise AttributeError( - 'Neither %r object nor %r object has an attribute %r' % ( - type(self).__name__, - type(self.comparator).__name__, - key) + "Neither %r object nor %r object has an attribute %r" + % (type(self).__name__, type(self.comparator).__name__, key) ) def operate(self, op, *other, **kwargs): @@ -697,10 +701,14 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): return op(other, self.comparator, **kwargs) def _bind_param(self, operator, obj, type_=None): - return BindParameter(None, obj, - _compared_to_operator=operator, - type_=type_, - _compared_to_type=self.type, unique=True) + return BindParameter( + None, + obj, + _compared_to_operator=operator, + type_=type_, + _compared_to_type=self.type, + unique=True, + ) @property def expression(self): @@ -713,17 +721,18 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): @property def _select_iterable(self): - return (self, ) + return (self,) @util.memoized_property def base_columns(self): - return util.column_set(c for c in self.proxy_set - if not hasattr(c, '_proxies')) + return util.column_set( + c for c in self.proxy_set if not hasattr(c, "_proxies") + ) @util.memoized_property def proxy_set(self): s = util.column_set([self]) - if hasattr(self, '_proxies'): + if hasattr(self, "_proxies"): for c in self._proxies: s.update(c.proxy_set) return s @@ -738,11 +747,15 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): """Return True if the given column element compares to this one when targeting within a result row.""" - return hasattr(other, 'name') and hasattr(self, 'name') and \ - other.name == self.name + return ( + hasattr(other, "name") + and hasattr(self, "name") + and other.name == self.name + ) def _make_proxy( - self, selectable, name=None, name_is_truncatable=False, **kw): + self, selectable, name=None, name_is_truncatable=False, **kw + ): """Create a new :class:`.ColumnElement` representing this :class:`.ColumnElement` as it appears in the select list of a descending selectable. @@ -762,13 +775,12 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): key = name co = ColumnClause( _as_truncated(name) if name_is_truncatable else name, - type_=getattr(self, 'type', None), - _selectable=selectable + type_=getattr(self, "type", None), + _selectable=selectable, ) co._proxies = [self] if selectable._is_clone_of is not None: - co._is_clone_of = \ - selectable._is_clone_of.columns.get(key) + co._is_clone_of = selectable._is_clone_of.columns.get(key) selectable._columns[key] = co return co @@ -788,7 +800,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): this one via foreign key or other criterion. """ - to_compare = (other, ) + to_compare = (other,) if equivalents and other in equivalents: to_compare = equivalents[other].union(to_compare) @@ -838,7 +850,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): self = self._is_clone_of return _anonymous_label( - '%%(%d %s)s' % (id(self), getattr(self, 'name', 'anon')) + "%%(%d %s)s" % (id(self), getattr(self, "name", "anon")) ) @@ -862,18 +874,25 @@ class BindParameter(ColumnElement): """ - __visit_name__ = 'bindparam' + __visit_name__ = "bindparam" _is_crud = False _expanding_in_types = () - def __init__(self, key, value=NO_ARG, type_=None, - unique=False, required=NO_ARG, - quote=None, callable_=None, - expanding=False, - isoutparam=False, - _compared_to_operator=None, - _compared_to_type=None): + def __init__( + self, + key, + value=NO_ARG, + type_=None, + unique=False, + required=NO_ARG, + quote=None, + callable_=None, + expanding=False, + isoutparam=False, + _compared_to_operator=None, + _compared_to_type=None, + ): r"""Produce a "bound expression". The return value is an instance of :class:`.BindParameter`; this @@ -1093,7 +1112,7 @@ class BindParameter(ColumnElement): type_ = key.type key = key.key if required is NO_ARG: - required = (value is NO_ARG and callable_ is None) + required = value is NO_ARG and callable_ is None if value is NO_ARG: value = None @@ -1101,11 +1120,11 @@ class BindParameter(ColumnElement): key = quoted_name(key, quote) if unique: - self.key = _anonymous_label('%%(%d %s)s' % (id(self), key - or 'param')) + self.key = _anonymous_label( + "%%(%d %s)s" % (id(self), key or "param") + ) else: - self.key = key or _anonymous_label('%%(%d param)s' - % id(self)) + self.key = key or _anonymous_label("%%(%d param)s" % id(self)) # identifying key that won't change across # clones, used to identify the bind's logical @@ -1114,7 +1133,7 @@ class BindParameter(ColumnElement): # key that was passed in the first place, used to # generate new keys - self._orig_key = key or 'param' + self._orig_key = key or "param" self.unique = unique self.value = value @@ -1125,9 +1144,9 @@ class BindParameter(ColumnElement): if type_ is None: if _compared_to_type is not None: - self.type = \ - _compared_to_type.coerce_compared_value( - _compared_to_operator, value) + self.type = _compared_to_type.coerce_compared_value( + _compared_to_operator, value + ) else: self.type = type_api._resolve_value_to_type(value) elif isinstance(type_, type): @@ -1174,24 +1193,28 @@ class BindParameter(ColumnElement): def _clone(self): c = ClauseElement._clone(self) if self.unique: - c.key = _anonymous_label('%%(%d %s)s' % (id(c), c._orig_key - or 'param')) + c.key = _anonymous_label( + "%%(%d %s)s" % (id(c), c._orig_key or "param") + ) return c def _convert_to_unique(self): if not self.unique: self.unique = True self.key = _anonymous_label( - '%%(%d %s)s' % (id(self), self._orig_key or 'param')) + "%%(%d %s)s" % (id(self), self._orig_key or "param") + ) def compare(self, other, **kw): """Compare this :class:`BindParameter` to the given clause.""" - return isinstance(other, BindParameter) \ - and self.type._compare_type_affinity(other.type) \ - and self.value == other.value \ + return ( + isinstance(other, BindParameter) + and self.type._compare_type_affinity(other.type) + and self.value == other.value and self.callable == other.callable + ) def __getstate__(self): """execute a deferred value for serialization purposes.""" @@ -1200,13 +1223,16 @@ class BindParameter(ColumnElement): v = self.value if self.callable: v = self.callable() - d['callable'] = None - d['value'] = v + d["callable"] = None + d["value"] = v return d def __repr__(self): - return 'BindParameter(%r, %r, type_=%r)' % (self.key, - self.value, self.type) + return "BindParameter(%r, %r, type_=%r)" % ( + self.key, + self.value, + self.type, + ) class TypeClause(ClauseElement): @@ -1216,7 +1242,7 @@ class TypeClause(ClauseElement): """ - __visit_name__ = 'typeclause' + __visit_name__ = "typeclause" def __init__(self, type): self.type = type @@ -1242,12 +1268,12 @@ class TextClause(Executable, ClauseElement): """ - __visit_name__ = 'textclause' + __visit_name__ = "textclause" - _bind_params_regex = re.compile(r'(?.name.quote``') + @util.deprecated("0.9", "Use ``.name.quote``") def quote(self): """Return the value of the ``quote`` flag passed to this schema object, for those schema items which @@ -121,7 +131,7 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable): return {} def _schema_item_copy(self, schema_item): - if 'info' in self.__dict__: + if "info" in self.__dict__: schema_item.info = self.info.copy() schema_item.dispatch._update(self.dispatch) return schema_item @@ -396,7 +406,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause): """ - __visit_name__ = 'table' + __visit_name__ = "table" def __new__(cls, *args, **kw): if not args: @@ -408,26 +418,26 @@ class Table(DialectKWArgs, SchemaItem, TableClause): except IndexError: raise TypeError("Table() takes at least two arguments") - schema = kw.get('schema', None) + schema = kw.get("schema", None) if schema is None: schema = metadata.schema elif schema is BLANK_SCHEMA: schema = None - keep_existing = kw.pop('keep_existing', False) - extend_existing = kw.pop('extend_existing', False) - if 'useexisting' in kw: + keep_existing = kw.pop("keep_existing", False) + extend_existing = kw.pop("extend_existing", False) + if "useexisting" in kw: msg = "useexisting is deprecated. Use extend_existing." util.warn_deprecated(msg) if extend_existing: msg = "useexisting is synonymous with extend_existing." raise exc.ArgumentError(msg) - extend_existing = kw.pop('useexisting', False) + extend_existing = kw.pop("useexisting", False) if keep_existing and extend_existing: msg = "keep_existing and extend_existing are mutually exclusive." raise exc.ArgumentError(msg) - mustexist = kw.pop('mustexist', False) + mustexist = kw.pop("mustexist", False) key = _get_table_key(name, schema) if key in metadata.tables: if not keep_existing and not extend_existing and bool(args): @@ -436,15 +446,15 @@ class Table(DialectKWArgs, SchemaItem, TableClause): "instance. Specify 'extend_existing=True' " "to redefine " "options and columns on an " - "existing Table object." % key) + "existing Table object." % key + ) table = metadata.tables[key] if extend_existing: table._init_existing(*args, **kw) return table else: if mustexist: - raise exc.InvalidRequestError( - "Table '%s' not defined" % (key)) + raise exc.InvalidRequestError("Table '%s' not defined" % (key)) table = object.__new__(cls) table.dispatch.before_parent_attach(table, metadata) metadata._add_table(name, schema, table) @@ -457,7 +467,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause): metadata._remove_table(name, schema) @property - @util.deprecated('0.9', 'Use ``table.schema.quote``') + @util.deprecated("0.9", "Use ``table.schema.quote``") def quote_schema(self): """Return the value of the ``quote_schema`` flag passed to this :class:`.Table`. @@ -478,23 +488,25 @@ class Table(DialectKWArgs, SchemaItem, TableClause): def _init(self, name, metadata, *args, **kwargs): super(Table, self).__init__( - quoted_name(name, kwargs.pop('quote', None))) + quoted_name(name, kwargs.pop("quote", None)) + ) self.metadata = metadata - self.schema = kwargs.pop('schema', None) + self.schema = kwargs.pop("schema", None) if self.schema is None: self.schema = metadata.schema elif self.schema is BLANK_SCHEMA: self.schema = None else: - quote_schema = kwargs.pop('quote_schema', None) + quote_schema = kwargs.pop("quote_schema", None) self.schema = quoted_name(self.schema, quote_schema) self.indexes = set() self.constraints = set() self._columns = ColumnCollection() - PrimaryKeyConstraint(_implicit_generated=True).\ - _set_parent_with_dispatch(self) + PrimaryKeyConstraint( + _implicit_generated=True + )._set_parent_with_dispatch(self) self.foreign_keys = set() self._extra_dependencies = set() if self.schema is not None: @@ -502,26 +514,26 @@ class Table(DialectKWArgs, SchemaItem, TableClause): else: self.fullname = self.name - autoload_with = kwargs.pop('autoload_with', None) - autoload = kwargs.pop('autoload', autoload_with is not None) + autoload_with = kwargs.pop("autoload_with", None) + autoload = kwargs.pop("autoload", autoload_with is not None) # this argument is only used with _init_existing() - kwargs.pop('autoload_replace', True) + kwargs.pop("autoload_replace", True) _extend_on = kwargs.pop("_extend_on", None) - include_columns = kwargs.pop('include_columns', None) + include_columns = kwargs.pop("include_columns", None) - self.implicit_returning = kwargs.pop('implicit_returning', True) + self.implicit_returning = kwargs.pop("implicit_returning", True) - self.comment = kwargs.pop('comment', None) + self.comment = kwargs.pop("comment", None) - if 'info' in kwargs: - self.info = kwargs.pop('info') - if 'listeners' in kwargs: - listeners = kwargs.pop('listeners') + if "info" in kwargs: + self.info = kwargs.pop("info") + if "listeners" in kwargs: + listeners = kwargs.pop("listeners") for evt, fn in listeners: event.listen(self, evt, fn) - self._prefixes = kwargs.pop('prefixes', []) + self._prefixes = kwargs.pop("prefixes", []) self._extra_kwargs(**kwargs) @@ -530,21 +542,29 @@ class Table(DialectKWArgs, SchemaItem, TableClause): # circular foreign keys if autoload: self._autoload( - metadata, autoload_with, - include_columns, _extend_on=_extend_on) + metadata, autoload_with, include_columns, _extend_on=_extend_on + ) # initialize all the column, etc. objects. done after reflection to # allow user-overrides self._init_items(*args) - def _autoload(self, metadata, autoload_with, include_columns, - exclude_columns=(), _extend_on=None): + def _autoload( + self, + metadata, + autoload_with, + include_columns, + exclude_columns=(), + _extend_on=None, + ): if autoload_with: autoload_with.run_callable( autoload_with.dialect.reflecttable, - self, include_columns, exclude_columns, - _extend_on=_extend_on + self, + include_columns, + exclude_columns, + _extend_on=_extend_on, ) else: bind = _bind_or_error( @@ -553,11 +573,14 @@ class Table(DialectKWArgs, SchemaItem, TableClause): "Pass an engine to the Table via " "autoload_with=, " "or associate the MetaData with an engine via " - "metadata.bind=") + "metadata.bind=", + ) bind.run_callable( bind.dialect.reflecttable, - self, include_columns, exclude_columns, - _extend_on=_extend_on + self, + include_columns, + exclude_columns, + _extend_on=_extend_on, ) @property @@ -582,34 +605,36 @@ class Table(DialectKWArgs, SchemaItem, TableClause): return set(fkc.constraint for fkc in self.foreign_keys) def _init_existing(self, *args, **kwargs): - autoload_with = kwargs.pop('autoload_with', None) - autoload = kwargs.pop('autoload', autoload_with is not None) - autoload_replace = kwargs.pop('autoload_replace', True) - schema = kwargs.pop('schema', None) - _extend_on = kwargs.pop('_extend_on', None) + autoload_with = kwargs.pop("autoload_with", None) + autoload = kwargs.pop("autoload", autoload_with is not None) + autoload_replace = kwargs.pop("autoload_replace", True) + schema = kwargs.pop("schema", None) + _extend_on = kwargs.pop("_extend_on", None) if schema and schema != self.schema: raise exc.ArgumentError( "Can't change schema of existing table from '%s' to '%s'", - (self.schema, schema)) + (self.schema, schema), + ) - include_columns = kwargs.pop('include_columns', None) + include_columns = kwargs.pop("include_columns", None) if include_columns is not None: for c in self.c: if c.name not in include_columns: self._columns.remove(c) - for key in ('quote', 'quote_schema'): + for key in ("quote", "quote_schema"): if key in kwargs: raise exc.ArgumentError( - "Can't redefine 'quote' or 'quote_schema' arguments") + "Can't redefine 'quote' or 'quote_schema' arguments" + ) - if 'comment' in kwargs: - self.comment = kwargs.pop('comment', None) + if "comment" in kwargs: + self.comment = kwargs.pop("comment", None) - if 'info' in kwargs: - self.info = kwargs.pop('info') + if "info" in kwargs: + self.info = kwargs.pop("info") if autoload: if not autoload_replace: @@ -620,8 +645,12 @@ class Table(DialectKWArgs, SchemaItem, TableClause): else: exclude_columns = () self._autoload( - self.metadata, autoload_with, - include_columns, exclude_columns, _extend_on=_extend_on) + self.metadata, + autoload_with, + include_columns, + exclude_columns, + _extend_on=_extend_on, + ) self._extra_kwargs(**kwargs) self._init_items(*args) @@ -653,10 +682,12 @@ class Table(DialectKWArgs, SchemaItem, TableClause): return _get_table_key(self.name, self.schema) def __repr__(self): - return "Table(%s)" % ', '.join( - [repr(self.name)] + [repr(self.metadata)] + - [repr(x) for x in self.columns] + - ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']]) + return "Table(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + ["%s=%s" % (k, repr(getattr(self, k))) for k in ["schema"]] + ) def __str__(self): return _get_table_key(self.description, self.schema) @@ -735,17 +766,19 @@ class Table(DialectKWArgs, SchemaItem, TableClause): def adapt_listener(target, connection, **kw): listener(event_name, target, connection) - event.listen(self, "" + event_name.replace('-', '_'), adapt_listener) + event.listen(self, "" + event_name.replace("-", "_"), adapt_listener) def _set_parent(self, metadata): metadata._add_table(self.name, self.schema, self) self.metadata = metadata - def get_children(self, column_collections=True, - schema_visitor=False, **kw): + def get_children( + self, column_collections=True, schema_visitor=False, **kw + ): if not schema_visitor: return TableClause.get_children( - self, column_collections=column_collections, **kw) + self, column_collections=column_collections, **kw + ) else: if column_collections: return list(self.columns) @@ -758,8 +791,9 @@ class Table(DialectKWArgs, SchemaItem, TableClause): if bind is None: bind = _bind_or_error(self) - return bind.run_callable(bind.dialect.has_table, - self.name, schema=self.schema) + return bind.run_callable( + bind.dialect.has_table, self.name, schema=self.schema + ) def create(self, bind=None, checkfirst=False): """Issue a ``CREATE`` statement for this @@ -774,9 +808,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause): if bind is None: bind = _bind_or_error(self) - bind._run_visitor(ddl.SchemaGenerator, - self, - checkfirst=checkfirst) + bind._run_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst) def drop(self, bind=None, checkfirst=False): """Issue a ``DROP`` statement for this @@ -790,12 +822,15 @@ class Table(DialectKWArgs, SchemaItem, TableClause): """ if bind is None: bind = _bind_or_error(self) - bind._run_visitor(ddl.SchemaDropper, - self, - checkfirst=checkfirst) - - def tometadata(self, metadata, schema=RETAIN_SCHEMA, - referred_schema_fn=None, name=None): + bind._run_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst) + + def tometadata( + self, + metadata, + schema=RETAIN_SCHEMA, + referred_schema_fn=None, + name=None, + ): """Return a copy of this :class:`.Table` associated with a different :class:`.MetaData`. @@ -868,29 +903,37 @@ class Table(DialectKWArgs, SchemaItem, TableClause): schema = metadata.schema key = _get_table_key(name, schema) if key in metadata.tables: - util.warn("Table '%s' already exists within the given " - "MetaData - not copying." % self.description) + util.warn( + "Table '%s' already exists within the given " + "MetaData - not copying." % self.description + ) return metadata.tables[key] args = [] for c in self.columns: args.append(c.copy(schema=schema)) table = Table( - name, metadata, schema=schema, + name, + metadata, + schema=schema, comment=self.comment, - *args, **self.kwargs + *args, + **self.kwargs ) for c in self.constraints: if isinstance(c, ForeignKeyConstraint): referred_schema = c._referred_schema if referred_schema_fn: fk_constraint_schema = referred_schema_fn( - self, schema, c, referred_schema) + self, schema, c, referred_schema + ) else: fk_constraint_schema = ( - schema if referred_schema == self.schema else None) + schema if referred_schema == self.schema else None + ) table.append_constraint( - c.copy(schema=fk_constraint_schema, target_table=table)) + c.copy(schema=fk_constraint_schema, target_table=table) + ) elif not c._type_bound: # skip unique constraints that would be generated # by the 'unique' flag on Column @@ -898,25 +941,30 @@ class Table(DialectKWArgs, SchemaItem, TableClause): continue table.append_constraint( - c.copy(schema=schema, target_table=table)) + c.copy(schema=schema, target_table=table) + ) for index in self.indexes: # skip indexes that would be generated # by the 'index' flag on Column if index._column_flag: continue - Index(index.name, - unique=index.unique, - *[_copy_expression(expr, self, table) - for expr in index.expressions], - _table=table, - **index.kwargs) + Index( + index.name, + unique=index.unique, + *[ + _copy_expression(expr, self, table) + for expr in index.expressions + ], + _table=table, + **index.kwargs + ) return self._schema_item_copy(table) class Column(DialectKWArgs, SchemaItem, ColumnClause): """Represents a column in a database table.""" - __visit_name__ = 'column' + __visit_name__ = "column" def __init__(self, *args, **kwargs): r""" @@ -1192,14 +1240,15 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): """ - name = kwargs.pop('name', None) - type_ = kwargs.pop('type_', None) + name = kwargs.pop("name", None) + type_ = kwargs.pop("type_", None) args = list(args) if args: if isinstance(args[0], util.string_types): if name is not None: raise exc.ArgumentError( - "May not pass name positionally and as a keyword.") + "May not pass name positionally and as a keyword." + ) name = args.pop(0) if args: coltype = args[0] @@ -1207,40 +1256,42 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): if hasattr(coltype, "_sqla_type"): if type_ is not None: raise exc.ArgumentError( - "May not pass type_ positionally and as a keyword.") + "May not pass type_ positionally and as a keyword." + ) type_ = args.pop(0) if name is not None: - name = quoted_name(name, kwargs.pop('quote', None)) + name = quoted_name(name, kwargs.pop("quote", None)) elif "quote" in kwargs: - raise exc.ArgumentError("Explicit 'name' is required when " - "sending 'quote' argument") + raise exc.ArgumentError( + "Explicit 'name' is required when " "sending 'quote' argument" + ) super(Column, self).__init__(name, type_) - self.key = kwargs.pop('key', name) - self.primary_key = kwargs.pop('primary_key', False) - self.nullable = kwargs.pop('nullable', not self.primary_key) - self.default = kwargs.pop('default', None) - self.server_default = kwargs.pop('server_default', None) - self.server_onupdate = kwargs.pop('server_onupdate', None) + self.key = kwargs.pop("key", name) + self.primary_key = kwargs.pop("primary_key", False) + self.nullable = kwargs.pop("nullable", not self.primary_key) + self.default = kwargs.pop("default", None) + self.server_default = kwargs.pop("server_default", None) + self.server_onupdate = kwargs.pop("server_onupdate", None) # these default to None because .index and .unique is *not* # an informational flag about Column - there can still be an # Index or UniqueConstraint referring to this Column. - self.index = kwargs.pop('index', None) - self.unique = kwargs.pop('unique', None) + self.index = kwargs.pop("index", None) + self.unique = kwargs.pop("unique", None) - self.system = kwargs.pop('system', False) - self.doc = kwargs.pop('doc', None) - self.onupdate = kwargs.pop('onupdate', None) - self.autoincrement = kwargs.pop('autoincrement', "auto") + self.system = kwargs.pop("system", False) + self.doc = kwargs.pop("doc", None) + self.onupdate = kwargs.pop("onupdate", None) + self.autoincrement = kwargs.pop("autoincrement", "auto") self.constraints = set() self.foreign_keys = set() - self.comment = kwargs.pop('comment', None) + self.comment = kwargs.pop("comment", None) # check if this Column is proxying another column - if '_proxies' in kwargs: - self._proxies = kwargs.pop('_proxies') + if "_proxies" in kwargs: + self._proxies = kwargs.pop("_proxies") # otherwise, add DDL-related events elif isinstance(self.type, SchemaEventTarget): self.type._set_parent_with_dispatch(self) @@ -1249,14 +1300,13 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): if isinstance(self.default, (ColumnDefault, Sequence)): args.append(self.default) else: - if getattr(self.type, '_warn_on_bytestring', False): + if getattr(self.type, "_warn_on_bytestring", False): if isinstance(self.default, util.binary_type): util.warn( "Unicode column '%s' has non-unicode " - "default value %r specified." % ( - self.key, - self.default - )) + "default value %r specified." + % (self.key, self.default) + ) args.append(ColumnDefault(self.default)) if self.server_default is not None: @@ -1275,30 +1325,31 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): if isinstance(self.server_onupdate, FetchedValue): args.append(self.server_onupdate._as_for_update(True)) else: - args.append(DefaultClause(self.server_onupdate, - for_update=True)) + args.append( + DefaultClause(self.server_onupdate, for_update=True) + ) self._init_items(*args) util.set_creation_order(self) - if 'info' in kwargs: - self.info = kwargs.pop('info') + if "info" in kwargs: + self.info = kwargs.pop("info") self._extra_kwargs(**kwargs) def _extra_kwargs(self, **kwargs): self._validate_dialect_kwargs(kwargs) -# @property -# def quote(self): -# return getattr(self.name, "quote", None) + # @property + # def quote(self): + # return getattr(self.name, "quote", None) def __str__(self): if self.name is None: return "(no name)" elif self.table is not None: if self.table.named_with_column: - return (self.table.description + "." + self.description) + return self.table.description + "." + self.description else: return self.description else: @@ -1320,40 +1371,47 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): def __repr__(self): kwarg = [] if self.key != self.name: - kwarg.append('key') + kwarg.append("key") if self.primary_key: - kwarg.append('primary_key') + kwarg.append("primary_key") if not self.nullable: - kwarg.append('nullable') + kwarg.append("nullable") if self.onupdate: - kwarg.append('onupdate') + kwarg.append("onupdate") if self.default: - kwarg.append('default') + kwarg.append("default") if self.server_default: - kwarg.append('server_default') - return "Column(%s)" % ', '.join( - [repr(self.name)] + [repr(self.type)] + - [repr(x) for x in self.foreign_keys if x is not None] + - [repr(x) for x in self.constraints] + - [(self.table is not None and "table=<%s>" % - self.table.description or "table=None")] + - ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]) + kwarg.append("server_default") + return "Column(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.type)] + + [repr(x) for x in self.foreign_keys if x is not None] + + [repr(x) for x in self.constraints] + + [ + ( + self.table is not None + and "table=<%s>" % self.table.description + or "table=None" + ) + ] + + ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg] + ) def _set_parent(self, table): if not self.name: raise exc.ArgumentError( "Column must be constructed with a non-blank name or " - "assign a non-blank .name before adding to a Table.") + "assign a non-blank .name before adding to a Table." + ) if self.key is None: self.key = self.name - existing = getattr(self, 'table', None) + existing = getattr(self, "table", None) if existing is not None and existing is not table: raise exc.ArgumentError( - "Column object '%s' already assigned to Table '%s'" % ( - self.key, - existing.description - )) + "Column object '%s' already assigned to Table '%s'" + % (self.key, existing.description) + ) if self.key in table._columns: col = table._columns.get(self.key) @@ -1373,8 +1431,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): elif self.key in table.primary_key: raise exc.ArgumentError( "Trying to redefine primary-key column '%s' as a " - "non-primary-key column on table '%s'" % ( - self.key, table.fullname)) + "non-primary-key column on table '%s'" + % (self.key, table.fullname) + ) self.table = table @@ -1383,7 +1442,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): raise exc.ArgumentError( "The 'index' keyword argument on Column is boolean only. " "To create indexes with a specific name, create an " - "explicit Index object external to the Table.") + "explicit Index object external to the Table." + ) Index(None, self, unique=bool(self.unique), _column_flag=True) elif self.unique: if isinstance(self.unique, util.string_types): @@ -1392,9 +1452,11 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): "only. To create unique constraints or indexes with a " "specific name, append an explicit UniqueConstraint to " "the Table's list of elements, or create an explicit " - "Index object external to the Table.") + "Index object external to the Table." + ) table.append_constraint( - UniqueConstraint(self.key, _column_flag=True)) + UniqueConstraint(self.key, _column_flag=True) + ) self._setup_on_memoized_fks(lambda fk: fk._set_remote_table(table)) @@ -1413,7 +1475,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): if self.table is not None: fn(self, self.table) else: - event.listen(self, 'after_parent_attach', fn) + event.listen(self, "after_parent_attach", fn) def copy(self, **kw): """Create a copy of this ``Column``, unitialized. @@ -1423,9 +1485,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): """ # Constraint objects plus non-constraint-bound ForeignKey objects - args = \ - [c.copy(**kw) for c in self.constraints if not c._type_bound] + \ - [c.copy(**kw) for c in self.foreign_keys if not c.constraint] + args = [ + c.copy(**kw) for c in self.constraints if not c._type_bound + ] + [c.copy(**kw) for c in self.foreign_keys if not c.constraint] type_ = self.type if isinstance(type_, SchemaEventTarget): @@ -1452,8 +1514,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): ) return self._schema_item_copy(c) - def _make_proxy(self, selectable, name=None, key=None, - name_is_truncatable=False, **kw): + def _make_proxy( + self, selectable, name=None, key=None, name_is_truncatable=False, **kw + ): """Create a *proxy* for this column. This is a copy of this ``Column`` referenced by a different parent @@ -1462,22 +1525,28 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): information is not transferred. """ - fk = [ForeignKey(f.column, _constraint=f.constraint) - for f in self.foreign_keys] + fk = [ + ForeignKey(f.column, _constraint=f.constraint) + for f in self.foreign_keys + ] if name is None and self.name is None: raise exc.InvalidRequestError( "Cannot initialize a sub-selectable" " with this Column object until its 'name' has " - "been assigned.") + "been assigned." + ) try: c = self._constructor( - _as_truncated(name or self.name) if - name_is_truncatable else (name or self.name), + _as_truncated(name or self.name) + if name_is_truncatable + else (name or self.name), self.type, key=key if key else name if name else self.key, primary_key=self.primary_key, nullable=self.nullable, - _proxies=[self], *fk) + _proxies=[self], + *fk + ) except TypeError: util.raise_from_cause( TypeError( @@ -1485,7 +1554,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): "Ensure the class includes a _constructor() " "attribute or method which accepts the " "standard Column constructor arguments, or " - "references the Column class itself." % self.__class__) + "references the Column class itself." % self.__class__ + ) ) c.table = selectable @@ -1499,9 +1569,11 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): def get_children(self, schema_visitor=False, **kwargs): if schema_visitor: - return [x for x in (self.default, self.onupdate) - if x is not None] + \ - list(self.foreign_keys) + list(self.constraints) + return ( + [x for x in (self.default, self.onupdate) if x is not None] + + list(self.foreign_keys) + + list(self.constraints) + ) else: return ColumnClause.get_children(self, **kwargs) @@ -1543,13 +1615,23 @@ class ForeignKey(DialectKWArgs, SchemaItem): """ - __visit_name__ = 'foreign_key' - - def __init__(self, column, _constraint=None, use_alter=False, name=None, - onupdate=None, ondelete=None, deferrable=None, - initially=None, link_to_name=False, match=None, - info=None, - **dialect_kw): + __visit_name__ = "foreign_key" + + def __init__( + self, + column, + _constraint=None, + use_alter=False, + name=None, + onupdate=None, + ondelete=None, + deferrable=None, + initially=None, + link_to_name=False, + match=None, + info=None, + **dialect_kw + ): r""" Construct a column-level FOREIGN KEY. @@ -1626,7 +1708,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): if isinstance(self._colspec, util.string_types): self._table_column = None else: - if hasattr(self._colspec, '__clause_element__'): + if hasattr(self._colspec, "__clause_element__"): self._table_column = self._colspec.__clause_element__() else: self._table_column = self._colspec @@ -1634,9 +1716,11 @@ class ForeignKey(DialectKWArgs, SchemaItem): if not isinstance(self._table_column, ColumnClause): raise exc.ArgumentError( "String, Column, or Column-bound argument " - "expected, got %r" % self._table_column) + "expected, got %r" % self._table_column + ) elif not isinstance( - self._table_column.table, (util.NoneType, TableClause)): + self._table_column.table, (util.NoneType, TableClause) + ): raise exc.ArgumentError( "ForeignKey received Column not bound " "to a Table, got: %r" % self._table_column.table @@ -1715,7 +1799,9 @@ class ForeignKey(DialectKWArgs, SchemaItem): return "%s.%s" % (table_name, colname) elif self._table_column is not None: return "%s.%s" % ( - self._table_column.table.fullname, self._table_column.key) + self._table_column.table.fullname, + self._table_column.key, + ) else: return self._colspec @@ -1756,12 +1842,12 @@ class ForeignKey(DialectKWArgs, SchemaItem): def _column_tokens(self): """parse a string-based _colspec into its component parts.""" - m = self._get_colspec().split('.') + m = self._get_colspec().split(".") if m is None: raise exc.ArgumentError( - "Invalid foreign key column specification: %s" % - self._colspec) - if (len(m) == 1): + "Invalid foreign key column specification: %s" % self._colspec + ) + if len(m) == 1: tname = m.pop() colname = None else: @@ -1777,8 +1863,8 @@ class ForeignKey(DialectKWArgs, SchemaItem): # indirectly related -- Ticket #594. This assumes that '.' # will never appear *within* any component of the FK. - if (len(m) > 0): - schema = '.'.join(m) + if len(m) > 0: + schema = ".".join(m) else: schema = None return schema, tname, colname @@ -1787,12 +1873,14 @@ class ForeignKey(DialectKWArgs, SchemaItem): if self.parent is None: raise exc.InvalidRequestError( "this ForeignKey object does not yet have a " - "parent Column associated with it.") + "parent Column associated with it." + ) elif self.parent.table is None: raise exc.InvalidRequestError( "this ForeignKey's parent column is not yet associated " - "with a Table.") + "with a Table." + ) parenttable = self.parent.table @@ -1817,7 +1905,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): return parenttable, tablekey, colname def _link_to_col_by_colstring(self, parenttable, table, colname): - if not hasattr(self.constraint, '_referred_table'): + if not hasattr(self.constraint, "_referred_table"): self.constraint._referred_table = table else: assert self.constraint._referred_table is table @@ -1843,9 +1931,11 @@ class ForeignKey(DialectKWArgs, SchemaItem): raise exc.NoReferencedColumnError( "Could not initialize target column " "for ForeignKey '%s' on table '%s': " - "table '%s' has no column named '%s'" % - (self._colspec, parenttable.name, table.name, key), - table.name, key) + "table '%s' has no column named '%s'" + % (self._colspec, parenttable.name, table.name, key), + table.name, + key, + ) self._set_target_column(_column) @@ -1861,6 +1951,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): def set_type(fk): if fk.parent.type._isnull: fk.parent.type = column.type + self.parent._setup_on_memoized_fks(set_type) self.column = column @@ -1888,21 +1979,25 @@ class ForeignKey(DialectKWArgs, SchemaItem): raise exc.NoReferencedTableError( "Foreign key associated with column '%s' could not find " "table '%s' with which to generate a " - "foreign key to target column '%s'" % - (self.parent, tablekey, colname), - tablekey) + "foreign key to target column '%s'" + % (self.parent, tablekey, colname), + tablekey, + ) elif parenttable.key not in parenttable.metadata: raise exc.InvalidRequestError( "Table %s is no longer associated with its " - "parent MetaData" % parenttable) + "parent MetaData" % parenttable + ) else: raise exc.NoReferencedColumnError( "Could not initialize target column for " "ForeignKey '%s' on table '%s': " - "table '%s' has no column named '%s'" % ( - self._colspec, parenttable.name, tablekey, colname), - tablekey, colname) - elif hasattr(self._colspec, '__clause_element__'): + "table '%s' has no column named '%s'" + % (self._colspec, parenttable.name, tablekey, colname), + tablekey, + colname, + ) + elif hasattr(self._colspec, "__clause_element__"): _column = self._colspec.__clause_element__() return _column else: @@ -1912,7 +2007,8 @@ class ForeignKey(DialectKWArgs, SchemaItem): def _set_parent(self, column): if self.parent is not None and self.parent is not column: raise exc.InvalidRequestError( - "This ForeignKey already has a parent !") + "This ForeignKey already has a parent !" + ) self.parent = column self.parent.foreign_keys.add(self) self.parent._on_table_attach(self._set_table) @@ -1935,9 +2031,14 @@ class ForeignKey(DialectKWArgs, SchemaItem): # on the hosting Table when attached to the Table. if self.constraint is None and isinstance(table, Table): self.constraint = ForeignKeyConstraint( - [], [], use_alter=self.use_alter, name=self.name, - onupdate=self.onupdate, ondelete=self.ondelete, - deferrable=self.deferrable, initially=self.initially, + [], + [], + use_alter=self.use_alter, + name=self.name, + onupdate=self.onupdate, + ondelete=self.ondelete, + deferrable=self.deferrable, + initially=self.initially, match=self.match, **self._unvalidated_dialect_kw ) @@ -1953,13 +2054,12 @@ class ForeignKey(DialectKWArgs, SchemaItem): if table_key in parenttable.metadata.tables: table = parenttable.metadata.tables[table_key] try: - self._link_to_col_by_colstring( - parenttable, table, colname) + self._link_to_col_by_colstring(parenttable, table, colname) except exc.NoReferencedColumnError: # this is OK, we'll try later pass parenttable.metadata._fk_memos[fk_key].append(self) - elif hasattr(self._colspec, '__clause_element__'): + elif hasattr(self._colspec, "__clause_element__"): _column = self._colspec.__clause_element__() self._set_target_column(_column) else: @@ -1971,7 +2071,8 @@ class _NotAColumnExpr(object): def _not_a_column_expr(self): raise exc.InvalidRequestError( "This %s cannot be used directly " - "as a column expression." % self.__class__.__name__) + "as a column expression." % self.__class__.__name__ + ) __clause_element__ = self_group = lambda self: self._not_a_column_expr() _from_objects = property(lambda self: self._not_a_column_expr()) @@ -1980,7 +2081,7 @@ class _NotAColumnExpr(object): class DefaultGenerator(_NotAColumnExpr, SchemaItem): """Base class for column *default* values.""" - __visit_name__ = 'default_generator' + __visit_name__ = "default_generator" is_sequence = False is_server_default = False @@ -2007,7 +2108,7 @@ class DefaultGenerator(_NotAColumnExpr, SchemaItem): @property def bind(self): """Return the connectable associated with this default.""" - if getattr(self, 'column', None) is not None: + if getattr(self, "column", None) is not None: return self.column.table.bind else: return None @@ -2064,7 +2165,8 @@ class ColumnDefault(DefaultGenerator): super(ColumnDefault, self).__init__(**kwargs) if isinstance(arg, FetchedValue): raise exc.ArgumentError( - "ColumnDefault may not be a server-side default type.") + "ColumnDefault may not be a server-side default type." + ) if util.callable(arg): arg = self._maybe_wrap_callable(arg) self.arg = arg @@ -2079,9 +2181,11 @@ class ColumnDefault(DefaultGenerator): @util.memoized_property def is_scalar(self): - return not self.is_callable and \ - not self.is_clause_element and \ - not self.is_sequence + return ( + not self.is_callable + and not self.is_clause_element + and not self.is_sequence + ) @util.memoized_property @util.dependencies("sqlalchemy.sql.sqltypes") @@ -2114,17 +2218,19 @@ class ColumnDefault(DefaultGenerator): else: raise exc.ArgumentError( "ColumnDefault Python function takes zero or one " - "positional arguments") + "positional arguments" + ) def _visit_name(self): if self.for_update: return "column_onupdate" else: return "column_default" + __visit_name__ = property(_visit_name) def __repr__(self): - return "ColumnDefault(%r)" % (self.arg, ) + return "ColumnDefault(%r)" % (self.arg,) class Sequence(DefaultGenerator): @@ -2157,15 +2263,29 @@ class Sequence(DefaultGenerator): """ - __visit_name__ = 'sequence' + __visit_name__ = "sequence" is_sequence = True - def __init__(self, name, start=None, increment=None, minvalue=None, - maxvalue=None, nominvalue=None, nomaxvalue=None, cycle=None, - schema=None, cache=None, order=None, optional=False, - quote=None, metadata=None, quote_schema=None, - for_update=False): + def __init__( + self, + name, + start=None, + increment=None, + minvalue=None, + maxvalue=None, + nominvalue=None, + nomaxvalue=None, + cycle=None, + schema=None, + cache=None, + order=None, + optional=False, + quote=None, + metadata=None, + quote_schema=None, + for_update=False, + ): """Construct a :class:`.Sequence` object. :param name: The name of the sequence. @@ -2353,27 +2473,22 @@ class Sequence(DefaultGenerator): if bind is None: bind = _bind_or_error(self) - bind._run_visitor(ddl.SchemaGenerator, - self, - checkfirst=checkfirst) + bind._run_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst) def drop(self, bind=None, checkfirst=True): """Drops this sequence from the database.""" if bind is None: bind = _bind_or_error(self) - bind._run_visitor(ddl.SchemaDropper, - self, - checkfirst=checkfirst) + bind._run_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst) def _not_a_column_expr(self): raise exc.InvalidRequestError( "This %s cannot be used directly " "as a column expression. Use func.next_value(sequence) " "to produce a 'next value' function that's usable " - "as a column element." - % self.__class__.__name__) - + "as a column element." % self.__class__.__name__ + ) @inspection._self_inspects @@ -2396,6 +2511,7 @@ class FetchedValue(_NotAColumnExpr, SchemaEventTarget): :ref:`triggered_columns` """ + is_server_default = True reflected = False has_argument = False @@ -2412,7 +2528,7 @@ class FetchedValue(_NotAColumnExpr, SchemaEventTarget): def _clone(self, for_update): n = self.__class__.__new__(self.__class__) n.__dict__.update(self.__dict__) - n.__dict__.pop('column', None) + n.__dict__.pop("column", None) n.for_update = for_update return n @@ -2452,16 +2568,15 @@ class DefaultClause(FetchedValue): has_argument = True def __init__(self, arg, for_update=False, _reflected=False): - util.assert_arg_type(arg, (util.string_types[0], - ClauseElement, - TextClause), 'arg') + util.assert_arg_type( + arg, (util.string_types[0], ClauseElement, TextClause), "arg" + ) super(DefaultClause, self).__init__(for_update) self.arg = arg self.reflected = _reflected def __repr__(self): - return "DefaultClause(%r, for_update=%r)" % \ - (self.arg, self.for_update) + return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update) class PassiveDefault(DefaultClause): @@ -2471,10 +2586,13 @@ class PassiveDefault(DefaultClause): :class:`.PassiveDefault` is deprecated. Use :class:`.DefaultClause`. """ - @util.deprecated("0.6", - ":class:`.PassiveDefault` is deprecated. " - "Use :class:`.DefaultClause`.", - False) + + @util.deprecated( + "0.6", + ":class:`.PassiveDefault` is deprecated. " + "Use :class:`.DefaultClause`.", + False, + ) def __init__(self, *arg, **kw): DefaultClause.__init__(self, *arg, **kw) @@ -2482,11 +2600,18 @@ class PassiveDefault(DefaultClause): class Constraint(DialectKWArgs, SchemaItem): """A table-level SQL constraint.""" - __visit_name__ = 'constraint' - - def __init__(self, name=None, deferrable=None, initially=None, - _create_rule=None, info=None, _type_bound=False, - **dialect_kw): + __visit_name__ = "constraint" + + def __init__( + self, + name=None, + deferrable=None, + initially=None, + _create_rule=None, + info=None, + _type_bound=False, + **dialect_kw + ): r"""Create a SQL constraint. :param name: @@ -2548,7 +2673,8 @@ class Constraint(DialectKWArgs, SchemaItem): pass raise exc.InvalidRequestError( "This constraint is not bound to a table. Did you " - "mean to call table.append_constraint(constraint) ?") + "mean to call table.append_constraint(constraint) ?" + ) def _set_parent(self, parent): self.parent = parent @@ -2559,7 +2685,7 @@ class Constraint(DialectKWArgs, SchemaItem): def _to_schema_column(element): - if hasattr(element, '__clause_element__'): + if hasattr(element, "__clause_element__"): element = element.__clause_element__() if not isinstance(element, Column): raise exc.ArgumentError("schema.Column object expected") @@ -2567,9 +2693,9 @@ def _to_schema_column(element): def _to_schema_column_or_string(element): - if hasattr(element, '__clause_element__'): + if hasattr(element, "__clause_element__"): element = element.__clause_element__() - if not isinstance(element, util.string_types + (ColumnElement, )): + if not isinstance(element, util.string_types + (ColumnElement,)): msg = "Element %r is not a string name or column element" raise exc.ArgumentError(msg % element) return element @@ -2588,11 +2714,12 @@ class ColumnCollectionMixin(object): _allow_multiple_tables = False def __init__(self, *columns, **kw): - _autoattach = kw.pop('_autoattach', True) - self._column_flag = kw.pop('_column_flag', False) + _autoattach = kw.pop("_autoattach", True) + self._column_flag = kw.pop("_column_flag", False) self.columns = ColumnCollection() - self._pending_colargs = [_to_schema_column_or_string(c) - for c in columns] + self._pending_colargs = [ + _to_schema_column_or_string(c) for c in columns + ] if _autoattach and self._pending_colargs: self._check_attach() @@ -2601,7 +2728,7 @@ class ColumnCollectionMixin(object): for expr in expressions: strname = None column = None - if hasattr(expr, '__clause_element__'): + if hasattr(expr, "__clause_element__"): expr = expr.__clause_element__() if not isinstance(expr, (ColumnElement, TextClause)): @@ -2609,21 +2736,16 @@ class ColumnCollectionMixin(object): strname = expr else: cols = [] - visitors.traverse(expr, {}, {'column': cols.append}) + visitors.traverse(expr, {}, {"column": cols.append}) if cols: column = cols[0] add_element = column if column is not None else strname yield expr, column, strname, add_element def _check_attach(self, evt=False): - col_objs = [ - c for c in self._pending_colargs - if isinstance(c, Column) - ] + col_objs = [c for c in self._pending_colargs if isinstance(c, Column)] - cols_w_table = [ - c for c in col_objs if isinstance(c.table, Table) - ] + cols_w_table = [c for c in col_objs if isinstance(c.table, Table)] cols_wo_table = set(col_objs).difference(cols_w_table) @@ -2636,6 +2758,7 @@ class ColumnCollectionMixin(object): # columns are specified as strings. has_string_cols = set(self._pending_colargs).difference(col_objs) if not has_string_cols: + def _col_attached(column, table): # this isinstance() corresponds with the # isinstance() above; only want to count Table-bound @@ -2644,6 +2767,7 @@ class ColumnCollectionMixin(object): cols_wo_table.discard(column) if not cols_wo_table: self._check_attach(evt=True) + self._cols_wo_table = cols_wo_table for col in cols_wo_table: col._on_table_attach(_col_attached) @@ -2659,9 +2783,11 @@ class ColumnCollectionMixin(object): others = [c for c in columns[1:] if c.table is not table] if others: raise exc.ArgumentError( - "Column(s) %s are not part of table '%s'." % - (", ".join("'%s'" % c for c in others), - table.description) + "Column(s) %s are not part of table '%s'." + % ( + ", ".join("'%s'" % c for c in others), + table.description, + ) ) def _set_parent(self, table): @@ -2694,11 +2820,12 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): arguments are propagated to the :class:`.Constraint` superclass. """ - _autoattach = kw.pop('_autoattach', True) - _column_flag = kw.pop('_column_flag', False) + _autoattach = kw.pop("_autoattach", True) + _column_flag = kw.pop("_column_flag", False) Constraint.__init__(self, **kw) ColumnCollectionMixin.__init__( - self, *columns, _autoattach=_autoattach, _column_flag=_column_flag) + self, *columns, _autoattach=_autoattach, _column_flag=_column_flag + ) columns = None """A :class:`.ColumnCollection` representing the set of columns @@ -2714,8 +2841,12 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): return x in self.columns def copy(self, **kw): - c = self.__class__(name=self.name, deferrable=self.deferrable, - initially=self.initially, *self.columns.keys()) + c = self.__class__( + name=self.name, + deferrable=self.deferrable, + initially=self.initially, + *self.columns.keys() + ) return self._schema_item_copy(c) def contains_column(self, col): @@ -2747,9 +2878,19 @@ class CheckConstraint(ColumnCollectionConstraint): _allow_multiple_tables = True - def __init__(self, sqltext, name=None, deferrable=None, - initially=None, table=None, info=None, _create_rule=None, - _autoattach=True, _type_bound=False, **kw): + def __init__( + self, + sqltext, + name=None, + deferrable=None, + initially=None, + table=None, + info=None, + _create_rule=None, + _autoattach=True, + _type_bound=False, + **kw + ): r"""Construct a CHECK constraint. :param sqltext: @@ -2781,14 +2922,19 @@ class CheckConstraint(ColumnCollectionConstraint): self.sqltext = _literal_as_text(sqltext, warn=False) columns = [] - visitors.traverse(self.sqltext, {}, {'column': columns.append}) - - super(CheckConstraint, self).\ - __init__( - name=name, deferrable=deferrable, - initially=initially, _create_rule=_create_rule, info=info, - _type_bound=_type_bound, _autoattach=_autoattach, - *columns, **kw) + visitors.traverse(self.sqltext, {}, {"column": columns.append}) + + super(CheckConstraint, self).__init__( + name=name, + deferrable=deferrable, + initially=initially, + _create_rule=_create_rule, + info=info, + _type_bound=_type_bound, + _autoattach=_autoattach, + *columns, + **kw + ) if table is not None: self._set_parent_with_dispatch(table) @@ -2797,22 +2943,24 @@ class CheckConstraint(ColumnCollectionConstraint): return "check_constraint" else: return "column_check_constraint" + __visit_name__ = property(__visit_name__) def copy(self, target_table=None, **kw): if target_table is not None: - sqltext = _copy_expression( - self.sqltext, self.table, target_table) + sqltext = _copy_expression(self.sqltext, self.table, target_table) else: sqltext = self.sqltext - c = CheckConstraint(sqltext, - name=self.name, - initially=self.initially, - deferrable=self.deferrable, - _create_rule=self._create_rule, - table=target_table, - _autoattach=False, - _type_bound=self._type_bound) + c = CheckConstraint( + sqltext, + name=self.name, + initially=self.initially, + deferrable=self.deferrable, + _create_rule=self._create_rule, + table=target_table, + _autoattach=False, + _type_bound=self._type_bound, + ) return self._schema_item_copy(c) @@ -2828,12 +2976,25 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): Examples of foreign key configuration are in :ref:`metadata_foreignkeys`. """ - __visit_name__ = 'foreign_key_constraint' - def __init__(self, columns, refcolumns, name=None, onupdate=None, - ondelete=None, deferrable=None, initially=None, - use_alter=False, link_to_name=False, match=None, - table=None, info=None, **dialect_kw): + __visit_name__ = "foreign_key_constraint" + + def __init__( + self, + columns, + refcolumns, + name=None, + onupdate=None, + ondelete=None, + deferrable=None, + initially=None, + use_alter=False, + link_to_name=False, + match=None, + table=None, + info=None, + **dialect_kw + ): r"""Construct a composite-capable FOREIGN KEY. :param columns: A sequence of local column names. The named columns @@ -2905,8 +3066,13 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): """ Constraint.__init__( - self, name=name, deferrable=deferrable, initially=initially, - info=info, **dialect_kw) + self, + name=name, + deferrable=deferrable, + initially=initially, + info=info, + **dialect_kw + ) self.onupdate = onupdate self.ondelete = ondelete self.link_to_name = link_to_name @@ -2927,7 +3093,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): raise exc.ArgumentError( "ForeignKeyConstraint number " "of constrained columns must match the number of " - "referenced columns.") + "referenced columns." + ) # standalone ForeignKeyConstraint - create # associated ForeignKey objects which will be applied to hosted @@ -2946,7 +3113,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): deferrable=self.deferrable, initially=self.initially, **self.dialect_kwargs - ) for refcol in refcolumns + ) + for refcol in refcolumns ] ColumnCollectionMixin.__init__(self, *columns) @@ -2978,9 +3146,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): @property def _elements(self): # legacy - provide a dictionary view of (column_key, fk) - return util.OrderedDict( - zip(self.column_keys, self.elements) - ) + return util.OrderedDict(zip(self.column_keys, self.elements)) @property def _referred_schema(self): @@ -3004,18 +3170,14 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): return self.elements[0].column.table def _validate_dest_table(self, table): - table_keys = set([elem._table_key() - for elem in self.elements]) + table_keys = set([elem._table_key() for elem in self.elements]) if None not in table_keys and len(table_keys) > 1: elem0, elem1 = sorted(table_keys)[0:2] raise exc.ArgumentError( - 'ForeignKeyConstraint on %s(%s) refers to ' - 'multiple remote tables: %s and %s' % ( - table.fullname, - self._col_description, - elem0, - elem1 - )) + "ForeignKeyConstraint on %s(%s) refers to " + "multiple remote tables: %s and %s" + % (table.fullname, self._col_description, elem0, elem1) + ) @property def column_keys(self): @@ -3034,8 +3196,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): return self.columns.keys() else: return [ - col.key if isinstance(col, ColumnElement) - else str(col) for col in self._pending_colargs + col.key if isinstance(col, ColumnElement) else str(col) + for col in self._pending_colargs ] @property @@ -3051,11 +3213,11 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): raise exc.ArgumentError( "Can't create ForeignKeyConstraint " "on table '%s': no column " - "named '%s' is present." % (table.description, ke.args[0])) + "named '%s' is present." % (table.description, ke.args[0]) + ) for col, fk in zip(self.columns, self.elements): - if not hasattr(fk, 'parent') or \ - fk.parent is not col: + if not hasattr(fk, "parent") or fk.parent is not col: fk._set_parent_with_dispatch(col) self._validate_dest_table(table) @@ -3063,13 +3225,16 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): def copy(self, schema=None, target_table=None, **kw): fkc = ForeignKeyConstraint( [x.parent.key for x in self.elements], - [x._get_colspec( - schema=schema, - table_name=target_table.name - if target_table is not None - and x._table_key() == x.parent.table.key - else None) - for x in self.elements], + [ + x._get_colspec( + schema=schema, + table_name=target_table.name + if target_table is not None + and x._table_key() == x.parent.table.key + else None, + ) + for x in self.elements + ], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, @@ -3077,11 +3242,9 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): deferrable=self.deferrable, initially=self.initially, link_to_name=self.link_to_name, - match=self.match + match=self.match, ) - for self_fk, other_fk in zip( - self.elements, - fkc.elements): + for self_fk, other_fk in zip(self.elements, fkc.elements): self_fk._schema_item_copy(other_fk) return self._schema_item_copy(fkc) @@ -3160,10 +3323,10 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): """ - __visit_name__ = 'primary_key_constraint' + __visit_name__ = "primary_key_constraint" def __init__(self, *columns, **kw): - self._implicit_generated = kw.pop('_implicit_generated', False) + self._implicit_generated = kw.pop("_implicit_generated", False) super(PrimaryKeyConstraint, self).__init__(*columns, **kw) def _set_parent(self, table): @@ -3175,18 +3338,21 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): table.constraints.add(self) table_pks = [c for c in table.c if c.primary_key] - if self.columns and table_pks and \ - set(table_pks) != set(self.columns.values()): + if ( + self.columns + and table_pks + and set(table_pks) != set(self.columns.values()) + ): util.warn( "Table '%s' specifies columns %s as primary_key=True, " "not matching locally specified columns %s; setting the " "current primary key columns to %s. This warning " - "may become an exception in a future release" % - ( + "may become an exception in a future release" + % ( table.name, ", ".join("'%s'" % c.name for c in table_pks), ", ".join("'%s'" % c.name for c in self.columns), - ", ".join("'%s'" % c.name for c in self.columns) + ", ".join("'%s'" % c.name for c in self.columns), ) ) table_pks[:] = [] @@ -3241,28 +3407,28 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): @util.memoized_property def _autoincrement_column(self): - def _validate_autoinc(col, autoinc_true): if col.type._type_affinity is None or not issubclass( - col.type._type_affinity, - type_api.INTEGERTYPE._type_affinity): + col.type._type_affinity, type_api.INTEGERTYPE._type_affinity + ): if autoinc_true: raise exc.ArgumentError( "Column type %s on column '%s' is not " - "compatible with autoincrement=True" % ( - col.type, - col - )) + "compatible with autoincrement=True" % (col.type, col) + ) else: return False - elif not isinstance(col.default, (type(None), Sequence)) and \ - not autoinc_true: - return False + elif ( + not isinstance(col.default, (type(None), Sequence)) + and not autoinc_true + ): + return False elif col.server_default is not None and not autoinc_true: return False - elif ( - col.foreign_keys and col.autoincrement - not in (True, 'ignore_fk')): + elif col.foreign_keys and col.autoincrement not in ( + True, + "ignore_fk", + ): return False return True @@ -3272,10 +3438,10 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): if col.autoincrement is True: _validate_autoinc(col, True) return col - elif ( - col.autoincrement in ('auto', 'ignore_fk') and - _validate_autoinc(col, False) - ): + elif col.autoincrement in ( + "auto", + "ignore_fk", + ) and _validate_autoinc(col, False): return col else: @@ -3286,8 +3452,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): if autoinc is not None: raise exc.ArgumentError( "Only one Column may be marked " - "autoincrement=True, found both %s and %s." % - (col.name, autoinc.name) + "autoincrement=True, found both %s and %s." + % (col.name, autoinc.name) ) else: autoinc = col @@ -3304,7 +3470,7 @@ class UniqueConstraint(ColumnCollectionConstraint): UniqueConstraint. """ - __visit_name__ = 'unique_constraint' + __visit_name__ = "unique_constraint" class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): @@ -3382,7 +3548,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): """ - __visit_name__ = 'index' + __visit_name__ = "index" def __init__(self, name, *expressions, **kw): r"""Construct an index object. @@ -3420,30 +3586,35 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): columns = [] processed_expressions = [] - for expr, column, strname, add_element in self.\ - _extract_col_expression_collection(expressions): + for ( + expr, + column, + strname, + add_element, + ) in self._extract_col_expression_collection(expressions): if add_element is not None: columns.append(add_element) processed_expressions.append(expr) self.expressions = processed_expressions self.name = quoted_name(name, kw.pop("quote", None)) - self.unique = kw.pop('unique', False) - _column_flag = kw.pop('_column_flag', False) - if 'info' in kw: - self.info = kw.pop('info') + self.unique = kw.pop("unique", False) + _column_flag = kw.pop("_column_flag", False) + if "info" in kw: + self.info = kw.pop("info") # TODO: consider "table" argument being public, but for # the purpose of the fix here, it starts as private. - if '_table' in kw: - table = kw.pop('_table') + if "_table" in kw: + table = kw.pop("_table") self._validate_dialect_kwargs(kw) # will call _set_parent() if table-bound column # objects are present ColumnCollectionMixin.__init__( - self, *columns, _column_flag=_column_flag) + self, *columns, _column_flag=_column_flag + ) if table is not None: self._set_parent(table) @@ -3454,20 +3625,17 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): if self.table is not None and table is not self.table: raise exc.ArgumentError( "Index '%s' is against table '%s', and " - "cannot be associated with table '%s'." % ( - self.name, - self.table.description, - table.description - ) + "cannot be associated with table '%s'." + % (self.name, self.table.description, table.description) ) self.table = table table.indexes.add(self) self.expressions = [ - expr if isinstance(expr, ClauseElement) - else colexpr - for expr, colexpr in util.zip_longest(self.expressions, - self.columns) + expr if isinstance(expr, ClauseElement) else colexpr + for expr, colexpr in util.zip_longest( + self.expressions, self.columns + ) ] @property @@ -3506,17 +3674,16 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): bind._run_visitor(ddl.SchemaDropper, self) def __repr__(self): - return 'Index(%s)' % ( + return "Index(%s)" % ( ", ".join( - [repr(self.name)] + - [repr(e) for e in self.expressions] + - (self.unique and ["unique=True"] or []) - )) + [repr(self.name)] + + [repr(e) for e in self.expressions] + + (self.unique and ["unique=True"] or []) + ) + ) -DEFAULT_NAMING_CONVENTION = util.immutabledict({ - "ix": 'ix_%(column_0_label)s' -}) +DEFAULT_NAMING_CONVENTION = util.immutabledict({"ix": "ix_%(column_0_label)s"}) class MetaData(SchemaItem): @@ -3542,13 +3709,17 @@ class MetaData(SchemaItem): """ - __visit_name__ = 'metadata' - - def __init__(self, bind=None, reflect=False, schema=None, - quote_schema=None, - naming_convention=DEFAULT_NAMING_CONVENTION, - info=None - ): + __visit_name__ = "metadata" + + def __init__( + self, + bind=None, + reflect=False, + schema=None, + quote_schema=None, + naming_convention=DEFAULT_NAMING_CONVENTION, + info=None, + ): """Create a new MetaData object. :param bind: @@ -3712,12 +3883,15 @@ class MetaData(SchemaItem): self.bind = bind if reflect: - util.warn_deprecated("reflect=True is deprecate; please " - "use the reflect() method.") + util.warn_deprecated( + "reflect=True is deprecate; please " + "use the reflect() method." + ) if not bind: raise exc.ArgumentError( "A bind must be supplied in conjunction " - "with reflect=True") + "with reflect=True" + ) self.reflect() tables = None @@ -3735,7 +3909,7 @@ class MetaData(SchemaItem): """ def __repr__(self): - return 'MetaData(bind=%r)' % self.bind + return "MetaData(bind=%r)" % self.bind def __contains__(self, table_or_key): if not isinstance(table_or_key, util.string_types): @@ -3755,27 +3929,32 @@ class MetaData(SchemaItem): for fk in removed.foreign_keys: fk._remove_from_metadata(self) if self._schemas: - self._schemas = set([t.schema - for t in self.tables.values() - if t.schema is not None]) + self._schemas = set( + [ + t.schema + for t in self.tables.values() + if t.schema is not None + ] + ) def __getstate__(self): - return {'tables': self.tables, - 'schema': self.schema, - 'schemas': self._schemas, - 'sequences': self._sequences, - 'fk_memos': self._fk_memos, - 'naming_convention': self.naming_convention - } + return { + "tables": self.tables, + "schema": self.schema, + "schemas": self._schemas, + "sequences": self._sequences, + "fk_memos": self._fk_memos, + "naming_convention": self.naming_convention, + } def __setstate__(self, state): - self.tables = state['tables'] - self.schema = state['schema'] - self.naming_convention = state['naming_convention'] + self.tables = state["tables"] + self.schema = state["schema"] + self.naming_convention = state["naming_convention"] self._bind = None - self._sequences = state['sequences'] - self._schemas = state['schemas'] - self._fk_memos = state['fk_memos'] + self._sequences = state["sequences"] + self._schemas = state["schemas"] + self._fk_memos = state["fk_memos"] def is_bound(self): """True if this MetaData is bound to an Engine or Connection.""" @@ -3805,10 +3984,11 @@ class MetaData(SchemaItem): def _bind_to(self, url, bind): """Bind this MetaData to an Engine, Connection, string or URL.""" - if isinstance(bind, util.string_types + (url.URL, )): + if isinstance(bind, util.string_types + (url.URL,)): self._bind = sqlalchemy.create_engine(bind) else: self._bind = bind + bind = property(bind, _bind_to) def clear(self): @@ -3858,12 +4038,20 @@ class MetaData(SchemaItem): """ - return ddl.sort_tables(sorted(self.tables.values(), key=lambda t: t.key)) + return ddl.sort_tables( + sorted(self.tables.values(), key=lambda t: t.key) + ) - def reflect(self, bind=None, schema=None, views=False, only=None, - extend_existing=False, - autoload_replace=True, - **dialect_kwargs): + def reflect( + self, + bind=None, + schema=None, + views=False, + only=None, + extend_existing=False, + autoload_replace=True, + **dialect_kwargs + ): r"""Load all available table definitions from the database. Automatically creates ``Table`` entries in this ``MetaData`` for any @@ -3926,11 +4114,11 @@ class MetaData(SchemaItem): with bind.connect() as conn: reflect_opts = { - 'autoload': True, - 'autoload_with': conn, - 'extend_existing': extend_existing, - 'autoload_replace': autoload_replace, - '_extend_on': set() + "autoload": True, + "autoload_with": conn, + "extend_existing": extend_existing, + "autoload_replace": autoload_replace, + "_extend_on": set(), } reflect_opts.update(dialect_kwargs) @@ -3939,42 +4127,49 @@ class MetaData(SchemaItem): schema = self.schema if schema is not None: - reflect_opts['schema'] = schema + reflect_opts["schema"] = schema available = util.OrderedSet( - bind.engine.table_names(schema, connection=conn)) + bind.engine.table_names(schema, connection=conn) + ) if views: - available.update( - bind.dialect.get_view_names(conn, schema) - ) + available.update(bind.dialect.get_view_names(conn, schema)) if schema is not None: - available_w_schema = util.OrderedSet(["%s.%s" % (schema, name) - for name in available]) + available_w_schema = util.OrderedSet( + ["%s.%s" % (schema, name) for name in available] + ) else: available_w_schema = available current = set(self.tables) if only is None: - load = [name for name, schname in - zip(available, available_w_schema) - if extend_existing or schname not in current] + load = [ + name + for name, schname in zip(available, available_w_schema) + if extend_existing or schname not in current + ] elif util.callable(only): - load = [name for name, schname in - zip(available, available_w_schema) - if (extend_existing or schname not in current) - and only(name, self)] + load = [ + name + for name, schname in zip(available, available_w_schema) + if (extend_existing or schname not in current) + and only(name, self) + ] else: missing = [name for name in only if name not in available] if missing: - s = schema and (" schema '%s'" % schema) or '' + s = schema and (" schema '%s'" % schema) or "" raise exc.InvalidRequestError( - 'Could not reflect: requested table(s) not available ' - 'in %r%s: (%s)' % - (bind.engine, s, ', '.join(missing))) - load = [name for name in only if extend_existing or - name not in current] + "Could not reflect: requested table(s) not available " + "in %r%s: (%s)" % (bind.engine, s, ", ".join(missing)) + ) + load = [ + name + for name in only + if extend_existing or name not in current + ] for name in load: try: @@ -3989,11 +4184,12 @@ class MetaData(SchemaItem): See :class:`.DDLEvents`. """ + def adapt_listener(target, connection, **kw): - tables = kw['tables'] + tables = kw["tables"] listener(event, target, connection, tables=tables) - event.listen(self, "" + event_name.replace('-', '_'), adapt_listener) + event.listen(self, "" + event_name.replace("-", "_"), adapt_listener) def create_all(self, bind=None, tables=None, checkfirst=True): """Create all tables stored in this metadata. @@ -4017,10 +4213,9 @@ class MetaData(SchemaItem): """ if bind is None: bind = _bind_or_error(self) - bind._run_visitor(ddl.SchemaGenerator, - self, - checkfirst=checkfirst, - tables=tables) + bind._run_visitor( + ddl.SchemaGenerator, self, checkfirst=checkfirst, tables=tables + ) def drop_all(self, bind=None, tables=None, checkfirst=True): """Drop all tables stored in this metadata. @@ -4044,10 +4239,9 @@ class MetaData(SchemaItem): """ if bind is None: bind = _bind_or_error(self) - bind._run_visitor(ddl.SchemaDropper, - self, - checkfirst=checkfirst, - tables=tables) + bind._run_visitor( + ddl.SchemaDropper, self, checkfirst=checkfirst, tables=tables + ) class ThreadLocalMetaData(MetaData): @@ -4064,7 +4258,7 @@ class ThreadLocalMetaData(MetaData): """ - __visit_name__ = 'metadata' + __visit_name__ = "metadata" def __init__(self): """Construct a ThreadLocalMetaData.""" @@ -4080,13 +4274,13 @@ class ThreadLocalMetaData(MetaData): string or URL to automatically create a basic Engine for this bind with ``create_engine()``.""" - return getattr(self.context, '_engine', None) + return getattr(self.context, "_engine", None) @util.dependencies("sqlalchemy.engine.url") def _bind_to(self, url, bind): """Bind to a Connectable in the caller's thread.""" - if isinstance(bind, util.string_types + (url.URL, )): + if isinstance(bind, util.string_types + (url.URL,)): try: self.context._engine = self.__engines[bind] except KeyError: @@ -4104,14 +4298,16 @@ class ThreadLocalMetaData(MetaData): def is_bound(self): """True if there is a bind for this thread.""" - return (hasattr(self.context, '_engine') and - self.context._engine is not None) + return ( + hasattr(self.context, "_engine") + and self.context._engine is not None + ) def dispose(self): """Dispose all bound engines, in all thread contexts.""" for e in self.__engines.values(): - if hasattr(e, 'dispose'): + if hasattr(e, "dispose"): e.dispose() @@ -4128,22 +4324,25 @@ class _SchemaTranslateMap(object): """ - __slots__ = 'map_', '__call__', 'hash_key', 'is_default' + + __slots__ = "map_", "__call__", "hash_key", "is_default" _default_schema_getter = operator.attrgetter("schema") def __init__(self, map_): self.map_ = map_ if map_ is not None: + def schema_for_object(obj): effective_schema = self._default_schema_getter(obj) effective_schema = obj._translate_schema( - effective_schema, map_) + effective_schema, map_ + ) return effective_schema + self.__call__ = schema_for_object self.hash_key = ";".join( - "%s=%s" % (k, map_[k]) - for k in sorted(map_, key=str) + "%s=%s" % (k, map_[k]) for k in sorted(map_, key=str) ) self.is_default = False else: @@ -4160,6 +4359,6 @@ class _SchemaTranslateMap(object): else: return _SchemaTranslateMap(map_) + _default_schema_map = _SchemaTranslateMap(None) _schema_getter = _SchemaTranslateMap._schema_getter - diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index f64f152c48..1f1800514f 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -10,15 +10,39 @@ SQL tables and derived rowsets. """ -from .elements import ClauseElement, TextClause, ClauseList, \ - and_, Grouping, UnaryExpression, literal_column, BindParameter -from .elements import _clone, \ - _literal_as_text, _interpret_as_column_or_from, _expand_cloned,\ - _select_iterables, _anonymous_label, _clause_element_as_expr,\ - _cloned_intersection, _cloned_difference, True_, \ - _literal_as_label_reference, _literal_and_labels_as_label_reference -from .base import Immutable, Executable, _generative, \ - ColumnCollection, ColumnSet, _from_objects, Generative +from .elements import ( + ClauseElement, + TextClause, + ClauseList, + and_, + Grouping, + UnaryExpression, + literal_column, + BindParameter, +) +from .elements import ( + _clone, + _literal_as_text, + _interpret_as_column_or_from, + _expand_cloned, + _select_iterables, + _anonymous_label, + _clause_element_as_expr, + _cloned_intersection, + _cloned_difference, + True_, + _literal_as_label_reference, + _literal_and_labels_as_label_reference, +) +from .base import ( + Immutable, + Executable, + _generative, + ColumnCollection, + ColumnSet, + _from_objects, + Generative, +) from . import type_api from .. import inspection from .. import util @@ -40,7 +64,8 @@ def _interpret_as_from(element): "Textual SQL FROM expression %(expr)r should be " "explicitly declared as text(%(expr)r), " "or use table(%(expr)r) for more specificity", - {"expr": util.ellipses_string(element)}) + {"expr": util.ellipses_string(element)}, + ) return TextClause(util.text_type(element)) try: @@ -73,7 +98,7 @@ def _offset_or_limit_clause(element, name=None, type_=None): """ if element is None: return None - elif hasattr(element, '__clause_element__'): + elif hasattr(element, "__clause_element__"): return element.__clause_element__() elif isinstance(element, Visitable): return element @@ -97,7 +122,8 @@ def _offset_or_limit_clause_asint(clause, attrname): except AttributeError: raise exc.CompileError( "This SELECT structure does not use a simple " - "integer value for %s" % attrname) + "integer value for %s" % attrname + ) else: return util.asint(value) @@ -225,12 +251,14 @@ def tablesample(selectable, sampling, name=None, seed=None): """ return _interpret_as_from(selectable).tablesample( - sampling, name=name, seed=seed) + sampling, name=name, seed=seed + ) class Selectable(ClauseElement): """mark a class as being selectable""" - __visit_name__ = 'selectable' + + __visit_name__ = "selectable" is_selectable = True @@ -265,15 +293,17 @@ class HasPrefixes(object): limit rendering of this prefix to only that dialect. """ - dialect = kw.pop('dialect', None) + dialect = kw.pop("dialect", None) if kw: - raise exc.ArgumentError("Unsupported argument(s): %s" % - ",".join(kw)) + raise exc.ArgumentError( + "Unsupported argument(s): %s" % ",".join(kw) + ) self._setup_prefixes(expr, dialect) def _setup_prefixes(self, prefixes, dialect=None): self._prefixes = self._prefixes + tuple( - [(_literal_as_text(p, warn=False), dialect) for p in prefixes]) + [(_literal_as_text(p, warn=False), dialect) for p in prefixes] + ) class HasSuffixes(object): @@ -301,15 +331,17 @@ class HasSuffixes(object): limit rendering of this suffix to only that dialect. """ - dialect = kw.pop('dialect', None) + dialect = kw.pop("dialect", None) if kw: - raise exc.ArgumentError("Unsupported argument(s): %s" % - ",".join(kw)) + raise exc.ArgumentError( + "Unsupported argument(s): %s" % ",".join(kw) + ) self._setup_suffixes(expr, dialect) def _setup_suffixes(self, suffixes, dialect=None): self._suffixes = self._suffixes + tuple( - [(_literal_as_text(p, warn=False), dialect) for p in suffixes]) + [(_literal_as_text(p, warn=False), dialect) for p in suffixes] + ) class FromClause(Selectable): @@ -330,7 +362,8 @@ class FromClause(Selectable): """ - __visit_name__ = 'fromclause' + + __visit_name__ = "fromclause" named_with_column = False _hide_froms = [] @@ -359,13 +392,14 @@ class FromClause(Selectable): _memoized_property = util.group_expirable_memoized_property(["_columns"]) @util.deprecated( - '1.1', + "1.1", message="``FromClause.count()`` is deprecated. Counting " "rows requires that the correct column expression and " "accommodations for joins, DISTINCT, etc. must be made, " "otherwise results may not be what's expected. " "Please use an appropriate ``func.count()`` expression " - "directly.") + "directly.", + ) @util.dependencies("sqlalchemy.sql.functions") def count(self, functions, whereclause=None, **params): """return a SELECT COUNT generated against this @@ -392,10 +426,11 @@ class FromClause(Selectable): else: col = list(self.columns)[0] return Select( - [functions.func.count(col).label('tbl_row_count')], + [functions.func.count(col).label("tbl_row_count")], whereclause, from_obj=[self], - **params) + **params + ) def select(self, whereclause=None, **params): """return a SELECT of this :class:`.FromClause`. @@ -603,8 +638,9 @@ class FromClause(Selectable): def embedded(expanded_proxy_set, target_set): for t in target_set.difference(expanded_proxy_set): - if not set(_expand_cloned([t]) - ).intersection(expanded_proxy_set): + if not set(_expand_cloned([t])).intersection( + expanded_proxy_set + ): return False return True @@ -617,8 +653,10 @@ class FromClause(Selectable): for c in cols: expanded_proxy_set = set(_expand_cloned(c.proxy_set)) i = target_set.intersection(expanded_proxy_set) - if i and (not require_embedded - or embedded(expanded_proxy_set, target_set)): + if i and ( + not require_embedded + or embedded(expanded_proxy_set, target_set) + ): if col is None: # no corresponding column yet, pick this one. @@ -646,12 +684,20 @@ class FromClause(Selectable): col_distance = util.reduce( operator.add, - [sc._annotations.get('weight', 1) for sc in - col.proxy_set if sc.shares_lineage(column)]) + [ + sc._annotations.get("weight", 1) + for sc in col.proxy_set + if sc.shares_lineage(column) + ], + ) c_distance = util.reduce( operator.add, - [sc._annotations.get('weight', 1) for sc in - c.proxy_set if sc.shares_lineage(column)]) + [ + sc._annotations.get("weight", 1) + for sc in c.proxy_set + if sc.shares_lineage(column) + ], + ) if c_distance < col_distance: col, intersect = c, i return col @@ -663,7 +709,7 @@ class FromClause(Selectable): Used primarily for error message formatting. """ - return getattr(self, 'name', self.__class__.__name__ + " object") + return getattr(self, "name", self.__class__.__name__ + " object") def _reset_exported(self): """delete memoized collections when a FromClause is cloned.""" @@ -683,7 +729,7 @@ class FromClause(Selectable): """ - if '_columns' not in self.__dict__: + if "_columns" not in self.__dict__: self._init_collections() self._populate_column_collection() return self._columns.as_immutable() @@ -706,14 +752,16 @@ class FromClause(Selectable): self._populate_column_collection() return self.foreign_keys - c = property(attrgetter('columns'), - doc="An alias for the :attr:`.columns` attribute.") - _select_iterable = property(attrgetter('columns')) + c = property( + attrgetter("columns"), + doc="An alias for the :attr:`.columns` attribute.", + ) + _select_iterable = property(attrgetter("columns")) def _init_collections(self): - assert '_columns' not in self.__dict__ - assert 'primary_key' not in self.__dict__ - assert 'foreign_keys' not in self.__dict__ + assert "_columns" not in self.__dict__ + assert "primary_key" not in self.__dict__ + assert "foreign_keys" not in self.__dict__ self._columns = ColumnCollection() self.primary_key = ColumnSet() @@ -721,7 +769,7 @@ class FromClause(Selectable): @property def _cols_populated(self): - return '_columns' in self.__dict__ + return "_columns" in self.__dict__ def _populate_column_collection(self): """Called on subclasses to establish the .c collection. @@ -758,8 +806,7 @@ class FromClause(Selectable): """ if not self._cols_populated: return None - elif (column.key in self.columns and - self.columns[column.key] is column): + elif column.key in self.columns and self.columns[column.key] is column: return column else: return None @@ -780,7 +827,8 @@ class Join(FromClause): :meth:`.FromClause.join` """ - __visit_name__ = 'join' + + __visit_name__ = "join" _is_join = True @@ -829,8 +877,9 @@ class Join(FromClause): return cls(left, right, onclause, isouter=True, full=full) @classmethod - def _create_join(cls, left, right, onclause=None, isouter=False, - full=False): + def _create_join( + cls, left, right, onclause=None, isouter=False, full=False + ): """Produce a :class:`.Join` object, given two :class:`.FromClause` expressions. @@ -882,26 +931,34 @@ class Join(FromClause): self.left.description, id(self.left), self.right.description, - id(self.right)) + id(self.right), + ) def is_derived_from(self, fromclause): - return fromclause is self or \ - self.left.is_derived_from(fromclause) or \ - self.right.is_derived_from(fromclause) + return ( + fromclause is self + or self.left.is_derived_from(fromclause) + or self.right.is_derived_from(fromclause) + ) def self_group(self, against=None): return FromGrouping(self) @util.dependencies("sqlalchemy.sql.util") def _populate_column_collection(self, sqlutil): - columns = [c for c in self.left.columns] + \ - [c for c in self.right.columns] + columns = [c for c in self.left.columns] + [ + c for c in self.right.columns + ] - self.primary_key.extend(sqlutil.reduce_columns( - (c for c in columns if c.primary_key), self.onclause)) + self.primary_key.extend( + sqlutil.reduce_columns( + (c for c in columns if c.primary_key), self.onclause + ) + ) self._columns.update((col._label, col) for col in columns) - self.foreign_keys.update(itertools.chain( - *[col.foreign_keys for col in columns])) + self.foreign_keys.update( + itertools.chain(*[col.foreign_keys for col in columns]) + ) def _refresh_for_new_column(self, column): col = self.left._refresh_for_new_column(column) @@ -933,9 +990,14 @@ class Join(FromClause): return self._join_condition(left, right, a_subset=left_right) @classmethod - def _join_condition(cls, a, b, ignore_nonexistent_tables=False, - a_subset=None, - consider_as_foreign_keys=None): + def _join_condition( + cls, + a, + b, + ignore_nonexistent_tables=False, + a_subset=None, + consider_as_foreign_keys=None, + ): """create a join condition between two tables or selectables. e.g.:: @@ -963,26 +1025,31 @@ class Join(FromClause): """ constraints = cls._joincond_scan_left_right( - a, a_subset, b, consider_as_foreign_keys) + a, a_subset, b, consider_as_foreign_keys + ) if len(constraints) > 1: cls._joincond_trim_constraints( - a, b, constraints, consider_as_foreign_keys) + a, b, constraints, consider_as_foreign_keys + ) if len(constraints) == 0: if isinstance(b, FromGrouping): - hint = " Perhaps you meant to convert the right side to a "\ + hint = ( + " Perhaps you meant to convert the right side to a " "subquery using alias()?" + ) else: hint = "" raise exc.NoForeignKeysError( "Can't find any foreign key relationships " - "between '%s' and '%s'.%s" % - (a.description, b.description, hint)) + "between '%s' and '%s'.%s" + % (a.description, b.description, hint) + ) crit = [(x == y) for x, y in list(constraints.values())[0]] if len(crit) == 1: - return (crit[0]) + return crit[0] else: return and_(*crit) @@ -994,24 +1061,30 @@ class Join(FromClause): left_right = None constraints = cls._joincond_scan_left_right( - a=left, b=right, a_subset=left_right, - consider_as_foreign_keys=consider_as_foreign_keys) + a=left, + b=right, + a_subset=left_right, + consider_as_foreign_keys=consider_as_foreign_keys, + ) return bool(constraints) @classmethod def _joincond_scan_left_right( - cls, a, a_subset, b, consider_as_foreign_keys): + cls, a, a_subset, b, consider_as_foreign_keys + ): constraints = collections.defaultdict(list) for left in (a_subset, a): if left is None: continue for fk in sorted( - b.foreign_keys, - key=lambda fk: fk.parent._creation_order): - if consider_as_foreign_keys is not None and \ - fk.parent not in consider_as_foreign_keys: + b.foreign_keys, key=lambda fk: fk.parent._creation_order + ): + if ( + consider_as_foreign_keys is not None + and fk.parent not in consider_as_foreign_keys + ): continue try: col = fk.get_referent(left) @@ -1025,10 +1098,12 @@ class Join(FromClause): constraints[fk.constraint].append((col, fk.parent)) if left is not b: for fk in sorted( - left.foreign_keys, - key=lambda fk: fk.parent._creation_order): - if consider_as_foreign_keys is not None and \ - fk.parent not in consider_as_foreign_keys: + left.foreign_keys, key=lambda fk: fk.parent._creation_order + ): + if ( + consider_as_foreign_keys is not None + and fk.parent not in consider_as_foreign_keys + ): continue try: col = fk.get_referent(b) @@ -1046,14 +1121,16 @@ class Join(FromClause): @classmethod def _joincond_trim_constraints( - cls, a, b, constraints, consider_as_foreign_keys): + cls, a, b, constraints, consider_as_foreign_keys + ): # more than one constraint matched. narrow down the list # to include just those FKCs that match exactly to # "consider_as_foreign_keys". if consider_as_foreign_keys: for const in list(constraints): if set(f.parent for f in const.elements) != set( - consider_as_foreign_keys): + consider_as_foreign_keys + ): del constraints[const] # if still multiple constraints, but @@ -1070,8 +1147,8 @@ class Join(FromClause): "tables have more than one foreign key " "constraint relationship between them. " "Please specify the 'onclause' of this " - "join explicitly." % (a.description, b.description)) - + "join explicitly." % (a.description, b.description) + ) def select(self, whereclause=None, **kwargs): r"""Create a :class:`.Select` from this :class:`.Join`. @@ -1200,27 +1277,37 @@ class Join(FromClause): """ if flat: assert name is None, "Can't send name argument with flat" - left_a, right_a = self.left.alias(flat=True), \ - self.right.alias(flat=True) - adapter = sqlutil.ClauseAdapter(left_a).\ - chain(sqlutil.ClauseAdapter(right_a)) + left_a, right_a = ( + self.left.alias(flat=True), + self.right.alias(flat=True), + ) + adapter = sqlutil.ClauseAdapter(left_a).chain( + sqlutil.ClauseAdapter(right_a) + ) - return left_a.join(right_a, adapter.traverse(self.onclause), - isouter=self.isouter, full=self.full) + return left_a.join( + right_a, + adapter.traverse(self.onclause), + isouter=self.isouter, + full=self.full, + ) else: return self.select(use_labels=True, correlate=False).alias(name) @property def _hide_froms(self): - return itertools.chain(*[_from_objects(x.left, x.right) - for x in self._cloned_set]) + return itertools.chain( + *[_from_objects(x.left, x.right) for x in self._cloned_set] + ) @property def _from_objects(self): - return [self] + \ - self.onclause._from_objects + \ - self.left._from_objects + \ - self.right._from_objects + return ( + [self] + + self.onclause._from_objects + + self.left._from_objects + + self.right._from_objects + ) class Alias(FromClause): @@ -1236,7 +1323,7 @@ class Alias(FromClause): """ - __visit_name__ = 'alias' + __visit_name__ = "alias" named_with_column = True _is_from_container = True @@ -1252,15 +1339,16 @@ class Alias(FromClause): self.element = selectable if name is None: if self.original.named_with_column: - name = getattr(self.original, 'name', None) - name = _anonymous_label('%%(%d %s)s' % (id(self), name - or 'anon')) + name = getattr(self.original, "name", None) + name = _anonymous_label("%%(%d %s)s" % (id(self), name or "anon")) self.name = name def self_group(self, against=None): - if isinstance(against, CompoundSelect) and \ - isinstance(self.original, Select) and \ - self.original._needs_parens_for_grouping(): + if ( + isinstance(against, CompoundSelect) + and isinstance(self.original, Select) + and self.original._needs_parens_for_grouping() + ): return FromGrouping(self) return super(Alias, self).self_group(against=against) @@ -1270,14 +1358,15 @@ class Alias(FromClause): if util.py3k: return self.name else: - return self.name.encode('ascii', 'backslashreplace') + return self.name.encode("ascii", "backslashreplace") def as_scalar(self): try: return self.element.as_scalar() except AttributeError: - raise AttributeError("Element %s does not support " - "'as_scalar()'" % self.element) + raise AttributeError( + "Element %s does not support " "'as_scalar()'" % self.element + ) def is_derived_from(self, fromclause): if fromclause in self._cloned_set: @@ -1344,7 +1433,7 @@ class Lateral(Alias): """ - __visit_name__ = 'lateral' + __visit_name__ = "lateral" _is_lateral = True @@ -1363,11 +1452,9 @@ class TableSample(Alias): """ - __visit_name__ = 'tablesample' + __visit_name__ = "tablesample" - def __init__(self, selectable, sampling, - name=None, - seed=None): + def __init__(self, selectable, sampling, name=None, seed=None): self.sampling = sampling self.seed = seed super(TableSample, self).__init__(selectable, name=name) @@ -1390,14 +1477,18 @@ class CTE(Generative, HasSuffixes, Alias): .. versionadded:: 0.7.6 """ - __visit_name__ = 'cte' - - def __init__(self, selectable, - name=None, - recursive=False, - _cte_alias=None, - _restates=frozenset(), - _suffixes=None): + + __visit_name__ = "cte" + + def __init__( + self, + selectable, + name=None, + recursive=False, + _cte_alias=None, + _restates=frozenset(), + _suffixes=None, + ): self.recursive = recursive self._cte_alias = _cte_alias self._restates = _restates @@ -1409,9 +1500,9 @@ class CTE(Generative, HasSuffixes, Alias): super(CTE, self)._copy_internals(clone, **kw) if self._cte_alias is not None: self._cte_alias = clone(self._cte_alias, **kw) - self._restates = frozenset([ - clone(elem, **kw) for elem in self._restates - ]) + self._restates = frozenset( + [clone(elem, **kw) for elem in self._restates] + ) @util.dependencies("sqlalchemy.sql.dml") def _populate_column_collection(self, dml): @@ -1428,7 +1519,7 @@ class CTE(Generative, HasSuffixes, Alias): name=name, recursive=self.recursive, _cte_alias=self, - _suffixes=self._suffixes + _suffixes=self._suffixes, ) def union(self, other): @@ -1437,7 +1528,7 @@ class CTE(Generative, HasSuffixes, Alias): name=self.name, recursive=self.recursive, _restates=self._restates.union([self]), - _suffixes=self._suffixes + _suffixes=self._suffixes, ) def union_all(self, other): @@ -1446,7 +1537,7 @@ class CTE(Generative, HasSuffixes, Alias): name=self.name, recursive=self.recursive, _restates=self._restates.union([self]), - _suffixes=self._suffixes + _suffixes=self._suffixes, ) @@ -1620,7 +1711,8 @@ class HasCTE(object): class FromGrouping(FromClause): """Represent a grouping of a FROM clause""" - __visit_name__ = 'grouping' + + __visit_name__ = "grouping" def __init__(self, element): self.element = element @@ -1651,7 +1743,7 @@ class FromGrouping(FromClause): return self.element._hide_froms def get_children(self, **kwargs): - return self.element, + return (self.element,) def _copy_internals(self, clone=_clone, **kw): self.element = clone(self.element, **kw) @@ -1664,10 +1756,10 @@ class FromGrouping(FromClause): return getattr(self.element, attr) def __getstate__(self): - return {'element': self.element} + return {"element": self.element} def __setstate__(self, state): - self.element = state['element'] + self.element = state["element"] class TableClause(Immutable, FromClause): @@ -1699,7 +1791,7 @@ class TableClause(Immutable, FromClause): """ - __visit_name__ = 'table' + __visit_name__ = "table" named_with_column = True @@ -1744,7 +1836,7 @@ class TableClause(Immutable, FromClause): if util.py3k: return self.name else: - return self.name.encode('ascii', 'backslashreplace') + return self.name.encode("ascii", "backslashreplace") def append_column(self, c): self._columns[c.key] = c @@ -1773,7 +1865,8 @@ class TableClause(Immutable, FromClause): @util.dependencies("sqlalchemy.sql.dml") def update( - self, dml, whereclause=None, values=None, inline=False, **kwargs): + self, dml, whereclause=None, values=None, inline=False, **kwargs + ): """Generate an :func:`.update` construct against this :class:`.TableClause`. @@ -1785,8 +1878,13 @@ class TableClause(Immutable, FromClause): """ - return dml.Update(self, whereclause=whereclause, - values=values, inline=inline, **kwargs) + return dml.Update( + self, + whereclause=whereclause, + values=values, + inline=inline, + **kwargs + ) @util.dependencies("sqlalchemy.sql.dml") def delete(self, dml, whereclause=None, **kwargs): @@ -1809,7 +1907,6 @@ class TableClause(Immutable, FromClause): class ForUpdateArg(ClauseElement): - @classmethod def parse_legacy_select(self, arg): """Parse the for_update argument of :func:`.select`. @@ -1836,11 +1933,11 @@ class ForUpdateArg(ClauseElement): return None nowait = read = False - if arg == 'nowait': + if arg == "nowait": nowait = True - elif arg == 'read': + elif arg == "read": read = True - elif arg == 'read_nowait': + elif arg == "read_nowait": read = nowait = True elif arg is not True: raise exc.ArgumentError("Unknown for_update argument: %r" % arg) @@ -1860,12 +1957,12 @@ class ForUpdateArg(ClauseElement): def __eq__(self, other): return ( - isinstance(other, ForUpdateArg) and - other.nowait == self.nowait and - other.read == self.read and - other.skip_locked == self.skip_locked and - other.key_share == self.key_share and - other.of is self.of + isinstance(other, ForUpdateArg) + and other.nowait == self.nowait + and other.read == self.read + and other.skip_locked == self.skip_locked + and other.key_share == self.key_share + and other.of is self.of ) def __hash__(self): @@ -1876,8 +1973,13 @@ class ForUpdateArg(ClauseElement): self.of = [clone(col, **kw) for col in self.of] def __init__( - self, nowait=False, read=False, of=None, - skip_locked=False, key_share=False): + self, + nowait=False, + read=False, + of=None, + skip_locked=False, + key_share=False, + ): """Represents arguments specified to :meth:`.Select.for_update`. .. versionadded:: 0.9.0 @@ -1889,8 +1991,9 @@ class ForUpdateArg(ClauseElement): self.skip_locked = skip_locked self.key_share = key_share if of is not None: - self.of = [_interpret_as_column_or_from(elem) - for elem in util.to_list(of)] + self.of = [ + _interpret_as_column_or_from(elem) for elem in util.to_list(of) + ] else: self.of = None @@ -1930,17 +2033,20 @@ class SelectBase(HasCTE, Executable, FromClause): return self.as_scalar().label(name) @_generative - @util.deprecated('0.6', - message="``autocommit()`` is deprecated. Use " - ":meth:`.Executable.execution_options` with the " - "'autocommit' flag.") + @util.deprecated( + "0.6", + message="``autocommit()`` is deprecated. Use " + ":meth:`.Executable.execution_options` with the " + "'autocommit' flag.", + ) def autocommit(self): """return a new selectable with the 'autocommit' flag set to True. """ - self._execution_options = \ - self._execution_options.union({'autocommit': True}) + self._execution_options = self._execution_options.union( + {"autocommit": True} + ) def _generate(self): """Override the default _generate() method to also clear out @@ -1973,34 +2079,38 @@ class GenerativeSelect(SelectBase): used for other SELECT-like objects, e.g. :class:`.TextAsFrom`. """ + _order_by_clause = ClauseList() _group_by_clause = ClauseList() _limit_clause = None _offset_clause = None _for_update_arg = None - def __init__(self, - use_labels=False, - for_update=False, - limit=None, - offset=None, - order_by=None, - group_by=None, - bind=None, - autocommit=None): + def __init__( + self, + use_labels=False, + for_update=False, + limit=None, + offset=None, + order_by=None, + group_by=None, + bind=None, + autocommit=None, + ): self.use_labels = use_labels if for_update is not False: - self._for_update_arg = (ForUpdateArg. - parse_legacy_select(for_update)) + self._for_update_arg = ForUpdateArg.parse_legacy_select(for_update) if autocommit is not None: - util.warn_deprecated('autocommit on select() is ' - 'deprecated. Use .execution_options(a' - 'utocommit=True)') - self._execution_options = \ - self._execution_options.union( - {'autocommit': autocommit}) + util.warn_deprecated( + "autocommit on select() is " + "deprecated. Use .execution_options(a" + "utocommit=True)" + ) + self._execution_options = self._execution_options.union( + {"autocommit": autocommit} + ) if limit is not None: self._limit_clause = _offset_or_limit_clause(limit) if offset is not None: @@ -2010,11 +2120,13 @@ class GenerativeSelect(SelectBase): if order_by is not None: self._order_by_clause = ClauseList( *util.to_list(order_by), - _literal_as_text=_literal_and_labels_as_label_reference) + _literal_as_text=_literal_and_labels_as_label_reference + ) if group_by is not None: self._group_by_clause = ClauseList( *util.to_list(group_by), - _literal_as_text=_literal_as_label_reference) + _literal_as_text=_literal_as_label_reference + ) @property def for_update(self): @@ -2030,8 +2142,14 @@ class GenerativeSelect(SelectBase): self._for_update_arg = ForUpdateArg.parse_legacy_select(value) @_generative - def with_for_update(self, nowait=False, read=False, of=None, - skip_locked=False, key_share=False): + def with_for_update( + self, + nowait=False, + read=False, + of=None, + skip_locked=False, + key_share=False, + ): """Specify a ``FOR UPDATE`` clause for this :class:`.GenerativeSelect`. E.g.:: @@ -2079,9 +2197,13 @@ class GenerativeSelect(SelectBase): .. versionadded:: 1.1.0 """ - self._for_update_arg = ForUpdateArg(nowait=nowait, read=read, of=of, - skip_locked=skip_locked, - key_share=key_share) + self._for_update_arg = ForUpdateArg( + nowait=nowait, + read=read, + of=of, + skip_locked=skip_locked, + key_share=key_share, + ) @_generative def apply_labels(self): @@ -2209,11 +2331,12 @@ class GenerativeSelect(SelectBase): if len(clauses) == 1 and clauses[0] is None: self._order_by_clause = ClauseList() else: - if getattr(self, '_order_by_clause', None) is not None: + if getattr(self, "_order_by_clause", None) is not None: clauses = list(self._order_by_clause) + list(clauses) self._order_by_clause = ClauseList( *clauses, - _literal_as_text=_literal_and_labels_as_label_reference) + _literal_as_text=_literal_and_labels_as_label_reference + ) def append_group_by(self, *clauses): """Append the given GROUP BY criterion applied to this selectable. @@ -2228,10 +2351,11 @@ class GenerativeSelect(SelectBase): if len(clauses) == 1 and clauses[0] is None: self._group_by_clause = ClauseList() else: - if getattr(self, '_group_by_clause', None) is not None: + if getattr(self, "_group_by_clause", None) is not None: clauses = list(self._group_by_clause) + list(clauses) self._group_by_clause = ClauseList( - *clauses, _literal_as_text=_literal_as_label_reference) + *clauses, _literal_as_text=_literal_as_label_reference + ) @property def _label_resolve_dict(self): @@ -2265,19 +2389,19 @@ class CompoundSelect(GenerativeSelect): """ - __visit_name__ = 'compound_select' + __visit_name__ = "compound_select" - UNION = util.symbol('UNION') - UNION_ALL = util.symbol('UNION ALL') - EXCEPT = util.symbol('EXCEPT') - EXCEPT_ALL = util.symbol('EXCEPT ALL') - INTERSECT = util.symbol('INTERSECT') - INTERSECT_ALL = util.symbol('INTERSECT ALL') + UNION = util.symbol("UNION") + UNION_ALL = util.symbol("UNION ALL") + EXCEPT = util.symbol("EXCEPT") + EXCEPT_ALL = util.symbol("EXCEPT ALL") + INTERSECT = util.symbol("INTERSECT") + INTERSECT_ALL = util.symbol("INTERSECT ALL") _is_from_container = True def __init__(self, keyword, *selects, **kwargs): - self._auto_correlate = kwargs.pop('correlate', False) + self._auto_correlate = kwargs.pop("correlate", False) self.keyword = keyword self.selects = [] @@ -2291,12 +2415,16 @@ class CompoundSelect(GenerativeSelect): numcols = len(s.c._all_columns) elif len(s.c._all_columns) != numcols: raise exc.ArgumentError( - 'All selectables passed to ' - 'CompoundSelect must have identical numbers of ' - 'columns; select #%d has %d columns, select ' - '#%d has %d' % - (1, len(self.selects[0].c._all_columns), - n + 1, len(s.c._all_columns)) + "All selectables passed to " + "CompoundSelect must have identical numbers of " + "columns; select #%d has %d columns, select " + "#%d has %d" + % ( + 1, + len(self.selects[0].c._all_columns), + n + 1, + len(s.c._all_columns), + ) ) self.selects.append(s.self_group(against=self)) @@ -2305,9 +2433,7 @@ class CompoundSelect(GenerativeSelect): @property def _label_resolve_dict(self): - d = dict( - (c.key, c) for c in self.c - ) + d = dict((c.key, c) for c in self.c) return d, d, d @classmethod @@ -2416,8 +2542,7 @@ class CompoundSelect(GenerativeSelect): :func:`select`. """ - return CompoundSelect( - CompoundSelect.INTERSECT_ALL, *selects, **kwargs) + return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects, **kwargs) def _scalar_type(self): return self.selects[0]._scalar_type() @@ -2445,8 +2570,10 @@ class CompoundSelect(GenerativeSelect): # those fks too. proxy = cols[0]._make_proxy( - self, name=cols[0]._label if self.use_labels else None, - key=cols[0]._key_label if self.use_labels else None) + self, + name=cols[0]._label if self.use_labels else None, + key=cols[0]._key_label if self.use_labels else None, + ) # hand-construct the "_proxies" collection to include all # derived columns place a 'weight' annotation corresponding @@ -2455,7 +2582,8 @@ class CompoundSelect(GenerativeSelect): # conflicts proxy._proxies = [ - c._annotate({'weight': i + 1}) for (i, c) in enumerate(cols)] + c._annotate({"weight": i + 1}) for (i, c) in enumerate(cols) + ] def _refresh_for_new_column(self, column): for s in self.selects: @@ -2464,25 +2592,32 @@ class CompoundSelect(GenerativeSelect): if not self._cols_populated: return None - raise NotImplementedError("CompoundSelect constructs don't support " - "addition of columns to underlying " - "selectables") + raise NotImplementedError( + "CompoundSelect constructs don't support " + "addition of columns to underlying " + "selectables" + ) def _copy_internals(self, clone=_clone, **kw): super(CompoundSelect, self)._copy_internals(clone, **kw) self._reset_exported() self.selects = [clone(s, **kw) for s in self.selects] - if hasattr(self, '_col_map'): + if hasattr(self, "_col_map"): del self._col_map for attr in ( - '_order_by_clause', '_group_by_clause', '_for_update_arg'): + "_order_by_clause", + "_group_by_clause", + "_for_update_arg", + ): if getattr(self, attr) is not None: setattr(self, attr, clone(getattr(self, attr), **kw)) def get_children(self, column_collections=True, **kwargs): - return (column_collections and list(self.c) or []) \ - + [self._order_by_clause, self._group_by_clause] \ + return ( + (column_collections and list(self.c) or []) + + [self._order_by_clause, self._group_by_clause] + list(self.selects) + ) def bind(self): if self._bind: @@ -2496,6 +2631,7 @@ class CompoundSelect(GenerativeSelect): def _set_bind(self, bind): self._bind = bind + bind = property(bind, _set_bind) @@ -2504,7 +2640,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ - __visit_name__ = 'select' + __visit_name__ = "select" _prefixes = () _suffixes = () @@ -2517,16 +2653,18 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): _memoized_property = SelectBase._memoized_property _is_select = True - def __init__(self, - columns=None, - whereclause=None, - from_obj=None, - distinct=False, - having=None, - correlate=True, - prefixes=None, - suffixes=None, - **kwargs): + def __init__( + self, + columns=None, + whereclause=None, + from_obj=None, + distinct=False, + having=None, + correlate=True, + prefixes=None, + suffixes=None, + **kwargs + ): """Construct a new :class:`.Select`. Similar functionality is also available via the @@ -2729,22 +2867,23 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._distinct = True else: self._distinct = [ - _literal_as_text(e) - for e in util.to_list(distinct) + _literal_as_text(e) for e in util.to_list(distinct) ] if from_obj is not None: self._from_obj = util.OrderedSet( - _interpret_as_from(f) - for f in util.to_list(from_obj)) + _interpret_as_from(f) for f in util.to_list(from_obj) + ) else: self._from_obj = util.OrderedSet() try: cols_present = bool(columns) except TypeError: - raise exc.ArgumentError("columns argument to select() must " - "be a Python list or other iterable") + raise exc.ArgumentError( + "columns argument to select() must " + "be a Python list or other iterable" + ) if cols_present: self._raw_columns = [] @@ -2757,14 +2896,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._raw_columns = [] if whereclause is not None: - self._whereclause = _literal_as_text( - whereclause).self_group(against=operators._asbool) + self._whereclause = _literal_as_text(whereclause).self_group( + against=operators._asbool + ) else: self._whereclause = None if having is not None: - self._having = _literal_as_text( - having).self_group(against=operators._asbool) + self._having = _literal_as_text(having).self_group( + against=operators._asbool + ) else: self._having = None @@ -2789,12 +2930,14 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): for item in itertools.chain( _from_objects(*self._raw_columns), _from_objects(self._whereclause) - if self._whereclause is not None else (), - self._from_obj + if self._whereclause is not None + else (), + self._from_obj, ): if item is self: raise exc.InvalidRequestError( - "select() construct refers to itself as a FROM") + "select() construct refers to itself as a FROM" + ) if translate and item in translate: item = translate[item] if not seen.intersection(item._cloned_set): @@ -2803,8 +2946,9 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): return froms - def _get_display_froms(self, explicit_correlate_froms=None, - implicit_correlate_froms=None): + def _get_display_froms( + self, explicit_correlate_froms=None, implicit_correlate_froms=None + ): """Return the full list of 'from' clauses to be displayed. Takes into account a set of existing froms which may be @@ -2815,17 +2959,17 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ froms = self._froms - toremove = set(itertools.chain(*[ - _expand_cloned(f._hide_froms) - for f in froms])) + toremove = set( + itertools.chain(*[_expand_cloned(f._hide_froms) for f in froms]) + ) if toremove: # if we're maintaining clones of froms, # add the copies out to the toremove list. only include # clones that are lexical equivalents. if self._from_cloned: toremove.update( - self._from_cloned[f] for f in - toremove.intersection(self._from_cloned) + self._from_cloned[f] + for f in toremove.intersection(self._from_cloned) if self._from_cloned[f]._is_lexical_equivalent(f) ) # filter out to FROM clauses not in the list, @@ -2836,41 +2980,53 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): to_correlate = self._correlate if to_correlate: froms = [ - f for f in froms if f not in - _cloned_intersection( + f + for f in froms + if f + not in _cloned_intersection( _cloned_intersection( - froms, explicit_correlate_froms or ()), - to_correlate + froms, explicit_correlate_froms or () + ), + to_correlate, ) ] if self._correlate_except is not None: froms = [ - f for f in froms if f not in - _cloned_difference( + f + for f in froms + if f + not in _cloned_difference( _cloned_intersection( - froms, explicit_correlate_froms or ()), - self._correlate_except + froms, explicit_correlate_froms or () + ), + self._correlate_except, ) ] - if self._auto_correlate and \ - implicit_correlate_froms and \ - len(froms) > 1: + if ( + self._auto_correlate + and implicit_correlate_froms + and len(froms) > 1 + ): froms = [ - f for f in froms if f not in - _cloned_intersection(froms, implicit_correlate_froms) + f + for f in froms + if f + not in _cloned_intersection(froms, implicit_correlate_froms) ] if not len(froms): - raise exc.InvalidRequestError("Select statement '%s" - "' returned no FROM clauses " - "due to auto-correlation; " - "specify correlate() " - "to control correlation " - "manually." % self) + raise exc.InvalidRequestError( + "Select statement '%s" + "' returned no FROM clauses " + "due to auto-correlation; " + "specify correlate() " + "to control correlation " + "manually." % self + ) return froms @@ -2885,7 +3041,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): return self._get_display_froms() - def with_statement_hint(self, text, dialect_name='*'): + def with_statement_hint(self, text, dialect_name="*"): """add a statement hint to this :class:`.Select`. This method is similar to :meth:`.Select.with_hint` except that @@ -2906,7 +3062,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): return self.with_hint(None, text, dialect_name) @_generative - def with_hint(self, selectable, text, dialect_name='*'): + def with_hint(self, selectable, text, dialect_name="*"): r"""Add an indexing or other executional context hint for the given selectable to this :class:`.Select`. @@ -2940,17 +3096,18 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ if selectable is None: - self._statement_hints += ((dialect_name, text), ) + self._statement_hints += ((dialect_name, text),) else: - self._hints = self._hints.union( - {(selectable, dialect_name): text}) + self._hints = self._hints.union({(selectable, dialect_name): text}) @property def type(self): - raise exc.InvalidRequestError("Select objects don't have a type. " - "Call as_scalar() on this Select " - "object to return a 'scalar' version " - "of this Select.") + raise exc.InvalidRequestError( + "Select objects don't have a type. " + "Call as_scalar() on this Select " + "object to return a 'scalar' version " + "of this Select." + ) @_memoized_property.method def locate_all_froms(self): @@ -2977,10 +3134,13 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): with_cols = dict( (c._resolve_label or c._label or c.key, c) for c in _select_iterables(self._raw_columns) - if c._allow_label_resolve) + if c._allow_label_resolve + ) only_froms = dict( - (c.key, c) for c in - _select_iterables(self.froms) if c._allow_label_resolve) + (c.key, c) + for c in _select_iterables(self.froms) + if c._allow_label_resolve + ) only_cols = with_cols.copy() for key, value in only_froms.items(): with_cols.setdefault(key, value) @@ -3011,11 +3171,13 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): # gets cleared on each generation. previously we were "baking" # _froms into self._from_obj. self._from_cloned = from_cloned = dict( - (f, clone(f, **kw)) for f in self._from_obj.union(self._froms)) + (f, clone(f, **kw)) for f in self._from_obj.union(self._froms) + ) # 3. update persistent _from_obj with the cloned versions. - self._from_obj = util.OrderedSet(from_cloned[f] for f in - self._from_obj) + self._from_obj = util.OrderedSet( + from_cloned[f] for f in self._from_obj + ) # the _correlate collection is done separately, what can happen # here is the same item is _correlate as in _from_obj but the @@ -3023,16 +3185,22 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): # RelationshipProperty.Comparator._criterion_exists() does # this). Also keep _correlate liberally open with its previous # contents, as this set is used for matching, not rendering. - self._correlate = set(clone(f) for f in - self._correlate).union(self._correlate) + self._correlate = set(clone(f) for f in self._correlate).union( + self._correlate + ) # 4. clone other things. The difficulty here is that Column # objects are not actually cloned, and refer to their original # .table, resulting in the wrong "from" parent after a clone # operation. Hence _from_cloned and _from_obj supersede what is # present here. self._raw_columns = [clone(c, **kw) for c in self._raw_columns] - for attr in '_whereclause', '_having', '_order_by_clause', \ - '_group_by_clause', '_for_update_arg': + for attr in ( + "_whereclause", + "_having", + "_order_by_clause", + "_group_by_clause", + "_for_update_arg", + ): if getattr(self, attr) is not None: setattr(self, attr, clone(getattr(self, attr), **kw)) @@ -3043,12 +3211,21 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): def get_children(self, column_collections=True, **kwargs): """return child elements as per the ClauseElement specification.""" - return (column_collections and list(self.columns) or []) + \ - self._raw_columns + list(self._froms) + \ - [x for x in - (self._whereclause, self._having, - self._order_by_clause, self._group_by_clause) - if x is not None] + return ( + (column_collections and list(self.columns) or []) + + self._raw_columns + + list(self._froms) + + [ + x + for x in ( + self._whereclause, + self._having, + self._order_by_clause, + self._group_by_clause, + ) + if x is not None + ] + ) @_generative def column(self, column): @@ -3094,7 +3271,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): sqlutil.reduce_columns( self.inner_columns, only_synonyms=only_synonyms, - *(self._whereclause, ) + tuple(self._from_obj) + *(self._whereclause,) + tuple(self._from_obj) ) ) @@ -3307,7 +3484,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._correlate = () else: self._correlate = set(self._correlate).union( - _interpret_as_from(f) for f in fromclauses) + _interpret_as_from(f) for f in fromclauses + ) @_generative def correlate_except(self, *fromclauses): @@ -3349,7 +3527,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._correlate_except = () else: self._correlate_except = set(self._correlate_except or ()).union( - _interpret_as_from(f) for f in fromclauses) + _interpret_as_from(f) for f in fromclauses + ) def append_correlation(self, fromclause): """append the given correlation expression to this select() @@ -3363,7 +3542,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._auto_correlate = False self._correlate = set(self._correlate).union( - _interpret_as_from(f) for f in fromclause) + _interpret_as_from(f) for f in fromclause + ) def append_column(self, column): """append the given column expression to the columns clause of this @@ -3415,8 +3595,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ self._reset_exported() - self._whereclause = and_( - True_._ifnone(self._whereclause), whereclause) + self._whereclause = and_(True_._ifnone(self._whereclause), whereclause) def append_having(self, having): """append the given expression to this select() construct's HAVING @@ -3463,19 +3642,17 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): return [ name_for_col(c) - for c in util.unique_list( - _select_iterables(self._raw_columns)) + for c in util.unique_list(_select_iterables(self._raw_columns)) ] else: return [ (None, c) - for c in util.unique_list( - _select_iterables(self._raw_columns)) + for c in util.unique_list(_select_iterables(self._raw_columns)) ] def _populate_column_collection(self): for name, c in self._columns_plus_names: - if not hasattr(c, '_make_proxy'): + if not hasattr(c, "_make_proxy"): continue if name is None: key = None @@ -3486,9 +3663,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): else: key = None - c._make_proxy(self, key=key, - name=name, - name_is_truncatable=True) + c._make_proxy(self, key=key, name=name, name_is_truncatable=True) def _refresh_for_new_column(self, column): for fromclause in self._froms: @@ -3501,15 +3676,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self, name=col._label if self.use_labels else None, key=col._key_label if self.use_labels else None, - name_is_truncatable=True) + name_is_truncatable=True, + ) return None return None def _needs_parens_for_grouping(self): return ( - self._limit_clause is not None or - self._offset_clause is not None or - bool(self._order_by_clause.clauses) + self._limit_clause is not None + or self._offset_clause is not None + or bool(self._order_by_clause.clauses) ) def self_group(self, against=None): @@ -3521,8 +3697,10 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): expressions and should not require explicit use. """ - if isinstance(against, CompoundSelect) and \ - not self._needs_parens_for_grouping(): + if ( + isinstance(against, CompoundSelect) + and not self._needs_parens_for_grouping() + ): return self return FromGrouping(self) @@ -3586,6 +3764,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): def _set_bind(self, bind): self._bind = bind + bind = property(bind, _set_bind) @@ -3600,9 +3779,12 @@ class ScalarSelect(Generative, Grouping): @property def columns(self): - raise exc.InvalidRequestError('Scalar Select expression has no ' - 'columns; use this object directly ' - 'within a column-level expression.') + raise exc.InvalidRequestError( + "Scalar Select expression has no " + "columns; use this object directly " + "within a column-level expression." + ) + c = columns @_generative @@ -3621,6 +3803,7 @@ class Exists(UnaryExpression): """Represent an ``EXISTS`` clause. """ + __visit_name__ = UnaryExpression.__visit_name__ _from_objects = [] @@ -3646,12 +3829,16 @@ class Exists(UnaryExpression): s = args[0] else: if not args: - args = ([literal_column('*')],) + args = ([literal_column("*")],) s = Select(*args, **kwargs).as_scalar().self_group() - UnaryExpression.__init__(self, s, operator=operators.exists, - type_=type_api.BOOLEANTYPE, - wraps_column_expression=True) + UnaryExpression.__init__( + self, + s, + operator=operators.exists, + type_=type_api.BOOLEANTYPE, + wraps_column_expression=True, + ) def select(self, whereclause=None, **params): return Select([self], whereclause, **params) @@ -3706,6 +3893,7 @@ class TextAsFrom(SelectBase): :meth:`.TextClause.columns` """ + __visit_name__ = "text_as_from" _textual = True diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index c5708940b4..61fc6d3c96 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -15,10 +15,21 @@ import collections import json from . import elements -from .type_api import TypeEngine, TypeDecorator, to_instance, Variant, \ - Emulated, NativeForEmulated -from .elements import quoted_name, TypeCoerce as type_coerce, _defer_name, \ - Slice, _literal_as_binds +from .type_api import ( + TypeEngine, + TypeDecorator, + to_instance, + Variant, + Emulated, + NativeForEmulated, +) +from .elements import ( + quoted_name, + TypeCoerce as type_coerce, + _defer_name, + Slice, + _literal_as_binds, +) from .. import exc, util, processors from .base import _bind_or_error, SchemaEventTarget from . import operators @@ -51,14 +62,15 @@ class _LookupExpressionAdapter(object): def _adapt_expression(self, op, other_comparator): othertype = other_comparator.type._type_affinity lookup = self.type._expression_adaptations.get( - op, self._blank_dict).get( - othertype, self.type) + op, self._blank_dict + ).get(othertype, self.type) if lookup is othertype: return (op, other_comparator.type) elif lookup is self.type._type_affinity: return (op, self.type) else: return (op, to_instance(lookup)) + comparator_factory = Comparator @@ -68,17 +80,16 @@ class Concatenable(object): typically strings.""" class Comparator(TypeEngine.Comparator): - def _adapt_expression(self, op, other_comparator): - if (op is operators.add and - isinstance( - other_comparator, - (Concatenable.Comparator, NullType.Comparator) - )): + if op is operators.add and isinstance( + other_comparator, + (Concatenable.Comparator, NullType.Comparator), + ): return operators.concat_op, self.expr.type else: return super(Concatenable.Comparator, self)._adapt_expression( - op, other_comparator) + op, other_comparator + ) comparator_factory = Comparator @@ -94,17 +105,15 @@ class Indexable(object): """ class Comparator(TypeEngine.Comparator): - def _setup_getitem(self, index): raise NotImplementedError() def __getitem__(self, index): - adjusted_op, adjusted_right_expr, result_type = \ - self._setup_getitem(index) + adjusted_op, adjusted_right_expr, result_type = self._setup_getitem( + index + ) return self.operate( - adjusted_op, - adjusted_right_expr, - result_type=result_type + adjusted_op, adjusted_right_expr, result_type=result_type ) comparator_factory = Comparator @@ -124,13 +133,16 @@ class String(Concatenable, TypeEngine): """ - __visit_name__ = 'string' + __visit_name__ = "string" - def __init__(self, length=None, collation=None, - convert_unicode=False, - unicode_error=None, - _warn_on_bytestring=False - ): + def __init__( + self, + length=None, + collation=None, + convert_unicode=False, + unicode_error=None, + _warn_on_bytestring=False, + ): """ Create a string-holding type. @@ -207,9 +219,10 @@ class String(Concatenable, TypeEngine): strings from a column with varied or corrupted encodings. """ - if unicode_error is not None and convert_unicode != 'force': - raise exc.ArgumentError("convert_unicode must be 'force' " - "when unicode_error is set.") + if unicode_error is not None and convert_unicode != "force": + raise exc.ArgumentError( + "convert_unicode must be 'force' " "when unicode_error is set." + ) self.length = length self.collation = collation @@ -222,23 +235,29 @@ class String(Concatenable, TypeEngine): value = value.replace("'", "''") if dialect.identifier_preparer._double_percents: - value = value.replace('%', '%%') + value = value.replace("%", "%%") return "'%s'" % value + return process def bind_processor(self, dialect): if self.convert_unicode or dialect.convert_unicode: - if dialect.supports_unicode_binds and \ - self.convert_unicode != 'force': + if ( + dialect.supports_unicode_binds + and self.convert_unicode != "force" + ): if self._warn_on_bytestring: + def process(value): if isinstance(value, util.binary_type): util.warn_limited( "Unicode type received non-unicode " "bind param value %r.", - (util.ellipses_string(value),)) + (util.ellipses_string(value),), + ) return value + return process else: return None @@ -253,29 +272,34 @@ class String(Concatenable, TypeEngine): util.warn_limited( "Unicode type received non-unicode bind " "param value %r.", - (util.ellipses_string(value),)) + (util.ellipses_string(value),), + ) return value + return process else: return None def result_processor(self, dialect, coltype): wants_unicode = self.convert_unicode or dialect.convert_unicode - needs_convert = wants_unicode and \ - (dialect.returns_unicode_strings is not True or - self.convert_unicode in ('force', 'force_nocheck')) + needs_convert = wants_unicode and ( + dialect.returns_unicode_strings is not True + or self.convert_unicode in ("force", "force_nocheck") + ) needs_isinstance = ( - needs_convert and - dialect.returns_unicode_strings and - self.convert_unicode != 'force_nocheck' + needs_convert + and dialect.returns_unicode_strings + and self.convert_unicode != "force_nocheck" ) if needs_convert: if needs_isinstance: return processors.to_conditional_unicode_processor_factory( - dialect.encoding, self.unicode_error) + dialect.encoding, self.unicode_error + ) else: return processors.to_unicode_processor_factory( - dialect.encoding, self.unicode_error) + dialect.encoding, self.unicode_error + ) else: return None @@ -301,7 +325,8 @@ class Text(String): argument here, it will be rejected by others. """ - __visit_name__ = 'text' + + __visit_name__ = "text" class Unicode(String): @@ -360,7 +385,7 @@ class Unicode(String): """ - __visit_name__ = 'unicode' + __visit_name__ = "unicode" def __init__(self, length=None, **kwargs): """ @@ -371,8 +396,8 @@ class Unicode(String): defaults to ``True``. """ - kwargs.setdefault('convert_unicode', True) - kwargs.setdefault('_warn_on_bytestring', True) + kwargs.setdefault("convert_unicode", True) + kwargs.setdefault("_warn_on_bytestring", True) super(Unicode, self).__init__(length=length, **kwargs) @@ -389,7 +414,7 @@ class UnicodeText(Text): """ - __visit_name__ = 'unicode_text' + __visit_name__ = "unicode_text" def __init__(self, length=None, **kwargs): """ @@ -400,8 +425,8 @@ class UnicodeText(Text): defaults to ``True``. """ - kwargs.setdefault('convert_unicode', True) - kwargs.setdefault('_warn_on_bytestring', True) + kwargs.setdefault("convert_unicode", True) + kwargs.setdefault("_warn_on_bytestring", True) super(UnicodeText, self).__init__(length=length, **kwargs) @@ -409,7 +434,7 @@ class Integer(_LookupExpressionAdapter, TypeEngine): """A type for ``int`` integers.""" - __visit_name__ = 'integer' + __visit_name__ = "integer" def get_dbapi_type(self, dbapi): return dbapi.NUMBER @@ -421,6 +446,7 @@ class Integer(_LookupExpressionAdapter, TypeEngine): def literal_processor(self, dialect): def process(value): return str(value) + return process @util.memoized_property @@ -438,18 +464,9 @@ class Integer(_LookupExpressionAdapter, TypeEngine): Integer: self.__class__, Numeric: Numeric, }, - operators.div: { - Integer: self.__class__, - Numeric: Numeric, - }, - operators.truediv: { - Integer: self.__class__, - Numeric: Numeric, - }, - operators.sub: { - Integer: self.__class__, - Numeric: Numeric, - }, + operators.div: {Integer: self.__class__, Numeric: Numeric}, + operators.truediv: {Integer: self.__class__, Numeric: Numeric}, + operators.sub: {Integer: self.__class__, Numeric: Numeric}, } @@ -462,7 +479,7 @@ class SmallInteger(Integer): """ - __visit_name__ = 'small_integer' + __visit_name__ = "small_integer" class BigInteger(Integer): @@ -474,7 +491,7 @@ class BigInteger(Integer): """ - __visit_name__ = 'big_integer' + __visit_name__ = "big_integer" class Numeric(_LookupExpressionAdapter, TypeEngine): @@ -517,12 +534,17 @@ class Numeric(_LookupExpressionAdapter, TypeEngine): """ - __visit_name__ = 'numeric' + __visit_name__ = "numeric" _default_decimal_return_scale = 10 - def __init__(self, precision=None, scale=None, - decimal_return_scale=None, asdecimal=True): + def __init__( + self, + precision=None, + scale=None, + decimal_return_scale=None, + asdecimal=True, + ): """ Construct a Numeric. @@ -587,6 +609,7 @@ class Numeric(_LookupExpressionAdapter, TypeEngine): def literal_processor(self, dialect): def process(value): return str(value) + return process @property @@ -608,19 +631,23 @@ class Numeric(_LookupExpressionAdapter, TypeEngine): # we're a "numeric", DBAPI will give us Decimal directly return None else: - util.warn('Dialect %s+%s does *not* support Decimal ' - 'objects natively, and SQLAlchemy must ' - 'convert from floating point - rounding ' - 'errors and other issues may occur. Please ' - 'consider storing Decimal numbers as strings ' - 'or integers on this platform for lossless ' - 'storage.' % (dialect.name, dialect.driver)) + util.warn( + "Dialect %s+%s does *not* support Decimal " + "objects natively, and SQLAlchemy must " + "convert from floating point - rounding " + "errors and other issues may occur. Please " + "consider storing Decimal numbers as strings " + "or integers on this platform for lossless " + "storage." % (dialect.name, dialect.driver) + ) # we're a "numeric", DBAPI returns floats, convert. return processors.to_decimal_processor_factory( decimal.Decimal, - self.scale if self.scale is not None - else self._default_decimal_return_scale) + self.scale + if self.scale is not None + else self._default_decimal_return_scale, + ) else: if dialect.supports_native_decimal: return processors.to_float @@ -635,22 +662,13 @@ class Numeric(_LookupExpressionAdapter, TypeEngine): Numeric: self.__class__, Integer: self.__class__, }, - operators.div: { - Numeric: self.__class__, - Integer: self.__class__, - }, + operators.div: {Numeric: self.__class__, Integer: self.__class__}, operators.truediv: { Numeric: self.__class__, Integer: self.__class__, }, - operators.add: { - Numeric: self.__class__, - Integer: self.__class__, - }, - operators.sub: { - Numeric: self.__class__, - Integer: self.__class__, - } + operators.add: {Numeric: self.__class__, Integer: self.__class__}, + operators.sub: {Numeric: self.__class__, Integer: self.__class__}, } @@ -675,12 +693,17 @@ class Float(Numeric): """ - __visit_name__ = 'float' + __visit_name__ = "float" scale = None - def __init__(self, precision=None, asdecimal=False, - decimal_return_scale=None, **kwargs): + def __init__( + self, + precision=None, + asdecimal=False, + decimal_return_scale=None, + **kwargs + ): r""" Construct a Float. @@ -713,14 +736,15 @@ class Float(Numeric): self.asdecimal = asdecimal self.decimal_return_scale = decimal_return_scale if kwargs: - util.warn_deprecated("Additional keyword arguments " - "passed to Float ignored.") + util.warn_deprecated( + "Additional keyword arguments " "passed to Float ignored." + ) def result_processor(self, dialect, coltype): if self.asdecimal: return processors.to_decimal_processor_factory( - decimal.Decimal, - self._effective_decimal_return_scale) + decimal.Decimal, self._effective_decimal_return_scale + ) elif dialect.supports_native_decimal: return processors.to_float else: @@ -746,7 +770,7 @@ class DateTime(_LookupExpressionAdapter, TypeEngine): """ - __visit_name__ = 'datetime' + __visit_name__ = "datetime" def __init__(self, timezone=False): """Construct a new :class:`.DateTime`. @@ -777,13 +801,8 @@ class DateTime(_LookupExpressionAdapter, TypeEngine): # static/functions-datetime.html. return { - operators.add: { - Interval: self.__class__, - }, - operators.sub: { - Interval: self.__class__, - DateTime: Interval, - }, + operators.add: {Interval: self.__class__}, + operators.sub: {Interval: self.__class__, DateTime: Interval}, } @@ -791,7 +810,7 @@ class Date(_LookupExpressionAdapter, TypeEngine): """A type for ``datetime.date()`` objects.""" - __visit_name__ = 'date' + __visit_name__ = "date" def get_dbapi_type(self, dbapi): return dbapi.DATETIME @@ -814,12 +833,9 @@ class Date(_LookupExpressionAdapter, TypeEngine): operators.sub: { # date - integer = date Integer: self.__class__, - # date - date = integer. Date: Integer, - Interval: DateTime, - # date - datetime = interval, # this one is not in the PG docs # but works @@ -832,7 +848,7 @@ class Time(_LookupExpressionAdapter, TypeEngine): """A type for ``datetime.time()`` objects.""" - __visit_name__ = 'time' + __visit_name__ = "time" def __init__(self, timezone=False): self.timezone = timezone @@ -850,14 +866,8 @@ class Time(_LookupExpressionAdapter, TypeEngine): # static/functions-datetime.html. return { - operators.add: { - Date: DateTime, - Interval: self.__class__ - }, - operators.sub: { - Time: Interval, - Interval: self.__class__, - }, + operators.add: {Date: DateTime, Interval: self.__class__}, + operators.sub: {Time: Interval, Interval: self.__class__}, } @@ -872,6 +882,7 @@ class _Binary(TypeEngine): def process(value): value = value.decode(dialect.encoding).replace("'", "''") return "'%s'" % value + return process @property @@ -891,14 +902,17 @@ class _Binary(TypeEngine): return DBAPIBinary(value) else: return None + return process # Python 3 has native bytes() type # both sqlite3 and pg8000 seem to return it, # psycopg2 as of 2.5 returns 'memoryview' if util.py2k: + def result_processor(self, dialect, coltype): if util.jython: + def process(value): if value is not None: if isinstance(value, array.array): @@ -906,15 +920,19 @@ class _Binary(TypeEngine): return str(value) else: return None + else: process = processors.to_str return process + else: + def result_processor(self, dialect, coltype): def process(value): if value is not None: value = bytes(value) return value + return process def coerce_compared_value(self, op, value): @@ -939,7 +957,7 @@ class LargeBinary(_Binary): """ - __visit_name__ = 'large_binary' + __visit_name__ = "large_binary" def __init__(self, length=None): """ @@ -958,8 +976,9 @@ class Binary(LargeBinary): """Deprecated. Renamed to LargeBinary.""" def __init__(self, *arg, **kw): - util.warn_deprecated('The Binary type has been renamed to ' - 'LargeBinary.') + util.warn_deprecated( + "The Binary type has been renamed to " "LargeBinary." + ) LargeBinary.__init__(self, *arg, **kw) @@ -986,8 +1005,15 @@ class SchemaType(SchemaEventTarget): """ - def __init__(self, name=None, schema=None, metadata=None, - inherit_schema=False, quote=None, _create_events=True): + def __init__( + self, + name=None, + schema=None, + metadata=None, + inherit_schema=False, + quote=None, + _create_events=True, + ): if name is not None: self.name = quoted_name(name, quote) else: @@ -1001,12 +1027,12 @@ class SchemaType(SchemaEventTarget): event.listen( self.metadata, "before_create", - util.portable_instancemethod(self._on_metadata_create) + util.portable_instancemethod(self._on_metadata_create), ) event.listen( self.metadata, "after_drop", - util.portable_instancemethod(self._on_metadata_drop) + util.portable_instancemethod(self._on_metadata_drop), ) def _translate_schema(self, effective_schema, map_): @@ -1018,7 +1044,7 @@ class SchemaType(SchemaEventTarget): def _variant_mapping_for_set_table(self, column): if isinstance(column.type, Variant): variant_mapping = column.type.mapping.copy() - variant_mapping['_default'] = column.type.impl + variant_mapping["_default"] = column.type.impl else: variant_mapping = None return variant_mapping @@ -1036,15 +1062,15 @@ class SchemaType(SchemaEventTarget): table, "before_create", util.portable_instancemethod( - self._on_table_create, - {"variant_mapping": variant_mapping}) + self._on_table_create, {"variant_mapping": variant_mapping} + ), ) event.listen( table, "after_drop", util.portable_instancemethod( - self._on_table_drop, - {"variant_mapping": variant_mapping}) + self._on_table_drop, {"variant_mapping": variant_mapping} + ), ) if self.metadata is None: # TODO: what's the difference between self.metadata @@ -1054,29 +1080,33 @@ class SchemaType(SchemaEventTarget): "before_create", util.portable_instancemethod( self._on_metadata_create, - {"variant_mapping": variant_mapping}) + {"variant_mapping": variant_mapping}, + ), ) event.listen( table.metadata, "after_drop", util.portable_instancemethod( self._on_metadata_drop, - {"variant_mapping": variant_mapping}) + {"variant_mapping": variant_mapping}, + ), ) def copy(self, **kw): return self.adapt(self.__class__, _create_events=True) def adapt(self, impltype, **kw): - schema = kw.pop('schema', self.schema) - metadata = kw.pop('metadata', self.metadata) - _create_events = kw.pop('_create_events', False) - return impltype(name=self.name, - schema=schema, - inherit_schema=self.inherit_schema, - metadata=metadata, - _create_events=_create_events, - **kw) + schema = kw.pop("schema", self.schema) + metadata = kw.pop("metadata", self.metadata) + _create_events = kw.pop("_create_events", False) + return impltype( + name=self.name, + schema=schema, + inherit_schema=self.inherit_schema, + metadata=metadata, + _create_events=_create_events, + **kw + ) @property def bind(self): @@ -1133,15 +1163,17 @@ class SchemaType(SchemaEventTarget): t._on_metadata_drop(target, bind, **kw) def _is_impl_for_variant(self, dialect, kw): - variant_mapping = kw.pop('variant_mapping', None) + variant_mapping = kw.pop("variant_mapping", None) if variant_mapping is None: return True - if dialect.name in variant_mapping and \ - variant_mapping[dialect.name] is self: + if ( + dialect.name in variant_mapping + and variant_mapping[dialect.name] is self + ): return True elif dialect.name not in variant_mapping: - return variant_mapping['_default'] is self + return variant_mapping["_default"] is self class Enum(Emulated, String, SchemaType): @@ -1220,7 +1252,8 @@ class Enum(Emulated, String, SchemaType): :class:`.mysql.ENUM` - MySQL-specific type """ - __visit_name__ = 'enum' + + __visit_name__ = "enum" def __init__(self, *enums, **kw): r"""Construct an enum. @@ -1322,15 +1355,15 @@ class Enum(Emulated, String, SchemaType): other arguments in kw to pass through. """ - self.native_enum = kw.pop('native_enum', True) - self.create_constraint = kw.pop('create_constraint', True) - self.values_callable = kw.pop('values_callable', None) + self.native_enum = kw.pop("native_enum", True) + self.create_constraint = kw.pop("create_constraint", True) + self.values_callable = kw.pop("values_callable", None) values, objects = self._parse_into_values(enums, kw) self._setup_for_values(values, objects, kw) - convert_unicode = kw.pop('convert_unicode', None) - self.validate_strings = kw.pop('validate_strings', False) + convert_unicode = kw.pop("convert_unicode", None) + self.validate_strings = kw.pop("validate_strings", False) if convert_unicode is None: for e in self.enums: @@ -1347,33 +1380,35 @@ class Enum(Emulated, String, SchemaType): self._valid_lookup[None] = self._object_lookup[None] = None super(Enum, self).__init__( - length=length, - convert_unicode=convert_unicode, + length=length, convert_unicode=convert_unicode ) if self.enum_class: - kw.setdefault('name', self.enum_class.__name__.lower()) + kw.setdefault("name", self.enum_class.__name__.lower()) SchemaType.__init__( self, - name=kw.pop('name', None), - schema=kw.pop('schema', None), - metadata=kw.pop('metadata', None), - inherit_schema=kw.pop('inherit_schema', False), - quote=kw.pop('quote', None), - _create_events=kw.pop('_create_events', True) + name=kw.pop("name", None), + schema=kw.pop("schema", None), + metadata=kw.pop("metadata", None), + inherit_schema=kw.pop("inherit_schema", False), + quote=kw.pop("quote", None), + _create_events=kw.pop("_create_events", True), ) def _parse_into_values(self, enums, kw): - if not enums and '_enums' in kw: - enums = kw.pop('_enums') + if not enums and "_enums" in kw: + enums = kw.pop("_enums") - if len(enums) == 1 and hasattr(enums[0], '__members__'): + if len(enums) == 1 and hasattr(enums[0], "__members__"): self.enum_class = enums[0] if self.values_callable: values = self.values_callable(self.enum_class) else: values = list(self.enum_class.__members__) - objects = [self.enum_class.__members__[k] for k in self.enum_class.__members__] + objects = [ + self.enum_class.__members__[k] + for k in self.enum_class.__members__ + ] return values, objects else: self.enum_class = None @@ -1382,18 +1417,16 @@ class Enum(Emulated, String, SchemaType): def _setup_for_values(self, values, objects, kw): self.enums = list(values) - self._valid_lookup = dict( - zip(reversed(objects), reversed(values)) - ) + self._valid_lookup = dict(zip(reversed(objects), reversed(values))) - self._object_lookup = dict( - zip(values, objects) - ) + self._object_lookup = dict(zip(values, objects)) - self._valid_lookup.update([ - (value, self._valid_lookup[self._object_lookup[value]]) - for value in values - ]) + self._valid_lookup.update( + [ + (value, self._valid_lookup[self._object_lookup[value]]) + for value in values + ] + ) @property def native(self): @@ -1411,22 +1444,24 @@ class Enum(Emulated, String, SchemaType): # here between an INSERT statement and a criteria used in a SELECT, # for now we're staying conservative w/ behavioral changes (perhaps # someone has a trigger that handles strings on INSERT) - if not self.validate_strings and \ - isinstance(elem, compat.string_types): + if not self.validate_strings and isinstance( + elem, compat.string_types + ): return elem else: raise LookupError( - '"%s" is not among the defined enum values' % elem) + '"%s" is not among the defined enum values' % elem + ) class Comparator(String.Comparator): - def _adapt_expression(self, op, other_comparator): op, typ = super(Enum.Comparator, self)._adapt_expression( - op, other_comparator) + op, other_comparator + ) if op is operators.concat_op: typ = String( - self.type.length, - convert_unicode=self.type.convert_unicode) + self.type.length, convert_unicode=self.type.convert_unicode + ) return op, typ comparator_factory = Comparator @@ -1436,38 +1471,40 @@ class Enum(Emulated, String, SchemaType): return self._object_lookup[elem] except KeyError: raise LookupError( - '"%s" is not among the defined enum values' % elem) + '"%s" is not among the defined enum values' % elem + ) def __repr__(self): return util.generic_repr( self, - additional_kw=[('native_enum', True)], + additional_kw=[("native_enum", True)], to_inspect=[Enum, SchemaType], ) def adapt_to_emulated(self, impltype, **kw): kw.setdefault("convert_unicode", self.convert_unicode) kw.setdefault("validate_strings", self.validate_strings) - kw.setdefault('name', self.name) - kw.setdefault('schema', self.schema) - kw.setdefault('inherit_schema', self.inherit_schema) - kw.setdefault('metadata', self.metadata) - kw.setdefault('_create_events', False) - kw.setdefault('native_enum', self.native_enum) - kw.setdefault('values_callable', self.values_callable) - kw.setdefault('create_constraint', self.create_constraint) - assert '_enums' in kw + kw.setdefault("name", self.name) + kw.setdefault("schema", self.schema) + kw.setdefault("inherit_schema", self.inherit_schema) + kw.setdefault("metadata", self.metadata) + kw.setdefault("_create_events", False) + kw.setdefault("native_enum", self.native_enum) + kw.setdefault("values_callable", self.values_callable) + kw.setdefault("create_constraint", self.create_constraint) + assert "_enums" in kw return impltype(**kw) def adapt(self, impltype, **kw): - kw['_enums'] = self._enums_argument + kw["_enums"] = self._enums_argument return super(Enum, self).adapt(impltype, **kw) def _should_create_constraint(self, compiler, **kw): if not self._is_impl_for_variant(compiler.dialect, kw): return False - return not self.native_enum or \ - not compiler.dialect.supports_native_enum + return ( + not self.native_enum or not compiler.dialect.supports_native_enum + ) @util.dependencies("sqlalchemy.sql.schema") def _set_table(self, schema, column, table): @@ -1483,20 +1520,21 @@ class Enum(Emulated, String, SchemaType): name=_defer_name(self.name), _create_rule=util.portable_instancemethod( self._should_create_constraint, - {"variant_mapping": variant_mapping}), - _type_bound=True + {"variant_mapping": variant_mapping}, + ), + _type_bound=True, ) assert e.table is table def literal_processor(self, dialect): - parent_processor = super( - Enum, self).literal_processor(dialect) + parent_processor = super(Enum, self).literal_processor(dialect) def process(value): value = self._db_value_for_elem(value) if parent_processor: value = parent_processor(value) return value + return process def bind_processor(self, dialect): @@ -1510,8 +1548,7 @@ class Enum(Emulated, String, SchemaType): return process def result_processor(self, dialect, coltype): - parent_processor = super(Enum, self).result_processor( - dialect, coltype) + parent_processor = super(Enum, self).result_processor(dialect, coltype) def process(value): if parent_processor: @@ -1548,8 +1585,9 @@ class PickleType(TypeDecorator): impl = LargeBinary - def __init__(self, protocol=pickle.HIGHEST_PROTOCOL, - pickler=None, comparator=None): + def __init__( + self, protocol=pickle.HIGHEST_PROTOCOL, pickler=None, comparator=None + ): """ Construct a PickleType. @@ -1570,40 +1608,46 @@ class PickleType(TypeDecorator): super(PickleType, self).__init__() def __reduce__(self): - return PickleType, (self.protocol, - None, - self.comparator) + return PickleType, (self.protocol, None, self.comparator) def bind_processor(self, dialect): impl_processor = self.impl.bind_processor(dialect) dumps = self.pickler.dumps protocol = self.protocol if impl_processor: + def process(value): if value is not None: value = dumps(value, protocol) return impl_processor(value) + else: + def process(value): if value is not None: value = dumps(value, protocol) return value + return process def result_processor(self, dialect, coltype): impl_processor = self.impl.result_processor(dialect, coltype) loads = self.pickler.loads if impl_processor: + def process(value): value = impl_processor(value) if value is None: return None return loads(value) + else: + def process(value): if value is None: return None return loads(value) + return process def compare_values(self, x, y): @@ -1635,11 +1679,10 @@ class Boolean(Emulated, TypeEngine, SchemaType): """ - __visit_name__ = 'boolean' + __visit_name__ = "boolean" native = True - def __init__( - self, create_constraint=True, name=None, _create_events=True): + def __init__(self, create_constraint=True, name=None, _create_events=True): """Construct a Boolean. :param create_constraint: defaults to True. If the boolean @@ -1657,8 +1700,10 @@ class Boolean(Emulated, TypeEngine, SchemaType): def _should_create_constraint(self, compiler, **kw): if not self._is_impl_for_variant(compiler.dialect, kw): return False - return not compiler.dialect.supports_native_boolean and \ - compiler.dialect.non_native_boolean_check_constraint + return ( + not compiler.dialect.supports_native_boolean + and compiler.dialect.non_native_boolean_check_constraint + ) @util.dependencies("sqlalchemy.sql.schema") def _set_table(self, schema, column, table): @@ -1672,8 +1717,9 @@ class Boolean(Emulated, TypeEngine, SchemaType): name=_defer_name(self.name), _create_rule=util.portable_instancemethod( self._should_create_constraint, - {"variant_mapping": variant_mapping}), - _type_bound=True + {"variant_mapping": variant_mapping}, + ), + _type_bound=True, ) assert e.table is table @@ -1686,11 +1732,11 @@ class Boolean(Emulated, TypeEngine, SchemaType): def _strict_as_bool(self, value): if value not in self._strict_bools: if not isinstance(value, int): - raise TypeError( - "Not a boolean value: %r" % value) + raise TypeError("Not a boolean value: %r" % value) else: raise ValueError( - "Value %r is not None, True, or False" % value) + "Value %r is not None, True, or False" % value + ) return value def literal_processor(self, dialect): @@ -1700,6 +1746,7 @@ class Boolean(Emulated, TypeEngine, SchemaType): def process(value): return true if self._strict_as_bool(value) else false + return process def bind_processor(self, dialect): @@ -1714,6 +1761,7 @@ class Boolean(Emulated, TypeEngine, SchemaType): if value is not None: value = _coerce(value) return value + return process def result_processor(self, dialect, coltype): @@ -1736,18 +1784,10 @@ class _AbstractInterval(_LookupExpressionAdapter, TypeEngine): DateTime: DateTime, Time: Time, }, - operators.sub: { - Interval: self.__class__ - }, - operators.mul: { - Numeric: self.__class__ - }, - operators.truediv: { - Numeric: self.__class__ - }, - operators.div: { - Numeric: self.__class__ - } + operators.sub: {Interval: self.__class__}, + operators.mul: {Numeric: self.__class__}, + operators.truediv: {Numeric: self.__class__}, + operators.div: {Numeric: self.__class__}, } @property @@ -1780,9 +1820,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator): impl = DateTime epoch = dt.datetime.utcfromtimestamp(0) - def __init__(self, native=True, - second_precision=None, - day_precision=None): + def __init__(self, native=True, second_precision=None, day_precision=None): """Construct an Interval object. :param native: when True, use the actual @@ -1815,31 +1853,39 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator): impl_processor = self.impl.bind_processor(dialect) epoch = self.epoch if impl_processor: + def process(value): if value is not None: value = epoch + value return impl_processor(value) + else: + def process(value): if value is not None: value = epoch + value return value + return process def result_processor(self, dialect, coltype): impl_processor = self.impl.result_processor(dialect, coltype) epoch = self.epoch if impl_processor: + def process(value): value = impl_processor(value) if value is None: return None return value - epoch + else: + def process(value): if value is None: return None return value - epoch + return process @@ -1986,10 +2032,11 @@ class JSON(Indexable, TypeEngine): """ - __visit_name__ = 'JSON' + + __visit_name__ = "JSON" hashable = False - NULL = util.symbol('JSON_NULL') + NULL = util.symbol("JSON_NULL") """Describe the json value of NULL. This value is used to force the JSON value of ``"null"`` to be @@ -2109,20 +2156,25 @@ class JSON(Indexable, TypeEngine): class Comparator(Indexable.Comparator, Concatenable.Comparator): """Define comparison operations for :class:`.types.JSON`.""" - @util.dependencies('sqlalchemy.sql.default_comparator') + @util.dependencies("sqlalchemy.sql.default_comparator") def _setup_getitem(self, default_comparator, index): - if not isinstance(index, util.string_types) and \ - isinstance(index, compat.collections_abc.Sequence): + if not isinstance(index, util.string_types) and isinstance( + index, compat.collections_abc.Sequence + ): index = default_comparator._check_literal( - self.expr, operators.json_path_getitem_op, - index, bindparam_type=JSON.JSONPathType + self.expr, + operators.json_path_getitem_op, + index, + bindparam_type=JSON.JSONPathType, ) operator = operators.json_path_getitem_op else: index = default_comparator._check_literal( - self.expr, operators.json_getitem_op, - index, bindparam_type=JSON.JSONIndexType + self.expr, + operators.json_getitem_op, + index, + bindparam_type=JSON.JSONIndexType, ) operator = operators.json_getitem_op @@ -2172,6 +2224,7 @@ class JSON(Indexable, TypeEngine): if string_process: value = string_process(value) return json_deserializer(value) + return process @@ -2266,7 +2319,8 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): :class:`.postgresql.ARRAY` """ - __visit_name__ = 'ARRAY' + + __visit_name__ = "ARRAY" zero_indexes = False """if True, Python zero-based indexes should be interpreted as one-based @@ -2285,21 +2339,23 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): if isinstance(index, slice): return_type = self.type if self.type.zero_indexes: - index = slice( - index.start + 1, - index.stop + 1, - index.step - ) + index = slice(index.start + 1, index.stop + 1, index.step) index = Slice( _literal_as_binds( - index.start, name=self.expr.key, - type_=type_api.INTEGERTYPE), + index.start, + name=self.expr.key, + type_=type_api.INTEGERTYPE, + ), _literal_as_binds( - index.stop, name=self.expr.key, - type_=type_api.INTEGERTYPE), + index.stop, + name=self.expr.key, + type_=type_api.INTEGERTYPE, + ), _literal_as_binds( - index.step, name=self.expr.key, - type_=type_api.INTEGERTYPE) + index.step, + name=self.expr.key, + type_=type_api.INTEGERTYPE, + ), ) else: if self.type.zero_indexes: @@ -2307,16 +2363,18 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): if self.type.dimensions is None or self.type.dimensions == 1: return_type = self.type.item_type else: - adapt_kw = {'dimensions': self.type.dimensions - 1} + adapt_kw = {"dimensions": self.type.dimensions - 1} return_type = self.type.adapt( - self.type.__class__, **adapt_kw) + self.type.__class__, **adapt_kw + ) return operators.getitem, index, return_type def contains(self, *arg, **kw): raise NotImplementedError( "ARRAY.contains() not implemented for the base " - "ARRAY type; please use the dialect-specific ARRAY type") + "ARRAY type; please use the dialect-specific ARRAY type" + ) @util.dependencies("sqlalchemy.sql.elements") def any(self, elements, other, operator=None): @@ -2350,7 +2408,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): operator = operator if operator else operators.eq return operator( elements._literal_as_binds(other), - elements.CollectionAggregate._create_any(self.expr) + elements.CollectionAggregate._create_any(self.expr), ) @util.dependencies("sqlalchemy.sql.elements") @@ -2385,13 +2443,14 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): operator = operator if operator else operators.eq return operator( elements._literal_as_binds(other), - elements.CollectionAggregate._create_all(self.expr) + elements.CollectionAggregate._create_all(self.expr), ) comparator_factory = Comparator - def __init__(self, item_type, as_tuple=False, dimensions=None, - zero_indexes=False): + def __init__( + self, item_type, as_tuple=False, dimensions=None, zero_indexes=False + ): """Construct an :class:`.types.ARRAY`. E.g.:: @@ -2424,8 +2483,10 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): """ if isinstance(item_type, ARRAY): - raise ValueError("Do not nest ARRAY types; ARRAY(basetype) " - "handles multi-dimensional arrays of basetype") + raise ValueError( + "Do not nest ARRAY types; ARRAY(basetype) " + "handles multi-dimensional arrays of basetype" + ) if isinstance(item_type, type): item_type = item_type() self.item_type = item_type @@ -2463,35 +2524,37 @@ class REAL(Float): """The SQL REAL type.""" - __visit_name__ = 'REAL' + __visit_name__ = "REAL" class FLOAT(Float): """The SQL FLOAT type.""" - __visit_name__ = 'FLOAT' + __visit_name__ = "FLOAT" class NUMERIC(Numeric): """The SQL NUMERIC type.""" - __visit_name__ = 'NUMERIC' + __visit_name__ = "NUMERIC" class DECIMAL(Numeric): """The SQL DECIMAL type.""" - __visit_name__ = 'DECIMAL' + __visit_name__ = "DECIMAL" class INTEGER(Integer): """The SQL INT or INTEGER type.""" - __visit_name__ = 'INTEGER' + __visit_name__ = "INTEGER" + + INT = INTEGER @@ -2499,14 +2562,14 @@ class SMALLINT(SmallInteger): """The SQL SMALLINT type.""" - __visit_name__ = 'SMALLINT' + __visit_name__ = "SMALLINT" class BIGINT(BigInteger): """The SQL BIGINT type.""" - __visit_name__ = 'BIGINT' + __visit_name__ = "BIGINT" class TIMESTAMP(DateTime): @@ -2520,7 +2583,7 @@ class TIMESTAMP(DateTime): """ - __visit_name__ = 'TIMESTAMP' + __visit_name__ = "TIMESTAMP" def __init__(self, timezone=False): """Construct a new :class:`.TIMESTAMP`. @@ -2543,28 +2606,28 @@ class DATETIME(DateTime): """The SQL DATETIME type.""" - __visit_name__ = 'DATETIME' + __visit_name__ = "DATETIME" class DATE(Date): """The SQL DATE type.""" - __visit_name__ = 'DATE' + __visit_name__ = "DATE" class TIME(Time): """The SQL TIME type.""" - __visit_name__ = 'TIME' + __visit_name__ = "TIME" class TEXT(Text): """The SQL TEXT type.""" - __visit_name__ = 'TEXT' + __visit_name__ = "TEXT" class CLOB(Text): @@ -2574,63 +2637,63 @@ class CLOB(Text): This type is found in Oracle and Informix. """ - __visit_name__ = 'CLOB' + __visit_name__ = "CLOB" class VARCHAR(String): """The SQL VARCHAR type.""" - __visit_name__ = 'VARCHAR' + __visit_name__ = "VARCHAR" class NVARCHAR(Unicode): """The SQL NVARCHAR type.""" - __visit_name__ = 'NVARCHAR' + __visit_name__ = "NVARCHAR" class CHAR(String): """The SQL CHAR type.""" - __visit_name__ = 'CHAR' + __visit_name__ = "CHAR" class NCHAR(Unicode): """The SQL NCHAR type.""" - __visit_name__ = 'NCHAR' + __visit_name__ = "NCHAR" class BLOB(LargeBinary): """The SQL BLOB type.""" - __visit_name__ = 'BLOB' + __visit_name__ = "BLOB" class BINARY(_Binary): """The SQL BINARY type.""" - __visit_name__ = 'BINARY' + __visit_name__ = "BINARY" class VARBINARY(_Binary): """The SQL VARBINARY type.""" - __visit_name__ = 'VARBINARY' + __visit_name__ = "VARBINARY" class BOOLEAN(Boolean): """The SQL BOOLEAN type.""" - __visit_name__ = 'BOOLEAN' + __visit_name__ = "BOOLEAN" class NullType(TypeEngine): @@ -2657,7 +2720,8 @@ class NullType(TypeEngine): construct. """ - __visit_name__ = 'null' + + __visit_name__ = "null" _isnull = True @@ -2666,16 +2730,18 @@ class NullType(TypeEngine): def literal_processor(self, dialect): def process(value): return "NULL" + return process class Comparator(TypeEngine.Comparator): - def _adapt_expression(self, op, other_comparator): - if isinstance(other_comparator, NullType.Comparator) or \ - not operators.is_commutative(op): + if isinstance( + other_comparator, NullType.Comparator + ) or not operators.is_commutative(op): return op, self.expr.type else: return other_comparator._adapt_expression(op, self) + comparator_factory = Comparator @@ -2694,6 +2760,7 @@ class MatchType(Boolean): """ + NULLTYPE = NullType() BOOLEANTYPE = Boolean() STRINGTYPE = String() @@ -2709,7 +2776,7 @@ _type_map = { dt.datetime: DateTime(), dt.time: Time(), dt.timedelta: Interval(), - util.NoneType: NULLTYPE + util.NoneType: NULLTYPE, } if util.py3k: @@ -2729,19 +2796,23 @@ def _resolve_value_to_type(value): # objects. insp = inspection.inspect(value, False) if ( - insp is not None and - # foil mock.Mock() and other impostors by ensuring - # the inspection target itself self-inspects - insp.__class__ in inspection._registrars + insp is not None + and + # foil mock.Mock() and other impostors by ensuring + # the inspection target itself self-inspects + insp.__class__ in inspection._registrars ): raise exc.ArgumentError( - "Object %r is not legal as a SQL literal value" % value) + "Object %r is not legal as a SQL literal value" % value + ) return NULLTYPE else: return _result_type + # back-assign to type_api from . import type_api + type_api.BOOLEANTYPE = BOOLEANTYPE type_api.STRINGTYPE = STRINGTYPE type_api.INTEGERTYPE = INTEGERTYPE diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index a8dfa19be7..7fe7807832 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -49,7 +49,8 @@ class TypeEngine(Visitable): """ - __slots__ = 'expr', 'type' + + __slots__ = "expr", "type" default_comparator = None @@ -57,16 +58,15 @@ class TypeEngine(Visitable): self.expr = expr self.type = expr.type - @util.dependencies('sqlalchemy.sql.default_comparator') + @util.dependencies("sqlalchemy.sql.default_comparator") def operate(self, default_comparator, op, *other, **kwargs): o = default_comparator.operator_lookup[op.__name__] return o[0](self.expr, op, *(other + o[1:]), **kwargs) - @util.dependencies('sqlalchemy.sql.default_comparator') + @util.dependencies("sqlalchemy.sql.default_comparator") def reverse_operate(self, default_comparator, op, other, **kwargs): o = default_comparator.operator_lookup[op.__name__] - return o[0](self.expr, op, other, - reverse=True, *o[1:], **kwargs) + return o[0](self.expr, op, other, reverse=True, *o[1:], **kwargs) def _adapt_expression(self, op, other_comparator): """evaluate the return type of , @@ -97,7 +97,7 @@ class TypeEngine(Visitable): return op, self.type def __reduce__(self): - return _reconstitute_comparator, (self.expr, ) + return _reconstitute_comparator, (self.expr,) hashable = True """Flag, if False, means values from this type aren't hashable. @@ -313,8 +313,10 @@ class TypeEngine(Visitable): """ - return self.__class__.column_expression.__code__ \ + return ( + self.__class__.column_expression.__code__ is not TypeEngine.column_expression.__code__ + ) def bind_expression(self, bindvalue): """"Given a bind value (i.e. a :class:`.BindParameter` instance), @@ -351,8 +353,10 @@ class TypeEngine(Visitable): """ - return self.__class__.bind_expression.__code__ \ + return ( + self.__class__.bind_expression.__code__ is not TypeEngine.bind_expression.__code__ + ) @staticmethod def _to_instance(cls_or_self): @@ -441,9 +445,9 @@ class TypeEngine(Visitable): """ try: - return dialect._type_memos[self]['impl'] + return dialect._type_memos[self]["impl"] except KeyError: - return self._dialect_info(dialect)['impl'] + return self._dialect_info(dialect)["impl"] def _unwrapped_dialect_impl(self, dialect): """Return the 'unwrapped' dialect impl for this type. @@ -462,20 +466,20 @@ class TypeEngine(Visitable): def _cached_literal_processor(self, dialect): """Return a dialect-specific literal processor for this type.""" try: - return dialect._type_memos[self]['literal'] + return dialect._type_memos[self]["literal"] except KeyError: d = self._dialect_info(dialect) - d['literal'] = lp = d['impl'].literal_processor(dialect) + d["literal"] = lp = d["impl"].literal_processor(dialect) return lp def _cached_bind_processor(self, dialect): """Return a dialect-specific bind processor for this type.""" try: - return dialect._type_memos[self]['bind'] + return dialect._type_memos[self]["bind"] except KeyError: d = self._dialect_info(dialect) - d['bind'] = bp = d['impl'].bind_processor(dialect) + d["bind"] = bp = d["impl"].bind_processor(dialect) return bp def _cached_result_processor(self, dialect, coltype): @@ -488,7 +492,7 @@ class TypeEngine(Visitable): # key assumption: DBAPI type codes are # constants. Else this dictionary would # grow unbounded. - d[coltype] = rp = d['impl'].result_processor(dialect, coltype) + d[coltype] = rp = d["impl"].result_processor(dialect, coltype) return rp def _cached_custom_processor(self, dialect, key, fn): @@ -496,7 +500,7 @@ class TypeEngine(Visitable): return dialect._type_memos[self][key] except KeyError: d = self._dialect_info(dialect) - impl = d['impl'] + impl = d["impl"] d[key] = result = fn(impl) return result @@ -513,7 +517,7 @@ class TypeEngine(Visitable): impl = self.adapt(type(self)) # this can't be self, else we create a cycle assert impl is not self - dialect._type_memos[self] = d = {'impl': impl} + dialect._type_memos[self] = d = {"impl": impl} return d def _gen_dialect_impl(self, dialect): @@ -549,8 +553,10 @@ class TypeEngine(Visitable): """ _coerced_type = _resolve_value_to_type(value) - if _coerced_type is NULLTYPE or _coerced_type._type_affinity \ - is self._type_affinity: + if ( + _coerced_type is NULLTYPE + or _coerced_type._type_affinity is self._type_affinity + ): return self else: return _coerced_type @@ -586,8 +592,7 @@ class TypeEngine(Visitable): def __str__(self): if util.py2k: - return unicode(self.compile()).\ - encode('ascii', 'backslashreplace') + return unicode(self.compile()).encode("ascii", "backslashreplace") else: return str(self.compile()) @@ -645,15 +650,16 @@ class UserDefinedType(util.with_metaclass(VisitableCheckKWArg, TypeEngine)): ``type_expression``, if it receives ``**kw`` in its signature. """ + __visit_name__ = "user_defined" - ensure_kwarg = 'get_col_spec' + ensure_kwarg = "get_col_spec" class Comparator(TypeEngine.Comparator): __slots__ = () def _adapt_expression(self, op, other_comparator): - if hasattr(self.type, 'adapt_operator'): + if hasattr(self.type, "adapt_operator"): util.warn_deprecated( "UserDefinedType.adapt_operator is deprecated. Create " "a UserDefinedType.Comparator subclass instead which " @@ -854,6 +860,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): will cause the index value ``'foo'`` to be JSON encoded. """ + __visit_name__ = "type_decorator" def __init__(self, *args, **kwargs): @@ -874,14 +881,16 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): """ - if not hasattr(self.__class__, 'impl'): - raise AssertionError("TypeDecorator implementations " - "require a class-level variable " - "'impl' which refers to the class of " - "type being decorated") + if not hasattr(self.__class__, "impl"): + raise AssertionError( + "TypeDecorator implementations " + "require a class-level variable " + "'impl' which refers to the class of " + "type being decorated" + ) self.impl = to_instance(self.__class__.impl, *args, **kwargs) - coerce_to_is_types = (util.NoneType, ) + coerce_to_is_types = (util.NoneType,) """Specify those Python types which should be coerced at the expression level to "IS " when compared using ``==`` (and same for ``IS NOT`` in conjunction with ``!=``. @@ -906,24 +915,27 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): __slots__ = () def operate(self, op, *other, **kwargs): - kwargs['_python_is_types'] = self.expr.type.coerce_to_is_types + kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types return super(TypeDecorator.Comparator, self).operate( - op, *other, **kwargs) + op, *other, **kwargs + ) def reverse_operate(self, op, other, **kwargs): - kwargs['_python_is_types'] = self.expr.type.coerce_to_is_types + kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types return super(TypeDecorator.Comparator, self).reverse_operate( - op, other, **kwargs) + op, other, **kwargs + ) @property def comparator_factory(self): if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__: return self.impl.comparator_factory else: - return type("TDComparator", - (TypeDecorator.Comparator, - self.impl.comparator_factory), - {}) + return type( + "TDComparator", + (TypeDecorator.Comparator, self.impl.comparator_factory), + {}, + ) def _gen_dialect_impl(self, dialect): """ @@ -939,10 +951,11 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): typedesc = self._unwrapped_dialect_impl(dialect) tt = self.copy() if not isinstance(tt, self.__class__): - raise AssertionError('Type object %s does not properly ' - 'implement the copy() method, it must ' - 'return an object of type %s' % - (self, self.__class__)) + raise AssertionError( + "Type object %s does not properly " + "implement the copy() method, it must " + "return an object of type %s" % (self, self.__class__) + ) tt.impl = typedesc return tt @@ -1099,8 +1112,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): """ - return self.__class__.process_bind_param.__code__ \ + return ( + self.__class__.process_bind_param.__code__ is not TypeDecorator.process_bind_param.__code__ + ) @util.memoized_property def _has_literal_processor(self): @@ -1109,8 +1124,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): """ - return self.__class__.process_literal_param.__code__ \ + return ( + self.__class__.process_literal_param.__code__ is not TypeDecorator.process_literal_param.__code__ + ) def literal_processor(self, dialect): """Provide a literal processing function for the given @@ -1147,9 +1164,12 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): if process_param: impl_processor = self.impl.literal_processor(dialect) if impl_processor: + def process(value): return impl_processor(process_param(value, dialect)) + else: + def process(value): return process_param(value, dialect) @@ -1180,10 +1200,12 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): process_param = self.process_bind_param impl_processor = self.impl.bind_processor(dialect) if impl_processor: + def process(value): return impl_processor(process_param(value, dialect)) else: + def process(value): return process_param(value, dialect) @@ -1200,8 +1222,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): exception throw. """ - return self.__class__.process_result_value.__code__ \ + return ( + self.__class__.process_result_value.__code__ is not TypeDecorator.process_result_value.__code__ + ) def result_processor(self, dialect, coltype): """Provide a result value processing function for the given @@ -1225,13 +1249,14 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): """ if self._has_result_processor: process_value = self.process_result_value - impl_processor = self.impl.result_processor(dialect, - coltype) + impl_processor = self.impl.result_processor(dialect, coltype) if impl_processor: + def process(value): return process_value(impl_processor(value), dialect) else: + def process(value): return process_value(value, dialect) @@ -1397,7 +1422,8 @@ class Variant(TypeDecorator): if dialect_name in self.mapping: raise exc.ArgumentError( "Dialect '%s' is already present in " - "the mapping for this Variant" % dialect_name) + "the mapping for this Variant" % dialect_name + ) mapping = self.mapping.copy() mapping[dialect_name] = type_ return Variant(self.impl, mapping) @@ -1439,6 +1465,6 @@ def adapt_type(typeobj, colspecs): # but it turns out the originally given "generic" type # is actually a subclass of our resulting type, then we were already # given a more specific type than that required; so use that. - if (issubclass(typeobj.__class__, impltype)): + if issubclass(typeobj.__class__, impltype): return typeobj return typeobj.adapt(impltype) diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 12cfe09d15..4feaf99389 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -15,15 +15,29 @@ from . import operators, visitors from itertools import chain from collections import deque -from .elements import BindParameter, ColumnClause, ColumnElement, \ - Null, UnaryExpression, literal_column, Label, _label_reference, \ - _textual_label_reference -from .selectable import SelectBase, ScalarSelect, Join, FromClause, FromGrouping +from .elements import ( + BindParameter, + ColumnClause, + ColumnElement, + Null, + UnaryExpression, + literal_column, + Label, + _label_reference, + _textual_label_reference, +) +from .selectable import ( + SelectBase, + ScalarSelect, + Join, + FromClause, + FromGrouping, +) from .schema import Column join_condition = util.langhelpers.public_factory( - Join._join_condition, - ".sql.util.join_condition") + Join._join_condition, ".sql.util.join_condition" +) # names that are still being imported from the outside from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate @@ -88,8 +102,9 @@ def find_left_clause_that_matches_given(clauses, join_from): for idx in liberal_idx: f = clauses[idx] for s in selectables: - if set(surface_selectables(f)).\ - intersection(surface_selectables(s)): + if set(surface_selectables(f)).intersection( + surface_selectables(s) + ): conservative_idx.append(idx) break if conservative_idx: @@ -184,8 +199,9 @@ def visit_binary_product(fn, expr): # we don't want to dig into correlated subqueries, # those are just column elements by themselves yield element - elif element.__visit_name__ == 'binary' and \ - operators.is_comparison(element.operator): + elif element.__visit_name__ == "binary" and operators.is_comparison( + element.operator + ): stack.insert(0, element) for l in visit(element.left): for r in visit(element.right): @@ -199,38 +215,47 @@ def visit_binary_product(fn, expr): for elem in element.get_children(): for e in visit(elem): yield e + list(visit(expr)) -def find_tables(clause, check_columns=False, - include_aliases=False, include_joins=False, - include_selects=False, include_crud=False): +def find_tables( + clause, + check_columns=False, + include_aliases=False, + include_joins=False, + include_selects=False, + include_crud=False, +): """locate Table objects within the given expression.""" tables = [] _visitors = {} if include_selects: - _visitors['select'] = _visitors['compound_select'] = tables.append + _visitors["select"] = _visitors["compound_select"] = tables.append if include_joins: - _visitors['join'] = tables.append + _visitors["join"] = tables.append if include_aliases: - _visitors['alias'] = tables.append + _visitors["alias"] = tables.append if include_crud: - _visitors['insert'] = _visitors['update'] = \ - _visitors['delete'] = lambda ent: tables.append(ent.table) + _visitors["insert"] = _visitors["update"] = _visitors[ + "delete" + ] = lambda ent: tables.append(ent.table) if check_columns: + def visit_column(column): tables.append(column.table) - _visitors['column'] = visit_column - _visitors['table'] = tables.append + _visitors["column"] = visit_column - visitors.traverse(clause, {'column_collections': False}, _visitors) + _visitors["table"] = tables.append + + visitors.traverse(clause, {"column_collections": False}, _visitors) return tables @@ -243,10 +268,9 @@ def unwrap_order_by(clause): stack = deque([clause]) while stack: t = stack.popleft() - if isinstance(t, ColumnElement) and \ - ( - not isinstance(t, UnaryExpression) or - not operators.is_ordering_modifier(t.modifier) + if isinstance(t, ColumnElement) and ( + not isinstance(t, UnaryExpression) + or not operators.is_ordering_modifier(t.modifier) ): if isinstance(t, _label_reference): t = t.element @@ -266,9 +290,7 @@ def unwrap_label_reference(element): if isinstance(elem, (_label_reference, _textual_label_reference)): return elem.element - return visitors.replacement_traverse( - element, {}, replace - ) + return visitors.replacement_traverse(element, {}, replace) def expand_column_list_from_order_by(collist, order_by): @@ -278,17 +300,16 @@ def expand_column_list_from_order_by(collist, order_by): in the collist. """ - cols_already_present = set([ - col.element if col._order_by_label_element is not None - else col for col in collist - ]) + cols_already_present = set( + [ + col.element if col._order_by_label_element is not None else col + for col in collist + ] + ) return [ - col for col in - chain(*[ - unwrap_order_by(o) - for o in order_by - ]) + col + for col in chain(*[unwrap_order_by(o) for o in order_by]) if col not in cols_already_present ] @@ -325,9 +346,9 @@ def surface_column_elements(clause, include_scalar_selects=True): be addressable in the WHERE clause of a SELECT if this element were in the columns clause.""" - filter_ = (FromGrouping, ) + filter_ = (FromGrouping,) if not include_scalar_selects: - filter_ += (SelectBase, ) + filter_ += (SelectBase,) stack = deque([clause]) while stack: @@ -343,9 +364,7 @@ def selectables_overlap(left, right): """Return True if left/right have some overlapping selectable""" return bool( - set(surface_selectables(left)).intersection( - surface_selectables(right) - ) + set(surface_selectables(left)).intersection(surface_selectables(right)) ) @@ -366,7 +385,7 @@ def bind_values(clause): def visit_bindparam(bind): v.append(bind.effective_value) - visitors.traverse(clause, {}, {'bindparam': visit_bindparam}) + visitors.traverse(clause, {}, {"bindparam": visit_bindparam}) return v @@ -383,7 +402,7 @@ class _repr_base(object): _TUPLE = 1 _DICT = 2 - __slots__ = 'max_chars', + __slots__ = ("max_chars",) def trunc(self, value): rep = repr(value) @@ -391,10 +410,12 @@ class _repr_base(object): if lenrep > self.max_chars: segment_length = self.max_chars // 2 rep = ( - rep[0:segment_length] + - (" ... (%d characters truncated) ... " - % (lenrep - self.max_chars)) + - rep[-segment_length:] + rep[0:segment_length] + + ( + " ... (%d characters truncated) ... " + % (lenrep - self.max_chars) + ) + + rep[-segment_length:] ) return rep @@ -402,7 +423,7 @@ class _repr_base(object): class _repr_row(_repr_base): """Provide a string view of a row.""" - __slots__ = 'row', + __slots__ = ("row",) def __init__(self, row, max_chars=300): self.row = row @@ -412,7 +433,7 @@ class _repr_row(_repr_base): trunc = self.trunc return "(%s%s)" % ( ", ".join(trunc(value) for value in self.row), - "," if len(self.row) == 1 else "" + "," if len(self.row) == 1 else "", ) @@ -424,7 +445,7 @@ class _repr_params(_repr_base): """ - __slots__ = 'params', 'batches', + __slots__ = "params", "batches" def __init__(self, params, batches, max_chars=300): self.params = params @@ -435,11 +456,13 @@ class _repr_params(_repr_base): if isinstance(self.params, list): typ = self._LIST ismulti = self.params and isinstance( - self.params[0], (list, dict, tuple)) + self.params[0], (list, dict, tuple) + ) elif isinstance(self.params, tuple): typ = self._TUPLE ismulti = self.params and isinstance( - self.params[0], (list, dict, tuple)) + self.params[0], (list, dict, tuple) + ) elif isinstance(self.params, dict): typ = self._DICT ismulti = False @@ -448,11 +471,15 @@ class _repr_params(_repr_base): if ismulti and len(self.params) > self.batches: msg = " ... displaying %i of %i total bound parameter sets ... " - return ' '.join(( - self._repr_multi(self.params[:self.batches - 2], typ)[0:-1], - msg % (self.batches, len(self.params)), - self._repr_multi(self.params[-2:], typ)[1:] - )) + return " ".join( + ( + self._repr_multi(self.params[: self.batches - 2], typ)[ + 0:-1 + ], + msg % (self.batches, len(self.params)), + self._repr_multi(self.params[-2:], typ)[1:], + ) + ) elif ismulti: return self._repr_multi(self.params, typ) else: @@ -467,12 +494,13 @@ class _repr_params(_repr_base): elif isinstance(multi_params[0], dict): elem_type = self._DICT else: - assert False, \ - "Unknown parameter type %s" % (type(multi_params[0])) + assert False, "Unknown parameter type %s" % ( + type(multi_params[0]) + ) elements = ", ".join( - self._repr_params(params, elem_type) - for params in multi_params) + self._repr_params(params, elem_type) for params in multi_params + ) else: elements = "" @@ -493,13 +521,10 @@ class _repr_params(_repr_base): elif typ is self._TUPLE: return "(%s%s)" % ( ", ".join(trunc(value) for value in params), - "," if len(params) == 1 else "" - + "," if len(params) == 1 else "", ) else: - return "[%s]" % ( - ", ".join(trunc(value) for value in params) - ) + return "[%s]" % (", ".join(trunc(value) for value in params)) def adapt_criterion_to_null(crit, nulls): @@ -509,20 +534,24 @@ def adapt_criterion_to_null(crit, nulls): """ def visit_binary(binary): - if isinstance(binary.left, BindParameter) \ - and binary.left._identifying_key in nulls: + if ( + isinstance(binary.left, BindParameter) + and binary.left._identifying_key in nulls + ): # reverse order if the NULL is on the left side binary.left = binary.right binary.right = Null() binary.operator = operators.is_ binary.negate = operators.isnot - elif isinstance(binary.right, BindParameter) \ - and binary.right._identifying_key in nulls: + elif ( + isinstance(binary.right, BindParameter) + and binary.right._identifying_key in nulls + ): binary.right = Null() binary.operator = operators.is_ binary.negate = operators.isnot - return visitors.cloned_traverse(crit, {}, {'binary': visit_binary}) + return visitors.cloned_traverse(crit, {}, {"binary": visit_binary}) def splice_joins(left, right, stop_on=None): @@ -570,8 +599,8 @@ def reduce_columns(columns, *clauses, **kw): in the selectable to just those that are not repeated. """ - ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False) - only_synonyms = kw.pop('only_synonyms', False) + ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False) + only_synonyms = kw.pop("only_synonyms", False) columns = util.ordered_column_set(columns) @@ -597,39 +626,48 @@ def reduce_columns(columns, *clauses, **kw): continue else: raise - if fk_col.shares_lineage(c) and \ - (not only_synonyms or - c.name == col.name): + if fk_col.shares_lineage(c) and ( + not only_synonyms or c.name == col.name + ): omit.add(col) break if clauses: + def visit_binary(binary): if binary.operator == operators.eq: cols = util.column_set( - chain(*[c.proxy_set for c in columns.difference(omit)])) + chain(*[c.proxy_set for c in columns.difference(omit)]) + ) if binary.left in cols and binary.right in cols: for c in reversed(columns): - if c.shares_lineage(binary.right) and \ - (not only_synonyms or - c.name == binary.left.name): + if c.shares_lineage(binary.right) and ( + not only_synonyms or c.name == binary.left.name + ): omit.add(c) break + for clause in clauses: if clause is not None: - visitors.traverse(clause, {}, {'binary': visit_binary}) + visitors.traverse(clause, {}, {"binary": visit_binary}) return ColumnSet(columns.difference(omit)) -def criterion_as_pairs(expression, consider_as_foreign_keys=None, - consider_as_referenced_keys=None, any_operator=False): +def criterion_as_pairs( + expression, + consider_as_foreign_keys=None, + consider_as_referenced_keys=None, + any_operator=False, +): """traverse an expression and locate binary criterion pairs.""" if consider_as_foreign_keys and consider_as_referenced_keys: - raise exc.ArgumentError("Can only specify one of " - "'consider_as_foreign_keys' or " - "'consider_as_referenced_keys'") + raise exc.ArgumentError( + "Can only specify one of " + "'consider_as_foreign_keys' or " + "'consider_as_referenced_keys'" + ) def col_is(a, b): # return a is b @@ -638,37 +676,44 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, def visit_binary(binary): if not any_operator and binary.operator is not operators.eq: return - if not isinstance(binary.left, ColumnElement) or \ - not isinstance(binary.right, ColumnElement): + if not isinstance(binary.left, ColumnElement) or not isinstance( + binary.right, ColumnElement + ): return if consider_as_foreign_keys: - if binary.left in consider_as_foreign_keys and \ - (col_is(binary.right, binary.left) or - binary.right not in consider_as_foreign_keys): + if binary.left in consider_as_foreign_keys and ( + col_is(binary.right, binary.left) + or binary.right not in consider_as_foreign_keys + ): pairs.append((binary.right, binary.left)) - elif binary.right in consider_as_foreign_keys and \ - (col_is(binary.left, binary.right) or - binary.left not in consider_as_foreign_keys): + elif binary.right in consider_as_foreign_keys and ( + col_is(binary.left, binary.right) + or binary.left not in consider_as_foreign_keys + ): pairs.append((binary.left, binary.right)) elif consider_as_referenced_keys: - if binary.left in consider_as_referenced_keys and \ - (col_is(binary.right, binary.left) or - binary.right not in consider_as_referenced_keys): + if binary.left in consider_as_referenced_keys and ( + col_is(binary.right, binary.left) + or binary.right not in consider_as_referenced_keys + ): pairs.append((binary.left, binary.right)) - elif binary.right in consider_as_referenced_keys and \ - (col_is(binary.left, binary.right) or - binary.left not in consider_as_referenced_keys): + elif binary.right in consider_as_referenced_keys and ( + col_is(binary.left, binary.right) + or binary.left not in consider_as_referenced_keys + ): pairs.append((binary.right, binary.left)) else: - if isinstance(binary.left, Column) and \ - isinstance(binary.right, Column): + if isinstance(binary.left, Column) and isinstance( + binary.right, Column + ): if binary.left.references(binary.right): pairs.append((binary.right, binary.left)) elif binary.right.references(binary.left): pairs.append((binary.left, binary.right)) + pairs = [] - visitors.traverse(expression, {}, {'binary': visit_binary}) + visitors.traverse(expression, {}, {"binary": visit_binary}) return pairs @@ -699,28 +744,38 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): """ - def __init__(self, selectable, equivalents=None, - include_fn=None, exclude_fn=None, - adapt_on_names=False, anonymize_labels=False): + def __init__( + self, + selectable, + equivalents=None, + include_fn=None, + exclude_fn=None, + adapt_on_names=False, + anonymize_labels=False, + ): self.__traverse_options__ = { - 'stop_on': [selectable], - 'anonymize_labels': anonymize_labels} + "stop_on": [selectable], + "anonymize_labels": anonymize_labels, + } self.selectable = selectable self.include_fn = include_fn self.exclude_fn = exclude_fn self.equivalents = util.column_dict(equivalents or {}) self.adapt_on_names = adapt_on_names - def _corresponding_column(self, col, require_embedded, - _seen=util.EMPTY_SET): + def _corresponding_column( + self, col, require_embedded, _seen=util.EMPTY_SET + ): newcol = self.selectable.corresponding_column( - col, - require_embedded=require_embedded) + col, require_embedded=require_embedded + ) if newcol is None and col in self.equivalents and col not in _seen: for equiv in self.equivalents[col]: newcol = self._corresponding_column( - equiv, require_embedded=require_embedded, - _seen=_seen.union([col])) + equiv, + require_embedded=require_embedded, + _seen=_seen.union([col]), + ) if newcol is not None: return newcol if self.adapt_on_names and newcol is None: @@ -728,8 +783,9 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): return newcol def replace(self, col): - if isinstance(col, FromClause) and \ - self.selectable.is_derived_from(col): + if isinstance(col, FromClause) and self.selectable.is_derived_from( + col + ): return self.selectable elif not isinstance(col, ColumnElement): return None @@ -772,16 +828,27 @@ class ColumnAdapter(ClauseAdapter): """ - def __init__(self, selectable, equivalents=None, - chain_to=None, adapt_required=False, - include_fn=None, exclude_fn=None, - adapt_on_names=False, - allow_label_resolve=True, - anonymize_labels=False): - ClauseAdapter.__init__(self, selectable, equivalents, - include_fn=include_fn, exclude_fn=exclude_fn, - adapt_on_names=adapt_on_names, - anonymize_labels=anonymize_labels) + def __init__( + self, + selectable, + equivalents=None, + chain_to=None, + adapt_required=False, + include_fn=None, + exclude_fn=None, + adapt_on_names=False, + allow_label_resolve=True, + anonymize_labels=False, + ): + ClauseAdapter.__init__( + self, + selectable, + equivalents, + include_fn=include_fn, + exclude_fn=exclude_fn, + adapt_on_names=adapt_on_names, + anonymize_labels=anonymize_labels, + ) if chain_to: self.chain(chain_to) @@ -800,9 +867,7 @@ class ColumnAdapter(ClauseAdapter): def __getitem__(self, key): if ( self.parent.include_fn and not self.parent.include_fn(key) - ) or ( - self.parent.exclude_fn and self.parent.exclude_fn(key) - ): + ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)): if self.parent._wrap: return self.parent._wrap.columns[key] else: @@ -843,7 +908,7 @@ class ColumnAdapter(ClauseAdapter): def __getstate__(self): d = self.__dict__.copy() - del d['columns'] + del d["columns"] return d def __setstate__(self, state): diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index b39ec8167a..bf17436436 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -29,11 +29,20 @@ from .. import util import operator from .. import exc -__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor', - 'CloningVisitor', 'ReplacingCloningVisitor', 'iterate', - 'iterate_depthfirst', 'traverse_using', 'traverse', - 'traverse_depthfirst', - 'cloned_traverse', 'replacement_traverse'] +__all__ = [ + "VisitableType", + "Visitable", + "ClauseVisitor", + "CloningVisitor", + "ReplacingCloningVisitor", + "iterate", + "iterate_depthfirst", + "traverse_using", + "traverse", + "traverse_depthfirst", + "cloned_traverse", + "replacement_traverse", +] class VisitableType(type): @@ -53,8 +62,7 @@ class VisitableType(type): """ def __init__(cls, clsname, bases, clsdict): - if clsname != 'Visitable' and \ - hasattr(cls, '__visit_name__'): + if clsname != "Visitable" and hasattr(cls, "__visit_name__"): _generate_dispatch(cls) super(VisitableType, cls).__init__(clsname, bases, clsdict) @@ -64,7 +72,7 @@ def _generate_dispatch(cls): """Return an optimized visit dispatch function for the cls for use by the compiler. """ - if '__visit_name__' in cls.__dict__: + if "__visit_name__" in cls.__dict__: visit_name = cls.__visit_name__ if isinstance(visit_name, str): # There is an optimization opportunity here because the @@ -79,12 +87,13 @@ def _generate_dispatch(cls): raise exc.UnsupportedCompilationError(visitor, cls) else: return meth(self, **kw) + else: # The optimization opportunity is lost for this case because the # __visit_name__ is not yet a string. As a result, the visit # string has to be recalculated with each compilation. def _compiler_dispatch(self, visitor, **kw): - visit_attr = 'visit_%s' % self.__visit_name__ + visit_attr = "visit_%s" % self.__visit_name__ try: meth = getattr(visitor, visit_attr) except AttributeError: @@ -92,8 +101,7 @@ def _generate_dispatch(cls): else: return meth(self, **kw) - _compiler_dispatch.__doc__ = \ - """Look for an attribute named "visit_" + self.__visit_name__ + _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__ on the visitor, and call it with the same kw params. """ cls._compiler_dispatch = _compiler_dispatch @@ -137,7 +145,7 @@ class ClauseVisitor(object): visitors = {} for name in dir(self): - if name.startswith('visit_'): + if name.startswith("visit_"): visitors[name[6:]] = getattr(self, name) return visitors @@ -148,7 +156,7 @@ class ClauseVisitor(object): v = self while v: yield v - v = getattr(v, '_next', None) + v = getattr(v, "_next", None) def chain(self, visitor): """'chain' an additional ClauseVisitor onto this ClauseVisitor. @@ -178,7 +186,8 @@ class CloningVisitor(ClauseVisitor): """traverse and visit the given expression structure.""" return cloned_traverse( - obj, self.__traverse_options__, self._visitor_dict) + obj, self.__traverse_options__, self._visitor_dict + ) class ReplacingCloningVisitor(CloningVisitor): @@ -204,6 +213,7 @@ class ReplacingCloningVisitor(CloningVisitor): e = v.replace(elem) if e is not None: return e + return replacement_traverse(obj, self.__traverse_options__, replace) @@ -282,7 +292,7 @@ def cloned_traverse(obj, opts, visitors): modifications by visitors.""" cloned = {} - stop_on = set(opts.get('stop_on', [])) + stop_on = set(opts.get("stop_on", [])) def clone(elem): if elem in stop_on: @@ -306,11 +316,13 @@ def replacement_traverse(obj, opts, replace): replacement by a given replacement function.""" cloned = {} - stop_on = {id(x) for x in opts.get('stop_on', [])} + stop_on = {id(x) for x in opts.get("stop_on", [])} def clone(elem, **kw): - if id(elem) in stop_on or \ - 'no_replacement_traverse' in elem._annotations: + if ( + id(elem) in stop_on + or "no_replacement_traverse" in elem._annotations + ): return elem else: newelem = replace(elem) diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 413a492b83..f46ca4528a 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -10,23 +10,62 @@ from .warnings import assert_warnings from . import config -from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\ - fails_on, fails_on_everything_except, skip, only_on, exclude, \ - against as _against, _server_version, only_if, fails +from .exclusions import ( + db_spec, + _is_excluded, + fails_if, + skip_if, + future, + fails_on, + fails_on_everything_except, + skip, + only_on, + exclude, + against as _against, + _server_version, + only_if, + fails, +) def against(*queries): return _against(config._current, *queries) -from .assertions import emits_warning, emits_warning_on, uses_deprecated, \ - eq_, ne_, le_, is_, is_not_, startswith_, assert_raises, \ - assert_raises_message, AssertsCompiledSQL, ComparesTables, \ - AssertsExecutionResults, expect_deprecated, expect_warnings, \ - in_, not_in_, eq_ignore_whitespace, eq_regex, is_true, is_false -from .util import run_as_contextmanager, rowset, fail, \ - provide_metadata, adict, force_drop_names, \ - teardown_events +from .assertions import ( + emits_warning, + emits_warning_on, + uses_deprecated, + eq_, + ne_, + le_, + is_, + is_not_, + startswith_, + assert_raises, + assert_raises_message, + AssertsCompiledSQL, + ComparesTables, + AssertsExecutionResults, + expect_deprecated, + expect_warnings, + in_, + not_in_, + eq_ignore_whitespace, + eq_regex, + is_true, + is_false, +) + +from .util import ( + run_as_contextmanager, + rowset, + fail, + provide_metadata, + adict, + force_drop_names, + teardown_events, +) crashes = skip diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index e42376921a..73ab4556af 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -86,6 +86,7 @@ def emits_warning_on(db, *messages): were in fact seen. """ + @decorator def decorate(fn, *args, **kw): with expect_warnings_on(db, assert_=False, *messages): @@ -114,12 +115,14 @@ def uses_deprecated(*messages): def decorate(fn, *args, **kw): with expect_deprecated(*messages, assert_=False): return fn(*args, **kw) + return decorate @contextlib.contextmanager -def _expect_warnings(exc_cls, messages, regex=True, assert_=True, - py2konly=False): +def _expect_warnings( + exc_cls, messages, regex=True, assert_=True, py2konly=False +): if regex: filters = [re.compile(msg, re.I | re.S) for msg in messages] @@ -145,8 +148,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True, return for filter_ in filters: - if (regex and filter_.match(msg)) or \ - (not regex and filter_ == msg): + if (regex and filter_.match(msg)) or ( + not regex and filter_ == msg + ): seen.discard(filter_) break else: @@ -156,8 +160,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True, yield if assert_ and (not py2konly or not compat.py3k): - assert not seen, "Warnings were not seen: %s" % \ - ", ".join("%r" % (s.pattern if regex else s) for s in seen) + assert not seen, "Warnings were not seen: %s" % ", ".join( + "%r" % (s.pattern if regex else s) for s in seen + ) def global_cleanup_assertions(): @@ -170,6 +175,7 @@ def global_cleanup_assertions(): """ _assert_no_stray_pool_connections() + _STRAY_CONNECTION_FAILURES = 0 @@ -187,8 +193,10 @@ def _assert_no_stray_pool_connections(): # OK, let's be somewhat forgiving. _STRAY_CONNECTION_FAILURES += 1 - print("Encountered a stray connection in test cleanup: %s" - % str(pool._refs)) + print( + "Encountered a stray connection in test cleanup: %s" + % str(pool._refs) + ) # then do a real GC sweep. We shouldn't even be here # so a single sweep should really be doing it, otherwise # there's probably a real unreachable cycle somewhere. @@ -206,8 +214,8 @@ def _assert_no_stray_pool_connections(): pool._refs.clear() _STRAY_CONNECTION_FAILURES = 0 warnings.warn( - "Stray connection refused to leave " - "after gc.collect(): %s" % err) + "Stray connection refused to leave " "after gc.collect(): %s" % err + ) elif _STRAY_CONNECTION_FAILURES > 10: assert False, "Encountered more than 10 stray connections" _STRAY_CONNECTION_FAILURES = 0 @@ -263,14 +271,16 @@ def not_in_(a, b, msg=None): def startswith_(a, fragment, msg=None): """Assert a.startswith(fragment), with repr messaging on failure.""" assert a.startswith(fragment), msg or "%r does not start with %r" % ( - a, fragment) + a, + fragment, + ) def eq_ignore_whitespace(a, b, msg=None): - a = re.sub(r'^\s+?|\n', "", a) - a = re.sub(r' {2,}', " ", a) - b = re.sub(r'^\s+?|\n', "", b) - b = re.sub(r' {2,}', " ", b) + a = re.sub(r"^\s+?|\n", "", a) + a = re.sub(r" {2,}", " ", a) + b = re.sub(r"^\s+?|\n", "", b) + b = re.sub(r" {2,}", " ", b) assert a == b, msg or "%r != %r" % (a, b) @@ -291,32 +301,41 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): callable_(*args, **kwargs) assert False, "Callable did not raise an exception" except except_cls as e: - assert re.search( - msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e) - print(util.text_type(e).encode('utf-8')) + assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % ( + msg, + e, + ) + print(util.text_type(e).encode("utf-8")) + class AssertsCompiledSQL(object): - def assert_compile(self, clause, result, params=None, - checkparams=None, dialect=None, - checkpositional=None, - check_prefetch=None, - use_default_dialect=False, - allow_dialect_select=False, - literal_binds=False, - schema_translate_map=None): + def assert_compile( + self, + clause, + result, + params=None, + checkparams=None, + dialect=None, + checkpositional=None, + check_prefetch=None, + use_default_dialect=False, + allow_dialect_select=False, + literal_binds=False, + schema_translate_map=None, + ): if use_default_dialect: dialect = default.DefaultDialect() elif allow_dialect_select: dialect = None else: if dialect is None: - dialect = getattr(self, '__dialect__', None) + dialect = getattr(self, "__dialect__", None) if dialect is None: dialect = config.db.dialect - elif dialect == 'default': + elif dialect == "default": dialect = default.DefaultDialect() - elif dialect == 'default_enhanced': + elif dialect == "default_enhanced": dialect = default.StrCompileDialect() elif isinstance(dialect, util.string_types): dialect = url.URL(dialect).get_dialect()() @@ -325,13 +344,13 @@ class AssertsCompiledSQL(object): compile_kwargs = {} if schema_translate_map: - kw['schema_translate_map'] = schema_translate_map + kw["schema_translate_map"] = schema_translate_map if params is not None: - kw['column_keys'] = list(params) + kw["column_keys"] = list(params) if literal_binds: - compile_kwargs['literal_binds'] = True + compile_kwargs["literal_binds"] = True if isinstance(clause, orm.Query): context = clause._compile_context() @@ -343,25 +362,27 @@ class AssertsCompiledSQL(object): clause = stmt_mock.mock_calls[0][1][0] if compile_kwargs: - kw['compile_kwargs'] = compile_kwargs + kw["compile_kwargs"] = compile_kwargs c = clause.compile(dialect=dialect, **kw) - param_str = repr(getattr(c, 'params', {})) + param_str = repr(getattr(c, "params", {})) if util.py3k: - param_str = param_str.encode('utf-8').decode('ascii', 'ignore') + param_str = param_str.encode("utf-8").decode("ascii", "ignore") print( - ("\nSQL String:\n" + - util.text_type(c) + - param_str).encode('utf-8')) + ("\nSQL String:\n" + util.text_type(c) + param_str).encode( + "utf-8" + ) + ) else: print( - "\nSQL String:\n" + - util.text_type(c).encode('utf-8') + - param_str) + "\nSQL String:\n" + + util.text_type(c).encode("utf-8") + + param_str + ) - cc = re.sub(r'[\n\t]', '', util.text_type(c)) + cc = re.sub(r"[\n\t]", "", util.text_type(c)) eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) @@ -375,7 +396,6 @@ class AssertsCompiledSQL(object): class ComparesTables(object): - def assert_tables_equal(self, table, reflected_table, strict_types=False): assert len(table.c) == len(reflected_table.c) for c, reflected_c in zip(table.c, reflected_table.c): @@ -386,8 +406,10 @@ class ComparesTables(object): if strict_types: msg = "Type '%s' doesn't correspond to type '%s'" - assert isinstance(reflected_c.type, type(c.type)), \ - msg % (reflected_c.type, c.type) + assert isinstance(reflected_c.type, type(c.type)), msg % ( + reflected_c.type, + c.type, + ) else: self.assert_types_base(reflected_c, c) @@ -396,20 +418,22 @@ class ComparesTables(object): eq_( {f.column.name for f in c.foreign_keys}, - {f.column.name for f in reflected_c.foreign_keys} + {f.column.name for f in reflected_c.foreign_keys}, ) if c.server_default: - assert isinstance(reflected_c.server_default, - schema.FetchedValue) + assert isinstance( + reflected_c.server_default, schema.FetchedValue + ) assert len(table.primary_key) == len(reflected_table.primary_key) for c in table.primary_key: assert reflected_table.primary_key.columns[c.name] is not None def assert_types_base(self, c1, c2): - assert c1.type._compare_type_affinity(c2.type),\ - "On column %r, type '%s' doesn't correspond to type '%s'" % \ - (c1.name, c1.type, c2.type) + assert c1.type._compare_type_affinity(c2.type), ( + "On column %r, type '%s' doesn't correspond to type '%s'" + % (c1.name, c1.type, c2.type) + ) class AssertsExecutionResults(object): @@ -419,15 +443,19 @@ class AssertsExecutionResults(object): self.assert_list(result, class_, objects) def assert_list(self, result, class_, list): - self.assert_(len(result) == len(list), - "result list is not the same size as test list, " + - "for class " + class_.__name__) + self.assert_( + len(result) == len(list), + "result list is not the same size as test list, " + + "for class " + + class_.__name__, + ) for i in range(0, len(list)): self.assert_row(class_, result[i], list[i]) def assert_row(self, class_, rowobj, desc): - self.assert_(rowobj.__class__ is class_, - "item class is not " + repr(class_)) + self.assert_( + rowobj.__class__ is class_, "item class is not " + repr(class_) + ) for key, value in desc.items(): if isinstance(value, tuple): if isinstance(value[1], list): @@ -435,9 +463,11 @@ class AssertsExecutionResults(object): else: self.assert_row(value[0], getattr(rowobj, key), value[1]) else: - self.assert_(getattr(rowobj, key) == value, - "attribute %s value %s does not match %s" % ( - key, getattr(rowobj, key), value)) + self.assert_( + getattr(rowobj, key) == value, + "attribute %s value %s does not match %s" + % (key, getattr(rowobj, key), value), + ) def assert_unordered_result(self, result, cls, *expected): """As assert_result, but the order of objects is not considered. @@ -453,14 +483,19 @@ class AssertsExecutionResults(object): found = util.IdentitySet(result) expected = {immutabledict(e) for e in expected} - for wrong in util.itertools_filterfalse(lambda o: - isinstance(o, cls), found): - fail('Unexpected type "%s", expected "%s"' % ( - type(wrong).__name__, cls.__name__)) + for wrong in util.itertools_filterfalse( + lambda o: isinstance(o, cls), found + ): + fail( + 'Unexpected type "%s", expected "%s"' + % (type(wrong).__name__, cls.__name__) + ) if len(found) != len(expected): - fail('Unexpected object count "%s", expected "%s"' % ( - len(found), len(expected))) + fail( + 'Unexpected object count "%s", expected "%s"' + % (len(found), len(expected)) + ) NOVALUE = object() @@ -469,7 +504,8 @@ class AssertsExecutionResults(object): if isinstance(value, tuple): try: self.assert_unordered_result( - getattr(obj, key), value[0], *value[1]) + getattr(obj, key), value[0], *value[1] + ) except AssertionError: return False else: @@ -484,8 +520,9 @@ class AssertsExecutionResults(object): break else: fail( - "Expected %s instance with attributes %s not found." % ( - cls.__name__, repr(expected_item))) + "Expected %s instance with attributes %s not found." + % (cls.__name__, repr(expected_item)) + ) return True def sql_execution_asserter(self, db=None): @@ -505,9 +542,9 @@ class AssertsExecutionResults(object): newrules = [] for rule in rules: if isinstance(rule, dict): - newrule = assertsql.AllOf(*[ - assertsql.CompiledSQL(k, v) for k, v in rule.items() - ]) + newrule = assertsql.AllOf( + *[assertsql.CompiledSQL(k, v) for k, v in rule.items()] + ) else: newrule = assertsql.CompiledSQL(*rule) newrules.append(newrule) @@ -516,7 +553,8 @@ class AssertsExecutionResults(object): def assert_sql_count(self, db, callable_, count): self.assert_sql_execution( - db, callable_, assertsql.CountStatements(count)) + db, callable_, assertsql.CountStatements(count) + ) def assert_multiple_sql_count(self, dbs, callable_, counts): recs = [ diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 7a525589df..d8e924cb61 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -26,8 +26,10 @@ class AssertRule(object): pass def no_more_statements(self): - assert False, 'All statements are complete, but pending '\ - 'assertion rules remain' + assert False, ( + "All statements are complete, but pending " + "assertion rules remain" + ) class SQLMatchRule(AssertRule): @@ -44,12 +46,17 @@ class CursorSQL(SQLMatchRule): def process_statement(self, execute_observed): stmt = execute_observed.statements[0] if self.statement != stmt.statement or ( - self.params is not None and self.params != stmt.parameters): - self.errormessage = \ - "Testing for exact SQL %s parameters %s received %s %s" % ( - self.statement, self.params, - stmt.statement, stmt.parameters + self.params is not None and self.params != stmt.parameters + ): + self.errormessage = ( + "Testing for exact SQL %s parameters %s received %s %s" + % ( + self.statement, + self.params, + stmt.statement, + stmt.parameters, ) + ) else: execute_observed.statements.pop(0) self.is_consumed = True @@ -58,23 +65,22 @@ class CursorSQL(SQLMatchRule): class CompiledSQL(SQLMatchRule): - - def __init__(self, statement, params=None, dialect='default'): + def __init__(self, statement, params=None, dialect="default"): self.statement = statement self.params = params self.dialect = dialect def _compare_sql(self, execute_observed, received_statement): - stmt = re.sub(r'[\n\t]', '', self.statement) + stmt = re.sub(r"[\n\t]", "", self.statement) return received_statement == stmt def _compile_dialect(self, execute_observed): - if self.dialect == 'default': + if self.dialect == "default": return DefaultDialect() else: # ugh - if self.dialect == 'postgresql': - params = {'implicit_returning': True} + if self.dialect == "postgresql": + params = {"implicit_returning": True} else: params = {} return url.URL(self.dialect).get_dialect()(**params) @@ -86,36 +92,39 @@ class CompiledSQL(SQLMatchRule): context = execute_observed.context compare_dialect = self._compile_dialect(execute_observed) if isinstance(context.compiled.statement, _DDLCompiles): - compiled = \ - context.compiled.statement.compile( - dialect=compare_dialect, - schema_translate_map=context. - execution_options.get('schema_translate_map')) + compiled = context.compiled.statement.compile( + dialect=compare_dialect, + schema_translate_map=context.execution_options.get( + "schema_translate_map" + ), + ) else: - compiled = ( - context.compiled.statement.compile( - dialect=compare_dialect, - column_keys=context.compiled.column_keys, - inline=context.compiled.inline, - schema_translate_map=context. - execution_options.get('schema_translate_map')) + compiled = context.compiled.statement.compile( + dialect=compare_dialect, + column_keys=context.compiled.column_keys, + inline=context.compiled.inline, + schema_translate_map=context.execution_options.get( + "schema_translate_map" + ), ) - _received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled)) + _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled)) parameters = execute_observed.parameters if not parameters: _received_parameters = [compiled.construct_params()] else: _received_parameters = [ - compiled.construct_params(m) for m in parameters] + compiled.construct_params(m) for m in parameters + ] return _received_statement, _received_parameters def process_statement(self, execute_observed): context = execute_observed.context - _received_statement, _received_parameters = \ - self._received_statement(execute_observed) + _received_statement, _received_parameters = self._received_statement( + execute_observed + ) params = self._all_params(context) equivalent = self._compare_sql(execute_observed, _received_statement) @@ -132,8 +141,10 @@ class CompiledSQL(SQLMatchRule): for param_key in param: # a key in param did not match current # 'received' - if param_key not in received or \ - received[param_key] != param[param_key]: + if ( + param_key not in received + or received[param_key] != param[param_key] + ): break else: # all keys in param matched 'received'; @@ -153,8 +164,8 @@ class CompiledSQL(SQLMatchRule): self.errormessage = None else: self.errormessage = self._failure_message(params) % { - 'received_statement': _received_statement, - 'received_parameters': _received_parameters + "received_statement": _received_statement, + "received_parameters": _received_parameters, } def _all_params(self, context): @@ -171,11 +182,10 @@ class CompiledSQL(SQLMatchRule): def _failure_message(self, expected_params): return ( - 'Testing for compiled statement %r partial params %r, ' - 'received %%(received_statement)r with params ' - '%%(received_parameters)r' % ( - self.statement.replace('%', '%%'), expected_params - ) + "Testing for compiled statement %r partial params %r, " + "received %%(received_statement)r with params " + "%%(received_parameters)r" + % (self.statement.replace("%", "%%"), expected_params) ) @@ -185,15 +195,13 @@ class RegexSQL(CompiledSQL): self.regex = re.compile(regex) self.orig_regex = regex self.params = params - self.dialect = 'default' + self.dialect = "default" def _failure_message(self, expected_params): return ( - 'Testing for compiled statement ~%r partial params %r, ' - 'received %%(received_statement)r with params ' - '%%(received_parameters)r' % ( - self.orig_regex, expected_params - ) + "Testing for compiled statement ~%r partial params %r, " + "received %%(received_statement)r with params " + "%%(received_parameters)r" % (self.orig_regex, expected_params) ) def _compare_sql(self, execute_observed, received_statement): @@ -205,12 +213,13 @@ class DialectSQL(CompiledSQL): return execute_observed.context.dialect def _compare_no_space(self, real_stmt, received_stmt): - stmt = re.sub(r'[\n\t]', '', real_stmt) + stmt = re.sub(r"[\n\t]", "", real_stmt) return received_stmt == stmt def _received_statement(self, execute_observed): - received_stmt, received_params = super(DialectSQL, self).\ - _received_statement(execute_observed) + received_stmt, received_params = super( + DialectSQL, self + )._received_statement(execute_observed) # TODO: why do we need this part? for real_stmt in execute_observed.statements: @@ -219,34 +228,33 @@ class DialectSQL(CompiledSQL): else: raise AssertionError( "Can't locate compiled statement %r in list of " - "statements actually invoked" % received_stmt) + "statements actually invoked" % received_stmt + ) return received_stmt, execute_observed.context.compiled_parameters def _compare_sql(self, execute_observed, received_statement): - stmt = re.sub(r'[\n\t]', '', self.statement) + stmt = re.sub(r"[\n\t]", "", self.statement) # convert our comparison statement to have the # paramstyle of the received paramstyle = execute_observed.context.dialect.paramstyle - if paramstyle == 'pyformat': - stmt = re.sub( - r':([\w_]+)', r"%(\1)s", stmt) + if paramstyle == "pyformat": + stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt) else: # positional params repl = None - if paramstyle == 'qmark': + if paramstyle == "qmark": repl = "?" - elif paramstyle == 'format': + elif paramstyle == "format": repl = r"%s" - elif paramstyle == 'numeric': + elif paramstyle == "numeric": repl = None - stmt = re.sub(r':([\w_]+)', repl, stmt) + stmt = re.sub(r":([\w_]+)", repl, stmt) return received_statement == stmt class CountStatements(AssertRule): - def __init__(self, count): self.count = count self._statement_count = 0 @@ -256,12 +264,13 @@ class CountStatements(AssertRule): def no_more_statements(self): if self.count != self._statement_count: - assert False, 'desired statement count %d does not match %d' \ - % (self.count, self._statement_count) + assert False, "desired statement count %d does not match %d" % ( + self.count, + self._statement_count, + ) class AllOf(AssertRule): - def __init__(self, *rules): self.rules = set(rules) @@ -283,7 +292,6 @@ class AllOf(AssertRule): class EachOf(AssertRule): - def __init__(self, *rules): self.rules = list(rules) @@ -309,7 +317,6 @@ class EachOf(AssertRule): class Or(AllOf): - def process_statement(self, execute_observed): for rule in self.rules: rule.process_statement(execute_observed) @@ -331,7 +338,8 @@ class SQLExecuteObserved(object): class SQLCursorExecuteObserved( collections.namedtuple( "SQLCursorExecuteObserved", - ["statement", "parameters", "context", "executemany"]) + ["statement", "parameters", "context", "executemany"], + ) ): pass @@ -374,21 +382,25 @@ def assert_engine(engine): orig[:] = clauseelement, multiparams, params @event.listens_for(engine, "after_cursor_execute") - def cursor_execute(conn, cursor, statement, parameters, - context, executemany): + def cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): if not context: return # then grab real cursor statements and associate them all # around a single context - if asserter.accumulated and \ - asserter.accumulated[-1].context is context: + if ( + asserter.accumulated + and asserter.accumulated[-1].context is context + ): obs = asserter.accumulated[-1] else: obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2]) asserter.accumulated.append(obs) obs.statements.append( SQLCursorExecuteObserved( - statement, parameters, context, executemany) + statement, parameters, context, executemany + ) ) try: diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index e9cfb3de93..1ff282af59 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -64,8 +64,9 @@ class Config(object): assert _current, "Can't push without a default Config set up" cls.push( Config( - db, _current.db_opts, _current.options, _current.file_config), - namespace + db, _current.db_opts, _current.options, _current.file_config + ), + namespace, ) @classmethod @@ -94,4 +95,3 @@ class Config(object): def skip_test(msg): raise _skip_test_exception(msg) - diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index d17e30edff..074e3b3387 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -16,7 +16,6 @@ import warnings class ConnectionKiller(object): - def __init__(self): self.proxy_refs = weakref.WeakKeyDictionary() self.testing_engines = weakref.WeakKeyDictionary() @@ -39,8 +38,8 @@ class ConnectionKiller(object): fn() except Exception as e: warnings.warn( - "testing_reaper couldn't " - "rollback/close connection: %s" % e) + "testing_reaper couldn't " "rollback/close connection: %s" % e + ) def rollback_all(self): for rec in list(self.proxy_refs): @@ -97,18 +96,19 @@ class ConnectionKiller(object): if rec.is_valid: assert False + testing_reaper = ConnectionKiller() def drop_all_tables(metadata, bind): testing_reaper.close_all() - if hasattr(bind, 'close'): + if hasattr(bind, "close"): bind.close() if not config.db.dialect.supports_alter: from . import assertions - with assertions.expect_warnings( - "Can't sort tables", assert_=False): + + with assertions.expect_warnings("Can't sort tables", assert_=False): metadata.drop_all(bind) else: metadata.drop_all(bind) @@ -151,19 +151,20 @@ def close_open_connections(fn, *args, **kw): def all_dialects(exclude=None): import sqlalchemy.databases as d + for name in d.__all__: # TEMPORARY if exclude and name in exclude: continue mod = getattr(d, name, None) if not mod: - mod = getattr(__import__( - 'sqlalchemy.databases.%s' % name).databases, name) + mod = getattr( + __import__("sqlalchemy.databases.%s" % name).databases, name + ) yield mod.dialect() class ReconnectFixture(object): - def __init__(self, dbapi): self.dbapi = dbapi self.connections = [] @@ -191,8 +192,8 @@ class ReconnectFixture(object): fn() except Exception as e: warnings.warn( - "ReconnectFixture couldn't " - "close connection: %s" % e) + "ReconnectFixture couldn't " "close connection: %s" % e + ) def shutdown(self, stop=False): # TODO: this doesn't cover all cases @@ -214,7 +215,7 @@ def reconnecting_engine(url=None, options=None): dbapi = config.db.dialect.dbapi if not options: options = {} - options['module'] = ReconnectFixture(dbapi) + options["module"] = ReconnectFixture(dbapi) engine = testing_engine(url, options) _dispose = engine.dispose @@ -238,7 +239,7 @@ def testing_engine(url=None, options=None): if not options: use_reaper = True else: - use_reaper = options.pop('use_reaper', True) + use_reaper = options.pop("use_reaper", True) url = url or config.db.url @@ -253,15 +254,15 @@ def testing_engine(url=None, options=None): default_opt.update(options) engine = create_engine(url, **options) - engine._has_events = True # enable event blocks, helps with profiling + engine._has_events = True # enable event blocks, helps with profiling if isinstance(engine.pool, pool.QueuePool): engine.pool._timeout = 0 engine.pool._max_overflow = 0 if use_reaper: - event.listen(engine.pool, 'connect', testing_reaper.connect) - event.listen(engine.pool, 'checkout', testing_reaper.checkout) - event.listen(engine.pool, 'invalidate', testing_reaper.invalidate) + event.listen(engine.pool, "connect", testing_reaper.connect) + event.listen(engine.pool, "checkout", testing_reaper.checkout) + event.listen(engine.pool, "invalidate", testing_reaper.invalidate) testing_reaper.add_engine(engine) return engine @@ -290,19 +291,17 @@ def mock_engine(dialect_name=None): buffer.append(sql) def assert_sql(stmts): - recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer] + recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer] assert recv == stmts, recv def print_sql(): d = engine.dialect - return "\n".join( - str(s.compile(dialect=d)) - for s in engine.mock - ) - - engine = create_engine(dialect_name + '://', - strategy='mock', executor=executor) - assert not hasattr(engine, 'mock') + return "\n".join(str(s.compile(dialect=d)) for s in engine.mock) + + engine = create_engine( + dialect_name + "://", strategy="mock", executor=executor + ) + assert not hasattr(engine, "mock") engine.mock = buffer engine.assert_sql = assert_sql engine.print_sql = print_sql @@ -358,14 +357,15 @@ class DBAPIProxyConnection(object): return getattr(self.conn, key) -def proxying_engine(conn_cls=DBAPIProxyConnection, - cursor_cls=DBAPIProxyCursor): +def proxying_engine( + conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor +): """Produce an engine that provides proxy hooks for common methods. """ + def mock_conn(): return conn_cls(config.db, cursor_cls) - return testing_engine(options={'creator': mock_conn}) - + return testing_engine(options={"creator": mock_conn}) diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py index b634735fc0..42c42149ca 100644 --- a/lib/sqlalchemy/testing/entities.py +++ b/lib/sqlalchemy/testing/entities.py @@ -12,7 +12,6 @@ _repr_stack = set() class BasicEntity(object): - def __init__(self, **kw): for key, value in kw.items(): setattr(self, key, value) @@ -24,17 +23,22 @@ class BasicEntity(object): try: return "%s(%s)" % ( (self.__class__.__name__), - ', '.join(["%s=%r" % (key, getattr(self, key)) - for key in sorted(self.__dict__.keys()) - if not key.startswith('_')])) + ", ".join( + [ + "%s=%r" % (key, getattr(self, key)) + for key in sorted(self.__dict__.keys()) + if not key.startswith("_") + ] + ), + ) finally: _repr_stack.remove(id(self)) + _recursion_stack = set() class ComparableEntity(BasicEntity): - def __hash__(self): return hash(self.__class__) @@ -75,7 +79,7 @@ class ComparableEntity(BasicEntity): b = other for attr in list(a.__dict__): - if attr.startswith('_'): + if attr.startswith("_"): continue value = getattr(a, attr) @@ -85,9 +89,10 @@ class ComparableEntity(BasicEntity): except (AttributeError, sa_exc.UnboundExecutionError): return False - if hasattr(value, '__iter__'): - if hasattr(value, '__getitem__') and not hasattr( - value, 'keys'): + if hasattr(value, "__iter__"): + if hasattr(value, "__getitem__") and not hasattr( + value, "keys" + ): if list(value) != list(battr): return False else: diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 512fffb3bf..9ed9e42c3a 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -16,6 +16,7 @@ from . import config from .. import util from ..util import decorator + def skip_if(predicate, reason=None): rule = compound() pred = _as_predicate(predicate, reason) @@ -70,15 +71,15 @@ class compound(object): def matching_config_reasons(self, config): return [ - predicate._as_string(config) for predicate - in self.skips.union(self.fails) + predicate._as_string(config) + for predicate in self.skips.union(self.fails) if predicate(config) ] def include_test(self, include_tags, exclude_tags): return bool( - not self.tags.intersection(exclude_tags) and - (not include_tags or self.tags.intersection(include_tags)) + not self.tags.intersection(exclude_tags) + and (not include_tags or self.tags.intersection(include_tags)) ) def _extend(self, other): @@ -87,13 +88,14 @@ class compound(object): self.tags.update(other.tags) def __call__(self, fn): - if hasattr(fn, '_sa_exclusion_extend'): + if hasattr(fn, "_sa_exclusion_extend"): fn._sa_exclusion_extend._extend(self) return fn @decorator def decorate(fn, *args, **kw): return self._do(config._current, fn, *args, **kw) + decorated = decorate(fn) decorated._sa_exclusion_extend = self return decorated @@ -113,10 +115,7 @@ class compound(object): def _do(self, cfg, fn, *args, **kw): for skip in self.skips: if skip(cfg): - msg = "'%s' : %s" % ( - fn.__name__, - skip._as_string(cfg) - ) + msg = "'%s' : %s" % (fn.__name__, skip._as_string(cfg)) config.skip_test(msg) try: @@ -127,16 +126,20 @@ class compound(object): self._expect_success(cfg, name=fn.__name__) return return_value - def _expect_failure(self, config, ex, name='block'): + def _expect_failure(self, config, ex, name="block"): for fail in self.fails: if fail(config): - print(("%s failed as expected (%s): %s " % ( - name, fail._as_string(config), str(ex)))) + print( + ( + "%s failed as expected (%s): %s " + % (name, fail._as_string(config), str(ex)) + ) + ) break else: util.raise_from_cause(ex) - def _expect_success(self, config, name='block'): + def _expect_success(self, config, name="block"): if not self.fails: return for fail in self.fails: @@ -144,13 +147,12 @@ class compound(object): break else: raise AssertionError( - "Unexpected success for '%s' (%s)" % - ( + "Unexpected success for '%s' (%s)" + % ( name, " and ".join( - fail._as_string(config) - for fail in self.fails - ) + fail._as_string(config) for fail in self.fails + ), ) ) @@ -186,21 +188,24 @@ class Predicate(object): return predicate elif isinstance(predicate, (list, set)): return OrPredicate( - [cls.as_predicate(pred) for pred in predicate], - description) + [cls.as_predicate(pred) for pred in predicate], description + ) elif isinstance(predicate, tuple): return SpecPredicate(*predicate) elif isinstance(predicate, util.string_types): tokens = re.match( - r'([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?', predicate) + r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate + ) if not tokens: raise ValueError( - "Couldn't locate DB name in predicate: %r" % predicate) + "Couldn't locate DB name in predicate: %r" % predicate + ) db = tokens.group(1) op = tokens.group(2) spec = ( tuple(int(d) for d in tokens.group(3).split(".")) - if tokens.group(3) else None + if tokens.group(3) + else None ) return SpecPredicate(db, op, spec, description=description) @@ -215,11 +220,13 @@ class Predicate(object): bool_ = not negate return self.description % { "driver": config.db.url.get_driver_name() - if config else "", + if config + else "", "database": config.db.url.get_backend_name() - if config else "", + if config + else "", "doesnt_support": "doesn't support" if bool_ else "does support", - "does_support": "does support" if bool_ else "doesn't support" + "does_support": "does support" if bool_ else "doesn't support", } def _as_string(self, config=None, negate=False): @@ -246,21 +253,21 @@ class SpecPredicate(Predicate): self.description = description _ops = { - '<': operator.lt, - '>': operator.gt, - '==': operator.eq, - '!=': operator.ne, - '<=': operator.le, - '>=': operator.ge, - 'in': operator.contains, - 'between': lambda val, pair: val >= pair[0] and val <= pair[1], + "<": operator.lt, + ">": operator.gt, + "==": operator.eq, + "!=": operator.ne, + "<=": operator.le, + ">=": operator.ge, + "in": operator.contains, + "between": lambda val, pair: val >= pair[0] and val <= pair[1], } def __call__(self, config): engine = config.db if "+" in self.db: - dialect, driver = self.db.split('+') + dialect, driver = self.db.split("+") else: dialect, driver = self.db, None @@ -273,8 +280,9 @@ class SpecPredicate(Predicate): assert driver is None, "DBAPI version specs not supported yet" version = _server_version(engine) - oper = hasattr(self.op, '__call__') and self.op \ - or self._ops[self.op] + oper = ( + hasattr(self.op, "__call__") and self.op or self._ops[self.op] + ) return oper(version, self.spec) else: return True @@ -289,17 +297,9 @@ class SpecPredicate(Predicate): return "%s" % self.db else: if negate: - return "not %s %s %s" % ( - self.db, - self.op, - self.spec - ) + return "not %s %s %s" % (self.db, self.op, self.spec) else: - return "%s %s %s" % ( - self.db, - self.op, - self.spec - ) + return "%s %s %s" % (self.db, self.op, self.spec) class LambdaPredicate(Predicate): @@ -356,8 +356,9 @@ class OrPredicate(Predicate): conjunction = " and " else: conjunction = " or " - return conjunction.join(p._as_string(config, negate=negate) - for p in self.predicates) + return conjunction.join( + p._as_string(config, negate=negate) for p in self.predicates + ) def _negation_str(self, config): if self.description is not None: @@ -387,7 +388,7 @@ def _server_version(engine): # force metadata to be retrieved conn = engine.connect() - version = getattr(engine.dialect, 'server_version_info', None) + version = getattr(engine.dialect, "server_version_info", None) if version is None: version = () conn.close() @@ -395,9 +396,7 @@ def _server_version(engine): def db_spec(*dbs): - return OrPredicate( - [Predicate.as_predicate(db) for db in dbs] - ) + return OrPredicate([Predicate.as_predicate(db) for db in dbs]) def open(): @@ -422,11 +421,7 @@ def fails_on(db, reason=None): def fails_on_everything_except(*dbs): - return succeeds_if( - OrPredicate([ - Predicate.as_predicate(db) for db in dbs - ]) - ) + return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs])) def skip(db, reason=None): @@ -435,8 +430,9 @@ def skip(db, reason=None): def only_on(dbs, reason=None): return only_if( - OrPredicate([Predicate.as_predicate(db, reason) - for db in util.to_list(dbs)]) + OrPredicate( + [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)] + ) ) @@ -446,7 +442,6 @@ def exclude(db, op, spec, reason=None): def against(config, *queries): assert queries, "no queries sent!" - return OrPredicate([ - Predicate.as_predicate(query) - for query in queries - ])(config) + return OrPredicate([Predicate.as_predicate(query) for query in queries])( + config + ) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index dd0fa5a48f..98184cdd47 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -54,19 +54,19 @@ class TestBase(object): class TablesTest(TestBase): # 'once', None - run_setup_bind = 'once' + run_setup_bind = "once" # 'once', 'each', None - run_define_tables = 'once' + run_define_tables = "once" # 'once', 'each', None - run_create_tables = 'once' + run_create_tables = "once" # 'once', 'each', None - run_inserts = 'each' + run_inserts = "each" # 'each', None - run_deletes = 'each' + run_deletes = "each" # 'once', None run_dispose_bind = None @@ -86,10 +86,10 @@ class TablesTest(TestBase): @classmethod def _init_class(cls): - if cls.run_define_tables == 'each': - if cls.run_create_tables == 'once': - cls.run_create_tables = 'each' - assert cls.run_inserts in ('each', None) + if cls.run_define_tables == "each": + if cls.run_create_tables == "once": + cls.run_create_tables = "each" + assert cls.run_inserts in ("each", None) cls.other = adict() cls.tables = adict() @@ -100,40 +100,40 @@ class TablesTest(TestBase): @classmethod def _setup_once_inserts(cls): - if cls.run_inserts == 'once': + if cls.run_inserts == "once": cls._load_fixtures() cls.insert_data() @classmethod def _setup_once_tables(cls): - if cls.run_define_tables == 'once': + if cls.run_define_tables == "once": cls.define_tables(cls.metadata) - if cls.run_create_tables == 'once': + if cls.run_create_tables == "once": cls.metadata.create_all(cls.bind) cls.tables.update(cls.metadata.tables) def _setup_each_tables(self): - if self.run_define_tables == 'each': + if self.run_define_tables == "each": self.tables.clear() - if self.run_create_tables == 'each': + if self.run_create_tables == "each": drop_all_tables(self.metadata, self.bind) self.metadata.clear() self.define_tables(self.metadata) - if self.run_create_tables == 'each': + if self.run_create_tables == "each": self.metadata.create_all(self.bind) self.tables.update(self.metadata.tables) - elif self.run_create_tables == 'each': + elif self.run_create_tables == "each": drop_all_tables(self.metadata, self.bind) self.metadata.create_all(self.bind) def _setup_each_inserts(self): - if self.run_inserts == 'each': + if self.run_inserts == "each": self._load_fixtures() self.insert_data() def _teardown_each_tables(self): # no need to run deletes if tables are recreated on setup - if self.run_define_tables != 'each' and self.run_deletes == 'each': + if self.run_define_tables != "each" and self.run_deletes == "each": with self.bind.connect() as conn: for table in reversed(self.metadata.sorted_tables): try: @@ -141,7 +141,8 @@ class TablesTest(TestBase): except sa.exc.DBAPIError as ex: util.print_( ("Error emptying table %s: %r" % (table, ex)), - file=sys.stderr) + file=sys.stderr, + ) def setup(self): self._setup_each_tables() @@ -155,7 +156,7 @@ class TablesTest(TestBase): if cls.run_create_tables: drop_all_tables(cls.metadata, cls.bind) - if cls.run_dispose_bind == 'once': + if cls.run_dispose_bind == "once": cls.dispose_bind(cls.bind) cls.metadata.bind = None @@ -173,9 +174,9 @@ class TablesTest(TestBase): @classmethod def dispose_bind(cls, bind): - if hasattr(bind, 'dispose'): + if hasattr(bind, "dispose"): bind.dispose() - elif hasattr(bind, 'close'): + elif hasattr(bind, "close"): bind.close() @classmethod @@ -212,8 +213,12 @@ class TablesTest(TestBase): continue cls.bind.execute( table.insert(), - [dict(zip(headers[table], column_values)) - for column_values in rows[table]]) + [ + dict(zip(headers[table], column_values)) + for column_values in rows[table] + ], + ) + from sqlalchemy import event @@ -236,7 +241,6 @@ class RemovesEvents(object): class _ORMTest(object): - @classmethod def teardown_class(cls): sa.orm.session.Session.close_all() @@ -249,10 +253,10 @@ class ORMTest(_ORMTest, TestBase): class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): # 'once', 'each', None - run_setup_classes = 'once' + run_setup_classes = "once" # 'once', 'each', None - run_setup_mappers = 'each' + run_setup_mappers = "each" classes = None @@ -292,20 +296,20 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): @classmethod def _setup_once_classes(cls): - if cls.run_setup_classes == 'once': + if cls.run_setup_classes == "once": cls._with_register_classes(cls.setup_classes) @classmethod def _setup_once_mappers(cls): - if cls.run_setup_mappers == 'once': + if cls.run_setup_mappers == "once": cls._with_register_classes(cls.setup_mappers) def _setup_each_mappers(self): - if self.run_setup_mappers == 'each': + if self.run_setup_mappers == "each": self._with_register_classes(self.setup_mappers) def _setup_each_classes(self): - if self.run_setup_classes == 'each': + if self.run_setup_classes == "each": self._with_register_classes(self.setup_classes) @classmethod @@ -339,11 +343,11 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): # some tests create mappers in the test bodies # and will define setup_mappers as None - # clear mappers in any case - if self.run_setup_mappers != 'once': + if self.run_setup_mappers != "once": sa.orm.clear_mappers() def _teardown_each_classes(self): - if self.run_setup_classes != 'once': + if self.run_setup_classes != "once": self.classes.clear() @classmethod @@ -356,8 +360,8 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): class DeclarativeMappedTest(MappedTest): - run_setup_classes = 'once' - run_setup_mappers = 'once' + run_setup_classes = "once" + run_setup_mappers = "once" @classmethod def _setup_once_tables(cls): @@ -370,15 +374,16 @@ class DeclarativeMappedTest(MappedTest): class FindFixtureDeclarative(DeclarativeMeta): def __init__(cls, classname, bases, dict_): cls_registry[classname] = cls - return DeclarativeMeta.__init__( - cls, classname, bases, dict_) + return DeclarativeMeta.__init__(cls, classname, bases, dict_) class DeclarativeBasic(object): __table_cls__ = schema.Table - _DeclBase = declarative_base(metadata=cls.metadata, - metaclass=FindFixtureDeclarative, - cls=DeclarativeBasic) + _DeclBase = declarative_base( + metadata=cls.metadata, + metaclass=FindFixtureDeclarative, + cls=DeclarativeBasic, + ) cls.DeclarativeBasic = _DeclBase fn() diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py index ea0a8da820..dc530af5eb 100644 --- a/lib/sqlalchemy/testing/mock.py +++ b/lib/sqlalchemy/testing/mock.py @@ -18,4 +18,5 @@ else: except ImportError: raise ImportError( "SQLAlchemy's test suite requires the " - "'mock' library as of 0.8.2.") + "'mock' library as of 0.8.2." + ) diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py index 087fc1fe6e..e84cbde44f 100644 --- a/lib/sqlalchemy/testing/pickleable.py +++ b/lib/sqlalchemy/testing/pickleable.py @@ -46,29 +46,28 @@ class Parent(fixtures.ComparableEntity): class Screen(object): - def __init__(self, obj, parent=None): self.obj = obj self.parent = parent class Foo(object): - def __init__(self, moredata): - self.data = 'im data' - self.stuff = 'im stuff' + self.data = "im data" + self.stuff = "im stuff" self.moredata = moredata __hash__ = object.__hash__ def __eq__(self, other): - return other.data == self.data and \ - other.stuff == self.stuff and \ - other.moredata == self.moredata + return ( + other.data == self.data + and other.stuff == self.stuff + and other.moredata == self.moredata + ) class Bar(object): - def __init__(self, x, y): self.x = x self.y = y @@ -76,35 +75,36 @@ class Bar(object): __hash__ = object.__hash__ def __eq__(self, other): - return other.__class__ is self.__class__ and \ - other.x == self.x and \ - other.y == self.y + return ( + other.__class__ is self.__class__ + and other.x == self.x + and other.y == self.y + ) def __str__(self): return "Bar(%d, %d)" % (self.x, self.y) class OldSchool: - def __init__(self, x, y): self.x = x self.y = y def __eq__(self, other): - return other.__class__ is self.__class__ and \ - other.x == self.x and \ - other.y == self.y + return ( + other.__class__ is self.__class__ + and other.x == self.x + and other.y == self.y + ) class OldSchoolWithoutCompare: - def __init__(self, x, y): self.x = x self.y = y class BarWithoutCompare(object): - def __init__(self, x, y): self.x = x self.y = y @@ -114,7 +114,6 @@ class BarWithoutCompare(object): class NotComparable(object): - def __init__(self, data): self.data = data @@ -129,7 +128,6 @@ class NotComparable(object): class BrokenComparable(object): - def __init__(self, data): self.data = data diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py index 497fcb7e58..bb52c125c6 100644 --- a/lib/sqlalchemy/testing/plugin/bootstrap.py +++ b/lib/sqlalchemy/testing/plugin/bootstrap.py @@ -20,20 +20,23 @@ this should be removable when Alembic targets SQLAlchemy 1.0.0. import os import sys -bootstrap_file = locals()['bootstrap_file'] -to_bootstrap = locals()['to_bootstrap'] +bootstrap_file = locals()["bootstrap_file"] +to_bootstrap = locals()["to_bootstrap"] def load_file_as_module(name): path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name) if sys.version_info >= (3, 3): from importlib import machinery + mod = machinery.SourceFileLoader(name, path).load_module() else: import imp + mod = imp.load_source(name, path) return mod + if to_bootstrap == "pytest": sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base") sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin") diff --git a/lib/sqlalchemy/testing/plugin/noseplugin.py b/lib/sqlalchemy/testing/plugin/noseplugin.py index 20ea61d890..0c28a52136 100644 --- a/lib/sqlalchemy/testing/plugin/noseplugin.py +++ b/lib/sqlalchemy/testing/plugin/noseplugin.py @@ -25,6 +25,7 @@ import sys from nose.plugins import Plugin import nose + fixtures = None py3k = sys.version_info >= (3, 0) @@ -33,7 +34,7 @@ py3k = sys.version_info >= (3, 0) class NoseSQLAlchemy(Plugin): enabled = True - name = 'sqla_testing' + name = "sqla_testing" score = 100 def options(self, parser, env=os.environ): @@ -41,10 +42,14 @@ class NoseSQLAlchemy(Plugin): opt = parser.add_option def make_option(name, **kw): - callback_ = kw.pop("callback", None) or kw.pop("zeroarg_callback", None) + callback_ = kw.pop("callback", None) or kw.pop( + "zeroarg_callback", None + ) if callback_: + def wrap_(option, opt_str, value, parser): callback_(opt_str, value, parser) + kw["callback"] = wrap_ opt(name, **kw) @@ -73,7 +78,7 @@ class NoseSQLAlchemy(Plugin): def wantMethod(self, fn): if py3k: - if not hasattr(fn.__self__, 'cls'): + if not hasattr(fn.__self__, "cls"): return False cls = fn.__self__.cls else: @@ -84,24 +89,24 @@ class NoseSQLAlchemy(Plugin): return plugin_base.want_class(cls) def beforeTest(self, test): - if not hasattr(test.test, 'cls'): + if not hasattr(test.test, "cls"): return plugin_base.before_test( test, test.test.cls.__module__, - test.test.cls, test.test.method.__name__) + test.test.cls, + test.test.method.__name__, + ) def afterTest(self, test): plugin_base.after_test(test) def startContext(self, ctx): - if not isinstance(ctx, type) \ - or not issubclass(ctx, fixtures.TestBase): + if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase): return plugin_base.start_test_class(ctx) def stopContext(self, ctx): - if not isinstance(ctx, type) \ - or not issubclass(ctx, fixtures.TestBase): + if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase): return plugin_base.stop_test_class(ctx) diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 0ffcae0936..5d6bf2975e 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -46,58 +46,130 @@ options = None def setup_options(make_option): - make_option("--log-info", action="callback", type="string", callback=_log, - help="turn on info logging for (multiple OK)") - make_option("--log-debug", action="callback", - type="string", callback=_log, - help="turn on debug logging for (multiple OK)") - make_option("--db", action="append", type="string", dest="db", - help="Use prefab database uri. Multiple OK, " - "first one is run by default.") - make_option('--dbs', action='callback', zeroarg_callback=_list_dbs, - help="List available prefab dbs") - make_option("--dburi", action="append", type="string", dest="dburi", - help="Database uri. Multiple OK, " - "first one is run by default.") - make_option("--dropfirst", action="store_true", dest="dropfirst", - help="Drop all tables in the target database first") - make_option("--backend-only", action="store_true", dest="backend_only", - help="Run only tests marked with __backend__") - make_option("--nomemory", action="store_true", dest="nomemory", - help="Don't run memory profiling tests") - make_option("--postgresql-templatedb", type="string", - help="name of template database to use for PostgreSQL " - "CREATE DATABASE (defaults to current database)") - make_option("--low-connections", action="store_true", - dest="low_connections", - help="Use a low number of distinct connections - " - "i.e. for Oracle TNS") - make_option("--write-idents", type="string", dest="write_idents", - help="write out generated follower idents to , " - "when -n is used") - make_option("--reversetop", action="store_true", - dest="reversetop", default=False, - help="Use a random-ordering set implementation in the ORM " - "(helps reveal dependency issues)") - make_option("--requirements", action="callback", type="string", - callback=_requirements_opt, - help="requirements class for testing, overrides setup.cfg") - make_option("--with-cdecimal", action="store_true", - dest="cdecimal", default=False, - help="Monkeypatch the cdecimal library into Python 'decimal' " - "for all tests") - make_option("--include-tag", action="callback", callback=_include_tag, - type="string", - help="Include tests with tag ") - make_option("--exclude-tag", action="callback", callback=_exclude_tag, - type="string", - help="Exclude tests with tag ") - make_option("--write-profiles", action="store_true", - dest="write_profiles", default=False, - help="Write/update failing profiling data.") - make_option("--force-write-profiles", action="store_true", - dest="force_write_profiles", default=False, - help="Unconditionally write/update profiling data.") + make_option( + "--log-info", + action="callback", + type="string", + callback=_log, + help="turn on info logging for (multiple OK)", + ) + make_option( + "--log-debug", + action="callback", + type="string", + callback=_log, + help="turn on debug logging for (multiple OK)", + ) + make_option( + "--db", + action="append", + type="string", + dest="db", + help="Use prefab database uri. Multiple OK, " + "first one is run by default.", + ) + make_option( + "--dbs", + action="callback", + zeroarg_callback=_list_dbs, + help="List available prefab dbs", + ) + make_option( + "--dburi", + action="append", + type="string", + dest="dburi", + help="Database uri. Multiple OK, " "first one is run by default.", + ) + make_option( + "--dropfirst", + action="store_true", + dest="dropfirst", + help="Drop all tables in the target database first", + ) + make_option( + "--backend-only", + action="store_true", + dest="backend_only", + help="Run only tests marked with __backend__", + ) + make_option( + "--nomemory", + action="store_true", + dest="nomemory", + help="Don't run memory profiling tests", + ) + make_option( + "--postgresql-templatedb", + type="string", + help="name of template database to use for PostgreSQL " + "CREATE DATABASE (defaults to current database)", + ) + make_option( + "--low-connections", + action="store_true", + dest="low_connections", + help="Use a low number of distinct connections - " + "i.e. for Oracle TNS", + ) + make_option( + "--write-idents", + type="string", + dest="write_idents", + help="write out generated follower idents to , " + "when -n is used", + ) + make_option( + "--reversetop", + action="store_true", + dest="reversetop", + default=False, + help="Use a random-ordering set implementation in the ORM " + "(helps reveal dependency issues)", + ) + make_option( + "--requirements", + action="callback", + type="string", + callback=_requirements_opt, + help="requirements class for testing, overrides setup.cfg", + ) + make_option( + "--with-cdecimal", + action="store_true", + dest="cdecimal", + default=False, + help="Monkeypatch the cdecimal library into Python 'decimal' " + "for all tests", + ) + make_option( + "--include-tag", + action="callback", + callback=_include_tag, + type="string", + help="Include tests with tag ", + ) + make_option( + "--exclude-tag", + action="callback", + callback=_exclude_tag, + type="string", + help="Exclude tests with tag ", + ) + make_option( + "--write-profiles", + action="store_true", + dest="write_profiles", + default=False, + help="Write/update failing profiling data.", + ) + make_option( + "--force-write-profiles", + action="store_true", + dest="force_write_profiles", + default=False, + help="Unconditionally write/update profiling data.", + ) def configure_follower(follower_ident): @@ -108,6 +180,7 @@ def configure_follower(follower_ident): """ from sqlalchemy.testing import provision + provision.FOLLOWER_IDENT = follower_ident @@ -121,9 +194,9 @@ def memoize_important_follower_config(dict_): callables, so we have to just copy all of that over. """ - dict_['memoized_config'] = { - 'include_tags': include_tags, - 'exclude_tags': exclude_tags + dict_["memoized_config"] = { + "include_tags": include_tags, + "exclude_tags": exclude_tags, } @@ -134,14 +207,14 @@ def restore_important_follower_config(dict_): """ global include_tags, exclude_tags - include_tags.update(dict_['memoized_config']['include_tags']) - exclude_tags.update(dict_['memoized_config']['exclude_tags']) + include_tags.update(dict_["memoized_config"]["include_tags"]) + exclude_tags.update(dict_["memoized_config"]["exclude_tags"]) def read_config(): global file_config file_config = configparser.ConfigParser() - file_config.read(['setup.cfg', 'test.cfg']) + file_config.read(["setup.cfg", "test.cfg"]) def pre_begin(opt): @@ -155,6 +228,7 @@ def pre_begin(opt): def set_coverage_flag(value): options.has_coverage = value + _skip_test_exception = None @@ -171,34 +245,33 @@ def post_begin(): # late imports, has to happen after config as well # as nose plugins like coverage - global util, fixtures, engines, exclusions, \ - assertions, warnings, profiling,\ - config, testing - from sqlalchemy import testing # noqa + global util, fixtures, engines, exclusions, assertions, warnings, profiling, config, testing + from sqlalchemy import testing # noqa from sqlalchemy.testing import fixtures, engines, exclusions # noqa - from sqlalchemy.testing import assertions, warnings, profiling # noqa + from sqlalchemy.testing import assertions, warnings, profiling # noqa from sqlalchemy.testing import config # noqa from sqlalchemy import util # noqa - warnings.setup_filters() + warnings.setup_filters() def _log(opt_str, value, parser): global logging if not logging: import logging + logging.basicConfig() - if opt_str.endswith('-info'): + if opt_str.endswith("-info"): logging.getLogger(value).setLevel(logging.INFO) - elif opt_str.endswith('-debug'): + elif opt_str.endswith("-debug"): logging.getLogger(value).setLevel(logging.DEBUG) def _list_dbs(*args): print("Available --db options (use --dburi to override)") - for macro in sorted(file_config.options('db')): - print("%20s\t%s" % (macro, file_config.get('db', macro))) + for macro in sorted(file_config.options("db")): + print("%20s\t%s" % (macro, file_config.get("db", macro))) sys.exit(0) @@ -207,11 +280,12 @@ def _requirements_opt(opt_str, value, parser): def _exclude_tag(opt_str, value, parser): - exclude_tags.add(value.replace('-', '_')) + exclude_tags.add(value.replace("-", "_")) def _include_tag(opt_str, value, parser): - include_tags.add(value.replace('-', '_')) + include_tags.add(value.replace("-", "_")) + pre_configure = [] post_configure = [] @@ -243,7 +317,8 @@ def _set_nomemory(opt, file_config): def _monkeypatch_cdecimal(options, file_config): if options.cdecimal: import cdecimal - sys.modules['decimal'] = cdecimal + + sys.modules["decimal"] = cdecimal @post @@ -266,27 +341,28 @@ def _engine_uri(options, file_config): if options.db: for db_token in options.db: - for db in re.split(r'[,\s]+', db_token): - if db not in file_config.options('db'): + for db in re.split(r"[,\s]+", db_token): + if db not in file_config.options("db"): raise RuntimeError( "Unknown URI specifier '%s'. " - "Specify --dbs for known uris." - % db) + "Specify --dbs for known uris." % db + ) else: - db_urls.append(file_config.get('db', db)) + db_urls.append(file_config.get("db", db)) if not db_urls: - db_urls.append(file_config.get('db', 'default')) + db_urls.append(file_config.get("db", "default")) config._current = None for db_url in db_urls: - if options.write_idents and provision.FOLLOWER_IDENT: # != 'master': + if options.write_idents and provision.FOLLOWER_IDENT: # != 'master': with open(options.write_idents, "a") as file_: file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n") cfg = provision.setup_config( - db_url, options, file_config, provision.FOLLOWER_IDENT) + db_url, options, file_config, provision.FOLLOWER_IDENT + ) if not config._current: cfg.set_as_current(cfg, testing) @@ -295,7 +371,7 @@ def _engine_uri(options, file_config): @post def _requirements(options, file_config): - requirement_cls = file_config.get('sqla_testing', "requirement_cls") + requirement_cls = file_config.get("sqla_testing", "requirement_cls") _setup_requirements(requirement_cls) @@ -334,22 +410,28 @@ def _prep_testing_database(options, file_config): pass else: for vname in view_names: - e.execute(schema._DropView( - schema.Table(vname, schema.MetaData()) - )) + e.execute( + schema._DropView( + schema.Table(vname, schema.MetaData()) + ) + ) if config.requirements.schemas.enabled_for_config(cfg): try: - view_names = inspector.get_view_names( - schema="test_schema") + view_names = inspector.get_view_names(schema="test_schema") except NotImplementedError: pass else: for vname in view_names: - e.execute(schema._DropView( - schema.Table(vname, schema.MetaData(), - schema="test_schema") - )) + e.execute( + schema._DropView( + schema.Table( + vname, + schema.MetaData(), + schema="test_schema", + ) + ) + ) util.drop_all_tables(e, inspector) @@ -358,23 +440,29 @@ def _prep_testing_database(options, file_config): if against(cfg, "postgresql"): from sqlalchemy.dialects import postgresql + for enum in inspector.get_enums("*"): - e.execute(postgresql.DropEnumType( - postgresql.ENUM( - name=enum['name'], - schema=enum['schema']))) + e.execute( + postgresql.DropEnumType( + postgresql.ENUM( + name=enum["name"], schema=enum["schema"] + ) + ) + ) @post def _reverse_topological(options, file_config): if options.reversetop: from sqlalchemy.orm.util import randomize_unitofwork + randomize_unitofwork() @post def _post_setup_options(opt, file_config): from sqlalchemy.testing import config + config.options = options config.file_config = file_config @@ -382,17 +470,20 @@ def _post_setup_options(opt, file_config): @post def _setup_profiling(options, file_config): from sqlalchemy.testing import profiling + profiling._profile_stats = profiling.ProfileStatsFile( - file_config.get('sqla_testing', 'profile_file')) + file_config.get("sqla_testing", "profile_file") + ) def want_class(cls): if not issubclass(cls, fixtures.TestBase): return False - elif cls.__name__.startswith('_'): + elif cls.__name__.startswith("_"): return False - elif config.options.backend_only and not getattr(cls, '__backend__', - False): + elif config.options.backend_only and not getattr( + cls, "__backend__", False + ): return False else: return True @@ -405,25 +496,28 @@ def want_method(cls, fn): return False elif include_tags: return ( - hasattr(cls, '__tags__') and - exclusions.tags(cls.__tags__).include_test( - include_tags, exclude_tags) + hasattr(cls, "__tags__") + and exclusions.tags(cls.__tags__).include_test( + include_tags, exclude_tags + ) ) or ( - hasattr(fn, '_sa_exclusion_extend') and - fn._sa_exclusion_extend.include_test( - include_tags, exclude_tags) + hasattr(fn, "_sa_exclusion_extend") + and fn._sa_exclusion_extend.include_test( + include_tags, exclude_tags + ) ) - elif exclude_tags and hasattr(cls, '__tags__'): + elif exclude_tags and hasattr(cls, "__tags__"): return exclusions.tags(cls.__tags__).include_test( - include_tags, exclude_tags) - elif exclude_tags and hasattr(fn, '_sa_exclusion_extend'): + include_tags, exclude_tags + ) + elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"): return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags) else: return True def generate_sub_tests(cls, module): - if getattr(cls, '__backend__', False): + if getattr(cls, "__backend__", False): for cfg in _possible_configs_for_cls(cls): orig_name = cls.__name__ @@ -431,16 +525,13 @@ def generate_sub_tests(cls, module): # pytest junit plugin, which is tripped up by the brackets # and periods, so sanitize - alpha_name = re.sub('[_\[\]\.]+', '_', cfg.name) - alpha_name = re.sub('_+$', '', alpha_name) + alpha_name = re.sub("[_\[\]\.]+", "_", cfg.name) + alpha_name = re.sub("_+$", "", alpha_name) name = "%s_%s" % (cls.__name__, alpha_name) subcls = type( name, - (cls, ), - { - "_sa_orig_cls_name": orig_name, - "__only_on_config__": cfg - } + (cls,), + {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg}, ) setattr(module, name, subcls) yield subcls @@ -454,8 +545,8 @@ def start_test_class(cls): def stop_test_class(cls): - #from sqlalchemy import inspect - #assert not inspect(testing.db).get_table_names() + # from sqlalchemy import inspect + # assert not inspect(testing.db).get_table_names() engines.testing_reaper._stop_test_ctx() try: if not options.low_connections: @@ -475,7 +566,7 @@ def final_process_cleanup(): def _setup_engine(cls): - if getattr(cls, '__engine_options__', None): + if getattr(cls, "__engine_options__", None): eng = engines.testing_engine(options=cls.__engine_options__) config._current.push_engine(eng, testing) @@ -485,7 +576,7 @@ def before_test(test, test_module_name, test_class, test_name): # like a nose id, e.g.: # "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause" - name = getattr(test_class, '_sa_orig_cls_name', test_class.__name__) + name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__) id_ = "%s.%s.%s" % (test_module_name, name, test_name) @@ -505,16 +596,16 @@ def _possible_configs_for_cls(cls, reasons=None): if spec(config_obj): all_configs.remove(config_obj) - if getattr(cls, '__only_on__', None): + if getattr(cls, "__only_on__", None): spec = exclusions.db_spec(*util.to_list(cls.__only_on__)) for config_obj in list(all_configs): if not spec(config_obj): all_configs.remove(config_obj) - if getattr(cls, '__only_on_config__', None): + if getattr(cls, "__only_on_config__", None): all_configs.intersection_update([cls.__only_on_config__]) - if hasattr(cls, '__requires__'): + if hasattr(cls, "__requires__"): requirements = config.requirements for config_obj in list(all_configs): for requirement in cls.__requires__: @@ -527,7 +618,7 @@ def _possible_configs_for_cls(cls, reasons=None): reasons.extend(skip_reasons) break - if hasattr(cls, '__prefer_requires__'): + if hasattr(cls, "__prefer_requires__"): non_preferred = set() requirements = config.requirements for config_obj in list(all_configs): @@ -546,30 +637,32 @@ def _do_skips(cls): reasons = [] all_configs = _possible_configs_for_cls(cls, reasons) - if getattr(cls, '__skip_if__', False): - for c in getattr(cls, '__skip_if__'): + if getattr(cls, "__skip_if__", False): + for c in getattr(cls, "__skip_if__"): if c(): - config.skip_test("'%s' skipped by %s" % ( - cls.__name__, c.__name__) + config.skip_test( + "'%s' skipped by %s" % (cls.__name__, c.__name__) ) if not all_configs: msg = "'%s' unsupported on any DB implementation %s%s" % ( cls.__name__, ", ".join( - "'%s(%s)+%s'" % ( + "'%s(%s)+%s'" + % ( config_obj.db.name, ".".join( - str(dig) for dig in - exclusions._server_version(config_obj.db)), - config_obj.db.driver + str(dig) + for dig in exclusions._server_version(config_obj.db) + ), + config_obj.db.driver, ) - for config_obj in config.Config.all_configs() + for config_obj in config.Config.all_configs() ), - ", ".join(reasons) + ", ".join(reasons), ) config.skip_test(msg) - elif hasattr(cls, '__prefer_backends__'): + elif hasattr(cls, "__prefer_backends__"): non_preferred = set() spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__)) for config_obj in all_configs: diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index da682ea008..fd0a484629 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -13,6 +13,7 @@ import os try: import xdist # noqa + has_xdist = True except ImportError: has_xdist = False @@ -24,30 +25,42 @@ def pytest_addoption(parser): def make_option(name, **kw): callback_ = kw.pop("callback", None) if callback_: + class CallableAction(argparse.Action): - def __call__(self, parser, namespace, - values, option_string=None): + def __call__( + self, parser, namespace, values, option_string=None + ): callback_(option_string, values, parser) + kw["action"] = CallableAction zeroarg_callback = kw.pop("zeroarg_callback", None) if zeroarg_callback: + class CallableAction(argparse.Action): - def __init__(self, option_strings, - dest, default=False, - required=False, help=None): - super(CallableAction, self).__init__( - option_strings=option_strings, - dest=dest, - nargs=0, - const=True, - default=default, - required=required, - help=help) - - def __call__(self, parser, namespace, - values, option_string=None): + def __init__( + self, + option_strings, + dest, + default=False, + required=False, + help=None, + ): + super(CallableAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=0, + const=True, + default=default, + required=required, + help=help, + ) + + def __call__( + self, parser, namespace, values, option_string=None + ): zeroarg_callback(option_string, values, parser) + kw["action"] = CallableAction group.addoption(name, **kw) @@ -59,18 +72,18 @@ def pytest_addoption(parser): def pytest_configure(config): if hasattr(config, "slaveinput"): plugin_base.restore_important_follower_config(config.slaveinput) - plugin_base.configure_follower( - config.slaveinput["follower_ident"] - ) + plugin_base.configure_follower(config.slaveinput["follower_ident"]) else: - if config.option.write_idents and \ - os.path.exists(config.option.write_idents): + if config.option.write_idents and os.path.exists( + config.option.write_idents + ): os.remove(config.option.write_idents) plugin_base.pre_begin(config.option) - plugin_base.set_coverage_flag(bool(getattr(config.option, - "cov_source", False))) + plugin_base.set_coverage_flag( + bool(getattr(config.option, "cov_source", False)) + ) plugin_base.set_skip_test(pytest.skip.Exception) @@ -94,10 +107,12 @@ if has_xdist: node.slaveinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12] from sqlalchemy.testing import provision + provision.create_follower_db(node.slaveinput["follower_ident"]) def pytest_testnodedown(node, error): from sqlalchemy.testing import provision + provision.drop_follower_db(node.slaveinput["follower_ident"]) @@ -114,19 +129,22 @@ def pytest_collection_modifyitems(session, config, items): rebuilt_items = collections.defaultdict(list) items[:] = [ - item for item in - items if isinstance(item.parent, pytest.Instance) - and not item.parent.parent.name.startswith("_")] + item + for item in items + if isinstance(item.parent, pytest.Instance) + and not item.parent.parent.name.startswith("_") + ] test_classes = set(item.parent for item in items) for test_class in test_classes: for sub_cls in plugin_base.generate_sub_tests( - test_class.cls, test_class.parent.module): + test_class.cls, test_class.parent.module + ): if sub_cls is not test_class.cls: list_ = rebuilt_items[test_class.cls] for inst in pytest.Class( - sub_cls.__name__, - parent=test_class.parent.parent).collect(): + sub_cls.__name__, parent=test_class.parent.parent + ).collect(): list_.extend(inst.collect()) newitems = [] @@ -139,23 +157,29 @@ def pytest_collection_modifyitems(session, config, items): # seems like the functions attached to a test class aren't sorted already? # is that true and why's that? (when using unittest, they're sorted) - items[:] = sorted(newitems, key=lambda item: ( - item.parent.parent.parent.name, - item.parent.parent.name, - item.name - )) + items[:] = sorted( + newitems, + key=lambda item: ( + item.parent.parent.parent.name, + item.parent.parent.name, + item.name, + ), + ) def pytest_pycollect_makeitem(collector, name, obj): if inspect.isclass(obj) and plugin_base.want_class(obj): return pytest.Class(name, parent=collector) - elif inspect.isfunction(obj) and \ - isinstance(collector, pytest.Instance) and \ - plugin_base.want_method(collector.cls, obj): + elif ( + inspect.isfunction(obj) + and isinstance(collector, pytest.Instance) + and plugin_base.want_method(collector.cls, obj) + ): return pytest.Function(name, parent=collector) else: return [] + _current_class = None @@ -180,6 +204,7 @@ def pytest_runtest_setup(item): global _current_class class_teardown(item.parent.parent) _current_class = None + item.parent.parent.addfinalizer(finalize) test_setup(item) @@ -194,8 +219,9 @@ def pytest_runtest_teardown(item): def test_setup(item): - plugin_base.before_test(item, item.parent.module.__name__, - item.parent.cls, item.name) + plugin_base.before_test( + item, item.parent.module.__name__, item.parent.cls, item.name + ) def test_teardown(item): diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index fab99b186e..3986985c7f 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -42,17 +42,16 @@ class ProfileStatsFile(object): def __init__(self, filename): self.force_write = ( - config.options is not None and - config.options.force_write_profiles + config.options is not None and config.options.force_write_profiles ) self.write = self.force_write or ( - config.options is not None and - config.options.write_profiles + config.options is not None and config.options.write_profiles ) self.fname = os.path.abspath(filename) self.short_fname = os.path.split(self.fname)[-1] self.data = collections.defaultdict( - lambda: collections.defaultdict(dict)) + lambda: collections.defaultdict(dict) + ) self._read() if self.write: # rewrite for the case where features changed, @@ -65,7 +64,7 @@ class ProfileStatsFile(object): dbapi_key = config.db.name + "_" + config.db.driver # keep it at 2.7, 3.1, 3.2, etc. for now. - py_version = '.'.join([str(v) for v in sys.version_info[0:2]]) + py_version = ".".join([str(v) for v in sys.version_info[0:2]]) platform_tokens = [py_version] platform_tokens.append(dbapi_key) @@ -87,8 +86,7 @@ class ProfileStatsFile(object): def has_stats(self): test_key = _current_test return ( - test_key in self.data and - self.platform_key in self.data[test_key] + test_key in self.data and self.platform_key in self.data[test_key] ) def result(self, callcount): @@ -96,15 +94,15 @@ class ProfileStatsFile(object): per_fn = self.data[test_key] per_platform = per_fn[self.platform_key] - if 'counts' not in per_platform: - per_platform['counts'] = counts = [] + if "counts" not in per_platform: + per_platform["counts"] = counts = [] else: - counts = per_platform['counts'] + counts = per_platform["counts"] - if 'current_count' not in per_platform: - per_platform['current_count'] = current_count = 0 + if "current_count" not in per_platform: + per_platform["current_count"] = current_count = 0 else: - current_count = per_platform['current_count'] + current_count = per_platform["current_count"] has_count = len(counts) > current_count @@ -114,16 +112,16 @@ class ProfileStatsFile(object): self._write() result = None else: - result = per_platform['lineno'], counts[current_count] - per_platform['current_count'] += 1 + result = per_platform["lineno"], counts[current_count] + per_platform["current_count"] += 1 return result def replace(self, callcount): test_key = _current_test per_fn = self.data[test_key] per_platform = per_fn[self.platform_key] - counts = per_platform['counts'] - current_count = per_platform['current_count'] + counts = per_platform["counts"] + current_count = per_platform["current_count"] if current_count < len(counts): counts[current_count - 1] = callcount else: @@ -164,9 +162,9 @@ class ProfileStatsFile(object): per_fn = self.data[test_key] per_platform = per_fn[platform_key] c = [int(count) for count in counts.split(",")] - per_platform['counts'] = c - per_platform['lineno'] = lineno + 1 - per_platform['current_count'] = 0 + per_platform["counts"] = c + per_platform["lineno"] = lineno + 1 + per_platform["current_count"] = 0 profile_f.close() def _write(self): @@ -179,7 +177,7 @@ class ProfileStatsFile(object): profile_f.write("\n# TEST: %s\n\n" % test_key) for platform_key in sorted(per_fn): per_platform = per_fn[platform_key] - c = ",".join(str(count) for count in per_platform['counts']) + c = ",".join(str(count) for count in per_platform["counts"]) profile_f.write("%s %s %s\n" % (test_key, platform_key, c)) profile_f.close() @@ -199,7 +197,9 @@ def function_call_count(variance=0.05): def wrap(*args, **kw): with count_functions(variance=variance): return fn(*args, **kw) + return update_wrapper(wrap, fn) + return decorate @@ -213,21 +213,22 @@ def count_functions(variance=0.05): "No profiling stats available on this " "platform for this function. Run tests with " "--write-profiles to add statistics to %s for " - "this platform." % _profile_stats.short_fname) + "this platform." % _profile_stats.short_fname + ) gc_collect() pr = cProfile.Profile() pr.enable() - #began = time.time() + # began = time.time() yield - #ended = time.time() + # ended = time.time() pr.disable() - #s = compat.StringIO() + # s = compat.StringIO() stats = pstats.Stats(pr, stream=sys.stdout) - #timespent = ended - began + # timespent = ended - began callcount = stats.total_calls expected = _profile_stats.result(callcount) @@ -237,11 +238,7 @@ def count_functions(variance=0.05): else: line_no, expected_count = expected - print(("Pstats calls: %d Expected %s" % ( - callcount, - expected_count - ) - )) + print(("Pstats calls: %d Expected %s" % (callcount, expected_count))) stats.sort_stats("cumulative") stats.print_stats() @@ -259,7 +256,9 @@ def count_functions(variance=0.05): "--write-profiles to " "regenerate this callcount." % ( - callcount, (variance * 100), - expected_count, _profile_stats.platform_key)) - - + callcount, + (variance * 100), + expected_count, + _profile_stats.platform_key, + ) + ) diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index c0ca7c1cbc..25028ccb3e 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -8,6 +8,7 @@ import collections import os import time import logging + log = logging.getLogger(__name__) FOLLOWER_IDENT = None @@ -25,6 +26,7 @@ class register(object): def decorate(fn): self.fns[dbname] = fn return self + return decorate def __call__(self, cfg, *arg): @@ -38,7 +40,7 @@ class register(object): if backend in self.fns: return self.fns[backend](cfg, *arg) else: - return self.fns['*'](cfg, *arg) + return self.fns["*"](cfg, *arg) def create_follower_db(follower_ident): @@ -82,9 +84,7 @@ def _configs_for_db_operation(): for cfg in config.Config.all_configs(): url = cfg.db.url backend = url.get_backend_name() - host_conf = ( - backend, - url.username, url.host, url.database) + host_conf = (backend, url.username, url.host, url.database) if host_conf not in hosts: yield cfg @@ -128,14 +128,13 @@ def _follower_url_from_main(url, ident): @_update_db_opts.for_db("mssql") def _mssql_update_db_opts(db_url, db_opts): - db_opts['legacy_schema_aliasing'] = False - + db_opts["legacy_schema_aliasing"] = False @_follower_url_from_main.for_db("sqlite") def _sqlite_follower_url_from_main(url, ident): url = sa_url.make_url(url) - if not url.database or url.database == ':memory:': + if not url.database or url.database == ":memory:": return url else: return sa_url.make_url("sqlite:///%s.db" % ident) @@ -151,19 +150,20 @@ def _sqlite_post_configure_engine(url, engine, follower_ident): # as an attached if not follower_ident: dbapi_connection.execute( - 'ATTACH DATABASE "test_schema.db" AS test_schema') + 'ATTACH DATABASE "test_schema.db" AS test_schema' + ) else: dbapi_connection.execute( 'ATTACH DATABASE "%s_test_schema.db" AS test_schema' - % follower_ident) + % follower_ident + ) @_create_db.for_db("postgresql") def _pg_create_db(cfg, eng, ident): template_db = cfg.options.postgresql_templatedb - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: try: _pg_drop_db(cfg, conn, ident) except Exception: @@ -175,7 +175,8 @@ def _pg_create_db(cfg, eng, ident): while True: try: conn.execute( - "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db)) + "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db) + ) except exc.OperationalError as err: attempt += 1 if attempt >= 3: @@ -184,8 +185,11 @@ def _pg_create_db(cfg, eng, ident): log.info( "Waiting to create %s, URI %r, " "template DB %s is in use sleeping for .5", - ident, eng.url, template_db) - time.sleep(.5) + ident, + eng.url, + template_db, + ) + time.sleep(0.5) else: break @@ -203,9 +207,11 @@ def _mysql_create_db(cfg, eng, ident): # 1271, u"Illegal mix of collations for operation 'UNION'" conn.execute("CREATE DATABASE %s CHARACTER SET utf8mb3" % ident) conn.execute( - "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb3" % ident) + "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb3" % ident + ) conn.execute( - "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb3" % ident) + "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb3" % ident + ) @_configure_follower.for_db("mysql") @@ -221,14 +227,15 @@ def _sqlite_create_db(cfg, eng, ident): @_drop_db.for_db("postgresql") def _pg_drop_db(cfg, eng, ident): - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: conn.execute( text( "select pg_terminate_backend(pid) from pg_stat_activity " "where usename=current_user and pid != pg_backend_pid() " "and datname=:dname" - ), dname=ident) + ), + dname=ident, + ) conn.execute("DROP DATABASE %s" % ident) @@ -257,11 +264,12 @@ def _oracle_create_db(cfg, eng, ident): conn.execute("create user %s identified by xe" % ident) conn.execute("create user %s_ts1 identified by xe" % ident) conn.execute("create user %s_ts2 identified by xe" % ident) - conn.execute("grant dba to %s" % (ident, )) + conn.execute("grant dba to %s" % (ident,)) conn.execute("grant unlimited tablespace to %s" % ident) conn.execute("grant unlimited tablespace to %s_ts1" % ident) conn.execute("grant unlimited tablespace to %s_ts2" % ident) + @_configure_follower.for_db("oracle") def _oracle_configure_follower(config, ident): config.test_schema = "%s_ts1" % ident @@ -320,6 +328,7 @@ def reap_dbs(idents_file): elif backend == "mssql": _reap_mssql_dbs(url, ident) + def _reap_oracle_dbs(url, idents): log.info("db reaper connecting to %r", url) eng = create_engine(url) @@ -330,8 +339,9 @@ def _reap_oracle_dbs(url, idents): to_reap = conn.execute( "select u.username from all_users u where username " "like 'TEST_%' and not exists (select username " - "from v$session where username=u.username)") - all_names = {username.lower() for (username, ) in to_reap} + "from v$session where username=u.username)" + ) + all_names = {username.lower() for (username,) in to_reap} to_drop = set() for name in all_names: if name.endswith("_ts1") or name.endswith("_ts2"): @@ -348,28 +358,28 @@ def _reap_oracle_dbs(url, idents): if _ora_drop_ignore(conn, username): dropped += 1 log.info( - "Dropped %d out of %d stale databases detected", - dropped, total) - + "Dropped %d out of %d stale databases detected", dropped, total + ) @_follower_url_from_main.for_db("oracle") def _oracle_follower_url_from_main(url, ident): url = sa_url.make_url(url) url.username = ident - url.password = 'xe' + url.password = "xe" return url @_create_db.for_db("mssql") def _mssql_create_db(cfg, eng, ident): - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: conn.execute("create database %s" % ident) conn.execute( - "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident) + "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident + ) conn.execute( - "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident) + "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident + ) conn.execute("use %s" % ident) conn.execute("create schema test_schema") conn.execute("create schema test_schema_2") @@ -377,10 +387,10 @@ def _mssql_create_db(cfg, eng, ident): @_drop_db.for_db("mssql") def _mssql_drop_db(cfg, eng, ident): - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: _mssql_drop_ignore(conn, ident) + def _mssql_drop_ignore(conn, ident): try: # typically when this happens, we can't KILL the session anyway, @@ -401,8 +411,7 @@ def _mssql_drop_ignore(conn, ident): def _reap_mssql_dbs(url, idents): log.info("db reaper connecting to %r", url) eng = create_engine(url) - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: log.info("identifiers in file: %s", ", ".join(idents)) @@ -410,8 +419,9 @@ def _reap_mssql_dbs(url, idents): "select d.name from sys.databases as d where name " "like 'TEST_%' and not exists (select session_id " "from sys.dm_exec_sessions " - "where database_id=d.database_id)") - all_names = {dbname.lower() for (dbname, ) in to_reap} + "where database_id=d.database_id)" + ) + all_names = {dbname.lower() for (dbname,) in to_reap} to_drop = set() for name in all_names: if name in idents: @@ -422,5 +432,5 @@ def _reap_mssql_dbs(url, idents): if _mssql_drop_ignore(conn, dbname): dropped += 1 log.info( - "Dropped %d out of %d stale databases detected", - dropped, total) + "Dropped %d out of %d stale databases detected", dropped, total + ) diff --git a/lib/sqlalchemy/testing/replay_fixture.py b/lib/sqlalchemy/testing/replay_fixture.py index b50f52e3de..9832b07a28 100644 --- a/lib/sqlalchemy/testing/replay_fixture.py +++ b/lib/sqlalchemy/testing/replay_fixture.py @@ -11,7 +11,6 @@ from sqlalchemy.orm import Session class ReplayFixtureTest(fixtures.TestBase): - @contextlib.contextmanager def _dummy_ctx(self, *arg, **kw): yield @@ -22,8 +21,8 @@ class ReplayFixtureTest(fixtures.TestBase): creator = config.db.pool._creator recorder = lambda: dbapi_session.recorder(creator()) engine = create_engine( - config.db.url, creator=recorder, - use_native_hstore=False) + config.db.url, creator=recorder, use_native_hstore=False + ) self.metadata = MetaData(engine) self.engine = engine self.session = Session(engine) @@ -37,8 +36,8 @@ class ReplayFixtureTest(fixtures.TestBase): player = lambda: dbapi_session.player() engine = create_engine( - config.db.url, creator=player, - use_native_hstore=False) + config.db.url, creator=player, use_native_hstore=False + ) self.metadata = MetaData(engine) self.engine = engine @@ -74,21 +73,49 @@ class ReplayableSession(object): NoAttribute = object() if util.py2k: - Natives = set([getattr(types, t) - for t in dir(types) if not t.startswith('_')]).\ - difference([getattr(types, t) - for t in ('FunctionType', 'BuiltinFunctionType', - 'MethodType', 'BuiltinMethodType', - 'LambdaType', 'UnboundMethodType',)]) + Natives = set( + [getattr(types, t) for t in dir(types) if not t.startswith("_")] + ).difference( + [ + getattr(types, t) + for t in ( + "FunctionType", + "BuiltinFunctionType", + "MethodType", + "BuiltinMethodType", + "LambdaType", + "UnboundMethodType", + ) + ] + ) else: - Natives = set([getattr(types, t) - for t in dir(types) if not t.startswith('_')]).\ - union([type(t) if not isinstance(t, type) - else t for t in __builtins__.values()]).\ - difference([getattr(types, t) - for t in ('FunctionType', 'BuiltinFunctionType', - 'MethodType', 'BuiltinMethodType', - 'LambdaType', )]) + Natives = ( + set( + [ + getattr(types, t) + for t in dir(types) + if not t.startswith("_") + ] + ) + .union( + [ + type(t) if not isinstance(t, type) else t + for t in __builtins__.values() + ] + ) + .difference( + [ + getattr(types, t) + for t in ( + "FunctionType", + "BuiltinFunctionType", + "MethodType", + "BuiltinMethodType", + "LambdaType", + ) + ] + ) + ) def __init__(self): self.buffer = deque() @@ -105,8 +132,10 @@ class ReplayableSession(object): self._subject = subject def __call__(self, *args, **kw): - subject, buffer = [object.__getattribute__(self, x) - for x in ('_subject', '_buffer')] + subject, buffer = [ + object.__getattribute__(self, x) + for x in ("_subject", "_buffer") + ] result = subject(*args, **kw) if type(result) not in ReplayableSession.Natives: @@ -126,8 +155,10 @@ class ReplayableSession(object): except AttributeError: pass - subject, buffer = [object.__getattribute__(self, x) - for x in ('_subject', '_buffer')] + subject, buffer = [ + object.__getattribute__(self, x) + for x in ("_subject", "_buffer") + ] try: result = type(subject).__getattribute__(subject, key) except AttributeError: @@ -146,7 +177,7 @@ class ReplayableSession(object): self._buffer = buffer def __call__(self, *args, **kw): - buffer = object.__getattribute__(self, '_buffer') + buffer = object.__getattribute__(self, "_buffer") result = buffer.popleft() if result is ReplayableSession.Callable: return self @@ -162,7 +193,7 @@ class ReplayableSession(object): return object.__getattribute__(self, key) except AttributeError: pass - buffer = object.__getattribute__(self, '_buffer') + buffer = object.__getattribute__(self, "_buffer") result = buffer.popleft() if result is ReplayableSession.Callable: return self diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 58df643f4e..c96d26d322 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -26,7 +26,6 @@ class Requirements(object): class SuiteRequirements(Requirements): - @property def create_table(self): """target platform can emit basic CreateTable DDL.""" @@ -68,8 +67,8 @@ class SuiteRequirements(Requirements): # somehow only_if([x, y]) isn't working here, negation/conjunctions # getting confused. return exclusions.only_if( - lambda: self.on_update_cascade.enabled or - self.deferrable_fks.enabled + lambda: self.on_update_cascade.enabled + or self.deferrable_fks.enabled ) @property @@ -231,22 +230,21 @@ class SuiteRequirements(Requirements): def sane_rowcount(self): return exclusions.skip_if( lambda config: not config.db.dialect.supports_sane_rowcount, - "driver doesn't support 'sane' rowcount" + "driver doesn't support 'sane' rowcount", ) @property def sane_multi_rowcount(self): return exclusions.fails_if( lambda config: not config.db.dialect.supports_sane_multi_rowcount, - "driver %(driver)s %(doesnt_support)s 'sane' multi row count" + "driver %(driver)s %(doesnt_support)s 'sane' multi row count", ) @property def sane_rowcount_w_returning(self): return exclusions.fails_if( - lambda config: - not config.db.dialect.supports_sane_rowcount_returning, - "driver doesn't support 'sane' rowcount when returning is on" + lambda config: not config.db.dialect.supports_sane_rowcount_returning, + "driver doesn't support 'sane' rowcount when returning is on", ) @property @@ -255,9 +253,9 @@ class SuiteRequirements(Requirements): INSERT DEFAULT VALUES or equivalent.""" return exclusions.only_if( - lambda config: config.db.dialect.supports_empty_insert or - config.db.dialect.supports_default_values, - "empty inserts not supported" + lambda config: config.db.dialect.supports_empty_insert + or config.db.dialect.supports_default_values, + "empty inserts not supported", ) @property @@ -272,7 +270,7 @@ class SuiteRequirements(Requirements): return exclusions.only_if( lambda config: config.db.dialect.implicit_returning, - "%(database)s %(does_support)s 'returning'" + "%(database)s %(does_support)s 'returning'", ) @property @@ -297,7 +295,7 @@ class SuiteRequirements(Requirements): return exclusions.skip_if( lambda config: not config.db.dialect.requires_name_normalize, - "Backend does not require denormalized names." + "Backend does not require denormalized names.", ) @property @@ -307,7 +305,7 @@ class SuiteRequirements(Requirements): return exclusions.skip_if( lambda config: not config.db.dialect.supports_multivalues_insert, - "Backend does not support multirow inserts." + "Backend does not support multirow inserts.", ) @property @@ -355,27 +353,32 @@ class SuiteRequirements(Requirements): def server_side_cursors(self): """Target dialect must support server side cursors.""" - return exclusions.only_if([ - lambda config: config.db.dialect.supports_server_side_cursors - ], "no server side cursors support") + return exclusions.only_if( + [lambda config: config.db.dialect.supports_server_side_cursors], + "no server side cursors support", + ) @property def sequences(self): """Target database must support SEQUENCEs.""" - return exclusions.only_if([ - lambda config: config.db.dialect.supports_sequences - ], "no sequence support") + return exclusions.only_if( + [lambda config: config.db.dialect.supports_sequences], + "no sequence support", + ) @property def sequences_optional(self): """Target database supports sequences, but also optionally as a means of generating new PK values.""" - return exclusions.only_if([ - lambda config: config.db.dialect.supports_sequences and - config.db.dialect.sequences_optional - ], "no sequence support, or sequences not optional") + return exclusions.only_if( + [ + lambda config: config.db.dialect.supports_sequences + and config.db.dialect.sequences_optional + ], + "no sequence support, or sequences not optional", + ) @property def reflects_pk_names(self): @@ -841,7 +844,8 @@ class SuiteRequirements(Requirements): """ return exclusions.skip_if( - lambda config: config.options.low_connections) + lambda config: config.options.low_connections + ) @property def timing_intensive(self): @@ -859,37 +863,37 @@ class SuiteRequirements(Requirements): """ return exclusions.skip_if( lambda config: util.py3k and config.options.has_coverage, - "Stability issues with coverage + py3k" + "Stability issues with coverage + py3k", ) @property def python2(self): return exclusions.skip_if( lambda: sys.version_info >= (3,), - "Python version 2.xx is required." + "Python version 2.xx is required.", ) @property def python3(self): return exclusions.skip_if( - lambda: sys.version_info < (3,), - "Python version 3.xx is required." + lambda: sys.version_info < (3,), "Python version 3.xx is required." ) @property def cpython(self): return exclusions.only_if( - lambda: util.cpython, - "cPython interpreter needed" + lambda: util.cpython, "cPython interpreter needed" ) @property def non_broken_pickle(self): from sqlalchemy.util import pickle + return exclusions.only_if( - lambda: not util.pypy and pickle.__name__ == 'cPickle' - or sys.version_info >= (3, 2), - "Needs cPickle+cPython or newer Python 3 pickle" + lambda: not util.pypy + and pickle.__name__ == "cPickle" + or sys.version_info >= (3, 2), + "Needs cPickle+cPython or newer Python 3 pickle", ) @property @@ -910,7 +914,7 @@ class SuiteRequirements(Requirements): """ return exclusions.skip_if( lambda config: config.options.has_coverage, - "Issues observed when coverage is enabled" + "Issues observed when coverage is enabled", ) def _has_mysql_on_windows(self, config): @@ -931,8 +935,9 @@ class SuiteRequirements(Requirements): def _has_sqlite(self): from sqlalchemy import create_engine + try: - create_engine('sqlite://') + create_engine("sqlite://") return True except ImportError: return False @@ -940,6 +945,7 @@ class SuiteRequirements(Requirements): def _has_cextensions(self): try: from sqlalchemy import cresultproxy, cprocessors + return True except ImportError: return False diff --git a/lib/sqlalchemy/testing/runner.py b/lib/sqlalchemy/testing/runner.py index 87be0749c9..6aa820fd57 100644 --- a/lib/sqlalchemy/testing/runner.py +++ b/lib/sqlalchemy/testing/runner.py @@ -47,4 +47,4 @@ def setup_py_test(): to nose. """ - nose.main(addplugins=[NoseSQLAlchemy()], argv=['runner']) + nose.main(addplugins=[NoseSQLAlchemy()], argv=["runner"]) diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index 401c8cbb78..b345a94874 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -9,7 +9,7 @@ from . import exclusions from .. import schema, event from . import config -__all__ = 'Table', 'Column', +__all__ = "Table", "Column" table_options = {} @@ -17,30 +17,35 @@ table_options = {} def Table(*args, **kw): """A schema.Table wrapper/hook for dialect-specific tweaks.""" - test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith('test_')} + test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")} kw.update(table_options) - if exclusions.against(config._current, 'mysql'): - if 'mysql_engine' not in kw and 'mysql_type' not in kw and \ - "autoload_with" not in kw: - if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts: - kw['mysql_engine'] = 'InnoDB' + if exclusions.against(config._current, "mysql"): + if ( + "mysql_engine" not in kw + and "mysql_type" not in kw + and "autoload_with" not in kw + ): + if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts: + kw["mysql_engine"] = "InnoDB" else: - kw['mysql_engine'] = 'MyISAM' + kw["mysql_engine"] = "MyISAM" # Apply some default cascading rules for self-referential foreign keys. # MySQL InnoDB has some issues around seleting self-refs too. - if exclusions.against(config._current, 'firebird'): + if exclusions.against(config._current, "firebird"): table_name = args[0] - unpack = (config.db.dialect. - identifier_preparer.unformat_identifiers) + unpack = config.db.dialect.identifier_preparer.unformat_identifiers # Only going after ForeignKeys in Columns. May need to # expand to ForeignKeyConstraint too. - fks = [fk - for col in args if isinstance(col, schema.Column) - for fk in col.foreign_keys] + fks = [ + fk + for col in args + if isinstance(col, schema.Column) + for fk in col.foreign_keys + ] for fk in fks: # root around in raw spec @@ -54,9 +59,9 @@ def Table(*args, **kw): name = unpack(ref)[0] if name == table_name: if fk.ondelete is None: - fk.ondelete = 'CASCADE' + fk.ondelete = "CASCADE" if fk.onupdate is None: - fk.onupdate = 'CASCADE' + fk.onupdate = "CASCADE" return schema.Table(*args, **kw) @@ -64,37 +69,46 @@ def Table(*args, **kw): def Column(*args, **kw): """A schema.Column wrapper/hook for dialect-specific tweaks.""" - test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith('test_')} + test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")} if not config.requirements.foreign_key_ddl.enabled_for_config(config): args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)] col = schema.Column(*args, **kw) - if test_opts.get('test_needs_autoincrement', False) and \ - kw.get('primary_key', False): + if test_opts.get("test_needs_autoincrement", False) and kw.get( + "primary_key", False + ): if col.default is None and col.server_default is None: col.autoincrement = True # allow any test suite to pick up on this - col.info['test_needs_autoincrement'] = True + col.info["test_needs_autoincrement"] = True # hardcoded rule for firebird, oracle; this should # be moved out - if exclusions.against(config._current, 'firebird', 'oracle'): + if exclusions.against(config._current, "firebird", "oracle"): + def add_seq(c, tbl): c._init_items( - schema.Sequence(_truncate_name( - config.db.dialect, tbl.name + '_' + c.name + '_seq'), - optional=True) + schema.Sequence( + _truncate_name( + config.db.dialect, tbl.name + "_" + c.name + "_seq" + ), + optional=True, + ) ) - event.listen(col, 'after_parent_attach', add_seq, propagate=True) + + event.listen(col, "after_parent_attach", add_seq, propagate=True) return col def _truncate_name(dialect, name): if len(name) > dialect.max_identifier_length: - return name[0:max(dialect.max_identifier_length - 6, 0)] + \ - "_" + hex(hash(name) % 64)[2:] + return ( + name[0 : max(dialect.max_identifier_length - 6, 0)] + + "_" + + hex(hash(name) % 64)[2:] + ) else: return name diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py index 748d9722d6..a4e142c5a3 100644 --- a/lib/sqlalchemy/testing/suite/__init__.py +++ b/lib/sqlalchemy/testing/suite/__init__.py @@ -1,4 +1,3 @@ - from sqlalchemy.testing.suite.test_cte import * from sqlalchemy.testing.suite.test_dialect import * from sqlalchemy.testing.suite.test_ddl import * diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py index cc72278e6c..d2f35933bb 100644 --- a/lib/sqlalchemy/testing/suite/test_cte.py +++ b/lib/sqlalchemy/testing/suite/test_cte.py @@ -10,22 +10,28 @@ from ..schema import Table, Column class CTETest(fixtures.TablesTest): __backend__ = True - __requires__ = 'ctes', + __requires__ = ("ctes",) - run_inserts = 'each' - run_deletes = 'each' + run_inserts = "each" + run_deletes = "each" @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column("parent_id", ForeignKey("some_table.id"))) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("parent_id", ForeignKey("some_table.id")), + ) - Table("some_other_table", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column("parent_id", Integer)) + Table( + "some_other_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("parent_id", Integer), + ) @classmethod def insert_data(cls): @@ -36,28 +42,33 @@ class CTETest(fixtures.TablesTest): {"id": 2, "data": "d2", "parent_id": 1}, {"id": 3, "data": "d3", "parent_id": 1}, {"id": 4, "data": "d4", "parent_id": 3}, - {"id": 5, "data": "d5", "parent_id": 3} - ] + {"id": 5, "data": "d5", "parent_id": 3}, + ], ) def test_select_nonrecursive_round_trip(self): some_table = self.tables.some_table with config.db.connect() as conn: - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"])).cte("some_cte") + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) result = conn.execute( select([cte.c.data]).where(cte.c.data.in_(["d4", "d5"])) ) - eq_(result.fetchall(), [("d4", )]) + eq_(result.fetchall(), [("d4",)]) def test_select_recursive_round_trip(self): some_table = self.tables.some_table with config.db.connect() as conn: - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"])).cte( - "some_cte", recursive=True) + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte", recursive=True) + ) cte_alias = cte.alias("c1") st1 = some_table.alias() @@ -67,12 +78,13 @@ class CTETest(fixtures.TablesTest): select([st1]).where(st1.c.id == cte_alias.c.parent_id) ) result = conn.execute( - select([cte.c.data]).where( - cte.c.data != "d2").order_by(cte.c.data.desc()) + select([cte.c.data]) + .where(cte.c.data != "d2") + .order_by(cte.c.data.desc()) ) eq_( result.fetchall(), - [('d4',), ('d3',), ('d3',), ('d1',), ('d1',), ('d1',)] + [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)], ) def test_insert_from_select_round_trip(self): @@ -80,20 +92,21 @@ class CTETest(fixtures.TablesTest): some_other_table = self.tables.some_other_table with config.db.connect() as conn: - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"]) - ).cte("some_cte") + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) conn.execute( some_other_table.insert().from_select( - ["id", "data", "parent_id"], - select([cte]) + ["id", "data", "parent_id"], select([cte]) ) ) eq_( conn.execute( select([some_other_table]).order_by(some_other_table.c.id) ).fetchall(), - [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)] + [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)], ) @testing.requires.ctes_with_update_delete @@ -105,27 +118,31 @@ class CTETest(fixtures.TablesTest): with config.db.connect() as conn: conn.execute( some_other_table.insert().from_select( - ['id', 'data', 'parent_id'], - select([some_table]) + ["id", "data", "parent_id"], select([some_table]) ) ) - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"]) - ).cte("some_cte") + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) conn.execute( - some_other_table.update().values(parent_id=5).where( - some_other_table.c.data == cte.c.data - ) + some_other_table.update() + .values(parent_id=5) + .where(some_other_table.c.data == cte.c.data) ) eq_( conn.execute( select([some_other_table]).order_by(some_other_table.c.id) ).fetchall(), [ - (1, "d1", None), (2, "d2", 5), - (3, "d3", 5), (4, "d4", 5), (5, "d5", 3) - ] + (1, "d1", None), + (2, "d2", 5), + (3, "d3", 5), + (4, "d4", 5), + (5, "d5", 3), + ], ) @testing.requires.ctes_with_update_delete @@ -137,14 +154,15 @@ class CTETest(fixtures.TablesTest): with config.db.connect() as conn: conn.execute( some_other_table.insert().from_select( - ['id', 'data', 'parent_id'], - select([some_table]) + ["id", "data", "parent_id"], select([some_table]) ) ) - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"]) - ).cte("some_cte") + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) conn.execute( some_other_table.delete().where( some_other_table.c.data == cte.c.data @@ -154,9 +172,7 @@ class CTETest(fixtures.TablesTest): conn.execute( select([some_other_table]).order_by(some_other_table.c.id) ).fetchall(), - [ - (1, "d1", None), (5, "d5", 3) - ] + [(1, "d1", None), (5, "d5", 3)], ) @testing.requires.ctes_with_update_delete @@ -168,26 +184,26 @@ class CTETest(fixtures.TablesTest): with config.db.connect() as conn: conn.execute( some_other_table.insert().from_select( - ['id', 'data', 'parent_id'], - select([some_table]) + ["id", "data", "parent_id"], select([some_table]) ) ) - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"]) - ).cte("some_cte") + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) conn.execute( some_other_table.delete().where( - some_other_table.c.data == - select([cte.c.data]).where( - cte.c.id == some_other_table.c.id) + some_other_table.c.data + == select([cte.c.data]).where( + cte.c.id == some_other_table.c.id + ) ) ) eq_( conn.execute( select([some_other_table]).order_by(some_other_table.c.id) ).fetchall(), - [ - (1, "d1", None), (5, "d5", 3) - ] + [(1, "d1", None), (5, "d5", 3)], ) diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py index 1d8010c8ae..7c44388d44 100644 --- a/lib/sqlalchemy/testing/suite/test_ddl.py +++ b/lib/sqlalchemy/testing/suite/test_ddl.py @@ -1,5 +1,3 @@ - - from .. import fixtures, config, util from ..config import requirements from ..assertions import eq_ @@ -11,55 +9,47 @@ class TableDDLTest(fixtures.TestBase): __backend__ = True def _simple_fixture(self): - return Table('test_table', self.metadata, - Column('id', Integer, primary_key=True, - autoincrement=False), - Column('data', String(50)) - ) + return Table( + "test_table", + self.metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) def _underscore_fixture(self): - return Table('_test_table', self.metadata, - Column('id', Integer, primary_key=True, - autoincrement=False), - Column('_data', String(50)) - ) + return Table( + "_test_table", + self.metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("_data", String(50)), + ) def _simple_roundtrip(self, table): with config.db.begin() as conn: - conn.execute(table.insert().values((1, 'some data'))) + conn.execute(table.insert().values((1, "some data"))) result = conn.execute(table.select()) - eq_( - result.first(), - (1, 'some data') - ) + eq_(result.first(), (1, "some data")) @requirements.create_table @util.provide_metadata def test_create_table(self): table = self._simple_fixture() - table.create( - config.db, checkfirst=False - ) + table.create(config.db, checkfirst=False) self._simple_roundtrip(table) @requirements.drop_table @util.provide_metadata def test_drop_table(self): table = self._simple_fixture() - table.create( - config.db, checkfirst=False - ) - table.drop( - config.db, checkfirst=False - ) + table.create(config.db, checkfirst=False) + table.drop(config.db, checkfirst=False) @requirements.create_table @util.provide_metadata def test_underscore_names(self): table = self._underscore_fixture() - table.create( - config.db, checkfirst=False - ) + table.create(config.db, checkfirst=False) self._simple_roundtrip(table) -__all__ = ('TableDDLTest', ) + +__all__ = ("TableDDLTest",) diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py index 2c5dd0e364..5e589f3b88 100644 --- a/lib/sqlalchemy/testing/suite/test_dialect.py +++ b/lib/sqlalchemy/testing/suite/test_dialect.py @@ -15,16 +15,19 @@ class ExceptionTest(fixtures.TablesTest): specific exceptions from real round trips, we need to be conservative. """ - run_deletes = 'each' + + run_deletes = "each" __backend__ = True @classmethod def define_tables(cls, metadata): - Table('manual_pk', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('data', String(50)) - ) + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) @requirements.duplicate_key_raises_integrity_error def test_integrity_error(self): @@ -33,15 +36,14 @@ class ExceptionTest(fixtures.TablesTest): trans = conn.begin() conn.execute( - self.tables.manual_pk.insert(), - {'id': 1, 'data': 'd1'} + self.tables.manual_pk.insert(), {"id": 1, "data": "d1"} ) assert_raises( exc.IntegrityError, conn.execute, self.tables.manual_pk.insert(), - {'id': 1, 'data': 'd1'} + {"id": 1, "data": "d1"}, ) trans.rollback() @@ -49,38 +51,39 @@ class ExceptionTest(fixtures.TablesTest): class AutocommitTest(fixtures.TablesTest): - run_deletes = 'each' + run_deletes = "each" - __requires__ = 'autocommit', + __requires__ = ("autocommit",) __backend__ = True @classmethod def define_tables(cls, metadata): - Table('some_table', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('data', String(50)), - test_needs_acid=True - ) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + test_needs_acid=True, + ) def _test_conn_autocommits(self, conn, autocommit): trans = conn.begin() conn.execute( - self.tables.some_table.insert(), - {"id": 1, "data": "some data"} + self.tables.some_table.insert(), {"id": 1, "data": "some data"} ) trans.rollback() eq_( conn.scalar(select([self.tables.some_table.c.id])), - 1 if autocommit else None + 1 if autocommit else None, ) conn.execute(self.tables.some_table.delete()) def test_autocommit_on(self): conn = config.db.connect() - c2 = conn.execution_options(isolation_level='AUTOCOMMIT') + c2 = conn.execution_options(isolation_level="AUTOCOMMIT") self._test_conn_autocommits(c2, True) conn.invalidate() self._test_conn_autocommits(conn, False) @@ -98,7 +101,7 @@ class EscapingTest(fixtures.TestBase): """ m = self.metadata - t = Table('t', m, Column('data', String(50))) + t = Table("t", m, Column("data", String(50))) t.create(config.db) with config.db.begin() as conn: conn.execute(t.insert(), dict(data="some % value")) @@ -107,14 +110,17 @@ class EscapingTest(fixtures.TestBase): eq_( conn.scalar( select([t.c.data]).where( - t.c.data == literal_column("'some % value'")) + t.c.data == literal_column("'some % value'") + ) ), - "some % value" + "some % value", ) eq_( conn.scalar( select([t.c.data]).where( - t.c.data == literal_column("'some %% other value'")) - ), "some %% other value" + t.c.data == literal_column("'some %% other value'") + ) + ), + "some %% other value", ) diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py index c0b6b18ebd..6257451eb0 100644 --- a/lib/sqlalchemy/testing/suite/test_insert.py +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -10,53 +10,48 @@ from ..schema import Table, Column class LastrowidTest(fixtures.TablesTest): - run_deletes = 'each' + run_deletes = "each" __backend__ = True - __requires__ = 'implements_get_lastrowid', 'autoincrement_insert' + __requires__ = "implements_get_lastrowid", "autoincrement_insert" __engine_options__ = {"implicit_returning": False} @classmethod def define_tables(cls, metadata): - Table('autoinc_pk', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)) - ) + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) - Table('manual_pk', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('data', String(50)) - ) + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) def _assert_round_trip(self, table, conn): row = conn.execute(table.select()).first() - eq_( - row, - (config.db.dialect.default_sequence_base, "some data") - ) + eq_(row, (config.db.dialect.default_sequence_base, "some data")) def test_autoincrement_on_insert(self): - config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" - ) + config.db.execute(self.tables.autoinc_pk.insert(), data="some data") self._assert_round_trip(self.tables.autoinc_pk, config.db) def test_last_inserted_id(self): r = config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" + self.tables.autoinc_pk.insert(), data="some data" ) pk = config.db.scalar(select([self.tables.autoinc_pk.c.id])) - eq_( - r.inserted_primary_key, - [pk] - ) + eq_(r.inserted_primary_key, [pk]) # failed on pypy1.9 but seems to be OK on pypy 2.1 # @exclusions.fails_if(lambda: util.pypy, @@ -65,50 +60,57 @@ class LastrowidTest(fixtures.TablesTest): @requirements.dbapi_lastrowid def test_native_lastrowid_autoinc(self): r = config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" + self.tables.autoinc_pk.insert(), data="some data" ) lastrowid = r.lastrowid pk = config.db.scalar(select([self.tables.autoinc_pk.c.id])) - eq_( - lastrowid, pk - ) + eq_(lastrowid, pk) class InsertBehaviorTest(fixtures.TablesTest): - run_deletes = 'each' + run_deletes = "each" __backend__ = True @classmethod def define_tables(cls, metadata): - Table('autoinc_pk', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)) - ) - Table('manual_pk', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('data', String(50)) - ) - Table('includes_defaults', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)), - Column('x', Integer, default=5), - Column('y', Integer, - default=literal_column("2", type_=Integer) + literal(2))) + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) + Table( + "includes_defaults", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + Column("x", Integer, default=5), + Column( + "y", + Integer, + default=literal_column("2", type_=Integer) + literal(2), + ), + ) def test_autoclose_on_insert(self): if requirements.returning.enabled: engine = engines.testing_engine( - options={'implicit_returning': False}) + options={"implicit_returning": False} + ) else: engine = config.db - r = engine.execute( - self.tables.autoinc_pk.insert(), - data="some data" - ) + r = engine.execute(self.tables.autoinc_pk.insert(), data="some data") assert r._soft_closed assert not r.closed assert r.is_insert @@ -117,8 +119,7 @@ class InsertBehaviorTest(fixtures.TablesTest): @requirements.returning def test_autoclose_on_insert_implicit_returning(self): r = config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" + self.tables.autoinc_pk.insert(), data="some data" ) assert r._soft_closed assert not r.closed @@ -127,15 +128,14 @@ class InsertBehaviorTest(fixtures.TablesTest): @requirements.empty_inserts def test_empty_insert(self): - r = config.db.execute( - self.tables.autoinc_pk.insert(), - ) + r = config.db.execute(self.tables.autoinc_pk.insert()) assert r._soft_closed assert not r.closed r = config.db.execute( - self.tables.autoinc_pk.select(). - where(self.tables.autoinc_pk.c.id != None) + self.tables.autoinc_pk.select().where( + self.tables.autoinc_pk.c.id != None + ) ) assert len(r.fetchall()) @@ -150,15 +150,15 @@ class InsertBehaviorTest(fixtures.TablesTest): dict(id=1, data="data1"), dict(id=2, data="data2"), dict(id=3, data="data3"), - ] + ], ) result = config.db.execute( - dest_table.insert(). - from_select( + dest_table.insert().from_select( ("data",), - select([src_table.c.data]). - where(src_table.c.data.in_(["data2", "data3"])) + select([src_table.c.data]).where( + src_table.c.data.in_(["data2", "data3"]) + ), ) ) @@ -167,7 +167,7 @@ class InsertBehaviorTest(fixtures.TablesTest): result = config.db.execute( select([dest_table.c.data]).order_by(dest_table.c.data) ) - eq_(result.fetchall(), [("data2", ), ("data3", )]) + eq_(result.fetchall(), [("data2",), ("data3",)]) @requirements.insert_from_select def test_insert_from_select_autoinc_no_rows(self): @@ -175,11 +175,11 @@ class InsertBehaviorTest(fixtures.TablesTest): dest_table = self.tables.autoinc_pk result = config.db.execute( - dest_table.insert(). - from_select( + dest_table.insert().from_select( ("data",), - select([src_table.c.data]). - where(src_table.c.data.in_(["data2", "data3"])) + select([src_table.c.data]).where( + src_table.c.data.in_(["data2", "data3"]) + ), ) ) eq_(result.inserted_primary_key, [None]) @@ -199,23 +199,23 @@ class InsertBehaviorTest(fixtures.TablesTest): dict(id=1, data="data1"), dict(id=2, data="data2"), dict(id=3, data="data3"), - ] + ], ) config.db.execute( - table.insert(inline=True). - from_select(("id", "data",), - select([table.c.id + 5, table.c.data]). - where(table.c.data.in_(["data2", "data3"])) - ), + table.insert(inline=True).from_select( + ("id", "data"), + select([table.c.id + 5, table.c.data]).where( + table.c.data.in_(["data2", "data3"]) + ), + ) ) eq_( config.db.execute( select([table.c.data]).order_by(table.c.data) ).fetchall(), - [("data1", ), ("data2", ), ("data2", ), - ("data3", ), ("data3", )] + [("data1",), ("data2",), ("data2",), ("data3",), ("data3",)], ) @requirements.insert_from_select @@ -227,56 +227,60 @@ class InsertBehaviorTest(fixtures.TablesTest): dict(id=1, data="data1"), dict(id=2, data="data2"), dict(id=3, data="data3"), - ] + ], ) config.db.execute( - table.insert(inline=True). - from_select(("id", "data",), - select([table.c.id + 5, table.c.data]). - where(table.c.data.in_(["data2", "data3"])) - ), + table.insert(inline=True).from_select( + ("id", "data"), + select([table.c.id + 5, table.c.data]).where( + table.c.data.in_(["data2", "data3"]) + ), + ) ) eq_( config.db.execute( select([table]).order_by(table.c.data, table.c.id) ).fetchall(), - [(1, 'data1', 5, 4), (2, 'data2', 5, 4), - (7, 'data2', 5, 4), (3, 'data3', 5, 4), (8, 'data3', 5, 4)] + [ + (1, "data1", 5, 4), + (2, "data2", 5, 4), + (7, "data2", 5, 4), + (3, "data3", 5, 4), + (8, "data3", 5, 4), + ], ) class ReturningTest(fixtures.TablesTest): - run_create_tables = 'each' - __requires__ = 'returning', 'autoincrement_insert' + run_create_tables = "each" + __requires__ = "returning", "autoincrement_insert" __backend__ = True __engine_options__ = {"implicit_returning": True} def _assert_round_trip(self, table, conn): row = conn.execute(table.select()).first() - eq_( - row, - (config.db.dialect.default_sequence_base, "some data") - ) + eq_(row, (config.db.dialect.default_sequence_base, "some data")) @classmethod def define_tables(cls, metadata): - Table('autoinc_pk', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)) - ) + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) @requirements.fetch_rows_post_commit def test_explicit_returning_pk_autocommit(self): engine = config.db table = self.tables.autoinc_pk r = engine.execute( - table.insert().returning( - table.c.id), - data="some data" + table.insert().returning(table.c.id), data="some data" ) pk = r.first()[0] fetched_pk = config.db.scalar(select([table.c.id])) @@ -287,9 +291,7 @@ class ReturningTest(fixtures.TablesTest): table = self.tables.autoinc_pk with engine.begin() as conn: r = conn.execute( - table.insert().returning( - table.c.id), - data="some data" + table.insert().returning(table.c.id), data="some data" ) pk = r.first()[0] fetched_pk = config.db.scalar(select([table.c.id])) @@ -297,23 +299,16 @@ class ReturningTest(fixtures.TablesTest): def test_autoincrement_on_insert_implcit_returning(self): - config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" - ) + config.db.execute(self.tables.autoinc_pk.insert(), data="some data") self._assert_round_trip(self.tables.autoinc_pk, config.db) def test_last_inserted_id_implicit_returning(self): r = config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" + self.tables.autoinc_pk.insert(), data="some data" ) pk = config.db.scalar(select([self.tables.autoinc_pk.c.id])) - eq_( - r.inserted_primary_key, - [pk] - ) + eq_(r.inserted_primary_key, [pk]) -__all__ = ('LastrowidTest', 'InsertBehaviorTest', 'ReturningTest') +__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest") diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 00a5aac018..bfed5f1ab5 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -1,5 +1,3 @@ - - import sqlalchemy as sa from sqlalchemy import exc as sa_exc from sqlalchemy import types as sql_types @@ -26,10 +24,12 @@ class HasTableTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table('test_table', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) - ) + Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) def test_has_table(self): with config.db.begin() as conn: @@ -46,8 +46,10 @@ class ComponentReflectionTest(fixtures.TablesTest): def setup_bind(cls): if config.requirements.independent_connections.enabled: from sqlalchemy import pool + return engines.testing_engine( - options=dict(poolclass=pool.StaticPool)) + options=dict(poolclass=pool.StaticPool) + ) else: return config.db @@ -65,86 +67,109 @@ class ComponentReflectionTest(fixtures.TablesTest): schema_prefix = "" if testing.requires.self_referential_foreign_keys.enabled: - users = Table('users', metadata, - Column('user_id', sa.INT, primary_key=True), - Column('test1', sa.CHAR(5), nullable=False), - Column('test2', sa.Float(5), nullable=False), - Column('parent_user_id', sa.Integer, - sa.ForeignKey('%susers.user_id' % - schema_prefix, - name='user_id_fk')), - schema=schema, - test_needs_fk=True, - ) + users = Table( + "users", + metadata, + Column("user_id", sa.INT, primary_key=True), + Column("test1", sa.CHAR(5), nullable=False), + Column("test2", sa.Float(5), nullable=False), + Column( + "parent_user_id", + sa.Integer, + sa.ForeignKey( + "%susers.user_id" % schema_prefix, name="user_id_fk" + ), + ), + schema=schema, + test_needs_fk=True, + ) else: - users = Table('users', metadata, - Column('user_id', sa.INT, primary_key=True), - Column('test1', sa.CHAR(5), nullable=False), - Column('test2', sa.Float(5), nullable=False), - schema=schema, - test_needs_fk=True, - ) - - Table("dingalings", metadata, - Column('dingaling_id', sa.Integer, primary_key=True), - Column('address_id', sa.Integer, - sa.ForeignKey('%semail_addresses.address_id' % - schema_prefix)), - Column('data', sa.String(30)), - schema=schema, - test_needs_fk=True, - ) - Table('email_addresses', metadata, - Column('address_id', sa.Integer), - Column('remote_user_id', sa.Integer, - sa.ForeignKey(users.c.user_id)), - Column('email_address', sa.String(20)), - sa.PrimaryKeyConstraint('address_id', name='email_ad_pk'), - schema=schema, - test_needs_fk=True, - ) - Table('comment_test', metadata, - Column('id', sa.Integer, primary_key=True, comment='id comment'), - Column('data', sa.String(20), comment='data % comment'), - Column( - 'd2', sa.String(20), - comment=r"""Comment types type speedily ' " \ '' Fun!"""), - schema=schema, - comment=r"""the test % ' " \ table comment""") + users = Table( + "users", + metadata, + Column("user_id", sa.INT, primary_key=True), + Column("test1", sa.CHAR(5), nullable=False), + Column("test2", sa.Float(5), nullable=False), + schema=schema, + test_needs_fk=True, + ) + + Table( + "dingalings", + metadata, + Column("dingaling_id", sa.Integer, primary_key=True), + Column( + "address_id", + sa.Integer, + sa.ForeignKey("%semail_addresses.address_id" % schema_prefix), + ), + Column("data", sa.String(30)), + schema=schema, + test_needs_fk=True, + ) + Table( + "email_addresses", + metadata, + Column("address_id", sa.Integer), + Column( + "remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id) + ), + Column("email_address", sa.String(20)), + sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"), + schema=schema, + test_needs_fk=True, + ) + Table( + "comment_test", + metadata, + Column("id", sa.Integer, primary_key=True, comment="id comment"), + Column("data", sa.String(20), comment="data % comment"), + Column( + "d2", + sa.String(20), + comment=r"""Comment types type speedily ' " \ '' Fun!""", + ), + schema=schema, + comment=r"""the test % ' " \ table comment""", + ) if testing.requires.cross_schema_fk_reflection.enabled: if schema is None: Table( - 'local_table', metadata, - Column('id', sa.Integer, primary_key=True), - Column('data', sa.String(20)), + "local_table", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("data", sa.String(20)), Column( - 'remote_id', + "remote_id", ForeignKey( - '%s.remote_table_2.id' % - testing.config.test_schema) + "%s.remote_table_2.id" % testing.config.test_schema + ), ), test_needs_fk=True, - schema=config.db.dialect.default_schema_name + schema=config.db.dialect.default_schema_name, ) else: Table( - 'remote_table', metadata, - Column('id', sa.Integer, primary_key=True), + "remote_table", + metadata, + Column("id", sa.Integer, primary_key=True), Column( - 'local_id', + "local_id", ForeignKey( - '%s.local_table.id' % - config.db.dialect.default_schema_name) + "%s.local_table.id" + % config.db.dialect.default_schema_name + ), ), - Column('data', sa.String(20)), + Column("data", sa.String(20)), schema=schema, test_needs_fk=True, ) Table( - 'remote_table_2', metadata, - Column('id', sa.Integer, primary_key=True), - Column('data', sa.String(20)), + "remote_table_2", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("data", sa.String(20)), schema=schema, test_needs_fk=True, ) @@ -155,19 +180,21 @@ class ComponentReflectionTest(fixtures.TablesTest): if not schema: # test_needs_fk is at the moment to force MySQL InnoDB noncol_idx_test_nopk = Table( - 'noncol_idx_test_nopk', metadata, - Column('q', sa.String(5)), + "noncol_idx_test_nopk", + metadata, + Column("q", sa.String(5)), test_needs_fk=True, ) noncol_idx_test_pk = Table( - 'noncol_idx_test_pk', metadata, - Column('id', sa.Integer, primary_key=True), - Column('q', sa.String(5)), + "noncol_idx_test_pk", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("q", sa.String(5)), test_needs_fk=True, ) - Index('noncol_idx_nopk', noncol_idx_test_nopk.c.q.desc()) - Index('noncol_idx_pk', noncol_idx_test_pk.c.q.desc()) + Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc()) + Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc()) if testing.requires.view_column_reflection.enabled: cls.define_views(metadata, schema) @@ -180,34 +207,35 @@ class ComponentReflectionTest(fixtures.TablesTest): # temp table fixture if testing.against("oracle"): kw = { - 'prefixes': ["GLOBAL TEMPORARY"], - 'oracle_on_commit': 'PRESERVE ROWS' + "prefixes": ["GLOBAL TEMPORARY"], + "oracle_on_commit": "PRESERVE ROWS", } else: - kw = { - 'prefixes': ["TEMPORARY"], - } + kw = {"prefixes": ["TEMPORARY"]} user_tmp = Table( - "user_tmp", metadata, + "user_tmp", + metadata, Column("id", sa.INT, primary_key=True), - Column('name', sa.VARCHAR(50)), - Column('foo', sa.INT), - sa.UniqueConstraint('name', name='user_tmp_uq'), + Column("name", sa.VARCHAR(50)), + Column("foo", sa.INT), + sa.UniqueConstraint("name", name="user_tmp_uq"), sa.Index("user_tmp_ix", "foo"), **kw ) - if testing.requires.view_reflection.enabled and \ - testing.requires.temporary_views.enabled: - event.listen( - user_tmp, "after_create", - DDL("create temporary view user_tmp_v as " - "select * from user_tmp") - ) + if ( + testing.requires.view_reflection.enabled + and testing.requires.temporary_views.enabled + ): event.listen( - user_tmp, "before_drop", - DDL("drop view user_tmp_v") + user_tmp, + "after_create", + DDL( + "create temporary view user_tmp_v as " + "select * from user_tmp" + ), ) + event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v")) @classmethod def define_index(cls, metadata, users): @@ -216,23 +244,19 @@ class ComponentReflectionTest(fixtures.TablesTest): @classmethod def define_views(cls, metadata, schema): - for table_name in ('users', 'email_addresses'): + for table_name in ("users", "email_addresses"): fullname = table_name if schema: fullname = "%s.%s" % (schema, table_name) - view_name = fullname + '_v' + view_name = fullname + "_v" query = "CREATE VIEW %s AS SELECT * FROM %s" % ( - view_name, fullname) - - event.listen( - metadata, - "after_create", - DDL(query) + view_name, + fullname, ) + + event.listen(metadata, "after_create", DDL(query)) event.listen( - metadata, - "before_drop", - DDL("DROP VIEW %s" % view_name) + metadata, "before_drop", DDL("DROP VIEW %s" % view_name) ) @testing.requires.schema_reflection @@ -244,9 +268,9 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.schema_reflection def test_dialect_initialize(self): engine = engines.testing_engine() - assert not hasattr(engine.dialect, 'default_schema_name') + assert not hasattr(engine.dialect, "default_schema_name") inspect(engine) - assert hasattr(engine.dialect, 'default_schema_name') + assert hasattr(engine.dialect, "default_schema_name") @testing.requires.schema_reflection def test_get_default_schema_name(self): @@ -254,40 +278,49 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(insp.default_schema_name, testing.db.dialect.default_schema_name) @testing.provide_metadata - def _test_get_table_names(self, schema=None, table_type='table', - order_by=None): + def _test_get_table_names( + self, schema=None, table_type="table", order_by=None + ): _ignore_tables = [ - 'comment_test', 'noncol_idx_test_pk', 'noncol_idx_test_nopk', - 'local_table', 'remote_table', 'remote_table_2' + "comment_test", + "noncol_idx_test_pk", + "noncol_idx_test_nopk", + "local_table", + "remote_table", + "remote_table_2", ] meta = self.metadata - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) insp = inspect(meta.bind) - if table_type == 'view': + if table_type == "view": table_names = insp.get_view_names(schema) table_names.sort() - answer = ['email_addresses_v', 'users_v'] + answer = ["email_addresses_v", "users_v"] eq_(sorted(table_names), answer) else: table_names = [ - t for t in insp.get_table_names( - schema, - order_by=order_by) if t not in _ignore_tables] + t + for t in insp.get_table_names(schema, order_by=order_by) + if t not in _ignore_tables + ] - if order_by == 'foreign_key': - answer = ['users', 'email_addresses', 'dingalings'] + if order_by == "foreign_key": + answer = ["users", "email_addresses", "dingalings"] eq_(table_names, answer) else: - answer = ['dingalings', 'email_addresses', 'users'] + answer = ["dingalings", "email_addresses", "users"] eq_(sorted(table_names), answer) @testing.requires.temp_table_names def test_get_temp_table_names(self): insp = inspect(self.bind) temp_table_names = insp.get_temp_table_names() - eq_(sorted(temp_table_names), ['user_tmp']) + eq_(sorted(temp_table_names), ["user_tmp"]) @testing.requires.view_reflection @testing.requires.temp_table_names @@ -295,7 +328,7 @@ class ComponentReflectionTest(fixtures.TablesTest): def test_get_temp_view_names(self): insp = inspect(self.bind) temp_table_names = insp.get_temp_view_names() - eq_(sorted(temp_table_names), ['user_tmp_v']) + eq_(sorted(temp_table_names), ["user_tmp_v"]) @testing.requires.table_reflection def test_get_table_names(self): @@ -304,7 +337,7 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.table_reflection @testing.requires.foreign_key_constraint_reflection def test_get_table_names_fks(self): - self._test_get_table_names(order_by='foreign_key') + self._test_get_table_names(order_by="foreign_key") @testing.requires.comment_reflection def test_get_comments(self): @@ -320,26 +353,24 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_( insp.get_table_comment("comment_test", schema=schema), - {"text": r"""the test % ' " \ table comment"""} + {"text": r"""the test % ' " \ table comment"""}, ) - eq_( - insp.get_table_comment("users", schema=schema), - {"text": None} - ) + eq_(insp.get_table_comment("users", schema=schema), {"text": None}) eq_( [ - {"name": rec['name'], "comment": rec['comment']} - for rec in - insp.get_columns("comment_test", schema=schema) + {"name": rec["name"], "comment": rec["comment"]} + for rec in insp.get_columns("comment_test", schema=schema) ], [ - {'comment': 'id comment', 'name': 'id'}, - {'comment': 'data % comment', 'name': 'data'}, - {'comment': r"""Comment types type speedily ' " \ '' Fun!""", - 'name': 'd2'} - ] + {"comment": "id comment", "name": "id"}, + {"comment": "data % comment", "name": "data"}, + { + "comment": r"""Comment types type speedily ' " \ '' Fun!""", + "name": "d2", + }, + ], ) @testing.requires.table_reflection @@ -349,30 +380,33 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.view_column_reflection def test_get_view_names(self): - self._test_get_table_names(table_type='view') + self._test_get_table_names(table_type="view") @testing.requires.view_column_reflection @testing.requires.schemas def test_get_view_names_with_schema(self): self._test_get_table_names( - testing.config.test_schema, table_type='view') + testing.config.test_schema, table_type="view" + ) @testing.requires.table_reflection @testing.requires.view_column_reflection def test_get_tables_and_views(self): self._test_get_table_names() - self._test_get_table_names(table_type='view') + self._test_get_table_names(table_type="view") - def _test_get_columns(self, schema=None, table_type='table'): + def _test_get_columns(self, schema=None, table_type="table"): meta = MetaData(testing.db) - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings - table_names = ['users', 'email_addresses'] - if table_type == 'view': - table_names = ['users_v', 'email_addresses_v'] + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) + table_names = ["users", "email_addresses"] + if table_type == "view": + table_names = ["users_v", "email_addresses_v"] insp = inspect(meta.bind) - for table_name, table in zip(table_names, (users, - addresses)): + for table_name, table in zip(table_names, (users, addresses)): schema_name = schema cols = insp.get_columns(table_name, schema=schema_name) self.assert_(len(cols) > 0, len(cols)) @@ -380,36 +414,46 @@ class ComponentReflectionTest(fixtures.TablesTest): # should be in order for i, col in enumerate(table.columns): - eq_(col.name, cols[i]['name']) - ctype = cols[i]['type'].__class__ + eq_(col.name, cols[i]["name"]) + ctype = cols[i]["type"].__class__ ctype_def = col.type if isinstance(ctype_def, sa.types.TypeEngine): ctype_def = ctype_def.__class__ # Oracle returns Date for DateTime. - if testing.against('oracle') and ctype_def \ - in (sql_types.Date, sql_types.DateTime): + if testing.against("oracle") and ctype_def in ( + sql_types.Date, + sql_types.DateTime, + ): ctype_def = sql_types.Date # assert that the desired type and return type share # a base within one of the generic types. - self.assert_(len(set(ctype.__mro__). - intersection(ctype_def.__mro__). - intersection([ - sql_types.Integer, - sql_types.Numeric, - sql_types.DateTime, - sql_types.Date, - sql_types.Time, - sql_types.String, - sql_types._Binary, - ])) > 0, '%s(%s), %s(%s)' % - (col.name, col.type, cols[i]['name'], ctype)) + self.assert_( + len( + set(ctype.__mro__) + .intersection(ctype_def.__mro__) + .intersection( + [ + sql_types.Integer, + sql_types.Numeric, + sql_types.DateTime, + sql_types.Date, + sql_types.Time, + sql_types.String, + sql_types._Binary, + ] + ) + ) + > 0, + "%s(%s), %s(%s)" + % (col.name, col.type, cols[i]["name"], ctype), + ) if not col.primary_key: - assert cols[i]['default'] is None + assert cols[i]["default"] is None @testing.requires.table_reflection def test_get_columns(self): @@ -417,24 +461,20 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.provide_metadata def _type_round_trip(self, *types): - t = Table('t', self.metadata, - *[ - Column('t%d' % i, type_) - for i, type_ in enumerate(types) - ] - ) + t = Table( + "t", + self.metadata, + *[Column("t%d" % i, type_) for i, type_ in enumerate(types)] + ) t.create() return [ - c['type'] for c in - inspect(self.metadata.bind).get_columns('t') + c["type"] for c in inspect(self.metadata.bind).get_columns("t") ] @testing.requires.table_reflection def test_numeric_reflection(self): - for typ in self._type_round_trip( - sql_types.Numeric(18, 5), - ): + for typ in self._type_round_trip(sql_types.Numeric(18, 5)): assert isinstance(typ, sql_types.Numeric) eq_(typ.precision, 18) eq_(typ.scale, 5) @@ -448,16 +488,19 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.table_reflection @testing.provide_metadata def test_nullable_reflection(self): - t = Table('t', self.metadata, - Column('a', Integer, nullable=True), - Column('b', Integer, nullable=False)) + t = Table( + "t", + self.metadata, + Column("a", Integer, nullable=True), + Column("b", Integer, nullable=False), + ) t.create() eq_( dict( - (col['name'], col['nullable']) - for col in inspect(self.metadata.bind).get_columns('t') + (col["name"], col["nullable"]) + for col in inspect(self.metadata.bind).get_columns("t") ), - {"a": True, "b": False} + {"a": True, "b": False}, ) @testing.requires.table_reflection @@ -470,32 +513,30 @@ class ComponentReflectionTest(fixtures.TablesTest): meta = MetaData(self.bind) user_tmp = self.tables.user_tmp insp = inspect(meta.bind) - cols = insp.get_columns('user_tmp') + cols = insp.get_columns("user_tmp") self.assert_(len(cols) > 0, len(cols)) for i, col in enumerate(user_tmp.columns): - eq_(col.name, cols[i]['name']) + eq_(col.name, cols[i]["name"]) @testing.requires.temp_table_reflection @testing.requires.view_column_reflection @testing.requires.temporary_views def test_get_temp_view_columns(self): insp = inspect(self.bind) - cols = insp.get_columns('user_tmp_v') - eq_( - [col['name'] for col in cols], - ['id', 'name', 'foo'] - ) + cols = insp.get_columns("user_tmp_v") + eq_([col["name"] for col in cols], ["id", "name", "foo"]) @testing.requires.view_column_reflection def test_get_view_columns(self): - self._test_get_columns(table_type='view') + self._test_get_columns(table_type="view") @testing.requires.view_column_reflection @testing.requires.schemas def test_get_view_columns_with_schema(self): self._test_get_columns( - schema=testing.config.test_schema, table_type='view') + schema=testing.config.test_schema, table_type="view" + ) @testing.provide_metadata def _test_get_pk_constraint(self, schema=None): @@ -504,15 +545,15 @@ class ComponentReflectionTest(fixtures.TablesTest): insp = inspect(meta.bind) users_cons = insp.get_pk_constraint(users.name, schema=schema) - users_pkeys = users_cons['constrained_columns'] - eq_(users_pkeys, ['user_id']) + users_pkeys = users_cons["constrained_columns"] + eq_(users_pkeys, ["user_id"]) addr_cons = insp.get_pk_constraint(addresses.name, schema=schema) - addr_pkeys = addr_cons['constrained_columns'] - eq_(addr_pkeys, ['address_id']) + addr_pkeys = addr_cons["constrained_columns"] + eq_(addr_pkeys, ["address_id"]) with testing.requires.reflects_pk_names.fail_if(): - eq_(addr_cons['name'], 'email_ad_pk') + eq_(addr_cons["name"], "email_ad_pk") @testing.requires.primary_key_constraint_reflection def test_get_pk_constraint(self): @@ -534,44 +575,46 @@ class ComponentReflectionTest(fixtures.TablesTest): sa_exc.SADeprecationWarning, "Call to deprecated method get_primary_keys." " Use get_pk_constraint instead.", - insp.get_primary_keys, users.name + insp.get_primary_keys, + users.name, ) @testing.provide_metadata def _test_get_foreign_keys(self, schema=None): meta = self.metadata - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) insp = inspect(meta.bind) expected_schema = schema # users if testing.requires.self_referential_foreign_keys.enabled: - users_fkeys = insp.get_foreign_keys(users.name, - schema=schema) + users_fkeys = insp.get_foreign_keys(users.name, schema=schema) fkey1 = users_fkeys[0] with testing.requires.named_constraints.fail_if(): - eq_(fkey1['name'], "user_id_fk") + eq_(fkey1["name"], "user_id_fk") - eq_(fkey1['referred_schema'], expected_schema) - eq_(fkey1['referred_table'], users.name) - eq_(fkey1['referred_columns'], ['user_id', ]) + eq_(fkey1["referred_schema"], expected_schema) + eq_(fkey1["referred_table"], users.name) + eq_(fkey1["referred_columns"], ["user_id"]) if testing.requires.self_referential_foreign_keys.enabled: - eq_(fkey1['constrained_columns'], ['parent_user_id']) + eq_(fkey1["constrained_columns"], ["parent_user_id"]) # addresses - addr_fkeys = insp.get_foreign_keys(addresses.name, - schema=schema) + addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema) fkey1 = addr_fkeys[0] with testing.requires.implicitly_named_constraints.fail_if(): - self.assert_(fkey1['name'] is not None) + self.assert_(fkey1["name"] is not None) - eq_(fkey1['referred_schema'], expected_schema) - eq_(fkey1['referred_table'], users.name) - eq_(fkey1['referred_columns'], ['user_id', ]) - eq_(fkey1['constrained_columns'], ['remote_user_id']) + eq_(fkey1["referred_schema"], expected_schema) + eq_(fkey1["referred_table"], users.name) + eq_(fkey1["referred_columns"], ["user_id"]) + eq_(fkey1["constrained_columns"], ["remote_user_id"]) @testing.requires.foreign_key_constraint_reflection def test_get_foreign_keys(self): @@ -586,9 +629,9 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.schemas def test_get_inter_schema_foreign_keys(self): local_table, remote_table, remote_table_2 = self.tables( - '%s.local_table' % testing.db.dialect.default_schema_name, - '%s.remote_table' % testing.config.test_schema, - '%s.remote_table_2' % testing.config.test_schema + "%s.local_table" % testing.db.dialect.default_schema_name, + "%s.remote_table" % testing.config.test_schema, + "%s.remote_table_2" % testing.config.test_schema, ) insp = inspect(config.db) @@ -597,25 +640,25 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(len(local_fkeys), 1) fkey1 = local_fkeys[0] - eq_(fkey1['referred_schema'], testing.config.test_schema) - eq_(fkey1['referred_table'], remote_table_2.name) - eq_(fkey1['referred_columns'], ['id', ]) - eq_(fkey1['constrained_columns'], ['remote_id']) + eq_(fkey1["referred_schema"], testing.config.test_schema) + eq_(fkey1["referred_table"], remote_table_2.name) + eq_(fkey1["referred_columns"], ["id"]) + eq_(fkey1["constrained_columns"], ["remote_id"]) remote_fkeys = insp.get_foreign_keys( - remote_table.name, schema=testing.config.test_schema) + remote_table.name, schema=testing.config.test_schema + ) eq_(len(remote_fkeys), 1) fkey2 = remote_fkeys[0] - assert fkey2['referred_schema'] in ( + assert fkey2["referred_schema"] in ( None, - testing.db.dialect.default_schema_name + testing.db.dialect.default_schema_name, ) - eq_(fkey2['referred_table'], local_table.name) - eq_(fkey2['referred_columns'], ['id', ]) - eq_(fkey2['constrained_columns'], ['local_id']) - + eq_(fkey2["referred_table"], local_table.name) + eq_(fkey2["referred_columns"], ["id"]) + eq_(fkey2["constrained_columns"], ["local_id"]) @testing.requires.foreign_key_constraint_option_reflection_ondelete def test_get_foreign_key_options_ondelete(self): @@ -630,26 +673,32 @@ class ComponentReflectionTest(fixtures.TablesTest): meta = self.metadata Table( - 'x', meta, - Column('id', Integer, primary_key=True), - test_needs_fk=True - ) - - Table('table', meta, - Column('id', Integer, primary_key=True), - Column('x_id', Integer, sa.ForeignKey('x.id', name='xid')), - Column('test', String(10)), - test_needs_fk=True) - - Table('user', meta, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('tid', Integer), - sa.ForeignKeyConstraint( - ['tid'], ['table.id'], - name='myfk', - **options), - test_needs_fk=True) + "x", + meta, + Column("id", Integer, primary_key=True), + test_needs_fk=True, + ) + + Table( + "table", + meta, + Column("id", Integer, primary_key=True), + Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")), + Column("test", String(10)), + test_needs_fk=True, + ) + + Table( + "user", + meta, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("tid", Integer), + sa.ForeignKeyConstraint( + ["tid"], ["table.id"], name="myfk", **options + ), + test_needs_fk=True, + ) meta.create_all() @@ -657,49 +706,44 @@ class ComponentReflectionTest(fixtures.TablesTest): # test 'options' is always present for a backend # that can reflect these, since alembic looks for this - opts = insp.get_foreign_keys('table')[0]['options'] + opts = insp.get_foreign_keys("table")[0]["options"] - eq_( - dict( - (k, opts[k]) - for k in opts if opts[k] - ), - {} - ) + eq_(dict((k, opts[k]) for k in opts if opts[k]), {}) - opts = insp.get_foreign_keys('user')[0]['options'] - eq_( - dict( - (k, opts[k]) - for k in opts if opts[k] - ), - options - ) + opts = insp.get_foreign_keys("user")[0]["options"] + eq_(dict((k, opts[k]) for k in opts if opts[k]), options) def _assert_insp_indexes(self, indexes, expected_indexes): - index_names = [d['name'] for d in indexes] + index_names = [d["name"] for d in indexes] for e_index in expected_indexes: - assert e_index['name'] in index_names - index = indexes[index_names.index(e_index['name'])] + assert e_index["name"] in index_names + index = indexes[index_names.index(e_index["name"])] for key in e_index: eq_(e_index[key], index[key]) @testing.provide_metadata def _test_get_indexes(self, schema=None): meta = self.metadata - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) # The database may decide to create indexes for foreign keys, etc. # so there may be more indexes than expected. insp = inspect(meta.bind) - indexes = insp.get_indexes('users', schema=schema) + indexes = insp.get_indexes("users", schema=schema) expected_indexes = [ - {'unique': False, - 'column_names': ['test1', 'test2'], - 'name': 'users_t_idx'}, - {'unique': False, - 'column_names': ['user_id', 'test2', 'test1'], - 'name': 'users_all_idx'} + { + "unique": False, + "column_names": ["test1", "test2"], + "name": "users_t_idx", + }, + { + "unique": False, + "column_names": ["user_id", "test2", "test1"], + "name": "users_all_idx", + }, ] self._assert_insp_indexes(indexes, expected_indexes) @@ -721,10 +765,7 @@ class ComponentReflectionTest(fixtures.TablesTest): # reflecting an index that has "x DESC" in it as the column. # the DB may or may not give us "x", but make sure we get the index # back, it has a name, it's connected to the table. - expected_indexes = [ - {'unique': False, - 'name': ixname} - ] + expected_indexes = [{"unique": False, "name": ixname}] self._assert_insp_indexes(indexes, expected_indexes) t = Table(tname, meta, autoload_with=meta.bind) @@ -748,24 +789,30 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.unique_constraint_reflection def test_get_temp_table_unique_constraints(self): insp = inspect(self.bind) - reflected = insp.get_unique_constraints('user_tmp') + reflected = insp.get_unique_constraints("user_tmp") for refl in reflected: # Different dialects handle duplicate index and constraints # differently, so ignore this flag - refl.pop('duplicates_index', None) - eq_(reflected, [{'column_names': ['name'], 'name': 'user_tmp_uq'}]) + refl.pop("duplicates_index", None) + eq_(reflected, [{"column_names": ["name"], "name": "user_tmp_uq"}]) @testing.requires.temp_table_reflection def test_get_temp_table_indexes(self): insp = inspect(self.bind) - indexes = insp.get_indexes('user_tmp') + indexes = insp.get_indexes("user_tmp") for ind in indexes: - ind.pop('dialect_options', None) + ind.pop("dialect_options", None) eq_( # TODO: we need to add better filtering for indexes/uq constraints # that are doubled up - [idx for idx in indexes if idx['name'] == 'user_tmp_ix'], - [{'unique': False, 'column_names': ['foo'], 'name': 'user_tmp_ix'}] + [idx for idx in indexes if idx["name"] == "user_tmp_ix"], + [ + { + "unique": False, + "column_names": ["foo"], + "name": "user_tmp_ix", + } + ], ) @testing.requires.unique_constraint_reflection @@ -783,36 +830,37 @@ class ComponentReflectionTest(fixtures.TablesTest): # CREATE TABLE? uniques = sorted( [ - {'name': 'unique_a', 'column_names': ['a']}, - {'name': 'unique_a_b_c', 'column_names': ['a', 'b', 'c']}, - {'name': 'unique_c_a_b', 'column_names': ['c', 'a', 'b']}, - {'name': 'unique_asc_key', 'column_names': ['asc', 'key']}, - {'name': 'i.have.dots', 'column_names': ['b']}, - {'name': 'i have spaces', 'column_names': ['c']}, + {"name": "unique_a", "column_names": ["a"]}, + {"name": "unique_a_b_c", "column_names": ["a", "b", "c"]}, + {"name": "unique_c_a_b", "column_names": ["c", "a", "b"]}, + {"name": "unique_asc_key", "column_names": ["asc", "key"]}, + {"name": "i.have.dots", "column_names": ["b"]}, + {"name": "i have spaces", "column_names": ["c"]}, ], - key=operator.itemgetter('name') + key=operator.itemgetter("name"), ) orig_meta = self.metadata table = Table( - 'testtbl', orig_meta, - Column('a', sa.String(20)), - Column('b', sa.String(30)), - Column('c', sa.Integer), + "testtbl", + orig_meta, + Column("a", sa.String(20)), + Column("b", sa.String(30)), + Column("c", sa.Integer), # reserved identifiers - Column('asc', sa.String(30)), - Column('key', sa.String(30)), - schema=schema + Column("asc", sa.String(30)), + Column("key", sa.String(30)), + schema=schema, ) for uc in uniques: table.append_constraint( - sa.UniqueConstraint(*uc['column_names'], name=uc['name']) + sa.UniqueConstraint(*uc["column_names"], name=uc["name"]) ) orig_meta.create_all() inspector = inspect(orig_meta.bind) reflected = sorted( - inspector.get_unique_constraints('testtbl', schema=schema), - key=operator.itemgetter('name') + inspector.get_unique_constraints("testtbl", schema=schema), + key=operator.itemgetter("name"), ) names_that_duplicate_index = set() @@ -820,25 +868,31 @@ class ComponentReflectionTest(fixtures.TablesTest): for orig, refl in zip(uniques, reflected): # Different dialects handle duplicate index and constraints # differently, so ignore this flag - dupe = refl.pop('duplicates_index', None) + dupe = refl.pop("duplicates_index", None) if dupe: names_that_duplicate_index.add(dupe) eq_(orig, refl) reflected_metadata = MetaData() reflected = Table( - 'testtbl', reflected_metadata, autoload_with=orig_meta.bind, - schema=schema) + "testtbl", + reflected_metadata, + autoload_with=orig_meta.bind, + schema=schema, + ) # test "deduplicates for index" logic. MySQL and Oracle # "unique constraints" are actually unique indexes (with possible # exception of a unique that is a dupe of another one in the case # of Oracle). make sure # they aren't duplicated. idx_names = set([idx.name for idx in reflected.indexes]) - uq_names = set([ - uq.name for uq in reflected.constraints - if isinstance(uq, sa.UniqueConstraint)]).difference( - ['unique_c_a_b']) + uq_names = set( + [ + uq.name + for uq in reflected.constraints + if isinstance(uq, sa.UniqueConstraint) + ] + ).difference(["unique_c_a_b"]) assert not idx_names.intersection(uq_names) if names_that_duplicate_index: @@ -858,47 +912,52 @@ class ComponentReflectionTest(fixtures.TablesTest): def _test_get_check_constraints(self, schema=None): orig_meta = self.metadata Table( - 'sa_cc', orig_meta, - Column('a', Integer()), - sa.CheckConstraint('a > 1 AND a < 5', name='cc1'), - sa.CheckConstraint('a = 1 OR (a > 2 AND a < 5)', name='cc2'), - schema=schema + "sa_cc", + orig_meta, + Column("a", Integer()), + sa.CheckConstraint("a > 1 AND a < 5", name="cc1"), + sa.CheckConstraint("a = 1 OR (a > 2 AND a < 5)", name="cc2"), + schema=schema, ) orig_meta.create_all() inspector = inspect(orig_meta.bind) reflected = sorted( - inspector.get_check_constraints('sa_cc', schema=schema), - key=operator.itemgetter('name') + inspector.get_check_constraints("sa_cc", schema=schema), + key=operator.itemgetter("name"), ) # trying to minimize effect of quoting, parenthesis, etc. # may need to add more to this as new dialects get CHECK # constraint reflection support def normalize(sqltext): - return " ".join(re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I)) + return " ".join( + re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I) + ) reflected = [ - {"name": item["name"], - "sqltext": normalize(item["sqltext"])} + {"name": item["name"], "sqltext": normalize(item["sqltext"])} for item in reflected ] eq_( reflected, [ - {'name': 'cc1', 'sqltext': 'a > 1 and a < 5'}, - {'name': 'cc2', 'sqltext': 'a = 1 or a > 2 and a < 5'} - ] + {"name": "cc1", "sqltext": "a > 1 and a < 5"}, + {"name": "cc2", "sqltext": "a = 1 or a > 2 and a < 5"}, + ], ) @testing.provide_metadata def _test_get_view_definition(self, schema=None): meta = self.metadata - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings - view_name1 = 'users_v' - view_name2 = 'email_addresses_v' + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) + view_name1 = "users_v" + view_name2 = "email_addresses_v" insp = inspect(meta.bind) v1 = insp.get_view_definition(view_name1, schema=schema) self.assert_(v1) @@ -918,18 +977,21 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.provide_metadata def _test_get_table_oid(self, table_name, schema=None): meta = self.metadata - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) insp = inspect(meta.bind) oid = insp.get_table_oid(table_name, schema) self.assert_(isinstance(oid, int)) def test_get_table_oid(self): - self._test_get_table_oid('users') + self._test_get_table_oid("users") @testing.requires.schemas def test_get_table_oid_with_schema(self): - self._test_get_table_oid('users', schema=testing.config.test_schema) + self._test_get_table_oid("users", schema=testing.config.test_schema) @testing.requires.table_reflection @testing.provide_metadata @@ -950,49 +1012,53 @@ class ComponentReflectionTest(fixtures.TablesTest): insp = inspect(meta.bind) for tname, cname in [ - ('users', 'user_id'), - ('email_addresses', 'address_id'), - ('dingalings', 'dingaling_id'), + ("users", "user_id"), + ("email_addresses", "address_id"), + ("dingalings", "dingaling_id"), ]: cols = insp.get_columns(tname) - id_ = {c['name']: c for c in cols}[cname] - assert id_.get('autoincrement', True) + id_ = {c["name"]: c for c in cols}[cname] + assert id_.get("autoincrement", True) class NormalizedNameTest(fixtures.TablesTest): - __requires__ = 'denormalized_names', + __requires__ = ("denormalized_names",) __backend__ = True @classmethod def define_tables(cls, metadata): Table( - quoted_name('t1', quote=True), metadata, - Column('id', Integer, primary_key=True), + quoted_name("t1", quote=True), + metadata, + Column("id", Integer, primary_key=True), ) Table( - quoted_name('t2', quote=True), metadata, - Column('id', Integer, primary_key=True), - Column('t1id', ForeignKey('t1.id')) + quoted_name("t2", quote=True), + metadata, + Column("id", Integer, primary_key=True), + Column("t1id", ForeignKey("t1.id")), ) def test_reflect_lowercase_forced_tables(self): m2 = MetaData(testing.db) - t2_ref = Table(quoted_name('t2', quote=True), m2, autoload=True) - t1_ref = m2.tables['t1'] + t2_ref = Table(quoted_name("t2", quote=True), m2, autoload=True) + t1_ref = m2.tables["t1"] assert t2_ref.c.t1id.references(t1_ref.c.id) m3 = MetaData(testing.db) - m3.reflect(only=lambda name, m: name.lower() in ('t1', 't2')) - assert m3.tables['t2'].c.t1id.references(m3.tables['t1'].c.id) + m3.reflect(only=lambda name, m: name.lower() in ("t1", "t2")) + assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id) def test_get_table_names(self): tablenames = [ - t for t in inspect(testing.db).get_table_names() - if t.lower() in ("t1", "t2")] + t + for t in inspect(testing.db).get_table_names() + if t.lower() in ("t1", "t2") + ] eq_(tablenames[0].upper(), tablenames[0].lower()) eq_(tablenames[1].upper(), tablenames[1].lower()) -__all__ = ('ComponentReflectionTest', 'HasTableTest', 'NormalizedNameTest') +__all__ = ("ComponentReflectionTest", "HasTableTest", "NormalizedNameTest") diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index f464d47ebd..247f05cf5f 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -15,14 +15,18 @@ class RowFetchTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table('plain_pk', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) - ) - Table('has_dates', metadata, - Column('id', Integer, primary_key=True), - Column('today', DateTime) - ) + Table( + "plain_pk", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + Table( + "has_dates", + metadata, + Column("id", Integer, primary_key=True), + Column("today", DateTime), + ) @classmethod def insert_data(cls): @@ -32,65 +36,51 @@ class RowFetchTest(fixtures.TablesTest): {"id": 1, "data": "d1"}, {"id": 2, "data": "d2"}, {"id": 3, "data": "d3"}, - ] + ], ) config.db.execute( cls.tables.has_dates.insert(), - [ - {"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)} - ] + [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}], ) def test_via_string(self): row = config.db.execute( - self.tables.plain_pk.select(). - order_by(self.tables.plain_pk.c.id) + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) ).first() - eq_( - row['id'], 1 - ) - eq_( - row['data'], "d1" - ) + eq_(row["id"], 1) + eq_(row["data"], "d1") def test_via_int(self): row = config.db.execute( - self.tables.plain_pk.select(). - order_by(self.tables.plain_pk.c.id) + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) ).first() - eq_( - row[0], 1 - ) - eq_( - row[1], "d1" - ) + eq_(row[0], 1) + eq_(row[1], "d1") def test_via_col_object(self): row = config.db.execute( - self.tables.plain_pk.select(). - order_by(self.tables.plain_pk.c.id) + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) ).first() - eq_( - row[self.tables.plain_pk.c.id], 1 - ) - eq_( - row[self.tables.plain_pk.c.data], "d1" - ) + eq_(row[self.tables.plain_pk.c.id], 1) + eq_(row[self.tables.plain_pk.c.data], "d1") @requirements.duplicate_names_in_cursor_description def test_row_with_dupe_names(self): result = config.db.execute( - select([self.tables.plain_pk.c.data, - self.tables.plain_pk.c.data.label('data')]). - order_by(self.tables.plain_pk.c.id) + select( + [ + self.tables.plain_pk.c.data, + self.tables.plain_pk.c.data.label("data"), + ] + ).order_by(self.tables.plain_pk.c.id) ) row = result.first() - eq_(result.keys(), ['data', 'data']) - eq_(row, ('d1', 'd1')) + eq_(result.keys(), ["data", "data"]) + eq_(row, ("d1", "d1")) def test_row_w_scalar_select(self): """test that a scalar select as a column is returned as such @@ -101,11 +91,11 @@ class RowFetchTest(fixtures.TablesTest): """ datetable = self.tables.has_dates - s = select([datetable.alias('x').c.today]).as_scalar() - s2 = select([datetable.c.id, s.label('somelabel')]) + s = select([datetable.alias("x").c.today]).as_scalar() + s2 = select([datetable.c.id, s.label("somelabel")]) row = config.db.execute(s2).first() - eq_(row['somelabel'], datetime.datetime(2006, 5, 12, 12, 0, 0)) + eq_(row["somelabel"], datetime.datetime(2006, 5, 12, 12, 0, 0)) class PercentSchemaNamesTest(fixtures.TablesTest): @@ -117,29 +107,31 @@ class PercentSchemaNamesTest(fixtures.TablesTest): """ - __requires__ = ('percent_schema_names', ) + __requires__ = ("percent_schema_names",) __backend__ = True @classmethod def define_tables(cls, metadata): - cls.tables.percent_table = Table('percent%table', metadata, - Column("percent%", Integer), - Column( - "spaces % more spaces", Integer), - ) + cls.tables.percent_table = Table( + "percent%table", + metadata, + Column("percent%", Integer), + Column("spaces % more spaces", Integer), + ) cls.tables.lightweight_percent_table = sql.table( - 'percent%table', sql.column("percent%"), - sql.column("spaces % more spaces") + "percent%table", + sql.column("percent%"), + sql.column("spaces % more spaces"), ) def test_single_roundtrip(self): percent_table = self.tables.percent_table for params in [ - {'percent%': 5, 'spaces % more spaces': 12}, - {'percent%': 7, 'spaces % more spaces': 11}, - {'percent%': 9, 'spaces % more spaces': 10}, - {'percent%': 11, 'spaces % more spaces': 9} + {"percent%": 5, "spaces % more spaces": 12}, + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, ]: config.db.execute(percent_table.insert(), params) self._assert_table() @@ -147,14 +139,15 @@ class PercentSchemaNamesTest(fixtures.TablesTest): def test_executemany_roundtrip(self): percent_table = self.tables.percent_table config.db.execute( - percent_table.insert(), - {'percent%': 5, 'spaces % more spaces': 12} + percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12} ) config.db.execute( percent_table.insert(), - [{'percent%': 7, 'spaces % more spaces': 11}, - {'percent%': 9, 'spaces % more spaces': 10}, - {'percent%': 11, 'spaces % more spaces': 9}] + [ + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, + ], ) self._assert_table() @@ -163,85 +156,81 @@ class PercentSchemaNamesTest(fixtures.TablesTest): lightweight_percent_table = self.tables.lightweight_percent_table for table in ( - percent_table, - percent_table.alias(), - lightweight_percent_table, - lightweight_percent_table.alias()): + percent_table, + percent_table.alias(), + lightweight_percent_table, + lightweight_percent_table.alias(), + ): eq_( list( config.db.execute( - table.select().order_by(table.c['percent%']) + table.select().order_by(table.c["percent%"]) ) ), - [ - (5, 12), - (7, 11), - (9, 10), - (11, 9) - ] + [(5, 12), (7, 11), (9, 10), (11, 9)], ) eq_( list( config.db.execute( - table.select(). - where(table.c['spaces % more spaces'].in_([9, 10])). - order_by(table.c['percent%']), + table.select() + .where(table.c["spaces % more spaces"].in_([9, 10])) + .order_by(table.c["percent%"]) ) ), - [ - (9, 10), - (11, 9) - ] + [(9, 10), (11, 9)], ) - row = config.db.execute(table.select(). - order_by(table.c['percent%'])).first() - eq_(row['percent%'], 5) - eq_(row['spaces % more spaces'], 12) + row = config.db.execute( + table.select().order_by(table.c["percent%"]) + ).first() + eq_(row["percent%"], 5) + eq_(row["spaces % more spaces"], 12) - eq_(row[table.c['percent%']], 5) - eq_(row[table.c['spaces % more spaces']], 12) + eq_(row[table.c["percent%"]], 5) + eq_(row[table.c["spaces % more spaces"]], 12) config.db.execute( percent_table.update().values( - {percent_table.c['spaces % more spaces']: 15} + {percent_table.c["spaces % more spaces"]: 15} ) ) eq_( list( config.db.execute( - percent_table. - select(). - order_by(percent_table.c['percent%']) + percent_table.select().order_by( + percent_table.c["percent%"] + ) ) ), - [(5, 15), (7, 15), (9, 15), (11, 15)] + [(5, 15), (7, 15), (9, 15), (11, 15)], ) -class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): +class ServerSideCursorsTest( + fixtures.TestBase, testing.AssertsExecutionResults +): - __requires__ = ('server_side_cursors', ) + __requires__ = ("server_side_cursors",) __backend__ = True def _is_server_side(self, cursor): if self.engine.dialect.driver == "psycopg2": return cursor.name - elif self.engine.dialect.driver == 'pymysql': - sscursor = __import__('pymysql.cursors').cursors.SSCursor + elif self.engine.dialect.driver == "pymysql": + sscursor = __import__("pymysql.cursors").cursors.SSCursor return isinstance(cursor, sscursor) elif self.engine.dialect.driver == "mysqldb": - sscursor = __import__('MySQLdb.cursors').cursors.SSCursor + sscursor = __import__("MySQLdb.cursors").cursors.SSCursor return isinstance(cursor, sscursor) else: return False def _fixture(self, server_side_cursors): self.engine = engines.testing_engine( - options={'server_side_cursors': server_side_cursors} + options={"server_side_cursors": server_side_cursors} ) return self.engine @@ -251,12 +240,12 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): def test_global_string(self): engine = self._fixture(True) - result = engine.execute('select 1') + result = engine.execute("select 1") assert self._is_server_side(result.cursor) def test_global_text(self): engine = self._fixture(True) - result = engine.execute(text('select 1')) + result = engine.execute(text("select 1")) assert self._is_server_side(result.cursor) def test_global_expr(self): @@ -266,7 +255,7 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): def test_global_off_explicit(self): engine = self._fixture(False) - result = engine.execute(text('select 1')) + result = engine.execute(text("select 1")) # It should be off globally ... @@ -286,10 +275,11 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): engine = self._fixture(False) # and this one - result = \ - engine.connect().execution_options(stream_results=True).\ - execute('select 1' - ) + result = ( + engine.connect() + .execution_options(stream_results=True) + .execute("select 1") + ) assert self._is_server_side(result.cursor) def test_stmt_enabled_conn_option_disabled(self): @@ -298,9 +288,9 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): s = select([1]).execution_options(stream_results=True) # not this one - result = \ - engine.connect().execution_options(stream_results=False).\ - execute(s) + result = ( + engine.connect().execution_options(stream_results=False).execute(s) + ) assert not self._is_server_side(result.cursor) def test_stmt_option_disabled(self): @@ -329,18 +319,18 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): def test_for_update_string(self): engine = self._fixture(True) - result = engine.execute('SELECT 1 FOR UPDATE') + result = engine.execute("SELECT 1 FOR UPDATE") assert self._is_server_side(result.cursor) def test_text_no_ss(self): engine = self._fixture(False) - s = text('select 42') + s = text("select 42") result = engine.execute(s) assert not self._is_server_side(result.cursor) def test_text_ss_option(self): engine = self._fixture(False) - s = text('select 42').execution_options(stream_results=True) + s = text("select 42").execution_options(stream_results=True) result = engine.execute(s) assert self._is_server_side(result.cursor) @@ -349,19 +339,25 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): md = self.metadata engine = self._fixture(True) - test_table = Table('test_table', md, - Column('id', Integer, primary_key=True), - Column('data', String(50))) + test_table = Table( + "test_table", + md, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) test_table.create(checkfirst=True) - test_table.insert().execute(data='data1') - test_table.insert().execute(data='data2') - eq_(test_table.select().order_by(test_table.c.id).execute().fetchall(), - [(1, 'data1'), (2, 'data2')]) - test_table.update().where( - test_table.c.id == 2).values( - data=test_table.c.data + - ' updated').execute() - eq_(test_table.select().order_by(test_table.c.id).execute().fetchall(), - [(1, 'data1'), (2, 'data2 updated')]) + test_table.insert().execute(data="data1") + test_table.insert().execute(data="data2") + eq_( + test_table.select().order_by(test_table.c.id).execute().fetchall(), + [(1, "data1"), (2, "data2")], + ) + test_table.update().where(test_table.c.id == 2).values( + data=test_table.c.data + " updated" + ).execute() + eq_( + test_table.select().order_by(test_table.c.id).execute().fetchall(), + [(1, "data1"), (2, "data2 updated")], + ) test_table.delete().execute() - eq_(select([func.count('*')]).select_from(test_table).scalar(), 0) + eq_(select([func.count("*")]).select_from(test_table).scalar(), 0) diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index 73ce02492f..032b68eb62 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -16,10 +16,12 @@ class CollateTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(100)) - ) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(100)), + ) @classmethod def insert_data(cls): @@ -28,26 +30,21 @@ class CollateTest(fixtures.TablesTest): [ {"id": 1, "data": "collate data1"}, {"id": 2, "data": "collate data2"}, - ] + ], ) def _assert_result(self, select, result): - eq_( - config.db.execute(select).fetchall(), - result - ) + eq_(config.db.execute(select).fetchall(), result) @testing.requires.order_by_collation def test_collate_order_by(self): collation = testing.requires.get_order_by_collation(testing.config) self._assert_result( - select([self.tables.some_table]). - order_by(self.tables.some_table.c.data.collate(collation).asc()), - [ - (1, "collate data1"), - (2, "collate data2"), - ] + select([self.tables.some_table]).order_by( + self.tables.some_table.c.data.collate(collation).asc() + ), + [(1, "collate data1"), (2, "collate data2")], ) @@ -59,17 +56,20 @@ class OrderByLabelTest(fixtures.TablesTest): setting. """ + __backend__ = True @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer), - Column('q', String(50)), - Column('p', String(50)) - ) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("q", String(50)), + Column("p", String(50)), + ) @classmethod def insert_data(cls): @@ -79,65 +79,55 @@ class OrderByLabelTest(fixtures.TablesTest): {"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"}, {"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"}, {"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"}, - ] + ], ) def _assert_result(self, select, result): - eq_( - config.db.execute(select).fetchall(), - result - ) + eq_(config.db.execute(select).fetchall(), result) def test_plain(self): table = self.tables.some_table - lx = table.c.x.label('lx') - self._assert_result( - select([lx]).order_by(lx), - [(1, ), (2, ), (3, )] - ) + lx = table.c.x.label("lx") + self._assert_result(select([lx]).order_by(lx), [(1,), (2,), (3,)]) def test_composed_int(self): table = self.tables.some_table - lx = (table.c.x + table.c.y).label('lx') - self._assert_result( - select([lx]).order_by(lx), - [(3, ), (5, ), (7, )] - ) + lx = (table.c.x + table.c.y).label("lx") + self._assert_result(select([lx]).order_by(lx), [(3,), (5,), (7,)]) def test_composed_multiple(self): table = self.tables.some_table - lx = (table.c.x + table.c.y).label('lx') - ly = (func.lower(table.c.q) + table.c.p).label('ly') + lx = (table.c.x + table.c.y).label("lx") + ly = (func.lower(table.c.q) + table.c.p).label("ly") self._assert_result( select([lx, ly]).order_by(lx, ly.desc()), - [(3, util.u('q1p3')), (5, util.u('q2p2')), (7, util.u('q3p1'))] + [(3, util.u("q1p3")), (5, util.u("q2p2")), (7, util.u("q3p1"))], ) def test_plain_desc(self): table = self.tables.some_table - lx = table.c.x.label('lx') + lx = table.c.x.label("lx") self._assert_result( - select([lx]).order_by(lx.desc()), - [(3, ), (2, ), (1, )] + select([lx]).order_by(lx.desc()), [(3,), (2,), (1,)] ) def test_composed_int_desc(self): table = self.tables.some_table - lx = (table.c.x + table.c.y).label('lx') + lx = (table.c.x + table.c.y).label("lx") self._assert_result( - select([lx]).order_by(lx.desc()), - [(7, ), (5, ), (3, )] + select([lx]).order_by(lx.desc()), [(7,), (5,), (3,)] ) @testing.requires.group_by_complex_expression def test_group_by_composed(self): table = self.tables.some_table - expr = (table.c.x + table.c.y).label('lx') - stmt = select([func.count(table.c.id), expr]).group_by(expr).order_by(expr) - self._assert_result( - stmt, - [(1, 3), (1, 5), (1, 7)] + expr = (table.c.x + table.c.y).label("lx") + stmt = ( + select([func.count(table.c.id), expr]) + .group_by(expr) + .order_by(expr) ) + self._assert_result(stmt, [(1, 3), (1, 5), (1, 7)]) class LimitOffsetTest(fixtures.TablesTest): @@ -145,10 +135,13 @@ class LimitOffsetTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer)) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) @classmethod def insert_data(cls): @@ -159,20 +152,17 @@ class LimitOffsetTest(fixtures.TablesTest): {"id": 2, "x": 2, "y": 3}, {"id": 3, "x": 3, "y": 4}, {"id": 4, "x": 4, "y": 5}, - ] + ], ) def _assert_result(self, select, result, params=()): - eq_( - config.db.execute(select, params).fetchall(), - result - ) + eq_(config.db.execute(select, params).fetchall(), result) def test_simple_limit(self): table = self.tables.some_table self._assert_result( select([table]).order_by(table.c.id).limit(2), - [(1, 1, 2), (2, 2, 3)] + [(1, 1, 2), (2, 2, 3)], ) @testing.requires.offset @@ -180,7 +170,7 @@ class LimitOffsetTest(fixtures.TablesTest): table = self.tables.some_table self._assert_result( select([table]).order_by(table.c.id).offset(2), - [(3, 3, 4), (4, 4, 5)] + [(3, 3, 4), (4, 4, 5)], ) @testing.requires.offset @@ -188,7 +178,7 @@ class LimitOffsetTest(fixtures.TablesTest): table = self.tables.some_table self._assert_result( select([table]).order_by(table.c.id).limit(2).offset(1), - [(2, 2, 3), (3, 3, 4)] + [(2, 2, 3), (3, 3, 4)], ) @testing.requires.offset @@ -198,41 +188,40 @@ class LimitOffsetTest(fixtures.TablesTest): table = self.tables.some_table stmt = select([table]).order_by(table.c.id).limit(2).offset(1) sql = stmt.compile( - dialect=config.db.dialect, - compile_kwargs={"literal_binds": True}) + dialect=config.db.dialect, compile_kwargs={"literal_binds": True} + ) sql = str(sql) - self._assert_result( - sql, - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(sql, [(2, 2, 3), (3, 3, 4)]) @testing.requires.bound_limit_offset def test_bound_limit(self): table = self.tables.some_table self._assert_result( - select([table]).order_by(table.c.id).limit(bindparam('l')), + select([table]).order_by(table.c.id).limit(bindparam("l")), [(1, 1, 2), (2, 2, 3)], - params={"l": 2} + params={"l": 2}, ) @testing.requires.bound_limit_offset def test_bound_offset(self): table = self.tables.some_table self._assert_result( - select([table]).order_by(table.c.id).offset(bindparam('o')), + select([table]).order_by(table.c.id).offset(bindparam("o")), [(3, 3, 4), (4, 4, 5)], - params={"o": 2} + params={"o": 2}, ) @testing.requires.bound_limit_offset def test_bound_limit_offset(self): table = self.tables.some_table self._assert_result( - select([table]).order_by(table.c.id). - limit(bindparam("l")).offset(bindparam("o")), + select([table]) + .order_by(table.c.id) + .limit(bindparam("l")) + .offset(bindparam("o")), [(2, 2, 3), (3, 3, 4)], - params={"l": 2, "o": 1} + params={"l": 2, "o": 1}, ) @@ -241,10 +230,13 @@ class CompoundSelectTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer)) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) @classmethod def insert_data(cls): @@ -255,14 +247,11 @@ class CompoundSelectTest(fixtures.TablesTest): {"id": 2, "x": 2, "y": 3}, {"id": 3, "x": 3, "y": 4}, {"id": 4, "x": 4, "y": 5}, - ] + ], ) def _assert_result(self, select, result, params=()): - eq_( - config.db.execute(select, params).fetchall(), - result - ) + eq_(config.db.execute(select, params).fetchall(), result) def test_plain_union(self): table = self.tables.some_table @@ -270,10 +259,7 @@ class CompoundSelectTest(fixtures.TablesTest): s2 = select([table]).where(table.c.id == 3) u1 = union(s1, s2) - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) def test_select_from_plain_union(self): table = self.tables.some_table @@ -281,80 +267,88 @@ class CompoundSelectTest(fixtures.TablesTest): s2 = select([table]).where(table.c.id == 3) u1 = union(s1, s2).alias().select() - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) @testing.requires.order_by_col_from_union @testing.requires.parens_in_union_contained_select_w_limit_offset def test_limit_offset_selectable_in_unions(self): table = self.tables.some_table - s1 = select([table]).where(table.c.id == 2).\ - limit(1).order_by(table.c.id) - s2 = select([table]).where(table.c.id == 3).\ - limit(1).order_by(table.c.id) + s1 = ( + select([table]) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + ) + s2 = ( + select([table]) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + ) u1 = union(s1, s2).limit(2) - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) @testing.requires.parens_in_union_contained_select_wo_limit_offset def test_order_by_selectable_in_unions(self): table = self.tables.some_table - s1 = select([table]).where(table.c.id == 2).\ - order_by(table.c.id) - s2 = select([table]).where(table.c.id == 3).\ - order_by(table.c.id) + s1 = select([table]).where(table.c.id == 2).order_by(table.c.id) + s2 = select([table]).where(table.c.id == 3).order_by(table.c.id) u1 = union(s1, s2).limit(2) - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) def test_distinct_selectable_in_unions(self): table = self.tables.some_table - s1 = select([table]).where(table.c.id == 2).\ - distinct() - s2 = select([table]).where(table.c.id == 3).\ - distinct() + s1 = select([table]).where(table.c.id == 2).distinct() + s2 = select([table]).where(table.c.id == 3).distinct() u1 = union(s1, s2).limit(2) - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) @testing.requires.parens_in_union_contained_select_w_limit_offset def test_limit_offset_in_unions_from_alias(self): table = self.tables.some_table - s1 = select([table]).where(table.c.id == 2).\ - limit(1).order_by(table.c.id) - s2 = select([table]).where(table.c.id == 3).\ - limit(1).order_by(table.c.id) + s1 = ( + select([table]) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + ) + s2 = ( + select([table]) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + ) # this necessarily has double parens u1 = union(s1, s2).alias() self._assert_result( - u1.select().limit(2).order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] + u1.select().limit(2).order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] ) def test_limit_offset_aliased_selectable_in_unions(self): table = self.tables.some_table - s1 = select([table]).where(table.c.id == 2).\ - limit(1).order_by(table.c.id).alias().select() - s2 = select([table]).where(table.c.id == 3).\ - limit(1).order_by(table.c.id).alias().select() + s1 = ( + select([table]) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + s2 = ( + select([table]) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) u1 = union(s1, s2).limit(2) - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) class ExpandingBoundInTest(fixtures.TablesTest): @@ -362,11 +356,14 @@ class ExpandingBoundInTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer), - Column('z', String(50))) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("z", String(50)), + ) @classmethod def insert_data(cls): @@ -377,178 +374,184 @@ class ExpandingBoundInTest(fixtures.TablesTest): {"id": 2, "x": 2, "y": 3, "z": "z2"}, {"id": 3, "x": 3, "y": 4, "z": "z3"}, {"id": 4, "x": 4, "y": 5, "z": "z4"}, - ] + ], ) def _assert_result(self, select, result, params=()): - eq_( - config.db.execute(select, params).fetchall(), - result - ) + eq_(config.db.execute(select, params).fetchall(), result) def test_multiple_empty_sets(self): # test that any anonymous aliasing used by the dialect # is fine with duplicates table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.x.in_(bindparam('q', expanding=True))).where( - table.c.y.in_(bindparam('p', expanding=True)) - ).order_by(table.c.id) - - self._assert_result( - stmt, - [], - params={"q": [], "p": []}, + stmt = ( + select([table.c.id]) + .where(table.c.x.in_(bindparam("q", expanding=True))) + .where(table.c.y.in_(bindparam("p", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [], params={"q": [], "p": []}) + @testing.requires.tuple_in def test_empty_heterogeneous_tuples(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - tuple_(table.c.x, table.c.z).in_( - bindparam('q', expanding=True))).order_by(table.c.id) - - self._assert_result( - stmt, - [], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where( + tuple_(table.c.x, table.c.z).in_( + bindparam("q", expanding=True) + ) + ) + .order_by(table.c.id) ) + self._assert_result(stmt, [], params={"q": []}) + @testing.requires.tuple_in def test_empty_homogeneous_tuples(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - tuple_(table.c.x, table.c.y).in_( - bindparam('q', expanding=True))).order_by(table.c.id) - - self._assert_result( - stmt, - [], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where( + tuple_(table.c.x, table.c.y).in_( + bindparam("q", expanding=True) + ) + ) + .order_by(table.c.id) ) + self._assert_result(stmt, [], params={"q": []}) + def test_bound_in_scalar(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.x.in_(bindparam('q', expanding=True))).order_by(table.c.id) - - self._assert_result( - stmt, - [(2, ), (3, ), (4, )], - params={"q": [2, 3, 4]}, + stmt = ( + select([table.c.id]) + .where(table.c.x.in_(bindparam("q", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]}) + @testing.requires.tuple_in def test_bound_in_two_tuple(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - tuple_(table.c.x, table.c.y).in_( - bindparam('q', expanding=True))).order_by(table.c.id) + stmt = ( + select([table.c.id]) + .where( + tuple_(table.c.x, table.c.y).in_( + bindparam("q", expanding=True) + ) + ) + .order_by(table.c.id) + ) self._assert_result( - stmt, - [(2, ), (3, ), (4, )], - params={"q": [(2, 3), (3, 4), (4, 5)]}, + stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]} ) @testing.requires.tuple_in def test_bound_in_heterogeneous_two_tuple(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - tuple_(table.c.x, table.c.z).in_( - bindparam('q', expanding=True))).order_by(table.c.id) + stmt = ( + select([table.c.id]) + .where( + tuple_(table.c.x, table.c.z).in_( + bindparam("q", expanding=True) + ) + ) + .order_by(table.c.id) + ) self._assert_result( stmt, - [(2, ), (3, ), (4, )], + [(2,), (3,), (4,)], params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, ) def test_empty_set_against_integer(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.x.in_(bindparam('q', expanding=True))).order_by(table.c.id) - - self._assert_result( - stmt, - [], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where(table.c.x.in_(bindparam("q", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [], params={"q": []}) + def test_empty_set_against_integer_negation(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.x.notin_(bindparam('q', expanding=True)) - ).order_by(table.c.id) - - self._assert_result( - stmt, - [(1, ), (2, ), (3, ), (4, )], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where(table.c.x.notin_(bindparam("q", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) + def test_empty_set_against_string(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.z.in_(bindparam('q', expanding=True))).order_by(table.c.id) - - self._assert_result( - stmt, - [], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where(table.c.z.in_(bindparam("q", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [], params={"q": []}) + def test_empty_set_against_string_negation(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.z.notin_(bindparam('q', expanding=True)) - ).order_by(table.c.id) - - self._assert_result( - stmt, - [(1, ), (2, ), (3, ), (4, )], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where(table.c.z.notin_(bindparam("q", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) + def test_null_in_empty_set_is_false(self): - stmt = select([ - case( - [ - ( - null().in_(bindparam('foo', value=(), expanding=True)), - true() - ) - ], - else_=false() - ) - ]) - in_( - config.db.execute(stmt).fetchone()[0], - (False, 0) + stmt = select( + [ + case( + [ + ( + null().in_( + bindparam("foo", value=(), expanding=True) + ), + true(), + ) + ], + else_=false(), + ) + ] ) + in_(config.db.execute(stmt).fetchone()[0], (False, 0)) class LikeFunctionsTest(fixtures.TablesTest): __backend__ = True - run_inserts = 'once' + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50))) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) @classmethod def insert_data(cls): @@ -565,7 +568,7 @@ class LikeFunctionsTest(fixtures.TablesTest): {"id": 8, "data": "ab9cdefg"}, {"id": 9, "data": "abcde#fg"}, {"id": 10, "data": "abcd9fg"}, - ] + ], ) def _test(self, expr, expected): @@ -573,8 +576,10 @@ class LikeFunctionsTest(fixtures.TablesTest): with config.db.connect() as conn: rows = { - value for value, in - conn.execute(select([some_table.c.id]).where(expr)) + value + for value, in conn.execute( + select([some_table.c.id]).where(expr) + ) } eq_(rows, expected) @@ -591,7 +596,8 @@ class LikeFunctionsTest(fixtures.TablesTest): col = self.tables.some_table.c.data self._test( col.startswith(literal_column("'ab%c'")), - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + ) def test_startswith_escape(self): col = self.tables.some_table.c.data @@ -608,8 +614,9 @@ class LikeFunctionsTest(fixtures.TablesTest): def test_endswith_sqlexpr(self): col = self.tables.some_table.c.data - self._test(col.endswith(literal_column("'e%fg'")), - {1, 2, 3, 4, 5, 6, 7, 8, 9}) + self._test( + col.endswith(literal_column("'e%fg'")), {1, 2, 3, 4, 5, 6, 7, 8, 9} + ) def test_endswith_autoescape(self): col = self.tables.some_table.c.data @@ -640,5 +647,3 @@ class LikeFunctionsTest(fixtures.TablesTest): col = self.tables.some_table.c.data self._test(col.contains("b%cd", autoescape=True, escape="#"), {3}) self._test(col.contains("b#cd", autoescape=True, escape="#"), {7}) - - diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py index f1c00de6b0..15a850fe9e 100644 --- a/lib/sqlalchemy/testing/suite/test_sequence.py +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -9,140 +9,144 @@ from ..schema import Table, Column class SequenceTest(fixtures.TablesTest): - __requires__ = ('sequences',) + __requires__ = ("sequences",) __backend__ = True - run_create_tables = 'each' + run_create_tables = "each" @classmethod def define_tables(cls, metadata): - Table('seq_pk', metadata, - Column('id', Integer, Sequence('tab_id_seq'), primary_key=True), - Column('data', String(50)) - ) + Table( + "seq_pk", + metadata, + Column("id", Integer, Sequence("tab_id_seq"), primary_key=True), + Column("data", String(50)), + ) - Table('seq_opt_pk', metadata, - Column('id', Integer, Sequence('tab_id_seq', optional=True), - primary_key=True), - Column('data', String(50)) - ) + Table( + "seq_opt_pk", + metadata, + Column( + "id", + Integer, + Sequence("tab_id_seq", optional=True), + primary_key=True, + ), + Column("data", String(50)), + ) def test_insert_roundtrip(self): - config.db.execute( - self.tables.seq_pk.insert(), - data="some data" - ) + config.db.execute(self.tables.seq_pk.insert(), data="some data") self._assert_round_trip(self.tables.seq_pk, config.db) def test_insert_lastrowid(self): - r = config.db.execute( - self.tables.seq_pk.insert(), - data="some data" - ) - eq_( - r.inserted_primary_key, - [1] - ) + r = config.db.execute(self.tables.seq_pk.insert(), data="some data") + eq_(r.inserted_primary_key, [1]) def test_nextval_direct(self): - r = config.db.execute( - self.tables.seq_pk.c.id.default - ) - eq_( - r, 1 - ) + r = config.db.execute(self.tables.seq_pk.c.id.default) + eq_(r, 1) @requirements.sequences_optional def test_optional_seq(self): r = config.db.execute( - self.tables.seq_opt_pk.insert(), - data="some data" - ) - eq_( - r.inserted_primary_key, - [1] + self.tables.seq_opt_pk.insert(), data="some data" ) + eq_(r.inserted_primary_key, [1]) def _assert_round_trip(self, table, conn): row = conn.execute(table.select()).first() - eq_( - row, - (1, "some data") - ) + eq_(row, (1, "some data")) class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase): - __requires__ = ('sequences',) + __requires__ = ("sequences",) __backend__ = True def test_literal_binds_inline_compile(self): table = Table( - 'x', MetaData(), - Column('y', Integer, Sequence('y_seq')), - Column('q', Integer)) + "x", + MetaData(), + Column("y", Integer, Sequence("y_seq")), + Column("q", Integer), + ) stmt = table.insert().values(q=5) seq_nextval = testing.db.dialect.statement_compiler( - statement=None, dialect=testing.db.dialect).visit_sequence( - Sequence("y_seq")) + statement=None, dialect=testing.db.dialect + ).visit_sequence(Sequence("y_seq")) self.assert_compile( stmt, - "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval, ), + "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,), literal_binds=True, - dialect=testing.db.dialect) + dialect=testing.db.dialect, + ) class HasSequenceTest(fixtures.TestBase): - __requires__ = 'sequences', + __requires__ = ("sequences",) __backend__ = True def test_has_sequence(self): - s1 = Sequence('user_id_seq') + s1 = Sequence("user_id_seq") testing.db.execute(schema.CreateSequence(s1)) try: - eq_(testing.db.dialect.has_sequence(testing.db, - 'user_id_seq'), True) + eq_( + testing.db.dialect.has_sequence(testing.db, "user_id_seq"), + True, + ) finally: testing.db.execute(schema.DropSequence(s1)) @testing.requires.schemas def test_has_sequence_schema(self): - s1 = Sequence('user_id_seq', schema=config.test_schema) + s1 = Sequence("user_id_seq", schema=config.test_schema) testing.db.execute(schema.CreateSequence(s1)) try: - eq_(testing.db.dialect.has_sequence( - testing.db, 'user_id_seq', schema=config.test_schema), True) + eq_( + testing.db.dialect.has_sequence( + testing.db, "user_id_seq", schema=config.test_schema + ), + True, + ) finally: testing.db.execute(schema.DropSequence(s1)) def test_has_sequence_neg(self): - eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), - False) + eq_(testing.db.dialect.has_sequence(testing.db, "user_id_seq"), False) @testing.requires.schemas def test_has_sequence_schemas_neg(self): - eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq', - schema=config.test_schema), - False) + eq_( + testing.db.dialect.has_sequence( + testing.db, "user_id_seq", schema=config.test_schema + ), + False, + ) @testing.requires.schemas def test_has_sequence_default_not_in_remote(self): - s1 = Sequence('user_id_seq') + s1 = Sequence("user_id_seq") testing.db.execute(schema.CreateSequence(s1)) try: - eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq', - schema=config.test_schema), - False) + eq_( + testing.db.dialect.has_sequence( + testing.db, "user_id_seq", schema=config.test_schema + ), + False, + ) finally: testing.db.execute(schema.DropSequence(s1)) @testing.requires.schemas def test_has_sequence_remote_not_in_default(self): - s1 = Sequence('user_id_seq', schema=config.test_schema) + s1 = Sequence("user_id_seq", schema=config.test_schema) testing.db.execute(schema.CreateSequence(s1)) try: - eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), - False) + eq_( + testing.db.dialect.has_sequence(testing.db, "user_id_seq"), + False, + ) finally: testing.db.execute(schema.DropSequence(s1)) diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 27c7bb115c..6dfb80915a 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -4,9 +4,24 @@ from .. import fixtures, config from ..assertions import eq_ from ..config import requirements from sqlalchemy import Integer, Unicode, UnicodeText, select, TIMESTAMP -from sqlalchemy import Date, DateTime, Time, MetaData, String, \ - Text, Numeric, Float, literal, Boolean, cast, null, JSON, and_, \ - type_coerce, BigInteger +from sqlalchemy import ( + Date, + DateTime, + Time, + MetaData, + String, + Text, + Numeric, + Float, + literal, + Boolean, + cast, + null, + JSON, + and_, + type_coerce, + BigInteger, +) from ..schema import Table, Column from ... import testing import decimal @@ -24,13 +39,17 @@ class _LiteralRoundTripFixture(object): # into a typed column. we can then SELECT it back as its # official type; ideally we'd be able to use CAST here # but MySQL in particular can't CAST fully - t = Table('t', self.metadata, Column('x', type_)) + t = Table("t", self.metadata, Column("x", type_)) t.create() for value in input_: - ins = t.insert().values(x=literal(value)).compile( - dialect=testing.db.dialect, - compile_kwargs=dict(literal_binds=True) + ins = ( + t.insert() + .values(x=literal(value)) + .compile( + dialect=testing.db.dialect, + compile_kwargs=dict(literal_binds=True), + ) ) testing.db.execute(ins) @@ -42,40 +61,33 @@ class _LiteralRoundTripFixture(object): class _UnicodeFixture(_LiteralRoundTripFixture): - __requires__ = 'unicode_data', + __requires__ = ("unicode_data",) - data = u("Alors vous imaginez ma surprise, au lever du jour, " - "quand une drôle de petite voix m’a réveillé. Elle " - "disait: « S’il vous plaît… dessine-moi un mouton! »") + data = u( + "Alors vous imaginez ma surprise, au lever du jour, " + "quand une drôle de petite voix m’a réveillé. Elle " + "disait: « S’il vous plaît… dessine-moi un mouton! »" + ) @classmethod def define_tables(cls, metadata): - Table('unicode_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('unicode_data', cls.datatype), - ) + Table( + "unicode_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("unicode_data", cls.datatype), + ) def test_round_trip(self): unicode_table = self.tables.unicode_table - config.db.execute( - unicode_table.insert(), - { - 'unicode_data': self.data, - } - ) + config.db.execute(unicode_table.insert(), {"unicode_data": self.data}) - row = config.db.execute( - select([ - unicode_table.c.unicode_data, - ]) - ).first() + row = config.db.execute(select([unicode_table.c.unicode_data])).first() - eq_( - row, - (self.data, ) - ) + eq_(row, (self.data,)) assert isinstance(row[0], util.text_type) def test_round_trip_executemany(self): @@ -83,44 +95,29 @@ class _UnicodeFixture(_LiteralRoundTripFixture): config.db.execute( unicode_table.insert(), - [ - { - 'unicode_data': self.data, - } - for i in range(3) - ] + [{"unicode_data": self.data} for i in range(3)], ) rows = config.db.execute( - select([ - unicode_table.c.unicode_data, - ]) + select([unicode_table.c.unicode_data]) ).fetchall() - eq_( - rows, - [(self.data, ) for i in range(3)] - ) + eq_(rows, [(self.data,) for i in range(3)]) for row in rows: assert isinstance(row[0], util.text_type) def _test_empty_strings(self): unicode_table = self.tables.unicode_table - config.db.execute( - unicode_table.insert(), - {"unicode_data": u('')} - ) - row = config.db.execute( - select([unicode_table.c.unicode_data]) - ).first() - eq_(row, (u(''),)) + config.db.execute(unicode_table.insert(), {"unicode_data": u("")}) + row = config.db.execute(select([unicode_table.c.unicode_data])).first() + eq_(row, (u(""),)) def test_literal(self): self._literal_round_trip(self.datatype, [self.data], [self.data]) class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest): - __requires__ = 'unicode_data', + __requires__ = ("unicode_data",) __backend__ = True datatype = Unicode(255) @@ -131,7 +128,7 @@ class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest): class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest): - __requires__ = 'unicode_data', 'text_type' + __requires__ = "unicode_data", "text_type" __backend__ = True datatype = UnicodeText() @@ -142,54 +139,47 @@ class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest): class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): - __requires__ = 'text_type', + __requires__ = ("text_type",) __backend__ = True @classmethod def define_tables(cls, metadata): - Table('text_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('text_data', Text), - ) + Table( + "text_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("text_data", Text), + ) def test_text_roundtrip(self): text_table = self.tables.text_table - config.db.execute( - text_table.insert(), - {"text_data": 'some text'} - ) - row = config.db.execute( - select([text_table.c.text_data]) - ).first() - eq_(row, ('some text',)) + config.db.execute(text_table.insert(), {"text_data": "some text"}) + row = config.db.execute(select([text_table.c.text_data])).first() + eq_(row, ("some text",)) def test_text_empty_strings(self): text_table = self.tables.text_table - config.db.execute( - text_table.insert(), - {"text_data": ''} - ) - row = config.db.execute( - select([text_table.c.text_data]) - ).first() - eq_(row, ('',)) + config.db.execute(text_table.insert(), {"text_data": ""}) + row = config.db.execute(select([text_table.c.text_data])).first() + eq_(row, ("",)) def test_literal(self): self._literal_round_trip(Text, ["some text"], ["some text"]) def test_literal_quoting(self): - data = '''some 'text' hey "hi there" that's text''' + data = """some 'text' hey "hi there" that's text""" self._literal_round_trip(Text, [data], [data]) def test_literal_backslashes(self): - data = r'backslash one \ backslash two \\ end' + data = r"backslash one \ backslash two \\ end" self._literal_round_trip(Text, [data], [data]) def test_literal_percentsigns(self): - data = r'percent % signs %% percent' + data = r"percent % signs %% percent" self._literal_round_trip(Text, [data], [data]) @@ -199,9 +189,7 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): @requirements.unbounded_varchar def test_nolength_string(self): metadata = MetaData() - foo = Table('foo', metadata, - Column('one', String) - ) + foo = Table("foo", metadata, Column("one", String)) foo.create(config.db) foo.drop(config.db) @@ -210,11 +198,11 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): self._literal_round_trip(String(40), ["some text"], ["some text"]) def test_literal_quoting(self): - data = '''some 'text' hey "hi there" that's text''' + data = """some 'text' hey "hi there" that's text""" self._literal_round_trip(String(40), [data], [data]) def test_literal_backslashes(self): - data = r'backslash one \ backslash two \\ end' + data = r"backslash one \ backslash two \\ end" self._literal_round_trip(String(40), [data], [data]) @@ -223,44 +211,32 @@ class _DateFixture(_LiteralRoundTripFixture): @classmethod def define_tables(cls, metadata): - Table('date_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('date_data', cls.datatype), - ) + Table( + "date_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("date_data", cls.datatype), + ) def test_round_trip(self): date_table = self.tables.date_table - config.db.execute( - date_table.insert(), - {'date_data': self.data} - ) + config.db.execute(date_table.insert(), {"date_data": self.data}) - row = config.db.execute( - select([ - date_table.c.date_data, - ]) - ).first() + row = config.db.execute(select([date_table.c.date_data])).first() compare = self.compare or self.data - eq_(row, - (compare, )) + eq_(row, (compare,)) assert isinstance(row[0], type(compare)) def test_null(self): date_table = self.tables.date_table - config.db.execute( - date_table.insert(), - {'date_data': None} - ) + config.db.execute(date_table.insert(), {"date_data": None}) - row = config.db.execute( - select([ - date_table.c.date_data, - ]) - ).first() + row = config.db.execute(select([date_table.c.date_data])).first() eq_(row, (None,)) @testing.requires.datetime_literals @@ -270,48 +246,49 @@ class _DateFixture(_LiteralRoundTripFixture): class DateTimeTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'datetime', + __requires__ = ("datetime",) __backend__ = True datatype = DateTime data = datetime.datetime(2012, 10, 15, 12, 57, 18) class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'datetime_microseconds', + __requires__ = ("datetime_microseconds",) __backend__ = True datatype = DateTime data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396) + class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'timestamp_microseconds', + __requires__ = ("timestamp_microseconds",) __backend__ = True datatype = TIMESTAMP data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396) class TimeTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'time', + __requires__ = ("time",) __backend__ = True datatype = Time data = datetime.time(12, 57, 18) class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'time_microseconds', + __requires__ = ("time_microseconds",) __backend__ = True datatype = Time data = datetime.time(12, 57, 18, 396) class DateTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'date', + __requires__ = ("date",) __backend__ = True datatype = Date data = datetime.date(2012, 10, 15) class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'date', 'date_coerces_from_datetime' + __requires__ = "date", "date_coerces_from_datetime" __backend__ = True datatype = Date data = datetime.datetime(2012, 10, 15, 12, 57, 18) @@ -319,14 +296,14 @@ class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest): class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'datetime_historic', + __requires__ = ("datetime_historic",) __backend__ = True datatype = DateTime data = datetime.datetime(1850, 11, 10, 11, 52, 35) class DateHistoricTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'date_historic', + __requires__ = ("date_historic",) __backend__ = True datatype = Date data = datetime.date(1727, 4, 1) @@ -345,26 +322,21 @@ class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase): def _round_trip(self, datatype, data): metadata = self.metadata int_table = Table( - 'integer_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('integer_data', datatype), + "integer_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("integer_data", datatype), ) metadata.create_all(config.db) - config.db.execute( - int_table.insert(), - {'integer_data': data} - ) + config.db.execute(int_table.insert(), {"integer_data": data}) - row = config.db.execute( - select([ - int_table.c.integer_data, - ]) - ).first() + row = config.db.execute(select([int_table.c.integer_data])).first() - eq_(row, (data, )) + eq_(row, (data,)) if util.py3k: assert isinstance(row[0], int) @@ -377,12 +349,11 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): @testing.emits_warning(r".*does \*not\* support Decimal objects natively") @testing.provide_metadata - def _do_test(self, type_, input_, output, - filter_=None, check_scale=False): + def _do_test(self, type_, input_, output, filter_=None, check_scale=False): metadata = self.metadata - t = Table('t', metadata, Column('x', type_)) + t = Table("t", metadata, Column("x", type_)) t.create() - t.insert().execute([{'x': x} for x in input_]) + t.insert().execute([{"x": x} for x in input_]) result = {row[0] for row in t.select().execute()} output = set(output) @@ -391,10 +362,7 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): output = set(filter_(x) for x in output) eq_(result, output) if check_scale: - eq_( - [str(x) for x in result], - [str(x) for x in output], - ) + eq_([str(x) for x in result], [str(x) for x in output]) @testing.emits_warning(r".*does \*not\* support Decimal objects natively") def test_render_literal_numeric(self): @@ -416,8 +384,8 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): self._literal_round_trip( Float(4), [15.7563, decimal.Decimal("15.7563")], - [15.7563, ], - filter_=lambda n: n is not None and round(n, 5) or None + [15.7563], + filter_=lambda n: n is not None and round(n, 5) or None, ) @testing.requires.precision_generic_float_type @@ -425,8 +393,8 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): self._do_test( Float(None, decimal_return_scale=7, asdecimal=True), [15.7563827, decimal.Decimal("15.7563827")], - [decimal.Decimal("15.7563827"), ], - check_scale=True + [decimal.Decimal("15.7563827")], + check_scale=True, ) def test_numeric_as_decimal(self): @@ -445,18 +413,12 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): @testing.requires.fetch_null_from_numeric def test_numeric_null_as_decimal(self): - self._do_test( - Numeric(precision=8, scale=4), - [None], - [None], - ) + self._do_test(Numeric(precision=8, scale=4), [None], [None]) @testing.requires.fetch_null_from_numeric def test_numeric_null_as_float(self): self._do_test( - Numeric(precision=8, scale=4, asdecimal=False), - [None], - [None], + Numeric(precision=8, scale=4, asdecimal=False), [None], [None] ) @testing.requires.floats_to_four_decimals @@ -472,15 +434,13 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): Float(precision=8), [15.7563, decimal.Decimal("15.7563")], [15.7563], - filter_=lambda n: n is not None and round(n, 5) or None + filter_=lambda n: n is not None and round(n, 5) or None, ) def test_float_coerce_round_trip(self): expr = 15.7563 - val = testing.db.scalar( - select([literal(expr)]) - ) + val = testing.db.scalar(select([literal(expr)])) eq_(val, expr) # this does not work in MySQL, see #4036, however we choose not @@ -491,34 +451,28 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): def test_decimal_coerce_round_trip(self): expr = decimal.Decimal("15.7563") - val = testing.db.scalar( - select([literal(expr)]) - ) + val = testing.db.scalar(select([literal(expr)])) eq_(val, expr) @testing.emits_warning(r".*does \*not\* support Decimal objects natively") def test_decimal_coerce_round_trip_w_cast(self): expr = decimal.Decimal("15.7563") - val = testing.db.scalar( - select([cast(expr, Numeric(10, 4))]) - ) + val = testing.db.scalar(select([cast(expr, Numeric(10, 4))])) eq_(val, expr) @testing.requires.precision_numerics_general def test_precision_decimal(self): - numbers = set([ - decimal.Decimal("54.234246451650"), - decimal.Decimal("0.004354"), - decimal.Decimal("900.0"), - ]) - - self._do_test( - Numeric(precision=18, scale=12), - numbers, - numbers, + numbers = set( + [ + decimal.Decimal("54.234246451650"), + decimal.Decimal("0.004354"), + decimal.Decimal("900.0"), + ] ) + self._do_test(Numeric(precision=18, scale=12), numbers, numbers) + @testing.requires.precision_numerics_enotation_large def test_enotation_decimal(self): """test exceedingly small decimals. @@ -528,25 +482,23 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): """ - numbers = set([ - decimal.Decimal('1E-2'), - decimal.Decimal('1E-3'), - decimal.Decimal('1E-4'), - decimal.Decimal('1E-5'), - decimal.Decimal('1E-6'), - decimal.Decimal('1E-7'), - decimal.Decimal('1E-8'), - decimal.Decimal("0.01000005940696"), - decimal.Decimal("0.00000005940696"), - decimal.Decimal("0.00000000000696"), - decimal.Decimal("0.70000000000696"), - decimal.Decimal("696E-12"), - ]) - self._do_test( - Numeric(precision=18, scale=14), - numbers, - numbers + numbers = set( + [ + decimal.Decimal("1E-2"), + decimal.Decimal("1E-3"), + decimal.Decimal("1E-4"), + decimal.Decimal("1E-5"), + decimal.Decimal("1E-6"), + decimal.Decimal("1E-7"), + decimal.Decimal("1E-8"), + decimal.Decimal("0.01000005940696"), + decimal.Decimal("0.00000005940696"), + decimal.Decimal("0.00000000000696"), + decimal.Decimal("0.70000000000696"), + decimal.Decimal("696E-12"), + ] ) + self._do_test(Numeric(precision=18, scale=14), numbers, numbers) @testing.requires.precision_numerics_enotation_large def test_enotation_decimal_large(self): @@ -554,41 +506,32 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): """ - numbers = set([ - decimal.Decimal('4E+8'), - decimal.Decimal("5748E+15"), - decimal.Decimal('1.521E+15'), - decimal.Decimal('00000000000000.1E+12'), - ]) - self._do_test( - Numeric(precision=25, scale=2), - numbers, - numbers + numbers = set( + [ + decimal.Decimal("4E+8"), + decimal.Decimal("5748E+15"), + decimal.Decimal("1.521E+15"), + decimal.Decimal("00000000000000.1E+12"), + ] ) + self._do_test(Numeric(precision=25, scale=2), numbers, numbers) @testing.requires.precision_numerics_many_significant_digits def test_many_significant_digits(self): - numbers = set([ - decimal.Decimal("31943874831932418390.01"), - decimal.Decimal("319438950232418390.273596"), - decimal.Decimal("87673.594069654243"), - ]) - self._do_test( - Numeric(precision=38, scale=12), - numbers, - numbers + numbers = set( + [ + decimal.Decimal("31943874831932418390.01"), + decimal.Decimal("319438950232418390.273596"), + decimal.Decimal("87673.594069654243"), + ] ) + self._do_test(Numeric(precision=38, scale=12), numbers, numbers) @testing.requires.precision_numerics_retains_significant_digits def test_numeric_no_decimal(self): - numbers = set([ - decimal.Decimal("1.000") - ]) + numbers = set([decimal.Decimal("1.000")]) self._do_test( - Numeric(precision=5, scale=3), - numbers, - numbers, - check_scale=True + Numeric(precision=5, scale=3), numbers, numbers, check_scale=True ) @@ -597,42 +540,32 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table('boolean_table', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('value', Boolean), - Column('unconstrained_value', Boolean(create_constraint=False)), - ) + Table( + "boolean_table", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("value", Boolean), + Column("unconstrained_value", Boolean(create_constraint=False)), + ) def test_render_literal_bool(self): - self._literal_round_trip( - Boolean(), - [True, False], - [True, False] - ) + self._literal_round_trip(Boolean(), [True, False], [True, False]) def test_round_trip(self): boolean_table = self.tables.boolean_table config.db.execute( boolean_table.insert(), - { - 'id': 1, - 'value': True, - 'unconstrained_value': False - } + {"id": 1, "value": True, "unconstrained_value": False}, ) row = config.db.execute( - select([ - boolean_table.c.value, - boolean_table.c.unconstrained_value - ]) + select( + [boolean_table.c.value, boolean_table.c.unconstrained_value] + ) ).first() - eq_( - row, - (True, False) - ) + eq_(row, (True, False)) assert isinstance(row[0], bool) def test_null(self): @@ -640,24 +573,16 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): config.db.execute( boolean_table.insert(), - { - 'id': 1, - 'value': None, - 'unconstrained_value': None - } + {"id": 1, "value": None, "unconstrained_value": None}, ) row = config.db.execute( - select([ - boolean_table.c.value, - boolean_table.c.unconstrained_value - ]) + select( + [boolean_table.c.value, boolean_table.c.unconstrained_value] + ) ).first() - eq_( - row, - (None, None) - ) + eq_(row, (None, None)) def test_whereclause(self): # testing "WHERE " renders a compatible expression @@ -667,92 +592,82 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): conn.execute( boolean_table.insert(), [ - {'id': 1, 'value': True, 'unconstrained_value': True}, - {'id': 2, 'value': False, 'unconstrained_value': False} - ] + {"id": 1, "value": True, "unconstrained_value": True}, + {"id": 2, "value": False, "unconstrained_value": False}, + ], ) eq_( conn.scalar( select([boolean_table.c.id]).where(boolean_table.c.value) ), - 1 + 1, ) eq_( conn.scalar( select([boolean_table.c.id]).where( - boolean_table.c.unconstrained_value) + boolean_table.c.unconstrained_value + ) ), - 1 + 1, ) eq_( conn.scalar( select([boolean_table.c.id]).where(~boolean_table.c.value) ), - 2 + 2, ) eq_( conn.scalar( select([boolean_table.c.id]).where( - ~boolean_table.c.unconstrained_value) + ~boolean_table.c.unconstrained_value + ) ), - 2 + 2, ) - - class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): - __requires__ = 'json_type', + __requires__ = ("json_type",) __backend__ = True datatype = JSON - data1 = { - "key1": "value1", - "key2": "value2" - } + data1 = {"key1": "value1", "key2": "value2"} data2 = { "Key 'One'": "value1", "key two": "value2", - "key three": "value ' three '" + "key three": "value ' three '", } data3 = { "key1": [1, 2, 3], "key2": ["one", "two", "three"], - "key3": [{"four": "five"}, {"six": "seven"}] + "key3": [{"four": "five"}, {"six": "seven"}], } data4 = ["one", "two", "three"] data5 = { "nested": { - "elem1": [ - {"a": "b", "c": "d"}, - {"e": "f", "g": "h"} - ], - "elem2": { - "elem3": {"elem4": "elem5"} - } + "elem1": [{"a": "b", "c": "d"}, {"e": "f", "g": "h"}], + "elem2": {"elem3": {"elem4": "elem5"}}, } } - data6 = { - "a": 5, - "b": "some value", - "c": {"foo": "bar"} - } + data6 = {"a": 5, "b": "some value", "c": {"foo": "bar"}} @classmethod def define_tables(cls, metadata): - Table('data_table', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(30), nullable=False), - Column('data', cls.datatype), - Column('nulldata', cls.datatype(none_as_null=True)) - ) + Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30), nullable=False), + Column("data", cls.datatype), + Column("nulldata", cls.datatype(none_as_null=True)), + ) def test_round_trip_data1(self): self._test_round_trip(self.data1) @@ -761,99 +676,82 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): data_table = self.tables.data_table config.db.execute( - data_table.insert(), - {'name': 'row1', 'data': data_element} + data_table.insert(), {"name": "row1", "data": data_element} ) - row = config.db.execute( - select([ - data_table.c.data, - ]) - ).first() + row = config.db.execute(select([data_table.c.data])).first() - eq_(row, (data_element, )) + eq_(row, (data_element,)) def test_round_trip_none_as_sql_null(self): - col = self.tables.data_table.c['nulldata'] + col = self.tables.data_table.c["nulldata"] with config.db.connect() as conn: conn.execute( - self.tables.data_table.insert(), - {"name": "r1", "data": None} + self.tables.data_table.insert(), {"name": "r1", "data": None} ) eq_( conn.scalar( - select([self.tables.data_table.c.name]). - where(col.is_(null())) + select([self.tables.data_table.c.name]).where( + col.is_(null()) + ) ), - "r1" + "r1", ) - eq_( - conn.scalar( - select([col]) - ), - None - ) + eq_(conn.scalar(select([col])), None) def test_round_trip_json_null_as_json_null(self): - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] with config.db.connect() as conn: conn.execute( self.tables.data_table.insert(), - {"name": "r1", "data": JSON.NULL} + {"name": "r1", "data": JSON.NULL}, ) eq_( conn.scalar( - select([self.tables.data_table.c.name]). - where(cast(col, String) == 'null') + select([self.tables.data_table.c.name]).where( + cast(col, String) == "null" + ) ), - "r1" + "r1", ) - eq_( - conn.scalar( - select([col]) - ), - None - ) + eq_(conn.scalar(select([col])), None) def test_round_trip_none_as_json_null(self): - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] with config.db.connect() as conn: conn.execute( - self.tables.data_table.insert(), - {"name": "r1", "data": None} + self.tables.data_table.insert(), {"name": "r1", "data": None} ) eq_( conn.scalar( - select([self.tables.data_table.c.name]). - where(cast(col, String) == 'null') + select([self.tables.data_table.c.name]).where( + cast(col, String) == "null" + ) ), - "r1" + "r1", ) - eq_( - conn.scalar( - select([col]) - ), - None - ) + eq_(conn.scalar(select([col])), None) def _criteria_fixture(self): config.db.execute( self.tables.data_table.insert(), - [{"name": "r1", "data": self.data1}, - {"name": "r2", "data": self.data2}, - {"name": "r3", "data": self.data3}, - {"name": "r4", "data": self.data4}, - {"name": "r5", "data": self.data5}, - {"name": "r6", "data": self.data6}] + [ + {"name": "r1", "data": self.data1}, + {"name": "r2", "data": self.data2}, + {"name": "r3", "data": self.data3}, + {"name": "r4", "data": self.data4}, + {"name": "r5", "data": self.data5}, + {"name": "r6", "data": self.data6}, + ], ) def _test_index_criteria(self, crit, expected, test_literal=True): @@ -861,20 +759,20 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): with config.db.connect() as conn: stmt = select([self.tables.data_table.c.name]).where(crit) - eq_( - conn.scalar(stmt), - expected - ) + eq_(conn.scalar(stmt), expected) if test_literal: - literal_sql = str(stmt.compile( - config.db, compile_kwargs={"literal_binds": True})) + literal_sql = str( + stmt.compile( + config.db, compile_kwargs={"literal_binds": True} + ) + ) eq_(conn.scalar(literal_sql), expected) def test_crit_spaces_in_key(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] # limit the rows here to avoid PG error # "cannot extract field from a non-object", which is @@ -882,76 +780,74 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): self._test_index_criteria( and_( name.in_(["r1", "r2", "r3"]), - cast(col["key two"], String) == '"value2"' + cast(col["key two"], String) == '"value2"', ), - "r2" + "r2", ) @config.requirements.json_array_indexes def test_crit_simple_int(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] # limit the rows here to avoid PG error # "cannot extract array element from a non-array", which is # fixed in 9.4 but may exist in 9.3 self._test_index_criteria( - and_(name == 'r4', cast(col[1], String) == '"two"'), - "r4" + and_(name == "r4", cast(col[1], String) == '"two"'), "r4" ) def test_crit_mixed_path(self): - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( - cast(col[("key3", 1, "six")], String) == '"seven"', - "r3" + cast(col[("key3", 1, "six")], String) == '"seven"', "r3" ) def test_crit_string_path(self): - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( cast(col[("nested", "elem2", "elem3", "elem4")], String) == '"elem5"', - "r5" + "r5", ) def test_crit_against_string_basic(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( - and_(name == 'r6', cast(col["b"], String) == '"some value"'), - "r6" + and_(name == "r6", cast(col["b"], String) == '"some value"'), "r6" ) def test_crit_against_string_coerce_type(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( - and_(name == 'r6', - cast(col["b"], String) == type_coerce("some value", JSON)), + and_( + name == "r6", + cast(col["b"], String) == type_coerce("some value", JSON), + ), "r6", - test_literal=False + test_literal=False, ) def test_crit_against_int_basic(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( - and_(name == 'r6', cast(col["a"], String) == '5'), - "r6" + and_(name == "r6", cast(col["a"], String) == "5"), "r6" ) def test_crit_against_int_coerce_type(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( - and_(name == 'r6', cast(col["a"], String) == type_coerce(5, JSON)), + and_(name == "r6", cast(col["a"], String) == type_coerce(5, JSON)), "r6", - test_literal=False + test_literal=False, ) def test_unicode_round_trip(self): @@ -961,17 +857,17 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): { "name": "r1", "data": { - util.u('réveillé'): util.u('réveillé'), - "data": {"k1": util.u('drôle')} - } - } + util.u("réveillé"): util.u("réveillé"), + "data": {"k1": util.u("drôle")}, + }, + }, ) eq_( conn.scalar(select([self.tables.data_table.c.data])), { - util.u('réveillé'): util.u('réveillé'), - "data": {"k1": util.u('drôle')} + util.u("réveillé"): util.u("réveillé"), + "data": {"k1": util.u("drôle")}, }, ) @@ -986,7 +882,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): s = Session(testing.db) - d1 = Data(name='d1', data=None, nulldata=None) + d1 = Data(name="d1", data=None, nulldata=None) s.add(d1) s.commit() @@ -995,24 +891,46 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): ) eq_( s.query( - cast(self.tables.data_table.c.data, String(convert_unicode="force")), - cast(self.tables.data_table.c.nulldata, String) - ).filter(self.tables.data_table.c.name == 'd1').first(), - ("null", None) + cast( + self.tables.data_table.c.data, + String(convert_unicode="force"), + ), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d1") + .first(), + ("null", None), ) eq_( s.query( - cast(self.tables.data_table.c.data, String(convert_unicode="force")), - cast(self.tables.data_table.c.nulldata, String) - ).filter(self.tables.data_table.c.name == 'd2').first(), - ("null", None) - ) - - -__all__ = ('UnicodeVarcharTest', 'UnicodeTextTest', 'JSONTest', - 'DateTest', 'DateTimeTest', 'TextTest', - 'NumericTest', 'IntegerTest', - 'DateTimeHistoricTest', 'DateTimeCoercedToDateTimeTest', - 'TimeMicrosecondsTest', 'TimestampMicrosecondsTest', 'TimeTest', - 'DateTimeMicrosecondsTest', - 'DateHistoricTest', 'StringTest', 'BooleanTest') + cast( + self.tables.data_table.c.data, + String(convert_unicode="force"), + ), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d2") + .first(), + ("null", None), + ) + + +__all__ = ( + "UnicodeVarcharTest", + "UnicodeTextTest", + "JSONTest", + "DateTest", + "DateTimeTest", + "TextTest", + "NumericTest", + "IntegerTest", + "DateTimeHistoricTest", + "DateTimeCoercedToDateTimeTest", + "TimeMicrosecondsTest", + "TimestampMicrosecondsTest", + "TimeTest", + "DateTimeMicrosecondsTest", + "DateHistoricTest", + "StringTest", + "BooleanTest", +) diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py index e4c61e74a4..b232c3a786 100644 --- a/lib/sqlalchemy/testing/suite/test_update_delete.py +++ b/lib/sqlalchemy/testing/suite/test_update_delete.py @@ -6,15 +6,17 @@ from ..schema import Table, Column class SimpleUpdateDeleteTest(fixtures.TablesTest): - run_deletes = 'each' + run_deletes = "each" __backend__ = True @classmethod def define_tables(cls, metadata): - Table('plain_pk', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) - ) + Table( + "plain_pk", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) @classmethod def insert_data(cls): @@ -24,40 +26,29 @@ class SimpleUpdateDeleteTest(fixtures.TablesTest): {"id": 1, "data": "d1"}, {"id": 2, "data": "d2"}, {"id": 3, "data": "d3"}, - ] + ], ) def test_update(self): t = self.tables.plain_pk - r = config.db.execute( - t.update().where(t.c.id == 2), - data="d2_new" - ) + r = config.db.execute(t.update().where(t.c.id == 2), data="d2_new") assert not r.is_insert assert not r.returns_rows eq_( config.db.execute(t.select().order_by(t.c.id)).fetchall(), - [ - (1, "d1"), - (2, "d2_new"), - (3, "d3") - ] + [(1, "d1"), (2, "d2_new"), (3, "d3")], ) def test_delete(self): t = self.tables.plain_pk - r = config.db.execute( - t.delete().where(t.c.id == 2) - ) + r = config.db.execute(t.delete().where(t.c.id == 2)) assert not r.is_insert assert not r.returns_rows eq_( config.db.execute(t.select().order_by(t.c.id)).fetchall(), - [ - (1, "d1"), - (3, "d3") - ] + [(1, "d1"), (3, "d3")], ) -__all__ = ('SimpleUpdateDeleteTest', ) + +__all__ = ("SimpleUpdateDeleteTest",) diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index 409d3bda5e..5b015d2142 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -14,6 +14,7 @@ import sys import types if jython: + def jython_gc_collect(*args): """aggressive gc.collect for tests.""" gc.collect() @@ -25,9 +26,11 @@ if jython: # "lazy" gc, for VM's that don't GC on refcount == 0 gc_collect = lazy_gc = jython_gc_collect elif pypy: + def pypy_gc_collect(*args): gc.collect() gc.collect() + gc_collect = lazy_gc = pypy_gc_collect else: # assume CPython - straight gc.collect, lazy_gc() is a pass @@ -42,11 +45,13 @@ def picklers(): if py2k: try: import cPickle + picklers.add(cPickle) except ImportError: pass import pickle + picklers.add(pickle) # yes, this thing needs this much testing @@ -60,9 +65,9 @@ def round_decimal(value, prec): return round(value, prec) # can also use shift() here but that is 2.6 only - return (value * decimal.Decimal("1" + "0" * prec) - ).to_integral(decimal.ROUND_FLOOR) / \ - pow(10, prec) + return (value * decimal.Decimal("1" + "0" * prec)).to_integral( + decimal.ROUND_FLOOR + ) / pow(10, prec) class RandomSet(set): @@ -137,8 +142,9 @@ def function_named(fn, name): try: fn.__name__ = name except TypeError: - fn = types.FunctionType(fn.__code__, fn.__globals__, name, - fn.__defaults__, fn.__closure__) + fn = types.FunctionType( + fn.__code__, fn.__globals__, name, fn.__defaults__, fn.__closure__ + ) return fn @@ -190,7 +196,7 @@ def provide_metadata(fn, *args, **kw): metadata = schema.MetaData(config.db) self = args[0] - prev_meta = getattr(self, 'metadata', None) + prev_meta = getattr(self, "metadata", None) self.metadata = metadata try: return fn(*args, **kw) @@ -213,8 +219,8 @@ def force_drop_names(*names): try: return fn(*args, **kw) finally: - drop_all_tables( - config.db, inspect(config.db), include_names=names) + drop_all_tables(config.db, inspect(config.db), include_names=names) + return go @@ -234,8 +240,13 @@ class adict(dict): def drop_all_tables(engine, inspector, schema=None, include_names=None): - from sqlalchemy import Column, Table, Integer, MetaData, \ - ForeignKeyConstraint + from sqlalchemy import ( + Column, + Table, + Integer, + MetaData, + ForeignKeyConstraint, + ) from sqlalchemy.schema import DropTable, DropConstraint if include_names is not None: @@ -243,30 +254,35 @@ def drop_all_tables(engine, inspector, schema=None, include_names=None): with engine.connect() as conn: for tname, fkcs in reversed( - inspector.get_sorted_table_and_fkc_names(schema=schema)): + inspector.get_sorted_table_and_fkc_names(schema=schema) + ): if tname: if include_names is not None and tname not in include_names: continue - conn.execute(DropTable( - Table(tname, MetaData(), schema=schema) - )) + conn.execute( + DropTable(Table(tname, MetaData(), schema=schema)) + ) elif fkcs: if not engine.dialect.supports_alter: continue for tname, fkc in fkcs: - if include_names is not None and \ - tname not in include_names: + if ( + include_names is not None + and tname not in include_names + ): continue tb = Table( - tname, MetaData(), - Column('x', Integer), - Column('y', Integer), - schema=schema + tname, + MetaData(), + Column("x", Integer), + Column("y", Integer), + schema=schema, + ) + conn.execute( + DropConstraint( + ForeignKeyConstraint([tb.c.x], [tb.c.y], name=fkc) + ) ) - conn.execute(DropConstraint( - ForeignKeyConstraint( - [tb.c.x], [tb.c.y], name=fkc) - )) def teardown_events(event_cls): @@ -276,5 +292,5 @@ def teardown_events(event_cls): return fn(*arg, **kw) finally: event_cls._clear() - return decorate + return decorate diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py index 46e7c54dbb..e0101b14d1 100644 --- a/lib/sqlalchemy/testing/warnings.py +++ b/lib/sqlalchemy/testing/warnings.py @@ -15,17 +15,20 @@ from . import assertions def setup_filters(): """Set global warning behavior for the test suite.""" - warnings.filterwarnings('ignore', - category=sa_exc.SAPendingDeprecationWarning) - warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning) - warnings.filterwarnings('error', category=sa_exc.SAWarning) + warnings.filterwarnings( + "ignore", category=sa_exc.SAPendingDeprecationWarning + ) + warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning) + warnings.filterwarnings("error", category=sa_exc.SAWarning) # some selected deprecations... - warnings.filterwarnings('error', category=DeprecationWarning) + warnings.filterwarnings("error", category=DeprecationWarning) warnings.filterwarnings( - "ignore", category=DeprecationWarning, message=".*StopIteration") + "ignore", category=DeprecationWarning, message=".*StopIteration" + ) warnings.filterwarnings( - "ignore", category=DeprecationWarning, message=".*inspect.getargspec") + "ignore", category=DeprecationWarning, message=".*inspect.getargspec" + ) def assert_warnings(fn, warning_msgs, regex=False): @@ -36,6 +39,6 @@ def assert_warnings(fn, warning_msgs, regex=False): """ with assertions._expect_warnings( - sa_exc.SAWarning, warning_msgs, regex=regex): + sa_exc.SAWarning, warning_msgs, regex=regex + ): return fn() - diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index cd0ded7d26..e665828014 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -9,15 +9,55 @@ """ -__all__ = ['TypeEngine', 'TypeDecorator', 'UserDefinedType', - 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR', 'TEXT', 'Text', - 'FLOAT', 'NUMERIC', 'REAL', 'DECIMAL', 'TIMESTAMP', 'DATETIME', - 'CLOB', 'BLOB', 'BINARY', 'VARBINARY', 'BOOLEAN', 'BIGINT', - 'SMALLINT', 'INTEGER', 'DATE', 'TIME', 'String', 'Integer', - 'SmallInteger', 'BigInteger', 'Numeric', 'Float', 'DateTime', - 'Date', 'Time', 'LargeBinary', 'Binary', 'Boolean', 'Unicode', - 'Concatenable', 'UnicodeText', 'PickleType', 'Interval', 'Enum', - 'Indexable', 'ARRAY', 'JSON'] +__all__ = [ + "TypeEngine", + "TypeDecorator", + "UserDefinedType", + "INT", + "CHAR", + "VARCHAR", + "NCHAR", + "NVARCHAR", + "TEXT", + "Text", + "FLOAT", + "NUMERIC", + "REAL", + "DECIMAL", + "TIMESTAMP", + "DATETIME", + "CLOB", + "BLOB", + "BINARY", + "VARBINARY", + "BOOLEAN", + "BIGINT", + "SMALLINT", + "INTEGER", + "DATE", + "TIME", + "String", + "Integer", + "SmallInteger", + "BigInteger", + "Numeric", + "Float", + "DateTime", + "Date", + "Time", + "LargeBinary", + "Binary", + "Boolean", + "Unicode", + "Concatenable", + "UnicodeText", + "PickleType", + "Interval", + "Enum", + "Indexable", + "ARRAY", + "JSON", +] from .sql.type_api import ( adapt_type, @@ -25,7 +65,7 @@ from .sql.type_api import ( TypeDecorator, Variant, to_instance, - UserDefinedType + UserDefinedType, ) from .sql.sqltypes import ( ARRAY, @@ -78,4 +118,4 @@ from .sql.sqltypes import ( UnicodeText, VARBINARY, VARCHAR, - ) +) diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index d8c28d6afc..103225e2a3 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -5,42 +5,146 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from .compat import callable, cmp, reduce, \ - threading, py3k, py33, py36, py2k, jython, pypy, cpython, win32, \ - pickle, dottedgetter, parse_qsl, namedtuple, next, reraise, \ - raise_from_cause, text_type, safe_kwarg, string_types, int_types, \ - binary_type, nested, \ - quote_plus, with_metaclass, print_, itertools_filterfalse, u, ue, b,\ - unquote_plus, unquote, b64decode, b64encode, byte_buffer, itertools_filter,\ - iterbytes, StringIO, inspect_getargspec, zip_longest +from .compat import ( + callable, + cmp, + reduce, + threading, + py3k, + py33, + py36, + py2k, + jython, + pypy, + cpython, + win32, + pickle, + dottedgetter, + parse_qsl, + namedtuple, + next, + reraise, + raise_from_cause, + text_type, + safe_kwarg, + string_types, + int_types, + binary_type, + nested, + quote_plus, + with_metaclass, + print_, + itertools_filterfalse, + u, + ue, + b, + unquote_plus, + unquote, + b64decode, + b64encode, + byte_buffer, + itertools_filter, + iterbytes, + StringIO, + inspect_getargspec, + zip_longest, +) -from ._collections import KeyedTuple, ImmutableContainer, immutabledict, \ - Properties, OrderedProperties, ImmutableProperties, OrderedDict, \ - OrderedSet, IdentitySet, OrderedIdentitySet, column_set, \ - column_dict, ordered_column_set, populate_column_dict, unique_list, \ - UniqueAppender, PopulateDict, EMPTY_SET, to_list, to_set, \ - to_column_set, update_copy, flatten_iterator, has_intersection, \ - LRUCache, ScopedRegistry, ThreadLocalRegistry, WeakSequence, \ - coerce_generator_arg, lightweight_named_tuple, collections_abc, \ - has_dupes +from ._collections import ( + KeyedTuple, + ImmutableContainer, + immutabledict, + Properties, + OrderedProperties, + ImmutableProperties, + OrderedDict, + OrderedSet, + IdentitySet, + OrderedIdentitySet, + column_set, + column_dict, + ordered_column_set, + populate_column_dict, + unique_list, + UniqueAppender, + PopulateDict, + EMPTY_SET, + to_list, + to_set, + to_column_set, + update_copy, + flatten_iterator, + has_intersection, + LRUCache, + ScopedRegistry, + ThreadLocalRegistry, + WeakSequence, + coerce_generator_arg, + lightweight_named_tuple, + collections_abc, + has_dupes, +) -from .langhelpers import iterate_attributes, class_hierarchy, \ - portable_instancemethod, unbound_method_to_callable, \ - getargspec_init, format_argspec_init, format_argspec_plus, \ - get_func_kwargs, get_cls_kwargs, decorator, as_interface, \ - memoized_property, memoized_instancemethod, md5_hex, \ - group_expirable_memoized_property, dependencies, decode_slice, \ - monkeypatch_proxied_specials, asbool, bool_or_str, coerce_kw_type,\ - duck_type_collection, assert_arg_type, symbol, dictlike_iteritems,\ - classproperty, set_creation_order, warn_exception, warn, NoneType,\ - constructor_copy, methods_equivalent, chop_traceback, asint,\ - generic_repr, counter, PluginLoader, hybridproperty, hybridmethod, \ - safe_reraise, quoted_token_parser,\ - get_callable_argspec, only_once, attrsetter, ellipses_string, \ - warn_limited, map_bits, MemoizedSlots, EnsureKWArgType, wrap_callable +from .langhelpers import ( + iterate_attributes, + class_hierarchy, + portable_instancemethod, + unbound_method_to_callable, + getargspec_init, + format_argspec_init, + format_argspec_plus, + get_func_kwargs, + get_cls_kwargs, + decorator, + as_interface, + memoized_property, + memoized_instancemethod, + md5_hex, + group_expirable_memoized_property, + dependencies, + decode_slice, + monkeypatch_proxied_specials, + asbool, + bool_or_str, + coerce_kw_type, + duck_type_collection, + assert_arg_type, + symbol, + dictlike_iteritems, + classproperty, + set_creation_order, + warn_exception, + warn, + NoneType, + constructor_copy, + methods_equivalent, + chop_traceback, + asint, + generic_repr, + counter, + PluginLoader, + hybridproperty, + hybridmethod, + safe_reraise, + quoted_token_parser, + get_callable_argspec, + only_once, + attrsetter, + ellipses_string, + warn_limited, + map_bits, + MemoizedSlots, + EnsureKWArgType, + wrap_callable, +) -from .deprecations import warn_deprecated, warn_pending_deprecation, \ - deprecated, pending_deprecation, inject_docstring_text +from .deprecations import ( + warn_deprecated, + warn_pending_deprecation, + deprecated, + pending_deprecation, + inject_docstring_text, +) # things that used to be not always available, # but are now as of current support Python versions diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 43440134ab..67be0e6bfb 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -10,8 +10,13 @@ from __future__ import absolute_import import weakref import operator -from .compat import threading, itertools_filterfalse, string_types, \ - binary_types, collections_abc +from .compat import ( + threading, + itertools_filterfalse, + string_types, + binary_types, + collections_abc, +) from . import py2k import types @@ -77,7 +82,7 @@ class KeyedTuple(AbstractKeyedTuple): t.__dict__.update(zip(labels, vals)) else: labels = [] - t.__dict__['_labels'] = labels + t.__dict__["_labels"] = labels return t @property @@ -139,8 +144,7 @@ class ImmutableContainer(object): class immutabledict(ImmutableContainer, dict): - clear = pop = popitem = setdefault = \ - update = ImmutableContainer._immutable + clear = pop = popitem = setdefault = update = ImmutableContainer._immutable def __new__(cls, *args): new = dict.__new__(cls) @@ -151,7 +155,7 @@ class immutabledict(ImmutableContainer, dict): pass def __reduce__(self): - return immutabledict, (dict(self), ) + return immutabledict, (dict(self),) def union(self, d): if not d: @@ -173,10 +177,10 @@ class immutabledict(ImmutableContainer, dict): class Properties(object): """Provide a __getattr__/__setattr__ interface over a dict.""" - __slots__ = '_data', + __slots__ = ("_data",) def __init__(self, data): - object.__setattr__(self, '_data', data) + object.__setattr__(self, "_data", data) def __len__(self): return len(self._data) @@ -185,7 +189,9 @@ class Properties(object): return iter(list(self._data.values())) def __dir__(self): - return dir(super(Properties, self)) + [str(k) for k in self._data.keys()] + return dir(super(Properties, self)) + [ + str(k) for k in self._data.keys() + ] def __add__(self, other): return list(self) + list(other) @@ -203,10 +209,10 @@ class Properties(object): self._data[key] = obj def __getstate__(self): - return {'_data': self._data} + return {"_data": self._data} def __setstate__(self, state): - object.__setattr__(self, '_data', state['_data']) + object.__setattr__(self, "_data", state["_data"]) def __getattr__(self, key): try: @@ -266,7 +272,7 @@ class ImmutableProperties(ImmutableContainer, Properties): class OrderedDict(dict): """A dict that returns keys/values/items in the order they were added.""" - __slots__ = '_list', + __slots__ = ("_list",) def __reduce__(self): return OrderedDict, (self.items(),) @@ -294,7 +300,7 @@ class OrderedDict(dict): def update(self, ____sequence=None, **kwargs): if ____sequence is not None: - if hasattr(____sequence, 'keys'): + if hasattr(____sequence, "keys"): for key in ____sequence.keys(): self.__setitem__(key, ____sequence[key]) else: @@ -323,6 +329,7 @@ class OrderedDict(dict): return [(key, self[key]) for key in self._list] if py2k: + def itervalues(self): return iter(self.values()) @@ -402,7 +409,7 @@ class OrderedSet(set): return self.union(other) def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, self._list) + return "%s(%r)" % (self.__class__.__name__, self._list) __str__ = __repr__ @@ -502,13 +509,13 @@ class IdentitySet(object): pair = self._members.popitem() return pair[1] except KeyError: - raise KeyError('pop from an empty set') + raise KeyError("pop from an empty set") def clear(self): self._members.clear() def __cmp__(self, other): - raise TypeError('cannot compare sets using cmp()') + raise TypeError("cannot compare sets using cmp()") def __eq__(self, other): if isinstance(other, IdentitySet): @@ -527,8 +534,9 @@ class IdentitySet(object): if len(self) > len(other): return False - for m in itertools_filterfalse(other._members.__contains__, - iter(self._members.keys())): + for m in itertools_filterfalse( + other._members.__contains__, iter(self._members.keys()) + ): return False return True @@ -548,8 +556,9 @@ class IdentitySet(object): if len(self) < len(other): return False - for m in itertools_filterfalse(self._members.__contains__, - iter(other._members.keys())): + for m in itertools_filterfalse( + self._members.__contains__, iter(other._members.keys()) + ): return False return True @@ -635,7 +644,8 @@ class IdentitySet(object): members = self._member_id_tuples() other = _iter_id(iterable) result._members.update( - self._working_set(members).symmetric_difference(other)) + self._working_set(members).symmetric_difference(other) + ) return result def _member_id_tuples(self): @@ -667,10 +677,10 @@ class IdentitySet(object): return iter(self._members.values()) def __hash__(self): - raise TypeError('set objects are unhashable') + raise TypeError("set objects are unhashable") def __repr__(self): - return '%s(%r)' % (type(self).__name__, list(self._members.values())) + return "%s(%r)" % (type(self).__name__, list(self._members.values())) class WeakSequence(object): @@ -689,8 +699,9 @@ class WeakSequence(object): return len(self._storage) def __iter__(self): - return (obj for obj in - (ref() for ref in self._storage) if obj is not None) + return ( + obj for obj in (ref() for ref in self._storage) if obj is not None + ) def __getitem__(self, index): try: @@ -732,6 +743,7 @@ class PopulateDict(dict): self[key] = val = self.creator(key) return val + # Define collections that are capable of storing # ColumnElement objects as hashable keys/elements. # At this point, these are mostly historical, things @@ -745,20 +757,21 @@ populate_column_dict = PopulateDict _getters = PopulateDict(operator.itemgetter) _property_getters = PopulateDict( - lambda idx: property(operator.itemgetter(idx))) + lambda idx: property(operator.itemgetter(idx)) +) def unique_list(seq, hashfunc=None): seen = set() seen_add = seen.add if not hashfunc: - return [x for x in seq - if x not in seen - and not seen_add(x)] + return [x for x in seq if x not in seen and not seen_add(x)] else: - return [x for x in seq - if hashfunc(x) not in seen - and not seen_add(hashfunc(x))] + return [ + x + for x in seq + if hashfunc(x) not in seen and not seen_add(hashfunc(x)) + ] class UniqueAppender(object): @@ -773,9 +786,9 @@ class UniqueAppender(object): self._unique = {} if via: self._data_appender = getattr(data, via) - elif hasattr(data, 'append'): + elif hasattr(data, "append"): self._data_appender = data.append - elif hasattr(data, 'add'): + elif hasattr(data, "add"): self._data_appender = data.add def append(self, item): @@ -798,8 +811,9 @@ def coerce_generator_arg(arg): def to_list(x, default=None): if x is None: return default - if not isinstance(x, collections_abc.Iterable) or \ - isinstance(x, string_types + binary_types): + if not isinstance(x, collections_abc.Iterable) or isinstance( + x, string_types + binary_types + ): return [x] elif isinstance(x, list): return x @@ -815,9 +829,7 @@ def has_intersection(set_, iterable): """ # TODO: optimize, write in C, etc. - return bool( - set_.intersection([i for i in iterable if i.__hash__]) - ) + return bool(set_.intersection([i for i in iterable if i.__hash__])) def to_set(x): @@ -854,7 +866,7 @@ def flatten_iterator(x): """ for elem in x: - if not isinstance(elem, str) and hasattr(elem, '__iter__'): + if not isinstance(elem, str) and hasattr(elem, "__iter__"): for y in flatten_iterator(elem): yield y else: @@ -871,9 +883,9 @@ class LRUCache(dict): """ - __slots__ = 'capacity', 'threshold', 'size_alert', '_counter', '_mutex' + __slots__ = "capacity", "threshold", "size_alert", "_counter", "_mutex" - def __init__(self, capacity=100, threshold=.5, size_alert=None): + def __init__(self, capacity=100, threshold=0.5, size_alert=None): self.capacity = capacity self.threshold = threshold self.size_alert = size_alert @@ -929,10 +941,10 @@ class LRUCache(dict): if size_alert: size_alert = False self.size_alert(self) - by_counter = sorted(dict.values(self), - key=operator.itemgetter(2), - reverse=True) - for item in by_counter[self.capacity:]: + by_counter = sorted( + dict.values(self), key=operator.itemgetter(2), reverse=True + ) + for item in by_counter[self.capacity :]: try: del self[item[0]] except KeyError: @@ -946,17 +958,22 @@ _lw_tuples = LRUCache(100) def lightweight_named_tuple(name, fields): - hash_ = (name, ) + tuple(fields) + hash_ = (name,) + tuple(fields) tp_cls = _lw_tuples.get(hash_) if tp_cls: return tp_cls tp_cls = type( - name, (_LW,), - dict([ - (field, _property_getters[idx]) - for idx, field in enumerate(fields) if field is not None - ] + [('__slots__', ())]) + name, + (_LW,), + dict( + [ + (field, _property_getters[idx]) + for idx, field in enumerate(fields) + if field is not None + ] + + [("__slots__", ())] + ), ) tp_cls._real_fields = fields @@ -1077,6 +1094,7 @@ def has_dupes(sequence, target): return True return False + # .index version. the two __contains__ calls as well # as .index() and isinstance() slow this down. # def has_dupes(sequence, target): diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index b01471edf9..553624b49b 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -20,9 +20,9 @@ py32 = sys.version_info >= (3, 2) py3k = sys.version_info >= (3, 0) py2k = sys.version_info < (3, 0) py265 = sys.version_info >= (2, 6, 5) -jython = sys.platform.startswith('java') -pypy = hasattr(sys, 'pypy_version_info') -win32 = sys.platform.startswith('win') +jython = sys.platform.startswith("java") +pypy = hasattr(sys, "pypy_version_info") +win32 = sys.platform.startswith("win") cpython = not pypy and not jython # TODO: something better for this ? contextmanager = contextlib.contextmanager @@ -30,8 +30,9 @@ dottedgetter = operator.attrgetter namedtuple = collections.namedtuple next = next -ArgSpec = collections.namedtuple("ArgSpec", - ["args", "varargs", "keywords", "defaults"]) +ArgSpec = collections.namedtuple( + "ArgSpec", ["args", "varargs", "keywords", "defaults"] +) try: import threading @@ -58,40 +59,43 @@ if py3k: from io import BytesIO as byte_buffer from io import StringIO from itertools import zip_longest - from urllib.parse import (quote_plus, unquote_plus, parse_qsl, quote, unquote) - - string_types = str, - binary_types = bytes, + from urllib.parse import ( + quote_plus, + unquote_plus, + parse_qsl, + quote, + unquote, + ) + + string_types = (str,) + binary_types = (bytes,) binary_type = bytes text_type = str - int_types = int, + int_types = (int,) iterbytes = iter itertools_filterfalse = itertools.filterfalse itertools_filter = filter itertools_imap = map - exec_ = getattr(builtins, 'exec') - import_ = getattr(builtins, '__import__') + exec_ = getattr(builtins, "exec") + import_ = getattr(builtins, "__import__") print_ = getattr(builtins, "print") def b(s): return s.encode("latin-1") def b64decode(x): - return base64.b64decode(x.encode('ascii')) - + return base64.b64decode(x.encode("ascii")) def b64encode(x): - return base64.b64encode(x).decode('ascii') + return base64.b64encode(x).decode("ascii") def cmp(a, b): return (a > b) - (a < b) def inspect_getargspec(func): - return ArgSpec( - *inspect_getfullargspec(func)[0:4] - ) + return ArgSpec(*inspect_getfullargspec(func)[0:4]) def reraise(tp, value, tb=None, cause=None): if cause is not None: @@ -110,8 +114,11 @@ if py3k: if py32: callable = callable else: + def callable(fn): - return hasattr(fn, '__call__') + return hasattr(fn, "__call__") + + else: import base64 import ConfigParser as configparser @@ -129,8 +136,8 @@ else: except ImportError: import pickle - string_types = basestring, - binary_types = bytes, + string_types = (basestring,) + binary_types = (bytes,) binary_type = str text_type = unicode int_types = int, long @@ -153,9 +160,9 @@ else: def exec_(func_text, globals_, lcl=None): if lcl is None: - exec('exec func_text in globals_') + exec("exec func_text in globals_") else: - exec('exec func_text in globals_, lcl') + exec("exec func_text in globals_, lcl") def iterbytes(buf): return (ord(byte) for byte in buf) @@ -186,24 +193,32 @@ else: # not as nice as that of Py3K, but at least preserves # the code line where the issue occurred - exec("def reraise(tp, value, tb=None, cause=None):\n" - " if cause is not None:\n" - " assert cause is not value, 'Same cause emitted'\n" - " raise tp, value, tb\n") + exec( + "def reraise(tp, value, tb=None, cause=None):\n" + " if cause is not None:\n" + " assert cause is not value, 'Same cause emitted'\n" + " raise tp, value, tb\n" + ) if py35: from inspect import formatannotation def inspect_formatargspec( - args, varargs=None, varkw=None, defaults=None, - kwonlyargs=(), kwonlydefaults={}, annotations={}, - formatarg=str, - formatvarargs=lambda name: '*' + name, - formatvarkw=lambda name: '**' + name, - formatvalue=lambda value: '=' + repr(value), - formatreturns=lambda text: ' -> ' + text, - formatannotation=formatannotation): + args, + varargs=None, + varkw=None, + defaults=None, + kwonlyargs=(), + kwonlydefaults={}, + annotations={}, + formatarg=str, + formatvarargs=lambda name: "*" + name, + formatvarkw=lambda name: "**" + name, + formatvalue=lambda value: "=" + repr(value), + formatreturns=lambda text: " -> " + text, + formatannotation=formatannotation, + ): """Copy formatargspec from python 3.7 standard library. Python 3 has deprecated formatargspec and requested that Signature @@ -221,7 +236,7 @@ if py35: def formatargandannotation(arg): result = formatarg(arg) if arg in annotations: - result += ': ' + formatannotation(annotations[arg]) + result += ": " + formatannotation(annotations[arg]) return result specs = [] @@ -237,7 +252,7 @@ if py35: specs.append(formatvarargs(formatargandannotation(varargs))) else: if kwonlyargs: - specs.append('*') + specs.append("*") if kwonlyargs: for kwonlyarg in kwonlyargs: @@ -249,10 +264,12 @@ if py35: if varkw is not None: specs.append(formatvarkw(formatargandannotation(varkw))) - result = '(' + ', '.join(specs) + ')' - if 'return' in annotations: - result += formatreturns(formatannotation(annotations['return'])) + result = "(" + ", ".join(specs) + ")" + if "return" in annotations: + result += formatreturns(formatannotation(annotations["return"])) return result + + else: from inspect import formatargspec as inspect_formatargspec @@ -330,4 +347,5 @@ def with_metaclass(meta, *bases): if this_bases is None: return type.__new__(cls, name, (), d) return meta(name, bases, d) - return metaclass('temporary_class', None, {}) + + return metaclass("temporary_class", None, {}) diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index 9000cc7951..e6612f075d 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -40,8 +40,7 @@ def deprecated(version, message=None, add_deprecation_to_docstring=True): """ if add_deprecation_to_docstring: - header = ".. deprecated:: %s %s" % \ - (version, (message or '')) + header = ".. deprecated:: %s %s" % (version, (message or "")) else: header = None @@ -50,13 +49,18 @@ def deprecated(version, message=None, add_deprecation_to_docstring=True): def decorate(fn): return _decorate_with_warning( - fn, exc.SADeprecationWarning, - message % dict(func=fn.__name__), header) + fn, + exc.SADeprecationWarning, + message % dict(func=fn.__name__), + header, + ) + return decorate -def pending_deprecation(version, message=None, - add_deprecation_to_docstring=True): +def pending_deprecation( + version, message=None, add_deprecation_to_docstring=True +): """Decorates a function and issues a pending deprecation warning on use. :param version: @@ -74,8 +78,7 @@ def pending_deprecation(version, message=None, """ if add_deprecation_to_docstring: - header = ".. deprecated:: %s (pending) %s" % \ - (version, (message or '')) + header = ".. deprecated:: %s (pending) %s" % (version, (message or "")) else: header = None @@ -84,8 +87,12 @@ def pending_deprecation(version, message=None, def decorate(fn): return _decorate_with_warning( - fn, exc.SAPendingDeprecationWarning, - message % dict(func=fn.__name__), header) + fn, + exc.SAPendingDeprecationWarning, + message % dict(func=fn.__name__), + header, + ) + return decorate @@ -95,7 +102,8 @@ def _sanitize_restructured_text(text): if type_ in ("func", "meth"): name += "()" return name - return re.sub(r'\:(\w+)\:`~?\.?(.+?)`', repl, text) + + return re.sub(r"\:(\w+)\:`~?\.?(.+?)`", repl, text) def _decorate_with_warning(func, wtype, message, docstring_header=None): @@ -108,7 +116,7 @@ def _decorate_with_warning(func, wtype, message, docstring_header=None): warnings.warn(message, wtype, stacklevel=3) return fn(*args, **kwargs) - doc = func.__doc__ is not None and func.__doc__ or '' + doc = func.__doc__ is not None and func.__doc__ or "" if docstring_header is not None: docstring_header %= dict(func=func.__name__) @@ -118,6 +126,7 @@ def _decorate_with_warning(func, wtype, message, docstring_header=None): decorated.__doc__ = doc return decorated + import textwrap @@ -135,7 +144,7 @@ def _dedent_docstring(text): def inject_docstring_text(doctext, injecttext, pos): doctext = _dedent_docstring(doctext or "") - lines = doctext.split('\n') + lines = doctext.split("\n") injectlines = textwrap.dedent(injecttext).split("\n") if injectlines[0]: injectlines.insert(0, "") diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 7e387f4f25..6a286998b4 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -25,7 +25,7 @@ from . import _collections def md5_hex(x): if compat.py3k: - x = x.encode('utf-8') + x = x.encode("utf-8") m = hashlib.md5() m.update(x) return m.hexdigest() @@ -49,7 +49,7 @@ class safe_reraise(object): """ - __slots__ = ('warn_only', '_exc_info') + __slots__ = ("warn_only", "_exc_info") def __init__(self, warn_only=False): self.warn_only = warn_only @@ -61,7 +61,7 @@ class safe_reraise(object): # see #2703 for notes if type_ is None: exc_type, exc_value, exc_tb = self._exc_info - self._exc_info = None # remove potential circular references + self._exc_info = None # remove potential circular references if not self.warn_only: compat.reraise(exc_type, exc_value, exc_tb) else: @@ -71,8 +71,9 @@ class safe_reraise(object): warn( "An exception has occurred during handling of a " "previous exception. The previous exception " - "is:\n %s %s\n" % (self._exc_info[0], self._exc_info[1])) - self._exc_info = None # remove potential circular references + "is:\n %s %s\n" % (self._exc_info[0], self._exc_info[1]) + ) + self._exc_info = None # remove potential circular references compat.reraise(type_, value, traceback) @@ -84,7 +85,7 @@ def decode_slice(slc): """ ret = [] for x in slc.start, slc.stop, slc.step: - if hasattr(x, '__index__'): + if hasattr(x, "__index__"): x = x.__index__() ret.append(x) return tuple(ret) @@ -93,9 +94,10 @@ def decode_slice(slc): def _unique_symbols(used, *bases): used = set(used) for base in bases: - pool = itertools.chain((base,), - compat.itertools_imap(lambda i: base + str(i), - range(1000))) + pool = itertools.chain( + (base,), + compat.itertools_imap(lambda i: base + str(i), range(1000)), + ) for sym in pool: if sym not in used: used.add(sym) @@ -122,21 +124,25 @@ def decorator(target): raise Exception("not a decoratable function") spec = compat.inspect_getfullargspec(fn) names = tuple(spec[0]) + spec[1:3] + (fn.__name__,) - targ_name, fn_name = _unique_symbols(names, 'target', 'fn') + targ_name, fn_name = _unique_symbols(names, "target", "fn") metadata = dict(target=targ_name, fn=fn_name) metadata.update(format_argspec_plus(spec, grouped=False)) - metadata['name'] = fn.__name__ - code = """\ + metadata["name"] = fn.__name__ + code = ( + """\ def %(name)s(%(args)s): return %(target)s(%(fn)s, %(apply_kw)s) -""" % metadata - decorated = _exec_code_in_env(code, - {targ_name: target, fn_name: fn}, - fn.__name__) - decorated.__defaults__ = getattr(fn, 'im_func', fn).__defaults__ +""" + % metadata + ) + decorated = _exec_code_in_env( + code, {targ_name: target, fn_name: fn}, fn.__name__ + ) + decorated.__defaults__ = getattr(fn, "im_func", fn).__defaults__ decorated.__wrapped__ = fn return update_wrapper(decorated, fn) + return update_wrapper(decorate, target) @@ -155,31 +161,38 @@ def public_factory(target, location): if isinstance(target, type): fn = target.__init__ callable_ = target - doc = "Construct a new :class:`.%s` object. \n\n"\ - "This constructor is mirrored as a public API function; "\ - "see :func:`~%s` "\ - "for a full usage and argument description." % ( - target.__name__, location, ) + doc = ( + "Construct a new :class:`.%s` object. \n\n" + "This constructor is mirrored as a public API function; " + "see :func:`~%s` " + "for a full usage and argument description." + % (target.__name__, location) + ) else: fn = callable_ = target - doc = "This function is mirrored; see :func:`~%s` "\ + doc = ( + "This function is mirrored; see :func:`~%s` " "for a description of arguments." % location + ) location_name = location.split(".")[-1] spec = compat.inspect_getfullargspec(fn) del spec[0][0] metadata = format_argspec_plus(spec, grouped=False) - metadata['name'] = location_name - code = """\ + metadata["name"] = location_name + code = ( + """\ def %(name)s(%(args)s): return cls(%(apply_kw)s) -""" % metadata - env = {'cls': callable_, 'symbol': symbol} +""" + % metadata + ) + env = {"cls": callable_, "symbol": symbol} exec(code, env) decorated = env[location_name] decorated.__doc__ = fn.__doc__ decorated.__module__ = "sqlalchemy" + location.rsplit(".", 1)[0] - if compat.py2k or hasattr(fn, '__func__'): + if compat.py2k or hasattr(fn, "__func__"): fn.__func__.__doc__ = doc else: fn.__doc__ = doc @@ -187,7 +200,6 @@ def %(name)s(%(args)s): class PluginLoader(object): - def __init__(self, group, auto_fn=None): self.group = group self.impls = {} @@ -211,14 +223,13 @@ class PluginLoader(object): except ImportError: pass else: - for impl in pkg_resources.iter_entry_points( - self.group, name): + for impl in pkg_resources.iter_entry_points(self.group, name): self.impls[name] = impl.load return impl.load() raise exc.NoSuchModuleError( - "Can't load plugin: %s:%s" % - (self.group, name)) + "Can't load plugin: %s:%s" % (self.group, name) + ) def register(self, name, modulepath, objname): def load(): @@ -226,6 +237,7 @@ class PluginLoader(object): for token in modulepath.split(".")[1:]: mod = getattr(mod, token) return getattr(mod, objname) + self.impls[name] = load @@ -245,10 +257,13 @@ def get_cls_kwargs(cls, _set=None): if toplevel: _set = set() - ctr = cls.__dict__.get('__init__', False) + ctr = cls.__dict__.get("__init__", False) - has_init = ctr and isinstance(ctr, types.FunctionType) and \ - isinstance(ctr.__code__, types.CodeType) + has_init = ( + ctr + and isinstance(ctr, types.FunctionType) + and isinstance(ctr.__code__, types.CodeType) + ) if has_init: names, has_kw = inspect_func_args(ctr) @@ -262,7 +277,7 @@ def get_cls_kwargs(cls, _set=None): if get_cls_kwargs(c, _set) is None: break - _set.discard('self') + _set.discard("self") return _set @@ -278,7 +293,9 @@ try: has_kw = bool(co.co_flags & CO_VARKEYWORDS) return args, has_kw + except ImportError: + def inspect_func_args(fn): names, _, has_kw, _ = compat.inspect_getargspec(fn) return names, bool(has_kw) @@ -309,23 +326,26 @@ def get_callable_argspec(fn, no_self=False, _is_init=False): elif inspect.isfunction(fn): if _is_init and no_self: spec = compat.inspect_getargspec(fn) - return compat.ArgSpec(spec.args[1:], spec.varargs, - spec.keywords, spec.defaults) + return compat.ArgSpec( + spec.args[1:], spec.varargs, spec.keywords, spec.defaults + ) else: return compat.inspect_getargspec(fn) elif inspect.ismethod(fn): if no_self and (_is_init or fn.__self__): spec = compat.inspect_getargspec(fn.__func__) - return compat.ArgSpec(spec.args[1:], spec.varargs, - spec.keywords, spec.defaults) + return compat.ArgSpec( + spec.args[1:], spec.varargs, spec.keywords, spec.defaults + ) else: return compat.inspect_getargspec(fn.__func__) elif inspect.isclass(fn): return get_callable_argspec( - fn.__init__, no_self=no_self, _is_init=True) - elif hasattr(fn, '__func__'): + fn.__init__, no_self=no_self, _is_init=True + ) + elif hasattr(fn, "__func__"): return compat.inspect_getargspec(fn.__func__) - elif hasattr(fn, '__call__'): + elif hasattr(fn, "__call__"): if inspect.ismethod(fn.__call__): return get_callable_argspec(fn.__call__, no_self=no_self) else: @@ -375,13 +395,14 @@ def format_argspec_plus(fn, grouped=True): if spec[0]: self_arg = spec[0][0] elif spec[1]: - self_arg = '%s[0]' % spec[1] + self_arg = "%s[0]" % spec[1] else: self_arg = None if compat.py3k: apply_pos = compat.inspect_formatargspec( - spec[0], spec[1], spec[2], None, spec[4]) + spec[0], spec[1], spec[2], None, spec[4] + ) num_defaults = 0 if spec[3]: num_defaults += len(spec[3]) @@ -396,19 +417,31 @@ def format_argspec_plus(fn, grouped=True): name_args = spec[0] if num_defaults: - defaulted_vals = name_args[0 - num_defaults:] + defaulted_vals = name_args[0 - num_defaults :] else: defaulted_vals = () apply_kw = compat.inspect_formatargspec( - name_args, spec[1], spec[2], defaulted_vals, - formatvalue=lambda x: '=' + x) + name_args, + spec[1], + spec[2], + defaulted_vals, + formatvalue=lambda x: "=" + x, + ) if grouped: - return dict(args=args, self_arg=self_arg, - apply_pos=apply_pos, apply_kw=apply_kw) + return dict( + args=args, + self_arg=self_arg, + apply_pos=apply_pos, + apply_kw=apply_kw, + ) else: - return dict(args=args[1:-1], self_arg=self_arg, - apply_pos=apply_pos[1:-1], apply_kw=apply_kw[1:-1]) + return dict( + args=args[1:-1], + self_arg=self_arg, + apply_pos=apply_pos[1:-1], + apply_kw=apply_kw[1:-1], + ) def format_argspec_init(method, grouped=True): @@ -422,14 +455,17 @@ def format_argspec_init(method, grouped=True): """ if method is object.__init__: - args = grouped and '(self)' or 'self' + args = grouped and "(self)" or "self" else: try: return format_argspec_plus(method, grouped=grouped) except TypeError: - args = (grouped and '(self, *args, **kwargs)' - or 'self, *args, **kwargs') - return dict(self_arg='self', args=args, apply_pos=args, apply_kw=args) + args = ( + grouped + and "(self, *args, **kwargs)" + or "self, *args, **kwargs" + ) + return dict(self_arg="self", args=args, apply_pos=args, apply_kw=args) def getargspec_init(method): @@ -445,9 +481,9 @@ def getargspec_init(method): return compat.inspect_getargspec(method) except TypeError: if method is object.__init__: - return (['self'], None, None, None) + return (["self"], None, None, None) else: - return (['self'], 'args', 'kwargs', None) + return (["self"], "args", "kwargs", None) def unbound_method_to_callable(func_or_cls): @@ -479,8 +515,9 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()): vargs = None for i, insp in enumerate(to_inspect): try: - (_args, _vargs, vkw, defaults) = \ - compat.inspect_getargspec(insp.__init__) + (_args, _vargs, vkw, defaults) = compat.inspect_getargspec( + insp.__init__ + ) except TypeError: continue else: @@ -493,16 +530,17 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()): else: pos_args.extend(_args[1:]) else: - kw_args.update([ - (arg, missing) for arg in _args[1:-default_len] - ]) + kw_args.update( + [(arg, missing) for arg in _args[1:-default_len]] + ) if default_len: - kw_args.update([ - (arg, default) - for arg, default - in zip(_args[-default_len:], defaults) - ]) + kw_args.update( + [ + (arg, default) + for arg, default in zip(_args[-default_len:], defaults) + ] + ) output = [] output.extend(repr(getattr(obj, arg, None)) for arg in pos_args) @@ -516,7 +554,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()): try: val = getattr(obj, arg, missing) if val is not missing and val != defval: - output.append('%s=%r' % (arg, val)) + output.append("%s=%r" % (arg, val)) except Exception: pass @@ -525,7 +563,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()): try: val = getattr(obj, arg, missing) if val is not missing and val != defval: - output.append('%s=%r' % (arg, val)) + output.append("%s=%r" % (arg, val)) except Exception: pass @@ -538,16 +576,19 @@ class portable_instancemethod(object): """ - __slots__ = 'target', 'name', 'kwargs', '__weakref__' + __slots__ = "target", "name", "kwargs", "__weakref__" def __getstate__(self): - return {'target': self.target, 'name': self.name, - 'kwargs': self.kwargs} + return { + "target": self.target, + "name": self.name, + "kwargs": self.kwargs, + } def __setstate__(self, state): - self.target = state['target'] - self.name = state['name'] - self.kwargs = state.get('kwargs', ()) + self.target = state["target"] + self.name = state["name"] + self.kwargs = state.get("kwargs", ()) def __init__(self, meth, kwargs=()): self.target = meth.__self__ @@ -583,8 +624,11 @@ def class_hierarchy(cls): if compat.py2k: if isinstance(c, types.ClassType): continue - bases = (_ for _ in c.__bases__ - if _ not in hier and not isinstance(_, types.ClassType)) + bases = ( + _ + for _ in c.__bases__ + if _ not in hier and not isinstance(_, types.ClassType) + ) else: bases = (_ for _ in c.__bases__ if _ not in hier) @@ -593,11 +637,12 @@ def class_hierarchy(cls): hier.add(b) if compat.py3k: - if c.__module__ == 'builtins' or not hasattr(c, '__subclasses__'): + if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"): continue else: - if c.__module__ == '__builtin__' or not hasattr( - c, '__subclasses__'): + if c.__module__ == "__builtin__" or not hasattr( + c, "__subclasses__" + ): continue for s in [_ for _ in c.__subclasses__() if _ not in hier]: @@ -622,26 +667,45 @@ def iterate_attributes(cls): break -def monkeypatch_proxied_specials(into_cls, from_cls, skip=None, only=None, - name='self.proxy', from_instance=None): +def monkeypatch_proxied_specials( + into_cls, + from_cls, + skip=None, + only=None, + name="self.proxy", + from_instance=None, +): """Automates delegation of __specials__ for a proxying type.""" if only: dunders = only else: if skip is None: - skip = ('__slots__', '__del__', '__getattribute__', - '__metaclass__', '__getstate__', '__setstate__') - dunders = [m for m in dir(from_cls) - if (m.startswith('__') and m.endswith('__') and - not hasattr(into_cls, m) and m not in skip)] + skip = ( + "__slots__", + "__del__", + "__getattribute__", + "__metaclass__", + "__getstate__", + "__setstate__", + ) + dunders = [ + m + for m in dir(from_cls) + if ( + m.startswith("__") + and m.endswith("__") + and not hasattr(into_cls, m) + and m not in skip + ) + ] for method in dunders: try: fn = getattr(from_cls, method) - if not hasattr(fn, '__call__'): + if not hasattr(fn, "__call__"): continue - fn = getattr(fn, 'im_func', fn) + fn = getattr(fn, "im_func", fn) except AttributeError: continue try: @@ -649,11 +713,13 @@ def monkeypatch_proxied_specials(into_cls, from_cls, skip=None, only=None, fn_args = compat.inspect_formatargspec(spec[0]) d_args = compat.inspect_formatargspec(spec[0][1:]) except TypeError: - fn_args = '(self, *args, **kw)' - d_args = '(*args, **kw)' + fn_args = "(self, *args, **kw)" + d_args = "(*args, **kw)" - py = ("def %(method)s%(fn_args)s: " - "return %(name)s.%(method)s%(d_args)s" % locals()) + py = ( + "def %(method)s%(fn_args)s: " + "return %(name)s.%(method)s%(d_args)s" % locals() + ) env = from_instance is not None and {name: from_instance} or {} compat.exec_(py, env) @@ -667,8 +733,9 @@ def monkeypatch_proxied_specials(into_cls, from_cls, skip=None, only=None, def methods_equivalent(meth1, meth2): """Return True if the two methods are the same implementation.""" - return getattr(meth1, '__func__', meth1) is getattr( - meth2, '__func__', meth2) + return getattr(meth1, "__func__", meth1) is getattr( + meth2, "__func__", meth2 + ) def as_interface(obj, cls=None, methods=None, required=None): @@ -705,12 +772,12 @@ def as_interface(obj, cls=None, methods=None, required=None): """ if not cls and not methods: - raise TypeError('a class or collection of method names are required') + raise TypeError("a class or collection of method names are required") if isinstance(cls, type) and isinstance(obj, cls): return obj - interface = set(methods or [m for m in dir(cls) if not m.startswith('_')]) + interface = set(methods or [m for m in dir(cls) if not m.startswith("_")]) implemented = set(dir(obj)) complies = operator.ge @@ -727,15 +794,17 @@ def as_interface(obj, cls=None, methods=None, required=None): # No dict duck typing here. if not isinstance(obj, dict): - qualifier = complies is operator.gt and 'any of' or 'all of' - raise TypeError("%r does not implement %s: %s" % ( - obj, qualifier, ', '.join(interface))) + qualifier = complies is operator.gt and "any of" or "all of" + raise TypeError( + "%r does not implement %s: %s" + % (obj, qualifier, ", ".join(interface)) + ) class AnonymousInterface(object): """A callable-holding shell.""" if cls: - AnonymousInterface.__name__ = 'Anonymous' + cls.__name__ + AnonymousInterface.__name__ = "Anonymous" + cls.__name__ found = set() for method, impl in dictlike_iteritems(obj): @@ -749,8 +818,10 @@ def as_interface(obj, cls=None, methods=None, required=None): if complies(found, required): return AnonymousInterface - raise TypeError("dictionary does not contain required keys %s" % - ', '.join(required - found)) + raise TypeError( + "dictionary does not contain required keys %s" + % ", ".join(required - found) + ) class memoized_property(object): @@ -791,6 +862,7 @@ def memoized_instancemethod(fn): memo.__doc__ = fn.__doc__ self.__dict__[fn.__name__] = memo return result + return update_wrapper(oneshot, fn) @@ -831,14 +903,14 @@ class MemoizedSlots(object): raise AttributeError(key) def __getattr__(self, key): - if key.startswith('_memoized'): + if key.startswith("_memoized"): raise AttributeError(key) - elif hasattr(self, '_memoized_attr_%s' % key): - value = getattr(self, '_memoized_attr_%s' % key)() + elif hasattr(self, "_memoized_attr_%s" % key): + value = getattr(self, "_memoized_attr_%s" % key)() setattr(self, key, value) return value - elif hasattr(self, '_memoized_method_%s' % key): - fn = getattr(self, '_memoized_method_%s' % key) + elif hasattr(self, "_memoized_method_%s" % key): + fn = getattr(self, "_memoized_method_%s" % key) def oneshot(*args, **kw): result = fn(*args, **kw) @@ -847,6 +919,7 @@ class MemoizedSlots(object): memo.__doc__ = fn.__doc__ setattr(self, key, memo) return result + oneshot.__doc__ = fn.__doc__ return oneshot else: @@ -859,12 +932,14 @@ def dependency_for(modulename, add_to_all=False): # unfortunately importlib doesn't work that great either tokens = modulename.split(".") mod = compat.import_( - ".".join(tokens[0:-1]), globals(), locals(), [tokens[-1]]) + ".".join(tokens[0:-1]), globals(), locals(), [tokens[-1]] + ) mod = getattr(mod, tokens[-1]) setattr(mod, obj.__name__, obj) if add_to_all and hasattr(mod, "__all__"): mod.__all__.append(obj.__name__) return obj + return decorate @@ -891,10 +966,7 @@ class dependencies(object): for dep in deps: tokens = dep.split(".") self.import_deps.append( - dependencies._importlater( - ".".join(tokens[0:-1]), - tokens[-1] - ) + dependencies._importlater(".".join(tokens[0:-1]), tokens[-1]) ) def __call__(self, fn): @@ -902,7 +974,7 @@ class dependencies(object): spec = compat.inspect_getfullargspec(fn) spec_zero = list(spec[0]) - hasself = spec_zero[0] in ('self', 'cls') + hasself = spec_zero[0] in ("self", "cls") for i in range(len(import_deps)): spec[0][i + (1 if hasself else 0)] = "import_deps[%r]" % i @@ -915,13 +987,13 @@ class dependencies(object): outer_spec = format_argspec_plus(spec, grouped=False) - code = 'lambda %(args)s: fn(%(apply_kw)s)' % { - "args": outer_spec['args'], - "apply_kw": inner_spec['apply_kw'] + code = "lambda %(args)s: fn(%(apply_kw)s)" % { + "args": outer_spec["args"], + "apply_kw": inner_spec["apply_kw"], } decorated = eval(code, locals()) - decorated.__defaults__ = getattr(fn, 'im_func', fn).__defaults__ + decorated.__defaults__ = getattr(fn, "im_func", fn).__defaults__ return update_wrapper(decorated, fn) @classmethod @@ -961,26 +1033,27 @@ class dependencies(object): raise ImportError( "importlater.resolve_all() hasn't " "been called (this is %s %s)" - % (self._il_path, self._il_addtl)) + % (self._il_path, self._il_addtl) + ) return getattr(self._initial_import, self._il_addtl) def _resolve(self): dependencies._unresolved.discard(self) self._initial_import = compat.import_( - self._il_path, globals(), locals(), - [self._il_addtl]) + self._il_path, globals(), locals(), [self._il_addtl] + ) def __getattr__(self, key): - if key == 'module': - raise ImportError("Could not resolve module %s" - % self._full_path) + if key == "module": + raise ImportError( + "Could not resolve module %s" % self._full_path + ) try: attr = getattr(self.module, key) except AttributeError: raise AttributeError( - "Module %s has no attribute '%s'" % - (self._full_path, key) + "Module %s has no attribute '%s'" % (self._full_path, key) ) self.__dict__[key] = attr return attr @@ -990,9 +1063,9 @@ class dependencies(object): def asbool(obj): if isinstance(obj, compat.string_types): obj = obj.strip().lower() - if obj in ['true', 'yes', 'on', 'y', 't', '1']: + if obj in ["true", "yes", "on", "y", "t", "1"]: return True - elif obj in ['false', 'no', 'off', 'n', 'f', '0']: + elif obj in ["false", "no", "off", "n", "f", "0"]: return False else: raise ValueError("String is not true/false: %r" % obj) @@ -1004,11 +1077,13 @@ def bool_or_str(*text): boolean, or one of a set of "alternate" string values. """ + def bool_or_value(obj): if obj in text: return obj else: return asbool(obj) + return bool_or_value @@ -1026,9 +1101,11 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True): when coercing to boolean. """ - if key in kw and ( - not isinstance(type_, type) or not isinstance(kw[key], type_) - ) and kw[key] is not None: + if ( + key in kw + and (not isinstance(type_, type) or not isinstance(kw[key], type_)) + and kw[key] is not None + ): if type_ is bool and flexi_bool: kw[key] = asbool(kw[key]) else: @@ -1044,8 +1121,8 @@ def constructor_copy(obj, cls, *args, **kw): names = get_cls_kwargs(cls) kw.update( - (k, obj.__dict__[k]) for k in names.difference(kw) - if k in obj.__dict__) + (k, obj.__dict__[k]) for k in names.difference(kw) if k in obj.__dict__ + ) return cls(*args, **kw) @@ -1072,10 +1149,11 @@ def duck_type_collection(specimen, default=None): property is present, return that preferentially. """ - if hasattr(specimen, '__emulates__'): + if hasattr(specimen, "__emulates__"): # canonicalize set vs sets.Set to a standard: the builtin set - if (specimen.__emulates__ is not None and - issubclass(specimen.__emulates__, set)): + if specimen.__emulates__ is not None and issubclass( + specimen.__emulates__, set + ): return set else: return specimen.__emulates__ @@ -1088,11 +1166,11 @@ def duck_type_collection(specimen, default=None): elif isa(specimen, dict): return dict - if hasattr(specimen, 'append'): + if hasattr(specimen, "append"): return list - elif hasattr(specimen, 'add'): + elif hasattr(specimen, "add"): return set - elif hasattr(specimen, 'set'): + elif hasattr(specimen, "set"): return dict else: return default @@ -1104,41 +1182,43 @@ def assert_arg_type(arg, argtype, name): else: if isinstance(argtype, tuple): raise exc.ArgumentError( - "Argument '%s' is expected to be one of type %s, got '%s'" % - (name, ' or '.join("'%s'" % a for a in argtype), type(arg))) + "Argument '%s' is expected to be one of type %s, got '%s'" + % (name, " or ".join("'%s'" % a for a in argtype), type(arg)) + ) else: raise exc.ArgumentError( - "Argument '%s' is expected to be of type '%s', got '%s'" % - (name, argtype, type(arg))) + "Argument '%s' is expected to be of type '%s', got '%s'" + % (name, argtype, type(arg)) + ) def dictlike_iteritems(dictlike): """Return a (key, value) iterator for almost any dict-like object.""" if compat.py3k: - if hasattr(dictlike, 'items'): + if hasattr(dictlike, "items"): return list(dictlike.items()) else: - if hasattr(dictlike, 'iteritems'): + if hasattr(dictlike, "iteritems"): return dictlike.iteritems() - elif hasattr(dictlike, 'items'): + elif hasattr(dictlike, "items"): return iter(dictlike.items()) - getter = getattr(dictlike, '__getitem__', getattr(dictlike, 'get', None)) + getter = getattr(dictlike, "__getitem__", getattr(dictlike, "get", None)) if getter is None: - raise TypeError( - "Object '%r' is not dict-like" % dictlike) + raise TypeError("Object '%r' is not dict-like" % dictlike) + + if hasattr(dictlike, "iterkeys"): - if hasattr(dictlike, 'iterkeys'): def iterator(): for key in dictlike.iterkeys(): yield key, getter(key) + return iterator() - elif hasattr(dictlike, 'keys'): + elif hasattr(dictlike, "keys"): return iter((key, getter(key)) for key in dictlike.keys()) else: - raise TypeError( - "Object '%r' is not dict-like" % dictlike) + raise TypeError("Object '%r' is not dict-like" % dictlike) class classproperty(property): @@ -1207,7 +1287,8 @@ class _symbol(int): def __repr__(self): return "symbol(%r)" % self.name -_symbol.__name__ = 'symbol' + +_symbol.__name__ = "symbol" class symbol(object): @@ -1231,6 +1312,7 @@ class symbol(object): ``doc`` here. """ + symbols = {} _lock = compat.threading.Lock() @@ -1292,9 +1374,11 @@ class _hash_limit_string(compat.text_type): """ + def __new__(cls, value, num, args): - interpolated = (value % args) + \ - (" (this warning may be suppressed after %d occurrences)" % num) + interpolated = (value % args) + ( + " (this warning may be suppressed after %d occurrences)" % num + ) self = super(_hash_limit_string, cls).__new__(cls, interpolated) self._hash = hash("%s_%d" % (value, hash(interpolated) % num)) return self @@ -1340,8 +1424,8 @@ def only_once(fn): return go -_SQLA_RE = re.compile(r'sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py') -_UNITTEST_RE = re.compile(r'unit(?:2|test2?/)') +_SQLA_RE = re.compile(r"sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py") +_UNITTEST_RE = re.compile(r"unit(?:2|test2?/)") def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE): @@ -1363,18 +1447,17 @@ def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE): start += 1 while start <= end and exclude_suffix.search(tb[end]): end -= 1 - return tb[start:end + 1] + return tb[start : end + 1] + NoneType = type(None) def attrsetter(attrname): - code = \ - "def set(obj, value):"\ - " obj.%s = value" % attrname + code = "def set(obj, value):" " obj.%s = value" % attrname env = locals().copy() exec(code, env) - return env['set'] + return env["set"] class EnsureKWArgType(type): @@ -1382,6 +1465,7 @@ class EnsureKWArgType(type): don't already. """ + def __init__(cls, clsname, bases, clsdict): fn_reg = cls.ensure_kwarg if fn_reg: @@ -1396,9 +1480,9 @@ class EnsureKWArgType(type): super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict) def _wrap_w_kw(self, fn): - def wrap(*arg, **kw): return fn(*arg) + return update_wrapper(wrap, fn) @@ -1410,15 +1494,15 @@ def wrap_callable(wrapper, fn): object with __call__ method """ - if hasattr(fn, '__name__'): + if hasattr(fn, "__name__"): return update_wrapper(wrapper, fn) else: _f = wrapper _f.__name__ = fn.__class__.__name__ - if hasattr(fn, '__module__'): + if hasattr(fn, "__module__"): _f.__module__ = fn.__module__ - if hasattr(fn.__call__, '__doc__') and fn.__call__.__doc__: + if hasattr(fn.__call__, "__doc__") and fn.__call__.__doc__: _f.__doc__ = fn.__call__.__doc__ elif fn.__doc__: _f.__doc__ = fn.__doc__ @@ -1468,4 +1552,3 @@ def quoted_token_parser(value): idx += 1 return ["".join(token) for token in result] - diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py index 640f70ea95..5e56e855ae 100644 --- a/lib/sqlalchemy/util/queue.py +++ b/lib/sqlalchemy/util/queue.py @@ -23,7 +23,7 @@ from time import time as _time from .compat import threading -__all__ = ['Empty', 'Full', 'Queue'] +__all__ = ["Empty", "Full", "Queue"] class Empty(Exception): diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py index 5f516d67e4..95391c31b9 100644 --- a/lib/sqlalchemy/util/topological.py +++ b/lib/sqlalchemy/util/topological.py @@ -10,7 +10,7 @@ from ..exc import CircularDependencyError from .. import util -__all__ = ['sort', 'sort_as_subsets', 'find_cycles'] +__all__ = ["sort", "sort_as_subsets", "find_cycles"] def sort_as_subsets(tuples, allitems, deterministic_order=False): @@ -33,7 +33,7 @@ def sort_as_subsets(tuples, allitems, deterministic_order=False): raise CircularDependencyError( "Circular dependency detected.", find_cycles(tuples, allitems), - _gen_edges(edges) + _gen_edges(edges), ) todo.difference_update(output) @@ -79,7 +79,7 @@ def find_cycles(tuples, allitems): top = stack[-1] for node in edges[top]: if node in stack: - cyc = stack[stack.index(node):] + cyc = stack[stack.index(node) :] todo.difference_update(cyc) output.update(cyc) @@ -93,8 +93,4 @@ def find_cycles(tuples, allitems): def _gen_edges(edges): - return set([ - (right, left) - for left in edges - for right in edges[left] - ]) + return set([(right, left) for left in edges for right in edges[left]]) diff --git a/setup.py b/setup.py index 909a4ebdaf..5c9d4e4f0d 100644 --- a/setup.py +++ b/setup.py @@ -15,26 +15,30 @@ cmdclass = {} if sys.version_info < (2, 7): raise Exception("SQLAlchemy requires Python 2.7 or higher.") -cpython = platform.python_implementation() == 'CPython' +cpython = platform.python_implementation() == "CPython" ext_modules = [ - Extension('sqlalchemy.cprocessors', - sources=['lib/sqlalchemy/cextension/processors.c']), - Extension('sqlalchemy.cresultproxy', - sources=['lib/sqlalchemy/cextension/resultproxy.c']), - Extension('sqlalchemy.cutils', - sources=['lib/sqlalchemy/cextension/utils.c']) + Extension( + "sqlalchemy.cprocessors", + sources=["lib/sqlalchemy/cextension/processors.c"], + ), + Extension( + "sqlalchemy.cresultproxy", + sources=["lib/sqlalchemy/cextension/resultproxy.c"], + ), + Extension( + "sqlalchemy.cutils", sources=["lib/sqlalchemy/cextension/utils.c"] + ), ] ext_errors = (CCompilerError, DistutilsExecError, DistutilsPlatformError) -if sys.platform == 'win32': +if sys.platform == "win32": # 2.6's distutils.msvc9compiler can raise an IOError when failing to # find the compiler ext_errors += (IOError,) class BuildFailed(Exception): - def __init__(self): self.cause = sys.exc_info()[1] # work around py 2/3 different syntax @@ -59,11 +63,11 @@ class ve_build_ext(build_ext): raise BuildFailed() raise -cmdclass['build_ext'] = ve_build_ext +cmdclass["build_ext"] = ve_build_ext -class Distribution(_Distribution): +class Distribution(_Distribution): def has_ext_modules(self): # We want to always claim that we have ext_modules. This will be fine # if we don't actually have them (such as on PyPy) because nothing @@ -79,7 +83,7 @@ class PyTest(TestCommand): # #integrating-with-setuptools-python-setup-py-test-pytest-runner # TODO: prefer pytest-runner package at some point, however it was # not working at the time of this comment. - user_options = [('pytest-args=', 'a', "Arguments to pass to py.test")] + user_options = [("pytest-args=", "a", "Arguments to pass to py.test")] default_options = ["-n", "4", "-q", "--nomemory"] @@ -94,39 +98,45 @@ class PyTest(TestCommand): def run_tests(self): import shlex + # import here, cause outside the eggs aren't loaded import pytest - errno = pytest.main(self.default_options + shlex.split(self.pytest_args)) + + errno = pytest.main( + self.default_options + shlex.split(self.pytest_args) + ) sys.exit(errno) -cmdclass['test'] = PyTest + +cmdclass["test"] = PyTest def status_msgs(*msgs): - print('*' * 75) + print("*" * 75) for msg in msgs: print(msg) - print('*' * 75) + print("*" * 75) with open( - os.path.join( - os.path.dirname(__file__), - 'lib', 'sqlalchemy', '__init__.py')) as v_file: - VERSION = re.compile( - r""".*__version__ = ["'](.*?)['"]""", - re.S).match(v_file.read()).group(1) - -with open(os.path.join(os.path.dirname(__file__), 'README.rst')) as r_file: + os.path.join(os.path.dirname(__file__), "lib", "sqlalchemy", "__init__.py") +) as v_file: + VERSION = ( + re.compile(r""".*__version__ = ["'](.*?)['"]""", re.S) + .match(v_file.read()) + .group(1) + ) + +with open(os.path.join(os.path.dirname(__file__), "README.rst")) as r_file: readme = r_file.read() def run_setup(with_cext): kwargs = {} if with_cext: - kwargs['ext_modules'] = ext_modules + kwargs["ext_modules"] = ext_modules else: - kwargs['ext_modules'] = [] + kwargs["ext_modules"] = [] setup( name="SQLAlchemy", @@ -135,11 +145,15 @@ def run_setup(with_cext): author="Mike Bayer", author_email="mike_mp@zzzcomputing.com", url="http://www.sqlalchemy.org", - packages=find_packages('lib'), - package_dir={'': 'lib'}, + packages=find_packages("lib"), + package_dir={"": "lib"}, license="MIT License", cmdclass=cmdclass, - tests_require=['pytest>=2.5.2,!=3.9.1,!=3.9.2', 'mock', 'pytest-xdist'], + tests_require=[ + "pytest>=2.5.2,!=3.9.1,!=3.9.2", + "mock", + "pytest-xdist", + ], long_description=readme, python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*", classifiers=[ @@ -161,33 +175,34 @@ def run_setup(with_cext): ], distclass=Distribution, extras_require={ - 'mysql': ['mysqlclient'], - 'pymysql': ['pymysql'], - 'postgresql': ['psycopg2'], - 'postgresql_psycopg2binary': ['psycopg2-binary'], - 'postgresql_pg8000': ['pg8000'], - 'postgresql_psycopg2cffi': ['psycopg2cffi'], - 'oracle': ['cx_oracle'], - 'mssql_pyodbc': ['pyodbc'], - 'mssql_pymssql': ['pymssql'], - 'mssql': ['pyodbc'], + "mysql": ["mysqlclient"], + "pymysql": ["pymysql"], + "postgresql": ["psycopg2"], + "postgresql_psycopg2binary": ["psycopg2-binary"], + "postgresql_pg8000": ["pg8000"], + "postgresql_psycopg2cffi": ["psycopg2cffi"], + "oracle": ["cx_oracle"], + "mssql_pyodbc": ["pyodbc"], + "mssql_pymssql": ["pymssql"], + "mssql": ["pyodbc"], }, **kwargs ) + if not cpython: run_setup(False) status_msgs( - "WARNING: C extensions are not supported on " + - "this Python platform, speedups are not enabled.", - "Plain-Python build succeeded." + "WARNING: C extensions are not supported on " + + "this Python platform, speedups are not enabled.", + "Plain-Python build succeeded.", ) -elif os.environ.get('DISABLE_SQLALCHEMY_CEXT'): +elif os.environ.get("DISABLE_SQLALCHEMY_CEXT"): run_setup(False) status_msgs( - "DISABLE_SQLALCHEMY_CEXT is set; " + - "not attempting to build C extensions.", - "Plain-Python build succeeded." + "DISABLE_SQLALCHEMY_CEXT is set; " + + "not attempting to build C extensions.", + "Plain-Python build succeeded.", ) else: @@ -196,16 +211,16 @@ else: except BuildFailed as exc: status_msgs( exc.cause, - "WARNING: The C extension could not be compiled, " + - "speedups are not enabled.", + "WARNING: The C extension could not be compiled, " + + "speedups are not enabled.", "Failure information, if any, is above.", - "Retrying the build without the C extension now." + "Retrying the build without the C extension now.", ) run_setup(False) status_msgs( - "WARNING: The C extension could not be compiled, " + - "speedups are not enabled.", - "Plain-Python build succeeded." + "WARNING: The C extension could not be compiled, " + + "speedups are not enabled.", + "Plain-Python build succeeded.", ) diff --git a/test/aaa_profiling/test_compiler.py b/test/aaa_profiling/test_compiler.py index 51a92f5290..e0f308814c 100644 --- a/test/aaa_profiling/test_compiler.py +++ b/test/aaa_profiling/test_compiler.py @@ -6,7 +6,7 @@ t1 = t2 = None class CompileTest(fixtures.TestBase, AssertsExecutionResults): - __requires__ = 'cpython', + __requires__ = ("cpython",) __backend__ = True @classmethod @@ -14,13 +14,19 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults): global t1, t2, metadata metadata = MetaData() - t1 = Table('t1', metadata, - Column('c1', Integer, primary_key=True), - Column('c2', String(30))) - - t2 = Table('t2', metadata, - Column('c1', Integer, primary_key=True), - Column('c2', String(30))) + t1 = Table( + "t1", + metadata, + Column("c1", Integer, primary_key=True), + Column("c2", String(30)), + ) + + t2 = Table( + "t2", + metadata, + Column("c1", Integer, primary_key=True), + Column("c2", String(30)), + ) cls.dialect = default.DefaultDialect() @@ -36,6 +42,7 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults): for c in t.c: c.type._type_affinity from sqlalchemy.sql import sqltypes + for t in list(sqltypes._type_map.values()): t._type_affinity @@ -43,7 +50,7 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults): def test_insert(self): t1.insert().compile(dialect=self.dialect) - @profiling.function_call_count(variance=.15) + @profiling.function_call_count(variance=0.15) def test_update(self): t1.update().compile(dialect=self.dialect) @@ -53,6 +60,7 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults): @profiling.function_call_count() def go(): t1.update().where(t1.c.c2 == 12).compile(dialect=self.dialect) + go() def test_select(self): @@ -65,6 +73,7 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults): def go(): s = select([t1], t1.c.c2 == t2.c.c1) s.compile(dialect=self.dialect) + go() def test_select_labels(self): @@ -77,4 +86,5 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults): def go(): s = select([t1], t1.c.c2 == t2.c.c1).apply_labels() s.compile(dialect=self.dialect) + go() diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index 381e82d3c9..a56fcd409a 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -1,18 +1,26 @@ from sqlalchemy.testing import eq_ -from sqlalchemy.orm import mapper, relationship, create_session, \ - clear_mappers, sessionmaker, aliased,\ - Session, subqueryload +from sqlalchemy.orm import ( + mapper, + relationship, + create_session, + clear_mappers, + sessionmaker, + aliased, + Session, + subqueryload, +) from sqlalchemy.orm.mapper import _mapper_registry from sqlalchemy.orm.session import _sessions from sqlalchemy import testing from sqlalchemy.testing import engines -from sqlalchemy import MetaData, Integer, String, ForeignKey, \ - Unicode, select +from sqlalchemy import MetaData, Integer, String, ForeignKey, Unicode, select import sqlalchemy as sa from sqlalchemy.testing.schema import Table, Column from sqlalchemy.sql import column -from sqlalchemy.processors import to_decimal_processor_factory, \ - to_unicode_processor_factory +from sqlalchemy.processors import ( + to_decimal_processor_factory, + to_unicode_processor_factory, +) from sqlalchemy.testing.util import gc_collect import decimal import gc @@ -23,6 +31,7 @@ import itertools import multiprocessing + class A(fixtures.ComparableEntity): pass @@ -35,8 +44,9 @@ class ASub(A): pass -def profile_memory(maxtimes=250, - assert_no_sessions=True, get_num_objects=None): +def profile_memory( + maxtimes=250, assert_no_sessions=True, get_num_objects=None +): def decorate(func): # run the test N times. if length of gc.get_objects() # keeps growing, assert false @@ -48,8 +58,11 @@ def profile_memory(maxtimes=250, # just filter them out so that we get a "flatline" more quickly. if testing.against("sqlite+pysqlite"): - return [o for o in gc.get_objects() - if not isinstance(o, weakref.ref)] + return [ + o + for o in gc.get_objects() + if not isinstance(o, weakref.ref) + ] else: return gc.get_objects() @@ -72,7 +85,8 @@ def profile_memory(maxtimes=250, func(*func_args) gc_collect() samples.append( - get_num_objects() if get_num_objects is not None + get_num_objects() + if get_num_objects is not None else len(get_objects_skipping_sqlite_issue()) ) @@ -85,11 +99,10 @@ def profile_memory(maxtimes=250, if latest_max > max_: queue.put( ( - 'status', + "status", "Max grew from %s to %s, max has " - "grown for %s samples" % ( - max_, latest_max, max_grew_for - ) + "grown for %s samples" + % (max_, latest_max, max_grew_for), ) ) max_ = latest_max @@ -99,9 +112,9 @@ def profile_memory(maxtimes=250, else: queue.put( ( - 'status', - "Max remained at %s, %s more attempts left" % - (max_, max_grew_for) + "status", + "Max remained at %s, %s more attempts left" + % (max_, max_grew_for), ) ) max_grew_for -= 1 @@ -112,34 +125,30 @@ def profile_memory(maxtimes=250, if not success: queue.put( ( - 'result', + "result", False, "Ran for a total of %d times, memory kept " - "growing: %r" % ( - maxtimes, - samples - ) + "growing: %r" % (maxtimes, samples), ) ) else: - queue.put( - ('result', True, 'success') - ) + queue.put(("result", True, "success")) def run_in_process(*func_args): queue = multiprocessing.Queue() proc = multiprocessing.Process( - target=profile, args=(queue, func_args)) + target=profile, args=(queue, func_args) + ) proc.start() while True: row = queue.get() typ = row[0] - if typ == 'samples': + if typ == "samples": print("sample gc sizes:", row[1]) - elif typ == 'status': + elif typ == "status": print(row[1]) - elif typ == 'result': + elif typ == "result": break else: assert False, "can't parse row" @@ -158,7 +167,6 @@ def assert_no_mappers(): class EnsureZeroed(fixtures.ORMTest): - def setup(self): _sessions.clear() _mapper_registry.clear() @@ -166,17 +174,19 @@ class EnsureZeroed(fixtures.ORMTest): class MemUsageTest(EnsureZeroed): - __tags__ = 'memory_intensive', - __requires__ = 'cpython', + __tags__ = ("memory_intensive",) + __requires__ = ("cpython",) def test_type_compile(self): from sqlalchemy.dialects.sqlite.base import dialect as SQLiteDialect - cast = sa.cast(column('x'), sa.Integer) + + cast = sa.cast(column("x"), sa.Integer) @profile_memory() def go(): dialect = SQLiteDialect() cast.compile(dialect=dialect) + go() @testing.requires.cextensions @@ -184,6 +194,7 @@ class MemUsageTest(EnsureZeroed): @profile_memory() def go(): to_decimal_processor_factory({}, 10) + go() @testing.requires.cextensions @@ -191,13 +202,15 @@ class MemUsageTest(EnsureZeroed): @profile_memory() def go(): to_decimal_processor_factory(decimal.Decimal, 10)(1.2) + go() @testing.requires.cextensions def test_UnicodeResultProcessor_init(self): @profile_memory() def go(): - to_unicode_processor_factory('utf8') + to_unicode_processor_factory("utf8") + go() def test_ad_hoc_types(self): @@ -209,22 +222,24 @@ class MemUsageTest(EnsureZeroed): eng = engines.testing_engine() for args in ( - (types.Integer, ), - (types.String, ), - (types.PickleType, ), - (types.Enum, 'a', 'b', 'c'), - (sqlite.DATETIME, ), - (postgresql.ENUM, 'a', 'b', 'c'), - (types.Interval, ), - (postgresql.INTERVAL, ), - (mysql.VARCHAR, ), + (types.Integer,), + (types.String,), + (types.PickleType,), + (types.Enum, "a", "b", "c"), + (sqlite.DATETIME,), + (postgresql.ENUM, "a", "b", "c"), + (types.Interval,), + (postgresql.INTERVAL,), + (mysql.VARCHAR,), ): + @profile_memory() def go(): type_ = args[0](*args[1:]) bp = type_._cached_bind_processor(eng.dialect) rp = type_._cached_result_processor(eng.dialect, 0) bp, rp # strong reference + go() assert not eng.dialect._type_memos @@ -233,20 +248,20 @@ class MemUsageTest(EnsureZeroed): def test_fixture_failure(self): class Foo(object): pass + stuff = [] @profile_memory(maxtimes=20) def go(): - stuff.extend( - Foo() for i in range(100) - ) + stuff.extend(Foo() for i in range(100)) + go() class MemUsageWBackendTest(EnsureZeroed): - __tags__ = 'memory_intensive', - __requires__ = 'cpython', 'memory_process_intensive' + __tags__ = ("memory_intensive",) + __requires__ = "cpython", "memory_process_intensive" __backend__ = True # ensure a pure growing test trips the assertion @@ -260,27 +275,48 @@ class MemUsageWBackendTest(EnsureZeroed): @profile_memory(maxtimes=10) def go(): x[-1:] = [Foo(), Foo(), Foo(), Foo(), Foo(), Foo()] + go() def test_session(self): metadata = MetaData(self.engine) - table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('col2', String(30))) + table1 = Table( + "mytable", + metadata, + Column( + "col1", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("col2", String(30)), + ) - table2 = Table("mytable2", metadata, - Column('col1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('col2', String(30)), - Column('col3', Integer, ForeignKey("mytable.col1"))) + table2 = Table( + "mytable2", + metadata, + Column( + "col1", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("col2", String(30)), + Column("col3", Integer, ForeignKey("mytable.col1")), + ) metadata.create_all() - m1 = mapper(A, table1, properties={ - "bs": relationship(B, cascade="all, delete", - order_by=table2.c.col1)}) + m1 = mapper( + A, + table1, + properties={ + "bs": relationship( + B, cascade="all, delete", order_by=table2.c.col1 + ) + }, + ) m2 = mapper(B, table2) m3 = mapper(A, table1, non_primary=True) @@ -304,13 +340,15 @@ class MemUsageWBackendTest(EnsureZeroed): [ A(col2="a1", bs=[B(col2="b1"), B(col2="b2")]), A(col2="a2", bs=[]), - A(col2="a3", bs=[B(col2="b3")]) + A(col2="a3", bs=[B(col2="b3")]), ], - alist) + alist, + ) for a in alist: sess.delete(a) sess.flush() + go() metadata.drop_all() @@ -327,44 +365,64 @@ class MemUsageWBackendTest(EnsureZeroed): sess.close() del sess del sessmaker + go() @testing.emits_warning("Compiled statement cache for mapper.*") @testing.emits_warning("Compiled statement cache for lazy loader.*") - @testing.crashes('sqlite', ':memory: connection not suitable here') + @testing.crashes("sqlite", ":memory: connection not suitable here") def test_orm_many_engines(self): metadata = MetaData(self.engine) - table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('col2', String(30))) + table1 = Table( + "mytable", + metadata, + Column( + "col1", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("col2", String(30)), + ) - table2 = Table("mytable2", metadata, - Column('col1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('col2', String(30)), - Column('col3', Integer, ForeignKey("mytable.col1"))) + table2 = Table( + "mytable2", + metadata, + Column( + "col1", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("col2", String(30)), + Column("col3", Integer, ForeignKey("mytable.col1")), + ) metadata.create_all() - m1 = mapper(A, table1, properties={ - "bs": relationship(B, cascade="all, delete", - order_by=table2.c.col1)}, - _compiled_cache_size=50 - ) - m2 = mapper(B, table2, - _compiled_cache_size=50 - ) + m1 = mapper( + A, + table1, + properties={ + "bs": relationship( + B, cascade="all, delete", order_by=table2.c.col1 + ) + }, + _compiled_cache_size=50, + ) + m2 = mapper(B, table2, _compiled_cache_size=50) m3 = mapper(A, table1, non_primary=True) @profile_memory() def go(): engine = engines.testing_engine( - options={'logging_name': 'FOO', - 'pool_logging_name': 'BAR', - 'use_reaper': False} + options={ + "logging_name": "FOO", + "pool_logging_name": "BAR", + "use_reaper": False, + } ) sess = create_session(bind=engine) @@ -384,15 +442,17 @@ class MemUsageWBackendTest(EnsureZeroed): [ A(col2="a1", bs=[B(col2="b1"), B(col2="b2")]), A(col2="a2", bs=[]), - A(col2="a3", bs=[B(col2="b3")]) + A(col2="a3", bs=[B(col2="b3")]), ], - alist) + alist, + ) for a in alist: sess.delete(a) sess.flush() sess.close() engine.dispose() + go() metadata.drop_all() @@ -403,11 +463,14 @@ class MemUsageWBackendTest(EnsureZeroed): def test_many_updates(self): metadata = MetaData(self.engine) - wide_table = Table('t', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - *[Column('col%d' % i, Integer) for i in range(10)] - ) + wide_table = Table( + "t", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + *[Column("col%d" % i, Integer) for i in range(10)] + ) class Wide(object): pass @@ -433,7 +496,7 @@ class MemUsageWBackendTest(EnsureZeroed): # trying to count in binary here, # works enough to trip the test case if pow(2, dec) < x: - setattr(w1, 'col%d' % dec, counter[0]) + setattr(w1, "col%d" % dec, counter[0]) x -= pow(2, dec) dec -= 1 session.flush() @@ -451,9 +514,11 @@ class MemUsageWBackendTest(EnsureZeroed): metadata = self.metadata some_table = Table( - 't', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) + "t", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), ) class SomeClass(object): @@ -465,14 +530,15 @@ class MemUsageWBackendTest(EnsureZeroed): session = Session(testing.db) - target_strings = session.connection().\ - dialect.identifier_preparer._strings + target_strings = ( + session.connection().dialect.identifier_preparer._strings + ) session.close() @profile_memory( assert_no_sessions=False, - get_num_objects=lambda: len(target_strings) + get_num_objects=lambda: len(target_strings), ) def go(): session = Session(testing.db) @@ -485,20 +551,20 @@ class MemUsageWBackendTest(EnsureZeroed): go() - @testing.crashes('mysql+cymysql', 'blocking') + @testing.crashes("mysql+cymysql", "blocking") def test_unicode_warnings(self): metadata = MetaData(self.engine) table1 = Table( - 'mytable', + "mytable", metadata, Column( - 'col1', + "col1", Integer, primary_key=True, - test_needs_autoincrement=True), - Column( - 'col2', - Unicode(30))) + test_needs_autoincrement=True, + ), + Column("col2", Unicode(30)), + ) metadata.create_all() i = [1] @@ -512,9 +578,11 @@ class MemUsageWBackendTest(EnsureZeroed): # execute with a non-unicode object. a warning is emitted, # this warning shouldn't clog up memory. - self.engine.execute(table1.select().where( - table1.c.col2 == 'foo%d' % i[0])) + self.engine.execute( + table1.select().where(table1.c.col2 == "foo%d" % i[0]) + ) i[0] += 1 + try: go() finally: @@ -523,34 +591,53 @@ class MemUsageWBackendTest(EnsureZeroed): def test_warnings_util(self): counter = itertools.count() import warnings + warnings.filterwarnings("ignore", "memusage warning.*") @profile_memory() def go(): util.warn_limited( "memusage warning, param1: %s, param2: %s", - (next(counter), next(counter))) + (next(counter), next(counter)), + ) + go() def test_mapper_reset(self): metadata = MetaData(self.engine) - table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('col2', String(30))) + table1 = Table( + "mytable", + metadata, + Column( + "col1", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("col2", String(30)), + ) - table2 = Table("mytable2", metadata, - Column('col1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('col2', String(30)), - Column('col3', Integer, ForeignKey("mytable.col1"))) + table2 = Table( + "mytable2", + metadata, + Column( + "col1", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("col2", String(30)), + Column("col3", Integer, ForeignKey("mytable.col1")), + ) @profile_memory() def go(): - mapper(A, table1, properties={ - "bs": relationship(B, order_by=table2.c.col1) - }) + mapper( + A, + table1, + properties={"bs": relationship(B, order_by=table2.c.col1)}, + ) mapper(B, table2) mapper(A, table1, non_primary=True) @@ -572,9 +659,10 @@ class MemUsageWBackendTest(EnsureZeroed): [ A(col2="a1", bs=[B(col2="b1"), B(col2="b2")]), A(col2="a2", bs=[]), - A(col2="a3", bs=[B(col2="b3")]) + A(col2="a3", bs=[B(col2="b3")]), ], - alist) + alist, + ) for a in alist: sess.delete(a) @@ -592,28 +680,33 @@ class MemUsageWBackendTest(EnsureZeroed): def test_alias_pathing(self): metadata = MetaData(self.engine) - a = Table("a", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('bid', Integer, ForeignKey('b.id')), - Column('type', String(30)) - ) - - asub = Table("asub", metadata, - Column('id', Integer, ForeignKey('a.id'), - primary_key=True), - Column('data', String(30))) - - b = Table("b", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - ) - mapper(A, a, polymorphic_identity='a', - polymorphic_on=a.c.type) - mapper(ASub, asub, inherits=A, polymorphic_identity='asub') - mapper(B, b, properties={ - 'as_': relationship(A) - }) + a = Table( + "a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("bid", Integer, ForeignKey("b.id")), + Column("type", String(30)), + ) + + asub = Table( + "asub", + metadata, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + Column("data", String(30)), + ) + + b = Table( + "b", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) + mapper(A, a, polymorphic_identity="a", polymorphic_on=a.c.type) + mapper(ASub, asub, inherits=A, polymorphic_identity="asub") + mapper(B, b, properties={"as_": relationship(A)}) metadata.create_all() sess = Session() @@ -633,6 +726,7 @@ class MemUsageWBackendTest(EnsureZeroed): sess = Session() sess.query(B).options(subqueryload(B.as_.of_type(ASub))).all() sess.close() + try: go() finally: @@ -641,34 +735,50 @@ class MemUsageWBackendTest(EnsureZeroed): def test_path_registry(self): metadata = MetaData() - a = Table("a", metadata, - Column('id', Integer, primary_key=True), - Column('foo', Integer), - Column('bar', Integer) - ) + a = Table( + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer), + Column("bar", Integer), + ) m1 = mapper(A, a) @profile_memory() def go(): ma = sa.inspect(aliased(A)) m1._path_registry[m1.attrs.foo][ma][m1.attrs.bar] + go() clear_mappers() def test_with_inheritance(self): metadata = MetaData(self.engine) - table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('col2', String(30)) - ) + table1 = Table( + "mytable", + metadata, + Column( + "col1", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("col2", String(30)), + ) - table2 = Table("mytable2", metadata, - Column('col1', Integer, ForeignKey('mytable.col1'), - primary_key=True, test_needs_autoincrement=True), - Column('col3', String(30)), - ) + table2 = Table( + "mytable2", + metadata, + Column( + "col1", + Integer, + ForeignKey("mytable.col1"), + primary_key=True, + test_needs_autoincrement=True, + ), + Column("col3", String(30)), + ) @profile_memory() def go(): @@ -678,29 +788,26 @@ class MemUsageWBackendTest(EnsureZeroed): class B(A): pass - mapper(A, table1, - polymorphic_on=table1.c.col2, - polymorphic_identity='a') - mapper(B, table2, - inherits=A, - polymorphic_identity='b') + mapper( + A, + table1, + polymorphic_on=table1.c.col2, + polymorphic_identity="a", + ) + mapper(B, table2, inherits=A, polymorphic_identity="b") sess = create_session() a1 = A() a2 = A() - b1 = B(col3='b1') - b2 = B(col3='b2') + b1 = B(col3="b1") + b2 = B(col3="b2") for x in [a1, a2, b1, b2]: sess.add(x) sess.flush() sess.expunge_all() alist = sess.query(A).order_by(A.col1).all() - eq_( - [ - A(), A(), B(col3='b1'), B(col3='b2') - ], - alist) + eq_([A(), A(), B(col3="b1"), B(col3="b2")], alist) for a in alist: sess.delete(a) @@ -720,22 +827,36 @@ class MemUsageWBackendTest(EnsureZeroed): def test_with_manytomany(self): metadata = MetaData(self.engine) - table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('col2', String(30)) - ) + table1 = Table( + "mytable", + metadata, + Column( + "col1", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("col2", String(30)), + ) - table2 = Table("mytable2", metadata, - Column('col1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('col2', String(30)), - ) + table2 = Table( + "mytable2", + metadata, + Column( + "col1", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("col2", String(30)), + ) - table3 = Table('t1tot2', metadata, - Column('t1', Integer, ForeignKey('mytable.col1')), - Column('t2', Integer, ForeignKey('mytable2.col1')), - ) + table3 = Table( + "t1tot2", + metadata, + Column("t1", Integer, ForeignKey("mytable.col1")), + Column("t2", Integer, ForeignKey("mytable2.col1")), + ) @profile_memory() def go(): @@ -745,17 +866,22 @@ class MemUsageWBackendTest(EnsureZeroed): class B(fixtures.ComparableEntity): pass - mapper(A, table1, properties={ - 'bs': relationship(B, secondary=table3, - backref='as', order_by=table3.c.t1) - }) + mapper( + A, + table1, + properties={ + "bs": relationship( + B, secondary=table3, backref="as", order_by=table3.c.t1 + ) + }, + ) mapper(B, table2) sess = create_session() - a1 = A(col2='a1') - a2 = A(col2='a2') - b1 = B(col2='b1') - b2 = B(col2='b2') + a1 = A(col2="a1") + a2 = A(col2="a2") + b1 = B(col2="b1") + b2 = B(col2="b2") a1.bs.append(b1) a2.bs.append(b2) for x in [a1, a2]: @@ -764,11 +890,7 @@ class MemUsageWBackendTest(EnsureZeroed): sess.expunge_all() alist = sess.query(A).order_by(A.col1).all() - eq_( - [ - A(bs=[B(col2='b1')]), A(bs=[B(col2='b2')]) - ], - alist) + eq_([A(bs=[B(col2="b1")]), A(bs=[B(col2="b2")])], alist) for a in alist: sess.delete(a) @@ -789,7 +911,7 @@ class MemUsageWBackendTest(EnsureZeroed): def test_key_fallback_result(self): e = self.engine m = self.metadata - t = Table('t', m, Column('x', Integer), Column('y', Integer)) + t = Table("t", m, Column("x", Integer), Column("y", Integer)) m.create_all(e) e.execute(t.insert(), {"x": 1, "y": 1}) @@ -798,6 +920,7 @@ class MemUsageWBackendTest(EnsureZeroed): r = e.execute(t.alias().select()) for row in r: row[t.c.x] + go() def test_many_discarded_relationships(self): @@ -805,13 +928,17 @@ class MemUsageWBackendTest(EnsureZeroed): guard against memleaks here so why not""" m1 = MetaData() - t1 = Table('t1', m1, Column('id', Integer, primary_key=True)) + t1 = Table("t1", m1, Column("id", Integer, primary_key=True)) t2 = Table( - 't2', m1, Column('id', Integer, primary_key=True), - Column('t1id', ForeignKey('t1.id'))) + "t2", + m1, + Column("id", Integer, primary_key=True), + Column("t1id", ForeignKey("t1.id")), + ) class T1(object): pass + t1_mapper = mapper(T1, t1) @testing.emits_warning() @@ -819,39 +946,39 @@ class MemUsageWBackendTest(EnsureZeroed): def go(): class T2(object): pass + t2_mapper = mapper(T2, t2) t1_mapper.add_property("bar", relationship(t2_mapper)) s1 = Session() # this causes the path_registry to be invoked s1.query(t1_mapper)._compile_context() + go() # fails on newer versions of pysqlite due to unusual memory behvior # in pysqlite itself. background at: # http://thread.gmane.org/gmane.comp.python.db.pysqlite.user/2290 - @testing.crashes('mysql+cymysql', 'blocking') + @testing.crashes("mysql+cymysql", "blocking") def test_join_cache(self): metadata = MetaData(self.engine) table1 = Table( - 'table1', + "table1", metadata, Column( - 'id', - Integer, - primary_key=True, - test_needs_autoincrement=True), - Column( - 'data', - String(30))) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + ) table2 = Table( - 'table2', metadata, - Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), + "table2", + metadata, Column( - 'data', String(30)), Column( - 't1id', Integer, ForeignKey('table1.id'))) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + Column("t1id", Integer, ForeignKey("table1.id")), + ) class Foo(object): pass @@ -859,8 +986,9 @@ class MemUsageWBackendTest(EnsureZeroed): class Bar(object): pass - mapper(Foo, table1, properties={ - 'bars': relationship(mapper(Bar, table2))}) + mapper( + Foo, table1, properties={"bars": relationship(mapper(Bar, table2))} + ) metadata.create_all() session = sessionmaker() @@ -870,8 +998,8 @@ class MemUsageWBackendTest(EnsureZeroed): sess = session() sess.query(Foo).join((s, Foo.bars)).all() sess.rollback() + try: go() finally: metadata.drop_all() - diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index 18434b9a4e..fc546de40e 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -1,7 +1,17 @@ from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.orm import mapper, relationship, \ - sessionmaker, Session, defer, joinedload, defaultload, selectinload, \ - Load, configure_mappers, Bundle +from sqlalchemy.orm import ( + mapper, + relationship, + sessionmaker, + Session, + defer, + joinedload, + defaultload, + selectinload, + Load, + configure_mappers, + Bundle, +) from sqlalchemy import testing from sqlalchemy.testing import profiling from sqlalchemy.testing import fixtures @@ -10,29 +20,27 @@ from sqlalchemy import inspect class MergeTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): Table( - 'parent', + "parent", metadata, Column( - 'id', - Integer, - primary_key=True, - test_needs_autoincrement=True), - Column( - 'data', - String(20))) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(20)), + ) Table( - 'child', metadata, - Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), + "child", + metadata, Column( - 'data', String(20)), + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(20)), Column( - 'parent_id', Integer, ForeignKey('parent.id'), nullable=False)) + "parent_id", Integer, ForeignKey("parent.id"), nullable=False + ), + ) @classmethod def setup_classes(cls): @@ -44,26 +52,26 @@ class MergeTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - Child, Parent, parent, child = (cls.classes.Child, - cls.classes.Parent, - cls.tables.parent, - cls.tables.child) + Child, Parent, parent, child = ( + cls.classes.Child, + cls.classes.Parent, + cls.tables.parent, + cls.tables.child, + ) mapper( Parent, parent, - properties={ - 'children': relationship( - Child, - backref='parent')}) + properties={"children": relationship(Child, backref="parent")}, + ) mapper(Child, child) @classmethod def insert_data(cls): parent, child = cls.tables.parent, cls.tables.child - parent.insert().execute({'id': 1, 'data': 'p1'}) - child.insert().execute({'id': 1, 'data': 'p1c1', 'parent_id': 1}) + parent.insert().execute({"id": 1, "data": "p1"}) + child.insert().execute({"id": 1, "data": "p1c1", "parent_id": 1}) def test_merge_no_load(self): Parent = self.classes.Parent @@ -79,6 +87,7 @@ class MergeTest(fixtures.MappedTest): @profiling.function_call_count(variance=0.10) def go1(): return sess2.merge(p1, load=False) + p2 = go1() # third call, merge object already present. almost no calls. @@ -86,6 +95,7 @@ class MergeTest(fixtures.MappedTest): @profiling.function_call_count(variance=0.10) def go2(): return sess2.merge(p2, load=False) + go2() def test_merge_load(self): @@ -103,12 +113,14 @@ class MergeTest(fixtures.MappedTest): @profiling.function_call_count() def go(): sess2.merge(p1) + go() # one more time, count the SQL def go2(): sess2.merge(p1) + sess2 = sessionmaker(testing.db)() self.assert_sql_count(testing.db, go2, 2) @@ -126,16 +138,20 @@ class LoadManyToOneFromIdentityTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('parent', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(20)), - Column('child_id', Integer, ForeignKey('child.id')) - ) + Table( + "parent", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(20)), + Column("child_id", Integer, ForeignKey("child.id")), + ) - Table('child', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(20)) - ) + Table( + "child", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(20)), + ) @classmethod def setup_classes(cls): @@ -147,31 +163,33 @@ class LoadManyToOneFromIdentityTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - Child, Parent, parent, child = (cls.classes.Child, - cls.classes.Parent, - cls.tables.parent, - cls.tables.child) + Child, Parent, parent, child = ( + cls.classes.Child, + cls.classes.Parent, + cls.tables.parent, + cls.tables.child, + ) - mapper(Parent, parent, properties={ - 'child': relationship(Child)}) + mapper(Parent, parent, properties={"child": relationship(Child)}) mapper(Child, child) @classmethod def insert_data(cls): parent, child = cls.tables.parent, cls.tables.child - child.insert().execute([ - {'id': i, 'data': 'c%d' % i} - for i in range(1, 251) - ]) - parent.insert().execute([ - { - 'id': i, - 'data': 'p%dc%d' % (i, (i % 250) + 1), - 'child_id': (i % 250) + 1 - } - for i in range(1, 1000) - ]) + child.insert().execute( + [{"id": i, "data": "c%d" % i} for i in range(1, 251)] + ) + parent.insert().execute( + [ + { + "id": i, + "data": "p%dc%d" % (i, (i % 250) + 1), + "child_id": (i % 250) + 1, + } + for i in range(1, 1000) + ] + ) def test_many_to_one_load_no_identity(self): Parent = self.classes.Parent @@ -179,10 +197,11 @@ class LoadManyToOneFromIdentityTest(fixtures.MappedTest): sess = Session() parents = sess.query(Parent).all() - @profiling.function_call_count(variance=.2) + @profiling.function_call_count(variance=0.2) def go(): for p in parents: p.child + go() def test_many_to_one_load_identity(self): @@ -197,28 +216,32 @@ class LoadManyToOneFromIdentityTest(fixtures.MappedTest): def go(): for p in parents: p.child + go() class MergeBackrefsTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('a', metadata, - Column('id', Integer, primary_key=True), - Column('c_id', Integer, ForeignKey('c.id')) - ) - Table('b', metadata, - Column('id', Integer, primary_key=True), - Column('a_id', Integer, ForeignKey('a.id')) - ) - Table('c', metadata, - Column('id', Integer, primary_key=True), - ) - Table('d', metadata, - Column('id', Integer, primary_key=True), - Column('a_id', Integer, ForeignKey('a.id')) - ) + Table( + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("c_id", Integer, ForeignKey("c.id")), + ) + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("a_id", Integer, ForeignKey("a.id")), + ) + Table("c", metadata, Column("id", Integer, primary_key=True)) + Table( + "d", + metadata, + Column("id", Integer, primary_key=True), + Column("a_id", Integer, ForeignKey("a.id")), + ) @classmethod def setup_classes(cls): @@ -236,63 +259,73 @@ class MergeBackrefsTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - A, B, C, D = cls.classes.A, cls.classes.B, \ - cls.classes.C, cls.classes.D - a, b, c, d = cls.tables.a, cls.tables.b, \ - cls.tables.c, cls.tables.d - mapper(A, a, properties={ - 'bs': relationship(B, backref='a'), - 'c': relationship(C, backref='as'), - 'ds': relationship(D, backref='a'), - }) + A, B, C, D = cls.classes.A, cls.classes.B, cls.classes.C, cls.classes.D + a, b, c, d = cls.tables.a, cls.tables.b, cls.tables.c, cls.tables.d + mapper( + A, + a, + properties={ + "bs": relationship(B, backref="a"), + "c": relationship(C, backref="as"), + "ds": relationship(D, backref="a"), + }, + ) mapper(B, b) mapper(C, c) mapper(D, d) @classmethod def insert_data(cls): - A, B, C, D = cls.classes.A, cls.classes.B, \ - cls.classes.C, cls.classes.D + A, B, C, D = cls.classes.A, cls.classes.B, cls.classes.C, cls.classes.D s = Session() - s.add_all([ - A(id=i, - bs=[B(id=(i * 5) + j) for j in range(1, 5)], - c=C(id=i), - ds=[D(id=(i * 5) + j) for j in range(1, 5)] - ) - for i in range(1, 5) - ]) + s.add_all( + [ + A( + id=i, + bs=[B(id=(i * 5) + j) for j in range(1, 5)], + c=C(id=i), + ds=[D(id=(i * 5) + j) for j in range(1, 5)], + ) + for i in range(1, 5) + ] + ) s.commit() - @profiling.function_call_count(variance=.10) + @profiling.function_call_count(variance=0.10) def test_merge_pending_with_all_pks(self): - A, B, C, D = self.classes.A, self.classes.B, \ - self.classes.C, self.classes.D + A, B, C, D = ( + self.classes.A, + self.classes.B, + self.classes.C, + self.classes.D, + ) s = Session() for a in [ - A(id=i, + A( + id=i, bs=[B(id=(i * 5) + j) for j in range(1, 5)], c=C(id=i), - ds=[D(id=(i * 5) + j) for j in range(1, 5)] - ) + ds=[D(id=(i * 5) + j) for j in range(1, 5)], + ) for i in range(1, 5) ]: s.merge(a) class DeferOptionsTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('a', metadata, - Column('id', Integer, primary_key=True), - Column('x', String(5)), - Column('y', String(5)), - Column('z', String(5)), - Column('q', String(5)), - Column('p', String(5)), - Column('r', String(5)), - ) + Table( + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("x", String(5)), + Column("y", String(5)), + Column("z", String(5)), + Column("q", String(5)), + Column("p", String(5)), + Column("r", String(5)), + ) @classmethod def setup_classes(cls): @@ -309,55 +342,60 @@ class DeferOptionsTest(fixtures.MappedTest): def insert_data(cls): A = cls.classes.A s = Session() - s.add_all([ - A(id=i, - **dict((letter, "%s%d" % (letter, i)) for letter in - ['x', 'y', 'z', 'p', 'q', 'r']) - ) for i in range(1, 1001) - ]) + s.add_all( + [ + A( + id=i, + **dict( + (letter, "%s%d" % (letter, i)) + for letter in ["x", "y", "z", "p", "q", "r"] + ) + ) + for i in range(1, 1001) + ] + ) s.commit() - @profiling.function_call_count(variance=.10) + @profiling.function_call_count(variance=0.10) def test_baseline(self): # as of [ticket:2778], this is at 39025 A = self.classes.A s = Session() s.query(A).all() - @profiling.function_call_count(variance=.10) + @profiling.function_call_count(variance=0.10) def test_defer_many_cols(self): # with [ticket:2778], this goes from 50805 to 32817, # as it should be fewer function calls than the baseline A = self.classes.A s = Session() s.query(A).options( - *[defer(letter) for letter in ['x', 'y', 'z', 'p', 'q', 'r']]).\ - all() + *[defer(letter) for letter in ["x", "y", "z", "p", "q", "r"]] + ).all() class AttributeOverheadTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): Table( - 'parent', + "parent", metadata, Column( - 'id', - Integer, - primary_key=True, - test_needs_autoincrement=True), - Column( - 'data', - String(20))) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(20)), + ) Table( - 'child', metadata, + "child", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(20)), Column( - 'data', String(20)), Column( - 'parent_id', Integer, ForeignKey('parent.id'), nullable=False)) + "parent_id", Integer, ForeignKey("parent.id"), nullable=False + ), + ) @classmethod def setup_classes(cls): @@ -369,18 +407,18 @@ class AttributeOverheadTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - Child, Parent, parent, child = (cls.classes.Child, - cls.classes.Parent, - cls.tables.parent, - cls.tables.child) + Child, Parent, parent, child = ( + cls.classes.Child, + cls.classes.Parent, + cls.tables.parent, + cls.tables.child, + ) mapper( Parent, parent, - properties={ - 'children': relationship( - Child, - backref='parent')}) + properties={"children": relationship(Child, backref="parent")}, + ) mapper(Child, child) def test_attribute_set(self): @@ -395,6 +433,7 @@ class AttributeOverheadTest(fixtures.MappedTest): c1.parent = None c1.parent = p1 del c1.parent + go() def test_collection_append_remove(self): @@ -408,6 +447,7 @@ class AttributeOverheadTest(fixtures.MappedTest): p1.children.append(child) for child in children: p1.children.remove(child) + go() @@ -415,18 +455,24 @@ class SessionTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'parent', + "parent", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('data', String(20))) + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(20)), + ) Table( - 'child', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + "child", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(20)), Column( - 'data', String(20)), Column( - 'parent_id', Integer, ForeignKey('parent.id'), nullable=False)) + "parent_id", Integer, ForeignKey("parent.id"), nullable=False + ), + ) @classmethod def setup_classes(cls): @@ -438,22 +484,25 @@ class SessionTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - Child, Parent, parent, child = (cls.classes.Child, - cls.classes.Parent, - cls.tables.parent, - cls.tables.child) + Child, Parent, parent, child = ( + cls.classes.Child, + cls.classes.Parent, + cls.tables.parent, + cls.tables.child, + ) mapper( - Parent, parent, properties={ - 'children': relationship( - Child, - backref='parent')}) + Parent, + parent, + properties={"children": relationship(Child, backref="parent")}, + ) mapper(Child, child) def test_expire_lots(self): Parent, Child = self.classes.Parent, self.classes.Child - obj = [Parent( - children=[Child() for j in range(10)]) for i in range(10)] + obj = [ + Parent(children=[Child() for j in range(10)]) for i in range(10) + ] sess = Session() sess.add_all(obj) @@ -462,6 +511,7 @@ class SessionTest(fixtures.MappedTest): @profiling.function_call_count() def go(): sess.expire_all() + go() @@ -469,14 +519,15 @@ class QueryTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'parent', + "parent", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('data1', String(20)), - Column('data2', String(20)), - Column('data3', String(20)), - Column('data4', String(20)), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data1", String(20)), + Column("data2", String(20)), + Column("data3", String(20)), + Column("data4", String(20)), ) @classmethod @@ -494,10 +545,12 @@ class QueryTest(fixtures.MappedTest): def _fixture(self): Parent = self.classes.Parent sess = Session() - sess.add_all([ - Parent(data1='d1', data2='d2', data3='d3', data4='d4') - for i in range(10) - ]) + sess.add_all( + [ + Parent(data1="d1", data2="d2", data3="d3", data4="d4") + for i in range(10) + ] + ) sess.commit() sess.close() @@ -530,30 +583,33 @@ class SelectInEagerLoadTest(fixtures.MappedTest): def define_tables(cls, metadata): Table( - 'a', + "a", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('x', Integer), - Column('y', Integer) + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("x", Integer), + Column("y", Integer), ) Table( - 'b', + "b", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('a_id', ForeignKey('a.id')), - Column('x', Integer), - Column('y', Integer) + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("a_id", ForeignKey("a.id")), + Column("x", Integer), + Column("y", Integer), ) Table( - 'c', + "c", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('b_id', ForeignKey('b.id')), - Column('x', Integer), - Column('y', Integer) + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("b_id", ForeignKey("b.id")), + Column("x", Integer), + Column("y", Integer), ) @classmethod @@ -569,36 +625,26 @@ class SelectInEagerLoadTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - A, B, C = cls.classes('A', 'B', 'C') - a, b, c = cls.tables('a', 'b', 'c') - - mapper(A, a, properties={ - 'bs': relationship(B), - }) - mapper(B, b, properties={ - 'cs': relationship(C) - }) + A, B, C = cls.classes("A", "B", "C") + a, b, c = cls.tables("a", "b", "c") + + mapper(A, a, properties={"bs": relationship(B)}) + mapper(B, b, properties={"cs": relationship(C)}) mapper(C, c) @classmethod def insert_data(cls): - A, B, C = cls.classes('A', 'B', 'C') + A, B, C = cls.classes("A", "B", "C") s = Session() - s.add( - A( - bs=[B(cs=[C()]), B(cs=[C()])] - ) - ) + s.add(A(bs=[B(cs=[C()]), B(cs=[C()])])) s.commit() def test_round_trip_results(self): - A, B, C = self.classes('A', 'B', 'C') + A, B, C = self.classes("A", "B", "C") sess = Session() - q = sess.query(A).options( - selectinload(A.bs).selectinload(B.cs) - ) + q = sess.query(A).options(selectinload(A.bs).selectinload(B.cs)) @profiling.function_call_count() def go(): @@ -606,6 +652,7 @@ class SelectInEagerLoadTest(fixtures.MappedTest): obj = q.all() list(obj) sess.close() + go() @@ -613,64 +660,68 @@ class JoinedEagerLoadTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): def make_some_columns(): - return [ - Column('c%d' % i, Integer) - for i in range(10) - ] + return [Column("c%d" % i, Integer) for i in range(10)] Table( - 'a', + "a", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), *make_some_columns() ) Table( - 'b', + "b", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('a_id', ForeignKey('a.id')), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("a_id", ForeignKey("a.id")), *make_some_columns() ) Table( - 'c', + "c", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('b_id', ForeignKey('b.id')), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("b_id", ForeignKey("b.id")), *make_some_columns() ) Table( - 'd', + "d", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('c_id', ForeignKey('c.id')), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("c_id", ForeignKey("c.id")), *make_some_columns() ) Table( - 'e', + "e", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('a_id', ForeignKey('a.id')), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("a_id", ForeignKey("a.id")), *make_some_columns() ) Table( - 'f', + "f", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('e_id', ForeignKey('e.id')), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("e_id", ForeignKey("e.id")), *make_some_columns() ) Table( - 'g', + "g", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('e_id', ForeignKey('e.id')), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("e_id", ForeignKey("e.id")), *make_some_columns() ) @@ -699,41 +750,31 @@ class JoinedEagerLoadTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - A, B, C, D, E, F, G = cls.classes('A', 'B', 'C', 'D', 'E', 'F', 'G') - a, b, c, d, e, f, g = cls.tables('a', 'b', 'c', 'd', 'e', 'f', 'g') - - mapper(A, a, properties={ - 'bs': relationship(B), - 'es': relationship(E) - }) - mapper(B, b, properties={ - 'cs': relationship(C) - }) - mapper(C, c, properties={ - 'ds': relationship(D) - }) + A, B, C, D, E, F, G = cls.classes("A", "B", "C", "D", "E", "F", "G") + a, b, c, d, e, f, g = cls.tables("a", "b", "c", "d", "e", "f", "g") + + mapper(A, a, properties={"bs": relationship(B), "es": relationship(E)}) + mapper(B, b, properties={"cs": relationship(C)}) + mapper(C, c, properties={"ds": relationship(D)}) mapper(D, d) - mapper(E, e, properties={ - 'fs': relationship(F), - 'gs': relationship(G) - }) + mapper(E, e, properties={"fs": relationship(F), "gs": relationship(G)}) mapper(F, f) mapper(G, g) @classmethod def insert_data(cls): - A, B, C, D, E, F, G = cls.classes('A', 'B', 'C', 'D', 'E', 'F', 'G') + A, B, C, D, E, F, G = cls.classes("A", "B", "C", "D", "E", "F", "G") s = Session() s.add( A( bs=[B(cs=[C(ds=[D()])]), B(cs=[C()])], - es=[E(fs=[F()], gs=[G()])] + es=[E(fs=[F()], gs=[G()])], ) ) s.commit() def test_build_query(self): - A, B, C, D, E, F, G = self.classes('A', 'B', 'C', 'D', 'E', 'F', 'G') + A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G") sess = Session() @@ -746,10 +787,11 @@ class JoinedEagerLoadTest(fixtures.MappedTest): defaultload(A.es).joinedload(E.gs), ) q._compile_context() + go() def test_fetch_results(self): - A, B, C, D, E, F, G = self.classes('A', 'B', 'C', 'D', 'E', 'F', 'G') + A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G") sess = Session() @@ -767,6 +809,7 @@ class JoinedEagerLoadTest(fixtures.MappedTest): obj = q._execute_and_instances(context) list(obj) sess.close() + go() @@ -774,64 +817,68 @@ class BranchedOptionTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): def make_some_columns(): - return [ - Column('c%d' % i, Integer) - for i in range(2) - ] + return [Column("c%d" % i, Integer) for i in range(2)] Table( - 'a', + "a", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), *make_some_columns() ) Table( - 'b', + "b", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('a_id', ForeignKey('a.id')), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("a_id", ForeignKey("a.id")), *make_some_columns() ) Table( - 'c', + "c", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('b_id', ForeignKey('b.id')), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("b_id", ForeignKey("b.id")), *make_some_columns() ) Table( - 'd', + "d", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('b_id', ForeignKey('b.id')), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("b_id", ForeignKey("b.id")), *make_some_columns() ) Table( - 'e', + "e", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('b_id', ForeignKey('b.id')), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("b_id", ForeignKey("b.id")), *make_some_columns() ) Table( - 'f', + "f", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('b_id', ForeignKey('b.id')), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("b_id", ForeignKey("b.id")), *make_some_columns() ) Table( - 'g', + "g", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('a_id', ForeignKey('a.id')), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("a_id", ForeignKey("a.id")), *make_some_columns() ) @@ -860,19 +907,20 @@ class BranchedOptionTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - A, B, C, D, E, F, G = cls.classes('A', 'B', 'C', 'D', 'E', 'F', 'G') - a, b, c, d, e, f, g = cls.tables('a', 'b', 'c', 'd', 'e', 'f', 'g') - - mapper(A, a, properties={ - 'bs': relationship(B), - 'gs': relationship(G) - }) - mapper(B, b, properties={ - 'cs': relationship(C), - 'ds': relationship(D), - 'es': relationship(E), - 'fs': relationship(F) - }) + A, B, C, D, E, F, G = cls.classes("A", "B", "C", "D", "E", "F", "G") + a, b, c, d, e, f, g = cls.tables("a", "b", "c", "d", "e", "f", "g") + + mapper(A, a, properties={"bs": relationship(B), "gs": relationship(G)}) + mapper( + B, + b, + properties={ + "cs": relationship(C), + "ds": relationship(D), + "es": relationship(E), + "fs": relationship(F), + }, + ) mapper(C, c) mapper(D, d) mapper(E, e) @@ -882,14 +930,14 @@ class BranchedOptionTest(fixtures.MappedTest): configure_mappers() def test_generate_cache_key_unbound_branching(self): - A, B, C, D, E, F, G = self.classes('A', 'B', 'C', 'D', 'E', 'F', 'G') + A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G") base = joinedload(A.bs) opts = [ base.joinedload(B.cs), base.joinedload(B.ds), base.joinedload(B.es), - base.joinedload(B.fs) + base.joinedload(B.fs), ] cache_path = inspect(A)._path_registry @@ -898,17 +946,18 @@ class BranchedOptionTest(fixtures.MappedTest): def go(): for opt in opts: opt._generate_cache_key(cache_path) + go() def test_generate_cache_key_bound_branching(self): - A, B, C, D, E, F, G = self.classes('A', 'B', 'C', 'D', 'E', 'F', 'G') + A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G") base = Load(A).joinedload(A.bs) opts = [ base.joinedload(B.cs), base.joinedload(B.ds), base.joinedload(B.es), - base.joinedload(B.fs) + base.joinedload(B.fs), ] cache_path = inspect(A)._path_registry @@ -917,17 +966,18 @@ class BranchedOptionTest(fixtures.MappedTest): def go(): for opt in opts: opt._generate_cache_key(cache_path) + go() def test_query_opts_unbound_branching(self): - A, B, C, D, E, F, G = self.classes('A', 'B', 'C', 'D', 'E', 'F', 'G') + A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G") base = joinedload(A.bs) opts = [ base.joinedload(B.cs), base.joinedload(B.ds), base.joinedload(B.es), - base.joinedload(B.fs) + base.joinedload(B.fs), ] q = Session().query(A) @@ -935,17 +985,18 @@ class BranchedOptionTest(fixtures.MappedTest): @profiling.function_call_count() def go(): q.options(*opts) + go() def test_query_opts_key_bound_branching(self): - A, B, C, D, E, F, G = self.classes('A', 'B', 'C', 'D', 'E', 'F', 'G') + A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G") base = Load(A).joinedload(A.bs) opts = [ base.joinedload(B.cs), base.joinedload(B.ds), base.joinedload(B.es), - base.joinedload(B.fs) + base.joinedload(B.fs), ] q = Session().query(A) @@ -953,6 +1004,7 @@ class BranchedOptionTest(fixtures.MappedTest): @profiling.function_call_count() def go(): q.options(*opts) + go() @@ -960,10 +1012,10 @@ class AnnotatedOverheadTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'a', + "a", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) + Column("id", Integer, primary_key=True), + Column("data", String(50)), ) @classmethod @@ -982,7 +1034,7 @@ class AnnotatedOverheadTest(fixtures.MappedTest): def insert_data(cls): A = cls.classes.A s = Session() - s.add_all([A(data='asdf') for i in range(5)]) + s.add_all([A(data="asdf") for i in range(5)]) s.commit() def test_no_bundle(self): @@ -995,6 +1047,7 @@ class AnnotatedOverheadTest(fixtures.MappedTest): def go(): for i in range(100): q.all() + go() def test_no_entity_wo_annotations(self): @@ -1008,6 +1061,7 @@ class AnnotatedOverheadTest(fixtures.MappedTest): def go(): for i in range(100): q.all() + go() def test_no_entity_w_annotations(self): @@ -1019,85 +1073,80 @@ class AnnotatedOverheadTest(fixtures.MappedTest): def go(): for i in range(100): q.all() + go() def test_entity_w_annotations(self): A = self.classes.A s = Session() - q = s.query( - A, A.data - ).select_from(A) + q = s.query(A, A.data).select_from(A) @profiling.function_call_count() def go(): for i in range(100): q.all() + go() def test_entity_wo_annotations(self): A = self.classes.A a = self.tables.a s = Session() - q = s.query( - A, a.c.data - ).select_from(A) + q = s.query(A, a.c.data).select_from(A) @profiling.function_call_count() def go(): for i in range(100): q.all() + go() def test_no_bundle_wo_annotations(self): A = self.classes.A a = self.tables.a s = Session() - q = s.query( - a.c.data, A - ).select_from(A) + q = s.query(a.c.data, A).select_from(A) @profiling.function_call_count() def go(): for i in range(100): q.all() + go() def test_no_bundle_w_annotations(self): A = self.classes.A s = Session() - q = s.query( - A.data, A - ).select_from(A) + q = s.query(A.data, A).select_from(A) @profiling.function_call_count() def go(): for i in range(100): q.all() + go() def test_bundle_wo_annotation(self): A = self.classes.A a = self.tables.a s = Session() - q = s.query( - Bundle("ASdf", a.c.data), A - ).select_from(A) + q = s.query(Bundle("ASdf", a.c.data), A).select_from(A) @profiling.function_call_count() def go(): for i in range(100): q.all() + go() def test_bundle_w_annotation(self): A = self.classes.A s = Session() - q = s.query( - Bundle("ASdf", A.data), A - ).select_from(A) + q = s.query(Bundle("ASdf", A.data), A).select_from(A) @profiling.function_call_count() def go(): for i in range(100): q.all() + go() diff --git a/test/aaa_profiling/test_pool.py b/test/aaa_profiling/test_pool.py index 02248fa592..af9669294d 100644 --- a/test/aaa_profiling/test_pool.py +++ b/test/aaa_profiling/test_pool.py @@ -6,10 +6,9 @@ pool = None class QueuePoolTest(fixtures.TestBase, AssertsExecutionResults): - __requires__ = 'cpython', + __requires__ = ("cpython",) class Connection(object): - def rollback(self): pass @@ -28,15 +27,21 @@ class QueuePoolTest(fixtures.TestBase, AssertsExecutionResults): # has the effect of initializing # class-level event listeners on Pool, # if not present already. - p1 = QueuePool(creator=self.Connection, - pool_size=3, max_overflow=-1, - use_threadlocal=True) + p1 = QueuePool( + creator=self.Connection, + pool_size=3, + max_overflow=-1, + use_threadlocal=True, + ) p1.connect() global pool - pool = QueuePool(creator=self.Connection, - pool_size=3, max_overflow=-1, - use_threadlocal=True) + pool = QueuePool( + creator=self.Connection, + pool_size=3, + max_overflow=-1, + use_threadlocal=True, + ) @profiling.function_call_count() def test_first_connect(self): @@ -50,6 +55,7 @@ class QueuePoolTest(fixtures.TestBase, AssertsExecutionResults): def go(): conn2 = pool.connect() return conn2 + go() def test_second_samethread_connect(self): @@ -59,4 +65,5 @@ class QueuePoolTest(fixtures.TestBase, AssertsExecutionResults): @profiling.function_call_count() def go(): return pool.connect() + go() diff --git a/test/aaa_profiling/test_resultset.py b/test/aaa_profiling/test_resultset.py index 51ff739a31..ab92ee94d9 100644 --- a/test/aaa_profiling/test_resultset.py +++ b/test/aaa_profiling/test_resultset.py @@ -1,5 +1,12 @@ -from sqlalchemy import MetaData, Table, Column, String, Unicode, Integer, \ - create_engine +from sqlalchemy import ( + MetaData, + Table, + Column, + String, + Unicode, + Integer, + create_engine, +) from sqlalchemy.testing import fixtures, AssertsExecutionResults, profiling from sqlalchemy import testing from sqlalchemy.testing import eq_ @@ -20,31 +27,55 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults): def setup_class(cls): global t, t2, metadata metadata = MetaData(testing.db) - t = Table('table1', metadata, *[Column('field%d' % fnum, String(50)) - for fnum in range(NUM_FIELDS)]) + t = Table( + "table1", + metadata, + *[ + Column("field%d" % fnum, String(50)) + for fnum in range(NUM_FIELDS) + ] + ) t2 = Table( - 'table2', metadata, * - [Column('field%d' % fnum, Unicode(50)) - for fnum in range(NUM_FIELDS)]) + "table2", + metadata, + *[ + Column("field%d" % fnum, Unicode(50)) + for fnum in range(NUM_FIELDS) + ] + ) def setup(self): metadata.create_all() - t.insert().execute([dict(('field%d' % fnum, u('value%d' % fnum)) - for fnum in range(NUM_FIELDS)) for r_num in - range(NUM_RECORDS)]) - t2.insert().execute([dict(('field%d' % fnum, u('value%d' % fnum)) - for fnum in range(NUM_FIELDS)) for r_num in - range(NUM_RECORDS)]) + t.insert().execute( + [ + dict( + ("field%d" % fnum, u("value%d" % fnum)) + for fnum in range(NUM_FIELDS) + ) + for r_num in range(NUM_RECORDS) + ] + ) + t2.insert().execute( + [ + dict( + ("field%d" % fnum, u("value%d" % fnum)) + for fnum in range(NUM_FIELDS) + ) + for r_num in range(NUM_RECORDS) + ] + ) # warm up type caches t.select().execute().fetchall() t2.select().execute().fetchall() - testing.db.execute('SELECT %s FROM table1' % ( - ", ".join("field%d" % fnum for fnum in range(NUM_FIELDS)) - )).fetchall() - testing.db.execute("SELECT %s FROM table2" % ( - ", ".join("field%d" % fnum for fnum in range(NUM_FIELDS)) - )).fetchall() + testing.db.execute( + "SELECT %s FROM table1" + % (", ".join("field%d" % fnum for fnum in range(NUM_FIELDS))) + ).fetchall() + testing.db.execute( + "SELECT %s FROM table2" + % (", ".join("field%d" % fnum for fnum in range(NUM_FIELDS))) + ).fetchall() def teardown(self): metadata.drop_all() @@ -59,7 +90,7 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults): @profiling.function_call_count(variance=0.10) def test_raw_string(self): - stmt = 'SELECT %s FROM table1' % ( + stmt = "SELECT %s FROM table1" % ( ", ".join("field%d" % fnum for fnum in range(NUM_FIELDS)) ) [tuple(row) for row in testing.db.execute(stmt).fetchall()] @@ -73,12 +104,14 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults): def test_contains_doesnt_compile(self): row = t.select().execute().first() - c1 = Column('some column', Integer) + \ - Column("some other column", Integer) + c1 = Column("some column", Integer) + Column( + "some other column", Integer + ) @profiling.function_call_count() def go(): c1 in row + go() @@ -87,7 +120,7 @@ class ExecutionTest(fixtures.TestBase): def test_minimal_connection_execute(self): # create an engine without any instrumentation. - e = create_engine('sqlite://') + e = create_engine("sqlite://") c = e.connect() # ensure initial connect activities complete c.execute("select 1") @@ -95,6 +128,7 @@ class ExecutionTest(fixtures.TestBase): @profiling.function_call_count() def go(): c.execute("select 1") + try: go() finally: @@ -102,31 +136,32 @@ class ExecutionTest(fixtures.TestBase): def test_minimal_engine_execute(self, variance=0.10): # create an engine without any instrumentation. - e = create_engine('sqlite://') + e = create_engine("sqlite://") # ensure initial connect activities complete e.execute("select 1") @profiling.function_call_count() def go(): e.execute("select 1") + go() class RowProxyTest(fixtures.TestBase): - __requires__ = 'cpython', + __requires__ = ("cpython",) __backend__ = True def _rowproxy_fixture(self, keys, processors, row): class MockMeta(object): - def __init__(self): pass metadata = MockMeta() keymap = {} - for index, (keyobjs, processor, values) in \ - enumerate(list(zip(keys, processors, row))): + for index, (keyobjs, processor, values) in enumerate( + list(zip(keys, processors, row)) + ): for key in keyobjs: keymap[key] = (processor, key, index) keymap[index] = (processor, key, index) @@ -137,11 +172,12 @@ class RowProxyTest(fixtures.TestBase): def proc1(value): return value + value1, value2 = "x", "y" row = self._rowproxy_fixture( [(col1, "a"), (col2, "b")], [proc1, None], - seq_factory([value1, value2]) + seq_factory([value1, value2]), ) v1_refcount = sys.getrefcount(value1) @@ -162,7 +198,6 @@ class RowProxyTest(fixtures.TestBase): def test_value_refcounts_custom_seq(self): class CustomSeq(object): - def __init__(self, data): self.data = data @@ -171,4 +206,5 @@ class RowProxyTest(fixtures.TestBase): def __iter__(self): return iter(self.data) + self._test_getitem_value_refcounts(CustomSeq) diff --git a/test/aaa_profiling/test_zoomark.py b/test/aaa_profiling/test_zoomark.py index 8ddca97c43..9ef824cf65 100644 --- a/test/aaa_profiling/test_zoomark.py +++ b/test/aaa_profiling/test_zoomark.py @@ -6,9 +6,23 @@ An adaptation of Robert Brewers' ZooMark speed tests. """ import datetime -from sqlalchemy import Table, Column, Integer, Unicode, Date, \ - DateTime, Time, Float, Sequence, ForeignKey, \ - select, join, and_, outerjoin, func +from sqlalchemy import ( + Table, + Column, + Integer, + Unicode, + Date, + DateTime, + Time, + Float, + Sequence, + ForeignKey, + select, + join, + and_, + outerjoin, + func, +) from sqlalchemy.testing import replay_fixture ITERATIONS = 1 @@ -18,8 +32,8 @@ class ZooMarkTest(replay_fixture.ReplayFixtureTest): """Runs the ZooMark and squawks if method counts vary from the norm.""" - __requires__ = 'cpython', - __only_on__ = 'postgresql+psycopg2' + __requires__ = ("cpython",) + __only_on__ = "postgresql+psycopg2" def _run_steps(self, ctx): with ctx(): @@ -45,110 +59,143 @@ class ZooMarkTest(replay_fixture.ReplayFixtureTest): def _baseline_1_create_tables(self): Table( - 'Zoo', + "Zoo", self.metadata, - Column('ID', Integer, Sequence('zoo_id_seq'), - primary_key=True, index=True), - Column('Name', Unicode(255)), - Column('Founded', Date), - Column('Opens', Time), - Column('LastEscape', DateTime), - Column('Admission', Float), + Column( + "ID", + Integer, + Sequence("zoo_id_seq"), + primary_key=True, + index=True, + ), + Column("Name", Unicode(255)), + Column("Founded", Date), + Column("Opens", Time), + Column("LastEscape", DateTime), + Column("Admission", Float), ) Table( - 'Animal', + "Animal", self.metadata, - Column('ID', Integer, Sequence('animal_id_seq'), - primary_key=True), - Column('ZooID', Integer, ForeignKey('Zoo.ID'), index=True), - Column('Name', Unicode(100)), - Column('Species', Unicode(100)), - Column('Legs', Integer, default=4), - Column('LastEscape', DateTime), - Column('Lifespan', Float(4)), - Column('MotherID', Integer, ForeignKey('Animal.ID')), - Column('PreferredFoodID', Integer), - Column('AlternateFoodID', Integer), + Column("ID", Integer, Sequence("animal_id_seq"), primary_key=True), + Column("ZooID", Integer, ForeignKey("Zoo.ID"), index=True), + Column("Name", Unicode(100)), + Column("Species", Unicode(100)), + Column("Legs", Integer, default=4), + Column("LastEscape", DateTime), + Column("Lifespan", Float(4)), + Column("MotherID", Integer, ForeignKey("Animal.ID")), + Column("PreferredFoodID", Integer), + Column("AlternateFoodID", Integer), ) self.metadata.create_all() def _baseline_1a_populate(self): - Zoo = self.metadata.tables['Zoo'] - Animal = self.metadata.tables['Animal'] + Zoo = self.metadata.tables["Zoo"] + Animal = self.metadata.tables["Animal"] engine = self.metadata.bind - wap = engine.execute(Zoo.insert(), Name='Wild Animal Park', - Founded=datetime.date(2000, 1, 1), - Opens=datetime.time(8, 15, 59), - LastEscape=datetime.datetime( - 2004, 7, 29, 5, 6, 7), - Admission=4.95).inserted_primary_key[0] - sdz = engine.execute(Zoo.insert(), Name='San Diego Zoo', - Founded=datetime.date(1935, 9, 13), - Opens=datetime.time(9, 0, 0), - Admission=0).inserted_primary_key[0] - engine.execute(Zoo.insert(inline=True), Name='Montr\xe9al Biod\xf4me', - Founded=datetime.date(1992, 6, 19), - Opens=datetime.time(9, 0, 0), Admission=11.75) - seaworld = engine.execute(Zoo.insert(), Name='Sea_World', - Admission=60).inserted_primary_key[0] + wap = engine.execute( + Zoo.insert(), + Name="Wild Animal Park", + Founded=datetime.date(2000, 1, 1), + Opens=datetime.time(8, 15, 59), + LastEscape=datetime.datetime(2004, 7, 29, 5, 6, 7), + Admission=4.95, + ).inserted_primary_key[0] + sdz = engine.execute( + Zoo.insert(), + Name="San Diego Zoo", + Founded=datetime.date(1935, 9, 13), + Opens=datetime.time(9, 0, 0), + Admission=0, + ).inserted_primary_key[0] + engine.execute( + Zoo.insert(inline=True), + Name="Montr\xe9al Biod\xf4me", + Founded=datetime.date(1992, 6, 19), + Opens=datetime.time(9, 0, 0), + Admission=11.75, + ) + seaworld = engine.execute( + Zoo.insert(), Name="Sea_World", Admission=60 + ).inserted_primary_key[0] # Let's add a crazy futuristic Zoo to test large date values. engine.execute( - Zoo.insert(), Name='Luna Park', + Zoo.insert(), + Name="Luna Park", Founded=datetime.date(2072, 7, 17), Opens=datetime.time(0, 0, 0), - Admission=134.95).inserted_primary_key[0] + Admission=134.95, + ).inserted_primary_key[0] # Animals - leopardid = engine.execute(Animal.insert(), Species='Leopard', - Lifespan=73.5).inserted_primary_key[0] - engine.execute(Animal.update(Animal.c.ID == leopardid), ZooID=wap, - LastEscape=datetime.datetime( - 2004, 12, 21, 8, 15, 0, 999907,) - ) + leopardid = engine.execute( + Animal.insert(), Species="Leopard", Lifespan=73.5 + ).inserted_primary_key[0] + engine.execute( + Animal.update(Animal.c.ID == leopardid), + ZooID=wap, + LastEscape=datetime.datetime(2004, 12, 21, 8, 15, 0, 999907), + ) engine.execute( - Animal.insert(), - Species='Lion', ZooID=wap).inserted_primary_key[0] + Animal.insert(), Species="Lion", ZooID=wap + ).inserted_primary_key[0] - engine.execute(Animal.insert(), Species='Slug', Legs=1, Lifespan=.75) - engine.execute(Animal.insert(), Species='Tiger', - ZooID=sdz).inserted_primary_key[0] + engine.execute(Animal.insert(), Species="Slug", Legs=1, Lifespan=0.75) + engine.execute( + Animal.insert(), Species="Tiger", ZooID=sdz + ).inserted_primary_key[0] # Override Legs.default with itself just to make sure it works. - engine.execute(Animal.insert(inline=True), Species='Bear', Legs=4) - engine.execute(Animal.insert(inline=True), Species='Ostrich', Legs=2, - Lifespan=103.2) - engine.execute(Animal.insert(inline=True), Species='Centipede', - Legs=100) - engine.execute(Animal.insert(), Species='Emperor Penguin', - Legs=2, ZooID=seaworld).inserted_primary_key[0] - engine.execute(Animal.insert(), Species='Adelie Penguin', - Legs=2, ZooID=seaworld).inserted_primary_key[0] - engine.execute(Animal.insert(inline=True), Species='Millipede', - Legs=1000000, ZooID=sdz) + engine.execute(Animal.insert(inline=True), Species="Bear", Legs=4) + engine.execute( + Animal.insert(inline=True), + Species="Ostrich", + Legs=2, + Lifespan=103.2, + ) + engine.execute( + Animal.insert(inline=True), Species="Centipede", Legs=100 + ) + engine.execute( + Animal.insert(), Species="Emperor Penguin", Legs=2, ZooID=seaworld + ).inserted_primary_key[0] + engine.execute( + Animal.insert(), Species="Adelie Penguin", Legs=2, ZooID=seaworld + ).inserted_primary_key[0] + engine.execute( + Animal.insert(inline=True), + Species="Millipede", + Legs=1000000, + ZooID=sdz, + ) # Add a mother and child to test relationships bai_yun = engine.execute( - Animal.insert(), - Species='Ape', - Name='Bai Yun', - Legs=2).inserted_primary_key[0] - engine.execute(Animal.insert(inline=True), Species='Ape', - Name='Hua Mei', Legs=2, MotherID=bai_yun) + Animal.insert(), Species="Ape", Name="Bai Yun", Legs=2 + ).inserted_primary_key[0] + engine.execute( + Animal.insert(inline=True), + Species="Ape", + Name="Hua Mei", + Legs=2, + MotherID=bai_yun, + ) def _baseline_2_insert(self): - Animal = self.metadata.tables['Animal'] + Animal = self.metadata.tables["Animal"] i = Animal.insert(inline=True) for x in range(ITERATIONS): - i.execute(Species='Tick', Name='Tick %d' % x, Legs=8) + i.execute(Species="Tick", Name="Tick %d" % x, Legs=8) def _baseline_3_properties(self): - Zoo = self.metadata.tables['Zoo'] - Animal = self.metadata.tables['Animal'] + Zoo = self.metadata.tables["Zoo"] + Animal = self.metadata.tables["Animal"] engine = self.metadata.bind def fullobject(select): @@ -160,23 +207,21 @@ class ZooMarkTest(replay_fixture.ReplayFixtureTest): # Zoos - fullobject(Zoo.select(Zoo.c.Name == 'Wild Animal Park')) - fullobject(Zoo.select(Zoo.c.Founded == - datetime.date(1935, 9, 13))) - fullobject(Zoo.select(Zoo.c.Name == - 'Montr\xe9al Biod\xf4me')) + fullobject(Zoo.select(Zoo.c.Name == "Wild Animal Park")) + fullobject(Zoo.select(Zoo.c.Founded == datetime.date(1935, 9, 13))) + fullobject(Zoo.select(Zoo.c.Name == "Montr\xe9al Biod\xf4me")) fullobject(Zoo.select(Zoo.c.Admission == float(60))) # Animals - fullobject(Animal.select(Animal.c.Species == 'Leopard')) - fullobject(Animal.select(Animal.c.Species == 'Ostrich')) + fullobject(Animal.select(Animal.c.Species == "Leopard")) + fullobject(Animal.select(Animal.c.Species == "Ostrich")) fullobject(Animal.select(Animal.c.Legs == 1000000)) - fullobject(Animal.select(Animal.c.Species == 'Tick')) + fullobject(Animal.select(Animal.c.Species == "Tick")) def _baseline_4_expressions(self): - Zoo = self.metadata.tables['Zoo'] - Animal = self.metadata.tables['Animal'] + Zoo = self.metadata.tables["Zoo"] + Animal = self.metadata.tables["Animal"] engine = self.metadata.bind def fulltable(select): @@ -187,104 +232,159 @@ class ZooMarkTest(replay_fixture.ReplayFixtureTest): for x in range(ITERATIONS): assert len(fulltable(Zoo.select())) == 5 assert len(fulltable(Animal.select())) == ITERATIONS + 12 - assert len(fulltable(Animal.select(Animal.c.Legs == 4))) \ - == 4 - assert len(fulltable(Animal.select(Animal.c.Legs == 2))) \ - == 5 - assert len( - fulltable( - Animal.select( - and_( - Animal.c.Legs >= 2, - Animal.c.Legs < 20)))) == ITERATIONS + 9 - assert len(fulltable(Animal.select(Animal.c.Legs > 10))) \ + assert len(fulltable(Animal.select(Animal.c.Legs == 4))) == 4 + assert len(fulltable(Animal.select(Animal.c.Legs == 2))) == 5 + assert ( + len( + fulltable( + Animal.select( + and_(Animal.c.Legs >= 2, Animal.c.Legs < 20) + ) + ) + ) + == ITERATIONS + 9 + ) + assert len(fulltable(Animal.select(Animal.c.Legs > 10))) == 2 + assert len(fulltable(Animal.select(Animal.c.Lifespan > 70))) == 2 + assert ( + len(fulltable(Animal.select(Animal.c.Species.startswith("L")))) + == 2 + ) + assert ( + len( + fulltable(Animal.select(Animal.c.Species.endswith("pede"))) + ) == 2 - assert len(fulltable(Animal.select(Animal.c.Lifespan - > 70))) == 2 - assert len(fulltable(Animal.select(Animal.c.Species. - startswith('L')))) == 2 - assert len(fulltable(Animal.select(Animal.c.Species. - endswith('pede')))) == 2 - assert len(fulltable( - Animal.select(Animal.c.LastEscape != None))) == 1 # noqa - assert len( - fulltable(Animal.select( - None == Animal.c.LastEscape))) == ITERATIONS + 11 # noqa + ) + assert ( + len(fulltable(Animal.select(Animal.c.LastEscape != None))) == 1 + ) # noqa + assert ( + len(fulltable(Animal.select(None == Animal.c.LastEscape))) + == ITERATIONS + 11 + ) # noqa # In operator (containedby) - assert len(fulltable(Animal.select(Animal.c.Species.like('%pede%' - )))) == 2 - assert len( - fulltable( - Animal.select( - Animal.c.Species.in_( - ['Lion', 'Tiger', 'Bear'])))) == 3 + assert ( + len(fulltable(Animal.select(Animal.c.Species.like("%pede%")))) + == 2 + ) + assert ( + len( + fulltable( + Animal.select( + Animal.c.Species.in_(["Lion", "Tiger", "Bear"]) + ) + ) + ) + == 3 + ) # Try In with cell references class thing(object): pass pet, pet2 = thing(), thing() - pet.Name, pet2.Name = 'Slug', 'Ostrich' - assert len( - fulltable( - Animal.select( - Animal.c.Species.in_([pet.Name, pet2.Name])))) == 2 + pet.Name, pet2.Name = "Slug", "Ostrich" + assert ( + len( + fulltable( + Animal.select( + Animal.c.Species.in_([pet.Name, pet2.Name]) + ) + ) + ) + == 2 + ) # logic and other functions - assert len(fulltable(Animal.select(Animal.c.Species.like('Slug' - )))) == 1 - assert len(fulltable(Animal.select(Animal.c.Species.like('%pede%' - )))) == 2 - name = 'Lion' - assert len( - fulltable( - Animal.select( - func.length( - Animal.c.Species) == len(name)))) == ITERATIONS + 3 - assert len( - fulltable( - Animal.select( - Animal.c.Species.like('%i%')))) == ITERATIONS + 7 + assert ( + len(fulltable(Animal.select(Animal.c.Species.like("Slug")))) + == 1 + ) + assert ( + len(fulltable(Animal.select(Animal.c.Species.like("%pede%")))) + == 2 + ) + name = "Lion" + assert ( + len( + fulltable( + Animal.select( + func.length(Animal.c.Species) == len(name) + ) + ) + ) + == ITERATIONS + 3 + ) + assert ( + len(fulltable(Animal.select(Animal.c.Species.like("%i%")))) + == ITERATIONS + 7 + ) # Test now(), today(), year(), month(), day() - assert len( - fulltable( - Zoo.select( - and_( - Zoo.c.Founded != None, # noqa - Zoo.c.Founded < func.current_timestamp( - _type=Date))))) == 3 - assert len( - fulltable( - Animal.select( - Animal.c.LastEscape == func.current_timestamp( - _type=Date)))) == 0 - assert len( - fulltable( - Animal.select( - func.date_part( - 'year', - Animal.c.LastEscape) == 2004))) == 1 - assert len( - fulltable( - Animal.select( - func.date_part( - 'month', - Animal.c.LastEscape) == 12))) == 1 - assert len( - fulltable( - Animal.select( - func.date_part( - 'day', - Animal.c.LastEscape) == 21))) == 1 + assert ( + len( + fulltable( + Zoo.select( + and_( + Zoo.c.Founded != None, # noqa + Zoo.c.Founded + < func.current_timestamp(_type=Date), + ) + ) + ) + ) + == 3 + ) + assert ( + len( + fulltable( + Animal.select( + Animal.c.LastEscape + == func.current_timestamp(_type=Date) + ) + ) + ) + == 0 + ) + assert ( + len( + fulltable( + Animal.select( + func.date_part("year", Animal.c.LastEscape) == 2004 + ) + ) + ) + == 1 + ) + assert ( + len( + fulltable( + Animal.select( + func.date_part("month", Animal.c.LastEscape) == 12 + ) + ) + ) + == 1 + ) + assert ( + len( + fulltable( + Animal.select( + func.date_part("day", Animal.c.LastEscape) == 21 + ) + ) + ) + == 1 + ) def _baseline_5_aggregates(self): - Animal = self.metadata.tables['Animal'] - Zoo = self.metadata.tables['Zoo'] + Animal = self.metadata.tables["Animal"] + Zoo = self.metadata.tables["Zoo"] engine = self.metadata.bind for x in range(ITERATIONS): @@ -294,29 +394,32 @@ class ZooMarkTest(replay_fixture.ReplayFixtureTest): view = engine.execute(select([Animal.c.Legs])).fetchall() legs = sorted([x[0] for x in view]) expected = { - 'Leopard': 73.5, - 'Slug': .75, - 'Tiger': None, - 'Lion': None, - 'Bear': None, - 'Ostrich': 103.2, - 'Centipede': None, - 'Emperor Penguin': None, - 'Adelie Penguin': None, - 'Millipede': None, - 'Ape': None, - 'Tick': None, + "Leopard": 73.5, + "Slug": 0.75, + "Tiger": None, + "Lion": None, + "Bear": None, + "Ostrich": 103.2, + "Centipede": None, + "Emperor Penguin": None, + "Adelie Penguin": None, + "Millipede": None, + "Ape": None, + "Tick": None, } for species, lifespan in engine.execute( - select([Animal.c.Species, Animal.c.Lifespan])).fetchall(): + select([Animal.c.Species, Animal.c.Lifespan]) + ).fetchall(): assert lifespan == expected[species] - expected = ['Montr\xe9al Biod\xf4me', 'Wild Animal Park'] - e = select([Zoo.c.Name], - and_(Zoo.c.Founded != None, # noqa - Zoo.c.Founded <= func.current_timestamp(), - Zoo.c.Founded >= datetime.date(1990, - 1, - 1))) + expected = ["Montr\xe9al Biod\xf4me", "Wild Animal Park"] + e = select( + [Zoo.c.Name], + and_( + Zoo.c.Founded != None, # noqa + Zoo.c.Founded <= func.current_timestamp(), + Zoo.c.Founded >= datetime.date(1990, 1, 1), + ), + ) values = [val[0] for val in engine.execute(e).fetchall()] assert set(values) == set(expected) @@ -325,50 +428,56 @@ class ZooMarkTest(replay_fixture.ReplayFixtureTest): legs = [ x[0] for x in engine.execute( - select([Animal.c.Legs], - distinct=True)).fetchall()] + select([Animal.c.Legs], distinct=True) + ).fetchall() + ] legs.sort() def _baseline_6_editing(self): - Zoo = self.metadata.tables['Zoo'] + Zoo = self.metadata.tables["Zoo"] engine = self.metadata.bind for x in range(ITERATIONS): # Edit - SDZ = engine.execute(Zoo.select(Zoo.c.Name == 'San Diego Zoo' - )).first() + SDZ = engine.execute( + Zoo.select(Zoo.c.Name == "San Diego Zoo") + ).first() engine.execute( - Zoo.update( - Zoo.c.ID == SDZ['ID']), - Name='The San Diego Zoo', + Zoo.update(Zoo.c.ID == SDZ["ID"]), + Name="The San Diego Zoo", Founded=datetime.date(1900, 1, 1), - Opens=datetime.time(7, 30, 0), Admission='35.00') + Opens=datetime.time(7, 30, 0), + Admission="35.00", + ) # Test edits - SDZ = engine.execute(Zoo.select(Zoo.c.Name == 'The San Diego Zoo' - )).first() - assert SDZ['Founded'] == datetime.date(1900, 1, 1), \ - SDZ['Founded'] + SDZ = engine.execute( + Zoo.select(Zoo.c.Name == "The San Diego Zoo") + ).first() + assert SDZ["Founded"] == datetime.date(1900, 1, 1), SDZ["Founded"] # Change it back - engine.execute(Zoo.update(Zoo.c.ID == SDZ['ID' - ]), Name='San Diego Zoo', - Founded=datetime.date(1935, 9, 13), - Opens=datetime.time(9, 0, 0), - Admission='0') + engine.execute( + Zoo.update(Zoo.c.ID == SDZ["ID"]), + Name="San Diego Zoo", + Founded=datetime.date(1935, 9, 13), + Opens=datetime.time(9, 0, 0), + Admission="0", + ) # Test re-edits - SDZ = engine.execute(Zoo.select(Zoo.c.Name == 'San Diego Zoo' - )).first() - assert SDZ['Founded'] == datetime.date(1935, 9, 13) + SDZ = engine.execute( + Zoo.select(Zoo.c.Name == "San Diego Zoo") + ).first() + assert SDZ["Founded"] == datetime.date(1935, 9, 13) def _baseline_7_multiview(self): - Zoo = self.metadata.tables['Zoo'] - Animal = self.metadata.tables['Animal'] + Zoo = self.metadata.tables["Zoo"] + Animal = self.metadata.tables["Animal"] engine = self.metadata.bind def fulltable(select): @@ -380,30 +489,42 @@ class ZooMarkTest(replay_fixture.ReplayFixtureTest): fulltable( select( [Zoo.c.ID] + list(Animal.c), - Zoo.c.Name == 'San Diego Zoo', - from_obj=[join(Zoo, Animal)])) - Zoo.select(Zoo.c.Name == 'San Diego Zoo') + Zoo.c.Name == "San Diego Zoo", + from_obj=[join(Zoo, Animal)], + ) + ) + Zoo.select(Zoo.c.Name == "San Diego Zoo") fulltable( select( [Zoo.c.ID, Animal.c.ID], and_( - Zoo.c.Name == 'San Diego Zoo', - Animal.c.Species == 'Leopard' + Zoo.c.Name == "San Diego Zoo", + Animal.c.Species == "Leopard", ), - from_obj=[join(Zoo, Animal)]) + from_obj=[join(Zoo, Animal)], + ) ) # Now try the same query with INNER, LEFT, and RIGHT JOINs. - fulltable(select([ - Zoo.c.Name, Animal.c.Species], - from_obj=[join(Zoo, Animal)])) - fulltable(select([ - Zoo.c.Name, Animal.c.Species], - from_obj=[outerjoin(Zoo, Animal)])) - fulltable(select([ - Zoo.c.Name, Animal.c.Species], - from_obj=[outerjoin(Animal, Zoo)])) + fulltable( + select( + [Zoo.c.Name, Animal.c.Species], + from_obj=[join(Zoo, Animal)], + ) + ) + fulltable( + select( + [Zoo.c.Name, Animal.c.Species], + from_obj=[outerjoin(Zoo, Animal)], + ) + ) + fulltable( + select( + [Zoo.c.Name, Animal.c.Species], + from_obj=[outerjoin(Animal, Zoo)], + ) + ) def _baseline_8_drop(self): self.metadata.drop_all() diff --git a/test/aaa_profiling/test_zoomark_orm.py b/test/aaa_profiling/test_zoomark_orm.py index ee21e9bc64..19ef5b74ab 100644 --- a/test/aaa_profiling/test_zoomark_orm.py +++ b/test/aaa_profiling/test_zoomark_orm.py @@ -6,9 +6,21 @@ An adaptation of Robert Brewers' ZooMark speed tests. """ import datetime -from sqlalchemy import Table, Column, Integer, Unicode, Date, \ - DateTime, Time, Float, Sequence, ForeignKey, \ - select, and_, func +from sqlalchemy import ( + Table, + Column, + Integer, + Unicode, + Date, + DateTime, + Time, + Float, + Sequence, + ForeignKey, + select, + and_, + func, +) from sqlalchemy.orm import mapper from sqlalchemy.testing import replay_fixture @@ -24,8 +36,8 @@ class ZooMarkTest(replay_fixture.ReplayFixtureTest): """ - __requires__ = 'cpython', - __only_on__ = 'postgresql+psycopg2' + __requires__ = ("cpython",) + __only_on__ = "postgresql+psycopg2" def _run_steps(self, ctx): with ctx(): @@ -49,42 +61,44 @@ class ZooMarkTest(replay_fixture.ReplayFixtureTest): def _baseline_1_create_tables(self): zoo = Table( - 'Zoo', + "Zoo", self.metadata, - Column('ID', Integer, Sequence('zoo_id_seq'), - primary_key=True, index=True), - Column('Name', Unicode(255)), - Column('Founded', Date), - Column('Opens', Time), - Column('LastEscape', DateTime), - Column('Admission', Float), + Column( + "ID", + Integer, + Sequence("zoo_id_seq"), + primary_key=True, + index=True, + ), + Column("Name", Unicode(255)), + Column("Founded", Date), + Column("Opens", Time), + Column("LastEscape", DateTime), + Column("Admission", Float), ) animal = Table( - 'Animal', + "Animal", self.metadata, - Column('ID', Integer, Sequence('animal_id_seq'), - primary_key=True), - Column('ZooID', Integer, ForeignKey('Zoo.ID'), index=True), - Column('Name', Unicode(100)), - Column('Species', Unicode(100)), - Column('Legs', Integer, default=4), - Column('LastEscape', DateTime), - Column('Lifespan', Float(4)), - Column('MotherID', Integer, ForeignKey('Animal.ID')), - Column('PreferredFoodID', Integer), - Column('AlternateFoodID', Integer), + Column("ID", Integer, Sequence("animal_id_seq"), primary_key=True), + Column("ZooID", Integer, ForeignKey("Zoo.ID"), index=True), + Column("Name", Unicode(100)), + Column("Species", Unicode(100)), + Column("Legs", Integer, default=4), + Column("LastEscape", DateTime), + Column("Lifespan", Float(4)), + Column("MotherID", Integer, ForeignKey("Animal.ID")), + Column("PreferredFoodID", Integer), + Column("AlternateFoodID", Integer), ) self.metadata.create_all() global Zoo, Animal class Zoo(object): - def __init__(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) class Animal(object): - def __init__(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) @@ -94,64 +108,79 @@ class ZooMarkTest(replay_fixture.ReplayFixtureTest): def _baseline_1a_populate(self): wap = Zoo( - Name='Wild Animal Park', Founded=datetime.date( - 2000, 1, 1), Opens=datetime.time( - 8, 15, 59), LastEscape=datetime.datetime( - 2004, 7, 29, 5, 6, 7, ), Admission=4.95) + Name="Wild Animal Park", + Founded=datetime.date(2000, 1, 1), + Opens=datetime.time(8, 15, 59), + LastEscape=datetime.datetime(2004, 7, 29, 5, 6, 7), + Admission=4.95, + ) self.session.add(wap) sdz = Zoo( - Name='San Diego Zoo', Founded=datetime.date( - 1835, 9, 13), Opens=datetime.time( - 9, 0, 0), Admission=0) + Name="San Diego Zoo", + Founded=datetime.date(1835, 9, 13), + Opens=datetime.time(9, 0, 0), + Admission=0, + ) self.session.add(sdz) - bio = Zoo(Name='Montr\xe9al Biod\xf4me', - Founded=datetime.date(1992, 6, 19), - Opens=datetime.time(9, 0, 0), Admission=11.75) + bio = Zoo( + Name="Montr\xe9al Biod\xf4me", + Founded=datetime.date(1992, 6, 19), + Opens=datetime.time(9, 0, 0), + Admission=11.75, + ) self.session.add(bio) - seaworld = Zoo(Name='Sea_World', Admission=60) + seaworld = Zoo(Name="Sea_World", Admission=60) self.session.add(seaworld) # Let's add a crazy futuristic Zoo to test large date values. - lp = Zoo(Name='Luna Park', Founded=datetime.date(2072, 7, 17), - Opens=datetime.time(0, 0, 0), Admission=134.95) + lp = Zoo( + Name="Luna Park", + Founded=datetime.date(2072, 7, 17), + Opens=datetime.time(0, 0, 0), + Admission=134.95, + ) self.session.add(lp) # Animals - leopard = Animal(Species='Leopard', Lifespan=73.5) + leopard = Animal(Species="Leopard", Lifespan=73.5) self.session.add(leopard) leopard.ZooID = wap.ID - leopard.LastEscape = \ - datetime.datetime(2004, 12, 21, 8, 15, 0, 999907, ) - self.session.add(Animal(Species='Lion', ZooID=wap.ID)) - self.session.add(Animal(Species='Slug', Legs=1, Lifespan=.75)) - self.session.add(Animal(Species='Tiger', ZooID=sdz.ID)) + leopard.LastEscape = datetime.datetime(2004, 12, 21, 8, 15, 0, 999907) + self.session.add(Animal(Species="Lion", ZooID=wap.ID)) + self.session.add(Animal(Species="Slug", Legs=1, Lifespan=0.75)) + self.session.add(Animal(Species="Tiger", ZooID=sdz.ID)) # Override Legs.default with itself just to make sure it works. - self.session.add(Animal(Species='Bear', Legs=4)) - self.session.add(Animal(Species='Ostrich', Legs=2, Lifespan=103.2)) - self.session.add(Animal(Species='Centipede', Legs=100)) - self.session.add(Animal(Species='Emperor Penguin', Legs=2, - ZooID=seaworld.ID)) - self.session.add(Animal(Species='Adelie Penguin', Legs=2, - ZooID=seaworld.ID)) - self.session.add(Animal(Species='Millipede', Legs=1000000, - ZooID=sdz.ID)) + self.session.add(Animal(Species="Bear", Legs=4)) + self.session.add(Animal(Species="Ostrich", Legs=2, Lifespan=103.2)) + self.session.add(Animal(Species="Centipede", Legs=100)) + self.session.add( + Animal(Species="Emperor Penguin", Legs=2, ZooID=seaworld.ID) + ) + self.session.add( + Animal(Species="Adelie Penguin", Legs=2, ZooID=seaworld.ID) + ) + self.session.add( + Animal(Species="Millipede", Legs=1000000, ZooID=sdz.ID) + ) # Add a mother and child to test relationships - bai_yun = Animal(Species='Ape', Nameu='Bai Yun', Legs=2) + bai_yun = Animal(Species="Ape", Nameu="Bai Yun", Legs=2) self.session.add(bai_yun) - self.session.add(Animal(Species='Ape', Name='Hua Mei', Legs=2, - MotherID=bai_yun.ID)) + self.session.add( + Animal(Species="Ape", Name="Hua Mei", Legs=2, MotherID=bai_yun.ID) + ) self.session.commit() def _baseline_2_insert(self): for x in range(ITERATIONS): - self.session.add(Animal(Species='Tick', Name='Tick %d' % x, - Legs=8)) + self.session.add( + Animal(Species="Tick", Name="Tick %d" % x, Legs=8) + ) self.session.flush() def _baseline_3_properties(self): @@ -159,113 +188,230 @@ class ZooMarkTest(replay_fixture.ReplayFixtureTest): # Zoos - list(self.session.query(Zoo).filter( - Zoo.Name == 'Wild Animal Park')) + list( + self.session.query(Zoo).filter(Zoo.Name == "Wild Animal Park") + ) list( self.session.query(Zoo).filter( - Zoo.Founded == datetime.date( - 1835, - 9, - 13))) + Zoo.Founded == datetime.date(1835, 9, 13) + ) + ) list( self.session.query(Zoo).filter( - Zoo.Name == 'Montr\xe9al Biod\xf4me')) + Zoo.Name == "Montr\xe9al Biod\xf4me" + ) + ) list(self.session.query(Zoo).filter(Zoo.Admission == float(60))) # Animals - list(self.session.query(Animal).filter( - Animal.Species == 'Leopard')) - list(self.session.query(Animal).filter( - Animal.Species == 'Ostrich')) - list(self.session.query(Animal).filter( - Animal.Legs == 1000000)) - list(self.session.query(Animal).filter( - Animal.Species == 'Tick')) + list( + self.session.query(Animal).filter(Animal.Species == "Leopard") + ) + list( + self.session.query(Animal).filter(Animal.Species == "Ostrich") + ) + list(self.session.query(Animal).filter(Animal.Legs == 1000000)) + list(self.session.query(Animal).filter(Animal.Species == "Tick")) def _baseline_4_expressions(self): for x in range(ITERATIONS): assert len(list(self.session.query(Zoo))) == 5 assert len(list(self.session.query(Animal))) == ITERATIONS + 12 - assert len(list(self.session.query(Animal) - .filter(Animal.Legs == 4))) == 4 - assert len(list(self.session.query(Animal) - .filter(Animal.Legs == 2))) == 5 - assert len( - list( - self.session.query(Animal).filter( - and_( - Animal.Legs >= 2, - Animal.Legs < 20)))) == ITERATIONS + 9 - assert len(list(self.session.query(Animal) - .filter(Animal.Legs > 10))) == 2 - assert len(list(self.session.query(Animal) - .filter(Animal.Lifespan > 70))) == 2 - assert len(list(self.session.query(Animal). - filter(Animal.Species.like('L%')))) == 2 - assert len(list(self.session.query(Animal). - filter(Animal.Species.like('%pede')))) == 2 - assert len(list(self.session.query(Animal) - .filter(Animal.LastEscape != None))) == 1 # noqa - assert len( - list( - self.session.query(Animal).filter( - Animal.LastEscape == None))) == ITERATIONS + 11 # noqa + assert ( + len(list(self.session.query(Animal).filter(Animal.Legs == 4))) + == 4 + ) + assert ( + len(list(self.session.query(Animal).filter(Animal.Legs == 2))) + == 5 + ) + assert ( + len( + list( + self.session.query(Animal).filter( + and_(Animal.Legs >= 2, Animal.Legs < 20) + ) + ) + ) + == ITERATIONS + 9 + ) + assert ( + len(list(self.session.query(Animal).filter(Animal.Legs > 10))) + == 2 + ) + assert ( + len( + list( + self.session.query(Animal).filter(Animal.Lifespan > 70) + ) + ) + == 2 + ) + assert ( + len( + list( + self.session.query(Animal).filter( + Animal.Species.like("L%") + ) + ) + ) + == 2 + ) + assert ( + len( + list( + self.session.query(Animal).filter( + Animal.Species.like("%pede") + ) + ) + ) + == 2 + ) + assert ( + len( + list( + self.session.query(Animal).filter( + Animal.LastEscape != None + ) + ) + ) + == 1 + ) # noqa + assert ( + len( + list( + self.session.query(Animal).filter( + Animal.LastEscape == None + ) + ) + ) + == ITERATIONS + 11 + ) # noqa # In operator (containedby) - assert len(list(self.session.query(Animal).filter( - Animal.Species.like('%pede%')))) == 2 - assert len( - list( - self.session.query(Animal). filter( - Animal.Species.in_( - ('Lion', 'Tiger', 'Bear'))))) == 3 + assert ( + len( + list( + self.session.query(Animal).filter( + Animal.Species.like("%pede%") + ) + ) + ) + == 2 + ) + assert ( + len( + list( + self.session.query(Animal).filter( + Animal.Species.in_(("Lion", "Tiger", "Bear")) + ) + ) + ) + == 3 + ) # Try In with cell references class thing(object): pass pet, pet2 = thing(), thing() - pet.Name, pet2.Name = 'Slug', 'Ostrich' - assert len(list(self.session.query(Animal). - filter(Animal.Species.in_((pet.Name, - pet2.Name))))) == 2 + pet.Name, pet2.Name = "Slug", "Ostrich" + assert ( + len( + list( + self.session.query(Animal).filter( + Animal.Species.in_((pet.Name, pet2.Name)) + ) + ) + ) + == 2 + ) # logic and other functions - name = 'Lion' - assert len(list(self.session.query(Animal). - filter(func.length(Animal.Species) - == len(name)))) == ITERATIONS + 3 - assert len(list(self.session.query(Animal). - filter(Animal.Species.like('%i%' - )))) == ITERATIONS + 7 + name = "Lion" + assert ( + len( + list( + self.session.query(Animal).filter( + func.length(Animal.Species) == len(name) + ) + ) + ) + == ITERATIONS + 3 + ) + assert ( + len( + list( + self.session.query(Animal).filter( + Animal.Species.like("%i%") + ) + ) + ) + == ITERATIONS + 7 + ) # Test now(), today(), year(), month(), day() - assert len( - list( - self.session.query(Zoo).filter( - and_( - Zoo.Founded != None, # noqa - Zoo.Founded < func.now())))) == 3 - assert len(list(self.session.query(Animal) - .filter(Animal.LastEscape == func.now()))) == 0 - assert len(list(self.session.query(Animal).filter( - func.date_part('year', Animal.LastEscape) == 2004))) == 1 - assert len( - list( - self.session.query(Animal). filter( - func.date_part( - 'month', - Animal.LastEscape) == 12))) == 1 - assert len(list(self.session.query(Animal).filter( - func.date_part('day', Animal.LastEscape) == 21))) == 1 + assert ( + len( + list( + self.session.query(Zoo).filter( + and_( + Zoo.Founded != None, # noqa + Zoo.Founded < func.now(), + ) + ) + ) + ) + == 3 + ) + assert ( + len( + list( + self.session.query(Animal).filter( + Animal.LastEscape == func.now() + ) + ) + ) + == 0 + ) + assert ( + len( + list( + self.session.query(Animal).filter( + func.date_part("year", Animal.LastEscape) == 2004 + ) + ) + ) + == 1 + ) + assert ( + len( + list( + self.session.query(Animal).filter( + func.date_part("month", Animal.LastEscape) == 12 + ) + ) + ) + == 1 + ) + assert ( + len( + list( + self.session.query(Animal).filter( + func.date_part("day", Animal.LastEscape) == 21 + ) + ) + ) + == 1 + ) def _baseline_5_aggregates(self): - Animal = self.metadata.tables['Animal'] - Zoo = self.metadata.tables['Zoo'] + Animal = self.metadata.tables["Animal"] + Zoo = self.metadata.tables["Zoo"] # TODO: convert to ORM engine = self.metadata.bind @@ -276,29 +422,32 @@ class ZooMarkTest(replay_fixture.ReplayFixtureTest): view = engine.execute(select([Animal.c.Legs])).fetchall() legs = sorted([x[0] for x in view]) expected = { - 'Leopard': 73.5, - 'Slug': .75, - 'Tiger': None, - 'Lion': None, - 'Bear': None, - 'Ostrich': 103.2, - 'Centipede': None, - 'Emperor Penguin': None, - 'Adelie Penguin': None, - 'Millipede': None, - 'Ape': None, - 'Tick': None, + "Leopard": 73.5, + "Slug": 0.75, + "Tiger": None, + "Lion": None, + "Bear": None, + "Ostrich": 103.2, + "Centipede": None, + "Emperor Penguin": None, + "Adelie Penguin": None, + "Millipede": None, + "Ape": None, + "Tick": None, } for species, lifespan in engine.execute( - select([Animal.c.Species, Animal.c.Lifespan])).fetchall(): + select([Animal.c.Species, Animal.c.Lifespan]) + ).fetchall(): assert lifespan == expected[species] - expected = ['Montr\xe9al Biod\xf4me', 'Wild Animal Park'] - e = select([Zoo.c.Name], - and_(Zoo.c.Founded != None, # noqa - Zoo.c.Founded <= func.current_timestamp(), - Zoo.c.Founded >= datetime.date(1990, - 1, - 1))) + expected = ["Montr\xe9al Biod\xf4me", "Wild Animal Park"] + e = select( + [Zoo.c.Name], + and_( + Zoo.c.Founded != None, # noqa + Zoo.c.Founded <= func.current_timestamp(), + Zoo.c.Founded >= datetime.date(1990, 1, 1), + ), + ) values = [val[0] for val in engine.execute(e).fetchall()] assert set(values) == set(expected) @@ -307,8 +456,9 @@ class ZooMarkTest(replay_fixture.ReplayFixtureTest): legs = [ x[0] for x in engine.execute( - select([Animal.c.Legs], - distinct=True)).fetchall()] + select([Animal.c.Legs], distinct=True) + ).fetchall() + ] legs.sort() def _baseline_6_editing(self): @@ -316,32 +466,40 @@ class ZooMarkTest(replay_fixture.ReplayFixtureTest): # Edit - SDZ = self.session.query(Zoo).filter(Zoo.Name == 'San Diego Zoo') \ - .one() - SDZ.Name = 'The San Diego Zoo' + SDZ = ( + self.session.query(Zoo) + .filter(Zoo.Name == "San Diego Zoo") + .one() + ) + SDZ.Name = "The San Diego Zoo" SDZ.Founded = datetime.date(1900, 1, 1) SDZ.Opens = datetime.time(7, 30, 0) SDZ.Admission = 35.00 # Test edits - SDZ = self.session.query(Zoo) \ - .filter(Zoo.Name == 'The San Diego Zoo').one() + SDZ = ( + self.session.query(Zoo) + .filter(Zoo.Name == "The San Diego Zoo") + .one() + ) assert SDZ.Founded == datetime.date(1900, 1, 1), SDZ.Founded # Change it back - SDZ.Name = 'San Diego Zoo' + SDZ.Name = "San Diego Zoo" SDZ.Founded = datetime.date(1835, 9, 13) SDZ.Opens = datetime.time(9, 0, 0) SDZ.Admission = 0 # Test re-edits - SDZ = self.session.query(Zoo).filter(Zoo.Name == 'San Diego Zoo') \ + SDZ = ( + self.session.query(Zoo) + .filter(Zoo.Name == "San Diego Zoo") .one() - assert SDZ.Founded == datetime.date(1835, 9, 13), \ - SDZ.Founded + ) + assert SDZ.Founded == datetime.date(1835, 9, 13), SDZ.Founded def _baseline_7_drop(self): self.session.rollback() diff --git a/test/base/test_dependency.py b/test/base/test_dependency.py index 120f6a1474..3376068f45 100644 --- a/test/base/test_dependency.py +++ b/test/base/test_dependency.py @@ -6,7 +6,6 @@ from sqlalchemy.testing import fixtures class DependencySortTest(fixtures.TestBase): - def assert_sort(self, tuples, allitems=None): if allitems is None: allitems = self._nodes_from_tuples(tuples) @@ -16,9 +15,9 @@ class DependencySortTest(fixtures.TestBase): assert conforms_partial_ordering(tuples, result) def assert_sort_deterministic(self, tuples, allitems, expected): - result = list(topological.sort(tuples, - allitems, - deterministic_order=True)) + result = list( + topological.sort(tuples, allitems, deterministic_order=True) + ) assert conforms_partial_ordering(tuples, result) assert result == expected @@ -29,15 +28,15 @@ class DependencySortTest(fixtures.TestBase): return s def test_sort_one(self): - rootnode = 'root' - node2 = 'node2' - node3 = 'node3' - node4 = 'node4' - subnode1 = 'subnode1' - subnode2 = 'subnode2' - subnode3 = 'subnode3' - subnode4 = 'subnode4' - subsubnode1 = 'subsubnode1' + rootnode = "root" + node2 = "node2" + node3 = "node3" + node4 = "node4" + subnode1 = "subnode1" + subnode2 = "subnode2" + subnode3 = "subnode3" + subnode4 = "subnode4" + subsubnode1 = "subsubnode1" tuples = [ (subnode3, subsubnode1), (node2, subnode1), @@ -47,37 +46,46 @@ class DependencySortTest(fixtures.TestBase): (rootnode, node4), (node4, subnode3), (node4, subnode4), - ] + ] self.assert_sort(tuples) def test_sort_two(self): - node1 = 'node1' - node2 = 'node2' - node3 = 'node3' - node4 = 'node4' - node5 = 'node5' - node6 = 'node6' - node7 = 'node7' - tuples = [(node1, node2), (node3, node4), (node4, node5), - (node5, node6), (node6, node2)] + node1 = "node1" + node2 = "node2" + node3 = "node3" + node4 = "node4" + node5 = "node5" + node6 = "node6" + node7 = "node7" + tuples = [ + (node1, node2), + (node3, node4), + (node4, node5), + (node5, node6), + (node6, node2), + ] self.assert_sort(tuples, [node7]) def test_sort_three(self): - node1 = 'keywords' - node2 = 'itemkeyowrds' - node3 = 'items' - node4 = 'hoho' - tuples = [(node1, node2), (node4, node1), (node1, node3), - (node3, node2)] + node1 = "keywords" + node2 = "itemkeyowrds" + node3 = "items" + node4 = "hoho" + tuples = [ + (node1, node2), + (node4, node1), + (node1, node3), + (node3, node2), + ] self.assert_sort(tuples) def test_sort_deterministic_one(self): - node1 = 'node1' - node2 = 'node2' - node3 = 'node3' - node4 = 'node4' - node5 = 'node5' - node6 = 'node6' + node1 = "node1" + node2 = "node2" + node3 = "node3" + node4 = "node4" + node5 = "node5" + node6 = "node6" allitems = [node6, node5, node4, node3, node2, node1] tuples = [(node6, node5), (node2, node1)] expected = [node6, node4, node3, node2, node5, node1] @@ -96,11 +104,11 @@ class DependencySortTest(fixtures.TestBase): self.assert_sort_deterministic(tuples, allitems, expected) def test_raise_on_cycle_one(self): - node1 = 'node1' - node2 = 'node2' - node3 = 'node3' - node4 = 'node4' - node5 = 'node5' + node1 = "node1" + node2 = "node2" + node3 = "node3" + node4 = "node4" + node5 = "node5" tuples = [ (node4, node5), (node5, node4), @@ -108,44 +116,72 @@ class DependencySortTest(fixtures.TestBase): (node2, node3), (node3, node1), (node4, node1), - ] + ] allitems = self._nodes_from_tuples(tuples) try: list(topological.sort(tuples, allitems)) assert False except exc.CircularDependencyError as err: - eq_(err.cycles, set(['node1', 'node3', 'node2', 'node5', - 'node4'])) - eq_(err.edges, set([('node3', 'node1'), ('node4', 'node1'), - ('node2', 'node3'), ('node1', 'node2'), - ('node4', 'node5'), ('node5', 'node4')])) + eq_(err.cycles, set(["node1", "node3", "node2", "node5", "node4"])) + eq_( + err.edges, + set( + [ + ("node3", "node1"), + ("node4", "node1"), + ("node2", "node3"), + ("node1", "node2"), + ("node4", "node5"), + ("node5", "node4"), + ] + ), + ) def test_raise_on_cycle_two(self): # this condition was arising from ticket:362 and was not treated # properly by topological sort - node1 = 'node1' - node2 = 'node2' - node3 = 'node3' - node4 = 'node4' - tuples = [(node1, node2), (node3, node1), (node2, node4), - (node3, node2), (node2, node3)] + node1 = "node1" + node2 = "node2" + node3 = "node3" + node4 = "node4" + tuples = [ + (node1, node2), + (node3, node1), + (node2, node4), + (node3, node2), + (node2, node3), + ] allitems = self._nodes_from_tuples(tuples) try: list(topological.sort(tuples, allitems)) assert False except exc.CircularDependencyError as err: - eq_(err.cycles, set(['node1', 'node3', 'node2'])) - eq_(err.edges, set([('node3', 'node1'), ('node2', 'node3'), - ('node3', 'node2'), ('node1', 'node2'), - ('node2', 'node4')])) + eq_(err.cycles, set(["node1", "node3", "node2"])) + eq_( + err.edges, + set( + [ + ("node3", "node1"), + ("node2", "node3"), + ("node3", "node2"), + ("node1", "node2"), + ("node2", "node4"), + ] + ), + ) def test_raise_on_cycle_three(self): - question, issue, providerservice, answer, provider = \ - 'Question', 'Issue', 'ProviderService', 'Answer', 'Provider' + question, issue, providerservice, answer, provider = ( + "Question", + "Issue", + "ProviderService", + "Answer", + "Provider", + ) tuples = [ (question, issue), (providerservice, issue), @@ -155,10 +191,13 @@ class DependencySortTest(fixtures.TestBase): (provider, providerservice), (question, answer), (issue, question), - ] + ] allitems = self._nodes_from_tuples(tuples) - assert_raises(exc.CircularDependencyError, list, - topological.sort(tuples, allitems)) + assert_raises( + exc.CircularDependencyError, + list, + topological.sort(tuples, allitems), + ) # TODO: test find_cycles @@ -174,27 +213,34 @@ class DependencySortTest(fixtures.TestBase): self.assert_sort(tuples) def test_find_cycle_one(self): - node1 = 'node1' - node2 = 'node2' - node3 = 'node3' - node4 = 'node4' - tuples = [(node1, node2), (node3, node1), (node2, node4), - (node3, node2), (node2, node3)] - eq_(topological.find_cycles(tuples, - self._nodes_from_tuples(tuples)), set([node1, node2, node3])) + node1 = "node1" + node2 = "node2" + node3 = "node3" + node4 = "node4" + tuples = [ + (node1, node2), + (node3, node1), + (node2, node4), + (node3, node2), + (node2, node3), + ] + eq_( + topological.find_cycles(tuples, self._nodes_from_tuples(tuples)), + set([node1, node2, node3]), + ) def test_find_multiple_cycles_one(self): - node1 = 'node1' - node2 = 'node2' - node3 = 'node3' - node4 = 'node4' - node5 = 'node5' - node6 = 'node6' - node7 = 'node7' - node8 = 'node8' - node9 = 'node9' + node1 = "node1" + node2 = "node2" + node3 = "node3" + node4 = "node4" + node5 = "node5" + node6 = "node6" + node7 = "node7" + node8 = "node8" + node9 = "node9" tuples = [ # cycle 1 cycle 2 cycle 3 cycle 4, but only if cycle - # 1 nodes are present + # 1 nodes are present (node1, node2), (node2, node4), (node4, node1), @@ -206,36 +252,33 @@ class DependencySortTest(fixtures.TestBase): (node8, node4), (node3, node1), (node3, node2), - ] - allnodes = set([ - node1, - node2, - node3, - node4, - node5, - node6, - node7, - node8, - node9, - ]) - eq_(topological.find_cycles(tuples, allnodes), set([ - 'node8', - 'node1', - 'node2', - 'node5', - 'node4', - 'node7', - 'node6', - 'node9', - ])) + ] + allnodes = set( + [node1, node2, node3, node4, node5, node6, node7, node8, node9] + ) + eq_( + topological.find_cycles(tuples, allnodes), + set( + [ + "node8", + "node1", + "node2", + "node5", + "node4", + "node7", + "node6", + "node9", + ] + ), + ) def test_find_multiple_cycles_two(self): - node1 = 'node1' - node2 = 'node2' - node3 = 'node3' - node4 = 'node4' - node5 = 'node5' - node6 = 'node6' + node1 = "node1" + node2 = "node2" + node3 = "node3" + node4 = "node4" + node5 = "node5" + node6 = "node6" tuples = [ # cycle 1 cycle 2 (node1, node2), (node2, node4), @@ -244,28 +287,21 @@ class DependencySortTest(fixtures.TestBase): (node6, node2), (node2, node4), (node4, node1), - ] - allnodes = set([ - node1, - node2, - node3, - node4, - node5, - node6, - ]) + ] + allnodes = set([node1, node2, node3, node4, node5, node6]) # node6 only became present here once [ticket:2282] was addressed. eq_( topological.find_cycles(tuples, allnodes), - set(['node1', 'node2', 'node4', 'node6']) + set(["node1", "node2", "node4", "node6"]), ) def test_find_multiple_cycles_three(self): - node1 = 'node1' - node2 = 'node2' - node3 = 'node3' - node4 = 'node4' - node5 = 'node5' - node6 = 'node6' + node1 = "node1" + node2 = "node2" + node3 = "node3" + node4 = "node4" + node5 = "node5" + node6 = "node6" tuples = [ # cycle 1 cycle 2 cycle3 cycle4 (node1, node2), (node2, node1), @@ -276,36 +312,61 @@ class DependencySortTest(fixtures.TestBase): (node2, node5), (node5, node6), (node6, node2), - ] - allnodes = set([ - node1, - node2, - node3, - node4, - node5, - node6, - ]) + ] + allnodes = set([node1, node2, node3, node4, node5, node6]) eq_(topological.find_cycles(tuples, allnodes), allnodes) def test_find_multiple_cycles_four(self): tuples = [ - ('node6', 'node2'), - ('node15', 'node19'), - ('node19', 'node2'), ('node4', 'node10'), - ('node15', 'node13'), - ('node17', 'node11'), ('node1', 'node19'), ('node15', 'node8'), - ('node6', 'node20'), ('node14', 'node11'), ('node6', 'node14'), - ('node11', 'node2'), ('node10', 'node20'), ('node1', 'node11'), - ('node20', 'node19'), ('node4', 'node20'), ('node15', 'node20'), - ('node9', 'node19'), ('node11', 'node10'), ('node11', 'node19'), - ('node13', 'node6'), ('node3', 'node15'), ('node9', 'node11'), - ('node4', 'node17'), ('node2', 'node20'), ('node19', 'node10'), - ('node8', 'node4'), ('node11', 'node3'), ('node6', 'node1') + ("node6", "node2"), + ("node15", "node19"), + ("node19", "node2"), + ("node4", "node10"), + ("node15", "node13"), + ("node17", "node11"), + ("node1", "node19"), + ("node15", "node8"), + ("node6", "node20"), + ("node14", "node11"), + ("node6", "node14"), + ("node11", "node2"), + ("node10", "node20"), + ("node1", "node11"), + ("node20", "node19"), + ("node4", "node20"), + ("node15", "node20"), + ("node9", "node19"), + ("node11", "node10"), + ("node11", "node19"), + ("node13", "node6"), + ("node3", "node15"), + ("node9", "node11"), + ("node4", "node17"), + ("node2", "node20"), + ("node19", "node10"), + ("node8", "node4"), + ("node11", "node3"), + ("node6", "node1"), ] - allnodes = ['node%d' % i for i in range(1, 21)] + allnodes = ["node%d" % i for i in range(1, 21)] eq_( topological.find_cycles(tuples, allnodes), - set(['node11', 'node10', 'node13', 'node15', 'node14', 'node17', - 'node19', 'node20', 'node8', 'node1', 'node3', 'node2', - 'node4', 'node6']) + set( + [ + "node11", + "node10", + "node13", + "node15", + "node14", + "node17", + "node19", + "node20", + "node8", + "node1", + "node3", + "node2", + "node4", + "node6", + ] + ), ) diff --git a/test/base/test_events.py b/test/base/test_events.py index 288c6091ff..5cf5d89efe 100644 --- a/test/base/test_events.py +++ b/test/base/test_events.py @@ -1,7 +1,12 @@ """Test event registration and listening.""" -from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, \ - is_, is_not_ +from sqlalchemy.testing import ( + eq_, + assert_raises, + assert_raises_message, + is_, + is_not_, +) from sqlalchemy import event, exc from sqlalchemy.testing import fixtures from sqlalchemy.testing.util import gc_collect @@ -25,10 +30,11 @@ class EventsTest(fixtures.TestBase): class Target(object): dispatch = event.dispatcher(TargetEvents) + self.Target = Target def tearDown(self): - event.base._remove_dispatcher(self.Target.__dict__['dispatch'].events) + event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events) def test_register_class(self): def listen(x, y): @@ -54,6 +60,7 @@ class EventsTest(fixtures.TestBase): def test_bool_clslevel(self): def listen_one(x, y): pass + event.listen(self.Target, "event_one", listen_one) t = self.Target() assert t.dispatch.event_one @@ -98,7 +105,7 @@ class EventsTest(fixtures.TestBase): eq_( list(self.Target().dispatch.event_one), - [listen_three, listen_one, listen_two] + [listen_three, listen_one, listen_two], ) def test_append_vs_insert_instance(self): @@ -118,7 +125,7 @@ class EventsTest(fixtures.TestBase): eq_( list(target.dispatch.event_one), - [listen_three, listen_one, listen_two] + [listen_three, listen_one, listen_two], ) def test_decorator(self): @@ -131,44 +138,37 @@ class EventsTest(fixtures.TestBase): def listen_two(x, y): pass - eq_( - list(self.Target().dispatch.event_one), - [listen_one] - ) + eq_(list(self.Target().dispatch.event_one), [listen_one]) - eq_( - list(self.Target().dispatch.event_two), - [listen_two] - ) + eq_(list(self.Target().dispatch.event_two), [listen_two]) - eq_( - list(self.Target().dispatch.event_three), - [listen_two] - ) + eq_(list(self.Target().dispatch.event_three), [listen_two]) def test_no_instance_level_collections(self): @event.listens_for(self.Target, "event_one") def listen_one(x, y): pass + t1 = self.Target() t2 = self.Target() t1.dispatch.event_one(5, 6) t2.dispatch.event_one(5, 6) is_( - self.Target.dispatch._empty_listener_reg[self.Target]['event_one'], - t1.dispatch.event_one + self.Target.dispatch._empty_listener_reg[self.Target]["event_one"], + t1.dispatch.event_one, ) @event.listens_for(t1, "event_one") def listen_two(x, y): pass + is_not_( - self.Target.dispatch._empty_listener_reg[self.Target]['event_one'], - t1.dispatch.event_one + self.Target.dispatch._empty_listener_reg[self.Target]["event_one"], + t1.dispatch.event_one, ) is_( - self.Target.dispatch._empty_listener_reg[self.Target]['event_one'], - t2.dispatch.event_one + self.Target.dispatch._empty_listener_reg[self.Target]["event_one"], + t2.dispatch.event_one, ) def test_immutable_methods(self): @@ -181,14 +181,11 @@ class EventsTest(fixtures.TestBase): t1.dispatch.event_one.clear, ]: assert_raises_message( - NotImplementedError, - r"need to call for_modify\(\)", - meth + NotImplementedError, r"need to call for_modify\(\)", meth ) class NamedCallTest(fixtures.TestBase): - def _fixture(self): class TargetEventsOne(event.Events): def event_one(self, x, y): @@ -202,6 +199,7 @@ class NamedCallTest(fixtures.TestBase): class TargetOne(object): dispatch = event.dispatcher(TargetEventsOne) + return TargetOne def _wrapped_fixture(self): @@ -212,6 +210,7 @@ class NamedCallTest(fixtures.TestBase): def adapt(*args): fn(*["adapted %s" % arg for arg in args]) + event_key = event_key.with_wrapper(adapt) event_key.base_listen() @@ -224,6 +223,7 @@ class NamedCallTest(fixtures.TestBase): class Target(object): dispatch = event.dispatcher(TargetEvents) + return Target def test_kw_accept(self): @@ -237,10 +237,7 @@ class NamedCallTest(fixtures.TestBase): TargetOne().dispatch.event_one(4, 5) - eq_( - canary.mock_calls, - [call({"x": 4, "y": 5})] - ) + eq_(canary.mock_calls, [call({"x": 4, "y": 5})]) def test_kw_accept_wrapped(self): TargetOne = self._wrapped_fixture() @@ -253,10 +250,7 @@ class NamedCallTest(fixtures.TestBase): TargetOne().dispatch.event_one(4, 5) - eq_( - canary.mock_calls, - [call({'y': 'adapted 5', 'x': 'adapted 4'})] - ) + eq_(canary.mock_calls, [call({"y": "adapted 5", "x": "adapted 4"})]) def test_partial_kw_accept(self): TargetOne = self._fixture() @@ -269,10 +263,7 @@ class NamedCallTest(fixtures.TestBase): TargetOne().dispatch.event_five(4, 5, 6, 7) - eq_( - canary.mock_calls, - [call(6, 5, {"x": 4, "q": 7})] - ) + eq_(canary.mock_calls, [call(6, 5, {"x": 4, "q": 7})]) def test_partial_kw_accept_wrapped(self): TargetOne = self._wrapped_fixture() @@ -287,8 +278,13 @@ class NamedCallTest(fixtures.TestBase): eq_( canary.mock_calls, - [call('adapted 6', 'adapted 5', - {'q': 'adapted 7', 'x': 'adapted 4'})] + [ + call( + "adapted 6", + "adapted 5", + {"q": "adapted 7", "x": "adapted 4"}, + ) + ], ) def test_kw_accept_plus_kw(self): @@ -301,10 +297,7 @@ class NamedCallTest(fixtures.TestBase): TargetOne().dispatch.event_two(4, 5, z=8, q=5) - eq_( - canary.mock_calls, - [call({"x": 4, "y": 5, "z": 8, "q": 5})] - ) + eq_(canary.mock_calls, [call({"x": 4, "y": 5, "z": 8, "q": 5})]) class LegacySignatureTest(fixtures.TestBase): @@ -312,7 +305,6 @@ class LegacySignatureTest(fixtures.TestBase): def setUp(self): class TargetEventsOne(event.Events): - @event._legacy_signature("0.9", ["x", "y"]) def event_three(self, x, y, z, q): pass @@ -322,18 +314,20 @@ class LegacySignatureTest(fixtures.TestBase): pass @event._legacy_signature( - "0.9", ["x", "y", "z", "q"], - lambda x, y: (x, y, x + y, x * y)) + "0.9", ["x", "y", "z", "q"], lambda x, y: (x, y, x + y, x * y) + ) def event_six(self, x, y): pass class TargetOne(object): dispatch = event.dispatcher(TargetEventsOne) + self.TargetOne = TargetOne def tearDown(self): event.base._remove_dispatcher( - self.TargetOne.__dict__['dispatch'].events) + self.TargetOne.__dict__["dispatch"].events + ) def test_legacy_accept(self): canary = Mock() @@ -344,10 +338,7 @@ class LegacySignatureTest(fixtures.TestBase): self.TargetOne().dispatch.event_three(4, 5, 6, 7) - eq_( - canary.mock_calls, - [call(4, 5)] - ) + eq_(canary.mock_calls, [call(4, 5)]) def test_legacy_accept_kw_cls(self): canary = Mock() @@ -355,6 +346,7 @@ class LegacySignatureTest(fixtures.TestBase): @event.listens_for(self.TargetOne, "event_four") def handler1(x, y, **kw): canary(x, y, kw) + self._test_legacy_accept_kw(self.TargetOne(), canary) def test_legacy_accept_kw_instance(self): @@ -365,6 +357,7 @@ class LegacySignatureTest(fixtures.TestBase): @event.listens_for(inst, "event_four") def handler1(x, y, **kw): canary(x, y, kw) + self._test_legacy_accept_kw(inst, canary) def test_legacy_accept_partial(self): @@ -372,28 +365,23 @@ class LegacySignatureTest(fixtures.TestBase): def evt(a, x, y, **kw): canary(a, x, y, **kw) + from functools import partial + evt_partial = partial(evt, 5) target = self.TargetOne() event.listen(target, "event_four", evt_partial) # can't do legacy accept on a partial; we can't inspect it assert_raises( - TypeError, - target.dispatch.event_four, 4, 5, 6, 7, foo="bar" + TypeError, target.dispatch.event_four, 4, 5, 6, 7, foo="bar" ) target.dispatch.event_four(4, 5, foo="bar") - eq_( - canary.mock_calls, - [call(5, 4, 5, foo="bar")] - ) + eq_(canary.mock_calls, [call(5, 4, 5, foo="bar")]) def _test_legacy_accept_kw(self, target, canary): target.dispatch.event_four(4, 5, 6, 7, foo="bar") - eq_( - canary.mock_calls, - [call(4, 5, {"foo": "bar"})] - ) + eq_(canary.mock_calls, [call(4, 5, {"foo": "bar"})]) def test_complex_legacy_accept(self): canary = Mock() @@ -403,10 +391,7 @@ class LegacySignatureTest(fixtures.TestBase): canary(x, y, z, q) self.TargetOne().dispatch.event_six(4, 5) - eq_( - canary.mock_calls, - [call(4, 5, 9, 20)] - ) + eq_(canary.mock_calls, [call(4, 5, 9, 20)]) def test_legacy_accept_from_method(self): canary = Mock() @@ -418,10 +403,7 @@ class LegacySignatureTest(fixtures.TestBase): event.listen(self.TargetOne, "event_three", MyClass().handler1) self.TargetOne().dispatch.event_three(4, 5, 6, 7) - eq_( - canary.mock_calls, - [call(4, 5)] - ) + eq_(canary.mock_calls, [call(4, 5)]) def test_standard_accept_has_legacies(self): canary = Mock() @@ -430,10 +412,7 @@ class LegacySignatureTest(fixtures.TestBase): self.TargetOne().dispatch.event_three(4, 5) - eq_( - canary.mock_calls, - [call(4, 5)] - ) + eq_(canary.mock_calls, [call(4, 5)]) def test_kw_accept_has_legacies(self): canary = Mock() @@ -444,10 +423,7 @@ class LegacySignatureTest(fixtures.TestBase): self.TargetOne().dispatch.event_three(4, 5, 6, 7) - eq_( - canary.mock_calls, - [call({"x": 4, "y": 5, "z": 6, "q": 7})] - ) + eq_(canary.mock_calls, [call({"x": 4, "y": 5, "z": 6, "q": 7})]) def test_kw_accept_plus_kw_has_legacies(self): canary = Mock() @@ -460,15 +436,15 @@ class LegacySignatureTest(fixtures.TestBase): eq_( canary.mock_calls, - [call({"x": 4, "y": 5, "z": 6, "q": 7, "foo": "bar"})] + [call({"x": 4, "y": 5, "z": 6, "q": 7, "foo": "bar"})], ) class ClsLevelListenTest(fixtures.TestBase): - def tearDown(self): event.base._remove_dispatcher( - self.TargetOne.__dict__['dispatch'].events) + self.TargetOne.__dict__["dispatch"].events + ) def setUp(self): class TargetEventsOne(event.Events): @@ -477,6 +453,7 @@ class ClsLevelListenTest(fixtures.TestBase): class TargetOne(object): dispatch = event.dispatcher(TargetEventsOne) + self.TargetOne = TargetOne def test_lis_subcalss_lis(self): @@ -491,10 +468,7 @@ class ClsLevelListenTest(fixtures.TestBase): def handler2(x, y): pass - eq_( - len(SubTarget().dispatch.event_one), - 2 - ) + eq_(len(SubTarget().dispatch.event_one), 2) def test_lis_multisub_lis(self): @event.listens_for(self.TargetOne, "event_one") @@ -511,14 +485,8 @@ class ClsLevelListenTest(fixtures.TestBase): def handler2(x, y): pass - eq_( - len(SubTarget().dispatch.event_one), - 2 - ) - eq_( - len(SubSubTarget().dispatch.event_one), - 2 - ) + eq_(len(SubTarget().dispatch.event_one), 2) + eq_(len(SubSubTarget().dispatch.event_one), 2) def test_two_sub_lis(self): class SubTarget1(self.TargetOne): @@ -561,14 +529,17 @@ class AcceptTargetsTest(fixtures.TestBase): class TargetTwo(object): dispatch = event.dispatcher(TargetEventsTwo) + self.TargetOne = TargetOne self.TargetTwo = TargetTwo def tearDown(self): event.base._remove_dispatcher( - self.TargetOne.__dict__['dispatch'].events) + self.TargetOne.__dict__["dispatch"].events + ) event.base._remove_dispatcher( - self.TargetTwo.__dict__['dispatch'].events) + self.TargetTwo.__dict__["dispatch"].events + ) def test_target_accept(self): """Test that events of the same name are routed to the correct @@ -591,15 +562,9 @@ class AcceptTargetsTest(fixtures.TestBase): event.listen(self.TargetOne, "event_one", listen_one) event.listen(self.TargetTwo, "event_one", listen_two) - eq_( - list(self.TargetOne().dispatch.event_one), - [listen_one] - ) + eq_(list(self.TargetOne().dispatch.event_one), [listen_one]) - eq_( - list(self.TargetTwo().dispatch.event_one), - [listen_two] - ) + eq_(list(self.TargetTwo().dispatch.event_one), [listen_two]) t1 = self.TargetOne() t2 = self.TargetTwo() @@ -607,15 +572,9 @@ class AcceptTargetsTest(fixtures.TestBase): event.listen(t1, "event_one", listen_three) event.listen(t2, "event_one", listen_four) - eq_( - list(t1.dispatch.event_one), - [listen_one, listen_three] - ) + eq_(list(t1.dispatch.event_one), [listen_one, listen_three]) - eq_( - list(t2.dispatch.event_one), - [listen_two, listen_four] - ) + eq_(list(t2.dispatch.event_one), [listen_two, listen_four]) class CustomTargetsTest(fixtures.TestBase): @@ -625,7 +584,7 @@ class CustomTargetsTest(fixtures.TestBase): class TargetEvents(event.Events): @classmethod def _accept_with(cls, target): - if target == 'one': + if target == "one": return Target else: return None @@ -635,10 +594,11 @@ class CustomTargetsTest(fixtures.TestBase): class Target(object): dispatch = event.dispatcher(TargetEvents) + self.Target = Target def tearDown(self): - event.base._remove_dispatcher(self.Target.__dict__['dispatch'].events) + event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events) def test_indirect(self): def listen(x, y): @@ -646,15 +606,14 @@ class CustomTargetsTest(fixtures.TestBase): event.listen("one", "event_one", listen) - eq_( - list(self.Target().dispatch.event_one), - [listen] - ) + eq_(list(self.Target().dispatch.event_one), [listen]) assert_raises( exc.InvalidRequestError, event.listen, - listen, "event_one", self.Target + listen, + "event_one", + self.Target, ) @@ -693,8 +652,10 @@ class ListenOverrideTest(fixtures.TestBase): def _listen(cls, event_key, add=False): fn = event_key.fn if add: + def adapt(x, y): fn(x + y) + event_key = event_key.with_wrapper(adapt) event_key.base_listen() @@ -704,10 +665,11 @@ class ListenOverrideTest(fixtures.TestBase): class Target(object): dispatch = event.dispatcher(TargetEvents) + self.Target = Target def tearDown(self): - event.base._remove_dispatcher(self.Target.__dict__['dispatch'].events) + event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events) def test_listen_override(self): listen_one = Mock() @@ -720,46 +682,28 @@ class ListenOverrideTest(fixtures.TestBase): t1.dispatch.event_one(5, 7) t1.dispatch.event_one(10, 5) - eq_( - listen_one.mock_calls, - [call(12), call(15)] - ) - eq_( - listen_two.mock_calls, - [call(5, 7), call(10, 5)] - ) + eq_(listen_one.mock_calls, [call(12), call(15)]) + eq_(listen_two.mock_calls, [call(5, 7), call(10, 5)]) def test_remove_clslevel(self): listen_one = Mock() event.listen(self.Target, "event_one", listen_one, add=True) t1 = self.Target() t1.dispatch.event_one(5, 7) - eq_( - listen_one.mock_calls, - [call(12)] - ) + eq_(listen_one.mock_calls, [call(12)]) event.remove(self.Target, "event_one", listen_one) t1.dispatch.event_one(10, 5) - eq_( - listen_one.mock_calls, - [call(12)] - ) + eq_(listen_one.mock_calls, [call(12)]) def test_remove_instancelevel(self): listen_one = Mock() t1 = self.Target() event.listen(t1, "event_one", listen_one, add=True) t1.dispatch.event_one(5, 7) - eq_( - listen_one.mock_calls, - [call(12)] - ) + eq_(listen_one.mock_calls, [call(12)]) event.remove(t1, "event_one", listen_one) t1.dispatch.event_one(10, 5) - eq_( - listen_one.mock_calls, - [call(12)] - ) + eq_(listen_one.mock_calls, [call(12)]) class PropagateTest(fixtures.TestBase): @@ -773,6 +717,7 @@ class PropagateTest(fixtures.TestBase): class Target(object): dispatch = event.dispatcher(TargetEvents) + self.Target = Target def test_propagate(self): @@ -791,14 +736,8 @@ class PropagateTest(fixtures.TestBase): t2.dispatch.event_one(t2, 1) t2.dispatch.event_two(t2, 2) - eq_( - listen_one.mock_calls, - [call(t2, 1)] - ) - eq_( - listen_two.mock_calls, - [] - ) + eq_(listen_one.mock_calls, [call(t2, 1)]) + eq_(listen_two.mock_calls, []) class JoinTest(fixtures.TestBase): @@ -827,11 +766,9 @@ class JoinTest(fixtures.TestBase): self.TargetElement = TargetElement def tearDown(self): - for cls in ( - self.TargetElement, - self.TargetFactory, self.BaseTarget): - if 'dispatch' in cls.__dict__: - event.base._remove_dispatcher(cls.__dict__['dispatch'].events) + for cls in (self.TargetElement, self.TargetFactory, self.BaseTarget): + if "dispatch" in cls.__dict__: + event.base._remove_dispatcher(cls.__dict__["dispatch"].events) def test_neither(self): element = self.TargetFactory().create() @@ -854,7 +791,7 @@ class JoinTest(fixtures.TestBase): [ call({"target": element, "arg": 1}), call({"target": element, "arg": 2}), - ] + ], ) def test_parent_class_only(self): @@ -868,7 +805,7 @@ class JoinTest(fixtures.TestBase): element.run_event(3) eq_( l1.mock_calls, - [call(element, 1), call(element, 2), call(element, 3)] + [call(element, 1), call(element, 2), call(element, 3)], ) def test_parent_class_child_class(self): @@ -884,11 +821,11 @@ class JoinTest(fixtures.TestBase): element.run_event(3) eq_( l1.mock_calls, - [call(element, 1), call(element, 2), call(element, 3)] + [call(element, 1), call(element, 2), call(element, 3)], ) eq_( l2.mock_calls, - [call(element, 1), call(element, 2), call(element, 3)] + [call(element, 1), call(element, 2), call(element, 3)], ) def test_parent_class_child_instance_apply_after(self): @@ -906,12 +843,9 @@ class JoinTest(fixtures.TestBase): eq_( l1.mock_calls, - [call(element, 1), call(element, 2), call(element, 3)] - ) - eq_( - l2.mock_calls, - [call(element, 2), call(element, 3)] + [call(element, 1), call(element, 2), call(element, 3)], ) + eq_(l2.mock_calls, [call(element, 2), call(element, 3)]) def test_parent_class_child_instance_apply_before(self): l1 = Mock() @@ -928,11 +862,11 @@ class JoinTest(fixtures.TestBase): eq_( l1.mock_calls, - [call(element, 1), call(element, 2), call(element, 3)] + [call(element, 1), call(element, 2), call(element, 3)], ) eq_( l2.mock_calls, - [call(element, 1), call(element, 2), call(element, 3)] + [call(element, 1), call(element, 2), call(element, 3)], ) def test_parent_instance_child_class_apply_before(self): @@ -952,11 +886,11 @@ class JoinTest(fixtures.TestBase): eq_( l1.mock_calls, - [call(element, 1), call(element, 2), call(element, 3)] + [call(element, 1), call(element, 2), call(element, 3)], ) eq_( l2.mock_calls, - [call(element, 1), call(element, 2), call(element, 3)] + [call(element, 1), call(element, 2), call(element, 3)], ) def test_parent_instance_child_class_apply_after(self): @@ -984,13 +918,11 @@ class JoinTest(fixtures.TestBase): # using a @property, then we get them, at the arguable # expense of the extra method call to access the .listeners # collection - eq_( - l1.mock_calls, [call(element, 2), call(element, 3)] - ) + eq_(l1.mock_calls, [call(element, 2), call(element, 3)]) eq_( l2.mock_calls, - [call(element, 1), call(element, 2), call(element, 3)] + [call(element, 1), call(element, 2), call(element, 3)], ) def test_parent_instance_child_instance_apply_before(self): @@ -1009,11 +941,11 @@ class JoinTest(fixtures.TestBase): eq_( l1.mock_calls, - [call(element, 1), call(element, 2), call(element, 3)] + [call(element, 1), call(element, 2), call(element, 3)], ) eq_( l2.mock_calls, - [call(element, 1), call(element, 2), call(element, 3)] + [call(element, 1), call(element, 2), call(element, 3)], ) def test_parent_events_child_no_events(self): @@ -1029,12 +961,11 @@ class JoinTest(fixtures.TestBase): eq_( l1.mock_calls, - [call(element, 1), call(element, 2), call(element, 3)] + [call(element, 1), call(element, 2), call(element, 3)], ) class DisableClsPropagateTest(fixtures.TestBase): - def setUp(self): class TargetEvents(event.Events): def event_one(self, target, arg): @@ -1054,8 +985,8 @@ class DisableClsPropagateTest(fixtures.TestBase): def tearDown(self): for cls in (self.SubTarget, self.BaseTarget): - if 'dispatch' in cls.__dict__: - event.base._remove_dispatcher(cls.__dict__['dispatch'].events) + if "dispatch" in cls.__dict__: + event.base._remove_dispatcher(cls.__dict__["dispatch"].events) def test_listen_invoke_clslevel(self): canary = Mock() @@ -1105,6 +1036,7 @@ class RemovalTest(fixtures.TestBase): class Target(object): dispatch = event.dispatcher(TargetEvents) + return Target def _wrapped_fixture(self): @@ -1115,6 +1047,7 @@ class RemovalTest(fixtures.TestBase): def adapt(value): fn("adapted " + value) + event_key = event_key.with_wrapper(adapt) event_key.base_listen() @@ -1124,6 +1057,7 @@ class RemovalTest(fixtures.TestBase): class Target(object): dispatch = event.dispatcher(TargetEvents) + return Target def test_clslevel(self): @@ -1238,11 +1172,7 @@ class RemovalTest(fixtures.TestBase): t2.dispatch.event_one("t2e1y") t2.dispatch.event_two("t2e2y") - eq_(m1.mock_calls, - [ - call('t1e1x'), call('t1e2x'), - call('t2e1x') - ]) + eq_(m1.mock_calls, [call("t1e1x"), call("t1e2x"), call("t2e1x")]) @testing.requires.predictable_gc def test_listener_collection_removed_cleanup(self): @@ -1286,7 +1216,10 @@ class RemovalTest(fixtures.TestBase): exc.InvalidRequestError, r"No listeners found for event <.*Target.*> / " r"'event_two' / ", - event.remove, t1, "event_two", m1 + event.remove, + t1, + "event_two", + m1, ) event.remove(t1, "event_three", m1) @@ -1302,9 +1235,7 @@ class RemovalTest(fixtures.TestBase): event.listen(t1, "event_one", evt) assert_raises_message( - Exception, - "deque mutated during iteration", - t1.dispatch.event_one + Exception, "deque mutated during iteration", t1.dispatch.event_one ) def test_no_add_in_event(self): @@ -1320,9 +1251,7 @@ class RemovalTest(fixtures.TestBase): event.listen(t1, "event_one", evt) assert_raises_message( - Exception, - "deque mutated during iteration", - t1.dispatch.event_one + Exception, "deque mutated during iteration", t1.dispatch.event_one ) def test_remove_plain_named(self): diff --git a/test/base/test_except.py b/test/base/test_except.py index ce8655af0b..2ace55e915 100644 --- a/test/base/test_except.py +++ b/test/base/test_except.py @@ -23,9 +23,8 @@ class OperationalError(DatabaseError): class ProgrammingError(DatabaseError): - def __str__(self): - return '<%s>' % self.bogus + return "<%s>" % self.bogus class OutOfSpec(DatabaseError): @@ -47,7 +46,6 @@ class SpecificIntegrityError(WrongNameError): class WrapTest(fixtures.TestBase): - def _translating_dialect_fixture(self): d = default.DefaultDialect() d.dbapi_exception_translation_map = { @@ -58,106 +56,140 @@ class WrapTest(fixtures.TestBase): def test_db_error_normal(self): try: raise sa_exceptions.DBAPIError.instance( - '', [], - OperationalError(), DatabaseError) + "", [], OperationalError(), DatabaseError + ) except sa_exceptions.DBAPIError: self.assert_(True) def test_tostring(self): try: raise sa_exceptions.DBAPIError.instance( - 'this is a message', - None, OperationalError(), DatabaseError) + "this is a message", None, OperationalError(), DatabaseError + ) except sa_exceptions.DBAPIError as exc: eq_( str(exc), "(test.base.test_except.OperationalError) " "[SQL: 'this is a message'] (Background on this error at: " - "http://sqlalche.me/e/e3q8)") + "http://sqlalche.me/e/e3q8)", + ) def test_statement_error_no_code(self): try: raise sa_exceptions.DBAPIError.instance( - 'select * from table', [{"x": 1}], - sa_exceptions.InvalidRequestError("hello"), DatabaseError) + "select * from table", + [{"x": 1}], + sa_exceptions.InvalidRequestError("hello"), + DatabaseError, + ) except sa_exceptions.StatementError as err: eq_( str(err), "(sqlalchemy.exc.InvalidRequestError) hello " - "[SQL: 'select * from table'] [parameters: [{'x': 1}]]" + "[SQL: 'select * from table'] [parameters: [{'x': 1}]]", ) - eq_(err.args, ("(sqlalchemy.exc.InvalidRequestError) hello", )) + eq_(err.args, ("(sqlalchemy.exc.InvalidRequestError) hello",)) def test_statement_error_w_code(self): try: raise sa_exceptions.DBAPIError.instance( - 'select * from table', [{"x": 1}], + "select * from table", + [{"x": 1}], sa_exceptions.InvalidRequestError("hello", code="abcd"), - DatabaseError) + DatabaseError, + ) except sa_exceptions.StatementError as err: eq_( str(err), "(sqlalchemy.exc.InvalidRequestError) hello " "[SQL: 'select * from table'] [parameters: [{'x': 1}]] " - "(Background on this error at: http://sqlalche.me/e/abcd)" + "(Background on this error at: http://sqlalche.me/e/abcd)", ) - eq_(err.args, ("(sqlalchemy.exc.InvalidRequestError) hello", )) + eq_(err.args, ("(sqlalchemy.exc.InvalidRequestError) hello",)) def test_wrap_multi_arg(self): # this is not supported by the API but oslo_db is doing it orig = sa_exceptions.DBAPIError(False, False, False) - orig.args = [2006, 'Test raise operational error'] + orig.args = [2006, "Test raise operational error"] eq_( str(orig), "(2006, 'Test raise operational error') " - "(Background on this error at: http://sqlalche.me/e/dbapi)" + "(Background on this error at: http://sqlalche.me/e/dbapi)", ) def test_wrap_unicode_arg(self): # this is not supported by the API but oslo_db is doing it orig = sa_exceptions.DBAPIError(False, False, False) - orig.args = [u('méil')] + orig.args = [u("méil")] eq_( compat.text_type(orig), compat.u( "méil (Background on this error at: " - "http://sqlalche.me/e/dbapi)") + "http://sqlalche.me/e/dbapi)" + ), ) - eq_(orig.args, (u('méil'),)) + eq_(orig.args, (u("méil"),)) def test_tostring_large_dict(self): try: raise sa_exceptions.DBAPIError.instance( - 'this is a message', + "this is a message", { - 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, - 'h': 8, 'i': 9, 'j': 10, 'k': 11 + "a": 1, + "b": 2, + "c": 3, + "d": 4, + "e": 5, + "f": 6, + "g": 7, + "h": 8, + "i": 9, + "j": 10, + "k": 11, }, - OperationalError(), DatabaseError) + OperationalError(), + DatabaseError, + ) except sa_exceptions.DBAPIError as exc: assert str(exc).startswith( "(test.base.test_except.OperationalError) " - "[SQL: 'this is a message'] [parameters: {") + "[SQL: 'this is a message'] [parameters: {" + ) def test_tostring_large_list(self): try: raise sa_exceptions.DBAPIError.instance( - 'this is a message', + "this is a message", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], - OperationalError(), DatabaseError) + OperationalError(), + DatabaseError, + ) except sa_exceptions.DBAPIError as exc: assert str(exc).startswith( "(test.base.test_except.OperationalError) " "[SQL: 'this is a message'] [parameters: " - "[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]") + "[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]" + ) def test_tostring_large_executemany(self): try: raise sa_exceptions.DBAPIError.instance( - 'this is a message', - [{1: 1}, {1: 1}, {1: 1}, {1: 1}, {1: 1}, {1: 1}, - {1: 1}, {1: 1}, {1: 1}, {1: 1}, ], - OperationalError("sql error"), DatabaseError) + "this is a message", + [ + {1: 1}, + {1: 1}, + {1: 1}, + {1: 1}, + {1: 1}, + {1: 1}, + {1: 1}, + {1: 1}, + {1: 1}, + {1: 1}, + ], + OperationalError("sql error"), + DatabaseError, + ) except sa_exceptions.DBAPIError as exc: eq_( str(exc), @@ -165,33 +197,48 @@ class WrapTest(fixtures.TestBase): "[SQL: 'this is a message'] [parameters: [{1: 1}, " "{1: 1}, {1: 1}, {1: 1}, {1: 1}, {1: 1}, {1: 1}, {1: " "1}, {1: 1}, {1: 1}]] (Background on this error at: " - "http://sqlalche.me/e/e3q8)" + "http://sqlalche.me/e/e3q8)", ) eq_( exc.args, - ("(test.base.test_except.OperationalError) sql error", ) + ("(test.base.test_except.OperationalError) sql error",), ) try: - raise sa_exceptions.DBAPIError.instance('this is a message', [ - {1: 1}, {1: 1}, {1: 1}, {1: 1}, {1: 1}, {1: 1}, {1: 1}, - {1: 1}, {1: 1}, {1: 1}, {1: 1}, - ], OperationalError(), DatabaseError) + raise sa_exceptions.DBAPIError.instance( + "this is a message", + [ + {1: 1}, + {1: 1}, + {1: 1}, + {1: 1}, + {1: 1}, + {1: 1}, + {1: 1}, + {1: 1}, + {1: 1}, + {1: 1}, + {1: 1}, + ], + OperationalError(), + DatabaseError, + ) except sa_exceptions.DBAPIError as exc: - eq_(str(exc), + eq_( + str(exc), "(test.base.test_except.OperationalError) " "[SQL: 'this is a message'] [parameters: [{1: 1}, " "{1: 1}, {1: 1}, {1: 1}, {1: 1}, {1: 1}, " "{1: 1}, {1: 1} ... displaying 10 of 11 total " "bound parameter sets ... {1: 1}, {1: 1}]] " - "(Background on this error at: http://sqlalche.me/e/e3q8)" - ) + "(Background on this error at: http://sqlalche.me/e/e3q8)", + ) try: raise sa_exceptions.DBAPIError.instance( - 'this is a message', - [ - (1, ), (1, ), (1, ), (1, ), (1, ), (1, ), - (1, ), (1, ), (1, ), (1, ), - ], OperationalError(), DatabaseError) + "this is a message", + [(1,), (1,), (1,), (1,), (1,), (1,), (1,), (1,), (1,), (1,)], + OperationalError(), + DatabaseError, + ) except sa_exceptions.DBAPIError as exc: eq_( @@ -199,36 +246,52 @@ class WrapTest(fixtures.TestBase): "(test.base.test_except.OperationalError) " "[SQL: 'this is a message'] [parameters: [(1,), " "(1,), (1,), (1,), (1,), (1,), (1,), (1,), (1,), (1,)]] " - "(Background on this error at: http://sqlalche.me/e/e3q8)") + "(Background on this error at: http://sqlalche.me/e/e3q8)", + ) try: - raise sa_exceptions.DBAPIError.instance('this is a message', [ - (1, ), (1, ), (1, ), (1, ), (1, ), (1, ), (1, ), (1, ), (1, ), - (1, ), (1, ), - ], OperationalError(), DatabaseError) + raise sa_exceptions.DBAPIError.instance( + "this is a message", + [ + (1,), + (1,), + (1,), + (1,), + (1,), + (1,), + (1,), + (1,), + (1,), + (1,), + (1,), + ], + OperationalError(), + DatabaseError, + ) except sa_exceptions.DBAPIError as exc: - eq_(str(exc), + eq_( + str(exc), "(test.base.test_except.OperationalError) " "[SQL: 'this is a message'] [parameters: [(1,), " "(1,), (1,), (1,), (1,), (1,), (1,), (1,) " "... displaying 10 of 11 total bound " "parameter sets ... (1,), (1,)]] " - "(Background on this error at: http://sqlalche.me/e/e3q8)" - ) + "(Background on this error at: http://sqlalche.me/e/e3q8)", + ) def test_db_error_busted_dbapi(self): try: raise sa_exceptions.DBAPIError.instance( - '', [], - ProgrammingError(), DatabaseError) + "", [], ProgrammingError(), DatabaseError + ) except sa_exceptions.DBAPIError as e: self.assert_(True) - self.assert_('Error in str() of DB-API' in e.args[0]) + self.assert_("Error in str() of DB-API" in e.args[0]) def test_db_error_noncompliant_dbapi(self): try: raise sa_exceptions.DBAPIError.instance( - '', [], OutOfSpec(), - DatabaseError) + "", [], OutOfSpec(), DatabaseError + ) except sa_exceptions.DBAPIError as e: # OutOfSpec subclasses DatabaseError self.assert_(e.__class__ is sa_exceptions.DatabaseError) @@ -237,8 +300,8 @@ class WrapTest(fixtures.TestBase): try: raise sa_exceptions.DBAPIError.instance( - '', [], - sa_exceptions.ArgumentError(), DatabaseError) + "", [], sa_exceptions.ArgumentError(), DatabaseError + ) except sa_exceptions.DBAPIError as e: self.assert_(e.__class__ is sa_exceptions.DBAPIError) except sa_exceptions.ArgumentError: @@ -247,9 +310,12 @@ class WrapTest(fixtures.TestBase): dialect = self._translating_dialect_fixture() try: raise sa_exceptions.DBAPIError.instance( - '', [], - sa_exceptions.ArgumentError(), DatabaseError, - dialect=dialect) + "", + [], + sa_exceptions.ArgumentError(), + DatabaseError, + dialect=dialect, + ) except sa_exceptions.DBAPIError as e: self.assert_(e.__class__ is sa_exceptions.DBAPIError) except sa_exceptions.ArgumentError: @@ -260,22 +326,26 @@ class WrapTest(fixtures.TestBase): try: raise sa_exceptions.DBAPIError.instance( - '', [], IntegrityError(), - DatabaseError, dialect=dialect) + "", [], IntegrityError(), DatabaseError, dialect=dialect + ) except sa_exceptions.DBAPIError as e: self.assert_(e.__class__ is sa_exceptions.IntegrityError) try: raise sa_exceptions.DBAPIError.instance( - '', [], SpecificIntegrityError(), - DatabaseError, dialect=dialect) + "", + [], + SpecificIntegrityError(), + DatabaseError, + dialect=dialect, + ) except sa_exceptions.DBAPIError as e: self.assert_(e.__class__ is sa_exceptions.IntegrityError) try: raise sa_exceptions.DBAPIError.instance( - '', [], SpecificIntegrityError(), - DatabaseError) + "", [], SpecificIntegrityError(), DatabaseError + ) except sa_exceptions.DBAPIError as e: # doesn't work without a dialect self.assert_(e.__class__ is not sa_exceptions.IntegrityError) @@ -283,8 +353,8 @@ class WrapTest(fixtures.TestBase): def test_db_error_keyboard_interrupt(self): try: raise sa_exceptions.DBAPIError.instance( - '', [], - KeyboardInterrupt(), DatabaseError) + "", [], KeyboardInterrupt(), DatabaseError + ) except sa_exceptions.DBAPIError: self.assert_(False) except KeyboardInterrupt: @@ -293,8 +363,8 @@ class WrapTest(fixtures.TestBase): def test_db_error_system_exit(self): try: raise sa_exceptions.DBAPIError.instance( - '', [], - SystemExit(), DatabaseError) + "", [], SystemExit(), DatabaseError + ) except sa_exceptions.DBAPIError: self.assert_(False) except SystemExit: diff --git a/test/base/test_inspect.py b/test/base/test_inspect.py index d10aeca9a9..933d72fe94 100644 --- a/test/base/test_inspect.py +++ b/test/base/test_inspect.py @@ -11,7 +11,6 @@ class TestFixture(object): class TestInspection(fixtures.TestBase): - def tearDown(self): for type_ in list(inspection._registrars): if issubclass(type_, TestFixture): @@ -36,7 +35,8 @@ class TestInspection(fixtures.TestBase): assert_raises_message( exc.NoInspectionAvailable, "No inspection system is available for object of type ", - inspect, SomeFoo + inspect, + SomeFoo, ) def test_class_insp(self): @@ -46,6 +46,7 @@ class TestInspection(fixtures.TestBase): class SomeFooInspect(object): def __init__(self, target): self.target = target + SomeFooInspect = inspection._inspects(SomeFoo)(SomeFooInspect) somefoo = SomeFoo() diff --git a/test/base/test_tutorials.py b/test/base/test_tutorials.py index 55a0b92d69..c409629c5f 100644 --- a/test/base/test_tutorials.py +++ b/test/base/test_tutorials.py @@ -10,7 +10,7 @@ import os class DocTest(fixtures.TestBase): def _setup_logger(self): - rootlogger = logging.getLogger('sqlalchemy.engine.base.Engine') + rootlogger = logging.getLogger("sqlalchemy.engine.base.Engine") class MyStream(object): def write(self, string): @@ -21,25 +21,26 @@ class DocTest(fixtures.TestBase): pass self._handler = handler = logging.StreamHandler(MyStream()) - handler.setFormatter(logging.Formatter('%(message)s')) + handler.setFormatter(logging.Formatter("%(message)s")) rootlogger.addHandler(handler) def _teardown_logger(self): - rootlogger = logging.getLogger('sqlalchemy.engine.base.Engine') + rootlogger = logging.getLogger("sqlalchemy.engine.base.Engine") rootlogger.removeHandler(self._handler) def _setup_create_table_patcher(self): from sqlalchemy.sql import ddl + self.orig_sort = ddl.sort_tables_and_constraints def our_sort(tables, **kw): - return self.orig_sort( - sorted(tables, key=lambda t: t.key), **kw - ) + return self.orig_sort(sorted(tables, key=lambda t: t.key), **kw) + ddl.sort_tables_and_constraints = our_sort def _teardown_create_table_patcher(self): from sqlalchemy.sql import ddl + ddl.sort_tables_and_constraints = self.orig_sort def setup(self): @@ -52,16 +53,17 @@ class DocTest(fixtures.TestBase): def _run_doctest_for_content(self, name, content): optionflags = ( - doctest.ELLIPSIS | - doctest.NORMALIZE_WHITESPACE | - doctest.IGNORE_EXCEPTION_DETAIL | - _get_allow_unicode_flag() + doctest.ELLIPSIS + | doctest.NORMALIZE_WHITESPACE + | doctest.IGNORE_EXCEPTION_DETAIL + | _get_allow_unicode_flag() ) runner = doctest.DocTestRunner( - verbose=None, optionflags=optionflags, - checker=_get_unicode_checker()) - globs = { - 'print_function': print_function} + verbose=None, + optionflags=optionflags, + checker=_get_unicode_checker(), + ) + globs = {"print_function": print_function} parser = doctest.DocTestParser() test = parser.get_doctest(content, globs, name, name, 0) runner.run(test) @@ -76,7 +78,7 @@ class DocTest(fixtures.TestBase): config.skip_test("Can't find documentation file %r" % path) with open(path) as file_: content = file_.read() - content = re.sub(r'{(?:stop|sql|opensql)}', '', content) + content = re.sub(r"{(?:stop|sql|opensql)}", "", content) self._run_doctest_for_content(fname, content) def test_orm(self): @@ -98,7 +100,7 @@ def _get_unicode_checker(): An inner class is used to avoid importing "doctest" at the module level. """ - if hasattr(_get_unicode_checker, 'UnicodeOutputChecker'): + if hasattr(_get_unicode_checker, "UnicodeOutputChecker"): return _get_unicode_checker.UnicodeOutputChecker() import doctest @@ -113,8 +115,9 @@ def _get_unicode_checker(): _literal_re = re.compile(r"(\W|^)[uU]([rR]?[\'\"])", re.UNICODE) def check_output(self, want, got, optionflags): - res = doctest.OutputChecker.check_output(self, want, got, - optionflags) + res = doctest.OutputChecker.check_output( + self, want, got, optionflags + ) if res: return True @@ -125,12 +128,13 @@ def _get_unicode_checker(): # the code below will end up executed only in Python 2 in # our tests, and our coverage check runs in Python 3 only def remove_u_prefixes(txt): - return re.sub(self._literal_re, r'\1\2', txt) + return re.sub(self._literal_re, r"\1\2", txt) want = remove_u_prefixes(want) got = remove_u_prefixes(got) - res = doctest.OutputChecker.check_output(self, want, got, - optionflags) + res = doctest.OutputChecker.check_output( + self, want, got, optionflags + ) return res _get_unicode_checker.UnicodeOutputChecker = UnicodeOutputChecker @@ -142,4 +146,5 @@ def _get_allow_unicode_flag(): Registers and returns the ALLOW_UNICODE flag. """ import doctest - return doctest.register_optionflag('ALLOW_UNICODE') + + return doctest.register_optionflag("ALLOW_UNICODE") diff --git a/test/base/test_utils.py b/test/base/test_utils.py index bf65d4fc97..88b865b1d7 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -13,13 +13,12 @@ import inspect class _KeyedTupleTest(object): - def _fixture(self, values, labels): raise NotImplementedError() def test_empty(self): keyed_tuple = self._fixture([], []) - eq_(str(keyed_tuple), '()') + eq_(str(keyed_tuple), "()") eq_(len(keyed_tuple), 0) eq_(list(keyed_tuple.keys()), []) @@ -28,7 +27,7 @@ class _KeyedTupleTest(object): def test_values_but_no_labels(self): keyed_tuple = self._fixture([1, 2], []) - eq_(str(keyed_tuple), '(1, 2)') + eq_(str(keyed_tuple), "(1, 2)") eq_(len(keyed_tuple), 2) eq_(list(keyed_tuple.keys()), []) @@ -39,37 +38,39 @@ class _KeyedTupleTest(object): eq_(keyed_tuple[1], 2) def test_basic_creation(self): - keyed_tuple = self._fixture([1, 2], ['a', 'b']) - eq_(str(keyed_tuple), '(1, 2)') - eq_(list(keyed_tuple.keys()), ['a', 'b']) - eq_(keyed_tuple._fields, ('a', 'b')) - eq_(keyed_tuple._asdict(), {'a': 1, 'b': 2}) + keyed_tuple = self._fixture([1, 2], ["a", "b"]) + eq_(str(keyed_tuple), "(1, 2)") + eq_(list(keyed_tuple.keys()), ["a", "b"]) + eq_(keyed_tuple._fields, ("a", "b")) + eq_(keyed_tuple._asdict(), {"a": 1, "b": 2}) def test_basic_index_access(self): - keyed_tuple = self._fixture([1, 2], ['a', 'b']) + keyed_tuple = self._fixture([1, 2], ["a", "b"]) eq_(keyed_tuple[0], 1) eq_(keyed_tuple[1], 2) def should_raise(): keyed_tuple[2] + assert_raises(IndexError, should_raise) def test_basic_attribute_access(self): - keyed_tuple = self._fixture([1, 2], ['a', 'b']) + keyed_tuple = self._fixture([1, 2], ["a", "b"]) eq_(keyed_tuple.a, 1) eq_(keyed_tuple.b, 2) def should_raise(): keyed_tuple.c + assert_raises(AttributeError, should_raise) def test_none_label(self): - keyed_tuple = self._fixture([1, 2, 3], ['a', None, 'b']) - eq_(str(keyed_tuple), '(1, 2, 3)') + keyed_tuple = self._fixture([1, 2, 3], ["a", None, "b"]) + eq_(str(keyed_tuple), "(1, 2, 3)") - eq_(list(keyed_tuple.keys()), ['a', 'b']) - eq_(keyed_tuple._fields, ('a', 'b')) - eq_(keyed_tuple._asdict(), {'a': 1, 'b': 3}) + eq_(list(keyed_tuple.keys()), ["a", "b"]) + eq_(keyed_tuple._fields, ("a", "b")) + eq_(keyed_tuple._asdict(), {"a": 1, "b": 3}) # attribute access: can't get at value 2 eq_(keyed_tuple.a, 1) @@ -81,12 +82,12 @@ class _KeyedTupleTest(object): eq_(keyed_tuple[2], 3) def test_duplicate_labels(self): - keyed_tuple = self._fixture([1, 2, 3], ['a', 'b', 'b']) - eq_(str(keyed_tuple), '(1, 2, 3)') + keyed_tuple = self._fixture([1, 2, 3], ["a", "b", "b"]) + eq_(str(keyed_tuple), "(1, 2, 3)") - eq_(list(keyed_tuple.keys()), ['a', 'b', 'b']) - eq_(keyed_tuple._fields, ('a', 'b', 'b')) - eq_(keyed_tuple._asdict(), {'a': 1, 'b': 3}) + eq_(list(keyed_tuple.keys()), ["a", "b", "b"]) + eq_(keyed_tuple._fields, ("a", "b", "b")) + eq_(keyed_tuple._asdict(), {"a": 1, "b": 3}) # attribute access: can't get at value 2 eq_(keyed_tuple.a, 1) @@ -98,8 +99,8 @@ class _KeyedTupleTest(object): eq_(keyed_tuple[2], 3) def test_immutable(self): - keyed_tuple = self._fixture([1, 2], ['a', 'b']) - eq_(str(keyed_tuple), '(1, 2)') + keyed_tuple = self._fixture([1, 2], ["a", "b"]) + eq_(str(keyed_tuple), "(1, 2)") eq_(keyed_tuple.a, 1) @@ -107,20 +108,21 @@ class _KeyedTupleTest(object): def should_raise(): keyed_tuple[0] = 100 + assert_raises(TypeError, should_raise) def test_serialize(self): - keyed_tuple = self._fixture([1, 2, 3], ['a', None, 'b']) + keyed_tuple = self._fixture([1, 2, 3], ["a", None, "b"]) for loads, dumps in picklers(): kt = loads(dumps(keyed_tuple)) - eq_(str(kt), '(1, 2, 3)') + eq_(str(kt), "(1, 2, 3)") - eq_(list(kt.keys()), ['a', 'b']) - eq_(kt._fields, ('a', 'b')) - eq_(kt._asdict(), {'a': 1, 'b': 3}) + eq_(list(kt.keys()), ["a", "b"]) + eq_(kt._fields, ("a", "b")) + eq_(kt._asdict(), {"a": 1, "b": 3}) class KeyedTupleTest(_KeyedTupleTest, fixtures.TestBase): @@ -130,7 +132,7 @@ class KeyedTupleTest(_KeyedTupleTest, fixtures.TestBase): class LWKeyedTupleTest(_KeyedTupleTest, fixtures.TestBase): def _fixture(self, values, labels): - return util.lightweight_named_tuple('n', labels)(values) + return util.lightweight_named_tuple("n", labels)(values) class WeakSequenceTest(fixtures.TestBase): @@ -138,6 +140,7 @@ class WeakSequenceTest(fixtures.TestBase): def test_cleanout_elements(self): class Foo(object): pass + f1, f2, f3 = Foo(), Foo(), Foo() w = WeakSequence([f1, f2, f3]) eq_(len(w), 3) @@ -151,6 +154,7 @@ class WeakSequenceTest(fixtures.TestBase): def test_cleanout_appended(self): class Foo(object): pass + f1, f2, f3 = Foo(), Foo(), Foo() w = WeakSequence() w.append(f1) @@ -165,63 +169,63 @@ class WeakSequenceTest(fixtures.TestBase): class OrderedDictTest(fixtures.TestBase): - def test_odict(self): o = util.OrderedDict() - o['a'] = 1 - o['b'] = 2 - o['snack'] = 'attack' - o['c'] = 3 + o["a"] = 1 + o["b"] = 2 + o["snack"] = "attack" + o["c"] = 3 - eq_(list(o.keys()), ['a', 'b', 'snack', 'c']) - eq_(list(o.values()), [1, 2, 'attack', 3]) + eq_(list(o.keys()), ["a", "b", "snack", "c"]) + eq_(list(o.values()), [1, 2, "attack", 3]) - o.pop('snack') - eq_(list(o.keys()), ['a', 'b', 'c']) + o.pop("snack") + eq_(list(o.keys()), ["a", "b", "c"]) eq_(list(o.values()), [1, 2, 3]) try: - o.pop('eep') + o.pop("eep") assert False except KeyError: pass - eq_(o.pop('eep', 'woot'), 'woot') + eq_(o.pop("eep", "woot"), "woot") try: - o.pop('whiff', 'bang', 'pow') + o.pop("whiff", "bang", "pow") assert False except TypeError: pass - eq_(list(o.keys()), ['a', 'b', 'c']) + eq_(list(o.keys()), ["a", "b", "c"]) eq_(list(o.values()), [1, 2, 3]) o2 = util.OrderedDict(d=4) - o2['e'] = 5 + o2["e"] = 5 - eq_(list(o2.keys()), ['d', 'e']) + eq_(list(o2.keys()), ["d", "e"]) eq_(list(o2.values()), [4, 5]) o.update(o2) - eq_(list(o.keys()), ['a', 'b', 'c', 'd', 'e']) + eq_(list(o.keys()), ["a", "b", "c", "d", "e"]) eq_(list(o.values()), [1, 2, 3, 4, 5]) - o.setdefault('c', 'zzz') - o.setdefault('f', 6) - eq_(list(o.keys()), ['a', 'b', 'c', 'd', 'e', 'f']) + o.setdefault("c", "zzz") + o.setdefault("f", 6) + eq_(list(o.keys()), ["a", "b", "c", "d", "e", "f"]) eq_(list(o.values()), [1, 2, 3, 4, 5, 6]) def test_odict_constructor(self): - o = util.OrderedDict([('name', 'jbe'), - ('fullname', 'jonathan'), ('password', '')]) - eq_(list(o.keys()), ['name', 'fullname', 'password']) + o = util.OrderedDict( + [("name", "jbe"), ("fullname", "jonathan"), ("password", "")] + ) + eq_(list(o.keys()), ["name", "fullname", "password"]) def test_odict_copy(self): o = util.OrderedDict() o["zzz"] = 1 o["aaa"] = 2 - eq_(list(o.keys()), ['zzz', 'aaa']) + eq_(list(o.keys()), ["zzz", "aaa"]) o2 = o.copy() eq_(list(o2.keys()), list(o.keys())) @@ -231,7 +235,6 @@ class OrderedDictTest(fixtures.TestBase): class OrderedSetTest(fixtures.TestBase): - def test_mutators_against_iter(self): # testing a set modified against an iterator o = util.OrderedSet([3, 2, 4, 5]) @@ -242,7 +245,6 @@ class OrderedSetTest(fixtures.TestBase): class FrozenDictTest(fixtures.TestBase): - def test_serialize(self): d = util.immutabledict({1: 2, 3: 4}) for loads, dumps in picklers(): @@ -250,7 +252,6 @@ class FrozenDictTest(fixtures.TestBase): class MemoizedAttrTest(fixtures.TestBase): - def test_memoized_property(self): val = [20] @@ -263,11 +264,11 @@ class MemoizedAttrTest(fixtures.TestBase): ne_(Foo.bar, None) f1 = Foo() - assert 'bar' not in f1.__dict__ + assert "bar" not in f1.__dict__ eq_(f1.bar, 20) eq_(f1.bar, 20) eq_(val[0], 21) - eq_(f1.__dict__['bar'], 20) + eq_(f1.__dict__["bar"], 20) def test_memoized_instancemethod(self): val = [20] @@ -282,7 +283,7 @@ class MemoizedAttrTest(fixtures.TestBase): assert inspect.ismethod(Foo().bar) ne_(Foo.bar, None) f1 = Foo() - assert 'bar' not in f1.__dict__ + assert "bar" not in f1.__dict__ eq_(f1.bar(), 20) eq_(f1.bar(), 20) eq_(val[0], 21) @@ -291,7 +292,7 @@ class MemoizedAttrTest(fixtures.TestBase): canary = mock.Mock() class Foob(util.MemoizedSlots): - __slots__ = ('foo_bar', 'gogo') + __slots__ = ("foo_bar", "gogo") def _memoized_method_gogo(self): canary.method() @@ -350,7 +351,6 @@ class WrapCallableTest(fixtures.TestBase): def test_wrapping_update_wrapper_cls_noclsdocstring(self): class MyFancyDefault(object): - def __call__(self): """run the fancy default""" return 10 @@ -374,7 +374,6 @@ class WrapCallableTest(fixtures.TestBase): def test_wrapping_update_wrapper_cls_noclsdocstring_nomethdocstring(self): class MyFancyDefault(object): - def __call__(self): return 10 @@ -388,10 +387,12 @@ class WrapCallableTest(fixtures.TestBase): return x import functools + my_functools_default = functools.partial(my_default, 5) c = util.wrap_callable( - lambda: my_functools_default(), my_functools_default) + lambda: my_functools_default(), my_functools_default + ) eq_(c.__name__, "partial") eq_(c.__doc__, my_functools_default.__call__.__doc__) eq_(c(), 5) @@ -399,59 +400,42 @@ class WrapCallableTest(fixtures.TestBase): class ToListTest(fixtures.TestBase): def test_from_string(self): - eq_( - util.to_list("xyz"), - ["xyz"] - ) + eq_(util.to_list("xyz"), ["xyz"]) def test_from_set(self): spec = util.to_list(set([1, 2, 3])) assert isinstance(spec, list) - eq_( - sorted(spec), - [1, 2, 3] - ) + eq_(sorted(spec), [1, 2, 3]) def test_from_dict(self): spec = util.to_list({1: "a", 2: "b", 3: "c"}) assert isinstance(spec, list) - eq_( - sorted(spec), - [1, 2, 3] - ) + eq_(sorted(spec), [1, 2, 3]) def test_from_tuple(self): - eq_( - util.to_list((1, 2, 3)), - [1, 2, 3] - ) + eq_(util.to_list((1, 2, 3)), [1, 2, 3]) def test_from_bytes(self): - eq_( - util.to_list(compat.b('abc')), - [compat.b('abc')] - ) + eq_(util.to_list(compat.b("abc")), [compat.b("abc")]) eq_( - util.to_list([ - compat.b('abc'), compat.b('def')]), - [compat.b('abc'), compat.b('def')] + util.to_list([compat.b("abc"), compat.b("def")]), + [compat.b("abc"), compat.b("def")], ) class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): - def test_in(self): cc = sql.ColumnCollection() - cc.add(sql.column('col1')) - cc.add(sql.column('col2')) - cc.add(sql.column('col3')) - assert 'col1' in cc - assert 'col2' in cc + cc.add(sql.column("col1")) + cc.add(sql.column("col2")) + cc.add(sql.column("col3")) + assert "col1" in cc + assert "col2" in cc try: - cc['col1'] in cc + cc["col1"] in cc assert False except exc.ArgumentError as e: eq_(str(e), "__contains__ requires a string argument") @@ -460,9 +444,9 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): cc1 = sql.ColumnCollection() cc2 = sql.ColumnCollection() cc3 = sql.ColumnCollection() - c1 = sql.column('col1') - c2 = c1.label('col2') - c3 = sql.column('col3') + c1 = sql.column("col1") + c2 = c1.label("col2") + c3 = sql.column("col3") cc1.add(c1) cc2.add(c2) cc3.add(c3) @@ -473,10 +457,12 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): def test_dupes_add(self): cc = sql.ColumnCollection() - c1, c2a, c3, c2b = (column('c1'), - column('c2'), - column('c3'), - column('c2')) + c1, c2a, c3, c2b = ( + column("c1"), + column("c2"), + column("c3"), + column("c2"), + ) cc.add(c1) cc.add(c2a) @@ -500,9 +486,7 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): def test_identical_dupe_add(self): cc = sql.ColumnCollection() - c1, c2, c3 = (column('c1'), - column('c2'), - column('c3')) + c1, c2, c3 = (column("c1"), column("c2"), column("c3")) cc.add(c1) cc.add(c2) @@ -512,8 +496,7 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): eq_(cc._all_columns, [c1, c2, c3]) self.assert_compile( - cc == [c1, c2, c3], - "c1 = c1 AND c2 = c2 AND c3 = c3" + cc == [c1, c2, c3], "c1 = c1 AND c2 = c2 AND c3 = c3" ) # for iter, c2a is replaced by c2b, ordering @@ -528,17 +511,18 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): eq_(list(ci), [c1, c2, c3]) self.assert_compile( - ci == [c1, c2, c3], - "c1 = c1 AND c2 = c2 AND c3 = c3" + ci == [c1, c2, c3], "c1 = c1 AND c2 = c2 AND c3 = c3" ) def test_replace(self): cc = sql.ColumnCollection() - c1, c2a, c3, c2b = (column('c1'), - column('c2'), - column('c3'), - column('c2')) + c1, c2a, c3, c2b = ( + column("c1"), + column("c2"), + column("c3"), + column("c2"), + ) cc.add(c1) cc.add(c2a) @@ -559,11 +543,13 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): def test_replace_key_matches(self): cc = sql.ColumnCollection() - c1, c2a, c3, c2b = (column('c1'), - column('c2'), - column('c3'), - column('X')) - c2b.key = 'c2' + c1, c2a, c3, c2b = ( + column("c1"), + column("c2"), + column("c3"), + column("X"), + ) + c2b.key = "c2" cc.add(c1) cc.add(c2a) @@ -584,11 +570,13 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): def test_replace_name_matches(self): cc = sql.ColumnCollection() - c1, c2a, c3, c2b = (column('c1'), - column('c2'), - column('c3'), - column('c2')) - c2b.key = 'X' + c1, c2a, c3, c2b = ( + column("c1"), + column("c2"), + column("c3"), + column("c2"), + ) + c2b.key = "X" cc.add(c1) cc.add(c2a) @@ -609,8 +597,8 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): def test_replace_no_match(self): cc = sql.ColumnCollection() - c1, c2, c3, c4 = column('c1'), column('c2'), column('c3'), column('c4') - c4.key = 'X' + c1, c2, c3, c4 = column("c1"), column("c2"), column("c3"), column("c4") + c4.key = "X" cc.add(c1) cc.add(c2) @@ -631,10 +619,12 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): def test_dupes_extend(self): cc = sql.ColumnCollection() - c1, c2a, c3, c2b = (column('c1'), - column('c2'), - column('c3'), - column('c2')) + c1, c2a, c3, c2b = ( + column("c1"), + column("c2"), + column("c3"), + column("c2"), + ) cc.add(c1) cc.add(c2a) @@ -658,10 +648,12 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): def test_dupes_update(self): cc = sql.ColumnCollection() - c1, c2a, c3, c2b = (column('c1'), - column('c2'), - column('c3'), - column('c2')) + c1, c2a, c3, c2b = ( + column("c1"), + column("c2"), + column("c3"), + column("c2"), + ) cc.add(c1) cc.add(c2a) @@ -681,11 +673,13 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): def test_extend_existing(self): cc = sql.ColumnCollection() - c1, c2, c3, c4, c5 = (column('c1'), - column('c2'), - column('c3'), - column('c4'), - column('c5')) + c1, c2, c3, c4, c5 = ( + column("c1"), + column("c2"), + column("c3"), + column("c4"), + column("c5"), + ) cc.extend([c1, c2]) eq_(cc._all_columns, [c1, c2]) @@ -699,24 +693,25 @@ class ColumnCollectionTest(testing.AssertsCompiledSQL, fixtures.TestBase): def test_update_existing(self): cc = sql.ColumnCollection() - c1, c2, c3, c4, c5 = (column('c1'), - column('c2'), - column('c3'), - column('c4'), - column('c5')) + c1, c2, c3, c4, c5 = ( + column("c1"), + column("c2"), + column("c3"), + column("c4"), + column("c5"), + ) - cc.update([('c1', c1), ('c2', c2)]) + cc.update([("c1", c1), ("c2", c2)]) eq_(cc._all_columns, [c1, c2]) - cc.update([('c3', c3)]) + cc.update([("c3", c3)]) eq_(cc._all_columns, [c1, c2, c3]) - cc.update([('c4', c4), ('c2', c2), ('c5', c5)]) + cc.update([("c4", c4), ("c2", c2), ("c5", c5)]) eq_(cc._all_columns, [c1, c2, c3, c4, c5]) class LRUTest(fixtures.TestBase): - def test_lru(self): class item(object): def __init__(self, id): @@ -725,7 +720,7 @@ class LRUTest(fixtures.TestBase): def __str__(self): return "item id %d" % self.id - lru = util.LRUCache(10, threshold=.2) + lru = util.LRUCache(10, threshold=0.2) for id in range(1, 20): lru[id] = item(id) @@ -764,10 +759,17 @@ class ImmutableSubclass(str): class FlattenIteratorTest(fixtures.TestBase): - def test_flatten(self): - assert list(util.flatten_iterator([[1, 2, 3], [4, 5, 6], 7, - 8])) == [1, 2, 3, 4, 5, 6, 7, 8] + assert list(util.flatten_iterator([[1, 2, 3], [4, 5, 6], 7, 8])) == [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + ] def test_str_with_iter(self): """ensure that a str object with an __iter__ method (like in @@ -777,15 +779,14 @@ class FlattenIteratorTest(fixtures.TestBase): class IterString(str): def __iter__(self): - return iter(self + '') + return iter(self + "") - iter_list = [IterString('asdf'), [IterString('x'), IterString('y')]] + iter_list = [IterString("asdf"), [IterString("x"), IterString("y")]] - assert list(util.flatten_iterator(iter_list)) == ['asdf', 'x', 'y'] + assert list(util.flatten_iterator(iter_list)) == ["asdf", "x", "y"] class HashOverride(object): - def __init__(self, value=None): self.value = value @@ -794,9 +795,9 @@ class HashOverride(object): class EqOverride(object): - def __init__(self, value=None): self.value = value + __hash__ = object.__hash__ def __eq__(self, other): @@ -813,7 +814,6 @@ class EqOverride(object): class HashEqOverride(object): - def __init__(self, value=None): self.value = value @@ -834,7 +834,6 @@ class HashEqOverride(object): class IdentitySetTest(fixtures.TestBase): - def assert_eq(self, identityset, expected_iterable): expected = sorted([id(o) for o in expected_iterable]) found = sorted([id(o) for o in identityset]) @@ -876,10 +875,7 @@ class IdentitySetTest(fixtures.TestBase): o1, o2, o3 = object(), object(), object() ids1 = IdentitySet([o1]) ids2 = IdentitySet([o1, o2, o3]) - eq_( - ids2 - ids1, - IdentitySet([o2, o3]) - ) + eq_(ids2 - ids1, IdentitySet([o2, o3])) ids2 -= ids1 eq_(ids2, IdentitySet([o2, o3])) @@ -925,6 +921,7 @@ class IdentitySetTest(fixtures.TestBase): def should_raise(): not_an_identity_set = object() return unique1 <= not_an_identity_set + self._assert_unorderable_types(should_raise) def test_dunder_lt(self): @@ -946,6 +943,7 @@ class IdentitySetTest(fixtures.TestBase): def should_raise(): not_an_identity_set = object() return unique1 < not_an_identity_set + self._assert_unorderable_types(should_raise) def test_dunder_ge(self): @@ -967,6 +965,7 @@ class IdentitySetTest(fixtures.TestBase): def should_raise(): not_an_identity_set = object() return unique1 >= not_an_identity_set + self._assert_unorderable_types(should_raise) def test_dunder_gt(self): @@ -988,6 +987,7 @@ class IdentitySetTest(fixtures.TestBase): def should_raise(): not_an_identity_set = object() return unique1 > not_an_identity_set + self._assert_unorderable_types(should_raise) def test_issubset(self): @@ -1076,6 +1076,7 @@ class IdentitySetTest(fixtures.TestBase): def should_raise(): not_an_identity_set = object() return unique1 | not_an_identity_set + assert_raises(TypeError, should_raise) def test_update(self): @@ -1102,6 +1103,7 @@ class IdentitySetTest(fixtures.TestBase): unique = util.IdentitySet([1]) not_an_identity_set = object() unique |= not_an_identity_set + assert_raises(TypeError, should_raise) def test_difference(self): @@ -1158,6 +1160,7 @@ class IdentitySetTest(fixtures.TestBase): def should_raise(): not_an_identity_set = object() unique1 - not_an_identity_set + assert_raises(TypeError, should_raise) def test_difference_update(self): @@ -1210,6 +1213,7 @@ class IdentitySetTest(fixtures.TestBase): def should_raise(): not_an_identity_set = object() return unique1 & not_an_identity_set + assert_raises(TypeError, should_raise) def test_intersection_update(self): @@ -1244,7 +1248,8 @@ class IdentitySetTest(fixtures.TestBase): # not an IdentitySet not_an_identity_set = object() assert_raises( - TypeError, unique1.symmetric_difference, not_an_identity_set) + TypeError, unique1.symmetric_difference, not_an_identity_set + ) def test_dunder_xor(self): _, _, twin1, twin2, _, _ = self._create_sets() @@ -1273,6 +1278,7 @@ class IdentitySetTest(fixtures.TestBase): def should_raise(): not_an_identity_set = object() return unique1 ^ not_an_identity_set + assert_raises(TypeError, should_raise) def test_symmetric_difference_update(self): @@ -1291,13 +1297,14 @@ class IdentitySetTest(fixtures.TestBase): def _assert_unorderable_types(self, callable_): if util.py36: assert_raises_message( - TypeError, 'not supported between instances of', callable_) + TypeError, "not supported between instances of", callable_ + ) elif util.py3k: - assert_raises_message( - TypeError, 'unorderable types', callable_) + assert_raises_message(TypeError, "unorderable types", callable_) else: assert_raises_message( - TypeError, 'cannot compare sets using cmp()', callable_) + TypeError, "cannot compare sets using cmp()", callable_ + ) def test_basic_sanity(self): IdentitySet = util.IdentitySet @@ -1314,7 +1321,7 @@ class IdentitySetTest(fixtures.TestBase): # explicit __eq__ and __ne__ tests assert ids != None # noqa - assert not(ids == None) # noqa + assert not (ids == None) # noqa ne_(ids, IdentitySet([o1, o2, o3])) ids.clear() @@ -1349,9 +1356,9 @@ class IdentitySetTest(fixtures.TestBase): ids.symmetric_difference_update(isuper) ids ^= isuper - ids.update('foobar') + ids.update("foobar") try: - ids |= 'foobar' + ids |= "foobar" assert False except TypeError: assert True @@ -1368,7 +1375,6 @@ class IdentitySetTest(fixtures.TestBase): class OrderedIdentitySetTest(fixtures.TestBase): - def assert_eq(self, identityset, expected_iterable): expected = [id(o) for o in expected_iterable] found = [id(o) for o in identityset] @@ -1384,8 +1390,15 @@ class OrderedIdentitySetTest(fixtures.TestBase): elem = object eq_ = self.assert_eq - a, b, c, d, e, f, g = \ - elem(), elem(), elem(), elem(), elem(), elem(), elem() + a, b, c, d, e, f, g = ( + elem(), + elem(), + elem(), + elem(), + elem(), + elem(), + elem(), + ) s1 = util.OrderedIdentitySet([a, b, c]) s2 = util.OrderedIdentitySet([d, e, f]) @@ -1396,16 +1409,14 @@ class OrderedIdentitySetTest(fixtures.TestBase): class DictlikeIteritemsTest(fixtures.TestBase): - baseline = set([('a', 1), ('b', 2), ('c', 3)]) + baseline = set([("a", 1), ("b", 2), ("c", 3)]) def _ok(self, instance): iterator = util.dictlike_iteritems(instance) eq_(set(iterator), self.baseline) def _notok(self, instance): - assert_raises(TypeError, - util.dictlike_iteritems, - instance) + assert_raises(TypeError, util.dictlike_iteritems, instance) def test_dict(self): d = dict(a=1, b=2, c=3) @@ -1414,12 +1425,15 @@ class DictlikeIteritemsTest(fixtures.TestBase): def test_subdict(self): class subdict(dict): pass + d = subdict(a=1, b=2, c=3) self._ok(d) if util.py2k: + def test_UserDict(self): import UserDict + d = UserDict.UserDict(a=1, b=2, c=3) self._ok(d) @@ -1427,52 +1441,59 @@ class DictlikeIteritemsTest(fixtures.TestBase): self._notok(object()) if util.py2k: + def test_duck_1(self): class duck1(object): def iteritems(duck): return iter(self.baseline) + self._ok(duck1()) def test_duck_2(self): class duck2(object): def items(duck): return list(self.baseline) + self._ok(duck2()) if util.py2k: + def test_duck_3(self): class duck3(object): def iterkeys(duck): - return iter(['a', 'b', 'c']) + return iter(["a", "b", "c"]) def __getitem__(duck, key): return dict(a=1, b=2, c=3).get(key) + self._ok(duck3()) def test_duck_4(self): class duck4(object): def iterkeys(duck): - return iter(['a', 'b', 'c']) + return iter(["a", "b", "c"]) + self._notok(duck4()) def test_duck_5(self): class duck5(object): def keys(duck): - return ['a', 'b', 'c'] + return ["a", "b", "c"] def get(duck, key): return dict(a=1, b=2, c=3).get(key) + self._ok(duck5()) def test_duck_6(self): class duck6(object): def keys(duck): - return ['a', 'b', 'c'] + return ["a", "b", "c"] + self._notok(duck6()) class DuckTypeCollectionTest(fixtures.TestBase): - def test_sets(self): class SetLike(object): def add(self): @@ -1481,21 +1502,18 @@ class DuckTypeCollectionTest(fixtures.TestBase): class ForcedSet(list): __emulates__ = set - for type_ in (set, - SetLike, - ForcedSet): + for type_ in (set, SetLike, ForcedSet): eq_(util.duck_type_collection(type_), set) instance = type_() eq_(util.duck_type_collection(instance), set) - for type_ in (frozenset, ): + for type_ in (frozenset,): is_(util.duck_type_collection(type_), None) instance = type_() is_(util.duck_type_collection(instance), None) class PublicFactoryTest(fixtures.TestBase): - def _fixture(self): class Thingy(object): def __init__(self, value): @@ -1511,8 +1529,7 @@ class PublicFactoryTest(fixtures.TestBase): def test_classmethod(self): Thingy = self._fixture() - foob = langhelpers.public_factory( - Thingy.foobar, ".sql.elements.foob") + foob = langhelpers.public_factory(Thingy.foobar, ".sql.elements.foob") eq_(foob(3, 4).value, 7) eq_(foob(x=3, y=4).value, 7) eq_(foob.__doc__, "do the foobar") @@ -1521,20 +1538,18 @@ class PublicFactoryTest(fixtures.TestBase): def test_constructor(self): Thingy = self._fixture() - foob = langhelpers.public_factory( - Thingy, ".sql.elements.foob") + foob = langhelpers.public_factory(Thingy, ".sql.elements.foob") eq_(foob(7).value, 7) eq_(foob(value=7).value, 7) eq_(foob.__doc__, "make a thingy") eq_(foob.__module__, "sqlalchemy.sql.elements") assert Thingy.__init__.__doc__.startswith( - "Construct a new :class:`.Thingy` object.") + "Construct a new :class:`.Thingy` object." + ) class ArgInspectionTest(fixtures.TestBase): - def test_get_cls_kwargs(self): - class A(object): def __init__(self, a): pass @@ -1600,26 +1615,25 @@ class ArgInspectionTest(fixtures.TestBase): def test(cls, *expected): eq_(set(util.get_cls_kwargs(cls)), set(expected)) - test(A, 'a') - test(A1, 'a1') - test(A11, 'a11', 'a1') - test(B, 'b') - test(B1, 'b1', 'b') - test(AB, 'ab') - test(BA, 'ba', 'b', 'a') - test(BA1, 'ba', 'b', 'a') - test(CAB, 'a') - test(CBA, 'b', 'a') - test(CAB1, 'a') - test(CB1A, 'b1', 'b', 'a') - test(CB2A, 'b2') + test(A, "a") + test(A1, "a1") + test(A11, "a11", "a1") + test(B, "b") + test(B1, "b1", "b") + test(AB, "ab") + test(BA, "ba", "b", "a") + test(BA1, "ba", "b", "a") + test(CAB, "a") + test(CBA, "b", "a") + test(CAB1, "a") + test(CB1A, "b1", "b", "a") + test(CB2A, "b2") test(CB1A1, "a1", "b1", "b") test(D) test(BA2, "a", "b") test(A11B1, "a1", "a11", "b", "b1") def test_get_func_kwargs(self): - def f1(): pass @@ -1636,74 +1650,72 @@ class ArgInspectionTest(fixtures.TestBase): eq_(set(util.get_func_kwargs(fn)), set(expected)) test(f1) - test(f2, 'foo') + test(f2, "foo") test(f3) test(f4) def test_callable_argspec_fn(self): def foo(x, y, **kw): pass - eq_( - get_callable_argspec(foo), - (['x', 'y'], None, 'kw', None) - ) + + eq_(get_callable_argspec(foo), (["x", "y"], None, "kw", None)) def test_callable_argspec_fn_no_self(self): def foo(x, y, **kw): pass + eq_( get_callable_argspec(foo, no_self=True), - (['x', 'y'], None, 'kw', None) + (["x", "y"], None, "kw", None), ) def test_callable_argspec_fn_no_self_but_self(self): def foo(self, x, y, **kw): pass + eq_( get_callable_argspec(foo, no_self=True), - (['self', 'x', 'y'], None, 'kw', None) + (["self", "x", "y"], None, "kw", None), ) - @fails_if(lambda: util.pypy, "pypy returns plain *arg, **kw") + @fails_if(lambda: util.pypy, "pypy returns plain *arg, **kw") def test_callable_argspec_py_builtin(self): import datetime - assert_raises( - TypeError, - get_callable_argspec, datetime.datetime.now - ) - @fails_if(lambda: util.pypy, "pypy returns plain *arg, **kw") + assert_raises(TypeError, get_callable_argspec, datetime.datetime.now) + + @fails_if(lambda: util.pypy, "pypy returns plain *arg, **kw") def test_callable_argspec_obj_init(self): - assert_raises( - TypeError, - get_callable_argspec, object - ) + assert_raises(TypeError, get_callable_argspec, object) def test_callable_argspec_method(self): class Foo(object): def foo(self, x, y, **kw): pass + eq_( get_callable_argspec(Foo.foo), - (['self', 'x', 'y'], None, 'kw', None) + (["self", "x", "y"], None, "kw", None), ) def test_callable_argspec_instance_method_no_self(self): class Foo(object): def foo(self, x, y, **kw): pass + eq_( get_callable_argspec(Foo().foo, no_self=True), - (['x', 'y'], None, 'kw', None) + (["x", "y"], None, "kw", None), ) def test_callable_argspec_unbound_method_no_self(self): class Foo(object): def foo(self, x, y, **kw): pass + eq_( get_callable_argspec(Foo.foo, no_self=True), - (['self', 'x', 'y'], None, 'kw', None) + (["self", "x", "y"], None, "kw", None), ) def test_callable_argspec_init(self): @@ -1711,10 +1723,7 @@ class ArgInspectionTest(fixtures.TestBase): def __init__(self, x, y): pass - eq_( - get_callable_argspec(Foo), - (['self', 'x', 'y'], None, None, None) - ) + eq_(get_callable_argspec(Foo), (["self", "x", "y"], None, None, None)) def test_callable_argspec_init_no_self(self): class Foo(object): @@ -1723,58 +1732,56 @@ class ArgInspectionTest(fixtures.TestBase): eq_( get_callable_argspec(Foo, no_self=True), - (['x', 'y'], None, None, None) + (["x", "y"], None, None, None), ) def test_callable_argspec_call(self): class Foo(object): def __call__(self, x, y): pass + eq_( - get_callable_argspec(Foo()), - (['self', 'x', 'y'], None, None, None) + get_callable_argspec(Foo()), (["self", "x", "y"], None, None, None) ) def test_callable_argspec_call_no_self(self): class Foo(object): def __call__(self, x, y): pass + eq_( get_callable_argspec(Foo(), no_self=True), - (['x', 'y'], None, None, None) + (["x", "y"], None, None, None), ) - @fails_if(lambda: util.pypy, "pypy returns plain *arg, **kw") + @fails_if(lambda: util.pypy, "pypy returns plain *arg, **kw") def test_callable_argspec_partial(self): from functools import partial def foo(x, y, z, **kw): pass + bar = partial(foo, 5) - assert_raises( - TypeError, - get_callable_argspec, bar - ) + assert_raises(TypeError, get_callable_argspec, bar) class SymbolTest(fixtures.TestBase): - def test_basic(self): - sym1 = util.symbol('foo') - assert sym1.name == 'foo' - sym2 = util.symbol('foo') + sym1 = util.symbol("foo") + assert sym1.name == "foo" + sym2 = util.symbol("foo") assert sym1 is sym2 assert sym1 == sym2 - sym3 = util.symbol('bar') + sym3 = util.symbol("bar") assert sym1 is not sym3 assert sym1 != sym3 def test_pickle(self): - sym1 = util.symbol('foo') - sym2 = util.symbol('foo') + sym1 = util.symbol("foo") + sym2 = util.symbol("foo") assert sym1 is sym2 @@ -1790,18 +1797,18 @@ class SymbolTest(fixtures.TestBase): assert rt is sym2 def test_bitflags(self): - sym1 = util.symbol('sym1', canonical=1) - sym2 = util.symbol('sym2', canonical=2) + sym1 = util.symbol("sym1", canonical=1) + sym2 = util.symbol("sym2", canonical=2) assert sym1 & sym1 assert not sym1 & sym2 assert not sym1 & sym1 & sym2 def test_composites(self): - sym1 = util.symbol('sym1', canonical=1) - sym2 = util.symbol('sym2', canonical=2) - sym3 = util.symbol('sym3', canonical=4) - sym4 = util.symbol('sym4', canonical=8) + sym1 = util.symbol("sym1", canonical=1) + sym2 = util.symbol("sym2", canonical=2) + sym3 = util.symbol("sym3", canonical=4) + sym4 = util.symbol("sym4", canonical=8) assert sym1 & (sym2 | sym1 | sym4) assert not sym1 & (sym2 | sym3) @@ -1811,7 +1818,6 @@ class SymbolTest(fixtures.TestBase): class TestFormatArgspec(fixtures.TestBase): - def test_specs(self): def test(fn, wanted, grouped=None): if grouped is None: @@ -1820,91 +1826,184 @@ class TestFormatArgspec(fixtures.TestBase): parsed = util.format_argspec_plus(fn, grouped=grouped) eq_(parsed, wanted) - test(lambda: None, - {'args': '()', 'self_arg': None, - 'apply_kw': '()', 'apply_pos': '()'}) + test( + lambda: None, + { + "args": "()", + "self_arg": None, + "apply_kw": "()", + "apply_pos": "()", + }, + ) - test(lambda: None, - {'args': '', 'self_arg': None, - 'apply_kw': '', 'apply_pos': ''}, - grouped=False) + test( + lambda: None, + {"args": "", "self_arg": None, "apply_kw": "", "apply_pos": ""}, + grouped=False, + ) - test(lambda self: None, - {'args': '(self)', 'self_arg': 'self', - 'apply_kw': '(self)', 'apply_pos': '(self)'}) + test( + lambda self: None, + { + "args": "(self)", + "self_arg": "self", + "apply_kw": "(self)", + "apply_pos": "(self)", + }, + ) - test(lambda self: None, - {'args': 'self', 'self_arg': 'self', - 'apply_kw': 'self', 'apply_pos': 'self'}, - grouped=False) + test( + lambda self: None, + { + "args": "self", + "self_arg": "self", + "apply_kw": "self", + "apply_pos": "self", + }, + grouped=False, + ) - test(lambda *a: None, - {'args': '(*a)', 'self_arg': 'a[0]', - 'apply_kw': '(*a)', 'apply_pos': '(*a)'}) + test( + lambda *a: None, + { + "args": "(*a)", + "self_arg": "a[0]", + "apply_kw": "(*a)", + "apply_pos": "(*a)", + }, + ) - test(lambda **kw: None, - {'args': '(**kw)', 'self_arg': None, - 'apply_kw': '(**kw)', 'apply_pos': '(**kw)'}) + test( + lambda **kw: None, + { + "args": "(**kw)", + "self_arg": None, + "apply_kw": "(**kw)", + "apply_pos": "(**kw)", + }, + ) - test(lambda *a, **kw: None, - {'args': '(*a, **kw)', 'self_arg': 'a[0]', - 'apply_kw': '(*a, **kw)', 'apply_pos': '(*a, **kw)'}) + test( + lambda *a, **kw: None, + { + "args": "(*a, **kw)", + "self_arg": "a[0]", + "apply_kw": "(*a, **kw)", + "apply_pos": "(*a, **kw)", + }, + ) - test(lambda a, *b: None, - {'args': '(a, *b)', 'self_arg': 'a', - 'apply_kw': '(a, *b)', 'apply_pos': '(a, *b)'}) + test( + lambda a, *b: None, + { + "args": "(a, *b)", + "self_arg": "a", + "apply_kw": "(a, *b)", + "apply_pos": "(a, *b)", + }, + ) - test(lambda a, **b: None, - {'args': '(a, **b)', 'self_arg': 'a', - 'apply_kw': '(a, **b)', 'apply_pos': '(a, **b)'}) + test( + lambda a, **b: None, + { + "args": "(a, **b)", + "self_arg": "a", + "apply_kw": "(a, **b)", + "apply_pos": "(a, **b)", + }, + ) - test(lambda a, *b, **c: None, - {'args': '(a, *b, **c)', 'self_arg': 'a', - 'apply_kw': '(a, *b, **c)', 'apply_pos': '(a, *b, **c)'}) + test( + lambda a, *b, **c: None, + { + "args": "(a, *b, **c)", + "self_arg": "a", + "apply_kw": "(a, *b, **c)", + "apply_pos": "(a, *b, **c)", + }, + ) - test(lambda a, b=1, **c: None, - {'args': '(a, b=1, **c)', 'self_arg': 'a', - 'apply_kw': '(a, b=b, **c)', 'apply_pos': '(a, b, **c)'}) + test( + lambda a, b=1, **c: None, + { + "args": "(a, b=1, **c)", + "self_arg": "a", + "apply_kw": "(a, b=b, **c)", + "apply_pos": "(a, b, **c)", + }, + ) - test(lambda a=1, b=2: None, - {'args': '(a=1, b=2)', 'self_arg': 'a', - 'apply_kw': '(a=a, b=b)', 'apply_pos': '(a, b)'}) + test( + lambda a=1, b=2: None, + { + "args": "(a=1, b=2)", + "self_arg": "a", + "apply_kw": "(a=a, b=b)", + "apply_pos": "(a, b)", + }, + ) - test(lambda a=1, b=2: None, - {'args': 'a=1, b=2', 'self_arg': 'a', - 'apply_kw': 'a=a, b=b', 'apply_pos': 'a, b'}, - grouped=False) + test( + lambda a=1, b=2: None, + { + "args": "a=1, b=2", + "self_arg": "a", + "apply_kw": "a=a, b=b", + "apply_pos": "a, b", + }, + grouped=False, + ) - @testing.fails_if(lambda: util.pypy, - "pypy doesn't report Obj.__init__ as object.__init__") + @testing.fails_if( + lambda: util.pypy, + "pypy doesn't report Obj.__init__ as object.__init__", + ) def test_init_grouped(self): object_spec = { - 'args': '(self)', 'self_arg': 'self', - 'apply_pos': '(self)', 'apply_kw': '(self)'} + "args": "(self)", + "self_arg": "self", + "apply_pos": "(self)", + "apply_kw": "(self)", + } wrapper_spec = { - 'args': '(self, *args, **kwargs)', 'self_arg': 'self', - 'apply_pos': '(self, *args, **kwargs)', - 'apply_kw': '(self, *args, **kwargs)'} + "args": "(self, *args, **kwargs)", + "self_arg": "self", + "apply_pos": "(self, *args, **kwargs)", + "apply_kw": "(self, *args, **kwargs)", + } custom_spec = { - 'args': '(slef, a=123)', 'self_arg': 'slef', # yes, slef - 'apply_pos': '(slef, a)', 'apply_kw': '(slef, a=a)'} + "args": "(slef, a=123)", + "self_arg": "slef", # yes, slef + "apply_pos": "(slef, a)", + "apply_kw": "(slef, a=a)", + } self._test_init(None, object_spec, wrapper_spec, custom_spec) self._test_init(True, object_spec, wrapper_spec, custom_spec) - @testing.fails_if(lambda: util.pypy, - "pypy doesn't report Obj.__init__ as object.__init__") + @testing.fails_if( + lambda: util.pypy, + "pypy doesn't report Obj.__init__ as object.__init__", + ) def test_init_bare(self): object_spec = { - 'args': 'self', 'self_arg': 'self', - 'apply_pos': 'self', 'apply_kw': 'self'} + "args": "self", + "self_arg": "self", + "apply_pos": "self", + "apply_kw": "self", + } wrapper_spec = { - 'args': 'self, *args, **kwargs', 'self_arg': 'self', - 'apply_pos': 'self, *args, **kwargs', - 'apply_kw': 'self, *args, **kwargs'} + "args": "self, *args, **kwargs", + "self_arg": "self", + "apply_pos": "self, *args, **kwargs", + "apply_kw": "self, *args, **kwargs", + } custom_spec = { - 'args': 'slef, a=123', 'self_arg': 'slef', # yes, slef - 'apply_pos': 'slef, a', 'apply_kw': 'slef, a=a'} + "args": "slef, a=123", + "self_arg": "slef", # yes, slef + "apply_pos": "slef, a", + "apply_kw": "slef, a=a", + } self._test_init(False, object_spec, wrapper_spec, custom_spec) @@ -1958,17 +2057,14 @@ class TestFormatArgspec(fixtures.TestBase): class GenericReprTest(fixtures.TestBase): - def test_all_positional(self): class Foo(object): def __init__(self, a, b, c): self.a = a self.b = b self.c = c - eq_( - util.generic_repr(Foo(1, 2, 3)), - "Foo(1, 2, 3)" - ) + + eq_(util.generic_repr(Foo(1, 2, 3)), "Foo(1, 2, 3)") def test_positional_plus_kw(self): class Foo(object): @@ -1977,10 +2073,8 @@ class GenericReprTest(fixtures.TestBase): self.b = b self.c = c self.d = d - eq_( - util.generic_repr(Foo(1, 2, 3, 6)), - "Foo(1, 2, c=3, d=6)" - ) + + eq_(util.generic_repr(Foo(1, 2, 3, 6)), "Foo(1, 2, c=3, d=6)") def test_kw_defaults(self): class Foo(object): @@ -1989,10 +2083,8 @@ class GenericReprTest(fixtures.TestBase): self.b = b self.c = c self.d = d - eq_( - util.generic_repr(Foo(1, 5, 3, 7)), - "Foo(b=5, d=7)" - ) + + eq_(util.generic_repr(Foo(1, 5, 3, 7)), "Foo(b=5, d=7)") def test_multi_kw(self): class Foo(object): @@ -2011,18 +2103,14 @@ class GenericReprTest(fixtures.TestBase): eq_( util.generic_repr( - Bar('e', 'f', g=7, a=6, b=5, d=9), - to_inspect=[Bar, Foo] + Bar("e", "f", g=7, a=6, b=5, d=9), to_inspect=[Bar, Foo] ), - "Bar('e', 'f', g=7, a=6, b=5, d=9)" + "Bar('e', 'f', g=7, a=6, b=5, d=9)", ) eq_( - util.generic_repr( - Bar('e', 'f', a=6, b=5), - to_inspect=[Bar, Foo] - ), - "Bar('e', 'f', a=6, b=5)" + util.generic_repr(Bar("e", "f", a=6, b=5), to_inspect=[Bar, Foo]), + "Bar('e', 'f', a=6, b=5)", ) def test_multi_kw_repeated(self): @@ -2037,11 +2125,8 @@ class GenericReprTest(fixtures.TestBase): super(Bar, self).__init__(b=b, **kw) eq_( - util.generic_repr( - Bar(a='a', b='b', c='c'), - to_inspect=[Bar, Foo] - ), - "Bar(b='b', c='c', a='a')" + util.generic_repr(Bar(a="a", b="b", c="c"), to_inspect=[Bar, Foo]), + "Bar(b='b', c='c', a='a')", ) def test_discard_vargs(self): @@ -2050,10 +2135,8 @@ class GenericReprTest(fixtures.TestBase): self.a = a self.b = b self.c, self.d = args[0:2] - eq_( - util.generic_repr(Foo(1, 2, 3, 4)), - "Foo(1, 2)" - ) + + eq_(util.generic_repr(Foo(1, 2, 3, 4)), "Foo(1, 2)") def test_discard_vargs_kwargs(self): class Foo(object): @@ -2061,10 +2144,8 @@ class GenericReprTest(fixtures.TestBase): self.a = a self.b = b self.c, self.d = args[0:2] - eq_( - util.generic_repr(Foo(1, 2, 3, 4, x=7, y=4)), - "Foo(1, 2)" - ) + + eq_(util.generic_repr(Foo(1, 2, 3, 4, x=7, y=4)), "Foo(1, 2)") def test_significant_vargs(self): class Foo(object): @@ -2072,33 +2153,25 @@ class GenericReprTest(fixtures.TestBase): self.a = a self.b = b self.args = args - eq_( - util.generic_repr(Foo(1, 2, 3, 4)), - "Foo(1, 2, 3, 4)" - ) + + eq_(util.generic_repr(Foo(1, 2, 3, 4)), "Foo(1, 2, 3, 4)") def test_no_args(self): class Foo(object): def __init__(self): pass - eq_( - util.generic_repr(Foo()), - "Foo()" - ) + + eq_(util.generic_repr(Foo()), "Foo()") def test_no_init(self): class Foo(object): pass - eq_( - util.generic_repr(Foo()), - "Foo()" - ) + eq_(util.generic_repr(Foo()), "Foo()") -class AsInterfaceTest(fixtures.TestBase): +class AsInterfaceTest(fixtures.TestBase): class Something(object): - def _ignoreme(self): pass @@ -2109,7 +2182,6 @@ class AsInterfaceTest(fixtures.TestBase): pass class Partial(object): - def bar(self): pass @@ -2118,21 +2190,27 @@ class AsInterfaceTest(fixtures.TestBase): def test_instance(self): obj = object() - assert_raises(TypeError, util.as_interface, obj, - cls=self.Something) + assert_raises(TypeError, util.as_interface, obj, cls=self.Something) - assert_raises(TypeError, util.as_interface, obj, - methods=('foo')) + assert_raises(TypeError, util.as_interface, obj, methods=("foo")) - assert_raises(TypeError, util.as_interface, obj, - cls=self.Something, required=('foo')) + assert_raises( + TypeError, + util.as_interface, + obj, + cls=self.Something, + required=("foo"), + ) obj = self.Something() eq_(obj, util.as_interface(obj, cls=self.Something)) - eq_(obj, util.as_interface(obj, methods=('foo',))) + eq_(obj, util.as_interface(obj, methods=("foo",))) eq_( - obj, util.as_interface(obj, cls=self.Something, - required=('outofband',))) + obj, + util.as_interface( + obj, cls=self.Something, required=("outofband",) + ), + ) partial = self.Partial() slotted = self.Object() @@ -2140,57 +2218,79 @@ class AsInterfaceTest(fixtures.TestBase): for obj in partial, slotted: eq_(obj, util.as_interface(obj, cls=self.Something)) - assert_raises(TypeError, util.as_interface, obj, - methods=('foo')) - eq_(obj, util.as_interface(obj, methods=('bar',))) - eq_(obj, util.as_interface(obj, cls=self.Something, - required=('bar',))) - assert_raises(TypeError, util.as_interface, obj, - cls=self.Something, required=('foo',)) - - assert_raises(TypeError, util.as_interface, obj, - cls=self.Something, required=self.Something) + assert_raises(TypeError, util.as_interface, obj, methods=("foo")) + eq_(obj, util.as_interface(obj, methods=("bar",))) + eq_( + obj, + util.as_interface(obj, cls=self.Something, required=("bar",)), + ) + assert_raises( + TypeError, + util.as_interface, + obj, + cls=self.Something, + required=("foo",), + ) + + assert_raises( + TypeError, + util.as_interface, + obj, + cls=self.Something, + required=self.Something, + ) def test_dict(self): obj = {} - assert_raises(TypeError, util.as_interface, obj, - cls=self.Something) - assert_raises(TypeError, util.as_interface, obj, methods='foo') - assert_raises(TypeError, util.as_interface, obj, - cls=self.Something, required='foo') + assert_raises(TypeError, util.as_interface, obj, cls=self.Something) + assert_raises(TypeError, util.as_interface, obj, methods="foo") + assert_raises( + TypeError, + util.as_interface, + obj, + cls=self.Something, + required="foo", + ) def assertAdapted(obj, *methods): assert isinstance(obj, type) - found = set([m for m in dir(obj) if not m.startswith('_')]) + found = set([m for m in dir(obj) if not m.startswith("_")]) for method in methods: assert method in found found.remove(method) assert not found - def fn(self): return 123 - obj = {'foo': fn, 'bar': fn} + def fn(self): + return 123 + + obj = {"foo": fn, "bar": fn} res = util.as_interface(obj, cls=self.Something) - assertAdapted(res, 'foo', 'bar') - res = util.as_interface(obj, cls=self.Something, - required=self.Something) - assertAdapted(res, 'foo', 'bar') - res = util.as_interface(obj, cls=self.Something, required=('foo',)) - assertAdapted(res, 'foo', 'bar') - res = util.as_interface(obj, methods=('foo', 'bar')) - assertAdapted(res, 'foo', 'bar') - res = util.as_interface(obj, methods=('foo', 'bar', 'baz')) - assertAdapted(res, 'foo', 'bar') - res = util.as_interface(obj, methods=('foo', 'bar'), required=('foo',)) - assertAdapted(res, 'foo', 'bar') - assert_raises(TypeError, util.as_interface, obj, methods=('foo',)) - assert_raises(TypeError, util.as_interface, obj, - methods=('foo', 'bar', 'baz'), required=('baz', )) - obj = {'foo': 123} + assertAdapted(res, "foo", "bar") + res = util.as_interface( + obj, cls=self.Something, required=self.Something + ) + assertAdapted(res, "foo", "bar") + res = util.as_interface(obj, cls=self.Something, required=("foo",)) + assertAdapted(res, "foo", "bar") + res = util.as_interface(obj, methods=("foo", "bar")) + assertAdapted(res, "foo", "bar") + res = util.as_interface(obj, methods=("foo", "bar", "baz")) + assertAdapted(res, "foo", "bar") + res = util.as_interface(obj, methods=("foo", "bar"), required=("foo",)) + assertAdapted(res, "foo", "bar") + assert_raises(TypeError, util.as_interface, obj, methods=("foo",)) + assert_raises( + TypeError, + util.as_interface, + obj, + methods=("foo", "bar", "baz"), + required=("baz",), + ) + obj = {"foo": 123} assert_raises(TypeError, util.as_interface, obj, cls=self.Something) class TestClassHierarchy(fixtures.TestBase): - def test_object(self): eq_(set(util.class_hierarchy(object)), set((object,))) @@ -2211,6 +2311,7 @@ class TestClassHierarchy(fixtures.TestBase): eq_(set(util.class_hierarchy(B)), set((A, B, C, object))) if util.py2k: + def test_oldstyle_mixin(self): class A(object): pass @@ -2255,11 +2356,7 @@ class ReraiseTest(fixtures.TestBase): type_, value, tb = sys.exc_info() util.reraise(type_, err, tb, value) - assert_raises_message( - AssertionError, - "Same cause emitted", - go - ) + assert_raises_message(AssertionError, "Same cause emitted", go) def test_raise_from_cause(self): class MyException(Exception): @@ -2317,26 +2414,23 @@ class ReraiseTest(fixtures.TestBase): class TestClassProperty(fixtures.TestBase): - def test_simple(self): class A(object): - something = {'foo': 1} + something = {"foo": 1} class B(A): - @classproperty def something(cls): d = dict(super(B, cls).something) - d.update({'bazz': 2}) + d.update({"bazz": 2}) return d - eq_(B.something, {'foo': 1, 'bazz': 2}) + eq_(B.something, {"foo": 1, "bazz": 2}) class TestProperties(fixtures.TestBase): - def test_pickle(self): - data = {'hello': 'bla'} + data = {"hello": "bla"} props = util.Properties(data) for loader, dumper in picklers(): @@ -2347,12 +2441,12 @@ class TestProperties(fixtures.TestBase): eq_(props.keys(), p.keys()) def test_keys_in_dir(self): - data = {'hello': 'bla'} + data = {"hello": "bla"} props = util.Properties(data) - in_('hello', dir(props)) + in_("hello", dir(props)) def test_pickle_immuatbleprops(self): - data = {'hello': 'bla'} + data = {"hello": "bla"} props = util.Properties(data).as_immutable() for loader, dumper in picklers(): @@ -2363,7 +2457,7 @@ class TestProperties(fixtures.TestBase): eq_(props.keys(), p.keys()) def test_pickle_orderedprops(self): - data = {'hello': 'bla'} + data = {"hello": "bla"} props = util.OrderedProperties() props.update(data) @@ -2377,118 +2471,70 @@ class TestProperties(fixtures.TestBase): class QuotedTokenParserTest(fixtures.TestBase): def _test(self, string, expected): - eq_( - langhelpers.quoted_token_parser(string), - expected - ) + eq_(langhelpers.quoted_token_parser(string), expected) def test_single(self): - self._test( - "name", - ["name"] - ) + self._test("name", ["name"]) def test_dotted(self): - self._test( - "schema.name", ["schema", "name"] - ) + self._test("schema.name", ["schema", "name"]) def test_dotted_quoted_left(self): - self._test( - '"Schema".name', ["Schema", "name"] - ) + self._test('"Schema".name', ["Schema", "name"]) def test_dotted_quoted_left_w_quote_left_edge(self): - self._test( - '"""Schema".name', ['"Schema', "name"] - ) + self._test('"""Schema".name', ['"Schema', "name"]) def test_dotted_quoted_left_w_quote_right_edge(self): - self._test( - '"Schema""".name', ['Schema"', "name"] - ) + self._test('"Schema""".name', ['Schema"', "name"]) def test_dotted_quoted_left_w_quote_middle(self): - self._test( - '"Sch""ema".name', ['Sch"ema', "name"] - ) + self._test('"Sch""ema".name', ['Sch"ema', "name"]) def test_dotted_quoted_right(self): - self._test( - 'schema."SomeName"', ["schema", "SomeName"] - ) + self._test('schema."SomeName"', ["schema", "SomeName"]) def test_dotted_quoted_right_w_quote_left_edge(self): - self._test( - 'schema."""name"', ['schema', '"name'] - ) + self._test('schema."""name"', ["schema", '"name']) def test_dotted_quoted_right_w_quote_right_edge(self): - self._test( - 'schema."name"""', ['schema', 'name"'] - ) + self._test('schema."name"""', ["schema", 'name"']) def test_dotted_quoted_right_w_quote_middle(self): - self._test( - 'schema."na""me"', ['schema', 'na"me'] - ) + self._test('schema."na""me"', ["schema", 'na"me']) def test_quoted_single_w_quote_left_edge(self): - self._test( - '"""name"', ['"name'] - ) + self._test('"""name"', ['"name']) def test_quoted_single_w_quote_right_edge(self): - self._test( - '"name"""', ['name"'] - ) + self._test('"name"""', ['name"']) def test_quoted_single_w_quote_middle(self): - self._test( - '"na""me"', ['na"me'] - ) + self._test('"na""me"', ['na"me']) def test_dotted_quoted_left_w_dot_left_edge(self): - self._test( - '".Schema".name', ['.Schema', "name"] - ) + self._test('".Schema".name', [".Schema", "name"]) def test_dotted_quoted_left_w_dot_right_edge(self): - self._test( - '"Schema.".name', ['Schema.', "name"] - ) + self._test('"Schema.".name', ["Schema.", "name"]) def test_dotted_quoted_left_w_dot_middle(self): - self._test( - '"Sch.ema".name', ['Sch.ema', "name"] - ) + self._test('"Sch.ema".name', ["Sch.ema", "name"]) def test_dotted_quoted_right_w_dot_left_edge(self): - self._test( - 'schema.".name"', ['schema', '.name'] - ) + self._test('schema.".name"', ["schema", ".name"]) def test_dotted_quoted_right_w_dot_right_edge(self): - self._test( - 'schema."name."', ['schema', 'name.'] - ) + self._test('schema."name."', ["schema", "name."]) def test_dotted_quoted_right_w_dot_middle(self): - self._test( - 'schema."na.me"', ['schema', 'na.me'] - ) + self._test('schema."na.me"', ["schema", "na.me"]) def test_quoted_single_w_dot_left_edge(self): - self._test( - '".name"', ['.name'] - ) + self._test('".name"', [".name"]) def test_quoted_single_w_dot_right_edge(self): - self._test( - '"name."', ['name.'] - ) + self._test('"name."', ["name."]) def test_quoted_single_w_dot_middle(self): - self._test( - '"na.me"', ['na.me'] - ) + self._test('"na.me"', ["na.me"]) diff --git a/test/conftest.py b/test/conftest.py index 1dadbaaee4..9b76b4e553 100755 --- a/test/conftest.py +++ b/test/conftest.py @@ -17,19 +17,23 @@ if not sys.flags.no_user_site: # We check no_user_site to honor the use of this flag. sys.path.insert( 0, - os.path.join( - os.path.dirname(os.path.abspath(__file__)), '..', 'lib') + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "lib"), ) # use bootstrapping so that test plugins are loaded # without touching the main library before coverage starts bootstrap_file = os.path.join( - os.path.dirname(__file__), "..", "lib", "sqlalchemy", - "testing", "plugin", "bootstrap.py" + os.path.dirname(__file__), + "..", + "lib", + "sqlalchemy", + "testing", + "plugin", + "bootstrap.py", ) with open(bootstrap_file) as f: - code = compile(f.read(), "bootstrap.py", 'exec') + code = compile(f.read(), "bootstrap.py", "exec") to_bootstrap = "pytest" exec(code, globals(), locals()) from pytestplugin import * # noqa diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index 70b9a6c901..03172aeb31 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -6,9 +6,25 @@ from sqlalchemy.dialects import mssql from sqlalchemy.dialects.mssql import mxodbc from sqlalchemy.testing import fixtures, AssertsCompiledSQL from sqlalchemy import sql -from sqlalchemy import Integer, String, Table, Column, select, MetaData,\ - update, delete, insert, extract, union, func, PrimaryKeyConstraint, \ - UniqueConstraint, Index, Sequence, literal +from sqlalchemy import ( + Integer, + String, + Table, + Column, + select, + MetaData, + update, + delete, + insert, + extract, + union, + func, + PrimaryKeyConstraint, + UniqueConstraint, + Index, + Sequence, + literal, +) from sqlalchemy import testing from sqlalchemy.dialects.mssql import base @@ -17,153 +33,163 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = mssql.dialect() def test_true_false(self): - self.assert_compile( - sql.false(), "0" - ) - self.assert_compile( - sql.true(), - "1" - ) + self.assert_compile(sql.false(), "0") + self.assert_compile(sql.true(), "1") def test_select(self): - t = table('sometable', column('somecolumn')) - self.assert_compile(t.select(), - 'SELECT sometable.somecolumn FROM sometable') + t = table("sometable", column("somecolumn")) + self.assert_compile( + t.select(), "SELECT sometable.somecolumn FROM sometable" + ) def test_select_with_nolock(self): - t = table('sometable', column('somecolumn')) + t = table("sometable", column("somecolumn")) self.assert_compile( - t.select().with_hint(t, 'WITH (NOLOCK)'), - 'SELECT sometable.somecolumn FROM sometable WITH (NOLOCK)') + t.select().with_hint(t, "WITH (NOLOCK)"), + "SELECT sometable.somecolumn FROM sometable WITH (NOLOCK)", + ) def test_select_with_nolock_schema(self): m = MetaData() - t = Table('sometable', m, Column('somecolumn', Integer), - schema='test_schema') + t = Table( + "sometable", m, Column("somecolumn", Integer), schema="test_schema" + ) self.assert_compile( - t.select().with_hint(t, 'WITH (NOLOCK)'), - 'SELECT test_schema.sometable.somecolumn ' - 'FROM test_schema.sometable WITH (NOLOCK)') + t.select().with_hint(t, "WITH (NOLOCK)"), + "SELECT test_schema.sometable.somecolumn " + "FROM test_schema.sometable WITH (NOLOCK)", + ) def test_select_w_order_by_collate(self): m = MetaData() - t = Table('sometable', m, Column('somecolumn', String)) + t = Table("sometable", m, Column("somecolumn", String)) self.assert_compile( - select([t]). - order_by( - t.c.somecolumn.collate("Latin1_General_CS_AS_KS_WS_CI").asc()), + select([t]).order_by( + t.c.somecolumn.collate("Latin1_General_CS_AS_KS_WS_CI").asc() + ), "SELECT sometable.somecolumn FROM sometable " "ORDER BY sometable.somecolumn COLLATE " - "Latin1_General_CS_AS_KS_WS_CI ASC" - + "Latin1_General_CS_AS_KS_WS_CI ASC", ) def test_join_with_hint(self): - t1 = table('t1', - column('a', Integer), - column('b', String), - column('c', String), - ) - t2 = table('t2', - column("a", Integer), - column("b", Integer), - column("c", Integer), - ) - join = t1.join(t2, t1.c.a == t2.c.a).\ - select().with_hint(t1, 'WITH (NOLOCK)') + t1 = table( + "t1", + column("a", Integer), + column("b", String), + column("c", String), + ) + t2 = table( + "t2", + column("a", Integer), + column("b", Integer), + column("c", Integer), + ) + join = ( + t1.join(t2, t1.c.a == t2.c.a) + .select() + .with_hint(t1, "WITH (NOLOCK)") + ) self.assert_compile( join, - 'SELECT t1.a, t1.b, t1.c, t2.a, t2.b, t2.c ' - 'FROM t1 WITH (NOLOCK) JOIN t2 ON t1.a = t2.a' + "SELECT t1.a, t1.b, t1.c, t2.a, t2.b, t2.c " + "FROM t1 WITH (NOLOCK) JOIN t2 ON t1.a = t2.a", ) def test_insert(self): - t = table('sometable', column('somecolumn')) - self.assert_compile(t.insert(), - 'INSERT INTO sometable (somecolumn) VALUES ' - '(:somecolumn)') + t = table("sometable", column("somecolumn")) + self.assert_compile( + t.insert(), + "INSERT INTO sometable (somecolumn) VALUES " "(:somecolumn)", + ) def test_update(self): - t = table('sometable', column('somecolumn')) - self.assert_compile(t.update(t.c.somecolumn == 7), - 'UPDATE sometable SET somecolumn=:somecolum' - 'n WHERE sometable.somecolumn = ' - ':somecolumn_1', dict(somecolumn=10)) + t = table("sometable", column("somecolumn")) + self.assert_compile( + t.update(t.c.somecolumn == 7), + "UPDATE sometable SET somecolumn=:somecolum" + "n WHERE sometable.somecolumn = " + ":somecolumn_1", + dict(somecolumn=10), + ) def test_insert_hint(self): - t = table('sometable', column('somecolumn')) + t = table("sometable", column("somecolumn")) for targ in (None, t): for darg in ("*", "mssql"): self.assert_compile( - t.insert(). - values(somecolumn="x"). - with_hint("WITH (PAGLOCK)", - selectable=targ, - dialect_name=darg), + t.insert() + .values(somecolumn="x") + .with_hint( + "WITH (PAGLOCK)", selectable=targ, dialect_name=darg + ), "INSERT INTO sometable WITH (PAGLOCK) " - "(somecolumn) VALUES (:somecolumn)" + "(somecolumn) VALUES (:somecolumn)", ) def test_update_hint(self): - t = table('sometable', column('somecolumn')) + t = table("sometable", column("somecolumn")) for targ in (None, t): for darg in ("*", "mssql"): self.assert_compile( - t.update().where(t.c.somecolumn == "q"). - values(somecolumn="x"). - with_hint("WITH (PAGLOCK)", - selectable=targ, - dialect_name=darg), + t.update() + .where(t.c.somecolumn == "q") + .values(somecolumn="x") + .with_hint( + "WITH (PAGLOCK)", selectable=targ, dialect_name=darg + ), "UPDATE sometable WITH (PAGLOCK) " "SET somecolumn=:somecolumn " - "WHERE sometable.somecolumn = :somecolumn_1" + "WHERE sometable.somecolumn = :somecolumn_1", ) def test_update_exclude_hint(self): - t = table('sometable', column('somecolumn')) + t = table("sometable", column("somecolumn")) self.assert_compile( - t.update().where(t.c.somecolumn == "q"). - values(somecolumn="x"). - with_hint("XYZ", "mysql"), + t.update() + .where(t.c.somecolumn == "q") + .values(somecolumn="x") + .with_hint("XYZ", "mysql"), "UPDATE sometable SET somecolumn=:somecolumn " - "WHERE sometable.somecolumn = :somecolumn_1" + "WHERE sometable.somecolumn = :somecolumn_1", ) def test_delete_hint(self): - t = table('sometable', column('somecolumn')) + t = table("sometable", column("somecolumn")) for targ in (None, t): for darg in ("*", "mssql"): self.assert_compile( - t.delete().where(t.c.somecolumn == "q"). - with_hint("WITH (PAGLOCK)", - selectable=targ, - dialect_name=darg), + t.delete() + .where(t.c.somecolumn == "q") + .with_hint( + "WITH (PAGLOCK)", selectable=targ, dialect_name=darg + ), "DELETE FROM sometable WITH (PAGLOCK) " - "WHERE sometable.somecolumn = :somecolumn_1" + "WHERE sometable.somecolumn = :somecolumn_1", ) def test_delete_exclude_hint(self): - t = table('sometable', column('somecolumn')) + t = table("sometable", column("somecolumn")) self.assert_compile( - t.delete(). - where(t.c.somecolumn == "q"). - with_hint("XYZ", dialect_name="mysql"), + t.delete() + .where(t.c.somecolumn == "q") + .with_hint("XYZ", dialect_name="mysql"), "DELETE FROM sometable WHERE " - "sometable.somecolumn = :somecolumn_1" + "sometable.somecolumn = :somecolumn_1", ) def test_delete_extra_froms(self): - t1 = table('t1', column('c1')) - t2 = table('t2', column('c1')) + t1 = table("t1", column("c1")) + t2 = table("t2", column("c1")) q = sql.delete(t1).where(t1.c.c1 == t2.c.c1) self.assert_compile( q, "DELETE FROM t1 FROM t1, t2 WHERE t1.c1 = t2.c1" ) def test_delete_extra_froms_alias(self): - a1 = table('t1', column('c1')).alias('a1') - t2 = table('t2', column('c1')) + a1 = table("t1", column("c1")).alias("a1") + t2 = table("t2", column("c1")) q = sql.delete(a1).where(a1.c.c1 == t2.c.c1) self.assert_compile( q, "DELETE FROM a1 FROM t1 AS a1, t2 WHERE a1.c1 = t2.c1" @@ -173,63 +199,74 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_update_from(self): metadata = MetaData() table1 = Table( - 'mytable', metadata, - Column('myid', Integer), - Column('name', String(30)), - Column('description', String(50))) + "mytable", + metadata, + Column("myid", Integer), + Column("name", String(30)), + Column("description", String(50)), + ) table2 = Table( - 'myothertable', metadata, - Column('otherid', Integer), - Column('othername', String(30))) + "myothertable", + metadata, + Column("otherid", Integer), + Column("othername", String(30)), + ) mt = table1.alias() - u = table1.update().values(name='foo')\ - .where(table2.c.otherid == table1.c.myid) + u = ( + table1.update() + .values(name="foo") + .where(table2.c.otherid == table1.c.myid) + ) # testing mssql.base.MSSQLCompiler.update_from_clause - self.assert_compile(u, - "UPDATE mytable SET name=:name " - "FROM mytable, myothertable WHERE " - "myothertable.otherid = mytable.myid") + self.assert_compile( + u, + "UPDATE mytable SET name=:name " + "FROM mytable, myothertable WHERE " + "myothertable.otherid = mytable.myid", + ) - self.assert_compile(u.where(table2.c.othername == mt.c.name), - "UPDATE mytable SET name=:name " - "FROM mytable, myothertable, mytable AS mytable_1 " - "WHERE myothertable.otherid = mytable.myid " - "AND myothertable.othername = mytable_1.name") + self.assert_compile( + u.where(table2.c.othername == mt.c.name), + "UPDATE mytable SET name=:name " + "FROM mytable, myothertable, mytable AS mytable_1 " + "WHERE myothertable.otherid = mytable.myid " + "AND myothertable.othername = mytable_1.name", + ) def test_update_from_hint(self): - t = table('sometable', column('somecolumn')) - t2 = table('othertable', column('somecolumn')) + t = table("sometable", column("somecolumn")) + t2 = table("othertable", column("somecolumn")) for darg in ("*", "mssql"): self.assert_compile( - t.update().where(t.c.somecolumn == t2.c.somecolumn). - values(somecolumn="x"). - with_hint("WITH (PAGLOCK)", - selectable=t2, - dialect_name=darg), + t.update() + .where(t.c.somecolumn == t2.c.somecolumn) + .values(somecolumn="x") + .with_hint("WITH (PAGLOCK)", selectable=t2, dialect_name=darg), "UPDATE sometable SET somecolumn=:somecolumn " "FROM sometable, othertable WITH (PAGLOCK) " - "WHERE sometable.somecolumn = othertable.somecolumn" + "WHERE sometable.somecolumn = othertable.somecolumn", ) def test_update_to_select_schema(self): meta = MetaData() table = Table( - "sometable", meta, + "sometable", + meta, Column("sym", String), Column("val", Integer), - schema="schema" + schema="schema", ) other = Table( - "#other", meta, - Column("sym", String), - Column("newval", Integer) + "#other", meta, Column("sym", String), Column("newval", Integer) ) stmt = table.update().values( - val=select([other.c.newval]). - where(table.c.sym == other.c.sym).as_scalar()) + val=select([other.c.newval]) + .where(table.c.sym == other.c.sym) + .as_scalar() + ) self.assert_compile( stmt, @@ -238,8 +275,11 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "WHERE [schema].sometable.sym = [#other].sym)", ) - stmt = table.update().values(val=other.c.newval).\ - where(table.c.sym == other.c.sym) + stmt = ( + table.update() + .values(val=other.c.newval) + .where(table.c.sym == other.c.sym) + ) self.assert_compile( stmt, "UPDATE [schema].sometable SET val=" @@ -264,10 +304,11 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): """test the 'strict' compiler binds.""" from sqlalchemy.dialects.mssql.base import MSSQLStrictCompiler + mxodbc_dialect = mxodbc.dialect() mxodbc_dialect.statement_compiler = MSSQLStrictCompiler - t = table('sometable', column('foo')) + t = table("sometable", column("foo")) for expr, compile in [ ( @@ -275,14 +316,11 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT 'x' AS anon_1, 'y' AS anon_2", ), ( - select([t]).where(t.c.foo.in_(['x', 'y', 'z'])), + select([t]).where(t.c.foo.in_(["x", "y", "z"])), "SELECT sometable.foo FROM sometable WHERE sometable.foo " "IN ('x', 'y', 'z')", ), - ( - t.c.foo.in_([None]), - "sometable.foo IN (NULL)" - ) + (t.c.foo.in_([None]), "sometable.foo IN (NULL)"), ]: self.assert_compile(expr, compile, dialect=mxodbc_dialect) @@ -292,115 +330,121 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): """ - t = table('sometable', column('somecolumn')) - self.assert_compile(t.select().where(t.c.somecolumn - == t.select()), - 'SELECT sometable.somecolumn FROM ' - 'sometable WHERE sometable.somecolumn = ' - '(SELECT sometable.somecolumn FROM ' - 'sometable)') - self.assert_compile(t.select().where(t.c.somecolumn - != t.select()), - 'SELECT sometable.somecolumn FROM ' - 'sometable WHERE sometable.somecolumn != ' - '(SELECT sometable.somecolumn FROM ' - 'sometable)') + t = table("sometable", column("somecolumn")) + self.assert_compile( + t.select().where(t.c.somecolumn == t.select()), + "SELECT sometable.somecolumn FROM " + "sometable WHERE sometable.somecolumn = " + "(SELECT sometable.somecolumn FROM " + "sometable)", + ) + self.assert_compile( + t.select().where(t.c.somecolumn != t.select()), + "SELECT sometable.somecolumn FROM " + "sometable WHERE sometable.somecolumn != " + "(SELECT sometable.somecolumn FROM " + "sometable)", + ) @testing.uses_deprecated def test_count(self): - t = table('sometable', column('somecolumn')) - self.assert_compile(t.count(), - 'SELECT count(sometable.somecolumn) AS ' - 'tbl_row_count FROM sometable') + t = table("sometable", column("somecolumn")) + self.assert_compile( + t.count(), + "SELECT count(sometable.somecolumn) AS " + "tbl_row_count FROM sometable", + ) def test_noorderby_insubquery(self): """test that the ms-sql dialect removes ORDER BY clauses from subqueries""" - table1 = table('mytable', - column('myid', Integer), - column('name', String), - column('description', String), - ) + table1 = table( + "mytable", + column("myid", Integer), + column("name", String), + column("description", String), + ) - q = select([table1.c.myid], - order_by=[table1.c.myid]).alias('foo') + q = select([table1.c.myid], order_by=[table1.c.myid]).alias("foo") crit = q.c.myid == table1.c.myid - self.assert_compile(select(['*'], crit), - "SELECT * FROM (SELECT mytable.myid AS " - "myid FROM mytable) AS foo, mytable WHERE " - "foo.myid = mytable.myid") + self.assert_compile( + select(["*"], crit), + "SELECT * FROM (SELECT mytable.myid AS " + "myid FROM mytable) AS foo, mytable WHERE " + "foo.myid = mytable.myid", + ) def test_force_schema_quoted_name_w_dot_case_insensitive(self): metadata = MetaData() tbl = Table( - 'test', metadata, - Column('id', Integer, primary_key=True), - schema=quoted_name("foo.dbo", True) + "test", + metadata, + Column("id", Integer, primary_key=True), + schema=quoted_name("foo.dbo", True), ) self.assert_compile( - select([tbl]), - "SELECT [foo.dbo].test.id FROM [foo.dbo].test" + select([tbl]), "SELECT [foo.dbo].test.id FROM [foo.dbo].test" ) def test_force_schema_quoted_w_dot_case_insensitive(self): metadata = MetaData() tbl = Table( - 'test', metadata, - Column('id', Integer, primary_key=True), - schema=quoted_name("foo.dbo", True) + "test", + metadata, + Column("id", Integer, primary_key=True), + schema=quoted_name("foo.dbo", True), ) self.assert_compile( - select([tbl]), - "SELECT [foo.dbo].test.id FROM [foo.dbo].test" + select([tbl]), "SELECT [foo.dbo].test.id FROM [foo.dbo].test" ) def test_force_schema_quoted_name_w_dot_case_sensitive(self): metadata = MetaData() tbl = Table( - 'test', metadata, - Column('id', Integer, primary_key=True), - schema=quoted_name("Foo.dbo", True) + "test", + metadata, + Column("id", Integer, primary_key=True), + schema=quoted_name("Foo.dbo", True), ) self.assert_compile( - select([tbl]), - "SELECT [Foo.dbo].test.id FROM [Foo.dbo].test" + select([tbl]), "SELECT [Foo.dbo].test.id FROM [Foo.dbo].test" ) def test_force_schema_quoted_w_dot_case_sensitive(self): metadata = MetaData() tbl = Table( - 'test', metadata, - Column('id', Integer, primary_key=True), - schema="[Foo.dbo]" + "test", + metadata, + Column("id", Integer, primary_key=True), + schema="[Foo.dbo]", ) self.assert_compile( - select([tbl]), - "SELECT [Foo.dbo].test.id FROM [Foo.dbo].test" + select([tbl]), "SELECT [Foo.dbo].test.id FROM [Foo.dbo].test" ) def test_schema_autosplit_w_dot_case_insensitive(self): metadata = MetaData() tbl = Table( - 'test', metadata, - Column('id', Integer, primary_key=True), - schema="foo.dbo" + "test", + metadata, + Column("id", Integer, primary_key=True), + schema="foo.dbo", ) self.assert_compile( - select([tbl]), - "SELECT foo.dbo.test.id FROM foo.dbo.test" + select([tbl]), "SELECT foo.dbo.test.id FROM foo.dbo.test" ) def test_schema_autosplit_w_dot_case_sensitive(self): metadata = MetaData() tbl = Table( - 'test', metadata, - Column('id', Integer, primary_key=True), - schema="Foo.dbo" + "test", + metadata, + Column("id", Integer, primary_key=True), + schema="Foo.dbo", ) self.assert_compile( - select([tbl]), - "SELECT [Foo].dbo.test.id FROM [Foo].dbo.test" + select([tbl]), "SELECT [Foo].dbo.test.id FROM [Foo].dbo.test" ) def test_owner_database_pairs(self): @@ -420,61 +464,82 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_delete_schema(self): metadata = MetaData() - tbl = Table('test', metadata, Column('id', Integer, - primary_key=True), schema='paj') - self.assert_compile(tbl.delete(tbl.c.id == 1), - 'DELETE FROM paj.test WHERE paj.test.id = ' - ':id_1') + tbl = Table( + "test", + metadata, + Column("id", Integer, primary_key=True), + schema="paj", + ) + self.assert_compile( + tbl.delete(tbl.c.id == 1), + "DELETE FROM paj.test WHERE paj.test.id = " ":id_1", + ) s = select([tbl.c.id]).where(tbl.c.id == 1) - self.assert_compile(tbl.delete().where(tbl.c.id.in_(s)), - 'DELETE FROM paj.test WHERE paj.test.id IN ' - '(SELECT paj.test.id FROM paj.test ' - 'WHERE paj.test.id = :id_1)') + self.assert_compile( + tbl.delete().where(tbl.c.id.in_(s)), + "DELETE FROM paj.test WHERE paj.test.id IN " + "(SELECT paj.test.id FROM paj.test " + "WHERE paj.test.id = :id_1)", + ) def test_delete_schema_multipart(self): metadata = MetaData() tbl = Table( - 'test', metadata, - Column('id', Integer, - primary_key=True), - schema='banana.paj') - self.assert_compile(tbl.delete(tbl.c.id == 1), - 'DELETE FROM banana.paj.test WHERE ' - 'banana.paj.test.id = :id_1') + "test", + metadata, + Column("id", Integer, primary_key=True), + schema="banana.paj", + ) + self.assert_compile( + tbl.delete(tbl.c.id == 1), + "DELETE FROM banana.paj.test WHERE " "banana.paj.test.id = :id_1", + ) s = select([tbl.c.id]).where(tbl.c.id == 1) - self.assert_compile(tbl.delete().where(tbl.c.id.in_(s)), - 'DELETE FROM banana.paj.test WHERE ' - 'banana.paj.test.id IN (SELECT banana.paj.test.id ' - 'FROM banana.paj.test WHERE ' - 'banana.paj.test.id = :id_1)') + self.assert_compile( + tbl.delete().where(tbl.c.id.in_(s)), + "DELETE FROM banana.paj.test WHERE " + "banana.paj.test.id IN (SELECT banana.paj.test.id " + "FROM banana.paj.test WHERE " + "banana.paj.test.id = :id_1)", + ) def test_delete_schema_multipart_needs_quoting(self): metadata = MetaData() tbl = Table( - 'test', metadata, - Column('id', Integer, primary_key=True), - schema='banana split.paj') - self.assert_compile(tbl.delete(tbl.c.id == 1), - 'DELETE FROM [banana split].paj.test WHERE ' - '[banana split].paj.test.id = :id_1') + "test", + metadata, + Column("id", Integer, primary_key=True), + schema="banana split.paj", + ) + self.assert_compile( + tbl.delete(tbl.c.id == 1), + "DELETE FROM [banana split].paj.test WHERE " + "[banana split].paj.test.id = :id_1", + ) s = select([tbl.c.id]).where(tbl.c.id == 1) - self.assert_compile(tbl.delete().where(tbl.c.id.in_(s)), - 'DELETE FROM [banana split].paj.test WHERE ' - '[banana split].paj.test.id IN (' - - 'SELECT [banana split].paj.test.id FROM ' - '[banana split].paj.test WHERE ' - '[banana split].paj.test.id = :id_1)') + self.assert_compile( + tbl.delete().where(tbl.c.id.in_(s)), + "DELETE FROM [banana split].paj.test WHERE " + "[banana split].paj.test.id IN (" + "SELECT [banana split].paj.test.id FROM " + "[banana split].paj.test WHERE " + "[banana split].paj.test.id = :id_1)", + ) def test_delete_schema_multipart_both_need_quoting(self): metadata = MetaData() - tbl = Table('test', metadata, Column('id', Integer, - primary_key=True), - schema='banana split.paj with a space') - self.assert_compile(tbl.delete(tbl.c.id == 1), - 'DELETE FROM [banana split].[paj with a ' - 'space].test WHERE [banana split].[paj ' - 'with a space].test.id = :id_1') + tbl = Table( + "test", + metadata, + Column("id", Integer, primary_key=True), + schema="banana split.paj with a space", + ) + self.assert_compile( + tbl.delete(tbl.c.id == 1), + "DELETE FROM [banana split].[paj with a " + "space].test WHERE [banana split].[paj " + "with a space].test.id = :id_1", + ) s = select([tbl.c.id]).where(tbl.c.id == 1) self.assert_compile( tbl.delete().where(tbl.c.id.in_(s)), @@ -482,156 +547,204 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "WHERE [banana split].[paj with a space].test.id IN " "(SELECT [banana split].[paj with a space].test.id " "FROM [banana split].[paj with a space].test " - "WHERE [banana split].[paj with a space].test.id = :id_1)" + "WHERE [banana split].[paj with a space].test.id = :id_1)", ) def test_union(self): t1 = table( - 't1', column('col1'), column('col2'), - column('col3'), column('col4')) + "t1", + column("col1"), + column("col2"), + column("col3"), + column("col4"), + ) t2 = table( - 't2', column('col1'), column('col2'), - column('col3'), column('col4')) - s1, s2 = select( - [t1.c.col3.label('col3'), t1.c.col4.label('col4')], - t1.c.col2.in_(['t1col2r1', 't1col2r2'])), \ - select([t2.c.col3.label('col3'), t2.c.col4.label('col4')], - t2.c.col2.in_(['t2col2r2', 't2col2r3'])) - u = union(s1, s2, order_by=['col3', 'col4']) - self.assert_compile(u, - 'SELECT t1.col3 AS col3, t1.col4 AS col4 ' - 'FROM t1 WHERE t1.col2 IN (:col2_1, ' - ':col2_2) UNION SELECT t2.col3 AS col3, ' - 't2.col4 AS col4 FROM t2 WHERE t2.col2 IN ' - '(:col2_3, :col2_4) ORDER BY col3, col4') - self.assert_compile(u.alias('bar').select(), - 'SELECT bar.col3, bar.col4 FROM (SELECT ' - 't1.col3 AS col3, t1.col4 AS col4 FROM t1 ' - 'WHERE t1.col2 IN (:col2_1, :col2_2) UNION ' - 'SELECT t2.col3 AS col3, t2.col4 AS col4 ' - 'FROM t2 WHERE t2.col2 IN (:col2_3, ' - ':col2_4)) AS bar') + "t2", + column("col1"), + column("col2"), + column("col3"), + column("col4"), + ) + s1, s2 = ( + select( + [t1.c.col3.label("col3"), t1.c.col4.label("col4")], + t1.c.col2.in_(["t1col2r1", "t1col2r2"]), + ), + select( + [t2.c.col3.label("col3"), t2.c.col4.label("col4")], + t2.c.col2.in_(["t2col2r2", "t2col2r3"]), + ), + ) + u = union(s1, s2, order_by=["col3", "col4"]) + self.assert_compile( + u, + "SELECT t1.col3 AS col3, t1.col4 AS col4 " + "FROM t1 WHERE t1.col2 IN (:col2_1, " + ":col2_2) UNION SELECT t2.col3 AS col3, " + "t2.col4 AS col4 FROM t2 WHERE t2.col2 IN " + "(:col2_3, :col2_4) ORDER BY col3, col4", + ) + self.assert_compile( + u.alias("bar").select(), + "SELECT bar.col3, bar.col4 FROM (SELECT " + "t1.col3 AS col3, t1.col4 AS col4 FROM t1 " + "WHERE t1.col2 IN (:col2_1, :col2_2) UNION " + "SELECT t2.col3 AS col3, t2.col4 AS col4 " + "FROM t2 WHERE t2.col2 IN (:col2_3, " + ":col2_4)) AS bar", + ) def test_function(self): - self.assert_compile(func.foo(1, 2), 'foo(:foo_1, :foo_2)') - self.assert_compile(func.current_time(), 'CURRENT_TIME') - self.assert_compile(func.foo(), 'foo()') + self.assert_compile(func.foo(1, 2), "foo(:foo_1, :foo_2)") + self.assert_compile(func.current_time(), "CURRENT_TIME") + self.assert_compile(func.foo(), "foo()") m = MetaData() t = Table( - 'sometable', m, Column('col1', Integer), Column('col2', Integer)) - self.assert_compile(select([func.max(t.c.col1)]), - 'SELECT max(sometable.col1) AS max_1 FROM ' - 'sometable') + "sometable", m, Column("col1", Integer), Column("col2", Integer) + ) + self.assert_compile( + select([func.max(t.c.col1)]), + "SELECT max(sometable.col1) AS max_1 FROM " "sometable", + ) def test_function_overrides(self): self.assert_compile(func.current_date(), "GETDATE()") self.assert_compile(func.length(3), "LEN(:length_1)") def test_extract(self): - t = table('t', column('col1')) + t = table("t", column("col1")) - for field in 'day', 'month', 'year': + for field in "day", "month", "year": self.assert_compile( select([extract(field, t.c.col1)]), - 'SELECT DATEPART(%s, t.col1) AS anon_1 FROM t' % field) + "SELECT DATEPART(%s, t.col1) AS anon_1 FROM t" % field, + ) def test_update_returning(self): table1 = table( - 'mytable', - column('myid', Integer), - column('name', String(128)), - column('description', String(128))) - u = update( - table1, - values=dict(name='foo')).returning(table1.c.myid, table1.c.name) - self.assert_compile(u, - 'UPDATE mytable SET name=:name OUTPUT ' - 'inserted.myid, inserted.name') - u = update(table1, values=dict(name='foo')).returning(table1) - self.assert_compile(u, - 'UPDATE mytable SET name=:name OUTPUT ' - 'inserted.myid, inserted.name, ' - 'inserted.description') - u = update( - table1, - values=dict( - name='foo')).returning(table1).where(table1.c.name == 'bar') - self.assert_compile(u, - 'UPDATE mytable SET name=:name OUTPUT ' - 'inserted.myid, inserted.name, ' - 'inserted.description WHERE mytable.name = ' - ':name_1') - u = update(table1, values=dict(name='foo' - )).returning(func.length(table1.c.name)) - self.assert_compile(u, - 'UPDATE mytable SET name=:name OUTPUT ' - 'LEN(inserted.name) AS length_1') + "mytable", + column("myid", Integer), + column("name", String(128)), + column("description", String(128)), + ) + u = update(table1, values=dict(name="foo")).returning( + table1.c.myid, table1.c.name + ) + self.assert_compile( + u, + "UPDATE mytable SET name=:name OUTPUT " + "inserted.myid, inserted.name", + ) + u = update(table1, values=dict(name="foo")).returning(table1) + self.assert_compile( + u, + "UPDATE mytable SET name=:name OUTPUT " + "inserted.myid, inserted.name, " + "inserted.description", + ) + u = ( + update(table1, values=dict(name="foo")) + .returning(table1) + .where(table1.c.name == "bar") + ) + self.assert_compile( + u, + "UPDATE mytable SET name=:name OUTPUT " + "inserted.myid, inserted.name, " + "inserted.description WHERE mytable.name = " + ":name_1", + ) + u = update(table1, values=dict(name="foo")).returning( + func.length(table1.c.name) + ) + self.assert_compile( + u, + "UPDATE mytable SET name=:name OUTPUT " + "LEN(inserted.name) AS length_1", + ) def test_delete_returning(self): table1 = table( - 'mytable', column('myid', Integer), - column('name', String(128)), column('description', String(128))) + "mytable", + column("myid", Integer), + column("name", String(128)), + column("description", String(128)), + ) d = delete(table1).returning(table1.c.myid, table1.c.name) - self.assert_compile(d, - 'DELETE FROM mytable OUTPUT deleted.myid, ' - 'deleted.name') - d = delete(table1).where(table1.c.name == 'bar' - ).returning(table1.c.myid, - table1.c.name) - self.assert_compile(d, - 'DELETE FROM mytable OUTPUT deleted.myid, ' - 'deleted.name WHERE mytable.name = :name_1') + self.assert_compile( + d, "DELETE FROM mytable OUTPUT deleted.myid, " "deleted.name" + ) + d = ( + delete(table1) + .where(table1.c.name == "bar") + .returning(table1.c.myid, table1.c.name) + ) + self.assert_compile( + d, + "DELETE FROM mytable OUTPUT deleted.myid, " + "deleted.name WHERE mytable.name = :name_1", + ) def test_insert_returning(self): table1 = table( - 'mytable', column('myid', Integer), - column('name', String(128)), column('description', String(128))) - i = insert( - table1, - values=dict(name='foo')).returning(table1.c.myid, table1.c.name) - self.assert_compile(i, - 'INSERT INTO mytable (name) OUTPUT ' - 'inserted.myid, inserted.name VALUES ' - '(:name)') - i = insert(table1, values=dict(name='foo')).returning(table1) - self.assert_compile(i, - 'INSERT INTO mytable (name) OUTPUT ' - 'inserted.myid, inserted.name, ' - 'inserted.description VALUES (:name)') - i = insert(table1, values=dict(name='foo' - )).returning(func.length(table1.c.name)) - self.assert_compile(i, - 'INSERT INTO mytable (name) OUTPUT ' - 'LEN(inserted.name) AS length_1 VALUES ' - '(:name)') + "mytable", + column("myid", Integer), + column("name", String(128)), + column("description", String(128)), + ) + i = insert(table1, values=dict(name="foo")).returning( + table1.c.myid, table1.c.name + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) OUTPUT " + "inserted.myid, inserted.name VALUES " + "(:name)", + ) + i = insert(table1, values=dict(name="foo")).returning(table1) + self.assert_compile( + i, + "INSERT INTO mytable (name) OUTPUT " + "inserted.myid, inserted.name, " + "inserted.description VALUES (:name)", + ) + i = insert(table1, values=dict(name="foo")).returning( + func.length(table1.c.name) + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) OUTPUT " + "LEN(inserted.name) AS length_1 VALUES " + "(:name)", + ) def test_limit_using_top(self): - t = table('t', column('x', Integer), column('y', Integer)) + t = table("t", column("x", Integer), column("y", Integer)) s = select([t]).where(t.c.x == 5).order_by(t.c.y).limit(10) self.assert_compile( s, "SELECT TOP 10 t.x, t.y FROM t WHERE t.x = :x_1 ORDER BY t.y", - checkparams={'x_1': 5} + checkparams={"x_1": 5}, ) def test_limit_zero_using_top(self): - t = table('t', column('x', Integer), column('y', Integer)) + t = table("t", column("x", Integer), column("y", Integer)) s = select([t]).where(t.c.x == 5).order_by(t.c.y).limit(0) self.assert_compile( s, "SELECT TOP 0 t.x, t.y FROM t WHERE t.x = :x_1 ORDER BY t.y", - checkparams={'x_1': 5} + checkparams={"x_1": 5}, ) c = s.compile(dialect=mssql.dialect()) eq_(len(c._result_columns), 2) - assert t.c.x in set(c._create_result_map()['x'][1]) + assert t.c.x in set(c._create_result_map()["x"][1]) def test_offset_using_window(self): - t = table('t', column('x', Integer), column('y', Integer)) + t = table("t", column("x", Integer), column("y", Integer)) s = select([t]).where(t.c.x == 5).order_by(t.c.y).offset(20) @@ -644,15 +757,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "AS y, ROW_NUMBER() OVER (ORDER BY t.y) AS " "mssql_rn FROM t WHERE t.x = :x_1) AS " "anon_1 WHERE mssql_rn > :param_1", - checkparams={'param_1': 20, 'x_1': 5} + checkparams={"param_1": 20, "x_1": 5}, ) c = s.compile(dialect=mssql.dialect()) eq_(len(c._result_columns), 2) - assert t.c.x in set(c._create_result_map()['x'][1]) + assert t.c.x in set(c._create_result_map()["x"][1]) def test_limit_offset_using_window(self): - t = table('t', column('x', Integer), column('y', Integer)) + t = table("t", column("x", Integer), column("y", Integer)) s = select([t]).where(t.c.x == 5).order_by(t.c.y).limit(10).offset(20) @@ -664,17 +777,17 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "FROM t " "WHERE t.x = :x_1) AS anon_1 " "WHERE mssql_rn > :param_1 AND mssql_rn <= :param_2 + :param_1", - checkparams={'param_1': 20, 'param_2': 10, 'x_1': 5} + checkparams={"param_1": 20, "param_2": 10, "x_1": 5}, ) c = s.compile(dialect=mssql.dialect()) eq_(len(c._result_columns), 2) - assert t.c.x in set(c._create_result_map()['x'][1]) - assert t.c.y in set(c._create_result_map()['y'][1]) + assert t.c.x in set(c._create_result_map()["x"][1]) + assert t.c.y in set(c._create_result_map()["y"][1]) def test_limit_offset_w_ambiguous_cols(self): - t = table('t', column('x', Integer), column('y', Integer)) + t = table("t", column("x", Integer), column("y", Integer)) - cols = [t.c.x, t.c.x.label('q'), t.c.x.label('p'), t.c.y] + cols = [t.c.x, t.c.x.label("q"), t.c.x.label("p"), t.c.y] s = select(cols).where(t.c.x == 5).order_by(t.c.y).limit(10).offset(20) self.assert_compile( @@ -685,7 +798,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "FROM t " "WHERE t.x = :x_1) AS anon_1 " "WHERE mssql_rn > :param_1 AND mssql_rn <= :param_2 + :param_1", - checkparams={'param_1': 20, 'param_2': 10, 'x_1': 5} + checkparams={"param_1": 20, "param_2": 10, "x_1": 5}, ) c = s.compile(dialect=mssql.dialect()) eq_(len(c._result_columns), 4) @@ -696,12 +809,17 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): is_(result_map[col.key][1][0], col) def test_limit_offset_with_correlated_order_by(self): - t1 = table('t1', column('x', Integer), column('y', Integer)) - t2 = table('t2', column('x', Integer), column('y', Integer)) + t1 = table("t1", column("x", Integer), column("y", Integer)) + t2 = table("t2", column("x", Integer), column("y", Integer)) order_by = select([t2.c.y]).where(t1.c.x == t2.c.x).as_scalar() - s = select([t1]).where(t1.c.x == 5).order_by(order_by) \ - .limit(10).offset(20) + s = ( + select([t1]) + .where(t1.c.x == 5) + .order_by(order_by) + .limit(10) + .offset(20) + ) self.assert_compile( s, @@ -713,21 +831,21 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "FROM t1 " "WHERE t1.x = :x_1) AS anon_1 " "WHERE mssql_rn > :param_1 AND mssql_rn <= :param_2 + :param_1", - checkparams={'param_1': 20, 'param_2': 10, 'x_1': 5} + checkparams={"param_1": 20, "param_2": 10, "x_1": 5}, ) c = s.compile(dialect=mssql.dialect()) eq_(len(c._result_columns), 2) - assert t1.c.x in set(c._create_result_map()['x'][1]) - assert t1.c.y in set(c._create_result_map()['y'][1]) + assert t1.c.x in set(c._create_result_map()["x"][1]) + assert t1.c.y in set(c._create_result_map()["y"][1]) def test_offset_dont_misapply_labelreference(self): m = MetaData() - t = Table('t', m, Column('x', Integer)) + t = Table("t", m, Column("x", Integer)) - expr1 = func.foo(t.c.x).label('x') - expr2 = func.foo(t.c.x).label('y') + expr1 = func.foo(t.c.x).label("x") + expr2 = func.foo(t.c.x).label("y") stmt1 = select([expr1]).order_by(expr1.desc()).offset(1) stmt2 = select([expr2]).order_by(expr2.desc()).offset(1) @@ -736,18 +854,18 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): stmt1, "SELECT anon_1.x FROM (SELECT foo(t.x) AS x, " "ROW_NUMBER() OVER (ORDER BY foo(t.x) DESC) AS mssql_rn FROM t) " - "AS anon_1 WHERE mssql_rn > :param_1" + "AS anon_1 WHERE mssql_rn > :param_1", ) self.assert_compile( stmt2, "SELECT anon_1.y FROM (SELECT foo(t.x) AS y, " "ROW_NUMBER() OVER (ORDER BY foo(t.x) DESC) AS mssql_rn FROM t) " - "AS anon_1 WHERE mssql_rn > :param_1" + "AS anon_1 WHERE mssql_rn > :param_1", ) def test_limit_zero_offset_using_window(self): - t = table('t', column('x', Integer), column('y', Integer)) + t = table("t", column("x", Integer), column("y", Integer)) s = select([t]).where(t.c.x == 5).order_by(t.c.y).limit(0).offset(0) @@ -755,264 +873,303 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): # of zero, so produces TOP 0 self.assert_compile( s, - "SELECT TOP 0 t.x, t.y FROM t " - "WHERE t.x = :x_1 ORDER BY t.y", - checkparams={'x_1': 5} + "SELECT TOP 0 t.x, t.y FROM t " "WHERE t.x = :x_1 ORDER BY t.y", + checkparams={"x_1": 5}, ) def test_primary_key_no_identity(self): metadata = MetaData() - tbl = Table('test', metadata, - Column('id', Integer, autoincrement=False, - primary_key=True)) + tbl = Table( + "test", + metadata, + Column("id", Integer, autoincrement=False, primary_key=True), + ) self.assert_compile( schema.CreateTable(tbl), - "CREATE TABLE test (id INTEGER NOT NULL, " - "PRIMARY KEY (id))" + "CREATE TABLE test (id INTEGER NOT NULL, " "PRIMARY KEY (id))", ) def test_primary_key_defaults_to_identity(self): metadata = MetaData() - tbl = Table('test', metadata, - Column('id', Integer, primary_key=True)) + tbl = Table("test", metadata, Column("id", Integer, primary_key=True)) self.assert_compile( schema.CreateTable(tbl), "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(1,1), " - "PRIMARY KEY (id))" + "PRIMARY KEY (id))", ) def test_identity_no_primary_key(self): metadata = MetaData() - tbl = Table('test', metadata, - Column('id', Integer, autoincrement=True)) + tbl = Table( + "test", metadata, Column("id", Integer, autoincrement=True) + ) self.assert_compile( schema.CreateTable(tbl), - "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(1,1)" - ")" + "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(1,1)" ")", ) def test_identity_separate_from_primary_key(self): metadata = MetaData() - tbl = Table('test', metadata, - Column('id', Integer, autoincrement=False, - primary_key=True), - Column('x', Integer, autoincrement=True) - ) + tbl = Table( + "test", + metadata, + Column("id", Integer, autoincrement=False, primary_key=True), + Column("x", Integer, autoincrement=True), + ) self.assert_compile( schema.CreateTable(tbl), "CREATE TABLE test (id INTEGER NOT NULL, " "x INTEGER NOT NULL IDENTITY(1,1), " - "PRIMARY KEY (id))" + "PRIMARY KEY (id))", ) def test_identity_illegal_two_autoincrements(self): metadata = MetaData() - tbl = Table('test', metadata, - Column('id', Integer, autoincrement=True), - Column('id2', Integer, autoincrement=True), - ) + tbl = Table( + "test", + metadata, + Column("id", Integer, autoincrement=True), + Column("id2", Integer, autoincrement=True), + ) # this will be rejected by the database, just asserting this is what # the two autoincrements will do right now self.assert_compile( schema.CreateTable(tbl), "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(1,1), " - "id2 INTEGER NOT NULL IDENTITY(1,1))" + "id2 INTEGER NOT NULL IDENTITY(1,1))", ) def test_identity_start_0(self): metadata = MetaData() - tbl = Table('test', metadata, - Column('id', Integer, mssql_identity_start=0, - primary_key=True)) + tbl = Table( + "test", + metadata, + Column("id", Integer, mssql_identity_start=0, primary_key=True), + ) self.assert_compile( schema.CreateTable(tbl), "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(0,1), " - "PRIMARY KEY (id))" + "PRIMARY KEY (id))", ) def test_identity_increment_5(self): metadata = MetaData() - tbl = Table('test', metadata, - Column('id', Integer, mssql_identity_increment=5, - primary_key=True)) + tbl = Table( + "test", + metadata, + Column( + "id", Integer, mssql_identity_increment=5, primary_key=True + ), + ) self.assert_compile( schema.CreateTable(tbl), "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(1,5), " - "PRIMARY KEY (id))" + "PRIMARY KEY (id))", ) def test_sequence_start_0(self): metadata = MetaData() - tbl = Table('test', metadata, - Column('id', Integer, Sequence('', 0), primary_key=True)) + tbl = Table( + "test", + metadata, + Column("id", Integer, Sequence("", 0), primary_key=True), + ) with testing.expect_deprecated( - "Use of Sequence with SQL Server in order to affect "): + "Use of Sequence with SQL Server in order to affect " + ): self.assert_compile( schema.CreateTable(tbl), "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(0,1), " - "PRIMARY KEY (id))" + "PRIMARY KEY (id))", ) def test_sequence_non_primary_key(self): metadata = MetaData() - tbl = Table('test', metadata, - Column('id', Integer, Sequence('', start=5), - primary_key=False)) + tbl = Table( + "test", + metadata, + Column("id", Integer, Sequence("", start=5), primary_key=False), + ) with testing.expect_deprecated( - "Use of Sequence with SQL Server in order to affect "): + "Use of Sequence with SQL Server in order to affect " + ): self.assert_compile( schema.CreateTable(tbl), - "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(5,1))" + "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(5,1))", ) def test_sequence_ignore_nullability(self): metadata = MetaData() - tbl = Table('test', metadata, - Column('id', Integer, Sequence('', start=5), - nullable=True)) + tbl = Table( + "test", + metadata, + Column("id", Integer, Sequence("", start=5), nullable=True), + ) with testing.expect_deprecated( - "Use of Sequence with SQL Server in order to affect "): + "Use of Sequence with SQL Server in order to affect " + ): self.assert_compile( schema.CreateTable(tbl), - "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(5,1))" + "CREATE TABLE test (id INTEGER NOT NULL IDENTITY(5,1))", ) def test_table_pkc_clustering(self): metadata = MetaData() - tbl = Table('test', metadata, - Column('x', Integer, autoincrement=False), - Column('y', Integer, autoincrement=False), - PrimaryKeyConstraint("x", "y", mssql_clustered=True)) + tbl = Table( + "test", + metadata, + Column("x", Integer, autoincrement=False), + Column("y", Integer, autoincrement=False), + PrimaryKeyConstraint("x", "y", mssql_clustered=True), + ) self.assert_compile( schema.CreateTable(tbl), "CREATE TABLE test (x INTEGER NOT NULL, y INTEGER NOT NULL, " - "PRIMARY KEY CLUSTERED (x, y))" + "PRIMARY KEY CLUSTERED (x, y))", ) def test_table_pkc_explicit_nonclustered(self): metadata = MetaData() - tbl = Table('test', metadata, - Column('x', Integer, autoincrement=False), - Column('y', Integer, autoincrement=False), - PrimaryKeyConstraint("x", "y", mssql_clustered=False)) + tbl = Table( + "test", + metadata, + Column("x", Integer, autoincrement=False), + Column("y", Integer, autoincrement=False), + PrimaryKeyConstraint("x", "y", mssql_clustered=False), + ) self.assert_compile( schema.CreateTable(tbl), "CREATE TABLE test (x INTEGER NOT NULL, y INTEGER NOT NULL, " - "PRIMARY KEY NONCLUSTERED (x, y))" + "PRIMARY KEY NONCLUSTERED (x, y))", ) def test_table_idx_explicit_nonclustered(self): metadata = MetaData() tbl = Table( - 'test', metadata, - Column('x', Integer, autoincrement=False), - Column('y', Integer, autoincrement=False) + "test", + metadata, + Column("x", Integer, autoincrement=False), + Column("y", Integer, autoincrement=False), ) idx = Index("myidx", tbl.c.x, tbl.c.y, mssql_clustered=False) self.assert_compile( schema.CreateIndex(idx), - "CREATE NONCLUSTERED INDEX myidx ON test (x, y)" + "CREATE NONCLUSTERED INDEX myidx ON test (x, y)", ) def test_table_uc_explicit_nonclustered(self): metadata = MetaData() - tbl = Table('test', metadata, - Column('x', Integer, autoincrement=False), - Column('y', Integer, autoincrement=False), - UniqueConstraint("x", "y", mssql_clustered=False)) + tbl = Table( + "test", + metadata, + Column("x", Integer, autoincrement=False), + Column("y", Integer, autoincrement=False), + UniqueConstraint("x", "y", mssql_clustered=False), + ) self.assert_compile( schema.CreateTable(tbl), "CREATE TABLE test (x INTEGER NULL, y INTEGER NULL, " - "UNIQUE NONCLUSTERED (x, y))" + "UNIQUE NONCLUSTERED (x, y))", ) def test_table_uc_clustering(self): metadata = MetaData() - tbl = Table('test', metadata, - Column('x', Integer, autoincrement=False), - Column('y', Integer, autoincrement=False), - PrimaryKeyConstraint("x"), - UniqueConstraint("y", mssql_clustered=True)) + tbl = Table( + "test", + metadata, + Column("x", Integer, autoincrement=False), + Column("y", Integer, autoincrement=False), + PrimaryKeyConstraint("x"), + UniqueConstraint("y", mssql_clustered=True), + ) self.assert_compile( schema.CreateTable(tbl), "CREATE TABLE test (x INTEGER NOT NULL, y INTEGER NULL, " - "PRIMARY KEY (x), UNIQUE CLUSTERED (y))" + "PRIMARY KEY (x), UNIQUE CLUSTERED (y))", ) def test_index_clustering(self): metadata = MetaData() - tbl = Table('test', metadata, - Column('id', Integer)) + tbl = Table("test", metadata, Column("id", Integer)) idx = Index("foo", tbl.c.id, mssql_clustered=True) - self.assert_compile(schema.CreateIndex(idx), - "CREATE CLUSTERED INDEX foo ON test (id)" - ) + self.assert_compile( + schema.CreateIndex(idx), "CREATE CLUSTERED INDEX foo ON test (id)" + ) def test_index_ordering(self): metadata = MetaData() tbl = Table( - 'test', metadata, - Column('x', Integer), Column('y', Integer), Column('z', Integer)) + "test", + metadata, + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), + ) idx = Index("foo", tbl.c.x.desc(), "y") - self.assert_compile(schema.CreateIndex(idx), - "CREATE INDEX foo ON test (x DESC, y)" - ) + self.assert_compile( + schema.CreateIndex(idx), "CREATE INDEX foo ON test (x DESC, y)" + ) def test_create_index_expr(self): m = MetaData() - t1 = Table('foo', m, - Column('x', Integer) - ) + t1 = Table("foo", m, Column("x", Integer)) self.assert_compile( schema.CreateIndex(Index("bar", t1.c.x > 5)), - "CREATE INDEX bar ON foo (x > 5)" + "CREATE INDEX bar ON foo (x > 5)", ) def test_drop_index_w_schema(self): m = MetaData() - t1 = Table('foo', m, - Column('x', Integer), - schema='bar' - ) + t1 = Table("foo", m, Column("x", Integer), schema="bar") self.assert_compile( schema.DropIndex(Index("idx_foo", t1.c.x)), - "DROP INDEX idx_foo ON bar.foo" + "DROP INDEX idx_foo ON bar.foo", ) def test_index_extra_include_1(self): metadata = MetaData() tbl = Table( - 'test', metadata, - Column('x', Integer), Column('y', Integer), Column('z', Integer)) - idx = Index("foo", tbl.c.x, mssql_include=['y']) - self.assert_compile(schema.CreateIndex(idx), - "CREATE INDEX foo ON test (x) INCLUDE (y)" - ) + "test", + metadata, + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), + ) + idx = Index("foo", tbl.c.x, mssql_include=["y"]) + self.assert_compile( + schema.CreateIndex(idx), "CREATE INDEX foo ON test (x) INCLUDE (y)" + ) def test_index_extra_include_2(self): metadata = MetaData() tbl = Table( - 'test', metadata, - Column('x', Integer), Column('y', Integer), Column('z', Integer)) + "test", + metadata, + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), + ) idx = Index("foo", tbl.c.x, mssql_include=[tbl.c.y]) - self.assert_compile(schema.CreateIndex(idx), - "CREATE INDEX foo ON test (x) INCLUDE (y)" - ) + self.assert_compile( + schema.CreateIndex(idx), "CREATE INDEX foo ON test (x) INCLUDE (y)" + ) class SchemaTest(fixtures.TestBase): - def setup(self): - t = Table('sometable', MetaData(), - Column('pk_column', Integer), - Column('test_column', String) - ) + t = Table( + "sometable", + MetaData(), + Column("pk_column", Integer), + Column("test_column", String), + ) self.column = t.c.test_column dialect = mssql.dialect() - self.ddl_compiler = dialect.ddl_compiler(dialect, - schema.CreateTable(t)) + self.ddl_compiler = dialect.ddl_compiler( + dialect, schema.CreateTable(t) + ) def _column_spec(self): return self.ddl_compiler.get_column_specification(self.column) diff --git a/test/dialect/mssql/test_engine.py b/test/dialect/mssql/test_engine.py index 973eb6dbbd..40d3894fb8 100644 --- a/test/dialect/mssql/test_engine.py +++ b/test/dialect/mssql/test_engine.py @@ -6,8 +6,11 @@ from sqlalchemy.dialects.mssql import pyodbc, pymssql, adodbapi from sqlalchemy.engine import url from sqlalchemy.testing import fixtures from sqlalchemy import testing -from sqlalchemy.testing import assert_raises_message, \ - assert_warnings, expect_warnings +from sqlalchemy.testing import ( + assert_raises_message, + assert_warnings, + expect_warnings, +) from sqlalchemy.testing.mock import Mock from sqlalchemy.dialects.mssql import base from sqlalchemy import Integer, String, Table, Column @@ -15,30 +18,29 @@ from sqlalchemy import event class ParseConnectTest(fixtures.TestBase): - def test_pyodbc_connect_dsn_trusted(self): dialect = pyodbc.dialect() - u = url.make_url('mssql://mydsn') + u = url.make_url("mssql://mydsn") connection = dialect.create_connect_args(u) - eq_([['dsn=mydsn;Trusted_Connection=Yes'], {}], connection) + eq_([["dsn=mydsn;Trusted_Connection=Yes"], {}], connection) def test_pyodbc_connect_old_style_dsn_trusted(self): dialect = pyodbc.dialect() - u = url.make_url('mssql:///?dsn=mydsn') + u = url.make_url("mssql:///?dsn=mydsn") connection = dialect.create_connect_args(u) - eq_([['dsn=mydsn;Trusted_Connection=Yes'], {}], connection) + eq_([["dsn=mydsn;Trusted_Connection=Yes"], {}], connection) def test_pyodbc_connect_dsn_non_trusted(self): dialect = pyodbc.dialect() - u = url.make_url('mssql://username:password@mydsn') + u = url.make_url("mssql://username:password@mydsn") connection = dialect.create_connect_args(u) - eq_([['dsn=mydsn;UID=username;PWD=password'], {}], connection) + eq_([["dsn=mydsn;UID=username;PWD=password"], {}], connection) def test_pyodbc_connect_dsn_extra(self): dialect = pyodbc.dialect() - u = \ - url.make_url('mssql://username:password@mydsn/?LANGUAGE=us_' - 'english&foo=bar') + u = url.make_url( + "mssql://username:password@mydsn/?LANGUAGE=us_" "english&foo=bar" + ) connection = dialect.create_connect_args(u) dsn_string = connection[0][0] assert ";LANGUAGE=us_english" in dsn_string @@ -47,87 +49,151 @@ class ParseConnectTest(fixtures.TestBase): def test_pyodbc_hostname(self): dialect = pyodbc.dialect() u = url.make_url( - 'mssql://username:password@hostspec/database?driver=SQL+Server' + "mssql://username:password@hostspec/database?driver=SQL+Server" ) connection = dialect.create_connect_args(u) - eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UI' - 'D=username;PWD=password'], {}], connection) + eq_( + [ + [ + "DRIVER={SQL Server};Server=hostspec;Database=database;UI" + "D=username;PWD=password" + ], + {}, + ], + connection, + ) def test_pyodbc_host_no_driver(self): dialect = pyodbc.dialect() - u = url.make_url('mssql://username:password@hostspec/database') + u = url.make_url("mssql://username:password@hostspec/database") def go(): return dialect.create_connect_args(u) + connection = assert_warnings( go, - ["No driver name specified; this is expected by " - "PyODBC when using DSN-less connections"]) + [ + "No driver name specified; this is expected by " + "PyODBC when using DSN-less connections" + ], + ) - eq_([['Server=hostspec;Database=database;UI' - 'D=username;PWD=password'], {}], connection) + eq_( + [ + [ + "Server=hostspec;Database=database;UI" + "D=username;PWD=password" + ], + {}, + ], + connection, + ) def test_pyodbc_connect_comma_port(self): dialect = pyodbc.dialect() - u = \ - url.make_url('mssql://username:password@hostspec:12345/data' - 'base?driver=SQL Server') + u = url.make_url( + "mssql://username:password@hostspec:12345/data" + "base?driver=SQL Server" + ) connection = dialect.create_connect_args(u) - eq_([['DRIVER={SQL Server};Server=hostspec,12345;Database=datab' - 'ase;UID=username;PWD=password'], {}], connection) + eq_( + [ + [ + "DRIVER={SQL Server};Server=hostspec,12345;Database=datab" + "ase;UID=username;PWD=password" + ], + {}, + ], + connection, + ) def test_pyodbc_connect_config_port(self): dialect = pyodbc.dialect() - u = \ - url.make_url('mssql://username:password@hostspec/database?p' - 'ort=12345&driver=SQL+Server') + u = url.make_url( + "mssql://username:password@hostspec/database?p" + "ort=12345&driver=SQL+Server" + ) connection = dialect.create_connect_args(u) - eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UI' - 'D=username;PWD=password;port=12345'], {}], connection) + eq_( + [ + [ + "DRIVER={SQL Server};Server=hostspec;Database=database;UI" + "D=username;PWD=password;port=12345" + ], + {}, + ], + connection, + ) def test_pyodbc_extra_connect(self): dialect = pyodbc.dialect() - u = \ - url.make_url('mssql://username:password@hostspec/database?L' - 'ANGUAGE=us_english&foo=bar&driver=SQL+Server') + u = url.make_url( + "mssql://username:password@hostspec/database?L" + "ANGUAGE=us_english&foo=bar&driver=SQL+Server" + ) connection = dialect.create_connect_args(u) eq_(connection[1], {}) - eq_(connection[0][0] - in ('DRIVER={SQL Server};Server=hostspec;Database=database;' - 'UID=username;PWD=password;foo=bar;LANGUAGE=us_english', - 'DRIVER={SQL Server};Server=hostspec;Database=database;UID=' - 'username;PWD=password;LANGUAGE=us_english;foo=bar'), True) + eq_( + connection[0][0] + in ( + "DRIVER={SQL Server};Server=hostspec;Database=database;" + "UID=username;PWD=password;foo=bar;LANGUAGE=us_english", + "DRIVER={SQL Server};Server=hostspec;Database=database;UID=" + "username;PWD=password;LANGUAGE=us_english;foo=bar", + ), + True, + ) def test_pyodbc_odbc_connect(self): dialect = pyodbc.dialect() - u = \ - url.make_url('mssql:///?odbc_connect=DRIVER%3D%7BSQL+Server' - '%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase' - '%3BUID%3Dusername%3BPWD%3Dpassword') + u = url.make_url( + "mssql:///?odbc_connect=DRIVER%3D%7BSQL+Server" + "%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase" + "%3BUID%3Dusername%3BPWD%3Dpassword" + ) connection = dialect.create_connect_args(u) - eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UI' - 'D=username;PWD=password'], {}], connection) + eq_( + [ + [ + "DRIVER={SQL Server};Server=hostspec;Database=database;UI" + "D=username;PWD=password" + ], + {}, + ], + connection, + ) def test_pyodbc_odbc_connect_with_dsn(self): dialect = pyodbc.dialect() - u = \ - url.make_url('mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase' - '%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword' - ) + u = url.make_url( + "mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase" + "%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword" + ) connection = dialect.create_connect_args(u) - eq_([['dsn=mydsn;Database=database;UID=username;PWD=password'], - {}], connection) + eq_( + [["dsn=mydsn;Database=database;UID=username;PWD=password"], {}], + connection, + ) def test_pyodbc_odbc_connect_ignores_other_values(self): dialect = pyodbc.dialect() - u = \ - url.make_url('mssql://userdiff:passdiff@localhost/dbdiff?od' - 'bc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer' - '%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Duse' - 'rname%3BPWD%3Dpassword') + u = url.make_url( + "mssql://userdiff:passdiff@localhost/dbdiff?od" + "bc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer" + "%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Duse" + "rname%3BPWD%3Dpassword" + ) connection = dialect.create_connect_args(u) - eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UI' - 'D=username;PWD=password'], {}], connection) + eq_( + [ + [ + "DRIVER={SQL Server};Server=hostspec;Database=database;UI" + "D=username;PWD=password" + ], + {}, + ], + connection, + ) def test_pyodbc_token_injection(self): token1 = "someuser%3BPORT%3D50001" @@ -136,18 +202,21 @@ class ParseConnectTest(fixtures.TestBase): token4 = "somedb%3BPORT%3D50001" u = url.make_url( - 'mssql+pyodbc://%s:%s@%s/%s?driver=foob' % ( - token1, token2, token3, token4 - ) + "mssql+pyodbc://%s:%s@%s/%s?driver=foob" + % (token1, token2, token3, token4) ) dialect = pyodbc.dialect() connection = dialect.create_connect_args(u) eq_( - [[ - "DRIVER={foob};Server=somehost%3BPORT%3D50001;" - "Database=somedb%3BPORT%3D50001;UID='someuser;PORT=50001';" - "PWD='somepw;PORT=50001'"], {}], - connection + [ + [ + "DRIVER={foob};Server=somehost%3BPORT%3D50001;" + "Database=somedb%3BPORT%3D50001;UID='someuser;PORT=50001';" + "PWD='somepw;PORT=50001'" + ], + {}, + ], + connection, ) def test_adodbapi_token_injection(self): @@ -158,49 +227,67 @@ class ParseConnectTest(fixtures.TestBase): # this URL format is all wrong u = url.make_url( - 'mssql+adodbapi://@/?user=%s&password=%s&host=%s&port=%s' % ( - token1, token2, token3, token4 - ) + "mssql+adodbapi://@/?user=%s&password=%s&host=%s&port=%s" + % (token1, token2, token3, token4) ) dialect = adodbapi.dialect() connection = dialect.create_connect_args(u) eq_( - [["Provider=SQLOLEDB;" - "Data Source='somehost;PORT=50001', 'someport;PORT=50001';" - "Initial Catalog=None;User Id='someuser;PORT=50001';" - "Password='somepw;PORT=50001'"], {}], - connection + [ + [ + "Provider=SQLOLEDB;" + "Data Source='somehost;PORT=50001', 'someport;PORT=50001';" + "Initial Catalog=None;User Id='someuser;PORT=50001';" + "Password='somepw;PORT=50001'" + ], + {}, + ], + connection, ) def test_pymssql_port_setting(self): dialect = pymssql.dialect() - u = \ - url.make_url('mssql+pymssql://scott:tiger@somehost/test') + u = url.make_url("mssql+pymssql://scott:tiger@somehost/test") connection = dialect.create_connect_args(u) eq_( - [[], {'host': 'somehost', 'password': 'tiger', - 'user': 'scott', 'database': 'test'}], connection + [ + [], + { + "host": "somehost", + "password": "tiger", + "user": "scott", + "database": "test", + }, + ], + connection, ) - u = \ - url.make_url('mssql+pymssql://scott:tiger@somehost:5000/test') + u = url.make_url("mssql+pymssql://scott:tiger@somehost:5000/test") connection = dialect.create_connect_args(u) eq_( - [[], {'host': 'somehost:5000', 'password': 'tiger', - 'user': 'scott', 'database': 'test'}], connection + [ + [], + { + "host": "somehost:5000", + "password": "tiger", + "user": "scott", + "database": "test", + }, + ], + connection, ) def test_pymssql_disconnect(self): dialect = pymssql.dialect() for error in [ - 'Adaptive Server connection timed out', - 'Net-Lib error during Connection reset by peer', - 'message 20003', - 'Error 10054', - 'Not connected to any MS SQL server', - 'Connection is closed' + "Adaptive Server connection timed out", + "Net-Lib error during Connection reset by peer", + "message 20003", + "Error 10054", + "Not connected to any MS SQL server", + "Connection is closed", ]: eq_(dialect.is_disconnect(error, None, None), True) @@ -216,24 +303,36 @@ class ParseConnectTest(fixtures.TestBase): pass dialect.dbapi = Mock( - Error=MockDBAPIError, ProgrammingError=MockProgrammingError) + Error=MockDBAPIError, ProgrammingError=MockProgrammingError + ) for error in [ MockDBAPIError("[%s] some pyodbc message" % code) for code in [ - '08S01', '01002', '08003', '08007', - '08S02', '08001', 'HYT00', 'HY010'] + "08S01", + "01002", + "08003", + "08007", + "08S02", + "08001", + "HYT00", + "HY010", + ] ] + [ MockProgrammingError(message) for message in [ "(some pyodbc stuff) The cursor's connection has been closed.", - "(some pyodbc stuff) Attempt to use a closed connection." + "(some pyodbc stuff) Attempt to use a closed connection.", ] ]: eq_(dialect.is_disconnect(error, None, None), True) - eq_(dialect.is_disconnect( - MockProgrammingError("not an error"), None, None), False) + eq_( + dialect.is_disconnect( + MockProgrammingError("not an error"), None, None + ), + False, + ) @testing.requires.mssql_freetds def test_bad_freetds_warning(self): @@ -243,33 +342,35 @@ class ParseConnectTest(fixtures.TestBase): return 95, 10, 255 engine.dialect._get_server_version_info = _bad_version - assert_raises_message(exc.SAWarning, - 'Unrecognized server version info', - engine.connect) + assert_raises_message( + exc.SAWarning, "Unrecognized server version info", engine.connect + ) class EngineFromConfigTest(fixtures.TestBase): def test_legacy_schema_flag(self): cfg = { "sqlalchemy.url": "mssql://foodsn", - "sqlalchemy.legacy_schema_aliasing": "false" + "sqlalchemy.legacy_schema_aliasing": "false", } e = engine_from_config( - cfg, module=Mock(version="MS SQL Server 11.0.92")) + cfg, module=Mock(version="MS SQL Server 11.0.92") + ) eq_(e.dialect.legacy_schema_aliasing, False) class FastExecutemanyTest(fixtures.TestBase): - __only_on__ = 'mssql' + __only_on__ = "mssql" __backend__ = True - __requires__ = ('pyodbc_fast_executemany', ) + __requires__ = ("pyodbc_fast_executemany",) @testing.provide_metadata def test_flag_on(self): t = Table( - 't', self.metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) + "t", + self.metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), ) t.create() @@ -277,20 +378,18 @@ class FastExecutemanyTest(fixtures.TestBase): @event.listens_for(eng, "after_cursor_execute") def after_cursor_execute( - conn, cursor, statement, parameters, context, executemany): + conn, cursor, statement, parameters, context, executemany + ): if executemany: assert cursor.fast_executemany with eng.connect() as conn: conn.execute( t.insert(), - [{"id": i, "data": "data_%d" % i} for i in range(100)] + [{"id": i, "data": "data_%d" % i} for i in range(100)], ) - conn.execute( - t.insert(), - {"id": 200, "data": "data_200"} - ) + conn.execute(t.insert(), {"id": 200, "data": "data_200"}) class VersionDetectionTest(fixtures.TestBase): @@ -302,22 +401,16 @@ class VersionDetectionTest(fixtures.TestBase): "Microsoft SQL Server (XYZ) - 11.0.9216.62 \n" "Jul 18 2014 22:00:21 \nCopyright (c) Microsoft Corporation", "Microsoft SQL Azure (RTM) - 11.0.9216.62 \n" - "Jul 18 2014 22:00:21 \nCopyright (c) Microsoft Corporation" + "Jul 18 2014 22:00:21 \nCopyright (c) Microsoft Corporation", ]: conn = Mock(scalar=Mock(return_value=vers)) - eq_( - dialect._get_server_version_info(conn), - (11, 0, 9216, 62) - ) + eq_(dialect._get_server_version_info(conn), (11, 0, 9216, 62)) def test_pyodbc_version_productversion(self): dialect = pyodbc.MSDialect_pyodbc() conn = Mock(scalar=Mock(return_value="11.0.9216.62")) - eq_( - dialect._get_server_version_info(conn), - (11, 0, 9216, 62) - ) + eq_(dialect._get_server_version_info(conn), (11, 0, 9216, 62)) def test_pyodbc_version_fallback(self): dialect = pyodbc.MSDialect_pyodbc() @@ -326,24 +419,19 @@ class VersionDetectionTest(fixtures.TestBase): for vers, expected in [ ("11.0.9216.62", (11, 0, 9216, 62)), ("notsqlserver.11.foo.0.9216.BAR.62", (11, 0, 9216, 62)), - ("Not SQL Server Version 10.5", (5, )) + ("Not SQL Server Version 10.5", (5,)), ]: conn = Mock( scalar=Mock( - side_effect=exc.DBAPIError("stmt", "params", None)), - connection=Mock( - getinfo=Mock(return_value=vers) - ) + side_effect=exc.DBAPIError("stmt", "params", None) + ), + connection=Mock(getinfo=Mock(return_value=vers)), ) - eq_( - dialect._get_server_version_info(conn), - expected - ) + eq_(dialect._get_server_version_info(conn), expected) class IsolationLevelDetectTest(fixtures.TestBase): - def _fixture(self, view): class Error(Exception): pass @@ -354,18 +442,17 @@ class IsolationLevelDetectTest(fixtures.TestBase): result = [] - def fail_on_exec(stmt, ): + def fail_on_exec(stmt,): if view is not None and view in stmt: - result.append(('SERIALIZABLE', )) + result.append(("SERIALIZABLE",)) else: raise Error("that didn't work") connection = Mock( cursor=Mock( return_value=Mock( - execute=fail_on_exec, - fetchone=lambda: result[0] - ), + execute=fail_on_exec, fetchone=lambda: result[0] + ) ) ) @@ -374,18 +461,12 @@ class IsolationLevelDetectTest(fixtures.TestBase): def test_dm_pdw_nodes(self): dialect, connection = self._fixture("dm_pdw_nodes_exec_sessions") - eq_( - dialect.get_isolation_level(connection), - "SERIALIZABLE" - ) + eq_(dialect.get_isolation_level(connection), "SERIALIZABLE") def test_exec_sessions(self): dialect, connection = self._fixture("exec_sessions") - eq_( - dialect.get_isolation_level(connection), - "SERIALIZABLE" - ) + eq_(dialect.get_isolation_level(connection), "SERIALIZABLE") def test_not_supported(self): dialect, connection = self._fixture(None) @@ -394,6 +475,6 @@ class IsolationLevelDetectTest(fixtures.TestBase): assert_raises_message( NotImplementedError, "Can't fetch isolation", - dialect.get_isolation_level, connection + dialect.get_isolation_level, + connection, ) - diff --git a/test/dialect/mssql/test_query.py b/test/dialect/mssql/test_query.py index 13876d6d18..4c35283b3a 100644 --- a/test/dialect/mssql/test_query.py +++ b/test/dialect/mssql/test_query.py @@ -7,8 +7,21 @@ from sqlalchemy import testing from sqlalchemy.util import ue from sqlalchemy import util from sqlalchemy.testing.assertsql import CursorSQL, DialectSQL -from sqlalchemy import Integer, String, Table, Column, select, MetaData,\ - func, PrimaryKeyConstraint, desc, DDL, ForeignKey, or_, and_ +from sqlalchemy import ( + Integer, + String, + Table, + Column, + select, + MetaData, + func, + PrimaryKeyConstraint, + desc, + DDL, + ForeignKey, + or_, + and_, +) from sqlalchemy import event metadata = None @@ -27,34 +40,27 @@ class LegacySchemaAliasingTest(fixtures.TestBase, AssertsCompiledSQL): def setup(self): metadata = MetaData() self.t1 = table( - 't1', - column('a', Integer), - column('b', String), - column('c', String), + "t1", + column("a", Integer), + column("b", String), + column("c", String), ) self.t2 = Table( - 't2', metadata, + "t2", + metadata, Column("a", Integer), Column("b", Integer), Column("c", Integer), - schema='schema' + schema="schema", ) def _assert_sql(self, element, legacy_sql, modern_sql=None): dialect = mssql.dialect(legacy_schema_aliasing=True) - self.assert_compile( - element, - legacy_sql, - dialect=dialect - ) + self.assert_compile(element, legacy_sql, dialect=dialect) dialect = mssql.dialect() - self.assert_compile( - element, - modern_sql or "foob", - dialect=dialect - ) + self.assert_compile(element, modern_sql or "foob", dialect=dialect) def _legacy_dialect(self): return mssql.dialect(legacy_schema_aliasing=True) @@ -62,19 +68,19 @@ class LegacySchemaAliasingTest(fixtures.TestBase, AssertsCompiledSQL): def test_result_map(self): s = self.t2.select() c = s.compile(dialect=self._legacy_dialect()) - assert self.t2.c.a in set(c._create_result_map()['a'][1]) + assert self.t2.c.a in set(c._create_result_map()["a"][1]) def test_result_map_use_labels(self): s = self.t2.select(use_labels=True) c = s.compile(dialect=self._legacy_dialect()) - assert self.t2.c.a in set(c._create_result_map()['schema_t2_a'][1]) + assert self.t2.c.a in set(c._create_result_map()["schema_t2_a"][1]) def test_straight_select(self): self._assert_sql( self.t2.select(), "SELECT t2_1.a, t2_1.b, t2_1.c FROM [schema].t2 AS t2_1", "SELECT [schema].t2.a, [schema].t2.b, " - "[schema].t2.c FROM [schema].t2" + "[schema].t2.c FROM [schema].t2", ) def test_straight_select_use_labels(self): @@ -84,7 +90,7 @@ class LegacySchemaAliasingTest(fixtures.TestBase, AssertsCompiledSQL): "t2_1.c AS schema_t2_c FROM [schema].t2 AS t2_1", "SELECT [schema].t2.a AS schema_t2_a, " "[schema].t2.b AS schema_t2_b, " - "[schema].t2.c AS schema_t2_c FROM [schema].t2" + "[schema].t2.c AS schema_t2_c FROM [schema].t2", ) def test_join_to_schema(self): @@ -93,46 +99,46 @@ class LegacySchemaAliasingTest(fixtures.TestBase, AssertsCompiledSQL): t1.join(t2, t1.c.a == t2.c.a).select(), "SELECT t1.a, t1.b, t1.c, t2_1.a, t2_1.b, t2_1.c FROM t1 " "JOIN [schema].t2 AS t2_1 ON t2_1.a = t1.a", - "SELECT t1.a, t1.b, t1.c, [schema].t2.a, [schema].t2.b, " - "[schema].t2.c FROM t1 JOIN [schema].t2 ON [schema].t2.a = t1.a" + "[schema].t2.c FROM t1 JOIN [schema].t2 ON [schema].t2.a = t1.a", ) def test_union_schema_to_non(self): t1, t2 = self.t1, self.t2 - s = select([t2.c.a, t2.c.b]).apply_labels().\ - union( - select([t1.c.a, t1.c.b]).apply_labels()).alias().select() + s = ( + select([t2.c.a, t2.c.b]) + .apply_labels() + .union(select([t1.c.a, t1.c.b]).apply_labels()) + .alias() + .select() + ) self._assert_sql( s, "SELECT anon_1.schema_t2_a, anon_1.schema_t2_b FROM " "(SELECT t2_1.a AS schema_t2_a, t2_1.b AS schema_t2_b " "FROM [schema].t2 AS t2_1 UNION SELECT t1.a AS t1_a, " "t1.b AS t1_b FROM t1) AS anon_1", - "SELECT anon_1.schema_t2_a, anon_1.schema_t2_b FROM " "(SELECT [schema].t2.a AS schema_t2_a, [schema].t2.b AS " "schema_t2_b FROM [schema].t2 UNION SELECT t1.a AS t1_a, " - "t1.b AS t1_b FROM t1) AS anon_1" + "t1.b AS t1_b FROM t1) AS anon_1", ) def test_column_subquery_to_alias(self): - a1 = self.t2.alias('a1') + a1 = self.t2.alias("a1") s = select([self.t2, select([a1.c.a]).as_scalar()]) self._assert_sql( s, "SELECT t2_1.a, t2_1.b, t2_1.c, " "(SELECT a1.a FROM [schema].t2 AS a1) " "AS anon_1 FROM [schema].t2 AS t2_1", - "SELECT [schema].t2.a, [schema].t2.b, [schema].t2.c, " - "(SELECT a1.a FROM [schema].t2 AS a1) AS anon_1 FROM [schema].t2" - + "(SELECT a1.a FROM [schema].t2 AS a1) AS anon_1 FROM [schema].t2", ) class IdentityInsertTest(fixtures.TestBase, AssertsCompiledSQL): - __only_on__ = 'mssql' + __only_on__ = "mssql" __dialect__ = mssql.MSDialect() __backend__ = True @@ -141,11 +147,13 @@ class IdentityInsertTest(fixtures.TestBase, AssertsCompiledSQL): global metadata, cattable metadata = MetaData(testing.db) - cattable = Table('cattable', metadata, - Column('id', Integer), - Column('description', String(50)), - PrimaryKeyConstraint('id', name='PK_cattable'), - ) + cattable = Table( + "cattable", + metadata, + Column("id", Integer), + Column("description", String(50)), + PrimaryKeyConstraint("id", name="PK_cattable"), + ) def setup(self): metadata.create_all() @@ -154,40 +162,48 @@ class IdentityInsertTest(fixtures.TestBase, AssertsCompiledSQL): metadata.drop_all() def test_compiled(self): - self.assert_compile(cattable.insert().values(id=9, - description='Python'), - 'INSERT INTO cattable (id, description) ' - 'VALUES (:id, :description)') + self.assert_compile( + cattable.insert().values(id=9, description="Python"), + "INSERT INTO cattable (id, description) " + "VALUES (:id, :description)", + ) def test_execute(self): - cattable.insert().values(id=9, description='Python').execute() + cattable.insert().values(id=9, description="Python").execute() cats = cattable.select().order_by(cattable.c.id).execute() - eq_([(9, 'Python')], list(cats)) + eq_([(9, "Python")], list(cats)) - result = cattable.insert().values(description='PHP').execute() + result = cattable.insert().values(description="PHP").execute() eq_([10], result.inserted_primary_key) lastcat = cattable.select().order_by(desc(cattable.c.id)).execute() - eq_((10, 'PHP'), lastcat.first()) + eq_((10, "PHP"), lastcat.first()) def test_executemany(self): - cattable.insert().execute([{'id': 89, 'description': 'Python'}, - {'id': 8, 'description': 'Ruby'}, - {'id': 3, 'description': 'Perl'}, - {'id': 1, 'description': 'Java'}]) + cattable.insert().execute( + [ + {"id": 89, "description": "Python"}, + {"id": 8, "description": "Ruby"}, + {"id": 3, "description": "Perl"}, + {"id": 1, "description": "Java"}, + ] + ) cats = cattable.select().order_by(cattable.c.id).execute() - eq_([(1, 'Java'), (3, 'Perl'), (8, 'Ruby'), (89, 'Python')], - list(cats)) - cattable.insert().execute([{'description': 'PHP'}, - {'description': 'Smalltalk'}]) - lastcats = \ + eq_( + [(1, "Java"), (3, "Perl"), (8, "Ruby"), (89, "Python")], list(cats) + ) + cattable.insert().execute( + [{"description": "PHP"}, {"description": "Smalltalk"}] + ) + lastcats = ( cattable.select().order_by(desc(cattable.c.id)).limit(2).execute() - eq_([(91, 'Smalltalk'), (90, 'PHP')], list(lastcats)) + ) + eq_([(91, "Smalltalk"), (90, "PHP")], list(lastcats)) class QueryUnicodeTest(fixtures.TestBase): - __only_on__ = 'mssql' + __only_on__ = "mssql" __backend__ = True @testing.requires.mssql_freetds @@ -195,29 +211,35 @@ class QueryUnicodeTest(fixtures.TestBase): def test_convert_unicode(self): meta = MetaData(testing.db) t1 = Table( - 'unitest_table', meta, - Column('id', Integer, primary_key=True), - Column('descr', mssql.MSText(convert_unicode=True))) + "unitest_table", + meta, + Column("id", Integer, primary_key=True), + Column("descr", mssql.MSText(convert_unicode=True)), + ) meta.create_all() con = testing.db.connect() # encode in UTF-8 (sting object) because this is the default # dialect encoding - con.execute(ue("insert into unitest_table values ('bien u\ - umang\xc3\xa9')").encode('UTF-8')) + con.execute( + ue( + "insert into unitest_table values ('bien u\ + umang\xc3\xa9')" + ).encode("UTF-8") + ) try: r = t1.select().execute().first() - assert isinstance(r[1], util.text_type), \ - '%s is %s instead of unicode, working on %s' % ( - r[1], - type(r[1]), meta.bind) + assert isinstance(r[1], util.text_type), ( + "%s is %s instead of unicode, working on %s" + % (r[1], type(r[1]), meta.bind) + ) finally: meta.drop_all() class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase): - __only_on__ = 'mssql' + __only_on__ = "mssql" __backend__ = True def test_fetchid_trigger(self): @@ -253,31 +275,36 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase): # than the SQL Server Native Client. Maybe an assert_raises # test should be written. meta = MetaData(testing.db) - t1 = Table('t1', meta, - Column('id', Integer, mssql_identity_start=100, - primary_key=True), - Column('descr', String(200)), - # the following flag will prevent the - # MSSQLCompiler.returning_clause from getting called, - # though the ExecutionContext will still have a - # _select_lastrowid, so the SELECT SCOPE_IDENTITY() will - # hopefully be called instead. - implicit_returning=False - ) - t2 = Table('t2', meta, - Column('id', Integer, mssql_identity_start=200, - primary_key=True), - Column('descr', String(200))) + t1 = Table( + "t1", + meta, + Column("id", Integer, mssql_identity_start=100, primary_key=True), + Column("descr", String(200)), + # the following flag will prevent the + # MSSQLCompiler.returning_clause from getting called, + # though the ExecutionContext will still have a + # _select_lastrowid, so the SELECT SCOPE_IDENTITY() will + # hopefully be called instead. + implicit_returning=False, + ) + t2 = Table( + "t2", + meta, + Column("id", Integer, mssql_identity_start=200, primary_key=True), + Column("descr", String(200)), + ) meta.create_all() con = testing.db.connect() - con.execute("""create trigger paj on t1 for insert as - insert into t2 (descr) select descr from inserted""") + con.execute( + """create trigger paj on t1 for insert as + insert into t2 (descr) select descr from inserted""" + ) try: tr = con.begin() - r = con.execute(t2.insert(), descr='hello') + r = con.execute(t2.insert(), descr="hello") self.assert_(r.inserted_primary_key == [200]) - r = con.execute(t1.insert(), descr='hello') + r = con.execute(t1.insert(), descr="hello") self.assert_(r.inserted_primary_key == [100]) finally: @@ -290,10 +317,11 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase): engine = engines.testing_engine(options={"use_scope_identity": False}) metadata = self.metadata t1 = Table( - 't1', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - implicit_returning=False + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + implicit_returning=False, ) metadata.create_all(engine) @@ -303,8 +331,7 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase): # TODO: need a dialect SQL that acts like Cursor SQL asserter.assert_( DialectSQL( - "INSERT INTO t1 (data) VALUES (:data)", - {"data": "somedata"} + "INSERT INTO t1 (data) VALUES (:data)", {"data": "somedata"} ), CursorSQL("SELECT @@identity AS lastrowid"), ) @@ -314,9 +341,10 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase): engine = engines.testing_engine(options={"use_scope_identity": True}) metadata = self.metadata t1 = Table( - 't1', metadata, - Column('id', Integer, primary_key=True), - implicit_returning=False + "t1", + metadata, + Column("id", Integer, primary_key=True), + implicit_returning=False, ) metadata.create_all(engine) @@ -330,100 +358,108 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase): CursorSQL("SELECT scope_identity() AS lastrowid"), ) - @testing.only_on('mssql+pyodbc') + @testing.only_on("mssql+pyodbc") @testing.provide_metadata def test_embedded_scope_identity(self): engine = engines.testing_engine(options={"use_scope_identity": True}) metadata = self.metadata t1 = Table( - 't1', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - implicit_returning=False + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + implicit_returning=False, ) metadata.create_all(engine) with self.sql_execution_asserter(engine) as asserter: - engine.execute(t1.insert(), {'data': 'somedata'}) + engine.execute(t1.insert(), {"data": "somedata"}) # pyodbc-specific system asserter.assert_( CursorSQL( "INSERT INTO t1 (data) VALUES (?); select scope_identity()", - ("somedata", ) - ), + ("somedata",), + ) ) @testing.provide_metadata def test_insertid_schema(self): meta = self.metadata eng = engines.testing_engine( - options=dict(legacy_schema_aliasing=False)) + options=dict(legacy_schema_aliasing=False) + ) meta.bind = eng con = eng.connect() - con.execute('create schema paj') + con.execute("create schema paj") @event.listens_for(meta, "after_drop") def cleanup(target, connection, **kw): - connection.execute('drop schema paj') + connection.execute("drop schema paj") - tbl = Table('test', meta, - Column('id', Integer, primary_key=True), schema='paj') + tbl = Table( + "test", meta, Column("id", Integer, primary_key=True), schema="paj" + ) tbl.create() - tbl.insert().execute({'id': 1}) + tbl.insert().execute({"id": 1}) eq_(tbl.select().scalar(), 1) @testing.provide_metadata def test_insertid_schema_legacy(self): meta = self.metadata - eng = engines.testing_engine( - options=dict(legacy_schema_aliasing=True)) + eng = engines.testing_engine(options=dict(legacy_schema_aliasing=True)) meta.bind = eng con = eng.connect() - con.execute('create schema paj') + con.execute("create schema paj") @event.listens_for(meta, "after_drop") def cleanup(target, connection, **kw): - connection.execute('drop schema paj') + connection.execute("drop schema paj") - tbl = Table('test', meta, - Column('id', Integer, primary_key=True), schema='paj') + tbl = Table( + "test", meta, Column("id", Integer, primary_key=True), schema="paj" + ) tbl.create() - tbl.insert().execute({'id': 1}) + tbl.insert().execute({"id": 1}) eq_(tbl.select().scalar(), 1) @testing.provide_metadata def test_returning_no_autoinc(self): meta = self.metadata table = Table( - 't1', meta, - Column('id', Integer, primary_key=True), - Column('data', String(50))) + "t1", + meta, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) table.create() - result = table.insert().values( - id=1, - data=func.lower('SomeString')).\ - returning(table.c.id, table.c.data).execute() - eq_(result.fetchall(), [(1, 'somestring')]) + result = ( + table.insert() + .values(id=1, data=func.lower("SomeString")) + .returning(table.c.id, table.c.data) + .execute() + ) + eq_(result.fetchall(), [(1, "somestring")]) @testing.provide_metadata def test_delete_schema(self): meta = self.metadata eng = engines.testing_engine( - options=dict(legacy_schema_aliasing=False)) + options=dict(legacy_schema_aliasing=False) + ) meta.bind = eng con = eng.connect() - con.execute('create schema paj') + con.execute("create schema paj") @event.listens_for(meta, "after_drop") def cleanup(target, connection, **kw): - connection.execute('drop schema paj') + connection.execute("drop schema paj") tbl = Table( - 'test', meta, - Column('id', Integer, primary_key=True), schema='paj') + "test", meta, Column("id", Integer, primary_key=True), schema="paj" + ) tbl.create() - tbl.insert().execute({'id': 1}) + tbl.insert().execute({"id": 1}) eq_(tbl.select().scalar(), 1) tbl.delete(tbl.c.id == 1).execute() eq_(tbl.select().scalar(), None) @@ -431,21 +467,20 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase): @testing.provide_metadata def test_delete_schema_legacy(self): meta = self.metadata - eng = engines.testing_engine( - options=dict(legacy_schema_aliasing=True)) + eng = engines.testing_engine(options=dict(legacy_schema_aliasing=True)) meta.bind = eng con = eng.connect() - con.execute('create schema paj') + con.execute("create schema paj") @event.listens_for(meta, "after_drop") def cleanup(target, connection, **kw): - connection.execute('drop schema paj') + connection.execute("drop schema paj") tbl = Table( - 'test', meta, - Column('id', Integer, primary_key=True), schema='paj') + "test", meta, Column("id", Integer, primary_key=True), schema="paj" + ) tbl.create() - tbl.insert().execute({'id': 1}) + tbl.insert().execute({"id": 1}) eq_(tbl.select().scalar(), 1) tbl.delete(tbl.c.id == 1).execute() eq_(tbl.select().scalar(), None) @@ -453,10 +488,7 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase): @testing.provide_metadata def test_insertid_reserved(self): meta = self.metadata - table = Table( - 'select', meta, - Column('col', Integer, primary_key=True) - ) + table = Table("select", meta, Column("col", Integer, primary_key=True)) table.create() table.insert().execute(col=7) @@ -464,7 +496,6 @@ class QueryTest(testing.AssertsExecutionResults, fixtures.TestBase): class Foo(object): - def __init__(self, **kw): for k in kw: setattr(self, k, kw[k]) @@ -477,8 +508,7 @@ def full_text_search_missing(): try: connection = testing.db.connect() try: - connection.execute('CREATE FULLTEXT CATALOG Catalog AS ' - 'DEFAULT') + connection.execute("CREATE FULLTEXT CATALOG Catalog AS " "DEFAULT") return False except Exception: return True @@ -488,44 +518,64 @@ def full_text_search_missing(): class MatchTest(fixtures.TestBase, AssertsCompiledSQL): - __only_on__ = 'mssql' - __skip_if__ = full_text_search_missing, + __only_on__ = "mssql" + __skip_if__ = (full_text_search_missing,) __backend__ = True @classmethod def setup_class(cls): global metadata, cattable, matchtable metadata = MetaData(testing.db) - cattable = Table('cattable', metadata, Column('id', Integer), - Column('description', String(50)), - PrimaryKeyConstraint('id', name='PK_cattable')) + cattable = Table( + "cattable", + metadata, + Column("id", Integer), + Column("description", String(50)), + PrimaryKeyConstraint("id", name="PK_cattable"), + ) matchtable = Table( - 'matchtable', + "matchtable", metadata, - Column('id', Integer), - Column('title', String(200)), - Column('category_id', Integer, ForeignKey('cattable.id')), - PrimaryKeyConstraint('id', name='PK_matchtable'), + Column("id", Integer), + Column("title", String(200)), + Column("category_id", Integer, ForeignKey("cattable.id")), + PrimaryKeyConstraint("id", name="PK_matchtable"), ) - DDL("""CREATE FULLTEXT INDEX + DDL( + """CREATE FULLTEXT INDEX ON cattable (description) - KEY INDEX PK_cattable""").\ - execute_at('after-create', matchtable) - DDL("""CREATE FULLTEXT INDEX + KEY INDEX PK_cattable""" + ).execute_at("after-create", matchtable) + DDL( + """CREATE FULLTEXT INDEX ON matchtable (title) - KEY INDEX PK_matchtable""").\ - execute_at('after-create', matchtable) + KEY INDEX PK_matchtable""" + ).execute_at("after-create", matchtable) metadata.create_all() - cattable.insert().execute([{'id': 1, 'description': 'Python'}, - {'id': 2, 'description': 'Ruby'}]) - matchtable.insert().execute([ - {'id': 1, 'title': 'Web Development with Rails', 'category_id': 2}, - {'id': 2, 'title': 'Dive Into Python', 'category_id': 1}, - {'id': 3, 'title': "Programming Matz's Ruby", 'category_id': 2}, - {'id': 4, 'title': 'Guide to Django', 'category_id': 1}, - {'id': 5, 'title': 'Python in a Nutshell', 'category_id': 1}]) - DDL("WAITFOR DELAY '00:00:05'" - ).execute(bind=engines.testing_engine()) + cattable.insert().execute( + [ + {"id": 1, "description": "Python"}, + {"id": 2, "description": "Ruby"}, + ] + ) + matchtable.insert().execute( + [ + { + "id": 1, + "title": "Web Development with Rails", + "category_id": 2, + }, + {"id": 2, "title": "Dive Into Python", "category_id": 1}, + { + "id": 3, + "title": "Programming Matz's Ruby", + "category_id": 2, + }, + {"id": 4, "title": "Guide to Django", "category_id": 1}, + {"id": 5, "title": "Python in a Nutshell", "category_id": 1}, + ] + ) + DDL("WAITFOR DELAY '00:00:05'").execute(bind=engines.testing_engine()) @classmethod def teardown_class(cls): @@ -535,65 +585,106 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): connection.close() def test_expression(self): - self.assert_compile(matchtable.c.title.match('somstr'), - 'CONTAINS (matchtable.title, ?)') + self.assert_compile( + matchtable.c.title.match("somstr"), + "CONTAINS (matchtable.title, ?)", + ) def test_simple_match(self): - results = \ - matchtable.select().where( - matchtable.c.title.match('python')).\ - order_by(matchtable.c.id).execute().fetchall() + results = ( + matchtable.select() + .where(matchtable.c.title.match("python")) + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([2, 5], [r.id for r in results]) def test_simple_match_with_apostrophe(self): - results = \ - matchtable.select().where( - matchtable.c.title.match("Matz's")).execute().fetchall() + results = ( + matchtable.select() + .where(matchtable.c.title.match("Matz's")) + .execute() + .fetchall() + ) eq_([3], [r.id for r in results]) def test_simple_prefix_match(self): - results = \ - matchtable.select().where( - matchtable.c.title.match('"nut*"')).execute().fetchall() + results = ( + matchtable.select() + .where(matchtable.c.title.match('"nut*"')) + .execute() + .fetchall() + ) eq_([5], [r.id for r in results]) def test_simple_inflectional_match(self): - results = \ - matchtable.select().where( - matchtable.c.title.match('FORMSOF(INFLECTIONAL, "dives")' - )).execute().fetchall() + results = ( + matchtable.select() + .where(matchtable.c.title.match('FORMSOF(INFLECTIONAL, "dives")')) + .execute() + .fetchall() + ) eq_([2], [r.id for r in results]) def test_or_match(self): - results1 = \ - matchtable.select().where(or_( - matchtable.c.title.match('nutshell'), - matchtable.c.title.match('ruby'))).\ - order_by(matchtable.c.id).execute().fetchall() + results1 = ( + matchtable.select() + .where( + or_( + matchtable.c.title.match("nutshell"), + matchtable.c.title.match("ruby"), + ) + ) + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([3, 5], [r.id for r in results1]) - results2 = \ - matchtable.select().where( - matchtable.c.title.match( - 'nutshell OR ruby')).\ - order_by(matchtable.c.id).execute().fetchall() + results2 = ( + matchtable.select() + .where(matchtable.c.title.match("nutshell OR ruby")) + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([3, 5], [r.id for r in results2]) def test_and_match(self): - results1 = \ - matchtable.select().where(and_( - matchtable.c.title.match('python'), - matchtable.c.title.match('nutshell'))).execute().fetchall() + results1 = ( + matchtable.select() + .where( + and_( + matchtable.c.title.match("python"), + matchtable.c.title.match("nutshell"), + ) + ) + .execute() + .fetchall() + ) eq_([5], [r.id for r in results1]) - results2 = \ - matchtable.select().where( - matchtable.c.title.match('python AND nutshell' - )).execute().fetchall() + results2 = ( + matchtable.select() + .where(matchtable.c.title.match("python AND nutshell")) + .execute() + .fetchall() + ) eq_([5], [r.id for r in results2]) def test_match_across_joins(self): - results = matchtable.select().where( - and_(cattable.c.id == matchtable.c.category_id, - or_(cattable.c.description.match('Ruby'), - matchtable.c.title.match('nutshell')))).\ - order_by(matchtable.c.id).execute().fetchall() + results = ( + matchtable.select() + .where( + and_( + cattable.c.id == matchtable.c.category_id, + or_( + cattable.c.description.match("Ruby"), + matchtable.c.title.match("nutshell"), + ), + ) + ) + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([1, 3, 5], [r.id for r in results]) diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py index e526168f15..47322ced64 100644 --- a/test/dialect/mssql/test_reflection.py +++ b/test/dialect/mssql/test_reflection.py @@ -3,9 +3,7 @@ from sqlalchemy.testing import eq_, is_, in_ from sqlalchemy import * from sqlalchemy import types, schema, event from sqlalchemy.databases import mssql -from sqlalchemy.testing import (fixtures, - AssertsCompiledSQL, - ComparesTables) +from sqlalchemy.testing import fixtures, AssertsCompiledSQL, ComparesTables from sqlalchemy import testing from sqlalchemy.engine.reflection import Inspector from sqlalchemy import util @@ -13,8 +11,9 @@ from sqlalchemy.dialects.mssql.information_schema import CoerceUnicode, tables from sqlalchemy.dialects.mssql import base from sqlalchemy.testing import mock + class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): - __only_on__ = 'mssql' + __only_on__ = "mssql" __backend__ = True @testing.provide_metadata @@ -22,44 +21,49 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): meta = self.metadata users = Table( - 'engine_users', + "engine_users", meta, - Column('user_id', types.INT, primary_key=True), - Column('user_name', types.VARCHAR(20), nullable=False), - Column('test1', types.CHAR(5), nullable=False), - Column('test2', types.Float(5), nullable=False), - Column('test3', types.Text()), - Column('test4', types.Numeric, nullable=False), - Column('test5', types.DateTime), - Column('parent_user_id', types.Integer, - ForeignKey('engine_users.user_id')), - Column('test6', types.DateTime, nullable=False), - Column('test7', types.Text()), - Column('test8', types.LargeBinary()), - Column('test_passivedefault2', types.Integer, - server_default='5'), - Column('test9', types.BINARY(100)), - Column('test_numeric', types.Numeric()), + Column("user_id", types.INT, primary_key=True), + Column("user_name", types.VARCHAR(20), nullable=False), + Column("test1", types.CHAR(5), nullable=False), + Column("test2", types.Float(5), nullable=False), + Column("test3", types.Text()), + Column("test4", types.Numeric, nullable=False), + Column("test5", types.DateTime), + Column( + "parent_user_id", + types.Integer, + ForeignKey("engine_users.user_id"), + ), + Column("test6", types.DateTime, nullable=False), + Column("test7", types.Text()), + Column("test8", types.LargeBinary()), + Column("test_passivedefault2", types.Integer, server_default="5"), + Column("test9", types.BINARY(100)), + Column("test_numeric", types.Numeric()), ) addresses = Table( - 'engine_email_addresses', + "engine_email_addresses", meta, - Column('address_id', types.Integer, primary_key=True), - Column('remote_user_id', types.Integer, - ForeignKey(users.c.user_id)), - Column('email_address', types.String(20)), - ) + Column("address_id", types.Integer, primary_key=True), + Column( + "remote_user_id", types.Integer, ForeignKey(users.c.user_id) + ), + Column("email_address", types.String(20)), + ) meta.create_all() meta2 = MetaData() - reflected_users = Table('engine_users', meta2, - autoload=True, - autoload_with=testing.db) - reflected_addresses = Table('engine_email_addresses', - meta2, - autoload=True, - autoload_with=testing.db) + reflected_users = Table( + "engine_users", meta2, autoload=True, autoload_with=testing.db + ) + reflected_addresses = Table( + "engine_email_addresses", + meta2, + autoload=True, + autoload_with=testing.db, + ) self.assert_tables_equal(users, reflected_users) self.assert_tables_equal(addresses, reflected_addresses) @@ -67,17 +71,14 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): def _test_specific_type(self, type_obj, ddl): metadata = self.metadata - table = Table( - 'type_test', metadata, - Column('col1', type_obj) - ) + table = Table("type_test", metadata, Column("col1", type_obj)) table.create() m2 = MetaData() - table2 = Table('type_test', m2, autoload_with=testing.db) + table2 = Table("type_test", m2, autoload_with=testing.db) self.assert_compile( schema.CreateTable(table2), - "CREATE TABLE type_test (col1 %s NULL)" % ddl + "CREATE TABLE type_test (col1 %s NULL)" % ddl, ) def test_xml_type(self): @@ -93,31 +94,36 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): def test_identity(self): metadata = self.metadata table = Table( - 'identity_test', metadata, - Column('col1', Integer, mssql_identity_start=2, - mssql_identity_increment=3, primary_key=True) + "identity_test", + metadata, + Column( + "col1", + Integer, + mssql_identity_start=2, + mssql_identity_increment=3, + primary_key=True, + ), ) table.create() meta2 = MetaData(testing.db) - table2 = Table('identity_test', meta2, autoload=True) - eq_( - table2.c['col1'].dialect_options['mssql'][ - 'identity_start'], 2) - eq_( - table2.c['col1'].dialect_options['mssql'][ - 'identity_increment'], 3) + table2 = Table("identity_test", meta2, autoload=True) + eq_(table2.c["col1"].dialect_options["mssql"]["identity_start"], 2) + eq_(table2.c["col1"].dialect_options["mssql"]["identity_increment"], 3) @testing.emits_warning("Did not recognize") @testing.provide_metadata def test_skip_types(self): metadata = self.metadata - testing.db.execute(""" + testing.db.execute( + """ create table foo (id integer primary key, data xml) - """) + """ + ) with mock.patch.object( - testing.db.dialect, "ischema_names", {"int": mssql.INTEGER}): - t1 = Table('foo', metadata, autoload=True) + testing.db.dialect, "ischema_names", {"int": mssql.INTEGER} + ): + t1 = Table("foo", metadata, autoload=True) assert isinstance(t1.c.id.type, Integer) assert isinstance(t1.c.data.type, types.NullType) @@ -127,29 +133,33 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): metadata = self.metadata Table( - "subject", metadata, + "subject", + metadata, Column("id", Integer), PrimaryKeyConstraint("id", name="subj_pk"), schema=testing.config.test_schema, ) Table( - "referrer", metadata, + "referrer", + metadata, Column("id", Integer, primary_key=True), Column( - 'sid', + "sid", ForeignKey( "%s.subject.id" % testing.config.test_schema, - name='fk_subject') + name="fk_subject", + ), ), - schema=testing.config.test_schema + schema=testing.config.test_schema, ) Table( - "subject", metadata, + "subject", + metadata, Column("id", Integer), PrimaryKeyConstraint("id", name="subj_pk"), - schema=testing.config.test_schema_2 + schema=testing.config.test_schema_2, ) metadata.create_all() @@ -157,102 +167,111 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): insp = inspect(testing.db) eq_( insp.get_foreign_keys("referrer", testing.config.test_schema), - [{ - 'name': 'fk_subject', - 'constrained_columns': ['sid'], - 'referred_schema': 'test_schema', - 'referred_table': 'subject', - 'referred_columns': ['id']}] + [ + { + "name": "fk_subject", + "constrained_columns": ["sid"], + "referred_schema": "test_schema", + "referred_table": "subject", + "referred_columns": ["id"], + } + ], ) @testing.provide_metadata def test_db_qualified_items(self): metadata = self.metadata - Table('foo', metadata, Column('id', Integer, primary_key=True)) - Table('bar', - metadata, - Column('id', Integer, primary_key=True), - Column('foo_id', Integer, ForeignKey('foo.id', name="fkfoo"))) + Table("foo", metadata, Column("id", Integer, primary_key=True)) + Table( + "bar", + metadata, + Column("id", Integer, primary_key=True), + Column("foo_id", Integer, ForeignKey("foo.id", name="fkfoo")), + ) metadata.create_all() dbname = testing.db.scalar("select db_name()") owner = testing.db.scalar("SELECT user_name()") - referred_schema = '%(dbname)s.%(owner)s' % { - "dbname": dbname, "owner": owner} + referred_schema = "%(dbname)s.%(owner)s" % { + "dbname": dbname, + "owner": owner, + } inspector = inspect(testing.db) - bar_via_db = inspector.get_foreign_keys( - "bar", schema=referred_schema) + bar_via_db = inspector.get_foreign_keys("bar", schema=referred_schema) eq_( bar_via_db, - [{ - 'referred_table': 'foo', - 'referred_columns': ['id'], - 'referred_schema': referred_schema, - 'name': 'fkfoo', - 'constrained_columns': ['foo_id']}] + [ + { + "referred_table": "foo", + "referred_columns": ["id"], + "referred_schema": referred_schema, + "name": "fkfoo", + "constrained_columns": ["foo_id"], + } + ], ) assert testing.db.has_table("bar", schema=referred_schema) m2 = MetaData() - Table('bar', m2, schema=referred_schema, autoload=True, - autoload_with=testing.db) + Table( + "bar", + m2, + schema=referred_schema, + autoload=True, + autoload_with=testing.db, + ) eq_(m2.tables["%s.foo" % referred_schema].schema, referred_schema) @testing.provide_metadata def test_indexes_cols(self): metadata = self.metadata - t1 = Table('t', metadata, Column('x', Integer), Column('y', Integer)) - Index('foo', t1.c.x, t1.c.y) + t1 = Table("t", metadata, Column("x", Integer), Column("y", Integer)) + Index("foo", t1.c.x, t1.c.y) metadata.create_all() m2 = MetaData() - t2 = Table('t', m2, autoload=True, autoload_with=testing.db) + t2 = Table("t", m2, autoload=True, autoload_with=testing.db) - eq_( - set(list(t2.indexes)[0].columns), - set([t2.c['x'], t2.c.y]) - ) + eq_(set(list(t2.indexes)[0].columns), set([t2.c["x"], t2.c.y])) @testing.provide_metadata def test_indexes_cols_with_commas(self): metadata = self.metadata - t1 = Table('t', - metadata, - Column('x, col', Integer, key='x'), - Column('y', Integer)) - Index('foo', t1.c.x, t1.c.y) + t1 = Table( + "t", + metadata, + Column("x, col", Integer, key="x"), + Column("y", Integer), + ) + Index("foo", t1.c.x, t1.c.y) metadata.create_all() m2 = MetaData() - t2 = Table('t', m2, autoload=True, autoload_with=testing.db) + t2 = Table("t", m2, autoload=True, autoload_with=testing.db) - eq_( - set(list(t2.indexes)[0].columns), - set([t2.c['x, col'], t2.c.y]) - ) + eq_(set(list(t2.indexes)[0].columns), set([t2.c["x, col"], t2.c.y])) @testing.provide_metadata def test_indexes_cols_with_spaces(self): metadata = self.metadata - t1 = Table('t', - metadata, - Column('x col', Integer, key='x'), - Column('y', Integer)) - Index('foo', t1.c.x, t1.c.y) + t1 = Table( + "t", + metadata, + Column("x col", Integer, key="x"), + Column("y", Integer), + ) + Index("foo", t1.c.x, t1.c.y) metadata.create_all() m2 = MetaData() - t2 = Table('t', m2, autoload=True, autoload_with=testing.db) + t2 = Table("t", m2, autoload=True, autoload_with=testing.db) - eq_( - set(list(t2.indexes)[0].columns), - set([t2.c['x col'], t2.c.y]) - ) + eq_(set(list(t2.indexes)[0].columns), set([t2.c["x col"], t2.c.y])) @testing.provide_metadata def test_max_ident_in_varchar_not_present(self): @@ -267,75 +286,81 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): metadata = self.metadata Table( - 't', metadata, - Column('t1', types.String), - Column('t2', types.Text('max')), - Column('t3', types.Text('max')), - Column('t4', types.LargeBinary('max')), - Column('t5', types.VARBINARY('max')), + "t", + metadata, + Column("t1", types.String), + Column("t2", types.Text("max")), + Column("t3", types.Text("max")), + Column("t4", types.LargeBinary("max")), + Column("t5", types.VARBINARY("max")), ) metadata.create_all() - for col in inspect(testing.db).get_columns('t'): - is_(col['type'].length, None) - in_('max', str(col['type'].compile(dialect=testing.db.dialect))) + for col in inspect(testing.db).get_columns("t"): + is_(col["type"].length, None) + in_("max", str(col["type"].compile(dialect=testing.db.dialect))) class InfoCoerceUnicodeTest(fixtures.TestBase, AssertsCompiledSQL): def test_info_unicode_coercion(self): dialect = mssql.dialect() - value = CoerceUnicode().bind_processor(dialect)('a string') + value = CoerceUnicode().bind_processor(dialect)("a string") assert isinstance(value, util.text_type) def test_info_unicode_cast_no_2000(self): dialect = mssql.dialect() dialect.server_version_info = base.MS_2000_VERSION - stmt = tables.c.table_name == 'somename' + stmt = tables.c.table_name == "somename" self.assert_compile( stmt, "[INFORMATION_SCHEMA].[TABLES].[TABLE_NAME] = :table_name_1", - dialect=dialect + dialect=dialect, ) def test_info_unicode_cast(self): dialect = mssql.dialect() dialect.server_version_info = base.MS_2005_VERSION - stmt = tables.c.table_name == 'somename' + stmt = tables.c.table_name == "somename" self.assert_compile( stmt, "[INFORMATION_SCHEMA].[TABLES].[TABLE_NAME] = " "CAST(:table_name_1 AS NVARCHAR(max))", - dialect=dialect + dialect=dialect, ) class ReflectHugeViewTest(fixtures.TestBase): - __only_on__ = 'mssql' + __only_on__ = "mssql" __backend__ = True # crashes on freetds 0.91, not worth it - __skip_if__ = ( - lambda: testing.requires.mssql_freetds.enabled, - ) + __skip_if__ = (lambda: testing.requires.mssql_freetds.enabled,) def setup(self): self.col_num = 150 self.metadata = MetaData(testing.db) - t = Table('base_table', self.metadata, - *[Column("long_named_column_number_%d" % i, Integer) + t = Table( + "base_table", + self.metadata, + *[ + Column("long_named_column_number_%d" % i, Integer) + for i in range(self.col_num) + ] + ) + self.view_str = view_str = ( + "CREATE VIEW huge_named_view AS SELECT %s FROM base_table" + % ( + ",".join( + "long_named_column_number_%d" % i for i in range(self.col_num) - ] - ) - self.view_str = view_str = \ - "CREATE VIEW huge_named_view AS SELECT %s FROM base_table" % ( - ",".join("long_named_column_number_%d" % i - for i in range(self.col_num)) + ) ) + ) assert len(view_str) > 4000 - event.listen(t, 'after_create', DDL(view_str)) - event.listen(t, 'before_drop', DDL("DROP VIEW huge_named_view")) + event.listen(t, "after_create", DDL(view_str)) + event.listen(t, "before_drop", DDL("DROP VIEW huge_named_view")) self.metadata.create_all() diff --git a/test/dialect/mssql/test_types.py b/test/dialect/mssql/test_types.py index ac45566719..260d6d0dec 100644 --- a/test/dialect/mssql/test_types.py +++ b/test/dialect/mssql/test_types.py @@ -3,18 +3,38 @@ from sqlalchemy.testing import eq_, engines, pickleable, assert_raises_message from sqlalchemy.testing import is_, is_not_ import datetime import os -from sqlalchemy import Table, Column, MetaData, Float, \ - Integer, String, Boolean, Sequence, Numeric, select, \ - Date, Time, DateTime, DefaultClause, PickleType, text, Text, \ - UnicodeText, LargeBinary +from sqlalchemy import ( + Table, + Column, + MetaData, + Float, + Integer, + String, + Boolean, + Sequence, + Numeric, + select, + Date, + Time, + DateTime, + DefaultClause, + PickleType, + text, + Text, + UnicodeText, + LargeBinary, +) from sqlalchemy.dialects.mssql import TIMESTAMP, ROWVERSION from sqlalchemy import types, schema from sqlalchemy import util from sqlalchemy.databases import mssql from sqlalchemy.dialects.mssql.base import TIME, _MSDate from sqlalchemy.dialects.mssql.base import MS_2005_VERSION, MS_2008_VERSION -from sqlalchemy.testing import fixtures, \ - AssertsExecutionResults, ComparesTables +from sqlalchemy.testing import ( + fixtures, + AssertsExecutionResults, + ComparesTables, +) from sqlalchemy import testing from sqlalchemy.testing import emits_warning_on import decimal @@ -26,15 +46,14 @@ import codecs class TimeTypeTest(fixtures.TestBase): - def test_result_processor_no_microseconds(self): expected = datetime.time(12, 34, 56) - self._assert_result_processor(expected, '12:34:56') + self._assert_result_processor(expected, "12:34:56") def test_result_processor_too_many_microseconds(self): # microsecond must be in 0..999999, should truncate (6 vs 7 digits) expected = datetime.time(12, 34, 56, 123456) - self._assert_result_processor(expected, '12:34:56.1234567') + self._assert_result_processor(expected, "12:34:56.1234567") def _assert_result_processor(self, expected, value): mssql_time_type = TIME() @@ -47,17 +66,18 @@ class TimeTypeTest(fixtures.TestBase): assert_raises_message( ValueError, "could not parse 'abc' as a time value", - result_processor, 'abc' + result_processor, + "abc", ) class MSDateTypeTest(fixtures.TestBase): - __only_on__ = 'mssql' + __only_on__ = "mssql" __backend__ = True def test_result_processor(self): expected = datetime.date(2000, 1, 2) - self._assert_result_processor(expected, '2000-01-02') + self._assert_result_processor(expected, "2000-01-02") def _assert_result_processor(self, expected, value): mssql_date_type = _MSDate() @@ -70,52 +90,53 @@ class MSDateTypeTest(fixtures.TestBase): assert_raises_message( ValueError, "could not parse 'abc' as a date value", - result_processor, 'abc' + result_processor, + "abc", ) def test_extract(self): from sqlalchemy import extract - fivedaysago = datetime.datetime.now() \ - - datetime.timedelta(days=5) - for field, exp in ('year', fivedaysago.year), \ - ('month', fivedaysago.month), ('day', fivedaysago.day): + + fivedaysago = datetime.datetime.now() - datetime.timedelta(days=5) + for field, exp in ( + ("year", fivedaysago.year), + ("month", fivedaysago.month), + ("day", fivedaysago.day), + ): r = testing.db.execute( - select([ - extract(field, fivedaysago)]) + select([extract(field, fivedaysago)]) ).scalar() eq_(r, exp) class RowVersionTest(fixtures.TablesTest): - __only_on__ = 'mssql' + __only_on__ = "mssql" __backend__ = True @classmethod def define_tables(cls, metadata): Table( - 'rv_t', metadata, - Column('data', String(50)), - Column('rv', ROWVERSION) + "rv_t", + metadata, + Column("data", String(50)), + Column("rv", ROWVERSION), ) Table( - 'ts_t', metadata, - Column('data', String(50)), - Column('rv', TIMESTAMP) + "ts_t", + metadata, + Column("data", String(50)), + Column("rv", TIMESTAMP), ) def test_rowversion_reflection(self): # ROWVERSION is only a synonym for TIMESTAMP insp = inspect(testing.db) - assert isinstance( - insp.get_columns('rv_t')[1]['type'], TIMESTAMP - ) + assert isinstance(insp.get_columns("rv_t")[1]["type"], TIMESTAMP) def test_timestamp_reflection(self): insp = inspect(testing.db) - assert isinstance( - insp.get_columns('ts_t')[1]['type'], TIMESTAMP - ) + assert isinstance(insp.get_columns("ts_t")[1]["type"], TIMESTAMP) def test_class_hierarchy(self): """TIMESTAMP and ROWVERSION aren't datetime types, theyre binary.""" @@ -124,38 +145,40 @@ class RowVersionTest(fixtures.TablesTest): assert issubclass(ROWVERSION, sqltypes._Binary) def test_round_trip_ts(self): - self._test_round_trip('ts_t', TIMESTAMP, False) + self._test_round_trip("ts_t", TIMESTAMP, False) def test_round_trip_rv(self): - self._test_round_trip('rv_t', ROWVERSION, False) + self._test_round_trip("rv_t", ROWVERSION, False) def test_round_trip_ts_int(self): - self._test_round_trip('ts_t', TIMESTAMP, True) + self._test_round_trip("ts_t", TIMESTAMP, True) def test_round_trip_rv_int(self): - self._test_round_trip('rv_t', ROWVERSION, True) + self._test_round_trip("rv_t", ROWVERSION, True) def _test_round_trip(self, tab, cls, convert_int): t = Table( - tab, MetaData(), - Column('data', String(50)), - Column('rv', cls(convert_int=convert_int)) + tab, + MetaData(), + Column("data", String(50)), + Column("rv", cls(convert_int=convert_int)), ) with testing.db.connect() as conn: - conn.execute(t.insert().values(data='foo')) + conn.execute(t.insert().values(data="foo")) last_ts_1 = conn.scalar("SELECT @@DBTS") if convert_int: - last_ts_1 = int(codecs.encode(last_ts_1, 'hex'), 16) + last_ts_1 = int(codecs.encode(last_ts_1, "hex"), 16) eq_(conn.scalar(select([t.c.rv])), last_ts_1) conn.execute( - t.update().values(data='bar').where(t.c.data == 'foo')) + t.update().values(data="bar").where(t.c.data == "foo") + ) last_ts_2 = conn.scalar("SELECT @@DBTS") if convert_int: - last_ts_2 = int(codecs.encode(last_ts_2, 'hex'), 16) + last_ts_2 = int(codecs.encode(last_ts_2, "hex"), 16) eq_(conn.scalar(select([t.c.rv])), last_ts_2) @@ -171,27 +194,26 @@ class RowVersionTest(fixtures.TablesTest): sa.exc.DBAPIError, r".*Cannot insert an explicit value into a timestamp column.", conn.execute, - tab.insert().values(data='ins', rv=b'000') + tab.insert().values(data="ins", rv=b"000"), ) class TypeDDLTest(fixtures.TestBase): - def test_boolean(self): "Exercise type specification for boolean type." columns = [ # column type, args, kwargs, expected ddl - (Boolean, [], {}, - 'BIT'), + (Boolean, [], {}, "BIT") ] metadata = MetaData() - table_args = ['test_mssql_boolean', metadata] + table_args = ["test_mssql_boolean", metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append( - Column('c%s' % index, type_(*args, **kw), nullable=None)) + Column("c%s" % index, type_(*args, **kw), nullable=None) + ) boolean_table = Table(*table_args) dialect = mssql.dialect() @@ -201,7 +223,8 @@ class TypeDDLTest(fixtures.TestBase): index = int(col.name[1:]) testing.eq_( gen.get_column_specification(col), - "%s %s" % (col.name, columns[index][3])) + "%s %s" % (col.name, columns[index][3]), + ) self.assert_(repr(col)) def test_numeric(self): @@ -209,38 +232,26 @@ class TypeDDLTest(fixtures.TestBase): columns = [ # column type, args, kwargs, expected ddl - (types.NUMERIC, [], {}, - 'NUMERIC'), - (types.NUMERIC, [None], {}, - 'NUMERIC'), - (types.NUMERIC, [12, 4], {}, - 'NUMERIC(12, 4)'), - - (types.Float, [], {}, - 'FLOAT'), - (types.Float, [None], {}, - 'FLOAT'), - (types.Float, [12], {}, - 'FLOAT(12)'), - (mssql.MSReal, [], {}, - 'REAL'), - - (types.Integer, [], {}, - 'INTEGER'), - (types.BigInteger, [], {}, - 'BIGINT'), - (mssql.MSTinyInteger, [], {}, - 'TINYINT'), - (types.SmallInteger, [], {}, - 'SMALLINT'), + (types.NUMERIC, [], {}, "NUMERIC"), + (types.NUMERIC, [None], {}, "NUMERIC"), + (types.NUMERIC, [12, 4], {}, "NUMERIC(12, 4)"), + (types.Float, [], {}, "FLOAT"), + (types.Float, [None], {}, "FLOAT"), + (types.Float, [12], {}, "FLOAT(12)"), + (mssql.MSReal, [], {}, "REAL"), + (types.Integer, [], {}, "INTEGER"), + (types.BigInteger, [], {}, "BIGINT"), + (mssql.MSTinyInteger, [], {}, "TINYINT"), + (types.SmallInteger, [], {}, "SMALLINT"), ] metadata = MetaData() - table_args = ['test_mssql_numeric', metadata] + table_args = ["test_mssql_numeric", metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append( - Column('c%s' % index, type_(*args, **kw), nullable=None)) + Column("c%s" % index, type_(*args, **kw), nullable=None) + ) numeric_table = Table(*table_args) dialect = mssql.dialect() @@ -250,58 +261,69 @@ class TypeDDLTest(fixtures.TestBase): index = int(col.name[1:]) testing.eq_( gen.get_column_specification(col), - "%s %s" % (col.name, columns[index][3])) + "%s %s" % (col.name, columns[index][3]), + ) self.assert_(repr(col)) def test_char(self): """Exercise COLLATE-ish options on string types.""" columns = [ - (mssql.MSChar, [], {}, - 'CHAR'), - (mssql.MSChar, [1], {}, - 'CHAR(1)'), - (mssql.MSChar, [1], {'collation': 'Latin1_General_CI_AS'}, - 'CHAR(1) COLLATE Latin1_General_CI_AS'), - - (mssql.MSNChar, [], {}, - 'NCHAR'), - (mssql.MSNChar, [1], {}, - 'NCHAR(1)'), - (mssql.MSNChar, [1], {'collation': 'Latin1_General_CI_AS'}, - 'NCHAR(1) COLLATE Latin1_General_CI_AS'), - - (mssql.MSString, [], {}, - 'VARCHAR(max)'), - (mssql.MSString, [1], {}, - 'VARCHAR(1)'), - (mssql.MSString, [1], {'collation': 'Latin1_General_CI_AS'}, - 'VARCHAR(1) COLLATE Latin1_General_CI_AS'), - - (mssql.MSNVarchar, [], {}, - 'NVARCHAR(max)'), - (mssql.MSNVarchar, [1], {}, - 'NVARCHAR(1)'), - (mssql.MSNVarchar, [1], {'collation': 'Latin1_General_CI_AS'}, - 'NVARCHAR(1) COLLATE Latin1_General_CI_AS'), - - (mssql.MSText, [], {}, - 'TEXT'), - (mssql.MSText, [], {'collation': 'Latin1_General_CI_AS'}, - 'TEXT COLLATE Latin1_General_CI_AS'), - - (mssql.MSNText, [], {}, - 'NTEXT'), - (mssql.MSNText, [], {'collation': 'Latin1_General_CI_AS'}, - 'NTEXT COLLATE Latin1_General_CI_AS'), + (mssql.MSChar, [], {}, "CHAR"), + (mssql.MSChar, [1], {}, "CHAR(1)"), + ( + mssql.MSChar, + [1], + {"collation": "Latin1_General_CI_AS"}, + "CHAR(1) COLLATE Latin1_General_CI_AS", + ), + (mssql.MSNChar, [], {}, "NCHAR"), + (mssql.MSNChar, [1], {}, "NCHAR(1)"), + ( + mssql.MSNChar, + [1], + {"collation": "Latin1_General_CI_AS"}, + "NCHAR(1) COLLATE Latin1_General_CI_AS", + ), + (mssql.MSString, [], {}, "VARCHAR(max)"), + (mssql.MSString, [1], {}, "VARCHAR(1)"), + ( + mssql.MSString, + [1], + {"collation": "Latin1_General_CI_AS"}, + "VARCHAR(1) COLLATE Latin1_General_CI_AS", + ), + (mssql.MSNVarchar, [], {}, "NVARCHAR(max)"), + (mssql.MSNVarchar, [1], {}, "NVARCHAR(1)"), + ( + mssql.MSNVarchar, + [1], + {"collation": "Latin1_General_CI_AS"}, + "NVARCHAR(1) COLLATE Latin1_General_CI_AS", + ), + (mssql.MSText, [], {}, "TEXT"), + ( + mssql.MSText, + [], + {"collation": "Latin1_General_CI_AS"}, + "TEXT COLLATE Latin1_General_CI_AS", + ), + (mssql.MSNText, [], {}, "NTEXT"), + ( + mssql.MSNText, + [], + {"collation": "Latin1_General_CI_AS"}, + "NTEXT COLLATE Latin1_General_CI_AS", + ), ] metadata = MetaData() - table_args = ['test_mssql_charset', metadata] + table_args = ["test_mssql_charset", metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append( - Column('c%s' % index, type_(*args, **kw), nullable=None)) + Column("c%s" % index, type_(*args, **kw), nullable=None) + ) charset_table = Table(*table_args) dialect = mssql.dialect() @@ -311,7 +333,8 @@ class TypeDDLTest(fixtures.TestBase): index = int(col.name[1:]) testing.eq_( gen.get_column_specification(col), - "%s %s" % (col.name, columns[index][3])) + "%s %s" % (col.name, columns[index][3]), + ) self.assert_(repr(col)) def test_dates(self): @@ -319,62 +342,35 @@ class TypeDDLTest(fixtures.TestBase): columns = [ # column type, args, kwargs, expected ddl - (mssql.MSDateTime, [], {}, - 'DATETIME', None), - - (types.DATE, [], {}, - 'DATE', None), - (types.Date, [], {}, - 'DATE', None), - (types.Date, [], {}, - 'DATETIME', MS_2005_VERSION), - (mssql.MSDate, [], {}, - 'DATE', None), - (mssql.MSDate, [], {}, - 'DATETIME', MS_2005_VERSION), - - (types.TIME, [], {}, - 'TIME', None), - (types.Time, [], {}, - 'TIME', None), - (mssql.MSTime, [], {}, - 'TIME', None), - (mssql.MSTime, [1], {}, - 'TIME(1)', None), - (types.Time, [], {}, - 'DATETIME', MS_2005_VERSION), - (mssql.MSTime, [], {}, - 'TIME', None), - - (mssql.MSSmallDateTime, [], {}, - 'SMALLDATETIME', None), - - (mssql.MSDateTimeOffset, [], {}, - 'DATETIMEOFFSET', None), - (mssql.MSDateTimeOffset, [1], {}, - 'DATETIMEOFFSET(1)', None), - - (mssql.MSDateTime2, [], {}, - 'DATETIME2', None), - (mssql.MSDateTime2, [0], {}, - 'DATETIME2(0)', None), - (mssql.MSDateTime2, [1], {}, - 'DATETIME2(1)', None), - - (mssql.MSTime, [0], {}, - 'TIME(0)', None), - - (mssql.MSDateTimeOffset, [0], {}, - 'DATETIMEOFFSET(0)', None), - + (mssql.MSDateTime, [], {}, "DATETIME", None), + (types.DATE, [], {}, "DATE", None), + (types.Date, [], {}, "DATE", None), + (types.Date, [], {}, "DATETIME", MS_2005_VERSION), + (mssql.MSDate, [], {}, "DATE", None), + (mssql.MSDate, [], {}, "DATETIME", MS_2005_VERSION), + (types.TIME, [], {}, "TIME", None), + (types.Time, [], {}, "TIME", None), + (mssql.MSTime, [], {}, "TIME", None), + (mssql.MSTime, [1], {}, "TIME(1)", None), + (types.Time, [], {}, "DATETIME", MS_2005_VERSION), + (mssql.MSTime, [], {}, "TIME", None), + (mssql.MSSmallDateTime, [], {}, "SMALLDATETIME", None), + (mssql.MSDateTimeOffset, [], {}, "DATETIMEOFFSET", None), + (mssql.MSDateTimeOffset, [1], {}, "DATETIMEOFFSET(1)", None), + (mssql.MSDateTime2, [], {}, "DATETIME2", None), + (mssql.MSDateTime2, [0], {}, "DATETIME2(0)", None), + (mssql.MSDateTime2, [1], {}, "DATETIME2(1)", None), + (mssql.MSTime, [0], {}, "TIME(0)", None), + (mssql.MSDateTimeOffset, [0], {}, "DATETIMEOFFSET(0)", None), ] metadata = MetaData() - table_args = ['test_mssql_dates', metadata] + table_args = ["test_mssql_dates", metadata] for index, spec in enumerate(columns): type_, args, kw, res, server_version = spec table_args.append( - Column('c%s' % index, type_(*args, **kw), nullable=None)) + Column("c%s" % index, type_(*args, **kw), nullable=None) + ) date_table = Table(*table_args) dialect = mssql.dialect() @@ -383,7 +379,8 @@ class TypeDDLTest(fixtures.TestBase): ms_2005_dialect.server_version_info = MS_2005_VERSION gen = dialect.ddl_compiler(dialect, schema.CreateTable(date_table)) gen2005 = ms_2005_dialect.ddl_compiler( - ms_2005_dialect, schema.CreateTable(date_table)) + ms_2005_dialect, schema.CreateTable(date_table) + ) for col in date_table.c: index = int(col.name[1:]) @@ -391,11 +388,13 @@ class TypeDDLTest(fixtures.TestBase): if not server_version: testing.eq_( gen.get_column_specification(col), - "%s %s" % (col.name, columns[index][3])) + "%s %s" % (col.name, columns[index][3]), + ) else: testing.eq_( gen2005.get_column_specification(col), - "%s %s" % (col.name, columns[index][3])) + "%s %s" % (col.name, columns[index][3]), + ) self.assert_(repr(col)) @@ -410,52 +409,38 @@ class TypeDDLTest(fixtures.TestBase): d4._setup_version_attributes() for dialect in (d1, d3): - eq_( - str(Text().compile(dialect=dialect)), - "VARCHAR(max)" - ) - eq_( - str(UnicodeText().compile(dialect=dialect)), - "NVARCHAR(max)" - ) - eq_( - str(LargeBinary().compile(dialect=dialect)), - "VARBINARY(max)" - ) + eq_(str(Text().compile(dialect=dialect)), "VARCHAR(max)") + eq_(str(UnicodeText().compile(dialect=dialect)), "NVARCHAR(max)") + eq_(str(LargeBinary().compile(dialect=dialect)), "VARBINARY(max)") for dialect in (d2, d4): - eq_( - str(Text().compile(dialect=dialect)), - "TEXT" - ) - eq_( - str(UnicodeText().compile(dialect=dialect)), - "NTEXT" - ) - eq_( - str(LargeBinary().compile(dialect=dialect)), - "IMAGE" - ) + eq_(str(Text().compile(dialect=dialect)), "TEXT") + eq_(str(UnicodeText().compile(dialect=dialect)), "NTEXT") + eq_(str(LargeBinary().compile(dialect=dialect)), "IMAGE") def test_money(self): """Exercise type specification for money types.""" - columns = [(mssql.MSMoney, [], {}, 'MONEY'), - (mssql.MSSmallMoney, [], {}, 'SMALLMONEY')] + columns = [ + (mssql.MSMoney, [], {}, "MONEY"), + (mssql.MSSmallMoney, [], {}, "SMALLMONEY"), + ] metadata = MetaData() - table_args = ['test_mssql_money', metadata] + table_args = ["test_mssql_money", metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec - table_args.append(Column('c%s' % index, type_(*args, **kw), - nullable=None)) + table_args.append( + Column("c%s" % index, type_(*args, **kw), nullable=None) + ) money_table = Table(*table_args) dialect = mssql.dialect() - gen = dialect.ddl_compiler(dialect, - schema.CreateTable(money_table)) + gen = dialect.ddl_compiler(dialect, schema.CreateTable(money_table)) for col in money_table.c: index = int(col.name[1:]) - testing.eq_(gen.get_column_specification(col), '%s %s' - % (col.name, columns[index][3])) + testing.eq_( + gen.get_column_specification(col), + "%s %s" % (col.name, columns[index][3]), + ) self.assert_(repr(col)) def test_binary(self): @@ -463,50 +448,35 @@ class TypeDDLTest(fixtures.TestBase): columns = [ # column type, args, kwargs, expected ddl - (mssql.MSBinary, [], {}, - 'BINARY'), - (mssql.MSBinary, [10], {}, - 'BINARY(10)'), - - (types.BINARY, [], {}, - 'BINARY'), - (types.BINARY, [10], {}, - 'BINARY(10)'), - - (mssql.MSVarBinary, [], {}, - 'VARBINARY(max)'), - (mssql.MSVarBinary, [10], {}, - 'VARBINARY(10)'), - - (types.VARBINARY, [10], {}, - 'VARBINARY(10)'), - (types.VARBINARY, [], {}, - 'VARBINARY(max)'), - - (mssql.MSImage, [], {}, - 'IMAGE'), - - (mssql.IMAGE, [], {}, - 'IMAGE'), - - (types.LargeBinary, [], {}, - 'IMAGE'), + (mssql.MSBinary, [], {}, "BINARY"), + (mssql.MSBinary, [10], {}, "BINARY(10)"), + (types.BINARY, [], {}, "BINARY"), + (types.BINARY, [10], {}, "BINARY(10)"), + (mssql.MSVarBinary, [], {}, "VARBINARY(max)"), + (mssql.MSVarBinary, [10], {}, "VARBINARY(10)"), + (types.VARBINARY, [10], {}, "VARBINARY(10)"), + (types.VARBINARY, [], {}, "VARBINARY(max)"), + (mssql.MSImage, [], {}, "IMAGE"), + (mssql.IMAGE, [], {}, "IMAGE"), + (types.LargeBinary, [], {}, "IMAGE"), ] metadata = MetaData() - table_args = ['test_mssql_binary', metadata] + table_args = ["test_mssql_binary", metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec - table_args.append(Column('c%s' % index, type_(*args, **kw), - nullable=None)) + table_args.append( + Column("c%s" % index, type_(*args, **kw), nullable=None) + ) binary_table = Table(*table_args) dialect = mssql.dialect() - gen = dialect.ddl_compiler(dialect, - schema.CreateTable(binary_table)) + gen = dialect.ddl_compiler(dialect, schema.CreateTable(binary_table)) for col in binary_table.c: index = int(col.name[1:]) - testing.eq_(gen.get_column_specification(col), '%s %s' - % (col.name, columns[index][3])) + testing.eq_( + gen.get_column_specification(col), + "%s %s" % (col.name, columns[index][3]), + ) self.assert_(repr(col)) @@ -514,8 +484,9 @@ metadata = None class TypeRoundTripTest( - fixtures.TestBase, AssertsExecutionResults, ComparesTables): - __only_on__ = 'mssql' + fixtures.TestBase, AssertsExecutionResults, ComparesTables +): + __only_on__ = "mssql" __backend__ = True @@ -529,214 +500,219 @@ class TypeRoundTripTest( def test_decimal_notation(self): numeric_table = Table( - 'numeric_table', metadata, + "numeric_table", + metadata, Column( - 'id', Integer, - Sequence('numeric_id_seq', optional=True), primary_key=True), + "id", + Integer, + Sequence("numeric_id_seq", optional=True), + primary_key=True, + ), Column( - 'numericcol', - Numeric(precision=38, scale=20, asdecimal=True))) + "numericcol", Numeric(precision=38, scale=20, asdecimal=True) + ), + ) metadata.create_all() - test_items = [decimal.Decimal(d) for d in ( - '1500000.00000000000000000000', - '-1500000.00000000000000000000', - '1500000', - '0.0000000000000000002', - '0.2', - '-0.0000000000000000002', - '-2E-2', - '156666.458923543', - '-156666.458923543', - '1', - '-1', - '-1234', - '1234', - '2E-12', - '4E8', - '3E-6', - '3E-7', - '4.1', - '1E-1', - '1E-2', - '1E-3', - '1E-4', - '1E-5', - '1E-6', - '1E-7', - '1E-1', - '1E-8', - '0.2732E2', - '-0.2432E2', - '4.35656E2', - '-02452E-2', - '45125E-2', - '1234.58965E-2', - '1.521E+15', - - # previously, these were at -1E-25, which were inserted - # cleanly howver we only got back 20 digits of accuracy. - # pyodbc as of 4.0.22 now disallows the silent truncation. - '-1E-20', - '1E-20', - '1254E-20', - '-1203E-20', - - - '0', - '-0.00', - '-0', - '4585E12', - '000000000000000000012', - '000000000000.32E12', - '00000000000000.1E+12', - - # these are no longer accepted by pyodbc 4.0.22 but it seems - # they were not actually round-tripping correctly before that - # in any case - # '-1E-25', - # '1E-25', - # '1254E-25', - # '-1203E-25', - # '000000000000.2E-32', - )] + test_items = [ + decimal.Decimal(d) + for d in ( + "1500000.00000000000000000000", + "-1500000.00000000000000000000", + "1500000", + "0.0000000000000000002", + "0.2", + "-0.0000000000000000002", + "-2E-2", + "156666.458923543", + "-156666.458923543", + "1", + "-1", + "-1234", + "1234", + "2E-12", + "4E8", + "3E-6", + "3E-7", + "4.1", + "1E-1", + "1E-2", + "1E-3", + "1E-4", + "1E-5", + "1E-6", + "1E-7", + "1E-1", + "1E-8", + "0.2732E2", + "-0.2432E2", + "4.35656E2", + "-02452E-2", + "45125E-2", + "1234.58965E-2", + "1.521E+15", + # previously, these were at -1E-25, which were inserted + # cleanly howver we only got back 20 digits of accuracy. + # pyodbc as of 4.0.22 now disallows the silent truncation. + "-1E-20", + "1E-20", + "1254E-20", + "-1203E-20", + "0", + "-0.00", + "-0", + "4585E12", + "000000000000000000012", + "000000000000.32E12", + "00000000000000.1E+12", + # these are no longer accepted by pyodbc 4.0.22 but it seems + # they were not actually round-tripping correctly before that + # in any case + # '-1E-25', + # '1E-25', + # '1254E-25', + # '-1203E-25', + # '000000000000.2E-32', + ) + ] with testing.db.connect() as conn: for value in test_items: result = conn.execute( - numeric_table.insert(), - dict(numericcol=value) + numeric_table.insert(), dict(numericcol=value) ) primary_key = result.inserted_primary_key returned = conn.scalar( - select([numeric_table.c.numericcol]). - where(numeric_table.c.id == primary_key[0]) + select([numeric_table.c.numericcol]).where( + numeric_table.c.id == primary_key[0] + ) ) eq_(value, returned) def test_float(self): float_table = Table( - 'float_table', metadata, + "float_table", + metadata, Column( - 'id', Integer, - Sequence('numeric_id_seq', optional=True), primary_key=True), - Column('floatcol', Float())) + "id", + Integer, + Sequence("numeric_id_seq", optional=True), + primary_key=True, + ), + Column("floatcol", Float()), + ) metadata.create_all() try: - test_items = [float(d) for d in ( - '1500000.00000000000000000000', - '-1500000.00000000000000000000', - '1500000', - '0.0000000000000000002', - '0.2', - '-0.0000000000000000002', - '156666.458923543', - '-156666.458923543', - '1', - '-1', - '1234', - '2E-12', - '4E8', - '3E-6', - '3E-7', - '4.1', - '1E-1', - '1E-2', - '1E-3', - '1E-4', - '1E-5', - '1E-6', - '1E-7', - '1E-8', - )] + test_items = [ + float(d) + for d in ( + "1500000.00000000000000000000", + "-1500000.00000000000000000000", + "1500000", + "0.0000000000000000002", + "0.2", + "-0.0000000000000000002", + "156666.458923543", + "-156666.458923543", + "1", + "-1", + "1234", + "2E-12", + "4E8", + "3E-6", + "3E-7", + "4.1", + "1E-1", + "1E-2", + "1E-3", + "1E-4", + "1E-5", + "1E-6", + "1E-7", + "1E-8", + ) + ] for value in test_items: float_table.insert().execute(floatcol=value) except Exception as e: raise e # todo this should suppress warnings, but it does not - @emits_warning_on('mssql+mxodbc', r'.*does not have any indexes.*') + @emits_warning_on("mssql+mxodbc", r".*does not have any indexes.*") def test_dates(self): "Exercise type specification for date types." columns = [ # column type, args, kwargs, expected ddl - (mssql.MSDateTime, [], {}, - 'DATETIME', []), - - (types.DATE, [], {}, - 'DATE', ['>=', (10,)]), - (types.Date, [], {}, - 'DATE', ['>=', (10,)]), - (types.Date, [], {}, - 'DATETIME', ['<', (10,)], mssql.MSDateTime), - (mssql.MSDate, [], {}, - 'DATE', ['>=', (10,)]), - (mssql.MSDate, [], {}, - 'DATETIME', ['<', (10,)], mssql.MSDateTime), - - (types.TIME, [], {}, - 'TIME', ['>=', (10,)]), - (types.Time, [], {}, - 'TIME', ['>=', (10,)]), - (mssql.MSTime, [], {}, - 'TIME', ['>=', (10,)]), - (mssql.MSTime, [1], {}, - 'TIME(1)', ['>=', (10,)]), - (types.Time, [], {}, - 'DATETIME', ['<', (10,)], mssql.MSDateTime), - (mssql.MSTime, [], {}, - 'TIME', ['>=', (10,)]), - - (mssql.MSSmallDateTime, [], {}, - 'SMALLDATETIME', []), - - (mssql.MSDateTimeOffset, [], {}, - 'DATETIMEOFFSET', ['>=', (10,)]), - (mssql.MSDateTimeOffset, [1], {}, - 'DATETIMEOFFSET(1)', ['>=', (10,)]), - - (mssql.MSDateTime2, [], {}, - 'DATETIME2', ['>=', (10,)]), - (mssql.MSDateTime2, [0], {}, - 'DATETIME2(0)', ['>=', (10,)]), - (mssql.MSDateTime2, [1], {}, - 'DATETIME2(1)', ['>=', (10,)]), - + (mssql.MSDateTime, [], {}, "DATETIME", []), + (types.DATE, [], {}, "DATE", [">=", (10,)]), + (types.Date, [], {}, "DATE", [">=", (10,)]), + (types.Date, [], {}, "DATETIME", ["<", (10,)], mssql.MSDateTime), + (mssql.MSDate, [], {}, "DATE", [">=", (10,)]), + (mssql.MSDate, [], {}, "DATETIME", ["<", (10,)], mssql.MSDateTime), + (types.TIME, [], {}, "TIME", [">=", (10,)]), + (types.Time, [], {}, "TIME", [">=", (10,)]), + (mssql.MSTime, [], {}, "TIME", [">=", (10,)]), + (mssql.MSTime, [1], {}, "TIME(1)", [">=", (10,)]), + (types.Time, [], {}, "DATETIME", ["<", (10,)], mssql.MSDateTime), + (mssql.MSTime, [], {}, "TIME", [">=", (10,)]), + (mssql.MSSmallDateTime, [], {}, "SMALLDATETIME", []), + (mssql.MSDateTimeOffset, [], {}, "DATETIMEOFFSET", [">=", (10,)]), + ( + mssql.MSDateTimeOffset, + [1], + {}, + "DATETIMEOFFSET(1)", + [">=", (10,)], + ), + (mssql.MSDateTime2, [], {}, "DATETIME2", [">=", (10,)]), + (mssql.MSDateTime2, [0], {}, "DATETIME2(0)", [">=", (10,)]), + (mssql.MSDateTime2, [1], {}, "DATETIME2(1)", [">=", (10,)]), ] - table_args = ['test_mssql_dates', metadata] + table_args = ["test_mssql_dates", metadata] for index, spec in enumerate(columns): type_, args, kw, res, requires = spec[0:5] - if requires and \ - testing._is_excluded('mssql', *requires) or not requires: - c = Column('c%s' % index, type_(*args, **kw), nullable=None) + if ( + requires + and testing._is_excluded("mssql", *requires) + or not requires + ): + c = Column("c%s" % index, type_(*args, **kw), nullable=None) testing.db.dialect.type_descriptor(c.type) table_args.append(c) dates_table = Table(*table_args) gen = testing.db.dialect.ddl_compiler( - testing.db.dialect, - schema.CreateTable(dates_table)) + testing.db.dialect, schema.CreateTable(dates_table) + ) for col in dates_table.c: index = int(col.name[1:]) - testing.eq_(gen.get_column_specification(col), '%s %s' - % (col.name, columns[index][3])) + testing.eq_( + gen.get_column_specification(col), + "%s %s" % (col.name, columns[index][3]), + ) self.assert_(repr(col)) dates_table.create(checkfirst=True) - reflected_dates = Table('test_mssql_dates', - MetaData(testing.db), autoload=True) + reflected_dates = Table( + "test_mssql_dates", MetaData(testing.db), autoload=True + ) for col in reflected_dates.c: self.assert_types_base(col, dates_table.c[col.key]) def test_date_roundtrip(self): t = Table( - 'test_dates', metadata, - Column('id', Integer, - Sequence('datetest_id_seq', optional=True), - primary_key=True), - Column('adate', Date), - Column('atime', Time), - Column('adatetime', DateTime)) + "test_dates", + metadata, + Column( + "id", + Integer, + Sequence("datetest_id_seq", optional=True), + primary_key=True, + ), + Column("adate", Date), + Column("atime", Time), + Column("adatetime", DateTime), + ) metadata.create_all() d1 = datetime.date(2007, 10, 30) t1 = datetime.time(11, 2, 32) @@ -759,125 +735,141 @@ class TypeRoundTripTest( t.insert().execute(adate=d1, adatetime=d2, atime=t1) - eq_(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate - == d1).execute().fetchall(), [(d1, t1, d2)]) + eq_( + select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate == d1) + .execute() + .fetchall(), + [(d1, t1, d2)], + ) - @emits_warning_on('mssql+mxodbc', r'.*does not have any indexes.*') + @emits_warning_on("mssql+mxodbc", r".*does not have any indexes.*") @testing.provide_metadata def _test_binary_reflection(self, deprecate_large_types): "Exercise type specification for binary types." columns = [ # column type, args, kwargs, expected ddl from reflected - (mssql.MSBinary, [], {}, - 'BINARY(1)'), - (mssql.MSBinary, [10], {}, - 'BINARY(10)'), - - (types.BINARY, [], {}, - 'BINARY(1)'), - (types.BINARY, [10], {}, - 'BINARY(10)'), - - (mssql.MSVarBinary, [], {}, - 'VARBINARY(max)'), - (mssql.MSVarBinary, [10], {}, - 'VARBINARY(10)'), - - (types.VARBINARY, [10], {}, - 'VARBINARY(10)'), - (types.VARBINARY, [], {}, - 'VARBINARY(max)'), - - (mssql.MSImage, [], {}, - 'IMAGE'), - - (mssql.IMAGE, [], {}, - 'IMAGE'), - - (types.LargeBinary, [], {}, - 'IMAGE' if not deprecate_large_types else 'VARBINARY(max)'), + (mssql.MSBinary, [], {}, "BINARY(1)"), + (mssql.MSBinary, [10], {}, "BINARY(10)"), + (types.BINARY, [], {}, "BINARY(1)"), + (types.BINARY, [10], {}, "BINARY(10)"), + (mssql.MSVarBinary, [], {}, "VARBINARY(max)"), + (mssql.MSVarBinary, [10], {}, "VARBINARY(10)"), + (types.VARBINARY, [10], {}, "VARBINARY(10)"), + (types.VARBINARY, [], {}, "VARBINARY(max)"), + (mssql.MSImage, [], {}, "IMAGE"), + (mssql.IMAGE, [], {}, "IMAGE"), + ( + types.LargeBinary, + [], + {}, + "IMAGE" if not deprecate_large_types else "VARBINARY(max)", + ), ] metadata = self.metadata metadata.bind = engines.testing_engine( - options={"deprecate_large_types": deprecate_large_types}) - table_args = ['test_mssql_binary', metadata] + options={"deprecate_large_types": deprecate_large_types} + ) + table_args = ["test_mssql_binary", metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec - table_args.append(Column('c%s' % index, type_(*args, **kw), - nullable=None)) + table_args.append( + Column("c%s" % index, type_(*args, **kw), nullable=None) + ) binary_table = Table(*table_args) metadata.create_all() - reflected_binary = Table('test_mssql_binary', - MetaData(testing.db), autoload=True) + reflected_binary = Table( + "test_mssql_binary", MetaData(testing.db), autoload=True + ) for col, spec in zip(reflected_binary.c, columns): eq_( - str(col.type), spec[3], - "column %s %s != %s" % (col.key, str(col.type), spec[3]) + str(col.type), + spec[3], + "column %s %s != %s" % (col.key, str(col.type), spec[3]), ) c1 = testing.db.dialect.type_descriptor(col.type).__class__ - c2 = \ - testing.db.dialect.type_descriptor( - binary_table.c[col.name].type).__class__ - assert issubclass(c1, c2), \ - 'column %s: %r is not a subclass of %r' \ - % (col.key, c1, c2) + c2 = testing.db.dialect.type_descriptor( + binary_table.c[col.name].type + ).__class__ + assert issubclass( + c1, c2 + ), "column %s: %r is not a subclass of %r" % (col.key, c1, c2) if binary_table.c[col.name].type.length: - testing.eq_(col.type.length, - binary_table.c[col.name].type.length) + testing.eq_( + col.type.length, binary_table.c[col.name].type.length + ) def test_binary_reflection_legacy_large_types(self): self._test_binary_reflection(False) - @testing.only_on('mssql >= 11') + @testing.only_on("mssql >= 11") def test_binary_reflection_sql2012_large_types(self): self._test_binary_reflection(True) def test_autoincrement(self): Table( - 'ai_1', metadata, - Column('int_y', Integer, primary_key=True, autoincrement=True), - Column( - 'int_n', Integer, DefaultClause('0'), primary_key=True)) + "ai_1", + metadata, + Column("int_y", Integer, primary_key=True, autoincrement=True), + Column("int_n", Integer, DefaultClause("0"), primary_key=True), + ) Table( - 'ai_2', metadata, - Column('int_y', Integer, primary_key=True, autoincrement=True), - Column('int_n', Integer, DefaultClause('0'), primary_key=True)) + "ai_2", + metadata, + Column("int_y", Integer, primary_key=True, autoincrement=True), + Column("int_n", Integer, DefaultClause("0"), primary_key=True), + ) Table( - 'ai_3', metadata, - Column('int_n', Integer, DefaultClause('0'), primary_key=True), - Column('int_y', Integer, primary_key=True, autoincrement=True)) + "ai_3", + metadata, + Column("int_n", Integer, DefaultClause("0"), primary_key=True), + Column("int_y", Integer, primary_key=True, autoincrement=True), + ) Table( - 'ai_4', metadata, - Column('int_n', Integer, DefaultClause('0'), primary_key=True), - Column('int_n2', Integer, DefaultClause('0'), primary_key=True)) + "ai_4", + metadata, + Column("int_n", Integer, DefaultClause("0"), primary_key=True), + Column("int_n2", Integer, DefaultClause("0"), primary_key=True), + ) Table( - 'ai_5', metadata, - Column('int_y', Integer, primary_key=True, autoincrement=True), - Column('int_n', Integer, DefaultClause('0'), primary_key=True)) + "ai_5", + metadata, + Column("int_y", Integer, primary_key=True, autoincrement=True), + Column("int_n", Integer, DefaultClause("0"), primary_key=True), + ) Table( - 'ai_6', metadata, - Column('o1', String(1), DefaultClause('x'), primary_key=True), - Column('int_y', Integer, primary_key=True, autoincrement=True)) + "ai_6", + metadata, + Column("o1", String(1), DefaultClause("x"), primary_key=True), + Column("int_y", Integer, primary_key=True, autoincrement=True), + ) Table( - 'ai_7', metadata, - Column('o1', String(1), DefaultClause('x'), - primary_key=True), - Column('o2', String(1), DefaultClause('x'), - primary_key=True), - Column('int_y', Integer, autoincrement=True, primary_key=True)) + "ai_7", + metadata, + Column("o1", String(1), DefaultClause("x"), primary_key=True), + Column("o2", String(1), DefaultClause("x"), primary_key=True), + Column("int_y", Integer, autoincrement=True, primary_key=True), + ) Table( - 'ai_8', metadata, - Column('o1', String(1), DefaultClause('x'), - primary_key=True), - Column('o2', String(1), DefaultClause('x'), - primary_key=True)) + "ai_8", + metadata, + Column("o1", String(1), DefaultClause("x"), primary_key=True), + Column("o2", String(1), DefaultClause("x"), primary_key=True), + ) metadata.create_all() - table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4', - 'ai_5', 'ai_6', 'ai_7', 'ai_8'] + table_names = [ + "ai_1", + "ai_2", + "ai_3", + "ai_4", + "ai_5", + "ai_6", + "ai_7", + "ai_8", + ] mr = MetaData(testing.db) for name in table_names: @@ -886,65 +878,73 @@ class TypeRoundTripTest( # test that the flag itself reflects appropriately for col in tbl.c: - if 'int_y' in col.name: + if "int_y" in col.name: is_(col.autoincrement, True) is_(tbl._autoincrement_column, col) else: - eq_(col.autoincrement, 'auto') + eq_(col.autoincrement, "auto") is_not_(tbl._autoincrement_column, col) # mxodbc can't handle scope_identity() with DEFAULT VALUES - if testing.db.driver == 'mxodbc': - eng = \ - [engines.testing_engine(options={ - 'implicit_returning': True})] + if testing.db.driver == "mxodbc": + eng = [ + engines.testing_engine( + options={"implicit_returning": True} + ) + ] else: - eng = \ - [engines.testing_engine(options={ - 'implicit_returning': False}), - engines.testing_engine(options={ - 'implicit_returning': True})] + eng = [ + engines.testing_engine( + options={"implicit_returning": False} + ), + engines.testing_engine( + options={"implicit_returning": True} + ), + ] for counter, engine in enumerate(eng): engine.execute(tbl.insert()) - if 'int_y' in tbl.c: - assert engine.scalar(select([tbl.c.int_y])) \ - == counter + 1 - assert list( - engine.execute(tbl.select()).first()).\ - count(counter + 1) == 1 + if "int_y" in tbl.c: + assert engine.scalar(select([tbl.c.int_y])) == counter + 1 + assert ( + list(engine.execute(tbl.select()).first()).count( + counter + 1 + ) + == 1 + ) else: - assert 1 \ - not in list(engine.execute(tbl.select()).first()) + assert 1 not in list(engine.execute(tbl.select()).first()) engine.execute(tbl.delete()) class BinaryTest(fixtures.TestBase): - __only_on__ = 'mssql' - __requires__ = "non_broken_binary", + __only_on__ = "mssql" + __requires__ = ("non_broken_binary",) __backend__ = True def test_character_binary(self): - self._test_round_trip( - mssql.MSVarBinary(800), b("some normal data") - ) + self._test_round_trip(mssql.MSVarBinary(800), b("some normal data")) @testing.provide_metadata def _test_round_trip( - self, type_, data, deprecate_large_types=True, - expected=None): - if testing.db.dialect.deprecate_large_types is not \ - deprecate_large_types: + self, type_, data, deprecate_large_types=True, expected=None + ): + if ( + testing.db.dialect.deprecate_large_types + is not deprecate_large_types + ): engine = engines.testing_engine( - options={"deprecate_large_types": deprecate_large_types}) + options={"deprecate_large_types": deprecate_large_types} + ) else: engine = testing.db binary_table = Table( - 'binary_table', self.metadata, - Column('id', Integer, primary_key=True), - Column('data', type_) + "binary_table", + self.metadata, + Column("id", Integer, primary_key=True), + Column("data", type_), ) binary_table.create(engine) @@ -952,47 +952,37 @@ class BinaryTest(fixtures.TestBase): expected = data with engine.connect() as conn: - conn.execute( - binary_table.insert(), - data=data - ) + conn.execute(binary_table.insert(), data=data) - eq_( - conn.scalar(select([binary_table.c.data])), - expected - ) + eq_(conn.scalar(select([binary_table.c.data])), expected) eq_( conn.scalar( - text("select data from binary_table"). - columns(binary_table.c.data) + text("select data from binary_table").columns( + binary_table.c.data + ) ), - expected + expected, ) conn.execute(binary_table.delete()) conn.execute(binary_table.insert(), data=None) - eq_( - conn.scalar(select([binary_table.c.data])), - None - ) + eq_(conn.scalar(select([binary_table.c.data])), None) eq_( conn.scalar( - text("select data from binary_table"). - columns(binary_table.c.data) + text("select data from binary_table").columns( + binary_table.c.data + ) ), - None + None, ) def test_plain_pickle(self): - self._test_round_trip( - PickleType, pickleable.Foo("im foo 1") - ) + self._test_round_trip(PickleType, pickleable.Foo("im foo 1")) def test_custom_pickle(self): - class MyPickleType(types.TypeDecorator): impl = PickleType @@ -1010,46 +1000,30 @@ class BinaryTest(fixtures.TestBase): expected = pickleable.Foo("im foo 1") expected.stuff = "BINDim stuffRESULT" - self._test_round_trip( - MyPickleType, data, - expected=expected - ) + self._test_round_trip(MyPickleType, data, expected=expected) def test_image(self): - stream1 = self._load_stream('binary_data_one.dat') - self._test_round_trip( - mssql.MSImage, - stream1 - ) + stream1 = self._load_stream("binary_data_one.dat") + self._test_round_trip(mssql.MSImage, stream1) def test_large_binary(self): - stream1 = self._load_stream('binary_data_one.dat') - self._test_round_trip( - sqltypes.LargeBinary, - stream1 - ) + stream1 = self._load_stream("binary_data_one.dat") + self._test_round_trip(sqltypes.LargeBinary, stream1) def test_large_legacy_types(self): - stream1 = self._load_stream('binary_data_one.dat') + stream1 = self._load_stream("binary_data_one.dat") self._test_round_trip( - sqltypes.LargeBinary, - stream1, - deprecate_large_types=False + sqltypes.LargeBinary, stream1, deprecate_large_types=False ) def test_mssql_varbinary_max(self): - stream1 = self._load_stream('binary_data_one.dat') - self._test_round_trip( - mssql.VARBINARY("max"), - stream1 - ) + stream1 = self._load_stream("binary_data_one.dat") + self._test_round_trip(mssql.VARBINARY("max"), stream1) def test_mssql_legacy_varbinary_max(self): - stream1 = self._load_stream('binary_data_one.dat') + stream1 = self._load_stream("binary_data_one.dat") self._test_round_trip( - mssql.VARBINARY("max"), - stream1, - deprecate_large_types=False + mssql.VARBINARY("max"), stream1, deprecate_large_types=False ) def test_binary_slice(self): @@ -1071,18 +1045,16 @@ class BinaryTest(fixtures.TestBase): self._test_var_slice_zeropadding(mssql.VARBINARY, False) def _test_var_slice(self, type_): - stream1 = self._load_stream('binary_data_one.dat') + stream1 = self._load_stream("binary_data_one.dat") data = stream1[0:100] - self._test_round_trip( - type_(100), - data - ) + self._test_round_trip(type_(100), data) def _test_var_slice_zeropadding( - self, type_, pad, deprecate_large_types=True): - stream2 = self._load_stream('binary_data_two.dat') + self, type_, pad, deprecate_large_types=True + ): + stream2 = self._load_stream("binary_data_two.dat") data = stream2[0:99] @@ -1090,18 +1062,16 @@ class BinaryTest(fixtures.TestBase): # so we will get 100 bytes zero-padded if pad: - paddedstream = stream2[0:99] + b'\x00' + paddedstream = stream2[0:99] + b"\x00" else: paddedstream = stream2[0:99] - self._test_round_trip( - type_(100), - data, expected=paddedstream - ) + self._test_round_trip(type_(100), data, expected=paddedstream) def _load_stream(self, name, len=3000): fp = open( - os.path.join(os.path.dirname(__file__), "..", "..", name), 'rb') + os.path.join(os.path.dirname(__file__), "..", "..", name), "rb" + ) stream = fp.read(len) fp.close() return stream diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 9e12e2d4c6..79dff63a35 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -2,13 +2,48 @@ from sqlalchemy.testing import eq_, assert_raises_message, expect_warnings from sqlalchemy import sql, exc, schema, types as sqltypes -from sqlalchemy import Table, MetaData, Column, select, String, \ - Index, Integer, ForeignKey, PrimaryKeyConstraint, extract, \ - VARCHAR, NVARCHAR, Unicode, UnicodeText, \ - NUMERIC, DECIMAL, Numeric, Float, FLOAT, TIMESTAMP, DATE, \ - DATETIME, TIME, \ - DateTime, Time, Date, Interval, NCHAR, CHAR, CLOB, TEXT, Boolean, \ - BOOLEAN, LargeBinary, BLOB, SmallInteger, INT, func, cast, literal +from sqlalchemy import ( + Table, + MetaData, + Column, + select, + String, + Index, + Integer, + ForeignKey, + PrimaryKeyConstraint, + extract, + VARCHAR, + NVARCHAR, + Unicode, + UnicodeText, + NUMERIC, + DECIMAL, + Numeric, + Float, + FLOAT, + TIMESTAMP, + DATE, + DATETIME, + TIME, + DateTime, + Time, + Date, + Interval, + NCHAR, + CHAR, + CLOB, + TEXT, + Boolean, + BOOLEAN, + LargeBinary, + BLOB, + SmallInteger, + INT, + func, + cast, + literal, +) from sqlalchemy.dialects.mysql import insert from sqlalchemy.dialects.mysql import base as mysql @@ -22,237 +57,295 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = mysql.dialect() def test_reserved_words(self): - table = Table("mysql_table", MetaData(), - Column("col1", Integer), - Column("master_ssl_verify_server_cert", Integer)) + table = Table( + "mysql_table", + MetaData(), + Column("col1", Integer), + Column("master_ssl_verify_server_cert", Integer), + ) x = select([table.c.col1, table.c.master_ssl_verify_server_cert]) self.assert_compile( x, "SELECT mysql_table.col1, " - "mysql_table.`master_ssl_verify_server_cert` FROM mysql_table") + "mysql_table.`master_ssl_verify_server_cert` FROM mysql_table", + ) def test_create_index_simple(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', String(255))) - idx = Index('test_idx1', tbl.c.data) + tbl = Table("testtbl", m, Column("data", String(255))) + idx = Index("test_idx1", tbl.c.data) - self.assert_compile(schema.CreateIndex(idx), - 'CREATE INDEX test_idx1 ON testtbl (data)') + self.assert_compile( + schema.CreateIndex(idx), "CREATE INDEX test_idx1 ON testtbl (data)" + ) def test_create_index_with_prefix(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', String(255))) - idx = Index('test_idx1', tbl.c.data, mysql_length=10, - mysql_prefix='FULLTEXT') + tbl = Table("testtbl", m, Column("data", String(255))) + idx = Index( + "test_idx1", tbl.c.data, mysql_length=10, mysql_prefix="FULLTEXT" + ) - self.assert_compile(schema.CreateIndex(idx), - 'CREATE FULLTEXT INDEX test_idx1 ' - 'ON testtbl (data(10))') + self.assert_compile( + schema.CreateIndex(idx), + "CREATE FULLTEXT INDEX test_idx1 " "ON testtbl (data(10))", + ) def test_create_index_with_parser(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', String(255))) - idx = Index('test_idx1', tbl.c.data, mysql_length=10, - mysql_prefix='FULLTEXT', mysql_with_parser="ngram") + tbl = Table("testtbl", m, Column("data", String(255))) + idx = Index( + "test_idx1", + tbl.c.data, + mysql_length=10, + mysql_prefix="FULLTEXT", + mysql_with_parser="ngram", + ) - self.assert_compile(schema.CreateIndex(idx), - 'CREATE FULLTEXT INDEX test_idx1 ' - 'ON testtbl (data(10)) WITH PARSER ngram') + self.assert_compile( + schema.CreateIndex(idx), + "CREATE FULLTEXT INDEX test_idx1 " + "ON testtbl (data(10)) WITH PARSER ngram", + ) def test_create_index_with_length(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', String(255))) - idx1 = Index('test_idx1', tbl.c.data, mysql_length=10) - idx2 = Index('test_idx2', tbl.c.data, mysql_length=5) + tbl = Table("testtbl", m, Column("data", String(255))) + idx1 = Index("test_idx1", tbl.c.data, mysql_length=10) + idx2 = Index("test_idx2", tbl.c.data, mysql_length=5) - self.assert_compile(schema.CreateIndex(idx1), - 'CREATE INDEX test_idx1 ON testtbl (data(10))') - self.assert_compile(schema.CreateIndex(idx2), - 'CREATE INDEX test_idx2 ON testtbl (data(5))') + self.assert_compile( + schema.CreateIndex(idx1), + "CREATE INDEX test_idx1 ON testtbl (data(10))", + ) + self.assert_compile( + schema.CreateIndex(idx2), + "CREATE INDEX test_idx2 ON testtbl (data(5))", + ) def test_create_index_with_length_quoted(self): m = MetaData() - tbl = Table('testtbl', m, Column('some quoted data', - String(255), key='s')) - idx1 = Index('test_idx1', tbl.c.s, mysql_length=10) + tbl = Table( + "testtbl", m, Column("some quoted data", String(255), key="s") + ) + idx1 = Index("test_idx1", tbl.c.s, mysql_length=10) self.assert_compile( schema.CreateIndex(idx1), - 'CREATE INDEX test_idx1 ON testtbl (`some quoted data`(10))') + "CREATE INDEX test_idx1 ON testtbl (`some quoted data`(10))", + ) def test_create_composite_index_with_length_quoted(self): m = MetaData() - tbl = Table('testtbl', m, - Column('some Quoted a', String(255), key='a'), - Column('some Quoted b', String(255), key='b')) - idx1 = Index('test_idx1', tbl.c.a, tbl.c.b, - mysql_length={'some Quoted a': 10, 'some Quoted b': 20}) + tbl = Table( + "testtbl", + m, + Column("some Quoted a", String(255), key="a"), + Column("some Quoted b", String(255), key="b"), + ) + idx1 = Index( + "test_idx1", + tbl.c.a, + tbl.c.b, + mysql_length={"some Quoted a": 10, "some Quoted b": 20}, + ) - self.assert_compile(schema.CreateIndex(idx1), - 'CREATE INDEX test_idx1 ON testtbl ' - '(`some Quoted a`(10), `some Quoted b`(20))') + self.assert_compile( + schema.CreateIndex(idx1), + "CREATE INDEX test_idx1 ON testtbl " + "(`some Quoted a`(10), `some Quoted b`(20))", + ) def test_create_composite_index_with_length_quoted_3085_workaround(self): m = MetaData() - tbl = Table('testtbl', m, - Column('some quoted a', String(255), key='a'), - Column('some quoted b', String(255), key='b')) + tbl = Table( + "testtbl", + m, + Column("some quoted a", String(255), key="a"), + Column("some quoted b", String(255), key="b"), + ) idx1 = Index( - 'test_idx1', tbl.c.a, tbl.c.b, - mysql_length={'`some quoted a`': 10, '`some quoted b`': 20} + "test_idx1", + tbl.c.a, + tbl.c.b, + mysql_length={"`some quoted a`": 10, "`some quoted b`": 20}, ) - self.assert_compile(schema.CreateIndex(idx1), - 'CREATE INDEX test_idx1 ON testtbl ' - '(`some quoted a`(10), `some quoted b`(20))') + self.assert_compile( + schema.CreateIndex(idx1), + "CREATE INDEX test_idx1 ON testtbl " + "(`some quoted a`(10), `some quoted b`(20))", + ) def test_create_composite_index_with_length(self): m = MetaData() - tbl = Table('testtbl', m, - Column('a', String(255)), - Column('b', String(255))) + tbl = Table( + "testtbl", m, Column("a", String(255)), Column("b", String(255)) + ) - idx1 = Index('test_idx1', tbl.c.a, tbl.c.b, - mysql_length={'a': 10, 'b': 20}) - idx2 = Index('test_idx2', tbl.c.a, tbl.c.b, - mysql_length={'a': 15}) - idx3 = Index('test_idx3', tbl.c.a, tbl.c.b, - mysql_length=30) + idx1 = Index( + "test_idx1", tbl.c.a, tbl.c.b, mysql_length={"a": 10, "b": 20} + ) + idx2 = Index("test_idx2", tbl.c.a, tbl.c.b, mysql_length={"a": 15}) + idx3 = Index("test_idx3", tbl.c.a, tbl.c.b, mysql_length=30) self.assert_compile( schema.CreateIndex(idx1), - 'CREATE INDEX test_idx1 ON testtbl (a(10), b(20))' + "CREATE INDEX test_idx1 ON testtbl (a(10), b(20))", ) self.assert_compile( schema.CreateIndex(idx2), - 'CREATE INDEX test_idx2 ON testtbl (a(15), b)' + "CREATE INDEX test_idx2 ON testtbl (a(15), b)", ) self.assert_compile( schema.CreateIndex(idx3), - 'CREATE INDEX test_idx3 ON testtbl (a(30), b(30))' + "CREATE INDEX test_idx3 ON testtbl (a(30), b(30))", ) def test_create_index_with_using(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', String(255))) - idx1 = Index('test_idx1', tbl.c.data, mysql_using='btree') - idx2 = Index('test_idx2', tbl.c.data, mysql_using='hash') + tbl = Table("testtbl", m, Column("data", String(255))) + idx1 = Index("test_idx1", tbl.c.data, mysql_using="btree") + idx2 = Index("test_idx2", tbl.c.data, mysql_using="hash") self.assert_compile( schema.CreateIndex(idx1), - 'CREATE INDEX test_idx1 ON testtbl (data) USING btree') + "CREATE INDEX test_idx1 ON testtbl (data) USING btree", + ) self.assert_compile( schema.CreateIndex(idx2), - 'CREATE INDEX test_idx2 ON testtbl (data) USING hash') + "CREATE INDEX test_idx2 ON testtbl (data) USING hash", + ) def test_create_pk_plain(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', String(255)), - PrimaryKeyConstraint('data')) + tbl = Table( + "testtbl", + m, + Column("data", String(255)), + PrimaryKeyConstraint("data"), + ) self.assert_compile( schema.CreateTable(tbl), "CREATE TABLE testtbl (data VARCHAR(255) NOT NULL, " - "PRIMARY KEY (data))") + "PRIMARY KEY (data))", + ) def test_create_pk_with_using(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', String(255)), - PrimaryKeyConstraint('data', mysql_using='btree')) + tbl = Table( + "testtbl", + m, + Column("data", String(255)), + PrimaryKeyConstraint("data", mysql_using="btree"), + ) self.assert_compile( schema.CreateTable(tbl), "CREATE TABLE testtbl (data VARCHAR(255) NOT NULL, " - "PRIMARY KEY (data) USING btree)") + "PRIMARY KEY (data) USING btree)", + ) def test_create_index_expr(self): m = MetaData() - t1 = Table('foo', m, - Column('x', Integer) - ) + t1 = Table("foo", m, Column("x", Integer)) self.assert_compile( schema.CreateIndex(Index("bar", t1.c.x > 5)), - "CREATE INDEX bar ON foo (x > 5)" + "CREATE INDEX bar ON foo (x > 5)", ) def test_deferrable_initially_kw_not_ignored(self): m = MetaData() - Table('t1', m, Column('id', Integer, primary_key=True)) + Table("t1", m, Column("id", Integer, primary_key=True)) t2 = Table( - 't2', m, Column( - 'id', Integer, - ForeignKey('t1.id', deferrable=True, initially="XYZ"), - primary_key=True)) + "t2", + m, + Column( + "id", + Integer, + ForeignKey("t1.id", deferrable=True, initially="XYZ"), + primary_key=True, + ), + ) self.assert_compile( schema.CreateTable(t2), "CREATE TABLE t2 (id INTEGER NOT NULL, " "PRIMARY KEY (id), FOREIGN KEY(id) REFERENCES t1 (id) " - "DEFERRABLE INITIALLY XYZ)" + "DEFERRABLE INITIALLY XYZ)", ) def test_match_kw_raises(self): m = MetaData() - Table('t1', m, Column('id', Integer, primary_key=True)) - t2 = Table('t2', m, Column('id', Integer, - ForeignKey('t1.id', match="XYZ"), - primary_key=True)) + Table("t1", m, Column("id", Integer, primary_key=True)) + t2 = Table( + "t2", + m, + Column( + "id", + Integer, + ForeignKey("t1.id", match="XYZ"), + primary_key=True, + ), + ) assert_raises_message( exc.CompileError, "MySQL ignores the 'MATCH' keyword while at the same time causes " "ON UPDATE/ON DELETE clauses to be ignored.", - schema.CreateTable(t2).compile, dialect=mysql.dialect() + schema.CreateTable(t2).compile, + dialect=mysql.dialect(), ) def test_match(self): - matchtable = table('matchtable', column('title', String)) + matchtable = table("matchtable", column("title", String)) self.assert_compile( - matchtable.c.title.match('somstr'), - "MATCH (matchtable.title) AGAINST (%s IN BOOLEAN MODE)") + matchtable.c.title.match("somstr"), + "MATCH (matchtable.title) AGAINST (%s IN BOOLEAN MODE)", + ) def test_match_compile_kw(self): - expr = literal('x').match(literal('y')) + expr = literal("x").match(literal("y")) self.assert_compile( expr, "MATCH ('x') AGAINST ('y' IN BOOLEAN MODE)", - literal_binds=True + literal_binds=True, ) def test_concat_compile_kw(self): - expr = literal('x', type_=String) + literal('y', type_=String) - self.assert_compile( - expr, - "concat('x', 'y')", - literal_binds=True - ) + expr = literal("x", type_=String) + literal("y", type_=String) + self.assert_compile(expr, "concat('x', 'y')", literal_binds=True) def test_for_update(self): - table1 = table('mytable', - column('myid'), column('name'), column('description')) + table1 = table( + "mytable", column("myid"), column("name"), column("description") + ) self.assert_compile( table1.select(table1.c.myid == 7).with_for_update(), "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = %s FOR UPDATE") + "FROM mytable WHERE mytable.myid = %s FOR UPDATE", + ) self.assert_compile( table1.select(table1.c.myid == 7).with_for_update(read=True), "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE") + "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE", + ) def test_delete_extra_froms(self): - t1 = table('t1', column('c1')) - t2 = table('t2', column('c1')) + t1 = table("t1", column("c1")) + t2 = table("t2", column("c1")) q = sql.delete(t1).where(t1.c.c1 == t2.c.c1) self.assert_compile( q, "DELETE FROM t1 USING t1, t2 WHERE t1.c1 = t2.c1" ) def test_delete_extra_froms_alias(self): - a1 = table('t1', column('c1')).alias('a1') - t2 = table('t2', column('c1')) + a1 = table("t1", column("c1")).alias("a1") + t2 = table("t2", column("c1")) q = sql.delete(a1).where(a1.c.c1 == t2.c.c1) self.assert_compile( q, "DELETE FROM a1 USING t1 AS a1, t2 WHERE a1.c1 = t2.c1" @@ -272,60 +365,60 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): def gen(distinct=None, prefixes=None): kw = {} if distinct is not None: - kw['distinct'] = distinct + kw["distinct"] = distinct if prefixes is not None: - kw['prefixes'] = prefixes - return str(select([column('q')], **kw).compile(dialect=dialect)) + kw["prefixes"] = prefixes + return str(select([column("q")], **kw).compile(dialect=dialect)) - eq_(gen(None), 'SELECT q') - eq_(gen(True), 'SELECT DISTINCT q') + eq_(gen(None), "SELECT q") + eq_(gen(True), "SELECT DISTINCT q") - eq_(gen(prefixes=['ALL']), 'SELECT ALL q') - eq_(gen(prefixes=['DISTINCTROW']), - 'SELECT DISTINCTROW q') + eq_(gen(prefixes=["ALL"]), "SELECT ALL q") + eq_(gen(prefixes=["DISTINCTROW"]), "SELECT DISTINCTROW q") # Interaction with MySQL prefix extensions + eq_(gen(None, ["straight_join"]), "SELECT straight_join q") eq_( - gen(None, ['straight_join']), - 'SELECT straight_join q') - eq_( - gen(False, ['HIGH_PRIORITY', 'SQL_SMALL_RESULT', 'ALL']), - 'SELECT HIGH_PRIORITY SQL_SMALL_RESULT ALL q') + gen(False, ["HIGH_PRIORITY", "SQL_SMALL_RESULT", "ALL"]), + "SELECT HIGH_PRIORITY SQL_SMALL_RESULT ALL q", + ) eq_( - gen(True, ['high_priority', sql.text('sql_cache')]), - 'SELECT high_priority sql_cache DISTINCT q') + gen(True, ["high_priority", sql.text("sql_cache")]), + "SELECT high_priority sql_cache DISTINCT q", + ) def test_backslash_escaping(self): self.assert_compile( - sql.column('foo').like('bar', escape='\\'), - "foo LIKE %s ESCAPE '\\\\'" + sql.column("foo").like("bar", escape="\\"), + "foo LIKE %s ESCAPE '\\\\'", ) dialect = mysql.dialect() dialect._backslash_escapes = False self.assert_compile( - sql.column('foo').like('bar', escape='\\'), + sql.column("foo").like("bar", escape="\\"), "foo LIKE %s ESCAPE '\\'", - dialect=dialect + dialect=dialect, ) def test_limit(self): - t = sql.table('t', sql.column('col1'), sql.column('col2')) + t = sql.table("t", sql.column("col1"), sql.column("col2")) self.assert_compile( select([t]).limit(10).offset(20), "SELECT t.col1, t.col2 FROM t LIMIT %s, %s", - {'param_1': 20, 'param_2': 10} + {"param_1": 20, "param_2": 10}, ) self.assert_compile( select([t]).limit(10), "SELECT t.col1, t.col2 FROM t LIMIT %s", - {'param_1': 10}) + {"param_1": 10}, + ) self.assert_compile( select([t]).offset(10), "SELECT t.col1, t.col2 FROM t LIMIT %s, 18446744073709551615", - {'param_1': 10} + {"param_1": 10}, ) def test_varchar_raise(self): @@ -343,38 +436,35 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): exc.CompileError, "VARCHAR requires a length on dialect mysql", type_.compile, - dialect=mysql.dialect() + dialect=mysql.dialect(), ) - t1 = Table('sometable', MetaData(), - Column('somecolumn', type_) - ) + t1 = Table("sometable", MetaData(), Column("somecolumn", type_)) assert_raises_message( exc.CompileError, r"\(in table 'sometable', column 'somecolumn'\)\: " r"(?:N)?VARCHAR requires a length on dialect mysql", schema.CreateTable(t1).compile, - dialect=mysql.dialect() + dialect=mysql.dialect(), ) def test_update_limit(self): - t = sql.table('t', sql.column('col1'), sql.column('col2')) + t = sql.table("t", sql.column("col1"), sql.column("col2")) self.assert_compile( - t.update(values={'col1': 123}), - "UPDATE t SET col1=%s" + t.update(values={"col1": 123}), "UPDATE t SET col1=%s" ) self.assert_compile( - t.update(values={'col1': 123}, mysql_limit=5), - "UPDATE t SET col1=%s LIMIT 5" + t.update(values={"col1": 123}, mysql_limit=5), + "UPDATE t SET col1=%s LIMIT 5", ) self.assert_compile( - t.update(values={'col1': 123}, mysql_limit=None), - "UPDATE t SET col1=%s" + t.update(values={"col1": 123}, mysql_limit=None), + "UPDATE t SET col1=%s", ) self.assert_compile( - t.update(t.c.col2 == 456, values={'col1': 123}, mysql_limit=1), - "UPDATE t SET col1=%s WHERE t.col2 = %s LIMIT 1" + t.update(t.c.col2 == 456, values={"col1": 123}, mysql_limit=1), + "UPDATE t SET col1=%s WHERE t.col2 = %s LIMIT 1", ) def test_utc_timestamp(self): @@ -382,14 +472,16 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): def test_utc_timestamp_fsp(self): self.assert_compile( - func.utc_timestamp(5), "utc_timestamp(%s)", - checkparams={"utc_timestamp_1": 5}) + func.utc_timestamp(5), + "utc_timestamp(%s)", + checkparams={"utc_timestamp_1": 5}, + ) def test_sysdate(self): self.assert_compile(func.sysdate(), "SYSDATE()") def test_cast(self): - t = sql.table('t', sql.column('col')) + t = sql.table("t", sql.column("col")) m = mysql specs = [ @@ -403,16 +495,13 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): # 'SIGNED INTEGER' is a bigint, so this is ok. (m.MSBigInteger, "CAST(t.col AS SIGNED INTEGER)"), (m.MSBigInteger(unsigned=False), "CAST(t.col AS SIGNED INTEGER)"), - (m.MSBigInteger(unsigned=True), - "CAST(t.col AS UNSIGNED INTEGER)"), - + (m.MSBigInteger(unsigned=True), "CAST(t.col AS UNSIGNED INTEGER)"), # this is kind of sucky. thank you default arguments! (NUMERIC, "CAST(t.col AS DECIMAL)"), (DECIMAL, "CAST(t.col AS DECIMAL)"), (Numeric, "CAST(t.col AS DECIMAL)"), (m.MSNumeric, "CAST(t.col AS DECIMAL)"), (m.MSDecimal, "CAST(t.col AS DECIMAL)"), - (TIMESTAMP, "CAST(t.col AS DATETIME)"), (DATETIME, "CAST(t.col AS DATETIME)"), (DATE, "CAST(t.col AS DATE)"), @@ -424,17 +513,16 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): (Date, "CAST(t.col AS DATE)"), (m.MSTime, "CAST(t.col AS TIME)"), (m.MSTimeStamp, "CAST(t.col AS DATETIME)"), - (String, "CAST(t.col AS CHAR)"), (Unicode, "CAST(t.col AS CHAR)"), (UnicodeText, "CAST(t.col AS CHAR)"), (VARCHAR, "CAST(t.col AS CHAR)"), (NCHAR, "CAST(t.col AS CHAR)"), (CHAR, "CAST(t.col AS CHAR)"), - (m.CHAR(charset='utf8'), "CAST(t.col AS CHAR CHARACTER SET utf8)"), + (m.CHAR(charset="utf8"), "CAST(t.col AS CHAR CHARACTER SET utf8)"), (CLOB, "CAST(t.col AS CHAR)"), (TEXT, "CAST(t.col AS CHAR)"), - (m.TEXT(charset='utf8'), "CAST(t.col AS CHAR CHARACTER SET utf8)"), + (m.TEXT(charset="utf8"), "CAST(t.col AS CHAR CHARACTER SET utf8)"), (String(32), "CAST(t.col AS CHAR(32))"), (Unicode(32), "CAST(t.col AS CHAR(32))"), (CHAR(32), "CAST(t.col AS CHAR(32))"), @@ -445,7 +533,6 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): (m.MSLongText, "CAST(t.col AS CHAR)"), (m.MSNChar, "CAST(t.col AS CHAR)"), (m.MSNVarChar, "CAST(t.col AS CHAR)"), - (LargeBinary, "CAST(t.col AS BINARY)"), (BLOB, "CAST(t.col AS BINARY)"), (m.MSBlob, "CAST(t.col AS BINARY)"), @@ -457,9 +544,7 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): (m.MSBinary(32), "CAST(t.col AS BINARY)"), (m.MSVarBinary, "CAST(t.col AS BINARY)"), (m.MSVarBinary(32), "CAST(t.col AS BINARY)"), - (Interval, "CAST(t.col AS DATETIME)"), - ] for type_, expected in specs: @@ -470,63 +555,52 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): impl = Integer type_ = MyInteger() - t = sql.table('t', sql.column('col')) + t = sql.table("t", sql.column("col")) self.assert_compile( - cast(t.c.col, type_), "CAST(t.col AS SIGNED INTEGER)") + cast(t.c.col, type_), "CAST(t.col AS SIGNED INTEGER)" + ) def test_cast_literal_bind(self): - expr = cast(column('foo', Integer) + 5, Integer()) + expr = cast(column("foo", Integer) + 5, Integer()) self.assert_compile( - expr, - "CAST(foo + 5 AS SIGNED INTEGER)", - literal_binds=True + expr, "CAST(foo + 5 AS SIGNED INTEGER)", literal_binds=True ) def test_unsupported_cast_literal_bind(self): - expr = cast(column('foo', Integer) + 5, Float) + expr = cast(column("foo", Integer) + 5, Float) - with expect_warnings( - "Datatype FLOAT does not support CAST on MySQL;" - ): - self.assert_compile( - expr, - "(foo + 5)", - literal_binds=True - ) + with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"): + self.assert_compile(expr, "(foo + 5)", literal_binds=True) dialect = mysql.MySQLDialect() dialect.server_version_info = (3, 9, 8) - with expect_warnings( - "Current MySQL version does not support CAST" - ): + with expect_warnings("Current MySQL version does not support CAST"): eq_( - str(expr.compile( - dialect=dialect, - compile_kwargs={"literal_binds": True})), - "(foo + 5)" + str( + expr.compile( + dialect=dialect, compile_kwargs={"literal_binds": True} + ) + ), + "(foo + 5)", ) def test_unsupported_casts(self): - t = sql.table('t', sql.column('col')) + t = sql.table("t", sql.column("col")) m = mysql specs = [ (m.MSBit, "t.col"), - (FLOAT, "t.col"), (Float, "t.col"), (m.MSFloat, "t.col"), (m.MSDouble, "t.col"), (m.MSReal, "t.col"), - (m.MSYear, "t.col"), (m.MSYear(2), "t.col"), - (Boolean, "t.col"), (BOOLEAN, "t.col"), - (m.MSEnum, "t.col"), (m.MSEnum("1", "2"), "t.col"), (m.MSSet, "t.col"), @@ -541,23 +615,19 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): def test_no_cast_pre_4(self): self.assert_compile( - cast(Column('foo', Integer), String), - "CAST(foo AS CHAR)", + cast(Column("foo", Integer), String), "CAST(foo AS CHAR)" ) dialect = mysql.dialect() dialect.server_version_info = (3, 2, 3) with expect_warnings("Current MySQL version does not support CAST;"): self.assert_compile( - cast(Column('foo', Integer), String), - "foo", - dialect=dialect + cast(Column("foo", Integer), String), "foo", dialect=dialect ) def test_cast_grouped_expression_non_castable(self): with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"): self.assert_compile( - cast(sql.column('x') + sql.column('y'), Float), - "(x + y)" + cast(sql.column("x") + sql.column("y"), Float), "(x + y)" ) def test_cast_grouped_expression_pre_4(self): @@ -565,173 +635,199 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): dialect.server_version_info = (3, 2, 3) with expect_warnings("Current MySQL version does not support CAST;"): self.assert_compile( - cast(sql.column('x') + sql.column('y'), Integer), + cast(sql.column("x") + sql.column("y"), Integer), "(x + y)", - dialect=dialect + dialect=dialect, ) def test_extract(self): - t = sql.table('t', sql.column('col1')) + t = sql.table("t", sql.column("col1")) - for field in 'year', 'month', 'day': + for field in "year", "month", "day": self.assert_compile( select([extract(field, t.c.col1)]), - "SELECT EXTRACT(%s FROM t.col1) AS anon_1 FROM t" % field) + "SELECT EXTRACT(%s FROM t.col1) AS anon_1 FROM t" % field, + ) # millsecondS to millisecond self.assert_compile( - select([extract('milliseconds', t.c.col1)]), - "SELECT EXTRACT(millisecond FROM t.col1) AS anon_1 FROM t") + select([extract("milliseconds", t.c.col1)]), + "SELECT EXTRACT(millisecond FROM t.col1) AS anon_1 FROM t", + ) def test_too_long_index(self): - exp = 'ix_zyrenian_zyme_zyzzogeton_zyzzogeton_zyrenian_zyme_zyz_5cd2' - tname = 'zyrenian_zyme_zyzzogeton_zyzzogeton' - cname = 'zyrenian_zyme_zyzzogeton_zo' + exp = "ix_zyrenian_zyme_zyzzogeton_zyzzogeton_zyrenian_zyme_zyz_5cd2" + tname = "zyrenian_zyme_zyzzogeton_zyzzogeton" + cname = "zyrenian_zyme_zyzzogeton_zo" - t1 = Table(tname, MetaData(), - Column(cname, Integer, index=True), - ) + t1 = Table(tname, MetaData(), Column(cname, Integer, index=True)) ix1 = list(t1.indexes)[0] self.assert_compile( schema.CreateIndex(ix1), - "CREATE INDEX %s " - "ON %s (%s)" % (exp, tname, cname) + "CREATE INDEX %s " "ON %s (%s)" % (exp, tname, cname), ) def test_innodb_autoincrement(self): t1 = Table( - 'sometable', MetaData(), + "sometable", + MetaData(), Column( - 'assigned_id', Integer(), primary_key=True, - autoincrement=False), - Column('id', Integer(), primary_key=True, autoincrement=True), - mysql_engine='InnoDB') - self.assert_compile(schema.CreateTable(t1), - 'CREATE TABLE sometable (assigned_id ' - 'INTEGER NOT NULL, id INTEGER NOT NULL ' - 'AUTO_INCREMENT, PRIMARY KEY (id, assigned_id)' - ')ENGINE=InnoDB') - - t1 = Table('sometable', MetaData(), - Column('assigned_id', Integer(), primary_key=True, - autoincrement=True), - Column('id', Integer(), primary_key=True, - autoincrement=False), mysql_engine='InnoDB') - self.assert_compile(schema.CreateTable(t1), - 'CREATE TABLE sometable (assigned_id ' - 'INTEGER NOT NULL AUTO_INCREMENT, id ' - 'INTEGER NOT NULL, PRIMARY KEY ' - '(assigned_id, id))ENGINE=InnoDB') + "assigned_id", Integer(), primary_key=True, autoincrement=False + ), + Column("id", Integer(), primary_key=True, autoincrement=True), + mysql_engine="InnoDB", + ) + self.assert_compile( + schema.CreateTable(t1), + "CREATE TABLE sometable (assigned_id " + "INTEGER NOT NULL, id INTEGER NOT NULL " + "AUTO_INCREMENT, PRIMARY KEY (id, assigned_id)" + ")ENGINE=InnoDB", + ) + + t1 = Table( + "sometable", + MetaData(), + Column( + "assigned_id", Integer(), primary_key=True, autoincrement=True + ), + Column("id", Integer(), primary_key=True, autoincrement=False), + mysql_engine="InnoDB", + ) + self.assert_compile( + schema.CreateTable(t1), + "CREATE TABLE sometable (assigned_id " + "INTEGER NOT NULL AUTO_INCREMENT, id " + "INTEGER NOT NULL, PRIMARY KEY " + "(assigned_id, id))ENGINE=InnoDB", + ) def test_innodb_autoincrement_reserved_word_column_name(self): t1 = Table( - 'sometable', MetaData(), - Column('id', Integer(), primary_key=True, autoincrement=False), - Column('order', Integer(), primary_key=True, autoincrement=True), - mysql_engine='InnoDB') + "sometable", + MetaData(), + Column("id", Integer(), primary_key=True, autoincrement=False), + Column("order", Integer(), primary_key=True, autoincrement=True), + mysql_engine="InnoDB", + ) self.assert_compile( schema.CreateTable(t1), - 'CREATE TABLE sometable (' - 'id INTEGER NOT NULL, ' - '`order` INTEGER NOT NULL AUTO_INCREMENT, ' - 'PRIMARY KEY (`order`, id)' - ')ENGINE=InnoDB') + "CREATE TABLE sometable (" + "id INTEGER NOT NULL, " + "`order` INTEGER NOT NULL AUTO_INCREMENT, " + "PRIMARY KEY (`order`, id)" + ")ENGINE=InnoDB", + ) def test_create_table_with_partition(self): t1 = Table( - 'testtable', MetaData(), - Column('id', Integer(), primary_key=True, autoincrement=True), - Column('other_id', Integer(), primary_key=True, - autoincrement=False), - mysql_partitions='2', mysql_partition_by='KEY(other_id)') + "testtable", + MetaData(), + Column("id", Integer(), primary_key=True, autoincrement=True), + Column( + "other_id", Integer(), primary_key=True, autoincrement=False + ), + mysql_partitions="2", + mysql_partition_by="KEY(other_id)", + ) self.assert_compile( schema.CreateTable(t1), - 'CREATE TABLE testtable (' - 'id INTEGER NOT NULL AUTO_INCREMENT, ' - 'other_id INTEGER NOT NULL, ' - 'PRIMARY KEY (id, other_id)' - ')PARTITION BY KEY(other_id) PARTITIONS 2' + "CREATE TABLE testtable (" + "id INTEGER NOT NULL AUTO_INCREMENT, " + "other_id INTEGER NOT NULL, " + "PRIMARY KEY (id, other_id)" + ")PARTITION BY KEY(other_id) PARTITIONS 2", ) def test_create_table_with_subpartition(self): t1 = Table( - 'testtable', MetaData(), - Column('id', Integer(), primary_key=True, autoincrement=True), - Column('other_id', Integer(), primary_key=True, - autoincrement=False), - mysql_partitions='2', - mysql_partition_by='KEY(other_id)', + "testtable", + MetaData(), + Column("id", Integer(), primary_key=True, autoincrement=True), + Column( + "other_id", Integer(), primary_key=True, autoincrement=False + ), + mysql_partitions="2", + mysql_partition_by="KEY(other_id)", mysql_subpartition_by="HASH(some_expr)", - mysql_subpartitions='2') + mysql_subpartitions="2", + ) self.assert_compile( schema.CreateTable(t1), - 'CREATE TABLE testtable (' - 'id INTEGER NOT NULL AUTO_INCREMENT, ' - 'other_id INTEGER NOT NULL, ' - 'PRIMARY KEY (id, other_id)' - ')PARTITION BY KEY(other_id) PARTITIONS 2 ' - 'SUBPARTITION BY HASH(some_expr) SUBPARTITIONS 2' + "CREATE TABLE testtable (" + "id INTEGER NOT NULL AUTO_INCREMENT, " + "other_id INTEGER NOT NULL, " + "PRIMARY KEY (id, other_id)" + ")PARTITION BY KEY(other_id) PARTITIONS 2 " + "SUBPARTITION BY HASH(some_expr) SUBPARTITIONS 2", ) def test_create_table_with_partition_hash(self): t1 = Table( - 'testtable', MetaData(), - Column('id', Integer(), primary_key=True, autoincrement=True), - Column('other_id', Integer(), primary_key=True, - autoincrement=False), - mysql_partitions='2', mysql_partition_by='HASH(other_id)') + "testtable", + MetaData(), + Column("id", Integer(), primary_key=True, autoincrement=True), + Column( + "other_id", Integer(), primary_key=True, autoincrement=False + ), + mysql_partitions="2", + mysql_partition_by="HASH(other_id)", + ) self.assert_compile( schema.CreateTable(t1), - 'CREATE TABLE testtable (' - 'id INTEGER NOT NULL AUTO_INCREMENT, ' - 'other_id INTEGER NOT NULL, ' - 'PRIMARY KEY (id, other_id)' - ')PARTITION BY HASH(other_id) PARTITIONS 2' + "CREATE TABLE testtable (" + "id INTEGER NOT NULL AUTO_INCREMENT, " + "other_id INTEGER NOT NULL, " + "PRIMARY KEY (id, other_id)" + ")PARTITION BY HASH(other_id) PARTITIONS 2", ) def test_create_table_with_partition_and_other_opts(self): t1 = Table( - 'testtable', MetaData(), - Column('id', Integer(), primary_key=True, autoincrement=True), - Column('other_id', Integer(), primary_key=True, - autoincrement=False), - mysql_stats_sample_pages='2', - mysql_partitions='2', mysql_partition_by='HASH(other_id)') + "testtable", + MetaData(), + Column("id", Integer(), primary_key=True, autoincrement=True), + Column( + "other_id", Integer(), primary_key=True, autoincrement=False + ), + mysql_stats_sample_pages="2", + mysql_partitions="2", + mysql_partition_by="HASH(other_id)", + ) self.assert_compile( schema.CreateTable(t1), - 'CREATE TABLE testtable (' - 'id INTEGER NOT NULL AUTO_INCREMENT, ' - 'other_id INTEGER NOT NULL, ' - 'PRIMARY KEY (id, other_id)' - ')STATS_SAMPLE_PAGES=2 PARTITION BY HASH(other_id) PARTITIONS 2' + "CREATE TABLE testtable (" + "id INTEGER NOT NULL AUTO_INCREMENT, " + "other_id INTEGER NOT NULL, " + "PRIMARY KEY (id, other_id)" + ")STATS_SAMPLE_PAGES=2 PARTITION BY HASH(other_id) PARTITIONS 2", ) def test_inner_join(self): - t1 = table('t1', column('x')) - t2 = table('t2', column('y')) + t1 = table("t1", column("x")) + t2 = table("t2", column("y")) self.assert_compile( - t1.join(t2, t1.c.x == t2.c.y), - "t1 INNER JOIN t2 ON t1.x = t2.y" + t1.join(t2, t1.c.x == t2.c.y), "t1 INNER JOIN t2 ON t1.x = t2.y" ) def test_outer_join(self): - t1 = table('t1', column('x')) - t2 = table('t2', column('y')) + t1 = table("t1", column("x")) + t2 = table("t2", column("y")) self.assert_compile( t1.outerjoin(t2, t1.c.x == t2.c.y), - "t1 LEFT OUTER JOIN t2 ON t1.x = t2.y" + "t1 LEFT OUTER JOIN t2 ON t1.x = t2.y", ) def test_full_outer_join(self): - t1 = table('t1', column('x')) - t2 = table('t2', column('y')) + t1 = table("t1", column("x")) + t2 = table("t2", column("y")) self.assert_compile( t1.outerjoin(t2, t1.c.x == t2.c.y, full=True), - "t1 FULL OUTER JOIN t2 ON t1.x = t2.y" + "t1 FULL OUTER JOIN t2 ON t1.x = t2.y", ) @@ -740,39 +836,44 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL): def setup(self): self.table = Table( - 'foos', MetaData(), - Column('id', Integer, primary_key=True), - Column('bar', String(10)), - Column('baz', String(10)), + "foos", + MetaData(), + Column("id", Integer, primary_key=True), + Column("bar", String(10)), + Column("baz", String(10)), ) def test_from_values(self): stmt = insert(self.table).values( - [{'id': 1, 'bar': 'ab'}, {'id': 2, 'bar': 'b'}]) + [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}] + ) stmt = stmt.on_duplicate_key_update( - bar=stmt.inserted.bar, baz=stmt.inserted.baz) + bar=stmt.inserted.bar, baz=stmt.inserted.baz + ) expected_sql = ( - 'INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) ' - 'ON DUPLICATE KEY UPDATE bar = VALUES(bar), baz = VALUES(baz)' + "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) " + "ON DUPLICATE KEY UPDATE bar = VALUES(bar), baz = VALUES(baz)" ) self.assert_compile(stmt, expected_sql) def test_from_literal(self): stmt = insert(self.table).values( - [{'id': 1, 'bar': 'ab'}, {'id': 2, 'bar': 'b'}]) - stmt = stmt.on_duplicate_key_update(bar=literal_column('bb')) + [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}] + ) + stmt = stmt.on_duplicate_key_update(bar=literal_column("bb")) expected_sql = ( - 'INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) ' - 'ON DUPLICATE KEY UPDATE bar = bb' + "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) " + "ON DUPLICATE KEY UPDATE bar = bb" ) self.assert_compile(stmt, expected_sql) def test_python_values(self): stmt = insert(self.table).values( - [{'id': 1, 'bar': 'ab'}, {'id': 2, 'bar': 'b'}]) + [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}] + ) stmt = stmt.on_duplicate_key_update(bar="foobar") expected_sql = ( - 'INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) ' - 'ON DUPLICATE KEY UPDATE bar = %s' + "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) " + "ON DUPLICATE KEY UPDATE bar = %s" ) self.assert_compile(stmt, expected_sql) diff --git a/test/dialect/mysql/test_dialect.py b/test/dialect/mysql/test_dialect.py index 44f786ee03..ffcea3bd1b 100644 --- a/test/dialect/mysql/test_dialect.py +++ b/test/dialect/mysql/test_dialect.py @@ -13,121 +13,137 @@ from sqlalchemy.dialects import mysql class DialectTest(fixtures.TestBase): __backend__ = True - __only_on__ = 'mysql' + __only_on__ = "mysql" def test_ssl_arguments_mysqldb(self): from sqlalchemy.dialects.mysql import mysqldb + dialect = mysqldb.dialect() self._test_ssl_arguments(dialect) def test_ssl_arguments_oursql(self): from sqlalchemy.dialects.mysql import oursql + dialect = oursql.dialect() self._test_ssl_arguments(dialect) def _test_ssl_arguments(self, dialect): kwarg = dialect.create_connect_args( - make_url("mysql://scott:tiger@localhost:3306/test" - "?ssl_ca=/ca.pem&ssl_cert=/cert.pem&ssl_key=/key.pem") + make_url( + "mysql://scott:tiger@localhost:3306/test" + "?ssl_ca=/ca.pem&ssl_cert=/cert.pem&ssl_key=/key.pem" + ) )[1] # args that differ among mysqldb and oursql - for k in ('use_unicode', 'found_rows', 'client_flag'): + for k in ("use_unicode", "found_rows", "client_flag"): kwarg.pop(k, None) eq_( kwarg, { - 'passwd': 'tiger', 'db': 'test', - 'ssl': {'ca': '/ca.pem', 'cert': '/cert.pem', - 'key': '/key.pem'}, - 'host': 'localhost', 'user': 'scott', - 'port': 3306 - } + "passwd": "tiger", + "db": "test", + "ssl": { + "ca": "/ca.pem", + "cert": "/cert.pem", + "key": "/key.pem", + }, + "host": "localhost", + "user": "scott", + "port": 3306, + }, ) def test_normal_arguments_mysqldb(self): from sqlalchemy.dialects.mysql import mysqldb + dialect = mysqldb.dialect() self._test_normal_arguments(dialect) def _test_normal_arguments(self, dialect): for kwarg, value in [ - ('compress', True), - ('connect_timeout', 30), - ('read_timeout', 30), - ('write_timeout', 30), - ('client_flag', 1234), - ('local_infile', 1234), - ('use_unicode', False), - ('charset', 'hello') + ("compress", True), + ("connect_timeout", 30), + ("read_timeout", 30), + ("write_timeout", 30), + ("client_flag", 1234), + ("local_infile", 1234), + ("use_unicode", False), + ("charset", "hello"), ]: connect_args = dialect.create_connect_args( - make_url("mysql://scott:tiger@localhost:3306/test" - "?%s=%s" % (kwarg, value)) + make_url( + "mysql://scott:tiger@localhost:3306/test" + "?%s=%s" % (kwarg, value) + ) ) eq_(connect_args[1][kwarg], value) def test_mysqlconnector_buffered_arg(self): from sqlalchemy.dialects.mysql import mysqlconnector + dialect = mysqlconnector.dialect() kw = dialect.create_connect_args( - make_url("mysql+mysqlconnector://u:p@host/db?buffered=true") - )[1] - eq_(kw['buffered'], True) + make_url("mysql+mysqlconnector://u:p@host/db?buffered=true") + )[1] + eq_(kw["buffered"], True) kw = dialect.create_connect_args( - make_url("mysql+mysqlconnector://u:p@host/db?buffered=false") - )[1] - eq_(kw['buffered'], False) + make_url("mysql+mysqlconnector://u:p@host/db?buffered=false") + )[1] + eq_(kw["buffered"], False) kw = dialect.create_connect_args( - make_url("mysql+mysqlconnector://u:p@host/db") - )[1] - eq_(kw['buffered'], True) + make_url("mysql+mysqlconnector://u:p@host/db") + )[1] + eq_(kw["buffered"], True) def test_mysqlconnector_raise_on_warnings_arg(self): from sqlalchemy.dialects.mysql import mysqlconnector + dialect = mysqlconnector.dialect() kw = dialect.create_connect_args( make_url( "mysql+mysqlconnector://u:p@host/db?raise_on_warnings=true" ) )[1] - eq_(kw['raise_on_warnings'], True) + eq_(kw["raise_on_warnings"], True) kw = dialect.create_connect_args( make_url( "mysql+mysqlconnector://u:p@host/db?raise_on_warnings=false" ) )[1] - eq_(kw['raise_on_warnings'], False) + eq_(kw["raise_on_warnings"], False) kw = dialect.create_connect_args( - make_url("mysql+mysqlconnector://u:p@host/db") - )[1] + make_url("mysql+mysqlconnector://u:p@host/db") + )[1] assert "raise_on_warnings" not in kw - @testing.only_on('mysql') + @testing.only_on("mysql") def test_random_arg(self): dialect = testing.db.dialect kw = dialect.create_connect_args( - make_url("mysql://u:p@host/db?foo=true") - )[1] - eq_(kw['foo'], "true") + make_url("mysql://u:p@host/db?foo=true") + )[1] + eq_(kw["foo"], "true") - @testing.only_on('mysql') - @testing.skip_if('mysql+mysqlconnector', "totally broken for the moment") - @testing.fails_on('mysql+oursql', "unsupported") + @testing.only_on("mysql") + @testing.skip_if("mysql+mysqlconnector", "totally broken for the moment") + @testing.fails_on("mysql+oursql", "unsupported") def test_special_encodings(self): - for enc in ['utf8mb4', 'utf8']: + for enc in ["utf8mb4", "utf8"]: eng = engines.testing_engine( - options={"connect_args": {'charset': enc, 'use_unicode': 0}}) + options={"connect_args": {"charset": enc, "use_unicode": 0}} + ) conn = eng.connect() eq_(conn.dialect._connection_charset, enc) def test_no_show_variables(self): from sqlalchemy.testing import mock + engine = engines.testing_engine() def my_execute(self, statement, *args, **kw): @@ -137,7 +153,8 @@ class DialectTest(fixtures.TestBase): real_exec = engine._connection_cls._execute_text with mock.patch.object( - engine._connection_cls, "_execute_text", my_execute): + engine._connection_cls, "_execute_text", my_execute + ): with expect_warnings( "Could not retrieve SQL_MODE; please ensure the " "MySQL user has permissions to SHOW VARIABLES" @@ -146,63 +163,59 @@ class DialectTest(fixtures.TestBase): def test_autocommit_isolation_level(self): c = testing.db.connect().execution_options( - isolation_level='AUTOCOMMIT' + isolation_level="AUTOCOMMIT" ) - assert c.execute('SELECT @@autocommit;').scalar() + assert c.execute("SELECT @@autocommit;").scalar() - c = c.execution_options(isolation_level='READ COMMITTED') - assert not c.execute('SELECT @@autocommit;').scalar() + c = c.execution_options(isolation_level="READ COMMITTED") + assert not c.execute("SELECT @@autocommit;").scalar() def test_isolation_level(self): values = [ - 'READ UNCOMMITTED', - 'READ COMMITTED', - 'REPEATABLE READ', - 'SERIALIZABLE' + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "SERIALIZABLE", ] for value in values: - c = testing.db.connect().execution_options( - isolation_level=value - ) - eq_( - testing.db.dialect.get_isolation_level(c.connection), - value) + c = testing.db.connect().execution_options(isolation_level=value) + eq_(testing.db.dialect.get_isolation_level(c.connection), value) class ParseVersionTest(fixtures.TestBase): def test_mariadb_normalized_version(self): for expected, version in [ - ((10, 2, 7), (10, 2, 7, 'MariaDB')), - ((10, 2, 7), (5, 6, 15, 10, 2, 7, 'MariaDB')), - ((10, 2, 10), (10, 2, 10, 'MariaDB')), + ((10, 2, 7), (10, 2, 7, "MariaDB")), + ((10, 2, 7), (5, 6, 15, 10, 2, 7, "MariaDB")), + ((10, 2, 10), (10, 2, 10, "MariaDB")), ((5, 7, 20), (5, 7, 20)), ((5, 6, 15), (5, 6, 15)), - ((10, 2, 6), - (10, 2, 6, 'MariaDB', 10, 2, '6+maria~stretch', 'log')), + ( + (10, 2, 6), + (10, 2, 6, "MariaDB", 10, 2, "6+maria~stretch", "log"), + ), ]: dialect = mysql.dialect() dialect.server_version_info = version - eq_( - dialect._mariadb_normalized_version_info, - expected - ) + eq_(dialect._mariadb_normalized_version_info, expected) def test_mariadb_check_warning(self): for expect_, version in [ - (True, (10, 2, 7, 'MariaDB')), - (True, (5, 6, 15, 10, 2, 7, 'MariaDB')), - (False, (10, 2, 10, 'MariaDB')), + (True, (10, 2, 7, "MariaDB")), + (True, (5, 6, 15, 10, 2, 7, "MariaDB")), + (False, (10, 2, 10, "MariaDB")), (False, (5, 7, 20)), (False, (5, 6, 15)), - (True, (10, 2, 6, 'MariaDB', 10, 2, '6+maria~stretch', 'log')), + (True, (10, 2, 6, "MariaDB", 10, 2, "6+maria~stretch", "log")), ]: dialect = mysql.dialect() dialect.server_version_info = version if expect_: with expect_warnings( - ".*before 10.2.9 has known issues regarding " - "CHECK constraints"): + ".*before 10.2.9 has known issues regarding " + "CHECK constraints" + ): dialect._warn_for_known_db_issues() else: dialect._warn_for_known_db_issues() @@ -218,31 +231,34 @@ class RemoveUTCTimestampTest(fixtures.TablesTest): [ticket:3966] """ - __only_on__ = 'mysql' + + __only_on__ = "mysql" __backend__ = True @classmethod def define_tables(cls, metadata): Table( - 't', metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('data', DateTime) + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("data", DateTime), ) Table( - 't_default', metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('idata', DateTime, default=func.utc_timestamp()), - Column('udata', DateTime, onupdate=func.utc_timestamp()) + "t_default", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("idata", DateTime, default=func.utc_timestamp()), + Column("udata", DateTime, onupdate=func.utc_timestamp()), ) def test_insert_executemany(self): with testing.db.connect() as conn: conn.execute( self.tables.t.insert().values(data=func.utc_timestamp()), - [{"x": 5}, {"x": 6}, {"x": 7}] + [{"x": 5}, {"x": 6}, {"x": 7}], ) def test_update_executemany(self): @@ -253,21 +269,21 @@ class RemoveUTCTimestampTest(fixtures.TablesTest): [ {"x": 5, "data": timestamp}, {"x": 6, "data": timestamp}, - {"x": 7, "data": timestamp}] + {"x": 7, "data": timestamp}, + ], ) conn.execute( - self.tables.t.update(). - values(data=func.utc_timestamp()). - where(self.tables.t.c.x == bindparam('xval')), - [{"xval": 5}, {"xval": 6}, {"xval": 7}] + self.tables.t.update() + .values(data=func.utc_timestamp()) + .where(self.tables.t.c.x == bindparam("xval")), + [{"xval": 5}, {"xval": 6}, {"xval": 7}], ) def test_insert_executemany_w_default(self): with testing.db.connect() as conn: conn.execute( - self.tables.t_default.insert(), - [{"x": 5}, {"x": 6}, {"x": 7}] + self.tables.t_default.insert(), [{"x": 5}, {"x": 6}, {"x": 7}] ) def test_update_executemany_w_default(self): @@ -278,35 +294,39 @@ class RemoveUTCTimestampTest(fixtures.TablesTest): [ {"x": 5, "idata": timestamp}, {"x": 6, "idata": timestamp}, - {"x": 7, "idata": timestamp}] + {"x": 7, "idata": timestamp}, + ], ) conn.execute( - self.tables.t_default.update(). - values(idata=func.utc_timestamp()). - where(self.tables.t_default.c.x == bindparam('xval')), - [{"xval": 5}, {"xval": 6}, {"xval": 7}] + self.tables.t_default.update() + .values(idata=func.utc_timestamp()) + .where(self.tables.t_default.c.x == bindparam("xval")), + [{"xval": 5}, {"xval": 6}, {"xval": 7}], ) class SQLModeDetectionTest(fixtures.TestBase): - __only_on__ = 'mysql' + __only_on__ = "mysql" __backend__ = True def _options(self, modes): def connect(con, record): cursor = con.cursor() cursor.execute("set sql_mode='%s'" % (",".join(modes))) - e = engines.testing_engine(options={ - 'pool_events': [ - (connect, 'first_connect'), - (connect, 'connect') - ] - }) + + e = engines.testing_engine( + options={ + "pool_events": [ + (connect, "first_connect"), + (connect, "connect"), + ] + } + ) return e def test_backslash_escapes(self): - engine = self._options(['NO_BACKSLASH_ESCAPES']) + engine = self._options(["NO_BACKSLASH_ESCAPES"]) c = engine.connect() assert not engine.dialect._backslash_escapes c.close() @@ -319,14 +339,14 @@ class SQLModeDetectionTest(fixtures.TestBase): engine.dispose() def test_ansi_quotes(self): - engine = self._options(['ANSI_QUOTES']) + engine = self._options(["ANSI_QUOTES"]) c = engine.connect() assert engine.dialect._server_ansiquotes c.close() engine.dispose() def test_combination(self): - engine = self._options(['ANSI_QUOTES,NO_BACKSLASH_ESCAPES']) + engine = self._options(["ANSI_QUOTES,NO_BACKSLASH_ESCAPES"]) c = engine.connect() assert engine.dialect._server_ansiquotes assert not engine.dialect._backslash_escapes @@ -337,7 +357,7 @@ class SQLModeDetectionTest(fixtures.TestBase): class ExecutionTest(fixtures.TestBase): """Various MySQL execution special cases.""" - __only_on__ = 'mysql' + __only_on__ = "mysql" __backend__ = True def test_charset_caching(self): @@ -357,7 +377,7 @@ class ExecutionTest(fixtures.TestBase): class AutocommitTextTest(test_execute.AutocommitTextTest): - __only_on__ = 'mysql' + __only_on__ = "mysql" def test_load_data(self): self._test_keyword("LOAD DATA STUFF") diff --git a/test/dialect/mysql/test_for_update.py b/test/dialect/mysql/test_for_update.py index a8cbcfb87e..197c9d6937 100644 --- a/test/dialect/mysql/test_for_update.py +++ b/test/dialect/mysql/test_for_update.py @@ -15,15 +15,15 @@ from sqlalchemy import testing class MySQLForUpdateLockingTest(fixtures.DeclarativeMappedTest): __backend__ = True - __only_on__ = 'mysql' - __requires__ = 'mysql_for_update', + __only_on__ = "mysql" + __requires__ = ("mysql_for_update",) @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) x = Column(Integer) y = Column(Integer) @@ -31,9 +31,9 @@ class MySQLForUpdateLockingTest(fixtures.DeclarativeMappedTest): __table_args__ = {"mysql_engine": "InnoDB"} class B(Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) - a_id = Column(ForeignKey('a.id')) + a_id = Column(ForeignKey("a.id")) x = Column(Integer) y = Column(Integer) __table_args__ = {"mysql_engine": "InnoDB"} @@ -48,7 +48,7 @@ class MySQLForUpdateLockingTest(fixtures.DeclarativeMappedTest): s.add_all( [ A(x=5, y=5, bs=[B(x=4, y=4), B(x=2, y=8), B(x=7, y=1)]), - A(x=7, y=5, bs=[B(x=4, y=4), B(x=5, y=8)]) + A(x=7, y=5, bs=[B(x=4, y=4), B(x=5, y=8)]), ] ) s.commit() @@ -70,9 +70,7 @@ class MySQLForUpdateLockingTest(fixtures.DeclarativeMappedTest): alt_trans.execute("set innodb_lock_wait_timeout=1") # set x/y > 10 try: - alt_trans.execute( - update(A).values(x=15, y=19) - ) + alt_trans.execute(update(A).values(x=15, y=19)) except (exc.InternalError, exc.OperationalError) as err: assert "Lock wait timeout exceeded" in str(err) assert should_be_locked @@ -85,9 +83,7 @@ class MySQLForUpdateLockingTest(fixtures.DeclarativeMappedTest): alt_trans.execute("set innodb_lock_wait_timeout=1") # set x/y > 10 try: - alt_trans.execute( - update(B).values(x=15, y=19) - ) + alt_trans.execute(update(B).values(x=15, y=19)) except (exc.InternalError, exc.OperationalError) as err: assert "Lock wait timeout exceeded" in str(err) assert should_be_locked diff --git a/test/dialect/mysql/test_on_duplicate.py b/test/dialect/mysql/test_on_duplicate.py index 376f9a9af6..04072d4a9c 100644 --- a/test/dialect/mysql/test_on_duplicate.py +++ b/test/dialect/mysql/test_on_duplicate.py @@ -6,84 +6,117 @@ from sqlalchemy import Table, Column, Boolean, Integer, String, func class OnDuplicateTest(fixtures.TablesTest): - __only_on__ = 'mysql', + __only_on__ = ("mysql",) __backend__ = True - run_define_tables = 'each' + run_define_tables = "each" @classmethod def define_tables(cls, metadata): Table( - 'foos', metadata, - Column('id', Integer, primary_key=True, autoincrement=True), - Column('bar', String(10)), - Column('baz', String(10)), - Column('updated_once', Boolean, default=False), + "foos", + metadata, + Column("id", Integer, primary_key=True, autoincrement=True), + Column("bar", String(10)), + Column("baz", String(10)), + Column("updated_once", Boolean, default=False), ) def test_bad_args(self): assert_raises( ValueError, - insert(self.tables.foos, values={}).on_duplicate_key_update + insert(self.tables.foos, values={}).on_duplicate_key_update, ) assert_raises( exc.ArgumentError, insert(self.tables.foos, values={}).on_duplicate_key_update, - {'id': 1, 'bar': 'b'}, + {"id": 1, "bar": "b"}, id=1, - bar='b', + bar="b", ) assert_raises( exc.ArgumentError, insert(self.tables.foos, values={}).on_duplicate_key_update, - {'id': 1, 'bar': 'b'}, - {'id': 2, 'bar': 'baz'}, + {"id": 1, "bar": "b"}, + {"id": 2, "bar": "baz"}, ) def test_on_duplicate_key_update(self): foos = self.tables.foos with testing.db.connect() as conn: - conn.execute(insert(foos, dict(id=1, bar='b', baz='bz'))) + conn.execute(insert(foos, dict(id=1, bar="b", baz="bz"))) stmt = insert(foos).values( - [dict(id=1, bar='ab'), dict(id=2, bar='b')]) + [dict(id=1, bar="ab"), dict(id=2, bar="b")] + ) stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar) result = conn.execute(stmt) eq_(result.inserted_primary_key, [2]) eq_( conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), - [(1, 'ab', 'bz', False)] + [(1, "ab", "bz", False)], ) def test_on_duplicate_key_update_preserve_order(self): foos = self.tables.foos with testing.db.connect() as conn: - conn.execute(insert(foos, - [dict(id=1, bar='b', baz='bz'), dict(id=2, bar='b', baz='bz2')])) + conn.execute( + insert( + foos, + [ + dict(id=1, bar="b", baz="bz"), + dict(id=2, bar="b", baz="bz2"), + ], + ) + ) stmt = insert(foos) - update_condition = (foos.c.updated_once == False) + update_condition = foos.c.updated_once == False # The following statements show importance of the columns update ordering # as old values being referenced in UPDATE clause are getting replaced one # by one from left to right with their new values. - stmt1 = stmt.on_duplicate_key_update([ - ('bar', func.if_(update_condition, func.values(foos.c.bar), foos.c.bar)), - ('updated_once', func.if_(update_condition, True, foos.c.updated_once)), - ]) - stmt2 = stmt.on_duplicate_key_update([ - ('updated_once', func.if_(update_condition, True, foos.c.updated_once)), - ('bar', func.if_(update_condition, func.values(foos.c.bar), foos.c.bar)), - ]) + stmt1 = stmt.on_duplicate_key_update( + [ + ( + "bar", + func.if_( + update_condition, + func.values(foos.c.bar), + foos.c.bar, + ), + ), + ( + "updated_once", + func.if_(update_condition, True, foos.c.updated_once), + ), + ] + ) + stmt2 = stmt.on_duplicate_key_update( + [ + ( + "updated_once", + func.if_(update_condition, True, foos.c.updated_once), + ), + ( + "bar", + func.if_( + update_condition, + func.values(foos.c.bar), + foos.c.bar, + ), + ), + ] + ) # First statement should succeed updating column bar - conn.execute(stmt1, dict(id=1, bar='ab')) + conn.execute(stmt1, dict(id=1, bar="ab")) eq_( conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), - [(1, 'ab', 'bz', True)], + [(1, "ab", "bz", True)], ) # Second statement will do noop update of column bar - conn.execute(stmt2, dict(id=2, bar='ab')) + conn.execute(stmt2, dict(id=2, bar="ab")) eq_( conn.execute(foos.select().where(foos.c.id == 2)).fetchall(), - [(2, 'b', 'bz2', True)] + [(2, "b", "bz2", True)], ) def test_last_inserted_id(self): @@ -92,13 +125,15 @@ class OnDuplicateTest(fixtures.TablesTest): stmt = insert(foos).values({"bar": "b", "baz": "bz"}) result = conn.execute( stmt.on_duplicate_key_update( - bar=stmt.inserted.bar, baz="newbz") + bar=stmt.inserted.bar, baz="newbz" + ) ) eq_(result.inserted_primary_key, [1]) stmt = insert(foos).values({"id": 1, "bar": "b", "baz": "bz"}) result = conn.execute( stmt.on_duplicate_key_update( - bar=stmt.inserted.bar, baz="newbz") + bar=stmt.inserted.bar, baz="newbz" + ) ) eq_(result.inserted_primary_key, [1]) diff --git a/test/dialect/mysql/test_query.py b/test/dialect/mysql/test_query.py index 04f3ca67d6..ecd79257f0 100644 --- a/test/dialect/mysql/test_query.py +++ b/test/dialect/mysql/test_query.py @@ -7,29 +7,29 @@ from sqlalchemy import testing class IdiosyncrasyTest(fixtures.TestBase): - __only_on__ = 'mysql' + __only_on__ = "mysql" __backend__ = True @testing.emits_warning() def test_is_boolean_symbols_despite_no_native(self): is_( testing.db.scalar(select([cast(true().is_(true()), Boolean)])), - True + True, ) is_( testing.db.scalar(select([cast(true().isnot(true()), Boolean)])), - False + False, ) is_( testing.db.scalar(select([cast(false().is_(false()), Boolean)])), - True + True, ) class MatchTest(fixtures.TestBase): - __only_on__ = 'mysql' + __only_on__ = "mysql" __backend__ = True @classmethod @@ -37,138 +37,182 @@ class MatchTest(fixtures.TestBase): global metadata, cattable, matchtable metadata = MetaData(testing.db) - cattable = Table('cattable', metadata, - Column('id', Integer, primary_key=True), - Column('description', String(50)), - mysql_engine='MyISAM') - matchtable = Table('matchtable', metadata, - Column('id', Integer, primary_key=True), - Column('title', String(200)), - Column('category_id', - Integer, - ForeignKey('cattable.id')), - mysql_engine='MyISAM') + cattable = Table( + "cattable", + metadata, + Column("id", Integer, primary_key=True), + Column("description", String(50)), + mysql_engine="MyISAM", + ) + matchtable = Table( + "matchtable", + metadata, + Column("id", Integer, primary_key=True), + Column("title", String(200)), + Column("category_id", Integer, ForeignKey("cattable.id")), + mysql_engine="MyISAM", + ) metadata.create_all() - cattable.insert().execute([ - {'id': 1, 'description': 'Python'}, - {'id': 2, 'description': 'Ruby'}, - ]) - matchtable.insert().execute([ - {'id': 1, - 'title': 'Agile Web Development with Ruby On Rails', - 'category_id': 2}, - {'id': 2, - 'title': 'Dive Into Python', - 'category_id': 1}, - {'id': 3, - 'title': "Programming Matz's Ruby", - 'category_id': 2}, - {'id': 4, - 'title': 'The Definitive Guide to Django', - 'category_id': 1}, - {'id': 5, - 'title': 'Python in a Nutshell', - 'category_id': 1} - ]) + cattable.insert().execute( + [ + {"id": 1, "description": "Python"}, + {"id": 2, "description": "Ruby"}, + ] + ) + matchtable.insert().execute( + [ + { + "id": 1, + "title": "Agile Web Development with Ruby On Rails", + "category_id": 2, + }, + {"id": 2, "title": "Dive Into Python", "category_id": 1}, + { + "id": 3, + "title": "Programming Matz's Ruby", + "category_id": 2, + }, + { + "id": 4, + "title": "The Definitive Guide to Django", + "category_id": 1, + }, + {"id": 5, "title": "Python in a Nutshell", "category_id": 1}, + ] + ) @classmethod def teardown_class(cls): metadata.drop_all() def test_simple_match(self): - results = (matchtable.select(). - where(matchtable.c.title.match('python')). - order_by(matchtable.c.id). - execute(). - fetchall()) + results = ( + matchtable.select() + .where(matchtable.c.title.match("python")) + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([2, 5], [r.id for r in results]) def test_not_match(self): - results = (matchtable.select(). - where(~matchtable.c.title.match('python')). - order_by(matchtable.c.id). - execute(). - fetchall()) + results = ( + matchtable.select() + .where(~matchtable.c.title.match("python")) + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([1, 3, 4], [r.id for r in results]) def test_simple_match_with_apostrophe(self): - results = (matchtable.select(). - where(matchtable.c.title.match("Matz's")). - execute(). - fetchall()) + results = ( + matchtable.select() + .where(matchtable.c.title.match("Matz's")) + .execute() + .fetchall() + ) eq_([3], [r.id for r in results]) def test_return_value(self): # test [ticket:3263] result = testing.db.execute( - select([ - matchtable.c.title.match('Agile Ruby Programming') - .label('ruby'), - matchtable.c.title.match('Dive Python').label('python'), - matchtable.c.title - ]).order_by(matchtable.c.id) + select( + [ + matchtable.c.title.match("Agile Ruby Programming").label( + "ruby" + ), + matchtable.c.title.match("Dive Python").label("python"), + matchtable.c.title, + ] + ).order_by(matchtable.c.id) ).fetchall() eq_( result, [ - (2.0, 0.0, 'Agile Web Development with Ruby On Rails'), - (0.0, 2.0, 'Dive Into Python'), + (2.0, 0.0, "Agile Web Development with Ruby On Rails"), + (0.0, 2.0, "Dive Into Python"), (2.0, 0.0, "Programming Matz's Ruby"), - (0.0, 0.0, 'The Definitive Guide to Django'), - (0.0, 1.0, 'Python in a Nutshell') - ] + (0.0, 0.0, "The Definitive Guide to Django"), + (0.0, 1.0, "Python in a Nutshell"), + ], ) def test_or_match(self): - results1 = (matchtable.select(). - where(or_(matchtable.c.title.match('nutshell'), - matchtable.c.title.match('ruby'))). - order_by(matchtable.c.id). - execute(). - fetchall()) + results1 = ( + matchtable.select() + .where( + or_( + matchtable.c.title.match("nutshell"), + matchtable.c.title.match("ruby"), + ) + ) + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([1, 3, 5], [r.id for r in results1]) - results2 = (matchtable.select(). - where(matchtable.c.title.match('nutshell ruby')). - order_by(matchtable.c.id). - execute(). - fetchall()) + results2 = ( + matchtable.select() + .where(matchtable.c.title.match("nutshell ruby")) + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([1, 3, 5], [r.id for r in results2]) def test_and_match(self): - results1 = (matchtable.select(). - where(and_(matchtable.c.title.match('python'), - matchtable.c.title.match('nutshell'))). - execute(). - fetchall()) + results1 = ( + matchtable.select() + .where( + and_( + matchtable.c.title.match("python"), + matchtable.c.title.match("nutshell"), + ) + ) + .execute() + .fetchall() + ) eq_([5], [r.id for r in results1]) - results2 = (matchtable.select(). - where(matchtable.c.title.match('+python +nutshell')). - execute(). - fetchall()) + results2 = ( + matchtable.select() + .where(matchtable.c.title.match("+python +nutshell")) + .execute() + .fetchall() + ) eq_([5], [r.id for r in results2]) def test_match_across_joins(self): - results = (matchtable.select(). - where(and_(cattable.c.id == matchtable.c.category_id, - or_(cattable.c.description.match('Ruby'), - matchtable.c.title.match('nutshell')))). - order_by(matchtable.c.id). - execute(). - fetchall()) + results = ( + matchtable.select() + .where( + and_( + cattable.c.id == matchtable.c.category_id, + or_( + cattable.c.description.match("Ruby"), + matchtable.c.title.match("nutshell"), + ), + ) + ) + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([1, 3, 5], [r.id for r in results]) class AnyAllTest(fixtures.TablesTest): - __only_on__ = 'mysql' + __only_on__ = "mysql" __backend__ = True @classmethod def define_tables(cls, metadata): Table( - 'stuff', metadata, - Column('id', Integer, primary_key=True), - Column('value', Integer) + "stuff", + metadata, + Column("id", Integer, primary_key=True), + Column("value", Integer), ) @classmethod @@ -177,38 +221,32 @@ class AnyAllTest(fixtures.TablesTest): testing.db.execute( stuff.insert(), [ - {'id': 1, 'value': 1}, - {'id': 2, 'value': 2}, - {'id': 3, 'value': 3}, - {'id': 4, 'value': 4}, - {'id': 5, 'value': 5}, - ] + {"id": 1, "value": 1}, + {"id": 2, "value": 2}, + {"id": 3, "value": 3}, + {"id": 4, "value": 4}, + {"id": 5, "value": 5}, + ], ) def test_any_w_comparator(self): stuff = self.tables.stuff stmt = select([stuff.c.id]).where( - stuff.c.value > any_(select([stuff.c.value]))) - - eq_( - testing.db.execute(stmt).fetchall(), - [(2,), (3,), (4,), (5,)] + stuff.c.value > any_(select([stuff.c.value])) ) + eq_(testing.db.execute(stmt).fetchall(), [(2,), (3,), (4,), (5,)]) + def test_all_w_comparator(self): stuff = self.tables.stuff stmt = select([stuff.c.id]).where( - stuff.c.value >= all_(select([stuff.c.value]))) - - eq_( - testing.db.execute(stmt).fetchall(), - [(5,)] + stuff.c.value >= all_(select([stuff.c.value])) ) + eq_(testing.db.execute(stmt).fetchall(), [(5,)]) + def test_any_literal(self): stuff = self.tables.stuff stmt = select([4 == any_(select([stuff.c.value]))]) - is_( - testing.db.execute(stmt).scalar(), True - ) + is_(testing.db.execute(stmt).scalar(), True) diff --git a/test/dialect/mysql/test_reflection.py b/test/dialect/mysql/test_reflection.py index 5de855838f..2ad63ec2a6 100644 --- a/test/dialect/mysql/test_reflection.py +++ b/test/dialect/mysql/test_reflection.py @@ -1,10 +1,28 @@ # coding: utf-8 from sqlalchemy.testing import eq_, is_ -from sqlalchemy import Column, Table, DDL, MetaData, TIMESTAMP, \ - DefaultClause, String, Integer, Text, UnicodeText, SmallInteger,\ - NCHAR, LargeBinary, DateTime, select, UniqueConstraint, Unicode,\ - BigInteger, Index, ForeignKey +from sqlalchemy import ( + Column, + Table, + DDL, + MetaData, + TIMESTAMP, + DefaultClause, + String, + Integer, + Text, + UnicodeText, + SmallInteger, + NCHAR, + LargeBinary, + DateTime, + select, + UniqueConstraint, + Unicode, + BigInteger, + Index, + ForeignKey, +) from sqlalchemy.schema import CreateIndex from sqlalchemy import event from sqlalchemy import sql @@ -19,39 +37,38 @@ import re class TypeReflectionTest(fixtures.TestBase): - __only_on__ = 'mysql' + __only_on__ = "mysql" __backend__ = True @testing.provide_metadata def _run_test(self, specs, attributes): - columns = [Column('c%i' % (i + 1), t[0]) for i, t in enumerate(specs)] + columns = [Column("c%i" % (i + 1), t[0]) for i, t in enumerate(specs)] # Early 5.0 releases seem to report more "general" for columns # in a view, e.g. char -> varchar, tinyblob -> mediumblob use_views = testing.db.dialect.server_version_info > (5, 0, 10) m = self.metadata - Table('mysql_types', m, *columns) + Table("mysql_types", m, *columns) if use_views: event.listen( - m, 'after_create', + m, + "after_create", DDL( - 'CREATE OR REPLACE VIEW mysql_types_v ' - 'AS SELECT * from mysql_types') + "CREATE OR REPLACE VIEW mysql_types_v " + "AS SELECT * from mysql_types" + ), ) event.listen( - m, 'before_drop', - DDL("DROP VIEW IF EXISTS mysql_types_v") + m, "before_drop", DDL("DROP VIEW IF EXISTS mysql_types_v") ) m.create_all() m2 = MetaData(testing.db) - tables = [ - Table('mysql_types', m2, autoload=True) - ] + tables = [Table("mysql_types", m2, autoload=True)] if use_views: - tables.append(Table('mysql_types_v', m2, autoload=True)) + tables.append(Table("mysql_types_v", m2, autoload=True)) for table in tables: for i, (reflected_col, spec) in enumerate(zip(table.c, specs)): @@ -64,13 +81,14 @@ class TypeReflectionTest(fixtures.TestBase): getattr(reflected_type, attr), getattr(expected_spec, attr), "Column %s: Attribute %s value of %s does not " - "match %s for type %s" % ( + "match %s for type %s" + % ( "c%i" % (i + 1), attr, getattr(reflected_type, attr), getattr(expected_spec, attr), - spec[0] - ) + spec[0], + ), ) def test_time_types(self): @@ -91,13 +109,12 @@ class TypeReflectionTest(fixtures.TestBase): else: specs.append((type_(), type_())) - specs.extend([ - (TIMESTAMP(), mysql.TIMESTAMP()), - (DateTime(), mysql.DATETIME()), - ]) + specs.extend( + [(TIMESTAMP(), mysql.TIMESTAMP()), (DateTime(), mysql.DATETIME())] + ) # note 'timezone' should always be None on both - self._run_test(specs, ['fsp', 'timezone']) + self._run_test(specs, ["fsp", "timezone"]) def test_year_types(self): specs = [ @@ -105,7 +122,7 @@ class TypeReflectionTest(fixtures.TestBase): (mysql.YEAR(display_width=4), mysql.YEAR(display_width=4)), ] - self._run_test(specs, ['display_width']) + self._run_test(specs, ["display_width"]) def test_string_types(self): specs = [ @@ -119,25 +136,29 @@ class TypeReflectionTest(fixtures.TestBase): (mysql.MSChar(3), mysql.MSChar(3)), (NCHAR(2), mysql.MSChar(2)), (mysql.MSNChar(2), mysql.MSChar(2)), - (mysql.MSNVarChar(22), mysql.MSString(22),), + (mysql.MSNVarChar(22), mysql.MSString(22)), ] - self._run_test(specs, ['length']) + self._run_test(specs, ["length"]) def test_integer_types(self): specs = [] for type_ in [ - mysql.TINYINT, mysql.SMALLINT, - mysql.MEDIUMINT, mysql.INTEGER, mysql.BIGINT]: + mysql.TINYINT, + mysql.SMALLINT, + mysql.MEDIUMINT, + mysql.INTEGER, + mysql.BIGINT, + ]: for display_width in [None, 4, 7]: for unsigned in [False, True]: for zerofill in [None, True]: kw = {} if display_width: - kw['display_width'] = display_width + kw["display_width"] = display_width if unsigned is not None: - kw['unsigned'] = unsigned + kw["unsigned"] = unsigned if zerofill is not None: - kw['zerofill'] = zerofill + kw["zerofill"] = zerofill zerofill = bool(zerofill) source_type = type_(**kw) @@ -148,7 +169,7 @@ class TypeReflectionTest(fixtures.TestBase): mysql.SMALLINT: 6, mysql.TINYINT: 4, mysql.INTEGER: 11, - mysql.BIGINT: 20 + mysql.BIGINT: 20, }[type_] if zerofill: @@ -157,24 +178,24 @@ class TypeReflectionTest(fixtures.TestBase): expected_type = type_( display_width=display_width, unsigned=unsigned, - zerofill=zerofill - ) - specs.append( - (source_type, expected_type) + zerofill=zerofill, ) + specs.append((source_type, expected_type)) - specs.extend([ - (SmallInteger(), mysql.SMALLINT(display_width=6)), - (Integer(), mysql.INTEGER(display_width=11)), - (BigInteger, mysql.BIGINT(display_width=20)) - ]) - self._run_test(specs, ['display_width', 'unsigned', 'zerofill']) + specs.extend( + [ + (SmallInteger(), mysql.SMALLINT(display_width=6)), + (Integer(), mysql.INTEGER(display_width=11)), + (BigInteger, mysql.BIGINT(display_width=20)), + ] + ) + self._run_test(specs, ["display_width", "unsigned", "zerofill"]) def test_binary_types(self): specs = [ - (LargeBinary(3), mysql.TINYBLOB(), ), + (LargeBinary(3), mysql.TINYBLOB()), (LargeBinary(), mysql.BLOB()), - (mysql.MSBinary(3), mysql.MSBinary(3), ), + (mysql.MSBinary(3), mysql.MSBinary(3)), (mysql.MSVarBinary(3), mysql.MSVarBinary(3)), (mysql.MSTinyBlob(), mysql.MSTinyBlob()), (mysql.MSBlob(), mysql.MSBlob()), @@ -184,118 +205,127 @@ class TypeReflectionTest(fixtures.TestBase): ] self._run_test(specs, []) - @testing.uses_deprecated('Manually quoting ENUM value literals') + @testing.uses_deprecated("Manually quoting ENUM value literals") def test_legacy_enum_types(self): - specs = [ - (mysql.ENUM("''", "'fleem'"), mysql.ENUM("''", "'fleem'")), - ] + specs = [(mysql.ENUM("''", "'fleem'"), mysql.ENUM("''", "'fleem'"))] - self._run_test(specs, ['enums']) + self._run_test(specs, ["enums"]) class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): - __only_on__ = 'mysql' + __only_on__ = "mysql" __backend__ = True def test_default_reflection(self): """Test reflection of column defaults.""" from sqlalchemy.dialects.mysql import VARCHAR + def_table = Table( - 'mysql_def', + "mysql_def", MetaData(testing.db), - Column('c1', VARCHAR(10, collation='utf8_unicode_ci'), - DefaultClause(''), nullable=False), - Column('c2', String(10), DefaultClause('0')), - Column('c3', String(10), DefaultClause('abc')), - Column('c4', TIMESTAMP, DefaultClause('2009-04-05 12:00:00')), - Column('c5', TIMESTAMP), - Column('c6', TIMESTAMP, - DefaultClause(sql.text("CURRENT_TIMESTAMP " - "ON UPDATE CURRENT_TIMESTAMP"))), + Column( + "c1", + VARCHAR(10, collation="utf8_unicode_ci"), + DefaultClause(""), + nullable=False, + ), + Column("c2", String(10), DefaultClause("0")), + Column("c3", String(10), DefaultClause("abc")), + Column("c4", TIMESTAMP, DefaultClause("2009-04-05 12:00:00")), + Column("c5", TIMESTAMP), + Column( + "c6", + TIMESTAMP, + DefaultClause( + sql.text( + "CURRENT_TIMESTAMP " "ON UPDATE CURRENT_TIMESTAMP" + ) + ), + ), ) def_table.create() try: - reflected = Table('mysql_def', MetaData(testing.db), - autoload=True) + reflected = Table("mysql_def", MetaData(testing.db), autoload=True) finally: def_table.drop() - assert def_table.c.c1.server_default.arg == '' - assert def_table.c.c2.server_default.arg == '0' - assert def_table.c.c3.server_default.arg == 'abc' - assert def_table.c.c4.server_default.arg \ - == '2009-04-05 12:00:00' + assert def_table.c.c1.server_default.arg == "" + assert def_table.c.c2.server_default.arg == "0" + assert def_table.c.c3.server_default.arg == "abc" + assert def_table.c.c4.server_default.arg == "2009-04-05 12:00:00" assert str(reflected.c.c1.server_default.arg) == "''" assert str(reflected.c.c2.server_default.arg) == "'0'" assert str(reflected.c.c3.server_default.arg) == "'abc'" - assert str(reflected.c.c4.server_default.arg) \ - == "'2009-04-05 12:00:00'" + assert ( + str(reflected.c.c4.server_default.arg) == "'2009-04-05 12:00:00'" + ) assert reflected.c.c5.default is None assert reflected.c.c5.server_default is None assert reflected.c.c6.default is None assert re.match( r"CURRENT_TIMESTAMP(\(\))? ON UPDATE CURRENT_TIMESTAMP(\(\))?", - str(reflected.c.c6.server_default.arg).upper() + str(reflected.c.c6.server_default.arg).upper(), ) reflected.create() try: - reflected2 = Table('mysql_def', MetaData(testing.db), - autoload=True) + reflected2 = Table( + "mysql_def", MetaData(testing.db), autoload=True + ) finally: reflected.drop() assert str(reflected2.c.c1.server_default.arg) == "''" assert str(reflected2.c.c2.server_default.arg) == "'0'" assert str(reflected2.c.c3.server_default.arg) == "'abc'" - assert str(reflected2.c.c4.server_default.arg) \ - == "'2009-04-05 12:00:00'" + assert ( + str(reflected2.c.c4.server_default.arg) == "'2009-04-05 12:00:00'" + ) assert reflected.c.c5.default is None assert reflected.c.c5.server_default is None assert reflected.c.c6.default is None assert re.match( r"CURRENT_TIMESTAMP(\(\))? ON UPDATE CURRENT_TIMESTAMP(\(\))?", - str(reflected.c.c6.server_default.arg).upper() + str(reflected.c.c6.server_default.arg).upper(), ) def test_reflection_with_table_options(self): comment = r"""Comment types type speedily ' " \ '' Fun!""" def_table = Table( - 'mysql_def', MetaData(testing.db), - Column('c1', Integer()), - mysql_engine='MEMORY', + "mysql_def", + MetaData(testing.db), + Column("c1", Integer()), + mysql_engine="MEMORY", comment=comment, - mysql_default_charset='utf8', - mysql_auto_increment='5', - mysql_avg_row_length='3', - mysql_password='secret', - mysql_connection='fish', + mysql_default_charset="utf8", + mysql_auto_increment="5", + mysql_avg_row_length="3", + mysql_password="secret", + mysql_connection="fish", ) def_table.create() try: - reflected = Table( - 'mysql_def', MetaData(testing.db), - autoload=True) + reflected = Table("mysql_def", MetaData(testing.db), autoload=True) finally: def_table.drop() - assert def_table.kwargs['mysql_engine'] == 'MEMORY' + assert def_table.kwargs["mysql_engine"] == "MEMORY" assert def_table.comment == comment - assert def_table.kwargs['mysql_default_charset'] == 'utf8' - assert def_table.kwargs['mysql_auto_increment'] == '5' - assert def_table.kwargs['mysql_avg_row_length'] == '3' - assert def_table.kwargs['mysql_password'] == 'secret' - assert def_table.kwargs['mysql_connection'] == 'fish' + assert def_table.kwargs["mysql_default_charset"] == "utf8" + assert def_table.kwargs["mysql_auto_increment"] == "5" + assert def_table.kwargs["mysql_avg_row_length"] == "3" + assert def_table.kwargs["mysql_password"] == "secret" + assert def_table.kwargs["mysql_connection"] == "fish" - assert reflected.kwargs['mysql_engine'] == 'MEMORY' + assert reflected.kwargs["mysql_engine"] == "MEMORY" assert reflected.comment == comment - assert reflected.kwargs['mysql_comment'] == comment - assert reflected.kwargs['mysql_default charset'] == 'utf8' - assert reflected.kwargs['mysql_avg_row_length'] == '3' - assert reflected.kwargs['mysql_connection'] == 'fish' + assert reflected.kwargs["mysql_comment"] == comment + assert reflected.kwargs["mysql_default charset"] == "utf8" + assert reflected.kwargs["mysql_avg_row_length"] == "3" + assert reflected.kwargs["mysql_connection"] == "fish" # This field doesn't seem to be returned by mysql itself. # assert reflected.kwargs['mysql_password'] == 'secret' @@ -307,23 +337,32 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): """Test reflection of include_columns to be sure they respect case.""" case_table = Table( - 'mysql_case', MetaData(testing.db), - Column('c1', String(10)), - Column('C2', String(10)), - Column('C3', String(10))) + "mysql_case", + MetaData(testing.db), + Column("c1", String(10)), + Column("C2", String(10)), + Column("C3", String(10)), + ) try: case_table.create() - reflected = Table('mysql_case', MetaData(testing.db), - autoload=True, include_columns=['c1', 'C2']) + reflected = Table( + "mysql_case", + MetaData(testing.db), + autoload=True, + include_columns=["c1", "C2"], + ) for t in case_table, reflected: - assert 'c1' in t.c.keys() - assert 'C2' in t.c.keys() + assert "c1" in t.c.keys() + assert "C2" in t.c.keys() reflected2 = Table( - 'mysql_case', MetaData(testing.db), - autoload=True, include_columns=['c1', 'c2']) - assert 'c1' in reflected2.c.keys() - for c in ['c2', 'C2', 'C3']: + "mysql_case", + MetaData(testing.db), + autoload=True, + include_columns=["c1", "c2"], + ) + assert "c1" in reflected2.c.keys() + for c in ["c2", "C2", "C3"]: assert c not in reflected2.c.keys() finally: case_table.drop() @@ -331,71 +370,110 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): def test_autoincrement(self): meta = MetaData(testing.db) try: - Table('ai_1', meta, - Column('int_y', Integer, primary_key=True, - autoincrement=True), - Column('int_n', Integer, DefaultClause('0'), - primary_key=True), - mysql_engine='MyISAM') - Table('ai_2', meta, - Column('int_y', Integer, primary_key=True, - autoincrement=True), - Column('int_n', Integer, DefaultClause('0'), - primary_key=True), - mysql_engine='MyISAM') - Table('ai_3', meta, - Column('int_n', Integer, DefaultClause('0'), - primary_key=True, autoincrement=False), - Column('int_y', Integer, primary_key=True, - autoincrement=True), - mysql_engine='MyISAM') - Table('ai_4', meta, - Column('int_n', Integer, DefaultClause('0'), - primary_key=True, autoincrement=False), - Column('int_n2', Integer, DefaultClause('0'), - primary_key=True, autoincrement=False), - mysql_engine='MyISAM') - Table('ai_5', meta, - Column('int_y', Integer, primary_key=True, - autoincrement=True), - Column('int_n', Integer, DefaultClause('0'), - primary_key=True, autoincrement=False), - mysql_engine='MyISAM') - Table('ai_6', meta, - Column('o1', String(1), DefaultClause('x'), - primary_key=True), - Column('int_y', Integer, primary_key=True, - autoincrement=True), - mysql_engine='MyISAM') - Table('ai_7', meta, - Column('o1', String(1), DefaultClause('x'), - primary_key=True), - Column('o2', String(1), DefaultClause('x'), - primary_key=True), - Column('int_y', Integer, primary_key=True, - autoincrement=True), - mysql_engine='MyISAM') - Table('ai_8', meta, - Column('o1', String(1), DefaultClause('x'), - primary_key=True), - Column('o2', String(1), DefaultClause('x'), - primary_key=True), - mysql_engine='MyISAM') + Table( + "ai_1", + meta, + Column("int_y", Integer, primary_key=True, autoincrement=True), + Column("int_n", Integer, DefaultClause("0"), primary_key=True), + mysql_engine="MyISAM", + ) + Table( + "ai_2", + meta, + Column("int_y", Integer, primary_key=True, autoincrement=True), + Column("int_n", Integer, DefaultClause("0"), primary_key=True), + mysql_engine="MyISAM", + ) + Table( + "ai_3", + meta, + Column( + "int_n", + Integer, + DefaultClause("0"), + primary_key=True, + autoincrement=False, + ), + Column("int_y", Integer, primary_key=True, autoincrement=True), + mysql_engine="MyISAM", + ) + Table( + "ai_4", + meta, + Column( + "int_n", + Integer, + DefaultClause("0"), + primary_key=True, + autoincrement=False, + ), + Column( + "int_n2", + Integer, + DefaultClause("0"), + primary_key=True, + autoincrement=False, + ), + mysql_engine="MyISAM", + ) + Table( + "ai_5", + meta, + Column("int_y", Integer, primary_key=True, autoincrement=True), + Column( + "int_n", + Integer, + DefaultClause("0"), + primary_key=True, + autoincrement=False, + ), + mysql_engine="MyISAM", + ) + Table( + "ai_6", + meta, + Column("o1", String(1), DefaultClause("x"), primary_key=True), + Column("int_y", Integer, primary_key=True, autoincrement=True), + mysql_engine="MyISAM", + ) + Table( + "ai_7", + meta, + Column("o1", String(1), DefaultClause("x"), primary_key=True), + Column("o2", String(1), DefaultClause("x"), primary_key=True), + Column("int_y", Integer, primary_key=True, autoincrement=True), + mysql_engine="MyISAM", + ) + Table( + "ai_8", + meta, + Column("o1", String(1), DefaultClause("x"), primary_key=True), + Column("o2", String(1), DefaultClause("x"), primary_key=True), + mysql_engine="MyISAM", + ) meta.create_all() - table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4', - 'ai_5', 'ai_6', 'ai_7', 'ai_8'] + table_names = [ + "ai_1", + "ai_2", + "ai_3", + "ai_4", + "ai_5", + "ai_6", + "ai_7", + "ai_8", + ] mr = MetaData(testing.db) mr.reflect(only=table_names) for tbl in [mr.tables[name] for name in table_names]: for c in tbl.c: - if c.name.startswith('int_y'): + if c.name.startswith("int_y"): assert c.autoincrement - elif c.name.startswith('int_n'): + elif c.name.startswith("int_n"): assert not c.autoincrement tbl.insert().execute() - if 'int_y' in tbl.c: + if "int_y" in tbl.c: assert select([tbl.c.int_y]).scalar() == 1 assert list(tbl.select().execute().first()).count(1) == 1 else: @@ -405,35 +483,35 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): @testing.provide_metadata def test_view_reflection(self): - Table('x', - self.metadata, - Column('a', Integer), - Column('b', String(50))) + Table( + "x", self.metadata, Column("a", Integer), Column("b", String(50)) + ) self.metadata.create_all() with testing.db.connect() as conn: conn.execute("CREATE VIEW v1 AS SELECT * FROM x") + conn.execute("CREATE ALGORITHM=MERGE VIEW v2 AS SELECT * FROM x") conn.execute( - "CREATE ALGORITHM=MERGE VIEW v2 AS SELECT * FROM x") - conn.execute( - "CREATE ALGORITHM=UNDEFINED VIEW v3 AS SELECT * FROM x") + "CREATE ALGORITHM=UNDEFINED VIEW v3 AS SELECT * FROM x" + ) conn.execute( - "CREATE DEFINER=CURRENT_USER VIEW v4 AS SELECT * FROM x") + "CREATE DEFINER=CURRENT_USER VIEW v4 AS SELECT * FROM x" + ) @event.listens_for(self.metadata, "before_drop") def cleanup(*arg, **kw): with testing.db.connect() as conn: - for v in ['v1', 'v2', 'v3', 'v4']: + for v in ["v1", "v2", "v3", "v4"]: conn.execute("DROP VIEW %s" % v) insp = inspect(testing.db) - for v in ['v1', 'v2', 'v3', 'v4']: + for v in ["v1", "v2", "v3", "v4"]: eq_( [ - (col['name'], col['type'].__class__) + (col["name"], col["type"].__class__) for col in insp.get_columns(v) ], - [('a', mysql.INTEGER), ('b', mysql.VARCHAR)] + [("a", mysql.INTEGER), ("b", mysql.VARCHAR)], ) @testing.provide_metadata @@ -448,7 +526,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): with testing.db.connect() as conn: conn.execute("CREATE TABLE test_t1 (id INTEGER)") conn.execute("CREATE TABLE test_t2 (id INTEGER)") - conn.execute("CREATE VIEW test_v AS SELECT id FROM test_t1" ) + conn.execute("CREATE VIEW test_v AS SELECT id FROM test_t1") conn.execute("DROP TABLE test_t1") m = MetaData() @@ -457,20 +535,23 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): "reflected: .* references invalid table" ): m.reflect(views=True, bind=conn) - eq_(m.tables['test_t2'].name, "test_t2") + eq_(m.tables["test_t2"].name, "test_t2") assert_raises_message( exc.UnreflectableTableError, "references invalid table", - Table, 'test_v', MetaData(), autoload_with=conn + Table, + "test_v", + MetaData(), + autoload_with=conn, ) - @testing.exclude('mysql', '<', (5, 0, 0), 'no information_schema support') + @testing.exclude("mysql", "<", (5, 0, 0), "no information_schema support") def test_system_views(self): dialect = testing.db.dialect connection = testing.db.connect() view_names = dialect.get_view_names(connection, "information_schema") - self.assert_('TABLES' in view_names) + self.assert_("TABLES" in view_names) @testing.provide_metadata def test_nullable_reflection(self): @@ -486,35 +567,41 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): row = testing.db.execute( "show variables like '%%explicit_defaults_for_timestamp%%'" ).first() - explicit_defaults_for_timestamp = row[1].lower() in ('on', '1', 'true') + explicit_defaults_for_timestamp = row[1].lower() in ("on", "1", "true") reflected = [] - for idx, cols in enumerate([ + for idx, cols in enumerate( [ - "x INTEGER NULL", - "y INTEGER NOT NULL", - "z INTEGER", - "q TIMESTAMP NULL" - ], - - ["p TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP"], - ["r TIMESTAMP NOT NULL"], - ["s TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP"], - ["t TIMESTAMP"], - ["u TIMESTAMP DEFAULT CURRENT_TIMESTAMP"] - ]): + [ + "x INTEGER NULL", + "y INTEGER NOT NULL", + "z INTEGER", + "q TIMESTAMP NULL", + ], + ["p TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP"], + ["r TIMESTAMP NOT NULL"], + ["s TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP"], + ["t TIMESTAMP"], + ["u TIMESTAMP DEFAULT CURRENT_TIMESTAMP"], + ] + ): Table("nn_t%d" % idx, meta) # to allow DROP - testing.db.execute(""" + testing.db.execute( + """ CREATE TABLE nn_t%d ( %s ) - """ % (idx, ", \n".join(cols))) + """ + % (idx, ", \n".join(cols)) + ) reflected.extend( { - "name": d['name'], "nullable": d['nullable'], - "default": d['default']} + "name": d["name"], + "nullable": d["nullable"], + "default": d["default"], + } for d in inspect(testing.db).get_columns("nn_t%d" % idx) ) @@ -526,29 +613,38 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): eq_( reflected, [ - {'name': 'x', 'nullable': True, 'default': None}, - {'name': 'y', 'nullable': False, 'default': None}, - {'name': 'z', 'nullable': True, 'default': None}, - {'name': 'q', 'nullable': True, 'default': None}, - {'name': 'p', 'nullable': True, - 'default': current_timestamp}, - {'name': 'r', 'nullable': False, - 'default': None if explicit_defaults_for_timestamp else - "%(current_timestamp)s ON UPDATE %(current_timestamp)s" % - {"current_timestamp": current_timestamp}}, - {'name': 's', 'nullable': False, - 'default': current_timestamp}, - {'name': 't', - 'nullable': True if explicit_defaults_for_timestamp else - False, - 'default': None if explicit_defaults_for_timestamp else - "%(current_timestamp)s ON UPDATE %(current_timestamp)s" % - {"current_timestamp": current_timestamp}}, - {'name': 'u', - 'nullable': True if explicit_defaults_for_timestamp else - False, - 'default': current_timestamp}, - ] + {"name": "x", "nullable": True, "default": None}, + {"name": "y", "nullable": False, "default": None}, + {"name": "z", "nullable": True, "default": None}, + {"name": "q", "nullable": True, "default": None}, + {"name": "p", "nullable": True, "default": current_timestamp}, + { + "name": "r", + "nullable": False, + "default": None + if explicit_defaults_for_timestamp + else "%(current_timestamp)s ON UPDATE %(current_timestamp)s" + % {"current_timestamp": current_timestamp}, + }, + {"name": "s", "nullable": False, "default": current_timestamp}, + { + "name": "t", + "nullable": True + if explicit_defaults_for_timestamp + else False, + "default": None + if explicit_defaults_for_timestamp + else "%(current_timestamp)s ON UPDATE %(current_timestamp)s" + % {"current_timestamp": current_timestamp}, + }, + { + "name": "u", + "nullable": True + if explicit_defaults_for_timestamp + else False, + "default": current_timestamp, + }, + ], ) @testing.provide_metadata @@ -556,95 +652,101 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): insp = inspect(testing.db) meta = self.metadata - uc_table = Table('mysql_uc', meta, - Column('a', String(10)), - UniqueConstraint('a', name='uc_a')) + uc_table = Table( + "mysql_uc", + meta, + Column("a", String(10)), + UniqueConstraint("a", name="uc_a"), + ) uc_table.create() # MySQL converts unique constraints into unique indexes. # separately we get both - indexes = dict((i['name'], i) for i in insp.get_indexes('mysql_uc')) - constraints = set(i['name'] - for i in insp.get_unique_constraints('mysql_uc')) + indexes = dict((i["name"], i) for i in insp.get_indexes("mysql_uc")) + constraints = set( + i["name"] for i in insp.get_unique_constraints("mysql_uc") + ) - self.assert_('uc_a' in indexes) - self.assert_(indexes['uc_a']['unique']) - self.assert_('uc_a' in constraints) + self.assert_("uc_a" in indexes) + self.assert_(indexes["uc_a"]["unique"]) + self.assert_("uc_a" in constraints) # reflection here favors the unique index, as that's the # more "official" MySQL construct - reflected = Table('mysql_uc', MetaData(testing.db), autoload=True) + reflected = Table("mysql_uc", MetaData(testing.db), autoload=True) indexes = dict((i.name, i) for i in reflected.indexes) constraints = set(uc.name for uc in reflected.constraints) - self.assert_('uc_a' in indexes) - self.assert_(indexes['uc_a'].unique) - self.assert_('uc_a' not in constraints) + self.assert_("uc_a" in indexes) + self.assert_(indexes["uc_a"].unique) + self.assert_("uc_a" not in constraints) @testing.provide_metadata def test_reflect_fulltext(self): mt = Table( - "mytable", self.metadata, + "mytable", + self.metadata, Column("id", Integer, primary_key=True), Column("textdata", String(50)), - mysql_engine='InnoDB' + mysql_engine="InnoDB", ) Index("textdata_ix", mt.c.textdata, mysql_prefix="FULLTEXT") self.metadata.create_all(testing.db) - mt = Table( - "mytable", MetaData(), autoload_with=testing.db - ) + mt = Table("mytable", MetaData(), autoload_with=testing.db) idx = list(mt.indexes)[0] eq_(idx.name, "textdata_ix") - eq_(idx.dialect_options['mysql']['prefix'], "FULLTEXT") + eq_(idx.dialect_options["mysql"]["prefix"], "FULLTEXT") self.assert_compile( CreateIndex(idx), - "CREATE FULLTEXT INDEX textdata_ix ON mytable (textdata)" + "CREATE FULLTEXT INDEX textdata_ix ON mytable (textdata)", ) @testing.requires.mysql_ngram_fulltext @testing.provide_metadata def test_reflect_fulltext_comment(self): mt = Table( - "mytable", self.metadata, + "mytable", + self.metadata, Column("id", Integer, primary_key=True), Column("textdata", String(50)), - mysql_engine='InnoDB' + mysql_engine="InnoDB", ) Index( - "textdata_ix", mt.c.textdata, - mysql_prefix="FULLTEXT", mysql_with_parser="ngram") + "textdata_ix", + mt.c.textdata, + mysql_prefix="FULLTEXT", + mysql_with_parser="ngram", + ) self.metadata.create_all(testing.db) - mt = Table( - "mytable", MetaData(), autoload_with=testing.db - ) + mt = Table("mytable", MetaData(), autoload_with=testing.db) idx = list(mt.indexes)[0] eq_(idx.name, "textdata_ix") - eq_(idx.dialect_options['mysql']['prefix'], "FULLTEXT") - eq_(idx.dialect_options['mysql']['with_parser'], "ngram") + eq_(idx.dialect_options["mysql"]["prefix"], "FULLTEXT") + eq_(idx.dialect_options["mysql"]["with_parser"], "ngram") self.assert_compile( CreateIndex(idx), "CREATE FULLTEXT INDEX textdata_ix ON mytable " - "(textdata) WITH PARSER ngram" + "(textdata) WITH PARSER ngram", ) @testing.provide_metadata def test_non_column_index(self): m1 = self.metadata t1 = Table( - 'add_ix', m1, Column('x', String(50)), mysql_engine='InnoDB') - Index('foo_idx', t1.c.x.desc()) + "add_ix", m1, Column("x", String(50)), mysql_engine="InnoDB" + ) + Index("foo_idx", t1.c.x.desc()) m1.create_all() insp = inspect(testing.db) eq_( insp.get_indexes("add_ix"), - [{'name': 'foo_idx', 'column_names': ['x'], 'unique': False}] + [{"name": "foo_idx", "column_names": ["x"], "unique": False}], ) @testing.provide_metadata @@ -655,60 +757,80 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): m1 = self.metadata Table( - 'Track', m1, Column('TrackID', Integer, primary_key=True), - mysql_engine='InnoDB' + "Track", + m1, + Column("TrackID", Integer, primary_key=True), + mysql_engine="InnoDB", ) Table( - 'Track', m1, Column('TrackID', Integer, primary_key=True), + "Track", + m1, + Column("TrackID", Integer, primary_key=True), schema=testing.config.test_schema, - mysql_engine='InnoDB' + mysql_engine="InnoDB", ) Table( - 'PlaylistTrack', m1, Column('id', Integer, primary_key=True), - Column('TrackID', - ForeignKey('Track.TrackID', name='FK_PlaylistTrackId')), + "PlaylistTrack", + m1, + Column("id", Integer, primary_key=True), + Column( + "TrackID", + ForeignKey("Track.TrackID", name="FK_PlaylistTrackId"), + ), Column( - 'TTrackID', + "TTrackID", ForeignKey( - '%s.Track.TrackID' % (testing.config.test_schema,), - name='FK_PlaylistTTrackId' - ) + "%s.Track.TrackID" % (testing.config.test_schema,), + name="FK_PlaylistTTrackId", + ), ), - mysql_engine='InnoDB' + mysql_engine="InnoDB", ) m1.create_all() if testing.db.dialect._casing in (1, 2): eq_( - inspect(testing.db).get_foreign_keys('PlaylistTrack'), + inspect(testing.db).get_foreign_keys("PlaylistTrack"), [ - {'name': 'FK_PlaylistTTrackId', - 'constrained_columns': ['TTrackID'], - 'referred_schema': testing.config.test_schema, - 'referred_table': 'track', - 'referred_columns': ['TrackID'], 'options': {}}, - {'name': 'FK_PlaylistTrackId', - 'constrained_columns': ['TrackID'], - 'referred_schema': None, - 'referred_table': 'track', - 'referred_columns': ['TrackID'], 'options': {}} - ] + { + "name": "FK_PlaylistTTrackId", + "constrained_columns": ["TTrackID"], + "referred_schema": testing.config.test_schema, + "referred_table": "track", + "referred_columns": ["TrackID"], + "options": {}, + }, + { + "name": "FK_PlaylistTrackId", + "constrained_columns": ["TrackID"], + "referred_schema": None, + "referred_table": "track", + "referred_columns": ["TrackID"], + "options": {}, + }, + ], ) else: eq_( - inspect(testing.db).get_foreign_keys('PlaylistTrack'), + inspect(testing.db).get_foreign_keys("PlaylistTrack"), [ - {'name': 'FK_PlaylistTTrackId', - 'constrained_columns': ['TTrackID'], - 'referred_schema': testing.config.test_schema, - 'referred_table': 'Track', - 'referred_columns': ['TrackID'], 'options': {}}, - {'name': 'FK_PlaylistTrackId', - 'constrained_columns': ['TrackID'], - 'referred_schema': None, - 'referred_table': 'Track', - 'referred_columns': ['TrackID'], 'options': {}} - ] + { + "name": "FK_PlaylistTTrackId", + "constrained_columns": ["TTrackID"], + "referred_schema": testing.config.test_schema, + "referred_table": "Track", + "referred_columns": ["TrackID"], + "options": {}, + }, + { + "name": "FK_PlaylistTrackId", + "constrained_columns": ["TrackID"], + "referred_schema": None, + "referred_table": "Track", + "referred_columns": ["TrackID"], + "options": {}, + }, + ], ) @testing.requires.mysql_fully_case_sensitive @@ -719,48 +841,52 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): # is case sensitive m = self.metadata Table( - 't1', m, - Column('some_id', Integer, primary_key=True), - mysql_engine='InnoDB' - + "t1", + m, + Column("some_id", Integer, primary_key=True), + mysql_engine="InnoDB", ) Table( - 'T1', m, - Column('Some_Id', Integer, primary_key=True), - mysql_engine='InnoDB' + "T1", + m, + Column("Some_Id", Integer, primary_key=True), + mysql_engine="InnoDB", ) Table( - 't2', m, - Column('id', Integer, primary_key=True), - Column('t1id', ForeignKey('t1.some_id', name='t1id_fk')), - Column('cap_t1id', ForeignKey('T1.Some_Id', name='cap_t1id_fk')), - mysql_engine='InnoDB' + "t2", + m, + Column("id", Integer, primary_key=True), + Column("t1id", ForeignKey("t1.some_id", name="t1id_fk")), + Column("cap_t1id", ForeignKey("T1.Some_Id", name="cap_t1id_fk")), + mysql_engine="InnoDB", ) m.create_all(testing.db) eq_( dict( - (rec['name'], rec) - for rec in inspect(testing.db).get_foreign_keys('t2') + (rec["name"], rec) + for rec in inspect(testing.db).get_foreign_keys("t2") ), { - 'cap_t1id_fk': { - 'name': 'cap_t1id_fk', - 'constrained_columns': ['cap_t1id'], - 'referred_schema': None, - 'referred_table': 'T1', - 'referred_columns': ['Some_Id'], 'options': {} + "cap_t1id_fk": { + "name": "cap_t1id_fk", + "constrained_columns": ["cap_t1id"], + "referred_schema": None, + "referred_table": "T1", + "referred_columns": ["Some_Id"], + "options": {}, }, - 't1id_fk': { - 'name': 't1id_fk', - 'constrained_columns': ['t1id'], - 'referred_schema': None, - 'referred_table': 't1', - 'referred_columns': ['some_id'], 'options': {} + "t1id_fk": { + "name": "t1id_fk", + "constrained_columns": ["t1id"], + "referred_schema": None, + "referred_table": "t1", + "referred_columns": ["some_id"], + "options": {}, }, - } + }, ) @@ -770,37 +896,39 @@ class RawReflectionTest(fixtures.TestBase): def setup(self): dialect = mysql.dialect() self.parser = _reflection.MySQLTableDefinitionParser( - dialect, dialect.identifier_preparer) + dialect, dialect.identifier_preparer + ) def test_key_reflection(self): regex = self.parser._re_key - assert regex.match(' PRIMARY KEY (`id`),') - assert regex.match(' PRIMARY KEY USING BTREE (`id`),') - assert regex.match(' PRIMARY KEY (`id`) USING BTREE,') - assert regex.match(' PRIMARY KEY (`id`)') - assert regex.match(' PRIMARY KEY USING BTREE (`id`)') - assert regex.match(' PRIMARY KEY (`id`) USING BTREE') + assert regex.match(" PRIMARY KEY (`id`),") + assert regex.match(" PRIMARY KEY USING BTREE (`id`),") + assert regex.match(" PRIMARY KEY (`id`) USING BTREE,") + assert regex.match(" PRIMARY KEY (`id`)") + assert regex.match(" PRIMARY KEY USING BTREE (`id`)") + assert regex.match(" PRIMARY KEY (`id`) USING BTREE") assert regex.match( - ' PRIMARY KEY (`id`) USING BTREE KEY_BLOCK_SIZE 16') + " PRIMARY KEY (`id`) USING BTREE KEY_BLOCK_SIZE 16" + ) assert regex.match( - ' PRIMARY KEY (`id`) USING BTREE KEY_BLOCK_SIZE=16') + " PRIMARY KEY (`id`) USING BTREE KEY_BLOCK_SIZE=16" + ) assert regex.match( - ' PRIMARY KEY (`id`) USING BTREE KEY_BLOCK_SIZE = 16') + " PRIMARY KEY (`id`) USING BTREE KEY_BLOCK_SIZE = 16" + ) assert not regex.match( - ' PRIMARY KEY (`id`) USING BTREE KEY_BLOCK_SIZE = = 16') - assert regex.match( - " KEY (`id`) USING BTREE COMMENT 'comment'") + " PRIMARY KEY (`id`) USING BTREE KEY_BLOCK_SIZE = = 16" + ) + assert regex.match(" KEY (`id`) USING BTREE COMMENT 'comment'") # `SHOW CREATE TABLE` returns COMMENT '''comment' # after creating table with COMMENT '\'comment' + assert regex.match(" KEY (`id`) USING BTREE COMMENT '''comment'") + assert regex.match(" KEY (`id`) USING BTREE COMMENT 'comment'''") + assert regex.match(" KEY (`id`) USING BTREE COMMENT 'prefix''suffix'") assert regex.match( - " KEY (`id`) USING BTREE COMMENT '''comment'") - assert regex.match( - " KEY (`id`) USING BTREE COMMENT 'comment'''") - assert regex.match( - " KEY (`id`) USING BTREE COMMENT 'prefix''suffix'") - assert regex.match( - " KEY (`id`) USING BTREE COMMENT 'prefix''text''suffix'") + " KEY (`id`) USING BTREE COMMENT 'prefix''text''suffix'" + ) # https://forums.mysql.com/read.php?20,567102,567111#msg-567111 # "It means if the MySQL version >= 501, execute what's in the comment" assert regex.match( @@ -811,65 +939,74 @@ class RawReflectionTest(fixtures.TestBase): def test_key_reflection_columns(self): regex = self.parser._re_key exprs = self.parser._re_keyexprs - m = regex.match( - " KEY (`id`) USING BTREE COMMENT '''comment'") - eq_(m.group("columns"), '`id`') + m = regex.match(" KEY (`id`) USING BTREE COMMENT '''comment'") + eq_(m.group("columns"), "`id`") - m = regex.match( - " KEY (`x`, `y`) USING BTREE") - eq_(m.group("columns"), '`x`, `y`') + m = regex.match(" KEY (`x`, `y`) USING BTREE") + eq_(m.group("columns"), "`x`, `y`") + eq_(exprs.findall(m.group("columns")), [("x", "", ""), ("y", "", "")]) + + m = regex.match(" KEY (`x`(25), `y`(15)) USING BTREE") + eq_(m.group("columns"), "`x`(25), `y`(15)") eq_( exprs.findall(m.group("columns")), - [("x", "", ""), ("y", "", "")] + [("x", "25", ""), ("y", "15", "")], ) - m = regex.match( - " KEY (`x`(25), `y`(15)) USING BTREE") - eq_(m.group("columns"), '`x`(25), `y`(15)') + m = regex.match(" KEY (`x`(25) DESC, `y`(15) ASC) USING BTREE") + eq_(m.group("columns"), "`x`(25) DESC, `y`(15) ASC") eq_( exprs.findall(m.group("columns")), - [("x", "25", ""), ("y", "15", "")] + [("x", "25", "DESC"), ("y", "15", "ASC")], ) + m = regex.match(" KEY `foo_idx` (`x` DESC)") + eq_(m.group("columns"), "`x` DESC") + eq_(exprs.findall(m.group("columns")), [("x", "", "DESC")]) + + eq_(exprs.findall(m.group("columns")), [("x", "", "DESC")]) + + m = regex.match(" KEY `foo_idx` (`x` DESC, `y` ASC)") + eq_(m.group("columns"), "`x` DESC, `y` ASC") + + def test_fk_reflection(self): + regex = self.parser._re_fk_constraint + m = regex.match( - " KEY (`x`(25) DESC, `y`(15) ASC) USING BTREE") - eq_(m.group("columns"), '`x`(25) DESC, `y`(15) ASC') + " CONSTRAINT `addresses_user_id_fkey` " + "FOREIGN KEY (`user_id`) " + "REFERENCES `users` (`id`) " + "ON DELETE CASCADE ON UPDATE CASCADE" + ) eq_( - exprs.findall(m.group("columns")), - [("x", "25", "DESC"), ("y", "15", "ASC")] + m.groups(), + ( + "addresses_user_id_fkey", + "`user_id`", + "`users`", + "`id`", + None, + "CASCADE", + "CASCADE", + ), ) m = regex.match( - " KEY `foo_idx` (`x` DESC)") - eq_(m.group("columns"), '`x` DESC') - eq_( - exprs.findall(m.group("columns")), - [("x", "", "DESC")] + " CONSTRAINT `addresses_user_id_fkey` " + "FOREIGN KEY (`user_id`) " + "REFERENCES `users` (`id`) " + "ON DELETE CASCADE ON UPDATE SET NULL" ) - eq_( - exprs.findall(m.group("columns")), - [("x", "", "DESC")] + m.groups(), + ( + "addresses_user_id_fkey", + "`user_id`", + "`users`", + "`id`", + None, + "CASCADE", + "SET NULL", + ), ) - - m = regex.match( - " KEY `foo_idx` (`x` DESC, `y` ASC)") - eq_(m.group("columns"), '`x` DESC, `y` ASC') - - def test_fk_reflection(self): - regex = self.parser._re_fk_constraint - - m = regex.match(' CONSTRAINT `addresses_user_id_fkey` ' - 'FOREIGN KEY (`user_id`) ' - 'REFERENCES `users` (`id`) ' - 'ON DELETE CASCADE ON UPDATE CASCADE') - eq_(m.groups(), ('addresses_user_id_fkey', '`user_id`', - '`users`', '`id`', None, 'CASCADE', 'CASCADE')) - - m = regex.match(' CONSTRAINT `addresses_user_id_fkey` ' - 'FOREIGN KEY (`user_id`) ' - 'REFERENCES `users` (`id`) ' - 'ON DELETE CASCADE ON UPDATE SET NULL') - eq_(m.groups(), ('addresses_user_id_fkey', '`user_id`', - '`users`', '`id`', None, 'CASCADE', 'SET NULL')) diff --git a/test/dialect/mysql/test_types.py b/test/dialect/mysql/test_types.py index e32b92043d..07c007c532 100644 --- a/test/dialect/mysql/test_types.py +++ b/test/dialect/mysql/test_types.py @@ -6,9 +6,11 @@ from sqlalchemy import sql, exc, schema from sqlalchemy.util import u from sqlalchemy import util from sqlalchemy.dialects.mysql import base as mysql -from sqlalchemy.testing import (fixtures, - AssertsCompiledSQL, - AssertsExecutionResults) +from sqlalchemy.testing import ( + fixtures, + AssertsCompiledSQL, + AssertsExecutionResults, +) from sqlalchemy import testing import datetime import decimal @@ -16,13 +18,13 @@ from sqlalchemy import types as sqltypes from collections import OrderedDict -class TypesTest(fixtures.TestBase, - AssertsExecutionResults, - AssertsCompiledSQL): +class TypesTest( + fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL +): "Test MySQL column types" __dialect__ = mysql.dialect() - __only_on__ = 'mysql' + __only_on__ = "mysql" __backend__ = True def test_numeric(self): @@ -32,148 +34,228 @@ class TypesTest(fixtures.TestBase, # column type, args, kwargs, expected ddl # e.g. Column(Integer(10, unsigned=True)) == # 'INTEGER(10) UNSIGNED' - (mysql.MSNumeric, [], {}, - 'NUMERIC'), - (mysql.MSNumeric, [None], {}, - 'NUMERIC'), - (mysql.MSNumeric, [12], {}, - 'NUMERIC(12)'), - (mysql.MSNumeric, [12, 4], {'unsigned': True}, - 'NUMERIC(12, 4) UNSIGNED'), - (mysql.MSNumeric, [12, 4], {'zerofill': True}, - 'NUMERIC(12, 4) ZEROFILL'), - (mysql.MSNumeric, [12, 4], {'zerofill': True, 'unsigned': True}, - 'NUMERIC(12, 4) UNSIGNED ZEROFILL'), - - (mysql.MSDecimal, [], {}, - 'DECIMAL'), - (mysql.MSDecimal, [None], {}, - 'DECIMAL'), - (mysql.MSDecimal, [12], {}, - 'DECIMAL(12)'), - (mysql.MSDecimal, [12, None], {}, - 'DECIMAL(12)'), - (mysql.MSDecimal, [12, 4], {'unsigned': True}, - 'DECIMAL(12, 4) UNSIGNED'), - (mysql.MSDecimal, [12, 4], {'zerofill': True}, - 'DECIMAL(12, 4) ZEROFILL'), - (mysql.MSDecimal, [12, 4], {'zerofill': True, 'unsigned': True}, - 'DECIMAL(12, 4) UNSIGNED ZEROFILL'), - - (mysql.MSDouble, [None, None], {}, - 'DOUBLE'), - (mysql.MSDouble, [12, 4], {'unsigned': True}, - 'DOUBLE(12, 4) UNSIGNED'), - (mysql.MSDouble, [12, 4], {'zerofill': True}, - 'DOUBLE(12, 4) ZEROFILL'), - (mysql.MSDouble, [12, 4], {'zerofill': True, 'unsigned': True}, - 'DOUBLE(12, 4) UNSIGNED ZEROFILL'), - - (mysql.MSReal, [None, None], {}, - 'REAL'), - (mysql.MSReal, [12, 4], {'unsigned': True}, - 'REAL(12, 4) UNSIGNED'), - (mysql.MSReal, [12, 4], {'zerofill': True}, - 'REAL(12, 4) ZEROFILL'), - (mysql.MSReal, [12, 4], {'zerofill': True, 'unsigned': True}, - 'REAL(12, 4) UNSIGNED ZEROFILL'), - - (mysql.MSFloat, [], {}, - 'FLOAT'), - (mysql.MSFloat, [None], {}, - 'FLOAT'), - (mysql.MSFloat, [12], {}, - 'FLOAT(12)'), - (mysql.MSFloat, [12, 4], {}, - 'FLOAT(12, 4)'), - (mysql.MSFloat, [12, 4], {'unsigned': True}, - 'FLOAT(12, 4) UNSIGNED'), - (mysql.MSFloat, [12, 4], {'zerofill': True}, - 'FLOAT(12, 4) ZEROFILL'), - (mysql.MSFloat, [12, 4], {'zerofill': True, 'unsigned': True}, - 'FLOAT(12, 4) UNSIGNED ZEROFILL'), - - (mysql.MSInteger, [], {}, - 'INTEGER'), - (mysql.MSInteger, [4], {}, - 'INTEGER(4)'), - (mysql.MSInteger, [4], {'unsigned': True}, - 'INTEGER(4) UNSIGNED'), - (mysql.MSInteger, [4], {'zerofill': True}, - 'INTEGER(4) ZEROFILL'), - (mysql.MSInteger, [4], {'zerofill': True, 'unsigned': True}, - 'INTEGER(4) UNSIGNED ZEROFILL'), - - (mysql.MSBigInteger, [], {}, - 'BIGINT'), - (mysql.MSBigInteger, [4], {}, - 'BIGINT(4)'), - (mysql.MSBigInteger, [4], {'unsigned': True}, - 'BIGINT(4) UNSIGNED'), - (mysql.MSBigInteger, [4], {'zerofill': True}, - 'BIGINT(4) ZEROFILL'), - (mysql.MSBigInteger, [4], {'zerofill': True, 'unsigned': True}, - 'BIGINT(4) UNSIGNED ZEROFILL'), - - (mysql.MSMediumInteger, [], {}, - 'MEDIUMINT'), - (mysql.MSMediumInteger, [4], {}, - 'MEDIUMINT(4)'), - (mysql.MSMediumInteger, [4], {'unsigned': True}, - 'MEDIUMINT(4) UNSIGNED'), - (mysql.MSMediumInteger, [4], {'zerofill': True}, - 'MEDIUMINT(4) ZEROFILL'), - (mysql.MSMediumInteger, [4], {'zerofill': True, 'unsigned': True}, - 'MEDIUMINT(4) UNSIGNED ZEROFILL'), - - (mysql.MSTinyInteger, [], {}, - 'TINYINT'), - (mysql.MSTinyInteger, [1], {}, - 'TINYINT(1)'), - (mysql.MSTinyInteger, [1], {'unsigned': True}, - 'TINYINT(1) UNSIGNED'), - (mysql.MSTinyInteger, [1], {'zerofill': True}, - 'TINYINT(1) ZEROFILL'), - (mysql.MSTinyInteger, [1], {'zerofill': True, 'unsigned': True}, - 'TINYINT(1) UNSIGNED ZEROFILL'), - - (mysql.MSSmallInteger, [], {}, - 'SMALLINT'), - (mysql.MSSmallInteger, [4], {}, - 'SMALLINT(4)'), - (mysql.MSSmallInteger, [4], {'unsigned': True}, - 'SMALLINT(4) UNSIGNED'), - (mysql.MSSmallInteger, [4], {'zerofill': True}, - 'SMALLINT(4) ZEROFILL'), - (mysql.MSSmallInteger, [4], {'zerofill': True, 'unsigned': True}, - 'SMALLINT(4) UNSIGNED ZEROFILL'), + (mysql.MSNumeric, [], {}, "NUMERIC"), + (mysql.MSNumeric, [None], {}, "NUMERIC"), + (mysql.MSNumeric, [12], {}, "NUMERIC(12)"), + ( + mysql.MSNumeric, + [12, 4], + {"unsigned": True}, + "NUMERIC(12, 4) UNSIGNED", + ), + ( + mysql.MSNumeric, + [12, 4], + {"zerofill": True}, + "NUMERIC(12, 4) ZEROFILL", + ), + ( + mysql.MSNumeric, + [12, 4], + {"zerofill": True, "unsigned": True}, + "NUMERIC(12, 4) UNSIGNED ZEROFILL", + ), + (mysql.MSDecimal, [], {}, "DECIMAL"), + (mysql.MSDecimal, [None], {}, "DECIMAL"), + (mysql.MSDecimal, [12], {}, "DECIMAL(12)"), + (mysql.MSDecimal, [12, None], {}, "DECIMAL(12)"), + ( + mysql.MSDecimal, + [12, 4], + {"unsigned": True}, + "DECIMAL(12, 4) UNSIGNED", + ), + ( + mysql.MSDecimal, + [12, 4], + {"zerofill": True}, + "DECIMAL(12, 4) ZEROFILL", + ), + ( + mysql.MSDecimal, + [12, 4], + {"zerofill": True, "unsigned": True}, + "DECIMAL(12, 4) UNSIGNED ZEROFILL", + ), + (mysql.MSDouble, [None, None], {}, "DOUBLE"), + ( + mysql.MSDouble, + [12, 4], + {"unsigned": True}, + "DOUBLE(12, 4) UNSIGNED", + ), + ( + mysql.MSDouble, + [12, 4], + {"zerofill": True}, + "DOUBLE(12, 4) ZEROFILL", + ), + ( + mysql.MSDouble, + [12, 4], + {"zerofill": True, "unsigned": True}, + "DOUBLE(12, 4) UNSIGNED ZEROFILL", + ), + (mysql.MSReal, [None, None], {}, "REAL"), + ( + mysql.MSReal, + [12, 4], + {"unsigned": True}, + "REAL(12, 4) UNSIGNED", + ), + ( + mysql.MSReal, + [12, 4], + {"zerofill": True}, + "REAL(12, 4) ZEROFILL", + ), + ( + mysql.MSReal, + [12, 4], + {"zerofill": True, "unsigned": True}, + "REAL(12, 4) UNSIGNED ZEROFILL", + ), + (mysql.MSFloat, [], {}, "FLOAT"), + (mysql.MSFloat, [None], {}, "FLOAT"), + (mysql.MSFloat, [12], {}, "FLOAT(12)"), + (mysql.MSFloat, [12, 4], {}, "FLOAT(12, 4)"), + ( + mysql.MSFloat, + [12, 4], + {"unsigned": True}, + "FLOAT(12, 4) UNSIGNED", + ), + ( + mysql.MSFloat, + [12, 4], + {"zerofill": True}, + "FLOAT(12, 4) ZEROFILL", + ), + ( + mysql.MSFloat, + [12, 4], + {"zerofill": True, "unsigned": True}, + "FLOAT(12, 4) UNSIGNED ZEROFILL", + ), + (mysql.MSInteger, [], {}, "INTEGER"), + (mysql.MSInteger, [4], {}, "INTEGER(4)"), + (mysql.MSInteger, [4], {"unsigned": True}, "INTEGER(4) UNSIGNED"), + (mysql.MSInteger, [4], {"zerofill": True}, "INTEGER(4) ZEROFILL"), + ( + mysql.MSInteger, + [4], + {"zerofill": True, "unsigned": True}, + "INTEGER(4) UNSIGNED ZEROFILL", + ), + (mysql.MSBigInteger, [], {}, "BIGINT"), + (mysql.MSBigInteger, [4], {}, "BIGINT(4)"), + ( + mysql.MSBigInteger, + [4], + {"unsigned": True}, + "BIGINT(4) UNSIGNED", + ), + ( + mysql.MSBigInteger, + [4], + {"zerofill": True}, + "BIGINT(4) ZEROFILL", + ), + ( + mysql.MSBigInteger, + [4], + {"zerofill": True, "unsigned": True}, + "BIGINT(4) UNSIGNED ZEROFILL", + ), + (mysql.MSMediumInteger, [], {}, "MEDIUMINT"), + (mysql.MSMediumInteger, [4], {}, "MEDIUMINT(4)"), + ( + mysql.MSMediumInteger, + [4], + {"unsigned": True}, + "MEDIUMINT(4) UNSIGNED", + ), + ( + mysql.MSMediumInteger, + [4], + {"zerofill": True}, + "MEDIUMINT(4) ZEROFILL", + ), + ( + mysql.MSMediumInteger, + [4], + {"zerofill": True, "unsigned": True}, + "MEDIUMINT(4) UNSIGNED ZEROFILL", + ), + (mysql.MSTinyInteger, [], {}, "TINYINT"), + (mysql.MSTinyInteger, [1], {}, "TINYINT(1)"), + ( + mysql.MSTinyInteger, + [1], + {"unsigned": True}, + "TINYINT(1) UNSIGNED", + ), + ( + mysql.MSTinyInteger, + [1], + {"zerofill": True}, + "TINYINT(1) ZEROFILL", + ), + ( + mysql.MSTinyInteger, + [1], + {"zerofill": True, "unsigned": True}, + "TINYINT(1) UNSIGNED ZEROFILL", + ), + (mysql.MSSmallInteger, [], {}, "SMALLINT"), + (mysql.MSSmallInteger, [4], {}, "SMALLINT(4)"), + ( + mysql.MSSmallInteger, + [4], + {"unsigned": True}, + "SMALLINT(4) UNSIGNED", + ), + ( + mysql.MSSmallInteger, + [4], + {"zerofill": True}, + "SMALLINT(4) ZEROFILL", + ), + ( + mysql.MSSmallInteger, + [4], + {"zerofill": True, "unsigned": True}, + "SMALLINT(4) UNSIGNED ZEROFILL", + ), ] for type_, args, kw, res in columns: type_inst = type_(*args, **kw) - self.assert_compile( - type_inst, - res - ) + self.assert_compile(type_inst, res) # test that repr() copies out all arguments - self.assert_compile( - eval("mysql.%r" % type_inst), - res - ) + self.assert_compile(eval("mysql.%r" % type_inst), res) # fixed in mysql-connector as of 2.0.1, # see http://bugs.mysql.com/bug.php?id=73266 @testing.provide_metadata def test_precision_float_roundtrip(self): - t = Table('t', self.metadata, - Column('scale_value', mysql.DOUBLE( - precision=15, scale=12, asdecimal=True)), - Column('unscale_value', mysql.DOUBLE( - decimal_return_scale=12, asdecimal=True))) + t = Table( + "t", + self.metadata, + Column( + "scale_value", + mysql.DOUBLE(precision=15, scale=12, asdecimal=True), + ), + Column( + "unscale_value", + mysql.DOUBLE(decimal_return_scale=12, asdecimal=True), + ), + ) t.create(testing.db) testing.db.execute( - t.insert(), scale_value=45.768392065789, - unscale_value=45.768392065789 + t.insert(), + scale_value=45.768392065789, + unscale_value=45.768392065789, ) result = testing.db.scalar(select([t.c.scale_value])) eq_(result, decimal.Decimal("45.768392065789")) @@ -181,114 +263,173 @@ class TypesTest(fixtures.TestBase, result = testing.db.scalar(select([t.c.unscale_value])) eq_(result, decimal.Decimal("45.768392065789")) - @testing.exclude('mysql', '<', (4, 1, 1), 'no charset support') + @testing.exclude("mysql", "<", (4, 1, 1), "no charset support") def test_charset(self): """Exercise CHARACTER SET and COLLATE-ish options on string types.""" columns = [ - (mysql.MSChar, [1], {}, - 'CHAR(1)'), - (mysql.NCHAR, [1], {}, - 'NATIONAL CHAR(1)'), - (mysql.MSChar, [1], {'binary': True}, - 'CHAR(1) BINARY'), - (mysql.MSChar, [1], {'ascii': True}, - 'CHAR(1) ASCII'), - (mysql.MSChar, [1], {'unicode': True}, - 'CHAR(1) UNICODE'), - (mysql.MSChar, [1], {'ascii': True, 'binary': True}, - 'CHAR(1) ASCII BINARY'), - (mysql.MSChar, [1], {'unicode': True, 'binary': True}, - 'CHAR(1) UNICODE BINARY'), - (mysql.MSChar, [1], {'charset': 'utf8'}, - 'CHAR(1) CHARACTER SET utf8'), - (mysql.MSChar, [1], {'charset': 'utf8', 'binary': True}, - 'CHAR(1) CHARACTER SET utf8 BINARY'), - (mysql.MSChar, [1], {'charset': 'utf8', 'unicode': True}, - 'CHAR(1) CHARACTER SET utf8'), - (mysql.MSChar, [1], {'charset': 'utf8', 'ascii': True}, - 'CHAR(1) CHARACTER SET utf8'), - (mysql.MSChar, [1], {'collation': 'utf8_bin'}, - 'CHAR(1) COLLATE utf8_bin'), - (mysql.MSChar, [1], {'charset': 'utf8', 'collation': 'utf8_bin'}, - 'CHAR(1) CHARACTER SET utf8 COLLATE utf8_bin'), - (mysql.MSChar, [1], {'charset': 'utf8', 'binary': True}, - 'CHAR(1) CHARACTER SET utf8 BINARY'), - (mysql.MSChar, [1], {'charset': 'utf8', 'collation': 'utf8_bin', - 'binary': True}, - 'CHAR(1) CHARACTER SET utf8 COLLATE utf8_bin'), - (mysql.MSChar, [1], {'national': True}, - 'NATIONAL CHAR(1)'), - (mysql.MSChar, [1], {'national': True, 'charset': 'utf8'}, - 'NATIONAL CHAR(1)'), - (mysql.MSChar, [1], {'national': True, 'charset': 'utf8', - 'binary': True}, - 'NATIONAL CHAR(1) BINARY'), - (mysql.MSChar, [1], {'national': True, 'binary': True, - 'unicode': True}, - 'NATIONAL CHAR(1) BINARY'), - (mysql.MSChar, [1], {'national': True, 'collation': 'utf8_bin'}, - 'NATIONAL CHAR(1) COLLATE utf8_bin'), - - (mysql.MSString, [1], {'charset': 'utf8', 'collation': 'utf8_bin'}, - 'VARCHAR(1) CHARACTER SET utf8 COLLATE utf8_bin'), - (mysql.MSString, [1], {'national': True, 'collation': 'utf8_bin'}, - 'NATIONAL VARCHAR(1) COLLATE utf8_bin'), - - (mysql.MSTinyText, - [], - {'charset': 'utf8', 'collation': 'utf8_bin'}, - 'TINYTEXT CHARACTER SET utf8 COLLATE utf8_bin'), - - (mysql.MSMediumText, [], {'charset': 'utf8', 'binary': True}, - 'MEDIUMTEXT CHARACTER SET utf8 BINARY'), - - (mysql.MSLongText, [], {'ascii': True}, - 'LONGTEXT ASCII'), - - (mysql.ENUM, ["foo", "bar"], {'unicode': True}, - '''ENUM('foo','bar') UNICODE'''), - - (String, [20], {"collation": "utf8"}, 'VARCHAR(20) COLLATE utf8') + (mysql.MSChar, [1], {}, "CHAR(1)"), + (mysql.NCHAR, [1], {}, "NATIONAL CHAR(1)"), + (mysql.MSChar, [1], {"binary": True}, "CHAR(1) BINARY"), + (mysql.MSChar, [1], {"ascii": True}, "CHAR(1) ASCII"), + (mysql.MSChar, [1], {"unicode": True}, "CHAR(1) UNICODE"), + ( + mysql.MSChar, + [1], + {"ascii": True, "binary": True}, + "CHAR(1) ASCII BINARY", + ), + ( + mysql.MSChar, + [1], + {"unicode": True, "binary": True}, + "CHAR(1) UNICODE BINARY", + ), + ( + mysql.MSChar, + [1], + {"charset": "utf8"}, + "CHAR(1) CHARACTER SET utf8", + ), + ( + mysql.MSChar, + [1], + {"charset": "utf8", "binary": True}, + "CHAR(1) CHARACTER SET utf8 BINARY", + ), + ( + mysql.MSChar, + [1], + {"charset": "utf8", "unicode": True}, + "CHAR(1) CHARACTER SET utf8", + ), + ( + mysql.MSChar, + [1], + {"charset": "utf8", "ascii": True}, + "CHAR(1) CHARACTER SET utf8", + ), + ( + mysql.MSChar, + [1], + {"collation": "utf8_bin"}, + "CHAR(1) COLLATE utf8_bin", + ), + ( + mysql.MSChar, + [1], + {"charset": "utf8", "collation": "utf8_bin"}, + "CHAR(1) CHARACTER SET utf8 COLLATE utf8_bin", + ), + ( + mysql.MSChar, + [1], + {"charset": "utf8", "binary": True}, + "CHAR(1) CHARACTER SET utf8 BINARY", + ), + ( + mysql.MSChar, + [1], + {"charset": "utf8", "collation": "utf8_bin", "binary": True}, + "CHAR(1) CHARACTER SET utf8 COLLATE utf8_bin", + ), + (mysql.MSChar, [1], {"national": True}, "NATIONAL CHAR(1)"), + ( + mysql.MSChar, + [1], + {"national": True, "charset": "utf8"}, + "NATIONAL CHAR(1)", + ), + ( + mysql.MSChar, + [1], + {"national": True, "charset": "utf8", "binary": True}, + "NATIONAL CHAR(1) BINARY", + ), + ( + mysql.MSChar, + [1], + {"national": True, "binary": True, "unicode": True}, + "NATIONAL CHAR(1) BINARY", + ), + ( + mysql.MSChar, + [1], + {"national": True, "collation": "utf8_bin"}, + "NATIONAL CHAR(1) COLLATE utf8_bin", + ), + ( + mysql.MSString, + [1], + {"charset": "utf8", "collation": "utf8_bin"}, + "VARCHAR(1) CHARACTER SET utf8 COLLATE utf8_bin", + ), + ( + mysql.MSString, + [1], + {"national": True, "collation": "utf8_bin"}, + "NATIONAL VARCHAR(1) COLLATE utf8_bin", + ), + ( + mysql.MSTinyText, + [], + {"charset": "utf8", "collation": "utf8_bin"}, + "TINYTEXT CHARACTER SET utf8 COLLATE utf8_bin", + ), + ( + mysql.MSMediumText, + [], + {"charset": "utf8", "binary": True}, + "MEDIUMTEXT CHARACTER SET utf8 BINARY", + ), + (mysql.MSLongText, [], {"ascii": True}, "LONGTEXT ASCII"), + ( + mysql.ENUM, + ["foo", "bar"], + {"unicode": True}, + """ENUM('foo','bar') UNICODE""", + ), + (String, [20], {"collation": "utf8"}, "VARCHAR(20) COLLATE utf8"), ] for type_, args, kw, res in columns: type_inst = type_(*args, **kw) - self.assert_compile( - type_inst, - res - ) + self.assert_compile(type_inst, res) # test that repr() copies out all arguments self.assert_compile( eval("mysql.%r" % type_inst) if type_ is not String else eval("%r" % type_inst), - res + res, ) - @testing.only_if('mysql') - @testing.fails_on('mysql+mysqlconnector', "different unicode behavior") - @testing.exclude('mysql', '<', (5, 0, 5), 'a 5.0+ feature') + @testing.only_if("mysql") + @testing.fails_on("mysql+mysqlconnector", "different unicode behavior") + @testing.exclude("mysql", "<", (5, 0, 5), "a 5.0+ feature") @testing.provide_metadata def test_charset_collate_table(self): - t = Table('foo', self.metadata, - Column('id', Integer), - Column('data', UnicodeText), - mysql_default_charset='utf8', - mysql_collate='utf8_bin') + t = Table( + "foo", + self.metadata, + Column("id", Integer), + Column("data", UnicodeText), + mysql_default_charset="utf8", + mysql_collate="utf8_bin", + ) t.create() m2 = MetaData(testing.db) - t2 = Table('foo', m2, autoload=True) - eq_(t2.kwargs['mysql_collate'], 'utf8_bin') - eq_(t2.kwargs['mysql_default charset'], 'utf8') + t2 = Table("foo", m2, autoload=True) + eq_(t2.kwargs["mysql_collate"], "utf8_bin") + eq_(t2.kwargs["mysql_default charset"], "utf8") # test [ticket:2906] # in order to test the condition here, need to use # MySQLdb 1.2.3 and also need to pass either use_unicode=1 # or charset=utf8 to the URL. - t.insert().execute(id=1, data=u('some text')) - assert isinstance(testing.db.scalar(select([t.c.data])), - util.text_type) + t.insert().execute(id=1, data=u("some text")) + assert isinstance( + testing.db.scalar(select([t.c.data])), util.text_type + ) def test_bit_50(self): """Exercise BIT types on 5.0+ (not valid for all engine types)""" @@ -300,22 +441,25 @@ class TypesTest(fixtures.TestBase, ]: self.assert_compile(type_, expected) - @testing.exclude('mysql', '<', (5, 0, 5), 'a 5.0+ feature') + @testing.exclude("mysql", "<", (5, 0, 5), "a 5.0+ feature") @testing.provide_metadata def test_bit_50_roundtrip(self): - bit_table = Table('mysql_bits', self.metadata, - Column('b1', mysql.MSBit), - Column('b2', mysql.MSBit()), - Column('b3', mysql.MSBit(), nullable=False), - Column('b4', mysql.MSBit(1)), - Column('b5', mysql.MSBit(8)), - Column('b6', mysql.MSBit(32)), - Column('b7', mysql.MSBit(63)), - Column('b8', mysql.MSBit(64))) + bit_table = Table( + "mysql_bits", + self.metadata, + Column("b1", mysql.MSBit), + Column("b2", mysql.MSBit()), + Column("b3", mysql.MSBit(), nullable=False), + Column("b4", mysql.MSBit(1)), + Column("b5", mysql.MSBit(8)), + Column("b6", mysql.MSBit(32)), + Column("b7", mysql.MSBit(63)), + Column("b8", mysql.MSBit(64)), + ) self.metadata.create_all() meta2 = MetaData(testing.db) - reflected = Table('mysql_bits', meta2, autoload=True) + reflected = Table("mysql_bits", meta2, autoload=True) for table in bit_table, reflected: @@ -351,20 +495,21 @@ class TypesTest(fixtures.TestBase, (BOOLEAN(), "BOOL"), (Boolean(), "BOOL"), (mysql.TINYINT(1), "TINYINT(1)"), - (mysql.TINYINT(1, unsigned=True), "TINYINT(1) UNSIGNED") + (mysql.TINYINT(1, unsigned=True), "TINYINT(1) UNSIGNED"), ]: self.assert_compile(type_, expected) @testing.provide_metadata def test_boolean_roundtrip(self): bool_table = Table( - 'mysql_bool', + "mysql_bool", self.metadata, - Column('b1', BOOLEAN), - Column('b2', Boolean), - Column('b3', mysql.MSTinyInteger(1)), - Column('b4', mysql.MSTinyInteger(1, unsigned=True)), - Column('b5', mysql.MSTinyInteger)) + Column("b1", BOOLEAN), + Column("b2", Boolean), + Column("b3", mysql.MSTinyInteger(1)), + Column("b4", mysql.MSTinyInteger(1, unsigned=True)), + Column("b5", mysql.MSTinyInteger), + ) self.metadata.create_all() table = bool_table @@ -381,120 +526,141 @@ class TypesTest(fixtures.TestBase, roundtrip([None, None, None, None, None]) roundtrip([True, True, 1, 1, 1]) roundtrip([False, False, 0, 0, 0]) - roundtrip([True, True, True, True, True], [True, True, 1, - 1, 1]) + roundtrip([True, True, True, True, True], [True, True, 1, 1, 1]) roundtrip([False, False, 0, 0, 0], [False, False, 0, 0, 0]) meta2 = MetaData(testing.db) - table = Table('mysql_bool', meta2, autoload=True) - eq_(colspec(table.c.b3), 'b3 TINYINT(1)') - eq_(colspec(table.c.b4), 'b4 TINYINT(1) UNSIGNED') + table = Table("mysql_bool", meta2, autoload=True) + eq_(colspec(table.c.b3), "b3 TINYINT(1)") + eq_(colspec(table.c.b4), "b4 TINYINT(1) UNSIGNED") meta2 = MetaData(testing.db) table = Table( - 'mysql_bool', + "mysql_bool", meta2, - Column('b1', BOOLEAN), - Column('b2', Boolean), - Column('b3', BOOLEAN), - Column('b4', BOOLEAN), - autoload=True) - eq_(colspec(table.c.b3), 'b3 BOOL') - eq_(colspec(table.c.b4), 'b4 BOOL') + Column("b1", BOOLEAN), + Column("b2", Boolean), + Column("b3", BOOLEAN), + Column("b4", BOOLEAN), + autoload=True, + ) + eq_(colspec(table.c.b3), "b3 BOOL") + eq_(colspec(table.c.b4), "b4 BOOL") roundtrip([None, None, None, None, None]) - roundtrip([True, True, 1, 1, 1], [True, True, True, True, - 1]) - roundtrip([False, False, 0, 0, 0], [False, False, False, - False, 0]) - roundtrip([True, True, True, True, True], [True, True, - True, True, 1]) - roundtrip([False, False, 0, 0, 0], [False, False, False, - False, 0]) + roundtrip([True, True, 1, 1, 1], [True, True, True, True, 1]) + roundtrip([False, False, 0, 0, 0], [False, False, False, False, 0]) + roundtrip([True, True, True, True, True], [True, True, True, True, 1]) + roundtrip([False, False, 0, 0, 0], [False, False, False, False, 0]) def test_timestamp_fsp(self): - self.assert_compile( - mysql.TIMESTAMP(fsp=5), - "TIMESTAMP(5)" - ) + self.assert_compile(mysql.TIMESTAMP(fsp=5), "TIMESTAMP(5)") def test_timestamp_defaults(self): """Exercise funky TIMESTAMP default syntax when used in columns.""" columns = [ - ([TIMESTAMP], {}, - 'TIMESTAMP NULL'), - - ([mysql.MSTimeStamp], {}, - 'TIMESTAMP NULL'), - - ([mysql.MSTimeStamp(), - DefaultClause(sql.text('CURRENT_TIMESTAMP'))], - {}, - "TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP"), - - ([mysql.MSTimeStamp, - DefaultClause(sql.text('CURRENT_TIMESTAMP'))], - {'nullable': False}, - "TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP"), - - ([mysql.MSTimeStamp, - DefaultClause(sql.text("'1999-09-09 09:09:09'"))], - {'nullable': False}, - "TIMESTAMP NOT NULL DEFAULT '1999-09-09 09:09:09'"), - - ([mysql.MSTimeStamp(), - DefaultClause(sql.text("'1999-09-09 09:09:09'"))], - {}, - "TIMESTAMP NULL DEFAULT '1999-09-09 09:09:09'"), - - ([mysql.MSTimeStamp(), - DefaultClause(sql.text( - "'1999-09-09 09:09:09' " - "ON UPDATE CURRENT_TIMESTAMP"))], - {}, - "TIMESTAMP NULL DEFAULT '1999-09-09 09:09:09' " - "ON UPDATE CURRENT_TIMESTAMP"), - - ([mysql.MSTimeStamp, - DefaultClause(sql.text( - "'1999-09-09 09:09:09' " - "ON UPDATE CURRENT_TIMESTAMP"))], - {'nullable': False}, - "TIMESTAMP NOT NULL DEFAULT '1999-09-09 09:09:09' " - "ON UPDATE CURRENT_TIMESTAMP"), - - ([mysql.MSTimeStamp(), - DefaultClause(sql.text( - "CURRENT_TIMESTAMP " - "ON UPDATE CURRENT_TIMESTAMP"))], - {}, - "TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP " - "ON UPDATE CURRENT_TIMESTAMP"), - - ([mysql.MSTimeStamp, - DefaultClause(sql.text( - "CURRENT_TIMESTAMP " - "ON UPDATE CURRENT_TIMESTAMP"))], - {'nullable': False}, - "TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP " - "ON UPDATE CURRENT_TIMESTAMP"), + ([TIMESTAMP], {}, "TIMESTAMP NULL"), + ([mysql.MSTimeStamp], {}, "TIMESTAMP NULL"), + ( + [ + mysql.MSTimeStamp(), + DefaultClause(sql.text("CURRENT_TIMESTAMP")), + ], + {}, + "TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP", + ), + ( + [ + mysql.MSTimeStamp, + DefaultClause(sql.text("CURRENT_TIMESTAMP")), + ], + {"nullable": False}, + "TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP", + ), + ( + [ + mysql.MSTimeStamp, + DefaultClause(sql.text("'1999-09-09 09:09:09'")), + ], + {"nullable": False}, + "TIMESTAMP NOT NULL DEFAULT '1999-09-09 09:09:09'", + ), + ( + [ + mysql.MSTimeStamp(), + DefaultClause(sql.text("'1999-09-09 09:09:09'")), + ], + {}, + "TIMESTAMP NULL DEFAULT '1999-09-09 09:09:09'", + ), + ( + [ + mysql.MSTimeStamp(), + DefaultClause( + sql.text( + "'1999-09-09 09:09:09' " + "ON UPDATE CURRENT_TIMESTAMP" + ) + ), + ], + {}, + "TIMESTAMP NULL DEFAULT '1999-09-09 09:09:09' " + "ON UPDATE CURRENT_TIMESTAMP", + ), + ( + [ + mysql.MSTimeStamp, + DefaultClause( + sql.text( + "'1999-09-09 09:09:09' " + "ON UPDATE CURRENT_TIMESTAMP" + ) + ), + ], + {"nullable": False}, + "TIMESTAMP NOT NULL DEFAULT '1999-09-09 09:09:09' " + "ON UPDATE CURRENT_TIMESTAMP", + ), + ( + [ + mysql.MSTimeStamp(), + DefaultClause( + sql.text( + "CURRENT_TIMESTAMP " "ON UPDATE CURRENT_TIMESTAMP" + ) + ), + ], + {}, + "TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP " + "ON UPDATE CURRENT_TIMESTAMP", + ), + ( + [ + mysql.MSTimeStamp, + DefaultClause( + sql.text( + "CURRENT_TIMESTAMP " "ON UPDATE CURRENT_TIMESTAMP" + ) + ), + ], + {"nullable": False}, + "TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP " + "ON UPDATE CURRENT_TIMESTAMP", + ), ] for spec, kw, expected in columns: - c = Column('t', *spec, **kw) - Table('t', MetaData(), c) - self.assert_compile( - schema.CreateColumn(c), - "t %s" % expected - - ) + c = Column("t", *spec, **kw) + Table("t", MetaData(), c) + self.assert_compile(schema.CreateColumn(c), "t %s" % expected) @testing.requires.mysql_zero_date @testing.provide_metadata def test_timestamp_nullable(self): ts_table = Table( - 'mysql_timestamp', self.metadata, - Column('t1', TIMESTAMP), - Column('t2', TIMESTAMP, nullable=False), - mysql_engine='InnoDB' + "mysql_timestamp", + self.metadata, + Column("t1", TIMESTAMP), + Column("t2", TIMESTAMP, nullable=False), + mysql_engine="InnoDB", ) self.metadata.create_all() @@ -517,62 +683,44 @@ class TypesTest(fixtures.TestBase, with testing.db.begin() as conn: now = conn.scalar("select now()") - conn.execute( - ts_table.insert(), {'t1': now, 't2': None}) - conn.execute( - ts_table.insert(), {'t1': None, 't2': None}) - conn.execute( - ts_table.insert(), {'t2': None}) + conn.execute(ts_table.insert(), {"t1": now, "t2": None}) + conn.execute(ts_table.insert(), {"t1": None, "t2": None}) + conn.execute(ts_table.insert(), {"t2": None}) eq_( - [tuple([normalize(dt) for dt in row]) - for row in conn.execute(ts_table.select())], [ - (now, now), - (None, now), - (None, now) - ] + tuple([normalize(dt) for dt in row]) + for row in conn.execute(ts_table.select()) + ], + [(now, now), (None, now), (None, now)], ) def test_datetime_generic(self): - self.assert_compile( - mysql.DATETIME(), - "DATETIME" - ) + self.assert_compile(mysql.DATETIME(), "DATETIME") def test_datetime_fsp(self): - self.assert_compile( - mysql.DATETIME(fsp=4), - "DATETIME(4)" - ) + self.assert_compile(mysql.DATETIME(fsp=4), "DATETIME(4)") def test_time_generic(self): """"Exercise TIME.""" - self.assert_compile( - mysql.TIME(), - "TIME" - ) + self.assert_compile(mysql.TIME(), "TIME") def test_time_fsp(self): - self.assert_compile( - mysql.TIME(fsp=5), - "TIME(5)" - ) + self.assert_compile(mysql.TIME(fsp=5), "TIME(5)") def test_time_result_processor(self): eq_( mysql.TIME().result_processor(None, None)( - datetime.timedelta(seconds=35, minutes=517, - microseconds=450)), - datetime.time(8, 37, 35, 450) + datetime.timedelta(seconds=35, minutes=517, microseconds=450) + ), + datetime.time(8, 37, 35, 450), ) @testing.fails_on("mysql+oursql", "TODO: probable OurSQL bug") @testing.provide_metadata def test_time_roundtrip(self): - t = Table('mysql_time', self.metadata, - Column('t1', mysql.TIME())) + t = Table("mysql_time", self.metadata, Column("t1", mysql.TIME())) t.create() t.insert().values(t1=datetime.time(8, 37, 35)).execute() eq_(select([t.c.t1]).scalar(), datetime.time(8, 37, 35)) @@ -581,43 +729,42 @@ class TypesTest(fixtures.TestBase, def test_year(self): """Exercise YEAR.""" - year_table = Table('mysql_year', self.metadata, - Column('y1', mysql.MSYear), - Column('y2', mysql.MSYear), - Column('y3', mysql.MSYear), - Column('y5', mysql.MSYear(4))) + year_table = Table( + "mysql_year", + self.metadata, + Column("y1", mysql.MSYear), + Column("y2", mysql.MSYear), + Column("y3", mysql.MSYear), + Column("y5", mysql.MSYear(4)), + ) for col in year_table.c: self.assert_(repr(col)) year_table.create() - reflected = Table('mysql_year', MetaData(testing.db), - autoload=True) + reflected = Table("mysql_year", MetaData(testing.db), autoload=True) for table in year_table, reflected: - table.insert(['1950', '50', None, 1950]).execute() + table.insert(["1950", "50", None, 1950]).execute() row = table.select().execute().first() eq_(list(row), [1950, 2050, None, 1950]) table.delete().execute() - self.assert_(colspec(table.c.y1).startswith('y1 YEAR')) - eq_(colspec(table.c.y5), 'y5 YEAR(4)') + self.assert_(colspec(table.c.y1).startswith("y1 YEAR")) + eq_(colspec(table.c.y5), "y5 YEAR(4)") class JSONTest(fixtures.TestBase): - __requires__ = ('json_type', ) - __only_on__ = 'mysql' + __requires__ = ("json_type",) + __only_on__ = "mysql" __backend__ = True @testing.provide_metadata @testing.requires.reflects_json_type def test_reflection(self): - Table( - 'mysql_json', self.metadata, - Column('foo', mysql.JSON) - ) + Table("mysql_json", self.metadata, Column("foo", mysql.JSON)) self.metadata.create_all() - reflected = Table('mysql_json', MetaData(), autoload_with=testing.db) + reflected = Table("mysql_json", MetaData(), autoload_with=testing.db) is_(reflected.c.foo.type._type_affinity, sqltypes.JSON) assert isinstance(reflected.c.foo.type, mysql.JSON) @@ -627,29 +774,23 @@ class JSONTest(fixtures.TestBase): # using the backend-agnostic JSON type mysql_json = Table( - 'mysql_json', self.metadata, - Column('foo', mysql.JSON) + "mysql_json", self.metadata, Column("foo", mysql.JSON) ) self.metadata.create_all() - value = { - 'json': {'foo': 'bar'}, - 'recs': ['one', 'two'] - } + value = {"json": {"foo": "bar"}, "recs": ["one", "two"]} with testing.db.connect() as conn: conn.execute(mysql_json.insert(), foo=value) - eq_( - conn.scalar(select([mysql_json.c.foo])), - value - ) + eq_(conn.scalar(select([mysql_json.c.foo])), value) class EnumSetTest( - fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL): + fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL +): - __only_on__ = 'mysql' + __only_on__ = "mysql" __dialect__ = mysql.dialect() __backend__ = True @@ -663,11 +804,11 @@ class EnumSetTest( self.__members__[name] = self setattr(self.__class__, name, self) - one = SomeEnum('one', 1) - two = SomeEnum('two', 2) - three = SomeEnum('three', 3) - a_member = SomeEnum('AMember', 'a') - b_member = SomeEnum('BMember', 'b') + one = SomeEnum("one", 1) + two = SomeEnum("two", 2) + three = SomeEnum("three", 3) + a_member = SomeEnum("AMember", "a") + b_member = SomeEnum("BMember", "b") @staticmethod def get_enum_string_values(some_enum): @@ -677,173 +818,210 @@ class EnumSetTest( def test_enum(self): """Exercise the ENUM type.""" - with testing.expect_deprecated('Manually quoting ENUM value literals'): + with testing.expect_deprecated("Manually quoting ENUM value literals"): e1, e2 = mysql.ENUM("'a'", "'b'"), mysql.ENUM("'a'", "'b'") e3 = mysql.ENUM("'a'", "'b'", strict=True) e4 = mysql.ENUM("'a'", "'b'", strict=True) enum_table = Table( - 'mysql_enum', self.metadata, - Column('e1', e1), - Column('e2', e2, nullable=False), - Column('e2generic', - Enum("a", "b", validate_strings=True), nullable=False), - Column('e3', e3), - Column('e4', e4, - nullable=False), - Column('e5', mysql.ENUM("a", "b")), - Column('e5generic', Enum("a", "b")), - Column('e6', mysql.ENUM("'a'", "b")), - Column('e7', mysql.ENUM(EnumSetTest.SomeEnum, - values_callable=EnumSetTest. - get_enum_string_values)), - Column('e8', mysql.ENUM(EnumSetTest.SomeEnum)) + "mysql_enum", + self.metadata, + Column("e1", e1), + Column("e2", e2, nullable=False), + Column( + "e2generic", + Enum("a", "b", validate_strings=True), + nullable=False, + ), + Column("e3", e3), + Column("e4", e4, nullable=False), + Column("e5", mysql.ENUM("a", "b")), + Column("e5generic", Enum("a", "b")), + Column("e6", mysql.ENUM("'a'", "b")), + Column( + "e7", + mysql.ENUM( + EnumSetTest.SomeEnum, + values_callable=EnumSetTest.get_enum_string_values, + ), + ), + Column("e8", mysql.ENUM(EnumSetTest.SomeEnum)), ) + eq_(colspec(enum_table.c.e1), "e1 ENUM('a','b')") + eq_(colspec(enum_table.c.e2), "e2 ENUM('a','b') NOT NULL") eq_( - colspec(enum_table.c.e1), - "e1 ENUM('a','b')") - eq_( - colspec(enum_table.c.e2), - "e2 ENUM('a','b') NOT NULL") - eq_( - colspec(enum_table.c.e2generic), - "e2generic ENUM('a','b') NOT NULL") - eq_( - colspec(enum_table.c.e3), - "e3 ENUM('a','b')") - eq_( - colspec(enum_table.c.e4), - "e4 ENUM('a','b') NOT NULL") - eq_( - colspec(enum_table.c.e5), - "e5 ENUM('a','b')") - eq_( - colspec(enum_table.c.e5generic), - "e5generic ENUM('a','b')") - eq_( - colspec(enum_table.c.e6), - "e6 ENUM('''a''','b')") - eq_( - colspec(enum_table.c.e7), - "e7 ENUM('1','2','3','a','b')" + colspec(enum_table.c.e2generic), "e2generic ENUM('a','b') NOT NULL" ) + eq_(colspec(enum_table.c.e3), "e3 ENUM('a','b')") + eq_(colspec(enum_table.c.e4), "e4 ENUM('a','b') NOT NULL") + eq_(colspec(enum_table.c.e5), "e5 ENUM('a','b')") + eq_(colspec(enum_table.c.e5generic), "e5generic ENUM('a','b')") + eq_(colspec(enum_table.c.e6), "e6 ENUM('''a''','b')") + eq_(colspec(enum_table.c.e7), "e7 ENUM('1','2','3','a','b')") eq_( colspec(enum_table.c.e8), - "e8 ENUM('one','two','three','AMember','BMember')" + "e8 ENUM('one','two','three','AMember','BMember')", ) enum_table.create() assert_raises( - exc.DBAPIError, enum_table.insert().execute, - e1=None, e2=None, e3=None, e4=None) + exc.DBAPIError, + enum_table.insert().execute, + e1=None, + e2=None, + e3=None, + e4=None, + ) assert enum_table.c.e2generic.type.validate_strings assert_raises( exc.StatementError, enum_table.insert().execute, - e1='c', e2='c', e2generic='c', e3='c', - e4='c', e5='c', e5generic='c', e6='c', - e7='c', e8='c') + e1="c", + e2="c", + e2generic="c", + e3="c", + e4="c", + e5="c", + e5generic="c", + e6="c", + e7="c", + e8="c", + ) enum_table.insert().execute() - enum_table.insert().execute(e1='a', e2='a', e2generic='a', e3='a', - e4='a', e5='a', e5generic='a', e6="'a'", - e7='a', e8='AMember') - enum_table.insert().execute(e1='b', e2='b', e2generic='b', e3='b', - e4='b', e5='b', e5generic='b', e6='b', - e7='b', e8='BMember') + enum_table.insert().execute( + e1="a", + e2="a", + e2generic="a", + e3="a", + e4="a", + e5="a", + e5generic="a", + e6="'a'", + e7="a", + e8="AMember", + ) + enum_table.insert().execute( + e1="b", + e2="b", + e2generic="b", + e3="b", + e4="b", + e5="b", + e5generic="b", + e6="b", + e7="b", + e8="BMember", + ) res = enum_table.select().execute().fetchall() - expected = [(None, 'a', 'a', None, 'a', None, None, None, - None, None), - ('a', 'a', 'a', 'a', 'a', 'a', 'a', "'a'", - EnumSetTest.SomeEnum.AMember, - EnumSetTest.SomeEnum.AMember), - ('b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', - EnumSetTest.SomeEnum.BMember, - EnumSetTest.SomeEnum.BMember)] + expected = [ + (None, "a", "a", None, "a", None, None, None, None, None), + ( + "a", + "a", + "a", + "a", + "a", + "a", + "a", + "'a'", + EnumSetTest.SomeEnum.AMember, + EnumSetTest.SomeEnum.AMember, + ), + ( + "b", + "b", + "b", + "b", + "b", + "b", + "b", + "b", + EnumSetTest.SomeEnum.BMember, + EnumSetTest.SomeEnum.BMember, + ), + ] eq_(res, expected) def _set_fixture_one(self): - with testing.expect_deprecated('Manually quoting SET value literals'): + with testing.expect_deprecated("Manually quoting SET value literals"): e1, e2 = mysql.SET("'a'", "'b'"), mysql.SET("'a'", "'b'") e4 = mysql.SET("'a'", "b") e5 = mysql.SET("'a'", "'b'", quoting="quoted") set_table = Table( - 'mysql_set', self.metadata, - Column('e1', e1), - Column('e2', e2, nullable=False), - Column('e3', mysql.SET("a", "b")), - Column('e4', e4), - Column('e5', e5) + "mysql_set", + self.metadata, + Column("e1", e1), + Column("e2", e2, nullable=False), + Column("e3", mysql.SET("a", "b")), + Column("e4", e4), + Column("e5", e5), ) return set_table def test_set_colspec(self): self.metadata = MetaData() set_table = self._set_fixture_one() - eq_( - colspec(set_table.c.e1), - "e1 SET('a','b')") - eq_(colspec( - set_table.c.e2), - "e2 SET('a','b') NOT NULL") - eq_( - colspec(set_table.c.e3), - "e3 SET('a','b')") - eq_( - colspec(set_table.c.e4), - "e4 SET('''a''','b')") - eq_( - colspec(set_table.c.e5), - "e5 SET('a','b')") + eq_(colspec(set_table.c.e1), "e1 SET('a','b')") + eq_(colspec(set_table.c.e2), "e2 SET('a','b') NOT NULL") + eq_(colspec(set_table.c.e3), "e3 SET('a','b')") + eq_(colspec(set_table.c.e4), "e4 SET('''a''','b')") + eq_(colspec(set_table.c.e5), "e5 SET('a','b')") @testing.provide_metadata def test_no_null(self): set_table = self._set_fixture_one() set_table.create() assert_raises( - exc.DBAPIError, set_table.insert().execute, - e1=None, e2=None, e3=None, e4=None) + exc.DBAPIError, + set_table.insert().execute, + e1=None, + e2=None, + e3=None, + e4=None, + ) - @testing.only_on('+oursql') + @testing.only_on("+oursql") @testing.provide_metadata def test_oursql_error_one(self): set_table = self._set_fixture_one() set_table.create() assert_raises( - exc.StatementError, set_table.insert().execute, - e1='c', e2='c', e3='c', e4='c') + exc.StatementError, + set_table.insert().execute, + e1="c", + e2="c", + e3="c", + e4="c", + ) @testing.requires.mysql_non_strict @testing.provide_metadata def test_empty_set_no_empty_string(self): t = Table( - 't', self.metadata, - Column('id', Integer), - Column('data', mysql.SET("a", "b")) + "t", + self.metadata, + Column("id", Integer), + Column("data", mysql.SET("a", "b")), ) t.create() with testing.db.begin() as conn: conn.execute( t.insert(), - {'id': 1, 'data': set()}, - {'id': 2, 'data': set([''])}, - {'id': 3, 'data': set(['a', ''])}, - {'id': 4, 'data': set(['b'])}, + {"id": 1, "data": set()}, + {"id": 2, "data": set([""])}, + {"id": 3, "data": set(["a", ""])}, + {"id": 4, "data": set(["b"])}, ) eq_( conn.execute(t.select().order_by(t.c.id)).fetchall(), - [ - (1, set()), - (2, set()), - (3, set(['a'])), - (4, set(['b'])), - ] + [(1, set()), (2, set()), (3, set(["a"])), (4, set(["b"]))], ) def test_bitwise_required_for_empty(self): @@ -851,33 +1029,37 @@ class EnumSetTest( exc.ArgumentError, "Can't use the blank value '' in a SET without setting " "retrieve_as_bitwise=True", - mysql.SET, "a", "b", '' + mysql.SET, + "a", + "b", + "", ) @testing.provide_metadata def test_empty_set_empty_string(self): t = Table( - 't', self.metadata, - Column('id', Integer), - Column('data', mysql.SET("a", "b", '', retrieve_as_bitwise=True)) + "t", + self.metadata, + Column("id", Integer), + Column("data", mysql.SET("a", "b", "", retrieve_as_bitwise=True)), ) t.create() with testing.db.begin() as conn: conn.execute( t.insert(), - {'id': 1, 'data': set()}, - {'id': 2, 'data': set([''])}, - {'id': 3, 'data': set(['a', ''])}, - {'id': 4, 'data': set(['b'])}, + {"id": 1, "data": set()}, + {"id": 2, "data": set([""])}, + {"id": 3, "data": set(["a", ""])}, + {"id": 4, "data": set(["b"])}, ) eq_( conn.execute(t.select().order_by(t.c.id)).fetchall(), [ (1, set()), - (2, set([''])), - (3, set(['a', ''])), - (4, set(['b'])), - ] + (2, set([""])), + (3, set(["a", ""])), + (4, set(["b"])), + ], ) @testing.provide_metadata @@ -887,46 +1069,56 @@ class EnumSetTest( with testing.db.begin() as conn: conn.execute( set_table.insert(), - dict(e1='a', e2='a', e3='a', e4="'a'", e5="a,b")) + dict(e1="a", e2="a", e3="a", e4="'a'", e5="a,b"), + ) conn.execute( set_table.insert(), - dict(e1='b', e2='b', e3='b', e4='b', e5="a,b")) + dict(e1="b", e2="b", e3="b", e4="b", e5="a,b"), + ) expected = [ - (set(['a']), set(['a']), set(['a']), - set(["'a'"]), set(['a', 'b'])), - (set(['b']), set(['b']), set(['b']), - set(['b']), set(['a', 'b'])) + ( + set(["a"]), + set(["a"]), + set(["a"]), + set(["'a'"]), + set(["a", "b"]), + ), + ( + set(["b"]), + set(["b"]), + set(["b"]), + set(["b"]), + set(["a", "b"]), + ), ] - res = conn.execute( - set_table.select() - ).fetchall() + res = conn.execute(set_table.select()).fetchall() eq_(res, expected) @testing.provide_metadata def test_unicode_roundtrip(self): set_table = Table( - 't', self.metadata, - Column('id', Integer, primary_key=True), - Column('data', mysql.SET( - u('réveillé'), u('drôle'), u('S’il'), convert_unicode=True)), + "t", + self.metadata, + Column("id", Integer, primary_key=True), + Column( + "data", + mysql.SET( + u("réveillé"), u("drôle"), u("S’il"), convert_unicode=True + ), + ), ) set_table.create() with testing.db.begin() as conn: conn.execute( - set_table.insert(), - {"data": set([u('réveillé'), u('drôle')])}) + set_table.insert(), {"data": set([u("réveillé"), u("drôle")])} + ) - row = conn.execute( - set_table.select() - ).first() + row = conn.execute(set_table.select()).first() - eq_( - row, - (1, set([u('réveillé'), u('drôle')])) - ) + eq_(row, (1, set([u("réveillé"), u("drôle")]))) @testing.provide_metadata def test_int_roundtrip(self): @@ -934,31 +1126,35 @@ class EnumSetTest( set_table.create() with testing.db.begin() as conn: conn.execute( - set_table.insert(), - dict(e1=1, e2=2, e3=3, e4=3, e5=0) + set_table.insert(), dict(e1=1, e2=2, e3=3, e4=3, e5=0) ) res = conn.execute(set_table.select()).first() eq_( res, ( - set(['a']), set(['b']), set(['a', 'b']), - set(["'a'", 'b']), set([])) + set(["a"]), + set(["b"]), + set(["a", "b"]), + set(["'a'", "b"]), + set([]), + ), ) @testing.provide_metadata def test_set_roundtrip_plus_reflection(self): set_table = Table( - 'mysql_set', self.metadata, - Column('s1', mysql.SET("dq", "sq")), - Column('s2', mysql.SET("a")), - Column('s3', mysql.SET("5", "7", "9"))) + "mysql_set", + self.metadata, + Column("s1", mysql.SET("dq", "sq")), + Column("s2", mysql.SET("a")), + Column("s3", mysql.SET("5", "7", "9")), + ) eq_(colspec(set_table.c.s1), "s1 SET('dq','sq')") eq_(colspec(set_table.c.s2), "s2 SET('a')") eq_(colspec(set_table.c.s3), "s3 SET('5','7','9')") set_table.create() - reflected = Table('mysql_set', MetaData(testing.db), - autoload=True) + reflected = Table("mysql_set", MetaData(testing.db), autoload=True) for table in set_table, reflected: def roundtrip(store, expected=None): @@ -969,97 +1165,110 @@ class EnumSetTest( table.delete().execute() roundtrip([None, None, None], [None] * 3) - roundtrip(['', '', ''], [set([])] * 3) - roundtrip([set(['dq']), set(['a']), set(['5'])]) - roundtrip(['dq', 'a', '5'], [set(['dq']), set(['a']), - set(['5'])]) - roundtrip([1, 1, 1], [set(['dq']), set(['a']), set(['5'])]) - roundtrip([set(['dq', 'sq']), None, set(['9', '5', '7'])]) + roundtrip(["", "", ""], [set([])] * 3) + roundtrip([set(["dq"]), set(["a"]), set(["5"])]) + roundtrip(["dq", "a", "5"], [set(["dq"]), set(["a"]), set(["5"])]) + roundtrip([1, 1, 1], [set(["dq"]), set(["a"]), set(["5"])]) + roundtrip([set(["dq", "sq"]), None, set(["9", "5", "7"])]) set_table.insert().execute( - {'s3': set(['5'])}, - {'s3': set(['5', '7'])}, - {'s3': set(['5', '7', '9'])}, - {'s3': set(['7', '9'])}) - - rows = select( - [set_table.c.s3], - set_table.c.s3.in_([set(['5']), ['5', '7']]) - ).execute().fetchall() + {"s3": set(["5"])}, + {"s3": set(["5", "7"])}, + {"s3": set(["5", "7", "9"])}, + {"s3": set(["7", "9"])}, + ) + + rows = ( + select( + [set_table.c.s3], set_table.c.s3.in_([set(["5"]), ["5", "7"]]) + ) + .execute() + .fetchall() + ) found = set([frozenset(row[0]) for row in rows]) - eq_(found, set([frozenset(['5']), frozenset(['5', '7'])])) + eq_(found, set([frozenset(["5"]), frozenset(["5", "7"])])) @testing.provide_metadata def test_unicode_enum(self): metadata = self.metadata t1 = Table( - 'table', metadata, - Column('id', Integer, primary_key=True), - Column('value', Enum(u('réveillé'), u('drôle'), u('S’il'))), - Column('value2', mysql.ENUM(u('réveillé'), u('drôle'), u('S’il'))) + "table", + metadata, + Column("id", Integer, primary_key=True), + Column("value", Enum(u("réveillé"), u("drôle"), u("S’il"))), + Column("value2", mysql.ENUM(u("réveillé"), u("drôle"), u("S’il"))), ) metadata.create_all() - t1.insert().execute(value=u('drôle'), value2=u('drôle')) - t1.insert().execute(value=u('réveillé'), value2=u('réveillé')) - t1.insert().execute(value=u('S’il'), value2=u('S’il')) - eq_(t1.select().order_by(t1.c.id).execute().fetchall(), + t1.insert().execute(value=u("drôle"), value2=u("drôle")) + t1.insert().execute(value=u("réveillé"), value2=u("réveillé")) + t1.insert().execute(value=u("S’il"), value2=u("S’il")) + eq_( + t1.select().order_by(t1.c.id).execute().fetchall(), [ - (1, u('drôle'), u('drôle')), - (2, u('réveillé'), u('réveillé')), - (3, u('S’il'), u('S’il')) - ]) + (1, u("drôle"), u("drôle")), + (2, u("réveillé"), u("réveillé")), + (3, u("S’il"), u("S’il")), + ], + ) # test reflection of the enum labels m2 = MetaData(testing.db) - t2 = Table('table', m2, autoload=True) + t2 = Table("table", m2, autoload=True) # TODO: what's wrong with the last element ? is there # latin-1 stuff forcing its way in ? eq_( t2.c.value.type.enums[0:2], - [u('réveillé'), u('drôle')] # u'S’il') # eh ? + [u("réveillé"), u("drôle")], # u'S’il') # eh ? ) eq_( t2.c.value2.type.enums[0:2], - [u('réveillé'), u('drôle')] # u'S’il') # eh ? + [u("réveillé"), u("drôle")], # u'S’il') # eh ? ) def test_enum_compile(self): - e1 = Enum('x', 'y', 'z', name='somename') - t1 = Table('sometable', MetaData(), Column('somecolumn', e1)) - self.assert_compile(schema.CreateTable(t1), - "CREATE TABLE sometable (somecolumn " - "ENUM('x','y','z'))") - t1 = Table('sometable', MetaData(), Column('somecolumn', - Enum('x', 'y', 'z', native_enum=False))) - self.assert_compile(schema.CreateTable(t1), - "CREATE TABLE sometable (somecolumn " - "VARCHAR(1), CHECK (somecolumn IN ('x', " - "'y', 'z')))") + e1 = Enum("x", "y", "z", name="somename") + t1 = Table("sometable", MetaData(), Column("somecolumn", e1)) + self.assert_compile( + schema.CreateTable(t1), + "CREATE TABLE sometable (somecolumn " "ENUM('x','y','z'))", + ) + t1 = Table( + "sometable", + MetaData(), + Column("somecolumn", Enum("x", "y", "z", native_enum=False)), + ) + self.assert_compile( + schema.CreateTable(t1), + "CREATE TABLE sometable (somecolumn " + "VARCHAR(1), CHECK (somecolumn IN ('x', " + "'y', 'z')))", + ) @testing.provide_metadata - @testing.exclude('mysql', '<', (4,), "3.23 can't handle an ENUM of ''") + @testing.exclude("mysql", "<", (4,), "3.23 can't handle an ENUM of ''") def test_enum_parse(self): - with testing.expect_deprecated('Manually quoting ENUM value literals'): + with testing.expect_deprecated("Manually quoting ENUM value literals"): enum_table = Table( - 'mysql_enum', self.metadata, - Column('e1', mysql.ENUM("'a'")), - Column('e2', mysql.ENUM("''")), - Column('e3', mysql.ENUM('a')), - Column('e4', mysql.ENUM('')), - Column('e5', mysql.ENUM("'a'", "''")), - Column('e6', mysql.ENUM("''", "'a'")), - Column('e7', mysql.ENUM("''", "'''a'''", "'b''b'", "''''"))) + "mysql_enum", + self.metadata, + Column("e1", mysql.ENUM("'a'")), + Column("e2", mysql.ENUM("''")), + Column("e3", mysql.ENUM("a")), + Column("e4", mysql.ENUM("")), + Column("e5", mysql.ENUM("'a'", "''")), + Column("e6", mysql.ENUM("''", "'a'")), + Column("e7", mysql.ENUM("''", "'''a'''", "'b''b'", "''''")), + ) for col in enum_table.c: self.assert_(repr(col)) enum_table.create() - reflected = Table('mysql_enum', MetaData(testing.db), - autoload=True) + reflected = Table("mysql_enum", MetaData(testing.db), autoload=True) for t in enum_table, reflected: eq_(t.c.e1.type.enums, ["a"]) eq_(t.c.e2.type.enums, [""]) @@ -1070,20 +1279,29 @@ class EnumSetTest( eq_(t.c.e7.type.enums, ["", "'a'", "b'b", "'"]) @testing.provide_metadata - @testing.exclude('mysql', '<', (5,)) + @testing.exclude("mysql", "<", (5,)) def test_set_parse(self): - with testing.expect_deprecated('Manually quoting SET value literals'): + with testing.expect_deprecated("Manually quoting SET value literals"): set_table = Table( - 'mysql_set', self.metadata, - Column('e1', mysql.SET("'a'")), - Column('e2', mysql.SET("''", retrieve_as_bitwise=True)), - Column('e3', mysql.SET('a')), - Column('e4', mysql.SET('', retrieve_as_bitwise=True)), - Column('e5', mysql.SET("'a'", "''", retrieve_as_bitwise=True)), - Column('e6', mysql.SET("''", "'a'", retrieve_as_bitwise=True)), - Column('e7', mysql.SET( - "''", "'''a'''", "'b''b'", "''''", - retrieve_as_bitwise=True))) + "mysql_set", + self.metadata, + Column("e1", mysql.SET("'a'")), + Column("e2", mysql.SET("''", retrieve_as_bitwise=True)), + Column("e3", mysql.SET("a")), + Column("e4", mysql.SET("", retrieve_as_bitwise=True)), + Column("e5", mysql.SET("'a'", "''", retrieve_as_bitwise=True)), + Column("e6", mysql.SET("''", "'a'", retrieve_as_bitwise=True)), + Column( + "e7", + mysql.SET( + "''", + "'''a'''", + "'b''b'", + "''''", + retrieve_as_bitwise=True, + ), + ), + ) for col in set_table.c: self.assert_(repr(col)) @@ -1091,8 +1309,7 @@ class EnumSetTest( set_table.create() # don't want any warnings on reflection - reflected = Table('mysql_set', MetaData(testing.db), - autoload=True) + reflected = Table("mysql_set", MetaData(testing.db), autoload=True) for t in set_table, reflected: eq_(t.c.e1.type.values, ("a",)) eq_(t.c.e2.type.values, ("",)) @@ -1106,17 +1323,18 @@ class EnumSetTest( @testing.provide_metadata def test_broken_enum_returns_blanks(self): t = Table( - 'enum_missing', + "enum_missing", self.metadata, - Column('id', Integer, primary_key=True), - Column('e1', sqltypes.Enum('one', 'two', 'three')), - Column('e2', mysql.ENUM('one', 'two', 'three')) + Column("id", Integer, primary_key=True), + Column("e1", sqltypes.Enum("one", "two", "three")), + Column("e2", mysql.ENUM("one", "two", "three")), ) t.create() with testing.db.connect() as conn: - conn.execute(t.insert(), - {"e1": "nonexistent", "e2": "nonexistent"}) + conn.execute( + t.insert(), {"e1": "nonexistent", "e2": "nonexistent"} + ) conn.execute(t.insert(), {"e1": "", "e2": ""}) conn.execute(t.insert(), {"e1": "two", "e2": "two"}) conn.execute(t.insert(), {"e1": None, "e2": None}) @@ -1125,10 +1343,11 @@ class EnumSetTest( conn.execute( select([t.c.e1, t.c.e2]).order_by(t.c.id) ).fetchall(), - [("", ""), ("", ""), ("two", "two"), (None, None)] + [("", ""), ("", ""), ("two", "two"), (None, None)], ) def colspec(c): return testing.db.dialect.ddl_compiler( - testing.db.dialect, None).get_column_specification(c) + testing.db.dialect, None + ).get_column_specification(c) diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py index 3e1ffebb3c..2edf5848e9 100644 --- a/test/dialect/oracle/test_compiler.py +++ b/test/dialect/oracle/test_compiler.py @@ -5,17 +5,53 @@ from sqlalchemy.testing import eq_ from sqlalchemy import types as sqltypes, exc, schema from sqlalchemy.sql import table, column from sqlalchemy import and_ -from sqlalchemy.testing import (fixtures, - AssertsExecutionResults, - AssertsCompiledSQL) +from sqlalchemy.testing import ( + fixtures, + AssertsExecutionResults, + AssertsCompiledSQL, +) from sqlalchemy import testing -from sqlalchemy import Integer, Text, LargeBinary, Unicode, UniqueConstraint,\ - Index, MetaData, select, inspect, ForeignKey, String, func, \ - TypeDecorator, bindparam, Numeric, TIMESTAMP, CHAR, text, \ - literal_column, VARCHAR, create_engine, Date, NVARCHAR, \ - ForeignKeyConstraint, Sequence, Float, DateTime, cast, UnicodeText, \ - union, except_, type_coerce, or_, outerjoin, DATE, NCHAR, outparam, \ - PrimaryKeyConstraint, FLOAT +from sqlalchemy import ( + Integer, + Text, + LargeBinary, + Unicode, + UniqueConstraint, + Index, + MetaData, + select, + inspect, + ForeignKey, + String, + func, + TypeDecorator, + bindparam, + Numeric, + TIMESTAMP, + CHAR, + text, + literal_column, + VARCHAR, + create_engine, + Date, + NVARCHAR, + ForeignKeyConstraint, + Sequence, + Float, + DateTime, + cast, + UnicodeText, + union, + except_, + type_coerce, + or_, + outerjoin, + DATE, + NCHAR, + outparam, + PrimaryKeyConstraint, + FLOAT, +) from sqlalchemy.util import u, b from sqlalchemy import util from sqlalchemy.testing import assert_raises, assert_raises_message @@ -35,34 +71,41 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "oracle" def test_true_false(self): - self.assert_compile( - sql.false(), "0" - ) - self.assert_compile( - sql.true(), - "1" - ) + self.assert_compile(sql.false(), "0") + self.assert_compile(sql.true(), "1") def test_owner(self): meta = MetaData() - parent = Table('parent', meta, Column('id', Integer, - primary_key=True), Column('name', String(50)), - schema='ed') - child = Table('child', meta, Column('id', Integer, - primary_key=True), Column('parent_id', Integer, - ForeignKey('ed.parent.id')), schema='ed') - self.assert_compile(parent.join(child), - 'ed.parent JOIN ed.child ON ed.parent.id = ' - 'ed.child.parent_id') + parent = Table( + "parent", + meta, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + schema="ed", + ) + child = Table( + "child", + meta, + Column("id", Integer, primary_key=True), + Column("parent_id", Integer, ForeignKey("ed.parent.id")), + schema="ed", + ) + self.assert_compile( + parent.join(child), + "ed.parent JOIN ed.child ON ed.parent.id = " "ed.child.parent_id", + ) def test_subquery(self): - t = table('sometable', column('col1'), column('col2')) + t = table("sometable", column("col1"), column("col2")) s = select([t]) s = select([s.c.col1, s.c.col2]) - self.assert_compile(s, "SELECT col1, col2 FROM (SELECT " - "sometable.col1 AS col1, sometable.col2 " - "AS col2 FROM sometable)") + self.assert_compile( + s, + "SELECT col1, col2 FROM (SELECT " + "sometable.col1 AS col1, sometable.col2 " + "AS col2 FROM sometable)", + ) def test_bindparam_quote(self): """test that bound parameters take on quoting for reserved words, @@ -70,16 +113,12 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): # note: this is only in cx_oracle at the moment. not sure # what other hypothetical oracle dialects might need - self.assert_compile( - bindparam("option"), ':"option"' - ) - self.assert_compile( - bindparam("plain"), ':plain' - ) - t = Table("s", MetaData(), Column('plain', Integer, quote=True)) + self.assert_compile(bindparam("option"), ':"option"') + self.assert_compile(bindparam("plain"), ":plain") + t = Table("s", MetaData(), Column("plain", Integer, quote=True)) self.assert_compile( t.insert().values(plain=5), - 'INSERT INTO s ("plain") VALUES (:"plain")' + 'INSERT INTO s ("plain") VALUES (:"plain")', ) self.assert_compile( t.update().values(plain=5), 'UPDATE s SET "plain"=:"plain"' @@ -92,38 +131,43 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "Oracle; it requires quoting which is not supported in this " "context", bindparam("uid", expanding=True).compile, - dialect=cx_oracle.dialect() + dialect=cx_oracle.dialect(), ) def test_cte(self): part = table( - 'part', - column('part'), - column('sub_part'), - column('quantity') + "part", column("part"), column("sub_part"), column("quantity") ) - included_parts = select([ - part.c.sub_part, part.c.part, part.c.quantity - ]).where(part.c.part == "p1").\ - cte(name="included_parts", recursive=True).\ - suffix_with( + included_parts = ( + select([part.c.sub_part, part.c.part, part.c.quantity]) + .where(part.c.part == "p1") + .cte(name="included_parts", recursive=True) + .suffix_with( "search depth first by part set ord1", - "cycle part set y_cycle to 1 default 0", dialect='oracle') + "cycle part set y_cycle to 1 default 0", + dialect="oracle", + ) + ) incl_alias = included_parts.alias("pr1") parts_alias = part.alias("p") included_parts = included_parts.union_all( - select([ - parts_alias.c.sub_part, - parts_alias.c.part, parts_alias.c.quantity - ]).where(parts_alias.c.part == incl_alias.c.sub_part) + select( + [ + parts_alias.c.sub_part, + parts_alias.c.part, + parts_alias.c.quantity, + ] + ).where(parts_alias.c.part == incl_alias.c.sub_part) ) - q = select([ - included_parts.c.sub_part, - func.sum(included_parts.c.quantity).label('total_quantity')]).\ - group_by(included_parts.c.sub_part) + q = select( + [ + included_parts.c.sub_part, + func.sum(included_parts.c.quantity).label("total_quantity"), + ] + ).group_by(included_parts.c.sub_part) self.assert_compile( q, @@ -137,160 +181,184 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "y_cycle to 1 default 0 " "SELECT included_parts.sub_part, sum(included_parts.quantity) " "AS total_quantity FROM included_parts " - "GROUP BY included_parts.sub_part" + "GROUP BY included_parts.sub_part", ) def test_limit(self): - t = table('sometable', column('col1'), column('col2')) + t = table("sometable", column("col1"), column("col2")) s = select([t]) c = s.compile(dialect=oracle.OracleDialect()) - assert t.c.col1 in set(c._create_result_map()['col1'][1]) + assert t.c.col1 in set(c._create_result_map()["col1"][1]) s = select([t]).limit(10).offset(20) - self.assert_compile(s, - 'SELECT col1, col2 FROM (SELECT col1, ' - 'col2, ROWNUM AS ora_rn FROM (SELECT ' - 'sometable.col1 AS col1, sometable.col2 AS ' - 'col2 FROM sometable) WHERE ROWNUM <= ' - ':param_1 + :param_2) WHERE ora_rn > :param_2', - checkparams={'param_1': 10, 'param_2': 20}) + self.assert_compile( + s, + "SELECT col1, col2 FROM (SELECT col1, " + "col2, ROWNUM AS ora_rn FROM (SELECT " + "sometable.col1 AS col1, sometable.col2 AS " + "col2 FROM sometable) WHERE ROWNUM <= " + ":param_1 + :param_2) WHERE ora_rn > :param_2", + checkparams={"param_1": 10, "param_2": 20}, + ) c = s.compile(dialect=oracle.OracleDialect()) eq_(len(c._result_columns), 2) - assert t.c.col1 in set(c._create_result_map()['col1'][1]) + assert t.c.col1 in set(c._create_result_map()["col1"][1]) s2 = select([s.c.col1, s.c.col2]) - self.assert_compile(s2, - 'SELECT col1, col2 FROM (SELECT col1, col2 ' - 'FROM (SELECT col1, col2, ROWNUM AS ora_rn ' - 'FROM (SELECT sometable.col1 AS col1, ' - 'sometable.col2 AS col2 FROM sometable) ' - 'WHERE ROWNUM <= :param_1 + :param_2) ' - 'WHERE ora_rn > :param_2)', - checkparams={'param_1': 10, 'param_2': 20}) - - self.assert_compile(s2, - 'SELECT col1, col2 FROM (SELECT col1, col2 ' - 'FROM (SELECT col1, col2, ROWNUM AS ora_rn ' - 'FROM (SELECT sometable.col1 AS col1, ' - 'sometable.col2 AS col2 FROM sometable) ' - 'WHERE ROWNUM <= :param_1 + :param_2) ' - 'WHERE ora_rn > :param_2)') + self.assert_compile( + s2, + "SELECT col1, col2 FROM (SELECT col1, col2 " + "FROM (SELECT col1, col2, ROWNUM AS ora_rn " + "FROM (SELECT sometable.col1 AS col1, " + "sometable.col2 AS col2 FROM sometable) " + "WHERE ROWNUM <= :param_1 + :param_2) " + "WHERE ora_rn > :param_2)", + checkparams={"param_1": 10, "param_2": 20}, + ) + + self.assert_compile( + s2, + "SELECT col1, col2 FROM (SELECT col1, col2 " + "FROM (SELECT col1, col2, ROWNUM AS ora_rn " + "FROM (SELECT sometable.col1 AS col1, " + "sometable.col2 AS col2 FROM sometable) " + "WHERE ROWNUM <= :param_1 + :param_2) " + "WHERE ora_rn > :param_2)", + ) c = s2.compile(dialect=oracle.OracleDialect()) eq_(len(c._result_columns), 2) - assert s.c.col1 in set(c._create_result_map()['col1'][1]) + assert s.c.col1 in set(c._create_result_map()["col1"][1]) s = select([t]).limit(10).offset(20).order_by(t.c.col2) - self.assert_compile(s, - 'SELECT col1, col2 FROM (SELECT col1, ' - 'col2, ROWNUM AS ora_rn FROM (SELECT ' - 'sometable.col1 AS col1, sometable.col2 AS ' - 'col2 FROM sometable ORDER BY ' - 'sometable.col2) WHERE ROWNUM <= ' - ':param_1 + :param_2) WHERE ora_rn > :param_2', - checkparams={'param_1': 10, 'param_2': 20} - ) + self.assert_compile( + s, + "SELECT col1, col2 FROM (SELECT col1, " + "col2, ROWNUM AS ora_rn FROM (SELECT " + "sometable.col1 AS col1, sometable.col2 AS " + "col2 FROM sometable ORDER BY " + "sometable.col2) WHERE ROWNUM <= " + ":param_1 + :param_2) WHERE ora_rn > :param_2", + checkparams={"param_1": 10, "param_2": 20}, + ) c = s.compile(dialect=oracle.OracleDialect()) eq_(len(c._result_columns), 2) - assert t.c.col1 in set(c._create_result_map()['col1'][1]) + assert t.c.col1 in set(c._create_result_map()["col1"][1]) s = select([t], for_update=True).limit(10).order_by(t.c.col2) - self.assert_compile(s, - 'SELECT col1, col2 FROM (SELECT ' - 'sometable.col1 AS col1, sometable.col2 AS ' - 'col2 FROM sometable ORDER BY ' - 'sometable.col2) WHERE ROWNUM <= :param_1 ' - 'FOR UPDATE') - - s = select([t], - for_update=True).limit(10).offset(20).order_by(t.c.col2) - self.assert_compile(s, - 'SELECT col1, col2 FROM (SELECT col1, ' - 'col2, ROWNUM AS ora_rn FROM (SELECT ' - 'sometable.col1 AS col1, sometable.col2 AS ' - 'col2 FROM sometable ORDER BY ' - 'sometable.col2) WHERE ROWNUM <= ' - ':param_1 + :param_2) WHERE ora_rn > :param_2 FOR ' - 'UPDATE') + self.assert_compile( + s, + "SELECT col1, col2 FROM (SELECT " + "sometable.col1 AS col1, sometable.col2 AS " + "col2 FROM sometable ORDER BY " + "sometable.col2) WHERE ROWNUM <= :param_1 " + "FOR UPDATE", + ) + + s = ( + select([t], for_update=True) + .limit(10) + .offset(20) + .order_by(t.c.col2) + ) + self.assert_compile( + s, + "SELECT col1, col2 FROM (SELECT col1, " + "col2, ROWNUM AS ora_rn FROM (SELECT " + "sometable.col1 AS col1, sometable.col2 AS " + "col2 FROM sometable ORDER BY " + "sometable.col2) WHERE ROWNUM <= " + ":param_1 + :param_2) WHERE ora_rn > :param_2 FOR " + "UPDATE", + ) def test_for_update(self): - table1 = table('mytable', - column('myid'), column('name'), column('description')) + table1 = table( + "mytable", column("myid"), column("name"), column("description") + ) self.assert_compile( table1.select(table1.c.myid == 7).with_for_update(), "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE") + "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE", + ) self.assert_compile( - table1 - .select(table1.c.myid == 7) - .with_for_update(of=table1.c.myid), + table1.select(table1.c.myid == 7).with_for_update( + of=table1.c.myid + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = :myid_1 " - "FOR UPDATE OF mytable.myid") + "FOR UPDATE OF mytable.myid", + ) self.assert_compile( table1.select(table1.c.myid == 7).with_for_update(nowait=True), "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE NOWAIT") + "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE NOWAIT", + ) self.assert_compile( - table1 - .select(table1.c.myid == 7) - .with_for_update(nowait=True, of=table1.c.myid), + table1.select(table1.c.myid == 7).with_for_update( + nowait=True, of=table1.c.myid + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = :myid_1 " - "FOR UPDATE OF mytable.myid NOWAIT") + "FOR UPDATE OF mytable.myid NOWAIT", + ) self.assert_compile( - table1 - .select(table1.c.myid == 7) - .with_for_update(nowait=True, of=[table1.c.myid, table1.c.name]), + table1.select(table1.c.myid == 7).with_for_update( + nowait=True, of=[table1.c.myid, table1.c.name] + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE OF " - "mytable.myid, mytable.name NOWAIT") + "mytable.myid, mytable.name NOWAIT", + ) self.assert_compile( - table1.select(table1.c.myid == 7) - .with_for_update(skip_locked=True, - of=[table1.c.myid, table1.c.name]), + table1.select(table1.c.myid == 7).with_for_update( + skip_locked=True, of=[table1.c.myid, table1.c.name] + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE OF " - "mytable.myid, mytable.name SKIP LOCKED") + "mytable.myid, mytable.name SKIP LOCKED", + ) # key_share has no effect self.assert_compile( table1.select(table1.c.myid == 7).with_for_update(key_share=True), "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE") + "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE", + ) # read has no effect self.assert_compile( - table1 - .select(table1.c.myid == 7) - .with_for_update(read=True, key_share=True), + table1.select(table1.c.myid == 7).with_for_update( + read=True, key_share=True + ), "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE") + "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE", + ) ta = table1.alias() self.assert_compile( - ta - .select(ta.c.myid == 7) - .with_for_update(of=[ta.c.myid, ta.c.name]), + ta.select(ta.c.myid == 7).with_for_update( + of=[ta.c.myid, ta.c.name] + ), "SELECT mytable_1.myid, mytable_1.name, mytable_1.description " "FROM mytable mytable_1 " "WHERE mytable_1.myid = :myid_1 FOR UPDATE OF " - "mytable_1.myid, mytable_1.name" + "mytable_1.myid, mytable_1.name", ) def test_for_update_of_w_limit_adaption_col_present(self): - table1 = table('mytable', column('myid'), column('name')) + table1 = table("mytable", column("myid"), column("name")) self.assert_compile( - select([table1.c.myid, table1.c.name]). - where(table1.c.myid == 7). - with_for_update(nowait=True, of=table1.c.name). - limit(10), + select([table1.c.myid, table1.c.name]) + .where(table1.c.myid == 7) + .with_for_update(nowait=True, of=table1.c.name) + .limit(10), "SELECT myid, name FROM " "(SELECT mytable.myid AS myid, mytable.name AS name " "FROM mytable WHERE mytable.myid = :myid_1) " @@ -298,13 +366,13 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_for_update_of_w_limit_adaption_col_unpresent(self): - table1 = table('mytable', column('myid'), column('name')) + table1 = table("mytable", column("myid"), column("name")) self.assert_compile( - select([table1.c.myid]). - where(table1.c.myid == 7). - with_for_update(nowait=True, of=table1.c.name). - limit(10), + select([table1.c.myid]) + .where(table1.c.myid == 7) + .with_for_update(nowait=True, of=table1.c.name) + .limit(10), "SELECT myid FROM " "(SELECT mytable.myid AS myid, mytable.name AS name " "FROM mytable WHERE mytable.myid = :myid_1) " @@ -312,13 +380,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_for_update_of_w_limit_offset_adaption_col_present(self): - table1 = table('mytable', column('myid'), column('name')) + table1 = table("mytable", column("myid"), column("name")) self.assert_compile( - select([table1.c.myid, table1.c.name]). - where(table1.c.myid == 7). - with_for_update(nowait=True, of=table1.c.name). - limit(10).offset(50), + select([table1.c.myid, table1.c.name]) + .where(table1.c.myid == 7) + .with_for_update(nowait=True, of=table1.c.name) + .limit(10) + .offset(50), "SELECT myid, name FROM (SELECT myid, name, ROWNUM AS ora_rn " "FROM (SELECT mytable.myid AS myid, mytable.name AS name " "FROM mytable WHERE mytable.myid = :myid_1) " @@ -327,13 +396,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_for_update_of_w_limit_offset_adaption_col_unpresent(self): - table1 = table('mytable', column('myid'), column('name')) + table1 = table("mytable", column("myid"), column("name")) self.assert_compile( - select([table1.c.myid]). - where(table1.c.myid == 7). - with_for_update(nowait=True, of=table1.c.name). - limit(10).offset(50), + select([table1.c.myid]) + .where(table1.c.myid == 7) + .with_for_update(nowait=True, of=table1.c.name) + .limit(10) + .offset(50), "SELECT myid FROM (SELECT myid, ROWNUM AS ora_rn, name " "FROM (SELECT mytable.myid AS myid, mytable.name AS name " "FROM mytable WHERE mytable.myid = :myid_1) " @@ -342,55 +412,59 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_for_update_of_w_limit_offset_adaption_partial_col_unpresent(self): - table1 = table('mytable', column('myid'), column('foo'), column('bar')) + table1 = table("mytable", column("myid"), column("foo"), column("bar")) self.assert_compile( - select([table1.c.myid, table1.c.bar]). - where(table1.c.myid == 7). - with_for_update(nowait=True, of=[table1.c.foo, table1.c.bar]). - limit(10).offset(50), + select([table1.c.myid, table1.c.bar]) + .where(table1.c.myid == 7) + .with_for_update(nowait=True, of=[table1.c.foo, table1.c.bar]) + .limit(10) + .offset(50), "SELECT myid, bar FROM (SELECT myid, bar, ROWNUM AS ora_rn, " "foo FROM (SELECT mytable.myid AS myid, mytable.bar AS bar, " "mytable.foo AS foo FROM mytable WHERE mytable.myid = :myid_1) " "WHERE ROWNUM <= :param_1 + :param_2) WHERE ora_rn > :param_2 " - "FOR UPDATE OF foo, bar NOWAIT" + "FOR UPDATE OF foo, bar NOWAIT", ) def test_limit_preserves_typing_information(self): class MyType(TypeDecorator): impl = Integer - stmt = select([type_coerce(column('x'), MyType).label('foo')]).limit(1) + stmt = select([type_coerce(column("x"), MyType).label("foo")]).limit(1) dialect = oracle.dialect() compiled = stmt.compile(dialect=dialect) - assert isinstance(compiled._create_result_map()['foo'][-1], MyType) + assert isinstance(compiled._create_result_map()["foo"][-1], MyType) def test_use_binds_for_limits_disabled(self): - t = table('sometable', column('col1'), column('col2')) + t = table("sometable", column("col1"), column("col2")) dialect = oracle.OracleDialect(use_binds_for_limits=False) self.assert_compile( select([t]).limit(10), "SELECT col1, col2 FROM (SELECT sometable.col1 AS col1, " "sometable.col2 AS col2 FROM sometable) WHERE ROWNUM <= 10", - dialect=dialect) + dialect=dialect, + ) self.assert_compile( select([t]).offset(10), "SELECT col1, col2 FROM (SELECT col1, col2, ROWNUM AS ora_rn " "FROM (SELECT sometable.col1 AS col1, sometable.col2 AS col2 " "FROM sometable)) WHERE ora_rn > 10", - dialect=dialect) + dialect=dialect, + ) self.assert_compile( select([t]).limit(10).offset(10), "SELECT col1, col2 FROM (SELECT col1, col2, ROWNUM AS ora_rn " "FROM (SELECT sometable.col1 AS col1, sometable.col2 AS col2 " "FROM sometable) WHERE ROWNUM <= 20) WHERE ora_rn > 10", - dialect=dialect) + dialect=dialect, + ) def test_use_binds_for_limits_enabled(self): - t = table('sometable', column('col1'), column('col2')) + t = table("sometable", column("col1"), column("col2")) dialect = oracle.OracleDialect(use_binds_for_limits=True) self.assert_compile( @@ -398,14 +472,16 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT col1, col2 FROM (SELECT sometable.col1 AS col1, " "sometable.col2 AS col2 FROM sometable) WHERE ROWNUM " "<= :param_1", - dialect=dialect) + dialect=dialect, + ) self.assert_compile( select([t]).offset(10), "SELECT col1, col2 FROM (SELECT col1, col2, ROWNUM AS ora_rn " "FROM (SELECT sometable.col1 AS col1, sometable.col2 AS col2 " "FROM sometable)) WHERE ora_rn > :param_1", - dialect=dialect) + dialect=dialect, + ) self.assert_compile( select([t]).limit(10).offset(10), @@ -414,7 +490,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "FROM sometable) WHERE ROWNUM <= :param_1 + :param_2) " "WHERE ora_rn > :param_2", dialect=dialect, - checkparams={'param_1': 10, 'param_2': 10}) + checkparams={"param_1": 10, "param_2": 10}, + ) def test_long_labels(self): dialect = default.DefaultDialect() @@ -424,203 +501,243 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): m = MetaData() a_table = Table( - 'thirty_characters_table_xxxxxx', + "thirty_characters_table_xxxxxx", m, - Column('id', Integer, primary_key=True) + Column("id", Integer, primary_key=True), ) other_table = Table( - 'other_thirty_characters_table_', + "other_thirty_characters_table_", m, - Column('id', Integer, primary_key=True), - Column('thirty_characters_table_id', - Integer, - ForeignKey('thirty_characters_table_xxxxxx.id'), - primary_key=True)) + Column("id", Integer, primary_key=True), + Column( + "thirty_characters_table_id", + Integer, + ForeignKey("thirty_characters_table_xxxxxx.id"), + primary_key=True, + ), + ) anon = a_table.alias() - self.assert_compile(select([other_table, - anon]). - select_from( - other_table.outerjoin(anon)).apply_labels(), - 'SELECT other_thirty_characters_table_.id ' - 'AS other_thirty_characters__1, ' - 'other_thirty_characters_table_.thirty_char' - 'acters_table_id AS other_thirty_characters' - '__2, thirty_characters_table__1.id AS ' - 'thirty_characters_table__3 FROM ' - 'other_thirty_characters_table_ LEFT OUTER ' - 'JOIN thirty_characters_table_xxxxxx AS ' - 'thirty_characters_table__1 ON ' - 'thirty_characters_table__1.id = ' - 'other_thirty_characters_table_.thirty_char' - 'acters_table_id', dialect=dialect) - self.assert_compile(select([other_table, - anon]).select_from( - other_table.outerjoin(anon)).apply_labels(), - 'SELECT other_thirty_characters_table_.id ' - 'AS other_thirty_characters__1, ' - 'other_thirty_characters_table_.thirty_char' - 'acters_table_id AS other_thirty_characters' - '__2, thirty_characters_table__1.id AS ' - 'thirty_characters_table__3 FROM ' - 'other_thirty_characters_table_ LEFT OUTER ' - 'JOIN thirty_characters_table_xxxxxx ' - 'thirty_characters_table__1 ON ' - 'thirty_characters_table__1.id = ' - 'other_thirty_characters_table_.thirty_char' - 'acters_table_id', dialect=ora_dialect) + self.assert_compile( + select([other_table, anon]) + .select_from(other_table.outerjoin(anon)) + .apply_labels(), + "SELECT other_thirty_characters_table_.id " + "AS other_thirty_characters__1, " + "other_thirty_characters_table_.thirty_char" + "acters_table_id AS other_thirty_characters" + "__2, thirty_characters_table__1.id AS " + "thirty_characters_table__3 FROM " + "other_thirty_characters_table_ LEFT OUTER " + "JOIN thirty_characters_table_xxxxxx AS " + "thirty_characters_table__1 ON " + "thirty_characters_table__1.id = " + "other_thirty_characters_table_.thirty_char" + "acters_table_id", + dialect=dialect, + ) + self.assert_compile( + select([other_table, anon]) + .select_from(other_table.outerjoin(anon)) + .apply_labels(), + "SELECT other_thirty_characters_table_.id " + "AS other_thirty_characters__1, " + "other_thirty_characters_table_.thirty_char" + "acters_table_id AS other_thirty_characters" + "__2, thirty_characters_table__1.id AS " + "thirty_characters_table__3 FROM " + "other_thirty_characters_table_ LEFT OUTER " + "JOIN thirty_characters_table_xxxxxx " + "thirty_characters_table__1 ON " + "thirty_characters_table__1.id = " + "other_thirty_characters_table_.thirty_char" + "acters_table_id", + dialect=ora_dialect, + ) def test_outer_join(self): - table1 = table('mytable', - column('myid', Integer), - column('name', String), - column('description', String)) + table1 = table( + "mytable", + column("myid", Integer), + column("name", String), + column("description", String), + ) table2 = table( - 'myothertable', - column('otherid', Integer), - column('othername', String), + "myothertable", + column("otherid", Integer), + column("othername", String), ) table3 = table( - 'thirdtable', - column('userid', Integer), - column('otherstuff', String), - ) - - query = select([table1, table2], - or_(table1.c.name == 'fred', - table1.c.myid == 10, table2.c.othername != 'jack', - text('EXISTS (select yay from foo where boo = lar)') - ), - from_obj=[outerjoin(table1, - table2, - table1.c.myid == table2.c.otherid)]) - self.assert_compile(query, - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description, myothertable.otherid,' - ' myothertable.othername FROM mytable, ' - 'myothertable WHERE (mytable.name = ' - ':name_1 OR mytable.myid = :myid_1 OR ' - 'myothertable.othername != :othername_1 OR ' - 'EXISTS (select yay from foo where boo = ' - 'lar)) AND mytable.myid = ' - 'myothertable.otherid(+)', - dialect=oracle.OracleDialect(use_ansi=False)) - query = table1.outerjoin(table2, - table1.c.myid == table2.c.otherid) \ - .outerjoin(table3, table3.c.userid == table2.c.otherid) - self.assert_compile(query.select(), - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description, myothertable.otherid,' - ' myothertable.othername, ' - 'thirdtable.userid, thirdtable.otherstuff ' - 'FROM mytable LEFT OUTER JOIN myothertable ' - 'ON mytable.myid = myothertable.otherid ' - 'LEFT OUTER JOIN thirdtable ON ' - 'thirdtable.userid = myothertable.otherid') - - self.assert_compile(query.select(), - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description, myothertable.otherid,' - ' myothertable.othername, ' - 'thirdtable.userid, thirdtable.otherstuff ' - 'FROM mytable, myothertable, thirdtable ' - 'WHERE thirdtable.userid(+) = ' - 'myothertable.otherid AND mytable.myid = ' - 'myothertable.otherid(+)', - dialect=oracle.dialect(use_ansi=False)) - query = table1.join(table2, - table1.c.myid == table2.c.otherid) \ - .join(table3, table3.c.userid == table2.c.otherid) - self.assert_compile(query.select(), - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description, myothertable.otherid,' - ' myothertable.othername, ' - 'thirdtable.userid, thirdtable.otherstuff ' - 'FROM mytable, myothertable, thirdtable ' - 'WHERE thirdtable.userid = ' - 'myothertable.otherid AND mytable.myid = ' - 'myothertable.otherid', - dialect=oracle.dialect(use_ansi=False)) - query = table1.join(table2, - table1.c.myid == table2.c.otherid) \ - .outerjoin(table3, table3.c.userid == table2.c.otherid) - self.assert_compile(query.select().order_by(table1.c.name). - limit(10).offset(5), - 'SELECT myid, name, description, otherid, ' - 'othername, userid, otherstuff FROM ' - '(SELECT myid, name, description, otherid, ' - 'othername, userid, otherstuff, ROWNUM AS ' - 'ora_rn FROM (SELECT mytable.myid AS myid, ' - 'mytable.name AS name, mytable.description ' - 'AS description, myothertable.otherid AS ' - 'otherid, myothertable.othername AS ' - 'othername, thirdtable.userid AS userid, ' - 'thirdtable.otherstuff AS otherstuff FROM ' - 'mytable, myothertable, thirdtable WHERE ' - 'thirdtable.userid(+) = ' - 'myothertable.otherid AND mytable.myid = ' - 'myothertable.otherid ORDER BY mytable.name) ' - 'WHERE ROWNUM <= :param_1 + :param_2) ' - 'WHERE ora_rn > :param_2', - checkparams={'param_1': 10, 'param_2': 5}, - dialect=oracle.dialect(use_ansi=False)) - - subq = select([table1]).select_from( - table1.outerjoin(table2, table1.c.myid == table2.c.otherid)) \ + "thirdtable", + column("userid", Integer), + column("otherstuff", String), + ) + + query = select( + [table1, table2], + or_( + table1.c.name == "fred", + table1.c.myid == 10, + table2.c.othername != "jack", + text("EXISTS (select yay from foo where boo = lar)"), + ), + from_obj=[ + outerjoin(table1, table2, table1.c.myid == table2.c.otherid) + ], + ) + self.assert_compile( + query, + "SELECT mytable.myid, mytable.name, " + "mytable.description, myothertable.otherid," + " myothertable.othername FROM mytable, " + "myothertable WHERE (mytable.name = " + ":name_1 OR mytable.myid = :myid_1 OR " + "myothertable.othername != :othername_1 OR " + "EXISTS (select yay from foo where boo = " + "lar)) AND mytable.myid = " + "myothertable.otherid(+)", + dialect=oracle.OracleDialect(use_ansi=False), + ) + query = table1.outerjoin( + table2, table1.c.myid == table2.c.otherid + ).outerjoin(table3, table3.c.userid == table2.c.otherid) + self.assert_compile( + query.select(), + "SELECT mytable.myid, mytable.name, " + "mytable.description, myothertable.otherid," + " myothertable.othername, " + "thirdtable.userid, thirdtable.otherstuff " + "FROM mytable LEFT OUTER JOIN myothertable " + "ON mytable.myid = myothertable.otherid " + "LEFT OUTER JOIN thirdtable ON " + "thirdtable.userid = myothertable.otherid", + ) + + self.assert_compile( + query.select(), + "SELECT mytable.myid, mytable.name, " + "mytable.description, myothertable.otherid," + " myothertable.othername, " + "thirdtable.userid, thirdtable.otherstuff " + "FROM mytable, myothertable, thirdtable " + "WHERE thirdtable.userid(+) = " + "myothertable.otherid AND mytable.myid = " + "myothertable.otherid(+)", + dialect=oracle.dialect(use_ansi=False), + ) + query = table1.join(table2, table1.c.myid == table2.c.otherid).join( + table3, table3.c.userid == table2.c.otherid + ) + self.assert_compile( + query.select(), + "SELECT mytable.myid, mytable.name, " + "mytable.description, myothertable.otherid," + " myothertable.othername, " + "thirdtable.userid, thirdtable.otherstuff " + "FROM mytable, myothertable, thirdtable " + "WHERE thirdtable.userid = " + "myothertable.otherid AND mytable.myid = " + "myothertable.otherid", + dialect=oracle.dialect(use_ansi=False), + ) + query = table1.join( + table2, table1.c.myid == table2.c.otherid + ).outerjoin(table3, table3.c.userid == table2.c.otherid) + self.assert_compile( + query.select().order_by(table1.c.name).limit(10).offset(5), + "SELECT myid, name, description, otherid, " + "othername, userid, otherstuff FROM " + "(SELECT myid, name, description, otherid, " + "othername, userid, otherstuff, ROWNUM AS " + "ora_rn FROM (SELECT mytable.myid AS myid, " + "mytable.name AS name, mytable.description " + "AS description, myothertable.otherid AS " + "otherid, myothertable.othername AS " + "othername, thirdtable.userid AS userid, " + "thirdtable.otherstuff AS otherstuff FROM " + "mytable, myothertable, thirdtable WHERE " + "thirdtable.userid(+) = " + "myothertable.otherid AND mytable.myid = " + "myothertable.otherid ORDER BY mytable.name) " + "WHERE ROWNUM <= :param_1 + :param_2) " + "WHERE ora_rn > :param_2", + checkparams={"param_1": 10, "param_2": 5}, + dialect=oracle.dialect(use_ansi=False), + ) + + subq = ( + select([table1]) + .select_from( + table1.outerjoin(table2, table1.c.myid == table2.c.otherid) + ) .alias() + ) q = select([table3]).select_from( - table3.outerjoin(subq, table3.c.userid == subq.c.myid)) - - self.assert_compile(q, - 'SELECT thirdtable.userid, ' - 'thirdtable.otherstuff FROM thirdtable ' - 'LEFT OUTER JOIN (SELECT mytable.myid AS ' - 'myid, mytable.name AS name, ' - 'mytable.description AS description FROM ' - 'mytable LEFT OUTER JOIN myothertable ON ' - 'mytable.myid = myothertable.otherid) ' - 'anon_1 ON thirdtable.userid = anon_1.myid', - dialect=oracle.dialect(use_ansi=True)) - - self.assert_compile(q, - 'SELECT thirdtable.userid, ' - 'thirdtable.otherstuff FROM thirdtable, ' - '(SELECT mytable.myid AS myid, ' - 'mytable.name AS name, mytable.description ' - 'AS description FROM mytable, myothertable ' - 'WHERE mytable.myid = myothertable.otherid(' - '+)) anon_1 WHERE thirdtable.userid = ' - 'anon_1.myid(+)', - dialect=oracle.dialect(use_ansi=False)) - - q = select([table1.c.name]).where(table1.c.name == 'foo') - self.assert_compile(q, - 'SELECT mytable.name FROM mytable WHERE ' - 'mytable.name = :name_1', - dialect=oracle.dialect(use_ansi=False)) - subq = select([table3.c.otherstuff]) \ - .where(table3.c.otherstuff == table1.c.name).label('bar') + table3.outerjoin(subq, table3.c.userid == subq.c.myid) + ) + + self.assert_compile( + q, + "SELECT thirdtable.userid, " + "thirdtable.otherstuff FROM thirdtable " + "LEFT OUTER JOIN (SELECT mytable.myid AS " + "myid, mytable.name AS name, " + "mytable.description AS description FROM " + "mytable LEFT OUTER JOIN myothertable ON " + "mytable.myid = myothertable.otherid) " + "anon_1 ON thirdtable.userid = anon_1.myid", + dialect=oracle.dialect(use_ansi=True), + ) + + self.assert_compile( + q, + "SELECT thirdtable.userid, " + "thirdtable.otherstuff FROM thirdtable, " + "(SELECT mytable.myid AS myid, " + "mytable.name AS name, mytable.description " + "AS description FROM mytable, myothertable " + "WHERE mytable.myid = myothertable.otherid(" + "+)) anon_1 WHERE thirdtable.userid = " + "anon_1.myid(+)", + dialect=oracle.dialect(use_ansi=False), + ) + + q = select([table1.c.name]).where(table1.c.name == "foo") + self.assert_compile( + q, + "SELECT mytable.name FROM mytable WHERE " "mytable.name = :name_1", + dialect=oracle.dialect(use_ansi=False), + ) + subq = ( + select([table3.c.otherstuff]) + .where(table3.c.otherstuff == table1.c.name) + .label("bar") + ) q = select([table1.c.name, subq]) - self.assert_compile(q, - 'SELECT mytable.name, (SELECT ' - 'thirdtable.otherstuff FROM thirdtable ' - 'WHERE thirdtable.otherstuff = ' - 'mytable.name) AS bar FROM mytable', - dialect=oracle.dialect(use_ansi=False)) + self.assert_compile( + q, + "SELECT mytable.name, (SELECT " + "thirdtable.otherstuff FROM thirdtable " + "WHERE thirdtable.otherstuff = " + "mytable.name) AS bar FROM mytable", + dialect=oracle.dialect(use_ansi=False), + ) def test_nonansi_plusses_everthing_in_the_condition(self): - table1 = table('mytable', - column('myid', Integer), - column('name', String), - column('description', String)) + table1 = table( + "mytable", + column("myid", Integer), + column("name", String), + column("description", String), + ) table2 = table( - 'myothertable', - column('otherid', Integer), - column('othername', String), + "myothertable", + column("otherid", Integer), + column("othername", String), ) stmt = select([table1]).select_from( @@ -629,8 +746,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): and_( table1.c.myid == table2.c.otherid, table2.c.othername > 5, - table1.c.name == 'foo' - ) + table1.c.name == "foo", + ), ) ) self.assert_compile( @@ -639,7 +756,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "FROM mytable, myothertable WHERE mytable.myid = " "myothertable.otherid(+) AND myothertable.othername(+) > " ":othername_1 AND mytable.name = :name_1", - dialect=oracle.dialect(use_ansi=False)) + dialect=oracle.dialect(use_ansi=False), + ) stmt = select([table1]).select_from( table1.outerjoin( @@ -647,8 +765,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): and_( table1.c.myid == table2.c.otherid, table2.c.othername == None, - table1.c.name == None - ) + table1.c.name == None, + ), ) ) self.assert_compile( @@ -657,12 +775,13 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "FROM mytable, myothertable WHERE mytable.myid = " "myothertable.otherid(+) AND myothertable.othername(+) IS NULL " "AND mytable.name IS NULL", - dialect=oracle.dialect(use_ansi=False)) + dialect=oracle.dialect(use_ansi=False), + ) def test_nonansi_nested_right_join(self): - a = table('a', column('a')) - b = table('b', column('b')) - c = table('c', column('c')) + a = table("a", column("a")) + b = table("b", column("b")) + c = table("c", column("c")) j = a.join(b.join(c, b.c.b == c.c.c), a.c.a == b.c.b) @@ -670,7 +789,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): select([j]), "SELECT a.a, b.b, c.c FROM a, b, c " "WHERE a.a = b.b AND b.b = c.c", - dialect=oracle.OracleDialect(use_ansi=False) + dialect=oracle.OracleDialect(use_ansi=False), ) j = a.outerjoin(b.join(c, b.c.b == c.c.c), a.c.a == b.c.b) @@ -679,7 +798,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): select([j]), "SELECT a.a, b.b, c.c FROM a, b, c " "WHERE a.a = b.b(+) AND b.b = c.c", - dialect=oracle.OracleDialect(use_ansi=False) + dialect=oracle.OracleDialect(use_ansi=False), ) j = a.join(b.outerjoin(c, b.c.b == c.c.c), a.c.a == b.c.b) @@ -688,75 +807,94 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): select([j]), "SELECT a.a, b.b, c.c FROM a, b, c " "WHERE a.a = b.b AND b.b = c.c(+)", - dialect=oracle.OracleDialect(use_ansi=False) + dialect=oracle.OracleDialect(use_ansi=False), ) def test_alias_outer_join(self): - address_types = table('address_types', column('id'), - column('name')) - addresses = table('addresses', column('id'), column('user_id'), - column('address_type_id'), - column('email_address')) + address_types = table("address_types", column("id"), column("name")) + addresses = table( + "addresses", + column("id"), + column("user_id"), + column("address_type_id"), + column("email_address"), + ) at_alias = address_types.alias() - s = select([at_alias, addresses]) \ + s = ( + select([at_alias, addresses]) .select_from( addresses.outerjoin( - at_alias, - addresses.c.address_type_id == at_alias.c.id)) \ - .where(addresses.c.user_id == 7) \ + at_alias, addresses.c.address_type_id == at_alias.c.id + ) + ) + .where(addresses.c.user_id == 7) .order_by(addresses.c.id, address_types.c.id) - self.assert_compile(s, - 'SELECT address_types_1.id, ' - 'address_types_1.name, addresses.id, ' - 'addresses.user_id, addresses.address_type_' - 'id, addresses.email_address FROM ' - 'addresses LEFT OUTER JOIN address_types ' - 'address_types_1 ON addresses.address_type_' - 'id = address_types_1.id WHERE ' - 'addresses.user_id = :user_id_1 ORDER BY ' - 'addresses.id, address_types.id') + ) + self.assert_compile( + s, + "SELECT address_types_1.id, " + "address_types_1.name, addresses.id, " + "addresses.user_id, addresses.address_type_" + "id, addresses.email_address FROM " + "addresses LEFT OUTER JOIN address_types " + "address_types_1 ON addresses.address_type_" + "id = address_types_1.id WHERE " + "addresses.user_id = :user_id_1 ORDER BY " + "addresses.id, address_types.id", + ) def test_returning_insert(self): - t1 = table('t1', column('c1'), column('c2'), column('c3')) + t1 = table("t1", column("c1"), column("c2"), column("c3")) self.assert_compile( t1.insert().values(c1=1).returning(t1.c.c2, t1.c.c3), "INSERT INTO t1 (c1) VALUES (:c1) RETURNING " - "t1.c2, t1.c3 INTO :ret_0, :ret_1") + "t1.c2, t1.c3 INTO :ret_0, :ret_1", + ) def test_returning_insert_functional(self): - t1 = table('t1', - column('c1'), - column('c2', String()), - column('c3', String())) + t1 = table( + "t1", column("c1"), column("c2", String()), column("c3", String()) + ) fn = func.lower(t1.c.c2, type_=String()) stmt = t1.insert().values(c1=1).returning(fn, t1.c.c3) compiled = stmt.compile(dialect=oracle.dialect()) - eq_(compiled._create_result_map(), - {'c3': ('c3', (t1.c.c3, 'c3', 'c3'), t1.c.c3.type), - 'lower': ('lower', (fn, 'lower', None), fn.type)}) + eq_( + compiled._create_result_map(), + { + "c3": ("c3", (t1.c.c3, "c3", "c3"), t1.c.c3.type), + "lower": ("lower", (fn, "lower", None), fn.type), + }, + ) self.assert_compile( stmt, "INSERT INTO t1 (c1) VALUES (:c1) RETURNING " - "lower(t1.c2), t1.c3 INTO :ret_0, :ret_1") + "lower(t1.c2), t1.c3 INTO :ret_0, :ret_1", + ) def test_returning_insert_labeled(self): - t1 = table('t1', column('c1'), column('c2'), column('c3')) + t1 = table("t1", column("c1"), column("c2"), column("c3")) self.assert_compile( - t1.insert().values(c1=1).returning( - t1.c.c2.label('c2_l'), t1.c.c3.label('c3_l')), + t1.insert() + .values(c1=1) + .returning(t1.c.c2.label("c2_l"), t1.c.c3.label("c3_l")), "INSERT INTO t1 (c1) VALUES (:c1) RETURNING " - "t1.c2, t1.c3 INTO :ret_0, :ret_1") + "t1.c2, t1.c3 INTO :ret_0, :ret_1", + ) def test_compound(self): - t1 = table('t1', column('c1'), column('c2'), column('c3')) - t2 = table('t2', column('c1'), column('c2'), column('c3')) - self.assert_compile(union(t1.select(), t2.select()), - 'SELECT t1.c1, t1.c2, t1.c3 FROM t1 UNION ' - 'SELECT t2.c1, t2.c2, t2.c3 FROM t2') - self.assert_compile(except_(t1.select(), t2.select()), - 'SELECT t1.c1, t1.c2, t1.c3 FROM t1 MINUS ' - 'SELECT t2.c1, t2.c2, t2.c3 FROM t2') + t1 = table("t1", column("c1"), column("c2"), column("c3")) + t2 = table("t2", column("c1"), column("c2"), column("c3")) + self.assert_compile( + union(t1.select(), t2.select()), + "SELECT t1.c1, t1.c2, t1.c3 FROM t1 UNION " + "SELECT t2.c1, t2.c2, t2.c3 FROM t2", + ) + self.assert_compile( + except_(t1.select(), t2.select()), + "SELECT t1.c1, t1.c2, t1.c3 FROM t1 MINUS " + "SELECT t2.c1, t2.c2, t2.c3 FROM t2", + ) def test_no_paren_fns(self): for fn, expected in [ @@ -773,79 +911,91 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_create_index_alt_schema(self): m = MetaData() - t1 = Table('foo', m, - Column('x', Integer), - schema="alt_schema") + t1 = Table("foo", m, Column("x", Integer), schema="alt_schema") self.assert_compile( schema.CreateIndex(Index("bar", t1.c.x)), - "CREATE INDEX alt_schema.bar ON alt_schema.foo (x)" + "CREATE INDEX alt_schema.bar ON alt_schema.foo (x)", ) def test_create_index_expr(self): m = MetaData() - t1 = Table('foo', m, - Column('x', Integer)) + t1 = Table("foo", m, Column("x", Integer)) self.assert_compile( schema.CreateIndex(Index("bar", t1.c.x > 5)), - "CREATE INDEX bar ON foo (x > 5)" + "CREATE INDEX bar ON foo (x > 5)", ) def test_table_options(self): m = MetaData() t = Table( - 'foo', m, - Column('x', Integer), + "foo", + m, + Column("x", Integer), prefixes=["GLOBAL TEMPORARY"], - oracle_on_commit="PRESERVE ROWS" + oracle_on_commit="PRESERVE ROWS", ) self.assert_compile( schema.CreateTable(t), "CREATE GLOBAL TEMPORARY TABLE " - "foo (x INTEGER) ON COMMIT PRESERVE ROWS" + "foo (x INTEGER) ON COMMIT PRESERVE ROWS", ) def test_create_table_compress(self): m = MetaData() - tbl1 = Table('testtbl1', m, Column('data', Integer), - oracle_compress=True) - tbl2 = Table('testtbl2', m, Column('data', Integer), - oracle_compress="OLTP") + tbl1 = Table( + "testtbl1", m, Column("data", Integer), oracle_compress=True + ) + tbl2 = Table( + "testtbl2", m, Column("data", Integer), oracle_compress="OLTP" + ) - self.assert_compile(schema.CreateTable(tbl1), - "CREATE TABLE testtbl1 (data INTEGER) COMPRESS") - self.assert_compile(schema.CreateTable(tbl2), - "CREATE TABLE testtbl2 (data INTEGER) " - "COMPRESS FOR OLTP") + self.assert_compile( + schema.CreateTable(tbl1), + "CREATE TABLE testtbl1 (data INTEGER) COMPRESS", + ) + self.assert_compile( + schema.CreateTable(tbl2), + "CREATE TABLE testtbl2 (data INTEGER) " "COMPRESS FOR OLTP", + ) def test_create_index_bitmap_compress(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', Integer)) - idx1 = Index('idx1', tbl.c.data, oracle_compress=True) - idx2 = Index('idx2', tbl.c.data, oracle_compress=1) - idx3 = Index('idx3', tbl.c.data, oracle_bitmap=True) + tbl = Table("testtbl", m, Column("data", Integer)) + idx1 = Index("idx1", tbl.c.data, oracle_compress=True) + idx2 = Index("idx2", tbl.c.data, oracle_compress=1) + idx3 = Index("idx3", tbl.c.data, oracle_bitmap=True) - self.assert_compile(schema.CreateIndex(idx1), - "CREATE INDEX idx1 ON testtbl (data) COMPRESS") - self.assert_compile(schema.CreateIndex(idx2), - "CREATE INDEX idx2 ON testtbl (data) COMPRESS 1") - self.assert_compile(schema.CreateIndex(idx3), - "CREATE BITMAP INDEX idx3 ON testtbl (data)") + self.assert_compile( + schema.CreateIndex(idx1), + "CREATE INDEX idx1 ON testtbl (data) COMPRESS", + ) + self.assert_compile( + schema.CreateIndex(idx2), + "CREATE INDEX idx2 ON testtbl (data) COMPRESS 1", + ) + self.assert_compile( + schema.CreateIndex(idx3), + "CREATE BITMAP INDEX idx3 ON testtbl (data)", + ) class SequenceTest(fixtures.TestBase, AssertsCompiledSQL): - def test_basic(self): - seq = Sequence('my_seq_no_schema') + seq = Sequence("my_seq_no_schema") dialect = oracle.OracleDialect() - assert dialect.identifier_preparer.format_sequence(seq) \ - == 'my_seq_no_schema' - seq = Sequence('my_seq', schema='some_schema') - assert dialect.identifier_preparer.format_sequence(seq) \ - == 'some_schema.my_seq' - seq = Sequence('My_Seq', schema='Some_Schema') - assert dialect.identifier_preparer.format_sequence(seq) \ + assert ( + dialect.identifier_preparer.format_sequence(seq) + == "my_seq_no_schema" + ) + seq = Sequence("my_seq", schema="some_schema") + assert ( + dialect.identifier_preparer.format_sequence(seq) + == "some_schema.my_seq" + ) + seq = Sequence("My_Seq", schema="Some_Schema") + assert ( + dialect.identifier_preparer.format_sequence(seq) == '"Some_Schema"."My_Seq"' - - + ) diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index 9f38a515a9..cfb90c25f4 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -3,14 +3,25 @@ from sqlalchemy.testing import eq_ from sqlalchemy import exc -from sqlalchemy.testing import (fixtures, - AssertsExecutionResults, - AssertsCompiledSQL) +from sqlalchemy.testing import ( + fixtures, + AssertsExecutionResults, + AssertsCompiledSQL, +) from sqlalchemy import testing from sqlalchemy import create_engine from sqlalchemy import bindparam, outparam -from sqlalchemy import text, Float, Integer, String, select, literal_column,\ - Unicode, UnicodeText, Sequence +from sqlalchemy import ( + text, + Float, + Integer, + String, + select, + literal_column, + Unicode, + UnicodeText, + Sequence, +) from sqlalchemy.util import u from sqlalchemy.testing import assert_raises, assert_raises_message from sqlalchemy.dialects.oracle import cx_oracle, base as oracle @@ -24,45 +35,41 @@ class DialectTest(fixtures.TestBase): def test_cx_oracle_version_parse(self): dialect = cx_oracle.OracleDialect_cx_oracle() - eq_( - dialect._parse_cx_oracle_ver("5.2"), - (5, 2) - ) + eq_(dialect._parse_cx_oracle_ver("5.2"), (5, 2)) - eq_( - dialect._parse_cx_oracle_ver("5.0.1"), - (5, 0, 1) - ) + eq_(dialect._parse_cx_oracle_ver("5.0.1"), (5, 0, 1)) - eq_( - dialect._parse_cx_oracle_ver("6.0b1"), - (6, 0) - ) + eq_(dialect._parse_cx_oracle_ver("6.0b1"), (6, 0)) def test_minimum_version(self): with mock.patch( - "sqlalchemy.dialects.oracle.cx_oracle.OracleDialect_cx_oracle." - "_parse_cx_oracle_ver", lambda self, vers: (5, 1, 5)): + "sqlalchemy.dialects.oracle.cx_oracle.OracleDialect_cx_oracle." + "_parse_cx_oracle_ver", + lambda self, vers: (5, 1, 5), + ): assert_raises_message( exc.InvalidRequestError, "cx_Oracle version 5.2 and above are supported", cx_oracle.OracleDialect_cx_oracle, - dbapi=Mock() + dbapi=Mock(), ) with mock.patch( - "sqlalchemy.dialects.oracle.cx_oracle.OracleDialect_cx_oracle." - "_parse_cx_oracle_ver", lambda self, vers: (5, 3, 1)): + "sqlalchemy.dialects.oracle.cx_oracle.OracleDialect_cx_oracle." + "_parse_cx_oracle_ver", + lambda self, vers: (5, 3, 1), + ): cx_oracle.OracleDialect_cx_oracle(dbapi=Mock()) class OutParamTest(fixtures.TestBase, AssertsExecutionResults): - __only_on__ = 'oracle+cx_oracle' + __only_on__ = "oracle+cx_oracle" __backend__ = True @classmethod def setup_class(cls): - testing.db.execute(""" + testing.db.execute( + """ create or replace procedure foo(x_in IN number, x_out OUT number, y_out OUT number, z_out OUT varchar) IS retval number; @@ -72,19 +79,24 @@ class OutParamTest(fixtures.TestBase, AssertsExecutionResults): y_out := x_in * 15; z_out := NULL; end; - """) + """ + ) def test_out_params(self): - result = testing.db.execute(text('begin foo(:x_in, :x_out, :y_out, ' - ':z_out); end;', - bindparams=[bindparam('x_in', Float), - outparam('x_out', Integer), - outparam('y_out', Float), - outparam('z_out', String)]), - x_in=5) - eq_(result.out_parameters, - {'x_out': 10, 'y_out': 75, 'z_out': None}) - assert isinstance(result.out_parameters['x_out'], int) + result = testing.db.execute( + text( + "begin foo(:x_in, :x_out, :y_out, " ":z_out); end;", + bindparams=[ + bindparam("x_in", Float), + outparam("x_out", Integer), + outparam("y_out", Float), + outparam("z_out", String), + ], + ), + x_in=5, + ) + eq_(result.out_parameters, {"x_out": 10, "y_out": 75, "z_out": None}) + assert isinstance(result.out_parameters["x_out"], int) @classmethod def teardown_class(cls): @@ -93,69 +105,61 @@ class OutParamTest(fixtures.TestBase, AssertsExecutionResults): class QuotedBindRoundTripTest(fixtures.TestBase): - __only_on__ = 'oracle' + __only_on__ = "oracle" __backend__ = True @testing.provide_metadata def test_table_round_trip(self): - oracle.RESERVED_WORDS.remove('UNION') + oracle.RESERVED_WORDS.remove("UNION") metadata = self.metadata - table = Table("t1", metadata, - Column("option", Integer), - Column("plain", Integer, quote=True), - # test that quote works for a reserved word - # that the dialect isn't aware of when quote - # is set - Column("union", Integer, quote=True)) + table = Table( + "t1", + metadata, + Column("option", Integer), + Column("plain", Integer, quote=True), + # test that quote works for a reserved word + # that the dialect isn't aware of when quote + # is set + Column("union", Integer, quote=True), + ) metadata.create_all() - table.insert().execute( - {"option": 1, "plain": 1, "union": 1} - ) - eq_( - testing.db.execute(table.select()).first(), - (1, 1, 1) - ) + table.insert().execute({"option": 1, "plain": 1, "union": 1}) + eq_(testing.db.execute(table.select()).first(), (1, 1, 1)) table.update().values(option=2, plain=2, union=2).execute() - eq_( - testing.db.execute(table.select()).first(), - (2, 2, 2) - ) + eq_(testing.db.execute(table.select()).first(), (2, 2, 2)) def test_numeric_bind_round_trip(self): eq_( testing.db.scalar( - select([ - literal_column("2", type_=Integer()) + - bindparam("2_1", value=2)]) + select( + [ + literal_column("2", type_=Integer()) + + bindparam("2_1", value=2) + ] + ) ), - 4 + 4, ) @testing.provide_metadata def test_numeric_bind_in_crud(self): - t = Table( - "asfd", self.metadata, - Column("100K", Integer) - ) + t = Table("asfd", self.metadata, Column("100K", Integer)) t.create() testing.db.execute(t.insert(), {"100K": 10}) - eq_( - testing.db.scalar(t.select()), 10 - ) + eq_(testing.db.scalar(t.select()), 10) class CompatFlagsTest(fixtures.TestBase, AssertsCompiledSQL): - def _dialect(self, server_version, **kw): def server_version_info(conn): return server_version dialect = oracle.dialect( - dbapi=Mock(version="0.0.0", paramstyle="named"), - **kw) + dbapi=Mock(version="0.0.0", paramstyle="named"), **kw + ) dialect._get_server_version_info = server_version_info dialect._check_unicode_returns = Mock() dialect._check_unicode_description = Mock() @@ -219,15 +223,19 @@ class CompatFlagsTest(fixtures.TestBase, AssertsCompiledSQL): class ExecuteTest(fixtures.TestBase): - __only_on__ = 'oracle' + __only_on__ = "oracle" __backend__ = True def test_basic(self): - eq_(testing.db.execute('/*+ this is a comment */ SELECT 1 FROM ' - 'DUAL').fetchall(), [(1, )]) + eq_( + testing.db.execute( + "/*+ this is a comment */ SELECT 1 FROM " "DUAL" + ).fetchall(), + [(1,)], + ) def test_sequences_are_integers(self): - seq = Sequence('foo_seq') + seq = Sequence("foo_seq") seq.create(testing.db) try: val = testing.db.execute(seq) @@ -242,25 +250,26 @@ class ExecuteTest(fixtures.TestBase): # oracle can't actually do the ROWNUM thing with FOR UPDATE # very well. - t = Table('t1', - metadata, - Column('id', Integer, primary_key=True), - Column('data', Integer)) + t = Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("data", Integer), + ) metadata.create_all() t.insert().execute( - {'id': 1, 'data': 1}, - {'id': 2, 'data': 7}, - {'id': 3, 'data': 12}, - {'id': 4, 'data': 15}, - {'id': 5, 'data': 32}, + {"id": 1, "data": 1}, + {"id": 2, "data": 7}, + {"id": 3, "data": 12}, + {"id": 4, "data": 15}, + {"id": 5, "data": 32}, ) # here, we can't use ORDER BY. eq_( t.select(for_update=True).limit(2).execute().fetchall(), - [(1, 1), - (2, 7)] + [(1, 1), (2, 7)], ) # here, its impossible. But we'd prefer it to raise ORA-02014 @@ -268,69 +277,71 @@ class ExecuteTest(fixtures.TestBase): assert_raises_message( exc.DatabaseError, "ORA-02014", - t.select(for_update=True).limit(2).offset(3).execute + t.select(for_update=True).limit(2).offset(3).execute, ) class UnicodeSchemaTest(fixtures.TestBase): - __only_on__ = 'oracle' + __only_on__ = "oracle" __backend__ = True @testing.provide_metadata def test_quoted_column_non_unicode(self): metadata = self.metadata - table = Table("atable", metadata, - Column("_underscorecolumn", - Unicode(255), - primary_key=True)) + table = Table( + "atable", + metadata, + Column("_underscorecolumn", Unicode(255), primary_key=True), + ) metadata.create_all() - table.insert().execute( - {'_underscorecolumn': u('’é')}, - ) + table.insert().execute({"_underscorecolumn": u("’é")}) result = testing.db.execute( - table.select().where(table.c._underscorecolumn == u('’é')) + table.select().where(table.c._underscorecolumn == u("’é")) ).scalar() - eq_(result, u('’é')) + eq_(result, u("’é")) @testing.provide_metadata def test_quoted_column_unicode(self): metadata = self.metadata - table = Table("atable", metadata, - Column(u("méil"), Unicode(255), primary_key=True)) + table = Table( + "atable", + metadata, + Column(u("méil"), Unicode(255), primary_key=True), + ) metadata.create_all() - table.insert().execute( - {u('méil'): u('’é')}, - ) + table.insert().execute({u("méil"): u("’é")}) result = testing.db.execute( - table.select().where(table.c[u('méil')] == u('’é')) + table.select().where(table.c[u("méil")] == u("’é")) ).scalar() - eq_(result, u('’é')) + eq_(result, u("’é")) class CXOracleConnectArgsTest(fixtures.TestBase): - __only_on__ = 'oracle+cx_oracle' + __only_on__ = "oracle+cx_oracle" __backend__ = True def test_cx_oracle_service_name(self): - url_string = 'oracle+cx_oracle://scott:tiger@host/?service_name=hr' + url_string = "oracle+cx_oracle://scott:tiger@host/?service_name=hr" eng = create_engine(url_string, _initialize=False) cargs, cparams = eng.dialect.create_connect_args(eng.url) - assert 'SERVICE_NAME=hr' in cparams['dsn'] - assert 'SID=hr' not in cparams['dsn'] + assert "SERVICE_NAME=hr" in cparams["dsn"] + assert "SID=hr" not in cparams["dsn"] def test_cx_oracle_service_name_bad(self): - url_string = 'oracle+cx_oracle://scott:tiger@host/hr1?service_name=hr2' + url_string = "oracle+cx_oracle://scott:tiger@host/hr1?service_name=hr2" assert_raises( exc.InvalidRequestError, - create_engine, url_string, - _initialize=False + create_engine, + url_string, + _initialize=False, ) def _test_db_opt(self, url_string, key, value): import cx_Oracle + url_obj = url.make_url(url_string) dialect = cx_oracle.dialect(dbapi=cx_Oracle) arg, kw = dialect.create_connect_args(url_obj) @@ -338,6 +349,7 @@ class CXOracleConnectArgsTest(fixtures.TestBase): def _test_db_opt_unpresent(self, url_string, key): import cx_Oracle + url_obj = url.make_url(url_string) dialect = cx_oracle.dialect(dbapi=cx_Oracle) arg, kw = dialect.create_connect_args(url_obj) @@ -345,10 +357,12 @@ class CXOracleConnectArgsTest(fixtures.TestBase): def _test_dialect_param_from_url(self, url_string, key, value): import cx_Oracle + url_obj = url.make_url(url_string) dialect = cx_oracle.dialect(dbapi=cx_Oracle) with testing.expect_deprecated( - "cx_oracle dialect option %r should" % key): + "cx_oracle dialect option %r should" % key + ): arg, kw = dialect.create_connect_args(url_obj) eq_(getattr(dialect, key), value) @@ -358,32 +372,32 @@ class CXOracleConnectArgsTest(fixtures.TestBase): def test_mode(self): import cx_Oracle + self._test_db_opt( - 'oracle+cx_oracle://scott:tiger@host/?mode=sYsDBA', + "oracle+cx_oracle://scott:tiger@host/?mode=sYsDBA", "mode", - cx_Oracle.SYSDBA + cx_Oracle.SYSDBA, ) self._test_db_opt( - 'oracle+cx_oracle://scott:tiger@host/?mode=SYSOPER', + "oracle+cx_oracle://scott:tiger@host/?mode=SYSOPER", "mode", - cx_Oracle.SYSOPER + cx_Oracle.SYSOPER, ) def test_int_mode(self): self._test_db_opt( - 'oracle+cx_oracle://scott:tiger@host/?mode=32767', - "mode", - 32767 + "oracle+cx_oracle://scott:tiger@host/?mode=32767", "mode", 32767 ) @testing.requires.cxoracle6_or_greater def test_purity(self): import cx_Oracle + self._test_db_opt( - 'oracle+cx_oracle://scott:tiger@host/?purity=attr_purity_new', + "oracle+cx_oracle://scott:tiger@host/?purity=attr_purity_new", "purity", - cx_Oracle.ATTR_PURITY_NEW + cx_Oracle.ATTR_PURITY_NEW, ) def test_encoding(self): @@ -391,44 +405,45 @@ class CXOracleConnectArgsTest(fixtures.TestBase): "oracle+cx_oracle://scott:tiger@host/" "?encoding=AMERICAN_AMERICA.UTF8", "encoding", - "AMERICAN_AMERICA.UTF8" + "AMERICAN_AMERICA.UTF8", ) def test_threaded(self): self._test_db_opt( - 'oracle+cx_oracle://scott:tiger@host/?threaded=true', + "oracle+cx_oracle://scott:tiger@host/?threaded=true", "threaded", - True + True, ) self._test_db_opt_unpresent( - 'oracle+cx_oracle://scott:tiger@host/', - "threaded" + "oracle+cx_oracle://scott:tiger@host/", "threaded" ) def test_events(self): self._test_db_opt( - 'oracle+cx_oracle://scott:tiger@host/?events=true', - "events", - True + "oracle+cx_oracle://scott:tiger@host/?events=true", "events", True ) def test_threaded_deprecated_at_dialect_level(self): with testing.expect_deprecated( - "The 'threaded' parameter to the cx_oracle dialect"): + "The 'threaded' parameter to the cx_oracle dialect" + ): dialect = cx_oracle.dialect(threaded=False) arg, kw = dialect.create_connect_args( - url.make_url("oracle+cx_oracle://scott:tiger@dsn")) - eq_(kw['threaded'], False) + url.make_url("oracle+cx_oracle://scott:tiger@dsn") + ) + eq_(kw["threaded"], False) def test_deprecated_use_ansi(self): self._test_dialect_param_from_url( - 'oracle+cx_oracle://scott:tiger@host/?use_ansi=False', - 'use_ansi', False + "oracle+cx_oracle://scott:tiger@host/?use_ansi=False", + "use_ansi", + False, ) def test_deprecated_auto_convert_lobs(self): self._test_dialect_param_from_url( - 'oracle+cx_oracle://scott:tiger@host/?auto_convert_lobs=False', - 'auto_convert_lobs', False - ) \ No newline at end of file + "oracle+cx_oracle://scott:tiger@host/?auto_convert_lobs=False", + "auto_convert_lobs", + False, + ) diff --git a/test/dialect/oracle/test_reflection.py b/test/dialect/oracle/test_reflection.py index f749e513ac..a88703ab0b 100644 --- a/test/dialect/oracle/test_reflection.py +++ b/test/dialect/oracle/test_reflection.py @@ -6,22 +6,61 @@ from sqlalchemy import exc from sqlalchemy.sql import table from sqlalchemy.testing import fixtures, AssertsCompiledSQL from sqlalchemy import testing -from sqlalchemy import Integer, Text, LargeBinary, Unicode, UniqueConstraint,\ - Index, MetaData, select, inspect, ForeignKey, String, func, \ - TypeDecorator, bindparam, Numeric, TIMESTAMP, CHAR, text, \ - literal_column, VARCHAR, create_engine, Date, NVARCHAR, \ - ForeignKeyConstraint, Sequence, Float, DateTime, cast, UnicodeText, \ - union, except_, type_coerce, or_, outerjoin, DATE, NCHAR, outparam, \ - PrimaryKeyConstraint, FLOAT, INTEGER -from sqlalchemy.dialects.oracle.base import NUMBER, BINARY_DOUBLE, \ - BINARY_FLOAT, DOUBLE_PRECISION +from sqlalchemy import ( + Integer, + Text, + LargeBinary, + Unicode, + UniqueConstraint, + Index, + MetaData, + select, + inspect, + ForeignKey, + String, + func, + TypeDecorator, + bindparam, + Numeric, + TIMESTAMP, + CHAR, + text, + literal_column, + VARCHAR, + create_engine, + Date, + NVARCHAR, + ForeignKeyConstraint, + Sequence, + Float, + DateTime, + cast, + UnicodeText, + union, + except_, + type_coerce, + or_, + outerjoin, + DATE, + NCHAR, + outparam, + PrimaryKeyConstraint, + FLOAT, + INTEGER, +) +from sqlalchemy.dialects.oracle.base import ( + NUMBER, + BINARY_DOUBLE, + BINARY_FLOAT, + DOUBLE_PRECISION, +) from sqlalchemy.testing import assert_raises from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing.schema import Table, Column class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL): - __only_on__ = 'oracle' + __only_on__ = "oracle" __backend__ = True @classmethod @@ -30,7 +69,8 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL): # don't really know how else to go here unless # we connect as the other user. - for stmt in (""" + for stmt in ( + """ create table %(test_schema)s.parent( id integer primary key, data varchar2(50) @@ -60,13 +100,16 @@ create synonym %(test_schema)s.local_table for local_table; -- so we give it to public. ideas welcome. grant references on %(test_schema)s.parent to public; grant references on %(test_schema)s.child to public; -""" % {"test_schema": testing.config.test_schema}).split(";"): +""" + % {"test_schema": testing.config.test_schema} + ).split(";"): if stmt.strip(): testing.db.execute(stmt) @classmethod def teardown_class(cls): - for stmt in (""" + for stmt in ( + """ drop table %(test_schema)s.child; drop table %(test_schema)s.parent; drop table local_table; @@ -75,7 +118,9 @@ drop synonym %(test_schema)s.ptable; drop synonym %(test_schema)s_pt; drop synonym %(test_schema)s.local_table; -""" % {"test_schema": testing.config.test_schema}).split(";"): +""" + % {"test_schema": testing.config.test_schema} + ).split(";"): if stmt.strip(): testing.db.execute(stmt) @@ -83,192 +128,235 @@ drop synonym %(test_schema)s.local_table; def test_create_same_names_explicit_schema(self): schema = testing.db.dialect.default_schema_name meta = self.metadata - parent = Table('parent', meta, - Column('pid', Integer, primary_key=True), - schema=schema) - child = Table('child', meta, - Column('cid', Integer, primary_key=True), - Column('pid', - Integer, - ForeignKey('%s.parent.pid' % schema)), - schema=schema) + parent = Table( + "parent", + meta, + Column("pid", Integer, primary_key=True), + schema=schema, + ) + child = Table( + "child", + meta, + Column("cid", Integer, primary_key=True), + Column("pid", Integer, ForeignKey("%s.parent.pid" % schema)), + schema=schema, + ) meta.create_all() - parent.insert().execute({'pid': 1}) - child.insert().execute({'cid': 1, 'pid': 1}) + parent.insert().execute({"pid": 1}) + child.insert().execute({"cid": 1, "pid": 1}) eq_(child.select().execute().fetchall(), [(1, 1)]) def test_reflect_alt_table_owner_local_synonym(self): meta = MetaData(testing.db) - parent = Table('%s_pt' % testing.config.test_schema, - meta, - autoload=True, - oracle_resolve_synonyms=True) - self.assert_compile(parent.select(), - "SELECT %(test_schema)s_pt.id, " - "%(test_schema)s_pt.data FROM %(test_schema)s_pt" - % {"test_schema": testing.config.test_schema}) + parent = Table( + "%s_pt" % testing.config.test_schema, + meta, + autoload=True, + oracle_resolve_synonyms=True, + ) + self.assert_compile( + parent.select(), + "SELECT %(test_schema)s_pt.id, " + "%(test_schema)s_pt.data FROM %(test_schema)s_pt" + % {"test_schema": testing.config.test_schema}, + ) select([parent]).execute().fetchall() def test_reflect_alt_synonym_owner_local_table(self): meta = MetaData(testing.db) parent = Table( - 'local_table', meta, autoload=True, - oracle_resolve_synonyms=True, schema=testing.config.test_schema) + "local_table", + meta, + autoload=True, + oracle_resolve_synonyms=True, + schema=testing.config.test_schema, + ) self.assert_compile( parent.select(), "SELECT %(test_schema)s.local_table.id, " "%(test_schema)s.local_table.data " - "FROM %(test_schema)s.local_table" % - {"test_schema": testing.config.test_schema} + "FROM %(test_schema)s.local_table" + % {"test_schema": testing.config.test_schema}, ) select([parent]).execute().fetchall() @testing.provide_metadata def test_create_same_names_implicit_schema(self): meta = self.metadata - parent = Table('parent', - meta, - Column('pid', Integer, primary_key=True)) - child = Table('child', meta, - Column('cid', Integer, primary_key=True), - Column('pid', Integer, ForeignKey('parent.pid'))) + parent = Table( + "parent", meta, Column("pid", Integer, primary_key=True) + ) + child = Table( + "child", + meta, + Column("cid", Integer, primary_key=True), + Column("pid", Integer, ForeignKey("parent.pid")), + ) meta.create_all() - parent.insert().execute({'pid': 1}) - child.insert().execute({'cid': 1, 'pid': 1}) + parent.insert().execute({"pid": 1}) + child.insert().execute({"cid": 1, "pid": 1}) eq_(child.select().execute().fetchall(), [(1, 1)]) def test_reflect_alt_owner_explicit(self): meta = MetaData(testing.db) parent = Table( - 'parent', meta, autoload=True, - schema=testing.config.test_schema) + "parent", meta, autoload=True, schema=testing.config.test_schema + ) child = Table( - 'child', meta, autoload=True, - schema=testing.config.test_schema) + "child", meta, autoload=True, schema=testing.config.test_schema + ) self.assert_compile( parent.join(child), "%(test_schema)s.parent JOIN %(test_schema)s.child ON " - "%(test_schema)s.parent.id = %(test_schema)s.child.parent_id" % { - "test_schema": testing.config.test_schema - }) - select([parent, child]).\ - select_from(parent.join(child)).\ - execute().fetchall() + "%(test_schema)s.parent.id = %(test_schema)s.child.parent_id" + % {"test_schema": testing.config.test_schema}, + ) + select([parent, child]).select_from( + parent.join(child) + ).execute().fetchall() def test_reflect_local_to_remote(self): testing.db.execute( - 'CREATE TABLE localtable (id INTEGER ' - 'PRIMARY KEY, parent_id INTEGER REFERENCES ' - '%(test_schema)s.parent(id))' % { - "test_schema": testing.config.test_schema}) + "CREATE TABLE localtable (id INTEGER " + "PRIMARY KEY, parent_id INTEGER REFERENCES " + "%(test_schema)s.parent(id))" + % {"test_schema": testing.config.test_schema} + ) try: meta = MetaData(testing.db) - lcl = Table('localtable', meta, autoload=True) - parent = meta.tables['%s.parent' % testing.config.test_schema] - self.assert_compile(parent.join(lcl), - '%(test_schema)s.parent JOIN localtable ON ' - '%(test_schema)s.parent.id = ' - 'localtable.parent_id' % { - "test_schema": testing.config.test_schema} - ) - select([parent, - lcl]).select_from(parent.join(lcl)).execute().fetchall() + lcl = Table("localtable", meta, autoload=True) + parent = meta.tables["%s.parent" % testing.config.test_schema] + self.assert_compile( + parent.join(lcl), + "%(test_schema)s.parent JOIN localtable ON " + "%(test_schema)s.parent.id = " + "localtable.parent_id" + % {"test_schema": testing.config.test_schema}, + ) + select([parent, lcl]).select_from( + parent.join(lcl) + ).execute().fetchall() finally: - testing.db.execute('DROP TABLE localtable') + testing.db.execute("DROP TABLE localtable") def test_reflect_alt_owner_implicit(self): meta = MetaData(testing.db) parent = Table( - 'parent', meta, autoload=True, - schema=testing.config.test_schema) + "parent", meta, autoload=True, schema=testing.config.test_schema + ) child = Table( - 'child', meta, autoload=True, - schema=testing.config.test_schema) + "child", meta, autoload=True, schema=testing.config.test_schema + ) self.assert_compile( parent.join(child), - '%(test_schema)s.parent JOIN %(test_schema)s.child ' - 'ON %(test_schema)s.parent.id = ' - '%(test_schema)s.child.parent_id' % { - "test_schema": testing.config.test_schema}) - select([parent, - child]).select_from(parent.join(child)).execute().fetchall() + "%(test_schema)s.parent JOIN %(test_schema)s.child " + "ON %(test_schema)s.parent.id = " + "%(test_schema)s.child.parent_id" + % {"test_schema": testing.config.test_schema}, + ) + select([parent, child]).select_from( + parent.join(child) + ).execute().fetchall() def test_reflect_alt_owner_synonyms(self): - testing.db.execute('CREATE TABLE localtable (id INTEGER ' - 'PRIMARY KEY, parent_id INTEGER REFERENCES ' - '%s.ptable(id))' % testing.config.test_schema) + testing.db.execute( + "CREATE TABLE localtable (id INTEGER " + "PRIMARY KEY, parent_id INTEGER REFERENCES " + "%s.ptable(id))" % testing.config.test_schema + ) try: meta = MetaData(testing.db) - lcl = Table('localtable', meta, autoload=True, - oracle_resolve_synonyms=True) - parent = meta.tables['%s.ptable' % testing.config.test_schema] + lcl = Table( + "localtable", meta, autoload=True, oracle_resolve_synonyms=True + ) + parent = meta.tables["%s.ptable" % testing.config.test_schema] self.assert_compile( parent.join(lcl), - '%(test_schema)s.ptable JOIN localtable ON ' - '%(test_schema)s.ptable.id = ' - 'localtable.parent_id' % { - "test_schema": testing.config.test_schema}) - select([parent, - lcl]).select_from(parent.join(lcl)).execute().fetchall() + "%(test_schema)s.ptable JOIN localtable ON " + "%(test_schema)s.ptable.id = " + "localtable.parent_id" + % {"test_schema": testing.config.test_schema}, + ) + select([parent, lcl]).select_from( + parent.join(lcl) + ).execute().fetchall() finally: - testing.db.execute('DROP TABLE localtable') + testing.db.execute("DROP TABLE localtable") def test_reflect_remote_synonyms(self): meta = MetaData(testing.db) - parent = Table('ptable', meta, autoload=True, - schema=testing.config.test_schema, - oracle_resolve_synonyms=True) - child = Table('ctable', meta, autoload=True, - schema=testing.config.test_schema, - oracle_resolve_synonyms=True) + parent = Table( + "ptable", + meta, + autoload=True, + schema=testing.config.test_schema, + oracle_resolve_synonyms=True, + ) + child = Table( + "ctable", + meta, + autoload=True, + schema=testing.config.test_schema, + oracle_resolve_synonyms=True, + ) self.assert_compile( parent.join(child), - '%(test_schema)s.ptable JOIN ' - '%(test_schema)s.ctable ' - 'ON %(test_schema)s.ptable.id = ' - '%(test_schema)s.ctable.parent_id' % { - "test_schema": testing.config.test_schema}) - select([parent, - child]).select_from(parent.join(child)).execute().fetchall() + "%(test_schema)s.ptable JOIN " + "%(test_schema)s.ctable " + "ON %(test_schema)s.ptable.id = " + "%(test_schema)s.ctable.parent_id" + % {"test_schema": testing.config.test_schema}, + ) + select([parent, child]).select_from( + parent.join(child) + ).execute().fetchall() class ConstraintTest(fixtures.TablesTest): - __only_on__ = 'oracle' + __only_on__ = "oracle" __backend__ = True run_deletes = None @classmethod def define_tables(cls, metadata): - Table('foo', metadata, Column('id', Integer, primary_key=True)) + Table("foo", metadata, Column("id", Integer, primary_key=True)) def test_oracle_has_no_on_update_cascade(self): - bar = Table('bar', self.metadata, - Column('id', Integer, primary_key=True), - Column('foo_id', - Integer, - ForeignKey('foo.id', onupdate='CASCADE'))) + bar = Table( + "bar", + self.metadata, + Column("id", Integer, primary_key=True), + Column( + "foo_id", Integer, ForeignKey("foo.id", onupdate="CASCADE") + ), + ) assert_raises(exc.SAWarning, bar.create) - bat = Table('bat', self.metadata, - Column('id', Integer, primary_key=True), - Column('foo_id', Integer), - ForeignKeyConstraint(['foo_id'], ['foo.id'], - onupdate='CASCADE')) + bat = Table( + "bat", + self.metadata, + Column("id", Integer, primary_key=True), + Column("foo_id", Integer), + ForeignKeyConstraint(["foo_id"], ["foo.id"], onupdate="CASCADE"), + ) assert_raises(exc.SAWarning, bat.create) def test_reflect_check_include_all(self): insp = inspect(testing.db) - eq_(insp.get_check_constraints('foo'), []) + eq_(insp.get_check_constraints("foo"), []) eq_( - [rec['sqltext'] - for rec in insp.get_check_constraints('foo', include_all=True)], - ['"ID" IS NOT NULL']) + [ + rec["sqltext"] + for rec in insp.get_check_constraints("foo", include_all=True) + ], + ['"ID" IS NOT NULL'], + ) class SystemTableTablenamesTest(fixtures.TestBase): - __only_on__ = 'oracle' + __only_on__ = "oracle" __backend__ = True def setup(self): @@ -287,23 +375,20 @@ class SystemTableTablenamesTest(fixtures.TestBase): def test_table_names_no_system(self): insp = inspect(testing.db) - eq_( - insp.get_table_names(), ["my_table"] - ) + eq_(insp.get_table_names(), ["my_table"]) def test_temp_table_names_no_system(self): insp = inspect(testing.db) - eq_( - insp.get_temp_table_names(), ["my_temp_table"] - ) + eq_(insp.get_temp_table_names(), ["my_temp_table"]) def test_table_names_w_system(self): engine = testing_engine(options={"exclude_tablespaces": ["FOO"]}) insp = inspect(engine) eq_( - set(insp.get_table_names()).intersection(["my_table", - "foo_table"]), - set(["my_table", "foo_table"]) + set(insp.get_table_names()).intersection( + ["my_table", "foo_table"] + ), + set(["my_table", "foo_table"]), ) @@ -311,11 +396,12 @@ class DontReflectIOTTest(fixtures.TestBase): """test that index overflow tables aren't included in table_names.""" - __only_on__ = 'oracle' + __only_on__ = "oracle" __backend__ = True def setup(self): - testing.db.execute(""" + testing.db.execute( + """ CREATE TABLE admin_docindex( token char(20), doc_id NUMBER, @@ -326,7 +412,8 @@ class DontReflectIOTTest(fixtures.TestBase): TABLESPACE users PCTTHRESHOLD 20 OVERFLOW TABLESPACE users - """) + """ + ) def teardown(self): testing.db.execute("drop table admin_docindex") @@ -334,35 +421,37 @@ class DontReflectIOTTest(fixtures.TestBase): def test_reflect_all(self): m = MetaData(testing.db) m.reflect() - eq_( - set(t.name for t in m.tables.values()), - set(['admin_docindex']) - ) + eq_(set(t.name for t in m.tables.values()), set(["admin_docindex"])) class UnsupportedIndexReflectTest(fixtures.TestBase): - __only_on__ = 'oracle' + __only_on__ = "oracle" __backend__ = True @testing.emits_warning("No column names") @testing.provide_metadata def test_reflect_functional_index(self): metadata = self.metadata - Table('test_index_reflect', metadata, - Column('data', String(20), primary_key=True)) + Table( + "test_index_reflect", + metadata, + Column("data", String(20), primary_key=True), + ) metadata.create_all() - testing.db.execute('CREATE INDEX DATA_IDX ON ' - 'TEST_INDEX_REFLECT (UPPER(DATA))') + testing.db.execute( + "CREATE INDEX DATA_IDX ON " "TEST_INDEX_REFLECT (UPPER(DATA))" + ) m2 = MetaData(testing.db) - Table('test_index_reflect', m2, autoload=True) + Table("test_index_reflect", m2, autoload=True) def all_tables_compression_missing(): try: - testing.db.execute('SELECT compression FROM all_tables') + testing.db.execute("SELECT compression FROM all_tables") if "Enterprise Edition" not in testing.db.scalar( - "select * from v$version"): + "select * from v$version" + ): return True return False except Exception: @@ -371,9 +460,10 @@ def all_tables_compression_missing(): def all_tables_compress_for_missing(): try: - testing.db.execute('SELECT compress_for FROM all_tables') + testing.db.execute("SELECT compress_for FROM all_tables") if "Enterprise Edition" not in testing.db.scalar( - "select * from v$version"): + "select * from v$version" + ): return True return False except Exception: @@ -381,7 +471,7 @@ def all_tables_compress_for_missing(): class TableReflectionTest(fixtures.TestBase): - __only_on__ = 'oracle' + __only_on__ = "oracle" __backend__ = True @testing.provide_metadata @@ -389,35 +479,41 @@ class TableReflectionTest(fixtures.TestBase): def test_reflect_basic_compression(self): metadata = self.metadata - tbl = Table('test_compress', metadata, - Column('data', Integer, primary_key=True), - oracle_compress=True) + tbl = Table( + "test_compress", + metadata, + Column("data", Integer, primary_key=True), + oracle_compress=True, + ) metadata.create_all() m2 = MetaData(testing.db) - tbl = Table('test_compress', m2, autoload=True) + tbl = Table("test_compress", m2, autoload=True) # Don't hardcode the exact value, but it must be non-empty - assert tbl.dialect_options['oracle']['compress'] + assert tbl.dialect_options["oracle"]["compress"] @testing.provide_metadata @testing.fails_if(all_tables_compress_for_missing) def test_reflect_oltp_compression(self): metadata = self.metadata - tbl = Table('test_compress', metadata, - Column('data', Integer, primary_key=True), - oracle_compress="OLTP") + tbl = Table( + "test_compress", + metadata, + Column("data", Integer, primary_key=True), + oracle_compress="OLTP", + ) metadata.create_all() m2 = MetaData(testing.db) - tbl = Table('test_compress', m2, autoload=True) - assert tbl.dialect_options['oracle']['compress'] == "OLTP" + tbl = Table("test_compress", m2, autoload=True) + assert tbl.dialect_options["oracle"]["compress"] == "OLTP" class RoundTripIndexTest(fixtures.TestBase): - __only_on__ = 'oracle' + __only_on__ = "oracle" __backend__ = True @testing.provide_metadata @@ -425,22 +521,27 @@ class RoundTripIndexTest(fixtures.TestBase): metadata = self.metadata s_table = Table( - "sometable", metadata, + "sometable", + metadata, Column("id_a", Unicode(255), primary_key=True), - Column("id_b", - Unicode(255), - primary_key=True, - unique=True), + Column("id_b", Unicode(255), primary_key=True, unique=True), Column("group", Unicode(255), primary_key=True), Column("col", Unicode(255)), - UniqueConstraint('col', 'group')) + UniqueConstraint("col", "group"), + ) # "group" is a keyword, so lower case - normalind = Index('tableind', s_table.c.id_b, s_table.c.group) - Index('compress1', s_table.c.id_a, s_table.c.id_b, - oracle_compress=True) - Index('compress2', s_table.c.id_a, s_table.c.id_b, s_table.c.col, - oracle_compress=1) + normalind = Index("tableind", s_table.c.id_b, s_table.c.group) + Index( + "compress1", s_table.c.id_a, s_table.c.id_b, oracle_compress=True + ) + Index( + "compress2", + s_table.c.id_a, + s_table.c.id_b, + s_table.c.col, + oracle_compress=1, + ) metadata.create_all() mirror = MetaData(testing.db) @@ -452,9 +553,11 @@ class RoundTripIndexTest(fixtures.TestBase): inspect.reflect() def obj_definition(obj): - return (obj.__class__, - tuple([c.name for c in obj.columns]), - getattr(obj, 'unique', None)) + return ( + obj.__class__, + tuple([c.name for c in obj.columns]), + getattr(obj, "unique", None), + ) # find what the primary k constraint name should be primaryconsname = testing.db.scalar( @@ -463,62 +566,72 @@ class RoundTripIndexTest(fixtures.TestBase): FROM all_constraints WHERE table_name = :table_name AND owner = :owner - AND constraint_type = 'P' """), + AND constraint_type = 'P' """ + ), table_name=s_table.name.upper(), - owner=testing.db.dialect.default_schema_name.upper()) + owner=testing.db.dialect.default_schema_name.upper(), + ) reflectedtable = inspect.tables[s_table.name] # make a dictionary of the reflected objects: - reflected = dict([(obj_definition(i), i) for i in - reflectedtable.indexes - | reflectedtable.constraints]) + reflected = dict( + [ + (obj_definition(i), i) + for i in reflectedtable.indexes | reflectedtable.constraints + ] + ) # assert we got primary key constraint and its name, Error # if not in dict - assert reflected[(PrimaryKeyConstraint, ('id_a', 'id_b', - 'group'), None)].name.upper() \ + assert ( + reflected[ + (PrimaryKeyConstraint, ("id_a", "id_b", "group"), None) + ].name.upper() == primaryconsname.upper() + ) # Error if not in dict - eq_( - reflected[(Index, ('id_b', 'group'), False)].name, - normalind.name - ) - assert (Index, ('id_b', ), True) in reflected - assert (Index, ('col', 'group'), True) in reflected + eq_(reflected[(Index, ("id_b", "group"), False)].name, normalind.name) + assert (Index, ("id_b",), True) in reflected + assert (Index, ("col", "group"), True) in reflected - idx = reflected[(Index, ('id_a', 'id_b', ), False)] - assert idx.dialect_options['oracle']['compress'] == 2 + idx = reflected[(Index, ("id_a", "id_b"), False)] + assert idx.dialect_options["oracle"]["compress"] == 2 - idx = reflected[(Index, ('id_a', 'id_b', 'col', ), False)] - assert idx.dialect_options['oracle']['compress'] == 1 + idx = reflected[(Index, ("id_a", "id_b", "col"), False)] + assert idx.dialect_options["oracle"]["compress"] == 1 eq_(len(reflectedtable.constraints), 1) eq_(len(reflectedtable.indexes), 5) class DBLinkReflectionTest(fixtures.TestBase): - __requires__ = 'oracle_test_dblink', - __only_on__ = 'oracle' + __requires__ = ("oracle_test_dblink",) + __only_on__ = "oracle" __backend__ = True @classmethod def setup_class(cls): from sqlalchemy.testing import config - cls.dblink = config.file_config.get('sqla_testing', 'oracle_db_link') + + cls.dblink = config.file_config.get("sqla_testing", "oracle_db_link") # note that the synonym here is still not totally functional # when accessing via a different username as we do with the # multiprocess test suite, so testing here is minimal with testing.db.connect() as conn: - conn.execute("create table test_table " - "(id integer primary key, data varchar2(50))") - conn.execute("create synonym test_table_syn " - "for test_table@%s" % cls.dblink) + conn.execute( + "create table test_table " + "(id integer primary key, data varchar2(50))" + ) + conn.execute( + "create synonym test_table_syn " + "for test_table@%s" % cls.dblink + ) @classmethod def teardown_class(cls): @@ -530,24 +643,29 @@ class DBLinkReflectionTest(fixtures.TestBase): """test the resolution of the synonym/dblink. """ m = MetaData() - t = Table('test_table_syn', m, autoload=True, - autoload_with=testing.db, oracle_resolve_synonyms=True) - eq_(list(t.c.keys()), ['id', 'data']) + t = Table( + "test_table_syn", + m, + autoload=True, + autoload_with=testing.db, + oracle_resolve_synonyms=True, + ) + eq_(list(t.c.keys()), ["id", "data"]) eq_(list(t.primary_key), [t.c.id]) class TypeReflectionTest(fixtures.TestBase): - __only_on__ = 'oracle' + __only_on__ = "oracle" __backend__ = True @testing.provide_metadata def _run_test(self, specs, attributes): - columns = [Column('c%i' % (i + 1), t[0]) for i, t in enumerate(specs)] + columns = [Column("c%i" % (i + 1), t[0]) for i, t in enumerate(specs)] m = self.metadata - Table('oracle_types', m, *columns) + Table("oracle_types", m, *columns) m.create_all() m2 = MetaData(testing.db) - table = Table('oracle_types', m2, autoload=True) + table = Table("oracle_types", m2, autoload=True) for i, (reflected_col, spec) in enumerate(zip(table.c, specs)): expected_spec = spec[1] reflected_type = reflected_col.type @@ -557,28 +675,23 @@ class TypeReflectionTest(fixtures.TestBase): getattr(reflected_type, attr), getattr(expected_spec, attr), "Column %s: Attribute %s value of %s does not " - "match %s for type %s" % ( + "match %s for type %s" + % ( "c%i" % (i + 1), attr, getattr(reflected_type, attr), getattr(expected_spec, attr), - spec[0] - ) + spec[0], + ), ) def test_integer_types(self): - specs = [ - (Integer, INTEGER(),), - (Numeric, INTEGER(),), - ] + specs = [(Integer, INTEGER()), (Numeric, INTEGER())] self._run_test(specs, []) def test_number_types(self): - specs = [ - (Numeric(5, 2), NUMBER(5, 2),), - (NUMBER, NUMBER(),), - ] - self._run_test(specs, ['precision', 'scale']) + specs = [(Numeric(5, 2), NUMBER(5, 2)), (NUMBER, NUMBER())] + self._run_test(specs, ["precision", "scale"]) def test_float_types(self): specs = [ @@ -587,11 +700,11 @@ class TypeReflectionTest(fixtures.TestBase): # (DOUBLE_PRECISION(), oracle.FLOAT(binary_precision=126)), (BINARY_DOUBLE(), BINARY_DOUBLE()), (BINARY_FLOAT(), BINARY_FLOAT()), - (FLOAT(5), FLOAT(),), + (FLOAT(5), FLOAT()), # when binary_precision is supported # (FLOAT(5), oracle.FLOAT(binary_precision=5),), (FLOAT(), FLOAT()), # when binary_precision is supported # (FLOAT(5), oracle.FLOAT(binary_precision=126),), ] - self._run_test(specs, ['precision']) + self._run_test(specs, ["precision"]) diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py index 9fbea61303..6d93d6501d 100644 --- a/test/dialect/oracle/test_types.py +++ b/test/dialect/oracle/test_types.py @@ -4,17 +4,54 @@ from sqlalchemy.testing import eq_ from sqlalchemy import types as sqltypes, exc, schema from sqlalchemy.sql import table, column -from sqlalchemy.testing import (fixtures, - AssertsExecutionResults, - AssertsCompiledSQL) +from sqlalchemy.testing import ( + fixtures, + AssertsExecutionResults, + AssertsCompiledSQL, +) from sqlalchemy import testing -from sqlalchemy import Integer, Text, LargeBinary, Unicode, UniqueConstraint,\ - Index, MetaData, select, inspect, ForeignKey, String, func, \ - TypeDecorator, bindparam, Numeric, TIMESTAMP, CHAR, text, SmallInteger, \ - literal_column, VARCHAR, create_engine, Date, NVARCHAR, \ - ForeignKeyConstraint, Sequence, Float, DateTime, cast, UnicodeText, \ - union, except_, type_coerce, or_, outerjoin, DATE, NCHAR, outparam, \ - PrimaryKeyConstraint, FLOAT +from sqlalchemy import ( + Integer, + Text, + LargeBinary, + Unicode, + UniqueConstraint, + Index, + MetaData, + select, + inspect, + ForeignKey, + String, + func, + TypeDecorator, + bindparam, + Numeric, + TIMESTAMP, + CHAR, + text, + SmallInteger, + literal_column, + VARCHAR, + create_engine, + Date, + NVARCHAR, + ForeignKeyConstraint, + Sequence, + Float, + DateTime, + cast, + UnicodeText, + union, + except_, + type_coerce, + or_, + outerjoin, + DATE, + NCHAR, + outparam, + PrimaryKeyConstraint, + FLOAT, +) from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.util import u, b from sqlalchemy import util @@ -50,16 +87,10 @@ class DialectTypesTest(fixtures.TestBase, AssertsCompiledSQL): dbapi = FakeDBAPI() b = bindparam("foo", "hello world!") - eq_( - b.type.dialect_impl(dialect).get_dbapi_type(dbapi), - 'STRING' - ) + eq_(b.type.dialect_impl(dialect).get_dbapi_type(dbapi), "STRING") b = bindparam("foo", "hello world!") - eq_( - b.type.dialect_impl(dialect).get_dbapi_type(dbapi), - 'STRING' - ) + eq_(b.type.dialect_impl(dialect).get_dbapi_type(dbapi), "STRING") def test_long(self): self.assert_compile(oracle.LONG(), "LONG") @@ -83,8 +114,8 @@ class DialectTypesTest(fixtures.TestBase, AssertsCompiledSQL): (oracle.RAW(50), cx_oracle._OracleRaw), ]: assert isinstance( - start.dialect_impl(dialect), test), \ - "wanted %r got %r" % (test, start.dialect_impl(dialect)) + start.dialect_impl(dialect), test + ), "wanted %r got %r" % (test, start.dialect_impl(dialect)) def test_type_adapt_nchar(self): dialect = cx_oracle.dialect(use_nchar_for_unicode=True) @@ -100,8 +131,8 @@ class DialectTypesTest(fixtures.TestBase, AssertsCompiledSQL): (NVARCHAR(), cx_oracle._OracleUnicodeStringNCHAR), ]: assert isinstance( - start.dialect_impl(dialect), test), \ - "wanted %r got %r" % (test, start.dialect_impl(dialect)) + start.dialect_impl(dialect), test + ), "wanted %r got %r" % (test, start.dialect_impl(dialect)) def test_raw_compile(self): self.assert_compile(oracle.RAW(), "RAW") @@ -154,49 +185,52 @@ class DialectTypesTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(typ, exp, dialect=dialect) def test_interval(self): - for type_, expected in [(oracle.INTERVAL(), - 'INTERVAL DAY TO SECOND'), - (oracle.INTERVAL(day_precision=3), - 'INTERVAL DAY(3) TO SECOND'), - (oracle.INTERVAL(second_precision=5), - 'INTERVAL DAY TO SECOND(5)'), - (oracle.INTERVAL(day_precision=2, - second_precision=5), - 'INTERVAL DAY(2) TO SECOND(5)')]: + for type_, expected in [ + (oracle.INTERVAL(), "INTERVAL DAY TO SECOND"), + (oracle.INTERVAL(day_precision=3), "INTERVAL DAY(3) TO SECOND"), + (oracle.INTERVAL(second_precision=5), "INTERVAL DAY TO SECOND(5)"), + ( + oracle.INTERVAL(day_precision=2, second_precision=5), + "INTERVAL DAY(2) TO SECOND(5)", + ), + ]: self.assert_compile(type_, expected) class TypesTest(fixtures.TestBase): - __only_on__ = 'oracle' + __only_on__ = "oracle" __dialect__ = oracle.OracleDialect() __backend__ = True - @testing.fails_on('+zxjdbc', 'zxjdbc lacks the FIXED_CHAR dbapi type') + @testing.fails_on("+zxjdbc", "zxjdbc lacks the FIXED_CHAR dbapi type") def test_fixed_char(self): m = MetaData(testing.db) - t = Table('t1', m, - Column('id', Integer, primary_key=True), - Column('data', CHAR(30), nullable=False)) + t = Table( + "t1", + m, + Column("id", Integer, primary_key=True), + Column("data", CHAR(30), nullable=False), + ) t.create() try: t.insert().execute( dict(id=1, data="value 1"), dict(id=2, data="value 2"), - dict(id=3, data="value 3") + dict(id=3, data="value 3"), ) eq_( - t.select().where(t.c.data == 'value 2').execute().fetchall(), - [(2, 'value 2 ')] + t.select().where(t.c.data == "value 2").execute().fetchall(), + [(2, "value 2 ")], ) m2 = MetaData(testing.db) - t2 = Table('t1', m2, autoload=True) + t2 = Table("t1", m2, autoload=True) assert type(t2.c.data.type) is CHAR eq_( - t2.select().where(t2.c.data == 'value 2').execute().fetchall(), - [(2, 'value 2 ')] + t2.select().where(t2.c.data == "value 2").execute().fetchall(), + [(2, "value 2 ")], ) finally: @@ -206,7 +240,7 @@ class TypesTest(fixtures.TestBase): @testing.provide_metadata def test_int_not_float(self): m = self.metadata - t1 = Table('t1', m, Column('foo', Integer)) + t1 = Table("t1", m, Column("foo", Integer)) t1.create() r = t1.insert().values(foo=5).returning(t1.c.foo).execute() x = r.scalar() @@ -223,7 +257,7 @@ class TypesTest(fixtures.TestBase): engine = testing_engine(options=dict(coerce_to_decimal=False)) m = self.metadata - t1 = Table('t1', m, Column('foo', Integer)) + t1 = Table("t1", m, Column("foo", Integer)) t1.create() r = engine.execute(t1.insert().values(foo=5).returning(t1.c.foo)) x = r.scalar() @@ -237,51 +271,58 @@ class TypesTest(fixtures.TestBase): @testing.provide_metadata def test_rowid(self): metadata = self.metadata - t = Table('t1', metadata, Column('x', Integer)) + t = Table("t1", metadata, Column("x", Integer)) t.create() t.insert().execute(x=5) s1 = select([t]) - s2 = select([column('rowid')]).select_from(s1) + s2 = select([column("rowid")]).select_from(s1) rowid = s2.scalar() # the ROWID type is not really needed here, # as cx_oracle just treats it as a string, # but we want to make sure the ROWID works... - rowid_col = column('rowid', oracle.ROWID) - s3 = select([t.c.x, rowid_col]) \ - .where(rowid_col == cast(rowid, oracle.ROWID)) + rowid_col = column("rowid", oracle.ROWID) + s3 = select([t.c.x, rowid_col]).where( + rowid_col == cast(rowid, oracle.ROWID) + ) eq_(s3.select().execute().fetchall(), [(5, rowid)]) - @testing.fails_on('+zxjdbc', - 'Not yet known how to pass values of the ' - 'INTERVAL type') + @testing.fails_on( + "+zxjdbc", "Not yet known how to pass values of the " "INTERVAL type" + ) @testing.provide_metadata def test_interval(self): metadata = self.metadata - interval_table = Table('intervaltable', metadata, Column('id', - Integer, primary_key=True, - test_needs_autoincrement=True), - Column('day_interval', - oracle.INTERVAL(day_precision=3))) + interval_table = Table( + "intervaltable", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("day_interval", oracle.INTERVAL(day_precision=3)), + ) metadata.create_all() - interval_table.insert().\ - execute(day_interval=datetime.timedelta(days=35, seconds=5743)) + interval_table.insert().execute( + day_interval=datetime.timedelta(days=35, seconds=5743) + ) row = interval_table.select().execute().first() - eq_(row['day_interval'], datetime.timedelta(days=35, - seconds=5743)) + eq_(row["day_interval"], datetime.timedelta(days=35, seconds=5743)) @testing.provide_metadata def test_numerics(self): m = self.metadata - t1 = Table('t1', m, - Column('intcol', Integer), - Column('numericcol', Numeric(precision=9, scale=2)), - Column('floatcol1', Float()), - Column('floatcol2', FLOAT()), - Column('doubleprec', oracle.DOUBLE_PRECISION), - Column('numbercol1', oracle.NUMBER(9)), - Column('numbercol2', oracle.NUMBER(9, 3)), - Column('numbercol3', oracle.NUMBER)) + t1 = Table( + "t1", + m, + Column("intcol", Integer), + Column("numericcol", Numeric(precision=9, scale=2)), + Column("floatcol1", Float()), + Column("floatcol2", FLOAT()), + Column("doubleprec", oracle.DOUBLE_PRECISION), + Column("numbercol1", oracle.NUMBER(9)), + Column("numbercol2", oracle.NUMBER(9, 3)), + Column("numbercol3", oracle.NUMBER), + ) t1.create() t1.insert().execute( intcol=1, @@ -291,115 +332,123 @@ class TypesTest(fixtures.TestBase): doubleprec=9.5, numbercol1=12, numbercol2=14.85, - numbercol3=15.76 - ) + numbercol3=15.76, + ) m2 = MetaData(testing.db) - t2 = Table('t1', m2, autoload=True) + t2 = Table("t1", m2, autoload=True) for row in ( t1.select().execute().first(), - t2.select().execute().first() + t2.select().execute().first(), ): - for i, (val, type_) in enumerate(( - (1, int), - (decimal.Decimal("5.2"), decimal.Decimal), - (6.5, float), - (8.5, float), - (9.5, float), - (12, int), - (decimal.Decimal("14.85"), decimal.Decimal), - (15.76, float), - )): + for i, (val, type_) in enumerate( + ( + (1, int), + (decimal.Decimal("5.2"), decimal.Decimal), + (6.5, float), + (8.5, float), + (9.5, float), + (12, int), + (decimal.Decimal("14.85"), decimal.Decimal), + (15.76, float), + ) + ): eq_(row[i], val) - assert isinstance(row[i], type_), '%r is not %r' \ - % (row[i], type_) + assert isinstance(row[i], type_), "%r is not %r" % ( + row[i], + type_, + ) @testing.provide_metadata def test_numeric_infinity_float(self): m = self.metadata - t1 = Table('t1', m, - Column("intcol", Integer), - Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=False))) + t1 = Table( + "t1", + m, + Column("intcol", Integer), + Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=False)), + ) t1.create() - t1.insert().execute([ - dict( - intcol=1, - numericcol=float("inf") - ), - dict( - intcol=2, - numericcol=float("-inf") - ), - ]) + t1.insert().execute( + [ + dict(intcol=1, numericcol=float("inf")), + dict(intcol=2, numericcol=float("-inf")), + ] + ) eq_( - select([t1.c.numericcol]). - order_by(t1.c.intcol).execute().fetchall(), - [(float('inf'), ), (float('-inf'), )] + select([t1.c.numericcol]) + .order_by(t1.c.intcol) + .execute() + .fetchall(), + [(float("inf"),), (float("-inf"),)], ) eq_( testing.db.execute( - "select numericcol from t1 order by intcol").fetchall(), - [(float('inf'), ), (float('-inf'), )] + "select numericcol from t1 order by intcol" + ).fetchall(), + [(float("inf"),), (float("-inf"),)], ) @testing.provide_metadata def test_numeric_infinity_decimal(self): m = self.metadata - t1 = Table('t1', m, - Column("intcol", Integer), - Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=True))) + t1 = Table( + "t1", + m, + Column("intcol", Integer), + Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=True)), + ) t1.create() - t1.insert().execute([ - dict( - intcol=1, - numericcol=decimal.Decimal("Infinity") - ), - dict( - intcol=2, - numericcol=decimal.Decimal("-Infinity") - ), - ]) + t1.insert().execute( + [ + dict(intcol=1, numericcol=decimal.Decimal("Infinity")), + dict(intcol=2, numericcol=decimal.Decimal("-Infinity")), + ] + ) eq_( - select([t1.c.numericcol]). - order_by(t1.c.intcol).execute().fetchall(), - [(decimal.Decimal("Infinity"), ), (decimal.Decimal("-Infinity"), )] + select([t1.c.numericcol]) + .order_by(t1.c.intcol) + .execute() + .fetchall(), + [(decimal.Decimal("Infinity"),), (decimal.Decimal("-Infinity"),)], ) eq_( testing.db.execute( - "select numericcol from t1 order by intcol").fetchall(), - [(decimal.Decimal("Infinity"), ), (decimal.Decimal("-Infinity"), )] + "select numericcol from t1 order by intcol" + ).fetchall(), + [(decimal.Decimal("Infinity"),), (decimal.Decimal("-Infinity"),)], ) @testing.provide_metadata def test_numeric_nan_float(self): m = self.metadata - t1 = Table('t1', m, - Column("intcol", Integer), - Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=False))) + t1 = Table( + "t1", + m, + Column("intcol", Integer), + Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=False)), + ) t1.create() - t1.insert().execute([ - dict( - intcol=1, - numericcol=float("nan") - ), - dict( - intcol=2, - numericcol=float("-nan") - ), - ]) + t1.insert().execute( + [ + dict(intcol=1, numericcol=float("nan")), + dict(intcol=2, numericcol=float("-nan")), + ] + ) eq_( [ tuple(str(col) for col in row) - for row in select([t1.c.numericcol]). - order_by(t1.c.intcol).execute() + for row in select([t1.c.numericcol]) + .order_by(t1.c.intcol) + .execute() ], - [('nan', ), ('nan', )] + [("nan",), ("nan",)], ) eq_( @@ -409,39 +458,40 @@ class TypesTest(fixtures.TestBase): "select numericcol from t1 order by intcol" ) ], - [('nan', ), ('nan', )] - + [("nan",), ("nan",)], ) # needs https://github.com/oracle/python-cx_Oracle/issues/184#issuecomment-391399292 @testing.provide_metadata def _dont_test_numeric_nan_decimal(self): m = self.metadata - t1 = Table('t1', m, - Column("intcol", Integer), - Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=True))) + t1 = Table( + "t1", + m, + Column("intcol", Integer), + Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=True)), + ) t1.create() - t1.insert().execute([ - dict( - intcol=1, - numericcol=decimal.Decimal("NaN") - ), - dict( - intcol=2, - numericcol=decimal.Decimal("-NaN") - ), - ]) + t1.insert().execute( + [ + dict(intcol=1, numericcol=decimal.Decimal("NaN")), + dict(intcol=2, numericcol=decimal.Decimal("-NaN")), + ] + ) eq_( - select([t1.c.numericcol]). - order_by(t1.c.intcol).execute().fetchall(), - [(decimal.Decimal("NaN"), ), (decimal.Decimal("NaN"), )] + select([t1.c.numericcol]) + .order_by(t1.c.intcol) + .execute() + .fetchall(), + [(decimal.Decimal("NaN"),), (decimal.Decimal("NaN"),)], ) eq_( testing.db.execute( - "select numericcol from t1 order by intcol").fetchall(), - [(decimal.Decimal("NaN"), ), (decimal.Decimal("NaN"), )] + "select numericcol from t1 order by intcol" + ).fetchall(), + [(decimal.Decimal("NaN"),), (decimal.Decimal("NaN"),)], ) @testing.provide_metadata @@ -456,33 +506,43 @@ class TypesTest(fixtures.TestBase): # this test requires cx_oracle 5 - foo = Table('foo', metadata, - Column('idata', Integer), - Column('ndata', Numeric(20, 2)), - Column('ndata2', Numeric(20, 2)), - Column('nidata', Numeric(5, 0)), - Column('fdata', Float())) + foo = Table( + "foo", + metadata, + Column("idata", Integer), + Column("ndata", Numeric(20, 2)), + Column("ndata2", Numeric(20, 2)), + Column("nidata", Numeric(5, 0)), + Column("fdata", Float()), + ) foo.create() - foo.insert().execute({ - 'idata': 5, - 'ndata': decimal.Decimal("45.6"), - 'ndata2': decimal.Decimal("45.0"), - 'nidata': decimal.Decimal('53'), - 'fdata': 45.68392 - }) + foo.insert().execute( + { + "idata": 5, + "ndata": decimal.Decimal("45.6"), + "ndata2": decimal.Decimal("45.0"), + "nidata": decimal.Decimal("53"), + "fdata": 45.68392, + } + ) stmt = "SELECT idata, ndata, ndata2, nidata, fdata FROM foo" row = testing.db.execute(stmt).fetchall()[0] eq_( [type(x) for x in row], - [int, decimal.Decimal, decimal.Decimal, int, float] + [int, decimal.Decimal, decimal.Decimal, int, float], ) eq_( row, - (5, decimal.Decimal('45.6'), decimal.Decimal('45'), - 53, 45.683920000000001) + ( + 5, + decimal.Decimal("45.6"), + decimal.Decimal("45"), + 53, + 45.683920000000001, + ), ) # with a nested subquery, @@ -508,28 +568,38 @@ class TypesTest(fixtures.TestBase): row = testing.db.execute(stmt).fetchall()[0] eq_( [type(x) for x in row], - [int, decimal.Decimal, int, int, decimal.Decimal] + [int, decimal.Decimal, int, int, decimal.Decimal], ) eq_( row, - (5, decimal.Decimal('45.6'), 45, 53, decimal.Decimal('45.68392')) + (5, decimal.Decimal("45.6"), 45, 53, decimal.Decimal("45.68392")), ) - row = testing.db.execute(text(stmt, - typemap={ - 'idata': Integer(), - 'ndata': Numeric(20, 2), - 'ndata2': Numeric(20, 2), - 'nidata': Numeric(5, 0), - 'fdata': Float()})).fetchall()[0] + row = testing.db.execute( + text( + stmt, + typemap={ + "idata": Integer(), + "ndata": Numeric(20, 2), + "ndata2": Numeric(20, 2), + "nidata": Numeric(5, 0), + "fdata": Float(), + }, + ) + ).fetchall()[0] eq_( [type(x) for x in row], - [int, decimal.Decimal, decimal.Decimal, decimal.Decimal, float] + [int, decimal.Decimal, decimal.Decimal, decimal.Decimal, float], ) eq_( row, - (5, decimal.Decimal('45.6'), decimal.Decimal('45'), - decimal.Decimal('53'), 45.683920000000001) + ( + 5, + decimal.Decimal("45.6"), + decimal.Decimal("45"), + decimal.Decimal("53"), + 45.683920000000001, + ), ) stmt = """ @@ -558,48 +628,56 @@ class TypesTest(fixtures.TestBase): row = testing.db.execute(stmt).fetchall()[0] eq_( [type(x) for x in row], - [int, decimal.Decimal, int, int, decimal.Decimal] + [int, decimal.Decimal, int, int, decimal.Decimal], ) eq_( row, - (5, decimal.Decimal('45.6'), 45, 53, decimal.Decimal('45.68392')) + (5, decimal.Decimal("45.6"), 45, 53, decimal.Decimal("45.68392")), ) - row = testing.db.execute(text(stmt, - typemap={ - 'anon_1_idata': Integer(), - 'anon_1_ndata': Numeric(20, 2), - 'anon_1_ndata2': Numeric(20, 2), - 'anon_1_nidata': Numeric(5, 0), - 'anon_1_fdata': Float() - })).fetchall()[0] + row = testing.db.execute( + text( + stmt, + typemap={ + "anon_1_idata": Integer(), + "anon_1_ndata": Numeric(20, 2), + "anon_1_ndata2": Numeric(20, 2), + "anon_1_nidata": Numeric(5, 0), + "anon_1_fdata": Float(), + }, + ) + ).fetchall()[0] eq_( [type(x) for x in row], - [int, decimal.Decimal, decimal.Decimal, decimal.Decimal, float] + [int, decimal.Decimal, decimal.Decimal, decimal.Decimal, float], ) eq_( row, - (5, decimal.Decimal('45.6'), decimal.Decimal('45'), - decimal.Decimal('53'), 45.683920000000001) + ( + 5, + decimal.Decimal("45.6"), + decimal.Decimal("45"), + decimal.Decimal("53"), + 45.683920000000001, + ), ) - row = testing.db.execute(text( - stmt, - typemap={ - 'anon_1_idata': Integer(), - 'anon_1_ndata': Numeric(20, 2, asdecimal=False), - 'anon_1_ndata2': Numeric(20, 2, asdecimal=False), - 'anon_1_nidata': Numeric(5, 0, asdecimal=False), - 'anon_1_fdata': Float(asdecimal=True) - })).fetchall()[0] - eq_( - [type(x) for x in row], - [int, float, float, float, decimal.Decimal] - ) + row = testing.db.execute( + text( + stmt, + typemap={ + "anon_1_idata": Integer(), + "anon_1_ndata": Numeric(20, 2, asdecimal=False), + "anon_1_ndata2": Numeric(20, 2, asdecimal=False), + "anon_1_nidata": Numeric(5, 0, asdecimal=False), + "anon_1_fdata": Float(asdecimal=True), + }, + ) + ).fetchall()[0] eq_( - row, - (5, 45.6, 45, 53, decimal.Decimal('45.68392')) + [type(x) for x in row], [int, float, float, float, decimal.Decimal] ) + eq_(row, (5, 45.6, 45, 53, decimal.Decimal("45.68392"))) def test_numeric_no_coerce_decimal_mode(self): engine = testing_engine(options=dict(coerce_to_decimal=False)) @@ -611,8 +689,10 @@ class TypesTest(fixtures.TestBase): # explicit typing still *does* coerce to decimal # (change in 1.2) value = engine.scalar( - text("SELECT 5.66 AS foo FROM DUAL"). - columns(foo=Numeric(4, 2, asdecimal=True))) + text("SELECT 5.66 AS foo FROM DUAL").columns( + foo=Numeric(4, 2, asdecimal=True) + ) + ) assert isinstance(value, decimal.Decimal) # default behavior is raw SQL coerces to decimal @@ -621,8 +701,8 @@ class TypesTest(fixtures.TestBase): @testing.only_on("oracle+cx_oracle", "cx_oracle-specific feature") @testing.fails_if( - testing.requires.python3, - "cx_oracle always returns unicode on py3k") + testing.requires.python3, "cx_oracle always returns unicode on py3k" + ) def test_coerce_to_unicode(self): engine = testing_engine(options=dict(coerce_to_unicode=False)) value = engine.scalar("SELECT 'hello' FROM DUAL") @@ -635,18 +715,17 @@ class TypesTest(fixtures.TestBase): def test_reflect_dates(self): metadata = self.metadata Table( - "date_types", metadata, - Column('d1', sqltypes.DATE), - Column('d2', oracle.DATE), - Column('d3', TIMESTAMP), - Column('d4', TIMESTAMP(timezone=True)), - Column('d5', oracle.INTERVAL(second_precision=5)), + "date_types", + metadata, + Column("d1", sqltypes.DATE), + Column("d2", oracle.DATE), + Column("d3", TIMESTAMP), + Column("d4", TIMESTAMP(timezone=True)), + Column("d5", oracle.INTERVAL(second_precision=5)), ) metadata.create_all() m = MetaData(testing.db) - t1 = Table( - "date_types", m, - autoload=True) + t1 = Table("date_types", m, autoload=True) assert isinstance(t1.c.d1.type, oracle.DATE) assert isinstance(t1.c.d1.type, DateTime) assert isinstance(t1.c.d2.type, oracle.DATE) @@ -658,76 +737,85 @@ class TypesTest(fixtures.TestBase): assert isinstance(t1.c.d5.type, oracle.INTERVAL) def _dont_test_reflect_all_types_schema(self): - types_table = Table('all_types', MetaData(testing.db), - Column('owner', String(30), primary_key=True), - Column('type_name', String(30), primary_key=True), - autoload=True, oracle_resolve_synonyms=True) + types_table = Table( + "all_types", + MetaData(testing.db), + Column("owner", String(30), primary_key=True), + Column("type_name", String(30), primary_key=True), + autoload=True, + oracle_resolve_synonyms=True, + ) for row in types_table.select().execute().fetchall(): [row[k] for k in row.keys()] @testing.provide_metadata def test_raw_roundtrip(self): metadata = self.metadata - raw_table = Table('raw', metadata, - Column('id', Integer, primary_key=True), - Column('data', oracle.RAW(35))) + raw_table = Table( + "raw", + metadata, + Column("id", Integer, primary_key=True), + Column("data", oracle.RAW(35)), + ) metadata.create_all() testing.db.execute(raw_table.insert(), id=1, data=b("ABCDEF")) - eq_( - testing.db.execute(raw_table.select()).first(), - (1, b("ABCDEF")) - ) + eq_(testing.db.execute(raw_table.select()).first(), (1, b("ABCDEF"))) @testing.provide_metadata def test_reflect_nvarchar(self): metadata = self.metadata - Table('tnv', metadata, Column('data', sqltypes.NVARCHAR(255))) + Table("tnv", metadata, Column("data", sqltypes.NVARCHAR(255))) metadata.create_all() m2 = MetaData(testing.db) - t2 = Table('tnv', m2, autoload=True) + t2 = Table("tnv", m2, autoload=True) assert isinstance(t2.c.data.type, sqltypes.NVARCHAR) - if testing.against('oracle+cx_oracle'): + if testing.against("oracle+cx_oracle"): assert isinstance( t2.c.data.type.dialect_impl(testing.db.dialect), - cx_oracle._OracleUnicodeStringNCHAR) + cx_oracle._OracleUnicodeStringNCHAR, + ) - data = u('m’a réveillé.') + data = u("m’a réveillé.") t2.insert().execute(data=data) - res = t2.select().execute().first()['data'] + res = t2.select().execute().first()["data"] eq_(res, data) assert isinstance(res, util.text_type) @testing.provide_metadata def test_reflect_unicode_no_nvarchar(self): metadata = self.metadata - Table('tnv', metadata, Column('data', sqltypes.Unicode(255))) + Table("tnv", metadata, Column("data", sqltypes.Unicode(255))) metadata.create_all() m2 = MetaData(testing.db) - t2 = Table('tnv', m2, autoload=True) + t2 = Table("tnv", m2, autoload=True) assert isinstance(t2.c.data.type, sqltypes.VARCHAR) - if testing.against('oracle+cx_oracle'): + if testing.against("oracle+cx_oracle"): assert isinstance( t2.c.data.type.dialect_impl(testing.db.dialect), - cx_oracle._OracleString) + cx_oracle._OracleString, + ) - data = u('m’a réveillé.') + data = u("m’a réveillé.") t2.insert().execute(data=data) - res = t2.select().execute().first()['data'] + res = t2.select().execute().first()["data"] eq_(res, data) assert isinstance(res, util.text_type) @testing.provide_metadata def test_char_length(self): metadata = self.metadata - t1 = Table('t1', metadata, - Column("c1", VARCHAR(50)), - Column("c2", NVARCHAR(250)), - Column("c3", CHAR(200))) + t1 = Table( + "t1", + metadata, + Column("c1", VARCHAR(50)), + Column("c2", NVARCHAR(250)), + Column("c3", CHAR(200)), + ) t1.create() m2 = MetaData(testing.db) - t2 = Table('t1', m2, autoload=True) + t2 = Table("t1", m2, autoload=True) eq_(t2.c.c1.type.length, 50) eq_(t2.c.c2.type.length, 250) eq_(t2.c.c3.type.length, 200) @@ -736,36 +824,35 @@ class TypesTest(fixtures.TestBase): def test_long_type(self): metadata = self.metadata - t = Table('t', metadata, Column('data', oracle.LONG)) + t = Table("t", metadata, Column("data", oracle.LONG)) metadata.create_all(testing.db) - testing.db.execute(t.insert(), data='xyz') - eq_( - testing.db.scalar(select([t.c.data])), - "xyz" - ) + testing.db.execute(t.insert(), data="xyz") + eq_(testing.db.scalar(select([t.c.data])), "xyz") def test_longstring(self): metadata = MetaData(testing.db) - testing.db.execute(""" + testing.db.execute( + """ CREATE TABLE Z_TEST ( ID NUMERIC(22) PRIMARY KEY, ADD_USER VARCHAR2(20) NOT NULL ) - """) + """ + ) try: t = Table("z_test", metadata, autoload=True) - t.insert().execute(id=1.0, add_user='foobar') - assert t.select().execute().fetchall() == [(1, 'foobar')] + t.insert().execute(id=1.0, add_user="foobar") + assert t.select().execute().fetchall() == [(1, "foobar")] finally: testing.db.execute("DROP TABLE Z_TEST") class LOBFetchTest(fixtures.TablesTest): - __only_on__ = 'oracle' + __only_on__ = "oracle" __backend__ = True - run_inserts = 'once' + run_inserts = "once" run_deletes = None @classmethod @@ -773,31 +860,35 @@ class LOBFetchTest(fixtures.TablesTest): Table( "z_test", metadata, - Column('id', Integer, primary_key=True), - Column('data', Text), - Column('bindata', LargeBinary)) + Column("id", Integer, primary_key=True), + Column("data", Text), + Column("bindata", LargeBinary), + ) Table( - 'binary_table', metadata, - Column('id', Integer, primary_key=True), - Column('data', LargeBinary) + "binary_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", LargeBinary), ) @classmethod def insert_data(cls): cls.data = data = [ dict( - id=i, data='this is text %d' % i, - bindata=b('this is binary %d' % i) - ) for i in range(1, 20) + id=i, + data="this is text %d" % i, + bindata=b("this is binary %d" % i), + ) + for i in range(1, 20) ] testing.db.execute(cls.tables.z_test.insert(), data) binary_table = cls.tables.binary_table fname = os.path.join( - os.path.dirname(__file__), "..", "..", - 'binary_data_one.dat') + os.path.dirname(__file__), "..", "..", "binary_data_one.dat" + ) with open(fname, "rb") as file_: cls.stream = stream = file_.read(12000) @@ -808,25 +899,27 @@ class LOBFetchTest(fixtures.TablesTest): engine = testing_engine(options=dict(auto_convert_lobs=False)) t = self.tables.z_test row = engine.execute(t.select().where(t.c.id == 1)).first() - eq_(row['data'].read(), 'this is text 1') - eq_(row['bindata'].read(), b('this is binary 1')) + eq_(row["data"].read(), "this is text 1") + eq_(row["bindata"].read(), b("this is binary 1")) def test_lobs_with_convert(self): t = self.tables.z_test row = testing.db.execute(t.select().where(t.c.id == 1)).first() - eq_(row['data'], 'this is text 1') - eq_(row['bindata'], b('this is binary 1')) + eq_(row["data"], "this is text 1") + eq_(row["bindata"], b("this is binary 1")) def test_lobs_with_convert_raw(self): row = testing.db.execute("select data, bindata from z_test").first() - eq_(row['data'], 'this is text 1') - eq_(row['bindata'], b('this is binary 1')) + eq_(row["data"], "this is text 1") + eq_(row["bindata"], b("this is binary 1")) def test_lobs_without_convert_many_rows(self): engine = testing_engine( - options=dict(auto_convert_lobs=False, arraysize=1)) + options=dict(auto_convert_lobs=False, arraysize=1) + ) result = engine.execute( - "select id, data, bindata from z_test order by id") + "select id, data, bindata from z_test order by id" + ) results = result.fetchall() def go(): @@ -835,17 +928,20 @@ class LOBFetchTest(fixtures.TablesTest): dict( id=row["id"], data=row["data"].read(), - bindata=row["bindata"].read() - ) for row in results + bindata=row["bindata"].read(), + ) + for row in results ], - self.data) + self.data, + ) + # this comes from cx_Oracle because these are raw # cx_Oracle.Variable objects if testing.requires.oracle5x.enabled: assert_raises_message( testing.db.dialect.dbapi.ProgrammingError, "LOB variable no longer valid after subsequent fetch", - go + go, ) else: go() @@ -853,32 +949,37 @@ class LOBFetchTest(fixtures.TablesTest): def test_lobs_with_convert_many_rows(self): # even with low arraysize, lobs are fine in autoconvert engine = testing_engine( - options=dict(auto_convert_lobs=True, arraysize=1)) + options=dict(auto_convert_lobs=True, arraysize=1) + ) result = engine.execute( - "select id, data, bindata from z_test order by id") + "select id, data, bindata from z_test order by id" + ) results = result.fetchall() eq_( [ - dict( - id=row["id"], - data=row["data"], - bindata=row["bindata"] - ) for row in results + dict(id=row["id"], data=row["data"], bindata=row["bindata"]) + for row in results ], - self.data) + self.data, + ) def test_large_stream(self): binary_table = self.tables.binary_table - result = binary_table.select().order_by(binary_table.c.id).\ - execute().fetchall() + result = ( + binary_table.select() + .order_by(binary_table.c.id) + .execute() + .fetchall() + ) eq_(result, [(i, self.stream) for i in range(1, 11)]) def test_large_stream_single_arraysize(self): binary_table = self.tables.binary_table - eng = testing_engine(options={'arraysize': 1}) - result = eng.execute(binary_table.select(). - order_by(binary_table.c.id)).fetchall() + eng = testing_engine(options={"arraysize": 1}) + result = eng.execute( + binary_table.select().order_by(binary_table.c.id) + ).fetchall() eq_(result, [(i, self.stream) for i in range(1, 11)]) @@ -887,7 +988,7 @@ class EuroNumericTest(fixtures.TestBase): test the numeric output_type_handler when using non-US locale for NLS_LANG. """ - __only_on__ = 'oracle+cx_oracle' + __only_on__ = "oracle+cx_oracle" __backend__ = True def setup(self): @@ -911,10 +1012,13 @@ class EuroNumericTest(fixtures.TestBase): try: cx_Oracle = self.engine.dialect.dbapi - def output_type_handler(cursor, name, defaultType, - size, precision, scale): - return cursor.var(cx_Oracle.STRING, 255, - arraysize=cursor.arraysize) + def output_type_handler( + cursor, name, defaultType, size, precision, scale + ): + return cursor.var( + cx_Oracle.STRING, 255, arraysize=cursor.arraysize + ) + cursor.outputtypehandler = output_type_handler cursor.execute("SELECT 1.1 FROM DUAL") row = cursor.fetchone() @@ -928,41 +1032,48 @@ class EuroNumericTest(fixtures.TestBase): for stmt, exp, kw in [ ("SELECT 0.1 FROM DUAL", decimal.Decimal("0.1"), {}), ("SELECT CAST(15 AS INTEGER) FROM DUAL", 15, {}), - ("SELECT CAST(15 AS NUMERIC(3, 1)) FROM DUAL", - decimal.Decimal("15"), {}), - ("SELECT CAST(0.1 AS NUMERIC(5, 2)) FROM DUAL", - decimal.Decimal("0.1"), {}), - ("SELECT :num FROM DUAL", decimal.Decimal("2.5"), - {'num': decimal.Decimal("2.5")}), - + ( + "SELECT CAST(15 AS NUMERIC(3, 1)) FROM DUAL", + decimal.Decimal("15"), + {}, + ), + ( + "SELECT CAST(0.1 AS NUMERIC(5, 2)) FROM DUAL", + decimal.Decimal("0.1"), + {}, + ), + ( + "SELECT :num FROM DUAL", + decimal.Decimal("2.5"), + {"num": decimal.Decimal("2.5")}, + ), ( text( "SELECT CAST(28.532 AS NUMERIC(5, 3)) " - "AS val FROM DUAL").columns( - val=Numeric(5, 3, asdecimal=True)), - decimal.Decimal("28.532"), {} - ) + "AS val FROM DUAL" + ).columns(val=Numeric(5, 3, asdecimal=True)), + decimal.Decimal("28.532"), + {}, + ), ]: test_exp = conn.scalar(stmt, **kw) - eq_( - test_exp, - exp - ) + eq_(test_exp, exp) assert type(test_exp) is type(exp) class SetInputSizesTest(fixtures.TestBase): - __only_on__ = 'oracle+cx_oracle' + __only_on__ = "oracle+cx_oracle" __backend__ = True @testing.provide_metadata def _test_setinputsizes( - self, datatype, value, sis_value, set_nchar_flag=False): + self, datatype, value, sis_value, set_nchar_flag=False + ): class TestTypeDec(TypeDecorator): impl = NullType() def load_dialect_impl(self, dialect): - if dialect.name == 'oracle': + if dialect.name == "oracle": return dialect.type_descriptor(datatype) else: return self.impl @@ -970,18 +1081,11 @@ class SetInputSizesTest(fixtures.TestBase): m = self.metadata # Oracle can have only one column of type LONG so we make three # tables rather than one table w/ three columns - t1 = Table( - 't1', m, - Column('foo', datatype), - ) + t1 = Table("t1", m, Column("foo", datatype)) t2 = Table( - 't2', m, - Column('foo', NullType().with_variant(datatype, "oracle")), - ) - t3 = Table( - 't3', m, - Column('foo', TestTypeDec()) + "t2", m, Column("foo", NullType().with_variant(datatype, "oracle")) ) + t3 = Table("t3", m, Column("foo", TestTypeDec())) m.create_all() class CursorWrapper(object): @@ -991,7 +1095,7 @@ class SetInputSizesTest(fixtures.TestBase): def __init__(self, connection_fairy): self.cursor = connection_fairy.connection.cursor() self.mock = mock.Mock() - connection_fairy.info['mock'] = self.mock + connection_fairy.info["mock"] = self.mock def setinputsizes(self, *arg, **kw): self.mock.setinputsizes(*arg, **kw) @@ -1009,23 +1113,21 @@ class SetInputSizesTest(fixtures.TestBase): connection_fairy = conn.connection for tab in [t1, t2, t3]: with mock.patch.object( - connection_fairy, "cursor", - lambda: CursorWrapper(connection_fairy) + connection_fairy, + "cursor", + lambda: CursorWrapper(connection_fairy), ): - conn.execute( - tab.insert(), - {"foo": value} - ) + conn.execute(tab.insert(), {"foo": value}) if sis_value: eq_( - conn.info['mock'].mock_calls, - [mock.call.setinputsizes(foo=sis_value)] + conn.info["mock"].mock_calls, + [mock.call.setinputsizes(foo=sis_value)], ) else: eq_( - conn.info['mock'].mock_calls, - [mock.call.setinputsizes()] + conn.info["mock"].mock_calls, + [mock.call.setinputsizes()], ) def test_smallint_setinputsizes(self): @@ -1036,58 +1138,68 @@ class SetInputSizesTest(fixtures.TestBase): def test_numeric_setinputsizes(self): self._test_setinputsizes( - Numeric(10, 8), decimal.Decimal("25.34534"), None) + Numeric(10, 8), decimal.Decimal("25.34534"), None + ) def test_float_setinputsizes(self): self._test_setinputsizes(Float(15), 25.34534, None) def test_binary_double_setinputsizes(self): self._test_setinputsizes( - oracle.BINARY_DOUBLE, 25.34534, - testing.db.dialect.dbapi.NATIVE_FLOAT) + oracle.BINARY_DOUBLE, + 25.34534, + testing.db.dialect.dbapi.NATIVE_FLOAT, + ) def test_binary_float_setinputsizes(self): self._test_setinputsizes( - oracle.BINARY_FLOAT, 25.34534, - testing.db.dialect.dbapi.NATIVE_FLOAT) + oracle.BINARY_FLOAT, + 25.34534, + testing.db.dialect.dbapi.NATIVE_FLOAT, + ) def test_double_precision_setinputsizes(self): - self._test_setinputsizes( - oracle.DOUBLE_PRECISION, 25.34534, - None) + self._test_setinputsizes(oracle.DOUBLE_PRECISION, 25.34534, None) def test_unicode_nchar_mode(self): self._test_setinputsizes( - Unicode(30), u("test"), testing.db.dialect.dbapi.NCHAR, - set_nchar_flag=True) + Unicode(30), + u("test"), + testing.db.dialect.dbapi.NCHAR, + set_nchar_flag=True, + ) def test_unicodetext_nchar_mode(self): self._test_setinputsizes( - UnicodeText(), u("test"), testing.db.dialect.dbapi.NCLOB, - set_nchar_flag=True) + UnicodeText(), + u("test"), + testing.db.dialect.dbapi.NCLOB, + set_nchar_flag=True, + ) def test_unicode(self): - self._test_setinputsizes( - Unicode(30), u("test"), None) + self._test_setinputsizes(Unicode(30), u("test"), None) def test_unicodetext(self): self._test_setinputsizes( - UnicodeText(), u("test"), testing.db.dialect.dbapi.CLOB) + UnicodeText(), u("test"), testing.db.dialect.dbapi.CLOB + ) def test_string(self): self._test_setinputsizes(String(30), "test", None) def test_char(self): self._test_setinputsizes( - CHAR(30), "test", testing.db.dialect.dbapi.FIXED_CHAR) + CHAR(30), "test", testing.db.dialect.dbapi.FIXED_CHAR + ) def test_nchar(self): self._test_setinputsizes( - NCHAR(30), u("test"), testing.db.dialect.dbapi.NCHAR) + NCHAR(30), u("test"), testing.db.dialect.dbapi.NCHAR + ) def test_long(self): - self._test_setinputsizes( - oracle.LONG(), "test", None) + self._test_setinputsizes(oracle.LONG(), "test", None) def test_event_no_native_float(self): def _remove_type(inputsizes, cursor, statement, parameters, context): @@ -1097,8 +1209,6 @@ class SetInputSizesTest(fixtures.TestBase): event.listen(testing.db, "do_setinputsizes", _remove_type) try: - self._test_setinputsizes( - oracle.BINARY_FLOAT, 25.34534, - None) + self._test_setinputsizes(oracle.BINARY_FLOAT, 25.34534, None) finally: event.remove(testing.db, "do_setinputsizes", _remove_type) diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 3ebc4a1ab2..58e421f8ac 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -1,12 +1,33 @@ # coding: utf-8 -from sqlalchemy.testing.assertions import AssertsCompiledSQL, is_, \ - assert_raises, assert_raises_message, expect_warnings +from sqlalchemy.testing.assertions import ( + AssertsCompiledSQL, + is_, + assert_raises, + assert_raises_message, + expect_warnings, +) from sqlalchemy.testing import engines, fixtures from sqlalchemy import testing -from sqlalchemy import Sequence, Table, Column, Integer, update, String,\ - func, MetaData, Enum, Index, and_, delete, select, cast, text, \ - Text, null +from sqlalchemy import ( + Sequence, + Table, + Column, + Integer, + update, + String, + func, + MetaData, + Enum, + Index, + and_, + delete, + select, + cast, + text, + Text, + null, +) from sqlalchemy import types as sqltypes from sqlalchemy.dialects.postgresql import ExcludeConstraint, array from sqlalchemy import exc, schema @@ -22,35 +43,42 @@ from sqlalchemy.dialects.postgresql import ARRAY as PG_ARRAY class SequenceTest(fixtures.TestBase, AssertsCompiledSQL): - __prefer__ = 'postgresql' + __prefer__ = "postgresql" def test_format(self): - seq = Sequence('my_seq_no_schema') + seq = Sequence("my_seq_no_schema") dialect = postgresql.dialect() - assert dialect.identifier_preparer.format_sequence(seq) \ - == 'my_seq_no_schema' - seq = Sequence('my_seq', schema='some_schema') - assert dialect.identifier_preparer.format_sequence(seq) \ - == 'some_schema.my_seq' - seq = Sequence('My_Seq', schema='Some_Schema') - assert dialect.identifier_preparer.format_sequence(seq) \ + assert ( + dialect.identifier_preparer.format_sequence(seq) + == "my_seq_no_schema" + ) + seq = Sequence("my_seq", schema="some_schema") + assert ( + dialect.identifier_preparer.format_sequence(seq) + == "some_schema.my_seq" + ) + seq = Sequence("My_Seq", schema="Some_Schema") + assert ( + dialect.identifier_preparer.format_sequence(seq) == '"Some_Schema"."My_Seq"' + ) - @testing.only_on('postgresql', 'foo') + @testing.only_on("postgresql", "foo") @testing.provide_metadata def test_reverse_eng_name(self): metadata = self.metadata engine = engines.testing_engine(options=dict(implicit_returning=False)) for tname, cname in [ - ('tb1' * 30, 'abc'), - ('tb2', 'abc' * 30), - ('tb3' * 30, 'abc' * 30), - ('tb4', 'abc'), + ("tb1" * 30, "abc"), + ("tb2", "abc" * 30), + ("tb3" * 30, "abc" * 30), + ("tb4", "abc"), ]: - t = Table(tname[:57], - metadata, - Column(cname[:57], Integer, primary_key=True) - ) + t = Table( + tname[:57], + metadata, + Column(cname[:57], Integer, primary_key=True), + ) t.create(engine) r = engine.execute(t.insert()) assert r.inserted_primary_key == [1] @@ -63,478 +91,589 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_update_returning(self): dialect = postgresql.dialect() table1 = table( - 'mytable', - column( - 'myid', Integer), - column( - 'name', String(128)), - column( - 'description', String(128))) - u = update( - table1, - values=dict( - name='foo')).returning( - table1.c.myid, - table1.c.name) - self.assert_compile(u, - 'UPDATE mytable SET name=%(name)s ' - 'RETURNING mytable.myid, mytable.name', - dialect=dialect) - u = update(table1, values=dict(name='foo')).returning(table1) - self.assert_compile(u, - 'UPDATE mytable SET name=%(name)s ' - 'RETURNING mytable.myid, mytable.name, ' - 'mytable.description', dialect=dialect) - u = update(table1, values=dict(name='foo' - )).returning(func.length(table1.c.name)) + "mytable", + column("myid", Integer), + column("name", String(128)), + column("description", String(128)), + ) + u = update(table1, values=dict(name="foo")).returning( + table1.c.myid, table1.c.name + ) + self.assert_compile( + u, + "UPDATE mytable SET name=%(name)s " + "RETURNING mytable.myid, mytable.name", + dialect=dialect, + ) + u = update(table1, values=dict(name="foo")).returning(table1) self.assert_compile( u, - 'UPDATE mytable SET name=%(name)s ' - 'RETURNING length(mytable.name) AS length_1', - dialect=dialect) + "UPDATE mytable SET name=%(name)s " + "RETURNING mytable.myid, mytable.name, " + "mytable.description", + dialect=dialect, + ) + u = update(table1, values=dict(name="foo")).returning( + func.length(table1.c.name) + ) + self.assert_compile( + u, + "UPDATE mytable SET name=%(name)s " + "RETURNING length(mytable.name) AS length_1", + dialect=dialect, + ) def test_insert_returning(self): dialect = postgresql.dialect() - table1 = table('mytable', - column('myid', Integer), - column('name', String(128)), - column('description', String(128)), - ) + table1 = table( + "mytable", + column("myid", Integer), + column("name", String(128)), + column("description", String(128)), + ) - i = insert( - table1, - values=dict( - name='foo')).returning( - table1.c.myid, - table1.c.name) - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES ' - '(%(name)s) RETURNING mytable.myid, ' - 'mytable.name', dialect=dialect) - i = insert(table1, values=dict(name='foo')).returning(table1) - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES ' - '(%(name)s) RETURNING mytable.myid, ' - 'mytable.name, mytable.description', - dialect=dialect) - i = insert(table1, values=dict(name='foo' - )).returning(func.length(table1.c.name)) - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES ' - '(%(name)s) RETURNING length(mytable.name) ' - 'AS length_1', dialect=dialect) + i = insert(table1, values=dict(name="foo")).returning( + table1.c.myid, table1.c.name + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) RETURNING mytable.myid, " + "mytable.name", + dialect=dialect, + ) + i = insert(table1, values=dict(name="foo")).returning(table1) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) RETURNING mytable.myid, " + "mytable.name, mytable.description", + dialect=dialect, + ) + i = insert(table1, values=dict(name="foo")).returning( + func.length(table1.c.name) + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) RETURNING length(mytable.name) " + "AS length_1", + dialect=dialect, + ) def test_create_drop_enum(self): # test escaping and unicode within CREATE TYPE for ENUM typ = postgresql.ENUM( - "val1", "val2", "val's 3", u('méil'), name="myname") + "val1", "val2", "val's 3", u("méil"), name="myname" + ) self.assert_compile( postgresql.CreateEnumType(typ), - u("CREATE TYPE myname AS " - "ENUM ('val1', 'val2', 'val''s 3', 'méil')")) + u( + "CREATE TYPE myname AS " + "ENUM ('val1', 'val2', 'val''s 3', 'méil')" + ), + ) - typ = postgresql.ENUM( - "val1", "val2", "val's 3", name="PleaseQuoteMe") - self.assert_compile(postgresql.CreateEnumType(typ), - "CREATE TYPE \"PleaseQuoteMe\" AS ENUM " - "('val1', 'val2', 'val''s 3')" - ) + typ = postgresql.ENUM("val1", "val2", "val's 3", name="PleaseQuoteMe") + self.assert_compile( + postgresql.CreateEnumType(typ), + 'CREATE TYPE "PleaseQuoteMe" AS ENUM ' + "('val1', 'val2', 'val''s 3')", + ) def test_generic_enum(self): - e1 = Enum('x', 'y', 'z', name='somename') - e2 = Enum('x', 'y', 'z', name='somename', schema='someschema') - self.assert_compile(postgresql.CreateEnumType(e1), - "CREATE TYPE somename AS ENUM ('x', 'y', 'z')" - ) - self.assert_compile(postgresql.CreateEnumType(e2), - "CREATE TYPE someschema.somename AS ENUM " - "('x', 'y', 'z')") - self.assert_compile(postgresql.DropEnumType(e1), - 'DROP TYPE somename') - self.assert_compile(postgresql.DropEnumType(e2), - 'DROP TYPE someschema.somename') - t1 = Table('sometable', MetaData(), Column('somecolumn', e1)) - self.assert_compile(schema.CreateTable(t1), - 'CREATE TABLE sometable (somecolumn ' - 'somename)') + e1 = Enum("x", "y", "z", name="somename") + e2 = Enum("x", "y", "z", name="somename", schema="someschema") + self.assert_compile( + postgresql.CreateEnumType(e1), + "CREATE TYPE somename AS ENUM ('x', 'y', 'z')", + ) + self.assert_compile( + postgresql.CreateEnumType(e2), + "CREATE TYPE someschema.somename AS ENUM " "('x', 'y', 'z')", + ) + self.assert_compile(postgresql.DropEnumType(e1), "DROP TYPE somename") + self.assert_compile( + postgresql.DropEnumType(e2), "DROP TYPE someschema.somename" + ) + t1 = Table("sometable", MetaData(), Column("somecolumn", e1)) + self.assert_compile( + schema.CreateTable(t1), + "CREATE TABLE sometable (somecolumn " "somename)", + ) t1 = Table( - 'sometable', + "sometable", MetaData(), - Column( - 'somecolumn', - Enum( - 'x', - 'y', - 'z', - native_enum=False))) - self.assert_compile(schema.CreateTable(t1), - "CREATE TABLE sometable (somecolumn " - "VARCHAR(1), CHECK (somecolumn IN ('x', " - "'y', 'z')))") + Column("somecolumn", Enum("x", "y", "z", native_enum=False)), + ) + self.assert_compile( + schema.CreateTable(t1), + "CREATE TABLE sometable (somecolumn " + "VARCHAR(1), CHECK (somecolumn IN ('x', " + "'y', 'z')))", + ) def test_create_type_schema_translate(self): - e1 = Enum('x', 'y', 'z', name='somename') - e2 = Enum('x', 'y', 'z', name='somename', schema='someschema') + e1 = Enum("x", "y", "z", name="somename") + e2 = Enum("x", "y", "z", name="somename", schema="someschema") schema_translate_map = {None: "foo", "someschema": "bar"} self.assert_compile( postgresql.CreateEnumType(e1), "CREATE TYPE foo.somename AS ENUM ('x', 'y', 'z')", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) self.assert_compile( postgresql.CreateEnumType(e2), "CREATE TYPE bar.somename AS ENUM ('x', 'y', 'z')", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) def test_create_table_with_tablespace(self): m = MetaData() tbl = Table( - 'atable', m, Column("id", Integer), - postgresql_tablespace='sometablespace') + "atable", + m, + Column("id", Integer), + postgresql_tablespace="sometablespace", + ) self.assert_compile( schema.CreateTable(tbl), - "CREATE TABLE atable (id INTEGER) TABLESPACE sometablespace") + "CREATE TABLE atable (id INTEGER) TABLESPACE sometablespace", + ) def test_create_table_with_tablespace_quoted(self): # testing quoting of tablespace name m = MetaData() tbl = Table( - 'anothertable', m, Column("id", Integer), - postgresql_tablespace='table') + "anothertable", + m, + Column("id", Integer), + postgresql_tablespace="table", + ) self.assert_compile( schema.CreateTable(tbl), - 'CREATE TABLE anothertable (id INTEGER) TABLESPACE "table"') + 'CREATE TABLE anothertable (id INTEGER) TABLESPACE "table"', + ) def test_create_table_inherits(self): m = MetaData() tbl = Table( - 'atable', m, Column("id", Integer), - postgresql_inherits='i1') + "atable", m, Column("id", Integer), postgresql_inherits="i1" + ) self.assert_compile( schema.CreateTable(tbl), - "CREATE TABLE atable (id INTEGER) INHERITS ( i1 )") + "CREATE TABLE atable (id INTEGER) INHERITS ( i1 )", + ) def test_create_table_inherits_tuple(self): m = MetaData() tbl = Table( - 'atable', m, Column("id", Integer), - postgresql_inherits=('i1', 'i2')) + "atable", + m, + Column("id", Integer), + postgresql_inherits=("i1", "i2"), + ) self.assert_compile( schema.CreateTable(tbl), - "CREATE TABLE atable (id INTEGER) INHERITS ( i1, i2 )") + "CREATE TABLE atable (id INTEGER) INHERITS ( i1, i2 )", + ) def test_create_table_inherits_quoting(self): m = MetaData() tbl = Table( - 'atable', m, Column("id", Integer), - postgresql_inherits=('Quote Me', 'quote Me Too')) + "atable", + m, + Column("id", Integer), + postgresql_inherits=("Quote Me", "quote Me Too"), + ) self.assert_compile( schema.CreateTable(tbl), - 'CREATE TABLE atable (id INTEGER) INHERITS ' - '( "Quote Me", "quote Me Too" )') + "CREATE TABLE atable (id INTEGER) INHERITS " + '( "Quote Me", "quote Me Too" )', + ) def test_create_table_partition_by_list(self): m = MetaData() tbl = Table( - 'atable', m, Column("id", Integer), Column("part_column", Integer), - postgresql_partition_by='LIST (part_column)') + "atable", + m, + Column("id", Integer), + Column("part_column", Integer), + postgresql_partition_by="LIST (part_column)", + ) self.assert_compile( schema.CreateTable(tbl), - 'CREATE TABLE atable (id INTEGER, part_column INTEGER) ' - 'PARTITION BY LIST (part_column)') + "CREATE TABLE atable (id INTEGER, part_column INTEGER) " + "PARTITION BY LIST (part_column)", + ) def test_create_table_partition_by_range(self): m = MetaData() tbl = Table( - 'atable', m, Column("id", Integer), Column("part_column", Integer), - postgresql_partition_by='RANGE (part_column)') + "atable", + m, + Column("id", Integer), + Column("part_column", Integer), + postgresql_partition_by="RANGE (part_column)", + ) self.assert_compile( schema.CreateTable(tbl), - 'CREATE TABLE atable (id INTEGER, part_column INTEGER) ' - 'PARTITION BY RANGE (part_column)') + "CREATE TABLE atable (id INTEGER, part_column INTEGER) " + "PARTITION BY RANGE (part_column)", + ) def test_create_table_with_oids(self): m = MetaData() tbl = Table( - 'atable', m, Column("id", Integer), - postgresql_with_oids=True, ) + "atable", m, Column("id", Integer), postgresql_with_oids=True + ) self.assert_compile( schema.CreateTable(tbl), - "CREATE TABLE atable (id INTEGER) WITH OIDS") + "CREATE TABLE atable (id INTEGER) WITH OIDS", + ) tbl2 = Table( - 'anothertable', m, Column("id", Integer), - postgresql_with_oids=False) + "anothertable", + m, + Column("id", Integer), + postgresql_with_oids=False, + ) self.assert_compile( schema.CreateTable(tbl2), - "CREATE TABLE anothertable (id INTEGER) WITHOUT OIDS") + "CREATE TABLE anothertable (id INTEGER) WITHOUT OIDS", + ) def test_create_table_with_oncommit_option(self): m = MetaData() tbl = Table( - 'atable', m, Column("id", Integer), - postgresql_on_commit="drop") + "atable", m, Column("id", Integer), postgresql_on_commit="drop" + ) self.assert_compile( schema.CreateTable(tbl), - "CREATE TABLE atable (id INTEGER) ON COMMIT DROP") + "CREATE TABLE atable (id INTEGER) ON COMMIT DROP", + ) def test_create_table_with_multiple_options(self): m = MetaData() tbl = Table( - 'atable', m, Column("id", Integer), - postgresql_tablespace='sometablespace', + "atable", + m, + Column("id", Integer), + postgresql_tablespace="sometablespace", postgresql_with_oids=False, - postgresql_on_commit="preserve_rows") + postgresql_on_commit="preserve_rows", + ) self.assert_compile( schema.CreateTable(tbl), "CREATE TABLE atable (id INTEGER) WITHOUT OIDS " - "ON COMMIT PRESERVE ROWS TABLESPACE sometablespace") + "ON COMMIT PRESERVE ROWS TABLESPACE sometablespace", + ) def test_create_partial_index(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', Integer)) - idx = Index('test_idx1', tbl.c.data, - postgresql_where=and_(tbl.c.data > 5, tbl.c.data - < 10)) - idx = Index('test_idx1', tbl.c.data, - postgresql_where=and_(tbl.c.data > 5, tbl.c.data - < 10)) + tbl = Table("testtbl", m, Column("data", Integer)) + idx = Index( + "test_idx1", + tbl.c.data, + postgresql_where=and_(tbl.c.data > 5, tbl.c.data < 10), + ) + idx = Index( + "test_idx1", + tbl.c.data, + postgresql_where=and_(tbl.c.data > 5, tbl.c.data < 10), + ) # test quoting and all that - idx2 = Index('test_idx2', tbl.c.data, - postgresql_where=and_(tbl.c.data > 'a', tbl.c.data - < "b's")) - self.assert_compile(schema.CreateIndex(idx), - 'CREATE INDEX test_idx1 ON testtbl (data) ' - 'WHERE data > 5 AND data < 10', - dialect=postgresql.dialect()) - self.assert_compile(schema.CreateIndex(idx2), - "CREATE INDEX test_idx2 ON testtbl (data) " - "WHERE data > 'a' AND data < 'b''s'", - dialect=postgresql.dialect()) + idx2 = Index( + "test_idx2", + tbl.c.data, + postgresql_where=and_(tbl.c.data > "a", tbl.c.data < "b's"), + ) + self.assert_compile( + schema.CreateIndex(idx), + "CREATE INDEX test_idx1 ON testtbl (data) " + "WHERE data > 5 AND data < 10", + dialect=postgresql.dialect(), + ) + self.assert_compile( + schema.CreateIndex(idx2), + "CREATE INDEX test_idx2 ON testtbl (data) " + "WHERE data > 'a' AND data < 'b''s'", + dialect=postgresql.dialect(), + ) def test_create_index_with_ops(self): m = MetaData() - tbl = Table('testtbl', m, - Column('data', String), - Column('data2', Integer, key='d2')) - - idx = Index('test_idx1', tbl.c.data, - postgresql_ops={'data': 'text_pattern_ops'}) - - idx2 = Index('test_idx2', tbl.c.data, tbl.c.d2, - postgresql_ops={'data': 'text_pattern_ops', - 'd2': 'int4_ops'}) - - self.assert_compile(schema.CreateIndex(idx), - 'CREATE INDEX test_idx1 ON testtbl ' - '(data text_pattern_ops)', - dialect=postgresql.dialect()) - self.assert_compile(schema.CreateIndex(idx2), - 'CREATE INDEX test_idx2 ON testtbl ' - '(data text_pattern_ops, data2 int4_ops)', - dialect=postgresql.dialect()) + tbl = Table( + "testtbl", + m, + Column("data", String), + Column("data2", Integer, key="d2"), + ) + + idx = Index( + "test_idx1", + tbl.c.data, + postgresql_ops={"data": "text_pattern_ops"}, + ) + + idx2 = Index( + "test_idx2", + tbl.c.data, + tbl.c.d2, + postgresql_ops={"data": "text_pattern_ops", "d2": "int4_ops"}, + ) + + self.assert_compile( + schema.CreateIndex(idx), + "CREATE INDEX test_idx1 ON testtbl " "(data text_pattern_ops)", + dialect=postgresql.dialect(), + ) + self.assert_compile( + schema.CreateIndex(idx2), + "CREATE INDEX test_idx2 ON testtbl " + "(data text_pattern_ops, data2 int4_ops)", + dialect=postgresql.dialect(), + ) def test_create_index_with_labeled_ops(self): m = MetaData() - tbl = Table('testtbl', m, - Column('data', String), - Column('data2', Integer, key='d2')) + tbl = Table( + "testtbl", + m, + Column("data", String), + Column("data2", Integer, key="d2"), + ) - idx = Index('test_idx1', func.lower(tbl.c.data).label('data_lower'), - postgresql_ops={'data_lower': 'text_pattern_ops'}) + idx = Index( + "test_idx1", + func.lower(tbl.c.data).label("data_lower"), + postgresql_ops={"data_lower": "text_pattern_ops"}, + ) idx2 = Index( - 'test_idx2', - (func.xyz(tbl.c.data) + tbl.c.d2).label('bar'), - tbl.c.d2.label('foo'), - postgresql_ops={'bar': 'text_pattern_ops', - 'foo': 'int4_ops'}) - - self.assert_compile(schema.CreateIndex(idx), - 'CREATE INDEX test_idx1 ON testtbl ' - '(lower(data) text_pattern_ops)', - dialect=postgresql.dialect()) - self.assert_compile(schema.CreateIndex(idx2), - 'CREATE INDEX test_idx2 ON testtbl ' - '((xyz(data) + data2) text_pattern_ops, ' - 'data2 int4_ops)', - dialect=postgresql.dialect()) + "test_idx2", + (func.xyz(tbl.c.data) + tbl.c.d2).label("bar"), + tbl.c.d2.label("foo"), + postgresql_ops={"bar": "text_pattern_ops", "foo": "int4_ops"}, + ) + + self.assert_compile( + schema.CreateIndex(idx), + "CREATE INDEX test_idx1 ON testtbl " + "(lower(data) text_pattern_ops)", + dialect=postgresql.dialect(), + ) + self.assert_compile( + schema.CreateIndex(idx2), + "CREATE INDEX test_idx2 ON testtbl " + "((xyz(data) + data2) text_pattern_ops, " + "data2 int4_ops)", + dialect=postgresql.dialect(), + ) def test_create_index_with_text_or_composite(self): m = MetaData() - tbl = Table('testtbl', m, - Column('d1', String), - Column('d2', Integer)) + tbl = Table("testtbl", m, Column("d1", String), Column("d2", Integer)) - idx = Index('test_idx1', text('x')) + idx = Index("test_idx1", text("x")) tbl.append_constraint(idx) - idx2 = Index('test_idx2', text('y'), tbl.c.d2) + idx2 = Index("test_idx2", text("y"), tbl.c.d2) idx3 = Index( - 'test_idx2', tbl.c.d1, text('y'), tbl.c.d2, - postgresql_ops={'d1': 'x1', 'd2': 'x2'} + "test_idx2", + tbl.c.d1, + text("y"), + tbl.c.d2, + postgresql_ops={"d1": "x1", "d2": "x2"}, ) idx4 = Index( - 'test_idx2', tbl.c.d1, tbl.c.d2 > 5, text('q'), - postgresql_ops={'d1': 'x1', 'd2': 'x2'} + "test_idx2", + tbl.c.d1, + tbl.c.d2 > 5, + text("q"), + postgresql_ops={"d1": "x1", "d2": "x2"}, ) idx5 = Index( - 'test_idx2', tbl.c.d1, (tbl.c.d2 > 5).label('g'), text('q'), - postgresql_ops={'d1': 'x1', 'g': 'x2'} + "test_idx2", + tbl.c.d1, + (tbl.c.d2 > 5).label("g"), + text("q"), + postgresql_ops={"d1": "x1", "g": "x2"}, ) self.assert_compile( - schema.CreateIndex(idx), - "CREATE INDEX test_idx1 ON testtbl (x)" + schema.CreateIndex(idx), "CREATE INDEX test_idx1 ON testtbl (x)" ) self.assert_compile( schema.CreateIndex(idx2), - "CREATE INDEX test_idx2 ON testtbl (y, d2)" + "CREATE INDEX test_idx2 ON testtbl (y, d2)", ) self.assert_compile( schema.CreateIndex(idx3), - "CREATE INDEX test_idx2 ON testtbl (d1 x1, y, d2 x2)" + "CREATE INDEX test_idx2 ON testtbl (d1 x1, y, d2 x2)", ) # note that at the moment we do not expect the 'd2' op to # pick up on the "d2 > 5" expression self.assert_compile( schema.CreateIndex(idx4), - "CREATE INDEX test_idx2 ON testtbl (d1 x1, (d2 > 5), q)" + "CREATE INDEX test_idx2 ON testtbl (d1 x1, (d2 > 5), q)", ) # however it does work if we label! self.assert_compile( schema.CreateIndex(idx5), - "CREATE INDEX test_idx2 ON testtbl (d1 x1, (d2 > 5) x2, q)" + "CREATE INDEX test_idx2 ON testtbl (d1 x1, (d2 > 5) x2, q)", ) def test_create_index_with_using(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', String)) - - idx1 = Index('test_idx1', tbl.c.data) - idx2 = Index('test_idx2', tbl.c.data, postgresql_using='btree') - idx3 = Index('test_idx3', tbl.c.data, postgresql_using='hash') - - self.assert_compile(schema.CreateIndex(idx1), - 'CREATE INDEX test_idx1 ON testtbl ' - '(data)', - dialect=postgresql.dialect()) - self.assert_compile(schema.CreateIndex(idx2), - 'CREATE INDEX test_idx2 ON testtbl ' - 'USING btree (data)', - dialect=postgresql.dialect()) - self.assert_compile(schema.CreateIndex(idx3), - 'CREATE INDEX test_idx3 ON testtbl ' - 'USING hash (data)', - dialect=postgresql.dialect()) + tbl = Table("testtbl", m, Column("data", String)) + + idx1 = Index("test_idx1", tbl.c.data) + idx2 = Index("test_idx2", tbl.c.data, postgresql_using="btree") + idx3 = Index("test_idx3", tbl.c.data, postgresql_using="hash") + + self.assert_compile( + schema.CreateIndex(idx1), + "CREATE INDEX test_idx1 ON testtbl " "(data)", + dialect=postgresql.dialect(), + ) + self.assert_compile( + schema.CreateIndex(idx2), + "CREATE INDEX test_idx2 ON testtbl " "USING btree (data)", + dialect=postgresql.dialect(), + ) + self.assert_compile( + schema.CreateIndex(idx3), + "CREATE INDEX test_idx3 ON testtbl " "USING hash (data)", + dialect=postgresql.dialect(), + ) def test_create_index_with_with(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', String)) + tbl = Table("testtbl", m, Column("data", String)) - idx1 = Index('test_idx1', tbl.c.data) + idx1 = Index("test_idx1", tbl.c.data) idx2 = Index( - 'test_idx2', tbl.c.data, postgresql_with={"fillfactor": 50}) - idx3 = Index('test_idx3', tbl.c.data, postgresql_using="gist", - postgresql_with={"buffering": "off"}) - - self.assert_compile(schema.CreateIndex(idx1), - 'CREATE INDEX test_idx1 ON testtbl ' - '(data)') - self.assert_compile(schema.CreateIndex(idx2), - 'CREATE INDEX test_idx2 ON testtbl ' - '(data) ' - 'WITH (fillfactor = 50)') - self.assert_compile(schema.CreateIndex(idx3), - 'CREATE INDEX test_idx3 ON testtbl ' - 'USING gist (data) ' - 'WITH (buffering = off)') + "test_idx2", tbl.c.data, postgresql_with={"fillfactor": 50} + ) + idx3 = Index( + "test_idx3", + tbl.c.data, + postgresql_using="gist", + postgresql_with={"buffering": "off"}, + ) + + self.assert_compile( + schema.CreateIndex(idx1), + "CREATE INDEX test_idx1 ON testtbl " "(data)", + ) + self.assert_compile( + schema.CreateIndex(idx2), + "CREATE INDEX test_idx2 ON testtbl " + "(data) " + "WITH (fillfactor = 50)", + ) + self.assert_compile( + schema.CreateIndex(idx3), + "CREATE INDEX test_idx3 ON testtbl " + "USING gist (data) " + "WITH (buffering = off)", + ) def test_create_index_with_tablespace(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', String)) - - idx1 = Index('test_idx1', - tbl.c.data) - idx2 = Index('test_idx2', - tbl.c.data, - postgresql_tablespace='sometablespace') - idx3 = Index('test_idx3', - tbl.c.data, - postgresql_tablespace='another table space') - - self.assert_compile(schema.CreateIndex(idx1), - 'CREATE INDEX test_idx1 ON testtbl ' - '(data)', - dialect=postgresql.dialect()) - self.assert_compile(schema.CreateIndex(idx2), - 'CREATE INDEX test_idx2 ON testtbl ' - '(data) ' - 'TABLESPACE sometablespace', - dialect=postgresql.dialect()) - self.assert_compile(schema.CreateIndex(idx3), - 'CREATE INDEX test_idx3 ON testtbl ' - '(data) ' - 'TABLESPACE "another table space"', - dialect=postgresql.dialect()) + tbl = Table("testtbl", m, Column("data", String)) + + idx1 = Index("test_idx1", tbl.c.data) + idx2 = Index( + "test_idx2", tbl.c.data, postgresql_tablespace="sometablespace" + ) + idx3 = Index( + "test_idx3", + tbl.c.data, + postgresql_tablespace="another table space", + ) + + self.assert_compile( + schema.CreateIndex(idx1), + "CREATE INDEX test_idx1 ON testtbl " "(data)", + dialect=postgresql.dialect(), + ) + self.assert_compile( + schema.CreateIndex(idx2), + "CREATE INDEX test_idx2 ON testtbl " + "(data) " + "TABLESPACE sometablespace", + dialect=postgresql.dialect(), + ) + self.assert_compile( + schema.CreateIndex(idx3), + "CREATE INDEX test_idx3 ON testtbl " + "(data) " + 'TABLESPACE "another table space"', + dialect=postgresql.dialect(), + ) def test_create_index_with_multiple_options(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', String)) - - idx1 = Index('test_idx1', - tbl.c.data, - postgresql_using='btree', - postgresql_tablespace='atablespace', - postgresql_with={"fillfactor": 60}, - postgresql_where=and_(tbl.c.data > 5, tbl.c.data < 10)) - - self.assert_compile(schema.CreateIndex(idx1), - 'CREATE INDEX test_idx1 ON testtbl ' - 'USING btree (data) ' - 'WITH (fillfactor = 60) ' - 'TABLESPACE atablespace ' - 'WHERE data > 5 AND data < 10', - dialect=postgresql.dialect()) + tbl = Table("testtbl", m, Column("data", String)) + + idx1 = Index( + "test_idx1", + tbl.c.data, + postgresql_using="btree", + postgresql_tablespace="atablespace", + postgresql_with={"fillfactor": 60}, + postgresql_where=and_(tbl.c.data > 5, tbl.c.data < 10), + ) + + self.assert_compile( + schema.CreateIndex(idx1), + "CREATE INDEX test_idx1 ON testtbl " + "USING btree (data) " + "WITH (fillfactor = 60) " + "TABLESPACE atablespace " + "WHERE data > 5 AND data < 10", + dialect=postgresql.dialect(), + ) def test_create_index_expr_gets_parens(self): m = MetaData() - tbl = Table('testtbl', m, Column('x', Integer), Column('y', Integer)) + tbl = Table("testtbl", m, Column("x", Integer), Column("y", Integer)) - idx1 = Index('test_idx1', 5 / (tbl.c.x + tbl.c.y)) + idx1 = Index("test_idx1", 5 / (tbl.c.x + tbl.c.y)) self.assert_compile( schema.CreateIndex(idx1), - "CREATE INDEX test_idx1 ON testtbl ((5 / (x + y)))" + "CREATE INDEX test_idx1 ON testtbl ((5 / (x + y)))", ) def test_create_index_literals(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', Integer)) + tbl = Table("testtbl", m, Column("data", Integer)) - idx1 = Index('test_idx1', tbl.c.data + 5) + idx1 = Index("test_idx1", tbl.c.data + 5) self.assert_compile( schema.CreateIndex(idx1), - "CREATE INDEX test_idx1 ON testtbl ((data + 5))" + "CREATE INDEX test_idx1 ON testtbl ((data + 5))", ) def test_create_index_concurrently(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', Integer)) + tbl = Table("testtbl", m, Column("data", Integer)) - idx1 = Index('test_idx1', tbl.c.data, postgresql_concurrently=True) + idx1 = Index("test_idx1", tbl.c.data, postgresql_concurrently=True) self.assert_compile( schema.CreateIndex(idx1), - "CREATE INDEX CONCURRENTLY test_idx1 ON testtbl (data)" + "CREATE INDEX CONCURRENTLY test_idx1 ON testtbl (data)", ) dialect_8_1 = postgresql.dialect() @@ -542,85 +681,86 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( schema.CreateIndex(idx1), "CREATE INDEX test_idx1 ON testtbl (data)", - dialect=dialect_8_1 + dialect=dialect_8_1, ) def test_drop_index_concurrently(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', Integer)) + tbl = Table("testtbl", m, Column("data", Integer)) - idx1 = Index('test_idx1', tbl.c.data, postgresql_concurrently=True) + idx1 = Index("test_idx1", tbl.c.data, postgresql_concurrently=True) self.assert_compile( - schema.DropIndex(idx1), - "DROP INDEX CONCURRENTLY test_idx1" + schema.DropIndex(idx1), "DROP INDEX CONCURRENTLY test_idx1" ) dialect_9_1 = postgresql.dialect() dialect_9_1._supports_drop_index_concurrently = False self.assert_compile( - schema.DropIndex(idx1), - "DROP INDEX test_idx1", - dialect=dialect_9_1 + schema.DropIndex(idx1), "DROP INDEX test_idx1", dialect=dialect_9_1 ) def test_exclude_constraint_min(self): m = MetaData() - tbl = Table('testtbl', m, - Column('room', Integer, primary_key=True)) - cons = ExcludeConstraint(('room', '=')) + tbl = Table("testtbl", m, Column("room", Integer, primary_key=True)) + cons = ExcludeConstraint(("room", "=")) tbl.append_constraint(cons) - self.assert_compile(schema.AddConstraint(cons), - 'ALTER TABLE testtbl ADD EXCLUDE USING gist ' - '(room WITH =)', - dialect=postgresql.dialect()) + self.assert_compile( + schema.AddConstraint(cons), + "ALTER TABLE testtbl ADD EXCLUDE USING gist " "(room WITH =)", + dialect=postgresql.dialect(), + ) def test_exclude_constraint_full(self): m = MetaData() - room = Column('room', Integer, primary_key=True) - tbl = Table('testtbl', m, - room, - Column('during', TSRANGE)) - room = Column('room', Integer, primary_key=True) - cons = ExcludeConstraint((room, '='), ('during', '&&'), - name='my_name', - using='gist', - where="room > 100", - deferrable=True, - initially='immediate') + room = Column("room", Integer, primary_key=True) + tbl = Table("testtbl", m, room, Column("during", TSRANGE)) + room = Column("room", Integer, primary_key=True) + cons = ExcludeConstraint( + (room, "="), + ("during", "&&"), + name="my_name", + using="gist", + where="room > 100", + deferrable=True, + initially="immediate", + ) tbl.append_constraint(cons) - self.assert_compile(schema.AddConstraint(cons), - 'ALTER TABLE testtbl ADD CONSTRAINT my_name ' - 'EXCLUDE USING gist ' - '(room WITH =, during WITH ''&&) WHERE ' - '(room > 100) DEFERRABLE INITIALLY immediate', - dialect=postgresql.dialect()) + self.assert_compile( + schema.AddConstraint(cons), + "ALTER TABLE testtbl ADD CONSTRAINT my_name " + "EXCLUDE USING gist " + "(room WITH =, during WITH " + "&&) WHERE " + "(room > 100) DEFERRABLE INITIALLY immediate", + dialect=postgresql.dialect(), + ) def test_exclude_constraint_copy(self): m = MetaData() - cons = ExcludeConstraint(('room', '=')) - tbl = Table('testtbl', m, - Column('room', Integer, primary_key=True), - cons) + cons = ExcludeConstraint(("room", "=")) + tbl = Table( + "testtbl", m, Column("room", Integer, primary_key=True), cons + ) # apparently you can't copy a ColumnCollectionConstraint until # after it has been bound to a table... cons_copy = cons.copy() tbl.append_constraint(cons_copy) - self.assert_compile(schema.AddConstraint(cons_copy), - 'ALTER TABLE testtbl ADD EXCLUDE USING gist ' - '(room WITH =)') + self.assert_compile( + schema.AddConstraint(cons_copy), + "ALTER TABLE testtbl ADD EXCLUDE USING gist " "(room WITH =)", + ) def test_exclude_constraint_copy_where_using(self): m = MetaData() - tbl = Table('testtbl', m, - Column('room', Integer, primary_key=True), - ) + tbl = Table("testtbl", m, Column("room", Integer, primary_key=True)) cons = ExcludeConstraint( - (tbl.c.room, '='), where=tbl.c.room > 5, using='foobar') + (tbl.c.room, "="), where=tbl.c.room > 5, using="foobar" + ) tbl.append_constraint(cons) self.assert_compile( schema.AddConstraint(cons), "ALTER TABLE testtbl ADD EXCLUDE USING foobar " - "(room WITH =) WHERE (testtbl.room > 5)" + "(room WITH =) WHERE (testtbl.room > 5)", ) m2 = MetaData() @@ -630,213 +770,235 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "CREATE TABLE testtbl (room SERIAL NOT NULL, " "PRIMARY KEY (room), " "EXCLUDE USING foobar " - "(room WITH =) WHERE (testtbl.room > 5))" + "(room WITH =) WHERE (testtbl.room > 5))", ) def test_exclude_constraint_text(self): m = MetaData() - cons = ExcludeConstraint((text('room::TEXT'), '=')) - Table( - 'testtbl', m, - Column('room', String), - cons) + cons = ExcludeConstraint((text("room::TEXT"), "=")) + Table("testtbl", m, Column("room", String), cons) self.assert_compile( schema.AddConstraint(cons), - 'ALTER TABLE testtbl ADD EXCLUDE USING gist ' - '(room::TEXT WITH =)') + "ALTER TABLE testtbl ADD EXCLUDE USING gist " + "(room::TEXT WITH =)", + ) def test_exclude_constraint_cast(self): m = MetaData() - tbl = Table( - 'testtbl', m, - Column('room', String) - ) - cons = ExcludeConstraint((cast(tbl.c.room, Text), '=')) + tbl = Table("testtbl", m, Column("room", String)) + cons = ExcludeConstraint((cast(tbl.c.room, Text), "=")) tbl.append_constraint(cons) self.assert_compile( schema.AddConstraint(cons), - 'ALTER TABLE testtbl ADD EXCLUDE USING gist ' - '(CAST(room AS TEXT) WITH =)' + "ALTER TABLE testtbl ADD EXCLUDE USING gist " + "(CAST(room AS TEXT) WITH =)", ) def test_exclude_constraint_cast_quote(self): m = MetaData() - tbl = Table( - 'testtbl', m, - Column('Room', String) - ) - cons = ExcludeConstraint((cast(tbl.c.Room, Text), '=')) + tbl = Table("testtbl", m, Column("Room", String)) + cons = ExcludeConstraint((cast(tbl.c.Room, Text), "=")) tbl.append_constraint(cons) self.assert_compile( schema.AddConstraint(cons), - 'ALTER TABLE testtbl ADD EXCLUDE USING gist ' - '(CAST("Room" AS TEXT) WITH =)' + "ALTER TABLE testtbl ADD EXCLUDE USING gist " + '(CAST("Room" AS TEXT) WITH =)', ) def test_exclude_constraint_when(self): m = MetaData() - tbl = Table( - 'testtbl', m, - Column('room', String) - ) - cons = ExcludeConstraint(('room', '='), where=tbl.c.room.in_(['12'])) + tbl = Table("testtbl", m, Column("room", String)) + cons = ExcludeConstraint(("room", "="), where=tbl.c.room.in_(["12"])) tbl.append_constraint(cons) - self.assert_compile(schema.AddConstraint(cons), - 'ALTER TABLE testtbl ADD EXCLUDE USING gist ' - '(room WITH =) WHERE (testtbl.room IN (\'12\'))', - dialect=postgresql.dialect()) + self.assert_compile( + schema.AddConstraint(cons), + "ALTER TABLE testtbl ADD EXCLUDE USING gist " + "(room WITH =) WHERE (testtbl.room IN ('12'))", + dialect=postgresql.dialect(), + ) def test_substring(self): - self.assert_compile(func.substring('abc', 1, 2), - 'SUBSTRING(%(substring_1)s FROM %(substring_2)s ' - 'FOR %(substring_3)s)') - self.assert_compile(func.substring('abc', 1), - 'SUBSTRING(%(substring_1)s FROM %(substring_2)s)') + self.assert_compile( + func.substring("abc", 1, 2), + "SUBSTRING(%(substring_1)s FROM %(substring_2)s " + "FOR %(substring_3)s)", + ) + self.assert_compile( + func.substring("abc", 1), + "SUBSTRING(%(substring_1)s FROM %(substring_2)s)", + ) def test_for_update(self): - table1 = table('mytable', - column('myid'), column('name'), column('description')) + table1 = table( + "mytable", column("myid"), column("name"), column("description") + ) self.assert_compile( table1.select(table1.c.myid == 7).with_for_update(), "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE") + "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE", + ) self.assert_compile( table1.select(table1.c.myid == 7).with_for_update(nowait=True), "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE NOWAIT") + "FROM mytable WHERE mytable.myid = %(myid_1)s FOR UPDATE NOWAIT", + ) self.assert_compile( - table1.select(table1.c.myid == 7). - with_for_update(skip_locked=True), + table1.select(table1.c.myid == 7).with_for_update( + skip_locked=True + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = %(myid_1)s " - "FOR UPDATE SKIP LOCKED") + "FOR UPDATE SKIP LOCKED", + ) self.assert_compile( table1.select(table1.c.myid == 7).with_for_update(read=True), "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE") + "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE", + ) self.assert_compile( - table1.select(table1.c.myid == 7). - with_for_update(read=True, nowait=True), + table1.select(table1.c.myid == 7).with_for_update( + read=True, nowait=True + ), "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE NOWAIT") + "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE NOWAIT", + ) self.assert_compile( - table1.select(table1.c.myid == 7). - with_for_update(read=True, skip_locked=True), + table1.select(table1.c.myid == 7).with_for_update( + read=True, skip_locked=True + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = %(myid_1)s " - "FOR SHARE SKIP LOCKED") + "FOR SHARE SKIP LOCKED", + ) self.assert_compile( - table1.select(table1.c.myid == 7). - with_for_update(of=table1.c.myid), + table1.select(table1.c.myid == 7).with_for_update( + of=table1.c.myid + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = %(myid_1)s " - "FOR UPDATE OF mytable") + "FOR UPDATE OF mytable", + ) self.assert_compile( - table1.select(table1.c.myid == 7). - with_for_update(read=True, nowait=True, of=table1), + table1.select(table1.c.myid == 7).with_for_update( + read=True, nowait=True, of=table1 + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = %(myid_1)s " - "FOR SHARE OF mytable NOWAIT") + "FOR SHARE OF mytable NOWAIT", + ) self.assert_compile( - table1.select(table1.c.myid == 7). - with_for_update(read=True, nowait=True, of=table1.c.myid), + table1.select(table1.c.myid == 7).with_for_update( + read=True, nowait=True, of=table1.c.myid + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = %(myid_1)s " - "FOR SHARE OF mytable NOWAIT") + "FOR SHARE OF mytable NOWAIT", + ) self.assert_compile( - table1.select(table1.c.myid == 7). - with_for_update(read=True, nowait=True, - of=[table1.c.myid, table1.c.name]), + table1.select(table1.c.myid == 7).with_for_update( + read=True, nowait=True, of=[table1.c.myid, table1.c.name] + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = %(myid_1)s " - "FOR SHARE OF mytable NOWAIT") + "FOR SHARE OF mytable NOWAIT", + ) self.assert_compile( - table1.select(table1.c.myid == 7). - with_for_update(read=True, skip_locked=True, - of=[table1.c.myid, table1.c.name]), + table1.select(table1.c.myid == 7).with_for_update( + read=True, skip_locked=True, of=[table1.c.myid, table1.c.name] + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = %(myid_1)s " - "FOR SHARE OF mytable SKIP LOCKED") + "FOR SHARE OF mytable SKIP LOCKED", + ) self.assert_compile( - table1.select(table1.c.myid == 7). - with_for_update(key_share=True, nowait=True, - of=[table1.c.myid, table1.c.name]), + table1.select(table1.c.myid == 7).with_for_update( + key_share=True, nowait=True, of=[table1.c.myid, table1.c.name] + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = %(myid_1)s " - "FOR NO KEY UPDATE OF mytable NOWAIT") + "FOR NO KEY UPDATE OF mytable NOWAIT", + ) self.assert_compile( - table1.select(table1.c.myid == 7). - with_for_update(key_share=True, skip_locked=True, - of=[table1.c.myid, table1.c.name]), + table1.select(table1.c.myid == 7).with_for_update( + key_share=True, + skip_locked=True, + of=[table1.c.myid, table1.c.name], + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = %(myid_1)s " - "FOR NO KEY UPDATE OF mytable SKIP LOCKED") + "FOR NO KEY UPDATE OF mytable SKIP LOCKED", + ) self.assert_compile( - table1.select(table1.c.myid == 7). - with_for_update(key_share=True, - of=[table1.c.myid, table1.c.name]), + table1.select(table1.c.myid == 7).with_for_update( + key_share=True, of=[table1.c.myid, table1.c.name] + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = %(myid_1)s " - "FOR NO KEY UPDATE OF mytable") + "FOR NO KEY UPDATE OF mytable", + ) self.assert_compile( - table1.select(table1.c.myid == 7). - with_for_update(key_share=True), + table1.select(table1.c.myid == 7).with_for_update(key_share=True), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = %(myid_1)s " - "FOR NO KEY UPDATE") + "FOR NO KEY UPDATE", + ) self.assert_compile( - table1.select(table1.c.myid == 7). - with_for_update(read=True, key_share=True), + table1.select(table1.c.myid == 7).with_for_update( + read=True, key_share=True + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = %(myid_1)s " - "FOR KEY SHARE") + "FOR KEY SHARE", + ) self.assert_compile( - table1.select(table1.c.myid == 7). - with_for_update(read=True, key_share=True, of=table1), + table1.select(table1.c.myid == 7).with_for_update( + read=True, key_share=True, of=table1 + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = %(myid_1)s " - "FOR KEY SHARE OF mytable") + "FOR KEY SHARE OF mytable", + ) self.assert_compile( - table1.select(table1.c.myid == 7). - with_for_update(read=True, key_share=True, skip_locked=True), + table1.select(table1.c.myid == 7).with_for_update( + read=True, key_share=True, skip_locked=True + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = %(myid_1)s " - "FOR KEY SHARE SKIP LOCKED") + "FOR KEY SHARE SKIP LOCKED", + ) ta = table1.alias() self.assert_compile( - ta.select(ta.c.myid == 7). - with_for_update(of=[ta.c.myid, ta.c.name]), + ta.select(ta.c.myid == 7).with_for_update( + of=[ta.c.myid, ta.c.name] + ), "SELECT mytable_1.myid, mytable_1.name, mytable_1.description " "FROM mytable AS mytable_1 " - "WHERE mytable_1.myid = %(myid_1)s FOR UPDATE OF mytable_1" + "WHERE mytable_1.myid = %(myid_1)s FOR UPDATE OF mytable_1", ) def test_for_update_with_schema(self): m = MetaData() table1 = Table( - 'mytable', m, - Column('myid'), - Column('name'), - schema='testschema' + "mytable", m, Column("myid"), Column("name"), schema="testschema" ) self.assert_compile( @@ -844,114 +1006,110 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT testschema.mytable.myid, testschema.mytable.name " "FROM testschema.mytable " "WHERE testschema.mytable.myid = %(myid_1)s " - "FOR UPDATE OF mytable") + "FOR UPDATE OF mytable", + ) def test_reserved_words(self): - table = Table("pg_table", MetaData(), - Column("col1", Integer), - Column("variadic", Integer)) + table = Table( + "pg_table", + MetaData(), + Column("col1", Integer), + Column("variadic", Integer), + ) x = select([table.c.col1, table.c.variadic]) self.assert_compile( - x, - '''SELECT pg_table.col1, pg_table."variadic" FROM pg_table''') + x, """SELECT pg_table.col1, pg_table."variadic" FROM pg_table""" + ) def test_array(self): - c = Column('x', postgresql.ARRAY(Integer)) + c = Column("x", postgresql.ARRAY(Integer)) self.assert_compile( - cast(c, postgresql.ARRAY(Integer)), - "CAST(x AS INTEGER[])" - ) - self.assert_compile( - c[5], - "x[%(x_1)s]", - checkparams={'x_1': 5} + cast(c, postgresql.ARRAY(Integer)), "CAST(x AS INTEGER[])" ) + self.assert_compile(c[5], "x[%(x_1)s]", checkparams={"x_1": 5}) self.assert_compile( - c[5:7], - "x[%(x_1)s:%(x_2)s]", - checkparams={'x_2': 7, 'x_1': 5} + c[5:7], "x[%(x_1)s:%(x_2)s]", checkparams={"x_2": 7, "x_1": 5} ) self.assert_compile( c[5:7][2:3], "x[%(x_1)s:%(x_2)s][%(param_1)s:%(param_2)s]", - checkparams={'x_2': 7, 'x_1': 5, 'param_1': 2, 'param_2': 3} + checkparams={"x_2": 7, "x_1": 5, "param_1": 2, "param_2": 3}, ) self.assert_compile( c[5:7][3], "x[%(x_1)s:%(x_2)s][%(param_1)s]", - checkparams={'x_2': 7, 'x_1': 5, 'param_1': 3} + checkparams={"x_2": 7, "x_1": 5, "param_1": 3}, ) self.assert_compile( - c.contains([1]), - 'x @> %(x_1)s', - checkparams={'x_1': [1]} + c.contains([1]), "x @> %(x_1)s", checkparams={"x_1": [1]} ) self.assert_compile( - c.contained_by([2]), - 'x <@ %(x_1)s', - checkparams={'x_1': [2]} + c.contained_by([2]), "x <@ %(x_1)s", checkparams={"x_1": [2]} ) self.assert_compile( - c.overlap([3]), - 'x && %(x_1)s', - checkparams={'x_1': [3]} + c.overlap([3]), "x && %(x_1)s", checkparams={"x_1": [3]} ) self.assert_compile( postgresql.Any(4, c), - '%(param_1)s = ANY (x)', - checkparams={'param_1': 4} + "%(param_1)s = ANY (x)", + checkparams={"param_1": 4}, ) self.assert_compile( c.any(5, operator=operators.ne), - '%(param_1)s != ANY (x)', - checkparams={'param_1': 5} + "%(param_1)s != ANY (x)", + checkparams={"param_1": 5}, ) self.assert_compile( postgresql.All(6, c, operator=operators.gt), - '%(param_1)s > ALL (x)', - checkparams={'param_1': 6} + "%(param_1)s > ALL (x)", + checkparams={"param_1": 6}, ) self.assert_compile( c.all(7, operator=operators.lt), - '%(param_1)s < ALL (x)', - checkparams={'param_1': 7} + "%(param_1)s < ALL (x)", + checkparams={"param_1": 7}, ) def _test_array_zero_indexes(self, zero_indexes): - c = Column('x', postgresql.ARRAY(Integer, zero_indexes=zero_indexes)) + c = Column("x", postgresql.ARRAY(Integer, zero_indexes=zero_indexes)) add_one = 1 if zero_indexes else 0 self.assert_compile( cast(c, postgresql.ARRAY(Integer, zero_indexes=zero_indexes)), - "CAST(x AS INTEGER[])" + "CAST(x AS INTEGER[])", ) self.assert_compile( - c[5], - "x[%(x_1)s]", - checkparams={'x_1': 5 + add_one} + c[5], "x[%(x_1)s]", checkparams={"x_1": 5 + add_one} ) self.assert_compile( c[5:7], "x[%(x_1)s:%(x_2)s]", - checkparams={'x_2': 7 + add_one, 'x_1': 5 + add_one} + checkparams={"x_2": 7 + add_one, "x_1": 5 + add_one}, ) self.assert_compile( c[5:7][2:3], "x[%(x_1)s:%(x_2)s][%(param_1)s:%(param_2)s]", - checkparams={'x_2': 7 + add_one, 'x_1': 5 + add_one, - 'param_1': 2 + add_one, 'param_2': 3 + add_one} + checkparams={ + "x_2": 7 + add_one, + "x_1": 5 + add_one, + "param_1": 2 + add_one, + "param_2": 3 + add_one, + }, ) self.assert_compile( c[5:7][3], "x[%(x_1)s:%(x_2)s][%(param_1)s]", - checkparams={'x_2': 7 + add_one, 'x_1': 5 + add_one, - 'param_1': 3 + add_one} + checkparams={ + "x_2": 7 + add_one, + "x_1": 5 + add_one, + "param_1": 3 + add_one, + }, ) def test_array_zero_indexes_true(self): @@ -964,17 +1122,27 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): isinstance(postgresql.array([1, 2]).type, postgresql.ARRAY) is_(postgresql.array([1, 2]).type.item_type._type_affinity, Integer) - is_(postgresql.array([1, 2], type_=String). - type.item_type._type_affinity, String) + is_( + postgresql.array( + [1, 2], type_=String + ).type.item_type._type_affinity, + String, + ) def test_array_literal(self): self.assert_compile( - func.array_dims(postgresql.array([1, 2]) + - postgresql.array([3, 4, 5])), + func.array_dims( + postgresql.array([1, 2]) + postgresql.array([3, 4, 5]) + ), "array_dims(ARRAY[%(param_1)s, %(param_2)s] || " "ARRAY[%(param_3)s, %(param_4)s, %(param_5)s])", - checkparams={'param_5': 5, 'param_4': 4, 'param_1': 1, - 'param_3': 3, 'param_2': 2} + checkparams={ + "param_5": 5, + "param_4": 4, + "param_1": 1, + "param_3": 3, + "param_2": 2, + }, ) def test_array_literal_compare(self): @@ -982,98 +1150,104 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): postgresql.array([1, 2]) == [3, 4, 5], "ARRAY[%(param_1)s, %(param_2)s] = " "ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]", - checkparams={'param_5': 5, - 'param_4': 4, - 'param_1': 1, - 'param_3': 3, - 'param_2': 2} - + checkparams={ + "param_5": 5, + "param_4": 4, + "param_1": 1, + "param_3": 3, + "param_2": 2, + }, ) def test_array_literal_insert(self): m = MetaData() - t = Table('t', m, Column('data', postgresql.ARRAY(Integer))) + t = Table("t", m, Column("data", postgresql.ARRAY(Integer))) self.assert_compile( t.insert().values(data=array([1, 2, 3])), "INSERT INTO t (data) VALUES (ARRAY[%(param_1)s, " - "%(param_2)s, %(param_3)s])" + "%(param_2)s, %(param_3)s])", ) def test_update_array_element(self): m = MetaData() - t = Table('t', m, Column('data', postgresql.ARRAY(Integer))) + t = Table("t", m, Column("data", postgresql.ARRAY(Integer))) self.assert_compile( t.update().values({t.c.data[5]: 1}), "UPDATE t SET data[%(data_1)s]=%(param_1)s", - checkparams={'data_1': 5, 'param_1': 1} + checkparams={"data_1": 5, "param_1": 1}, ) def test_update_array_slice(self): m = MetaData() - t = Table('t', m, Column('data', postgresql.ARRAY(Integer))) + t = Table("t", m, Column("data", postgresql.ARRAY(Integer))) self.assert_compile( t.update().values({t.c.data[2:5]: 2}), "UPDATE t SET data[%(data_1)s:%(data_2)s]=%(param_1)s", - checkparams={'param_1': 2, 'data_2': 5, 'data_1': 2} - + checkparams={"param_1": 2, "data_2": 5, "data_1": 2}, ) def test_from_only(self): m = MetaData() - tbl1 = Table('testtbl1', m, Column('id', Integer)) - tbl2 = Table('testtbl2', m, Column('id', Integer)) + tbl1 = Table("testtbl1", m, Column("id", Integer)) + tbl2 = Table("testtbl2", m, Column("id", Integer)) - stmt = tbl1.select().with_hint(tbl1, 'ONLY', 'postgresql') - expected = 'SELECT testtbl1.id FROM ONLY testtbl1' + stmt = tbl1.select().with_hint(tbl1, "ONLY", "postgresql") + expected = "SELECT testtbl1.id FROM ONLY testtbl1" self.assert_compile(stmt, expected) - talias1 = tbl1.alias('foo') - stmt = talias1.select().with_hint(talias1, 'ONLY', 'postgresql') - expected = 'SELECT foo.id FROM ONLY testtbl1 AS foo' + talias1 = tbl1.alias("foo") + stmt = talias1.select().with_hint(talias1, "ONLY", "postgresql") + expected = "SELECT foo.id FROM ONLY testtbl1 AS foo" self.assert_compile(stmt, expected) - stmt = select([tbl1, tbl2]).with_hint(tbl1, 'ONLY', 'postgresql') - expected = ('SELECT testtbl1.id, testtbl2.id FROM ONLY testtbl1, ' - 'testtbl2') + stmt = select([tbl1, tbl2]).with_hint(tbl1, "ONLY", "postgresql") + expected = ( + "SELECT testtbl1.id, testtbl2.id FROM ONLY testtbl1, " "testtbl2" + ) self.assert_compile(stmt, expected) - stmt = select([tbl1, tbl2]).with_hint(tbl2, 'ONLY', 'postgresql') - expected = ('SELECT testtbl1.id, testtbl2.id FROM testtbl1, ONLY ' - 'testtbl2') + stmt = select([tbl1, tbl2]).with_hint(tbl2, "ONLY", "postgresql") + expected = ( + "SELECT testtbl1.id, testtbl2.id FROM testtbl1, ONLY " "testtbl2" + ) self.assert_compile(stmt, expected) stmt = select([tbl1, tbl2]) - stmt = stmt.with_hint(tbl1, 'ONLY', 'postgresql') - stmt = stmt.with_hint(tbl2, 'ONLY', 'postgresql') - expected = ('SELECT testtbl1.id, testtbl2.id FROM ONLY testtbl1, ' - 'ONLY testtbl2') + stmt = stmt.with_hint(tbl1, "ONLY", "postgresql") + stmt = stmt.with_hint(tbl2, "ONLY", "postgresql") + expected = ( + "SELECT testtbl1.id, testtbl2.id FROM ONLY testtbl1, " + "ONLY testtbl2" + ) self.assert_compile(stmt, expected) stmt = update(tbl1, values=dict(id=1)) - stmt = stmt.with_hint('ONLY', dialect_name='postgresql') - expected = 'UPDATE ONLY testtbl1 SET id=%(id)s' + stmt = stmt.with_hint("ONLY", dialect_name="postgresql") + expected = "UPDATE ONLY testtbl1 SET id=%(id)s" self.assert_compile(stmt, expected) stmt = delete(tbl1).with_hint( - 'ONLY', selectable=tbl1, dialect_name='postgresql') - expected = 'DELETE FROM ONLY testtbl1' + "ONLY", selectable=tbl1, dialect_name="postgresql" + ) + expected = "DELETE FROM ONLY testtbl1" self.assert_compile(stmt, expected) - tbl3 = Table('testtbl3', m, Column('id', Integer), schema='testschema') - stmt = tbl3.select().with_hint(tbl3, 'ONLY', 'postgresql') - expected = 'SELECT testschema.testtbl3.id FROM '\ - 'ONLY testschema.testtbl3' + tbl3 = Table("testtbl3", m, Column("id", Integer), schema="testschema") + stmt = tbl3.select().with_hint(tbl3, "ONLY", "postgresql") + expected = ( + "SELECT testschema.testtbl3.id FROM " "ONLY testschema.testtbl3" + ) self.assert_compile(stmt, expected) assert_raises( exc.CompileError, tbl3.select().with_hint(tbl3, "FAKE", "postgresql").compile, - dialect=postgresql.dialect() + dialect=postgresql.dialect(), ) def test_aggregate_order_by_one(self): m = MetaData() - table = Table('table1', m, Column('a', Integer), Column('b', Integer)) + table = Table("table1", m, Column("a", Integer), Column("b", Integer)) expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc())) stmt = select([expr]) @@ -1082,32 +1256,31 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( stmt, "SELECT array_agg(table1.a ORDER BY table1.b DESC) " - "AS array_agg_1 FROM table1" + "AS array_agg_1 FROM table1", ) def test_aggregate_order_by_two(self): m = MetaData() - table = Table('table1', m, Column('a', Integer), Column('b', Integer)) + table = Table("table1", m, Column("a", Integer), Column("b", Integer)) expr = func.string_agg( - table.c.a, - aggregate_order_by(literal_column("','"), table.c.a) + table.c.a, aggregate_order_by(literal_column("','"), table.c.a) ) stmt = select([expr]) self.assert_compile( stmt, "SELECT string_agg(table1.a, ',' ORDER BY table1.a) " - "AS string_agg_1 FROM table1" + "AS string_agg_1 FROM table1", ) def test_aggregate_order_by_multi_col(self): m = MetaData() - table = Table('table1', m, Column('a', Integer), Column('b', Integer)) + table = Table("table1", m, Column("a", Integer), Column("b", Integer)) expr = func.string_agg( table.c.a, aggregate_order_by( - literal_column("','"), - table.c.a, table.c.b.desc()) + literal_column("','"), table.c.a, table.c.b.desc() + ), ) stmt = select([expr]) @@ -1115,75 +1288,77 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): stmt, "SELECT string_agg(table1.a, " "',' ORDER BY table1.a, table1.b DESC) " - "AS string_agg_1 FROM table1" + "AS string_agg_1 FROM table1", ) def test_aggregate_orcer_by_no_arg(self): assert_raises_message( TypeError, "at least one ORDER BY element is required", - aggregate_order_by, literal_column("','") + aggregate_order_by, + literal_column("','"), ) def test_pg_array_agg_implicit_pg_array(self): - expr = pg_array_agg(column('data', Integer)) + expr = pg_array_agg(column("data", Integer)) assert isinstance(expr.type, PG_ARRAY) is_(expr.type.item_type._type_affinity, Integer) def test_pg_array_agg_uses_base_array(self): - expr = pg_array_agg(column('data', sqltypes.ARRAY(Integer))) + expr = pg_array_agg(column("data", sqltypes.ARRAY(Integer))) assert isinstance(expr.type, sqltypes.ARRAY) assert not isinstance(expr.type, PG_ARRAY) is_(expr.type.item_type._type_affinity, Integer) def test_pg_array_agg_uses_pg_array(self): - expr = pg_array_agg(column('data', PG_ARRAY(Integer))) + expr = pg_array_agg(column("data", PG_ARRAY(Integer))) assert isinstance(expr.type, PG_ARRAY) is_(expr.type.item_type._type_affinity, Integer) def test_pg_array_agg_explicit_base_array(self): - expr = pg_array_agg(column( - 'data', sqltypes.ARRAY(Integer)), type_=sqltypes.ARRAY(Integer)) + expr = pg_array_agg( + column("data", sqltypes.ARRAY(Integer)), + type_=sqltypes.ARRAY(Integer), + ) assert isinstance(expr.type, sqltypes.ARRAY) assert not isinstance(expr.type, PG_ARRAY) is_(expr.type.item_type._type_affinity, Integer) def test_pg_array_agg_explicit_pg_array(self): - expr = pg_array_agg(column( - 'data', sqltypes.ARRAY(Integer)), type_=PG_ARRAY(Integer)) + expr = pg_array_agg( + column("data", sqltypes.ARRAY(Integer)), type_=PG_ARRAY(Integer) + ) assert isinstance(expr.type, PG_ARRAY) is_(expr.type.item_type._type_affinity, Integer) def test_aggregate_order_by_adapt(self): m = MetaData() - table = Table('table1', m, Column('a', Integer), Column('b', Integer)) + table = Table("table1", m, Column("a", Integer), Column("b", Integer)) expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc())) stmt = select([expr]) - a1 = table.alias('foo') + a1 = table.alias("foo") stmt2 = sql_util.ClauseAdapter(a1).traverse(stmt) self.assert_compile( stmt2, "SELECT array_agg(foo.a ORDER BY foo.b DESC) AS array_agg_1 " - "FROM table1 AS foo" + "FROM table1 AS foo", ) def test_delete_extra_froms(self): - t1 = table('t1', column('c1')) - t2 = table('t2', column('c1')) + t1 = table("t1", column("c1")) + t2 = table("t2", column("c1")) q = delete(t1).where(t1.c.c1 == t2.c.c1) - self.assert_compile( - q, "DELETE FROM t1 USING t2 WHERE t1.c1 = t2.c1" - ) + self.assert_compile(q, "DELETE FROM t1 USING t2 WHERE t1.c1 = t2.c1") def test_delete_extra_froms_alias(self): - a1 = table('t1', column('c1')).alias('a1') - t2 = table('t2', column('c1')) + a1 = table("t1", column("c1")).alias("a1") + t2 = table("t2", column("c1")) q = delete(a1).where(a1.c.c1 == t2.c.c1) self.assert_compile( q, "DELETE FROM t1 AS a1 USING t2 WHERE a1.c1 = t2.c1" @@ -1195,204 +1370,217 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL): def setup(self): self.table1 = table1 = table( - 'mytable', - column('myid', Integer), - column('name', String(128)), - column('description', String(128)), + "mytable", + column("myid", Integer), + column("name", String(128)), + column("description", String(128)), ) md = MetaData() self.table_with_metadata = Table( - 'mytable', md, - Column('myid', Integer, primary_key=True), - Column('name', String(128)), - Column('description', String(128)) + "mytable", + md, + Column("myid", Integer, primary_key=True), + Column("name", String(128)), + Column("description", String(128)), ) self.unique_constr = schema.UniqueConstraint( - table1.c.name, name='uq_name') + table1.c.name, name="uq_name" + ) self.excl_constr = ExcludeConstraint( - (table1.c.name, '='), - (table1.c.description, '&&'), - name='excl_thing' + (table1.c.name, "="), + (table1.c.description, "&&"), + name="excl_thing", ) self.excl_constr_anon = ExcludeConstraint( - (self.table_with_metadata.c.name, '='), - (self.table_with_metadata.c.description, '&&'), - where=self.table_with_metadata.c.description != 'foo' + (self.table_with_metadata.c.name, "="), + (self.table_with_metadata.c.description, "&&"), + where=self.table_with_metadata.c.description != "foo", ) self.goofy_index = Index( - 'goofy_index', table1.c.name, - postgresql_where=table1.c.name > 'm' + "goofy_index", table1.c.name, postgresql_where=table1.c.name > "m" ) def test_do_nothing_no_target(self): i = insert( - self.table1, values=dict(name='foo'), + self.table1, values=dict(name="foo") ).on_conflict_do_nothing() - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES ' - '(%(name)s) ON CONFLICT DO NOTHING') + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) ON CONFLICT DO NOTHING", + ) def test_do_nothing_index_elements_target(self): i = insert( - self.table1, values=dict(name='foo'), - ).on_conflict_do_nothing( - index_elements=['myid'], - ) + self.table1, values=dict(name="foo") + ).on_conflict_do_nothing(index_elements=["myid"]) self.assert_compile( i, "INSERT INTO mytable (name) VALUES " - "(%(name)s) ON CONFLICT (myid) DO NOTHING" + "(%(name)s) ON CONFLICT (myid) DO NOTHING", ) def test_do_update_set_clause_none(self): - i = insert(self.table_with_metadata).values(myid=1, name='foo') + i = insert(self.table_with_metadata).values(myid=1, name="foo") i = i.on_conflict_do_update( - index_elements=['myid'], - set_=OrderedDict([ - ('name', "I'm a name"), - ('description', None)]) + index_elements=["myid"], + set_=OrderedDict([("name", "I'm a name"), ("description", None)]), ) self.assert_compile( i, - 'INSERT INTO mytable (myid, name) VALUES ' - '(%(myid)s, %(name)s) ON CONFLICT (myid) ' - 'DO UPDATE SET name = %(param_1)s, ' - 'description = %(param_2)s', - {"myid": 1, "name": "foo", - "param_1": "I'm a name", "param_2": None} - + "INSERT INTO mytable (myid, name) VALUES " + "(%(myid)s, %(name)s) ON CONFLICT (myid) " + "DO UPDATE SET name = %(param_1)s, " + "description = %(param_2)s", + { + "myid": 1, + "name": "foo", + "param_1": "I'm a name", + "param_2": None, + }, ) def test_do_update_set_clause_literal(self): - i = insert(self.table_with_metadata).values(myid=1, name='foo') + i = insert(self.table_with_metadata).values(myid=1, name="foo") i = i.on_conflict_do_update( - index_elements=['myid'], - set_=OrderedDict([ - ('name', "I'm a name"), - ('description', null())]) + index_elements=["myid"], + set_=OrderedDict( + [("name", "I'm a name"), ("description", null())] + ), ) self.assert_compile( i, - 'INSERT INTO mytable (myid, name) VALUES ' - '(%(myid)s, %(name)s) ON CONFLICT (myid) ' - 'DO UPDATE SET name = %(param_1)s, ' - 'description = NULL', - {"myid": 1, "name": "foo", "param_1": "I'm a name"} - + "INSERT INTO mytable (myid, name) VALUES " + "(%(myid)s, %(name)s) ON CONFLICT (myid) " + "DO UPDATE SET name = %(param_1)s, " + "description = NULL", + {"myid": 1, "name": "foo", "param_1": "I'm a name"}, ) def test_do_update_str_index_elements_target_one(self): - i = insert(self.table_with_metadata).values(myid=1, name='foo') + i = insert(self.table_with_metadata).values(myid=1, name="foo") i = i.on_conflict_do_update( - index_elements=['myid'], - set_=OrderedDict([ - ('name', i.excluded.name), - ('description', i.excluded.description)]) + index_elements=["myid"], + set_=OrderedDict( + [ + ("name", i.excluded.name), + ("description", i.excluded.description), + ] + ), + ) + self.assert_compile( + i, + "INSERT INTO mytable (myid, name) VALUES " + "(%(myid)s, %(name)s) ON CONFLICT (myid) " + "DO UPDATE SET name = excluded.name, " + "description = excluded.description", ) - self.assert_compile(i, - 'INSERT INTO mytable (myid, name) VALUES ' - '(%(myid)s, %(name)s) ON CONFLICT (myid) ' - 'DO UPDATE SET name = excluded.name, ' - 'description = excluded.description') def test_do_update_str_index_elements_target_two(self): - i = insert( - self.table1, values=dict(name='foo')) + i = insert(self.table1, values=dict(name="foo")) i = i.on_conflict_do_update( - index_elements=['myid'], - set_=dict(name=i.excluded.name) + index_elements=["myid"], set_=dict(name=i.excluded.name) + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) ON CONFLICT (myid) " + "DO UPDATE SET name = excluded.name", ) - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES ' - '(%(name)s) ON CONFLICT (myid) ' - 'DO UPDATE SET name = excluded.name') def test_do_update_col_index_elements_target(self): - i = insert( - self.table1, values=dict(name='foo')) + i = insert(self.table1, values=dict(name="foo")) i = i.on_conflict_do_update( index_elements=[self.table1.c.myid], - set_=dict(name=i.excluded.name) + set_=dict(name=i.excluded.name), + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) ON CONFLICT (myid) " + "DO UPDATE SET name = excluded.name", ) - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES ' - '(%(name)s) ON CONFLICT (myid) ' - 'DO UPDATE SET name = excluded.name') def test_do_update_unnamed_pk_constraint_target(self): - i = insert( - self.table_with_metadata, values=dict(myid=1, name='foo')) + i = insert(self.table_with_metadata, values=dict(myid=1, name="foo")) i = i.on_conflict_do_update( constraint=self.table_with_metadata.primary_key, - set_=dict(name=i.excluded.name) + set_=dict(name=i.excluded.name), + ) + self.assert_compile( + i, + "INSERT INTO mytable (myid, name) VALUES " + "(%(myid)s, %(name)s) ON CONFLICT (myid) " + "DO UPDATE SET name = excluded.name", ) - self.assert_compile(i, - 'INSERT INTO mytable (myid, name) VALUES ' - '(%(myid)s, %(name)s) ON CONFLICT (myid) ' - 'DO UPDATE SET name = excluded.name') def test_do_update_pk_constraint_index_elements_target(self): - i = insert( - self.table_with_metadata, values=dict(myid=1, name='foo')) + i = insert(self.table_with_metadata, values=dict(myid=1, name="foo")) i = i.on_conflict_do_update( index_elements=self.table_with_metadata.primary_key, - set_=dict(name=i.excluded.name) + set_=dict(name=i.excluded.name), + ) + self.assert_compile( + i, + "INSERT INTO mytable (myid, name) VALUES " + "(%(myid)s, %(name)s) ON CONFLICT (myid) " + "DO UPDATE SET name = excluded.name", ) - self.assert_compile(i, - 'INSERT INTO mytable (myid, name) VALUES ' - '(%(myid)s, %(name)s) ON CONFLICT (myid) ' - 'DO UPDATE SET name = excluded.name') def test_do_update_named_unique_constraint_target(self): - i = insert( - self.table1, values=dict(name='foo')) + i = insert(self.table1, values=dict(name="foo")) i = i.on_conflict_do_update( - constraint=self.unique_constr, - set_=dict(myid=i.excluded.myid) + constraint=self.unique_constr, set_=dict(myid=i.excluded.myid) + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) ON CONFLICT ON CONSTRAINT uq_name " + "DO UPDATE SET myid = excluded.myid", ) - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES ' - '(%(name)s) ON CONFLICT ON CONSTRAINT uq_name ' - 'DO UPDATE SET myid = excluded.myid') def test_do_update_string_constraint_target(self): - i = insert( - self.table1, values=dict(name='foo')) + i = insert(self.table1, values=dict(name="foo")) i = i.on_conflict_do_update( - constraint=self.unique_constr.name, - set_=dict(myid=i.excluded.myid) + constraint=self.unique_constr.name, set_=dict(myid=i.excluded.myid) + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) ON CONFLICT ON CONSTRAINT uq_name " + "DO UPDATE SET myid = excluded.myid", ) - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES ' - '(%(name)s) ON CONFLICT ON CONSTRAINT uq_name ' - 'DO UPDATE SET myid = excluded.myid') def test_do_update_index_elements_where_target(self): - i = insert( - self.table1, values=dict(name='foo')) + i = insert(self.table1, values=dict(name="foo")) i = i.on_conflict_do_update( index_elements=self.goofy_index.expressions, - index_where=self.goofy_index.dialect_options[ - 'postgresql']['where'], - set_=dict(name=i.excluded.name) + index_where=self.goofy_index.dialect_options["postgresql"][ + "where" + ], + set_=dict(name=i.excluded.name), + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) ON CONFLICT (name) " + "WHERE name > %(name_1)s " + "DO UPDATE SET name = excluded.name", ) - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES ' - "(%(name)s) ON CONFLICT (name) " - "WHERE name > %(name_1)s " - 'DO UPDATE SET name = excluded.name') def test_do_update_index_elements_where_target_multivalues(self): i = insert( self.table1, - values=[dict(name='foo'), dict(name='bar'), dict(name='bat')]) + values=[dict(name="foo"), dict(name="bar"), dict(name="bat")], + ) i = i.on_conflict_do_update( index_elements=self.goofy_index.expressions, - index_where=self.goofy_index.dialect_options[ - 'postgresql']['where'], - set_=dict(name=i.excluded.name) + index_where=self.goofy_index.dialect_options["postgresql"][ + "where" + ], + set_=dict(name=i.excluded.name), ) self.assert_compile( i, @@ -1402,107 +1590,116 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL): "WHERE name > %(name_1)s " "DO UPDATE SET name = excluded.name", checkparams={ - 'name_1': 'm', 'name_m0': 'foo', - 'name_m1': 'bar', 'name_m2': 'bat'} + "name_1": "m", + "name_m0": "foo", + "name_m1": "bar", + "name_m2": "bat", + }, ) def test_do_update_unnamed_index_target(self): - i = insert( - self.table1, values=dict(name='foo')) + i = insert(self.table1, values=dict(name="foo")) unnamed_goofy = Index( - None, self.table1.c.name, - postgresql_where=self.table1.c.name > 'm' + None, self.table1.c.name, postgresql_where=self.table1.c.name > "m" ) i = i.on_conflict_do_update( - constraint=unnamed_goofy, - set_=dict(name=i.excluded.name) + constraint=unnamed_goofy, set_=dict(name=i.excluded.name) + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) ON CONFLICT (name) " + "WHERE name > %(name_1)s " + "DO UPDATE SET name = excluded.name", ) - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES ' - "(%(name)s) ON CONFLICT (name) " - "WHERE name > %(name_1)s " - 'DO UPDATE SET name = excluded.name') def test_do_update_unnamed_exclude_constraint_target(self): - i = insert( - self.table1, values=dict(name='foo')) + i = insert(self.table1, values=dict(name="foo")) i = i.on_conflict_do_update( - constraint=self.excl_constr_anon, - set_=dict(name=i.excluded.name) + constraint=self.excl_constr_anon, set_=dict(name=i.excluded.name) + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) ON CONFLICT (name, description) " + "WHERE description != %(description_1)s " + "DO UPDATE SET name = excluded.name", ) - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES ' - "(%(name)s) ON CONFLICT (name, description) " - "WHERE description != %(description_1)s " - 'DO UPDATE SET name = excluded.name') def test_do_update_add_whereclause(self): - i = insert( - self.table1, values=dict(name='foo')) + i = insert(self.table1, values=dict(name="foo")) i = i.on_conflict_do_update( constraint=self.excl_constr_anon, set_=dict(name=i.excluded.name), where=( - (self.table1.c.name != 'brah') & - (self.table1.c.description != 'brah')) + (self.table1.c.name != "brah") + & (self.table1.c.description != "brah") + ), + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) ON CONFLICT (name, description) " + "WHERE description != %(description_1)s " + "DO UPDATE SET name = excluded.name " + "WHERE mytable.name != %(name_1)s " + "AND mytable.description != %(description_2)s", ) - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES ' - "(%(name)s) ON CONFLICT (name, description) " - "WHERE description != %(description_1)s " - 'DO UPDATE SET name = excluded.name ' - "WHERE mytable.name != %(name_1)s " - "AND mytable.description != %(description_2)s") def test_do_update_add_whereclause_references_excluded(self): - i = insert( - self.table1, values=dict(name='foo')) + i = insert(self.table1, values=dict(name="foo")) i = i.on_conflict_do_update( constraint=self.excl_constr_anon, set_=dict(name=i.excluded.name), - where=( - (self.table1.c.name != i.excluded.name)) + where=((self.table1.c.name != i.excluded.name)), + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) ON CONFLICT (name, description) " + "WHERE description != %(description_1)s " + "DO UPDATE SET name = excluded.name " + "WHERE mytable.name != excluded.name", ) - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES ' - "(%(name)s) ON CONFLICT (name, description) " - "WHERE description != %(description_1)s " - 'DO UPDATE SET name = excluded.name ' - "WHERE mytable.name != excluded.name") def test_do_update_additional_colnames(self): - i = insert( - self.table1, values=dict(name='bar')) + i = insert(self.table1, values=dict(name="bar")) i = i.on_conflict_do_update( constraint=self.excl_constr_anon, - set_=dict(name='somename', unknown='unknown') + set_=dict(name="somename", unknown="unknown"), ) with expect_warnings( - "Additional column names not matching any " - "column keys in table 'mytable': 'unknown'"): - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES ' - "(%(name)s) ON CONFLICT (name, description) " - "WHERE description != %(description_1)s " - "DO UPDATE SET name = %(param_1)s, " - "unknown = %(param_2)s", - checkparams={ - "name": "bar", - "description_1": "foo", - "param_1": "somename", - "param_2": "unknown"}) + "Additional column names not matching any " + "column keys in table 'mytable': 'unknown'" + ): + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) ON CONFLICT (name, description) " + "WHERE description != %(description_1)s " + "DO UPDATE SET name = %(param_1)s, " + "unknown = %(param_2)s", + checkparams={ + "name": "bar", + "description_1": "foo", + "param_1": "somename", + "param_2": "unknown", + }, + ) def test_on_conflict_as_cte(self): - i = insert( - self.table1, values=dict(name='foo')) - i = i.on_conflict_do_update( - constraint=self.excl_constr_anon, - set_=dict(name=i.excluded.name), - where=( - (self.table1.c.name != i.excluded.name)) - ).returning(literal_column("1")).cte("i_upsert") + i = insert(self.table1, values=dict(name="foo")) + i = ( + i.on_conflict_do_update( + constraint=self.excl_constr_anon, + set_=dict(name=i.excluded.name), + where=((self.table1.c.name != i.excluded.name)), + ) + .returning(literal_column("1")) + .cte("i_upsert") + ) stmt = select([i]) @@ -1515,19 +1712,24 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL): "DO UPDATE SET name = excluded.name " "WHERE mytable.name != excluded.name RETURNING 1) " "SELECT i_upsert.1 " - "FROM i_upsert" + "FROM i_upsert", ) def test_quote_raw_string_col(self): - t = table('t', column("FancyName"), column("other name")) + t = table("t", column("FancyName"), column("other name")) - stmt = insert(t).values(FancyName='something new').\ - on_conflict_do_update( - index_elements=['FancyName', 'other name'], - set_=OrderedDict([ - ("FancyName", 'something updated'), - ("other name", "something else") - ]) + stmt = ( + insert(t) + .values(FancyName="something new") + .on_conflict_do_update( + index_elements=["FancyName", "other name"], + set_=OrderedDict( + [ + ("FancyName", "something updated"), + ("other name", "something else"), + ] + ), + ) ) self.assert_compile( @@ -1536,8 +1738,11 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL): 'ON CONFLICT ("FancyName", "other name") ' 'DO UPDATE SET "FancyName" = %(param_1)s, ' '"other name" = %(param_2)s', - {'param_1': 'something updated', - 'param_2': 'something else', 'FancyName': 'something new'} + { + "param_1": "something updated", + "param_2": "something else", + "FancyName": "something new", + }, ) @@ -1547,68 +1752,71 @@ class DistinctOnTest(fixtures.TestBase, AssertsCompiledSQL): an emphasis on PG's 'DISTINCT ON' syntax. """ + __dialect__ = postgresql.dialect() def setup(self): - self.table = Table('t', MetaData(), - Column('id', Integer, primary_key=True), - Column('a', String), - Column('b', String), - ) + self.table = Table( + "t", + MetaData(), + Column("id", Integer, primary_key=True), + Column("a", String), + Column("b", String), + ) def test_plain_generative(self): self.assert_compile( select([self.table]).distinct(), - "SELECT DISTINCT t.id, t.a, t.b FROM t" + "SELECT DISTINCT t.id, t.a, t.b FROM t", ) def test_on_columns_generative(self): self.assert_compile( select([self.table]).distinct(self.table.c.a), - "SELECT DISTINCT ON (t.a) t.id, t.a, t.b FROM t" + "SELECT DISTINCT ON (t.a) t.id, t.a, t.b FROM t", ) def test_on_columns_generative_multi_call(self): self.assert_compile( - select([self.table]).distinct(self.table.c.a). - distinct(self.table.c.b), - "SELECT DISTINCT ON (t.a, t.b) t.id, t.a, t.b FROM t" + select([self.table]) + .distinct(self.table.c.a) + .distinct(self.table.c.b), + "SELECT DISTINCT ON (t.a, t.b) t.id, t.a, t.b FROM t", ) def test_plain_inline(self): self.assert_compile( select([self.table], distinct=True), - "SELECT DISTINCT t.id, t.a, t.b FROM t" + "SELECT DISTINCT t.id, t.a, t.b FROM t", ) def test_on_columns_inline_list(self): self.assert_compile( - select([self.table], - distinct=[self.table.c.a, self.table.c.b]). - order_by(self.table.c.a, self.table.c.b), + select( + [self.table], distinct=[self.table.c.a, self.table.c.b] + ).order_by(self.table.c.a, self.table.c.b), "SELECT DISTINCT ON (t.a, t.b) t.id, " - "t.a, t.b FROM t ORDER BY t.a, t.b" + "t.a, t.b FROM t ORDER BY t.a, t.b", ) def test_on_columns_inline_scalar(self): self.assert_compile( select([self.table], distinct=self.table.c.a), - "SELECT DISTINCT ON (t.a) t.id, t.a, t.b FROM t" + "SELECT DISTINCT ON (t.a) t.id, t.a, t.b FROM t", ) def test_literal_binds(self): self.assert_compile( select([self.table]).distinct(self.table.c.a == 10), "SELECT DISTINCT ON (t.a = 10) t.id, t.a, t.b FROM t", - literal_binds=True + literal_binds=True, ) def test_query_plain(self): sess = Session() self.assert_compile( sess.query(self.table).distinct(), - "SELECT DISTINCT t.id AS t_id, t.a AS t_a, " - "t.b AS t_b FROM t" + "SELECT DISTINCT t.id AS t_id, t.a AS t_a, " "t.b AS t_b FROM t", ) def test_query_on_columns(self): @@ -1616,16 +1824,17 @@ class DistinctOnTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( sess.query(self.table).distinct(self.table.c.a), "SELECT DISTINCT ON (t.a) t.id AS t_id, t.a AS t_a, " - "t.b AS t_b FROM t" + "t.b AS t_b FROM t", ) def test_query_on_columns_multi_call(self): sess = Session() self.assert_compile( - sess.query(self.table).distinct(self.table.c.a). - distinct(self.table.c.b), + sess.query(self.table) + .distinct(self.table.c.a) + .distinct(self.table.c.b), "SELECT DISTINCT ON (t.a, t.b) t.id AS t_id, t.a AS t_a, " - "t.b AS t_b FROM t" + "t.b AS t_b FROM t", ) def test_query_on_columns_subquery(self): @@ -1633,6 +1842,7 @@ class DistinctOnTest(fixtures.TestBase, AssertsCompiledSQL): class Foo(object): pass + mapper(Foo, self.table) sess = Session() self.assert_compile( @@ -1640,45 +1850,50 @@ class DistinctOnTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT DISTINCT ON (anon_1.t_a, anon_1.t_b) anon_1.t_id " "AS anon_1_t_id, anon_1.t_a AS anon_1_t_a, anon_1.t_b " "AS anon_1_t_b FROM (SELECT t.id AS t_id, t.a AS t_a, " - "t.b AS t_b FROM t) AS anon_1" + "t.b AS t_b FROM t) AS anon_1", ) def test_query_distinct_on_aliased(self): class Foo(object): pass + mapper(Foo, self.table) a1 = aliased(Foo) sess = Session() self.assert_compile( sess.query(a1).distinct(a1.a), "SELECT DISTINCT ON (t_1.a) t_1.id AS t_1_id, " - "t_1.a AS t_1_a, t_1.b AS t_1_b FROM t AS t_1" + "t_1.a AS t_1_a, t_1.b AS t_1_b FROM t AS t_1", ) def test_distinct_on_subquery_anon(self): sq = select([self.table]).alias() - q = select([self.table.c.id, sq.c.id]).\ - distinct(sq.c.id).\ - where(self.table.c.id == sq.c.id) + q = ( + select([self.table.c.id, sq.c.id]) + .distinct(sq.c.id) + .where(self.table.c.id == sq.c.id) + ) self.assert_compile( q, "SELECT DISTINCT ON (anon_1.id) t.id, anon_1.id " "FROM t, (SELECT t.id AS id, t.a AS a, t.b " - "AS b FROM t) AS anon_1 WHERE t.id = anon_1.id" + "AS b FROM t) AS anon_1 WHERE t.id = anon_1.id", ) def test_distinct_on_subquery_named(self): - sq = select([self.table]).alias('sq') - q = select([self.table.c.id, sq.c.id]).\ - distinct(sq.c.id).\ - where(self.table.c.id == sq.c.id) + sq = select([self.table]).alias("sq") + q = ( + select([self.table.c.id, sq.c.id]) + .distinct(sq.c.id) + .where(self.table.c.id == sq.c.id) + ) self.assert_compile( q, "SELECT DISTINCT ON (sq.id) t.id, sq.id " "FROM t, (SELECT t.id AS id, t.a AS a, " - "t.b AS b FROM t) AS sq WHERE t.id = sq.id" + "t.b AS b FROM t) AS sq WHERE t.id = sq.id", ) @@ -1686,18 +1901,23 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL): """Tests for full text searching """ + __dialect__ = postgresql.dialect() def setup(self): - self.table = Table('t', MetaData(), - Column('id', Integer, primary_key=True), - Column('title', String), - Column('body', String), - ) - self.table_alt = table('mytable', - column('id', Integer), - column('title', String(128)), - column('body', String(128))) + self.table = Table( + "t", + MetaData(), + Column("id", Integer, primary_key=True), + Column("title", String), + Column("body", String), + ) + self.table_alt = table( + "mytable", + column("id", Integer), + column("title", String(128)), + column("body", String(128)), + ) def _raise_query(self, q): """ @@ -1708,53 +1928,65 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL): raise ValueError(c) def test_match_basic(self): - s = select([self.table_alt.c.id])\ - .where(self.table_alt.c.title.match('somestring')) - self.assert_compile(s, - 'SELECT mytable.id ' - 'FROM mytable ' - 'WHERE mytable.title @@ to_tsquery(%(title_1)s)') + s = select([self.table_alt.c.id]).where( + self.table_alt.c.title.match("somestring") + ) + self.assert_compile( + s, + "SELECT mytable.id " + "FROM mytable " + "WHERE mytable.title @@ to_tsquery(%(title_1)s)", + ) def test_match_regconfig(self): s = select([self.table_alt.c.id]).where( self.table_alt.c.title.match( - 'somestring', - postgresql_regconfig='english') + "somestring", postgresql_regconfig="english" + ) ) self.assert_compile( - s, 'SELECT mytable.id ' - 'FROM mytable ' - """WHERE mytable.title @@ to_tsquery('english', %(title_1)s)""") + s, + "SELECT mytable.id " + "FROM mytable " + """WHERE mytable.title @@ to_tsquery('english', %(title_1)s)""", + ) def test_match_tsvector(self): s = select([self.table_alt.c.id]).where( - func.to_tsvector(self.table_alt.c.title) - .match('somestring') + func.to_tsvector(self.table_alt.c.title).match("somestring") ) self.assert_compile( - s, 'SELECT mytable.id ' - 'FROM mytable ' - 'WHERE to_tsvector(mytable.title) ' - '@@ to_tsquery(%(to_tsvector_1)s)') + s, + "SELECT mytable.id " + "FROM mytable " + "WHERE to_tsvector(mytable.title) " + "@@ to_tsquery(%(to_tsvector_1)s)", + ) def test_match_tsvectorconfig(self): s = select([self.table_alt.c.id]).where( - func.to_tsvector('english', self.table_alt.c.title) - .match('somestring') + func.to_tsvector("english", self.table_alt.c.title).match( + "somestring" + ) ) self.assert_compile( - s, 'SELECT mytable.id ' - 'FROM mytable ' - 'WHERE to_tsvector(%(to_tsvector_1)s, mytable.title) @@ ' - 'to_tsquery(%(to_tsvector_2)s)') + s, + "SELECT mytable.id " + "FROM mytable " + "WHERE to_tsvector(%(to_tsvector_1)s, mytable.title) @@ " + "to_tsquery(%(to_tsvector_2)s)", + ) def test_match_tsvectorconfig_regconfig(self): s = select([self.table_alt.c.id]).where( - func.to_tsvector('english', self.table_alt.c.title) - .match('somestring', postgresql_regconfig='english') + func.to_tsvector("english", self.table_alt.c.title).match( + "somestring", postgresql_regconfig="english" + ) ) self.assert_compile( - s, 'SELECT mytable.id ' - 'FROM mytable ' - 'WHERE to_tsvector(%(to_tsvector_1)s, mytable.title) @@ ' - """to_tsquery('english', %(to_tsvector_2)s)""") + s, + "SELECT mytable.id " + "FROM mytable " + "WHERE to_tsvector(%(to_tsvector_1)s, mytable.title) @@ " + """to_tsquery('english', %(to_tsvector_2)s)""", + ) diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 699a8aaaf3..82dd6d3ff6 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -1,15 +1,35 @@ # coding: utf-8 from sqlalchemy.testing.assertions import ( - eq_, assert_raises, assert_raises_message, AssertsExecutionResults, - AssertsCompiledSQL) + eq_, + assert_raises, + assert_raises_message, + AssertsExecutionResults, + AssertsCompiledSQL, +) from sqlalchemy.testing import engines, fixtures from sqlalchemy import testing import datetime from sqlalchemy import ( - Table, Column, select, MetaData, text, Integer, String, Sequence, Numeric, - DateTime, BigInteger, func, extract, SmallInteger, TypeDecorator, literal, - cast, bindparam) + Table, + Column, + select, + MetaData, + text, + Integer, + String, + Sequence, + Numeric, + DateTime, + BigInteger, + func, + extract, + SmallInteger, + TypeDecorator, + literal, + cast, + bindparam, +) from sqlalchemy import exc, schema from sqlalchemy.dialects.postgresql import base as postgresql import logging @@ -27,65 +47,77 @@ class DialectTest(fixtures.TestBase): """python-side dialect tests. """ def test_version_parsing(self): - def mock_conn(res): return Mock( - execute=Mock(return_value=Mock(scalar=Mock(return_value=res)))) + execute=Mock(return_value=Mock(scalar=Mock(return_value=res))) + ) dialect = postgresql.dialect() for string, version in [ - ( - 'PostgreSQL 8.3.8 on i686-redhat-linux-gnu, compiled by ' - 'GCC gcc (GCC) 4.1.2 20070925 (Red Hat 4.1.2-33)', - (8, 3, 8)), - ( - 'PostgreSQL 8.5devel on x86_64-unknown-linux-gnu, ' - 'compiled by GCC gcc (GCC) 4.4.2, 64-bit', (8, 5)), - ( - 'EnterpriseDB 9.1.2.2 on x86_64-unknown-linux-gnu, ' - 'compiled by gcc (GCC) 4.1.2 20080704 (Red Hat 4.1.2-50), ' - '64-bit', (9, 1, 2)), - ( - '[PostgreSQL 9.2.4 ] VMware vFabric Postgres 9.2.4.0 ' - 'release build 1080137', (9, 2, 4)), - ( - 'PostgreSQL 10devel on x86_64-pc-linux-gnu' - 'compiled by gcc (GCC) 6.3.1 20170306, 64-bit', (10,)), - ( - 'PostgreSQL 10beta1 on x86_64-pc-linux-gnu, ' - 'compiled by gcc (GCC) 4.8.5 20150623 ' - '(Red Hat 4.8.5-11), 64-bit', (10,)) + ( + "PostgreSQL 8.3.8 on i686-redhat-linux-gnu, compiled by " + "GCC gcc (GCC) 4.1.2 20070925 (Red Hat 4.1.2-33)", + (8, 3, 8), + ), + ( + "PostgreSQL 8.5devel on x86_64-unknown-linux-gnu, " + "compiled by GCC gcc (GCC) 4.4.2, 64-bit", + (8, 5), + ), + ( + "EnterpriseDB 9.1.2.2 on x86_64-unknown-linux-gnu, " + "compiled by gcc (GCC) 4.1.2 20080704 (Red Hat 4.1.2-50), " + "64-bit", + (9, 1, 2), + ), + ( + "[PostgreSQL 9.2.4 ] VMware vFabric Postgres 9.2.4.0 " + "release build 1080137", + (9, 2, 4), + ), + ( + "PostgreSQL 10devel on x86_64-pc-linux-gnu" + "compiled by gcc (GCC) 6.3.1 20170306, 64-bit", + (10,), + ), + ( + "PostgreSQL 10beta1 on x86_64-pc-linux-gnu, " + "compiled by gcc (GCC) 4.8.5 20150623 " + "(Red Hat 4.8.5-11), 64-bit", + (10,), + ), ]: - eq_(dialect._get_server_version_info(mock_conn(string)), - version) + eq_(dialect._get_server_version_info(mock_conn(string)), version) def test_deprecated_dialect_name_still_loads(self): dialects.registry.clear() with expect_deprecated( - "The 'postgres' dialect name " - "has been renamed to 'postgresql'"): + "The 'postgres' dialect name " "has been renamed to 'postgresql'" + ): dialect = url.URL("postgres").get_dialect() is_(dialect, postgresql.dialect) @testing.requires.psycopg2_compatibility def test_pg_dialect_use_native_unicode_from_config(self): config = { - 'sqlalchemy.url': testing.db.url, - 'sqlalchemy.use_native_unicode': "false"} + "sqlalchemy.url": testing.db.url, + "sqlalchemy.use_native_unicode": "false", + } e = engine_from_config(config, _initialize=False) eq_(e.dialect.use_native_unicode, False) config = { - 'sqlalchemy.url': testing.db.url, - 'sqlalchemy.use_native_unicode': "true"} + "sqlalchemy.url": testing.db.url, + "sqlalchemy.use_native_unicode": "true", + } e = engine_from_config(config, _initialize=False) eq_(e.dialect.use_native_unicode, True) class BatchInsertsTest(fixtures.TablesTest): - __only_on__ = 'postgresql+psycopg2' + __only_on__ = "postgresql+psycopg2" __backend__ = True run_create_tables = "each" @@ -93,11 +125,12 @@ class BatchInsertsTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table( - 'data', metadata, - Column('id', Integer, primary_key=True), - Column('x', String), - Column('y', String), - Column('z', Integer, server_default="5") + "data", + metadata, + Column("id", Integer, primary_key=True), + Column("x", String), + Column("y", String), + Column("z", Integer, server_default="5"), ) def setup(self): @@ -115,17 +148,13 @@ class BatchInsertsTest(fixtures.TablesTest): [ {"x": "x1", "y": "y1"}, {"x": "x2", "y": "y2"}, - {"x": "x3", "y": "y3"} - ] + {"x": "x3", "y": "y3"}, + ], ) eq_( conn.execute(select([self.tables.data])).fetchall(), - [ - (1, "x1", "y1", 5), - (2, "x2", "y2", 5), - (3, "x3", "y3", 5) - ] + [(1, "x1", "y1", 5), (2, "x2", "y2", 5), (3, "x3", "y3", 5)], ) def test_not_sane_rowcount(self): @@ -139,55 +168,52 @@ class BatchInsertsTest(fixtures.TablesTest): [ {"x": "x1", "y": "y1"}, {"x": "x2", "y": "y2"}, - {"x": "x3", "y": "y3"} - ] + {"x": "x3", "y": "y3"}, + ], ) conn.execute( - self.tables.data.update(). - where(self.tables.data.c.x == bindparam('xval')). - values(y=bindparam('yval')), - [ - {"xval": "x1", "yval": "y5"}, - {"xval": "x3", "yval": "y6"} - ] + self.tables.data.update() + .where(self.tables.data.c.x == bindparam("xval")) + .values(y=bindparam("yval")), + [{"xval": "x1", "yval": "y5"}, {"xval": "x3", "yval": "y6"}], ) eq_( conn.execute( - select([self.tables.data]). - order_by(self.tables.data.c.id)). - fetchall(), - [ - (1, "x1", "y5", 5), - (2, "x2", "y2", 5), - (3, "x3", "y6", 5) - ] + select([self.tables.data]).order_by(self.tables.data.c.id) + ).fetchall(), + [(1, "x1", "y5", 5), (2, "x2", "y2", 5), (3, "x3", "y6", 5)], ) class MiscBackendTest( - fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL): + fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL +): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True @testing.provide_metadata def test_date_reflection(self): metadata = self.metadata Table( - 'pgdate', metadata, Column('date1', DateTime(timezone=True)), - Column('date2', DateTime(timezone=False))) + "pgdate", + metadata, + Column("date1", DateTime(timezone=True)), + Column("date2", DateTime(timezone=False)), + ) metadata.create_all() m2 = MetaData(testing.db) - t2 = Table('pgdate', m2, autoload=True) + t2 = Table("pgdate", m2, autoload=True) assert t2.c.date1.type.timezone is True assert t2.c.date2.type.timezone is False @testing.requires.psycopg2_compatibility def test_psycopg2_version(self): v = testing.db.dialect.psycopg2_version - assert testing.db.dialect.dbapi.__version__.\ - startswith(".".join(str(x) for x in v)) + assert testing.db.dialect.dbapi.__version__.startswith( + ".".join(str(x) for x in v) + ) @testing.requires.psycopg2_compatibility def test_psycopg2_non_standard_err(self): @@ -198,8 +224,11 @@ class MiscBackendTest( ).extensions.TransactionRollbackError exception = exc.DBAPIError.instance( - "some statement", {}, TransactionRollbackError("foo"), - psycopg2.Error) + "some statement", + {}, + TransactionRollbackError("foo"), + psycopg2.Error, + ) assert isinstance(exception, exc.OperationalError) # currently not passing with pg 9.3 that does not seem to generate @@ -207,7 +236,7 @@ class MiscBackendTest( @testing.requires.no_coverage @testing.requires.psycopg2_compatibility def _test_notice_logging(self): - log = logging.getLogger('sqlalchemy.dialects.postgresql') + log = logging.getLogger("sqlalchemy.dialects.postgresql") buf = logging.handlers.BufferingHandler(100) lev = log.level log.addHandler(buf) @@ -216,15 +245,15 @@ class MiscBackendTest( conn = testing.db.connect() trans = conn.begin() try: - conn.execute('create table foo (id serial primary key)') + conn.execute("create table foo (id serial primary key)") finally: trans.rollback() finally: log.removeHandler(buf) log.setLevel(lev) - msgs = ' '.join(b.msg for b in buf.buffer) - assert 'will create implicit sequence' in msgs - assert 'will create implicit index' in msgs + msgs = " ".join(b.msg for b in buf.buffer) + assert "will create implicit sequence" in msgs + assert "will create implicit index" in msgs @testing.requires.psycopg2_or_pg8000_compatibility @engines.close_open_connections @@ -235,12 +264,12 @@ class MiscBackendTest( # attempt to use an encoding that's not # already set - if current_encoding == 'UTF8': - test_encoding = 'LATIN1' + if current_encoding == "UTF8": + test_encoding = "LATIN1" else: - test_encoding = 'UTF8' + test_encoding = "UTF8" - e = engines.testing_engine(options={'client_encoding': test_encoding}) + e = engines.testing_engine(options={"client_encoding": test_encoding}) c = e.connect() new_encoding = c.execute("show client_encoding").fetchone()[0] eq_(new_encoding, test_encoding) @@ -249,7 +278,8 @@ class MiscBackendTest( @engines.close_open_connections def test_autocommit_isolation_level(self): c = testing.db.connect().execution_options( - isolation_level='AUTOCOMMIT') + isolation_level="AUTOCOMMIT" + ) # If we're really in autocommit mode then we'll get an error saying # that the prepared transaction doesn't exist. Otherwise, we'd # get an error saying that the command can't be run within a @@ -257,57 +287,79 @@ class MiscBackendTest( assert_raises_message( exc.ProgrammingError, 'prepared transaction with identifier "gilberte" does not exist', - c.execute, "commit prepared 'gilberte'") + c.execute, + "commit prepared 'gilberte'", + ) - @testing.fails_on('+zxjdbc', - "Can't infer the SQL type to use for an instance " - "of org.python.core.PyObjectDerived.") + @testing.fails_on( + "+zxjdbc", + "Can't infer the SQL type to use for an instance " + "of org.python.core.PyObjectDerived.", + ) def test_extract(self): - fivedaysago = testing.db.scalar(select([func.now()])) - \ - datetime.timedelta(days=5) - for field, exp in ('year', fivedaysago.year), \ - ('month', fivedaysago.month), ('day', fivedaysago.day): + fivedaysago = testing.db.scalar( + select([func.now()]) + ) - datetime.timedelta(days=5) + for field, exp in ( + ("year", fivedaysago.year), + ("month", fivedaysago.month), + ("day", fivedaysago.day), + ): r = testing.db.execute( - select([ - extract(field, func.now() + datetime.timedelta(days=-5))]) + select( + [extract(field, func.now() + datetime.timedelta(days=-5))] + ) ).scalar() eq_(r, exp) @testing.provide_metadata def test_checksfor_sequence(self): meta1 = self.metadata - seq = Sequence('fooseq') - t = Table( - 'mytable', meta1, - Column('col1', Integer, seq) - ) + seq = Sequence("fooseq") + t = Table("mytable", meta1, Column("col1", Integer, seq)) seq.drop() - testing.db.execute('CREATE SEQUENCE fooseq') + testing.db.execute("CREATE SEQUENCE fooseq") t.create(checkfirst=True) @testing.provide_metadata def test_schema_roundtrips(self): meta = self.metadata users = Table( - 'users', meta, Column( - 'id', Integer, primary_key=True), Column( - 'name', String(50)), schema='test_schema') + "users", + meta, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + schema="test_schema", + ) users.create() - users.insert().execute(id=1, name='name1') - users.insert().execute(id=2, name='name2') - users.insert().execute(id=3, name='name3') - users.insert().execute(id=4, name='name4') - eq_(users.select().where(users.c.name == 'name2') - .execute().fetchall(), [(2, 'name2')]) - eq_(users.select(use_labels=True).where( - users.c.name == 'name2').execute().fetchall(), [(2, 'name2')]) + users.insert().execute(id=1, name="name1") + users.insert().execute(id=2, name="name2") + users.insert().execute(id=3, name="name3") + users.insert().execute(id=4, name="name4") + eq_( + users.select().where(users.c.name == "name2").execute().fetchall(), + [(2, "name2")], + ) + eq_( + users.select(use_labels=True) + .where(users.c.name == "name2") + .execute() + .fetchall(), + [(2, "name2")], + ) users.delete().where(users.c.id == 3).execute() - eq_(users.select().where(users.c.name == 'name3') - .execute().fetchall(), []) - users.update().where(users.c.name == 'name4' - ).execute(name='newname') - eq_(users.select(use_labels=True).where( - users.c.id == 4).execute().fetchall(), [(4, 'newname')]) + eq_( + users.select().where(users.c.name == "name3").execute().fetchall(), + [], + ) + users.update().where(users.c.name == "name4").execute(name="newname") + eq_( + users.select(use_labels=True) + .where(users.c.id == 4) + .execute() + .fetchall(), + [(4, "newname")], + ) def test_quoted_name_bindparam_ok(self): from sqlalchemy.sql.elements import quoted_name @@ -316,11 +368,15 @@ class MiscBackendTest( eq_( conn.scalar( select( - [cast( - literal(quoted_name("some_name", False)), String)] + [ + cast( + literal(quoted_name("some_name", False)), + String, + ) + ] ) ), - "some_name" + "some_name", ) def test_preexecute_passivedefault(self): @@ -330,7 +386,8 @@ class MiscBackendTest( try: meta = MetaData(testing.db) - testing.db.execute(""" + testing.db.execute( + """ CREATE TABLE speedy_users ( speedy_user_id SERIAL PRIMARY KEY, @@ -338,57 +395,62 @@ class MiscBackendTest( user_name VARCHAR NOT NULL, user_password VARCHAR NOT NULL ); - """) - t = Table('speedy_users', meta, autoload=True) - r = t.insert().execute(user_name='user', - user_password='lala') + """ + ) + t = Table("speedy_users", meta, autoload=True) + r = t.insert().execute(user_name="user", user_password="lala") assert r.inserted_primary_key == [1] result = t.select().execute().fetchall() - assert result == [(1, 'user', 'lala')] + assert result == [(1, "user", "lala")] finally: - testing.db.execute('drop table speedy_users') + testing.db.execute("drop table speedy_users") - @testing.fails_on('+zxjdbc', 'psycopg2/pg8000 specific assertion') + @testing.fails_on("+zxjdbc", "psycopg2/pg8000 specific assertion") @testing.requires.psycopg2_or_pg8000_compatibility def test_numeric_raise(self): - stmt = text( - "select cast('hi' as char) as hi", typemap={'hi': Numeric}) + stmt = text("select cast('hi' as char) as hi", typemap={"hi": Numeric}) assert_raises(exc.InvalidRequestError, testing.db.execute, stmt) @testing.only_if( - "postgresql >= 8.2", "requires standard_conforming_strings") + "postgresql >= 8.2", "requires standard_conforming_strings" + ) def test_serial_integer(self): - class BITD(TypeDecorator): impl = Integer def load_dialect_impl(self, dialect): - if dialect.name == 'postgresql': + if dialect.name == "postgresql": return BigInteger() else: return Integer() for version, type_, expected in [ - (None, Integer, 'SERIAL'), - (None, BigInteger, 'BIGSERIAL'), - ((9, 1), SmallInteger, 'SMALLINT'), - ((9, 2), SmallInteger, 'SMALLSERIAL'), - (None, postgresql.INTEGER, 'SERIAL'), - (None, postgresql.BIGINT, 'BIGSERIAL'), + (None, Integer, "SERIAL"), + (None, BigInteger, "BIGSERIAL"), + ((9, 1), SmallInteger, "SMALLINT"), + ((9, 2), SmallInteger, "SMALLSERIAL"), + (None, postgresql.INTEGER, "SERIAL"), + (None, postgresql.BIGINT, "BIGSERIAL"), ( - None, Integer().with_variant(BigInteger(), 'postgresql'), - 'BIGSERIAL'), + None, + Integer().with_variant(BigInteger(), "postgresql"), + "BIGSERIAL", + ), ( - None, Integer().with_variant(postgresql.BIGINT, 'postgresql'), - 'BIGSERIAL'), + None, + Integer().with_variant(postgresql.BIGINT, "postgresql"), + "BIGSERIAL", + ), ( - (9, 2), Integer().with_variant(SmallInteger, 'postgresql'), - 'SMALLSERIAL'), - (None, BITD(), 'BIGSERIAL') + (9, 2), + Integer().with_variant(SmallInteger, "postgresql"), + "SMALLSERIAL", + ), + (None, BITD(), "BIGSERIAL"), ]: m = MetaData() - t = Table('t', m, Column('c', type_, primary_key=True)) + t = Table("t", m, Column("c", type_, primary_key=True)) if version: dialect = postgresql.dialect() @@ -400,12 +462,12 @@ class MiscBackendTest( ddl_compiler = dialect.ddl_compiler(dialect, schema.CreateTable(t)) eq_( ddl_compiler.get_column_specification(t.c.c), - "c %s NOT NULL" % expected + "c %s NOT NULL" % expected, ) class AutocommitTextTest(test_execute.AutocommitTextTest): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" def test_grant(self): self._test_keyword("GRANT USAGE ON SCHEMA fooschema TO foorole") diff --git a/test/dialect/postgresql/test_on_conflict.py b/test/dialect/postgresql/test_on_conflict.py index c3e1b91584..4e73c38402 100644 --- a/test/dialect/postgresql/test_on_conflict.py +++ b/test/dialect/postgresql/test_on_conflict.py @@ -10,16 +10,17 @@ from sqlalchemy.dialects.postgresql import insert class OnConflictTest(fixtures.TablesTest): - __only_on__ = 'postgresql >= 9.5', + __only_on__ = ("postgresql >= 9.5",) __backend__ = True - run_define_tables = 'each' + run_define_tables = "each" @classmethod def define_tables(cls, metadata): Table( - 'users', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)) + "users", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), ) class SpecialType(sqltypes.TypeDecorator): @@ -29,49 +30,57 @@ class OnConflictTest(fixtures.TablesTest): return value + " processed" Table( - 'bind_targets', metadata, - Column('id', Integer, primary_key=True), - Column('data', SpecialType()) + "bind_targets", + metadata, + Column("id", Integer, primary_key=True), + Column("data", SpecialType()), ) users_xtra = Table( - 'users_xtra', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)), - Column('login_email', String(50)), - Column('lets_index_this', String(50)) + "users_xtra", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + Column("login_email", String(50)), + Column("lets_index_this", String(50)), ) cls.unique_partial_index = schema.Index( - 'idx_unique_partial_name', - users_xtra.c.name, users_xtra.c.lets_index_this, + "idx_unique_partial_name", + users_xtra.c.name, + users_xtra.c.lets_index_this, unique=True, - postgresql_where=users_xtra.c.lets_index_this == 'unique_name') + postgresql_where=users_xtra.c.lets_index_this == "unique_name", + ) cls.unique_constraint = schema.UniqueConstraint( - users_xtra.c.login_email, name='uq_login_email') + users_xtra.c.login_email, name="uq_login_email" + ) cls.bogus_index = schema.Index( - 'idx_special_ops', + "idx_special_ops", users_xtra.c.lets_index_this, - postgresql_where=users_xtra.c.lets_index_this > 'm') + postgresql_where=users_xtra.c.lets_index_this > "m", + ) def test_bad_args(self): assert_raises( ValueError, insert(self.tables.users).on_conflict_do_nothing, - constraint='id', index_elements=['id'] + constraint="id", + index_elements=["id"], ) assert_raises( ValueError, insert(self.tables.users).on_conflict_do_update, - constraint='id', index_elements=['id'] + constraint="id", + index_elements=["id"], ) assert_raises( ValueError, - insert(self.tables.users).on_conflict_do_update, constraint='id' + insert(self.tables.users).on_conflict_do_update, + constraint="id", ) assert_raises( - ValueError, - insert(self.tables.users).on_conflict_do_update + ValueError, insert(self.tables.users).on_conflict_do_update ) def test_on_conflict_do_nothing(self): @@ -80,22 +89,21 @@ class OnConflictTest(fixtures.TablesTest): with testing.db.connect() as conn: result = conn.execute( insert(users).on_conflict_do_nothing(), - - dict(id=1, name='name1') + dict(id=1, name="name1"), ) eq_(result.inserted_primary_key, [1]) eq_(result.returned_defaults, None) result = conn.execute( insert(users).on_conflict_do_nothing(), - dict(id=1, name='name2') + dict(id=1, name="name2"), ) eq_(result.inserted_primary_key, [1]) eq_(result.returned_defaults, None) eq_( conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, 'name1')] + [(1, "name1")], ) def test_on_conflict_do_nothing_connectionless(self): @@ -104,25 +112,25 @@ class OnConflictTest(fixtures.TablesTest): with testing.db.connect() as conn: result = conn.execute( insert(users).on_conflict_do_nothing( - constraint='uq_login_email'), - - dict(name='name1', login_email='email1') + constraint="uq_login_email" + ), + dict(name="name1", login_email="email1"), ) eq_(result.inserted_primary_key, [1]) eq_(result.returned_defaults, (1,)) result = testing.db.execute( - insert(users).on_conflict_do_nothing( - constraint='uq_login_email' - ), - dict(name='name2', login_email='email1') + insert(users).on_conflict_do_nothing(constraint="uq_login_email"), + dict(name="name2", login_email="email1"), ) eq_(result.inserted_primary_key, None) eq_(result.returned_defaults, None) eq_( - testing.db.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, 'name1', 'email1', None)] + testing.db.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name1", "email1", None)], ) @testing.provide_metadata @@ -131,100 +139,100 @@ class OnConflictTest(fixtures.TablesTest): with testing.db.connect() as conn: result = conn.execute( - insert(users) - .on_conflict_do_nothing( - index_elements=users.primary_key.columns), - dict(id=1, name='name1') + insert(users).on_conflict_do_nothing( + index_elements=users.primary_key.columns + ), + dict(id=1, name="name1"), ) eq_(result.inserted_primary_key, [1]) eq_(result.returned_defaults, None) result = conn.execute( - insert(users) - .on_conflict_do_nothing( - index_elements=users.primary_key.columns), - dict(id=1, name='name2') + insert(users).on_conflict_do_nothing( + index_elements=users.primary_key.columns + ), + dict(id=1, name="name2"), ) eq_(result.inserted_primary_key, [1]) eq_(result.returned_defaults, None) eq_( conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, 'name1')] + [(1, "name1")], ) def test_on_conflict_do_update_one(self): users = self.tables.users with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name='name1')) + conn.execute(users.insert(), dict(id=1, name="name1")) i = insert(users) i = i.on_conflict_do_update( - index_elements=[users.c.id], - set_=dict(name=i.excluded.name)) - result = conn.execute(i, dict(id=1, name='name1')) + index_elements=[users.c.id], set_=dict(name=i.excluded.name) + ) + result = conn.execute(i, dict(id=1, name="name1")) eq_(result.inserted_primary_key, [1]) eq_(result.returned_defaults, None) eq_( conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, 'name1')] + [(1, "name1")], ) def test_on_conflict_do_update_two(self): users = self.tables.users with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name='name1')) + conn.execute(users.insert(), dict(id=1, name="name1")) i = insert(users) i = i.on_conflict_do_update( index_elements=[users.c.id], - set_=dict(id=i.excluded.id, name=i.excluded.name) + set_=dict(id=i.excluded.id, name=i.excluded.name), ) - result = conn.execute(i, dict(id=1, name='name2')) + result = conn.execute(i, dict(id=1, name="name2")) eq_(result.inserted_primary_key, [1]) eq_(result.returned_defaults, None) eq_( conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, 'name2')] + [(1, "name2")], ) def test_on_conflict_do_update_three(self): users = self.tables.users with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name='name1')) + conn.execute(users.insert(), dict(id=1, name="name1")) i = insert(users) i = i.on_conflict_do_update( index_elements=users.primary_key.columns, - set_=dict(name=i.excluded.name) + set_=dict(name=i.excluded.name), ) - result = conn.execute(i, dict(id=1, name='name3')) + result = conn.execute(i, dict(id=1, name="name3")) eq_(result.inserted_primary_key, [1]) eq_(result.returned_defaults, None) eq_( conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, 'name3')] + [(1, "name3")], ) def test_on_conflict_do_update_four(self): users = self.tables.users with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name='name1')) + conn.execute(users.insert(), dict(id=1, name="name1")) i = insert(users) i = i.on_conflict_do_update( index_elements=users.primary_key.columns, - set_=dict(id=i.excluded.id, name=i.excluded.name) - ).values(id=1, name='name4') + set_=dict(id=i.excluded.id, name=i.excluded.name), + ).values(id=1, name="name4") result = conn.execute(i) eq_(result.inserted_primary_key, [1]) @@ -232,20 +240,20 @@ class OnConflictTest(fixtures.TablesTest): eq_( conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, 'name4')] + [(1, "name4")], ) def test_on_conflict_do_update_five(self): users = self.tables.users with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name='name1')) + conn.execute(users.insert(), dict(id=1, name="name1")) i = insert(users) i = i.on_conflict_do_update( index_elements=users.primary_key.columns, - set_=dict(id=10, name="I'm a name") - ).values(id=1, name='name4') + set_=dict(id=10, name="I'm a name"), + ).values(id=1, name="name4") result = conn.execute(i) eq_(result.inserted_primary_key, [1]) @@ -253,42 +261,39 @@ class OnConflictTest(fixtures.TablesTest): eq_( conn.execute( - users.select().where(users.c.id == 10)).fetchall(), - [(10, "I'm a name")] + users.select().where(users.c.id == 10) + ).fetchall(), + [(10, "I'm a name")], ) def test_on_conflict_do_update_multivalues(self): users = self.tables.users with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name='name1')) - conn.execute(users.insert(), dict(id=2, name='name2')) + conn.execute(users.insert(), dict(id=1, name="name1")) + conn.execute(users.insert(), dict(id=2, name="name2")) i = insert(users) i = i.on_conflict_do_update( index_elements=users.primary_key.columns, set_=dict(name="updated"), - where=(i.excluded.name != 'name12') - ).values([ - dict(id=1, name='name11'), - dict(id=2, name='name12'), - dict(id=3, name='name13'), - dict(id=4, name='name14'), - ]) + where=(i.excluded.name != "name12"), + ).values( + [ + dict(id=1, name="name11"), + dict(id=2, name="name12"), + dict(id=3, name="name13"), + dict(id=4, name="name14"), + ] + ) result = conn.execute(i) eq_(result.inserted_primary_key, [None]) eq_(result.returned_defaults, None) eq_( - conn.execute( - users.select().order_by(users.c.id)).fetchall(), - [ - (1, "updated"), - (2, "name2"), - (3, "name13"), - (4, "name14") - ] + conn.execute(users.select().order_by(users.c.id)).fetchall(), + [(1, "updated"), (2, "name2"), (3, "name13"), (4, "name14")], ) def _exotic_targets_fixture(self, conn): @@ -297,21 +302,25 @@ class OnConflictTest(fixtures.TablesTest): conn.execute( insert(users), dict( - id=1, name='name1', - login_email='name1@gmail.com', lets_index_this='not' - ) + id=1, + name="name1", + login_email="name1@gmail.com", + lets_index_this="not", + ), ) conn.execute( users.insert(), dict( - id=2, name='name2', - login_email='name2@gmail.com', lets_index_this='not' - ) + id=2, + name="name2", + login_email="name2@gmail.com", + lets_index_this="not", + ), ) eq_( conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, 'name1', 'name1@gmail.com', 'not')] + [(1, "name1", "name1@gmail.com", "not")], ) def test_on_conflict_do_update_exotic_targets_two(self): @@ -324,19 +333,24 @@ class OnConflictTest(fixtures.TablesTest): i = i.on_conflict_do_update( index_elements=users.primary_key.columns, set_=dict( - name=i.excluded.name, - login_email=i.excluded.login_email) + name=i.excluded.name, login_email=i.excluded.login_email + ), ) - result = conn.execute(i, dict( - id=1, name='name2', login_email='name1@gmail.com', - lets_index_this='not') + result = conn.execute( + i, + dict( + id=1, + name="name2", + login_email="name1@gmail.com", + lets_index_this="not", + ), ) eq_(result.inserted_primary_key, [1]) eq_(result.returned_defaults, None) eq_( conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, 'name2', 'name1@gmail.com', 'not')] + [(1, "name2", "name1@gmail.com", "not")], ) def test_on_conflict_do_update_exotic_targets_three(self): @@ -349,23 +363,32 @@ class OnConflictTest(fixtures.TablesTest): i = insert(users) i = i.on_conflict_do_update( constraint=self.unique_constraint, - set_=dict(id=i.excluded.id, name=i.excluded.name, - login_email=i.excluded.login_email) + set_=dict( + id=i.excluded.id, + name=i.excluded.name, + login_email=i.excluded.login_email, + ), ) # note: lets_index_this value totally ignored in SET clause. - result = conn.execute(i, dict( - id=42, name='nameunique', - login_email='name2@gmail.com', lets_index_this='unique') + result = conn.execute( + i, + dict( + id=42, + name="nameunique", + login_email="name2@gmail.com", + lets_index_this="unique", + ), ) eq_(result.inserted_primary_key, [42]) eq_(result.returned_defaults, None) eq_( conn.execute( - users.select(). - where(users.c.login_email == 'name2@gmail.com') + users.select().where( + users.c.login_email == "name2@gmail.com" + ) ).fetchall(), - [(42, 'nameunique', 'name2@gmail.com', 'not')] + [(42, "nameunique", "name2@gmail.com", "not")], ) def test_on_conflict_do_update_exotic_targets_four(self): @@ -379,24 +402,32 @@ class OnConflictTest(fixtures.TablesTest): i = i.on_conflict_do_update( constraint=self.unique_constraint.name, set_=dict( - id=i.excluded.id, name=i.excluded.name, - login_email=i.excluded.login_email) + id=i.excluded.id, + name=i.excluded.name, + login_email=i.excluded.login_email, + ), ) # note: lets_index_this value totally ignored in SET clause. - result = conn.execute(i, dict( - id=43, name='nameunique2', - login_email='name2@gmail.com', lets_index_this='unique') + result = conn.execute( + i, + dict( + id=43, + name="nameunique2", + login_email="name2@gmail.com", + lets_index_this="unique", + ), ) eq_(result.inserted_primary_key, [43]) eq_(result.returned_defaults, None) eq_( conn.execute( - users.select(). - where(users.c.login_email == 'name2@gmail.com') + users.select().where( + users.c.login_email == "name2@gmail.com" + ) ).fetchall(), - [(43, 'nameunique2', 'name2@gmail.com', 'not')] + [(43, "nameunique2", "name2@gmail.com", "not")], ) def test_on_conflict_do_update_exotic_targets_four_no_pk(self): @@ -410,23 +441,24 @@ class OnConflictTest(fixtures.TablesTest): i = i.on_conflict_do_update( index_elements=[users.c.login_email], set_=dict( - id=i.excluded.id, name=i.excluded.name, - login_email=i.excluded.login_email) + id=i.excluded.id, + name=i.excluded.name, + login_email=i.excluded.login_email, + ), ) - result = conn.execute(i, dict( - name='name3', - login_email='name1@gmail.com') + result = conn.execute( + i, dict(name="name3", login_email="name1@gmail.com") ) eq_(result.inserted_primary_key, [1]) - eq_(result.returned_defaults, (1, )) + eq_(result.returned_defaults, (1,)) eq_( conn.execute(users.select().order_by(users.c.id)).fetchall(), [ - (1, 'name3', 'name1@gmail.com', 'not'), - (2, 'name2', 'name2@gmail.com', 'not') - ] + (1, "name3", "name1@gmail.com", "not"), + (2, "name2", "name2@gmail.com", "not"), + ], ) def test_on_conflict_do_update_exotic_targets_five(self): @@ -438,18 +470,24 @@ class OnConflictTest(fixtures.TablesTest): i = insert(users) i = i.on_conflict_do_update( index_elements=self.bogus_index.columns, - index_where=self. - bogus_index.dialect_options['postgresql']['where'], + index_where=self.bogus_index.dialect_options["postgresql"][ + "where" + ], set_=dict( - name=i.excluded.name, - login_email=i.excluded.login_email) + name=i.excluded.name, login_email=i.excluded.login_email + ), ) assert_raises( - exc.ProgrammingError, conn.execute, i, + exc.ProgrammingError, + conn.execute, + i, dict( - id=1, name='namebogus', login_email='bogus@gmail.com', - lets_index_this='bogus') + id=1, + name="namebogus", + login_email="bogus@gmail.com", + lets_index_this="bogus", + ), ) def test_on_conflict_do_update_exotic_targets_six(self): @@ -459,35 +497,38 @@ class OnConflictTest(fixtures.TablesTest): conn.execute( insert(users), dict( - id=1, name='name1', - login_email='mail1@gmail.com', - lets_index_this='unique_name' - ) + id=1, + name="name1", + login_email="mail1@gmail.com", + lets_index_this="unique_name", + ), ) i = insert(users) i = i.on_conflict_do_update( index_elements=self.unique_partial_index.columns, - index_where=self.unique_partial_index.dialect_options - ['postgresql']['where'], + index_where=self.unique_partial_index.dialect_options[ + "postgresql" + ]["where"], set_=dict( - name=i.excluded.name, - login_email=i.excluded.login_email), + name=i.excluded.name, login_email=i.excluded.login_email + ), ) conn.execute( i, [ - dict(name='name1', login_email='mail2@gmail.com', - lets_index_this='unique_name'), - ] + dict( + name="name1", + login_email="mail2@gmail.com", + lets_index_this="unique_name", + ) + ], ) eq_( conn.execute(users.select()).fetchall(), - [ - (1, 'name1', 'mail2@gmail.com', 'unique_name'), - ] + [(1, "name1", "mail2@gmail.com", "unique_name")], ) def test_on_conflict_do_update_no_row_actually_affected(self): @@ -498,11 +539,12 @@ class OnConflictTest(fixtures.TablesTest): i = insert(users) i = i.on_conflict_do_update( index_elements=[users.c.login_email], - set_=dict(name='new_name'), - where=(i.excluded.name == 'other_name') + set_=dict(name="new_name"), + where=(i.excluded.name == "other_name"), ) result = conn.execute( - i, dict(name='name2', login_email='name1@gmail.com')) + i, dict(name="name2", login_email="name1@gmail.com") + ) eq_(result.returned_defaults, None) eq_(result.inserted_primary_key, None) @@ -510,9 +552,9 @@ class OnConflictTest(fixtures.TablesTest): eq_( conn.execute(users.select()).fetchall(), [ - (1, 'name1', 'name1@gmail.com', 'not'), - (2, 'name2', 'name2@gmail.com', 'not') - ] + (1, "name1", "name1@gmail.com", "not"), + (2, "name2", "name2@gmail.com", "not"), + ], ) def test_on_conflict_do_update_special_types_in_set(self): @@ -524,19 +566,17 @@ class OnConflictTest(fixtures.TablesTest): eq_( conn.scalar(sql.select([bind_targets.c.data])), - "initial data processed" + "initial data processed", ) i = insert(bind_targets) i = i.on_conflict_do_update( index_elements=[bind_targets.c.id], - set_=dict(data="new updated data") - ) - conn.execute( - i, {"id": 1, "data": "new inserted data"} + set_=dict(data="new updated data"), ) + conn.execute(i, {"id": 1, "data": "new inserted data"}) eq_( conn.scalar(sql.select([bind_targets.c.data])), - "new updated data processed" + "new updated data processed", ) diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py index 47a12afece..4156dc0f51 100644 --- a/test/dialect/postgresql/test_query.py +++ b/test/dialect/postgresql/test_query.py @@ -1,10 +1,35 @@ # coding: utf-8 -from sqlalchemy.testing import AssertsExecutionResults, eq_, \ - assert_raises_message, AssertsCompiledSQL, expect_warnings, assert_raises -from sqlalchemy import Table, Column, MetaData, Integer, String, bindparam, \ - Sequence, ForeignKey, text, select, func, extract, literal_column, \ - tuple_, DateTime, Time, literal, and_, Date, or_ +from sqlalchemy.testing import ( + AssertsExecutionResults, + eq_, + assert_raises_message, + AssertsCompiledSQL, + expect_warnings, + assert_raises, +) +from sqlalchemy import ( + Table, + Column, + MetaData, + Integer, + String, + bindparam, + Sequence, + ForeignKey, + text, + select, + func, + extract, + literal_column, + tuple_, + DateTime, + Time, + literal, + and_, + Date, + or_, +) from sqlalchemy.testing import engines, fixtures from sqlalchemy.testing.assertsql import DialectSQL, CursorSQL from sqlalchemy import testing @@ -17,7 +42,7 @@ matchtable = cattable = None class InsertTest(fixtures.TestBase, AssertsExecutionResults): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True @classmethod @@ -30,32 +55,25 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): def test_compiled_insert(self): table = Table( - 'testtable', self.metadata, Column( - 'id', Integer, primary_key=True), - Column( - 'data', String(30))) + "testtable", + self.metadata, + Column("id", Integer, primary_key=True), + Column("data", String(30)), + ) self.metadata.create_all() ins = table.insert( - inline=True, - values={'data': bindparam('x')}).compile() - ins.execute({'x': 'five'}, {'x': 'seven'}) - eq_( - table.select().execute().fetchall(), - [(1, 'five'), (2, 'seven')] - ) + inline=True, values={"data": bindparam("x")} + ).compile() + ins.execute({"x": "five"}, {"x": "seven"}) + eq_(table.select().execute().fetchall(), [(1, "five"), (2, "seven")]) def test_foreignkey_missing_insert(self): - Table( - 't1', self.metadata, - Column('id', Integer, primary_key=True)) + Table("t1", self.metadata, Column("id", Integer, primary_key=True)) t2 = Table( - 't2', + "t2", self.metadata, - Column( - 'id', - Integer, - ForeignKey('t1.id'), - primary_key=True)) + Column("id", Integer, ForeignKey("t1.id"), primary_key=True), + ) self.metadata.create_all() # want to ensure that "null value in column "id" violates not- @@ -66,178 +84,178 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): # the case here due to the foreign key. for eng in [ - engines.testing_engine(options={'implicit_returning': False}), - engines.testing_engine(options={'implicit_returning': True}) + engines.testing_engine(options={"implicit_returning": False}), + engines.testing_engine(options={"implicit_returning": True}), ]: with expect_warnings( ".*has no Python-side or server-side default.*" ): assert_raises( (exc.IntegrityError, exc.ProgrammingError), - eng.execute, t2.insert() + eng.execute, + t2.insert(), ) def test_sequence_insert(self): table = Table( - 'testtable', + "testtable", self.metadata, - Column( - 'id', - Integer, - Sequence('my_seq'), - primary_key=True), - Column( - 'data', - String(30))) + Column("id", Integer, Sequence("my_seq"), primary_key=True), + Column("data", String(30)), + ) self.metadata.create_all() - self._assert_data_with_sequence(table, 'my_seq') + self._assert_data_with_sequence(table, "my_seq") @testing.requires.returning def test_sequence_returning_insert(self): table = Table( - 'testtable', + "testtable", self.metadata, - Column( - 'id', - Integer, - Sequence('my_seq'), - primary_key=True), - Column( - 'data', - String(30))) + Column("id", Integer, Sequence("my_seq"), primary_key=True), + Column("data", String(30)), + ) self.metadata.create_all() - self._assert_data_with_sequence_returning(table, 'my_seq') + self._assert_data_with_sequence_returning(table, "my_seq") def test_opt_sequence_insert(self): table = Table( - 'testtable', self.metadata, - Column( - 'id', Integer, Sequence( - 'my_seq', optional=True), primary_key=True), + "testtable", + self.metadata, Column( - 'data', String(30))) + "id", + Integer, + Sequence("my_seq", optional=True), + primary_key=True, + ), + Column("data", String(30)), + ) self.metadata.create_all() self._assert_data_autoincrement(table) @testing.requires.returning def test_opt_sequence_returning_insert(self): table = Table( - 'testtable', self.metadata, - Column( - 'id', Integer, Sequence( - 'my_seq', optional=True), primary_key=True), + "testtable", + self.metadata, Column( - 'data', String(30))) + "id", + Integer, + Sequence("my_seq", optional=True), + primary_key=True, + ), + Column("data", String(30)), + ) self.metadata.create_all() self._assert_data_autoincrement_returning(table) def test_autoincrement_insert(self): table = Table( - 'testtable', self.metadata, - Column( - 'id', Integer, primary_key=True), - Column( - 'data', String(30))) + "testtable", + self.metadata, + Column("id", Integer, primary_key=True), + Column("data", String(30)), + ) self.metadata.create_all() self._assert_data_autoincrement(table) @testing.requires.returning def test_autoincrement_returning_insert(self): table = Table( - 'testtable', self.metadata, - Column( - 'id', Integer, primary_key=True), - Column( - 'data', String(30))) + "testtable", + self.metadata, + Column("id", Integer, primary_key=True), + Column("data", String(30)), + ) self.metadata.create_all() self._assert_data_autoincrement_returning(table) def test_noautoincrement_insert(self): table = Table( - 'testtable', + "testtable", self.metadata, - Column( - 'id', - Integer, - primary_key=True, - autoincrement=False), - Column( - 'data', - String(30))) + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(30)), + ) self.metadata.create_all() self._assert_data_noautoincrement(table) def _assert_data_autoincrement(self, table): - engine = \ - engines.testing_engine(options={'implicit_returning': False}) + engine = engines.testing_engine(options={"implicit_returning": False}) with self.sql_execution_asserter(engine) as asserter: with engine.connect() as conn: # execute with explicit id - r = conn.execute(table.insert(), {'id': 30, 'data': 'd1'}) + r = conn.execute(table.insert(), {"id": 30, "data": "d1"}) eq_(r.inserted_primary_key, [30]) # execute with prefetch id - r = conn.execute(table.insert(), {'data': 'd2'}) + r = conn.execute(table.insert(), {"data": "d2"}) eq_(r.inserted_primary_key, [1]) # executemany with explicit ids conn.execute( table.insert(), - {'id': 31, 'data': 'd3'}, {'id': 32, 'data': 'd4'}) + {"id": 31, "data": "d3"}, + {"id": 32, "data": "d4"}, + ) # executemany, uses SERIAL - conn.execute(table.insert(), {'data': 'd5'}, {'data': 'd6'}) + conn.execute(table.insert(), {"data": "d5"}, {"data": "d6"}) # single execute, explicit id, inline conn.execute( - table.insert(inline=True), - {'id': 33, 'data': 'd7'}) + table.insert(inline=True), {"id": 33, "data": "d7"} + ) # single execute, inline, uses SERIAL - conn.execute(table.insert(inline=True), {'data': 'd8'}) + conn.execute(table.insert(inline=True), {"data": "d8"}) asserter.assert_( DialectSQL( - 'INSERT INTO testtable (id, data) VALUES (:id, :data)', - {'id': 30, 'data': 'd1'}), + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {"id": 30, "data": "d1"}, + ), DialectSQL( - 'INSERT INTO testtable (id, data) VALUES (:id, :data)', - {'id': 1, 'data': 'd2'}), + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {"id": 1, "data": "d2"}, + ), DialectSQL( - 'INSERT INTO testtable (id, data) VALUES (:id, :data)', - [{'id': 31, 'data': 'd3'}, {'id': 32, 'data': 'd4'}]), + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}], + ), DialectSQL( - 'INSERT INTO testtable (data) VALUES (:data)', - [{'data': 'd5'}, {'data': 'd6'}]), + "INSERT INTO testtable (data) VALUES (:data)", + [{"data": "d5"}, {"data": "d6"}], + ), DialectSQL( - 'INSERT INTO testtable (id, data) VALUES (:id, :data)', - [{'id': 33, 'data': 'd7'}]), + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{"id": 33, "data": "d7"}], + ), DialectSQL( - 'INSERT INTO testtable (data) VALUES (:data)', - [{'data': 'd8'}]), + "INSERT INTO testtable (data) VALUES (:data)", [{"data": "d8"}] + ), ) with engine.connect() as conn: eq_( conn.execute(table.select()).fetchall(), [ - (30, 'd1'), - (1, 'd2'), - (31, 'd3'), - (32, 'd4'), - (2, 'd5'), - (3, 'd6'), - (33, 'd7'), - (4, 'd8'), - ] + (30, "d1"), + (1, "d2"), + (31, "d3"), + (32, "d4"), + (2, "d5"), + (3, "d6"), + (33, "d7"), + (4, "d8"), + ], ) conn.execute(table.delete()) @@ -250,117 +268,139 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): with self.sql_execution_asserter(engine) as asserter: with engine.connect() as conn: - conn.execute(table.insert(), {'id': 30, 'data': 'd1'}) - r = conn.execute(table.insert(), {'data': 'd2'}) + conn.execute(table.insert(), {"id": 30, "data": "d1"}) + r = conn.execute(table.insert(), {"data": "d2"}) eq_(r.inserted_primary_key, [5]) conn.execute( table.insert(), - {'id': 31, 'data': 'd3'}, {'id': 32, 'data': 'd4'}) - conn.execute(table.insert(), {'data': 'd5'}, {'data': 'd6'}) + {"id": 31, "data": "d3"}, + {"id": 32, "data": "d4"}, + ) + conn.execute(table.insert(), {"data": "d5"}, {"data": "d6"}) conn.execute( - table.insert(inline=True), {'id': 33, 'data': 'd7'}) - conn.execute(table.insert(inline=True), {'data': 'd8'}) + table.insert(inline=True), {"id": 33, "data": "d7"} + ) + conn.execute(table.insert(inline=True), {"data": "d8"}) asserter.assert_( DialectSQL( - 'INSERT INTO testtable (id, data) VALUES (:id, :data)', - {'id': 30, 'data': 'd1'}), + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {"id": 30, "data": "d1"}, + ), DialectSQL( - 'INSERT INTO testtable (id, data) VALUES (:id, :data)', - {'id': 5, 'data': 'd2'}), + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {"id": 5, "data": "d2"}, + ), DialectSQL( - 'INSERT INTO testtable (id, data) VALUES (:id, :data)', - [{'id': 31, 'data': 'd3'}, {'id': 32, 'data': 'd4'}]), + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}], + ), DialectSQL( - 'INSERT INTO testtable (data) VALUES (:data)', - [{'data': 'd5'}, {'data': 'd6'}]), + "INSERT INTO testtable (data) VALUES (:data)", + [{"data": "d5"}, {"data": "d6"}], + ), DialectSQL( - 'INSERT INTO testtable (id, data) VALUES (:id, :data)', - [{'id': 33, 'data': 'd7'}]), + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{"id": 33, "data": "d7"}], + ), DialectSQL( - 'INSERT INTO testtable (data) VALUES (:data)', - [{'data': 'd8'}]), + "INSERT INTO testtable (data) VALUES (:data)", [{"data": "d8"}] + ), ) with engine.connect() as conn: eq_( conn.execute(table.select()).fetchall(), [ - (30, 'd1'), - (5, 'd2'), - (31, 'd3'), - (32, 'd4'), - (6, 'd5'), - (7, 'd6'), - (33, 'd7'), - (8, 'd8'), - ] + (30, "d1"), + (5, "d2"), + (31, "d3"), + (32, "d4"), + (6, "d5"), + (7, "d6"), + (33, "d7"), + (8, "d8"), + ], ) conn.execute(table.delete()) def _assert_data_autoincrement_returning(self, table): - engine = \ - engines.testing_engine(options={'implicit_returning': True}) + engine = engines.testing_engine(options={"implicit_returning": True}) with self.sql_execution_asserter(engine) as asserter: with engine.connect() as conn: # execute with explicit id - r = conn.execute(table.insert(), {'id': 30, 'data': 'd1'}) + r = conn.execute(table.insert(), {"id": 30, "data": "d1"}) eq_(r.inserted_primary_key, [30]) # execute with prefetch id - r = conn.execute(table.insert(), {'data': 'd2'}) + r = conn.execute(table.insert(), {"data": "d2"}) eq_(r.inserted_primary_key, [1]) # executemany with explicit ids conn.execute( table.insert(), - {'id': 31, 'data': 'd3'}, {'id': 32, 'data': 'd4'}) + {"id": 31, "data": "d3"}, + {"id": 32, "data": "d4"}, + ) # executemany, uses SERIAL - conn.execute(table.insert(), {'data': 'd5'}, {'data': 'd6'}) + conn.execute(table.insert(), {"data": "d5"}, {"data": "d6"}) # single execute, explicit id, inline conn.execute( - table.insert(inline=True), {'id': 33, 'data': 'd7'}) + table.insert(inline=True), {"id": 33, "data": "d7"} + ) # single execute, inline, uses SERIAL - conn.execute(table.insert(inline=True), {'data': 'd8'}) + conn.execute(table.insert(inline=True), {"data": "d8"}) asserter.assert_( - DialectSQL('INSERT INTO testtable (id, data) VALUES (:id, :data)', - {'id': 30, 'data': 'd1'}), - DialectSQL('INSERT INTO testtable (data) VALUES (:data) RETURNING ' - 'testtable.id', {'data': 'd2'}), - DialectSQL('INSERT INTO testtable (id, data) VALUES (:id, :data)', - [{'id': 31, 'data': 'd3'}, {'id': 32, 'data': 'd4'}]), - DialectSQL('INSERT INTO testtable (data) VALUES (:data)', - [{'data': 'd5'}, {'data': 'd6'}]), - DialectSQL('INSERT INTO testtable (id, data) VALUES (:id, :data)', - [{'id': 33, 'data': 'd7'}]), - DialectSQL('INSERT INTO testtable (data) VALUES (:data)', - [{'data': 'd8'}]), + DialectSQL( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {"id": 30, "data": "d1"}, + ), + DialectSQL( + "INSERT INTO testtable (data) VALUES (:data) RETURNING " + "testtable.id", + {"data": "d2"}, + ), + DialectSQL( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}], + ), + DialectSQL( + "INSERT INTO testtable (data) VALUES (:data)", + [{"data": "d5"}, {"data": "d6"}], + ), + DialectSQL( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{"id": 33, "data": "d7"}], + ), + DialectSQL( + "INSERT INTO testtable (data) VALUES (:data)", [{"data": "d8"}] + ), ) with engine.connect() as conn: eq_( conn.execute(table.select()).fetchall(), [ - (30, 'd1'), - (1, 'd2'), - (31, 'd3'), - (32, 'd4'), - (2, 'd5'), - (3, 'd6'), - (33, 'd7'), - (4, 'd8'), - ] + (30, "d1"), + (1, "d2"), + (31, "d3"), + (32, "d4"), + (2, "d5"), + (3, "d6"), + (33, "d7"), + (4, "d8"), + ], ) conn.execute(table.delete()) @@ -372,195 +412,249 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): with self.sql_execution_asserter(engine) as asserter: with engine.connect() as conn: - conn.execute(table.insert(), {'id': 30, 'data': 'd1'}) - r = conn.execute(table.insert(), {'data': 'd2'}) + conn.execute(table.insert(), {"id": 30, "data": "d1"}) + r = conn.execute(table.insert(), {"data": "d2"}) eq_(r.inserted_primary_key, [5]) conn.execute( table.insert(), - {'id': 31, 'data': 'd3'}, {'id': 32, 'data': 'd4'}) - conn.execute(table.insert(), {'data': 'd5'}, {'data': 'd6'}) + {"id": 31, "data": "d3"}, + {"id": 32, "data": "d4"}, + ) + conn.execute(table.insert(), {"data": "d5"}, {"data": "d6"}) conn.execute( - table.insert(inline=True), {'id': 33, 'data': 'd7'}) - conn.execute(table.insert(inline=True), {'data': 'd8'}) + table.insert(inline=True), {"id": 33, "data": "d7"} + ) + conn.execute(table.insert(inline=True), {"data": "d8"}) asserter.assert_( - DialectSQL('INSERT INTO testtable (id, data) VALUES (:id, :data)', - {'id': 30, 'data': 'd1'}), - DialectSQL('INSERT INTO testtable (data) VALUES (:data) RETURNING ' - 'testtable.id', {'data': 'd2'}), - DialectSQL('INSERT INTO testtable (id, data) VALUES (:id, :data)', - [{'id': 31, 'data': 'd3'}, {'id': 32, 'data': 'd4'}]), - DialectSQL('INSERT INTO testtable (data) VALUES (:data)', - [{'data': 'd5'}, {'data': 'd6'}]), - DialectSQL('INSERT INTO testtable (id, data) VALUES (:id, :data)', - [{'id': 33, 'data': 'd7'}]), DialectSQL( - 'INSERT INTO testtable (data) VALUES (:data)', - [{'data': 'd8'}]), + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {"id": 30, "data": "d1"}, + ), + DialectSQL( + "INSERT INTO testtable (data) VALUES (:data) RETURNING " + "testtable.id", + {"data": "d2"}, + ), + DialectSQL( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}], + ), + DialectSQL( + "INSERT INTO testtable (data) VALUES (:data)", + [{"data": "d5"}, {"data": "d6"}], + ), + DialectSQL( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{"id": 33, "data": "d7"}], + ), + DialectSQL( + "INSERT INTO testtable (data) VALUES (:data)", [{"data": "d8"}] + ), ) with engine.connect() as conn: eq_( conn.execute(table.select()).fetchall(), [ - (30, 'd1'), - (5, 'd2'), - (31, 'd3'), - (32, 'd4'), - (6, 'd5'), - (7, 'd6'), - (33, 'd7'), - (8, 'd8'), - ] + (30, "d1"), + (5, "d2"), + (31, "d3"), + (32, "d4"), + (6, "d5"), + (7, "d6"), + (33, "d7"), + (8, "d8"), + ], ) conn.execute(table.delete()) def _assert_data_with_sequence(self, table, seqname): - engine = \ - engines.testing_engine(options={'implicit_returning': False}) + engine = engines.testing_engine(options={"implicit_returning": False}) with self.sql_execution_asserter(engine) as asserter: with engine.connect() as conn: - conn.execute(table.insert(), {'id': 30, 'data': 'd1'}) - conn.execute(table.insert(), {'data': 'd2'}) - conn.execute(table.insert(), - {'id': 31, 'data': 'd3'}, - {'id': 32, 'data': 'd4'}) - conn.execute(table.insert(), {'data': 'd5'}, {'data': 'd6'}) - conn.execute(table.insert(inline=True), - {'id': 33, 'data': 'd7'}) - conn.execute(table.insert(inline=True), {'data': 'd8'}) + conn.execute(table.insert(), {"id": 30, "data": "d1"}) + conn.execute(table.insert(), {"data": "d2"}) + conn.execute( + table.insert(), + {"id": 31, "data": "d3"}, + {"id": 32, "data": "d4"}, + ) + conn.execute(table.insert(), {"data": "d5"}, {"data": "d6"}) + conn.execute( + table.insert(inline=True), {"id": 33, "data": "d7"} + ) + conn.execute(table.insert(inline=True), {"data": "d8"}) asserter.assert_( - DialectSQL('INSERT INTO testtable (id, data) VALUES (:id, :data)', - {'id': 30, 'data': 'd1'}), + DialectSQL( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {"id": 30, "data": "d1"}, + ), CursorSQL("select nextval('my_seq')"), - DialectSQL('INSERT INTO testtable (id, data) VALUES (:id, :data)', - {'id': 1, 'data': 'd2'}), - DialectSQL('INSERT INTO testtable (id, data) VALUES (:id, :data)', - [{'id': 31, 'data': 'd3'}, {'id': 32, 'data': 'd4'}]), + DialectSQL( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {"id": 1, "data": "d2"}, + ), + DialectSQL( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}], + ), DialectSQL( "INSERT INTO testtable (id, data) VALUES (nextval('%s'), " - ":data)" % seqname, [{'data': 'd5'}, {'data': 'd6'}]), - DialectSQL('INSERT INTO testtable (id, data) VALUES (:id, :data)', - [{'id': 33, 'data': 'd7'}]), + ":data)" % seqname, + [{"data": "d5"}, {"data": "d6"}], + ), + DialectSQL( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{"id": 33, "data": "d7"}], + ), DialectSQL( "INSERT INTO testtable (id, data) VALUES (nextval('%s'), " - ":data)" % seqname, [{'data': 'd8'}]), + ":data)" % seqname, + [{"data": "d8"}], + ), ) with engine.connect() as conn: eq_( conn.execute(table.select()).fetchall(), [ - (30, 'd1'), - (1, 'd2'), - (31, 'd3'), - (32, 'd4'), - (2, 'd5'), - (3, 'd6'), - (33, 'd7'), - (4, 'd8'), - ] + (30, "d1"), + (1, "d2"), + (31, "d3"), + (32, "d4"), + (2, "d5"), + (3, "d6"), + (33, "d7"), + (4, "d8"), + ], ) # cant test reflection here since the Sequence must be # explicitly specified def _assert_data_with_sequence_returning(self, table, seqname): - engine = \ - engines.testing_engine(options={'implicit_returning': True}) + engine = engines.testing_engine(options={"implicit_returning": True}) with self.sql_execution_asserter(engine) as asserter: with engine.connect() as conn: - conn.execute(table.insert(), {'id': 30, 'data': 'd1'}) - conn.execute(table.insert(), {'data': 'd2'}) - conn.execute(table.insert(), - {'id': 31, 'data': 'd3'}, - {'id': 32, 'data': 'd4'}) - conn.execute(table.insert(), {'data': 'd5'}, {'data': 'd6'}) + conn.execute(table.insert(), {"id": 30, "data": "d1"}) + conn.execute(table.insert(), {"data": "d2"}) conn.execute( - table.insert(inline=True), {'id': 33, 'data': 'd7'}) - conn.execute(table.insert(inline=True), {'data': 'd8'}) + table.insert(), + {"id": 31, "data": "d3"}, + {"id": 32, "data": "d4"}, + ) + conn.execute(table.insert(), {"data": "d5"}, {"data": "d6"}) + conn.execute( + table.insert(inline=True), {"id": 33, "data": "d7"} + ) + conn.execute(table.insert(inline=True), {"data": "d8"}) asserter.assert_( - DialectSQL('INSERT INTO testtable (id, data) VALUES (:id, :data)', - {'id': 30, 'data': 'd1'}), - DialectSQL("INSERT INTO testtable (id, data) VALUES " - "(nextval('my_seq'), :data) RETURNING testtable.id", - {'data': 'd2'}), - DialectSQL('INSERT INTO testtable (id, data) VALUES (:id, :data)', - [{'id': 31, 'data': 'd3'}, {'id': 32, 'data': 'd4'}]), + DialectSQL( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {"id": 30, "data": "d1"}, + ), + DialectSQL( + "INSERT INTO testtable (id, data) VALUES " + "(nextval('my_seq'), :data) RETURNING testtable.id", + {"data": "d2"}, + ), + DialectSQL( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}], + ), DialectSQL( "INSERT INTO testtable (id, data) VALUES (nextval('%s'), " - ":data)" % seqname, [{'data': 'd5'}, {'data': 'd6'}]), - DialectSQL('INSERT INTO testtable (id, data) VALUES (:id, :data)', - [{'id': 33, 'data': 'd7'}]), + ":data)" % seqname, + [{"data": "d5"}, {"data": "d6"}], + ), + DialectSQL( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{"id": 33, "data": "d7"}], + ), DialectSQL( "INSERT INTO testtable (id, data) VALUES (nextval('%s'), " - ":data)" % seqname, [{'data': 'd8'}]), + ":data)" % seqname, + [{"data": "d8"}], + ), ) with engine.connect() as conn: eq_( conn.execute(table.select()).fetchall(), [ - (30, 'd1'), - (1, 'd2'), - (31, 'd3'), - (32, 'd4'), - (2, 'd5'), - (3, 'd6'), - (33, 'd7'), - (4, 'd8'), - ] + (30, "d1"), + (1, "d2"), + (31, "d3"), + (32, "d4"), + (2, "d5"), + (3, "d6"), + (33, "d7"), + (4, "d8"), + ], ) # cant test reflection here since the Sequence must be # explicitly specified def _assert_data_noautoincrement(self, table): - engine = \ - engines.testing_engine(options={'implicit_returning': False}) + engine = engines.testing_engine(options={"implicit_returning": False}) with engine.connect() as conn: - conn.execute(table.insert(), {'id': 30, 'data': 'd1'}) + conn.execute(table.insert(), {"id": 30, "data": "d1"}) with expect_warnings( - ".*has no Python-side or server-side default.*", + ".*has no Python-side or server-side default.*" ): assert_raises( (exc.IntegrityError, exc.ProgrammingError), - conn.execute, table.insert(), {'data': 'd2'}) + conn.execute, + table.insert(), + {"data": "d2"}, + ) with expect_warnings( - ".*has no Python-side or server-side default.*", + ".*has no Python-side or server-side default.*" ): assert_raises( (exc.IntegrityError, exc.ProgrammingError), - conn.execute, table.insert(), {'data': 'd2'}, - {'data': 'd3'}) + conn.execute, + table.insert(), + {"data": "d2"}, + {"data": "d3"}, + ) with expect_warnings( - ".*has no Python-side or server-side default.*", + ".*has no Python-side or server-side default.*" ): assert_raises( (exc.IntegrityError, exc.ProgrammingError), - conn.execute, table.insert(), {'data': 'd2'}) + conn.execute, + table.insert(), + {"data": "d2"}, + ) with expect_warnings( - ".*has no Python-side or server-side default.*", + ".*has no Python-side or server-side default.*" ): assert_raises( (exc.IntegrityError, exc.ProgrammingError), - conn.execute, table.insert(), {'data': 'd2'}, - {'data': 'd3'}) + conn.execute, + table.insert(), + {"data": "d2"}, + {"data": "d3"}, + ) conn.execute( table.insert(), - {'id': 31, 'data': 'd2'}, {'id': 32, 'data': 'd3'}) - conn.execute(table.insert(inline=True), {'id': 33, 'data': 'd4'}) - eq_(conn.execute(table.select()).fetchall(), [ - (30, 'd1'), - (31, 'd2'), - (32, 'd3'), - (33, 'd4')]) + {"id": 31, "data": "d2"}, + {"id": 32, "data": "d3"}, + ) + conn.execute(table.insert(inline=True), {"id": 33, "data": "d4"}) + eq_( + conn.execute(table.select()).fetchall(), + [(30, "d1"), (31, "d2"), (32, "d3"), (33, "d4")], + ) conn.execute(table.delete()) # test the same series of events using a reflected version of @@ -569,35 +663,42 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): m2 = MetaData(engine) table = Table(table.name, m2, autoload=True) with engine.connect() as conn: - conn.execute(table.insert(), {'id': 30, 'data': 'd1'}) + conn.execute(table.insert(), {"id": 30, "data": "d1"}) with expect_warnings( - ".*has no Python-side or server-side default.*", + ".*has no Python-side or server-side default.*" ): assert_raises( (exc.IntegrityError, exc.ProgrammingError), - conn.execute, table.insert(), {'data': 'd2'}) + conn.execute, + table.insert(), + {"data": "d2"}, + ) with expect_warnings( - ".*has no Python-side or server-side default.*", + ".*has no Python-side or server-side default.*" ): assert_raises( (exc.IntegrityError, exc.ProgrammingError), - conn.execute, table.insert(), {'data': 'd2'}, - {'data': 'd3'}) + conn.execute, + table.insert(), + {"data": "d2"}, + {"data": "d3"}, + ) conn.execute( table.insert(), - {'id': 31, 'data': 'd2'}, {'id': 32, 'data': 'd3'}) - conn.execute(table.insert(inline=True), {'id': 33, 'data': 'd4'}) - eq_(conn.execute(table.select()).fetchall(), [ - (30, 'd1'), - (31, 'd2'), - (32, 'd3'), - (33, 'd4')]) + {"id": 31, "data": "d2"}, + {"id": 32, "data": "d3"}, + ) + conn.execute(table.insert(inline=True), {"id": 33, "data": "d4"}) + eq_( + conn.execute(table.select()).fetchall(), + [(30, "d1"), (31, "d2"), (32, "d3"), (33, "d4")], + ) class MatchTest(fixtures.TestBase, AssertsCompiledSQL): - __only_on__ = 'postgresql >= 8.3' + __only_on__ = "postgresql >= 8.3" __backend__ = True @classmethod @@ -605,135 +706,204 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): global metadata, cattable, matchtable metadata = MetaData(testing.db) cattable = Table( - 'cattable', metadata, - Column( - 'id', Integer, primary_key=True), - Column( - 'description', String(50))) + "cattable", + metadata, + Column("id", Integer, primary_key=True), + Column("description", String(50)), + ) matchtable = Table( - 'matchtable', metadata, - Column( - 'id', Integer, primary_key=True), - Column( - 'title', String(200)), - Column( - 'category_id', Integer, ForeignKey('cattable.id'))) + "matchtable", + metadata, + Column("id", Integer, primary_key=True), + Column("title", String(200)), + Column("category_id", Integer, ForeignKey("cattable.id")), + ) metadata.create_all() - cattable.insert().execute([{'id': 1, 'description': 'Python'}, - {'id': 2, 'description': 'Ruby'}]) + cattable.insert().execute( + [ + {"id": 1, "description": "Python"}, + {"id": 2, "description": "Ruby"}, + ] + ) matchtable.insert().execute( - [{'id': 1, 'title': 'Agile Web Development with Rails', - 'category_id': 2}, - {'id': 2, 'title': 'Dive Into Python', 'category_id': 1}, - {'id': 3, 'title': "Programming Matz's Ruby", 'category_id': 2}, - {'id': 4, 'title': 'The Definitive Guide to Django', - 'category_id': 1}, - {'id': 5, 'title': 'Python in a Nutshell', 'category_id': 1}]) + [ + { + "id": 1, + "title": "Agile Web Development with Rails", + "category_id": 2, + }, + {"id": 2, "title": "Dive Into Python", "category_id": 1}, + { + "id": 3, + "title": "Programming Matz's Ruby", + "category_id": 2, + }, + { + "id": 4, + "title": "The Definitive Guide to Django", + "category_id": 1, + }, + {"id": 5, "title": "Python in a Nutshell", "category_id": 1}, + ] + ) @classmethod def teardown_class(cls): metadata.drop_all() - @testing.fails_on('postgresql+pg8000', 'uses positional') - @testing.fails_on('postgresql+zxjdbc', 'uses qmark') + @testing.fails_on("postgresql+pg8000", "uses positional") + @testing.fails_on("postgresql+zxjdbc", "uses qmark") def test_expression_pyformat(self): - self.assert_compile(matchtable.c.title.match('somstr'), - 'matchtable.title @@ to_tsquery(%(title_1)s' - ')') - - @testing.fails_on('postgresql+psycopg2', 'uses pyformat') - @testing.fails_on('postgresql+pypostgresql', 'uses pyformat') - @testing.fails_on('postgresql+pygresql', 'uses pyformat') - @testing.fails_on('postgresql+zxjdbc', 'uses qmark') - @testing.fails_on('postgresql+psycopg2cffi', 'uses pyformat') + self.assert_compile( + matchtable.c.title.match("somstr"), + "matchtable.title @@ to_tsquery(%(title_1)s" ")", + ) + + @testing.fails_on("postgresql+psycopg2", "uses pyformat") + @testing.fails_on("postgresql+pypostgresql", "uses pyformat") + @testing.fails_on("postgresql+pygresql", "uses pyformat") + @testing.fails_on("postgresql+zxjdbc", "uses qmark") + @testing.fails_on("postgresql+psycopg2cffi", "uses pyformat") def test_expression_positional(self): - self.assert_compile(matchtable.c.title.match('somstr'), - 'matchtable.title @@ to_tsquery(%s)') + self.assert_compile( + matchtable.c.title.match("somstr"), + "matchtable.title @@ to_tsquery(%s)", + ) def test_simple_match(self): - results = matchtable.select().where( - matchtable.c.title.match('python')).order_by( - matchtable.c.id).execute().fetchall() + results = ( + matchtable.select() + .where(matchtable.c.title.match("python")) + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([2, 5], [r.id for r in results]) def test_not_match(self): - results = matchtable.select().where( - ~matchtable.c.title.match('python')).order_by( - matchtable.c.id).execute().fetchall() + results = ( + matchtable.select() + .where(~matchtable.c.title.match("python")) + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([1, 3, 4], [r.id for r in results]) def test_simple_match_with_apostrophe(self): - results = matchtable.select().where( - matchtable.c.title.match("Matz's")).execute().fetchall() + results = ( + matchtable.select() + .where(matchtable.c.title.match("Matz's")) + .execute() + .fetchall() + ) eq_([3], [r.id for r in results]) def test_simple_derivative_match(self): - results = matchtable.select().where( - matchtable.c.title.match('nutshells')).execute().fetchall() + results = ( + matchtable.select() + .where(matchtable.c.title.match("nutshells")) + .execute() + .fetchall() + ) eq_([5], [r.id for r in results]) def test_or_match(self): - results1 = matchtable.select().where( - or_( - matchtable.c.title.match('nutshells'), - matchtable.c.title.match('rubies'))).order_by( - matchtable.c.id).execute().fetchall() + results1 = ( + matchtable.select() + .where( + or_( + matchtable.c.title.match("nutshells"), + matchtable.c.title.match("rubies"), + ) + ) + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([3, 5], [r.id for r in results1]) - results2 = matchtable.select().where( - matchtable.c.title.match('nutshells | rubies')).order_by( - matchtable.c.id).execute().fetchall() + results2 = ( + matchtable.select() + .where(matchtable.c.title.match("nutshells | rubies")) + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([3, 5], [r.id for r in results2]) def test_and_match(self): - results1 = matchtable.select().where( - and_( - matchtable.c.title.match('python'), - matchtable.c.title.match('nutshells'))).execute().fetchall() + results1 = ( + matchtable.select() + .where( + and_( + matchtable.c.title.match("python"), + matchtable.c.title.match("nutshells"), + ) + ) + .execute() + .fetchall() + ) eq_([5], [r.id for r in results1]) - results2 = \ - matchtable.select().where( - matchtable.c.title.match('python & nutshells' - )).execute().fetchall() + results2 = ( + matchtable.select() + .where(matchtable.c.title.match("python & nutshells")) + .execute() + .fetchall() + ) eq_([5], [r.id for r in results2]) def test_match_across_joins(self): - results = matchtable.select().where( - and_( - cattable.c.id == matchtable.c.category_id, or_( - cattable.c.description.match('Ruby'), - matchtable.c.title.match('nutshells')))).order_by( - matchtable.c.id).execute().fetchall() + results = ( + matchtable.select() + .where( + and_( + cattable.c.id == matchtable.c.category_id, + or_( + cattable.c.description.match("Ruby"), + matchtable.c.title.match("nutshells"), + ), + ) + ) + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([1, 3, 5], [r.id for r in results]) class TupleTest(fixtures.TestBase): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True def test_tuple_containment(self): for test, exp in [ - ([('a', 'b')], True), - ([('a', 'c')], False), - ([('f', 'q'), ('a', 'b')], True), - ([('f', 'q'), ('a', 'c')], False) + ([("a", "b")], True), + ([("a", "c")], False), + ([("f", "q"), ("a", "b")], True), + ([("f", "q"), ("a", "c")], False), ]: eq_( testing.db.execute( - select([ - tuple_( - literal_column("'a'"), - literal_column("'b'") - ). - in_([ - tuple_(*[ - literal_column("'%s'" % letter) - for letter in elem - ]) for elem in test - ]) - ]) + select( + [ + tuple_( + literal_column("'a'"), literal_column("'b'") + ).in_( + [ + tuple_( + *[ + literal_column("'%s'" % letter) + for letter in elem + ] + ) + for elem in test + ] + ) + ] + ) ).scalar(), - exp + exp, ) @@ -746,15 +916,17 @@ class ExtractTest(fixtures.TablesTest): are not needed; see [ticket:2740]. """ - __only_on__ = 'postgresql' + + __only_on__ = "postgresql" __backend__ = True - run_inserts = 'once' + run_inserts = "once" run_deletes = None @classmethod def setup_bind(cls): from sqlalchemy import event + eng = engines.testing_engine() @event.listens_for(eng, "connect") @@ -767,33 +939,35 @@ class ExtractTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table('t', metadata, - Column('id', Integer, primary_key=True), - Column('dtme', DateTime), - Column('dt', Date), - Column('tm', Time), - Column('intv', postgresql.INTERVAL), - Column('dttz', DateTime(timezone=True)) - ) + Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("dtme", DateTime), + Column("dt", Date), + Column("tm", Time), + Column("intv", postgresql.INTERVAL), + Column("dttz", DateTime(timezone=True)), + ) @classmethod def insert_data(cls): # TODO: why does setting hours to anything # not affect the TZ in the DB col ? class TZ(datetime.tzinfo): - def utcoffset(self, dt): return datetime.timedelta(hours=4) cls.bind.execute( cls.tables.t.insert(), { - 'dtme': datetime.datetime(2012, 5, 10, 12, 15, 25), - 'dt': datetime.date(2012, 5, 10), - 'tm': datetime.time(12, 15, 25), - 'intv': datetime.timedelta(seconds=570), - 'dttz': datetime.datetime(2012, 5, 10, 12, 15, 25, - tzinfo=TZ()) + "dtme": datetime.datetime(2012, 5, 10, 12, 15, 25), + "dt": datetime.date(2012, 5, 10), + "tm": datetime.time(12, 15, 25), + "intv": datetime.timedelta(seconds=570), + "dttz": datetime.datetime( + 2012, 5, 10, 12, 15, 25, tzinfo=TZ() + ), }, ) @@ -801,19 +975,27 @@ class ExtractTest(fixtures.TablesTest): t = self.tables.t if field == "all": - fields = {"year": 2012, "month": 5, "day": 10, - "epoch": 1336652125.0, - "hour": 12, "minute": 15} + fields = { + "year": 2012, + "month": 5, + "day": 10, + "epoch": 1336652125.0, + "hour": 12, + "minute": 15, + } elif field == "time": fields = {"hour": 12, "minute": 15, "second": 25} - elif field == 'date': + elif field == "date": fields = {"year": 2012, "month": 5, "day": 10} - elif field == 'all+tz': - fields = {"year": 2012, "month": 5, "day": 10, - "epoch": 1336637725.0, - "hour": 8, - "timezone": 0 - } + elif field == "all+tz": + fields = { + "year": 2012, + "month": 5, + "day": 10, + "epoch": 1336637725.0, + "hour": 8, + "timezone": 0, + } else: fields = field @@ -822,7 +1004,8 @@ class ExtractTest(fixtures.TablesTest): for field in fields: result = self.bind.scalar( - select([extract(field, expr)]).select_from(t)) + select([extract(field, expr)]).select_from(t) + ) eq_(result, fields[field]) def test_one(self): @@ -831,46 +1014,74 @@ class ExtractTest(fixtures.TablesTest): def test_two(self): t = self.tables.t - self._test(t.c.dtme + t.c.intv, - overrides={"epoch": 1336652695.0, "minute": 24}) + self._test( + t.c.dtme + t.c.intv, + overrides={"epoch": 1336652695.0, "minute": 24}, + ) def test_three(self): self.tables.t - actual_ts = self.bind.scalar(func.current_timestamp()) - \ - datetime.timedelta(days=5) - self._test(func.current_timestamp() - datetime.timedelta(days=5), - {"hour": actual_ts.hour, "year": actual_ts.year, - "month": actual_ts.month} - ) + actual_ts = self.bind.scalar( + func.current_timestamp() + ) - datetime.timedelta(days=5) + self._test( + func.current_timestamp() - datetime.timedelta(days=5), + { + "hour": actual_ts.hour, + "year": actual_ts.year, + "month": actual_ts.month, + }, + ) def test_four(self): t = self.tables.t - self._test(datetime.timedelta(days=5) + t.c.dt, - overrides={"day": 15, "epoch": 1337040000.0, "hour": 0, - "minute": 0} - ) + self._test( + datetime.timedelta(days=5) + t.c.dt, + overrides={ + "day": 15, + "epoch": 1337040000.0, + "hour": 0, + "minute": 0, + }, + ) def test_five(self): t = self.tables.t - self._test(func.coalesce(t.c.dtme, func.current_timestamp()), - overrides={"epoch": 1336652125.0}) + self._test( + func.coalesce(t.c.dtme, func.current_timestamp()), + overrides={"epoch": 1336652125.0}, + ) def test_six(self): t = self.tables.t - self._test(t.c.tm + datetime.timedelta(seconds=30), "time", - overrides={"second": 55}) + self._test( + t.c.tm + datetime.timedelta(seconds=30), + "time", + overrides={"second": 55}, + ) def test_seven(self): - self._test(literal(datetime.timedelta(seconds=10)) - - literal(datetime.timedelta(seconds=10)), "all", - overrides={"hour": 0, "minute": 0, "month": 0, - "year": 0, "day": 0, "epoch": 0}) + self._test( + literal(datetime.timedelta(seconds=10)) + - literal(datetime.timedelta(seconds=10)), + "all", + overrides={ + "hour": 0, + "minute": 0, + "month": 0, + "year": 0, + "day": 0, + "epoch": 0, + }, + ) def test_eight(self): t = self.tables.t - self._test(t.c.tm + datetime.timedelta(seconds=30), - {"hour": 12, "minute": 15, "second": 55}) + self._test( + t.c.tm + datetime.timedelta(seconds=30), + {"hour": 12, "minute": 15, "second": 55}, + ) def test_nine(self): self._test(text("t.dt + t.tm")) @@ -880,22 +1091,22 @@ class ExtractTest(fixtures.TablesTest): self._test(t.c.dt + t.c.tm) def test_eleven(self): - self._test(func.current_timestamp() - func.current_timestamp(), - {"year": 0, "month": 0, "day": 0, "hour": 0} - ) + self._test( + func.current_timestamp() - func.current_timestamp(), + {"year": 0, "month": 0, "day": 0, "hour": 0}, + ) def test_twelve(self): t = self.tables.t - actual_ts = self.bind.scalar( - func.current_timestamp()).replace(tzinfo=None) - \ - datetime.datetime(2012, 5, 10, 12, 15, 25) + actual_ts = self.bind.scalar(func.current_timestamp()).replace( + tzinfo=None + ) - datetime.datetime(2012, 5, 10, 12, 15, 25) self._test( - func.current_timestamp() - func.coalesce( - t.c.dtme, - func.current_timestamp() - ), - {"day": actual_ts.days}) + func.current_timestamp() + - func.coalesce(t.c.dtme, func.current_timestamp()), + {"day": actual_ts.days}, + ) def test_thirteen(self): t = self.tables.t @@ -907,6 +1118,7 @@ class ExtractTest(fixtures.TablesTest): def test_fifteen(self): t = self.tables.t - self._test(datetime.timedelta(days=5) + t.c.dtme, - overrides={"day": 15, "epoch": 1337084125.0} - ) + self._test( + datetime.timedelta(days=5) + t.c.dtme, + overrides={"day": 15, "epoch": 1337084125.0}, + ) diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index 5c4214430f..d9facad6d9 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -2,14 +2,27 @@ from sqlalchemy.engine import reflection from sqlalchemy.sql.schema import CheckConstraint -from sqlalchemy.testing.assertions import eq_, assert_raises, \ - AssertsExecutionResults +from sqlalchemy.testing.assertions import ( + eq_, + assert_raises, + AssertsExecutionResults, +) from sqlalchemy.testing import fixtures from sqlalchemy import testing from sqlalchemy import inspect -from sqlalchemy import Table, Column, MetaData, Integer, String, \ - PrimaryKeyConstraint, ForeignKey, join, Sequence, UniqueConstraint, \ - Index +from sqlalchemy import ( + Table, + Column, + MetaData, + Integer, + String, + PrimaryKeyConstraint, + ForeignKey, + join, + Sequence, + UniqueConstraint, + Index, +) from sqlalchemy import exc import sqlalchemy as sa from sqlalchemy.dialects.postgresql import base as postgresql @@ -23,20 +36,24 @@ import itertools class ForeignTableReflectionTest(fixtures.TablesTest, AssertsExecutionResults): """Test reflection on foreign tables""" - __requires__ = 'postgresql_test_dblink', - __only_on__ = 'postgresql >= 9.3' + __requires__ = ("postgresql_test_dblink",) + __only_on__ = "postgresql >= 9.3" __backend__ = True @classmethod def define_tables(cls, metadata): from sqlalchemy.testing import config + dblink = config.file_config.get( - 'sqla_testing', 'postgres_test_db_link') + "sqla_testing", "postgres_test_db_link" + ) testtable = Table( - 'testtable', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30))) + "testtable", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(30)), + ) for ddl in [ "CREATE SERVER test_server FOREIGN DATA WRAPPER postgres_fdw " @@ -51,46 +68,50 @@ class ForeignTableReflectionTest(fixtures.TablesTest, AssertsExecutionResults): sa.event.listen(metadata, "after_create", sa.DDL(ddl)) for ddl in [ - 'DROP FOREIGN TABLE test_foreigntable', - 'DROP USER MAPPING FOR public SERVER test_server', - "DROP SERVER test_server" + "DROP FOREIGN TABLE test_foreigntable", + "DROP USER MAPPING FOR public SERVER test_server", + "DROP SERVER test_server", ]: sa.event.listen(metadata, "before_drop", sa.DDL(ddl)) def test_foreign_table_is_reflected(self): metadata = MetaData(testing.db) - table = Table('test_foreigntable', metadata, autoload=True) - eq_(set(table.columns.keys()), set(['id', 'data']), - "Columns of reflected foreign table didn't equal expected columns") + table = Table("test_foreigntable", metadata, autoload=True) + eq_( + set(table.columns.keys()), + set(["id", "data"]), + "Columns of reflected foreign table didn't equal expected columns", + ) def test_get_foreign_table_names(self): inspector = inspect(testing.db) with testing.db.connect() as conn: ft_names = inspector.get_foreign_table_names() - eq_(ft_names, ['test_foreigntable']) + eq_(ft_names, ["test_foreigntable"]) def test_get_table_names_no_foreign(self): inspector = inspect(testing.db) with testing.db.connect() as conn: names = inspector.get_table_names() - eq_(names, ['testtable']) + eq_(names, ["testtable"]) -class PartitionedReflectionTest( - fixtures.TablesTest, AssertsExecutionResults): +class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults): # partitioned table reflection, issue #4237 - __only_on__ = 'postgresql >= 10' + __only_on__ = "postgresql >= 10" __backend__ = True @classmethod def define_tables(cls, metadata): # the actual function isn't reflected yet dv = Table( - 'data_values', metadata, - Column('modulus', Integer, nullable=False), - Column('data', String(30)), - postgresql_partition_by='range(modulus)') + "data_values", + metadata, + Column("modulus", Integer, nullable=False), + Column("data", String(30)), + postgresql_partition_by="range(modulus)", + ) # looks like this is reflected prior to #4237 sa.event.listen( @@ -98,103 +119,99 @@ class PartitionedReflectionTest( "after_create", sa.DDL( "CREATE TABLE data_values_4_10 PARTITION OF data_values " - "FOR VALUES FROM (4) TO (10)") + "FOR VALUES FROM (4) TO (10)" + ), ) def test_get_tablenames(self): - assert {'data_values', 'data_values_4_10'}.issubset( + assert {"data_values", "data_values_4_10"}.issubset( inspect(testing.db).get_table_names() ) def test_reflect_cols(self): - cols = inspect(testing.db).get_columns('data_values') - eq_( - [c['name'] for c in cols], - ['modulus', 'data'] - ) + cols = inspect(testing.db).get_columns("data_values") + eq_([c["name"] for c in cols], ["modulus", "data"]) def test_reflect_cols_from_partition(self): - cols = inspect(testing.db).get_columns('data_values_4_10') - eq_( - [c['name'] for c in cols], - ['modulus', 'data'] - ) + cols = inspect(testing.db).get_columns("data_values_4_10") + eq_([c["name"] for c in cols], ["modulus", "data"]) class MaterializedViewReflectionTest( - fixtures.TablesTest, AssertsExecutionResults): + fixtures.TablesTest, AssertsExecutionResults +): """Test reflection on materialized views""" - __only_on__ = 'postgresql >= 9.3' + __only_on__ = "postgresql >= 9.3" __backend__ = True @classmethod def define_tables(cls, metadata): testtable = Table( - 'testtable', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30))) + "testtable", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(30)), + ) # insert data before we create the view @sa.event.listens_for(testtable, "after_create") def insert_data(target, connection, **kw): - connection.execute( - target.insert(), - {"id": 89, "data": 'd1'} - ) + connection.execute(target.insert(), {"id": 89, "data": "d1"}) materialized_view = sa.DDL( - "CREATE MATERIALIZED VIEW test_mview AS " - "SELECT * FROM testtable") + "CREATE MATERIALIZED VIEW test_mview AS " "SELECT * FROM testtable" + ) plain_view = sa.DDL( - "CREATE VIEW test_regview AS " - "SELECT * FROM testtable") + "CREATE VIEW test_regview AS " "SELECT * FROM testtable" + ) - sa.event.listen(testtable, 'after_create', plain_view) - sa.event.listen(testtable, 'after_create', materialized_view) + sa.event.listen(testtable, "after_create", plain_view) + sa.event.listen(testtable, "after_create", materialized_view) sa.event.listen( - testtable, 'before_drop', - sa.DDL("DROP MATERIALIZED VIEW test_mview") + testtable, + "before_drop", + sa.DDL("DROP MATERIALIZED VIEW test_mview"), ) sa.event.listen( - testtable, 'before_drop', - sa.DDL("DROP VIEW test_regview") + testtable, "before_drop", sa.DDL("DROP VIEW test_regview") ) def test_mview_is_reflected(self): metadata = MetaData(testing.db) - table = Table('test_mview', metadata, autoload=True) - eq_(set(table.columns.keys()), set(['id', 'data']), - "Columns of reflected mview didn't equal expected columns") + table = Table("test_mview", metadata, autoload=True) + eq_( + set(table.columns.keys()), + set(["id", "data"]), + "Columns of reflected mview didn't equal expected columns", + ) def test_mview_select(self): metadata = MetaData(testing.db) - table = Table('test_mview', metadata, autoload=True) - eq_( - table.select().execute().fetchall(), - [(89, 'd1',)] - ) + table = Table("test_mview", metadata, autoload=True) + eq_(table.select().execute().fetchall(), [(89, "d1")]) def test_get_view_names(self): insp = inspect(testing.db) - eq_(set(insp.get_view_names()), set(['test_regview', 'test_mview'])) + eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"])) def test_get_view_names_plain(self): insp = inspect(testing.db) eq_( - set(insp.get_view_names(include=('plain',))), - set(['test_regview'])) + set(insp.get_view_names(include=("plain",))), set(["test_regview"]) + ) def test_get_view_names_plain_string(self): insp = inspect(testing.db) - eq_(set(insp.get_view_names(include='plain')), set(['test_regview'])) + eq_(set(insp.get_view_names(include="plain")), set(["test_regview"])) def test_get_view_names_materialized(self): insp = inspect(testing.db) eq_( - set(insp.get_view_names(include=('materialized',))), - set(['test_mview'])) + set(insp.get_view_names(include=("materialized",))), + set(["test_mview"]), + ) def test_get_view_names_empty(self): insp = inspect(testing.db) @@ -204,16 +221,18 @@ class MaterializedViewReflectionTest( insp = inspect(testing.db) eq_( re.sub( - r'[\n\t ]+', ' ', - insp.get_view_definition("test_mview").strip()), - "SELECT testtable.id, testtable.data FROM testtable;" + r"[\n\t ]+", + " ", + insp.get_view_definition("test_mview").strip(), + ), + "SELECT testtable.id, testtable.data FROM testtable;", ) class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): """Test PostgreSQL domains""" - __only_on__ = 'postgresql > 8.3' + __only_on__ = "postgresql > 8.3" __backend__ = True @classmethod @@ -221,208 +240,231 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): con = testing.db.connect() for ddl in [ 'CREATE SCHEMA "SomeSchema"', - 'CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42', - 'CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0', + "CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42", + "CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0", "CREATE TYPE testtype AS ENUM ('test')", - 'CREATE DOMAIN enumdomain AS testtype', - 'CREATE DOMAIN arraydomain AS INTEGER[]', - 'CREATE DOMAIN "SomeSchema"."Quoted.Domain" INTEGER DEFAULT 0' + "CREATE DOMAIN enumdomain AS testtype", + "CREATE DOMAIN arraydomain AS INTEGER[]", + 'CREATE DOMAIN "SomeSchema"."Quoted.Domain" INTEGER DEFAULT 0', ]: try: con.execute(ddl) except exc.DBAPIError as e: - if 'already exists' not in str(e): + if "already exists" not in str(e): raise e - con.execute('CREATE TABLE testtable (question integer, answer ' - 'testdomain)') - con.execute('CREATE TABLE test_schema.testtable(question ' - 'integer, answer test_schema.testdomain, anything ' - 'integer)') - con.execute('CREATE TABLE crosschema (question integer, answer ' - 'test_schema.testdomain)') + con.execute( + "CREATE TABLE testtable (question integer, answer " "testdomain)" + ) + con.execute( + "CREATE TABLE test_schema.testtable(question " + "integer, answer test_schema.testdomain, anything " + "integer)" + ) + con.execute( + "CREATE TABLE crosschema (question integer, answer " + "test_schema.testdomain)" + ) - con.execute('CREATE TABLE enum_test (id integer, data enumdomain)') + con.execute("CREATE TABLE enum_test (id integer, data enumdomain)") - con.execute('CREATE TABLE array_test (id integer, data arraydomain)') + con.execute("CREATE TABLE array_test (id integer, data arraydomain)") con.execute( - 'CREATE TABLE quote_test ' - '(id integer, data "SomeSchema"."Quoted.Domain")') + "CREATE TABLE quote_test " + '(id integer, data "SomeSchema"."Quoted.Domain")' + ) @classmethod def teardown_class(cls): con = testing.db.connect() - con.execute('DROP TABLE testtable') - con.execute('DROP TABLE test_schema.testtable') - con.execute('DROP TABLE crosschema') - con.execute('DROP TABLE quote_test') - con.execute('DROP DOMAIN testdomain') - con.execute('DROP DOMAIN test_schema.testdomain') + con.execute("DROP TABLE testtable") + con.execute("DROP TABLE test_schema.testtable") + con.execute("DROP TABLE crosschema") + con.execute("DROP TABLE quote_test") + con.execute("DROP DOMAIN testdomain") + con.execute("DROP DOMAIN test_schema.testdomain") con.execute("DROP TABLE enum_test") con.execute("DROP DOMAIN enumdomain") con.execute("DROP TYPE testtype") - con.execute('DROP TABLE array_test') - con.execute('DROP DOMAIN arraydomain') + con.execute("DROP TABLE array_test") + con.execute("DROP DOMAIN arraydomain") con.execute('DROP DOMAIN "SomeSchema"."Quoted.Domain"') con.execute('DROP SCHEMA "SomeSchema"') def test_table_is_reflected(self): metadata = MetaData(testing.db) - table = Table('testtable', metadata, autoload=True) - eq_(set(table.columns.keys()), set(['question', 'answer']), - "Columns of reflected table didn't equal expected columns") + table = Table("testtable", metadata, autoload=True) + eq_( + set(table.columns.keys()), + set(["question", "answer"]), + "Columns of reflected table didn't equal expected columns", + ) assert isinstance(table.c.answer.type, Integer) def test_domain_is_reflected(self): metadata = MetaData(testing.db) - table = Table('testtable', metadata, autoload=True) - eq_(str(table.columns.answer.server_default.arg), '42', - "Reflected default value didn't equal expected value") - assert not table.columns.answer.nullable, \ - 'Expected reflected column to not be nullable.' + table = Table("testtable", metadata, autoload=True) + eq_( + str(table.columns.answer.server_default.arg), + "42", + "Reflected default value didn't equal expected value", + ) + assert ( + not table.columns.answer.nullable + ), "Expected reflected column to not be nullable." def test_enum_domain_is_reflected(self): metadata = MetaData(testing.db) - table = Table('enum_test', metadata, autoload=True) - eq_( - table.c.data.type.enums, - ['test'] - ) + table = Table("enum_test", metadata, autoload=True) + eq_(table.c.data.type.enums, ["test"]) def test_array_domain_is_reflected(self): metadata = MetaData(testing.db) - table = Table('array_test', metadata, autoload=True) - eq_( - table.c.data.type.__class__, - ARRAY - ) - eq_( - table.c.data.type.item_type.__class__, - INTEGER - ) + table = Table("array_test", metadata, autoload=True) + eq_(table.c.data.type.__class__, ARRAY) + eq_(table.c.data.type.item_type.__class__, INTEGER) def test_quoted_remote_schema_domain_is_reflected(self): metadata = MetaData(testing.db) - table = Table('quote_test', metadata, autoload=True) - eq_( - table.c.data.type.__class__, - INTEGER - ) + table = Table("quote_test", metadata, autoload=True) + eq_(table.c.data.type.__class__, INTEGER) def test_table_is_reflected_test_schema(self): metadata = MetaData(testing.db) - table = Table('testtable', metadata, autoload=True, - schema='test_schema') - eq_(set(table.columns.keys()), set(['question', 'answer', - 'anything']), - "Columns of reflected table didn't equal expected columns") + table = Table( + "testtable", metadata, autoload=True, schema="test_schema" + ) + eq_( + set(table.columns.keys()), + set(["question", "answer", "anything"]), + "Columns of reflected table didn't equal expected columns", + ) assert isinstance(table.c.anything.type, Integer) def test_schema_domain_is_reflected(self): metadata = MetaData(testing.db) - table = Table('testtable', metadata, autoload=True, - schema='test_schema') - eq_(str(table.columns.answer.server_default.arg), '0', - "Reflected default value didn't equal expected value") - assert table.columns.answer.nullable, \ - 'Expected reflected column to be nullable.' + table = Table( + "testtable", metadata, autoload=True, schema="test_schema" + ) + eq_( + str(table.columns.answer.server_default.arg), + "0", + "Reflected default value didn't equal expected value", + ) + assert ( + table.columns.answer.nullable + ), "Expected reflected column to be nullable." def test_crosschema_domain_is_reflected(self): metadata = MetaData(testing.db) - table = Table('crosschema', metadata, autoload=True) - eq_(str(table.columns.answer.server_default.arg), '0', - "Reflected default value didn't equal expected value") - assert table.columns.answer.nullable, \ - 'Expected reflected column to be nullable.' + table = Table("crosschema", metadata, autoload=True) + eq_( + str(table.columns.answer.server_default.arg), + "0", + "Reflected default value didn't equal expected value", + ) + assert ( + table.columns.answer.nullable + ), "Expected reflected column to be nullable." def test_unknown_types(self): from sqlalchemy.databases import postgresql + ischema_names = postgresql.PGDialect.ischema_names postgresql.PGDialect.ischema_names = {} try: m2 = MetaData(testing.db) - assert_raises(exc.SAWarning, Table, 'testtable', m2, - autoload=True) + assert_raises(exc.SAWarning, Table, "testtable", m2, autoload=True) - @testing.emits_warning('Did not recognize type') + @testing.emits_warning("Did not recognize type") def warns(): m3 = MetaData(testing.db) - t3 = Table('testtable', m3, autoload=True) + t3 = Table("testtable", m3, autoload=True) assert t3.c.answer.type.__class__ == sa.types.NullType + finally: postgresql.PGDialect.ischema_names = ischema_names class ReflectionTest(fixtures.TestBase): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True - @testing.fails_if("postgresql < 8.4", - "Better int2vector functions not available") + @testing.fails_if( + "postgresql < 8.4", "Better int2vector functions not available" + ) @testing.provide_metadata def test_reflected_primary_key_order(self): meta1 = self.metadata - subject = Table('subject', meta1, - Column('p1', Integer, primary_key=True), - Column('p2', Integer, primary_key=True), - PrimaryKeyConstraint('p2', 'p1') - ) + subject = Table( + "subject", + meta1, + Column("p1", Integer, primary_key=True), + Column("p2", Integer, primary_key=True), + PrimaryKeyConstraint("p2", "p1"), + ) meta1.create_all() meta2 = MetaData(testing.db) - subject = Table('subject', meta2, autoload=True) - eq_(subject.primary_key.columns.keys(), ['p2', 'p1']) + subject = Table("subject", meta2, autoload=True) + eq_(subject.primary_key.columns.keys(), ["p2", "p1"]) @testing.provide_metadata def test_pg_weirdchar_reflection(self): meta1 = self.metadata - subject = Table('subject', meta1, Column('id$', Integer, - primary_key=True)) + subject = Table( + "subject", meta1, Column("id$", Integer, primary_key=True) + ) referer = Table( - 'referer', meta1, - Column( - 'id', Integer, primary_key=True), - Column( - 'ref', Integer, ForeignKey('subject.id$'))) + "referer", + meta1, + Column("id", Integer, primary_key=True), + Column("ref", Integer, ForeignKey("subject.id$")), + ) meta1.create_all() meta2 = MetaData(testing.db) - subject = Table('subject', meta2, autoload=True) - referer = Table('referer', meta2, autoload=True) - self.assert_((subject.c['id$'] - == referer.c.ref).compare( - subject.join(referer).onclause)) + subject = Table("subject", meta2, autoload=True) + referer = Table("referer", meta2, autoload=True) + self.assert_( + (subject.c["id$"] == referer.c.ref).compare( + subject.join(referer).onclause + ) + ) @testing.provide_metadata def test_reflect_default_over_128_chars(self): - Table('t', self.metadata, - Column('x', String(200), server_default="abcd" * 40) - ).create(testing.db) + Table( + "t", + self.metadata, + Column("x", String(200), server_default="abcd" * 40), + ).create(testing.db) m = MetaData() - t = Table('t', m, autoload=True, autoload_with=testing.db) + t = Table("t", m, autoload=True, autoload_with=testing.db) eq_( - t.c.x.server_default.arg.text, "'%s'::character varying" % ( - "abcd" * 40) + t.c.x.server_default.arg.text, + "'%s'::character varying" % ("abcd" * 40), ) @testing.fails_if("postgresql < 8.1", "schema name leaks in, not sure") @testing.provide_metadata def test_renamed_sequence_reflection(self): metadata = self.metadata - t = Table('t', metadata, Column('id', Integer, primary_key=True)) + t = Table("t", metadata, Column("id", Integer, primary_key=True)) metadata.create_all() m2 = MetaData(testing.db) - t2 = Table('t', m2, autoload=True, implicit_returning=False) - eq_(t2.c.id.server_default.arg.text, - "nextval('t_id_seq'::regclass)") + t2 = Table("t", m2, autoload=True, implicit_returning=False) + eq_(t2.c.id.server_default.arg.text, "nextval('t_id_seq'::regclass)") r = t2.insert().execute() eq_(r.inserted_primary_key, [1]) - testing.db.connect().execution_options(autocommit=True).\ - execute('alter table t_id_seq rename to foobar_id_seq' - ) + testing.db.connect().execution_options(autocommit=True).execute( + "alter table t_id_seq rename to foobar_id_seq" + ) m3 = MetaData(testing.db) - t3 = Table('t', m3, autoload=True, implicit_returning=False) - eq_(t3.c.id.server_default.arg.text, - "nextval('foobar_id_seq'::regclass)") + t3 = Table("t", m3, autoload=True, implicit_returning=False) + eq_( + t3.c.id.server_default.arg.text, + "nextval('foobar_id_seq'::regclass)", + ) r = t3.insert().execute() eq_(r.inserted_primary_key, [2]) @@ -430,37 +472,41 @@ class ReflectionTest(fixtures.TestBase): def test_altered_type_autoincrement_pk_reflection(self): metadata = self.metadata t = Table( - 't', metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer) + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), ) metadata.create_all() - testing.db.connect().execution_options(autocommit=True).\ - execute('alter table t alter column id type varchar(50)') + testing.db.connect().execution_options(autocommit=True).execute( + "alter table t alter column id type varchar(50)" + ) m2 = MetaData(testing.db) - t2 = Table('t', m2, autoload=True) + t2 = Table("t", m2, autoload=True) eq_(t2.c.id.autoincrement, False) eq_(t2.c.x.autoincrement, False) @testing.provide_metadata def test_renamed_pk_reflection(self): metadata = self.metadata - t = Table('t', metadata, Column('id', Integer, primary_key=True)) + t = Table("t", metadata, Column("id", Integer, primary_key=True)) metadata.create_all() - testing.db.connect().execution_options(autocommit=True).\ - execute('alter table t rename id to t_id') + testing.db.connect().execution_options(autocommit=True).execute( + "alter table t rename id to t_id" + ) m2 = MetaData(testing.db) - t2 = Table('t', m2, autoload=True) - eq_([c.name for c in t2.primary_key], ['t_id']) + t2 = Table("t", m2, autoload=True) + eq_([c.name for c in t2.primary_key], ["t_id"]) @testing.provide_metadata def test_has_temporary_table(self): assert not testing.db.has_table("some_temp_table") user_tmp = Table( - "some_temp_table", self.metadata, + "some_temp_table", + self.metadata, Column("id", Integer, primary_key=True), - Column('name', String(50)), - prefixes=['TEMPORARY'] + Column("name", String(50)), + prefixes=["TEMPORARY"], ) user_tmp.create(testing.db) assert testing.db.has_table("some_temp_table") @@ -470,108 +516,124 @@ class ReflectionTest(fixtures.TestBase): meta1 = self.metadata - users = Table('users', meta1, - Column('user_id', Integer, primary_key=True), - Column('user_name', String(30), nullable=False), - schema='test_schema') + users = Table( + "users", + meta1, + Column("user_id", Integer, primary_key=True), + Column("user_name", String(30), nullable=False), + schema="test_schema", + ) addresses = Table( - 'email_addresses', meta1, - Column( - 'address_id', Integer, primary_key=True), - Column( - 'remote_user_id', Integer, ForeignKey( - users.c.user_id)), - Column( - 'email_address', String(20)), schema='test_schema') + "email_addresses", + meta1, + Column("address_id", Integer, primary_key=True), + Column("remote_user_id", Integer, ForeignKey(users.c.user_id)), + Column("email_address", String(20)), + schema="test_schema", + ) meta1.create_all() meta2 = MetaData(testing.db) - addresses = Table('email_addresses', meta2, autoload=True, - schema='test_schema') - users = Table('users', meta2, mustexist=True, - schema='test_schema') + addresses = Table( + "email_addresses", meta2, autoload=True, schema="test_schema" + ) + users = Table("users", meta2, mustexist=True, schema="test_schema") j = join(users, addresses) - self.assert_((users.c.user_id - == addresses.c.remote_user_id).compare(j.onclause)) + self.assert_( + (users.c.user_id == addresses.c.remote_user_id).compare(j.onclause) + ) @testing.provide_metadata def test_cross_schema_reflection_two(self): meta1 = self.metadata - subject = Table('subject', meta1, - Column('id', Integer, primary_key=True)) - referer = Table('referer', meta1, - Column('id', Integer, primary_key=True), - Column('ref', Integer, ForeignKey('subject.id')), - schema='test_schema') + subject = Table( + "subject", meta1, Column("id", Integer, primary_key=True) + ) + referer = Table( + "referer", + meta1, + Column("id", Integer, primary_key=True), + Column("ref", Integer, ForeignKey("subject.id")), + schema="test_schema", + ) meta1.create_all() meta2 = MetaData(testing.db) - subject = Table('subject', meta2, autoload=True) - referer = Table('referer', meta2, schema='test_schema', - autoload=True) - self.assert_((subject.c.id - == referer.c.ref).compare( - subject.join(referer).onclause)) + subject = Table("subject", meta2, autoload=True) + referer = Table("referer", meta2, schema="test_schema", autoload=True) + self.assert_( + (subject.c.id == referer.c.ref).compare( + subject.join(referer).onclause + ) + ) @testing.provide_metadata def test_cross_schema_reflection_three(self): meta1 = self.metadata - subject = Table('subject', meta1, - Column('id', Integer, primary_key=True), - schema='test_schema_2') + subject = Table( + "subject", + meta1, + Column("id", Integer, primary_key=True), + schema="test_schema_2", + ) referer = Table( - 'referer', + "referer", meta1, - Column( - 'id', - Integer, - primary_key=True), - Column( - 'ref', - Integer, - ForeignKey('test_schema_2.subject.id')), - schema='test_schema') + Column("id", Integer, primary_key=True), + Column("ref", Integer, ForeignKey("test_schema_2.subject.id")), + schema="test_schema", + ) meta1.create_all() meta2 = MetaData(testing.db) - subject = Table('subject', meta2, autoload=True, - schema='test_schema_2') - referer = Table('referer', meta2, autoload=True, - schema='test_schema') - self.assert_((subject.c.id - == referer.c.ref).compare( - subject.join(referer).onclause)) + subject = Table( + "subject", meta2, autoload=True, schema="test_schema_2" + ) + referer = Table("referer", meta2, autoload=True, schema="test_schema") + self.assert_( + (subject.c.id == referer.c.ref).compare( + subject.join(referer).onclause + ) + ) @testing.provide_metadata def test_cross_schema_reflection_four(self): meta1 = self.metadata - subject = Table('subject', meta1, - Column('id', Integer, primary_key=True), - schema='test_schema_2') + subject = Table( + "subject", + meta1, + Column("id", Integer, primary_key=True), + schema="test_schema_2", + ) referer = Table( - 'referer', + "referer", meta1, - Column( - 'id', - Integer, - primary_key=True), - Column( - 'ref', - Integer, - ForeignKey('test_schema_2.subject.id')), - schema='test_schema') + Column("id", Integer, primary_key=True), + Column("ref", Integer, ForeignKey("test_schema_2.subject.id")), + schema="test_schema", + ) meta1.create_all() conn = testing.db.connect() conn.detach() conn.execute("SET search_path TO test_schema, test_schema_2") meta2 = MetaData(bind=conn) - subject = Table('subject', meta2, autoload=True, - schema='test_schema_2', - postgresql_ignore_search_path=True) - referer = Table('referer', meta2, autoload=True, - schema='test_schema', - postgresql_ignore_search_path=True) - self.assert_((subject.c.id - == referer.c.ref).compare( - subject.join(referer).onclause)) + subject = Table( + "subject", + meta2, + autoload=True, + schema="test_schema_2", + postgresql_ignore_search_path=True, + ) + referer = Table( + "referer", + meta2, + autoload=True, + schema="test_schema", + postgresql_ignore_search_path=True, + ) + self.assert_( + (subject.c.id == referer.c.ref).compare( + subject.join(referer).onclause + ) + ) conn.close() @testing.provide_metadata @@ -580,26 +642,38 @@ class ReflectionTest(fixtures.TestBase): # we assume 'public' default_schema = testing.db.dialect.default_schema_name - subject = Table('subject', meta1, - Column('id', Integer, primary_key=True)) - referer = Table('referer', meta1, - Column('id', Integer, primary_key=True), - Column('ref', Integer, ForeignKey('subject.id'))) + subject = Table( + "subject", meta1, Column("id", Integer, primary_key=True) + ) + referer = Table( + "referer", + meta1, + Column("id", Integer, primary_key=True), + Column("ref", Integer, ForeignKey("subject.id")), + ) meta1.create_all() meta2 = MetaData(testing.db) - subject = Table('subject', meta2, autoload=True, - schema=default_schema, - postgresql_ignore_search_path=True - ) - referer = Table('referer', meta2, autoload=True, - schema=default_schema, - postgresql_ignore_search_path=True - ) + subject = Table( + "subject", + meta2, + autoload=True, + schema=default_schema, + postgresql_ignore_search_path=True, + ) + referer = Table( + "referer", + meta2, + autoload=True, + schema=default_schema, + postgresql_ignore_search_path=True, + ) assert subject.schema == default_schema - self.assert_((subject.c.id - == referer.c.ref).compare( - subject.join(referer).onclause)) + self.assert_( + (subject.c.id == referer.c.ref).compare( + subject.join(referer).onclause + ) + ) @testing.provide_metadata def test_cross_schema_reflection_six(self): @@ -607,61 +681,62 @@ class ReflectionTest(fixtures.TestBase): # by default meta1 = self.metadata - Table('some_table', meta1, - Column('id', Integer, primary_key=True), - schema='test_schema' - ) - Table('some_other_table', meta1, - Column('id', Integer, primary_key=True), - Column('sid', Integer, ForeignKey('test_schema.some_table.id')), - schema='test_schema_2' - ) + Table( + "some_table", + meta1, + Column("id", Integer, primary_key=True), + schema="test_schema", + ) + Table( + "some_other_table", + meta1, + Column("id", Integer, primary_key=True), + Column("sid", Integer, ForeignKey("test_schema.some_table.id")), + schema="test_schema_2", + ) meta1.create_all() with testing.db.connect() as conn: conn.detach() conn.execute( - "set search_path to test_schema_2, test_schema, public") + "set search_path to test_schema_2, test_schema, public" + ) m1 = MetaData(conn) - t1_schema = Table('some_table', - m1, - schema="test_schema", - autoload=True) - t2_schema = Table('some_other_table', - m1, - schema="test_schema_2", - autoload=True) + t1_schema = Table( + "some_table", m1, schema="test_schema", autoload=True + ) + t2_schema = Table( + "some_other_table", m1, schema="test_schema_2", autoload=True + ) - t2_no_schema = Table('some_other_table', - m1, - autoload=True) + t2_no_schema = Table("some_other_table", m1, autoload=True) - t1_no_schema = Table('some_table', - m1, - autoload=True) + t1_no_schema = Table("some_table", m1, autoload=True) m2 = MetaData(conn) - t1_schema_isp = Table('some_table', - m2, - schema="test_schema", - autoload=True, - postgresql_ignore_search_path=True) - t2_schema_isp = Table('some_other_table', - m2, - schema="test_schema_2", - autoload=True, - postgresql_ignore_search_path=True) + t1_schema_isp = Table( + "some_table", + m2, + schema="test_schema", + autoload=True, + postgresql_ignore_search_path=True, + ) + t2_schema_isp = Table( + "some_other_table", + m2, + schema="test_schema_2", + autoload=True, + postgresql_ignore_search_path=True, + ) # t2_schema refers to t1_schema, but since "test_schema" # is in the search path, we instead link to t2_no_schema - assert t2_schema.c.sid.references( - t1_no_schema.c.id) + assert t2_schema.c.sid.references(t1_no_schema.c.id) # the two no_schema tables refer to each other also. - assert t2_no_schema.c.sid.references( - t1_no_schema.c.id) + assert t2_no_schema.c.sid.references(t1_no_schema.c.id) # but if we're ignoring search path, then we maintain # those explicit schemas vs. what the "default" schema is @@ -673,33 +748,48 @@ class ReflectionTest(fixtures.TestBase): # by default meta1 = self.metadata - Table('some_table', meta1, - Column('id', Integer, primary_key=True), - schema='test_schema' - ) - Table('some_other_table', meta1, - Column('id', Integer, primary_key=True), - Column('sid', Integer, ForeignKey('test_schema.some_table.id')), - schema='test_schema_2' - ) + Table( + "some_table", + meta1, + Column("id", Integer, primary_key=True), + schema="test_schema", + ) + Table( + "some_other_table", + meta1, + Column("id", Integer, primary_key=True), + Column("sid", Integer, ForeignKey("test_schema.some_table.id")), + schema="test_schema_2", + ) meta1.create_all() with testing.db.connect() as conn: conn.detach() conn.execute( - "set search_path to test_schema_2, test_schema, public") + "set search_path to test_schema_2, test_schema, public" + ) meta2 = MetaData(conn) meta2.reflect(schema="test_schema_2") - eq_(set(meta2.tables), set( - ['test_schema_2.some_other_table', 'some_table'])) + eq_( + set(meta2.tables), + set(["test_schema_2.some_other_table", "some_table"]), + ) meta3 = MetaData(conn) meta3.reflect( - schema="test_schema_2", postgresql_ignore_search_path=True) + schema="test_schema_2", postgresql_ignore_search_path=True + ) - eq_(set(meta3.tables), set( - ['test_schema_2.some_other_table', 'test_schema.some_table'])) + eq_( + set(meta3.tables), + set( + [ + "test_schema_2.some_other_table", + "test_schema.some_table", + ] + ), + ) @testing.provide_metadata def test_cross_schema_reflection_metadata_uses_schema(self): @@ -707,29 +797,35 @@ class ReflectionTest(fixtures.TestBase): metadata = self.metadata - Table('some_table', metadata, - Column('id', Integer, primary_key=True), - Column('sid', Integer, ForeignKey('some_other_table.id')), - schema='test_schema' - ) - Table('some_other_table', metadata, - Column('id', Integer, primary_key=True), - schema=None - ) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("sid", Integer, ForeignKey("some_other_table.id")), + schema="test_schema", + ) + Table( + "some_other_table", + metadata, + Column("id", Integer, primary_key=True), + schema=None, + ) metadata.create_all() with testing.db.connect() as conn: meta2 = MetaData(conn, schema="test_schema") meta2.reflect() - eq_(set(meta2.tables), set( - ['some_other_table', 'test_schema.some_table'])) + eq_( + set(meta2.tables), + set(["some_other_table", "test_schema.some_table"]), + ) @testing.provide_metadata def test_uppercase_lowercase_table(self): metadata = self.metadata - a_table = Table('a', metadata, Column('x', Integer)) - A_table = Table('A', metadata, Column('x', Integer)) + a_table = Table("a", metadata, Column("x", Integer)) + A_table = Table("A", metadata, Column("x", Integer)) a_table.create() assert testing.db.has_table("a") @@ -739,8 +835,8 @@ class ReflectionTest(fixtures.TestBase): def test_uppercase_lowercase_sequence(self): - a_seq = Sequence('a') - A_seq = Sequence('A') + a_seq = Sequence("a") + A_seq = Sequence("A") a_seq.create(testing.db) assert testing.db.dialect.has_sequence(testing.db, "a") @@ -759,28 +855,33 @@ class ReflectionTest(fixtures.TestBase): metadata = self.metadata t1 = Table( - 'party', metadata, - Column( - 'id', String(10), nullable=False), - Column( - 'name', String(20), index=True), - Column( - 'aname', String(20))) + "party", + metadata, + Column("id", String(10), nullable=False), + Column("name", String(20), index=True), + Column("aname", String(20)), + ) metadata.create_all() - testing.db.execute(""" + testing.db.execute( + """ create index idx1 on party ((id || name)) - """) - testing.db.execute(""" + """ + ) + testing.db.execute( + """ create unique index idx2 on party (id) where name = 'test' - """) - testing.db.execute(""" + """ + ) + testing.db.execute( + """ create index idx3 on party using btree (lower(name::text), lower(aname::text)) - """) + """ + ) def go(): m2 = MetaData(testing.db) - t2 = Table('party', m2, autoload=True) + t2 = Table("party", m2, autoload=True) assert len(t2.indexes) == 2 # Make sure indexes are in the order we expect them in @@ -788,7 +889,7 @@ class ReflectionTest(fixtures.TestBase): tmp = [(idx.name, idx) for idx in t2.indexes] tmp.sort() r1, r2 = [idx[1] for idx in tmp] - assert r1.name == 'idx2' + assert r1.name == "idx2" assert r1.unique is True assert r2.unique is False assert [t2.c.id] == r1.columns @@ -796,12 +897,14 @@ class ReflectionTest(fixtures.TestBase): testing.assert_warnings( go, - ['Skipped unsupported reflection of ' - 'expression-based index idx1', - 'Predicate of partial index idx2 ignored during ' - 'reflection', - 'Skipped unsupported reflection of ' - 'expression-based index idx3']) + [ + "Skipped unsupported reflection of " + "expression-based index idx1", + "Predicate of partial index idx2 ignored during " "reflection", + "Skipped unsupported reflection of " + "expression-based index idx3", + ], + ) @testing.provide_metadata def test_index_reflection_modified(self): @@ -813,17 +916,19 @@ class ReflectionTest(fixtures.TestBase): metadata = self.metadata - t1 = Table('t', metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer) - ) + t1 = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + ) metadata.create_all() conn = testing.db.connect().execution_options(autocommit=True) conn.execute("CREATE INDEX idx1 ON t (x)") conn.execute("ALTER TABLE t RENAME COLUMN x to y") ind = testing.db.dialect.get_indexes(conn, "t", None) - eq_(ind, [{'unique': False, 'column_names': ['y'], 'name': 'idx1'}]) + eq_(ind, [{"unique": False, "column_names": ["y"], "name": "idx1"}]) conn.close() @testing.fails_if("postgresql < 8.2", "reloptions not supported") @@ -834,9 +939,10 @@ class ReflectionTest(fixtures.TestBase): metadata = self.metadata Table( - 't', metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer) + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), ) metadata.create_all() @@ -844,15 +950,25 @@ class ReflectionTest(fixtures.TestBase): conn.execute("CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)") ind = testing.db.dialect.get_indexes(conn, "t", None) - eq_(ind, [{'unique': False, 'column_names': ['x'], 'name': 'idx1', - 'dialect_options': - {"postgresql_with": {"fillfactor": "50"}}}]) + eq_( + ind, + [ + { + "unique": False, + "column_names": ["x"], + "name": "idx1", + "dialect_options": { + "postgresql_with": {"fillfactor": "50"} + }, + } + ], + ) m = MetaData() - t1 = Table('t', m, autoload_with=conn) + t1 = Table("t", m, autoload_with=conn) eq_( - list(t1.indexes)[0].dialect_options['postgresql']['with'], - {"fillfactor": "50"} + list(t1.indexes)[0].dialect_options["postgresql"]["with"], + {"fillfactor": "50"}, ) @testing.provide_metadata @@ -862,249 +978,316 @@ class ReflectionTest(fixtures.TestBase): metadata = self.metadata Table( - 't', metadata, - Column('id', Integer, primary_key=True), - Column('x', ARRAY(Integer)) + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("x", ARRAY(Integer)), ) metadata.create_all() with testing.db.connect().execution_options(autocommit=True) as conn: conn.execute("CREATE INDEX idx1 ON t USING gin (x)") ind = testing.db.dialect.get_indexes(conn, "t", None) - eq_(ind, [{'unique': False, 'column_names': ['x'], 'name': 'idx1', - 'dialect_options': {'postgresql_using': 'gin'}}]) + eq_( + ind, + [ + { + "unique": False, + "column_names": ["x"], + "name": "idx1", + "dialect_options": {"postgresql_using": "gin"}, + } + ], + ) m = MetaData() - t1 = Table('t', m, autoload_with=conn) + t1 = Table("t", m, autoload_with=conn) eq_( - list(t1.indexes)[0].dialect_options['postgresql']['using'], - 'gin' + list(t1.indexes)[0].dialect_options["postgresql"]["using"], + "gin", ) @testing.provide_metadata def test_foreign_key_option_inspection(self): metadata = self.metadata Table( - 'person', + "person", metadata, + Column("id", String(length=32), nullable=False, primary_key=True), Column( - 'id', - String( - length=32), - nullable=False, - primary_key=True), - Column( - 'company_id', + "company_id", ForeignKey( - 'company.id', - name='person_company_id_fkey', - match='FULL', - onupdate='RESTRICT', - ondelete='RESTRICT', + "company.id", + name="person_company_id_fkey", + match="FULL", + onupdate="RESTRICT", + ondelete="RESTRICT", deferrable=True, - initially='DEFERRED'))) + initially="DEFERRED", + ), + ), + ) Table( - 'company', metadata, - Column('id', String(length=32), nullable=False, primary_key=True), - Column('name', String(length=255)), + "company", + metadata, + Column("id", String(length=32), nullable=False, primary_key=True), + Column("name", String(length=255)), Column( - 'industry_id', + "industry_id", ForeignKey( - 'industry.id', - name='company_industry_id_fkey', - onupdate='CASCADE', ondelete='CASCADE', + "industry.id", + name="company_industry_id_fkey", + onupdate="CASCADE", + ondelete="CASCADE", deferrable=False, # PG default # PG default - initially='IMMEDIATE' - ) - ) + initially="IMMEDIATE", + ), + ), + ) + Table( + "industry", + metadata, + Column("id", Integer(), nullable=False, primary_key=True), + Column("name", String(length=255)), ) - Table('industry', metadata, - Column('id', Integer(), nullable=False, primary_key=True), - Column('name', String(length=255)) - ) fk_ref = { - 'person_company_id_fkey': { - 'name': 'person_company_id_fkey', - 'constrained_columns': ['company_id'], - 'referred_columns': ['id'], - 'referred_table': 'company', - 'referred_schema': None, - 'options': { - 'onupdate': 'RESTRICT', - 'deferrable': True, - 'ondelete': 'RESTRICT', - 'initially': 'DEFERRED', - 'match': 'FULL' - } + "person_company_id_fkey": { + "name": "person_company_id_fkey", + "constrained_columns": ["company_id"], + "referred_columns": ["id"], + "referred_table": "company", + "referred_schema": None, + "options": { + "onupdate": "RESTRICT", + "deferrable": True, + "ondelete": "RESTRICT", + "initially": "DEFERRED", + "match": "FULL", + }, + }, + "company_industry_id_fkey": { + "name": "company_industry_id_fkey", + "constrained_columns": ["industry_id"], + "referred_columns": ["id"], + "referred_table": "industry", + "referred_schema": None, + "options": { + "onupdate": "CASCADE", + "deferrable": None, + "ondelete": "CASCADE", + "initially": None, + "match": None, + }, }, - 'company_industry_id_fkey': { - 'name': 'company_industry_id_fkey', - 'constrained_columns': ['industry_id'], - 'referred_columns': ['id'], - 'referred_table': 'industry', - 'referred_schema': None, - 'options': { - 'onupdate': 'CASCADE', - 'deferrable': None, - 'ondelete': 'CASCADE', - 'initially': None, - 'match': None - } - } } metadata.create_all() inspector = inspect(testing.db) - fks = inspector.get_foreign_keys('person') + \ - inspector.get_foreign_keys('company') + fks = inspector.get_foreign_keys( + "person" + ) + inspector.get_foreign_keys("company") for fk in fks: - eq_(fk, fk_ref[fk['name']]) + eq_(fk, fk_ref[fk["name"]]) @testing.provide_metadata def test_inspect_enums_schema(self): conn = testing.db.connect() enum_type = postgresql.ENUM( - 'sad', 'ok', 'happy', name='mood', - schema='test_schema', - metadata=self.metadata) + "sad", + "ok", + "happy", + name="mood", + schema="test_schema", + metadata=self.metadata, + ) enum_type.create(conn) inspector = reflection.Inspector.from_engine(conn.engine) eq_( - inspector.get_enums('test_schema'), [{ - 'visible': False, - 'name': 'mood', - 'schema': 'test_schema', - 'labels': ['sad', 'ok', 'happy'] - }]) + inspector.get_enums("test_schema"), + [ + { + "visible": False, + "name": "mood", + "schema": "test_schema", + "labels": ["sad", "ok", "happy"], + } + ], + ) @testing.provide_metadata def test_inspect_enums(self): enum_type = postgresql.ENUM( - 'cat', 'dog', 'rat', name='pet', metadata=self.metadata) + "cat", "dog", "rat", name="pet", metadata=self.metadata + ) enum_type.create(testing.db) inspector = reflection.Inspector.from_engine(testing.db) - eq_(inspector.get_enums(), [ - { - 'visible': True, - 'labels': ['cat', 'dog', 'rat'], - 'name': 'pet', - 'schema': 'public' - }]) + eq_( + inspector.get_enums(), + [ + { + "visible": True, + "labels": ["cat", "dog", "rat"], + "name": "pet", + "schema": "public", + } + ], + ) @testing.provide_metadata def test_inspect_enums_case_sensitive(self): sa.event.listen( - self.metadata, "before_create", - sa.DDL('create schema "TestSchema"')) + self.metadata, + "before_create", + sa.DDL('create schema "TestSchema"'), + ) sa.event.listen( - self.metadata, "after_drop", - sa.DDL('drop schema "TestSchema" cascade')) + self.metadata, + "after_drop", + sa.DDL('drop schema "TestSchema" cascade'), + ) - for enum in 'lower_case', 'UpperCase', 'Name.With.Dot': - for schema in None, 'test_schema', 'TestSchema': + for enum in "lower_case", "UpperCase", "Name.With.Dot": + for schema in None, "test_schema", "TestSchema": postgresql.ENUM( - 'CapsOne', 'CapsTwo', name=enum, - schema=schema, metadata=self.metadata) + "CapsOne", + "CapsTwo", + name=enum, + schema=schema, + metadata=self.metadata, + ) self.metadata.create_all(testing.db) inspector = inspect(testing.db) - for schema in None, 'test_schema', 'TestSchema': - eq_(sorted( - inspector.get_enums(schema=schema), - key=itemgetter("name")), [ - { - 'visible': schema is None, - 'labels': ['CapsOne', 'CapsTwo'], - 'name': "Name.With.Dot", - 'schema': 'public' if schema is None else schema - }, - { - 'visible': schema is None, - 'labels': ['CapsOne', 'CapsTwo'], - 'name': "UpperCase", - 'schema': 'public' if schema is None else schema - }, - { - 'visible': schema is None, - 'labels': ['CapsOne', 'CapsTwo'], - 'name': "lower_case", - 'schema': 'public' if schema is None else schema - } - ]) + for schema in None, "test_schema", "TestSchema": + eq_( + sorted( + inspector.get_enums(schema=schema), key=itemgetter("name") + ), + [ + { + "visible": schema is None, + "labels": ["CapsOne", "CapsTwo"], + "name": "Name.With.Dot", + "schema": "public" if schema is None else schema, + }, + { + "visible": schema is None, + "labels": ["CapsOne", "CapsTwo"], + "name": "UpperCase", + "schema": "public" if schema is None else schema, + }, + { + "visible": schema is None, + "labels": ["CapsOne", "CapsTwo"], + "name": "lower_case", + "schema": "public" if schema is None else schema, + }, + ], + ) @testing.provide_metadata def test_inspect_enums_case_sensitive_from_table(self): sa.event.listen( - self.metadata, "before_create", - sa.DDL('create schema "TestSchema"')) + self.metadata, + "before_create", + sa.DDL('create schema "TestSchema"'), + ) sa.event.listen( - self.metadata, "after_drop", - sa.DDL('drop schema "TestSchema" cascade')) + self.metadata, + "after_drop", + sa.DDL('drop schema "TestSchema" cascade'), + ) counter = itertools.count() - for enum in 'lower_case', 'UpperCase', 'Name.With.Dot': - for schema in None, 'test_schema', 'TestSchema': - - enum_type = postgresql.ENUM( - 'CapsOne', 'CapsTwo', name=enum, - metadata=self.metadata, schema=schema) + for enum in "lower_case", "UpperCase", "Name.With.Dot": + for schema in None, "test_schema", "TestSchema": + + enum_type = postgresql.ENUM( + "CapsOne", + "CapsTwo", + name=enum, + metadata=self.metadata, + schema=schema, + ) - Table( - 't%d' % next(counter), - self.metadata, Column('q', enum_type)) + Table( + "t%d" % next(counter), + self.metadata, + Column("q", enum_type), + ) self.metadata.create_all(testing.db) inspector = inspect(testing.db) counter = itertools.count() - for enum in 'lower_case', 'UpperCase', 'Name.With.Dot': - for schema in None, 'test_schema', 'TestSchema': + for enum in "lower_case", "UpperCase", "Name.With.Dot": + for schema in None, "test_schema", "TestSchema": cols = inspector.get_columns("t%d" % next(counter)) - cols[0]['type'] = ( - cols[0]['type'].schema, - cols[0]['type'].name, cols[0]['type'].enums) - eq_(cols, [ - { - 'name': 'q', - 'type': ( - schema, enum, ['CapsOne', 'CapsTwo']), - 'nullable': True, 'default': None, - 'autoincrement': False, 'comment': None} - ]) + cols[0]["type"] = ( + cols[0]["type"].schema, + cols[0]["type"].name, + cols[0]["type"].enums, + ) + eq_( + cols, + [ + { + "name": "q", + "type": (schema, enum, ["CapsOne", "CapsTwo"]), + "nullable": True, + "default": None, + "autoincrement": False, + "comment": None, + } + ], + ) @testing.provide_metadata def test_inspect_enums_star(self): enum_type = postgresql.ENUM( - 'cat', 'dog', 'rat', name='pet', metadata=self.metadata) + "cat", "dog", "rat", name="pet", metadata=self.metadata + ) schema_enum_type = postgresql.ENUM( - 'sad', 'ok', 'happy', name='mood', - schema='test_schema', - metadata=self.metadata) + "sad", + "ok", + "happy", + name="mood", + schema="test_schema", + metadata=self.metadata, + ) enum_type.create(testing.db) schema_enum_type.create(testing.db) inspector = reflection.Inspector.from_engine(testing.db) - eq_(inspector.get_enums(), [ - { - 'visible': True, - 'labels': ['cat', 'dog', 'rat'], - 'name': 'pet', - 'schema': 'public' - }]) + eq_( + inspector.get_enums(), + [ + { + "visible": True, + "labels": ["cat", "dog", "rat"], + "name": "pet", + "schema": "public", + } + ], + ) - eq_(inspector.get_enums('*'), [ - { - 'visible': True, - 'labels': ['cat', 'dog', 'rat'], - 'name': 'pet', - 'schema': 'public' - }, - { - 'visible': False, - 'name': 'mood', - 'schema': 'test_schema', - 'labels': ['sad', 'ok', 'happy'] - }]) + eq_( + inspector.get_enums("*"), + [ + { + "visible": True, + "labels": ["cat", "dog", "rat"], + "name": "pet", + "schema": "public", + }, + { + "visible": False, + "name": "mood", + "schema": "test_schema", + "labels": ["sad", "ok", "happy"], + }, + ], + ) @testing.provide_metadata @testing.only_on("postgresql >= 8.5") @@ -1112,39 +1295,44 @@ class ReflectionTest(fixtures.TestBase): insp = inspect(testing.db) meta = self.metadata - uc_table = Table('pgsql_uc', meta, - Column('a', String(10)), - UniqueConstraint('a', name='uc_a')) + uc_table = Table( + "pgsql_uc", + meta, + Column("a", String(10)), + UniqueConstraint("a", name="uc_a"), + ) uc_table.create() # PostgreSQL will create an implicit index for a unique # constraint. Separately we get both - indexes = set(i['name'] for i in insp.get_indexes('pgsql_uc')) - constraints = set(i['name'] - for i in insp.get_unique_constraints('pgsql_uc')) + indexes = set(i["name"] for i in insp.get_indexes("pgsql_uc")) + constraints = set( + i["name"] for i in insp.get_unique_constraints("pgsql_uc") + ) - self.assert_('uc_a' in indexes) - self.assert_('uc_a' in constraints) + self.assert_("uc_a" in indexes) + self.assert_("uc_a" in constraints) # reflection corrects for the dupe - reflected = Table('pgsql_uc', MetaData(testing.db), autoload=True) + reflected = Table("pgsql_uc", MetaData(testing.db), autoload=True) indexes = set(i.name for i in reflected.indexes) constraints = set(uc.name for uc in reflected.constraints) - self.assert_('uc_a' not in indexes) - self.assert_('uc_a' in constraints) + self.assert_("uc_a" not in indexes) + self.assert_("uc_a" in constraints) @testing.requires.btree_gist @testing.provide_metadata def test_reflection_with_exclude_constraint(self): m = self.metadata Table( - 't', m, - Column('id', Integer, primary_key=True), - Column('period', TSRANGE), - ExcludeConstraint(('period', '&&'), name='quarters_period_excl') + "t", + m, + Column("id", Integer, primary_key=True), + Column("period", TSRANGE), + ExcludeConstraint(("period", "&&"), name="quarters_period_excl"), ) m.create_all() @@ -1154,15 +1342,20 @@ class ReflectionTest(fixtures.TestBase): # PostgreSQL will create an implicit index for an exclude constraint. # we don't reflect the EXCLUDE yet. eq_( - insp.get_indexes('t'), - [{'unique': False, 'name': 'quarters_period_excl', - 'duplicates_constraint': 'quarters_period_excl', - 'dialect_options': {'postgresql_using': 'gist'}, - 'column_names': ['period']}] + insp.get_indexes("t"), + [ + { + "unique": False, + "name": "quarters_period_excl", + "duplicates_constraint": "quarters_period_excl", + "dialect_options": {"postgresql_using": "gist"}, + "column_names": ["period"], + } + ], ) # reflection corrects for the dupe - reflected = Table('t', MetaData(testing.db), autoload=True) + reflected = Table("t", MetaData(testing.db), autoload=True) eq_(set(reflected.indexes), set()) @@ -1174,58 +1367,66 @@ class ReflectionTest(fixtures.TestBase): # a unique index OTOH we are able to detect is an index # and not a unique constraint - uc_table = Table('pgsql_uc', meta, - Column('a', String(10)), - Index('ix_a', 'a', unique=True)) + uc_table = Table( + "pgsql_uc", + meta, + Column("a", String(10)), + Index("ix_a", "a", unique=True), + ) uc_table.create() - indexes = dict((i['name'], i) for i in insp.get_indexes('pgsql_uc')) - constraints = set(i['name'] - for i in insp.get_unique_constraints('pgsql_uc')) + indexes = dict((i["name"], i) for i in insp.get_indexes("pgsql_uc")) + constraints = set( + i["name"] for i in insp.get_unique_constraints("pgsql_uc") + ) - self.assert_('ix_a' in indexes) - assert indexes['ix_a']['unique'] - self.assert_('ix_a' not in constraints) + self.assert_("ix_a" in indexes) + assert indexes["ix_a"]["unique"] + self.assert_("ix_a" not in constraints) - reflected = Table('pgsql_uc', MetaData(testing.db), autoload=True) + reflected = Table("pgsql_uc", MetaData(testing.db), autoload=True) indexes = dict((i.name, i) for i in reflected.indexes) constraints = set(uc.name for uc in reflected.constraints) - self.assert_('ix_a' in indexes) - assert indexes['ix_a'].unique - self.assert_('ix_a' not in constraints) + self.assert_("ix_a" in indexes) + assert indexes["ix_a"].unique + self.assert_("ix_a" not in constraints) @testing.provide_metadata def test_reflect_check_constraint(self): meta = self.metadata cc_table = Table( - 'pgsql_cc', meta, - Column('a', Integer()), - CheckConstraint('a > 1 AND a < 5', name='cc1'), - CheckConstraint('a = 1 OR (a > 2 AND a < 5)', name='cc2')) + "pgsql_cc", + meta, + Column("a", Integer()), + CheckConstraint("a > 1 AND a < 5", name="cc1"), + CheckConstraint("a = 1 OR (a > 2 AND a < 5)", name="cc2"), + ) cc_table.create() - reflected = Table('pgsql_cc', MetaData(testing.db), autoload=True) + reflected = Table("pgsql_cc", MetaData(testing.db), autoload=True) - check_constraints = dict((uc.name, uc.sqltext.text) - for uc in reflected.constraints - if isinstance(uc, CheckConstraint)) - - eq_(check_constraints, { - u'cc1': u'(a > 1) AND (a < 5)', - u'cc2': u'(a = 1) OR ((a > 2) AND (a < 5))' - }) + check_constraints = dict( + (uc.name, uc.sqltext.text) + for uc in reflected.constraints + if isinstance(uc, CheckConstraint) + ) + eq_( + check_constraints, + { + u"cc1": u"(a > 1) AND (a < 5)", + u"cc2": u"(a = 1) OR ((a > 2) AND (a < 5))", + }, + ) class CustomTypeReflectionTest(fixtures.TestBase): - class CustomType(object): - def __init__(self, arg1=None, arg2=None): self.arg1 = arg1 self.arg2 = arg2 @@ -1243,32 +1444,32 @@ class CustomTypeReflectionTest(fixtures.TestBase): def _assert_reflected(self, dialect): for sch, args in [ - ('my_custom_type', (None, None)), - ('my_custom_type()', (None, None)), - ('my_custom_type(ARG1)', ('ARG1', None)), - ('my_custom_type(ARG1, ARG2)', ('ARG1', 'ARG2')), + ("my_custom_type", (None, None)), + ("my_custom_type()", (None, None)), + ("my_custom_type(ARG1)", ("ARG1", None)), + ("my_custom_type(ARG1, ARG2)", ("ARG1", "ARG2")), ]: column_info = dialect._get_column_info( - 'colname', sch, None, False, - {}, {}, 'public', None) - assert isinstance(column_info['type'], self.CustomType) - eq_(column_info['type'].arg1, args[0]) - eq_(column_info['type'].arg2, args[1]) + "colname", sch, None, False, {}, {}, "public", None + ) + assert isinstance(column_info["type"], self.CustomType) + eq_(column_info["type"].arg1, args[0]) + eq_(column_info["type"].arg2, args[1]) def test_clslevel(self): - postgresql.PGDialect.ischema_names['my_custom_type'] = self.CustomType + postgresql.PGDialect.ischema_names["my_custom_type"] = self.CustomType dialect = postgresql.PGDialect() self._assert_reflected(dialect) def test_instancelevel(self): dialect = postgresql.PGDialect() dialect.ischema_names = dialect.ischema_names.copy() - dialect.ischema_names['my_custom_type'] = self.CustomType + dialect.ischema_names["my_custom_type"] = self.CustomType self._assert_reflected(dialect) class IntervalReflectionTest(fixtures.TestBase): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True def test_interval_types(self): @@ -1292,14 +1493,15 @@ class IntervalReflectionTest(fixtures.TestBase): @testing.provide_metadata def _test_interval_symbol(self, sym): t = Table( - 'i_test', self.metadata, - Column('id', Integer, primary_key=True), - Column('data1', INTERVAL(fields=sym)), + "i_test", + self.metadata, + Column("id", Integer, primary_key=True), + Column("data1", INTERVAL(fields=sym)), ) t.create(testing.db) columns = { - rec['name']: rec + rec["name"]: rec for rec in inspect(testing.db).get_columns("i_test") } assert isinstance(columns["data1"]["type"], INTERVAL) @@ -1309,17 +1511,17 @@ class IntervalReflectionTest(fixtures.TestBase): @testing.provide_metadata def test_interval_precision(self): t = Table( - 'i_test', self.metadata, - Column('id', Integer, primary_key=True), - Column('data1', INTERVAL(precision=6)), + "i_test", + self.metadata, + Column("id", Integer, primary_key=True), + Column("data1", INTERVAL(precision=6)), ) t.create(testing.db) columns = { - rec['name']: rec + rec["name"]: rec for rec in inspect(testing.db).get_columns("i_test") } assert isinstance(columns["data1"]["type"], INTERVAL) eq_(columns["data1"]["type"].fields, None) eq_(columns["data1"]["type"].precision, 6) - diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 2ea7d3024f..d7ae213963 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -1,21 +1,59 @@ # coding: utf-8 -from sqlalchemy.testing.assertions import eq_, assert_raises, \ - assert_raises_message, is_, AssertsExecutionResults, \ - AssertsCompiledSQL, ComparesTables +from sqlalchemy.testing.assertions import ( + eq_, + assert_raises, + assert_raises_message, + is_, + AssertsExecutionResults, + AssertsCompiledSQL, + ComparesTables, +) from sqlalchemy.testing import engines, fixtures from sqlalchemy import testing from sqlalchemy.sql import sqltypes import datetime -from sqlalchemy import Table, MetaData, Column, Integer, Enum, Float, select, \ - func, DateTime, Numeric, exc, String, cast, REAL, TypeDecorator, Unicode, \ - Text, null, text, column, ARRAY, any_, all_ +from sqlalchemy import ( + Table, + MetaData, + Column, + Integer, + Enum, + Float, + select, + func, + DateTime, + Numeric, + exc, + String, + cast, + REAL, + TypeDecorator, + Unicode, + Text, + null, + text, + column, + ARRAY, + any_, + all_, +) from sqlalchemy.sql import operators from sqlalchemy import types import sqlalchemy as sa from sqlalchemy.dialects import postgresql -from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, \ - INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE, \ - JSON, JSONB +from sqlalchemy.dialects.postgresql import ( + HSTORE, + hstore, + array, + INT4RANGE, + INT8RANGE, + NUMRANGE, + DATERANGE, + TSRANGE, + TSTZRANGE, + JSON, + JSONB, +) import decimal from sqlalchemy import util from sqlalchemy.testing.util import round_decimal @@ -28,211 +66,223 @@ tztable = notztable = metadata = table = None class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __dialect__ = postgresql.dialect() __backend__ = True @classmethod def define_tables(cls, metadata): - data_table = Table('data_table', metadata, - Column('id', Integer, primary_key=True), - Column('data', Integer) - ) + data_table = Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", Integer), + ) @classmethod def insert_data(cls): data_table = cls.tables.data_table data_table.insert().execute( - {'data': 3}, - {'data': 5}, - {'data': 7}, - {'data': 2}, - {'data': 15}, - {'data': 12}, - {'data': 6}, - {'data': 478}, - {'data': 52}, - {'data': 9}, + {"data": 3}, + {"data": 5}, + {"data": 7}, + {"data": 2}, + {"data": 15}, + {"data": 12}, + {"data": 6}, + {"data": 478}, + {"data": 52}, + {"data": 9}, ) @testing.fails_on( - 'postgresql+zxjdbc', - 'XXX: postgresql+zxjdbc currently returns a Decimal result for Float') + "postgresql+zxjdbc", + "XXX: postgresql+zxjdbc currently returns a Decimal result for Float", + ) def test_float_coercion(self): data_table = self.tables.data_table for type_, result in [ - (Numeric, decimal.Decimal('140.381230939')), + (Numeric, decimal.Decimal("140.381230939")), (Float, 140.381230939), - (Float(asdecimal=True), decimal.Decimal('140.381230939')), + (Float(asdecimal=True), decimal.Decimal("140.381230939")), (Numeric(asdecimal=False), 140.381230939), ]: ret = testing.db.execute( - select([ - func.stddev_pop(data_table.c.data, type_=type_) - ]) + select([func.stddev_pop(data_table.c.data, type_=type_)]) ).scalar() eq_(round_decimal(ret, 9), result) ret = testing.db.execute( - select([ - cast(func.stddev_pop(data_table.c.data), type_) - ]) + select([cast(func.stddev_pop(data_table.c.data), type_)]) ).scalar() eq_(round_decimal(ret, 9), result) - @testing.fails_on('postgresql+zxjdbc', - 'zxjdbc has no support for PG arrays') + @testing.fails_on( + "postgresql+zxjdbc", "zxjdbc has no support for PG arrays" + ) @testing.provide_metadata def test_arrays_pg(self): metadata = self.metadata - t1 = Table('t', metadata, - Column('x', postgresql.ARRAY(Float)), - Column('y', postgresql.ARRAY(REAL)), - Column('z', postgresql.ARRAY(postgresql.DOUBLE_PRECISION)), - Column('q', postgresql.ARRAY(Numeric)) - ) + t1 = Table( + "t", + metadata, + Column("x", postgresql.ARRAY(Float)), + Column("y", postgresql.ARRAY(REAL)), + Column("z", postgresql.ARRAY(postgresql.DOUBLE_PRECISION)), + Column("q", postgresql.ARRAY(Numeric)), + ) metadata.create_all() t1.insert().execute(x=[5], y=[5], z=[6], q=[decimal.Decimal("6.4")]) row = t1.select().execute().first() - eq_( - row, - ([5], [5], [6], [decimal.Decimal("6.4")]) - ) + eq_(row, ([5], [5], [6], [decimal.Decimal("6.4")])) - @testing.fails_on('postgresql+zxjdbc', - 'zxjdbc has no support for PG arrays') + @testing.fails_on( + "postgresql+zxjdbc", "zxjdbc has no support for PG arrays" + ) @testing.provide_metadata def test_arrays_base(self): metadata = self.metadata - t1 = Table('t', metadata, - Column('x', sqltypes.ARRAY(Float)), - Column('y', sqltypes.ARRAY(REAL)), - Column('z', sqltypes.ARRAY(postgresql.DOUBLE_PRECISION)), - Column('q', sqltypes.ARRAY(Numeric)) - ) + t1 = Table( + "t", + metadata, + Column("x", sqltypes.ARRAY(Float)), + Column("y", sqltypes.ARRAY(REAL)), + Column("z", sqltypes.ARRAY(postgresql.DOUBLE_PRECISION)), + Column("q", sqltypes.ARRAY(Numeric)), + ) metadata.create_all() t1.insert().execute(x=[5], y=[5], z=[6], q=[decimal.Decimal("6.4")]) row = t1.select().execute().first() - eq_( - row, - ([5], [5], [6], [decimal.Decimal("6.4")]) - ) + eq_(row, ([5], [5], [6], [decimal.Decimal("6.4")])) class EnumTest(fixtures.TestBase, AssertsExecutionResults): __backend__ = True - __only_on__ = 'postgresql > 8.3' + __only_on__ = "postgresql > 8.3" - @testing.fails_on('postgresql+zxjdbc', - 'zxjdbc fails on ENUM: column "XXX" is of type ' - 'XXX but expression is of type character varying') + @testing.fails_on( + "postgresql+zxjdbc", + 'zxjdbc fails on ENUM: column "XXX" is of type ' + "XXX but expression is of type character varying", + ) def test_create_table(self): metadata = MetaData(testing.db) t1 = Table( - 'table', metadata, - Column( - 'id', Integer, primary_key=True), + "table", + metadata, + Column("id", Integer, primary_key=True), Column( - 'value', Enum( - 'one', 'two', 'three', name='onetwothreetype'))) + "value", Enum("one", "two", "three", name="onetwothreetype") + ), + ) t1.create() t1.create(checkfirst=True) # check the create try: - t1.insert().execute(value='two') - t1.insert().execute(value='three') - t1.insert().execute(value='three') - eq_(t1.select().order_by(t1.c.id).execute().fetchall(), - [(1, 'two'), (2, 'three'), (3, 'three')]) + t1.insert().execute(value="two") + t1.insert().execute(value="three") + t1.insert().execute(value="three") + eq_( + t1.select().order_by(t1.c.id).execute().fetchall(), + [(1, "two"), (2, "three"), (3, "three")], + ) finally: metadata.drop_all() metadata.drop_all() def test_name_required(self): metadata = MetaData(testing.db) - etype = Enum('four', 'five', 'six', metadata=metadata) + etype = Enum("four", "five", "six", metadata=metadata) assert_raises(exc.CompileError, etype.create) - assert_raises(exc.CompileError, etype.compile, - dialect=postgresql.dialect()) + assert_raises( + exc.CompileError, etype.compile, dialect=postgresql.dialect() + ) - @testing.fails_on('postgresql+zxjdbc', - 'zxjdbc fails on ENUM: column "XXX" is of type ' - 'XXX but expression is of type character varying') + @testing.fails_on( + "postgresql+zxjdbc", + 'zxjdbc fails on ENUM: column "XXX" is of type ' + "XXX but expression is of type character varying", + ) @testing.provide_metadata def test_unicode_labels(self): metadata = self.metadata t1 = Table( - 'table', + "table", metadata, + Column("id", Integer, primary_key=True), Column( - 'id', - Integer, - primary_key=True), - Column( - 'value', + "value", Enum( - util.u('réveillé'), - util.u('drôle'), - util.u('S’il'), - name='onetwothreetype'))) + util.u("réveillé"), + util.u("drôle"), + util.u("S’il"), + name="onetwothreetype", + ), + ), + ) metadata.create_all() - t1.insert().execute(value=util.u('drôle')) - t1.insert().execute(value=util.u('réveillé')) - t1.insert().execute(value=util.u('S’il')) - eq_(t1.select().order_by(t1.c.id).execute().fetchall(), - [(1, util.u('drôle')), (2, util.u('réveillé')), - (3, util.u('S’il'))] - ) + t1.insert().execute(value=util.u("drôle")) + t1.insert().execute(value=util.u("réveillé")) + t1.insert().execute(value=util.u("S’il")) + eq_( + t1.select().order_by(t1.c.id).execute().fetchall(), + [ + (1, util.u("drôle")), + (2, util.u("réveillé")), + (3, util.u("S’il")), + ], + ) m2 = MetaData(testing.db) - t2 = Table('table', m2, autoload=True) + t2 = Table("table", m2, autoload=True) eq_( t2.c.value.type.enums, - [util.u('réveillé'), util.u('drôle'), util.u('S’il')] + [util.u("réveillé"), util.u("drôle"), util.u("S’il")], ) @testing.provide_metadata def test_non_native_enum(self): metadata = self.metadata t1 = Table( - 'foo', + "foo", metadata, Column( - 'bar', - Enum( - 'one', - 'two', - 'three', - name='myenum', - native_enum=False))) + "bar", + Enum("one", "two", "three", name="myenum", native_enum=False), + ), + ) def go(): t1.create(testing.db) self.assert_sql( - testing.db, go, [ - ("CREATE TABLE foo (\tbar " - "VARCHAR(5), \tCONSTRAINT myenum CHECK " - "(bar IN ('one', 'two', 'three')))", {})]) + testing.db, + go, + [ + ( + "CREATE TABLE foo (\tbar " + "VARCHAR(5), \tCONSTRAINT myenum CHECK " + "(bar IN ('one', 'two', 'three')))", + {}, + ) + ], + ) with testing.db.begin() as conn: - conn.execute( - t1.insert(), {'bar': 'two'} - ) - eq_( - conn.scalar(select([t1.c.bar])), 'two' - ) + conn.execute(t1.insert(), {"bar": "two"}) + eq_(conn.scalar(select([t1.c.bar])), "two") @testing.provide_metadata def test_non_native_enum_w_unicode(self): metadata = self.metadata t1 = Table( - 'foo', + "foo", metadata, Column( - 'bar', - Enum('B', util.u('Ü'), name='myenum', native_enum=False))) + "bar", Enum("B", util.u("Ü"), name="myenum", native_enum=False) + ), + ) def go(): t1.create(testing.db) @@ -247,29 +297,24 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): "VARCHAR(1), \tCONSTRAINT myenum CHECK " "(bar IN ('B', 'Ü')))" ), - {} + {}, ) - ]) + ], + ) with testing.db.begin() as conn: - conn.execute( - t1.insert(), {'bar': util.u('Ü')} - ) - eq_( - conn.scalar(select([t1.c.bar])), util.u('Ü') - ) + conn.execute(t1.insert(), {"bar": util.u("Ü")}) + eq_(conn.scalar(select([t1.c.bar])), util.u("Ü")) @testing.provide_metadata def test_disable_create(self): metadata = self.metadata - e1 = postgresql.ENUM('one', 'two', 'three', - name="myenum", - create_type=False) + e1 = postgresql.ENUM( + "one", "two", "three", name="myenum", create_type=False + ) - t1 = Table('e1', metadata, - Column('c1', e1) - ) + t1 = Table("e1", metadata, Column("c1", e1)) # table can be created separately # without conflict e1.create(bind=testing.db) @@ -288,20 +333,16 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): """ metadata = self.metadata - e1 = Enum('one', 'two', 'three', - name="myenum") - t1 = Table('e1', metadata, - Column('c1', e1) - ) + e1 = Enum("one", "two", "three", name="myenum") + t1 = Table("e1", metadata, Column("c1", e1)) - t2 = Table('e2', metadata, - Column('c1', e1) - ) + t2 = Table("e2", metadata, Column("c1", e1)) metadata.create_all(checkfirst=False) metadata.drop_all(checkfirst=False) - assert 'myenum' not in [ - e['name'] for e in inspect(testing.db).get_enums()] + assert "myenum" not in [ + e["name"] for e in inspect(testing.db).get_enums() + ] @testing.provide_metadata def test_generate_alone_on_metadata(self): @@ -314,37 +355,31 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): """ metadata = self.metadata - e1 = Enum('one', 'two', 'three', - name="myenum", metadata=self.metadata) + e1 = Enum("one", "two", "three", name="myenum", metadata=self.metadata) metadata.create_all(checkfirst=False) - assert 'myenum' in [ - e['name'] for e in inspect(testing.db).get_enums()] + assert "myenum" in [e["name"] for e in inspect(testing.db).get_enums()] metadata.drop_all(checkfirst=False) - assert 'myenum' not in [ - e['name'] for e in inspect(testing.db).get_enums()] + assert "myenum" not in [ + e["name"] for e in inspect(testing.db).get_enums() + ] @testing.provide_metadata def test_generate_multiple_on_metadata(self): metadata = self.metadata - e1 = Enum('one', 'two', 'three', - name="myenum", metadata=metadata) + e1 = Enum("one", "two", "three", name="myenum", metadata=metadata) - t1 = Table('e1', metadata, - Column('c1', e1) - ) + t1 = Table("e1", metadata, Column("c1", e1)) - t2 = Table('e2', metadata, - Column('c1', e1) - ) + t2 = Table("e2", metadata, Column("c1", e1)) metadata.create_all(checkfirst=False) - assert 'myenum' in [ - e['name'] for e in inspect(testing.db).get_enums()] + assert "myenum" in [e["name"] for e in inspect(testing.db).get_enums()] metadata.drop_all(checkfirst=False) - assert 'myenum' not in [ - e['name'] for e in inspect(testing.db).get_enums()] + assert "myenum" not in [ + e["name"] for e in inspect(testing.db).get_enums() + ] e1.create() # creates ENUM t1.create() # does not create ENUM @@ -354,55 +389,56 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): def test_generate_multiple_schemaname_on_metadata(self): metadata = self.metadata - Enum('one', 'two', 'three', name="myenum", metadata=metadata) - Enum('one', 'two', 'three', name="myenum", metadata=metadata, - schema="test_schema") + Enum("one", "two", "three", name="myenum", metadata=metadata) + Enum( + "one", + "two", + "three", + name="myenum", + metadata=metadata, + schema="test_schema", + ) metadata.create_all(checkfirst=False) - assert 'myenum' in [ - e['name'] for e in inspect(testing.db).get_enums()] - assert 'myenum' in [ - e['name'] for - e in inspect(testing.db).get_enums(schema="test_schema")] + assert "myenum" in [e["name"] for e in inspect(testing.db).get_enums()] + assert "myenum" in [ + e["name"] + for e in inspect(testing.db).get_enums(schema="test_schema") + ] metadata.drop_all(checkfirst=False) - assert 'myenum' not in [ - e['name'] for e in inspect(testing.db).get_enums()] - assert 'myenum' not in [ - e['name'] for - e in inspect(testing.db).get_enums(schema="test_schema")] + assert "myenum" not in [ + e["name"] for e in inspect(testing.db).get_enums() + ] + assert "myenum" not in [ + e["name"] + for e in inspect(testing.db).get_enums(schema="test_schema") + ] @testing.provide_metadata def test_drops_on_table(self): metadata = self.metadata - e1 = Enum('one', 'two', 'three', - name="myenum") - table = Table( - 'e1', metadata, - Column('c1', e1) - ) + e1 = Enum("one", "two", "three", name="myenum") + table = Table("e1", metadata, Column("c1", e1)) table.create() table.drop() - assert 'myenum' not in [ - e['name'] for e in inspect(testing.db).get_enums()] + assert "myenum" not in [ + e["name"] for e in inspect(testing.db).get_enums() + ] table.create() - assert 'myenum' in [ - e['name'] for e in inspect(testing.db).get_enums()] + assert "myenum" in [e["name"] for e in inspect(testing.db).get_enums()] table.drop() - assert 'myenum' not in [ - e['name'] for e in inspect(testing.db).get_enums()] + assert "myenum" not in [ + e["name"] for e in inspect(testing.db).get_enums() + ] @testing.provide_metadata def test_remain_on_table_metadata_wide(self): metadata = self.metadata - e1 = Enum('one', 'two', 'three', - name="myenum", metadata=metadata) - table = Table( - 'e1', metadata, - Column('c1', e1) - ) + e1 = Enum("one", "two", "three", name="myenum", metadata=metadata) + table = Table("e1", metadata, Column("c1", e1)) # need checkfirst here, otherwise enum will not be created assert_raises_message( @@ -414,11 +450,11 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): table.drop() table.create(checkfirst=True) table.drop() - assert 'myenum' in [ - e['name'] for e in inspect(testing.db).get_enums()] + assert "myenum" in [e["name"] for e in inspect(testing.db).get_enums()] metadata.drop_all() - assert 'myenum' not in [ - e['name'] for e in inspect(testing.db).get_enums()] + assert "myenum" not in [ + e["name"] for e in inspect(testing.db).get_enums() + ] def test_non_native_dialect(self): engine = engines.testing_engine() @@ -426,48 +462,51 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): engine.dialect.supports_native_enum = False metadata = MetaData() t1 = Table( - 'foo', + "foo", metadata, - Column( - 'bar', - Enum( - 'one', - 'two', - 'three', - name='myenum'))) + Column("bar", Enum("one", "two", "three", name="myenum")), + ) def go(): t1.create(engine) try: self.assert_sql( - engine, go, [ - ("CREATE TABLE foo (bar " - "VARCHAR(5), CONSTRAINT myenum CHECK " - "(bar IN ('one', 'two', 'three')))", {})]) + engine, + go, + [ + ( + "CREATE TABLE foo (bar " + "VARCHAR(5), CONSTRAINT myenum CHECK " + "(bar IN ('one', 'two', 'three')))", + {}, + ) + ], + ) finally: metadata.drop_all(engine) def test_standalone_enum(self): metadata = MetaData(testing.db) - etype = Enum('four', 'five', 'six', name='fourfivesixtype', - metadata=metadata) + etype = Enum( + "four", "five", "six", name="fourfivesixtype", metadata=metadata + ) etype.create() try: - assert testing.db.dialect.has_type(testing.db, - 'fourfivesixtype') + assert testing.db.dialect.has_type(testing.db, "fourfivesixtype") finally: etype.drop() - assert not testing.db.dialect.has_type(testing.db, - 'fourfivesixtype') + assert not testing.db.dialect.has_type( + testing.db, "fourfivesixtype" + ) metadata.create_all() try: - assert testing.db.dialect.has_type(testing.db, - 'fourfivesixtype') + assert testing.db.dialect.has_type(testing.db, "fourfivesixtype") finally: metadata.drop_all() - assert not testing.db.dialect.has_type(testing.db, - 'fourfivesixtype') + assert not testing.db.dialect.has_type( + testing.db, "fourfivesixtype" + ) def test_no_support(self): def server_version_info(self): @@ -489,57 +528,66 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): @testing.provide_metadata def test_reflection(self): metadata = self.metadata - etype = Enum('four', 'five', 'six', name='fourfivesixtype', - metadata=metadata) + etype = Enum( + "four", "five", "six", name="fourfivesixtype", metadata=metadata + ) t1 = Table( - 'table', metadata, - Column( - 'id', Integer, primary_key=True), + "table", + metadata, + Column("id", Integer, primary_key=True), Column( - 'value', Enum( - 'one', 'two', 'three', name='onetwothreetype')), - Column('value2', etype)) + "value", Enum("one", "two", "three", name="onetwothreetype") + ), + Column("value2", etype), + ) metadata.create_all() m2 = MetaData(testing.db) - t2 = Table('table', m2, autoload=True) - eq_(t2.c.value.type.enums, ['one', 'two', 'three']) - eq_(t2.c.value.type.name, 'onetwothreetype') - eq_(t2.c.value2.type.enums, ['four', 'five', 'six']) - eq_(t2.c.value2.type.name, 'fourfivesixtype') + t2 = Table("table", m2, autoload=True) + eq_(t2.c.value.type.enums, ["one", "two", "three"]) + eq_(t2.c.value.type.name, "onetwothreetype") + eq_(t2.c.value2.type.enums, ["four", "five", "six"]) + eq_(t2.c.value2.type.name, "fourfivesixtype") @testing.provide_metadata def test_schema_reflection(self): metadata = self.metadata etype = Enum( - 'four', - 'five', - 'six', - name='fourfivesixtype', - schema='test_schema', + "four", + "five", + "six", + name="fourfivesixtype", + schema="test_schema", metadata=metadata, ) Table( - 'table', metadata, - Column( - 'id', Integer, primary_key=True), + "table", + metadata, + Column("id", Integer, primary_key=True), Column( - 'value', Enum( - 'one', 'two', 'three', - name='onetwothreetype', schema='test_schema')), - Column('value2', etype)) + "value", + Enum( + "one", + "two", + "three", + name="onetwothreetype", + schema="test_schema", + ), + ), + Column("value2", etype), + ) metadata.create_all() m2 = MetaData(testing.db) - t2 = Table('table', m2, autoload=True) - eq_(t2.c.value.type.enums, ['one', 'two', 'three']) - eq_(t2.c.value.type.name, 'onetwothreetype') - eq_(t2.c.value2.type.enums, ['four', 'five', 'six']) - eq_(t2.c.value2.type.name, 'fourfivesixtype') - eq_(t2.c.value2.type.schema, 'test_schema') + t2 = Table("table", m2, autoload=True) + eq_(t2.c.value.type.enums, ["one", "two", "three"]) + eq_(t2.c.value.type.name, "onetwothreetype") + eq_(t2.c.value2.type.enums, ["four", "five", "six"]) + eq_(t2.c.value2.type.name, "fourfivesixtype") + eq_(t2.c.value2.type.schema, "test_schema") @testing.provide_metadata def test_custom_subclass(self): class MyEnum(TypeDecorator): - impl = Enum('oneHI', 'twoHI', 'threeHI', name='myenum') + impl = Enum("oneHI", "twoHI", "threeHI", name="myenum") def process_bind_param(self, value, dialect): if value is not None: @@ -551,108 +599,103 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): value += "THERE" return value - t1 = Table( - 'table1', self.metadata, - Column('data', MyEnum()) - ) + t1 = Table("table1", self.metadata, Column("data", MyEnum())) self.metadata.create_all(testing.db) with testing.db.connect() as conn: conn.execute(t1.insert(), {"data": "two"}) - eq_( - conn.scalar(select([t1.c.data])), - "twoHITHERE" - ) + eq_(conn.scalar(select([t1.c.data])), "twoHITHERE") @testing.provide_metadata def test_generic_w_pg_variant(self): some_table = Table( - 'some_table', self.metadata, + "some_table", + self.metadata, Column( - 'data', + "data", Enum( - "one", "two", "three", - native_enum=True # make sure this is True because - # it should *not* take effect due to - # the variant + "one", + "two", + "three", + native_enum=True # make sure this is True because + # it should *not* take effect due to + # the variant ).with_variant( postgresql.ENUM("four", "five", "six", name="my_enum"), - "postgresql" - ) - ) + "postgresql", + ), + ), ) with testing.db.begin() as conn: - assert 'my_enum' not in [ - e['name'] for e in inspect(conn).get_enums()] + assert "my_enum" not in [ + e["name"] for e in inspect(conn).get_enums() + ] self.metadata.create_all(conn) - assert 'my_enum' in [ - e['name'] for e in inspect(conn).get_enums()] + assert "my_enum" in [e["name"] for e in inspect(conn).get_enums()] - conn.execute( - some_table.insert(), {"data": "five"} - ) + conn.execute(some_table.insert(), {"data": "five"}) self.metadata.drop_all(conn) - assert 'my_enum' not in [ - e['name'] for e in inspect(conn).get_enums()] + assert "my_enum" not in [ + e["name"] for e in inspect(conn).get_enums() + ] @testing.provide_metadata def test_generic_w_some_other_variant(self): some_table = Table( - 'some_table', self.metadata, + "some_table", + self.metadata, Column( - 'data', + "data", Enum( - "one", "two", "three", - name="my_enum", - native_enum=True - ).with_variant( - Enum("four", "five", "six"), - "mysql" - ) - ) + "one", "two", "three", name="my_enum", native_enum=True + ).with_variant(Enum("four", "five", "six"), "mysql"), + ), ) with testing.db.begin() as conn: - assert 'my_enum' not in [ - e['name'] for e in inspect(conn).get_enums()] + assert "my_enum" not in [ + e["name"] for e in inspect(conn).get_enums() + ] self.metadata.create_all(conn) - assert 'my_enum' in [ - e['name'] for e in inspect(conn).get_enums()] + assert "my_enum" in [e["name"] for e in inspect(conn).get_enums()] - conn.execute( - some_table.insert(), {"data": "two"} - ) + conn.execute(some_table.insert(), {"data": "two"}) self.metadata.drop_all(conn) - assert 'my_enum' not in [ - e['name'] for e in inspect(conn).get_enums()] + assert "my_enum" not in [ + e["name"] for e in inspect(conn).get_enums() + ] class OIDTest(fixtures.TestBase): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True @testing.provide_metadata def test_reflection(self): metadata = self.metadata - Table('table', metadata, Column('x', Integer), - Column('y', postgresql.OID)) + Table( + "table", + metadata, + Column("x", Integer), + Column("y", postgresql.OID), + ) metadata.create_all() m2 = MetaData() - t2 = Table('table', m2, autoload_with=testing.db, autoload=True) + t2 = Table("table", m2, autoload_with=testing.db, autoload=True) assert isinstance(t2.c.y.type, postgresql.OID) class RegClassTest(fixtures.TestBase): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True @staticmethod @@ -661,50 +704,60 @@ class RegClassTest(fixtures.TestBase): return conn.scalar(select([expression])) def test_cast_name(self): - eq_( - self._scalar(cast('pg_class', postgresql.REGCLASS)), - 'pg_class' - ) + eq_(self._scalar(cast("pg_class", postgresql.REGCLASS)), "pg_class") def test_cast_path(self): eq_( - self._scalar(cast('pg_catalog.pg_class', postgresql.REGCLASS)), - 'pg_class' + self._scalar(cast("pg_catalog.pg_class", postgresql.REGCLASS)), + "pg_class", ) def test_cast_oid(self): - regclass = cast('pg_class', postgresql.REGCLASS) + regclass = cast("pg_class", postgresql.REGCLASS) oid = self._scalar(cast(regclass, postgresql.OID)) assert isinstance(oid, int) - eq_(self._scalar(cast(oid, postgresql.REGCLASS)), 'pg_class') + eq_(self._scalar(cast(oid, postgresql.REGCLASS)), "pg_class") def test_cast_whereclause(self): - pga = Table('pg_attribute', MetaData(testing.db), - Column('attrelid', postgresql.OID), - Column('attname', String(64))) + pga = Table( + "pg_attribute", + MetaData(testing.db), + Column("attrelid", postgresql.OID), + Column("attname", String(64)), + ) with testing.db.connect() as conn: oid = conn.scalar( select([pga.c.attrelid]).where( - pga.c.attrelid == cast('pg_class', postgresql.REGCLASS) + pga.c.attrelid == cast("pg_class", postgresql.REGCLASS) ) ) assert isinstance(oid, int) class NumericInterpretationTest(fixtures.TestBase): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True def test_numeric_codes(self): - from sqlalchemy.dialects.postgresql import pg8000, pygresql, \ - psycopg2, psycopg2cffi, base + from sqlalchemy.dialects.postgresql import ( + pg8000, + pygresql, + psycopg2, + psycopg2cffi, + base, + ) - dialects = (pg8000.dialect(), pygresql.dialect(), - psycopg2.dialect(), psycopg2cffi.dialect()) + dialects = ( + pg8000.dialect(), + pygresql.dialect(), + psycopg2.dialect(), + psycopg2cffi.dialect(), + ) for dialect in dialects: typ = Numeric().dialect_impl(dialect) - for code in base._INT_TYPES + base._FLOAT_TYPES + \ - base._DECIMAL_TYPES: + for code in ( + base._INT_TYPES + base._FLOAT_TYPES + base._DECIMAL_TYPES + ): proc = typ.result_processor(dialect, code) val = 23.7 if proc is not None: @@ -716,13 +769,15 @@ class NumericInterpretationTest(fixtures.TestBase): metadata = self.metadata # pg8000 appears to fail when the value is 0, # returns an int instead of decimal. - t = Table('t', metadata, - Column('id', Integer, primary_key=True), - Column('nd', Numeric(asdecimal=True), default=1), - Column('nf', Numeric(asdecimal=False), default=1), - Column('fd', Float(asdecimal=True), default=1), - Column('ff', Float(asdecimal=False), default=1), - ) + t = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("nd", Numeric(asdecimal=True), default=1), + Column("nf", Numeric(asdecimal=False), default=1), + Column("fd", Float(asdecimal=True), default=1), + Column("ff", Float(asdecimal=False), default=1), + ) metadata.create_all() r = t.insert().execute() @@ -731,18 +786,12 @@ class NumericInterpretationTest(fixtures.TestBase): assert isinstance(row[2], float) assert isinstance(row[3], decimal.Decimal) assert isinstance(row[4], float) - eq_( - row, - (1, decimal.Decimal("1"), 1, decimal.Decimal("1"), 1) - ) + eq_(row, (1, decimal.Decimal("1"), 1, decimal.Decimal("1"), 1)) class PythonTypeTest(fixtures.TestBase): def test_interval(self): - is_( - postgresql.INTERVAL().python_type, - datetime.timedelta - ) + is_(postgresql.INTERVAL().python_type, datetime.timedelta) class TimezoneTest(fixtures.TestBase): @@ -756,7 +805,7 @@ class TimezoneTest(fixtures.TestBase): test illustrates two ways to have datetime types with and without timezone info. """ - __only_on__ = 'postgresql' + __only_on__ = "postgresql" @classmethod def setup_class(cls): @@ -767,48 +816,59 @@ class TimezoneTest(fixtures.TestBase): # TIMESTAMP WITH TIMEZONE tztable = Table( - 'tztable', metadata, - Column( - 'id', Integer, primary_key=True), + "tztable", + metadata, + Column("id", Integer, primary_key=True), Column( - 'date', DateTime( - timezone=True), onupdate=func.current_timestamp()), - Column('name', String(20))) + "date", + DateTime(timezone=True), + onupdate=func.current_timestamp(), + ), + Column("name", String(20)), + ) notztable = Table( - 'notztable', metadata, - Column( - 'id', Integer, primary_key=True), + "notztable", + metadata, + Column("id", Integer, primary_key=True), Column( - 'date', DateTime( - timezone=False), onupdate=cast( - func.current_timestamp(), DateTime( - timezone=False))), - Column('name', String(20))) + "date", + DateTime(timezone=False), + onupdate=cast( + func.current_timestamp(), DateTime(timezone=False) + ), + ), + Column("name", String(20)), + ) metadata.create_all() @classmethod def teardown_class(cls): metadata.drop_all() - @testing.fails_on('postgresql+zxjdbc', - "XXX: postgresql+zxjdbc doesn't give a tzinfo back") + @testing.fails_on( + "postgresql+zxjdbc", + "XXX: postgresql+zxjdbc doesn't give a tzinfo back", + ) def test_with_timezone(self): # get a date with a tzinfo - somedate = \ - testing.db.connect().scalar(func.current_timestamp().select()) + somedate = testing.db.connect().scalar( + func.current_timestamp().select() + ) assert somedate.tzinfo - tztable.insert().execute(id=1, name='row1', date=somedate) - row = select([tztable.c.date], tztable.c.id - == 1).execute().first() + tztable.insert().execute(id=1, name="row1", date=somedate) + row = select([tztable.c.date], tztable.c.id == 1).execute().first() eq_(row[0], somedate) - eq_(somedate.tzinfo.utcoffset(somedate), - row[0].tzinfo.utcoffset(row[0])) - result = tztable.update(tztable.c.id - == 1).returning(tztable.c.date).\ - execute(name='newname' - ) + eq_( + somedate.tzinfo.utcoffset(somedate), + row[0].tzinfo.utcoffset(row[0]), + ) + result = ( + tztable.update(tztable.c.id == 1) + .returning(tztable.c.date) + .execute(name="newname") + ) row = result.first() assert row[0] >= somedate @@ -816,17 +876,17 @@ class TimezoneTest(fixtures.TestBase): # get a date without a tzinfo - somedate = datetime.datetime(2005, 10, 20, 11, 52, 0, ) + somedate = datetime.datetime(2005, 10, 20, 11, 52, 0) assert not somedate.tzinfo - notztable.insert().execute(id=1, name='row1', date=somedate) - row = select([notztable.c.date], notztable.c.id - == 1).execute().first() + notztable.insert().execute(id=1, name="row1", date=somedate) + row = select([notztable.c.date], notztable.c.id == 1).execute().first() eq_(row[0], somedate) eq_(row[0].tzinfo, None) - result = notztable.update(notztable.c.id - == 1).returning(notztable.c.date).\ - execute(name='newname' - ) + result = ( + notztable.update(notztable.c.id == 1) + .returning(notztable.c.date) + .execute(name="newname") + ) row = result.first() assert row[0] >= somedate @@ -834,46 +894,51 @@ class TimezoneTest(fixtures.TestBase): class TimePrecisionTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = postgresql.dialect() - __prefer__ = 'postgresql' + __prefer__ = "postgresql" __backend__ = True def test_compile(self): for type_, expected in [ - (postgresql.TIME(), 'TIME WITHOUT TIME ZONE'), - (postgresql.TIME(precision=5), 'TIME(5) WITHOUT TIME ZONE' - ), - (postgresql.TIME(timezone=True, precision=5), - 'TIME(5) WITH TIME ZONE'), - (postgresql.TIMESTAMP(), 'TIMESTAMP WITHOUT TIME ZONE'), - (postgresql.TIMESTAMP(precision=5), - 'TIMESTAMP(5) WITHOUT TIME ZONE'), - (postgresql.TIMESTAMP(timezone=True, precision=5), - 'TIMESTAMP(5) WITH TIME ZONE'), - (postgresql.TIME(precision=0), - 'TIME(0) WITHOUT TIME ZONE'), - (postgresql.TIMESTAMP(precision=0), - 'TIMESTAMP(0) WITHOUT TIME ZONE'), + (postgresql.TIME(), "TIME WITHOUT TIME ZONE"), + (postgresql.TIME(precision=5), "TIME(5) WITHOUT TIME ZONE"), + ( + postgresql.TIME(timezone=True, precision=5), + "TIME(5) WITH TIME ZONE", + ), + (postgresql.TIMESTAMP(), "TIMESTAMP WITHOUT TIME ZONE"), + ( + postgresql.TIMESTAMP(precision=5), + "TIMESTAMP(5) WITHOUT TIME ZONE", + ), + ( + postgresql.TIMESTAMP(timezone=True, precision=5), + "TIMESTAMP(5) WITH TIME ZONE", + ), + (postgresql.TIME(precision=0), "TIME(0) WITHOUT TIME ZONE"), + ( + postgresql.TIMESTAMP(precision=0), + "TIMESTAMP(0) WITHOUT TIME ZONE", + ), ]: self.assert_compile(type_, expected) - @testing.only_on('postgresql', 'DB specific feature') + @testing.only_on("postgresql", "DB specific feature") @testing.provide_metadata def test_reflection(self): metadata = self.metadata t1 = Table( - 't1', + "t1", metadata, - Column('c1', postgresql.TIME()), - Column('c2', postgresql.TIME(precision=5)), - Column('c3', postgresql.TIME(timezone=True, precision=5)), - Column('c4', postgresql.TIMESTAMP()), - Column('c5', postgresql.TIMESTAMP(precision=5)), - Column('c6', postgresql.TIMESTAMP(timezone=True, - precision=5)), + Column("c1", postgresql.TIME()), + Column("c2", postgresql.TIME(precision=5)), + Column("c3", postgresql.TIME(timezone=True, precision=5)), + Column("c4", postgresql.TIMESTAMP()), + Column("c5", postgresql.TIMESTAMP(precision=5)), + Column("c6", postgresql.TIMESTAMP(timezone=True, precision=5)), ) t1.create() m2 = MetaData(testing.db) - t2 = Table('t1', m2, autoload=True) + t2 = Table("t1", m2, autoload=True) eq_(t2.c.c1.type.precision, None) eq_(t2.c.c2.type.precision, 5) eq_(t2.c.c3.type.precision, 5) @@ -889,159 +954,140 @@ class TimePrecisionTest(fixtures.TestBase, AssertsCompiledSQL): class ArrayTest(AssertsCompiledSQL, fixtures.TestBase): - __dialect__ = 'postgresql' + __dialect__ = "postgresql" def test_array_type_render_str(self): - self.assert_compile( - postgresql.ARRAY(Unicode(30)), - "VARCHAR(30)[]" - ) + self.assert_compile(postgresql.ARRAY(Unicode(30)), "VARCHAR(30)[]") def test_array_type_render_str_collate(self): self.assert_compile( postgresql.ARRAY(Unicode(30, collation="en_US")), - 'VARCHAR(30)[] COLLATE "en_US"' + 'VARCHAR(30)[] COLLATE "en_US"', ) def test_array_type_render_str_multidim(self): self.assert_compile( - postgresql.ARRAY(Unicode(30), dimensions=2), - "VARCHAR(30)[][]" + postgresql.ARRAY(Unicode(30), dimensions=2), "VARCHAR(30)[][]" ) self.assert_compile( - postgresql.ARRAY(Unicode(30), dimensions=3), - "VARCHAR(30)[][][]" + postgresql.ARRAY(Unicode(30), dimensions=3), "VARCHAR(30)[][][]" ) def test_array_type_render_str_collate_multidim(self): self.assert_compile( postgresql.ARRAY(Unicode(30, collation="en_US"), dimensions=2), - 'VARCHAR(30)[][] COLLATE "en_US"' + 'VARCHAR(30)[][] COLLATE "en_US"', ) self.assert_compile( postgresql.ARRAY(Unicode(30, collation="en_US"), dimensions=3), - 'VARCHAR(30)[][][] COLLATE "en_US"' + 'VARCHAR(30)[][][] COLLATE "en_US"', ) - def test_array_int_index(self): - col = column('x', postgresql.ARRAY(Integer)) + col = column("x", postgresql.ARRAY(Integer)) self.assert_compile( select([col[3]]), "SELECT x[%(x_1)s] AS anon_1", - checkparams={'x_1': 3} + checkparams={"x_1": 3}, ) def test_array_any(self): - col = column('x', postgresql.ARRAY(Integer)) + col = column("x", postgresql.ARRAY(Integer)) self.assert_compile( select([col.any(7, operator=operators.lt)]), "SELECT %(param_1)s < ANY (x) AS anon_1", - checkparams={'param_1': 7} + checkparams={"param_1": 7}, ) def test_array_all(self): - col = column('x', postgresql.ARRAY(Integer)) + col = column("x", postgresql.ARRAY(Integer)) self.assert_compile( select([col.all(7, operator=operators.lt)]), "SELECT %(param_1)s < ALL (x) AS anon_1", - checkparams={'param_1': 7} + checkparams={"param_1": 7}, ) def test_array_contains(self): - col = column('x', postgresql.ARRAY(Integer)) + col = column("x", postgresql.ARRAY(Integer)) self.assert_compile( select([col.contains(array([4, 5, 6]))]), "SELECT x @> ARRAY[%(param_1)s, %(param_2)s, %(param_3)s] " "AS anon_1", - checkparams={'param_1': 4, 'param_3': 6, 'param_2': 5} + checkparams={"param_1": 4, "param_3": 6, "param_2": 5}, ) def test_contains_override_raises(self): - col = column('x', postgresql.ARRAY(Integer)) + col = column("x", postgresql.ARRAY(Integer)) assert_raises_message( NotImplementedError, "Operator 'contains' is not supported on this expression", - lambda: 'foo' in col + lambda: "foo" in col, ) def test_array_contained_by(self): - col = column('x', postgresql.ARRAY(Integer)) + col = column("x", postgresql.ARRAY(Integer)) self.assert_compile( select([col.contained_by(array([4, 5, 6]))]), "SELECT x <@ ARRAY[%(param_1)s, %(param_2)s, %(param_3)s] " "AS anon_1", - checkparams={'param_1': 4, 'param_3': 6, 'param_2': 5} + checkparams={"param_1": 4, "param_3": 6, "param_2": 5}, ) def test_array_overlap(self): - col = column('x', postgresql.ARRAY(Integer)) + col = column("x", postgresql.ARRAY(Integer)) self.assert_compile( select([col.overlap(array([4, 5, 6]))]), "SELECT x && ARRAY[%(param_1)s, %(param_2)s, %(param_3)s] " "AS anon_1", - checkparams={'param_1': 4, 'param_3': 6, 'param_2': 5} + checkparams={"param_1": 4, "param_3": 6, "param_2": 5}, ) def test_array_slice_index(self): - col = column('x', postgresql.ARRAY(Integer)) + col = column("x", postgresql.ARRAY(Integer)) self.assert_compile( select([col[5:10]]), "SELECT x[%(x_1)s:%(x_2)s] AS anon_1", - checkparams={'x_2': 10, 'x_1': 5} + checkparams={"x_2": 10, "x_1": 5}, ) def test_array_dim_index(self): - col = column('x', postgresql.ARRAY(Integer, dimensions=2)) + col = column("x", postgresql.ARRAY(Integer, dimensions=2)) self.assert_compile( select([col[3][5]]), "SELECT x[%(x_1)s][%(param_1)s] AS anon_1", - checkparams={'x_1': 3, 'param_1': 5} + checkparams={"x_1": 3, "param_1": 5}, ) def test_array_concat(self): - col = column('x', postgresql.ARRAY(Integer)) + col = column("x", postgresql.ARRAY(Integer)) literal = array([4, 5]) self.assert_compile( select([col + literal]), "SELECT x || ARRAY[%(param_1)s, %(param_2)s] AS anon_1", - checkparams={'param_1': 4, 'param_2': 5} + checkparams={"param_1": 4, "param_2": 5}, ) def test_array_index_map_dimensions(self): - col = column('x', postgresql.ARRAY(Integer, dimensions=3)) - is_( - col[5].type._type_affinity, ARRAY - ) - assert isinstance( - col[5].type, postgresql.ARRAY - ) - eq_( - col[5].type.dimensions, 2 - ) - is_( - col[5][6].type._type_affinity, ARRAY - ) - assert isinstance( - col[5][6].type, postgresql.ARRAY - ) - eq_( - col[5][6].type.dimensions, 1 - ) - is_( - col[5][6][7].type._type_affinity, Integer - ) + col = column("x", postgresql.ARRAY(Integer, dimensions=3)) + is_(col[5].type._type_affinity, ARRAY) + assert isinstance(col[5].type, postgresql.ARRAY) + eq_(col[5].type.dimensions, 2) + is_(col[5][6].type._type_affinity, ARRAY) + assert isinstance(col[5][6].type, postgresql.ARRAY) + eq_(col[5][6].type.dimensions, 1) + is_(col[5][6][7].type._type_affinity, Integer) def test_array_getitem_single_type(self): m = MetaData() arrtable = Table( - 'arrtable', m, - Column('intarr', postgresql.ARRAY(Integer)), - Column('strarr', postgresql.ARRAY(String)), + "arrtable", + m, + Column("intarr", postgresql.ARRAY(Integer)), + Column("strarr", postgresql.ARRAY(String)), ) is_(arrtable.c.intarr[1].type._type_affinity, Integer) is_(arrtable.c.strarr[1].type._type_affinity, String) @@ -1049,9 +1095,10 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase): def test_array_getitem_slice_type(self): m = MetaData() arrtable = Table( - 'arrtable', m, - Column('intarr', postgresql.ARRAY(Integer)), - Column('strarr', postgresql.ARRAY(String)), + "arrtable", + m, + Column("intarr", postgresql.ARRAY(Integer)), + Column("strarr", postgresql.ARRAY(String)), ) # type affinity is Array... @@ -1067,95 +1114,92 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase): to be required by PostgreSQL. """ - stmt = select([ - func.array_cat( - array([1, 2, 3]), - array([4, 5, 6]), - type_=postgresql.ARRAY(Integer) - )[2:5] - ]) + stmt = select( + [ + func.array_cat( + array([1, 2, 3]), + array([4, 5, 6]), + type_=postgresql.ARRAY(Integer), + )[2:5] + ] + ) self.assert_compile( stmt, "SELECT (array_cat(ARRAY[%(param_1)s, %(param_2)s, %(param_3)s], " "ARRAY[%(param_4)s, %(param_5)s, %(param_6)s]))" - "[%(param_7)s:%(param_8)s] AS anon_1" + "[%(param_7)s:%(param_8)s] AS anon_1", ) self.assert_compile( func.array_cat( array([1, 2, 3]), array([4, 5, 6]), - type_=postgresql.ARRAY(Integer) + type_=postgresql.ARRAY(Integer), )[3], "(array_cat(ARRAY[%(param_1)s, %(param_2)s, %(param_3)s], " - "ARRAY[%(param_4)s, %(param_5)s, %(param_6)s]))[%(array_cat_1)s]" + "ARRAY[%(param_4)s, %(param_5)s, %(param_6)s]))[%(array_cat_1)s]", ) def test_array_agg_generic(self): - expr = func.array_agg(column('q', Integer)) + expr = func.array_agg(column("q", Integer)) is_(expr.type.__class__, types.ARRAY) is_(expr.type.item_type.__class__, Integer) def test_array_agg_specific(self): from sqlalchemy.dialects.postgresql import array_agg - expr = array_agg(column('q', Integer)) + + expr = array_agg(column("q", Integer)) is_(expr.type.__class__, postgresql.ARRAY) is_(expr.type.item_type.__class__, Integer) class ArrayRoundTripTest(object): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True - __unsupported_on__ = 'postgresql+pg8000', 'postgresql+zxjdbc' + __unsupported_on__ = "postgresql+pg8000", "postgresql+zxjdbc" ARRAY = postgresql.ARRAY @classmethod def define_tables(cls, metadata): - class ProcValue(TypeDecorator): impl = cls.ARRAY(Integer, dimensions=2) def process_bind_param(self, value, dialect): if value is None: return None - return [ - [x + 5 for x in v] - for v in value - ] + return [[x + 5 for x in v] for v in value] def process_result_value(self, value, dialect): if value is None: return None - return [ - [x - 7 for x in v] - for v in value - ] - - Table('arrtable', metadata, - Column('id', Integer, primary_key=True), - Column('intarr', cls.ARRAY(Integer)), - Column('strarr', cls.ARRAY(Unicode())), - Column('dimarr', ProcValue) - ) - - Table('dim_arrtable', metadata, - Column('id', Integer, primary_key=True), - Column('intarr', cls.ARRAY(Integer, dimensions=1)), - Column('strarr', cls.ARRAY(Unicode(), dimensions=1)), - Column('dimarr', ProcValue) - ) + return [[x - 7 for x in v] for v in value] - def _fixture_456(self, table): - testing.db.execute( - table.insert(), - intarr=[4, 5, 6] + Table( + "arrtable", + metadata, + Column("id", Integer, primary_key=True), + Column("intarr", cls.ARRAY(Integer)), + Column("strarr", cls.ARRAY(Unicode())), + Column("dimarr", ProcValue), ) + Table( + "dim_arrtable", + metadata, + Column("id", Integer, primary_key=True), + Column("intarr", cls.ARRAY(Integer, dimensions=1)), + Column("strarr", cls.ARRAY(Unicode(), dimensions=1)), + Column("dimarr", ProcValue), + ) + + def _fixture_456(self, table): + testing.db.execute(table.insert(), intarr=[4, 5, 6]) + def test_reflect_array_column(self): metadata2 = MetaData(testing.db) - tbl = Table('arrtable', metadata2, autoload=True) + tbl = Table("arrtable", metadata2, autoload=True) assert isinstance(tbl.c.intarr.type, self.ARRAY) assert isinstance(tbl.c.strarr.type, self.ARRAY) assert isinstance(tbl.c.intarr.type.item_type, Integer) @@ -1166,249 +1210,220 @@ class ArrayRoundTripTest(object): m = self.metadata t = Table( - 't', m, Column('data', - sqltypes.ARRAY(String(50, collation="en_US"))) + "t", + m, + Column("data", sqltypes.ARRAY(String(50, collation="en_US"))), ) t.create() @testing.provide_metadata def test_array_agg(self): - values_table = Table('values', self.metadata, Column('value', Integer)) + values_table = Table("values", self.metadata, Column("value", Integer)) self.metadata.create_all(testing.db) testing.db.execute( - values_table.insert(), - [{'value': i} for i in range(1, 10)] + values_table.insert(), [{"value": i} for i in range(1, 10)] ) stmt = select([func.array_agg(values_table.c.value)]) - eq_( - testing.db.execute(stmt).scalar(), - list(range(1, 10)) - ) + eq_(testing.db.execute(stmt).scalar(), list(range(1, 10))) stmt = select([func.array_agg(values_table.c.value)[3]]) - eq_( - testing.db.execute(stmt).scalar(), - 3 - ) + eq_(testing.db.execute(stmt).scalar(), 3) stmt = select([func.array_agg(values_table.c.value)[2:4]]) - eq_( - testing.db.execute(stmt).scalar(), - [2, 3, 4] - ) + eq_(testing.db.execute(stmt).scalar(), [2, 3, 4]) def test_array_index_slice_exprs(self): """test a variety of expressions that sometimes need parenthesizing""" stmt = select([array([1, 2, 3, 4])[2:3]]) - eq_( - testing.db.execute(stmt).scalar(), - [2, 3] - ) + eq_(testing.db.execute(stmt).scalar(), [2, 3]) stmt = select([array([1, 2, 3, 4])[2]]) - eq_( - testing.db.execute(stmt).scalar(), - 2 - ) + eq_(testing.db.execute(stmt).scalar(), 2) stmt = select([(array([1, 2]) + array([3, 4]))[2:3]]) - eq_( - testing.db.execute(stmt).scalar(), - [2, 3] - ) + eq_(testing.db.execute(stmt).scalar(), [2, 3]) stmt = select([array([1, 2]) + array([3, 4])[2:3]]) - eq_( - testing.db.execute(stmt).scalar(), - [1, 2, 4] - ) + eq_(testing.db.execute(stmt).scalar(), [1, 2, 4]) stmt = select([array([1, 2])[2:3] + array([3, 4])]) - eq_( - testing.db.execute(stmt).scalar(), - [2, 3, 4] - ) + eq_(testing.db.execute(stmt).scalar(), [2, 3, 4]) - stmt = select([ - func.array_cat( - array([1, 2, 3]), - array([4, 5, 6]), - type_=self.ARRAY(Integer) - )[2:5] - ]) - eq_( - testing.db.execute(stmt).scalar(), [2, 3, 4, 5] + stmt = select( + [ + func.array_cat( + array([1, 2, 3]), + array([4, 5, 6]), + type_=self.ARRAY(Integer), + )[2:5] + ] ) + eq_(testing.db.execute(stmt).scalar(), [2, 3, 4, 5]) def test_any_all_exprs_array(self): - stmt = select([ - 3 == any_(func.array_cat( - array([1, 2, 3]), - array([4, 5, 6]), - type_=self.ARRAY(Integer) - )) - ]) - eq_( - testing.db.execute(stmt).scalar(), True + stmt = select( + [ + 3 + == any_( + func.array_cat( + array([1, 2, 3]), + array([4, 5, 6]), + type_=self.ARRAY(Integer), + ) + ) + ] ) + eq_(testing.db.execute(stmt).scalar(), True) def test_insert_array(self): arrtable = self.tables.arrtable - arrtable.insert().execute(intarr=[1, 2, 3], strarr=[util.u('abc'), - util.u('def')]) + arrtable.insert().execute( + intarr=[1, 2, 3], strarr=[util.u("abc"), util.u("def")] + ) results = arrtable.select().execute().fetchall() eq_(len(results), 1) - eq_(results[0]['intarr'], [1, 2, 3]) - eq_(results[0]['strarr'], [util.u('abc'), util.u('def')]) + eq_(results[0]["intarr"], [1, 2, 3]) + eq_(results[0]["strarr"], [util.u("abc"), util.u("def")]) def test_insert_array_w_null(self): arrtable = self.tables.arrtable - arrtable.insert().execute(intarr=[1, None, 3], strarr=[util.u('abc'), - None]) + arrtable.insert().execute( + intarr=[1, None, 3], strarr=[util.u("abc"), None] + ) results = arrtable.select().execute().fetchall() eq_(len(results), 1) - eq_(results[0]['intarr'], [1, None, 3]) - eq_(results[0]['strarr'], [util.u('abc'), None]) + eq_(results[0]["intarr"], [1, None, 3]) + eq_(results[0]["strarr"], [util.u("abc"), None]) def test_array_where(self): arrtable = self.tables.arrtable - arrtable.insert().execute(intarr=[1, 2, 3], strarr=[util.u('abc'), - util.u('def')]) - arrtable.insert().execute(intarr=[4, 5, 6], strarr=util.u('ABC')) - results = arrtable.select().where( - arrtable.c.intarr == [ - 1, - 2, - 3]).execute().fetchall() + arrtable.insert().execute( + intarr=[1, 2, 3], strarr=[util.u("abc"), util.u("def")] + ) + arrtable.insert().execute(intarr=[4, 5, 6], strarr=util.u("ABC")) + results = ( + arrtable.select() + .where(arrtable.c.intarr == [1, 2, 3]) + .execute() + .fetchall() + ) eq_(len(results), 1) - eq_(results[0]['intarr'], [1, 2, 3]) + eq_(results[0]["intarr"], [1, 2, 3]) def test_array_concat(self): arrtable = self.tables.arrtable - arrtable.insert().execute(intarr=[1, 2, 3], - strarr=[util.u('abc'), util.u('def')]) - results = select([arrtable.c.intarr + [4, 5, - 6]]).execute().fetchall() + arrtable.insert().execute( + intarr=[1, 2, 3], strarr=[util.u("abc"), util.u("def")] + ) + results = select([arrtable.c.intarr + [4, 5, 6]]).execute().fetchall() eq_(len(results), 1) - eq_(results[0][0], [1, 2, 3, 4, 5, 6, ]) + eq_(results[0][0], [1, 2, 3, 4, 5, 6]) def test_array_comparison(self): arrtable = self.tables.arrtable - arrtable.insert().execute(id=5, intarr=[1, 2, 3], - strarr=[util.u('abc'), util.u('def')]) - results = select([arrtable.c.id])\ - .where(arrtable.c.intarr < [4, 5, 6])\ - .execute()\ + arrtable.insert().execute( + id=5, intarr=[1, 2, 3], strarr=[util.u("abc"), util.u("def")] + ) + results = ( + select([arrtable.c.id]) + .where(arrtable.c.intarr < [4, 5, 6]) + .execute() .fetchall() + ) eq_(len(results), 1) eq_(results[0][0], 5) def test_array_subtype_resultprocessor(self): arrtable = self.tables.arrtable - arrtable.insert().execute(intarr=[4, 5, 6], - strarr=[[util.ue('m\xe4\xe4')], [ - util.ue('m\xf6\xf6')]]) - arrtable.insert().execute(intarr=[1, 2, 3], strarr=[ - util.ue('m\xe4\xe4'), util.ue('m\xf6\xf6')]) - results = \ + arrtable.insert().execute( + intarr=[4, 5, 6], + strarr=[[util.ue("m\xe4\xe4")], [util.ue("m\xf6\xf6")]], + ) + arrtable.insert().execute( + intarr=[1, 2, 3], + strarr=[util.ue("m\xe4\xe4"), util.ue("m\xf6\xf6")], + ) + results = ( arrtable.select(order_by=[arrtable.c.intarr]).execute().fetchall() + ) eq_(len(results), 2) - eq_(results[0]['strarr'], [util.ue('m\xe4\xe4'), util.ue('m\xf6\xf6')]) - eq_(results[1]['strarr'], - [[util.ue('m\xe4\xe4')], - [util.ue('m\xf6\xf6')]]) + eq_(results[0]["strarr"], [util.ue("m\xe4\xe4"), util.ue("m\xf6\xf6")]) + eq_( + results[1]["strarr"], + [[util.ue("m\xe4\xe4")], [util.ue("m\xf6\xf6")]], + ) def test_array_literal(self): eq_( testing.db.scalar( - select([ - postgresql.array([1, 2]) + postgresql.array([3, 4, 5]) - ]) - ), [1, 2, 3, 4, 5] + select( + [postgresql.array([1, 2]) + postgresql.array([3, 4, 5])] + ) + ), + [1, 2, 3, 4, 5], ) def test_array_literal_compare(self): eq_( - testing.db.scalar( - select([ - postgresql.array([1, 2]) < [3, 4, 5] - ]) - ), True + testing.db.scalar(select([postgresql.array([1, 2]) < [3, 4, 5]])), + True, ) def test_array_getitem_single_exec(self): arrtable = self.tables.arrtable self._fixture_456(arrtable) - eq_( - testing.db.scalar(select([arrtable.c.intarr[2]])), - 5 - ) - testing.db.execute( - arrtable.update().values({arrtable.c.intarr[2]: 7}) - ) - eq_( - testing.db.scalar(select([arrtable.c.intarr[2]])), - 7 - ) + eq_(testing.db.scalar(select([arrtable.c.intarr[2]])), 5) + testing.db.execute(arrtable.update().values({arrtable.c.intarr[2]: 7})) + eq_(testing.db.scalar(select([arrtable.c.intarr[2]])), 7) def test_array_getitem_slice_exec(self): arrtable = self.tables.arrtable testing.db.execute( arrtable.insert(), intarr=[4, 5, 6], - strarr=[util.u('abc'), util.u('def')] - ) - eq_( - testing.db.scalar(select([arrtable.c.intarr[2:3]])), - [5, 6] + strarr=[util.u("abc"), util.u("def")], ) + eq_(testing.db.scalar(select([arrtable.c.intarr[2:3]])), [5, 6]) testing.db.execute( arrtable.update().values({arrtable.c.intarr[2:3]: [7, 8]}) ) - eq_( - testing.db.scalar(select([arrtable.c.intarr[2:3]])), - [7, 8] - ) + eq_(testing.db.scalar(select([arrtable.c.intarr[2:3]])), [7, 8]) def test_multi_dim_roundtrip(self): arrtable = self.tables.arrtable testing.db.execute(arrtable.insert(), dimarr=[[1, 2, 3], [4, 5, 6]]) eq_( testing.db.scalar(select([arrtable.c.dimarr])), - [[-1, 0, 1], [2, 3, 4]] + [[-1, 0, 1], [2, 3, 4]], ) def test_array_any_exec(self): arrtable = self.tables.arrtable with testing.db.connect() as conn: - conn.execute( - arrtable.insert(), - intarr=[4, 5, 6] - ) + conn.execute(arrtable.insert(), intarr=[4, 5, 6]) eq_( conn.scalar( - select([arrtable.c.intarr]). - where(postgresql.Any(5, arrtable.c.intarr)) + select([arrtable.c.intarr]).where( + postgresql.Any(5, arrtable.c.intarr) + ) ), - [4, 5, 6] + [4, 5, 6], ) def test_array_all_exec(self): arrtable = self.tables.arrtable with testing.db.connect() as conn: - conn.execute( - arrtable.insert(), - intarr=[4, 5, 6] - ) + conn.execute(arrtable.insert(), intarr=[4, 5, 6]) eq_( conn.scalar( - select([arrtable.c.intarr]). - where(arrtable.c.intarr.all(4, operator=operators.le)) + select([arrtable.c.intarr]).where( + arrtable.c.intarr.all(4, operator=operators.le) + ) ), - [4, 5, 6] + [4, 5, 6], ) @testing.provide_metadata @@ -1416,81 +1431,77 @@ class ArrayRoundTripTest(object): metadata = self.metadata t1 = Table( - 't1', metadata, - Column('id', Integer, primary_key=True), - Column('data', self.ARRAY(String(5), as_tuple=True)), + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("data", self.ARRAY(String(5), as_tuple=True)), Column( - 'data2', - self.ARRAY( - Numeric(asdecimal=False), as_tuple=True) - ) + "data2", self.ARRAY(Numeric(asdecimal=False), as_tuple=True) + ), ) metadata.create_all() testing.db.execute( - t1.insert(), id=1, data=[ - "1", "2", "3"], data2=[ - 5.4, 5.6]) + t1.insert(), id=1, data=["1", "2", "3"], data2=[5.4, 5.6] + ) + testing.db.execute( + t1.insert(), id=2, data=["4", "5", "6"], data2=[1.0] + ) testing.db.execute( t1.insert(), - id=2, - data=[ - "4", - "5", - "6"], - data2=[1.0]) - testing.db.execute(t1.insert(), id=3, data=[["4", "5"], ["6", "7"]], - data2=[[5.4, 5.6], [1.0, 1.1]]) + id=3, + data=[["4", "5"], ["6", "7"]], + data2=[[5.4, 5.6], [1.0, 1.1]], + ) r = testing.db.execute(t1.select().order_by(t1.c.id)).fetchall() eq_( r, [ - (1, ('1', '2', '3'), (5.4, 5.6)), - (2, ('4', '5', '6'), (1.0,)), - (3, (('4', '5'), ('6', '7')), ((5.4, 5.6), (1.0, 1.1))) - ] + (1, ("1", "2", "3"), (5.4, 5.6)), + (2, ("4", "5", "6"), (1.0,)), + (3, (("4", "5"), ("6", "7")), ((5.4, 5.6), (1.0, 1.1))), + ], ) # hashable eq_( set(row[1] for row in r), - set([('1', '2', '3'), ('4', '5', '6'), (('4', '5'), ('6', '7'))]) + set([("1", "2", "3"), ("4", "5", "6"), (("4", "5"), ("6", "7"))]), ) def test_array_plus_native_enum_create(self): m = MetaData() t = Table( - 't', m, + "t", + m, Column( - 'data_1', - self.ARRAY( - postgresql.ENUM('a', 'b', 'c', name='my_enum_1') - ) + "data_1", + self.ARRAY(postgresql.ENUM("a", "b", "c", name="my_enum_1")), ), Column( - 'data_2', - self.ARRAY( - types.Enum('a', 'b', 'c', name='my_enum_2') - ) - ) + "data_2", + self.ARRAY(types.Enum("a", "b", "c", name="my_enum_2")), + ), ) t.create(testing.db) eq_( - set(e['name'] for e in inspect(testing.db).get_enums()), - set(['my_enum_1', 'my_enum_2']) + set(e["name"] for e in inspect(testing.db).get_enums()), + set(["my_enum_1", "my_enum_2"]), ) t.drop(testing.db) eq_(inspect(testing.db).get_enums(), []) -class CoreArrayRoundTripTest(ArrayRoundTripTest, - fixtures.TablesTest, AssertsExecutionResults): +class CoreArrayRoundTripTest( + ArrayRoundTripTest, fixtures.TablesTest, AssertsExecutionResults +): ARRAY = sqltypes.ARRAY -class PGArrayRoundTripTest(ArrayRoundTripTest, - fixtures.TablesTest, AssertsExecutionResults): +class PGArrayRoundTripTest( + ArrayRoundTripTest, fixtures.TablesTest, AssertsExecutionResults +): ARRAY = postgresql.ARRAY def _test_undim_array_contains_typed_exec(self, struct): @@ -1498,10 +1509,11 @@ class PGArrayRoundTripTest(ArrayRoundTripTest, self._fixture_456(arrtable) eq_( testing.db.scalar( - select([arrtable.c.intarr]). - where(arrtable.c.intarr.contains(struct([4, 5]))) + select([arrtable.c.intarr]).where( + arrtable.c.intarr.contains(struct([4, 5])) + ) ), - [4, 5, 6] + [4, 5, 6], ) def test_undim_array_contains_set_exec(self): @@ -1512,17 +1524,19 @@ class PGArrayRoundTripTest(ArrayRoundTripTest, def test_undim_array_contains_generator_exec(self): self._test_undim_array_contains_typed_exec( - lambda elem: (x for x in elem)) + lambda elem: (x for x in elem) + ) def _test_dim_array_contains_typed_exec(self, struct): dim_arrtable = self.tables.dim_arrtable self._fixture_456(dim_arrtable) eq_( testing.db.scalar( - select([dim_arrtable.c.intarr]). - where(dim_arrtable.c.intarr.contains(struct([4, 5]))) + select([dim_arrtable.c.intarr]).where( + dim_arrtable.c.intarr.contains(struct([4, 5])) + ) ), - [4, 5, 6] + [4, 5, 6], ) def test_dim_array_contains_set_exec(self): @@ -1533,21 +1547,18 @@ class PGArrayRoundTripTest(ArrayRoundTripTest, def test_dim_array_contains_generator_exec(self): self._test_dim_array_contains_typed_exec( - lambda elem: ( - x for x in elem)) + lambda elem: (x for x in elem) + ) def test_array_contained_by_exec(self): arrtable = self.tables.arrtable with testing.db.connect() as conn: - conn.execute( - arrtable.insert(), - intarr=[6, 5, 4] - ) + conn.execute(arrtable.insert(), intarr=[6, 5, 4]) eq_( conn.scalar( select([arrtable.c.intarr.contained_by([4, 5, 6, 7])]) ), - True + True, ) def test_undim_array_empty(self): @@ -1555,25 +1566,24 @@ class PGArrayRoundTripTest(ArrayRoundTripTest, self._fixture_456(arrtable) eq_( testing.db.scalar( - select([arrtable.c.intarr]). - where(arrtable.c.intarr.contains([])) + select([arrtable.c.intarr]).where( + arrtable.c.intarr.contains([]) + ) ), - [4, 5, 6] + [4, 5, 6], ) def test_array_overlap_exec(self): arrtable = self.tables.arrtable with testing.db.connect() as conn: - conn.execute( - arrtable.insert(), - intarr=[4, 5, 6] - ) + conn.execute(arrtable.insert(), intarr=[4, 5, 6]) eq_( conn.scalar( - select([arrtable.c.intarr]). - where(arrtable.c.intarr.overlap([7, 6])) + select([arrtable.c.intarr]).where( + arrtable.c.intarr.overlap([7, 6]) + ) ), - [4, 5, 6] + [4, 5, 6], ) @@ -1581,33 +1591,33 @@ class HashableFlagORMTest(fixtures.TestBase): """test the various 'collection' types that they flip the 'hashable' flag appropriately. [ticket:3499]""" - __only_on__ = 'postgresql' + __only_on__ = "postgresql" def _test(self, type_, data): Base = declarative_base(metadata=self.metadata) class A(Base): - __tablename__ = 'a1' + __tablename__ = "a1" id = Column(Integer, primary_key=True) data = Column(type_) + Base.metadata.create_all(testing.db) s = Session(testing.db) - s.add_all([ - A(data=elem) for elem in data - ]) + s.add_all([A(data=elem) for elem in data]) s.commit() eq_( - [(obj.A.id, obj.data) for obj in - s.query(A, A.data).order_by(A.id)], - list(enumerate(data, 1)) + [ + (obj.A.id, obj.data) + for obj in s.query(A, A.data).order_by(A.id) + ], + list(enumerate(data, 1)), ) @testing.provide_metadata def test_array(self): self._test( - postgresql.ARRAY(Text()), - [['a', 'b', 'c'], ['d', 'e', 'f']] + postgresql.ARRAY(Text()), [["a", "b", "c"], ["d", "e", "f"]] ) @testing.requires.hstore @@ -1615,10 +1625,7 @@ class HashableFlagORMTest(fixtures.TestBase): def test_hstore(self): self._test( postgresql.HSTORE(), - [ - {'a': '1', 'b': '2', 'c': '3'}, - {'d': '4', 'e': '5', 'f': '6'} - ] + [{"a": "1", "b": "2", "c": "3"}, {"d": "4", "e": "5", "f": "6"}], ) @testing.provide_metadata @@ -1626,10 +1633,13 @@ class HashableFlagORMTest(fixtures.TestBase): self._test( postgresql.JSON(), [ - {'a': '1', 'b': '2', 'c': '3'}, - {'d': '4', 'e': {'e1': '5', 'e2': '6'}, - 'f': {'f1': [9, 10, 11]}} - ] + {"a": "1", "b": "2", "c": "3"}, + { + "d": "4", + "e": {"e1": "5", "e2": "6"}, + "f": {"f1": [9, 10, 11]}, + }, + ], ) @testing.requires.postgresql_jsonb @@ -1638,15 +1648,18 @@ class HashableFlagORMTest(fixtures.TestBase): self._test( postgresql.JSONB(), [ - {'a': '1', 'b': '2', 'c': '3'}, - {'d': '4', 'e': {'e1': '5', 'e2': '6'}, - 'f': {'f1': [9, 10, 11]}} - ] + {"a": "1", "b": "2", "c": "3"}, + { + "d": "4", + "e": {"e1": "5", "e2": "6"}, + "f": {"f1": [9, 10, 11]}, + }, + ], ) class TimestampTest(fixtures.TestBase, AssertsExecutionResults): - __only_on__ = 'postgresql' + __only_on__ = "postgresql" __backend__ = True def test_timestamp(self): @@ -1668,11 +1681,12 @@ class TimestampTest(fixtures.TestBase, AssertsExecutionResults): eq_(result[0], datetime.timedelta(40)) def test_interval_coercion(self): - expr = column('bar', postgresql.INTERVAL) + column('foo', types.Date) + expr = column("bar", postgresql.INTERVAL) + column("foo", types.Date) eq_(expr.type._type_affinity, types.DateTime) - expr = column('bar', postgresql.INTERVAL) * \ - column('foo', types.Numeric) + expr = column("bar", postgresql.INTERVAL) * column( + "foo", types.Numeric + ) eq_(expr.type._type_affinity, types.Interval) assert isinstance(expr.type, postgresql.INTERVAL) @@ -1681,7 +1695,7 @@ class SpecialTypesTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): """test DDL and reflection of PG-specific types """ - __only_on__ = 'postgresql >= 8.3.0', + __only_on__ = ("postgresql >= 8.3.0",) __backend__ = True @classmethod @@ -1692,34 +1706,30 @@ class SpecialTypesTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): # create these types so that we can issue # special SQL92 INTERVAL syntax class y2m(types.UserDefinedType, postgresql.INTERVAL): - def get_col_spec(self): return "INTERVAL YEAR TO MONTH" class d2s(types.UserDefinedType, postgresql.INTERVAL): - def get_col_spec(self): return "INTERVAL DAY TO SECOND" table = Table( - 'sometable', metadata, - Column( - 'id', postgresql.UUID, primary_key=True), - Column( - 'flag', postgresql.BIT), - Column( - 'bitstring', postgresql.BIT(4)), - Column('addr', postgresql.INET), - Column('addr2', postgresql.MACADDR), - Column('price', postgresql.MONEY), - Column('addr3', postgresql.CIDR), - Column('doubleprec', postgresql.DOUBLE_PRECISION), - Column('plain_interval', postgresql.INTERVAL), - Column('year_interval', y2m()), - Column('month_interval', d2s()), - Column('precision_interval', postgresql.INTERVAL( - precision=3)), - Column('tsvector_document', postgresql.TSVECTOR)) + "sometable", + metadata, + Column("id", postgresql.UUID, primary_key=True), + Column("flag", postgresql.BIT), + Column("bitstring", postgresql.BIT(4)), + Column("addr", postgresql.INET), + Column("addr2", postgresql.MACADDR), + Column("price", postgresql.MONEY), + Column("addr3", postgresql.CIDR), + Column("doubleprec", postgresql.DOUBLE_PRECISION), + Column("plain_interval", postgresql.INTERVAL), + Column("year_interval", y2m()), + Column("month_interval", d2s()), + Column("precision_interval", postgresql.INTERVAL(precision=3)), + Column("tsvector_document", postgresql.TSVECTOR), + ) metadata.create_all() @@ -1734,7 +1744,7 @@ class SpecialTypesTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): def test_reflection(self): m = MetaData(testing.db) - t = Table('sometable', m, autoload=True) + t = Table("sometable", m, autoload=True) self.assert_tables_equal(table, t, strict_types=True) assert t.c.plain_interval.type.precision is None @@ -1742,38 +1752,43 @@ class SpecialTypesTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): assert t.c.bitstring.type.length == 4 def test_bit_compile(self): - pairs = [(postgresql.BIT(), 'BIT(1)'), - (postgresql.BIT(5), 'BIT(5)'), - (postgresql.BIT(varying=True), 'BIT VARYING'), - (postgresql.BIT(5, varying=True), 'BIT VARYING(5)'), - ] + pairs = [ + (postgresql.BIT(), "BIT(1)"), + (postgresql.BIT(5), "BIT(5)"), + (postgresql.BIT(varying=True), "BIT VARYING"), + (postgresql.BIT(5, varying=True), "BIT VARYING(5)"), + ] for type_, expected in pairs: self.assert_compile(type_, expected) @testing.provide_metadata def test_tsvector_round_trip(self): - t = Table('t1', self.metadata, Column('data', postgresql.TSVECTOR)) + t = Table("t1", self.metadata, Column("data", postgresql.TSVECTOR)) t.create() testing.db.execute(t.insert(), data="a fat cat sat") eq_(testing.db.scalar(select([t.c.data])), "'a' 'cat' 'fat' 'sat'") testing.db.execute(t.update(), data="'a' 'cat' 'fat' 'mat' 'sat'") - eq_(testing.db.scalar(select([t.c.data])), - "'a' 'cat' 'fat' 'mat' 'sat'") + eq_( + testing.db.scalar(select([t.c.data])), + "'a' 'cat' 'fat' 'mat' 'sat'", + ) @testing.provide_metadata def test_bit_reflection(self): metadata = self.metadata - t1 = Table('t1', metadata, - Column('bit1', postgresql.BIT()), - Column('bit5', postgresql.BIT(5)), - Column('bitvarying', postgresql.BIT(varying=True)), - Column('bitvarying5', postgresql.BIT(5, varying=True)), - ) + t1 = Table( + "t1", + metadata, + Column("bit1", postgresql.BIT()), + Column("bit5", postgresql.BIT(5)), + Column("bitvarying", postgresql.BIT(varying=True)), + Column("bitvarying5", postgresql.BIT(5, varying=True)), + ) t1.create() m2 = MetaData(testing.db) - t2 = Table('t1', m2, autoload=True) + t2 = Table("t1", m2, autoload=True) eq_(t2.c.bit1.type.length, 1) eq_(t2.c.bit1.type.varying, False) eq_(t2.c.bit5.type.length, 5) @@ -1788,64 +1803,82 @@ class UUIDTest(fixtures.TestBase): """Test the bind/return values of the UUID type.""" - __only_on__ = 'postgresql >= 8.3' + __only_on__ = "postgresql >= 8.3" __backend__ = True @testing.fails_on( - 'postgresql+zxjdbc', + "postgresql+zxjdbc", 'column "data" is of type uuid but expression ' - 'is of type character varying') + "is of type character varying", + ) def test_uuid_string(self): import uuid + self._test_round_trip( - Table('utable', MetaData(), - Column('data', postgresql.UUID(as_uuid=False)) - ), + Table( + "utable", + MetaData(), + Column("data", postgresql.UUID(as_uuid=False)), + ), + str(uuid.uuid4()), str(uuid.uuid4()), - str(uuid.uuid4()) ) @testing.fails_on( - 'postgresql+zxjdbc', + "postgresql+zxjdbc", 'column "data" is of type uuid but expression is ' - 'of type character varying') + "of type character varying", + ) def test_uuid_uuid(self): import uuid + self._test_round_trip( - Table('utable', MetaData(), - Column('data', postgresql.UUID(as_uuid=True)) - ), + Table( + "utable", + MetaData(), + Column("data", postgresql.UUID(as_uuid=True)), + ), + uuid.uuid4(), uuid.uuid4(), - uuid.uuid4() ) - @testing.fails_on('postgresql+zxjdbc', - 'column "data" is of type uuid[] but ' - 'expression is of type character varying') - @testing.fails_on('postgresql+pg8000', 'No support for UUID with ARRAY') + @testing.fails_on( + "postgresql+zxjdbc", + 'column "data" is of type uuid[] but ' + "expression is of type character varying", + ) + @testing.fails_on("postgresql+pg8000", "No support for UUID with ARRAY") def test_uuid_array(self): import uuid + self._test_round_trip( Table( - 'utable', MetaData(), - Column('data', postgresql.ARRAY(postgresql.UUID(as_uuid=True))) + "utable", + MetaData(), + Column( + "data", postgresql.ARRAY(postgresql.UUID(as_uuid=True)) + ), ), [uuid.uuid4(), uuid.uuid4()], [uuid.uuid4(), uuid.uuid4()], ) - @testing.fails_on('postgresql+zxjdbc', - 'column "data" is of type uuid[] but ' - 'expression is of type character varying') - @testing.fails_on('postgresql+pg8000', 'No support for UUID with ARRAY') + @testing.fails_on( + "postgresql+zxjdbc", + 'column "data" is of type uuid[] but ' + "expression is of type character varying", + ) + @testing.fails_on("postgresql+pg8000", "No support for UUID with ARRAY") def test_uuid_string_array(self): import uuid + self._test_round_trip( Table( - 'utable', MetaData(), + "utable", + MetaData(), Column( - 'data', - postgresql.ARRAY(postgresql.UUID(as_uuid=False))) + "data", postgresql.ARRAY(postgresql.UUID(as_uuid=False)) + ), ), [str(uuid.uuid4()), str(uuid.uuid4())], [str(uuid.uuid4()), str(uuid.uuid4())], @@ -1853,13 +1886,11 @@ class UUIDTest(fixtures.TestBase): def test_no_uuid_available(self): from sqlalchemy.dialects.postgresql import base + uuid_type = base._python_UUID base._python_UUID = None try: - assert_raises( - NotImplementedError, - postgresql.UUID, as_uuid=True - ) + assert_raises(NotImplementedError, postgresql.UUID, as_uuid=True) finally: base._python_UUID = uuid_type @@ -1872,11 +1903,10 @@ class UUIDTest(fixtures.TestBase): def _test_round_trip(self, utable, value1, value2, exp_value2=None): utable.create(self.conn) - self.conn.execute(utable.insert(), {'data': value1}) - self.conn.execute(utable.insert(), {'data': value2}) + self.conn.execute(utable.insert(), {"data": value1}) + self.conn.execute(utable.insert(), {"data": value2}) r = self.conn.execute( - select([utable.c.data]). - where(utable.c.data != value1) + select([utable.c.data]).where(utable.c.data != value1) ) if exp_value2: eq_(r.fetchone()[0], exp_value2) @@ -1886,14 +1916,16 @@ class UUIDTest(fixtures.TestBase): class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): - __dialect__ = 'postgresql' + __dialect__ = "postgresql" def setup(self): metadata = MetaData() - self.test_table = Table('test_table', metadata, - Column('id', Integer, primary_key=True), - Column('hash', HSTORE) - ) + self.test_table = Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("hash", HSTORE), + ) self.hashcol = self.test_table.c.hash def _test_where(self, whereclause, expected): @@ -1901,17 +1933,14 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): self.assert_compile( stmt, "SELECT test_table.id, test_table.hash FROM test_table " - "WHERE %s" % expected + "WHERE %s" % expected, ) def _test_cols(self, colclause, expected, from_=True): stmt = select([colclause]) self.assert_compile( stmt, - ( - "SELECT %s" + - (" FROM test_table" if from_ else "") - ) % expected + ("SELECT %s" + (" FROM test_table" if from_ else "")) % expected, ) def test_bind_serialize_default(self): @@ -1920,47 +1949,44 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): proc = self.test_table.c.hash.type._cached_bind_processor(dialect) eq_( proc(util.OrderedDict([("key1", "value1"), ("key2", "value2")])), - '"key1"=>"value1", "key2"=>"value2"' + '"key1"=>"value1", "key2"=>"value2"', ) def test_bind_serialize_with_slashes_and_quotes(self): dialect = postgresql.dialect() proc = self.test_table.c.hash.type._cached_bind_processor(dialect) - eq_( - proc({'\\"a': '\\"1'}), - '"\\\\\\"a"=>"\\\\\\"1"' - ) + eq_(proc({'\\"a': '\\"1'}), '"\\\\\\"a"=>"\\\\\\"1"') def test_parse_error(self): dialect = postgresql.dialect() proc = self.test_table.c.hash.type._cached_result_processor( - dialect, None) + dialect, None + ) assert_raises_message( ValueError, - r'''After u?'\[\.\.\.\], "key1"=>"value1", ', could not parse ''' - r'''residual at position 36: u?'crapcrapcrap, "key3"\[\.\.\.\]''', + r"""After u?'\[\.\.\.\], "key1"=>"value1", ', could not parse """ + r"""residual at position 36: u?'crapcrapcrap, "key3"\[\.\.\.\]""", proc, '"key2"=>"value2", "key1"=>"value1", ' - 'crapcrapcrap, "key3"=>"value3"' + 'crapcrapcrap, "key3"=>"value3"', ) def test_result_deserialize_default(self): dialect = postgresql.dialect() proc = self.test_table.c.hash.type._cached_result_processor( - dialect, None) + dialect, None + ) eq_( proc('"key2"=>"value2", "key1"=>"value1"'), - {"key1": "value1", "key2": "value2"} + {"key1": "value1", "key2": "value2"}, ) def test_result_deserialize_with_slashes_and_quotes(self): dialect = postgresql.dialect() proc = self.test_table.c.hash.type._cached_result_processor( - dialect, None) - eq_( - proc('"\\\\\\"a"=>"\\\\\\"1"'), - {'\\"a': '\\"1'} + dialect, None ) + eq_(proc('"\\\\\\"a"=>"\\\\\\"1"'), {'\\"a': '\\"1'}) def test_bind_serialize_psycopg2(self): from sqlalchemy.dialects.postgresql import psycopg2 @@ -1975,7 +2001,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): proc = self.test_table.c.hash.type._cached_bind_processor(dialect) eq_( proc(util.OrderedDict([("key1", "value1"), ("key2", "value2")])), - '"key1"=>"value1", "key2"=>"value2"' + '"key1"=>"value1", "key2"=>"value2"', ) def test_result_deserialize_psycopg2(self): @@ -1984,236 +2010,252 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): dialect = psycopg2.PGDialect_psycopg2() dialect._has_native_hstore = True proc = self.test_table.c.hash.type._cached_result_processor( - dialect, None) + dialect, None + ) is_(proc, None) dialect = psycopg2.PGDialect_psycopg2() dialect._has_native_hstore = False proc = self.test_table.c.hash.type._cached_result_processor( - dialect, None) + dialect, None + ) eq_( proc('"key2"=>"value2", "key1"=>"value1"'), - {"key1": "value1", "key2": "value2"} + {"key1": "value1", "key2": "value2"}, ) def test_ret_type_text(self): - col = column('x', HSTORE()) + col = column("x", HSTORE()) - is_(col['foo'].type.__class__, Text) + is_(col["foo"].type.__class__, Text) def test_ret_type_custom(self): class MyType(types.UserDefinedType): pass - col = column('x', HSTORE(text_type=MyType)) + col = column("x", HSTORE(text_type=MyType)) - is_(col['foo'].type.__class__, MyType) + is_(col["foo"].type.__class__, MyType) def test_where_has_key(self): self._test_where( # hide from 2to3 - getattr(self.hashcol, 'has_key')('foo'), - "test_table.hash ? %(hash_1)s" + getattr(self.hashcol, "has_key")("foo"), + "test_table.hash ? %(hash_1)s", ) def test_where_has_all(self): self._test_where( - self.hashcol.has_all(postgresql.array(['1', '2'])), - "test_table.hash ?& ARRAY[%(param_1)s, %(param_2)s]" + self.hashcol.has_all(postgresql.array(["1", "2"])), + "test_table.hash ?& ARRAY[%(param_1)s, %(param_2)s]", ) def test_where_has_any(self): self._test_where( - self.hashcol.has_any(postgresql.array(['1', '2'])), - "test_table.hash ?| ARRAY[%(param_1)s, %(param_2)s]" + self.hashcol.has_any(postgresql.array(["1", "2"])), + "test_table.hash ?| ARRAY[%(param_1)s, %(param_2)s]", ) def test_where_defined(self): self._test_where( - self.hashcol.defined('foo'), - "defined(test_table.hash, %(defined_1)s)" + self.hashcol.defined("foo"), + "defined(test_table.hash, %(defined_1)s)", ) def test_where_contains(self): self._test_where( - self.hashcol.contains({'foo': '1'}), - "test_table.hash @> %(hash_1)s" + self.hashcol.contains({"foo": "1"}), + "test_table.hash @> %(hash_1)s", ) def test_where_contained_by(self): self._test_where( - self.hashcol.contained_by({'foo': '1', 'bar': None}), - "test_table.hash <@ %(hash_1)s" + self.hashcol.contained_by({"foo": "1", "bar": None}), + "test_table.hash <@ %(hash_1)s", ) def test_where_getitem(self): self._test_where( - self.hashcol['bar'] == None, # noqa - "(test_table.hash -> %(hash_1)s) IS NULL" + self.hashcol["bar"] == None, # noqa + "(test_table.hash -> %(hash_1)s) IS NULL", ) def test_cols_get(self): self._test_cols( - self.hashcol['foo'], + self.hashcol["foo"], "test_table.hash -> %(hash_1)s AS anon_1", - True + True, ) def test_cols_delete_single_key(self): self._test_cols( - self.hashcol.delete('foo'), + self.hashcol.delete("foo"), "delete(test_table.hash, %(delete_2)s) AS delete_1", - True + True, ) def test_cols_delete_array_of_keys(self): self._test_cols( - self.hashcol.delete(postgresql.array(['foo', 'bar'])), - ("delete(test_table.hash, ARRAY[%(param_1)s, %(param_2)s]) " - "AS delete_1"), - True + self.hashcol.delete(postgresql.array(["foo", "bar"])), + ( + "delete(test_table.hash, ARRAY[%(param_1)s, %(param_2)s]) " + "AS delete_1" + ), + True, ) def test_cols_delete_matching_pairs(self): self._test_cols( - self.hashcol.delete(hstore('1', '2')), - ("delete(test_table.hash, hstore(%(hstore_1)s, %(hstore_2)s)) " - "AS delete_1"), - True + self.hashcol.delete(hstore("1", "2")), + ( + "delete(test_table.hash, hstore(%(hstore_1)s, %(hstore_2)s)) " + "AS delete_1" + ), + True, ) def test_cols_slice(self): self._test_cols( - self.hashcol.slice(postgresql.array(['1', '2'])), - ("slice(test_table.hash, ARRAY[%(param_1)s, %(param_2)s]) " - "AS slice_1"), - True + self.hashcol.slice(postgresql.array(["1", "2"])), + ( + "slice(test_table.hash, ARRAY[%(param_1)s, %(param_2)s]) " + "AS slice_1" + ), + True, ) def test_cols_hstore_pair_text(self): self._test_cols( - hstore('foo', '3')['foo'], + hstore("foo", "3")["foo"], "hstore(%(hstore_1)s, %(hstore_2)s) -> %(hstore_3)s AS anon_1", - False + False, ) def test_cols_hstore_pair_array(self): self._test_cols( - hstore(postgresql.array(['1', '2']), - postgresql.array(['3', None]))['1'], - ("hstore(ARRAY[%(param_1)s, %(param_2)s], " - "ARRAY[%(param_3)s, NULL]) -> %(hstore_1)s AS anon_1"), - False + hstore( + postgresql.array(["1", "2"]), postgresql.array(["3", None]) + )["1"], + ( + "hstore(ARRAY[%(param_1)s, %(param_2)s], " + "ARRAY[%(param_3)s, NULL]) -> %(hstore_1)s AS anon_1" + ), + False, ) def test_cols_hstore_single_array(self): self._test_cols( - hstore(postgresql.array(['1', '2', '3', None]))['3'], - ("hstore(ARRAY[%(param_1)s, %(param_2)s, %(param_3)s, NULL]) " - "-> %(hstore_1)s AS anon_1"), - False + hstore(postgresql.array(["1", "2", "3", None]))["3"], + ( + "hstore(ARRAY[%(param_1)s, %(param_2)s, %(param_3)s, NULL]) " + "-> %(hstore_1)s AS anon_1" + ), + False, ) def test_cols_concat(self): self._test_cols( - self.hashcol.concat(hstore(cast(self.test_table.c.id, Text), '3')), - ("test_table.hash || hstore(CAST(test_table.id AS TEXT), " - "%(hstore_1)s) AS anon_1"), - True + self.hashcol.concat(hstore(cast(self.test_table.c.id, Text), "3")), + ( + "test_table.hash || hstore(CAST(test_table.id AS TEXT), " + "%(hstore_1)s) AS anon_1" + ), + True, ) def test_cols_concat_op(self): self._test_cols( - hstore('foo', 'bar') + self.hashcol, + hstore("foo", "bar") + self.hashcol, "hstore(%(hstore_1)s, %(hstore_2)s) || test_table.hash AS anon_1", - True + True, ) def test_cols_concat_get(self): self._test_cols( - (self.hashcol + self.hashcol)['foo'], - "(test_table.hash || test_table.hash) -> %(param_1)s AS anon_1" + (self.hashcol + self.hashcol)["foo"], + "(test_table.hash || test_table.hash) -> %(param_1)s AS anon_1", ) def test_cols_against_is(self): self._test_cols( - self.hashcol['foo'] != None, # noqa - "(test_table.hash -> %(hash_1)s) IS NOT NULL AS anon_1" + self.hashcol["foo"] != None, # noqa + "(test_table.hash -> %(hash_1)s) IS NOT NULL AS anon_1", ) def test_cols_keys(self): self._test_cols( # hide from 2to3 - getattr(self.hashcol, 'keys')(), + getattr(self.hashcol, "keys")(), "akeys(test_table.hash) AS akeys_1", - True + True, ) def test_cols_vals(self): self._test_cols( - self.hashcol.vals(), - "avals(test_table.hash) AS avals_1", - True + self.hashcol.vals(), "avals(test_table.hash) AS avals_1", True ) def test_cols_array(self): self._test_cols( self.hashcol.array(), "hstore_to_array(test_table.hash) AS hstore_to_array_1", - True + True, ) def test_cols_matrix(self): self._test_cols( self.hashcol.matrix(), "hstore_to_matrix(test_table.hash) AS hstore_to_matrix_1", - True + True, ) class HStoreRoundTripTest(fixtures.TablesTest): - __requires__ = 'hstore', - __dialect__ = 'postgresql' + __requires__ = ("hstore",) + __dialect__ = "postgresql" __backend__ = True @classmethod def define_tables(cls, metadata): - Table('data_table', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(30), nullable=False), - Column('data', HSTORE) - ) + Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30), nullable=False), + Column("data", HSTORE), + ) def _fixture_data(self, engine): data_table = self.tables.data_table engine.execute( data_table.insert(), - {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}}, - {'name': 'r2', 'data': {"k1": "r2v1", "k2": "r2v2"}}, - {'name': 'r3', 'data': {"k1": "r3v1", "k2": "r3v2"}}, - {'name': 'r4', 'data': {"k1": "r4v1", "k2": "r4v2"}}, - {'name': 'r5', 'data': {"k1": "r5v1", "k2": "r5v2"}}, + {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, + {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}}, + {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}}, + {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}}, + {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}}, ) def _assert_data(self, compare): data = testing.db.execute( - select([self.tables.data_table.c.data]). - order_by(self.tables.data_table.c.name) + select([self.tables.data_table.c.data]).order_by( + self.tables.data_table.c.name + ) ).fetchall() eq_([d for d, in data], compare) def _test_insert(self, engine): engine.execute( self.tables.data_table.insert(), - {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}} + {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, ) self._assert_data([{"k1": "r1v1", "k2": "r1v2"}]) def _non_native_engine(self): if testing.requires.psycopg2_native_hstore.enabled: engine = engines.testing_engine( - options=dict( - use_native_hstore=False)) + options=dict(use_native_hstore=False) + ) else: engine = testing.db engine.connect().close() @@ -2221,8 +2263,8 @@ class HStoreRoundTripTest(fixtures.TablesTest): def test_reflect(self): insp = inspect(testing.db) - cols = insp.get_columns('data_table') - assert isinstance(cols[2]['type'], HSTORE) + cols = insp.get_columns("data_table") + assert isinstance(cols[2]["type"], HSTORE) def test_literal_round_trip(self): # in particular, this tests that the array index @@ -2230,14 +2272,9 @@ class HStoreRoundTripTest(fixtures.TablesTest): # array functions it requires outer parenthezisation on the left and # we may not be doing that here expr = hstore( - postgresql.array(['1', '2']), - postgresql.array(['3', None]))['1'] - eq_( - testing.db.scalar( - select([expr]) - ), - "3" - ) + postgresql.array(["1", "2"]), postgresql.array(["3", None]) + )["1"] + eq_(testing.db.scalar(select([expr])), "3") @testing.requires.psycopg2_native_hstore def test_insert_native(self): @@ -2263,19 +2300,23 @@ class HStoreRoundTripTest(fixtures.TablesTest): data_table = self.tables.data_table result = engine.execute( select([data_table.c.data]).where( - data_table.c.data['k1'] == 'r3v1')).first() - eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},)) + data_table.c.data["k1"] == "r3v1" + ) + ).first() + eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) def _test_fixed_round_trip(self, engine): - s = select([ - hstore( - array(['key1', 'key2', 'key3']), - array(['value1', 'value2', 'value3']) - ) - ]) + s = select( + [ + hstore( + array(["key1", "key2", "key3"]), + array(["value1", "value2", "value3"]), + ) + ] + ) eq_( engine.scalar(s), - {"key1": "value1", "key2": "value2", "key3": "value3"} + {"key1": "value1", "key2": "value2", "key3": "value3"}, ) def test_fixed_round_trip_python(self): @@ -2288,19 +2329,25 @@ class HStoreRoundTripTest(fixtures.TablesTest): self._test_fixed_round_trip(engine) def _test_unicode_round_trip(self, engine): - s = select([ - hstore( - array([util.u('réveillé'), util.u('drôle'), util.u('S’il')]), - array([util.u('réveillé'), util.u('drôle'), util.u('S’il')]) - ) - ]) + s = select( + [ + hstore( + array( + [util.u("réveillé"), util.u("drôle"), util.u("S’il")] + ), + array( + [util.u("réveillé"), util.u("drôle"), util.u("S’il")] + ), + ) + ] + ) eq_( engine.scalar(s), { - util.u('réveillé'): util.u('réveillé'), - util.u('drôle'): util.u('drôle'), - util.u('S’il'): util.u('S’il') - } + util.u("réveillé"): util.u("réveillé"), + util.u("drôle"): util.u("drôle"), + util.u("S’il"): util.u("S’il"), + }, ) @testing.requires.psycopg2_native_hstore @@ -2325,171 +2372,153 @@ class HStoreRoundTripTest(fixtures.TablesTest): def _test_escaped_quotes_round_trip(self, engine): engine.execute( self.tables.data_table.insert(), - {'name': 'r1', 'data': {r'key \"foo\"': r'value \"bar"\ xyz'}} + {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}}, ) - self._assert_data([{r'key \"foo\"': r'value \"bar"\ xyz'}]) + self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}]) def test_orm_round_trip(self): from sqlalchemy import orm class Data(object): - def __init__(self, name, data): self.name = name self.data = data + orm.mapper(Data, self.tables.data_table) s = orm.Session(testing.db) - d = Data(name='r1', data={"key1": "value1", "key2": "value2", - "key3": "value3"}) - s.add(d) - eq_( - s.query(Data.data, Data).all(), - [(d.data, d)] + d = Data( + name="r1", + data={"key1": "value1", "key2": "value2", "key3": "value3"}, ) + s.add(d) + eq_(s.query(Data.data, Data).all(), [(d.data, d)]) class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): - __dialect__ = 'postgresql' + __dialect__ = "postgresql" # operator tests @classmethod def setup_class(cls): - table = Table('data_table', MetaData(), - Column('range', cls._col_type, primary_key=True), - ) + table = Table( + "data_table", + MetaData(), + Column("range", cls._col_type, primary_key=True), + ) cls.col = table.c.range def _test_clause(self, colclause, expected): - self.assert_compile( - colclause, expected - ) + self.assert_compile(colclause, expected) def test_where_equal(self): self._test_clause( - self.col == self._data_str, - "data_table.range = %(range_1)s" + self.col == self._data_str, "data_table.range = %(range_1)s" ) def test_where_not_equal(self): self._test_clause( - self.col != self._data_str, - "data_table.range <> %(range_1)s" + self.col != self._data_str, "data_table.range <> %(range_1)s" ) def test_where_is_null(self): - self._test_clause( - self.col == None, - "data_table.range IS NULL" - ) + self._test_clause(self.col == None, "data_table.range IS NULL") def test_where_is_not_null(self): - self._test_clause( - self.col != None, - "data_table.range IS NOT NULL" - ) + self._test_clause(self.col != None, "data_table.range IS NOT NULL") def test_where_less_than(self): self._test_clause( - self.col < self._data_str, - "data_table.range < %(range_1)s" + self.col < self._data_str, "data_table.range < %(range_1)s" ) def test_where_greater_than(self): self._test_clause( - self.col > self._data_str, - "data_table.range > %(range_1)s" + self.col > self._data_str, "data_table.range > %(range_1)s" ) def test_where_less_than_or_equal(self): self._test_clause( - self.col <= self._data_str, - "data_table.range <= %(range_1)s" + self.col <= self._data_str, "data_table.range <= %(range_1)s" ) def test_where_greater_than_or_equal(self): self._test_clause( - self.col >= self._data_str, - "data_table.range >= %(range_1)s" + self.col >= self._data_str, "data_table.range >= %(range_1)s" ) def test_contains(self): self._test_clause( self.col.contains(self._data_str), - "data_table.range @> %(range_1)s" + "data_table.range @> %(range_1)s", ) def test_contained_by(self): self._test_clause( self.col.contained_by(self._data_str), - "data_table.range <@ %(range_1)s" + "data_table.range <@ %(range_1)s", ) def test_overlaps(self): self._test_clause( self.col.overlaps(self._data_str), - "data_table.range && %(range_1)s" + "data_table.range && %(range_1)s", ) def test_strictly_left_of(self): self._test_clause( - self.col << self._data_str, - "data_table.range << %(range_1)s" + self.col << self._data_str, "data_table.range << %(range_1)s" ) self._test_clause( self.col.strictly_left_of(self._data_str), - "data_table.range << %(range_1)s" + "data_table.range << %(range_1)s", ) def test_strictly_right_of(self): self._test_clause( - self.col >> self._data_str, - "data_table.range >> %(range_1)s" + self.col >> self._data_str, "data_table.range >> %(range_1)s" ) self._test_clause( self.col.strictly_right_of(self._data_str), - "data_table.range >> %(range_1)s" + "data_table.range >> %(range_1)s", ) def test_not_extend_right_of(self): self._test_clause( self.col.not_extend_right_of(self._data_str), - "data_table.range &< %(range_1)s" + "data_table.range &< %(range_1)s", ) def test_not_extend_left_of(self): self._test_clause( self.col.not_extend_left_of(self._data_str), - "data_table.range &> %(range_1)s" + "data_table.range &> %(range_1)s", ) def test_adjacent_to(self): self._test_clause( self.col.adjacent_to(self._data_str), - "data_table.range -|- %(range_1)s" + "data_table.range -|- %(range_1)s", ) def test_union(self): self._test_clause( - self.col + self.col, - "data_table.range + data_table.range" + self.col + self.col, "data_table.range + data_table.range" ) def test_intersection(self): self._test_clause( - self.col * self.col, - "data_table.range * data_table.range" + self.col * self.col, "data_table.range * data_table.range" ) def test_different(self): self._test_clause( - self.col - self.col, - "data_table.range - data_table.range" + self.col - self.col, "data_table.range - data_table.range" ) class _RangeTypeRoundTrip(fixtures.TablesTest): - __requires__ = 'range_types', 'psycopg2_compatibility' + __requires__ = "range_types", "psycopg2_compatibility" __backend__ = True def extras(self): @@ -2505,9 +2534,11 @@ class _RangeTypeRoundTrip(fixtures.TablesTest): def define_tables(cls, metadata): # no reason ranges shouldn't be primary keys, # so lets just use them as such - table = Table('data_table', metadata, - Column('range', cls._col_type, primary_key=True), - ) + table = Table( + "data_table", + metadata, + Column("range", cls._col_type, primary_key=True), + ) cls.col = table.c.range def test_actual_type(self): @@ -2515,75 +2546,65 @@ class _RangeTypeRoundTrip(fixtures.TablesTest): def test_reflect(self): from sqlalchemy import inspect + insp = inspect(testing.db) - cols = insp.get_columns('data_table') - assert isinstance(cols[0]['type'], self._col_type) + cols = insp.get_columns("data_table") + assert isinstance(cols[0]["type"], self._col_type) def _assert_data(self): data = testing.db.execute( select([self.tables.data_table.c.range]) ).fetchall() - eq_(data, [(self._data_obj(), )]) + eq_(data, [(self._data_obj(),)]) def test_insert_obj(self): testing.db.engine.execute( - self.tables.data_table.insert(), - {'range': self._data_obj()} + self.tables.data_table.insert(), {"range": self._data_obj()} ) self._assert_data() def test_insert_text(self): testing.db.engine.execute( - self.tables.data_table.insert(), - {'range': self._data_str} + self.tables.data_table.insert(), {"range": self._data_str} ) self._assert_data() def test_union_result(self): # insert testing.db.engine.execute( - self.tables.data_table.insert(), - {'range': self._data_str} + self.tables.data_table.insert(), {"range": self._data_str} ) # select range = self.tables.data_table.c.range - data = testing.db.execute( - select([range + range]) - ).fetchall() - eq_(data, [(self._data_obj(), )]) + data = testing.db.execute(select([range + range])).fetchall() + eq_(data, [(self._data_obj(),)]) def test_intersection_result(self): # insert testing.db.engine.execute( - self.tables.data_table.insert(), - {'range': self._data_str} + self.tables.data_table.insert(), {"range": self._data_str} ) # select range = self.tables.data_table.c.range - data = testing.db.execute( - select([range * range]) - ).fetchall() - eq_(data, [(self._data_obj(), )]) + data = testing.db.execute(select([range * range])).fetchall() + eq_(data, [(self._data_obj(),)]) def test_difference_result(self): # insert testing.db.engine.execute( - self.tables.data_table.insert(), - {'range': self._data_str} + self.tables.data_table.insert(), {"range": self._data_str} ) # select range = self.tables.data_table.c.range - data = testing.db.execute( - select([range - range]) - ).fetchall() - eq_(data, [(self._data_obj().__class__(empty=True), )]) + data = testing.db.execute(select([range - range])).fetchall() + eq_(data, [(self._data_obj().__class__(empty=True),)]) class _Int4RangeTests(object): _col_type = INT4RANGE - _col_str = 'INT4RANGE' - _data_str = '[1,2)' + _col_str = "INT4RANGE" + _data_str = "[1,2)" def _data_obj(self): return self.extras().NumericRange(1, 2) @@ -2592,8 +2613,8 @@ class _Int4RangeTests(object): class _Int8RangeTests(object): _col_type = INT8RANGE - _col_str = 'INT8RANGE' - _data_str = '[9223372036854775806,9223372036854775807)' + _col_str = "INT8RANGE" + _data_str = "[9223372036854775806,9223372036854775807)" def _data_obj(self): return self.extras().NumericRange( @@ -2604,20 +2625,20 @@ class _Int8RangeTests(object): class _NumRangeTests(object): _col_type = NUMRANGE - _col_str = 'NUMRANGE' - _data_str = '[1.0,2.0)' + _col_str = "NUMRANGE" + _data_str = "[1.0,2.0)" def _data_obj(self): return self.extras().NumericRange( - decimal.Decimal('1.0'), decimal.Decimal('2.0') + decimal.Decimal("1.0"), decimal.Decimal("2.0") ) class _DateRangeTests(object): _col_type = DATERANGE - _col_str = 'DATERANGE' - _data_str = '[2013-03-23,2013-03-24)' + _col_str = "DATERANGE" + _data_str = "[2013-03-23,2013-03-24)" def _data_obj(self): return self.extras().DateRange( @@ -2628,20 +2649,20 @@ class _DateRangeTests(object): class _DateTimeRangeTests(object): _col_type = TSRANGE - _col_str = 'TSRANGE' - _data_str = '[2013-03-23 14:30,2013-03-23 23:30)' + _col_str = "TSRANGE" + _data_str = "[2013-03-23 14:30,2013-03-23 23:30)" def _data_obj(self): return self.extras().DateTimeRange( datetime.datetime(2013, 3, 23, 14, 30), - datetime.datetime(2013, 3, 23, 23, 30) + datetime.datetime(2013, 3, 23, 23, 30), ) class _DateTimeTZRangeTests(object): _col_type = TSTZRANGE - _col_str = 'TSTZRANGE' + _col_str = "TSTZRANGE" # make sure we use one, steady timestamp with timezone pair # for all parts of all these tests @@ -2649,16 +2670,14 @@ class _DateTimeTZRangeTests(object): def tstzs(self): if self._tstzs is None: - lower = testing.db.scalar( - func.current_timestamp().select() - ) + lower = testing.db.scalar(func.current_timestamp().select()) upper = lower + datetime.timedelta(1) self._tstzs = (lower, upper) return self._tstzs @property def _data_str(self): - return '[%s,%s)' % self.tstzs() + return "[%s,%s)" % self.tstzs() def _data_obj(self): return self.extras().DateTimeTZRange(*self.tstzs()) @@ -2704,7 +2723,9 @@ class DateTimeRangeRoundTripTest(_DateTimeRangeTests, _RangeTypeRoundTrip): pass -class DateTimeTZRangeCompilationTest(_DateTimeTZRangeTests, _RangeTypeCompilation): +class DateTimeTZRangeCompilationTest( + _DateTimeTZRangeTests, _RangeTypeCompilation +): pass @@ -2713,14 +2734,16 @@ class DateTimeTZRangeRoundTripTest(_DateTimeTZRangeTests, _RangeTypeRoundTrip): class JSONTest(AssertsCompiledSQL, fixtures.TestBase): - __dialect__ = 'postgresql' + __dialect__ = "postgresql" def setup(self): metadata = MetaData() - self.test_table = Table('test_table', metadata, - Column('id', Integer, primary_key=True), - Column('test_column', JSON), - ) + self.test_table = Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("test_column", JSON), + ) self.jsoncol = self.test_table.c.test_column def _test_where(self, whereclause, expected): @@ -2728,194 +2751,176 @@ class JSONTest(AssertsCompiledSQL, fixtures.TestBase): self.assert_compile( stmt, "SELECT test_table.id, test_table.test_column FROM test_table " - "WHERE %s" % expected + "WHERE %s" % expected, ) def _test_cols(self, colclause, expected, from_=True): stmt = select([colclause]) self.assert_compile( stmt, - ( - "SELECT %s" + - (" FROM test_table" if from_ else "") - ) % expected + ("SELECT %s" + (" FROM test_table" if from_ else "")) % expected, ) # This test is a bit misleading -- in real life you will need to cast to # do anything def test_where_getitem(self): self._test_where( - self.jsoncol['bar'] == None, # noqa - "(test_table.test_column -> %(test_column_1)s) IS NULL" + self.jsoncol["bar"] == None, # noqa + "(test_table.test_column -> %(test_column_1)s) IS NULL", ) def test_where_path(self): self._test_where( self.jsoncol[("foo", 1)] == None, # noqa - "(test_table.test_column #> %(test_column_1)s) IS NULL" + "(test_table.test_column #> %(test_column_1)s) IS NULL", ) def test_path_typing(self): - col = column('x', JSON()) - is_( - col['q'].type._type_affinity, types.JSON - ) - is_( - col[('q', )].type._type_affinity, types.JSON - ) - is_( - col['q']['p'].type._type_affinity, types.JSON - ) - is_( - col[('q', 'p')].type._type_affinity, types.JSON - ) + col = column("x", JSON()) + is_(col["q"].type._type_affinity, types.JSON) + is_(col[("q",)].type._type_affinity, types.JSON) + is_(col["q"]["p"].type._type_affinity, types.JSON) + is_(col[("q", "p")].type._type_affinity, types.JSON) def test_custom_astext_type(self): class MyType(types.UserDefinedType): pass - col = column('x', JSON(astext_type=MyType)) + col = column("x", JSON(astext_type=MyType)) - is_( - col['q'].astext.type.__class__, MyType - ) + is_(col["q"].astext.type.__class__, MyType) - is_( - col[('q', 'p')].astext.type.__class__, MyType - ) + is_(col[("q", "p")].astext.type.__class__, MyType) - is_( - col['q']['p'].astext.type.__class__, MyType - ) + is_(col["q"]["p"].astext.type.__class__, MyType) def test_where_getitem_as_text(self): self._test_where( - self.jsoncol['bar'].astext == None, # noqa - "(test_table.test_column ->> %(test_column_1)s) IS NULL" + self.jsoncol["bar"].astext == None, # noqa + "(test_table.test_column ->> %(test_column_1)s) IS NULL", ) def test_where_getitem_astext_cast(self): self._test_where( - self.jsoncol['bar'].astext.cast(Integer) == 5, + self.jsoncol["bar"].astext.cast(Integer) == 5, "CAST((test_table.test_column ->> %(test_column_1)s) AS INTEGER) " - "= %(param_1)s" + "= %(param_1)s", ) def test_where_getitem_json_cast(self): self._test_where( - self.jsoncol['bar'].cast(Integer) == 5, + self.jsoncol["bar"].cast(Integer) == 5, "CAST((test_table.test_column -> %(test_column_1)s) AS INTEGER) " - "= %(param_1)s" + "= %(param_1)s", ) def test_where_path_as_text(self): self._test_where( self.jsoncol[("foo", 1)].astext == None, # noqa - "(test_table.test_column #>> %(test_column_1)s) IS NULL" + "(test_table.test_column #>> %(test_column_1)s) IS NULL", ) def test_cols_get(self): self._test_cols( - self.jsoncol['foo'], + self.jsoncol["foo"], "test_table.test_column -> %(test_column_1)s AS anon_1", - True + True, ) class JSONRoundTripTest(fixtures.TablesTest): - __only_on__ = ('postgresql >= 9.3',) + __only_on__ = ("postgresql >= 9.3",) __backend__ = True test_type = JSON @classmethod def define_tables(cls, metadata): - Table('data_table', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(30), nullable=False), - Column('data', cls.test_type), - Column('nulldata', cls.test_type(none_as_null=True)) - ) + Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30), nullable=False), + Column("data", cls.test_type), + Column("nulldata", cls.test_type(none_as_null=True)), + ) def _fixture_data(self, engine): data_table = self.tables.data_table engine.execute( data_table.insert(), - {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}}, - {'name': 'r2', 'data': {"k1": "r2v1", "k2": "r2v2"}}, - {'name': 'r3', 'data': {"k1": "r3v1", "k2": "r3v2"}}, - {'name': 'r4', 'data': {"k1": "r4v1", "k2": "r4v2"}}, - {'name': 'r5', 'data': {"k1": "r5v1", "k2": "r5v2", "k3": 5}}, - {'name': 'r6', 'data': {"k1": {"r6v1": {'subr': [1, 2, 3]}}}}, + {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, + {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}}, + {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}}, + {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}}, + {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2", "k3": 5}}, + {"name": "r6", "data": {"k1": {"r6v1": {"subr": [1, 2, 3]}}}}, ) - def _assert_data(self, compare, column='data'): + def _assert_data(self, compare, column="data"): col = self.tables.data_table.c[column] data = testing.db.execute( - select([col]). - order_by(self.tables.data_table.c.name) + select([col]).order_by(self.tables.data_table.c.name) ).fetchall() eq_([d for d, in data], compare) - def _assert_column_is_NULL(self, column='data'): + def _assert_column_is_NULL(self, column="data"): col = self.tables.data_table.c[column] data = testing.db.execute( - select([col]). - where(col.is_(null())) + select([col]).where(col.is_(null())) ).fetchall() eq_([d for d, in data], [None]) - def _assert_column_is_JSON_NULL(self, column='data'): + def _assert_column_is_JSON_NULL(self, column="data"): col = self.tables.data_table.c[column] data = testing.db.execute( - select([col]). - where(cast(col, String) == "null") + select([col]).where(cast(col, String) == "null") ).fetchall() eq_([d for d, in data], [None]) def _test_insert(self, engine): engine.execute( self.tables.data_table.insert(), - {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}} + {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, ) self._assert_data([{"k1": "r1v1", "k2": "r1v2"}]) def _test_insert_nulls(self, engine): engine.execute( - self.tables.data_table.insert(), - {'name': 'r1', 'data': null()} + self.tables.data_table.insert(), {"name": "r1", "data": null()} ) self._assert_data([None]) def _test_insert_none_as_null(self, engine): engine.execute( - self.tables.data_table.insert(), - {'name': 'r1', 'nulldata': None} + self.tables.data_table.insert(), {"name": "r1", "nulldata": None} ) - self._assert_column_is_NULL(column='nulldata') + self._assert_column_is_NULL(column="nulldata") def _test_insert_nulljson_into_none_as_null(self, engine): engine.execute( self.tables.data_table.insert(), - {'name': 'r1', 'nulldata': JSON.NULL} + {"name": "r1", "nulldata": JSON.NULL}, ) - self._assert_column_is_JSON_NULL(column='nulldata') + self._assert_column_is_JSON_NULL(column="nulldata") def _non_native_engine(self, json_serializer=None, json_deserializer=None): if json_serializer is not None or json_deserializer is not None: options = { "json_serializer": json_serializer, - "json_deserializer": json_deserializer + "json_deserializer": json_deserializer, } else: options = {} - if testing.against("postgresql+psycopg2") and \ - testing.db.dialect.psycopg2_version >= (2, 5): + if testing.against( + "postgresql+psycopg2" + ) and testing.db.dialect.psycopg2_version >= (2, 5): from psycopg2.extras import register_default_json + engine = engines.testing_engine(options=options) @event.listens_for(engine, "connect") @@ -2924,7 +2929,9 @@ class JSONRoundTripTest(fixtures.TablesTest): def pass_(value): return value + register_default_json(dbapi_connection, loads=pass_) + elif options: engine = engines.testing_engine(options=options) else: @@ -2934,8 +2941,8 @@ class JSONRoundTripTest(fixtures.TablesTest): def test_reflect(self): insp = inspect(testing.db) - cols = insp.get_columns('data_table') - assert isinstance(cols[2]['type'], self.test_type) + cols = insp.get_columns("data_table") + assert isinstance(cols[2]["type"], self.test_type) @testing.requires.psycopg2_native_json def test_insert_native(self): @@ -2978,41 +2985,25 @@ class JSONRoundTripTest(fixtures.TablesTest): def loads(value): value = json.loads(value) - value['x'] = value['x'] + '_loads' + value["x"] = value["x"] + "_loads" return value def dumps(value): value = dict(value) - value['x'] = 'dumps_y' + value["x"] = "dumps_y" return json.dumps(value) if native: - engine = engines.testing_engine(options=dict( - json_serializer=dumps, - json_deserializer=loads - )) + engine = engines.testing_engine( + options=dict(json_serializer=dumps, json_deserializer=loads) + ) else: engine = self._non_native_engine( - json_serializer=dumps, - json_deserializer=loads + json_serializer=dumps, json_deserializer=loads ) - s = select([ - cast( - { - "key": "value", - "x": "q" - }, - self.test_type - ) - ]) - eq_( - engine.scalar(s), - { - "key": "value", - "x": "dumps_y_loads" - }, - ) + s = select([cast({"key": "value", "x": "q"}, self.test_type)]) + eq_(engine.scalar(s), {"key": "value", "x": "dumps_y_loads"}) @testing.requires.psycopg2_native_json def test_custom_native(self): @@ -3040,14 +3031,14 @@ class JSONRoundTripTest(fixtures.TablesTest): result = engine.execute( select([data_table.c.name]).where( - data_table.c.data[('k1', 'r6v1', 'subr')].astext == "[1, 2, 3]" + data_table.c.data[("k1", "r6v1", "subr")].astext == "[1, 2, 3]" ) ) - eq_(result.scalar(), 'r6') + eq_(result.scalar(), "r6") @testing.fails_on( - "postgresql < 9.4", - "Improvement in PostgreSQL behavior?") + "postgresql < 9.4", "Improvement in PostgreSQL behavior?" + ) def test_multi_index_query(self): engine = testing.db self._fixture_data(engine) @@ -3055,17 +3046,17 @@ class JSONRoundTripTest(fixtures.TablesTest): result = engine.execute( select([data_table.c.name]).where( - data_table.c.data['k1']['r6v1']['subr'].astext == "[1, 2, 3]" + data_table.c.data["k1"]["r6v1"]["subr"].astext == "[1, 2, 3]" ) ) - eq_(result.scalar(), 'r6') + eq_(result.scalar(), "r6") def test_query_returned_as_text(self): engine = testing.db self._fixture_data(engine) data_table = self.tables.data_table result = engine.execute( - select([data_table.c.data['k1'].astext]) + select([data_table.c.data["k1"].astext]) ).first() if engine.dialect.returns_unicode_strings: assert isinstance(result[0], util.text_type) @@ -3077,8 +3068,9 @@ class JSONRoundTripTest(fixtures.TablesTest): self._fixture_data(engine) data_table = self.tables.data_table result = engine.execute( - select([data_table.c.data['k3'].astext.cast(Integer)]).where( - data_table.c.name == 'r5') + select([data_table.c.data["k3"].astext.cast(Integer)]).where( + data_table.c.name == "r5" + ) ).first() assert isinstance(result[0], int) @@ -3086,34 +3078,30 @@ class JSONRoundTripTest(fixtures.TablesTest): data_table = self.tables.data_table result = engine.execute( select([data_table.c.data]).where( - data_table.c.data['k1'].astext == 'r3v1' + data_table.c.data["k1"].astext == "r3v1" ) ).first() - eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},)) + eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) result = engine.execute( select([data_table.c.data]).where( - data_table.c.data['k1'].astext.cast(String) == 'r3v1' + data_table.c.data["k1"].astext.cast(String) == "r3v1" ) ).first() - eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},)) + eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) def _test_fixed_round_trip(self, engine): - s = select([ - cast( - { - "key": "value", - "key2": {"k1": "v1", "k2": "v2"} - }, - self.test_type - ) - ]) + s = select( + [ + cast( + {"key": "value", "key2": {"k1": "v1", "k2": "v2"}}, + self.test_type, + ) + ] + ) eq_( engine.scalar(s), - { - "key": "value", - "key2": {"k1": "v1", "k2": "v2"} - }, + {"key": "value", "key2": {"k1": "v1", "k2": "v2"}}, ) def test_fixed_round_trip_python(self): @@ -3126,20 +3114,22 @@ class JSONRoundTripTest(fixtures.TablesTest): self._test_fixed_round_trip(engine) def _test_unicode_round_trip(self, engine): - s = select([ - cast( - { - util.u('réveillé'): util.u('réveillé'), - "data": {"k1": util.u('drôle')} - }, - self.test_type - ) - ]) + s = select( + [ + cast( + { + util.u("réveillé"): util.u("réveillé"), + "data": {"k1": util.u("drôle")}, + }, + self.test_type, + ) + ] + ) eq_( engine.scalar(s), { - util.u('réveillé'): util.u('réveillé'), - "data": {"k1": util.u('drôle')} + util.u("réveillé"): util.u("réveillé"), + "data": {"k1": util.u("drôle")}, }, ) @@ -3160,7 +3150,7 @@ class JSONRoundTripTest(fixtures.TablesTest): s = Session(testing.db) - d1 = Data(name='d1', data=None, nulldata=None) + d1 = Data(name="d1", data=None, nulldata=None) s.add(d1) s.commit() @@ -3170,27 +3160,32 @@ class JSONRoundTripTest(fixtures.TablesTest): eq_( s.query( cast(self.tables.data_table.c.data, String), - cast(self.tables.data_table.c.nulldata, String) - ).filter(self.tables.data_table.c.name == 'd1').first(), - ("null", None) + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d1") + .first(), + ("null", None), ) eq_( s.query( cast(self.tables.data_table.c.data, String), - cast(self.tables.data_table.c.nulldata, String) - ).filter(self.tables.data_table.c.name == 'd2').first(), - ("null", None) + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d2") + .first(), + ("null", None), ) class JSONBTest(JSONTest): - def setup(self): metadata = MetaData() - self.test_table = Table('test_table', metadata, - Column('id', Integer, primary_key=True), - Column('test_column', JSONB) - ) + self.test_table = Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("test_column", JSONB), + ) self.jsoncol = self.test_table.c.test_column # Note - add fixture data for arrays [] @@ -3198,37 +3193,39 @@ class JSONBTest(JSONTest): def test_where_has_key(self): self._test_where( # hide from 2to3 - getattr(self.jsoncol, 'has_key')('data'), - "test_table.test_column ? %(test_column_1)s" + getattr(self.jsoncol, "has_key")("data"), + "test_table.test_column ? %(test_column_1)s", ) def test_where_has_all(self): self._test_where( self.jsoncol.has_all( - {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}}), - "test_table.test_column ?& %(test_column_1)s") + {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}} + ), + "test_table.test_column ?& %(test_column_1)s", + ) def test_where_has_any(self): self._test_where( - self.jsoncol.has_any(postgresql.array(['name', 'data'])), - "test_table.test_column ?| ARRAY[%(param_1)s, %(param_2)s]" + self.jsoncol.has_any(postgresql.array(["name", "data"])), + "test_table.test_column ?| ARRAY[%(param_1)s, %(param_2)s]", ) def test_where_contains(self): self._test_where( self.jsoncol.contains({"k1": "r1v1"}), - "test_table.test_column @> %(test_column_1)s" + "test_table.test_column @> %(test_column_1)s", ) def test_where_contained_by(self): self._test_where( - self.jsoncol.contained_by({'foo': '1', 'bar': None}), - "test_table.test_column <@ %(test_column_1)s" + self.jsoncol.contained_by({"foo": "1", "bar": None}), + "test_table.test_column <@ %(test_column_1)s", ) class JSONBRoundTripTest(JSONRoundTripTest): - __requires__ = ('postgresql_jsonb', ) + __requires__ = ("postgresql_jsonb",) test_type = JSONB diff --git a/test/dialect/test_all.py b/test/dialect/test_all.py index 3e028e87aa..c61a258dad 100644 --- a/test/dialect/test_all.py +++ b/test/dialect/test_all.py @@ -4,12 +4,11 @@ from sqlalchemy import dialects class ImportStarTest(fixtures.TestBase): - def _all_dialect_packages(self): return [ getattr(__import__("sqlalchemy.dialects.%s" % d).dialects, d) for d in dialects.__all__ - if not d.startswith('_') + if not d.startswith("_") ] def test_all_import(self): diff --git a/test/dialect/test_firebird.py b/test/dialect/test_firebird.py index ff8c16eb13..3aac6a5c05 100644 --- a/test/dialect/test_firebird.py +++ b/test/dialect/test_firebird.py @@ -4,14 +4,30 @@ from sqlalchemy.databases import firebird from sqlalchemy.exc import ProgrammingError from sqlalchemy.sql import table, column from sqlalchemy import types as sqltypes -from sqlalchemy.testing import (fixtures, - AssertsExecutionResults, - AssertsCompiledSQL) +from sqlalchemy.testing import ( + fixtures, + AssertsExecutionResults, + AssertsCompiledSQL, +) from sqlalchemy import testing from sqlalchemy.testing import engines -from sqlalchemy import String, VARCHAR, NVARCHAR, Unicode, Integer,\ - func, insert, update, MetaData, select, Table, Column, text,\ - Sequence, Float +from sqlalchemy import ( + String, + VARCHAR, + NVARCHAR, + Unicode, + Integer, + func, + insert, + update, + MetaData, + select, + Table, + Column, + text, + Sequence, + Float, +) from sqlalchemy import schema from sqlalchemy.testing.mock import Mock, call @@ -19,75 +35,92 @@ from sqlalchemy.testing.mock import Mock, call class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): "Test Firebird domains" - __only_on__ = 'firebird' + __only_on__ = "firebird" @classmethod def setup_class(cls): con = testing.db.connect() try: - con.execute('CREATE DOMAIN int_domain AS INTEGER DEFAULT ' - '42 NOT NULL') - con.execute('CREATE DOMAIN str_domain AS VARCHAR(255)') - con.execute('CREATE DOMAIN rem_domain AS BLOB SUB_TYPE TEXT' - ) - con.execute('CREATE DOMAIN img_domain AS BLOB SUB_TYPE ' - 'BINARY') + con.execute( + "CREATE DOMAIN int_domain AS INTEGER DEFAULT " "42 NOT NULL" + ) + con.execute("CREATE DOMAIN str_domain AS VARCHAR(255)") + con.execute("CREATE DOMAIN rem_domain AS BLOB SUB_TYPE TEXT") + con.execute("CREATE DOMAIN img_domain AS BLOB SUB_TYPE " "BINARY") except ProgrammingError as e: - if 'attempt to store duplicate value' not in str(e): + if "attempt to store duplicate value" not in str(e): raise e - con.execute('''CREATE GENERATOR gen_testtable_id''') - con.execute('''CREATE TABLE testtable (question int_domain, + con.execute("""CREATE GENERATOR gen_testtable_id""") + con.execute( + """CREATE TABLE testtable (question int_domain, answer str_domain DEFAULT 'no answer', remark rem_domain DEFAULT '', photo img_domain, d date, t time, dt timestamp, - redundant str_domain DEFAULT NULL)''') - con.execute("ALTER TABLE testtable " - "ADD CONSTRAINT testtable_pk PRIMARY KEY " - "(question)") - con.execute("CREATE TRIGGER testtable_autoid FOR testtable " - " ACTIVE BEFORE INSERT AS" - " BEGIN" - " IF (NEW.question IS NULL) THEN" - " NEW.question = gen_id(gen_testtable_id, 1);" - " END") + redundant str_domain DEFAULT NULL)""" + ) + con.execute( + "ALTER TABLE testtable " + "ADD CONSTRAINT testtable_pk PRIMARY KEY " + "(question)" + ) + con.execute( + "CREATE TRIGGER testtable_autoid FOR testtable " + " ACTIVE BEFORE INSERT AS" + " BEGIN" + " IF (NEW.question IS NULL) THEN" + " NEW.question = gen_id(gen_testtable_id, 1);" + " END" + ) @classmethod def teardown_class(cls): con = testing.db.connect() - con.execute('DROP TABLE testtable') - con.execute('DROP DOMAIN int_domain') - con.execute('DROP DOMAIN str_domain') - con.execute('DROP DOMAIN rem_domain') - con.execute('DROP DOMAIN img_domain') - con.execute('DROP GENERATOR gen_testtable_id') + con.execute("DROP TABLE testtable") + con.execute("DROP DOMAIN int_domain") + con.execute("DROP DOMAIN str_domain") + con.execute("DROP DOMAIN rem_domain") + con.execute("DROP DOMAIN img_domain") + con.execute("DROP GENERATOR gen_testtable_id") def test_table_is_reflected(self): - from sqlalchemy.types import Integer, Text, BLOB, String, Date, \ - Time, DateTime + from sqlalchemy.types import ( + Integer, + Text, + BLOB, + String, + Date, + Time, + DateTime, + ) + metadata = MetaData(testing.db) - table = Table('testtable', metadata, autoload=True) - eq_(set(table.columns.keys()), set([ - 'question', - 'answer', - 'remark', - 'photo', - 'd', - 't', - 'dt', - 'redundant', - ]), - "Columns of reflected table didn't equal expected " - "columns") + table = Table("testtable", metadata, autoload=True) + eq_( + set(table.columns.keys()), + set( + [ + "question", + "answer", + "remark", + "photo", + "d", + "t", + "dt", + "redundant", + ] + ), + "Columns of reflected table didn't equal expected " "columns", + ) eq_(table.c.question.primary_key, True) # disabled per http://www.sqlalchemy.org/trac/ticket/1660 # eq_(table.c.question.sequence.name, 'gen_testtable_id') assert isinstance(table.c.question.type, Integer) - eq_(table.c.question.server_default.arg.text, '42') + eq_(table.c.question.server_default.arg.text, "42") assert isinstance(table.c.answer.type, String) assert table.c.answer.type.length == 255 eq_(table.c.answer.server_default.arg.text, "'no answer'") @@ -107,7 +140,7 @@ class BuggyDomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): """Test Firebird domains (and some other reflection bumps), see [ticket:1663] and http://tracker.firebirdsql.org/browse/CORE-356""" - __only_on__ = 'firebird' + __only_on__ = "firebird" # NB: spacing and newlines are *significant* here! # PS: this test is superfluous on recent FB, where the issue 356 is @@ -204,40 +237,46 @@ ID DOM_ID /* INTEGER NOT NULL */ default 0 ) @classmethod def teardown_class(cls): con = testing.db.connect() - con.execute('DROP TABLE a') + con.execute("DROP TABLE a") con.execute("DROP TABLE b") - con.execute('DROP DOMAIN dom_id') - con.execute('DROP TABLE def_error_nodom') - con.execute('DROP TABLE def_error') - con.execute('DROP DOMAIN rit_tesoreria_capitolo_dm') - con.execute('DROP DOMAIN nosi_dm') - con.execute('DROP DOMAIN money_dm') - con.execute('DROP DOMAIN autoinc_dm') + con.execute("DROP DOMAIN dom_id") + con.execute("DROP TABLE def_error_nodom") + con.execute("DROP TABLE def_error") + con.execute("DROP DOMAIN rit_tesoreria_capitolo_dm") + con.execute("DROP DOMAIN nosi_dm") + con.execute("DROP DOMAIN money_dm") + con.execute("DROP DOMAIN autoinc_dm") def test_tables_are_reflected_same_way(self): metadata = MetaData(testing.db) - table_dom = Table('def_error', metadata, autoload=True) - table_nodom = Table('def_error_nodom', metadata, autoload=True) + table_dom = Table("def_error", metadata, autoload=True) + table_nodom = Table("def_error_nodom", metadata, autoload=True) - eq_(table_dom.c.interessi.server_default.arg.text, - table_nodom.c.interessi.server_default.arg.text) - eq_(table_dom.c.ritenuta.server_default.arg.text, - table_nodom.c.ritenuta.server_default.arg.text) - eq_(table_dom.c.stampato_modulo.server_default.arg.text, - table_nodom.c.stampato_modulo.server_default.arg.text) + eq_( + table_dom.c.interessi.server_default.arg.text, + table_nodom.c.interessi.server_default.arg.text, + ) + eq_( + table_dom.c.ritenuta.server_default.arg.text, + table_nodom.c.ritenuta.server_default.arg.text, + ) + eq_( + table_dom.c.stampato_modulo.server_default.arg.text, + table_nodom.c.stampato_modulo.server_default.arg.text, + ) def test_intermixed_comment(self): metadata = MetaData(testing.db) - table_a = Table('a', metadata, autoload=True) + table_a = Table("a", metadata, autoload=True) eq_(table_a.c.id.server_default.arg.text, "0") def test_lowercase_default_name(self): metadata = MetaData(testing.db) - table_b = Table('b', metadata, autoload=True) + table_b = Table("b", metadata, autoload=True) eq_(table_b.c.id.server_default.arg.text, "0") @@ -247,17 +286,21 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = firebird.FBDialect() def test_alias(self): - t = table('sometable', column('col1'), column('col2')) + t = table("sometable", column("col1"), column("col2")) s = select([t.alias()]) - self.assert_compile(s, - 'SELECT sometable_1.col1, sometable_1.col2 ' - 'FROM sometable AS sometable_1') + self.assert_compile( + s, + "SELECT sometable_1.col1, sometable_1.col2 " + "FROM sometable AS sometable_1", + ) dialect = firebird.FBDialect() dialect._version_two = False - self.assert_compile(s, - 'SELECT sometable_1.col1, sometable_1.col2 ' - 'FROM sometable sometable_1', - dialect=dialect) + self.assert_compile( + s, + "SELECT sometable_1.col1, sometable_1.col2 " + "FROM sometable sometable_1", + dialect=dialect, + ) def test_varchar_raise(self): for type_ in ( @@ -273,121 +316,146 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): exc.CompileError, "VARCHAR requires a length on dialect firebird", type_.compile, - dialect=firebird.dialect()) + dialect=firebird.dialect(), + ) - t1 = Table('sometable', MetaData(), - Column('somecolumn', type_)) + t1 = Table("sometable", MetaData(), Column("somecolumn", type_)) assert_raises_message( exc.CompileError, r"\(in table 'sometable', column 'somecolumn'\)\: " r"(?:N)?VARCHAR requires a length on dialect firebird", schema.CreateTable(t1).compile, - dialect=firebird.dialect() + dialect=firebird.dialect(), ) def test_function(self): - self.assert_compile(func.foo(1, 2), 'foo(:foo_1, :foo_2)') - self.assert_compile(func.current_time(), 'CURRENT_TIME') - self.assert_compile(func.foo(), 'foo') + self.assert_compile(func.foo(1, 2), "foo(:foo_1, :foo_2)") + self.assert_compile(func.current_time(), "CURRENT_TIME") + self.assert_compile(func.foo(), "foo") m = MetaData() - t = Table('sometable', - m, - Column('col1', Integer), - Column('col2', Integer)) - self.assert_compile(select([func.max(t.c.col1)]), - 'SELECT max(sometable.col1) AS max_1 FROM ' - 'sometable') + t = Table( + "sometable", m, Column("col1", Integer), Column("col2", Integer) + ) + self.assert_compile( + select([func.max(t.c.col1)]), + "SELECT max(sometable.col1) AS max_1 FROM " "sometable", + ) def test_substring(self): - self.assert_compile(func.substring('abc', 1, 2), - 'SUBSTRING(:substring_1 FROM :substring_2 ' - 'FOR :substring_3)') - self.assert_compile(func.substring('abc', 1), - 'SUBSTRING(:substring_1 FROM :substring_2)') + self.assert_compile( + func.substring("abc", 1, 2), + "SUBSTRING(:substring_1 FROM :substring_2 " "FOR :substring_3)", + ) + self.assert_compile( + func.substring("abc", 1), + "SUBSTRING(:substring_1 FROM :substring_2)", + ) def test_update_returning(self): - table1 = table('mytable', - column('myid', Integer), - column('name', String(128)), - column('description', String(128))) - u = update(table1, values=dict(name='foo'))\ - .returning(table1.c.myid, table1.c.name) - self.assert_compile(u, - 'UPDATE mytable SET name=:name RETURNING ' - 'mytable.myid, mytable.name') - u = update(table1, values=dict(name='foo')).returning(table1) - self.assert_compile(u, - 'UPDATE mytable SET name=:name RETURNING ' - 'mytable.myid, mytable.name, ' - 'mytable.description') - u = update(table1, values=dict(name='foo')) \ - .returning(func.length(table1.c.name)) - self.assert_compile(u, - 'UPDATE mytable SET name=:name RETURNING ' - 'char_length(mytable.name) AS length_1') + table1 = table( + "mytable", + column("myid", Integer), + column("name", String(128)), + column("description", String(128)), + ) + u = update(table1, values=dict(name="foo")).returning( + table1.c.myid, table1.c.name + ) + self.assert_compile( + u, + "UPDATE mytable SET name=:name RETURNING " + "mytable.myid, mytable.name", + ) + u = update(table1, values=dict(name="foo")).returning(table1) + self.assert_compile( + u, + "UPDATE mytable SET name=:name RETURNING " + "mytable.myid, mytable.name, " + "mytable.description", + ) + u = update(table1, values=dict(name="foo")).returning( + func.length(table1.c.name) + ) + self.assert_compile( + u, + "UPDATE mytable SET name=:name RETURNING " + "char_length(mytable.name) AS length_1", + ) def test_insert_returning(self): - table1 = table('mytable', - column('myid', Integer), - column('name', String(128)), - column('description', String(128))) - i = insert(table1, values=dict(name='foo'))\ - .returning(table1.c.myid, table1.c.name) - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES (:name) ' - 'RETURNING mytable.myid, mytable.name') - i = insert(table1, values=dict(name='foo')).returning(table1) - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES (:name) ' - 'RETURNING mytable.myid, mytable.name, ' - 'mytable.description') - i = insert(table1, values=dict(name='foo'))\ - .returning(func.length(table1.c.name)) - self.assert_compile(i, - 'INSERT INTO mytable (name) VALUES (:name) ' - 'RETURNING char_length(mytable.name) AS ' - 'length_1') + table1 = table( + "mytable", + column("myid", Integer), + column("name", String(128)), + column("description", String(128)), + ) + i = insert(table1, values=dict(name="foo")).returning( + table1.c.myid, table1.c.name + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES (:name) " + "RETURNING mytable.myid, mytable.name", + ) + i = insert(table1, values=dict(name="foo")).returning(table1) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES (:name) " + "RETURNING mytable.myid, mytable.name, " + "mytable.description", + ) + i = insert(table1, values=dict(name="foo")).returning( + func.length(table1.c.name) + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES (:name) " + "RETURNING char_length(mytable.name) AS " + "length_1", + ) def test_charset(self): """Exercise CHARACTER SET options on string types.""" - columns = [(firebird.CHAR, [1], {}, 'CHAR(1)'), (firebird.CHAR, - [1], {'charset': 'OCTETS'}, - 'CHAR(1) CHARACTER SET OCTETS'), (firebird.VARCHAR, - [1], {}, 'VARCHAR(1)'), (firebird.VARCHAR, [1], - {'charset': 'OCTETS'}, - 'VARCHAR(1) CHARACTER SET OCTETS')] + columns = [ + (firebird.CHAR, [1], {}, "CHAR(1)"), + ( + firebird.CHAR, + [1], + {"charset": "OCTETS"}, + "CHAR(1) CHARACTER SET OCTETS", + ), + (firebird.VARCHAR, [1], {}, "VARCHAR(1)"), + ( + firebird.VARCHAR, + [1], + {"charset": "OCTETS"}, + "VARCHAR(1) CHARACTER SET OCTETS", + ), + ] for type_, args, kw, res in columns: self.assert_compile(type_(*args, **kw), res) def test_quoting_initial_chars(self): - self.assert_compile( - column("_somecol"), - '"_somecol"' - ) - self.assert_compile( - column("$somecol"), - '"$somecol"' - ) + self.assert_compile(column("_somecol"), '"_somecol"') + self.assert_compile(column("$somecol"), '"$somecol"') class TypesTest(fixtures.TestBase): - __only_on__ = 'firebird' + __only_on__ = "firebird" @testing.provide_metadata def test_infinite_float(self): metadata = self.metadata - t = Table('t', metadata, - Column('data', Float)) + t = Table("t", metadata, Column("data", Float)) metadata.create_all() - t.insert().execute(data=float('inf')) - eq_(t.select().execute().fetchall(), - [(float('inf'),)]) + t.insert().execute(data=float("inf")) + eq_(t.select().execute().fetchall(), [(float("inf"),)]) class MiscTest(fixtures.TestBase): - __only_on__ = 'firebird' + __only_on__ = "firebird" @testing.provide_metadata def test_strlen(self): @@ -399,54 +467,52 @@ class MiscTest(fixtures.TestBase): # string length the UDF was declared to accept). This test # checks that at least it works ok in other cases. - t = Table('t1', - metadata, - Column('id', Integer, Sequence('t1idseq'), primary_key=True), - Column('name', String(10))) + t = Table( + "t1", + metadata, + Column("id", Integer, Sequence("t1idseq"), primary_key=True), + Column("name", String(10)), + ) metadata.create_all() - t.insert(values=dict(name='dante')).execute() - t.insert(values=dict(name='alighieri')).execute() - select([func.count(t.c.id)], func.length(t.c.name) - == 5).execute().first()[0] == 1 + t.insert(values=dict(name="dante")).execute() + t.insert(values=dict(name="alighieri")).execute() + select( + [func.count(t.c.id)], func.length(t.c.name) == 5 + ).execute().first()[0] == 1 def test_version_parsing(self): for string, result in [ - ("WI-V1.5.0.1234 Firebird 1.5", (1, 5, 1234, 'firebird')), - ("UI-V6.3.2.18118 Firebird 2.1", (2, 1, 18118, 'firebird')), - ("LI-V6.3.3.12981 Firebird 2.0", (2, 0, 12981, 'firebird')), - ("WI-V8.1.1.333", (8, 1, 1, 'interbase')), - ("WI-V8.1.1.333 Firebird 1.5", (1, 5, 333, 'firebird')), + ("WI-V1.5.0.1234 Firebird 1.5", (1, 5, 1234, "firebird")), + ("UI-V6.3.2.18118 Firebird 2.1", (2, 1, 18118, "firebird")), + ("LI-V6.3.3.12981 Firebird 2.0", (2, 0, 12981, "firebird")), + ("WI-V8.1.1.333", (8, 1, 1, "interbase")), + ("WI-V8.1.1.333 Firebird 1.5", (1, 5, 333, "firebird")), ]: - eq_( - testing.db.dialect._parse_version_info(string), - result - ) + eq_(testing.db.dialect._parse_version_info(string), result) @testing.provide_metadata def test_rowcount_flag(self): metadata = self.metadata - engine = engines.testing_engine( - options={'enable_rowcount': True}) + engine = engines.testing_engine(options={"enable_rowcount": True}) assert engine.dialect.supports_sane_rowcount metadata.bind = engine - t = Table('t1', metadata, Column('data', String(10))) + t = Table("t1", metadata, Column("data", String(10))) metadata.create_all() - r = t.insert().execute({'data': 'd1'}, {'data': 'd2'}, {'data': 'd3'}) - r = t.update().where(t.c.data == 'd2').values(data='d3').execute() + r = t.insert().execute({"data": "d1"}, {"data": "d2"}, {"data": "d3"}) + r = t.update().where(t.c.data == "d2").values(data="d3").execute() eq_(r.rowcount, 1) - r = t.delete().where(t.c.data == 'd3').execute() + r = t.delete().where(t.c.data == "d3").execute() eq_(r.rowcount, 2) - r = \ - t.delete().execution_options(enable_rowcount=False).execute() + r = t.delete().execution_options(enable_rowcount=False).execute() eq_(r.rowcount, -1) engine.dispose() - engine = engines.testing_engine(options={'enable_rowcount': False}) + engine = engines.testing_engine(options={"enable_rowcount": False}) assert not engine.dialect.supports_sane_rowcount metadata.bind = engine - r = t.insert().execute({'data': 'd1'}, {'data': 'd2'}, {'data': 'd3'}) - r = t.update().where(t.c.data == 'd2').values(data='d3').execute() + r = t.insert().execute({"data": "d1"}, {"data": "d2"}, {"data": "d3"}) + r = t.update().where(t.c.data == "d2").values(data="d3").execute() eq_(r.rowcount, -1) - r = t.delete().where(t.c.data == 'd3').execute() + r = t.delete().where(t.c.data == "d3").execute() eq_(r.rowcount, -1) r = t.delete().execution_options(enable_rowcount=True).execute() eq_(r.rowcount, 1) @@ -454,35 +520,33 @@ class MiscTest(fixtures.TestBase): engine.dispose() def test_percents_in_text(self): - for expr, result in (text("select '%' from rdb$database"), '%' - ), (text("select '%%' from rdb$database"), - '%%'), \ - (text("select '%%%' from rdb$database"), '%%%'), \ - (text("select 'hello % world' from rdb$database"), - 'hello % world'): + for expr, result in ( + (text("select '%' from rdb$database"), "%"), + (text("select '%%' from rdb$database"), "%%"), + (text("select '%%%' from rdb$database"), "%%%"), + ( + text("select 'hello % world' from rdb$database"), + "hello % world", + ), + ): eq_(testing.db.scalar(expr), result) class ArgumentTest(fixtures.TestBase): def _dbapi(self): return Mock( - paramstyle='qmark', - connect=Mock( - return_value=Mock( - server_version="UI-V6.3.2.18118 Firebird 2.1", - cursor=Mock(return_value=Mock()) - ) - ) + paramstyle="qmark", + connect=Mock( + return_value=Mock( + server_version="UI-V6.3.2.18118 Firebird 2.1", + cursor=Mock(return_value=Mock()), ) + ), + ) def _engine(self, type_, **kw): dbapi = self._dbapi() - kw.update( - dict( - module=dbapi, - _initialize=False - ) - ) + kw.update(dict(module=dbapi, _initialize=False)) engine = engines.testing_engine("firebird+%s://" % type_, options=kw) return engine @@ -516,12 +580,12 @@ class ArgumentTest(fixtures.TestBase): trans.commit() eq_( engine.dialect.dbapi.connect.return_value.commit.mock_calls, - [call(flag)] + [call(flag)], ) trans = conn.begin() trans.rollback() eq_( engine.dialect.dbapi.connect.return_value.rollback.mock_calls, - [call(flag)] + [call(flag)], ) diff --git a/test/dialect/test_mxodbc.py b/test/dialect/test_mxodbc.py index e5f9f330fe..abff52744c 100644 --- a/test/dialect/test_mxodbc.py +++ b/test/dialect/test_mxodbc.py @@ -6,38 +6,52 @@ from sqlalchemy.testing.mock import Mock def mock_dbapi(): - return Mock(paramstyle='qmark', - connect=Mock( - return_value=Mock( - cursor=Mock(return_value=Mock(description=None, - rowcount=None))))) + return Mock( + paramstyle="qmark", + connect=Mock( + return_value=Mock( + cursor=Mock(return_value=Mock(description=None, rowcount=None)) + ) + ), + ) class MxODBCTest(fixtures.TestBase): - def test_native_odbc_execute(self): - t1 = Table('t1', MetaData(), Column('c1', Integer)) + t1 = Table("t1", MetaData(), Column("c1", Integer)) dbapi = mock_dbapi() - engine = engines.testing_engine('mssql+mxodbc://localhost', - options={'module': dbapi, - '_initialize': False}) + engine = engines.testing_engine( + "mssql+mxodbc://localhost", + options={"module": dbapi, "_initialize": False}, + ) conn = engine.connect() # crud: uses execute - conn.execute(t1.insert().values(c1='foo')) - conn.execute(t1.delete().where(t1.c.c1 == 'foo')) - conn.execute(t1.update().where(t1.c.c1 == 'foo').values(c1='bar')) + conn.execute(t1.insert().values(c1="foo")) + conn.execute(t1.delete().where(t1.c.c1 == "foo")) + conn.execute(t1.update().where(t1.c.c1 == "foo").values(c1="bar")) # select: uses executedirect conn.execute(t1.select()) # manual flagging conn.execution_options(native_odbc_execute=True).execute(t1.select()) - conn.execution_options(native_odbc_execute=False)\ - .execute(t1.insert().values(c1='foo')) - - eq_([c[2] for c in - dbapi.connect.return_value.cursor.return_value.execute.mock_calls], - [{'direct': True}, {'direct': True}, {'direct': True}, - {'direct': True}, {'direct': False}, {'direct': True}]) + conn.execution_options(native_odbc_execute=False).execute( + t1.insert().values(c1="foo") + ) + + eq_( + [ + c[2] + for c in dbapi.connect.return_value.cursor.return_value.execute.mock_calls + ], + [ + {"direct": True}, + {"direct": True}, + {"direct": True}, + {"direct": True}, + {"direct": False}, + {"direct": True}, + ], + ) diff --git a/test/dialect/test_pyodbc.py b/test/dialect/test_pyodbc.py index 8d817fdb16..5c60c5b5ae 100644 --- a/test/dialect/test_pyodbc.py +++ b/test/dialect/test_pyodbc.py @@ -7,12 +7,9 @@ class PyODBCTest(fixtures.TestBase): def test_pyodbc_version(self): connector = pyodbc.PyODBCConnector() for vers, expected in [ - ('2.1.8', (2, 1, 8)), - ("py3-3.0.1-beta4", (3, 0, 1, 'beta4')), + ("2.1.8", (2, 1, 8)), + ("py3-3.0.1-beta4", (3, 0, 1, "beta4")), ("10.15.17", (10, 15, 17)), ("crap.crap.crap", ()), ]: - eq_( - connector._parse_dbapi_version(vers), - expected - ) + eq_(connector._parse_dbapi_version(vers), expected) diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index 417ace5c8f..0c3dd2acd6 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -4,22 +4,47 @@ import os import datetime -from sqlalchemy.testing import eq_, assert_raises, \ - assert_raises_message, is_, expect_warnings -from sqlalchemy import Table, select, bindparam, Column,\ - MetaData, func, extract, ForeignKey, text, DefaultClause, and_, \ - create_engine, \ - UniqueConstraint, Index, PrimaryKeyConstraint, CheckConstraint +from sqlalchemy.testing import ( + eq_, + assert_raises, + assert_raises_message, + is_, + expect_warnings, +) +from sqlalchemy import ( + Table, + select, + bindparam, + Column, + MetaData, + func, + extract, + ForeignKey, + text, + DefaultClause, + and_, + create_engine, + UniqueConstraint, + Index, + PrimaryKeyConstraint, + CheckConstraint, +) from sqlalchemy.types import Integer, String, Boolean, DateTime, Date, Time from sqlalchemy import types as sqltypes from sqlalchemy import event, inspect from sqlalchemy.util import u, ue from sqlalchemy import exc, sql, schema, pool, util -from sqlalchemy.dialects.sqlite import base as sqlite, \ - pysqlite as pysqlite_dialect +from sqlalchemy.dialects.sqlite import ( + base as sqlite, + pysqlite as pysqlite_dialect, +) from sqlalchemy.engine.url import make_url -from sqlalchemy.testing import fixtures, AssertsCompiledSQL, \ - AssertsExecutionResults, engines +from sqlalchemy.testing import ( + fixtures, + AssertsCompiledSQL, + AssertsExecutionResults, + engines, +) from sqlalchemy import testing from sqlalchemy.schema import CreateTable, FetchedValue from sqlalchemy.engine.reflection import Inspector @@ -28,7 +53,7 @@ from sqlalchemy.testing import mock class TestTypes(fixtures.TestBase, AssertsExecutionResults): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" def test_boolean(self): """Test that the boolean only treats 1 as True @@ -37,68 +62,93 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults): meta = MetaData(testing.db) t = Table( - 'bool_table', meta, - Column('id', Integer, primary_key=True), - Column('boo', Boolean(create_constraint=False))) + "bool_table", + meta, + Column("id", Integer, primary_key=True), + Column("boo", Boolean(create_constraint=False)), + ) try: meta.create_all() - testing.db.execute("INSERT INTO bool_table (id, boo) " - "VALUES (1, 'false');") - testing.db.execute("INSERT INTO bool_table (id, boo) " - "VALUES (2, 'true');") - testing.db.execute("INSERT INTO bool_table (id, boo) " - "VALUES (3, '1');") - testing.db.execute("INSERT INTO bool_table (id, boo) " - "VALUES (4, '0');") - testing.db.execute('INSERT INTO bool_table (id, boo) ' - 'VALUES (5, 1);') - testing.db.execute('INSERT INTO bool_table (id, boo) ' - 'VALUES (6, 0);') - eq_(t.select(t.c.boo).order_by(t.c.id).execute().fetchall(), - [(3, True), (5, True)]) + testing.db.execute( + "INSERT INTO bool_table (id, boo) " "VALUES (1, 'false');" + ) + testing.db.execute( + "INSERT INTO bool_table (id, boo) " "VALUES (2, 'true');" + ) + testing.db.execute( + "INSERT INTO bool_table (id, boo) " "VALUES (3, '1');" + ) + testing.db.execute( + "INSERT INTO bool_table (id, boo) " "VALUES (4, '0');" + ) + testing.db.execute( + "INSERT INTO bool_table (id, boo) " "VALUES (5, 1);" + ) + testing.db.execute( + "INSERT INTO bool_table (id, boo) " "VALUES (6, 0);" + ) + eq_( + t.select(t.c.boo).order_by(t.c.id).execute().fetchall(), + [(3, True), (5, True)], + ) finally: meta.drop_all() def test_string_dates_passed_raise(self): - assert_raises(exc.StatementError, testing.db.execute, - select([1]).where(bindparam('date', type_=Date)), - date=str(datetime.date(2007, 10, 30))) + assert_raises( + exc.StatementError, + testing.db.execute, + select([1]).where(bindparam("date", type_=Date)), + date=str(datetime.date(2007, 10, 30)), + ) def test_cant_parse_datetime_message(self): for (typ, disp) in [ (Time, "time"), (DateTime, "datetime"), - (Date, "date") + (Date, "date"), ]: assert_raises_message( ValueError, "Couldn't parse %s string." % disp, lambda: testing.db.execute( text("select 'ASDF' as value", typemap={"value": typ}) - ).scalar() + ).scalar(), ) def test_native_datetime(self): dbapi = testing.db.dialect.dbapi connect_args = { - 'detect_types': dbapi.PARSE_DECLTYPES | dbapi.PARSE_COLNAMES} + "detect_types": dbapi.PARSE_DECLTYPES | dbapi.PARSE_COLNAMES + } engine = engines.testing_engine( - options={'connect_args': connect_args, 'native_datetime': True}) + options={"connect_args": connect_args, "native_datetime": True} + ) t = Table( - 'datetest', MetaData(), - Column('id', Integer, primary_key=True), - Column('d1', Date), Column('d2', sqltypes.TIMESTAMP)) + "datetest", + MetaData(), + Column("id", Integer, primary_key=True), + Column("d1", Date), + Column("d2", sqltypes.TIMESTAMP), + ) t.create(engine) try: - engine.execute(t.insert(), { - 'd1': datetime.date(2010, 5, 10), - 'd2': datetime.datetime(2010, 5, 10, 12, 15, 25) - }) + engine.execute( + t.insert(), + { + "d1": datetime.date(2010, 5, 10), + "d2": datetime.datetime(2010, 5, 10, 12, 15, 25), + }, + ) row = engine.execute(t.select()).first() eq_( row, - (1, datetime.date(2010, 5, 10), - datetime.datetime(2010, 5, 10, 12, 15, 25))) + ( + 1, + datetime.date(2010, 5, 10), + datetime.datetime(2010, 5, 10, 12, 15, 25), + ), + ) r = engine.execute(func.current_date()).scalar() assert isinstance(r, util.string_types) finally: @@ -113,21 +163,22 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults): "T%(hour)02d:%(minute)02d:%(second)02d", regexp=r"(\d+)-(\d+)-(\d+)T(\d+):(\d+):(\d+)", ) - t = Table('t', self.metadata, Column('d', sqlite_date)) + t = Table("t", self.metadata, Column("d", sqlite_date)) self.metadata.create_all(testing.db) testing.db.execute( - t.insert(). - values(d=datetime.datetime(2010, 10, 15, 12, 37, 0))) + t.insert().values(d=datetime.datetime(2010, 10, 15, 12, 37, 0)) + ) testing.db.execute("insert into t (d) values ('2004-05-21T00:00:00')") eq_( testing.db.execute("select * from t order by d").fetchall(), - [('2004-05-21T00:00:00',), ('2010-10-15T12:37:00',)] + [("2004-05-21T00:00:00",), ("2010-10-15T12:37:00",)], ) eq_( testing.db.execute(select([t.c.d]).order_by(t.c.d)).fetchall(), [ (datetime.datetime(2004, 5, 21, 0, 0),), - (datetime.datetime(2010, 10, 15, 12, 37),)] + (datetime.datetime(2010, 10, 15, 12, 37),), + ], ) @testing.provide_metadata @@ -137,21 +188,22 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults): "%(hour)02d%(minute)02d%(second)02d", regexp=r"(\d{4})(\d{2})(\d{2})(\d{2})(\d{2})(\d{2})", ) - t = Table('t', self.metadata, Column('d', sqlite_date)) + t = Table("t", self.metadata, Column("d", sqlite_date)) self.metadata.create_all(testing.db) testing.db.execute( - t.insert(). - values(d=datetime.datetime(2010, 10, 15, 12, 37, 0))) + t.insert().values(d=datetime.datetime(2010, 10, 15, 12, 37, 0)) + ) testing.db.execute("insert into t (d) values ('20040521000000')") eq_( testing.db.execute("select * from t order by d").fetchall(), - [('20040521000000',), ('20101015123700',)] + [("20040521000000",), ("20101015123700",)], ) eq_( testing.db.execute(select([t.c.d]).order_by(t.c.d)).fetchall(), [ (datetime.datetime(2004, 5, 21, 0, 0),), - (datetime.datetime(2010, 10, 15, 12, 37),)] + (datetime.datetime(2010, 10, 15, 12, 37),), + ], ) @testing.provide_metadata @@ -160,21 +212,17 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults): storage_format="%(year)04d%(month)02d%(day)02d", regexp=r"(\d{4})(\d{2})(\d{2})", ) - t = Table('t', self.metadata, Column('d', sqlite_date)) + t = Table("t", self.metadata, Column("d", sqlite_date)) self.metadata.create_all(testing.db) - testing.db.execute( - t.insert(). - values(d=datetime.date(2010, 10, 15))) + testing.db.execute(t.insert().values(d=datetime.date(2010, 10, 15))) testing.db.execute("insert into t (d) values ('20040521')") eq_( testing.db.execute("select * from t order by d").fetchall(), - [('20040521',), ('20101015',)] + [("20040521",), ("20101015",)], ) eq_( testing.db.execute(select([t.c.d]).order_by(t.c.d)).fetchall(), - [ - (datetime.date(2004, 5, 21),), - (datetime.date(2010, 10, 15),)] + [(datetime.date(2004, 5, 21),), (datetime.date(2010, 10, 15),)], ) @testing.provide_metadata @@ -184,21 +232,17 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults): storage_format="%(year)04d|%(month)02d|%(day)02d", regexp=r"(\d+)\|(\d+)\|(\d+)", ) - t = Table('t', self.metadata, Column('d', sqlite_date)) + t = Table("t", self.metadata, Column("d", sqlite_date)) self.metadata.create_all(testing.db) - testing.db.execute( - t.insert(). - values(d=datetime.date(2010, 10, 15))) + testing.db.execute(t.insert().values(d=datetime.date(2010, 10, 15))) testing.db.execute("insert into t (d) values ('2004|05|21')") eq_( testing.db.execute("select * from t order by d").fetchall(), - [('2004|05|21',), ('2010|10|15',)] + [("2004|05|21",), ("2010|10|15",)], ) eq_( testing.db.execute(select([t.c.d]).order_by(t.c.d)).fetchall(), - [ - (datetime.date(2004, 5, 21),), - (datetime.date(2010, 10, 15),)] + [(datetime.date(2004, 5, 21),), (datetime.date(2010, 10, 15),)], ) def test_no_convert_unicode(self): @@ -216,96 +260,82 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults): sqltypes.UnicodeText(), ): bindproc = t.dialect_impl(dialect).bind_processor(dialect) - assert not bindproc or \ - isinstance(bindproc(util.u('some string')), util.text_type) + assert not bindproc or isinstance( + bindproc(util.u("some string")), util.text_type + ) class JSONTest(fixtures.TestBase): - __requires__ = ('json_type', ) - __only_on__ = 'sqlite' + __requires__ = ("json_type",) + __only_on__ = "sqlite" @testing.provide_metadata @testing.requires.reflects_json_type def test_reflection(self): - Table( - 'json_test', self.metadata, - Column('foo', sqlite.JSON) - ) + Table("json_test", self.metadata, Column("foo", sqlite.JSON)) self.metadata.create_all() - reflected = Table('json_test', MetaData(), autoload_with=testing.db) + reflected = Table("json_test", MetaData(), autoload_with=testing.db) is_(reflected.c.foo.type._type_affinity, sqltypes.JSON) assert isinstance(reflected.c.foo.type, sqlite.JSON) @testing.provide_metadata def test_rudimentary_roundtrip(self): sqlite_json = Table( - 'json_test', self.metadata, - Column('foo', sqlite.JSON) + "json_test", self.metadata, Column("foo", sqlite.JSON) ) self.metadata.create_all() - value = { - 'json': {'foo': 'bar'}, - 'recs': ['one', 'two'] - } + value = {"json": {"foo": "bar"}, "recs": ["one", "two"]} with testing.db.connect() as conn: conn.execute(sqlite_json.insert(), foo=value) - eq_( - conn.scalar(select([sqlite_json.c.foo])), - value - ) + eq_(conn.scalar(select([sqlite_json.c.foo])), value) @testing.provide_metadata def test_extract_subobject(self): sqlite_json = Table( - 'json_test', self.metadata, - Column('foo', sqlite.JSON) + "json_test", self.metadata, Column("foo", sqlite.JSON) ) self.metadata.create_all() - value = { - 'json': {'foo': 'bar'}, - } + value = {"json": {"foo": "bar"}} with testing.db.connect() as conn: conn.execute(sqlite_json.insert(), foo=value) eq_( - conn.scalar(select([sqlite_json.c.foo['json']])), - value['json'] + conn.scalar(select([sqlite_json.c.foo["json"]])), value["json"] ) class DateTimeTest(fixtures.TestBase, AssertsCompiledSQL): - def test_time_microseconds(self): - dt = datetime.datetime(2008, 6, 27, 12, 0, 0, 125, ) - eq_(str(dt), '2008-06-27 12:00:00.000125') + dt = datetime.datetime(2008, 6, 27, 12, 0, 0, 125) + eq_(str(dt), "2008-06-27 12:00:00.000125") sldt = sqlite.DATETIME() bp = sldt.bind_processor(None) - eq_(bp(dt), '2008-06-27 12:00:00.000125') + eq_(bp(dt), "2008-06-27 12:00:00.000125") rp = sldt.result_processor(None, None) eq_(rp(bp(dt)), dt) def test_truncate_microseconds(self): dt = datetime.datetime(2008, 6, 27, 12, 0, 0, 125) dt_out = datetime.datetime(2008, 6, 27, 12, 0, 0) - eq_(str(dt), '2008-06-27 12:00:00.000125') + eq_(str(dt), "2008-06-27 12:00:00.000125") sldt = sqlite.DATETIME(truncate_microseconds=True) bp = sldt.bind_processor(None) - eq_(bp(dt), '2008-06-27 12:00:00') + eq_(bp(dt), "2008-06-27 12:00:00") rp = sldt.result_processor(None, None) eq_(rp(bp(dt)), dt_out) def test_custom_format_compact(self): dt = datetime.datetime(2008, 6, 27, 12, 0, 0, 125) - eq_(str(dt), '2008-06-27 12:00:00.000125') + eq_(str(dt), "2008-06-27 12:00:00.000125") sldt = sqlite.DATETIME( storage_format=( "%(year)04d%(month)02d%(day)02d" @@ -314,141 +344,150 @@ class DateTimeTest(fixtures.TestBase, AssertsCompiledSQL): regexp=r"(\d{4})(\d{2})(\d{2})(\d{2})(\d{2})(\d{2})(\d{6})", ) bp = sldt.bind_processor(None) - eq_(bp(dt), '20080627120000000125') + eq_(bp(dt), "20080627120000000125") rp = sldt.result_processor(None, None) eq_(rp(bp(dt)), dt) class DateTest(fixtures.TestBase, AssertsCompiledSQL): - def test_default(self): dt = datetime.date(2008, 6, 27) - eq_(str(dt), '2008-06-27') + eq_(str(dt), "2008-06-27") sldt = sqlite.DATE() bp = sldt.bind_processor(None) - eq_(bp(dt), '2008-06-27') + eq_(bp(dt), "2008-06-27") rp = sldt.result_processor(None, None) eq_(rp(bp(dt)), dt) def test_custom_format(self): dt = datetime.date(2008, 6, 27) - eq_(str(dt), '2008-06-27') + eq_(str(dt), "2008-06-27") sldt = sqlite.DATE( storage_format="%(month)02d/%(day)02d/%(year)04d", regexp=r"(?P\d+)/(?P\d+)/(?P\d+)", ) bp = sldt.bind_processor(None) - eq_(bp(dt), '06/27/2008') + eq_(bp(dt), "06/27/2008") rp = sldt.result_processor(None, None) eq_(rp(bp(dt)), dt) class TimeTest(fixtures.TestBase, AssertsCompiledSQL): - def test_default(self): dt = datetime.date(2008, 6, 27) - eq_(str(dt), '2008-06-27') + eq_(str(dt), "2008-06-27") sldt = sqlite.DATE() bp = sldt.bind_processor(None) - eq_(bp(dt), '2008-06-27') + eq_(bp(dt), "2008-06-27") rp = sldt.result_processor(None, None) eq_(rp(bp(dt)), dt) def test_truncate_microseconds(self): dt = datetime.time(12, 0, 0, 125) dt_out = datetime.time(12, 0, 0) - eq_(str(dt), '12:00:00.000125') + eq_(str(dt), "12:00:00.000125") sldt = sqlite.TIME(truncate_microseconds=True) bp = sldt.bind_processor(None) - eq_(bp(dt), '12:00:00') + eq_(bp(dt), "12:00:00") rp = sldt.result_processor(None, None) eq_(rp(bp(dt)), dt_out) def test_custom_format(self): dt = datetime.date(2008, 6, 27) - eq_(str(dt), '2008-06-27') + eq_(str(dt), "2008-06-27") sldt = sqlite.DATE( storage_format="%(year)04d%(month)02d%(day)02d", regexp=r"(\d{4})(\d{2})(\d{2})", ) bp = sldt.bind_processor(None) - eq_(bp(dt), '20080627') + eq_(bp(dt), "20080627") rp = sldt.result_processor(None, None) eq_(rp(bp(dt)), dt) class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" - @testing.exclude('sqlite', '<', (3, 3, 8), - 'sqlite3 changesets 3353 and 3440 modified ' - 'behavior of default displayed in pragma ' - 'table_info()') + @testing.exclude( + "sqlite", + "<", + (3, 3, 8), + "sqlite3 changesets 3353 and 3440 modified " + "behavior of default displayed in pragma " + "table_info()", + ) def test_default_reflection(self): # (ask_for, roundtripped_as_if_different) - specs = [(String(3), '"foo"'), (sqltypes.NUMERIC(10, 2), '100.50'), - (Integer, '5'), (Boolean, 'False')] - columns = [Column('c%i' % (i + 1), t[0], - server_default=text(t[1])) for (i, t) in - enumerate(specs)] + specs = [ + (String(3), '"foo"'), + (sqltypes.NUMERIC(10, 2), "100.50"), + (Integer, "5"), + (Boolean, "False"), + ] + columns = [ + Column("c%i" % (i + 1), t[0], server_default=text(t[1])) + for (i, t) in enumerate(specs) + ] db = testing.db m = MetaData(db) - Table('t_defaults', m, *columns) + Table("t_defaults", m, *columns) try: m.create_all() m2 = MetaData(db) - rt = Table('t_defaults', m2, autoload=True) + rt = Table("t_defaults", m2, autoload=True) expected = [c[1] for c in specs] for i, reflected in enumerate(rt.c): eq_(str(reflected.server_default.arg), expected[i]) finally: m.drop_all() - @testing.exclude('sqlite', '<', (3, 3, 8), - 'sqlite3 changesets 3353 and 3440 modified ' - 'behavior of default displayed in pragma ' - 'table_info()') + @testing.exclude( + "sqlite", + "<", + (3, 3, 8), + "sqlite3 changesets 3353 and 3440 modified " + "behavior of default displayed in pragma " + "table_info()", + ) def test_default_reflection_2(self): db = testing.db m = MetaData(db) - expected = ["'my_default'", '0'] - table = \ - """CREATE TABLE r_defaults ( + expected = ["'my_default'", "0"] + table = """CREATE TABLE r_defaults ( data VARCHAR(40) DEFAULT 'my_default', val INTEGER NOT NULL DEFAULT 0 )""" try: db.execute(table) - rt = Table('r_defaults', m, autoload=True) + rt = Table("r_defaults", m, autoload=True) for i, reflected in enumerate(rt.c): eq_(str(reflected.server_default.arg), expected[i]) finally: - db.execute('DROP TABLE r_defaults') + db.execute("DROP TABLE r_defaults") def test_default_reflection_3(self): db = testing.db - table = \ - """CREATE TABLE r_defaults ( + table = """CREATE TABLE r_defaults ( data VARCHAR(40) DEFAULT 'my_default', val INTEGER NOT NULL DEFAULT 0 )""" try: db.execute(table) m1 = MetaData(db) - t1 = Table('r_defaults', m1, autoload=True) + t1 = Table("r_defaults", m1, autoload=True) db.execute("DROP TABLE r_defaults") t1.create() m2 = MetaData(db) - t2 = Table('r_defaults', m2, autoload=True) + t2 = Table("r_defaults", m2, autoload=True) self.assert_compile( CreateTable(t2), "CREATE TABLE r_defaults (data VARCHAR(40) " "DEFAULT 'my_default', val INTEGER DEFAULT 0 " - "NOT NULL)" + "NOT NULL)", ) finally: db.execute("DROP TABLE r_defaults") @@ -456,14 +495,16 @@ class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL): @testing.provide_metadata def test_boolean_default(self): t = Table( - "t", self.metadata, - Column("x", Boolean, server_default=sql.false())) + "t", + self.metadata, + Column("x", Boolean, server_default=sql.false()), + ) t.create(testing.db) testing.db.execute(t.insert()) testing.db.execute(t.insert().values(x=True)) eq_( testing.db.execute(t.select().order_by(t.c.x)).fetchall(), - [(False,), (True,)] + [(False,), (True,)], ) def test_old_style_default(self): @@ -471,12 +512,12 @@ class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL): dialect = sqlite.dialect() info = dialect._get_column_info("foo", "INTEGER", False, 3, False) - eq_(info['default'], '3') + eq_(info["default"], "3") class DialectTest(fixtures.TestBase, AssertsExecutionResults): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" def test_extra_reserved_words(self): """Tests reserved words in identifiers. @@ -488,12 +529,12 @@ class DialectTest(fixtures.TestBase, AssertsExecutionResults): meta = MetaData(testing.db) t = Table( - 'reserved', + "reserved", meta, - Column('safe', Integer), - Column('true', Integer), - Column('false', Integer), - Column('column', Integer), + Column("safe", Integer), + Column("true", Integer), + Column("false", Integer), + Column("column", Integer), ) try: meta.create_all() @@ -507,12 +548,15 @@ class DialectTest(fixtures.TestBase, AssertsExecutionResults): """Tests autoload of tables created with quoted column names.""" metadata = self.metadata - testing.db.execute("""CREATE TABLE "django_content_type" ( + testing.db.execute( + """CREATE TABLE "django_content_type" ( "id" integer NOT NULL PRIMARY KEY, "django_stuff" text NULL ) - """) - testing.db.execute(""" + """ + ) + testing.db.execute( + """ CREATE TABLE "django_admin_log" ( "id" integer NOT NULL PRIMARY KEY, "action_time" datetime NOT NULL, @@ -521,12 +565,12 @@ class DialectTest(fixtures.TestBase, AssertsExecutionResults): "object_id" text NULL, "change_message" text NOT NULL ) - """) - table1 = Table('django_admin_log', metadata, autoload=True) - table2 = Table('django_content_type', metadata, autoload=True) + """ + ) + table1 = Table("django_admin_log", metadata, autoload=True) + table2 = Table("django_content_type", metadata, autoload=True) j = table1.join(table2) - assert j.onclause.compare( - table1.c.content_type_id == table2.c.id) + assert j.onclause.compare(table1.c.content_type_id == table2.c.id) @testing.provide_metadata def test_quoted_identifiers_functional_two(self): @@ -541,10 +585,12 @@ class DialectTest(fixtures.TestBase, AssertsExecutionResults): """ metadata = self.metadata - testing.db.execute(r'''CREATE TABLE """a""" ( + testing.db.execute( + r'''CREATE TABLE """a""" ( """id""" integer NOT NULL PRIMARY KEY ) - ''') + ''' + ) # unfortunately, still can't do this; sqlite quadruples # up the quotes on the table name here for pragma foreign_key_list @@ -570,45 +616,48 @@ class DialectTest(fixtures.TestBase, AssertsExecutionResults): # as encoded bytes in py2k t = Table( - 'x', self.metadata, - Column(u('méil'), Integer, primary_key=True), - Column(ue('\u6e2c\u8a66'), Integer), + "x", + self.metadata, + Column(u("méil"), Integer, primary_key=True), + Column(ue("\u6e2c\u8a66"), Integer), ) self.metadata.create_all(testing.db) result = testing.db.execute(t.select()) - assert u('méil') in result.keys() - assert ue('\u6e2c\u8a66') in result.keys() + assert u("méil") in result.keys() + assert ue("\u6e2c\u8a66") in result.keys() def test_file_path_is_absolute(self): d = pysqlite_dialect.dialect() eq_( - d.create_connect_args(make_url('sqlite:///foo.db')), - ([os.path.abspath('foo.db')], {}) + d.create_connect_args(make_url("sqlite:///foo.db")), + ([os.path.abspath("foo.db")], {}), ) def test_pool_class(self): - e = create_engine('sqlite+pysqlite://') + e = create_engine("sqlite+pysqlite://") assert e.pool.__class__ is pool.SingletonThreadPool - e = create_engine('sqlite+pysqlite:///:memory:') + e = create_engine("sqlite+pysqlite:///:memory:") assert e.pool.__class__ is pool.SingletonThreadPool - e = create_engine('sqlite+pysqlite:///foo.db') + e = create_engine("sqlite+pysqlite:///foo.db") assert e.pool.__class__ is pool.NullPool class AttachedDBTest(fixtures.TestBase): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" def _fixture(self): meta = self.metadata self.conn = testing.db.connect() ct = Table( - 'created', meta, - Column('id', Integer), - Column('name', String), - schema='test_schema') + "created", + meta, + Column("id", Integer), + Column("name", String), + schema="test_schema", + ) meta.create_all(self.conn) return ct @@ -650,53 +699,47 @@ class AttachedDBTest(fixtures.TestBase): def test_reflect_system_table(self): meta = MetaData(self.conn) alt_master = Table( - 'sqlite_master', meta, autoload=True, + "sqlite_master", + meta, + autoload=True, autoload_with=self.conn, - schema='test_schema') + schema="test_schema", + ) assert len(alt_master.c) > 0 def test_reflect_user_table(self): self._fixture() m2 = MetaData() - c2 = Table('created', m2, autoload=True, autoload_with=self.conn) + c2 = Table("created", m2, autoload=True, autoload_with=self.conn) eq_(len(c2.c), 2) def test_crud(self): ct = self._fixture() - self.conn.execute(ct.insert(), {'id': 1, 'name': 'foo'}) - eq_( - self.conn.execute(ct.select()).fetchall(), - [(1, 'foo')] - ) + self.conn.execute(ct.insert(), {"id": 1, "name": "foo"}) + eq_(self.conn.execute(ct.select()).fetchall(), [(1, "foo")]) - self.conn.execute(ct.update(), {'id': 2, 'name': 'bar'}) - eq_( - self.conn.execute(ct.select()).fetchall(), - [(2, 'bar')] - ) + self.conn.execute(ct.update(), {"id": 2, "name": "bar"}) + eq_(self.conn.execute(ct.select()).fetchall(), [(2, "bar")]) self.conn.execute(ct.delete()) - eq_( - self.conn.execute(ct.select()).fetchall(), - [] - ) + eq_(self.conn.execute(ct.select()).fetchall(), []) def test_col_targeting(self): ct = self._fixture() - self.conn.execute(ct.insert(), {'id': 1, 'name': 'foo'}) + self.conn.execute(ct.insert(), {"id": 1, "name": "foo"}) row = self.conn.execute(ct.select()).first() - eq_(row['id'], 1) - eq_(row['name'], 'foo') + eq_(row["id"], 1) + eq_(row["name"], "foo") def test_col_targeting_union(self): ct = self._fixture() - self.conn.execute(ct.insert(), {'id': 1, 'name': 'foo'}) + self.conn.execute(ct.insert(), {"id": 1, "name": "foo"}) row = self.conn.execute(ct.select().union(ct.select())).first() - eq_(row['id'], 1) - eq_(row['name'], 'foo') + eq_(row["id"], 1) + eq_(row["name"], "foo") class SQLTest(fixtures.TestBase, AssertsCompiledSQL): @@ -706,72 +749,71 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = sqlite.dialect() def test_extract(self): - t = sql.table('t', sql.column('col1')) + t = sql.table("t", sql.column("col1")) mapping = { - 'month': '%m', - 'day': '%d', - 'year': '%Y', - 'second': '%S', - 'hour': '%H', - 'doy': '%j', - 'minute': '%M', - 'epoch': '%s', - 'dow': '%w', - 'week': '%W', + "month": "%m", + "day": "%d", + "year": "%Y", + "second": "%S", + "hour": "%H", + "doy": "%j", + "minute": "%M", + "epoch": "%s", + "dow": "%w", + "week": "%W", } for field, subst in mapping.items(): - self.assert_compile(select([extract(field, t.c.col1)]), - "SELECT CAST(STRFTIME('%s', t.col1) AS " - "INTEGER) AS anon_1 FROM t" % subst) + self.assert_compile( + select([extract(field, t.c.col1)]), + "SELECT CAST(STRFTIME('%s', t.col1) AS " + "INTEGER) AS anon_1 FROM t" % subst, + ) def test_true_false(self): - self.assert_compile( - sql.false(), "0" - ) - self.assert_compile( - sql.true(), - "1" - ) + self.assert_compile(sql.false(), "0") + self.assert_compile(sql.true(), "1") def test_is_distinct_from(self): self.assert_compile( - sql.column('x').is_distinct_from(None), - "x IS NOT NULL" + sql.column("x").is_distinct_from(None), "x IS NOT NULL" ) self.assert_compile( - sql.column('x').isnot_distinct_from(False), - "x IS 0" + sql.column("x").isnot_distinct_from(False), "x IS 0" ) def test_localtime(self): self.assert_compile( - func.localtimestamp(), - 'DATETIME(CURRENT_TIMESTAMP, "localtime")' + func.localtimestamp(), 'DATETIME(CURRENT_TIMESTAMP, "localtime")' ) def test_constraints_with_schemas(self): metadata = MetaData() Table( - 't1', metadata, - Column('id', Integer, primary_key=True), - schema='master') + "t1", + metadata, + Column("id", Integer, primary_key=True), + schema="master", + ) t2 = Table( - 't2', metadata, - Column('id', Integer, primary_key=True), - Column('t1_id', Integer, ForeignKey('master.t1.id')), - schema='master' + "t2", + metadata, + Column("id", Integer, primary_key=True), + Column("t1_id", Integer, ForeignKey("master.t1.id")), + schema="master", ) t3 = Table( - 't3', metadata, - Column('id', Integer, primary_key=True), - Column('t1_id', Integer, ForeignKey('master.t1.id')), - schema='alternate' + "t3", + metadata, + Column("id", Integer, primary_key=True), + Column("t1_id", Integer, ForeignKey("master.t1.id")), + schema="alternate", ) t4 = Table( - 't4', metadata, - Column('id', Integer, primary_key=True), - Column('t1_id', Integer, ForeignKey('master.t1.id')), + "t4", + metadata, + Column("id", Integer, primary_key=True), + Column("t1_id", Integer, ForeignKey("master.t1.id")), ) # schema->schema, generate REFERENCES with no schema name @@ -782,7 +824,7 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): "t1_id INTEGER, " "PRIMARY KEY (id), " "FOREIGN KEY(t1_id) REFERENCES t1 (id)" - ")" + ")", ) # schema->different schema, don't generate REFERENCES @@ -792,7 +834,7 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): "id INTEGER NOT NULL, " "t1_id INTEGER, " "PRIMARY KEY (id)" - ")" + ")", ) # same for local schema @@ -802,179 +844,234 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): "id INTEGER NOT NULL, " "t1_id INTEGER, " "PRIMARY KEY (id)" - ")" + ")", ) def test_create_partial_index(self): m = MetaData() - tbl = Table('testtbl', m, Column('data', Integer)) - idx = Index('test_idx1', tbl.c.data, - sqlite_where=and_(tbl.c.data > 5, tbl.c.data < 10)) + tbl = Table("testtbl", m, Column("data", Integer)) + idx = Index( + "test_idx1", + tbl.c.data, + sqlite_where=and_(tbl.c.data > 5, tbl.c.data < 10), + ) # test quoting and all that - idx2 = Index('test_idx2', tbl.c.data, - sqlite_where=and_(tbl.c.data > 'a', tbl.c.data < "b's")) - self.assert_compile(schema.CreateIndex(idx), - 'CREATE INDEX test_idx1 ON testtbl (data) ' - 'WHERE data > 5 AND data < 10', - dialect=sqlite.dialect()) - self.assert_compile(schema.CreateIndex(idx2), - "CREATE INDEX test_idx2 ON testtbl (data) " - "WHERE data > 'a' AND data < 'b''s'", - dialect=sqlite.dialect()) + idx2 = Index( + "test_idx2", + tbl.c.data, + sqlite_where=and_(tbl.c.data > "a", tbl.c.data < "b's"), + ) + self.assert_compile( + schema.CreateIndex(idx), + "CREATE INDEX test_idx1 ON testtbl (data) " + "WHERE data > 5 AND data < 10", + dialect=sqlite.dialect(), + ) + self.assert_compile( + schema.CreateIndex(idx2), + "CREATE INDEX test_idx2 ON testtbl (data) " + "WHERE data > 'a' AND data < 'b''s'", + dialect=sqlite.dialect(), + ) def test_no_autoinc_on_composite_pk(self): m = MetaData() t = Table( - 't', m, - Column('x', Integer, primary_key=True, autoincrement=True), - Column('y', Integer, primary_key=True)) + "t", + m, + Column("x", Integer, primary_key=True, autoincrement=True), + Column("y", Integer, primary_key=True), + ) assert_raises_message( exc.CompileError, "SQLite does not support autoincrement for composite", - CreateTable(t).compile, dialect=sqlite.dialect() + CreateTable(t).compile, + dialect=sqlite.dialect(), ) + class OnConflictDDLTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = sqlite.dialect() def test_on_conflict_clause_column_not_null(self): - c = Column('test', Integer, nullable=False, - sqlite_on_conflict_not_null='FAIL') + c = Column( + "test", Integer, nullable=False, sqlite_on_conflict_not_null="FAIL" + ) - self.assert_compile(schema.CreateColumn(c), - 'test INTEGER NOT NULL ' - 'ON CONFLICT FAIL', dialect=sqlite.dialect()) + self.assert_compile( + schema.CreateColumn(c), + "test INTEGER NOT NULL " "ON CONFLICT FAIL", + dialect=sqlite.dialect(), + ) def test_on_conflict_clause_column_many_clause(self): meta = MetaData() t = Table( - 'n', meta, - Column('test', Integer, nullable=False, primary_key=True, - sqlite_on_conflict_not_null='FAIL', - sqlite_on_conflict_primary_key='IGNORE') + "n", + meta, + Column( + "test", + Integer, + nullable=False, + primary_key=True, + sqlite_on_conflict_not_null="FAIL", + sqlite_on_conflict_primary_key="IGNORE", + ), ) - self.assert_compile(CreateTable(t), - "CREATE TABLE n (" - "test INTEGER NOT NULL ON CONFLICT FAIL, " - "PRIMARY KEY (test) ON CONFLICT IGNORE)", - dialect=sqlite.dialect()) + self.assert_compile( + CreateTable(t), + "CREATE TABLE n (" + "test INTEGER NOT NULL ON CONFLICT FAIL, " + "PRIMARY KEY (test) ON CONFLICT IGNORE)", + dialect=sqlite.dialect(), + ) def test_on_conflict_clause_unique_constraint_from_column(self): meta = MetaData() t = Table( - 'n', meta, - Column('x', String(30), unique=True, - sqlite_on_conflict_unique='FAIL'), + "n", + meta, + Column( + "x", String(30), unique=True, sqlite_on_conflict_unique="FAIL" + ), ) - self.assert_compile(CreateTable(t), - "CREATE TABLE n (x VARCHAR(30), " - "UNIQUE (x) ON CONFLICT FAIL)", - dialect=sqlite.dialect()) + self.assert_compile( + CreateTable(t), + "CREATE TABLE n (x VARCHAR(30), " "UNIQUE (x) ON CONFLICT FAIL)", + dialect=sqlite.dialect(), + ) def test_on_conflict_clause_unique_constraint(self): meta = MetaData() t = Table( - 'n', meta, - Column('id', Integer), - Column('x', String(30)), - UniqueConstraint('id', 'x', sqlite_on_conflict='FAIL'), + "n", + meta, + Column("id", Integer), + Column("x", String(30)), + UniqueConstraint("id", "x", sqlite_on_conflict="FAIL"), ) - self.assert_compile(CreateTable(t), - "CREATE TABLE n (id INTEGER, x VARCHAR(30), " - "UNIQUE (id, x) ON CONFLICT FAIL)", - dialect=sqlite.dialect()) + self.assert_compile( + CreateTable(t), + "CREATE TABLE n (id INTEGER, x VARCHAR(30), " + "UNIQUE (id, x) ON CONFLICT FAIL)", + dialect=sqlite.dialect(), + ) def test_on_conflict_clause_primary_key(self): meta = MetaData() t = Table( - 'n', meta, - Column('id', Integer, primary_key=True, - sqlite_on_conflict_primary_key='FAIL'), - sqlite_autoincrement=True + "n", + meta, + Column( + "id", + Integer, + primary_key=True, + sqlite_on_conflict_primary_key="FAIL", + ), + sqlite_autoincrement=True, ) - self.assert_compile(CreateTable(t), - "CREATE TABLE n (id INTEGER NOT NULL " - "PRIMARY KEY ON CONFLICT FAIL AUTOINCREMENT)", - dialect=sqlite.dialect()) + self.assert_compile( + CreateTable(t), + "CREATE TABLE n (id INTEGER NOT NULL " + "PRIMARY KEY ON CONFLICT FAIL AUTOINCREMENT)", + dialect=sqlite.dialect(), + ) def test_on_conflict_clause_primary_key_constraint_from_column(self): meta = MetaData() t = Table( - 'n', meta, - Column('x', String(30), sqlite_on_conflict_primary_key='FAIL', - primary_key=True), + "n", + meta, + Column( + "x", + String(30), + sqlite_on_conflict_primary_key="FAIL", + primary_key=True, + ), ) - self.assert_compile(CreateTable(t), - "CREATE TABLE n (x VARCHAR(30) NOT NULL, " - "PRIMARY KEY (x) ON CONFLICT FAIL)", - dialect=sqlite.dialect()) + self.assert_compile( + CreateTable(t), + "CREATE TABLE n (x VARCHAR(30) NOT NULL, " + "PRIMARY KEY (x) ON CONFLICT FAIL)", + dialect=sqlite.dialect(), + ) def test_on_conflict_clause_check_constraint(self): meta = MetaData() t = Table( - 'n', meta, - Column('id', Integer), - Column('x', Integer), - CheckConstraint('id > x', sqlite_on_conflict='FAIL'), + "n", + meta, + Column("id", Integer), + Column("x", Integer), + CheckConstraint("id > x", sqlite_on_conflict="FAIL"), ) - self.assert_compile(CreateTable(t), - "CREATE TABLE n (id INTEGER, x INTEGER, " - "CHECK (id > x) ON CONFLICT FAIL)", - dialect=sqlite.dialect()) + self.assert_compile( + CreateTable(t), + "CREATE TABLE n (id INTEGER, x INTEGER, " + "CHECK (id > x) ON CONFLICT FAIL)", + dialect=sqlite.dialect(), + ) def test_on_conflict_clause_check_constraint_from_column(self): meta = MetaData() t = Table( - 'n', meta, - Column('x', Integer, - CheckConstraint('x > 1', - sqlite_on_conflict='FAIL')), + "n", + meta, + Column( + "x", + Integer, + CheckConstraint("x > 1", sqlite_on_conflict="FAIL"), + ), ) assert_raises_message( exc.CompileError, "SQLite does not support on conflict " "clause for column check constraint", - CreateTable(t).compile, dialect=sqlite.dialect() + CreateTable(t).compile, + dialect=sqlite.dialect(), ) def test_on_conflict_clause_primary_key_constraint(self): meta = MetaData() t = Table( - 'n', meta, - Column('id', Integer), - Column('x', String(30)), - PrimaryKeyConstraint('id', 'x', sqlite_on_conflict='FAIL'), + "n", + meta, + Column("id", Integer), + Column("x", String(30)), + PrimaryKeyConstraint("id", "x", sqlite_on_conflict="FAIL"), ) - self.assert_compile(CreateTable(t), - "CREATE TABLE n (" - "id INTEGER NOT NULL, " - "x VARCHAR(30) NOT NULL, " - "PRIMARY KEY (id, x) ON CONFLICT FAIL)", - dialect=sqlite.dialect()) + self.assert_compile( + CreateTable(t), + "CREATE TABLE n (" + "id INTEGER NOT NULL, " + "x VARCHAR(30) NOT NULL, " + "PRIMARY KEY (id, x) ON CONFLICT FAIL)", + dialect=sqlite.dialect(), + ) class InsertTest(fixtures.TestBase, AssertsExecutionResults): """Tests inserts and autoincrement.""" - __only_on__ = 'sqlite' + __only_on__ = "sqlite" # empty insert (i.e. INSERT INTO table DEFAULT VALUES) fails on # 3.3.7 and before @@ -989,14 +1086,17 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): finally: table.drop() - @testing.exclude('sqlite', '<', (3, 3, 8), 'no database support') + @testing.exclude("sqlite", "<", (3, 3, 8), "no database support") def test_empty_insert_pk1(self): self._test_empty_insert( Table( - 'a', MetaData(testing.db), - Column('id', Integer, primary_key=True))) + "a", + MetaData(testing.db), + Column("id", Integer, primary_key=True), + ) + ) - @testing.exclude('sqlite', '<', (3, 3, 8), 'no database support') + @testing.exclude("sqlite", "<", (3, 3, 8), "no database support") def test_empty_insert_pk2(self): # now warns due to [ticket:3216] @@ -1007,24 +1107,40 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): "primary key for table 'b'", ): assert_raises( - exc.IntegrityError, self._test_empty_insert, + exc.IntegrityError, + self._test_empty_insert, Table( - 'b', MetaData(testing.db), - Column('x', Integer, primary_key=True), - Column('y', Integer, primary_key=True))) + "b", + MetaData(testing.db), + Column("x", Integer, primary_key=True), + Column("y", Integer, primary_key=True), + ), + ) - @testing.exclude('sqlite', '<', (3, 3, 8), 'no database support') + @testing.exclude("sqlite", "<", (3, 3, 8), "no database support") def test_empty_insert_pk2_fv(self): assert_raises( - exc.DBAPIError, self._test_empty_insert, + exc.DBAPIError, + self._test_empty_insert, Table( - 'b', MetaData(testing.db), - Column('x', Integer, primary_key=True, - server_default=FetchedValue()), - Column('y', Integer, primary_key=True, - server_default=FetchedValue()))) + "b", + MetaData(testing.db), + Column( + "x", + Integer, + primary_key=True, + server_default=FetchedValue(), + ), + Column( + "y", + Integer, + primary_key=True, + server_default=FetchedValue(), + ), + ), + ) - @testing.exclude('sqlite', '<', (3, 3, 8), 'no database support') + @testing.exclude("sqlite", "<", (3, 3, 8), "no database support") def test_empty_insert_pk3(self): # now warns due to [ticket:3216] with expect_warnings( @@ -1034,53 +1150,74 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): exc.IntegrityError, self._test_empty_insert, Table( - 'c', MetaData(testing.db), - Column('x', Integer, primary_key=True), - Column('y', Integer, - DefaultClause('123'), primary_key=True)) + "c", + MetaData(testing.db), + Column("x", Integer, primary_key=True), + Column( + "y", Integer, DefaultClause("123"), primary_key=True + ), + ), ) - @testing.exclude('sqlite', '<', (3, 3, 8), 'no database support') + @testing.exclude("sqlite", "<", (3, 3, 8), "no database support") def test_empty_insert_pk3_fv(self): assert_raises( - exc.DBAPIError, self._test_empty_insert, + exc.DBAPIError, + self._test_empty_insert, Table( - 'c', MetaData(testing.db), - Column('x', Integer, primary_key=True, - server_default=FetchedValue()), - Column('y', Integer, DefaultClause('123'), primary_key=True))) + "c", + MetaData(testing.db), + Column( + "x", + Integer, + primary_key=True, + server_default=FetchedValue(), + ), + Column("y", Integer, DefaultClause("123"), primary_key=True), + ), + ) - @testing.exclude('sqlite', '<', (3, 3, 8), 'no database support') + @testing.exclude("sqlite", "<", (3, 3, 8), "no database support") def test_empty_insert_pk4(self): self._test_empty_insert( Table( - 'd', MetaData(testing.db), - Column('x', Integer, primary_key=True), - Column('y', Integer, DefaultClause('123')) - )) + "d", + MetaData(testing.db), + Column("x", Integer, primary_key=True), + Column("y", Integer, DefaultClause("123")), + ) + ) - @testing.exclude('sqlite', '<', (3, 3, 8), 'no database support') + @testing.exclude("sqlite", "<", (3, 3, 8), "no database support") def test_empty_insert_nopk1(self): - self._test_empty_insert(Table('e', MetaData(testing.db), - Column('id', Integer))) + self._test_empty_insert( + Table("e", MetaData(testing.db), Column("id", Integer)) + ) - @testing.exclude('sqlite', '<', (3, 3, 8), 'no database support') + @testing.exclude("sqlite", "<", (3, 3, 8), "no database support") def test_empty_insert_nopk2(self): self._test_empty_insert( Table( - 'f', MetaData(testing.db), - Column('x', Integer), Column('y', Integer))) + "f", + MetaData(testing.db), + Column("x", Integer), + Column("y", Integer), + ) + ) def test_inserts_with_spaces(self): - tbl = Table('tbl', MetaData('sqlite:///'), Column('with space', - Integer), Column('without', Integer)) + tbl = Table( + "tbl", + MetaData("sqlite:///"), + Column("with space", Integer), + Column("without", Integer), + ) tbl.create() try: - tbl.insert().execute({'without': 123}) + tbl.insert().execute({"without": 123}) assert list(tbl.select().execute()) == [(None, 123)] - tbl.insert().execute({'with space': 456}) - assert list(tbl.select().execute()) == [ - (None, 123), (456, None)] + tbl.insert().execute({"with space": 456}) + assert list(tbl.select().execute()) == [(None, 123), (456, None)] finally: tbl.drop() @@ -1090,8 +1227,8 @@ def full_text_search_missing(): it is and True otherwise.""" try: - testing.db.execute('CREATE VIRTUAL TABLE t using FTS3;') - testing.db.execute('DROP TABLE t;') + testing.db.execute("CREATE VIRTUAL TABLE t using FTS3;") + testing.db.execute("DROP TABLE t;") return False except Exception: return True @@ -1102,152 +1239,203 @@ metadata = cattable = matchtable = None class MatchTest(fixtures.TestBase, AssertsCompiledSQL): - __only_on__ = 'sqlite' - __skip_if__ = full_text_search_missing, + __only_on__ = "sqlite" + __skip_if__ = (full_text_search_missing,) @classmethod def setup_class(cls): global metadata, cattable, matchtable metadata = MetaData(testing.db) - testing.db.execute(""" + testing.db.execute( + """ CREATE VIRTUAL TABLE cattable using FTS3 ( id INTEGER NOT NULL, description VARCHAR(50), PRIMARY KEY (id) ) - """) - cattable = Table('cattable', metadata, autoload=True) - testing.db.execute(""" + """ + ) + cattable = Table("cattable", metadata, autoload=True) + testing.db.execute( + """ CREATE VIRTUAL TABLE matchtable using FTS3 ( id INTEGER NOT NULL, title VARCHAR(200), category_id INTEGER NOT NULL, PRIMARY KEY (id) ) - """) - matchtable = Table('matchtable', metadata, autoload=True) + """ + ) + matchtable = Table("matchtable", metadata, autoload=True) metadata.create_all() cattable.insert().execute( - [{'id': 1, 'description': 'Python'}, - {'id': 2, 'description': 'Ruby'}]) + [ + {"id": 1, "description": "Python"}, + {"id": 2, "description": "Ruby"}, + ] + ) matchtable.insert().execute( [ - {'id': 1, 'title': 'Agile Web Development with Rails', - 'category_id': 2}, - {'id': 2, 'title': 'Dive Into Python', 'category_id': 1}, - {'id': 3, 'title': "Programming Matz's Ruby", - 'category_id': 2}, - {'id': 4, 'title': 'The Definitive Guide to Django', - 'category_id': 1}, - {'id': 5, 'title': 'Python in a Nutshell', 'category_id': 1} - ]) + { + "id": 1, + "title": "Agile Web Development with Rails", + "category_id": 2, + }, + {"id": 2, "title": "Dive Into Python", "category_id": 1}, + { + "id": 3, + "title": "Programming Matz's Ruby", + "category_id": 2, + }, + { + "id": 4, + "title": "The Definitive Guide to Django", + "category_id": 1, + }, + {"id": 5, "title": "Python in a Nutshell", "category_id": 1}, + ] + ) @classmethod def teardown_class(cls): metadata.drop_all() def test_expression(self): - self.assert_compile(matchtable.c.title.match('somstr'), - 'matchtable.title MATCH ?', - dialect=sqlite.dialect()) + self.assert_compile( + matchtable.c.title.match("somstr"), + "matchtable.title MATCH ?", + dialect=sqlite.dialect(), + ) def test_simple_match(self): - results = \ - matchtable.select().where( - matchtable.c.title.match('python')).\ - order_by(matchtable.c.id).execute().fetchall() + results = ( + matchtable.select() + .where(matchtable.c.title.match("python")) + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([2, 5], [r.id for r in results]) def test_simple_prefix_match(self): - results = \ - matchtable.select().where( - matchtable.c.title.match('nut*')).execute().fetchall() + results = ( + matchtable.select() + .where(matchtable.c.title.match("nut*")) + .execute() + .fetchall() + ) eq_([5], [r.id for r in results]) def test_or_match(self): - results2 = \ - matchtable.select().where( - matchtable.c.title.match('nutshell OR ruby')).\ - order_by(matchtable.c.id).execute().fetchall() + results2 = ( + matchtable.select() + .where(matchtable.c.title.match("nutshell OR ruby")) + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([3, 5], [r.id for r in results2]) def test_and_match(self): - results2 = \ - matchtable.select().where( - matchtable.c.title.match('python nutshell') - ).execute().fetchall() + results2 = ( + matchtable.select() + .where(matchtable.c.title.match("python nutshell")) + .execute() + .fetchall() + ) eq_([5], [r.id for r in results2]) def test_match_across_joins(self): - results = matchtable.select().where( - and_( - cattable.c.id == matchtable.c.category_id, - cattable.c.description.match('Ruby') + results = ( + matchtable.select() + .where( + and_( + cattable.c.id == matchtable.c.category_id, + cattable.c.description.match("Ruby"), + ) ) - ).order_by(matchtable.c.id).execute().fetchall() + .order_by(matchtable.c.id) + .execute() + .fetchall() + ) eq_([1, 3], [r.id for r in results]) class AutoIncrementTest(fixtures.TestBase, AssertsCompiledSQL): - def test_sqlite_autoincrement(self): - table = Table('autoinctable', MetaData(), Column('id', Integer, - primary_key=True), Column('x', Integer, - default=None), sqlite_autoincrement=True) + table = Table( + "autoinctable", + MetaData(), + Column("id", Integer, primary_key=True), + Column("x", Integer, default=None), + sqlite_autoincrement=True, + ) self.assert_compile( schema.CreateTable(table), - 'CREATE TABLE autoinctable (id INTEGER NOT ' - 'NULL PRIMARY KEY AUTOINCREMENT, x INTEGER)', - dialect=sqlite.dialect()) + "CREATE TABLE autoinctable (id INTEGER NOT " + "NULL PRIMARY KEY AUTOINCREMENT, x INTEGER)", + dialect=sqlite.dialect(), + ) def test_sqlite_autoincrement_constraint(self): table = Table( - 'autoinctable', + "autoinctable", MetaData(), - Column('id', Integer, primary_key=True), - Column('x', Integer, default=None), - UniqueConstraint('x'), + Column("id", Integer, primary_key=True), + Column("x", Integer, default=None), + UniqueConstraint("x"), sqlite_autoincrement=True, ) - self.assert_compile(schema.CreateTable(table), - 'CREATE TABLE autoinctable (id INTEGER NOT ' - 'NULL PRIMARY KEY AUTOINCREMENT, x ' - 'INTEGER, UNIQUE (x))', - dialect=sqlite.dialect()) + self.assert_compile( + schema.CreateTable(table), + "CREATE TABLE autoinctable (id INTEGER NOT " + "NULL PRIMARY KEY AUTOINCREMENT, x " + "INTEGER, UNIQUE (x))", + dialect=sqlite.dialect(), + ) def test_sqlite_no_autoincrement(self): - table = Table('noautoinctable', MetaData(), Column('id', - Integer, primary_key=True), Column('x', Integer, - default=None)) - self.assert_compile(schema.CreateTable(table), - 'CREATE TABLE noautoinctable (id INTEGER ' - 'NOT NULL, x INTEGER, PRIMARY KEY (id))', - dialect=sqlite.dialect()) + table = Table( + "noautoinctable", + MetaData(), + Column("id", Integer, primary_key=True), + Column("x", Integer, default=None), + ) + self.assert_compile( + schema.CreateTable(table), + "CREATE TABLE noautoinctable (id INTEGER " + "NOT NULL, x INTEGER, PRIMARY KEY (id))", + dialect=sqlite.dialect(), + ) def test_sqlite_autoincrement_int_affinity(self): class MyInteger(sqltypes.TypeDecorator): impl = Integer + table = Table( - 'autoinctable', + "autoinctable", MetaData(), - Column('id', MyInteger, primary_key=True), + Column("id", MyInteger, primary_key=True), sqlite_autoincrement=True, ) - self.assert_compile(schema.CreateTable(table), - 'CREATE TABLE autoinctable (id INTEGER NOT ' - 'NULL PRIMARY KEY AUTOINCREMENT)', - dialect=sqlite.dialect()) + self.assert_compile( + schema.CreateTable(table), + "CREATE TABLE autoinctable (id INTEGER NOT " + "NULL PRIMARY KEY AUTOINCREMENT)", + dialect=sqlite.dialect(), + ) class ReflectHeadlessFKsTest(fixtures.TestBase): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" def setup(self): testing.db.execute("CREATE TABLE a (id INTEGER PRIMARY KEY)") # this syntax actually works on other DBs perhaps we'd want to add # tests to test_reflection testing.db.execute( - "CREATE TABLE b (id INTEGER PRIMARY KEY REFERENCES a)") + "CREATE TABLE b (id INTEGER PRIMARY KEY REFERENCES a)" + ) def teardown(self): testing.db.execute("drop table b") @@ -1255,19 +1443,19 @@ class ReflectHeadlessFKsTest(fixtures.TestBase): def test_reflect_tables_fk_no_colref(self): meta = MetaData() - a = Table('a', meta, autoload=True, autoload_with=testing.db) - b = Table('b', meta, autoload=True, autoload_with=testing.db) + a = Table("a", meta, autoload=True, autoload_with=testing.db) + b = Table("b", meta, autoload=True, autoload_with=testing.db) assert b.c.id.references(a.c.id) class KeywordInDatabaseNameTest(fixtures.TestBase): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" @classmethod def setup_class(cls): with testing.db.begin() as conn: - conn.execute("ATTACH %r AS \"default\"" % conn.engine.url.database) + conn.execute('ATTACH %r AS "default"' % conn.engine.url.database) conn.execute('CREATE TABLE "default".a (id INTEGER PRIMARY KEY)') @classmethod @@ -1281,13 +1469,13 @@ class KeywordInDatabaseNameTest(fixtures.TestBase): def test_reflect(self): with testing.db.begin() as conn: - meta = MetaData(bind=conn, schema='default') + meta = MetaData(bind=conn, schema="default") meta.reflect() - assert 'default.a' in meta.tables + assert "default.a" in meta.tables class ConstraintReflectionTest(fixtures.TestBase): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" @classmethod def setup_class(cls): @@ -1299,37 +1487,46 @@ class ConstraintReflectionTest(fixtures.TestBase): "CREATE TABLE b (id INTEGER PRIMARY KEY, " "FOREIGN KEY(id) REFERENCES a1(id)," "FOREIGN KEY(id) REFERENCES a2(id)" - ")") + ")" + ) conn.execute( "CREATE TABLE c (id INTEGER, " "CONSTRAINT bar PRIMARY KEY(id)," "CONSTRAINT foo1 FOREIGN KEY(id) REFERENCES a1(id)," "CONSTRAINT foo2 FOREIGN KEY(id) REFERENCES a2(id)" - ")") + ")" + ) conn.execute( # the lower casing + inline is intentional here - "CREATE TABLE d (id INTEGER, x INTEGER unique)") + "CREATE TABLE d (id INTEGER, x INTEGER unique)" + ) conn.execute( # the lower casing + inline is intentional here - 'CREATE TABLE d1 ' - '(id INTEGER, "some ( STUPID n,ame" INTEGER unique)') + "CREATE TABLE d1 " + '(id INTEGER, "some ( STUPID n,ame" INTEGER unique)' + ) conn.execute( # the lower casing + inline is intentional here - 'CREATE TABLE d2 ( "some STUPID n,ame" INTEGER unique)') + 'CREATE TABLE d2 ( "some STUPID n,ame" INTEGER unique)' + ) conn.execute( # the lower casing + inline is intentional here - 'CREATE TABLE d3 ( "some STUPID n,ame" INTEGER NULL unique)') + 'CREATE TABLE d3 ( "some STUPID n,ame" INTEGER NULL unique)' + ) conn.execute( # lower casing + inline is intentional - "CREATE TABLE e (id INTEGER, x INTEGER references a2(id))") + "CREATE TABLE e (id INTEGER, x INTEGER references a2(id))" + ) conn.execute( 'CREATE TABLE e1 (id INTEGER, "some ( STUPID n,ame" INTEGER ' - 'references a2 ("some ( STUPID n,ame"))') + 'references a2 ("some ( STUPID n,ame"))' + ) conn.execute( - 'CREATE TABLE e2 (id INTEGER, ' + "CREATE TABLE e2 (id INTEGER, " '"some ( STUPID n,ame" INTEGER NOT NULL ' - 'references a2 ("some ( STUPID n,ame"))') + 'references a2 ("some ( STUPID n,ame"))' + ) conn.execute( "CREATE TABLE f (x INTEGER, CONSTRAINT foo_fx UNIQUE(x))" @@ -1357,36 +1554,33 @@ class ConstraintReflectionTest(fixtures.TestBase): ) meta = MetaData() - Table( - 'l', meta, Column('bar', String, index=True), - schema='main') + Table("l", meta, Column("bar", String, index=True), schema="main") Table( - 'm', meta, - Column('id', Integer, primary_key=True), - Column('x', String(30)), - UniqueConstraint('x') + "m", + meta, + Column("id", Integer, primary_key=True), + Column("x", String(30)), + UniqueConstraint("x"), ) Table( - 'n', meta, - Column('id', Integer, primary_key=True), - Column('x', String(30)), - UniqueConstraint('x'), - prefixes=['TEMPORARY'] + "n", + meta, + Column("id", Integer, primary_key=True), + Column("x", String(30)), + UniqueConstraint("x"), + prefixes=["TEMPORARY"], ) Table( - 'p', meta, - Column('id', Integer), - PrimaryKeyConstraint('id', name='pk_name'), + "p", + meta, + Column("id", Integer), + PrimaryKeyConstraint("id", name="pk_name"), ) - Table( - 'q', meta, - Column('id', Integer), - PrimaryKeyConstraint('id'), - ) + Table("q", meta, Column("id", Integer), PrimaryKeyConstraint("id")) meta.create_all(conn) @@ -1416,8 +1610,24 @@ class ConstraintReflectionTest(fixtures.TestBase): def teardown_class(cls): with testing.db.begin() as conn: for name in [ - "m", "main.l", "k", "j", "i", "h", "g", "f", "e", "e1", - "d", "d1", "d2", "c", "b", "a1", "a2"]: + "m", + "main.l", + "k", + "j", + "i", + "h", + "g", + "f", + "e", + "e1", + "d", + "d1", + "d2", + "c", + "b", + "a1", + "a2", + ]: try: conn.execute("drop table %s" % name) except Exception: @@ -1428,301 +1638,369 @@ class ConstraintReflectionTest(fixtures.TestBase): dialect._broken_fk_pragma_quotes = True for row in [ - (0, None, 'target', 'tid', 'id', None), - (0, None, '"target"', 'tid', 'id', None), - (0, None, '[target]', 'tid', 'id', None), - (0, None, "'target'", 'tid', 'id', None), - (0, None, '`target`', 'tid', 'id', None), + (0, None, "target", "tid", "id", None), + (0, None, '"target"', "tid", "id", None), + (0, None, "[target]", "tid", "id", None), + (0, None, "'target'", "tid", "id", None), + (0, None, "`target`", "tid", "id", None), ]: + def _get_table_pragma(*arg, **kw): return [row] def _get_table_sql(*arg, **kw): - return "CREATE TABLE foo "\ - "(tid INTEGER, "\ + return ( + "CREATE TABLE foo " + "(tid INTEGER, " "FOREIGN KEY(tid) REFERENCES %s (id))" % row[2] + ) + with mock.patch.object( - dialect, "_get_table_pragma", _get_table_pragma): + dialect, "_get_table_pragma", _get_table_pragma + ): with mock.patch.object( - dialect, '_get_table_sql', _get_table_sql): + dialect, "_get_table_sql", _get_table_sql + ): - fkeys = dialect.get_foreign_keys(None, 'foo') + fkeys = dialect.get_foreign_keys(None, "foo") eq_( fkeys, - [{ - 'referred_table': 'target', - 'referred_columns': ['id'], - 'referred_schema': None, - 'name': None, - 'constrained_columns': ['tid'], - 'options': {} - }]) + [ + { + "referred_table": "target", + "referred_columns": ["id"], + "referred_schema": None, + "name": None, + "constrained_columns": ["tid"], + "options": {}, + } + ], + ) def test_foreign_key_name_is_none(self): # and not "0" inspector = Inspector(testing.db) - fks = inspector.get_foreign_keys('b') + fks = inspector.get_foreign_keys("b") eq_( fks, [ - {'referred_table': 'a1', 'referred_columns': ['id'], - 'referred_schema': None, 'name': None, - 'constrained_columns': ['id'], - 'options': {}}, - {'referred_table': 'a2', 'referred_columns': ['id'], - 'referred_schema': None, 'name': None, - 'constrained_columns': ['id'], - 'options': {}}, - ] + { + "referred_table": "a1", + "referred_columns": ["id"], + "referred_schema": None, + "name": None, + "constrained_columns": ["id"], + "options": {}, + }, + { + "referred_table": "a2", + "referred_columns": ["id"], + "referred_schema": None, + "name": None, + "constrained_columns": ["id"], + "options": {}, + }, + ], ) def test_foreign_key_name_is_not_none(self): inspector = Inspector(testing.db) - fks = inspector.get_foreign_keys('c') + fks = inspector.get_foreign_keys("c") eq_( fks, [ { - 'referred_table': 'a1', 'referred_columns': ['id'], - 'referred_schema': None, 'name': 'foo1', - 'constrained_columns': ['id'], - 'options': {}}, + "referred_table": "a1", + "referred_columns": ["id"], + "referred_schema": None, + "name": "foo1", + "constrained_columns": ["id"], + "options": {}, + }, { - 'referred_table': 'a2', 'referred_columns': ['id'], - 'referred_schema': None, 'name': 'foo2', - 'constrained_columns': ['id'], - 'options': {}}, - ] + "referred_table": "a2", + "referred_columns": ["id"], + "referred_schema": None, + "name": "foo2", + "constrained_columns": ["id"], + "options": {}, + }, + ], ) def test_unnamed_inline_foreign_key(self): inspector = Inspector(testing.db) - fks = inspector.get_foreign_keys('e') + fks = inspector.get_foreign_keys("e") eq_( fks, - [{ - 'referred_table': 'a2', 'referred_columns': ['id'], - 'referred_schema': None, - 'name': None, 'constrained_columns': ['x'], - 'options': {} - }] + [ + { + "referred_table": "a2", + "referred_columns": ["id"], + "referred_schema": None, + "name": None, + "constrained_columns": ["x"], + "options": {}, + } + ], ) def test_unnamed_inline_foreign_key_quoted(self): inspector = Inspector(testing.db) - fks = inspector.get_foreign_keys('e1') + fks = inspector.get_foreign_keys("e1") eq_( fks, - [{ - 'referred_table': 'a2', - 'referred_columns': ['some ( STUPID n,ame'], - 'referred_schema': None, - 'options': {}, - 'name': None, 'constrained_columns': ['some ( STUPID n,ame'] - }] - ) - fks = inspector.get_foreign_keys('e2') + [ + { + "referred_table": "a2", + "referred_columns": ["some ( STUPID n,ame"], + "referred_schema": None, + "options": {}, + "name": None, + "constrained_columns": ["some ( STUPID n,ame"], + } + ], + ) + fks = inspector.get_foreign_keys("e2") eq_( fks, - [{ - 'referred_table': 'a2', - 'referred_columns': ['some ( STUPID n,ame'], - 'referred_schema': None, - 'options': {}, - 'name': None, 'constrained_columns': ['some ( STUPID n,ame'] - }] + [ + { + "referred_table": "a2", + "referred_columns": ["some ( STUPID n,ame"], + "referred_schema": None, + "options": {}, + "name": None, + "constrained_columns": ["some ( STUPID n,ame"], + } + ], ) def test_foreign_key_composite_broken_casing(self): inspector = Inspector(testing.db) - fks = inspector.get_foreign_keys('j') + fks = inspector.get_foreign_keys("j") eq_( fks, - [{ - 'referred_table': 'i', - 'referred_columns': ['x', 'y'], - 'referred_schema': None, 'name': None, - 'constrained_columns': ['q', 'p'], - 'options': {}}] - ) - fks = inspector.get_foreign_keys('k') + [ + { + "referred_table": "i", + "referred_columns": ["x", "y"], + "referred_schema": None, + "name": None, + "constrained_columns": ["q", "p"], + "options": {}, + } + ], + ) + fks = inspector.get_foreign_keys("k") eq_( fks, [ - {'referred_table': 'i', 'referred_columns': ['x', 'y'], - 'referred_schema': None, 'name': 'my_fk', - 'constrained_columns': ['q', 'p'], - 'options': {}}] + { + "referred_table": "i", + "referred_columns": ["x", "y"], + "referred_schema": None, + "name": "my_fk", + "constrained_columns": ["q", "p"], + "options": {}, + } + ], ) def test_foreign_key_ondelete_onupdate(self): inspector = Inspector(testing.db) - fks = inspector.get_foreign_keys('onud_test') + fks = inspector.get_foreign_keys("onud_test") eq_( fks, [ { - 'referred_table': 'a1', 'referred_columns': ['id'], - 'referred_schema': None, 'name': 'fk1', - 'constrained_columns': ['c1'], - 'options': {'ondelete': 'SET NULL'} + "referred_table": "a1", + "referred_columns": ["id"], + "referred_schema": None, + "name": "fk1", + "constrained_columns": ["c1"], + "options": {"ondelete": "SET NULL"}, }, { - 'referred_table': 'a1', 'referred_columns': ['id'], - 'referred_schema': None, 'name': 'fk2', - 'constrained_columns': ['c2'], - 'options': {'onupdate': 'CASCADE'} + "referred_table": "a1", + "referred_columns": ["id"], + "referred_schema": None, + "name": "fk2", + "constrained_columns": ["c2"], + "options": {"onupdate": "CASCADE"}, }, { - 'referred_table': 'a2', 'referred_columns': ['id'], - 'referred_schema': None, 'name': 'fk3', - 'constrained_columns': ['c3'], - 'options': {'ondelete': 'CASCADE', 'onupdate': 'SET NULL'} + "referred_table": "a2", + "referred_columns": ["id"], + "referred_schema": None, + "name": "fk3", + "constrained_columns": ["c3"], + "options": {"ondelete": "CASCADE", "onupdate": "SET NULL"}, }, { - 'referred_table': 'a2', 'referred_columns': ['id'], - 'referred_schema': None, 'name': 'fk4', - 'constrained_columns': ['c4'], - 'options': {'onupdate': 'NO ACTION'} + "referred_table": "a2", + "referred_columns": ["id"], + "referred_schema": None, + "name": "fk4", + "constrained_columns": ["c4"], + "options": {"onupdate": "NO ACTION"}, }, - ] + ], ) def test_foreign_key_options_unnamed_inline(self): with testing.db.connect() as conn: conn.execute( "create table foo (id integer, " - "foreign key (id) references bar (id) on update cascade)") + "foreign key (id) references bar (id) on update cascade)" + ) insp = inspect(conn) eq_( - insp.get_foreign_keys('foo'), - [{ - 'name': None, - 'referred_columns': ['id'], - 'referred_table': 'bar', - 'constrained_columns': ['id'], - 'referred_schema': None, - 'options': {'onupdate': 'CASCADE'}}] + insp.get_foreign_keys("foo"), + [ + { + "name": None, + "referred_columns": ["id"], + "referred_table": "bar", + "constrained_columns": ["id"], + "referred_schema": None, + "options": {"onupdate": "CASCADE"}, + } + ], ) def test_dont_reflect_autoindex(self): inspector = Inspector(testing.db) - eq_(inspector.get_indexes('o'), []) + eq_(inspector.get_indexes("o"), []) eq_( - inspector.get_indexes('o', include_auto_indexes=True), - [{ - 'unique': 1, - 'name': 'sqlite_autoindex_o_1', - 'column_names': ['foo']}]) + inspector.get_indexes("o", include_auto_indexes=True), + [ + { + "unique": 1, + "name": "sqlite_autoindex_o_1", + "column_names": ["foo"], + } + ], + ) def test_create_index_with_schema(self): """Test creation of index with explicit schema""" inspector = Inspector(testing.db) eq_( - inspector.get_indexes('l', schema='main'), - [{'unique': 0, 'name': u'ix_main_l_bar', - 'column_names': [u'bar']}]) + inspector.get_indexes("l", schema="main"), + [ + { + "unique": 0, + "name": u"ix_main_l_bar", + "column_names": [u"bar"], + } + ], + ) def test_unique_constraint_named(self): inspector = Inspector(testing.db) eq_( inspector.get_unique_constraints("f"), - [{'column_names': ['x'], 'name': 'foo_fx'}] + [{"column_names": ["x"], "name": "foo_fx"}], ) def test_unique_constraint_named_broken_casing(self): inspector = Inspector(testing.db) eq_( inspector.get_unique_constraints("h"), - [{'column_names': ['x'], 'name': 'foo_hx'}] + [{"column_names": ["x"], "name": "foo_hx"}], ) def test_unique_constraint_named_broken_temp(self): inspector = Inspector(testing.db) eq_( inspector.get_unique_constraints("g"), - [{'column_names': ['x'], 'name': 'foo_gx'}] + [{"column_names": ["x"], "name": "foo_gx"}], ) def test_unique_constraint_unnamed_inline(self): inspector = Inspector(testing.db) eq_( inspector.get_unique_constraints("d"), - [{'column_names': ['x'], 'name': None}] + [{"column_names": ["x"], "name": None}], ) def test_unique_constraint_unnamed_inline_quoted(self): inspector = Inspector(testing.db) eq_( inspector.get_unique_constraints("d1"), - [{'column_names': ['some ( STUPID n,ame'], 'name': None}] + [{"column_names": ["some ( STUPID n,ame"], "name": None}], ) eq_( inspector.get_unique_constraints("d2"), - [{'column_names': ['some STUPID n,ame'], 'name': None}] + [{"column_names": ["some STUPID n,ame"], "name": None}], ) eq_( inspector.get_unique_constraints("d3"), - [{'column_names': ['some STUPID n,ame'], 'name': None}] + [{"column_names": ["some STUPID n,ame"], "name": None}], ) def test_unique_constraint_unnamed_normal(self): inspector = Inspector(testing.db) eq_( inspector.get_unique_constraints("m"), - [{'column_names': ['x'], 'name': None}] + [{"column_names": ["x"], "name": None}], ) def test_unique_constraint_unnamed_normal_temporary(self): inspector = Inspector(testing.db) eq_( inspector.get_unique_constraints("n"), - [{'column_names': ['x'], 'name': None}] + [{"column_names": ["x"], "name": None}], ) def test_primary_key_constraint_named(self): inspector = Inspector(testing.db) eq_( inspector.get_pk_constraint("p"), - {'constrained_columns': ['id'], 'name': 'pk_name'} + {"constrained_columns": ["id"], "name": "pk_name"}, ) def test_primary_key_constraint_unnamed(self): inspector = Inspector(testing.db) eq_( inspector.get_pk_constraint("q"), - {'constrained_columns': ['id'], 'name': None} + {"constrained_columns": ["id"], "name": None}, ) def test_primary_key_constraint_no_pk(self): inspector = Inspector(testing.db) eq_( inspector.get_pk_constraint("d"), - {'constrained_columns': [], 'name': None} + {"constrained_columns": [], "name": None}, ) def test_check_constraint(self): inspector = Inspector(testing.db) eq_( inspector.get_check_constraints("cp"), - [{'sqltext': 'q > 1 AND q < 6', 'name': None}, - {'sqltext': 'q == 1 OR (q > 2 AND q < 5)', 'name': 'cq'}] + [ + {"sqltext": "q > 1 AND q < 6", "name": None}, + {"sqltext": "q == 1 OR (q > 2 AND q < 5)", "name": "cq"}, + ], ) class SavepointTest(fixtures.TablesTest): """test that savepoints work when we use the correct event setup""" - __only_on__ = 'sqlite' + + __only_on__ = "sqlite" @classmethod def define_tables(cls, metadata): Table( - 'users', metadata, - Column('user_id', Integer, primary_key=True), - Column('user_name', String) + "users", + metadata, + Column("user_id", Integer, primary_key=True), + Column("user_name", String), ) @classmethod @@ -1746,53 +2024,62 @@ class SavepointTest(fixtures.TablesTest): users = self.tables.users connection = self.bind.connect() transaction = connection.begin() - connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=1, user_name="user1") trans2 = connection.begin_nested() - connection.execute(users.insert(), user_id=2, user_name='user2') + connection.execute(users.insert(), user_id=2, user_name="user2") trans2.rollback() - connection.execute(users.insert(), user_id=3, user_name='user3') + connection.execute(users.insert(), user_id=3, user_name="user3") transaction.commit() - eq_(connection.execute(select([users.c.user_id]). - order_by(users.c.user_id)).fetchall(), - [(1, ), (3, )]) + eq_( + connection.execute( + select([users.c.user_id]).order_by(users.c.user_id) + ).fetchall(), + [(1,), (3,)], + ) connection.close() def test_nested_subtransaction_commit(self): users = self.tables.users connection = self.bind.connect() transaction = connection.begin() - connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=1, user_name="user1") trans2 = connection.begin_nested() - connection.execute(users.insert(), user_id=2, user_name='user2') + connection.execute(users.insert(), user_id=2, user_name="user2") trans2.commit() - connection.execute(users.insert(), user_id=3, user_name='user3') + connection.execute(users.insert(), user_id=3, user_name="user3") transaction.commit() - eq_(connection.execute(select([users.c.user_id]). - order_by(users.c.user_id)).fetchall(), - [(1, ), (2, ), (3, )]) + eq_( + connection.execute( + select([users.c.user_id]).order_by(users.c.user_id) + ).fetchall(), + [(1,), (2,), (3,)], + ) connection.close() def test_rollback_to_subtransaction(self): users = self.tables.users connection = self.bind.connect() transaction = connection.begin() - connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=1, user_name="user1") connection.begin_nested() - connection.execute(users.insert(), user_id=2, user_name='user2') + connection.execute(users.insert(), user_id=2, user_name="user2") trans3 = connection.begin() - connection.execute(users.insert(), user_id=3, user_name='user3') + connection.execute(users.insert(), user_id=3, user_name="user3") trans3.rollback() - connection.execute(users.insert(), user_id=4, user_name='user4') + connection.execute(users.insert(), user_id=4, user_name="user4") transaction.commit() - eq_(connection.execute(select([users.c.user_id]). - order_by(users.c.user_id)).fetchall(), - [(1, ), (4, )]) + eq_( + connection.execute( + select([users.c.user_id]).order_by(users.c.user_id) + ).fetchall(), + [(1,), (4,)], + ) connection.close() class TypeReflectionTest(fixtures.TestBase): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" def _fixed_lookup_fixture(self): return [ @@ -1823,22 +2110,29 @@ class TypeReflectionTest(fixtures.TestBase): (sqltypes.Time, sqltypes.TIME()), (sqltypes.BOOLEAN, sqltypes.BOOLEAN()), (sqltypes.Boolean, sqltypes.BOOLEAN()), - (sqlite.DATE( - storage_format="%(year)04d%(month)02d%(day)02d", - ), sqltypes.DATE()), - (sqlite.TIME( - storage_format="%(hour)02d%(minute)02d%(second)02d", - ), sqltypes.TIME()), - (sqlite.DATETIME( - storage_format="%(year)04d%(month)02d%(day)02d" - "%(hour)02d%(minute)02d%(second)02d", - ), sqltypes.DATETIME()), + ( + sqlite.DATE(storage_format="%(year)04d%(month)02d%(day)02d"), + sqltypes.DATE(), + ), + ( + sqlite.TIME( + storage_format="%(hour)02d%(minute)02d%(second)02d" + ), + sqltypes.TIME(), + ), + ( + sqlite.DATETIME( + storage_format="%(year)04d%(month)02d%(day)02d" + "%(hour)02d%(minute)02d%(second)02d" + ), + sqltypes.DATETIME(), + ), ] def _unsupported_args_fixture(self): return [ - ("INTEGER(5)", sqltypes.INTEGER(),), - ("DATETIME(6, 12)", sqltypes.DATETIME()) + ("INTEGER(5)", sqltypes.INTEGER()), + ("DATETIME(6, 12)", sqltypes.DATETIME()), ] def _type_affinity_fixture(self): @@ -1873,10 +2167,13 @@ class TypeReflectionTest(fixtures.TestBase): dialect = sqlite.dialect() for from_, to_ in self._fixture_as_string(fixture): if warnings: + def go(): return dialect._resolve_type_affinity(from_) + final_type = testing.assert_warnings( - go, ["Could not instantiate"], regex=True) + go, ["Could not instantiate"], regex=True + ) else: final_type = dialect._resolve_type_affinity(from_) expected_type = type(to_) @@ -1884,27 +2181,31 @@ class TypeReflectionTest(fixtures.TestBase): def _test_round_trip(self, fixture, warnings=False): from sqlalchemy import inspect + conn = testing.db.connect() for from_, to_ in self._fixture_as_string(fixture): inspector = inspect(conn) conn.execute("CREATE TABLE foo (data %s)" % from_) try: if warnings: + def go(): return inspector.get_columns("foo")[0] + col_info = testing.assert_warnings( - go, ["Could not instantiate"], regex=True) + go, ["Could not instantiate"], regex=True + ) else: col_info = inspector.get_columns("foo")[0] expected_type = type(to_) - is_(type(col_info['type']), expected_type) + is_(type(col_info["type"]), expected_type) # test args for attr in ("scale", "precision", "length"): if getattr(to_, attr, None) is not None: eq_( - getattr(col_info['type'], attr), - getattr(to_, attr, None) + getattr(col_info["type"], attr), + getattr(to_, attr, None), ) finally: conn.execute("DROP TABLE foo") @@ -1914,7 +2215,8 @@ class TypeReflectionTest(fixtures.TestBase): def test_lookup_direct_unsupported_args(self): self._test_lookup_direct( - self._unsupported_args_fixture(), warnings=True) + self._unsupported_args_fixture(), warnings=True + ) def test_lookup_direct_type_affinity(self): self._test_lookup_direct(self._type_affinity_fixture()) @@ -1923,8 +2225,7 @@ class TypeReflectionTest(fixtures.TestBase): self._test_round_trip(self._fixed_lookup_fixture()) def test_round_trip_direct_unsupported_args(self): - self._test_round_trip( - self._unsupported_args_fixture(), warnings=True) + self._test_round_trip(self._unsupported_args_fixture(), warnings=True) def test_round_trip_direct_type_affinity(self): self._test_round_trip(self._type_affinity_fixture()) diff --git a/test/dialect/test_sybase.py b/test/dialect/test_sybase.py index 6027471061..df2ae784e3 100644 --- a/test/dialect/test_sybase.py +++ b/test/dialect/test_sybase.py @@ -1,49 +1,54 @@ from sqlalchemy import extract, select from sqlalchemy import sql from sqlalchemy.databases import sybase -from sqlalchemy.testing import assert_raises_message, \ - fixtures, AssertsCompiledSQL +from sqlalchemy.testing import ( + assert_raises_message, + fixtures, + AssertsCompiledSQL, +) class CompileTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = sybase.dialect() def test_extract(self): - t = sql.table('t', sql.column('col1')) + t = sql.table("t", sql.column("col1")) mapping = { - 'day': 'day', - 'doy': 'dayofyear', - 'dow': 'weekday', - 'milliseconds': 'millisecond', - 'millisecond': 'millisecond', - 'year': 'year', + "day": "day", + "doy": "dayofyear", + "dow": "weekday", + "milliseconds": "millisecond", + "millisecond": "millisecond", + "year": "year", } for field, subst in list(mapping.items()): self.assert_compile( select([extract(field, t.c.col1)]), - 'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % subst) + 'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % subst, + ) def test_offset_not_supported(self): stmt = select([1]).offset(10) assert_raises_message( NotImplementedError, "Sybase ASE does not support OFFSET", - stmt.compile, dialect=self.__dialect__ + stmt.compile, + dialect=self.__dialect__, ) def test_delete_extra_froms(self): - t1 = sql.table('t1', sql.column('c1')) - t2 = sql.table('t2', sql.column('c1')) + t1 = sql.table("t1", sql.column("c1")) + t2 = sql.table("t2", sql.column("c1")) q = sql.delete(t1).where(t1.c.c1 == t2.c.c1) self.assert_compile( q, "DELETE FROM t1 FROM t1, t2 WHERE t1.c1 = t2.c1" ) def test_delete_extra_froms_alias(self): - a1 = sql.table('t1', sql.column('c1')).alias('a1') - t2 = sql.table('t2', sql.column('c1')) + a1 = sql.table("t1", sql.column("c1")).alias("a1") + t2 = sql.table("t2", sql.column("c1")) q = sql.delete(a1).where(a1.c.c1 == t2.c.c1) self.assert_compile( q, "DELETE FROM a1 FROM t1 AS a1, t2 WHERE a1.c1 = t2.c1" diff --git a/test/engine/test_bind.py b/test/engine/test_bind.py index a084a849f4..862626c40f 100644 --- a/test/engine/test_bind.py +++ b/test/engine/test_bind.py @@ -38,15 +38,9 @@ class BindTest(fixtures.TestBase): def test_create_drop_explicit(self): metadata = MetaData() - table = Table('test_table', metadata, Column('foo', Integer)) - for bind in ( - testing.db, - testing.db.connect() - ): - for args in [ - ([], {'bind': bind}), - ([bind], {}) - ]: + table = Table("test_table", metadata, Column("foo", Integer)) + for bind in (testing.db, testing.db.connect()): + for args in [([], {"bind": bind}), ([bind], {})]: metadata.create_all(*args[0], **args[1]) assert table.exists(*args[0], **args[1]) metadata.drop_all(*args[0], **args[1]) @@ -56,40 +50,35 @@ class BindTest(fixtures.TestBase): def test_create_drop_err_metadata(self): metadata = MetaData() - Table('test_table', metadata, Column('foo', Integer)) + Table("test_table", metadata, Column("foo", Integer)) for meth in [metadata.create_all, metadata.drop_all]: assert_raises_message( exc.UnboundExecutionError, "MetaData object is not bound to an Engine or Connection.", - meth + meth, ) def test_create_drop_err_table(self): metadata = MetaData() - table = Table('test_table', metadata, Column('foo', Integer)) + table = Table("test_table", metadata, Column("foo", Integer)) - for meth in [ - table.exists, - table.create, - table.drop, - ]: + for meth in [table.exists, table.create, table.drop]: assert_raises_message( exc.UnboundExecutionError, - ("Table object 'test_table' is not bound to an Engine or " - "Connection."), - meth + ( + "Table object 'test_table' is not bound to an Engine or " + "Connection." + ), + meth, ) @testing.uses_deprecated() def test_create_drop_bound(self): for meta in (MetaData, ThreadLocalMetaData): - for bind in ( - testing.db, - testing.db.connect() - ): + for bind in (testing.db, testing.db.connect()): metadata = meta() - table = Table('test_table', metadata, Column('foo', Integer)) + table = Table("test_table", metadata, Column("foo", Integer)) metadata.bind = bind assert metadata.bind is table.bind is bind metadata.create_all() @@ -100,7 +89,7 @@ class BindTest(fixtures.TestBase): assert not table.exists() metadata = meta() - table = Table('test_table', metadata, Column('foo', Integer)) + table = Table("test_table", metadata, Column("foo", Integer)) metadata.bind = bind @@ -115,18 +104,13 @@ class BindTest(fixtures.TestBase): bind.close() def test_create_drop_constructor_bound(self): - for bind in ( - testing.db, - testing.db.connect() - ): + for bind in (testing.db, testing.db.connect()): try: - for args in ( - ([bind], {}), - ([], {'bind': bind}), - ): + for args in (([bind], {}), ([], {"bind": bind})): metadata = MetaData(*args[0], **args[1]) - table = Table('test_table', metadata, - Column('foo', Integer)) + table = Table( + "test_table", metadata, Column("foo", Integer) + ) assert metadata.bind is table.bind is bind metadata.create_all() assert table.exists() @@ -140,9 +124,12 @@ class BindTest(fixtures.TestBase): def test_implicit_execution(self): metadata = MetaData() - table = Table('test_table', metadata, - Column('foo', Integer), - test_needs_acid=True) + table = Table( + "test_table", + metadata, + Column("foo", Integer), + test_needs_acid=True, + ) conn = testing.db.connect() metadata.create_all(bind=conn) try: @@ -155,26 +142,24 @@ class BindTest(fixtures.TestBase): table.insert().execute(foo=7) trans.rollback() metadata.bind = None - assert conn.execute('select count(*) from test_table' - ).scalar() == 0 + assert ( + conn.execute("select count(*) from test_table").scalar() == 0 + ) finally: metadata.drop_all(bind=conn) def test_clauseelement(self): metadata = MetaData() - table = Table('test_table', metadata, Column('foo', Integer)) + table = Table("test_table", metadata, Column("foo", Integer)) metadata.create_all(bind=testing.db) try: for elem in [ table.select, lambda **kwargs: sa.func.current_timestamp(**kwargs).select(), # func.current_timestamp().select, - lambda **kwargs:text("select * from test_table", **kwargs) + lambda **kwargs: text("select * from test_table", **kwargs), ]: - for bind in ( - testing.db, - testing.db.connect() - ): + for bind in (testing.db, testing.db.connect()): try: e = elem(bind=bind) assert e.bind is bind @@ -185,10 +170,7 @@ class BindTest(fixtures.TestBase): e = elem() assert e.bind is None - assert_raises( - exc.UnboundExecutionError, - e.execute - ) + assert_raises(exc.UnboundExecutionError, e.execute) finally: if isinstance(bind, engine.Connection): bind.close() diff --git a/test/engine/test_ddlevents.py b/test/engine/test_ddlevents.py index 80dd8084bb..307cb1cf1f 100644 --- a/test/engine/test_ddlevents.py +++ b/test/engine/test_ddlevents.py @@ -1,7 +1,10 @@ - from sqlalchemy.testing import assert_raises, assert_raises_message -from sqlalchemy.schema import DDL, CheckConstraint, AddConstraint, \ - DropConstraint +from sqlalchemy.schema import ( + DDL, + CheckConstraint, + AddConstraint, + DropConstraint, +) from sqlalchemy import create_engine from sqlalchemy import MetaData, Integer, String, event, exc, text from sqlalchemy.testing.schema import Table @@ -15,16 +18,15 @@ from sqlalchemy.testing import mock class DDLEventTest(fixtures.TestBase): - def setup(self): self.bind = engines.mock_engine() self.metadata = MetaData() - self.table = Table('t', self.metadata, Column('id', Integer)) + self.table = Table("t", self.metadata, Column("id", Integer)) def test_table_create_before(self): table, bind = self.table, self.bind canary = mock.Mock() - event.listen(table, 'before_create', canary.before_create) + event.listen(table, "before_create", canary.before_create) table.create(bind) table.drop(bind) @@ -32,15 +34,19 @@ class DDLEventTest(fixtures.TestBase): canary.mock_calls, [ mock.call.before_create( - table, self.bind, checkfirst=False, - _ddl_runner=mock.ANY, _is_metadata_operation=mock.ANY) - ] + table, + self.bind, + checkfirst=False, + _ddl_runner=mock.ANY, + _is_metadata_operation=mock.ANY, + ) + ], ) def test_table_create_after(self): table, bind = self.table, self.bind canary = mock.Mock() - event.listen(table, 'after_create', canary.after_create) + event.listen(table, "after_create", canary.after_create) table.create(bind) table.drop(bind) @@ -48,16 +54,20 @@ class DDLEventTest(fixtures.TestBase): canary.mock_calls, [ mock.call.after_create( - table, self.bind, checkfirst=False, - _ddl_runner=mock.ANY, _is_metadata_operation=mock.ANY) - ] + table, + self.bind, + checkfirst=False, + _ddl_runner=mock.ANY, + _is_metadata_operation=mock.ANY, + ) + ], ) def test_table_create_both(self): table, bind = self.table, self.bind canary = mock.Mock() - event.listen(table, 'before_create', canary.before_create) - event.listen(table, 'after_create', canary.after_create) + event.listen(table, "before_create", canary.before_create) + event.listen(table, "after_create", canary.after_create) table.create(bind) table.drop(bind) @@ -65,18 +75,26 @@ class DDLEventTest(fixtures.TestBase): canary.mock_calls, [ mock.call.before_create( - table, self.bind, checkfirst=False, - _ddl_runner=mock.ANY, _is_metadata_operation=mock.ANY), + table, + self.bind, + checkfirst=False, + _ddl_runner=mock.ANY, + _is_metadata_operation=mock.ANY, + ), mock.call.after_create( - table, self.bind, checkfirst=False, - _ddl_runner=mock.ANY, _is_metadata_operation=mock.ANY) - ] + table, + self.bind, + checkfirst=False, + _ddl_runner=mock.ANY, + _is_metadata_operation=mock.ANY, + ), + ], ) def test_table_drop_before(self): table, bind = self.table, self.bind canary = mock.Mock() - event.listen(table, 'before_drop', canary.before_drop) + event.listen(table, "before_drop", canary.before_drop) table.create(bind) table.drop(bind) @@ -84,34 +102,42 @@ class DDLEventTest(fixtures.TestBase): canary.mock_calls, [ mock.call.before_drop( - table, self.bind, checkfirst=False, - _ddl_runner=mock.ANY, _is_metadata_operation=mock.ANY), - ] + table, + self.bind, + checkfirst=False, + _ddl_runner=mock.ANY, + _is_metadata_operation=mock.ANY, + ) + ], ) def test_table_drop_after(self): table, bind = self.table, self.bind canary = mock.Mock() - event.listen(table, 'after_drop', canary.after_drop) + event.listen(table, "after_drop", canary.after_drop) table.create(bind) - canary.state = 'skipped' + canary.state = "skipped" table.drop(bind) eq_( canary.mock_calls, [ mock.call.after_drop( - table, self.bind, checkfirst=False, - _ddl_runner=mock.ANY, _is_metadata_operation=mock.ANY), - ] + table, + self.bind, + checkfirst=False, + _ddl_runner=mock.ANY, + _is_metadata_operation=mock.ANY, + ) + ], ) def test_table_drop_both(self): table, bind = self.table, self.bind canary = mock.Mock() - event.listen(table, 'before_drop', canary.before_drop) - event.listen(table, 'after_drop', canary.after_drop) + event.listen(table, "before_drop", canary.before_drop) + event.listen(table, "after_drop", canary.after_drop) table.create(bind) table.drop(bind) @@ -119,22 +145,30 @@ class DDLEventTest(fixtures.TestBase): canary.mock_calls, [ mock.call.before_drop( - table, self.bind, checkfirst=False, - _ddl_runner=mock.ANY, _is_metadata_operation=mock.ANY), + table, + self.bind, + checkfirst=False, + _ddl_runner=mock.ANY, + _is_metadata_operation=mock.ANY, + ), mock.call.after_drop( - table, self.bind, checkfirst=False, - _ddl_runner=mock.ANY, _is_metadata_operation=mock.ANY), - ] + table, + self.bind, + checkfirst=False, + _ddl_runner=mock.ANY, + _is_metadata_operation=mock.ANY, + ), + ], ) def test_table_all(self): table, bind = self.table, self.bind canary = mock.Mock() - event.listen(table, 'before_create', canary.before_create) - event.listen(table, 'after_create', canary.after_create) - event.listen(table, 'before_drop', canary.before_drop) - event.listen(table, 'after_drop', canary.after_drop) + event.listen(table, "before_create", canary.before_create) + event.listen(table, "after_create", canary.after_create) + event.listen(table, "before_drop", canary.before_drop) + event.listen(table, "after_drop", canary.after_drop) table.create(bind) table.drop(bind) @@ -142,24 +176,40 @@ class DDLEventTest(fixtures.TestBase): canary.mock_calls, [ mock.call.before_create( - table, self.bind, checkfirst=False, - _ddl_runner=mock.ANY, _is_metadata_operation=mock.ANY), + table, + self.bind, + checkfirst=False, + _ddl_runner=mock.ANY, + _is_metadata_operation=mock.ANY, + ), mock.call.after_create( - table, self.bind, checkfirst=False, - _ddl_runner=mock.ANY, _is_metadata_operation=mock.ANY), + table, + self.bind, + checkfirst=False, + _ddl_runner=mock.ANY, + _is_metadata_operation=mock.ANY, + ), mock.call.before_drop( - table, self.bind, checkfirst=False, - _ddl_runner=mock.ANY, _is_metadata_operation=mock.ANY), + table, + self.bind, + checkfirst=False, + _ddl_runner=mock.ANY, + _is_metadata_operation=mock.ANY, + ), mock.call.after_drop( - table, self.bind, checkfirst=False, - _ddl_runner=mock.ANY, _is_metadata_operation=mock.ANY), - ] + table, + self.bind, + checkfirst=False, + _ddl_runner=mock.ANY, + _is_metadata_operation=mock.ANY, + ), + ], ) def test_metadata_create_before(self): metadata, bind = self.metadata, self.bind canary = mock.Mock() - event.listen(metadata, 'before_create', canary.before_create) + event.listen(metadata, "before_create", canary.before_create) metadata.create_all(bind) metadata.drop_all(bind) @@ -169,16 +219,19 @@ class DDLEventTest(fixtures.TestBase): mock.call.before_create( # checkfirst is False because of the MockConnection # used in the current testing strategy. - metadata, self.bind, checkfirst=False, + metadata, + self.bind, + checkfirst=False, tables=list(metadata.tables.values()), - _ddl_runner=mock.ANY), - ] + _ddl_runner=mock.ANY, + ) + ], ) def test_metadata_create_after(self): metadata, bind = self.metadata, self.bind canary = mock.Mock() - event.listen(metadata, 'after_create', canary.after_create) + event.listen(metadata, "after_create", canary.after_create) metadata.create_all(bind) metadata.drop_all(bind) @@ -186,18 +239,21 @@ class DDLEventTest(fixtures.TestBase): canary.mock_calls, [ mock.call.after_create( - metadata, self.bind, checkfirst=False, + metadata, + self.bind, + checkfirst=False, tables=list(metadata.tables.values()), - _ddl_runner=mock.ANY), - ] + _ddl_runner=mock.ANY, + ) + ], ) def test_metadata_create_both(self): metadata, bind = self.metadata, self.bind canary = mock.Mock() - event.listen(metadata, 'before_create', canary.before_create) - event.listen(metadata, 'after_create', canary.after_create) + event.listen(metadata, "before_create", canary.before_create) + event.listen(metadata, "after_create", canary.after_create) metadata.create_all(bind) metadata.drop_all(bind) @@ -205,20 +261,26 @@ class DDLEventTest(fixtures.TestBase): canary.mock_calls, [ mock.call.before_create( - metadata, self.bind, checkfirst=False, + metadata, + self.bind, + checkfirst=False, tables=list(metadata.tables.values()), - _ddl_runner=mock.ANY), + _ddl_runner=mock.ANY, + ), mock.call.after_create( - metadata, self.bind, checkfirst=False, + metadata, + self.bind, + checkfirst=False, tables=list(metadata.tables.values()), - _ddl_runner=mock.ANY), - ] + _ddl_runner=mock.ANY, + ), + ], ) def test_metadata_drop_before(self): metadata, bind = self.metadata, self.bind canary = mock.Mock() - event.listen(metadata, 'before_drop', canary.before_drop) + event.listen(metadata, "before_drop", canary.before_drop) metadata.create_all(bind) metadata.drop_all(bind) @@ -226,16 +288,19 @@ class DDLEventTest(fixtures.TestBase): canary.mock_calls, [ mock.call.before_drop( - metadata, self.bind, checkfirst=False, + metadata, + self.bind, + checkfirst=False, tables=list(metadata.tables.values()), - _ddl_runner=mock.ANY), - ] + _ddl_runner=mock.ANY, + ) + ], ) def test_metadata_drop_after(self): metadata, bind = self.metadata, self.bind canary = mock.Mock() - event.listen(metadata, 'after_drop', canary.after_drop) + event.listen(metadata, "after_drop", canary.after_drop) metadata.create_all(bind) metadata.drop_all(bind) @@ -243,18 +308,21 @@ class DDLEventTest(fixtures.TestBase): canary.mock_calls, [ mock.call.after_drop( - metadata, self.bind, checkfirst=False, + metadata, + self.bind, + checkfirst=False, tables=list(metadata.tables.values()), - _ddl_runner=mock.ANY), - ] + _ddl_runner=mock.ANY, + ) + ], ) def test_metadata_drop_both(self): metadata, bind = self.metadata, self.bind canary = mock.Mock() - event.listen(metadata, 'before_drop', canary.before_drop) - event.listen(metadata, 'after_drop', canary.after_drop) + event.listen(metadata, "before_drop", canary.before_drop) + event.listen(metadata, "after_drop", canary.after_drop) metadata.create_all(bind) metadata.drop_all(bind) @@ -262,14 +330,20 @@ class DDLEventTest(fixtures.TestBase): canary.mock_calls, [ mock.call.before_drop( - metadata, self.bind, checkfirst=False, + metadata, + self.bind, + checkfirst=False, tables=list(metadata.tables.values()), - _ddl_runner=mock.ANY), + _ddl_runner=mock.ANY, + ), mock.call.after_drop( - metadata, self.bind, checkfirst=False, + metadata, + self.bind, + checkfirst=False, tables=list(metadata.tables.values()), - _ddl_runner=mock.ANY), - ] + _ddl_runner=mock.ANY, + ), + ], ) def test_metadata_table_isolation(self): @@ -277,297 +351,310 @@ class DDLEventTest(fixtures.TestBase): table_canary = mock.Mock() metadata_canary = mock.Mock() - event.listen(table, 'before_create', table_canary.before_create) + event.listen(table, "before_create", table_canary.before_create) - event.listen(metadata, 'before_create', metadata_canary.before_create) + event.listen(metadata, "before_create", metadata_canary.before_create) self.table.create(self.bind) eq_( table_canary.mock_calls, [ mock.call.before_create( - table, self.bind, checkfirst=False, - _ddl_runner=mock.ANY, _is_metadata_operation=mock.ANY), - ] - ) - eq_( - metadata_canary.mock_calls, - [] + table, + self.bind, + checkfirst=False, + _ddl_runner=mock.ANY, + _is_metadata_operation=mock.ANY, + ) + ], ) + eq_(metadata_canary.mock_calls, []) def test_append_listener(self): metadata, table, bind = self.metadata, self.table, self.bind - def fn(*a): return None + def fn(*a): + return None - table.append_ddl_listener('before-create', fn) - assert_raises(exc.InvalidRequestError, table.append_ddl_listener, - 'blah', fn) + table.append_ddl_listener("before-create", fn) + assert_raises( + exc.InvalidRequestError, table.append_ddl_listener, "blah", fn + ) - metadata.append_ddl_listener('before-create', fn) - assert_raises(exc.InvalidRequestError, metadata.append_ddl_listener, - 'blah', fn) + metadata.append_ddl_listener("before-create", fn) + assert_raises( + exc.InvalidRequestError, metadata.append_ddl_listener, "blah", fn + ) class DDLExecutionTest(fixtures.TestBase): def setup(self): self.engine = engines.mock_engine() self.metadata = MetaData(self.engine) - self.users = Table('users', self.metadata, - Column('user_id', Integer, primary_key=True), - Column('user_name', String(40)), - ) + self.users = Table( + "users", + self.metadata, + Column("user_id", Integer, primary_key=True), + Column("user_name", String(40)), + ) def test_table_standalone(self): users, engine = self.users, self.engine - event.listen(users, 'before_create', DDL('mxyzptlk')) - event.listen(users, 'after_create', DDL('klptzyxm')) - event.listen(users, 'before_drop', DDL('xyzzy')) - event.listen(users, 'after_drop', DDL('fnord')) + event.listen(users, "before_create", DDL("mxyzptlk")) + event.listen(users, "after_create", DDL("klptzyxm")) + event.listen(users, "before_drop", DDL("xyzzy")) + event.listen(users, "after_drop", DDL("fnord")) users.create() strings = [str(x) for x in engine.mock] - assert 'mxyzptlk' in strings - assert 'klptzyxm' in strings - assert 'xyzzy' not in strings - assert 'fnord' not in strings + assert "mxyzptlk" in strings + assert "klptzyxm" in strings + assert "xyzzy" not in strings + assert "fnord" not in strings del engine.mock[:] users.drop() strings = [str(x) for x in engine.mock] - assert 'mxyzptlk' not in strings - assert 'klptzyxm' not in strings - assert 'xyzzy' in strings - assert 'fnord' in strings + assert "mxyzptlk" not in strings + assert "klptzyxm" not in strings + assert "xyzzy" in strings + assert "fnord" in strings def test_table_by_metadata(self): metadata, users, engine = self.metadata, self.users, self.engine - event.listen(users, 'before_create', DDL('mxyzptlk')) - event.listen(users, 'after_create', DDL('klptzyxm')) - event.listen(users, 'before_drop', DDL('xyzzy')) - event.listen(users, 'after_drop', DDL('fnord')) + event.listen(users, "before_create", DDL("mxyzptlk")) + event.listen(users, "after_create", DDL("klptzyxm")) + event.listen(users, "before_drop", DDL("xyzzy")) + event.listen(users, "after_drop", DDL("fnord")) metadata.create_all() strings = [str(x) for x in engine.mock] - assert 'mxyzptlk' in strings - assert 'klptzyxm' in strings - assert 'xyzzy' not in strings - assert 'fnord' not in strings + assert "mxyzptlk" in strings + assert "klptzyxm" in strings + assert "xyzzy" not in strings + assert "fnord" not in strings del engine.mock[:] metadata.drop_all() strings = [str(x) for x in engine.mock] - assert 'mxyzptlk' not in strings - assert 'klptzyxm' not in strings - assert 'xyzzy' in strings - assert 'fnord' in strings + assert "mxyzptlk" not in strings + assert "klptzyxm" not in strings + assert "xyzzy" in strings + assert "fnord" in strings - @testing.uses_deprecated(r'See DDLEvents') + @testing.uses_deprecated(r"See DDLEvents") def test_table_by_metadata_deprecated(self): metadata, users, engine = self.metadata, self.users, self.engine - DDL('mxyzptlk').execute_at('before-create', users) - DDL('klptzyxm').execute_at('after-create', users) - DDL('xyzzy').execute_at('before-drop', users) - DDL('fnord').execute_at('after-drop', users) + DDL("mxyzptlk").execute_at("before-create", users) + DDL("klptzyxm").execute_at("after-create", users) + DDL("xyzzy").execute_at("before-drop", users) + DDL("fnord").execute_at("after-drop", users) metadata.create_all() strings = [str(x) for x in engine.mock] - assert 'mxyzptlk' in strings - assert 'klptzyxm' in strings - assert 'xyzzy' not in strings - assert 'fnord' not in strings + assert "mxyzptlk" in strings + assert "klptzyxm" in strings + assert "xyzzy" not in strings + assert "fnord" not in strings del engine.mock[:] metadata.drop_all() strings = [str(x) for x in engine.mock] - assert 'mxyzptlk' not in strings - assert 'klptzyxm' not in strings - assert 'xyzzy' in strings - assert 'fnord' in strings + assert "mxyzptlk" not in strings + assert "klptzyxm" not in strings + assert "xyzzy" in strings + assert "fnord" in strings def test_deprecated_append_ddl_listener_table(self): metadata, users, engine = self.metadata, self.users, self.engine canary = [] - users.append_ddl_listener('before-create', - lambda e, t, b: canary.append('mxyzptlk')) - users.append_ddl_listener('after-create', - lambda e, t, b: canary.append('klptzyxm')) - users.append_ddl_listener('before-drop', - lambda e, t, b: canary.append('xyzzy')) - users.append_ddl_listener('after-drop', - lambda e, t, b: canary.append('fnord')) + users.append_ddl_listener( + "before-create", lambda e, t, b: canary.append("mxyzptlk") + ) + users.append_ddl_listener( + "after-create", lambda e, t, b: canary.append("klptzyxm") + ) + users.append_ddl_listener( + "before-drop", lambda e, t, b: canary.append("xyzzy") + ) + users.append_ddl_listener( + "after-drop", lambda e, t, b: canary.append("fnord") + ) metadata.create_all() - assert 'mxyzptlk' in canary - assert 'klptzyxm' in canary - assert 'xyzzy' not in canary - assert 'fnord' not in canary + assert "mxyzptlk" in canary + assert "klptzyxm" in canary + assert "xyzzy" not in canary + assert "fnord" not in canary del engine.mock[:] canary[:] = [] metadata.drop_all() - assert 'mxyzptlk' not in canary - assert 'klptzyxm' not in canary - assert 'xyzzy' in canary - assert 'fnord' in canary + assert "mxyzptlk" not in canary + assert "klptzyxm" not in canary + assert "xyzzy" in canary + assert "fnord" in canary def test_deprecated_append_ddl_listener_metadata(self): metadata, users, engine = self.metadata, self.users, self.engine canary = [] metadata.append_ddl_listener( - 'before-create', - lambda e, t, b, tables=None: canary.append('mxyzptlk') + "before-create", + lambda e, t, b, tables=None: canary.append("mxyzptlk"), ) metadata.append_ddl_listener( - 'after-create', - lambda e, t, b, tables=None: canary.append('klptzyxm') + "after-create", + lambda e, t, b, tables=None: canary.append("klptzyxm"), ) metadata.append_ddl_listener( - 'before-drop', - lambda e, t, b, tables=None: canary.append('xyzzy') + "before-drop", lambda e, t, b, tables=None: canary.append("xyzzy") ) metadata.append_ddl_listener( - 'after-drop', - lambda e, t, b, tables=None: canary.append('fnord') + "after-drop", lambda e, t, b, tables=None: canary.append("fnord") ) metadata.create_all() - assert 'mxyzptlk' in canary - assert 'klptzyxm' in canary - assert 'xyzzy' not in canary - assert 'fnord' not in canary + assert "mxyzptlk" in canary + assert "klptzyxm" in canary + assert "xyzzy" not in canary + assert "fnord" not in canary del engine.mock[:] canary[:] = [] metadata.drop_all() - assert 'mxyzptlk' not in canary - assert 'klptzyxm' not in canary - assert 'xyzzy' in canary - assert 'fnord' in canary + assert "mxyzptlk" not in canary + assert "klptzyxm" not in canary + assert "xyzzy" in canary + assert "fnord" in canary def test_metadata(self): metadata, engine = self.metadata, self.engine - event.listen(metadata, 'before_create', DDL('mxyzptlk')) - event.listen(metadata, 'after_create', DDL('klptzyxm')) - event.listen(metadata, 'before_drop', DDL('xyzzy')) - event.listen(metadata, 'after_drop', DDL('fnord')) + event.listen(metadata, "before_create", DDL("mxyzptlk")) + event.listen(metadata, "after_create", DDL("klptzyxm")) + event.listen(metadata, "before_drop", DDL("xyzzy")) + event.listen(metadata, "after_drop", DDL("fnord")) metadata.create_all() strings = [str(x) for x in engine.mock] - assert 'mxyzptlk' in strings - assert 'klptzyxm' in strings - assert 'xyzzy' not in strings - assert 'fnord' not in strings + assert "mxyzptlk" in strings + assert "klptzyxm" in strings + assert "xyzzy" not in strings + assert "fnord" not in strings del engine.mock[:] metadata.drop_all() strings = [str(x) for x in engine.mock] - assert 'mxyzptlk' not in strings - assert 'klptzyxm' not in strings - assert 'xyzzy' in strings - assert 'fnord' in strings + assert "mxyzptlk" not in strings + assert "klptzyxm" not in strings + assert "xyzzy" in strings + assert "fnord" in strings - @testing.uses_deprecated(r'See DDLEvents') + @testing.uses_deprecated(r"See DDLEvents") def test_metadata_deprecated(self): metadata, engine = self.metadata, self.engine - DDL('mxyzptlk').execute_at('before-create', metadata) - DDL('klptzyxm').execute_at('after-create', metadata) - DDL('xyzzy').execute_at('before-drop', metadata) - DDL('fnord').execute_at('after-drop', metadata) + DDL("mxyzptlk").execute_at("before-create", metadata) + DDL("klptzyxm").execute_at("after-create", metadata) + DDL("xyzzy").execute_at("before-drop", metadata) + DDL("fnord").execute_at("after-drop", metadata) metadata.create_all() strings = [str(x) for x in engine.mock] - assert 'mxyzptlk' in strings - assert 'klptzyxm' in strings - assert 'xyzzy' not in strings - assert 'fnord' not in strings + assert "mxyzptlk" in strings + assert "klptzyxm" in strings + assert "xyzzy" not in strings + assert "fnord" not in strings del engine.mock[:] metadata.drop_all() strings = [str(x) for x in engine.mock] - assert 'mxyzptlk' not in strings - assert 'klptzyxm' not in strings - assert 'xyzzy' in strings - assert 'fnord' in strings + assert "mxyzptlk" not in strings + assert "klptzyxm" not in strings + assert "xyzzy" in strings + assert "fnord" in strings def test_conditional_constraint(self): metadata, users, engine = self.metadata, self.users, self.engine - nonpg_mock = engines.mock_engine(dialect_name='sqlite') - pg_mock = engines.mock_engine(dialect_name='postgresql') - constraint = CheckConstraint('a < b', name='my_test_constraint', - table=users) + nonpg_mock = engines.mock_engine(dialect_name="sqlite") + pg_mock = engines.mock_engine(dialect_name="postgresql") + constraint = CheckConstraint( + "a < b", name="my_test_constraint", table=users + ) # by placing the constraint in an Add/Drop construct, the # 'inline_ddl' flag is set to False event.listen( users, - 'after_create', - AddConstraint(constraint).execute_if(dialect='postgresql'), + "after_create", + AddConstraint(constraint).execute_if(dialect="postgresql"), ) event.listen( users, - 'before_drop', - DropConstraint(constraint).execute_if(dialect='postgresql'), + "before_drop", + DropConstraint(constraint).execute_if(dialect="postgresql"), ) metadata.create_all(bind=nonpg_mock) - strings = ' '.join(str(x) for x in nonpg_mock.mock) - assert 'my_test_constraint' not in strings + strings = " ".join(str(x) for x in nonpg_mock.mock) + assert "my_test_constraint" not in strings metadata.drop_all(bind=nonpg_mock) - strings = ' '.join(str(x) for x in nonpg_mock.mock) - assert 'my_test_constraint' not in strings + strings = " ".join(str(x) for x in nonpg_mock.mock) + assert "my_test_constraint" not in strings metadata.create_all(bind=pg_mock) - strings = ' '.join(str(x) for x in pg_mock.mock) - assert 'my_test_constraint' in strings + strings = " ".join(str(x) for x in pg_mock.mock) + assert "my_test_constraint" in strings metadata.drop_all(bind=pg_mock) - strings = ' '.join(str(x) for x in pg_mock.mock) - assert 'my_test_constraint' in strings + strings = " ".join(str(x) for x in pg_mock.mock) + assert "my_test_constraint" in strings - @testing.uses_deprecated(r'See DDLEvents') + @testing.uses_deprecated(r"See DDLEvents") def test_conditional_constraint_deprecated(self): metadata, users, engine = self.metadata, self.users, self.engine - nonpg_mock = engines.mock_engine(dialect_name='sqlite') - pg_mock = engines.mock_engine(dialect_name='postgresql') - constraint = CheckConstraint('a < b', name='my_test_constraint', - table=users) + nonpg_mock = engines.mock_engine(dialect_name="sqlite") + pg_mock = engines.mock_engine(dialect_name="postgresql") + constraint = CheckConstraint( + "a < b", name="my_test_constraint", table=users + ) # by placing the constraint in an Add/Drop construct, the # 'inline_ddl' flag is set to False - AddConstraint(constraint, on='postgresql' - ).execute_at('after-create', users) - DropConstraint(constraint, on='postgresql' - ).execute_at('before-drop', users) + AddConstraint(constraint, on="postgresql").execute_at( + "after-create", users + ) + DropConstraint(constraint, on="postgresql").execute_at( + "before-drop", users + ) metadata.create_all(bind=nonpg_mock) - strings = ' '.join(str(x) for x in nonpg_mock.mock) - assert 'my_test_constraint' not in strings + strings = " ".join(str(x) for x in nonpg_mock.mock) + assert "my_test_constraint" not in strings metadata.drop_all(bind=nonpg_mock) - strings = ' '.join(str(x) for x in nonpg_mock.mock) - assert 'my_test_constraint' not in strings + strings = " ".join(str(x) for x in nonpg_mock.mock) + assert "my_test_constraint" not in strings metadata.create_all(bind=pg_mock) - strings = ' '.join(str(x) for x in pg_mock.mock) - assert 'my_test_constraint' in strings + strings = " ".join(str(x) for x in pg_mock.mock) + assert "my_test_constraint" in strings metadata.drop_all(bind=pg_mock) - strings = ' '.join(str(x) for x in pg_mock.mock) - assert 'my_test_constraint' in strings + strings = " ".join(str(x) for x in pg_mock.mock) + assert "my_test_constraint" in strings @testing.requires.sqlite def test_ddl_execute(self): - engine = create_engine('sqlite:///') + engine = create_engine("sqlite:///") cx = engine.connect() table = self.users - ddl = DDL('SELECT 1') - - for py in ('engine.execute(ddl)', - 'engine.execute(ddl, table)', - 'cx.execute(ddl)', - 'cx.execute(ddl, table)', - 'ddl.execute(engine)', - 'ddl.execute(engine, table)', - 'ddl.execute(cx)', - 'ddl.execute(cx, table)'): + ddl = DDL("SELECT 1") + + for py in ( + "engine.execute(ddl)", + "engine.execute(ddl, table)", + "cx.execute(ddl)", + "cx.execute(ddl, table)", + "ddl.execute(engine)", + "ddl.execute(engine, table)", + "ddl.execute(cx)", + "ddl.execute(cx, table)", + ): r = eval(py) assert list(r) == [(1,)], py - for py in ('ddl.execute()', - 'ddl.execute(target=table)'): + for py in ("ddl.execute()", "ddl.execute(target=table)"): try: r = eval(py) assert False @@ -576,8 +663,7 @@ class DDLExecutionTest(fixtures.TestBase): for bind in engine, cx: ddl.bind = bind - for py in ('ddl.execute()', - 'ddl.execute(target=table)'): + for py in ("ddl.execute()", "ddl.execute(target=table)"): r = eval(py) assert list(r) == [(1,)], py @@ -585,7 +671,8 @@ class DDLExecutionTest(fixtures.TestBase): """test the escaping of % characters in the DDL construct.""" default_from = testing.db.dialect.statement_compiler( - testing.db.dialect, None).default_from() + testing.db.dialect, None + ).default_from() # We're abusing the DDL() # construct here by pushing a SELECT through it @@ -598,93 +685,114 @@ class DDLExecutionTest(fixtures.TestBase): conn.execute( text("select 'foo%something'" + default_from) ).scalar(), - 'foo%something' + "foo%something", ) eq_( conn.execute( DDL("select 'foo%%something'" + default_from) ).scalar(), - 'foo%something' + "foo%something", ) class DDLTest(fixtures.TestBase, AssertsCompiledSQL): def mock_engine(self): - def executor(*a, **kw): return None - engine = create_engine(testing.db.name + '://', - strategy='mock', executor=executor) - engine.dialect.identifier_preparer = \ - tsa.sql.compiler.IdentifierPreparer(engine.dialect) + def executor(*a, **kw): + return None + + engine = create_engine( + testing.db.name + "://", strategy="mock", executor=executor + ) + engine.dialect.identifier_preparer = tsa.sql.compiler.IdentifierPreparer( + engine.dialect + ) return engine def test_tokens(self): m = MetaData() - sane_alone = Table('t', m, Column('id', Integer)) - sane_schema = Table('t', m, Column('id', Integer), schema='s') - insane_alone = Table('t t', m, Column('id', Integer)) - insane_schema = Table('t t', m, Column('id', Integer), - schema='s s') - ddl = DDL('%(schema)s-%(table)s-%(fullname)s') + sane_alone = Table("t", m, Column("id", Integer)) + sane_schema = Table("t", m, Column("id", Integer), schema="s") + insane_alone = Table("t t", m, Column("id", Integer)) + insane_schema = Table("t t", m, Column("id", Integer), schema="s s") + ddl = DDL("%(schema)s-%(table)s-%(fullname)s") dialect = self.mock_engine().dialect - self.assert_compile(ddl.against(sane_alone), '-t-t', - dialect=dialect) - self.assert_compile(ddl.against(sane_schema), 's-t-s.t', - dialect=dialect) - self.assert_compile(ddl.against(insane_alone), '-"t t"-"t t"', - dialect=dialect) - self.assert_compile(ddl.against(insane_schema), - '"s s"-"t t"-"s s"."t t"', dialect=dialect) + self.assert_compile(ddl.against(sane_alone), "-t-t", dialect=dialect) + self.assert_compile( + ddl.against(sane_schema), "s-t-s.t", dialect=dialect + ) + self.assert_compile( + ddl.against(insane_alone), '-"t t"-"t t"', dialect=dialect + ) + self.assert_compile( + ddl.against(insane_schema), + '"s s"-"t t"-"s s"."t t"', + dialect=dialect, + ) # overrides are used piece-meal and verbatim. - ddl = DDL('%(schema)s-%(table)s-%(fullname)s-%(bonus)s', - context={'schema': 'S S', 'table': 'T T', 'bonus': 'b'}) - self.assert_compile(ddl.against(sane_alone), 'S S-T T-t-b', - dialect=dialect) - self.assert_compile(ddl.against(sane_schema), 'S S-T T-s.t-b', - dialect=dialect) - self.assert_compile(ddl.against(insane_alone), 'S S-T T-"t t"-b', - dialect=dialect) - self.assert_compile(ddl.against(insane_schema), - 'S S-T T-"s s"."t t"-b', dialect=dialect) + ddl = DDL( + "%(schema)s-%(table)s-%(fullname)s-%(bonus)s", + context={"schema": "S S", "table": "T T", "bonus": "b"}, + ) + self.assert_compile( + ddl.against(sane_alone), "S S-T T-t-b", dialect=dialect + ) + self.assert_compile( + ddl.against(sane_schema), "S S-T T-s.t-b", dialect=dialect + ) + self.assert_compile( + ddl.against(insane_alone), 'S S-T T-"t t"-b', dialect=dialect + ) + self.assert_compile( + ddl.against(insane_schema), + 'S S-T T-"s s"."t t"-b', + dialect=dialect, + ) def test_filter(self): cx = self.mock_engine() - tbl = Table('t', MetaData(), Column('id', Integer)) + tbl = Table("t", MetaData(), Column("id", Integer)) target = cx.name - assert DDL('')._should_execute(tbl, cx) - assert DDL('').execute_if(dialect=target)._should_execute(tbl, cx) - assert not DDL('').execute_if(dialect='bogus').\ - _should_execute(tbl, cx) - assert DDL('').execute_if(callable_=lambda d, y, z, **kw: True).\ - _should_execute(tbl, cx) - assert(DDL('').execute_if( - callable_=lambda d, y, z, **kw: z.engine.name - != 'bogus'). - _should_execute(tbl, cx)) - - @testing.uses_deprecated(r'See DDLEvents') + assert DDL("")._should_execute(tbl, cx) + assert DDL("").execute_if(dialect=target)._should_execute(tbl, cx) + assert not DDL("").execute_if(dialect="bogus")._should_execute(tbl, cx) + assert ( + DDL("") + .execute_if(callable_=lambda d, y, z, **kw: True) + ._should_execute(tbl, cx) + ) + assert ( + DDL("") + .execute_if( + callable_=lambda d, y, z, **kw: z.engine.name != "bogus" + ) + ._should_execute(tbl, cx) + ) + + @testing.uses_deprecated(r"See DDLEvents") def test_filter_deprecated(self): cx = self.mock_engine() - tbl = Table('t', MetaData(), Column('id', Integer)) + tbl = Table("t", MetaData(), Column("id", Integer)) target = cx.name - assert DDL('')._should_execute_deprecated('x', tbl, cx) - assert DDL('', on=target)._should_execute_deprecated('x', tbl, cx) - assert not DDL('', on='bogus').\ - _should_execute_deprecated('x', tbl, cx) - assert DDL('', on=lambda d, x, y, z: True).\ - _should_execute_deprecated('x', tbl, cx) - assert(DDL('', on=lambda d, x, y, z: z.engine.name != 'bogus'). - _should_execute_deprecated('x', tbl, cx)) + assert DDL("")._should_execute_deprecated("x", tbl, cx) + assert DDL("", on=target)._should_execute_deprecated("x", tbl, cx) + assert not DDL("", on="bogus")._should_execute_deprecated("x", tbl, cx) + assert DDL("", on=lambda d, x, y, z: True)._should_execute_deprecated( + "x", tbl, cx + ) + assert DDL( + "", on=lambda d, x, y, z: z.engine.name != "bogus" + )._should_execute_deprecated("x", tbl, cx) def test_repr(self): - assert repr(DDL('s')) - assert repr(DDL('s', on='engine')) - assert repr(DDL('s', on=lambda x: 1)) - assert repr(DDL('s', context={'a': 1})) - assert repr(DDL('s', on='engine', context={'a': 1})) + assert repr(DDL("s")) + assert repr(DDL("s", on="engine")) + assert repr(DDL("s", on=lambda x: 1)) + assert repr(DDL("s", context={"a": 1})) + assert repr(DDL("s", on="engine", context={"a": 1})) diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 84262dab24..87bf52bdb6 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -1,15 +1,35 @@ # coding: utf-8 import weakref -from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, \ - config, is_, is_not_, le_, expect_warnings +from sqlalchemy.testing import ( + eq_, + assert_raises, + assert_raises_message, + config, + is_, + is_not_, + le_, + expect_warnings, +) import re from sqlalchemy.testing.util import picklers from sqlalchemy.testing.util import gc_collect from sqlalchemy.interfaces import ConnectionProxy -from sqlalchemy import MetaData, Integer, String, INT, VARCHAR, func, \ - bindparam, select, event, TypeDecorator, create_engine, Sequence, \ - LargeBinary +from sqlalchemy import ( + MetaData, + Integer, + String, + INT, + VARCHAR, + func, + bindparam, + select, + event, + TypeDecorator, + create_engine, + Sequence, + LargeBinary, +) from sqlalchemy.sql import column, literal from sqlalchemy.testing.schema import Table, Column import sqlalchemy as tsa @@ -43,16 +63,18 @@ class ExecuteTest(fixtures.TestBase): global users, users_autoinc, metadata metadata = MetaData(testing.db) users = Table( - 'users', metadata, - Column('user_id', INT, primary_key=True, autoincrement=False), - Column('user_name', VARCHAR(20)) + "users", + metadata, + Column("user_id", INT, primary_key=True, autoincrement=False), + Column("user_name", VARCHAR(20)), ) users_autoinc = Table( - 'users_autoinc', metadata, + "users_autoinc", + metadata, Column( - 'user_id', INT, primary_key=True, - test_needs_autoincrement=True), - Column('user_name', VARCHAR(20)), + "user_id", INT, primary_key=True, test_needs_autoincrement=True + ), + Column("user_name", VARCHAR(20)), ) metadata.create_all() @@ -66,60 +88,72 @@ class ExecuteTest(fixtures.TestBase): @testing.fails_on( "postgresql+pg8000", - "pg8000 still doesn't allow single paren without params") + "pg8000 still doesn't allow single paren without params", + ) def test_no_params_option(self): - stmt = "SELECT '%'" + testing.db.dialect.statement_compiler( - testing.db.dialect, None).default_from() + stmt = ( + "SELECT '%'" + + testing.db.dialect.statement_compiler( + testing.db.dialect, None + ).default_from() + ) conn = testing.db.connect() - result = conn.\ - execution_options(no_parameters=True).\ - scalar(stmt) - eq_(result, '%') - - @testing.fails_on_everything_except('firebird', - 'sqlite', '+pyodbc', - '+mxodbc', '+zxjdbc', 'mysql+oursql') + result = conn.execution_options(no_parameters=True).scalar(stmt) + eq_(result, "%") + + @testing.fails_on_everything_except( + "firebird", "sqlite", "+pyodbc", "+mxodbc", "+zxjdbc", "mysql+oursql" + ) def test_raw_qmark(self): def go(conn): - conn.execute('insert into users (user_id, user_name) ' - 'values (?, ?)', (1, 'jack')) - conn.execute('insert into users (user_id, user_name) ' - 'values (?, ?)', [2, 'fred']) - conn.execute('insert into users (user_id, user_name) ' - 'values (?, ?)', [3, 'ed'], [4, 'horse']) - conn.execute('insert into users (user_id, user_name) ' - 'values (?, ?)', (5, 'barney'), (6, 'donkey')) - conn.execute('insert into users (user_id, user_name) ' - 'values (?, ?)', 7, 'sally') - res = conn.execute('select * from users order by user_id') + conn.execute( + "insert into users (user_id, user_name) " "values (?, ?)", + (1, "jack"), + ) + conn.execute( + "insert into users (user_id, user_name) " "values (?, ?)", + [2, "fred"], + ) + conn.execute( + "insert into users (user_id, user_name) " "values (?, ?)", + [3, "ed"], + [4, "horse"], + ) + conn.execute( + "insert into users (user_id, user_name) " "values (?, ?)", + (5, "barney"), + (6, "donkey"), + ) + conn.execute( + "insert into users (user_id, user_name) " "values (?, ?)", + 7, + "sally", + ) + res = conn.execute("select * from users order by user_id") assert res.fetchall() == [ - (1, 'jack'), - (2, 'fred'), - (3, 'ed'), - (4, 'horse'), - (5, 'barney'), - (6, 'donkey'), - (7, 'sally'), + (1, "jack"), + (2, "fred"), + (3, "ed"), + (4, "horse"), + (5, "barney"), + (6, "donkey"), + (7, "sally"), ] for multiparam, param in [ (("jack", "fred"), {}), - ((["jack", "fred"],), {}) + ((["jack", "fred"],), {}), ]: res = conn.execute( "select * from users where user_name=? or " "user_name=? order by user_id", - *multiparam, **param) - assert res.fetchall() == [ - (1, 'jack'), - (2, 'fred') - ] - res = conn.execute( - "select * from users where user_name=?", - "jack" - ) - assert res.fetchall() == [(1, 'jack')] - conn.execute('delete from users') + *multiparam, + **param + ) + assert res.fetchall() == [(1, "jack"), (2, "fred")] + res = conn.execute("select * from users where user_name=?", "jack") + assert res.fetchall() == [(1, "jack")] + conn.execute("delete from users") go(testing.db) conn = testing.db.connect() @@ -130,42 +164,56 @@ class ExecuteTest(fixtures.TestBase): # some psycopg2 versions bomb this. @testing.fails_on_everything_except( - 'mysql+mysqldb', 'mysql+pymysql', - 'mysql+cymysql', 'mysql+mysqlconnector', 'postgresql') - @testing.fails_on('postgresql+zxjdbc', 'sprintf not supported') + "mysql+mysqldb", + "mysql+pymysql", + "mysql+cymysql", + "mysql+mysqlconnector", + "postgresql", + ) + @testing.fails_on("postgresql+zxjdbc", "sprintf not supported") def test_raw_sprintf(self): def go(conn): - conn.execute('insert into users (user_id, user_name) ' - 'values (%s, %s)', [1, 'jack']) - conn.execute('insert into users (user_id, user_name) ' - 'values (%s, %s)', [2, 'ed'], [3, 'horse']) - conn.execute('insert into users (user_id, user_name) ' - 'values (%s, %s)', 4, 'sally') - conn.execute('insert into users (user_id) values (%s)', 5) - res = conn.execute('select * from users order by user_id') + conn.execute( + "insert into users (user_id, user_name) " "values (%s, %s)", + [1, "jack"], + ) + conn.execute( + "insert into users (user_id, user_name) " "values (%s, %s)", + [2, "ed"], + [3, "horse"], + ) + conn.execute( + "insert into users (user_id, user_name) " "values (%s, %s)", + 4, + "sally", + ) + conn.execute("insert into users (user_id) values (%s)", 5) + res = conn.execute("select * from users order by user_id") assert res.fetchall() == [ - (1, 'jack'), (2, 'ed'), - (3, 'horse'), (4, 'sally'), (5, None) + (1, "jack"), + (2, "ed"), + (3, "horse"), + (4, "sally"), + (5, None), ] for multiparam, param in [ (("jack", "ed"), {}), - ((["jack", "ed"],), {}) + ((["jack", "ed"],), {}), ]: res = conn.execute( "select * from users where user_name=%s or " "user_name=%s order by user_id", - *multiparam, **param) - assert res.fetchall() == [ - (1, 'jack'), - (2, 'ed') - ] + *multiparam, + **param + ) + assert res.fetchall() == [(1, "jack"), (2, "ed")] res = conn.execute( - "select * from users where user_name=%s", - "jack" + "select * from users where user_name=%s", "jack" ) - assert res.fetchall() == [(1, 'jack')] + assert res.fetchall() == [(1, "jack")] + + conn.execute("delete from users") - conn.execute('delete from users') go(testing.db) conn = testing.db.connect() try: @@ -177,31 +225,45 @@ class ExecuteTest(fixtures.TestBase): # versions have a bug that bombs out on this test. (1.2.2b3, # 1.2.2c1, 1.2.2) - @testing.skip_if( - lambda: testing.against('mysql+mysqldb'), 'db-api flaky') + @testing.skip_if(lambda: testing.against("mysql+mysqldb"), "db-api flaky") @testing.fails_on_everything_except( - 'postgresql+psycopg2', 'postgresql+psycopg2cffi', - 'postgresql+pypostgresql', 'postgresql+pygresql', - 'mysql+mysqlconnector', 'mysql+pymysql', 'mysql+cymysql', - 'mssql+pymssql') + "postgresql+psycopg2", + "postgresql+psycopg2cffi", + "postgresql+pypostgresql", + "postgresql+pygresql", + "mysql+mysqlconnector", + "mysql+pymysql", + "mysql+cymysql", + "mssql+pymssql", + ) def test_raw_python(self): def go(conn): conn.execute( - 'insert into users (user_id, user_name) ' - 'values (%(id)s, %(name)s)', - {'id': 1, 'name': 'jack'}) + "insert into users (user_id, user_name) " + "values (%(id)s, %(name)s)", + {"id": 1, "name": "jack"}, + ) conn.execute( - 'insert into users (user_id, user_name) ' - 'values (%(id)s, %(name)s)', - {'id': 2, 'name': 'ed'}, {'id': 3, 'name': 'horse'}) + "insert into users (user_id, user_name) " + "values (%(id)s, %(name)s)", + {"id": 2, "name": "ed"}, + {"id": 3, "name": "horse"}, + ) conn.execute( - 'insert into users (user_id, user_name) ' - 'values (%(id)s, %(name)s)', id=4, name='sally' + "insert into users (user_id, user_name) " + "values (%(id)s, %(name)s)", + id=4, + name="sally", ) - res = conn.execute('select * from users order by user_id') + res = conn.execute("select * from users order by user_id") assert res.fetchall() == [ - (1, 'jack'), (2, 'ed'), (3, 'horse'), (4, 'sally')] - conn.execute('delete from users') + (1, "jack"), + (2, "ed"), + (3, "horse"), + (4, "sally"), + ] + conn.execute("delete from users") + go(testing.db) conn = testing.db.connect() try: @@ -209,21 +271,35 @@ class ExecuteTest(fixtures.TestBase): finally: conn.close() - @testing.fails_on_everything_except('sqlite', 'oracle+cx_oracle') + @testing.fails_on_everything_except("sqlite", "oracle+cx_oracle") def test_raw_named(self): def go(conn): - conn.execute('insert into users (user_id, user_name) ' - 'values (:id, :name)', {'id': 1, 'name': 'jack' - }) - conn.execute('insert into users (user_id, user_name) ' - 'values (:id, :name)', {'id': 2, 'name': 'ed' - }, {'id': 3, 'name': 'horse'}) - conn.execute('insert into users (user_id, user_name) ' - 'values (:id, :name)', id=4, name='sally') - res = conn.execute('select * from users order by user_id') + conn.execute( + "insert into users (user_id, user_name) " + "values (:id, :name)", + {"id": 1, "name": "jack"}, + ) + conn.execute( + "insert into users (user_id, user_name) " + "values (:id, :name)", + {"id": 2, "name": "ed"}, + {"id": 3, "name": "horse"}, + ) + conn.execute( + "insert into users (user_id, user_name) " + "values (:id, :name)", + id=4, + name="sally", + ) + res = conn.execute("select * from users order by user_id") assert res.fetchall() == [ - (1, 'jack'), (2, 'ed'), (3, 'horse'), (4, 'sally')] - conn.execute('delete from users') + (1, "jack"), + (2, "ed"), + (3, "horse"), + (4, "sally"), + ] + conn.execute("delete from users") + go(testing.db) conn = testing.db.connect() try: @@ -238,12 +314,13 @@ class ExecuteTest(fixtures.TestBase): assert_raises_message( tsa.exc.DBAPIError, r"not_a_valid_statement", - _c.execute, 'not_a_valid_statement' + _c.execute, + "not_a_valid_statement", ) @testing.requires.sqlite def test_exception_wrapping_non_dbapi_error(self): - e = create_engine('sqlite://') + e = create_engine("sqlite://") e.dialect.is_disconnect = is_disconnect = Mock() with e.connect() as c: @@ -251,13 +328,12 @@ class ExecuteTest(fixtures.TestBase): return_value=Mock( execute=Mock( side_effect=TypeError("I'm not a DBAPI error") - )) + ) + ) ) assert_raises_message( - TypeError, - "I'm not a DBAPI error", - c.execute, "select " + TypeError, "I'm not a DBAPI error", c.execute, "select " ) eq_(is_disconnect.call_count, 0) @@ -274,16 +350,17 @@ class ExecuteTest(fixtures.TestBase): with nested( patch.object(testing.db.dialect, "dbapi", Mock(Error=DBAPIError)), patch.object( - testing.db.dialect, "is_disconnect", - lambda *arg: False), + testing.db.dialect, "is_disconnect", lambda *arg: False + ), patch.object( - testing.db.dialect, "do_execute", - Mock(side_effect=NonStandardException)), + testing.db.dialect, + "do_execute", + Mock(side_effect=NonStandardException), + ), ): with testing.db.connect() as conn: assert_raises( - tsa.exc.OperationalError, - conn.execute, "select 1" + tsa.exc.OperationalError, conn.execute, "select 1" ) def test_exception_wrapping_non_dbapi_statement(self): @@ -299,11 +376,9 @@ class ExecuteTest(fixtures.TestBase): r"\(test.engine.test_execute.SomeException\) " r"nope \[SQL\: u?'SELECT 1 ", conn.execute, - select([1]). - where( - column('foo') == literal('bar', MyType()) - ) + select([1]).where(column("foo") == literal("bar", MyType())), ) + _go(testing.db) conn = testing.db.connect() try: @@ -314,39 +389,43 @@ class ExecuteTest(fixtures.TestBase): def test_not_an_executable(self): for obj in ( Table("foo", MetaData(), Column("x", Integer)), - Column('x', Integer), + Column("x", Integer), tsa.and_(), - column('foo'), + column("foo"), tsa.and_().compile(), - column('foo').compile(), + column("foo").compile(), MetaData(), Integer(), - tsa.Index(name='foo'), - tsa.UniqueConstraint('x') + tsa.Index(name="foo"), + tsa.UniqueConstraint("x"), ): with testing.db.connect() as conn: assert_raises_message( tsa.exc.ObjectNotExecutableError, "Not an executable object", - conn.execute, obj + conn.execute, + obj, ) def test_stmt_exception_non_ascii(self): - name = util.u('méil') + name = util.u("méil") with testing.db.connect() as conn: assert_raises_message( tsa.exc.StatementError, util.u( "A value is required for bind parameter 'uname'" - r'.*SELECT users.user_name AS .m\\xe9il.') if util.py2k - else - util.u( + r".*SELECT users.user_name AS .m\\xe9il." + ) + if util.py2k + else util.u( "A value is required for bind parameter 'uname'" - '.*SELECT users.user_name AS .méil.'), + ".*SELECT users.user_name AS .méil." + ), conn.execute, select([users.c.user_name.label(name)]).where( - users.c.user_name == bindparam("uname")), - {'uname_incorrect': 'foo'} + users.c.user_name == bindparam("uname") + ), + {"uname_incorrect": "foo"}, ) def test_stmt_exception_pickleable_no_dbapi(self): @@ -354,17 +433,20 @@ class ExecuteTest(fixtures.TestBase): @testing.crashes( "postgresql+psycopg2", - "Older versions don't support cursor pickling, newer ones do") + "Older versions don't support cursor pickling, newer ones do", + ) @testing.fails_on( "mysql+oursql", - "Exception doesn't come back exactly the same from pickle") + "Exception doesn't come back exactly the same from pickle", + ) @testing.fails_on( "mysql+mysqlconnector", - "Exception doesn't come back exactly the same from pickle") + "Exception doesn't come back exactly the same from pickle", + ) @testing.fails_on( "oracle+cx_oracle", - "cx_oracle exception seems to be having " - "some issue with pickling") + "cx_oracle exception seems to be having " "some issue with pickling", + ) def test_stmt_exception_pickleable_plus_dbapi(self): raw = testing.db.raw_connection() the_orig = None @@ -381,17 +463,17 @@ class ExecuteTest(fixtures.TestBase): def _test_stmt_exception_pickleable(self, orig): for sa_exc in ( - tsa.exc.StatementError("some error", - "select * from table", - {"foo": "bar"}, - orig), - tsa.exc.InterfaceError("select * from table", - {"foo": "bar"}, - orig), + tsa.exc.StatementError( + "some error", "select * from table", {"foo": "bar"}, orig + ), + tsa.exc.InterfaceError( + "select * from table", {"foo": "bar"}, orig + ), tsa.exc.NoReferencedTableError("message", "tname"), tsa.exc.NoReferencedColumnError("message", "tname", "cname"), tsa.exc.CircularDependencyError( - "some message", [1, 2, 3], [(1, 2), (3, 4)]), + "some message", [1, 2, 3], [(1, 2), (3, 4)] + ), ): for loads, dumps in picklers(): repickled = loads(dumps(sa_exc)) @@ -400,8 +482,10 @@ class ExecuteTest(fixtures.TestBase): eq_(repickled.params, {"foo": "bar"}) eq_(repickled.statement, sa_exc.statement) if hasattr(sa_exc, "connection_invalidated"): - eq_(repickled.connection_invalidated, - sa_exc.connection_invalidated) + eq_( + repickled.connection_invalidated, + sa_exc.connection_invalidated, + ) eq_(repickled.orig.args[0], orig.args[0]) def test_dont_wrap_mixin(self): @@ -419,11 +503,9 @@ class ExecuteTest(fixtures.TestBase): MyException, "nope", conn.execute, - select([1]). - where( - column('foo') == literal('bar', MyType()) - ) + select([1]).where(column("foo") == literal("bar", MyType())), ) + _go(testing.db) conn = testing.db.connect() try: @@ -434,14 +516,16 @@ class ExecuteTest(fixtures.TestBase): def test_empty_insert(self): """test that execute() interprets [] as a list with no params""" - testing.db.execute(users_autoinc.insert(). - values(user_name=bindparam('name', None)), []) + testing.db.execute( + users_autoinc.insert().values(user_name=bindparam("name", None)), + [], + ) eq_(testing.db.execute(users_autoinc.select()).fetchall(), [(1, None)]) @testing.only_on("sqlite") def test_execute_compiled_favors_compiled_paramstyle(self): with patch.object(testing.db.dialect, "do_execute") as do_exec: - stmt = users.update().values(user_id=1, user_name='foo') + stmt = users.update().values(user_id=1, user_name="foo") d1 = default.DefaultDialect(paramstyle="format") d2 = default.DefaultDialect(paramstyle="pyformat") @@ -450,62 +534,63 @@ class ExecuteTest(fixtures.TestBase): testing.db.execute(stmt.compile(dialect=d2)) eq_( - do_exec.mock_calls, [ + do_exec.mock_calls, + [ call( mock.ANY, "UPDATE users SET user_id=%s, user_name=%s", - (1, 'foo'), - mock.ANY + (1, "foo"), + mock.ANY, ), call( mock.ANY, "UPDATE users SET user_id=%(user_id)s, " "user_name=%(user_name)s", - {'user_name': 'foo', 'user_id': 1}, - mock.ANY - ) - ] + {"user_name": "foo", "user_id": 1}, + mock.ANY, + ), + ], ) @testing.requires.ad_hoc_engines def test_engine_level_options(self): - eng = engines.testing_engine(options={'execution_options': - {'foo': 'bar'}}) + eng = engines.testing_engine( + options={"execution_options": {"foo": "bar"}} + ) with eng.contextual_connect() as conn: - eq_(conn._execution_options['foo'], 'bar') + eq_(conn._execution_options["foo"], "bar") eq_( - conn.execution_options(bat='hoho')._execution_options['foo'], - 'bar') + conn.execution_options(bat="hoho")._execution_options["foo"], + "bar", + ) eq_( - conn.execution_options(bat='hoho')._execution_options['bat'], - 'hoho') + conn.execution_options(bat="hoho")._execution_options["bat"], + "hoho", + ) eq_( - conn.execution_options(foo='hoho')._execution_options['foo'], - 'hoho') - eng.update_execution_options(foo='hoho') + conn.execution_options(foo="hoho")._execution_options["foo"], + "hoho", + ) + eng.update_execution_options(foo="hoho") conn = eng.contextual_connect() - eq_(conn._execution_options['foo'], 'hoho') + eq_(conn._execution_options["foo"], "hoho") @testing.requires.ad_hoc_engines def test_generative_engine_execution_options(self): - eng = engines.testing_engine(options={'execution_options': - {'base': 'x1'}}) + eng = engines.testing_engine( + options={"execution_options": {"base": "x1"}} + ) eng1 = eng.execution_options(foo="b1") eng2 = eng.execution_options(foo="b2") eng1a = eng1.execution_options(bar="a1") eng2a = eng2.execution_options(foo="b3", bar="a2") - eq_(eng._execution_options, - {'base': 'x1'}) - eq_(eng1._execution_options, - {'base': 'x1', 'foo': 'b1'}) - eq_(eng2._execution_options, - {'base': 'x1', 'foo': 'b2'}) - eq_(eng1a._execution_options, - {'base': 'x1', 'foo': 'b1', 'bar': 'a1'}) - eq_(eng2a._execution_options, - {'base': 'x1', 'foo': 'b3', 'bar': 'a2'}) + eq_(eng._execution_options, {"base": "x1"}) + eq_(eng1._execution_options, {"base": "x1", "foo": "b1"}) + eq_(eng2._execution_options, {"base": "x1", "foo": "b2"}) + eq_(eng1a._execution_options, {"base": "x1", "foo": "b1", "bar": "a1"}) + eq_(eng2a._execution_options, {"base": "x1", "foo": "b3", "bar": "a2"}) is_(eng1a.pool, eng.pool) # test pool is shared @@ -526,7 +611,7 @@ class ExecuteTest(fixtures.TestBase): eng = create_engine(testing.db.url) def my_init(connection): - connection.execution_options(foo='bar').execute(select([1])) + connection.execution_options(foo="bar").execute(select([1])) with patch.object(eng.dialect, "initialize", my_init): conn = eng.connect() @@ -537,25 +622,26 @@ class ExecuteTest(fixtures.TestBase): def test_generative_engine_event_dispatch_hasevents(self): def l1(*arg, **kw): pass + eng = create_engine(testing.db.url) assert not eng._has_events event.listen(eng, "before_execute", l1) - eng2 = eng.execution_options(foo='bar') + eng2 = eng.execution_options(foo="bar") assert eng2._has_events def test_unicode_test_fails_warning(self): class MockCursor(engines.DBAPIProxyCursor): - def execute(self, stmt, params=None, **kw): if "test unicode returns" in stmt: raise self.engine.dialect.dbapi.DatabaseError("boom") else: return super(MockCursor, self).execute(stmt, params, **kw) + eng = engines.proxying_engine(cursor_cls=MockCursor) assert_raises_message( tsa.exc.SAWarning, "Exception attempting to detect unicode returns", - eng.connect + eng.connect, ) assert eng.dialect.returns_unicode_strings in (True, False) eng.dispose() @@ -578,17 +664,20 @@ class ConvenienceExecuteTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - cls.table = Table('exec_test', metadata, - Column('a', Integer), - Column('b', Integer), - test_needs_acid=True - ) + cls.table = Table( + "exec_test", + metadata, + Column("a", Integer), + Column("b", Integer), + test_needs_acid=True, + ) def _trans_fn(self, is_transaction=False): def go(conn, x, value=None): if is_transaction: conn = conn.connection conn.execute(self.table.insert().values(a=x, b=value)) + return go def _trans_rollback_fn(self, is_transaction=False): @@ -597,19 +686,19 @@ class ConvenienceExecuteTest(fixtures.TablesTest): conn = conn.connection conn.execute(self.table.insert().values(a=x, b=value)) raise SomeException("breakage") + return go def _assert_no_data(self): eq_( testing.db.scalar( - select([func.count('*')]).select_from(self.table)), 0 + select([func.count("*")]).select_from(self.table) + ), + 0, ) def _assert_fn(self, x, value=None): - eq_( - testing.db.execute(self.table.select()).fetchall(), - [(x, value)] - ) + eq_(testing.db.execute(self.table.select()).fetchall(), [(x, value)]) def test_transaction_engine_ctx_commit(self): fn = self._trans_fn() @@ -621,20 +710,12 @@ class ConvenienceExecuteTest(fixtures.TablesTest): engine = engines.testing_engine() mock_connection = Mock( - return_value=Mock( - begin=Mock(side_effect=Exception("boom")) - ) + return_value=Mock(begin=Mock(side_effect=Exception("boom"))) ) engine._connection_cls = mock_connection - assert_raises( - Exception, - engine.begin - ) + assert_raises(Exception, engine.begin) - eq_( - mock_connection.return_value.close.mock_calls, - [call()] - ) + eq_(mock_connection.return_value.close.mock_calls, [call()]) def test_transaction_engine_ctx_rollback(self): fn = self._trans_rollback_fn() @@ -642,29 +723,37 @@ class ConvenienceExecuteTest(fixtures.TablesTest): assert_raises_message( Exception, "breakage", - testing.run_as_contextmanager, ctx, fn, 5, value=8 + testing.run_as_contextmanager, + ctx, + fn, + 5, + value=8, ) self._assert_no_data() def test_transaction_tlocal_engine_ctx_commit(self): fn = self._trans_fn() - engine = engines.testing_engine(options=dict( - strategy='threadlocal', - pool=testing.db.pool)) + engine = engines.testing_engine( + options=dict(strategy="threadlocal", pool=testing.db.pool) + ) ctx = engine.begin() testing.run_as_contextmanager(ctx, fn, 5, value=8) self._assert_fn(5, value=8) def test_transaction_tlocal_engine_ctx_rollback(self): fn = self._trans_rollback_fn() - engine = engines.testing_engine(options=dict( - strategy='threadlocal', - pool=testing.db.pool)) + engine = engines.testing_engine( + options=dict(strategy="threadlocal", pool=testing.db.pool) + ) ctx = engine.begin() assert_raises_message( Exception, "breakage", - testing.run_as_contextmanager, ctx, fn, 5, value=8 + testing.run_as_contextmanager, + ctx, + fn, + 5, + value=8, ) self._assert_no_data() @@ -682,7 +771,11 @@ class ConvenienceExecuteTest(fixtures.TablesTest): assert_raises_message( Exception, "breakage", - testing.run_as_contextmanager, ctx, fn, 5, value=8 + testing.run_as_contextmanager, + ctx, + fn, + 5, + value=8, ) self._assert_no_data() @@ -693,7 +786,7 @@ class ConvenienceExecuteTest(fixtures.TablesTest): # autocommit is on self._assert_fn(5, value=8) - @testing.fails_on('mysql+oursql', "oursql bug ? getting wrong rowcount") + @testing.fails_on("mysql+oursql", "oursql bug ? getting wrong rowcount") def test_connect_as_ctx_noautocommit(self): fn = self._trans_fn() self._assert_no_data() @@ -712,9 +805,7 @@ class ConvenienceExecuteTest(fixtures.TablesTest): def test_transaction_engine_fn_rollback(self): fn = self._trans_rollback_fn() assert_raises_message( - Exception, - "breakage", - testing.db.transaction, fn, 5, value=8 + Exception, "breakage", testing.db.transaction, fn, 5, value=8 ) self._assert_no_data() @@ -727,10 +818,7 @@ class ConvenienceExecuteTest(fixtures.TablesTest): def test_transaction_connection_fn_rollback(self): fn = self._trans_rollback_fn() with testing.db.connect() as conn: - assert_raises( - Exception, - conn.transaction, fn, 5, value=8 - ) + assert_raises(Exception, conn.transaction, fn, 5, value=8) self._assert_no_data() @@ -741,12 +829,15 @@ class CompiledCacheTest(fixtures.TestBase): def setup_class(cls): global users, metadata metadata = MetaData(testing.db) - users = Table('users', metadata, - Column('user_id', INT, primary_key=True, - test_needs_autoincrement=True), - Column('user_name', VARCHAR(20)), - Column("extra_data", VARCHAR(20)) - ) + users = Table( + "users", + metadata, + Column( + "user_id", INT, primary_key=True, test_needs_autoincrement=True + ), + Column("user_name", VARCHAR(20)), + Column("extra_data", VARCHAR(20)), + ) metadata.create_all() @engines.close_first @@ -764,17 +855,19 @@ class CompiledCacheTest(fixtures.TestBase): ins = users.insert() with patch.object( - ins, "compile", - Mock(side_effect=ins.compile)) as compile_mock: - cached_conn.execute(ins, {'user_name': 'u1'}) - cached_conn.execute(ins, {'user_name': 'u2'}) - cached_conn.execute(ins, {'user_name': 'u3'}) + ins, "compile", Mock(side_effect=ins.compile) + ) as compile_mock: + cached_conn.execute(ins, {"user_name": "u1"}) + cached_conn.execute(ins, {"user_name": "u2"}) + cached_conn.execute(ins, {"user_name": "u3"}) eq_(compile_mock.call_count, 1) assert len(cache) == 1 eq_(conn.execute("select count(*) from users").scalar(), 3) - @testing.only_on(["sqlite", "mysql", "postgresql"], - "uses blob value that is problematic for some DBAPIs") + @testing.only_on( + ["sqlite", "mysql", "postgresql"], + "uses blob value that is problematic for some DBAPIs", + ) @testing.provide_metadata def test_cache_noleak_on_statement_values(self): # This is a non regression test for an object reference leak caused @@ -782,10 +875,12 @@ class CompiledCacheTest(fixtures.TestBase): metadata = self.metadata photo = Table( - 'photo', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('photo_blob', LargeBinary()), + "photo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("photo_blob", LargeBinary()), ) metadata.create_all() @@ -801,9 +896,9 @@ class CompiledCacheTest(fixtures.TestBase): ins = photo.insert() with patch.object( - ins, "compile", - Mock(side_effect=ins.compile)) as compile_mock: - cached_conn.execute(ins, {'photo_blob': blob}) + ins, "compile", Mock(side_effect=ins.compile) + ) as compile_mock: + cached_conn.execute(ins, {"photo_blob": blob}) eq_(compile_mock.call_count, 1) eq_(len(cache), 1) eq_(conn.execute("select count(*) from photo").scalar(), 1) @@ -820,35 +915,45 @@ class CompiledCacheTest(fixtures.TestBase): conn = testing.db.connect() conn.execute( users.insert(), - {"user_id": 1, "user_name": "u1", "extra_data": "e1"}) + {"user_id": 1, "user_name": "u1", "extra_data": "e1"}, + ) cache = {} cached_conn = conn.execution_options(compiled_cache=cache) upd = users.update().where(users.c.user_id == bindparam("b_user_id")) with patch.object( - upd, "compile", - Mock(side_effect=upd.compile)) as compile_mock: + upd, "compile", Mock(side_effect=upd.compile) + ) as compile_mock: cached_conn.execute( - upd, util.OrderedDict([ - ("b_user_id", 1), - ("user_name", "u2"), - ("extra_data", "e2") - ]) + upd, + util.OrderedDict( + [ + ("b_user_id", 1), + ("user_name", "u2"), + ("extra_data", "e2"), + ] + ), ) cached_conn.execute( - upd, util.OrderedDict([ - ("b_user_id", 1), - ("extra_data", "e3"), - ("user_name", "u3"), - ]) + upd, + util.OrderedDict( + [ + ("b_user_id", 1), + ("extra_data", "e3"), + ("user_name", "u3"), + ] + ), ) cached_conn.execute( - upd, util.OrderedDict([ - ("extra_data", "e4"), - ("user_name", "u4"), - ("b_user_id", 1), - ]) + upd, + util.OrderedDict( + [ + ("extra_data", "e4"), + ("user_name", "u4"), + ("b_user_id", 1), + ] + ), ) eq_(compile_mock.call_count, 1) eq_(len(cache), 1) @@ -856,88 +961,86 @@ class CompiledCacheTest(fixtures.TestBase): @testing.requires.schemas @testing.provide_metadata def test_schema_translate_in_key(self): + Table("x", self.metadata, Column("q", Integer)) Table( - 'x', self.metadata, Column('q', Integer)) - Table( - 'x', self.metadata, Column('q', Integer), - schema=config.test_schema) + "x", self.metadata, Column("q", Integer), schema=config.test_schema + ) self.metadata.create_all() m = MetaData() - t1 = Table('x', m, Column('q', Integer)) + t1 = Table("x", m, Column("q", Integer)) ins = t1.insert() stmt = select([t1.c.q]) cache = {} with config.db.connect().execution_options( - compiled_cache=cache, + compiled_cache=cache ) as conn: conn.execute(ins, {"q": 1}) eq_(conn.scalar(stmt), 1) with config.db.connect().execution_options( compiled_cache=cache, - schema_translate_map={None: config.test_schema} + schema_translate_map={None: config.test_schema}, ) as conn: conn.execute(ins, {"q": 2}) eq_(conn.scalar(stmt), 2) with config.db.connect().execution_options( - compiled_cache=cache, + compiled_cache=cache ) as conn: eq_(conn.scalar(stmt), 1) class MockStrategyTest(fixtures.TestBase): - def _engine_fixture(self): buf = util.StringIO() def dump(sql, *multiparams, **params): buf.write(util.text_type(sql.compile(dialect=engine.dialect))) - engine = create_engine('postgresql://', strategy='mock', executor=dump) + + engine = create_engine("postgresql://", strategy="mock", executor=dump) return engine, buf def test_sequence_not_duped(self): engine, buf = self._engine_fixture() metadata = MetaData() - t = Table('testtable', metadata, - Column('pk', - Integer, - Sequence('testtable_pk_seq'), - primary_key=True)) + t = Table( + "testtable", + metadata, + Column( + "pk", Integer, Sequence("testtable_pk_seq"), primary_key=True + ), + ) t.create(engine) t.drop(engine) - eq_( - re.findall(r'CREATE (\w+)', buf.getvalue()), - ["SEQUENCE", "TABLE"] - ) + eq_(re.findall(r"CREATE (\w+)", buf.getvalue()), ["SEQUENCE", "TABLE"]) - eq_( - re.findall(r'DROP (\w+)', buf.getvalue()), - ["TABLE", "SEQUENCE"] - ) + eq_(re.findall(r"DROP (\w+)", buf.getvalue()), ["TABLE", "SEQUENCE"]) class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): - __requires__ = 'schemas', + __requires__ = ("schemas",) __backend__ = True def test_create_table(self): map_ = { None: config.test_schema, - "foo": config.test_schema, "bar": None} + "foo": config.test_schema, + "bar": None, + } metadata = MetaData() - t1 = Table('t1', metadata, Column('x', Integer)) - t2 = Table('t2', metadata, Column('x', Integer), schema="foo") - t3 = Table('t3', metadata, Column('x', Integer), schema="bar") + t1 = Table("t1", metadata, Column("x", Integer)) + t2 = Table("t2", metadata, Column("x", Integer), schema="foo") + t3 = Table("t3", metadata, Column("x", Integer), schema="bar") with self.sql_execution_asserter(config.db) as asserter: with config.db.connect().execution_options( - schema_translate_map=map_) as conn: + schema_translate_map=map_ + ) as conn: t1.create(conn) t2.create(conn) @@ -953,46 +1056,46 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): CompiledSQL("CREATE TABLE t3 (x INTEGER)"), CompiledSQL("DROP TABLE t3"), CompiledSQL("DROP TABLE %s.t2" % config.test_schema), - CompiledSQL("DROP TABLE %s.t1" % config.test_schema) + CompiledSQL("DROP TABLE %s.t1" % config.test_schema), ) def _fixture(self): metadata = self.metadata - Table( - 't1', metadata, Column('x', Integer), - schema=config.test_schema) - Table( - 't2', metadata, Column('x', Integer), - schema=config.test_schema) - Table('t3', metadata, Column('x', Integer), schema=None) + Table("t1", metadata, Column("x", Integer), schema=config.test_schema) + Table("t2", metadata, Column("x", Integer), schema=config.test_schema) + Table("t3", metadata, Column("x", Integer), schema=None) metadata.create_all() def test_ddl_hastable(self): map_ = { None: config.test_schema, - "foo": config.test_schema, "bar": None} + "foo": config.test_schema, + "bar": None, + } metadata = MetaData() - Table('t1', metadata, Column('x', Integer)) - Table('t2', metadata, Column('x', Integer), schema="foo") - Table('t3', metadata, Column('x', Integer), schema="bar") + Table("t1", metadata, Column("x", Integer)) + Table("t2", metadata, Column("x", Integer), schema="foo") + Table("t3", metadata, Column("x", Integer), schema="bar") with config.db.connect().execution_options( - schema_translate_map=map_) as conn: + schema_translate_map=map_ + ) as conn: metadata.create_all(conn) - assert config.db.has_table('t1', schema=config.test_schema) - assert config.db.has_table('t2', schema=config.test_schema) - assert config.db.has_table('t3', schema=None) + assert config.db.has_table("t1", schema=config.test_schema) + assert config.db.has_table("t2", schema=config.test_schema) + assert config.db.has_table("t3", schema=None) with config.db.connect().execution_options( - schema_translate_map=map_) as conn: + schema_translate_map=map_ + ) as conn: metadata.drop_all(conn) - assert not config.db.has_table('t1', schema=config.test_schema) - assert not config.db.has_table('t2', schema=config.test_schema) - assert not config.db.has_table('t3', schema=None) + assert not config.db.has_table("t1", schema=config.test_schema) + assert not config.db.has_table("t2", schema=config.test_schema) + assert not config.db.has_table("t3", schema=None) @testing.provide_metadata def test_crud(self): @@ -1000,20 +1103,23 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): map_ = { None: config.test_schema, - "foo": config.test_schema, "bar": None} + "foo": config.test_schema, + "bar": None, + } metadata = MetaData() - t1 = Table('t1', metadata, Column('x', Integer)) - t2 = Table('t2', metadata, Column('x', Integer), schema="foo") - t3 = Table('t3', metadata, Column('x', Integer), schema="bar") + t1 = Table("t1", metadata, Column("x", Integer)) + t2 = Table("t2", metadata, Column("x", Integer), schema="foo") + t3 = Table("t3", metadata, Column("x", Integer), schema="bar") with self.sql_execution_asserter(config.db) as asserter: with config.db.connect().execution_options( - schema_translate_map=map_) as conn: + schema_translate_map=map_ + ) as conn: - conn.execute(t1.insert(), {'x': 1}) - conn.execute(t2.insert(), {'x': 1}) - conn.execute(t3.insert(), {'x': 1}) + conn.execute(t1.insert(), {"x": 1}) + conn.execute(t2.insert(), {"x": 1}) + conn.execute(t3.insert(), {"x": 1}) conn.execute(t1.update().values(x=1).where(t1.c.x == 1)) conn.execute(t2.update().values(x=2).where(t2.c.x == 1)) @@ -1029,26 +1135,33 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): asserter.assert_( CompiledSQL( - "INSERT INTO %s.t1 (x) VALUES (:x)" % config.test_schema), + "INSERT INTO %s.t1 (x) VALUES (:x)" % config.test_schema + ), CompiledSQL( - "INSERT INTO %s.t2 (x) VALUES (:x)" % config.test_schema), + "INSERT INTO %s.t2 (x) VALUES (:x)" % config.test_schema + ), + CompiledSQL("INSERT INTO t3 (x) VALUES (:x)"), CompiledSQL( - "INSERT INTO t3 (x) VALUES (:x)"), + "UPDATE %s.t1 SET x=:x WHERE %s.t1.x = :x_1" + % (config.test_schema, config.test_schema) + ), CompiledSQL( - "UPDATE %s.t1 SET x=:x WHERE %s.t1.x = :x_1" % ( - config.test_schema, config.test_schema)), - CompiledSQL( - "UPDATE %s.t2 SET x=:x WHERE %s.t2.x = :x_1" % ( - config.test_schema, config.test_schema)), + "UPDATE %s.t2 SET x=:x WHERE %s.t2.x = :x_1" + % (config.test_schema, config.test_schema) + ), CompiledSQL("UPDATE t3 SET x=:x WHERE t3.x = :x_1"), - CompiledSQL("SELECT %s.t1.x FROM %s.t1" % ( - config.test_schema, config.test_schema)), - CompiledSQL("SELECT %s.t2.x FROM %s.t2" % ( - config.test_schema, config.test_schema)), + CompiledSQL( + "SELECT %s.t1.x FROM %s.t1" + % (config.test_schema, config.test_schema) + ), + CompiledSQL( + "SELECT %s.t2.x FROM %s.t2" + % (config.test_schema, config.test_schema) + ), CompiledSQL("SELECT t3.x FROM t3"), CompiledSQL("DELETE FROM %s.t1" % config.test_schema), CompiledSQL("DELETE FROM %s.t2" % config.test_schema), - CompiledSQL("DELETE FROM t3") + CompiledSQL("DELETE FROM t3"), ) @testing.provide_metadata @@ -1057,23 +1170,26 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): map_ = { None: config.test_schema, - "foo": config.test_schema, "bar": None} + "foo": config.test_schema, + "bar": None, + } metadata = MetaData() - t2 = Table('t2', metadata, Column('x', Integer), schema="foo") + t2 = Table("t2", metadata, Column("x", Integer), schema="foo") with self.sql_execution_asserter(config.db) as asserter: eng = config.db.execution_options(schema_translate_map=map_) conn = eng.connect() conn.execute(select([t2.c.x])) asserter.assert_( - CompiledSQL("SELECT %s.t2.x FROM %s.t2" % ( - config.test_schema, config.test_schema)), + CompiledSQL( + "SELECT %s.t2.x FROM %s.t2" + % (config.test_schema, config.test_schema) + ) ) class ExecutionOptionsTest(fixtures.TestBase): - def test_dialect_conn_options(self): engine = testing_engine("sqlite://", options=dict(_initialize=False)) engine.dialect = Mock() @@ -1081,7 +1197,7 @@ class ExecutionOptionsTest(fixtures.TestBase): c2 = conn.execution_options(foo="bar") eq_( engine.dialect.set_connection_execution_options.mock_calls, - [call(c2, {"foo": "bar"})] + [call(c2, {"foo": "bar"})], ) def test_dialect_engine_options(self): @@ -1090,27 +1206,30 @@ class ExecutionOptionsTest(fixtures.TestBase): e2 = engine.execution_options(foo="bar") eq_( engine.dialect.set_engine_execution_options.mock_calls, - [call(e2, {"foo": "bar"})] + [call(e2, {"foo": "bar"})], ) def test_dialect_engine_construction_options(self): dialect = Mock() - engine = Engine(Mock(), dialect, Mock(), - execution_options={"foo": "bar"}) + engine = Engine( + Mock(), dialect, Mock(), execution_options={"foo": "bar"} + ) eq_( dialect.set_engine_execution_options.mock_calls, - [call(engine, {"foo": "bar"})] + [call(engine, {"foo": "bar"})], ) def test_propagate_engine_to_connection(self): - engine = testing_engine("sqlite://", - options=dict(execution_options={"foo": "bar"})) + engine = testing_engine( + "sqlite://", options=dict(execution_options={"foo": "bar"}) + ) conn = engine.connect() eq_(conn._execution_options, {"foo": "bar"}) def test_propagate_option_engine_to_connection(self): - e1 = testing_engine("sqlite://", - options=dict(execution_options={"foo": "bar"})) + e1 = testing_engine( + "sqlite://", options=dict(execution_options={"foo": "bar"}) + ) e2 = e1.execution_options(bat="hoho") c1 = e1.connect() c2 = e2.connect() @@ -1123,14 +1242,11 @@ class ExecutionOptionsTest(fixtures.TestBase): conn = engine.connect() c2 = conn.execution_options(foo="bar") c2_branch = c2.connect() - eq_( - c2_branch._execution_options, - {"foo": "bar"} - ) + eq_(c2_branch._execution_options, {"foo": "bar"}) class EngineEventsTest(fixtures.TestBase): - __requires__ = 'ad_hoc_engines', + __requires__ = ("ad_hoc_engines",) __backend__ = True def tearDown(self): @@ -1143,12 +1259,13 @@ class EngineEventsTest(fixtures.TestBase): if not received: assert False, "Nothing available for stmt: %s" % stmt while received: - teststmt, testparams, testmultiparams = \ - received.pop(0) - teststmt = re.compile(r'[\n\t ]+', re.M).sub(' ', - teststmt).strip() + teststmt, testparams, testmultiparams = received.pop(0) + teststmt = ( + re.compile(r"[\n\t ]+", re.M).sub(" ", teststmt).strip() + ) if teststmt.startswith(stmt) and ( - testparams == params or testparams == posn): + testparams == params or testparams == posn + ): break def test_per_engine_independence(self): @@ -1161,9 +1278,7 @@ class EngineEventsTest(fixtures.TestBase): s2 = select([2]) e1.execute(s1) e2.execute(s2) - eq_( - [arg[1][1] for arg in canary.mock_calls], [s1] - ) + eq_([arg[1][1] for arg in canary.mock_calls], [s1]) event.listen(e2, "before_execute", canary) e1.execute(s1) e2.execute(s2) @@ -1231,8 +1346,9 @@ class EngineEventsTest(fixtures.TestBase): event.listen(e1, "before_execute", canary.be1) - conn = e1._connection_cls(e1, connection=e1.raw_connection(), - _has_events=False) + conn = e1._connection_cls( + e1, connection=e1.raw_connection(), _has_events=False + ) conn.execute(select([1])) @@ -1254,14 +1370,19 @@ class EngineEventsTest(fixtures.TestBase): dialect = conn.dialect ctx = dialect.execution_ctx_cls._init_statement( - dialect, conn, conn.connection, stmt, {}) + dialect, conn, conn.connection, stmt, {} + ) ctx._execute_scalar(stmt, Integer()) - eq_(canary.bce.mock_calls, - [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)]) - eq_(canary.ace.mock_calls, - [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)]) + eq_( + canary.bce.mock_calls, + [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)], + ) + eq_( + canary.ace.mock_calls, + [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)], + ) def test_cursor_events_execute(self): canary = Mock() @@ -1277,10 +1398,14 @@ class EngineEventsTest(fixtures.TestBase): result = conn.execute(stmt) ctx = result.context - eq_(canary.bce.mock_calls, - [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)]) - eq_(canary.ace.mock_calls, - [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)]) + eq_( + canary.bce.mock_calls, + [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)], + ) + eq_( + canary.ace.mock_calls, + [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)], + ) def test_argument_format_execute(self): def before_execute(conn, clauseelement, multiparams, params): @@ -1290,74 +1415,91 @@ class EngineEventsTest(fixtures.TestBase): def after_execute(conn, clauseelement, multiparams, params, result): assert isinstance(multiparams, (list, tuple)) assert isinstance(params, dict) + e1 = testing_engine(config.db_url) - event.listen(e1, 'before_execute', before_execute) - event.listen(e1, 'after_execute', after_execute) + event.listen(e1, "before_execute", before_execute) + event.listen(e1, "after_execute", after_execute) e1.execute(select([1])) e1.execute(select([1]).compile(dialect=e1.dialect).statement) e1.execute(select([1]).compile(dialect=e1.dialect)) e1._execute_compiled(select([1]).compile(dialect=e1.dialect), (), {}) - @testing.fails_on('firebird', 'Data type unknown') + @testing.fails_on("firebird", "Data type unknown") def test_execute_events(self): stmts = [] cursor_stmts = [] - def execute(conn, clauseelement, multiparams, - params): + def execute(conn, clauseelement, multiparams, params): stmts.append((str(clauseelement), params, multiparams)) - def cursor_execute(conn, cursor, statement, parameters, - context, executemany): + def cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): cursor_stmts.append((str(statement), parameters, None)) for engine in [ - engines.testing_engine(options=dict(implicit_returning=False)), - engines.testing_engine(options=dict(implicit_returning=False, - strategy='threadlocal')), - engines.testing_engine(options=dict(implicit_returning=False)). - connect() + engines.testing_engine(options=dict(implicit_returning=False)), + engines.testing_engine( + options=dict(implicit_returning=False, strategy="threadlocal") + ), + engines.testing_engine( + options=dict(implicit_returning=False) + ).connect(), ]: - event.listen(engine, 'before_execute', execute) - event.listen(engine, 'before_cursor_execute', cursor_execute) + event.listen(engine, "before_execute", execute) + event.listen(engine, "before_cursor_execute", cursor_execute) m = MetaData(engine) - t1 = Table('t1', m, - Column('c1', Integer, primary_key=True), - Column('c2', String(50), default=func.lower('Foo'), - primary_key=True) - ) + t1 = Table( + "t1", + m, + Column("c1", Integer, primary_key=True), + Column( + "c2", + String(50), + default=func.lower("Foo"), + primary_key=True, + ), + ) m.create_all() try: - t1.insert().execute(c1=5, c2='some data') + t1.insert().execute(c1=5, c2="some data") t1.insert().execute(c1=6) eq_( - engine.execute('select * from t1').fetchall(), - [(5, 'some data'), (6, 'foo')]) + engine.execute("select * from t1").fetchall(), + [(5, "some data"), (6, "foo")], + ) finally: m.drop_all() - compiled = [('CREATE TABLE t1', {}, None), - ('INSERT INTO t1 (c1, c2)', - {'c2': 'some data', 'c1': 5}, None), - ('INSERT INTO t1 (c1, c2)', - {'c1': 6}, None), - ('select * from t1', {}, None), - ('DROP TABLE t1', {}, None)] + compiled = [ + ("CREATE TABLE t1", {}, None), + ( + "INSERT INTO t1 (c1, c2)", + {"c2": "some data", "c1": 5}, + None, + ), + ("INSERT INTO t1 (c1, c2)", {"c1": 6}, None), + ("select * from t1", {}, None), + ("DROP TABLE t1", {}, None), + ] cursor = [ - ('CREATE TABLE t1', {}, ()), - ('INSERT INTO t1 (c1, c2)', { - 'c2': 'some data', 'c1': 5}, - (5, 'some data')), - ('SELECT lower', {'lower_1': 'Foo'}, - ('Foo', )), - ('INSERT INTO t1 (c1, c2)', - {'c2': 'foo', 'c1': 6}, - (6, 'foo')), - ('select * from t1', {}, ()), - ('DROP TABLE t1', {}, ()), + ("CREATE TABLE t1", {}, ()), + ( + "INSERT INTO t1 (c1, c2)", + {"c2": "some data", "c1": 5}, + (5, "some data"), + ), + ("SELECT lower", {"lower_1": "Foo"}, ("Foo",)), + ( + "INSERT INTO t1 (c1, c2)", + {"c2": "foo", "c1": 6}, + (6, "foo"), + ), + ("select * from t1", {}, ()), + ("DROP TABLE t1", {}, ()), ] self._assert_stmts(compiled, stmts) self._assert_stmts(cursor, cursor_stmts) @@ -1366,21 +1508,21 @@ class EngineEventsTest(fixtures.TestBase): canary = [] def execute(conn, *args, **kw): - canary.append('execute') + canary.append("execute") def cursor_execute(conn, *args, **kw): - canary.append('cursor_execute') + canary.append("cursor_execute") engine = engines.testing_engine() - event.listen(engine, 'before_execute', execute) - event.listen(engine, 'before_cursor_execute', cursor_execute) + event.listen(engine, "before_execute", execute) + event.listen(engine, "before_cursor_execute", cursor_execute) conn = engine.connect() - c2 = conn.execution_options(foo='bar') - eq_(c2._execution_options, {'foo': 'bar'}) + c2 = conn.execution_options(foo="bar") + eq_(c2._execution_options, {"foo": "bar"}) c2.execute(select([1])) - c3 = c2.execution_options(bar='bat') - eq_(c3._execution_options, {'foo': 'bar', 'bar': 'bat'}) - eq_(canary, ['execute', 'cursor_execute']) + c3 = c2.execution_options(bar="bat") + eq_(c3._execution_options, {"foo": "bar", "bar": "bat"}) + eq_(canary, ["execute", "cursor_execute"]) @testing.requires.ad_hoc_engines def test_generative_engine_event_dispatch(self): @@ -1395,8 +1537,9 @@ class EngineEventsTest(fixtures.TestBase): def l3(*arg, **kw): canary.append("l3") - eng = engines.testing_engine(options={'execution_options': - {'base': 'x1'}}) + eng = engines.testing_engine( + options={"execution_options": {"base": "x1"}} + ) event.listen(eng, "before_execute", l1) eng1 = eng.execution_options(foo="b1") @@ -1429,8 +1572,9 @@ class EngineEventsTest(fixtures.TestBase): event.listen(Engine, "before_execute", l1) - eng = engines.testing_engine(options={'execution_options': - {'base': 'x1'}}) + eng = engines.testing_engine( + options={"execution_options": {"base": "x1"}} + ) event.listen(eng, "before_execute", l2) eng1 = eng.execution_options(foo="b1") @@ -1465,7 +1609,10 @@ class EngineEventsTest(fixtures.TestBase): tsa.exc.InvalidRequestError, r"Can't assign an event directly to the " " class", - event.listen, base.OptionEngine, "before_cursor_execute", evt + event.listen, + base.OptionEngine, + "before_cursor_execute", + evt, ) @testing.requires.ad_hoc_engines @@ -1481,17 +1628,11 @@ class EngineEventsTest(fixtures.TestBase): conn = eng.connect() conn.close() - eq_( - canary.mock_calls, - [call(eng)] - ) + eq_(canary.mock_calls, [call(eng)]) eng.dispose() - eq_( - canary.mock_calls, - [call(eng), call(eng)] - ) + eq_(canary.mock_calls, [call(eng), call(eng)]) def test_retval_flag(self): canary = [] @@ -1499,31 +1640,36 @@ class EngineEventsTest(fixtures.TestBase): def tracker(name): def go(conn, *args, **kw): canary.append(name) + return go def execute(conn, clauseelement, multiparams, params): - canary.append('execute') + canary.append("execute") return clauseelement, multiparams, params - def cursor_execute(conn, cursor, statement, - parameters, context, executemany): - canary.append('cursor_execute') + def cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): + canary.append("cursor_execute") return statement, parameters engine = engines.testing_engine() assert_raises( tsa.exc.ArgumentError, - event.listen, engine, "begin", tracker("begin"), retval=True + event.listen, + engine, + "begin", + tracker("begin"), + retval=True, ) event.listen(engine, "before_execute", execute, retval=True) event.listen( - engine, "before_cursor_execute", cursor_execute, retval=True) - engine.execute(select([1])) - eq_( - canary, ['execute', 'cursor_execute'] + engine, "before_cursor_execute", cursor_execute, retval=True ) + engine.execute(select([1])) + eq_(canary, ["execute", "cursor_execute"]) def test_engine_connect(self): engine = engines.testing_engine() @@ -1534,10 +1680,7 @@ class EngineEventsTest(fixtures.TestBase): c1 = engine.connect() c2 = c1._branch() c1.close() - eq_( - tracker.mock_calls, - [call(c1, False), call(c2, True)] - ) + eq_(tracker.mock_calls, [call(c1, False), call(c2, True)]) def test_execution_options(self): engine = engines.testing_engine() @@ -1548,18 +1691,15 @@ class EngineEventsTest(fixtures.TestBase): event.listen(engine, "set_engine_execution_options", engine_tracker) event.listen(engine, "set_connection_execution_options", conn_tracker) - e2 = engine.execution_options(e1='opt_e1') + e2 = engine.execution_options(e1="opt_e1") c1 = engine.connect() - c2 = c1.execution_options(c1='opt_c1') + c2 = c1.execution_options(c1="opt_c1") c3 = e2.connect() - c4 = c3.execution_options(c3='opt_c3') - eq_( - engine_tracker.mock_calls, - [call(e2, {'e1': 'opt_e1'})] - ) + c4 = c3.execution_options(c3="opt_c3") + eq_(engine_tracker.mock_calls, [call(e2, {"e1": "opt_e1"})]) eq_( conn_tracker.mock_calls, - [call(c2, {"c1": "opt_c1"}), call(c4, {"c3": "opt_c3"})] + [call(c2, {"c1": "opt_c1"}), call(c4, {"c3": "opt_c3"})], ) @testing.requires.sequences @@ -1570,25 +1710,28 @@ class EngineEventsTest(fixtures.TestBase): def tracker(name): def go(conn, cursor, statement, parameters, context, executemany): canary.append((statement, context)) + return go + engine = engines.testing_engine() - t = Table('t', self.metadata, - Column('x', Integer, Sequence('t_id_seq'), primary_key=True), - implicit_returning=False - ) + t = Table( + "t", + self.metadata, + Column("x", Integer, Sequence("t_id_seq"), primary_key=True), + implicit_returning=False, + ) self.metadata.create_all(engine) with engine.begin() as conn: event.listen( - conn, 'before_cursor_execute', tracker('cursor_execute')) + conn, "before_cursor_execute", tracker("cursor_execute") + ) conn.execute(t.insert()) # we see the sequence pre-executed in the first call assert "t_id_seq" in canary[0][0] assert "INSERT" in canary[1][0] # same context - is_( - canary[0][1], canary[1][1] - ) + is_(canary[0][1], canary[1][1]) def test_transactional(self): canary = [] @@ -1596,15 +1739,17 @@ class EngineEventsTest(fixtures.TestBase): def tracker(name): def go(conn, *args, **kw): canary.append(name) + return go engine = engines.testing_engine() - event.listen(engine, 'before_execute', tracker('execute')) + event.listen(engine, "before_execute", tracker("execute")) event.listen( - engine, 'before_cursor_execute', tracker('cursor_execute')) - event.listen(engine, 'begin', tracker('begin')) - event.listen(engine, 'commit', tracker('commit')) - event.listen(engine, 'rollback', tracker('rollback')) + engine, "before_cursor_execute", tracker("cursor_execute") + ) + event.listen(engine, "begin", tracker("begin")) + event.listen(engine, "commit", tracker("commit")) + event.listen(engine, "rollback", tracker("rollback")) conn = engine.connect() trans = conn.begin() @@ -1615,10 +1760,18 @@ class EngineEventsTest(fixtures.TestBase): trans.commit() eq_( - canary, [ - 'begin', 'execute', 'cursor_execute', 'rollback', - 'begin', 'execute', 'cursor_execute', 'commit', - ]) + canary, + [ + "begin", + "execute", + "cursor_execute", + "rollback", + "begin", + "execute", + "cursor_execute", + "commit", + ], + ) def test_transactional_named(self): canary = [] @@ -1626,16 +1779,20 @@ class EngineEventsTest(fixtures.TestBase): def tracker(name): def go(*args, **kw): canary.append((name, set(kw))) + return go engine = engines.testing_engine() - event.listen(engine, 'before_execute', tracker('execute'), named=True) + event.listen(engine, "before_execute", tracker("execute"), named=True) event.listen( - engine, 'before_cursor_execute', - tracker('cursor_execute'), named=True) - event.listen(engine, 'begin', tracker('begin'), named=True) - event.listen(engine, 'commit', tracker('commit'), named=True) - event.listen(engine, 'rollback', tracker('rollback'), named=True) + engine, + "before_cursor_execute", + tracker("cursor_execute"), + named=True, + ) + event.listen(engine, "begin", tracker("begin"), named=True) + event.listen(engine, "commit", tracker("commit"), named=True) + event.listen(engine, "rollback", tracker("rollback"), named=True) conn = engine.connect() trans = conn.begin() @@ -1646,20 +1803,47 @@ class EngineEventsTest(fixtures.TestBase): trans.commit() eq_( - canary, [ - ('begin', set(['conn', ])), - ('execute', set([ - 'conn', 'clauseelement', 'multiparams', 'params'])), - ('cursor_execute', set([ - 'conn', 'cursor', 'executemany', - 'statement', 'parameters', 'context'])), - ('rollback', set(['conn', ])), ('begin', set(['conn', ])), - ('execute', set([ - 'conn', 'clauseelement', 'multiparams', 'params'])), - ('cursor_execute', set([ - 'conn', 'cursor', 'executemany', 'statement', - 'parameters', 'context'])), - ('commit', set(['conn', ]))] + canary, + [ + ("begin", set(["conn"])), + ( + "execute", + set(["conn", "clauseelement", "multiparams", "params"]), + ), + ( + "cursor_execute", + set( + [ + "conn", + "cursor", + "executemany", + "statement", + "parameters", + "context", + ] + ), + ), + ("rollback", set(["conn"])), + ("begin", set(["conn"])), + ( + "execute", + set(["conn", "clauseelement", "multiparams", "params"]), + ), + ( + "cursor_execute", + set( + [ + "conn", + "cursor", + "executemany", + "statement", + "parameters", + "context", + ] + ), + ), + ("commit", set(["conn"])), + ], ) @testing.requires.savepoints @@ -1670,27 +1854,42 @@ class EngineEventsTest(fixtures.TestBase): def tracker1(name): def go(*args, **kw): canary1.append(name) + return go + canary2 = [] def tracker2(name): def go(*args, **kw): canary2.append(name) + return go engine = engines.testing_engine() - for name in ['begin', 'savepoint', - 'rollback_savepoint', 'release_savepoint', - 'rollback', 'begin_twophase', - 'prepare_twophase', 'commit_twophase']: - event.listen(engine, '%s' % name, tracker1(name)) + for name in [ + "begin", + "savepoint", + "rollback_savepoint", + "release_savepoint", + "rollback", + "begin_twophase", + "prepare_twophase", + "commit_twophase", + ]: + event.listen(engine, "%s" % name, tracker1(name)) conn = engine.connect() - for name in ['begin', 'savepoint', - 'rollback_savepoint', 'release_savepoint', - 'rollback', 'begin_twophase', - 'prepare_twophase', 'commit_twophase']: - event.listen(conn, '%s' % name, tracker2(name)) + for name in [ + "begin", + "savepoint", + "rollback_savepoint", + "release_savepoint", + "rollback", + "begin_twophase", + "prepare_twophase", + "commit_twophase", + ]: + event.listen(conn, "%s" % name, tracker2(name)) trans = conn.begin() trans2 = conn.begin_nested() @@ -1706,20 +1905,38 @@ class EngineEventsTest(fixtures.TestBase): trans.prepare() trans.commit() - eq_(canary1, ['begin', 'savepoint', - 'rollback_savepoint', 'savepoint', 'release_savepoint', - 'rollback', 'begin_twophase', - 'prepare_twophase', 'commit_twophase'] - ) - eq_(canary2, ['begin', 'savepoint', - 'rollback_savepoint', 'savepoint', 'release_savepoint', - 'rollback', 'begin_twophase', - 'prepare_twophase', 'commit_twophase'] - ) + eq_( + canary1, + [ + "begin", + "savepoint", + "rollback_savepoint", + "savepoint", + "release_savepoint", + "rollback", + "begin_twophase", + "prepare_twophase", + "commit_twophase", + ], + ) + eq_( + canary2, + [ + "begin", + "savepoint", + "rollback_savepoint", + "savepoint", + "release_savepoint", + "rollback", + "begin_twophase", + "prepare_twophase", + "commit_twophase", + ], + ) class HandleErrorTest(fixtures.TestBase): - __requires__ = 'ad_hoc_engines', + __requires__ = ("ad_hoc_engines",) __backend__ = True def tearDown(self): @@ -1744,7 +1961,7 @@ class HandleErrorTest(fixtures.TestBase): engine = engines.testing_engine() listener = Mock(return_value=None) - event.listen(engine, 'dbapi_error', listener) + event.listen(engine, "dbapi_error", listener) nope = SomeException("nope") @@ -1760,8 +1977,7 @@ class HandleErrorTest(fixtures.TestBase): r"\(test.engine.test_execute.SomeException\) " r"nope \[SQL\: u?'SELECT 1 ", conn.execute, - select([1]).where( - column('foo') == literal('bar', MyType())) + select([1]).where(column("foo") == literal("bar", MyType())), ) # no legacy event eq_(listener.mock_calls, []) @@ -1770,21 +1986,16 @@ class HandleErrorTest(fixtures.TestBase): engine = engines.testing_engine() listener = Mock(return_value=None) - event.listen(engine, 'dbapi_error', listener) + event.listen(engine, "dbapi_error", listener) nope = TypeError("I'm not a DBAPI error") with engine.connect() as c: c.connection.cursor = Mock( - return_value=Mock( - execute=Mock( - side_effect=nope - )) + return_value=Mock(execute=Mock(side_effect=nope)) ) assert_raises_message( - TypeError, - "I'm not a DBAPI error", - c.execute, "select " + TypeError, "I'm not a DBAPI error", c.execute, "select " ) # no legacy event eq_(listener.mock_calls, []) @@ -1812,7 +2023,7 @@ class HandleErrorTest(fixtures.TestBase): class MyException(Exception): pass - @event.listens_for(engine, 'handle_error', retval=True) + @event.listens_for(engine, "handle_error", retval=True) def err(context): stmt = context.statement exception = context.original_exception @@ -1828,18 +2039,21 @@ class HandleErrorTest(fixtures.TestBase): assert_raises_message( MyException, "my exception", - conn.execute, "SELECT 'ERROR ONE' FROM I_DONT_EXIST" + conn.execute, + "SELECT 'ERROR ONE' FROM I_DONT_EXIST", ) # case 2: return the DBAPI exception we're given; # no wrapping should occur assert_raises( conn.dialect.dbapi.Error, - conn.execute, "SELECT 'ERROR TWO' FROM I_DONT_EXIST" + conn.execute, + "SELECT 'ERROR TWO' FROM I_DONT_EXIST", ) # case 3: normal wrapping assert_raises( tsa.exc.DBAPIError, - conn.execute, "SELECT 'ERROR THREE' FROM I_DONT_EXIST" + conn.execute, + "SELECT 'ERROR THREE' FROM I_DONT_EXIST", ) def test_exception_event_reraise_chaining(self): @@ -1854,21 +2068,25 @@ class HandleErrorTest(fixtures.TestBase): class MyException3(Exception): pass - @event.listens_for(engine, 'handle_error', retval=True) + @event.listens_for(engine, "handle_error", retval=True) def err1(context): stmt = context.statement - if "ERROR ONE" in str(stmt) or "ERROR TWO" in str(stmt) \ - or "ERROR THREE" in str(stmt): + if ( + "ERROR ONE" in str(stmt) + or "ERROR TWO" in str(stmt) + or "ERROR THREE" in str(stmt) + ): return MyException1("my exception") elif "ERROR FOUR" in str(stmt): raise MyException3("my exception short circuit") - @event.listens_for(engine, 'handle_error', retval=True) + @event.listens_for(engine, "handle_error", retval=True) def err2(context): stmt = context.statement - if ("ERROR ONE" in str(stmt) or "ERROR FOUR" in str(stmt)) \ - and isinstance(context.chained_exception, MyException1): + if ( + "ERROR ONE" in str(stmt) or "ERROR FOUR" in str(stmt) + ) and isinstance(context.chained_exception, MyException1): raise MyException2("my exception chained") elif "ERROR TWO" in str(stmt): return context.chained_exception @@ -1877,52 +2095,57 @@ class HandleErrorTest(fixtures.TestBase): conn = engine.connect() - with patch.object(engine. - dialect.execution_ctx_cls, - "handle_dbapi_exception") as patched: + with patch.object( + engine.dialect.execution_ctx_cls, "handle_dbapi_exception" + ) as patched: assert_raises_message( MyException2, "my exception chained", - conn.execute, "SELECT 'ERROR ONE' FROM I_DONT_EXIST" + conn.execute, + "SELECT 'ERROR ONE' FROM I_DONT_EXIST", ) eq_(patched.call_count, 1) - with patch.object(engine. - dialect.execution_ctx_cls, - "handle_dbapi_exception") as patched: + with patch.object( + engine.dialect.execution_ctx_cls, "handle_dbapi_exception" + ) as patched: assert_raises( MyException1, - conn.execute, "SELECT 'ERROR TWO' FROM I_DONT_EXIST" + conn.execute, + "SELECT 'ERROR TWO' FROM I_DONT_EXIST", ) eq_(patched.call_count, 1) - with patch.object(engine. - dialect.execution_ctx_cls, - "handle_dbapi_exception") as patched: + with patch.object( + engine.dialect.execution_ctx_cls, "handle_dbapi_exception" + ) as patched: # test that non None from err1 isn't cancelled out # by err2 assert_raises( MyException1, - conn.execute, "SELECT 'ERROR THREE' FROM I_DONT_EXIST" + conn.execute, + "SELECT 'ERROR THREE' FROM I_DONT_EXIST", ) eq_(patched.call_count, 1) - with patch.object(engine. - dialect.execution_ctx_cls, - "handle_dbapi_exception") as patched: + with patch.object( + engine.dialect.execution_ctx_cls, "handle_dbapi_exception" + ) as patched: assert_raises( tsa.exc.DBAPIError, - conn.execute, "SELECT 'ERROR FIVE' FROM I_DONT_EXIST" + conn.execute, + "SELECT 'ERROR FIVE' FROM I_DONT_EXIST", ) eq_(patched.call_count, 1) - with patch.object(engine. - dialect.execution_ctx_cls, - "handle_dbapi_exception") as patched: + with patch.object( + engine.dialect.execution_ctx_cls, "handle_dbapi_exception" + ) as patched: assert_raises_message( MyException3, "my exception short circuit", - conn.execute, "SELECT 'ERROR FOUR' FROM I_DONT_EXIST" + conn.execute, + "SELECT 'ERROR FOUR' FROM I_DONT_EXIST", ) eq_(patched.call_count, 1) @@ -1937,14 +2160,14 @@ class HandleErrorTest(fixtures.TestBase): r"An exception has occurred during handling of a previous " r"exception. The previous exception " r"is.*(?:i_dont_exist|does not exist)", - py2konly=True + py2konly=True, ): with patch.object(conn.dialect, "do_rollback", boom) as patched: assert_raises_message( tsa.exc.OperationalError, "rollback failed", conn.execute, - "insert into i_dont_exist (x) values ('y')" + "insert into i_dont_exist (x) values ('y')", ) def test_exception_event_ad_hoc_context(self): @@ -1957,7 +2180,7 @@ class HandleErrorTest(fixtures.TestBase): engine = engines.testing_engine() listener = Mock(return_value=None) - event.listen(engine, 'handle_error', listener) + event.listen(engine, "handle_error", listener) nope = SomeException("nope") @@ -1973,8 +2196,7 @@ class HandleErrorTest(fixtures.TestBase): r"\(test.engine.test_execute.SomeException\) " r"nope \[SQL\: u?'SELECT 1 ", conn.execute, - select([1]).where( - column('foo') == literal('bar', MyType())) + select([1]).where(column("foo") == literal("bar", MyType())), ) ctx = listener.mock_calls[0][1][0] @@ -1991,21 +2213,16 @@ class HandleErrorTest(fixtures.TestBase): engine = engines.testing_engine() listener = Mock(return_value=None) - event.listen(engine, 'handle_error', listener) + event.listen(engine, "handle_error", listener) nope = TypeError("I'm not a DBAPI error") with engine.connect() as c: c.connection.cursor = Mock( - return_value=Mock( - execute=Mock( - side_effect=nope - )) + return_value=Mock(execute=Mock(side_effect=nope)) ) assert_raises_message( - TypeError, - "I'm not a DBAPI error", - c.execute, "select " + TypeError, "I'm not a DBAPI error", c.execute, "select " ) ctx = listener.mock_calls[0][1][0] eq_(ctx.statement, "select ") @@ -2018,7 +2235,7 @@ class HandleErrorTest(fixtures.TestBase): class MyException1(Exception): pass - @event.listens_for(engine, 'handle_error') + @event.listens_for(engine, "handle_error") def err1(context): stmt = context.statement @@ -2028,16 +2245,14 @@ class HandleErrorTest(fixtures.TestBase): with engine.connect() as conn: assert_raises( tsa.exc.DBAPIError, - conn.execution_options( - skip_user_error_events=True - ).execute, "SELECT ERROR_ONE FROM I_DONT_EXIST" + conn.execution_options(skip_user_error_events=True).execute, + "SELECT ERROR_ONE FROM I_DONT_EXIST", ) assert_raises( MyException1, - conn.execution_options( - skip_user_error_events=False - ).execute, "SELECT ERROR_ONE FROM I_DONT_EXIST" + conn.execution_options(skip_user_error_events=False).execute, + "SELECT ERROR_ONE FROM I_DONT_EXIST", ) def _test_alter_disconnect(self, orig_error, evt_value): @@ -2047,8 +2262,9 @@ class HandleErrorTest(fixtures.TestBase): def evt(ctx): ctx.is_disconnect = evt_value - with patch.object(engine.dialect, "is_disconnect", - Mock(return_value=orig_error)): + with patch.object( + engine.dialect, "is_disconnect", Mock(return_value=orig_error) + ): with engine.connect() as c: try: @@ -2076,15 +2292,19 @@ class HandleErrorTest(fixtures.TestBase): if set_to_false: ctx.invalidate_pool_on_disconnect = False - c1, c2, c3 = engine.pool.connect(), \ - engine.pool.connect(), engine.pool.connect() + c1, c2, c3 = ( + engine.pool.connect(), + engine.pool.connect(), + engine.pool.connect(), + ) crecs = [conn._connection_record for conn in (c1, c2, c3)] c1.close() c2.close() c3.close() - with patch.object(engine.dialect, "is_disconnect", - Mock(return_value=orig_error)): + with patch.object( + engine.dialect, "is_disconnect", Mock(return_value=orig_error) + ): with engine.connect() as c: target_crec = c.connection._connection_record @@ -2119,36 +2339,34 @@ class HandleErrorTest(fixtures.TestBase): ProgrammingError = engine.dialect.dbapi.ProgrammingError with engine.connect() as conn: with patch.object( - conn.dialect, "get_isolation_level", - Mock(side_effect=ProgrammingError("random error")) + conn.dialect, + "get_isolation_level", + Mock(side_effect=ProgrammingError("random error")), ): - assert_raises( - MySpecialException, - conn.get_isolation_level - ) + assert_raises(MySpecialException, conn.get_isolation_level) class HandleInvalidatedOnConnectTest(fixtures.TestBase): - __requires__ = ('sqlite', ) + __requires__ = ("sqlite",) def setUp(self): - e = create_engine('sqlite://') + e = create_engine("sqlite://") - connection = Mock( - get_server_version_info=Mock(return_value='5.0')) + connection = Mock(get_server_version_info=Mock(return_value="5.0")) def connect(*args, **kwargs): return connection + dbapi = Mock( - sqlite_version_info=(99, 9, 9,), - version_info=(99, 9, 9,), - sqlite_version='99.9.9', - paramstyle='named', - connect=Mock(side_effect=connect) + sqlite_version_info=(99, 9, 9), + version_info=(99, 9, 9), + sqlite_version="99.9.9", + paramstyle="named", + connect=Mock(side_effect=connect), ) sqlite3 = e.dialect.dbapi - dbapi.Error = sqlite3.Error, + dbapi.Error = (sqlite3.Error,) dbapi.ProgrammingError = sqlite3.ProgrammingError self.dbapi = dbapi @@ -2156,23 +2374,21 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase): def test_wraps_connect_in_dbapi(self): dbapi = self.dbapi - dbapi.connect = Mock( - side_effect=self.ProgrammingError("random error")) + dbapi.connect = Mock(side_effect=self.ProgrammingError("random error")) try: - create_engine('sqlite://', module=dbapi).connect() + create_engine("sqlite://", module=dbapi).connect() assert False except tsa.exc.DBAPIError as de: assert not de.connection_invalidated def test_handle_error_event_connect(self): dbapi = self.dbapi - dbapi.connect = Mock( - side_effect=self.ProgrammingError("random error")) + dbapi.connect = Mock(side_effect=self.ProgrammingError("random error")) class MySpecialException(Exception): pass - eng = create_engine('sqlite://', module=dbapi) + eng = create_engine("sqlite://", module=dbapi) @event.listens_for(eng, "handle_error") def handle_error(ctx): @@ -2180,10 +2396,7 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase): assert ctx.connection is None raise MySpecialException("failed operation") - assert_raises( - MySpecialException, - eng.connect - ) + assert_raises(MySpecialException, eng.connect) def test_handle_error_event_revalidate(self): dbapi = self.dbapi @@ -2191,26 +2404,23 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase): class MySpecialException(Exception): pass - eng = create_engine('sqlite://', module=dbapi, _initialize=False) + eng = create_engine("sqlite://", module=dbapi, _initialize=False) @event.listens_for(eng, "handle_error") def handle_error(ctx): assert ctx.engine is eng assert ctx.connection is conn - assert isinstance(ctx.sqlalchemy_exception, - tsa.exc.ProgrammingError) + assert isinstance( + ctx.sqlalchemy_exception, tsa.exc.ProgrammingError + ) raise MySpecialException("failed operation") conn = eng.connect() conn.invalidate() - dbapi.connect = Mock( - side_effect=self.ProgrammingError("random error")) + dbapi.connect = Mock(side_effect=self.ProgrammingError("random error")) - assert_raises( - MySpecialException, - getattr, conn, 'connection' - ) + assert_raises(MySpecialException, getattr, conn, "connection") def test_handle_error_event_implicit_revalidate(self): dbapi = self.dbapi @@ -2218,26 +2428,23 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase): class MySpecialException(Exception): pass - eng = create_engine('sqlite://', module=dbapi, _initialize=False) + eng = create_engine("sqlite://", module=dbapi, _initialize=False) @event.listens_for(eng, "handle_error") def handle_error(ctx): assert ctx.engine is eng assert ctx.connection is conn assert isinstance( - ctx.sqlalchemy_exception, tsa.exc.ProgrammingError) + ctx.sqlalchemy_exception, tsa.exc.ProgrammingError + ) raise MySpecialException("failed operation") conn = eng.connect() conn.invalidate() - dbapi.connect = Mock( - side_effect=self.ProgrammingError("random error")) + dbapi.connect = Mock(side_effect=self.ProgrammingError("random error")) - assert_raises( - MySpecialException, - conn.execute, select([1]) - ) + assert_raises(MySpecialException, conn.execute, select([1])) def test_handle_error_custom_connect(self): dbapi = self.dbapi @@ -2248,7 +2455,7 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase): def custom_connect(): raise self.ProgrammingError("random error") - eng = create_engine('sqlite://', module=dbapi, creator=custom_connect) + eng = create_engine("sqlite://", module=dbapi, creator=custom_connect) @event.listens_for(eng, "handle_error") def handle_error(ctx): @@ -2256,21 +2463,20 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase): assert ctx.connection is None raise MySpecialException("failed operation") - assert_raises( - MySpecialException, - eng.connect - ) + assert_raises(MySpecialException, eng.connect) def test_handle_error_event_connect_invalidate_flag(self): dbapi = self.dbapi dbapi.connect = Mock( side_effect=self.ProgrammingError( - "Cannot operate on a closed database.")) + "Cannot operate on a closed database." + ) + ) class MySpecialException(Exception): pass - eng = create_engine('sqlite://', module=dbapi) + eng = create_engine("sqlite://", module=dbapi) @event.listens_for(eng, "handle_error") def handle_error(ctx): @@ -2287,7 +2493,7 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase): class MySpecialException(Exception): pass - eng = create_engine('sqlite://') + eng = create_engine("sqlite://") @event.listens_for(eng, "handle_error") def handle_error(ctx): @@ -2299,7 +2505,9 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase): eng.pool._creator = Mock( side_effect=self.ProgrammingError( - "Cannot operate on a closed database.")) + "Cannot operate on a closed database." + ) + ) try: conn.connection @@ -2311,22 +2519,22 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase): dbapi = self.dbapi dbapi.connect = Mock(side_effect=TypeError("I'm not a DBAPI error")) - e = create_engine('sqlite://', module=dbapi) + e = create_engine("sqlite://", module=dbapi) e.dialect.is_disconnect = is_disconnect = Mock() assert_raises_message( - TypeError, - "I'm not a DBAPI error", - connect_fn, e + TypeError, "I'm not a DBAPI error", connect_fn, e ) eq_(is_disconnect.call_count, 0) def test_dont_touch_non_dbapi_exception_on_connect(self): self._test_dont_touch_non_dbapi_exception_on_connect( - lambda engine: engine.connect()) + lambda engine: engine.connect() + ) def test_dont_touch_non_dbapi_exception_on_contextual_connect(self): self._test_dont_touch_non_dbapi_exception_on_connect( - lambda engine: engine.contextual_connect()) + lambda engine: engine.contextual_connect() + ) def test_ensure_dialect_does_is_disconnect_no_conn(self): """test that is_disconnect() doesn't choke if no connection, @@ -2334,7 +2542,8 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase): dialect = testing.db.dialect dbapi = dialect.dbapi assert not dialect.is_disconnect( - dbapi.OperationalError("test"), None, None) + dbapi.OperationalError("test"), None, None + ) def _test_invalidate_on_connect(self, connect_fn): """test that is_disconnect() is called during connect. @@ -2347,9 +2556,11 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase): dbapi = self.dbapi dbapi.connect = Mock( side_effect=self.ProgrammingError( - "Cannot operate on a closed database.")) + "Cannot operate on a closed database." + ) + ) try: - connect_fn(create_engine('sqlite://', module=dbapi)) + connect_fn(create_engine("sqlite://", module=dbapi)) assert False except tsa.exc.DBAPIError as de: assert de.connection_invalidated @@ -2371,7 +2582,8 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase): """ self._test_invalidate_on_connect( - lambda engine: engine.contextual_connect()) + lambda engine: engine.contextual_connect() + ) class ProxyConnectionTest(fixtures.TestBase): @@ -2380,25 +2592,20 @@ class ProxyConnectionTest(fixtures.TestBase): the deprecated ConnectionProxy interface. """ - __requires__ = 'ad_hoc_engines', - __prefer_requires__ = 'two_phase_transactions', - @testing.uses_deprecated(r'.*Use event.listen') - @testing.fails_on('firebird', 'Data type unknown') + __requires__ = ("ad_hoc_engines",) + __prefer_requires__ = ("two_phase_transactions",) + + @testing.uses_deprecated(r".*Use event.listen") + @testing.fails_on("firebird", "Data type unknown") def test_proxy(self): stmts = [] cursor_stmts = [] class MyProxy(ConnectionProxy): - def execute( - self, - conn, - execute, - clauseelement, - *multiparams, - **params + self, conn, execute, clauseelement, *multiparams, **params ): stmts.append((str(clauseelement), params, multiparams)) return execute(clauseelement, *multiparams, **params) @@ -2420,95 +2627,122 @@ class ProxyConnectionTest(fixtures.TestBase): if not received: assert False, "Nothing available for stmt: %s" % stmt while received: - teststmt, testparams, testmultiparams = \ - received.pop(0) - teststmt = re.compile( - r'[\n\t ]+', re.M).sub(' ', teststmt).strip() + teststmt, testparams, testmultiparams = received.pop(0) + teststmt = ( + re.compile(r"[\n\t ]+", re.M) + .sub(" ", teststmt) + .strip() + ) if teststmt.startswith(stmt) and ( - testparams == params or testparams == posn): + testparams == params or testparams == posn + ): break - for engine in \ - engines.testing_engine(options=dict(implicit_returning=False, - proxy=MyProxy())), \ - engines.testing_engine(options=dict(implicit_returning=False, - proxy=MyProxy(), - strategy='threadlocal')): + for engine in ( + engines.testing_engine( + options=dict(implicit_returning=False, proxy=MyProxy()) + ), + engines.testing_engine( + options=dict( + implicit_returning=False, + proxy=MyProxy(), + strategy="threadlocal", + ) + ), + ): m = MetaData(engine) - t1 = Table('t1', m, - Column('c1', Integer, primary_key=True), - Column('c2', String(50), default=func.lower('Foo'), - primary_key=True) - ) + t1 = Table( + "t1", + m, + Column("c1", Integer, primary_key=True), + Column( + "c2", + String(50), + default=func.lower("Foo"), + primary_key=True, + ), + ) m.create_all() try: - t1.insert().execute(c1=5, c2='some data') + t1.insert().execute(c1=5, c2="some data") t1.insert().execute(c1=6) eq_( - engine.execute('select * from t1').fetchall(), - [(5, 'some data'), (6, 'foo')]) + engine.execute("select * from t1").fetchall(), + [(5, "some data"), (6, "foo")], + ) finally: m.drop_all() engine.dispose() - compiled = [('CREATE TABLE t1', {}, None), - ('INSERT INTO t1 (c1, c2)', {'c2': 'some data', - 'c1': 5}, None), - ('INSERT INTO t1 (c1, c2)', {'c1': 6}, None), - ('select * from t1', {}, None), - ('DROP TABLE t1', {}, None)] + compiled = [ + ("CREATE TABLE t1", {}, None), + ( + "INSERT INTO t1 (c1, c2)", + {"c2": "some data", "c1": 5}, + None, + ), + ("INSERT INTO t1 (c1, c2)", {"c1": 6}, None), + ("select * from t1", {}, None), + ("DROP TABLE t1", {}, None), + ] cursor = [ - ('CREATE TABLE t1', {}, ()), - ('INSERT INTO t1 (c1, c2)', { - 'c2': 'some data', 'c1': 5}, (5, 'some data')), - ('SELECT lower', {'lower_1': 'Foo'}, - ('Foo', )), - ('INSERT INTO t1 (c1, c2)', {'c2': 'foo', 'c1': 6}, - (6, 'foo')), - ('select * from t1', {}, ()), - ('DROP TABLE t1', {}, ()), + ("CREATE TABLE t1", {}, ()), + ( + "INSERT INTO t1 (c1, c2)", + {"c2": "some data", "c1": 5}, + (5, "some data"), + ), + ("SELECT lower", {"lower_1": "Foo"}, ("Foo",)), + ( + "INSERT INTO t1 (c1, c2)", + {"c2": "foo", "c1": 6}, + (6, "foo"), + ), + ("select * from t1", {}, ()), + ("DROP TABLE t1", {}, ()), ] assert_stmts(compiled, stmts) assert_stmts(cursor, cursor_stmts) - @testing.uses_deprecated(r'.*Use event.listen') + @testing.uses_deprecated(r".*Use event.listen") def test_options(self): canary = [] class TrackProxy(ConnectionProxy): - def __getattribute__(self, key): fn = object.__getattribute__(self, key) def go(*arg, **kw): canary.append(fn.__name__) return fn(*arg, **kw) + return go - engine = engines.testing_engine(options={'proxy': TrackProxy()}) + + engine = engines.testing_engine(options={"proxy": TrackProxy()}) conn = engine.connect() - c2 = conn.execution_options(foo='bar') - eq_(c2._execution_options, {'foo': 'bar'}) + c2 = conn.execution_options(foo="bar") + eq_(c2._execution_options, {"foo": "bar"}) c2.execute(select([1])) - c3 = c2.execution_options(bar='bat') - eq_(c3._execution_options, {'foo': 'bar', 'bar': 'bat'}) - eq_(canary, ['execute', 'cursor_execute']) + c3 = c2.execution_options(bar="bat") + eq_(c3._execution_options, {"foo": "bar", "bar": "bat"}) + eq_(canary, ["execute", "cursor_execute"]) - @testing.uses_deprecated(r'.*Use event.listen') + @testing.uses_deprecated(r".*Use event.listen") def test_transactional(self): canary = [] class TrackProxy(ConnectionProxy): - def __getattribute__(self, key): fn = object.__getattribute__(self, key) def go(*arg, **kw): canary.append(fn.__name__) return fn(*arg, **kw) + return go - engine = engines.testing_engine(options={'proxy': TrackProxy()}) + engine = engines.testing_engine(options={"proxy": TrackProxy()}) conn = engine.connect() trans = conn.begin() conn.execute(select([1])) @@ -2518,28 +2752,36 @@ class ProxyConnectionTest(fixtures.TestBase): trans.commit() eq_( - canary, [ - 'begin', 'execute', 'cursor_execute', 'rollback', - 'begin', 'execute', 'cursor_execute', 'commit', - ]) + canary, + [ + "begin", + "execute", + "cursor_execute", + "rollback", + "begin", + "execute", + "cursor_execute", + "commit", + ], + ) - @testing.uses_deprecated(r'.*Use event.listen') + @testing.uses_deprecated(r".*Use event.listen") @testing.requires.savepoints @testing.requires.two_phase_transactions def test_transactional_advanced(self): canary = [] class TrackProxy(ConnectionProxy): - def __getattribute__(self, key): fn = object.__getattribute__(self, key) def go(*arg, **kw): canary.append(fn.__name__) return fn(*arg, **kw) + return go - engine = engines.testing_engine(options={'proxy': TrackProxy()}) + engine = engines.testing_engine(options={"proxy": TrackProxy()}) conn = engine.connect() trans = conn.begin() @@ -2556,16 +2798,24 @@ class ProxyConnectionTest(fixtures.TestBase): trans.prepare() trans.commit() - canary = [t for t in canary if t not in ('cursor_execute', 'execute')] - eq_(canary, ['begin', 'savepoint', - 'rollback_savepoint', 'savepoint', 'release_savepoint', - 'rollback', 'begin_twophase', - 'prepare_twophase', 'commit_twophase'] - ) + canary = [t for t in canary if t not in ("cursor_execute", "execute")] + eq_( + canary, + [ + "begin", + "savepoint", + "rollback_savepoint", + "savepoint", + "release_savepoint", + "rollback", + "begin_twophase", + "prepare_twophase", + "commit_twophase", + ], + ) class DialectEventTest(fixtures.TestBase): - @contextmanager def _run_test(self, retval): m1 = Mock() @@ -2587,12 +2837,15 @@ class DialectEventTest(fixtures.TestBase): arg[-1].get_result_proxy = Mock(return_value=Mock(context=arg[-1])) return retval - m1.real_do_execute.side_effect = \ - m1.do_execute.side_effect = mock_the_cursor - m1.real_do_executemany.side_effect = \ - m1.do_executemany.side_effect = mock_the_cursor - m1.real_do_execute_no_params.side_effect = \ - m1.do_execute_no_params.side_effect = mock_the_cursor + m1.real_do_execute.side_effect = ( + m1.do_execute.side_effect + ) = mock_the_cursor + m1.real_do_executemany.side_effect = ( + m1.do_executemany.side_effect + ) = mock_the_cursor + m1.real_do_execute_no_params.side_effect = ( + m1.do_execute_no_params.side_effect + ) = mock_the_cursor with e.connect() as conn: yield conn, m1 @@ -2609,36 +2862,53 @@ class DialectEventTest(fixtures.TestBase): result = conn.execute("insert into table foo", {"foo": "bar"}) self._assert( retval, - m1.do_execute, m1.real_do_execute, - [call( - result.context.cursor, - "insert into table foo", - {"foo": "bar"}, result.context)] + m1.do_execute, + m1.real_do_execute, + [ + call( + result.context.cursor, + "insert into table foo", + {"foo": "bar"}, + result.context, + ) + ], ) def _test_do_executemany(self, retval): with self._run_test(retval) as (conn, m1): - result = conn.execute("insert into table foo", - [{"foo": "bar"}, {"foo": "bar"}]) + result = conn.execute( + "insert into table foo", [{"foo": "bar"}, {"foo": "bar"}] + ) self._assert( retval, - m1.do_executemany, m1.real_do_executemany, - [call( - result.context.cursor, - "insert into table foo", - [{"foo": "bar"}, {"foo": "bar"}], result.context)] + m1.do_executemany, + m1.real_do_executemany, + [ + call( + result.context.cursor, + "insert into table foo", + [{"foo": "bar"}, {"foo": "bar"}], + result.context, + ) + ], ) def _test_do_execute_no_params(self, retval): with self._run_test(retval) as (conn, m1): - result = conn.execution_options(no_parameters=True).\ - execute("insert into table foo") + result = conn.execution_options(no_parameters=True).execute( + "insert into table foo" + ) self._assert( retval, - m1.do_execute_no_params, m1.real_do_execute_no_params, - [call( - result.context.cursor, - "insert into table foo", result.context)] + m1.do_execute_no_params, + m1.real_do_execute_no_params, + [ + call( + result.context.cursor, + "insert into table foo", + result.context, + ) + ], ) def _test_cursor_execute(self, retval): @@ -2648,17 +2918,16 @@ class DialectEventTest(fixtures.TestBase): stmt = "insert into table foo" params = {"foo": "bar"} ctx = dialect.execution_ctx_cls._init_statement( - dialect, conn, conn.connection, stmt, [params]) + dialect, conn, conn.connection, stmt, [params] + ) conn._cursor_execute(ctx.cursor, stmt, params, ctx) self._assert( retval, - m1.do_execute, m1.real_do_execute, - [call( - ctx.cursor, - "insert into table foo", - {"foo": "bar"}, ctx)] + m1.do_execute, + m1.real_do_execute, + [call(ctx.cursor, "insert into table foo", {"foo": "bar"}, ctx)], ) def test_do_execute_w_replace(self): @@ -2690,17 +2959,17 @@ class DialectEventTest(fixtures.TestBase): @event.listens_for(e, "do_connect") def evt(dialect, conn_rec, cargs, cparams): - cargs[:] = ['foo', 'hoho'] + cargs[:] = ["foo", "hoho"] cparams.clear() - cparams['bar'] = 'bat' - conn_rec.info['boom'] = "bap" + cparams["bar"] = "bat" + conn_rec.info["boom"] = "bap" m1 = Mock() e.dialect.connect = m1.real_connect with e.connect() as conn: - eq_(m1.mock_calls, [call.real_connect('foo', 'hoho', bar='bat')]) - eq_(conn.info['boom'], 'bap') + eq_(m1.mock_calls, [call.real_connect("foo", "hoho", bar="bat")]) + eq_(conn.info["boom"], "bap") def test_connect_do_connect(self): e = engines.testing_engine(options={"_initialize": False}) @@ -2709,24 +2978,25 @@ class DialectEventTest(fixtures.TestBase): @event.listens_for(e, "do_connect") def evt1(dialect, conn_rec, cargs, cparams): - cargs[:] = ['foo', 'hoho'] + cargs[:] = ["foo", "hoho"] cparams.clear() - cparams['bar'] = 'bat' - conn_rec.info['boom'] = "one" + cparams["bar"] = "bat" + conn_rec.info["boom"] = "one" @event.listens_for(e, "do_connect") def evt2(dialect, conn_rec, cargs, cparams): - conn_rec.info['bap'] = "two" + conn_rec.info["bap"] = "two" return m1.our_connect(cargs, cparams) with e.connect() as conn: # called with args eq_( m1.mock_calls, - [call.our_connect(['foo', 'hoho'], {'bar': 'bat'})]) + [call.our_connect(["foo", "hoho"], {"bar": "bat"})], + ) - eq_(conn.info['boom'], "one") - eq_(conn.info['bap'], "two") + eq_(conn.info["boom"], "one") + eq_(conn.info["bap"], "two") # returned our mock connection is_(conn.connection.connection, m1.our_connect()) @@ -2739,15 +3009,15 @@ class DialectEventTest(fixtures.TestBase): @event.listens_for(e, "do_connect") def evt1(dialect, conn_rec, cargs, cparams): - conn_rec.info['boom'] = "one" + conn_rec.info["boom"] = "one" conn = e.connect() - eq_(conn.info['boom'], "one") + eq_(conn.info["boom"], "one") conn.connection.invalidate(soft=True) conn.close() conn = e.connect() - eq_(conn.info['boom'], "one") + eq_(conn.info["boom"], "one") def test_connect_do_connect_info_there_after_invalidate(self): # test that info is maintained after the do_connect() @@ -2758,14 +3028,14 @@ class DialectEventTest(fixtures.TestBase): @event.listens_for(e, "do_connect") def evt1(dialect, conn_rec, cargs, cparams): assert not conn_rec.info - conn_rec.info['boom'] = "one" + conn_rec.info["boom"] = "one" conn = e.connect() - eq_(conn.info['boom'], "one") + eq_(conn.info["boom"], "one") conn.connection.invalidate() conn = e.connect() - eq_(conn.info['boom'], "one") + eq_(conn.info["boom"], "one") class AutocommitTextTest(fixtures.TestBase): @@ -2775,31 +3045,19 @@ class AutocommitTextTest(fixtures.TestBase): dbapi = Mock( connect=Mock( return_value=Mock( - cursor=Mock( - return_value=Mock( - description=() - ) - ) + cursor=Mock(return_value=Mock(description=())) ) ) ) engine = engines.testing_engine( - options={ - "_initialize": False, - "pool_reset_on_return": None - }) + options={"_initialize": False, "pool_reset_on_return": None} + ) engine.dialect.dbapi = dbapi engine.execute("%s something table something" % keyword) if expected: - eq_( - dbapi.connect().mock_calls, - [call.cursor(), call.commit()] - ) + eq_(dbapi.connect().mock_calls, [call.cursor(), call.commit()]) else: - eq_( - dbapi.connect().mock_calls, - [call.cursor()] - ) + eq_(dbapi.connect().mock_calls, [call.cursor()]) def test_update(self): self._test_keyword("UPDATE") @@ -2821,4 +3079,3 @@ class AutocommitTextTest(fixtures.TestBase): def test_select(self): self._test_keyword("SELECT foo FROM table", False) - diff --git a/test/engine/test_logging.py b/test/engine/test_logging.py index 7044867cff..4a19ee86f6 100644 --- a/test/engine/test_logging.py +++ b/test/engine/test_logging.py @@ -10,108 +10,97 @@ from sqlalchemy import util class LogParamsTest(fixtures.TestBase): - __only_on__ = 'sqlite' - __requires__ = 'ad_hoc_engines', + __only_on__ = "sqlite" + __requires__ = ("ad_hoc_engines",) def setup(self): - self.eng = engines.testing_engine(options={'echo': True}) + self.eng = engines.testing_engine(options={"echo": True}) self.eng.execute("create table foo (data string)") self.buf = logging.handlers.BufferingHandler(100) - for log in [ - logging.getLogger('sqlalchemy.engine'), - ]: + for log in [logging.getLogger("sqlalchemy.engine")]: log.addHandler(self.buf) def teardown(self): self.eng.execute("drop table foo") - for log in [ - logging.getLogger('sqlalchemy.engine'), - ]: + for log in [logging.getLogger("sqlalchemy.engine")]: log.removeHandler(self.buf) def test_log_large_dict(self): self.eng.execute( "INSERT INTO foo (data) values (:data)", - [{"data": str(i)} for i in range(100)] + [{"data": str(i)} for i in range(100)], ) eq_( self.buf.buffer[1].message, "[{'data': '0'}, {'data': '1'}, {'data': '2'}, {'data': '3'}, " "{'data': '4'}, {'data': '5'}, {'data': '6'}, {'data': '7'}" " ... displaying 10 of 100 total bound " - "parameter sets ... {'data': '98'}, {'data': '99'}]" + "parameter sets ... {'data': '98'}, {'data': '99'}]", ) def test_log_large_list(self): self.eng.execute( "INSERT INTO foo (data) values (?)", - [(str(i), ) for i in range(100)] + [(str(i),) for i in range(100)], ) eq_( self.buf.buffer[1].message, "[('0',), ('1',), ('2',), ('3',), ('4',), ('5',), " "('6',), ('7',) ... displaying 10 of 100 total " - "bound parameter sets ... ('98',), ('99',)]" + "bound parameter sets ... ('98',), ('99',)]", ) def test_log_large_parameter_single(self): import random - largeparam = ''.join(chr(random.randint(52, 85)) for i in range(5000)) - self.eng.execute( - "INSERT INTO foo (data) values (?)", - (largeparam, ) - ) + largeparam = "".join(chr(random.randint(52, 85)) for i in range(5000)) + + self.eng.execute("INSERT INTO foo (data) values (?)", (largeparam,)) eq_( self.buf.buffer[1].message, - "('%s ... (4702 characters truncated) ... %s',)" % ( - largeparam[0:149], largeparam[-149:] - ) + "('%s ... (4702 characters truncated) ... %s',)" + % (largeparam[0:149], largeparam[-149:]), ) def test_log_large_multi_parameter(self): import random - lp1 = ''.join(chr(random.randint(52, 85)) for i in range(5)) - lp2 = ''.join(chr(random.randint(52, 85)) for i in range(8)) - lp3 = ''.join(chr(random.randint(52, 85)) for i in range(670)) - self.eng.execute( - "SELECT ?, ?, ?", - (lp1, lp2, lp3) - ) + lp1 = "".join(chr(random.randint(52, 85)) for i in range(5)) + lp2 = "".join(chr(random.randint(52, 85)) for i in range(8)) + lp3 = "".join(chr(random.randint(52, 85)) for i in range(670)) + + self.eng.execute("SELECT ?, ?, ?", (lp1, lp2, lp3)) eq_( self.buf.buffer[1].message, - "('%s', '%s', '%s ... (372 characters truncated) ... %s')" % ( - lp1, lp2, lp3[0:149], lp3[-149:] - ) + "('%s', '%s', '%s ... (372 characters truncated) ... %s')" + % (lp1, lp2, lp3[0:149], lp3[-149:]), ) def test_log_large_parameter_multiple(self): import random - lp1 = ''.join(chr(random.randint(52, 85)) for i in range(5000)) - lp2 = ''.join(chr(random.randint(52, 85)) for i in range(200)) - lp3 = ''.join(chr(random.randint(52, 85)) for i in range(670)) + + lp1 = "".join(chr(random.randint(52, 85)) for i in range(5000)) + lp2 = "".join(chr(random.randint(52, 85)) for i in range(200)) + lp3 = "".join(chr(random.randint(52, 85)) for i in range(670)) self.eng.execute( - "INSERT INTO foo (data) values (?)", - [(lp1, ), (lp2, ), (lp3, )] + "INSERT INTO foo (data) values (?)", [(lp1,), (lp2,), (lp3,)] ) eq_( self.buf.buffer[1].message, "[('%s ... (4702 characters truncated) ... %s',), ('%s',), " - "('%s ... (372 characters truncated) ... %s',)]" % ( - lp1[0:149], lp1[-149:], lp2, lp3[0:149], lp3[-149:] - ) + "('%s ... (372 characters truncated) ... %s',)]" + % (lp1[0:149], lp1[-149:], lp2, lp3[0:149], lp3[-149:]), ) def test_exception_format_dict_param(self): exception = tsa.exc.IntegrityError("foo", {"x": "y"}, None) eq_regex( str(exception), - r"\(.*.NoneType\) None \[SQL: 'foo'\] \[parameters: {'x': 'y'}\]" + r"\(.*.NoneType\) None \[SQL: 'foo'\] \[parameters: {'x': 'y'}\]", ) def test_exception_format_unexpected_parameter(self): @@ -120,7 +109,7 @@ class LogParamsTest(fixtures.TestBase): exception = tsa.exc.IntegrityError("foo", "bar", "bat") eq_regex( str(exception), - r"\(.*.str\) bat \[SQL: 'foo'\] \[parameters: 'bar'\]" + r"\(.*.str\) bat \[SQL: 'foo'\] \[parameters: 'bar'\]", ) def test_exception_format_unexpected_member_parameter(self): @@ -129,56 +118,49 @@ class LogParamsTest(fixtures.TestBase): exception = tsa.exc.IntegrityError("foo", ["bar", "bat"], "hoho") eq_regex( str(exception), - r"\(.*.str\) hoho \[SQL: 'foo'\] \[parameters: \['bar', 'bat'\]\]" + r"\(.*.str\) hoho \[SQL: 'foo'\] \[parameters: \['bar', 'bat'\]\]", ) def test_result_large_param(self): import random - largeparam = ''.join(chr(random.randint(52, 85)) for i in range(5000)) - self.eng.echo = 'debug' - result = self.eng.execute( - "SELECT ?", - (largeparam, ) - ) + largeparam = "".join(chr(random.randint(52, 85)) for i in range(5000)) + + self.eng.echo = "debug" + result = self.eng.execute("SELECT ?", (largeparam,)) row = result.first() eq_( self.buf.buffer[1].message, - "('%s ... (4702 characters truncated) ... %s',)" % ( - largeparam[0:149], largeparam[-149:] - ) + "('%s ... (4702 characters truncated) ... %s',)" + % (largeparam[0:149], largeparam[-149:]), ) if util.py3k: eq_( self.buf.buffer[3].message, - "Row ('%s ... (4702 characters truncated) ... %s',)" % ( - largeparam[0:149], largeparam[-149:] - ) + "Row ('%s ... (4702 characters truncated) ... %s',)" + % (largeparam[0:149], largeparam[-149:]), ) else: eq_( self.buf.buffer[3].message, - "Row (u'%s ... (4703 characters truncated) ... %s',)" % ( - largeparam[0:148], largeparam[-149:] - ) + "Row (u'%s ... (4703 characters truncated) ... %s',)" + % (largeparam[0:148], largeparam[-149:]), ) if util.py3k: eq_( repr(row), - "('%s ... (4702 characters truncated) ... %s',)" % ( - largeparam[0:149], largeparam[-149:] - ) + "('%s ... (4702 characters truncated) ... %s',)" + % (largeparam[0:149], largeparam[-149:]), ) else: eq_( repr(row), - "(u'%s ... (4703 characters truncated) ... %s',)" % ( - largeparam[0:148], largeparam[-149:] - ) + "(u'%s ... (4703 characters truncated) ... %s',)" + % (largeparam[0:148], largeparam[-149:]), ) def test_error_large_dict(self): @@ -193,8 +175,8 @@ class LogParamsTest(fixtures.TestBase): r"{'data': '99'}\]", lambda: self.eng.execute( "INSERT INTO nonexistent (data) values (:data)", - [{"data": str(i)} for i in range(100)] - ) + [{"data": str(i)} for i in range(100)], + ), ) def test_error_large_list(self): @@ -208,8 +190,8 @@ class LogParamsTest(fixtures.TestBase): r"\('98',\), \('99',\)\]", lambda: self.eng.execute( "INSERT INTO nonexistent (data) values (?)", - [(str(i), ) for i in range(100)] - ) + [(str(i),) for i in range(100)], + ), ) @@ -218,27 +200,23 @@ class PoolLoggingTest(fixtures.TestBase): self.existing_level = logging.getLogger("sqlalchemy.pool").level self.buf = logging.handlers.BufferingHandler(100) - for log in [ - logging.getLogger('sqlalchemy.pool') - ]: + for log in [logging.getLogger("sqlalchemy.pool")]: log.addHandler(self.buf) def teardown(self): - for log in [ - logging.getLogger('sqlalchemy.pool') - ]: + for log in [logging.getLogger("sqlalchemy.pool")]: log.removeHandler(self.buf) logging.getLogger("sqlalchemy.pool").setLevel(self.existing_level) def _queuepool_echo_fixture(self): - return tsa.pool.QueuePool(creator=mock.Mock(), echo='debug') + return tsa.pool.QueuePool(creator=mock.Mock(), echo="debug") def _queuepool_logging_fixture(self): logging.getLogger("sqlalchemy.pool").setLevel(logging.DEBUG) return tsa.pool.QueuePool(creator=mock.Mock()) def _stpool_echo_fixture(self): - return tsa.pool.SingletonThreadPool(creator=mock.Mock(), echo='debug') + return tsa.pool.SingletonThreadPool(creator=mock.Mock(), echo="debug") def _stpool_logging_fixture(self): logging.getLogger("sqlalchemy.pool").setLevel(logging.DEBUG) @@ -262,19 +240,19 @@ class PoolLoggingTest(fixtures.TestBase): eq_( [buf.msg for buf in self.buf.buffer], [ - 'Created new connection %r', - 'Connection %r checked out from pool', - 'Connection %r being returned to pool', - 'Connection %s rollback-on-return%s', - 'Connection %r checked out from pool', - 'Connection %r being returned to pool', - 'Connection %s rollback-on-return%s', - 'Connection %r checked out from pool', - 'Connection %r being returned to pool', - 'Connection %s rollback-on-return%s', - 'Closing connection %r', - - ] + (['Pool disposed. %s'] if dispose else []) + "Created new connection %r", + "Connection %r checked out from pool", + "Connection %r being returned to pool", + "Connection %s rollback-on-return%s", + "Connection %r checked out from pool", + "Connection %r being returned to pool", + "Connection %s rollback-on-return%s", + "Connection %r checked out from pool", + "Connection %r being returned to pool", + "Connection %s rollback-on-return%s", + "Closing connection %r", + ] + + (["Pool disposed. %s"] if dispose else []), ) def test_stpool_echo(self): @@ -295,16 +273,16 @@ class PoolLoggingTest(fixtures.TestBase): class LoggingNameTest(fixtures.TestBase): - __requires__ = 'ad_hoc_engines', + __requires__ = ("ad_hoc_engines",) def _assert_names_in_execute(self, eng, eng_name, pool_name): eng.execute(select([1])) assert self.buf.buffer for name in [b.name for b in self.buf.buffer]: assert name in ( - 'sqlalchemy.engine.base.Engine.%s' % eng_name, - 'sqlalchemy.pool.impl.%s.%s' % - (eng.pool.__class__.__name__, pool_name) + "sqlalchemy.engine.base.Engine.%s" % eng_name, + "sqlalchemy.pool.impl.%s.%s" + % (eng.pool.__class__.__name__, pool_name), ) def _assert_no_name_in_execute(self, eng): @@ -312,35 +290,35 @@ class LoggingNameTest(fixtures.TestBase): assert self.buf.buffer for name in [b.name for b in self.buf.buffer]: assert name in ( - 'sqlalchemy.engine.base.Engine', - 'sqlalchemy.pool.impl.%s' % eng.pool.__class__.__name__ + "sqlalchemy.engine.base.Engine", + "sqlalchemy.pool.impl.%s" % eng.pool.__class__.__name__, ) def _named_engine(self, **kw): options = { - 'logging_name': 'myenginename', - 'pool_logging_name': 'mypoolname', - 'echo': True + "logging_name": "myenginename", + "pool_logging_name": "mypoolname", + "echo": True, } options.update(kw) return engines.testing_engine(options=options) def _unnamed_engine(self, **kw): - kw.update({'echo': True}) + kw.update({"echo": True}) return engines.testing_engine(options=kw) def setup(self): self.buf = logging.handlers.BufferingHandler(100) for log in [ - logging.getLogger('sqlalchemy.engine'), - logging.getLogger('sqlalchemy.pool') + logging.getLogger("sqlalchemy.engine"), + logging.getLogger("sqlalchemy.pool"), ]: log.addHandler(self.buf) def teardown(self): for log in [ - logging.getLogger('sqlalchemy.engine'), - logging.getLogger('sqlalchemy.pool') + logging.getLogger("sqlalchemy.engine"), + logging.getLogger("sqlalchemy.pool"), ]: log.removeHandler(self.buf) @@ -366,7 +344,7 @@ class LoggingNameTest(fixtures.TestBase): self._assert_names_in_execute(eng, "myenginename", "mypoolname") def test_named_logger_echoflags_execute(self): - eng = self._named_engine(echo='debug', echo_pool='debug') + eng = self._named_engine(echo="debug", echo_pool="debug") self._assert_names_in_execute(eng, "myenginename", "mypoolname") def test_named_logger_execute_after_dispose(self): @@ -380,22 +358,22 @@ class LoggingNameTest(fixtures.TestBase): self._assert_no_name_in_execute(eng) def test_unnamed_logger_echoflags_execute(self): - eng = self._unnamed_engine(echo='debug', echo_pool='debug') + eng = self._unnamed_engine(echo="debug", echo_pool="debug") self._assert_no_name_in_execute(eng) class EchoTest(fixtures.TestBase): - __requires__ = 'ad_hoc_engines', + __requires__ = ("ad_hoc_engines",) def setup(self): - self.level = logging.getLogger('sqlalchemy.engine').level - logging.getLogger('sqlalchemy.engine').setLevel(logging.WARN) + self.level = logging.getLogger("sqlalchemy.engine").level + logging.getLogger("sqlalchemy.engine").setLevel(logging.WARN) self.buf = logging.handlers.BufferingHandler(100) - logging.getLogger('sqlalchemy.engine').addHandler(self.buf) + logging.getLogger("sqlalchemy.engine").addHandler(self.buf) def teardown(self): - logging.getLogger('sqlalchemy.engine').removeHandler(self.buf) - logging.getLogger('sqlalchemy.engine').setLevel(self.level) + logging.getLogger("sqlalchemy.engine").removeHandler(self.buf) + logging.getLogger("sqlalchemy.engine").setLevel(self.level) def _testing_engine(self): e = engines.testing_engine() @@ -421,7 +399,7 @@ class EchoTest(fixtures.TestBase): eq_(e1.logger.isEnabledFor(logging.INFO), True) eq_(e1.logger.getEffectiveLevel(), logging.INFO) - e1.echo = 'debug' + e1.echo = "debug" eq_(e1._should_log_info(), True) eq_(e1._should_log_debug(), True) eq_(e1.logger.isEnabledFor(logging.DEBUG), True) diff --git a/test/engine/test_parseconnect.py b/test/engine/test_parseconnect.py index 3e3ba7b035..39ad7042bb 100644 --- a/test/engine/test_parseconnect.py +++ b/test/engine/test_parseconnect.py @@ -16,45 +16,53 @@ dialect = None class ParseConnectTest(fixtures.TestBase): def test_rfc1738(self): for text in ( - 'dbtype://username:password@hostspec:110//usr/db_file.db', - 'dbtype://username:password@hostspec/database', - 'dbtype+apitype://username:password@hostspec/database', - 'dbtype://username:password@hostspec', - 'dbtype://username:password@/database', - 'dbtype://username@hostspec', - 'dbtype://username:password@127.0.0.1:1521', - 'dbtype://hostspec/database', - 'dbtype://hostspec', - 'dbtype://hostspec/?arg1=val1&arg2=val2', - 'dbtype+apitype:///database', - 'dbtype:///:memory:', - 'dbtype:///foo/bar/im/a/file', - 'dbtype:///E:/work/src/LEM/db/hello.db', - 'dbtype:///E:/work/src/LEM/db/hello.db?foo=bar&hoho=lala', - 'dbtype:///E:/work/src/LEM/db/hello.db?foo=bar&hoho=lala&hoho=bat', - 'dbtype://', - 'dbtype://username:password@/database', - 'dbtype:////usr/local/_xtest@example.com/members.db', - 'dbtype://username:apples%2Foranges@hostspec/database', - 'dbtype://username:password@[2001:da8:2004:1000:202:116:160:90]' - '/database?foo=bar', - 'dbtype://username:password@[2001:da8:2004:1000:202:116:160:90]:80' - '/database?foo=bar' + "dbtype://username:password@hostspec:110//usr/db_file.db", + "dbtype://username:password@hostspec/database", + "dbtype+apitype://username:password@hostspec/database", + "dbtype://username:password@hostspec", + "dbtype://username:password@/database", + "dbtype://username@hostspec", + "dbtype://username:password@127.0.0.1:1521", + "dbtype://hostspec/database", + "dbtype://hostspec", + "dbtype://hostspec/?arg1=val1&arg2=val2", + "dbtype+apitype:///database", + "dbtype:///:memory:", + "dbtype:///foo/bar/im/a/file", + "dbtype:///E:/work/src/LEM/db/hello.db", + "dbtype:///E:/work/src/LEM/db/hello.db?foo=bar&hoho=lala", + "dbtype:///E:/work/src/LEM/db/hello.db?foo=bar&hoho=lala&hoho=bat", + "dbtype://", + "dbtype://username:password@/database", + "dbtype:////usr/local/_xtest@example.com/members.db", + "dbtype://username:apples%2Foranges@hostspec/database", + "dbtype://username:password@[2001:da8:2004:1000:202:116:160:90]" + "/database?foo=bar", + "dbtype://username:password@[2001:da8:2004:1000:202:116:160:90]:80" + "/database?foo=bar", ): u = url.make_url(text) - assert u.drivername in ('dbtype', 'dbtype+apitype') - assert u.username in ('username', None) - assert u.password in ('password', 'apples/oranges', None) + assert u.drivername in ("dbtype", "dbtype+apitype") + assert u.username in ("username", None) + assert u.password in ("password", "apples/oranges", None) assert u.host in ( - 'hostspec', '127.0.0.1', - '2001:da8:2004:1000:202:116:160:90', '', None), u.host + "hostspec", + "127.0.0.1", + "2001:da8:2004:1000:202:116:160:90", + "", + None, + ), u.host assert u.database in ( - 'database', - '/usr/local/_xtest@example.com/members.db', - '/usr/db_file.db', ':memory:', '', - 'foo/bar/im/a/file', - 'E:/work/src/LEM/db/hello.db', None), u.database + "database", + "/usr/local/_xtest@example.com/members.db", + "/usr/db_file.db", + ":memory:", + "", + "foo/bar/im/a/file", + "E:/work/src/LEM/db/hello.db", + None, + ), u.database eq_(str(u), text) def test_rfc1738_password(self): @@ -63,25 +71,28 @@ class ParseConnectTest(fixtures.TestBase): eq_(str(u), "dbtype://user:pass word + other%3Awords@host/dbname") u = url.make_url( - 'dbtype://username:apples%2Foranges@hostspec/database') + "dbtype://username:apples%2Foranges@hostspec/database" + ) eq_(u.password, "apples/oranges") - eq_(str(u), 'dbtype://username:apples%2Foranges@hostspec/database') + eq_(str(u), "dbtype://username:apples%2Foranges@hostspec/database") u = url.make_url( - 'dbtype://username:apples%40oranges%40%40@hostspec/database') + "dbtype://username:apples%40oranges%40%40@hostspec/database" + ) eq_(u.password, "apples@oranges@@") eq_( str(u), - 'dbtype://username:apples%40oranges%40%40@hostspec/database') + "dbtype://username:apples%40oranges%40%40@hostspec/database", + ) - u = url.make_url('dbtype://username%40:@hostspec/database') - eq_(u.password, '') + u = url.make_url("dbtype://username%40:@hostspec/database") + eq_(u.password, "") eq_(u.username, "username@") - eq_(str(u), 'dbtype://username%40:@hostspec/database') + eq_(str(u), "dbtype://username%40:@hostspec/database") - u = url.make_url('dbtype://username:pass%2Fword@hostspec/database') - eq_(u.password, 'pass/word') - eq_(str(u), 'dbtype://username:pass%2Fword@hostspec/database') + u = url.make_url("dbtype://username:pass%2Fword@hostspec/database") + eq_(u.password, "pass/word") + eq_(str(u), "dbtype://username:pass%2Fword@hostspec/database") def test_password_custom_obj(self): class SecurePassword(str): @@ -92,11 +103,7 @@ class ParseConnectTest(fixtures.TestBase): return self.value sp = SecurePassword("secured_password") - u = url.URL( - "dbtype", - username="x", password=sp, - host="localhost" - ) + u = url.URL("dbtype", username="x", password=sp, host="localhost") eq_(u.password, "secured_password") eq_(str(u), "dbtype://x:secured_password@localhost") @@ -117,18 +124,18 @@ class ParseConnectTest(fixtures.TestBase): eq_(str(u), "dbtype://x@localhost") def test_query_string(self): - u = url.make_url( - "dialect://user:pass@host/db?arg1=param1&arg2=param2") + u = url.make_url("dialect://user:pass@host/db?arg1=param1&arg2=param2") eq_(u.query, {"arg1": "param1", "arg2": "param2"}) eq_(str(u), "dialect://user:pass@host/db?arg1=param1&arg2=param2") u = url.make_url( - "dialect://user:pass@host/db?arg1=param1&arg2=param2&arg2=param3") + "dialect://user:pass@host/db?arg1=param1&arg2=param2&arg2=param3" + ) eq_(u.query, {"arg1": "param1", "arg2": ["param2", "param3"]}) eq_( str(u), - "dialect://user:pass@host/db?arg1=param1&arg2=param2&arg2=param3") - + "dialect://user:pass@host/db?arg1=param1&arg2=param2&arg2=param3", + ) class DialectImportTest(fixtures.TestBase): @@ -136,14 +143,18 @@ class DialectImportTest(fixtures.TestBase): # the globals() somehow makes it for the exec() + nose3. for name in ( - 'mysql', - 'firebird', - 'postgresql', - 'sqlite', - 'oracle', - 'mssql'): - exec('from sqlalchemy.dialects import %s\ndialect = ' - '%s.dialect()' % (name, name), globals()) + "mysql", + "firebird", + "postgresql", + "sqlite", + "oracle", + "mssql", + ): + exec( + "from sqlalchemy.dialects import %s\ndialect = " + "%s.dialect()" % (name, name), + globals(), + ) eq_(dialect.name, name) @@ -152,54 +163,59 @@ class CreateEngineTest(fixtures.TestBase): propagated properly""" def test_connect_query(self): - dbapi = MockDBAPI(foober='12', lala='18', fooz='somevalue') - e = \ - create_engine('postgresql://scott:tiger@somehost/test?foobe' - 'r=12&lala=18&fooz=somevalue', module=dbapi, - _initialize=False) + dbapi = MockDBAPI(foober="12", lala="18", fooz="somevalue") + e = create_engine( + "postgresql://scott:tiger@somehost/test?foobe" + "r=12&lala=18&fooz=somevalue", + module=dbapi, + _initialize=False, + ) e.connect() def test_kwargs(self): - dbapi = MockDBAPI(foober=12, lala=18, hoho={'this': 'dict'}, - fooz='somevalue') - e = \ - create_engine( - 'postgresql://scott:tiger@somehost/test?fooz=' - 'somevalue', connect_args={ - 'foober': 12, - 'lala': 18, 'hoho': {'this': 'dict'}}, - module=dbapi, _initialize=False) + dbapi = MockDBAPI( + foober=12, lala=18, hoho={"this": "dict"}, fooz="somevalue" + ) + e = create_engine( + "postgresql://scott:tiger@somehost/test?fooz=" "somevalue", + connect_args={"foober": 12, "lala": 18, "hoho": {"this": "dict"}}, + module=dbapi, + _initialize=False, + ) e.connect() def test_engine_from_config(self): dbapi = mock_dbapi config = { - 'sqlalchemy.url': 'postgresql://scott:tiger@somehost/test' - '?fooz=somevalue', - 'sqlalchemy.pool_recycle': '50', - 'sqlalchemy.echo': 'true'} + "sqlalchemy.url": "postgresql://scott:tiger@somehost/test" + "?fooz=somevalue", + "sqlalchemy.pool_recycle": "50", + "sqlalchemy.echo": "true", + } e = engine_from_config(config, module=dbapi, _initialize=False) assert e.pool._recycle == 50 - assert e.url \ - == url.make_url('postgresql://scott:tiger@somehost/test?foo' - 'z=somevalue') + assert e.url == url.make_url( + "postgresql://scott:tiger@somehost/test?foo" "z=somevalue" + ) assert e.echo is True def test_pool_threadlocal_from_config(self): dbapi = mock_dbapi config = { - 'sqlalchemy.url': 'postgresql://scott:tiger@somehost/test', - 'sqlalchemy.pool_threadlocal': "false"} + "sqlalchemy.url": "postgresql://scott:tiger@somehost/test", + "sqlalchemy.pool_threadlocal": "false", + } e = engine_from_config(config, module=dbapi, _initialize=False) eq_(e.pool._use_threadlocal, False) config = { - 'sqlalchemy.url': 'postgresql://scott:tiger@somehost/test', - 'sqlalchemy.pool_threadlocal': "true"} + "sqlalchemy.url": "postgresql://scott:tiger@somehost/test", + "sqlalchemy.pool_threadlocal": "true", + } e = engine_from_config(config, module=dbapi, _initialize=False) eq_(e.pool._use_threadlocal, True) @@ -210,23 +226,25 @@ class CreateEngineTest(fixtures.TestBase): for value, expected in [ ("rollback", pool.reset_rollback), ("commit", pool.reset_commit), - ("none", pool.reset_none) + ("none", pool.reset_none), ]: config = { - 'sqlalchemy.url': 'postgresql://scott:tiger@somehost/test', - 'sqlalchemy.pool_reset_on_return': value} + "sqlalchemy.url": "postgresql://scott:tiger@somehost/test", + "sqlalchemy.pool_reset_on_return": value, + } e = engine_from_config(config, module=dbapi, _initialize=False) eq_(e.pool._reset_on_return, expected) def test_engine_from_config_custom(self): from sqlalchemy import util + tokens = __name__.split(".") class MyDialect(MockDialect): engine_config_types = { "foobar": int, - "bathoho": util.bool_or_str('force') + "bathoho": util.bool_or_str("force"), } def __init__(self, foobar=None, bathoho=None, **kw): @@ -236,108 +254,143 @@ class CreateEngineTest(fixtures.TestBase): global dialect dialect = MyDialect registry.register( - "mockdialect.barb", - ".".join(tokens[0:-1]), tokens[-1]) + "mockdialect.barb", ".".join(tokens[0:-1]), tokens[-1] + ) config = { "sqlalchemy.url": "mockdialect+barb://", "sqlalchemy.foobar": "5", - "sqlalchemy.bathoho": "false" + "sqlalchemy.bathoho": "false", } e = engine_from_config(config, _initialize=False) eq_(e.dialect.foobar, 5) eq_(e.dialect.bathoho, False) def test_custom(self): - dbapi = MockDBAPI(foober=12, lala=18, hoho={'this': 'dict'}, - fooz='somevalue') + dbapi = MockDBAPI( + foober=12, lala=18, hoho={"this": "dict"}, fooz="somevalue" + ) def connect(): - return dbapi.connect(foober=12, lala=18, fooz='somevalue', - hoho={'this': 'dict'}) + return dbapi.connect( + foober=12, lala=18, fooz="somevalue", hoho={"this": "dict"} + ) # start the postgresql dialect, but put our mock DBAPI as the # module instead of psycopg - e = create_engine('postgresql://', creator=connect, - module=dbapi, _initialize=False) + e = create_engine( + "postgresql://", creator=connect, module=dbapi, _initialize=False + ) e.connect() def test_recycle(self): - dbapi = MockDBAPI(foober=12, lala=18, hoho={'this': 'dict'}, - fooz='somevalue') - e = create_engine('postgresql://', pool_recycle=472, - module=dbapi, _initialize=False) + dbapi = MockDBAPI( + foober=12, lala=18, hoho={"this": "dict"}, fooz="somevalue" + ) + e = create_engine( + "postgresql://", pool_recycle=472, module=dbapi, _initialize=False + ) assert e.pool._recycle == 472 def test_reset_on_return(self): - dbapi = MockDBAPI(foober=12, lala=18, hoho={'this': 'dict'}, - fooz='somevalue') + dbapi = MockDBAPI( + foober=12, lala=18, hoho={"this": "dict"}, fooz="somevalue" + ) for (value, expected) in [ - ('rollback', pool.reset_rollback), - ('commit', pool.reset_commit), + ("rollback", pool.reset_rollback), + ("commit", pool.reset_commit), (None, pool.reset_none), (True, pool.reset_rollback), (False, pool.reset_none), ]: e = create_engine( - 'postgresql://', pool_reset_on_return=value, - module=dbapi, _initialize=False) + "postgresql://", + pool_reset_on_return=value, + module=dbapi, + _initialize=False, + ) assert e.pool._reset_on_return is expected assert_raises( exc.ArgumentError, - create_engine, "postgresql://", - pool_reset_on_return='hi', module=dbapi, - _initialize=False + create_engine, + "postgresql://", + pool_reset_on_return="hi", + module=dbapi, + _initialize=False, ) def test_bad_args(self): - assert_raises(exc.ArgumentError, create_engine, 'foobar://', - module=mock_dbapi) + assert_raises( + exc.ArgumentError, create_engine, "foobar://", module=mock_dbapi + ) # bad arg - assert_raises(TypeError, create_engine, 'postgresql://', - use_ansi=True, module=mock_dbapi) + assert_raises( + TypeError, + create_engine, + "postgresql://", + use_ansi=True, + module=mock_dbapi, + ) # bad arg assert_raises( TypeError, create_engine, - 'oracle://', + "oracle://", lala=5, use_ansi=True, module=mock_dbapi, ) - assert_raises(TypeError, create_engine, 'postgresql://', - lala=5, module=mock_dbapi) - assert_raises(TypeError, create_engine, 'sqlite://', lala=5, - module=mock_sqlite_dbapi) - assert_raises(TypeError, create_engine, 'mysql+mysqldb://', - use_unicode=True, module=mock_dbapi) + assert_raises( + TypeError, + create_engine, + "postgresql://", + lala=5, + module=mock_dbapi, + ) + assert_raises( + TypeError, + create_engine, + "sqlite://", + lala=5, + module=mock_sqlite_dbapi, + ) + assert_raises( + TypeError, + create_engine, + "mysql+mysqldb://", + use_unicode=True, + module=mock_dbapi, + ) def test_urlattr(self): """test the url attribute on ``Engine``.""" - e = create_engine('mysql://scott:tiger@localhost/test', - module=mock_dbapi, _initialize=False) - u = url.make_url('mysql://scott:tiger@localhost/test') + e = create_engine( + "mysql://scott:tiger@localhost/test", + module=mock_dbapi, + _initialize=False, + ) + u = url.make_url("mysql://scott:tiger@localhost/test") e2 = create_engine(u, module=mock_dbapi, _initialize=False) - assert e.url.drivername == e2.url.drivername == 'mysql' - assert e.url.username == e2.url.username == 'scott' + assert e.url.drivername == e2.url.drivername == "mysql" + assert e.url.username == e2.url.username == "scott" assert e2.url is u - assert str(u) == 'mysql://scott:tiger@localhost/test' - assert repr(u) == 'mysql://scott:***@localhost/test' - assert repr(e) == 'Engine(mysql://scott:***@localhost/test)' - assert repr(e2) == 'Engine(mysql://scott:***@localhost/test)' + assert str(u) == "mysql://scott:tiger@localhost/test" + assert repr(u) == "mysql://scott:***@localhost/test" + assert repr(e) == "Engine(mysql://scott:***@localhost/test)" + assert repr(e2) == "Engine(mysql://scott:***@localhost/test)" def test_poolargs(self): """test that connection pool args make it thru""" e = create_engine( - 'postgresql://', + "postgresql://", creator=None, pool_recycle=50, echo_pool=None, @@ -349,7 +402,7 @@ class CreateEngineTest(fixtures.TestBase): # these args work for QueuePool e = create_engine( - 'postgresql://', + "postgresql://", max_overflow=8, pool_timeout=60, poolclass=tsa.pool.QueuePool, @@ -362,7 +415,7 @@ class CreateEngineTest(fixtures.TestBase): assert_raises( TypeError, create_engine, - 'sqlite://', + "sqlite://", max_overflow=8, pool_timeout=60, poolclass=tsa.pool.SingletonThreadPool, @@ -390,7 +443,8 @@ class TestRegNewDBAPI(fixtures.TestBase): global dialect dialect = MockDialect registry.register( - "mockdialect.foob", ".".join(tokens[0:-1]), tokens[-1]) + "mockdialect.foob", ".".join(tokens[0:-1]), tokens[-1] + ) e = create_engine("mockdialect+foob://") assert isinstance(e.dialect, MockDialect) @@ -414,6 +468,7 @@ class TestRegNewDBAPI(fixtures.TestBase): registry.register("wrapperdialect", __name__, "WrapperFactory") from sqlalchemy.dialects import sqlite + e = create_engine("wrapperdialect://") eq_(e.dialect.name, "sqlite") @@ -423,8 +478,8 @@ class TestRegNewDBAPI(fixtures.TestBase): WrapperFactory.mock_calls, [ call.get_dialect_cls(url.make_url("sqlite://")), - call.engine_created(e) - ] + call.engine_created(e), + ], ) @testing.requires.sqlite @@ -436,11 +491,15 @@ class TestRegNewDBAPI(fixtures.TestBase): def side_effect(url, kw): eq_( url.query, - {"plugin": "engineplugin", "myplugin_arg": "bat", "foo": "bar"} + { + "plugin": "engineplugin", + "myplugin_arg": "bat", + "foo": "bar", + }, ) eq_(kw, {"logging_name": "foob"}) - kw['logging_name'] = 'bar' - url.query.pop('myplugin_arg', None) + kw["logging_name"] = "bar" + url.query.pop("myplugin_arg", None) return MyEnginePlugin MyEnginePlugin = Mock(side_effect=side_effect) @@ -449,15 +508,13 @@ class TestRegNewDBAPI(fixtures.TestBase): e = create_engine( "sqlite:///?plugin=engineplugin&foo=bar&myplugin_arg=bat", - logging_name='foob') + logging_name="foob", + ) eq_(e.dialect.name, "sqlite") eq_(e.logging_name, "bar") # plugin args are removed from URL. - eq_( - e.url.query, - {"foo": "bar"} - ) + eq_(e.url.query, {"foo": "bar"}) assert isinstance(e.dialect, sqlite.dialect) eq_( @@ -466,15 +523,12 @@ class TestRegNewDBAPI(fixtures.TestBase): call(url.make_url("sqlite:///?foo=bar"), {}), call.handle_dialect_kwargs(sqlite.dialect, mock.ANY), call.handle_pool_kwargs(mock.ANY, {"dialect": e.dialect}), - call.engine_created(e) - ] + call.engine_created(e), + ], ) # url was modified in place by MyEnginePlugin - eq_( - str(MyEnginePlugin.mock_calls[0][1][0]), - "sqlite:///?foo=bar" - ) + eq_(str(MyEnginePlugin.mock_calls[0][1][0]), "sqlite:///?foo=bar") @testing.requires.sqlite def test_plugin_multiple_url_registration(self): @@ -485,12 +539,12 @@ class TestRegNewDBAPI(fixtures.TestBase): def side_effect_1(url, kw): eq_(kw, {"logging_name": "foob"}) - kw['logging_name'] = 'bar' - url.query.pop('myplugin1_arg', None) + kw["logging_name"] = "bar" + url.query.pop("myplugin1_arg", None) return MyEnginePlugin1 def side_effect_2(url, kw): - url.query.pop('myplugin2_arg', None) + url.query.pop("myplugin2_arg", None) return MyEnginePlugin2 MyEnginePlugin1 = Mock(side_effect=side_effect_1) @@ -502,15 +556,13 @@ class TestRegNewDBAPI(fixtures.TestBase): e = create_engine( "sqlite:///?plugin=engineplugin1&foo=bar&myplugin1_arg=bat" "&plugin=engineplugin2&myplugin2_arg=hoho", - logging_name='foob') + logging_name="foob", + ) eq_(e.dialect.name, "sqlite") eq_(e.logging_name, "bar") # plugin args are removed from URL. - eq_( - e.url.query, - {"foo": "bar"} - ) + eq_(e.url.query, {"foo": "bar"}) assert isinstance(e.dialect, sqlite.dialect) eq_( @@ -519,8 +571,8 @@ class TestRegNewDBAPI(fixtures.TestBase): call(url.make_url("sqlite:///?foo=bar"), {}), call.handle_dialect_kwargs(sqlite.dialect, mock.ANY), call.handle_pool_kwargs(mock.ANY, {"dialect": e.dialect}), - call.engine_created(e) - ] + call.engine_created(e), + ], ) eq_( @@ -529,8 +581,8 @@ class TestRegNewDBAPI(fixtures.TestBase): call(url.make_url("sqlite:///?foo=bar"), {}), call.handle_dialect_kwargs(sqlite.dialect, mock.ANY), call.handle_pool_kwargs(mock.ANY, {"dialect": e.dialect}), - call.engine_created(e) - ] + call.engine_created(e), + ], ) @testing.requires.sqlite @@ -542,11 +594,14 @@ class TestRegNewDBAPI(fixtures.TestBase): def side_effect(url, kw): eq_( kw, - {"logging_name": "foob", 'plugins': ['engineplugin'], - 'myplugin_arg': 'bat'} + { + "logging_name": "foob", + "plugins": ["engineplugin"], + "myplugin_arg": "bat", + }, ) - kw['logging_name'] = 'bar' - kw.pop('myplugin_arg', None) + kw["logging_name"] = "bar" + kw.pop("myplugin_arg", None) return MyEnginePlugin MyEnginePlugin = Mock(side_effect=side_effect) @@ -555,7 +610,10 @@ class TestRegNewDBAPI(fixtures.TestBase): e = create_engine( "sqlite:///?foo=bar", - logging_name='foob', plugins=["engineplugin"], myplugin_arg="bat") + logging_name="foob", + plugins=["engineplugin"], + myplugin_arg="bat", + ) eq_(e.dialect.name, "sqlite") eq_(e.logging_name, "bar") @@ -567,8 +625,8 @@ class TestRegNewDBAPI(fixtures.TestBase): call(url.make_url("sqlite:///?foo=bar"), {}), call.handle_dialect_kwargs(sqlite.dialect, mock.ANY), call.handle_pool_kwargs(mock.ANY, {"dialect": e.dialect}), - call.engine_created(e) - ] + call.engine_created(e), + ], ) @@ -579,22 +637,20 @@ class MockDialect(DefaultDialect): def MockDBAPI(**assert_kwargs): - connection = Mock(get_server_version_info=Mock(return_value='5.0')) + connection = Mock(get_server_version_info=Mock(return_value="5.0")) def connect(*args, **kwargs): for k in assert_kwargs: - assert k in kwargs, 'key %s not present in dictionary' % k - eq_( - kwargs[k], assert_kwargs[k] - ) + assert k in kwargs, "key %s not present in dictionary" % k + eq_(kwargs[k], assert_kwargs[k]) return connection return MagicMock( - sqlite_version_info=(99, 9, 9,), - version_info=(99, 9, 9,), - sqlite_version='99.9.9', - paramstyle='named', - connect=Mock(side_effect=connect) + sqlite_version_info=(99, 9, 9), + version_info=(99, 9, 9), + sqlite_version="99.9.9", + paramstyle="named", + connect=Mock(side_effect=connect), ) diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index 99e50f582a..547c265bb4 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -29,7 +29,9 @@ def MockDBAPI(): # noqa # adding a side_effect for close seems to help. conn = Mock( cursor=Mock(side_effect=cursor), - close=Mock(side_effect=close), closed=False) + close=Mock(side_effect=close), + closed=False, + ) return conn def shutdown(value): @@ -40,9 +42,8 @@ def MockDBAPI(): # noqa db.is_shutdown = value db = Mock( - connect=Mock(side_effect=connect), - shutdown=shutdown, - is_shutdown=False) + connect=Mock(side_effect=connect), shutdown=shutdown, is_shutdown=False + ) return db @@ -71,18 +72,19 @@ class PoolTestBase(fixtures.TestBase): def _queuepool_dbapi_fixture(self, **kw): dbapi = MockDBAPI() - return dbapi, pool.QueuePool( - creator=lambda: dbapi.connect('foo.db'), - **kw) + return ( + dbapi, + pool.QueuePool(creator=lambda: dbapi.connect("foo.db"), **kw), + ) class PoolTest(PoolTestBase): def test_manager(self): manager = pool.manage(MockDBAPI(), use_threadlocal=True) - c1 = manager.connect('foo.db') - c2 = manager.connect('foo.db') - c3 = manager.connect('bar.db') + c1 = manager.connect("foo.db") + c2 = manager.connect("foo.db") + c3 = manager.connect("bar.db") c4 = manager.connect("foo.db", bar="bat") c5 = manager.connect("foo.db", bar="hoho") c6 = manager.connect("foo.db", bar="bat") @@ -98,21 +100,15 @@ class PoolTest(PoolTestBase): dbapi = MockDBAPI() manager = pool.manage(dbapi, use_threadlocal=True) - c1 = manager.connect('foo.db', sa_pool_key="a") - c2 = manager.connect('foo.db', sa_pool_key="b") - c3 = manager.connect('bar.db', sa_pool_key="a") + c1 = manager.connect("foo.db", sa_pool_key="a") + c2 = manager.connect("foo.db", sa_pool_key="b") + c3 = manager.connect("bar.db", sa_pool_key="a") assert c1.cursor() is not None assert c1 is not c2 assert c1 is c3 - eq_( - dbapi.connect.mock_calls, - [ - call("foo.db"), - call("foo.db"), - ] - ) + eq_(dbapi.connect.mock_calls, [call("foo.db"), call("foo.db")]) def test_bad_args(self): manager = pool.manage(MockDBAPI()) @@ -121,20 +117,21 @@ class PoolTest(PoolTestBase): def test_non_thread_local_manager(self): manager = pool.manage(MockDBAPI(), use_threadlocal=False) - connection = manager.connect('foo.db') - connection2 = manager.connect('foo.db') + connection = manager.connect("foo.db") + connection2 = manager.connect("foo.db") self.assert_(connection.cursor() is not None) self.assert_(connection is not connection2) - @testing.fails_on('+pyodbc', - "pyodbc cursor doesn't implement tuple __eq__") + @testing.fails_on( + "+pyodbc", "pyodbc cursor doesn't implement tuple __eq__" + ) @testing.fails_on("+pg8000", "returns [1], not (1,)") def test_cursor_iterable(self): conn = testing.db.raw_connection() cursor = conn.cursor() cursor.execute(str(select([1], bind=testing.db))) - expected = [(1, )] + expected = [(1,)] for row in cursor: eq_(row, expected.pop(0)) @@ -142,8 +139,13 @@ class PoolTest(PoolTestBase): def creator(): raise Exception("no creates allowed") - for cls in (pool.SingletonThreadPool, pool.StaticPool, - pool.QueuePool, pool.NullPool, pool.AssertionPool): + for cls in ( + pool.SingletonThreadPool, + pool.StaticPool, + pool.QueuePool, + pool.NullPool, + pool.AssertionPool, + ): p = cls(creator=creator) p.dispose() p2 = p.recreate() @@ -165,12 +167,17 @@ class PoolTest(PoolTestBase): def _do_testthreadlocal(self, useclose=False): dbapi = MockDBAPI() - for p in pool.QueuePool(creator=dbapi.connect, - pool_size=3, max_overflow=-1, - use_threadlocal=True), \ - pool.SingletonThreadPool( + for p in ( + pool.QueuePool( creator=dbapi.connect, - use_threadlocal=True): + pool_size=3, + max_overflow=-1, + use_threadlocal=True, + ), + pool.SingletonThreadPool( + creator=dbapi.connect, use_threadlocal=True + ), + ): c1 = p.connect() c2 = p.connect() self.assert_(c1 is c2) @@ -222,25 +229,25 @@ class PoolTest(PoolTestBase): self.assert_(not c.info) self.assert_(c.info is c._connection_record.info) - c.info['foo'] = 'bar' + c.info["foo"] = "bar" c.close() del c c = p.connect() - self.assert_('foo' in c.info) + self.assert_("foo" in c.info) c.invalidate() c = p.connect() - self.assert_('foo' not in c.info) + self.assert_("foo" not in c.info) - c.info['foo2'] = 'bar2' + c.info["foo2"] = "bar2" c.detach() - self.assert_('foo2' in c.info) + self.assert_("foo2" in c.info) c2 = p.connect() is_not_(c.connection, c2.connection) assert not c2.info - assert 'foo2' in c.info + assert "foo2" in c.info def test_rec_info(self): p = self._queuepool_fixture(pool_size=1, max_overflow=0) @@ -249,18 +256,18 @@ class PoolTest(PoolTestBase): self.assert_(not c.record_info) self.assert_(c.record_info is c._connection_record.record_info) - c.record_info['foo'] = 'bar' + c.record_info["foo"] = "bar" c.close() del c c = p.connect() - self.assert_('foo' in c.record_info) + self.assert_("foo" in c.record_info) c.invalidate() c = p.connect() - self.assert_('foo' in c.record_info) + self.assert_("foo" in c.record_info) - c.record_info['foo2'] = 'bar2' + c.record_info["foo2"] = "bar2" c.detach() is_(c.record_info, None) is_(c._connection_record, None) @@ -268,16 +275,14 @@ class PoolTest(PoolTestBase): c2 = p.connect() assert c2.record_info - assert 'foo2' in c2.record_info + assert "foo2" in c2.record_info def test_rec_unconnected(self): # test production of a _ConnectionRecord with an # initially unconnected state. dbapi = MockDBAPI() - p1 = pool.Pool( - creator=lambda: dbapi.connect('foo.db') - ) + p1 = pool.Pool(creator=lambda: dbapi.connect("foo.db")) r1 = pool._ConnectionRecord(p1, connect=False) @@ -289,9 +294,7 @@ class PoolTest(PoolTestBase): # test that _ConnectionRecord.close() allows # the record to be reusable dbapi = MockDBAPI() - p1 = pool.Pool( - creator=lambda: dbapi.connect('foo.db') - ) + p1 = pool.Pool(creator=lambda: dbapi.connect("foo.db")) r1 = pool._ConnectionRecord(p1) @@ -302,20 +305,14 @@ class PoolTest(PoolTestBase): r1.close() assert not r1.connection - eq_( - c1.mock_calls, - [call.close()] - ) + eq_(c1.mock_calls, [call.close()]) c2 = r1.get_connection() is_not_(c1, c2) is_(c2, r1.connection) - eq_( - c2.mock_calls, - [] - ) + eq_(c2.mock_calls, []) class PoolDialectTest(PoolTestBase): @@ -324,16 +321,17 @@ class PoolDialectTest(PoolTestBase): class PoolDialect(object): def do_rollback(self, dbapi_connection): - canary.append('R') + canary.append("R") dbapi_connection.rollback() def do_commit(self, dbapi_connection): - canary.append('C') + canary.append("C") dbapi_connection.commit() def do_close(self, dbapi_connection): - canary.append('CL') + canary.append("CL") dbapi_connection.close() + return PoolDialect(), canary def _do_test(self, pool_cls, assertion): @@ -351,19 +349,19 @@ class PoolDialectTest(PoolTestBase): eq_(canary, assertion) def test_queue_pool(self): - self._do_test(pool.QueuePool, ['R', 'CL', 'R']) + self._do_test(pool.QueuePool, ["R", "CL", "R"]) def test_assertion_pool(self): - self._do_test(pool.AssertionPool, ['R', 'CL', 'R']) + self._do_test(pool.AssertionPool, ["R", "CL", "R"]) def test_singleton_pool(self): - self._do_test(pool.SingletonThreadPool, ['R', 'CL', 'R']) + self._do_test(pool.SingletonThreadPool, ["R", "CL", "R"]) def test_null_pool(self): - self._do_test(pool.NullPool, ['R', 'CL', 'R', 'CL']) + self._do_test(pool.NullPool, ["R", "CL", "R", "CL"]) def test_static_pool(self): - self._do_test(pool.StaticPool, ['R', 'R']) + self._do_test(pool.StaticPool, ["R", "R"]) class PoolEventsTest(PoolTestBase): @@ -372,9 +370,9 @@ class PoolEventsTest(PoolTestBase): canary = [] def first_connect(*arg, **kw): - canary.append('first_connect') + canary.append("first_connect") - event.listen(p, 'first_connect', first_connect) + event.listen(p, "first_connect", first_connect) return p, canary @@ -383,9 +381,9 @@ class PoolEventsTest(PoolTestBase): canary = [] def connect(*arg, **kw): - canary.append('connect') + canary.append("connect") - event.listen(p, 'connect', connect) + event.listen(p, "connect", connect) return p, canary @@ -394,8 +392,9 @@ class PoolEventsTest(PoolTestBase): canary = [] def checkout(*arg, **kw): - canary.append('checkout') - event.listen(p, 'checkout', checkout) + canary.append("checkout") + + event.listen(p, "checkout", checkout) return p, canary @@ -404,8 +403,9 @@ class PoolEventsTest(PoolTestBase): canary = [] def checkin(*arg, **kw): - canary.append('checkin') - event.listen(p, 'checkin', checkin) + canary.append("checkin") + + event.listen(p, "checkin", checkin) return p, canary @@ -414,43 +414,44 @@ class PoolEventsTest(PoolTestBase): canary = [] def reset(*arg, **kw): - canary.append('reset') - event.listen(p, 'reset', reset) + canary.append("reset") + + event.listen(p, "reset", reset) return p, canary def _invalidate_event_fixture(self): p = self._queuepool_fixture() canary = Mock() - event.listen(p, 'invalidate', canary) + event.listen(p, "invalidate", canary) return p, canary def _soft_invalidate_event_fixture(self): p = self._queuepool_fixture() canary = Mock() - event.listen(p, 'soft_invalidate', canary) + event.listen(p, "soft_invalidate", canary) return p, canary def _close_event_fixture(self): p = self._queuepool_fixture() canary = Mock() - event.listen(p, 'close', canary) + event.listen(p, "close", canary) return p, canary def _detach_event_fixture(self): p = self._queuepool_fixture() canary = Mock() - event.listen(p, 'detach', canary) + event.listen(p, "detach", canary) return p, canary def _close_detached_event_fixture(self): p = self._queuepool_fixture() canary = Mock() - event.listen(p, 'close_detached', canary) + event.listen(p, "close_detached", canary) return p, canary @@ -497,7 +498,7 @@ class PoolEventsTest(PoolTestBase): p, canary = self._first_connect_event_fixture() p.connect() - eq_(canary, ['first_connect']) + eq_(canary, ["first_connect"]) def test_first_connect_event_fires_once(self): p, canary = self._first_connect_event_fixture() @@ -505,7 +506,7 @@ class PoolEventsTest(PoolTestBase): p.connect() p.connect() - eq_(canary, ['first_connect']) + eq_(canary, ["first_connect"]) def test_first_connect_on_previously_recreated(self): p, canary = self._first_connect_event_fixture() @@ -514,7 +515,7 @@ class PoolEventsTest(PoolTestBase): p.connect() p2.connect() - eq_(canary, ['first_connect', 'first_connect']) + eq_(canary, ["first_connect", "first_connect"]) def test_first_connect_on_subsequently_recreated(self): p, canary = self._first_connect_event_fixture() @@ -523,13 +524,13 @@ class PoolEventsTest(PoolTestBase): p2 = p.recreate() p2.connect() - eq_(canary, ['first_connect', 'first_connect']) + eq_(canary, ["first_connect", "first_connect"]) def test_connect_event(self): p, canary = self._connect_event_fixture() p.connect() - eq_(canary, ['connect']) + eq_(canary, ["connect"]) def test_connect_event_fires_subsequent(self): p, canary = self._connect_event_fixture() @@ -537,7 +538,7 @@ class PoolEventsTest(PoolTestBase): c1 = p.connect() # noqa c2 = p.connect() # noqa - eq_(canary, ['connect', 'connect']) + eq_(canary, ["connect", "connect"]) def test_connect_on_previously_recreated(self): p, canary = self._connect_event_fixture() @@ -547,7 +548,7 @@ class PoolEventsTest(PoolTestBase): p.connect() p2.connect() - eq_(canary, ['connect', 'connect']) + eq_(canary, ["connect", "connect"]) def test_connect_on_subsequently_recreated(self): p, canary = self._connect_event_fixture() @@ -556,20 +557,20 @@ class PoolEventsTest(PoolTestBase): p2 = p.recreate() p2.connect() - eq_(canary, ['connect', 'connect']) + eq_(canary, ["connect", "connect"]) def test_checkout_event(self): p, canary = self._checkout_event_fixture() p.connect() - eq_(canary, ['checkout']) + eq_(canary, ["checkout"]) def test_checkout_event_fires_subsequent(self): p, canary = self._checkout_event_fixture() p.connect() p.connect() - eq_(canary, ['checkout', 'checkout']) + eq_(canary, ["checkout", "checkout"]) def test_checkout_event_on_subsequently_recreated(self): p, canary = self._checkout_event_fixture() @@ -578,7 +579,7 @@ class PoolEventsTest(PoolTestBase): p2 = p.recreate() p2.connect() - eq_(canary, ['checkout', 'checkout']) + eq_(canary, ["checkout", "checkout"]) def test_checkin_event(self): p, canary = self._checkin_event_fixture() @@ -586,7 +587,7 @@ class PoolEventsTest(PoolTestBase): c1 = p.connect() eq_(canary, []) c1.close() - eq_(canary, ['checkin']) + eq_(canary, ["checkin"]) def test_reset_event(self): p, canary = self._reset_event_fixture() @@ -594,7 +595,7 @@ class PoolEventsTest(PoolTestBase): c1 = p.connect() eq_(canary, []) c1.close() - eq_(canary, ['reset']) + eq_(canary, ["reset"]) def test_soft_invalidate_event_no_exception(self): p, canary = self._soft_invalidate_event_fixture() @@ -653,7 +654,7 @@ class PoolEventsTest(PoolTestBase): eq_(canary, []) del c1 lazy_gc() - eq_(canary, ['checkin']) + eq_(canary, ["checkin"]) def test_checkin_event_on_subsequently_recreated(self): p, canary = self._checkin_event_fixture() @@ -665,10 +666,10 @@ class PoolEventsTest(PoolTestBase): eq_(canary, []) c1.close() - eq_(canary, ['checkin']) + eq_(canary, ["checkin"]) c2.close() - eq_(canary, ['checkin', 'checkin']) + eq_(canary, ["checkin", "checkin"]) def test_listen_targets_scope(self): canary = [] @@ -686,15 +687,14 @@ class PoolEventsTest(PoolTestBase): canary.append("listen_four") engine = testing_engine(testing.db.url) - event.listen(pool.Pool, 'connect', listen_one) - event.listen(engine.pool, 'connect', listen_two) - event.listen(engine, 'connect', listen_three) - event.listen(engine.__class__, 'connect', listen_four) + event.listen(pool.Pool, "connect", listen_one) + event.listen(engine.pool, "connect", listen_two) + event.listen(engine, "connect", listen_three) + event.listen(engine.__class__, "connect", listen_four) engine.execute(select([1])).close() eq_( - canary, - ["listen_one", "listen_four", "listen_two", "listen_three"] + canary, ["listen_one", "listen_four", "listen_two", "listen_three"] ) def test_listen_targets_per_subclass(self): @@ -712,9 +712,9 @@ class PoolEventsTest(PoolTestBase): def listen_three(*args): canary.append("listen_three") - event.listen(pool.Pool, 'connect', listen_one) - event.listen(pool.QueuePool, 'connect', listen_two) - event.listen(pool.SingletonThreadPool, 'connect', listen_three) + event.listen(pool.Pool, "connect", listen_one) + event.listen(pool.QueuePool, "connect", listen_two) + event.listen(pool.SingletonThreadPool, "connect", listen_three) p1 = pool.QueuePool(creator=MockDBAPI().connect) p2 = pool.SingletonThreadPool(creator=MockDBAPI().connect) @@ -739,15 +739,16 @@ class PoolEventsTest(PoolTestBase): raise Exception("it failed") def listen_two(conn, rec): - rec.info['important_flag'] = True + rec.info["important_flag"] = True p1 = pool.QueuePool( - creator=MockDBAPI().connect, pool_size=1, max_overflow=0) - event.listen(p1, 'connect', listen_one) - event.listen(p1, 'connect', listen_two) + creator=MockDBAPI().connect, pool_size=1, max_overflow=0 + ) + event.listen(p1, "connect", listen_one) + event.listen(p1, "connect", listen_two) conn = p1.connect() - eq_(conn.info['important_flag'], True) + eq_(conn.info["important_flag"], True) conn.invalidate() conn.close() @@ -757,7 +758,7 @@ class PoolEventsTest(PoolTestBase): fail = False conn = p1.connect() - eq_(conn.info['important_flag'], True) + eq_(conn.info["important_flag"], True) conn.close() def teardown(self): @@ -775,21 +776,21 @@ class PoolFirstConnectSyncTest(PoolTestBase): evt = Mock() - @event.listens_for(pool, 'first_connect') + @event.listens_for(pool, "first_connect") def slow_first_connect(dbapi_con, rec): time.sleep(1) evt.first_connect() - @event.listens_for(pool, 'connect') + @event.listens_for(pool, "connect") def on_connect(dbapi_con, rec): evt.connect() def checkout(): for j in range(2): c1 = pool.connect() - time.sleep(.02) + time.sleep(0.02) c1.close() - time.sleep(.02) + time.sleep(0.02) threads = [] for i in range(5): @@ -805,7 +806,8 @@ class PoolFirstConnectSyncTest(PoolTestBase): call.first_connect(), call.connect(), call.connect(), - call.connect()] + call.connect(), + ], ) @@ -813,16 +815,15 @@ class DeprecatedPoolListenerTest(PoolTestBase): @testing.requires.predictable_gc @testing.uses_deprecated(r".*Use event.listen") def test_listeners(self): - class InstrumentingListener(object): def __init__(self): - if hasattr(self, 'connect'): + if hasattr(self, "connect"): self.connect = self.inst_connect - if hasattr(self, 'first_connect'): + if hasattr(self, "first_connect"): self.first_connect = self.inst_first_connect - if hasattr(self, 'checkout'): + if hasattr(self, "checkout"): self.checkout = self.inst_checkout - if hasattr(self, 'checkin'): + if hasattr(self, "checkin"): self.checkin = self.inst_checkin self.clear() @@ -838,9 +839,7 @@ class DeprecatedPoolListenerTest(PoolTestBase): eq_(len(self.checked_out), cout) eq_(len(self.checked_in), cin) - def assert_in( - self, item, in_conn, in_fconn, - in_cout, in_cin): + def assert_in(self, item, in_conn, in_fconn, in_cout, in_cin): eq_((item in self.connected), in_conn) eq_((item in self.first_connected), in_fconn) eq_((item in self.checked_out), in_cout) @@ -1056,7 +1055,6 @@ class DeprecatedPoolListenerTest(PoolTestBase): class QueuePoolTest(PoolTestBase): - def test_queuepool_del(self): self._do_testqueuepool(useclose=False) @@ -1064,13 +1062,15 @@ class QueuePoolTest(PoolTestBase): self._do_testqueuepool(useclose=True) def _do_testqueuepool(self, useclose=False): - p = self._queuepool_fixture( - pool_size=3, - max_overflow=-1) + p = self._queuepool_fixture(pool_size=3, max_overflow=-1) def status(pool): - return pool.size(), pool.checkedin(), pool.overflow(), \ - pool.checkedout() + return ( + pool.size(), + pool.checkedin(), + pool.overflow(), + pool.checkedout(), + ) c1 = p.connect() self.assert_(status(p) == (3, 0, -2, 1)) @@ -1115,19 +1115,13 @@ class QueuePoolTest(PoolTestBase): @testing.requires.timing_intensive def test_timeout(self): - p = self._queuepool_fixture( - pool_size=3, - max_overflow=0, - timeout=2) + p = self._queuepool_fixture(pool_size=3, max_overflow=0, timeout=2) c1 = p.connect() # noqa c2 = p.connect() # noqa c3 = p.connect() # noqa now = time.time() - assert_raises( - tsa.exc.TimeoutError, - p.connect - ) + assert_raises(tsa.exc.TimeoutError, p.connect) assert int(time.time() - now) == 2 @testing.requires.threading_with_mock @@ -1142,9 +1136,12 @@ class QueuePoolTest(PoolTestBase): # them back to the start of do_get() dbapi = MockDBAPI() p = pool.QueuePool( - creator=lambda: dbapi.connect(delay=.05), + creator=lambda: dbapi.connect(delay=0.05), pool_size=2, - max_overflow=1, use_threadlocal=False, timeout=3) + max_overflow=1, + use_threadlocal=False, + timeout=3, + ) timeouts = [] def checkout(): @@ -1180,25 +1177,26 @@ class QueuePoolTest(PoolTestBase): mutex = threading.Lock() def creator(): - time.sleep(.05) + time.sleep(0.05) with mutex: return dbapi.connect() - p = pool.QueuePool(creator=creator, - pool_size=3, timeout=2, - max_overflow=max_overflow) + p = pool.QueuePool( + creator=creator, pool_size=3, timeout=2, max_overflow=max_overflow + ) peaks = [] def whammy(): for i in range(10): try: con = p.connect() - time.sleep(.005) + time.sleep(0.005) peaks.append(p.overflow()) con.close() del con except tsa.exc.TimeoutError: pass + threads = [] for i in range(thread_count): th = threading.Thread(target=whammy) @@ -1270,28 +1268,29 @@ class QueuePoolTest(PoolTestBase): p = pool.QueuePool(creator=create, pool_size=2, max_overflow=3) threads = [ + threading.Thread(target=run_test, args=("success_one", p, False)), + threading.Thread(target=run_test, args=("success_two", p, False)), + threading.Thread(target=run_test, args=("overflow_one", p, True)), + threading.Thread(target=run_test, args=("overflow_two", p, False)), threading.Thread( - target=run_test, args=("success_one", p, False)), - threading.Thread( - target=run_test, args=("success_two", p, False)), - threading.Thread( - target=run_test, args=("overflow_one", p, True)), - threading.Thread( - target=run_test, args=("overflow_two", p, False)), - threading.Thread( - target=run_test, args=("overflow_three", p, False)) + target=run_test, args=("overflow_three", p, False) + ), ] for t in threads: t.start() - time.sleep(.2) + time.sleep(0.2) for t in threads: t.join(timeout=join_timeout) eq_( dbapi.connect().operation.mock_calls, - [call("success_one"), call("success_two"), - call("overflow_two"), call("overflow_three"), - call("overflow_one")] + [ + call("success_one"), + call("success_two"), + call("overflow_two"), + call("overflow_three"), + call("overflow_one"), + ], ) @testing.requires.threading_with_mock @@ -1314,15 +1313,18 @@ class QueuePoolTest(PoolTestBase): success = [] for timeout in (None, 30): for max_overflow in (0, -1, 3): - p = pool.QueuePool(creator=creator, - pool_size=2, timeout=timeout, - max_overflow=max_overflow) + p = pool.QueuePool( + creator=creator, + pool_size=2, + timeout=timeout, + max_overflow=max_overflow, + ) def waiter(p, timeout, max_overflow): success_key = (timeout, max_overflow) conn = p.connect() success.append(success_key) - time.sleep(.1) + time.sleep(0.1) conn.close() c1 = p.connect() # noqa @@ -1331,8 +1333,8 @@ class QueuePoolTest(PoolTestBase): threads = [] for i in range(2): t = threading.Thread( - target=waiter, - args=(p, timeout, max_overflow)) + target=waiter, args=(p, timeout, max_overflow) + ) t.daemon = True t.start() threads.append(t) @@ -1341,7 +1343,7 @@ class QueuePoolTest(PoolTestBase): # two waiter threads hit upon wait() # inside the queue, before we invalidate the other # two conns - time.sleep(.2) + time.sleep(0.2) p._invalidate(c2) for t in threads: @@ -1381,8 +1383,9 @@ class QueuePoolTest(PoolTestBase): return fairy with patch( - "sqlalchemy.pool._ConnectionRecord.checkout", - _decorate_existing_checkout): + "sqlalchemy.pool._ConnectionRecord.checkout", + _decorate_existing_checkout, + ): conn = p.connect() is_(conn._connection_record.connection, None) conn.close() @@ -1397,25 +1400,25 @@ class QueuePoolTest(PoolTestBase): def creator(): canary.append(1) return dbapi.connect() + p1 = pool.QueuePool( - creator=creator, - pool_size=1, timeout=None, - max_overflow=0) + creator=creator, pool_size=1, timeout=None, max_overflow=0 + ) def waiter(p): conn = p.connect() canary.append(2) - time.sleep(.5) + time.sleep(0.5) conn.close() c1 = p1.connect() threads = [] for i in range(5): - t = threading.Thread(target=waiter, args=(p1, )) + t = threading.Thread(target=waiter, args=(p1,)) t.start() threads.append(t) - time.sleep(.5) + time.sleep(0.5) eq_(canary, [1]) # this also calls invalidate() @@ -1430,9 +1433,9 @@ class QueuePoolTest(PoolTestBase): def test_dispose_closes_pooled(self): dbapi = MockDBAPI() - p = pool.QueuePool(creator=dbapi.connect, - pool_size=2, timeout=None, - max_overflow=0) + p = pool.QueuePool( + creator=dbapi.connect, pool_size=2, timeout=None, max_overflow=0 + ) c1 = p.connect() c2 = p.connect() c1_con = c1.connection @@ -1473,8 +1476,9 @@ class QueuePoolTest(PoolTestBase): def test_mixed_close(self): pool._refs.clear() - p = self._queuepool_fixture(pool_size=3, max_overflow=-1, - use_threadlocal=True) + p = self._queuepool_fixture( + pool_size=3, max_overflow=-1, use_threadlocal=True + ) c1 = p.connect() c2 = p.connect() assert c1 is c2 @@ -1494,9 +1498,7 @@ class QueuePoolTest(PoolTestBase): self._test_overflow_no_gc(False) def _test_overflow_no_gc(self, threadlocal): - p = self._queuepool_fixture( - pool_size=2, - max_overflow=2) + p = self._queuepool_fixture(pool_size=2, max_overflow=2) # disable weakref collection of the # underlying connections @@ -1521,14 +1523,14 @@ class QueuePoolTest(PoolTestBase): eq_( set([c.close.call_count for c in strong_refs]), - set([1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0]) + set([1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0]), ) @testing.requires.predictable_gc def test_weakref_kaboom(self): p = self._queuepool_fixture( - pool_size=3, - max_overflow=-1, use_threadlocal=True) + pool_size=3, max_overflow=-1, use_threadlocal=True + ) c1 = p.connect() c2 = p.connect() c1.close() @@ -1548,8 +1550,8 @@ class QueuePoolTest(PoolTestBase): reference counting.""" p = self._queuepool_fixture( - pool_size=3, - max_overflow=-1, use_threadlocal=True) + pool_size=3, max_overflow=-1, use_threadlocal=True + ) c1 = p.connect() c2 = p.connect() assert c1 is c2 @@ -1565,9 +1567,8 @@ class QueuePoolTest(PoolTestBase): mock.return_value = 10000 p = self._queuepool_fixture( - pool_size=1, - max_overflow=0, - recycle=30) + pool_size=1, max_overflow=0, recycle=30 + ) c1 = p.connect() c_ref = weakref.ref(c1.connection) c1.close() @@ -1583,9 +1584,7 @@ class QueuePoolTest(PoolTestBase): @testing.requires.timing_intensive def test_recycle_on_invalidate(self): - p = self._queuepool_fixture( - pool_size=1, - max_overflow=0) + p = self._queuepool_fixture(pool_size=1, max_overflow=0) c1 = p.connect() c_ref = weakref.ref(c1.connection) c1.close() @@ -1596,16 +1595,14 @@ class QueuePoolTest(PoolTestBase): p._invalidate(c2) assert c2_rec.connection is None c2.close() - time.sleep(.5) + time.sleep(0.5) c3 = p.connect() is_not_(c3.connection, c_ref()) @testing.requires.timing_intensive def test_recycle_on_soft_invalidate(self): - p = self._queuepool_fixture( - pool_size=1, - max_overflow=0) + p = self._queuepool_fixture(pool_size=1, max_overflow=0) c1 = p.connect() c_ref = weakref.ref(c1.connection) c1.close() @@ -1617,7 +1614,7 @@ class QueuePoolTest(PoolTestBase): is_(c2_rec.connection, c2.connection) c2.close() - time.sleep(.5) + time.sleep(0.5) c3 = p.connect() is_not_(c3.connection, c_ref()) is_(c3._connection_record, c2_rec) @@ -1627,15 +1624,17 @@ class QueuePoolTest(PoolTestBase): finalize_fairy = pool._finalize_fairy def assert_no_wr_callback( - connection, connection_record, - pool, ref, echo, fairy=None): + connection, connection_record, pool, ref, echo, fairy=None + ): if fairy is None: raise AssertionError( - "finalize fairy was called as a weakref callback") + "finalize fairy was called as a weakref callback" + ) return finalize_fairy( - connection, connection_record, pool, ref, echo, fairy) - return patch.object( - pool, '_finalize_fairy', assert_no_wr_callback) + connection, connection_record, pool, ref, echo, fairy + ) + + return patch.object(pool, "_finalize_fairy", assert_no_wr_callback) def _assert_cleanup_on_pooled_reconnect(self, dbapi, p): # p is QueuePool with size=1, max_overflow=2, @@ -1646,10 +1645,7 @@ class QueuePoolTest(PoolTestBase): eq_(p.checkedout(), 0) eq_(p._overflow, 0) dbapi.shutdown(True) - assert_raises( - Exception, - p.connect - ) + assert_raises(Exception, p.connect) eq_(p._overflow, 0) eq_(p.checkedout(), 0) # and not 1 @@ -1675,8 +1671,8 @@ class QueuePoolTest(PoolTestBase): @testing.requires.timing_intensive def test_error_on_pooled_reconnect_cleanup_recycle(self): dbapi, p = self._queuepool_dbapi_fixture( - pool_size=1, - max_overflow=2, recycle=1) + pool_size=1, max_overflow=2, recycle=1 + ) c1 = p.connect() c1.close() time.sleep(1.5) @@ -1685,8 +1681,7 @@ class QueuePoolTest(PoolTestBase): def test_connect_handler_not_called_for_recycled(self): """test [ticket:3497]""" - dbapi, p = self._queuepool_dbapi_fixture( - pool_size=2, max_overflow=2) + dbapi, p = self._queuepool_dbapi_fixture(pool_size=2, max_overflow=2) canary = Mock() @@ -1706,16 +1701,10 @@ class QueuePoolTest(PoolTestBase): event.listen(p, "connect", canary.connect) event.listen(p, "checkout", canary.checkout) - assert_raises( - Exception, - p.connect - ) + assert_raises(Exception, p.connect) p._pool.queue = collections.deque( - [ - c for c in p._pool.queue - if c.connection is not None - ] + [c for c in p._pool.queue if c.connection is not None] ) dbapi.shutdown(False) @@ -1724,17 +1713,13 @@ class QueuePoolTest(PoolTestBase): eq_( canary.mock_calls, - [ - call.connect(ANY, ANY), - call.checkout(ANY, ANY, ANY) - ] + [call.connect(ANY, ANY), call.checkout(ANY, ANY, ANY)], ) def test_connect_checkout_handler_always_gets_info(self): """test [ticket:3497]""" - dbapi, p = self._queuepool_dbapi_fixture( - pool_size=2, max_overflow=2) + dbapi, p = self._queuepool_dbapi_fixture(pool_size=2, max_overflow=2) c1 = p.connect() c2 = p.connect() @@ -1751,22 +1736,16 @@ class QueuePoolTest(PoolTestBase): @event.listens_for(p, "connect") def connect(conn, conn_rec): - conn_rec.info['x'] = True + conn_rec.info["x"] = True @event.listens_for(p, "checkout") def checkout(conn, conn_rec, conn_f): - assert 'x' in conn_rec.info + assert "x" in conn_rec.info - assert_raises( - Exception, - p.connect - ) + assert_raises(Exception, p.connect) p._pool.queue = collections.deque( - [ - c for c in p._pool.queue - if c.connection is not None - ] + [c for c in p._pool.queue if c.connection is not None] ) dbapi.shutdown(False) @@ -1774,9 +1753,7 @@ class QueuePoolTest(PoolTestBase): c.close() def test_error_on_pooled_reconnect_cleanup_wcheckout_event(self): - dbapi, p = self._queuepool_dbapi_fixture( - pool_size=1, - max_overflow=2) + dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=2) c1 = p.connect() c1.close() @@ -1791,12 +1768,12 @@ class QueuePoolTest(PoolTestBase): @testing.requires.predictable_gc def test_userspace_disconnectionerror_weakref_finalizer(self): dbapi, pool = self._queuepool_dbapi_fixture( - pool_size=1, - max_overflow=2) + pool_size=1, max_overflow=2 + ) @event.listens_for(pool, "checkout") def handle_checkout_event(dbapi_con, con_record, con_proxy): - if getattr(dbapi_con, 'boom') == 'yes': + if getattr(dbapi_con, "boom") == "yes": raise tsa.exc.DisconnectionError() conn = pool.connect() @@ -1805,7 +1782,7 @@ class QueuePoolTest(PoolTestBase): eq_(old_dbapi_conn.mock_calls, [call.rollback()]) - old_dbapi_conn.boom = 'yes' + old_dbapi_conn.boom = "yes" conn = pool.connect() dbapi_conn = conn.connection @@ -1817,16 +1794,13 @@ class QueuePoolTest(PoolTestBase): # old connection was just closed - did not get an # erroneous reset on return - eq_( - old_dbapi_conn.mock_calls, - [call.rollback(), call.close()] - ) + eq_(old_dbapi_conn.mock_calls, [call.rollback(), call.close()]) @testing.requires.timing_intensive def test_recycle_pool_no_race(self): def slow_close(): slow_closing_connection._slow_close() - time.sleep(.5) + time.sleep(0.5) slow_closing_connection = Mock() slow_closing_connection.connect.return_value.close = slow_close @@ -1847,9 +1821,11 @@ class QueuePoolTest(PoolTestBase): def creator(): return slow_closing_connection.connect() + p1 = TrackQueuePool(creator=creator, pool_size=20) from sqlalchemy import create_engine + eng = create_engine(testing.db.url, pool=p1, _initialize=False) eng.dialect = dialect @@ -1864,8 +1840,8 @@ class QueuePoolTest(PoolTestBase): time.sleep(random.random()) try: conn._handle_dbapi_exception( - Error(), "statement", {}, - Mock(), Mock()) + Error(), "statement", {}, Mock(), Mock() + ) except tsa.exc.DBAPIError: pass @@ -1873,7 +1849,7 @@ class QueuePoolTest(PoolTestBase): # connections threads = [] for conn in conns: - t = threading.Thread(target=attempt, args=(conn, )) + t = threading.Thread(target=attempt, args=(conn,)) t.start() threads.append(t) @@ -1908,8 +1884,9 @@ class QueuePoolTest(PoolTestBase): assert c1.connection.id != c_id def test_recreate(self): - p = self._queuepool_fixture(reset_on_return=None, pool_size=1, - max_overflow=0) + p = self._queuepool_fixture( + reset_on_return=None, pool_size=1, max_overflow=0 + ) p2 = p.recreate() assert p2.size() == 1 assert p2._reset_on_return is pool.reset_none @@ -1964,8 +1941,9 @@ class QueuePoolTest(PoolTestBase): eq_(c2_con.close.call_count, 0) def test_threadfairy(self): - p = self._queuepool_fixture(pool_size=3, max_overflow=-1, - use_threadlocal=True) + p = self._queuepool_fixture( + pool_size=3, max_overflow=-1, use_threadlocal=True + ) c1 = p.connect() c1.close() c2 = p.connect() @@ -1978,9 +1956,7 @@ class QueuePoolTest(PoolTestBase): rec = c1._connection_record c1.close() assert_raises_message( - Warning, - "Double checkin attempted on %s" % rec, - rec.checkin + Warning, "Double checkin attempted on %s" % rec, rec.checkin ) def test_lifo(self): @@ -2064,12 +2040,13 @@ class QueuePoolTest(PoolTestBase): class ResetOnReturnTest(PoolTestBase): def _fixture(self, **kw): dbapi = Mock() - return dbapi, pool.QueuePool( - creator=lambda: dbapi.connect('foo.db'), - **kw) + return ( + dbapi, + pool.QueuePool(creator=lambda: dbapi.connect("foo.db"), **kw), + ) def test_plain_rollback(self): - dbapi, p = self._fixture(reset_on_return='rollback') + dbapi, p = self._fixture(reset_on_return="rollback") c1 = p.connect() c1.close() @@ -2077,7 +2054,7 @@ class ResetOnReturnTest(PoolTestBase): assert not dbapi.connect().commit.called def test_plain_commit(self): - dbapi, p = self._fixture(reset_on_return='commit') + dbapi, p = self._fixture(reset_on_return="commit") c1 = p.connect() c1.close() @@ -2093,7 +2070,7 @@ class ResetOnReturnTest(PoolTestBase): assert not dbapi.connect().commit.called def test_agent_rollback(self): - dbapi, p = self._fixture(reset_on_return='rollback') + dbapi, p = self._fixture(reset_on_return="rollback") class Agent(object): def __init__(self, conn): @@ -2124,7 +2101,7 @@ class ResetOnReturnTest(PoolTestBase): assert not dbapi.connect().commit.called def test_agent_commit(self): - dbapi, p = self._fixture(reset_on_return='commit') + dbapi, p = self._fixture(reset_on_return="commit") class Agent(object): def __init__(self, conn): @@ -2154,7 +2131,7 @@ class ResetOnReturnTest(PoolTestBase): assert dbapi.connect().commit.called def test_reset_agent_disconnect(self): - dbapi, p = self._fixture(reset_on_return='rollback') + dbapi, p = self._fixture(reset_on_return="rollback") class Agent(object): def __init__(self, conn): @@ -2180,16 +2157,16 @@ class SingletonThreadPoolTest(PoolTestBase): def test_cleanup(self): self._test_cleanup(False) -# TODO: the SingletonThreadPool cleanup method -# has an unfixed race condition within the "cleanup" system that -# leads to this test being off by one connection under load; in any -# case, this connection will be closed once it is garbage collected. -# this pool is not a production-level pool and is only used for the -# SQLite "memory" connection, and is not very useful under actual -# multi-threaded conditions -# @testing.requires.threading_with_mock -# def test_cleanup_no_gc(self): -# self._test_cleanup(True) + # TODO: the SingletonThreadPool cleanup method + # has an unfixed race condition within the "cleanup" system that + # leads to this test being off by one connection under load; in any + # case, this connection will be closed once it is garbage collected. + # this pool is not a production-level pool and is only used for the + # SQLite "memory" connection, and is not very useful under actual + # multi-threaded conditions + # @testing.requires.threading_with_mock + # def test_cleanup_no_gc(self): + # self._test_cleanup(True) def _test_cleanup(self, strong_refs): """test that the pool's connections are OK after cleanup() has @@ -2203,6 +2180,7 @@ class SingletonThreadPoolTest(PoolTestBase): # the mock iterator isn't threadsafe... with lock: return dbapi.connect() + p = pool.SingletonThreadPool(creator=creator, pool_size=3) if strong_refs: @@ -2212,7 +2190,9 @@ class SingletonThreadPoolTest(PoolTestBase): c = p.connect() sr.add(c.connection) return c + else: + def _conn(): return p.connect() @@ -2222,7 +2202,7 @@ class SingletonThreadPoolTest(PoolTestBase): assert c c.cursor() c.close() - time.sleep(.1) + time.sleep(0.1) threads = [] for i in range(10): @@ -2241,13 +2221,13 @@ class SingletonThreadPoolTest(PoolTestBase): class AssertionPoolTest(PoolTestBase): def test_connect_error(self): dbapi = MockDBAPI() - p = pool.AssertionPool(creator=lambda: dbapi.connect('foo.db')) + p = pool.AssertionPool(creator=lambda: dbapi.connect("foo.db")) c1 = p.connect() # noqa assert_raises(AssertionError, p.connect) def test_connect_multiple(self): dbapi = MockDBAPI() - p = pool.AssertionPool(creator=lambda: dbapi.connect('foo.db')) + p = pool.AssertionPool(creator=lambda: dbapi.connect("foo.db")) c1 = p.connect() c1.close() c2 = p.connect() @@ -2260,7 +2240,7 @@ class AssertionPoolTest(PoolTestBase): class NullPoolTest(PoolTestBase): def test_reconnect(self): dbapi = MockDBAPI() - p = pool.NullPool(creator=lambda: dbapi.connect('foo.db')) + p = pool.NullPool(creator=lambda: dbapi.connect("foo.db")) c1 = p.connect() c1.close() @@ -2272,10 +2252,8 @@ class NullPoolTest(PoolTestBase): c1 = p.connect() dbapi.connect.assert_has_calls( - [ - call('foo.db'), - call('foo.db')], - any_order=True) + [call("foo.db"), call("foo.db")], any_order=True + ) class StaticPoolTest(PoolTestBase): @@ -2283,7 +2261,8 @@ class StaticPoolTest(PoolTestBase): dbapi = MockDBAPI() def creator(): - return dbapi.connect('foo.db') + return dbapi.connect("foo.db") + p = pool.StaticPool(creator) p2 = p.recreate() assert p._creator is p2._creator diff --git a/test/engine/test_processors.py b/test/engine/test_processors.py index 9f0055e05f..8da838145b 100644 --- a/test/engine/test_processors.py +++ b/test/engine/test_processors.py @@ -4,42 +4,28 @@ from sqlalchemy.testing import assert_raises_message, eq_ class _BooleanProcessorTest(fixtures.TestBase): def test_int_to_bool_none(self): - eq_( - self.module.int_to_boolean(None), - None - ) + eq_(self.module.int_to_boolean(None), None) def test_int_to_bool_zero(self): - eq_( - self.module.int_to_boolean(0), - False - ) + eq_(self.module.int_to_boolean(0), False) def test_int_to_bool_one(self): - eq_( - self.module.int_to_boolean(1), - True - ) + eq_(self.module.int_to_boolean(1), True) def test_int_to_bool_positive_int(self): - eq_( - self.module.int_to_boolean(12), - True - ) + eq_(self.module.int_to_boolean(12), True) def test_int_to_bool_negative_int(self): - eq_( - self.module.int_to_boolean(-4), - True - ) + eq_(self.module.int_to_boolean(-4), True) class CBooleanProcessorTest(_BooleanProcessorTest): - __requires__ = ('cextensions',) + __requires__ = ("cextensions",) @classmethod def setup_class(cls): from sqlalchemy import cprocessors + cls.module = cprocessors @@ -48,42 +34,48 @@ class _DateProcessorTest(fixtures.TestBase): assert_raises_message( ValueError, "Couldn't parse date string '2012' - value is not a string", - self.module.str_to_date, 2012 + self.module.str_to_date, + 2012, ) def test_datetime_no_string(self): assert_raises_message( ValueError, "Couldn't parse datetime string '2012' - value is not a string", - self.module.str_to_datetime, 2012 + self.module.str_to_datetime, + 2012, ) def test_time_no_string(self): assert_raises_message( ValueError, "Couldn't parse time string '2012' - value is not a string", - self.module.str_to_time, 2012 + self.module.str_to_time, + 2012, ) def test_date_invalid_string(self): assert_raises_message( ValueError, "Couldn't parse date string: '5:a'", - self.module.str_to_date, "5:a" + self.module.str_to_date, + "5:a", ) def test_datetime_invalid_string(self): assert_raises_message( ValueError, "Couldn't parse datetime string: '5:a'", - self.module.str_to_datetime, "5:a" + self.module.str_to_datetime, + "5:a", ) def test_time_invalid_string(self): assert_raises_message( ValueError, "Couldn't parse time string: '5:a'", - self.module.str_to_time, "5:a" + self.module.str_to_time, + "5:a", ) @@ -91,103 +83,94 @@ class PyDateProcessorTest(_DateProcessorTest): @classmethod def setup_class(cls): from sqlalchemy import processors + cls.module = type( "util", (object,), - dict((k, staticmethod(v)) - for k, v in list(processors.py_fallback().items())) + dict( + (k, staticmethod(v)) + for k, v in list(processors.py_fallback().items()) + ), ) class CDateProcessorTest(_DateProcessorTest): - __requires__ = ('cextensions',) + __requires__ = ("cextensions",) @classmethod def setup_class(cls): from sqlalchemy import cprocessors + cls.module = cprocessors class _DistillArgsTest(fixtures.TestBase): def test_distill_none(self): - eq_( - self.module._distill_params(None, None), - [] - ) + eq_(self.module._distill_params(None, None), []) def test_distill_no_multi_no_param(self): - eq_( - self.module._distill_params((), {}), - [] - ) + eq_(self.module._distill_params((), {}), []) def test_distill_dict_multi_none_param(self): eq_( - self.module._distill_params(None, {"foo": "bar"}), - [{"foo": "bar"}] + self.module._distill_params(None, {"foo": "bar"}), [{"foo": "bar"}] ) def test_distill_dict_multi_empty_param(self): - eq_( - self.module._distill_params((), {"foo": "bar"}), - [{"foo": "bar"}] - ) + eq_(self.module._distill_params((), {"foo": "bar"}), [{"foo": "bar"}]) def test_distill_single_dict(self): eq_( self.module._distill_params(({"foo": "bar"},), {}), - [{"foo": "bar"}] + [{"foo": "bar"}], ) def test_distill_single_list_strings(self): eq_( self.module._distill_params((["foo", "bar"],), {}), - [["foo", "bar"]] + [["foo", "bar"]], ) def test_distill_single_list_tuples(self): eq_( self.module._distill_params( - ([("foo", "bar"), ("bat", "hoho")],), {}), - [('foo', 'bar'), ('bat', 'hoho')] + ([("foo", "bar"), ("bat", "hoho")],), {} + ), + [("foo", "bar"), ("bat", "hoho")], ) def test_distill_single_list_tuple(self): eq_( self.module._distill_params(([("foo", "bar")],), {}), - [('foo', 'bar')] + [("foo", "bar")], ) def test_distill_multi_list_tuple(self): eq_( self.module._distill_params( - ([("foo", "bar")], [("bar", "bat")]), {}), - ([('foo', 'bar')], [('bar', 'bat')]) + ([("foo", "bar")], [("bar", "bat")]), {} + ), + ([("foo", "bar")], [("bar", "bat")]), ) def test_distill_multi_strings(self): - eq_( - self.module._distill_params(("foo", "bar"), {}), - [('foo', 'bar')] - ) + eq_(self.module._distill_params(("foo", "bar"), {}), [("foo", "bar")]) def test_distill_single_list_dicts(self): eq_( self.module._distill_params( - ([{"foo": "bar"}, {"foo": "hoho"}],), {}), - [{'foo': 'bar'}, {'foo': 'hoho'}] + ([{"foo": "bar"}, {"foo": "hoho"}],), {} + ), + [{"foo": "bar"}, {"foo": "hoho"}], ) def test_distill_single_string(self): - eq_( - self.module._distill_params(("arg",), {}), - [["arg"]] - ) + eq_(self.module._distill_params(("arg",), {}), [["arg"]]) def test_distill_multi_string_tuple(self): eq_( self.module._distill_params((("arg", "arg"),), {}), - [("arg", "arg")] + [("arg", "arg")], ) @@ -195,18 +178,22 @@ class PyDistillArgsTest(_DistillArgsTest): @classmethod def setup_class(cls): from sqlalchemy.engine import util + cls.module = type( "util", (object,), - dict((k, staticmethod(v)) - for k, v in list(util.py_fallback().items())) + dict( + (k, staticmethod(v)) + for k, v in list(util.py_fallback().items()) + ), ) class CDistillArgsTest(_DistillArgsTest): - __requires__ = ('cextensions', ) + __requires__ = ("cextensions",) @classmethod def setup_class(cls): from sqlalchemy import cutils as util + cls.module = util diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index d0b6cc9590..1ba8bcac24 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -1,8 +1,21 @@ -from sqlalchemy.testing import eq_, ne_, assert_raises, \ - expect_warnings, assert_raises_message +from sqlalchemy.testing import ( + eq_, + ne_, + assert_raises, + expect_warnings, + assert_raises_message, +) import time from sqlalchemy import ( - select, MetaData, Integer, String, create_engine, pool, exc, util) + select, + MetaData, + Integer, + String, + create_engine, + pool, + exc, + util, +) from sqlalchemy.testing.schema import Table, Column import sqlalchemy as tsa from sqlalchemy.engine import url @@ -31,35 +44,44 @@ class MockExitIsh(BaseException): def mock_connection(): def mock_cursor(): def execute(*args, **kwargs): - if conn.explode == 'execute': + if conn.explode == "execute": raise MockDisconnect("Lost the DB connection on execute") - elif conn.explode == 'interrupt': + elif conn.explode == "interrupt": conn.explode = "explode_no_disconnect" raise MockExitIsh("Keyboard / greenlet / etc interruption") - elif conn.explode == 'interrupt_dont_break': + elif conn.explode == "interrupt_dont_break": conn.explode = None raise MockExitIsh("Keyboard / greenlet / etc interruption") - elif conn.explode in ('execute_no_disconnect', - 'explode_no_disconnect'): + elif conn.explode in ( + "execute_no_disconnect", + "explode_no_disconnect", + ): raise MockError( "something broke on execute but we didn't lose the " - "connection") - elif conn.explode in ('rollback', 'rollback_no_disconnect', - 'explode_no_disconnect'): + "connection" + ) + elif conn.explode in ( + "rollback", + "rollback_no_disconnect", + "explode_no_disconnect", + ): raise MockError( "something broke on execute but we didn't lose the " - "connection") + "connection" + ) elif args and "SELECT" in args[0]: - cursor.description = [('foo', None, None, None, None, None)] + cursor.description = [("foo", None, None, None, None, None)] else: return def close(): - cursor.fetchall = cursor.fetchone = \ - Mock(side_effect=MockError("cursor closed")) + cursor.fetchall = cursor.fetchone = Mock( + side_effect=MockError("cursor closed") + ) + cursor = Mock( - execute=Mock(side_effect=execute), - close=Mock(side_effect=close)) + execute=Mock(side_effect=execute), close=Mock(side_effect=close) + ) return cursor def cursor(): @@ -67,18 +89,19 @@ def mock_connection(): yield mock_cursor() def rollback(): - if conn.explode == 'rollback': + if conn.explode == "rollback": raise MockDisconnect("Lost the DB connection on rollback") - if conn.explode == 'rollback_no_disconnect': + if conn.explode == "rollback_no_disconnect": raise MockError( "something broke on rollback but we didn't lose the " - "connection") + "connection" + ) else: return conn = Mock( - rollback=Mock(side_effect=rollback), - cursor=Mock(side_effect=cursor())) + rollback=Mock(side_effect=rollback), cursor=Mock(side_effect=cursor()) + ) return conn @@ -94,7 +117,7 @@ def MockDBAPI(): connections.append(conn) yield conn - def shutdown(explode='execute', stop=False): + def shutdown(explode="execute", stop=False): stopped[0] = stop for c in connections: c.explode = explode @@ -114,9 +137,10 @@ def MockDBAPI(): shutdown=Mock(side_effect=shutdown), dispose=Mock(side_effect=dispose), restart=Mock(side_effect=restart), - paramstyle='named', + paramstyle="named", connections=connections, - Error=MockError) + Error=MockError, + ) class PrePingMockTest(fixtures.TestBase): @@ -125,14 +149,18 @@ class PrePingMockTest(fixtures.TestBase): def _pool_fixture(self, pre_ping): dialect = url.make_url( - 'postgresql://foo:bar@localhost/test').get_dialect()() + "postgresql://foo:bar@localhost/test" + ).get_dialect()() dialect.dbapi = self.dbapi _pool = pool.QueuePool( - creator=lambda: self.dbapi.connect('foo.db'), pre_ping=pre_ping, - dialect=dialect) + creator=lambda: self.dbapi.connect("foo.db"), + pre_ping=pre_ping, + dialect=dialect, + ) - dialect.is_disconnect = \ - lambda e, conn, cursor: isinstance(e, MockDisconnect) + dialect.is_disconnect = lambda e, conn, cursor: isinstance( + e, MockDisconnect + ) return _pool def teardown(self): @@ -153,10 +181,7 @@ class PrePingMockTest(fixtures.TestBase): cursor.execute("hi") stale_cursor = stale_connection.cursor() - assert_raises( - MockDisconnect, - stale_cursor.execute, "hi" - ) + assert_raises(MockDisconnect, stale_cursor.execute, "hi") def test_raise_db_is_stopped(self): pool = self._pool_fixture(pre_ping=True) @@ -167,9 +192,7 @@ class PrePingMockTest(fixtures.TestBase): self.dbapi.shutdown("execute", stop=True) assert_raises_message( - MockDisconnect, - "database is stopped", - pool.connect + MockDisconnect, "database is stopped", pool.connect ) def test_waits_til_exec_wo_ping_db_is_stopped(self): @@ -186,7 +209,8 @@ class PrePingMockTest(fixtures.TestBase): assert_raises_message( MockDisconnect, "Lost the DB connection on execute", - cursor.execute, "foo" + cursor.execute, + "foo", ) def test_waits_til_exec_wo_ping_db_is_restarted(self): @@ -204,7 +228,8 @@ class PrePingMockTest(fixtures.TestBase): assert_raises_message( MockDisconnect, "Lost the DB connection on execute", - cursor.execute, "foo" + cursor.execute, + "foo", ) @testing.requires.predictable_gc @@ -232,7 +257,7 @@ class PrePingMockTest(fixtures.TestBase): # erroneous reset on return eq_( old_dbapi_conn.mock_calls, - [call.cursor(), call.rollback(), call.cursor(), call.close()] + [call.cursor(), call.rollback(), call.cursor(), call.close()], ) @@ -241,14 +266,17 @@ class MockReconnectTest(fixtures.TestBase): self.dbapi = MockDBAPI() self.db = testing_engine( - 'postgresql://foo:bar@localhost/test', - options=dict(module=self.dbapi, _initialize=False)) + "postgresql://foo:bar@localhost/test", + options=dict(module=self.dbapi, _initialize=False), + ) self.mock_connect = call( - host='localhost', password='bar', user='foo', database='test') + host="localhost", password="bar", user="foo", database="test" + ) # monkeypatch disconnect checker - self.db.dialect.is_disconnect = \ - lambda e, conn, cursor: isinstance(e, MockDisconnect) + self.db.dialect.is_disconnect = lambda e, conn, cursor: isinstance( + e, MockDisconnect + ) def teardown(self): self.dbapi.dispose() @@ -281,10 +309,7 @@ class MockReconnectTest(fixtures.TestBase): # set it to fail self.dbapi.shutdown() - assert_raises( - tsa.exc.DBAPIError, - conn.execute, select([1]) - ) + assert_raises(tsa.exc.DBAPIError, conn.execute, select([1])) # assert was invalidated @@ -298,14 +323,14 @@ class MockReconnectTest(fixtures.TestBase): # ensure one connection closed... eq_( [c.close.mock_calls for c in self.dbapi.connections], - [[call()], []] + [[call()], []], ) conn = self.db.connect() eq_( [c.close.mock_calls for c in self.dbapi.connections], - [[call()], [call()], []] + [[call()], [call()], []], ) conn.execute(select([1])) @@ -313,7 +338,7 @@ class MockReconnectTest(fixtures.TestBase): eq_( [c.close.mock_calls for c in self.dbapi.connections], - [[call()], [call()], []] + [[call()], [call()], []], ) def test_invalidate_trans(self): @@ -321,29 +346,25 @@ class MockReconnectTest(fixtures.TestBase): trans = conn.begin() self.dbapi.shutdown() - assert_raises( - tsa.exc.DBAPIError, - conn.execute, select([1]) - ) + assert_raises(tsa.exc.DBAPIError, conn.execute, select([1])) - eq_( - [c.close.mock_calls for c in self.dbapi.connections], - [[call()]] - ) + eq_([c.close.mock_calls for c in self.dbapi.connections], [[call()]]) assert not conn.closed assert conn.invalidated assert trans.is_active assert_raises_message( tsa.exc.StatementError, "Can't reconnect until invalid transaction is rolled back", - conn.execute, select([1]) + conn.execute, + select([1]), ) assert trans.is_active assert_raises_message( tsa.exc.InvalidRequestError, "Can't reconnect until invalid transaction is rolled back", - trans.commit) + trans.commit, + ) assert trans.is_active trans.rollback() @@ -352,14 +373,13 @@ class MockReconnectTest(fixtures.TestBase): assert not conn.invalidated eq_( [c.close.mock_calls for c in self.dbapi.connections], - [[call()], []] + [[call()], []], ) def test_invalidate_dont_call_finalizer(self): conn = self.db.connect() finalizer = mock.Mock() - conn.connection._connection_record.\ - finalize_callback.append(finalizer) + conn.connection._connection_record.finalize_callback.append(finalizer) conn.invalidate() assert conn.invalidated eq_(finalizer.call_count, 0) @@ -369,25 +389,16 @@ class MockReconnectTest(fixtures.TestBase): conn.execute(select([1])) - eq_( - self.dbapi.connect.mock_calls, - [self.mock_connect] - ) + eq_(self.dbapi.connect.mock_calls, [self.mock_connect]) self.dbapi.shutdown() - assert_raises( - tsa.exc.DBAPIError, - conn.execute, select([1]) - ) + assert_raises(tsa.exc.DBAPIError, conn.execute, select([1])) assert not conn.closed assert conn.invalidated - eq_( - [c.close.mock_calls for c in self.dbapi.connections], - [[call()]] - ) + eq_([c.close.mock_calls for c in self.dbapi.connections], [[call()]]) # test reconnects conn.execute(select([1])) @@ -395,7 +406,7 @@ class MockReconnectTest(fixtures.TestBase): eq_( [c.close.mock_calls for c in self.dbapi.connections], - [[call()], []] + [[call()], []], ) def test_invalidated_close(self): @@ -403,10 +414,7 @@ class MockReconnectTest(fixtures.TestBase): self.dbapi.shutdown() - assert_raises( - tsa.exc.DBAPIError, - conn.execute, select([1]) - ) + assert_raises(tsa.exc.DBAPIError, conn.execute, select([1])) conn.close() assert conn.closed @@ -414,7 +422,8 @@ class MockReconnectTest(fixtures.TestBase): assert_raises_message( tsa.exc.StatementError, "This Connection is closed", - conn.execute, select([1]) + conn.execute, + select([1]), ) def test_noreconnect_execute_plus_closewresult(self): @@ -426,7 +435,8 @@ class MockReconnectTest(fixtures.TestBase): assert_raises_message( tsa.exc.DBAPIError, "something broke on execute but we didn't lose the connection", - conn.execute, select([1]) + conn.execute, + select([1]), ) assert conn.closed @@ -441,13 +451,14 @@ class MockReconnectTest(fixtures.TestBase): with expect_warnings( "An exception has occurred during handling .*" "something broke on execute but we didn't lose the connection", - py2konly=True + py2konly=True, ): assert_raises_message( tsa.exc.DBAPIError, "something broke on rollback but we didn't " "lose the connection", - conn.execute, select([1]) + conn.execute, + select([1]), ) assert conn.closed @@ -456,7 +467,8 @@ class MockReconnectTest(fixtures.TestBase): assert_raises_message( tsa.exc.StatementError, "This Connection is closed", - conn.execute, select([1]) + conn.execute, + select([1]), ) def test_reconnect_on_reentrant(self): @@ -472,12 +484,13 @@ class MockReconnectTest(fixtures.TestBase): with expect_warnings( "An exception has occurred during handling .*" "something broke on execute but we didn't lose the connection", - py2konly=True + py2konly=True, ): assert_raises_message( tsa.exc.DBAPIError, "Lost the DB connection on rollback", - conn.execute, select([1]) + conn.execute, + select([1]), ) assert not conn.closed @@ -492,12 +505,13 @@ class MockReconnectTest(fixtures.TestBase): with expect_warnings( "An exception has occurred during handling .*" "something broke on execute but we didn't lose the connection", - py2konly=True + py2konly=True, ): assert_raises_message( tsa.exc.DBAPIError, "Lost the DB connection on rollback", - conn.execute, select([1]) + conn.execute, + select([1]), ) assert conn.closed @@ -506,7 +520,8 @@ class MockReconnectTest(fixtures.TestBase): assert_raises_message( tsa.exc.StatementError, "This Connection is closed", - conn.execute, select([1]) + conn.execute, + select([1]), ) def test_check_disconnect_no_cursor(self): @@ -516,14 +531,13 @@ class MockReconnectTest(fixtures.TestBase): conn.close() assert_raises_message( - tsa.exc.DBAPIError, - "cursor closed", - list, result + tsa.exc.DBAPIError, "cursor closed", list, result ) def test_dialect_initialize_once(self): from sqlalchemy.engine.url import URL from sqlalchemy.engine.default import DefaultDialect + dbapi = self.dbapi mock_dialect = Mock() @@ -555,10 +569,7 @@ class MockReconnectTest(fixtures.TestBase): with conn.begin(): conn.execute(select([1])) - assert_raises( - MockExitIsh, - go - ) + assert_raises(MockExitIsh, go) assert conn.invalidated @@ -583,10 +594,7 @@ class MockReconnectTest(fixtures.TestBase): with conn.begin(): conn.execute(select([1])) - assert_raises( - MockExitIsh, - go - ) + assert_raises(MockExitIsh, go) assert not conn.invalidated @@ -607,10 +615,7 @@ class MockReconnectTest(fixtures.TestBase): with conn.begin(): conn.execute(select([1])) - assert_raises( - exc.DBAPIError, # wraps a MockDisconnect - go - ) + assert_raises(exc.DBAPIError, go) # wraps a MockDisconnect assert conn.invalidated @@ -619,6 +624,7 @@ class MockReconnectTest(fixtures.TestBase): conn.execute(select([1])) assert not conn.invalidated + class CursorErrTest(fixtures.TestBase): # this isn't really a "reconnect" test, it's more of # a generic "recovery". maybe this test suite should have been @@ -634,7 +640,7 @@ class CursorErrTest(fixtures.TestBase): yield Mock( description=[], close=Mock(side_effect=DBAPIError("explode")), - execute=Mock(side_effect=DBAPIError("explode")) + execute=Mock(side_effect=DBAPIError("explode")), ) else: yield Mock( @@ -645,22 +651,30 @@ class CursorErrTest(fixtures.TestBase): def connect(): while True: yield Mock( - spec=['cursor', 'commit', 'rollback', 'close'], - cursor=Mock(side_effect=cursor()),) + spec=["cursor", "commit", "rollback", "close"], + cursor=Mock(side_effect=cursor()), + ) return Mock( - Error=DBAPIError, paramstyle='qmark', - connect=Mock(side_effect=connect())) + Error=DBAPIError, + paramstyle="qmark", + connect=Mock(side_effect=connect()), + ) + dbapi = MockDBAPI() from sqlalchemy.engine import default + url = Mock( get_dialect=lambda: default.DefaultDialect, _get_entrypoint=lambda: default.DefaultDialect, _instantiate_plugins=lambda kwargs: (), - translate_connect_args=lambda: {}, query={},) + translate_connect_args=lambda: {}, + query={}, + ) eng = testing_engine( - url, options=dict(module=dbapi, _initialize=initialize)) + url, options=dict(module=dbapi, _initialize=initialize) + ) eng.pool.logger = Mock() return eng @@ -672,19 +686,17 @@ class CursorErrTest(fixtures.TestBase): conn.close() eq_( db.pool.logger.error.mock_calls, - [call('Error closing cursor', exc_info=True)] + [call("Error closing cursor", exc_info=True)], ) def test_cursor_shutdown_in_initialize(self): db = self._fixture(True, True) assert_raises_message( - exc.SAWarning, - "Exception attempting to detect", - db.connect + exc.SAWarning, "Exception attempting to detect", db.connect ) eq_( db.pool.logger.error.mock_calls, - [call('Error closing cursor', exc_info=True)] + [call("Error closing cursor", exc_info=True)], ) @@ -699,7 +711,7 @@ def _assert_invalidated(fn, *args): class RealReconnectTest(fixtures.TestBase): __backend__ = True - __requires__ = 'graceful_disconnects', 'ad_hoc_engines' + __requires__ = "graceful_disconnects", "ad_hoc_engines" def setup(self): self.engine = engines.reconnecting_engine() @@ -812,13 +824,9 @@ class RealReconnectTest(fixtures.TestBase): conn = self.engine.connect() self.engine.test_shutdown() with expect_warnings( - "An exception has occurred during handling .*", - py2konly=True + "An exception has occurred during handling .*", py2konly=True ): - assert_raises( - tsa.exc.DBAPIError, - conn.execute, select([1]) - ) + assert_raises(tsa.exc.DBAPIError, conn.execute, select([1])) def test_rollback_on_invalid_plain(self): conn = self.engine.connect() @@ -847,8 +855,8 @@ class RealReconnectTest(fixtures.TestBase): conn.invalidate() @testing.skip_if( - [lambda: util.py3k, "oracle+cx_oracle"], - "Crashes on py3k+cx_oracle") + [lambda: util.py3k, "oracle+cx_oracle"], "Crashes on py3k+cx_oracle" + ) def test_explode_in_initializer(self): engine = engines.testing_engine() @@ -861,8 +869,8 @@ class RealReconnectTest(fixtures.TestBase): assert_raises(exc.DBAPIError, engine.connect) @testing.skip_if( - [lambda: util.py3k, "oracle+cx_oracle"], - "Crashes on py3k+cx_oracle") + [lambda: util.py3k, "oracle+cx_oracle"], "Crashes on py3k+cx_oracle" + ) def test_explode_in_initializer_disconnect(self): engine = engines.testing_engine() @@ -882,8 +890,9 @@ class RealReconnectTest(fixtures.TestBase): assert_raises(exc.DBAPIError, engine.connect) def test_null_pool(self): - engine = \ - engines.reconnecting_engine(options=dict(poolclass=pool.NullPool)) + engine = engines.reconnecting_engine( + options=dict(poolclass=pool.NullPool) + ) conn = engine.connect() eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed @@ -920,12 +929,14 @@ class RealReconnectTest(fixtures.TestBase): assert_raises_message( tsa.exc.StatementError, "Can't reconnect until invalid transaction is rolled back", - conn.execute, select([1])) + conn.execute, + select([1]), + ) assert trans.is_active assert_raises_message( tsa.exc.InvalidRequestError, "Can't reconnect until invalid transaction is rolled back", - trans.commit + trans.commit, ) assert trans.is_active trans.rollback() @@ -941,7 +952,8 @@ class RecycleTest(fixtures.TestBase): def test_basic(self): for threadlocal in False, True: engine = engines.reconnecting_engine( - options={'pool_threadlocal': threadlocal}) + options={"pool_threadlocal": threadlocal} + ) conn = engine.contextual_connect() eq_(conn.execute(select([1])).scalar(), 1) @@ -971,9 +983,7 @@ class PrePingRealTest(fixtures.TestBase): __backend__ = True def test_pre_ping_db_is_restarted(self): - engine = engines.reconnecting_engine( - options={"pool_pre_ping": True} - ) + engine = engines.reconnecting_engine(options={"pool_pre_ping": True}) conn = engine.connect() eq_(conn.execute(select([1])).scalar(), 1) @@ -991,15 +1001,10 @@ class PrePingRealTest(fixtures.TestBase): curs = stale_connection.cursor() curs.execute("select 1") - assert_raises( - engine.dialect.dbapi.Error, - exercise_stale_connection - ) + assert_raises(engine.dialect.dbapi.Error, exercise_stale_connection) def test_pre_ping_db_stays_shutdown(self): - engine = engines.reconnecting_engine( - options={"pool_pre_ping": True} - ) + engine = engines.reconnecting_engine(options={"pool_pre_ping": True}) conn = engine.connect() eq_(conn.execute(select([1])).scalar(), 1) @@ -1007,10 +1012,7 @@ class PrePingRealTest(fixtures.TestBase): engine.test_shutdown(stop=True) - assert_raises( - exc.DBAPIError, - engine.connect - ) + assert_raises(exc.DBAPIError, engine.connect) class InvalidateDuringResultTest(fixtures.TestBase): @@ -1020,12 +1022,14 @@ class InvalidateDuringResultTest(fixtures.TestBase): self.engine = engines.reconnecting_engine() self.meta = MetaData(self.engine) table = Table( - 'sometable', self.meta, - Column('id', Integer, primary_key=True), - Column('name', String(50))) + "sometable", + self.meta, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + ) self.meta.create_all() table.insert().execute( - [{'id': i, 'name': 'row %d' % i} for i in range(1, 100)] + [{"id": i, "name": "row %d" % i} for i in range(1, 100)] ) def teardown(self): @@ -1034,13 +1038,15 @@ class InvalidateDuringResultTest(fixtures.TestBase): @testing.crashes( "oracle", - "cx_oracle 6 doesn't allow a close like this due to open cursors") - @testing.fails_if([ - '+mysqlconnector', '+mysqldb', '+cymysql', '+pymysql', '+pg8000'], - "Buffers the result set and doesn't check for connection close") + "cx_oracle 6 doesn't allow a close like this due to open cursors", + ) + @testing.fails_if( + ["+mysqlconnector", "+mysqldb", "+cymysql", "+pymysql", "+pg8000"], + "Buffers the result set and doesn't check for connection close", + ) def test_invalidate_on_results(self): conn = self.engine.connect() - result = conn.execute('select * from sometable') + result = conn.execute("select * from sometable") for x in range(20): result.fetchone() self.engine.test_shutdown() diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 57c841dcf6..29e0900396 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -1,14 +1,31 @@ import unicodedata import sqlalchemy as sa from sqlalchemy import schema, inspect, sql -from sqlalchemy import MetaData, Integer, String, Index, ForeignKey, \ - UniqueConstraint, FetchedValue, DefaultClause +from sqlalchemy import ( + MetaData, + Integer, + String, + Index, + ForeignKey, + UniqueConstraint, + FetchedValue, + DefaultClause, +) from sqlalchemy.testing import ( - ComparesTables, engines, AssertsCompiledSQL, - fixtures, skip) + ComparesTables, + engines, + AssertsCompiledSQL, + fixtures, + skip, +) from sqlalchemy.testing.schema import Table, Column -from sqlalchemy.testing import eq_, eq_regex, is_true, assert_raises, \ - assert_raises_message +from sqlalchemy.testing import ( + eq_, + eq_regex, + is_true, + assert_raises, + assert_raises_message, +) from sqlalchemy import testing from sqlalchemy.util import ue from sqlalchemy.testing import config @@ -21,52 +38,60 @@ metadata, users = None, None class ReflectionTest(fixtures.TestBase, ComparesTables): __backend__ = True - @testing.exclude('mssql', '<', (10, 0, 0), - 'Date is only supported on MSSQL 2008+') - @testing.exclude('mysql', '<', (4, 1, 1), - 'early types are squirrely') + @testing.exclude( + "mssql", "<", (10, 0, 0), "Date is only supported on MSSQL 2008+" + ) + @testing.exclude("mysql", "<", (4, 1, 1), "early types are squirrely") @testing.provide_metadata def test_basic_reflection(self): meta = self.metadata - users = Table('engine_users', meta, - Column('user_id', sa.INT, primary_key=True), - Column('user_name', sa.VARCHAR(20), nullable=False), - Column('test1', sa.CHAR(5), nullable=False), - Column('test2', sa.Float(5), nullable=False), - Column('test3', sa.Text), - Column('test4', sa.Numeric(10, 2), nullable=False), - Column('test5', sa.Date), - Column('parent_user_id', sa.Integer, - sa.ForeignKey('engine_users.user_id')), - Column('test6', sa.Date, nullable=False), - Column('test7', sa.Text), - Column('test8', sa.LargeBinary), - Column('test_passivedefault2', - sa.Integer, server_default='5'), - Column('test9', sa.LargeBinary(100)), - Column('test10', sa.Numeric(10, 2)), - test_needs_fk=True) + users = Table( + "engine_users", + meta, + Column("user_id", sa.INT, primary_key=True), + Column("user_name", sa.VARCHAR(20), nullable=False), + Column("test1", sa.CHAR(5), nullable=False), + Column("test2", sa.Float(5), nullable=False), + Column("test3", sa.Text), + Column("test4", sa.Numeric(10, 2), nullable=False), + Column("test5", sa.Date), + Column( + "parent_user_id", + sa.Integer, + sa.ForeignKey("engine_users.user_id"), + ), + Column("test6", sa.Date, nullable=False), + Column("test7", sa.Text), + Column("test8", sa.LargeBinary), + Column("test_passivedefault2", sa.Integer, server_default="5"), + Column("test9", sa.LargeBinary(100)), + Column("test10", sa.Numeric(10, 2)), + test_needs_fk=True, + ) addresses = Table( - 'engine_email_addresses', + "engine_email_addresses", meta, - Column('address_id', sa.Integer, primary_key=True), - Column('remote_user_id', sa.Integer, - sa.ForeignKey(users.c.user_id)), - Column('email_address', sa.String(20)), + Column("address_id", sa.Integer, primary_key=True), + Column( + "remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id) + ), + Column("email_address", sa.String(20)), test_needs_fk=True, - ) + ) meta.create_all() meta2 = MetaData() - reflected_users = Table('engine_users', meta2, - autoload=True, - autoload_with=testing.db) - reflected_addresses = Table('engine_email_addresses', - meta2, - autoload=True, - autoload_with=testing.db) + reflected_users = Table( + "engine_users", meta2, autoload=True, autoload_with=testing.db + ) + reflected_addresses = Table( + "engine_email_addresses", + meta2, + autoload=True, + autoload_with=testing.db, + ) self.assert_tables_equal(users, reflected_users) self.assert_tables_equal(addresses, reflected_addresses) @@ -74,134 +99,152 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): def test_autoload_with_imply_autoload(self,): meta = self.metadata t = Table( - 't', - meta, - Column('id', sa.Integer, primary_key=True), - Column('x', sa.String(20)), - Column('y', sa.Integer)) + "t", + meta, + Column("id", sa.Integer, primary_key=True), + Column("x", sa.String(20)), + Column("y", sa.Integer), + ) meta.create_all() meta2 = MetaData() - reflected_t = Table('t', meta2, autoload_with=testing.db) + reflected_t = Table("t", meta2, autoload_with=testing.db) self.assert_tables_equal(t, reflected_t) @testing.provide_metadata def test_two_foreign_keys(self): meta = self.metadata Table( - 't1', + "t1", meta, - Column('id', sa.Integer, primary_key=True), - Column('t2id', sa.Integer, sa.ForeignKey('t2.id')), - Column('t3id', sa.Integer, sa.ForeignKey('t3.id')), + Column("id", sa.Integer, primary_key=True), + Column("t2id", sa.Integer, sa.ForeignKey("t2.id")), + Column("t3id", sa.Integer, sa.ForeignKey("t3.id")), test_needs_fk=True, - ) - Table('t2', - meta, - Column('id', sa.Integer, primary_key=True), - test_needs_fk=True) - Table('t3', - meta, - Column('id', sa.Integer, primary_key=True), - test_needs_fk=True) + ) + Table( + "t2", + meta, + Column("id", sa.Integer, primary_key=True), + test_needs_fk=True, + ) + Table( + "t3", + meta, + Column("id", sa.Integer, primary_key=True), + test_needs_fk=True, + ) meta.create_all() meta2 = MetaData() - t1r, t2r, t3r = [Table(x, meta2, autoload=True, - autoload_with=testing.db) for x in ('t1', - 't2', 't3')] + t1r, t2r, t3r = [ + Table(x, meta2, autoload=True, autoload_with=testing.db) + for x in ("t1", "t2", "t3") + ] assert t1r.c.t2id.references(t2r.c.id) assert t1r.c.t3id.references(t3r.c.id) def test_nonexistent(self): meta = MetaData(testing.db) - assert_raises(sa.exc.NoSuchTableError, Table, 'nonexistent', - meta, autoload=True) - assert 'nonexistent' not in meta.tables + assert_raises( + sa.exc.NoSuchTableError, Table, "nonexistent", meta, autoload=True + ) + assert "nonexistent" not in meta.tables @testing.provide_metadata def test_include_columns(self): meta = self.metadata - foo = Table('foo', meta, *[Column(n, sa.String(30)) - for n in ['a', 'b', 'c', 'd', 'e', 'f']]) + foo = Table( + "foo", + meta, + *[Column(n, sa.String(30)) for n in ["a", "b", "c", "d", "e", "f"]] + ) meta.create_all() meta2 = MetaData(testing.db) - foo = Table('foo', meta2, autoload=True, - include_columns=['b', 'f', 'e']) + foo = Table( + "foo", meta2, autoload=True, include_columns=["b", "f", "e"] + ) # test that cols come back in original order - eq_([c.name for c in foo.c], ['b', 'e', 'f']) - for c in ('b', 'f', 'e'): + eq_([c.name for c in foo.c], ["b", "e", "f"]) + for c in ("b", "f", "e"): assert c in foo.c - for c in ('a', 'c', 'd'): + for c in ("a", "c", "d"): assert c not in foo.c # test against a table which is already reflected meta3 = MetaData(testing.db) - foo = Table('foo', meta3, autoload=True) - foo = Table('foo', meta3, include_columns=['b', 'f', 'e'], - extend_existing=True) - eq_([c.name for c in foo.c], ['b', 'e', 'f']) - for c in ('b', 'f', 'e'): + foo = Table("foo", meta3, autoload=True) + foo = Table( + "foo", meta3, include_columns=["b", "f", "e"], extend_existing=True + ) + eq_([c.name for c in foo.c], ["b", "e", "f"]) + for c in ("b", "f", "e"): assert c in foo.c - for c in ('a', 'c', 'd'): + for c in ("a", "c", "d"): assert c not in foo.c @testing.provide_metadata def test_extend_existing(self): meta = self.metadata - Table('t', meta, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer), - Column('z', Integer, server_default="5")) + Table( + "t", + meta, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("z", Integer, server_default="5"), + ) meta.create_all() m2 = MetaData() - old_z = Column('z', String, primary_key=True) - old_y = Column('y', String) - old_q = Column('q', Integer) - t2 = Table('t', m2, old_z, old_q) - eq_(t2.primary_key.columns, (t2.c.z, )) - t2 = Table('t', m2, old_y, - extend_existing=True, - autoload=True, - autoload_with=testing.db) - eq_( - set(t2.columns.keys()), - set(['x', 'y', 'z', 'q', 'id']) + old_z = Column("z", String, primary_key=True) + old_y = Column("y", String) + old_q = Column("q", Integer) + t2 = Table("t", m2, old_z, old_q) + eq_(t2.primary_key.columns, (t2.c.z,)) + t2 = Table( + "t", + m2, + old_y, + extend_existing=True, + autoload=True, + autoload_with=testing.db, ) - eq_(t2.primary_key.columns, (t2.c.id, )) + eq_(set(t2.columns.keys()), set(["x", "y", "z", "q", "id"])) + eq_(t2.primary_key.columns, (t2.c.id,)) assert t2.c.z is not old_z assert t2.c.y is old_y assert t2.c.z.type._type_affinity is Integer assert t2.c.q is old_q m3 = MetaData() - t3 = Table('t', m3, Column('z', Integer)) - t3 = Table('t', m3, extend_existing=False, - autoload=True, - autoload_with=testing.db) - eq_( - set(t3.columns.keys()), - set(['z']) + t3 = Table("t", m3, Column("z", Integer)) + t3 = Table( + "t", + m3, + extend_existing=False, + autoload=True, + autoload_with=testing.db, ) + eq_(set(t3.columns.keys()), set(["z"])) m4 = MetaData() - old_z = Column('z', String, primary_key=True) - old_y = Column('y', String) - old_q = Column('q', Integer) - t4 = Table('t', m4, old_z, old_q) - eq_(t4.primary_key.columns, (t4.c.z, )) - t4 = Table('t', m4, old_y, - extend_existing=True, - autoload=True, - autoload_replace=False, - autoload_with=testing.db) - eq_( - set(t4.columns.keys()), - set(['x', 'y', 'z', 'q', 'id']) + old_z = Column("z", String, primary_key=True) + old_y = Column("y", String) + old_q = Column("q", Integer) + t4 = Table("t", m4, old_z, old_q) + eq_(t4.primary_key.columns, (t4.c.z,)) + t4 = Table( + "t", + m4, + old_y, + extend_existing=True, + autoload=True, + autoload_replace=False, + autoload_with=testing.db, ) - eq_(t4.primary_key.columns, (t4.c.id, )) + eq_(set(t4.columns.keys()), set(["x", "y", "z", "q", "id"])) + eq_(t4.primary_key.columns, (t4.c.id,)) assert t4.c.z is old_z assert t4.c.y is old_y assert t4.c.z.type._type_affinity is String @@ -211,15 +254,19 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): def test_extend_existing_reflect_all_dont_dupe_index(self): m = self.metadata d = Table( - "d", m, Column('id', Integer, primary_key=True), - Column('foo', String(50)), - Column('bar', String(50)), - UniqueConstraint('bar') + "d", + m, + Column("id", Integer, primary_key=True), + Column("foo", String(50)), + Column("bar", String(50)), + UniqueConstraint("bar"), ) Index("foo_idx", d.c.foo) Table( - "b", m, Column('id', Integer, primary_key=True), - Column('aid', ForeignKey('d.id')) + "b", + m, + Column("id", Integer, primary_key=True), + Column("aid", ForeignKey("d.id")), ) m.create_all() @@ -227,17 +274,27 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): m2.reflect(testing.db, extend_existing=True) eq_( - len([idx for idx in m2.tables['d'].indexes - if idx.name == 'foo_idx']), - 1 + len( + [ + idx + for idx in m2.tables["d"].indexes + if idx.name == "foo_idx" + ] + ), + 1, ) - if testing.requires.\ - unique_constraint_reflection_no_index_overlap.enabled: + if ( + testing.requires.unique_constraint_reflection_no_index_overlap.enabled + ): eq_( - len([ - const for const in m2.tables['d'].constraints - if isinstance(const, UniqueConstraint)]), - 1 + len( + [ + const + for const in m2.tables["d"].constraints + if isinstance(const, UniqueConstraint) + ] + ), + 1, ) @testing.emits_warning(r".*omitted columns") @@ -245,20 +302,20 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): def test_include_columns_indexes(self): m = self.metadata - t1 = Table('t1', m, Column('a', sa.Integer), Column('b', sa.Integer)) - sa.Index('foobar', t1.c.a, t1.c.b) - sa.Index('bat', t1.c.a) + t1 = Table("t1", m, Column("a", sa.Integer), Column("b", sa.Integer)) + sa.Index("foobar", t1.c.a, t1.c.b) + sa.Index("bat", t1.c.a) m.create_all() m2 = MetaData(testing.db) - t2 = Table('t1', m2, autoload=True) + t2 = Table("t1", m2, autoload=True) assert len(t2.indexes) == 2 m2 = MetaData(testing.db) - t2 = Table('t1', m2, autoload=True, include_columns=['a']) + t2 = Table("t1", m2, autoload=True, include_columns=["a"]) assert len(t2.indexes) == 1 m2 = MetaData(testing.db) - t2 = Table('t1', m2, autoload=True, include_columns=['a', 'b']) + t2 = Table("t1", m2, autoload=True, include_columns=["a", "b"]) assert len(t2.indexes) == 2 @testing.provide_metadata @@ -267,17 +324,26 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): establishes the FK not present in the DB. """ - Table('a', self.metadata, Column('id', Integer, primary_key=True)) - Table('b', self.metadata, Column('id', Integer, primary_key=True), - Column('a_id', Integer)) + Table("a", self.metadata, Column("id", Integer, primary_key=True)) + Table( + "b", + self.metadata, + Column("id", Integer, primary_key=True), + Column("a_id", Integer), + ) self.metadata.create_all() m2 = MetaData() - b2 = Table('b', m2, Column('a_id', Integer, sa.ForeignKey('a.id'))) - a2 = Table('a', m2, autoload=True, autoload_with=testing.db) - b2 = Table('b', m2, extend_existing=True, autoload=True, - autoload_with=testing.db, - autoload_replace=False) + b2 = Table("b", m2, Column("a_id", Integer, sa.ForeignKey("a.id"))) + a2 = Table("a", m2, autoload=True, autoload_with=testing.db) + b2 = Table( + "b", + m2, + extend_existing=True, + autoload=True, + autoload_with=testing.db, + autoload_replace=False, + ) assert b2.c.id is not None assert b2.c.a_id.references(a2.c.id) @@ -290,17 +356,26 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): the in-python one only. """ - Table('a', self.metadata, Column('id', Integer, primary_key=True)) - Table('b', self.metadata, Column('id', Integer, primary_key=True), - Column('a_id', Integer, sa.ForeignKey('a.id'))) + Table("a", self.metadata, Column("id", Integer, primary_key=True)) + Table( + "b", + self.metadata, + Column("id", Integer, primary_key=True), + Column("a_id", Integer, sa.ForeignKey("a.id")), + ) self.metadata.create_all() m2 = MetaData() - b2 = Table('b', m2, Column('a_id', Integer, sa.ForeignKey('a.id'))) - a2 = Table('a', m2, autoload=True, autoload_with=testing.db) - b2 = Table('b', m2, extend_existing=True, autoload=True, - autoload_with=testing.db, - autoload_replace=False) + b2 = Table("b", m2, Column("a_id", Integer, sa.ForeignKey("a.id"))) + a2 = Table("a", m2, autoload=True, autoload_with=testing.db) + b2 = Table( + "b", + m2, + extend_existing=True, + autoload=True, + autoload_with=testing.db, + autoload_replace=False, + ) assert b2.c.id is not None assert b2.c.a_id.references(a2.c.id) @@ -312,17 +387,26 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): DB means the FK is skipped and doesn't get installed at all. """ - Table('a', self.metadata, Column('id', Integer, primary_key=True)) - Table('b', self.metadata, Column('id', Integer, primary_key=True), - Column('a_id', Integer, sa.ForeignKey('a.id'))) + Table("a", self.metadata, Column("id", Integer, primary_key=True)) + Table( + "b", + self.metadata, + Column("id", Integer, primary_key=True), + Column("a_id", Integer, sa.ForeignKey("a.id")), + ) self.metadata.create_all() m2 = MetaData() - b2 = Table('b', m2, Column('a_id', Integer)) - a2 = Table('a', m2, autoload=True, autoload_with=testing.db) - b2 = Table('b', m2, extend_existing=True, autoload=True, - autoload_with=testing.db, - autoload_replace=False) + b2 = Table("b", m2, Column("a_id", Integer)) + a2 = Table("a", m2, autoload=True, autoload_with=testing.db) + b2 = Table( + "b", + m2, + extend_existing=True, + autoload=True, + autoload_with=testing.db, + autoload_replace=False, + ) assert b2.c.id is not None assert not b2.c.a_id.references(a2.c.id) @@ -330,18 +414,24 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): @testing.provide_metadata def test_autoload_replace_primary_key(self): - Table('a', self.metadata, Column('id', Integer)) + Table("a", self.metadata, Column("id", Integer)) self.metadata.create_all() m2 = MetaData() - a2 = Table('a', m2, Column('id', Integer, primary_key=True)) + a2 = Table("a", m2, Column("id", Integer, primary_key=True)) - Table('a', m2, autoload=True, autoload_with=testing.db, - autoload_replace=False, extend_existing=True) + Table( + "a", + m2, + autoload=True, + autoload_with=testing.db, + autoload_replace=False, + extend_existing=True, + ) eq_(list(a2.primary_key), [a2.c.id]) def test_autoload_replace_arg(self): - Table('t', MetaData(), autoload_replace=False) + Table("t", MetaData(), autoload_replace=False) @testing.provide_metadata def test_autoincrement_col(self): @@ -353,28 +443,31 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): meta = self.metadata Table( - 'test', meta, - Column('id', sa.Integer, primary_key=True), - Column('data', sa.String(50)), - mysql_engine='InnoDB' + "test", + meta, + Column("id", sa.Integer, primary_key=True), + Column("data", sa.String(50)), + mysql_engine="InnoDB", ) Table( - 'test2', meta, + "test2", + meta, Column( - 'id', sa.Integer, sa.ForeignKey('test.id'), primary_key=True), - Column('id2', sa.Integer, primary_key=True), - Column('data', sa.String(50)), - mysql_engine='InnoDB' + "id", sa.Integer, sa.ForeignKey("test.id"), primary_key=True + ), + Column("id2", sa.Integer, primary_key=True), + Column("data", sa.String(50)), + mysql_engine="InnoDB", ) meta.create_all() m2 = MetaData(testing.db) - t1a = Table('test', m2, autoload=True) + t1a = Table("test", m2, autoload=True) assert t1a._autoincrement_column is t1a.c.id - t2a = Table('test2', m2, autoload=True) + t2a = Table("test2", m2, autoload=True) assert t2a._autoincrement_column is None - @skip('sqlite') + @skip("sqlite") @testing.provide_metadata def test_unknown_types(self): """Test the handling of unknown types for the given dialect. @@ -383,8 +476,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): 'affinity types' - this feature is tested in that dialect's test spec. """ meta = self.metadata - t = Table("test", meta, - Column('foo', sa.DateTime)) + t = Table("test", meta, Column("foo", sa.DateTime)) ischema_names = testing.db.dialect.ischema_names t.create() @@ -393,7 +485,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): m2 = MetaData(testing.db) assert_raises(sa.exc.SAWarning, Table, "test", m2, autoload=True) - @testing.emits_warning('Did not recognize type') + @testing.emits_warning("Did not recognize type") def warns(): m3 = MetaData(testing.db) t3 = Table("test", m3, autoload=True) @@ -406,18 +498,22 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): def test_basic_override(self): meta = self.metadata table = Table( - 'override_test', meta, - Column('col1', sa.Integer, primary_key=True), - Column('col2', sa.String(20)), - Column('col3', sa.Numeric) + "override_test", + meta, + Column("col1", sa.Integer, primary_key=True), + Column("col2", sa.String(20)), + Column("col3", sa.Numeric), ) table.create() meta2 = MetaData(testing.db) table = Table( - 'override_test', meta2, - Column('col2', sa.Unicode()), - Column('col4', sa.String(30)), autoload=True) + "override_test", + meta2, + Column("col2", sa.Unicode()), + Column("col4", sa.String(30)), + autoload=True, + ) self.assert_(isinstance(table.c.col1.type, sa.Integer)) self.assert_(isinstance(table.c.col2.type, sa.Unicode)) @@ -427,18 +523,21 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): def test_override_upgrade_pk_flag(self): meta = self.metadata table = Table( - 'override_test', meta, - Column('col1', sa.Integer), - Column('col2', sa.String(20)), - Column('col3', sa.Numeric) + "override_test", + meta, + Column("col1", sa.Integer), + Column("col2", sa.String(20)), + Column("col3", sa.Numeric), ) table.create() meta2 = MetaData(testing.db) table = Table( - 'override_test', meta2, - Column('col1', sa.Integer, primary_key=True), - autoload=True) + "override_test", + meta2, + Column("col1", sa.Integer, primary_key=True), + autoload=True, + ) eq_(list(table.primary_key), [table.c.col1]) eq_(table.c.col1.primary_key, True) @@ -450,31 +549,45 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): a primary key column""" meta = self.metadata - Table('users', meta, - Column('id', sa.Integer, primary_key=True), - Column('name', sa.String(30))) - Table('addresses', meta, - Column('id', sa.Integer, primary_key=True), - Column('street', sa.String(30))) + Table( + "users", + meta, + Column("id", sa.Integer, primary_key=True), + Column("name", sa.String(30)), + ) + Table( + "addresses", + meta, + Column("id", sa.Integer, primary_key=True), + Column("street", sa.String(30)), + ) meta.create_all() meta2 = MetaData(testing.db) - a2 = Table('addresses', meta2, - Column('id', sa.Integer, - sa.ForeignKey('users.id'), primary_key=True), - autoload=True) - u2 = Table('users', meta2, autoload=True) + a2 = Table( + "addresses", + meta2, + Column( + "id", sa.Integer, sa.ForeignKey("users.id"), primary_key=True + ), + autoload=True, + ) + u2 = Table("users", meta2, autoload=True) assert list(a2.primary_key) == [a2.c.id] assert list(u2.primary_key) == [u2.c.id] assert u2.join(a2).onclause.compare(u2.c.id == a2.c.id) meta3 = MetaData(testing.db) - u3 = Table('users', meta3, autoload=True) - a3 = Table('addresses', meta3, - Column('id', sa.Integer, sa.ForeignKey('users.id'), - primary_key=True), - autoload=True) + u3 = Table("users", meta3, autoload=True) + a3 = Table( + "addresses", + meta3, + Column( + "id", sa.Integer, sa.ForeignKey("users.id"), primary_key=True + ), + autoload=True, + ) assert list(a3.primary_key) == [a3.c.id] assert list(u3.primary_key) == [u3.c.id] @@ -487,52 +600,66 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): is common with MySQL MyISAM tables.""" meta = self.metadata - Table('users', meta, - Column('id', sa.Integer, primary_key=True), - Column('name', sa.String(30))) - Table('addresses', meta, - Column('id', sa.Integer, primary_key=True), - Column('street', sa.String(30)), - Column('user_id', sa.Integer)) + Table( + "users", + meta, + Column("id", sa.Integer, primary_key=True), + Column("name", sa.String(30)), + ) + Table( + "addresses", + meta, + Column("id", sa.Integer, primary_key=True), + Column("street", sa.String(30)), + Column("user_id", sa.Integer), + ) meta.create_all() meta2 = MetaData(testing.db) - a2 = Table('addresses', meta2, - Column('user_id', sa.Integer, sa.ForeignKey('users.id')), - autoload=True) - u2 = Table('users', meta2, autoload=True) + a2 = Table( + "addresses", + meta2, + Column("user_id", sa.Integer, sa.ForeignKey("users.id")), + autoload=True, + ) + u2 = Table("users", meta2, autoload=True) assert len(a2.c.user_id.foreign_keys) == 1 assert len(a2.foreign_keys) == 1 assert [c.parent for c in a2.foreign_keys] == [a2.c.user_id] - assert [c.parent for c in a2.c.user_id.foreign_keys] \ - == [a2.c.user_id] - assert list(a2.c.user_id.foreign_keys)[0].parent \ - is a2.c.user_id + assert [c.parent for c in a2.c.user_id.foreign_keys] == [a2.c.user_id] + assert list(a2.c.user_id.foreign_keys)[0].parent is a2.c.user_id assert u2.join(a2).onclause.compare(u2.c.id == a2.c.user_id) meta3 = MetaData(testing.db) - u3 = Table('users', meta3, autoload=True) + u3 = Table("users", meta3, autoload=True) - a3 = Table('addresses', meta3, - Column('user_id', sa.Integer, sa.ForeignKey('users.id')), - autoload=True) + a3 = Table( + "addresses", + meta3, + Column("user_id", sa.Integer, sa.ForeignKey("users.id")), + autoload=True, + ) assert u3.join(a3).onclause.compare(u3.c.id == a3.c.user_id) meta4 = MetaData(testing.db) - u4 = Table('users', meta4, - Column('id', sa.Integer, key='u_id', primary_key=True), - autoload=True) + u4 = Table( + "users", + meta4, + Column("id", sa.Integer, key="u_id", primary_key=True), + autoload=True, + ) a4 = Table( - 'addresses', + "addresses", meta4, - Column('id', sa.Integer, key='street', - primary_key=True), - Column('street', sa.String(30), key='user_id'), - Column('user_id', sa.Integer, sa.ForeignKey('users.u_id'), - key='id'), - autoload=True) + Column("id", sa.Integer, key="street", primary_key=True), + Column("street", sa.String(30), key="user_id"), + Column( + "user_id", sa.Integer, sa.ForeignKey("users.u_id"), key="id" + ), + autoload=True, + ) assert u4.join(a4).onclause.compare(u4.c.u_id == a4.c.id) assert list(u4.primary_key) == [u4.c.u_id] assert len(u4.columns) == 2 @@ -546,28 +673,31 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): metadata = self.metadata - Table('a', - metadata, - Column('x', sa.Integer, primary_key=True), - Column('y', sa.Integer, primary_key=True)) + Table( + "a", + metadata, + Column("x", sa.Integer, primary_key=True), + Column("y", sa.Integer, primary_key=True), + ) - Table('b', - metadata, - Column('x', sa.Integer, primary_key=True), - Column('y', sa.Integer, primary_key=True), - sa.ForeignKeyConstraint(['x', 'y'], ['a.x', 'a.y'])) + Table( + "b", + metadata, + Column("x", sa.Integer, primary_key=True), + Column("y", sa.Integer, primary_key=True), + sa.ForeignKeyConstraint(["x", "y"], ["a.x", "a.y"]), + ) metadata.create_all() meta2 = MetaData() - c1 = Column('x', sa.Integer, primary_key=True) - c2 = Column('y', sa.Integer, primary_key=True) - f1 = sa.ForeignKeyConstraint(['x', 'y'], ['a.x', 'a.y']) - b1 = Table('b', - meta2, c1, c2, f1, - autoload=True, - autoload_with=testing.db) + c1 = Column("x", sa.Integer, primary_key=True) + c2 = Column("y", sa.Integer, primary_key=True) + f1 = sa.ForeignKeyConstraint(["x", "y"], ["a.x", "a.y"]) + b1 = Table( + "b", meta2, c1, c2, f1, autoload=True, autoload_with=testing.db + ) assert b1.c.x is c1 assert b1.c.y is c2 @@ -580,19 +710,28 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): and that ForeignKey targeting during reflection still works.""" meta = self.metadata - Table('a', meta, - Column('x', sa.Integer, primary_key=True), - Column('z', sa.Integer), - test_needs_fk=True) - Table('b', meta, - Column('y', sa.Integer, sa.ForeignKey('a.x')), - test_needs_fk=True) + Table( + "a", + meta, + Column("x", sa.Integer, primary_key=True), + Column("z", sa.Integer), + test_needs_fk=True, + ) + Table( + "b", + meta, + Column("y", sa.Integer, sa.ForeignKey("a.x")), + test_needs_fk=True, + ) meta.create_all() m2 = MetaData(testing.db) - a2 = Table('a', m2, - Column('x', sa.Integer, primary_key=True, key='x1'), - autoload=True) - b2 = Table('b', m2, autoload=True) + a2 = Table( + "a", + m2, + Column("x", sa.Integer, primary_key=True, key="x1"), + autoload=True, + ) + b2 = Table("b", m2, autoload=True) assert a2.join(b2).onclause.compare(a2.c.x1 == b2.c.y) assert b2.c.y.references(a2.c.x1) @@ -605,21 +744,27 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): """ meta = self.metadata - Table('a', meta, - Column('x', sa.Integer, primary_key=True), - Column('z', sa.Integer), - test_needs_fk=True) - Table('b', meta, - Column('y', sa.Integer, sa.ForeignKey('a.x')), - test_needs_fk=True) + Table( + "a", + meta, + Column("x", sa.Integer, primary_key=True), + Column("z", sa.Integer), + test_needs_fk=True, + ) + Table( + "b", + meta, + Column("y", sa.Integer, sa.ForeignKey("a.x")), + test_needs_fk=True, + ) meta.create_all() m2 = MetaData(testing.db) - a2 = Table('a', m2, include_columns=['z'], autoload=True) - b2 = Table('b', m2, autoload=True) + a2 = Table("a", m2, include_columns=["z"], autoload=True) + b2 = Table("b", m2, autoload=True) assert_raises(sa.exc.NoReferencedColumnError, a2.join, b2) - @testing.exclude('mysql', '<', (4, 1, 1), 'innodb funkiness') + @testing.exclude("mysql", "<", (4, 1, 1), "innodb funkiness") @testing.provide_metadata def test_override_existing_fk(self): """test that you can override columns and specify new foreign @@ -627,21 +772,30 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): have that foreign key, and that the FK is not duped. """ meta = self.metadata - Table('users', meta, - Column('id', sa.Integer, primary_key=True), - Column('name', sa.String(30)), - test_needs_fk=True) - Table('addresses', meta, - Column('id', sa.Integer, primary_key=True), - Column('user_id', sa.Integer, sa.ForeignKey('users.id')), - test_needs_fk=True) + Table( + "users", + meta, + Column("id", sa.Integer, primary_key=True), + Column("name", sa.String(30)), + test_needs_fk=True, + ) + Table( + "addresses", + meta, + Column("id", sa.Integer, primary_key=True), + Column("user_id", sa.Integer, sa.ForeignKey("users.id")), + test_needs_fk=True, + ) meta.create_all() meta2 = MetaData(testing.db) - a2 = Table('addresses', meta2, - Column('user_id', sa.Integer, sa.ForeignKey('users.id')), - autoload=True) - u2 = Table('users', meta2, autoload=True) + a2 = Table( + "addresses", + meta2, + Column("user_id", sa.Integer, sa.ForeignKey("users.id")), + autoload=True, + ) + u2 = Table("users", meta2, autoload=True) s = sa.select([a2]) assert s.c.user_id is not None @@ -649,19 +803,24 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): assert len(a2.c.user_id.foreign_keys) == 1 assert len(a2.constraints) == 2 assert [c.parent for c in a2.foreign_keys] == [a2.c.user_id] - assert [c.parent for c in a2.c.user_id.foreign_keys] \ - == [a2.c.user_id] - assert list(a2.c.user_id.foreign_keys)[0].parent \ - is a2.c.user_id + assert [c.parent for c in a2.c.user_id.foreign_keys] == [a2.c.user_id] + assert list(a2.c.user_id.foreign_keys)[0].parent is a2.c.user_id assert u2.join(a2).onclause.compare(u2.c.id == a2.c.user_id) meta2 = MetaData(testing.db) - u2 = Table('users', meta2, Column('id', sa.Integer, primary_key=True), - autoload=True) - a2 = Table('addresses', meta2, - Column('id', sa.Integer, primary_key=True), - Column('user_id', sa.Integer, sa.ForeignKey('users.id')), - autoload=True) + u2 = Table( + "users", + meta2, + Column("id", sa.Integer, primary_key=True), + autoload=True, + ) + a2 = Table( + "addresses", + meta2, + Column("id", sa.Integer, primary_key=True), + Column("user_id", sa.Integer, sa.ForeignKey("users.id")), + autoload=True, + ) s = sa.select([a2]) assert s.c.user_id is not None @@ -669,58 +828,67 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): assert len(a2.c.user_id.foreign_keys) == 1 assert len(a2.constraints) == 2 assert [c.parent for c in a2.foreign_keys] == [a2.c.user_id] - assert [c.parent for c in a2.c.user_id.foreign_keys] \ - == [a2.c.user_id] - assert list(a2.c.user_id.foreign_keys)[0].parent \ - is a2.c.user_id + assert [c.parent for c in a2.c.user_id.foreign_keys] == [a2.c.user_id] + assert list(a2.c.user_id.foreign_keys)[0].parent is a2.c.user_id assert u2.join(a2).onclause.compare(u2.c.id == a2.c.user_id) - @testing.only_on(['postgresql', 'mysql']) + @testing.only_on(["postgresql", "mysql"]) @testing.provide_metadata def test_fk_options(self): """test that foreign key reflection includes options (on backends with {dialect}.get_foreign_keys() support)""" - if testing.against('postgresql'): - test_attrs = ('match', 'onupdate', 'ondelete', - 'deferrable', 'initially') + if testing.against("postgresql"): + test_attrs = ( + "match", + "onupdate", + "ondelete", + "deferrable", + "initially", + ) addresses_user_id_fkey = sa.ForeignKey( # Each option is specifically not a Postgres default, or # it won't be returned by PG's inspection - 'users.id', - name='addresses_user_id_fkey', - match='FULL', - onupdate='RESTRICT', - ondelete='RESTRICT', + "users.id", + name="addresses_user_id_fkey", + match="FULL", + onupdate="RESTRICT", + ondelete="RESTRICT", deferrable=True, - initially='DEFERRED' + initially="DEFERRED", ) - elif testing.against('mysql'): + elif testing.against("mysql"): # MATCH, DEFERRABLE, and INITIALLY cannot be defined for MySQL # ON UPDATE and ON DELETE have defaults of RESTRICT, which are # elided by MySQL's inspection addresses_user_id_fkey = sa.ForeignKey( - 'users.id', - name='addresses_user_id_fkey', - onupdate='CASCADE', - ondelete='CASCADE' + "users.id", + name="addresses_user_id_fkey", + onupdate="CASCADE", + ondelete="CASCADE", ) - test_attrs = ('onupdate', 'ondelete') + test_attrs = ("onupdate", "ondelete") meta = self.metadata - Table('users', meta, - Column('id', sa.Integer, primary_key=True), - Column('name', sa.String(30)), - test_needs_fk=True) - Table('addresses', meta, - Column('id', sa.Integer, primary_key=True), - Column('user_id', sa.Integer, addresses_user_id_fkey), - test_needs_fk=True) + Table( + "users", + meta, + Column("id", sa.Integer, primary_key=True), + Column("name", sa.String(30)), + test_needs_fk=True, + ) + Table( + "addresses", + meta, + Column("id", sa.Integer, primary_key=True), + Column("user_id", sa.Integer, addresses_user_id_fkey), + test_needs_fk=True, + ) meta.create_all() meta2 = MetaData() meta2.reflect(testing.db) - for fk in meta2.tables['addresses'].foreign_keys: + for fk in meta2.tables["addresses"].foreign_keys: ref = addresses_user_id_fkey for attr in test_attrs: eq_(getattr(fk, attr), getattr(ref, attr)) @@ -729,7 +897,8 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): """test that primary key reflection not tripped up by unique indexes""" - testing.db.execute(""" + testing.db.execute( + """ CREATE TABLE book ( id INTEGER NOT NULL, title VARCHAR(100) NOT NULL, @@ -737,10 +906,11 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): series_id INTEGER, UNIQUE(series, series_id), PRIMARY KEY(id) - )""") + )""" + ) try: metadata = MetaData(bind=testing.db) - book = Table('book', metadata, autoload=True) + book = Table("book", metadata, autoload=True) assert book.primary_key.contains_column(book.c.id) assert not book.primary_key.contains_column(book.c.series) assert len(book.primary_key) == 1 @@ -749,22 +919,27 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): def test_fk_error(self): metadata = MetaData(testing.db) - Table('slots', metadata, - Column('slot_id', sa.Integer, primary_key=True), - Column('pkg_id', sa.Integer, sa.ForeignKey('pkgs.pkg_id')), - Column('slot', sa.String(128))) + Table( + "slots", + metadata, + Column("slot_id", sa.Integer, primary_key=True), + Column("pkg_id", sa.Integer, sa.ForeignKey("pkgs.pkg_id")), + Column("slot", sa.String(128)), + ) assert_raises_message( sa.exc.InvalidRequestError, "Foreign key associated with column 'slots.pkg_id' " "could not find table 'pkgs' with which to generate " "a foreign key to target column 'pkg_id'", - metadata.create_all) + metadata.create_all, + ) def test_composite_pks(self): """test reflection of a composite primary key""" - testing.db.execute(""" + testing.db.execute( + """ CREATE TABLE book ( id INTEGER NOT NULL, isbn VARCHAR(50) NOT NULL, @@ -773,10 +948,11 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): series_id INTEGER NOT NULL, UNIQUE(series, series_id), PRIMARY KEY(id, isbn) - )""") + )""" + ) try: metadata = MetaData(bind=testing.db) - book = Table('book', metadata, autoload=True) + book = Table("book", metadata, autoload=True) assert book.primary_key.contains_column(book.c.id) assert book.primary_key.contains_column(book.c.isbn) assert not book.primary_key.contains_column(book.c.series) @@ -784,50 +960,56 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): finally: testing.db.execute("drop table book") - @testing.exclude('mysql', '<', (4, 1, 1), 'innodb funkiness') + @testing.exclude("mysql", "<", (4, 1, 1), "innodb funkiness") @testing.provide_metadata def test_composite_fk(self): """test reflection of composite foreign keys""" meta = self.metadata multi = Table( - 'multi', meta, - Column('multi_id', sa.Integer, primary_key=True), - Column('multi_rev', sa.Integer, primary_key=True), - Column('multi_hoho', sa.Integer, primary_key=True), - Column('name', sa.String(50), nullable=False), - Column('val', sa.String(100)), + "multi", + meta, + Column("multi_id", sa.Integer, primary_key=True), + Column("multi_rev", sa.Integer, primary_key=True), + Column("multi_hoho", sa.Integer, primary_key=True), + Column("name", sa.String(50), nullable=False), + Column("val", sa.String(100)), + test_needs_fk=True, + ) + multi2 = Table( + "multi2", + meta, + Column("id", sa.Integer, primary_key=True), + Column("foo", sa.Integer), + Column("bar", sa.Integer), + Column("lala", sa.Integer), + Column("data", sa.String(50)), + sa.ForeignKeyConstraint( + ["foo", "bar", "lala"], + ["multi.multi_id", "multi.multi_rev", "multi.multi_hoho"], + ), test_needs_fk=True, ) - multi2 = Table('multi2', meta, - Column('id', sa.Integer, primary_key=True), - Column('foo', sa.Integer), - Column('bar', sa.Integer), - Column('lala', sa.Integer), - Column('data', sa.String(50)), - sa.ForeignKeyConstraint(['foo', 'bar', 'lala'], - ['multi.multi_id', - 'multi.multi_rev', - 'multi.multi_hoho']), - test_needs_fk=True, - ) meta.create_all() meta2 = MetaData() - table = Table('multi', meta2, autoload=True, - autoload_with=testing.db) - table2 = Table('multi2', meta2, autoload=True, - autoload_with=testing.db) + table = Table("multi", meta2, autoload=True, autoload_with=testing.db) + table2 = Table( + "multi2", meta2, autoload=True, autoload_with=testing.db + ) self.assert_tables_equal(multi, table) self.assert_tables_equal(multi2, table2) j = sa.join(table, table2) - self.assert_(sa.and_(table.c.multi_id == table2.c.foo, - table.c.multi_rev == table2.c.bar, - table.c.multi_hoho == table2.c.lala) - .compare(j.onclause)) + self.assert_( + sa.and_( + table.c.multi_id == table2.c.foo, + table.c.multi_rev == table2.c.bar, + table.c.multi_hoho == table2.c.lala, + ).compare(j.onclause) + ) - @testing.crashes('oracle', 'FIXME: unknown, confirm not fails_on') + @testing.crashes("oracle", "FIXME: unknown, confirm not fails_on") @testing.requires.check_constraints @testing.provide_metadata def test_reserved(self): @@ -836,50 +1018,59 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): # error meta = self.metadata - table_a = Table('select', meta, - Column('not', sa.Integer, primary_key=True), - Column('from', sa.String(12), nullable=False), - sa.UniqueConstraint('from', name='when')) - sa.Index('where', table_a.c['from']) + table_a = Table( + "select", + meta, + Column("not", sa.Integer, primary_key=True), + Column("from", sa.String(12), nullable=False), + sa.UniqueConstraint("from", name="when"), + ) + sa.Index("where", table_a.c["from"]) # There's currently no way to calculate identifier case # normalization in isolation, so... - if testing.against('firebird', 'oracle'): - check_col = 'TRUE' + if testing.against("firebird", "oracle"): + check_col = "TRUE" else: - check_col = 'true' + check_col = "true" quoter = meta.bind.dialect.identifier_preparer.quote_identifier - Table('false', meta, - Column('create', sa.Integer, primary_key=True), - Column('true', sa.Integer, sa.ForeignKey('select.not')), - sa.CheckConstraint('%s <> 1' % quoter(check_col), name='limit')) - - table_c = Table('is', meta, - Column('or', sa.Integer, nullable=False, - primary_key=True), - Column('join', sa.Integer, nullable=False, - primary_key=True), - sa.PrimaryKeyConstraint('or', 'join', name='to')) - index_c = sa.Index('else', table_c.c.join) + Table( + "false", + meta, + Column("create", sa.Integer, primary_key=True), + Column("true", sa.Integer, sa.ForeignKey("select.not")), + sa.CheckConstraint("%s <> 1" % quoter(check_col), name="limit"), + ) + + table_c = Table( + "is", + meta, + Column("or", sa.Integer, nullable=False, primary_key=True), + Column("join", sa.Integer, nullable=False, primary_key=True), + sa.PrimaryKeyConstraint("or", "join", name="to"), + ) + index_c = sa.Index("else", table_c.c.join) meta.create_all() index_c.drop() meta2 = MetaData(testing.db) - Table('select', meta2, autoload=True) - Table('false', meta2, autoload=True) - Table('is', meta2, autoload=True) + Table("select", meta2, autoload=True) + Table("false", meta2, autoload=True) + Table("is", meta2, autoload=True) @testing.provide_metadata def _test_reflect_uses_bind(self, fn): from sqlalchemy.pool import AssertionPool + e = engines.testing_engine(options={"poolclass": AssertionPool}) fn(e) @testing.uses_deprecated() def test_reflect_uses_bind_constructor_conn(self): - self._test_reflect_uses_bind(lambda e: MetaData(e.connect(), - reflect=True)) + self._test_reflect_uses_bind( + lambda e: MetaData(e.connect(), reflect=True) + ) @testing.uses_deprecated() def test_reflect_uses_bind_constructor_engine(self): @@ -901,16 +1092,16 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): def test_reflect_all(self): existing = testing.db.table_names() - names = ['rt_%s' % name for name in ('a', 'b', 'c', 'd', 'e')] + names = ["rt_%s" % name for name in ("a", "b", "c", "d", "e")] nameset = set(names) for name in names: # be sure our starting environment is sane self.assert_(name not in existing) - self.assert_('rt_f' not in existing) + self.assert_("rt_f" not in existing) baseline = self.metadata for name in names: - Table(name, baseline, Column('id', sa.Integer, primary_key=True)) + Table(name, baseline, Column("id", sa.Integer, primary_key=True)) baseline.create_all() m1 = MetaData(testing.db) @@ -919,13 +1110,13 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): self.assert_(nameset.issubset(set(m1.tables.keys()))) m2 = MetaData() - m2.reflect(testing.db, only=['rt_a', 'rt_b']) - self.assert_(set(m2.tables.keys()) == set(['rt_a', 'rt_b'])) + m2.reflect(testing.db, only=["rt_a", "rt_b"]) + self.assert_(set(m2.tables.keys()) == set(["rt_a", "rt_b"])) m3 = MetaData() c = testing.db.connect() - m3.reflect(bind=c, only=lambda name, meta: name == 'rt_c') - self.assert_(set(m3.tables.keys()) == set(['rt_c'])) + m3.reflect(bind=c, only=lambda name, meta: name == "rt_c") + self.assert_(set(m3.tables.keys()) == set(["rt_c"])) m4 = MetaData(testing.db) @@ -933,7 +1124,8 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): sa.exc.InvalidRequestError, r"Could not reflect: requested table\(s\) not available in " r"Engine\(.*?\): \(rt_f\)", - m4.reflect, only=['rt_a', 'rt_f'] + m4.reflect, + only=["rt_a", "rt_f"], ) m5 = MetaData(testing.db) @@ -949,22 +1141,19 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): self.assert_(nameset.issubset(set(m7.tables.keys()))) m8 = MetaData() - assert_raises( - sa.exc.UnboundExecutionError, - m8.reflect - ) + assert_raises(sa.exc.UnboundExecutionError, m8.reflect) m8_e1 = MetaData(testing.db) - rt_c = Table('rt_c', m8_e1) + rt_c = Table("rt_c", m8_e1) m8_e1.reflect(extend_existing=True) eq_(set(m8_e1.tables.keys()), set(names)) - eq_(rt_c.c.keys(), ['id']) + eq_(rt_c.c.keys(), ["id"]) m8_e2 = MetaData(testing.db) - rt_c = Table('rt_c', m8_e2) - m8_e2.reflect(extend_existing=True, only=['rt_a', 'rt_c']) - eq_(set(m8_e2.tables.keys()), set(['rt_a', 'rt_c'])) - eq_(rt_c.c.keys(), ['id']) + rt_c = Table("rt_c", m8_e2) + m8_e2.reflect(extend_existing=True, only=["rt_a", "rt_c"]) + eq_(set(m8_e2.tables.keys()), set(["rt_a", "rt_c"])) + eq_(rt_c.c.keys(), ["id"]) if existing: print("Other tables present in database, skipping some checks.") @@ -976,11 +1165,12 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): @testing.provide_metadata def test_reflect_all_unreflectable_table(self): - names = ['rt_%s' % name for name in ('a', 'b', 'c', 'd', 'e')] + names = ["rt_%s" % name for name in ("a", "b", "c", "d", "e")] for name in names: - Table(name, self.metadata, - Column('id', sa.Integer, primary_key=True)) + Table( + name, self.metadata, Column("id", sa.Integer, primary_key=True) + ) self.metadata.create_all() m = MetaData() @@ -988,7 +1178,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): reflecttable = testing.db.dialect.reflecttable def patched(conn, table, *arg, **kw): - if table.name == 'rt_c': + if table.name == "rt_c": raise sa.exc.UnreflectableTableError("Can't reflect rt_c") else: return reflecttable(conn, table, *arg, **kw) @@ -1000,7 +1190,10 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): assert_raises_message( sa.exc.UnreflectableTableError, "Can't reflect rt_c", - Table, 'rt_c', m, autoload_with=testing.db + Table, + "rt_c", + m, + autoload_with=testing.db, ) def test_reflect_all_conn_closing(self): @@ -1017,14 +1210,17 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): @testing.provide_metadata def test_index_reflection(self): m1 = self.metadata - t1 = Table('party', m1, - Column('id', sa.Integer, nullable=False), - Column('name', sa.String(20), index=True)) - sa.Index('idx1', t1.c.id, unique=True) - sa.Index('idx2', t1.c.name, t1.c.id, unique=False) + t1 = Table( + "party", + m1, + Column("id", sa.Integer, nullable=False), + Column("name", sa.String(20), index=True), + ) + sa.Index("idx1", t1.c.id, unique=True) + sa.Index("idx2", t1.c.name, t1.c.id, unique=False) m1.create_all() m2 = MetaData(testing.db) - t2 = Table('party', m2, autoload=True) + t2 = Table("party", m2, autoload=True) assert len(t2.indexes) == 3 # Make sure indexes are in the order we expect them in @@ -1032,8 +1228,8 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): tmp.sort() r1, r2, r3 = [idx[1] for idx in tmp] - assert r1.name == 'idx1' - assert r2.name == 'idx2' + assert r1.name == "idx1" + assert r2.name == "idx2" assert r1.unique == True # noqa assert r2.unique == False # noqa assert r3.unique == False # noqa @@ -1045,58 +1241,65 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): @testing.provide_metadata def test_comment_reflection(self): m1 = self.metadata - Table('sometable', m1, - Column('id', sa.Integer, comment='c1 comment'), - comment='t1 comment') + Table( + "sometable", + m1, + Column("id", sa.Integer, comment="c1 comment"), + comment="t1 comment", + ) m1.create_all() m2 = MetaData(testing.db) - t2 = Table('sometable', m2, autoload=True) + t2 = Table("sometable", m2, autoload=True) - eq_(t2.comment, 't1 comment') - eq_(t2.c.id.comment, 'c1 comment') + eq_(t2.comment, "t1 comment") + eq_(t2.c.id.comment, "c1 comment") - t3 = Table('sometable', m2, extend_existing=True) - eq_(t3.comment, 't1 comment') - eq_(t3.c.id.comment, 'c1 comment') + t3 = Table("sometable", m2, extend_existing=True) + eq_(t3.comment, "t1 comment") + eq_(t3.c.id.comment, "c1 comment") @testing.requires.check_constraint_reflection @testing.provide_metadata def test_check_constraint_reflection(self): m1 = self.metadata Table( - 'x', m1, - Column('q', Integer), - sa.CheckConstraint('q > 10', name="ck1") + "x", + m1, + Column("q", Integer), + sa.CheckConstraint("q > 10", name="ck1"), ) m1.create_all() m2 = MetaData(testing.db) - t2 = Table('x', m2, autoload=True) + t2 = Table("x", m2, autoload=True) ck = [ - const for const in - t2.constraints if isinstance(const, sa.CheckConstraint)][0] + const + for const in t2.constraints + if isinstance(const, sa.CheckConstraint) + ][0] eq_regex(ck.sqltext.text, r".?q.? > 10") eq_(ck.name, "ck1") @testing.provide_metadata def test_index_reflection_cols_busted(self): - t = Table('x', self.metadata, - Column('a', Integer), Column('b', Integer)) - sa.Index('x_ix', t.c.a, t.c.b) + t = Table( + "x", self.metadata, Column("a", Integer), Column("b", Integer) + ) + sa.Index("x_ix", t.c.a, t.c.b) self.metadata.create_all() def mock_get_columns(self, connection, table_name, **kw): - return [ - {"name": "b", "type": Integer, "primary_key": False} - ] + return [{"name": "b", "type": Integer, "primary_key": False}] with testing.mock.patch.object( - testing.db.dialect, "get_columns", mock_get_columns): + testing.db.dialect, "get_columns", mock_get_columns + ): m = MetaData() with testing.expect_warnings( - "index key 'a' was not located in columns"): - t = Table('x', m, autoload=True, autoload_with=testing.db) + "index key 'a' was not located in columns" + ): + t = Table("x", m, autoload=True, autoload_with=testing.db) eq_(list(t.indexes)[0].columns, [t.c.b]) @@ -1134,16 +1337,22 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): m2.reflect(views=False) eq_( - set(m2.tables), - set(['users', 'email_addresses', 'dingalings']) + set(m2.tables), set(["users", "email_addresses", "dingalings"]) ) m2 = MetaData(testing.db) m2.reflect(views=True) eq_( set(m2.tables), - set(['email_addresses_v', 'users_v', - 'users', 'dingalings', 'email_addresses']) + set( + [ + "email_addresses_v", + "users_v", + "users", + "dingalings", + "email_addresses", + ] + ), ) finally: _drop_views(metadata.bind) @@ -1156,48 +1365,65 @@ class CreateDropTest(fixtures.TestBase): def setup_class(cls): global metadata, users metadata = MetaData() - users = Table('users', metadata, - Column('user_id', sa.Integer, - sa.Sequence('user_id_seq', optional=True), - primary_key=True), - Column('user_name', sa.String(40))) - - Table('email_addresses', metadata, - Column('address_id', sa.Integer, - sa.Sequence('address_id_seq', optional=True), - primary_key=True), - Column('user_id', - sa.Integer, sa.ForeignKey(users.c.user_id)), - Column('email_address', sa.String(40))) + users = Table( + "users", + metadata, + Column( + "user_id", + sa.Integer, + sa.Sequence("user_id_seq", optional=True), + primary_key=True, + ), + Column("user_name", sa.String(40)), + ) Table( - 'orders', + "email_addresses", metadata, - Column('order_id', - sa.Integer, - sa.Sequence('order_id_seq', optional=True), - primary_key=True), - Column('user_id', sa.Integer, - sa.ForeignKey(users.c.user_id)), - Column('description', sa.String(50)), - Column('isopen', sa.Integer), - ) - Table('items', metadata, - Column('item_id', sa.INT, - sa.Sequence('items_id_seq', optional=True), - primary_key=True), - Column('order_id', - sa.INT, sa.ForeignKey('orders')), - Column('item_name', sa.VARCHAR(50))) + Column( + "address_id", + sa.Integer, + sa.Sequence("address_id_seq", optional=True), + primary_key=True, + ), + Column("user_id", sa.Integer, sa.ForeignKey(users.c.user_id)), + Column("email_address", sa.String(40)), + ) + + Table( + "orders", + metadata, + Column( + "order_id", + sa.Integer, + sa.Sequence("order_id_seq", optional=True), + primary_key=True, + ), + Column("user_id", sa.Integer, sa.ForeignKey(users.c.user_id)), + Column("description", sa.String(50)), + Column("isopen", sa.Integer), + ) + Table( + "items", + metadata, + Column( + "item_id", + sa.INT, + sa.Sequence("items_id_seq", optional=True), + primary_key=True, + ), + Column("order_id", sa.INT, sa.ForeignKey("orders")), + Column("item_name", sa.VARCHAR(50)), + ) def test_sorter(self): tables = metadata.sorted_tables table_names = [t.name for t in tables] - ua = [n for n in table_names if n in ('users', 'email_addresses')] - oi = [n for n in table_names if n in ('orders', 'items')] + ua = [n for n in table_names if n in ("users", "email_addresses")] + oi = [n for n in table_names if n in ("orders", "items")] - eq_(ua, ['users', 'email_addresses']) - eq_(oi, ['orders', 'items']) + eq_(ua, ["users", "email_addresses"]) + eq_(oi, ["orders", "items"]) def test_checkfirst(self): try: @@ -1215,16 +1441,16 @@ class CreateDropTest(fixtures.TestBase): def test_createdrop(self): metadata.create_all(bind=testing.db) - eq_(testing.db.has_table('items'), True) - eq_(testing.db.has_table('email_addresses'), True) + eq_(testing.db.has_table("items"), True) + eq_(testing.db.has_table("email_addresses"), True) metadata.create_all(bind=testing.db) - eq_(testing.db.has_table('items'), True) + eq_(testing.db.has_table("items"), True) metadata.drop_all(bind=testing.db) - eq_(testing.db.has_table('items'), False) - eq_(testing.db.has_table('email_addresses'), False) + eq_(testing.db.has_table("items"), False) + eq_(testing.db.has_table("email_addresses"), False) metadata.drop_all(bind=testing.db) - eq_(testing.db.has_table('items'), False) + eq_(testing.db.has_table("items"), False) def test_tablenames(self): metadata.create_all(bind=testing.db) @@ -1234,8 +1460,7 @@ class CreateDropTest(fixtures.TestBase): # "extra" tables if there is a misconfigured template. (*cough* # tsearch2 w/ the pg windows installer.) - self.assert_(not set(metadata.tables) - - set(testing.db.table_names())) + self.assert_(not set(metadata.tables) - set(testing.db.table_names())) metadata.drop_all(bind=testing.db) @@ -1245,12 +1470,15 @@ class SchemaManipulationTest(fixtures.TestBase): def test_append_constraint_unique(self): meta = MetaData() - users = Table('users', meta, Column('id', sa.Integer)) - addresses = Table('addresses', meta, - Column('id', sa.Integer), - Column('user_id', sa.Integer)) + users = Table("users", meta, Column("id", sa.Integer)) + addresses = Table( + "addresses", + meta, + Column("id", sa.Integer), + Column("user_id", sa.Integer), + ) - fk = sa.ForeignKeyConstraint(['user_id'], [users.c.id]) + fk = sa.ForeignKeyConstraint(["user_id"], [users.c.id]) addresses.append_constraint(fk) addresses.append_constraint(fk) @@ -1265,43 +1493,37 @@ class UnicodeReflectionTest(fixtures.TestBase): def setup_class(cls): cls.metadata = metadata = MetaData() - no_multibyte_period = set([ - ('plain', 'col_plain', 'ix_plain') - ]) + no_multibyte_period = set([("plain", "col_plain", "ix_plain")]) no_has_table = [ ( - 'no_has_table_1', - ue('col_Unit\u00e9ble'), - ue('ix_Unit\u00e9ble') - ), - ( - 'no_has_table_2', - ue('col_\u6e2c\u8a66'), - ue('ix_\u6e2c\u8a66') + "no_has_table_1", + ue("col_Unit\u00e9ble"), + ue("ix_Unit\u00e9ble"), ), + ("no_has_table_2", ue("col_\u6e2c\u8a66"), ue("ix_\u6e2c\u8a66")), ] no_case_sensitivity = [ ( - ue('\u6e2c\u8a66'), - ue('col_\u6e2c\u8a66'), - ue('ix_\u6e2c\u8a66') + ue("\u6e2c\u8a66"), + ue("col_\u6e2c\u8a66"), + ue("ix_\u6e2c\u8a66"), ), ( - ue('unit\u00e9ble'), - ue('col_unit\u00e9ble'), - ue('ix_unit\u00e9ble') + ue("unit\u00e9ble"), + ue("col_unit\u00e9ble"), + ue("ix_unit\u00e9ble"), ), ] full = [ ( - ue('Unit\u00e9ble'), - ue('col_Unit\u00e9ble'), - ue('ix_Unit\u00e9ble') + ue("Unit\u00e9ble"), + ue("col_Unit\u00e9ble"), + ue("ix_Unit\u00e9ble"), ), ( - ue('\u6e2c\u8a66'), - ue('col_\u6e2c\u8a66'), - ue('ix_\u6e2c\u8a66') + ue("\u6e2c\u8a66"), + ue("col_\u6e2c\u8a66"), + ue("ix_\u6e2c\u8a66"), ), ] @@ -1312,8 +1534,10 @@ class UnicodeReflectionTest(fixtures.TestBase): if not testing.requires.unicode_ddl.enabled: names = no_multibyte_period # mysql can't handle casing usually - elif testing.against("mysql") and \ - not testing.requires.mysql_fully_case_sensitive.enabled: + elif ( + testing.against("mysql") + and not testing.requires.mysql_fully_case_sensitive.enabled + ): names = no_multibyte_period.union(no_case_sensitivity) # mssql + pyodbc + freetds can't compare multibyte names to # information_schema.tables.table_name @@ -1323,11 +1547,17 @@ class UnicodeReflectionTest(fixtures.TestBase): names = no_multibyte_period.union(full) for tname, cname, ixname in names: - t = Table(tname, metadata, - Column('id', sa.Integer, - sa.Sequence(cname + '_id_seq'), - primary_key=True), - Column(cname, Integer)) + t = Table( + tname, + metadata, + Column( + "id", + sa.Integer, + sa.Sequence(cname + "_id_seq"), + primary_key=True, + ), + Column(cname, Integer), + ) schema.Index(ixname, t.c[cname]) metadata.create_all(testing.db) @@ -1356,14 +1586,14 @@ class UnicodeReflectionTest(fixtures.TestBase): # Jython 2.5 on Java 5 lacks unicodedata.normalize - if not names.issubset(reflected) and hasattr(unicodedata, 'normalize'): + if not names.issubset(reflected) and hasattr(unicodedata, "normalize"): # Python source files in the utf-8 coding seem to # normalize literals as NFC (and the above are # explicitly NFC). Maybe this database normalizes NFD # on reflection. - nfc = set([unicodedata.normalize('NFC', n) for n in names]) + nfc = set([unicodedata.normalize("NFC", n) for n in names]) self.assert_(nfc == names) # Yep. But still ensure that bulk reflection and @@ -1384,10 +1614,10 @@ class UnicodeReflectionTest(fixtures.TestBase): assert tname in names eq_( [ - (rec['name'], rec['column_names'][0]) + (rec["name"], rec["column_names"][0]) for rec in inspector.get_indexes(tname) ], - [(names[tname][1], names[tname][0])] + [(names[tname][1], names[tname][0])], ) @@ -1399,12 +1629,19 @@ class SchemaTest(fixtures.TestBase): def test_has_schema(self): if not hasattr(testing.db.dialect, "has_schema"): testing.config.skip_test( - "dialect %s doesn't have a has_schema method" % - testing.db.dialect.name) - eq_(testing.db.dialect.has_schema(testing.db, - testing.config.test_schema), True) - eq_(testing.db.dialect.has_schema(testing.db, - 'sa_fake_schema_123'), False) + "dialect %s doesn't have a has_schema method" + % testing.db.dialect.name + ) + eq_( + testing.db.dialect.has_schema( + testing.db, testing.config.test_schema + ), + True, + ) + eq_( + testing.db.dialect.has_schema(testing.db, "sa_fake_schema_123"), + False, + ) @testing.requires.schemas @testing.requires.cross_schema_fk_reflection @@ -1413,63 +1650,79 @@ class SchemaTest(fixtures.TestBase): def test_blank_schema_arg(self): metadata = self.metadata - Table('some_table', metadata, - Column('id', Integer, primary_key=True), - Column('sid', Integer, sa.ForeignKey('some_other_table.id')), - schema=testing.config.test_schema, - test_needs_fk=True - ) - Table('some_other_table', metadata, - Column('id', Integer, primary_key=True), - schema=None, - test_needs_fk=True - ) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("sid", Integer, sa.ForeignKey("some_other_table.id")), + schema=testing.config.test_schema, + test_needs_fk=True, + ) + Table( + "some_other_table", + metadata, + Column("id", Integer, primary_key=True), + schema=None, + test_needs_fk=True, + ) metadata.create_all() with testing.db.connect() as conn: meta2 = MetaData(conn, schema=testing.config.test_schema) meta2.reflect() - eq_(set(meta2.tables), set( - [ - 'some_other_table', - '%s.some_table' % testing.config.test_schema])) + eq_( + set(meta2.tables), + set( + [ + "some_other_table", + "%s.some_table" % testing.config.test_schema, + ] + ), + ) @testing.requires.schemas - @testing.fails_on('sqlite', 'FIXME: unknown') - @testing.fails_on('sybase', 'FIXME: unknown') + @testing.fails_on("sqlite", "FIXME: unknown") + @testing.fails_on("sybase", "FIXME: unknown") def test_explicit_default_schema(self): engine = testing.db engine.connect().close() - if testing.against('sqlite'): + if testing.against("sqlite"): # Works for CREATE TABLE main.foo, SELECT FROM main.foo, etc., # but fails on: # FOREIGN KEY(col2) REFERENCES main.table1 (col1) - schema = 'main' + schema = "main" else: schema = engine.dialect.default_schema_name assert bool(schema) metadata = MetaData(engine) - Table('table1', metadata, - Column('col1', sa.Integer, primary_key=True), - test_needs_fk=True, - schema=schema) - Table('table2', metadata, - Column('col1', sa.Integer, primary_key=True), - Column('col2', sa.Integer, - sa.ForeignKey('%s.table1.col1' % schema)), - test_needs_fk=True, - schema=schema) + Table( + "table1", + metadata, + Column("col1", sa.Integer, primary_key=True), + test_needs_fk=True, + schema=schema, + ) + Table( + "table2", + metadata, + Column("col1", sa.Integer, primary_key=True), + Column( + "col2", sa.Integer, sa.ForeignKey("%s.table1.col1" % schema) + ), + test_needs_fk=True, + schema=schema, + ) try: metadata.create_all() metadata.create_all(checkfirst=True) assert len(metadata.tables) == 2 metadata.clear() - Table('table1', metadata, autoload=True, schema=schema) - Table('table2', metadata, autoload=True, schema=schema) + Table("table1", metadata, autoload=True, schema=schema) + Table("table2", metadata, autoload=True, schema=schema) assert len(metadata.tables) == 2 finally: metadata.drop_all() @@ -1477,49 +1730,60 @@ class SchemaTest(fixtures.TestBase): @testing.requires.schemas @testing.provide_metadata def test_schema_translation(self): - Table('foob', self.metadata, Column('q', Integer), - schema=config.test_schema) + Table( + "foob", + self.metadata, + Column("q", Integer), + schema=config.test_schema, + ) self.metadata.create_all() m = MetaData() map_ = {"foob": config.test_schema} - with config.db.connect().execution_options(schema_translate_map=map_) \ - as conn: - t = Table('foob', m, schema="foob", autoload_with=conn) + with config.db.connect().execution_options( + schema_translate_map=map_ + ) as conn: + t = Table("foob", m, schema="foob", autoload_with=conn) eq_(t.schema, "foob") - eq_(t.c.keys(), ['q']) + eq_(t.c.keys(), ["q"]) @testing.requires.schemas - @testing.fails_on('sybase', 'FIXME: unknown') + @testing.fails_on("sybase", "FIXME: unknown") def test_explicit_default_schema_metadata(self): engine = testing.db - if testing.against('sqlite'): + if testing.against("sqlite"): # Works for CREATE TABLE main.foo, SELECT FROM main.foo, etc., # but fails on: # FOREIGN KEY(col2) REFERENCES main.table1 (col1) - schema = 'main' + schema = "main" else: schema = engine.dialect.default_schema_name assert bool(schema) metadata = MetaData(engine, schema=schema) - Table('table1', metadata, - Column('col1', sa.Integer, primary_key=True), - test_needs_fk=True) - Table('table2', metadata, - Column('col1', sa.Integer, primary_key=True), - Column('col2', sa.Integer, sa.ForeignKey('table1.col1')), - test_needs_fk=True) + Table( + "table1", + metadata, + Column("col1", sa.Integer, primary_key=True), + test_needs_fk=True, + ) + Table( + "table2", + metadata, + Column("col1", sa.Integer, primary_key=True), + Column("col2", sa.Integer, sa.ForeignKey("table1.col1")), + test_needs_fk=True, + ) try: metadata.create_all() metadata.create_all(checkfirst=True) assert len(metadata.tables) == 2 metadata.clear() - Table('table1', metadata, autoload=True) - Table('table2', metadata, autoload=True) + Table("table1", metadata, autoload=True) + Table("table2", metadata, autoload=True) assert len(metadata.tables) == 2 finally: metadata.drop_all() @@ -1534,11 +1798,13 @@ class SchemaTest(fixtures.TestBase): m2.reflect() eq_( set(m2.tables), - set([ - '%s.dingalings' % testing.config.test_schema, - '%s.users' % testing.config.test_schema, - '%s.email_addresses' % testing.config.test_schema - ]) + set( + [ + "%s.dingalings" % testing.config.test_schema, + "%s.users" % testing.config.test_schema, + "%s.email_addresses" % testing.config.test_schema, + ] + ), ) @testing.requires.schemas @@ -1546,12 +1812,14 @@ class SchemaTest(fixtures.TestBase): @testing.requires.implicit_default_schema @testing.provide_metadata def test_reflect_all_schemas_default_overlap(self): - t1 = Table('t', self.metadata, - Column('id', Integer, primary_key=True)) + t1 = Table("t", self.metadata, Column("id", Integer, primary_key=True)) - t2 = Table('t', self.metadata, - Column('id1', sa.ForeignKey('t.id')), - schema=testing.config.test_schema) + t2 = Table( + "t", + self.metadata, + Column("id1", sa.ForeignKey("t.id")), + schema=testing.config.test_schema, + ) self.metadata.create_all() m2 = MetaData() @@ -1563,7 +1831,7 @@ class SchemaTest(fixtures.TestBase): eq_( set((t.name, t.schema) for t in m2.tables.values()), - set((t.name, t.schema) for t in m3.tables.values()) + set((t.name, t.schema) for t in m3.tables.values()), ) @@ -1576,88 +1844,102 @@ def createTables(meta, schema=None): else: schema_prefix = "" - users = Table('users', meta, - Column('user_id', sa.INT, primary_key=True), - Column('user_name', sa.VARCHAR(20), nullable=False), - Column('test1', sa.CHAR(5), nullable=False), - Column('test2', sa.Float(5), nullable=False), - Column('test3', sa.Text), - Column('test4', sa.Numeric(10, 2), nullable=False), - Column('test5', sa.Date), - Column('parent_user_id', sa.Integer, - sa.ForeignKey('%susers.user_id' % schema_prefix)), - Column('test6', sa.Date, nullable=False), - Column('test7', sa.Text), - Column('test8', sa.LargeBinary), - Column('test_passivedefault2', sa.Integer, - server_default='5'), - Column('test9', sa.LargeBinary(100)), - Column('test10', sa.Numeric(10, 2)), - schema=schema, - test_needs_fk=True) - dingalings = Table("dingalings", meta, - Column('dingaling_id', sa.Integer, primary_key=True), - Column('address_id', sa.Integer, - sa.ForeignKey('%semail_addresses.address_id' - % schema_prefix)), - Column('data', sa.String(30)), - schema=schema, test_needs_fk=True) - addresses = Table('email_addresses', meta, - Column('address_id', sa.Integer), - Column('remote_user_id', sa.Integer, - sa.ForeignKey(users.c.user_id)), - Column('email_address', sa.String(20)), - sa.PrimaryKeyConstraint('address_id', - name='email_ad_pk'), - schema=schema, - test_needs_fk=True) + users = Table( + "users", + meta, + Column("user_id", sa.INT, primary_key=True), + Column("user_name", sa.VARCHAR(20), nullable=False), + Column("test1", sa.CHAR(5), nullable=False), + Column("test2", sa.Float(5), nullable=False), + Column("test3", sa.Text), + Column("test4", sa.Numeric(10, 2), nullable=False), + Column("test5", sa.Date), + Column( + "parent_user_id", + sa.Integer, + sa.ForeignKey("%susers.user_id" % schema_prefix), + ), + Column("test6", sa.Date, nullable=False), + Column("test7", sa.Text), + Column("test8", sa.LargeBinary), + Column("test_passivedefault2", sa.Integer, server_default="5"), + Column("test9", sa.LargeBinary(100)), + Column("test10", sa.Numeric(10, 2)), + schema=schema, + test_needs_fk=True, + ) + dingalings = Table( + "dingalings", + meta, + Column("dingaling_id", sa.Integer, primary_key=True), + Column( + "address_id", + sa.Integer, + sa.ForeignKey("%semail_addresses.address_id" % schema_prefix), + ), + Column("data", sa.String(30)), + schema=schema, + test_needs_fk=True, + ) + addresses = Table( + "email_addresses", + meta, + Column("address_id", sa.Integer), + Column("remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id)), + Column("email_address", sa.String(20)), + sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"), + schema=schema, + test_needs_fk=True, + ) return (users, addresses, dingalings) def createIndexes(con, schema=None): - fullname = 'users' + fullname = "users" if schema: - fullname = "%s.%s" % (schema, 'users') + fullname = "%s.%s" % (schema, "users") query = "CREATE INDEX users_t_idx ON %s (test1, test2)" % fullname con.execute(sa.sql.text(query)) @testing.requires.views def _create_views(con, schema=None): - for table_name in ('users', 'email_addresses'): + for table_name in ("users", "email_addresses"): fullname = table_name if schema: fullname = "%s.%s" % (schema, table_name) - view_name = fullname + '_v' + view_name = fullname + "_v" query = "CREATE VIEW %s AS SELECT * FROM %s" % (view_name, fullname) con.execute(sa.sql.text(query)) @testing.requires.views def _drop_views(con, schema=None): - for table_name in ('email_addresses', 'users'): + for table_name in ("email_addresses", "users"): fullname = table_name if schema: fullname = "%s.%s" % (schema, table_name) - view_name = fullname + '_v' + view_name = fullname + "_v" query = "DROP VIEW %s" % view_name con.execute(sa.sql.text(query)) class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" __backend__ = True @testing.requires.denormalized_names def setup(self): - testing.db.execute(""" + testing.db.execute( + """ CREATE TABLE weird_casing( col1 char(20), "Col2" char(20), "col3" char(20) ) - """) + """ + ) @testing.requires.denormalized_names def teardown(self): @@ -1666,32 +1948,39 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL): @testing.requires.denormalized_names def test_direct_quoting(self): m = MetaData(testing.db) - t = Table('weird_casing', m, autoload=True) - self.assert_compile(t.select(), - 'SELECT weird_casing.col1, ' - 'weird_casing."Col2", weird_casing."col3" ' - 'FROM weird_casing') + t = Table("weird_casing", m, autoload=True) + self.assert_compile( + t.select(), + "SELECT weird_casing.col1, " + 'weird_casing."Col2", weird_casing."col3" ' + "FROM weird_casing", + ) class CaseSensitiveTest(fixtures.TablesTest): """Nail down case sensitive behaviors, mostly on MySQL.""" + __backend__ = True @classmethod def define_tables(cls, metadata): - Table('SomeTable', metadata, - Column('x', Integer, primary_key=True), - test_needs_fk=True) - Table('SomeOtherTable', metadata, - Column('x', Integer, primary_key=True), - Column('y', Integer, sa.ForeignKey("SomeTable.x")), - test_needs_fk=True) + Table( + "SomeTable", + metadata, + Column("x", Integer, primary_key=True), + test_needs_fk=True, + ) + Table( + "SomeOtherTable", + metadata, + Column("x", Integer, primary_key=True), + Column("y", Integer, sa.ForeignKey("SomeTable.x")), + test_needs_fk=True, + ) @testing.fails_if(testing.requires._has_mysql_on_windows) def test_table_names(self): - x = testing.db.run_callable( - testing.db.dialect.get_table_names - ) + x = testing.db.run_callable(testing.db.dialect.get_table_names) assert set(["SomeTable", "SomeOtherTable"]).issubset(x) def test_reflect_exact_name(self): @@ -1700,19 +1989,20 @@ class CaseSensitiveTest(fixtures.TablesTest): eq_(t1.name, "SomeTable") assert t1.c.x is not None - @testing.fails_if(lambda: - testing.against(('mysql', '<', (5, 5))) and - not testing.requires._has_mysql_fully_case_sensitive() - ) + @testing.fails_if( + lambda: testing.against(("mysql", "<", (5, 5))) + and not testing.requires._has_mysql_fully_case_sensitive() + ) def test_reflect_via_fk(self): m = MetaData() - t2 = Table("SomeOtherTable", m, autoload=True, - autoload_with=testing.db) + t2 = Table( + "SomeOtherTable", m, autoload=True, autoload_with=testing.db + ) eq_(t2.name, "SomeOtherTable") assert "SomeTable" in m.tables @testing.fails_if(testing.requires._has_mysql_fully_case_sensitive) - @testing.fails_on_everything_except('sqlite', 'mysql', 'mssql') + @testing.fails_on_everything_except("sqlite", "mysql", "mssql") def test_reflect_case_insensitive(self): m = MetaData() t2 = Table("sOmEtAbLe", m, autoload=True, autoload_with=testing.db) @@ -1726,17 +2016,17 @@ class ColumnEventsTest(fixtures.RemovesEvents, fixtures.TestBase): def setup_class(cls): cls.metadata = MetaData() cls.to_reflect = Table( - 'to_reflect', + "to_reflect", cls.metadata, - Column('x', sa.Integer, primary_key=True), - Column('y', sa.Integer), - test_needs_fk=True + Column("x", sa.Integer, primary_key=True), + Column("y", sa.Integer), + test_needs_fk=True, ) cls.related = Table( - 'related', + "related", cls.metadata, - Column('q', sa.Integer, sa.ForeignKey('to_reflect.x')), - test_needs_fk=True + Column("q", sa.Integer, sa.ForeignKey("to_reflect.x")), + test_needs_fk=True, ) sa.Index("some_index", cls.to_reflect.c.y) cls.metadata.create_all(testing.db) @@ -1753,16 +2043,19 @@ class ColumnEventsTest(fixtures.RemovesEvents, fixtures.TestBase): m = MetaData(testing.db) def column_reflect(insp, table, column_info): - if column_info['name'] == col: + if column_info["name"] == col: column_info.update(update) - t = Table(tablename, m, autoload=True, listeners=[ - ('column_reflect', column_reflect), - ]) + t = Table( + tablename, + m, + autoload=True, + listeners=[("column_reflect", column_reflect)], + ) assert_(t) m = MetaData(testing.db) - self.event_listen(Table, 'column_reflect', column_reflect) + self.event_listen(Table, "column_reflect", column_reflect) t2 = Table(tablename, m, autoload=True) assert_(t2) @@ -1771,65 +2064,67 @@ class ColumnEventsTest(fixtures.RemovesEvents, fixtures.TestBase): eq_(table.c.YXZ.name, "x") eq_(set(table.primary_key), set([table.c.YXZ])) - self._do_test( - "x", {"key": "YXZ"}, - assertions - ) + self._do_test("x", {"key": "YXZ"}, assertions) def test_override_index(self): def assertions(table): idx = list(table.indexes)[0] eq_(idx.columns, [table.c.YXZ]) - self._do_test( - "y", {"key": "YXZ"}, - assertions - ) + self._do_test("y", {"key": "YXZ"}, assertions) def test_override_key_fk(self): m = MetaData(testing.db) def column_reflect(insp, table, column_info): - if column_info['name'] == 'q': - column_info['key'] = 'qyz' - elif column_info['name'] == 'x': - column_info['key'] = 'xyz' + if column_info["name"] == "q": + column_info["key"] = "qyz" + elif column_info["name"] == "x": + column_info["key"] = "xyz" - to_reflect = Table("to_reflect", m, autoload=True, listeners=[ - ('column_reflect', column_reflect), - ]) - related = Table("related", m, autoload=True, - listeners=[('column_reflect', column_reflect)]) + to_reflect = Table( + "to_reflect", + m, + autoload=True, + listeners=[("column_reflect", column_reflect)], + ) + related = Table( + "related", + m, + autoload=True, + listeners=[("column_reflect", column_reflect)], + ) assert related.c.qyz.references(to_reflect.c.xyz) def test_override_type(self): def assert_(table): assert isinstance(table.c.x.type, sa.String) - self._do_test( - "x", {"type": sa.String}, - assert_ - ) + + self._do_test("x", {"type": sa.String}, assert_) def test_override_info(self): self._do_test( - "x", {"info": {"a": "b"}}, - lambda table: eq_(table.c.x.info, {"a": "b"}) + "x", + {"info": {"a": "b"}}, + lambda table: eq_(table.c.x.info, {"a": "b"}), ) def test_override_server_default_fetchedvalue(self): my_default = FetchedValue() self._do_test( - "x", {"default": my_default}, - lambda table: eq_(table.c.x.server_default, my_default) + "x", + {"default": my_default}, + lambda table: eq_(table.c.x.server_default, my_default), ) def test_override_server_default_default_clause(self): my_default = DefaultClause("1") self._do_test( - "x", {"default": my_default}, - lambda table: eq_(table.c.x.server_default, my_default) + "x", + {"default": my_default}, + lambda table: eq_(table.c.x.server_default, my_default), ) def test_override_server_default_plain_text(self): @@ -1838,15 +2133,12 @@ class ColumnEventsTest(fixtures.RemovesEvents, fixtures.TestBase): def assert_text_of_one(table): is_true( isinstance( - table.c.x.server_default.arg, sql.elements.TextClause) - ) - eq_( - str(table.c.x.server_default.arg), "1" + table.c.x.server_default.arg, sql.elements.TextClause + ) ) - self._do_test( - "x", {"default": my_default}, - assert_text_of_one - ) + eq_(str(table.c.x.server_default.arg), "1") + + self._do_test("x", {"default": my_default}, assert_text_of_one) def test_override_server_default_textclause(self): my_default = sa.text("1") @@ -1854,12 +2146,9 @@ class ColumnEventsTest(fixtures.RemovesEvents, fixtures.TestBase): def assert_text_of_one(table): is_true( isinstance( - table.c.x.server_default.arg, sql.elements.TextClause) - ) - eq_( - str(table.c.x.server_default.arg), "1" + table.c.x.server_default.arg, sql.elements.TextClause + ) ) - self._do_test( - "x", {"default": my_default}, - assert_text_of_one - ) + eq_(str(table.c.x.server_default.arg), "1") + + self._do_test("x", {"default": my_default}, assert_text_of_one) diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py index d865b47a56..ad2f29fa03 100644 --- a/test/engine/test_transaction.py +++ b/test/engine/test_transaction.py @@ -1,10 +1,26 @@ -from sqlalchemy.testing import eq_, assert_raises, \ - assert_raises_message, ne_, expect_warnings +from sqlalchemy.testing import ( + eq_, + assert_raises, + assert_raises_message, + ne_, + expect_warnings, +) import sys from sqlalchemy import event from sqlalchemy.testing.engines import testing_engine -from sqlalchemy import create_engine, MetaData, INT, VARCHAR, Sequence, \ - select, Integer, String, func, text, exc +from sqlalchemy import ( + create_engine, + MetaData, + INT, + VARCHAR, + Sequence, + select, + Integer, + String, + func, + text, + exc, +) from sqlalchemy.testing.schema import Table from sqlalchemy.testing.schema import Column from sqlalchemy import testing @@ -21,10 +37,13 @@ class TransactionTest(fixtures.TestBase): def setup_class(cls): global users, metadata metadata = MetaData() - users = Table('query_users', metadata, - Column('user_id', INT, primary_key=True), - Column('user_name', VARCHAR(20)), - test_needs_acid=True) + users = Table( + "query_users", + metadata, + Column("user_id", INT, primary_key=True), + Column("user_name", VARCHAR(20)), + test_needs_acid=True, + ) users.create(testing.db) def teardown(self): @@ -37,12 +56,12 @@ class TransactionTest(fixtures.TestBase): def test_commits(self): connection = testing.db.connect() transaction = connection.begin() - connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=1, user_name="user1") transaction.commit() transaction = connection.begin() - connection.execute(users.insert(), user_id=2, user_name='user2') - connection.execute(users.insert(), user_id=3, user_name='user3') + connection.execute(users.insert(), user_id=2, user_name="user2") + connection.execute(users.insert(), user_id=3, user_name="user3") transaction.commit() transaction = connection.begin() @@ -56,9 +75,9 @@ class TransactionTest(fixtures.TestBase): connection = testing.db.connect() transaction = connection.begin() - connection.execute(users.insert(), user_id=1, user_name='user1') - connection.execute(users.insert(), user_id=2, user_name='user2') - connection.execute(users.insert(), user_id=3, user_name='user3') + connection.execute(users.insert(), user_id=1, user_name="user1") + connection.execute(users.insert(), user_id=2, user_name="user2") + connection.execute(users.insert(), user_id=3, user_name="user3") transaction.rollback() result = connection.execute("select * from query_users") @@ -70,9 +89,9 @@ class TransactionTest(fixtures.TestBase): transaction = connection.begin() try: - connection.execute(users.insert(), user_id=1, user_name='user1') - connection.execute(users.insert(), user_id=2, user_name='user2') - connection.execute(users.insert(), user_id=1, user_name='user3') + connection.execute(users.insert(), user_id=1, user_name="user1") + connection.execute(users.insert(), user_id=2, user_name="user2") + connection.execute(users.insert(), user_id=1, user_name="user3") transaction.commit() assert False except Exception as e: @@ -84,37 +103,47 @@ class TransactionTest(fixtures.TestBase): connection.close() def test_transaction_container(self): - def go(conn, table, data): for d in data: conn.execute(table.insert(), d) - testing.db.transaction(go, users, [dict(user_id=1, - user_name='user1')]) - eq_(testing.db.execute(users.select()).fetchall(), [(1, 'user1')]) - assert_raises(exc.DBAPIError, testing.db.transaction, go, - users, [{'user_id': 2, 'user_name': 'user2'}, - {'user_id': 1, 'user_name': 'user3'}]) - eq_(testing.db.execute(users.select()).fetchall(), [(1, 'user1')]) + testing.db.transaction(go, users, [dict(user_id=1, user_name="user1")]) + eq_(testing.db.execute(users.select()).fetchall(), [(1, "user1")]) + assert_raises( + exc.DBAPIError, + testing.db.transaction, + go, + users, + [ + {"user_id": 2, "user_name": "user2"}, + {"user_id": 1, "user_name": "user3"}, + ], + ) + eq_(testing.db.execute(users.select()).fetchall(), [(1, "user1")]) def test_nested_rollback(self): connection = testing.db.connect() try: transaction = connection.begin() try: - connection.execute(users.insert(), user_id=1, - user_name='user1') - connection.execute(users.insert(), user_id=2, - user_name='user2') - connection.execute(users.insert(), user_id=3, - user_name='user3') + connection.execute( + users.insert(), user_id=1, user_name="user1" + ) + connection.execute( + users.insert(), user_id=2, user_name="user2" + ) + connection.execute( + users.insert(), user_id=3, user_name="user3" + ) trans2 = connection.begin() try: - connection.execute(users.insert(), user_id=4, - user_name='user4') - connection.execute(users.insert(), user_id=5, - user_name='user5') - raise Exception('uh oh') + connection.execute( + users.insert(), user_id=4, user_name="user4" + ) + connection.execute( + users.insert(), user_id=5, user_name="user5" + ) + raise Exception("uh oh") trans2.commit() except Exception: trans2.rollback() @@ -127,7 +156,7 @@ class TransactionTest(fixtures.TestBase): try: # and not "This transaction is inactive" # comment moved here to fix pep8 - assert str(e) == 'uh oh' + assert str(e) == "uh oh" finally: connection.close() @@ -137,9 +166,9 @@ class TransactionTest(fixtures.TestBase): connection.begin() branched = connection.connect() assert branched.in_transaction() - branched.execute(users.insert(), user_id=1, user_name='user1') + branched.execute(users.insert(), user_id=1, user_name="user1") nested = branched.begin() - branched.execute(users.insert(), user_id=2, user_name='user2') + branched.execute(users.insert(), user_id=2, user_name="user2") nested.rollback() assert not connection.in_transaction() eq_(connection.scalar("select count(*) from query_users"), 0) @@ -151,9 +180,9 @@ class TransactionTest(fixtures.TestBase): connection = testing.db.connect() try: branched = connection.connect() - branched.execute(users.insert(), user_id=1, user_name='user1') + branched.execute(users.insert(), user_id=1, user_name="user1") try: - branched.execute(users.insert(), user_id=1, user_name='user1') + branched.execute(users.insert(), user_id=1, user_name="user1") except exc.DBAPIError: pass finally: @@ -163,10 +192,10 @@ class TransactionTest(fixtures.TestBase): connection = testing.db.connect() try: branched = connection.connect() - branched.execute(users.insert(), user_id=1, user_name='user1') + branched.execute(users.insert(), user_id=1, user_name="user1") nested = branched.begin() assert branched.in_transaction() - branched.execute(users.insert(), user_id=2, user_name='user2') + branched.execute(users.insert(), user_id=2, user_name="user2") nested.rollback() eq_(connection.scalar("select count(*) from query_users"), 1) @@ -177,7 +206,7 @@ class TransactionTest(fixtures.TestBase): connection = testing.db.connect() try: branched = connection.connect() - branched.execute(users.insert(), user_id=1, user_name='user1') + branched.execute(users.insert(), user_id=1, user_name="user1") finally: connection.close() eq_(testing.db.scalar("select count(*) from query_users"), 1) @@ -189,9 +218,9 @@ class TransactionTest(fixtures.TestBase): trans = connection.begin() branched = connection.connect() assert branched.in_transaction() - branched.execute(users.insert(), user_id=1, user_name='user1') + branched.execute(users.insert(), user_id=1, user_name="user1") nested = branched.begin_nested() - branched.execute(users.insert(), user_id=2, user_name='user2') + branched.execute(users.insert(), user_id=2, user_name="user2") nested.rollback() assert connection.in_transaction() trans.commit() @@ -206,9 +235,9 @@ class TransactionTest(fixtures.TestBase): try: branched = connection.connect() assert not branched.in_transaction() - branched.execute(users.insert(), user_id=1, user_name='user1') + branched.execute(users.insert(), user_id=1, user_name="user1") nested = branched.begin_twophase() - branched.execute(users.insert(), user_id=2, user_name='user2') + branched.execute(users.insert(), user_id=2, user_name="user2") nested.rollback() assert not connection.in_transaction() eq_(connection.scalar("select count(*) from query_users"), 1) @@ -227,23 +256,24 @@ class TransactionTest(fixtures.TestBase): "exception. The previous exception " r"is:.*..SQL\:.*RELEASE SAVEPOINT" ): + def go(): with connection.begin_nested() as savepoint: connection.dialect.do_release_savepoint( - connection, savepoint._savepoint) + connection, savepoint._savepoint + ) + assert_raises_message( - exc.DBAPIError, - r".*SQL\:.*ROLLBACK TO SAVEPOINT", - go + exc.DBAPIError, r".*SQL\:.*ROLLBACK TO SAVEPOINT", go ) def test_retains_through_options(self): connection = testing.db.connect() try: transaction = connection.begin() - connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=1, user_name="user1") conn2 = connection.execution_options(dummy=True) - conn2.execute(users.insert(), user_id=2, user_name='user2') + conn2.execute(users.insert(), user_id=2, user_name="user2") transaction.rollback() eq_(connection.scalar("select count(*) from query_users"), 0) finally: @@ -252,79 +282,84 @@ class TransactionTest(fixtures.TestBase): def test_nesting(self): connection = testing.db.connect() transaction = connection.begin() - connection.execute(users.insert(), user_id=1, user_name='user1') - connection.execute(users.insert(), user_id=2, user_name='user2') - connection.execute(users.insert(), user_id=3, user_name='user3') + connection.execute(users.insert(), user_id=1, user_name="user1") + connection.execute(users.insert(), user_id=2, user_name="user2") + connection.execute(users.insert(), user_id=3, user_name="user3") trans2 = connection.begin() - connection.execute(users.insert(), user_id=4, user_name='user4') - connection.execute(users.insert(), user_id=5, user_name='user5') + connection.execute(users.insert(), user_id=4, user_name="user4") + connection.execute(users.insert(), user_id=5, user_name="user5") trans2.commit() transaction.rollback() - self.assert_(connection.scalar('select count(*) from ' - 'query_users') == 0) - result = connection.execute('select * from query_users') + self.assert_( + connection.scalar("select count(*) from " "query_users") == 0 + ) + result = connection.execute("select * from query_users") assert len(result.fetchall()) == 0 connection.close() def test_with_interface(self): connection = testing.db.connect() trans = connection.begin() - connection.execute(users.insert(), user_id=1, user_name='user1') - connection.execute(users.insert(), user_id=2, user_name='user2') + connection.execute(users.insert(), user_id=1, user_name="user1") + connection.execute(users.insert(), user_id=2, user_name="user2") try: - connection.execute(users.insert(), user_id=2, user_name='user2.5') + connection.execute(users.insert(), user_id=2, user_name="user2.5") except Exception as e: trans.__exit__(*sys.exc_info()) assert not trans.is_active - self.assert_(connection.scalar('select count(*) from ' - 'query_users') == 0) + self.assert_( + connection.scalar("select count(*) from " "query_users") == 0 + ) trans = connection.begin() - connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=1, user_name="user1") trans.__exit__(None, None, None) assert not trans.is_active - self.assert_(connection.scalar('select count(*) from ' - 'query_users') == 1) + self.assert_( + connection.scalar("select count(*) from " "query_users") == 1 + ) connection.close() def test_close(self): connection = testing.db.connect() transaction = connection.begin() - connection.execute(users.insert(), user_id=1, user_name='user1') - connection.execute(users.insert(), user_id=2, user_name='user2') - connection.execute(users.insert(), user_id=3, user_name='user3') + connection.execute(users.insert(), user_id=1, user_name="user1") + connection.execute(users.insert(), user_id=2, user_name="user2") + connection.execute(users.insert(), user_id=3, user_name="user3") trans2 = connection.begin() - connection.execute(users.insert(), user_id=4, user_name='user4') - connection.execute(users.insert(), user_id=5, user_name='user5') + connection.execute(users.insert(), user_id=4, user_name="user4") + connection.execute(users.insert(), user_id=5, user_name="user5") assert connection.in_transaction() trans2.close() assert connection.in_transaction() transaction.commit() assert not connection.in_transaction() - self.assert_(connection.scalar('select count(*) from ' - 'query_users') == 5) - result = connection.execute('select * from query_users') + self.assert_( + connection.scalar("select count(*) from " "query_users") == 5 + ) + result = connection.execute("select * from query_users") assert len(result.fetchall()) == 5 connection.close() def test_close2(self): connection = testing.db.connect() transaction = connection.begin() - connection.execute(users.insert(), user_id=1, user_name='user1') - connection.execute(users.insert(), user_id=2, user_name='user2') - connection.execute(users.insert(), user_id=3, user_name='user3') + connection.execute(users.insert(), user_id=1, user_name="user1") + connection.execute(users.insert(), user_id=2, user_name="user2") + connection.execute(users.insert(), user_id=3, user_name="user3") trans2 = connection.begin() - connection.execute(users.insert(), user_id=4, user_name='user4') - connection.execute(users.insert(), user_id=5, user_name='user5') + connection.execute(users.insert(), user_id=4, user_name="user4") + connection.execute(users.insert(), user_id=5, user_name="user5") assert connection.in_transaction() trans2.close() assert connection.in_transaction() transaction.close() assert not connection.in_transaction() - self.assert_(connection.scalar('select count(*) from ' - 'query_users') == 0) - result = connection.execute('select * from query_users') + self.assert_( + connection.scalar("select count(*) from " "query_users") == 0 + ) + result = connection.execute("select * from query_users") assert len(result.fetchall()) == 0 connection.close() @@ -332,74 +367,87 @@ class TransactionTest(fixtures.TestBase): def test_nested_subtransaction_rollback(self): connection = testing.db.connect() transaction = connection.begin() - connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=1, user_name="user1") trans2 = connection.begin_nested() - connection.execute(users.insert(), user_id=2, user_name='user2') + connection.execute(users.insert(), user_id=2, user_name="user2") trans2.rollback() - connection.execute(users.insert(), user_id=3, user_name='user3') + connection.execute(users.insert(), user_id=3, user_name="user3") transaction.commit() - eq_(connection.execute(select([users.c.user_id]). - order_by(users.c.user_id)).fetchall(), - [(1, ), (3, )]) + eq_( + connection.execute( + select([users.c.user_id]).order_by(users.c.user_id) + ).fetchall(), + [(1,), (3,)], + ) connection.close() @testing.requires.savepoints - @testing.crashes('oracle+zxjdbc', - 'Errors out and causes subsequent tests to ' - 'deadlock') + @testing.crashes( + "oracle+zxjdbc", + "Errors out and causes subsequent tests to " "deadlock", + ) def test_nested_subtransaction_commit(self): connection = testing.db.connect() transaction = connection.begin() - connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=1, user_name="user1") trans2 = connection.begin_nested() - connection.execute(users.insert(), user_id=2, user_name='user2') + connection.execute(users.insert(), user_id=2, user_name="user2") trans2.commit() - connection.execute(users.insert(), user_id=3, user_name='user3') + connection.execute(users.insert(), user_id=3, user_name="user3") transaction.commit() - eq_(connection.execute(select([users.c.user_id]). - order_by(users.c.user_id)).fetchall(), - [(1, ), (2, ), (3, )]) + eq_( + connection.execute( + select([users.c.user_id]).order_by(users.c.user_id) + ).fetchall(), + [(1,), (2,), (3,)], + ) connection.close() @testing.requires.savepoints def test_rollback_to_subtransaction(self): connection = testing.db.connect() transaction = connection.begin() - connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=1, user_name="user1") trans2 = connection.begin_nested() - connection.execute(users.insert(), user_id=2, user_name='user2') + connection.execute(users.insert(), user_id=2, user_name="user2") trans3 = connection.begin() - connection.execute(users.insert(), user_id=3, user_name='user3') + connection.execute(users.insert(), user_id=3, user_name="user3") trans3.rollback() - connection.execute(users.insert(), user_id=4, user_name='user4') + connection.execute(users.insert(), user_id=4, user_name="user4") transaction.commit() - eq_(connection.execute(select([users.c.user_id]). - order_by(users.c.user_id)).fetchall(), - [(1, ), (4, )]) + eq_( + connection.execute( + select([users.c.user_id]).order_by(users.c.user_id) + ).fetchall(), + [(1,), (4,)], + ) connection.close() @testing.requires.two_phase_transactions def test_two_phase_transaction(self): connection = testing.db.connect() transaction = connection.begin_twophase() - connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=1, user_name="user1") transaction.prepare() transaction.commit() transaction = connection.begin_twophase() - connection.execute(users.insert(), user_id=2, user_name='user2') + connection.execute(users.insert(), user_id=2, user_name="user2") transaction.commit() transaction.close() transaction = connection.begin_twophase() - connection.execute(users.insert(), user_id=3, user_name='user3') + connection.execute(users.insert(), user_id=3, user_name="user3") transaction.rollback() transaction = connection.begin_twophase() - connection.execute(users.insert(), user_id=4, user_name='user4') + connection.execute(users.insert(), user_id=4, user_name="user4") transaction.prepare() transaction.rollback() transaction.close() - eq_(connection.execute(select([users.c.user_id]). - order_by(users.c.user_id)).fetchall(), - [(1, ), (2, )]) + eq_( + connection.execute( + select([users.c.user_id]).order_by(users.c.user_id) + ).fetchall(), + [(1,), (2,)], + ) connection.close() # PG emergency shutdown: @@ -408,29 +456,32 @@ class TransactionTest(fixtures.TestBase): # MySQL emergency shutdown: # for arg in `mysql -u root -e "xa recover" | cut -c 8-100 | # grep sa`; do mysql -u root -e "xa rollback '$arg'"; done - @testing.crashes('mysql', 'Crashing on 5.5, not worth it') + @testing.crashes("mysql", "Crashing on 5.5, not worth it") @testing.requires.skip_mysql_on_windows @testing.requires.two_phase_transactions @testing.requires.savepoints def test_mixed_two_phase_transaction(self): connection = testing.db.connect() transaction = connection.begin_twophase() - connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=1, user_name="user1") transaction2 = connection.begin() - connection.execute(users.insert(), user_id=2, user_name='user2') + connection.execute(users.insert(), user_id=2, user_name="user2") transaction3 = connection.begin_nested() - connection.execute(users.insert(), user_id=3, user_name='user3') + connection.execute(users.insert(), user_id=3, user_name="user3") transaction4 = connection.begin() - connection.execute(users.insert(), user_id=4, user_name='user4') + connection.execute(users.insert(), user_id=4, user_name="user4") transaction4.commit() transaction3.rollback() - connection.execute(users.insert(), user_id=5, user_name='user5') + connection.execute(users.insert(), user_id=5, user_name="user5") transaction2.commit() transaction.prepare() transaction.commit() - eq_(connection.execute(select([users.c.user_id]). - order_by(users.c.user_id)).fetchall(), - [(1, ), (2, ), (5, )]) + eq_( + connection.execute( + select([users.c.user_id]).order_by(users.c.user_id) + ).fetchall(), + [(1,), (2,), (5,)], + ) connection.close() @testing.requires.two_phase_transactions @@ -444,45 +495,50 @@ class TransactionTest(fixtures.TestBase): connection = testing.db.connect() transaction = connection.begin_twophase() - connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=1, user_name="user1") transaction.prepare() connection.invalidate() connection2 = testing.db.connect() eq_( - connection2.execution_options(autocommit=True). - execute(select([users.c.user_id]). - order_by(users.c.user_id)).fetchall(), []) + connection2.execution_options(autocommit=True) + .execute(select([users.c.user_id]).order_by(users.c.user_id)) + .fetchall(), + [], + ) recoverables = connection2.recover_twophase() assert transaction.xid in recoverables connection2.commit_prepared(transaction.xid, recover=True) - eq_(connection2.execute(select([users.c.user_id]). - order_by(users.c.user_id)).fetchall(), - [(1, )]) + eq_( + connection2.execute( + select([users.c.user_id]).order_by(users.c.user_id) + ).fetchall(), + [(1,)], + ) connection2.close() @testing.requires.two_phase_transactions def test_multiple_two_phase(self): conn = testing.db.connect() xa = conn.begin_twophase() - conn.execute(users.insert(), user_id=1, user_name='user1') + conn.execute(users.insert(), user_id=1, user_name="user1") xa.prepare() xa.commit() xa = conn.begin_twophase() - conn.execute(users.insert(), user_id=2, user_name='user2') + conn.execute(users.insert(), user_id=2, user_name="user2") xa.prepare() xa.rollback() xa = conn.begin_twophase() - conn.execute(users.insert(), user_id=3, user_name='user3') + conn.execute(users.insert(), user_id=3, user_name="user3") xa.rollback() xa = conn.begin_twophase() - conn.execute(users.insert(), user_id=4, user_name='user4') + conn.execute(users.insert(), user_id=4, user_name="user4") xa.prepare() xa.commit() - result = \ - conn.execute(select([users.c.user_name]). - order_by(users.c.user_id)) - eq_(result.fetchall(), [('user1', ), ('user4', )]) + result = conn.execute( + select([users.c.user_name]).order_by(users.c.user_id) + ) + eq_(result.fetchall(), [("user1",), ("user4",)]) conn.close() @testing.requires.two_phase_transactions @@ -506,14 +562,14 @@ class TransactionTest(fixtures.TestBase): rec = conn.connection._connection_record raw_dbapi_con = rec.connection xa = conn.begin_twophase() - conn.execute(users.insert(), user_id=1, user_name='user1') + conn.execute(users.insert(), user_id=1, user_name="user1") assert rec.connection is raw_dbapi_con with eng.connect() as conn: - result = \ - conn.execute(select([users.c.user_name]). - order_by(users.c.user_id)) + result = conn.execute( + select([users.c.user_name]).order_by(users.c.user_id) + ) eq_(result.fetchall(), []) @@ -639,12 +695,15 @@ class AutoRollbackTest(fixtures.TestBase): conn1 = testing.db.connect() conn2 = testing.db.connect() - users = Table('deadlock_users', metadata, - Column('user_id', INT, primary_key=True), - Column('user_name', VARCHAR(20)), - test_needs_acid=True) + users = Table( + "deadlock_users", + metadata, + Column("user_id", INT, primary_key=True), + Column("user_name", VARCHAR(20)), + test_needs_acid=True, + ) users.create(conn1) - conn1.execute('select * from deadlock_users') + conn1.execute("select * from deadlock_users") conn1.close() # without auto-rollback in the connection pool's return() logic, @@ -663,26 +722,31 @@ class ExplicitAutoCommitTest(fixtures.TestBase): Requires PostgreSQL so that we may define a custom function which modifies the database. """ - __only_on__ = 'postgresql' + __only_on__ = "postgresql" @classmethod def setup_class(cls): global metadata, foo metadata = MetaData(testing.db) - foo = Table('foo', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(100))) + foo = Table( + "foo", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(100)), + ) metadata.create_all() - testing.db.execute("create function insert_foo(varchar) " - "returns integer as 'insert into foo(data) " - "values ($1);select 1;' language sql") + testing.db.execute( + "create function insert_foo(varchar) " + "returns integer as 'insert into foo(data) " + "values ($1);select 1;' language sql" + ) def teardown(self): foo.delete().execute().close() @classmethod def teardown_class(cls): - testing.db.execute('drop function insert_foo(varchar)') + testing.db.execute("drop function insert_foo(varchar)") metadata.drop_all() def test_control(self): @@ -691,82 +755,101 @@ class ExplicitAutoCommitTest(fixtures.TestBase): conn1 = testing.db.connect() conn2 = testing.db.connect() - conn1.execute(select([func.insert_foo('data1')])) + conn1.execute(select([func.insert_foo("data1")])) assert conn2.execute(select([foo.c.data])).fetchall() == [] conn1.execute(text("select insert_foo('moredata')")) assert conn2.execute(select([foo.c.data])).fetchall() == [] trans = conn1.begin() trans.commit() - assert conn2.execute(select([foo.c.data])).fetchall() \ - == [('data1', ), ('moredata', )] + assert conn2.execute(select([foo.c.data])).fetchall() == [ + ("data1",), + ("moredata",), + ] conn1.close() conn2.close() def test_explicit_compiled(self): conn1 = testing.db.connect() conn2 = testing.db.connect() - conn1.execute(select([func.insert_foo('data1')]) - .execution_options(autocommit=True)) - assert conn2.execute(select([foo.c.data])).fetchall() \ - == [('data1', )] + conn1.execute( + select([func.insert_foo("data1")]).execution_options( + autocommit=True + ) + ) + assert conn2.execute(select([foo.c.data])).fetchall() == [("data1",)] conn1.close() conn2.close() def test_explicit_connection(self): conn1 = testing.db.connect() conn2 = testing.db.connect() - conn1.execution_options(autocommit=True).\ - execute(select([func.insert_foo('data1')])) - eq_(conn2.execute(select([foo.c.data])).fetchall(), [('data1',)]) + conn1.execution_options(autocommit=True).execute( + select([func.insert_foo("data1")]) + ) + eq_(conn2.execute(select([foo.c.data])).fetchall(), [("data1",)]) # connection supersedes statement - conn1.execution_options(autocommit=False).\ - execute(select([func.insert_foo('data2')]) - .execution_options(autocommit=True)) - eq_(conn2.execute(select([foo.c.data])).fetchall(), [('data1',)]) + conn1.execution_options(autocommit=False).execute( + select([func.insert_foo("data2")]).execution_options( + autocommit=True + ) + ) + eq_(conn2.execute(select([foo.c.data])).fetchall(), [("data1",)]) # ditto - conn1.execution_options(autocommit=True).\ - execute(select([func.insert_foo('data3')]) - .execution_options(autocommit=False)) - eq_(conn2.execute(select([foo.c.data])).fetchall(), - [('data1',), ('data2', ), ('data3', )]) + conn1.execution_options(autocommit=True).execute( + select([func.insert_foo("data3")]).execution_options( + autocommit=False + ) + ) + eq_( + conn2.execute(select([foo.c.data])).fetchall(), + [("data1",), ("data2",), ("data3",)], + ) conn1.close() conn2.close() def test_explicit_text(self): conn1 = testing.db.connect() conn2 = testing.db.connect() - conn1.execute(text("select insert_foo('moredata')") - .execution_options(autocommit=True)) - assert conn2.execute(select([foo.c.data])).fetchall() \ - == [('moredata', )] + conn1.execute( + text("select insert_foo('moredata')").execution_options( + autocommit=True + ) + ) + assert conn2.execute(select([foo.c.data])).fetchall() == [ + ("moredata",) + ] conn1.close() conn2.close() - @testing.uses_deprecated(r'autocommit on select\(\) is deprecated', - r'``autocommit\(\)`` is deprecated') + @testing.uses_deprecated( + r"autocommit on select\(\) is deprecated", + r"``autocommit\(\)`` is deprecated", + ) def test_explicit_compiled_deprecated(self): conn1 = testing.db.connect() conn2 = testing.db.connect() - conn1.execute(select([func.insert_foo('data1')], autocommit=True)) - assert conn2.execute(select([foo.c.data])).fetchall() \ - == [('data1', )] - conn1.execute(select([func.insert_foo('data2')]).autocommit()) - assert conn2.execute(select([foo.c.data])).fetchall() \ - == [('data1', ), ('data2', )] + conn1.execute(select([func.insert_foo("data1")], autocommit=True)) + assert conn2.execute(select([foo.c.data])).fetchall() == [("data1",)] + conn1.execute(select([func.insert_foo("data2")]).autocommit()) + assert conn2.execute(select([foo.c.data])).fetchall() == [ + ("data1",), + ("data2",), + ] conn1.close() conn2.close() - @testing.uses_deprecated(r'autocommit on text\(\) is deprecated') + @testing.uses_deprecated(r"autocommit on text\(\) is deprecated") def test_explicit_text_deprecated(self): conn1 = testing.db.connect() conn2 = testing.db.connect() conn1.execute(text("select insert_foo('moredata')", autocommit=True)) - assert conn2.execute(select([foo.c.data])).fetchall() \ - == [('moredata', )] + assert conn2.execute(select([foo.c.data])).fetchall() == [ + ("moredata",) + ] conn1.close() conn2.close() @@ -774,8 +857,9 @@ class ExplicitAutoCommitTest(fixtures.TestBase): conn1 = testing.db.connect() conn2 = testing.db.connect() conn1.execute(text("insert into foo (data) values ('implicitdata')")) - assert conn2.execute(select([foo.c.data])).fetchall() \ - == [('implicitdata', )] + assert conn2.execute(select([foo.c.data])).fetchall() == [ + ("implicitdata",) + ] conn1.close() conn2.close() @@ -784,20 +868,26 @@ tlengine = None class TLTransactionTest(fixtures.TestBase): - __requires__ = ('ad_hoc_engines', ) + __requires__ = ("ad_hoc_engines",) __backend__ = True @classmethod def setup_class(cls): global users, metadata, tlengine - tlengine = testing_engine(options=dict(strategy='threadlocal')) + tlengine = testing_engine(options=dict(strategy="threadlocal")) metadata = MetaData() - users = Table('query_users', metadata, - Column('user_id', - INT, - Sequence('query_users_id_seq', optional=True), - primary_key=True), - Column('user_name', VARCHAR(20)), test_needs_acid=True) + users = Table( + "query_users", + metadata, + Column( + "user_id", + INT, + Sequence("query_users_id_seq", optional=True), + primary_key=True, + ), + Column("user_name", VARCHAR(20)), + test_needs_acid=True, + ) metadata.create_all(tlengine) def teardown(self): @@ -815,8 +905,9 @@ class TLTransactionTest(fixtures.TestBase): tlengine.close() - @testing.crashes('oracle', - 'TNS error of unknown origin occurs on the buildbot.') + @testing.crashes( + "oracle", "TNS error of unknown origin occurs on the buildbot." + ) def test_rollback_no_trans(self): tlengine = testing_engine(options=dict(strategy="threadlocal")) @@ -866,17 +957,17 @@ class TLTransactionTest(fixtures.TestBase): def test_transaction_close(self): c = tlengine.contextual_connect() t = c.begin() - tlengine.execute(users.insert(), user_id=1, user_name='user1') - tlengine.execute(users.insert(), user_id=2, user_name='user2') + tlengine.execute(users.insert(), user_id=1, user_name="user1") + tlengine.execute(users.insert(), user_id=2, user_name="user2") t2 = c.begin() - tlengine.execute(users.insert(), user_id=3, user_name='user3') - tlengine.execute(users.insert(), user_id=4, user_name='user4') + tlengine.execute(users.insert(), user_id=3, user_name="user3") + tlengine.execute(users.insert(), user_id=4, user_name="user4") t2.close() - result = c.execute('select * from query_users') + result = c.execute("select * from query_users") assert len(result.fetchall()) == 4 t.close() external_connection = tlengine.connect() - result = external_connection.execute('select * from query_users') + result = external_connection.execute("select * from query_users") try: assert len(result.fetchall()) == 0 finally: @@ -887,12 +978,12 @@ class TLTransactionTest(fixtures.TestBase): """test a basic rollback""" tlengine.begin() - tlengine.execute(users.insert(), user_id=1, user_name='user1') - tlengine.execute(users.insert(), user_id=2, user_name='user2') - tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.execute(users.insert(), user_id=1, user_name="user1") + tlengine.execute(users.insert(), user_id=2, user_name="user2") + tlengine.execute(users.insert(), user_id=3, user_name="user3") tlengine.rollback() external_connection = tlengine.connect() - result = external_connection.execute('select * from query_users') + result = external_connection.execute("select * from query_users") try: assert len(result.fetchall()) == 0 finally: @@ -902,12 +993,12 @@ class TLTransactionTest(fixtures.TestBase): """test a basic commit""" tlengine.begin() - tlengine.execute(users.insert(), user_id=1, user_name='user1') - tlengine.execute(users.insert(), user_id=2, user_name='user2') - tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.execute(users.insert(), user_id=1, user_name="user1") + tlengine.execute(users.insert(), user_id=2, user_name="user2") + tlengine.execute(users.insert(), user_id=3, user_name="user3") tlengine.commit() external_connection = tlengine.connect() - result = external_connection.execute('select * from query_users') + result = external_connection.execute("select * from query_users") try: assert len(result.fetchall()) == 3 finally: @@ -915,43 +1006,42 @@ class TLTransactionTest(fixtures.TestBase): def test_with_interface(self): trans = tlengine.begin() - tlengine.execute(users.insert(), user_id=1, user_name='user1') - tlengine.execute(users.insert(), user_id=2, user_name='user2') + tlengine.execute(users.insert(), user_id=1, user_name="user1") + tlengine.execute(users.insert(), user_id=2, user_name="user2") trans.commit() trans = tlengine.begin() - tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.execute(users.insert(), user_id=3, user_name="user3") trans.__exit__(Exception, "fake", None) trans = tlengine.begin() - tlengine.execute(users.insert(), user_id=4, user_name='user4') + tlengine.execute(users.insert(), user_id=4, user_name="user4") trans.__exit__(None, None, None) eq_( - tlengine.execute(users.select().order_by(users.c.user_id)) - .fetchall(), - [ - (1, 'user1'), - (2, 'user2'), - (4, 'user4'), - ] + tlengine.execute( + users.select().order_by(users.c.user_id) + ).fetchall(), + [(1, "user1"), (2, "user2"), (4, "user4")], ) def test_commits(self): connection = tlengine.connect() - assert connection.execute('select count(*) from query_users' - ).scalar() == 0 + assert ( + connection.execute("select count(*) from query_users").scalar() + == 0 + ) connection.close() connection = tlengine.contextual_connect() transaction = connection.begin() - connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=1, user_name="user1") transaction.commit() transaction = connection.begin() - connection.execute(users.insert(), user_id=2, user_name='user2') - connection.execute(users.insert(), user_id=3, user_name='user3') + connection.execute(users.insert(), user_id=2, user_name="user2") + connection.execute(users.insert(), user_id=3, user_name="user3") transaction.commit() transaction = connection.begin() - result = connection.execute('select * from query_users') + result = connection.execute("select * from query_users") rows = result.fetchall() - assert len(rows) == 3, 'expected 3 got %d' % len(rows) + assert len(rows) == 3, "expected 3 got %d" % len(rows) transaction.commit() connection.close() @@ -962,12 +1052,12 @@ class TLTransactionTest(fixtures.TestBase): conn = tlengine.contextual_connect() trans = conn.begin() - conn.execute(users.insert(), user_id=1, user_name='user1') - conn.execute(users.insert(), user_id=2, user_name='user2') - conn.execute(users.insert(), user_id=3, user_name='user3') + conn.execute(users.insert(), user_id=1, user_name="user1") + conn.execute(users.insert(), user_id=2, user_name="user2") + conn.execute(users.insert(), user_id=3, user_name="user3") trans.rollback() external_connection = tlengine.connect() - result = external_connection.execute('select * from query_users') + result = external_connection.execute("select * from query_users") try: assert len(result.fetchall()) == 0 finally: @@ -982,12 +1072,12 @@ class TLTransactionTest(fixtures.TestBase): conn = tlengine.contextual_connect() conn2 = tlengine.contextual_connect() trans = conn2.begin() - conn.execute(users.insert(), user_id=1, user_name='user1') - conn.execute(users.insert(), user_id=2, user_name='user2') - conn.execute(users.insert(), user_id=3, user_name='user3') + conn.execute(users.insert(), user_id=1, user_name="user1") + conn.execute(users.insert(), user_id=2, user_name="user2") + conn.execute(users.insert(), user_id=3, user_name="user3") trans.rollback() external_connection = tlengine.connect() - result = external_connection.execute('select * from query_users') + result = external_connection.execute("select * from query_users") try: assert len(result.fetchall()) == 0 finally: @@ -998,12 +1088,12 @@ class TLTransactionTest(fixtures.TestBase): def test_commit_off_connection(self): conn = tlengine.contextual_connect() trans = conn.begin() - conn.execute(users.insert(), user_id=1, user_name='user1') - conn.execute(users.insert(), user_id=2, user_name='user2') - conn.execute(users.insert(), user_id=3, user_name='user3') + conn.execute(users.insert(), user_id=1, user_name="user1") + conn.execute(users.insert(), user_id=2, user_name="user2") + conn.execute(users.insert(), user_id=3, user_name="user3") trans.commit() external_connection = tlengine.connect() - result = external_connection.execute('select * from query_users') + result = external_connection.execute("select * from query_users") try: assert len(result.fetchall()) == 3 finally: @@ -1014,20 +1104,24 @@ class TLTransactionTest(fixtures.TestBase): """tests nesting of transactions, rollback at the end""" external_connection = tlengine.connect() - self.assert_(external_connection.connection - is not tlengine.contextual_connect().connection) + self.assert_( + external_connection.connection + is not tlengine.contextual_connect().connection + ) tlengine.begin() - tlengine.execute(users.insert(), user_id=1, user_name='user1') - tlengine.execute(users.insert(), user_id=2, user_name='user2') - tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.execute(users.insert(), user_id=1, user_name="user1") + tlengine.execute(users.insert(), user_id=2, user_name="user2") + tlengine.execute(users.insert(), user_id=3, user_name="user3") tlengine.begin() - tlengine.execute(users.insert(), user_id=4, user_name='user4') - tlengine.execute(users.insert(), user_id=5, user_name='user5') + tlengine.execute(users.insert(), user_id=4, user_name="user4") + tlengine.execute(users.insert(), user_id=5, user_name="user5") tlengine.commit() tlengine.rollback() try: - self.assert_(external_connection.scalar( - 'select count(*) from query_users') == 0) + self.assert_( + external_connection.scalar("select count(*) from query_users") + == 0 + ) finally: external_connection.close() @@ -1035,20 +1129,24 @@ class TLTransactionTest(fixtures.TestBase): """tests nesting of transactions, commit at the end.""" external_connection = tlengine.connect() - self.assert_(external_connection.connection - is not tlengine.contextual_connect().connection) + self.assert_( + external_connection.connection + is not tlengine.contextual_connect().connection + ) tlengine.begin() - tlengine.execute(users.insert(), user_id=1, user_name='user1') - tlengine.execute(users.insert(), user_id=2, user_name='user2') - tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.execute(users.insert(), user_id=1, user_name="user1") + tlengine.execute(users.insert(), user_id=2, user_name="user2") + tlengine.execute(users.insert(), user_id=3, user_name="user3") tlengine.begin() - tlengine.execute(users.insert(), user_id=4, user_name='user4') - tlengine.execute(users.insert(), user_id=5, user_name='user5') + tlengine.execute(users.insert(), user_id=4, user_name="user4") + tlengine.execute(users.insert(), user_id=5, user_name="user5") tlengine.commit() tlengine.commit() try: - self.assert_(external_connection.scalar( - 'select count(*) from query_users') == 5) + self.assert_( + external_connection.scalar("select count(*) from query_users") + == 5 + ) finally: external_connection.close() @@ -1057,29 +1155,33 @@ class TLTransactionTest(fixtures.TestBase): inside of transactions off the connection from the TLEngine""" external_connection = tlengine.connect() - self.assert_(external_connection.connection - is not tlengine.contextual_connect().connection) + self.assert_( + external_connection.connection + is not tlengine.contextual_connect().connection + ) conn = tlengine.contextual_connect() trans = conn.begin() trans2 = conn.begin() - tlengine.execute(users.insert(), user_id=1, user_name='user1') - tlengine.execute(users.insert(), user_id=2, user_name='user2') - tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.execute(users.insert(), user_id=1, user_name="user1") + tlengine.execute(users.insert(), user_id=2, user_name="user2") + tlengine.execute(users.insert(), user_id=3, user_name="user3") tlengine.begin() - tlengine.execute(users.insert(), user_id=4, user_name='user4') + tlengine.execute(users.insert(), user_id=4, user_name="user4") tlengine.begin() - tlengine.execute(users.insert(), user_id=5, user_name='user5') - tlengine.execute(users.insert(), user_id=6, user_name='user6') - tlengine.execute(users.insert(), user_id=7, user_name='user7') + tlengine.execute(users.insert(), user_id=5, user_name="user5") + tlengine.execute(users.insert(), user_id=6, user_name="user6") + tlengine.execute(users.insert(), user_id=7, user_name="user7") tlengine.commit() - tlengine.execute(users.insert(), user_id=8, user_name='user8') + tlengine.execute(users.insert(), user_id=8, user_name="user8") tlengine.commit() trans2.commit() trans.rollback() conn.close() try: - self.assert_(external_connection.scalar( - 'select count(*) from query_users') == 0) + self.assert_( + external_connection.scalar("select count(*) from query_users") + == 0 + ) finally: external_connection.close() @@ -1088,76 +1190,90 @@ class TLTransactionTest(fixtures.TestBase): TLEngine inside of transactions off the TLEngine directly.""" external_connection = tlengine.connect() - self.assert_(external_connection.connection - is not tlengine.contextual_connect().connection) + self.assert_( + external_connection.connection + is not tlengine.contextual_connect().connection + ) tlengine.begin() connection = tlengine.contextual_connect() - connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=1, user_name="user1") tlengine.begin() - connection.execute(users.insert(), user_id=2, user_name='user2') - connection.execute(users.insert(), user_id=3, user_name='user3') + connection.execute(users.insert(), user_id=2, user_name="user2") + connection.execute(users.insert(), user_id=3, user_name="user3") trans = connection.begin() - connection.execute(users.insert(), user_id=4, user_name='user4') - connection.execute(users.insert(), user_id=5, user_name='user5') + connection.execute(users.insert(), user_id=4, user_name="user4") + connection.execute(users.insert(), user_id=5, user_name="user5") trans.commit() tlengine.commit() tlengine.rollback() connection.close() try: - self.assert_(external_connection.scalar( - 'select count(*) from query_users') == 0) + self.assert_( + external_connection.scalar("select count(*) from query_users") + == 0 + ) finally: external_connection.close() @testing.requires.savepoints def test_nested_subtransaction_rollback(self): tlengine.begin() - tlengine.execute(users.insert(), user_id=1, user_name='user1') + tlengine.execute(users.insert(), user_id=1, user_name="user1") tlengine.begin_nested() - tlengine.execute(users.insert(), user_id=2, user_name='user2') + tlengine.execute(users.insert(), user_id=2, user_name="user2") tlengine.rollback() - tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.execute(users.insert(), user_id=3, user_name="user3") tlengine.commit() tlengine.close() - eq_(tlengine.execute(select([users.c.user_id]). - order_by(users.c.user_id)).fetchall(), - [(1, ), (3, )]) + eq_( + tlengine.execute( + select([users.c.user_id]).order_by(users.c.user_id) + ).fetchall(), + [(1,), (3,)], + ) tlengine.close() @testing.requires.savepoints - @testing.crashes('oracle+zxjdbc', - 'Errors out and causes subsequent tests to ' - 'deadlock') + @testing.crashes( + "oracle+zxjdbc", + "Errors out and causes subsequent tests to " "deadlock", + ) def test_nested_subtransaction_commit(self): tlengine.begin() - tlengine.execute(users.insert(), user_id=1, user_name='user1') + tlengine.execute(users.insert(), user_id=1, user_name="user1") tlengine.begin_nested() - tlengine.execute(users.insert(), user_id=2, user_name='user2') + tlengine.execute(users.insert(), user_id=2, user_name="user2") tlengine.commit() - tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.execute(users.insert(), user_id=3, user_name="user3") tlengine.commit() tlengine.close() - eq_(tlengine.execute(select([users.c.user_id]). - order_by(users.c.user_id)).fetchall(), - [(1, ), (2, ), (3, )]) + eq_( + tlengine.execute( + select([users.c.user_id]).order_by(users.c.user_id) + ).fetchall(), + [(1,), (2,), (3,)], + ) tlengine.close() @testing.requires.savepoints def test_rollback_to_subtransaction(self): tlengine.begin() - tlengine.execute(users.insert(), user_id=1, user_name='user1') + tlengine.execute(users.insert(), user_id=1, user_name="user1") tlengine.begin_nested() - tlengine.execute(users.insert(), user_id=2, user_name='user2') + tlengine.execute(users.insert(), user_id=2, user_name="user2") tlengine.begin() - tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.execute(users.insert(), user_id=3, user_name="user3") tlengine.rollback() tlengine.rollback() - tlengine.execute(users.insert(), user_id=4, user_name='user4') + tlengine.execute(users.insert(), user_id=4, user_name="user4") tlengine.commit() tlengine.close() - eq_(tlengine.execute(select([users.c.user_id]). - order_by(users.c.user_id)).fetchall(), - [(1, ), (4, )]) + eq_( + tlengine.execute( + select([users.c.user_id]).order_by(users.c.user_id) + ).fetchall(), + [(1,), (4,)], + ) tlengine.close() def test_connections(self): @@ -1194,10 +1310,11 @@ class TLTransactionTest(fixtures.TestBase): assert r2.connection.closed assert tlengine.closed - @testing.crashes('oracle+cx_oracle', - 'intermittent failures on the buildbot') + @testing.crashes( + "oracle+cx_oracle", "intermittent failures on the buildbot" + ) def test_dispose(self): - eng = testing_engine(options=dict(strategy='threadlocal')) + eng = testing_engine(options=dict(strategy="threadlocal")) result = eng.execute(select([1])) eng.dispose() eng.execute(select([1])) @@ -1205,48 +1322,51 @@ class TLTransactionTest(fixtures.TestBase): @testing.requires.two_phase_transactions def test_two_phase_transaction(self): tlengine.begin_twophase() - tlengine.execute(users.insert(), user_id=1, user_name='user1') + tlengine.execute(users.insert(), user_id=1, user_name="user1") tlengine.prepare() tlengine.commit() tlengine.begin_twophase() - tlengine.execute(users.insert(), user_id=2, user_name='user2') + tlengine.execute(users.insert(), user_id=2, user_name="user2") tlengine.commit() tlengine.begin_twophase() - tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.execute(users.insert(), user_id=3, user_name="user3") tlengine.rollback() tlengine.begin_twophase() - tlengine.execute(users.insert(), user_id=4, user_name='user4') + tlengine.execute(users.insert(), user_id=4, user_name="user4") tlengine.prepare() tlengine.rollback() - eq_(tlengine.execute(select([users.c.user_id]). - order_by(users.c.user_id)).fetchall(), - [(1, ), (2, )]) + eq_( + tlengine.execute( + select([users.c.user_id]).order_by(users.c.user_id) + ).fetchall(), + [(1,), (2,)], + ) class IsolationLevelTest(fixtures.TestBase): - __requires__ = ('isolation_level', 'ad_hoc_engines') + __requires__ = ("isolation_level", "ad_hoc_engines") __backend__ = True def _default_isolation_level(self): - if testing.against('sqlite'): - return 'SERIALIZABLE' - elif testing.against('postgresql'): - return 'READ COMMITTED' - elif testing.against('mysql'): + if testing.against("sqlite"): + return "SERIALIZABLE" + elif testing.against("postgresql"): + return "READ COMMITTED" + elif testing.against("mysql"): return "REPEATABLE READ" - elif testing.against('mssql'): + elif testing.against("mssql"): return "READ COMMITTED" else: assert False, "default isolation level not known" def _non_default_isolation_level(self): - if testing.against('sqlite'): - return 'READ UNCOMMITTED' - elif testing.against('postgresql'): - return 'SERIALIZABLE' - elif testing.against('mysql'): + if testing.against("sqlite"): + return "READ UNCOMMITTED" + elif testing.against("postgresql"): + return "SERIALIZABLE" + elif testing.against("mysql"): return "SERIALIZABLE" - elif testing.against('mssql'): + elif testing.against("mssql"): return "SERIALIZABLE" else: assert False, "non default isolation level not known" @@ -1255,37 +1375,29 @@ class IsolationLevelTest(fixtures.TestBase): eng = testing_engine() isolation_level = eng.dialect.get_isolation_level( - eng.connect().connection) + eng.connect().connection + ) level = self._non_default_isolation_level() ne_(isolation_level, level) eng = testing_engine(options=dict(isolation_level=level)) - eq_( - eng.dialect.get_isolation_level( - eng.connect().connection), - level - ) + eq_(eng.dialect.get_isolation_level(eng.connect().connection), level) # check that it stays conn = eng.connect() - eq_( - eng.dialect.get_isolation_level(conn.connection), - level - ) + eq_(eng.dialect.get_isolation_level(conn.connection), level) conn.close() conn = eng.connect() - eq_( - eng.dialect.get_isolation_level(conn.connection), - level - ) + eq_(eng.dialect.get_isolation_level(conn.connection), level) conn.close() def test_default_level(self): eng = testing_engine(options=dict()) isolation_level = eng.dialect.get_isolation_level( - eng.connect().connection) + eng.connect().connection + ) eq_(isolation_level, self._default_isolation_level()) def test_reset_level(self): @@ -1293,7 +1405,7 @@ class IsolationLevelTest(fixtures.TestBase): conn = eng.connect() eq_( eng.dialect.get_isolation_level(conn.connection), - self._default_isolation_level() + self._default_isolation_level(), ) eng.dialect.set_isolation_level( @@ -1301,50 +1413,60 @@ class IsolationLevelTest(fixtures.TestBase): ) eq_( eng.dialect.get_isolation_level(conn.connection), - self._non_default_isolation_level() + self._non_default_isolation_level(), ) eng.dialect.reset_isolation_level(conn.connection) eq_( eng.dialect.get_isolation_level(conn.connection), - self._default_isolation_level() + self._default_isolation_level(), ) conn.close() def test_reset_level_with_setting(self): eng = testing_engine( - options=dict( - isolation_level=self._non_default_isolation_level())) + options=dict(isolation_level=self._non_default_isolation_level()) + ) conn = eng.connect() - eq_(eng.dialect.get_isolation_level(conn.connection), - self._non_default_isolation_level()) + eq_( + eng.dialect.get_isolation_level(conn.connection), + self._non_default_isolation_level(), + ) eng.dialect.set_isolation_level( - conn.connection, - self._default_isolation_level()) - eq_(eng.dialect.get_isolation_level(conn.connection), - self._default_isolation_level()) + conn.connection, self._default_isolation_level() + ) + eq_( + eng.dialect.get_isolation_level(conn.connection), + self._default_isolation_level(), + ) eng.dialect.reset_isolation_level(conn.connection) - eq_(eng.dialect.get_isolation_level(conn.connection), - self._non_default_isolation_level()) + eq_( + eng.dialect.get_isolation_level(conn.connection), + self._non_default_isolation_level(), + ) conn.close() def test_invalid_level(self): - eng = testing_engine(options=dict(isolation_level='FOO')) + eng = testing_engine(options=dict(isolation_level="FOO")) assert_raises_message( exc.ArgumentError, "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" % - ("FOO", - eng.dialect.name, ", ".join(eng.dialect._isolation_lookup)), - eng.connect + "Valid isolation levels for %s are %s" + % ( + "FOO", + eng.dialect.name, + ", ".join(eng.dialect._isolation_lookup), + ), + eng.connect, ) def test_connection_invalidated(self): eng = testing_engine() conn = eng.connect() c2 = conn.execution_options( - isolation_level=self._non_default_isolation_level()) + isolation_level=self._non_default_isolation_level() + ) c2.invalidate() c2.connection @@ -1354,10 +1476,10 @@ class IsolationLevelTest(fixtures.TestBase): def test_per_connection(self): from sqlalchemy.pool import QueuePool + eng = testing_engine( - options=dict( - poolclass=QueuePool, - pool_size=2, max_overflow=0)) + options=dict(poolclass=QueuePool, pool_size=2, max_overflow=0) + ) c1 = eng.connect() c1 = c1.execution_options( @@ -1366,23 +1488,23 @@ class IsolationLevelTest(fixtures.TestBase): c2 = eng.connect() eq_( eng.dialect.get_isolation_level(c1.connection), - self._non_default_isolation_level() + self._non_default_isolation_level(), ) eq_( eng.dialect.get_isolation_level(c2.connection), - self._default_isolation_level() + self._default_isolation_level(), ) c1.close() c2.close() c3 = eng.connect() eq_( eng.dialect.get_isolation_level(c3.connection), - self._default_isolation_level() + self._default_isolation_level(), ) c4 = eng.connect() eq_( eng.dialect.get_isolation_level(c4.connection), - self._default_isolation_level() + self._default_isolation_level(), ) c3.close() @@ -1392,10 +1514,10 @@ class IsolationLevelTest(fixtures.TestBase): eng = testing_engine() c1 = eng.connect() with expect_warnings( - "Connection is already established with a Transaction; " - "setting isolation_level may implicitly rollback or commit " - "the existing transaction, or have no effect until next " - "transaction" + "Connection is already established with a Transaction; " + "setting isolation_level may implicitly rollback or commit " + "the existing transaction, or have no effect until next " + "transaction" ): with c1.begin(): c1 = c1.execution_options( @@ -1404,12 +1526,12 @@ class IsolationLevelTest(fixtures.TestBase): eq_( eng.dialect.get_isolation_level(c1.connection), - self._non_default_isolation_level() + self._non_default_isolation_level(), ) # stays outside of transaction eq_( eng.dialect.get_isolation_level(c1.connection), - self._non_default_isolation_level() + self._non_default_isolation_level(), ) def test_per_statement_bzzt(self): @@ -1420,7 +1542,7 @@ class IsolationLevelTest(fixtures.TestBase): r"per-engine using the isolation_level " r"argument to create_engine\(\).", select([1]).execution_options, - isolation_level=self._non_default_isolation_level() + isolation_level=self._non_default_isolation_level(), ) def test_per_engine(self): @@ -1428,32 +1550,30 @@ class IsolationLevelTest(fixtures.TestBase): eng = create_engine( testing.db.url, execution_options={ - 'isolation_level': - self._non_default_isolation_level()} + "isolation_level": self._non_default_isolation_level() + }, ) conn = eng.connect() eq_( eng.dialect.get_isolation_level(conn.connection), - self._non_default_isolation_level() + self._non_default_isolation_level(), ) def test_isolation_level_accessors_connection_default(self): - eng = create_engine( - testing.db.url - ) + eng = create_engine(testing.db.url) with eng.connect() as conn: eq_(conn.default_isolation_level, self._default_isolation_level()) with eng.connect() as conn: eq_(conn.get_isolation_level(), self._default_isolation_level()) def test_isolation_level_accessors_connection_option_modified(self): - eng = create_engine( - testing.db.url - ) + eng = create_engine(testing.db.url) with eng.connect() as conn: c2 = conn.execution_options( - isolation_level=self._non_default_isolation_level()) + isolation_level=self._non_default_isolation_level() + ) eq_(conn.default_isolation_level, self._default_isolation_level()) - eq_(conn.get_isolation_level(), - self._non_default_isolation_level()) + eq_( + conn.get_isolation_level(), self._non_default_isolation_level() + ) eq_(c2.get_isolation_level(), self._non_default_isolation_level()) diff --git a/test/ext/declarative/test_basic.py b/test/ext/declarative/test_basic.py index f45421ad9e..9de7f5ab9e 100644 --- a/test/ext/declarative/test_basic.py +++ b/test/ext/declarative/test_basic.py @@ -1,18 +1,41 @@ - -from sqlalchemy.testing import eq_, assert_raises, \ - assert_raises_message, expect_warnings, is_ +from sqlalchemy.testing import ( + eq_, + assert_raises, + assert_raises_message, + expect_warnings, + is_, +) from sqlalchemy.testing import assertions from sqlalchemy.ext import declarative as decl from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy import exc import sqlalchemy as sa from sqlalchemy import testing, util -from sqlalchemy import MetaData, Integer, String, ForeignKey, \ - ForeignKeyConstraint, Index, UniqueConstraint, CheckConstraint +from sqlalchemy import ( + MetaData, + Integer, + String, + ForeignKey, + ForeignKeyConstraint, + Index, + UniqueConstraint, + CheckConstraint, +) from sqlalchemy.testing.schema import Table, Column -from sqlalchemy.orm import relationship, create_session, class_mapper, \ - joinedload, configure_mappers, backref, clear_mappers, \ - column_property, composite, Session, properties, deferred +from sqlalchemy.orm import ( + relationship, + create_session, + class_mapper, + joinedload, + configure_mappers, + backref, + clear_mappers, + column_property, + composite, + Session, + properties, + deferred, +) from sqlalchemy.util import with_metaclass from sqlalchemy.ext.declarative import declared_attr, synonym_for from sqlalchemy.testing import fixtures, mock @@ -26,10 +49,12 @@ Base = None User = Address = None -class DeclarativeTestBase(fixtures.TestBase, - testing.AssertsExecutionResults, - testing.AssertsCompiledSQL): - __dialect__ = 'default' +class DeclarativeTestBase( + fixtures.TestBase, + testing.AssertsExecutionResults, + testing.AssertsCompiledSQL, +): + __dialect__ = "default" def setup(self): global Base @@ -42,119 +67,132 @@ class DeclarativeTestBase(fixtures.TestBase, class DeclarativeTest(DeclarativeTestBase): - def test_basic(self): class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' + __tablename__ = "users" - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) addresses = relationship("Address", backref="user") class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' + __tablename__ = "addresses" - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) - email = Column(String(50), key='_email') - user_id = Column('user_id', Integer, ForeignKey('users.id'), - key='_user_id') + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + email = Column(String(50), key="_email") + user_id = Column( + "user_id", Integer, ForeignKey("users.id"), key="_user_id" + ) Base.metadata.create_all() - eq_(Address.__table__.c['id'].name, 'id') - eq_(Address.__table__.c['_email'].name, 'email') - eq_(Address.__table__.c['_user_id'].name, 'user_id') + eq_(Address.__table__.c["id"].name, "id") + eq_(Address.__table__.c["_email"].name, "email") + eq_(Address.__table__.c["_user_id"].name, "user_id") - u1 = User(name='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ]) + u1 = User( + name="u1", addresses=[Address(email="one"), Address(email="two")] + ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [User(name='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ])]) + eq_( + sess.query(User).all(), + [ + User( + name="u1", + addresses=[Address(email="one"), Address(email="two")], + ) + ], + ) - a1 = sess.query(Address).filter(Address.email == 'two').one() - eq_(a1, Address(email='two')) - eq_(a1.user, User(name='u1')) + a1 = sess.query(Address).filter(Address.email == "two").one() + eq_(a1, Address(email="two")) + eq_(a1.user, User(name="u1")) def test_unicode_string_resolve(self): class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' + __tablename__ = "users" - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) addresses = relationship(util.u("Address"), backref="user") class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' + __tablename__ = "addresses" - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) - email = Column(String(50), key='_email') - user_id = Column('user_id', Integer, ForeignKey('users.id'), - key='_user_id') + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + email = Column(String(50), key="_email") + user_id = Column( + "user_id", Integer, ForeignKey("users.id"), key="_user_id" + ) assert User.addresses.property.mapper.class_ is Address def test_unicode_string_resolve_backref(self): class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' + __tablename__ = "users" - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' + __tablename__ = "addresses" - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) - email = Column(String(50), key='_email') - user_id = Column('user_id', Integer, ForeignKey('users.id'), - key='_user_id') + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + email = Column(String(50), key="_email") + user_id = Column( + "user_id", Integer, ForeignKey("users.id"), key="_user_id" + ) user = relationship( - User, - backref=backref("addresses", - order_by=util.u("Address.email"))) + User, + backref=backref("addresses", order_by=util.u("Address.email")), + ) assert Address.user.property.mapper.class_ is User def test_no_table(self): def go(): class User(Base): - id = Column('id', Integer, primary_key=True) + id = Column("id", Integer, primary_key=True) - assert_raises_message(sa.exc.InvalidRequestError, - 'does not have a __table__', go) + assert_raises_message( + sa.exc.InvalidRequestError, "does not have a __table__", go + ) def test_table_args_empty_dict(self): - class MyModel(Base): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) __table_args__ = {} def test_table_args_empty_tuple(self): - class MyModel(Base): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) __table_args__ = () def test_cant_add_columns(self): t = Table( - 't', Base.metadata, - Column('id', Integer, primary_key=True), - Column('data', String)) + "t", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("data", String), + ) def go(): class User(Base): @@ -163,9 +201,11 @@ class DeclarativeTest(DeclarativeTestBase): # can't specify new columns not already in the table - assert_raises_message(sa.exc.ArgumentError, - "Can't add additional column 'foo' when " - "specifying __table__", go) + assert_raises_message( + sa.exc.ArgumentError, + "Can't add additional column 'foo' when " "specifying __table__", + go, + ) # regular re-mapping works tho @@ -173,16 +213,18 @@ class DeclarativeTest(DeclarativeTestBase): __table__ = t some_data = t.c.data - assert class_mapper(Bar).get_property('some_data').columns[0] \ - is t.c.data + assert ( + class_mapper(Bar).get_property("some_data").columns[0] is t.c.data + ) def test_lower_case_c_column_warning(self): with assertions.expect_warnings( r"Attribute 'x' on class sa.func.foo()), - ) + __table_args__ = (CheckConstraint(cprop > sa.func.foo()),) + ck = [ - c for c in Bar.__table__.constraints - if isinstance(c, CheckConstraint)][0] + c + for c in Bar.__table__.constraints + if isinstance(c, CheckConstraint) + ][0] is_(ck.columns.cprop, Bar.__table__.c.cprop) if testing.requires.python3.enabled: # test the existing failure case in case something changes def go(): class Bat(Base): - __tablename__ = 'bat' + __tablename__ = "bat" id = Column(Integer, primary_key=True) cprop = deferred(Column(Integer)) @@ -306,43 +354,44 @@ class DeclarativeTest(DeclarativeTestBase): # "cprop > 5" because the column property isn't # a full blown column - __table_args__ = ( - CheckConstraint(cprop > 5), - ) + __table_args__ = (CheckConstraint(cprop > 5),) + assert_raises(TypeError, go) def test_relationship_level_msg_for_invalid_callable(self): class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) class B(Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) - a_id = Column(Integer, ForeignKey('a.id')) - a = relationship('a') + a_id = Column(Integer, ForeignKey("a.id")) + a = relationship("a") + assert_raises_message( sa.exc.ArgumentError, "relationship 'a' expects a class or a mapper " "argument .received: .*Table", - configure_mappers + configure_mappers, ) def test_relationship_level_msg_for_invalid_object(self): class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) class B(Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) - a_id = Column(Integer, ForeignKey('a.id')) + a_id = Column(Integer, ForeignKey("a.id")) a = relationship(A.__table__) + assert_raises_message( sa.exc.ArgumentError, "relationship 'a' expects a class or a mapper " "argument .received: .*Table", - configure_mappers + configure_mappers, ) def test_difficult_class(self): @@ -350,10 +399,9 @@ class DeclarativeTest(DeclarativeTestBase): # metaclass to mock the way zope.interface breaks getattr() class BrokenMeta(type): - def __getattribute__(self, attr): - if attr == 'xyzzy': - raise AttributeError('xyzzy') + if attr == "xyzzy": + raise AttributeError("xyzzy") else: return object.__getattribute__(self, attr) @@ -364,23 +412,24 @@ class DeclarativeTest(DeclarativeTestBase): # _as_declarative() inspects obj.__class__.__bases__ class User(BrokenParent, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) decl.instrument_declarative(User, {}, Base.metadata) def test_reserved_identifiers(self): def go1(): class User1(Base): - __tablename__ = 'user1' + __tablename__ = "user1" id = Column(Integer, primary_key=True) metadata = Column(Integer) def go2(): class User2(Base): - __tablename__ = 'user2' + __tablename__ = "user2" id = Column(Integer, primary_key=True) metadata = relationship("Address") @@ -390,20 +439,20 @@ class DeclarativeTest(DeclarativeTestBase): "Attribute name 'metadata' is reserved " "for the MetaData instance when using a " "declarative base class.", - go + go, ) def test_undefer_column_name(self): # TODO: not sure if there was an explicit # test for this elsewhere foo = Column(Integer) - eq_(str(foo), '(no name)') + eq_(str(foo), "(no name)") eq_(foo.key, None) eq_(foo.name, None) - decl.base._undefer_column_name('foo', foo) - eq_(str(foo), 'foo') - eq_(foo.key, 'foo') - eq_(foo.name, 'foo') + decl.base._undefer_column_name("foo", foo) + eq_(str(foo), "foo") + eq_(foo.key, "foo") + eq_(foo.name, "foo") def test_recompile_on_othermapper(self): """declarative version of the same test in mappers.py""" @@ -411,19 +460,20 @@ class DeclarativeTest(DeclarativeTestBase): from sqlalchemy.orm import mapperlib class User(Base): - __tablename__ = 'users' + __tablename__ = "users" - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) + id = Column("id", Integer, primary_key=True) + name = Column("name", String(50)) class Address(Base): - __tablename__ = 'addresses' + __tablename__ = "addresses" - id = Column('id', Integer, primary_key=True) - email = Column('email', String(50)) - user_id = Column('user_id', Integer, ForeignKey('users.id')) - user = relationship("User", primaryjoin=user_id == User.id, - backref="addresses") + id = Column("id", Integer, primary_key=True) + email = Column("email", String(50)) + user_id = Column("user_id", Integer, ForeignKey("users.id")) + user = relationship( + "User", primaryjoin=user_id == User.id, backref="addresses" + ) assert mapperlib.Mapper._new_mappers is True u = User() # noqa @@ -433,81 +483,99 @@ class DeclarativeTest(DeclarativeTestBase): def test_string_dependency_resolution(self): class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "users" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) addresses = relationship( - 'Address', - order_by='desc(Address.email)', - primaryjoin='User.id==Address.user_id', - foreign_keys='[Address.user_id]', - backref=backref('user', - primaryjoin='User.id==Address.user_id', - foreign_keys='[Address.user_id]')) + "Address", + order_by="desc(Address.email)", + primaryjoin="User.id==Address.user_id", + foreign_keys="[Address.user_id]", + backref=backref( + "user", + primaryjoin="User.id==Address.user_id", + foreign_keys="[Address.user_id]", + ), + ) class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "addresses" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) email = Column(String(50)) user_id = Column(Integer) # note no foreign key Base.metadata.create_all() sess = create_session() u1 = User( - name='ed', addresses=[ - Address(email='abc'), - Address(email='def'), Address(email='xyz')]) + name="ed", + addresses=[ + Address(email="abc"), + Address(email="def"), + Address(email="xyz"), + ], + ) sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).filter(User.name == 'ed').one(), - User(name='ed', addresses=[ - Address(email='xyz'), - Address(email='def'), Address(email='abc')])) + eq_( + sess.query(User).filter(User.name == "ed").one(), + User( + name="ed", + addresses=[ + Address(email="xyz"), + Address(email="def"), + Address(email="abc"), + ], + ), + ) class Foo(Base, fixtures.ComparableEntity): - __tablename__ = 'foo' + __tablename__ = "foo" id = Column(Integer, primary_key=True) - rel = relationship('User', - primaryjoin='User.addresses==Foo.id') + rel = relationship("User", primaryjoin="User.addresses==Foo.id") - assert_raises_message(exc.InvalidRequestError, - "'addresses' is not an instance of " - "ColumnProperty", configure_mappers) + assert_raises_message( + exc.InvalidRequestError, + "'addresses' is not an instance of " "ColumnProperty", + configure_mappers, + ) def test_string_dependency_resolution_synonym(self): - class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "users" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) Base.metadata.create_all() sess = create_session() - u1 = User(name='ed') + u1 = User(name="ed") sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).filter(User.name == 'ed').one(), - User(name='ed')) + eq_(sess.query(User).filter(User.name == "ed").one(), User(name="ed")) class Foo(Base, fixtures.ComparableEntity): - __tablename__ = 'foo' + __tablename__ = "foo" id = Column(Integer, primary_key=True) _user_id = Column(Integer) - rel = relationship('User', - uselist=False, - foreign_keys=[User.id], - primaryjoin='Foo.user_id==User.id') + rel = relationship( + "User", + uselist=False, + foreign_keys=[User.id], + primaryjoin="Foo.user_id==User.id", + ) - @synonym_for('_user_id') + @synonym_for("_user_id") @property def user_id(self): return self._user_id @@ -520,18 +588,18 @@ class DeclarativeTest(DeclarativeTestBase): from sqlalchemy.ext.hybrid import hybrid_property class User(Base): - __tablename__ = 'user' + __tablename__ = "user" id = Column(Integer, primary_key=True) firstname = Column(String(50)) lastname = Column(String(50)) - game_id = Column(Integer, ForeignKey('game.id')) + game_id = Column(Integer, ForeignKey("game.id")) @hybrid_property def fullname(self): return self.firstname + " " + self.lastname class Game(Base): - __tablename__ = 'game' + __tablename__ = "game" id = Column(Integer, primary_key=True) name = Column(String(50)) users = relationship("User", order_by="User.fullname") @@ -543,17 +611,17 @@ class DeclarativeTest(DeclarativeTestBase): "user_1.id AS user_1_id, user_1.firstname AS user_1_firstname, " "user_1.lastname AS user_1_lastname, " "user_1.game_id AS user_1_game_id " - "FROM game LEFT OUTER JOIN \"user\" AS user_1 ON game.id = " + 'FROM game LEFT OUTER JOIN "user" AS user_1 ON game.id = ' "user_1.game_id ORDER BY " - "user_1.firstname || :firstname_1 || user_1.lastname" + "user_1.firstname || :firstname_1 || user_1.lastname", ) def test_string_dependency_resolution_asselectable(self): class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) - b_id = Column(ForeignKey('b.id')) + b_id = Column(ForeignKey("b.id")) d = relationship( "D", @@ -564,22 +632,23 @@ class DeclarativeTest(DeclarativeTestBase): ) class B(Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) - d_id = Column(ForeignKey('d.id')) + d_id = Column(ForeignKey("d.id")) class C(Base): - __tablename__ = 'c' + __tablename__ = "c" id = Column(Integer, primary_key=True) - a_id = Column(ForeignKey('a.id')) - d_id = Column(ForeignKey('d.id')) + a_id = Column(ForeignKey("a.id")) + d_id = Column(ForeignKey("d.id")) class D(Base): - __tablename__ = 'd' + __tablename__ = "d" id = Column(Integer, primary_key=True) + s = Session() self.assert_compile( s.query(A).join(A.d), @@ -590,46 +659,47 @@ class DeclarativeTest(DeclarativeTestBase): ) def test_string_dependency_resolution_no_table(self): - class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "users" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) class Bar(Base, fixtures.ComparableEntity): - __tablename__ = 'bar' + __tablename__ = "bar" id = Column(Integer, primary_key=True) - rel = relationship('User', - primaryjoin='User.id==Bar.__table__.id') + rel = relationship("User", primaryjoin="User.id==Bar.__table__.id") - assert_raises_message(exc.InvalidRequestError, - "does not have a mapped column named " - "'__table__'", configure_mappers) + assert_raises_message( + exc.InvalidRequestError, + "does not have a mapped column named " "'__table__'", + configure_mappers, + ) def test_string_w_pj_annotations(self): - class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "users" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "addresses" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) email = Column(String(50)) user_id = Column(Integer) user = relationship( - "User", - primaryjoin="remote(User.id)==foreign(Address.user_id)" + "User", primaryjoin="remote(User.id)==foreign(Address.user_id)" ) eq_( Address.user.property._join_condition.local_remote_pairs, - [(Address.__table__.c.user_id, User.__table__.c.id)] + [(Address.__table__.c.user_id, User.__table__.c.id)], ) def test_string_dependency_resolution_no_magic(self): @@ -637,148 +707,166 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' + __tablename__ = "users" id = Column(Integer, primary_key=True) addresses = relationship( - 'Address', - primaryjoin='User.id==Address.user_id.prop.columns[0]') + "Address", + primaryjoin="User.id==Address.user_id.prop.columns[0]", + ) class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' + __tablename__ = "addresses" id = Column(Integer, primary_key=True) - user_id = Column(Integer, ForeignKey('users.id')) + user_id = Column(Integer, ForeignKey("users.id")) configure_mappers() - eq_(str(User.addresses.prop.primaryjoin), - 'users.id = addresses.user_id') + eq_( + str(User.addresses.prop.primaryjoin), + "users.id = addresses.user_id", + ) def test_string_dependency_resolution_module_qualified(self): class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' + __tablename__ = "users" id = Column(Integer, primary_key=True) addresses = relationship( - '%s.Address' % __name__, - primaryjoin='%s.User.id==%s.Address.user_id.prop.columns[0]' - % (__name__, __name__)) + "%s.Address" % __name__, + primaryjoin="%s.User.id==%s.Address.user_id.prop.columns[0]" + % (__name__, __name__), + ) class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' + __tablename__ = "addresses" id = Column(Integer, primary_key=True) - user_id = Column(Integer, ForeignKey('users.id')) + user_id = Column(Integer, ForeignKey("users.id")) configure_mappers() - eq_(str(User.addresses.prop.primaryjoin), - 'users.id = addresses.user_id') + eq_( + str(User.addresses.prop.primaryjoin), + "users.id = addresses.user_id", + ) def test_string_dependency_resolution_in_backref(self): - class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' + __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) - addresses = relationship('Address', - primaryjoin='User.id==Address.user_id', - backref='user') + addresses = relationship( + "Address", + primaryjoin="User.id==Address.user_id", + backref="user", + ) class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' + __tablename__ = "addresses" id = Column(Integer, primary_key=True) email = Column(String(50)) - user_id = Column(Integer, ForeignKey('users.id')) + user_id = Column(Integer, ForeignKey("users.id")) configure_mappers() - eq_(str(User.addresses.property.primaryjoin), - str(Address.user.property.primaryjoin)) + eq_( + str(User.addresses.property.primaryjoin), + str(Address.user.property.primaryjoin), + ) def test_string_dependency_resolution_tables(self): - class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' + __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) - props = relationship('Prop', secondary='user_to_prop', - primaryjoin='User.id==user_to_prop.c.u' - 'ser_id', - secondaryjoin='user_to_prop.c.prop_id=' - '=Prop.id', backref='users') + props = relationship( + "Prop", + secondary="user_to_prop", + primaryjoin="User.id==user_to_prop.c.u" "ser_id", + secondaryjoin="user_to_prop.c.prop_id=" "=Prop.id", + backref="users", + ) class Prop(Base, fixtures.ComparableEntity): - __tablename__ = 'props' + __tablename__ = "props" id = Column(Integer, primary_key=True) name = Column(String(50)) user_to_prop = Table( - 'user_to_prop', Base.metadata, - Column('user_id', Integer, ForeignKey('users.id')), - Column('prop_id', Integer, ForeignKey('props.id'))) + "user_to_prop", + Base.metadata, + Column("user_id", Integer, ForeignKey("users.id")), + Column("prop_id", Integer, ForeignKey("props.id")), + ) configure_mappers() - assert class_mapper(User).get_property('props').secondary \ - is user_to_prop + assert ( + class_mapper(User).get_property("props").secondary is user_to_prop + ) def test_string_dependency_resolution_schemas(self): Base = decl.declarative_base() class User(Base): - __tablename__ = 'users' - __table_args__ = {'schema': 'fooschema'} + __tablename__ = "users" + __table_args__ = {"schema": "fooschema"} id = Column(Integer, primary_key=True) name = Column(String(50)) props = relationship( - 'Prop', secondary='fooschema.user_to_prop', - primaryjoin='User.id==fooschema.user_to_prop.c.user_id', - secondaryjoin='fooschema.user_to_prop.c.prop_id==Prop.id', - backref='users') + "Prop", + secondary="fooschema.user_to_prop", + primaryjoin="User.id==fooschema.user_to_prop.c.user_id", + secondaryjoin="fooschema.user_to_prop.c.prop_id==Prop.id", + backref="users", + ) class Prop(Base): - __tablename__ = 'props' - __table_args__ = {'schema': 'fooschema'} + __tablename__ = "props" + __table_args__ = {"schema": "fooschema"} id = Column(Integer, primary_key=True) name = Column(String(50)) user_to_prop = Table( - 'user_to_prop', Base.metadata, - Column('user_id', Integer, ForeignKey('fooschema.users.id')), - Column('prop_id', Integer, ForeignKey('fooschema.props.id')), - schema='fooschema') + "user_to_prop", + Base.metadata, + Column("user_id", Integer, ForeignKey("fooschema.users.id")), + Column("prop_id", Integer, ForeignKey("fooschema.props.id")), + schema="fooschema", + ) configure_mappers() - assert class_mapper(User).get_property('props').secondary \ - is user_to_prop + assert ( + class_mapper(User).get_property("props").secondary is user_to_prop + ) def test_string_dependency_resolution_annotations(self): Base = decl.declarative_base() class Parent(Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) name = Column(String) children = relationship( "Child", primaryjoin="Parent.name==" - "remote(foreign(func.lower(Child.name_upper)))" + "remote(foreign(func.lower(Child.name_upper)))", ) class Child(Base): - __tablename__ = 'child' + __tablename__ = "child" id = Column(Integer, primary_key=True) name_upper = Column(String) configure_mappers() eq_( Parent.children.property._calculated_foreign_keys, - set([Child.name_upper.property.columns[0]]) + set([Child.name_upper.property.columns[0]]), ) def test_shared_class_registry(self): @@ -787,11 +875,11 @@ class DeclarativeTest(DeclarativeTestBase): Base2 = decl.declarative_base(testing.db, class_registry=reg) class A(Base1): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) class B(Base2): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) aid = Column(Integer, ForeignKey(A.id)) as_ = relationship("A") @@ -799,24 +887,28 @@ class DeclarativeTest(DeclarativeTestBase): assert B.as_.property.mapper.class_ is A def test_uncompiled_attributes_in_relationship(self): - class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "addresses" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) email = Column(String(50)) - user_id = Column(Integer, ForeignKey('users.id')) + user_id = Column(Integer, ForeignKey("users.id")) class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "users" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) - addresses = relationship('Address', order_by=Address.email, - foreign_keys=Address.user_id, - remote_side=Address.user_id) + addresses = relationship( + "Address", + order_by=Address.email, + foreign_keys=Address.user_id, + remote_side=Address.user_id, + ) # get the mapper for User. User mapper will compile, # "addresses" relationship will call upon Address.user_id for @@ -829,28 +921,39 @@ class DeclarativeTest(DeclarativeTestBase): class_mapper(User) Base.metadata.create_all() sess = create_session() - u1 = User(name='ed', addresses=[ - Address(email='abc'), - Address(email='xyz'), Address(email='def')]) + u1 = User( + name="ed", + addresses=[ + Address(email="abc"), + Address(email="xyz"), + Address(email="def"), + ], + ) sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).filter(User.name == 'ed').one(), - User(name='ed', addresses=[ - Address(email='abc'), - Address(email='def'), Address(email='xyz')])) + eq_( + sess.query(User).filter(User.name == "ed").one(), + User( + name="ed", + addresses=[ + Address(email="abc"), + Address(email="def"), + Address(email="xyz"), + ], + ), + ) def test_nice_dependency_error(self): - class User(Base): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) - addresses = relationship('Address') + __tablename__ = "users" + id = Column("id", Integer, primary_key=True) + addresses = relationship("Address") class Address(Base): - __tablename__ = 'addresses' + __tablename__ = "addresses" id = Column(Integer, primary_key=True) foo = sa.orm.column_property(User.id == 5) @@ -860,16 +963,15 @@ class DeclarativeTest(DeclarativeTestBase): assert_raises(sa.exc.ArgumentError, configure_mappers) def test_nice_dependency_error_works_with_hasattr(self): - class User(Base): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) - addresses = relationship('Address') + __tablename__ = "users" + id = Column("id", Integer, primary_key=True) + addresses = relationship("Address") # hasattr() on a compile-loaded attribute try: - hasattr(User.addresses, 'property') + hasattr(User.addresses, "property") except exc.InvalidRequestError: assert sa.util.compat.py32 @@ -882,15 +984,16 @@ class DeclarativeTest(DeclarativeTestBase): " - can't proceed with initialization of other mappers. " r"Triggering mapper: 'Mapper\|User\|users'. " "Original exception was: When initializing.*", - configure_mappers) + configure_mappers, + ) def test_custom_base(self): class MyBase(object): - def foobar(self): return "foobar" + Base = decl.declarative_base(cls=MyBase) - assert hasattr(Base, 'metadata') + assert hasattr(Base, "metadata") assert Base().foobar() == "foobar" def test_uses_get_on_class_col_fk(self): @@ -899,22 +1002,23 @@ class DeclarativeTest(DeclarativeTestBase): class Master(Base): - __tablename__ = 'master' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "master" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) class Detail(Base): - __tablename__ = 'detail' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "detail" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) master_id = Column(None, ForeignKey(Master.id)) master = relationship(Master) Base.metadata.create_all() configure_mappers() - assert class_mapper(Detail).get_property('master' - ).strategy.use_get + assert class_mapper(Detail).get_property("master").strategy.use_get m1 = Master() d1 = Detail(master=m1) sess = create_session() @@ -931,12 +1035,12 @@ class DeclarativeTest(DeclarativeTestBase): def test_index_doesnt_compile(self): class User(Base): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) + __tablename__ = "users" + id = Column("id", Integer, primary_key=True) + name = Column("name", String(50)) error = relationship("Address") - i = Index('my_index', User.name) + i = Index("my_index", User.name) # compile fails due to the nonexistent Addresses relationship assert_raises(sa.exc.InvalidRequestError, configure_mappers) @@ -950,83 +1054,96 @@ class DeclarativeTest(DeclarativeTestBase): Base.metadata.create_all() def test_add_prop(self): - class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) - User.name = Column('name', String(50)) - User.addresses = relationship('Address', backref='user') + User.name = Column("name", String(50)) + User.addresses = relationship("Address", backref="user") class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "addresses" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) - Address.email = Column(String(50), key='_email') - Address.user_id = Column('user_id', Integer, - ForeignKey('users.id'), key='_user_id') + Address.email = Column(String(50), key="_email") + Address.user_id = Column( + "user_id", Integer, ForeignKey("users.id"), key="_user_id" + ) Base.metadata.create_all() - eq_(Address.__table__.c['id'].name, 'id') - eq_(Address.__table__.c['_email'].name, 'email') - eq_(Address.__table__.c['_user_id'].name, 'user_id') - u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + eq_(Address.__table__.c["id"].name, "id") + eq_(Address.__table__.c["_email"].name, "email") + eq_(Address.__table__.c["_user_id"].name, "user_id") + u1 = User( + name="u1", addresses=[Address(email="one"), Address(email="two")] + ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [ - User( - name='u1', - addresses=[Address(email='one'), Address(email='two')])]) - a1 = sess.query(Address).filter(Address.email == 'two').one() - eq_(a1, Address(email='two')) - eq_(a1.user, User(name='u1')) + eq_( + sess.query(User).all(), + [ + User( + name="u1", + addresses=[Address(email="one"), Address(email="two")], + ) + ], + ) + a1 = sess.query(Address).filter(Address.email == "two").one() + eq_(a1, Address(email="two")) + eq_(a1.user, User(name="u1")) def test_alt_name_attr_subclass_column_inline(self): # [ticket:2900] class A(Base): - __tablename__ = 'a' - id = Column('id', Integer, primary_key=True) - data = Column('data') + __tablename__ = "a" + id = Column("id", Integer, primary_key=True) + data = Column("data") class ASub(A): brap = A.data + assert ASub.brap.property is A.data.property assert isinstance( - ASub.brap.original_property, properties.SynonymProperty) + ASub.brap.original_property, properties.SynonymProperty + ) def test_alt_name_attr_subclass_relationship_inline(self): # [ticket:2900] class A(Base): - __tablename__ = 'a' - id = Column('id', Integer, primary_key=True) - b_id = Column(Integer, ForeignKey('b.id')) + __tablename__ = "a" + id = Column("id", Integer, primary_key=True) + b_id = Column(Integer, ForeignKey("b.id")) b = relationship("B", backref="as_") class B(Base): - __tablename__ = 'b' - id = Column('id', Integer, primary_key=True) + __tablename__ = "b" + id = Column("id", Integer, primary_key=True) configure_mappers() class ASub(A): brap = A.b + assert ASub.brap.property is A.b.property assert isinstance( - ASub.brap.original_property, properties.SynonymProperty) + ASub.brap.original_property, properties.SynonymProperty + ) ASub(brap=B()) def test_alt_name_attr_subclass_column_attrset(self): # [ticket:2900] class A(Base): - __tablename__ = 'a' - id = Column('id', Integer, primary_key=True) - data = Column('data') + __tablename__ = "a" + id = Column("id", Integer, primary_key=True) + data = Column("data") + A.brap = A.data assert A.brap.property is A.data.property assert isinstance(A.brap.original_property, properties.SynonymProperty) @@ -1034,113 +1151,131 @@ class DeclarativeTest(DeclarativeTestBase): def test_alt_name_attr_subclass_relationship_attrset(self): # [ticket:2900] class A(Base): - __tablename__ = 'a' - id = Column('id', Integer, primary_key=True) - b_id = Column(Integer, ForeignKey('b.id')) + __tablename__ = "a" + id = Column("id", Integer, primary_key=True) + b_id = Column(Integer, ForeignKey("b.id")) b = relationship("B", backref="as_") + A.brap = A.b class B(Base): - __tablename__ = 'b' - id = Column('id', Integer, primary_key=True) + __tablename__ = "b" + id = Column("id", Integer, primary_key=True) assert A.brap.property is A.b.property assert isinstance(A.brap.original_property, properties.SynonymProperty) A(brap=B()) def test_eager_order_by(self): - class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - email = Column('email', String(50)) - user_id = Column('user_id', Integer, ForeignKey('users.id')) + __tablename__ = "addresses" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + email = Column("email", String(50)) + user_id = Column("user_id", Integer, ForeignKey("users.id")) class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) - addresses = relationship('Address', order_by=Address.email) + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) + addresses = relationship("Address", order_by=Address.email) Base.metadata.create_all() - u1 = User(name='u1', addresses=[Address(email='two'), - Address(email='one')]) + u1 = User( + name="u1", addresses=[Address(email="two"), Address(email="one")] + ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).options(joinedload(User.addresses)).all(), - [User(name='u1', addresses=[Address(email='one'), - Address(email='two')])]) + eq_( + sess.query(User).options(joinedload(User.addresses)).all(), + [ + User( + name="u1", + addresses=[Address(email="one"), Address(email="two")], + ) + ], + ) def test_order_by_multi(self): - class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - email = Column('email', String(50)) - user_id = Column('user_id', Integer, ForeignKey('users.id')) + __tablename__ = "addresses" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + email = Column("email", String(50)) + user_id = Column("user_id", Integer, ForeignKey("users.id")) class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) - addresses = relationship('Address', - order_by=(Address.email, Address.id)) + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) + addresses = relationship( + "Address", order_by=(Address.email, Address.id) + ) Base.metadata.create_all() - u1 = User(name='u1', addresses=[Address(email='two'), - Address(email='one')]) + u1 = User( + name="u1", addresses=[Address(email="two"), Address(email="one")] + ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - u = sess.query(User).filter(User.name == 'u1').one() + u = sess.query(User).filter(User.name == "u1").one() u.addresses def test_as_declarative(self): - class User(fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) - addresses = relationship('Address', backref='user') + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) + addresses = relationship("Address", backref="user") class Address(fixtures.ComparableEntity): - __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - email = Column('email', String(50)) - user_id = Column('user_id', Integer, ForeignKey('users.id')) + __tablename__ = "addresses" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + email = Column("email", String(50)) + user_id = Column("user_id", Integer, ForeignKey("users.id")) reg = {} decl.instrument_declarative(User, reg, Base.metadata) decl.instrument_declarative(Address, reg, Base.metadata) Base.metadata.create_all() - u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + u1 = User( + name="u1", addresses=[Address(email="one"), Address(email="two")] + ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [ - User( - name='u1', - addresses=[Address(email='one'), Address(email='two')])]) + eq_( + sess.query(User).all(), + [ + User( + name="u1", + addresses=[Address(email="one"), Address(email="two")], + ) + ], + ) def test_custom_mapper_attribute(self): - def mymapper(cls, tbl, **kwargs): m = sa.orm.mapper(cls, tbl, **kwargs) m.CHECK = True @@ -1149,14 +1284,13 @@ class DeclarativeTest(DeclarativeTestBase): base = decl.declarative_base() class Foo(base): - __tablename__ = 'foo' + __tablename__ = "foo" __mapper_cls__ = mymapper id = Column(Integer, primary_key=True) eq_(Foo.__mapper__.CHECK, True) def test_custom_mapper_argument(self): - def mymapper(cls, tbl, **kwargs): m = sa.orm.mapper(cls, tbl, **kwargs) m.CHECK = True @@ -1165,7 +1299,7 @@ class DeclarativeTest(DeclarativeTestBase): base = decl.declarative_base(mapper=mymapper) class Foo(base): - __tablename__ = 'foo' + __tablename__ = "foo" id = Column(Integer, primary_key=True) eq_(Foo.__mapper__.CHECK, True) @@ -1173,23 +1307,22 @@ class DeclarativeTest(DeclarativeTestBase): def test_oops(self): with testing.expect_warnings( - "Ignoring declarative-like tuple value of " - "attribute 'name'"): + "Ignoring declarative-like tuple value of " "attribute 'name'" + ): class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)), + __tablename__ = "users" + id = Column("id", Integer, primary_key=True) + name = (Column("name", String(50)),) def test_table_args_no_dict(self): - class Foo1(Base): - __tablename__ = 'foo' - __table_args__ = ForeignKeyConstraint(['id'], ['foo.bar']), - id = Column('id', Integer, primary_key=True) - bar = Column('bar', Integer) + __tablename__ = "foo" + __table_args__ = (ForeignKeyConstraint(["id"], ["foo.bar"]),) + id = Column("id", Integer, primary_key=True) + bar = Column("bar", Integer) assert Foo1.__table__.c.id.references(Foo1.__table__.c.bar) @@ -1197,49 +1330,50 @@ class DeclarativeTest(DeclarativeTestBase): def err(): class Foo1(Base): - __tablename__ = 'foo' - __table_args__ = ForeignKeyConstraint(['id'], ['foo.id' - ]) - id = Column('id', Integer, primary_key=True) - assert_raises_message(sa.exc.ArgumentError, - '__table_args__ value must be a tuple, ', err) + __tablename__ = "foo" + __table_args__ = ForeignKeyConstraint(["id"], ["foo.id"]) + id = Column("id", Integer, primary_key=True) - def test_table_args_none(self): + assert_raises_message( + sa.exc.ArgumentError, "__table_args__ value must be a tuple, ", err + ) + def test_table_args_none(self): class Foo2(Base): - __tablename__ = 'foo' + __tablename__ = "foo" __table_args__ = None - id = Column('id', Integer, primary_key=True) + id = Column("id", Integer, primary_key=True) assert Foo2.__table__.kwargs == {} def test_table_args_dict_format(self): - class Foo2(Base): - __tablename__ = 'foo' - __table_args__ = {'mysql_engine': 'InnoDB'} - id = Column('id', Integer, primary_key=True) + __tablename__ = "foo" + __table_args__ = {"mysql_engine": "InnoDB"} + id = Column("id", Integer, primary_key=True) - assert Foo2.__table__.kwargs['mysql_engine'] == 'InnoDB' + assert Foo2.__table__.kwargs["mysql_engine"] == "InnoDB" def test_table_args_tuple_format(self): class Foo2(Base): - __tablename__ = 'foo' - __table_args__ = {'mysql_engine': 'InnoDB'} - id = Column('id', Integer, primary_key=True) + __tablename__ = "foo" + __table_args__ = {"mysql_engine": "InnoDB"} + id = Column("id", Integer, primary_key=True) class Bar(Base): - __tablename__ = 'bar' - __table_args__ = ForeignKeyConstraint(['id'], ['foo.id']), \ - {'mysql_engine': 'InnoDB'} - id = Column('id', Integer, primary_key=True) + __tablename__ = "bar" + __table_args__ = ( + ForeignKeyConstraint(["id"], ["foo.id"]), + {"mysql_engine": "InnoDB"}, + ) + id = Column("id", Integer, primary_key=True) assert Bar.__table__.c.id.references(Foo2.__table__.c.id) - assert Bar.__table__.kwargs['mysql_engine'] == 'InnoDB' + assert Bar.__table__.kwargs["mysql_engine"] == "InnoDB" def test_table_cls_attribute(self): class Foo(Base): @@ -1248,7 +1382,7 @@ class DeclarativeTest(DeclarativeTestBase): @classmethod def __table_cls__(cls, *arg, **kw): name = arg[0] - return Table(name + 'bat', *arg[1:], **kw) + return Table(name + "bat", *arg[1:], **kw) id = Column(Integer, primary_key=True) @@ -1265,8 +1399,9 @@ class DeclarativeTest(DeclarativeTestBase): @classmethod def __table_cls__(cls, *arg, **kw): for obj in arg[1:]: - if (isinstance(obj, Column) and obj.primary_key) or \ - isinstance(obj, PrimaryKeyConstraint): + if ( + isinstance(obj, Column) and obj.primary_key + ) or isinstance(obj, PrimaryKeyConstraint): return Table(*arg, **kw) return None @@ -1280,77 +1415,99 @@ class DeclarativeTest(DeclarativeTestBase): is_(inspect(Employee).local_table, Person.__table__) def test_expression(self): - class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) - addresses = relationship('Address', backref='user') + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) + addresses = relationship("Address", backref="user") class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - email = Column('email', String(50)) - user_id = Column('user_id', Integer, ForeignKey('users.id')) + __tablename__ = "addresses" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + email = Column("email", String(50)) + user_id = Column("user_id", Integer, ForeignKey("users.id")) - User.address_count = \ - sa.orm.column_property(sa.select([sa.func.count(Address.id)]). - where(Address.user_id - == User.id).as_scalar()) + User.address_count = sa.orm.column_property( + sa.select([sa.func.count(Address.id)]) + .where(Address.user_id == User.id) + .as_scalar() + ) Base.metadata.create_all() - u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + u1 = User( + name="u1", addresses=[Address(email="one"), Address(email="two")] + ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [ - User(name='u1', address_count=2, - addresses=[Address(email='one'), Address(email='two')])]) + eq_( + sess.query(User).all(), + [ + User( + name="u1", + address_count=2, + addresses=[Address(email="one"), Address(email="two")], + ) + ], + ) def test_useless_declared_attr(self): class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - email = Column('email', String(50)) - user_id = Column('user_id', Integer, ForeignKey('users.id')) + __tablename__ = "addresses" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + email = Column("email", String(50)) + user_id = Column("user_id", Integer, ForeignKey("users.id")) class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) - addresses = relationship('Address', backref='user') + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) + addresses = relationship("Address", backref="user") @declared_attr def address_count(cls): # this doesn't really gain us anything. but if # one is used, lets have it function as expected... return sa.orm.column_property( - sa.select([sa.func.count(Address.id)]). - where(Address.user_id == cls.id)) + sa.select([sa.func.count(Address.id)]).where( + Address.user_id == cls.id + ) + ) Base.metadata.create_all() - u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + u1 = User( + name="u1", addresses=[Address(email="one"), Address(email="two")] + ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [ - User(name='u1', address_count=2, - addresses=[Address(email='one'), Address(email='two')])]) + eq_( + sess.query(User).all(), + [ + User( + name="u1", + address_count=2, + addresses=[Address(email="one"), Address(email="two")], + ) + ], + ) def test_declared_on_base_class(self): class MyBase(Base): - __tablename__ = 'foo' + __tablename__ = "foo" id = Column(Integer, primary_key=True) @declared_attr @@ -1358,25 +1515,26 @@ class DeclarativeTest(DeclarativeTestBase): return Column(Integer) class MyClass(MyBase): - __tablename__ = 'bar' - id = Column(Integer, ForeignKey('foo.id'), primary_key=True) + __tablename__ = "bar" + id = Column(Integer, ForeignKey("foo.id"), primary_key=True) # previously, the 'somecol' declared_attr would be ignored # by the mapping and would remain unused. now we take # it as part of MyBase. - assert 'somecol' in MyBase.__table__.c - assert 'somecol' not in MyClass.__table__.c + assert "somecol" in MyBase.__table__.c + assert "somecol" not in MyClass.__table__.c def test_decl_cascading_warns_non_mixin(self): with expect_warnings( - "Use of @declared_attr.cascading only applies to " - "Declarative 'mixin' and 'abstract' classes. " - "Currently, this flag is ignored on mapped class " - "" + "Use of @declared_attr.cascading only applies to " + "Declarative 'mixin' and 'abstract' classes. " + "Currently, this flag is ignored on mapped class " + "" ): + class MyBase(Base): - __tablename__ = 'foo' + __tablename__ = "foo" id = Column(Integer, primary_key=True) @declared_attr.cascading @@ -1384,109 +1542,116 @@ class DeclarativeTest(DeclarativeTestBase): return Column(Integer) def test_column(self): - class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) - User.a = Column('a', String(10)) + User.a = Column("a", String(10)) User.b = Column(String(10)) Base.metadata.create_all() - u1 = User(name='u1', a='a', b='b') - eq_(u1.a, 'a') - eq_(User.a.get_history(u1), (['a'], (), ())) + u1 = User(name="u1", a="a", b="b") + eq_(u1.a, "a") + eq_(User.a.get_history(u1), (["a"], (), ())) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [User(name='u1', a='a', b='b')]) + eq_(sess.query(User).all(), [User(name="u1", a="a", b="b")]) def test_column_properties(self): - class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "addresses" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) email = Column(String(50)) - user_id = Column(Integer, ForeignKey('users.id')) + user_id = Column(Integer, ForeignKey("users.id")) class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) - adr_count = \ - sa.orm.column_property( - sa.select([sa.func.count(Address.id)], - Address.user_id == id).as_scalar()) + adr_count = sa.orm.column_property( + sa.select( + [sa.func.count(Address.id)], Address.user_id == id + ).as_scalar() + ) addresses = relationship(Address) Base.metadata.create_all() - u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + u1 = User( + name="u1", addresses=[Address(email="one"), Address(email="two")] + ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [ - User(name='u1', adr_count=2, - addresses=[Address(email='one'), Address(email='two')])]) + eq_( + sess.query(User).all(), + [ + User( + name="u1", + adr_count=2, + addresses=[Address(email="one"), Address(email="two")], + ) + ], + ) def test_column_properties_2(self): - class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' + __tablename__ = "addresses" id = Column(Integer, primary_key=True) email = Column(String(50)) - user_id = Column(Integer, ForeignKey('users.id')) + user_id = Column(Integer, ForeignKey("users.id")) class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) + __tablename__ = "users" + id = Column("id", Integer, primary_key=True) + name = Column("name", String(50)) # this is not "valid" but we want to test that Address.id # doesn't get stuck into user's table adr_count = Address.id - eq_(set(User.__table__.c.keys()), set(['id', 'name'])) - eq_(set(Address.__table__.c.keys()), set(['id', 'email', - 'user_id'])) + eq_(set(User.__table__.c.keys()), set(["id", "name"])) + eq_(set(Address.__table__.c.keys()), set(["id", "email", "user_id"])) def test_deferred(self): - class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "users" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = sa.orm.deferred(Column(String(50))) Base.metadata.create_all() sess = create_session() - sess.add(User(name='u1')) + sess.add(User(name="u1")) sess.flush() sess.expunge_all() - u1 = sess.query(User).filter(User.name == 'u1').one() - assert 'name' not in u1.__dict__ + u1 = sess.query(User).filter(User.name == "u1").one() + assert "name" not in u1.__dict__ def go(): - eq_(u1.name, 'u1') + eq_(u1.name, "u1") self.assert_sql_count(testing.db, go, 1) def test_composite_inline(self): class AddressComposite(fixtures.ComparableEntity): - def __init__(self, street, state): self.street = street self.state = state @@ -1495,30 +1660,27 @@ class DeclarativeTest(DeclarativeTestBase): return [self.street, self.state] class User(Base, fixtures.ComparableEntity): - __tablename__ = 'user' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) - address = composite(AddressComposite, - Column('street', String(50)), - Column('state', String(2)), - ) + __tablename__ = "user" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + address = composite( + AddressComposite, + Column("street", String(50)), + Column("state", String(2)), + ) Base.metadata.create_all() sess = Session() - sess.add(User( - address=AddressComposite('123 anywhere street', - 'MD') - )) + sess.add(User(address=AddressComposite("123 anywhere street", "MD"))) sess.commit() eq_( sess.query(User).all(), - [User(address=AddressComposite('123 anywhere street', - 'MD'))] + [User(address=AddressComposite("123 anywhere street", "MD"))], ) def test_composite_separate(self): class AddressComposite(fixtures.ComparableEntity): - def __init__(self, street, state): self.street = street self.state = state @@ -1527,78 +1689,79 @@ class DeclarativeTest(DeclarativeTestBase): return [self.street, self.state] class User(Base, fixtures.ComparableEntity): - __tablename__ = 'user' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "user" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) street = Column(String(50)) state = Column(String(2)) - address = composite(AddressComposite, - street, state) + address = composite(AddressComposite, street, state) Base.metadata.create_all() sess = Session() - sess.add(User( - address=AddressComposite('123 anywhere street', - 'MD') - )) + sess.add(User(address=AddressComposite("123 anywhere street", "MD"))) sess.commit() eq_( sess.query(User).all(), - [User(address=AddressComposite('123 anywhere street', - 'MD'))] + [User(address=AddressComposite("123 anywhere street", "MD"))], ) def test_mapping_to_join(self): - users = Table('users', Base.metadata, - Column('id', Integer, primary_key=True) - ) - addresses = Table('addresses', Base.metadata, - Column('id', Integer, primary_key=True), - Column('user_id', Integer, ForeignKey('users.id')) - ) - usersaddresses = sa.join(users, addresses, users.c.id - == addresses.c.user_id) + users = Table( + "users", Base.metadata, Column("id", Integer, primary_key=True) + ) + addresses = Table( + "addresses", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("user_id", Integer, ForeignKey("users.id")), + ) + usersaddresses = sa.join( + users, addresses, users.c.id == addresses.c.user_id + ) class User(Base): __table__ = usersaddresses - __table_args__ = {'primary_key': [users.c.id]} + __table_args__ = {"primary_key": [users.c.id]} # need to use column_property for now user_id = column_property(users.c.id, addresses.c.user_id) address_id = addresses.c.id - assert User.__mapper__.get_property('user_id').columns[0] \ - is users.c.id - assert User.__mapper__.get_property('user_id').columns[1] \ + assert User.__mapper__.get_property("user_id").columns[0] is users.c.id + assert ( + User.__mapper__.get_property("user_id").columns[1] is addresses.c.user_id + ) def test_synonym_inline(self): - class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - _name = Column('name', String(50)) + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + _name = Column("name", String(50)) def _set_name(self, name): - self._name = 'SOMENAME ' + name + self._name = "SOMENAME " + name def _get_name(self): return self._name - name = sa.orm.synonym('_name', - descriptor=property(_get_name, - _set_name)) + name = sa.orm.synonym( + "_name", descriptor=property(_get_name, _set_name) + ) Base.metadata.create_all() sess = create_session() - u1 = User(name='someuser') - eq_(u1.name, 'SOMENAME someuser') + u1 = User(name="someuser") + eq_(u1.name, "SOMENAME someuser") sess.add(u1) sess.flush() - eq_(sess.query(User).filter(User.name == 'SOMENAME someuser' - ).one(), u1) + eq_( + sess.query(User).filter(User.name == "SOMENAME someuser").one(), u1 + ) def test_synonym_no_descriptor(self): from sqlalchemy.orm.properties import ColumnProperty @@ -1608,68 +1771,70 @@ class DeclarativeTest(DeclarativeTestBase): __hash__ = None def __eq__(self, other): - return self.__clause_element__() == other + ' FOO' + return self.__clause_element__() == other + " FOO" class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - _name = Column('name', String(50)) - name = sa.orm.synonym('_name', - comparator_factory=CustomCompare) + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + _name = Column("name", String(50)) + name = sa.orm.synonym("_name", comparator_factory=CustomCompare) Base.metadata.create_all() sess = create_session() - u1 = User(name='someuser FOO') + u1 = User(name="someuser FOO") sess.add(u1) sess.flush() - eq_(sess.query(User).filter(User.name == 'someuser').one(), u1) + eq_(sess.query(User).filter(User.name == "someuser").one(), u1) def test_synonym_added(self): - class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - _name = Column('name', String(50)) + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + _name = Column("name", String(50)) def _set_name(self, name): - self._name = 'SOMENAME ' + name + self._name = "SOMENAME " + name def _get_name(self): return self._name name = property(_get_name, _set_name) - User.name = sa.orm.synonym('_name', descriptor=User.name) + User.name = sa.orm.synonym("_name", descriptor=User.name) Base.metadata.create_all() sess = create_session() - u1 = User(name='someuser') - eq_(u1.name, 'SOMENAME someuser') + u1 = User(name="someuser") + eq_(u1.name, "SOMENAME someuser") sess.add(u1) sess.flush() - eq_(sess.query(User).filter(User.name == 'SOMENAME someuser' - ).one(), u1) + eq_( + sess.query(User).filter(User.name == "SOMENAME someuser").one(), u1 + ) def test_reentrant_compile_via_foreignkey(self): - class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) - addresses = relationship('Address', backref='user') + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) + addresses = relationship("Address", backref="user") class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - email = Column('email', String(50)) - user_id = Column('user_id', Integer, ForeignKey(User.id)) + __tablename__ = "addresses" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + email = Column("email", String(50)) + user_id = Column("user_id", Integer, ForeignKey(User.id)) # previous versions would force a re-entrant mapper compile via # the User.id inside the ForeignKey but this is no longer the @@ -1678,65 +1843,83 @@ class DeclarativeTest(DeclarativeTestBase): sa.orm.configure_mappers() eq_( list(Address.user_id.property.columns[0].foreign_keys)[0].column, - User.__table__.c.id + User.__table__.c.id, ) Base.metadata.create_all() - u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + u1 = User( + name="u1", addresses=[Address(email="one"), Address(email="two")] + ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [ - User(name='u1', - addresses=[Address(email='one'), Address(email='two')])]) + eq_( + sess.query(User).all(), + [ + User( + name="u1", + addresses=[Address(email="one"), Address(email="two")], + ) + ], + ) def test_relationship_reference(self): - class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - email = Column('email', String(50)) - user_id = Column('user_id', Integer, ForeignKey('users.id')) + __tablename__ = "addresses" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + email = Column("email", String(50)) + user_id = Column("user_id", Integer, ForeignKey("users.id")) class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) - addresses = relationship('Address', backref='user', - primaryjoin=id == Address.user_id) - - User.address_count = \ - sa.orm.column_property(sa.select([sa.func.count(Address.id)]). - where(Address.user_id - == User.id).as_scalar()) + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) + addresses = relationship( + "Address", backref="user", primaryjoin=id == Address.user_id + ) + + User.address_count = sa.orm.column_property( + sa.select([sa.func.count(Address.id)]) + .where(Address.user_id == User.id) + .as_scalar() + ) Base.metadata.create_all() - u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + u1 = User( + name="u1", addresses=[Address(email="one"), Address(email="two")] + ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [ - User(name='u1', address_count=2, - addresses=[Address(email='one'), Address(email='two')])]) + eq_( + sess.query(User).all(), + [ + User( + name="u1", + address_count=2, + addresses=[Address(email="one"), Address(email="two")], + ) + ], + ) def test_pk_with_fk_init(self): - class Bar(Base): - __tablename__ = 'bar' - id = sa.Column(sa.Integer, sa.ForeignKey('foo.id'), - primary_key=True) + __tablename__ = "bar" + id = sa.Column( + sa.Integer, sa.ForeignKey("foo.id"), primary_key=True + ) ex = sa.Column(sa.Integer, primary_key=True) class Foo(Base): - __tablename__ = 'foo' + __tablename__ = "foo" id = sa.Column(sa.Integer, primary_key=True) bars = sa.orm.relationship(Bar) @@ -1746,108 +1929,101 @@ class DeclarativeTest(DeclarativeTestBase): def test_with_explicit_autoloaded(self): meta = MetaData(testing.db) t1 = Table( - 't1', meta, - Column('id', String(50), primary_key=True), - Column('data', String(50))) + "t1", + meta, + Column("id", String(50), primary_key=True), + Column("data", String(50)), + ) meta.create_all() try: class MyObj(Base): - __table__ = Table('t1', Base.metadata, autoload=True) + __table__ = Table("t1", Base.metadata, autoload=True) sess = create_session() - m = MyObj(id='someid', data='somedata') + m = MyObj(id="someid", data="somedata") sess.add(m) sess.flush() - eq_(t1.select().execute().fetchall(), [('someid', 'somedata' - )]) + eq_(t1.select().execute().fetchall(), [("someid", "somedata")]) finally: meta.drop_all() def test_synonym_for(self): - class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) - @decl.synonym_for('name') + @decl.synonym_for("name") @property def namesyn(self): return self.name Base.metadata.create_all() sess = create_session() - u1 = User(name='someuser') - eq_(u1.name, 'someuser') - eq_(u1.namesyn, 'someuser') + u1 = User(name="someuser") + eq_(u1.name, "someuser") + eq_(u1.namesyn, "someuser") sess.add(u1) sess.flush() - rt = sess.query(User).filter(User.namesyn == 'someuser').one() + rt = sess.query(User).filter(User.namesyn == "someuser").one() eq_(rt, u1) def test_comparable_using(self): - class NameComparator(sa.orm.PropComparator): - @property def upperself(self): cls = self.prop.parent.class_ - col = getattr(cls, 'name') + col = getattr(cls, "name") return sa.func.upper(col) - def operate( - self, - op, - other, - **kw - ): + def operate(self, op, other, **kw): return op(self.upperself, other, **kw) class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) + __tablename__ = "users" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) @decl.comparable_using(NameComparator) @property def uc_name(self): - return self.name is not None and self.name.upper() \ - or None + return self.name is not None and self.name.upper() or None Base.metadata.create_all() sess = create_session() - u1 = User(name='someuser') - eq_(u1.name, 'someuser', u1.name) - eq_(u1.uc_name, 'SOMEUSER', u1.uc_name) + u1 = User(name="someuser") + eq_(u1.name, "someuser", u1.name) + eq_(u1.uc_name, "SOMEUSER", u1.uc_name) sess.add(u1) sess.flush() sess.expunge_all() - rt = sess.query(User).filter(User.uc_name == 'SOMEUSER').one() + rt = sess.query(User).filter(User.uc_name == "SOMEUSER").one() eq_(rt, u1) sess.expunge_all() - rt = sess.query(User).filter(User.uc_name.startswith('SOMEUSE' - )).one() + rt = sess.query(User).filter(User.uc_name.startswith("SOMEUSE")).one() eq_(rt, u1) def test_duplicate_classes_in_base(self): - class Test(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) assert_raises_message( sa.exc.SAWarning, "This declarative base already contains a class with ", - lambda: type(Base)("Test", (Base,), dict( - __tablename__='b', - id=Column(Integer, primary_key=True) - )) + lambda: type(Base)( + "Test", + (Base,), + dict(__tablename__="b", id=Column(Integer, primary_key=True)), + ), ) @testing.teardown_events(MapperEvents) @@ -1865,19 +2041,18 @@ class DeclarativeTest(DeclarativeTestBase): canary.class_instrument(cls) class Test(Base): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) eq_( canary.mock_calls, [ mock.call.instrument_class(Test.__mapper__, Test), - mock.call.class_instrument(Test) - ] + mock.call.class_instrument(Test), + ], ) def test_cls_docstring(self): - class MyBase(object): """MyBase Docstring""" @@ -1889,7 +2064,7 @@ class DeclarativeTest(DeclarativeTestBase): Base = decl.declarative_base() class Foo(Base): - __tablename__ = 'foo' + __tablename__ = "foo" id = Column(Integer, primary_key=True) data = Column(String) @@ -1900,14 +2075,14 @@ class DeclarativeTest(DeclarativeTestBase): assert_raises_message( NotImplementedError, "Can't un-map individual mapped attributes on a mapped class.", - go + go, ) def test_delattr_hybrid_fine(self): Base = decl.declarative_base() class Foo(Base): - __tablename__ = 'foo' + __tablename__ = "foo" id = Column(Integer, primary_key=True) data = Column(String) @@ -1928,7 +2103,7 @@ class DeclarativeTest(DeclarativeTestBase): Base = decl.declarative_base() class Foo(Base): - __tablename__ = 'foo' + __tablename__ = "foo" id = Column(Integer, primary_key=True) data = Column(String) @@ -1938,6 +2113,7 @@ class DeclarativeTest(DeclarativeTestBase): @hybrid_property def data_hybrid(self): return self.data + Foo.data_hybrid = data_hybrid assert "data_hybrid" in Foo.__mapper__.all_orm_descriptors.keys() @@ -1949,9 +2125,7 @@ class DeclarativeTest(DeclarativeTestBase): def _produce_test(inline, stringbased): - class ExplicitJoinTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): global User, Address @@ -1959,57 +2133,74 @@ def _produce_test(inline, stringbased): class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "users" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "addresses" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) email = Column(String(50)) - user_id = Column(Integer, ForeignKey('users.id')) + user_id = Column(Integer, ForeignKey("users.id")) if inline: if stringbased: user = relationship( - 'User', - primaryjoin='User.id==Address.user_id', - backref='addresses') + "User", + primaryjoin="User.id==Address.user_id", + backref="addresses", + ) else: - user = relationship(User, primaryjoin=User.id - == user_id, backref='addresses') + user = relationship( + User, + primaryjoin=User.id == user_id, + backref="addresses", + ) if not inline: configure_mappers() if stringbased: Address.user = relationship( - 'User', - primaryjoin='User.id==Address.user_id', - backref='addresses') + "User", + primaryjoin="User.id==Address.user_id", + backref="addresses", + ) else: Address.user = relationship( User, primaryjoin=User.id == Address.user_id, - backref='addresses') + backref="addresses", + ) @classmethod def insert_data(cls): params = [ - dict(list(zip(('id', 'name'), column_values))) + dict(list(zip(("id", "name"), column_values))) for column_values in [ - (7, 'jack'), (8, 'ed'), - (9, 'fred'), (10, 'chuck')]] + (7, "jack"), + (8, "ed"), + (9, "fred"), + (10, "chuck"), + ] + ] User.__table__.insert().execute(params) - Address.__table__.insert().execute([ - dict(list(zip(('id', 'user_id', 'email'), column_values))) - for column_values in [ - (1, 7, 'jack@bean.com'), - (2, 8, 'ed@wood.com'), - (3, 8, 'ed@bettyboop.com'), - (4, 8, 'ed@lala.com'), (5, 9, 'fred@fred.com')]]) + Address.__table__.insert().execute( + [ + dict(list(zip(("id", "user_id", "email"), column_values))) + for column_values in [ + (1, 7, "jack@bean.com"), + (2, 8, "ed@wood.com"), + (3, 8, "ed@bettyboop.com"), + (4, 8, "ed@lala.com"), + (5, 9, "fred@fred.com"), + ] + ] + ) def test_aliased_join(self): @@ -2022,20 +2213,24 @@ def _produce_test(inline, stringbased): # _orm_adapt, though. sess = create_session() - eq_(sess.query(User).join(User.addresses, - aliased=True).filter( - Address.email == 'ed@wood.com').filter( - User.addresses.any(Address.email == 'jack@bean.com')).all(), - []) - - ExplicitJoinTest.__name__ = 'ExplicitJoinTest%s%s' % ( - inline and 'Inline' or 'Separate', - stringbased and 'String' or 'Literal') + eq_( + sess.query(User) + .join(User.addresses, aliased=True) + .filter(Address.email == "ed@wood.com") + .filter(User.addresses.any(Address.email == "jack@bean.com")) + .all(), + [], + ) + + ExplicitJoinTest.__name__ = "ExplicitJoinTest%s%s" % ( + inline and "Inline" or "Separate", + stringbased and "String" or "Literal", + ) return ExplicitJoinTest for inline in True, False: for stringbased in True, False: testclass = _produce_test(inline, stringbased) - exec('%s = testclass' % testclass.__name__) + exec("%s = testclass" % testclass.__name__) del testclass diff --git a/test/ext/declarative/test_clsregistry.py b/test/ext/declarative/test_clsregistry.py index 000479f09d..c93fbc9dc9 100644 --- a/test/ext/declarative/test_clsregistry.py +++ b/test/ext/declarative/test_clsregistry.py @@ -7,7 +7,6 @@ import weakref class MockClass(object): - def __init__(self, base, name): self._decl_class_registry = base tokens = name.split(".") @@ -21,7 +20,7 @@ class MockProp(object): class ClsRegistryTest(fixtures.TestBase): - __requires__ = 'predictable_gc', + __requires__ = ("predictable_gc",) def test_same_module_same_name(self): base = weakref.WeakValueDictionary() @@ -35,7 +34,9 @@ class ClsRegistryTest(fixtures.TestBase): "This declarative base already contains a class with the " "same class name and module name as foo.bar.Foo, and " "will be replaced in the string-lookup table.", - clsregistry.add_class, "Foo", f2 + clsregistry.add_class, + "Foo", + f2, ) def test_resolve(self): @@ -81,9 +82,9 @@ class ClsRegistryTest(fixtures.TestBase): assert_raises_message( exc.InvalidRequestError, 'Multiple classes found for path "alt.Foo" in the registry ' - 'of this declarative base. Please use a fully ' - 'module-qualified path.', - resolver("alt.Foo") + "of this declarative base. Please use a fully " + "module-qualified path.", + resolver("alt.Foo"), ) def test_resolve_dupe_by_name(self): @@ -100,9 +101,9 @@ class ClsRegistryTest(fixtures.TestBase): assert_raises_message( exc.InvalidRequestError, 'Multiple classes found for path "Foo" in the ' - 'registry of this declarative base. Please use a ' - 'fully module-qualified path.', - resolver + "registry of this declarative base. Please use a " + "fully module-qualified path.", + resolver, ) def test_dupe_classes_back_to_one(self): @@ -149,7 +150,7 @@ class ClsRegistryTest(fixtures.TestBase): clsregistry.add_class("Foo", f1) clsregistry.add_class("Foo", f2) - dupe_reg = base['Foo'] + dupe_reg = base["Foo"] dupe_reg.contents = [lambda: None] resolver = clsregistry._resolver(f1, MockProp()) resolver = resolver("Foo") @@ -157,7 +158,7 @@ class ClsRegistryTest(fixtures.TestBase): exc.InvalidRequestError, r"When initializing mapper some_parent, expression " r"'Foo' failed to locate a name \('Foo'\).", - resolver + resolver, ) def test_module_reg_cleanout_race(self): @@ -167,9 +168,9 @@ class ClsRegistryTest(fixtures.TestBase): base = weakref.WeakValueDictionary() f1 = MockClass(base, "foo.bar.Foo") clsregistry.add_class("Foo", f1) - reg = base['_sa_module_registry'] + reg = base["_sa_module_registry"] - mod_entry = reg['foo']['bar'] + mod_entry = reg["foo"]["bar"] resolver = clsregistry._resolver(f1, MockProp()) resolver = resolver("foo") del mod_entry.contents["Foo"] @@ -177,60 +178,59 @@ class ClsRegistryTest(fixtures.TestBase): AttributeError, "Module 'bar' has no mapped classes registered " "under the name 'Foo'", - lambda: resolver().bar.Foo + lambda: resolver().bar.Foo, ) def test_module_reg_no_class(self): base = weakref.WeakValueDictionary() f1 = MockClass(base, "foo.bar.Foo") clsregistry.add_class("Foo", f1) - reg = base['_sa_module_registry'] - mod_entry = reg['foo']['bar'] # noqa + reg = base["_sa_module_registry"] + mod_entry = reg["foo"]["bar"] # noqa resolver = clsregistry._resolver(f1, MockProp()) resolver = resolver("foo") assert_raises_message( AttributeError, "Module 'bar' has no mapped classes registered " "under the name 'Bat'", - lambda: resolver().bar.Bat + lambda: resolver().bar.Bat, ) def test_module_reg_cleanout_two_sub(self): base = weakref.WeakValueDictionary() f1 = MockClass(base, "foo.bar.Foo") clsregistry.add_class("Foo", f1) - reg = base['_sa_module_registry'] + reg = base["_sa_module_registry"] f2 = MockClass(base, "foo.alt.Bar") clsregistry.add_class("Bar", f2) - assert reg['foo']['bar'] + assert reg["foo"]["bar"] del f1 gc_collect() - assert 'bar' not in \ - reg['foo'] - assert 'alt' in reg['foo'] + assert "bar" not in reg["foo"] + assert "alt" in reg["foo"] del f2 gc_collect() - assert 'foo' not in reg.contents + assert "foo" not in reg.contents def test_module_reg_cleanout_sub_to_base(self): base = weakref.WeakValueDictionary() f3 = MockClass(base, "bat.bar.Hoho") clsregistry.add_class("Hoho", f3) - reg = base['_sa_module_registry'] + reg = base["_sa_module_registry"] - assert reg['bat']['bar'] + assert reg["bat"]["bar"] del f3 gc_collect() - assert 'bat' not in reg + assert "bat" not in reg def test_module_reg_cleanout_cls_to_base(self): base = weakref.WeakValueDictionary() f4 = MockClass(base, "single.Blat") clsregistry.add_class("Blat", f4) - reg = base['_sa_module_registry'] - assert reg['single'] + reg = base["_sa_module_registry"] + assert reg["single"] del f4 gc_collect() - assert 'single' not in reg + assert "single" not in reg diff --git a/test/ext/declarative/test_inheritance.py b/test/ext/declarative/test_inheritance.py index a453bc165c..6cf0073b24 100644 --- a/test/ext/declarative/test_inheritance.py +++ b/test/ext/declarative/test_inheritance.py @@ -1,16 +1,34 @@ - -from sqlalchemy.testing import eq_, le_, assert_raises, \ - assert_raises_message, is_, is_true, is_false +from sqlalchemy.testing import ( + eq_, + le_, + assert_raises, + assert_raises_message, + is_, + is_true, + is_false, +) from sqlalchemy.ext import declarative as decl import sqlalchemy as sa from sqlalchemy import testing from sqlalchemy import Integer, String, ForeignKey from sqlalchemy.testing.schema import Table, Column -from sqlalchemy.orm import relationship, create_session, class_mapper, \ - configure_mappers, clear_mappers, \ - polymorphic_union, deferred, Session, mapper -from sqlalchemy.ext.declarative import declared_attr, AbstractConcreteBase, \ - ConcreteBase, has_inherited_table +from sqlalchemy.orm import ( + relationship, + create_session, + class_mapper, + configure_mappers, + clear_mappers, + polymorphic_union, + deferred, + Session, + mapper, +) +from sqlalchemy.ext.declarative import ( + declared_attr, + AbstractConcreteBase, + ConcreteBase, + has_inherited_table, +) from sqlalchemy.testing import fixtures, mock from test.orm.test_events import _RemoveListeners @@ -18,7 +36,6 @@ Base = None class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): - def setup(self): global Base Base = decl.declarative_base(testing.db) @@ -30,119 +47,133 @@ class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): class DeclarativeInheritanceTest(DeclarativeTestBase): - def test_we_must_copy_mapper_args(self): - class Person(Base): - __tablename__ = 'people' + __tablename__ = "people" id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator, - 'polymorphic_identity': 'person'} + discriminator = Column("type", String(50)) + __mapper_args__ = { + "polymorphic_on": discriminator, + "polymorphic_identity": "person", + } class Engineer(Person): primary_language = Column(String(50)) - assert 'inherits' not in Person.__mapper_args__ + assert "inherits" not in Person.__mapper_args__ assert class_mapper(Engineer).polymorphic_identity is None assert class_mapper(Engineer).polymorphic_on is Person.__table__.c.type def test_we_must_only_copy_column_mapper_args(self): - class Person(Base): - __tablename__ = 'people' + __tablename__ = "people" id = Column(Integer, primary_key=True) a = Column(Integer) b = Column(Integer) c = Column(Integer) d = Column(Integer) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator, - 'polymorphic_identity': 'person', - 'version_id_col': 'a', - 'column_prefix': 'bar', - 'include_properties': ['id', 'a', 'b'], - } - assert class_mapper(Person).version_id_col == 'a' - assert class_mapper(Person).include_properties == set(['id', 'a', 'b']) + discriminator = Column("type", String(50)) + __mapper_args__ = { + "polymorphic_on": discriminator, + "polymorphic_identity": "person", + "version_id_col": "a", + "column_prefix": "bar", + "include_properties": ["id", "a", "b"], + } - def test_custom_join_condition(self): + assert class_mapper(Person).version_id_col == "a" + assert class_mapper(Person).include_properties == set(["id", "a", "b"]) + def test_custom_join_condition(self): class Foo(Base): - __tablename__ = 'foo' - id = Column('id', Integer, primary_key=True) + __tablename__ = "foo" + id = Column("id", Integer, primary_key=True) class Bar(Foo): - __tablename__ = 'bar' - bar_id = Column('id', Integer, primary_key=True) - foo_id = Column('foo_id', Integer) - __mapper_args__ = {'inherit_condition': foo_id == Foo.id} + __tablename__ = "bar" + bar_id = Column("id", Integer, primary_key=True) + foo_id = Column("foo_id", Integer) + __mapper_args__ = {"inherit_condition": foo_id == Foo.id} # compile succeeds because inherit_condition is honored configure_mappers() def test_joined(self): - class Company(Base, fixtures.ComparableEntity): - __tablename__ = 'companies' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) - employees = relationship('Person') + __tablename__ = "companies" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) + employees = relationship("Person") class Person(Base, fixtures.ComparableEntity): - __tablename__ = 'people' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - company_id = Column('company_id', Integer, - ForeignKey('companies.id')) - name = Column('name', String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} + __tablename__ = "people" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + company_id = Column( + "company_id", Integer, ForeignKey("companies.id") + ) + name = Column("name", String(50)) + discriminator = Column("type", String(50)) + __mapper_args__ = {"polymorphic_on": discriminator} class Engineer(Person): - __tablename__ = 'engineers' - __mapper_args__ = {'polymorphic_identity': 'engineer'} - id = Column('id', Integer, ForeignKey('people.id'), - primary_key=True) - primary_language = Column('primary_language', String(50)) + __tablename__ = "engineers" + __mapper_args__ = {"polymorphic_identity": "engineer"} + id = Column( + "id", Integer, ForeignKey("people.id"), primary_key=True + ) + primary_language = Column("primary_language", String(50)) class Manager(Person): - __tablename__ = 'managers' - __mapper_args__ = {'polymorphic_identity': 'manager'} - id = Column('id', Integer, ForeignKey('people.id'), - primary_key=True) - golf_swing = Column('golf_swing', String(50)) + __tablename__ = "managers" + __mapper_args__ = {"polymorphic_identity": "manager"} + id = Column( + "id", Integer, ForeignKey("people.id"), primary_key=True + ) + golf_swing = Column("golf_swing", String(50)) Base.metadata.create_all() sess = create_session() c1 = Company( - name='MegaCorp, Inc.', + name="MegaCorp, Inc.", employees=[ - Engineer(name='dilbert', primary_language='java'), - Engineer(name='wally', primary_language='c++'), - Manager(name='dogbert', golf_swing='fore!')]) + Engineer(name="dilbert", primary_language="java"), + Engineer(name="wally", primary_language="c++"), + Manager(name="dogbert", golf_swing="fore!"), + ], + ) - c2 = Company(name='Elbonia, Inc.', - employees=[Engineer(name='vlad', - primary_language='cobol')]) + c2 = Company( + name="Elbonia, Inc.", + employees=[Engineer(name="vlad", primary_language="cobol")], + ) sess.add(c1) sess.add(c2) sess.flush() sess.expunge_all() - eq_(sess.query(Company).filter(Company.employees.of_type(Engineer). - any(Engineer.primary_language - == 'cobol')).first(), c2) + eq_( + sess.query(Company) + .filter( + Company.employees.of_type(Engineer).any( + Engineer.primary_language == "cobol" + ) + ) + .first(), + c2, + ) # ensure that the Manager mapper was compiled with the Manager id # column as higher priority. this ensures that "Manager.id" @@ -151,7 +182,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): eq_( Manager.id.property.columns, - [Manager.__table__.c.id, Person.__table__.c.id] + [Manager.__table__.c.id, Person.__table__.c.id], ) # assert that the "id" column is available without a second @@ -161,120 +192,130 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): sess.expunge_all() def go(): - assert sess.query(Manager).filter(Manager.name == 'dogbert' - ).one().id + assert ( + sess.query(Manager).filter(Manager.name == "dogbert").one().id + ) + self.assert_sql_count(testing.db, go, 1) sess.expunge_all() def go(): - assert sess.query(Person).filter(Manager.name == 'dogbert' - ).one().id + assert ( + sess.query(Person).filter(Manager.name == "dogbert").one().id + ) self.assert_sql_count(testing.db, go, 1) def test_add_subcol_after_the_fact(self): - class Person(Base, fixtures.ComparableEntity): - __tablename__ = 'people' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} + __tablename__ = "people" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) + discriminator = Column("type", String(50)) + __mapper_args__ = {"polymorphic_on": discriminator} class Engineer(Person): - __tablename__ = 'engineers' - __mapper_args__ = {'polymorphic_identity': 'engineer'} - id = Column('id', Integer, ForeignKey('people.id'), - primary_key=True) + __tablename__ = "engineers" + __mapper_args__ = {"polymorphic_identity": "engineer"} + id = Column( + "id", Integer, ForeignKey("people.id"), primary_key=True + ) - Engineer.primary_language = Column('primary_language', - String(50)) + Engineer.primary_language = Column("primary_language", String(50)) Base.metadata.create_all() sess = create_session() - e1 = Engineer(primary_language='java', name='dilbert') + e1 = Engineer(primary_language="java", name="dilbert") sess.add(e1) sess.flush() sess.expunge_all() - eq_(sess.query(Person).first(), - Engineer(primary_language='java', name='dilbert')) + eq_( + sess.query(Person).first(), + Engineer(primary_language="java", name="dilbert"), + ) def test_add_parentcol_after_the_fact(self): - class Person(Base, fixtures.ComparableEntity): - __tablename__ = 'people' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} + __tablename__ = "people" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + discriminator = Column("type", String(50)) + __mapper_args__ = {"polymorphic_on": discriminator} class Engineer(Person): - __tablename__ = 'engineers' - __mapper_args__ = {'polymorphic_identity': 'engineer'} + __tablename__ = "engineers" + __mapper_args__ = {"polymorphic_identity": "engineer"} primary_language = Column(String(50)) - id = Column('id', Integer, ForeignKey('people.id'), - primary_key=True) + id = Column( + "id", Integer, ForeignKey("people.id"), primary_key=True + ) - Person.name = Column('name', String(50)) + Person.name = Column("name", String(50)) Base.metadata.create_all() sess = create_session() - e1 = Engineer(primary_language='java', name='dilbert') + e1 = Engineer(primary_language="java", name="dilbert") sess.add(e1) sess.flush() sess.expunge_all() - eq_(sess.query(Person).first(), - Engineer(primary_language='java', name='dilbert')) + eq_( + sess.query(Person).first(), + Engineer(primary_language="java", name="dilbert"), + ) def test_add_sub_parentcol_after_the_fact(self): - class Person(Base, fixtures.ComparableEntity): - __tablename__ = 'people' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} + __tablename__ = "people" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + discriminator = Column("type", String(50)) + __mapper_args__ = {"polymorphic_on": discriminator} class Engineer(Person): - __tablename__ = 'engineers' - __mapper_args__ = {'polymorphic_identity': 'engineer'} + __tablename__ = "engineers" + __mapper_args__ = {"polymorphic_identity": "engineer"} primary_language = Column(String(50)) - id = Column('id', Integer, ForeignKey('people.id'), - primary_key=True) + id = Column( + "id", Integer, ForeignKey("people.id"), primary_key=True + ) class Admin(Engineer): - __tablename__ = 'admins' - __mapper_args__ = {'polymorphic_identity': 'admin'} + __tablename__ = "admins" + __mapper_args__ = {"polymorphic_identity": "admin"} workstation = Column(String(50)) - id = Column('id', Integer, ForeignKey('engineers.id'), - primary_key=True) + id = Column( + "id", Integer, ForeignKey("engineers.id"), primary_key=True + ) - Person.name = Column('name', String(50)) + Person.name = Column("name", String(50)) Base.metadata.create_all() sess = create_session() - e1 = Admin(primary_language='java', name='dilbert', - workstation='foo') + e1 = Admin(primary_language="java", name="dilbert", workstation="foo") sess.add(e1) sess.flush() sess.expunge_all() - eq_(sess.query(Person).first(), - Admin(primary_language='java', name='dilbert', workstation='foo')) + eq_( + sess.query(Person).first(), + Admin(primary_language="java", name="dilbert", workstation="foo"), + ) def test_subclass_mixin(self): - class Person(Base, fixtures.ComparableEntity): - __tablename__ = 'people' - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} + __tablename__ = "people" + id = Column("id", Integer, primary_key=True) + name = Column("name", String(50)) + discriminator = Column("type", String(50)) + __mapper_args__ = {"polymorphic_on": discriminator} class MyMixin(object): @@ -282,11 +323,12 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): class Engineer(MyMixin, Person): - __tablename__ = 'engineers' - __mapper_args__ = {'polymorphic_identity': 'engineer'} - id = Column('id', Integer, ForeignKey('people.id'), - primary_key=True) - primary_language = Column('primary_language', String(50)) + __tablename__ = "engineers" + __mapper_args__ = {"polymorphic_identity": "engineer"} + id = Column( + "id", Integer, ForeignKey("people.id"), primary_key=True + ) + primary_language = Column("primary_language", String(50)) assert class_mapper(Engineer).inherits is class_mapper(Person) @@ -294,69 +336,88 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): class Person(object): pass - person_table = Table('people', Base.metadata, - Column('id', Integer, primary_key=True), - Column('kind', String(50))) + person_table = Table( + "people", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("kind", String(50)), + ) - mapper(Person, person_table, - polymorphic_on='kind', polymorphic_identity='person') + mapper( + Person, + person_table, + polymorphic_on="kind", + polymorphic_identity="person", + ) class SpecialPerson(Person): __abstract__ = True class Manager(SpecialPerson, Base): - __tablename__ = 'managers' + __tablename__ = "managers" id = Column(Integer, ForeignKey(Person.id), primary_key=True) - __mapper_args__ = { - 'polymorphic_identity': 'manager' - } + __mapper_args__ = {"polymorphic_identity": "manager"} from sqlalchemy import inspect + assert inspect(Manager).inherits is inspect(Person) - eq_(set(class_mapper(Person).class_manager), {'id', 'kind'}) - eq_(set(class_mapper(Manager).class_manager), {'id', 'kind'}) + eq_(set(class_mapper(Person).class_manager), {"id", "kind"}) + eq_(set(class_mapper(Manager).class_manager), {"id", "kind"}) def test_intermediate_unmapped_class_on_classical(self): class Person(object): pass - person_table = Table('people', Base.metadata, - Column('id', Integer, primary_key=True), - Column('kind', String(50))) + person_table = Table( + "people", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("kind", String(50)), + ) - mapper(Person, person_table, - polymorphic_on='kind', polymorphic_identity='person') + mapper( + Person, + person_table, + polymorphic_on="kind", + polymorphic_identity="person", + ) class SpecialPerson(Person): pass class Manager(SpecialPerson, Base): - __tablename__ = 'managers' + __tablename__ = "managers" id = Column(Integer, ForeignKey(Person.id), primary_key=True) - __mapper_args__ = { - 'polymorphic_identity': 'manager' - } + __mapper_args__ = {"polymorphic_identity": "manager"} from sqlalchemy import inspect + assert inspect(Manager).inherits is inspect(Person) - eq_(set(class_mapper(Person).class_manager), {'id', 'kind'}) - eq_(set(class_mapper(Manager).class_manager), {'id', 'kind'}) + eq_(set(class_mapper(Person).class_manager), {"id", "kind"}) + eq_(set(class_mapper(Manager).class_manager), {"id", "kind"}) def test_class_w_invalid_multiple_bases(self): class Person(object): pass - person_table = Table('people', Base.metadata, - Column('id', Integer, primary_key=True), - Column('kind', String(50))) + person_table = Table( + "people", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("kind", String(50)), + ) - mapper(Person, person_table, - polymorphic_on='kind', polymorphic_identity='person') + mapper( + Person, + person_table, + polymorphic_on="kind", + polymorphic_identity="person", + ) class DeclPerson(Base): - __tablename__ = 'decl_people' + __tablename__ = "decl_people" id = Column(Integer, primary_key=True) kind = Column(String(50)) @@ -365,48 +426,47 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): def go(): class Manager(SpecialPerson, DeclPerson): - __tablename__ = 'managers' - id = Column(Integer, - ForeignKey(DeclPerson.id), primary_key=True) - __mapper_args__ = { - 'polymorphic_identity': 'manager' - } + __tablename__ = "managers" + id = Column( + Integer, ForeignKey(DeclPerson.id), primary_key=True + ) + __mapper_args__ = {"polymorphic_identity": "manager"} assert_raises_message( sa.exc.InvalidRequestError, r"Class .*Manager.* has multiple mapped " r"bases: \[.*Person.*DeclPerson.*\]", - go + go, ) def test_with_undefined_foreignkey(self): - class Parent(Base): - __tablename__ = 'parent' - id = Column('id', Integer, primary_key=True) - tp = Column('type', String(50)) + __tablename__ = "parent" + id = Column("id", Integer, primary_key=True) + tp = Column("type", String(50)) __mapper_args__ = dict(polymorphic_on=tp) class Child1(Parent): - __tablename__ = 'child1' - id = Column('id', Integer, ForeignKey('parent.id'), - primary_key=True) - related_child2 = Column('c2', Integer, - ForeignKey('child2.id')) - __mapper_args__ = dict(polymorphic_identity='child1') + __tablename__ = "child1" + id = Column( + "id", Integer, ForeignKey("parent.id"), primary_key=True + ) + related_child2 = Column("c2", Integer, ForeignKey("child2.id")) + __mapper_args__ = dict(polymorphic_identity="child1") # no exception is raised by the ForeignKey to "child2" even # though child2 doesn't exist yet class Child2(Parent): - __tablename__ = 'child2' - id = Column('id', Integer, ForeignKey('parent.id'), - primary_key=True) - related_child1 = Column('c1', Integer) - __mapper_args__ = dict(polymorphic_identity='child2') + __tablename__ = "child2" + id = Column( + "id", Integer, ForeignKey("parent.id"), primary_key=True + ) + related_child1 = Column("c1", Integer) + __mapper_args__ = dict(polymorphic_identity="child2") sa.orm.configure_mappers() # no exceptions here @@ -419,31 +479,29 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): """ class Booking(Base): - __tablename__ = 'booking' + __tablename__ = "booking" id = Column(Integer, primary_key=True) class PlanBooking(Booking): - __tablename__ = 'plan_booking' - id = Column(Integer, ForeignKey(Booking.id), - primary_key=True) + __tablename__ = "plan_booking" + id = Column(Integer, ForeignKey(Booking.id), primary_key=True) # referencing PlanBooking.id gives us the column # on plan_booking, not booking class FeatureBooking(Booking): - __tablename__ = 'feature_booking' - id = Column(Integer, ForeignKey(Booking.id), - primary_key=True) - plan_booking_id = Column(Integer, - ForeignKey(PlanBooking.id)) + __tablename__ = "feature_booking" + id = Column(Integer, ForeignKey(Booking.id), primary_key=True) + plan_booking_id = Column(Integer, ForeignKey(PlanBooking.id)) - plan_booking = relationship(PlanBooking, - backref='feature_bookings') + plan_booking = relationship( + PlanBooking, backref="feature_bookings" + ) - assert FeatureBooking.__table__.c.plan_booking_id.\ - references(PlanBooking.__table__.c.id) + assert FeatureBooking.__table__.c.plan_booking_id.references( + PlanBooking.__table__.c.id + ) - assert FeatureBooking.__table__.c.id.\ - references(Booking.__table__.c.id) + assert FeatureBooking.__table__.c.id.references(Booking.__table__.c.id) def test_single_colsonbase(self): """test single inheritance where all the columns are on the base @@ -451,55 +509,71 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): class Company(Base, fixtures.ComparableEntity): - __tablename__ = 'companies' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) - employees = relationship('Person') + __tablename__ = "companies" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) + employees = relationship("Person") class Person(Base, fixtures.ComparableEntity): - __tablename__ = 'people' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - company_id = Column('company_id', Integer, - ForeignKey('companies.id')) - name = Column('name', String(50)) - discriminator = Column('type', String(50)) - primary_language = Column('primary_language', String(50)) - golf_swing = Column('golf_swing', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} + __tablename__ = "people" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + company_id = Column( + "company_id", Integer, ForeignKey("companies.id") + ) + name = Column("name", String(50)) + discriminator = Column("type", String(50)) + primary_language = Column("primary_language", String(50)) + golf_swing = Column("golf_swing", String(50)) + __mapper_args__ = {"polymorphic_on": discriminator} class Engineer(Person): - __mapper_args__ = {'polymorphic_identity': 'engineer'} + __mapper_args__ = {"polymorphic_identity": "engineer"} class Manager(Person): - __mapper_args__ = {'polymorphic_identity': 'manager'} + __mapper_args__ = {"polymorphic_identity": "manager"} Base.metadata.create_all() sess = create_session() c1 = Company( - name='MegaCorp, Inc.', + name="MegaCorp, Inc.", employees=[ - Engineer(name='dilbert', primary_language='java'), - Engineer(name='wally', primary_language='c++'), - Manager(name='dogbert', golf_swing='fore!')]) + Engineer(name="dilbert", primary_language="java"), + Engineer(name="wally", primary_language="c++"), + Manager(name="dogbert", golf_swing="fore!"), + ], + ) - c2 = Company(name='Elbonia, Inc.', - employees=[Engineer(name='vlad', - primary_language='cobol')]) + c2 = Company( + name="Elbonia, Inc.", + employees=[Engineer(name="vlad", primary_language="cobol")], + ) sess.add(c1) sess.add(c2) sess.flush() sess.expunge_all() - eq_(sess.query(Person).filter(Engineer.primary_language - == 'cobol').first(), - Engineer(name='vlad')) - eq_(sess.query(Company).filter(Company.employees.of_type(Engineer). - any(Engineer.primary_language - == 'cobol')).first(), c2) + eq_( + sess.query(Person) + .filter(Engineer.primary_language == "cobol") + .first(), + Engineer(name="vlad"), + ) + eq_( + sess.query(Company) + .filter( + Company.employees.of_type(Engineer).any( + Engineer.primary_language == "cobol" + ) + ) + .first(), + c2, + ) def test_single_colsonsub(self): """test single inheritance where the columns are local to their @@ -511,30 +585,32 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): class Company(Base, fixtures.ComparableEntity): - __tablename__ = 'companies' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) - employees = relationship('Person') + __tablename__ = "companies" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) + employees = relationship("Person") class Person(Base, fixtures.ComparableEntity): - __tablename__ = 'people' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) - company_id = Column(Integer, ForeignKey('companies.id')) + __tablename__ = "people" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + company_id = Column(Integer, ForeignKey("companies.id")) name = Column(String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} + discriminator = Column("type", String(50)) + __mapper_args__ = {"polymorphic_on": discriminator} class Engineer(Person): - __mapper_args__ = {'polymorphic_identity': 'engineer'} + __mapper_args__ = {"polymorphic_identity": "engineer"} primary_language = Column(String(50)) class Manager(Person): - __mapper_args__ = {'polymorphic_identity': 'manager'} + __mapper_args__ = {"polymorphic_identity": "manager"} golf_swing = Column(String(50)) # we have here a situation that is somewhat unique. the Person @@ -549,31 +625,42 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): assert Person.__table__.c.primary_language is not None assert Engineer.primary_language is not None assert Manager.golf_swing is not None - assert not hasattr(Person, 'primary_language') - assert not hasattr(Person, 'golf_swing') - assert not hasattr(Engineer, 'golf_swing') - assert not hasattr(Manager, 'primary_language') + assert not hasattr(Person, "primary_language") + assert not hasattr(Person, "golf_swing") + assert not hasattr(Engineer, "golf_swing") + assert not hasattr(Manager, "primary_language") Base.metadata.create_all() sess = create_session() - e1 = Engineer(name='dilbert', primary_language='java') - e2 = Engineer(name='wally', primary_language='c++') - m1 = Manager(name='dogbert', golf_swing='fore!') - c1 = Company(name='MegaCorp, Inc.', employees=[e1, e2, m1]) - e3 = Engineer(name='vlad', primary_language='cobol') - c2 = Company(name='Elbonia, Inc.', employees=[e3]) + e1 = Engineer(name="dilbert", primary_language="java") + e2 = Engineer(name="wally", primary_language="c++") + m1 = Manager(name="dogbert", golf_swing="fore!") + c1 = Company(name="MegaCorp, Inc.", employees=[e1, e2, m1]) + e3 = Engineer(name="vlad", primary_language="cobol") + c2 = Company(name="Elbonia, Inc.", employees=[e3]) sess.add(c1) sess.add(c2) sess.flush() sess.expunge_all() - eq_(sess.query(Person).filter(Engineer.primary_language - == 'cobol').first(), - Engineer(name='vlad')) - eq_(sess.query(Company).filter(Company.employees.of_type(Engineer). - any(Engineer.primary_language - == 'cobol')).first(), c2) - eq_(sess.query(Engineer).filter_by(primary_language='cobol' - ).one(), - Engineer(name='vlad', primary_language='cobol')) + eq_( + sess.query(Person) + .filter(Engineer.primary_language == "cobol") + .first(), + Engineer(name="vlad"), + ) + eq_( + sess.query(Company) + .filter( + Company.employees.of_type(Engineer).any( + Engineer.primary_language == "cobol" + ) + ) + .first(), + c2, + ) + eq_( + sess.query(Engineer).filter_by(primary_language="cobol").one(), + Engineer(name="vlad", primary_language="cobol"), + ) def test_single_cols_on_sub_base_of_joined(self): """test [ticket:3895]""" @@ -584,16 +671,12 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): id = Column(Integer, primary_key=True) type = Column(String) - __mapper_args__ = { - "polymorphic_on": type, - } + __mapper_args__ = {"polymorphic_on": type} class Contractor(Person): contractor_field = Column(String) - __mapper_args__ = { - "polymorphic_identity": "contractor", - } + __mapper_args__ = {"polymorphic_identity": "contractor"} class Employee(Person): __tablename__ = "employee" @@ -601,104 +684,115 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): id = Column(Integer, ForeignKey(Person.id), primary_key=True) class Engineer(Employee): - __mapper_args__ = { - "polymorphic_identity": "engineer", - } + __mapper_args__ = {"polymorphic_identity": "engineer"} configure_mappers() - is_false(hasattr(Person, 'contractor_field')) - is_true(hasattr(Contractor, 'contractor_field')) - is_false(hasattr(Employee, 'contractor_field')) - is_false(hasattr(Engineer, 'contractor_field')) + is_false(hasattr(Person, "contractor_field")) + is_true(hasattr(Contractor, "contractor_field")) + is_false(hasattr(Employee, "contractor_field")) + is_false(hasattr(Engineer, "contractor_field")) def test_single_cols_on_sub_to_joined(self): """test [ticket:3797]""" class BaseUser(Base): - __tablename__ = 'root' + __tablename__ = "root" id = Column(Integer, primary_key=True) row_type = Column(String) __mapper_args__ = { - 'polymorphic_on': row_type, - 'polymorphic_identity': 'baseuser' + "polymorphic_on": row_type, + "polymorphic_identity": "baseuser", } class User(BaseUser): - __tablename__ = 'user' + __tablename__ = "user" - __mapper_args__ = { - 'polymorphic_identity': 'user' - } + __mapper_args__ = {"polymorphic_identity": "user"} baseuser_id = Column( - Integer, ForeignKey('root.id'), primary_key=True) + Integer, ForeignKey("root.id"), primary_key=True + ) class Bat(Base): - __tablename__ = 'bat' + __tablename__ = "bat" id = Column(Integer, primary_key=True) class Thing(Base): - __tablename__ = 'thing' + __tablename__ = "thing" id = Column(Integer, primary_key=True) - owner_id = Column(Integer, ForeignKey('user.baseuser_id')) - owner = relationship('User') + owner_id = Column(Integer, ForeignKey("user.baseuser_id")) + owner = relationship("User") class SubUser(User): - __mapper_args__ = { - 'polymorphic_identity': 'subuser' - } + __mapper_args__ = {"polymorphic_identity": "subuser"} - sub_user_custom_thing = Column(Integer, ForeignKey('bat.id')) + sub_user_custom_thing = Column(Integer, ForeignKey("bat.id")) eq_( User.__table__.foreign_keys, User.baseuser_id.foreign_keys.union( - SubUser.sub_user_custom_thing.foreign_keys)) - is_true(Thing.owner.property.primaryjoin.compare( - Thing.owner_id == User.baseuser_id)) + SubUser.sub_user_custom_thing.foreign_keys + ), + ) + is_true( + Thing.owner.property.primaryjoin.compare( + Thing.owner_id == User.baseuser_id + ) + ) def test_single_constraint_on_sub(self): """test the somewhat unusual case of [ticket:3341]""" class Person(Base, fixtures.ComparableEntity): - __tablename__ = 'people' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "people" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} + discriminator = Column("type", String(50)) + __mapper_args__ = {"polymorphic_on": discriminator} class Engineer(Person): - __mapper_args__ = {'polymorphic_identity': 'engineer'} + __mapper_args__ = {"polymorphic_identity": "engineer"} primary_language = Column(String(50)) __hack_args_one__ = sa.UniqueConstraint( - Person.name, primary_language) + Person.name, primary_language + ) __hack_args_two__ = sa.CheckConstraint( - Person.name != primary_language) - - uq = [c for c in Person.__table__.constraints - if isinstance(c, sa.UniqueConstraint)][0] - ck = [c for c in Person.__table__.constraints - if isinstance(c, sa.CheckConstraint)][0] + Person.name != primary_language + ) + + uq = [ + c + for c in Person.__table__.constraints + if isinstance(c, sa.UniqueConstraint) + ][0] + ck = [ + c + for c in Person.__table__.constraints + if isinstance(c, sa.CheckConstraint) + ][0] eq_( list(uq.columns), - [Person.__table__.c.name, Person.__table__.c.primary_language] + [Person.__table__.c.name, Person.__table__.c.primary_language], ) eq_( list(ck.columns), - [Person.__table__.c.name, Person.__table__.c.primary_language] + [Person.__table__.c.name, Person.__table__.c.primary_language], ) - @testing.skip_if(lambda: testing.against('oracle'), - "Test has an empty insert in it at the moment") + @testing.skip_if( + lambda: testing.against("oracle"), + "Test has an empty insert in it at the moment", + ) def test_columns_single_inheritance_conflict_resolution(self): """Test that a declared_attr can return the existing column and it will be ignored. this allows conditional columns to be added. @@ -706,8 +800,9 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): See [ticket:2472]. """ + class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) class Engineer(Person): @@ -717,8 +812,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): @declared_attr def target_id(cls): return cls.__table__.c.get( - 'target_id', - Column(Integer, ForeignKey('other.id'))) + "target_id", Column(Integer, ForeignKey("other.id")) + ) @declared_attr def target(cls): @@ -731,34 +826,31 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): @declared_attr def target_id(cls): return cls.__table__.c.get( - 'target_id', - Column(Integer, ForeignKey('other.id'))) + "target_id", Column(Integer, ForeignKey("other.id")) + ) @declared_attr def target(cls): return relationship("Other") class Other(Base): - __tablename__ = 'other' + __tablename__ = "other" id = Column(Integer, primary_key=True) is_( Engineer.target_id.property.columns[0], - Person.__table__.c.target_id + Person.__table__.c.target_id, ) is_( - Manager.target_id.property.columns[0], - Person.__table__.c.target_id + Manager.target_id.property.columns[0], Person.__table__.c.target_id ) # do a brief round trip on this Base.metadata.create_all() session = Session() o1, o2 = Other(), Other() - session.add_all([ - Engineer(target=o1), - Manager(target=o2), - Manager(target=o1) - ]) + session.add_all( + [Engineer(target=o1), Manager(target=o2), Manager(target=o1)] + ) session.commit() eq_(session.query(Engineer).first().target, o1) @@ -767,8 +859,9 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): #4352. """ + class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) target_id = Column(Integer, primary_key=True) @@ -780,8 +873,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): @declared_attr def target_id(cls): return cls.__table__.c.get( - 'target_id', - Column(Integer, primary_key=True)) + "target_id", Column(Integer, primary_key=True) + ) class Manager(Person): @@ -790,22 +883,22 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): @declared_attr def target_id(cls): return cls.__table__.c.get( - 'target_id', - Column(Integer, primary_key=True)) + "target_id", Column(Integer, primary_key=True) + ) is_( Engineer.target_id.property.columns[0], - Person.__table__.c.target_id + Person.__table__.c.target_id, ) is_( - Manager.target_id.property.columns[0], - Person.__table__.c.target_id + Manager.target_id.property.columns[0], Person.__table__.c.target_id ) def test_columns_single_inheritance_cascading_resolution_pk(self): """An additional test for #4352 in terms of the requested use case. """ + class TestBase(Base): __abstract__ = True @@ -813,14 +906,15 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): def id(cls): col_val = None if TestBase not in cls.__bases__: - col_val = cls.__table__.c.get('id') + col_val = cls.__table__.c.get("id") if col_val is None: col_val = Column(Integer, primary_key=True) return col_val class Person(TestBase): """single table base class""" - __tablename__ = 'person' + + __tablename__ = "person" class Engineer(Person): """ single table inheritance, no extra cols """ @@ -832,130 +926,143 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): is_(Manager.id.property.columns[0], Person.__table__.c.id) def test_joined_from_single(self): - class Company(Base, fixtures.ComparableEntity): - __tablename__ = 'companies' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) - name = Column('name', String(50)) - employees = relationship('Person') + __tablename__ = "companies" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column("name", String(50)) + employees = relationship("Person") class Person(Base, fixtures.ComparableEntity): - __tablename__ = 'people' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) - company_id = Column(Integer, ForeignKey('companies.id')) + __tablename__ = "people" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + company_id = Column(Integer, ForeignKey("companies.id")) name = Column(String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} + discriminator = Column("type", String(50)) + __mapper_args__ = {"polymorphic_on": discriminator} class Manager(Person): - __mapper_args__ = {'polymorphic_identity': 'manager'} + __mapper_args__ = {"polymorphic_identity": "manager"} golf_swing = Column(String(50)) class Engineer(Person): - __tablename__ = 'engineers' - __mapper_args__ = {'polymorphic_identity': 'engineer'} - id = Column(Integer, ForeignKey('people.id'), - primary_key=True) + __tablename__ = "engineers" + __mapper_args__ = {"polymorphic_identity": "engineer"} + id = Column(Integer, ForeignKey("people.id"), primary_key=True) primary_language = Column(String(50)) assert Person.__table__.c.golf_swing is not None - assert 'primary_language' not in Person.__table__.c + assert "primary_language" not in Person.__table__.c assert Engineer.__table__.c.primary_language is not None assert Engineer.primary_language is not None assert Manager.golf_swing is not None - assert not hasattr(Person, 'primary_language') - assert not hasattr(Person, 'golf_swing') - assert not hasattr(Engineer, 'golf_swing') - assert not hasattr(Manager, 'primary_language') + assert not hasattr(Person, "primary_language") + assert not hasattr(Person, "golf_swing") + assert not hasattr(Engineer, "golf_swing") + assert not hasattr(Manager, "primary_language") Base.metadata.create_all() sess = create_session() - e1 = Engineer(name='dilbert', primary_language='java') - e2 = Engineer(name='wally', primary_language='c++') - m1 = Manager(name='dogbert', golf_swing='fore!') - c1 = Company(name='MegaCorp, Inc.', employees=[e1, e2, m1]) - e3 = Engineer(name='vlad', primary_language='cobol') - c2 = Company(name='Elbonia, Inc.', employees=[e3]) + e1 = Engineer(name="dilbert", primary_language="java") + e2 = Engineer(name="wally", primary_language="c++") + m1 = Manager(name="dogbert", golf_swing="fore!") + c1 = Company(name="MegaCorp, Inc.", employees=[e1, e2, m1]) + e3 = Engineer(name="vlad", primary_language="cobol") + c2 = Company(name="Elbonia, Inc.", employees=[e3]) sess.add(c1) sess.add(c2) sess.flush() sess.expunge_all() - eq_(sess.query(Person).with_polymorphic(Engineer). - filter(Engineer.primary_language - == 'cobol').first(), Engineer(name='vlad')) - eq_(sess.query(Company).filter(Company.employees.of_type(Engineer). - any(Engineer.primary_language - == 'cobol')).first(), c2) - eq_(sess.query(Engineer).filter_by(primary_language='cobol' - ).one(), - Engineer(name='vlad', primary_language='cobol')) + eq_( + sess.query(Person) + .with_polymorphic(Engineer) + .filter(Engineer.primary_language == "cobol") + .first(), + Engineer(name="vlad"), + ) + eq_( + sess.query(Company) + .filter( + Company.employees.of_type(Engineer).any( + Engineer.primary_language == "cobol" + ) + ) + .first(), + c2, + ) + eq_( + sess.query(Engineer).filter_by(primary_language="cobol").one(), + Engineer(name="vlad", primary_language="cobol"), + ) def test_single_from_joined_colsonsub(self): class Person(Base, fixtures.ComparableEntity): - __tablename__ = 'people' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "people" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} + discriminator = Column("type", String(50)) + __mapper_args__ = {"polymorphic_on": discriminator} class Manager(Person): - __tablename__ = 'manager' - __mapper_args__ = {'polymorphic_identity': 'manager'} - id = Column(Integer, ForeignKey('people.id'), primary_key=True) + __tablename__ = "manager" + __mapper_args__ = {"polymorphic_identity": "manager"} + id = Column(Integer, ForeignKey("people.id"), primary_key=True) golf_swing = Column(String(50)) class Boss(Manager): boss_name = Column(String(50)) is_( - Boss.__mapper__.column_attrs['boss_name'].columns[0], - Manager.__table__.c.boss_name + Boss.__mapper__.column_attrs["boss_name"].columns[0], + Manager.__table__.c.boss_name, ) def test_polymorphic_on_converted_from_inst(self): class A(Base): - __tablename__ = 'A' + __tablename__ = "A" id = Column(Integer, primary_key=True) discriminator = Column(String) @declared_attr def __mapper_args__(cls): return { - 'polymorphic_identity': cls.__name__, - 'polymorphic_on': cls.discriminator + "polymorphic_identity": cls.__name__, + "polymorphic_on": cls.discriminator, } class B(A): pass + is_(B.__mapper__.polymorphic_on, A.__table__.c.discriminator) def test_add_deferred(self): - class Person(Base, fixtures.ComparableEntity): - __tablename__ = 'people' - id = Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "people" + id = Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ) Person.name = deferred(Column(String(10))) Base.metadata.create_all() sess = create_session() - p = Person(name='ratbert') + p = Person(name="ratbert") sess.add(p) sess.flush() sess.expunge_all() - eq_(sess.query(Person).all(), [Person(name='ratbert')]) + eq_(sess.query(Person).all(), [Person(name="ratbert")]) sess.expunge_all() - person = sess.query(Person).filter(Person.name == 'ratbert' - ).one() - assert 'name' not in person.__dict__ + person = sess.query(Person).filter(Person.name == "ratbert").one() + assert "name" not in person.__dict__ def test_single_fksonsub(self): """test single inheritance with a foreign key-holding column on @@ -965,78 +1072,90 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): class Person(Base, fixtures.ComparableEntity): - __tablename__ = 'people' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "people" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} + discriminator = Column("type", String(50)) + __mapper_args__ = {"polymorphic_on": discriminator} class Engineer(Person): - __mapper_args__ = {'polymorphic_identity': 'engineer'} - primary_language_id = Column(Integer, - ForeignKey('languages.id')) - primary_language = relationship('Language') + __mapper_args__ = {"polymorphic_identity": "engineer"} + primary_language_id = Column(Integer, ForeignKey("languages.id")) + primary_language = relationship("Language") class Language(Base, fixtures.ComparableEntity): - __tablename__ = 'languages' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "languages" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) - assert not hasattr(Person, 'primary_language_id') + assert not hasattr(Person, "primary_language_id") Base.metadata.create_all() sess = create_session() - java, cpp, cobol = Language(name='java'), Language(name='cpp'), \ - Language(name='cobol') - e1 = Engineer(name='dilbert', primary_language=java) - e2 = Engineer(name='wally', primary_language=cpp) - e3 = Engineer(name='vlad', primary_language=cobol) + java, cpp, cobol = ( + Language(name="java"), + Language(name="cpp"), + Language(name="cobol"), + ) + e1 = Engineer(name="dilbert", primary_language=java) + e2 = Engineer(name="wally", primary_language=cpp) + e3 = Engineer(name="vlad", primary_language=cobol) sess.add_all([e1, e2, e3]) sess.flush() sess.expunge_all() - eq_(sess.query(Person).filter(Engineer.primary_language.has( - Language.name - == 'cobol')).first(), - Engineer(name='vlad', primary_language=Language(name='cobol'))) - eq_(sess.query(Engineer).filter(Engineer.primary_language.has( - Language.name - == 'cobol')).one(), - Engineer(name='vlad', primary_language=Language(name='cobol'))) - eq_(sess.query(Person).join(Engineer.primary_language).order_by( - Language.name).all(), - [Engineer(name='vlad', - primary_language=Language(name='cobol')), - Engineer(name='wally', primary_language=Language(name='cpp' - )), - Engineer(name='dilbert', primary_language=Language(name='java'))]) + eq_( + sess.query(Person) + .filter(Engineer.primary_language.has(Language.name == "cobol")) + .first(), + Engineer(name="vlad", primary_language=Language(name="cobol")), + ) + eq_( + sess.query(Engineer) + .filter(Engineer.primary_language.has(Language.name == "cobol")) + .one(), + Engineer(name="vlad", primary_language=Language(name="cobol")), + ) + eq_( + sess.query(Person) + .join(Engineer.primary_language) + .order_by(Language.name) + .all(), + [ + Engineer(name="vlad", primary_language=Language(name="cobol")), + Engineer(name="wally", primary_language=Language(name="cpp")), + Engineer( + name="dilbert", primary_language=Language(name="java") + ), + ], + ) def test_single_three_levels(self): - class Person(Base, fixtures.ComparableEntity): - __tablename__ = 'people' + __tablename__ = "people" id = Column(Integer, primary_key=True) name = Column(String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} + discriminator = Column("type", String(50)) + __mapper_args__ = {"polymorphic_on": discriminator} class Engineer(Person): - __mapper_args__ = {'polymorphic_identity': 'engineer'} + __mapper_args__ = {"polymorphic_identity": "engineer"} primary_language = Column(String(50)) class JuniorEngineer(Engineer): - __mapper_args__ = \ - {'polymorphic_identity': 'junior_engineer'} + __mapper_args__ = {"polymorphic_identity": "junior_engineer"} nerf_gun = Column(String(50)) class Manager(Person): - __mapper_args__ = {'polymorphic_identity': 'manager'} + __mapper_args__ = {"polymorphic_identity": "manager"} golf_swing = Column(String(50)) assert JuniorEngineer.nerf_gun @@ -1044,33 +1163,31 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): assert JuniorEngineer.name assert Manager.golf_swing assert Engineer.primary_language - assert not hasattr(Engineer, 'golf_swing') - assert not hasattr(Engineer, 'nerf_gun') - assert not hasattr(Manager, 'nerf_gun') - assert not hasattr(Manager, 'primary_language') + assert not hasattr(Engineer, "golf_swing") + assert not hasattr(Engineer, "nerf_gun") + assert not hasattr(Manager, "nerf_gun") + assert not hasattr(Manager, "primary_language") def test_single_detects_conflict(self): - class Person(Base): - __tablename__ = 'people' + __tablename__ = "people" id = Column(Integer, primary_key=True) name = Column(String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} + discriminator = Column("type", String(50)) + __mapper_args__ = {"polymorphic_on": discriminator} class Engineer(Person): - __mapper_args__ = {'polymorphic_identity': 'engineer'} + __mapper_args__ = {"polymorphic_identity": "engineer"} primary_language = Column(String(50)) # test sibling col conflict def go(): - class Manager(Person): - __mapper_args__ = {'polymorphic_identity': 'manager'} + __mapper_args__ = {"polymorphic_identity": "manager"} golf_swing = Column(String(50)) primary_language = Column(String(50)) @@ -1079,72 +1196,64 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): # test parent col conflict def go(): - class Salesman(Person): - __mapper_args__ = {'polymorphic_identity': 'manager'} + __mapper_args__ = {"polymorphic_identity": "manager"} name = Column(String(50)) assert_raises(sa.exc.ArgumentError, go) def test_single_no_special_cols(self): - class Person(Base, fixtures.ComparableEntity): - __tablename__ = 'people' - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} + __tablename__ = "people" + id = Column("id", Integer, primary_key=True) + name = Column("name", String(50)) + discriminator = Column("type", String(50)) + __mapper_args__ = {"polymorphic_on": discriminator} def go(): - class Engineer(Person): - __mapper_args__ = {'polymorphic_identity': 'engineer'} - primary_language = Column('primary_language', - String(50)) + __mapper_args__ = {"polymorphic_identity": "engineer"} + primary_language = Column("primary_language", String(50)) foo_bar = Column(Integer, primary_key=True) - assert_raises_message(sa.exc.ArgumentError, - 'place primary key', go) + assert_raises_message(sa.exc.ArgumentError, "place primary key", go) def test_single_no_table_args(self): - class Person(Base, fixtures.ComparableEntity): - __tablename__ = 'people' - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} + __tablename__ = "people" + id = Column("id", Integer, primary_key=True) + name = Column("name", String(50)) + discriminator = Column("type", String(50)) + __mapper_args__ = {"polymorphic_on": discriminator} def go(): - class Engineer(Person): - __mapper_args__ = {'polymorphic_identity': 'engineer'} - primary_language = Column('primary_language', - String(50)) + __mapper_args__ = {"polymorphic_identity": "engineer"} + primary_language = Column("primary_language", String(50)) # this should be on the Person class, as this is single # table inheritance, which is why we test that this # throws an exception! - __table_args__ = {'mysql_engine': 'InnoDB'} + __table_args__ = {"mysql_engine": "InnoDB"} - assert_raises_message(sa.exc.ArgumentError, - 'place __table_args__', go) + assert_raises_message(sa.exc.ArgumentError, "place __table_args__", go) @testing.emits_warning("This declarative") def test_dupe_name_in_hierarchy(self): class A(Base): __tablename__ = "a" id = Column(Integer, primary_key=True) + a_1 = A class A(a_1): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer(), ForeignKey(a_1.id), primary_key=True) assert A.__mapper__.inherits is a_1.__mapper__ @@ -1155,68 +1264,76 @@ class OverlapColPrecedenceTest(DeclarativeTestBase): """test #1892 cases when declarative does column precedence.""" def _run_test(self, Engineer, e_id, p_id): - p_table = Base.metadata.tables['person'] - e_table = Base.metadata.tables['engineer'] + p_table = Base.metadata.tables["person"] + e_table = Base.metadata.tables["engineer"] assert Engineer.id.property.columns[0] is e_table.c[e_id] assert Engineer.id.property.columns[1] is p_table.c[p_id] def test_basic(self): class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) class Engineer(Person): - __tablename__ = 'engineer' - id = Column(Integer, ForeignKey('person.id'), primary_key=True) + __tablename__ = "engineer" + id = Column(Integer, ForeignKey("person.id"), primary_key=True) self._run_test(Engineer, "id", "id") def test_alt_name_base(self): class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column("pid", Integer, primary_key=True) class Engineer(Person): - __tablename__ = 'engineer' - id = Column(Integer, ForeignKey('person.pid'), primary_key=True) + __tablename__ = "engineer" + id = Column(Integer, ForeignKey("person.pid"), primary_key=True) self._run_test(Engineer, "id", "pid") def test_alt_name_sub(self): class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) class Engineer(Person): - __tablename__ = 'engineer' - id = Column("eid", Integer, ForeignKey('person.id'), - primary_key=True) + __tablename__ = "engineer" + id = Column( + "eid", Integer, ForeignKey("person.id"), primary_key=True + ) self._run_test(Engineer, "eid", "id") def test_alt_name_both(self): class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column("pid", Integer, primary_key=True) class Engineer(Person): - __tablename__ = 'engineer' - id = Column("eid", Integer, ForeignKey('person.pid'), - primary_key=True) + __tablename__ = "engineer" + id = Column( + "eid", Integer, ForeignKey("person.pid"), primary_key=True + ) self._run_test(Engineer, "eid", "pid") class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): - - def _roundtrip(self, Employee, Manager, Engineer, Boss, - polymorphic=True, explicit_type=False): + def _roundtrip( + self, + Employee, + Manager, + Engineer, + Boss, + polymorphic=True, + explicit_type=False, + ): Base.metadata.create_all() sess = create_session() - e1 = Engineer(name='dilbert', primary_language='java') - e2 = Engineer(name='wally', primary_language='c++') - m1 = Manager(name='dogbert', golf_swing='fore!') - e3 = Engineer(name='vlad', primary_language='cobol') + e1 = Engineer(name="dilbert", primary_language="java") + e2 = Engineer(name="wally", primary_language="c++") + m1 = Manager(name="dogbert", golf_swing="fore!") + e3 = Engineer(name="vlad", primary_language="cobol") b1 = Boss(name="pointy haired") if polymorphic: @@ -1228,7 +1345,9 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): AttributeError, "does not implement attribute .?'type' " "at the instance level.", - getattr, obj, "type" + getattr, + obj, + "type", ) else: assert "type" not in Engineer.__dict__ @@ -1239,66 +1358,93 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): sess.flush() sess.expunge_all() if polymorphic: - eq_(sess.query(Employee).order_by(Employee.name).all(), - [Engineer(name='dilbert'), Manager(name='dogbert'), - Boss(name='pointy haired'), - Engineer(name='vlad'), Engineer(name='wally')]) + eq_( + sess.query(Employee).order_by(Employee.name).all(), + [ + Engineer(name="dilbert"), + Manager(name="dogbert"), + Boss(name="pointy haired"), + Engineer(name="vlad"), + Engineer(name="wally"), + ], + ) else: - eq_(sess.query(Engineer).order_by(Engineer.name).all(), - [Engineer(name='dilbert'), Engineer(name='vlad'), - Engineer(name='wally')]) - eq_(sess.query(Manager).all(), [Manager(name='dogbert')]) - eq_(sess.query(Boss).all(), [Boss(name='pointy haired')]) + eq_( + sess.query(Engineer).order_by(Engineer.name).all(), + [ + Engineer(name="dilbert"), + Engineer(name="vlad"), + Engineer(name="wally"), + ], + ) + eq_(sess.query(Manager).all(), [Manager(name="dogbert")]) + eq_(sess.query(Boss).all(), [Boss(name="pointy haired")]) e1 = sess.query(Engineer).order_by(Engineer.name).first() sess.expire(e1) - eq_(e1.name, 'dilbert') + eq_(e1.name, "dilbert") def test_explicit(self): engineers = Table( - 'engineers', Base.metadata, - Column('id', - Integer, primary_key=True, test_needs_autoincrement=True), - Column('name', String(50)), - Column('primary_language', String(50))) - managers = Table('managers', Base.metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('golf_swing', String(50)) - ) - boss = Table('boss', Base.metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('golf_swing', String(50)) - ) - punion = polymorphic_union({ - 'engineer': engineers, - 'manager': managers, - 'boss': boss}, 'type', 'punion') + "engineers", + Base.metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), + Column("primary_language", String(50)), + ) + managers = Table( + "managers", + Base.metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), + Column("golf_swing", String(50)), + ) + boss = Table( + "boss", + Base.metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), + Column("golf_swing", String(50)), + ) + punion = polymorphic_union( + {"engineer": engineers, "manager": managers, "boss": boss}, + "type", + "punion", + ) class Employee(Base, fixtures.ComparableEntity): __table__ = punion - __mapper_args__ = {'polymorphic_on': punion.c.type} + __mapper_args__ = {"polymorphic_on": punion.c.type} class Engineer(Employee): __table__ = engineers - __mapper_args__ = {'polymorphic_identity': 'engineer', - 'concrete': True} + __mapper_args__ = { + "polymorphic_identity": "engineer", + "concrete": True, + } class Manager(Employee): __table__ = managers - __mapper_args__ = {'polymorphic_identity': 'manager', - 'concrete': True} + __mapper_args__ = { + "polymorphic_identity": "manager", + "concrete": True, + } class Boss(Manager): __table__ = boss - __mapper_args__ = {'polymorphic_identity': 'boss', - 'concrete': True} + __mapper_args__ = { + "polymorphic_identity": "boss", + "concrete": True, + } self._roundtrip(Employee, Manager, Engineer, Boss) @@ -1307,34 +1453,38 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): class Employee(Base, fixtures.ComparableEntity): - __tablename__ = 'people' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "people" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) class Engineer(Employee): - __tablename__ = 'engineers' - __mapper_args__ = {'concrete': True} - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "engineers" + __mapper_args__ = {"concrete": True} + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) primary_language = Column(String(50)) name = Column(String(50)) class Manager(Employee): - __tablename__ = 'manager' - __mapper_args__ = {'concrete': True} - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "manager" + __mapper_args__ = {"concrete": True} + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) golf_swing = Column(String(50)) name = Column(String(50)) class Boss(Manager): - __tablename__ = 'boss' - __mapper_args__ = {'concrete': True} - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "boss" + __mapper_args__ = {"concrete": True} + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) golf_swing = Column(String(50)) name = Column(String(50)) @@ -1345,33 +1495,40 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): pass class Manager(Employee): - __tablename__ = 'manager' - employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "manager" + employee_id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) golf_swing = Column(String(40)) __mapper_args__ = { - 'polymorphic_identity': 'manager', - 'concrete': True} + "polymorphic_identity": "manager", + "concrete": True, + } class Boss(Manager): - __tablename__ = 'boss' - employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "boss" + employee_id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) golf_swing = Column(String(40)) __mapper_args__ = { - 'polymorphic_identity': 'boss', - 'concrete': True} + "polymorphic_identity": "boss", + "concrete": True, + } class Engineer(Employee): - __tablename__ = 'engineer' - employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "engineer" + employee_id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) primary_language = Column(String(40)) - __mapper_args__ = {'polymorphic_identity': 'engineer', - 'concrete': True} + __mapper_args__ = { + "polymorphic_identity": "engineer", + "concrete": True, + } self._roundtrip(Employee, Manager, Engineer, Boss) @@ -1382,104 +1539,120 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): return Column(String(50)) class Manager(Employee): - __tablename__ = 'manager' - employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "manager" + employee_id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) paperwork = Column(String(10)) __mapper_args__ = { - 'polymorphic_identity': 'manager', 'concrete': True} + "polymorphic_identity": "manager", + "concrete": True, + } class Engineer(Employee): - __tablename__ = 'engineer' - employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "engineer" + employee_id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) @property def paperwork(self): return "p" __mapper_args__ = { - 'polymorphic_identity': 'engineer', 'concrete': True} + "polymorphic_identity": "engineer", + "concrete": True, + } Base.metadata.create_all() sess = Session() - sess.add(Engineer(name='d')) + sess.add(Engineer(name="d")) sess.commit() # paperwork is excluded because there's a descritor; so it is # not in the Engineers mapped properties at all, though is inside the # class manager. Maybe it shouldn't be in the class manager either. - assert 'paperwork' in Engineer.__mapper__.class_manager - assert 'paperwork' not in Engineer.__mapper__.attrs.keys() + assert "paperwork" in Engineer.__mapper__.class_manager + assert "paperwork" not in Engineer.__mapper__.attrs.keys() # type currently does get mapped, as a # ConcreteInheritedProperty, which means, "ignore this thing inherited # from the concrete base". if we didn't specify concrete=True, then # this one gets stuck in the error condition also. - assert 'type' in Engineer.__mapper__.class_manager - assert 'type' in Engineer.__mapper__.attrs.keys() + assert "type" in Engineer.__mapper__.class_manager + assert "type" in Engineer.__mapper__.attrs.keys() e1 = sess.query(Engineer).first() - eq_(e1.name, 'd') + eq_(e1.name, "d") sess.expire(e1) - eq_(e1.name, 'd') + eq_(e1.name, "d") def test_concrete_extension(self): class Employee(ConcreteBase, Base, fixtures.ComparableEntity): - __tablename__ = 'employee' - employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "employee" + employee_id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) __mapper_args__ = { - 'polymorphic_identity': 'employee', - 'concrete': True} + "polymorphic_identity": "employee", + "concrete": True, + } class Manager(Employee): - __tablename__ = 'manager' - employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "manager" + employee_id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) golf_swing = Column(String(40)) __mapper_args__ = { - 'polymorphic_identity': 'manager', - 'concrete': True} + "polymorphic_identity": "manager", + "concrete": True, + } class Boss(Manager): - __tablename__ = 'boss' - employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "boss" + employee_id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) golf_swing = Column(String(40)) __mapper_args__ = { - 'polymorphic_identity': 'boss', - 'concrete': True} + "polymorphic_identity": "boss", + "concrete": True, + } class Engineer(Employee): - __tablename__ = 'engineer' - employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "engineer" + employee_id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) primary_language = Column(String(40)) - __mapper_args__ = {'polymorphic_identity': 'engineer', - 'concrete': True} + __mapper_args__ = { + "polymorphic_identity": "engineer", + "concrete": True, + } + self._roundtrip(Employee, Manager, Engineer, Boss) def test_has_inherited_table_doesnt_consider_base(self): class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) assert not has_inherited_table(A) class B(A): - __tablename__ = 'b' - id = Column(Integer, ForeignKey('a.id'), primary_key=True) + __tablename__ = "b" + id = Column(Integer, ForeignKey("a.id"), primary_key=True) assert has_inherited_table(B) def test_has_inherited_table_in_mapper_args(self): class Test(Base): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) type = Column(String(20)) @@ -1487,15 +1660,15 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): def __mapper_args__(cls): if not has_inherited_table(cls): ret = { - 'polymorphic_identity': 'default', - 'polymorphic_on': cls.type, + "polymorphic_identity": "default", + "polymorphic_on": cls.type, } else: - ret = {'polymorphic_identity': cls.__name__} + ret = {"polymorphic_identity": cls.__name__} return ret class PolyTest(Test): - __tablename__ = 'poly_test' + __tablename__ = "poly_test" id = Column(Integer, ForeignKey(Test.id), primary_key=True) configure_mappers() @@ -1508,9 +1681,10 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): pass class Manager(Employee): - __tablename__ = 'manager' - employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "manager" + employee_id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) golf_swing = Column(String(40)) @@ -1519,13 +1693,15 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): return "manager" __mapper_args__ = { - 'polymorphic_identity': "manager", - 'concrete': True} + "polymorphic_identity": "manager", + "concrete": True, + } class Boss(Manager): - __tablename__ = 'boss' - employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "boss" + employee_id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) golf_swing = Column(String(40)) @@ -1534,94 +1710,116 @@ class ConcreteInhTest(_RemoveListeners, DeclarativeTestBase): return "boss" __mapper_args__ = { - 'polymorphic_identity': "boss", - 'concrete': True} + "polymorphic_identity": "boss", + "concrete": True, + } class Engineer(Employee): - __tablename__ = 'engineer' - employee_id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "engineer" + employee_id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) primary_language = Column(String(40)) @property def type(self): return "engineer" - __mapper_args__ = {'polymorphic_identity': "engineer", - 'concrete': True} + + __mapper_args__ = { + "polymorphic_identity": "engineer", + "concrete": True, + } + self._roundtrip(Employee, Manager, Engineer, Boss, explicit_type=True) class ConcreteExtensionConfigTest( - _RemoveListeners, testing.AssertsCompiledSQL, DeclarativeTestBase): - __dialect__ = 'default' + _RemoveListeners, testing.AssertsCompiledSQL, DeclarativeTestBase +): + __dialect__ = "default" def test_classreg_setup(self): class A(Base, fixtures.ComparableEntity): - __tablename__ = 'a' - id = Column(Integer, - primary_key=True, test_needs_autoincrement=True) + __tablename__ = "a" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) data = Column(String(50)) - collection = relationship("BC", primaryjoin="BC.a_id == A.id", - collection_class=set) + collection = relationship( + "BC", primaryjoin="BC.a_id == A.id", collection_class=set + ) class BC(AbstractConcreteBase, Base, fixtures.ComparableEntity): pass class B(BC): - __tablename__ = 'b' - id = Column(Integer, - primary_key=True, test_needs_autoincrement=True) + __tablename__ = "b" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) - a_id = Column(Integer, ForeignKey('a.id')) + a_id = Column(Integer, ForeignKey("a.id")) data = Column(String(50)) b_data = Column(String(50)) - __mapper_args__ = { - "polymorphic_identity": "b", - "concrete": True - } + __mapper_args__ = {"polymorphic_identity": "b", "concrete": True} class C(BC): - __tablename__ = 'c' - id = Column(Integer, - primary_key=True, test_needs_autoincrement=True) - a_id = Column(Integer, ForeignKey('a.id')) + __tablename__ = "c" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + a_id = Column(Integer, ForeignKey("a.id")) data = Column(String(50)) c_data = Column(String(50)) - __mapper_args__ = { - "polymorphic_identity": "c", - "concrete": True - } + __mapper_args__ = {"polymorphic_identity": "c", "concrete": True} Base.metadata.create_all() sess = Session() - sess.add_all([ - A(data='a1', collection=set([ - B(data='a1b1', b_data='a1b1'), - C(data='a1b2', c_data='a1c1'), - B(data='a1b2', b_data='a1b2'), - C(data='a1c2', c_data='a1c2'), - ])), - A(data='a2', collection=set([ - B(data='a2b1', b_data='a2b1'), - C(data='a2c1', c_data='a2c1'), - B(data='a2b2', b_data='a2b2'), - C(data='a2c2', c_data='a2c2'), - ])) - ]) + sess.add_all( + [ + A( + data="a1", + collection=set( + [ + B(data="a1b1", b_data="a1b1"), + C(data="a1b2", c_data="a1c1"), + B(data="a1b2", b_data="a1b2"), + C(data="a1c2", c_data="a1c2"), + ] + ), + ), + A( + data="a2", + collection=set( + [ + B(data="a2b1", b_data="a2b1"), + C(data="a2c1", c_data="a2c1"), + B(data="a2b2", b_data="a2b2"), + C(data="a2c2", c_data="a2c2"), + ] + ), + ), + ] + ) sess.commit() sess.expunge_all() eq_( - sess.query(A).filter_by(data='a2').all(), + sess.query(A).filter_by(data="a2").all(), [ - A(data='a2', collection=set([ - B(data='a2b1', b_data='a2b1'), - B(data='a2b2', b_data='a2b2'), - C(data='a2c1', c_data='a2c1'), - C(data='a2c2', c_data='a2c2'), - ])) - ] + A( + data="a2", + collection=set( + [ + B(data="a2b1", b_data="a2b1"), + B(data="a2b2", b_data="a2b2"), + C(data="a2c1", c_data="a2c1"), + C(data="a2c2", c_data="a2c2"), + ] + ), + ) + ], ) self.assert_compile( @@ -1632,7 +1830,7 @@ class ConcreteExtensionConfigTest( "'c' AS type FROM c UNION ALL SELECT b.id AS id, b.a_id AS a_id, " "b.data AS data, CAST(NULL AS VARCHAR(50)) AS c_data, " "b.b_data AS b_data, 'b' AS type FROM b) AS pjoin " - "ON pjoin.a_id = a.id" + "ON pjoin.a_id = a.id", ) def test_prop_on_base(self): @@ -1641,7 +1839,7 @@ class ConcreteExtensionConfigTest( counter = mock.Mock() class Something(Base): - __tablename__ = 'something' + __tablename__ = "something" id = Column(Integer, primary_key=True) class AbstractConcreteAbstraction(AbstractConcreteBase, Base): @@ -1664,23 +1862,22 @@ class ConcreteExtensionConfigTest( return relationship("Something") class ConcreteConcreteAbstraction(AbstractConcreteAbstraction): - __tablename__ = 'cca' - __mapper_args__ = { - 'polymorphic_identity': 'ccb', - 'concrete': True} + __tablename__ = "cca" + __mapper_args__ = {"polymorphic_identity": "ccb", "concrete": True} # concrete is mapped, the abstract base is not (yet) assert ConcreteConcreteAbstraction.__mapper__ - assert not hasattr(AbstractConcreteAbstraction, '__mapper__') + assert not hasattr(AbstractConcreteAbstraction, "__mapper__") session = Session() self.assert_compile( session.query(ConcreteConcreteAbstraction).filter( - ConcreteConcreteAbstraction.something.has(id=1)), + ConcreteConcreteAbstraction.something.has(id=1) + ), "SELECT cca.id AS cca_id, cca.x AS cca_x, cca.y AS cca_y, " "cca.something_id AS cca_something_id FROM cca WHERE EXISTS " "(SELECT 1 FROM something WHERE something.id = cca.something_id " - "AND something.id = :id_1)" + "AND something.id = :id_1)", ) # now it is @@ -1688,35 +1885,38 @@ class ConcreteExtensionConfigTest( self.assert_compile( session.query(ConcreteConcreteAbstraction).filter( - ConcreteConcreteAbstraction.something_else.has(id=1)), + ConcreteConcreteAbstraction.something_else.has(id=1) + ), "SELECT cca.id AS cca_id, cca.x AS cca_x, cca.y AS cca_y, " "cca.something_id AS cca_something_id FROM cca WHERE EXISTS " "(SELECT 1 FROM something WHERE something.id = cca.something_id " - "AND something.id = :id_1)" + "AND something.id = :id_1)", ) self.assert_compile( session.query(AbstractConcreteAbstraction).filter( - AbstractConcreteAbstraction.something.has(id=1)), + AbstractConcreteAbstraction.something.has(id=1) + ), "SELECT pjoin.id AS pjoin_id, pjoin.x AS pjoin_x, " "pjoin.y AS pjoin_y, pjoin.something_id AS pjoin_something_id, " "pjoin.type AS pjoin_type FROM " "(SELECT cca.id AS id, cca.x AS x, cca.y AS y, " "cca.something_id AS something_id, 'ccb' AS type FROM cca) " "AS pjoin WHERE EXISTS (SELECT 1 FROM something " - "WHERE something.id = pjoin.something_id AND something.id = :id_1)" + "WHERE something.id = pjoin.something_id AND something.id = :id_1)", ) self.assert_compile( session.query(AbstractConcreteAbstraction).filter( - AbstractConcreteAbstraction.something_else.has(id=1)), + AbstractConcreteAbstraction.something_else.has(id=1) + ), "SELECT pjoin.id AS pjoin_id, pjoin.x AS pjoin_x, " "pjoin.y AS pjoin_y, pjoin.something_id AS pjoin_something_id, " "pjoin.type AS pjoin_type FROM " "(SELECT cca.id AS id, cca.x AS x, cca.y AS y, " "cca.something_id AS something_id, 'ccb' AS type FROM cca) " "AS pjoin WHERE EXISTS (SELECT 1 FROM something " - "WHERE something.id = pjoin.something_id AND something.id = :id_1)" + "WHERE something.id = pjoin.something_id AND something.id = :id_1)", ) def test_abstract_in_hierarchy(self): @@ -1729,10 +1929,11 @@ class ConcreteExtensionConfigTest( send_method = Column(String) class ActualDocument(ContactDocument): - __tablename__ = 'actual_documents' + __tablename__ = "actual_documents" __mapper_args__ = { - 'concrete': True, - 'polymorphic_identity': 'actual'} + "concrete": True, + "polymorphic_identity": "actual", + } id = Column(Integer, primary_key=True) @@ -1746,22 +1947,20 @@ class ConcreteExtensionConfigTest( "FROM (SELECT actual_documents.doctype AS doctype, " "actual_documents.send_method AS send_method, " "actual_documents.id AS id, 'actual' AS type " - "FROM actual_documents) AS pjoin" + "FROM actual_documents) AS pjoin", ) def test_column_attr_names(self): """test #3480""" class Document(Base, AbstractConcreteBase): - documentType = Column('documenttype', String) + documentType = Column("documenttype", String) class Offer(Document): - __tablename__ = 'offers' + __tablename__ = "offers" id = Column(Integer, primary_key=True) - __mapper_args__ = { - 'polymorphic_identity': 'offer' - } + __mapper_args__ = {"polymorphic_identity": "offer"} configure_mappers() session = Session() @@ -1770,12 +1969,12 @@ class ConcreteExtensionConfigTest( "SELECT pjoin.documenttype AS pjoin_documenttype, " "pjoin.id AS pjoin_id, pjoin.type AS pjoin_type FROM " "(SELECT offers.documenttype AS documenttype, offers.id AS id, " - "'offer' AS type FROM offers) AS pjoin" + "'offer' AS type FROM offers) AS pjoin", ) self.assert_compile( session.query(Document.documentType), "SELECT pjoin.documenttype AS pjoin_documenttype FROM " "(SELECT offers.documenttype AS documenttype, offers.id AS id, " - "'offer' AS type FROM offers) AS pjoin" + "'offer' AS type FROM offers) AS pjoin", ) diff --git a/test/ext/declarative/test_mixin.py b/test/ext/declarative/test_mixin.py index fa8c36656c..f51a75b3e3 100644 --- a/test/ext/declarative/test_mixin.py +++ b/test/ext/declarative/test_mixin.py @@ -1,13 +1,27 @@ -from sqlalchemy.testing import eq_, assert_raises, \ - assert_raises_message, is_, expect_warnings +from sqlalchemy.testing import ( + eq_, + assert_raises, + assert_raises_message, + is_, + expect_warnings, +) from sqlalchemy.ext import declarative as decl import sqlalchemy as sa from sqlalchemy import testing from sqlalchemy import Integer, String, ForeignKey, select, func from sqlalchemy.testing.schema import Table, Column -from sqlalchemy.orm import relationship, create_session, class_mapper, \ - configure_mappers, clear_mappers, \ - deferred, column_property, Session, base as orm_base, synonym +from sqlalchemy.orm import ( + relationship, + create_session, + class_mapper, + configure_mappers, + clear_mappers, + deferred, + column_property, + Session, + base as orm_base, + synonym, +) from sqlalchemy.util import classproperty from sqlalchemy.ext.declarative import declared_attr, declarative_base from sqlalchemy.orm import events as orm_events @@ -18,7 +32,6 @@ Base = None class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): - def setup(self): global Base Base = decl.declarative_base(testing.db) @@ -30,34 +43,32 @@ class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): class DeclarativeMixinTest(DeclarativeTestBase): - def test_simple(self): - class MyMixin(object): - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) def foo(self): - return 'bar' + str(self.id) + return "bar" + str(self.id) class MyModel(Base, MyMixin): - __tablename__ = 'test' + __tablename__ = "test" name = Column(String(100), nullable=False, index=True) Base.metadata.create_all() session = create_session() - session.add(MyModel(name='testing')) + session.add(MyModel(name="testing")) session.flush() session.expunge_all() obj = session.query(MyModel).one() eq_(obj.id, 1) - eq_(obj.name, 'testing') - eq_(obj.foo(), 'bar1') + eq_(obj.name, "testing") + eq_(obj.foo(), "bar1") def test_unique_column(self): - class MyMixin(object): id = Column(Integer, primary_key=True) @@ -65,19 +76,19 @@ class DeclarativeMixinTest(DeclarativeTestBase): class MyModel(Base, MyMixin): - __tablename__ = 'test' + __tablename__ = "test" assert MyModel.__table__.c.value.unique def test_hierarchical_bases(self): - class MyMixinParent: - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) def foo(self): - return 'bar' + str(self.id) + return "bar" + str(self.id) class MyMixin(MyMixinParent): @@ -85,19 +96,19 @@ class DeclarativeMixinTest(DeclarativeTestBase): class MyModel(Base, MyMixin): - __tablename__ = 'test' + __tablename__ = "test" name = Column(String(100), nullable=False, index=True) Base.metadata.create_all() session = create_session() - session.add(MyModel(name='testing', baz='fu')) + session.add(MyModel(name="testing", baz="fu")) session.flush() session.expunge_all() obj = session.query(MyModel).one() eq_(obj.id, 1) - eq_(obj.name, 'testing') - eq_(obj.foo(), 'bar1') - eq_(obj.baz, 'fu') + eq_(obj.name, "testing") + eq_(obj.foo(), "bar1") + eq_(obj.baz, "fu") def test_mixin_overrides(self): """test a mixin that overrides a column on a superclass.""" @@ -109,170 +120,158 @@ class DeclarativeMixinTest(DeclarativeTestBase): foo = Column(Integer) class MyModelA(Base, MixinA): - __tablename__ = 'testa' + __tablename__ = "testa" id = Column(Integer, primary_key=True) class MyModelB(Base, MixinB): - __tablename__ = 'testb' + __tablename__ = "testb" id = Column(Integer, primary_key=True) eq_(MyModelA.__table__.c.foo.type.__class__, String) eq_(MyModelB.__table__.c.foo.type.__class__, Integer) def test_not_allowed(self): - class MyMixin: - foo = Column(Integer, ForeignKey('bar.id')) + foo = Column(Integer, ForeignKey("bar.id")) def go(): class MyModel(Base, MyMixin): - __tablename__ = 'foo' + __tablename__ = "foo" assert_raises(sa.exc.InvalidRequestError, go) class MyRelMixin: - foo = relationship('Bar') + foo = relationship("Bar") def go(): class MyModel(Base, MyRelMixin): - __tablename__ = 'foo' + __tablename__ = "foo" assert_raises(sa.exc.InvalidRequestError, go) class MyDefMixin: - foo = deferred(Column('foo', String)) + foo = deferred(Column("foo", String)) def go(): class MyModel(Base, MyDefMixin): - __tablename__ = 'foo' + __tablename__ = "foo" assert_raises(sa.exc.InvalidRequestError, go) class MyCPropMixin: - foo = column_property(Column('foo', String)) + foo = column_property(Column("foo", String)) def go(): class MyModel(Base, MyCPropMixin): - __tablename__ = 'foo' + __tablename__ = "foo" assert_raises(sa.exc.InvalidRequestError, go) def test_table_name_inherited(self): - class MyMixin: - @declared_attr def __tablename__(cls): return cls.__name__.lower() + id = Column(Integer, primary_key=True) class MyModel(Base, MyMixin): pass - eq_(MyModel.__table__.name, 'mymodel') + eq_(MyModel.__table__.name, "mymodel") def test_classproperty_still_works(self): class MyMixin(object): - @classproperty def __tablename__(cls): return cls.__name__.lower() + id = Column(Integer, primary_key=True) class MyModel(Base, MyMixin): - __tablename__ = 'overridden' + __tablename__ = "overridden" - eq_(MyModel.__table__.name, 'overridden') + eq_(MyModel.__table__.name, "overridden") def test_table_name_not_inherited(self): - class MyMixin: - @declared_attr def __tablename__(cls): return cls.__name__.lower() + id = Column(Integer, primary_key=True) class MyModel(Base, MyMixin): - __tablename__ = 'overridden' + __tablename__ = "overridden" - eq_(MyModel.__table__.name, 'overridden') + eq_(MyModel.__table__.name, "overridden") def test_table_name_inheritance_order(self): - class MyMixin1: - @declared_attr def __tablename__(cls): - return cls.__name__.lower() + '1' + return cls.__name__.lower() + "1" class MyMixin2: - @declared_attr def __tablename__(cls): - return cls.__name__.lower() + '2' + return cls.__name__.lower() + "2" class MyModel(Base, MyMixin1, MyMixin2): id = Column(Integer, primary_key=True) - eq_(MyModel.__table__.name, 'mymodel1') + eq_(MyModel.__table__.name, "mymodel1") def test_table_name_dependent_on_subclass(self): - class MyHistoryMixin: - @declared_attr def __tablename__(cls): - return cls.parent_name + '_changelog' + return cls.parent_name + "_changelog" class MyModel(Base, MyHistoryMixin): - parent_name = 'foo' + parent_name = "foo" id = Column(Integer, primary_key=True) - eq_(MyModel.__table__.name, 'foo_changelog') + eq_(MyModel.__table__.name, "foo_changelog") def test_table_args_inherited(self): - class MyMixin: - __table_args__ = {'mysql_engine': 'InnoDB'} + __table_args__ = {"mysql_engine": "InnoDB"} class MyModel(Base, MyMixin): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) - eq_(MyModel.__table__.kwargs, {'mysql_engine': 'InnoDB'}) + eq_(MyModel.__table__.kwargs, {"mysql_engine": "InnoDB"}) def test_table_args_inherited_descriptor(self): - class MyMixin: - @declared_attr def __table_args__(cls): - return {'info': cls.__name__} + return {"info": cls.__name__} class MyModel(Base, MyMixin): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) - eq_(MyModel.__table__.info, 'MyModel') + eq_(MyModel.__table__.info, "MyModel") def test_table_args_inherited_single_table_inheritance(self): - class MyMixin: - __table_args__ = {'mysql_engine': 'InnoDB'} + __table_args__ = {"mysql_engine": "InnoDB"} class General(Base, MyMixin): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) type_ = Column(String(50)) - __mapper__args = {'polymorphic_on': type_} + __mapper__args = {"polymorphic_on": type_} class Specific(General): - __mapper_args__ = {'polymorphic_identity': 'specific'} + __mapper_args__ = {"polymorphic_identity": "specific"} assert Specific.__table__ is General.__table__ - eq_(General.__table__.kwargs, {'mysql_engine': 'InnoDB'}) + eq_(General.__table__.kwargs, {"mysql_engine": "InnoDB"}) def test_columns_single_table_inheritance(self): """Test a column on a mixin with an alternate attribute name, @@ -283,24 +282,26 @@ class DeclarativeMixinTest(DeclarativeTestBase): """ class MyMixin(object): - foo = Column('foo', Integer) - bar = Column('bar_newname', Integer) + foo = Column("foo", Integer) + bar = Column("bar_newname", Integer) class General(Base, MyMixin): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) type_ = Column(String(50)) - __mapper__args = {'polymorphic_on': type_} + __mapper__args = {"polymorphic_on": type_} class Specific(General): - __mapper_args__ = {'polymorphic_identity': 'specific'} + __mapper_args__ = {"polymorphic_identity": "specific"} assert General.bar.prop.columns[0] is General.__table__.c.bar_newname assert len(General.bar.prop.columns) == 1 assert Specific.bar.prop is General.bar.prop - @testing.skip_if(lambda: testing.against('oracle'), - "Test has an empty insert in it at the moment") + @testing.skip_if( + lambda: testing.against("oracle"), + "Test has an empty insert in it at the moment", + ) def test_columns_single_inheritance_conflict_resolution(self): """Test that a declared_attr can return the existing column and it will be ignored. this allows conditional columns to be added. @@ -308,17 +309,16 @@ class DeclarativeMixinTest(DeclarativeTestBase): See [ticket:2472]. """ + class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) class Mixin(object): - @declared_attr def target_id(cls): return cls.__table__.c.get( - 'target_id', - Column(Integer, ForeignKey('other.id')) + "target_id", Column(Integer, ForeignKey("other.id")) ) @declared_attr @@ -334,26 +334,23 @@ class DeclarativeMixinTest(DeclarativeTestBase): """single table inheritance""" class Other(Base): - __tablename__ = 'other' + __tablename__ = "other" id = Column(Integer, primary_key=True) is_( Engineer.target_id.property.columns[0], - Person.__table__.c.target_id + Person.__table__.c.target_id, ) is_( - Manager.target_id.property.columns[0], - Person.__table__.c.target_id + Manager.target_id.property.columns[0], Person.__table__.c.target_id ) # do a brief round trip on this Base.metadata.create_all() session = Session() o1, o2 = Other(), Other() - session.add_all([ - Engineer(target=o1), - Manager(target=o2), - Manager(target=o1) - ]) + session.add_all( + [Engineer(target=o1), Manager(target=o2), Manager(target=o1)] + ) session.commit() eq_(session.query(Engineer).first().target, o1) @@ -366,19 +363,19 @@ class DeclarativeMixinTest(DeclarativeTestBase): """ class MyMixin(object): - foo = Column('foo', Integer) - bar = Column('bar_newname', Integer) + foo = Column("foo", Integer) + bar = Column("bar_newname", Integer) class General(Base, MyMixin): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) type_ = Column(String(50)) - __mapper__args = {'polymorphic_on': type_} + __mapper__args = {"polymorphic_on": type_} class Specific(General): - __tablename__ = 'sub' - id = Column(Integer, ForeignKey('test.id'), primary_key=True) - __mapper_args__ = {'polymorphic_identity': 'specific'} + __tablename__ = "sub" + id = Column(Integer, ForeignKey("test.id"), primary_key=True) + __mapper_args__ = {"polymorphic_identity": "specific"} assert General.bar.prop.columns[0] is General.__table__.c.bar_newname assert len(General.bar.prop.columns) == 1 @@ -393,15 +390,15 @@ class DeclarativeMixinTest(DeclarativeTestBase): """ class General(Base): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) - general_id = Column(Integer, ForeignKey('test.id')) + general_id = Column(Integer, ForeignKey("test.id")) type_ = relationship("General") class Specific(General): - __tablename__ = 'sub' - id = Column(Integer, ForeignKey('test.id'), primary_key=True) - type_ = Column('foob', String(50)) + __tablename__ = "sub" + id = Column(Integer, ForeignKey("test.id"), primary_key=True) + type_ = Column("foob", String(50)) assert isinstance(General.type_.property, sa.orm.RelationshipProperty) assert Specific.type_.property.columns[0] is Specific.__table__.c.foob @@ -414,30 +411,30 @@ class DeclarativeMixinTest(DeclarativeTestBase): def go(): class General(Base): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) - type_ = Column('foob', Integer) + type_ = Column("foob", Integer) class Specific(General): - __tablename__ = 'sub' - id = Column(Integer, ForeignKey('test.id'), primary_key=True) - specific_id = Column(Integer, ForeignKey('sub.id')) + __tablename__ = "sub" + id = Column(Integer, ForeignKey("test.id"), primary_key=True) + specific_id = Column(Integer, ForeignKey("sub.id")) type_ = relationship("Specific") + assert_raises_message( sa.exc.ArgumentError, "column 'foob' conflicts with property", go ) def test_table_args_overridden(self): - class MyMixin: - __table_args__ = {'mysql_engine': 'Foo'} + __table_args__ = {"mysql_engine": "Foo"} class MyModel(Base, MyMixin): - __tablename__ = 'test' - __table_args__ = {'mysql_engine': 'InnoDB'} + __tablename__ = "test" + __table_args__ = {"mysql_engine": "InnoDB"} id = Column(Integer, primary_key=True) - eq_(MyModel.__table__.kwargs, {'mysql_engine': 'InnoDB'}) + eq_(MyModel.__table__.kwargs, {"mysql_engine": "InnoDB"}) @testing.teardown_events(orm_events.MapperEvents) def test_declare_first_mixin(self): @@ -453,7 +450,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): canary.declare_last__(cls) class MyModel(Base, MyMixin): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) configure_mappers() @@ -463,7 +460,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): [ mock.call.declare_first__(MyModel), mock.call.declare_last__(MyModel), - ] + ], ) @testing.teardown_events(orm_events.MapperEvents) @@ -481,10 +478,11 @@ class DeclarativeMixinTest(DeclarativeTestBase): class Base(MyMixin): pass + Base = declarative_base(cls=Base) class MyModel(Base): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) configure_mappers() @@ -494,7 +492,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): [ mock.call.declare_first__(MyModel), mock.call.declare_last__(MyModel), - ] + ], ) @testing.teardown_events(orm_events.MapperEvents) @@ -502,7 +500,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): canary = mock.Mock() class MyOtherModel(Base): - __tablename__ = 'test2' + __tablename__ = "test2" id = Column(Integer, primary_key=True) @classmethod @@ -519,33 +517,30 @@ class DeclarativeMixinTest(DeclarativeTestBase): canary.mock_calls, [ mock.call.declare_first__(MyOtherModel), - mock.call.declare_last__(MyOtherModel) - ] + mock.call.declare_last__(MyOtherModel), + ], ) def test_mapper_args_declared_attr(self): - class ComputedMapperArgs: - @declared_attr def __mapper_args__(cls): - if cls.__name__ == 'Person': - return {'polymorphic_on': cls.discriminator} + if cls.__name__ == "Person": + return {"polymorphic_on": cls.discriminator} else: - return {'polymorphic_identity': cls.__name__} + return {"polymorphic_identity": cls.__name__} class Person(Base, ComputedMapperArgs): - __tablename__ = 'people' + __tablename__ = "people" id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) + discriminator = Column("type", String(50)) class Engineer(Person): pass configure_mappers() - assert class_mapper(Person).polymorphic_on \ - is Person.__table__.c.type - eq_(class_mapper(Engineer).polymorphic_identity, 'Engineer') + assert class_mapper(Person).polymorphic_on is Person.__table__.c.type + eq_(class_mapper(Engineer).polymorphic_identity, "Engineer") def test_mapper_args_declared_attr_two(self): @@ -553,201 +548,190 @@ class DeclarativeMixinTest(DeclarativeTestBase): # ComputedMapperArgs on both classes for no apparent reason. class ComputedMapperArgs: - @declared_attr def __mapper_args__(cls): - if cls.__name__ == 'Person': - return {'polymorphic_on': cls.discriminator} + if cls.__name__ == "Person": + return {"polymorphic_on": cls.discriminator} else: - return {'polymorphic_identity': cls.__name__} + return {"polymorphic_identity": cls.__name__} class Person(Base, ComputedMapperArgs): - __tablename__ = 'people' + __tablename__ = "people" id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) + discriminator = Column("type", String(50)) class Engineer(Person, ComputedMapperArgs): pass configure_mappers() - assert class_mapper(Person).polymorphic_on \ - is Person.__table__.c.type - eq_(class_mapper(Engineer).polymorphic_identity, 'Engineer') + assert class_mapper(Person).polymorphic_on is Person.__table__.c.type + eq_(class_mapper(Engineer).polymorphic_identity, "Engineer") def test_table_args_composite(self): - class MyMixin1: - __table_args__ = {'info': {'baz': 'bob'}} + __table_args__ = {"info": {"baz": "bob"}} class MyMixin2: - __table_args__ = {'info': {'foo': 'bar'}} + __table_args__ = {"info": {"foo": "bar"}} class MyModel(Base, MyMixin1, MyMixin2): - __tablename__ = 'test' + __tablename__ = "test" @declared_attr def __table_args__(self): info = {} args = dict(info=info) - info.update(MyMixin1.__table_args__['info']) - info.update(MyMixin2.__table_args__['info']) + info.update(MyMixin1.__table_args__["info"]) + info.update(MyMixin2.__table_args__["info"]) return args + id = Column(Integer, primary_key=True) - eq_(MyModel.__table__.info, {'foo': 'bar', 'baz': 'bob'}) + eq_(MyModel.__table__.info, {"foo": "bar", "baz": "bob"}) def test_mapper_args_inherited(self): - class MyMixin: - __mapper_args__ = {'always_refresh': True} + __mapper_args__ = {"always_refresh": True} class MyModel(Base, MyMixin): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) eq_(MyModel.__mapper__.always_refresh, True) def test_mapper_args_inherited_descriptor(self): - class MyMixin: - @declared_attr def __mapper_args__(cls): # tenuous, but illustrates the problem! - if cls.__name__ == 'MyModel': + if cls.__name__ == "MyModel": return dict(always_refresh=True) else: return dict(always_refresh=False) class MyModel(Base, MyMixin): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) eq_(MyModel.__mapper__.always_refresh, True) def test_mapper_args_polymorphic_on_inherited(self): - class MyMixin: type_ = Column(String(50)) - __mapper_args__ = {'polymorphic_on': type_} + __mapper_args__ = {"polymorphic_on": type_} class MyModel(Base, MyMixin): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) col = MyModel.__mapper__.polymorphic_on - eq_(col.name, 'type_') + eq_(col.name, "type_") assert col.table is not None def test_mapper_args_overridden(self): - class MyMixin: __mapper_args__ = dict(always_refresh=True) class MyModel(Base, MyMixin): - __tablename__ = 'test' + __tablename__ = "test" __mapper_args__ = dict(always_refresh=False) id = Column(Integer, primary_key=True) eq_(MyModel.__mapper__.always_refresh, False) def test_mapper_args_composite(self): - class MyMixin1: type_ = Column(String(50)) - __mapper_args__ = {'polymorphic_on': type_} + __mapper_args__ = {"polymorphic_on": type_} class MyMixin2: - __mapper_args__ = {'always_refresh': True} + __mapper_args__ = {"always_refresh": True} class MyModel(Base, MyMixin1, MyMixin2): - __tablename__ = 'test' + __tablename__ = "test" @declared_attr def __mapper_args__(cls): args = {} args.update(MyMixin1.__mapper_args__) args.update(MyMixin2.__mapper_args__) - if cls.__name__ != 'MyModel': - args.pop('polymorphic_on') - args['polymorphic_identity'] = cls.__name__ + if cls.__name__ != "MyModel": + args.pop("polymorphic_on") + args["polymorphic_identity"] = cls.__name__ return args + id = Column(Integer, primary_key=True) class MySubModel(MyModel): pass - eq_( - MyModel.__mapper__.polymorphic_on.name, - 'type_' - ) + eq_(MyModel.__mapper__.polymorphic_on.name, "type_") assert MyModel.__mapper__.polymorphic_on.table is not None eq_(MyModel.__mapper__.always_refresh, True) eq_(MySubModel.__mapper__.always_refresh, True) - eq_(MySubModel.__mapper__.polymorphic_identity, 'MySubModel') + eq_(MySubModel.__mapper__.polymorphic_identity, "MySubModel") def test_mapper_args_property(self): class MyModel(Base): - @declared_attr def __tablename__(cls): return cls.__name__.lower() @declared_attr def __table_args__(cls): - return {'mysql_engine': 'InnoDB'} + return {"mysql_engine": "InnoDB"} @declared_attr def __mapper_args__(cls): args = {} - args['polymorphic_identity'] = cls.__name__ + args["polymorphic_identity"] = cls.__name__ return args + id = Column(Integer, primary_key=True) class MySubModel(MyModel): - id = Column(Integer, ForeignKey('mymodel.id'), primary_key=True) + id = Column(Integer, ForeignKey("mymodel.id"), primary_key=True) class MySubModel2(MyModel): - __tablename__ = 'sometable' - id = Column(Integer, ForeignKey('mymodel.id'), primary_key=True) + __tablename__ = "sometable" + id = Column(Integer, ForeignKey("mymodel.id"), primary_key=True) - eq_(MyModel.__mapper__.polymorphic_identity, 'MyModel') - eq_(MySubModel.__mapper__.polymorphic_identity, 'MySubModel') - eq_(MyModel.__table__.kwargs['mysql_engine'], 'InnoDB') - eq_(MySubModel.__table__.kwargs['mysql_engine'], 'InnoDB') - eq_(MySubModel2.__table__.kwargs['mysql_engine'], 'InnoDB') - eq_(MyModel.__table__.name, 'mymodel') - eq_(MySubModel.__table__.name, 'mysubmodel') + eq_(MyModel.__mapper__.polymorphic_identity, "MyModel") + eq_(MySubModel.__mapper__.polymorphic_identity, "MySubModel") + eq_(MyModel.__table__.kwargs["mysql_engine"], "InnoDB") + eq_(MySubModel.__table__.kwargs["mysql_engine"], "InnoDB") + eq_(MySubModel2.__table__.kwargs["mysql_engine"], "InnoDB") + eq_(MyModel.__table__.name, "mymodel") + eq_(MySubModel.__table__.name, "mysubmodel") def test_mapper_args_custom_base(self): """test the @declared_attr approach from a custom base.""" class Base(object): - @declared_attr def __tablename__(cls): return cls.__name__.lower() @declared_attr def __table_args__(cls): - return {'mysql_engine': 'InnoDB'} + return {"mysql_engine": "InnoDB"} @declared_attr def id(self): @@ -761,78 +745,77 @@ class DeclarativeMixinTest(DeclarativeTestBase): class MyOtherClass(Base): pass - eq_(MyClass.__table__.kwargs['mysql_engine'], 'InnoDB') - eq_(MyClass.__table__.name, 'myclass') - eq_(MyOtherClass.__table__.name, 'myotherclass') + eq_(MyClass.__table__.kwargs["mysql_engine"], "InnoDB") + eq_(MyClass.__table__.name, "myclass") + eq_(MyOtherClass.__table__.name, "myotherclass") assert MyClass.__table__.c.id.table is MyClass.__table__ assert MyOtherClass.__table__.c.id.table is MyOtherClass.__table__ def test_single_table_no_propagation(self): - class IdColumn: id = Column(Integer, primary_key=True) class Generic(Base, IdColumn): - __tablename__ = 'base' - discriminator = Column('type', String(50)) + __tablename__ = "base" + discriminator = Column("type", String(50)) __mapper_args__ = dict(polymorphic_on=discriminator) value = Column(Integer()) class Specific(Generic): - __mapper_args__ = dict(polymorphic_identity='specific') + __mapper_args__ = dict(polymorphic_identity="specific") assert Specific.__table__ is Generic.__table__ - eq_(list(Generic.__table__.c.keys()), ['id', 'type', 'value']) - assert class_mapper(Specific).polymorphic_on \ - is Generic.__table__.c.type - eq_(class_mapper(Specific).polymorphic_identity, 'specific') + eq_(list(Generic.__table__.c.keys()), ["id", "type", "value"]) + assert ( + class_mapper(Specific).polymorphic_on is Generic.__table__.c.type + ) + eq_(class_mapper(Specific).polymorphic_identity, "specific") def test_joined_table_propagation(self): - class CommonMixin: - @declared_attr def __tablename__(cls): return cls.__name__.lower() - __table_args__ = {'mysql_engine': 'InnoDB'} + + __table_args__ = {"mysql_engine": "InnoDB"} timestamp = Column(Integer) id = Column(Integer, primary_key=True) class Generic(Base, CommonMixin): - discriminator = Column('python_type', String(50)) + discriminator = Column("python_type", String(50)) __mapper_args__ = dict(polymorphic_on=discriminator) class Specific(Generic): - __mapper_args__ = dict(polymorphic_identity='specific') - id = Column(Integer, ForeignKey('generic.id'), - primary_key=True) + __mapper_args__ = dict(polymorphic_identity="specific") + id = Column(Integer, ForeignKey("generic.id"), primary_key=True) - eq_(Generic.__table__.name, 'generic') - eq_(Specific.__table__.name, 'specific') - eq_(list(Generic.__table__.c.keys()), ['timestamp', 'id', - 'python_type']) - eq_(list(Specific.__table__.c.keys()), ['id']) - eq_(Generic.__table__.kwargs, {'mysql_engine': 'InnoDB'}) - eq_(Specific.__table__.kwargs, {'mysql_engine': 'InnoDB'}) + eq_(Generic.__table__.name, "generic") + eq_(Specific.__table__.name, "specific") + eq_( + list(Generic.__table__.c.keys()), + ["timestamp", "id", "python_type"], + ) + eq_(list(Specific.__table__.c.keys()), ["id"]) + eq_(Generic.__table__.kwargs, {"mysql_engine": "InnoDB"}) + eq_(Specific.__table__.kwargs, {"mysql_engine": "InnoDB"}) def test_some_propagation(self): - class CommonMixin: - @declared_attr def __tablename__(cls): return cls.__name__.lower() - __table_args__ = {'mysql_engine': 'InnoDB'} + + __table_args__ = {"mysql_engine": "InnoDB"} timestamp = Column(Integer) class BaseType(Base, CommonMixin): - discriminator = Column('type', String(50)) + discriminator = Column("type", String(50)) __mapper_args__ = dict(polymorphic_on=discriminator) id = Column(Integer, primary_key=True) value = Column(Integer()) @@ -840,22 +823,23 @@ class DeclarativeMixinTest(DeclarativeTestBase): class Single(BaseType): __tablename__ = None - __mapper_args__ = dict(polymorphic_identity='type1') + __mapper_args__ = dict(polymorphic_identity="type1") class Joined(BaseType): - __mapper_args__ = dict(polymorphic_identity='type2') - id = Column(Integer, ForeignKey('basetype.id'), - primary_key=True) + __mapper_args__ = dict(polymorphic_identity="type2") + id = Column(Integer, ForeignKey("basetype.id"), primary_key=True) - eq_(BaseType.__table__.name, 'basetype') - eq_(list(BaseType.__table__.c.keys()), ['timestamp', 'type', 'id', - 'value']) - eq_(BaseType.__table__.kwargs, {'mysql_engine': 'InnoDB'}) + eq_(BaseType.__table__.name, "basetype") + eq_( + list(BaseType.__table__.c.keys()), + ["timestamp", "type", "id", "value"], + ) + eq_(BaseType.__table__.kwargs, {"mysql_engine": "InnoDB"}) assert Single.__table__ is BaseType.__table__ - eq_(Joined.__table__.name, 'joined') - eq_(list(Joined.__table__.c.keys()), ['id']) - eq_(Joined.__table__.kwargs, {'mysql_engine': 'InnoDB'}) + eq_(Joined.__table__.name, "joined") + eq_(list(Joined.__table__.c.keys()), ["id"]) + eq_(Joined.__table__.kwargs, {"mysql_engine": "InnoDB"}) def test_col_copy_vs_declared_attr_joined_propagation(self): class Mixin(object): @@ -866,38 +850,38 @@ class DeclarativeMixinTest(DeclarativeTestBase): return Column(Integer) class A(Mixin, Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) class B(A): - __tablename__ = 'b' - id = Column(Integer, ForeignKey('a.id'), primary_key=True) + __tablename__ = "b" + id = Column(Integer, ForeignKey("a.id"), primary_key=True) - assert 'a' in A.__table__.c - assert 'b' in A.__table__.c - assert 'a' not in B.__table__.c - assert 'b' not in B.__table__.c + assert "a" in A.__table__.c + assert "b" in A.__table__.c + assert "a" not in B.__table__.c + assert "b" not in B.__table__.c def test_col_copy_vs_declared_attr_joined_propagation_newname(self): class Mixin(object): - a = Column('a1', Integer) + a = Column("a1", Integer) @declared_attr def b(cls): - return Column('b1', Integer) + return Column("b1", Integer) class A(Mixin, Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) class B(A): - __tablename__ = 'b' - id = Column(Integer, ForeignKey('a.id'), primary_key=True) + __tablename__ = "b" + id = Column(Integer, ForeignKey("a.id"), primary_key=True) - assert 'a1' in A.__table__.c - assert 'b1' in A.__table__.c - assert 'a1' not in B.__table__.c - assert 'b1' not in B.__table__.c + assert "a1" in A.__table__.c + assert "b1" in A.__table__.c + assert "a1" not in B.__table__.c + assert "b1" not in B.__table__.c def test_col_copy_vs_declared_attr_single_propagation(self): class Mixin(object): @@ -908,19 +892,17 @@ class DeclarativeMixinTest(DeclarativeTestBase): return Column(Integer) class A(Mixin, Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) class B(A): pass - assert 'a' in A.__table__.c - assert 'b' in A.__table__.c + assert "a" in A.__table__.c + assert "b" in A.__table__.c def test_non_propagating_mixin(self): - class NoJoinedTableNameMixin: - @declared_attr def __tablename__(cls): if decl.has_inherited_table(cls): @@ -929,105 +911,109 @@ class DeclarativeMixinTest(DeclarativeTestBase): class BaseType(Base, NoJoinedTableNameMixin): - discriminator = Column('type', String(50)) + discriminator = Column("type", String(50)) __mapper_args__ = dict(polymorphic_on=discriminator) id = Column(Integer, primary_key=True) value = Column(Integer()) class Specific(BaseType): - __mapper_args__ = dict(polymorphic_identity='specific') + __mapper_args__ = dict(polymorphic_identity="specific") - eq_(BaseType.__table__.name, 'basetype') - eq_(list(BaseType.__table__.c.keys()), ['type', 'id', 'value']) + eq_(BaseType.__table__.name, "basetype") + eq_(list(BaseType.__table__.c.keys()), ["type", "id", "value"]) assert Specific.__table__ is BaseType.__table__ - assert class_mapper(Specific).polymorphic_on \ - is BaseType.__table__.c.type - eq_(class_mapper(Specific).polymorphic_identity, 'specific') + assert ( + class_mapper(Specific).polymorphic_on is BaseType.__table__.c.type + ) + eq_(class_mapper(Specific).polymorphic_identity, "specific") def test_non_propagating_mixin_used_for_joined(self): - class TableNameMixin: - @declared_attr def __tablename__(cls): - if decl.has_inherited_table(cls) and TableNameMixin \ - not in cls.__bases__: + if ( + decl.has_inherited_table(cls) + and TableNameMixin not in cls.__bases__ + ): return None return cls.__name__.lower() class BaseType(Base, TableNameMixin): - discriminator = Column('type', String(50)) + discriminator = Column("type", String(50)) __mapper_args__ = dict(polymorphic_on=discriminator) id = Column(Integer, primary_key=True) value = Column(Integer()) class Specific(BaseType, TableNameMixin): - __mapper_args__ = dict(polymorphic_identity='specific') - id = Column(Integer, ForeignKey('basetype.id'), - primary_key=True) + __mapper_args__ = dict(polymorphic_identity="specific") + id = Column(Integer, ForeignKey("basetype.id"), primary_key=True) - eq_(BaseType.__table__.name, 'basetype') - eq_(list(BaseType.__table__.c.keys()), ['type', 'id', 'value']) - eq_(Specific.__table__.name, 'specific') - eq_(list(Specific.__table__.c.keys()), ['id']) + eq_(BaseType.__table__.name, "basetype") + eq_(list(BaseType.__table__.c.keys()), ["type", "id", "value"]) + eq_(Specific.__table__.name, "specific") + eq_(list(Specific.__table__.c.keys()), ["id"]) def test_single_back_propagate(self): - class ColumnMixin: timestamp = Column(Integer) class BaseType(Base): - __tablename__ = 'foo' - discriminator = Column('type', String(50)) + __tablename__ = "foo" + discriminator = Column("type", String(50)) __mapper_args__ = dict(polymorphic_on=discriminator) id = Column(Integer, primary_key=True) class Specific(BaseType, ColumnMixin): - __mapper_args__ = dict(polymorphic_identity='specific') + __mapper_args__ = dict(polymorphic_identity="specific") - eq_(list(BaseType.__table__.c.keys()), ['type', 'id', 'timestamp']) + eq_(list(BaseType.__table__.c.keys()), ["type", "id", "timestamp"]) def test_table_in_model_and_same_column_in_mixin(self): - class ColumnMixin: data = Column(Integer) class Model(Base, ColumnMixin): - __table__ = Table('foo', Base.metadata, - Column('data', Integer), - Column('id', Integer, primary_key=True)) + __table__ = Table( + "foo", + Base.metadata, + Column("data", Integer), + Column("id", Integer, primary_key=True), + ) model_col = Model.__table__.c.data mixin_col = ColumnMixin.data assert model_col is not mixin_col - eq_(model_col.name, 'data') + eq_(model_col.name, "data") assert model_col.type.__class__ is mixin_col.type.__class__ def test_table_in_model_and_different_named_column_in_mixin(self): - class ColumnMixin: tada = Column(Integer) def go(): - class Model(Base, ColumnMixin): - __table__ = Table('foo', Base.metadata, - Column('data', Integer), - Column('id', Integer, primary_key=True)) + __table__ = Table( + "foo", + Base.metadata, + Column("data", Integer), + Column("id", Integer, primary_key=True), + ) foo = relationship("Dest") - assert_raises_message(sa.exc.ArgumentError, - "Can't add additional column 'tada' when " - "specifying __table__", go) + assert_raises_message( + sa.exc.ArgumentError, + "Can't add additional column 'tada' when " "specifying __table__", + go, + ) def test_table_in_model_and_different_named_alt_key_column_in_mixin(self): @@ -1036,42 +1022,48 @@ class DeclarativeMixinTest(DeclarativeTestBase): # keyed to 'tada'. class ColumnMixin: - tada = Column('foobar', Integer) + tada = Column("foobar", Integer) def go(): - class Model(Base, ColumnMixin): - __table__ = Table('foo', Base.metadata, - Column('data', Integer), - Column('tada', Integer), - Column('id', Integer, primary_key=True)) + __table__ = Table( + "foo", + Base.metadata, + Column("data", Integer), + Column("tada", Integer), + Column("id", Integer, primary_key=True), + ) foo = relationship("Dest") - assert_raises_message(sa.exc.ArgumentError, - "Can't add additional column 'foobar' when " - "specifying __table__", go) + assert_raises_message( + sa.exc.ArgumentError, + "Can't add additional column 'foobar' when " + "specifying __table__", + go, + ) def test_table_in_model_overrides_different_typed_column_in_mixin(self): - class ColumnMixin: data = Column(String) class Model(Base, ColumnMixin): - __table__ = Table('foo', Base.metadata, - Column('data', Integer), - Column('id', Integer, primary_key=True)) + __table__ = Table( + "foo", + Base.metadata, + Column("data", Integer), + Column("id", Integer, primary_key=True), + ) model_col = Model.__table__.c.data mixin_col = ColumnMixin.data assert model_col is not mixin_col - eq_(model_col.name, 'data') + eq_(model_col.name, "data") assert model_col.type.__class__ is Integer def test_mixin_column_ordering(self): - class Foo(object): col1 = Column(Integer) @@ -1085,112 +1077,109 @@ class DeclarativeMixinTest(DeclarativeTestBase): class Model(Base, Foo, Bar): id = Column(Integer, primary_key=True) - __tablename__ = 'model' + __tablename__ = "model" - eq_(list(Model.__table__.c.keys()), ['col1', 'col3', 'col2', 'col4', - 'id']) + eq_( + list(Model.__table__.c.keys()), + ["col1", "col3", "col2", "col4", "id"], + ) def test_honor_class_mro_one(self): class HasXMixin(object): - @declared_attr def x(self): return Column(Integer) class Parent(HasXMixin, Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) class Child(Parent): - __tablename__ = 'child' - id = Column(Integer, ForeignKey('parent.id'), primary_key=True) + __tablename__ = "child" + id = Column(Integer, ForeignKey("parent.id"), primary_key=True) assert "x" not in Child.__table__.c def test_honor_class_mro_two(self): class HasXMixin(object): - @declared_attr def x(self): return Column(Integer) class Parent(HasXMixin, Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) def x(self): return "hi" class C(Parent): - __tablename__ = 'c' - id = Column(Integer, ForeignKey('parent.id'), primary_key=True) + __tablename__ = "c" + id = Column(Integer, ForeignKey("parent.id"), primary_key=True) - assert C().x() == 'hi' + assert C().x() == "hi" def test_arbitrary_attrs_one(self): class HasMixin(object): - @declared_attr def some_attr(cls): return cls.__name__ + "SOME ATTR" class Mapped(HasMixin, Base): - __tablename__ = 't' + __tablename__ = "t" id = Column(Integer, primary_key=True) eq_(Mapped.some_attr, "MappedSOME ATTR") - eq_(Mapped.__dict__['some_attr'], "MappedSOME ATTR") + eq_(Mapped.__dict__["some_attr"], "MappedSOME ATTR") def test_arbitrary_attrs_two(self): from sqlalchemy.ext.associationproxy import association_proxy class FilterA(Base): - __tablename__ = 'filter_a' + __tablename__ = "filter_a" id = Column(Integer(), primary_key=True) - parent_id = Column(Integer(), - ForeignKey('type_a.id')) + parent_id = Column(Integer(), ForeignKey("type_a.id")) filter = Column(String()) def __init__(self, filter_, **kw): self.filter = filter_ class FilterB(Base): - __tablename__ = 'filter_b' + __tablename__ = "filter_b" id = Column(Integer(), primary_key=True) - parent_id = Column(Integer(), - ForeignKey('type_b.id')) + parent_id = Column(Integer(), ForeignKey("type_b.id")) filter = Column(String()) def __init__(self, filter_, **kw): self.filter = filter_ class FilterMixin(object): - @declared_attr def _filters(cls): - return relationship(cls.filter_class, - cascade='all,delete,delete-orphan') + return relationship( + cls.filter_class, cascade="all,delete,delete-orphan" + ) @declared_attr def filters(cls): - return association_proxy('_filters', 'filter') + return association_proxy("_filters", "filter") class TypeA(Base, FilterMixin): - __tablename__ = 'type_a' + __tablename__ = "type_a" filter_class = FilterA id = Column(Integer(), primary_key=True) class TypeB(Base, FilterMixin): - __tablename__ = 'type_b' + __tablename__ = "type_b" filter_class = FilterB id = Column(Integer(), primary_key=True) - TypeA(filters=['foo']) - TypeB(filters=['foo']) + TypeA(filters=["foo"]) + TypeB(filters=["foo"]) def test_arbitrary_attrs_three(self): class Mapped(Base): - __tablename__ = 't' + __tablename__ = "t" id = Column(Integer, primary_key=True) @declared_attr @@ -1198,7 +1187,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): return cls.__name__ + "SOME ATTR" eq_(Mapped.some_attr, "MappedSOME ATTR") - eq_(Mapped.__dict__['some_attr'], "MappedSOME ATTR") + eq_(Mapped.__dict__["some_attr"], "MappedSOME ATTR") def test_arbitrary_attrs_doesnt_apply_to_abstract_declared_attr(self): names = ["name1", "name2", "name3"] @@ -1211,21 +1200,21 @@ class DeclarativeMixinTest(DeclarativeTestBase): return names.pop(0) class M1(SomeAbstract): - __tablename__ = 't1' + __tablename__ = "t1" id = Column(Integer, primary_key=True) class M2(SomeAbstract): - __tablename__ = 't2' + __tablename__ = "t2" id = Column(Integer, primary_key=True) - eq_(M1.__dict__['some_attr'], 'name1') - eq_(M2.__dict__['some_attr'], 'name2') + eq_(M1.__dict__["some_attr"], "name1") + eq_(M2.__dict__["some_attr"], "name2") def test_arbitrary_attrs_doesnt_apply_to_prepare_nocascade(self): names = ["name1", "name2", "name3"] class SomeAbstract(Base): - __tablename__ = 't0' + __tablename__ = "t0" __no_table__ = True # used by AbstractConcreteBase @@ -1238,63 +1227,65 @@ class DeclarativeMixinTest(DeclarativeTestBase): return names.pop(0) class M1(SomeAbstract): - __tablename__ = 't1' + __tablename__ = "t1" id = Column(Integer, primary_key=True) class M2(SomeAbstract): - __tablename__ = 't2' + __tablename__ = "t2" id = Column(Integer, primary_key=True) eq_(M1.some_attr, "name2") eq_(M2.some_attr, "name3") - eq_(M1.__dict__['some_attr'], 'name2') - eq_(M2.__dict__['some_attr'], 'name3') - assert isinstance(SomeAbstract.__dict__['some_attr'], declared_attr) + eq_(M1.__dict__["some_attr"], "name2") + eq_(M2.__dict__["some_attr"], "name3") + assert isinstance(SomeAbstract.__dict__["some_attr"], declared_attr) class DeclarativeMixinPropertyTest( - DeclarativeTestBase, - testing.AssertsCompiledSQL): - + DeclarativeTestBase, testing.AssertsCompiledSQL +): def test_column_property(self): - class MyMixin(object): - @declared_attr def prop_hoho(cls): - return column_property(Column('prop', String(50))) + return column_property(Column("prop", String(50))) class MyModel(Base, MyMixin): - __tablename__ = 'test' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "test" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) class MyOtherModel(Base, MyMixin): - __tablename__ = 'othertest' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "othertest" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) assert MyModel.__table__.c.prop is not None assert MyOtherModel.__table__.c.prop is not None - assert MyModel.__table__.c.prop \ - is not MyOtherModel.__table__.c.prop - assert MyModel.prop_hoho.property.columns \ - == [MyModel.__table__.c.prop] - assert MyOtherModel.prop_hoho.property.columns \ - == [MyOtherModel.__table__.c.prop] - assert MyModel.prop_hoho.property \ - is not MyOtherModel.prop_hoho.property + assert MyModel.__table__.c.prop is not MyOtherModel.__table__.c.prop + assert MyModel.prop_hoho.property.columns == [MyModel.__table__.c.prop] + assert MyOtherModel.prop_hoho.property.columns == [ + MyOtherModel.__table__.c.prop + ] + assert ( + MyModel.prop_hoho.property is not MyOtherModel.prop_hoho.property + ) Base.metadata.create_all() sess = create_session() - m1, m2 = MyModel(prop_hoho='foo'), MyOtherModel(prop_hoho='bar') + m1, m2 = MyModel(prop_hoho="foo"), MyOtherModel(prop_hoho="bar") sess.add_all([m1, m2]) sess.flush() - eq_(sess.query(MyModel).filter(MyModel.prop_hoho == 'foo' - ).one(), m1) - eq_(sess.query(MyOtherModel).filter(MyOtherModel.prop_hoho - == 'bar').one(), m2) + eq_(sess.query(MyModel).filter(MyModel.prop_hoho == "foo").one(), m1) + eq_( + sess.query(MyOtherModel) + .filter(MyOtherModel.prop_hoho == "bar") + .one(), + m2, + ) def test_doc(self): """test documentation transfer. @@ -1305,7 +1296,6 @@ class DeclarativeMixinPropertyTest( """ class MyMixin(object): - @declared_attr def type_(cls): """this is a document.""" @@ -1320,7 +1310,7 @@ class DeclarativeMixinPropertyTest( class MyModel(Base, MyMixin): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) configure_mappers() @@ -1346,7 +1336,7 @@ class DeclarativeMixinPropertyTest( return hp2 class Base(declarative_base(), Mixin): - __tablename__ = 'test' + __tablename__ = "test" id = Column(String, primary_key=True) class Derived(Base): @@ -1363,15 +1353,9 @@ class DeclarativeMixinPropertyTest( # and adjusts b1 = inspect(Base) d1 = inspect(Derived) - is_( - b1.all_orm_descriptors['hp1'], - d1.all_orm_descriptors['hp1'], - ) + is_(b1.all_orm_descriptors["hp1"], d1.all_orm_descriptors["hp1"]) - is_( - b1.all_orm_descriptors['hp2'], - d1.all_orm_descriptors['hp2'], - ) + is_(b1.all_orm_descriptors["hp2"], d1.all_orm_descriptors["hp2"]) def test_correct_for_proxies_doesnt_impact_synonyms(self): from sqlalchemy import inspect @@ -1379,153 +1363,147 @@ class DeclarativeMixinPropertyTest( class Mixin(object): @declared_attr def data_syn(cls): - return synonym('data') + return synonym("data") class Base(declarative_base(), Mixin): - __tablename__ = 'test' + __tablename__ = "test" id = Column(String, primary_key=True) data = Column(String) type = Column(String) __mapper_args__ = { - 'polymorphic_on': type, - 'polymorphic_identity': 'base' + "polymorphic_on": type, + "polymorphic_identity": "base", } class Derived(Base): - __mapper_args__ = { - 'polymorphic_identity': 'derived' - } + __mapper_args__ = {"polymorphic_identity": "derived"} assert Base.data_syn._is_internal_proxy assert Derived.data_syn._is_internal_proxy b1 = inspect(Base) d1 = inspect(Derived) - is_( - b1.attrs['data_syn'], - d1.attrs['data_syn'], - ) + is_(b1.attrs["data_syn"], d1.attrs["data_syn"]) s = Session() self.assert_compile( - s.query(Base.data_syn).filter(Base.data_syn == 'foo'), - 'SELECT test.data AS test_data FROM test WHERE test.data = :data_1', - dialect='default' + s.query(Base.data_syn).filter(Base.data_syn == "foo"), + "SELECT test.data AS test_data FROM test WHERE test.data = :data_1", + dialect="default", ) self.assert_compile( - s.query(Derived.data_syn).filter(Derived.data_syn == 'foo'), - 'SELECT test.data AS test_data FROM test WHERE test.data = ' - ':data_1 AND test.type IN (:type_1)', - dialect='default', - checkparams={"type_1": "derived", "data_1": "foo"} + s.query(Derived.data_syn).filter(Derived.data_syn == "foo"), + "SELECT test.data AS test_data FROM test WHERE test.data = " + ":data_1 AND test.type IN (:type_1)", + dialect="default", + checkparams={"type_1": "derived", "data_1": "foo"}, ) def test_column_in_mapper_args(self): - class MyMixin(object): - @declared_attr def type_(cls): return Column(String(50)) - __mapper_args__ = {'polymorphic_on': type_} + + __mapper_args__ = {"polymorphic_on": type_} class MyModel(Base, MyMixin): - __tablename__ = 'test' + __tablename__ = "test" id = Column(Integer, primary_key=True) configure_mappers() col = MyModel.__mapper__.polymorphic_on - eq_(col.name, 'type_') + eq_(col.name, "type_") assert col.table is not None def test_column_in_mapper_args_used_multiple_times(self): - class MyMixin(object): version_id = Column(Integer) - __mapper_args__ = {'version_id_col': version_id} + __mapper_args__ = {"version_id_col": version_id} class ModelOne(Base, MyMixin): - __tablename__ = 'm1' + __tablename__ = "m1" id = Column(Integer, primary_key=True) class ModelTwo(Base, MyMixin): - __tablename__ = 'm2' + __tablename__ = "m2" id = Column(Integer, primary_key=True) is_( - ModelOne.__mapper__.version_id_col, - ModelOne.__table__.c.version_id + ModelOne.__mapper__.version_id_col, ModelOne.__table__.c.version_id ) is_( - ModelTwo.__mapper__.version_id_col, - ModelTwo.__table__.c.version_id + ModelTwo.__mapper__.version_id_col, ModelTwo.__table__.c.version_id ) def test_deferred(self): - class MyMixin(object): - @declared_attr def data(cls): - return deferred(Column('data', String(50))) + return deferred(Column("data", String(50))) class MyModel(Base, MyMixin): - __tablename__ = 'test' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "test" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) Base.metadata.create_all() sess = create_session() - sess.add_all([MyModel(data='d1'), MyModel(data='d2')]) + sess.add_all([MyModel(data="d1"), MyModel(data="d2")]) sess.flush() sess.expunge_all() d1, d2 = sess.query(MyModel).order_by(MyModel.data) - assert 'data' not in d1.__dict__ - assert d1.data == 'd1' - assert 'data' in d1.__dict__ + assert "data" not in d1.__dict__ + assert d1.data == "d1" + assert "data" in d1.__dict__ def _test_relationship(self, usestring): - class RefTargetMixin(object): - @declared_attr def target_id(cls): - return Column('target_id', ForeignKey('target.id')) + return Column("target_id", ForeignKey("target.id")) + if usestring: @declared_attr def target(cls): - return relationship('Target', - primaryjoin='Target.id==%s.target_id' - % cls.__name__) + return relationship( + "Target", + primaryjoin="Target.id==%s.target_id" % cls.__name__, + ) + else: @declared_attr def target(cls): - return relationship('Target') + return relationship("Target") class Foo(Base, RefTargetMixin): - __tablename__ = 'foo' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "foo" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) class Bar(Base, RefTargetMixin): - __tablename__ = 'bar' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "bar" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) class Target(Base): - __tablename__ = 'target' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "target" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) Base.metadata.create_all() sess = create_session() @@ -1546,7 +1524,7 @@ class DeclarativeMixinPropertyTest( class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_singleton_behavior_within_decl(self): counter = mock.Mock() @@ -1555,10 +1533,10 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): @declared_attr def my_prop(cls): counter(cls) - return Column('x', Integer) + return Column("x", Integer) class A(Base, Mixin): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) @declared_attr @@ -1568,16 +1546,14 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): eq_(counter.mock_calls, [mock.call(A)]) class B(Base, Mixin): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) @declared_attr def my_other_prop(cls): return column_property(cls.my_prop + 5) - eq_( - counter.mock_calls, - [mock.call(A), mock.call(B)]) + eq_(counter.mock_calls, [mock.call(A), mock.call(B)]) # this is why we need singleton-per-class behavior. We get # an un-bound "x" column otherwise here, because my_prop() generates @@ -1592,11 +1568,11 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): s = Session() self.assert_compile( s.query(A), - "SELECT a.x AS a_x, a.x + :x_1 AS anon_1, a.id AS a_id FROM a" + "SELECT a.x AS a_x, a.x + :x_1 AS anon_1, a.id AS a_id FROM a", ) self.assert_compile( s.query(B), - "SELECT b.x AS b_x, b.x + :x_1 AS anon_1, b.id AS b_id FROM b" + "SELECT b.x AS b_x, b.x + :x_1 AS anon_1, b.id AS b_id FROM b", ) @testing.requires.predictable_gc @@ -1607,10 +1583,10 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): @declared_attr def my_prop(cls): counter(cls.__name__) - return Column('x', Integer) + return Column("x", Integer) class A(Base, Mixin): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) @declared_attr @@ -1626,13 +1602,15 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): class Mixin(object): @declared_attr def my_prop(cls): - return Column('x', Integer) + return Column("x", Integer) assert_raises_message( sa.exc.SAWarning, "Unmanaged access of declarative attribute my_prop " "from non-mapped class Mixin", - getattr, Mixin, "my_prop" + getattr, + Mixin, + "my_prop", ) def test_can_we_access_the_mixin_straight_special_names(self): @@ -1668,17 +1646,14 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): def y(cls): cls.__tablename__ - eq_( - counter.mock_calls, - [mock.call(Foo)] - ) + eq_(counter.mock_calls, [mock.call(Foo)]) - eq_(Foo.__tablename__, 'foo') - eq_(Foo.__tablename__, 'foo') + eq_(Foo.__tablename__, "foo") + eq_(Foo.__tablename__, "foo") eq_( counter.mock_calls, - [mock.call(Foo), mock.call(Foo), mock.call(Foo)] + [mock.call(Foo), mock.call(Foo), mock.call(Foo)], ) def test_property_noncascade(self): @@ -1691,7 +1666,7 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): return column_property(cls.x + 2) class A(Base, Mixin): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) x = Column(Integer) @@ -1711,7 +1686,7 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): return column_property(cls.x + 2) class A(Base, Mixin): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) x = Column(Integer) @@ -1731,17 +1706,19 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): return column_property(cls.x + 2) class A(Base, Mixin): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) x = Column(Integer) with expect_warnings( - "Attribute 'my_prop' on class .*B.* " - "cannot be processed due to @declared_attr.cascading; " - "skipping"): + "Attribute 'my_prop' on class .*B.* " + "cannot be processed due to @declared_attr.cascading; " + "skipping" + ): + class B(A): - my_prop = Column('foobar', Integer) + my_prop = Column("foobar", Integer) eq_(counter.mock_calls, [mock.call(A), mock.call(B)]) @@ -1757,7 +1734,7 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): return column_property(cls.x + 2) class A(Abs): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) x = Column(Integer) @@ -1777,6 +1754,7 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): "@declared_attr.cascading is not supported on the " "__tablename__ attribute on class .*A." ): + class A(Mixin, Base): id = Column(Integer, primary_key=True) @@ -1792,19 +1770,19 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): @declared_attr.cascading def my_attr(cls): if decl.has_inherited_table(cls): - id = Column(ForeignKey('a.my_attr'), primary_key=True) - asserted['b'].add(id) + id = Column(ForeignKey("a.my_attr"), primary_key=True) + asserted["b"].add(id) else: id = Column(Integer, primary_key=True) - asserted['a'].add(id) + asserted["a"].add(id) return id class A(Base, Mixin): - __tablename__ = 'a' + __tablename__ = "a" @declared_attr def __mapper_args__(cls): - asserted['a'].add(cls.my_attr) + asserted["a"].add(cls.my_attr) return {} # here: @@ -1820,19 +1798,19 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): # descriptor from being invoked. class B(A): - __tablename__ = 'b' + __tablename__ = "b" @declared_attr def __mapper_args__(cls): - asserted['b'].add(cls.my_attr) + asserted["b"].add(cls.my_attr) return {} eq_( asserted, { - 'a': set([A.my_attr.property.columns[0]]), - 'b': set([B.my_attr.property.columns[0]]) - } + "a": set([A.my_attr.property.columns[0]]), + "b": set([B.my_attr.property.columns[0]]), + }, ) def test_column_pre_map(self): @@ -1843,10 +1821,10 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): def my_col(cls): counter(cls) assert not orm_base._mapper_or_none(cls) - return Column('x', Integer) + return Column("x", Integer) class A(Base, Mixin): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) @@ -1866,65 +1844,60 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): def address_count(cls): counter(cls.id) return column_property( - select([func.count(Address.id)]). - where(Address.user_id == cls.id). - as_scalar() + select([func.count(Address.id)]) + .where(Address.user_id == cls.id) + .as_scalar() ) class Address(Base): - __tablename__ = 'address' + __tablename__ = "address" id = Column(Integer, primary_key=True) - user_id = Column(ForeignKey('user.id')) + user_id = Column(ForeignKey("user.id")) class User(Base, HasAddressCount): - __tablename__ = 'user' + __tablename__ = "user" - eq_( - counter.mock_calls, - [mock.call(User.id)] - ) + eq_(counter.mock_calls, [mock.call(User.id)]) sess = Session() self.assert_compile( sess.query(User).having(User.address_count > 5), - 'SELECT (SELECT count(address.id) AS ' + "SELECT (SELECT count(address.id) AS " 'count_1 FROM address WHERE address.user_id = "user".id) ' 'AS anon_1, "user".id AS user_id FROM "user" ' - 'HAVING (SELECT count(address.id) AS ' + "HAVING (SELECT count(address.id) AS " 'count_1 FROM address WHERE address.user_id = "user".id) ' - '> :param_1' + "> :param_1", ) class AbstractTest(DeclarativeTestBase): - def test_abstract_boolean(self): - class A(Base): __abstract__ = True - __tablename__ = 'x' + __tablename__ = "x" id = Column(Integer, primary_key=True) class B(Base): __abstract__ = False - __tablename__ = 'y' + __tablename__ = "y" id = Column(Integer, primary_key=True) class C(Base): __abstract__ = False - __tablename__ = 'z' + __tablename__ = "z" id = Column(Integer, primary_key=True) class D(Base): - __tablename__ = 'q' + __tablename__ = "q" id = Column(Integer, primary_key=True) - eq_(set(Base.metadata.tables), set(['y', 'z', 'q'])) + eq_(set(Base.metadata.tables), set(["y", "z", "q"])) def test_middle_abstract_attributes(self): # test for [ticket:3219] class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) name = Column(String) @@ -1936,27 +1909,25 @@ class AbstractTest(DeclarativeTestBase): class C(B): c_value = Column(String) - eq_( - sa.inspect(C).attrs.keys(), ['id', 'name', 'data', 'c_value'] - ) + eq_(sa.inspect(C).attrs.keys(), ["id", "name", "data", "c_value"]) def test_middle_abstract_inherits(self): # test for [ticket:3240] class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) class AAbs(A): __abstract__ = True class B1(A): - __tablename__ = 'b1' - id = Column(ForeignKey('a.id'), primary_key=True) + __tablename__ = "b1" + id = Column(ForeignKey("a.id"), primary_key=True) class B2(AAbs): - __tablename__ = 'b2' - id = Column(ForeignKey('a.id'), primary_key=True) + __tablename__ = "b2" + id = Column(ForeignKey("a.id"), primary_key=True) assert B1.__mapper__.inherits is A.__mapper__ diff --git a/test/ext/declarative/test_reflection.py b/test/ext/declarative/test_reflection.py index fef9d794c4..e2bab0dd86 100644 --- a/test/ext/declarative/test_reflection.py +++ b/test/ext/declarative/test_reflection.py @@ -3,16 +3,14 @@ from sqlalchemy.ext import declarative as decl from sqlalchemy import testing from sqlalchemy import Integer, String, ForeignKey from sqlalchemy.testing.schema import Table, Column -from sqlalchemy.orm import relationship, create_session, \ - clear_mappers, \ - Session +from sqlalchemy.orm import relationship, create_session, clear_mappers, Session from sqlalchemy.testing import fixtures from sqlalchemy.testing.util import gc_collect from sqlalchemy.ext.declarative.base import _DeferredMapperConfig class DeclarativeReflectionBase(fixtures.TablesTest): - __requires__ = 'reflectable_autoincrement', + __requires__ = ("reflectable_autoincrement",) def setup(self): global Base @@ -24,116 +22,147 @@ class DeclarativeReflectionBase(fixtures.TablesTest): class DeclarativeReflectionTest(DeclarativeReflectionBase): - @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('name', String(50)), test_needs_fk=True) Table( - 'addresses', + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), + test_needs_fk=True, + ) + Table( + "addresses", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('email', String(50)), - Column('user_id', Integer, ForeignKey('users.id')), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("email", String(50)), + Column("user_id", Integer, ForeignKey("users.id")), test_needs_fk=True, ) Table( - 'imhandles', + "imhandles", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', Integer), - Column('network', String(50)), - Column('handle', String(50)), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", Integer), + Column("network", String(50)), + Column("handle", String(50)), test_needs_fk=True, ) def test_basic(self): class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' + __tablename__ = "users" __autoload__ = True - addresses = relationship('Address', backref='user') + addresses = relationship("Address", backref="user") class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' + __tablename__ = "addresses" __autoload__ = True - u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + u1 = User( + name="u1", addresses=[Address(email="one"), Address(email="two")] + ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [ - User(name='u1', - addresses=[Address(email='one'), Address(email='two')])]) - a1 = sess.query(Address).filter(Address.email == 'two').one() - eq_(a1, Address(email='two')) - eq_(a1.user, User(name='u1')) + eq_( + sess.query(User).all(), + [ + User( + name="u1", + addresses=[Address(email="one"), Address(email="two")], + ) + ], + ) + a1 = sess.query(Address).filter(Address.email == "two").one() + eq_(a1, Address(email="two")) + eq_(a1.user, User(name="u1")) def test_rekey(self): class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' + __tablename__ = "users" __autoload__ = True - nom = Column('name', String(50), key='nom') - addresses = relationship('Address', backref='user') + nom = Column("name", String(50), key="nom") + addresses = relationship("Address", backref="user") class Address(Base, fixtures.ComparableEntity): - __tablename__ = 'addresses' + __tablename__ = "addresses" __autoload__ = True - u1 = User(nom='u1', addresses=[Address(email='one'), - Address(email='two')]) + u1 = User( + nom="u1", addresses=[Address(email="one"), Address(email="two")] + ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [ - User(nom='u1', - addresses=[Address(email='one'), Address(email='two')])]) - a1 = sess.query(Address).filter(Address.email == 'two').one() - eq_(a1, Address(email='two')) - eq_(a1.user, User(nom='u1')) - assert_raises(TypeError, User, name='u3') + eq_( + sess.query(User).all(), + [ + User( + nom="u1", + addresses=[Address(email="one"), Address(email="two")], + ) + ], + ) + a1 = sess.query(Address).filter(Address.email == "two").one() + eq_(a1, Address(email="two")) + eq_(a1.user, User(nom="u1")) + assert_raises(TypeError, User, name="u3") def test_supplied_fk(self): class IMHandle(Base, fixtures.ComparableEntity): - __tablename__ = 'imhandles' + __tablename__ = "imhandles" __autoload__ = True - user_id = Column('user_id', Integer, ForeignKey('users.id')) + user_id = Column("user_id", Integer, ForeignKey("users.id")) class User(Base, fixtures.ComparableEntity): - __tablename__ = 'users' + __tablename__ = "users" __autoload__ = True - handles = relationship('IMHandle', backref='user') - - u1 = User(name='u1', handles=[ - IMHandle(network='blabber', handle='foo'), - IMHandle(network='lol', handle='zomg')]) + handles = relationship("IMHandle", backref="user") + + u1 = User( + name="u1", + handles=[ + IMHandle(network="blabber", handle="foo"), + IMHandle(network="lol", handle="zomg"), + ], + ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [ - User(name='u1', handles=[IMHandle(network='blabber', handle='foo'), - IMHandle(network='lol', handle='zomg')])]) - a1 = sess.query(IMHandle).filter(IMHandle.handle == 'zomg' - ).one() - eq_(a1, IMHandle(network='lol', handle='zomg')) - eq_(a1.user, User(name='u1')) + eq_( + sess.query(User).all(), + [ + User( + name="u1", + handles=[ + IMHandle(network="blabber", handle="foo"), + IMHandle(network="lol", handle="zomg"), + ], + ) + ], + ) + a1 = sess.query(IMHandle).filter(IMHandle.handle == "zomg").one() + eq_(a1, IMHandle(network="lol", handle="zomg")) + eq_(a1.user, User(name="u1")) class DeferredReflectBase(DeclarativeReflectionBase): - def teardown(self): super(DeferredReflectBase, self).teardown() _DeferredMapperConfig._configs.clear() @@ -143,78 +172,90 @@ Base = None class DeferredReflectPKFKTest(DeferredReflectBase): - @classmethod def define_tables(cls, metadata): - Table("a", metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - ) - Table("b", metadata, - Column('id', Integer, - ForeignKey('a.id'), - primary_key=True), - Column('x', Integer, primary_key=True) - ) + Table( + "a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) + Table( + "b", + metadata, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + Column("x", Integer, primary_key=True), + ) def test_pk_fk(self): - class B(decl.DeferredReflection, fixtures.ComparableEntity, - Base): - __tablename__ = 'b' + class B(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = "b" a = relationship("A") - class A(decl.DeferredReflection, fixtures.ComparableEntity, - Base): - __tablename__ = 'a' + class A(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = "a" decl.DeferredReflection.prepare(testing.db) class DeferredReflectionTest(DeferredReflectBase): - @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('name', String(50)), test_needs_fk=True) Table( - 'addresses', + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), + test_needs_fk=True, + ) + Table( + "addresses", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('email', String(50)), - Column('user_id', Integer, ForeignKey('users.id')), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("email", String(50)), + Column("user_id", Integer, ForeignKey("users.id")), test_needs_fk=True, ) def _roundtrip(self): - User = Base._decl_class_registry['User'] - Address = Base._decl_class_registry['Address'] + User = Base._decl_class_registry["User"] + Address = Base._decl_class_registry["Address"] - u1 = User(name='u1', addresses=[Address(email='one'), - Address(email='two')]) + u1 = User( + name="u1", addresses=[Address(email="one"), Address(email="two")] + ) sess = create_session() sess.add(u1) sess.flush() sess.expunge_all() - eq_(sess.query(User).all(), [ - User(name='u1', - addresses=[Address(email='one'), Address(email='two')])]) - a1 = sess.query(Address).filter(Address.email == 'two').one() - eq_(a1, Address(email='two')) - eq_(a1.user, User(name='u1')) + eq_( + sess.query(User).all(), + [ + User( + name="u1", + addresses=[Address(email="one"), Address(email="two")], + ) + ], + ) + a1 = sess.query(Address).filter(Address.email == "two").one() + eq_(a1, Address(email="two")) + eq_(a1.user, User(name="u1")) def test_basic_deferred(self): - class User(decl.DeferredReflection, fixtures.ComparableEntity, - Base): - __tablename__ = 'users' + class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = "users" addresses = relationship("Address", backref="user") - class Address(decl.DeferredReflection, fixtures.ComparableEntity, - Base): - __tablename__ = 'addresses' + class Address( + decl.DeferredReflection, fixtures.ComparableEntity, Base + ): + __tablename__ = "addresses" decl.DeferredReflection.prepare(testing.db) self._roundtrip() @@ -227,28 +268,28 @@ class DeferredReflectionTest(DeferredReflectBase): __abstract__ = True class User(fixtures.ComparableEntity, DefBase): - __tablename__ = 'users' + __tablename__ = "users" addresses = relationship("Address", backref="user") class Address(fixtures.ComparableEntity, DefBase): - __tablename__ = 'addresses' + __tablename__ = "addresses" class Fake(OtherDefBase): - __tablename__ = 'nonexistent' + __tablename__ = "nonexistent" DefBase.prepare(testing.db) self._roundtrip() def test_redefine_fk_double(self): - class User(decl.DeferredReflection, fixtures.ComparableEntity, - Base): - __tablename__ = 'users' + class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = "users" addresses = relationship("Address", backref="user") - class Address(decl.DeferredReflection, fixtures.ComparableEntity, - Base): - __tablename__ = 'addresses' - user_id = Column(Integer, ForeignKey('users.id')) + class Address( + decl.DeferredReflection, fixtures.ComparableEntity, Base + ): + __tablename__ = "addresses" + user_id = Column(Integer, ForeignKey("users.id")) decl.DeferredReflection.prepare(testing.db) self._roundtrip() @@ -257,44 +298,34 @@ class DeferredReflectionTest(DeferredReflectBase): """test that __mapper_args__ is not called until *after* table reflection""" - class User(decl.DeferredReflection, fixtures.ComparableEntity, - Base): - __tablename__ = 'users' + class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = "users" @decl.declared_attr def __mapper_args__(cls): - return { - "primary_key": cls.__table__.c.id - } + return {"primary_key": cls.__table__.c.id} decl.DeferredReflection.prepare(testing.db) sess = Session() - sess.add_all([ - User(name='G'), - User(name='Q'), - User(name='A'), - User(name='C'), - ]) + sess.add_all( + [User(name="G"), User(name="Q"), User(name="A"), User(name="C")] + ) sess.commit() eq_( sess.query(User).order_by(User.name).all(), - [ - User(name='A'), - User(name='C'), - User(name='G'), - User(name='Q'), - ] + [User(name="A"), User(name="C"), User(name="G"), User(name="Q")], ) @testing.requires.predictable_gc def test_cls_not_strong_ref(self): - class User(decl.DeferredReflection, fixtures.ComparableEntity, - Base): - __tablename__ = 'users' + class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = "users" + + class Address( + decl.DeferredReflection, fixtures.ComparableEntity, Base + ): + __tablename__ = "addresses" - class Address(decl.DeferredReflection, fixtures.ComparableEntity, - Base): - __tablename__ = 'addresses' eq_(len(_DeferredMapperConfig._configs), 2) del Address gc_collect() @@ -304,112 +335,128 @@ class DeferredReflectionTest(DeferredReflectBase): class DeferredSecondaryReflectionTest(DeferredReflectBase): - @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('name', String(50)), test_needs_fk=True) - - Table('user_items', metadata, - Column('user_id', ForeignKey('users.id'), primary_key=True), - Column('item_id', ForeignKey('items.id'), primary_key=True), - test_needs_fk=True - ) - - Table('items', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - test_needs_fk=True - ) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), + test_needs_fk=True, + ) + + Table( + "user_items", + metadata, + Column("user_id", ForeignKey("users.id"), primary_key=True), + Column("item_id", ForeignKey("items.id"), primary_key=True), + test_needs_fk=True, + ) + + Table( + "items", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), + test_needs_fk=True, + ) def _roundtrip(self): - User = Base._decl_class_registry['User'] - Item = Base._decl_class_registry['Item'] + User = Base._decl_class_registry["User"] + Item = Base._decl_class_registry["Item"] - u1 = User(name='u1', items=[Item(name='i1'), Item(name='i2')]) + u1 = User(name="u1", items=[Item(name="i1"), Item(name="i2")]) sess = Session() sess.add(u1) sess.commit() - eq_(sess.query(User).all(), [ - User(name='u1', items=[Item(name='i1'), Item(name='i2')])]) + eq_( + sess.query(User).all(), + [User(name="u1", items=[Item(name="i1"), Item(name="i2")])], + ) def test_string_resolution(self): class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): - __tablename__ = 'users' + __tablename__ = "users" items = relationship("Item", secondary="user_items") class Item(decl.DeferredReflection, fixtures.ComparableEntity, Base): - __tablename__ = 'items' + __tablename__ = "items" decl.DeferredReflection.prepare(testing.db) self._roundtrip() def test_table_resolution(self): class User(decl.DeferredReflection, fixtures.ComparableEntity, Base): - __tablename__ = 'users' + __tablename__ = "users" - items = relationship("Item", - secondary=Table("user_items", Base.metadata)) + items = relationship( + "Item", secondary=Table("user_items", Base.metadata) + ) class Item(decl.DeferredReflection, fixtures.ComparableEntity, Base): - __tablename__ = 'items' + __tablename__ = "items" decl.DeferredReflection.prepare(testing.db) self._roundtrip() class DeferredInhReflectBase(DeferredReflectBase): - def _roundtrip(self): - Foo = Base._decl_class_registry['Foo'] - Bar = Base._decl_class_registry['Bar'] + Foo = Base._decl_class_registry["Foo"] + Bar = Base._decl_class_registry["Bar"] s = Session(testing.db) - s.add_all([ - Bar(data='d1', bar_data='b1'), - Bar(data='d2', bar_data='b2'), - Bar(data='d3', bar_data='b3'), - Foo(data='d4') - ]) + s.add_all( + [ + Bar(data="d1", bar_data="b1"), + Bar(data="d2", bar_data="b2"), + Bar(data="d3", bar_data="b3"), + Foo(data="d4"), + ] + ) s.commit() eq_( s.query(Foo).order_by(Foo.id).all(), [ - Bar(data='d1', bar_data='b1'), - Bar(data='d2', bar_data='b2'), - Bar(data='d3', bar_data='b3'), - Foo(data='d4') - ] + Bar(data="d1", bar_data="b1"), + Bar(data="d2", bar_data="b2"), + Bar(data="d3", bar_data="b3"), + Foo(data="d4"), + ], ) class DeferredSingleInhReflectionTest(DeferredInhReflectBase): - @classmethod def define_tables(cls, metadata): - Table("foo", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(32)), - Column('data', String(30)), - Column('bar_data', String(30)) - ) + Table( + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(32)), + Column("data", String(30)), + Column("bar_data", String(30)), + ) def test_basic(self): - class Foo(decl.DeferredReflection, fixtures.ComparableEntity, - Base): - __tablename__ = 'foo' - __mapper_args__ = {"polymorphic_on": "type", - "polymorphic_identity": "foo"} + class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = "foo" + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "foo", + } class Bar(Foo): __mapper_args__ = {"polymorphic_identity": "bar"} @@ -418,11 +465,12 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): self._roundtrip() def test_add_subclass_column(self): - class Foo(decl.DeferredReflection, fixtures.ComparableEntity, - Base): - __tablename__ = 'foo' - __mapper_args__ = {"polymorphic_on": "type", - "polymorphic_identity": "foo"} + class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = "foo" + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "foo", + } class Bar(Foo): __mapper_args__ = {"polymorphic_identity": "bar"} @@ -432,11 +480,12 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): self._roundtrip() def test_add_pk_column(self): - class Foo(decl.DeferredReflection, fixtures.ComparableEntity, - Base): - __tablename__ = 'foo' - __mapper_args__ = {"polymorphic_on": "type", - "polymorphic_identity": "foo"} + class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = "foo" + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "foo", + } id = Column(Integer, primary_key=True) class Bar(Foo): @@ -447,45 +496,51 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): class DeferredJoinedInhReflectionTest(DeferredInhReflectBase): - @classmethod def define_tables(cls, metadata): - Table("foo", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(32)), - Column('data', String(30)), - test_needs_fk=True, - ) - Table('bar', metadata, - Column('id', Integer, ForeignKey('foo.id'), primary_key=True), - Column('bar_data', String(30)), - test_needs_fk=True, - ) + Table( + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(32)), + Column("data", String(30)), + test_needs_fk=True, + ) + Table( + "bar", + metadata, + Column("id", Integer, ForeignKey("foo.id"), primary_key=True), + Column("bar_data", String(30)), + test_needs_fk=True, + ) def test_basic(self): - class Foo(decl.DeferredReflection, fixtures.ComparableEntity, - Base): - __tablename__ = 'foo' - __mapper_args__ = {"polymorphic_on": "type", - "polymorphic_identity": "foo"} + class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = "foo" + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "foo", + } class Bar(Foo): - __tablename__ = 'bar' + __tablename__ = "bar" __mapper_args__ = {"polymorphic_identity": "bar"} decl.DeferredReflection.prepare(testing.db) self._roundtrip() def test_add_subclass_column(self): - class Foo(decl.DeferredReflection, fixtures.ComparableEntity, - Base): - __tablename__ = 'foo' - __mapper_args__ = {"polymorphic_on": "type", - "polymorphic_identity": "foo"} + class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = "foo" + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "foo", + } class Bar(Foo): - __tablename__ = 'bar' + __tablename__ = "bar" __mapper_args__ = {"polymorphic_identity": "bar"} bar_data = Column(String(30)) @@ -493,31 +548,33 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase): self._roundtrip() def test_add_pk_column(self): - class Foo(decl.DeferredReflection, fixtures.ComparableEntity, - Base): - __tablename__ = 'foo' - __mapper_args__ = {"polymorphic_on": "type", - "polymorphic_identity": "foo"} + class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = "foo" + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "foo", + } id = Column(Integer, primary_key=True) class Bar(Foo): - __tablename__ = 'bar' + __tablename__ = "bar" __mapper_args__ = {"polymorphic_identity": "bar"} decl.DeferredReflection.prepare(testing.db) self._roundtrip() def test_add_fk_pk_column(self): - class Foo(decl.DeferredReflection, fixtures.ComparableEntity, - Base): - __tablename__ = 'foo' - __mapper_args__ = {"polymorphic_on": "type", - "polymorphic_identity": "foo"} + class Foo(decl.DeferredReflection, fixtures.ComparableEntity, Base): + __tablename__ = "foo" + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "foo", + } class Bar(Foo): - __tablename__ = 'bar' + __tablename__ = "bar" __mapper_args__ = {"polymorphic_identity": "bar"} - id = Column(Integer, ForeignKey('foo.id'), primary_key=True) + id = Column(Integer, ForeignKey("foo.id"), primary_key=True) decl.DeferredReflection.prepare(testing.db) self._roundtrip() diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index 69f0c5ed0e..4e063a84b8 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -3,8 +3,17 @@ import copy import pickle from sqlalchemy import Integer, ForeignKey, String, or_, MetaData -from sqlalchemy.orm import relationship, configure_mappers, mapper, Session,\ - collections, sessionmaker, aliased, clear_mappers, create_session +from sqlalchemy.orm import ( + relationship, + configure_mappers, + mapper, + Session, + collections, + sessionmaker, + aliased, + clear_mappers, + create_session, +) from sqlalchemy import exc from sqlalchemy.orm.collections import collection, attribute_mapped_collection from sqlalchemy.ext.associationproxy import association_proxy @@ -60,20 +69,26 @@ class AutoFlushTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table( - 'parent', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True)) + "parent", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) Table( - 'association', metadata, - Column('parent_id', ForeignKey('parent.id'), primary_key=True), - Column('child_id', ForeignKey('child.id'), primary_key=True), - Column('name', String(50)) + "association", + metadata, + Column("parent_id", ForeignKey("parent.id"), primary_key=True), + Column("child_id", ForeignKey("child.id"), primary_key=True), + Column("name", String(50)), ) Table( - 'child', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)) + "child", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), ) def _fixture(self, collection_class, is_dict=False): @@ -86,33 +101,45 @@ class AutoFlushTest(fixtures.TablesTest): class Association(object): if is_dict: + def __init__(self, key, child): self.child = child + else: + def __init__(self, child): self.child = child - mapper(Parent, self.tables.parent, properties={ - "_collection": relationship(Association, - collection_class=collection_class, - backref="parent") - }) - mapper(Association, self.tables.association, properties={ - "child": relationship(Child, backref="association") - }) + mapper( + Parent, + self.tables.parent, + properties={ + "_collection": relationship( + Association, + collection_class=collection_class, + backref="parent", + ) + }, + ) + mapper( + Association, + self.tables.association, + properties={"child": relationship(Child, backref="association")}, + ) mapper(Child, self.tables.child) return Parent, Child, Association def _test_premature_flush(self, collection_class, fn, is_dict=False): Parent, Child, Association = self._fixture( - collection_class, is_dict=is_dict) + collection_class, is_dict=is_dict + ) session = Session(testing.db, autoflush=True, expire_on_commit=True) p1 = Parent() - c1 = Child('c1') - c2 = Child('c2') + c1 = Child("c1") + c2 = Child("c2") session.add(p1) session.add(c1) session.add(c2) @@ -130,27 +157,31 @@ class AutoFlushTest(fixtures.TablesTest): def test_list_append(self): self._test_premature_flush( - list, lambda collection, obj: collection.append(obj)) + list, lambda collection, obj: collection.append(obj) + ) def test_list_extend(self): self._test_premature_flush( - list, lambda collection, obj: collection.extend([obj])) + list, lambda collection, obj: collection.extend([obj]) + ) def test_set_add(self): self._test_premature_flush( - set, lambda collection, obj: collection.add(obj)) + set, lambda collection, obj: collection.add(obj) + ) def test_set_extend(self): self._test_premature_flush( - set, lambda collection, obj: collection.update([obj])) + set, lambda collection, obj: collection.update([obj]) + ) def test_dict_set(self): def set_(collection, obj): collection[obj.name] = obj self._test_premature_flush( - collections.attribute_mapped_collection('name'), - set_, is_dict=True) + collections.attribute_mapped_collection("name"), set_, is_dict=True + ) class _CollectionOperations(fixtures.TestBase): @@ -159,36 +190,55 @@ class _CollectionOperations(fixtures.TestBase): metadata = MetaData(testing.db) - parents_table = Table('Parent', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(128))) - children_table = Table('Children', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, - ForeignKey('Parent.id')), - Column('foo', String(128)), - Column('name', String(128))) + parents_table = Table( + "Parent", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(128)), + ) + children_table = Table( + "Children", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_id", Integer, ForeignKey("Parent.id")), + Column("foo", String(128)), + Column("name", String(128)), + ) class Parent(object): - children = association_proxy('_children', 'name') + children = association_proxy("_children", "name") def __init__(self, name): self.name = name class Child(object): if collection_class and issubclass(collection_class, dict): + def __init__(self, foo, name): self.foo = foo self.name = name + else: + def __init__(self, name): self.name = name - mapper(Parent, parents_table, properties={ - '_children': relationship(Child, lazy='joined', backref='parent', - collection_class=collection_class)}) + mapper( + Parent, + parents_table, + properties={ + "_children": relationship( + Child, + lazy="joined", + backref="parent", + collection_class=collection_class, + ) + }, + ) mapper(Child, children_table) metadata.create_all() @@ -211,7 +261,7 @@ class _CollectionOperations(fixtures.TestBase): def _test_sequence_ops(self): Parent, Child = self.Parent, self.Child - p1 = Parent('P1') + p1 = Parent("P1") def assert_index(expected, value, *args): """Assert index of child value is equal to expected. @@ -229,7 +279,7 @@ class _CollectionOperations(fixtures.TestBase): self.assert_(not p1._children) self.assert_(not p1.children) - ch = Child('regular') + ch = Child("regular") p1._children.append(ch) self.assert_(ch in p1._children) @@ -238,25 +288,25 @@ class _CollectionOperations(fixtures.TestBase): self.assert_(p1.children) self.assert_(len(p1.children) == 1) self.assert_(ch not in p1.children) - self.assert_('regular' in p1.children) + self.assert_("regular" in p1.children) - assert_index(0, 'regular') - assert_index(None, 'regular', 1) + assert_index(0, "regular") + assert_index(None, "regular", 1) - p1.children.append('proxied') + p1.children.append("proxied") - self.assert_('proxied' in p1.children) - self.assert_('proxied' not in p1._children) + self.assert_("proxied" in p1.children) + self.assert_("proxied" not in p1._children) self.assert_(len(p1.children) == 2) self.assert_(len(p1._children) == 2) - self.assert_(p1._children[0].name == 'regular') - self.assert_(p1._children[1].name == 'proxied') + self.assert_(p1._children[0].name == "regular") + self.assert_(p1._children[1].name == "proxied") - assert_index(0, 'regular') - assert_index(1, 'proxied') - assert_index(1, 'proxied', 1) - assert_index(None, 'proxied', 0, 1) + assert_index(0, "regular") + assert_index(1, "proxied") + assert_index(1, "proxied", 1) + assert_index(None, "proxied", 0, 1) del p1._children[1] @@ -264,22 +314,22 @@ class _CollectionOperations(fixtures.TestBase): self.assert_(len(p1.children) == 1) self.assert_(p1._children[0] == ch) - assert_index(None, 'proxied') + assert_index(None, "proxied") del p1.children[0] self.assert_(len(p1._children) == 0) self.assert_(len(p1.children) == 0) - assert_index(None, 'regular') + assert_index(None, "regular") - p1.children = ['a', 'b', 'c'] + p1.children = ["a", "b", "c"] self.assert_(len(p1._children) == 3) self.assert_(len(p1.children) == 3) - assert_index(0, 'a') - assert_index(1, 'b') - assert_index(2, 'c') + assert_index(0, "a") + assert_index(1, "b") + assert_index(2, "c") del ch p1 = self.roundtrip(p1) @@ -287,9 +337,9 @@ class _CollectionOperations(fixtures.TestBase): self.assert_(len(p1._children) == 3) self.assert_(len(p1.children) == 3) - assert_index(0, 'a') - assert_index(1, 'b') - assert_index(2, 'c') + assert_index(0, "a") + assert_index(1, "b") + assert_index(2, "c") popped = p1.children.pop() self.assert_(len(p1.children) == 2) @@ -301,61 +351,61 @@ class _CollectionOperations(fixtures.TestBase): self.assert_(popped not in p1.children) assert_index(None, popped) - p1.children[1] = 'changed-in-place' - self.assert_(p1.children[1] == 'changed-in-place') - assert_index(1, 'changed-in-place') - assert_index(None, 'b') + p1.children[1] = "changed-in-place" + self.assert_(p1.children[1] == "changed-in-place") + assert_index(1, "changed-in-place") + assert_index(None, "b") inplace_id = p1._children[1].id p1 = self.roundtrip(p1) - self.assert_(p1.children[1] == 'changed-in-place') + self.assert_(p1.children[1] == "changed-in-place") assert p1._children[1].id == inplace_id - p1.children.append('changed-in-place') - self.assert_(p1.children.count('changed-in-place') == 2) - assert_index(1, 'changed-in-place') + p1.children.append("changed-in-place") + self.assert_(p1.children.count("changed-in-place") == 2) + assert_index(1, "changed-in-place") - p1.children.remove('changed-in-place') - self.assert_(p1.children.count('changed-in-place') == 1) - assert_index(1, 'changed-in-place') + p1.children.remove("changed-in-place") + self.assert_(p1.children.count("changed-in-place") == 1) + assert_index(1, "changed-in-place") p1 = self.roundtrip(p1) - self.assert_(p1.children.count('changed-in-place') == 1) - assert_index(1, 'changed-in-place') + self.assert_(p1.children.count("changed-in-place") == 1) + assert_index(1, "changed-in-place") p1._children = [] self.assert_(len(p1.children) == 0) - assert_index(None, 'changed-in-place') + assert_index(None, "changed-in-place") - after = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] - p1.children = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] + after = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] + p1.children = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] self.assert_(len(p1.children) == 10) self.assert_([c.name for c in p1._children] == after) for i, val in enumerate(after): assert_index(i, val) - p1.children[2:6] = ['x'] * 4 - after = ['a', 'b', 'x', 'x', 'x', 'x', 'g', 'h', 'i', 'j'] + p1.children[2:6] = ["x"] * 4 + after = ["a", "b", "x", "x", "x", "x", "g", "h", "i", "j"] self.assert_(p1.children == after) self.assert_([c.name for c in p1._children] == after) - assert_index(2, 'x') - assert_index(3, 'x', 3) - assert_index(None, 'x', 6) + assert_index(2, "x") + assert_index(3, "x", 3) + assert_index(None, "x", 6) - p1.children[2:6] = ['y'] - after = ['a', 'b', 'y', 'g', 'h', 'i', 'j'] + p1.children[2:6] = ["y"] + after = ["a", "b", "y", "g", "h", "i", "j"] self.assert_(p1.children == after) self.assert_([c.name for c in p1._children] == after) - assert_index(2, 'y') - assert_index(None, 'y', 3) + assert_index(2, "y") + assert_index(None, "y", 3) - p1.children[2:3] = ['z'] * 4 - after = ['a', 'b', 'z', 'z', 'z', 'z', 'g', 'h', 'i', 'j'] + p1.children[2:3] = ["z"] * 4 + after = ["a", "b", "z", "z", "z", "z", "g", "h", "i", "j"] self.assert_(p1.children == after) self.assert_([c.name for c in p1._children] == after) - p1.children[2::2] = ['O'] * 4 - after = ['a', 'b', 'O', 'z', 'O', 'z', 'O', 'h', 'O', 'j'] + p1.children[2::2] = ["O"] * 4 + after = ["a", "b", "O", "z", "O", "z", "O", "h", "O", "j"] self.assert_(p1.children == after) self.assert_([c.name for c in p1._children] == after) @@ -366,45 +416,45 @@ class _CollectionOperations(fixtures.TestBase): self.assert_(p1.children == after) self.assert_([c.name for c in p1._children] == after) - p1.children += ['a', 'b'] - after = ['a', 'b'] + p1.children += ["a", "b"] + after = ["a", "b"] self.assert_(p1.children == after) self.assert_([c.name for c in p1._children] == after) - p1.children[:] = ['d', 'e'] - after = ['d', 'e'] + p1.children[:] = ["d", "e"] + after = ["d", "e"] self.assert_(p1.children == after) self.assert_([c.name for c in p1._children] == after) - p1.children[:] = ['a', 'b'] + p1.children[:] = ["a", "b"] - p1.children += ['c'] - after = ['a', 'b', 'c'] + p1.children += ["c"] + after = ["a", "b", "c"] self.assert_(p1.children == after) self.assert_([c.name for c in p1._children] == after) p1.children *= 1 - after = ['a', 'b', 'c'] + after = ["a", "b", "c"] self.assert_(p1.children == after) self.assert_([c.name for c in p1._children] == after) p1.children *= 2 - after = ['a', 'b', 'c', 'a', 'b', 'c'] + after = ["a", "b", "c", "a", "b", "c"] self.assert_(p1.children == after) self.assert_([c.name for c in p1._children] == after) - p1.children = ['a'] - after = ['a'] + p1.children = ["a"] + after = ["a"] self.assert_(p1.children == after) self.assert_([c.name for c in p1._children] == after) - self.assert_((p1.children * 2) == ['a', 'a']) - self.assert_((2 * p1.children) == ['a', 'a']) + self.assert_((p1.children * 2) == ["a", "a"]) + self.assert_((2 * p1.children) == ["a", "a"]) self.assert_((p1.children * 0) == []) self.assert_((0 * p1.children) == []) - self.assert_((p1.children + ['b']) == ['a', 'b']) - self.assert_((['b'] + p1.children) == ['b', 'a']) + self.assert_((p1.children + ["b"]) == ["a", "b"]) + self.assert_((["b"] + p1.children) == ["b", "a"]) try: p1.children + 123 @@ -433,12 +483,12 @@ class CustomDictTest(_CollectionOperations): def test_mapping_ops(self): Parent, Child = self.Parent, self.Child - p1 = Parent('P1') + p1 = Parent("P1") self.assert_(not p1._children) self.assert_(not p1.children) - ch = Child('a', 'regular') + ch = Child("a", "regular") p1._children.append(ch) self.assert_(ch in list(p1._children.values())) @@ -447,49 +497,49 @@ class CustomDictTest(_CollectionOperations): self.assert_(p1.children) self.assert_(len(p1.children) == 1) self.assert_(ch not in p1.children) - self.assert_('a' in p1.children) - self.assert_(p1.children['a'] == 'regular') - self.assert_(p1._children['a'] == ch) + self.assert_("a" in p1.children) + self.assert_(p1.children["a"] == "regular") + self.assert_(p1._children["a"] == ch) - p1.children['b'] = 'proxied' + p1.children["b"] = "proxied" - self.assert_('proxied' in list(p1.children.values())) - self.assert_('b' in p1.children) - self.assert_('proxied' not in p1._children) + self.assert_("proxied" in list(p1.children.values())) + self.assert_("b" in p1.children) + self.assert_("proxied" not in p1._children) self.assert_(len(p1.children) == 2) self.assert_(len(p1._children) == 2) - self.assert_(p1._children['a'].name == 'regular') - self.assert_(p1._children['b'].name == 'proxied') + self.assert_(p1._children["a"].name == "regular") + self.assert_(p1._children["b"].name == "proxied") - del p1._children['b'] + del p1._children["b"] self.assert_(len(p1._children) == 1) self.assert_(len(p1.children) == 1) - self.assert_(p1._children['a'] == ch) + self.assert_(p1._children["a"] == ch) - del p1.children['a'] + del p1.children["a"] self.assert_(len(p1._children) == 0) self.assert_(len(p1.children) == 0) - p1.children = {'d': 'v d', 'e': 'v e', 'f': 'v f'} + p1.children = {"d": "v d", "e": "v e", "f": "v f"} self.assert_(len(p1._children) == 3) self.assert_(len(p1.children) == 3) - self.assert_(set(p1.children) == set(['d', 'e', 'f'])) + self.assert_(set(p1.children) == set(["d", "e", "f"])) del ch p1 = self.roundtrip(p1) self.assert_(len(p1._children) == 3) self.assert_(len(p1.children) == 3) - p1.children['e'] = 'changed-in-place' - self.assert_(p1.children['e'] == 'changed-in-place') - inplace_id = p1._children['e'].id + p1.children["e"] = "changed-in-place" + self.assert_(p1.children["e"] == "changed-in-place") + inplace_id = p1._children["e"].id p1 = self.roundtrip(p1) - self.assert_(p1.children['e'] == 'changed-in-place') - self.assert_(p1._children['e'].id == inplace_id) + self.assert_(p1.children["e"] == "changed-in-place") + self.assert_(p1._children["e"].id == inplace_id) p1._children = {} self.assert_(len(p1.children) == 0) @@ -515,12 +565,12 @@ class SetTest(_CollectionOperations): def test_set_operations(self): Parent, Child = self.Parent, self.Child - p1 = Parent('P1') + p1 = Parent("P1") self.assert_(not p1._children) self.assert_(not p1.children) - ch1 = Child('regular') + ch1 = Child("regular") p1._children.add(ch1) self.assert_(ch1 in p1._children) @@ -529,21 +579,22 @@ class SetTest(_CollectionOperations): self.assert_(p1.children) self.assert_(len(p1.children) == 1) self.assert_(ch1 not in p1.children) - self.assert_('regular' in p1.children) + self.assert_("regular" in p1.children) - p1.children.add('proxied') + p1.children.add("proxied") - self.assert_('proxied' in p1.children) - self.assert_('proxied' not in p1._children) + self.assert_("proxied" in p1.children) + self.assert_("proxied" not in p1._children) self.assert_(len(p1.children) == 2) self.assert_(len(p1._children) == 2) - self.assert_(set([o.name for o in p1._children]) == - set(['regular', 'proxied'])) + self.assert_( + set([o.name for o in p1._children]) == set(["regular", "proxied"]) + ) ch2 = None for o in p1._children: - if o.name == 'proxied': + if o.name == "proxied": ch2 = o break @@ -553,12 +604,12 @@ class SetTest(_CollectionOperations): self.assert_(len(p1.children) == 1) self.assert_(p1._children == set([ch1])) - p1.children.remove('regular') + p1.children.remove("regular") self.assert_(len(p1._children) == 0) self.assert_(len(p1.children) == 0) - p1.children = ['a', 'b', 'c'] + p1.children = ["a", "b", "c"] self.assert_(len(p1._children) == 3) self.assert_(len(p1.children) == 3) @@ -568,19 +619,16 @@ class SetTest(_CollectionOperations): self.assert_(len(p1._children) == 3) self.assert_(len(p1.children) == 3) - self.assert_('a' in p1.children) - self.assert_('b' in p1.children) - self.assert_('d' not in p1.children) + self.assert_("a" in p1.children) + self.assert_("b" in p1.children) + self.assert_("d" not in p1.children) - self.assert_(p1.children == set(['a', 'b', 'c'])) + self.assert_(p1.children == set(["a", "b", "c"])) - assert_raises( - KeyError, - p1.children.remove, "d" - ) + assert_raises(KeyError, p1.children.remove, "d") self.assert_(len(p1.children) == 3) - p1.children.discard('d') + p1.children.discard("d") self.assert_(len(p1.children) == 3) p1 = self.roundtrip(p1) self.assert_(len(p1.children) == 3) @@ -592,17 +640,17 @@ class SetTest(_CollectionOperations): self.assert_(len(p1.children) == 2) self.assert_(popped not in p1.children) - p1.children = ['a', 'b', 'c'] + p1.children = ["a", "b", "c"] p1 = self.roundtrip(p1) - self.assert_(p1.children == set(['a', 'b', 'c'])) + self.assert_(p1.children == set(["a", "b", "c"])) - p1.children.discard('b') + p1.children.discard("b") p1 = self.roundtrip(p1) - self.assert_(p1.children == set(['a', 'c'])) + self.assert_(p1.children == set(["a", "c"])) - p1.children.remove('a') + p1.children.remove("a") p1 = self.roundtrip(p1) - self.assert_(p1.children == set(['c'])) + self.assert_(p1.children == set(["c"])) p1._children = set() self.assert_(len(p1.children) == 0) @@ -624,29 +672,30 @@ class SetTest(_CollectionOperations): def test_set_comparisons(self): Parent = self.Parent - p1 = Parent('P1') - p1.children = ['a', 'b', 'c'] - control = set(['a', 'b', 'c']) - - for other in (set(['a', 'b', 'c']), set(['a', 'b', 'c', 'd']), - set(['a']), set(['a', 'b']), - set(['c', 'd']), set(['e', 'f', 'g']), - set()): - - eq_(p1.children.union(other), - control.union(other)) - eq_(p1.children.difference(other), - control.difference(other)) - eq_((p1.children - other), - (control - other)) - eq_(p1.children.intersection(other), - control.intersection(other)) - eq_(p1.children.symmetric_difference(other), - control.symmetric_difference(other)) - eq_(p1.children.issubset(other), - control.issubset(other)) - eq_(p1.children.issuperset(other), - control.issuperset(other)) + p1 = Parent("P1") + p1.children = ["a", "b", "c"] + control = set(["a", "b", "c"]) + + for other in ( + set(["a", "b", "c"]), + set(["a", "b", "c", "d"]), + set(["a"]), + set(["a", "b"]), + set(["c", "d"]), + set(["e", "f", "g"]), + set(), + ): + + eq_(p1.children.union(other), control.union(other)) + eq_(p1.children.difference(other), control.difference(other)) + eq_((p1.children - other), (control - other)) + eq_(p1.children.intersection(other), control.intersection(other)) + eq_( + p1.children.symmetric_difference(other), + control.symmetric_difference(other), + ) + eq_(p1.children.issubset(other), control.issubset(other)) + eq_(p1.children.issuperset(other), control.issuperset(other)) self.assert_((p1.children == other) == (control == other)) self.assert_((p1.children != other) == (control != other)) @@ -660,10 +709,10 @@ class SetTest(_CollectionOperations): # test issue #3265 which was fixed in Python version 2.7.8 Parent = self.Parent - p1 = Parent('P1') + p1 = Parent("P1") p1.children = [] - p2 = Parent('P2') + p2 = Parent("P2") p2.children = [] set_0 = set() @@ -684,14 +733,23 @@ class SetTest(_CollectionOperations): Parent = self.Parent # mutations - for op in ('update', 'intersection_update', - 'difference_update', 'symmetric_difference_update'): - for base in (['a', 'b', 'c'], []): - for other in (set(['a', 'b', 'c']), set(['a', 'b', 'c', 'd']), - set(['a']), set(['a', 'b']), - set(['c', 'd']), set(['e', 'f', 'g']), - set()): - p = Parent('p') + for op in ( + "update", + "intersection_update", + "difference_update", + "symmetric_difference_update", + ): + for base in (["a", "b", "c"], []): + for other in ( + set(["a", "b", "c"]), + set(["a", "b", "c", "d"]), + set(["a"]), + set(["a", "b"]), + set(["c", "d"]), + set(["e", "f", "g"]), + set(), + ): + p = Parent("p") p.children = base[:] control = set(base[:]) @@ -700,9 +758,9 @@ class SetTest(_CollectionOperations): try: self.assert_(p.children == control) except Exception: - print('Test %s.%s(%s):' % (set(base), op, other)) - print('want', repr(control)) - print('got', repr(p.children)) + print("Test %s.%s(%s):" % (set(base), op, other)) + print("want", repr(control)) + print("got", repr(p.children)) raise p = self.roundtrip(p) @@ -710,20 +768,25 @@ class SetTest(_CollectionOperations): try: self.assert_(p.children == control) except Exception: - print('Test %s.%s(%s):' % (base, op, other)) - print('want', repr(control)) - print('got', repr(p.children)) + print("Test %s.%s(%s):" % (base, op, other)) + print("want", repr(control)) + print("got", repr(p.children)) raise # in-place mutations - for op in ('|=', '-=', '&=', '^='): - for base in (['a', 'b', 'c'], []): - for other in (set(['a', 'b', 'c']), set(['a', 'b', 'c', 'd']), - set(['a']), set(['a', 'b']), - set(['c', 'd']), set(['e', 'f', 'g']), - frozenset(['e', 'f', 'g']), - set()): - p = Parent('p') + for op in ("|=", "-=", "&=", "^="): + for base in (["a", "b", "c"], []): + for other in ( + set(["a", "b", "c"]), + set(["a", "b", "c", "d"]), + set(["a"]), + set(["a", "b"]), + set(["c", "d"]), + set(["e", "f", "g"]), + frozenset(["e", "f", "g"]), + set(), + ): + p = Parent("p") p.children = base[:] control = set(base[:]) @@ -733,9 +796,9 @@ class SetTest(_CollectionOperations): try: self.assert_(p.children == control) except Exception: - print('Test %s %s %s:' % (set(base), op, other)) - print('want', repr(control)) - print('got', repr(p.children)) + print("Test %s %s %s:" % (set(base), op, other)) + print("want", repr(control)) + print("got", repr(p.children)) raise p = self.roundtrip(p) @@ -743,9 +806,9 @@ class SetTest(_CollectionOperations): try: self.assert_(p.children == control) except Exception: - print('Test %s %s %s:' % (base, op, other)) - print('want', repr(control)) - print('got', repr(p.children)) + print("Test %s %s %s:" % (base, op, other)) + print("want", repr(control)) + print("got", repr(p.children)) raise @@ -759,10 +822,10 @@ class CustomObjectTest(_CollectionOperations): def test_basic(self): Parent = self.Parent - p = Parent('p1') + p = Parent("p1") self.assert_(len(list(p.children)) == 0) - p.children.append('child') + p.children.append("child") self.assert_(len(list(p.children)) == 1) p = self.roundtrip(p) @@ -770,48 +833,46 @@ class CustomObjectTest(_CollectionOperations): # We didn't provide an alternate _AssociationList implementation # for our ObjectCollection, so indexing will fail. - assert_raises( - TypeError, - p.children.__getitem__, 1 - ) + assert_raises(TypeError, p.children.__getitem__, 1) class ProxyFactoryTest(ListTest): def setup(self): metadata = MetaData(testing.db) - parents_table = Table('Parent', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(128))) - children_table = Table('Children', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, - ForeignKey('Parent.id')), - Column('foo', String(128)), - Column('name', String(128))) + parents_table = Table( + "Parent", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(128)), + ) + children_table = Table( + "Children", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_id", Integer, ForeignKey("Parent.id")), + Column("foo", String(128)), + Column("name", String(128)), + ) class CustomProxy(_AssociationList): - def __init__(self, - lazy_collection, - creator, - value_attr, - parent): + def __init__(self, lazy_collection, creator, value_attr, parent): getter, setter = parent._default_getset(lazy_collection) _AssociationList.__init__( - self, - lazy_collection, - creator, - getter, - setter, - parent, + self, lazy_collection, creator, getter, setter, parent ) class Parent(object): - children = association_proxy('_children', 'name', - proxy_factory=CustomProxy, - proxy_bulk_set=CustomProxy.extend) + children = association_proxy( + "_children", + "name", + proxy_factory=CustomProxy, + proxy_bulk_set=CustomProxy.extend, + ) def __init__(self, name): self.name = name @@ -820,9 +881,15 @@ class ProxyFactoryTest(ListTest): def __init__(self, name): self.name = name - mapper(Parent, parents_table, properties={ - '_children': relationship(Child, lazy='joined', - collection_class=list)}) + mapper( + Parent, + parents_table, + properties={ + "_children": relationship( + Child, lazy="joined", collection_class=list + ) + }, + ) mapper(Child, children_table) metadata.create_all() @@ -840,25 +907,34 @@ class ScalarTest(fixtures.TestBase): def test_scalar_proxy(self): metadata = self.metadata - parents_table = Table('Parent', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(128))) - children_table = Table('Children', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, - ForeignKey('Parent.id')), - Column('foo', String(128)), - Column('bar', String(128)), - Column('baz', String(128))) + parents_table = Table( + "Parent", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(128)), + ) + children_table = Table( + "Children", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_id", Integer, ForeignKey("Parent.id")), + Column("foo", String(128)), + Column("bar", String(128)), + Column("baz", String(128)), + ) class Parent(object): - foo = association_proxy('child', 'foo') - bar = association_proxy('child', 'bar', - creator=lambda v: Child(bar=v)) - baz = association_proxy('child', 'baz', - creator=lambda v: Child(baz=v)) + foo = association_proxy("child", "foo") + bar = association_proxy( + "child", "bar", creator=lambda v: Child(bar=v) + ) + baz = association_proxy( + "child", "baz", creator=lambda v: Child(baz=v) + ) def __init__(self, name): self.name = name @@ -868,9 +944,15 @@ class ScalarTest(fixtures.TestBase): for attr in kw: setattr(self, attr, kw[attr]) - mapper(Parent, parents_table, properties={ - 'child': relationship(Child, lazy='joined', - backref='parent', uselist=False)}) + mapper( + Parent, + parents_table, + properties={ + "child": relationship( + Child, lazy="joined", backref="parent", uselist=False + ) + }, + ) mapper(Child, children_table) metadata.create_all() @@ -884,42 +966,39 @@ class ScalarTest(fixtures.TestBase): session.expunge_all() return session.query(type_).get(id) - p = Parent('p') + p = Parent("p") eq_(p.child, None) eq_(p.foo, None) - p.child = Child(foo='a', bar='b', baz='c') + p.child = Child(foo="a", bar="b", baz="c") - self.assert_(p.foo == 'a') - self.assert_(p.bar == 'b') - self.assert_(p.baz == 'c') + self.assert_(p.foo == "a") + self.assert_(p.bar == "b") + self.assert_(p.baz == "c") - p.bar = 'x' - self.assert_(p.foo == 'a') - self.assert_(p.bar == 'x') - self.assert_(p.baz == 'c') + p.bar = "x" + self.assert_(p.foo == "a") + self.assert_(p.bar == "x") + self.assert_(p.baz == "c") p = roundtrip(p) - self.assert_(p.foo == 'a') - self.assert_(p.bar == 'x') - self.assert_(p.baz == 'c') + self.assert_(p.foo == "a") + self.assert_(p.bar == "x") + self.assert_(p.baz == "c") p.child = None eq_(p.foo, None) # Bogus creator for this scalar type - assert_raises( - TypeError, - setattr, p, "foo", "zzz" - ) + assert_raises(TypeError, setattr, p, "foo", "zzz") - p.bar = 'yyy' + p.bar = "yyy" self.assert_(p.foo is None) - self.assert_(p.bar == 'yyy') + self.assert_(p.bar == "yyy") self.assert_(p.baz is None) del p.child @@ -928,37 +1007,46 @@ class ScalarTest(fixtures.TestBase): self.assert_(p.child is None) - p.baz = 'xxx' + p.baz = "xxx" self.assert_(p.foo is None) self.assert_(p.bar is None) - self.assert_(p.baz == 'xxx') + self.assert_(p.baz == "xxx") p = roundtrip(p) self.assert_(p.foo is None) self.assert_(p.bar is None) - self.assert_(p.baz == 'xxx') + self.assert_(p.baz == "xxx") # Ensure an immediate __set__ works. - p2 = Parent('p2') - p2.bar = 'quux' + p2 = Parent("p2") + p2.bar = "quux" @testing.provide_metadata def test_empty_scalars(self): metadata = self.metadata - a = Table('a', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50))) - a2b = Table('a2b', metadata, - Column('id', Integer, primary_key=True), - Column('id_a', Integer, ForeignKey('a.id')), - Column('id_b', Integer, ForeignKey('b.id')), - Column('name', String(50))) - b = Table('b', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50))) + a = Table( + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + ) + a2b = Table( + "a2b", + metadata, + Column("id", Integer, primary_key=True), + Column("id_a", Integer, ForeignKey("a.id")), + Column("id_b", Integer, ForeignKey("b.id")), + Column("name", String(50)), + ) + b = Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + ) class A(object): a2b_name = association_proxy("a2b_single", "name") @@ -970,13 +1058,11 @@ class ScalarTest(fixtures.TestBase): class B(object): pass - mapper(A, a, properties=dict( - a2b_single=relationship(A2B, uselist=False) - )) + mapper( + A, a, properties=dict(a2b_single=relationship(A2B, uselist=False)) + ) - mapper(A2B, a2b, properties=dict( - b=relationship(B) - )) + mapper(A2B, a2b, properties=dict(b=relationship(B))) mapper(B, b) a1 = A() @@ -985,32 +1071,38 @@ class ScalarTest(fixtures.TestBase): def test_custom_getset(self): metadata = MetaData() - p = Table('p', metadata, - Column('id', Integer, primary_key=True), - Column('cid', Integer, ForeignKey('c.id'))) - c = Table('c', metadata, - Column('id', Integer, primary_key=True), - Column('foo', String(128))) + p = Table( + "p", + metadata, + Column("id", Integer, primary_key=True), + Column("cid", Integer, ForeignKey("c.id")), + ) + c = Table( + "c", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", String(128)), + ) get = Mock() set_ = Mock() class Parent(object): - foo = association_proxy('child', 'foo', - getset_factory=lambda cc, - parent: (get, set_)) + foo = association_proxy( + "child", "foo", getset_factory=lambda cc, parent: (get, set_) + ) class Child(object): def __init__(self, foo): self.foo = foo - mapper(Parent, p, properties={'child': relationship(Child)}) + mapper(Parent, p, properties={"child": relationship(Child)}) mapper(Child, c) p1 = Parent() eq_(p1.foo, get(None)) - p1.child = child = Child(foo='x') + p1.child = child = Child(foo="x") eq_(p1.foo, get(child)) p1.foo = "y" eq_(set_.mock_calls, [call(child, "y")]) @@ -1020,20 +1112,27 @@ class LazyLoadTest(fixtures.TestBase): def setup(self): metadata = MetaData(testing.db) - parents_table = Table('Parent', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(128))) - children_table = Table('Children', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, - ForeignKey('Parent.id')), - Column('foo', String(128)), - Column('name', String(128))) + parents_table = Table( + "Parent", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(128)), + ) + children_table = Table( + "Children", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_id", Integer, ForeignKey("Parent.id")), + Column("foo", String(128)), + Column("name", String(128)), + ) class Parent(object): - children = association_proxy('_children', 'name') + children = association_proxy("_children", "name") def __init__(self, name): self.name = name @@ -1063,78 +1162,106 @@ class LazyLoadTest(fixtures.TestBase): def test_lazy_list(self): Parent, Child = self.Parent, self.Child - mapper(Parent, self.table, properties={ - '_children': relationship(Child, lazy='select', - collection_class=list)}) + mapper( + Parent, + self.table, + properties={ + "_children": relationship( + Child, lazy="select", collection_class=list + ) + }, + ) - p = Parent('p') - p.children = ['a', 'b', 'c'] + p = Parent("p") + p.children = ["a", "b", "c"] p = self.roundtrip(p) # Is there a better way to ensure that the association_proxy # didn't convert a lazy load to an eager load? This does work though. - self.assert_('_children' not in p.__dict__) + self.assert_("_children" not in p.__dict__) self.assert_(len(p._children) == 3) - self.assert_('_children' in p.__dict__) + self.assert_("_children" in p.__dict__) def test_eager_list(self): Parent, Child = self.Parent, self.Child - mapper(Parent, self.table, properties={ - '_children': relationship(Child, lazy='joined', - collection_class=list)}) + mapper( + Parent, + self.table, + properties={ + "_children": relationship( + Child, lazy="joined", collection_class=list + ) + }, + ) - p = Parent('p') - p.children = ['a', 'b', 'c'] + p = Parent("p") + p.children = ["a", "b", "c"] p = self.roundtrip(p) - self.assert_('_children' in p.__dict__) + self.assert_("_children" in p.__dict__) self.assert_(len(p._children) == 3) def test_slicing_list(self): Parent, Child = self.Parent, self.Child - mapper(Parent, self.table, properties={ - '_children': relationship(Child, lazy='select', - collection_class=list)}) + mapper( + Parent, + self.table, + properties={ + "_children": relationship( + Child, lazy="select", collection_class=list + ) + }, + ) - p = Parent('p') - p.children = ['a', 'b', 'c'] + p = Parent("p") + p.children = ["a", "b", "c"] p = self.roundtrip(p) self.assert_(len(p._children) == 3) - eq_('b', p.children[1]) - eq_(['b', 'c'], p.children[-2:]) + eq_("b", p.children[1]) + eq_(["b", "c"], p.children[-2:]) def test_lazy_scalar(self): Parent, Child = self.Parent, self.Child - mapper(Parent, self.table, properties={ - '_children': relationship(Child, lazy='select', uselist=False)}) + mapper( + Parent, + self.table, + properties={ + "_children": relationship(Child, lazy="select", uselist=False) + }, + ) - p = Parent('p') - p.children = 'value' + p = Parent("p") + p.children = "value" p = self.roundtrip(p) - self.assert_('_children' not in p.__dict__) + self.assert_("_children" not in p.__dict__) self.assert_(p._children is not None) def test_eager_scalar(self): Parent, Child = self.Parent, self.Child - mapper(Parent, self.table, properties={ - '_children': relationship(Child, lazy='joined', uselist=False)}) + mapper( + Parent, + self.table, + properties={ + "_children": relationship(Child, lazy="joined", uselist=False) + }, + ) - p = Parent('p') - p.children = 'value' + p = Parent("p") + p.children = "value" p = self.roundtrip(p) - self.assert_('_children' in p.__dict__) + self.assert_("_children" in p.__dict__) self.assert_(p._children is not None) @@ -1155,99 +1282,117 @@ class KVChild(object): class ReconstitutionTest(fixtures.TestBase): - def setup(self): metadata = MetaData(testing.db) - parents = Table('parents', metadata, Column('id', Integer, - primary_key=True, - test_needs_autoincrement=True), Column('name', - String(30))) - children = Table('children', metadata, Column('id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, - ForeignKey('parents.id')), Column('name', - String(30))) + parents = Table( + "parents", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30)), + ) + children = Table( + "children", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_id", Integer, ForeignKey("parents.id")), + Column("name", String(30)), + ) metadata.create_all() - parents.insert().execute(name='p1') + parents.insert().execute(name="p1") self.metadata = metadata self.parents = parents self.children = children - Parent.kids = association_proxy('children', 'name') + Parent.kids = association_proxy("children", "name") def teardown(self): self.metadata.drop_all() clear_mappers() def test_weak_identity_map(self): - mapper(Parent, self.parents, - properties=dict(children=relationship(Child))) + mapper( + Parent, self.parents, properties=dict(children=relationship(Child)) + ) mapper(Child, self.children) session = create_session(weak_identity_map=True) def add_child(parent_name, child_name): - parent = \ - session.query(Parent).filter_by(name=parent_name).one() + parent = session.query(Parent).filter_by(name=parent_name).one() parent.kids.append(child_name) - add_child('p1', 'c1') + add_child("p1", "c1") gc_collect() - add_child('p1', 'c2') + add_child("p1", "c2") session.flush() - p = session.query(Parent).filter_by(name='p1').one() - assert set(p.kids) == set(['c1', 'c2']), p.kids + p = session.query(Parent).filter_by(name="p1").one() + assert set(p.kids) == set(["c1", "c2"]), p.kids def test_copy(self): - mapper(Parent, self.parents, - properties=dict(children=relationship(Child))) + mapper( + Parent, self.parents, properties=dict(children=relationship(Child)) + ) mapper(Child, self.children) - p = Parent('p1') - p.kids.extend(['c1', 'c2']) + p = Parent("p1") + p.kids.extend(["c1", "c2"]) p_copy = copy.copy(p) del p gc_collect() - assert set(p_copy.kids) == set(['c1', 'c2']), p_copy.kids + assert set(p_copy.kids) == set(["c1", "c2"]), p_copy.kids def test_pickle_list(self): - mapper(Parent, self.parents, - properties=dict(children=relationship(Child))) + mapper( + Parent, self.parents, properties=dict(children=relationship(Child)) + ) mapper(Child, self.children) - p = Parent('p1') - p.kids.extend(['c1', 'c2']) + p = Parent("p1") + p.kids.extend(["c1", "c2"]) r1 = pickle.loads(pickle.dumps(p)) - assert r1.kids == ['c1', 'c2'] + assert r1.kids == ["c1", "c2"] # can't do this without parent having a cycle # r2 = pickle.loads(pickle.dumps(p.kids)) # assert r2 == ['c1', 'c2'] def test_pickle_set(self): - mapper(Parent, self.parents, - properties=dict(children=relationship(Child, - collection_class=set))) + mapper( + Parent, + self.parents, + properties=dict( + children=relationship(Child, collection_class=set) + ), + ) mapper(Child, self.children) - p = Parent('p1') - p.kids.update(['c1', 'c2']) + p = Parent("p1") + p.kids.update(["c1", "c2"]) r1 = pickle.loads(pickle.dumps(p)) - assert r1.kids == set(['c1', 'c2']) + assert r1.kids == set(["c1", "c2"]) # can't do this without parent having a cycle # r2 = pickle.loads(pickle.dumps(p.kids)) # assert r2 == set(['c1', 'c2']) def test_pickle_dict(self): - mapper(Parent, self.parents, - properties=dict( - children=relationship( - KVChild, - collection_class=collections.mapped_collection( - PickleKeyFunc('name'))))) + mapper( + Parent, + self.parents, + properties=dict( + children=relationship( + KVChild, + collection_class=collections.mapped_collection( + PickleKeyFunc("name") + ), + ) + ), + ) mapper(KVChild, self.children) - p = Parent('p1') - p.kids.update({'c1': 'v1', 'c2': 'v2'}) - assert p.kids == {'c1': 'c1', 'c2': 'c2'} + p = Parent("p1") + p.kids.update({"c1": "v1", "c2": "v2"}) + assert p.kids == {"c1": "c1", "c2": "c2"} r1 = pickle.loads(pickle.dumps(p)) - assert r1.kids == {'c1': 'c1', 'c2': 'c2'} + assert r1.kids == {"c1": "c1", "c2": "c2"} # can't do this without parent having a cycle # r2 = pickle.loads(pickle.dumps(p.kids)) @@ -1263,34 +1408,53 @@ class PickleKeyFunc(object): class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" - run_inserts = 'once' + run_inserts = "once" run_deletes = None - run_setup_mappers = 'once' - run_setup_classes = 'once' + run_setup_mappers = "once" + run_setup_classes = "once" @classmethod def define_tables(cls, metadata): - Table('userkeywords', metadata, - Column('keyword_id', Integer, ForeignKey('keywords.id'), - primary_key=True), - Column('user_id', Integer, ForeignKey('users.id')), - Column('value', String(50))) - Table('users', metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('name', String(64)), - Column('singular_id', Integer, ForeignKey('singular.id'))) - Table('keywords', metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('keyword', String(64)), - Column('singular_id', Integer, ForeignKey('singular.id'))) - Table('singular', metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('value', String(50))) + Table( + "userkeywords", + metadata, + Column( + "keyword_id", + Integer, + ForeignKey("keywords.id"), + primary_key=True, + ), + Column("user_id", Integer, ForeignKey("users.id")), + Column("value", String(50)), + ) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(64)), + Column("singular_id", Integer, ForeignKey("singular.id")), + ) + Table( + "keywords", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("keyword", String(64)), + Column("singular_id", Integer, ForeignKey("singular.id")), + ) + Table( + "singular", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("value", String(50)), + ) @classmethod def setup_classes(cls): @@ -1301,20 +1465,21 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): # o2m -> m2o # uselist -> nonuselist keywords = association_proxy( - 'user_keywords', - 'keyword', - creator=lambda k: UserKeyword(keyword=k)) + "user_keywords", + "keyword", + creator=lambda k: UserKeyword(keyword=k), + ) # m2o -> o2m # nonuselist -> uselist - singular_keywords = association_proxy('singular', 'keywords') + singular_keywords = association_proxy("singular", "keywords") # m2o -> scalar # nonuselist - singular_value = association_proxy('singular', 'value') + singular_value = association_proxy("singular", "value") # o2m -> scalar - singular_collection = association_proxy('user_keywords', 'value') + singular_collection = association_proxy("user_keywords", "value") # uselist assoc_proxy -> assoc_proxy -> obj common_users = association_proxy("user_keywords", "common_users") @@ -1327,7 +1492,8 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): # uselist assoc_proxy -> assoc_proxy -> scalar common_keyword_name = association_proxy( - "user_keywords", "keyword_name") + "user_keywords", "keyword_name" + ) class Keyword(cls.Comparable): def __init__(self, keyword): @@ -1335,7 +1501,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): # o2o -> m2o # nonuselist -> nonuselist - user = association_proxy('user_keyword', 'user') + user = association_proxy("user_keyword", "user") # uselist assoc_proxy -> collection -> assoc_proxy -> scalar object # (o2m relationship, @@ -1360,64 +1526,70 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): @classmethod def setup_mappers(cls): - users, Keyword, UserKeyword, singular, \ - userkeywords, User, keywords, Singular = (cls.tables.users, - cls.classes.Keyword, - cls.classes.UserKeyword, - cls.tables.singular, - cls.tables.userkeywords, - cls.classes.User, - cls.tables.keywords, - cls.classes.Singular) - - mapper(User, users, properties={ - 'singular': relationship(Singular) - }) - mapper(Keyword, keywords, properties={ - 'user_keyword': relationship(UserKeyword, uselist=False), - 'user_keywords': relationship(UserKeyword) - }) - - mapper(UserKeyword, userkeywords, properties={ - 'user': relationship(User, backref='user_keywords'), - 'keyword': relationship(Keyword) - }) - mapper(Singular, singular, properties={ - 'keywords': relationship(Keyword) - }) + users, Keyword, UserKeyword, singular, userkeywords, User, keywords, Singular = ( + cls.tables.users, + cls.classes.Keyword, + cls.classes.UserKeyword, + cls.tables.singular, + cls.tables.userkeywords, + cls.classes.User, + cls.tables.keywords, + cls.classes.Singular, + ) + + mapper(User, users, properties={"singular": relationship(Singular)}) + mapper( + Keyword, + keywords, + properties={ + "user_keyword": relationship(UserKeyword, uselist=False), + "user_keywords": relationship(UserKeyword), + }, + ) + + mapper( + UserKeyword, + userkeywords, + properties={ + "user": relationship(User, backref="user_keywords"), + "keyword": relationship(Keyword), + }, + ) + mapper( + Singular, singular, properties={"keywords": relationship(Keyword)} + ) @classmethod def insert_data(cls): - UserKeyword, User, Keyword, Singular = (cls.classes.UserKeyword, - cls.classes.User, - cls.classes.Keyword, - cls.classes.Singular) + UserKeyword, User, Keyword, Singular = ( + cls.classes.UserKeyword, + cls.classes.User, + cls.classes.Keyword, + cls.classes.Singular, + ) session = sessionmaker()() - words = ( - 'quick', 'brown', - 'fox', 'jumped', 'over', - 'the', 'lazy', - ) + words = ("quick", "brown", "fox", "jumped", "over", "the", "lazy") for ii in range(16): - user = User('user%d' % ii) + user = User("user%d" % ii) if ii % 2 == 0: - user.singular = Singular(value=("singular%d" % ii) - if ii % 4 == 0 else None) + user.singular = Singular( + value=("singular%d" % ii) if ii % 4 == 0 else None + ) session.add(user) - for jj in words[(ii % len(words)):((ii + 3) % len(words))]: + for jj in words[(ii % len(words)) : ((ii + 3) % len(words))]: k = Keyword(jj) user.keywords.append(k) if ii % 2 == 0: user.singular.keywords.append(k) user.user_keywords[-1].value = "singular%d" % ii - orphan = Keyword('orphan') + orphan = Keyword("orphan") orphan.user_keyword = UserKeyword(keyword=orphan, user=None) session.add(orphan) - keyword_with_nothing = Keyword('kwnothing') + keyword_with_nothing = Keyword("kwnothing") session.add(keyword_with_nothing) session.commit() @@ -1428,7 +1600,8 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): def _equivalent(self, q_proxy, q_direct): proxy_sql = q_proxy.statement.compile(dialect=default.DefaultDialect()) direct_sql = q_direct.statement.compile( - dialect=default.DefaultDialect()) + dialect=default.DefaultDialect() + ) eq_(str(proxy_sql), str(direct_sql)) eq_(q_proxy.all(), q_direct.all()) @@ -1436,7 +1609,8 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): UserKeyword, User = self.classes.UserKeyword, self.classes.User q1 = self.session.query(User).filter( - User.singular_collection.any(UserKeyword.value == 'singular8')) + User.singular_collection.any(UserKeyword.value == "singular8") + ) self.assert_compile( q1, "SELECT users.id AS users_id, users.name AS users_name, " @@ -1446,97 +1620,126 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): "FROM userkeywords " "WHERE users.id = userkeywords.user_id AND " "userkeywords.value = :value_1)", - checkparams={'value_1': 'singular8'} + checkparams={"value_1": "singular8"}, ) q2 = self.session.query(User).filter( - User.user_keywords.any(UserKeyword.value == 'singular8')) + User.user_keywords.any(UserKeyword.value == "singular8") + ) self._equivalent(q1, q2) def test_filter_any_kwarg_ul_nul(self): UserKeyword, User = self.classes.UserKeyword, self.classes.User - self._equivalent(self.session.query(User). - filter(User.keywords.any(keyword='jumped')), - self.session.query(User).filter( - User.user_keywords.any( - UserKeyword.keyword.has(keyword='jumped')))) + self._equivalent( + self.session.query(User).filter( + User.keywords.any(keyword="jumped") + ), + self.session.query(User).filter( + User.user_keywords.any( + UserKeyword.keyword.has(keyword="jumped") + ) + ), + ) def test_filter_has_kwarg_nul_nul(self): UserKeyword, Keyword = self.classes.UserKeyword, self.classes.Keyword - self._equivalent(self.session.query(Keyword). - filter(Keyword.user.has(name='user2')), - self.session.query(Keyword). - filter(Keyword.user_keyword.has( - UserKeyword.user.has(name='user2')))) + self._equivalent( + self.session.query(Keyword).filter(Keyword.user.has(name="user2")), + self.session.query(Keyword).filter( + Keyword.user_keyword.has(UserKeyword.user.has(name="user2")) + ), + ) def test_filter_has_kwarg_nul_ul(self): User, Singular = self.classes.User, self.classes.Singular self._equivalent( self.session.query(User).filter( - User.singular_keywords.any(keyword='jumped')), + User.singular_keywords.any(keyword="jumped") + ), self.session.query(User).filter( - User.singular.has(Singular.keywords.any(keyword='jumped')))) + User.singular.has(Singular.keywords.any(keyword="jumped")) + ), + ) def test_filter_any_criterion_ul_nul(self): - UserKeyword, User, Keyword = (self.classes.UserKeyword, - self.classes.User, - self.classes.Keyword) + UserKeyword, User, Keyword = ( + self.classes.UserKeyword, + self.classes.User, + self.classes.Keyword, + ) self._equivalent( self.session.query(User).filter( - User.keywords.any(Keyword.keyword == 'jumped')), + User.keywords.any(Keyword.keyword == "jumped") + ), self.session.query(User).filter( User.user_keywords.any( - UserKeyword.keyword.has(Keyword.keyword == 'jumped')))) + UserKeyword.keyword.has(Keyword.keyword == "jumped") + ) + ), + ) def test_filter_has_criterion_nul_nul(self): - UserKeyword, User, Keyword = (self.classes.UserKeyword, - self.classes.User, - self.classes.Keyword) + UserKeyword, User, Keyword = ( + self.classes.UserKeyword, + self.classes.User, + self.classes.Keyword, + ) - self._equivalent(self.session.query(Keyword). - filter(Keyword.user.has(User.name == 'user2')), - self.session.query(Keyword). - filter(Keyword.user_keyword.has( - UserKeyword.user.has(User.name == 'user2')))) + self._equivalent( + self.session.query(Keyword).filter( + Keyword.user.has(User.name == "user2") + ), + self.session.query(Keyword).filter( + Keyword.user_keyword.has( + UserKeyword.user.has(User.name == "user2") + ) + ), + ) def test_filter_any_criterion_nul_ul(self): - User, Keyword, Singular = (self.classes.User, - self.classes.Keyword, - self.classes.Singular) + User, Keyword, Singular = ( + self.classes.User, + self.classes.Keyword, + self.classes.Singular, + ) self._equivalent( - self.session.query(User). - filter(User.singular_keywords.any( - Keyword.keyword == 'jumped')), - self.session.query(User). - filter(User.singular.has( - Singular.keywords.any(Keyword.keyword == 'jumped')))) + self.session.query(User).filter( + User.singular_keywords.any(Keyword.keyword == "jumped") + ), + self.session.query(User).filter( + User.singular.has( + Singular.keywords.any(Keyword.keyword == "jumped") + ) + ), + ) def test_filter_contains_ul_nul(self): User = self.classes.User - self._equivalent(self.session.query(User). - filter(User.keywords.contains(self.kw)), - self.session.query(User). - filter(User.user_keywords.any(keyword=self.kw))) + self._equivalent( + self.session.query(User).filter(User.keywords.contains(self.kw)), + self.session.query(User).filter( + User.user_keywords.any(keyword=self.kw) + ), + ) def test_filter_contains_nul_ul(self): User, Singular = self.classes.User, self.classes.Singular with expect_warnings( - "Got None for value of column keywords.singular_id;"): + "Got None for value of column keywords.singular_id;" + ): self._equivalent( self.session.query(User).filter( User.singular_keywords.contains(self.kw) ), self.session.query(User).filter( - User.singular.has( - Singular.keywords.contains(self.kw) - ) + User.singular.has(Singular.keywords.contains(self.kw)) ), ) @@ -1545,8 +1748,9 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): self._equivalent( self.session.query(Keyword).filter(Keyword.user == self.u), - self.session.query(Keyword). - filter(Keyword.user_keyword.has(user=self.u)) + self.session.query(Keyword).filter( + Keyword.user_keyword.has(user=self.u) + ), ) def test_filter_ne_nul_nul(self): @@ -1556,7 +1760,9 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): self._equivalent( self.session.query(Keyword).filter(Keyword.user != self.u), self.session.query(Keyword).filter( - Keyword.user_keyword.has(UserKeyword.user != self.u))) + Keyword.user_keyword.has(UserKeyword.user != self.u) + ), + ) def test_filter_eq_null_nul_nul(self): UserKeyword, Keyword = self.classes.UserKeyword, self.classes.Keyword @@ -1564,17 +1770,22 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): self._equivalent( self.session.query(Keyword).filter(Keyword.user == None), # noqa self.session.query(Keyword).filter( - or_(Keyword.user_keyword.has(UserKeyword.user == None), - Keyword.user_keyword == None))) + or_( + Keyword.user_keyword.has(UserKeyword.user == None), + Keyword.user_keyword == None, + ) + ), + ) def test_filter_ne_null_nul_nul(self): UserKeyword, Keyword = self.classes.UserKeyword, self.classes.Keyword self._equivalent( - self.session.query(Keyword).filter( - Keyword.user != None), # noqa - self.session.query(Keyword).filter( - Keyword.user_keyword.has(UserKeyword.user != None))) + self.session.query(Keyword).filter(Keyword.user != None), # noqa + self.session.query(Keyword).filter( + Keyword.user_keyword.has(UserKeyword.user != None) + ), + ) def test_filter_object_eq_None_nul(self): UserKeyword = self.classes.UserKeyword @@ -1582,11 +1793,14 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): self._equivalent( self.session.query(UserKeyword).filter( - UserKeyword.singular == None), # noqa - self.session.query(UserKeyword).filter(or_( - UserKeyword.user.has(User.singular == None), - UserKeyword.user_id == None) - ) + UserKeyword.singular == None + ), # noqa + self.session.query(UserKeyword).filter( + or_( + UserKeyword.user.has(User.singular == None), + UserKeyword.user_id == None, + ) + ), ) def test_filter_column_eq_None_nul(self): @@ -1595,11 +1809,14 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): self._equivalent( self.session.query(User).filter( - User.singular_value == None), # noqa - self.session.query(User).filter(or_( - User.singular.has(Singular.value == None), - User.singular == None) - ) + User.singular_value == None + ), # noqa + self.session.query(User).filter( + or_( + User.singular.has(Singular.value == None), + User.singular == None, + ) + ), ) def test_filter_object_ne_value_nul(self): @@ -1609,10 +1826,11 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): s4 = self.session.query(Singular).filter_by(value="singular4").one() self._equivalent( + self.session.query(UserKeyword).filter(UserKeyword.singular != s4), self.session.query(UserKeyword).filter( - UserKeyword.singular != s4), - self.session.query(UserKeyword).filter( - UserKeyword.user.has(User.singular != s4))) + UserKeyword.user.has(User.singular != s4) + ), + ) def test_filter_column_ne_value_nul(self): User = self.classes.User @@ -1620,9 +1838,12 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): self._equivalent( self.session.query(User).filter( - User.singular_value != "singular4"), + User.singular_value != "singular4" + ), self.session.query(User).filter( - User.singular.has(Singular.value != "singular4"))) + User.singular.has(Singular.value != "singular4") + ), + ) def test_filter_eq_value_nul(self): User = self.classes.User @@ -1630,9 +1851,12 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): self._equivalent( self.session.query(User).filter( - User.singular_value == "singular4"), + User.singular_value == "singular4" + ), self.session.query(User).filter( - User.singular.has(Singular.value == "singular4"))) + User.singular.has(Singular.value == "singular4") + ), + ) def test_filter_ne_None_nul(self): User = self.classes.User @@ -1640,9 +1864,12 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): self._equivalent( self.session.query(User).filter( - User.singular_value != None), # noqa + User.singular_value != None + ), # noqa self.session.query(User).filter( - User.singular.has(Singular.value != None))) + User.singular.has(Singular.value != None) + ), + ) def test_has_nul(self): # a special case where we provide an empty has() on a @@ -1652,9 +1879,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): self._equivalent( self.session.query(User).filter(User.singular_value.has()), - self.session.query(User).filter( - User.singular.has(), - ) + self.session.query(User).filter(User.singular.has()), ) def test_nothas_nul(self): @@ -1665,9 +1890,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): self._equivalent( self.session.query(User).filter(~User.singular_value.has()), - self.session.query(User).filter( - ~User.singular.has(), - ) + self.session.query(User).filter(~User.singular.has()), ) def test_filter_any_chained(self): @@ -1677,11 +1900,10 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): Keyword = self.classes.Keyword q1 = self.session.query(User).filter( - User.common_users.any(User.name == 'user7') + User.common_users.any(User.name == "user7") ) self.assert_compile( q1, - "SELECT users.id AS users_id, users.name AS users_name, " "users.singular_id AS users_singular_id " "FROM users " @@ -1697,18 +1919,18 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): "FROM users " "WHERE users.id = userkeywords.user_id AND users.name = :name_1)" "))))))", - checkparams={'name_1': 'user7'} + checkparams={"name_1": "user7"}, ) q2 = self.session.query(User).filter( User.user_keywords.any( UserKeyword.keyword.has( Keyword.user_keyword.has( - UserKeyword.user.has( - User.name == 'user7' - ) + UserKeyword.user.has(User.name == "user7") ) - ))) + ) + ) + ) self._equivalent(q1, q2) def test_filter_has_chained_has_to_any(self): @@ -1717,11 +1939,10 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): Keyword = self.classes.Keyword q1 = self.session.query(User).filter( - User.common_singular.has(Keyword.keyword == 'brown') + User.common_singular.has(Keyword.keyword == "brown") ) self.assert_compile( q1, - "SELECT users.id AS users_id, users.name AS users_name, " "users.singular_id AS users_singular_id " "FROM users " @@ -1731,12 +1952,14 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): "FROM keywords " "WHERE singular.id = keywords.singular_id AND " "keywords.keyword = :keyword_1)))", - checkparams={'keyword_1': 'brown'} + checkparams={"keyword_1": "brown"}, ) q2 = self.session.query(User).filter( User.singular.has( - Singular.keywords.any(Keyword.keyword == 'brown'))) + Singular.keywords.any(Keyword.keyword == "brown") + ) + ) self._equivalent(q1, q2) def test_filter_has_scalar_raises(self): @@ -1744,7 +1967,8 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): assert_raises_message( exc.ArgumentError, r"Can't apply keyword arguments to column-targeted", - User.singular_keyword.has, keyword="brown" + User.singular_keyword.has, + keyword="brown", ) def test_filter_eq_chained_has_to_any(self): @@ -1752,9 +1976,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): Keyword = self.classes.Keyword Singular = self.classes.Singular - q1 = self.session.query(User).filter( - User.singular_keyword == "brown" - ) + q1 = self.session.query(User).filter(User.singular_keyword == "brown") self.assert_compile( q1, "SELECT users.id AS users_id, users.name AS users_name, " @@ -1766,13 +1988,11 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): "FROM keywords " "WHERE singular.id = keywords.singular_id " "AND keywords.keyword = :keyword_1)))", - checkparams={'keyword_1': 'brown'} + checkparams={"keyword_1": "brown"}, ) q2 = self.session.query(User).filter( User.singular.has( - Singular.keywords.any( - Keyword.keyword == 'brown' - ) + Singular.keywords.any(Keyword.keyword == "brown") ) ) @@ -1797,7 +2017,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): "FROM keywords " "WHERE keywords.id = userkeywords.keyword_id AND " "keywords.keyword = :keyword_1)))", - checkparams={'keyword_1': 'brown'} + checkparams={"keyword_1": "brown"}, ) q2 = self.session.query(User).filter( @@ -1831,7 +2051,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): "FROM users " "WHERE users.id = userkeywords.user_id AND " ":param_1 = users.singular_id)))", - checkparams={"param_1": singular.id} + checkparams={"param_1": singular.id}, ) q2 = self.session.query(Keyword).filter( @@ -1850,7 +2070,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): exc.ArgumentError, r"Non-empty has\(\) not allowed", User.singular_value.has, - User.singular_value == "singular4" + User.singular_value == "singular4", ) def test_has_kwargs_nul(self): @@ -1861,30 +2081,33 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): assert_raises_message( exc.ArgumentError, r"Can't apply keyword arguments to column-targeted", - User.singular_value.has, singular_value="singular4" + User.singular_value.has, + singular_value="singular4", ) def test_filter_scalar_object_contains_fails_nul_nul(self): Keyword = self.classes.Keyword - assert_raises(exc.InvalidRequestError, - lambda: Keyword.user.contains(self.u)) + assert_raises( + exc.InvalidRequestError, lambda: Keyword.user.contains(self.u) + ) def test_filter_scalar_object_any_fails_nul_nul(self): Keyword = self.classes.Keyword - assert_raises(exc.InvalidRequestError, - lambda: Keyword.user.any(name='user2')) + assert_raises( + exc.InvalidRequestError, lambda: Keyword.user.any(name="user2") + ) def test_filter_scalar_column_like(self): User = self.classes.User Singular = self.classes.Singular self._equivalent( - self.session.query(User).filter(User.singular_value.like('foo')), + self.session.query(User).filter(User.singular_value.like("foo")), self.session.query(User).filter( - User.singular.has(Singular.value.like('foo')), - ) + User.singular.has(Singular.value.like("foo")) + ), ) def test_filter_scalar_column_contains(self): @@ -1892,10 +2115,12 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): Singular = self.classes.Singular self._equivalent( - self.session.query(User).filter(User.singular_value.contains('foo')), self.session.query(User).filter( - User.singular.has(Singular.value.contains('foo')), - ) + User.singular_value.contains("foo") + ), + self.session.query(User).filter( + User.singular.has(Singular.value.contains("foo")) + ), ) def test_filter_scalar_column_eq(self): @@ -1903,10 +2128,10 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): Singular = self.classes.Singular self._equivalent( - self.session.query(User).filter(User.singular_value == 'foo'), + self.session.query(User).filter(User.singular_value == "foo"), self.session.query(User).filter( - User.singular.has(Singular.value == 'foo'), - ) + User.singular.has(Singular.value == "foo") + ), ) def test_filter_scalar_column_ne(self): @@ -1914,10 +2139,10 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): Singular = self.classes.Singular self._equivalent( - self.session.query(User).filter(User.singular_value != 'foo'), + self.session.query(User).filter(User.singular_value != "foo"), self.session.query(User).filter( - User.singular.has(Singular.value != 'foo'), - ) + User.singular.has(Singular.value != "foo") + ), ) def test_filter_scalar_column_eq_nul(self): @@ -1926,53 +2151,57 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): self._equivalent( self.session.query(User).filter(User.singular_value == None), - self.session.query(User).filter(or_( - User.singular.has(Singular.value == None), - User.singular == None - )) + self.session.query(User).filter( + or_( + User.singular.has(Singular.value == None), + User.singular == None, + ) + ), ) def test_filter_collection_has_fails_ul_nul(self): User = self.classes.User - assert_raises(exc.InvalidRequestError, - lambda: User.keywords.has(keyword='quick')) + assert_raises( + exc.InvalidRequestError, lambda: User.keywords.has(keyword="quick") + ) def test_filter_collection_eq_fails_ul_nul(self): User = self.classes.User - assert_raises(exc.InvalidRequestError, - lambda: User.keywords == self.kw) + assert_raises( + exc.InvalidRequestError, lambda: User.keywords == self.kw + ) def test_filter_collection_ne_fails_ul_nul(self): User = self.classes.User - assert_raises(exc.InvalidRequestError, - lambda: User.keywords != self.kw) + assert_raises( + exc.InvalidRequestError, lambda: User.keywords != self.kw + ) def test_join_separate_attr(self): User = self.classes.User self.assert_compile( self.session.query(User).join( - User.keywords.local_attr, - User.keywords.remote_attr), + User.keywords.local_attr, User.keywords.remote_attr + ), "SELECT users.id AS users_id, users.name AS users_name, " "users.singular_id AS users_singular_id " "FROM users JOIN userkeywords ON users.id = " "userkeywords.user_id JOIN keywords ON keywords.id = " - "userkeywords.keyword_id" + "userkeywords.keyword_id", ) def test_join_single_attr(self): User = self.classes.User self.assert_compile( - self.session.query(User).join( - *User.keywords.attr), + self.session.query(User).join(*User.keywords.attr), "SELECT users.id AS users_id, users.name AS users_name, " "users.singular_id AS users_singular_id " "FROM users JOIN userkeywords ON users.id = " "userkeywords.user_id JOIN keywords ON keywords.id = " - "userkeywords.keyword_id" + "userkeywords.keyword_id", ) @@ -1987,46 +2216,54 @@ class DictOfTupleUpdateTest(fixtures.TestBase): elements = association_proxy("orig", "elem", creator=B) m = MetaData() - a = Table('a', m, Column('id', Integer, primary_key=True)) - b = Table('b', m, Column('id', Integer, primary_key=True), - Column('aid', Integer, ForeignKey('a.id')), - Column('elem', String)) - mapper(A, a, properties={ - 'orig': relationship( - B, - collection_class=attribute_mapped_collection('key')) - }) + a = Table("a", m, Column("id", Integer, primary_key=True)) + b = Table( + "b", + m, + Column("id", Integer, primary_key=True), + Column("aid", Integer, ForeignKey("a.id")), + Column("elem", String), + ) + mapper( + A, + a, + properties={ + "orig": relationship( + B, collection_class=attribute_mapped_collection("key") + ) + }, + ) mapper(B, b) self.A = A self.B = B def test_update_one_elem_dict(self): a1 = self.A() - a1.elements.update({("B", 3): 'elem2'}) - eq_(a1.elements, {("B", 3): 'elem2'}) + a1.elements.update({("B", 3): "elem2"}) + eq_(a1.elements, {("B", 3): "elem2"}) def test_update_multi_elem_dict(self): a1 = self.A() - a1.elements.update({("B", 3): 'elem2', ("C", 4): "elem3"}) - eq_(a1.elements, {("B", 3): 'elem2', ("C", 4): "elem3"}) + a1.elements.update({("B", 3): "elem2", ("C", 4): "elem3"}) + eq_(a1.elements, {("B", 3): "elem2", ("C", 4): "elem3"}) def test_update_one_elem_list(self): a1 = self.A() - a1.elements.update([(("B", 3), 'elem2')]) - eq_(a1.elements, {("B", 3): 'elem2'}) + a1.elements.update([(("B", 3), "elem2")]) + eq_(a1.elements, {("B", 3): "elem2"}) def test_update_multi_elem_list(self): a1 = self.A() - a1.elements.update([(("B", 3), 'elem2'), (("C", 4), "elem3")]) - eq_(a1.elements, {("B", 3): 'elem2', ("C", 4): "elem3"}) + a1.elements.update([(("B", 3), "elem2"), (("C", 4), "elem3")]) + eq_(a1.elements, {("B", 3): "elem2", ("C", 4): "elem3"}) def test_update_one_elem_varg(self): a1 = self.A() assert_raises_message( ValueError, - "dictionary update sequence requires " - "2-element tuples", - a1.elements.update, (("B", 3), 'elem2') + "dictionary update sequence requires " "2-element tuples", + a1.elements.update, + (("B", 3), "elem2"), ) def test_update_multi_elem_varg(self): @@ -2035,7 +2272,8 @@ class DictOfTupleUpdateTest(fixtures.TestBase): TypeError, "update expected at most 1 arguments, got 2", a1.elements.update, - (("B", 3), 'elem2'), (("C", 4), "elem3") + (("B", 3), "elem2"), + (("C", 4), "elem3"), ) @@ -2047,16 +2285,16 @@ class AttributeAccessTest(fixtures.TestBase): Base = declarative_base() class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) value = Column(String) class B(Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) a_id = Column(Integer, ForeignKey(A.id)) a = relationship(A) - a_value = association_proxy('a', 'value') + a_value = association_proxy("a", "value") spec = aliased(B).a_value @@ -2074,20 +2312,21 @@ class AttributeAccessTest(fixtures.TestBase): class Mixin(object): @declared_attr def children(cls): - return association_proxy('_children', 'value') + return association_proxy("_children", "value") # 1. build parent, Mixin.children gets invoked, we add # Parent.children class Parent(Mixin, Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) _children = relationship("Child") class Child(Base): - __tablename__ = 'child' + __tablename__ = "child" parent_id = Column( - Integer, ForeignKey(Parent.id), primary_key=True) + Integer, ForeignKey(Parent.id), primary_key=True + ) value = Column(String) # 2. declarative builds up SubParent, scans through all attributes @@ -2096,7 +2335,7 @@ class AttributeAccessTest(fixtures.TestBase): # mapped yet. association proxy then sets up "owning_class" # as NoneType. class SubParent(Parent): - __tablename__ = 'subparent' + __tablename__ = "subparent" id = Column(Integer, ForeignKey(Parent.id), primary_key=True) configure_mappers() @@ -2109,19 +2348,20 @@ class AttributeAccessTest(fixtures.TestBase): Base = declarative_base() class Parent(Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) _children = relationship("Child") - children = association_proxy('_children', 'value') + children = association_proxy("_children", "value") class Child(Base): - __tablename__ = 'child' + __tablename__ = "child" parent_id = Column( - Integer, ForeignKey(Parent.id), primary_key=True) + Integer, ForeignKey(Parent.id), primary_key=True + ) value = Column(String) class SubParent(Parent): - __tablename__ = 'subparent' + __tablename__ = "subparent" id = Column(Integer, ForeignKey(Parent.id), primary_key=True) is_(SubParent.children.owning_class, SubParent) @@ -2131,20 +2371,21 @@ class AttributeAccessTest(fixtures.TestBase): Base = declarative_base() class Parent(Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) _children = relationship("Child") class Child(Base): - __tablename__ = 'child' + __tablename__ = "child" parent_id = Column( - Integer, ForeignKey(Parent.id), primary_key=True) + Integer, ForeignKey(Parent.id), primary_key=True + ) value = Column(String) class SubParent(Parent): - __tablename__ = 'subparent' + __tablename__ = "subparent" id = Column(Integer, ForeignKey(Parent.id), primary_key=True) - children = association_proxy('_children', 'value') + children = association_proxy("_children", "value") is_(SubParent.children.owning_class, SubParent) @@ -2152,23 +2393,24 @@ class AttributeAccessTest(fixtures.TestBase): Base = declarative_base() class Parent(Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) _children = relationship("Child") class Child(Base): - __tablename__ = 'child' + __tablename__ = "child" parent_id = Column( - Integer, ForeignKey(Parent.id), primary_key=True) + Integer, ForeignKey(Parent.id), primary_key=True + ) value = Column(String) class SubParent(Parent): - __tablename__ = 'subparent' + __tablename__ = "subparent" id = Column(Integer, ForeignKey(Parent.id), primary_key=True) - children = association_proxy('_children', 'value') + children = association_proxy("_children", "value") class SubSubParent(SubParent): - __tablename__ = 'subsubparent' + __tablename__ = "subsubparent" id = Column(Integer, ForeignKey(SubParent.id), primary_key=True) is_(SubParent.children.owning_class, SubParent) @@ -2178,24 +2420,26 @@ class AttributeAccessTest(fixtures.TestBase): Base = declarative_base() class Parent(Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) _children = relationship("Child") children = association_proxy( - '_children', 'value', creator=lambda value: Child(value=value)) + "_children", "value", creator=lambda value: Child(value=value) + ) class Child(Base): - __tablename__ = 'child' + __tablename__ = "child" parent_id = Column( - Integer, ForeignKey(Parent.id), primary_key=True) + Integer, ForeignKey(Parent.id), primary_key=True + ) value = Column(String) class SubParent(Parent): - __tablename__ = 'subparent' + __tablename__ = "subparent" id = Column(Integer, ForeignKey(Parent.id), primary_key=True) sp = SubParent() - sp.children = 'c' + sp.children = "c" is_(SubParent.children.owning_class, SubParent) is_(Parent.children.owning_class, Parent) @@ -2203,17 +2447,18 @@ class AttributeAccessTest(fixtures.TestBase): Base = declarative_base() class Mixin(object): - children = association_proxy('_children', 'value') + children = association_proxy("_children", "value") class Parent(Mixin, Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) _children = relationship("Child") class Child(Base): - __tablename__ = 'child' + __tablename__ = "child" parent_id = Column( - Integer, ForeignKey(Parent.id), primary_key=True) + Integer, ForeignKey(Parent.id), primary_key=True + ) value = Column(String) # this triggers the owning routine, doesn't fail @@ -2221,18 +2466,18 @@ class AttributeAccessTest(fixtures.TestBase): p1 = Parent() - c1 = Child(value='c1') + c1 = Child(value="c1") p1._children.append(c1) is_(Parent.children.owning_class, Parent) eq_(p1.children, ["c1"]) def _test_never_assign_nonetype(self): - foo = association_proxy('x', 'y') + foo = association_proxy("x", "y") foo._calc_owner(None, None) is_(foo.owning_class, None) class Bat(object): - foo = association_proxy('x', 'y') + foo = association_proxy("x", "y") Bat.foo is_(Bat.foo.owning_class, None) @@ -2242,14 +2487,17 @@ class AttributeAccessTest(fixtures.TestBase): exc.InvalidRequestError, "This association proxy has no mapped owning class; " "can't locate a mapped property", - getattr, b1, "foo" + getattr, + b1, + "foo", ) is_(Bat.foo.owning_class, None) # after all that, we can map it mapper( Bat, - Table('bat', MetaData(), Column('x', Integer, primary_key=True))) + Table("bat", MetaData(), Column("x", Integer, primary_key=True)), + ) # answer is correct is_(Bat.foo.owning_class, Bat) @@ -2265,29 +2513,32 @@ class ScalarRemoveTest(object): Base = cls.DeclarativeBasic class A(Base): - __tablename__ = 'test_a' + __tablename__ = "test_a" id = Column(Integer, primary_key=True) - ab = relationship( - 'AB', backref='a', - uselist=cls.uselist) + ab = relationship("AB", backref="a", uselist=cls.uselist) b = association_proxy( - 'ab', 'b', creator=lambda b: AB(b=b), - cascade_scalar_deletes=cls.cascade_scalar_deletes) + "ab", + "b", + creator=lambda b: AB(b=b), + cascade_scalar_deletes=cls.cascade_scalar_deletes, + ) if cls.useobject: + class B(Base): - __tablename__ = 'test_b' + __tablename__ = "test_b" id = Column(Integer, primary_key=True) - ab = relationship('AB', backref="b") + ab = relationship("AB", backref="b") class AB(Base): - __tablename__ = 'test_ab' + __tablename__ = "test_ab" a_id = Column(Integer, ForeignKey(A.id), primary_key=True) b_id = Column(Integer, ForeignKey(B.id), primary_key=True) else: + class AB(Base): - __tablename__ = 'test_ab' + __tablename__ = "test_ab" b = Column(Integer) a_id = Column(Integer, ForeignKey(A.id), primary_key=True) @@ -2354,13 +2605,12 @@ class ScalarRemoveTest(object): eq_(a1.ab, []) else: + def go(): del a1.b assert_raises_message( - AttributeError, - "A.ab object does not have a value", - go + AttributeError, "A.ab object does not have a value", go ) def test_del(self): @@ -2425,7 +2675,8 @@ class ScalarRemoveTest(object): class ScalarRemoveListObjectCascade( - ScalarRemoveTest, fixtures.DeclarativeMappedTest): + ScalarRemoveTest, fixtures.DeclarativeMappedTest +): useobject = True cascade_scalar_deletes = True @@ -2433,7 +2684,8 @@ class ScalarRemoveListObjectCascade( class ScalarRemoveScalarObjectCascade( - ScalarRemoveTest, fixtures.DeclarativeMappedTest): + ScalarRemoveTest, fixtures.DeclarativeMappedTest +): useobject = True cascade_scalar_deletes = True @@ -2441,7 +2693,8 @@ class ScalarRemoveScalarObjectCascade( class ScalarRemoveListScalarCascade( - ScalarRemoveTest, fixtures.DeclarativeMappedTest): + ScalarRemoveTest, fixtures.DeclarativeMappedTest +): useobject = False cascade_scalar_deletes = True @@ -2449,7 +2702,8 @@ class ScalarRemoveListScalarCascade( class ScalarRemoveScalarScalarCascade( - ScalarRemoveTest, fixtures.DeclarativeMappedTest): + ScalarRemoveTest, fixtures.DeclarativeMappedTest +): useobject = False cascade_scalar_deletes = True @@ -2457,7 +2711,8 @@ class ScalarRemoveScalarScalarCascade( class ScalarRemoveListObjectNoCascade( - ScalarRemoveTest, fixtures.DeclarativeMappedTest): + ScalarRemoveTest, fixtures.DeclarativeMappedTest +): useobject = True cascade_scalar_deletes = False @@ -2465,7 +2720,8 @@ class ScalarRemoveListObjectNoCascade( class ScalarRemoveScalarObjectNoCascade( - ScalarRemoveTest, fixtures.DeclarativeMappedTest): + ScalarRemoveTest, fixtures.DeclarativeMappedTest +): useobject = True cascade_scalar_deletes = False @@ -2473,7 +2729,8 @@ class ScalarRemoveScalarObjectNoCascade( class ScalarRemoveListScalarNoCascade( - ScalarRemoveTest, fixtures.DeclarativeMappedTest): + ScalarRemoveTest, fixtures.DeclarativeMappedTest +): useobject = False cascade_scalar_deletes = False @@ -2481,7 +2738,8 @@ class ScalarRemoveListScalarNoCascade( class ScalarRemoveScalarScalarNoCascade( - ScalarRemoveTest, fixtures.DeclarativeMappedTest): + ScalarRemoveTest, fixtures.DeclarativeMappedTest +): useobject = False cascade_scalar_deletes = False @@ -2490,22 +2748,22 @@ class ScalarRemoveScalarScalarNoCascade( class InfoTest(fixtures.TestBase): def test_constructor(self): - assoc = association_proxy('a', 'b', info={'some_assoc': 'some_value'}) + assoc = association_proxy("a", "b", info={"some_assoc": "some_value"}) eq_(assoc.info, {"some_assoc": "some_value"}) def test_empty(self): - assoc = association_proxy('a', 'b') + assoc = association_proxy("a", "b") eq_(assoc.info, {}) def test_via_cls(self): class Foob(object): - assoc = association_proxy('a', 'b') + assoc = association_proxy("a", "b") eq_(Foob.assoc.info, {}) - Foob.assoc.info["foo"] = 'bar' + Foob.assoc.info["foo"] = "bar" - eq_(Foob.assoc.info, {'foo': 'bar'}) + eq_(Foob.assoc.info, {"foo": "bar"}) class OnlyRelationshipTest(fixtures.DeclarativeMappedTest): @@ -2519,7 +2777,7 @@ class OnlyRelationshipTest(fixtures.DeclarativeMappedTest): Base = cls.DeclarativeBasic class Foo(Base): - __tablename__ = 'foo' + __tablename__ = "foo" id = Column(Integer, primary_key=True) foo = Column(String) # assume some composite datatype @@ -2535,7 +2793,10 @@ class OnlyRelationshipTest(fixtures.DeclarativeMappedTest): NotImplementedError, "association proxy to a non-relationship " "intermediary is not supported", - setattr, f1, 'bar', 'asdf' + setattr, + f1, + "bar", + "asdf", ) def test_getattr(self): @@ -2547,7 +2808,9 @@ class OnlyRelationshipTest(fixtures.DeclarativeMappedTest): NotImplementedError, "association proxy to a non-relationship " "intermediary is not supported", - getattr, f1, 'bar' + getattr, + f1, + "bar", ) def test_get_class_attr(self): @@ -2557,28 +2820,30 @@ class OnlyRelationshipTest(fixtures.DeclarativeMappedTest): NotImplementedError, "association proxy to a non-relationship " "intermediary is not supported", - getattr, Foo, 'bar' + getattr, + Foo, + "bar", ) class MultiOwnerTest( - fixtures.DeclarativeMappedTest, - testing.AssertsCompiledSQL): - __dialect__ = 'default' + fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL +): + __dialect__ = "default" - run_define_tables = 'each' + run_define_tables = "each" run_create_tables = None run_inserts = None run_deletes = None - run_setup_classes = 'each' - run_setup_mappers = 'each' + run_setup_classes = "each" + run_setup_mappers = "each" @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) type = Column(String(5), nullable=False) d_values = association_proxy("ds", "value") @@ -2586,8 +2851,8 @@ class MultiOwnerTest( __mapper_args__ = {"polymorphic_on": type} class B(A): - __tablename__ = 'b' - id = Column(ForeignKey('a.id'), primary_key=True) + __tablename__ = "b" + id = Column(ForeignKey("a.id"), primary_key=True) c1_id = Column(ForeignKey("c1.id")) @@ -2596,44 +2861,47 @@ class MultiOwnerTest( __mapper_args__ = {"polymorphic_identity": "b"} class C(A): - __tablename__ = 'c' - id = Column(ForeignKey('a.id'), primary_key=True) + __tablename__ = "c" + id = Column(ForeignKey("a.id"), primary_key=True) ds = relationship( - "D", primaryjoin="D.c_id == C.id", back_populates="c") + "D", primaryjoin="D.c_id == C.id", back_populates="c" + ) __mapper_args__ = {"polymorphic_identity": "c"} class C1(C): - __tablename__ = 'c1' - id = Column(ForeignKey('c.id'), primary_key=True) + __tablename__ = "c1" + id = Column(ForeignKey("c.id"), primary_key=True) csub_only_data = relationship("B") # uselist=True relationship ds = relationship( - "D", primaryjoin="D.c1_id == C1.id", back_populates="c") + "D", primaryjoin="D.c1_id == C1.id", back_populates="c" + ) __mapper_args__ = {"polymorphic_identity": "c1"} class C2(C): - __tablename__ = 'c2' - id = Column(ForeignKey('c.id'), primary_key=True) + __tablename__ = "c2" + id = Column(ForeignKey("c.id"), primary_key=True) csub_only_data = Column(String(50)) # scalar Column ds = relationship( - "D", primaryjoin="D.c2_id == C2.id", back_populates="c") + "D", primaryjoin="D.c2_id == C2.id", back_populates="c" + ) __mapper_args__ = {"polymorphic_identity": "c2"} class D(Base): - __tablename__ = 'd' + __tablename__ = "d" id = Column(Integer, primary_key=True) value = Column(String(50)) - b_id = Column(ForeignKey('b.id')) - c_id = Column(ForeignKey('c.id')) - c1_id = Column(ForeignKey('c1.id')) - c2_id = Column(ForeignKey('c2.id')) + b_id = Column(ForeignKey("b.id")) + c_id = Column(ForeignKey("c.id")) + c1_id = Column(ForeignKey("c1.id")) + c2_id = Column(ForeignKey("c2.id")) c = relationship("C", primaryjoin="D.c_id == C.id") @@ -2643,106 +2911,96 @@ class MultiOwnerTest( assert_raises_message( AttributeError, "Association proxy D.c refers to an attribute 'csub_only_data'", - fn, *arg, **kw + fn, + *arg, + **kw ) def test_column_collection_expressions(self): B, C, C2 = self.classes("B", "C", "C2") self.assert_compile( - B.d_values.contains('b1'), + B.d_values.contains("b1"), "EXISTS (SELECT 1 FROM d, b WHERE d.b_id = b.id " - "AND (d.value LIKE '%' || :value_1 || '%'))" + "AND (d.value LIKE '%' || :value_1 || '%'))", ) self.assert_compile( C2.d_values.contains("c2"), "EXISTS (SELECT 1 FROM d, c2 WHERE d.c2_id = c2.id " - "AND (d.value LIKE '%' || :value_1 || '%'))" + "AND (d.value LIKE '%' || :value_1 || '%'))", ) self.assert_compile( - C.d_values.contains('c1'), + C.d_values.contains("c1"), "EXISTS (SELECT 1 FROM d, c WHERE d.c_id = c.id " - "AND (d.value LIKE '%' || :value_1 || '%'))" + "AND (d.value LIKE '%' || :value_1 || '%'))", ) def test_subclass_only_owner_none(self): D, C, C2 = self.classes("D", "C", "C2") d1 = D() - self._assert_raises_ambiguous( - getattr, d1, 'c_data' - ) + self._assert_raises_ambiguous(getattr, d1, "c_data") def test_subclass_only_owner_assign(self): D, C, C2 = self.classes("D", "C", "C2") d1 = D(c=C2()) - d1.c_data = 'some c2' + d1.c_data = "some c2" eq_(d1.c_data, "some c2") def test_subclass_only_owner_get(self): D, C, C2 = self.classes("D", "C", "C2") - d1 = D(c=C2(csub_only_data='some c2')) + d1 = D(c=C2(csub_only_data="some c2")) eq_(d1.c_data, "some c2") def test_subclass_only_owner_none_raise(self): D, C, C2 = self.classes("D", "C", "C2") d1 = D() - self._assert_raises_ambiguous( - getattr, d1, "c_data" - ) + self._assert_raises_ambiguous(getattr, d1, "c_data") def test_subclass_only_owner_delete(self): D, C, C2 = self.classes("D", "C", "C2") - d1 = D(c=C2(csub_only_data='some c2')) + d1 = D(c=C2(csub_only_data="some c2")) del d1.c_data - self._assert_raises_ambiguous( - getattr, d1, "c_data" - ) + self._assert_raises_ambiguous(getattr, d1, "c_data") def test_subclass_only_owner_assign_raises(self): D, C, C2 = self.classes("D", "C", "C2") d1 = D(c=C()) - self._assert_raises_ambiguous( - setattr, d1, "c_data", 'some c1' - ) + self._assert_raises_ambiguous(setattr, d1, "c_data", "some c1") def test_subclass_only_owner_get_raises(self): D, C, C2 = self.classes("D", "C", "C2") d1 = D(c=C()) - self._assert_raises_ambiguous( - getattr, d1, "c_data" - ) + self._assert_raises_ambiguous(getattr, d1, "c_data") def test_subclass_only_owner_delete_raises(self): D, C, C2 = self.classes("D", "C", "C2") - d1 = D(c=C2(csub_only_data='some c2')) + d1 = D(c=C2(csub_only_data="some c2")) eq_(d1.c_data, "some c2") # now switch d1.c = C() - self._assert_raises_ambiguous( - delattr, d1, "c_data" - ) + self._assert_raises_ambiguous(delattr, d1, "c_data") def test_subclasses_conflicting_types(self): B, D, C, C1, C2 = self.classes("B", "D", "C", "C1", "C2") bs = [B(), B()] d1 = D(c=C1(csub_only_data=bs)) - d2 = D(c=C2(csub_only_data='some c2')) + d2 = D(c=C2(csub_only_data="some c2")) - association_proxy_object = inspect(D).all_orm_descriptors['c_data'] + association_proxy_object = inspect(D).all_orm_descriptors["c_data"] inst1 = association_proxy_object.for_class(D, d1) inst2 = association_proxy_object.for_class(D, d2) @@ -2760,16 +3018,12 @@ class MultiOwnerTest( def test_col_expressions_not_available(self): D, = self.classes("D") - self._assert_raises_ambiguous( - lambda: D.c_data == 5 - ) + self._assert_raises_ambiguous(lambda: D.c_data == 5) def test_rel_expressions_not_available(self): B, D, = self.classes("B", "D") - self._assert_raises_ambiguous( - lambda: D.c_data.any(B.id == 5) - ) + self._assert_raises_ambiguous(lambda: D.c_data.any(B.id == 5)) class ScopeBehaviorTest(fixtures.DeclarativeMappedTest): @@ -2780,7 +3034,7 @@ class ScopeBehaviorTest(fixtures.DeclarativeMappedTest): Base = cls.DeclarativeBasic class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) data = Column(String(50)) @@ -2793,10 +3047,10 @@ class ScopeBehaviorTest(fixtures.DeclarativeMappedTest): b_dynamic_data = association_proxy("bs", "data") class B(Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) - aid = Column(ForeignKey('a.id')) + aid = Column(ForeignKey("a.id")) data = Column(String(50)) @classmethod @@ -2804,9 +3058,12 @@ class ScopeBehaviorTest(fixtures.DeclarativeMappedTest): A, B = cls.classes("A", "B") s = Session(testing.db) - s.add_all([ - A(id=1, bs=[B(data='b1'), B(data='b2')]), - A(id=2, bs=[B(data='b3'), B(data='b4')])]) + s.add_all( + [ + A(id=1, bs=[B(data="b1"), B(data="b2")]), + A(id=2, bs=[B(data="b3"), B(data="b4")]), + ] + ) s.commit() s.close() @@ -2822,7 +3079,7 @@ class ScopeBehaviorTest(fixtures.DeclarativeMappedTest): gc_collect() - assert (A, (1, ), None) not in s.identity_map + assert (A, (1,), None) not in s.identity_map @testing.fails("dynamic relationship strong references parent") def test_dynamic_collection_gc(self): @@ -2839,7 +3096,7 @@ class ScopeBehaviorTest(fixtures.DeclarativeMappedTest): gc_collect() # also fails, AppenderQuery holds onto parent - assert (A, (1, ), None) not in s.identity_map + assert (A, (1,), None) not in s.identity_map @testing.fails("association proxy strong references parent") def test_associated_collection_gc(self): @@ -2849,13 +3106,13 @@ class ScopeBehaviorTest(fixtures.DeclarativeMappedTest): a1 = s.query(A).filter_by(id=1).one() - a1bs = a1.b_data # noqa + a1bs = a1.b_data # noqa del a1 gc_collect() - assert (A, (1, ), None) not in s.identity_map + assert (A, (1,), None) not in s.identity_map @testing.fails("association proxy strong references parent") def test_associated_dynamic_gc(self): @@ -2865,13 +3122,13 @@ class ScopeBehaviorTest(fixtures.DeclarativeMappedTest): a1 = s.query(A).filter_by(id=1).one() - a1bs = a1.b_dynamic_data # noqa + a1bs = a1.b_dynamic_data # noqa del a1 gc_collect() - assert (A, (1, ), None) not in s.identity_map + assert (A, (1,), None) not in s.identity_map def test_plain_collection_iterate(self): A, B = self.classes("A", "B") @@ -2932,5 +3189,3 @@ class ScopeBehaviorTest(fixtures.DeclarativeMappedTest): gc_collect() assert len(a1bs) == 2 - - diff --git a/test/ext/test_automap.py b/test/ext/test_automap.py index 4ac0860c88..e6bdd1d8cc 100644 --- a/test/ext/test_automap.py +++ b/test/ext/test_automap.py @@ -26,8 +26,8 @@ class AutomapTest(fixtures.MappedTest): User = Base.classes.users Address = Base.classes.addresses - a1 = Address(email_address='e1') - u1 = User(name='u1', addresses_collection=[a1]) + a1 = Address(email_address="e1") + u1 = User(name="u1", addresses_collection=[a1]) assert a1.users is u1 def test_relationship_explicit_override_o2m(self): @@ -35,7 +35,7 @@ class AutomapTest(fixtures.MappedTest): prop = relationship("addresses", collection_class=set) class User(Base): - __tablename__ = 'users' + __tablename__ = "users" addresses_collection = prop @@ -43,8 +43,8 @@ class AutomapTest(fixtures.MappedTest): assert User.addresses_collection.property is prop Address = Base.classes.addresses - a1 = Address(email_address='e1') - u1 = User(name='u1', addresses_collection=set([a1])) + a1 = Address(email_address="e1") + u1 = User(name="u1", addresses_collection=set([a1])) assert a1.user is u1 def test_relationship_explicit_override_m2o(self): @@ -53,7 +53,7 @@ class AutomapTest(fixtures.MappedTest): prop = relationship("users") class Address(Base): - __tablename__ = 'addresses' + __tablename__ = "addresses" users = prop @@ -61,8 +61,8 @@ class AutomapTest(fixtures.MappedTest): User = Base.classes.users assert Address.users.property is prop - a1 = Address(email_address='e1') - u1 = User(name='u1', address_collection=[a1]) + a1 = Address(email_address="e1") + u1 = User(name="u1", address_collection=[a1]) assert a1.users is u1 def test_relationship_self_referential(self): @@ -119,17 +119,19 @@ class AutomapTest(fixtures.MappedTest): return str("cls_" + tablename) def name_for_scalar_relationship( - base, local_cls, referred_cls, constraint): + base, local_cls, referred_cls, constraint + ): return "scalar_" + referred_cls.__name__ def name_for_collection_relationship( - base, local_cls, referred_cls, constraint): + base, local_cls, referred_cls, constraint + ): return "coll_" + referred_cls.__name__ Base.prepare( classname_for_table=classname_for_table, name_for_scalar_relationship=name_for_scalar_relationship, - name_for_collection_relationship=name_for_collection_relationship + name_for_collection_relationship=name_for_collection_relationship, ) User = Base.classes.cls_users @@ -145,7 +147,7 @@ class AutomapTest(fixtures.MappedTest): Base.prepare() - Order, Item = Base.classes.orders, Base.classes['items'] + Order, Item = Base.classes.orders, Base.classes["items"] o1 = Order() i1 = Item() @@ -156,15 +158,15 @@ class AutomapTest(fixtures.MappedTest): Base = automap_base(metadata=self.metadata) class Order(Base): - __tablename__ = 'orders' + __tablename__ = "orders" items_collection = relationship( - "items", - secondary="order_items", - collection_class=set) + "items", secondary="order_items", collection_class=set + ) + Base.prepare() - Item = Base.classes['items'] + Item = Base.classes["items"] o1 = Order() i1 = Item() @@ -181,52 +183,62 @@ class AutomapTest(fixtures.MappedTest): mock = Mock() def _gen_relationship( - base, direction, return_fn, attrname, - local_cls, referred_cls, **kw): + base, direction, return_fn, attrname, local_cls, referred_cls, **kw + ): mock(base, direction, attrname) return generate_relationship( - base, direction, return_fn, - attrname, local_cls, referred_cls, **kw) + base, + direction, + return_fn, + attrname, + local_cls, + referred_cls, + **kw + ) Base.prepare(generate_relationship=_gen_relationship) - assert set(tuple(c[1]) for c in mock.mock_calls).issuperset([ - (Base, interfaces.MANYTOONE, "nodes"), - (Base, interfaces.MANYTOMANY, "keywords_collection"), - (Base, interfaces.MANYTOMANY, "items_collection"), - (Base, interfaces.MANYTOONE, "users"), - (Base, interfaces.ONETOMANY, "addresses_collection"), - ]) + assert set(tuple(c[1]) for c in mock.mock_calls).issuperset( + [ + (Base, interfaces.MANYTOONE, "nodes"), + (Base, interfaces.MANYTOMANY, "keywords_collection"), + (Base, interfaces.MANYTOMANY, "items_collection"), + (Base, interfaces.MANYTOONE, "users"), + (Base, interfaces.ONETOMANY, "addresses_collection"), + ] + ) class CascadeTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): + Table("a", metadata, Column("id", Integer, primary_key=True)) Table( - "a", metadata, - Column('id', Integer, primary_key=True) + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("aid", ForeignKey("a.id"), nullable=True), ) Table( - "b", metadata, - Column('id', Integer, primary_key=True), - Column('aid', ForeignKey('a.id'), nullable=True) + "c", + metadata, + Column("id", Integer, primary_key=True), + Column("aid", ForeignKey("a.id"), nullable=False), ) Table( - "c", metadata, - Column('id', Integer, primary_key=True), - Column('aid', ForeignKey('a.id'), nullable=False) - ) - Table( - "d", metadata, - Column('id', Integer, primary_key=True), + "d", + metadata, + Column("id", Integer, primary_key=True), Column( - 'aid', ForeignKey('a.id', ondelete="cascade"), nullable=False) + "aid", ForeignKey("a.id", ondelete="cascade"), nullable=False + ), ) Table( - "e", metadata, - Column('id', Integer, primary_key=True), + "e", + metadata, + Column("id", Integer, primary_key=True), Column( - 'aid', ForeignKey('a.id', ondelete="set null"), - nullable=True) + "aid", ForeignKey("a.id", ondelete="set null"), nullable=True + ), ) def test_o2m_relationship_cascade(self): @@ -268,25 +280,28 @@ class AutomapInhTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'single', metadata, - Column('id', Integer, primary_key=True), - Column('type', String(10)), - test_needs_fk=True + "single", + metadata, + Column("id", Integer, primary_key=True), + Column("type", String(10)), + test_needs_fk=True, ) Table( - 'joined_base', metadata, - Column('id', Integer, primary_key=True), - Column('type', String(10)), - test_needs_fk=True + "joined_base", + metadata, + Column("id", Integer, primary_key=True), + Column("type", String(10)), + test_needs_fk=True, ) Table( - 'joined_inh', metadata, + "joined_inh", + metadata, Column( - 'id', Integer, - ForeignKey('joined_base.id'), primary_key=True), - test_needs_fk=True + "id", Integer, ForeignKey("joined_base.id"), primary_key=True + ), + test_needs_fk=True, ) FixtureTest.define_tables(metadata) @@ -295,13 +310,14 @@ class AutomapInhTest(fixtures.MappedTest): Base = automap_base() class Single(Base): - __tablename__ = 'single' + __tablename__ = "single" type = Column(String) __mapper_args__ = { "polymorphic_identity": "u0", - "polymorphic_on": type} + "polymorphic_on": type, + } class SubUser1(Single): __mapper_args__ = {"polymorphic_identity": "u1"} @@ -317,16 +333,17 @@ class AutomapInhTest(fixtures.MappedTest): Base = automap_base() class Joined(Base): - __tablename__ = 'joined_base' + __tablename__ = "joined_base" type = Column(String) __mapper_args__ = { "polymorphic_identity": "u0", - "polymorphic_on": type} + "polymorphic_on": type, + } class SubJoined(Joined): - __tablename__ = 'joined_inh' + __tablename__ = "joined_inh" __mapper_args__ = {"polymorphic_identity": "u1"} Base.prepare(engine=testing.db, reflect=True) @@ -341,26 +358,30 @@ class AutomapInhTest(fixtures.MappedTest): def _gen_relationship(*arg, **kw): return None + Base.prepare( - engine=testing.db, reflect=True, - generate_relationship=_gen_relationship) + engine=testing.db, + reflect=True, + generate_relationship=_gen_relationship, + ) class ConcurrentAutomapTest(fixtures.TestBase): - __only_on__ = 'sqlite' + __only_on__ = "sqlite" def _make_tables(self, e): m = MetaData() for i in range(15): Table( - 'table_%d' % i, + "table_%d" % i, m, - Column('id', Integer, primary_key=True), - Column('data', String(50)), + Column("id", Integer, primary_key=True), + Column("data", String(50)), Column( - 't_%d_id' % (i - 1), - ForeignKey('table_%d.id' % (i - 1)) - ) if i > 4 else None + "t_%d_id" % (i - 1), ForeignKey("table_%d.id" % (i - 1)) + ) + if i > 4 + else None, ) m.drop_all(e) m.create_all(e) @@ -370,7 +391,7 @@ class ConcurrentAutomapTest(fixtures.TestBase): Base.prepare(e, reflect=True) - time.sleep(.01) + time.sleep(0.01) configure_mappers() def _chaos(self): @@ -396,4 +417,4 @@ class ConcurrentAutomapTest(fixtures.TestBase): for t in threads: t.join() - assert self._success, "One or more threads failed" \ No newline at end of file + assert self._success, "One or more threads failed" diff --git a/test/ext/test_baked.py b/test/ext/test_baked.py index f6afabd2da..8a6702dce9 100644 --- a/test/ext/test_baked.py +++ b/test/ext/test_baked.py @@ -1,5 +1,14 @@ -from sqlalchemy.orm import Session, subqueryload, \ - mapper, relationship, lazyload, backref, aliased, Load, defaultload +from sqlalchemy.orm import ( + Session, + subqueryload, + mapper, + relationship, + lazyload, + backref, + aliased, + Load, + defaultload, +) from sqlalchemy.testing import eq_, is_, is_not_ from sqlalchemy.testing import assert_raises_message from sqlalchemy import testing @@ -16,8 +25,8 @@ from sqlalchemy import exc as sa_exc class BakedTest(_fixtures.FixtureTest): - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None def setup(self): @@ -32,10 +41,7 @@ class StateChangeTest(BakedTest): mapper(User, cls.tables.users) def _assert_cache_key(self, key, elements): - eq_( - key, - tuple(elem.__code__ for elem in elements) - ) + eq_(key, tuple(elem.__code__ for elem in elements)) def test_initial_key(self): User = self.classes.User @@ -45,10 +51,7 @@ class StateChangeTest(BakedTest): return session.query(User) q1 = self.bakery(l1) - self._assert_cache_key( - q1._cache_key, - [l1] - ) + self._assert_cache_key(q1._cache_key, [l1]) eq_(q1.steps, [l1]) def test_inplace_add(self): @@ -59,22 +62,16 @@ class StateChangeTest(BakedTest): return session.query(User) def l2(q): - return q.filter(User.name == bindparam('name')) + return q.filter(User.name == bindparam("name")) q1 = self.bakery(l1) - self._assert_cache_key( - q1._cache_key, - [l1] - ) + self._assert_cache_key(q1._cache_key, [l1]) eq_(q1.steps, [l1]) q2 = q1.add_criteria(l2) is_(q2, q1) - self._assert_cache_key( - q1._cache_key, - [l1, l2] - ) + self._assert_cache_key(q1._cache_key, [l1, l2]) eq_(q1.steps, [l1, l2]) def test_inplace_add_operator(self): @@ -85,20 +82,14 @@ class StateChangeTest(BakedTest): return session.query(User) def l2(q): - return q.filter(User.name == bindparam('name')) + return q.filter(User.name == bindparam("name")) q1 = self.bakery(l1) - self._assert_cache_key( - q1._cache_key, - [l1] - ) + self._assert_cache_key(q1._cache_key, [l1]) q1 += l2 - self._assert_cache_key( - q1._cache_key, - [l1, l2] - ) + self._assert_cache_key(q1._cache_key, [l1, l2]) def test_chained_add(self): User = self.classes.User @@ -108,42 +99,33 @@ class StateChangeTest(BakedTest): return session.query(User) def l2(q): - return q.filter(User.name == bindparam('name')) + return q.filter(User.name == bindparam("name")) q1 = self.bakery(l1) q2 = q1.with_criteria(l2) is_not_(q2, q1) - self._assert_cache_key( - q1._cache_key, - [l1] - ) - self._assert_cache_key( - q2._cache_key, - [l1, l2] - ) + self._assert_cache_key(q1._cache_key, [l1]) + self._assert_cache_key(q2._cache_key, [l1, l2]) def test_chained_add_operator(self): User = self.classes.User session = Session() - def l1(): return session.query(User) + def l1(): + return session.query(User) + + def l2(q): + return q.filter(User.name == bindparam("name")) - def l2(q): return q.filter(User.name == bindparam('name')) q1 = self.bakery(l1) q2 = q1 + l2 is_not_(q2, q1) - self._assert_cache_key( - q1._cache_key, - [l1] - ) - self._assert_cache_key( - q2._cache_key, - [l1, l2] - ) + self._assert_cache_key(q1._cache_key, [l1]) + self._assert_cache_key(q2._cache_key, [l1, l2]) class LikeQueryTest(BakedTest): @@ -157,87 +139,78 @@ class LikeQueryTest(BakedTest): User = self.classes.User bq = self.bakery(lambda s: s.query(User)) - bq += lambda q: q.filter(User.name == 'asdf') + bq += lambda q: q.filter(User.name == "asdf") - eq_( - bq(Session()).first(), - None - ) + eq_(bq(Session()).first(), None) def test_first_multiple_result(self): User = self.classes.User bq = self.bakery(lambda s: s.query(User.id)) - bq += lambda q: q.filter(User.name.like('%ed%')).order_by(User.id) + bq += lambda q: q.filter(User.name.like("%ed%")).order_by(User.id) - eq_( - bq(Session()).first(), - (8, ) - ) + eq_(bq(Session()).first(), (8,)) def test_one_or_none_no_result(self): User = self.classes.User bq = self.bakery(lambda s: s.query(User)) - bq += lambda q: q.filter(User.name == 'asdf') + bq += lambda q: q.filter(User.name == "asdf") - eq_( - bq(Session()).one_or_none(), - None - ) + eq_(bq(Session()).one_or_none(), None) def test_one_or_none_result(self): User = self.classes.User bq = self.bakery(lambda s: s.query(User)) - bq += lambda q: q.filter(User.name == 'ed') + bq += lambda q: q.filter(User.name == "ed") u1 = bq(Session()).one_or_none() - eq_(u1.name, 'ed') + eq_(u1.name, "ed") def test_one_or_none_multiple_result(self): User = self.classes.User bq = self.bakery(lambda s: s.query(User)) - bq += lambda q: q.filter(User.name.like('%ed%')) + bq += lambda q: q.filter(User.name.like("%ed%")) assert_raises_message( orm_exc.MultipleResultsFound, "Multiple rows were found for one_or_none()", - bq(Session()).one_or_none + bq(Session()).one_or_none, ) def test_one_no_result(self): User = self.classes.User bq = self.bakery(lambda s: s.query(User)) - bq += lambda q: q.filter(User.name == 'asdf') + bq += lambda q: q.filter(User.name == "asdf") assert_raises_message( orm_exc.NoResultFound, "No row was found for one()", - bq(Session()).one + bq(Session()).one, ) def test_one_result(self): User = self.classes.User bq = self.bakery(lambda s: s.query(User)) - bq += lambda q: q.filter(User.name == 'ed') + bq += lambda q: q.filter(User.name == "ed") u1 = bq(Session()).one() - eq_(u1.name, 'ed') + eq_(u1.name, "ed") def test_one_multiple_result(self): User = self.classes.User bq = self.bakery(lambda s: s.query(User)) - bq += lambda q: q.filter(User.name.like('%ed%')) + bq += lambda q: q.filter(User.name.like("%ed%")) assert_raises_message( orm_exc.MultipleResultsFound, "Multiple rows were found for one()", - bq(Session()).one + bq(Session()).one, ) def test_get(self): @@ -249,19 +222,22 @@ class LikeQueryTest(BakedTest): def go(): u1 = bq(sess).get(7) - eq_(u1.name, 'jack') + eq_(u1.name, "jack") + self.assert_sql_count(testing.db, go, 1) u1 = sess.query(User).get(7) # noqa def go(): u2 = bq(sess).get(7) - eq_(u2.name, 'jack') + eq_(u2.name, "jack") + self.assert_sql_count(testing.db, go, 0) def go(): u2 = bq(sess).get(8) - eq_(u2.name, 'ed') + eq_(u2.name, "ed") + self.assert_sql_count(testing.db, go, 1) def test_scalar(self): @@ -273,9 +249,7 @@ class LikeQueryTest(BakedTest): bq += lambda q: q.filter(User.id == 7) - eq_( - bq(sess).scalar(), 7 - ) + eq_(bq(sess).scalar(), 7) def test_count(self): User = self.classes.User @@ -284,21 +258,16 @@ class LikeQueryTest(BakedTest): sess = Session() - eq_( - bq(sess).count(), - 4 - ) + eq_(bq(sess).count(), 4) bq += lambda q: q.filter(User.id.in_([8, 9])) - eq_( - bq(sess).count(), 2 - ) + eq_(bq(sess).count(), 2) # original query still works eq_( set([(u.id, u.name) for u in bq(sess).all()]), - set([(8, 'ed'), (9, 'fred')]) + set([(8, "ed"), (9, "fred")]), ) def test_count_with_bindparams(self): @@ -308,39 +277,33 @@ class LikeQueryTest(BakedTest): sess = Session() - eq_( - bq(sess).count(), - 4 - ) + eq_(bq(sess).count(), 4) bq += lambda q: q.filter(User.name == bindparam("uname")) # calling with *args - eq_( - bq(sess).params(uname='fred').count(), 1 - ) + eq_(bq(sess).params(uname="fred").count(), 1) # with multiple params, the **kwargs will be used bq += lambda q: q.filter(User.id == bindparam("anid")) - eq_( - bq(sess).params(uname='fred', anid=9).count(), 1 - ) + eq_(bq(sess).params(uname="fred", anid=9).count(), 1) eq_( # wrong id, so 0 results: - bq(sess).params(uname='fred', anid=8).count(), 0 + bq(sess).params(uname="fred", anid=8).count(), + 0, ) - def test_get_pk_w_null(self): """test the re-implementation of logic to do get with IS NULL.""" class AddressUser(object): pass + mapper( AddressUser, self.tables.users.outerjoin(self.tables.addresses), properties={ "id": self.tables.users.c.id, - "address_id": self.tables.addresses.c.id - } + "address_id": self.tables.addresses.c.id, + }, ) bq = self.bakery(lambda s: s.query(AddressUser)) @@ -349,14 +312,16 @@ class LikeQueryTest(BakedTest): def go(): u1 = bq(sess).get((10, None)) - eq_(u1.name, 'chuck') + eq_(u1.name, "chuck") + self.assert_sql_count(testing.db, go, 1) u1 = sess.query(AddressUser).get((10, None)) # noqa def go(): u2 = bq(sess).get((10, None)) - eq_(u2.name, 'chuck') + eq_(u2.name, "chuck") + self.assert_sql_count(testing.db, go, 0) def test_get_includes_getclause(self): @@ -368,7 +333,7 @@ class LikeQueryTest(BakedTest): for i in range(5): sess = Session() u1 = bq(sess).get(7) - eq_(u1.name, 'jack') + eq_(u1.name, "jack") sess.close() eq_(len(bq._bakery), 2) @@ -376,45 +341,50 @@ class LikeQueryTest(BakedTest): # simulate race where mapper._get_clause # may be generated more than once from sqlalchemy import inspect - del inspect(User).__dict__['_get_clause'] + + del inspect(User).__dict__["_get_clause"] for i in range(5): sess = Session() u1 = bq(sess).get(7) - eq_(u1.name, 'jack') + eq_(u1.name, "jack") sess.close() eq_(len(bq._bakery), 4) class ResultPostCriteriaTest(BakedTest): - @classmethod def setup_mappers(cls): User = cls.classes.User Address = cls.classes.Address Order = cls.classes.Order - mapper(User, cls.tables.users, properties={ - "addresses": relationship( - Address, order_by=cls.tables.addresses.c.id), - "orders": relationship( - Order, order_by=cls.tables.orders.c.id) - }) + mapper( + User, + cls.tables.users, + properties={ + "addresses": relationship( + Address, order_by=cls.tables.addresses.c.id + ), + "orders": relationship(Order, order_by=cls.tables.orders.c.id), + }, + ) mapper(Address, cls.tables.addresses) mapper(Order, cls.tables.orders) @contextlib.contextmanager def _fixture(self): from sqlalchemy import event + User = self.classes.User with testing.db.connect() as conn: + @event.listens_for(conn, "before_execute") def before_execute(conn, clauseelement, multiparams, params): assert "yes" in conn._execution_options - bq = self.bakery( - lambda s: s.query(User.id).order_by(User.id)) + bq = self.bakery(lambda s: s.query(User.id).order_by(User.id)) sess = Session(conn) @@ -423,31 +393,34 @@ class ResultPostCriteriaTest(BakedTest): def test_first(self): with self._fixture() as (sess, bq): result = bq(sess).with_post_criteria( - lambda q: q.execution_options(yes=True)) - eq_(result.first(), (7, )) + lambda q: q.execution_options(yes=True) + ) + eq_(result.first(), (7,)) def test_iter(self): with self._fixture() as (sess, bq): result = bq(sess).with_post_criteria( - lambda q: q.execution_options(yes=True)) - eq_(list(result)[0], (7, )) + lambda q: q.execution_options(yes=True) + ) + eq_(list(result)[0], (7,)) def test_spoiled(self): with self._fixture() as (sess, bq): result = bq.spoil()(sess).with_post_criteria( - lambda q: q.execution_options(yes=True)) + lambda q: q.execution_options(yes=True) + ) - eq_(list(result)[0], (7, )) + eq_(list(result)[0], (7,)) def test_get(self): User = self.classes.User with self._fixture() as (sess, bq): - bq = self.bakery( - lambda s: s.query(User)) + bq = self.bakery(lambda s: s.query(User)) result = bq(sess).with_post_criteria( - lambda q: q.execution_options(yes=True)) + lambda q: q.execution_options(yes=True) + ) eq_(result.get(7), User(id=7)) @@ -460,12 +433,16 @@ class ResultTest(BakedTest): Address = cls.classes.Address Order = cls.classes.Order - mapper(User, cls.tables.users, properties={ - "addresses": relationship( - Address, order_by=cls.tables.addresses.c.id), - "orders": relationship( - Order, order_by=cls.tables.orders.c.id) - }) + mapper( + User, + cls.tables.users, + properties={ + "addresses": relationship( + Address, order_by=cls.tables.addresses.c.id + ), + "orders": relationship(Order, order_by=cls.tables.orders.c.id), + }, + ) mapper(Address, cls.tables.addresses) mapper(Order, cls.tables.orders) @@ -482,48 +459,41 @@ class ResultTest(BakedTest): for i in range(3): session = Session(autocommit=True) - eq_( - bq1(session).all(), - [(7,)] - ) + eq_(bq1(session).all(), [(7,)]) - eq_( - bq2(session).all(), - [(8,)] - ) + eq_(bq2(session).all(), [(8,)]) def test_no_steps(self): User = self.classes.User bq = self.bakery( - lambda s: s.query(User.id, User.name).order_by(User.id)) + lambda s: s.query(User.id, User.name).order_by(User.id) + ) for i in range(3): session = Session(autocommit=True) eq_( bq(session).all(), - [(7, 'jack'), (8, 'ed'), (9, 'fred'), (10, 'chuck')] + [(7, "jack"), (8, "ed"), (9, "fred"), (10, "chuck")], ) def test_different_limits(self): User = self.classes.User bq = self.bakery( - lambda s: s.query(User.id, User.name).order_by(User.id)) + lambda s: s.query(User.id, User.name).order_by(User.id) + ) - bq += lambda q: q.limit(bindparam('limit')).offset(bindparam('offset')) + bq += lambda q: q.limit(bindparam("limit")).offset(bindparam("offset")) session = Session(autocommit=True) for i in range(4): for limit, offset, exp in [ - (2, 1, [(8, 'ed'), (9, 'fred')]), - (3, 0, [(7, 'jack'), (8, 'ed'), (9, 'fred')]), - (1, 2, [(9, 'fred')]) + (2, 1, [(8, "ed"), (9, "fred")]), + (3, 0, [(7, "jack"), (8, "ed"), (9, "fred")]), + (1, 2, [(9, "fred")]), ]: - eq_( - bq(session).params(limit=limit, offset=offset).all(), - exp - ) + eq_(bq(session).params(limit=limit, offset=offset).all(), exp) def test_disable_on_session(self): User = self.classes.User @@ -536,7 +506,7 @@ class ResultTest(BakedTest): def fn2(q): canary.fn2() - return q.filter(User.id == bindparam('id')) + return q.filter(User.id == bindparam("id")) def fn3(q): canary.fn3() @@ -548,16 +518,21 @@ class ResultTest(BakedTest): bq += fn2 sess = Session(autocommit=True, enable_baked_queries=False) - eq_( - bq.add_criteria(fn3)(sess).params(id=7).all(), - [(7, 'jack')] - ) + eq_(bq.add_criteria(fn3)(sess).params(id=7).all(), [(7, "jack")]) eq_( canary.mock_calls, - [mock.call.fn1(), mock.call.fn2(), mock.call.fn3(), - mock.call.fn1(), mock.call.fn2(), mock.call.fn3(), - mock.call.fn1(), mock.call.fn2(), mock.call.fn3()] + [ + mock.call.fn1(), + mock.call.fn2(), + mock.call.fn3(), + mock.call.fn1(), + mock.call.fn2(), + mock.call.fn3(), + mock.call.fn1(), + mock.call.fn2(), + mock.call.fn3(), + ], ) def test_spoiled_full_w_params(self): @@ -571,7 +546,7 @@ class ResultTest(BakedTest): def fn2(q): canary.fn2() - return q.filter(User.id == bindparam('id')) + return q.filter(User.id == bindparam("id")) def fn3(q): canary.fn3() @@ -585,14 +560,22 @@ class ResultTest(BakedTest): sess = Session(autocommit=True) eq_( bq.spoil(full=True).add_criteria(fn3)(sess).params(id=7).all(), - [(7, 'jack')] + [(7, "jack")], ) eq_( canary.mock_calls, - [mock.call.fn1(), mock.call.fn2(), mock.call.fn3(), - mock.call.fn1(), mock.call.fn2(), mock.call.fn3(), - mock.call.fn1(), mock.call.fn2(), mock.call.fn3()] + [ + mock.call.fn1(), + mock.call.fn2(), + mock.call.fn3(), + mock.call.fn1(), + mock.call.fn2(), + mock.call.fn3(), + mock.call.fn1(), + mock.call.fn2(), + mock.call.fn3(), + ], ) def test_spoiled_half_w_params(self): @@ -606,7 +589,7 @@ class ResultTest(BakedTest): def fn2(q): canary.fn2() - return q.filter(User.id == bindparam('id')) + return q.filter(User.id == bindparam("id")) def fn3(q): canary.fn3() @@ -624,13 +607,18 @@ class ResultTest(BakedTest): sess = Session(autocommit=True) eq_( bq.spoil().add_criteria(fn3)(sess).params(id=7).all(), - [(7, 'jack')] + [(7, "jack")], ) eq_( canary.mock_calls, - [mock.call.fn1(), mock.call.fn2(), - mock.call.fn3(), mock.call.fn3(), mock.call.fn3()] + [ + mock.call.fn1(), + mock.call.fn2(), + mock.call.fn3(), + mock.call.fn3(), + mock.call.fn3(), + ], ) def test_w_new_entities(self): @@ -641,18 +629,13 @@ class ResultTest(BakedTest): """ User = self.classes.User - bq = self.bakery( - lambda s: s.query(User.id, User.name)) + bq = self.bakery(lambda s: s.query(User.id, User.name)) - bq += lambda q: q.from_self().with_entities( - func.count(User.id)) + bq += lambda q: q.from_self().with_entities(func.count(User.id)) for i in range(3): session = Session(autocommit=True) - eq_( - bq(session).all(), - [(4, )] - ) + eq_(bq(session).all(), [(4,)]) def test_conditional_step(self): """Test a large series of conditionals and assert that @@ -662,29 +645,30 @@ class ResultTest(BakedTest): """ User = self.classes.User - base_bq = self.bakery( - lambda s: s.query(User.id, User.name)) + base_bq = self.bakery(lambda s: s.query(User.id, User.name)) base_bq += lambda q: q.order_by(User.id) for i in range(4): for cond1, cond2, cond3, cond4 in itertools.product( - *[(False, True) for j in range(4)]): + *[(False, True) for j in range(4)] + ): bq = base_bq._clone() if cond1: - bq += lambda q: q.filter(User.name != 'jack') + bq += lambda q: q.filter(User.name != "jack") if cond2: bq += lambda q: q.join(User.addresses) else: bq += lambda q: q.outerjoin(User.addresses) elif cond3: - bq += lambda q: q.filter(User.name.like('%ed%')) + bq += lambda q: q.filter(User.name.like("%ed%")) else: - bq += lambda q: q.filter(User.name == 'jack') + bq += lambda q: q.filter(User.name == "jack") if cond4: bq += lambda q: q.from_self().with_entities( - func.count(User.id)) + func.count(User.id) + ) sess = Session(autocommit=True) result = bq(sess).all() if cond4: @@ -702,27 +686,30 @@ class ResultTest(BakedTest): if cond2: eq_( result, - [(8, 'ed'), (8, 'ed'), (8, 'ed'), - (9, 'fred')] + [(8, "ed"), (8, "ed"), (8, "ed"), (9, "fred")], ) else: eq_( result, - [(8, 'ed'), (8, 'ed'), (8, 'ed'), - (9, 'fred'), (10, 'chuck')] + [ + (8, "ed"), + (8, "ed"), + (8, "ed"), + (9, "fred"), + (10, "chuck"), + ], ) elif cond3: - eq_(result, [(8, 'ed'), (9, 'fred')]) + eq_(result, [(8, "ed"), (9, "fred")]) else: - eq_(result, [(7, 'jack')]) + eq_(result, [(7, "jack")]) sess.close() def test_conditional_step_oneline(self): User = self.classes.User - base_bq = self.bakery( - lambda s: s.query(User.id, User.name)) + base_bq = self.bakery(lambda s: s.query(User.id, User.name)) base_bq += lambda q: q.order_by(User.id) @@ -732,14 +719,18 @@ class ResultTest(BakedTest): # we were using (filename, firstlineno) as cache key, # which fails for this kind of thing! - bq += (lambda q: q.filter(User.name != 'jack')) if cond1 else (lambda q: q.filter(User.name == 'jack')) # noqa + bq += ( + (lambda q: q.filter(User.name != "jack")) + if cond1 + else (lambda q: q.filter(User.name == "jack")) + ) # noqa sess = Session(autocommit=True) result = bq(sess).all() if cond1: - eq_(result, [(8, u'ed'), (9, u'fred'), (10, u'chuck')]) + eq_(result, [(8, u"ed"), (9, u"fred"), (10, u"chuck")]) else: - eq_(result, [(7, 'jack')]) + eq_(result, [(7, "jack")]) sess.close() @@ -747,16 +738,15 @@ class ResultTest(BakedTest): User = self.classes.User Address = self.classes.Address - sub_bq = self.bakery( - lambda s: s.query(User.name) + sub_bq = self.bakery(lambda s: s.query(User.name)) + sub_bq += ( + lambda q: q.filter(User.id == Address.user_id) + .filter(User.name == "ed") + .correlate(Address) ) - sub_bq += lambda q: q.filter( - User.id == Address.user_id).filter(User.name == 'ed').\ - correlate(Address) main_bq = self.bakery(lambda s: s.query(Address.id)) - main_bq += lambda q: q.filter( - sub_bq.to_query(q).exists()) + main_bq += lambda q: q.filter(sub_bq.to_query(q).exists()) main_bq += lambda q: q.order_by(Address.id) sess = Session() @@ -767,38 +757,38 @@ class ResultTest(BakedTest): User = self.classes.User Address = self.classes.Address - sub_bq = self.bakery( - lambda s: s.query(User.name) + sub_bq = self.bakery(lambda s: s.query(User.name)) + sub_bq += lambda q: q.filter(User.id == Address.user_id).correlate( + Address ) - sub_bq += lambda q: q.filter( - User.id == Address.user_id).correlate(Address) main_bq = self.bakery( - lambda s: s.query(Address.id, sub_bq.to_query(s).as_scalar())) - main_bq += lambda q: q.filter(sub_bq.to_query(q).as_scalar() == 'ed') + lambda s: s.query(Address.id, sub_bq.to_query(s).as_scalar()) + ) + main_bq += lambda q: q.filter(sub_bq.to_query(q).as_scalar() == "ed") main_bq += lambda q: q.order_by(Address.id) sess = Session() result = main_bq(sess).all() - eq_(result, [(2, 'ed'), (3, 'ed'), (4, 'ed')]) + eq_(result, [(2, "ed"), (3, "ed"), (4, "ed")]) def test_to_query_args(self): User = self.classes.User - sub_bq = self.bakery( - lambda s: s.query(User.name) - ) + sub_bq = self.bakery(lambda s: s.query(User.name)) q = Query([], None) assert_raises_message( sa_exc.ArgumentError, "Given Query needs to be associated with a Session", - sub_bq.to_query, q + sub_bq.to_query, + q, ) assert_raises_message( TypeError, "Query or Session object expected, got .*'int'.*", - sub_bq.to_query, 5 + sub_bq.to_query, + 5, ) def test_subquery_eagerloading(self): @@ -811,67 +801,86 @@ class ResultTest(BakedTest): self.bakery = baked.bakery(size=3) base_bq = self.bakery(lambda s: s.query(User)) - base_bq += lambda q: q.options(subqueryload(User.addresses), - subqueryload(User.orders)) + base_bq += lambda q: q.options( + subqueryload(User.addresses), subqueryload(User.orders) + ) base_bq += lambda q: q.order_by(User.id) assert_result = [ - User(id=7, - addresses=[Address(id=1, email_address='jack@bean.com')], - orders=[Order(id=1), Order(id=3), Order(id=5)]), - User(id=8, addresses=[ - Address(id=2, email_address='ed@wood.com'), - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - ]), - User(id=9, - addresses=[Address(id=5)], - orders=[Order(id=2), Order(id=4)]), - User(id=10, addresses=[]) + User( + id=7, + addresses=[Address(id=1, email_address="jack@bean.com")], + orders=[Order(id=1), Order(id=3), Order(id=5)], + ), + User( + id=8, + addresses=[ + Address(id=2, email_address="ed@wood.com"), + Address(id=3, email_address="ed@bettyboop.com"), + Address(id=4, email_address="ed@lala.com"), + ], + ), + User( + id=9, + addresses=[Address(id=5)], + orders=[Order(id=2), Order(id=4)], + ), + User(id=10, addresses=[]), ] for i in range(4): for cond1, cond2 in itertools.product( - *[(False, True) for j in range(2)]): + *[(False, True) for j in range(2)] + ): bq = base_bq._clone() sess = Session() if cond1: - bq += lambda q: q.filter(User.name == 'jack') + bq += lambda q: q.filter(User.name == "jack") else: - bq += lambda q: q.filter(User.name.like('%ed%')) + bq += lambda q: q.filter(User.name.like("%ed%")) if cond2: - ct = func.count(Address.id).label('count') - subq = sess.query( - ct, - Address.user_id).group_by(Address.user_id).\ - having(ct > 2).subquery() + ct = func.count(Address.id).label("count") + subq = ( + sess.query(ct, Address.user_id) + .group_by(Address.user_id) + .having(ct > 2) + .subquery() + ) bq += lambda q: q.join(subq) if cond2: if cond1: + def go(): result = bq(sess).all() eq_([], result) + self.assert_sql_count(testing.db, go, 1) else: + def go(): result = bq(sess).all() eq_(assert_result[1:2], result) + self.assert_sql_count(testing.db, go, 3) else: if cond1: + def go(): result = bq(sess).all() eq_(assert_result[0:1], result) + self.assert_sql_count(testing.db, go, 3) else: + def go(): result = bq(sess).all() eq_(assert_result[1:3], result) + self.assert_sql_count(testing.db, go, 3) sess.close() @@ -881,8 +890,9 @@ class ResultTest(BakedTest): Address = self.classes.Address assert_result = [ - User(id=7, - addresses=[Address(id=1, email_address='jack@bean.com')]) + User( + id=7, addresses=[Address(id=1, email_address="jack@bean.com")] + ) ] self.bakery = baked.bakery(size=3) @@ -891,11 +901,11 @@ class ResultTest(BakedTest): bq += lambda q: q.options(subqueryload(User.addresses)) bq += lambda q: q.order_by(User.id) - bq += lambda q: q.filter(User.name == bindparam('name')) + bq += lambda q: q.filter(User.name == bindparam("name")) sess = Session() def set_params(q): - return q.params(name='jack') + return q.params(name="jack") # test that the changes we make using with_post_criteria() # are also applied to the subqueryload query. @@ -907,17 +917,24 @@ class ResultTest(BakedTest): class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): - run_setup_mappers = 'each' + run_setup_mappers = "each" def _o2m_fixture(self, lazy="select", **kw): User = self.classes.User Address = self.classes.Address - mapper(User, self.tables.users, properties={ - 'addresses': relationship( - Address, order_by=self.tables.addresses.c.id, - lazy=lazy, **kw) - }) + mapper( + User, + self.tables.users, + properties={ + "addresses": relationship( + Address, + order_by=self.tables.addresses.c.id, + lazy=lazy, + **kw + ) + }, + ) mapper(Address, self.tables.addresses) return User, Address @@ -926,14 +943,23 @@ class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): Address = self.classes.Address Dingaling = self.classes.Dingaling - mapper(User, self.tables.users, properties={ - 'addresses': relationship( - Address, order_by=self.tables.addresses.c.id, - lazy=lazy, **kw) - }) - mapper(Address, self.tables.addresses, properties={ - "dingalings": relationship(Dingaling, lazy=lazy) - }) + mapper( + User, + self.tables.users, + properties={ + "addresses": relationship( + Address, + order_by=self.tables.addresses.c.id, + lazy=lazy, + **kw + ) + }, + ) + mapper( + Address, + self.tables.addresses, + properties={"dingalings": relationship(Dingaling, lazy=lazy)}, + ) mapper(Dingaling, self.tables.dingalings) return User, Address, Dingaling @@ -942,9 +968,11 @@ class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): Address = self.classes.Address mapper(User, self.tables.users) - mapper(Address, self.tables.addresses, properties={ - 'user': relationship(User) - }) + mapper( + Address, + self.tables.addresses, + properties={"user": relationship(User)}, + ) return User, Address def test_unsafe_unbound_option_cancels_bake(self): @@ -952,16 +980,24 @@ class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): class SubDingaling(Dingaling): pass + mapper(SubDingaling, None, inherits=Dingaling) lru = Address.dingalings.property._lazy_strategy._bakery( - lambda q: None)._bakery + lambda q: None + )._bakery l1 = len(lru) for i in range(5): sess = Session() - u1 = sess.query(User).options( - defaultload(User.addresses).lazyload( - Address.dingalings.of_type(aliased(SubDingaling)))).first() + u1 = ( + sess.query(User) + .options( + defaultload(User.addresses).lazyload( + Address.dingalings.of_type(aliased(SubDingaling)) + ) + ) + .first() + ) for ad in u1.addresses: ad.dingalings l2 = len(lru) @@ -973,16 +1009,26 @@ class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): class SubDingaling(Dingaling): pass + mapper(SubDingaling, None, inherits=Dingaling) lru = Address.dingalings.property._lazy_strategy._bakery( - lambda q: None)._bakery + lambda q: None + )._bakery l1 = len(lru) for i in range(5): sess = Session() - u1 = sess.query(User).options( - Load(User).defaultload(User.addresses).lazyload( - Address.dingalings.of_type(aliased(SubDingaling)))).first() + u1 = ( + sess.query(User) + .options( + Load(User) + .defaultload(User.addresses) + .lazyload( + Address.dingalings.of_type(aliased(SubDingaling)) + ) + ) + .first() + ) for ad in u1.addresses: ad.dingalings l2 = len(lru) @@ -993,13 +1039,18 @@ class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): User, Address, Dingaling = self._o2m_twolevel_fixture(lazy="joined") lru = Address.dingalings.property._lazy_strategy._bakery( - lambda q: None)._bakery + lambda q: None + )._bakery l1 = len(lru) for i in range(5): sess = Session() - u1 = sess.query(User).options( - defaultload(User.addresses).lazyload( - Address.dingalings)).first() + u1 = ( + sess.query(User) + .options( + defaultload(User.addresses).lazyload(Address.dingalings) + ) + .first() + ) for ad in u1.addresses: ad.dingalings l2 = len(lru) @@ -1010,13 +1061,20 @@ class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): User, Address, Dingaling = self._o2m_twolevel_fixture(lazy="joined") lru = Address.dingalings.property._lazy_strategy._bakery( - lambda q: None)._bakery + lambda q: None + )._bakery l1 = len(lru) for i in range(5): sess = Session() - u1 = sess.query(User).options( - Load(User).defaultload(User.addresses).lazyload( - Address.dingalings)).first() + u1 = ( + sess.query(User) + .options( + Load(User) + .defaultload(User.addresses) + .lazyload(Address.dingalings) + ) + .first() + ) for ad in u1.addresses: ad.dingalings l2 = len(lru) @@ -1044,15 +1102,11 @@ class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): real_compile_context = Query._compile_context def _my_compile_context(*arg, **kw): - if arg[0].column_descriptions[0]['entity'] is Address: + if arg[0].column_descriptions[0]["entity"] is Address: canary() return real_compile_context(*arg, **kw) - with mock.patch.object( - Query, - "_compile_context", - _my_compile_context - ): + with mock.patch.object(Query, "_compile_context", _my_compile_context): u1.addresses sess.expire(u1) @@ -1074,8 +1128,7 @@ class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): def _test_baked_lazy_loading(self, set_option): User, Address = self.classes.User, self.classes.Address - base_bq = self.bakery( - lambda s: s.query(User)) + base_bq = self.bakery(lambda s: s.query(User)) if set_option: base_bq += lambda q: q.options(lazyload(User.addresses)) @@ -1086,46 +1139,57 @@ class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): for i in range(4): for cond1, cond2 in itertools.product( - *[(False, True) for j in range(2)]): + *[(False, True) for j in range(2)] + ): bq = base_bq._clone() sess = Session() if cond1: - bq += lambda q: q.filter(User.name == 'jack') + bq += lambda q: q.filter(User.name == "jack") else: - bq += lambda q: q.filter(User.name.like('%ed%')) + bq += lambda q: q.filter(User.name.like("%ed%")) if cond2: - ct = func.count(Address.id).label('count') - subq = sess.query( - ct, - Address.user_id).group_by(Address.user_id).\ - having(ct > 2).subquery() + ct = func.count(Address.id).label("count") + subq = ( + sess.query(ct, Address.user_id) + .group_by(Address.user_id) + .having(ct > 2) + .subquery() + ) bq += lambda q: q.join(subq) if cond2: if cond1: + def go(): result = bq(sess).all() eq_([], result) + self.assert_sql_count(testing.db, go, 1) else: + def go(): result = bq(sess).all() eq_(assert_result[1:2], result) + self.assert_sql_count(testing.db, go, 2) else: if cond1: + def go(): result = bq(sess).all() eq_(assert_result[0:1], result) + self.assert_sql_count(testing.db, go, 2) else: + def go(): result = bq(sess).all() eq_(assert_result[1:3], result) + self.assert_sql_count(testing.db, go, 3) sess.close() @@ -1133,8 +1197,7 @@ class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): def test_baked_lazy_loading_m2o(self): User, Address = self._m2o_fixture() - base_bq = self.bakery( - lambda s: s.query(Address)) + base_bq = self.bakery(lambda s: s.query(Address)) base_bq += lambda q: q.options(lazyload(Address.user)) base_bq += lambda q: q.order_by(Address.id) @@ -1149,20 +1212,26 @@ class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): if cond1: bq += lambda q: q.filter( - Address.email_address == 'jack@bean.com') + Address.email_address == "jack@bean.com" + ) else: bq += lambda q: q.filter( - Address.email_address.like('ed@%')) + Address.email_address.like("ed@%") + ) if cond1: + def go(): result = bq(sess).all() eq_(assert_result[0:1], result) + self.assert_sql_count(testing.db, go, 2) else: + def go(): result = bq(sess).all() eq_(assert_result[1:4], result) + self.assert_sql_count(testing.db, go, 2) sess.close() @@ -1175,26 +1244,34 @@ class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): Address = self.classes.Address mapper(User, self.tables.users) - mapper(Address, self.tables.addresses, properties={ - 'user': relationship( - User, lazy='joined', - backref=backref('addresses', lazy='baked_select') - ) - }) + mapper( + Address, + self.tables.addresses, + properties={ + "user": relationship( + User, + lazy="joined", + backref=backref("addresses", lazy="baked_select"), + ) + }, + ) sess = Session() u1 = sess.query(User).filter(User.id == 8).one() def go(): eq_(u1.addresses[0].user, u1) + self.assert_sql_execution( - testing.db, go, + testing.db, + go, CompiledSQL( "SELECT addresses.id AS addresses_id, addresses.user_id AS " "addresses_user_id, addresses.email_address AS " "addresses_email_address FROM addresses WHERE :param_1 = " "addresses.user_id", - {'param_1': 8}) + {"param_1": 8}, + ), ) def test_useget_cancels_eager_propagated_present(self): @@ -1206,12 +1283,17 @@ class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): Address = self.classes.Address mapper(User, self.tables.users) - mapper(Address, self.tables.addresses, properties={ - 'user': relationship( - User, lazy='joined', - backref=backref('addresses', lazy='baked_select') - ) - }) + mapper( + Address, + self.tables.addresses, + properties={ + "user": relationship( + User, + lazy="joined", + backref=backref("addresses", lazy="baked_select"), + ) + }, + ) from sqlalchemy.orm.interfaces import MapperOption @@ -1219,19 +1301,26 @@ class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): propagate_to_loaders = True sess = Session() - u1 = sess.query(User).options(MyBogusOption()).filter(User.id == 8) \ + u1 = ( + sess.query(User) + .options(MyBogusOption()) + .filter(User.id == 8) .one() + ) def go(): eq_(u1.addresses[0].user, u1) + self.assert_sql_execution( - testing.db, go, + testing.db, + go, CompiledSQL( "SELECT addresses.id AS addresses_id, addresses.user_id AS " "addresses_user_id, addresses.email_address AS " "addresses_email_address FROM addresses WHERE :param_1 = " "addresses.user_id", - {'param_1': 8}) + {"param_1": 8}, + ), ) # additional tests: @@ -1244,17 +1333,24 @@ class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): # assert that the integration style illustrated in the dogpile.cache # example works w/ baked class CustomIntegrationTest(testing.AssertsCompiledSQL, BakedTest): - run_setup_mappers = 'each' + run_setup_mappers = "each" def _o2m_fixture(self, lazy="select", **kw): User = self.classes.User Address = self.classes.Address - mapper(User, self.tables.users, properties={ - 'addresses': relationship( - Address, order_by=self.tables.addresses.c.id, - lazy=lazy, **kw) - }) + mapper( + User, + self.tables.users, + properties={ + "addresses": relationship( + Address, + order_by=self.tables.addresses.c.id, + lazy=lazy, + **kw + ) + }, + ) mapper(Address, self.tables.addresses) return User, Address @@ -1271,16 +1367,17 @@ class CustomIntegrationTest(testing.AssertsCompiledSQL, BakedTest): def __iter__(self): super_ = super(CachingQuery, self) - if hasattr(self, '_cache_key'): + if hasattr(self, "_cache_key"): return self.get_value( - createfunc=lambda: list(super_.__iter__())) + createfunc=lambda: list(super_.__iter__()) + ) else: return super_.__iter__() def _execute_and_instances(self, context): super_ = super(CachingQuery, self) - if context.query is not self and hasattr(self, '_cache_key'): + if context.query is not self and hasattr(self, "_cache_key"): return self.get_value( createfunc=lambda: list( super_._execute_and_instances(context) @@ -1323,17 +1420,11 @@ class CustomIntegrationTest(testing.AssertsCompiledSQL, BakedTest): q = sess.query(User).filter(User.id == 7).set_cache_key("user7") - eq_( - q.all(), - [User(id=7, addresses=[Address(id=1)])] - ) + eq_(q.all(), [User(id=7, addresses=[Address(id=1)])]) eq_(q.cache, {"user7": [User(id=7, addresses=[Address(id=1)])]}) - eq_( - q.all(), - [User(id=7, addresses=[Address(id=1)])] - ) + eq_(q.all(), [User(id=7, addresses=[Address(id=1)])]) def test_use_w_baked(self): User, Address = self._o2m_fixture() @@ -1342,22 +1433,15 @@ class CustomIntegrationTest(testing.AssertsCompiledSQL, BakedTest): q = sess._query_cls eq_(q.cache, {}) - base_bq = self.bakery( - lambda s: s.query(User)) + base_bq = self.bakery(lambda s: s.query(User)) base_bq += lambda q: q.filter(User.id == 7) base_bq += lambda q: q.set_cache_key("user7") - eq_( - base_bq(sess).all(), - [User(id=7, addresses=[Address(id=1)])] - ) + eq_(base_bq(sess).all(), [User(id=7, addresses=[Address(id=1)])]) eq_(q.cache, {"user7": [User(id=7, addresses=[Address(id=1)])]}) - eq_( - base_bq(sess).all(), - [User(id=7, addresses=[Address(id=1)])] - ) + eq_(base_bq(sess).all(), [User(id=7, addresses=[Address(id=1)])]) def test_plain_w_baked_lazyload(self): User, Address = self._o2m_fixture() diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py index c23d5f2ac1..22ab1f1634 100644 --- a/test/ext/test_compiler.py +++ b/test/ext/test_compiler.py @@ -1,8 +1,13 @@ from sqlalchemy import * from sqlalchemy.types import TypeEngine -from sqlalchemy.sql.expression import ClauseElement, ColumnClause,\ - FunctionElement, Select, \ - BindParameter, ColumnElement +from sqlalchemy.sql.expression import ( + ClauseElement, + ColumnClause, + FunctionElement, + Select, + BindParameter, + ColumnElement, +) from sqlalchemy.schema import DDLElement, CreateColumn, CreateTable from sqlalchemy.ext.compiler import compiles, deregister @@ -14,92 +19,87 @@ from sqlalchemy.testing import fixtures, AssertsCompiledSQL class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_column(self): - class MyThingy(ColumnClause): def __init__(self, arg=None): - super(MyThingy, self).__init__(arg or 'MYTHINGY!') + super(MyThingy, self).__init__(arg or "MYTHINGY!") @compiles(MyThingy) def visit_thingy(thingy, compiler, **kw): return ">>%s<<" % thingy.name self.assert_compile( - select([column('foo'), MyThingy()]), - "SELECT foo, >>MYTHINGY!<<" + select([column("foo"), MyThingy()]), "SELECT foo, >>MYTHINGY!<<" ) self.assert_compile( - select([MyThingy('x'), MyThingy('y')]).where(MyThingy() == 5), - "SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1" + select([MyThingy("x"), MyThingy("y")]).where(MyThingy() == 5), + "SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1", ) def test_create_column_skip(self): @compiles(CreateColumn) def skip_xmin(element, compiler, **kw): - if element.element.name == 'xmin': + if element.element.name == "xmin": return None else: return compiler.visit_create_column(element, **kw) - t = Table('t', MetaData(), Column('a', Integer), - Column('xmin', Integer), - Column('c', Integer)) + t = Table( + "t", + MetaData(), + Column("a", Integer), + Column("xmin", Integer), + Column("c", Integer), + ) self.assert_compile( - CreateTable(t), - "CREATE TABLE t (a INTEGER, c INTEGER)" + CreateTable(t), "CREATE TABLE t (a INTEGER, c INTEGER)" ) def test_types(self): class MyType(TypeEngine): pass - @compiles(MyType, 'sqlite') + @compiles(MyType, "sqlite") def visit_type(type, compiler, **kw): return "SQLITE_FOO" - @compiles(MyType, 'postgresql') + @compiles(MyType, "postgresql") def visit_type(type, compiler, **kw): return "POSTGRES_FOO" from sqlalchemy.dialects.sqlite import base as sqlite from sqlalchemy.dialects.postgresql import base as postgresql - self.assert_compile( - MyType(), - "SQLITE_FOO", - dialect=sqlite.dialect() - ) + self.assert_compile(MyType(), "SQLITE_FOO", dialect=sqlite.dialect()) self.assert_compile( - MyType(), - "POSTGRES_FOO", - dialect=postgresql.dialect() + MyType(), "POSTGRES_FOO", dialect=postgresql.dialect() ) def test_stateful(self): class MyThingy(ColumnClause): def __init__(self): - super(MyThingy, self).__init__('MYTHINGY!') + super(MyThingy, self).__init__("MYTHINGY!") @compiles(MyThingy) def visit_thingy(thingy, compiler, **kw): - if not hasattr(compiler, 'counter'): + if not hasattr(compiler, "counter"): compiler.counter = 0 compiler.counter += 1 return str(compiler.counter) self.assert_compile( - select([column('foo'), MyThingy()]).order_by(desc(MyThingy())), - "SELECT foo, 1 ORDER BY 2 DESC" + select([column("foo"), MyThingy()]).order_by(desc(MyThingy())), + "SELECT foo, 1 ORDER BY 2 DESC", ) self.assert_compile( select([MyThingy(), MyThingy()]).where(MyThingy() == 5), - "SELECT 1, 2 WHERE 3 = :MYTHINGY!_1" + "SELECT 1, 2 WHERE 3 = :MYTHINGY!_1", ) def test_callout_to_compiler(self): @@ -112,34 +112,31 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): def visit_insert_from_select(element, compiler, **kw): return "INSERT INTO %s (%s)" % ( compiler.process(element.table, asfrom=True), - compiler.process(element.select) + compiler.process(element.select), ) - t1 = table("mytable", column('x'), column('y'), column('z')) + t1 = table("mytable", column("x"), column("y"), column("z")) self.assert_compile( - InsertFromSelect( - t1, - select([t1]).where(t1.c.x > 5) - ), + InsertFromSelect(t1, select([t1]).where(t1.c.x > 5)), "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z " - "FROM mytable WHERE mytable.x > :x_1)" + "FROM mytable WHERE mytable.x > :x_1)", ) def test_no_default_but_has_a_visit(self): class MyThingy(ColumnClause): pass - @compiles(MyThingy, 'postgresql') + @compiles(MyThingy, "postgresql") def visit_thingy(thingy, compiler, **kw): return "mythingy" - eq_(str(MyThingy('x')), "x") + eq_(str(MyThingy("x")), "x") def test_no_default_has_no_visit(self): class MyThingy(TypeEngine): pass - @compiles(MyThingy, 'postgresql') + @compiles(MyThingy, "postgresql") def visit_thingy(thingy, compiler, **kw): return "mythingy" @@ -147,14 +144,15 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): exc.CompileError, " " "construct has no default compilation handler.", - str, MyThingy() + str, + MyThingy(), ) def test_no_default_message(self): class MyThingy(ClauseElement): pass - @compiles(MyThingy, 'postgresql') + @compiles(MyThingy, "postgresql") def visit_thingy(thingy, compiler, **kw): return "mythingy" @@ -162,7 +160,8 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): exc.CompileError, " " "construct has no default compilation handler.", - str, MyThingy() + str, + MyThingy(), ) def test_default_subclass(self): @@ -176,9 +175,7 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): return "array" self.assert_compile( - MyArray(Integer), - "INTEGER[]", - dialect="postgresql" + MyArray(Integer), "INTEGER[]", dialect="postgresql" ) def test_annotations(self): @@ -187,35 +184,31 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): """ - t1 = table('t1', column('c1'), column('c2')) + t1 = table("t1", column("c1"), column("c2")) dispatch = Select._compiler_dispatch try: + @compiles(Select) def compile(element, compiler, **kw): return "OVERRIDE" s1 = select([t1]) - self.assert_compile( - s1, "OVERRIDE" - ) - self.assert_compile( - s1._annotate({}), - "OVERRIDE" - ) + self.assert_compile(s1, "OVERRIDE") + self.assert_compile(s1._annotate({}), "OVERRIDE") finally: Select._compiler_dispatch = dispatch - if hasattr(Select, '_compiler_dispatcher'): + if hasattr(Select, "_compiler_dispatcher"): del Select._compiler_dispatcher def test_dialect_specific(self): class AddThingy(DDLElement): - __visit_name__ = 'add_thingy' + __visit_name__ = "add_thingy" class DropThingy(DDLElement): - __visit_name__ = 'drop_thingy' + __visit_name__ = "drop_thingy" - @compiles(AddThingy, 'sqlite') + @compiles(AddThingy, "sqlite") def visit_add_thingy(thingy, compiler, **kw): return "ADD SPECIAL SL THINGY" @@ -232,21 +225,22 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(DropThingy(), "DROP THINGY") from sqlalchemy.dialects.sqlite import base - self.assert_compile(AddThingy(), - "ADD SPECIAL SL THINGY", - dialect=base.dialect()) - self.assert_compile(DropThingy(), - "DROP THINGY", - dialect=base.dialect()) + self.assert_compile( + AddThingy(), "ADD SPECIAL SL THINGY", dialect=base.dialect() + ) - @compiles(DropThingy, 'sqlite') + self.assert_compile( + DropThingy(), "DROP THINGY", dialect=base.dialect() + ) + + @compiles(DropThingy, "sqlite") def visit_drop_thingy(thingy, compiler, **kw): return "DROP SPECIAL SL THINGY" - self.assert_compile(DropThingy(), - "DROP SPECIAL SL THINGY", - dialect=base.dialect()) + self.assert_compile( + DropThingy(), "DROP SPECIAL SL THINGY", dialect=base.dialect() + ) self.assert_compile(DropThingy(), "DROP THINGY") @@ -260,19 +254,17 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): def visit_myfunc(element, compiler, **kw): return "utcnow()" - @compiles(MyUtcFunction, 'postgresql') + @compiles(MyUtcFunction, "postgresql") def visit_myfunc(element, compiler, **kw): return "timezone('utc', current_timestamp)" self.assert_compile( - MyUtcFunction(), - "utcnow()", - use_default_dialect=True + MyUtcFunction(), "utcnow()", use_default_dialect=True ) self.assert_compile( MyUtcFunction(), "timezone('utc', current_timestamp)", - dialect=postgresql.dialect() + dialect=postgresql.dialect(), ) def test_function_calls_base(self): @@ -280,13 +272,13 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): class greatest(FunctionElement): type = Numeric() - name = 'greatest' + name = "greatest" @compiles(greatest) def default_greatest(element, compiler, **kw): return compiler.visit_function(element) - @compiles(greatest, 'mssql') + @compiles(greatest, "mssql") def case_greatest(element, compiler, **kw): arg1, arg2 = list(element.clauses) return "CASE WHEN %s > %s THEN %s ELSE %s END" % ( @@ -297,26 +289,26 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): ) self.assert_compile( - greatest('a', 'b'), - 'greatest(:greatest_1, :greatest_2)', - use_default_dialect=True + greatest("a", "b"), + "greatest(:greatest_1, :greatest_2)", + use_default_dialect=True, ) self.assert_compile( - greatest('a', 'b'), + greatest("a", "b"), "CASE WHEN :greatest_1 > :greatest_2 " "THEN :greatest_1 ELSE :greatest_2 END", - dialect=mssql.dialect() + dialect=mssql.dialect(), ) def test_subclasses_one(self): class Base(FunctionElement): - name = 'base' + name = "base" class Sub1(Base): - name = 'sub1' + name = "sub1" class Sub2(Base): - name = 'sub2' + name = "sub2" @compiles(Base) def visit_base(element, compiler, **kw): @@ -328,31 +320,31 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( select([Sub1(), Sub2()]), - 'SELECT FOOsub1, sub2', - use_default_dialect=True + "SELECT FOOsub1, sub2", + use_default_dialect=True, ) def test_subclasses_two(self): class Base(FunctionElement): - name = 'base' + name = "base" class Sub1(Base): - name = 'sub1' + name = "sub1" @compiles(Base) def visit_base(element, compiler, **kw): return element.name class Sub2(Base): - name = 'sub2' + name = "sub2" class SubSub1(Sub1): - name = 'subsub1' + name = "subsub1" self.assert_compile( select([Sub1(), Sub2(), SubSub1()]), - 'SELECT sub1, sub2, subsub1', - use_default_dialect=True + "SELECT sub1, sub2, subsub1", + use_default_dialect=True, ) @compiles(Sub1) @@ -361,42 +353,36 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( select([Sub1(), Sub2(), SubSub1()]), - 'SELECT FOOsub1, sub2, FOOsubsub1', - use_default_dialect=True + "SELECT FOOsub1, sub2, FOOsubsub1", + use_default_dialect=True, ) class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL): """Test replacement of default compilation on existing constructs.""" - __dialect__ = 'default' + + __dialect__ = "default" def teardown(self): for cls in (Select, BindParameter): deregister(cls) def test_select(self): - t1 = table('t1', column('c1'), column('c2')) + t1 = table("t1", column("c1"), column("c2")) - @compiles(Select, 'sqlite') + @compiles(Select, "sqlite") def compile(element, compiler, **kw): return "OVERRIDE" s1 = select([t1]) - self.assert_compile( - s1, "SELECT t1.c1, t1.c2 FROM t1", - ) + self.assert_compile(s1, "SELECT t1.c1, t1.c2 FROM t1") from sqlalchemy.dialects.sqlite import base as sqlite - self.assert_compile( - s1, "OVERRIDE", - dialect=sqlite.dialect() - ) + + self.assert_compile(s1, "OVERRIDE", dialect=sqlite.dialect()) def test_binds_in_select(self): - t = table('t', - column('a'), - column('b'), - column('c')) + t = table("t", column("a"), column("b"), column("c")) @compiles(BindParameter) def gen_bind(element, compiler, **kw): @@ -405,14 +391,11 @@ class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( t.select().where(t.c.c == 5), "SELECT t.a, t.b, t.c FROM t WHERE t.c = BIND(:c_1)", - use_default_dialect=True + use_default_dialect=True, ) def test_binds_in_dml(self): - t = table('t', - column('a'), - column('b'), - column('c')) + t = table("t", column("a"), column("b"), column("c")) @compiles(BindParameter) def gen_bind(element, compiler, **kw): @@ -421,6 +404,6 @@ class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( t.insert(), "INSERT INTO t (a, b) VALUES (BIND(:a), BIND(:b))", - {'a': 1, 'b': 2}, - use_default_dialect=True + {"a": 1, "b": 2}, + use_default_dialect=True, ) diff --git a/test/ext/test_extendedattr.py b/test/ext/test_extendedattr.py index cf0613a522..277178af99 100644 --- a/test/ext/test_extendedattr.py +++ b/test/ext/test_extendedattr.py @@ -3,8 +3,11 @@ from sqlalchemy import util import sqlalchemy as sa from sqlalchemy.orm import class_mapper from sqlalchemy.orm import attributes -from sqlalchemy.orm.attributes import set_attribute, \ - get_attribute, del_attribute +from sqlalchemy.orm.attributes import ( + set_attribute, + get_attribute, + del_attribute, +) from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.orm import clear_mappers from sqlalchemy.testing import fixtures @@ -32,7 +35,6 @@ class _ExtBase(object): class MyTypesManager(instrumentation.InstrumentationManager): - def instrument_attribute(self, class_, key, attr): pass @@ -49,13 +51,13 @@ class MyTypesManager(instrumentation.InstrumentationManager): return instance._goofy_dict def initialize_instance_dict(self, class_, instance): - instance.__dict__['_goofy_dict'] = {} + instance.__dict__["_goofy_dict"] = {} def install_state(self, class_, instance, state): - instance.__dict__['_my_state'] = state + instance.__dict__["_my_state"] = state def state_getter(self, class_): - return lambda instance: instance.__dict__['_my_state'] + return lambda instance: instance.__dict__["_my_state"] class MyListLike(list): @@ -68,6 +70,7 @@ class MyListLike(list): if _sa_initiator is not False: self._sa_adapter.fire_append_event(item, _sa_initiator) list.append(self, item) + append = _sa_appender def _sa_remover(self, item, _sa_initiator=None): @@ -75,6 +78,7 @@ class MyListLike(list): if _sa_initiator is not False: self._sa_adapter.fire_remove_event(item, _sa_initiator) list.remove(self, item) + remove = _sa_remover @@ -82,14 +86,14 @@ MyBaseClass, MyClass = None, None class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): - @classmethod def setup_class(cls): global MyBaseClass, MyClass class MyBaseClass(object): - __sa_instrumentation_manager__ = \ + __sa_instrumentation_manager__ = ( instrumentation.InstrumentationManager + ) class MyClass(object): @@ -99,11 +103,12 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): return MyTypesManager(cls) __sa_instrumentation_manager__ = staticmethod( - __sa_instrumentation_manager__) + __sa_instrumentation_manager__ + ) # This proves SA can handle a class with non-string dict keys if not util.pypy and not util.jython: - locals()[42] = 99 # Don't remove this line! + locals()[42] = 99 # Don't remove this line! def __init__(self, **kwargs): for k in kwargs: @@ -145,64 +150,76 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): register_class(User) attributes.register_attribute( - User, 'user_id', uselist=False, useobject=False) + User, "user_id", uselist=False, useobject=False + ) attributes.register_attribute( - User, 'user_name', uselist=False, useobject=False) + User, "user_name", uselist=False, useobject=False + ) attributes.register_attribute( - User, 'email_address', uselist=False, useobject=False) + User, "email_address", uselist=False, useobject=False + ) u = User() u.user_id = 7 - u.user_name = 'john' - u.email_address = 'lala@123.com' + u.user_name = "john" + u.email_address = "lala@123.com" eq_( u.__dict__, { - '_my_state': u._my_state, - '_goofy_dict': { - 'user_id': 7, 'user_name': 'john', - 'email_address': 'lala@123.com'}} + "_my_state": u._my_state, + "_goofy_dict": { + "user_id": 7, + "user_name": "john", + "email_address": "lala@123.com", + }, + }, ) def test_basic(self): for base in (object, MyBaseClass, MyClass): + class User(base): pass register_class(User) attributes.register_attribute( - User, 'user_id', uselist=False, useobject=False) + User, "user_id", uselist=False, useobject=False + ) attributes.register_attribute( - User, 'user_name', uselist=False, useobject=False) + User, "user_name", uselist=False, useobject=False + ) attributes.register_attribute( - User, 'email_address', uselist=False, useobject=False) + User, "email_address", uselist=False, useobject=False + ) u = User() u.user_id = 7 - u.user_name = 'john' - u.email_address = 'lala@123.com' + u.user_name = "john" + u.email_address = "lala@123.com" eq_(u.user_id, 7) eq_(u.user_name, "john") eq_(u.email_address, "lala@123.com") attributes.instance_state(u)._commit_all( - attributes.instance_dict(u)) + attributes.instance_dict(u) + ) eq_(u.user_id, 7) eq_(u.user_name, "john") eq_(u.email_address, "lala@123.com") - u.user_name = 'heythere' - u.email_address = 'foo@bar.com' + u.user_name = "heythere" + u.email_address = "foo@bar.com" eq_(u.user_id, 7) eq_(u.user_name, "heythere") eq_(u.email_address, "foo@bar.com") def test_deferred(self): for base in (object, MyBaseClass, MyClass): + class Foo(base): pass - data = {'a': 'this is a', 'b': 12} + data = {"a": "this is a", "b": 12} def loader(state, keys): for k in keys: @@ -212,37 +229,47 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): manager = register_class(Foo) manager.deferred_scalar_loader = loader attributes.register_attribute( - Foo, 'a', uselist=False, useobject=False) + Foo, "a", uselist=False, useobject=False + ) attributes.register_attribute( - Foo, 'b', uselist=False, useobject=False) + Foo, "b", uselist=False, useobject=False + ) if base is object: - assert Foo not in \ - instrumentation._instrumentation_factory._state_finders + assert ( + Foo + not in instrumentation._instrumentation_factory._state_finders + ) else: - assert Foo in \ - instrumentation._instrumentation_factory._state_finders + assert ( + Foo + in instrumentation._instrumentation_factory._state_finders + ) f = Foo() attributes.instance_state(f)._expire( - attributes.instance_dict(f), set()) + attributes.instance_dict(f), set() + ) eq_(f.a, "this is a") eq_(f.b, 12) f.a = "this is some new a" attributes.instance_state(f)._expire( - attributes.instance_dict(f), set()) + attributes.instance_dict(f), set() + ) eq_(f.a, "this is a") eq_(f.b, 12) attributes.instance_state(f)._expire( - attributes.instance_dict(f), set()) + attributes.instance_dict(f), set() + ) f.a = "this is another new a" eq_(f.a, "this is another new a") eq_(f.b, 12) attributes.instance_state(f)._expire( - attributes.instance_dict(f), set()) + attributes.instance_dict(f), set() + ) eq_(f.a, "this is a") eq_(f.b, 12) @@ -251,7 +278,8 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): eq_(f.b, 12) attributes.instance_state(f)._commit_all( - attributes.instance_dict(f)) + attributes.instance_dict(f) + ) eq_(f.a, None) eq_(f.b, 12) @@ -259,6 +287,7 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): """tests that attributes are polymorphic""" for base in (object, MyBaseClass, MyClass): + class Foo(base): pass @@ -276,25 +305,27 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): def func3(state, passive): return "this is the shared attr" - attributes.register_attribute(Foo, 'element', - uselist=False, callable_=func1, - useobject=True) - attributes.register_attribute(Foo, 'element2', - uselist=False, callable_=func3, - useobject=True) - attributes.register_attribute(Bar, 'element', - uselist=False, callable_=func2, - useobject=True) + + attributes.register_attribute( + Foo, "element", uselist=False, callable_=func1, useobject=True + ) + attributes.register_attribute( + Foo, "element2", uselist=False, callable_=func3, useobject=True + ) + attributes.register_attribute( + Bar, "element", uselist=False, callable_=func2, useobject=True + ) x = Foo() y = Bar() - assert x.element == 'this is the foo attr' - assert y.element == 'this is the bar attr', y.element - assert x.element2 == 'this is the shared attr' - assert y.element2 == 'this is the shared attr' + assert x.element == "this is the foo attr" + assert y.element == "this is the bar attr", y.element + assert x.element2 == "this is the shared attr" + assert y.element2 == "this is the shared attr" def test_collection_with_backref(self): for base in (object, MyBaseClass, MyClass): + class Post(base): pass @@ -304,11 +335,21 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): register_class(Post) register_class(Blog) attributes.register_attribute( - Post, 'blog', uselist=False, - backref='posts', trackparent=True, useobject=True) + Post, + "blog", + uselist=False, + backref="posts", + trackparent=True, + useobject=True, + ) attributes.register_attribute( - Blog, 'posts', uselist=True, - backref='blog', trackparent=True, useobject=True) + Blog, + "posts", + uselist=True, + backref="blog", + trackparent=True, + useobject=True, + ) b = Blog() (p1, p2, p3) = (Post(), Post(), Post()) b.posts.append(p1) @@ -334,6 +375,7 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): def test_history(self): for base in (object, MyBaseClass, MyClass): + class Foo(base): pass @@ -343,73 +385,93 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): register_class(Foo) register_class(Bar) attributes.register_attribute( - Foo, "name", uselist=False, useobject=False) + Foo, "name", uselist=False, useobject=False + ) attributes.register_attribute( - Foo, "bars", uselist=True, trackparent=True, useobject=True) + Foo, "bars", uselist=True, trackparent=True, useobject=True + ) attributes.register_attribute( - Bar, "name", uselist=False, useobject=False) + Bar, "name", uselist=False, useobject=False + ) f1 = Foo() - f1.name = 'f1' + f1.name = "f1" eq_( attributes.get_state_history( - attributes.instance_state(f1), 'name'), - (['f1'], (), ())) + attributes.instance_state(f1), "name" + ), + (["f1"], (), ()), + ) b1 = Bar() - b1.name = 'b1' + b1.name = "b1" f1.bars.append(b1) eq_( attributes.get_state_history( - attributes.instance_state(f1), 'bars'), - ([b1], [], [])) + attributes.instance_state(f1), "bars" + ), + ([b1], [], []), + ) attributes.instance_state(f1)._commit_all( - attributes.instance_dict(f1)) + attributes.instance_dict(f1) + ) attributes.instance_state(b1)._commit_all( - attributes.instance_dict(b1)) + attributes.instance_dict(b1) + ) eq_( attributes.get_state_history( - attributes.instance_state(f1), - 'name'), - ((), ['f1'], ())) + attributes.instance_state(f1), "name" + ), + ((), ["f1"], ()), + ) eq_( attributes.get_state_history( - attributes.instance_state(f1), - 'bars'), - ((), [b1], ())) + attributes.instance_state(f1), "bars" + ), + ((), [b1], ()), + ) - f1.name = 'f1mod' + f1.name = "f1mod" b2 = Bar() - b2.name = 'b2' + b2.name = "b2" f1.bars.append(b2) eq_( attributes.get_state_history( - attributes.instance_state(f1), 'name'), - (['f1mod'], (), ['f1'])) + attributes.instance_state(f1), "name" + ), + (["f1mod"], (), ["f1"]), + ) eq_( attributes.get_state_history( - attributes.instance_state(f1), 'bars'), - ([b2], [b1], [])) + attributes.instance_state(f1), "bars" + ), + ([b2], [b1], []), + ) f1.bars.remove(b1) eq_( attributes.get_state_history( - attributes.instance_state(f1), 'bars'), - ([b2], [], [b1])) + attributes.instance_state(f1), "bars" + ), + ([b2], [], [b1]), + ) def test_null_instrumentation(self): class Foo(MyBaseClass): pass + register_class(Foo) attributes.register_attribute( - Foo, "name", uselist=False, useobject=False) + Foo, "name", uselist=False, useobject=False + ) attributes.register_attribute( - Foo, "bars", uselist=True, trackparent=True, useobject=True) + Foo, "bars", uselist=True, trackparent=True, useobject=True + ) - assert Foo.name == attributes.manager_of_class(Foo)['name'] - assert Foo.bars == attributes.manager_of_class(Foo)['bars'] + assert Foo.name == attributes.manager_of_class(Foo)["name"] + assert Foo.bars == attributes.manager_of_class(Foo)["bars"] def test_alternate_finders(self): """Ensure the generic finder front-end deals with edge cases.""" @@ -428,10 +490,10 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): assert instrumentation.manager_of_class(None) is None assert attributes.instance_state(k) is not None - assert_raises((AttributeError, KeyError), - attributes.instance_state, u) - assert_raises((AttributeError, KeyError), - attributes.instance_state, None) + assert_raises((AttributeError, KeyError), attributes.instance_state, u) + assert_raises( + (AttributeError, KeyError), attributes.instance_state, None + ) def test_unmapped_not_type_error(self): """extension version of the same test in test_mapper. @@ -441,7 +503,8 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): assert_raises_message( sa.exc.ArgumentError, "Class object expected, got '5'.", - class_mapper, 5 + class_mapper, + 5, ) def test_unmapped_not_type_error_iter_ok(self): @@ -452,31 +515,28 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): assert_raises_message( sa.exc.ArgumentError, r"Class object expected, got '\(5, 6\)'.", - class_mapper, (5, 6) + class_mapper, + (5, 6), ) class FinderTest(_ExtBase, fixtures.ORMTest): - def test_standard(self): class A(object): pass register_class(A) - eq_( - type(manager_of_class(A)), - instrumentation.ClassManager) + eq_(type(manager_of_class(A)), instrumentation.ClassManager) def test_nativeext_interfaceexact(self): class A(object): - __sa_instrumentation_manager__ = \ + __sa_instrumentation_manager__ = ( instrumentation.InstrumentationManager + ) register_class(A) - ne_( - type(manager_of_class(A)), - instrumentation.ClassManager) + ne_(type(manager_of_class(A)), instrumentation.ClassManager) def test_nativeext_submanager(self): class Mine(instrumentation.ClassManager): @@ -514,60 +574,72 @@ class FinderTest(_ExtBase, fixtures.ORMTest): instrumentation.instrumentation_finders.insert(0, find) register_class(A) - eq_( - type(manager_of_class(A)), - instrumentation.ClassManager) + eq_(type(manager_of_class(A)), instrumentation.ClassManager) class InstrumentationCollisionTest(_ExtBase, fixtures.ORMTest): - def test_none(self): class A(object): pass + register_class(A) - def mgr_factory(cls): return instrumentation.ClassManager(cls) + def mgr_factory(cls): + return instrumentation.ClassManager(cls) class B(object): __sa_instrumentation_manager__ = staticmethod(mgr_factory) + register_class(B) class C(object): __sa_instrumentation_manager__ = instrumentation.ClassManager + register_class(C) def test_single_down(self): class A(object): pass + register_class(A) - def mgr_factory(cls): return instrumentation.ClassManager(cls) + def mgr_factory(cls): + return instrumentation.ClassManager(cls) class B(A): __sa_instrumentation_manager__ = staticmethod(mgr_factory) assert_raises_message( - TypeError, "multiple instrumentation implementations", - register_class, B) + TypeError, + "multiple instrumentation implementations", + register_class, + B, + ) def test_single_up(self): - class A(object): pass + # delay registration - def mgr_factory(cls): return instrumentation.ClassManager(cls) + def mgr_factory(cls): + return instrumentation.ClassManager(cls) class B(A): __sa_instrumentation_manager__ = staticmethod(mgr_factory) + register_class(B) assert_raises_message( - TypeError, "multiple instrumentation implementations", - register_class, A) + TypeError, + "multiple instrumentation implementations", + register_class, + A, + ) def test_diamond_b1(self): - def mgr_factory(cls): return instrumentation.ClassManager(cls) + def mgr_factory(cls): + return instrumentation.ClassManager(cls) class A(object): pass @@ -582,11 +654,15 @@ class InstrumentationCollisionTest(_ExtBase, fixtures.ORMTest): pass assert_raises_message( - TypeError, "multiple instrumentation implementations", - register_class, B1) + TypeError, + "multiple instrumentation implementations", + register_class, + B1, + ) def test_diamond_b2(self): - def mgr_factory(cls): return instrumentation.ClassManager(cls) + def mgr_factory(cls): + return instrumentation.ClassManager(cls) class A(object): pass @@ -602,11 +678,15 @@ class InstrumentationCollisionTest(_ExtBase, fixtures.ORMTest): register_class(B2) assert_raises_message( - TypeError, "multiple instrumentation implementations", - register_class, B1) + TypeError, + "multiple instrumentation implementations", + register_class, + B1, + ) def test_diamond_c_b(self): - def mgr_factory(cls): return instrumentation.ClassManager(cls) + def mgr_factory(cls): + return instrumentation.ClassManager(cls) class A(object): pass @@ -623,8 +703,11 @@ class InstrumentationCollisionTest(_ExtBase, fixtures.ORMTest): register_class(C) assert_raises_message( - TypeError, "multiple instrumentation implementations", - register_class, B1) + TypeError, + "multiple instrumentation implementations", + register_class, + B1, + ) class ExtendedEventsTest(_ExtBase, fixtures.ORMTest): @@ -640,7 +723,8 @@ class ExtendedEventsTest(_ExtBase, fixtures.ORMTest): dispatch = event.dispatcher(MyEvents) instrumentation.instrumentation_finders.insert( - 0, lambda cls: MyClassManager) + 0, lambda cls: MyClassManager + ) class A(object): pass diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index 164162a5df..def82faf7e 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -16,7 +16,7 @@ from sqlalchemy import testing class ShardTest(object): __skip_if__ = (lambda: util.win32,) - __requires__ = 'sqlite', + __requires__ = ("sqlite",) schema = None @@ -26,8 +26,7 @@ class ShardTest(object): db1, db2, db3, db4 = self._init_dbs() meta = MetaData() - ids = Table('ids', meta, - Column('nextid', Integer, nullable=False)) + ids = Table("ids", meta, Column("nextid", Integer, nullable=False)) def id_generator(ctx): # in reality, might want to use a separate transaction for this. @@ -38,24 +37,23 @@ class ShardTest(object): return nextid weather_locations = Table( - "weather_locations", meta, - Column('id', Integer, primary_key=True, default=id_generator), - Column('continent', String(30), nullable=False), - Column('city', String(50), nullable=False), - schema=self.schema + "weather_locations", + meta, + Column("id", Integer, primary_key=True, default=id_generator), + Column("continent", String(30), nullable=False), + Column("city", String(50), nullable=False), + schema=self.schema, ) weather_reports = Table( - 'weather_reports', + "weather_reports", meta, - Column('id', Integer, primary_key=True), - Column('location_id', Integer, - ForeignKey(weather_locations.c.id)), - Column('temperature', Float), - Column('report_time', DateTime, - default=datetime.datetime.now), - schema=self.schema - ) + Column("id", Integer, primary_key=True), + Column("location_id", Integer, ForeignKey(weather_locations.c.id)), + Column("temperature", Float), + Column("report_time", DateTime, default=datetime.datetime.now), + schema=self.schema, + ) for db in (db1, db2, db3, db4): meta.create_all(db) @@ -69,11 +67,11 @@ class ShardTest(object): def setup_session(cls): global create_session shard_lookup = { - 'North America': 'north_america', - 'Asia': 'asia', - 'Europe': 'europe', - 'South America': 'south_america', - } + "North America": "north_america", + "Asia": "asia", + "Europe": "europe", + "South America": "south_america", + } def shard_chooser(mapper, instance, clause=None): if isinstance(instance, WeatherLocation): @@ -82,16 +80,16 @@ class ShardTest(object): return shard_chooser(mapper, instance.location) def id_chooser(query, ident): - return ['north_america', 'asia', 'europe', 'south_america'] + return ["north_america", "asia", "europe", "south_america"] def query_chooser(query): ids = [] class FindContinent(sql.ClauseVisitor): - def visit_binary(self, binary): if binary.left.shares_lineage( - weather_locations.c.continent): + weather_locations.c.continent + ): if binary.operator == operators.eq: ids.append(shard_lookup[binary.right.value]) elif binary.operator == operators.in_op: @@ -101,20 +99,24 @@ class ShardTest(object): if query._criterion is not None: FindContinent().traverse(query._criterion) if len(ids) == 0: - return ['north_america', 'asia', 'europe', - 'south_america'] + return ["north_america", "asia", "europe", "south_america"] else: return ids - create_session = sessionmaker(class_=ShardedSession, - autoflush=True, autocommit=False) - create_session.configure(shards={ - 'north_america': db1, - 'asia': db2, - 'europe': db3, - 'south_america': db4, - }, shard_chooser=shard_chooser, id_chooser=id_chooser, - query_chooser=query_chooser) + create_session = sessionmaker( + class_=ShardedSession, autoflush=True, autocommit=False + ) + create_session.configure( + shards={ + "north_america": db1, + "asia": db2, + "europe": db3, + "south_america": db4, + }, + shard_chooser=shard_chooser, + id_chooser=id_chooser, + query_chooser=query_chooser, + ) @classmethod def setup_mappers(cls): @@ -131,34 +133,30 @@ class ShardTest(object): if id_: self.id = id_ - mapper(WeatherLocation, weather_locations, properties={ - 'reports': relationship(Report, backref='location'), - 'city': deferred(weather_locations.c.city), - }) + mapper( + WeatherLocation, + weather_locations, + properties={ + "reports": relationship(Report, backref="location"), + "city": deferred(weather_locations.c.city), + }, + ) mapper(Report, weather_reports) def _fixture_data(self): - tokyo = WeatherLocation('Asia', 'Tokyo') - newyork = WeatherLocation('North America', 'New York') - toronto = WeatherLocation('North America', 'Toronto') - london = WeatherLocation('Europe', 'London') - dublin = WeatherLocation('Europe', 'Dublin') - brasilia = WeatherLocation('South America', 'Brasila') - quito = WeatherLocation('South America', 'Quito') + tokyo = WeatherLocation("Asia", "Tokyo") + newyork = WeatherLocation("North America", "New York") + toronto = WeatherLocation("North America", "Toronto") + london = WeatherLocation("Europe", "London") + dublin = WeatherLocation("Europe", "Dublin") + brasilia = WeatherLocation("South America", "Brasila") + quito = WeatherLocation("South America", "Quito") tokyo.reports.append(Report(80.0, id_=1)) newyork.reports.append(Report(75, id_=1)) quito.reports.append(Report(85)) sess = create_session() - for c in [ - tokyo, - newyork, - toronto, - london, - dublin, - brasilia, - quito, - ]: + for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]: sess.add(c) sess.flush() @@ -176,34 +174,51 @@ class ShardTest(object): tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").one() tokyo.city # reload 'city' attribute on tokyo sess.expire_all() - eq_(db2.execute(weather_locations.select()).fetchall(), [(1, - 'Asia', 'Tokyo')]) - eq_(db1.execute(weather_locations.select()).fetchall(), [(2, - 'North America', 'New York'), (3, 'North America', 'Toronto')]) - eq_(sess.execute(weather_locations.select(), shard_id='asia') - .fetchall(), [(1, 'Asia', 'Tokyo')]) + eq_( + db2.execute(weather_locations.select()).fetchall(), + [(1, "Asia", "Tokyo")], + ) + eq_( + db1.execute(weather_locations.select()).fetchall(), + [ + (2, "North America", "New York"), + (3, "North America", "Toronto"), + ], + ) + eq_( + sess.execute( + weather_locations.select(), shard_id="asia" + ).fetchall(), + [(1, "Asia", "Tokyo")], + ) t = sess.query(WeatherLocation).get(tokyo.id) eq_(t.city, tokyo.city) eq_(t.reports[0].temperature, 80.0) - north_american_cities = \ - sess.query(WeatherLocation).filter( - WeatherLocation.continent == 'North America') - eq_(set([c.city for c in north_american_cities]), - set(['New York', 'Toronto'])) - asia_and_europe = \ - sess.query(WeatherLocation).filter( - WeatherLocation.continent.in_(['Europe', 'Asia'])) - eq_(set([c.city for c in asia_and_europe]), set(['Tokyo', - 'London', 'Dublin'])) + north_american_cities = sess.query(WeatherLocation).filter( + WeatherLocation.continent == "North America" + ) + eq_( + set([c.city for c in north_american_cities]), + set(["New York", "Toronto"]), + ) + asia_and_europe = sess.query(WeatherLocation).filter( + WeatherLocation.continent.in_(["Europe", "Asia"]) + ) + eq_( + set([c.city for c in asia_and_europe]), + set(["Tokyo", "London", "Dublin"]), + ) # inspect the shard token stored with each instance eq_( set(inspect(c).key[2] for c in asia_and_europe), - set(['europe', 'asia'])) + set(["europe", "asia"]), + ) eq_( set(inspect(c).identity_token for c in asia_and_europe), - set(['europe', 'asia'])) + set(["europe", "asia"]), + ) newyork = sess.query(WeatherLocation).filter_by(city="New York").one() newyork_report = newyork.reports[0] @@ -212,12 +227,9 @@ class ShardTest(object): # same primary key, two identity keys eq_( inspect(newyork_report).identity_key, - (Report, (1, ), "north_america") - ) - eq_( - inspect(tokyo_report).identity_key, - (Report, (1, ), "asia") + (Report, (1,), "north_america"), ) + eq_(inspect(tokyo_report).identity_key, (Report, (1,), "asia")) # the token representing the originating shard is available eq_(inspect(newyork_report).identity_token, "north_america") @@ -238,7 +250,7 @@ class ShardTest(object): t = bq(sess).get(tokyo.id) eq_(t.city, tokyo.city) - eq_(inspect(t).key[2], 'asia') + eq_(inspect(t).key[2], "asia") def test_get_baked_query_shard_id(self): sess = self._fixture_data() @@ -252,11 +264,14 @@ class ShardTest(object): bakery = BakedQuery.bakery() bq = bakery(lambda session: session.query(WeatherLocation)) - t = bq(sess).with_post_criteria( - lambda q: q.set_shard("asia")).get(tokyo.id) + t = ( + bq(sess) + .with_post_criteria(lambda q: q.set_shard("asia")) + .get(tokyo.id) + ) eq_(t.city, tokyo.city) - eq_(inspect(t).key[2], 'asia') + eq_(inspect(t).key[2], "asia") def test_filter_baked_query_shard_id(self): sess = self._fixture_data() @@ -269,10 +284,10 @@ class ShardTest(object): bakery = BakedQuery.bakery() - bq = bakery(lambda session: session.query(WeatherLocation)).\ - with_criteria(lambda q: q.filter_by(id=tokyo.id)) - t = bq(sess).with_post_criteria( - lambda q: q.set_shard("asia")).one() + bq = bakery( + lambda session: session.query(WeatherLocation) + ).with_criteria(lambda q: q.filter_by(id=tokyo.id)) + t = bq(sess).with_post_criteria(lambda q: q.set_shard("asia")).one() eq_(t.city, tokyo.city) def test_shard_id_event(self): @@ -284,15 +299,25 @@ class ShardTest(object): event.listen(WeatherLocation, "load", load) sess = self._fixture_data() - tokyo = sess.query(WeatherLocation).\ - filter_by(city="Tokyo").set_shard("asia").one() + tokyo = ( + sess.query(WeatherLocation) + .filter_by(city="Tokyo") + .set_shard("asia") + .one() + ) sess.query(WeatherLocation).all() eq_( canary, - ['asia', 'north_america', 'north_america', - 'europe', 'europe', 'south_america', - 'south_america'] + [ + "asia", + "north_america", + "north_america", + "europe", + "europe", + "south_america", + "south_america", + ], ) def test_baked_mix(self): @@ -311,8 +336,9 @@ class ShardTest(object): t = bq(sess).get(tokyo.id) return t - Sess = sessionmaker(class_=Session, bind=db2, - autoflush=True, autocommit=False) + Sess = sessionmaker( + class_=Session, bind=db2, autoflush=True, autocommit=False + ) sess2 = Sess() t = get_tokyo(sess) @@ -326,45 +352,35 @@ class ShardTest(object): eq_( set(row.temperature for row in sess.query(Report.temperature)), - {80.0, 75.0, 85.0} + {80.0, 75.0, 85.0}, ) temps = sess.query(Report).all() - eq_( - set(t.temperature for t in temps), - {80.0, 75.0, 85.0} - ) + eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) - sess.query(Report).filter( - Report.temperature >= 80).update( - {"temperature": Report.temperature + 6}) + sess.query(Report).filter(Report.temperature >= 80).update( + {"temperature": Report.temperature + 6} + ) eq_( set(row.temperature for row in sess.query(Report.temperature)), - {86.0, 75.0, 91.0} + {86.0, 75.0, 91.0}, ) # test synchronize session as well - eq_( - set(t.temperature for t in temps), - {86.0, 75.0, 91.0} - ) + eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0}) def test_bulk_delete(self): sess = self._fixture_data() temps = sess.query(Report).all() - eq_( - set(t.temperature for t in temps), - {80.0, 75.0, 85.0} - ) + eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) - sess.query(Report).filter( - Report.temperature >= 80).delete() + sess.query(Report).filter(Report.temperature >= 80).delete() eq_( set(row.temperature for row in sess.query(Report.temperature)), - {75.0} + {75.0}, ) # test synchronize session as well @@ -372,21 +388,24 @@ class ShardTest(object): assert inspect(t).deleted is (t.temperature >= 80) - from sqlalchemy.testing import provision class DistinctEngineShardTest(ShardTest, fixtures.TestBase): def _init_dbs(self): db1 = testing_engine( - 'sqlite:///shard1_%s.db' % provision.FOLLOWER_IDENT, - options=dict(pool_threadlocal=True)) + "sqlite:///shard1_%s.db" % provision.FOLLOWER_IDENT, + options=dict(pool_threadlocal=True), + ) db2 = testing_engine( - 'sqlite:///shard2_%s.db' % provision.FOLLOWER_IDENT) + "sqlite:///shard2_%s.db" % provision.FOLLOWER_IDENT + ) db3 = testing_engine( - 'sqlite:///shard3_%s.db' % provision.FOLLOWER_IDENT) + "sqlite:///shard3_%s.db" % provision.FOLLOWER_IDENT + ) db4 = testing_engine( - 'sqlite:///shard4_%s.db' % provision.FOLLOWER_IDENT) + "sqlite:///shard4_%s.db" % provision.FOLLOWER_IDENT + ) self.dbs = [db1, db2, db3, db4] return self.dbs @@ -404,8 +423,9 @@ class AttachedFileShardTest(ShardTest, fixtures.TestBase): schema = "changeme" def _init_dbs(self): - db1 = testing_engine('sqlite://', options={"execution_options": - {"shard_id": "shard1"}}) + db1 = testing_engine( + "sqlite://", options={"execution_options": {"shard_id": "shard1"}} + ) db2 = db1.execution_options(shard_id="shard2") db3 = db1.execution_options(shard_id="shard3") db4 = db1.execution_options(shard_id="shard4") @@ -414,7 +434,7 @@ class AttachedFileShardTest(ShardTest, fixtures.TestBase): @event.listens_for(db1, "before_cursor_execute", retval=True) def _switch_shard(conn, cursor, stmt, params, context, executemany): - shard_id = conn._execution_options['shard_id'] + shard_id = conn._execution_options["shard_id"] # because SQLite can't just give us a "use" statement, we have # to use the schema hack to locate table names if shard_id: @@ -428,26 +448,27 @@ class AttachedFileShardTest(ShardTest, fixtures.TestBase): class SelectinloadRegressionTest(fixtures.DeclarativeMappedTest): """test #4175 """ + @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class Book(Base): - __tablename__ = 'book' + __tablename__ = "book" id = Column(Integer, primary_key=True) - pages = relationship('Page') + pages = relationship("Page") class Page(Base): - __tablename__ = 'page' + __tablename__ = "page" id = Column(Integer, primary_key=True) - book_id = Column(ForeignKey('book.id')) + book_id = Column(ForeignKey("book.id")) def test_selectinload_query(self): session = ShardedSession( shards={"test": testing.db}, - shard_chooser=lambda *args: 'test', + shard_chooser=lambda *args: "test", id_chooser=lambda *args: None, - query_chooser=lambda *args: ['test'] + query_chooser=lambda *args: ["test"], ) Book, Page = self.classes("Book", "Page") @@ -457,16 +478,17 @@ class SelectinloadRegressionTest(fixtures.DeclarativeMappedTest): session.add(book) session.commit() - result = session.query(Book).options(selectinload('pages')).all() + result = session.query(Book).options(selectinload("pages")).all() eq_(result, [book]) + class RefreshDeferExpireTest(fixtures.DeclarativeMappedTest): @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) data = Column(String(30)) deferred_data = deferred(Column(String(30))) @@ -475,18 +497,16 @@ class RefreshDeferExpireTest(fixtures.DeclarativeMappedTest): def insert_data(cls): A = cls.classes.A s = Session() - s.add(A(data='d1', deferred_data='d2')) + s.add(A(data="d1", deferred_data="d2")) s.commit() def _session_fixture(self): return ShardedSession( - shards={ - "main": testing.db, - }, - shard_chooser=lambda *args: 'main', - id_chooser=lambda *args: ['fake', 'main'], - query_chooser=lambda *args: ['fake', 'main'] + shards={"main": testing.db}, + shard_chooser=lambda *args: "main", + id_chooser=lambda *args: ["fake", "main"], + query_chooser=lambda *args: ["fake", "main"], ) def test_refresh(self): @@ -515,10 +535,12 @@ class RefreshDeferExpireTest(fixtures.DeclarativeMappedTest): class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): def _init_dbs(self): self.db1 = db1 = testing_engine( - 'sqlite:///shard1_%s.db' % provision.FOLLOWER_IDENT, - options=dict(pool_threadlocal=True)) + "sqlite:///shard1_%s.db" % provision.FOLLOWER_IDENT, + options=dict(pool_threadlocal=True), + ) self.db2 = db2 = testing_engine( - 'sqlite:///shard2_%s.db' % provision.FOLLOWER_IDENT) + "sqlite:///shard2_%s.db" % provision.FOLLOWER_IDENT + ) for db in (db1, db2): self.metadata.create_all(db) @@ -538,15 +560,15 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): Base = cls.DeclarativeBasic class Book(Base): - __tablename__ = 'book' + __tablename__ = "book" id = Column(Integer, primary_key=True) title = Column(String(50), nullable=False) - pages = relationship('Page', backref='book') + pages = relationship("Page", backref="book") class Page(Base): - __tablename__ = 'page' + __tablename__ = "page" id = Column(Integer, primary_key=True) - book_id = Column(ForeignKey('book.id')) + book_id = Column(ForeignKey("book.id")) title = Column(String(50)) def _fixture(self, lazy_load_book=False, lazy_load_pages=False): @@ -569,13 +591,16 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): return [query.lazy_loaded_from.identity_token] def no_query_chooser(query): - if query.column_descriptions[0]['type'] is Book and lazy_load_book: + if query.column_descriptions[0]["type"] is Book and lazy_load_book: assert isinstance(query.lazy_loaded_from.obj(), Page) - elif query.column_descriptions[0]['type'] is Page and lazy_load_pages: + elif ( + query.column_descriptions[0]["type"] is Page + and lazy_load_pages + ): assert isinstance(query.lazy_loaded_from.obj(), Book) if query.lazy_loaded_from is None: - return ['test', 'test2'] + return ["test", "test2"] else: return [query.lazy_loaded_from.identity_token] @@ -590,7 +615,7 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): shards={"test": db1, "test2": db2}, shard_chooser=shard_chooser, id_chooser=id_chooser, - query_chooser=no_query_chooser + query_chooser=no_query_chooser, ) return session @@ -607,16 +632,13 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): page = session.query(Page).first() - session.expire(page, ['book']) + session.expire(page, ["book"]) def go(): eq_(page.book, book) # doesn't emit SQL - self.assert_multiple_sql_count( - self.dbs, - go, - [0, 0]) + self.assert_multiple_sql_count(self.dbs, go, [0, 0]) def test_lazy_load_from_db(self): session = self._fixture(lazy_load_book=True) @@ -632,16 +654,13 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): session.expunge(book1) book1_page = session.query(Page).first() - session.expire(book1_page, ['book']) + session.expire(book1_page, ["book"]) def go(): eq_(inspect(book1_page.book).identity_key, book1_id) # emits one query - self.assert_multiple_sql_count( - self.dbs, - go, - [1, 0]) + self.assert_multiple_sql_count(self.dbs, go, [1, 0]) def test_lazy_load_no_baked_conflict(self): session = self._fixture(lazy_load_pages=True) @@ -657,8 +676,8 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): session.add(book2) session.flush() - session.expire(book1, ['pages']) - session.expire(book2, ['pages']) + session.expire(book1, ["pages"]) + session.expire(book2, ["pages"]) eq_(book1.pages[0].title, "book 1 page 1") diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py index ea71beb568..55fd2b1a54 100644 --- a/test/ext/test_hybrid.py +++ b/test/ext/test_hybrid.py @@ -3,21 +3,24 @@ from sqlalchemy.orm import relationship, Session, aliased, persistence from sqlalchemy.testing.schema import Column from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext import hybrid -from sqlalchemy.testing import eq_, is_, AssertsCompiledSQL, \ - assert_raises_message +from sqlalchemy.testing import ( + eq_, + is_, + AssertsCompiledSQL, + assert_raises_message, +) from sqlalchemy.testing import fixtures from sqlalchemy import inspect from decimal import Decimal class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def _fixture(self): Base = declarative_base() class UCComparator(hybrid.Comparator): - def __eq__(self, other): if other is None: return self.expression is None @@ -25,7 +28,7 @@ class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL): return func.upper(self.expression) == func.upper(other) class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) _value = Column("value", String) @@ -63,8 +66,7 @@ class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL): A = self._fixture() sess = Session() self.assert_compile( - sess.query(A.value), - "SELECT a.value AS a_value FROM a" + sess.query(A.value), "SELECT a.value AS a_value FROM a" ) def test_aliased_query(self): @@ -72,7 +74,7 @@ class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL): sess = Session() self.assert_compile( sess.query(aliased(A).value), - "SELECT a_1.value AS a_1_value FROM a AS a_1" + "SELECT a_1.value AS a_1_value FROM a AS a_1", ) def test_aliased_filter(self): @@ -81,7 +83,7 @@ class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( sess.query(aliased(A)).filter_by(value="foo"), "SELECT a_1.value AS a_1_value, a_1.id AS a_1_id " - "FROM a AS a_1 WHERE upper(a_1.value) = upper(:upper_1)" + "FROM a AS a_1 WHERE upper(a_1.value) = upper(:upper_1)", ) def test_docstring(self): @@ -90,13 +92,13 @@ class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL): class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def _fixture(self): Base = declarative_base() class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) _value = Column("value", String) @@ -120,14 +122,13 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): return A - def _relationship_fixture(self): Base = declarative_base() class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) - b_id = Column('bid', Integer, ForeignKey('b.id')) + b_id = Column("bid", Integer, ForeignKey("b.id")) _value = Column("value", String) @hybrid.hybrid_property @@ -147,7 +148,7 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): return func.bar(cls._value) class B(Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) as_ = relationship("A") @@ -159,7 +160,7 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): inspect(A).all_orm_descriptors.value.info["some key"] = "some value" eq_( inspect(A).all_orm_descriptors.value.info, - {"some key": "some value"} + {"some key": "some value"}, ) def test_set_get(self): @@ -171,8 +172,7 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): def test_expression(self): A = self._fixture() self.assert_compile( - A.value.__clause_element__(), - "foo(a.value) + bar(a.value)" + A.value.__clause_element__(), "foo(a.value) + bar(a.value)" ) def test_any(self): @@ -182,14 +182,14 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): sess.query(B).filter(B.as_.any(value=5)), "SELECT b.id AS b_id FROM b WHERE EXISTS " "(SELECT 1 FROM a WHERE b.id = a.bid " - "AND foo(a.value) + bar(a.value) = :param_1)" + "AND foo(a.value) + bar(a.value) = :param_1)", ) def test_aliased_expression(self): A = self._fixture() self.assert_compile( aliased(A).value.__clause_element__(), - "foo(a_1.value) + bar(a_1.value)" + "foo(a_1.value) + bar(a_1.value)", ) def test_query(self): @@ -198,7 +198,7 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( sess.query(A).filter_by(value="foo"), "SELECT a.value AS a_value, a.id AS a_id " - "FROM a WHERE foo(a.value) + bar(a.value) = :param_1" + "FROM a WHERE foo(a.value) + bar(a.value) = :param_1", ) def test_aliased_query(self): @@ -207,7 +207,7 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( sess.query(aliased(A)).filter_by(value="foo"), "SELECT a_1.value AS a_1_value, a_1.id AS a_1_id " - "FROM a AS a_1 WHERE foo(a_1.value) + bar(a_1.value) = :param_1" + "FROM a AS a_1 WHERE foo(a_1.value) + bar(a_1.value) = :param_1", ) def test_docstring(self): @@ -220,13 +220,13 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): class PropertyValueTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def _fixture(self, assignable): Base = declarative_base() class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) _value = Column("value", String) @@ -235,6 +235,7 @@ class PropertyValueTest(fixtures.TestBase, AssertsCompiledSQL): return self._value - 5 if assignable: + @value.setter def value(self, v): self._value = v + 5 @@ -245,18 +246,14 @@ class PropertyValueTest(fixtures.TestBase, AssertsCompiledSQL): A = self._fixture(False) a1 = A(_value=5) assert_raises_message( - AttributeError, - "can't set attribute", - setattr, a1, 'value', 10 + AttributeError, "can't set attribute", setattr, a1, "value", 10 ) def test_nondeletable(self): A = self._fixture(False) a1 = A(_value=5) assert_raises_message( - AttributeError, - "can't delete attribute", - delattr, a1, 'value' + AttributeError, "can't delete attribute", delattr, a1, "value" ) def test_set_get(self): @@ -267,13 +264,13 @@ class PropertyValueTest(fixtures.TestBase, AssertsCompiledSQL): class PropertyOverrideTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def _fixture(self): Base = declarative_base() class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) _name = Column(String) @@ -286,8 +283,8 @@ class PropertyOverrideTest(fixtures.TestBase, AssertsCompiledSQL): self._name = value.title() class OverrideSetter(Person): - __tablename__ = 'override_setter' - id = Column(Integer, ForeignKey('person.id'), primary_key=True) + __tablename__ = "override_setter" + id = Column(Integer, ForeignKey("person.id"), primary_key=True) other = Column(String) @Person.name.setter @@ -295,8 +292,8 @@ class PropertyOverrideTest(fixtures.TestBase, AssertsCompiledSQL): self._name = value.upper() class OverrideGetter(Person): - __tablename__ = 'override_getter' - id = Column(Integer, ForeignKey('person.id'), primary_key=True) + __tablename__ = "override_getter" + id = Column(Integer, ForeignKey("person.id"), primary_key=True) other = Column(String) @Person.name.getter @@ -304,8 +301,8 @@ class PropertyOverrideTest(fixtures.TestBase, AssertsCompiledSQL): return "Hello " + self._name class OverrideExpr(Person): - __tablename__ = 'override_expr' - id = Column(Integer, ForeignKey('person.id'), primary_key=True) + __tablename__ = "override_expr" + id = Column(Integer, ForeignKey("person.id"), primary_key=True) other = Column(String) @Person.name.overrides.expression @@ -317,8 +314,8 @@ class PropertyOverrideTest(fixtures.TestBase, AssertsCompiledSQL): return func.concat("Hello", self.expression._name) class OverrideComparator(Person): - __tablename__ = 'override_comp' - id = Column(Integer, ForeignKey('person.id'), primary_key=True) + __tablename__ = "override_comp" + id = Column(Integer, ForeignKey("person.id"), primary_key=True) other = Column(String) @Person.name.overrides.comparator @@ -326,66 +323,63 @@ class PropertyOverrideTest(fixtures.TestBase, AssertsCompiledSQL): return FooComparator(self) return ( - Person, OverrideSetter, OverrideGetter, - OverrideExpr, OverrideComparator + Person, + OverrideSetter, + OverrideGetter, + OverrideExpr, + OverrideComparator, ) def test_property(self): Person, _, _, _, _ = self._fixture() p1 = Person() - p1.name = 'mike' - eq_(p1._name, 'Mike') - eq_(p1.name, 'Mike') + p1.name = "mike" + eq_(p1._name, "Mike") + eq_(p1.name, "Mike") def test_override_setter(self): _, OverrideSetter, _, _, _ = self._fixture() p1 = OverrideSetter() - p1.name = 'mike' - eq_(p1._name, 'MIKE') - eq_(p1.name, 'MIKE') + p1.name = "mike" + eq_(p1._name, "MIKE") + eq_(p1.name, "MIKE") def test_override_getter(self): _, _, OverrideGetter, _, _ = self._fixture() p1 = OverrideGetter() - p1.name = 'mike' - eq_(p1._name, 'Mike') - eq_(p1.name, 'Hello Mike') + p1.name = "mike" + eq_(p1._name, "Mike") + eq_(p1.name, "Hello Mike") def test_override_expr(self): Person, _, _, OverrideExpr, _ = self._fixture() - self.assert_compile( - Person.name.__clause_element__(), - "person._name" - ) + self.assert_compile(Person.name.__clause_element__(), "person._name") self.assert_compile( OverrideExpr.name.__clause_element__(), - "concat(:concat_1, person._name)" + "concat(:concat_1, person._name)", ) def test_override_comparator(self): Person, _, _, _, OverrideComparator = self._fixture() - self.assert_compile( - Person.name.__clause_element__(), - "person._name" - ) + self.assert_compile(Person.name.__clause_element__(), "person._name") self.assert_compile( OverrideComparator.name.__clause_element__(), - "concat(:concat_1, person._name)" + "concat(:concat_1, person._name)", ) class PropertyMirrorTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def _fixture(self): Base = declarative_base() class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) _value = Column("value", String) @@ -393,6 +387,7 @@ class PropertyMirrorTest(fixtures.TestBase, AssertsCompiledSQL): def value(self): "This is an instance-level docstring" return self._value + return A def test_property(self): @@ -416,29 +411,29 @@ class PropertyMirrorTest(fixtures.TestBase, AssertsCompiledSQL): def test_info_not_mirrored(self): A = self._fixture() - A._value.info['foo'] = 'bar' - A.value.info['bar'] = 'hoho' + A._value.info["foo"] = "bar" + A.value.info["bar"] = "hoho" - eq_(A._value.info, {'foo': 'bar'}) - eq_(A.value.info, {'bar': 'hoho'}) + eq_(A._value.info, {"foo": "bar"}) + eq_(A.value.info, {"bar": "hoho"}) def test_info_from_hybrid(self): A = self._fixture() - A._value.info['foo'] = 'bar' - A.value.info['bar'] = 'hoho' + A._value.info["foo"] = "bar" + A.value.info["bar"] = "hoho" insp = inspect(A) - is_(insp.all_orm_descriptors['value'].info, A.value.info) + is_(insp.all_orm_descriptors["value"].info, A.value.info) class MethodExpressionTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def _fixture(self): Base = declarative_base() class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) _value = Column("value", String) @@ -470,24 +465,20 @@ class MethodExpressionTest(fixtures.TestBase, AssertsCompiledSQL): def test_expression(self): A = self._fixture() - self.assert_compile( - A.value(5), - "foo(a.value, :foo_1) + :foo_2" - ) + self.assert_compile(A.value(5), "foo(a.value, :foo_1) + :foo_2") def test_info(self): A = self._fixture() inspect(A).all_orm_descriptors.value.info["some key"] = "some value" eq_( inspect(A).all_orm_descriptors.value.info, - {"some key": "some value"} + {"some key": "some value"}, ) def test_aliased_expression(self): A = self._fixture() self.assert_compile( - aliased(A).value(5), - "foo(a_1.value, :foo_1) + :foo_2" + aliased(A).value(5), "foo(a_1.value, :foo_1) + :foo_2" ) def test_query(self): @@ -496,7 +487,7 @@ class MethodExpressionTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( sess.query(A).filter(A.value(5) == "foo"), "SELECT a.value AS a_value, a.id AS a_id " - "FROM a WHERE foo(a.value, :foo_1) + :foo_2 = :param_1" + "FROM a WHERE foo(a.value, :foo_1) + :foo_2 = :param_1", ) def test_aliased_query(self): @@ -506,7 +497,7 @@ class MethodExpressionTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( sess.query(a1).filter(a1.value(5) == "foo"), "SELECT a_1.value AS a_1_value, a_1.id AS a_1_id " - "FROM a AS a_1 WHERE foo(a_1.value, :foo_1) + :foo_2 = :param_1" + "FROM a AS a_1 WHERE foo(a_1.value, :foo_1) + :foo_2 = :param_1", ) def test_query_col(self): @@ -514,7 +505,7 @@ class MethodExpressionTest(fixtures.TestBase, AssertsCompiledSQL): sess = Session() self.assert_compile( sess.query(A.value(5)), - "SELECT foo(a.value, :foo_1) + :foo_2 AS anon_1 FROM a" + "SELECT foo(a.value, :foo_1) + :foo_2 AS anon_1 FROM a", ) def test_aliased_query_col(self): @@ -522,7 +513,7 @@ class MethodExpressionTest(fixtures.TestBase, AssertsCompiledSQL): sess = Session() self.assert_compile( sess.query(aliased(A).value(5)), - "SELECT foo(a_1.value, :foo_1) + :foo_2 AS anon_1 FROM a AS a_1" + "SELECT foo(a_1.value, :foo_1) + :foo_2 AS anon_1 FROM a AS a_1", ) def test_docstring(self): @@ -539,14 +530,14 @@ class MethodExpressionTest(fixtures.TestBase, AssertsCompiledSQL): class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) first_name = Column(String(10)) @@ -554,19 +545,19 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): @hybrid.hybrid_property def name(self): - return self.first_name + ' ' + self.last_name + return self.first_name + " " + self.last_name @name.setter def name(self, value): - self.first_name, self.last_name = value.split(' ', 1) + self.first_name, self.last_name = value.split(" ", 1) @name.expression def name(cls): - return func.concat(cls.first_name, ' ', cls.last_name) + return func.concat(cls.first_name, " ", cls.last_name) @name.update_expression def name(cls, value): - f, l = value.split(' ', 1) + f, l = value.split(" ", 1) return [(cls.first_name, f), (cls.last_name, l)] @hybrid.hybrid_property @@ -584,7 +575,7 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): @classmethod def insert_data(cls): s = Session() - jill = cls.classes.Person(id=3, first_name='jill') + jill = cls.classes.Person(id=3, first_name="jill") s.add(jill) s.commit() @@ -595,12 +586,13 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): q = s.query(Person) bulk_ud = persistence.BulkUpdate.factory( - q, False, {Person.fname: "Dr."}, {}) + q, False, {Person.fname: "Dr."}, {} + ) self.assert_compile( bulk_ud, "UPDATE person SET first_name=:first_name", - params={'first_name': 'Dr.'} + params={"first_name": "Dr."}, ) def test_update_expr(self): @@ -610,12 +602,13 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): q = s.query(Person) bulk_ud = persistence.BulkUpdate.factory( - q, False, {Person.name: "Dr. No"}, {}) + q, False, {Person.name: "Dr. No"}, {} + ) self.assert_compile( bulk_ud, "UPDATE person SET first_name=:first_name, last_name=:last_name", - params={'first_name': 'Dr.', 'last_name': 'No'} + params={"first_name": "Dr.", "last_name": "No"}, ) def test_evaluate_hybrid_attr_indirect(self): @@ -625,9 +618,9 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): jill = s.query(Person).get(3) s.query(Person).update( - {Person.fname2: 'moonbeam'}, - synchronize_session='evaluate') - eq_(jill.fname2, 'moonbeam') + {Person.fname2: "moonbeam"}, synchronize_session="evaluate" + ) + eq_(jill.fname2, "moonbeam") def test_evaluate_hybrid_attr_plain(self): Person = self.classes.Person @@ -636,9 +629,9 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): jill = s.query(Person).get(3) s.query(Person).update( - {Person.fname: 'moonbeam'}, - synchronize_session='evaluate') - eq_(jill.fname, 'moonbeam') + {Person.fname: "moonbeam"}, synchronize_session="evaluate" + ) + eq_(jill.fname, "moonbeam") def test_fetch_hybrid_attr_indirect(self): Person = self.classes.Person @@ -647,9 +640,9 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): jill = s.query(Person).get(3) s.query(Person).update( - {Person.fname2: 'moonbeam'}, - synchronize_session='fetch') - eq_(jill.fname2, 'moonbeam') + {Person.fname2: "moonbeam"}, synchronize_session="fetch" + ) + eq_(jill.fname2, "moonbeam") def test_fetch_hybrid_attr_plain(self): Person = self.classes.Person @@ -658,9 +651,9 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): jill = s.query(Person).get(3) s.query(Person).update( - {Person.fname: 'moonbeam'}, - synchronize_session='fetch') - eq_(jill.fname, 'moonbeam') + {Person.fname: "moonbeam"}, synchronize_session="fetch" + ) + eq_(jill.fname, "moonbeam") def test_evaluate_hybrid_attr_w_update_expr(self): Person = self.classes.Person @@ -669,9 +662,9 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): jill = s.query(Person).get(3) s.query(Person).update( - {Person.name: 'moonbeam sunshine'}, - synchronize_session='evaluate') - eq_(jill.name, 'moonbeam sunshine') + {Person.name: "moonbeam sunshine"}, synchronize_session="evaluate" + ) + eq_(jill.name, "moonbeam sunshine") def test_fetch_hybrid_attr_w_update_expr(self): Person = self.classes.Person @@ -680,9 +673,9 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): jill = s.query(Person).get(3) s.query(Person).update( - {Person.name: 'moonbeam sunshine'}, - synchronize_session='fetch') - eq_(jill.name, 'moonbeam sunshine') + {Person.name: "moonbeam sunshine"}, synchronize_session="fetch" + ) + eq_(jill.name, "moonbeam sunshine") def test_evaluate_hybrid_attr_indirect_w_update_expr(self): Person = self.classes.Person @@ -691,9 +684,9 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): jill = s.query(Person).get(3) s.query(Person).update( - {Person.uname: 'moonbeam sunshine'}, - synchronize_session='evaluate') - eq_(jill.uname, 'moonbeam sunshine') + {Person.uname: "moonbeam sunshine"}, synchronize_session="evaluate" + ) + eq_(jill.uname, "moonbeam sunshine") class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): @@ -703,13 +696,14 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): http://techspot.zzzeek.org/2011/10/21/hybrids-and-value-agnostic-types/ """ - __dialect__ = 'default' + + __dialect__ = "default" @classmethod def setup_class(cls): from sqlalchemy import literal - symbols = ('usd', 'gbp', 'cad', 'eur', 'aud') + symbols = ("usd", "gbp", "cad", "eur", "aud") currency_lookup = dict( ((currency_from, currency_to), Decimal(str(rate))) for currency_to, values in zip( @@ -720,7 +714,8 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): (1.01152, 1.6084, 1, 1.39569, 1.04148), (0.724743, 1.1524, 0.716489, 1, 0.746213), (0.971228, 1.54434, 0.960166, 1.34009, 1), - ]) + ], + ) for currency_from, rate in zip(symbols, values) ) @@ -731,16 +726,14 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): def __add__(self, other): return Amount( - self.amount + - other.as_currency(self.currency).amount, - self.currency + self.amount + other.as_currency(self.currency).amount, + self.currency, ) def __sub__(self, other): return Amount( - self.amount - - other.as_currency(self.currency).amount, - self.currency + self.amount - other.as_currency(self.currency).amount, + self.currency, ) def __lt__(self, other): @@ -754,9 +747,9 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): def as_currency(self, other_currency): return Amount( - currency_lookup[(self.currency, other_currency)] * - self.amount, - other_currency + currency_lookup[(self.currency, other_currency)] + * self.amount, + other_currency, ) def __clause_element__(self): @@ -776,10 +769,10 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): Base = declarative_base() class BankAccount(Base): - __tablename__ = 'bank_account' + __tablename__ = "bank_account" id = Column(Integer, primary_key=True) - _balance = Column('balance', Numeric) + _balance = Column("balance", Numeric) @hybrid.hybrid_property def balance(self): @@ -805,7 +798,7 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): account = BankAccount(balance=Amount(4000, "usd")) # 3c. print balance in gbp - eq_(account.balance.as_currency("gbp").amount, Decimal('2515.58')) + eq_(account.balance.as_currency("gbp").amount, Decimal("2515.58")) def test_instance_three(self): BankAccount, Amount = self.BankAccount, self.Amount @@ -819,22 +812,23 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): account = BankAccount(balance=Amount(4000, "usd")) eq_( account.balance + Amount(500, "cad") - Amount(50, "eur"), - Amount(Decimal("4425.316"), "usd") + Amount(Decimal("4425.316"), "usd"), ) def test_query_one(self): BankAccount, Amount = self.BankAccount, self.Amount session = Session() - query = session.query(BankAccount).\ - filter(BankAccount.balance == Amount(10000, "cad")) + query = session.query(BankAccount).filter( + BankAccount.balance == Amount(10000, "cad") + ) self.assert_compile( query, "SELECT bank_account.balance AS bank_account_balance, " "bank_account.id AS bank_account_id FROM bank_account " "WHERE bank_account.balance = :balance_1", - checkparams={'balance_1': Decimal('9886.110000')} + checkparams={"balance_1": Decimal("9886.110000")}, ) def test_query_two(self): @@ -842,11 +836,15 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): session = Session() # alternatively we can do the calc on the DB side. - query = session.query(BankAccount).\ - filter( - BankAccount.balance.as_currency("cad") > Amount(9999, "cad")).\ - filter( - BankAccount.balance.as_currency("cad") < Amount(10001, "cad")) + query = ( + session.query(BankAccount) + .filter( + BankAccount.balance.as_currency("cad") > Amount(9999, "cad") + ) + .filter( + BankAccount.balance.as_currency("cad") < Amount(10001, "cad") + ) + ) self.assert_compile( query, "SELECT bank_account.balance AS bank_account_balance, " @@ -855,20 +853,21 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): "WHERE :balance_1 * bank_account.balance > :param_1 " "AND :balance_2 * bank_account.balance < :param_2", checkparams={ - 'balance_1': Decimal('1.01152'), - 'balance_2': Decimal('1.01152'), - 'param_1': Decimal('9999'), - 'param_2': Decimal('10001')} + "balance_1": Decimal("1.01152"), + "balance_2": Decimal("1.01152"), + "param_1": Decimal("9999"), + "param_2": Decimal("10001"), + }, ) def test_query_three(self): BankAccount = self.BankAccount session = Session() - query = session.query(BankAccount).\ - filter( - BankAccount.balance.as_currency("cad") > - BankAccount.balance.as_currency("eur")) + query = session.query(BankAccount).filter( + BankAccount.balance.as_currency("cad") + > BankAccount.balance.as_currency("eur") + ) self.assert_compile( query, "SELECT bank_account.balance AS bank_account_balance, " @@ -876,9 +875,10 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): "WHERE :balance_1 * bank_account.balance > " ":param_1 * :balance_2 * bank_account.balance", checkparams={ - 'balance_1': Decimal('1.01152'), - 'balance_2': Decimal('0.724743'), - 'param_1': Decimal('1.39569')} + "balance_1": Decimal("1.01152"), + "balance_2": Decimal("0.724743"), + "param_1": Decimal("1.39569"), + }, ) def test_query_four(self): @@ -891,7 +891,7 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): query, "SELECT :balance_1 * bank_account.balance AS anon_1 " "FROM bank_account", - checkparams={'balance_1': Decimal('1.01152')} + checkparams={"balance_1": Decimal("1.01152")}, ) def test_query_five(self): @@ -904,11 +904,12 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): query, "SELECT avg(:balance_1 * bank_account.balance) AS avg_1 " "FROM bank_account", - checkparams={'balance_1': Decimal('0.724743')} + checkparams={"balance_1": Decimal("0.724743")}, ) def test_docstring(self): BankAccount = self.BankAccount eq_( BankAccount.balance.__doc__, - "Return an Amount view of the current balance.") + "Return an Amount view of the current balance.", + ) diff --git a/test/ext/test_indexable.py b/test/ext/test_indexable.py index 44d8619669..28ae6ab6bf 100644 --- a/test/ext/test_indexable.py +++ b/test/ext/test_indexable.py @@ -13,17 +13,15 @@ from sqlalchemy import inspect class IndexPropertyTest(fixtures.TestBase): - def test_array(self): Base = declarative_base() class A(Base): - __tablename__ = 'a' - id = Column('id', Integer, primary_key=True) - array = Column('_array', ARRAY(Integer), - default=[]) - first = index_property('array', 0) - tenth = index_property('array', 9) + __tablename__ = "a" + id = Column("id", Integer, primary_key=True) + array = Column("_array", ARRAY(Integer), default=[]) + first = index_property("array", 0) + tenth = index_property("array", 9) a = A(array=[1, 2, 3]) eq_(a.first, 1) @@ -42,13 +40,12 @@ class IndexPropertyTest(fixtures.TestBase): Base = declarative_base() class A(Base): - __tablename__ = 'a' - id = Column('id', Integer, primary_key=True) - array = Column('_array', ARRAY(Integer), - default=[]) - first = index_property('array', 0) + __tablename__ = "a" + id = Column("id", Integer, primary_key=True) + array = Column("_array", ARRAY(Integer), default=[]) + first = index_property("array", 0) - fifth = index_property('array', 4) + fifth = index_property("array", 4) a1 = A(fifth=10) a2 = A(first=5) @@ -62,18 +59,18 @@ class IndexPropertyTest(fixtures.TestBase): Base = declarative_base() class J(Base): - __tablename__ = 'j' - id = Column('id', Integer, primary_key=True) - json = Column('_json', JSON, default={}) - field = index_property('json', 'field') + __tablename__ = "j" + id = Column("id", Integer, primary_key=True) + json = Column("_json", JSON, default={}) + field = index_property("json", "field") - j = J(json={'a': 1, 'b': 2}) + j = J(json={"a": 1, "b": 2}) assert_raises(AttributeError, lambda: j.field) - j.field = 'test' - eq_(j.field, 'test') - eq_(j.json, {'a': 1, 'b': 2, 'field': 'test'}) + j.field = "test" + eq_(j.field, "test") + eq_(j.json, {"a": 1, "b": 2, "field": "test"}) - j2 = J(field='test') + j2 = J(field="test") eq_(j2.json, {"field": "test"}) eq_(j2.field, "test") @@ -81,10 +78,10 @@ class IndexPropertyTest(fixtures.TestBase): Base = declarative_base() class A(Base): - __tablename__ = 'a' - id = Column('id', Integer, primary_key=True) - array = Column('_array', ARRAY(Integer)) - first = index_property('array', 1) + __tablename__ = "a" + id = Column("id", Integer, primary_key=True) + array = Column("_array", ARRAY(Integer)) + first = index_property("array", 1) a = A() assert_raises(AttributeError, getattr, a, "first") @@ -95,10 +92,10 @@ class IndexPropertyTest(fixtures.TestBase): Base = declarative_base() class A(Base): - __tablename__ = 'a' - id = Column('id', Integer, primary_key=True) - array = Column('_array', ARRAY(Integer)) - first = index_property('array', 1) + __tablename__ = "a" + id = Column("id", Integer, primary_key=True) + array = Column("_array", ARRAY(Integer)) + first = index_property("array", 1) a = A(array=[]) assert_raises(AttributeError, lambda: a.first) @@ -107,25 +104,26 @@ class IndexPropertyTest(fixtures.TestBase): Base = declarative_base() class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) array = Column(ARRAY(Integer)) - first = index_property('array', 1, mutable=False) + first = index_property("array", 1, mutable=False) a = A() def set_(): a.first = 10 + assert_raises(AttributeError, set_) def test_set_mutable_dict(self): Base = declarative_base() class J(Base): - __tablename__ = 'j' + __tablename__ = "j" id = Column(Integer, primary_key=True) json = Column(JSON, default={}) - field = index_property('json', 'field') + field = index_property("json", "field") j = J() @@ -142,19 +140,19 @@ class IndexPropertyTest(fixtures.TestBase): Base = declarative_base() class J(Base): - __tablename__ = 'j' + __tablename__ = "j" id = Column(Integer, primary_key=True) json = Column(JSON, default={}) - default = index_property('json', 'field', default='default') - none = index_property('json', 'field', default=None) + default = index_property("json", "field", default="default") + none = index_property("json", "field", default=None) j = J() assert j.json is None - assert j.default == 'default' + assert j.default == "default" assert j.none is None j.json = {} - assert j.default == 'default' + assert j.default == "default" assert j.none is None j.default = None assert j.default is None @@ -166,7 +164,7 @@ class IndexPropertyTest(fixtures.TestBase): class IndexPropertyArrayTest(fixtures.DeclarativeMappedTest): - __requires__ = ('array_type',) + __requires__ = ("array_type",) __backend__ = True @classmethod @@ -176,21 +174,25 @@ class IndexPropertyArrayTest(fixtures.DeclarativeMappedTest): class Array(fixtures.ComparableEntity, Base): __tablename__ = "array" - id = Column(sa.Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + sa.Integer, primary_key=True, test_needs_autoincrement=True + ) array = Column(ARRAY(Integer), default=[]) array0 = Column(ARRAY(Integer, zero_indexes=True), default=[]) - first = index_property('array', 0) - first0 = index_property('array0', 0, onebased=False) + first = index_property("array", 0) + first0 = index_property("array0", 0, onebased=False) def test_query(self): Array = self.classes.Array s = Session(testing.db) - s.add_all([ - Array(), - Array(array=[1, 2, 3], array0=[1, 2, 3]), - Array(array=[4, 5, 6], array0=[4, 5, 6])]) + s.add_all( + [ + Array(), + Array(array=[1, 2, 3], array0=[1, 2, 3]), + Array(array=[4, 5, 6], array0=[4, 5, 6]), + ] + ) s.commit() a1 = s.query(Array).filter(Array.array == [1, 2, 3]).one() @@ -233,19 +235,19 @@ class IndexPropertyArrayTest(fixtures.DeclarativeMappedTest): i = inspect(a) is_(i.modified, False) - in_('array', i.unmodified) + in_("array", i.unmodified) a.first = 10 is_(i.modified, True) - not_in_('array', i.unmodified) + not_in_("array", i.unmodified) class IndexPropertyJsonTest(fixtures.DeclarativeMappedTest): # TODO: remove reliance on "astext" for these tests - __requires__ = ('json_type',) - __only_on__ = 'postgresql' + __requires__ = ("json_type",) + __only_on__ = "postgresql" __backend__ = True @@ -267,39 +269,39 @@ class IndexPropertyJsonTest(fixtures.DeclarativeMappedTest): class Json(fixtures.ComparableEntity, Base): __tablename__ = "json" - id = Column(sa.Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + sa.Integer, primary_key=True, test_needs_autoincrement=True + ) json = Column(JSON, default={}) - field = index_property('json', 'field') - json_field = index_property('json', 'field') - int_field = json_property('json', 'field', Integer) - text_field = json_property('json', 'field', Text) - other = index_property('json', 'other') - subfield = json_property('other', 'field', Text) + field = index_property("json", "field") + json_field = index_property("json", "field") + int_field = json_property("json", "field", Integer) + text_field = json_property("json", "field", Text) + other = index_property("json", "other") + subfield = json_property("other", "field", Text) def test_query(self): Json = self.classes.Json s = Session(testing.db) - s.add_all([ - Json(), - Json(json={'field': 10}), - Json(json={'field': 20})]) + s.add_all([Json(), Json(json={"field": 10}), Json(json={"field": 20})]) s.commit() - a1 = s.query(Json)\ - .filter(Json.json['field'].astext.cast(Integer) == 10)\ + a1 = ( + s.query(Json) + .filter(Json.json["field"].astext.cast(Integer) == 10) .one() - a2 = s.query(Json).filter(Json.field.astext == '10').one() + ) + a2 = s.query(Json).filter(Json.field.astext == "10").one() eq_(a1.id, a2.id) - a3 = s.query(Json).filter(Json.field.astext == '20').one() + a3 = s.query(Json).filter(Json.field.astext == "20").one() ne_(a1.id, a3.id) - a4 = s.query(Json).filter(Json.json_field.astext == '10').one() + a4 = s.query(Json).filter(Json.json_field.astext == "10").one() eq_(a2.id, a4.id) a5 = s.query(Json).filter(Json.int_field == 10).one() eq_(a2.id, a5.id) - a6 = s.query(Json).filter(Json.text_field == '10').one() + a6 = s.query(Json).filter(Json.text_field == "10").one() eq_(a2.id, a6.id) def test_mutable(self): @@ -326,37 +328,37 @@ class IndexPropertyJsonTest(fixtures.DeclarativeMappedTest): i = inspect(j) is_(i.modified, False) - in_('json', i.unmodified) + in_("json", i.unmodified) j.other = 42 is_(i.modified, True) - not_in_('json', i.unmodified) + not_in_("json", i.unmodified) def test_cast_type(self): Json = self.classes.Json s = Session(testing.db) - j = Json(json={'field': 10}) + j = Json(json={"field": 10}) s.add(j) s.commit() jq = s.query(Json).filter(Json.int_field == 10).one() eq_(j.id, jq.id) - jq = s.query(Json).filter(Json.text_field == '10').one() + jq = s.query(Json).filter(Json.text_field == "10").one() eq_(j.id, jq.id) - jq = s.query(Json).filter(Json.json_field.astext == '10').one() + jq = s.query(Json).filter(Json.json_field.astext == "10").one() eq_(j.id, jq.id) - jq = s.query(Json).filter(Json.text_field == 'wrong').first() + jq = s.query(Json).filter(Json.text_field == "wrong").first() is_(jq, None) - j.json = {'field': True} + j.json = {"field": True} s.commit() - jq = s.query(Json).filter(Json.text_field == 'true').one() + jq = s.query(Json).filter(Json.text_field == "true").one() eq_(j.id, jq.id) def test_multi_dimension(self): @@ -364,26 +366,26 @@ class IndexPropertyJsonTest(fixtures.DeclarativeMappedTest): s = Session(testing.db) - j = Json(json={'other': {'field': 'multi'}}) + j = Json(json={"other": {"field": "multi"}}) s.add(j) s.commit() - eq_(j.other, {'field': 'multi'}) - eq_(j.subfield, 'multi') + eq_(j.other, {"field": "multi"}) + eq_(j.subfield, "multi") - jq = s.query(Json).filter(Json.subfield == 'multi').first() + jq = s.query(Json).filter(Json.subfield == "multi").first() eq_(j.id, jq.id) def test_nested_property_init(self): Json = self.classes.Json # subfield initializer - j = Json(subfield='a') - eq_(j.json, {'other': {'field': 'a'}}) + j = Json(subfield="a") + eq_(j.json, {"other": {"field": "a"}}) def test_nested_property_set(self): Json = self.classes.Json j = Json() - j.subfield = 'a' - eq_(j.json, {'other': {'field': 'a'}}) + j.subfield = "a" + eq_(j.json, {"other": {"field": "a"}}) diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py index 654a85e745..d46ace9a8a 100644 --- a/test/ext/test_mutable.py +++ b/test/ext/test_mutable.py @@ -23,7 +23,6 @@ class SubFoo(Foo): class FooWithEq(object): - def __init__(self, **kw): for k in kw: setattr(self, k, kw[k]) @@ -36,7 +35,6 @@ class FooWithEq(object): class Point(MutableComposite): - def __init__(self, x, y): self.x = x self.y = y @@ -55,13 +53,14 @@ class Point(MutableComposite): self.x, self.y = state def __eq__(self, other): - return isinstance(other, Point) and \ - other.x == self.x and \ - other.y == self.y + return ( + isinstance(other, Point) + and other.x == self.x + and other.y == self.y + ) class MyPoint(Point): - @classmethod def coerce(cls, key, value): if isinstance(value, tuple): @@ -82,7 +81,7 @@ class _MutableDictTestFixture(object): class _MutableDictTestBase(_MutableDictTestFixture): - run_define_tables = 'each' + run_define_tables = "each" def setup_mappers(cls): foo = cls.tables.foo @@ -100,20 +99,21 @@ class _MutableDictTestBase(_MutableDictTestFixture): assert_raises_message( ValueError, "Attribute 'data' does not accept objects of type", - Foo, data=set([1, 2, 3]) + Foo, + data=set([1, 2, 3]), ) def test_in_place_mutation(self): sess = Session() - f1 = Foo(data={'a': 'b'}) + f1 = Foo(data={"a": "b"}) sess.add(f1) sess.commit() - f1.data['a'] = 'c' + f1.data["a"] = "c" sess.commit() - eq_(f1.data, {'a': 'c'}) + eq_(f1.data, {"a": "c"}) def test_modified_event(self): canary = mock.Mock() @@ -124,14 +124,17 @@ class _MutableDictTestBase(_MutableDictTestFixture): eq_( canary.mock_calls, - [mock.call( - f1, attributes.Event(Foo.data.impl, attributes.OP_MODIFIED))] + [ + mock.call( + f1, attributes.Event(Foo.data.impl, attributes.OP_MODIFIED) + ) + ], ) def test_clear(self): sess = Session() - f1 = Foo(data={'a': 'b'}) + f1 = Foo(data={"a": "b"}) sess.add(f1) sess.commit() @@ -143,46 +146,46 @@ class _MutableDictTestBase(_MutableDictTestFixture): def test_update(self): sess = Session() - f1 = Foo(data={'a': 'b'}) + f1 = Foo(data={"a": "b"}) sess.add(f1) sess.commit() - f1.data.update({'a': 'z'}) + f1.data.update({"a": "z"}) sess.commit() - eq_(f1.data, {'a': 'z'}) + eq_(f1.data, {"a": "z"}) def test_pop(self): sess = Session() - f1 = Foo(data={'a': 'b', 'c': 'd'}) + f1 = Foo(data={"a": "b", "c": "d"}) sess.add(f1) sess.commit() - eq_(f1.data.pop('a'), 'b') + eq_(f1.data.pop("a"), "b") sess.commit() - assert_raises(KeyError, f1.data.pop, 'g') + assert_raises(KeyError, f1.data.pop, "g") - eq_(f1.data, {'c': 'd'}) + eq_(f1.data, {"c": "d"}) def test_pop_default(self): sess = Session() - f1 = Foo(data={'a': 'b', 'c': 'd'}) + f1 = Foo(data={"a": "b", "c": "d"}) sess.add(f1) sess.commit() - eq_(f1.data.pop('a', 'q'), 'b') - eq_(f1.data.pop('a', 'q'), 'q') + eq_(f1.data.pop("a", "q"), "b") + eq_(f1.data.pop("a", "q"), "q") sess.commit() - eq_(f1.data, {'c': 'd'}) + eq_(f1.data, {"c": "d"}) def test_popitem(self): sess = Session() - orig = {'a': 'b', 'c': 'd'} + orig = {"a": "b", "c": "d"} # the orig dict remains unchanged when we assign, # but just making this future-proof @@ -192,7 +195,7 @@ class _MutableDictTestBase(_MutableDictTestFixture): sess.commit() k, v = f1.data.popitem() - assert k in ('a', 'c') + assert k in ("a", "c") orig.pop(k) sess.commit() @@ -202,45 +205,45 @@ class _MutableDictTestBase(_MutableDictTestFixture): def test_setdefault(self): sess = Session() - f1 = Foo(data={'a': 'b'}) + f1 = Foo(data={"a": "b"}) sess.add(f1) sess.commit() - eq_(f1.data.setdefault('c', 'd'), 'd') + eq_(f1.data.setdefault("c", "d"), "d") sess.commit() - eq_(f1.data, {'a': 'b', 'c': 'd'}) + eq_(f1.data, {"a": "b", "c": "d"}) - eq_(f1.data.setdefault('c', 'q'), 'd') + eq_(f1.data.setdefault("c", "q"), "d") sess.commit() - eq_(f1.data, {'a': 'b', 'c': 'd'}) + eq_(f1.data, {"a": "b", "c": "d"}) def test_replace(self): sess = Session() - f1 = Foo(data={'a': 'b'}) + f1 = Foo(data={"a": "b"}) sess.add(f1) sess.flush() - f1.data = {'b': 'c'} + f1.data = {"b": "c"} sess.commit() - eq_(f1.data, {'b': 'c'}) + eq_(f1.data, {"b": "c"}) def test_replace_itself_still_ok(self): sess = Session() - f1 = Foo(data={'a': 'b'}) + f1 = Foo(data={"a": "b"}) sess.add(f1) sess.flush() f1.data = f1.data - f1.data['b'] = 'c' + f1.data["b"] = "c" sess.commit() - eq_(f1.data, {'a': 'b', 'b': 'c'}) + eq_(f1.data, {"a": "b", "b": "c"}) def test_pickle_parent(self): sess = Session() - f1 = Foo(data={'a': 'b'}) + f1 = Foo(data={"a": "b"}) sess.add(f1) sess.commit() f1.data @@ -250,7 +253,7 @@ class _MutableDictTestBase(_MutableDictTestFixture): sess = Session() f2 = loads(dumps(f1)) sess.add(f2) - f2.data['a'] = 'c' + f2.data["a"] = "c" assert f2 in sess.dirty def test_unrelated_flush(self): @@ -267,14 +270,14 @@ class _MutableDictTestBase(_MutableDictTestFixture): def _test_non_mutable(self): sess = Session() - f1 = Foo(non_mutable_data={'a': 'b'}) + f1 = Foo(non_mutable_data={"a": "b"}) sess.add(f1) sess.commit() - f1.non_mutable_data['a'] = 'c' + f1.non_mutable_data["a"] = "c" sess.commit() - eq_(f1.non_mutable_data, {'a': 'b'}) + eq_(f1.non_mutable_data, {"a": "b"}) class _MutableListTestFixture(object): @@ -290,7 +293,7 @@ class _MutableListTestFixture(object): class _MutableListTestBase(_MutableListTestFixture): - run_define_tables = 'each' + run_define_tables = "each" def setup_mappers(cls): foo = cls.tables.foo @@ -308,7 +311,8 @@ class _MutableListTestBase(_MutableListTestFixture): assert_raises_message( ValueError, "Attribute 'data' does not accept objects of type", - Foo, data=set([1, 2, 3]) + Foo, + data=set([1, 2, 3]), ) def test_in_place_mutation(self): @@ -348,7 +352,7 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data, [1, 4]) def test_clear(self): - if not hasattr(list, 'clear'): + if not hasattr(list, "clear"): # py2 list doesn't have 'clear' return sess = Session() @@ -502,7 +506,7 @@ class _MutableSetTestFixture(object): class _MutableSetTestBase(_MutableSetTestFixture): - run_define_tables = 'each' + run_define_tables = "each" def setup_mappers(cls): foo = cls.tables.foo @@ -520,7 +524,8 @@ class _MutableSetTestBase(_MutableSetTestFixture): assert_raises_message( ValueError, "Attribute 'data' does not accept objects of type", - Foo, data=[1, 2, 3] + Foo, + data=[1, 2, 3], ) def test_clear(self): @@ -721,11 +726,12 @@ class MutableColumnDefaultTest(_MutableDictTestFixture, fixtures.MappedTest): mutable_pickle = MutableDict.as_mutable(PickleType) Table( - 'foo', metadata, + "foo", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', mutable_pickle, default={}), + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", mutable_pickle, default={}), ) def setup_mappers(cls): @@ -743,32 +749,33 @@ class MutableColumnDefaultTest(_MutableDictTestFixture, fixtures.MappedTest): sess.flush() assert isinstance(f1.data, self._type_fixture()) assert f1 not in sess.dirty - f1.data['foo'] = 'bar' + f1.data["foo"] = "bar" assert f1 in sess.dirty class MutableWithScalarPickleTest(_MutableDictTestBase, fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): MutableDict = cls._type_fixture() mutable_pickle = MutableDict.as_mutable(PickleType) - Table('foo', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('skip', mutable_pickle), - Column('data', mutable_pickle), - Column('non_mutable_data', PickleType), - Column('unrelated_data', String(50)) - ) + Table( + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("skip", mutable_pickle), + Column("data", mutable_pickle), + Column("non_mutable_data", PickleType), + Column("unrelated_data", String(50)), + ) def test_non_mutable(self): self._test_non_mutable() class MutableWithScalarJSONTest(_MutableDictTestBase, fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): import json @@ -789,13 +796,16 @@ class MutableWithScalarJSONTest(_MutableDictTestBase, fixtures.MappedTest): MutableDict = cls._type_fixture() - Table('foo', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', MutableDict.as_mutable(JSONEncodedDict)), - Column('non_mutable_data', JSONEncodedDict), - Column('unrelated_data', String(50)) - ) + Table( + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", MutableDict.as_mutable(JSONEncodedDict)), + Column("non_mutable_data", JSONEncodedDict), + Column("unrelated_data", String(50)), + ) def test_non_mutable(self): self._test_non_mutable() @@ -807,13 +817,10 @@ class MutableIncludeNonPrimaryTest(MutableWithScalarJSONTest): foo = cls.tables.foo mapper(Foo, foo) - mapper(Foo, foo, non_primary=True, properties={ - "foo_bar": foo.c.data - }) + mapper(Foo, foo, non_primary=True, properties={"foo_bar": foo.c.data}) class MutableColumnCopyJSONTest(_MutableDictTestBase, fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): import json @@ -840,8 +847,9 @@ class MutableColumnCopyJSONTest(_MutableDictTestBase, fixtures.MappedTest): class AbstractFoo(Base): __abstract__ = True - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) data = Column(MutableDict.as_mutable(JSONEncodedDict)) non_mutable_data = Column(JSONEncodedDict) unrelated_data = Column(String(50)) @@ -849,7 +857,8 @@ class MutableColumnCopyJSONTest(_MutableDictTestBase, fixtures.MappedTest): class Foo(AbstractFoo): __tablename__ = "foo" column_prop = column_property( - func.lower(AbstractFoo.unrelated_data)) + func.lower(AbstractFoo.unrelated_data) + ) assert Foo.data.property.columns[0].type is not AbstractFoo.data.type @@ -858,7 +867,7 @@ class MutableColumnCopyJSONTest(_MutableDictTestBase, fixtures.MappedTest): class MutableColumnCopyArrayTest(_MutableListTestBase, fixtures.MappedTest): - __requires__ = 'array_type', + __requires__ = ("array_type",) @classmethod def define_tables(cls, metadata): @@ -873,62 +882,72 @@ class MutableColumnCopyArrayTest(_MutableListTestBase, fixtures.MappedTest): data = Column(MutableList.as_mutable(ARRAY(Integer))) class Foo(Mixin, Base): - __tablename__ = 'foo' + __tablename__ = "foo" id = Column(Integer, primary_key=True) -class MutableListWithScalarPickleTest(_MutableListTestBase, - fixtures.MappedTest): - +class MutableListWithScalarPickleTest( + _MutableListTestBase, fixtures.MappedTest +): @classmethod def define_tables(cls, metadata): MutableList = cls._type_fixture() mutable_pickle = MutableList.as_mutable(PickleType) - Table('foo', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('skip', mutable_pickle), - Column('data', mutable_pickle), - Column('non_mutable_data', PickleType), - Column('unrelated_data', String(50)) - ) + Table( + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("skip", mutable_pickle), + Column("data", mutable_pickle), + Column("non_mutable_data", PickleType), + Column("unrelated_data", String(50)), + ) class MutableSetWithScalarPickleTest(_MutableSetTestBase, fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): MutableSet = cls._type_fixture() mutable_pickle = MutableSet.as_mutable(PickleType) - Table('foo', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('skip', mutable_pickle), - Column('data', mutable_pickle), - Column('non_mutable_data', PickleType), - Column('unrelated_data', String(50)) - ) - + Table( + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("skip", mutable_pickle), + Column("data", mutable_pickle), + Column("non_mutable_data", PickleType), + Column("unrelated_data", String(50)), + ) -class MutableAssocWithAttrInheritTest(_MutableDictTestBase, - fixtures.MappedTest): +class MutableAssocWithAttrInheritTest( + _MutableDictTestBase, fixtures.MappedTest +): @classmethod def define_tables(cls, metadata): - Table('foo', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', PickleType), - Column('non_mutable_data', PickleType), - Column('unrelated_data', String(50)) - ) + Table( + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", PickleType), + Column("non_mutable_data", PickleType), + Column("unrelated_data", String(50)), + ) - Table('subfoo', metadata, - Column('id', Integer, ForeignKey('foo.id'), primary_key=True), - ) + Table( + "subfoo", + metadata, + Column("id", Integer, ForeignKey("foo.id"), primary_key=True), + ) def setup_mappers(cls): foo = cls.tables.foo @@ -941,41 +960,44 @@ class MutableAssocWithAttrInheritTest(_MutableDictTestBase, def test_in_place_mutation(self): sess = Session() - f1 = SubFoo(data={'a': 'b'}) + f1 = SubFoo(data={"a": "b"}) sess.add(f1) sess.commit() - f1.data['a'] = 'c' + f1.data["a"] = "c" sess.commit() - eq_(f1.data, {'a': 'c'}) + eq_(f1.data, {"a": "c"}) def test_replace(self): sess = Session() - f1 = SubFoo(data={'a': 'b'}) + f1 = SubFoo(data={"a": "b"}) sess.add(f1) sess.flush() - f1.data = {'b': 'c'} + f1.data = {"b": "c"} sess.commit() - eq_(f1.data, {'b': 'c'}) - + eq_(f1.data, {"b": "c"}) -class MutableAssociationScalarPickleTest(_MutableDictTestBase, - fixtures.MappedTest): +class MutableAssociationScalarPickleTest( + _MutableDictTestBase, fixtures.MappedTest +): @classmethod def define_tables(cls, metadata): MutableDict = cls._type_fixture() MutableDict.associate_with(PickleType) - Table('foo', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('skip', PickleType), - Column('data', PickleType), - Column('unrelated_data', String(50)) - ) + Table( + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("skip", PickleType), + Column("data", PickleType), + Column("unrelated_data", String(50)), + ) class MutableAssocIncludeNonPrimaryTest(MutableAssociationScalarPickleTest): @@ -984,14 +1006,12 @@ class MutableAssocIncludeNonPrimaryTest(MutableAssociationScalarPickleTest): foo = cls.tables.foo mapper(Foo, foo) - mapper(Foo, foo, non_primary=True, properties={ - "foo_bar": foo.c.data - }) + mapper(Foo, foo, non_primary=True, properties={"foo_bar": foo.c.data}) -class MutableAssociationScalarJSONTest(_MutableDictTestBase, - fixtures.MappedTest): - +class MutableAssociationScalarJSONTest( + _MutableDictTestBase, fixtures.MappedTest +): @classmethod def define_tables(cls, metadata): import json @@ -1013,27 +1033,33 @@ class MutableAssociationScalarJSONTest(_MutableDictTestBase, MutableDict = cls._type_fixture() MutableDict.associate_with(JSONEncodedDict) - Table('foo', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', JSONEncodedDict), - Column('unrelated_data', String(50)) - ) + Table( + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", JSONEncodedDict), + Column("unrelated_data", String(50)), + ) -class CustomMutableAssociationScalarJSONTest(_MutableDictTestBase, - fixtures.MappedTest): +class CustomMutableAssociationScalarJSONTest( + _MutableDictTestBase, fixtures.MappedTest +): CustomMutableDict = None @classmethod def _type_fixture(cls): - if not(getattr(cls, 'CustomMutableDict')): + if not (getattr(cls, "CustomMutableDict")): MutableDict = super( - CustomMutableAssociationScalarJSONTest, cls)._type_fixture() + CustomMutableAssociationScalarJSONTest, cls + )._type_fixture() class CustomMutableDict(MutableDict): pass + cls.CustomMutableDict = CustomMutableDict return cls.CustomMutableDict @@ -1058,12 +1084,15 @@ class CustomMutableAssociationScalarJSONTest(_MutableDictTestBase, CustomMutableDict = cls._type_fixture() CustomMutableDict.associate_with(JSONEncodedDict) - Table('foo', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', JSONEncodedDict), - Column('unrelated_data', String(50)) - ) + Table( + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", JSONEncodedDict), + Column("unrelated_data", String(50)), + ) def test_pickle_parent(self): # Picklers don't know how to pickle CustomMutableDict, @@ -1072,26 +1101,29 @@ class CustomMutableAssociationScalarJSONTest(_MutableDictTestBase, def test_coerce(self): sess = Session() - f1 = Foo(data={'a': 'b'}) + f1 = Foo(data={"a": "b"}) sess.add(f1) sess.flush() eq_(type(f1.data), self._type_fixture()) class _CompositeTestBase(object): - @classmethod def define_tables(cls, metadata): - Table('foo', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('x', Integer), - Column('y', Integer), - Column('unrelated_data', String(50)) - ) + Table( + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("x", Integer), + Column("y", Integer), + Column("unrelated_data", String(50)), + ) def setup(self): from sqlalchemy.ext import mutable + mutable._setup_composite_listener() super(_CompositeTestBase, self).setup() @@ -1107,17 +1139,20 @@ class _CompositeTestBase(object): return Point -class MutableCompositeColumnDefaultTest(_CompositeTestBase, - fixtures.MappedTest): +class MutableCompositeColumnDefaultTest( + _CompositeTestBase, fixtures.MappedTest +): @classmethod def define_tables(cls, metadata): Table( - 'foo', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('x', Integer, default=5), - Column('y', Integer, default=9), - Column('unrelated_data', String(50)) + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("x", Integer, default=5), + Column("y", Integer, default=9), + Column("unrelated_data", String(50)), ) @classmethod @@ -1126,9 +1161,11 @@ class MutableCompositeColumnDefaultTest(_CompositeTestBase, cls.Point = cls._type_fixture() - mapper(Foo, foo, properties={ - 'data': composite(cls.Point, foo.c.x, foo.c.y) - }) + mapper( + Foo, + foo, + properties={"data": composite(cls.Point, foo.c.x, foo.c.y)}, + ) def test_evt_on_flush_refresh(self): # this still worked prior to #3427 being fixed in any case @@ -1145,16 +1182,17 @@ class MutableCompositeColumnDefaultTest(_CompositeTestBase, class MutableCompositesUnpickleTest(_CompositeTestBase, fixtures.MappedTest): - @classmethod def setup_mappers(cls): foo = cls.tables.foo cls.Point = cls._type_fixture() - mapper(FooWithEq, foo, properties={ - 'data': composite(cls.Point, foo.c.x, foo.c.y) - }) + mapper( + FooWithEq, + foo, + properties={"data": composite(cls.Point, foo.c.x, foo.c.y)}, + ) def test_unpickle_modified_eq(self): u1 = FooWithEq(data=self.Point(3, 5)) @@ -1163,16 +1201,15 @@ class MutableCompositesUnpickleTest(_CompositeTestBase, fixtures.MappedTest): class MutableCompositesTest(_CompositeTestBase, fixtures.MappedTest): - @classmethod def setup_mappers(cls): foo = cls.tables.foo Point = cls._type_fixture() - mapper(Foo, foo, properties={ - 'data': composite(Point, foo.c.x, foo.c.y) - }) + mapper( + Foo, foo, properties={"data": composite(Point, foo.c.x, foo.c.y)} + ) def test_in_place_mutation(self): sess = Session() @@ -1194,7 +1231,7 @@ class MutableCompositesTest(_CompositeTestBase, fixtures.MappedTest): sess.commit() f1.data - assert 'data' in f1.__dict__ + assert "data" in f1.__dict__ sess.close() for loads, dumps in picklers(): @@ -1220,7 +1257,10 @@ class MutableCompositesTest(_CompositeTestBase, fixtures.MappedTest): assert_raises_message( ValueError, "Attribute 'data' does not accept objects", - setattr, f1, 'data', 'foo' + setattr, + f1, + "data", + "foo", ) def test_unrelated_flush(self): @@ -1237,7 +1277,6 @@ class MutableCompositesTest(_CompositeTestBase, fixtures.MappedTest): class MutableCompositeCallableTest(_CompositeTestBase, fixtures.MappedTest): - @classmethod def setup_mappers(cls): foo = cls.tables.foo @@ -1246,9 +1285,13 @@ class MutableCompositeCallableTest(_CompositeTestBase, fixtures.MappedTest): # in this case, this is not actually a MutableComposite. # so we don't expect it to track changes - mapper(Foo, foo, properties={ - 'data': composite(lambda x, y: Point(x, y), foo.c.x, foo.c.y) - }) + mapper( + Foo, + foo, + properties={ + "data": composite(lambda x, y: Point(x, y), foo.c.x, foo.c.y) + }, + ) def test_basic(self): sess = Session() @@ -1262,9 +1305,9 @@ class MutableCompositeCallableTest(_CompositeTestBase, fixtures.MappedTest): eq_(f1.data.x, 3) -class MutableCompositeCustomCoerceTest(_CompositeTestBase, - fixtures.MappedTest): - +class MutableCompositeCustomCoerceTest( + _CompositeTestBase, fixtures.MappedTest +): @classmethod def _type_fixture(cls): @@ -1276,9 +1319,9 @@ class MutableCompositeCustomCoerceTest(_CompositeTestBase, Point = cls._type_fixture() - mapper(Foo, foo, properties={ - 'data': composite(Point, foo.c.x, foo.c.y) - }) + mapper( + Foo, foo, properties={"data": composite(Point, foo.c.x, foo.c.y)} + ) def test_custom_coerce(self): f = Foo() @@ -1297,18 +1340,22 @@ class MutableCompositeCustomCoerceTest(_CompositeTestBase, class MutableInheritedCompositesTest(_CompositeTestBase, fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('foo', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('x', Integer), - Column('y', Integer) - ) - Table('subfoo', metadata, - Column('id', Integer, ForeignKey('foo.id'), primary_key=True), - ) + Table( + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("x", Integer), + Column("y", Integer), + ) + Table( + "subfoo", + metadata, + Column("id", Integer, ForeignKey("foo.id"), primary_key=True), + ) @classmethod def setup_mappers(cls): @@ -1317,9 +1364,9 @@ class MutableInheritedCompositesTest(_CompositeTestBase, fixtures.MappedTest): Point = cls._type_fixture() - mapper(Foo, foo, properties={ - 'data': composite(Point, foo.c.x, foo.c.y) - }) + mapper( + Foo, foo, properties={"data": composite(Point, foo.c.x, foo.c.y)} + ) mapper(SubFoo, subfoo, inherits=Foo) def test_in_place_mutation_subclass(self): @@ -1342,7 +1389,7 @@ class MutableInheritedCompositesTest(_CompositeTestBase, fixtures.MappedTest): sess.commit() f1.data - assert 'data' in f1.__dict__ + assert "data" in f1.__dict__ sess.close() for loads, dumps in picklers(): diff --git a/test/ext/test_orderinglist.py b/test/ext/test_orderinglist.py index 59c1a5161c..c417c7153d 100644 --- a/test/ext/test_orderinglist.py +++ b/test/ext/test_orderinglist.py @@ -13,8 +13,10 @@ metadata = None def step_numbering(step): """ order in whole steps """ + def f(index, collection): return step * index + return f @@ -24,14 +26,17 @@ def fibonacci_numbering(order_col): e.g. 1, 2, 3, 5, 8, ... instead of 0, 1, 1, 2, 3, ... otherwise ordering of the elements at '1' is undefined... ;) """ + def f(index, collection): if index == 0: return 1 elif index == 1: return 2 else: - return (getattr(collection[index - 1], order_col) + - getattr(collection[index - 2], order_col)) + return getattr(collection[index - 1], order_col) + getattr( + collection[index - 2], order_col + ) + return f @@ -39,7 +44,7 @@ def alpha_ordering(index, collection): """ 0 -> A, 1 -> B, ... 25 -> Z, 26 -> AA, 27 -> AB, ... """ - s = '' + s = "" while index > 25: d = index / 26 s += chr((d % 26) + 64) @@ -61,17 +66,24 @@ class OrderingListTest(fixtures.TestBase): global metadata, slides_table, bullets_table, Slide, Bullet - slides_table = Table('test_Slides', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(128))) - bullets_table = Table('test_Bullets', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('slide_id', Integer, - ForeignKey('test_Slides.id')), - Column('position', Integer), - Column('text', String(128))) + slides_table = Table( + "test_Slides", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(128)), + ) + bullets_table = Table( + "test_Bullets", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("slide_id", Integer, ForeignKey("test_Slides.id")), + Column("position", Integer), + Column("text", String(128)), + ) class Slide(object): def __init__(self, name): @@ -87,11 +99,19 @@ class OrderingListTest(fixtures.TestBase): def __repr__(self): return '' % (self.text, self.position) - mapper(Slide, slides_table, properties={ - 'bullets': relationship(Bullet, lazy='joined', - collection_class=test_collection_class, - backref='slide', - order_by=[bullets_table.c.position])}) + mapper( + Slide, + slides_table, + properties={ + "bullets": relationship( + Bullet, + lazy="joined", + collection_class=test_collection_class, + backref="slide", + order_by=[bullets_table.c.position], + ) + }, + ) mapper(Bullet, bullets_table) metadata.create_all() @@ -100,27 +120,28 @@ class OrderingListTest(fixtures.TestBase): metadata.drop_all() def test_append_no_reorder(self): - self._setup(ordering_list('position', count_from=1, - reorder_on_append=False)) + self._setup( + ordering_list("position", count_from=1, reorder_on_append=False) + ) - s1 = Slide('Slide #1') + s1 = Slide("Slide #1") self.assert_(not s1.bullets) self.assert_(len(s1.bullets) == 0) - s1.bullets.append(Bullet('s1/b1')) + s1.bullets.append(Bullet("s1/b1")) self.assert_(s1.bullets) self.assert_(len(s1.bullets) == 1) self.assert_(s1.bullets[0].position == 1) - s1.bullets.append(Bullet('s1/b2')) + s1.bullets.append(Bullet("s1/b2")) self.assert_(len(s1.bullets) == 2) self.assert_(s1.bullets[0].position == 1) self.assert_(s1.bullets[1].position == 2) - bul = Bullet('s1/b100') + bul = Bullet("s1/b100") bul.position = 100 s1.bullets.append(bul) @@ -128,7 +149,7 @@ class OrderingListTest(fixtures.TestBase): self.assert_(s1.bullets[1].position == 2) self.assert_(s1.bullets[2].position == 100) - s1.bullets.append(Bullet('s1/b4')) + s1.bullets.append(Bullet("s1/b4")) self.assert_(s1.bullets[0].position == 1) self.assert_(s1.bullets[1].position == 2) self.assert_(s1.bullets[2].position == 100) @@ -153,33 +174,34 @@ class OrderingListTest(fixtures.TestBase): self.assert_(srt.bullets) self.assert_(len(srt.bullets) == 4) - titles = ['s1/b1', 's1/b2', 's1/b100', 's1/b4'] + titles = ["s1/b1", "s1/b2", "s1/b100", "s1/b4"] found = [b.text for b in srt.bullets] self.assert_(titles == found) def test_append_reorder(self): - self._setup(ordering_list('position', count_from=1, - reorder_on_append=True)) + self._setup( + ordering_list("position", count_from=1, reorder_on_append=True) + ) - s1 = Slide('Slide #1') + s1 = Slide("Slide #1") self.assert_(not s1.bullets) self.assert_(len(s1.bullets) == 0) - s1.bullets.append(Bullet('s1/b1')) + s1.bullets.append(Bullet("s1/b1")) self.assert_(s1.bullets) self.assert_(len(s1.bullets) == 1) self.assert_(s1.bullets[0].position == 1) - s1.bullets.append(Bullet('s1/b2')) + s1.bullets.append(Bullet("s1/b2")) self.assert_(len(s1.bullets) == 2) self.assert_(s1.bullets[0].position == 1) self.assert_(s1.bullets[1].position == 2) - bul = Bullet('s1/b100') + bul = Bullet("s1/b100") bul.position = 100 s1.bullets.append(bul) @@ -187,7 +209,7 @@ class OrderingListTest(fixtures.TestBase): self.assert_(s1.bullets[1].position == 2) self.assert_(s1.bullets[2].position == 3) - s1.bullets.append(Bullet('s1/b4')) + s1.bullets.append(Bullet("s1/b4")) self.assert_(s1.bullets[0].position == 1) self.assert_(s1.bullets[1].position == 2) self.assert_(s1.bullets[2].position == 3) @@ -199,7 +221,7 @@ class OrderingListTest(fixtures.TestBase): self.assert_(s1.bullets[2].position == 3) self.assert_(s1.bullets[3].position == 4) - s1.bullets._raw_append(Bullet('raw')) + s1.bullets._raw_append(Bullet("raw")) self.assert_(s1.bullets[4].position is None) s1.bullets._reorder() @@ -217,46 +239,46 @@ class OrderingListTest(fixtures.TestBase): self.assert_(srt.bullets) self.assert_(len(srt.bullets) == 5) - titles = ['s1/b1', 's1/b2', 's1/b100', 's1/b4', 'raw'] + titles = ["s1/b1", "s1/b2", "s1/b100", "s1/b4", "raw"] found = [b.text for b in srt.bullets] eq_(titles, found) - srt.bullets._raw_append(Bullet('raw2')) + srt.bullets._raw_append(Bullet("raw2")) srt.bullets[-1].position = 6 session.flush() session.expunge_all() srt = session.query(Slide).get(id) - titles = ['s1/b1', 's1/b2', 's1/b100', 's1/b4', 'raw', 'raw2'] + titles = ["s1/b1", "s1/b2", "s1/b100", "s1/b4", "raw", "raw2"] found = [b.text for b in srt.bullets] eq_(titles, found) def test_insert(self): - self._setup(ordering_list('position')) + self._setup(ordering_list("position")) - s1 = Slide('Slide #1') - s1.bullets.append(Bullet('1')) - s1.bullets.append(Bullet('2')) - s1.bullets.append(Bullet('3')) - s1.bullets.append(Bullet('4')) + s1 = Slide("Slide #1") + s1.bullets.append(Bullet("1")) + s1.bullets.append(Bullet("2")) + s1.bullets.append(Bullet("3")) + s1.bullets.append(Bullet("4")) self.assert_(s1.bullets[0].position == 0) self.assert_(s1.bullets[1].position == 1) self.assert_(s1.bullets[2].position == 2) self.assert_(s1.bullets[3].position == 3) - s1.bullets.insert(2, Bullet('insert_at_2')) + s1.bullets.insert(2, Bullet("insert_at_2")) self.assert_(s1.bullets[0].position == 0) self.assert_(s1.bullets[1].position == 1) self.assert_(s1.bullets[2].position == 2) self.assert_(s1.bullets[3].position == 3) self.assert_(s1.bullets[4].position == 4) - self.assert_(s1.bullets[1].text == '2') - self.assert_(s1.bullets[2].text == 'insert_at_2') - self.assert_(s1.bullets[3].text == '3') + self.assert_(s1.bullets[1].text == "2") + self.assert_(s1.bullets[2].text == "insert_at_2") + self.assert_(s1.bullets[3].text == "3") - s1.bullets.insert(999, Bullet('999')) + s1.bullets.insert(999, Bullet("999")) self.assert_(len(s1.bullets) == 6) self.assert_(s1.bullets[5].position == 5) @@ -274,17 +296,23 @@ class OrderingListTest(fixtures.TestBase): self.assert_(srt.bullets) self.assert_(len(srt.bullets) == 6) - texts = ['1', '2', 'insert_at_2', '3', '4', '999'] + texts = ["1", "2", "insert_at_2", "3", "4", "999"] found = [b.text for b in srt.bullets] self.assert_(texts == found) def test_slice(self): - self._setup(ordering_list('position')) - - b = [Bullet('1'), Bullet('2'), Bullet('3'), - Bullet('4'), Bullet('5'), Bullet('6')] - s1 = Slide('Slide #1') + self._setup(ordering_list("position")) + + b = [ + Bullet("1"), + Bullet("2"), + Bullet("3"), + Bullet("4"), + Bullet("5"), + Bullet("6"), + ] + s1 = Slide("Slide #1") # 1, 2, 3 s1.bullets[0:3] = b[0:3] @@ -317,16 +345,16 @@ class OrderingListTest(fixtures.TestBase): self.assert_(srt.bullets) self.assert_(len(srt.bullets) == 3) - texts = ['1', '6', '3'] + texts = ["1", "6", "3"] for i, text in enumerate(texts): self.assert_(srt.bullets[i].position == i) self.assert_(srt.bullets[i].text == text) def test_replace(self): - self._setup(ordering_list('position')) + self._setup(ordering_list("position")) - s1 = Slide('Slide #1') - s1.bullets = [Bullet('1'), Bullet('2'), Bullet('3')] + s1 = Slide("Slide #1") + s1.bullets = [Bullet("1"), Bullet("2"), Bullet("3")] self.assert_(len(s1.bullets) == 3) self.assert_(s1.bullets[2].position == 2) @@ -335,7 +363,7 @@ class OrderingListTest(fixtures.TestBase): session.add(s1) session.flush() - new_bullet = Bullet('new 2') + new_bullet = Bullet("new 2") self.assert_(new_bullet.position is None) # mark existing bullet as db-deleted before replacement. @@ -355,37 +383,32 @@ class OrderingListTest(fixtures.TestBase): self.assert_(srt.bullets) self.assert_(len(srt.bullets) == 3) - self.assert_(srt.bullets[1].text == 'new 2') - self.assert_(srt.bullets[2].text == '3') + self.assert_(srt.bullets[1].text == "new 2") + self.assert_(srt.bullets[2].text == "3") def test_replace_two(self): """test #3191""" - self._setup(ordering_list('position', reorder_on_append=True)) + self._setup(ordering_list("position", reorder_on_append=True)) - s1 = Slide('Slide #1') + s1 = Slide("Slide #1") - b1, b2, b3, b4 = Bullet('1'), Bullet('2'), Bullet('3'), Bullet('4') + b1, b2, b3, b4 = Bullet("1"), Bullet("2"), Bullet("3"), Bullet("4") s1.bullets = [b1, b2, b3] - eq_( - [b.position for b in s1.bullets], - [0, 1, 2] - ) + eq_([b.position for b in s1.bullets], [0, 1, 2]) s1.bullets = [b4, b2, b1] - eq_( - [b.position for b in s1.bullets], - [0, 1, 2] - ) + eq_([b.position for b in s1.bullets], [0, 1, 2]) def test_funky_ordering(self): class Pos(object): def __init__(self): self.position = None - step_factory = ordering_list('position', - ordering_func=step_numbering(2)) + step_factory = ordering_list( + "position", ordering_func=step_numbering(2) + ) stepped = step_factory() stepped.append(Pos()) @@ -397,8 +420,8 @@ class OrderingListTest(fixtures.TestBase): self.assert_(stepped[li].position == pos) fib_factory = ordering_list( - 'position', - ordering_func=fibonacci_numbering('position')) + "position", ordering_func=fibonacci_numbering("position") + ) fibbed = fib_factory() fibbed.append(Pos()) @@ -426,8 +449,7 @@ class OrderingListTest(fixtures.TestBase): ): self.assert_(fibbed[li].position == pos) - alpha_factory = ordering_list('position', - ordering_func=alpha_ordering) + alpha_factory = ordering_list("position", ordering_func=alpha_ordering) alpha = alpha_factory() alpha.append(Pos()) alpha.append(Pos()) @@ -435,13 +457,13 @@ class OrderingListTest(fixtures.TestBase): alpha.insert(1, Pos()) - for li, pos in (0, 'A'), (1, 'B'), (2, 'C'), (3, 'D'): + for li, pos in (0, "A"), (1, "B"), (2, "C"), (3, "D"): self.assert_(alpha[li].position == pos) def test_picklability(self): from sqlalchemy.ext.orderinglist import OrderingList - olist = OrderingList('order', reorder_on_append=True) + olist = OrderingList("order", reorder_on_append=True) olist.append(DummyItem()) for loads, dumps in picklers(): diff --git a/test/ext/test_serializer.py b/test/ext/test_serializer.py index 1ea5dfd1de..3173c0234e 100644 --- a/test/ext/test_serializer.py +++ b/test/ext/test_serializer.py @@ -2,12 +2,30 @@ from sqlalchemy.ext import serializer from sqlalchemy import testing -from sqlalchemy import Integer, String, ForeignKey, select, \ - desc, func, util, MetaData, literal_column, join +from sqlalchemy import ( + Integer, + String, + ForeignKey, + select, + desc, + func, + util, + MetaData, + literal_column, + join, +) from sqlalchemy.testing.schema import Table from sqlalchemy.testing.schema import Column -from sqlalchemy.orm import relationship, sessionmaker, scoped_session, \ - class_mapper, mapper, joinedload, configure_mappers, aliased +from sqlalchemy.orm import ( + relationship, + sessionmaker, + scoped_session, + class_mapper, + mapper, + joinedload, + configure_mappers, + aliased, +) from sqlalchemy.testing import eq_, AssertsCompiledSQL from sqlalchemy.util import u, ue from sqlalchemy.testing import fixtures @@ -15,7 +33,7 @@ from sqlalchemy.testing import fixtures def pickle_protocols(): return iter([-1, 1, 2]) - #return iter([-1, 0, 1, 2]) + # return iter([-1, 0, 1, 2]) class User(fixtures.ComparableEntity): @@ -31,131 +49,193 @@ users = addresses = Session = None class SerializeTest(AssertsCompiledSQL, fixtures.MappedTest): - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): global users, addresses - users = Table('users', metadata, Column('id', Integer, - primary_key=True), Column('name', String(50))) - addresses = Table('addresses', metadata, Column('id', Integer, - primary_key=True), Column('email', - String(50)), Column('user_id', Integer, - ForeignKey('users.id'))) + users = Table( + "users", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + ) + addresses = Table( + "addresses", + metadata, + Column("id", Integer, primary_key=True), + Column("email", String(50)), + Column("user_id", Integer, ForeignKey("users.id")), + ) @classmethod def setup_mappers(cls): global Session Session = scoped_session(sessionmaker()) - mapper(User, users, - properties={'addresses': relationship(Address, backref='user', - order_by=addresses.c.id)}) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, backref="user", order_by=addresses.c.id + ) + }, + ) mapper(Address, addresses) configure_mappers() @classmethod def insert_data(cls): - params = [dict(list(zip(('id', 'name'), column_values))) - for column_values in [(7, 'jack'), (8, 'ed'), (9, 'fred'), - (10, 'chuck')]] + params = [ + dict(list(zip(("id", "name"), column_values))) + for column_values in [ + (7, "jack"), + (8, "ed"), + (9, "fred"), + (10, "chuck"), + ] + ] users.insert().execute(params) - addresses.insert().execute([dict(list(zip(('id', 'user_id', 'email'), - column_values))) - for column_values in [ - (1, 7, 'jack@bean.com'), - (2, 8, 'ed@wood.com'), - (3, 8, 'ed@bettyboop.com'), - (4, 8, 'ed@lala.com'), - (5, 9, 'fred@fred.com')]]) + addresses.insert().execute( + [ + dict(list(zip(("id", "user_id", "email"), column_values))) + for column_values in [ + (1, 7, "jack@bean.com"), + (2, 8, "ed@wood.com"), + (3, 8, "ed@bettyboop.com"), + (4, 8, "ed@lala.com"), + (5, 9, "fred@fred.com"), + ] + ] + ) def test_tables(self): - assert serializer.loads(serializer.dumps(users, -1), - users.metadata, Session) is users + assert ( + serializer.loads( + serializer.dumps(users, -1), users.metadata, Session + ) + is users + ) def test_columns(self): - assert serializer.loads(serializer.dumps(users.c.name, -1), - users.metadata, Session) is users.c.name + assert ( + serializer.loads( + serializer.dumps(users.c.name, -1), users.metadata, Session + ) + is users.c.name + ) def test_mapper(self): user_mapper = class_mapper(User) - assert serializer.loads(serializer.dumps(user_mapper, -1), - None, None) is user_mapper + assert ( + serializer.loads(serializer.dumps(user_mapper, -1), None, None) + is user_mapper + ) def test_attribute(self): - assert serializer.loads(serializer.dumps(User.name, -1), None, - None) is User.name + assert ( + serializer.loads(serializer.dumps(User.name, -1), None, None) + is User.name + ) def test_expression(self): - expr = \ - select([users]).select_from(users.join(addresses)).limit(5) - re_expr = serializer.loads(serializer.dumps(expr, -1), - users.metadata, None) + expr = select([users]).select_from(users.join(addresses)).limit(5) + re_expr = serializer.loads( + serializer.dumps(expr, -1), users.metadata, None + ) eq_(str(expr), str(re_expr)) assert re_expr.bind is testing.db - eq_(re_expr.execute().fetchall(), [(7, 'jack'), (8, 'ed'), - (8, 'ed'), (8, 'ed'), (9, 'fred')]) + eq_( + re_expr.execute().fetchall(), + [(7, "jack"), (8, "ed"), (8, "ed"), (8, "ed"), (9, "fred")], + ) def test_query_one(self): - q = Session.query(User).filter(User.name == 'ed').\ - options(joinedload(User.addresses)) + q = ( + Session.query(User) + .filter(User.name == "ed") + .options(joinedload(User.addresses)) + ) q2 = serializer.loads(serializer.dumps(q, -1), users.metadata, Session) def go(): - eq_(q2.all(), - [User(name='ed', - addresses=[Address(id=2), Address(id=3), Address(id=4)])] - ) + eq_( + q2.all(), + [ + User( + name="ed", + addresses=[ + Address(id=2), + Address(id=3), + Address(id=4), + ], + ) + ], + ) self.assert_sql_count(testing.db, go, 1) - eq_(q2.join(User.addresses).filter(Address.email - == 'ed@bettyboop.com').value(func.count(literal_column('*'))), 1) + eq_( + q2.join(User.addresses) + .filter(Address.email == "ed@bettyboop.com") + .value(func.count(literal_column("*"))), + 1, + ) u1 = Session.query(User).get(8) - q = Session.query(Address).filter(Address.user == u1)\ + q = ( + Session.query(Address) + .filter(Address.user == u1) .order_by(desc(Address.email)) - q2 = serializer.loads(serializer.dumps(q, -1), users.metadata, - Session) - eq_(q2.all(), [Address(email='ed@wood.com'), - Address(email='ed@lala.com'), - Address(email='ed@bettyboop.com')]) + ) + q2 = serializer.loads(serializer.dumps(q, -1), users.metadata, Session) + eq_( + q2.all(), + [ + Address(email="ed@wood.com"), + Address(email="ed@lala.com"), + Address(email="ed@bettyboop.com"), + ], + ) @testing.requires.non_broken_pickle def test_query_two(self): - q = Session.query(User).join(User.addresses).\ - filter(Address.email.like('%fred%')) - q2 = serializer.loads(serializer.dumps(q, -1), users.metadata, - Session) - eq_(q2.all(), [User(name='fred')]) - eq_(list(q2.values(User.id, User.name)), [(9, 'fred')]) + q = ( + Session.query(User) + .join(User.addresses) + .filter(Address.email.like("%fred%")) + ) + q2 = serializer.loads(serializer.dumps(q, -1), users.metadata, Session) + eq_(q2.all(), [User(name="fred")]) + eq_(list(q2.values(User.id, User.name)), [(9, "fred")]) @testing.requires.non_broken_pickle def test_query_three(self): ua = aliased(User) - q = \ - Session.query(ua).join(ua.addresses).\ - filter(Address.email.like('%fred%')) + q = ( + Session.query(ua) + .join(ua.addresses) + .filter(Address.email.like("%fred%")) + ) for prot in pickle_protocols(): - q2 = serializer.loads(serializer.dumps(q, prot), users.metadata, - Session) - eq_(q2.all(), [User(name='fred')]) + q2 = serializer.loads( + serializer.dumps(q, prot), users.metadata, Session + ) + eq_(q2.all(), [User(name="fred")]) - # try to pull out the aliased entity here... + # try to pull out the aliased entity here... ua_2 = q2._entities[0].entity_zero.entity - eq_(list(q2.values(ua_2.id, ua_2.name)), [(9, 'fred')]) + eq_(list(q2.values(ua_2.id, ua_2.name)), [(9, "fred")]) def test_annotated_one(self): j = join(users, addresses)._annotate({"foo": "bar"}) - query = select([addresses]).select_from( - j - ) + query = select([addresses]).select_from(j) str(query) for prot in pickle_protocols(): - pickled_failing = serializer.dumps( - j, prot) + pickled_failing = serializer.dumps(j, prot) serializer.loads(pickled_failing, users.metadata, None) @testing.requires.non_broken_pickle @@ -169,45 +249,53 @@ class SerializeTest(AssertsCompiledSQL, fixtures.MappedTest): assert j2.right is j.right assert j2._target_adapter._next - @testing.exclude('sqlite', '<=', (3, 5, 9), - 'id comparison failing on the buildbot') + @testing.exclude( + "sqlite", "<=", (3, 5, 9), "id comparison failing on the buildbot" + ) def test_aliases(self): u7, u8, u9, u10 = Session.query(User).order_by(User.id).all() ualias = aliased(User) - q = Session.query(User, ualias)\ - .join(ualias, User.id < ualias.id)\ - .filter(User.id < 9)\ + q = ( + Session.query(User, ualias) + .join(ualias, User.id < ualias.id) + .filter(User.id < 9) .order_by(User.id, ualias.id) - eq_(list(q.all()), [(u7, u8), (u7, u9), (u7, u10), (u8, u9), - (u8, u10)]) - q2 = serializer.loads(serializer.dumps(q, -1), users.metadata, - Session) - eq_(list(q2.all()), [(u7, u8), (u7, u9), (u7, u10), (u8, u9), - (u8, u10)]) + ) + eq_( + list(q.all()), [(u7, u8), (u7, u9), (u7, u10), (u8, u9), (u8, u10)] + ) + q2 = serializer.loads(serializer.dumps(q, -1), users.metadata, Session) + eq_( + list(q2.all()), + [(u7, u8), (u7, u9), (u7, u10), (u8, u9), (u8, u10)], + ) @testing.requires.non_broken_pickle def test_any(self): - r = User.addresses.any(Address.email == 'x') + r = User.addresses.any(Address.email == "x") ser = serializer.dumps(r, -1) x = serializer.loads(ser, users.metadata) eq_(str(r), str(x)) def test_unicode(self): m = MetaData() - t = Table(ue('\u6e2c\u8a66'), m, - Column(ue('\u6e2c\u8a66_id'), Integer)) + t = Table( + ue("\u6e2c\u8a66"), m, Column(ue("\u6e2c\u8a66_id"), Integer) + ) - expr = select([t]).where(t.c[ue('\u6e2c\u8a66_id')] == 5) + expr = select([t]).where(t.c[ue("\u6e2c\u8a66_id")] == 5) expr2 = serializer.loads(serializer.dumps(expr, -1), m) self.assert_compile( expr2, - ue('SELECT "\u6e2c\u8a66"."\u6e2c\u8a66_id" FROM "\u6e2c\u8a66" ' - 'WHERE "\u6e2c\u8a66"."\u6e2c\u8a66_id" = :\u6e2c\u8a66_id_1'), - dialect="default" + ue( + 'SELECT "\u6e2c\u8a66"."\u6e2c\u8a66_id" FROM "\u6e2c\u8a66" ' + 'WHERE "\u6e2c\u8a66"."\u6e2c\u8a66_id" = :\u6e2c\u8a66_id_1' + ), + dialect="default", ) -if __name__ == '__main__': +if __name__ == "__main__": testing.main() diff --git a/test/orm/_fixtures.py b/test/orm/_fixtures.py index 97ca6d8b8a..57a042447a 100644 --- a/test/orm/_fixtures.py +++ b/test/orm/_fixtures.py @@ -2,8 +2,13 @@ from sqlalchemy import MetaData, Integer, String, ForeignKey from sqlalchemy import util from sqlalchemy.testing.schema import Table from sqlalchemy.testing.schema import Column -from sqlalchemy.orm import attributes, mapper, relationship, \ - backref, configure_mappers +from sqlalchemy.orm import ( + attributes, + mapper, + relationship, + backref, + configure_mappers, +) from sqlalchemy.testing import fixtures __all__ = () @@ -14,11 +19,11 @@ class FixtureTest(fixtures.MappedTest): """ - run_define_tables = 'once' - run_setup_classes = 'once' - run_setup_mappers = 'each' - run_inserts = 'each' - run_deletes = 'each' + run_define_tables = "once" + run_setup_classes = "once" + run_setup_mappers = "each" + run_inserts = "each" + run_deletes = "each" @classmethod def setup_classes(cls): @@ -51,50 +56,91 @@ class FixtureTest(fixtures.MappedTest): @classmethod def _setup_stock_mapping(cls): - Node, composite_pk_table, users, Keyword, items, Dingaling, \ - order_items, item_keywords, Item, User, dingalings, \ - Address, keywords, CompositePk, nodes, Order, orders, \ - addresses = cls.classes.Node, \ - cls.tables.composite_pk_table, cls.tables.users, \ - cls.classes.Keyword, cls.tables.items, \ - cls.classes.Dingaling, cls.tables.order_items, \ - cls.tables.item_keywords, cls.classes.Item, \ - cls.classes.User, cls.tables.dingalings, \ - cls.classes.Address, cls.tables.keywords, \ - cls.classes.CompositePk, cls.tables.nodes, \ - cls.classes.Order, cls.tables.orders, cls.tables.addresses + Node, composite_pk_table, users, Keyword, items, Dingaling, order_items, item_keywords, Item, User, dingalings, Address, keywords, CompositePk, nodes, Order, orders, addresses = ( + cls.classes.Node, + cls.tables.composite_pk_table, + cls.tables.users, + cls.classes.Keyword, + cls.tables.items, + cls.classes.Dingaling, + cls.tables.order_items, + cls.tables.item_keywords, + cls.classes.Item, + cls.classes.User, + cls.tables.dingalings, + cls.classes.Address, + cls.tables.keywords, + cls.classes.CompositePk, + cls.tables.nodes, + cls.classes.Order, + cls.tables.orders, + cls.tables.addresses, + ) # use OrderedDict on this one to support some tests that # assert the order of attributes (e.g. orm/test_inspect) - mapper(User, users, properties=util.OrderedDict( - [('addresses', relationship(Address, backref='user', - order_by=addresses.c.id)), - ('orders', relationship(Order, backref='user', - order_by=orders.c.id)), # o2m, m2o - ] - )) - mapper(Address, addresses, properties={ - # o2o - 'dingaling': relationship(Dingaling, uselist=False, - backref="address") - }) + mapper( + User, + users, + properties=util.OrderedDict( + [ + ( + "addresses", + relationship( + Address, backref="user", order_by=addresses.c.id + ), + ), + ( + "orders", + relationship( + Order, backref="user", order_by=orders.c.id + ), + ), # o2m, m2o + ] + ), + ) + mapper( + Address, + addresses, + properties={ + # o2o + "dingaling": relationship( + Dingaling, uselist=False, backref="address" + ) + }, + ) mapper(Dingaling, dingalings) - mapper(Order, orders, properties={ - # m2m - 'items': relationship(Item, secondary=order_items, - order_by=items.c.id), - 'address': relationship(Address), # m2o - }) - mapper(Item, items, properties={ - 'keywords': relationship(Keyword, secondary=item_keywords) # m2m - }) + mapper( + Order, + orders, + properties={ + # m2m + "items": relationship( + Item, secondary=order_items, order_by=items.c.id + ), + "address": relationship(Address), # m2o + }, + ) + mapper( + Item, + items, + properties={ + "keywords": relationship( + Keyword, secondary=item_keywords + ) # m2m + }, + ) mapper(Keyword, keywords) - mapper(Node, nodes, properties={ - 'children': relationship(Node, - backref=backref('parent', - remote_side=[nodes.c.id])) - }) + mapper( + Node, + nodes, + properties={ + "children": relationship( + Node, backref=backref("parent", remote_side=[nodes.c.id]) + ) + }, + ) mapper(CompositePk, composite_pk_table) @@ -102,81 +148,121 @@ class FixtureTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30), nullable=False), - test_needs_acid=True, - test_needs_fk=True) - - Table('addresses', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', None, ForeignKey('users.id')), - Column('email_address', String(50), nullable=False), - test_needs_acid=True, - test_needs_fk=True) - - Table('email_bounces', metadata, - Column('id', Integer, ForeignKey('addresses.id')), - Column('bounces', Integer)) - - Table('orders', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', None, ForeignKey('users.id')), - Column('address_id', None, ForeignKey('addresses.id')), - Column('description', String(30)), - Column('isopen', Integer), - test_needs_acid=True, - test_needs_fk=True) - - Table("dingalings", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('address_id', None, ForeignKey('addresses.id')), - Column('data', String(30)), - test_needs_acid=True, - test_needs_fk=True) - - Table('items', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('description', String(30), nullable=False), - test_needs_acid=True, - test_needs_fk=True) - - Table('order_items', metadata, - Column('item_id', None, ForeignKey('items.id')), - Column('order_id', None, ForeignKey('orders.id')), - test_needs_acid=True, - test_needs_fk=True) - - Table('keywords', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30), nullable=False), - test_needs_acid=True, - test_needs_fk=True) - - Table('item_keywords', metadata, - Column('item_id', None, ForeignKey('items.id')), - Column('keyword_id', None, ForeignKey('keywords.id')), - test_needs_acid=True, - test_needs_fk=True) - - Table('nodes', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, ForeignKey('nodes.id')), - Column('data', String(30)), - test_needs_acid=True, - test_needs_fk=True) - - Table('composite_pk_table', metadata, - Column('i', Integer, primary_key=True), - Column('j', Integer, primary_key=True), - Column('k', Integer, nullable=False)) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30), nullable=False), + test_needs_acid=True, + test_needs_fk=True, + ) + + Table( + "addresses", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", None, ForeignKey("users.id")), + Column("email_address", String(50), nullable=False), + test_needs_acid=True, + test_needs_fk=True, + ) + + Table( + "email_bounces", + metadata, + Column("id", Integer, ForeignKey("addresses.id")), + Column("bounces", Integer), + ) + + Table( + "orders", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", None, ForeignKey("users.id")), + Column("address_id", None, ForeignKey("addresses.id")), + Column("description", String(30)), + Column("isopen", Integer), + test_needs_acid=True, + test_needs_fk=True, + ) + + Table( + "dingalings", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("address_id", None, ForeignKey("addresses.id")), + Column("data", String(30)), + test_needs_acid=True, + test_needs_fk=True, + ) + + Table( + "items", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("description", String(30), nullable=False), + test_needs_acid=True, + test_needs_fk=True, + ) + + Table( + "order_items", + metadata, + Column("item_id", None, ForeignKey("items.id")), + Column("order_id", None, ForeignKey("orders.id")), + test_needs_acid=True, + test_needs_fk=True, + ) + + Table( + "keywords", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30), nullable=False), + test_needs_acid=True, + test_needs_fk=True, + ) + + Table( + "item_keywords", + metadata, + Column("item_id", None, ForeignKey("items.id")), + Column("keyword_id", None, ForeignKey("keywords.id")), + test_needs_acid=True, + test_needs_fk=True, + ) + + Table( + "nodes", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_id", Integer, ForeignKey("nodes.id")), + Column("data", String(30)), + test_needs_acid=True, + test_needs_fk=True, + ) + + Table( + "composite_pk_table", + metadata, + Column("i", Integer, primary_key=True), + Column("j", Integer, primary_key=True), + Column("k", Integer, nullable=False), + ) @classmethod def setup_mappers(cls): @@ -186,88 +272,76 @@ class FixtureTest(fixtures.MappedTest): def fixtures(cls): return dict( users=( - ('id', 'name'), - (7, 'jack'), - (8, 'ed'), - (9, 'fred'), - (10, 'chuck') + ("id", "name"), + (7, "jack"), + (8, "ed"), + (9, "fred"), + (10, "chuck"), ), - addresses=( - ('id', 'user_id', 'email_address'), + ("id", "user_id", "email_address"), (1, 7, "jack@bean.com"), (2, 8, "ed@wood.com"), (3, 8, "ed@bettyboop.com"), (4, 8, "ed@lala.com"), - (5, 9, "fred@fred.com") + (5, 9, "fred@fred.com"), ), - email_bounces=( - ('id', 'bounces'), + ("id", "bounces"), (1, 1), (2, 0), (3, 5), (4, 0), - (5, 0) + (5, 0), ), - orders=( - ('id', 'user_id', 'description', 'isopen', 'address_id'), - (1, 7, 'order 1', 0, 1), - (2, 9, 'order 2', 0, 4), - (3, 7, 'order 3', 1, 1), - (4, 9, 'order 4', 1, 4), - (5, 7, 'order 5', 0, None) + ("id", "user_id", "description", "isopen", "address_id"), + (1, 7, "order 1", 0, 1), + (2, 9, "order 2", 0, 4), + (3, 7, "order 3", 1, 1), + (4, 9, "order 4", 1, 4), + (5, 7, "order 5", 0, None), ), - dingalings=( - ('id', 'address_id', 'data'), - (1, 2, 'ding 1/2'), - (2, 5, 'ding 2/5') + ("id", "address_id", "data"), + (1, 2, "ding 1/2"), + (2, 5, "ding 2/5"), ), - items=( - ('id', 'description'), - (1, 'item 1'), - (2, 'item 2'), - (3, 'item 3'), - (4, 'item 4'), - (5, 'item 5') + ("id", "description"), + (1, "item 1"), + (2, "item 2"), + (3, "item 3"), + (4, "item 4"), + (5, "item 5"), ), - order_items=( - ('item_id', 'order_id'), + ("item_id", "order_id"), (1, 1), (2, 1), (3, 1), - (1, 2), (2, 2), (3, 2), - (3, 3), (4, 3), (5, 3), - (1, 4), (5, 4), - - (5, 5) + (5, 5), ), - keywords=( - ('id', 'name'), - (1, 'blue'), - (2, 'red'), - (3, 'green'), - (4, 'big'), - (5, 'small'), - (6, 'round'), - (7, 'square') + ("id", "name"), + (1, "blue"), + (2, "red"), + (3, "green"), + (4, "big"), + (5, "small"), + (6, "round"), + (7, "square"), ), - item_keywords=( - ('keyword_id', 'item_id'), + ("keyword_id", "item_id"), (2, 1), (2, 2), (4, 1), @@ -276,20 +350,16 @@ class FixtureTest(fixtures.MappedTest): (3, 3), (4, 3), (7, 2), - (6, 3) + (6, 3), ), - - nodes=( - ('id', 'parent_id', 'data'), - ), - + nodes=(("id", "parent_id", "data"),), composite_pk_table=( - ('i', 'j', 'k'), + ("i", "j", "k"), (1, 2, 3), (2, 1, 4), (1, 1, 5), - (2, 2, 6) - ) + (2, 2, 6), + ), ) @util.memoized_property @@ -307,29 +377,25 @@ class CannedResults(object): def user_result(self): User = self.test.classes.User - return [ - User(id=7), - User(id=8), - User(id=9), - User(id=10)] + return [User(id=7), User(id=8), User(id=9), User(id=10)] @property def user_address_result(self): User, Address = self.test.classes.User, self.test.classes.Address return [ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=2, email_address='ed@wood.com'), - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[])] + User(id=7, addresses=[Address(id=1)]), + User( + id=8, + addresses=[ + Address(id=2, email_address="ed@wood.com"), + Address(id=3, email_address="ed@bettyboop.com"), + Address(id=4, email_address="ed@lala.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + User(id=10, addresses=[]), + ] @property def address_user_result(self): @@ -338,133 +404,176 @@ class CannedResults(object): u8 = User(id=8) u9 = User(id=9) return [ - Address(id=1, email_address='jack@bean.com', user=u7), - Address(id=2, email_address='ed@wood.com', user=u8), - Address(id=3, email_address='ed@bettyboop.com', user=u8), - Address(id=4, email_address='ed@lala.com', user=u8), - Address(id=5, user=u9) + Address(id=1, email_address="jack@bean.com", user=u7), + Address(id=2, email_address="ed@wood.com", user=u8), + Address(id=3, email_address="ed@bettyboop.com", user=u8), + Address(id=4, email_address="ed@lala.com", user=u8), + Address(id=5, user=u9), ] @property def user_all_result(self): - User, Address, Order, Item = self.test.classes.User, \ - self.test.classes.Address, self.test.classes.Order, \ - self.test.classes.Item + User, Address, Order, Item = ( + self.test.classes.User, + self.test.classes.Address, + self.test.classes.Order, + self.test.classes.Item, + ) return [ - User(id=7, - addresses=[Address(id=1)], - orders=[ - Order(description='order 1', - items=[ - Item(description='item 1'), - Item(description='item 2'), - Item(description='item 3')]), - Order(description='order 3'), - Order(description='order 5')]), - User(id=8, - addresses=[Address(id=2), Address(id=3), Address(id=4)]), - User(id=9, - addresses=[ - Address(id=5)], - orders=[ - Order(description='order 2', - items=[ - Item(description='item 1'), - Item(description='item 2'), - Item(description='item 3')]), - Order(description='order 4', - items=[ - Item(description='item 1'), - Item(description='item 5')])]), - User(id=10, addresses=[])] + User( + id=7, + addresses=[Address(id=1)], + orders=[ + Order( + description="order 1", + items=[ + Item(description="item 1"), + Item(description="item 2"), + Item(description="item 3"), + ], + ), + Order(description="order 3"), + Order(description="order 5"), + ], + ), + User( + id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)] + ), + User( + id=9, + addresses=[Address(id=5)], + orders=[ + Order( + description="order 2", + items=[ + Item(description="item 1"), + Item(description="item 2"), + Item(description="item 3"), + ], + ), + Order( + description="order 4", + items=[ + Item(description="item 1"), + Item(description="item 5"), + ], + ), + ], + ), + User(id=10, addresses=[]), + ] @property def user_order_result(self): - User, Order, Item = self.test.classes.User, \ - self.test.classes.Order, self.test.classes.Item + User, Order, Item = ( + self.test.classes.User, + self.test.classes.Order, + self.test.classes.Item, + ) return [ - User(id=7, - orders=[ - Order(id=1, - items=[Item(id=1), Item(id=2), Item(id=3)]), - Order(id=3, - items=[Item(id=3), Item(id=4), Item(id=5)]), - Order(id=5, - items=[Item(id=5)])]), - User(id=8, - orders=[]), - User(id=9, - orders=[ - Order(id=2, - items=[Item(id=1), Item(id=2), Item(id=3)]), - Order(id=4, - items=[Item(id=1), Item(id=5)])]), - User(id=10)] + User( + id=7, + orders=[ + Order(id=1, items=[Item(id=1), Item(id=2), Item(id=3)]), + Order(id=3, items=[Item(id=3), Item(id=4), Item(id=5)]), + Order(id=5, items=[Item(id=5)]), + ], + ), + User(id=8, orders=[]), + User( + id=9, + orders=[ + Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)]), + Order(id=4, items=[Item(id=1), Item(id=5)]), + ], + ), + User(id=10), + ] @property def item_keyword_result(self): Item, Keyword = self.test.classes.Item, self.test.classes.Keyword return [ - Item(id=1, - keywords=[ - Keyword(name='red'), - Keyword(name='big'), - Keyword(name='round')]), - Item(id=2, - keywords=[ - Keyword(name='red'), - Keyword(name='small'), - Keyword(name='square')]), - Item(id=3, - keywords=[ - Keyword(name='green'), - Keyword(name='big'), - Keyword(name='round')]), - Item(id=4, - keywords=[]), - Item(id=5, - keywords=[])] + Item( + id=1, + keywords=[ + Keyword(name="red"), + Keyword(name="big"), + Keyword(name="round"), + ], + ), + Item( + id=2, + keywords=[ + Keyword(name="red"), + Keyword(name="small"), + Keyword(name="square"), + ], + ), + Item( + id=3, + keywords=[ + Keyword(name="green"), + Keyword(name="big"), + Keyword(name="round"), + ], + ), + Item(id=4, keywords=[]), + Item(id=5, keywords=[]), + ] @property def user_item_keyword_result(self): Item, Keyword = self.test.classes.Item, self.test.classes.Keyword User, Order = self.test.classes.User, self.test.classes.Order - item1, item2, item3, item4, item5 = \ - Item(id=1, - keywords=[ - Keyword(name='red'), - Keyword(name='big'), - Keyword(name='round')]),\ - Item(id=2, - keywords=[ - Keyword(name='red'), - Keyword(name='small'), - Keyword(name='square')]),\ - Item(id=3, - keywords=[ - Keyword(name='green'), - Keyword(name='big'), - Keyword(name='round')]),\ - Item(id=4, keywords=[]),\ - Item(id=5, keywords=[]) + item1, item2, item3, item4, item5 = ( + Item( + id=1, + keywords=[ + Keyword(name="red"), + Keyword(name="big"), + Keyword(name="round"), + ], + ), + Item( + id=2, + keywords=[ + Keyword(name="red"), + Keyword(name="small"), + Keyword(name="square"), + ], + ), + Item( + id=3, + keywords=[ + Keyword(name="green"), + Keyword(name="big"), + Keyword(name="round"), + ], + ), + Item(id=4, keywords=[]), + Item(id=5, keywords=[]), + ) user_result = [ - User(id=7, - orders=[ - Order(id=1, - items=[item1, item2, item3]), - Order(id=3, - items=[item3, item4, item5]), - Order(id=5, - items=[item5])]), + User( + id=7, + orders=[ + Order(id=1, items=[item1, item2, item3]), + Order(id=3, items=[item3, item4, item5]), + Order(id=5, items=[item5]), + ], + ), User(id=8, orders=[]), - User(id=9, - orders=[ - Order(id=2, - items=[item1, item2, item3]), - Order(id=4, - items=[item1, item5])]), - User(id=10, orders=[])] + User( + id=9, + orders=[ + Order(id=2, items=[item1, item2, item3]), + Order(id=4, items=[item1, item5]), + ], + ), + User(id=10, orders=[]), + ] return user_result diff --git a/test/orm/inheritance/_poly_fixtures.py b/test/orm/inheritance/_poly_fixtures.py index f1f9cd6f36..668ea61602 100644 --- a/test/orm/inheritance/_poly_fixtures.py +++ b/test/orm/inheritance/_poly_fixtures.py @@ -1,6 +1,10 @@ from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.orm import relationship, mapper, \ - create_session, polymorphic_union +from sqlalchemy.orm import ( + relationship, + mapper, + create_session, + polymorphic_union, +) from sqlalchemy.testing import AssertsCompiledSQL, fixtures from sqlalchemy.testing.schema import Table, Column @@ -44,8 +48,8 @@ class Page(fixtures.ComparableEntity): class _PolymorphicFixtureBase(fixtures.MappedTest, AssertsCompiledSQL): - run_inserts = 'once' - run_setup_mappers = 'once' + run_inserts = "once" + run_setup_mappers = "once" run_deletes = None @classmethod @@ -53,57 +57,96 @@ class _PolymorphicFixtureBase(fixtures.MappedTest, AssertsCompiledSQL): global people, engineers, managers, boss global companies, paperwork, machines - companies = Table('companies', metadata, - Column('company_id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50))) - - people = Table('people', metadata, - Column('person_id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('company_id', Integer, - ForeignKey('companies.company_id')), - Column('name', String(50)), - Column('type', String(30))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('status', String(30)), - Column('engineer_name', String(50)), - Column('primary_language', String(50))) - - machines = Table('machines', metadata, - Column('machine_id', - Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('engineer_id', Integer, - ForeignKey('engineers.person_id'))) - - managers = Table('managers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('status', String(30)), - Column('manager_name', String(50))) - - boss = Table('boss', metadata, - Column('boss_id', Integer, - ForeignKey('managers.person_id'), - primary_key=True), - Column('golf_swing', String(30))) - - paperwork = Table('paperwork', metadata, - Column('paperwork_id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('description', String(50)), - Column('person_id', Integer, - ForeignKey('people.person_id'))) + companies = Table( + "companies", + metadata, + Column( + "company_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + ) + + people = Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("company_id", Integer, ForeignKey("companies.company_id")), + Column("name", String(50)), + Column("type", String(30)), + ) + + engineers = Table( + "engineers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("status", String(30)), + Column("engineer_name", String(50)), + Column("primary_language", String(50)), + ) + + machines = Table( + "machines", + metadata, + Column( + "machine_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("engineer_id", Integer, ForeignKey("engineers.person_id")), + ) + + managers = Table( + "managers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("status", String(30)), + Column("manager_name", String(50)), + ) + + boss = Table( + "boss", + metadata, + Column( + "boss_id", + Integer, + ForeignKey("managers.person_id"), + primary_key=True, + ), + Column("golf_swing", String(30)), + ) + + paperwork = Table( + "paperwork", + metadata, + Column( + "paperwork_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("description", String(50)), + Column("person_id", Integer, ForeignKey("people.person_id")), + ) @classmethod def insert_data(cls): @@ -115,10 +158,10 @@ class _PolymorphicFixtureBase(fixtures.MappedTest, AssertsCompiledSQL): status="regular engineer", paperwork=[ Paperwork(description="tps report #1"), - Paperwork(description="tps report #2")], - machines=[ - Machine(name='IBM ThinkPad'), - Machine(name='IPhone')]) + Paperwork(description="tps report #2"), + ], + machines=[Machine(name="IBM ThinkPad"), Machine(name="IPhone")], + ) cls.e2 = e2 = Engineer( name="wally", @@ -127,15 +170,18 @@ class _PolymorphicFixtureBase(fixtures.MappedTest, AssertsCompiledSQL): status="regular engineer", paperwork=[ Paperwork(description="tps report #3"), - Paperwork(description="tps report #4")], - machines=[Machine(name="Commodore 64")]) + Paperwork(description="tps report #4"), + ], + machines=[Machine(name="Commodore 64")], + ) cls.b1 = b1 = Boss( name="pointy haired boss", golf_swing="fore", manager_name="pointy", status="da boss", - paperwork=[Paperwork(description="review #1")]) + paperwork=[Paperwork(description="review #1")], + ) cls.m1 = m1 = Manager( name="dogbert", @@ -143,18 +189,18 @@ class _PolymorphicFixtureBase(fixtures.MappedTest, AssertsCompiledSQL): status="regular manager", paperwork=[ Paperwork(description="review #2"), - Paperwork(description="review #3")]) + Paperwork(description="review #3"), + ], + ) cls.e3 = e3 = Engineer( name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer", - paperwork=[ - Paperwork(description='elbonian missive #3')], - machines=[ - Machine(name="Commodore 64"), - Machine(name="IBM 3270")]) + paperwork=[Paperwork(description="elbonian missive #3")], + machines=[Machine(name="Commodore 64"), Machine(name="IBM 3270")], + ) cls.c1 = c1 = Company(name="MegaCorp, Inc.") c1.employees = [e1, e2, b1, m1] @@ -177,9 +223,7 @@ class _PolymorphicFixtureBase(fixtures.MappedTest, AssertsCompiledSQL): Machine(name="IBM ThinkPad"), Machine(name="IPhone"), ] - fixture[0].employees[1].machines = [ - Machine(name="Commodore 64") - ] + fixture[0].employees[1].machines = [Machine(name="Commodore 64")] return fixture def _company_with_emps_fixture(self): @@ -191,23 +235,27 @@ class _PolymorphicFixtureBase(fixtures.MappedTest, AssertsCompiledSQL): name="dilbert", engineer_name="dilbert", primary_language="java", - status="regular engineer" + status="regular engineer", ), Engineer( name="wally", engineer_name="wally", primary_language="c++", - status="regular engineer"), + status="regular engineer", + ), Boss( name="pointy haired boss", golf_swing="fore", manager_name="pointy", - status="da boss"), + status="da boss", + ), Manager( name="dogbert", manager_name="dogbert", - status="regular manager"), - ]), + status="regular manager", + ), + ], + ), Company( name="Elbonia, Inc.", employees=[ @@ -215,8 +263,10 @@ class _PolymorphicFixtureBase(fixtures.MappedTest, AssertsCompiledSQL): name="vlad", engineer_name="vlad", primary_language="cobol", - status="elbonian engineer") - ]) + status="elbonian engineer", + ) + ], + ), ] def _emps_wo_relationships_fixture(self): @@ -225,66 +275,83 @@ class _PolymorphicFixtureBase(fixtures.MappedTest, AssertsCompiledSQL): name="dilbert", engineer_name="dilbert", primary_language="java", - status="regular engineer"), + status="regular engineer", + ), Engineer( name="wally", engineer_name="wally", primary_language="c++", - status="regular engineer"), + status="regular engineer", + ), Boss( name="pointy haired boss", golf_swing="fore", manager_name="pointy", - status="da boss"), + status="da boss", + ), Manager( name="dogbert", manager_name="dogbert", - status="regular manager"), + status="regular manager", + ), Engineer( name="vlad", engineer_name="vlad", primary_language="cobol", - status="elbonian engineer") + status="elbonian engineer", + ), ] @classmethod def setup_mappers(cls): - mapper(Company, companies, - properties={ - 'employees': relationship( - Person, - order_by=people.c.person_id)}) + mapper( + Company, + companies, + properties={ + "employees": relationship(Person, order_by=people.c.person_id) + }, + ) mapper(Machine, machines) - person_with_polymorphic,\ - manager_with_polymorphic = cls._get_polymorphics() - - mapper(Person, people, - with_polymorphic=person_with_polymorphic, - polymorphic_on=people.c.type, - polymorphic_identity='person', - properties={ - 'paperwork': relationship( - Paperwork, - order_by=paperwork.c.paperwork_id)}) - - mapper(Engineer, engineers, - inherits=Person, - polymorphic_identity='engineer', - properties={ - 'machines': relationship( - Machine, - order_by=machines.c.machine_id)}) - - mapper(Manager, managers, - with_polymorphic=manager_with_polymorphic, - inherits=Person, - polymorphic_identity='manager') - - mapper(Boss, boss, - inherits=Manager, - polymorphic_identity='boss') + person_with_polymorphic, manager_with_polymorphic = ( + cls._get_polymorphics() + ) + + mapper( + Person, + people, + with_polymorphic=person_with_polymorphic, + polymorphic_on=people.c.type, + polymorphic_identity="person", + properties={ + "paperwork": relationship( + Paperwork, order_by=paperwork.c.paperwork_id + ) + }, + ) + + mapper( + Engineer, + engineers, + inherits=Person, + polymorphic_identity="engineer", + properties={ + "machines": relationship( + Machine, order_by=machines.c.machine_id + ) + }, + ) + + mapper( + Manager, + managers, + with_polymorphic=manager_with_polymorphic, + inherits=Person, + polymorphic_identity="manager", + ) + + mapper(Boss, boss, inherits=Manager, polymorphic_identity="boss") mapper(Paperwork, paperwork) @@ -302,7 +369,7 @@ class _PolymorphicPolymorphic(_PolymorphicFixtureBase): @classmethod def _get_polymorphics(cls): - return '*', '*' + return "*", "*" class _PolymorphicUnions(_PolymorphicFixtureBase): @@ -310,19 +377,24 @@ class _PolymorphicUnions(_PolymorphicFixtureBase): @classmethod def _get_polymorphics(cls): - people, engineers, managers, boss = \ - cls.tables.people, cls.tables.engineers, \ - cls.tables.managers, cls.tables.boss - person_join = polymorphic_union({ - 'engineer': people.join(engineers), - 'manager': people.join(managers)}, - None, 'pjoin') + people, engineers, managers, boss = ( + cls.tables.people, + cls.tables.engineers, + cls.tables.managers, + cls.tables.boss, + ) + person_join = polymorphic_union( + { + "engineer": people.join(engineers), + "manager": people.join(managers), + }, + None, + "pjoin", + ) manager_join = people.join(managers).outerjoin(boss) - person_with_polymorphic = ( - [Person, Manager, Engineer], person_join) - manager_with_polymorphic = ('*', manager_join) - return person_with_polymorphic,\ - manager_with_polymorphic + person_with_polymorphic = ([Person, Manager, Engineer], person_join) + manager_with_polymorphic = ("*", manager_join) + return person_with_polymorphic, manager_with_polymorphic class _PolymorphicAliasedJoins(_PolymorphicFixtureBase): @@ -330,24 +402,27 @@ class _PolymorphicAliasedJoins(_PolymorphicFixtureBase): @classmethod def _get_polymorphics(cls): - people, engineers, managers, boss = \ - cls.tables.people, cls.tables.engineers, \ - cls.tables.managers, cls.tables.boss - person_join = people \ - .outerjoin(engineers) \ - .outerjoin(managers) \ - .select(use_labels=True) \ - .alias('pjoin') - manager_join = people \ - .join(managers) \ - .outerjoin(boss) \ - .select(use_labels=True) \ - .alias('mjoin') - person_with_polymorphic = ( - [Person, Manager, Engineer], person_join) - manager_with_polymorphic = ('*', manager_join) - return person_with_polymorphic,\ - manager_with_polymorphic + people, engineers, managers, boss = ( + cls.tables.people, + cls.tables.engineers, + cls.tables.managers, + cls.tables.boss, + ) + person_join = ( + people.outerjoin(engineers) + .outerjoin(managers) + .select(use_labels=True) + .alias("pjoin") + ) + manager_join = ( + people.join(managers) + .outerjoin(boss) + .select(use_labels=True) + .alias("mjoin") + ) + person_with_polymorphic = ([Person, Manager, Engineer], person_join) + manager_with_polymorphic = ("*", manager_join) + return person_with_polymorphic, manager_with_polymorphic class _PolymorphicJoins(_PolymorphicFixtureBase): @@ -355,16 +430,17 @@ class _PolymorphicJoins(_PolymorphicFixtureBase): @classmethod def _get_polymorphics(cls): - people, engineers, managers, boss = \ - cls.tables.people, cls.tables.engineers, \ - cls.tables.managers, cls.tables.boss + people, engineers, managers, boss = ( + cls.tables.people, + cls.tables.engineers, + cls.tables.managers, + cls.tables.boss, + ) person_join = people.outerjoin(engineers).outerjoin(managers) manager_join = people.join(managers).outerjoin(boss) - person_with_polymorphic = ( - [Person, Manager, Engineer], person_join) - manager_with_polymorphic = ('*', manager_join) - return person_with_polymorphic,\ - manager_with_polymorphic + person_with_polymorphic = ([Person, Manager, Engineer], person_join) + manager_with_polymorphic = ("*", manager_join) + return person_with_polymorphic, manager_with_polymorphic class GeometryFixtureBase(fixtures.DeclarativeMappedTest): @@ -432,10 +508,10 @@ class GeometryFixtureBase(fixtures.DeclarativeMappedTest): """ - run_create_tables = 'each' - run_define_tables = 'each' - run_setup_classes = 'each' - run_setup_mappers = 'each' + run_create_tables = "each" + run_define_tables = "each" + run_setup_classes = "each" + run_setup_mappers = "each" def _fixture_from_geometry(self, geometry, base=None): if not base: @@ -453,34 +529,30 @@ class GeometryFixtureBase(fixtures.DeclarativeMappedTest): "type": type_, "__mapper_args__": { "polymorphic_on": type_, - "polymorphic_identity": key - } - + "polymorphic_identity": key, + }, } else: - items = { - "__mapper_args__": { - "polymorphic_identity": key - } - } + items = {"__mapper_args__": {"polymorphic_identity": key}} if not value.get("single", False): items["__tablename__"] = key items["id"] = Column( ForeignKey("%s.id" % base.__tablename__), - primary_key=True) + primary_key=True, + ) items["%s_data" % key] = Column(String(50)) # add other mapper options to be transferred here as needed. - for mapper_opt in ("polymorphic_load", ): + for mapper_opt in ("polymorphic_load",): if mapper_opt in value: items["__mapper_args__"][mapper_opt] = value[mapper_opt] if is_base: - klass = type(key, (fixtures.ComparableEntity, base, ), items) + klass = type(key, (fixtures.ComparableEntity, base), items) else: - klass = type(key, (base, ), items) + klass = type(key, (base,), items) if "subclasses" in value: self._fixture_from_geometry(value["subclasses"], klass) @@ -488,4 +560,3 @@ class GeometryFixtureBase(fixtures.DeclarativeMappedTest): if is_base and self.metadata.tables and self.run_create_tables: self.tables.update(self.metadata.tables) self.metadata.create_all(config.db) - diff --git a/test/orm/inheritance/test_abc_inheritance.py b/test/orm/inheritance/test_abc_inheritance.py index fb62acb31f..55ae264968 100644 --- a/test/orm/inheritance/test_abc_inheritance.py +++ b/test/orm/inheritance/test_abc_inheritance.py @@ -15,59 +15,112 @@ def produce_test(parent, child, direction): the old "no discriminator column" pattern is used. """ + class ABCTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): global ta, tb, tc ta = ["a", metadata] - ta.append(Column('id', Integer, primary_key=True, - test_needs_autoincrement=True)), - ta.append(Column('a_data', String(30))) + ta.append( + Column( + "id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ) + ), + ta.append(Column("a_data", String(30))) if "a" == parent and direction == MANYTOONE: - ta.append(Column('child_id', Integer, ForeignKey( - "%s.id" % child, use_alter=True, name="foo"))) + ta.append( + Column( + "child_id", + Integer, + ForeignKey( + "%s.id" % child, use_alter=True, name="foo" + ), + ) + ) elif "a" == child and direction == ONETOMANY: - ta.append(Column('parent_id', Integer, ForeignKey( - "%s.id" % parent, use_alter=True, name="foo"))) + ta.append( + Column( + "parent_id", + Integer, + ForeignKey( + "%s.id" % parent, use_alter=True, name="foo" + ), + ) + ) ta = Table(*ta) tb = ["b", metadata] - tb.append(Column('id', Integer, ForeignKey( - "a.id"), primary_key=True, )) + tb.append( + Column("id", Integer, ForeignKey("a.id"), primary_key=True) + ) - tb.append(Column('b_data', String(30))) + tb.append(Column("b_data", String(30))) if "b" == parent and direction == MANYTOONE: - tb.append(Column('child_id', Integer, ForeignKey( - "%s.id" % child, use_alter=True, name="foo"))) + tb.append( + Column( + "child_id", + Integer, + ForeignKey( + "%s.id" % child, use_alter=True, name="foo" + ), + ) + ) elif "b" == child and direction == ONETOMANY: - tb.append(Column('parent_id', Integer, ForeignKey( - "%s.id" % parent, use_alter=True, name="foo"))) + tb.append( + Column( + "parent_id", + Integer, + ForeignKey( + "%s.id" % parent, use_alter=True, name="foo" + ), + ) + ) tb = Table(*tb) tc = ["c", metadata] - tc.append(Column('id', Integer, ForeignKey( - "b.id"), primary_key=True, )) + tc.append( + Column("id", Integer, ForeignKey("b.id"), primary_key=True) + ) - tc.append(Column('c_data', String(30))) + tc.append(Column("c_data", String(30))) if "c" == parent and direction == MANYTOONE: - tc.append(Column('child_id', Integer, ForeignKey( - "%s.id" % child, use_alter=True, name="foo"))) + tc.append( + Column( + "child_id", + Integer, + ForeignKey( + "%s.id" % child, use_alter=True, name="foo" + ), + ) + ) elif "c" == child and direction == ONETOMANY: - tc.append(Column('parent_id', Integer, ForeignKey( - "%s.id" % parent, use_alter=True, name="foo"))) + tc.append( + Column( + "parent_id", + Integer, + ForeignKey( + "%s.id" % parent, use_alter=True, name="foo" + ), + ) + ) tc = Table(*tc) def teardown(self): if direction == MANYTOONE: parent_table = {"a": ta, "b": tb, "c": tc}[parent] parent_table.update( - values={parent_table.c.child_id: None}).execute() + values={parent_table.c.child_id: None} + ).execute() elif direction == ONETOMANY: child_table = {"a": ta, "b": tb, "c": tc}[child] child_table.update( - values={child_table.c.parent_id: None}).execute() + values={child_table.c.parent_id: None} + ).execute() super(ABCTest, self).teardown() def test_roundtrip(self): @@ -92,19 +145,31 @@ def produce_test(parent, child, direction): remote_side = [child_table.c.id] abcjoin = polymorphic_union( - {"a": ta.select(tb.c.id == None, # noqa - from_obj=[ta.outerjoin(tb, onclause=atob)]), - "b": ta.join(tb, onclause=atob).outerjoin(tc, onclause=btoc)\ - .select(tc.c.id == None).reduce_columns(), # noqa - "c": tc.join(tb, onclause=btoc).join(ta, onclause=atob)}, - "type", "abcjoin" + { + "a": ta.select( + tb.c.id == None, # noqa + from_obj=[ta.outerjoin(tb, onclause=atob)], + ), + "b": ta.join(tb, onclause=atob) + .outerjoin(tc, onclause=btoc) + .select(tc.c.id == None) + .reduce_columns(), # noqa + "c": tc.join(tb, onclause=btoc).join(ta, onclause=atob), + }, + "type", + "abcjoin", ) bcjoin = polymorphic_union( - {"b": ta.join(tb, onclause=atob).outerjoin(tc, onclause=btoc) - .select(tc.c.id == None).reduce_columns(), # noqa - "c": tc.join(tb, onclause=btoc).join(ta, onclause=atob)}, - "type", "bcjoin" + { + "b": ta.join(tb, onclause=atob) + .outerjoin(tc, onclause=btoc) + .select(tc.c.id == None) + .reduce_columns(), # noqa + "c": tc.join(tb, onclause=btoc).join(ta, onclause=atob), + }, + "type", + "bcjoin", ) class A(object): @@ -117,15 +182,29 @@ def produce_test(parent, child, direction): class C(B): pass - mapper(A, ta, polymorphic_on=abcjoin.c.type, with_polymorphic=( - '*', abcjoin), polymorphic_identity="a") - mapper(B, tb, polymorphic_on=bcjoin.c.type, - with_polymorphic=('*', bcjoin), - polymorphic_identity="b", - inherits=A, - inherit_condition=atob) - mapper(C, tc, polymorphic_identity="c", - inherits=B, inherit_condition=btoc) + mapper( + A, + ta, + polymorphic_on=abcjoin.c.type, + with_polymorphic=("*", abcjoin), + polymorphic_identity="a", + ) + mapper( + B, + tb, + polymorphic_on=bcjoin.c.type, + with_polymorphic=("*", bcjoin), + polymorphic_identity="b", + inherits=A, + inherit_condition=atob, + ) + mapper( + C, + tc, + polymorphic_identity="c", + inherits=B, + inherit_condition=btoc, + ) parent_mapper = class_mapper({ta: A, tb: B, tc: C}[parent_table]) child_mapper = class_mapper({ta: A, tb: B, tc: C}[child_table]) @@ -135,19 +214,23 @@ def produce_test(parent, child, direction): parent_mapper.add_property( "collection", - relationship(child_mapper, - primaryjoin=relationshipjoin, - foreign_keys=foreign_keys, - order_by=child_mapper.c.id, - remote_side=remote_side, uselist=True)) + relationship( + child_mapper, + primaryjoin=relationshipjoin, + foreign_keys=foreign_keys, + order_by=child_mapper.c.id, + remote_side=remote_side, + uselist=True, + ), + ) sess = create_session() - parent_obj = parent_class('parent1') - child_obj = child_class('child1') - somea = A('somea') - someb = B('someb') - somec = C('somec') + parent_obj = parent_class("parent1") + child_obj = child_class("child1") + somea = A("somea") + someb = B("someb") + somec = C("somec") # print "APPENDING", parent.__class__.__name__ , "TO", # child.__class__.__name__ @@ -155,11 +238,11 @@ def produce_test(parent, child, direction): sess.add(parent_obj) parent_obj.collection.append(child_obj) if direction == ONETOMANY: - child2 = child_class('child2') + child2 = child_class("child2") parent_obj.collection.append(child2) sess.add(child2) elif direction == MANYTOONE: - parent2 = parent_class('parent2') + parent2 = parent_class("parent2") parent2.collection.append(child_obj) sess.add(parent2) sess.add(somea) @@ -193,7 +276,10 @@ def produce_test(parent, child, direction): assert result2.collection[0].id == child_obj.id ABCTest.__name__ = "Test%sTo%s%s" % ( - parent, child, (direction is ONETOMANY and "O2M" or "M2O")) + parent, + child, + (direction is ONETOMANY and "O2M" or "M2O"), + ) return ABCTest diff --git a/test/orm/inheritance/test_abc_polymorphic.py b/test/orm/inheritance/test_abc_polymorphic.py index 76703bffad..98166c21b5 100644 --- a/test/orm/inheritance/test_abc_polymorphic.py +++ b/test/orm/inheritance/test_abc_polymorphic.py @@ -13,18 +13,27 @@ class ABCTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): global a, b, c - a = Table('a', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('adata', String(30)), - Column('type', String(30)), - ) - b = Table('b', metadata, - Column('id', Integer, ForeignKey('a.id'), primary_key=True), - Column('bdata', String(30))) - c = Table('c', metadata, - Column('id', Integer, ForeignKey('b.id'), primary_key=True), - Column('cdata', String(30))) + a = Table( + "a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("adata", String(30)), + Column("type", String(30)), + ) + b = Table( + "b", + metadata, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + Column("bdata", String(30)), + ) + c = Table( + "c", + metadata, + Column("id", Integer, ForeignKey("b.id"), primary_key=True), + Column("cdata", String(30)), + ) def _make_test(fetchtype): def test_roundtrip(self): @@ -37,25 +46,35 @@ class ABCTest(fixtures.MappedTest): class C(B): pass - if fetchtype == 'union': + if fetchtype == "union": abc = a.outerjoin(b).outerjoin(c) bc = a.join(b).outerjoin(c) else: abc = bc = None - mapper(A, a, with_polymorphic=('*', abc), - polymorphic_on=a.c.type, polymorphic_identity='a') - mapper(B, b, with_polymorphic=('*', bc), - inherits=A, polymorphic_identity='b') - mapper(C, c, inherits=B, polymorphic_identity='c') + mapper( + A, + a, + with_polymorphic=("*", abc), + polymorphic_on=a.c.type, + polymorphic_identity="a", + ) + mapper( + B, + b, + with_polymorphic=("*", bc), + inherits=A, + polymorphic_identity="b", + ) + mapper(C, c, inherits=B, polymorphic_identity="c") - a1 = A(adata='a1') - b1 = B(bdata='b1', adata='b1') - b2 = B(bdata='b2', adata='b2') - b3 = B(bdata='b3', adata='b3') - c1 = C(cdata='c1', bdata='c1', adata='c1') - c2 = C(cdata='c2', bdata='c2', adata='c2') - c3 = C(cdata='c2', bdata='c2', adata='c2') + a1 = A(adata="a1") + b1 = B(bdata="b1", adata="b1") + b2 = B(bdata="b2", adata="b2") + b3 = B(bdata="b3", adata="b3") + c1 = C(cdata="c1", bdata="c1", adata="c1") + c2 = C(cdata="c2", bdata="c2", adata="c2") + c3 = C(cdata="c2", bdata="c2", adata="c2") sess = create_session() for x in (a1, b1, b2, b3, c1, c2, c3): @@ -65,33 +84,42 @@ class ABCTest(fixtures.MappedTest): # for obj in sess.query(A).all(): # print obj - eq_([A(adata='a1'), - B(bdata='b1', adata='b1'), - B(bdata='b2', adata='b2'), - B(bdata='b3', adata='b3'), - C(cdata='c1', bdata='c1', adata='c1'), - C(cdata='c2', bdata='c2', adata='c2'), - C(cdata='c2', bdata='c2', adata='c2')], - sess.query(A).order_by(A.id).all()) + eq_( + [ + A(adata="a1"), + B(bdata="b1", adata="b1"), + B(bdata="b2", adata="b2"), + B(bdata="b3", adata="b3"), + C(cdata="c1", bdata="c1", adata="c1"), + C(cdata="c2", bdata="c2", adata="c2"), + C(cdata="c2", bdata="c2", adata="c2"), + ], + sess.query(A).order_by(A.id).all(), + ) - eq_([ - B(bdata='b1', adata='b1'), - B(bdata='b2', adata='b2'), - B(bdata='b3', adata='b3'), - C(cdata='c1', bdata='c1', adata='c1'), - C(cdata='c2', bdata='c2', adata='c2'), - C(cdata='c2', bdata='c2', adata='c2'), - ], sess.query(B).order_by(A.id).all()) + eq_( + [ + B(bdata="b1", adata="b1"), + B(bdata="b2", adata="b2"), + B(bdata="b3", adata="b3"), + C(cdata="c1", bdata="c1", adata="c1"), + C(cdata="c2", bdata="c2", adata="c2"), + C(cdata="c2", bdata="c2", adata="c2"), + ], + sess.query(B).order_by(A.id).all(), + ) - eq_([ - C(cdata='c1', bdata='c1', adata='c1'), - C(cdata='c2', bdata='c2', adata='c2'), - C(cdata='c2', bdata='c2', adata='c2'), - ], sess.query(C).order_by(A.id).all()) + eq_( + [ + C(cdata="c1", bdata="c1", adata="c1"), + C(cdata="c2", bdata="c2", adata="c2"), + C(cdata="c2", bdata="c2", adata="c2"), + ], + sess.query(C).order_by(A.id).all(), + ) - test_roundtrip = function_named( - test_roundtrip, 'test_%s' % fetchtype) + test_roundtrip = function_named(test_roundtrip, "test_%s" % fetchtype) return test_roundtrip - test_union = _make_test('union') - test_none = _make_test('none') + test_union = _make_test("union") + test_none = _make_test("none") diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py index a6b5861ed8..43c667abd7 100644 --- a/test/orm/inheritance/test_assorted_poly.py +++ b/test/orm/inheritance/test_assorted_poly.py @@ -27,27 +27,43 @@ class AttrSettable(object): class RelationshipTest1(fixtures.MappedTest): """test self-referential relationships on polymorphic mappers""" + @classmethod def define_tables(cls, metadata): global people, managers - people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', - optional=True), - primary_key=True), - Column('manager_id', Integer, - ForeignKey('managers.person_id', - use_alter=True, name="mpid_fq")), - Column('name', String(50)), - Column('type', String(30))) - - managers = Table('managers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('status', String(30)), - Column('manager_name', String(50)) - ) + people = Table( + "people", + metadata, + Column( + "person_id", + Integer, + Sequence("person_id_seq", optional=True), + primary_key=True, + ), + Column( + "manager_id", + Integer, + ForeignKey( + "managers.person_id", use_alter=True, name="mpid_fq" + ), + ), + Column("name", String(50)), + Column("type", String(30)), + ) + + managers = Table( + "managers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("status", String(30)), + Column("manager_name", String(50)), + ) def teardown(self): people.update(values={people.c.manager_id: None}).execute() @@ -60,21 +76,33 @@ class RelationshipTest1(fixtures.MappedTest): class Manager(Person): pass - mapper(Person, people, properties={ - 'manager': relationship(Manager, primaryjoin=( - people.c.manager_id == - managers.c.person_id), - uselist=False, post_update=True) - }) - mapper(Manager, managers, inherits=Person, - inherit_condition=people.c.person_id == managers.c.person_id) + mapper( + Person, + people, + properties={ + "manager": relationship( + Manager, + primaryjoin=(people.c.manager_id == managers.c.person_id), + uselist=False, + post_update=True, + ) + }, + ) + mapper( + Manager, + managers, + inherits=Person, + inherit_condition=people.c.person_id == managers.c.person_id, + ) - eq_(class_mapper(Person).get_property('manager').synchronize_pairs, - [(managers.c.person_id, people.c.manager_id)]) + eq_( + class_mapper(Person).get_property("manager").synchronize_pairs, + [(managers.c.person_id, people.c.manager_id)], + ) session = create_session() - p = Person(name='some person') - m = Manager(name='some manager') + p = Person(name="some person") + m = Manager(name="some manager") p.manager = m session.add(p) session.flush() @@ -92,19 +120,25 @@ class RelationshipTest1(fixtures.MappedTest): pass mapper(Person, people) - mapper(Manager, managers, inherits=Person, - inherit_condition=people.c.person_id == - managers.c.person_id, - properties={ - 'employee': relationship(Person, primaryjoin=( - people.c.manager_id == - managers.c.person_id), - foreign_keys=[people.c.manager_id], - uselist=False, post_update=True)}) + mapper( + Manager, + managers, + inherits=Person, + inherit_condition=people.c.person_id == managers.c.person_id, + properties={ + "employee": relationship( + Person, + primaryjoin=(people.c.manager_id == managers.c.person_id), + foreign_keys=[people.c.manager_id], + uselist=False, + post_update=True, + ) + }, + ) session = create_session() - p = Person(name='some person') - m = Manager(name='some manager') + p = Person(name="some person") + m = Manager(name="some manager") m.employee = p session.add(m) session.flush() @@ -117,28 +151,47 @@ class RelationshipTest1(fixtures.MappedTest): class RelationshipTest2(fixtures.MappedTest): """test self-referential relationships on polymorphic mappers""" + @classmethod def define_tables(cls, metadata): global people, managers, data - people = Table('people', metadata, - Column('person_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('type', String(30))) - - managers = Table('managers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('manager_id', Integer, - ForeignKey('people.person_id')), - Column('status', String(30))) - - data = Table('data', metadata, - Column('person_id', Integer, - ForeignKey('managers.person_id'), - primary_key=True), - Column('data', String(30))) + people = Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("type", String(30)), + ) + + managers = Table( + "managers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("manager_id", Integer, ForeignKey("people.person_id")), + Column("status", String(30)), + ) + + data = Table( + "data", + metadata, + Column( + "person_id", + Integer, + ForeignKey("managers.person_id"), + primary_key=True, + ), + Column("data", String(30)), + ) def test_relationshiponsubclass_j1_nodata(self): self._do_test("join1", False) @@ -166,66 +219,91 @@ class RelationshipTest2(fixtures.MappedTest): pass if jointype == "join1": - poly_union = polymorphic_union({ - 'person': people.select(people.c.type == 'person'), - 'manager': join(people, managers, - people.c.person_id == managers.c.person_id) - }, None) + poly_union = polymorphic_union( + { + "person": people.select(people.c.type == "person"), + "manager": join( + people, + managers, + people.c.person_id == managers.c.person_id, + ), + }, + None, + ) polymorphic_on = poly_union.c.type elif jointype == "join2": - poly_union = polymorphic_union({ - 'person': people.select(people.c.type == 'person'), - 'manager': managers.join( - people, - people.c.person_id == managers.c.person_id) - }, None) + poly_union = polymorphic_union( + { + "person": people.select(people.c.type == "person"), + "manager": managers.join( + people, people.c.person_id == managers.c.person_id + ), + }, + None, + ) polymorphic_on = poly_union.c.type elif jointype == "join3": poly_union = None polymorphic_on = people.c.type if usedata: + class Data(object): def __init__(self, data): self.data = data + mapper(Data, data) - mapper(Person, people, - with_polymorphic=('*', poly_union), - polymorphic_identity='person', - polymorphic_on=polymorphic_on) + mapper( + Person, + people, + with_polymorphic=("*", poly_union), + polymorphic_identity="person", + polymorphic_on=polymorphic_on, + ) if usedata: - mapper(Manager, managers, - inherits=Person, - inherit_condition=people.c.person_id == - managers.c.person_id, - polymorphic_identity='manager', - properties={ - 'colleague': relationship( - Person, - primaryjoin=managers.c.manager_id == - people.c.person_id, - lazy='select', uselist=False), - 'data': relationship(Data, uselist=False)}) + mapper( + Manager, + managers, + inherits=Person, + inherit_condition=people.c.person_id == managers.c.person_id, + polymorphic_identity="manager", + properties={ + "colleague": relationship( + Person, + primaryjoin=managers.c.manager_id + == people.c.person_id, + lazy="select", + uselist=False, + ), + "data": relationship(Data, uselist=False), + }, + ) else: - mapper(Manager, managers, inherits=Person, - inherit_condition=people.c.person_id == - managers.c.person_id, - polymorphic_identity='manager', - properties={ - 'colleague': relationship( - Person, - primaryjoin=managers.c.manager_id == - people.c.person_id, - lazy='select', uselist=False)}) + mapper( + Manager, + managers, + inherits=Person, + inherit_condition=people.c.person_id == managers.c.person_id, + polymorphic_identity="manager", + properties={ + "colleague": relationship( + Person, + primaryjoin=managers.c.manager_id + == people.c.person_id, + lazy="select", + uselist=False, + ) + }, + ) sess = create_session() - p = Person(name='person1') - m = Manager(name='manager1') + p = Person(name="person1") + m = Manager(name="manager1") m.colleague = p if usedata: - m.data = Data('ms data') + m.data = Data("ms data") sess.add(m) sess.flush() @@ -234,33 +312,52 @@ class RelationshipTest2(fixtures.MappedTest): m = sess.query(Manager).get(m.person_id) assert m.colleague is p if usedata: - assert m.data.data == 'ms data' + assert m.data.data == "ms data" class RelationshipTest3(fixtures.MappedTest): """test self-referential relationships on polymorphic mappers""" + @classmethod def define_tables(cls, metadata): global people, managers, data - people = Table('people', metadata, - Column('person_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('colleague_id', Integer, - ForeignKey('people.person_id')), - Column('name', String(50)), - Column('type', String(30))) - - managers = Table('managers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('status', String(30))) - - data = Table('data', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('data', String(30))) + people = Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("colleague_id", Integer, ForeignKey("people.person_id")), + Column("name", String(50)), + Column("type", String(30)), + ) + + managers = Table( + "managers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("status", String(30)), + ) + + data = Table( + "data", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("data", String(30)), + ) def _generate_test(jointype="join1", usedata=False): @@ -272,24 +369,34 @@ def _generate_test(jointype="join1", usedata=False): pass if usedata: + class Data(object): def __init__(self, data): self.data = data if jointype == "join1": - poly_union = polymorphic_union({ - 'manager': managers.join( - people, - people.c.person_id == managers.c.person_id), - 'person': people.select(people.c.type == 'person') - }, None) + poly_union = polymorphic_union( + { + "manager": managers.join( + people, people.c.person_id == managers.c.person_id + ), + "person": people.select(people.c.type == "person"), + }, + None, + ) elif jointype == "join2": - poly_union = polymorphic_union({ - 'manager': join(people, managers, - people.c.person_id == managers.c.person_id), - 'person': people.select(people.c.type == 'person') - }, None) - elif jointype == 'join3': + poly_union = polymorphic_union( + { + "manager": join( + people, + managers, + people.c.person_id == managers.c.person_id, + ), + "person": people.select(people.c.type == "person"), + }, + None, + ) + elif jointype == "join3": poly_union = people.outerjoin(managers) elif jointype == "join4": poly_union = None @@ -298,45 +405,59 @@ def _generate_test(jointype="join1", usedata=False): mapper(Data, data) if usedata: - mapper(Person, people, - with_polymorphic=('*', poly_union), - polymorphic_identity='person', - polymorphic_on=people.c.type, - properties={ - 'colleagues': relationship( - Person, - primaryjoin=people.c.colleague_id == - people.c.person_id, - remote_side=people.c.colleague_id, - uselist=True), - 'data': relationship(Data, uselist=False)}) + mapper( + Person, + people, + with_polymorphic=("*", poly_union), + polymorphic_identity="person", + polymorphic_on=people.c.type, + properties={ + "colleagues": relationship( + Person, + primaryjoin=people.c.colleague_id + == people.c.person_id, + remote_side=people.c.colleague_id, + uselist=True, + ), + "data": relationship(Data, uselist=False), + }, + ) else: - mapper(Person, people, - with_polymorphic=('*', poly_union), - polymorphic_identity='person', - polymorphic_on=people.c.type, - properties={ - 'colleagues': relationship( - Person, - primaryjoin=people.c.colleague_id == - people.c.person_id, - remote_side=people.c.colleague_id, uselist=True)}) - - mapper(Manager, managers, inherits=Person, - inherit_condition=people.c.person_id == - managers.c.person_id, - polymorphic_identity='manager') + mapper( + Person, + people, + with_polymorphic=("*", poly_union), + polymorphic_identity="person", + polymorphic_on=people.c.type, + properties={ + "colleagues": relationship( + Person, + primaryjoin=people.c.colleague_id + == people.c.person_id, + remote_side=people.c.colleague_id, + uselist=True, + ) + }, + ) + + mapper( + Manager, + managers, + inherits=Person, + inherit_condition=people.c.person_id == managers.c.person_id, + polymorphic_identity="manager", + ) sess = create_session() - p = Person(name='person1') - p2 = Person(name='person2') - p3 = Person(name='person3') - m = Manager(name='manager1') + p = Person(name="person1") + p2 = Person(name="person2") + p3 = Person(name="person3") + m = Manager(name="manager1") p.colleagues.append(p2) m.colleagues.append(p3) if usedata: - p.data = Data('ps data') - m.data = Data('ms data') + p.data = Data("ps data") + m.data = Data("ms data") sess.add(m) sess.add(p) @@ -351,12 +472,14 @@ def _generate_test(jointype="join1", usedata=False): assert p.colleagues == [p2] assert m.colleagues == [p3] if usedata: - assert p.data.data == 'ps data' - assert m.data.data == 'ms data' + assert p.data.data == "ps data" + assert m.data.data == "ms data" do_test = function_named( - _do_test, 'test_relationship_on_base_class_%s_%s' % ( - jointype, data and "nodata" or "data")) + _do_test, + "test_relationship_on_base_class_%s_%s" + % (jointype, data and "nodata" or "data"), + ) return do_test @@ -371,27 +494,53 @@ class RelationshipTest4(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): global people, engineers, managers, cars - people = Table('people', metadata, - Column('person_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('status', String(30))) - - managers = Table('managers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('longer_status', String(70))) - - cars = Table('cars', metadata, - Column('car_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('owner', Integer, ForeignKey('people.person_id'))) + people = Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + ) + + engineers = Table( + "engineers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("status", String(30)), + ) + + managers = Table( + "managers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("longer_status", String(70)), + ) + + cars = Table( + "cars", + metadata, + Column( + "car_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("owner", Integer, ForeignKey("people.person_id")), + ) def test_many_to_one_polymorphic(self): """in this test, the polymorphic union is between two subclasses, but @@ -412,13 +561,14 @@ class RelationshipTest4(fixtures.MappedTest): class Engineer(Person): def __repr__(self): - return "Engineer %s, status %s" % \ - (self.name, self.status) + return "Engineer %s, status %s" % (self.name, self.status) class Manager(Person): def __repr__(self): - return "Manager %s, status %s" % \ - (self.name, self.longer_status) + return "Manager %s, status %s" % ( + self.name, + self.longer_status, + ) class Car(object): def __init__(self, **kwargs): @@ -431,40 +581,51 @@ class RelationshipTest4(fixtures.MappedTest): # create a union that represents both types of joins. employee_join = polymorphic_union( { - 'engineer': people.join(engineers), - 'manager': people.join(managers), - }, "type", 'employee_join') - - person_mapper = mapper(Person, people, - with_polymorphic=('*', employee_join), - polymorphic_on=employee_join.c.type, - polymorphic_identity='person') - engineer_mapper = mapper(Engineer, engineers, - inherits=person_mapper, - polymorphic_identity='engineer') - manager_mapper = mapper(Manager, managers, - inherits=person_mapper, - polymorphic_identity='manager') - car_mapper = mapper(Car, cars, - properties={'employee': - relationship(person_mapper)}) + "engineer": people.join(engineers), + "manager": people.join(managers), + }, + "type", + "employee_join", + ) + + person_mapper = mapper( + Person, + people, + with_polymorphic=("*", employee_join), + polymorphic_on=employee_join.c.type, + polymorphic_identity="person", + ) + engineer_mapper = mapper( + Engineer, + engineers, + inherits=person_mapper, + polymorphic_identity="engineer", + ) + manager_mapper = mapper( + Manager, + managers, + inherits=person_mapper, + polymorphic_identity="manager", + ) + car_mapper = mapper( + Car, cars, properties={"employee": relationship(person_mapper)} + ) session = create_session() # creating 5 managers named from M1 to E5 for i in range(1, 5): - session.add(Manager(name="M%d" % i, - longer_status="YYYYYYYYY")) + session.add(Manager(name="M%d" % i, longer_status="YYYYYYYYY")) # creating 5 engineers named from E1 to E5 for i in range(1, 5): session.add(Engineer(name="E%d" % i, status="X")) session.flush() - engineer4 = session.query(Engineer).\ - filter(Engineer.name == "E4").first() - manager3 = session.query(Manager).\ - filter(Manager.name == "M3").first() + engineer4 = ( + session.query(Engineer).filter(Engineer.name == "E4").first() + ) + manager3 = session.query(Manager).filter(Manager.name == "M3").first() car1 = Car(employee=engineer4) session.add(car1) @@ -475,10 +636,13 @@ class RelationshipTest4(fixtures.MappedTest): session.expunge_all() def go(): - testcar = session.query(Car).options( - joinedload('employee') - ).get(car1.car_id) + testcar = ( + session.query(Car) + .options(joinedload("employee")) + .get(car1.car_id) + ) assert str(testcar.employee) == "Engineer E4, status X" + self.assert_sql_count(testing.db, go, 1) car1 = session.query(Car).get(car1.car_id) @@ -493,10 +657,13 @@ class RelationshipTest4(fixtures.MappedTest): # and now for the lightning round, eager ! def go(): - testcar = session.query(Car).options( - joinedload('employee') - ).get(car1.car_id) + testcar = ( + session.query(Car) + .options(joinedload("employee")) + .get(car1.car_id) + ) assert str(testcar.employee) == "Engineer E4, status X" + self.assert_sql_count(testing.db, go, 1) session.expunge_all() @@ -509,28 +676,54 @@ class RelationshipTest5(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): global people, engineers, managers, cars - people = Table('people', metadata, - Column('person_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('type', String(50))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('status', String(30))) - - managers = Table('managers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('longer_status', String(70))) - - cars = Table('cars', metadata, - Column('car_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('owner', Integer, ForeignKey('people.person_id'))) + people = Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("type", String(50)), + ) + + engineers = Table( + "engineers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("status", String(30)), + ) + + managers = Table( + "managers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("longer_status", String(70)), + ) + + cars = Table( + "cars", + metadata, + Column( + "car_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("owner", Integer, ForeignKey("people.person_id")), + ) def test_eager_empty(self): """test parent object with child relationship to an inheriting mapper, @@ -546,13 +739,14 @@ class RelationshipTest5(fixtures.MappedTest): class Engineer(Person): def __repr__(self): - return "Engineer %s, status %s" % \ - (self.name, self.status) + return "Engineer %s, status %s" % (self.name, self.status) class Manager(Person): def __repr__(self): - return "Manager %s, status %s" % \ - (self.name, self.longer_status) + return "Manager %s, status %s" % ( + self.name, + self.longer_status, + ) class Car(object): def __init__(self, **kwargs): @@ -562,18 +756,31 @@ class RelationshipTest5(fixtures.MappedTest): def __repr__(self): return "Car number %d" % self.car_id - person_mapper = mapper(Person, people, - polymorphic_on=people.c.type, - polymorphic_identity='person') - engineer_mapper = mapper(Engineer, engineers, - inherits=person_mapper, - polymorphic_identity='engineer') - manager_mapper = mapper(Manager, managers, - inherits=person_mapper, - polymorphic_identity='manager') - car_mapper = mapper(Car, cars, properties={ - 'manager': relationship( - manager_mapper, lazy='joined')}) + person_mapper = mapper( + Person, + people, + polymorphic_on=people.c.type, + polymorphic_identity="person", + ) + engineer_mapper = mapper( + Engineer, + engineers, + inherits=person_mapper, + polymorphic_identity="engineer", + ) + manager_mapper = mapper( + Manager, + managers, + inherits=person_mapper, + polymorphic_identity="manager", + ) + car_mapper = mapper( + Car, + cars, + properties={ + "manager": relationship(manager_mapper, lazy="joined") + }, + ) sess = create_session() car1 = Car() @@ -596,20 +803,30 @@ class RelationshipTest6(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): global people, managers, data - people = Table('people', metadata, - Column('person_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - ) - - managers = Table('managers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('colleague_id', Integer, - ForeignKey('managers.person_id')), - Column('status', String(30)), - ) + people = Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + ) + + managers = Table( + "managers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("colleague_id", Integer, ForeignKey("managers.person_id")), + Column("status", String(30)), + ) def test_basic(self): class Person(AttrSettable): @@ -620,19 +837,25 @@ class RelationshipTest6(fixtures.MappedTest): mapper(Person, people) - mapper(Manager, managers, inherits=Person, - inherit_condition=people.c.person_id == - managers.c.person_id, - properties={ - 'colleague': relationship( - Manager, - primaryjoin=managers.c.colleague_id == - managers.c.person_id, - lazy='select', uselist=False)}) + mapper( + Manager, + managers, + inherits=Person, + inherit_condition=people.c.person_id == managers.c.person_id, + properties={ + "colleague": relationship( + Manager, + primaryjoin=managers.c.colleague_id + == managers.c.person_id, + lazy="select", + uselist=False, + ) + }, + ) sess = create_session() - m = Manager(name='manager1') - m2 = Manager(name='manager2') + m = Manager(name="manager1") + m2 = Manager(name="manager2") m.colleague = m2 sess.add(m) sess.flush() @@ -647,34 +870,68 @@ class RelationshipTest7(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): global people, engineers, managers, cars, offroad_cars - cars = Table('cars', metadata, - Column('car_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30))) - - offroad_cars = Table('offroad_cars', metadata, - Column('car_id', Integer, - ForeignKey('cars.car_id'), - nullable=False, primary_key=True)) - - people = Table('people', metadata, - Column('person_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('car_id', Integer, ForeignKey('cars.car_id'), - nullable=False), - Column('name', String(50))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('field', String(30))) - - managers = Table('managers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('category', String(70))) + cars = Table( + "cars", + metadata, + Column( + "car_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(30)), + ) + + offroad_cars = Table( + "offroad_cars", + metadata, + Column( + "car_id", + Integer, + ForeignKey("cars.car_id"), + nullable=False, + primary_key=True, + ), + ) + + people = Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column( + "car_id", Integer, ForeignKey("cars.car_id"), nullable=False + ), + Column("name", String(50)), + ) + + engineers = Table( + "engineers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("field", String(30)), + ) + + managers = Table( + "managers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("category", String(70)), + ) def test_manytoone_lazyload(self): """test that lazy load clause to a polymorphic child mapper generates @@ -695,56 +952,76 @@ class RelationshipTest7(fixtures.MappedTest): class Engineer(Person): def __repr__(self): - return "Engineer %s, field %s" % (self.name, - self.field) + return "Engineer %s, field %s" % (self.name, self.field) class Manager(Person): def __repr__(self): - return "Manager %s, category %s" % (self.name, - self.category) + return "Manager %s, category %s" % (self.name, self.category) class Car(PersistentObject): def __repr__(self): - return "Car number %d, name %s" % \ - (self.car_id, self.name) + return "Car number %d, name %s" % (self.car_id, self.name) class Offraod_Car(Car): def __repr__(self): - return "Offroad Car number %d, name %s" % \ - (self.car_id, self.name) + return "Offroad Car number %d, name %s" % ( + self.car_id, + self.name, + ) employee_join = polymorphic_union( { - 'engineer': people.join(engineers), - 'manager': people.join(managers), - }, "type", 'employee_join') + "engineer": people.join(engineers), + "manager": people.join(managers), + }, + "type", + "employee_join", + ) car_join = polymorphic_union( { - 'car': cars.outerjoin(offroad_cars). - select(offroad_cars.c.car_id == None).reduce_columns(), # noqa - 'offroad': cars.join(offroad_cars) - }, "type", 'car_join') - - car_mapper = mapper(Car, cars, - with_polymorphic=('*', car_join), - polymorphic_on=car_join.c.type, - polymorphic_identity='car') - offroad_car_mapper = mapper(Offraod_Car, offroad_cars, - inherits=car_mapper, - polymorphic_identity='offroad') - person_mapper = mapper(Person, people, - with_polymorphic=('*', employee_join), - polymorphic_on=employee_join.c.type, - polymorphic_identity='person', - properties={ - 'car': relationship(car_mapper)}) - engineer_mapper = mapper(Engineer, engineers, - inherits=person_mapper, - polymorphic_identity='engineer') - manager_mapper = mapper(Manager, managers, - inherits=person_mapper, - polymorphic_identity='manager') + "car": cars.outerjoin(offroad_cars) + .select(offroad_cars.c.car_id == None) + .reduce_columns(), # noqa + "offroad": cars.join(offroad_cars), + }, + "type", + "car_join", + ) + + car_mapper = mapper( + Car, + cars, + with_polymorphic=("*", car_join), + polymorphic_on=car_join.c.type, + polymorphic_identity="car", + ) + offroad_car_mapper = mapper( + Offraod_Car, + offroad_cars, + inherits=car_mapper, + polymorphic_identity="offroad", + ) + person_mapper = mapper( + Person, + people, + with_polymorphic=("*", employee_join), + polymorphic_on=employee_join.c.type, + polymorphic_identity="person", + properties={"car": relationship(car_mapper)}, + ) + engineer_mapper = mapper( + Engineer, + engineers, + inherits=person_mapper, + polymorphic_identity="engineer", + ) + manager_mapper = mapper( + Manager, + managers, + inherits=person_mapper, + polymorphic_identity="manager", + ) session = create_session() basic_car = Car(name="basic") @@ -755,8 +1032,7 @@ class RelationshipTest7(fixtures.MappedTest): car = Car() else: car = Offraod_Car() - session.add(Manager(name="M%d" % i, - category="YYYYYYYYY", car=car)) + session.add(Manager(name="M%d" % i, category="YYYYYYYYY", car=car)) session.add(Engineer(name="E%d" % i, field="X", car=car)) session.flush() session.expunge_all() @@ -770,17 +1046,21 @@ class RelationshipTest8(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): global taggable, users - taggable = Table('taggable', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(30)), - Column('owner_id', Integer, - ForeignKey('taggable.id')), - ) - users = Table('users', metadata, - Column('id', Integer, ForeignKey('taggable.id'), - primary_key=True), - Column('data', String(50))) + taggable = Table( + "taggable", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(30)), + Column("owner_id", Integer, ForeignKey("taggable.id")), + ) + users = Table( + "users", + metadata, + Column("id", Integer, ForeignKey("taggable.id"), primary_key=True), + Column("data", String(50)), + ) def test_selfref_onjoined(self): class Taggable(fixtures.ComparableEntity): @@ -789,20 +1069,29 @@ class RelationshipTest8(fixtures.MappedTest): class User(Taggable): pass - mapper(Taggable, taggable, - polymorphic_on=taggable.c.type, - polymorphic_identity='taggable', - properties={ - 'owner': relationship( - User, - primaryjoin=taggable.c.owner_id == taggable.c.id, - remote_side=taggable.c.id)}) + mapper( + Taggable, + taggable, + polymorphic_on=taggable.c.type, + polymorphic_identity="taggable", + properties={ + "owner": relationship( + User, + primaryjoin=taggable.c.owner_id == taggable.c.id, + remote_side=taggable.c.id, + ) + }, + ) - mapper(User, users, inherits=Taggable, - polymorphic_identity='user', - inherit_condition=users.c.id == taggable.c.id) + mapper( + User, + users, + inherits=Taggable, + polymorphic_identity="user", + inherit_condition=users.c.id == taggable.c.id, + ) - u1 = User(data='u1') + u1 = User(data="u1") t1 = Taggable(owner=u1) sess = create_session() sess.add(t1) @@ -811,7 +1100,7 @@ class RelationshipTest8(fixtures.MappedTest): sess.expunge_all() eq_( sess.query(Taggable).order_by(Taggable.id).all(), - [User(data='u1'), Taggable(owner=User(data='u1'))] + [User(data="u1"), Taggable(owner=User(data="u1"))], ) @@ -828,46 +1117,82 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): global metadata, status, people, engineers, managers, cars metadata = MetaData(testing.db) # table definitions - status = Table('status', metadata, - Column('status_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(20))) + status = Table( + "status", + metadata, + Column( + "status_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(20)), + ) people = Table( - 'people', metadata, + "people", + metadata, Column( - 'person_id', Integer, primary_key=True, - test_needs_autoincrement=True), + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), Column( - 'status_id', Integer, ForeignKey('status.status_id'), - nullable=False), - Column('name', String(50))) + "status_id", + Integer, + ForeignKey("status.status_id"), + nullable=False, + ), + Column("name", String(50)), + ) engineers = Table( - 'engineers', metadata, + "engineers", + metadata, Column( - 'person_id', Integer, ForeignKey('people.person_id'), - primary_key=True), - Column('field', String(30))) + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("field", String(30)), + ) managers = Table( - 'managers', metadata, + "managers", + metadata, Column( - 'person_id', Integer, ForeignKey('people.person_id'), - primary_key=True), - Column('category', String(70))) + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("category", String(70)), + ) cars = Table( - 'cars', metadata, + "cars", + metadata, Column( - 'car_id', Integer, primary_key=True, - test_needs_autoincrement=True), + "car_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), Column( - 'status_id', Integer, ForeignKey('status.status_id'), - nullable=False), + "status_id", + Integer, + ForeignKey("status.status_id"), + nullable=False, + ), Column( - 'owner', Integer, ForeignKey('people.person_id'), - nullable=False)) + "owner", + Integer, + ForeignKey("people.person_id"), + nullable=False, + ), + ) metadata.create_all() @@ -898,12 +1223,18 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): class Engineer(Person): def __repr__(self): return "Engineer %s, field %s, status %s" % ( - self.name, self.field, self.status) + self.name, + self.field, + self.status, + ) class Manager(Person): def __repr__(self): return "Manager %s, category %s, status %s" % ( - self.name, self.category, self.status) + self.name, + self.category, + self.status, + ) class Car(PersistentObject): def __repr__(self): @@ -912,24 +1243,42 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): # create a union that represents both types of joins. employee_join = polymorphic_union( { - 'engineer': people.join(engineers), - 'manager': people.join(managers), - }, "type", 'employee_join') + "engineer": people.join(engineers), + "manager": people.join(managers), + }, + "type", + "employee_join", + ) status_mapper = mapper(Status, status) person_mapper = mapper( - Person, people, with_polymorphic=('*', employee_join), - polymorphic_on=employee_join.c.type, polymorphic_identity='person', - properties={'status': relationship(status_mapper)}) - engineer_mapper = mapper(Engineer, engineers, - inherits=person_mapper, - polymorphic_identity='engineer') - manager_mapper = mapper(Manager, managers, - inherits=person_mapper, - polymorphic_identity='manager') - car_mapper = mapper(Car, cars, properties={ - 'employee': relationship(person_mapper), - 'status': relationship(status_mapper)}) + Person, + people, + with_polymorphic=("*", employee_join), + polymorphic_on=employee_join.c.type, + polymorphic_identity="person", + properties={"status": relationship(status_mapper)}, + ) + engineer_mapper = mapper( + Engineer, + engineers, + inherits=person_mapper, + polymorphic_identity="engineer", + ) + manager_mapper = mapper( + Manager, + managers, + inherits=person_mapper, + polymorphic_identity="manager", + ) + car_mapper = mapper( + Car, + cars, + properties={ + "employee": relationship(person_mapper), + "status": relationship(status_mapper), + }, + ) session = create_session() @@ -951,8 +1300,9 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): st = active else: st = dead - session.add(Manager(name="M%d" % i, - category="YYYYYYYYY", status=st)) + session.add( + Manager(name="M%d" % i, category="YYYYYYYYY", status=st) + ) session.add(Engineer(name="E%d" % i, field="X", status=st)) session.flush() @@ -972,23 +1322,38 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): e = exists([Car.owner], Car.owner == employee_join.c.person_id) Query(Person)._adapt_clause(employee_join, False, False) - r = session.query(Person).filter(Person.name.like('%2')).\ - join('status').\ - filter_by(name="active").\ - order_by(Person.person_id) - eq_(str(list(r)), "[Manager M2, category YYYYYYYYY, status " + r = ( + session.query(Person) + .filter(Person.name.like("%2")) + .join("status") + .filter_by(name="active") + .order_by(Person.person_id) + ) + eq_( + str(list(r)), + "[Manager M2, category YYYYYYYYY, status " "Status active, Engineer E2, field X, " - "status Status active]") - r = session.query(Engineer).join('status').\ - filter(Person.name.in_( - ['E2', 'E3', 'E4', 'M4', 'M2', 'M1']) & - (status.c.name == "active")).order_by(Person.name) - eq_(str(list(r)), "[Engineer E2, field X, status Status " + "status Status active]", + ) + r = ( + session.query(Engineer) + .join("status") + .filter( + Person.name.in_(["E2", "E3", "E4", "M4", "M2", "M1"]) + & (status.c.name == "active") + ) + .order_by(Person.name) + ) + eq_( + str(list(r)), + "[Engineer E2, field X, status Status " "active, Engineer E3, field X, status " - "Status active]") + "Status active]", + ) - r = session.query(Person).filter(exists([1], - Car.owner == Person.person_id)) + r = session.query(Person).filter( + exists([1], Car.owner == Person.person_id) + ) eq_(str(list(r)), "[Engineer E4, field X, status Status dead]") @@ -996,23 +1361,32 @@ class MultiLevelTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): global table_Employee, table_Engineer, table_Manager - table_Employee = Table('Employee', metadata, - Column('name', type_=String(100), ), - Column('id', primary_key=True, type_=Integer, - test_needs_autoincrement=True), - Column('atype', type_=String(100), ), - ) + table_Employee = Table( + "Employee", + metadata, + Column("name", type_=String(100)), + Column( + "id", + primary_key=True, + type_=Integer, + test_needs_autoincrement=True, + ), + Column("atype", type_=String(100)), + ) table_Engineer = Table( - 'Engineer', metadata, Column('machine', type_=String(100),), - Column( - 'id', Integer, ForeignKey('Employee.id',), - primary_key=True),) + "Engineer", + metadata, + Column("machine", type_=String(100)), + Column("id", Integer, ForeignKey("Employee.id"), primary_key=True), + ) - table_Manager = Table('Manager', metadata, - Column('duties', type_=String(100),), - Column('id', Integer, ForeignKey('Engineer.id'), - primary_key=True)) + table_Manager = Table( + "Manager", + metadata, + Column("duties", type_=String(100)), + Column("id", Integer, ForeignKey("Engineer.id"), primary_key=True), + ) def test_threelevels(self): class Employee(object): @@ -1022,7 +1396,8 @@ class MultiLevelTest(fixtures.MappedTest): return me def __str__(me): - return str(me.__class__.__name__) + ':' + str(me.name) + return str(me.__class__.__name__) + ":" + str(me.name) + __repr__ = __str__ class Engineer(Employee): @@ -1031,50 +1406,67 @@ class MultiLevelTest(fixtures.MappedTest): class Manager(Engineer): pass - pu_Employee = polymorphic_union({ - 'Manager': table_Employee.join( - table_Engineer).join(table_Manager), - 'Engineer': select([table_Employee, - table_Engineer.c.machine], - table_Employee.c.atype == 'Engineer', - from_obj=[ - table_Employee.join(table_Engineer)]), - 'Employee': table_Employee.select( - table_Employee.c.atype == 'Employee') - }, None, 'pu_employee') - - mapper_Employee = mapper(Employee, table_Employee, - polymorphic_identity='Employee', - polymorphic_on=pu_Employee.c.atype, - with_polymorphic=('*', pu_Employee), - ) - - pu_Engineer = polymorphic_union({ - 'Manager': table_Employee.join(table_Engineer). - join(table_Manager), - 'Engineer': select([table_Employee, - table_Engineer.c.machine], - table_Employee.c.atype == 'Engineer', - from_obj=[ - table_Employee.join(table_Engineer)]) - }, None, 'pu_engineer') - mapper_Engineer = mapper(Engineer, table_Engineer, - inherit_condition=table_Engineer.c.id == - table_Employee.c.id, - inherits=mapper_Employee, - polymorphic_identity='Engineer', - polymorphic_on=pu_Engineer.c.atype, - with_polymorphic=('*', pu_Engineer)) - - mapper_Manager = mapper(Manager, table_Manager, - inherit_condition=table_Manager.c.id == - table_Engineer.c.id, - inherits=mapper_Engineer, - polymorphic_identity='Manager') - - a = Employee().set(name='one') - b = Engineer().set(egn='two', machine='any') - c = Manager().set(name='head', machine='fast', duties='many') + pu_Employee = polymorphic_union( + { + "Manager": table_Employee.join(table_Engineer).join( + table_Manager + ), + "Engineer": select( + [table_Employee, table_Engineer.c.machine], + table_Employee.c.atype == "Engineer", + from_obj=[table_Employee.join(table_Engineer)], + ), + "Employee": table_Employee.select( + table_Employee.c.atype == "Employee" + ), + }, + None, + "pu_employee", + ) + + mapper_Employee = mapper( + Employee, + table_Employee, + polymorphic_identity="Employee", + polymorphic_on=pu_Employee.c.atype, + with_polymorphic=("*", pu_Employee), + ) + + pu_Engineer = polymorphic_union( + { + "Manager": table_Employee.join(table_Engineer).join( + table_Manager + ), + "Engineer": select( + [table_Employee, table_Engineer.c.machine], + table_Employee.c.atype == "Engineer", + from_obj=[table_Employee.join(table_Engineer)], + ), + }, + None, + "pu_engineer", + ) + mapper_Engineer = mapper( + Engineer, + table_Engineer, + inherit_condition=table_Engineer.c.id == table_Employee.c.id, + inherits=mapper_Employee, + polymorphic_identity="Engineer", + polymorphic_on=pu_Engineer.c.atype, + with_polymorphic=("*", pu_Engineer), + ) + + mapper_Manager = mapper( + Manager, + table_Manager, + inherit_condition=table_Manager.c.id == table_Engineer.c.id, + inherits=mapper_Engineer, + polymorphic_identity="Manager", + ) + + a = Employee().set(name="one") + b = Engineer().set(egn="two", machine="any") + c = Manager().set(name="head", machine="fast", duties="many") session = create_session() session.add(a) @@ -1089,30 +1481,40 @@ class MultiLevelTest(fixtures.MappedTest): class ManyToManyPolyTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - global base_item_table, item_table, base_item_collection_table, \ - collection_table + global base_item_table, item_table, base_item_collection_table, collection_table base_item_table = Table( - 'base_item', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('child_name', String(255), default=None)) + "base_item", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("child_name", String(255), default=None), + ) item_table = Table( - 'item', metadata, - Column('id', Integer, ForeignKey('base_item.id'), - primary_key=True), - Column('dummy', Integer, default=0)) + "item", + metadata, + Column( + "id", Integer, ForeignKey("base_item.id"), primary_key=True + ), + Column("dummy", Integer, default=0), + ) base_item_collection_table = Table( - 'base_item_collection', metadata, - Column('item_id', Integer, ForeignKey('base_item.id')), - Column('collection_id', Integer, ForeignKey('collection.id'))) + "base_item_collection", + metadata, + Column("item_id", Integer, ForeignKey("base_item.id")), + Column("collection_id", Integer, ForeignKey("collection.id")), + ) collection_table = Table( - 'collection', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', Unicode(255))) + "collection", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", Unicode(255)), + ) def test_pjoin_compile(self): """test that remote_side columns in the secondary join table @@ -1127,25 +1529,36 @@ class ManyToManyPolyTest(fixtures.MappedTest): class Collection(object): pass - item_join = polymorphic_union({ - 'BaseItem': base_item_table.select( - base_item_table.c.child_name == 'BaseItem'), - 'Item': base_item_table.join(item_table), - }, None, 'item_join') + + item_join = polymorphic_union( + { + "BaseItem": base_item_table.select( + base_item_table.c.child_name == "BaseItem" + ), + "Item": base_item_table.join(item_table), + }, + None, + "item_join", + ) mapper( - BaseItem, base_item_table, with_polymorphic=('*', item_join), + BaseItem, + base_item_table, + with_polymorphic=("*", item_join), polymorphic_on=base_item_table.c.child_name, - polymorphic_identity='BaseItem', + polymorphic_identity="BaseItem", properties=dict( collections=relationship( - Collection, secondary=base_item_collection_table, - backref="items"))) + Collection, + secondary=base_item_collection_table, + backref="items", + ) + ), + ) mapper( - Item, item_table, - inherits=BaseItem, - polymorphic_identity='Item') + Item, item_table, inherits=BaseItem, polymorphic_identity="Item" + ) mapper(Collection, collection_table) @@ -1156,16 +1569,22 @@ class CustomPKTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): global t1, t2 - t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(30), nullable=False), - Column('data', String(30))) + t1 = Table( + "t1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(30), nullable=False), + Column("data", String(30)), + ) # note that the primary key column in t2 is named differently - t2 = Table('t2', metadata, - Column('t2id', Integer, ForeignKey( - 't1.id'), primary_key=True), - Column('t2data', String(30))) + t2 = Table( + "t2", + metadata, + Column("t2id", Integer, ForeignKey("t1.id"), primary_key=True), + Column("t2data", String(30)), + ) def test_custompk(self): """test that the primary_key attribute is propagated to the @@ -1183,15 +1602,19 @@ class CustomPKTest(fixtures.MappedTest): # a 2-col pk in any case but the leading select has a NULL for the # "t2id" column d = util.OrderedDict() - d['t1'] = t1.select(t1.c.type == 't1') - d['t2'] = t1.join(t2) - pjoin = polymorphic_union(d, None, 'pjoin') - - mapper(T1, t1, polymorphic_on=t1.c.type, - polymorphic_identity='t1', - with_polymorphic=('*', pjoin), - primary_key=[pjoin.c.id]) - mapper(T2, t2, inherits=T1, polymorphic_identity='t2') + d["t1"] = t1.select(t1.c.type == "t1") + d["t2"] = t1.join(t2) + pjoin = polymorphic_union(d, None, "pjoin") + + mapper( + T1, + t1, + polymorphic_on=t1.c.type, + polymorphic_identity="t1", + with_polymorphic=("*", pjoin), + primary_key=[pjoin.c.id], + ) + mapper(T2, t2, inherits=T1, polymorphic_identity="t2") ot1 = T1() ot2 = T2() sess = create_session() @@ -1206,7 +1629,7 @@ class CustomPKTest(fixtures.MappedTest): assert sess.query(T1).get(ot1.id).id == ot1.id ot1 = sess.query(T1).get(ot1.id) - ot1.data = 'hi' + ot1.data = "hi" sess.flush() def test_pk_collapses(self): @@ -1225,14 +1648,18 @@ class CustomPKTest(fixtures.MappedTest): # a 2-col pk in any case but the leading select has a NULL for the # "t2id" column d = util.OrderedDict() - d['t1'] = t1.select(t1.c.type == 't1') - d['t2'] = t1.join(t2) - pjoin = polymorphic_union(d, None, 'pjoin') - - mapper(T1, t1, polymorphic_on=t1.c.type, - polymorphic_identity='t1', - with_polymorphic=('*', pjoin)) - mapper(T2, t2, inherits=T1, polymorphic_identity='t2') + d["t1"] = t1.select(t1.c.type == "t1") + d["t2"] = t1.join(t2) + pjoin = polymorphic_union(d, None, "pjoin") + + mapper( + T1, + t1, + polymorphic_on=t1.c.type, + polymorphic_identity="t1", + with_polymorphic=("*", pjoin), + ) + mapper(T2, t2, inherits=T1, polymorphic_identity="t2") assert len(class_mapper(T1).primary_key) == 1 ot1 = T1() @@ -1249,7 +1676,7 @@ class CustomPKTest(fixtures.MappedTest): assert sess.query(T1).get(ot1.id).id == ot1.id ot1 = sess.query(T1).get(ot1.id) - ot1.data = 'hi' + ot1.data = "hi" sess.flush() @@ -1258,25 +1685,36 @@ class InheritingEagerTest(fixtures.MappedTest): def define_tables(cls, metadata): global people, employees, tags, peopleTags - people = Table('people', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('_type', String(30), nullable=False)) + people = Table( + "people", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("_type", String(30), nullable=False), + ) - employees = Table('employees', metadata, - Column('id', Integer, ForeignKey('people.id'), - primary_key=True)) + employees = Table( + "employees", + metadata, + Column("id", Integer, ForeignKey("people.id"), primary_key=True), + ) - tags = Table('tags', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('label', String(50), nullable=False)) + tags = Table( + "tags", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("label", String(50), nullable=False), + ) - peopleTags = Table('peopleTags', metadata, - Column('person_id', Integer, - ForeignKey('people.id')), - Column('tag_id', Integer, - ForeignKey('tags.id'))) + peopleTags = Table( + "peopleTags", + metadata, + Column("person_id", Integer, ForeignKey("people.id")), + Column("tag_id", Integer, ForeignKey("tags.id")), + ) def test_basic(self): """test that Query uses the full set of mapper._eager_loaders @@ -1286,20 +1724,30 @@ class InheritingEagerTest(fixtures.MappedTest): pass class Employee(Person): - def __init__(self, name='bob'): + def __init__(self, name="bob"): self.name = name class Tag(fixtures.ComparableEntity): def __init__(self, label): self.label = label - mapper(Person, people, polymorphic_on=people.c._type, - polymorphic_identity='person', properties={ - 'tags': relationship(Tag, - secondary=peopleTags, - backref='people', lazy='joined')}) - mapper(Employee, employees, inherits=Person, - polymorphic_identity='employee') + mapper( + Person, + people, + polymorphic_on=people.c._type, + polymorphic_identity="person", + properties={ + "tags": relationship( + Tag, secondary=peopleTags, backref="people", lazy="joined" + ) + }, + ) + mapper( + Employee, + employees, + inherits=Person, + polymorphic_identity="employee", + ) mapper(Tag, tags) session = create_session() @@ -1307,41 +1755,52 @@ class InheritingEagerTest(fixtures.MappedTest): bob = Employee() session.add(bob) - tag = Tag('crazy') + tag = Tag("crazy") bob.tags.append(tag) - tag = Tag('funny') + tag = Tag("funny") bob.tags.append(tag) session.flush() session.expunge_all() # query from Employee with limit, query needs to apply eager limiting # subquery - instance = session.query(Employee).\ - filter_by(id=1).limit(1).first() + instance = session.query(Employee).filter_by(id=1).limit(1).first() assert len(instance.tags) == 2 class MissingPolymorphicOnTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - tablea = Table('tablea', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('adata', String(50))) - tableb = Table('tableb', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('aid', Integer, ForeignKey('tablea.id')), - Column('data', String(50))) - tablec = Table('tablec', metadata, - Column('id', Integer, ForeignKey('tablea.id'), - primary_key=True), - Column('cdata', String(50))) - tabled = Table('tabled', metadata, - Column('id', Integer, ForeignKey('tablec.id'), - primary_key=True), - Column('ddata', String(50))) + tablea = Table( + "tablea", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("adata", String(50)), + ) + tableb = Table( + "tableb", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("aid", Integer, ForeignKey("tablea.id")), + Column("data", String(50)), + ) + tablec = Table( + "tablec", + metadata, + Column("id", Integer, ForeignKey("tablea.id"), primary_key=True), + Column("cdata", String(50)), + ) + tabled = Table( + "tabled", + metadata, + Column("id", Integer, ForeignKey("tablec.id"), primary_key=True), + Column("ddata", String(50)), + ) @classmethod def setup_classes(cls): @@ -1358,24 +1817,36 @@ class MissingPolymorphicOnTest(fixtures.MappedTest): pass def test_polyon_col_setsup(self): - tablea, tableb, tablec, tabled = self.tables.tablea, \ - self.tables.tableb, self.tables.tablec, self.tables.tabled - A, B, C, D = self.classes.A, self.classes.B, self.classes.C, \ - self.classes.D + tablea, tableb, tablec, tabled = ( + self.tables.tablea, + self.tables.tableb, + self.tables.tablec, + self.tables.tabled, + ) + A, B, C, D = ( + self.classes.A, + self.classes.B, + self.classes.C, + self.classes.D, + ) poly_select = select( - [tablea, tableb.c.data.label('discriminator')], - from_obj=tablea.join(tableb)).alias('poly') + [tablea, tableb.c.data.label("discriminator")], + from_obj=tablea.join(tableb), + ).alias("poly") mapper(B, tableb) - mapper(A, tablea, - with_polymorphic=('*', poly_select), - polymorphic_on=poly_select.c.discriminator, - properties={'b': relationship(B, uselist=False)}) - mapper(C, tablec, inherits=A, polymorphic_identity='c') - mapper(D, tabled, inherits=C, polymorphic_identity='d') - - c = C(cdata='c1', adata='a1', b=B(data='c')) - d = D(cdata='c2', adata='a2', ddata='d2', b=B(data='d')) + mapper( + A, + tablea, + with_polymorphic=("*", poly_select), + polymorphic_on=poly_select.c.discriminator, + properties={"b": relationship(B, uselist=False)}, + ) + mapper(C, tablec, inherits=A, polymorphic_identity="c") + mapper(D, tabled, inherits=C, polymorphic_identity="d") + + c = C(cdata="c1", adata="a1", b=B(data="c")) + d = D(cdata="c2", adata="a2", ddata="d2", b=B(data="d")) sess = create_session() sess.add(c) sess.add(d) @@ -1383,26 +1854,32 @@ class MissingPolymorphicOnTest(fixtures.MappedTest): sess.expunge_all() eq_( sess.query(A).all(), - [ - C(cdata='c1', adata='a1'), - D(cdata='c2', adata='a2', ddata='d2') - ] + [C(cdata="c1", adata="a1"), D(cdata="c2", adata="a2", ddata="d2")], ) class JoinedInhAdjacencyTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('people', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(30))) - Table('users', metadata, - Column('id', Integer, ForeignKey('people.id'), - primary_key=True), - Column('supervisor_id', Integer, ForeignKey('people.id'))) - Table('dudes', metadata, - Column('id', Integer, ForeignKey('users.id'), primary_key=True)) + Table( + "people", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(30)), + ) + Table( + "users", + metadata, + Column("id", Integer, ForeignKey("people.id"), primary_key=True), + Column("supervisor_id", Integer, ForeignKey("people.id")), + ) + Table( + "dudes", + metadata, + Column("id", Integer, ForeignKey("users.id"), primary_key=True), + ) @classmethod def setup_classes(cls): @@ -1441,16 +1918,24 @@ class JoinedInhAdjacencyTest(fixtures.MappedTest): people, users = self.tables.people, self.tables.users Person, User = self.classes.Person, self.classes.User - mapper(Person, people, - polymorphic_on=people.c.type, - polymorphic_identity='person') - mapper(User, users, inherits=Person, - polymorphic_identity='user', - inherit_condition=(users.c.id == people.c.id), - properties={ - 'supervisor': relationship( - Person, - primaryjoin=users.c.supervisor_id == people.c.id)}) + mapper( + Person, + people, + polymorphic_on=people.c.type, + polymorphic_identity="person", + ) + mapper( + User, + users, + inherits=Person, + polymorphic_identity="user", + inherit_condition=(users.c.id == people.c.id), + properties={ + "supervisor": relationship( + Person, primaryjoin=users.c.supervisor_id == people.c.id + ) + }, + ) assert User.supervisor.property.direction is MANYTOONE self._roundtrip() @@ -1459,44 +1944,70 @@ class JoinedInhAdjacencyTest(fixtures.MappedTest): people, users = self.tables.people, self.tables.users Person, User = self.classes.Person, self.classes.User - mapper(Person, people, - polymorphic_on=people.c.type, - polymorphic_identity='person') - mapper(User, users, inherits=Person, - polymorphic_identity='user', - inherit_condition=(users.c.id == people.c.id), - properties={ - 'supervisor': relationship( - User, - primaryjoin=users.c.supervisor_id == people.c.id, - remote_side=people.c.id, - foreign_keys=[ - users.c.supervisor_id])}) + mapper( + Person, + people, + polymorphic_on=people.c.type, + polymorphic_identity="person", + ) + mapper( + User, + users, + inherits=Person, + polymorphic_identity="user", + inherit_condition=(users.c.id == people.c.id), + properties={ + "supervisor": relationship( + User, + primaryjoin=users.c.supervisor_id == people.c.id, + remote_side=people.c.id, + foreign_keys=[users.c.supervisor_id], + ) + }, + ) assert User.supervisor.property.direction is MANYTOONE self._roundtrip() def test_joined_subclass_to_superclass(self): - people, users, dudes = self.tables.people, self.tables.users, \ - self.tables.dudes - Person, User, Dude = self.classes.Person, self.classes.User, \ - self.classes.Dude - - mapper(Person, people, - polymorphic_on=people.c.type, - polymorphic_identity='person') - mapper(User, users, inherits=Person, - polymorphic_identity='user', - inherit_condition=(users.c.id == people.c.id)) - mapper(Dude, dudes, inherits=User, - polymorphic_identity='dude', - inherit_condition=(dudes.c.id == users.c.id), - properties={ - 'supervisor': relationship( - User, - primaryjoin=users.c.supervisor_id == people.c.id, - remote_side=people.c.id, - foreign_keys=[ - users.c.supervisor_id])}) + people, users, dudes = ( + self.tables.people, + self.tables.users, + self.tables.dudes, + ) + Person, User, Dude = ( + self.classes.Person, + self.classes.User, + self.classes.Dude, + ) + + mapper( + Person, + people, + polymorphic_on=people.c.type, + polymorphic_identity="person", + ) + mapper( + User, + users, + inherits=Person, + polymorphic_identity="user", + inherit_condition=(users.c.id == people.c.id), + ) + mapper( + Dude, + dudes, + inherits=User, + polymorphic_identity="dude", + inherit_condition=(dudes.c.id == users.c.id), + properties={ + "supervisor": relationship( + User, + primaryjoin=users.c.supervisor_id == people.c.id, + remote_side=people.c.id, + foreign_keys=[users.c.supervisor_id], + ) + }, + ) assert Dude.supervisor.property.direction is MANYTOONE self._dude_roundtrip() @@ -1511,82 +2022,84 @@ class Ticket2419Test(fixtures.DeclarativeMappedTest): class A(Base): __tablename__ = "a" - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) class B(Base): __tablename__ = "b" - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) ds = relationship("D") es = relationship("E") class C(A): __tablename__ = "c" - id = Column(Integer, ForeignKey('a.id'), primary_key=True) - b_id = Column(Integer, ForeignKey('b.id')) + id = Column(Integer, ForeignKey("a.id"), primary_key=True) + b_id = Column(Integer, ForeignKey("b.id")) b = relationship("B", primaryjoin=b_id == B.id) class D(Base): __tablename__ = "d" - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) - b_id = Column(Integer, ForeignKey('b.id')) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + b_id = Column(Integer, ForeignKey("b.id")) class E(Base): - __tablename__ = 'e' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) - b_id = Column(Integer, ForeignKey('b.id')) - - @testing.fails_on(["oracle", "mssql"], - "Oracle / SQL server engines can't handle this, " - "not clear if there's an expression-level bug on our " - "end though") + __tablename__ = "e" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + b_id = Column(Integer, ForeignKey("b.id")) + + @testing.fails_on( + ["oracle", "mssql"], + "Oracle / SQL server engines can't handle this, " + "not clear if there's an expression-level bug on our " + "end though", + ) def test_join_w_eager_w_any(self): - A, B, C, D, E = (self.classes.A, - self.classes.B, - self.classes.C, - self.classes.D, - self.classes.E) + A, B, C, D, E = ( + self.classes.A, + self.classes.B, + self.classes.C, + self.classes.D, + self.classes.E, + ) s = Session(testing.db) b = B(ds=[D()]) - s.add_all([ - C( - b=b - ) - - ]) + s.add_all([C(b=b)]) s.commit() q = s.query(B, B.ds.any(D.id == 1)).options(joinedload_all("es")) q = q.join(C, C.b_id == B.id) q = q.limit(5) - eq_( - q.all(), - [(b, True)] - ) + eq_(q.all(), [(b, True)]) -class ColSubclassTest(fixtures.DeclarativeMappedTest, - testing.AssertsCompiledSQL): +class ColSubclassTest( + fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL +): """Test [ticket:2918]'s test case.""" run_create_tables = run_deletes = None - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_classes(cls): from sqlalchemy.schema import Column + Base = cls.DeclarativeBasic class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) @@ -1594,9 +2107,9 @@ class ColSubclassTest(fixtures.DeclarativeMappedTest, pass class B(A): - __tablename__ = 'b' + __tablename__ = "b" - id = Column(ForeignKey('a.id'), primary_key=True) + id = Column(ForeignKey("a.id"), primary_key=True) x = MySpecialColumn(String) def test_polymorphic_adaptation(self): @@ -1604,8 +2117,8 @@ class ColSubclassTest(fixtures.DeclarativeMappedTest, s = Session() self.assert_compile( - s.query(A).join(B).filter(B.x == 'test'), + s.query(A).join(B).filter(B.x == "test"), "SELECT a.id AS a_id FROM a JOIN " "(a AS a_1 JOIN b AS b_1 ON a_1.id = b_1.id) " - "ON a.id = b_1.id WHERE b_1.x = :x_1" + "ON a.id = b_1.id WHERE b_1.x = :x_1", ) diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 7fd9329f9d..d871a6a98c 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -19,25 +19,33 @@ from sqlalchemy.testing.util import gc_collect class O2MTest(fixtures.MappedTest): """deals with inheritance and one-to-many relationships""" + @classmethod def define_tables(cls, metadata): global foo, bar, blub - foo = Table('foo', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(20))) - - bar = Table('bar', metadata, - Column('id', Integer, ForeignKey( - 'foo.id'), primary_key=True), - Column('bar_data', String(20))) - - blub = Table('blub', metadata, - Column('id', Integer, ForeignKey( - 'bar.id'), primary_key=True), - Column('foo_id', Integer, ForeignKey( - 'foo.id'), nullable=False), - Column('blub_data', String(20))) + foo = Table( + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(20)), + ) + + bar = Table( + "bar", + metadata, + Column("id", Integer, ForeignKey("foo.id"), primary_key=True), + Column("bar_data", String(20)), + ) + + blub = Table( + "blub", + metadata, + Column("id", Integer, ForeignKey("bar.id"), primary_key=True), + Column("foo_id", Integer, ForeignKey("foo.id"), nullable=False), + Column("blub_data", String(20)), + ) def test_basic(self): class Foo(object): @@ -46,6 +54,7 @@ class O2MTest(fixtures.MappedTest): def __repr__(self): return "Foo id %d, data %s" % (self.id, self.data) + mapper(Foo, foo) class Bar(Foo): @@ -58,9 +67,12 @@ class O2MTest(fixtures.MappedTest): def __repr__(self): return "Blub id %d, data %s" % (self.id, self.data) - mapper(Blub, blub, inherits=Bar, properties={ - 'parent_foo': relationship(Foo) - }) + mapper( + Blub, + blub, + inherits=Bar, + properties={"parent_foo": relationship(Foo)}, + ) sess = create_session() b1 = Blub("blub #1") @@ -72,59 +84,64 @@ class O2MTest(fixtures.MappedTest): b1.parent_foo = f b2.parent_foo = f sess.flush() - compare = ','.join([repr(b1), repr(b2), repr(b1.parent_foo), - repr(b2.parent_foo)]) + compare = ",".join( + [repr(b1), repr(b2), repr(b1.parent_foo), repr(b2.parent_foo)] + ) sess.expunge_all() result = sess.query(Blub).all() - result_str = ','.join([repr(result[0]), repr(result[1]), - repr(result[0].parent_foo), - repr(result[1].parent_foo)]) + result_str = ",".join( + [ + repr(result[0]), + repr(result[1]), + repr(result[0].parent_foo), + repr(result[1].parent_foo), + ] + ) eq_(compare, result_str) - eq_(result[0].parent_foo.data, 'foo #1') - eq_(result[1].parent_foo.data, 'foo #1') + eq_(result[0].parent_foo.data, "foo #1") + eq_(result[1].parent_foo.data, "foo #1") class PolyExpressionEagerLoad(fixtures.DeclarativeMappedTest): - run_setup_mappers = 'once' - __dialect__ = 'default' + run_setup_mappers = "once" + __dialect__ = "default" @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class A(fixtures.ComparableEntity, Base): - __tablename__ = 'a' + __tablename__ = "a" - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) discriminator = Column(String(50), nullable=False) - child_id = Column(Integer, ForeignKey('a.id')) - child = relationship('A') + child_id = Column(Integer, ForeignKey("a.id")) + child = relationship("A") - p_a = case([ - (discriminator == "a", "a"), - ], else_="b") + p_a = case([(discriminator == "a", "a")], else_="b") __mapper_args__ = { - 'polymorphic_identity': 'a', + "polymorphic_identity": "a", "polymorphic_on": p_a, } class B(A): - __mapper_args__ = { - 'polymorphic_identity': 'b' - } + __mapper_args__ = {"polymorphic_identity": "b"} @classmethod def insert_data(cls): A = cls.classes.A session = Session(testing.db) - session.add_all([ - A(id=1, discriminator='a'), - A(id=2, discriminator='b', child_id=1), - A(id=3, discriminator='c', child_id=1), - ]) + session.add_all( + [ + A(id=1, discriminator="a"), + A(id=2, discriminator="b", child_id=1), + A(id=3, discriminator="c", child_id=1), + ] + ) session.commit() def test_joinedload(self): @@ -132,46 +149,49 @@ class PolyExpressionEagerLoad(fixtures.DeclarativeMappedTest): B = self.classes.B session = Session(testing.db) - result = session.query(A).filter_by(child_id=None).\ - options(joinedload('child')).one() - - eq_( - result, - A(id=1, discriminator='a', child=[B(id=2), B(id=3)]), + result = ( + session.query(A) + .filter_by(child_id=None) + .options(joinedload("child")) + .one() ) + eq_(result, A(id=1, discriminator="a", child=[B(id=2), B(id=3)])) + -class PolymorphicResolutionMultiLevel(fixtures.DeclarativeMappedTest, - testing.AssertsCompiledSQL): - run_setup_mappers = 'once' - __dialect__ = 'default' +class PolymorphicResolutionMultiLevel( + fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL +): + run_setup_mappers = "once" + __dialect__ = "default" @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) class B(A): - __tablename__ = 'b' - id = Column(Integer, ForeignKey('a.id'), primary_key=True) + __tablename__ = "b" + id = Column(Integer, ForeignKey("a.id"), primary_key=True) class C(A): - __tablename__ = 'c' - id = Column(Integer, ForeignKey('a.id'), primary_key=True) + __tablename__ = "c" + id = Column(Integer, ForeignKey("a.id"), primary_key=True) class D(B): - __tablename__ = 'd' - id = Column(Integer, ForeignKey('b.id'), primary_key=True) + __tablename__ = "d" + id = Column(Integer, ForeignKey("b.id"), primary_key=True) def test_ordered_b_d(self): a_mapper = inspect(self.classes.A) eq_( a_mapper._mappers_from_spec( - [self.classes.B, self.classes.D], None), - [a_mapper, inspect(self.classes.B), inspect(self.classes.D)] + [self.classes.B, self.classes.D], None + ), + [a_mapper, inspect(self.classes.B), inspect(self.classes.D)], ) def test_a(self): @@ -183,10 +203,9 @@ class PolymorphicResolutionMultiLevel(fixtures.DeclarativeMappedTest, spec = [self.classes.D, self.classes.B] eq_( a_mapper._mappers_from_spec( - spec, - self.classes.B.__table__.join(self.classes.D.__table__) + spec, self.classes.B.__table__.join(self.classes.D.__table__) ), - [inspect(self.classes.B), inspect(self.classes.D)] + [inspect(self.classes.B), inspect(self.classes.D)], ) def test_d_selectable(self): @@ -194,10 +213,9 @@ class PolymorphicResolutionMultiLevel(fixtures.DeclarativeMappedTest, spec = [self.classes.D] eq_( a_mapper._mappers_from_spec( - spec, - self.classes.B.__table__.join(self.classes.D.__table__) + spec, self.classes.B.__table__.join(self.classes.D.__table__) ), - [inspect(self.classes.D)] + [inspect(self.classes.D)], ) def test_reverse_d_b(self): @@ -205,52 +223,63 @@ class PolymorphicResolutionMultiLevel(fixtures.DeclarativeMappedTest, spec = [self.classes.D, self.classes.B] eq_( a_mapper._mappers_from_spec(spec, None), - [a_mapper, inspect(self.classes.B), inspect(self.classes.D)] + [a_mapper, inspect(self.classes.B), inspect(self.classes.D)], ) mappers, selectable = a_mapper._with_polymorphic_args(spec=spec) - self.assert_compile(selectable, - "a LEFT OUTER JOIN b ON a.id = b.id " - "LEFT OUTER JOIN d ON b.id = d.id") + self.assert_compile( + selectable, + "a LEFT OUTER JOIN b ON a.id = b.id " + "LEFT OUTER JOIN d ON b.id = d.id", + ) def test_d_b_missing(self): a_mapper = inspect(self.classes.A) spec = [self.classes.D] eq_( a_mapper._mappers_from_spec(spec, None), - [a_mapper, inspect(self.classes.B), inspect(self.classes.D)] + [a_mapper, inspect(self.classes.B), inspect(self.classes.D)], ) mappers, selectable = a_mapper._with_polymorphic_args(spec=spec) - self.assert_compile(selectable, - "a LEFT OUTER JOIN b ON a.id = b.id " - "LEFT OUTER JOIN d ON b.id = d.id") + self.assert_compile( + selectable, + "a LEFT OUTER JOIN b ON a.id = b.id " + "LEFT OUTER JOIN d ON b.id = d.id", + ) def test_d_c_b(self): a_mapper = inspect(self.classes.A) spec = [self.classes.D, self.classes.C, self.classes.B] ms = a_mapper._mappers_from_spec(spec, None) - eq_( - ms[-1], inspect(self.classes.D) - ) + eq_(ms[-1], inspect(self.classes.D)) eq_(ms[0], a_mapper) - eq_( - set(ms[1:3]), set(a_mapper._inheriting_mappers) - ) + eq_(set(ms[1:3]), set(a_mapper._inheriting_mappers)) class PolymorphicOnNotLocalTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('x', String(10)), - Column('q', String(10))) - t2 = Table('t2', metadata, - Column('t2id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('y', String(10)), - Column('xid', ForeignKey('t1.id'))) + t1 = Table( + "t1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("x", String(10)), + Column("q", String(10)), + ) + t2 = Table( + "t2", + metadata, + Column( + "t2id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("y", String(10)), + Column("xid", ForeignKey("t1.id")), + ) @classmethod def setup_classes(cls): @@ -269,7 +298,9 @@ class PolymorphicOnNotLocalTest(fixtures.MappedTest): "value 'im not a column' - no " "attribute is mapped to this name.", mapper, - Parent, t2, polymorphic_on="im not a column" + Parent, + t2, + polymorphic_on="im not a column", ) def test_polymorphic_on_non_expr_prop(self): @@ -279,15 +310,15 @@ class PolymorphicOnNotLocalTest(fixtures.MappedTest): t1t2_join = select([t1.c.x], from_obj=[t1.join(t2)]).alias() def go(): - interface_m = mapper(Parent, t2, - polymorphic_on=lambda: "hi", - polymorphic_identity=0) + interface_m = mapper( + Parent, t2, polymorphic_on=lambda: "hi", polymorphic_identity=0 + ) assert_raises_message( sa_exc.ArgumentError, "Only direct column-mapped property or " "SQL expression can be passed for polymorphic_on", - go + go, ) def test_polymorphic_on_not_present_col(self): @@ -297,15 +328,19 @@ class PolymorphicOnNotLocalTest(fixtures.MappedTest): def go(): t1t2_join_2 = select([t1.c.q], from_obj=[t1.join(t2)]).alias() - interface_m = mapper(Parent, t2, - polymorphic_on=t1t2_join.c.x, - with_polymorphic=('*', t1t2_join_2), - polymorphic_identity=0) + interface_m = mapper( + Parent, + t2, + polymorphic_on=t1t2_join.c.x, + with_polymorphic=("*", t1t2_join_2), + polymorphic_identity=0, + ) + assert_raises_message( sa_exc.InvalidRequestError, "Could not map polymorphic_on column 'x' to the mapped table - " "polymorphic loads will not function properly", - go + go, ) def test_polymorphic_on_only_in_with_poly(self): @@ -313,10 +348,13 @@ class PolymorphicOnNotLocalTest(fixtures.MappedTest): Parent = self.classes.Parent t1t2_join = select([t1.c.x], from_obj=[t1.join(t2)]).alias() # if its in the with_polymorphic, then its OK - mapper(Parent, t2, - polymorphic_on=t1t2_join.c.x, - with_polymorphic=('*', t1t2_join), - polymorphic_identity=0) + mapper( + Parent, + t2, + polymorphic_on=t1t2_join.c.x, + with_polymorphic=("*", t1t2_join), + polymorphic_identity=0, + ) def test_polymorpic_on_not_in_with_poly(self): t2, t1 = self.tables.t2, self.tables.t1 @@ -327,123 +365,112 @@ class PolymorphicOnNotLocalTest(fixtures.MappedTest): # if with_polymorphic, but its not present, not OK def go(): t1t2_join_2 = select([t1.c.q], from_obj=[t1.join(t2)]).alias() - interface_m = mapper(Parent, t2, - polymorphic_on=t1t2_join.c.x, - with_polymorphic=('*', t1t2_join_2), - polymorphic_identity=0) + interface_m = mapper( + Parent, + t2, + polymorphic_on=t1t2_join.c.x, + with_polymorphic=("*", t1t2_join_2), + polymorphic_identity=0, + ) + assert_raises_message( sa_exc.InvalidRequestError, "Could not map polymorphic_on column 'x' " "to the mapped table - " "polymorphic loads will not function properly", - go + go, ) def test_polymorphic_on_expr_explicit_map(self): t2, t1 = self.tables.t2, self.tables.t1 Parent, Child = self.classes.Parent, self.classes.Child - expr = case([ - (t1.c.x == "p", "parent"), - (t1.c.x == "c", "child"), - ]) - mapper(Parent, t1, properties={ - "discriminator": column_property(expr) - }, polymorphic_identity="parent", - polymorphic_on=expr) - mapper(Child, t2, inherits=Parent, - polymorphic_identity="child") - - self._roundtrip(parent_ident='p', child_ident='c') + expr = case([(t1.c.x == "p", "parent"), (t1.c.x == "c", "child")]) + mapper( + Parent, + t1, + properties={"discriminator": column_property(expr)}, + polymorphic_identity="parent", + polymorphic_on=expr, + ) + mapper(Child, t2, inherits=Parent, polymorphic_identity="child") + + self._roundtrip(parent_ident="p", child_ident="c") def test_polymorphic_on_expr_implicit_map_no_label_joined(self): t2, t1 = self.tables.t2, self.tables.t1 Parent, Child = self.classes.Parent, self.classes.Child - expr = case([ - (t1.c.x == "p", "parent"), - (t1.c.x == "c", "child"), - ]) - mapper(Parent, t1, polymorphic_identity="parent", - polymorphic_on=expr) + expr = case([(t1.c.x == "p", "parent"), (t1.c.x == "c", "child")]) + mapper(Parent, t1, polymorphic_identity="parent", polymorphic_on=expr) mapper(Child, t2, inherits=Parent, polymorphic_identity="child") - self._roundtrip(parent_ident='p', child_ident='c') + self._roundtrip(parent_ident="p", child_ident="c") def test_polymorphic_on_expr_implicit_map_w_label_joined(self): t2, t1 = self.tables.t2, self.tables.t1 Parent, Child = self.classes.Parent, self.classes.Child - expr = case([ - (t1.c.x == "p", "parent"), - (t1.c.x == "c", "child"), - ]).label(None) - mapper(Parent, t1, polymorphic_identity="parent", - polymorphic_on=expr) + expr = case( + [(t1.c.x == "p", "parent"), (t1.c.x == "c", "child")] + ).label(None) + mapper(Parent, t1, polymorphic_identity="parent", polymorphic_on=expr) mapper(Child, t2, inherits=Parent, polymorphic_identity="child") - self._roundtrip(parent_ident='p', child_ident='c') + self._roundtrip(parent_ident="p", child_ident="c") def test_polymorphic_on_expr_implicit_map_no_label_single(self): """test that single_table_criterion is propagated with a standalone expr""" t2, t1 = self.tables.t2, self.tables.t1 Parent, Child = self.classes.Parent, self.classes.Child - expr = case([ - (t1.c.x == "p", "parent"), - (t1.c.x == "c", "child"), - ]) - mapper(Parent, t1, polymorphic_identity="parent", - polymorphic_on=expr) + expr = case([(t1.c.x == "p", "parent"), (t1.c.x == "c", "child")]) + mapper(Parent, t1, polymorphic_identity="parent", polymorphic_on=expr) mapper(Child, inherits=Parent, polymorphic_identity="child") - self._roundtrip(parent_ident='p', child_ident='c') + self._roundtrip(parent_ident="p", child_ident="c") def test_polymorphic_on_expr_implicit_map_w_label_single(self): """test that single_table_criterion is propagated with a standalone expr""" t2, t1 = self.tables.t2, self.tables.t1 Parent, Child = self.classes.Parent, self.classes.Child - expr = case([ - (t1.c.x == "p", "parent"), - (t1.c.x == "c", "child"), - ]).label(None) - mapper(Parent, t1, polymorphic_identity="parent", - polymorphic_on=expr) + expr = case( + [(t1.c.x == "p", "parent"), (t1.c.x == "c", "child")] + ).label(None) + mapper(Parent, t1, polymorphic_identity="parent", polymorphic_on=expr) mapper(Child, inherits=Parent, polymorphic_identity="child") - self._roundtrip(parent_ident='p', child_ident='c') + self._roundtrip(parent_ident="p", child_ident="c") def test_polymorphic_on_column_prop(self): t2, t1 = self.tables.t2, self.tables.t1 Parent, Child = self.classes.Parent, self.classes.Child - expr = case([ - (t1.c.x == "p", "parent"), - (t1.c.x == "c", "child"), - ]) + expr = case([(t1.c.x == "p", "parent"), (t1.c.x == "c", "child")]) cprop = column_property(expr) - mapper(Parent, t1, properties={ - "discriminator": cprop - }, polymorphic_identity="parent", - polymorphic_on=cprop) - mapper(Child, t2, inherits=Parent, - polymorphic_identity="child") + mapper( + Parent, + t1, + properties={"discriminator": cprop}, + polymorphic_identity="parent", + polymorphic_on=cprop, + ) + mapper(Child, t2, inherits=Parent, polymorphic_identity="child") - self._roundtrip(parent_ident='p', child_ident='c') + self._roundtrip(parent_ident="p", child_ident="c") def test_polymorphic_on_column_str_prop(self): t2, t1 = self.tables.t2, self.tables.t1 Parent, Child = self.classes.Parent, self.classes.Child - expr = case([ - (t1.c.x == "p", "parent"), - (t1.c.x == "c", "child"), - ]) + expr = case([(t1.c.x == "p", "parent"), (t1.c.x == "c", "child")]) cprop = column_property(expr) - mapper(Parent, t1, properties={ - "discriminator": cprop - }, polymorphic_identity="parent", - polymorphic_on="discriminator") - mapper(Child, t2, inherits=Parent, - polymorphic_identity="child") + mapper( + Parent, + t1, + properties={"discriminator": cprop}, + polymorphic_identity="parent", + polymorphic_on="discriminator", + ) + mapper(Child, t2, inherits=Parent, polymorphic_identity="child") - self._roundtrip(parent_ident='p', child_ident='c') + self._roundtrip(parent_ident="p", child_ident="c") def test_polymorphic_on_synonym(self): t2, t1 = self.tables.t2, self.tables.t1 @@ -453,14 +480,17 @@ class PolymorphicOnNotLocalTest(fixtures.MappedTest): sa_exc.ArgumentError, "Only direct column-mapped property or " "SQL expression can be passed for polymorphic_on", - mapper, Parent, t1, properties={ - "discriminator": cprop, - "discrim_syn": synonym(cprop) - }, polymorphic_identity="parent", - polymorphic_on="discrim_syn") - - def _roundtrip(self, set_event=True, parent_ident='parent', - child_ident='child'): + mapper, + Parent, + t1, + properties={"discriminator": cprop, "discrim_syn": synonym(cprop)}, + polymorphic_identity="parent", + polymorphic_on="discrim_syn", + ) + + def _roundtrip( + self, set_event=True, parent_ident="parent", child_ident="child" + ): Parent, Child = self.classes.Parent, self.classes.Child # locate the "polymorphic_on" ColumnProperty. This isn't @@ -471,55 +501,57 @@ class PolymorphicOnNotLocalTest(fixtures.MappedTest): break else: prop = parent_mapper._columntoproperty[ - parent_mapper.polymorphic_on] + parent_mapper.polymorphic_on + ] # then make sure the column we will query on matches. - is_( - parent_mapper.polymorphic_on, - prop.columns[0] - ) + is_(parent_mapper.polymorphic_on, prop.columns[0]) if set_event: + @event.listens_for(Parent, "init", propagate=True) def set_identity(instance, *arg, **kw): ident = object_mapper(instance).polymorphic_identity - if ident == 'parent': + if ident == "parent": instance.x = parent_ident - elif ident == 'child': + elif ident == "child": instance.x = child_ident else: assert False, "Got unexpected identity %r" % ident s = Session(testing.db) - s.add_all([ - Parent(q="p1"), - Child(q="c1", y="c1"), - Parent(q="p2"), - ]) + s.add_all([Parent(q="p1"), Child(q="c1", y="c1"), Parent(q="p2")]) s.commit() s.close() eq_( [type(t) for t in s.query(Parent).order_by(Parent.id)], - [Parent, Child, Parent] + [Parent, Child, Parent], ) - eq_( - [type(t) for t in s.query(Child).all()], - [Child] - ) + eq_([type(t) for t in s.query(Child).all()], [Child]) class SortOnlyOnImportantFKsTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('a', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('b_id', Integer, - ForeignKey('b.id', use_alter=True, name='b_fk'))) - Table('b', metadata, - Column('id', Integer, ForeignKey('a.id'), primary_key=True)) + Table( + "a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column( + "b_id", + Integer, + ForeignKey("b.id", use_alter=True, name="b_fk"), + ), + ) + Table( + "b", + metadata, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + ) @classmethod def setup_classes(cls): @@ -528,16 +560,17 @@ class SortOnlyOnImportantFKsTest(fixtures.MappedTest): class A(Base): __tablename__ = "a" - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) - b_id = Column(Integer, ForeignKey('b.id')) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + b_id = Column(Integer, ForeignKey("b.id")) class B(A): __tablename__ = "b" - id = Column(Integer, ForeignKey('a.id'), primary_key=True) + id = Column(Integer, ForeignKey("a.id"), primary_key=True) - __mapper_args__ = {'inherit_condition': id == A.id} + __mapper_args__ = {"inherit_condition": id == A.id} cls.classes.A = A cls.classes.B = B @@ -552,10 +585,14 @@ class FalseDiscriminatorTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): global t1 - t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', Boolean, nullable=False)) + t1 = Table( + "t1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", Boolean, nullable=False), + ) def test_false_on_sub(self): class Foo(object): @@ -563,6 +600,7 @@ class FalseDiscriminatorTest(fixtures.MappedTest): class Bar(Foo): pass + mapper(Foo, t1, polymorphic_on=t1.c.type, polymorphic_identity=True) mapper(Bar, inherits=Foo, polymorphic_identity=False) sess = create_session() @@ -579,6 +617,7 @@ class FalseDiscriminatorTest(fixtures.MappedTest): class Bat(Ding): pass + mapper(Ding, t1, polymorphic_on=t1.c.type, polymorphic_identity=False) mapper(Bat, inherits=Ding, polymorphic_identity=True) sess = create_session() @@ -594,15 +633,21 @@ class PolymorphicSynonymTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): global t1, t2 - t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(10), nullable=False), - Column('info', String(255))) - t2 = Table('t2', metadata, - Column('id', Integer, ForeignKey('t1.id'), - primary_key=True), - Column('data', String(10), nullable=False)) + t1 = Table( + "t1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(10), nullable=False), + Column("info", String(255)), + ) + t2 = Table( + "t2", + metadata, + Column("id", Integer, ForeignKey("t1.id"), primary_key=True), + Column("data", String(10), nullable=False), + ) def test_polymorphic_synonym(self): class T1(fixtures.ComparableEntity): @@ -611,50 +656,66 @@ class PolymorphicSynonymTest(fixtures.MappedTest): def _set_info(self, x): self._info = x + info = property(info, _set_info) class T2(T1): pass - mapper(T1, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1', - properties={'info': synonym('_info', map_column=True)}) - mapper(T2, t2, inherits=T1, polymorphic_identity='t2') + mapper( + T1, + t1, + polymorphic_on=t1.c.type, + polymorphic_identity="t1", + properties={"info": synonym("_info", map_column=True)}, + ) + mapper(T2, t2, inherits=T1, polymorphic_identity="t2") sess = create_session() - at1 = T1(info='at1') - at2 = T2(info='at2', data='t2 data') + at1 = T1(info="at1") + at2 = T2(info="at2", data="t2 data") sess.add(at1) sess.add(at2) sess.flush() sess.expunge_all() - eq_(sess.query(T2).filter(T2.info == 'at2').one(), at2) + eq_(sess.query(T2).filter(T2.info == "at2").one(), at2) eq_(at2.info, "THE INFO IS:at2") class PolymorphicAttributeManagementTest(fixtures.MappedTest): """Test polymorphic_on can be assigned, can be mirrored, etc.""" - run_setup_mappers = 'once' + run_setup_mappers = "once" @classmethod def define_tables(cls, metadata): - Table('table_a', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('class_name', String(50))) - Table('table_b', metadata, - Column('id', Integer, ForeignKey('table_a.id'), - primary_key=True), - Column('class_name', String(50))) - Table('table_c', metadata, - Column('id', Integer, ForeignKey('table_b.id'), - primary_key=True), - Column('data', String(10))) + Table( + "table_a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("class_name", String(50)), + ) + Table( + "table_b", + metadata, + Column("id", Integer, ForeignKey("table_a.id"), primary_key=True), + Column("class_name", String(50)), + ) + Table( + "table_c", + metadata, + Column("id", Integer, ForeignKey("table_b.id"), primary_key=True), + Column("data", String(10)), + ) @classmethod def setup_classes(cls): - table_b, table_c, table_a = (cls.tables.table_b, - cls.tables.table_c, - cls.tables.table_a) + table_b, table_c, table_a = ( + cls.tables.table_b, + cls.tables.table_c, + cls.tables.table_a, + ) class A(cls.Basic): pass @@ -668,15 +729,24 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest): class D(B): pass - mapper(A, table_a, - polymorphic_on=table_a.c.class_name, - polymorphic_identity='a') - mapper(B, table_b, inherits=A, polymorphic_on=table_b.c.class_name, - polymorphic_identity='b', - properties=dict( - class_name=[table_a.c.class_name, table_b.c.class_name])) - mapper(C, table_c, inherits=B, polymorphic_identity='c') - mapper(D, inherits=B, polymorphic_identity='d') + mapper( + A, + table_a, + polymorphic_on=table_a.c.class_name, + polymorphic_identity="a", + ) + mapper( + B, + table_b, + inherits=A, + polymorphic_on=table_b.c.class_name, + polymorphic_identity="b", + properties=dict( + class_name=[table_a.c.class_name, table_b.c.class_name] + ), + ) + mapper(C, table_c, inherits=B, polymorphic_identity="c") + mapper(D, inherits=B, polymorphic_identity="d") def test_poly_configured_immediate(self): A, C, B = (self.classes.A, self.classes.C, self.classes.B) @@ -684,9 +754,9 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest): a = A() b = B() c = C() - eq_(a.class_name, 'a') - eq_(b.class_name, 'b') - eq_(c.class_name, 'c') + eq_(a.class_name, "a") + eq_(b.class_name, "b") + eq_(c.class_name, "c") def test_base_class(self): A, C, B = (self.classes.A, self.classes.C, self.classes.B) @@ -710,7 +780,7 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest): sess = Session() b1 = B() - b1.class_name = 'd' + b1.class_name = "d" sess.add(b1) sess.commit() sess.close() @@ -724,14 +794,14 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest): sess = Session() c1 = C() - c1.class_name = 'b' + c1.class_name = "b" sess.add(c1) assert_raises_message( sa_exc.SAWarning, "Flushing object %s with incompatible " "polymorphic identity 'b'; the object may not " "refresh and/or load correctly" % instance_str(c1), - sess.flush + sess.flush, ) def test_invalid_assignment_upwards(self): @@ -743,14 +813,14 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest): sess = Session() b1 = B() - b1.class_name = 'c' + b1.class_name = "c" sess.add(b1) assert_raises_message( sa_exc.SAWarning, "Flushing object %s with incompatible " "polymorphic identity 'c'; the object may not " "refresh and/or load correctly" % instance_str(b1), - sess.flush + sess.flush, ) def test_entirely_oob_assignment(self): @@ -760,14 +830,14 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest): sess = Session() b1 = B() - b1.class_name = 'xyz' + b1.class_name = "xyz" sess.add(b1) assert_raises_message( sa_exc.SAWarning, "Flushing object %s with incompatible " "polymorphic identity 'xyz'; the object may not " "refresh and/or load correctly" % instance_str(b1), - sess.flush + sess.flush, ) def test_not_set_on_upate(self): @@ -779,7 +849,7 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest): sess.commit() sess.expire(c1) - c1.data = 'foo' + c1.data = "foo" sess.flush() def test_validate_on_upate(self): @@ -791,13 +861,13 @@ class PolymorphicAttributeManagementTest(fixtures.MappedTest): sess.commit() sess.expire(c1) - c1.class_name = 'b' + c1.class_name = "b" assert_raises_message( sa_exc.SAWarning, "Flushing object %s with incompatible " "polymorphic identity 'b'; the object may not " "refresh and/or load correctly" % instance_str(c1), - sess.flush + sess.flush, ) @@ -809,27 +879,41 @@ class CascadeTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): global t1, t2, t3, t4 - t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30))) - - t2 = Table('t2', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('t1id', Integer, ForeignKey('t1.id')), - Column('type', String(30)), - Column('data', String(30))) - t3 = Table('t3', metadata, - Column('id', Integer, ForeignKey('t2.id'), - primary_key=True), - Column('moredata', String(30))) - - t4 = Table('t4', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('t3id', Integer, ForeignKey('t3.id')), - Column('data', String(30))) + t1 = Table( + "t1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + ) + + t2 = Table( + "t2", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("t1id", Integer, ForeignKey("t1.id")), + Column("type", String(30)), + Column("data", String(30)), + ) + t3 = Table( + "t3", + metadata, + Column("id", Integer, ForeignKey("t2.id"), primary_key=True), + Column("moredata", String(30)), + ) + + t4 = Table( + "t4", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("t3id", Integer, ForeignKey("t3.id")), + Column("data", String(30)), + ) def test_cascade(self): class T1(fixtures.BasicEntity): @@ -844,25 +928,27 @@ class CascadeTest(fixtures.MappedTest): class T4(fixtures.BasicEntity): pass - mapper(T1, t1, properties={ - 't2s': relationship(T2, cascade="all") - }) - mapper(T2, t2, polymorphic_on=t2.c.type, polymorphic_identity='t2') - mapper(T3, t3, inherits=T2, polymorphic_identity='t3', properties={ - 't4s': relationship(T4, cascade="all") - }) + mapper(T1, t1, properties={"t2s": relationship(T2, cascade="all")}) + mapper(T2, t2, polymorphic_on=t2.c.type, polymorphic_identity="t2") + mapper( + T3, + t3, + inherits=T2, + polymorphic_identity="t3", + properties={"t4s": relationship(T4, cascade="all")}, + ) mapper(T4, t4) sess = create_session() - t1_1 = T1(data='t1') + t1_1 = T1(data="t1") - t3_1 = T3(data='t3', moredata='t3') - t2_1 = T2(data='t2') + t3_1 = T3(data="t3", moredata="t3") + t2_1 = T2(data="t2") t1_1.t2s.append(t2_1) t1_1.t2s.append(t3_1) - t4_1 = T4(data='t4') + t4_1 = T4(data="t4") t3_1.t4s.append(t4_1) sess.add(t1_1) @@ -878,21 +964,34 @@ class CascadeTest(fixtures.MappedTest): class M2OUseGetTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('base', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(30))) - Table('sub', metadata, - Column('id', Integer, ForeignKey('base.id'), primary_key=True)) - Table('related', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('sub_id', Integer, ForeignKey('sub.id'))) + Table( + "base", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(30)), + ) + Table( + "sub", + metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + ) + Table( + "related", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("sub_id", Integer, ForeignKey("sub.id")), + ) def test_use_get(self): - base, sub, related = (self.tables.base, - self.tables.sub, - self.tables.related) + base, sub, related = ( + self.tables.base, + self.tables.sub, + self.tables.related, + ) # test [ticket:1186] class Base(fixtures.BasicEntity): @@ -903,22 +1002,27 @@ class M2OUseGetTest(fixtures.MappedTest): class Related(Base): pass - mapper(Base, base, polymorphic_on=base.c.type, - polymorphic_identity='b') - mapper(Sub, sub, inherits=Base, polymorphic_identity='s') - mapper(Related, related, properties={ - # previously, this was needed for the comparison to occur: - # the 'primaryjoin' looks just like "Sub"'s "get" clause - # (based on the Base id), and foreign_keys since that join - # condition doesn't actually have any fks in it - # 'sub':relationship(Sub, primaryjoin=base.c.id==related.c.sub_id, - # foreign_keys=related.c.sub_id) - # now we can use this: - 'sub': relationship(Sub) - }) + mapper( + Base, base, polymorphic_on=base.c.type, polymorphic_identity="b" + ) + mapper(Sub, sub, inherits=Base, polymorphic_identity="s") + mapper( + Related, + related, + properties={ + # previously, this was needed for the comparison to occur: + # the 'primaryjoin' looks just like "Sub"'s "get" clause + # (based on the Base id), and foreign_keys since that join + # condition doesn't actually have any fks in it + # 'sub':relationship(Sub, primaryjoin=base.c.id==related.c.sub_id, + # foreign_keys=related.c.sub_id) + # now we can use this: + "sub": relationship(Sub) + }, + ) - assert class_mapper(Related).get_property('sub').strategy.use_get + assert class_mapper(Related).get_property("sub").strategy.use_get sess = create_session() s1 = Sub() @@ -932,6 +1036,7 @@ class M2OUseGetTest(fixtures.MappedTest): def go(): assert r1.sub + self.assert_sql_count(testing.db, go, 0) @@ -939,23 +1044,36 @@ class GetTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): global foo, bar, blub - foo = Table('foo', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(30)), - Column('data', String(20))) - - bar = Table('bar', metadata, - Column('id', Integer, ForeignKey( - 'foo.id'), primary_key=True), - Column('bar_data', String(20))) - - blub = Table('blub', metadata, - Column('blub_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('foo_id', Integer, ForeignKey('foo.id')), - Column('bar_id', Integer, ForeignKey('bar.id')), - Column('blub_data', String(20))) + foo = Table( + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(30)), + Column("data", String(20)), + ) + + bar = Table( + "bar", + metadata, + Column("id", Integer, ForeignKey("foo.id"), primary_key=True), + Column("bar_data", String(20)), + ) + + blub = Table( + "blub", + metadata, + Column( + "blub_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("foo_id", Integer, ForeignKey("foo.id")), + Column("bar_id", Integer, ForeignKey("bar.id")), + Column("blub_data", String(20)), + ) @classmethod def setup_classes(cls): @@ -975,18 +1093,21 @@ class GetTest(fixtures.MappedTest): self._do_get_test(False) def _do_get_test(self, polymorphic): - foo, Bar, Blub, blub, bar, Foo = (self.tables.foo, - self.classes.Bar, - self.classes.Blub, - self.tables.blub, - self.tables.bar, - self.classes.Foo) + foo, Bar, Blub, blub, bar, Foo = ( + self.tables.foo, + self.classes.Bar, + self.classes.Blub, + self.tables.blub, + self.tables.bar, + self.classes.Foo, + ) if polymorphic: - mapper(Foo, foo, polymorphic_on=foo.c.type, - polymorphic_identity='foo') - mapper(Bar, bar, inherits=Foo, polymorphic_identity='bar') - mapper(Blub, blub, inherits=Bar, polymorphic_identity='blub') + mapper( + Foo, foo, polymorphic_on=foo.c.type, polymorphic_identity="foo" + ) + mapper(Bar, bar, inherits=Foo, polymorphic_identity="bar") + mapper(Blub, blub, inherits=Bar, polymorphic_identity="blub") else: mapper(Foo, foo) mapper(Bar, bar, inherits=Foo) @@ -1002,6 +1123,7 @@ class GetTest(fixtures.MappedTest): sess.flush() if polymorphic: + def go(): assert sess.query(Foo).get(f.id) is f assert sess.query(Foo).get(b.id) is b @@ -1048,18 +1170,27 @@ class EagerLazyTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): global foo, bar, bar_foo - foo = Table('foo', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30))) - bar = Table('bar', metadata, - Column('id', Integer, ForeignKey( - 'foo.id'), primary_key=True), - Column('bar_data', String(30))) - - bar_foo = Table('bar_foo', metadata, - Column('bar_id', Integer, ForeignKey('bar.id')), - Column('foo_id', Integer, ForeignKey('foo.id'))) + foo = Table( + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + ) + bar = Table( + "bar", + metadata, + Column("id", Integer, ForeignKey("foo.id"), primary_key=True), + Column("bar_data", String(30)), + ) + + bar_foo = Table( + "bar_foo", + metadata, + Column("bar_id", Integer, ForeignKey("bar.id")), + Column("foo_id", Integer, ForeignKey("foo.id")), + ) def test_basic(self): class Foo(object): @@ -1070,17 +1201,17 @@ class EagerLazyTest(fixtures.MappedTest): foos = mapper(Foo, foo) bars = mapper(Bar, bar, inherits=foos) - bars.add_property('lazy', relationship(foos, bar_foo, lazy='select')) - bars.add_property('eager', relationship(foos, bar_foo, lazy='joined')) + bars.add_property("lazy", relationship(foos, bar_foo, lazy="select")) + bars.add_property("eager", relationship(foos, bar_foo, lazy="joined")) - foo.insert().execute(data='foo1') - bar.insert().execute(id=1, data='bar1') + foo.insert().execute(data="foo1") + bar.insert().execute(id=1, data="bar1") - foo.insert().execute(data='foo2') - bar.insert().execute(id=2, data='bar2') + foo.insert().execute(data="foo2") + bar.insert().execute(id=2, data="bar2") - foo.insert().execute(data='foo3') # 3 - foo.insert().execute(data='foo4') # 4 + foo.insert().execute(data="foo3") # 3 + foo.insert().execute(data="foo4") # 4 bar_foo.insert().execute(bar_id=1, foo_id=3) bar_foo.insert().execute(bar_id=2, foo_id=4) @@ -1097,16 +1228,21 @@ class EagerTargetingTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('a_table', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)), - Column('type', String(30), nullable=False), - Column('parent_id', Integer, ForeignKey('a_table.id'))) + Table( + "a_table", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + Column("type", String(30), nullable=False), + Column("parent_id", Integer, ForeignKey("a_table.id")), + ) - Table('b_table', metadata, - Column('id', Integer, ForeignKey( - 'a_table.id'), primary_key=True), - Column('b_data', String(50))) + Table( + "b_table", + metadata, + Column("id", Integer, ForeignKey("a_table.id"), primary_key=True), + Column("b_data", String(50)), + ) def test_adapt_stringency(self): b_table, a_table = self.tables.b_table, self.tables.a_table @@ -1118,21 +1254,30 @@ class EagerTargetingTest(fixtures.MappedTest): pass mapper( - A, a_table, polymorphic_on=a_table.c.type, - polymorphic_identity='A', - properties={'children': relationship(A, order_by=a_table.c.name)}) + A, + a_table, + polymorphic_on=a_table.c.type, + polymorphic_identity="A", + properties={"children": relationship(A, order_by=a_table.c.name)}, + ) - mapper(B, b_table, inherits=A, polymorphic_identity='B', properties={ - 'b_derived': column_property(b_table.c.b_data + "DATA") - }) + mapper( + B, + b_table, + inherits=A, + polymorphic_identity="B", + properties={ + "b_derived": column_property(b_table.c.b_data + "DATA") + }, + ) sess = create_session() - b1 = B(id=1, name='b1', b_data='i') + b1 = B(id=1, name="b1", b_data="i") sess.add(b1) sess.flush() - b2 = B(id=2, name='b2', b_data='l', parent_id=1) + b2 = B(id=2, name="b2", b_data="l", parent_id=1) sess.add(b2) sess.flush() @@ -1140,14 +1285,18 @@ class EagerTargetingTest(fixtures.MappedTest): sess.expunge_all() node = sess.query(B).filter(B.id == bid).all()[0] - eq_(node, B(id=1, name='b1', b_data='i')) - eq_(node.children[0], B(id=2, name='b2', b_data='l')) + eq_(node, B(id=1, name="b1", b_data="i")) + eq_(node.children[0], B(id=2, name="b2", b_data="l")) sess.expunge_all() - node = sess.query(B).options(joinedload(B.children))\ - .filter(B.id == bid).all()[0] - eq_(node, B(id=1, name='b1', b_data='i')) - eq_(node.children[0], B(id=2, name='b2', b_data='l')) + node = ( + sess.query(B) + .options(joinedload(B.children)) + .filter(B.id == bid) + .all()[0] + ) + eq_(node, B(id=1, name="b1", b_data="i")) + eq_(node.children[0], B(id=2, name="b2", b_data="l")) class FlushTest(fixtures.MappedTest): @@ -1155,34 +1304,55 @@ class FlushTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('email', String(128)), - Column('password', String(16))) - - Table('roles', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('description', String(32))) - - Table('user_roles', metadata, - Column('user_id', Integer, ForeignKey( - 'users.id'), primary_key=True), - Column('role_id', Integer, ForeignKey( - 'roles.id'), primary_key=True) - ) - - Table('admins', metadata, - Column('admin_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', Integer, ForeignKey('users.id'))) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("email", String(128)), + Column("password", String(16)), + ) + + Table( + "roles", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("description", String(32)), + ) + + Table( + "user_roles", + metadata, + Column( + "user_id", Integer, ForeignKey("users.id"), primary_key=True + ), + Column( + "role_id", Integer, ForeignKey("roles.id"), primary_key=True + ), + ) + + Table( + "admins", + metadata, + Column( + "admin_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("user_id", Integer, ForeignKey("users.id")), + ) def test_one(self): - admins, users, roles, user_roles = (self.tables.admins, - self.tables.users, - self.tables.roles, - self.tables.user_roles) + admins, users, roles, user_roles = ( + self.tables.admins, + self.tables.users, + self.tables.roles, + self.tables.user_roles, + ) class User(object): pass @@ -1192,9 +1362,17 @@ class FlushTest(fixtures.MappedTest): class Admin(User): pass + role_mapper = mapper(Role, roles) - user_mapper = mapper(User, users, properties={ - 'roles': relationship(Role, secondary=user_roles, lazy='joined')}) + user_mapper = mapper( + User, + users, + properties={ + "roles": relationship( + Role, secondary=user_roles, lazy="joined" + ) + }, + ) admin_mapper = mapper(Admin, admins, inherits=user_mapper) sess = create_session() adminrole = Role() @@ -1207,17 +1385,19 @@ class FlushTest(fixtures.MappedTest): # off and insert the many to many row twice. a = Admin() a.roles.append(adminrole) - a.password = 'admin' + a.password = "admin" sess.add(a) sess.flush() - eq_(select([func.count('*')]).select_from(user_roles).scalar(), 1) + eq_(select([func.count("*")]).select_from(user_roles).scalar(), 1) def test_two(self): - admins, users, roles, user_roles = (self.tables.admins, - self.tables.users, - self.tables.roles, - self.tables.user_roles) + admins, users, roles, user_roles = ( + self.tables.admins, + self.tables.users, + self.tables.roles, + self.tables.user_roles, + ) class User(object): def __init__(self, email=None, password=None): @@ -1232,51 +1412,64 @@ class FlushTest(fixtures.MappedTest): pass role_mapper = mapper(Role, roles) - user_mapper = mapper(User, users, properties={ - 'roles': relationship(Role, secondary=user_roles, lazy='joined')}) + user_mapper = mapper( + User, + users, + properties={ + "roles": relationship( + Role, secondary=user_roles, lazy="joined" + ) + }, + ) admin_mapper = mapper(Admin, admins, inherits=user_mapper) # create roles - adminrole = Role('admin') + adminrole = Role("admin") sess = create_session() sess.add(adminrole) sess.flush() # create admin user - a = Admin(email='tim', password='admin') + a = Admin(email="tim", password="admin") a.roles.append(adminrole) sess.add(a) sess.flush() - a.password = 'sadmin' + a.password = "sadmin" sess.flush() - eq_(select([func.count('*')]).select_from(user_roles).scalar(), 1) + eq_(select([func.count("*")]).select_from(user_roles).scalar(), 1) class PassiveDeletesTest(fixtures.MappedTest): - __requires__ = ('foreign_keys',) + __requires__ = ("foreign_keys",) @classmethod def define_tables(cls, metadata): Table( - "a", metadata, - Column('id', Integer, primary_key=True), - Column('type', String(30)) + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("type", String(30)), ) Table( - "b", metadata, + "b", + metadata, Column( - 'id', Integer, ForeignKey('a.id', ondelete="CASCADE"), - primary_key=True), - Column('data', String(10)) + "id", + Integer, + ForeignKey("a.id", ondelete="CASCADE"), + primary_key=True, + ), + Column("data", String(10)), ) Table( - "c", metadata, - Column('cid', Integer, primary_key=True), - Column('bid', ForeignKey('b.id', ondelete="CASCADE")) + "c", + metadata, + Column("cid", Integer, primary_key=True), + Column("bid", ForeignKey("b.id", ondelete="CASCADE")), ) @classmethod @@ -1295,12 +1488,14 @@ class PassiveDeletesTest(fixtures.MappedTest): a, b, c = self.tables("a", "b", "c") mapper( - A, a, passive_deletes=a_p, - polymorphic_on=a.c.type, polymorphic_identity='a') - mapper( - B, b, inherits=A, passive_deletes=b_p, polymorphic_identity='b') - mapper( - C, c, inherits=B, passive_deletes=c_p, polymorphic_identity='c') + A, + a, + passive_deletes=a_p, + polymorphic_on=a.c.type, + polymorphic_identity="a", + ) + mapper(B, b, inherits=A, passive_deletes=b_p, polymorphic_identity="b") + mapper(C, c, inherits=B, passive_deletes=c_p, polymorphic_identity="c") def test_none(self): A, B, C = self.classes("A", "B", "C") @@ -1319,22 +1514,11 @@ class PassiveDeletesTest(fixtures.MappedTest): s.flush() asserter.assert_( RegexSQL( - "SELECT .* " - "FROM c WHERE :param_1 = c.bid", - [{'param_1': 3}] - ), - CompiledSQL( - "DELETE FROM c WHERE c.cid = :cid", - [{'cid': 1}] - ), - CompiledSQL( - "DELETE FROM b WHERE b.id = :id", - [{'id': 3}] + "SELECT .* " "FROM c WHERE :param_1 = c.bid", [{"param_1": 3}] ), - CompiledSQL( - "DELETE FROM a WHERE a.id = :id", - [{'id': 3}] - ) + CompiledSQL("DELETE FROM c WHERE c.cid = :cid", [{"cid": 1}]), + CompiledSQL("DELETE FROM b WHERE b.id = :id", [{"id": 3}]), + CompiledSQL("DELETE FROM a WHERE a.id = :id", [{"id": 3}]), ) def test_c_only(self): @@ -1354,12 +1538,9 @@ class PassiveDeletesTest(fixtures.MappedTest): CompiledSQL( "SELECT a.id AS a_id, a.type AS a_type " "FROM a WHERE a.id = :param_1", - [{'param_1': 1}] + [{"param_1": 1}], ), - CompiledSQL( - "DELETE FROM a WHERE a.id = :id", - [{'id': 1}] - ) + CompiledSQL("DELETE FROM a WHERE a.id = :id", [{"id": 1}]), ) b1.id @@ -1367,14 +1548,8 @@ class PassiveDeletesTest(fixtures.MappedTest): with self.sql_execution_asserter(testing.db) as asserter: s.flush() asserter.assert_( - CompiledSQL( - "DELETE FROM b WHERE b.id = :id", - [{'id': 2}] - ), - CompiledSQL( - "DELETE FROM a WHERE a.id = :id", - [{'id': 2}] - ) + CompiledSQL("DELETE FROM b WHERE b.id = :id", [{"id": 2}]), + CompiledSQL("DELETE FROM a WHERE a.id = :id", [{"id": 2}]), ) # want to see if the 'C' table loads even though @@ -1384,14 +1559,8 @@ class PassiveDeletesTest(fixtures.MappedTest): with self.sql_execution_asserter(testing.db) as asserter: s.flush() asserter.assert_( - CompiledSQL( - "DELETE FROM b WHERE b.id = :id", - [{'id': 3}] - ), - CompiledSQL( - "DELETE FROM a WHERE a.id = :id", - [{'id': 3}] - ) + CompiledSQL("DELETE FROM b WHERE b.id = :id", [{"id": 3}]), + CompiledSQL("DELETE FROM a WHERE a.id = :id", [{"id": 3}]), ) def test_b_only(self): @@ -1411,12 +1580,9 @@ class PassiveDeletesTest(fixtures.MappedTest): CompiledSQL( "SELECT a.id AS a_id, a.type AS a_type " "FROM a WHERE a.id = :param_1", - [{'param_1': 1}] + [{"param_1": 1}], ), - CompiledSQL( - "DELETE FROM a WHERE a.id = :id", - [{'id': 1}] - ) + CompiledSQL("DELETE FROM a WHERE a.id = :id", [{"id": 1}]), ) b1.id @@ -1424,10 +1590,7 @@ class PassiveDeletesTest(fixtures.MappedTest): with self.sql_execution_asserter(testing.db) as asserter: s.flush() asserter.assert_( - CompiledSQL( - "DELETE FROM a WHERE a.id = :id", - [{'id': 2}] - ) + CompiledSQL("DELETE FROM a WHERE a.id = :id", [{"id": 2}]) ) c1.id @@ -1435,10 +1598,7 @@ class PassiveDeletesTest(fixtures.MappedTest): with self.sql_execution_asserter(testing.db) as asserter: s.flush() asserter.assert_( - CompiledSQL( - "DELETE FROM a WHERE a.id = :id", - [{'id': 3}] - ) + CompiledSQL("DELETE FROM a WHERE a.id = :id", [{"id": 3}]) ) def test_a_only(self): @@ -1458,12 +1618,9 @@ class PassiveDeletesTest(fixtures.MappedTest): CompiledSQL( "SELECT a.id AS a_id, a.type AS a_type " "FROM a WHERE a.id = :param_1", - [{'param_1': 1}] + [{"param_1": 1}], ), - CompiledSQL( - "DELETE FROM a WHERE a.id = :id", - [{'id': 1}] - ) + CompiledSQL("DELETE FROM a WHERE a.id = :id", [{"id": 1}]), ) b1.id @@ -1471,10 +1628,7 @@ class PassiveDeletesTest(fixtures.MappedTest): with self.sql_execution_asserter(testing.db) as asserter: s.flush() asserter.assert_( - CompiledSQL( - "DELETE FROM a WHERE a.id = :id", - [{'id': 2}] - ) + CompiledSQL("DELETE FROM a WHERE a.id = :id", [{"id": 2}]) ) # want to see if the 'C' table loads even though @@ -1484,10 +1638,7 @@ class PassiveDeletesTest(fixtures.MappedTest): with self.sql_execution_asserter(testing.db) as asserter: s.flush() asserter.assert_( - CompiledSQL( - "DELETE FROM a WHERE a.id = :id", - [{'id': 3}] - ) + CompiledSQL("DELETE FROM a WHERE a.id = :id", [{"id": 3}]) ) @@ -1497,14 +1648,17 @@ class OptimizedGetOnDeferredTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - "a", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True) + "a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), ) Table( - "b", metadata, - Column('id', Integer, ForeignKey('a.id'), primary_key=True), - Column('data', String(10)) + "b", + metadata, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + Column("data", String(10)), ) @classmethod @@ -1521,39 +1675,48 @@ class OptimizedGetOnDeferredTest(fixtures.MappedTest): a, b = cls.tables("a", "b") mapper(A, a) - mapper(B, b, inherits=A, properties={ - 'data': deferred(b.c.data), - 'expr': column_property(b.c.data + 'q', deferred=True) - }) + mapper( + B, + b, + inherits=A, + properties={ + "data": deferred(b.c.data), + "expr": column_property(b.c.data + "q", deferred=True), + }, + ) def test_column_property(self): A, B = self.classes("A", "B") sess = Session() - b1 = B(data='x') + b1 = B(data="x") sess.add(b1) sess.flush() - eq_(b1.expr, 'xq') + eq_(b1.expr, "xq") def test_expired_column(self): A, B = self.classes("A", "B") sess = Session() - b1 = B(data='x') + b1 = B(data="x") sess.add(b1) sess.flush() - sess.expire(b1, ['data']) + sess.expire(b1, ["data"]) - eq_(b1.data, 'x') + eq_(b1.data, "x") class JoinedNoFKSortingTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table("a", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True)) - Table("b", metadata, Column('id', Integer, primary_key=True)) - Table("c", metadata, Column('id', Integer, primary_key=True)) + Table( + "a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) + Table("b", metadata, Column("id", Integer, primary_key=True)) + Table("c", metadata, Column("id", Integer, primary_key=True)) @classmethod def setup_classes(cls): @@ -1570,12 +1733,20 @@ class JoinedNoFKSortingTest(fixtures.MappedTest): def setup_mappers(cls): A, B, C = cls.classes.A, cls.classes.B, cls.classes.C mapper(A, cls.tables.a) - mapper(B, cls.tables.b, inherits=A, - inherit_condition=cls.tables.a.c.id == cls.tables.b.c.id, - inherit_foreign_keys=cls.tables.b.c.id) - mapper(C, cls.tables.c, inherits=A, - inherit_condition=cls.tables.a.c.id == cls.tables.c.c.id, - inherit_foreign_keys=cls.tables.c.c.id) + mapper( + B, + cls.tables.b, + inherits=A, + inherit_condition=cls.tables.a.c.id == cls.tables.b.c.id, + inherit_foreign_keys=cls.tables.b.c.id, + ) + mapper( + C, + cls.tables.c, + inherits=A, + inherit_condition=cls.tables.a.c.id == cls.tables.c.c.id, + inherit_foreign_keys=cls.tables.c.c.id, + ) def test_ordering(self): B, C = self.classes.B, self.classes.C @@ -1590,41 +1761,52 @@ class JoinedNoFKSortingTest(fixtures.MappedTest): CompiledSQL("INSERT INTO a () VALUES ()", {}), AllOf( CompiledSQL( - "INSERT INTO b (id) VALUES (:id)", - [{"id": 1}, {"id": 3}] + "INSERT INTO b (id) VALUES (:id)", [{"id": 1}, {"id": 3}] ), CompiledSQL( - "INSERT INTO c (id) VALUES (:id)", - [{"id": 2}, {"id": 4}] - ) - ) + "INSERT INTO c (id) VALUES (:id)", [{"id": 2}, {"id": 4}] + ), + ), ) class VersioningTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('base', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('version_id', Integer, nullable=False), - Column('value', String(40)), - Column('discriminator', Integer, nullable=False)) - Table('subtable', metadata, - Column('id', None, ForeignKey('base.id'), primary_key=True), - Column('subdata', String(50))) - Table('stuff', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent', Integer, ForeignKey('base.id'))) + Table( + "base", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("version_id", Integer, nullable=False), + Column("value", String(40)), + Column("discriminator", Integer, nullable=False), + ) + Table( + "subtable", + metadata, + Column("id", None, ForeignKey("base.id"), primary_key=True), + Column("subdata", String(50)), + ) + Table( + "stuff", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent", Integer, ForeignKey("base.id")), + ) @testing.emits_warning(r".*updated rowcount") @testing.requires.sane_rowcount_w_returning @engines.close_open_connections def test_save_update(self): - subtable, base, stuff = (self.tables.subtable, - self.tables.base, - self.tables.stuff) + subtable, base, stuff = ( + self.tables.subtable, + self.tables.base, + self.tables.stuff, + ) class Base(fixtures.BasicEntity): pass @@ -1634,19 +1816,22 @@ class VersioningTest(fixtures.MappedTest): class Stuff(Base): pass + mapper(Stuff, stuff) - mapper(Base, base, - polymorphic_on=base.c.discriminator, - version_id_col=base.c.version_id, - polymorphic_identity=1, properties={ - 'stuff': relationship(Stuff) - }) + mapper( + Base, + base, + polymorphic_on=base.c.discriminator, + version_id_col=base.c.version_id, + polymorphic_identity=1, + properties={"stuff": relationship(Stuff)}, + ) mapper(Sub, subtable, inherits=Base, polymorphic_identity=2) sess = create_session() - b1 = Base(value='b1') - s1 = Sub(value='sub1', subdata='some subdata') + b1 = Base(value="b1") + s1 = Sub(value="sub1", subdata="some subdata") sess.add(b1) sess.add(s1) @@ -1654,15 +1839,17 @@ class VersioningTest(fixtures.MappedTest): sess2 = create_session() s2 = sess2.query(Base).get(s1.id) - s2.subdata = 'sess2 subdata' + s2.subdata = "sess2 subdata" - s1.subdata = 'sess1 subdata' + s1.subdata = "sess1 subdata" sess.flush() - assert_raises(orm_exc.StaleDataError, - sess2.query(Base).with_lockmode('read').get, - s1.id) + assert_raises( + orm_exc.StaleDataError, + sess2.query(Base).with_lockmode("read").get, + s1.id, + ) if not testing.db.dialect.supports_sane_rowcount: sess2.flush() @@ -1671,8 +1858,8 @@ class VersioningTest(fixtures.MappedTest): sess2.refresh(s2) if testing.db.dialect.supports_sane_rowcount: - assert s2.subdata == 'sess1 subdata' - s2.subdata = 'sess2 subdata' + assert s2.subdata == "sess1 subdata" + s2.subdata = "sess2 subdata" sess2.flush() @testing.emits_warning(r".*(update|delete)d rowcount") @@ -1686,16 +1873,20 @@ class VersioningTest(fixtures.MappedTest): class Sub(Base): pass - mapper(Base, base, - polymorphic_on=base.c.discriminator, - version_id_col=base.c.version_id, polymorphic_identity=1) + mapper( + Base, + base, + polymorphic_on=base.c.discriminator, + version_id_col=base.c.version_id, + polymorphic_identity=1, + ) mapper(Sub, subtable, inherits=Base, polymorphic_identity=2) sess = create_session() - b1 = Base(value='b1') - s1 = Sub(value='sub1', subdata='some subdata') - s2 = Sub(value='sub2', subdata='some other subdata') + b1 = Base(value="b1") + s1 = Sub(value="sub1", subdata="some subdata") + s2 = Sub(value="sub2", subdata="some other subdata") sess.add(b1) sess.add(s1) sess.add(s2) @@ -1707,15 +1898,12 @@ class VersioningTest(fixtures.MappedTest): sess2.delete(s3) sess2.flush() - s2.subdata = 'some new subdata' + s2.subdata = "some new subdata" sess.flush() - s1.subdata = 'some new subdata' + s1.subdata = "some new subdata" if testing.db.dialect.supports_sane_rowcount: - assert_raises( - orm_exc.StaleDataError, - sess.flush - ) + assert_raises(orm_exc.StaleDataError, sess.flush) else: sess.flush() @@ -1724,24 +1912,31 @@ class DistinctPKTest(fixtures.MappedTest): """test the construction of mapper.primary_key when an inheriting relationship joins on a column other than primary key column.""" - run_inserts = 'once' + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): global person_table, employee_table, Person, Employee - person_table = Table("persons", metadata, - Column("id", Integer, primary_key=True, - test_needs_autoincrement=True), - Column("name", String(80))) + person_table = Table( + "persons", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(80)), + ) - employee_table = Table("employees", metadata, - Column("eid", Integer, primary_key=True, - test_needs_autoincrement=True), - Column("salary", Integer), - Column("person_id", Integer, - ForeignKey("persons.id"))) + employee_table = Table( + "employees", + metadata, + Column( + "eid", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("salary", Integer), + Column("person_id", Integer, ForeignKey("persons.id")), + ) class Person(object): def __init__(self, name): @@ -1753,8 +1948,8 @@ class DistinctPKTest(fixtures.MappedTest): @classmethod def insert_data(cls): person_insert = person_table.insert() - person_insert.execute(id=1, name='alice') - person_insert.execute(id=2, name='bob') + person_insert.execute(id=1, name="alice") + person_insert.execute(id=2, name="bob") employee_insert = employee_table.insert() employee_insert.execute(id=2, salary=250, person_id=1) # alice @@ -1767,29 +1962,42 @@ class DistinctPKTest(fixtures.MappedTest): def test_explicit_props(self): person_mapper = mapper(Person, person_table) - mapper(Employee, employee_table, inherits=person_mapper, - properties={'pid': person_table.c.id, - 'eid': employee_table.c.eid}) + mapper( + Employee, + employee_table, + inherits=person_mapper, + properties={"pid": person_table.c.id, "eid": employee_table.c.eid}, + ) self._do_test(False) def test_explicit_composite_pk(self): person_mapper = mapper(Person, person_table) - mapper(Employee, employee_table, - inherits=person_mapper, - properties=dict(id=[employee_table.c.eid, person_table.c.id]), - primary_key=[person_table.c.id, employee_table.c.eid]) + mapper( + Employee, + employee_table, + inherits=person_mapper, + properties=dict(id=[employee_table.c.eid, person_table.c.id]), + primary_key=[person_table.c.id, employee_table.c.eid], + ) assert_raises_message( sa_exc.SAWarning, r"On mapper Mapper\|Employee\|employees, " "primary key column 'persons.id' is being " "combined with distinct primary key column 'employees.eid' " "in attribute 'id'. Use explicit properties to give " - "each column its own mapped attribute name.", self._do_test, True) + "each column its own mapped attribute name.", + self._do_test, + True, + ) def test_explicit_pk(self): person_mapper = mapper(Person, person_table) - mapper(Employee, employee_table, inherits=person_mapper, - primary_key=[person_table.c.id]) + mapper( + Employee, + employee_table, + inherits=person_mapper, + primary_key=[person_table.c.id], + ) self._do_test(False) def _do_test(self, composite): @@ -1805,8 +2013,8 @@ class DistinctPKTest(fixtures.MappedTest): bob = query.get(2) alice2 = query.get(1) - assert alice1.name == alice2.name == 'alice' - assert bob.name == 'bob' + assert alice1.name == alice2.name == "alice" + assert bob.name == "bob" class SyncCompileTest(fixtures.MappedTest): @@ -1816,28 +2024,42 @@ class SyncCompileTest(fixtures.MappedTest): def define_tables(cls, metadata): global _a_table, _b_table, _c_table - _a_table = Table('a', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data1', String(128))) + _a_table = Table( + "a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data1", String(128)), + ) - _b_table = Table('b', metadata, - Column('a_id', Integer, ForeignKey( - 'a.id'), primary_key=True), - Column('data2', String(128))) + _b_table = Table( + "b", + metadata, + Column("a_id", Integer, ForeignKey("a.id"), primary_key=True), + Column("data2", String(128)), + ) - _c_table = Table('c', metadata, - # Column('a_id', Integer, ForeignKey('b.a_id'), - # primary_key=True), #works - Column('b_a_id', Integer, ForeignKey( - 'b.a_id'), primary_key=True), - Column('data3', String(128))) + _c_table = Table( + "c", + metadata, + # Column('a_id', Integer, ForeignKey('b.a_id'), + # primary_key=True), #works + Column("b_a_id", Integer, ForeignKey("b.a_id"), primary_key=True), + Column("data3", String(128)), + ) def test_joins(self): - for j1 in (None, _b_table.c.a_id == _a_table.c.id, _a_table.c.id == - _b_table.c.a_id): - for j2 in (None, _b_table.c.a_id == _c_table.c.b_a_id, - _c_table.c.b_a_id == _b_table.c.a_id): + for j1 in ( + None, + _b_table.c.a_id == _a_table.c.id, + _a_table.c.id == _b_table.c.a_id, + ): + for j2 in ( + None, + _b_table.c.a_id == _c_table.c.b_a_id, + _c_table.c.b_a_id == _b_table.c.a_id, + ): self._do_test(j1, j2) for t in reversed(_a_table.metadata.sorted_tables): t.delete().execute().close() @@ -1855,22 +2077,18 @@ class SyncCompileTest(fixtures.MappedTest): pass mapper(A, _a_table) - mapper(B, _b_table, inherits=A, - inherit_condition=j1 - ) - mapper(C, _c_table, inherits=B, - inherit_condition=j2 - ) + mapper(B, _b_table, inherits=A, inherit_condition=j1) + mapper(C, _c_table, inherits=B, inherit_condition=j2) session = create_session() - a = A(data1='a1') + a = A(data1="a1") session.add(a) - b = B(data1='b1', data2='b2') + b = B(data1="b1", data2="b2") session.add(b) - c = C(data1='c1', data2='c2', data3='c3') + c = C(data1="c1", data2="c2", data3="c3") session.add(c) session.flush() @@ -1888,21 +2106,37 @@ class OverrideColKeyTest(fixtures.MappedTest): def define_tables(cls, metadata): global base, subtable, subtable_two - base = Table('base', metadata, - Column('base_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(255)), - Column('sqlite_fixer', String(10))) - - subtable = Table('subtable', metadata, - Column('base_id', Integer, ForeignKey( - 'base.base_id'), primary_key=True), - Column('subdata', String(255))) - subtable_two = Table('subtable_two', metadata, - Column('base_id', Integer, primary_key=True), - Column('fk_base_id', Integer, - ForeignKey('base.base_id')), - Column('subdata', String(255))) + base = Table( + "base", + metadata, + Column( + "base_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("data", String(255)), + Column("sqlite_fixer", String(10)), + ) + + subtable = Table( + "subtable", + metadata, + Column( + "base_id", + Integer, + ForeignKey("base.base_id"), + primary_key=True, + ), + Column("subdata", String(255)), + ) + subtable_two = Table( + "subtable_two", + metadata, + Column("base_id", Integer, primary_key=True), + Column("fk_base_id", Integer, ForeignKey("base.base_id")), + Column("subdata", String(255)), + ) def test_plain(self): # control case @@ -1918,8 +2152,8 @@ class OverrideColKeyTest(fixtures.MappedTest): # Sub gets a "base_id" property using the "base_id" # column of both tables. eq_( - class_mapper(Sub).get_property('base_id').columns, - [subtable.c.base_id, base.c.base_id] + class_mapper(Sub).get_property("base_id").columns, + [subtable.c.base_id, base.c.base_id], ) def test_override_explicit(self): @@ -1933,18 +2167,21 @@ class OverrideColKeyTest(fixtures.MappedTest): class Sub(Base): pass - mapper(Base, base, properties={ - 'id': base.c.base_id - }) - mapper(Sub, subtable, inherits=Base, properties={ - # this is the manual way to do it, is not really - # possible in declarative - 'id': [base.c.base_id, subtable.c.base_id] - }) + mapper(Base, base, properties={"id": base.c.base_id}) + mapper( + Sub, + subtable, + inherits=Base, + properties={ + # this is the manual way to do it, is not really + # possible in declarative + "id": [base.c.base_id, subtable.c.base_id] + }, + ) eq_( - class_mapper(Sub).get_property('id').columns, - [base.c.base_id, subtable.c.base_id] + class_mapper(Sub).get_property("id").columns, + [base.c.base_id, subtable.c.base_id], ) s1 = Sub() @@ -1961,19 +2198,14 @@ class OverrideColKeyTest(fixtures.MappedTest): class Sub(Base): pass - mapper(Base, base, properties={ - 'id': base.c.base_id - }) + mapper(Base, base, properties={"id": base.c.base_id}) mapper(Sub, subtable, inherits=Base) - eq_( - class_mapper(Sub).get_property('id').columns, - [base.c.base_id] - ) + eq_(class_mapper(Sub).get_property("id").columns, [base.c.base_id]) eq_( - class_mapper(Sub).get_property('base_id').columns, - [subtable.c.base_id] + class_mapper(Sub).get_property("base_id").columns, + [subtable.c.base_id], ) s1 = Sub() @@ -2003,14 +2235,16 @@ class OverrideColKeyTest(fixtures.MappedTest): class Sub(Base): pass - mapper(Base, base, properties={ - 'id': base.c.base_id - }) + mapper(Base, base, properties={"id": base.c.base_id}) def go(): - mapper(Sub, subtable, inherits=Base, properties={ - 'id': subtable.c.base_id - }) + mapper( + Sub, + subtable, + inherits=Base, + properties={"id": subtable.c.base_id}, + ) + # Sub mapper compilation needs to detect that "base.c.base_id" # is renamed in the inherited mapper as "id", even though # it has its own "id" property. It then generates @@ -2028,11 +2262,12 @@ class OverrideColKeyTest(fixtures.MappedTest): def go(): mapper(Sub, subtable_two, inherits=Base) + assert_raises_message( sa_exc.SAWarning, "Implicitly combining column base.base_id with " "column subtable_two.base_id under attribute 'base_id'", - go + go, ) def test_plain_descriptor(self): @@ -2138,24 +2373,37 @@ class OptimizedLoadTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('base', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)), - Column('type', String(50)), - Column('counter', Integer, server_default="1")) - Table('sub', metadata, - Column('id', Integer, ForeignKey('base.id'), primary_key=True), - Column('sub', String(50)), - Column('subcounter', Integer, server_default="1"), - Column('subcounter2', Integer, server_default="1")) - Table('subsub', metadata, - Column('id', Integer, ForeignKey('sub.id'), primary_key=True), - Column('subsubcounter2', Integer, server_default="1")) - Table('with_comp', metadata, - Column('id', Integer, ForeignKey('base.id'), primary_key=True), - Column('a', String(10)), - Column('b', String(10))) + Table( + "base", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + Column("type", String(50)), + Column("counter", Integer, server_default="1"), + ) + Table( + "sub", + metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column("sub", String(50)), + Column("subcounter", Integer, server_default="1"), + Column("subcounter2", Integer, server_default="1"), + ) + Table( + "subsub", + metadata, + Column("id", Integer, ForeignKey("sub.id"), primary_key=True), + Column("subsubcounter2", Integer, server_default="1"), + ) + Table( + "with_comp", + metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column("a", String(10)), + Column("b", String(10)), + ) def test_no_optimize_on_map_to_join(self): base, sub = self.tables.base, self.tables.sub @@ -2170,13 +2418,20 @@ class OptimizedLoadTest(fixtures.MappedTest): pass mapper(Base, base) - mapper(JoinBase, base.outerjoin(sub), properties=util.OrderedDict( - [('id', [base.c.id, sub.c.id]), - ('counter', [base.c.counter, sub.c.subcounter])])) + mapper( + JoinBase, + base.outerjoin(sub), + properties=util.OrderedDict( + [ + ("id", [base.c.id, sub.c.id]), + ("counter", [base.c.counter, sub.c.subcounter]), + ] + ), + ) mapper(SubJoinBase, inherits=JoinBase) sess = Session() - sess.add(Base(data='data')) + sess.add(Base(data="data")) sess.commit() sjb = sess.query(SubJoinBase).one() @@ -2186,10 +2441,11 @@ class OptimizedLoadTest(fixtures.MappedTest): # this should not use the optimized load, # which assumes discrete tables def go(): - eq_(sjb.data, 'data') + eq_(sjb.data, "data") self.assert_sql_execution( - testing.db, go, + testing.db, + go, CompiledSQL( "SELECT base.id AS base_id, sub.id AS sub_id, " "base.counter AS base_counter, " @@ -2197,7 +2453,10 @@ class OptimizedLoadTest(fixtures.MappedTest): "base.data AS base_data, base.type AS base_type, " "sub.sub AS sub_sub, sub.subcounter2 AS sub_subcounter2 " "FROM base LEFT OUTER JOIN sub ON base.id = sub.id " - "WHERE base.id = :param_1", {'param_1': sjb_id})) + "WHERE base.id = :param_1", + {"param_1": sjb_id}, + ), + ) def test_optimized_passes(self): """"test that the 'optimized load' routine doesn't crash when @@ -2211,16 +2470,21 @@ class OptimizedLoadTest(fixtures.MappedTest): class Sub(Base): pass - mapper(Base, base, polymorphic_on=base.c.type, - polymorphic_identity='base') + mapper( + Base, base, polymorphic_on=base.c.type, polymorphic_identity="base" + ) # redefine Sub's "id" to favor the "id" col in the subtable. # "id" is also part of the primary join condition - mapper(Sub, sub, inherits=Base, - polymorphic_identity='sub', - properties={'id': [sub.c.id, base.c.id]}) + mapper( + Sub, + sub, + inherits=Base, + polymorphic_identity="sub", + properties={"id": [sub.c.id, base.c.id]}, + ) sess = sessionmaker()() - s1 = Sub(data='s1data', sub='s1sub') + s1 = Sub(data="s1data", sub="s1sub") sess.add(s1) sess.commit() sess.expunge_all() @@ -2231,7 +2495,7 @@ class OptimizedLoadTest(fixtures.MappedTest): # unloaded. the optimized load needs to return "None" so regular # full-row loading proceeds s1 = sess.query(Base).first() - assert s1.sub == 's1sub' + assert s1.sub == "s1sub" def test_column_expression(self): base, sub = self.tables.base, self.tables.sub @@ -2241,18 +2505,26 @@ class OptimizedLoadTest(fixtures.MappedTest): class Sub(Base): pass - mapper(Base, base, polymorphic_on=base.c.type, - polymorphic_identity='base') - mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', - properties={ - 'concat': column_property(sub.c.sub + "|" + sub.c.sub)}) + + mapper( + Base, base, polymorphic_on=base.c.type, polymorphic_identity="base" + ) + mapper( + Sub, + sub, + inherits=Base, + polymorphic_identity="sub", + properties={ + "concat": column_property(sub.c.sub + "|" + sub.c.sub) + }, + ) sess = sessionmaker()() - s1 = Sub(data='s1data', sub='s1sub') + s1 = Sub(data="s1data", sub="s1sub") sess.add(s1) sess.commit() sess.expunge_all() s1 = sess.query(Base).first() - assert s1.concat == 's1sub|s1sub' + assert s1.concat == "s1sub|s1sub" def test_column_expression_joined(self): base, sub = self.tables.base, self.tables.sub @@ -2262,15 +2534,23 @@ class OptimizedLoadTest(fixtures.MappedTest): class Sub(Base): pass - mapper(Base, base, polymorphic_on=base.c.type, - polymorphic_identity='base') - mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', - properties={ - 'concat': column_property(base.c.data + "|" + sub.c.sub)}) + + mapper( + Base, base, polymorphic_on=base.c.type, polymorphic_identity="base" + ) + mapper( + Sub, + sub, + inherits=Base, + polymorphic_identity="sub", + properties={ + "concat": column_property(base.c.data + "|" + sub.c.sub) + }, + ) sess = sessionmaker()() - s1 = Sub(data='s1data', sub='s1sub') - s2 = Sub(data='s2data', sub='s2sub') - s3 = Sub(data='s3data', sub='s3sub') + s1 = Sub(data="s1data", sub="s1sub") + s2 = Sub(data="s2data", sub="s2sub") + s3 = Sub(data="s3data", sub="s3sub") sess.add_all([s1, s2, s3]) sess.commit() sess.expunge_all() @@ -2281,10 +2561,10 @@ class OptimizedLoadTest(fixtures.MappedTest): eq_( sess.query(Base).order_by(Base.id).all(), [ - Sub(data='s1data', sub='s1sub', concat='s1data|s1sub'), - Sub(data='s2data', sub='s2sub', concat='s2data|s2sub'), - Sub(data='s3data', sub='s3sub', concat='s3data|s3sub') - ] + Sub(data="s1data", sub="s1sub", concat="s1data|s1sub"), + Sub(data="s2data", sub="s2sub", concat="s2data|s2sub"), + Sub(data="s3data", sub="s3sub", concat="s3data|s3sub"), + ], ) def test_composite_column_joined(self): @@ -2306,22 +2586,28 @@ class OptimizedLoadTest(fixtures.MappedTest): def __eq__(self, other): return (self.a == other.a) and (self.b == other.b) - mapper(Base, base, polymorphic_on=base.c.type, - polymorphic_identity='base') - mapper(WithComp, with_comp, inherits=Base, polymorphic_identity='wc', - properties={'comp': composite(Comp, - with_comp.c.a, with_comp.c.b)}) + + mapper( + Base, base, polymorphic_on=base.c.type, polymorphic_identity="base" + ) + mapper( + WithComp, + with_comp, + inherits=Base, + polymorphic_identity="wc", + properties={"comp": composite(Comp, with_comp.c.a, with_comp.c.b)}, + ) sess = sessionmaker()() - s1 = WithComp(data='s1data', comp=Comp('ham', 'cheese')) - s2 = WithComp(data='s2data', comp=Comp('bacon', 'eggs')) + s1 = WithComp(data="s1data", comp=Comp("ham", "cheese")) + s2 = WithComp(data="s2data", comp=Comp("bacon", "eggs")) sess.add_all([s1, s2]) sess.commit() sess.expunge_all() s1test, s2test = sess.query(Base).order_by(Base.id).all() assert s1test.comp assert s2test.comp - eq_(s1test.comp, Comp('ham', 'cheese')) - eq_(s2test.comp, Comp('bacon', 'eggs')) + eq_(s1test.comp, Comp("ham", "cheese")) + eq_(s2test.comp, Comp("bacon", "eggs")) def test_load_expired_on_pending(self): base, sub = self.tables.base, self.tables.sub @@ -2331,35 +2617,41 @@ class OptimizedLoadTest(fixtures.MappedTest): class Sub(Base): pass - mapper(Base, base, polymorphic_on=base.c.type, - polymorphic_identity='base') - mapper(Sub, sub, inherits=Base, polymorphic_identity='sub') + + mapper( + Base, base, polymorphic_on=base.c.type, polymorphic_identity="base" + ) + mapper(Sub, sub, inherits=Base, polymorphic_identity="sub") sess = Session() - s1 = Sub(data='s1') + s1 = Sub(data="s1") sess.add(s1) self.assert_sql_execution( testing.db, sess.flush, CompiledSQL( "INSERT INTO base (data, type) VALUES (:data, :type)", - [{'data': 's1', 'type': 'sub'}] + [{"data": "s1", "type": "sub"}], ), CompiledSQL( "INSERT INTO sub (id, sub) VALUES (:id, :sub)", - lambda ctx: {'id': s1.id, 'sub': None} + lambda ctx: {"id": s1.id, "sub": None}, ), ) def go(): eq_(s1.subcounter2, 1) + self.assert_sql_execution( - testing.db, go, + testing.db, + go, CompiledSQL( "SELECT base.counter AS base_counter, " "sub.subcounter AS sub_subcounter, " "sub.subcounter2 AS sub_subcounter2 FROM base JOIN sub " "ON base.id = sub.id WHERE base.id = :param_1", - lambda ctx: {'param_1': s1.id})) + lambda ctx: {"param_1": s1.id}, + ), + ) def test_dont_generate_on_none(self): base, sub = self.tables.base, self.tables.sub @@ -2369,30 +2661,46 @@ class OptimizedLoadTest(fixtures.MappedTest): class Sub(Base): pass - mapper(Base, base, polymorphic_on=base.c.type, - polymorphic_identity='base') - m = mapper(Sub, sub, inherits=Base, polymorphic_identity='sub') + + mapper( + Base, base, polymorphic_on=base.c.type, polymorphic_identity="base" + ) + m = mapper(Sub, sub, inherits=Base, polymorphic_identity="sub") s1 = Sub() - assert m._optimized_get_statement(attributes.instance_state(s1), - ['subcounter2']) is None + assert ( + m._optimized_get_statement( + attributes.instance_state(s1), ["subcounter2"] + ) + is None + ) # loads s1.id as None eq_(s1.id, None) # this now will come up with a value of None for id - should reject - assert m._optimized_get_statement(attributes.instance_state(s1), - ['subcounter2']) is None + assert ( + m._optimized_get_statement( + attributes.instance_state(s1), ["subcounter2"] + ) + is None + ) s1.id = 1 attributes.instance_state(s1)._commit_all(s1.__dict__, None) - assert m._optimized_get_statement(attributes.instance_state(s1), - ['subcounter2']) is not None + assert ( + m._optimized_get_statement( + attributes.instance_state(s1), ["subcounter2"] + ) + is not None + ) def test_load_expired_on_pending_twolevel(self): - base, sub, subsub = (self.tables.base, - self.tables.sub, - self.tables.subsub) + base, sub, subsub = ( + self.tables.base, + self.tables.sub, + self.tables.subsub, + ) class Base(fixtures.BasicEntity): pass @@ -2403,12 +2711,13 @@ class OptimizedLoadTest(fixtures.MappedTest): class SubSub(Sub): pass - mapper(Base, base, polymorphic_on=base.c.type, - polymorphic_identity='base') - mapper(Sub, sub, inherits=Base, polymorphic_identity='sub') - mapper(SubSub, subsub, inherits=Sub, polymorphic_identity='subsub') + mapper( + Base, base, polymorphic_on=base.c.type, polymorphic_identity="base" + ) + mapper(Sub, sub, inherits=Base, polymorphic_identity="sub") + mapper(SubSub, subsub, inherits=Sub, polymorphic_identity="subsub") sess = Session() - s1 = SubSub(data='s1', counter=1, subcounter=2) + s1 = SubSub(data="s1", counter=1, subcounter=2) sess.add(s1) self.assert_sql_execution( testing.db, @@ -2416,23 +2725,22 @@ class OptimizedLoadTest(fixtures.MappedTest): CompiledSQL( "INSERT INTO base (data, type, counter) VALUES " "(:data, :type, :counter)", - [{'data': 's1', 'type': 'subsub', 'counter': 1}] + [{"data": "s1", "type": "subsub", "counter": 1}], ), CompiledSQL( "INSERT INTO sub (id, sub, subcounter) VALUES " "(:id, :sub, :subcounter)", - lambda ctx: [{'subcounter': 2, 'sub': None, 'id': s1.id}] + lambda ctx: [{"subcounter": 2, "sub": None, "id": s1.id}], ), CompiledSQL( "INSERT INTO subsub (id) VALUES (:id)", - lambda ctx: {'id': s1.id} + lambda ctx: {"id": s1.id}, ), ) def go(): - eq_( - s1.subcounter2, 1 - ) + eq_(s1.subcounter2, 1) + self.assert_sql_execution( testing.db, go, @@ -2441,27 +2749,28 @@ class OptimizedLoadTest(fixtures.MappedTest): "SELECT subsub.subsubcounter2 AS subsub_subsubcounter2, " "sub.subcounter2 AS sub_subcounter2 FROM subsub, sub " "WHERE :param_1 = sub.id AND sub.id = subsub.id", - lambda ctx: {'param_1': s1.id} + lambda ctx: {"param_1": s1.id}, ), CompiledSQL( "SELECT sub.subcounter2 AS sub_subcounter2, " "subsub.subsubcounter2 AS subsub_subsubcounter2 " "FROM sub, subsub " "WHERE :param_1 = sub.id AND sub.id = subsub.id", - lambda ctx: {'param_1': s1.id} + lambda ctx: {"param_1": s1.id}, ), - ) + ), ) class NoPKOnSubTableWarningTest(fixtures.TestBase): - def _fixture(self): metadata = MetaData() - parent = Table('parent', metadata, - Column('id', Integer, primary_key=True)) - child = Table('child', metadata, - Column('id', Integer, ForeignKey('parent.id'))) + parent = Table( + "parent", metadata, Column("id", Integer, primary_key=True) + ) + child = Table( + "child", metadata, Column("id", Integer, ForeignKey("parent.id")) + ) return parent, child def tearDown(self): @@ -2481,7 +2790,10 @@ class NoPKOnSubTableWarningTest(fixtures.TestBase): sa_exc.SAWarning, "Could not assemble any primary keys for locally mapped " "table 'child' - no rows will be persisted in this Table.", - mapper, C, child, inherits=P + mapper, + C, + child, + inherits=P, ) def test_no_warning_with_explicit(self): @@ -2501,13 +2813,15 @@ class NoPKOnSubTableWarningTest(fixtures.TestBase): class InhCondTest(fixtures.TestBase): def test_inh_cond_nonexistent_table_unrelated(self): metadata = MetaData() - base_table = Table("base", metadata, - Column("id", Integer, primary_key=True)) - derived_table = Table("derived", metadata, - Column("id", Integer, ForeignKey( - "base.id"), primary_key=True), - Column("owner_id", Integer, - ForeignKey("owner.owner_id"))) + base_table = Table( + "base", metadata, Column("id", Integer, primary_key=True) + ) + derived_table = Table( + "derived", + metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column("owner_id", Integer, ForeignKey("owner.owner_id")), + ) class Base(object): pass @@ -2517,23 +2831,23 @@ class InhCondTest(fixtures.TestBase): mapper(Base, base_table) # succeeds, despite "owner" table not configured yet - m2 = mapper(Derived, derived_table, - inherits=Base) + m2 = mapper(Derived, derived_table, inherits=Base) assert m2.inherit_condition.compare( base_table.c.id == derived_table.c.id ) def test_inh_cond_nonexistent_col_unrelated(self): m = MetaData() - base_table = Table("base", m, - Column("id", Integer, primary_key=True)) - derived_table = Table("derived", m, - Column("id", Integer, ForeignKey('base.id'), - primary_key=True), - Column('order_id', Integer, - ForeignKey('order.foo'))) - order_table = Table('order', m, Column( - 'id', Integer, primary_key=True)) + base_table = Table("base", m, Column("id", Integer, primary_key=True)) + derived_table = Table( + "derived", + m, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column("order_id", Integer, ForeignKey("order.foo")), + ) + order_table = Table( + "order", m, Column("id", Integer, primary_key=True) + ) class Base(object): pass @@ -2551,10 +2865,12 @@ class InhCondTest(fixtures.TestBase): def test_inh_cond_no_fk(self): metadata = MetaData() - base_table = Table("base", metadata, - Column("id", Integer, primary_key=True)) - derived_table = Table("derived", metadata, - Column("id", Integer, primary_key=True)) + base_table = Table( + "base", metadata, Column("id", Integer, primary_key=True) + ) + derived_table = Table( + "derived", metadata, Column("id", Integer, primary_key=True) + ) class Base(object): pass @@ -2568,17 +2884,20 @@ class InhCondTest(fixtures.TestBase): "Can't find any foreign key relationships between " "'base' and 'derived'.", mapper, - Derived, derived_table, inherits=Base + Derived, + derived_table, + inherits=Base, ) def test_inh_cond_nonexistent_table_related(self): m1 = MetaData() m2 = MetaData() - base_table = Table("base", m1, - Column("id", Integer, primary_key=True)) - derived_table = Table("derived", m2, - Column("id", Integer, ForeignKey('base.id'), - primary_key=True)) + base_table = Table("base", m1, Column("id", Integer, primary_key=True)) + derived_table = Table( + "derived", + m2, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + ) class Base(object): pass @@ -2598,16 +2917,19 @@ class InhCondTest(fixtures.TestBase): "could not find table 'base' with which to generate " "a foreign key to target column 'id'", mapper, - Derived, derived_table, inherits=Base + Derived, + derived_table, + inherits=Base, ) def test_inh_cond_nonexistent_col_related(self): m = MetaData() - base_table = Table("base", m, - Column("id", Integer, primary_key=True)) - derived_table = Table("derived", m, - Column("id", Integer, ForeignKey('base.q'), - primary_key=True)) + base_table = Table("base", m, Column("id", Integer, primary_key=True)) + derived_table = Table( + "derived", + m, + Column("id", Integer, ForeignKey("base.q"), primary_key=True), + ) class Base(object): pass @@ -2623,23 +2945,31 @@ class InhCondTest(fixtures.TestBase): "'base.q' on table " "'derived': table 'base' has no column named 'q'", mapper, - Derived, derived_table, inherits=Base + Derived, + derived_table, + inherits=Base, ) class PKDiscriminatorTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - parents = Table('parents', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(60))) + parents = Table( + "parents", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(60)), + ) - children = Table('children', metadata, - Column('id', Integer, ForeignKey('parents.id'), - primary_key=True), - Column('type', Integer, primary_key=True), - Column('name', String(60))) + children = Table( + "children", + metadata, + Column("id", Integer, ForeignKey("parents.id"), primary_key=True), + Column("type", Integer, primary_key=True), + Column("name", String(60)), + ) def test_pk_as_discriminator(self): parents, children = self.tables.parents, self.tables.children @@ -2655,17 +2985,23 @@ class PKDiscriminatorTest(fixtures.MappedTest): class A(Child): pass - mapper(Parent, parents, properties={ - 'children': relationship(Child, backref='parent'), - }) - mapper(Child, children, polymorphic_on=children.c.type, - polymorphic_identity=1) + mapper( + Parent, + parents, + properties={"children": relationship(Child, backref="parent")}, + ) + mapper( + Child, + children, + polymorphic_on=children.c.type, + polymorphic_identity=1, + ) mapper(A, inherits=Child, polymorphic_identity=2) s = create_session() - p = Parent('p1') - a = A('a1') + p = Parent("p1") + a = A("a1") p.children.append(a) s.add(p) s.flush() @@ -2673,22 +3009,26 @@ class PKDiscriminatorTest(fixtures.MappedTest): assert a.id assert a.type == 2 - p.name = 'p1new' - a.name = 'a1new' + p.name = "p1new" + a.name = "a1new" s.flush() s.expire_all() - assert a.name == 'a1new' - assert p.name == 'p1new' + assert a.name == "a1new" + assert p.name == "p1new" class NoPolyIdentInMiddleTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('base', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(50), nullable=False)) + Table( + "base", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(50), nullable=False), + ) @classmethod def setup_classes(cls): @@ -2709,18 +3049,20 @@ class NoPolyIdentInMiddleTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - A, C, B, E, D, base = (cls.classes.A, - cls.classes.C, - cls.classes.B, - cls.classes.E, - cls.classes.D, - cls.tables.base) + A, C, B, E, D, base = ( + cls.classes.A, + cls.classes.C, + cls.classes.B, + cls.classes.E, + cls.classes.D, + cls.tables.base, + ) mapper(A, base, polymorphic_on=base.c.type) - mapper(B, inherits=A, ) - mapper(C, inherits=B, polymorphic_identity='c') - mapper(D, inherits=B, polymorphic_identity='d') - mapper(E, inherits=A, polymorphic_identity='e') + mapper(B, inherits=A) + mapper(C, inherits=B, polymorphic_identity="c") + mapper(D, inherits=B, polymorphic_identity="d") + mapper(E, inherits=A, polymorphic_identity="e") def test_load_from_middle(self): C, B = self.classes.C, self.classes.B @@ -2728,7 +3070,7 @@ class NoPolyIdentInMiddleTest(fixtures.MappedTest): s = Session() s.add(C()) o = s.query(B).first() - eq_(o.type, 'c') + eq_(o.type, "c") assert isinstance(o, C) def test_load_from_base(self): @@ -2737,30 +3079,27 @@ class NoPolyIdentInMiddleTest(fixtures.MappedTest): s = Session() s.add(C()) o = s.query(A).first() - eq_(o.type, 'c') + eq_(o.type, "c") assert isinstance(o, C) def test_discriminator(self): - C, B, base = (self.classes.C, - self.classes.B, - self.tables.base) + C, B, base = (self.classes.C, self.classes.B, self.tables.base) assert class_mapper(B).polymorphic_on is base.c.type assert class_mapper(C).polymorphic_on is base.c.type def test_load_multiple_from_middle(self): - C, B, E, D, base = (self.classes.C, - self.classes.B, - self.classes.E, - self.classes.D, - self.tables.base) + C, B, E, D, base = ( + self.classes.C, + self.classes.B, + self.classes.E, + self.classes.D, + self.tables.base, + ) s = Session() s.add_all([C(), D(), E()]) - eq_( - s.query(B).order_by(base.c.type).all(), - [C(), D()] - ) + eq_(s.query(B).order_by(base.c.type).all(), [C(), D()]) class DeleteOrphanTest(fixtures.MappedTest): @@ -2775,19 +3114,27 @@ class DeleteOrphanTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): global single, parent - single = Table('single', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(50), nullable=False), - Column('data', String(50)), - Column('parent_id', Integer, ForeignKey( - 'parent.id'), nullable=False), - ) - - parent = Table('parent', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50))) + single = Table( + "single", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(50), nullable=False), + Column("data", String(50)), + Column( + "parent_id", Integer, ForeignKey("parent.id"), nullable=False + ), + ) + + parent = Table( + "parent", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) def test_orphan_message(self): class Base(fixtures.BasicEntity): @@ -2799,61 +3146,77 @@ class DeleteOrphanTest(fixtures.MappedTest): class Parent(fixtures.BasicEntity): pass - mapper(Base, single, polymorphic_on=single.c.type, - polymorphic_identity='base') - mapper(SubClass, inherits=Base, polymorphic_identity='sub') - mapper(Parent, parent, properties={ - 'related': relationship(Base, cascade="all, delete-orphan") - }) + mapper( + Base, + single, + polymorphic_on=single.c.type, + polymorphic_identity="base", + ) + mapper(SubClass, inherits=Base, polymorphic_identity="sub") + mapper( + Parent, + parent, + properties={ + "related": relationship(Base, cascade="all, delete-orphan") + }, + ) sess = create_session() - s1 = SubClass(data='s1') + s1 = SubClass(data="s1") sess.add(s1) assert_raises(sa_exc.DBAPIError, sess.flush) class PolymorphicUnionTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def _fixture(self): - t1 = table('t1', column('c1', Integer), - column('c2', Integer), - column('c3', Integer)) - t2 = table('t2', column('c1', Integer), column('c2', Integer), - column('c3', Integer), - column('c4', Integer)) - t3 = table('t3', column('c1', Integer), - column('c3', Integer), - column('c5', Integer)) + t1 = table( + "t1", + column("c1", Integer), + column("c2", Integer), + column("c3", Integer), + ) + t2 = table( + "t2", + column("c1", Integer), + column("c2", Integer), + column("c3", Integer), + column("c4", Integer), + ) + t3 = table( + "t3", + column("c1", Integer), + column("c3", Integer), + column("c5", Integer), + ) return t1, t2, t3 def test_type_col_present(self): t1, t2, t3 = self._fixture() self.assert_compile( polymorphic_union( - util.OrderedDict([("a", t1), ("b", t2), ("c", t3)]), - 'q1' + util.OrderedDict([("a", t1), ("b", t2), ("c", t3)]), "q1" ), "SELECT t1.c1, t1.c2, t1.c3, CAST(NULL AS INTEGER) AS c4, " "CAST(NULL AS INTEGER) AS c5, 'a' AS q1 FROM t1 UNION ALL " "SELECT t2.c1, t2.c2, t2.c3, t2.c4, CAST(NULL AS INTEGER) AS c5, " "'b' AS q1 FROM t2 UNION ALL SELECT t3.c1, " "CAST(NULL AS INTEGER) AS c2, t3.c3, CAST(NULL AS INTEGER) AS c4, " - "t3.c5, 'c' AS q1 FROM t3" + "t3.c5, 'c' AS q1 FROM t3", ) def test_type_col_non_present(self): t1, t2, t3 = self._fixture() self.assert_compile( polymorphic_union( - util.OrderedDict([("a", t1), ("b", t2), ("c", t3)]), - None + util.OrderedDict([("a", t1), ("b", t2), ("c", t3)]), None ), "SELECT t1.c1, t1.c2, t1.c3, CAST(NULL AS INTEGER) AS c4, " "CAST(NULL AS INTEGER) AS c5 FROM t1 UNION ALL SELECT t2.c1, " "t2.c2, t2.c3, t2.c4, CAST(NULL AS INTEGER) AS c5 FROM t2 " "UNION ALL SELECT t3.c1, CAST(NULL AS INTEGER) AS c2, t3.c3, " - "CAST(NULL AS INTEGER) AS c4, t3.c5 FROM t3" + "CAST(NULL AS INTEGER) AS c4, t3.c5 FROM t3", ) def test_no_cast_null(self): @@ -2861,26 +3224,33 @@ class PolymorphicUnionTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( polymorphic_union( util.OrderedDict([("a", t1), ("b", t2), ("c", t3)]), - 'q1', cast_nulls=False + "q1", + cast_nulls=False, ), "SELECT t1.c1, t1.c2, t1.c3, NULL AS c4, NULL AS c5, 'a' AS q1 " "FROM t1 UNION ALL SELECT t2.c1, t2.c2, t2.c3, t2.c4, NULL AS c5, " "'b' AS q1 FROM t2 UNION ALL SELECT t3.c1, NULL AS c2, t3.c3, " - "NULL AS c4, t3.c5, 'c' AS q1 FROM t3" + "NULL AS c4, t3.c5, 'c' AS q1 FROM t3", ) class NameConflictTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - content = Table('content', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(30))) - foo = Table('foo', metadata, - Column('id', Integer, ForeignKey('content.id'), - primary_key=True), - Column('content_type', String(30))) + content = Table( + "content", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(30)), + ) + foo = Table( + "foo", + metadata, + Column("id", Integer, ForeignKey("content.id"), primary_key=True), + Column("content_type", String(30)), + ) def test_name_conflict(self): class Content(object): @@ -2888,15 +3258,20 @@ class NameConflictTest(fixtures.MappedTest): class Foo(Content): pass - mapper(Content, self.tables.content, - polymorphic_on=self.tables.content.c.type) - mapper(Foo, self.tables.foo, inherits=Content, - polymorphic_identity='foo') + + mapper( + Content, + self.tables.content, + polymorphic_on=self.tables.content.c.type, + ) + mapper( + Foo, self.tables.foo, inherits=Content, polymorphic_identity="foo" + ) sess = create_session() f = Foo() - f.content_type = 'bar' + f.content_type = "bar" sess.add(f) sess.flush() f_id = f.id sess.expunge_all() - assert sess.query(Content).get(f_id).content_type == 'bar' + assert sess.query(Content).get(f_id).content_type == "bar" diff --git a/test/orm/inheritance/test_concrete.py b/test/orm/inheritance/test_concrete.py index edf2c4bdc3..3c1b4b08eb 100644 --- a/test/orm/inheritance/test_concrete.py +++ b/test/orm/inheritance/test_concrete.py @@ -1,5 +1,4 @@ -from sqlalchemy.testing import eq_, assert_raises, \ - assert_raises_message +from sqlalchemy.testing import eq_, assert_raises, assert_raises_message from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm import exc as orm_exc @@ -15,51 +14,55 @@ from sqlalchemy.testing import mock class Employee(object): - def __init__(self, name): self.name = name def __repr__(self): - return self.__class__.__name__ + ' ' + self.name + return self.__class__.__name__ + " " + self.name class Manager(Employee): - def __init__(self, name, manager_data): self.name = name self.manager_data = manager_data def __repr__(self): - return self.__class__.__name__ + ' ' + self.name + ' ' \ - + self.manager_data + return ( + self.__class__.__name__ + " " + self.name + " " + self.manager_data + ) class Engineer(Employee): - def __init__(self, name, engineer_info): self.name = name self.engineer_info = engineer_info def __repr__(self): - return self.__class__.__name__ + ' ' + self.name + ' ' \ + return ( + self.__class__.__name__ + + " " + + self.name + + " " + self.engineer_info + ) class Hacker(Engineer): - - def __init__( - self, - name, - nickname, - engineer_info, - ): + def __init__(self, name, nickname, engineer_info): self.name = name self.nickname = nickname self.engineer_info = engineer_info def __repr__(self): - return self.__class__.__name__ + ' ' + self.name + " '" \ - + self.nickname + "' " + self.engineer_info + return ( + self.__class__.__name__ + + " " + + self.name + + " '" + + self.nickname + + "' " + + self.engineer_info + ) class Company(object): @@ -67,118 +70,174 @@ class Company(object): class ConcreteTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - global managers_table, engineers_table, hackers_table, \ - companies, employees_table - companies = Table('companies', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50))) - employees_table = Table('employees', metadata, - Column('employee_id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('company_id', Integer, - ForeignKey('companies.id'))) + global managers_table, engineers_table, hackers_table, companies, employees_table + companies = Table( + "companies", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), + ) + employees_table = Table( + "employees", + metadata, + Column( + "employee_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("company_id", Integer, ForeignKey("companies.id")), + ) managers_table = Table( - 'managers', + "managers", metadata, - Column('employee_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('manager_data', String(50)), - Column('company_id', Integer, ForeignKey('companies.id'))) + Column( + "employee_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("manager_data", String(50)), + Column("company_id", Integer, ForeignKey("companies.id")), + ) engineers_table = Table( - 'engineers', + "engineers", metadata, - Column('employee_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('engineer_info', String(50)), - Column('company_id', Integer, ForeignKey('companies.id'))) + Column( + "employee_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("engineer_info", String(50)), + Column("company_id", Integer, ForeignKey("companies.id")), + ) hackers_table = Table( - 'hackers', + "hackers", metadata, - Column('employee_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('engineer_info', String(50)), - Column('company_id', Integer, ForeignKey('companies.id')), - Column('nickname', String(50))) + Column( + "employee_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("engineer_info", String(50)), + Column("company_id", Integer, ForeignKey("companies.id")), + Column("nickname", String(50)), + ) def test_basic(self): pjoin = polymorphic_union( - {'manager': managers_table, 'engineer': engineers_table}, - 'type', 'pjoin') - employee_mapper = mapper(Employee, pjoin, - polymorphic_on=pjoin.c.type) - manager_mapper = mapper(Manager, managers_table, - inherits=employee_mapper, - concrete=True, - polymorphic_identity='manager') - engineer_mapper = mapper(Engineer, engineers_table, - inherits=employee_mapper, - concrete=True, - polymorphic_identity='engineer') + {"manager": managers_table, "engineer": engineers_table}, + "type", + "pjoin", + ) + employee_mapper = mapper(Employee, pjoin, polymorphic_on=pjoin.c.type) + manager_mapper = mapper( + Manager, + managers_table, + inherits=employee_mapper, + concrete=True, + polymorphic_identity="manager", + ) + engineer_mapper = mapper( + Engineer, + engineers_table, + inherits=employee_mapper, + concrete=True, + polymorphic_identity="engineer", + ) session = create_session() - session.add(Manager('Tom', 'knows how to manage things')) - session.add(Engineer('Kurt', 'knows how to hack')) + session.add(Manager("Tom", "knows how to manage things")) + session.add(Engineer("Kurt", "knows how to hack")) session.flush() session.expunge_all() - assert set([repr(x) for x in session.query(Employee)]) \ - == set(['Engineer Kurt knows how to hack', - 'Manager Tom knows how to manage things']) - assert set([repr(x) for x in session.query(Manager)]) \ - == set(['Manager Tom knows how to manage things']) - assert set([repr(x) for x in session.query(Engineer)]) \ - == set(['Engineer Kurt knows how to hack']) + assert set([repr(x) for x in session.query(Employee)]) == set( + [ + "Engineer Kurt knows how to hack", + "Manager Tom knows how to manage things", + ] + ) + assert set([repr(x) for x in session.query(Manager)]) == set( + ["Manager Tom knows how to manage things"] + ) + assert set([repr(x) for x in session.query(Engineer)]) == set( + ["Engineer Kurt knows how to hack"] + ) manager = session.query(Manager).one() - session.expire(manager, ['manager_data']) - eq_(manager.manager_data, 'knows how to manage things') + session.expire(manager, ["manager_data"]) + eq_(manager.manager_data, "knows how to manage things") def test_multi_level_no_base(self): pjoin = polymorphic_union( - {'manager': managers_table, 'engineer': engineers_table, - 'hacker': hackers_table}, - 'type', 'pjoin') - pjoin2 = polymorphic_union({'engineer': engineers_table, - 'hacker': hackers_table}, 'type', - 'pjoin2') - employee_mapper = mapper(Employee, pjoin, - polymorphic_on=pjoin.c.type) - manager_mapper = mapper(Manager, managers_table, - inherits=employee_mapper, - concrete=True, - polymorphic_identity='manager') + { + "manager": managers_table, + "engineer": engineers_table, + "hacker": hackers_table, + }, + "type", + "pjoin", + ) + pjoin2 = polymorphic_union( + {"engineer": engineers_table, "hacker": hackers_table}, + "type", + "pjoin2", + ) + employee_mapper = mapper(Employee, pjoin, polymorphic_on=pjoin.c.type) + manager_mapper = mapper( + Manager, + managers_table, + inherits=employee_mapper, + concrete=True, + polymorphic_identity="manager", + ) engineer_mapper = mapper( Engineer, engineers_table, - with_polymorphic=('*', pjoin2), + with_polymorphic=("*", pjoin2), polymorphic_on=pjoin2.c.type, inherits=employee_mapper, concrete=True, - polymorphic_identity='engineer') - hacker_mapper = mapper(Hacker, hackers_table, - inherits=engineer_mapper, concrete=True, - polymorphic_identity='hacker') + polymorphic_identity="engineer", + ) + hacker_mapper = mapper( + Hacker, + hackers_table, + inherits=engineer_mapper, + concrete=True, + polymorphic_identity="hacker", + ) session = create_session() - tom = Manager('Tom', 'knows how to manage things') + tom = Manager("Tom", "knows how to manage things") assert_raises_message( AttributeError, "does not implement attribute .?'type' at the instance level.", - setattr, tom, "type", "sometype") + setattr, + tom, + "type", + "sometype", + ) - jerry = Engineer('Jerry', 'knows how to program') - hacker = Hacker('Kurt', 'Badass', 'knows how to hack') + jerry = Engineer("Jerry", "knows how to program") + hacker = Hacker("Kurt", "Badass", "knows how to hack") assert_raises_message( AttributeError, "does not implement attribute .?'type' at the instance level.", - setattr, hacker, "type", "sometype") + setattr, + hacker, + "type", + "sometype", + ) session.add_all((tom, jerry, hacker)) session.flush() @@ -186,49 +245,69 @@ class ConcreteTest(fixtures.MappedTest): # ensure "readonly" on save logic didn't pollute the # expired_attributes collection - assert 'nickname' \ - not in attributes.instance_state(jerry).expired_attributes - assert 'name' \ + assert ( + "nickname" not in attributes.instance_state(jerry).expired_attributes - assert 'name' \ - not in attributes.instance_state(hacker).expired_attributes - assert 'nickname' \ + ) + assert ( + "name" not in attributes.instance_state(jerry).expired_attributes + ) + assert ( + "name" not in attributes.instance_state(hacker).expired_attributes + ) + assert ( + "nickname" not in attributes.instance_state(hacker).expired_attributes + ) def go(): - eq_(jerry.name, 'Jerry') - eq_(hacker.nickname, 'Badass') + eq_(jerry.name, "Jerry") + eq_(hacker.nickname, "Badass") self.assert_sql_count(testing.db, go, 0) session.expunge_all() - assert repr(session.query(Employee).filter(Employee.name == 'Tom') - .one()) \ - == 'Manager Tom knows how to manage things' - assert repr(session.query(Manager) - .filter(Manager.name == 'Tom').one()) \ - == 'Manager Tom knows how to manage things' - assert set([repr(x) for x in session.query(Employee).all()]) \ - == set(['Engineer Jerry knows how to program', - 'Manager Tom knows how to manage things', - "Hacker Kurt 'Badass' knows how to hack"]) - assert set([repr(x) for x in session.query(Manager).all()]) \ - == set(['Manager Tom knows how to manage things']) - assert set([repr(x) for x in session.query(Engineer).all()]) \ - == set(['Engineer Jerry knows how to program', - "Hacker Kurt 'Badass' knows how to hack"]) - assert set([repr(x) for x in session.query(Hacker).all()]) \ - == set(["Hacker Kurt 'Badass' knows how to hack"]) + assert ( + repr(session.query(Employee).filter(Employee.name == "Tom").one()) + == "Manager Tom knows how to manage things" + ) + assert ( + repr(session.query(Manager).filter(Manager.name == "Tom").one()) + == "Manager Tom knows how to manage things" + ) + assert set([repr(x) for x in session.query(Employee).all()]) == set( + [ + "Engineer Jerry knows how to program", + "Manager Tom knows how to manage things", + "Hacker Kurt 'Badass' knows how to hack", + ] + ) + assert set([repr(x) for x in session.query(Manager).all()]) == set( + ["Manager Tom knows how to manage things"] + ) + assert set([repr(x) for x in session.query(Engineer).all()]) == set( + [ + "Engineer Jerry knows how to program", + "Hacker Kurt 'Badass' knows how to hack", + ] + ) + assert set([repr(x) for x in session.query(Hacker).all()]) == set( + ["Hacker Kurt 'Badass' knows how to hack"] + ) def test_multi_level_no_base_w_hybrid(self): pjoin = polymorphic_union( - {'manager': managers_table, 'engineer': engineers_table, - 'hacker': hackers_table}, - 'type', 'pjoin') + { + "manager": managers_table, + "engineer": engineers_table, + "hacker": hackers_table, + }, + "type", + "pjoin", + ) test_calls = mock.Mock() class ManagerWHybrid(Employee): - def __init__(self, name, manager_data): self.name = name self.manager_data = manager_data @@ -243,30 +322,30 @@ class ConcreteTest(fixtures.MappedTest): test_calls.engineer_info_class() return cls.manager_data - employee_mapper = mapper(Employee, pjoin, - polymorphic_on=pjoin.c.type) + employee_mapper = mapper(Employee, pjoin, polymorphic_on=pjoin.c.type) mapper( - ManagerWHybrid, managers_table, + ManagerWHybrid, + managers_table, inherits=employee_mapper, concrete=True, - polymorphic_identity='manager') + polymorphic_identity="manager", + ) mapper( Engineer, engineers_table, inherits=employee_mapper, concrete=True, - polymorphic_identity='engineer') + polymorphic_identity="engineer", + ) session = create_session() - tom = ManagerWHybrid('Tom', 'mgrdata') + tom = ManagerWHybrid("Tom", "mgrdata") # mapping did not impact the engineer_info # hybrid in any way eq_(test_calls.mock_calls, []) - eq_( - tom.engineer_info, "mgrdata" - ) + eq_(tom.engineer_info, "mgrdata") eq_(test_calls.mock_calls, [mock.call.engineer_info_instance()]) session.add(tom) @@ -274,56 +353,75 @@ class ConcreteTest(fixtures.MappedTest): session.close() - tom = session.query(ManagerWHybrid).filter( - ManagerWHybrid.engineer_info == 'mgrdata').one() + tom = ( + session.query(ManagerWHybrid) + .filter(ManagerWHybrid.engineer_info == "mgrdata") + .one() + ) eq_( test_calls.mock_calls, [ mock.call.engineer_info_instance(), - mock.call.engineer_info_class()] - ) - eq_( - tom.engineer_info, "mgrdata" + mock.call.engineer_info_class(), + ], ) + eq_(tom.engineer_info, "mgrdata") def test_multi_level_with_base(self): - pjoin = polymorphic_union({ - 'employee': employees_table, - 'manager': managers_table, - 'engineer': engineers_table, - 'hacker': hackers_table, - }, 'type', 'pjoin') - pjoin2 = polymorphic_union({'engineer': engineers_table, - 'hacker': hackers_table}, 'type', - 'pjoin2') - employee_mapper = mapper(Employee, employees_table, - with_polymorphic=('*', pjoin), - polymorphic_on=pjoin.c.type) - manager_mapper = mapper(Manager, managers_table, - inherits=employee_mapper, - concrete=True, - polymorphic_identity='manager') + pjoin = polymorphic_union( + { + "employee": employees_table, + "manager": managers_table, + "engineer": engineers_table, + "hacker": hackers_table, + }, + "type", + "pjoin", + ) + pjoin2 = polymorphic_union( + {"engineer": engineers_table, "hacker": hackers_table}, + "type", + "pjoin2", + ) + employee_mapper = mapper( + Employee, + employees_table, + with_polymorphic=("*", pjoin), + polymorphic_on=pjoin.c.type, + ) + manager_mapper = mapper( + Manager, + managers_table, + inherits=employee_mapper, + concrete=True, + polymorphic_identity="manager", + ) engineer_mapper = mapper( Engineer, engineers_table, - with_polymorphic=('*', pjoin2), + with_polymorphic=("*", pjoin2), polymorphic_on=pjoin2.c.type, inherits=employee_mapper, concrete=True, - polymorphic_identity='engineer') - hacker_mapper = mapper(Hacker, hackers_table, - inherits=engineer_mapper, concrete=True, - polymorphic_identity='hacker') + polymorphic_identity="engineer", + ) + hacker_mapper = mapper( + Hacker, + hackers_table, + inherits=engineer_mapper, + concrete=True, + polymorphic_identity="hacker", + ) session = create_session() - tom = Manager('Tom', 'knows how to manage things') - jerry = Engineer('Jerry', 'knows how to program') - hacker = Hacker('Kurt', 'Badass', 'knows how to hack') + tom = Manager("Tom", "knows how to manage things") + jerry = Engineer("Jerry", "knows how to program") + hacker = Hacker("Kurt", "Badass", "knows how to hack") session.add_all((tom, jerry, hacker)) session.flush() def go(): - eq_(jerry.name, 'Jerry') - eq_(hacker.nickname, 'Badass') + eq_(jerry.name, "Jerry") + eq_(hacker.nickname, "Badass") self.assert_sql_count(testing.db, go, 0) session.expunge_all() @@ -333,157 +431,273 @@ class ConcreteTest(fixtures.MappedTest): # is not rendered in the statement which is only against # Employee's "pjoin" - assert len(testing.db.execute(session.query( - Employee).with_labels().statement).fetchall()) == 3 - assert set([repr(x) for x in session.query(Employee)]) \ - == set(['Engineer Jerry knows how to program', - 'Manager Tom knows how to manage things', - "Hacker Kurt 'Badass' knows how to hack"]) - assert set([repr(x) for x in session.query(Manager)]) \ - == set(['Manager Tom knows how to manage things']) - assert set([repr(x) for x in session.query(Engineer)]) \ - == set(['Engineer Jerry knows how to program', - "Hacker Kurt 'Badass' knows how to hack"]) - assert set([repr(x) for x in session.query(Hacker)]) \ - == set(["Hacker Kurt 'Badass' knows how to hack"]) + assert ( + len( + testing.db.execute( + session.query(Employee).with_labels().statement + ).fetchall() + ) + == 3 + ) + assert set([repr(x) for x in session.query(Employee)]) == set( + [ + "Engineer Jerry knows how to program", + "Manager Tom knows how to manage things", + "Hacker Kurt 'Badass' knows how to hack", + ] + ) + assert set([repr(x) for x in session.query(Manager)]) == set( + ["Manager Tom knows how to manage things"] + ) + assert set([repr(x) for x in session.query(Engineer)]) == set( + [ + "Engineer Jerry knows how to program", + "Hacker Kurt 'Badass' knows how to hack", + ] + ) + assert set([repr(x) for x in session.query(Hacker)]) == set( + ["Hacker Kurt 'Badass' knows how to hack"] + ) def test_without_default_polymorphic(self): - pjoin = polymorphic_union({ - 'employee': employees_table, - 'manager': managers_table, - 'engineer': engineers_table, - 'hacker': hackers_table, - }, 'type', 'pjoin') - pjoin2 = polymorphic_union({'engineer': engineers_table, - 'hacker': hackers_table}, 'type', - 'pjoin2') - employee_mapper = mapper(Employee, employees_table, - polymorphic_identity='employee') - manager_mapper = mapper(Manager, managers_table, - inherits=employee_mapper, - concrete=True, - polymorphic_identity='manager') - engineer_mapper = mapper(Engineer, engineers_table, - inherits=employee_mapper, - concrete=True, - polymorphic_identity='engineer') - hacker_mapper = mapper(Hacker, hackers_table, - inherits=engineer_mapper, concrete=True, - polymorphic_identity='hacker') + pjoin = polymorphic_union( + { + "employee": employees_table, + "manager": managers_table, + "engineer": engineers_table, + "hacker": hackers_table, + }, + "type", + "pjoin", + ) + pjoin2 = polymorphic_union( + {"engineer": engineers_table, "hacker": hackers_table}, + "type", + "pjoin2", + ) + employee_mapper = mapper( + Employee, employees_table, polymorphic_identity="employee" + ) + manager_mapper = mapper( + Manager, + managers_table, + inherits=employee_mapper, + concrete=True, + polymorphic_identity="manager", + ) + engineer_mapper = mapper( + Engineer, + engineers_table, + inherits=employee_mapper, + concrete=True, + polymorphic_identity="engineer", + ) + hacker_mapper = mapper( + Hacker, + hackers_table, + inherits=engineer_mapper, + concrete=True, + polymorphic_identity="hacker", + ) session = create_session() - jdoe = Employee('Jdoe') - tom = Manager('Tom', 'knows how to manage things') - jerry = Engineer('Jerry', 'knows how to program') - hacker = Hacker('Kurt', 'Badass', 'knows how to hack') + jdoe = Employee("Jdoe") + tom = Manager("Tom", "knows how to manage things") + jerry = Engineer("Jerry", "knows how to program") + hacker = Hacker("Kurt", "Badass", "knows how to hack") session.add_all((jdoe, tom, jerry, hacker)) session.flush() - eq_(len(testing.db.execute(session.query(Employee).with_polymorphic( - '*', pjoin, pjoin.c.type).with_labels().statement).fetchall()), 4) + eq_( + len( + testing.db.execute( + session.query(Employee) + .with_polymorphic("*", pjoin, pjoin.c.type) + .with_labels() + .statement + ).fetchall() + ), + 4, + ) eq_(session.query(Employee).get(jdoe.employee_id), jdoe) eq_(session.query(Engineer).get(jerry.employee_id), jerry) - eq_(set([repr(x) for x in - session.query(Employee).with_polymorphic('*', pjoin, - pjoin.c.type)]), - set(['Employee Jdoe', - 'Engineer Jerry knows how to program', - 'Manager Tom knows how to manage things', - "Hacker Kurt 'Badass' knows how to hack"])) - eq_(set([repr(x) for x in session.query(Manager)]), - set(['Manager Tom knows how to manage things'])) - eq_(set([repr(x) for x in - session.query(Engineer).with_polymorphic('*', - pjoin2, - pjoin2.c.type)]), - set(['Engineer Jerry knows how to program', - "Hacker Kurt 'Badass' knows how to hack"])) - eq_(set([repr(x) for x in session.query(Hacker)]), - set(["Hacker Kurt 'Badass' knows how to hack"])) + eq_( + set( + [ + repr(x) + for x in session.query(Employee).with_polymorphic( + "*", pjoin, pjoin.c.type + ) + ] + ), + set( + [ + "Employee Jdoe", + "Engineer Jerry knows how to program", + "Manager Tom knows how to manage things", + "Hacker Kurt 'Badass' knows how to hack", + ] + ), + ) + eq_( + set([repr(x) for x in session.query(Manager)]), + set(["Manager Tom knows how to manage things"]), + ) + eq_( + set( + [ + repr(x) + for x in session.query(Engineer).with_polymorphic( + "*", pjoin2, pjoin2.c.type + ) + ] + ), + set( + [ + "Engineer Jerry knows how to program", + "Hacker Kurt 'Badass' knows how to hack", + ] + ), + ) + eq_( + set([repr(x) for x in session.query(Hacker)]), + set(["Hacker Kurt 'Badass' knows how to hack"]), + ) # test adaption of the column by wrapping the query in a # subquery - eq_(len(testing.db.execute(session.query(Engineer).with_polymorphic( - '*', pjoin2, pjoin2.c.type).from_self().statement).fetchall()), 2) - eq_(set([repr(x) for x in - session.query(Engineer) - .with_polymorphic('*', pjoin2, pjoin2.c.type) - .from_self()]), - set(['Engineer Jerry knows how to program', - "Hacker Kurt 'Badass' knows how to hack"])) + eq_( + len( + testing.db.execute( + session.query(Engineer) + .with_polymorphic("*", pjoin2, pjoin2.c.type) + .from_self() + .statement + ).fetchall() + ), + 2, + ) + eq_( + set( + [ + repr(x) + for x in session.query(Engineer) + .with_polymorphic("*", pjoin2, pjoin2.c.type) + .from_self() + ] + ), + set( + [ + "Engineer Jerry knows how to program", + "Hacker Kurt 'Badass' knows how to hack", + ] + ), + ) def test_relationship(self): pjoin = polymorphic_union( - {'manager': managers_table, 'engineer': engineers_table}, - 'type', 'pjoin') - mapper(Company, companies, properties={ - 'employees': relationship(Employee)}) - employee_mapper = mapper(Employee, pjoin, - polymorphic_on=pjoin.c.type) - manager_mapper = mapper(Manager, managers_table, - inherits=employee_mapper, - concrete=True, - polymorphic_identity='manager') - engineer_mapper = mapper(Engineer, engineers_table, - inherits=employee_mapper, - concrete=True, - polymorphic_identity='engineer') + {"manager": managers_table, "engineer": engineers_table}, + "type", + "pjoin", + ) + mapper( + Company, + companies, + properties={"employees": relationship(Employee)}, + ) + employee_mapper = mapper(Employee, pjoin, polymorphic_on=pjoin.c.type) + manager_mapper = mapper( + Manager, + managers_table, + inherits=employee_mapper, + concrete=True, + polymorphic_identity="manager", + ) + engineer_mapper = mapper( + Engineer, + engineers_table, + inherits=employee_mapper, + concrete=True, + polymorphic_identity="engineer", + ) session = create_session() c = Company() - c.employees.append(Manager('Tom', 'knows how to manage things')) - c.employees.append(Engineer('Kurt', 'knows how to hack')) + c.employees.append(Manager("Tom", "knows how to manage things")) + c.employees.append(Engineer("Kurt", "knows how to hack")) session.add(c) session.flush() session.expunge_all() def go(): c2 = session.query(Company).get(c.id) - assert set([repr(x) for x in c2.employees]) \ - == set(['Engineer Kurt knows how to hack', - 'Manager Tom knows how to manage things']) + assert set([repr(x) for x in c2.employees]) == set( + [ + "Engineer Kurt knows how to hack", + "Manager Tom knows how to manage things", + ] + ) self.assert_sql_count(testing.db, go, 2) session.expunge_all() def go(): - c2 = \ - session.query(Company).options( - joinedload(Company.employees)).get(c.id) - assert set([repr(x) for x in c2.employees]) \ - == set(['Engineer Kurt knows how to hack', - 'Manager Tom knows how to manage things']) + c2 = ( + session.query(Company) + .options(joinedload(Company.employees)) + .get(c.id) + ) + assert set([repr(x) for x in c2.employees]) == set( + [ + "Engineer Kurt knows how to hack", + "Manager Tom knows how to manage things", + ] + ) self.assert_sql_count(testing.db, go, 1) class PropertyInheritanceTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('a_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('some_dest_id', Integer, ForeignKey('dest_table.id')), - Column('aname', String(50))) - Table('b_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('some_dest_id', Integer, ForeignKey('dest_table.id')), - Column('bname', String(50))) - - Table('c_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('some_dest_id', Integer, ForeignKey('dest_table.id')), - Column('cname', String(50))) - - Table('dest_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50))) + Table( + "a_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("some_dest_id", Integer, ForeignKey("dest_table.id")), + Column("aname", String(50)), + ) + Table( + "b_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("some_dest_id", Integer, ForeignKey("dest_table.id")), + Column("bname", String(50)), + ) + + Table( + "c_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("some_dest_id", Integer, ForeignKey("dest_table.id")), + Column("cname", String(50)), + ) + + Table( + "dest_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), + ) @classmethod def setup_classes(cls): - class A(cls.Comparable): pass @@ -497,61 +711,78 @@ class PropertyInheritanceTest(fixtures.MappedTest): pass def test_noninherited_warning(self): - A, B, b_table, a_table, Dest, dest_table = (self.classes.A, - self.classes.B, - self.tables.b_table, - self.tables.a_table, - self.classes.Dest, - self.tables.dest_table) - - mapper(A, a_table, properties={'some_dest': relationship(Dest)}) + A, B, b_table, a_table, Dest, dest_table = ( + self.classes.A, + self.classes.B, + self.tables.b_table, + self.tables.a_table, + self.classes.Dest, + self.tables.dest_table, + ) + + mapper(A, a_table, properties={"some_dest": relationship(Dest)}) mapper(B, b_table, inherits=A, concrete=True) mapper(Dest, dest_table) b = B() dest = Dest() - assert_raises(AttributeError, setattr, b, 'some_dest', dest) + assert_raises(AttributeError, setattr, b, "some_dest", dest) clear_mappers() - mapper(A, a_table, properties={'a_id': a_table.c.id}) + mapper(A, a_table, properties={"a_id": a_table.c.id}) mapper(B, b_table, inherits=A, concrete=True) mapper(Dest, dest_table) b = B() - assert_raises(AttributeError, setattr, b, 'a_id', 3) + assert_raises(AttributeError, setattr, b, "a_id", 3) clear_mappers() - mapper(A, a_table, properties={'a_id': a_table.c.id}) + mapper(A, a_table, properties={"a_id": a_table.c.id}) mapper(B, b_table, inherits=A, concrete=True) mapper(Dest, dest_table) def test_inheriting(self): - A, B, b_table, a_table, Dest, dest_table = (self.classes.A, - self.classes.B, - self.tables.b_table, - self.tables.a_table, - self.classes.Dest, - self.tables.dest_table) - - mapper(A, a_table, properties={ - 'some_dest': relationship(Dest, back_populates='many_a') - }) - mapper(B, b_table, inherits=A, concrete=True, - properties={ - 'some_dest': relationship(Dest, back_populates='many_b') - }) - - mapper(Dest, dest_table, properties={ - 'many_a': relationship(A, back_populates='some_dest'), - 'many_b': relationship(B, back_populates='some_dest') - }) + A, B, b_table, a_table, Dest, dest_table = ( + self.classes.A, + self.classes.B, + self.tables.b_table, + self.tables.a_table, + self.classes.Dest, + self.tables.dest_table, + ) + + mapper( + A, + a_table, + properties={ + "some_dest": relationship(Dest, back_populates="many_a") + }, + ) + mapper( + B, + b_table, + inherits=A, + concrete=True, + properties={ + "some_dest": relationship(Dest, back_populates="many_b") + }, + ) + + mapper( + Dest, + dest_table, + properties={ + "many_a": relationship(A, back_populates="some_dest"), + "many_b": relationship(B, back_populates="some_dest"), + }, + ) sess = sessionmaker()() - dest1 = Dest(name='c1') - dest2 = Dest(name='c2') - a1 = A(some_dest=dest1, aname='a1') - a2 = A(some_dest=dest2, aname='a2') - b1 = B(some_dest=dest1, bname='b1') - b2 = B(some_dest=dest1, bname='b2') - assert_raises(AttributeError, setattr, b1, 'aname', 'foo') - assert_raises(AttributeError, getattr, A, 'bname') + dest1 = Dest(name="c1") + dest2 = Dest(name="c2") + a1 = A(some_dest=dest1, aname="a1") + a2 = A(some_dest=dest2, aname="a2") + b1 = B(some_dest=dest1, bname="b1") + b2 = B(some_dest=dest1, bname="b2") + assert_raises(AttributeError, setattr, b1, "aname", "foo") + assert_raises(AttributeError, getattr, A, "bname") assert dest2.many_a == [a2] assert dest1.many_a == [a1] assert dest1.many_b == [b1, b2] @@ -561,7 +792,7 @@ class PropertyInheritanceTest(fixtures.MappedTest): assert dest2.many_a == [a2] assert dest1.many_a == [a1] assert dest1.many_b == [b1, b2] - assert sess.query(B).filter(B.bname == 'b1').one() is b1 + assert sess.query(B).filter(B.bname == "b1").one() is b1 def test_overlapping_backref_relationship(self): A, B, b_table, a_table, Dest, dest_table = ( @@ -570,15 +801,20 @@ class PropertyInheritanceTest(fixtures.MappedTest): self.tables.b_table, self.tables.a_table, self.classes.Dest, - self.tables.dest_table) + self.tables.dest_table, + ) # test issue #3630, no error or warning is generated mapper(A, a_table) mapper(B, b_table, inherits=A, concrete=True) - mapper(Dest, dest_table, properties={ - 'a': relationship(A, backref='dest'), - 'a1': relationship(B, backref='dest') - }) + mapper( + Dest, + dest_table, + properties={ + "a": relationship(A, backref="dest"), + "a1": relationship(B, backref="dest"), + }, + ) configure_mappers() def test_overlapping_forwards_relationship(self): @@ -588,16 +824,21 @@ class PropertyInheritanceTest(fixtures.MappedTest): self.tables.b_table, self.tables.a_table, self.classes.Dest, - self.tables.dest_table) + self.tables.dest_table, + ) # this is the opposite mapping as that of #3630, never generated # an error / warning - mapper(A, a_table, properties={ - 'dest': relationship(Dest, backref='a') - }) - mapper(B, b_table, inherits=A, concrete=True, properties={ - 'dest': relationship(Dest, backref='a1') - }) + mapper( + A, a_table, properties={"dest": relationship(Dest, backref="a")} + ) + mapper( + B, + b_table, + inherits=A, + concrete=True, + properties={"dest": relationship(Dest, backref="a1")}, + ) mapper(Dest, dest_table) configure_mappers() @@ -606,20 +847,27 @@ class PropertyInheritanceTest(fixtures.MappedTest): attribute.""" A, C, B, c_table, b_table, a_table, Dest, dest_table = ( - self.classes.A, self.classes.C, self.classes.B, self.tables. - c_table, self.tables.b_table, self.tables.a_table, self.classes. - Dest, self.tables.dest_table) + self.classes.A, + self.classes.C, + self.classes.B, + self.tables.c_table, + self.tables.b_table, + self.tables.a_table, + self.classes.Dest, + self.tables.dest_table, + ) - ajoin = polymorphic_union({'a': a_table, 'b': b_table, 'c': c_table}, - 'type', 'ajoin') + ajoin = polymorphic_union( + {"a": a_table, "b": b_table, "c": c_table}, "type", "ajoin" + ) mapper( A, a_table, - with_polymorphic=('*', ajoin), + with_polymorphic=("*", ajoin), polymorphic_on=ajoin.c.type, - polymorphic_identity='a', + polymorphic_identity="a", properties={ - 'some_dest': relationship(Dest, back_populates='many_a') + "some_dest": relationship(Dest, back_populates="many_a") }, ) mapper( @@ -627,9 +875,10 @@ class PropertyInheritanceTest(fixtures.MappedTest): b_table, inherits=A, concrete=True, - polymorphic_identity='b', + polymorphic_identity="b", properties={ - 'some_dest': relationship(Dest, back_populates='many_a')}, + "some_dest": relationship(Dest, back_populates="many_a") + }, ) mapper( @@ -637,25 +886,31 @@ class PropertyInheritanceTest(fixtures.MappedTest): c_table, inherits=A, concrete=True, - polymorphic_identity='c', + polymorphic_identity="c", properties={ - 'some_dest': relationship(Dest, back_populates='many_a')}, + "some_dest": relationship(Dest, back_populates="many_a") + }, ) - mapper(Dest, dest_table, properties={ - 'many_a': relationship(A, - back_populates='some_dest', - order_by=ajoin.c.id)}) + mapper( + Dest, + dest_table, + properties={ + "many_a": relationship( + A, back_populates="some_dest", order_by=ajoin.c.id + ) + }, + ) sess = sessionmaker()() - dest1 = Dest(name='c1') - dest2 = Dest(name='c2') - a1 = A(some_dest=dest1, aname='a1', id=1) - a2 = A(some_dest=dest2, aname='a2', id=2) - b1 = B(some_dest=dest1, bname='b1', id=3) - b2 = B(some_dest=dest1, bname='b2', id=4) - c1 = C(some_dest=dest1, cname='c1', id=5) - c2 = C(some_dest=dest2, cname='c2', id=6) + dest1 = Dest(name="c1") + dest2 = Dest(name="c2") + a1 = A(some_dest=dest1, aname="a1", id=1) + a2 = A(some_dest=dest2, aname="a2", id=2) + b1 = B(some_dest=dest1, bname="b1", id=3) + b2 = B(some_dest=dest1, bname="b2", id=4) + c1 = C(some_dest=dest1, cname="c1", id=5) + c2 = C(some_dest=dest2, cname="c2", id=6) eq_([a2, c2], dest2.many_a) eq_([a1, b1, b2, c1], dest1.many_a) @@ -673,42 +928,58 @@ class PropertyInheritanceTest(fixtures.MappedTest): def go(): eq_( [ - Dest(many_a=[A(aname='a1'), - B(bname='b1'), - B(bname='b2'), - C(cname='c1')]), - Dest(many_a=[A(aname='a2'), C(cname='c2')])], - sess.query(Dest).options(joinedload(Dest.many_a)) - .order_by(Dest.id).all()) + Dest( + many_a=[ + A(aname="a1"), + B(bname="b1"), + B(bname="b2"), + C(cname="c1"), + ] + ), + Dest(many_a=[A(aname="a2"), C(cname="c2")]), + ], + sess.query(Dest) + .options(joinedload(Dest.many_a)) + .order_by(Dest.id) + .all(), + ) self.assert_sql_count(testing.db, go, 1) def test_merge_w_relationship(self): A, C, B, c_table, b_table, a_table, Dest, dest_table = ( - self.classes.A, self.classes.C, self.classes.B, self.tables. - c_table, self.tables.b_table, self.tables.a_table, self.classes. - Dest, self.tables.dest_table) + self.classes.A, + self.classes.C, + self.classes.B, + self.tables.c_table, + self.tables.b_table, + self.tables.a_table, + self.classes.Dest, + self.tables.dest_table, + ) - ajoin = polymorphic_union({'a': a_table, 'b': b_table, 'c': c_table}, - 'type', 'ajoin') + ajoin = polymorphic_union( + {"a": a_table, "b": b_table, "c": c_table}, "type", "ajoin" + ) mapper( A, a_table, - with_polymorphic=('*', ajoin), + with_polymorphic=("*", ajoin), polymorphic_on=ajoin.c.type, - polymorphic_identity='a', + polymorphic_identity="a", properties={ - 'some_dest': relationship(Dest, back_populates='many_a') - } + "some_dest": relationship(Dest, back_populates="many_a") + }, ) mapper( B, b_table, inherits=A, concrete=True, - polymorphic_identity='b', + polymorphic_identity="b", properties={ - 'some_dest': relationship(Dest, back_populates='many_a')} + "some_dest": relationship(Dest, back_populates="many_a") + }, ) mapper( @@ -716,57 +987,89 @@ class PropertyInheritanceTest(fixtures.MappedTest): c_table, inherits=A, concrete=True, - polymorphic_identity='c', + polymorphic_identity="c", properties={ - 'some_dest': relationship(Dest, back_populates='many_a')} + "some_dest": relationship(Dest, back_populates="many_a") + }, ) - mapper(Dest, dest_table, properties={ - 'many_a': relationship(A, - back_populates='some_dest', - order_by=ajoin.c.id) - }) + mapper( + Dest, + dest_table, + properties={ + "many_a": relationship( + A, back_populates="some_dest", order_by=ajoin.c.id + ) + }, + ) assert C.some_dest.property.parent is class_mapper(C) assert B.some_dest.property.parent is class_mapper(B) assert A.some_dest.property.parent is class_mapper(A) sess = sessionmaker()() - dest1 = Dest(name='d1') - dest2 = Dest(name='d2') - a1 = A(some_dest=dest2, aname='a1') - b1 = B(some_dest=dest1, bname='b1') - c1 = C(some_dest=dest2, cname='c1') + dest1 = Dest(name="d1") + dest2 = Dest(name="d2") + a1 = A(some_dest=dest2, aname="a1") + b1 = B(some_dest=dest1, bname="b1") + c1 = C(some_dest=dest2, cname="c1") sess.add_all([dest1, dest2, c1, a1, b1]) sess.commit() sess2 = sessionmaker()() merged_c1 = sess2.merge(c1) - eq_(merged_c1.some_dest.name, 'd2') + eq_(merged_c1.some_dest.name, "d2") eq_(merged_c1.some_dest_id, c1.some_dest_id) class ManyToManyTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('base', metadata, Column('id', Integer, primary_key=True, - test_needs_autoincrement=True)) - Table('sub', metadata, Column('id', Integer, primary_key=True, - test_needs_autoincrement=True)) - Table('base_mtom', metadata, - Column('base_id', Integer, ForeignKey('base.id'), - primary_key=True), - Column('related_id', Integer, ForeignKey('related.id'), - primary_key=True)) - Table('sub_mtom', metadata, - Column('base_id', Integer, ForeignKey('sub.id'), - primary_key=True), - Column('related_id', Integer, ForeignKey('related.id'), - primary_key=True)) - Table('related', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True)) + Table( + "base", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) + Table( + "sub", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) + Table( + "base_mtom", + metadata, + Column( + "base_id", Integer, ForeignKey("base.id"), primary_key=True + ), + Column( + "related_id", + Integer, + ForeignKey("related.id"), + primary_key=True, + ), + ) + Table( + "sub_mtom", + metadata, + Column("base_id", Integer, ForeignKey("sub.id"), primary_key=True), + Column( + "related_id", + Integer, + ForeignKey("related.id"), + primary_key=True, + ), + ) + Table( + "related", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) @classmethod def setup_classes(cls): @@ -781,22 +1084,45 @@ class ManyToManyTest(fixtures.MappedTest): def test_selective_relationships(self): sub, base_mtom, Related, Base, related, sub_mtom, base, Sub = ( - self.tables.sub, self.tables.base_mtom, self.classes.Related, self. - classes.Base, self.tables.related, self.tables.sub_mtom, self. - tables.base, self.classes.Sub) - - mapper(Base, base, properties={'related': relationship( - Related, secondary=base_mtom, backref='bases', - order_by=related.c.id)}) - mapper(Sub, sub, inherits=Base, concrete=True, - properties={'related': relationship(Related, - secondary=sub_mtom, - backref='subs', - order_by=related.c.id)}) + self.tables.sub, + self.tables.base_mtom, + self.classes.Related, + self.classes.Base, + self.tables.related, + self.tables.sub_mtom, + self.tables.base, + self.classes.Sub, + ) + + mapper( + Base, + base, + properties={ + "related": relationship( + Related, + secondary=base_mtom, + backref="bases", + order_by=related.c.id, + ) + }, + ) + mapper( + Sub, + sub, + inherits=Base, + concrete=True, + properties={ + "related": relationship( + Related, + secondary=sub_mtom, + backref="subs", + order_by=related.c.id, + ) + }, + ) mapper(Related, related) sess = sessionmaker()() - b1, s1, r1, r2, r3 = Base(), Sub(), Related(), Related(), \ - Related() + b1, s1, r1, r2, r3 = Base(), Sub(), Related(), Related(), Related() b1.related.append(r1) b1.related.append(r2) s1.related.append(r2) @@ -808,38 +1134,49 @@ class ManyToManyTest(fixtures.MappedTest): class ColKeysTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): global offices_table, refugees_table refugees_table = Table( - 'refugee', metadata, + "refugee", + metadata, Column( - 'refugee_fid', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('refugee_name', String(30), - key='name')) + "refugee_fid", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("refugee_name", String(30), key="name"), + ) offices_table = Table( - 'office', metadata, + "office", + metadata, Column( - 'office_fid', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('office_name', String(30), - key='name')) + "office_fid", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("office_name", String(30), key="name"), + ) @classmethod def insert_data(cls): refugees_table.insert().execute( - dict(refugee_fid=1, name='refugee1'), - dict(refugee_fid=2, name='refugee2')) + dict(refugee_fid=1, name="refugee1"), + dict(refugee_fid=2, name="refugee2"), + ) offices_table.insert().execute( - dict(office_fid=1, name='office1'), - dict(office_fid=2, name='office2')) + dict(office_fid=1, name="office1"), + dict(office_fid=2, name="office2"), + ) def test_keys(self): pjoin = polymorphic_union( - {'refugee': refugees_table, 'office': offices_table}, - 'type', 'pjoin') + {"refugee": refugees_table, "office": offices_table}, + "type", + "pjoin", + ) class Location(object): pass @@ -850,18 +1187,28 @@ class ColKeysTest(fixtures.MappedTest): class Office(Location): pass - location_mapper = mapper(Location, pjoin, - polymorphic_on=pjoin.c.type, - polymorphic_identity='location') - office_mapper = mapper(Office, offices_table, - inherits=location_mapper, concrete=True, - polymorphic_identity='office') - refugee_mapper = mapper(Refugee, refugees_table, - inherits=location_mapper, - concrete=True, - polymorphic_identity='refugee') + location_mapper = mapper( + Location, + pjoin, + polymorphic_on=pjoin.c.type, + polymorphic_identity="location", + ) + office_mapper = mapper( + Office, + offices_table, + inherits=location_mapper, + concrete=True, + polymorphic_identity="office", + ) + refugee_mapper = mapper( + Refugee, + refugees_table, + inherits=location_mapper, + concrete=True, + polymorphic_identity="refugee", + ) sess = create_session() - eq_(sess.query(Refugee).get(1).name, 'refugee1') - eq_(sess.query(Refugee).get(2).name, 'refugee2') - eq_(sess.query(Office).get(1).name, 'office1') - eq_(sess.query(Office).get(2).name, 'office2') + eq_(sess.query(Refugee).get(1).name, "refugee1") + eq_(sess.query(Refugee).get(2).name, "refugee2") + eq_(sess.query(Office).get(1).name, "office1") + eq_(sess.query(Office).get(2).name, "office2") diff --git a/test/orm/inheritance/test_manytomany.py b/test/orm/inheritance/test_manytomany.py index 2ac873451e..c7196a365c 100644 --- a/test/orm/inheritance/test_manytomany.py +++ b/test/orm/inheritance/test_manytomany.py @@ -8,6 +8,7 @@ from sqlalchemy.testing import fixtures class InheritTest(fixtures.MappedTest): """deals with inheritance and many-to-many relationships""" + @classmethod def define_tables(cls, metadata): global principals @@ -15,34 +16,59 @@ class InheritTest(fixtures.MappedTest): global groups global user_group_map - principals = Table('principals', metadata, - Column('principal_id', Integer, - Sequence('principal_id_seq', optional=False), - primary_key=True), - Column('name', String(50), nullable=False)) - - users = Table('prin_users', metadata, - Column('principal_id', Integer, - ForeignKey('principals.principal_id'), - primary_key=True), - Column('password', String(50), nullable=False), - Column('email', String(50), nullable=False), - Column('login_id', String(50), nullable=False)) - - groups = Table('prin_groups', metadata, - Column( - 'principal_id', Integer, - ForeignKey('principals.principal_id'), - primary_key=True)) + principals = Table( + "principals", + metadata, + Column( + "principal_id", + Integer, + Sequence("principal_id_seq", optional=False), + primary_key=True, + ), + Column("name", String(50), nullable=False), + ) + + users = Table( + "prin_users", + metadata, + Column( + "principal_id", + Integer, + ForeignKey("principals.principal_id"), + primary_key=True, + ), + Column("password", String(50), nullable=False), + Column("email", String(50), nullable=False), + Column("login_id", String(50), nullable=False), + ) + + groups = Table( + "prin_groups", + metadata, + Column( + "principal_id", + Integer, + ForeignKey("principals.principal_id"), + primary_key=True, + ), + ) user_group_map = Table( - 'prin_user_group_map', metadata, + "prin_user_group_map", + metadata, Column( - 'user_id', Integer, ForeignKey("prin_users.principal_id"), - primary_key=True), + "user_id", + Integer, + ForeignKey("prin_users.principal_id"), + primary_key=True, + ), Column( - 'group_id', Integer, ForeignKey("prin_groups.principal_id"), - primary_key=True),) + "group_id", + Integer, + ForeignKey("prin_groups.principal_id"), + primary_key=True, + ), + ) def test_basic(self): class Principal(object): @@ -59,16 +85,29 @@ class InheritTest(fixtures.MappedTest): mapper(Principal, principals) mapper(User, users, inherits=Principal) - mapper(Group, groups, inherits=Principal, properties={ - 'users': relationship(User, secondary=user_group_map, - lazy='select', backref="groups") - }) + mapper( + Group, + groups, + inherits=Principal, + properties={ + "users": relationship( + User, + secondary=user_group_map, + lazy="select", + backref="groups", + ) + }, + ) g = Group(name="group1") g.users.append( User( - name="user1", password="pw", email="foo@bar.com", - login_id="lg1")) + name="user1", + password="pw", + email="foo@bar.com", + login_id="lg1", + ) + ) sess = create_session() sess.add(g) sess.flush() @@ -77,22 +116,34 @@ class InheritTest(fixtures.MappedTest): class InheritTest2(fixtures.MappedTest): """deals with inheritance and many-to-many relationships""" + @classmethod def define_tables(cls, metadata): global foo, bar, foo_bar - foo = Table('foo', metadata, - Column('id', Integer, - Sequence('foo_id_seq', optional=True), - primary_key=True), - Column('data', String(20))) - - bar = Table('bar', metadata, - Column('bid', Integer, ForeignKey('foo.id'), - primary_key=True)) - - foo_bar = Table('foo_bar', metadata, - Column('foo_id', Integer, ForeignKey('foo.id')), - Column('bar_id', Integer, ForeignKey('bar.bid'))) + foo = Table( + "foo", + metadata, + Column( + "id", + Integer, + Sequence("foo_id_seq", optional=True), + primary_key=True, + ), + Column("data", String(20)), + ) + + bar = Table( + "bar", + metadata, + Column("bid", Integer, ForeignKey("foo.id"), primary_key=True), + ) + + foo_bar = Table( + "foo_bar", + metadata, + Column("foo_id", Integer, ForeignKey("foo.id")), + Column("bar_id", Integer, ForeignKey("bar.bid")), + ) def test_get(self): class Foo(object): @@ -106,7 +157,7 @@ class InheritTest2(fixtures.MappedTest): mapper(Bar, bar, inherits=Foo) print(foo.join(bar).primary_key) print(class_mapper(Bar).primary_key) - b = Bar('somedata') + b = Bar("somedata") sess = create_session() sess.add(b) sess.flush() @@ -126,17 +177,22 @@ class InheritTest2(fixtures.MappedTest): class Bar(Foo): pass - mapper(Bar, bar, inherits=Foo, properties={ - 'foos': relationship(Foo, secondary=foo_bar, lazy='joined') - }) + mapper( + Bar, + bar, + inherits=Foo, + properties={ + "foos": relationship(Foo, secondary=foo_bar, lazy="joined") + }, + ) sess = create_session() - b = Bar('barfoo') + b = Bar("barfoo") sess.add(b) sess.flush() - f1 = Foo('subfoo1') - f2 = Foo('subfoo2') + f1 = Foo("subfoo1") + f2 = Foo("subfoo2") b.foos.append(f1) b.foos.append(f2) @@ -146,48 +202,78 @@ class InheritTest2(fixtures.MappedTest): result = sess.query(Bar).all() print(result[0]) print(result[0].foos) - self.assert_unordered_result(result, Bar, - {'id': b.id, - 'data': 'barfoo', - 'foos': ( - Foo, [{'id': f1.id, - 'data': 'subfoo1'}, - {'id': f2.id, - 'data': 'subfoo2'}])}) + self.assert_unordered_result( + result, + Bar, + { + "id": b.id, + "data": "barfoo", + "foos": ( + Foo, + [ + {"id": f1.id, "data": "subfoo1"}, + {"id": f2.id, "data": "subfoo2"}, + ], + ), + }, + ) class InheritTest3(fixtures.MappedTest): """deals with inheritance and many-to-many relationships""" + @classmethod def define_tables(cls, metadata): global foo, bar, blub, bar_foo, blub_bar, blub_foo # the 'data' columns are to appease SQLite which cant handle a blank # INSERT - foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq', optional=True), - primary_key=True), - Column('data', String(20))) - - bar = Table('bar', metadata, Column('id', Integer, ForeignKey( - 'foo.id'), primary_key=True), Column('bar_data', String(20))) - - blub = Table('blub', metadata, - Column('id', Integer, ForeignKey('bar.id'), - primary_key=True), - Column('blub_data', String(20))) - - bar_foo = Table('bar_foo', metadata, - Column('bar_id', Integer, ForeignKey('bar.id')), - Column('foo_id', Integer, ForeignKey('foo.id'))) - - blub_bar = Table('bar_blub', metadata, - Column('blub_id', Integer, ForeignKey('blub.id')), - Column('bar_id', Integer, ForeignKey('bar.id'))) - - blub_foo = Table('blub_foo', metadata, - Column('blub_id', Integer, ForeignKey('blub.id')), - Column('foo_id', Integer, ForeignKey('foo.id'))) + foo = Table( + "foo", + metadata, + Column( + "id", + Integer, + Sequence("foo_seq", optional=True), + primary_key=True, + ), + Column("data", String(20)), + ) + + bar = Table( + "bar", + metadata, + Column("id", Integer, ForeignKey("foo.id"), primary_key=True), + Column("bar_data", String(20)), + ) + + blub = Table( + "blub", + metadata, + Column("id", Integer, ForeignKey("bar.id"), primary_key=True), + Column("blub_data", String(20)), + ) + + bar_foo = Table( + "bar_foo", + metadata, + Column("bar_id", Integer, ForeignKey("bar.id")), + Column("foo_id", Integer, ForeignKey("foo.id")), + ) + + blub_bar = Table( + "bar_blub", + metadata, + Column("blub_id", Integer, ForeignKey("blub.id")), + Column("bar_id", Integer, ForeignKey("bar.id")), + ) + + blub_foo = Table( + "blub_foo", + metadata, + Column("blub_id", Integer, ForeignKey("blub.id")), + Column("foo_id", Integer, ForeignKey("foo.id")), + ) def test_basic(self): class Foo(object): @@ -196,18 +282,24 @@ class InheritTest3(fixtures.MappedTest): def __repr__(self): return "Foo id %d, data %s" % (self.id, self.data) + mapper(Foo, foo) class Bar(Foo): def __repr__(self): return "Bar id %d, data %s" % (self.id, self.data) - mapper(Bar, bar, inherits=Foo, properties={ - 'foos': relationship(Foo, secondary=bar_foo, lazy='select') - }) + mapper( + Bar, + bar, + inherits=Foo, + properties={ + "foos": relationship(Foo, secondary=bar_foo, lazy="select") + }, + ) sess = create_session() - b = Bar('bar #1') + b = Bar("bar #1") sess.add(b) b.foos.append(Foo("foo #1")) b.foos.append(Foo("foo #2")) @@ -226,23 +318,33 @@ class InheritTest3(fixtures.MappedTest): def __repr__(self): return "Foo id %d, data %s" % (self.id, self.data) + mapper(Foo, foo) class Bar(Foo): def __repr__(self): return "Bar id %d, data %s" % (self.id, self.data) + mapper(Bar, bar, inherits=Foo) class Blub(Bar): def __repr__(self): return "Blub id %d, data %s, bars %s, foos %s" % ( - self.id, self.data, repr([b for b in self.bars]), - repr([f for f in self.foos])) - - mapper(Blub, blub, inherits=Bar, properties={ - 'bars': relationship(Bar, secondary=blub_bar, lazy='joined'), - 'foos': relationship(Foo, secondary=blub_foo, lazy='joined'), - }) + self.id, + self.data, + repr([b for b in self.bars]), + repr([f for f in self.foos]), + ) + + mapper( + Blub, + blub, + inherits=Bar, + properties={ + "bars": relationship(Bar, secondary=blub_bar, lazy="joined"), + "foos": relationship(Foo, secondary=blub_foo, lazy="joined"), + }, + ) sess = create_session() f1 = Foo("foo #1") diff --git a/test/orm/inheritance/test_poly_linked_list.py b/test/orm/inheritance/test_poly_linked_list.py index a0fe291703..55e48f2508 100644 --- a/test/orm/inheritance/test_poly_linked_list.py +++ b/test/orm/inheritance/test_poly_linked_list.py @@ -7,41 +7,45 @@ from sqlalchemy.testing.schema import Table, Column class PolymorphicCircularTest(fixtures.MappedTest): - run_setup_mappers = 'once' + run_setup_mappers = "once" @classmethod def define_tables(cls, metadata): - global Table1, Table1B, Table2, Table3, Data + global Table1, Table1B, Table2, Table3, Data table1 = Table( - 'table1', metadata, + "table1", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), Column( - 'related_id', Integer, ForeignKey('table1.id'), - nullable=True), - Column('type', String(30)), - Column('name', String(30))) + "related_id", Integer, ForeignKey("table1.id"), nullable=True + ), + Column("type", String(30)), + Column("name", String(30)), + ) table2 = Table( - 'table2', metadata, - Column( - 'id', Integer, ForeignKey('table1.id'), - primary_key=True),) + "table2", + metadata, + Column("id", Integer, ForeignKey("table1.id"), primary_key=True), + ) table3 = Table( - 'table3', metadata, - Column( - 'id', Integer, ForeignKey('table1.id'), - primary_key=True),) + "table3", + metadata, + Column("id", Integer, ForeignKey("table1.id"), primary_key=True), + ) data = Table( - 'data', metadata, + "data", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('node_id', Integer, ForeignKey('table1.id')), - Column('data', String(30))) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("node_id", Integer, ForeignKey("table1.id")), + Column("data", String(30)), + ) # join = polymorphic_union( # { @@ -50,7 +54,7 @@ class PolymorphicCircularTest(fixtures.MappedTest): # 'table1' : table1.select(table1.c.type.in_(['table1', 'table1b'])), # }, None, 'pjoin') - join = table1.outerjoin(table2).outerjoin(table3).alias('pjoin') + join = table1.outerjoin(table2).outerjoin(table3).alias("pjoin") # join = None class Table1(object): @@ -61,8 +65,11 @@ class PolymorphicCircularTest(fixtures.MappedTest): def __repr__(self): return "%s(%s, %s, %s)" % ( - self.__class__.__name__, self.id, repr(str(self.name)), - repr(self.data)) + self.__class__.__name__, + self.id, + repr(str(self.name)), + repr(self.data), + ) class Table1B(Table1): pass @@ -79,25 +86,32 @@ class PolymorphicCircularTest(fixtures.MappedTest): def __repr__(self): return "%s(%s, %s)" % ( - self.__class__.__name__, self.id, repr(str(self.data))) + self.__class__.__name__, + self.id, + repr(str(self.data)), + ) try: # this is how the mapping used to work. ensure that this raises an # error now table1_mapper = mapper( - Table1, table1, select_table=join, + Table1, + table1, + select_table=join, polymorphic_on=table1.c.type, - polymorphic_identity='table1', + polymorphic_identity="table1", properties={ - 'nxt': relationship( + "nxt": relationship( Table1, - backref=backref('prev', - foreignkey=join.c.id, - uselist=False), + backref=backref( + "prev", foreignkey=join.c.id, uselist=False + ), uselist=False, - primaryjoin=join.c.id == join.c.related_id), - 'data': relationship(mapper(Data, data)) - }) + primaryjoin=join.c.id == join.c.related_id, + ), + "data": relationship(mapper(Data, data)), + }, + ) configure_mappers() assert False except Exception: @@ -114,36 +128,48 @@ class PolymorphicCircularTest(fixtures.MappedTest): # NOTE: using "nxt" instead of "next" to avoid 2to3 turning it into # __next__() for some reason. table1_mapper = mapper( - Table1, table1, + Table1, + table1, # select_table=join, polymorphic_on=table1.c.type, - polymorphic_identity='table1', + polymorphic_identity="table1", properties={ - 'nxt': relationship( + "nxt": relationship( Table1, backref=backref( - 'prev', remote_side=table1.c.id, uselist=False), + "prev", remote_side=table1.c.id, uselist=False + ), uselist=False, - primaryjoin=table1.c.id == table1.c.related_id), - 'data': relationship(mapper(Data, data), lazy='joined', - order_by=data.c.id) - } + primaryjoin=table1.c.id == table1.c.related_id, + ), + "data": relationship( + mapper(Data, data), lazy="joined", order_by=data.c.id + ), + }, ) table1b_mapper = mapper( - Table1B, inherits=table1_mapper, polymorphic_identity='table1b') + Table1B, inherits=table1_mapper, polymorphic_identity="table1b" + ) - table2_mapper = mapper(Table2, table2, - inherits=table1_mapper, - polymorphic_identity='table2') + table2_mapper = mapper( + Table2, + table2, + inherits=table1_mapper, + polymorphic_identity="table2", + ) table3_mapper = mapper( - Table3, table3, inherits=table1_mapper, - polymorphic_identity='table3') + Table3, + table3, + inherits=table1_mapper, + polymorphic_identity="table3", + ) configure_mappers() assert table1_mapper.primary_key == ( - table1.c.id,), table1_mapper.primary_key + table1.c.id, + ), table1_mapper.primary_key def test_one(self): self._testlist([Table1, Table2, Table1, Table2]) @@ -152,16 +178,29 @@ class PolymorphicCircularTest(fixtures.MappedTest): self._testlist([Table3]) def test_three(self): - self._testlist([Table2, Table1, Table1B, Table3, - Table3, Table1B, Table1B, Table2, Table1]) + self._testlist( + [ + Table2, + Table1, + Table1B, + Table3, + Table3, + Table1B, + Table1B, + Table2, + Table1, + ] + ) def test_four(self): - self._testlist([ - Table2('t2', [Data('data1'), Data('data2')]), - Table1('t1', []), - Table3('t3', [Data('data3')]), - Table1B('t1b', [Data('data4'), Data('data5')]) - ]) + self._testlist( + [ + Table2("t2", [Data("data1"), Data("data2")]), + Table1("t1", []), + Table3("t3", [Data("data3")]), + Table1B("t1b", [Data("data4"), Data("data5")]), + ] + ) def _testlist(self, classes): sess = create_session() @@ -171,7 +210,7 @@ class PolymorphicCircularTest(fixtures.MappedTest): obj = None for c in classes: if isinstance(c, type): - newobj = c('item %d' % count) + newobj = c("item %d" % count) count += 1 else: newobj = c @@ -188,7 +227,7 @@ class PolymorphicCircularTest(fixtures.MappedTest): # string version of the saved list assertlist = [] node = t - while (node): + while node: assertlist.append(node) n = node.nxt if n is not None: @@ -198,10 +237,14 @@ class PolymorphicCircularTest(fixtures.MappedTest): # clear and query forwards sess.expunge_all() - node = sess.query(Table1).order_by(Table1.id).\ - filter(Table1.id == t.id).first() + node = ( + sess.query(Table1) + .order_by(Table1.id) + .filter(Table1.id == t.id) + .first() + ) assertlist = [] - while (node): + while node: assertlist.append(node) n = node.nxt if n is not None: @@ -211,10 +254,14 @@ class PolymorphicCircularTest(fixtures.MappedTest): # clear and query backwards sess.expunge_all() - node = sess.query(Table1).order_by(Table1.id).\ - filter(Table1.id == obj.id).first() + node = ( + sess.query(Table1) + .order_by(Table1.id) + .filter(Table1.id == obj.id) + .first() + ) assertlist = [] - while (node): + while node: assertlist.insert(0, node) n = node.prev if n is not None: diff --git a/test/orm/inheritance/test_poly_loading.py b/test/orm/inheritance/test_poly_loading.py index 48117593be..8f4abae41a 100644 --- a/test/orm/inheritance/test_poly_loading.py +++ b/test/orm/inheritance/test_poly_loading.py @@ -1,12 +1,25 @@ from sqlalchemy import String, Integer, Column, ForeignKey -from sqlalchemy.orm import relationship, Session, joinedload, \ - selectin_polymorphic, selectinload, with_polymorphic, backref +from sqlalchemy.orm import ( + relationship, + Session, + joinedload, + selectin_polymorphic, + selectinload, + with_polymorphic, + backref, +) from sqlalchemy.testing import fixtures from sqlalchemy import testing from sqlalchemy.testing import eq_ from sqlalchemy.testing.assertsql import AllOf, CompiledSQL, EachOf, Or -from ._poly_fixtures import Company, Person, Engineer, Manager, \ - _Polymorphic, GeometryFixtureBase +from ._poly_fixtures import ( + Company, + Person, + Engineer, + Manager, + _Polymorphic, + GeometryFixtureBase, +) class BaseAndSubFixture(object): @@ -17,7 +30,7 @@ class BaseAndSubFixture(object): Base = cls.DeclarativeBasic class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) adata = Column(String(50)) bs = relationship("B") @@ -25,43 +38,48 @@ class BaseAndSubFixture(object): __mapper_args__ = { "polymorphic_on": type, - "polymorphic_identity": "a" + "polymorphic_identity": "a", } class ASub(A): - __tablename__ = 'asub' - id = Column(ForeignKey('a.id'), primary_key=True) + __tablename__ = "asub" + id = Column(ForeignKey("a.id"), primary_key=True) asubdata = Column(String(50)) cs = relationship("C") if cls.use_options: - __mapper_args__ = { - "polymorphic_identity": "asub" - } + __mapper_args__ = {"polymorphic_identity": "asub"} else: __mapper_args__ = { "polymorphic_load": "selectin", - "polymorphic_identity": "asub" + "polymorphic_identity": "asub", } class B(Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) - a_id = Column(ForeignKey('a.id')) + a_id = Column(ForeignKey("a.id")) class C(Base): - __tablename__ = 'c' + __tablename__ = "c" id = Column(Integer, primary_key=True) - a_sub_id = Column(ForeignKey('asub.id')) + a_sub_id = Column(ForeignKey("asub.id")) @classmethod def insert_data(cls): A, B, ASub, C = cls.classes("A", "B", "ASub", "C") s = Session() - s.add(A(id=1, adata='adata', bs=[B(), B()])) - s.add(ASub(id=2, adata='adata', asubdata='asubdata', - bs=[B(), B()], cs=[C(), C()])) + s.add(A(id=1, adata="adata", bs=[B(), B()])) + s.add( + ASub( + id=2, + adata="adata", + asubdata="asubdata", + bs=[B(), B()], + cs=[C(), C()], + ) + ) s.commit() @@ -79,7 +97,7 @@ class BaseAndSubFixture(object): CompiledSQL( "SELECT a.id AS a_id, a.adata AS a_adata, " "a.type AS a_type FROM a ORDER BY a.id", - {} + {}, ), AllOf( EachOf( @@ -88,7 +106,7 @@ class BaseAndSubFixture(object): "asub.asubdata AS asub_asubdata FROM a JOIN asub " "ON a.id = asub.id WHERE a.id IN ([EXPANDING_primary_keys]) " "ORDER BY a.id", - {"primary_keys": [2]} + {"primary_keys": [2]}, ), CompiledSQL( # note this links c.a_sub_id to a.id, even though @@ -100,55 +118,60 @@ class BaseAndSubFixture(object): "c.id AS c_id " "FROM c WHERE c.a_sub_id " "IN ([EXPANDING_primary_keys]) ORDER BY c.a_sub_id", - {"primary_keys": [2]} + {"primary_keys": [2]}, ), ), CompiledSQL( "SELECT b.a_id AS b_a_id, b.id AS b_id FROM b " "WHERE b.a_id IN ([EXPANDING_primary_keys]) " "ORDER BY b.a_id", - {"primary_keys": [1, 2]} - ) - ) - + {"primary_keys": [1, 2]}, + ), + ), ) - self.assert_sql_execution( - testing.db, - lambda: self._run_query(result), - ) + self.assert_sql_execution(testing.db, lambda: self._run_query(result)) class LoadBaseAndSubWEagerRelOpt( - BaseAndSubFixture, fixtures.DeclarativeMappedTest, - testing.AssertsExecutionResults): + BaseAndSubFixture, + fixtures.DeclarativeMappedTest, + testing.AssertsExecutionResults, +): use_options = True def test_load(self): A, B, ASub, C = self.classes("A", "B", "ASub", "C") s = Session() - q = s.query(A).order_by(A.id).options( - selectin_polymorphic(A, [ASub]), - selectinload(ASub.cs), - selectinload(A.bs) + q = ( + s.query(A) + .order_by(A.id) + .options( + selectin_polymorphic(A, [ASub]), + selectinload(ASub.cs), + selectinload(A.bs), + ) ) self._assert_all_selectin(q) class LoadBaseAndSubWEagerRelMapped( - BaseAndSubFixture, fixtures.DeclarativeMappedTest, - testing.AssertsExecutionResults): + BaseAndSubFixture, + fixtures.DeclarativeMappedTest, + testing.AssertsExecutionResults, +): use_options = False def test_load(self): A, B, ASub, C = self.classes("A", "B", "ASub", "C") s = Session() - q = s.query(A).order_by(A.id).options( - selectinload(ASub.cs), - selectinload(A.bs) + q = ( + s.query(A) + .order_by(A.id) + .options(selectinload(ASub.cs), selectinload(A.bs)) ) self._assert_all_selectin(q) @@ -158,7 +181,8 @@ class FixtureLoadTest(_Polymorphic, testing.AssertsExecutionResults): def test_person_selectin_subclasses(self): s = Session() q = s.query(Person).options( - selectin_polymorphic(Person, [Engineer, Manager])) + selectin_polymorphic(Person, [Engineer, Manager]) + ) result = self.assert_sql_execution( testing.db, @@ -168,7 +192,7 @@ class FixtureLoadTest(_Polymorphic, testing.AssertsExecutionResults): "people.company_id AS people_company_id, " "people.name AS people_name, " "people.type AS people_type FROM people", - {} + {}, ), AllOf( CompiledSQL( @@ -182,7 +206,7 @@ class FixtureLoadTest(_Polymorphic, testing.AssertsExecutionResults): "ON people.person_id = engineers.person_id " "WHERE people.person_id IN ([EXPANDING_primary_keys]) " "ORDER BY people.person_id", - {"primary_keys": [1, 2, 5]} + {"primary_keys": [1, 2, 5]}, ), CompiledSQL( "SELECT managers.person_id AS managers_person_id, " @@ -194,18 +218,23 @@ class FixtureLoadTest(_Polymorphic, testing.AssertsExecutionResults): "ON people.person_id = managers.person_id " "WHERE people.person_id IN ([EXPANDING_primary_keys]) " "ORDER BY people.person_id", - {"primary_keys": [3, 4]} - ) + {"primary_keys": [3, 4]}, + ), ), ) eq_(result, self.all_employees) def test_load_company_plus_employees(self): s = Session() - q = s.query(Company).options( - selectinload(Company.employees). - selectin_polymorphic([Engineer, Manager]) - ).order_by(Company.company_id) + q = ( + s.query(Company) + .options( + selectinload(Company.employees).selectin_polymorphic( + [Engineer, Manager] + ) + ) + .order_by(Company.company_id) + ) result = self.assert_sql_execution( testing.db, @@ -214,7 +243,7 @@ class FixtureLoadTest(_Polymorphic, testing.AssertsExecutionResults): "SELECT companies.company_id AS companies_company_id, " "companies.name AS companies_name FROM companies " "ORDER BY companies.company_id", - {} + {}, ), CompiledSQL( "SELECT people.company_id AS people_company_id, " @@ -223,7 +252,7 @@ class FixtureLoadTest(_Polymorphic, testing.AssertsExecutionResults): "FROM people WHERE people.company_id " "IN ([EXPANDING_primary_keys]) " "ORDER BY people.company_id, people.person_id", - {"primary_keys": [1, 2]} + {"primary_keys": [1, 2]}, ), AllOf( CompiledSQL( @@ -237,7 +266,7 @@ class FixtureLoadTest(_Polymorphic, testing.AssertsExecutionResults): "ON people.person_id = managers.person_id " "WHERE people.person_id IN ([EXPANDING_primary_keys]) " "ORDER BY people.person_id", - {"primary_keys": [3, 4]} + {"primary_keys": [3, 4]}, ), CompiledSQL( "SELECT engineers.person_id AS engineers_person_id, " @@ -251,34 +280,37 @@ class FixtureLoadTest(_Polymorphic, testing.AssertsExecutionResults): "ON people.person_id = engineers.person_id " "WHERE people.person_id IN ([EXPANDING_primary_keys]) " "ORDER BY people.person_id", - {"primary_keys": [1, 2, 5]} - ) - ) + {"primary_keys": [1, 2, 5]}, + ), + ), ) eq_(result, [self.c1, self.c2]) class TestGeometries(GeometryFixtureBase): - def test_threelevel_selectin_to_inline_mapped(self): - self._fixture_from_geometry({ - "a": { - "subclasses": { - "b": {"polymorphic_load": "selectin"}, - "c": { - "subclasses": { - "d": { - "polymorphic_load": "inline", "single": True - }, - "e": { - "polymorphic_load": "inline", "single": True + self._fixture_from_geometry( + { + "a": { + "subclasses": { + "b": {"polymorphic_load": "selectin"}, + "c": { + "subclasses": { + "d": { + "polymorphic_load": "inline", + "single": True, + }, + "e": { + "polymorphic_load": "inline", + "single": True, + }, }, + "polymorphic_load": "selectin", }, - "polymorphic_load": "selectin", } } } - }) + ) a, b, c, d, e = self.classes("a", "b", "c", "d", "e") sess = Session() @@ -293,7 +325,7 @@ class TestGeometries(GeometryFixtureBase): CompiledSQL( "SELECT a.type AS a_type, a.id AS a_id, " "a.a_data AS a_a_data FROM a", - {} + {}, ), Or( CompiledSQL( @@ -302,7 +334,7 @@ class TestGeometries(GeometryFixtureBase): "c.d_data AS c_d_data " "FROM a JOIN c ON a.id = c.id " "WHERE a.id IN ([EXPANDING_primary_keys]) ORDER BY a.id", - [{'primary_keys': [1, 2]}] + [{"primary_keys": [1, 2]}], ), CompiledSQL( "SELECT a.type AS a_type, c.id AS c_id, a.id AS a_id, " @@ -310,34 +342,29 @@ class TestGeometries(GeometryFixtureBase): "c.d_data AS c_d_data, c.e_data AS c_e_data " "FROM a JOIN c ON a.id = c.id " "WHERE a.id IN ([EXPANDING_primary_keys]) ORDER BY a.id", - [{'primary_keys': [1, 2]}] - ) - ) + [{"primary_keys": [1, 2]}], + ), + ), ) with self.assert_statement_count(testing.db, 0): - eq_( - result, - [d(d_data="d1"), e(e_data="e1")] - ) + eq_(result, [d(d_data="d1"), e(e_data="e1")]) def test_threelevel_selectin_to_inline_options(self): - self._fixture_from_geometry({ - "a": { - "subclasses": { - "b": {}, - "c": { - "subclasses": { - "d": { - "single": True - }, - "e": { - "single": True - }, + self._fixture_from_geometry( + { + "a": { + "subclasses": { + "b": {}, + "c": { + "subclasses": { + "d": {"single": True}, + "e": {"single": True}, + } }, } } } - }) + ) a, b, c, d, e = self.classes("a", "b", "c", "d", "e") sess = Session() @@ -345,9 +372,7 @@ class TestGeometries(GeometryFixtureBase): sess.commit() c_alias = with_polymorphic(c, (d, e)) - q = sess.query(a).options( - selectin_polymorphic(a, [b, c_alias]) - ) + q = sess.query(a).options(selectin_polymorphic(a, [b, c_alias])) result = self.assert_sql_execution( testing.db, @@ -355,7 +380,7 @@ class TestGeometries(GeometryFixtureBase): CompiledSQL( "SELECT a.type AS a_type, a.id AS a_id, " "a.a_data AS a_a_data FROM a", - {} + {}, ), Or( CompiledSQL( @@ -364,7 +389,7 @@ class TestGeometries(GeometryFixtureBase): "c.d_data AS c_d_data " "FROM a JOIN c ON a.id = c.id " "WHERE a.id IN ([EXPANDING_primary_keys]) ORDER BY a.id", - [{'primary_keys': [1, 2]}] + [{"primary_keys": [1, 2]}], ), CompiledSQL( "SELECT a.type AS a_type, c.id AS c_id, a.id AS a_id, " @@ -372,30 +397,24 @@ class TestGeometries(GeometryFixtureBase): "c.e_data AS c_e_data " "FROM a JOIN c ON a.id = c.id " "WHERE a.id IN ([EXPANDING_primary_keys]) ORDER BY a.id", - [{'primary_keys': [1, 2]}] + [{"primary_keys": [1, 2]}], ), - ) + ), ) with self.assert_statement_count(testing.db, 0): - eq_( - result, - [d(d_data="d1"), e(e_data="e1")] - ) + eq_(result, [d(d_data="d1"), e(e_data="e1")]) def test_threelevel_selectin_to_inline_awkward_alias_options(self): - self._fixture_from_geometry({ - "a": { - "subclasses": { - "b": {}, - "c": { - "subclasses": { - "d": {}, - "e": {}, - }, + self._fixture_from_geometry( + { + "a": { + "subclasses": { + "b": {}, + "c": {"subclasses": {"d": {}, "e": {}}}, } } } - }) + ) a, b, c, d, e = self.classes("a", "b", "c", "d", "e") sess = Session() @@ -406,16 +425,21 @@ class TestGeometries(GeometryFixtureBase): a_table, c_table, d_table, e_table = self.tables("a", "c", "d", "e") - poly = select([ - a_table.c.id, a_table.c.type, c_table, d_table, e_table - ]).select_from( - a_table.join(c_table).outerjoin(d_table).outerjoin(e_table) - ).apply_labels().alias('poly') + poly = ( + select([a_table.c.id, a_table.c.type, c_table, d_table, e_table]) + .select_from( + a_table.join(c_table).outerjoin(d_table).outerjoin(e_table) + ) + .apply_labels() + .alias("poly") + ) c_alias = with_polymorphic(c, (d, e), poly) - q = sess.query(a).options( - selectin_polymorphic(a, [b, c_alias]) - ).order_by(a.id) + q = ( + sess.query(a) + .options(selectin_polymorphic(a, [b, c_alias])) + .order_by(a.id) + ) result = self.assert_sql_execution( testing.db, @@ -423,7 +447,7 @@ class TestGeometries(GeometryFixtureBase): CompiledSQL( "SELECT a.type AS a_type, a.id AS a_id, " "a.a_data AS a_a_data FROM a ORDER BY a.id", - {} + {}, ), Or( # here, the test is that the adaptation of "a" takes place @@ -442,7 +466,7 @@ class TestGeometries(GeometryFixtureBase): "LEFT OUTER JOIN e ON c.id = e.id) AS poly " "WHERE poly.a_id IN ([EXPANDING_primary_keys]) " "ORDER BY poly.a_id", - [{'primary_keys': [1, 2]}] + [{"primary_keys": [1, 2]}], ), CompiledSQL( "SELECT poly.a_type AS poly_a_type, " @@ -458,27 +482,26 @@ class TestGeometries(GeometryFixtureBase): "LEFT OUTER JOIN e ON c.id = e.id) AS poly " "WHERE poly.a_id IN ([EXPANDING_primary_keys]) " "ORDER BY poly.a_id", - [{'primary_keys': [1, 2]}] - ) - ) + [{"primary_keys": [1, 2]}], + ), + ), ) with self.assert_statement_count(testing.db, 0): - eq_( - result, - [d(d_data="d1"), e(e_data="e1")] - ) + eq_(result, [d(d_data="d1"), e(e_data="e1")]) def test_partial_load_no_invoke_eagers(self): # test issue #4199 - self._fixture_from_geometry({ - "a": { - "subclasses": { - "a1": {"polymorphic_load": "selectin"}, - "a2": {"polymorphic_load": "selectin"} + self._fixture_from_geometry( + { + "a": { + "subclasses": { + "a1": {"polymorphic_load": "selectin"}, + "a2": {"polymorphic_load": "selectin"}, + } } } - }) + ) a, a1, a2 = self.classes("a", "a1", "a2") sess = Session() @@ -499,47 +522,49 @@ class TestGeometries(GeometryFixtureBase): class LoaderOptionsTest( - fixtures.DeclarativeMappedTest, testing.AssertsExecutionResults): + fixtures.DeclarativeMappedTest, testing.AssertsExecutionResults +): @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class Parent(fixtures.ComparableEntity, Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) class Child(fixtures.ComparableEntity, Base): - __tablename__ = 'child' + __tablename__ = "child" id = Column(Integer, primary_key=True) - parent_id = Column(Integer, ForeignKey('parent.id')) - parent = relationship('Parent', backref=backref('children')) + parent_id = Column(Integer, ForeignKey("parent.id")) + parent = relationship("Parent", backref=backref("children")) type = Column(String(50), nullable=False) - __mapper_args__ = { - 'polymorphic_on': type, - } + __mapper_args__ = {"polymorphic_on": type} class ChildSubclass1(Child): - __tablename__ = 'child_subclass1' - id = Column(Integer, ForeignKey('child.id'), primary_key=True) + __tablename__ = "child_subclass1" + id = Column(Integer, ForeignKey("child.id"), primary_key=True) __mapper_args__ = { - 'polymorphic_identity': 'subclass1', - 'polymorphic_load': 'selectin' + "polymorphic_identity": "subclass1", + "polymorphic_load": "selectin", } class Other(fixtures.ComparableEntity, Base): - __tablename__ = 'other' + __tablename__ = "other" id = Column(Integer, primary_key=True) - child_subclass_id = Column(Integer, - ForeignKey('child_subclass1.id')) - child_subclass = relationship('ChildSubclass1', - backref=backref('others')) + child_subclass_id = Column( + Integer, ForeignKey("child_subclass1.id") + ) + child_subclass = relationship( + "ChildSubclass1", backref=backref("others") + ) @classmethod def insert_data(cls): Parent, ChildSubclass1, Other = cls.classes( - "Parent", "ChildSubclass1", "Other") + "Parent", "ChildSubclass1", "Other" + ) session = Session() parent = Parent(id=1) @@ -556,12 +581,14 @@ class LoaderOptionsTest( def _test_options_dont_pollute(self, enable_baked): Parent, ChildSubclass1, Other = self.classes( - "Parent", "ChildSubclass1", "Other") + "Parent", "ChildSubclass1", "Other" + ) session = Session(enable_baked_queries=enable_baked) def no_opt(): q = session.query(Parent).options( - joinedload(Parent.children.of_type(ChildSubclass1))) + joinedload(Parent.children.of_type(ChildSubclass1)) + ) return self.assert_sql_execution( testing.db, @@ -581,7 +608,7 @@ class LoaderOptionsTest( "LEFT OUTER JOIN child_subclass1 " "ON child.id = child_subclass1.id) AS anon_1 " "ON parent.id = anon_1.child_parent_id", - {} + {}, ), CompiledSQL( "SELECT child_subclass1.id AS child_subclass1_id, " @@ -592,22 +619,20 @@ class LoaderOptionsTest( "ON child.id = child_subclass1.id " "WHERE child.id IN ([EXPANDING_primary_keys]) " "ORDER BY child.id", - [{'primary_keys': [1]}] + [{"primary_keys": [1]}], ), ) result = no_opt() with self.assert_statement_count(testing.db, 1): - eq_( - result, - [Parent(children=[ChildSubclass1(others=[Other()])])] - ) + eq_(result, [Parent(children=[ChildSubclass1(others=[Other()])])]) session.expunge_all() q = session.query(Parent).options( - joinedload(Parent.children.of_type(ChildSubclass1)) - .joinedload(ChildSubclass1.others) + joinedload(Parent.children.of_type(ChildSubclass1)).joinedload( + ChildSubclass1.others + ) ) result = self.assert_sql_execution( @@ -631,7 +656,7 @@ class LoaderOptionsTest( "ON parent.id = anon_1.child_parent_id " "LEFT OUTER JOIN other AS other_1 " "ON anon_1.child_subclass1_id = other_1.child_subclass_id", - {} + {}, ), CompiledSQL( "SELECT child_subclass1.id AS child_subclass1_id, " @@ -644,21 +669,15 @@ class LoaderOptionsTest( "ON child_subclass1.id = other_1.child_subclass_id " "WHERE child.id IN ([EXPANDING_primary_keys]) " "ORDER BY child.id", - [{'primary_keys': [1]}] - ) + [{"primary_keys": [1]}], + ), ) with self.assert_statement_count(testing.db, 0): - eq_( - result, - [Parent(children=[ChildSubclass1(others=[Other()])])] - ) + eq_(result, [Parent(children=[ChildSubclass1(others=[Other()])])]) session.expunge_all() result = no_opt() with self.assert_statement_count(testing.db, 1): - eq_( - result, - [Parent(children=[ChildSubclass1(others=[Other()])])] - ) + eq_(result, [Parent(children=[ChildSubclass1(others=[Other()])])]) diff --git a/test/orm/inheritance/test_poly_persistence.py b/test/orm/inheritance/test_poly_persistence.py index 3713ca24f8..67ae78fb45 100644 --- a/test/orm/inheritance/test_poly_persistence.py +++ b/test/orm/inheritance/test_poly_persistence.py @@ -37,45 +37,75 @@ class PolymorphTest(fixtures.MappedTest): def define_tables(cls, metadata): global companies, people, engineers, managers, boss - companies = Table('companies', metadata, - Column('company_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50))) + companies = Table( + "companies", + metadata, + Column( + "company_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + ) people = Table( - 'people', metadata, + "people", + metadata, Column( - 'person_id', Integer, primary_key=True, - test_needs_autoincrement=True), + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), Column( - 'company_id', Integer, ForeignKey('companies.company_id'), - nullable=False), - Column('name', String(50)), - Column('type', String(30))) + "company_id", + Integer, + ForeignKey("companies.company_id"), + nullable=False, + ), + Column("name", String(50)), + Column("type", String(30)), + ) engineers = Table( - 'engineers', metadata, + "engineers", + metadata, Column( - 'person_id', Integer, ForeignKey('people.person_id'), - primary_key=True), - Column('status', String(30)), - Column('engineer_name', String(50)), - Column('primary_language', String(50))) + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("status", String(30)), + Column("engineer_name", String(50)), + Column("primary_language", String(50)), + ) managers = Table( - 'managers', metadata, + "managers", + metadata, Column( - 'person_id', Integer, ForeignKey('people.person_id'), - primary_key=True), - Column('status', String(30)), - Column('manager_name', String(50))) + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("status", String(30)), + Column("manager_name", String(50)), + ) boss = Table( - 'boss', metadata, + "boss", + metadata, Column( - 'boss_id', Integer, ForeignKey('managers.person_id'), - primary_key=True), - Column('golf_swing', String(30))) + "boss_id", + Integer, + ForeignKey("managers.person_id"), + primary_key=True, + ), + Column("golf_swing", String(30)), + ) metadata.create_all() @@ -87,43 +117,73 @@ class InsertOrderTest(PolymorphTest): person_join = polymorphic_union( { - 'engineer': people.join(engineers), - 'manager': people.join(managers), - 'person': people.select(people.c.type == 'person'), - }, None, 'pjoin') - - person_mapper = mapper(Person, people, - with_polymorphic=('*', person_join), - polymorphic_on=person_join.c.type, - polymorphic_identity='person') - - mapper(Engineer, engineers, inherits=person_mapper, - polymorphic_identity='engineer') - mapper(Manager, managers, inherits=person_mapper, - polymorphic_identity='manager') - mapper(Company, companies, properties={ - 'employees': relationship(Person, - backref='company', - order_by=person_join.c.person_id) - }) + "engineer": people.join(engineers), + "manager": people.join(managers), + "person": people.select(people.c.type == "person"), + }, + None, + "pjoin", + ) + + person_mapper = mapper( + Person, + people, + with_polymorphic=("*", person_join), + polymorphic_on=person_join.c.type, + polymorphic_identity="person", + ) + + mapper( + Engineer, + engineers, + inherits=person_mapper, + polymorphic_identity="engineer", + ) + mapper( + Manager, + managers, + inherits=person_mapper, + polymorphic_identity="manager", + ) + mapper( + Company, + companies, + properties={ + "employees": relationship( + Person, backref="company", order_by=person_join.c.person_id + ) + }, + ) session = create_session() - c = Company(name='company1') + c = Company(name="company1") c.employees.append( Manager( - status='AAB', manager_name='manager1', - name='pointy haired boss')) - c.employees.append(Engineer(status='BBA', - engineer_name='engineer1', - primary_language='java', name='dilbert')) - c.employees.append(Person(status='HHH', name='joesmith')) - c.employees.append(Engineer(status='CGG', - engineer_name='engineer2', - primary_language='python', name='wally')) + status="AAB", + manager_name="manager1", + name="pointy haired boss", + ) + ) c.employees.append( - Manager( - status='ABA', manager_name='manager2', - name='jsmith')) + Engineer( + status="BBA", + engineer_name="engineer1", + primary_language="java", + name="dilbert", + ) + ) + c.employees.append(Person(status="HHH", name="joesmith")) + c.employees.append( + Engineer( + status="CGG", + engineer_name="engineer2", + primary_language="python", + name="wally", + ) + ) + c.employees.append( + Manager(status="ABA", manager_name="manager2", name="jsmith") + ) session.add(c) session.flush() session.expunge_all() @@ -134,8 +194,9 @@ class RoundTripTest(PolymorphTest): pass -def _generate_round_trip_test(include_base, lazy_relationship, - redefine_colprop, with_polymorphic): +def _generate_round_trip_test( + include_base, lazy_relationship, redefine_colprop, with_polymorphic +): """generates a round trip test. include_base - whether or not to include the base 'person' type in @@ -151,84 +212,124 @@ def _generate_round_trip_test(include_base, lazy_relationship, """ def test_roundtrip(self): - if with_polymorphic == 'unions': + if with_polymorphic == "unions": if include_base: person_join = polymorphic_union( { - 'engineer': people.join(engineers), - 'manager': people.join(managers), - 'person': people.select(people.c.type == 'person'), - }, None, 'pjoin') + "engineer": people.join(engineers), + "manager": people.join(managers), + "person": people.select(people.c.type == "person"), + }, + None, + "pjoin", + ) else: person_join = polymorphic_union( { - 'engineer': people.join(engineers), - 'manager': people.join(managers), - }, None, 'pjoin') + "engineer": people.join(engineers), + "manager": people.join(managers), + }, + None, + "pjoin", + ) manager_join = people.join(managers).outerjoin(boss) - person_with_polymorphic = ['*', person_join] - manager_with_polymorphic = ['*', manager_join] - elif with_polymorphic == 'joins': - person_join = people.outerjoin(engineers).outerjoin(managers).\ - outerjoin(boss) + person_with_polymorphic = ["*", person_join] + manager_with_polymorphic = ["*", manager_join] + elif with_polymorphic == "joins": + person_join = ( + people.outerjoin(engineers).outerjoin(managers).outerjoin(boss) + ) manager_join = people.join(managers).outerjoin(boss) - person_with_polymorphic = ['*', person_join] - manager_with_polymorphic = ['*', manager_join] - elif with_polymorphic == 'auto': - person_with_polymorphic = '*' - manager_with_polymorphic = '*' + person_with_polymorphic = ["*", person_join] + manager_with_polymorphic = ["*", manager_join] + elif with_polymorphic == "auto": + person_with_polymorphic = "*" + manager_with_polymorphic = "*" else: person_with_polymorphic = None manager_with_polymorphic = None if redefine_colprop: - person_mapper = mapper(Person, people, - with_polymorphic=person_with_polymorphic, - polymorphic_on=people.c.type, - polymorphic_identity='person', - properties={'person_name': people.c.name}) + person_mapper = mapper( + Person, + people, + with_polymorphic=person_with_polymorphic, + polymorphic_on=people.c.type, + polymorphic_identity="person", + properties={"person_name": people.c.name}, + ) else: - person_mapper = mapper(Person, people, - with_polymorphic=person_with_polymorphic, - polymorphic_on=people.c.type, - polymorphic_identity='person') - - mapper(Engineer, engineers, inherits=person_mapper, - polymorphic_identity='engineer') - mapper(Manager, managers, inherits=person_mapper, - with_polymorphic=manager_with_polymorphic, - polymorphic_identity='manager') - - mapper(Boss, boss, inherits=Manager, polymorphic_identity='boss') - - mapper(Company, companies, properties={ - 'employees': relationship(Person, lazy=lazy_relationship, - cascade="all, delete-orphan", - backref="company", - order_by=people.c.person_id) - }) + person_mapper = mapper( + Person, + people, + with_polymorphic=person_with_polymorphic, + polymorphic_on=people.c.type, + polymorphic_identity="person", + ) + + mapper( + Engineer, + engineers, + inherits=person_mapper, + polymorphic_identity="engineer", + ) + mapper( + Manager, + managers, + inherits=person_mapper, + with_polymorphic=manager_with_polymorphic, + polymorphic_identity="manager", + ) + + mapper(Boss, boss, inherits=Manager, polymorphic_identity="boss") + + mapper( + Company, + companies, + properties={ + "employees": relationship( + Person, + lazy=lazy_relationship, + cascade="all, delete-orphan", + backref="company", + order_by=people.c.person_id, + ) + }, + ) if redefine_colprop: - person_attribute_name = 'person_name' + person_attribute_name = "person_name" else: - person_attribute_name = 'name' + person_attribute_name = "name" employees = [ - Manager(status='AAB', manager_name='manager1', - **{person_attribute_name: 'pointy haired boss'}), - Engineer(status='BBA', engineer_name='engineer1', - primary_language='java', - **{person_attribute_name: 'dilbert'}), + Manager( + status="AAB", + manager_name="manager1", + **{person_attribute_name: "pointy haired boss"} + ), + Engineer( + status="BBA", + engineer_name="engineer1", + primary_language="java", + **{person_attribute_name: "dilbert"} + ), ] if include_base: - employees.append(Person(**{person_attribute_name: 'joesmith'})) + employees.append(Person(**{person_attribute_name: "joesmith"})) employees += [ - Engineer(status='CGG', engineer_name='engineer2', - primary_language='python', - **{person_attribute_name: 'wally'}), - Manager(status='ABA', manager_name='manager2', - **{person_attribute_name: 'jsmith'}) + Engineer( + status="CGG", + engineer_name="engineer2", + primary_language="python", + **{person_attribute_name: "wally"} + ), + Manager( + status="ABA", + manager_name="manager2", + **{person_attribute_name: "jsmith"} + ), ] pointy = employees[0] @@ -236,7 +337,7 @@ def _generate_round_trip_test(include_base, lazy_relationship, dilbert = employees[1] session = create_session() - c = Company(name='company1') + c = Company(name="company1") c.employees = employees session.add(c) @@ -246,9 +347,12 @@ def _generate_round_trip_test(include_base, lazy_relationship, eq_(session.query(Person).get(dilbert.person_id), dilbert) session.expunge_all() - eq_(session.query(Person).filter( - Person.person_id == dilbert.person_id).one(), - dilbert) + eq_( + session.query(Person) + .filter(Person.person_id == dilbert.person_id) + .one(), + dilbert, + ) session.expunge_all() def go(): @@ -256,13 +360,13 @@ def _generate_round_trip_test(include_base, lazy_relationship, eq_(cc.employees, employees) if not lazy_relationship: - if with_polymorphic != 'none': + if with_polymorphic != "none": self.assert_sql_count(testing.db, go, 1) else: self.assert_sql_count(testing.db, go, 5) else: - if with_polymorphic != 'none': + if with_polymorphic != "none": self.assert_sql_count(testing.db, go, 2) else: self.assert_sql_count(testing.db, go, 6) @@ -272,21 +376,24 @@ def _generate_round_trip_test(include_base, lazy_relationship, # in the case of the polymorphic Person query, # the "people" selectable should be adapted to be "person_join" eq_( - session.query(Person).filter( - getattr(Person, person_attribute_name) == 'dilbert' - ).first(), - dilbert + session.query(Person) + .filter(getattr(Person, person_attribute_name) == "dilbert") + .first(), + dilbert, ) - assert session.query(Person).filter( - getattr(Person, person_attribute_name) == 'dilbert' - ).first().person_id + assert ( + session.query(Person) + .filter(getattr(Person, person_attribute_name) == "dilbert") + .first() + .person_id + ) eq_( - session.query(Engineer).filter( - getattr(Person, person_attribute_name) == 'dilbert' - ).first(), - dilbert + session.query(Engineer) + .filter(getattr(Person, person_attribute_name) == "dilbert") + .first(), + dilbert, ) # test selecting from the query, joining against @@ -297,27 +404,36 @@ def _generate_round_trip_test(include_base, lazy_relationship, dilbert = session.query(Person).get(dilbert.person_id) is_( dilbert, - session.query(Person).filter( - (palias.c.name == 'dilbert') & - (palias.c.person_id == Person.person_id)).first() + session.query(Person) + .filter( + (palias.c.name == "dilbert") + & (palias.c.person_id == Person.person_id) + ) + .first(), ) is_( dilbert, - session.query(Engineer).filter( - (palias.c.name == 'dilbert') & - (palias.c.person_id == Person.person_id)).first() + session.query(Engineer) + .filter( + (palias.c.name == "dilbert") + & (palias.c.person_id == Person.person_id) + ) + .first(), ) is_( dilbert, - session.query(Person).filter( - (Engineer.engineer_name == "engineer1") & - (engineers.c.person_id == people.c.person_id) - ).first() + session.query(Person) + .filter( + (Engineer.engineer_name == "engineer1") + & (engineers.c.person_id == people.c.person_id) + ) + .first(), ) is_( dilbert, - session.query(Engineer). - filter(Engineer.engineer_name == "engineer1")[0] + session.query(Engineer).filter( + Engineer.engineer_name == "engineer1" + )[0], ) session.flush() @@ -325,61 +441,83 @@ def _generate_round_trip_test(include_base, lazy_relationship, def go(): session.query(Person).filter( - getattr(Person, person_attribute_name) == 'dilbert').first() + getattr(Person, person_attribute_name) == "dilbert" + ).first() + self.assert_sql_count(testing.db, go, 1) session.expunge_all() - dilbert = session.query(Person).filter( - getattr(Person, person_attribute_name) == 'dilbert').first() + dilbert = ( + session.query(Person) + .filter(getattr(Person, person_attribute_name) == "dilbert") + .first() + ) def go(): # assert that only primary table is queried for # already-present-in-session - d = session.query(Person).filter( - getattr(Person, person_attribute_name) == 'dilbert').first() + d = ( + session.query(Person) + .filter(getattr(Person, person_attribute_name) == "dilbert") + .first() + ) + self.assert_sql_count(testing.db, go, 1) # test standalone orphans - daboss = Boss(status='BBB', - manager_name='boss', - golf_swing='fore', - **{person_attribute_name: 'daboss'}) + daboss = Boss( + status="BBB", + manager_name="boss", + golf_swing="fore", + **{person_attribute_name: "daboss"} + ) session.add(daboss) assert_raises(sa_exc.DBAPIError, session.flush) c = session.query(Company).first() daboss.company = c - manager_list = [e for e in c.employees - if isinstance(e, Manager)] + manager_list = [e for e in c.employees if isinstance(e, Manager)] session.flush() session.expunge_all() - eq_(session.query(Manager).order_by(Manager.person_id).all(), - manager_list) + eq_( + session.query(Manager).order_by(Manager.person_id).all(), + manager_list, + ) c = session.query(Company).first() session.delete(c) session.flush() - eq_(select([func.count('*')]).select_from(people).scalar(), 0) + eq_(select([func.count("*")]).select_from(people).scalar(), 0) test_roundtrip = function_named( - test_roundtrip, "test_%s%s%s_%s" % ( + test_roundtrip, + "test_%s%s%s_%s" + % ( (lazy_relationship and "lazy" or "eager"), (include_base and "_inclbase" or ""), (redefine_colprop and "_redefcol" or ""), - with_polymorphic)) + with_polymorphic, + ), + ) setattr(RoundTripTest, test_roundtrip.__name__, test_roundtrip) for lazy_relationship in [True, False]: for redefine_colprop in [True, False]: - for with_polymorphic in ['unions', 'joins', 'auto', 'none']: - if with_polymorphic == 'unions': + for with_polymorphic in ["unions", "joins", "auto", "none"]: + if with_polymorphic == "unions": for include_base in [True, False]: _generate_round_trip_test( - include_base, lazy_relationship, redefine_colprop, - with_polymorphic) + include_base, + lazy_relationship, + redefine_colprop, + with_polymorphic, + ) else: - _generate_round_trip_test(False, - lazy_relationship, - redefine_colprop, with_polymorphic) + _generate_round_trip_test( + False, + lazy_relationship, + redefine_colprop, + with_polymorphic, + ) diff --git a/test/orm/inheritance/test_polymorphic_rel.py b/test/orm/inheritance/test_polymorphic_rel.py index d46448355f..a0279f9e71 100644 --- a/test/orm/inheritance/test_polymorphic_rel.py +++ b/test/orm/inheritance/test_polymorphic_rel.py @@ -1,16 +1,34 @@ from sqlalchemy import func, desc, select -from sqlalchemy.orm import (interfaces, create_session, joinedload, - joinedload_all, subqueryload, subqueryload_all, - aliased, class_mapper, with_polymorphic) +from sqlalchemy.orm import ( + interfaces, + create_session, + joinedload, + joinedload_all, + subqueryload, + subqueryload_all, + aliased, + class_mapper, + with_polymorphic, +) from sqlalchemy import exc as sa_exc from sqlalchemy import testing from sqlalchemy.testing import assert_raises, eq_ -from ._poly_fixtures import Company, Person, Engineer, Manager, Boss, \ - Machine, Paperwork, _Polymorphic,\ - _PolymorphicPolymorphic, _PolymorphicUnions, _PolymorphicJoins,\ - _PolymorphicAliasedJoins +from ._poly_fixtures import ( + Company, + Person, + Engineer, + Manager, + Boss, + Machine, + Paperwork, + _Polymorphic, + _PolymorphicPolymorphic, + _PolymorphicUnions, + _PolymorphicJoins, + _PolymorphicAliasedJoins, +) class _PolymorphicTestBase(object): @@ -21,11 +39,15 @@ class _PolymorphicTestBase(object): super(_PolymorphicTestBase, cls).setup_mappers() global people, engineers, managers, boss global companies, paperwork, machines - people, engineers, managers, boss,\ - companies, paperwork, machines = \ - cls.tables.people, cls.tables.engineers, \ - cls.tables.managers, cls.tables.boss,\ - cls.tables.companies, cls.tables.paperwork, cls.tables.machines + people, engineers, managers, boss, companies, paperwork, machines = ( + cls.tables.people, + cls.tables.engineers, + cls.tables.managers, + cls.tables.boss, + cls.tables.companies, + cls.tables.paperwork, + cls.tables.machines, + ) @classmethod def insert_data(cls): @@ -33,11 +55,14 @@ class _PolymorphicTestBase(object): global all_employees, c1_employees, c2_employees global c1, c2, e1, e2, e3, b1, m1 - c1, c2, all_employees, c1_employees, c2_employees = \ - cls.c1, cls.c2, cls.all_employees, \ - cls.c1_employees, cls.c2_employees - e1, e2, e3, b1, m1 = \ - cls.e1, cls.e2, cls.e3, cls.b1, cls.m1 + c1, c2, all_employees, c1_employees, c2_employees = ( + cls.c1, + cls.c2, + cls.all_employees, + cls.c1_employees, + cls.c2_employees, + ) + e1, e2, e3, b1, m1 = cls.e1, cls.e2, cls.e3, cls.b1, cls.m1 def test_loads_at_once(self): """ @@ -50,8 +75,10 @@ class _PolymorphicTestBase(object): def go(): eq_( sess.query(Person).order_by(Person.person_id).all(), - all_employees) - count = {'': 14, 'Polymorphic': 9}.get(self.select_type, 10) + all_employees, + ) + + count = {"": 14, "Polymorphic": 9}.get(self.select_type, 10) self.assert_sql_count(testing.db, go, count) def test_primary_eager_aliasing_one(self): @@ -61,20 +88,29 @@ class _PolymorphicTestBase(object): sess = create_session() def go(): - eq_(sess.query(Person).order_by(Person.person_id) - .options(joinedload(Engineer.machines))[1:3], - all_employees[1:3]) - count = {'': 6, 'Polymorphic': 3}.get(self.select_type, 4) + eq_( + sess.query(Person) + .order_by(Person.person_id) + .options(joinedload(Engineer.machines))[1:3], + all_employees[1:3], + ) + + count = {"": 6, "Polymorphic": 3}.get(self.select_type, 4) self.assert_sql_count(testing.db, go, count) def test_primary_eager_aliasing_two(self): sess = create_session() def go(): - eq_(sess.query(Person).order_by(Person.person_id) - .options(subqueryload(Engineer.machines)).all(), - all_employees) - count = {'': 14, 'Polymorphic': 7}.get(self.select_type, 8) + eq_( + sess.query(Person) + .order_by(Person.person_id) + .options(subqueryload(Engineer.machines)) + .all(), + all_employees, + ) + + count = {"": 14, "Polymorphic": 7}.get(self.select_type, 8) self.assert_sql_count(testing.db, go, count) def test_primary_eager_aliasing_three(self): @@ -84,19 +120,31 @@ class _PolymorphicTestBase(object): sess = create_session() def go(): - eq_(sess.query(Person).with_polymorphic('*') - .order_by(Person.person_id) - .options(joinedload(Engineer.machines))[1:3], - all_employees[1:3]) + eq_( + sess.query(Person) + .with_polymorphic("*") + .order_by(Person.person_id) + .options(joinedload(Engineer.machines))[1:3], + all_employees[1:3], + ) + self.assert_sql_count(testing.db, go, 3) eq_( - select([func.count('*')]).select_from( - sess.query(Person).with_polymorphic('*') + select([func.count("*")]) + .select_from( + sess.query(Person) + .with_polymorphic("*") .options(joinedload(Engineer.machines)) - .order_by(Person.person_id).limit(2).offset(1) - .with_labels().subquery() - ).scalar(), 2) + .order_by(Person.person_id) + .limit(2) + .offset(1) + .with_labels() + .subquery() + ) + .scalar(), + 2, + ) def test_get_one(self): """ @@ -104,51 +152,64 @@ class _PolymorphicTestBase(object): just the "person_id" column. """ sess = create_session() - eq_(sess.query(Person).get(e1.person_id), - Engineer(name="dilbert", primary_language="java")) + eq_( + sess.query(Person).get(e1.person_id), + Engineer(name="dilbert", primary_language="java"), + ) def test_get_two(self): sess = create_session() - eq_(sess.query(Engineer).get(e1.person_id), - Engineer(name="dilbert", primary_language="java")) + eq_( + sess.query(Engineer).get(e1.person_id), + Engineer(name="dilbert", primary_language="java"), + ) def test_get_three(self): sess = create_session() - eq_(sess.query(Manager).get(b1.person_id), - Boss(name="pointy haired boss", golf_swing="fore")) + eq_( + sess.query(Manager).get(b1.person_id), + Boss(name="pointy haired boss", golf_swing="fore"), + ) def test_multi_join(self): sess = create_session() e = aliased(Person) c = aliased(Company) - q = sess.query(Company, Person, c, e)\ - .join(Person, Company.employees)\ - .join(e, c.employees)\ - .filter(Person.name == 'dilbert')\ - .filter(e.name == 'wally') + q = ( + sess.query(Company, Person, c, e) + .join(Person, Company.employees) + .join(e, c.employees) + .filter(Person.name == "dilbert") + .filter(e.name == "wally") + ) eq_(q.count(), 1) - eq_(q.all(), [ - ( - Company(company_id=1, name='MegaCorp, Inc.'), - Engineer( - status='regular engineer', - engineer_name='dilbert', - name='dilbert', - company_id=1, - primary_language='java', - person_id=1, - type='engineer'), - Company(company_id=1, name='MegaCorp, Inc.'), - Engineer( - status='regular engineer', - engineer_name='wally', - name='wally', - company_id=1, - primary_language='c++', - person_id=2, - type='engineer') - ) - ]) + eq_( + q.all(), + [ + ( + Company(company_id=1, name="MegaCorp, Inc."), + Engineer( + status="regular engineer", + engineer_name="dilbert", + name="dilbert", + company_id=1, + primary_language="java", + person_id=1, + type="engineer", + ), + Company(company_id=1, name="MegaCorp, Inc."), + Engineer( + status="regular engineer", + engineer_name="wally", + name="wally", + company_id=1, + primary_language="c++", + person_id=2, + type="engineer", + ), + ) + ], + ) def test_filter_on_subclass_one(self): sess = create_session() @@ -160,376 +221,454 @@ class _PolymorphicTestBase(object): def test_filter_on_subclass_three(self): sess = create_session() - eq_(sess.query(Engineer) - .filter(Engineer.person_id == e1.person_id).first(), - Engineer(name="dilbert")) + eq_( + sess.query(Engineer) + .filter(Engineer.person_id == e1.person_id) + .first(), + Engineer(name="dilbert"), + ) def test_filter_on_subclass_four(self): sess = create_session() - eq_(sess.query(Manager) - .filter(Manager.person_id == m1.person_id).one(), - Manager(name="dogbert")) + eq_( + sess.query(Manager) + .filter(Manager.person_id == m1.person_id) + .one(), + Manager(name="dogbert"), + ) def test_filter_on_subclass_five(self): sess = create_session() - eq_(sess.query(Manager) - .filter(Manager.person_id == b1.person_id).one(), - Boss(name="pointy haired boss")) + eq_( + sess.query(Manager) + .filter(Manager.person_id == b1.person_id) + .one(), + Boss(name="pointy haired boss"), + ) def test_filter_on_subclass_six(self): sess = create_session() - eq_(sess.query(Boss) - .filter(Boss.person_id == b1.person_id).one(), - Boss(name="pointy haired boss")) + eq_( + sess.query(Boss).filter(Boss.person_id == b1.person_id).one(), + Boss(name="pointy haired boss"), + ) def test_join_from_polymorphic_nonaliased_one(self): sess = create_session() - eq_(sess.query(Person) - .join('paperwork', aliased=False) - .filter(Paperwork.description.like('%review%')).all(), - [b1, m1]) + eq_( + sess.query(Person) + .join("paperwork", aliased=False) + .filter(Paperwork.description.like("%review%")) + .all(), + [b1, m1], + ) def test_join_from_polymorphic_nonaliased_two(self): sess = create_session() - eq_(sess.query(Person) - .order_by(Person.person_id) - .join('paperwork', aliased=False) - .filter(Paperwork.description.like('%#2%')).all(), - [e1, m1]) + eq_( + sess.query(Person) + .order_by(Person.person_id) + .join("paperwork", aliased=False) + .filter(Paperwork.description.like("%#2%")) + .all(), + [e1, m1], + ) def test_join_from_polymorphic_nonaliased_three(self): sess = create_session() - eq_(sess.query(Engineer) - .order_by(Person.person_id) - .join('paperwork', aliased=False) - .filter(Paperwork.description.like('%#2%')).all(), - [e1]) + eq_( + sess.query(Engineer) + .order_by(Person.person_id) + .join("paperwork", aliased=False) + .filter(Paperwork.description.like("%#2%")) + .all(), + [e1], + ) def test_join_from_polymorphic_nonaliased_four(self): sess = create_session() - eq_(sess.query(Person) - .order_by(Person.person_id) - .join('paperwork', aliased=False) - .filter(Person.name.like('%dog%')) - .filter(Paperwork.description.like('%#2%')).all(), - [m1]) + eq_( + sess.query(Person) + .order_by(Person.person_id) + .join("paperwork", aliased=False) + .filter(Person.name.like("%dog%")) + .filter(Paperwork.description.like("%#2%")) + .all(), + [m1], + ) def test_join_from_polymorphic_aliased_one(self): sess = create_session() - eq_(sess.query(Person) - .order_by(Person.person_id) - .join('paperwork', aliased=True) - .filter(Paperwork.description.like('%review%')).all(), - [b1, m1]) + eq_( + sess.query(Person) + .order_by(Person.person_id) + .join("paperwork", aliased=True) + .filter(Paperwork.description.like("%review%")) + .all(), + [b1, m1], + ) def test_join_from_polymorphic_aliased_two(self): sess = create_session() - eq_(sess.query(Person) - .order_by(Person.person_id) - .join('paperwork', aliased=True) - .filter(Paperwork.description.like('%#2%')).all(), - [e1, m1]) + eq_( + sess.query(Person) + .order_by(Person.person_id) + .join("paperwork", aliased=True) + .filter(Paperwork.description.like("%#2%")) + .all(), + [e1, m1], + ) def test_join_from_polymorphic_aliased_three(self): sess = create_session() - eq_(sess.query(Engineer) - .order_by(Person.person_id) - .join('paperwork', aliased=True) - .filter(Paperwork.description.like('%#2%')).all(), - [e1]) + eq_( + sess.query(Engineer) + .order_by(Person.person_id) + .join("paperwork", aliased=True) + .filter(Paperwork.description.like("%#2%")) + .all(), + [e1], + ) def test_join_from_polymorphic_aliased_four(self): sess = create_session() - eq_(sess.query(Person) - .order_by(Person.person_id) - .join('paperwork', aliased=True) - .filter(Person.name.like('%dog%')) - .filter(Paperwork.description.like('%#2%')).all(), - [m1]) + eq_( + sess.query(Person) + .order_by(Person.person_id) + .join("paperwork", aliased=True) + .filter(Person.name.like("%dog%")) + .filter(Paperwork.description.like("%#2%")) + .all(), + [m1], + ) def test_join_from_with_polymorphic_nonaliased_one(self): sess = create_session() - eq_(sess.query(Person) - .with_polymorphic(Manager) - .order_by(Person.person_id) - .join('paperwork') - .filter(Paperwork.description.like('%review%')).all(), - [b1, m1]) + eq_( + sess.query(Person) + .with_polymorphic(Manager) + .order_by(Person.person_id) + .join("paperwork") + .filter(Paperwork.description.like("%review%")) + .all(), + [b1, m1], + ) def test_join_from_with_polymorphic_nonaliased_two(self): sess = create_session() - eq_(sess.query(Person) - .with_polymorphic([Manager, Engineer]) - .order_by(Person.person_id) - .join('paperwork') - .filter(Paperwork.description.like('%#2%')).all(), - [e1, m1]) + eq_( + sess.query(Person) + .with_polymorphic([Manager, Engineer]) + .order_by(Person.person_id) + .join("paperwork") + .filter(Paperwork.description.like("%#2%")) + .all(), + [e1, m1], + ) def test_join_from_with_polymorphic_nonaliased_three(self): sess = create_session() - eq_(sess.query(Person) - .with_polymorphic([Manager, Engineer]) - .order_by(Person.person_id) - .join('paperwork') - .filter(Person.name.like('%dog%')) - .filter(Paperwork.description.like('%#2%')).all(), - [m1]) + eq_( + sess.query(Person) + .with_polymorphic([Manager, Engineer]) + .order_by(Person.person_id) + .join("paperwork") + .filter(Person.name.like("%dog%")) + .filter(Paperwork.description.like("%#2%")) + .all(), + [m1], + ) def test_join_from_with_polymorphic_aliased_one(self): sess = create_session() - eq_(sess.query(Person) - .with_polymorphic(Manager) - .join('paperwork', aliased=True) - .filter(Paperwork.description.like('%review%')).all(), - [b1, m1]) + eq_( + sess.query(Person) + .with_polymorphic(Manager) + .join("paperwork", aliased=True) + .filter(Paperwork.description.like("%review%")) + .all(), + [b1, m1], + ) def test_join_from_with_polymorphic_aliased_two(self): sess = create_session() - eq_(sess.query(Person) - .with_polymorphic([Manager, Engineer]) - .order_by(Person.person_id) - .join('paperwork', aliased=True) - .filter(Paperwork.description.like('%#2%')).all(), - [e1, m1]) + eq_( + sess.query(Person) + .with_polymorphic([Manager, Engineer]) + .order_by(Person.person_id) + .join("paperwork", aliased=True) + .filter(Paperwork.description.like("%#2%")) + .all(), + [e1, m1], + ) def test_join_from_with_polymorphic_aliased_three(self): sess = create_session() - eq_(sess.query(Person) - .with_polymorphic([Manager, Engineer]) - .order_by(Person.person_id) - .join('paperwork', aliased=True) - .filter(Person.name.like('%dog%')) - .filter(Paperwork.description.like('%#2%')).all(), - [m1]) + eq_( + sess.query(Person) + .with_polymorphic([Manager, Engineer]) + .order_by(Person.person_id) + .join("paperwork", aliased=True) + .filter(Person.name.like("%dog%")) + .filter(Paperwork.description.like("%#2%")) + .all(), + [m1], + ) def test_join_to_polymorphic_nonaliased(self): sess = create_session() - eq_(sess.query(Company) - .join('employees') - .filter(Person.name == 'vlad').one(), - c2) + eq_( + sess.query(Company) + .join("employees") + .filter(Person.name == "vlad") + .one(), + c2, + ) def test_join_to_polymorphic_aliased(self): sess = create_session() - eq_(sess.query(Company) - .join('employees', aliased=True) - .filter(Person.name == 'vlad').one(), - c2) + eq_( + sess.query(Company) + .join("employees", aliased=True) + .filter(Person.name == "vlad") + .one(), + c2, + ) def test_polymorphic_any_one(self): sess = create_session() - any_ = Company.employees.any(Person.name == 'vlad') + any_ = Company.employees.any(Person.name == "vlad") eq_(sess.query(Company).filter(any_).all(), [c2]) def test_polymorphic_any_two(self): sess = create_session() # test that the aliasing on "Person" does not bleed into the # EXISTS clause generated by any() - any_ = Company.employees.any(Person.name == 'wally') - eq_(sess.query(Company) - .join(Company.employees, aliased=True) - .filter(Person.name == 'dilbert') - .filter(any_).all(), - [c1]) + any_ = Company.employees.any(Person.name == "wally") + eq_( + sess.query(Company) + .join(Company.employees, aliased=True) + .filter(Person.name == "dilbert") + .filter(any_) + .all(), + [c1], + ) def test_polymorphic_any_three(self): sess = create_session() - any_ = Company.employees.any(Person.name == 'vlad') - eq_(sess.query(Company) - .join(Company.employees, aliased=True) - .filter(Person.name == 'dilbert') - .filter(any_).all(), - []) + any_ = Company.employees.any(Person.name == "vlad") + eq_( + sess.query(Company) + .join(Company.employees, aliased=True) + .filter(Person.name == "dilbert") + .filter(any_) + .all(), + [], + ) def test_polymorphic_any_eight(self): sess = create_session() - any_ = Engineer.machines.any( - Machine.name == "Commodore 64") + any_ = Engineer.machines.any(Machine.name == "Commodore 64") eq_( sess.query(Person).order_by(Person.person_id).filter(any_).all(), - [e2, e3]) + [e2, e3], + ) def test_polymorphic_any_nine(self): sess = create_session() - any_ = Person.paperwork.any( - Paperwork.description == "review #2") + any_ = Person.paperwork.any(Paperwork.description == "review #2") eq_( sess.query(Person).order_by(Person.person_id).filter(any_).all(), - [m1]) + [m1], + ) def test_join_from_columns_or_subclass_one(self): sess = create_session() - expected = [ - ('dogbert',), - ('pointy haired boss',)] - eq_(sess.query(Manager.name) - .order_by(Manager.name).all(), - expected) + expected = [("dogbert",), ("pointy haired boss",)] + eq_(sess.query(Manager.name).order_by(Manager.name).all(), expected) def test_join_from_columns_or_subclass_two(self): sess = create_session() - expected = [ - ('dogbert',), - ('dogbert',), - ('pointy haired boss',)] - eq_(sess.query(Manager.name) - .join(Paperwork, Manager.paperwork) - .order_by(Manager.name).all(), - expected) + expected = [("dogbert",), ("dogbert",), ("pointy haired boss",)] + eq_( + sess.query(Manager.name) + .join(Paperwork, Manager.paperwork) + .order_by(Manager.name) + .all(), + expected, + ) def test_join_from_columns_or_subclass_three(self): sess = create_session() expected = [ - ('dilbert',), - ('dilbert',), - ('dogbert',), - ('dogbert',), - ('pointy haired boss',), - ('vlad',), - ('wally',), - ('wally',)] - eq_(sess.query(Person.name) - .join(Paperwork, Person.paperwork) - .order_by(Person.name).all(), - expected) + ("dilbert",), + ("dilbert",), + ("dogbert",), + ("dogbert",), + ("pointy haired boss",), + ("vlad",), + ("wally",), + ("wally",), + ] + eq_( + sess.query(Person.name) + .join(Paperwork, Person.paperwork) + .order_by(Person.name) + .all(), + expected, + ) def test_join_from_columns_or_subclass_four(self): sess = create_session() # Load Person.name, joining from Person -> paperwork, get all # the people. expected = [ - ('dilbert',), - ('dilbert',), - ('dogbert',), - ('dogbert',), - ('pointy haired boss',), - ('vlad',), - ('wally',), - ('wally',)] - eq_(sess.query(Person.name) - .join(paperwork, - Person.person_id == paperwork.c.person_id) - .order_by(Person.name).all(), - expected) + ("dilbert",), + ("dilbert",), + ("dogbert",), + ("dogbert",), + ("pointy haired boss",), + ("vlad",), + ("wally",), + ("wally",), + ] + eq_( + sess.query(Person.name) + .join(paperwork, Person.person_id == paperwork.c.person_id) + .order_by(Person.name) + .all(), + expected, + ) def test_join_from_columns_or_subclass_five(self): sess = create_session() # same, on manager. get only managers. - expected = [ - ('dogbert',), - ('dogbert',), - ('pointy haired boss',)] - eq_(sess.query(Manager.name) - .join(paperwork, - Manager.person_id == paperwork.c.person_id) - .order_by(Person.name).all(), - expected) + expected = [("dogbert",), ("dogbert",), ("pointy haired boss",)] + eq_( + sess.query(Manager.name) + .join(paperwork, Manager.person_id == paperwork.c.person_id) + .order_by(Person.name) + .all(), + expected, + ) def test_join_from_columns_or_subclass_six(self): sess = create_session() - if self.select_type == '': + if self.select_type == "": # this now raises, due to [ticket:1892]. Manager.person_id # is now the "person_id" column on Manager. SQL is incorrect. assert_raises( sa_exc.DBAPIError, sess.query(Person.name) - .join(paperwork, - Manager.person_id == paperwork.c.person_id) - .order_by(Person.name).all) - elif self.select_type == 'Unions': + .join(paperwork, Manager.person_id == paperwork.c.person_id) + .order_by(Person.name) + .all, + ) + elif self.select_type == "Unions": # with the union, not something anyone would really be using # here, it joins to the full result set. This is 0.6's # behavior and is more or less wrong. expected = [ - ('dilbert',), - ('dilbert',), - ('dogbert',), - ('dogbert',), - ('pointy haired boss',), - ('vlad',), - ('wally',), - ('wally',)] - eq_(sess.query(Person.name) - .join(paperwork, - Manager.person_id == paperwork.c.person_id) - .order_by(Person.name).all(), - expected) + ("dilbert",), + ("dilbert",), + ("dogbert",), + ("dogbert",), + ("pointy haired boss",), + ("vlad",), + ("wally",), + ("wally",), + ] + eq_( + sess.query(Person.name) + .join(paperwork, Manager.person_id == paperwork.c.person_id) + .order_by(Person.name) + .all(), + expected, + ) else: # when a join is present and managers.person_id is available, # you get the managers. - expected = [ - ('dogbert',), - ('dogbert',), - ('pointy haired boss',)] - eq_(sess.query(Person.name) - .join(paperwork, - Manager.person_id == paperwork.c.person_id) - .order_by(Person.name).all(), - expected) + expected = [("dogbert",), ("dogbert",), ("pointy haired boss",)] + eq_( + sess.query(Person.name) + .join(paperwork, Manager.person_id == paperwork.c.person_id) + .order_by(Person.name) + .all(), + expected, + ) def test_join_from_columns_or_subclass_seven(self): sess = create_session() - eq_(sess.query(Manager) - .join(Paperwork, Manager.paperwork) - .order_by(Manager.name).all(), - [m1, b1]) + eq_( + sess.query(Manager) + .join(Paperwork, Manager.paperwork) + .order_by(Manager.name) + .all(), + [m1, b1], + ) def test_join_from_columns_or_subclass_eight(self): sess = create_session() - expected = [ - ('dogbert',), - ('dogbert',), - ('pointy haired boss',)] - eq_(sess.query(Manager.name) - .join(paperwork, - Manager.person_id == paperwork.c.person_id) - .order_by(Manager.name).all(), - expected) + expected = [("dogbert",), ("dogbert",), ("pointy haired boss",)] + eq_( + sess.query(Manager.name) + .join(paperwork, Manager.person_id == paperwork.c.person_id) + .order_by(Manager.name) + .all(), + expected, + ) def test_join_from_columns_or_subclass_nine(self): sess = create_session() - eq_(sess.query(Manager.person_id) - .join(paperwork, - Manager.person_id == paperwork.c.person_id) - .order_by(Manager.name).all(), - [(4,), (4,), (3,)]) + eq_( + sess.query(Manager.person_id) + .join(paperwork, Manager.person_id == paperwork.c.person_id) + .order_by(Manager.name) + .all(), + [(4,), (4,), (3,)], + ) def test_join_from_columns_or_subclass_ten(self): sess = create_session() expected = [ - ('pointy haired boss', 'review #1'), - ('dogbert', 'review #2'), - ('dogbert', 'review #3')] - eq_(sess.query(Manager.name, Paperwork.description) - .join(Paperwork, - Manager.person_id == Paperwork.person_id) - .order_by(Paperwork.paperwork_id).all(), - expected) + ("pointy haired boss", "review #1"), + ("dogbert", "review #2"), + ("dogbert", "review #3"), + ] + eq_( + sess.query(Manager.name, Paperwork.description) + .join(Paperwork, Manager.person_id == Paperwork.person_id) + .order_by(Paperwork.paperwork_id) + .all(), + expected, + ) def test_join_from_columns_or_subclass_eleven(self): sess = create_session() - expected = [ - ('pointy haired boss',), - ('dogbert',), - ('dogbert',)] + expected = [("pointy haired boss",), ("dogbert",), ("dogbert",)] malias = aliased(Manager) - eq_(sess.query(malias.name) - .join(paperwork, - malias.person_id == paperwork.c.person_id) - .all(), - expected) + eq_( + sess.query(malias.name) + .join(paperwork, malias.person_id == paperwork.c.person_id) + .all(), + expected, + ) def test_subclass_option_pathing(self): from sqlalchemy.orm import defer + sess = create_session() - dilbert = sess.query(Person).\ - options(defer(Engineer.machines, Machine.name)).\ - filter(Person.name == 'dilbert').first() + dilbert = ( + sess.query(Person) + .options(defer(Engineer.machines, Machine.name)) + .filter(Person.name == "dilbert") + .first() + ) m = dilbert.machines[0] - assert 'name' not in m.__dict__ - eq_(m.name, 'IBM ThinkPad') + assert "name" not in m.__dict__ + eq_(m.name, "IBM ThinkPad") def test_expire(self): """ @@ -539,56 +678,70 @@ class _PolymorphicTestBase(object): sess = create_session() - name = 'dogbert' + name = "dogbert" m1 = sess.query(Manager).filter(Manager.name == name).one() sess.expire(m1) - assert m1.status == 'regular manager' + assert m1.status == "regular manager" - name = 'pointy haired boss' + name = "pointy haired boss" m2 = sess.query(Manager).filter(Manager.name == name).one() - sess.expire(m2, ['manager_name', 'golf_swing']) - assert m2.golf_swing == 'fore' + sess.expire(m2, ["manager_name", "golf_swing"]) + assert m2.golf_swing == "fore" def test_with_polymorphic_one(self): sess = create_session() def go(): - eq_(sess.query(Person) - .with_polymorphic(Engineer) - .filter(Engineer.primary_language == 'java').all(), - self._emps_wo_relationships_fixture()[0:1]) + eq_( + sess.query(Person) + .with_polymorphic(Engineer) + .filter(Engineer.primary_language == "java") + .all(), + self._emps_wo_relationships_fixture()[0:1], + ) + self.assert_sql_count(testing.db, go, 1) def test_with_polymorphic_two(self): sess = create_session() def go(): - eq_(sess.query(Person) - .with_polymorphic('*').order_by(Person.person_id).all(), - self._emps_wo_relationships_fixture()) + eq_( + sess.query(Person) + .with_polymorphic("*") + .order_by(Person.person_id) + .all(), + self._emps_wo_relationships_fixture(), + ) + self.assert_sql_count(testing.db, go, 1) def test_with_polymorphic_three(self): sess = create_session() def go(): - eq_(sess.query(Person) - .with_polymorphic(Engineer). - order_by(Person.person_id).all(), - self._emps_wo_relationships_fixture()) + eq_( + sess.query(Person) + .with_polymorphic(Engineer) + .order_by(Person.person_id) + .all(), + self._emps_wo_relationships_fixture(), + ) + self.assert_sql_count(testing.db, go, 3) def test_with_polymorphic_four(self): sess = create_session() def go(): - eq_(sess.query(Person) - .with_polymorphic( - Engineer, - people.outerjoin(engineers)) - .order_by(Person.person_id) - .all(), - self._emps_wo_relationships_fixture()) + eq_( + sess.query(Person) + .with_polymorphic(Engineer, people.outerjoin(engineers)) + .order_by(Person.person_id) + .all(), + self._emps_wo_relationships_fixture(), + ) + self.assert_sql_count(testing.db, go, 3) def test_with_polymorphic_five(self): @@ -597,28 +750,43 @@ class _PolymorphicTestBase(object): def go(): # limit the polymorphic join down to just "Person", # overriding select_table - eq_(sess.query(Person) - .with_polymorphic(Person).all(), - self._emps_wo_relationships_fixture()) + eq_( + sess.query(Person).with_polymorphic(Person).all(), + self._emps_wo_relationships_fixture(), + ) + self.assert_sql_count(testing.db, go, 6) def test_with_polymorphic_six(self): sess = create_session() - assert_raises(sa_exc.InvalidRequestError, - sess.query(Person).with_polymorphic, Paperwork) - assert_raises(sa_exc.InvalidRequestError, - sess.query(Engineer).with_polymorphic, Boss) - assert_raises(sa_exc.InvalidRequestError, - sess.query(Engineer).with_polymorphic, Person) + assert_raises( + sa_exc.InvalidRequestError, + sess.query(Person).with_polymorphic, + Paperwork, + ) + assert_raises( + sa_exc.InvalidRequestError, + sess.query(Engineer).with_polymorphic, + Boss, + ) + assert_raises( + sa_exc.InvalidRequestError, + sess.query(Engineer).with_polymorphic, + Person, + ) def test_with_polymorphic_seven(self): sess = create_session() # compare to entities without related collections to prevent # additional lazy SQL from firing on loaded entities - eq_(sess.query(Person).with_polymorphic('*'). - order_by(Person.person_id).all(), - self._emps_wo_relationships_fixture()) + eq_( + sess.query(Person) + .with_polymorphic("*") + .order_by(Person.person_id) + .all(), + self._emps_wo_relationships_fixture(), + ) def test_relationship_to_polymorphic_one(self): expected = self._company_with_emps_machines_fixture() @@ -627,7 +795,8 @@ class _PolymorphicTestBase(object): def go(): # test load Companies with lazy load to 'employees' eq_(sess.query(Company).all(), expected) - count = {'': 10, 'Polymorphic': 5}.get(self.select_type, 6) + + count = {"": 10, "Polymorphic": 5}.get(self.select_type, 6) self.assert_sql_count(testing.db, go, count) def test_relationship_to_polymorphic_two(self): @@ -638,12 +807,16 @@ class _PolymorphicTestBase(object): # with #2438, of_type() is recognized. This # overrides the with_polymorphic of the mapper # and we get a consistent 3 queries now. - eq_(sess.query(Company) - .options(joinedload_all( - Company.employees.of_type(Engineer), - Engineer.machines)) - .all(), - expected) + eq_( + sess.query(Company) + .options( + joinedload_all( + Company.employees.of_type(Engineer), Engineer.machines + ) + ) + .all(), + expected, + ) # in the old case, we would get this # count = {'':7, 'Polymorphic':1}.get(self.select_type, 2) @@ -661,12 +834,16 @@ class _PolymorphicTestBase(object): sess = create_session() def go(): - eq_(sess.query(Company) - .options(subqueryload_all( - Company.employees.of_type(Engineer), - Engineer.machines)) - .all(), - expected) + eq_( + sess.query(Company) + .options( + subqueryload_all( + Company.employees.of_type(Engineer), Engineer.machines + ) + ) + .all(), + expected, + ) # the old case where subqueryload_all # didn't work with of_tyoe @@ -693,15 +870,22 @@ class _PolymorphicTestBase(object): status="regular engineer", machines=[ Machine(name="IBM ThinkPad"), - Machine(name="IPhone")])] + Machine(name="IPhone"), + ], + ) + ] def go(): # test load People with joinedload to engineers + machines - eq_(sess.query(Person) - .with_polymorphic('*') - .options(joinedload(Engineer.machines)) - .filter(Person.name == 'dilbert').all(), - expected) + eq_( + sess.query(Person) + .with_polymorphic("*") + .options(joinedload(Engineer.machines)) + .filter(Person.name == "dilbert") + .all(), + expected, + ) + self.assert_sql_count(testing.db, go, 1) def test_subqueryload_on_subclass(self): @@ -714,239 +898,314 @@ class _PolymorphicTestBase(object): status="regular engineer", machines=[ Machine(name="IBM ThinkPad"), - Machine(name="IPhone")])] + Machine(name="IPhone"), + ], + ) + ] def go(): # test load People with subqueryload to engineers + machines - eq_(sess.query(Person) - .with_polymorphic('*') - .options(subqueryload(Engineer.machines)) - .filter(Person.name == 'dilbert').all(), - expected) + eq_( + sess.query(Person) + .with_polymorphic("*") + .options(subqueryload(Engineer.machines)) + .filter(Person.name == "dilbert") + .all(), + expected, + ) + self.assert_sql_count(testing.db, go, 2) def test_query_subclass_join_to_base_relationship(self): sess = create_session() # non-polymorphic - eq_(sess.query(Engineer) - .join(Person.paperwork).all(), - [e1, e2, e3]) + eq_(sess.query(Engineer).join(Person.paperwork).all(), [e1, e2, e3]) def test_join_to_subclass(self): sess = create_session() - eq_(sess.query(Company) - .join(people.join(engineers), 'employees') - .filter(Engineer.primary_language == 'java').all(), - [c1]) + eq_( + sess.query(Company) + .join(people.join(engineers), "employees") + .filter(Engineer.primary_language == "java") + .all(), + [c1], + ) def test_join_to_subclass_one(self): sess = create_session() - eq_(sess.query(Company) - .select_from(companies.join(people).join(engineers)) - .filter(Engineer.primary_language == 'java').all(), - [c1]) + eq_( + sess.query(Company) + .select_from(companies.join(people).join(engineers)) + .filter(Engineer.primary_language == "java") + .all(), + [c1], + ) def test_join_to_subclass_two(self): sess = create_session() - eq_(sess.query(Company) - .join(people.join(engineers), 'employees') - .filter(Engineer.primary_language == 'java').all(), - [c1]) + eq_( + sess.query(Company) + .join(people.join(engineers), "employees") + .filter(Engineer.primary_language == "java") + .all(), + [c1], + ) def test_join_to_subclass_three(self): sess = create_session() ealias = aliased(Engineer) - eq_(sess.query(Company) - .join(ealias, 'employees') - .filter(ealias.primary_language == 'java').all(), - [c1]) + eq_( + sess.query(Company) + .join(ealias, "employees") + .filter(ealias.primary_language == "java") + .all(), + [c1], + ) def test_join_to_subclass_six(self): sess = create_session() - eq_(sess.query(Company) - .join(people.join(engineers), 'employees') - .join(Engineer.machines).all(), - [c1, c2]) + eq_( + sess.query(Company) + .join(people.join(engineers), "employees") + .join(Engineer.machines) + .all(), + [c1, c2], + ) def test_join_to_subclass_six_point_five(self): sess = create_session() - eq_(sess.query(Company) - .join(people.join(engineers), 'employees') - .join(Engineer.machines) - .filter(Engineer.name == 'dilbert').all(), - [c1]) + eq_( + sess.query(Company) + .join(people.join(engineers), "employees") + .join(Engineer.machines) + .filter(Engineer.name == "dilbert") + .all(), + [c1], + ) def test_join_to_subclass_seven(self): sess = create_session() - eq_(sess.query(Company) - .join(people.join(engineers), 'employees') - .join(Engineer.machines) - .filter(Machine.name.ilike("%thinkpad%")).all(), - [c1]) + eq_( + sess.query(Company) + .join(people.join(engineers), "employees") + .join(Engineer.machines) + .filter(Machine.name.ilike("%thinkpad%")) + .all(), + [c1], + ) def test_join_to_subclass_eight(self): sess = create_session() - eq_(sess.query(Person) - .join(Engineer.machines).all(), - [e1, e2, e3]) + eq_(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3]) def test_join_to_subclass_nine(self): sess = create_session() - eq_(sess.query(Company) - .select_from(companies.join(people).join(engineers)) - .filter(Engineer.primary_language == 'java').all(), - [c1]) + eq_( + sess.query(Company) + .select_from(companies.join(people).join(engineers)) + .filter(Engineer.primary_language == "java") + .all(), + [c1], + ) def test_join_to_subclass_ten(self): sess = create_session() - eq_(sess.query(Company) - .join('employees') - .filter(Engineer.primary_language == 'java').all(), - [c1]) + eq_( + sess.query(Company) + .join("employees") + .filter(Engineer.primary_language == "java") + .all(), + [c1], + ) def test_join_to_subclass_eleven(self): sess = create_session() - eq_(sess.query(Company) - .select_from(companies.join(people).join(engineers)) - .filter(Engineer.primary_language == 'java').all(), - [c1]) + eq_( + sess.query(Company) + .select_from(companies.join(people).join(engineers)) + .filter(Engineer.primary_language == "java") + .all(), + [c1], + ) def test_join_to_subclass_twelve(self): sess = create_session() - eq_(sess.query(Person) - .join(Engineer.machines).all(), - [e1, e2, e3]) + eq_(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3]) def test_join_to_subclass_thirteen(self): sess = create_session() - eq_(sess.query(Person) - .join(Engineer.machines) - .filter(Machine.name.ilike("%ibm%")).all(), - [e1, e3]) + eq_( + sess.query(Person) + .join(Engineer.machines) + .filter(Machine.name.ilike("%ibm%")) + .all(), + [e1, e3], + ) def test_join_to_subclass_fourteen(self): sess = create_session() - eq_(sess.query(Company) - .join('employees', Engineer.machines).all(), - [c1, c2]) + eq_( + sess.query(Company).join("employees", Engineer.machines).all(), + [c1, c2], + ) def test_join_to_subclass_fifteen(self): sess = create_session() - eq_(sess.query(Company) - .join('employees', Engineer.machines) - .filter(Machine.name.ilike("%thinkpad%")).all(), - [c1]) + eq_( + sess.query(Company) + .join("employees", Engineer.machines) + .filter(Machine.name.ilike("%thinkpad%")) + .all(), + [c1], + ) def test_join_to_subclass_sixteen(self): sess = create_session() # non-polymorphic - eq_(sess.query(Engineer) - .join(Engineer.machines).all(), - [e1, e2, e3]) + eq_(sess.query(Engineer).join(Engineer.machines).all(), [e1, e2, e3]) def test_join_to_subclass_seventeen(self): sess = create_session() - eq_(sess.query(Engineer) - .join(Engineer.machines) - .filter(Machine.name.ilike("%ibm%")).all(), - [e1, e3]) + eq_( + sess.query(Engineer) + .join(Engineer.machines) + .filter(Machine.name.ilike("%ibm%")) + .all(), + [e1, e3], + ) def test_join_through_polymorphic_nonaliased_one(self): sess = create_session() - eq_(sess.query(Company) - .join('employees', 'paperwork', aliased=False) - .filter(Paperwork.description.like('%#2%')).all(), - [c1]) + eq_( + sess.query(Company) + .join("employees", "paperwork", aliased=False) + .filter(Paperwork.description.like("%#2%")) + .all(), + [c1], + ) def test_join_through_polymorphic_nonaliased_two(self): sess = create_session() - eq_(sess.query(Company) - .join('employees', 'paperwork', aliased=False) - .filter(Paperwork.description.like('%#%')).all(), - [c1, c2]) + eq_( + sess.query(Company) + .join("employees", "paperwork", aliased=False) + .filter(Paperwork.description.like("%#%")) + .all(), + [c1, c2], + ) def test_join_through_polymorphic_nonaliased_three(self): sess = create_session() - eq_(sess.query(Company) - .join('employees', 'paperwork', aliased=False) - .filter(Person.name.in_(['dilbert', 'vlad'])) - .filter(Paperwork.description.like('%#2%')).all(), - [c1]) + eq_( + sess.query(Company) + .join("employees", "paperwork", aliased=False) + .filter(Person.name.in_(["dilbert", "vlad"])) + .filter(Paperwork.description.like("%#2%")) + .all(), + [c1], + ) def test_join_through_polymorphic_nonaliased_four(self): sess = create_session() - eq_(sess.query(Company) - .join('employees', 'paperwork', aliased=False) - .filter(Person.name.in_(['dilbert', 'vlad'])) - .filter(Paperwork.description.like('%#%')).all(), - [c1, c2]) + eq_( + sess.query(Company) + .join("employees", "paperwork", aliased=False) + .filter(Person.name.in_(["dilbert", "vlad"])) + .filter(Paperwork.description.like("%#%")) + .all(), + [c1, c2], + ) def test_join_through_polymorphic_nonaliased_five(self): sess = create_session() - eq_(sess.query(Company) - .join('employees', aliased=aliased) - .filter(Person.name.in_(['dilbert', 'vlad'])) - .join('paperwork', from_joinpoint=True, aliased=False) - .filter(Paperwork.description.like('%#2%')).all(), - [c1]) + eq_( + sess.query(Company) + .join("employees", aliased=aliased) + .filter(Person.name.in_(["dilbert", "vlad"])) + .join("paperwork", from_joinpoint=True, aliased=False) + .filter(Paperwork.description.like("%#2%")) + .all(), + [c1], + ) def test_join_through_polymorphic_nonaliased_six(self): sess = create_session() - eq_(sess.query(Company) - .join('employees', aliased=aliased) - .filter(Person.name.in_(['dilbert', 'vlad'])) - .join('paperwork', from_joinpoint=True, aliased=False) - .filter(Paperwork.description.like('%#%')).all(), - [c1, c2]) + eq_( + sess.query(Company) + .join("employees", aliased=aliased) + .filter(Person.name.in_(["dilbert", "vlad"])) + .join("paperwork", from_joinpoint=True, aliased=False) + .filter(Paperwork.description.like("%#%")) + .all(), + [c1, c2], + ) def test_join_through_polymorphic_aliased_one(self): sess = create_session() - eq_(sess.query(Company) - .join('employees', 'paperwork', aliased=True) - .filter(Paperwork.description.like('%#2%')).all(), - [c1]) + eq_( + sess.query(Company) + .join("employees", "paperwork", aliased=True) + .filter(Paperwork.description.like("%#2%")) + .all(), + [c1], + ) def test_join_through_polymorphic_aliased_two(self): sess = create_session() - eq_(sess.query(Company) - .join('employees', 'paperwork', aliased=True) - .filter(Paperwork.description.like('%#%')).all(), - [c1, c2]) + eq_( + sess.query(Company) + .join("employees", "paperwork", aliased=True) + .filter(Paperwork.description.like("%#%")) + .all(), + [c1, c2], + ) def test_join_through_polymorphic_aliased_three(self): sess = create_session() - eq_(sess.query(Company) - .join('employees', 'paperwork', aliased=True) - .filter(Person.name.in_(['dilbert', 'vlad'])) - .filter(Paperwork.description.like('%#2%')).all(), - [c1]) + eq_( + sess.query(Company) + .join("employees", "paperwork", aliased=True) + .filter(Person.name.in_(["dilbert", "vlad"])) + .filter(Paperwork.description.like("%#2%")) + .all(), + [c1], + ) def test_join_through_polymorphic_aliased_four(self): sess = create_session() - eq_(sess.query(Company) - .join('employees', 'paperwork', aliased=True) - .filter(Person.name.in_(['dilbert', 'vlad'])) - .filter(Paperwork.description.like('%#%')).all(), - [c1, c2]) + eq_( + sess.query(Company) + .join("employees", "paperwork", aliased=True) + .filter(Person.name.in_(["dilbert", "vlad"])) + .filter(Paperwork.description.like("%#%")) + .all(), + [c1, c2], + ) def test_join_through_polymorphic_aliased_five(self): sess = create_session() - eq_(sess.query(Company) - .join('employees', aliased=aliased) - .filter(Person.name.in_(['dilbert', 'vlad'])) - .join('paperwork', from_joinpoint=True, aliased=True) - .filter(Paperwork.description.like('%#2%')).all(), - [c1]) + eq_( + sess.query(Company) + .join("employees", aliased=aliased) + .filter(Person.name.in_(["dilbert", "vlad"])) + .join("paperwork", from_joinpoint=True, aliased=True) + .filter(Paperwork.description.like("%#2%")) + .all(), + [c1], + ) def test_join_through_polymorphic_aliased_six(self): sess = create_session() - eq_(sess.query(Company) - .join('employees', aliased=aliased) - .filter(Person.name.in_(['dilbert', 'vlad'])) - .join('paperwork', from_joinpoint=True, aliased=True) - .filter(Paperwork.description.like('%#%')).all(), - [c1, c2]) + eq_( + sess.query(Company) + .join("employees", aliased=aliased) + .filter(Person.name.in_(["dilbert", "vlad"])) + .join("paperwork", from_joinpoint=True, aliased=True) + .filter(Paperwork.description.like("%#%")) + .all(), + [c1, c2], + ) def test_explicit_polymorphic_join_one(self): sess = create_session() @@ -955,10 +1214,13 @@ class _PolymorphicTestBase(object): # ORMJoin using regular table foreign key connections. Engineer # is expressed as "(select * people join engineers) as anon_1" # so the join is contained. - eq_(sess.query(Company) - .join(Engineer) - .filter(Engineer.engineer_name == 'vlad').one(), - c2) + eq_( + sess.query(Company) + .join(Engineer) + .filter(Engineer.engineer_name == "vlad") + .one(), + c2, + ) def test_explicit_polymorphic_join_two(self): sess = create_session() @@ -966,53 +1228,70 @@ class _PolymorphicTestBase(object): # same, using explicit join condition. Query.join() must # adapt the on clause here to match the subquery wrapped around # "people join engineers". - eq_(sess.query(Company) - .join(Engineer, Company.company_id == Engineer.company_id) - .filter(Engineer.engineer_name == 'vlad').one(), - c2) + eq_( + sess.query(Company) + .join(Engineer, Company.company_id == Engineer.company_id) + .filter(Engineer.engineer_name == "vlad") + .one(), + c2, + ) def test_filter_on_baseclass(self): sess = create_session() eq_(sess.query(Person).order_by(Person.person_id).all(), all_employees) eq_( sess.query(Person).order_by(Person.person_id).first(), - all_employees[0]) - eq_(sess.query(Person).order_by(Person.person_id) - .filter(Person.person_id == e2.person_id).one(), - e2) + all_employees[0], + ) + eq_( + sess.query(Person) + .order_by(Person.person_id) + .filter(Person.person_id == e2.person_id) + .one(), + e2, + ) def test_from_alias(self): sess = create_session() palias = aliased(Person) - eq_(sess.query(palias) - .order_by(palias.person_id) - .filter(palias.name.in_(['dilbert', 'wally'])).all(), - [e1, e2]) + eq_( + sess.query(palias) + .order_by(palias.person_id) + .filter(palias.name.in_(["dilbert", "wally"])) + .all(), + [e1, e2], + ) def test_self_referential_one(self): sess = create_session() palias = aliased(Person) expected = [(m1, e1), (m1, e2), (m1, b1)] - eq_(sess.query(Person, palias) - .filter(Person.company_id == palias.company_id) - .filter(Person.name == 'dogbert') - .filter(Person.person_id > palias.person_id) - .order_by(Person.person_id, palias.person_id).all(), - expected) + eq_( + sess.query(Person, palias) + .filter(Person.company_id == palias.company_id) + .filter(Person.name == "dogbert") + .filter(Person.person_id > palias.person_id) + .order_by(Person.person_id, palias.person_id) + .all(), + expected, + ) def test_self_referential_two(self): sess = create_session() palias = aliased(Person) expected = [(m1, e1), (m1, e2), (m1, b1)] - eq_(sess.query(Person, palias) - .filter(Person.company_id == palias.company_id) - .filter(Person.name == 'dogbert') - .filter(Person.person_id > palias.person_id) - .from_self() - .order_by(Person.person_id, palias.person_id).all(), - expected) + eq_( + sess.query(Person, palias) + .filter(Person.company_id == palias.company_id) + .filter(Person.name == "dogbert") + .filter(Person.person_id > palias.person_id) + .from_self() + .order_by(Person.person_id, palias.person_id) + .all(), + expected, + ) def test_nesting_queries(self): # query.statement places a flag "no_adapt" on the returned @@ -1021,267 +1300,342 @@ class _PolymorphicTestBase(object): # subquery and usually results in recursion overflow errors # within the adaption. sess = create_session() - subq = (sess.query(engineers.c.person_id) - .filter(Engineer.primary_language == 'java') - .statement.as_scalar()) - eq_(sess.query(Person) - .filter(Person.person_id.in_(subq)).one(), - e1) + subq = ( + sess.query(engineers.c.person_id) + .filter(Engineer.primary_language == "java") + .statement.as_scalar() + ) + eq_(sess.query(Person).filter(Person.person_id.in_(subq)).one(), e1) def test_mixed_entities_one(self): sess = create_session() expected = [ - (Engineer( - status='regular engineer', - engineer_name='dilbert', - name='dilbert', - company_id=1, - primary_language='java', - person_id=1, - type='engineer'), - 'MegaCorp, Inc.'), - (Engineer( - status='regular engineer', - engineer_name='wally', - name='wally', - company_id=1, - primary_language='c++', - person_id=2, - type='engineer'), - 'MegaCorp, Inc.'), - (Engineer( - status='elbonian engineer', - engineer_name='vlad', - name='vlad', - company_id=2, - primary_language='cobol', - person_id=5, - type='engineer'), - 'Elbonia, Inc.')] - eq_(sess.query(Engineer, Company.name) - .join(Company.employees) - .order_by(Person.person_id) - .filter(Person.type == 'engineer').all(), - expected) + ( + Engineer( + status="regular engineer", + engineer_name="dilbert", + name="dilbert", + company_id=1, + primary_language="java", + person_id=1, + type="engineer", + ), + "MegaCorp, Inc.", + ), + ( + Engineer( + status="regular engineer", + engineer_name="wally", + name="wally", + company_id=1, + primary_language="c++", + person_id=2, + type="engineer", + ), + "MegaCorp, Inc.", + ), + ( + Engineer( + status="elbonian engineer", + engineer_name="vlad", + name="vlad", + company_id=2, + primary_language="cobol", + person_id=5, + type="engineer", + ), + "Elbonia, Inc.", + ), + ] + eq_( + sess.query(Engineer, Company.name) + .join(Company.employees) + .order_by(Person.person_id) + .filter(Person.type == "engineer") + .all(), + expected, + ) def test_mixed_entities_two(self): sess = create_session() expected = [ - ('java', 'MegaCorp, Inc.'), - ('cobol', 'Elbonia, Inc.'), - ('c++', 'MegaCorp, Inc.')] - eq_(sess.query(Engineer.primary_language, Company.name) - .join(Company.employees) - .filter(Person.type == 'engineer') - .order_by(desc(Engineer.primary_language)).all(), - expected) + ("java", "MegaCorp, Inc."), + ("cobol", "Elbonia, Inc."), + ("c++", "MegaCorp, Inc."), + ] + eq_( + sess.query(Engineer.primary_language, Company.name) + .join(Company.employees) + .filter(Person.type == "engineer") + .order_by(desc(Engineer.primary_language)) + .all(), + expected, + ) def test_mixed_entities_three(self): sess = create_session() palias = aliased(Person) - expected = [( - Engineer( - status='elbonian engineer', - engineer_name='vlad', - name='vlad', - primary_language='cobol'), - 'Elbonia, Inc.', - Engineer( - status='regular engineer', - engineer_name='dilbert', - name='dilbert', - company_id=1, - primary_language='java', - person_id=1, - type='engineer'))] - eq_(sess.query(Person, Company.name, palias) - .join(Company.employees) - .filter(Company.name == 'Elbonia, Inc.') - .filter(palias.name == 'dilbert').all(), - expected) + expected = [ + ( + Engineer( + status="elbonian engineer", + engineer_name="vlad", + name="vlad", + primary_language="cobol", + ), + "Elbonia, Inc.", + Engineer( + status="regular engineer", + engineer_name="dilbert", + name="dilbert", + company_id=1, + primary_language="java", + person_id=1, + type="engineer", + ), + ) + ] + eq_( + sess.query(Person, Company.name, palias) + .join(Company.employees) + .filter(Company.name == "Elbonia, Inc.") + .filter(palias.name == "dilbert") + .all(), + expected, + ) def test_mixed_entities_four(self): sess = create_session() palias = aliased(Person) - expected = [( - Engineer( - status='regular engineer', - engineer_name='dilbert', - name='dilbert', - company_id=1, - primary_language='java', - person_id=1, - type='engineer'), - 'Elbonia, Inc.', - Engineer( - status='elbonian engineer', - engineer_name='vlad', - name='vlad', - primary_language='cobol'),)] - eq_(sess.query(palias, Company.name, Person) - .join(Company.employees) - .filter(Company.name == 'Elbonia, Inc.') - .filter(palias.name == 'dilbert').all(), - expected) + expected = [ + ( + Engineer( + status="regular engineer", + engineer_name="dilbert", + name="dilbert", + company_id=1, + primary_language="java", + person_id=1, + type="engineer", + ), + "Elbonia, Inc.", + Engineer( + status="elbonian engineer", + engineer_name="vlad", + name="vlad", + primary_language="cobol", + ), + ) + ] + eq_( + sess.query(palias, Company.name, Person) + .join(Company.employees) + .filter(Company.name == "Elbonia, Inc.") + .filter(palias.name == "dilbert") + .all(), + expected, + ) def test_mixed_entities_five(self): sess = create_session() palias = aliased(Person) - expected = [('vlad', 'Elbonia, Inc.', 'dilbert')] - eq_(sess.query(Person.name, Company.name, palias.name) - .join(Company.employees) - .filter(Company.name == 'Elbonia, Inc.') - .filter(palias.name == 'dilbert').all(), - expected) + expected = [("vlad", "Elbonia, Inc.", "dilbert")] + eq_( + sess.query(Person.name, Company.name, palias.name) + .join(Company.employees) + .filter(Company.name == "Elbonia, Inc.") + .filter(palias.name == "dilbert") + .all(), + expected, + ) def test_mixed_entities_six(self): sess = create_session() palias = aliased(Person) expected = [ - ('manager', 'dogbert', 'engineer', 'dilbert'), - ('manager', 'dogbert', 'engineer', 'wally'), - ('manager', 'dogbert', 'boss', 'pointy haired boss')] - eq_(sess.query(Person.type, Person.name, palias.type, palias.name) - .filter(Person.company_id == palias.company_id) - .filter(Person.name == 'dogbert') - .filter(Person.person_id > palias.person_id) - .order_by(Person.person_id, palias.person_id).all(), - expected) + ("manager", "dogbert", "engineer", "dilbert"), + ("manager", "dogbert", "engineer", "wally"), + ("manager", "dogbert", "boss", "pointy haired boss"), + ] + eq_( + sess.query(Person.type, Person.name, palias.type, palias.name) + .filter(Person.company_id == palias.company_id) + .filter(Person.name == "dogbert") + .filter(Person.person_id > palias.person_id) + .order_by(Person.person_id, palias.person_id) + .all(), + expected, + ) def test_mixed_entities_seven(self): sess = create_session() expected = [ - ('dilbert', 'tps report #1'), - ('dilbert', 'tps report #2'), - ('dogbert', 'review #2'), - ('dogbert', 'review #3'), - ('pointy haired boss', 'review #1'), - ('vlad', 'elbonian missive #3'), - ('wally', 'tps report #3'), - ('wally', 'tps report #4')] - eq_(sess.query(Person.name, Paperwork.description) - .filter(Person.person_id == Paperwork.person_id) - .order_by(Person.name, Paperwork.description).all(), - expected) + ("dilbert", "tps report #1"), + ("dilbert", "tps report #2"), + ("dogbert", "review #2"), + ("dogbert", "review #3"), + ("pointy haired boss", "review #1"), + ("vlad", "elbonian missive #3"), + ("wally", "tps report #3"), + ("wally", "tps report #4"), + ] + eq_( + sess.query(Person.name, Paperwork.description) + .filter(Person.person_id == Paperwork.person_id) + .order_by(Person.name, Paperwork.description) + .all(), + expected, + ) def test_mixed_entities_eight(self): sess = create_session() - eq_(sess.query(func.count(Person.person_id)) - .filter(Engineer.primary_language == 'java').all(), - [(1,)]) + eq_( + sess.query(func.count(Person.person_id)) + .filter(Engineer.primary_language == "java") + .all(), + [(1,)], + ) def test_mixed_entities_nine(self): sess = create_session() - expected = [('Elbonia, Inc.', 1), ('MegaCorp, Inc.', 4)] - eq_(sess.query(Company.name, func.count(Person.person_id)) - .filter(Company.company_id == Person.company_id) - .group_by(Company.name) - .order_by(Company.name).all(), - expected) + expected = [("Elbonia, Inc.", 1), ("MegaCorp, Inc.", 4)] + eq_( + sess.query(Company.name, func.count(Person.person_id)) + .filter(Company.company_id == Person.company_id) + .group_by(Company.name) + .order_by(Company.name) + .all(), + expected, + ) def test_mixed_entities_ten(self): sess = create_session() - expected = [('Elbonia, Inc.', 1), ('MegaCorp, Inc.', 4)] - eq_(sess.query(Company.name, func.count(Person.person_id)) - .join(Company.employees) - .group_by(Company.name) - .order_by(Company.name).all(), - expected) + expected = [("Elbonia, Inc.", 1), ("MegaCorp, Inc.", 4)] + eq_( + sess.query(Company.name, func.count(Person.person_id)) + .join(Company.employees) + .group_by(Company.name) + .order_by(Company.name) + .all(), + expected, + ) # def test_mixed_entities(self): # sess = create_session() - # TODO: I think raise error on these for now. different - # inheritance/loading schemes have different results here, - # all incorrect - # - # eq_( - # sess.query(Person.name, Engineer.primary_language).all(), - # []) + # TODO: I think raise error on these for now. different + # inheritance/loading schemes have different results here, + # all incorrect + # + # eq_( + # sess.query(Person.name, Engineer.primary_language).all(), + # []) # def test_mixed_entities(self): # sess = create_session() - # eq_(sess.query( - # Person.name, - # Engineer.primary_language, - # Manager.manager_name) - # .all(), - # []) + # eq_(sess.query( + # Person.name, + # Engineer.primary_language, + # Manager.manager_name) + # .all(), + # []) def test_mixed_entities_eleven(self): sess = create_session() - expected = [('java',), ('c++',), ('cobol',)] - eq_(sess.query(Engineer.primary_language) - .filter(Person.type == 'engineer').all(), - expected) + expected = [("java",), ("c++",), ("cobol",)] + eq_( + sess.query(Engineer.primary_language) + .filter(Person.type == "engineer") + .all(), + expected, + ) def test_mixed_entities_twelve(self): sess = create_session() - expected = [('vlad', 'Elbonia, Inc.')] - eq_(sess.query(Person.name, Company.name) - .join(Company.employees) - .filter(Company.name == 'Elbonia, Inc.').all(), - expected) + expected = [("vlad", "Elbonia, Inc.")] + eq_( + sess.query(Person.name, Company.name) + .join(Company.employees) + .filter(Company.name == "Elbonia, Inc.") + .all(), + expected, + ) def test_mixed_entities_thirteen(self): sess = create_session() - expected = [('pointy haired boss', 'fore')] + expected = [("pointy haired boss", "fore")] eq_(sess.query(Boss.name, Boss.golf_swing).all(), expected) def test_mixed_entities_fourteen(self): sess = create_session() - expected = [ - ('dilbert', 'java'), - ('wally', 'c++'), - ('vlad', 'cobol')] - eq_(sess.query(Engineer.name, Engineer.primary_language).all(), - expected) + expected = [("dilbert", "java"), ("wally", "c++"), ("vlad", "cobol")] + eq_( + sess.query(Engineer.name, Engineer.primary_language).all(), + expected, + ) def test_mixed_entities_fifteen(self): sess = create_session() - expected = [( - 'Elbonia, Inc.', - Engineer( - status='elbonian engineer', - engineer_name='vlad', - name='vlad', - primary_language='cobol'))] - eq_(sess.query(Company.name, Person) - .join(Company.employees) - .filter(Company.name == 'Elbonia, Inc.').all(), - expected) + expected = [ + ( + "Elbonia, Inc.", + Engineer( + status="elbonian engineer", + engineer_name="vlad", + name="vlad", + primary_language="cobol", + ), + ) + ] + eq_( + sess.query(Company.name, Person) + .join(Company.employees) + .filter(Company.name == "Elbonia, Inc.") + .all(), + expected, + ) def test_mixed_entities_sixteen(self): sess = create_session() - expected = [( - Engineer( - status='elbonian engineer', - engineer_name='vlad', - name='vlad', - primary_language='cobol'), - 'Elbonia, Inc.')] - eq_(sess.query(Person, Company.name) - .join(Company.employees) - .filter(Company.name == 'Elbonia, Inc.').all(), - expected) + expected = [ + ( + Engineer( + status="elbonian engineer", + engineer_name="vlad", + name="vlad", + primary_language="cobol", + ), + "Elbonia, Inc.", + ) + ] + eq_( + sess.query(Person, Company.name) + .join(Company.employees) + .filter(Company.name == "Elbonia, Inc.") + .all(), + expected, + ) def test_mixed_entities_seventeen(self): sess = create_session() - expected = [('pointy haired boss',), ('dogbert',)] + expected = [("pointy haired boss",), ("dogbert",)] eq_(sess.query(Manager.name).all(), expected) def test_mixed_entities_eighteen(self): sess = create_session() - expected = [('pointy haired boss foo',), ('dogbert foo',)] + expected = [("pointy haired boss foo",), ("dogbert foo",)] eq_(sess.query(Manager.name + " foo").all(), expected) def test_mixed_entities_nineteen(self): sess = create_session() - row = sess.query(Engineer.name, Engineer.primary_language) \ - .filter(Engineer.name == 'dilbert').first() - assert row.name == 'dilbert' - assert row.primary_language == 'java' + row = ( + sess.query(Engineer.name, Engineer.primary_language) + .filter(Engineer.name == "dilbert") + .first() + ) + assert row.name == "dilbert" + assert row.primary_language == "java" def test_correlation_one(self): sess = create_session() @@ -1290,79 +1644,115 @@ class _PolymorphicTestBase(object): # PolymorphicUnions, which was due to the no_replacement_traverse # annotation added to query.statement which then went into as_scalar(). # this is removed as of :ticket:`4304` so now works. - eq_(sess.query(Person.name) - .filter( - sess.query(Company.name). - filter(Company.company_id == Person.company_id). - correlate(Person).as_scalar() == "Elbonia, Inc.").all(), - [(e3.name, )]) + eq_( + sess.query(Person.name) + .filter( + sess.query(Company.name) + .filter(Company.company_id == Person.company_id) + .correlate(Person) + .as_scalar() + == "Elbonia, Inc." + ) + .all(), + [(e3.name,)], + ) def test_correlation_two(self): sess = create_session() paliased = aliased(Person) - eq_(sess.query(paliased.name) - .filter( - sess.query(Company.name). - filter(Company.company_id == paliased.company_id). - correlate(paliased).as_scalar() == "Elbonia, Inc.").all(), - [(e3.name, )]) + eq_( + sess.query(paliased.name) + .filter( + sess.query(Company.name) + .filter(Company.company_id == paliased.company_id) + .correlate(paliased) + .as_scalar() + == "Elbonia, Inc." + ) + .all(), + [(e3.name,)], + ) def test_correlation_three(self): sess = create_session() paliased = aliased(Person, flat=True) - eq_(sess.query(paliased.name) - .filter( - sess.query(Company.name). - filter(Company.company_id == paliased.company_id). - correlate(paliased).as_scalar() == "Elbonia, Inc.").all(), - [(e3.name, )]) + eq_( + sess.query(paliased.name) + .filter( + sess.query(Company.name) + .filter(Company.company_id == paliased.company_id) + .correlate(paliased) + .as_scalar() + == "Elbonia, Inc." + ) + .all(), + [(e3.name,)], + ) class PolymorphicTest(_PolymorphicTestBase, _Polymorphic): def test_join_to_subclass_four(self): sess = create_session() - eq_(sess.query(Person) - .select_from(people.join(engineers)) - .join(Engineer.machines).all(), - [e1, e2, e3]) + eq_( + sess.query(Person) + .select_from(people.join(engineers)) + .join(Engineer.machines) + .all(), + [e1, e2, e3], + ) def test_join_to_subclass_five(self): sess = create_session() - eq_(sess.query(Person) - .select_from(people.join(engineers)) - .join(Engineer.machines) - .filter(Machine.name.ilike("%ibm%")).all(), - [e1, e3]) + eq_( + sess.query(Person) + .select_from(people.join(engineers)) + .join(Engineer.machines) + .filter(Machine.name.ilike("%ibm%")) + .all(), + [e1, e3], + ) def test_correlation_w_polymorphic(self): sess = create_session() - p_poly = with_polymorphic(Person, '*') + p_poly = with_polymorphic(Person, "*") - eq_(sess.query(p_poly.name) - .filter( - sess.query(Company.name). - filter(Company.company_id == p_poly.company_id). - correlate(p_poly).as_scalar() == "Elbonia, Inc.").all(), - [(e3.name, )]) + eq_( + sess.query(p_poly.name) + .filter( + sess.query(Company.name) + .filter(Company.company_id == p_poly.company_id) + .correlate(p_poly) + .as_scalar() + == "Elbonia, Inc." + ) + .all(), + [(e3.name,)], + ) def test_correlation_w_polymorphic_flat(self): sess = create_session() - p_poly = with_polymorphic(Person, '*', flat=True) + p_poly = with_polymorphic(Person, "*", flat=True) - eq_(sess.query(p_poly.name) - .filter( - sess.query(Company.name). - filter(Company.company_id == p_poly.company_id). - correlate(p_poly).as_scalar() == "Elbonia, Inc.").all(), - [(e3.name, )]) + eq_( + sess.query(p_poly.name) + .filter( + sess.query(Company.name) + .filter(Company.company_id == p_poly.company_id) + .correlate(p_poly) + .as_scalar() + == "Elbonia, Inc." + ) + .all(), + [(e3.name,)], + ) def test_join_to_subclass_ten(self): pass @@ -1381,8 +1771,9 @@ class PolymorphicTest(_PolymorphicTestBase, _Polymorphic): class PolymorphicPolymorphicTest( - _PolymorphicTestBase, _PolymorphicPolymorphic): - __dialect__ = 'default' + _PolymorphicTestBase, _PolymorphicPolymorphic +): + __dialect__ = "default" def test_aliased_not_polluted_by_join(self): # aliased(polymorphic) will normally do the old-school @@ -1391,8 +1782,10 @@ class PolymorphicPolymorphicTest( sess = create_session() palias = aliased(Person) self.assert_compile( - sess.query(palias, Company.name).order_by(palias.person_id). - join(Person, Company.employees).filter(palias.name == 'dilbert'), + sess.query(palias, Company.name) + .order_by(palias.person_id) + .join(Person, Company.employees) + .filter(palias.name == "dilbert"), "SELECT anon_1.people_person_id AS anon_1_people_person_id, " "anon_1.people_company_id AS anon_1_people_company_id, " "anon_1.people_name AS anon_1_people_name, " @@ -1433,15 +1826,18 @@ class PolymorphicPolymorphicTest( "LEFT OUTER JOIN boss ON managers.person_id = boss.boss_id) " "ON companies.company_id = people.company_id " "WHERE anon_1.people_name = :people_name_1 " - "ORDER BY anon_1.people_person_id") + "ORDER BY anon_1.people_person_id", + ) def test_flat_aliased_w_select_from(self): sess = create_session() palias = aliased(Person, flat=True) self.assert_compile( - sess.query(palias, Company.name). - select_from(palias).order_by(palias.person_id).join( - Person, Company.employees).filter(palias.name == 'dilbert'), + sess.query(palias, Company.name) + .select_from(palias) + .order_by(palias.person_id) + .join(Person, Company.employees) + .filter(palias.name == "dilbert"), "SELECT people_1.person_id AS people_1_person_id, " "people_1.company_id AS people_1_company_id, " "people_1.name AS people_1_name, people_1.type AS people_1_type, " @@ -1468,7 +1864,8 @@ class PolymorphicPolymorphicTest( "ON people.person_id = managers.person_id " "LEFT OUTER JOIN boss ON managers.person_id = boss.boss_id) " "ON companies.company_id = people.company_id " - "WHERE people_1.name = :name_1 ORDER BY people_1.person_id") + "WHERE people_1.name = :name_1 ORDER BY people_1.person_id", + ) class PolymorphicUnionsTest(_PolymorphicTestBase, _PolymorphicUnions): @@ -1476,7 +1873,8 @@ class PolymorphicUnionsTest(_PolymorphicTestBase, _PolymorphicUnions): class PolymorphicAliasedJoinsTest( - _PolymorphicTestBase, _PolymorphicAliasedJoins): + _PolymorphicTestBase, _PolymorphicAliasedJoins +): pass diff --git a/test/orm/inheritance/test_productspec.py b/test/orm/inheritance/test_productspec.py index 54cc51e82a..24ed47f702 100644 --- a/test/orm/inheritance/test_productspec.py +++ b/test/orm/inheritance/test_productspec.py @@ -10,49 +10,80 @@ from sqlalchemy.testing.schema import Table, Column class InheritTest(fixtures.MappedTest): """tests some various inheritance round trips involving a particular set of polymorphic inheritance relationships""" + @classmethod def define_tables(cls, metadata): global products_table, specification_table, documents_table global Product, Detail, Assembly, SpecLine, Document, RasterDocument products_table = Table( - 'products', metadata, - Column('product_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('product_type', String(128)), - Column('name', String(128)), - Column('mark', String(128)),) + "products", + metadata, + Column( + "product_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("product_type", String(128)), + Column("name", String(128)), + Column("mark", String(128)), + ) specification_table = Table( - 'specification', metadata, - Column('spec_line_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('master_id', Integer, ForeignKey("products.product_id"), - nullable=True), - Column('slave_id', Integer, ForeignKey("products.product_id"), - nullable=True), - Column('quantity', Float, default=1.)) + "specification", + metadata, + Column( + "spec_line_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column( + "master_id", + Integer, + ForeignKey("products.product_id"), + nullable=True, + ), + Column( + "slave_id", + Integer, + ForeignKey("products.product_id"), + nullable=True, + ), + Column("quantity", Float, default=1.0), + ) documents_table = Table( - 'documents', metadata, - Column('document_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('document_type', String(128)), - Column('product_id', Integer, ForeignKey('products.product_id')), - Column('create_date', DateTime, default=lambda: datetime.now()), - Column('last_updated', DateTime, default=lambda: datetime.now(), - onupdate=lambda: datetime.now()), - Column('name', String(128)), - Column('data', LargeBinary), - Column('size', Integer, default=0)) + "documents", + metadata, + Column( + "document_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("document_type", String(128)), + Column("product_id", Integer, ForeignKey("products.product_id")), + Column("create_date", DateTime, default=lambda: datetime.now()), + Column( + "last_updated", + DateTime, + default=lambda: datetime.now(), + onupdate=lambda: datetime.now(), + ), + Column("name", String(128)), + Column("data", LargeBinary), + Column("size", Integer, default=0), + ) class Product(object): - def __init__(self, name, mark=''): + def __init__(self, name, mark=""): self.name = name self.mark = mark def __repr__(self): - return '<%s %s>' % (self.__class__.__name__, self.name) + return "<%s %s>" % (self.__class__.__name__, self.name) class Detail(Product): def __init__(self, name): @@ -60,9 +91,16 @@ class InheritTest(fixtures.MappedTest): class Assembly(Product): def __repr__(self): - return Product.__repr__(self) + " " + " ".join( - [x + "=" + repr(getattr(self, x, None)) - for x in ['specification', 'documents']]) + return ( + Product.__repr__(self) + + " " + + " ".join( + [ + x + "=" + repr(getattr(self, x, None)) + for x in ["specification", "documents"] + ] + ) + ) class SpecLine(object): def __init__(self, master=None, slave=None, quantity=1): @@ -71,10 +109,11 @@ class InheritTest(fixtures.MappedTest): self.quantity = quantity def __repr__(self): - return '<%s %.01f %s>' % ( + return "<%s %.01f %s>" % ( self.__class__.__name__, - self.quantity or 0., - repr(self.slave)) + self.quantity or 0.0, + repr(self.slave), + ) class Document(object): def __init__(self, name, data=None): @@ -82,45 +121,60 @@ class InheritTest(fixtures.MappedTest): self.data = data def __repr__(self): - return '<%s %s>' % (self.__class__.__name__, self.name) + return "<%s %s>" % (self.__class__.__name__, self.name) class RasterDocument(Document): pass def test_one(self): - product_mapper = mapper(Product, products_table, - polymorphic_on=products_table.c.product_type, - polymorphic_identity='product') + product_mapper = mapper( + Product, + products_table, + polymorphic_on=products_table.c.product_type, + polymorphic_identity="product", + ) - detail_mapper = mapper(Detail, inherits=product_mapper, - polymorphic_identity='detail') + detail_mapper = mapper( + Detail, inherits=product_mapper, polymorphic_identity="detail" + ) - assembly_mapper = mapper(Assembly, inherits=product_mapper, - polymorphic_identity='assembly') + assembly_mapper = mapper( + Assembly, inherits=product_mapper, polymorphic_identity="assembly" + ) specification_mapper = mapper( - SpecLine, specification_table, + SpecLine, + specification_table, properties=dict( master=relationship( - Assembly, foreign_keys=[specification_table.c.master_id], - primaryjoin=specification_table.c.master_id == - products_table.c.product_id, lazy='select', - backref=backref('specification'), - uselist=False), + Assembly, + foreign_keys=[specification_table.c.master_id], + primaryjoin=specification_table.c.master_id + == products_table.c.product_id, + lazy="select", + backref=backref("specification"), + uselist=False, + ), slave=relationship( - Product, foreign_keys=[specification_table.c.slave_id], - primaryjoin=specification_table.c.slave_id == - products_table.c.product_id, lazy='select', uselist=False), - quantity=specification_table.c.quantity)) + Product, + foreign_keys=[specification_table.c.slave_id], + primaryjoin=specification_table.c.slave_id + == products_table.c.product_id, + lazy="select", + uselist=False, + ), + quantity=specification_table.c.quantity, + ), + ) session = create_session() - a1 = Assembly(name='a1') + a1 = Assembly(name="a1") - p1 = Product(name='p1') + p1 = Product(name="p1") a1.specification.append(SpecLine(slave=p1)) - d1 = Detail(name='d1') + d1 = Detail(name="d1") a1.specification.append(SpecLine(slave=d1)) session.add(a1) @@ -128,34 +182,46 @@ class InheritTest(fixtures.MappedTest): session.flush() session.expunge_all() - a1 = session.query(Product).filter_by(name='a1').one() + a1 = session.query(Product).filter_by(name="a1").one() new = repr(a1) print(orig) print(new) - assert orig == new == ' specification=[>, >] documents=None' + assert ( + orig == new == " specification=[>, >] documents=None" + ) def test_two(self): - product_mapper = mapper(Product, products_table, - polymorphic_on=products_table.c.product_type, - polymorphic_identity='product') + product_mapper = mapper( + Product, + products_table, + polymorphic_on=products_table.c.product_type, + polymorphic_identity="product", + ) - detail_mapper = mapper(Detail, inherits=product_mapper, - polymorphic_identity='detail') + detail_mapper = mapper( + Detail, inherits=product_mapper, polymorphic_identity="detail" + ) specification_mapper = mapper( - SpecLine, specification_table, + SpecLine, + specification_table, properties=dict( slave=relationship( - Product, foreign_keys=[specification_table.c.slave_id], - primaryjoin=specification_table.c.slave_id == - products_table.c.product_id, lazy='select', - uselist=False))) + Product, + foreign_keys=[specification_table.c.slave_id], + primaryjoin=specification_table.c.slave_id + == products_table.c.product_id, + lazy="select", + uselist=False, + ) + ), + ) session = create_session() - s = SpecLine(slave=Product(name='p1')) - s2 = SpecLine(slave=Detail(name='d1')) + s = SpecLine(slave=Product(name="p1")) + s2 = SpecLine(slave=Detail(name="d1")) session.add(s) session.add(s2) orig = repr([s, s2]) @@ -164,68 +230,93 @@ class InheritTest(fixtures.MappedTest): new = repr(session.query(SpecLine).all()) print(orig) print(new) - assert orig == new == '[>, ' \ - '>]' + assert ( + orig == new == "[>, " + ">]" + ) def test_three(self): - product_mapper = mapper(Product, products_table, - polymorphic_on=products_table.c.product_type, - polymorphic_identity='product') - detail_mapper = mapper(Detail, inherits=product_mapper, - polymorphic_identity='detail') - assembly_mapper = mapper(Assembly, inherits=product_mapper, - polymorphic_identity='assembly') + product_mapper = mapper( + Product, + products_table, + polymorphic_on=products_table.c.product_type, + polymorphic_identity="product", + ) + detail_mapper = mapper( + Detail, inherits=product_mapper, polymorphic_identity="detail" + ) + assembly_mapper = mapper( + Assembly, inherits=product_mapper, polymorphic_identity="assembly" + ) specification_mapper = mapper( - SpecLine, specification_table, + SpecLine, + specification_table, properties=dict( master=relationship( - Assembly, lazy='joined', uselist=False, + Assembly, + lazy="joined", + uselist=False, foreign_keys=[specification_table.c.master_id], - primaryjoin=specification_table.c.master_id == - products_table.c.product_id, + primaryjoin=specification_table.c.master_id + == products_table.c.product_id, backref=backref( - 'specification', cascade="all, delete-orphan")), + "specification", cascade="all, delete-orphan" + ), + ), slave=relationship( - Product, lazy='joined', uselist=False, + Product, + lazy="joined", + uselist=False, foreign_keys=[specification_table.c.slave_id], - primaryjoin=specification_table.c.slave_id == - products_table.c.product_id,), - quantity=specification_table.c.quantity)) + primaryjoin=specification_table.c.slave_id + == products_table.c.product_id, + ), + quantity=specification_table.c.quantity, + ), + ) document_mapper = mapper( - Document, documents_table, + Document, + documents_table, polymorphic_on=documents_table.c.document_type, - polymorphic_identity='document', + polymorphic_identity="document", properties=dict( name=documents_table.c.name, data=deferred(documents_table.c.data), product=relationship( - Product, lazy='select', - backref=backref( - 'documents', cascade="all, delete-orphan")))) + Product, + lazy="select", + backref=backref("documents", cascade="all, delete-orphan"), + ), + ), + ) raster_document_mapper = mapper( - RasterDocument, inherits=document_mapper, - polymorphic_identity='raster_document') + RasterDocument, + inherits=document_mapper, + polymorphic_identity="raster_document", + ) session = create_session() - a1 = Assembly(name='a1') - a1.specification.append(SpecLine(slave=Detail(name='d1'))) - a1.documents.append(Document('doc1')) - a1.documents.append(RasterDocument('doc2')) + a1 = Assembly(name="a1") + a1.specification.append(SpecLine(slave=Detail(name="d1"))) + a1.documents.append(Document("doc1")) + a1.documents.append(RasterDocument("doc2")) session.add(a1) orig = repr(a1) session.flush() session.expunge_all() - a1 = session.query(Product).filter_by(name='a1').one() + a1 = session.query(Product).filter_by(name="a1").one() new = repr(a1) print(orig) print(new) - assert orig == new == ' specification=' \ - '[>] ' \ - 'documents=[, ]' + assert ( + orig == new == " specification=" + "[>] " + "documents=[, ]" + ) def test_four(self): """this tests the RasterDocument being attached to the Assembly, but @@ -233,111 +324,150 @@ class InheritTest(fixtures.MappedTest): corresponding to an inheriting mapper but not the base mapper, is created. """ - product_mapper = mapper(Product, products_table, - polymorphic_on=products_table.c.product_type, - polymorphic_identity='product') - detail_mapper = mapper(Detail, inherits=product_mapper, - polymorphic_identity='detail') - assembly_mapper = mapper(Assembly, inherits=product_mapper, - polymorphic_identity='assembly') + product_mapper = mapper( + Product, + products_table, + polymorphic_on=products_table.c.product_type, + polymorphic_identity="product", + ) + detail_mapper = mapper( + Detail, inherits=product_mapper, polymorphic_identity="detail" + ) + assembly_mapper = mapper( + Assembly, inherits=product_mapper, polymorphic_identity="assembly" + ) document_mapper = mapper( - Document, documents_table, + Document, + documents_table, polymorphic_on=documents_table.c.document_type, - polymorphic_identity='document', + polymorphic_identity="document", properties=dict( name=documents_table.c.name, data=deferred(documents_table.c.data), product=relationship( - Product, lazy='select', - backref=backref( - 'documents', cascade="all, delete-orphan")))) + Product, + lazy="select", + backref=backref("documents", cascade="all, delete-orphan"), + ), + ), + ) raster_document_mapper = mapper( - RasterDocument, inherits=document_mapper, - polymorphic_identity='raster_document') + RasterDocument, + inherits=document_mapper, + polymorphic_identity="raster_document", + ) session = create_session() - a1 = Assembly(name='a1') - a1.documents.append(RasterDocument('doc2')) + a1 = Assembly(name="a1") + a1.documents.append(RasterDocument("doc2")) session.add(a1) orig = repr(a1) session.flush() session.expunge_all() - a1 = session.query(Product).filter_by(name='a1').one() + a1 = session.query(Product).filter_by(name="a1").one() new = repr(a1) print(orig) print(new) - assert orig == new == ' specification=None documents=' \ - '[]' + assert ( + orig == new == " specification=None documents=" + "[]" + ) del a1.documents[0] session.flush() session.expunge_all() - a1 = session.query(Product).filter_by(name='a1').one() + a1 = session.query(Product).filter_by(name="a1").one() assert len(session.query(Document).all()) == 0 def test_five(self): """tests the late compilation of mappers""" specification_mapper = mapper( - SpecLine, specification_table, + SpecLine, + specification_table, properties=dict( master=relationship( - Assembly, lazy='joined', uselist=False, + Assembly, + lazy="joined", + uselist=False, foreign_keys=[specification_table.c.master_id], - primaryjoin=specification_table.c.master_id == - products_table.c.product_id, - backref=backref('specification')), + primaryjoin=specification_table.c.master_id + == products_table.c.product_id, + backref=backref("specification"), + ), slave=relationship( - Product, lazy='joined', uselist=False, + Product, + lazy="joined", + uselist=False, foreign_keys=[specification_table.c.slave_id], - primaryjoin=specification_table.c.slave_id == - products_table.c.product_id,), - quantity=specification_table.c.quantity)) + primaryjoin=specification_table.c.slave_id + == products_table.c.product_id, + ), + quantity=specification_table.c.quantity, + ), + ) product_mapper = mapper( - Product, products_table, + Product, + products_table, polymorphic_on=products_table.c.product_type, - polymorphic_identity='product', properties={ - 'documents': relationship(Document, lazy='select', - backref='product', - cascade='all, delete-orphan')}) - - detail_mapper = mapper(Detail, inherits=Product, - polymorphic_identity='detail') + polymorphic_identity="product", + properties={ + "documents": relationship( + Document, + lazy="select", + backref="product", + cascade="all, delete-orphan", + ) + }, + ) + + detail_mapper = mapper( + Detail, inherits=Product, polymorphic_identity="detail" + ) document_mapper = mapper( - Document, documents_table, + Document, + documents_table, polymorphic_on=documents_table.c.document_type, - polymorphic_identity='document', + polymorphic_identity="document", properties=dict( name=documents_table.c.name, - data=deferred(documents_table.c.data))) + data=deferred(documents_table.c.data), + ), + ) - raster_document_mapper = mapper(RasterDocument, inherits=Document, - polymorphic_identity='raster_document') + raster_document_mapper = mapper( + RasterDocument, + inherits=Document, + polymorphic_identity="raster_document", + ) - assembly_mapper = mapper(Assembly, inherits=Product, - polymorphic_identity='assembly') + assembly_mapper = mapper( + Assembly, inherits=Product, polymorphic_identity="assembly" + ) session = create_session() - a1 = Assembly(name='a1') - a1.specification.append(SpecLine(slave=Detail(name='d1'))) - a1.documents.append(Document('doc1')) - a1.documents.append(RasterDocument('doc2')) + a1 = Assembly(name="a1") + a1.specification.append(SpecLine(slave=Detail(name="d1"))) + a1.documents.append(Document("doc1")) + a1.documents.append(RasterDocument("doc2")) session.add(a1) orig = repr(a1) session.flush() session.expunge_all() - a1 = session.query(Product).filter_by(name='a1').one() + a1 = session.query(Product).filter_by(name="a1").one() new = repr(a1) print(orig) print(new) - assert orig == new == ' specification=' \ - '[>] documents=[, ' \ - ']' + assert ( + orig == new == " specification=" + "[>] documents=[, " + "]" + ) diff --git a/test/orm/inheritance/test_relationship.py b/test/orm/inheritance/test_relationship.py index 246cf214b8..a16d9dc346 100644 --- a/test/orm/inheritance/test_relationship.py +++ b/test/orm/inheritance/test_relationship.py @@ -1,6 +1,17 @@ -from sqlalchemy.orm import create_session, relationship, mapper, \ - contains_eager, joinedload, subqueryload, subqueryload_all,\ - Session, aliased, with_polymorphic, joinedload_all, backref +from sqlalchemy.orm import ( + create_session, + relationship, + mapper, + contains_eager, + joinedload, + subqueryload, + subqueryload_all, + Session, + aliased, + with_polymorphic, + joinedload_all, + backref, +) from sqlalchemy import Integer, String, ForeignKey, select, func from sqlalchemy.engine import default @@ -41,108 +52,154 @@ class Paperwork(fixtures.ComparableEntity): class SelfReferentialTestJoinedToBase(fixtures.MappedTest): - run_setup_mappers = 'once' + run_setup_mappers = "once" @classmethod def define_tables(cls, metadata): - Table('people', metadata, - Column('person_id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('type', String(30))) - - Table('engineers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('primary_language', String(50)), - Column('reports_to_id', Integer, - ForeignKey('people.person_id'))) + Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("type", String(30)), + ) + + Table( + "engineers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("primary_language", String(50)), + Column("reports_to_id", Integer, ForeignKey("people.person_id")), + ) @classmethod def setup_mappers(cls): engineers, people = cls.tables.engineers, cls.tables.people - mapper(Person, people, - polymorphic_on=people.c.type, - polymorphic_identity='person') + mapper( + Person, + people, + polymorphic_on=people.c.type, + polymorphic_identity="person", + ) - mapper(Engineer, engineers, - inherits=Person, - inherit_condition=engineers.c.person_id == people.c.person_id, - polymorphic_identity='engineer', - properties={ - 'reports_to': relationship( - Person, - primaryjoin=( - people.c.person_id == engineers.c.reports_to_id))}) + mapper( + Engineer, + engineers, + inherits=Person, + inherit_condition=engineers.c.person_id == people.c.person_id, + polymorphic_identity="engineer", + properties={ + "reports_to": relationship( + Person, + primaryjoin=( + people.c.person_id == engineers.c.reports_to_id + ), + ) + }, + ) def test_has(self): - p1 = Person(name='dogbert') - e1 = Engineer(name='dilbert', primary_language='java', reports_to=p1) + p1 = Person(name="dogbert") + e1 = Engineer(name="dilbert", primary_language="java", reports_to=p1) sess = create_session() sess.add(p1) sess.add(e1) sess.flush() sess.expunge_all() - eq_(sess.query(Engineer) - .filter(Engineer.reports_to.has(Person.name == 'dogbert')) - .first(), - Engineer(name='dilbert')) + eq_( + sess.query(Engineer) + .filter(Engineer.reports_to.has(Person.name == "dogbert")) + .first(), + Engineer(name="dilbert"), + ) def test_oftype_aliases_in_exists(self): - e1 = Engineer(name='dilbert', primary_language='java') - e2 = Engineer(name='wally', primary_language='c++', reports_to=e1) + e1 = Engineer(name="dilbert", primary_language="java") + e2 = Engineer(name="wally", primary_language="c++", reports_to=e1) sess = create_session() sess.add_all([e1, e2]) sess.flush() - eq_(sess.query(Engineer) - .filter(Engineer.reports_to - .of_type(Engineer) - .has(Engineer.name == 'dilbert')) - .first(), - e2) + eq_( + sess.query(Engineer) + .filter( + Engineer.reports_to.of_type(Engineer).has( + Engineer.name == "dilbert" + ) + ) + .first(), + e2, + ) def test_join(self): - p1 = Person(name='dogbert') - e1 = Engineer(name='dilbert', primary_language='java', reports_to=p1) + p1 = Person(name="dogbert") + e1 = Engineer(name="dilbert", primary_language="java", reports_to=p1) sess = create_session() sess.add(p1) sess.add(e1) sess.flush() sess.expunge_all() - eq_(sess.query(Engineer) - .join('reports_to', aliased=True) - .filter(Person.name == 'dogbert').first(), - Engineer(name='dilbert')) + eq_( + sess.query(Engineer) + .join("reports_to", aliased=True) + .filter(Person.name == "dogbert") + .first(), + Engineer(name="dilbert"), + ) class SelfReferentialJ2JTest(fixtures.MappedTest): - run_setup_mappers = 'once' + run_setup_mappers = "once" @classmethod def define_tables(cls, metadata): - people = Table('people', metadata, - Column('person_id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('type', String(30))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('primary_language', String(50)), - Column('reports_to_id', Integer, - ForeignKey('managers.person_id'))) - - managers = Table('managers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True),) + people = Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("type", String(30)), + ) + + engineers = Table( + "engineers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("primary_language", String(50)), + Column("reports_to_id", Integer, ForeignKey("managers.person_id")), + ) + + managers = Table( + "managers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + ) @classmethod def setup_mappers(cls): @@ -150,58 +207,72 @@ class SelfReferentialJ2JTest(fixtures.MappedTest): managers = cls.tables.managers people = cls.tables.people - mapper(Person, people, - polymorphic_on=people.c.type, - polymorphic_identity='person') + mapper( + Person, + people, + polymorphic_on=people.c.type, + polymorphic_identity="person", + ) - mapper(Manager, managers, - inherits=Person, - polymorphic_identity='manager') + mapper( + Manager, managers, inherits=Person, polymorphic_identity="manager" + ) - mapper(Engineer, engineers, - inherits=Person, - polymorphic_identity='engineer', - properties={ - 'reports_to': relationship( - Manager, - primaryjoin=( - managers.c.person_id == engineers.c.reports_to_id), - backref='engineers')}) + mapper( + Engineer, + engineers, + inherits=Person, + polymorphic_identity="engineer", + properties={ + "reports_to": relationship( + Manager, + primaryjoin=( + managers.c.person_id == engineers.c.reports_to_id + ), + backref="engineers", + ) + }, + ) def test_has(self): - m1 = Manager(name='dogbert') - e1 = Engineer(name='dilbert', primary_language='java', reports_to=m1) + m1 = Manager(name="dogbert") + e1 = Engineer(name="dilbert", primary_language="java", reports_to=m1) sess = create_session() sess.add(m1) sess.add(e1) sess.flush() sess.expunge_all() - eq_(sess.query(Engineer) - .filter(Engineer.reports_to.has(Manager.name == 'dogbert')) - .first(), - Engineer(name='dilbert')) + eq_( + sess.query(Engineer) + .filter(Engineer.reports_to.has(Manager.name == "dogbert")) + .first(), + Engineer(name="dilbert"), + ) def test_join(self): - m1 = Manager(name='dogbert') - e1 = Engineer(name='dilbert', primary_language='java', reports_to=m1) + m1 = Manager(name="dogbert") + e1 = Engineer(name="dilbert", primary_language="java", reports_to=m1) sess = create_session() sess.add(m1) sess.add(e1) sess.flush() sess.expunge_all() - eq_(sess.query(Engineer) - .join('reports_to', aliased=True) - .filter(Manager.name == 'dogbert').first(), - Engineer(name='dilbert')) + eq_( + sess.query(Engineer) + .join("reports_to", aliased=True) + .filter(Manager.name == "dogbert") + .first(), + Engineer(name="dilbert"), + ) def test_filter_aliasing(self): - m1 = Manager(name='dogbert') - m2 = Manager(name='foo') - e1 = Engineer(name='wally', primary_language='java', reports_to=m1) - e2 = Engineer(name='dilbert', primary_language='c++', reports_to=m2) - e3 = Engineer(name='etc', primary_language='c++') + m1 = Manager(name="dogbert") + m2 = Manager(name="foo") + e1 = Engineer(name="wally", primary_language="java", reports_to=m1) + e2 = Engineer(name="dilbert", primary_language="c++", reports_to=m2) + e3 = Engineer(name="etc", primary_language="c++") sess = create_session() sess.add_all([m1, m2, e1, e2, e3]) @@ -209,27 +280,36 @@ class SelfReferentialJ2JTest(fixtures.MappedTest): sess.expunge_all() # filter aliasing applied to Engineer doesn't whack Manager - eq_(sess.query(Manager) - .join(Manager.engineers) - .filter(Manager.name == 'dogbert').all(), - [m1]) + eq_( + sess.query(Manager) + .join(Manager.engineers) + .filter(Manager.name == "dogbert") + .all(), + [m1], + ) - eq_(sess.query(Manager) - .join(Manager.engineers) - .filter(Engineer.name == 'dilbert').all(), - [m2]) + eq_( + sess.query(Manager) + .join(Manager.engineers) + .filter(Engineer.name == "dilbert") + .all(), + [m2], + ) - eq_(sess.query(Manager, Engineer) - .join(Manager.engineers) - .order_by(Manager.name.desc()).all(), - [(m2, e2), (m1, e1)]) + eq_( + sess.query(Manager, Engineer) + .join(Manager.engineers) + .order_by(Manager.name.desc()) + .all(), + [(m2, e2), (m1, e1)], + ) def test_relationship_compare(self): - m1 = Manager(name='dogbert') - m2 = Manager(name='foo') - e1 = Engineer(name='dilbert', primary_language='java', reports_to=m1) - e2 = Engineer(name='wally', primary_language='c++', reports_to=m2) - e3 = Engineer(name='etc', primary_language='c++') + m1 = Manager(name="dogbert") + m2 = Manager(name="foo") + e1 = Engineer(name="dilbert", primary_language="java", reports_to=m1) + e2 = Engineer(name="wally", primary_language="c++", reports_to=m2) + e3 = Engineer(name="etc", primary_language="c++") sess = create_session() sess.add(m1) @@ -240,60 +320,88 @@ class SelfReferentialJ2JTest(fixtures.MappedTest): sess.flush() sess.expunge_all() - eq_(sess.query(Manager) - .join(Manager.engineers) - .filter(Engineer.reports_to == None).all(), # noqa - []) + eq_( + sess.query(Manager) + .join(Manager.engineers) + .filter(Engineer.reports_to == None) + .all(), # noqa + [], + ) - eq_(sess.query(Manager) - .join(Manager.engineers) - .filter(Engineer.reports_to == m1).all(), - [m1]) + eq_( + sess.query(Manager) + .join(Manager.engineers) + .filter(Engineer.reports_to == m1) + .all(), + [m1], + ) class SelfReferentialJ2JSelfTest(fixtures.MappedTest): - run_setup_mappers = 'once' + run_setup_mappers = "once" @classmethod def define_tables(cls, metadata): - people = Table('people', metadata, - Column('person_id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('type', String(30))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('reports_to_id', Integer, - ForeignKey('engineers.person_id'))) + people = Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("type", String(30)), + ) + + engineers = Table( + "engineers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column( + "reports_to_id", Integer, ForeignKey("engineers.person_id") + ), + ) @classmethod def setup_mappers(cls): engineers = cls.tables.engineers people = cls.tables.people - mapper(Person, people, - polymorphic_on=people.c.type, - polymorphic_identity='person') - - mapper(Engineer, engineers, - inherits=Person, - polymorphic_identity='engineer', - properties={ - 'reports_to': relationship( - Engineer, - primaryjoin=( - engineers.c.person_id == engineers.c.reports_to_id), - backref='engineers', - remote_side=engineers.c.person_id)}) + mapper( + Person, + people, + polymorphic_on=people.c.type, + polymorphic_identity="person", + ) + + mapper( + Engineer, + engineers, + inherits=Person, + polymorphic_identity="engineer", + properties={ + "reports_to": relationship( + Engineer, + primaryjoin=( + engineers.c.person_id == engineers.c.reports_to_id + ), + backref="engineers", + remote_side=engineers.c.person_id, + ) + }, + ) def _two_obj_fixture(self): - e1 = Engineer(name='wally') - e2 = Engineer(name='dilbert', reports_to=e1) + e1 = Engineer(name="wally") + e2 = Engineer(name="dilbert", reports_to=e1) sess = Session() sess.add_all([e1, e2]) sess.commit() @@ -301,9 +409,7 @@ class SelfReferentialJ2JSelfTest(fixtures.MappedTest): def _five_obj_fixture(self): sess = Session() - e1, e2, e3, e4, e5 = [ - Engineer(name='e%d' % (i + 1)) for i in range(5) - ] + e1, e2, e3, e4, e5 = [Engineer(name="e%d" % (i + 1)) for i in range(5)] e3.reports_to = e1 e4.reports_to = e2 sess.add_all([e1, e2, e3, e4, e5]) @@ -312,86 +418,122 @@ class SelfReferentialJ2JSelfTest(fixtures.MappedTest): def test_has(self): sess = self._two_obj_fixture() - eq_(sess.query(Engineer) - .filter(Engineer.reports_to.has(Engineer.name == 'wally')) - .first(), - Engineer(name='dilbert')) + eq_( + sess.query(Engineer) + .filter(Engineer.reports_to.has(Engineer.name == "wally")) + .first(), + Engineer(name="dilbert"), + ) def test_join_explicit_alias(self): sess = self._five_obj_fixture() ea = aliased(Engineer) - eq_(sess.query(Engineer) - .join(ea, Engineer.engineers) - .filter(Engineer.name == 'e1').all(), - [Engineer(name='e1')]) + eq_( + sess.query(Engineer) + .join(ea, Engineer.engineers) + .filter(Engineer.name == "e1") + .all(), + [Engineer(name="e1")], + ) def test_join_aliased_flag_one(self): sess = self._two_obj_fixture() - eq_(sess.query(Engineer) - .join('reports_to', aliased=True) - .filter(Engineer.name == 'wally').first(), - Engineer(name='dilbert')) + eq_( + sess.query(Engineer) + .join("reports_to", aliased=True) + .filter(Engineer.name == "wally") + .first(), + Engineer(name="dilbert"), + ) def test_join_aliased_flag_two(self): sess = self._five_obj_fixture() - eq_(sess.query(Engineer) - .join(Engineer.engineers, aliased=True) - .filter(Engineer.name == 'e4').all(), - [Engineer(name='e2')]) + eq_( + sess.query(Engineer) + .join(Engineer.engineers, aliased=True) + .filter(Engineer.name == "e4") + .all(), + [Engineer(name="e2")], + ) def test_relationship_compare(self): sess = self._five_obj_fixture() - e1 = sess.query(Engineer).filter_by(name='e1').one() - e2 = sess.query(Engineer).filter_by(name='e2').one() + e1 = sess.query(Engineer).filter_by(name="e1").one() + e2 = sess.query(Engineer).filter_by(name="e2").one() - eq_(sess.query(Engineer) - .join(Engineer.engineers, aliased=True) - .filter(Engineer.reports_to == None).all(), # noqa - []) + eq_( + sess.query(Engineer) + .join(Engineer.engineers, aliased=True) + .filter(Engineer.reports_to == None) + .all(), # noqa + [], + ) - eq_(sess.query(Engineer) - .join(Engineer.engineers, aliased=True) - .filter(Engineer.reports_to == e1).all(), - [e1]) + eq_( + sess.query(Engineer) + .join(Engineer.engineers, aliased=True) + .filter(Engineer.reports_to == e1) + .all(), + [e1], + ) - eq_(sess.query(Engineer) - .join(Engineer.engineers, aliased=True) - .filter(Engineer.reports_to != None).all(), # noqa - [e1, e2]) + eq_( + sess.query(Engineer) + .join(Engineer.engineers, aliased=True) + .filter(Engineer.reports_to != None) + .all(), # noqa + [e1, e2], + ) class M2MFilterTest(fixtures.MappedTest): - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - organizations = Table('organizations', metadata, - Column('id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50))) - - engineers_to_org = Table('engineers_to_org', metadata, - Column('org_id', Integer, - ForeignKey('organizations.id')), - Column('engineer_id', Integer, - ForeignKey('engineers.person_id'))) - - people = Table('people', metadata, - Column('person_id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('type', String(30))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('primary_language', String(50))) + organizations = Table( + "organizations", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), + ) + + engineers_to_org = Table( + "engineers_to_org", + metadata, + Column("org_id", Integer, ForeignKey("organizations.id")), + Column("engineer_id", Integer, ForeignKey("engineers.person_id")), + ) + + people = Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("type", String(30)), + ) + + engineers = Table( + "engineers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("primary_language", String(50)), + ) @classmethod def setup_mappers(cls): @@ -403,30 +545,41 @@ class M2MFilterTest(fixtures.MappedTest): class Organization(cls.Comparable): pass - mapper(Organization, organizations, - properties={ - 'engineers': relationship( - Engineer, - secondary=engineers_to_org, - backref='organizations')}) + mapper( + Organization, + organizations, + properties={ + "engineers": relationship( + Engineer, + secondary=engineers_to_org, + backref="organizations", + ) + }, + ) - mapper(Person, people, - polymorphic_on=people.c.type, - polymorphic_identity='person') + mapper( + Person, + people, + polymorphic_on=people.c.type, + polymorphic_identity="person", + ) - mapper(Engineer, engineers, - inherits=Person, - polymorphic_identity='engineer') + mapper( + Engineer, + engineers, + inherits=Person, + polymorphic_identity="engineer", + ) @classmethod def insert_data(cls): Organization = cls.classes.Organization - e1 = Engineer(name='e1') - e2 = Engineer(name='e2') - e3 = Engineer(name='e3') - e4 = Engineer(name='e4') - org1 = Organization(name='org1', engineers=[e1, e2]) - org2 = Organization(name='org2', engineers=[e3, e4]) + e1 = Engineer(name="e1") + e2 = Engineer(name="e2") + e3 = Engineer(name="e3") + e4 = Engineer(name="e4") + org1 = Organization(name="org1", engineers=[e1, e2]) + org2 = Organization(name="org2", engineers=[e3, e4]) sess = create_session() sess.add(org1) sess.add(org2) @@ -435,38 +588,44 @@ class M2MFilterTest(fixtures.MappedTest): def test_not_contains(self): Organization = self.classes.Organization sess = create_session() - e1 = sess.query(Person).filter(Engineer.name == 'e1').one() + e1 = sess.query(Person).filter(Engineer.name == "e1").one() - eq_(sess.query(Organization) - .filter(~Organization.engineers - .of_type(Engineer) - .contains(e1)) - .all(), - [Organization(name='org2')]) + eq_( + sess.query(Organization) + .filter(~Organization.engineers.of_type(Engineer).contains(e1)) + .all(), + [Organization(name="org2")], + ) # this had a bug - eq_(sess.query(Organization) - .filter(~Organization.engineers - .contains(e1)) + eq_( + sess.query(Organization) + .filter(~Organization.engineers.contains(e1)) .all(), - [Organization(name='org2')]) + [Organization(name="org2")], + ) def test_any(self): sess = create_session() Organization = self.classes.Organization - eq_(sess.query(Organization) - .filter(Organization.engineers - .of_type(Engineer) - .any(Engineer.name == 'e1')) - .all(), - [Organization(name='org1')]) + eq_( + sess.query(Organization) + .filter( + Organization.engineers.of_type(Engineer).any( + Engineer.name == "e1" + ) + ) + .all(), + [Organization(name="org1")], + ) - eq_(sess.query(Organization) - .filter(Organization.engineers - .any(Engineer.name == 'e1')) - .all(), - [Organization(name='org1')]) + eq_( + sess.query(Organization) + .filter(Organization.engineers.any(Engineer.name == "e1")) + .all(), + [Organization(name="org1")], + ) class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): @@ -474,29 +633,37 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): @classmethod def define_tables(cls, metadata): - Table('secondary', metadata, - Column('left_id', Integer, - ForeignKey('parent.id'), - nullable=False), - Column('right_id', Integer, - ForeignKey('parent.id'), - nullable=False)) - - Table('parent', metadata, - Column('id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('cls', String(50))) - - Table('child1', metadata, - Column('id', Integer, - ForeignKey('parent.id'), - primary_key=True)) - - Table('child2', metadata, - Column('id', Integer, - ForeignKey('parent.id'), - primary_key=True)) + Table( + "secondary", + metadata, + Column( + "left_id", Integer, ForeignKey("parent.id"), nullable=False + ), + Column( + "right_id", Integer, ForeignKey("parent.id"), nullable=False + ), + ) + + Table( + "parent", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("cls", String(50)), + ) + + Table( + "child1", + metadata, + Column("id", Integer, ForeignKey("parent.id"), primary_key=True), + ) + + Table( + "child2", + metadata, + Column("id", Integer, ForeignKey("parent.id"), primary_key=True), + ) @classmethod def setup_classes(cls): @@ -519,24 +686,26 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): Child2 = cls.classes.Child2 secondary = cls.tables.secondary - mapper(Parent, parent, - polymorphic_on=parent.c.cls) - - mapper(Child1, child1, - inherits=Parent, - polymorphic_identity='child1', - properties={ - 'left_child2': relationship( - Child2, - secondary=secondary, - primaryjoin=parent.c.id == secondary.c.right_id, - secondaryjoin=parent.c.id == secondary.c.left_id, - uselist=False, - backref="right_children")}) - - mapper(Child2, child2, - inherits=Parent, - polymorphic_identity='child2') + mapper(Parent, parent, polymorphic_on=parent.c.cls) + + mapper( + Child1, + child1, + inherits=Parent, + polymorphic_identity="child1", + properties={ + "left_child2": relationship( + Child2, + secondary=secondary, + primaryjoin=parent.c.id == secondary.c.right_id, + secondaryjoin=parent.c.id == secondary.c.left_id, + uselist=False, + backref="right_children", + ) + }, + ) + + mapper(Child2, child2, inherits=Parent, polymorphic_identity="child2") def test_query_crit(self): Child1, Child2 = self.classes.Child1, self.classes.Child2 @@ -550,22 +719,33 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): sess.flush() # test that the join to Child2 doesn't alias Child1 in the select - eq_(set(sess.query(Child1).join(Child1.left_child2)), - set([c11, c12, c13])) + eq_( + set(sess.query(Child1).join(Child1.left_child2)), + set([c11, c12, c13]), + ) - eq_(set(sess.query(Child1, Child2).join(Child1.left_child2)), - set([(c11, c22), (c12, c22), (c13, c23)])) + eq_( + set(sess.query(Child1, Child2).join(Child1.left_child2)), + set([(c11, c22), (c12, c22), (c13, c23)]), + ) # test __eq__() on property is annotating correctly - eq_(set(sess.query(Child2) - .join(Child2.right_children) - .filter(Child1.left_child2 == c22)), - set([c22])) + eq_( + set( + sess.query(Child2) + .join(Child2.right_children) + .filter(Child1.left_child2 == c22) + ), + set([c22]), + ) # test the same again self.assert_compile( - sess.query(Child2).join(Child2.right_children). - filter(Child1.left_child2 == c22).with_labels().statement, + sess.query(Child2) + .join(Child2.right_children) + .filter(Child1.left_child2 == c22) + .with_labels() + .statement, "SELECT child2.id AS child2_id, parent.id AS parent_id, " "parent.cls AS parent_cls FROM secondary AS secondary_1, " "parent JOIN child2 ON parent.id = child2.id JOIN secondary AS " @@ -574,7 +754,8 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): "ON parent_1.id = child1_1.id) " "ON parent_1.id = secondary_2.right_id WHERE " "parent_1.id = secondary_1.right_id AND :param_1 = " - "secondary_1.left_id") + "secondary_1.left_id", + ) def test_eager_join(self): Child1, Child2 = self.classes.Child1, self.classes.Child2 @@ -586,7 +767,7 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): # test that the splicing of the join works here, doesn't break in # the middle of "parent join child1" - q = sess.query(Child1).options(joinedload('left_child2')) + q = sess.query(Child1).options(joinedload("left_child2")) self.assert_compile( q.limit(1).with_labels().statement, "SELECT child1.id AS child1_id, parent.id AS parent_id, " @@ -599,15 +780,15 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): "ON parent_1.id = secondary_1.left_id) " "ON parent.id = secondary_1.right_id " "LIMIT :param_1", - checkparams={'param_1': 1} + checkparams={"param_1": 1}, ) # another way to check eq_( - select([func.count('*')]).select_from( - q.limit(1).with_labels().subquery() - ).scalar(), - 1 + select([func.count("*")]) + .select_from(q.limit(1).with_labels().subquery()) + .scalar(), + 1, ) assert q.first() is c1 @@ -620,7 +801,7 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): sess.flush() sess.expunge_all() - query_ = sess.query(Child1).options(subqueryload('left_child2')) + query_ = sess.query(Child1).options(subqueryload("left_child2")) for row in query_.all(): assert row.left_child2 @@ -628,41 +809,50 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): class EagerToSubclassTest(fixtures.MappedTest): """Test eager loads to subclass mappers""" - run_setup_classes = 'once' - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_classes = "once" + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - Table('parent', metadata, - Column('id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('data', String(10))) - - Table('base', metadata, - Column('id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('type', String(10)), - Column('related_id', Integer, - ForeignKey('related.id'))) - - Table('sub', metadata, - Column('id', Integer, - ForeignKey('base.id'), - primary_key=True), - Column('data', String(10)), - Column('parent_id', Integer, - ForeignKey('parent.id'), - nullable=False)) - - Table('related', metadata, - Column('id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('data', String(10))) + Table( + "parent", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(10)), + ) + + Table( + "base", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(10)), + Column("related_id", Integer, ForeignKey("related.id")), + ) + + Table( + "sub", + metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column("data", String(10)), + Column( + "parent_id", Integer, ForeignKey("parent.id"), nullable=False + ), + ) + + Table( + "related", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(10)), + ) @classmethod def setup_classes(cls): @@ -689,17 +879,21 @@ class EagerToSubclassTest(fixtures.MappedTest): related = cls.tables.related Related = cls.classes.Related - mapper(Parent, parent, - properties={'children': relationship(Sub, order_by=sub.c.data)}) + mapper( + Parent, + parent, + properties={"children": relationship(Sub, order_by=sub.c.data)}, + ) - mapper(Base, base, - polymorphic_on=base.c.type, - polymorphic_identity='b', - properties={'related': relationship(Related)}) + mapper( + Base, + base, + polymorphic_on=base.c.type, + polymorphic_identity="b", + properties={"related": relationship(Related)}, + ) - mapper(Sub, sub, - inherits=Base, - polymorphic_identity='s') + mapper(Sub, sub, inherits=Base, polymorphic_identity="s") mapper(Related, related) @@ -711,14 +905,14 @@ class EagerToSubclassTest(fixtures.MappedTest): Sub = cls.classes.Sub Related = cls.classes.Related sess = Session() - r1, r2 = Related(data='r1'), Related(data='r2') - s1 = Sub(data='s1', related=r1) - s2 = Sub(data='s2', related=r2) - s3 = Sub(data='s3') - s4 = Sub(data='s4', related=r2) - s5 = Sub(data='s5') - p1 = Parent(data='p1', children=[s1, s2, s3]) - p2 = Parent(data='p2', children=[s4, s5]) + r1, r2 = Related(data="r1"), Related(data="r2") + s1 = Sub(data="s1", related=r1) + s2 = Sub(data="s2", related=r2) + s3 = Sub(data="s3") + s4 = Sub(data="s4", related=r2) + s5 = Sub(data="s5") + p1 = Parent(data="p1", children=[s1, s2, s3]) + p2 = Parent(data="p2", children=[s4, s5]) sess.add(p1) sess.add(p2) sess.commit() @@ -728,9 +922,11 @@ class EagerToSubclassTest(fixtures.MappedTest): sess = Session() def go(): - eq_(sess.query(Parent) - .options(joinedload(Parent.children)).all(), - [p1, p2]) + eq_( + sess.query(Parent).options(joinedload(Parent.children)).all(), + [p1, p2], + ) + self.assert_sql_count(testing.db, go, 1) def test_contains_eager(self): @@ -739,11 +935,15 @@ class EagerToSubclassTest(fixtures.MappedTest): sess = Session() def go(): - eq_(sess.query(Parent) - .join(Parent.children) - .options(contains_eager(Parent.children)) - .order_by(Parent.data, Sub.data).all(), - [p1, p2]) + eq_( + sess.query(Parent) + .join(Parent.children) + .options(contains_eager(Parent.children)) + .order_by(Parent.data, Sub.data) + .all(), + [p1, p2], + ) + self.assert_sql_count(testing.db, go, 1) def test_subq_through_related(self): @@ -752,10 +952,14 @@ class EagerToSubclassTest(fixtures.MappedTest): sess = Session() def go(): - eq_(sess.query(Parent) - .options(subqueryload_all(Parent.children, Base.related)) - .order_by(Parent.data).all(), - [p1, p2]) + eq_( + sess.query(Parent) + .options(subqueryload_all(Parent.children, Base.related)) + .order_by(Parent.data) + .all(), + [p1, p2], + ) + self.assert_sql_count(testing.db, go, 3) def test_subq_through_related_aliased(self): @@ -765,49 +969,64 @@ class EagerToSubclassTest(fixtures.MappedTest): sess = Session() def go(): - eq_(sess.query(pa) - .options(subqueryload_all(pa.children, Base.related)) - .order_by(pa.data).all(), - [p1, p2]) + eq_( + sess.query(pa) + .options(subqueryload_all(pa.children, Base.related)) + .order_by(pa.data) + .all(), + [p1, p2], + ) + self.assert_sql_count(testing.db, go, 3) class SubClassEagerToSubClassTest(fixtures.MappedTest): """Test joinedloads from subclass to subclass mappers""" - run_setup_classes = 'once' - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_classes = "once" + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - Table('parent', metadata, - Column('id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('type', String(10))) - - Table('subparent', metadata, - Column('id', Integer, - ForeignKey('parent.id'), - primary_key=True), - Column('data', String(10))) - - Table('base', metadata, - Column('id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('type', String(10))) - - Table('sub', metadata, - Column('id', Integer, - ForeignKey('base.id'), - primary_key=True), - Column('data', String(10)), - Column('subparent_id', Integer, - ForeignKey('subparent.id'), - nullable=False)) + Table( + "parent", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(10)), + ) + + Table( + "subparent", + metadata, + Column("id", Integer, ForeignKey("parent.id"), primary_key=True), + Column("data", String(10)), + ) + + Table( + "base", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(10)), + ) + + Table( + "sub", + metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column("data", String(10)), + Column( + "subparent_id", + Integer, + ForeignKey("subparent.id"), + nullable=False, + ), + ) @classmethod def setup_classes(cls): @@ -834,23 +1053,26 @@ class SubClassEagerToSubClassTest(fixtures.MappedTest): subparent = cls.tables.subparent Subparent = cls.classes.Subparent - mapper(Parent, parent, - polymorphic_on=parent.c.type, - polymorphic_identity='b') + mapper( + Parent, + parent, + polymorphic_on=parent.c.type, + polymorphic_identity="b", + ) - mapper(Subparent, subparent, - inherits=Parent, - polymorphic_identity='s', - properties={ - 'children': relationship(Sub, order_by=base.c.id)}) + mapper( + Subparent, + subparent, + inherits=Parent, + polymorphic_identity="s", + properties={"children": relationship(Sub, order_by=base.c.id)}, + ) - mapper(Base, base, - polymorphic_on=base.c.type, - polymorphic_identity='b') + mapper( + Base, base, polymorphic_on=base.c.type, polymorphic_identity="b" + ) - mapper(Sub, sub, - inherits=Base, - polymorphic_identity='s') + mapper(Sub, sub, inherits=Base, polymorphic_identity="s") @classmethod def insert_data(cls): @@ -859,11 +1081,10 @@ class SubClassEagerToSubClassTest(fixtures.MappedTest): Sub, Subparent = cls.classes.Sub, cls.classes.Subparent sess = create_session() p1 = Subparent( - data='p1', - children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]) - p2 = Subparent( - data='p2', - children=[Sub(data='s4'), Sub(data='s5')]) + data="p1", + children=[Sub(data="s1"), Sub(data="s2"), Sub(data="s3")], + ) + p2 = Subparent(data="p2", children=[Sub(data="s4"), Sub(data="s5")]) sess.add(p1) sess.add(p2) sess.flush() @@ -874,17 +1095,23 @@ class SubClassEagerToSubClassTest(fixtures.MappedTest): sess = create_session() def go(): - eq_(sess.query(Subparent) - .options(joinedload(Subparent.children)).all(), - [p1, p2]) + eq_( + sess.query(Subparent) + .options(joinedload(Subparent.children)) + .all(), + [p1, p2], + ) + self.assert_sql_count(testing.db, go, 1) sess.expunge_all() def go(): - eq_(sess.query(Subparent) - .options(joinedload("children")).all(), - [p1, p2]) + eq_( + sess.query(Subparent).options(joinedload("children")).all(), + [p1, p2], + ) + self.assert_sql_count(testing.db, go, 1) def test_contains_eager(self): @@ -893,19 +1120,27 @@ class SubClassEagerToSubClassTest(fixtures.MappedTest): sess = create_session() def go(): - eq_(sess.query(Subparent) - .join(Subparent.children) - .options(contains_eager(Subparent.children)).all(), - [p1, p2]) + eq_( + sess.query(Subparent) + .join(Subparent.children) + .options(contains_eager(Subparent.children)) + .all(), + [p1, p2], + ) + self.assert_sql_count(testing.db, go, 1) sess.expunge_all() def go(): - eq_(sess.query(Subparent) - .join(Subparent.children) - .options(contains_eager("children")).all(), - [p1, p2]) + eq_( + sess.query(Subparent) + .join(Subparent.children) + .options(contains_eager("children")) + .all(), + [p1, p2], + ) + self.assert_sql_count(testing.db, go, 1) def test_subqueryload(self): @@ -914,17 +1149,23 @@ class SubClassEagerToSubClassTest(fixtures.MappedTest): sess = create_session() def go(): - eq_(sess.query(Subparent) - .options(subqueryload(Subparent.children)).all(), - [p1, p2]) + eq_( + sess.query(Subparent) + .options(subqueryload(Subparent.children)) + .all(), + [p1, p2], + ) + self.assert_sql_count(testing.db, go, 2) sess.expunge_all() def go(): - eq_(sess.query(Subparent) - .options(subqueryload("children")).all(), - [p1, p2]) + eq_( + sess.query(Subparent).options(subqueryload("children")).all(), + [p1, p2], + ) + self.assert_sql_count(testing.db, go, 2) @@ -935,31 +1176,51 @@ class SameNamedPropTwoPolymorphicSubClassesTest(fixtures.MappedTest): #2614 """ - run_setup_classes = 'once' - run_setup_mappers = 'once' - run_inserts = 'once' + + run_setup_classes = "once" + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - Table('a', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(10))) - Table('b', metadata, - Column('id', Integer, ForeignKey('a.id'), primary_key=True)) - Table('btod', metadata, - Column('bid', Integer, ForeignKey('b.id'), nullable=False), - Column('did', Integer, ForeignKey('d.id'), nullable=False) - ) - Table('c', metadata, - Column('id', Integer, ForeignKey('a.id'), primary_key=True)) - Table('ctod', metadata, - Column('cid', Integer, ForeignKey('c.id'), nullable=False), - Column('did', Integer, ForeignKey('d.id'), nullable=False)) - Table('d', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True)) + Table( + "a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(10)), + ) + Table( + "b", + metadata, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + ) + Table( + "btod", + metadata, + Column("bid", Integer, ForeignKey("b.id"), nullable=False), + Column("did", Integer, ForeignKey("d.id"), nullable=False), + ) + Table( + "c", + metadata, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + ) + Table( + "ctod", + metadata, + Column("cid", Integer, ForeignKey("c.id"), nullable=False), + Column("did", Integer, ForeignKey("d.id"), nullable=False), + ) + Table( + "d", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) @classmethod def setup_classes(cls): @@ -983,14 +1244,20 @@ class SameNamedPropTwoPolymorphicSubClassesTest(fixtures.MappedTest): D = cls.classes.D mapper(A, cls.tables.a, polymorphic_on=cls.tables.a.c.type) - mapper(B, cls.tables.b, inherits=A, polymorphic_identity='b', - properties={ - 'related': relationship(D, secondary=cls.tables.btod) - }) - mapper(C, cls.tables.c, inherits=A, polymorphic_identity='c', - properties={ - 'related': relationship(D, secondary=cls.tables.ctod) - }) + mapper( + B, + cls.tables.b, + inherits=A, + polymorphic_identity="b", + properties={"related": relationship(D, secondary=cls.tables.btod)}, + ) + mapper( + C, + cls.tables.c, + inherits=A, + polymorphic_identity="c", + properties={"related": relationship(D, secondary=cls.tables.ctod)}, + ) mapper(D, cls.tables.d) @classmethod @@ -1002,10 +1269,7 @@ class SameNamedPropTwoPolymorphicSubClassesTest(fixtures.MappedTest): session = Session() d = D() - session.add_all([ - B(related=[d]), - C(related=[d]) - ]) + session.add_all([B(related=[d]), C(related=[d])]) session.commit() def test_free_w_poly_subquery(self): @@ -1019,11 +1283,11 @@ class SameNamedPropTwoPolymorphicSubClassesTest(fixtures.MappedTest): a_poly = with_polymorphic(A, [B, C]) def go(): - for a in session.query(a_poly).\ - options( - subqueryload(a_poly.B.related), - subqueryload(a_poly.C.related)): + for a in session.query(a_poly).options( + subqueryload(a_poly.B.related), subqueryload(a_poly.C.related) + ): eq_(a.related, [d]) + self.assert_sql_count(testing.db, go, 3) def test_fixed_w_poly_subquery(self): @@ -1036,9 +1300,13 @@ class SameNamedPropTwoPolymorphicSubClassesTest(fixtures.MappedTest): d = session.query(D).one() def go(): - for a in session.query(A).with_polymorphic([B, C]).\ - options(subqueryload(B.related), subqueryload(C.related)): + for a in ( + session.query(A) + .with_polymorphic([B, C]) + .options(subqueryload(B.related), subqueryload(C.related)) + ): eq_(a.related, [d]) + self.assert_sql_count(testing.db, go, 3) def test_free_w_poly_joined(self): @@ -1052,11 +1320,11 @@ class SameNamedPropTwoPolymorphicSubClassesTest(fixtures.MappedTest): a_poly = with_polymorphic(A, [B, C]) def go(): - for a in session.query(a_poly).\ - options( - joinedload(a_poly.B.related), - joinedload(a_poly.C.related)): + for a in session.query(a_poly).options( + joinedload(a_poly.B.related), joinedload(a_poly.C.related) + ): eq_(a.related, [d]) + self.assert_sql_count(testing.db, go, 1) def test_fixed_w_poly_joined(self): @@ -1069,9 +1337,13 @@ class SameNamedPropTwoPolymorphicSubClassesTest(fixtures.MappedTest): d = session.query(D).one() def go(): - for a in session.query(A).with_polymorphic([B, C]).\ - options(joinedload(B.related), joinedload(C.related)): + for a in ( + session.query(A) + .with_polymorphic([B, C]) + .options(joinedload(B.related), joinedload(C.related)) + ): eq_(a.related, [d]) + self.assert_sql_count(testing.db, go, 1) @@ -1079,26 +1351,41 @@ class SubClassToSubClassFromParentTest(fixtures.MappedTest): """test #2617 """ - run_setup_classes = 'once' - run_setup_mappers = 'once' - run_inserts = 'once' + + run_setup_classes = "once" + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - Table('z', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True)) - Table('a', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(10)), - Column('z_id', Integer, ForeignKey('z.id'))) - Table('b', metadata, - Column('id', Integer, ForeignKey('a.id'), primary_key=True)) - Table('d', metadata, - Column('id', Integer, ForeignKey('a.id'), primary_key=True), - Column('b_id', Integer, ForeignKey('b.id'))) + Table( + "z", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) + Table( + "a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(10)), + Column("z_id", Integer, ForeignKey("z.id")), + ) + Table( + "b", + metadata, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + ) + Table( + "d", + metadata, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + Column("b_id", Integer, ForeignKey("b.id")), + ) @classmethod def setup_classes(cls): @@ -1122,18 +1409,27 @@ class SubClassToSubClassFromParentTest(fixtures.MappedTest): D = cls.classes.D mapper(Z, cls.tables.z) - mapper(A, cls.tables.a, polymorphic_on=cls.tables.a.c.type, - with_polymorphic='*', - properties={ - 'zs': relationship(Z, lazy="subquery") - }) - mapper(B, cls.tables.b, inherits=A, polymorphic_identity='b', - properties={ - 'related': relationship(D, lazy="subquery", - primaryjoin=cls.tables.d.c.b_id == - cls.tables.b.c.id) - }) - mapper(D, cls.tables.d, inherits=A, polymorphic_identity='d') + mapper( + A, + cls.tables.a, + polymorphic_on=cls.tables.a.c.type, + with_polymorphic="*", + properties={"zs": relationship(Z, lazy="subquery")}, + ) + mapper( + B, + cls.tables.b, + inherits=A, + polymorphic_identity="b", + properties={ + "related": relationship( + D, + lazy="subquery", + primaryjoin=cls.tables.d.c.b_id == cls.tables.b.c.id, + ) + }, + ) + mapper(D, cls.tables.d, inherits=A, polymorphic_identity="d") @classmethod def insert_data(cls): @@ -1150,6 +1446,7 @@ class SubClassToSubClassFromParentTest(fixtures.MappedTest): def go(): a1 = session.query(A).first() eq_(a1.related, []) + self.assert_sql_count(testing.db, go, 3) @@ -1167,41 +1464,67 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): run_create_tables = None run_deletes = None - __dialect__ = 'default' + __dialect__ = "default" @classmethod def define_tables(cls, metadata): - Table('parent', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30))) - Table('base1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30))) - Table('sub1', metadata, - Column('id', Integer, ForeignKey('base1.id'), primary_key=True), - Column('parent_id', ForeignKey('parent.id')), - Column('subdata', String(30))) - - Table('base2', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('base1_id', ForeignKey('base1.id')), - Column('data', String(30))) - Table('sub2', metadata, - Column('id', Integer, ForeignKey('base2.id'), primary_key=True), - Column('subdata', String(30))) - Table('ep1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('base2_id', Integer, ForeignKey('base2.id')), - Column('data', String(30))) - Table('ep2', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('base2_id', Integer, ForeignKey('base2.id')), - Column('data', String(30))) + Table( + "parent", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + ) + Table( + "base1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + ) + Table( + "sub1", + metadata, + Column("id", Integer, ForeignKey("base1.id"), primary_key=True), + Column("parent_id", ForeignKey("parent.id")), + Column("subdata", String(30)), + ) + + Table( + "base2", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("base1_id", ForeignKey("base1.id")), + Column("data", String(30)), + ) + Table( + "sub2", + metadata, + Column("id", Integer, ForeignKey("base2.id"), primary_key=True), + Column("subdata", String(30)), + ) + Table( + "ep1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("base2_id", Integer, ForeignKey("base2.id")), + Column("data", String(30)), + ) + Table( + "ep2", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("base2_id", Integer, ForeignKey("base2.id")), + Column("data", String(30)), + ) @classmethod def setup_classes(cls): @@ -1228,26 +1551,32 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): @classmethod def _classes(cls): - return cls.classes.Parent, cls.classes.Base1,\ - cls.classes.Base2, cls.classes.Sub1,\ - cls.classes.Sub2, cls.classes.EP1,\ - cls.classes.EP2 + return ( + cls.classes.Parent, + cls.classes.Base1, + cls.classes.Base2, + cls.classes.Sub1, + cls.classes.Sub2, + cls.classes.EP1, + cls.classes.EP2, + ) @classmethod def setup_mappers(cls): Parent, Base1, Base2, Sub1, Sub2, EP1, EP2 = cls._classes() - mapper(Parent, cls.tables.parent, properties={ - 'sub1': relationship(Sub1) - }) - mapper(Base1, cls.tables.base1, properties={ - 'sub2': relationship(Sub2) - }) + mapper( + Parent, cls.tables.parent, properties={"sub1": relationship(Sub1)} + ) + mapper( + Base1, cls.tables.base1, properties={"sub2": relationship(Sub2)} + ) mapper(Sub1, cls.tables.sub1, inherits=Base1) - mapper(Base2, cls.tables.base2, properties={ - 'ep1': relationship(EP1), - 'ep2': relationship(EP2) - }) + mapper( + Base2, + cls.tables.base2, + properties={"ep1": relationship(EP1), "ep2": relationship(EP2)}, + ) mapper(Sub2, cls.tables.sub2, inherits=Base2) mapper(EP1, cls.tables.ep1) mapper(EP2, cls.tables.ep2) @@ -1257,9 +1586,10 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): s = Session() self.assert_compile( - s.query(Parent).join(Parent.sub1, Sub1.sub2). - join(Sub2.ep1). - join(Sub2.ep2), + s.query(Parent) + .join(Parent.sub1, Sub1.sub2) + .join(Sub2.ep1) + .join(Sub2.ep2), "SELECT parent.id AS parent_id, parent.data AS parent_data " "FROM parent JOIN (base1 JOIN sub1 ON base1.id = sub1.id) " "ON parent.id = sub1.parent_id JOIN " @@ -1267,7 +1597,7 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): "ON base2.id = sub2.id) " "ON base1.id = base2.base1_id " "JOIN ep1 ON base2.id = ep1.base2_id " - "JOIN ep2 ON base2.id = ep2.base2_id" + "JOIN ep2 ON base2.id = ep2.base2_id", ) def test_two(self): @@ -1277,14 +1607,13 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): s = Session() self.assert_compile( - s.query(Parent).join(Parent.sub1). - join(s2a, Sub1.sub2), + s.query(Parent).join(Parent.sub1).join(s2a, Sub1.sub2), "SELECT parent.id AS parent_id, parent.data AS parent_data " "FROM parent JOIN (base1 JOIN sub1 ON base1.id = sub1.id) " "ON parent.id = sub1.parent_id JOIN " "(base2 AS base2_1 JOIN sub2 AS sub2_1 " "ON base2_1.id = sub2_1.id) " - "ON base1.id = base2_1.base1_id" + "ON base1.id = base2_1.base1_id", ) def test_three(self): @@ -1292,15 +1621,13 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): s = Session() self.assert_compile( - s.query(Base1).join(Base1.sub2). - join(Sub2.ep1). - join(Sub2.ep2), + s.query(Base1).join(Base1.sub2).join(Sub2.ep1).join(Sub2.ep2), "SELECT base1.id AS base1_id, base1.data AS base1_data " "FROM base1 JOIN (base2 JOIN sub2 " "ON base2.id = sub2.id) ON base1.id = " "base2.base1_id " "JOIN ep1 ON base2.id = ep1.base2_id " - "JOIN ep2 ON base2.id = ep2.base2_id" + "JOIN ep2 ON base2.id = ep2.base2_id", ) def test_four(self): @@ -1308,16 +1635,17 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): s = Session() self.assert_compile( - s.query(Sub2).join(Base1, Base1.id == Sub2.base1_id). - join(Sub2.ep1). - join(Sub2.ep2), + s.query(Sub2) + .join(Base1, Base1.id == Sub2.base1_id) + .join(Sub2.ep1) + .join(Sub2.ep2), "SELECT sub2.id AS sub2_id, base2.id AS base2_id, " "base2.base1_id AS base2_base1_id, base2.data AS base2_data, " "sub2.subdata AS sub2_subdata " "FROM base2 JOIN sub2 ON base2.id = sub2.id " "JOIN base1 ON base1.id = base2.base1_id " "JOIN ep1 ON base2.id = ep1.base2_id " - "JOIN ep2 ON base2.id = ep2.base2_id" + "JOIN ep2 ON base2.id = ep2.base2_id", ) def test_five(self): @@ -1325,9 +1653,10 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): s = Session() self.assert_compile( - s.query(Sub2).join(Sub1, Sub1.id == Sub2.base1_id). - join(Sub2.ep1). - join(Sub2.ep2), + s.query(Sub2) + .join(Sub1, Sub1.id == Sub2.base1_id) + .join(Sub2.ep1) + .join(Sub2.ep2), "SELECT sub2.id AS sub2_id, base2.id AS base2_id, " "base2.base1_id AS base2_base1_id, base2.data AS base2_data, " "sub2.subdata AS sub2_subdata " @@ -1336,7 +1665,7 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): "(base1 JOIN sub1 ON base1.id = sub1.id) " "ON sub1.id = base2.base1_id " "JOIN ep1 ON base2.id = ep1.base2_id " - "JOIN ep2 ON base2.id = ep2.base2_id" + "JOIN ep2 ON base2.id = ep2.base2_id", ) def test_six(self): @@ -1344,9 +1673,7 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): s = Session() self.assert_compile( - s.query(Sub2).from_self(). - join(Sub2.ep1). - join(Sub2.ep2), + s.query(Sub2).from_self().join(Sub2.ep1).join(Sub2.ep2), "SELECT anon_1.sub2_id AS anon_1_sub2_id, " "anon_1.base2_id AS anon_1_base2_id, " "anon_1.base2_base1_id AS anon_1_base2_base1_id, " @@ -1357,7 +1684,7 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): "sub2.subdata AS sub2_subdata " "FROM base2 JOIN sub2 ON base2.id = sub2.id) AS anon_1 " "JOIN ep1 ON anon_1.base2_id = ep1.base2_id " - "JOIN ep2 ON anon_1.base2_id = ep2.base2_id" + "JOIN ep2 ON anon_1.base2_id = ep2.base2_id", ) def test_seven(self): @@ -1368,10 +1695,12 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): # adding Sub2 to the entities list helps it, # otherwise the joins for Sub2.ep1/ep2 don't have columns # to latch onto. Can't really make it better than this - s.query(Parent, Sub2).join(Parent.sub1).\ - join(Sub1.sub2).from_self().\ - join(Sub2.ep1). - join(Sub2.ep2), + s.query(Parent, Sub2) + .join(Parent.sub1) + .join(Sub1.sub2) + .from_self() + .join(Sub2.ep1) + .join(Sub2.ep2), "SELECT anon_1.parent_id AS anon_1_parent_id, " "anon_1.parent_data AS anon_1_parent_data, " "anon_1.sub2_id AS anon_1_sub2_id, " @@ -1390,50 +1719,50 @@ class SubClassToSubClassMultiTest(AssertsCompiledSQL, fixtures.MappedTest): "(base2 JOIN sub2 ON base2.id = sub2.id) " "ON base1.id = base2.base1_id) AS anon_1 " "JOIN ep1 ON anon_1.base2_id = ep1.base2_id " - "JOIN ep2 ON anon_1.base2_id = ep2.base2_id" + "JOIN ep2 ON anon_1.base2_id = ep2.base2_id", ) class JoinedloadSinglePolysubSingle( - fixtures.DeclarativeMappedTest, - testing.AssertsCompiledSQL): + fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL +): """exercise issue #3611, using the test from dupe issue 3614""" run_define_tables = None - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class User(Base): - __tablename__ = 'users' + __tablename__ = "users" id = Column(Integer, primary_key=True) class UserRole(Base): - __tablename__ = 'user_roles' + __tablename__ = "user_roles" id = Column(Integer, primary_key=True) row_type = Column(String(50), nullable=False) - __mapper_args__ = {'polymorphic_on': row_type} + __mapper_args__ = {"polymorphic_on": row_type} - user_id = Column(Integer, ForeignKey('users.id'), nullable=False) - user = relationship('User', lazy=False) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False) + user = relationship("User", lazy=False) class Admin(UserRole): - __tablename__ = 'admins' - __mapper_args__ = {'polymorphic_identity': 'admin'} + __tablename__ = "admins" + __mapper_args__ = {"polymorphic_identity": "admin"} - id = Column(Integer, ForeignKey('user_roles.id'), primary_key=True) + id = Column(Integer, ForeignKey("user_roles.id"), primary_key=True) class Thing(Base): - __tablename__ = 'things' + __tablename__ = "things" id = Column(Integer, primary_key=True) - admin_id = Column(Integer, ForeignKey('admins.id')) - admin = relationship('Admin', lazy=False) + admin_id = Column(Integer, ForeignKey("admins.id")) + admin = relationship("Admin", lazy=False) def test_query(self): Thing = self.classes.Thing @@ -1450,69 +1779,71 @@ class JoinedloadSinglePolysubSingle( "AS admins_1 ON user_roles_1.id = admins_1.id) ON " "admins_1.id = things.admin_id " "LEFT OUTER JOIN users AS " - "users_1 ON users_1.id = user_roles_1.user_id" + "users_1 ON users_1.id = user_roles_1.user_id", ) class JoinedloadOverWPolyAliased( - fixtures.DeclarativeMappedTest, - testing.AssertsCompiledSQL): + fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL +): """exercise issues in #3593 and #3611""" - run_setup_mappers = 'each' - run_setup_classes = 'each' - run_define_tables = 'each' - __dialect__ = 'default' + run_setup_mappers = "each" + run_setup_classes = "each" + run_define_tables = "each" + __dialect__ = "default" @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class Owner(Base): - __tablename__ = 'owner' + __tablename__ = "owner" id = Column(Integer, primary_key=True) type = Column(String(20)) __mapper_args__ = { - 'polymorphic_on': type, - 'with_polymorphic': ('*', None), + "polymorphic_on": type, + "with_polymorphic": ("*", None), } class SubOwner(Owner): - __mapper_args__ = {'polymorphic_identity': 'so'} + __mapper_args__ = {"polymorphic_identity": "so"} class Parent(Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) type = Column(String(20)) __mapper_args__ = { - 'polymorphic_on': type, - 'with_polymorphic': ('*', None), + "polymorphic_on": type, + "with_polymorphic": ("*", None), } class Sub1(Parent): - __mapper_args__ = {'polymorphic_identity': 's1'} + __mapper_args__ = {"polymorphic_identity": "s1"} class Link(Base): - __tablename__ = 'link' + __tablename__ = "link" parent_id = Column( - Integer, ForeignKey('parent.id'), primary_key=True) + Integer, ForeignKey("parent.id"), primary_key=True + ) child_id = Column( - Integer, ForeignKey('parent.id'), primary_key=True) + Integer, ForeignKey("parent.id"), primary_key=True + ) def _fixture_from_base(self): Parent = self.classes.Parent Link = self.classes.Link Link.child = relationship( - Parent, primaryjoin=Link.child_id == Parent.id) + Parent, primaryjoin=Link.child_id == Parent.id + ) Parent.links = relationship( - Link, - primaryjoin=Parent.id == Link.parent_id, + Link, primaryjoin=Parent.id == Link.parent_id ) return Parent @@ -1521,12 +1852,10 @@ class JoinedloadOverWPolyAliased( Link = self.classes.Link Parent = self.classes.Parent Link.child = relationship( - Parent, primaryjoin=Link.child_id == Parent.id) - - Sub1.links = relationship( - Link, - primaryjoin=Sub1.id == Link.parent_id, + Parent, primaryjoin=Link.child_id == Parent.id ) + + Sub1.links = relationship(Link, primaryjoin=Sub1.id == Link.parent_id) return Sub1 def _fixture_to_subclass_to_base(self): @@ -1537,10 +1866,9 @@ class JoinedloadOverWPolyAliased( # Link -> Sub1 -> Owner - Link.child = relationship( - Sub1, primaryjoin=Link.child_id == Sub1.id) + Link.child = relationship(Sub1, primaryjoin=Link.child_id == Sub1.id) - Parent.owner_id = Column(ForeignKey('owner.id')) + Parent.owner_id = Column(ForeignKey("owner.id")) Parent.owner = relationship(Owner) return Parent @@ -1553,9 +1881,10 @@ class JoinedloadOverWPolyAliased( # Link -> Parent -> Owner Link.child = relationship( - Parent, primaryjoin=Link.child_id == Parent.id) + Parent, primaryjoin=Link.child_id == Parent.id + ) - Parent.owner_id = Column(ForeignKey('owner.id')) + Parent.owner_id = Column(ForeignKey("owner.id")) Parent.owner = relationship(Owner) return Parent @@ -1578,11 +1907,7 @@ class JoinedloadOverWPolyAliased( session = Session() q = session.query(cls).options( - joinedload_all( - cls.links, - Link.child, - cls.links - ) + joinedload_all(cls.links, Link.child, cls.links) ) if cls is self.classes.Sub1: extra = " WHERE parent.type IN (:type_1)" @@ -1602,7 +1927,7 @@ class JoinedloadOverWPolyAliased( "LEFT OUTER JOIN parent " "AS parent_1 ON link_1.child_id = parent_1.id " "LEFT OUTER JOIN link AS link_2 " - "ON parent_1.id = link_2.parent_id" + extra + "ON parent_1.id = link_2.parent_id" + extra, ) def _test_single_poly_poly(self, fn): @@ -1611,10 +1936,7 @@ class JoinedloadOverWPolyAliased( session = Session() q = session.query(Link).options( - joinedload_all( - Link.child, - parent_cls.owner - ) + joinedload_all(Link.child, parent_cls.owner) ) if Link.child.property.mapper.class_ is self.classes.Sub1: @@ -1630,9 +1952,10 @@ class JoinedloadOverWPolyAliased( "parent_1.owner_id AS parent_1_owner_id, " "owner_1.id AS owner_1_id, owner_1.type AS owner_1_type " "FROM link LEFT OUTER JOIN parent AS parent_1 " - "ON link.child_id = parent_1.id " + extra + - "LEFT OUTER JOIN owner AS owner_1 " - "ON owner_1.id = parent_1.owner_id" + "ON link.child_id = parent_1.id " + + extra + + "LEFT OUTER JOIN owner AS owner_1 " + "ON owner_1.id = parent_1.owner_id", ) def test_local_wpoly(self): @@ -1644,9 +1967,9 @@ class JoinedloadOverWPolyAliased( session = Session() q = session.query(poly).options( - joinedload(poly.Sub1.links). - joinedload(Link.child.of_type(Sub1)). joinedload(poly.Sub1.links) + .joinedload(Link.child.of_type(Sub1)) + .joinedload(poly.Sub1.links) ) self.assert_compile( q, @@ -1659,7 +1982,7 @@ class JoinedloadOverWPolyAliased( "LEFT OUTER JOIN link AS link_1 ON parent.id = link_1.parent_id " "LEFT OUTER JOIN parent AS parent_1 " "ON link_1.child_id = parent_1.id " - "LEFT OUTER JOIN link AS link_2 ON parent_1.id = link_2.parent_id" + "LEFT OUTER JOIN link AS link_2 ON parent_1.id = link_2.parent_id", ) def test_local_wpoly_innerjoins(self): @@ -1672,9 +1995,9 @@ class JoinedloadOverWPolyAliased( session = Session() q = session.query(poly).options( - joinedload(poly.Sub1.links, innerjoin=True). - joinedload(Link.child.of_type(Sub1), innerjoin=True). joinedload(poly.Sub1.links, innerjoin=True) + .joinedload(Link.child.of_type(Sub1), innerjoin=True) + .joinedload(poly.Sub1.links, innerjoin=True) ) self.assert_compile( q, @@ -1687,7 +2010,7 @@ class JoinedloadOverWPolyAliased( "LEFT OUTER JOIN link AS link_1 ON parent.id = link_1.parent_id " "LEFT OUTER JOIN parent AS parent_1 " "ON link_1.child_id = parent_1.id " - "LEFT OUTER JOIN link AS link_2 ON parent_1.id = link_2.parent_id" + "LEFT OUTER JOIN link AS link_2 ON parent_1.id = link_2.parent_id", ) def test_local_wpoly_innerjoins_roundtrip(self): @@ -1697,10 +2020,7 @@ class JoinedloadOverWPolyAliased( Link = self.classes.Link session = Session() - session.add_all([ - Parent(), - Parent() - ]) + session.add_all([Parent(), Parent()]) # represents "Parent" and "Sub1" rows poly = with_polymorphic(Parent, [Sub1]) @@ -1709,74 +2029,82 @@ class JoinedloadOverWPolyAliased( # to be cancelled because the Parent rows # would be omitted q = session.query(poly).options( - joinedload(poly.Sub1.links, innerjoin=True). - joinedload(Link.child.of_type(Sub1), innerjoin=True) + joinedload(poly.Sub1.links, innerjoin=True).joinedload( + Link.child.of_type(Sub1), innerjoin=True + ) ) eq_(len(q.all()), 2) -class JoinAcrossJoinedInhMultiPath(fixtures.DeclarativeMappedTest, - testing.AssertsCompiledSQL): +class JoinAcrossJoinedInhMultiPath( + fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL +): """test long join paths with a joined-inh in the middle, where we go multiple times across the same joined-inh to the same target but with other classes in the middle. E.g. test [ticket:2908] """ - run_setup_mappers = 'once' - __dialect__ = 'default' + run_setup_mappers = "once" + __dialect__ = "default" @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class Root(Base): - __tablename__ = 'root' + __tablename__ = "root" id = Column(Integer, primary_key=True) - sub1_id = Column(Integer, ForeignKey('sub1.id')) + sub1_id = Column(Integer, ForeignKey("sub1.id")) intermediate = relationship("Intermediate") sub1 = relationship("Sub1") class Intermediate(Base): - __tablename__ = 'intermediate' + __tablename__ = "intermediate" id = Column(Integer, primary_key=True) - sub1_id = Column(Integer, ForeignKey('sub1.id')) - root_id = Column(Integer, ForeignKey('root.id')) + sub1_id = Column(Integer, ForeignKey("sub1.id")) + root_id = Column(Integer, ForeignKey("root.id")) sub1 = relationship("Sub1") class Parent(Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) class Sub1(Parent): - __tablename__ = 'sub1' - id = Column(Integer, ForeignKey('parent.id'), - primary_key=True) + __tablename__ = "sub1" + id = Column(Integer, ForeignKey("parent.id"), primary_key=True) target = relationship("Target") class Target(Base): - __tablename__ = 'target' + __tablename__ = "target" id = Column(Integer, primary_key=True) - sub1_id = Column(Integer, ForeignKey('sub1.id')) + sub1_id = Column(Integer, ForeignKey("sub1.id")) def test_join(self): - Root, Intermediate, Sub1, Target = \ - self.classes.Root, self.classes.Intermediate, \ - self.classes.Sub1, self.classes.Target + Root, Intermediate, Sub1, Target = ( + self.classes.Root, + self.classes.Intermediate, + self.classes.Sub1, + self.classes.Target, + ) s1_alias = aliased(Sub1) s2_alias = aliased(Sub1) t1_alias = aliased(Target) t2_alias = aliased(Target) sess = Session() - q = sess.query(Root).\ - join(s1_alias, Root.sub1).join(t1_alias, s1_alias.target).\ - join(Root.intermediate).join(s2_alias, Intermediate.sub1).\ - join(t2_alias, s2_alias.target) + q = ( + sess.query(Root) + .join(s1_alias, Root.sub1) + .join(t1_alias, s1_alias.target) + .join(Root.intermediate) + .join(s2_alias, Intermediate.sub1) + .join(t2_alias, s2_alias.target) + ) self.assert_compile( q, "SELECT root.id AS root_id, root.sub1_id AS root_sub1_id " @@ -1789,22 +2117,30 @@ class JoinAcrossJoinedInhMultiPath(fixtures.DeclarativeMappedTest, "JOIN (SELECT parent.id AS parent_id, sub1.id AS sub1_id " "FROM parent JOIN sub1 ON parent.id = sub1.id) AS anon_2 " "ON anon_2.sub1_id = intermediate.sub1_id " - "JOIN target AS target_2 ON anon_2.sub1_id = target_2.sub1_id") + "JOIN target AS target_2 ON anon_2.sub1_id = target_2.sub1_id", + ) def test_join_flat(self): - Root, Intermediate, Sub1, Target = \ - self.classes.Root, self.classes.Intermediate, \ - self.classes.Sub1, self.classes.Target + Root, Intermediate, Sub1, Target = ( + self.classes.Root, + self.classes.Intermediate, + self.classes.Sub1, + self.classes.Target, + ) s1_alias = aliased(Sub1, flat=True) s2_alias = aliased(Sub1, flat=True) t1_alias = aliased(Target) t2_alias = aliased(Target) sess = Session() - q = sess.query(Root).\ - join(s1_alias, Root.sub1).join(t1_alias, s1_alias.target).\ - join(Root.intermediate).join(s2_alias, Intermediate.sub1).\ - join(t2_alias, s2_alias.target) + q = ( + sess.query(Root) + .join(s1_alias, Root.sub1) + .join(t1_alias, s1_alias.target) + .join(Root.intermediate) + .join(s2_alias, Intermediate.sub1) + .join(t2_alias, s2_alias.target) + ) self.assert_compile( q, "SELECT root.id AS root_id, root.sub1_id AS root_sub1_id " @@ -1817,19 +2153,24 @@ class JoinAcrossJoinedInhMultiPath(fixtures.DeclarativeMappedTest, "JOIN (parent AS parent_2 JOIN sub1 AS sub1_2 " "ON parent_2.id = sub1_2.id) " "ON sub1_2.id = intermediate.sub1_id " - "JOIN target AS target_2 ON sub1_2.id = target_2.sub1_id") + "JOIN target AS target_2 ON sub1_2.id = target_2.sub1_id", + ) def test_joinedload(self): - Root, Intermediate, Sub1, Target = \ - self.classes.Root, self.classes.Intermediate, \ - self.classes.Sub1, self.classes.Target + Root, Intermediate, Sub1, Target = ( + self.classes.Root, + self.classes.Intermediate, + self.classes.Sub1, + self.classes.Target, + ) sess = Session() - q = sess.query(Root).\ - options( + q = sess.query(Root).options( joinedload(Root.sub1).joinedload(Sub1.target), - joinedload(Root.intermediate).joinedload(Intermediate.sub1). - joinedload(Sub1.target)) + joinedload(Root.intermediate) + .joinedload(Intermediate.sub1) + .joinedload(Sub1.target), + ) self.assert_compile( q, "SELECT root.id AS root_id, root.sub1_id AS root_sub1_id, " @@ -1853,28 +2194,42 @@ class JoinAcrossJoinedInhMultiPath(fixtures.DeclarativeMappedTest, "LEFT OUTER JOIN (parent AS parent_2 JOIN sub1 AS sub1_2 " "ON parent_2.id = sub1_2.id) ON sub1_2.id = root.sub1_id " "LEFT OUTER JOIN target AS target_2 " - "ON sub1_2.id = target_2.sub1_id") + "ON sub1_2.id = target_2.sub1_id", + ) class MultipleAdaptUsesEntityOverTableTest( - AssertsCompiledSQL, fixtures.MappedTest): - __dialect__ = 'default' + AssertsCompiledSQL, fixtures.MappedTest +): + __dialect__ = "default" run_create_tables = None run_deletes = None @classmethod def define_tables(cls, metadata): - Table('a', metadata, - Column('id', Integer, primary_key=True), - Column('name', String)) - Table('b', metadata, - Column('id', Integer, ForeignKey('a.id'), primary_key=True)) - Table('c', metadata, - Column('id', Integer, ForeignKey('a.id'), primary_key=True), - Column('bid', Integer, ForeignKey('b.id'))) - Table('d', metadata, - Column('id', Integer, ForeignKey('a.id'), primary_key=True), - Column('cid', Integer, ForeignKey('c.id'))) + Table( + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + Table( + "b", + metadata, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + ) + Table( + "c", + metadata, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + Column("bid", Integer, ForeignKey("b.id")), + ) + Table( + "d", + metadata, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + Column("cid", Integer, ForeignKey("c.id")), + ) @classmethod def setup_classes(cls): @@ -1900,12 +2255,19 @@ class MultipleAdaptUsesEntityOverTableTest( mapper(D, d, inherits=A) def _two_join_fixture(self): - A, B, C, D = (self.classes.A, self.classes.B, self.classes.C, - self.classes.D) + A, B, C, D = ( + self.classes.A, + self.classes.B, + self.classes.C, + self.classes.D, + ) s = Session() - return s.query(B.name, C.name, D.name).select_from(B).\ - join(C, C.bid == B.id).\ - join(D, D.cid == C.id) + return ( + s.query(B.name, C.name, D.name) + .select_from(B) + .join(C, C.bid == B.id) + .join(D, D.cid == C.id) + ) def test_two_joins_adaption(self): a, b, c, d = self.tables.a, self.tables.b, self.tables.c, self.tables.d @@ -1944,34 +2306,38 @@ class MultipleAdaptUsesEntityOverTableTest( "FROM a JOIN b ON a.id = b.id JOIN " "(a AS a_1 JOIN c AS c_1 ON a_1.id = c_1.id) ON c_1.bid = b.id " "JOIN (a AS a_2 JOIN d AS d_1 ON a_2.id = d_1.id) " - "ON d_1.cid = c_1.id") + "ON d_1.cid = c_1.id", + ) class SameNameOnJoined(fixtures.MappedTest): - run_setup_mappers = 'once' + run_setup_mappers = "once" run_inserts = None run_deletes = None @classmethod def define_tables(cls, metadata): Table( - 'a', metadata, + "a", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('t', String(5)) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("t", String(5)), ) Table( - 'a_sub', metadata, - Column('id', Integer, ForeignKey('a.id'), primary_key=True) + "a_sub", + metadata, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), ) Table( - 'b', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('a_id', Integer, ForeignKey('a.id')) - + "b", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("a_id", Integer, ForeignKey("a.id")), ) @classmethod @@ -1986,31 +2352,29 @@ class SameNameOnJoined(fixtures.MappedTest): pass mapper( - A, cls.tables.a, polymorphic_on=cls.tables.a.c.t, - polymorphic_identity='a', - properties={ - 'bs': relationship(B, cascade="all, delete-orphan") - } + A, + cls.tables.a, + polymorphic_on=cls.tables.a.c.t, + polymorphic_identity="a", + properties={"bs": relationship(B, cascade="all, delete-orphan")}, ) mapper( - ASub, cls.tables.a_sub, inherits=A, - polymorphic_identity='asub', properties={ - 'bs': relationship(B, cascade="all, delete-orphan") - } + ASub, + cls.tables.a_sub, + inherits=A, + polymorphic_identity="asub", + properties={"bs": relationship(B, cascade="all, delete-orphan")}, ) mapper(B, cls.tables.b) def test_persist(self): - A, ASub, B = self.classes('A', 'ASub', 'B') + A, ASub, B = self.classes("A", "ASub", "B") s = Session(testing.db) - s.add_all([ - A(bs=[B(), B(), B()]), - ASub(bs=[B(), B(), B()]) - ]) + s.add_all([A(bs=[B(), B(), B()]), ASub(bs=[B(), B(), B()])]) s.commit() eq_(s.query(B).count(), 6) @@ -2025,44 +2389,47 @@ class SameNameOnJoined(fixtures.MappedTest): class BetweenSubclassJoinWExtraJoinedLoad( - fixtures.DeclarativeMappedTest, - testing.AssertsCompiledSQL): + fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL +): """test for [ticket:3884]""" run_define_tables = None - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class Person(Base): - __tablename__ = 'people' + __tablename__ = "people" id = Column(Integer, primary_key=True) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on': discriminator} + discriminator = Column("type", String(50)) + __mapper_args__ = {"polymorphic_on": discriminator} class Manager(Person): - __tablename__ = 'managers' - __mapper_args__ = {'polymorphic_identity': 'manager'} - id = Column(Integer, ForeignKey('people.id'), primary_key=True) + __tablename__ = "managers" + __mapper_args__ = {"polymorphic_identity": "manager"} + id = Column(Integer, ForeignKey("people.id"), primary_key=True) class Engineer(Person): - __tablename__ = 'engineers' - __mapper_args__ = {'polymorphic_identity': 'engineer'} - id = Column(Integer, ForeignKey('people.id'), primary_key=True) + __tablename__ = "engineers" + __mapper_args__ = {"polymorphic_identity": "engineer"} + id = Column(Integer, ForeignKey("people.id"), primary_key=True) primary_language = Column(String(50)) - manager_id = Column(Integer, ForeignKey('managers.id')) + manager_id = Column(Integer, ForeignKey("managers.id")) manager = relationship( - Manager, primaryjoin=(Manager.id == manager_id)) + Manager, primaryjoin=(Manager.id == manager_id) + ) class LastSeen(Base): - __tablename__ = 'seen' - id = Column(Integer, ForeignKey('people.id'), primary_key=True) + __tablename__ = "seen" + id = Column(Integer, ForeignKey("people.id"), primary_key=True) timestamp = Column(Integer) taggable = relationship( - Person, primaryjoin=(Person.id == id), - backref=backref("last_seen", lazy=False)) + Person, + primaryjoin=(Person.id == id), + backref=backref("last_seen", lazy=False), + ) def test_query(self): Engineer, Manager = self.classes("Engineer", "Manager") @@ -2089,5 +2456,5 @@ class BetweenSubclassJoinWExtraJoinedLoad( "ON people_1.id = managers_1.id) " "ON managers_1.id = engineers.manager_id LEFT OUTER JOIN " "seen AS seen_1 ON people.id = seen_1.id LEFT OUTER JOIN " - "seen AS seen_2 ON people_1.id = seen_2.id" + "seen AS seen_2 ON people_1.id = seen_2.id", ) diff --git a/test/orm/inheritance/test_selects.py b/test/orm/inheritance/test_selects.py index 710418b24b..cd7e3be40b 100644 --- a/test/orm/inheritance/test_selects.py +++ b/test/orm/inheritance/test_selects.py @@ -10,20 +10,23 @@ from sqlalchemy.testing.schema import Table, Column class InheritingSelectablesTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - foo = Table('foo', metadata, - Column('a', String(30), primary_key=1), - Column('b', String(30), nullable=0)) + foo = Table( + "foo", + metadata, + Column("a", String(30), primary_key=1), + Column("b", String(30), nullable=0), + ) - cls.tables.bar = foo.select(foo.c.b == 'bar').alias('bar') - cls.tables.baz = foo.select(foo.c.b == 'baz').alias('baz') + cls.tables.bar = foo.select(foo.c.b == "bar").alias("bar") + cls.tables.baz = foo.select(foo.c.b == "baz").alias("baz") def test_load(self): foo, bar, baz = self.tables.foo, self.tables.bar, self.tables.baz # TODO: add persistence test also - testing.db.execute(foo.insert(), a='not bar', b='baz') - testing.db.execute(foo.insert(), a='also not bar', b='baz') - testing.db.execute(foo.insert(), a='i am bar', b='bar') - testing.db.execute(foo.insert(), a='also bar', b='bar') + testing.db.execute(foo.insert(), a="not bar", b="baz") + testing.db.execute(foo.insert(), a="also not bar", b="baz") + testing.db.execute(foo.insert(), a="i am bar", b="bar") + testing.db.execute(foo.insert(), a="also bar", b="bar") class Foo(fixtures.ComparableEntity): pass @@ -36,24 +39,37 @@ class InheritingSelectablesTest(fixtures.MappedTest): mapper(Foo, foo, polymorphic_on=foo.c.b) - mapper(Baz, baz, - with_polymorphic=('*', - foo.join(baz, foo.c.b == 'baz').alias('baz')), - inherits=Foo, inherit_condition=(foo.c.a == baz.c.a), - inherit_foreign_keys=[baz.c.a], - polymorphic_identity='baz') - - mapper(Bar, bar, - with_polymorphic=('*', - foo.join(bar, foo.c.b == 'bar').alias('bar')), - inherits=Foo, inherit_condition=(foo.c.a == bar.c.a), - inherit_foreign_keys=[bar.c.a], - polymorphic_identity='bar') + mapper( + Baz, + baz, + with_polymorphic=( + "*", + foo.join(baz, foo.c.b == "baz").alias("baz"), + ), + inherits=Foo, + inherit_condition=(foo.c.a == baz.c.a), + inherit_foreign_keys=[baz.c.a], + polymorphic_identity="baz", + ) + + mapper( + Bar, + bar, + with_polymorphic=( + "*", + foo.join(bar, foo.c.b == "bar").alias("bar"), + ), + inherits=Foo, + inherit_condition=(foo.c.a == bar.c.a), + inherit_foreign_keys=[bar.c.a], + polymorphic_identity="bar", + ) s = Session() - assert [Baz(), Baz(), Bar(), Bar()] == s.query( - Foo).order_by(Foo.b.desc()).all() + assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).order_by( + Foo.b.desc() + ).all() assert [Bar(), Bar()] == s.query(Bar).all() @@ -62,16 +78,24 @@ class JoinFromSelectPersistenceTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('base', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(50))) - Table('child', metadata, - # 1. name of column must be different, so that we rely on - # mapper._table_to_equated to link the two cols - Column('child_id', Integer, ForeignKey( - 'base.id'), primary_key=True), - Column('name', String(50))) + Table( + "base", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(50)), + ) + Table( + "child", + metadata, + # 1. name of column must be different, so that we rely on + # mapper._table_to_equated to link the two cols + Column( + "child_id", Integer, ForeignKey("base.id"), primary_key=True + ), + Column("name", String(50)), + ) @classmethod def setup_classes(cls): @@ -86,20 +110,23 @@ class JoinFromSelectPersistenceTest(fixtures.MappedTest): base, child = self.tables.base, self.tables.child base_select = select([base]).alias() - mapper(Base, base_select, polymorphic_on=base_select.c.type, - polymorphic_identity='base') - mapper(Child, child, inherits=Base, - polymorphic_identity='child') + mapper( + Base, + base_select, + polymorphic_on=base_select.c.type, + polymorphic_identity="base", + ) + mapper(Child, child, inherits=Base, polymorphic_identity="child") sess = Session() # 2. use an id other than "1" here so can't rely on # the two inserts having the same id - c1 = Child(id=12, name='c1') + c1 = Child(id=12, name="c1") sess.add(c1) sess.commit() sess.close() c1 = sess.query(Child).one() - eq_(c1.name, 'c1') + eq_(c1.name, "c1") diff --git a/test/orm/inheritance/test_single.py b/test/orm/inheritance/test_single.py index 2416fdc294..13e1937a4d 100644 --- a/test/orm/inheritance/test_single.py +++ b/test/orm/inheritance/test_single.py @@ -8,24 +8,39 @@ from sqlalchemy.testing import fixtures, AssertsCompiledSQL from sqlalchemy.testing.schema import Table, Column from sqlalchemy import inspect + class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): - __dialect__ = 'default' + __dialect__ = "default" @classmethod def define_tables(cls, metadata): - Table('employees', metadata, - Column('employee_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('manager_data', String(50)), - Column('engineer_info', String(50)), - Column('type', String(20))) - - Table('reports', metadata, - Column('report_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('employee_id', ForeignKey('employees.employee_id')), - Column('name', String(50)),) + Table( + "employees", + metadata, + Column( + "employee_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("manager_data", String(50)), + Column("engineer_info", String(50)), + Column("type", String(20)), + ) + + Table( + "reports", + metadata, + Column( + "report_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("employee_id", ForeignKey("employees.employee_id")), + Column("name", String(50)), + ) @classmethod def setup_classes(cls): @@ -46,27 +61,35 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): @classmethod def setup_mappers(cls): Employee, Manager, JuniorEngineer, employees, Engineer = ( - cls.classes.Employee, cls.classes.Manager, cls.classes. - JuniorEngineer, cls.tables.employees, cls.classes.Engineer) + cls.classes.Employee, + cls.classes.Manager, + cls.classes.JuniorEngineer, + cls.tables.employees, + cls.classes.Engineer, + ) mapper(Employee, employees, polymorphic_on=employees.c.type) - mapper(Manager, inherits=Employee, polymorphic_identity='manager') - mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') - mapper(JuniorEngineer, inherits=Engineer, - polymorphic_identity='juniorengineer') + mapper(Manager, inherits=Employee, polymorphic_identity="manager") + mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") + mapper( + JuniorEngineer, + inherits=Engineer, + polymorphic_identity="juniorengineer", + ) def _fixture_one(self): Employee, JuniorEngineer, Manager, Engineer = ( self.classes.Employee, self.classes.JuniorEngineer, self.classes.Manager, - self.classes.Engineer) + self.classes.Engineer, + ) session = create_session() - m1 = Manager(name='Tom', manager_data='knows how to manage things') - e1 = Engineer(name='Kurt', engineer_info='knows how to hack') - e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed') + m1 = Manager(name="Tom", manager_data="knows how to manage things") + e1 = Engineer(name="Kurt", engineer_info="knows how to hack") + e2 = JuniorEngineer(name="Ed", engineer_info="oh that ed") session.add_all([m1, e1, e2]) session.flush() return session, m1, e1, e2 @@ -76,7 +99,8 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): self.classes.Employee, self.classes.JuniorEngineer, self.classes.Manager, - self.classes.Engineer) + self.classes.Engineer, + ) session, m1, e1, e2 = self._fixture_one() @@ -86,50 +110,51 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): assert session.query(JuniorEngineer).all() == [e2] m1 = session.query(Manager).one() - session.expire(m1, ['manager_data']) + session.expire(m1, ["manager_data"]) eq_(m1.manager_data, "knows how to manage things") - row = session.query(Engineer.name, Engineer.employee_id).filter( - Engineer.name == 'Kurt').first() - assert row.name == 'Kurt' + row = ( + session.query(Engineer.name, Engineer.employee_id) + .filter(Engineer.name == "Kurt") + .first() + ) + assert row.name == "Kurt" assert row.employee_id == e1.employee_id def test_multi_qualification(self): JuniorEngineer, Manager, Engineer = ( self.classes.JuniorEngineer, self.classes.Manager, - self.classes.Engineer) + self.classes.Engineer, + ) session, m1, e1, e2 = self._fixture_one() ealias = aliased(Engineer) - eq_( - session.query(Manager, ealias).all(), - [(m1, e1), (m1, e2)] - ) + eq_(session.query(Manager, ealias).all(), [(m1, e1), (m1, e2)]) - eq_( - session.query(Manager.name).all(), - [("Tom",)] - ) + eq_(session.query(Manager.name).all(), [("Tom",)]) eq_( session.query(Manager.name, ealias.name).all(), - [("Tom", "Kurt"), ("Tom", "Ed")] + [("Tom", "Kurt"), ("Tom", "Ed")], ) - eq_(session.query(func.upper(Manager.name), - func.upper(ealias.name)).all(), - [("TOM", "KURT"), ("TOM", "ED")]) + eq_( + session.query( + func.upper(Manager.name), func.upper(ealias.name) + ).all(), + [("TOM", "KURT"), ("TOM", "ED")], + ) eq_( session.query(Manager).add_entity(ealias).all(), - [(m1, e1), (m1, e2)] + [(m1, e1), (m1, e2)], ) eq_( session.query(Manager.name).add_column(ealias.name).all(), - [("Tom", "Kurt"), ("Tom", "Ed")] + [("Tom", "Kurt"), ("Tom", "Ed")], ) # TODO: I think raise error on this for now @@ -144,7 +169,8 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): self.classes.Employee, self.classes.JuniorEngineer, self.classes.Manager, - self.classes.Engineer) + self.classes.Engineer, + ) session, m1, e1, e2 = self._fixture_one() @@ -153,38 +179,25 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): def scalar(q): return [x for x, in q] - eq_( - scalar(session.query(Employee.employee_id)), - [m1id, e1id, e2id] - ) + eq_(scalar(session.query(Employee.employee_id)), [m1id, e1id, e2id]) - eq_( - scalar(session.query(Engineer.employee_id)), - [e1id, e2id] - ) + eq_(scalar(session.query(Engineer.employee_id)), [e1id, e2id]) - eq_( - scalar(session.query(Manager.employee_id)), [m1id] - ) + eq_(scalar(session.query(Manager.employee_id)), [m1id]) # this currently emits "WHERE type IN (?, ?) AND type IN (?, ?)", # so no result. - eq_( - session.query(Manager.employee_id, Engineer.employee_id).all(), - [] - ) + eq_(session.query(Manager.employee_id, Engineer.employee_id).all(), []) - eq_( - scalar(session.query(JuniorEngineer.employee_id)), - [e2id] - ) + eq_(scalar(session.query(JuniorEngineer.employee_id)), [e2id]) def test_bundle_qualification(self): Employee, JuniorEngineer, Manager, Engineer = ( self.classes.Employee, self.classes.JuniorEngineer, self.classes.Manager, - self.classes.Engineer) + self.classes.Engineer, + ) session, m1, e1, e2 = self._fixture_one() @@ -195,17 +208,15 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): eq_( scalar(session.query(Bundle("name", Employee.employee_id))), - [m1id, e1id, e2id] + [m1id, e1id, e2id], ) eq_( scalar(session.query(Bundle("name", Engineer.employee_id))), - [e1id, e2id] + [e1id, e2id], ) - eq_( - scalar(session.query(Bundle("name", Manager.employee_id))), [m1id] - ) + eq_(scalar(session.query(Bundle("name", Manager.employee_id))), [m1id]) # this currently emits "WHERE type IN (?, ?) AND type IN (?, ?)", # so no result. @@ -213,39 +224,41 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): session.query( Bundle("name", Manager.employee_id, Engineer.employee_id) ).all(), - [] + [], ) eq_( scalar(session.query(Bundle("name", JuniorEngineer.employee_id))), - [e2id] + [e2id], ) def test_from_self(self): Engineer = self.classes.Engineer sess = create_session() - self.assert_compile(sess.query(Engineer).from_self(), - 'SELECT anon_1.employees_employee_id AS ' - 'anon_1_employees_employee_id, ' - 'anon_1.employees_name AS ' - 'anon_1_employees_name, ' - 'anon_1.employees_manager_data AS ' - 'anon_1_employees_manager_data, ' - 'anon_1.employees_engineer_info AS ' - 'anon_1_employees_engineer_info, ' - 'anon_1.employees_type AS ' - 'anon_1_employees_type FROM (SELECT ' - 'employees.employee_id AS ' - 'employees_employee_id, employees.name AS ' - 'employees_name, employees.manager_data AS ' - 'employees_manager_data, ' - 'employees.engineer_info AS ' - 'employees_engineer_info, employees.type ' - 'AS employees_type FROM employees WHERE ' - 'employees.type IN (:type_1, :type_2)) AS ' - 'anon_1', - use_default_dialect=True) + self.assert_compile( + sess.query(Engineer).from_self(), + "SELECT anon_1.employees_employee_id AS " + "anon_1_employees_employee_id, " + "anon_1.employees_name AS " + "anon_1_employees_name, " + "anon_1.employees_manager_data AS " + "anon_1_employees_manager_data, " + "anon_1.employees_engineer_info AS " + "anon_1_employees_engineer_info, " + "anon_1.employees_type AS " + "anon_1_employees_type FROM (SELECT " + "employees.employee_id AS " + "employees_employee_id, employees.name AS " + "employees_name, employees.manager_data AS " + "employees_manager_data, " + "employees.engineer_info AS " + "employees_engineer_info, employees.type " + "AS employees_type FROM employees WHERE " + "employees.type IN (:type_1, :type_2)) AS " + "anon_1", + use_default_dialect=True, + ) def test_select_from_aliased_w_subclass(self): Engineer = self.classes.Engineer @@ -261,17 +274,17 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): ) self.assert_compile( - sess.query(literal('1')).select_from(a1), + sess.query(literal("1")).select_from(a1), "SELECT :param_1 AS param_1 FROM employees AS employees_1 " - "WHERE employees_1.type IN (:type_1, :type_2)" + "WHERE employees_1.type IN (:type_1, :type_2)", ) def test_union_modifiers(self): Engineer, Manager = self.classes("Engineer", "Manager") sess = create_session() - q1 = sess.query(Engineer).filter(Engineer.engineer_info == 'foo') - q2 = sess.query(Manager).filter(Manager.manager_data == 'bar') + q1 = sess.query(Engineer).filter(Engineer.engineer_info == "foo") + q2 = sess.query(Manager).filter(Manager.manager_data == "bar") assert_sql = ( "SELECT anon_1.employees_employee_id AS " @@ -294,7 +307,8 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): "employees.engineer_info AS employees_engineer_info, " "employees.type AS employees_type FROM employees " "WHERE employees.manager_data = :manager_data_1 " - "AND employees.type IN (:type_3)) AS anon_1") + "AND employees.type IN (:type_3)) AS anon_1" + ) for meth, token in [ (q1.union, "UNION"), @@ -308,58 +322,61 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): meth(q2), assert_sql % {"token": token}, checkparams={ - 'manager_data_1': 'bar', - 'type_2': 'juniorengineer', - 'type_3': 'manager', - 'engineer_info_1': 'foo', - 'type_1': 'engineer'}, + "manager_data_1": "bar", + "type_2": "juniorengineer", + "type_3": "manager", + "engineer_info_1": "foo", + "type_1": "engineer", + }, ) def test_from_self_count(self): Engineer = self.classes.Engineer sess = create_session() - col = func.count(literal_column('*')) + col = func.count(literal_column("*")) self.assert_compile( sess.query(Engineer.employee_id).from_self(col), "SELECT count(*) AS count_1 " "FROM (SELECT employees.employee_id AS employees_employee_id " "FROM employees " "WHERE employees.type IN (:type_1, :type_2)) AS anon_1", - use_default_dialect=True + use_default_dialect=True, ) def test_select_from_count(self): Manager, Engineer = (self.classes.Manager, self.classes.Engineer) sess = create_session() - m1 = Manager(name='Tom', manager_data='data1') - e1 = Engineer(name='Kurt', engineer_info='knows how to hack') + m1 = Manager(name="Tom", manager_data="data1") + e1 = Engineer(name="Kurt", engineer_info="knows how to hack") sess.add_all([m1, e1]) sess.flush() - eq_( - sess.query(func.count(1)).select_from(Manager).all(), - [(1, )] - ) + eq_(sess.query(func.count(1)).select_from(Manager).all(), [(1,)]) def test_select_from_subquery(self): Manager, JuniorEngineer, employees, Engineer = ( self.classes.Manager, self.classes.JuniorEngineer, self.tables.employees, - self.classes.Engineer) + self.classes.Engineer, + ) sess = create_session() - m1 = Manager(name='Tom', manager_data='data1') - m2 = Manager(name='Tom2', manager_data='data2') - e1 = Engineer(name='Kurt', engineer_info='knows how to hack') - e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed') + m1 = Manager(name="Tom", manager_data="data1") + m2 = Manager(name="Tom2", manager_data="data2") + e1 = Engineer(name="Kurt", engineer_info="knows how to hack") + e2 = JuniorEngineer(name="Ed", engineer_info="oh that ed") sess.add_all([m1, m2, e1, e2]) sess.flush() - eq_(sess.query(Manager).select_from( - employees.select().limit(10)).all(), [m1, m2]) + eq_( + sess.query(Manager) + .select_from(employees.select().limit(10)) + .all(), + [m1, m2], + ) def test_count(self): Employee = self.classes.Employee @@ -368,10 +385,10 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): Engineer = self.classes.Engineer sess = create_session() - m1 = Manager(name='Tom', manager_data='data1') - m2 = Manager(name='Tom2', manager_data='data2') - e1 = Engineer(name='Kurt', engineer_info='data3') - e2 = JuniorEngineer(name='marvin', engineer_info='data4') + m1 = Manager(name="Tom", manager_data="data1") + m2 = Manager(name="Tom2", manager_data="data2") + e1 = Engineer(name="Kurt", engineer_info="data3") + e2 = JuniorEngineer(name="marvin", engineer_info="data4") sess.add_all([m1, m2, e1, e2]) sess.flush() @@ -379,8 +396,8 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): eq_(sess.query(Engineer).count(), 2) eq_(sess.query(Employee).count(), 4) - eq_(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2) - eq_(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3) + eq_(sess.query(Manager).filter(Manager.name.like("%m%")).count(), 2) + eq_(sess.query(Employee).filter(Employee.name.like("%m%")).count(), 3) def test_exists_standalone(self): Engineer = self.classes.Engineer @@ -389,50 +406,63 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): self.assert_compile( sess.query( - sess.query(Engineer).filter(Engineer.name == 'foo').exists()), + sess.query(Engineer).filter(Engineer.name == "foo").exists() + ), "SELECT EXISTS (SELECT 1 FROM employees WHERE " "employees.name = :name_1 AND employees.type " - "IN (:type_1, :type_2)) AS anon_1" + "IN (:type_1, :type_2)) AS anon_1", ) def test_type_filtering(self): - Employee, Manager, reports, Engineer = (self.classes.Employee, - self.classes.Manager, - self.tables.reports, - self.classes.Engineer) + Employee, Manager, reports, Engineer = ( + self.classes.Employee, + self.classes.Manager, + self.tables.reports, + self.classes.Engineer, + ) class Report(fixtures.ComparableEntity): pass - mapper(Report, reports, properties={ - 'employee': relationship(Employee, backref='reports')}) + mapper( + Report, + reports, + properties={"employee": relationship(Employee, backref="reports")}, + ) sess = create_session() - m1 = Manager(name='Tom', manager_data='data1') + m1 = Manager(name="Tom", manager_data="data1") r1 = Report(employee=m1) sess.add_all([m1, r1]) sess.flush() rq = sess.query(Report) - assert len(rq.filter(Report.employee.of_type(Manager).has()) - .all()) == 1 - assert len(rq.filter(Report.employee.of_type(Engineer).has()) - .all()) == 0 + assert ( + len(rq.filter(Report.employee.of_type(Manager).has()).all()) == 1 + ) + assert ( + len(rq.filter(Report.employee.of_type(Engineer).has()).all()) == 0 + ) def test_type_joins(self): - Employee, Manager, reports, Engineer = (self.classes.Employee, - self.classes.Manager, - self.tables.reports, - self.classes.Engineer) + Employee, Manager, reports, Engineer = ( + self.classes.Employee, + self.classes.Manager, + self.tables.reports, + self.classes.Engineer, + ) class Report(fixtures.ComparableEntity): pass - mapper(Report, reports, properties={ - 'employee': relationship(Employee, backref='reports')}) + mapper( + Report, + reports, + properties={"employee": relationship(Employee, backref="reports")}, + ) sess = create_session() - m1 = Manager(name='Tom', manager_data='data1') + m1 = Manager(name="Tom", manager_data="data1") r1 = Report(employee=m1) sess.add_all([m1, r1]) sess.flush() @@ -444,20 +474,29 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): class RelationshipFromSingleTest( - testing.AssertsCompiledSQL, fixtures.MappedTest): + testing.AssertsCompiledSQL, fixtures.MappedTest +): @classmethod def define_tables(cls, metadata): - Table('employee', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('type', String(20))) - - Table('employee_stuff', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('employee_id', Integer, ForeignKey('employee.id')), - Column('name', String(50))) + Table( + "employee", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), + Column("type", String(20)), + ) + + Table( + "employee_stuff", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("employee_id", Integer, ForeignKey("employee.id")), + Column("name", String(50)), + ) @classmethod def setup_classes(cls): @@ -472,56 +511,91 @@ class RelationshipFromSingleTest( def test_subquery_load(self): employee, employee_stuff, Employee, Stuff, Manager = ( - self.tables.employee, self.tables.employee_stuff, self.classes. - Employee, self.classes.Stuff, self.classes.Manager) + self.tables.employee, + self.tables.employee_stuff, + self.classes.Employee, + self.classes.Stuff, + self.classes.Manager, + ) - mapper(Employee, employee, polymorphic_on=employee.c.type, - polymorphic_identity='employee') - mapper(Manager, inherits=Employee, polymorphic_identity='manager', - properties={'stuff': relationship(Stuff)}) + mapper( + Employee, + employee, + polymorphic_on=employee.c.type, + polymorphic_identity="employee", + ) + mapper( + Manager, + inherits=Employee, + polymorphic_identity="manager", + properties={"stuff": relationship(Stuff)}, + ) mapper(Stuff, employee_stuff) sess = create_session() - context = sess.query(Manager).options( - subqueryload('stuff'))._compile_context() - subq = context.attributes[('subquery', (class_mapper( - Manager), class_mapper(Manager).attrs.stuff))] - - self.assert_compile(subq, - 'SELECT employee_stuff.id AS ' - 'employee_stuff_id, employee_stuff.employee' - '_id AS employee_stuff_employee_id, ' - 'employee_stuff.name AS ' - 'employee_stuff_name, anon_1.employee_id ' - 'AS anon_1_employee_id FROM (SELECT ' - 'employee.id AS employee_id FROM employee ' - 'WHERE employee.type IN (:type_1)) AS anon_1 ' - 'JOIN employee_stuff ON anon_1.employee_id ' - '= employee_stuff.employee_id ORDER BY ' - 'anon_1.employee_id', - use_default_dialect=True) + context = ( + sess.query(Manager) + .options(subqueryload("stuff")) + ._compile_context() + ) + subq = context.attributes[ + ( + "subquery", + (class_mapper(Manager), class_mapper(Manager).attrs.stuff), + ) + ] + + self.assert_compile( + subq, + "SELECT employee_stuff.id AS " + "employee_stuff_id, employee_stuff.employee" + "_id AS employee_stuff_employee_id, " + "employee_stuff.name AS " + "employee_stuff_name, anon_1.employee_id " + "AS anon_1_employee_id FROM (SELECT " + "employee.id AS employee_id FROM employee " + "WHERE employee.type IN (:type_1)) AS anon_1 " + "JOIN employee_stuff ON anon_1.employee_id " + "= employee_stuff.employee_id ORDER BY " + "anon_1.employee_id", + use_default_dialect=True, + ) class RelationshipToSingleTest( - testing.AssertsCompiledSQL, fixtures.MappedTest): - __dialect__ = 'default' + testing.AssertsCompiledSQL, fixtures.MappedTest +): + __dialect__ = "default" @classmethod def define_tables(cls, metadata): - Table('employees', metadata, - Column('employee_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('manager_data', String(50)), - Column('engineer_info', String(50)), - Column('type', String(20)), - Column('company_id', Integer, - ForeignKey('companies.company_id'))) - - Table('companies', metadata, - Column('company_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)),) + Table( + "employees", + metadata, + Column( + "employee_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("manager_data", String(50)), + Column("engineer_info", String(50)), + Column("type", String(20)), + Column("company_id", Integer, ForeignKey("companies.company_id")), + ) + + Table( + "companies", + metadata, + Column( + "company_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + ) @classmethod def setup_classes(cls): @@ -541,64 +615,85 @@ class RelationshipToSingleTest( pass def test_of_type(self): - JuniorEngineer, Company, companies, Manager,\ - Employee, employees, Engineer = (self.classes.JuniorEngineer, - self.classes.Company, - self.tables.companies, - self.classes.Manager, - self.classes.Employee, - self.tables.employees, - self.classes.Engineer) - - mapper(Company, companies, properties={ - 'employees': relationship(Employee, backref='company') - }) + JuniorEngineer, Company, companies, Manager, Employee, employees, Engineer = ( + self.classes.JuniorEngineer, + self.classes.Company, + self.tables.companies, + self.classes.Manager, + self.classes.Employee, + self.tables.employees, + self.classes.Engineer, + ) + + mapper( + Company, + companies, + properties={ + "employees": relationship(Employee, backref="company") + }, + ) mapper(Employee, employees, polymorphic_on=employees.c.type) - mapper(Manager, inherits=Employee, polymorphic_identity='manager') - mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') - mapper(JuniorEngineer, inherits=Engineer, - polymorphic_identity='juniorengineer') + mapper(Manager, inherits=Employee, polymorphic_identity="manager") + mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") + mapper( + JuniorEngineer, + inherits=Engineer, + polymorphic_identity="juniorengineer", + ) sess = sessionmaker()() - c1 = Company(name='c1') - c2 = Company(name='c2') + c1 = Company(name="c1") + c2 = Company(name="c2") - m1 = Manager(name='Tom', manager_data='data1', company=c1) - m2 = Manager(name='Tom2', manager_data='data2', company=c2) - e1 = Engineer(name='Kurt', engineer_info='knows how to hack', - company=c2) - e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed', company=c1) + m1 = Manager(name="Tom", manager_data="data1", company=c1) + m2 = Manager(name="Tom2", manager_data="data2", company=c2) + e1 = Engineer( + name="Kurt", engineer_info="knows how to hack", company=c2 + ) + e2 = JuniorEngineer(name="Ed", engineer_info="oh that ed", company=c1) sess.add_all([c1, c2, m1, m2, e1, e2]) sess.commit() sess.expunge_all() - eq_(sess.query(Company).filter(Company.employees.of_type( - JuniorEngineer).any()).all(), [Company(name='c1'), ]) + eq_( + sess.query(Company) + .filter(Company.employees.of_type(JuniorEngineer).any()) + .all(), + [Company(name="c1")], + ) - eq_(sess.query(Company).join(Company.employees.of_type( - JuniorEngineer)).all(), [Company(name='c1'), ]) + eq_( + sess.query(Company) + .join(Company.employees.of_type(JuniorEngineer)) + .all(), + [Company(name="c1")], + ) def test_of_type_aliased_fromjoinpoint(self): - Company, Employee, Engineer = (self.classes.Company, - self.classes.Employee, - self.classes.Engineer) + Company, Employee, Engineer = ( + self.classes.Company, + self.classes.Employee, + self.classes.Engineer, + ) companies, employees = self.tables.companies, self.tables.employees - mapper(Company, companies, properties={ - 'employee': relationship(Employee) - }) + mapper( + Company, companies, properties={"employee": relationship(Employee)} + ) mapper(Employee, employees, polymorphic_on=employees.c.type) - mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') + mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") sess = create_session() self.assert_compile( sess.query(Company).outerjoin( Company.employee.of_type(Engineer), - aliased=True, from_joinpoint=True), + aliased=True, + from_joinpoint=True, + ), "SELECT companies.company_id AS companies_company_id, " "companies.name AS companies_name FROM companies " "LEFT OUTER JOIN employees AS employees_1 ON " "companies.company_id = employees_1.company_id " - "AND employees_1.type IN (:type_1)" + "AND employees_1.type IN (:type_1)", ) def test_join_explicit_onclause_no_discriminator(self): @@ -606,37 +701,45 @@ class RelationshipToSingleTest( Company, Employee, Engineer = ( self.classes.Company, self.classes.Employee, - self.classes.Engineer) + self.classes.Engineer, + ) companies, employees = self.tables.companies, self.tables.employees - mapper(Company, companies, properties={ - 'employees': relationship(Employee) - }) + mapper( + Company, + companies, + properties={"employees": relationship(Employee)}, + ) mapper(Employee, employees) mapper(Engineer, inherits=Employee) sess = create_session() self.assert_compile( sess.query(Company, Engineer.name).join( - Engineer, Company.company_id == Engineer.company_id), + Engineer, Company.company_id == Engineer.company_id + ), "SELECT companies.company_id AS companies_company_id, " "companies.name AS companies_name, " "employees.name AS employees_name " "FROM companies JOIN " - "employees ON companies.company_id = employees.company_id" + "employees ON companies.company_id = employees.company_id", ) def test_outer_join_prop(self): - Company, Employee, Engineer = (self.classes.Company, - self.classes.Employee, - self.classes.Engineer) + Company, Employee, Engineer = ( + self.classes.Company, + self.classes.Employee, + self.classes.Engineer, + ) companies, employees = self.tables.companies, self.tables.employees - mapper(Company, companies, properties={ - 'engineers': relationship(Engineer) - }) + mapper( + Company, + companies, + properties={"engineers": relationship(Engineer)}, + ) mapper(Employee, employees, polymorphic_on=employees.c.type) - mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') + mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") sess = create_session() self.assert_compile( @@ -645,47 +748,59 @@ class RelationshipToSingleTest( "companies.name AS companies_name, " "employees.name AS employees_name " "FROM companies LEFT OUTER JOIN employees ON companies.company_id " - "= employees.company_id AND employees.type IN (:type_1)") + "= employees.company_id AND employees.type IN (:type_1)", + ) def test_outer_join_prop_alias(self): - Company, Employee, Engineer = (self.classes.Company, - self.classes.Employee, - self.classes.Engineer) + Company, Employee, Engineer = ( + self.classes.Company, + self.classes.Employee, + self.classes.Engineer, + ) companies, employees = self.tables.companies, self.tables.employees - mapper(Company, companies, properties={ - 'engineers': relationship(Engineer) - }) + mapper( + Company, + companies, + properties={"engineers": relationship(Engineer)}, + ) mapper(Employee, employees, polymorphic_on=employees.c.type) - mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') + mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") eng_alias = aliased(Engineer) sess = create_session() self.assert_compile( sess.query(Company, eng_alias.name).outerjoin( - eng_alias, Company.engineers), + eng_alias, Company.engineers + ), "SELECT companies.company_id AS companies_company_id, " "companies.name AS companies_name, employees_1.name AS " "employees_1_name FROM companies LEFT OUTER " "JOIN employees AS employees_1 ON companies.company_id " - "= employees_1.company_id AND employees_1.type IN (:type_1)") + "= employees_1.company_id AND employees_1.type IN (:type_1)", + ) def test_outer_join_literal_onclause(self): - Company, Employee, Engineer = (self.classes.Company, - self.classes.Employee, - self.classes.Engineer) + Company, Employee, Engineer = ( + self.classes.Company, + self.classes.Employee, + self.classes.Engineer, + ) companies, employees = self.tables.companies, self.tables.employees - mapper(Company, companies, properties={ - 'engineers': relationship(Engineer) - }) + mapper( + Company, + companies, + properties={"engineers": relationship(Engineer)}, + ) mapper(Employee, employees, polymorphic_on=employees.c.type) - mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') + mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") sess = create_session() self.assert_compile( sess.query(Company, Engineer).outerjoin( - Engineer, Company.company_id == Engineer.company_id), + Engineer, Company.company_id == Engineer.company_id + ), "SELECT companies.company_id AS companies_company_id, " "companies.name AS companies_name, " "employees.employee_id AS employees_employee_id, " @@ -696,26 +811,31 @@ class RelationshipToSingleTest( "employees.company_id AS employees_company_id FROM companies " "LEFT OUTER JOIN employees ON " "companies.company_id = employees.company_id " - "AND employees.type IN (:type_1)" + "AND employees.type IN (:type_1)", ) def test_outer_join_literal_onclause_alias(self): - Company, Employee, Engineer = (self.classes.Company, - self.classes.Employee, - self.classes.Engineer) + Company, Employee, Engineer = ( + self.classes.Company, + self.classes.Employee, + self.classes.Engineer, + ) companies, employees = self.tables.companies, self.tables.employees - mapper(Company, companies, properties={ - 'engineers': relationship(Engineer) - }) + mapper( + Company, + companies, + properties={"engineers": relationship(Engineer)}, + ) mapper(Employee, employees, polymorphic_on=employees.c.type) - mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') + mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") eng_alias = aliased(Engineer) sess = create_session() self.assert_compile( sess.query(Company, eng_alias).outerjoin( - eng_alias, Company.company_id == eng_alias.company_id), + eng_alias, Company.company_id == eng_alias.company_id + ), "SELECT companies.company_id AS companies_company_id, " "companies.name AS companies_name, " "employees_1.employee_id AS employees_1_employee_id, " @@ -726,25 +846,28 @@ class RelationshipToSingleTest( "employees_1.company_id AS employees_1_company_id " "FROM companies LEFT OUTER JOIN employees AS employees_1 ON " "companies.company_id = employees_1.company_id " - "AND employees_1.type IN (:type_1)" + "AND employees_1.type IN (:type_1)", ) def test_outer_join_no_onclause(self): - Company, Employee, Engineer = (self.classes.Company, - self.classes.Employee, - self.classes.Engineer) + Company, Employee, Engineer = ( + self.classes.Company, + self.classes.Employee, + self.classes.Engineer, + ) companies, employees = self.tables.companies, self.tables.employees - mapper(Company, companies, properties={ - 'engineers': relationship(Engineer) - }) + mapper( + Company, + companies, + properties={"engineers": relationship(Engineer)}, + ) mapper(Employee, employees, polymorphic_on=employees.c.type) - mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') + mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") sess = create_session() self.assert_compile( - sess.query(Company, Engineer).outerjoin( - Engineer), + sess.query(Company, Engineer).outerjoin(Engineer), "SELECT companies.company_id AS companies_company_id, " "companies.name AS companies_name, " "employees.employee_id AS employees_employee_id, " @@ -755,26 +878,29 @@ class RelationshipToSingleTest( "employees.company_id AS employees_company_id " "FROM companies LEFT OUTER JOIN employees ON " "companies.company_id = employees.company_id " - "AND employees.type IN (:type_1)" + "AND employees.type IN (:type_1)", ) def test_outer_join_no_onclause_alias(self): - Company, Employee, Engineer = (self.classes.Company, - self.classes.Employee, - self.classes.Engineer) + Company, Employee, Engineer = ( + self.classes.Company, + self.classes.Employee, + self.classes.Engineer, + ) companies, employees = self.tables.companies, self.tables.employees - mapper(Company, companies, properties={ - 'engineers': relationship(Engineer) - }) + mapper( + Company, + companies, + properties={"engineers": relationship(Engineer)}, + ) mapper(Employee, employees, polymorphic_on=employees.c.type) - mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') + mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") eng_alias = aliased(Engineer) sess = create_session() self.assert_compile( - sess.query(Company, eng_alias).outerjoin( - eng_alias), + sess.query(Company, eng_alias).outerjoin(eng_alias), "SELECT companies.company_id AS companies_company_id, " "companies.name AS companies_name, " "employees_1.employee_id AS employees_1_employee_id, " @@ -785,31 +911,34 @@ class RelationshipToSingleTest( "employees_1.company_id AS employees_1_company_id " "FROM companies LEFT OUTER JOIN employees AS employees_1 ON " "companies.company_id = employees_1.company_id " - "AND employees_1.type IN (:type_1)" + "AND employees_1.type IN (:type_1)", ) def test_correlated_column_select(self): - Company, Employee, Engineer = (self.classes.Company, - self.classes.Employee, - self.classes.Engineer) + Company, Employee, Engineer = ( + self.classes.Company, + self.classes.Employee, + self.classes.Engineer, + ) companies, employees = self.tables.companies, self.tables.employees mapper(Company, companies) mapper( - Employee, employees, + Employee, + employees, polymorphic_on=employees.c.type, - properties={ - 'company': relationship(Company) - } + properties={"company": relationship(Company)}, ) - mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') + mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") sess = create_session() - engineer_count = sess.query(func.count(Engineer.employee_id)) \ - .select_from(Engineer) \ - .filter(Engineer.company_id == Company.company_id) \ - .correlate(Company) \ + engineer_count = ( + sess.query(func.count(Engineer.employee_id)) + .select_from(Engineer) + .filter(Engineer.company_id == Company.company_id) + .correlate(Company) .as_scalar() + ) self.assert_compile( sess.query(Company.company_id, engineer_count), @@ -817,47 +946,63 @@ class RelationshipToSingleTest( "(SELECT count(employees.employee_id) AS count_1 " "FROM employees WHERE employees.company_id = " "companies.company_id AND employees.type IN (:type_1)) AS anon_1 " - "FROM companies" + "FROM companies", ) def test_no_aliasing_from_overlap(self): # test [ticket:3233] - Company, Employee, Engineer, Manager = (self.classes.Company, - self.classes.Employee, - self.classes.Engineer, - self.classes.Manager) + Company, Employee, Engineer, Manager = ( + self.classes.Company, + self.classes.Employee, + self.classes.Engineer, + self.classes.Manager, + ) companies, employees = self.tables.companies, self.tables.employees - mapper(Company, companies, properties={ - 'employees': relationship(Employee, backref="company") - }) + mapper( + Company, + companies, + properties={ + "employees": relationship(Employee, backref="company") + }, + ) mapper(Employee, employees, polymorphic_on=employees.c.type) - mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') - mapper(Manager, inherits=Employee, polymorphic_identity='manager') + mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") + mapper(Manager, inherits=Employee, polymorphic_identity="manager") s = create_session() - q1 = s.query(Engineer).\ - join(Engineer.company).\ - join(Manager, Company.employees) + q1 = ( + s.query(Engineer) + .join(Engineer.company) + .join(Manager, Company.employees) + ) - q2 = s.query(Engineer).\ - join(Engineer.company).\ - join(Manager, Company.company_id == Manager.company_id) + q2 = ( + s.query(Engineer) + .join(Engineer.company) + .join(Manager, Company.company_id == Manager.company_id) + ) - q3 = s.query(Engineer).\ - join(Engineer.company).\ - join(Manager, Company.employees.of_type(Manager)) + q3 = ( + s.query(Engineer) + .join(Engineer.company) + .join(Manager, Company.employees.of_type(Manager)) + ) - q4 = s.query(Engineer).\ - join(Company, Company.company_id == Engineer.company_id).\ - join(Manager, Company.employees.of_type(Manager)) + q4 = ( + s.query(Engineer) + .join(Company, Company.company_id == Engineer.company_id) + .join(Manager, Company.employees.of_type(Manager)) + ) - q5 = s.query(Engineer).\ - join(Company, Company.company_id == Engineer.company_id).\ - join(Manager, Company.company_id == Manager.company_id) + q5 = ( + s.query(Engineer) + .join(Company, Company.company_id == Engineer.company_id) + .join(Manager, Company.company_id == Manager.company_id) + ) # note that the query is incorrect SQL; we JOIN to # employees twice. However, this is what's expected so we seek @@ -877,38 +1022,49 @@ class RelationshipToSingleTest( "JOIN employees " "ON companies.company_id = employees.company_id " "AND employees.type IN (:type_1) " - "WHERE employees.type IN (:type_2)" + "WHERE employees.type IN (:type_2)", ) def test_relationship_to_subclass(self): - JuniorEngineer, Company, companies, Manager, \ - Employee, employees, Engineer = (self.classes.JuniorEngineer, - self.classes.Company, - self.tables.companies, - self.classes.Manager, - self.classes.Employee, - self.tables.employees, - self.classes.Engineer) - - mapper(Company, companies, properties={ - 'engineers': relationship(Engineer) - }) - mapper(Employee, employees, polymorphic_on=employees.c.type, - properties={'company': relationship(Company)}) - mapper(Manager, inherits=Employee, polymorphic_identity='manager') - mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') - mapper(JuniorEngineer, inherits=Engineer, - polymorphic_identity='juniorengineer') + JuniorEngineer, Company, companies, Manager, Employee, employees, Engineer = ( + self.classes.JuniorEngineer, + self.classes.Company, + self.tables.companies, + self.classes.Manager, + self.classes.Employee, + self.tables.employees, + self.classes.Engineer, + ) + + mapper( + Company, + companies, + properties={"engineers": relationship(Engineer)}, + ) + mapper( + Employee, + employees, + polymorphic_on=employees.c.type, + properties={"company": relationship(Company)}, + ) + mapper(Manager, inherits=Employee, polymorphic_identity="manager") + mapper(Engineer, inherits=Employee, polymorphic_identity="engineer") + mapper( + JuniorEngineer, + inherits=Engineer, + polymorphic_identity="juniorengineer", + ) sess = sessionmaker()() - c1 = Company(name='c1') - c2 = Company(name='c2') + c1 = Company(name="c1") + c2 = Company(name="c2") - m1 = Manager(name='Tom', manager_data='data1', company=c1) - m2 = Manager(name='Tom2', manager_data='data2', company=c2) - e1 = Engineer(name='Kurt', engineer_info='knows how to hack', - company=c2) - e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed', company=c1) + m1 = Manager(name="Tom", manager_data="data1", company=c1) + m2 = Manager(name="Tom2", manager_data="data2", company=c2) + e1 = Engineer( + name="Kurt", engineer_info="knows how to hack", company=c2 + ) + e2 = JuniorEngineer(name="Ed", engineer_info="oh that ed", company=c1) sess.add_all([c1, c2, m1, m2, e1, e2]) sess.commit() @@ -916,42 +1072,64 @@ class RelationshipToSingleTest( eq_(c2.engineers, [e1]) sess.expunge_all() - eq_(sess.query(Company).order_by(Company.name).all(), + eq_( + sess.query(Company).order_by(Company.name).all(), [ - Company(name='c1', engineers=[JuniorEngineer(name='Ed')]), - Company(name='c2', engineers=[Engineer(name='Kurt')])]) + Company(name="c1", engineers=[JuniorEngineer(name="Ed")]), + Company(name="c2", engineers=[Engineer(name="Kurt")]), + ], + ) # eager load join should limit to only "Engineer" sess.expunge_all() - eq_(sess.query(Company).options(joinedload('engineers')). - order_by(Company.name).all(), - [Company(name='c1', engineers=[JuniorEngineer(name='Ed')]), - Company(name='c2', engineers=[Engineer(name='Kurt')])]) + eq_( + sess.query(Company) + .options(joinedload("engineers")) + .order_by(Company.name) + .all(), + [ + Company(name="c1", engineers=[JuniorEngineer(name="Ed")]), + Company(name="c2", engineers=[Engineer(name="Kurt")]), + ], + ) # join() to Company.engineers, Employee as the requested entity sess.expunge_all() - eq_(sess.query(Company, Employee) + eq_( + sess.query(Company, Employee) .join(Company.engineers) .order_by(Company.name) .all(), - [(Company(name='c1'), JuniorEngineer(name='Ed')), - (Company(name='c2'), Engineer(name='Kurt'))]) + [ + (Company(name="c1"), JuniorEngineer(name="Ed")), + (Company(name="c2"), Engineer(name="Kurt")), + ], + ) # join() to Company.engineers, Engineer as the requested entity. # this actually applies the IN criterion twice which is less than # ideal. sess.expunge_all() - eq_(sess.query(Company, Engineer) + eq_( + sess.query(Company, Engineer) .join(Company.engineers) .order_by(Company.name) .all(), - [(Company(name='c1'), JuniorEngineer(name='Ed')), - (Company(name='c2'), Engineer(name='Kurt'))]) + [ + (Company(name="c1"), JuniorEngineer(name="Ed")), + (Company(name="c2"), Engineer(name="Kurt")), + ], + ) # join() to Company.engineers without any Employee/Engineer entity sess.expunge_all() - eq_(sess.query(Company).join(Company.engineers).filter( - Engineer.name.in_(['Tom', 'Kurt'])).all(), [Company(name='c2')]) + eq_( + sess.query(Company) + .join(Company.engineers) + .filter(Engineer.name.in_(["Tom", "Kurt"])) + .all(), + [Company(name="c2")], + ) # this however fails as it does not limit the subtypes to just # "Engineer". with joins constructed by filter(), we seem to be @@ -964,31 +1142,48 @@ class RelationshipToSingleTest( @testing.fails_on_everything_except() def go(): sess.expunge_all() - eq_(sess.query(Company).filter( - Company.company_id == Engineer.company_id).filter( - Engineer.name.in_(['Tom', 'Kurt'])).all(), - [Company(name='c2')]) + eq_( + sess.query(Company) + .filter(Company.company_id == Engineer.company_id) + .filter(Engineer.name.in_(["Tom", "Kurt"])) + .all(), + [Company(name="c2")], + ) + go() class ManyToManyToSingleTest(fixtures.MappedTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" @classmethod def define_tables(cls, metadata): - Table('parent', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True)) - Table('m2m', metadata, - Column('parent_id', Integer, - ForeignKey('parent.id'), primary_key=True), - Column('child_id', Integer, - ForeignKey('child.id'), primary_key=True)) - Table('child', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('discriminator', String(20)), - Column('name', String(20))) + Table( + "parent", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) + Table( + "m2m", + metadata, + Column( + "parent_id", Integer, ForeignKey("parent.id"), primary_key=True + ), + Column( + "child_id", Integer, ForeignKey("child.id"), primary_key=True + ), + ) + Table( + "child", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("discriminator", String(20)), + Column("name", String(20)), + ) @classmethod def setup_classes(cls): @@ -1006,19 +1201,35 @@ class ManyToManyToSingleTest(fixtures.MappedTest, AssertsCompiledSQL): @classmethod def setup_mappers(cls): - mapper(cls.classes.Parent, cls.tables.parent, properties={ - "s1": relationship(cls.classes.SubChild1, - secondary=cls.tables.m2m, - uselist=False), - "s2": relationship(cls.classes.SubChild2, - secondary=cls.tables.m2m) - }) - mapper(cls.classes.Child, cls.tables.child, - polymorphic_on=cls.tables.child.c.discriminator) - mapper(cls.classes.SubChild1, inherits=cls.classes.Child, - polymorphic_identity='sub1') - mapper(cls.classes.SubChild2, inherits=cls.classes.Child, - polymorphic_identity='sub2') + mapper( + cls.classes.Parent, + cls.tables.parent, + properties={ + "s1": relationship( + cls.classes.SubChild1, + secondary=cls.tables.m2m, + uselist=False, + ), + "s2": relationship( + cls.classes.SubChild2, secondary=cls.tables.m2m + ), + }, + ) + mapper( + cls.classes.Child, + cls.tables.child, + polymorphic_on=cls.tables.child.c.discriminator, + ) + mapper( + cls.classes.SubChild1, + inherits=cls.classes.Child, + polymorphic_identity="sub1", + ) + mapper( + cls.classes.SubChild2, + inherits=cls.classes.Child, + polymorphic_identity="sub2", + ) @classmethod def insert_data(cls): @@ -1026,9 +1237,14 @@ class ManyToManyToSingleTest(fixtures.MappedTest, AssertsCompiledSQL): SubChild1 = cls.classes.SubChild1 SubChild2 = cls.classes.SubChild2 s = Session() - s.add_all([ - Parent(s1=SubChild1(name='sc1_1'), - s2=[SubChild2(name="sc2_1"), SubChild2(name="sc2_2")])]) + s.add_all( + [ + Parent( + s1=SubChild1(name="sc1_1"), + s2=[SubChild2(name="sc2_1"), SubChild2(name="sc2_2")], + ) + ] + ) s.commit() def test_eager_join(self): @@ -1038,7 +1254,7 @@ class ManyToManyToSingleTest(fixtures.MappedTest, AssertsCompiledSQL): s = Session() p1 = s.query(Parent).options(joinedload(Parent.s1)).all()[0] - eq_(p1.__dict__['s1'], SubChild1(name='sc1_1')) + eq_(p1.__dict__["s1"], SubChild1(name="sc1_1")) def test_manual_join(self): Parent = self.classes.Parent @@ -1048,7 +1264,7 @@ class ManyToManyToSingleTest(fixtures.MappedTest, AssertsCompiledSQL): s = Session() p1, c1 = s.query(Parent, Child).outerjoin(Parent.s1).all()[0] - eq_(c1, SubChild1(name='sc1_1')) + eq_(c1, SubChild1(name="sc1_1")) def test_assert_join_sql(self): Parent = self.classes.Parent @@ -1064,7 +1280,7 @@ class ManyToManyToSingleTest(fixtures.MappedTest, AssertsCompiledSQL): "FROM parent LEFT OUTER JOIN (m2m AS m2m_1 " "JOIN child ON child.id = m2m_1.child_id " "AND child.discriminator IN (:discriminator_1)) " - "ON parent.id = m2m_1.parent_id" + "ON parent.id = m2m_1.parent_id", ) def test_assert_joinedload_sql(self): @@ -1081,7 +1297,7 @@ class ManyToManyToSingleTest(fixtures.MappedTest, AssertsCompiledSQL): "FROM parent LEFT OUTER JOIN " "(m2m AS m2m_1 JOIN child AS child_1 " "ON child_1.id = m2m_1.child_id AND child_1.discriminator " - "IN (:discriminator_1)) ON parent.id = m2m_1.parent_id" + "IN (:discriminator_1)) ON parent.id = m2m_1.parent_id", ) @@ -1090,18 +1306,31 @@ class SingleOnJoinedTest(fixtures.MappedTest): def define_tables(cls, metadata): global persons_table, employees_table - persons_table = Table('persons', metadata, - Column('person_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('type', String(20), nullable=False)) + persons_table = Table( + "persons", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("type", String(20), nullable=False), + ) - employees_table = Table('employees', metadata, - Column('person_id', Integer, - ForeignKey('persons.person_id'), - primary_key=True), - Column('employee_data', String(50)), - Column('manager_data', String(50)),) + employees_table = Table( + "employees", + metadata, + Column( + "person_id", + Integer, + ForeignKey("persons.person_id"), + primary_key=True, + ), + Column("employee_data", String(50)), + Column("manager_data", String(50)), + ) def test_single_on_joined(self): class Person(fixtures.ComparableEntity): @@ -1113,44 +1342,67 @@ class SingleOnJoinedTest(fixtures.MappedTest): class Manager(Employee): pass - mapper(Person, persons_table, polymorphic_on=persons_table.c.type, - polymorphic_identity='person') - mapper(Employee, employees_table, inherits=Person, - polymorphic_identity='engineer') - mapper(Manager, inherits=Employee, polymorphic_identity='manager') + mapper( + Person, + persons_table, + polymorphic_on=persons_table.c.type, + polymorphic_identity="person", + ) + mapper( + Employee, + employees_table, + inherits=Person, + polymorphic_identity="engineer", + ) + mapper(Manager, inherits=Employee, polymorphic_identity="manager") sess = create_session() - sess.add(Person(name='p1')) - sess.add(Employee(name='e1', employee_data='ed1')) - sess.add(Manager(name='m1', employee_data='ed2', manager_data='md1')) + sess.add(Person(name="p1")) + sess.add(Employee(name="e1", employee_data="ed1")) + sess.add(Manager(name="m1", employee_data="ed2", manager_data="md1")) sess.flush() sess.expunge_all() - eq_(sess.query(Person).order_by(Person.person_id).all(), [ - Person(name='p1'), - Employee(name='e1', employee_data='ed1'), - Manager(name='m1', employee_data='ed2', manager_data='md1') - ]) + eq_( + sess.query(Person).order_by(Person.person_id).all(), + [ + Person(name="p1"), + Employee(name="e1", employee_data="ed1"), + Manager(name="m1", employee_data="ed2", manager_data="md1"), + ], + ) sess.expunge_all() - eq_(sess.query(Employee).order_by(Person.person_id).all(), [ - Employee(name='e1', employee_data='ed1'), - Manager(name='m1', employee_data='ed2', manager_data='md1') - ]) + eq_( + sess.query(Employee).order_by(Person.person_id).all(), + [ + Employee(name="e1", employee_data="ed1"), + Manager(name="m1", employee_data="ed2", manager_data="md1"), + ], + ) sess.expunge_all() - eq_(sess.query(Manager).order_by(Person.person_id).all(), [ - Manager(name='m1', employee_data='ed2', manager_data='md1') - ]) + eq_( + sess.query(Manager).order_by(Person.person_id).all(), + [Manager(name="m1", employee_data="ed2", manager_data="md1")], + ) sess.expunge_all() def go(): - eq_(sess.query(Person).with_polymorphic('*').order_by( - Person.person_id).all(), - [Person(name='p1'), - Employee(name='e1', employee_data='ed1'), - Manager( - name='m1', employee_data='ed2', manager_data='md1')]) + eq_( + sess.query(Person) + .with_polymorphic("*") + .order_by(Person.person_id) + .all(), + [ + Person(name="p1"), + Employee(name="e1", employee_data="ed1"), + Manager( + name="m1", employee_data="ed2", manager_data="md1" + ), + ], + ) + self.assert_sql_count(testing.db, go, 1) @@ -1162,7 +1414,8 @@ class EagerDefaultEvalTest(fixtures.DeclarativeMappedTest): class Foo(Base): __tablename__ = "foo" id = Column( - Integer, primary_key=True, test_needs_autoincrement=True) + Integer, primary_key=True, test_needs_autoincrement=True + ) type = Column(String(50)) created_at = Column(Integer, server_default="5") @@ -1170,7 +1423,7 @@ class EagerDefaultEvalTest(fixtures.DeclarativeMappedTest): "polymorphic_on": type, "polymorphic_identity": "foo", "eager_defaults": True, - "with_polymorphic": with_polymorphic + "with_polymorphic": with_polymorphic, } class Bar(Foo): @@ -1178,9 +1431,7 @@ class EagerDefaultEvalTest(fixtures.DeclarativeMappedTest): if include_sub_defaults: bat = Column(Integer, server_default="10") - __mapper_args__ = { - "polymorphic_identity": "bar", - } + __mapper_args__ = {"polymorphic_identity": "bar"} def test_persist_foo(self): Foo = self.classes.Foo @@ -1191,9 +1442,9 @@ class EagerDefaultEvalTest(fixtures.DeclarativeMappedTest): session.add(foo) session.flush() - eq_(foo.__dict__['created_at'], 5) + eq_(foo.__dict__["created_at"], 5) - assert 'bat' not in foo.__dict__ + assert "bat" not in foo.__dict__ session.close() @@ -1204,10 +1455,10 @@ class EagerDefaultEvalTest(fixtures.DeclarativeMappedTest): session.add(bar) session.flush() - eq_(bar.__dict__['created_at'], 5) + eq_(bar.__dict__["created_at"], 5) - if 'bat' in inspect(Bar).attrs: - eq_(bar.__dict__['bat'], 10) + if "bat" in inspect(Bar).attrs: + eq_(bar.__dict__["bat"], 10) session.close() @@ -1216,11 +1467,13 @@ class EagerDefaultEvalTestSubDefaults(EagerDefaultEvalTest): @classmethod def setup_classes(cls): super(EagerDefaultEvalTestSubDefaults, cls).setup_classes( - include_sub_defaults=True) + include_sub_defaults=True + ) class EagerDefaultEvalTestPolymorphic(EagerDefaultEvalTest): @classmethod def setup_classes(cls): super(EagerDefaultEvalTestPolymorphic, cls).setup_classes( - with_polymorphic="*") + with_polymorphic="*" + ) diff --git a/test/orm/inheritance/test_with_poly.py b/test/orm/inheritance/test_with_poly.py index fb6b5a881f..014eb15345 100644 --- a/test/orm/inheritance/test_with_poly.py +++ b/test/orm/inheritance/test_with_poly.py @@ -1,8 +1,19 @@ from sqlalchemy import Integer, String, ForeignKey, func, desc, and_, or_ -from sqlalchemy.orm import interfaces, relationship, mapper, \ - clear_mappers, create_session, joinedload, joinedload_all, \ - subqueryload, subqueryload_all, polymorphic_union, aliased,\ - class_mapper, with_polymorphic +from sqlalchemy.orm import ( + interfaces, + relationship, + mapper, + clear_mappers, + create_session, + joinedload, + joinedload_all, + subqueryload, + subqueryload_all, + polymorphic_union, + aliased, + class_mapper, + with_polymorphic, +) from sqlalchemy import exc as sa_exc from sqlalchemy.engine import default @@ -11,10 +22,21 @@ from sqlalchemy import testing from sqlalchemy.testing.schema import Table, Column from sqlalchemy.testing import assert_raises, eq_ -from ._poly_fixtures import Company, Person, Engineer, Manager, Boss, \ - Machine, Paperwork, _PolymorphicFixtureBase, _Polymorphic,\ - _PolymorphicPolymorphic, _PolymorphicUnions, _PolymorphicJoins,\ - _PolymorphicAliasedJoins +from ._poly_fixtures import ( + Company, + Person, + Engineer, + Manager, + Boss, + Machine, + Paperwork, + _PolymorphicFixtureBase, + _Polymorphic, + _PolymorphicPolymorphic, + _PolymorphicUnions, + _PolymorphicJoins, + _PolymorphicAliasedJoins, +) class _WithPolymorphicBase(_PolymorphicFixtureBase): @@ -23,24 +45,33 @@ class _WithPolymorphicBase(_PolymorphicFixtureBase): pa = with_polymorphic(Person, [Engineer]) def go(): - eq_(sess.query(pa) - .filter(pa.Engineer.primary_language == 'java').all(), - self._emps_wo_relationships_fixture()[0:1]) + eq_( + sess.query(pa) + .filter(pa.Engineer.primary_language == "java") + .all(), + self._emps_wo_relationships_fixture()[0:1], + ) + self.assert_sql_count(testing.db, go, 1) def test_col_expression_base_plus_two_subs(self): sess = create_session() pa = with_polymorphic(Person, [Engineer, Manager]) - eq_(sess.query( - pa.name, pa.Engineer.primary_language, - pa.Manager.manager_name).filter( - or_( - pa.Engineer.primary_language == 'java', pa.Manager. - manager_name - == 'dogbert')).order_by(pa.Engineer.type).all(), - [('dilbert', 'java', None), - ('dogbert', None, 'dogbert'), ]) + eq_( + sess.query( + pa.name, pa.Engineer.primary_language, pa.Manager.manager_name + ) + .filter( + or_( + pa.Engineer.primary_language == "java", + pa.Manager.manager_name == "dogbert", + ) + ) + .order_by(pa.Engineer.type) + .all(), + [("dilbert", "java", None), ("dogbert", None, "dogbert")], + ) def test_join_to_join_entities(self): sess = create_session() @@ -48,24 +79,29 @@ class _WithPolymorphicBase(_PolymorphicFixtureBase): pa_alias = with_polymorphic(Person, [Engineer], aliased=True) eq_( - [(p1.name, type(p1), p2.name, type(p2)) for (p1, p2) in sess.query( - pa, pa_alias - ).join(pa_alias, - or_( - pa.Engineer.primary_language == - pa_alias.Engineer.primary_language, - and_( - pa.Engineer.primary_language == None, # noqa - pa_alias.Engineer.primary_language == None, - pa.person_id > pa_alias.person_id - )) - ).order_by(pa.name, pa_alias.name)], [ - ('dilbert', Engineer, 'dilbert', Engineer), - ('dogbert', Manager, 'pointy haired boss', Boss), - ('vlad', Engineer, 'vlad', Engineer), - ('wally', Engineer, 'wally', Engineer) - ] + (p1.name, type(p1), p2.name, type(p2)) + for (p1, p2) in sess.query(pa, pa_alias) + .join( + pa_alias, + or_( + pa.Engineer.primary_language + == pa_alias.Engineer.primary_language, + and_( + pa.Engineer.primary_language == None, # noqa + pa_alias.Engineer.primary_language == None, + pa.person_id > pa_alias.person_id, + ), + ), + ) + .order_by(pa.name, pa_alias.name) + ], + [ + ("dilbert", Engineer, "dilbert", Engineer), + ("dogbert", Manager, "pointy haired boss", Boss), + ("vlad", Engineer, "vlad", Engineer), + ("wally", Engineer, "wally", Engineer), + ], ) def test_join_to_join_columns(self): @@ -74,25 +110,34 @@ class _WithPolymorphicBase(_PolymorphicFixtureBase): pa_alias = with_polymorphic(Person, [Engineer], aliased=True) eq_( - [row for row in sess.query( - pa.name, pa.Engineer.primary_language, - pa_alias.name, pa_alias.Engineer.primary_language - ).join(pa_alias, - or_( - pa.Engineer.primary_language == - pa_alias.Engineer.primary_language, - and_( - pa.Engineer.primary_language == None, # noqa - pa_alias.Engineer.primary_language == None, - pa.person_id > pa_alias.person_id - )) - ).order_by(pa.name, pa_alias.name)], [ - ('dilbert', 'java', 'dilbert', 'java'), - ('dogbert', None, 'pointy haired boss', None), - ('vlad', 'cobol', 'vlad', 'cobol'), - ('wally', 'c++', 'wally', 'c++') - ] + row + for row in sess.query( + pa.name, + pa.Engineer.primary_language, + pa_alias.name, + pa_alias.Engineer.primary_language, + ) + .join( + pa_alias, + or_( + pa.Engineer.primary_language + == pa_alias.Engineer.primary_language, + and_( + pa.Engineer.primary_language == None, # noqa + pa_alias.Engineer.primary_language == None, + pa.person_id > pa_alias.person_id, + ), + ), + ) + .order_by(pa.name, pa_alias.name) + ], + [ + ("dilbert", "java", "dilbert", "java"), + ("dogbert", None, "pointy haired boss", None), + ("vlad", "cobol", "vlad", "cobol"), + ("wally", "c++", "wally", "c++"), + ], ) @@ -100,8 +145,9 @@ class PolymorphicTest(_WithPolymorphicBase, _Polymorphic): pass -class PolymorphicPolymorphicTest(_WithPolymorphicBase, - _PolymorphicPolymorphic): +class PolymorphicPolymorphicTest( + _WithPolymorphicBase, _PolymorphicPolymorphic +): pass @@ -109,8 +155,9 @@ class PolymorphicUnionsTest(_WithPolymorphicBase, _PolymorphicUnions): pass -class PolymorphicAliasedJoinsTest(_WithPolymorphicBase, - _PolymorphicAliasedJoins): +class PolymorphicAliasedJoinsTest( + _WithPolymorphicBase, _PolymorphicAliasedJoins +): pass diff --git a/test/orm/test_association.py b/test/orm/test_association.py index 288e781c60..823f000253 100644 --- a/test/orm/test_association.py +++ b/test/orm/test_association.py @@ -1,4 +1,3 @@ - from sqlalchemy import testing from sqlalchemy import Integer, String, ForeignKey, func, select from sqlalchemy.testing.schema import Table, Column @@ -8,23 +7,40 @@ from sqlalchemy.testing import eq_ class AssociationTest(fixtures.MappedTest): - run_setup_classes = 'once' - run_setup_mappers = 'once' + run_setup_classes = "once" + run_setup_mappers = "once" @classmethod def define_tables(cls, metadata): - Table('items', metadata, - Column('item_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(40))) - Table('item_keywords', metadata, - Column('item_id', Integer, ForeignKey('items.item_id')), - Column('keyword_id', Integer, ForeignKey('keywords.keyword_id')), - Column('data', String(40))) - Table('keywords', metadata, - Column('keyword_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(40))) + Table( + "items", + metadata, + Column( + "item_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(40)), + ) + Table( + "item_keywords", + metadata, + Column("item_id", Integer, ForeignKey("items.item_id")), + Column("keyword_id", Integer, ForeignKey("keywords.keyword_id")), + Column("data", String(40)), + ) + Table( + "keywords", + metadata, + Column( + "keyword_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(40)), + ) @classmethod def setup_classes(cls): @@ -34,7 +50,10 @@ class AssociationTest(fixtures.MappedTest): def __repr__(self): return "Item id=%d name=%s keywordassoc=%r" % ( - self.item_id, self.name, self.keywords) + self.item_id, + self.name, + self.keywords, + ) class Keyword(cls.Basic): def __init__(self, name): @@ -50,41 +69,60 @@ class AssociationTest(fixtures.MappedTest): def __repr__(self): return "KeywordAssociation itemid=%d keyword=%r data=%s" % ( - self.item_id, self.keyword, self.data) + self.item_id, + self.keyword, + self.data, + ) @classmethod def setup_mappers(cls): - KeywordAssociation, Item, Keyword = (cls.classes.KeywordAssociation, - cls.classes.Item, - cls.classes.Keyword) + KeywordAssociation, Item, Keyword = ( + cls.classes.KeywordAssociation, + cls.classes.Item, + cls.classes.Keyword, + ) items, item_keywords, keywords = cls.tables.get_all( - 'items', 'item_keywords', 'keywords') + "items", "item_keywords", "keywords" + ) mapper(Keyword, keywords) - mapper(KeywordAssociation, item_keywords, properties={ - 'keyword': relationship(Keyword, lazy='joined')}, - primary_key=[item_keywords.c.item_id, item_keywords.c.keyword_id]) - - mapper(Item, items, properties={ - 'keywords': relationship(KeywordAssociation, - order_by=item_keywords.c.data, - cascade="all, delete-orphan") - }) + mapper( + KeywordAssociation, + item_keywords, + properties={"keyword": relationship(Keyword, lazy="joined")}, + primary_key=[item_keywords.c.item_id, item_keywords.c.keyword_id], + ) + + mapper( + Item, + items, + properties={ + "keywords": relationship( + KeywordAssociation, + order_by=item_keywords.c.data, + cascade="all, delete-orphan", + ) + }, + ) def test_insert(self): - KeywordAssociation, Item, Keyword = (self.classes.KeywordAssociation, - self.classes.Item, - self.classes.Keyword) + KeywordAssociation, Item, Keyword = ( + self.classes.KeywordAssociation, + self.classes.Item, + self.classes.Keyword, + ) sess = create_session() - item1 = Item('item1') - item2 = Item('item2') - item1.keywords.append(KeywordAssociation( - Keyword('blue'), 'blue_assoc')) - item1.keywords.append(KeywordAssociation(Keyword('red'), 'red_assoc')) - item2.keywords.append(KeywordAssociation( - Keyword('green'), 'green_assoc')) + item1 = Item("item1") + item2 = Item("item2") + item1.keywords.append( + KeywordAssociation(Keyword("blue"), "blue_assoc") + ) + item1.keywords.append(KeywordAssociation(Keyword("red"), "red_assoc")) + item2.keywords.append( + KeywordAssociation(Keyword("green"), "green_assoc") + ) sess.add_all((item1, item2)) sess.flush() saved = repr([item1, item2]) @@ -94,21 +132,24 @@ class AssociationTest(fixtures.MappedTest): eq_(saved, loaded) def test_replace(self): - KeywordAssociation, Item, Keyword = (self.classes.KeywordAssociation, - self.classes.Item, - self.classes.Keyword) + KeywordAssociation, Item, Keyword = ( + self.classes.KeywordAssociation, + self.classes.Item, + self.classes.Keyword, + ) sess = create_session() - item1 = Item('item1') - item1.keywords.append(KeywordAssociation( - Keyword('blue'), 'blue_assoc')) - item1.keywords.append(KeywordAssociation(Keyword('red'), 'red_assoc')) + item1 = Item("item1") + item1.keywords.append( + KeywordAssociation(Keyword("blue"), "blue_assoc") + ) + item1.keywords.append(KeywordAssociation(Keyword("red"), "red_assoc")) sess.add(item1) sess.flush() red_keyword = item1.keywords[1].keyword del item1.keywords[1] - item1.keywords.append(KeywordAssociation(red_keyword, 'new_red_assoc')) + item1.keywords.append(KeywordAssociation(red_keyword, "new_red_assoc")) sess.flush() saved = repr([item1]) sess.expunge_all() @@ -117,32 +158,39 @@ class AssociationTest(fixtures.MappedTest): eq_(saved, loaded) def test_modify(self): - KeywordAssociation, Item, Keyword = (self.classes.KeywordAssociation, - self.classes.Item, - self.classes.Keyword) + KeywordAssociation, Item, Keyword = ( + self.classes.KeywordAssociation, + self.classes.Item, + self.classes.Keyword, + ) sess = create_session() - item1 = Item('item1') - item2 = Item('item2') - item1.keywords.append(KeywordAssociation( - Keyword('blue'), 'blue_assoc')) - item1.keywords.append(KeywordAssociation(Keyword('red'), 'red_assoc')) - item2.keywords.append(KeywordAssociation( - Keyword('green'), 'green_assoc')) + item1 = Item("item1") + item2 = Item("item2") + item1.keywords.append( + KeywordAssociation(Keyword("blue"), "blue_assoc") + ) + item1.keywords.append(KeywordAssociation(Keyword("red"), "red_assoc")) + item2.keywords.append( + KeywordAssociation(Keyword("green"), "green_assoc") + ) sess.add_all((item1, item2)) sess.flush() red_keyword = item1.keywords[1].keyword del item1.keywords[0] del item1.keywords[0] - purple_keyword = Keyword('purple') - item1.keywords.append(KeywordAssociation(red_keyword, 'new_red_assoc')) - item2.keywords.append(KeywordAssociation( - purple_keyword, 'purple_item2_assoc')) - item1.keywords.append(KeywordAssociation( - purple_keyword, 'purple_item1_assoc')) - item1.keywords.append(KeywordAssociation( - Keyword('yellow'), 'yellow_assoc')) + purple_keyword = Keyword("purple") + item1.keywords.append(KeywordAssociation(red_keyword, "new_red_assoc")) + item2.keywords.append( + KeywordAssociation(purple_keyword, "purple_item2_assoc") + ) + item1.keywords.append( + KeywordAssociation(purple_keyword, "purple_item1_assoc") + ) + item1.keywords.append( + KeywordAssociation(Keyword("yellow"), "yellow_assoc") + ) sess.flush() saved = repr([item1, item2]) @@ -158,18 +206,20 @@ class AssociationTest(fixtures.MappedTest): Keyword = self.classes.Keyword sess = create_session() - item1 = Item('item1') - item2 = Item('item2') - item1.keywords.append(KeywordAssociation( - Keyword('blue'), 'blue_assoc')) - item1.keywords.append(KeywordAssociation(Keyword('red'), 'red_assoc')) - item2.keywords.append(KeywordAssociation( - Keyword('green'), 'green_assoc')) + item1 = Item("item1") + item2 = Item("item2") + item1.keywords.append( + KeywordAssociation(Keyword("blue"), "blue_assoc") + ) + item1.keywords.append(KeywordAssociation(Keyword("red"), "red_assoc")) + item2.keywords.append( + KeywordAssociation(Keyword("green"), "green_assoc") + ) sess.add_all((item1, item2)) sess.flush() - eq_(select([func.count('*')]).select_from(item_keywords).scalar(), 3) + eq_(select([func.count("*")]).select_from(item_keywords).scalar(), 3) sess.delete(item1) sess.delete(item2) sess.flush() - eq_(select([func.count('*')]).select_from(item_keywords).scalar(), 0) + eq_(select([func.count("*")]).select_from(item_keywords).scalar(), 0) diff --git a/test/orm/test_assorted_eager.py b/test/orm/test_assorted_eager.py index affa14c0e9..711dd2a991 100644 --- a/test/orm/test_assorted_eager.py +++ b/test/orm/test_assorted_eager.py @@ -16,6 +16,7 @@ from sqlalchemy.orm import mapper, relationship, backref, create_session from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures + class EagerTest(fixtures.MappedTest): run_deletes = None run_inserts = "once" @@ -24,31 +25,57 @@ class EagerTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('owners', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30))) - - Table('categories', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(20))) - - Table('tests', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('owner_id', Integer, ForeignKey('owners.id'), - nullable=False), - Column('category_id', Integer, ForeignKey('categories.id'), - nullable=False)) - - Table('options', metadata, - Column('test_id', Integer, ForeignKey('tests.id'), - primary_key=True), - Column('owner_id', Integer, ForeignKey('owners.id'), - primary_key=True), - Column('someoption', sa.Boolean, server_default=sa.false(), - nullable=False)) + Table( + "owners", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + ) + + Table( + "categories", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(20)), + ) + + Table( + "tests", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column( + "owner_id", Integer, ForeignKey("owners.id"), nullable=False + ), + Column( + "category_id", + Integer, + ForeignKey("categories.id"), + nullable=False, + ), + ) + + Table( + "options", + metadata, + Column( + "test_id", Integer, ForeignKey("tests.id"), primary_key=True + ), + Column( + "owner_id", Integer, ForeignKey("owners.id"), primary_key=True + ), + Column( + "someoption", + sa.Boolean, + server_default=sa.false(), + nullable=False, + ), + ) @classmethod def setup_classes(cls): @@ -74,50 +101,73 @@ class EagerTest(fixtures.MappedTest): cls.classes.Thing, cls.classes.Owner, cls.tables.options, - cls.tables.categories) + cls.tables.categories, + ) mapper(Owner, owners) mapper(Category, categories) - mapper(Option, options, properties=dict( - owner=relationship(Owner, viewonly=True), - test=relationship(Thing, viewonly=True))) + mapper( + Option, + options, + properties=dict( + owner=relationship(Owner, viewonly=True), + test=relationship(Thing, viewonly=True), + ), + ) - mapper(Thing, tests, properties=dict( - owner=relationship(Owner, backref='tests'), - category=relationship(Category), - owner_option=relationship( - Option, primaryjoin=sa.and_( - tests.c.id == options.c.test_id, - tests.c.owner_id == options.c.owner_id), - foreign_keys=[options.c.test_id, options.c.owner_id], - uselist=False))) + mapper( + Thing, + tests, + properties=dict( + owner=relationship(Owner, backref="tests"), + category=relationship(Category), + owner_option=relationship( + Option, + primaryjoin=sa.and_( + tests.c.id == options.c.test_id, + tests.c.owner_id == options.c.owner_id, + ), + foreign_keys=[options.c.test_id, options.c.owner_id], + uselist=False, + ), + ), + ) @classmethod def insert_data(cls): - Owner, Category, Option, Thing = (cls.classes.Owner, - cls.classes.Category, - cls.classes.Option, - cls.classes.Thing) + Owner, Category, Option, Thing = ( + cls.classes.Owner, + cls.classes.Category, + cls.classes.Option, + cls.classes.Thing, + ) session = create_session() o = Owner() - c = Category(name='Some Category') - session.add_all(( - Thing(owner=o, category=c), - Thing(owner=o, category=c, owner_option=Option(someoption=True)), - Thing(owner=o, category=c, owner_option=Option()))) + c = Category(name="Some Category") + session.add_all( + ( + Thing(owner=o, category=c), + Thing( + owner=o, category=c, owner_option=Option(someoption=True) + ), + Thing(owner=o, category=c, owner_option=Option()), + ) + ) session.flush() def test_noorm(self): """test the control case""" - tests, options, categories = (self.tables.tests, - self.tables.options, - self.tables.categories) + tests, options, categories = ( + self.tables.tests, + self.tables.options, + self.tables.categories, + ) # I want to display a list of tests owned by owner 1 # if someoption is false or they haven't specified it yet (null) @@ -130,36 +180,64 @@ class EagerTest(fixtures.MappedTest): # not orm style correct query print("Obtaining correct results without orm") - result = sa.select( - [tests.c.id, categories.c.name], - sa.and_(tests.c.owner_id == 1, - sa.or_(options.c.someoption == None, # noqa - options.c.someoption == False)), - order_by=[tests.c.id], - from_obj=[tests.join(categories).outerjoin(options, sa.and_( - tests.c.id == options.c.test_id, - tests.c.owner_id == options.c.owner_id))] - ).execute().fetchall() - eq_(result, [(1, 'Some Category'), (3, 'Some Category')]) + result = ( + sa.select( + [tests.c.id, categories.c.name], + sa.and_( + tests.c.owner_id == 1, + sa.or_( + options.c.someoption == None, # noqa + options.c.someoption == False, + ), + ), + order_by=[tests.c.id], + from_obj=[ + tests.join(categories).outerjoin( + options, + sa.and_( + tests.c.id == options.c.test_id, + tests.c.owner_id == options.c.owner_id, + ), + ) + ], + ) + .execute() + .fetchall() + ) + eq_(result, [(1, "Some Category"), (3, "Some Category")]) def test_withoutjoinedload(self): - Thing, tests, options = (self.classes.Thing, - self.tables.tests, - self.tables.options) + Thing, tests, options = ( + self.classes.Thing, + self.tables.tests, + self.tables.options, + ) s = create_session() - result = (s.query(Thing) - .select_from(tests.outerjoin( - options, - sa.and_(tests.c.id == options.c.test_id, - tests.c.owner_id == options.c.owner_id))) - .filter(sa.and_( - tests.c.owner_id == 1, - sa.or_(options.c.someoption == None, # noqa - options.c.someoption == False)))) + result = ( + s.query(Thing) + .select_from( + tests.outerjoin( + options, + sa.and_( + tests.c.id == options.c.test_id, + tests.c.owner_id == options.c.owner_id, + ), + ) + ) + .filter( + sa.and_( + tests.c.owner_id == 1, + sa.or_( + options.c.someoption == None, # noqa + options.c.someoption == False, + ), + ) + ) + ) result_str = ["%d %s" % (t.id, t.category.name) for t in result] - eq_(result_str, ['1 Some Category', '3 Some Category']) + eq_(result_str, ["1 Some Category", "3 Some Category"]) def test_withjoinedload(self): """ @@ -169,90 +247,122 @@ class EagerTest(fixtures.MappedTest): """ - Thing, tests, options = (self.classes.Thing, - self.tables.tests, - self.tables.options) + Thing, tests, options = ( + self.classes.Thing, + self.tables.tests, + self.tables.options, + ) s = create_session() - q = s.query(Thing).options(sa.orm.joinedload('category')) + q = s.query(Thing).options(sa.orm.joinedload("category")) - result = (q.select_from(tests.outerjoin(options, - sa.and_(tests.c.id == - options.c.test_id, - tests.c.owner_id == - options.c.owner_id))). - filter(sa.and_(tests.c.owner_id == 1, - sa.or_(options.c.someoption == None, # noqa - options.c.someoption == False)))) + result = q.select_from( + tests.outerjoin( + options, + sa.and_( + tests.c.id == options.c.test_id, + tests.c.owner_id == options.c.owner_id, + ), + ) + ).filter( + sa.and_( + tests.c.owner_id == 1, + sa.or_( + options.c.someoption == None, # noqa + options.c.someoption == False, + ), + ) + ) result_str = ["%d %s" % (t.id, t.category.name) for t in result] - eq_(result_str, ['1 Some Category', '3 Some Category']) + eq_(result_str, ["1 Some Category", "3 Some Category"]) def test_dslish(self): """test the same as withjoinedload except using generative""" - Thing, tests, options = (self.classes.Thing, - self.tables.tests, - self.tables.options) + Thing, tests, options = ( + self.classes.Thing, + self.tables.tests, + self.tables.options, + ) s = create_session() - q = s.query(Thing).options(sa.orm.joinedload('category')) + q = s.query(Thing).options(sa.orm.joinedload("category")) result = q.filter( - sa.and_(tests.c.owner_id == 1, - sa.or_(options.c.someoption == None, # noqa - options.c.someoption == False)) - ).outerjoin('owner_option') + sa.and_( + tests.c.owner_id == 1, + sa.or_( + options.c.someoption == None, # noqa + options.c.someoption == False, + ), + ) + ).outerjoin("owner_option") result_str = ["%d %s" % (t.id, t.category.name) for t in result] - eq_(result_str, ['1 Some Category', '3 Some Category']) + eq_(result_str, ["1 Some Category", "3 Some Category"]) - @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on') + @testing.crashes("sybase", "FIXME: unknown, verify not fails_on") def test_without_outerjoin_literal(self): - Thing, tests= (self.classes.Thing, - self.tables.tests) + Thing, tests = (self.classes.Thing, self.tables.tests) s = create_session() - q = s.query(Thing).options(sa.orm.joinedload('category')) - result = (q.filter( - (tests.c.owner_id == 1) & - text( - 'options.someoption is null or options.someoption=:opt' - ).bindparams(opt=False)).join('owner_option')) + q = s.query(Thing).options(sa.orm.joinedload("category")) + result = q.filter( + (tests.c.owner_id == 1) + & text( + "options.someoption is null or options.someoption=:opt" + ).bindparams(opt=False) + ).join("owner_option") result_str = ["%d %s" % (t.id, t.category.name) for t in result] - eq_(result_str, ['3 Some Category']) + eq_(result_str, ["3 Some Category"]) def test_withoutouterjoin(self): - Thing, tests, options = (self.classes.Thing, - self.tables.tests, - self.tables.options) + Thing, tests, options = ( + self.classes.Thing, + self.tables.tests, + self.tables.options, + ) s = create_session() - q = s.query(Thing).options(sa.orm.joinedload('category')) + q = s.query(Thing).options(sa.orm.joinedload("category")) result = q.filter( - (tests.c.owner_id == 1) & - ((options.c.someoption == None) | (options.c.someoption == False)) # noqa - ).join('owner_option') + (tests.c.owner_id == 1) + & ( + (options.c.someoption == None) + | (options.c.someoption == False) + ) # noqa + ).join("owner_option") result_str = ["%d %s" % (t.id, t.category.name) for t in result] - eq_(result_str, ['3 Some Category']) + eq_(result_str, ["3 Some Category"]) class EagerTest2(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('left', metadata, - Column('id', Integer, ForeignKey('middle.id'), primary_key=True), - Column('data', String(50), primary_key=True)) - - Table('middle', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50))) - - Table('right', metadata, - Column('id', Integer, ForeignKey('middle.id'), primary_key=True), - Column('data', String(50), primary_key=True)) + Table( + "left", + metadata, + Column("id", Integer, ForeignKey("middle.id"), primary_key=True), + Column("data", String(50), primary_key=True), + ) + + Table( + "middle", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + + Table( + "right", + metadata, + Column("id", Integer, ForeignKey("middle.id"), primary_key=True), + Column("data", String(50), primary_key=True), + ) @classmethod def setup_classes(cls): @@ -270,23 +380,34 @@ class EagerTest2(fixtures.MappedTest): @classmethod def setup_mappers(cls): - Right, Middle, middle, right, left, Left = (cls.classes.Right, - cls.classes.Middle, - cls.tables.middle, - cls.tables.right, - cls.tables.left, - cls.classes.Left) + Right, Middle, middle, right, left, Left = ( + cls.classes.Right, + cls.classes.Middle, + cls.tables.middle, + cls.tables.right, + cls.tables.left, + cls.classes.Left, + ) # set up bi-directional eager loads mapper(Left, left) mapper(Right, right) - mapper(Middle, middle, properties=dict( - left=relationship(Left, - lazy='joined', - backref=backref('middle', lazy='joined')), - right=relationship(Right, - lazy='joined', - backref=backref('middle', lazy='joined')))), + mapper( + Middle, + middle, + properties=dict( + left=relationship( + Left, + lazy="joined", + backref=backref("middle", lazy="joined"), + ), + right=relationship( + Right, + lazy="joined", + backref=backref("middle", lazy="joined"), + ), + ), + ), def test_eager_terminate(self): """Eager query generation does not include the same mapper's table twice. @@ -296,19 +417,21 @@ class EagerTest2(fixtures.MappedTest): """ - Middle, Right, Left = (self.classes.Middle, - self.classes.Right, - self.classes.Left) + Middle, Right, Left = ( + self.classes.Middle, + self.classes.Right, + self.classes.Left, + ) - p = Middle('m1') - p.left.append(Left('l1')) - p.right.append(Right('r1')) + p = Middle("m1") + p.left.append(Left("l1")) + p.right.append(Right("r1")) session = create_session() session.add(p) session.flush() session.expunge_all() - obj = session.query(Left).filter_by(data='l1').one() + obj = session.query(Left).filter_by(data="l1").one() class EagerTest3(fixtures.MappedTest): @@ -317,21 +440,33 @@ class EagerTest3(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('datas', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('a', Integer, nullable=False)) - - Table('foo', metadata, - Column('data_id', Integer, ForeignKey('datas.id'), - primary_key=True), - Column('bar', Integer)) - - Table('stats', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data_id', Integer, ForeignKey('datas.id')), - Column('somedata', Integer, nullable=False)) + Table( + "datas", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("a", Integer, nullable=False), + ) + + Table( + "foo", + metadata, + Column( + "data_id", Integer, ForeignKey("datas.id"), primary_key=True + ), + Column("bar", Integer), + ) + + Table( + "stats", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data_id", Integer, ForeignKey("datas.id")), + Column("somedata", Integer, nullable=False), + ) @classmethod def setup_classes(cls): @@ -345,61 +480,76 @@ class EagerTest3(fixtures.MappedTest): pass def test_nesting_with_functions(self): - Stat, Foo, stats, foo, Data, datas = (self.classes.Stat, - self.classes.Foo, - self.tables.stats, - self.tables.foo, - self.classes.Data, - self.tables.datas) + Stat, Foo, stats, foo, Data, datas = ( + self.classes.Stat, + self.classes.Foo, + self.tables.stats, + self.tables.foo, + self.classes.Data, + self.tables.datas, + ) mapper(Data, datas) - mapper(Foo, foo, properties={ - 'data': relationship(Data, backref=backref('foo', uselist=False))}) + mapper( + Foo, + foo, + properties={ + "data": relationship( + Data, backref=backref("foo", uselist=False) + ) + }, + ) - mapper(Stat, stats, properties={ - 'data': relationship(Data)}) + mapper(Stat, stats, properties={"data": relationship(Data)}) session = create_session() data = [Data(a=x) for x in range(5)] session.add_all(data) - session.add_all(( - Stat(data=data[0], somedata=1), - Stat(data=data[1], somedata=2), - Stat(data=data[2], somedata=3), - Stat(data=data[3], somedata=4), - Stat(data=data[4], somedata=5), - Stat(data=data[0], somedata=6), - Stat(data=data[1], somedata=7), - Stat(data=data[2], somedata=8), - Stat(data=data[3], somedata=9), - Stat(data=data[4], somedata=10))) + session.add_all( + ( + Stat(data=data[0], somedata=1), + Stat(data=data[1], somedata=2), + Stat(data=data[2], somedata=3), + Stat(data=data[3], somedata=4), + Stat(data=data[4], somedata=5), + Stat(data=data[0], somedata=6), + Stat(data=data[1], somedata=7), + Stat(data=data[2], somedata=8), + Stat(data=data[3], somedata=9), + Stat(data=data[4], somedata=10), + ) + ) session.flush() arb_data = sa.select( - [stats.c.data_id, sa.func.max(stats.c.somedata).label('max')], + [stats.c.data_id, sa.func.max(stats.c.somedata).label("max")], stats.c.data_id <= 5, - group_by=[stats.c.data_id]) + group_by=[stats.c.data_id], + ) arb_result = arb_data.execute().fetchall() # order the result list descending based on 'max' - arb_result.sort(key=lambda a: a['max'], reverse=True) + arb_result.sort(key=lambda a: a["max"], reverse=True) # extract just the "data_id" from it - arb_result = [row['data_id'] for row in arb_result] + arb_result = [row["data_id"] for row in arb_result] - arb_data = arb_data.alias('arb') + arb_data = arb_data.alias("arb") # now query for Data objects using that above select, adding the # "order by max desc" separately - q = (session.query(Data). - options(sa.orm.joinedload('foo')). - select_from(datas.join(arb_data, - arb_data.c.data_id == datas.c.id)). - order_by(sa.desc(arb_data.c.max)). - limit(10)) + q = ( + session.query(Data) + .options(sa.orm.joinedload("foo")) + .select_from( + datas.join(arb_data, arb_data.c.data_id == datas.c.id) + ) + .order_by(sa.desc(arb_data.c.max)) + .limit(10) + ) # extract "data_id" from the list of result objects verify_result = [d.id for d in q] @@ -408,20 +558,36 @@ class EagerTest3(fixtures.MappedTest): class EagerTest4(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('departments', metadata, - Column('department_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50))) - - Table('employees', metadata, - Column('person_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('department_id', Integer, - ForeignKey('departments.department_id'))) + Table( + "departments", + metadata, + Column( + "department_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + ) + + Table( + "employees", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column( + "department_id", + Integer, + ForeignKey("departments.department_id"), + ), + ) @classmethod def setup_classes(cls): @@ -433,32 +599,42 @@ class EagerTest4(fixtures.MappedTest): def test_basic(self): Department, Employee, employees, departments = ( - self.classes.Department, self.classes.Employee, - self.tables.employees, self.tables.departments) + self.classes.Department, + self.classes.Employee, + self.tables.employees, + self.tables.departments, + ) mapper(Employee, employees) - mapper(Department, departments, properties=dict( - employees=relationship(Employee, - lazy='joined', - backref='department'))) - - d1 = Department(name='One') - for e in 'Jim', 'Jack', 'John', 'Susan': + mapper( + Department, + departments, + properties=dict( + employees=relationship( + Employee, lazy="joined", backref="department" + ) + ), + ) + + d1 = Department(name="One") + for e in "Jim", "Jack", "John", "Susan": d1.employees.append(Employee(name=e)) - d2 = Department(name='Two') - for e in 'Joe', 'Bob', 'Mary', 'Wally': + d2 = Department(name="Two") + for e in "Joe", "Bob", "Mary", "Wally": d2.employees.append(Employee(name=e)) sess = create_session() sess.add_all((d1, d2)) sess.flush() - q = (sess.query(Department). - join('employees'). - filter(Employee.name.startswith('J')). - distinct(). - order_by(sa.desc(Department.name))) + q = ( + sess.query(Department) + .join("employees") + .filter(Employee.name.startswith("J")) + .distinct() + .order_by(sa.desc(Department.name)) + ) eq_(q.count(), 2) assert q[0] is d2 @@ -470,28 +646,40 @@ class EagerTest5(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('base', metadata, - Column('uid', String(30), primary_key=True), - Column('x', String(30))) - - Table('derived', metadata, - Column('uid', String(30), - ForeignKey('base.uid'), - primary_key=True), - Column('y', String(30))) - - Table('derivedII', metadata, - Column('uid', String(30), - ForeignKey('base.uid'), - primary_key=True), - Column('z', String(30))) - - Table('comments', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('uid', String(30), - ForeignKey('base.uid')), - Column('comment', String(30))) + Table( + "base", + metadata, + Column("uid", String(30), primary_key=True), + Column("x", String(30)), + ) + + Table( + "derived", + metadata, + Column( + "uid", String(30), ForeignKey("base.uid"), primary_key=True + ), + Column("y", String(30)), + ) + + Table( + "derivedII", + metadata, + Column( + "uid", String(30), ForeignKey("base.uid"), primary_key=True + ), + Column("z", String(30)), + ) + + Table( + "comments", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("uid", String(30), ForeignKey("base.uid")), + Column("comment", String(30)), + ) @classmethod def setup_classes(cls): @@ -518,31 +706,38 @@ class EagerTest5(fixtures.MappedTest): self.comment = comment def test_basic(self): - Comment, Derived, derived, comments, \ - DerivedII, Base, base, derivedII = (self.classes.Comment, - self.classes.Derived, - self.tables.derived, - self.tables.comments, - self.classes.DerivedII, - self.classes.Base, - self.tables.base, - self.tables.derivedII) + Comment, Derived, derived, comments, DerivedII, Base, base, derivedII = ( + self.classes.Comment, + self.classes.Derived, + self.tables.derived, + self.tables.comments, + self.classes.DerivedII, + self.classes.Base, + self.tables.base, + self.tables.derivedII, + ) commentMapper = mapper(Comment, comments) - baseMapper = mapper(Base, base, properties=dict( - comments=relationship(Comment, lazy='joined', - cascade='all, delete-orphan'))) + baseMapper = mapper( + Base, + base, + properties=dict( + comments=relationship( + Comment, lazy="joined", cascade="all, delete-orphan" + ) + ), + ) mapper(Derived, derived, inherits=baseMapper) mapper(DerivedII, derivedII, inherits=baseMapper) sess = create_session() - d = Derived('uid1', 'x', 'y') - d.comments = [Comment('uid1', 'comment')] - d2 = DerivedII('uid2', 'xx', 'z') - d2.comments = [Comment('uid2', 'comment')] + d = Derived("uid1", "x", "y") + d.comments = [Comment("uid1", "comment")] + d2 = DerivedII("uid2", "xx", "z") + d2.comments = [Comment("uid2", "comment")] sess.add_all((d, d2)) sess.flush() sess.expunge_all() @@ -550,7 +745,7 @@ class EagerTest5(fixtures.MappedTest): # this eager load sets up an AliasedClauses for the "comment" # relationship, then stores it in clauses_by_lead_mapper[mapper for # Derived] - d = sess.query(Derived).get('uid1') + d = sess.query(Derived).get("uid1") sess.expunge_all() assert len([c for c in d.comments]) == 1 @@ -558,7 +753,7 @@ class EagerTest5(fixtures.MappedTest): # relationship, and should store it in clauses_by_lead_mapper[mapper # for DerivedII]. the bug was that the previous AliasedClause create # prevented this population from occurring. - d2 = sess.query(DerivedII).get('uid2') + d2 = sess.query(DerivedII).get("uid2") sess.expunge_all() # object is not in the session; therefore the lazy load cant trigger @@ -567,31 +762,64 @@ class EagerTest5(fixtures.MappedTest): class EagerTest6(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('design_types', metadata, - Column('design_type_id', Integer, primary_key=True, - test_needs_autoincrement=True)) - - Table('design', metadata, - Column('design_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('design_type_id', Integer, - ForeignKey('design_types.design_type_id'))) - - Table('parts', metadata, - Column('part_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('design_id', Integer, ForeignKey('design.design_id')), - Column('design_type_id', Integer, - ForeignKey('design_types.design_type_id'))) - - Table('inherited_part', metadata, - Column('ip_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('part_id', Integer, ForeignKey('parts.part_id')), - Column('design_id', Integer, ForeignKey('design.design_id'))) + Table( + "design_types", + metadata, + Column( + "design_type_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + ) + + Table( + "design", + metadata, + Column( + "design_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column( + "design_type_id", + Integer, + ForeignKey("design_types.design_type_id"), + ), + ) + + Table( + "parts", + metadata, + Column( + "part_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("design_id", Integer, ForeignKey("design.design_id")), + Column( + "design_type_id", + Integer, + ForeignKey("design_types.design_type_id"), + ), + ) + + Table( + "inherited_part", + metadata, + Column( + "ip_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("part_id", Integer, ForeignKey("parts.part_id")), + Column("design_id", Integer, ForeignKey("design.design_id")), + ) @classmethod def setup_classes(cls): @@ -608,35 +836,51 @@ class EagerTest6(fixtures.MappedTest): pass def test_one(self): - Part, inherited_part, design_types, DesignType, \ - parts, design, Design, InheritedPart = (self.classes.Part, - self.tables.inherited_part, - self.tables.design_types, - self.classes.DesignType, - self.tables.parts, - self.tables.design, - self.classes.Design, - self.classes.InheritedPart) + Part, inherited_part, design_types, DesignType, parts, design, Design, InheritedPart = ( + self.classes.Part, + self.tables.inherited_part, + self.tables.design_types, + self.classes.DesignType, + self.tables.parts, + self.tables.design, + self.classes.Design, + self.classes.InheritedPart, + ) p_m = mapper(Part, parts) - mapper(InheritedPart, inherited_part, properties=dict( - part=relationship(Part, lazy='joined'))) - - d_m = mapper(Design, design, properties=dict( - inheritedParts=relationship(InheritedPart, - cascade="all, delete-orphan", - backref="design"))) + mapper( + InheritedPart, + inherited_part, + properties=dict(part=relationship(Part, lazy="joined")), + ) + + d_m = mapper( + Design, + design, + properties=dict( + inheritedParts=relationship( + InheritedPart, + cascade="all, delete-orphan", + backref="design", + ) + ), + ) mapper(DesignType, design_types) d_m.add_property( - "type", relationship(DesignType, lazy='joined', backref="designs")) + "type", relationship(DesignType, lazy="joined", backref="designs") + ) p_m.add_property( - "design", relationship( - Design, lazy='joined', - backref=backref("parts", cascade="all, delete-orphan"))) + "design", + relationship( + Design, + lazy="joined", + backref=backref("parts", cascade="all, delete-orphan"), + ), + ) d = Design() sess = create_session() @@ -650,32 +894,57 @@ class EagerTest6(fixtures.MappedTest): class EagerTest7(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('companies', metadata, - Column('company_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('company_name', String(40))) - - Table('addresses', metadata, - Column('address_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('company_id', Integer, - ForeignKey("companies.company_id")), - Column('address', String(40))) - - Table('phone_numbers', metadata, - Column('phone_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('address_id', Integer, - ForeignKey('addresses.address_id')), - Column('type', String(20)), - Column('number', String(10))) - - Table('invoices', metadata, - Column('invoice_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('company_id', Integer, - ForeignKey("companies.company_id")), - Column('date', sa.DateTime)) + Table( + "companies", + metadata, + Column( + "company_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("company_name", String(40)), + ) + + Table( + "addresses", + metadata, + Column( + "address_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("company_id", Integer, ForeignKey("companies.company_id")), + Column("address", String(40)), + ) + + Table( + "phone_numbers", + metadata, + Column( + "phone_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("address_id", Integer, ForeignKey("addresses.address_id")), + Column("type", String(20)), + Column("number", String(10)), + ) + + Table( + "invoices", + metadata, + Column( + "invoice_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("company_id", Integer, ForeignKey("companies.company_id")), + Column("date", sa.DateTime), + ) @classmethod def setup_classes(cls): @@ -701,20 +970,31 @@ class EagerTest7(fixtures.MappedTest): """ addresses, invoices, Company, companies, Invoice, Address = ( - self.tables.addresses, self.tables.invoices, self.classes.Company, - self.tables.companies, self.classes.Invoice, self.classes.Address) + self.tables.addresses, + self.tables.invoices, + self.classes.Company, + self.tables.companies, + self.classes.Invoice, + self.classes.Address, + ) mapper(Address, addresses) - mapper(Company, companies, properties={ - 'addresses': relationship(Address, lazy='joined')}) - - mapper(Invoice, invoices, properties={ - 'company': relationship(Company, lazy='joined')}) + mapper( + Company, + companies, + properties={"addresses": relationship(Address, lazy="joined")}, + ) - a1 = Address(address='a1 address') - a2 = Address(address='a2 address') - c1 = Company(company_name='company 1', addresses=[a1, a2]) + mapper( + Invoice, + invoices, + properties={"company": relationship(Company, lazy="joined")}, + ) + + a1 = Address(address="a1 address") + a2 = Address(address="a2 address") + c1 = Company(company_name="company 1", addresses=[a1, a2]) i1 = Invoice(date=datetime.datetime.now(), company=c1) session = create_session() @@ -733,65 +1013,93 @@ class EagerTest7(fixtures.MappedTest): def go(): eq_(c, i.company) eq_(c.addresses, i.company.addresses) + self.assert_sql_count(testing.db, go, 0) class EagerTest8(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('prj', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('created', sa.DateTime), - Column('title', sa.String(100))) - - Table('task', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('status_id', Integer, ForeignKey('task_status.id'), - nullable=False), - Column('title', sa.String(100)), - Column('task_type_id', Integer, ForeignKey('task_type.id'), - nullable=False), - Column('prj_id', Integer, ForeignKey('prj.id'), - nullable=False)) - - Table('task_status', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True)) - - Table('task_type', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True)) - - Table('msg', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('posted', sa.DateTime, index=True,), - Column('type_id', Integer, ForeignKey('msg_type.id')), - Column('task_id', Integer, ForeignKey('task.id'))) - - Table('msg_type', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', sa.String(20)), - Column('display_name', sa.String(20))) + Table( + "prj", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("created", sa.DateTime), + Column("title", sa.String(100)), + ) + + Table( + "task", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column( + "status_id", + Integer, + ForeignKey("task_status.id"), + nullable=False, + ), + Column("title", sa.String(100)), + Column( + "task_type_id", + Integer, + ForeignKey("task_type.id"), + nullable=False, + ), + Column("prj_id", Integer, ForeignKey("prj.id"), nullable=False), + ) + + Table( + "task_status", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) + + Table( + "task_type", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) + + Table( + "msg", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("posted", sa.DateTime, index=True), + Column("type_id", Integer, ForeignKey("msg_type.id")), + Column("task_id", Integer, ForeignKey("task.id")), + ) + + Table( + "msg_type", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", sa.String(20)), + Column("display_name", sa.String(20)), + ) @classmethod def fixtures(cls): return dict( - prj=(('id',), - (1,)), - - task_status=(('id',), - (1,)), - - task_type=(('id',), - (1,),), - - task=(('title', 'task_type_id', 'status_id', 'prj_id'), - ('task 1', 1, 1, 1))) + prj=(("id",), (1,)), + task_status=(("id",), (1,)), + task_type=(("id",), (1,)), + task=( + ("title", "task_type_id", "status_id", "prj_id"), + ("task 1", 1, 1, 1), + ), + ) @classmethod def setup_classes(cls): @@ -802,12 +1110,14 @@ class EagerTest8(fixtures.MappedTest): pass def test_nested_joins(self): - task, Task_Type, Joined, prj, task_type, msg = (self.tables.task, - self.classes.Task_Type, - self.classes.Joined, - self.tables.prj, - self.tables.task_type, - self.tables.msg) + task, Task_Type, Joined, prj, task_type, msg = ( + self.tables.task, + self.classes.Task_Type, + self.classes.Joined, + self.tables.prj, + self.tables.task_type, + self.tables.msg, + ) # this is testing some subtle column resolution stuff, # concerning corresponding_column() being extremely accurate @@ -818,19 +1128,28 @@ class EagerTest8(fixtures.MappedTest): tsk_cnt_join = sa.outerjoin(prj, task, task.c.prj_id == prj.c.id) j = sa.outerjoin(task, msg, task.c.id == msg.c.task_id) - jj = sa.select([task.c.id.label('task_id'), - sa.func.count(msg.c.id).label('props_cnt')], - from_obj=[j], - group_by=[task.c.id]).alias('prop_c_s') + jj = sa.select( + [ + task.c.id.label("task_id"), + sa.func.count(msg.c.id).label("props_cnt"), + ], + from_obj=[j], + group_by=[task.c.id], + ).alias("prop_c_s") jjj = sa.join(task, jj, task.c.id == jj.c.task_id) - mapper(Joined, jjj, properties=dict( - type=relationship(Task_Type, lazy='joined'))) + mapper( + Joined, + jjj, + properties=dict(type=relationship(Task_Type, lazy="joined")), + ) session = create_session() - eq_(session.query(Joined).limit(10).offset(0).one(), - Joined(id=1, title='task 1', props_cnt=0)) + eq_( + session.query(Joined).limit(10).offset(0).one(), + Joined(id=1, title="task 1", props_cnt=0), + ) class EagerTest9(fixtures.MappedTest): @@ -841,25 +1160,50 @@ class EagerTest9(fixtures.MappedTest): throughout the query setup/mapper instances process. """ + @classmethod def define_tables(cls, metadata): - Table('accounts', metadata, - Column('account_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(40))) - - Table('transactions', metadata, - Column('transaction_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(40))) - - Table('entries', metadata, - Column('entry_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(40)), - Column('account_id', Integer, ForeignKey('accounts.account_id')), - Column('transaction_id', Integer, - ForeignKey('transactions.transaction_id'))) + Table( + "accounts", + metadata, + Column( + "account_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(40)), + ) + + Table( + "transactions", + metadata, + Column( + "transaction_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(40)), + ) + + Table( + "entries", + metadata, + Column( + "entry_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(40)), + Column("account_id", Integer, ForeignKey("accounts.account_id")), + Column( + "transaction_id", + Integer, + ForeignKey("transactions.transaction_id"), + ), + ) @classmethod def setup_classes(cls): @@ -875,45 +1219,58 @@ class EagerTest9(fixtures.MappedTest): @classmethod def setup_mappers(cls): Account, Transaction, transactions, accounts, entries, Entry = ( - cls.classes.Account, cls.classes.Transaction, cls.tables. - transactions, cls.tables.accounts, cls.tables.entries, cls.classes. - Entry) + cls.classes.Account, + cls.classes.Transaction, + cls.tables.transactions, + cls.tables.accounts, + cls.tables.entries, + cls.classes.Entry, + ) mapper(Account, accounts) mapper(Transaction, transactions) mapper( - Entry, entries, + Entry, + entries, properties=dict( account=relationship( - Account, uselist=False, + Account, + uselist=False, backref=backref( - 'entries', lazy='select', - order_by=entries.c.entry_id)), + "entries", lazy="select", order_by=entries.c.entry_id + ), + ), transaction=relationship( - Transaction, uselist=False, + Transaction, + uselist=False, backref=backref( - 'entries', lazy='joined', - order_by=entries.c.entry_id)))) + "entries", lazy="joined", order_by=entries.c.entry_id + ), + ), + ), + ) def test_joinedload_on_path(self): - Entry, Account, Transaction = (self.classes.Entry, - self.classes.Account, - self.classes.Transaction) + Entry, Account, Transaction = ( + self.classes.Entry, + self.classes.Account, + self.classes.Transaction, + ) session = create_session() - tx1 = Transaction(name='tx1') - tx2 = Transaction(name='tx2') + tx1 = Transaction(name="tx1") + tx2 = Transaction(name="tx2") - acc1 = Account(name='acc1') - ent11 = Entry(name='ent11', account=acc1, transaction=tx1) - ent12 = Entry(name='ent12', account=acc1, transaction=tx2) + acc1 = Account(name="acc1") + ent11 = Entry(name="ent11", account=acc1, transaction=tx1) + ent12 = Entry(name="ent12", account=acc1, transaction=tx2) - acc2 = Account(name='acc2') - ent21 = Entry(name='ent21', account=acc2, transaction=tx1) - ent22 = Entry(name='ent22', account=acc2, transaction=tx2) + acc2 = Account(name="acc2") + ent21 = Entry(name="ent21", account=acc2, transaction=tx1) + ent22 = Entry(name="ent22", account=acc2, transaction=tx2) session.add(acc1) session.flush() @@ -924,15 +1281,20 @@ class EagerTest9(fixtures.MappedTest): # all objects saved thus far, but will not eagerly load the # "accounts" off the immediate "entries"; only the "accounts" off # the entries->transaction->entries - acc = (session.query(Account).options( - sa.orm.joinedload_all( - 'entries.transaction.entries.account')).order_by( - Account.account_id)).first() + acc = ( + session.query(Account) + .options( + sa.orm.joinedload_all( + "entries.transaction.entries.account" + ) + ) + .order_by(Account.account_id) + ).first() # no sql occurs - eq_(acc.name, 'acc1') - eq_(acc.entries[0].transaction.entries[0].account.name, 'acc1') - eq_(acc.entries[0].transaction.entries[1].account.name, 'acc2') + eq_(acc.name, "acc1") + eq_(acc.entries[0].transaction.entries[0].account.name, "acc1") + eq_(acc.entries[0].transaction.entries[1].account.name, "acc2") # lazyload triggers but no sql occurs because many-to-one uses # cached query.get() diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index a830b39b81..9dd3889893 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -3,8 +3,15 @@ from sqlalchemy.orm import attributes, instrumentation, exc as orm_exc from sqlalchemy.orm.collections import collection from sqlalchemy.orm.interfaces import AttributeExtension from sqlalchemy import exc as sa_exc -from sqlalchemy.testing import eq_, ne_, assert_raises, \ - assert_raises_message, is_true, is_false, is_ +from sqlalchemy.testing import ( + eq_, + ne_, + assert_raises, + assert_raises_message, + is_true, + is_false, + is_, +) from sqlalchemy.testing import fixtures from sqlalchemy.testing.util import gc_collect, all_partial_orderings from sqlalchemy.util import jython @@ -20,7 +27,8 @@ MyTest2 = None def _set_callable(state, dict_, key, callable_): fn = InstanceState._instance_level_callable_processor( - state.manager, callable_, key) + state.manager, callable_, key + ) fn(state, dict_, None) @@ -31,6 +39,7 @@ class AttributeImplAPITest(fixtures.MappedTest): class B(object): pass + instrumentation.register_class(A) instrumentation.register_class(B) attributes.register_attribute(A, "b", uselist=False, useobject=True) @@ -42,6 +51,7 @@ class AttributeImplAPITest(fixtures.MappedTest): class B(object): pass + instrumentation.register_class(A) instrumentation.register_class(B) attributes.register_attribute(A, "b", uselist=True, useobject=True) @@ -56,7 +66,9 @@ class AttributeImplAPITest(fixtures.MappedTest): A.b.impl.append( attributes.instance_state(a1), - attributes.instance_dict(a1), b1, None + attributes.instance_dict(a1), + b1, + None, ) assert a1.b is b1 @@ -67,7 +79,9 @@ class AttributeImplAPITest(fixtures.MappedTest): "associated with on attribute 'b'", A.b.impl.remove, attributes.instance_state(a1), - attributes.instance_dict(a1), b2, None + attributes.instance_dict(a1), + b2, + None, ) def test_scalar_obj_pop_invalid(self): @@ -79,14 +93,18 @@ class AttributeImplAPITest(fixtures.MappedTest): A.b.impl.append( attributes.instance_state(a1), - attributes.instance_dict(a1), b1, None + attributes.instance_dict(a1), + b1, + None, ) assert a1.b is b1 A.b.impl.pop( attributes.instance_state(a1), - attributes.instance_dict(a1), b2, None + attributes.instance_dict(a1), + b2, + None, ) assert a1.b is b1 @@ -98,14 +116,18 @@ class AttributeImplAPITest(fixtures.MappedTest): A.b.impl.append( attributes.instance_state(a1), - attributes.instance_dict(a1), b1, None + attributes.instance_dict(a1), + b1, + None, ) assert a1.b is b1 A.b.impl.pop( attributes.instance_state(a1), - attributes.instance_dict(a1), b1, None + attributes.instance_dict(a1), + b1, + None, ) assert a1.b is None @@ -118,7 +140,9 @@ class AttributeImplAPITest(fixtures.MappedTest): A.b.impl.append( attributes.instance_state(a1), - attributes.instance_dict(a1), b1, None + attributes.instance_dict(a1), + b1, + None, ) assert a1.b == [b1] @@ -128,7 +152,9 @@ class AttributeImplAPITest(fixtures.MappedTest): r"list.remove\(.*?\): .* not in list", A.b.impl.remove, attributes.instance_state(a1), - attributes.instance_dict(a1), b2, None + attributes.instance_dict(a1), + b2, + None, ) def test_collection_obj_pop_invalid(self): @@ -140,14 +166,18 @@ class AttributeImplAPITest(fixtures.MappedTest): A.b.impl.append( attributes.instance_state(a1), - attributes.instance_dict(a1), b1, None + attributes.instance_dict(a1), + b1, + None, ) assert a1.b == [b1] A.b.impl.pop( attributes.instance_state(a1), - attributes.instance_dict(a1), b2, None + attributes.instance_dict(a1), + b2, + None, ) assert a1.b == [b1] @@ -159,14 +189,18 @@ class AttributeImplAPITest(fixtures.MappedTest): A.b.impl.append( attributes.instance_state(a1), - attributes.instance_dict(a1), b1, None + attributes.instance_dict(a1), + b1, + None, ) assert a1.b == [b1] A.b.impl.pop( attributes.instance_state(a1), - attributes.instance_dict(a1), b1, None + attributes.instance_dict(a1), + b1, + None, ) assert a1.b == [] @@ -190,53 +224,75 @@ class AttributesTest(fixtures.ORMTest): pass instrumentation.register_class(User) - attributes.register_attribute(User, 'user_id', uselist=False, - useobject=False) - attributes.register_attribute(User, 'user_name', uselist=False, - useobject=False) - attributes.register_attribute(User, 'email_address', - uselist=False, useobject=False) + attributes.register_attribute( + User, "user_id", uselist=False, useobject=False + ) + attributes.register_attribute( + User, "user_name", uselist=False, useobject=False + ) + attributes.register_attribute( + User, "email_address", uselist=False, useobject=False + ) u = User() u.user_id = 7 - u.user_name = 'john' - u.email_address = 'lala@123.com' - self.assert_(u.user_id == 7 and u.user_name == 'john' - and u.email_address == 'lala@123.com') + u.user_name = "john" + u.email_address = "lala@123.com" + self.assert_( + u.user_id == 7 + and u.user_name == "john" + and u.email_address == "lala@123.com" + ) attributes.instance_state(u)._commit_all(attributes.instance_dict(u)) - self.assert_(u.user_id == 7 and u.user_name == 'john' - and u.email_address == 'lala@123.com') - u.user_name = 'heythere' - u.email_address = 'foo@bar.com' - self.assert_(u.user_id == 7 and u.user_name == 'heythere' - and u.email_address == 'foo@bar.com') + self.assert_( + u.user_id == 7 + and u.user_name == "john" + and u.email_address == "lala@123.com" + ) + u.user_name = "heythere" + u.email_address = "foo@bar.com" + self.assert_( + u.user_id == 7 + and u.user_name == "heythere" + and u.email_address == "foo@bar.com" + ) def test_pickleness(self): instrumentation.register_class(MyTest) instrumentation.register_class(MyTest2) - attributes.register_attribute(MyTest, 'user_id', uselist=False, - useobject=False) - attributes.register_attribute(MyTest, 'user_name', - uselist=False, useobject=False) - attributes.register_attribute(MyTest, 'email_address', - uselist=False, useobject=False) - attributes.register_attribute(MyTest2, 'a', uselist=False, - useobject=False) - attributes.register_attribute(MyTest2, 'b', uselist=False, - useobject=False) + attributes.register_attribute( + MyTest, "user_id", uselist=False, useobject=False + ) + attributes.register_attribute( + MyTest, "user_name", uselist=False, useobject=False + ) + attributes.register_attribute( + MyTest, "email_address", uselist=False, useobject=False + ) + attributes.register_attribute( + MyTest2, "a", uselist=False, useobject=False + ) + attributes.register_attribute( + MyTest2, "b", uselist=False, useobject=False + ) # shouldn't be pickling callables at the class level def somecallable(state, passive): return None - attributes.register_attribute(MyTest, 'mt2', uselist=True, - trackparent=True, callable_=somecallable, - useobject=True) + attributes.register_attribute( + MyTest, + "mt2", + uselist=True, + trackparent=True, + callable_=somecallable, + useobject=True, + ) o = MyTest() o.mt2.append(MyTest2()) o.user_id = 7 - o.mt2[0].a = 'abcde' + o.mt2[0].a = "abcde" pk_o = pickle.dumps(o) o2 = pickle.loads(pk_o) @@ -258,7 +314,7 @@ class AttributesTest(fixtures.ORMTest): self.assert_(o4.user_name is None) self.assert_(o4.email_address is None) self.assert_(len(o4.mt2) == 1) - self.assert_(o4.mt2[0].a == 'abcde') + self.assert_(o4.mt2[0].a == "abcde") self.assert_(o4.mt2[0].b is None) @testing.requires.predictable_gc @@ -273,7 +329,7 @@ class AttributesTest(fixtures.ORMTest): f = Foo() state = attributes.instance_state(f) f.bar = "foo" - eq_(state.dict, {'bar': 'foo', state.manager.STATE_ATTR: state}) + eq_(state.dict, {"bar": "foo", state.manager.STATE_ATTR: state}) del f gc_collect() assert state.obj() is None @@ -290,17 +346,16 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, - 'bars', - uselist=True, - useobject=True) + attributes.register_attribute( + Foo, "bars", uselist=True, useobject=True + ) assert_raises_message( orm_exc.ObjectDereferencedError, "Can't emit change event for attribute " "'Foo.bars' - parent object of type " "has been garbage collected.", - lambda: Foo().bars.append(Bar()) + lambda: Foo().bars.append(Bar()), ) def test_del_scalar_nonobject(self): @@ -308,7 +363,7 @@ class AttributesTest(fixtures.ORMTest): pass instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'b', uselist=False, useobject=False) + attributes.register_attribute(Foo, "b", uselist=False, useobject=False) f1 = Foo() @@ -325,9 +380,7 @@ class AttributesTest(fixtures.ORMTest): del f1.b assert_raises_message( - AttributeError, - "Foo.b object does not have a value", - go + AttributeError, "Foo.b object does not have a value", go ) def test_del_scalar_object(self): @@ -339,7 +392,7 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, 'b', uselist=False, useobject=True) + attributes.register_attribute(Foo, "b", uselist=False, useobject=True) f1 = Foo() @@ -354,9 +407,7 @@ class AttributesTest(fixtures.ORMTest): del f1.b assert_raises_message( - AttributeError, - "Foo.b object does not have a value", - go + AttributeError, "Foo.b object does not have a value", go ) def test_del_collection_object(self): @@ -368,7 +419,7 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, 'b', uselist=True, useobject=True) + attributes.register_attribute(Foo, "b", uselist=True, useobject=True) f1 = Foo() @@ -386,7 +437,7 @@ class AttributesTest(fixtures.ORMTest): class Foo(object): pass - data = {'a': 'this is a', 'b': 12} + data = {"a": "this is a", "b": 12} def loader(state, keys): for k in keys: @@ -396,38 +447,43 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(Foo) manager = attributes.manager_of_class(Foo) manager.deferred_scalar_loader = loader - attributes.register_attribute(Foo, 'a', uselist=False, useobject=False) - attributes.register_attribute(Foo, 'b', uselist=False, useobject=False) + attributes.register_attribute(Foo, "a", uselist=False, useobject=False) + attributes.register_attribute(Foo, "b", uselist=False, useobject=False) f = Foo() - attributes.instance_state(f)._expire(attributes.instance_dict(f), - set()) - eq_(f.a, 'this is a') + attributes.instance_state(f)._expire( + attributes.instance_dict(f), set() + ) + eq_(f.a, "this is a") eq_(f.b, 12) - f.a = 'this is some new a' - attributes.instance_state(f)._expire(attributes.instance_dict(f), - set()) - eq_(f.a, 'this is a') + f.a = "this is some new a" + attributes.instance_state(f)._expire( + attributes.instance_dict(f), set() + ) + eq_(f.a, "this is a") eq_(f.b, 12) - attributes.instance_state(f)._expire(attributes.instance_dict(f), - set()) - f.a = 'this is another new a' - eq_(f.a, 'this is another new a') + attributes.instance_state(f)._expire( + attributes.instance_dict(f), set() + ) + f.a = "this is another new a" + eq_(f.a, "this is another new a") eq_(f.b, 12) - attributes.instance_state(f)._expire(attributes.instance_dict(f), - set()) - eq_(f.a, 'this is a') + attributes.instance_state(f)._expire( + attributes.instance_dict(f), set() + ) + eq_(f.a, "this is a") eq_(f.b, 12) del f.a eq_(f.a, None) eq_(f.b, 12) - attributes.instance_state(f)._commit_all(attributes.instance_dict(f), - set()) + attributes.instance_state(f)._commit_all( + attributes.instance_dict(f), set() + ) eq_(f.a, None) eq_(f.b, 12) def test_deferred_pickleable(self): - data = {'a': 'this is a', 'b': 12} + data = {"a": "this is a", "b": 12} def loader(state, keys): for k in keys: @@ -437,17 +493,20 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(MyTest) manager = attributes.manager_of_class(MyTest) manager.deferred_scalar_loader = loader - attributes.register_attribute(MyTest, 'a', uselist=False, - useobject=False) - attributes.register_attribute(MyTest, 'b', uselist=False, - useobject=False) + attributes.register_attribute( + MyTest, "a", uselist=False, useobject=False + ) + attributes.register_attribute( + MyTest, "b", uselist=False, useobject=False + ) m = MyTest() - attributes.instance_state(m)._expire(attributes.instance_dict(m), - set()) - assert 'a' not in m.__dict__ + attributes.instance_state(m)._expire( + attributes.instance_dict(m), set() + ) + assert "a" not in m.__dict__ m2 = pickle.loads(pickle.dumps(m)) - assert 'a' not in m2.__dict__ + assert "a" not in m2.__dict__ eq_(m2.a, "this is a") eq_(m2.b, 12) @@ -460,43 +519,58 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(User) instrumentation.register_class(Address) - attributes.register_attribute(User, 'user_id', uselist=False, - useobject=False) - attributes.register_attribute(User, 'user_name', uselist=False, - useobject=False) - attributes.register_attribute(User, 'addresses', uselist=True, - useobject=True) - attributes.register_attribute(Address, 'address_id', - uselist=False, useobject=False) - attributes.register_attribute(Address, 'email_address', - uselist=False, useobject=False) + attributes.register_attribute( + User, "user_id", uselist=False, useobject=False + ) + attributes.register_attribute( + User, "user_name", uselist=False, useobject=False + ) + attributes.register_attribute( + User, "addresses", uselist=True, useobject=True + ) + attributes.register_attribute( + Address, "address_id", uselist=False, useobject=False + ) + attributes.register_attribute( + Address, "email_address", uselist=False, useobject=False + ) u = User() u.user_id = 7 - u.user_name = 'john' + u.user_name = "john" u.addresses = [] a = Address() a.address_id = 10 - a.email_address = 'lala@123.com' + a.email_address = "lala@123.com" u.addresses.append(a) - self.assert_(u.user_id == 7 and u.user_name == 'john' - and u.addresses[0].email_address == 'lala@123.com') - (u, - attributes.instance_state(a)._commit_all(attributes.instance_dict(a))) - self.assert_(u.user_id == 7 and u.user_name == 'john' - and u.addresses[0].email_address == 'lala@123.com') + self.assert_( + u.user_id == 7 + and u.user_name == "john" + and u.addresses[0].email_address == "lala@123.com" + ) + ( + u, + attributes.instance_state(a)._commit_all( + attributes.instance_dict(a) + ), + ) + self.assert_( + u.user_id == 7 + and u.user_name == "john" + and u.addresses[0].email_address == "lala@123.com" + ) - u.user_name = 'heythere' + u.user_name = "heythere" a = Address() a.address_id = 11 - a.email_address = 'foo@bar.com' + a.email_address = "foo@bar.com" u.addresses.append(a) eq_(u.user_id, 7) - eq_(u.user_name, 'heythere') - eq_(u.addresses[0].email_address, 'lala@123.com') - eq_(u.addresses[1].email_address, 'foo@bar.com') + eq_(u.user_name, "heythere") + eq_(u.addresses[0].email_address, "lala@123.com") + eq_(u.addresses[1].email_address, "foo@bar.com") def test_extension_commit_attr(self): """test that an extension which commits attribute history @@ -534,7 +608,7 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - b1, b2, b3, b4 = Bar(id='b1'), Bar(id='b2'), Bar(id='b3'), Bar(id='b4') + b1, b2, b3, b4 = Bar(id="b1"), Bar(id="b2"), Bar(id="b3"), Bar(id="b4") def loadcollection(state, passive): if passive is attributes.PASSIVE_NO_FETCH: @@ -546,40 +620,50 @@ class AttributesTest(fixtures.ORMTest): return attributes.PASSIVE_NO_RESULT return b2 - attributes.register_attribute(Foo, 'bars', - uselist=True, - useobject=True, - callable_=loadcollection, - extension=[ReceiveEvents('bars')]) + attributes.register_attribute( + Foo, + "bars", + uselist=True, + useobject=True, + callable_=loadcollection, + extension=[ReceiveEvents("bars")], + ) - attributes.register_attribute(Foo, 'bar', - uselist=False, - useobject=True, - callable_=loadscalar, - extension=[ReceiveEvents('bar')]) + attributes.register_attribute( + Foo, + "bar", + uselist=False, + useobject=True, + callable_=loadscalar, + extension=[ReceiveEvents("bar")], + ) - attributes.register_attribute(Foo, 'scalar', - uselist=False, - useobject=False, - extension=[ReceiveEvents('scalar')]) + attributes.register_attribute( + Foo, + "scalar", + uselist=False, + useobject=False, + extension=[ReceiveEvents("scalar")], + ) def create_hist(): def hist(key, fn, *arg): attributes.instance_state(f1)._commit_all( - attributes.instance_dict(f1)) + attributes.instance_dict(f1) + ) fn(*arg) histories.append(attributes.get_history(f1, key)) f1 = Foo() - hist('bars', f1.bars.append, b3) - hist('bars', f1.bars.append, b4) - hist('bars', f1.bars.remove, b2) - hist('bar', setattr, f1, 'bar', b3) - hist('bar', setattr, f1, 'bar', None) - hist('bar', setattr, f1, 'bar', b4) - hist('scalar', setattr, f1, 'scalar', 5) - hist('scalar', setattr, f1, 'scalar', None) - hist('scalar', setattr, f1, 'scalar', 4) + hist("bars", f1.bars.append, b3) + hist("bars", f1.bars.append, b4) + hist("bars", f1.bars.remove, b2) + hist("bar", setattr, f1, "bar", b3) + hist("bar", setattr, f1, "bar", None) + hist("bar", setattr, f1, "bar", b4) + hist("scalar", setattr, f1, "scalar", 5) + hist("scalar", setattr, f1, "scalar", None) + hist("scalar", setattr, f1, "scalar", 4) histories = [] commit = False @@ -624,11 +708,17 @@ class AttributesTest(fixtures.ORMTest): return [bar1, bar2, bar3] - attributes.register_attribute(Foo, 'bars', uselist=True, - callable_=func1, useobject=True, - extension=[ReceiveEvents()]) - attributes.register_attribute(Bar, 'foos', uselist=True, - useobject=True, backref='bars') + attributes.register_attribute( + Foo, + "bars", + uselist=True, + callable_=func1, + useobject=True, + extension=[ReceiveEvents()], + ) + attributes.register_attribute( + Bar, "foos", uselist=True, useobject=True, backref="bars" + ) x = Foo() assert_raises(AssertionError, Bar(id=4).foos.append, x) @@ -637,7 +727,8 @@ class AttributesTest(fixtures.ORMTest): b = Bar(id=4) b.foos.append(x) attributes.instance_state(x)._expire_attributes( - attributes.instance_dict(x), ['bars']) + attributes.instance_dict(x), ["bars"] + ) assert_raises(AssertionError, b.foos.remove, x) def test_scalar_listener(self): @@ -663,19 +754,23 @@ class AttributesTest(fixtures.ORMTest): return child instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'x', uselist=False, useobject=False, - extension=ReceiveEvents()) + attributes.register_attribute( + Foo, "x", uselist=False, useobject=False, extension=ReceiveEvents() + ) f = Foo() f.x = 5 f.x = 17 del f.x - eq_(results, [ - ('set', f, 5, attributes.NEVER_SET), - ('set', f, 17, 5), - ('remove', f, 17), - ]) + eq_( + results, + [ + ("set", f, 5, attributes.NEVER_SET), + ("set", f, 17, 5), + ("remove", f, 17), + ], + ) def test_lazytrackparent(self): """test that the "hasparent" flag works properly @@ -688,33 +783,51 @@ class AttributesTest(fixtures.ORMTest): class Blog(object): pass + instrumentation.register_class(Post) instrumentation.register_class(Blog) # set up instrumented attributes with backrefs - attributes.register_attribute(Post, 'blog', uselist=False, - backref='posts', - trackparent=True, useobject=True) - attributes.register_attribute(Blog, 'posts', uselist=True, - backref='blog', - trackparent=True, useobject=True) + attributes.register_attribute( + Post, + "blog", + uselist=False, + backref="posts", + trackparent=True, + useobject=True, + ) + attributes.register_attribute( + Blog, + "posts", + uselist=True, + backref="blog", + trackparent=True, + useobject=True, + ) # create objects as if they'd been freshly loaded from the database # (without history) b = Blog() p1 = Post() - _set_callable(attributes.instance_state(b), - attributes.instance_dict(b), - 'posts', lambda state, passive: [p1]) - _set_callable(attributes.instance_state(p1), - attributes.instance_dict(p1), - 'blog', lambda state, passive: b) + _set_callable( + attributes.instance_state(b), + attributes.instance_dict(b), + "posts", + lambda state, passive: [p1], + ) + _set_callable( + attributes.instance_state(p1), + attributes.instance_dict(p1), + "blog", + lambda state, passive: b, + ) p1, attributes.instance_state(b)._commit_all( - attributes.instance_dict(b)) + attributes.instance_dict(b) + ) # no orphans (called before the lazy loaders fire off) - assert attributes.has_parent(Blog, p1, 'posts', optimistic=True) - assert attributes.has_parent(Post, b, 'blog', optimistic=True) + assert attributes.has_parent(Blog, p1, "posts", optimistic=True) + assert attributes.has_parent(Post, b, "blog", optimistic=True) # assert connections assert p1.blog is b @@ -724,8 +837,8 @@ class AttributesTest(fixtures.ORMTest): b2 = Blog() p2 = Post() b2.posts.append(p2) - assert attributes.has_parent(Blog, p2, 'posts') - assert attributes.has_parent(Post, b2, 'blog') + assert attributes.has_parent(Blog, p2, "posts") + assert attributes.has_parent(Post, b2, "blog") def test_illegal_trackparent(self): class Post(object): @@ -733,19 +846,26 @@ class AttributesTest(fixtures.ORMTest): class Blog(object): pass + instrumentation.register_class(Post) instrumentation.register_class(Blog) - attributes.register_attribute(Post, 'blog', useobject=True) + attributes.register_attribute(Post, "blog", useobject=True) assert_raises_message( AssertionError, "This AttributeImpl is not configured to track parents.", - attributes.has_parent, Post, Blog(), 'blog' + attributes.has_parent, + Post, + Blog(), + "blog", ) assert_raises_message( AssertionError, "This AttributeImpl is not configured to track parents.", - Post.blog.impl.sethasparent, "x", "x", True + Post.blog.impl.sethasparent, + "x", + "x", + True, ) def test_inheritance(self): @@ -768,19 +888,23 @@ class AttributesTest(fixtures.ORMTest): def func3(state, passive): return "this is the shared attr" - attributes.register_attribute(Foo, 'element', uselist=False, - callable_=func1, useobject=True) - attributes.register_attribute(Foo, 'element2', uselist=False, - callable_=func3, useobject=True) - attributes.register_attribute(Bar, 'element', uselist=False, - callable_=func2, useobject=True) + + attributes.register_attribute( + Foo, "element", uselist=False, callable_=func1, useobject=True + ) + attributes.register_attribute( + Foo, "element2", uselist=False, callable_=func3, useobject=True + ) + attributes.register_attribute( + Bar, "element", uselist=False, callable_=func2, useobject=True + ) x = Foo() y = Bar() - assert x.element == 'this is the foo attr' - assert y.element == 'this is the bar attr' - assert x.element2 == 'this is the shared attr' - assert y.element2 == 'this is the shared attr' + assert x.element == "this is the foo attr" + assert y.element == "this is the bar attr" + assert x.element2 == "this is the shared attr" + assert y.element2 == "this is the shared attr" def test_no_double_state(self): states = set() @@ -817,17 +941,22 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, 'element', uselist=False, - useobject=True) + attributes.register_attribute( + Foo, "element", uselist=False, useobject=True + ) el = Element() x = Bar() x.element = el - eq_(attributes.get_state_history(attributes.instance_state(x), - 'element'), ([el], (), ())) + eq_( + attributes.get_state_history( + attributes.instance_state(x), "element" + ), + ([el], (), ()), + ) attributes.instance_state(x)._commit_all(attributes.instance_dict(x)) - added, unchanged, deleted = \ - attributes.get_state_history(attributes.instance_state(x), - 'element') + added, unchanged, deleted = attributes.get_state_history( + attributes.instance_state(x), "element" + ) assert added == () assert unchanged == [el] @@ -842,26 +971,28 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), - Bar(id=4)] + bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)] def func1(state, passive): - return 'this is func 1' + return "this is func 1" def func2(state, passive): return [bar1, bar2, bar3] - attributes.register_attribute(Foo, 'col1', uselist=False, - callable_=func1, useobject=True) - attributes.register_attribute(Foo, 'col2', uselist=True, - callable_=func2, useobject=True) - attributes.register_attribute(Bar, 'id', uselist=False, - useobject=True) + attributes.register_attribute( + Foo, "col1", uselist=False, callable_=func1, useobject=True + ) + attributes.register_attribute( + Foo, "col2", uselist=True, callable_=func2, useobject=True + ) + attributes.register_attribute(Bar, "id", uselist=False, useobject=True) x = Foo() attributes.instance_state(x)._commit_all(attributes.instance_dict(x)) x.col2.append(bar4) - eq_(attributes.get_state_history(attributes.instance_state(x), 'col2'), - ([bar4], [bar1, bar2, bar3], [])) + eq_( + attributes.get_state_history(attributes.instance_state(x), "col2"), + ([bar4], [bar1, bar2, bar3], []), + ) def test_parenttrack(self): class Foo(object): @@ -872,22 +1003,24 @@ class AttributesTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, 'element', uselist=False, - trackparent=True, useobject=True) - attributes.register_attribute(Bar, 'element', uselist=False, - trackparent=True, useobject=True) + attributes.register_attribute( + Foo, "element", uselist=False, trackparent=True, useobject=True + ) + attributes.register_attribute( + Bar, "element", uselist=False, trackparent=True, useobject=True + ) f1 = Foo() f2 = Foo() b1 = Bar() b2 = Bar() f1.element = b1 b2.element = f2 - assert attributes.has_parent(Foo, b1, 'element') - assert not attributes.has_parent(Foo, b2, 'element') - assert not attributes.has_parent(Foo, f2, 'element') - assert attributes.has_parent(Bar, f2, 'element') + assert attributes.has_parent(Foo, b1, "element") + assert not attributes.has_parent(Foo, b2, "element") + assert not attributes.has_parent(Foo, f2, "element") + assert attributes.has_parent(Bar, f2, "element") b2.element = None - assert not attributes.has_parent(Bar, f2, 'element') + assert not attributes.has_parent(Bar, f2, "element") # test that double assignment doesn't accidentally reset the # 'parent' flag. @@ -895,9 +1028,9 @@ class AttributesTest(fixtures.ORMTest): b3 = Bar() f4 = Foo() b3.element = f4 - assert attributes.has_parent(Bar, f4, 'element') + assert attributes.has_parent(Bar, f4, "element") b3.element = f4 - assert attributes.has_parent(Bar, f4, 'element') + assert attributes.has_parent(Bar, f4, "element") def test_descriptorattributes(self): """changeset: 1633 broke ability to use ORM to map classes with @@ -906,9 +1039,8 @@ class AttributesTest(fixtures.ORMTest): simple regression test to prevent that defect. """ class des(object): - def __get__(self, instance, owner): - raise AttributeError('fake attribute') + raise AttributeError("fake attribute") class Foo(object): A = des() @@ -917,31 +1049,35 @@ class AttributesTest(fixtures.ORMTest): instrumentation.unregister_class(Foo) def test_collectionclasses(self): - class Foo(object): pass instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'collection', uselist=True, - typecallable=set, useobject=True) - assert attributes.manager_of_class(Foo).is_instrumented('collection' - ) + attributes.register_attribute( + Foo, "collection", uselist=True, typecallable=set, useobject=True + ) + assert attributes.manager_of_class(Foo).is_instrumented("collection") assert isinstance(Foo().collection, set) - attributes.unregister_attribute(Foo, 'collection') - assert not attributes.manager_of_class(Foo) \ - .is_instrumented('collection') + attributes.unregister_attribute(Foo, "collection") + assert not attributes.manager_of_class(Foo).is_instrumented( + "collection" + ) try: - attributes.register_attribute(Foo, 'collection', - uselist=True, typecallable=dict, - useobject=True) + attributes.register_attribute( + Foo, + "collection", + uselist=True, + typecallable=dict, + useobject=True, + ) assert False except sa_exc.ArgumentError as e: - assert str(e) \ - == 'Type InstrumentedDict must elect an appender '\ - 'method to be a collection class' + assert ( + str(e) == "Type InstrumentedDict must elect an appender " + "method to be a collection class" + ) class MyDict(dict): - @collection.appender def append(self, item): self[item.foo] = item @@ -950,26 +1086,35 @@ class AttributesTest(fixtures.ORMTest): def remove(self, item): del self[item.foo] - attributes.register_attribute(Foo, 'collection', uselist=True, - typecallable=MyDict, useobject=True) + attributes.register_attribute( + Foo, + "collection", + uselist=True, + typecallable=MyDict, + useobject=True, + ) assert isinstance(Foo().collection, MyDict) - attributes.unregister_attribute(Foo, 'collection') + attributes.unregister_attribute(Foo, "collection") class MyColl(object): pass try: - attributes.register_attribute(Foo, 'collection', - uselist=True, typecallable=MyColl, - useobject=True) + attributes.register_attribute( + Foo, + "collection", + uselist=True, + typecallable=MyColl, + useobject=True, + ) assert False except sa_exc.ArgumentError as e: - assert str(e) \ - == 'Type MyColl must elect an appender method to be a '\ - 'collection class' + assert ( + str(e) == "Type MyColl must elect an appender method to be a " + "collection class" + ) class MyColl(object): - @collection.iterator def __iter__(self): return iter([]) @@ -982,8 +1127,13 @@ class AttributesTest(fixtures.ORMTest): def remove(self, item): pass - attributes.register_attribute(Foo, 'collection', uselist=True, - typecallable=MyColl, useobject=True) + attributes.register_attribute( + Foo, + "collection", + uselist=True, + typecallable=MyColl, + useobject=True, + ) try: Foo().collection assert True @@ -995,52 +1145,46 @@ class AttributesTest(fixtures.ORMTest): pass instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'a', useobject=False) - attributes.register_attribute(Foo, 'b', useobject=False) - attributes.register_attribute(Foo, 'c', useobject=False) + attributes.register_attribute(Foo, "a", useobject=False) + attributes.register_attribute(Foo, "b", useobject=False) + attributes.register_attribute(Foo, "c", useobject=False) f1 = Foo() state = attributes.instance_state(f1) - f1.a = 'a1' - f1.b = 'b1' - f1.c = 'c1' + f1.a = "a1" + f1.b = "b1" + f1.c = "c1" assert not state._last_known_values - state._track_last_known_value('b') - state._track_last_known_value('c') + state._track_last_known_value("b") + state._track_last_known_value("c") eq_( state._last_known_values, - {'b': attributes.NO_VALUE, 'c': attributes.NO_VALUE}) + {"b": attributes.NO_VALUE, "c": attributes.NO_VALUE}, + ) - state._expire_attributes(state.dict, ['b']) - eq_( - state._last_known_values, - {'b': 'b1', 'c': attributes.NO_VALUE}) + state._expire_attributes(state.dict, ["b"]) + eq_(state._last_known_values, {"b": "b1", "c": attributes.NO_VALUE}) state._expire(state.dict, set()) - eq_( - state._last_known_values, - {'b': 'b1', 'c': 'c1'}) + eq_(state._last_known_values, {"b": "b1", "c": "c1"}) - f1.b = 'b2' + f1.b = "b2" - eq_( - state._last_known_values, - {'b': attributes.NO_VALUE, 'c': 'c1'}) + eq_(state._last_known_values, {"b": attributes.NO_VALUE, "c": "c1"}) - f1.c = 'c2' + f1.c = "c2" eq_( state._last_known_values, - {'b': attributes.NO_VALUE, 'c': attributes.NO_VALUE}) + {"b": attributes.NO_VALUE, "c": attributes.NO_VALUE}, + ) state._expire(state.dict, set()) - eq_( - state._last_known_values, - {'b': 'b2', 'c': 'c2'}) + eq_(state._last_known_values, {"b": "b2", "c": "c2"}) class GetNoValueTest(fixtures.ORMTest): @@ -1057,60 +1201,64 @@ class GetNoValueTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) if expected is not None: - attributes.register_attribute(Foo, - "attr", useobject=True, - uselist=False, - callable_=lazy_callable) + attributes.register_attribute( + Foo, + "attr", + useobject=True, + uselist=False, + callable_=lazy_callable, + ) else: - attributes.register_attribute(Foo, - "attr", useobject=True, - uselist=False) + attributes.register_attribute( + Foo, "attr", useobject=True, uselist=False + ) f1 = self.f1 = Foo() - return Foo.attr.impl,\ - attributes.instance_state(f1), \ - attributes.instance_dict(f1) + return ( + Foo.attr.impl, + attributes.instance_state(f1), + attributes.instance_dict(f1), + ) def test_passive_no_result(self): attr, state, dict_ = self._fixture(attributes.PASSIVE_NO_RESULT) eq_( attr.get(state, dict_, passive=attributes.PASSIVE_NO_INITIALIZE), - attributes.PASSIVE_NO_RESULT + attributes.PASSIVE_NO_RESULT, ) def test_passive_no_result_never_set(self): attr, state, dict_ = self._fixture(attributes.NEVER_SET) eq_( attr.get(state, dict_, passive=attributes.PASSIVE_NO_INITIALIZE), - attributes.PASSIVE_NO_RESULT + attributes.PASSIVE_NO_RESULT, ) - assert 'attr' not in dict_ + assert "attr" not in dict_ def test_passive_ret_never_set_never_set(self): attr, state, dict_ = self._fixture(attributes.NEVER_SET) eq_( - attr.get(state, dict_, - passive=attributes.PASSIVE_RETURN_NEVER_SET), - attributes.NEVER_SET + attr.get( + state, dict_, passive=attributes.PASSIVE_RETURN_NEVER_SET + ), + attributes.NEVER_SET, ) - assert 'attr' not in dict_ + assert "attr" not in dict_ def test_passive_ret_never_set_empty(self): attr, state, dict_ = self._fixture(None) eq_( - attr.get(state, dict_, - passive=attributes.PASSIVE_RETURN_NEVER_SET), - attributes.NEVER_SET + attr.get( + state, dict_, passive=attributes.PASSIVE_RETURN_NEVER_SET + ), + attributes.NEVER_SET, ) - assert 'attr' not in dict_ + assert "attr" not in dict_ def test_off_empty(self): attr, state, dict_ = self._fixture(None) - eq_( - attr.get(state, dict_, passive=attributes.PASSIVE_OFF), - None - ) - assert 'attr' not in dict_ + eq_(attr.get(state, dict_, passive=attributes.PASSIVE_OFF), None) + assert "attr" not in dict_ class UtilTest(fixtures.ORMTest): @@ -1124,7 +1272,8 @@ class UtilTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) attributes.register_attribute( - Foo, "coll", uselist=True, useobject=True) + Foo, "coll", uselist=True, useobject=True + ) f1 = Foo() b1 = Bar() @@ -1150,10 +1299,8 @@ class UtilTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute( - Foo, "a", uselist=False, useobject=False) - attributes.register_attribute( - Bar, "b", uselist=False, useobject=False) + attributes.register_attribute(Foo, "a", uselist=False, useobject=False) + attributes.register_attribute(Bar, "b", uselist=False, useobject=False) @event.listens_for(Foo.a, "set") def sync_a(target, value, oldvalue, initiator): @@ -1172,14 +1319,13 @@ class UtilTest(fixtures.ORMTest): f1.bar = b1 b1.foo = f1 - f1.a = 'x' - eq_(b1.b, 'x') - b1.b = 'y' - eq_(f1.a, 'y') + f1.a = "x" + eq_(b1.b, "x") + b1.b = "y" + eq_(f1.a, "y") class BackrefTest(fixtures.ORMTest): - def test_m2m(self): class Student(object): pass @@ -1189,10 +1335,16 @@ class BackrefTest(fixtures.ORMTest): instrumentation.register_class(Student) instrumentation.register_class(Course) - attributes.register_attribute(Student, 'courses', uselist=True, - backref="students", useobject=True) - attributes.register_attribute(Course, 'students', uselist=True, - backref="courses", useobject=True) + attributes.register_attribute( + Student, + "courses", + uselist=True, + backref="students", + useobject=True, + ) + attributes.register_attribute( + Course, "students", uselist=True, backref="courses", useobject=True + ) s = Student() c = Course() @@ -1218,12 +1370,22 @@ class BackrefTest(fixtures.ORMTest): instrumentation.register_class(Post) instrumentation.register_class(Blog) - attributes.register_attribute(Post, 'blog', uselist=False, - backref='posts', - trackparent=True, useobject=True) - attributes.register_attribute(Blog, 'posts', uselist=True, - backref='blog', - trackparent=True, useobject=True) + attributes.register_attribute( + Post, + "blog", + uselist=False, + backref="posts", + trackparent=True, + useobject=True, + ) + attributes.register_attribute( + Blog, + "posts", + uselist=True, + backref="blog", + trackparent=True, + useobject=True, + ) b = Blog() (p1, p2, p3) = (Post(), Post(), Post()) b.posts.append(p1) @@ -1253,14 +1415,17 @@ class BackrefTest(fixtures.ORMTest): class Jack(object): pass + instrumentation.register_class(Port) instrumentation.register_class(Jack) - attributes.register_attribute(Port, 'jack', uselist=False, - useobject=True, backref="port") + attributes.register_attribute( + Port, "jack", uselist=False, useobject=True, backref="port" + ) - attributes.register_attribute(Jack, 'port', uselist=False, - useobject=True, backref="jack") + attributes.register_attribute( + Jack, "port", uselist=False, useobject=True, backref="jack" + ) p = Port() j = Jack() @@ -1294,19 +1459,30 @@ class BackrefTest(fixtures.ORMTest): instrumentation.register_class(Parent) instrumentation.register_class(Child) instrumentation.register_class(SubChild) - attributes.register_attribute(Parent, 'child', uselist=False, - backref="parent", - parent_token=p_token, - useobject=True) - attributes.register_attribute(Child, 'parent', uselist=False, - backref="child", - parent_token=c_token, - useobject=True) - attributes.register_attribute(SubChild, 'parent', - uselist=False, - backref="child", - parent_token=c_token, - useobject=True) + attributes.register_attribute( + Parent, + "child", + uselist=False, + backref="parent", + parent_token=p_token, + useobject=True, + ) + attributes.register_attribute( + Child, + "parent", + uselist=False, + backref="child", + parent_token=c_token, + useobject=True, + ) + attributes.register_attribute( + SubChild, + "parent", + uselist=False, + backref="child", + parent_token=c_token, + useobject=True, + ) p1 = Parent() c1 = Child() @@ -1331,18 +1507,30 @@ class BackrefTest(fixtures.ORMTest): instrumentation.register_class(Parent) instrumentation.register_class(SubParent) instrumentation.register_class(Child) - attributes.register_attribute(Parent, 'children', uselist=True, - backref='parent', - parent_token=p_token, - useobject=True) - attributes.register_attribute(SubParent, 'children', uselist=True, - backref='parent', - parent_token=p_token, - useobject=True) - attributes.register_attribute(Child, 'parent', uselist=False, - backref='children', - parent_token=c_token, - useobject=True) + attributes.register_attribute( + Parent, + "children", + uselist=True, + backref="parent", + parent_token=p_token, + useobject=True, + ) + attributes.register_attribute( + SubParent, + "children", + uselist=True, + backref="parent", + parent_token=p_token, + useobject=True, + ) + attributes.register_attribute( + Child, + "parent", + uselist=False, + backref="children", + parent_token=c_token, + useobject=True, + ) p1 = Parent() p2 = SubParent() @@ -1372,11 +1560,14 @@ class CyclicBackrefAssertionTest(fixtures.TestBase): b1 = B() assert_raises_message( ValueError, - 'Bidirectional attribute conflict detected: ' + "Bidirectional attribute conflict detected: " 'Passing object to attribute "C.a" ' 'triggers a modify event on attribute "C.b" ' 'via the backref "B.c".', - setattr, c1, 'a', b1 + setattr, + c1, + "a", + b1, ) def test_collection_append_type_assertion(self): @@ -1385,11 +1576,12 @@ class CyclicBackrefAssertionTest(fixtures.TestBase): b1 = B() assert_raises_message( ValueError, - 'Bidirectional attribute conflict detected: ' + "Bidirectional attribute conflict detected: " 'Passing object to attribute "C.a" ' 'triggers a modify event on attribute "C.b" ' 'via the backref "B.c".', - c1.a.append, b1 + c1.a.append, + b1, ) def _scalar_fixture(self): @@ -1401,16 +1593,19 @@ class CyclicBackrefAssertionTest(fixtures.TestBase): class C(object): pass + instrumentation.register_class(A) instrumentation.register_class(B) instrumentation.register_class(C) - attributes.register_attribute(C, 'a', backref='c', useobject=True) - attributes.register_attribute(C, 'b', backref='c', useobject=True) + attributes.register_attribute(C, "a", backref="c", useobject=True) + attributes.register_attribute(C, "b", backref="c", useobject=True) - attributes.register_attribute(A, 'c', backref='a', useobject=True, - uselist=True) - attributes.register_attribute(B, 'c', backref='b', useobject=True, - uselist=True) + attributes.register_attribute( + A, "c", backref="a", useobject=True, uselist=True + ) + attributes.register_attribute( + B, "c", backref="b", useobject=True, uselist=True + ) return A, B, C @@ -1423,17 +1618,20 @@ class CyclicBackrefAssertionTest(fixtures.TestBase): class C(object): pass + instrumentation.register_class(A) instrumentation.register_class(B) instrumentation.register_class(C) - attributes.register_attribute(C, 'a', backref='c', useobject=True, - uselist=True) - attributes.register_attribute(C, 'b', backref='c', useobject=True, - uselist=True) + attributes.register_attribute( + C, "a", backref="c", useobject=True, uselist=True + ) + attributes.register_attribute( + C, "b", backref="c", useobject=True, uselist=True + ) - attributes.register_attribute(A, 'c', backref='a', useobject=True) - attributes.register_attribute(B, 'c', backref='b', useobject=True) + attributes.register_attribute(A, "c", backref="a", useobject=True) + attributes.register_attribute(B, "c", backref="b", useobject=True) return A, B, C @@ -1443,15 +1641,18 @@ class CyclicBackrefAssertionTest(fixtures.TestBase): class B(object): pass + instrumentation.register_class(A) instrumentation.register_class(B) - attributes.register_attribute(A, 'b', backref='a1', useobject=True) - attributes.register_attribute(B, 'a1', backref='b', useobject=True, - uselist=True) + attributes.register_attribute(A, "b", backref="a1", useobject=True) + attributes.register_attribute( + B, "a1", backref="b", useobject=True, uselist=True + ) - attributes.register_attribute(B, 'a2', backref='b', useobject=True, - uselist=True) + attributes.register_attribute( + B, "a2", backref="b", useobject=True, uselist=True + ) return A, B @@ -1461,11 +1662,12 @@ class CyclicBackrefAssertionTest(fixtures.TestBase): a1 = A() assert_raises_message( ValueError, - 'Bidirectional attribute conflict detected: ' + "Bidirectional attribute conflict detected: " 'Passing object to attribute "B.a2" ' 'triggers a modify event on attribute "B.a1" ' 'via the backref "A.b".', - b1.a2.append, a1 + b1.a2.append, + a1, ) @@ -1474,6 +1676,7 @@ class PendingBackrefTest(fixtures.ORMTest): class Post(object): def __init__(self, name): self.name = name + __hash__ = None def __eq__(self, other): @@ -1482,6 +1685,7 @@ class PendingBackrefTest(fixtures.ORMTest): class Blog(object): def __init__(self, name): self.name = name + __hash__ = None def __eq__(self, other): @@ -1491,13 +1695,23 @@ class PendingBackrefTest(fixtures.ORMTest): instrumentation.register_class(Post) instrumentation.register_class(Blog) - attributes.register_attribute(Post, 'blog', uselist=False, - backref='posts', trackparent=True, - useobject=True) - attributes.register_attribute(Blog, 'posts', uselist=True, - backref='blog', callable_=lazy_posts, - trackparent=True, - useobject=True) + attributes.register_attribute( + Post, + "blog", + uselist=False, + backref="posts", + trackparent=True, + useobject=True, + ) + attributes.register_attribute( + Blog, + "posts", + uselist=True, + backref="blog", + callable_=lazy_posts, + trackparent=True, + useobject=True, + ) return Post, Blog, lazy_posts @@ -1514,9 +1728,8 @@ class PendingBackrefTest(fixtures.ORMTest): p.blog = b eq_( - lazy_posts.mock_calls, [ - call(b1_state, attributes.PASSIVE_NO_FETCH) - ] + lazy_posts.mock_calls, + [call(b1_state, attributes.PASSIVE_NO_FETCH)], ) p = Post("post 5") @@ -1524,10 +1737,11 @@ class PendingBackrefTest(fixtures.ORMTest): # setting blog doesn't call 'posts' callable, calls with no fetch p.blog = b eq_( - lazy_posts.mock_calls, [ + lazy_posts.mock_calls, + [ call(b1_state, attributes.PASSIVE_NO_FETCH), - call(b1_state, attributes.PASSIVE_NO_FETCH) - ] + call(b1_state, attributes.PASSIVE_NO_FETCH), + ], ) lazy_posts.return_value = [p1, p2, p3] @@ -1535,11 +1749,12 @@ class PendingBackrefTest(fixtures.ORMTest): # calling backref calls the callable, populates extra posts eq_(b.posts, [p1, p2, p3, Post("post 4"), Post("post 5")]) eq_( - lazy_posts.mock_calls, [ + lazy_posts.mock_calls, + [ call(b1_state, attributes.PASSIVE_NO_FETCH), call(b1_state, attributes.PASSIVE_NO_FETCH), - call(b1_state, attributes.PASSIVE_OFF) - ] + call(b1_state, attributes.PASSIVE_OFF), + ], ) def test_lazy_history_collection(self): @@ -1557,9 +1772,12 @@ class PendingBackrefTest(fixtures.ORMTest): eq_(lazy_posts.call_count, 1) - eq_(attributes.instance_state(b). - get_history('posts', attributes.PASSIVE_OFF), - ([p, p4], [p1, p2, p3], [])) + eq_( + attributes.instance_state(b).get_history( + "posts", attributes.PASSIVE_OFF + ), + ([p, p4], [p1, p2, p3], []), + ) eq_(lazy_posts.call_count, 1) def test_passive_history_collection_never_set(self): @@ -1570,34 +1788,36 @@ class PendingBackrefTest(fixtures.ORMTest): b = Blog("blog 1") p = Post("post 1") - state, dict_ = (attributes.instance_state(b), - attributes.instance_dict(b)) + state, dict_ = ( + attributes.instance_state(b), + attributes.instance_dict(b), + ) # this sets up NEVER_SET on b.posts p.blog = b eq_(state.committed_state, {"posts": attributes.NEVER_SET}) - assert 'posts' not in dict_ + assert "posts" not in dict_ # then suppose the object was made transient again, # the lazy loader would return this lazy_posts.return_value = attributes.ATTR_EMPTY - p2 = Post('asdf') + p2 = Post("asdf") p2.blog = b eq_(state.committed_state, {"posts": attributes.NEVER_SET}) - eq_(dict_['posts'], [p2]) + eq_(dict_["posts"], [p2]) # then this would fail. eq_( Blog.posts.impl.get_history(state, dict_, passive=True), - ([p2], (), ()) + ([p2], (), ()), ) eq_( Blog.posts.impl.get_all_pending(state, dict_), - [(attributes.instance_state(p2), p2)] + [(attributes.instance_state(p2), p2)], ) def test_state_on_add_remove(self): @@ -1608,18 +1828,28 @@ class PendingBackrefTest(fixtures.ORMTest): b1_state = attributes.instance_state(b) p = Post("post 1") p.blog = b - eq_(lazy_posts.mock_calls, - [call(b1_state, attributes.PASSIVE_NO_FETCH)]) + eq_( + lazy_posts.mock_calls, + [call(b1_state, attributes.PASSIVE_NO_FETCH)], + ) p.blog = None - eq_(lazy_posts.mock_calls, - [call(b1_state, attributes.PASSIVE_NO_FETCH), - call(b1_state, attributes.PASSIVE_NO_FETCH)]) + eq_( + lazy_posts.mock_calls, + [ + call(b1_state, attributes.PASSIVE_NO_FETCH), + call(b1_state, attributes.PASSIVE_NO_FETCH), + ], + ) lazy_posts.return_value = [] eq_(b.posts, []) - eq_(lazy_posts.mock_calls, - [call(b1_state, attributes.PASSIVE_NO_FETCH), - call(b1_state, attributes.PASSIVE_NO_FETCH), - call(b1_state, attributes.PASSIVE_OFF)]) + eq_( + lazy_posts.mock_calls, + [ + call(b1_state, attributes.PASSIVE_NO_FETCH), + call(b1_state, attributes.PASSIVE_NO_FETCH), + call(b1_state, attributes.PASSIVE_OFF), + ], + ) def test_pending_combines_with_lazy(self): Post, Blog, lazy_posts = self._fixture() @@ -1644,13 +1874,16 @@ class PendingBackrefTest(fixtures.ORMTest): def test_normal_load(self): Post, Blog, lazy_posts = self._fixture() - lazy_posts.return_value = \ - (p1, p2, p3) = [Post("post 1"), Post("post 2"), Post("post 3")] + lazy_posts.return_value = (p1, p2, p3) = [ + Post("post 1"), + Post("post 2"), + Post("post 3"), + ] b = Blog("blog 1") # assign without using backref system - p2.__dict__['blog'] = b + p2.__dict__["blog"] = b eq_(b.posts, [Post("post 1"), Post("post 2"), Post("post 3")]) @@ -1663,8 +1896,7 @@ class PendingBackrefTest(fixtures.ORMTest): b_state = attributes.instance_state(b) eq_(lazy_posts.call_count, 1) - eq_(lazy_posts.mock_calls, - [call(b_state, attributes.PASSIVE_OFF)]) + eq_(lazy_posts.mock_calls, [call(b_state, attributes.PASSIVE_OFF)]) def test_commit_removes_pending(self): Post, Blog, lazy_posts = self._fixture() @@ -1681,24 +1913,29 @@ class PendingBackrefTest(fixtures.ORMTest): p1_state._commit_all(attributes.instance_dict(p1)) lazy_posts.return_value = [p1] eq_(b.posts, [Post("post 1")]) - eq_(lazy_posts.mock_calls, - [call(b_state, attributes.PASSIVE_NO_FETCH), - call(b_state, attributes.PASSIVE_OFF)]) + eq_( + lazy_posts.mock_calls, + [ + call(b_state, attributes.PASSIVE_NO_FETCH), + call(b_state, attributes.PASSIVE_OFF), + ], + ) class HistoryTest(fixtures.TestBase): - def _fixture(self, uselist, useobject, active_history, **kw): class Foo(fixtures.BasicEntity): pass instrumentation.register_class(Foo) attributes.register_attribute( - Foo, 'someattr', + Foo, + "someattr", uselist=uselist, useobject=useobject, active_history=active_history, - **kw) + **kw + ) return Foo def _two_obj_fixture(self, uselist, active_history=False): @@ -1711,75 +1948,85 @@ class HistoryTest(fixtures.TestBase): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, 'someattr', uselist=uselist, - useobject=True, - active_history=active_history) + attributes.register_attribute( + Foo, + "someattr", + uselist=uselist, + useobject=True, + active_history=active_history, + ) return Foo, Bar def _someattr_history(self, f, **kw): - passive = kw.pop('passive', None) + passive = kw.pop("passive", None) if passive is True: - kw['passive'] = attributes.PASSIVE_NO_INITIALIZE + kw["passive"] = attributes.PASSIVE_NO_INITIALIZE elif passive is False: - kw['passive'] = attributes.PASSIVE_OFF + kw["passive"] = attributes.PASSIVE_OFF return attributes.get_state_history( - attributes.instance_state(f), - 'someattr', **kw) + attributes.instance_state(f), "someattr", **kw + ) def _commit_someattr(self, f): - attributes.instance_state(f)._commit(attributes.instance_dict(f), - ['someattr']) + attributes.instance_state(f)._commit( + attributes.instance_dict(f), ["someattr"] + ) def _someattr_committed_state(self, f): Foo = f.__class__ return Foo.someattr.impl.get_committed_value( - attributes.instance_state(f), - attributes.instance_dict(f)) + attributes.instance_state(f), attributes.instance_dict(f) + ) def test_committed_value_init(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() eq_(self._someattr_committed_state(f), None) def test_committed_value_set(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() f.someattr = 3 eq_(self._someattr_committed_state(f), None) def test_committed_value_set_active_hist(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=True) + Foo = self._fixture( + uselist=False, useobject=False, active_history=True + ) f = Foo() f.someattr = 3 eq_(self._someattr_committed_state(f), None) def test_committed_value_set_commit(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() f.someattr = 3 self._commit_someattr(f) eq_(self._someattr_committed_state(f), 3) def test_scalar_init(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() eq_(self._someattr_history(f), ((), (), ())) def test_object_init(self): - Foo = self._fixture(uselist=False, useobject=True, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=True, active_history=False + ) f = Foo() eq_(self._someattr_history(f), ((), (), ())) def test_object_init_active_history(self): - Foo = self._fixture(uselist=False, useobject=True, - active_history=True) + Foo = self._fixture(uselist=False, useobject=True, active_history=True) f = Foo() eq_(self._someattr_history(f), ((), (), ())) @@ -1810,8 +2057,8 @@ class HistoryTest(fixtures.TestBase): f.someattr = b1 self._commit_someattr(f) - attributes.instance_state(f).dict.pop('someattr', None) - attributes.instance_state(f).expired_attributes.add('someattr') + attributes.instance_state(f).dict.pop("someattr", None) + attributes.instance_state(f).expired_attributes.add("someattr") f.someattr = None eq_(self._someattr_history(f), ([None], (), ())) @@ -1838,29 +2085,31 @@ class HistoryTest(fixtures.TestBase): # is db-loaded when testing if an empty "del" is valid, # because there's nothing else to look at for a related # object, there's no "expired" status - attributes.instance_state(f).key = ('foo', ) + attributes.instance_state(f).key = ("foo",) attributes.instance_state(f)._expire_attributes( - attributes.instance_dict(f), - ['someattr']) + attributes.instance_dict(f), ["someattr"] + ) del f.someattr eq_(self._someattr_history(f), ([None], (), ())) def test_scalar_no_init_side_effect(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() self._someattr_history(f) # no side effects - assert 'someattr' not in f.__dict__ - assert 'someattr' not in attributes.instance_state(f).committed_state + assert "someattr" not in f.__dict__ + assert "someattr" not in attributes.instance_state(f).committed_state def test_scalar_set(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.someattr = 'hi' - eq_(self._someattr_history(f), (['hi'], (), ())) + f.someattr = "hi" + eq_(self._someattr_history(f), (["hi"], (), ())) def test_scalar_set_None(self): # note - compare: @@ -1868,8 +2117,9 @@ class HistoryTest(fixtures.TestBase): # test_scalar_get_first_set_None, # test_use_object_set_None, # test_use_object_get_first_set_None - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() f.someattr = None eq_(self._someattr_history(f), ([None], (), ())) @@ -1880,11 +2130,12 @@ class HistoryTest(fixtures.TestBase): # test_scalar_get_first_set_None, # test_use_object_set_None, # test_use_object_get_first_set_None - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() f.someattr = 5 - attributes.instance_state(f).key = ('foo', ) + attributes.instance_state(f).key = ("foo",) self._commit_someattr(f) del f.someattr @@ -1896,15 +2147,16 @@ class HistoryTest(fixtures.TestBase): # test_scalar_get_first_set_None, # test_use_object_set_None, # test_use_object_get_first_set_None - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() f.someattr = 5 self._commit_someattr(f) attributes.instance_state(f)._expire_attributes( - attributes.instance_dict(f), - ['someattr']) + attributes.instance_dict(f), ["someattr"] + ) del f.someattr eq_(self._someattr_history(f), ([None], (), ())) @@ -1914,224 +2166,247 @@ class HistoryTest(fixtures.TestBase): # test_scalar_get_first_set_None, # test_use_object_set_None, # test_use_object_get_first_set_None - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() assert f.someattr is None f.someattr = None eq_(self._someattr_history(f), ([None], (), ())) def test_scalar_set_commit(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.someattr = 'hi' + f.someattr = "hi" self._commit_someattr(f) - eq_(self._someattr_history(f), ((), ['hi'], ())) + eq_(self._someattr_history(f), ((), ["hi"], ())) def test_scalar_set_commit_reset(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.someattr = 'hi' + f.someattr = "hi" self._commit_someattr(f) - f.someattr = 'there' - eq_(self._someattr_history(f), (['there'], (), ['hi'])) + f.someattr = "there" + eq_(self._someattr_history(f), (["there"], (), ["hi"])) def test_scalar_set_commit_reset_commit(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.someattr = 'hi' + f.someattr = "hi" self._commit_someattr(f) - f.someattr = 'there' + f.someattr = "there" self._commit_someattr(f) - eq_(self._someattr_history(f), ((), ['there'], ())) + eq_(self._someattr_history(f), ((), ["there"], ())) def test_scalar_set_commit_reset_commit_del(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.someattr = 'there' + f.someattr = "there" self._commit_someattr(f) del f.someattr - eq_(self._someattr_history(f), ((), (), ['there'])) + eq_(self._someattr_history(f), ((), (), ["there"])) def test_scalar_set_dict(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.__dict__['someattr'] = 'new' - eq_(self._someattr_history(f), ((), ['new'], ())) + f.__dict__["someattr"] = "new" + eq_(self._someattr_history(f), ((), ["new"], ())) def test_scalar_set_dict_set(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.__dict__['someattr'] = 'new' + f.__dict__["someattr"] = "new" self._someattr_history(f) - f.someattr = 'old' - eq_(self._someattr_history(f), (['old'], (), ['new'])) + f.someattr = "old" + eq_(self._someattr_history(f), (["old"], (), ["new"])) def test_scalar_set_dict_set_commit(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.__dict__['someattr'] = 'new' + f.__dict__["someattr"] = "new" self._someattr_history(f) - f.someattr = 'old' + f.someattr = "old" self._commit_someattr(f) - eq_(self._someattr_history(f), ((), ['old'], ())) + eq_(self._someattr_history(f), ((), ["old"], ())) def test_scalar_set_None_from_dict_set(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.__dict__['someattr'] = 'new' + f.__dict__["someattr"] = "new" f.someattr = None - eq_(self._someattr_history(f), ([None], (), ['new'])) + eq_(self._someattr_history(f), ([None], (), ["new"])) def test_scalar_set_twice_no_commit(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.someattr = 'one' - eq_(self._someattr_history(f), (['one'], (), ())) - f.someattr = 'two' - eq_(self._someattr_history(f), (['two'], (), ())) + f.someattr = "one" + eq_(self._someattr_history(f), (["one"], (), ())) + f.someattr = "two" + eq_(self._someattr_history(f), (["two"], (), ())) def test_scalar_active_init(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=True) + Foo = self._fixture( + uselist=False, useobject=False, active_history=True + ) f = Foo() eq_(self._someattr_history(f), ((), (), ())) def test_scalar_active_no_init_side_effect(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=True) + Foo = self._fixture( + uselist=False, useobject=False, active_history=True + ) f = Foo() self._someattr_history(f) # no side effects - assert 'someattr' not in f.__dict__ - assert 'someattr' not in attributes.instance_state(f).committed_state + assert "someattr" not in f.__dict__ + assert "someattr" not in attributes.instance_state(f).committed_state def test_collection_never_set(self): - Foo = self._fixture(uselist=True, useobject=True, - active_history=True) + Foo = self._fixture(uselist=True, useobject=True, active_history=True) f = Foo() eq_(self._someattr_history(f, passive=True), (None, None, None)) def test_scalar_obj_never_set(self): - Foo = self._fixture(uselist=False, useobject=True, - active_history=True) + Foo = self._fixture(uselist=False, useobject=True, active_history=True) f = Foo() eq_(self._someattr_history(f, passive=True), (None, None, None)) def test_scalar_never_set(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=True) + Foo = self._fixture( + uselist=False, useobject=False, active_history=True + ) f = Foo() eq_(self._someattr_history(f, passive=True), (None, None, None)) def test_scalar_active_set(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=True) + Foo = self._fixture( + uselist=False, useobject=False, active_history=True + ) f = Foo() - f.someattr = 'hi' - eq_(self._someattr_history(f), (['hi'], (), ())) + f.someattr = "hi" + eq_(self._someattr_history(f), (["hi"], (), ())) def test_scalar_active_set_commit(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=True) + Foo = self._fixture( + uselist=False, useobject=False, active_history=True + ) f = Foo() - f.someattr = 'hi' + f.someattr = "hi" self._commit_someattr(f) - eq_(self._someattr_history(f), ((), ['hi'], ())) + eq_(self._someattr_history(f), ((), ["hi"], ())) def test_scalar_active_set_commit_reset(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=True) + Foo = self._fixture( + uselist=False, useobject=False, active_history=True + ) f = Foo() - f.someattr = 'hi' + f.someattr = "hi" self._commit_someattr(f) - f.someattr = 'there' - eq_(self._someattr_history(f), (['there'], (), ['hi'])) + f.someattr = "there" + eq_(self._someattr_history(f), (["there"], (), ["hi"])) def test_scalar_active_set_commit_reset_commit(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=True) + Foo = self._fixture( + uselist=False, useobject=False, active_history=True + ) f = Foo() - f.someattr = 'hi' + f.someattr = "hi" self._commit_someattr(f) - f.someattr = 'there' + f.someattr = "there" self._commit_someattr(f) - eq_(self._someattr_history(f), ((), ['there'], ())) + eq_(self._someattr_history(f), ((), ["there"], ())) def test_scalar_active_set_commit_reset_commit_del(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=True) + Foo = self._fixture( + uselist=False, useobject=False, active_history=True + ) f = Foo() - f.someattr = 'there' + f.someattr = "there" self._commit_someattr(f) del f.someattr - eq_(self._someattr_history(f), ((), (), ['there'])) + eq_(self._someattr_history(f), ((), (), ["there"])) def test_scalar_active_set_dict(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=True) + Foo = self._fixture( + uselist=False, useobject=False, active_history=True + ) f = Foo() - f.__dict__['someattr'] = 'new' - eq_(self._someattr_history(f), ((), ['new'], ())) + f.__dict__["someattr"] = "new" + eq_(self._someattr_history(f), ((), ["new"], ())) def test_scalar_active_set_dict_set(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=True) + Foo = self._fixture( + uselist=False, useobject=False, active_history=True + ) f = Foo() - f.__dict__['someattr'] = 'new' + f.__dict__["someattr"] = "new" self._someattr_history(f) - f.someattr = 'old' - eq_(self._someattr_history(f), (['old'], (), ['new'])) + f.someattr = "old" + eq_(self._someattr_history(f), (["old"], (), ["new"])) def test_scalar_active_set_dict_set_commit(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=True) + Foo = self._fixture( + uselist=False, useobject=False, active_history=True + ) f = Foo() - f.__dict__['someattr'] = 'new' + f.__dict__["someattr"] = "new" self._someattr_history(f) - f.someattr = 'old' + f.someattr = "old" self._commit_someattr(f) - eq_(self._someattr_history(f), ((), ['old'], ())) + eq_(self._someattr_history(f), ((), ["old"], ())) def test_scalar_active_set_None(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=True) + Foo = self._fixture( + uselist=False, useobject=False, active_history=True + ) f = Foo() f.someattr = None eq_(self._someattr_history(f), ([None], (), ())) def test_scalar_active_set_None_from_dict_set(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=True) + Foo = self._fixture( + uselist=False, useobject=False, active_history=True + ) f = Foo() - f.__dict__['someattr'] = 'new' + f.__dict__["someattr"] = "new" f.someattr = None - eq_(self._someattr_history(f), ([None], (), ['new'])) + eq_(self._someattr_history(f), ([None], (), ["new"])) def test_scalar_active_set_twice_no_commit(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=True) + Foo = self._fixture( + uselist=False, useobject=False, active_history=True + ) f = Foo() - f.someattr = 'one' - eq_(self._someattr_history(f), (['one'], (), ())) - f.someattr = 'two' - eq_(self._someattr_history(f), (['two'], (), ())) + f.someattr = "one" + eq_(self._someattr_history(f), (["one"], (), ())) + f.someattr = "two" + eq_(self._someattr_history(f), (["two"], (), ())) def test_scalar_passive_flag(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=True) + Foo = self._fixture( + uselist=False, useobject=False, active_history=True + ) f = Foo() - f.someattr = 'one' - eq_(self._someattr_history(f), (['one'], (), ())) + f.someattr = "one" + eq_(self._someattr_history(f), (["one"], (), ())) self._commit_someattr(f) @@ -2139,114 +2414,126 @@ class HistoryTest(fixtures.TestBase): # do the same thing that # populators.expire.append((self.key, True)) # does in loading.py - state.dict.pop('someattr', None) - state.expired_attributes.add('someattr') + state.dict.pop("someattr", None) + state.expired_attributes.add("someattr") def scalar_loader(state, toload): - state.dict['someattr'] = 'one' + state.dict["someattr"] = "one" + state.manager.deferred_scalar_loader = scalar_loader - eq_(self._someattr_history(f), ((), ['one'], ())) + eq_(self._someattr_history(f), ((), ["one"], ())) def test_scalar_inplace_mutation_set(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.someattr = {'a': 'b'} - eq_(self._someattr_history(f), ([{'a': 'b'}], (), ())) + f.someattr = {"a": "b"} + eq_(self._someattr_history(f), ([{"a": "b"}], (), ())) def test_scalar_inplace_mutation_set_commit(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.someattr = {'a': 'b'} + f.someattr = {"a": "b"} self._commit_someattr(f) - eq_(self._someattr_history(f), ((), [{'a': 'b'}], ())) + eq_(self._someattr_history(f), ((), [{"a": "b"}], ())) def test_scalar_inplace_mutation_set_commit_set(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.someattr = {'a': 'b'} + f.someattr = {"a": "b"} self._commit_someattr(f) - f.someattr['a'] = 'c' - eq_(self._someattr_history(f), ((), [{'a': 'c'}], ())) + f.someattr["a"] = "c" + eq_(self._someattr_history(f), ((), [{"a": "c"}], ())) def test_scalar_inplace_mutation_set_commit_flag_modified(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.someattr = {'a': 'b'} + f.someattr = {"a": "b"} self._commit_someattr(f) - attributes.flag_modified(f, 'someattr') - eq_(self._someattr_history(f), ([{'a': 'b'}], (), ())) + attributes.flag_modified(f, "someattr") + eq_(self._someattr_history(f), ([{"a": "b"}], (), ())) def test_scalar_inplace_mutation_set_commit_set_flag_modified(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.someattr = {'a': 'b'} + f.someattr = {"a": "b"} self._commit_someattr(f) - f.someattr['a'] = 'c' - attributes.flag_modified(f, 'someattr') - eq_(self._someattr_history(f), ([{'a': 'c'}], (), ())) + f.someattr["a"] = "c" + attributes.flag_modified(f, "someattr") + eq_(self._someattr_history(f), ([{"a": "c"}], (), ())) def test_scalar_inplace_mutation_set_commit_flag_modified_set(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.someattr = {'a': 'b'} + f.someattr = {"a": "b"} self._commit_someattr(f) - attributes.flag_modified(f, 'someattr') - eq_(self._someattr_history(f), ([{'a': 'b'}], (), ())) - f.someattr = ['a'] - eq_(self._someattr_history(f), ([['a']], (), ())) + attributes.flag_modified(f, "someattr") + eq_(self._someattr_history(f), ([{"a": "b"}], (), ())) + f.someattr = ["a"] + eq_(self._someattr_history(f), ([["a"]], (), ())) def test_scalar_inplace_mutation_replace_self_flag_modified_set(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.someattr = {'a': 'b'} + f.someattr = {"a": "b"} self._commit_someattr(f) - eq_(self._someattr_history(f), ((), [{'a': 'b'}], ())) + eq_(self._someattr_history(f), ((), [{"a": "b"}], ())) # set the attribute to itself; this places a copy # in committed_state f.someattr = f.someattr - attributes.flag_modified(f, 'someattr') - eq_(self._someattr_history(f), ([{'a': 'b'}], (), ())) + attributes.flag_modified(f, "someattr") + eq_(self._someattr_history(f), ([{"a": "b"}], (), ())) def test_flag_modified_but_no_value_raises(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.someattr = 'foo' + f.someattr = "foo" self._commit_someattr(f) - eq_(self._someattr_history(f), ((), ['foo'], ())) + eq_(self._someattr_history(f), ((), ["foo"], ())) attributes.instance_state(f)._expire_attributes( - attributes.instance_dict(f), - ['someattr']) + attributes.instance_dict(f), ["someattr"] + ) assert_raises_message( sa_exc.InvalidRequestError, "Can't flag attribute 'someattr' modified; it's " "not present in the object state", - attributes.flag_modified, f, 'someattr' + attributes.flag_modified, + f, + "someattr", ) def test_mark_dirty_no_attr(self): - Foo = self._fixture(uselist=False, useobject=False, - active_history=False) + Foo = self._fixture( + uselist=False, useobject=False, active_history=False + ) f = Foo() - f.someattr = 'foo' + f.someattr = "foo" attributes.instance_state(f)._commit_all(f.__dict__) - eq_(self._someattr_history(f), ((), ['foo'], ())) + eq_(self._someattr_history(f), ((), ["foo"], ())) attributes.instance_state(f)._expire_attributes( - attributes.instance_dict(f), - ['someattr']) + attributes.instance_dict(f), ["someattr"] + ) is_false(attributes.instance_state(f).modified) @@ -2263,20 +2550,20 @@ class HistoryTest(fixtures.TestBase): Foo, Bar = self._two_obj_fixture(uselist=False) f = Foo() self._someattr_history(f) - assert 'someattr' not in f.__dict__ - assert 'someattr' not in attributes.instance_state(f).committed_state + assert "someattr" not in f.__dict__ + assert "someattr" not in attributes.instance_state(f).committed_state def test_use_object_set(self): Foo, Bar = self._two_obj_fixture(uselist=False) f = Foo() - hi = Bar(name='hi') + hi = Bar(name="hi") f.someattr = hi eq_(self._someattr_history(f), ([hi], (), ())) def test_use_object_set_commit(self): Foo, Bar = self._two_obj_fixture(uselist=False) f = Foo() - hi = Bar(name='hi') + hi = Bar(name="hi") f.someattr = hi self._commit_someattr(f) eq_(self._someattr_history(f), ((), [hi], ())) @@ -2284,20 +2571,20 @@ class HistoryTest(fixtures.TestBase): def test_use_object_set_commit_set(self): Foo, Bar = self._two_obj_fixture(uselist=False) f = Foo() - hi = Bar(name='hi') + hi = Bar(name="hi") f.someattr = hi self._commit_someattr(f) - there = Bar(name='there') + there = Bar(name="there") f.someattr = there eq_(self._someattr_history(f), ([there], (), [hi])) def test_use_object_set_commit_set_commit(self): Foo, Bar = self._two_obj_fixture(uselist=False) f = Foo() - hi = Bar(name='hi') + hi = Bar(name="hi") f.someattr = hi self._commit_someattr(f) - there = Bar(name='there') + there = Bar(name="there") f.someattr = there self._commit_someattr(f) eq_(self._someattr_history(f), ((), [there], ())) @@ -2305,7 +2592,7 @@ class HistoryTest(fixtures.TestBase): def test_use_object_set_commit_del(self): Foo, Bar = self._two_obj_fixture(uselist=False) f = Foo() - hi = Bar(name='hi') + hi = Bar(name="hi") f.someattr = hi self._commit_someattr(f) del f.someattr @@ -2314,27 +2601,27 @@ class HistoryTest(fixtures.TestBase): def test_use_object_set_dict(self): Foo, Bar = self._two_obj_fixture(uselist=False) f = Foo() - hi = Bar(name='hi') - f.__dict__['someattr'] = hi + hi = Bar(name="hi") + f.__dict__["someattr"] = hi eq_(self._someattr_history(f), ((), [hi], ())) def test_use_object_set_dict_set(self): Foo, Bar = self._two_obj_fixture(uselist=False) f = Foo() - hi = Bar(name='hi') - f.__dict__['someattr'] = hi + hi = Bar(name="hi") + f.__dict__["someattr"] = hi - there = Bar(name='there') + there = Bar(name="there") f.someattr = there eq_(self._someattr_history(f), ([there], (), [hi])) def test_use_object_set_dict_set_commit(self): Foo, Bar = self._two_obj_fixture(uselist=False) f = Foo() - hi = Bar(name='hi') - f.__dict__['someattr'] = hi + hi = Bar(name="hi") + f.__dict__["someattr"] = hi - there = Bar(name='there') + there = Bar(name="there") f.someattr = there self._commit_someattr(f) eq_(self._someattr_history(f), ((), [there], ())) @@ -2365,16 +2652,16 @@ class HistoryTest(fixtures.TestBase): def test_use_object_set_dict_set_None(self): Foo, Bar = self._two_obj_fixture(uselist=False) f = Foo() - hi = Bar(name='hi') - f.__dict__['someattr'] = hi + hi = Bar(name="hi") + f.__dict__["someattr"] = hi f.someattr = None eq_(self._someattr_history(f), ([None], (), [hi])) def test_use_object_set_value_twice(self): Foo, Bar = self._two_obj_fixture(uselist=False) f = Foo() - hi = Bar(name='hi') - there = Bar(name='there') + hi = Bar(name="hi") + there = Bar(name="there") f.someattr = hi f.someattr = there eq_(self._someattr_history(f), ([there], (), ())) @@ -2383,51 +2670,90 @@ class HistoryTest(fixtures.TestBase): # TODO: break into individual tests Foo, Bar = self._two_obj_fixture(uselist=True) - hi = Bar(name='hi') - there = Bar(name='there') - old = Bar(name='old') - new = Bar(name='new') + hi = Bar(name="hi") + there = Bar(name="there") + old = Bar(name="old") + new = Bar(name="new") # case 1. new object f = Foo() - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ((), [], ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ((), [], ()), + ) f.someattr = [hi] - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ([hi], [], [])) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([hi], [], []), + ) self._commit_someattr(f) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ((), [hi], ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ((), [hi], ()), + ) f.someattr = [there] - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ([there], [], [hi])) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([there], [], [hi]), + ) self._commit_someattr(f) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ((), [there], ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ((), [there], ()), + ) f.someattr = [hi] - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ([hi], [], [there])) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([hi], [], [there]), + ) f.someattr = [old, new] - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), - ([old, new], [], [there])) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([old, new], [], [there]), + ) # case 2. object with direct settings (similar to a load # operation) f = Foo() - collection = attributes.init_collection(f, 'someattr') + collection = attributes.init_collection(f, "someattr") collection.append_without_event(new) attributes.instance_state(f)._commit_all(attributes.instance_dict(f)) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ((), [new], ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ((), [new], ()), + ) f.someattr = [old] - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ([old], [], [new])) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([old], [], [new]), + ) self._commit_someattr(f) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ((), [old], ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ((), [old], ()), + ) def test_dict_collections(self): # TODO: break into individual tests @@ -2439,31 +2765,58 @@ class HistoryTest(fixtures.TestBase): pass from sqlalchemy.orm.collections import attribute_mapped_collection + instrumentation.register_class(Foo) instrumentation.register_class(Bar) attributes.register_attribute( - Foo, 'someattr', uselist=True, useobject=True, - typecallable=attribute_mapped_collection('name')) - hi = Bar(name='hi') - there = Bar(name='there') - old = Bar(name='old') - new = Bar(name='new') - f = Foo() - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ((), [], ())) - f.someattr['hi'] = hi - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ([hi], [], [])) - f.someattr['there'] = there - eq_(tuple([set(x) for x in - attributes.get_state_history(attributes.instance_state(f), - 'someattr')]), - (set([hi, there]), set(), set())) + Foo, + "someattr", + uselist=True, + useobject=True, + typecallable=attribute_mapped_collection("name"), + ) + hi = Bar(name="hi") + there = Bar(name="there") + old = Bar(name="old") + new = Bar(name="new") + f = Foo() + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ((), [], ()), + ) + f.someattr["hi"] = hi + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([hi], [], []), + ) + f.someattr["there"] = there + eq_( + tuple( + [ + set(x) + for x in attributes.get_state_history( + attributes.instance_state(f), "someattr" + ) + ] + ), + (set([hi, there]), set(), set()), + ) self._commit_someattr(f) - eq_(tuple([set(x) for x in - attributes.get_state_history(attributes.instance_state(f), - 'someattr')]), - (set(), set([hi, there]), set())) + eq_( + tuple( + [ + set(x) + for x in attributes.get_state_history( + attributes.instance_state(f), "someattr" + ) + ] + ), + (set(), set([hi, there]), set()), + ) def test_object_collections_mutate(self): # TODO: break into individual tests @@ -2475,89 +2828,160 @@ class HistoryTest(fixtures.TestBase): pass instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'someattr', uselist=True, - useobject=True) - attributes.register_attribute(Foo, 'id', uselist=False, - useobject=False) + attributes.register_attribute( + Foo, "someattr", uselist=True, useobject=True + ) + attributes.register_attribute( + Foo, "id", uselist=False, useobject=False + ) instrumentation.register_class(Bar) - hi = Bar(name='hi') - there = Bar(name='there') - old = Bar(name='old') - new = Bar(name='new') + hi = Bar(name="hi") + there = Bar(name="there") + old = Bar(name="old") + new = Bar(name="new") # case 1. new object f = Foo(id=1) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ((), [], ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ((), [], ()), + ) f.someattr.append(hi) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ([hi], [], [])) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([hi], [], []), + ) self._commit_someattr(f) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ((), [hi], ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ((), [hi], ()), + ) f.someattr.append(there) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ([there], [hi], [])) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([there], [hi], []), + ) self._commit_someattr(f) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ((), [hi, there], ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ((), [hi, there], ()), + ) f.someattr.remove(there) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ([], [hi], [there])) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([], [hi], [there]), + ) f.someattr.append(old) f.someattr.append(new) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), - ([old, new], [hi], [there])) - attributes.instance_state(f)._commit(attributes.instance_dict(f), - ['someattr']) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ((), [hi, old, new], ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([old, new], [hi], [there]), + ) + attributes.instance_state(f)._commit( + attributes.instance_dict(f), ["someattr"] + ) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ((), [hi, old, new], ()), + ) f.someattr.pop(0) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ([], [old, new], [hi])) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([], [old, new], [hi]), + ) # case 2. object with direct settings (similar to a load # operation) f = Foo() - f.__dict__['id'] = 1 - collection = attributes.init_collection(f, 'someattr') + f.__dict__["id"] = 1 + collection = attributes.init_collection(f, "someattr") collection.append_without_event(new) attributes.instance_state(f)._commit_all(attributes.instance_dict(f)) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ((), [new], ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ((), [new], ()), + ) f.someattr.append(old) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ([old], [new], [])) - attributes.instance_state(f)._commit(attributes.instance_dict(f), - ['someattr']) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ((), [new, old], ())) - f = Foo() - collection = attributes.init_collection(f, 'someattr') + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([old], [new], []), + ) + attributes.instance_state(f)._commit( + attributes.instance_dict(f), ["someattr"] + ) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ((), [new, old], ()), + ) + f = Foo() + collection = attributes.init_collection(f, "someattr") collection.append_without_event(new) attributes.instance_state(f)._commit_all(attributes.instance_dict(f)) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ((), [new], ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ((), [new], ()), + ) f.id = 1 f.someattr.remove(new) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ([], [], [new])) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([], [], [new]), + ) # case 3. mixing appends with sets f = Foo() f.someattr.append(hi) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ([hi], [], [])) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([hi], [], []), + ) f.someattr.append(there) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ([hi, there], [], [])) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([hi, there], [], []), + ) f.someattr = [there] - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), ([there], [], [])) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([there], [], []), + ) # case 4. ensure duplicates show up, order is maintained @@ -2565,17 +2989,26 @@ class HistoryTest(fixtures.TestBase): f.someattr.append(hi) f.someattr.append(there) f.someattr.append(hi) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), - ([hi, there, hi], [], [])) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([hi, there, hi], [], []), + ) attributes.instance_state(f)._commit_all(attributes.instance_dict(f)) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), - ((), [hi, there, hi], ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ((), [hi, there, hi], ()), + ) f.someattr = [] - eq_(attributes.get_state_history(attributes.instance_state(f), - 'someattr'), - ([], [], [hi, there, hi])) + eq_( + attributes.get_state_history( + attributes.instance_state(f), "someattr" + ), + ([], [], [hi, there, hi]), + ) def test_collections_via_backref(self): # TODO: break into individual tests @@ -2588,48 +3021,84 @@ class HistoryTest(fixtures.TestBase): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, 'bars', uselist=True, - backref='foo', trackparent=True, - useobject=True) - attributes.register_attribute(Bar, 'foo', uselist=False, - backref='bars', trackparent=True, - useobject=True) + attributes.register_attribute( + Foo, + "bars", + uselist=True, + backref="foo", + trackparent=True, + useobject=True, + ) + attributes.register_attribute( + Bar, + "foo", + uselist=False, + backref="bars", + trackparent=True, + useobject=True, + ) f1 = Foo() b1 = Bar() - eq_(attributes.get_state_history(attributes.instance_state(f1), - 'bars'), ((), [], ())) - eq_(attributes.get_state_history(attributes.instance_state(b1), - 'foo'), ((), (), ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f1), "bars" + ), + ((), [], ()), + ) + eq_( + attributes.get_state_history(attributes.instance_state(b1), "foo"), + ((), (), ()), + ) # b1.foo = f1 f1.bars.append(b1) - eq_(attributes.get_state_history(attributes.instance_state(f1), - 'bars'), ([b1], [], [])) - eq_(attributes.get_state_history(attributes.instance_state(b1), - 'foo'), ([f1], (), ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f1), "bars" + ), + ([b1], [], []), + ) + eq_( + attributes.get_state_history(attributes.instance_state(b1), "foo"), + ([f1], (), ()), + ) b2 = Bar() f1.bars.append(b2) - eq_(attributes.get_state_history(attributes.instance_state(f1), - 'bars'), ([b1, b2], [], [])) - eq_(attributes.get_state_history(attributes.instance_state(b1), - 'foo'), ([f1], (), ())) - eq_(attributes.get_state_history(attributes.instance_state(b2), - 'foo'), ([f1], (), ())) + eq_( + attributes.get_state_history( + attributes.instance_state(f1), "bars" + ), + ([b1, b2], [], []), + ) + eq_( + attributes.get_state_history(attributes.instance_state(b1), "foo"), + ([f1], (), ()), + ) + eq_( + attributes.get_state_history(attributes.instance_state(b2), "foo"), + ([f1], (), ()), + ) def test_deprecated_flags(self): assert_raises_message( sa_exc.SADeprecationWarning, "Passing True for 'passive' is deprecated. " "Use attributes.PASSIVE_NO_INITIALIZE", - attributes.get_history, object(), 'foo', True + attributes.get_history, + object(), + "foo", + True, ) assert_raises_message( sa_exc.SADeprecationWarning, "Passing False for 'passive' is deprecated. " "Use attributes.PASSIVE_OFF", - attributes.get_history, object(), 'foo', False + attributes.get_history, + object(), + "foo", + False, ) @@ -2650,34 +3119,48 @@ class LazyloadHistoryTest(fixtures.TestBase): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, 'bars', uselist=True, - backref='foo', trackparent=True, - callable_=lazyload, - useobject=True) - attributes.register_attribute(Bar, 'foo', uselist=False, - backref='bars', trackparent=True, - useobject=True) - bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), - Bar(id=4)] + attributes.register_attribute( + Foo, + "bars", + uselist=True, + backref="foo", + trackparent=True, + callable_=lazyload, + useobject=True, + ) + attributes.register_attribute( + Bar, + "foo", + uselist=False, + backref="bars", + trackparent=True, + useobject=True, + ) + bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)] lazy_load = [bar1, bar2, bar3] f = Foo() bar4 = Bar() bar4.foo = f - eq_(attributes.get_state_history(attributes.instance_state(f), - 'bars'), - ([bar4], [bar1, bar2, bar3], [])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bars"), + ([bar4], [bar1, bar2, bar3], []), + ) lazy_load = None f = Foo() bar4 = Bar() bar4.foo = f - eq_(attributes.get_state_history(attributes.instance_state(f), 'bars'), - ([bar4], [], [])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bars"), + ([bar4], [], []), + ) lazy_load = [bar1, bar2, bar3] attributes.instance_state(f)._expire_attributes( - attributes.instance_dict(f), - ['bars']) - eq_(attributes.get_state_history(attributes.instance_state(f), 'bars'), - ((), [bar1, bar2, bar3], ())) + attributes.instance_dict(f), ["bars"] + ) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bars"), + ((), [bar1, bar2, bar3], ()), + ) def test_collections_via_lazyload(self): # TODO: break into individual tests @@ -2695,36 +3178,52 @@ class LazyloadHistoryTest(fixtures.TestBase): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, 'bars', uselist=True, - callable_=lazyload, trackparent=True, - useobject=True) - bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), - Bar(id=4)] + attributes.register_attribute( + Foo, + "bars", + uselist=True, + callable_=lazyload, + trackparent=True, + useobject=True, + ) + bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)] lazy_load = [bar1, bar2, bar3] f = Foo() f.bars = [] - eq_(attributes.get_state_history(attributes.instance_state(f), 'bars'), - ([], [], [bar1, bar2, bar3])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bars"), + ([], [], [bar1, bar2, bar3]), + ) f = Foo() f.bars.append(bar4) - eq_(attributes.get_state_history(attributes.instance_state(f), 'bars'), - ([bar4], [bar1, bar2, bar3], [])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bars"), + ([bar4], [bar1, bar2, bar3], []), + ) f = Foo() f.bars.remove(bar2) - eq_(attributes.get_state_history(attributes.instance_state(f), - 'bars'), ([], [bar1, bar3], [bar2])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bars"), + ([], [bar1, bar3], [bar2]), + ) f.bars.append(bar4) - eq_(attributes.get_state_history(attributes.instance_state(f), 'bars'), - ([bar4], [bar1, bar3], [bar2])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bars"), + ([bar4], [bar1, bar3], [bar2]), + ) f = Foo() del f.bars[1] - eq_(attributes.get_state_history(attributes.instance_state(f), - 'bars'), ([], [bar1, bar3], [bar2])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bars"), + ([], [bar1, bar3], [bar2]), + ) lazy_load = None f = Foo() f.bars.append(bar2) - eq_(attributes.get_state_history(attributes.instance_state(f), 'bars'), - ([bar2], [], [])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bars"), + ([bar2], [], []), + ) def test_scalar_via_lazyload(self): # TODO: break into individual tests @@ -2738,36 +3237,49 @@ class LazyloadHistoryTest(fixtures.TestBase): return lazy_load instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'bar', uselist=False, - callable_=lazyload, useobject=False) - lazy_load = 'hi' + attributes.register_attribute( + Foo, "bar", uselist=False, callable_=lazyload, useobject=False + ) + lazy_load = "hi" # with scalar non-object and active_history=False, the lazy # callable is only executed on gets, not history operations f = Foo() - eq_(f.bar, 'hi') - eq_(attributes.get_state_history(attributes.instance_state(f), 'bar'), - ((), ['hi'], ())) + eq_(f.bar, "hi") + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + ((), ["hi"], ()), + ) f = Foo() f.bar = None - eq_(attributes.get_state_history(attributes.instance_state(f), - 'bar'), ([None], (), ())) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + ([None], (), ()), + ) f = Foo() - f.bar = 'there' - eq_(attributes.get_state_history(attributes.instance_state(f), 'bar'), - (['there'], (), ())) - f.bar = 'hi' - eq_(attributes.get_state_history(attributes.instance_state(f), - 'bar'), (['hi'], (), ())) + f.bar = "there" + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + (["there"], (), ()), + ) + f.bar = "hi" + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + (["hi"], (), ()), + ) f = Foo() - eq_(f.bar, 'hi') + eq_(f.bar, "hi") del f.bar - eq_(attributes.get_state_history(attributes.instance_state(f), 'bar'), - ((), (), ['hi'])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + ((), (), ["hi"]), + ) assert f.bar is None - eq_(attributes.get_state_history(attributes.instance_state(f), 'bar'), - ((), (), ['hi'])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + ((), (), ["hi"]), + ) def test_scalar_via_lazyload_with_active(self): # TODO: break into individual tests @@ -2781,37 +3293,54 @@ class LazyloadHistoryTest(fixtures.TestBase): return lazy_load instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'bar', uselist=False, - callable_=lazyload, useobject=False, - active_history=True) - lazy_load = 'hi' + attributes.register_attribute( + Foo, + "bar", + uselist=False, + callable_=lazyload, + useobject=False, + active_history=True, + ) + lazy_load = "hi" # active_history=True means the lazy callable is executed on set # as well as get, causing the old value to appear in the history f = Foo() - eq_(f.bar, 'hi') - eq_(attributes.get_state_history(attributes.instance_state(f), 'bar'), - ((), ['hi'], ())) + eq_(f.bar, "hi") + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + ((), ["hi"], ()), + ) f = Foo() f.bar = None - eq_(attributes.get_state_history(attributes.instance_state(f), - 'bar'), ([None], (), ['hi'])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + ([None], (), ["hi"]), + ) f = Foo() - f.bar = 'there' - eq_(attributes.get_state_history(attributes.instance_state(f), 'bar'), - (['there'], (), ['hi'])) - f.bar = 'hi' - eq_(attributes.get_state_history(attributes.instance_state(f), 'bar'), - ((), ['hi'], ())) + f.bar = "there" + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + (["there"], (), ["hi"]), + ) + f.bar = "hi" + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + ((), ["hi"], ()), + ) f = Foo() - eq_(f.bar, 'hi') + eq_(f.bar, "hi") del f.bar - eq_(attributes.get_state_history(attributes.instance_state(f), 'bar'), - ((), (), ['hi'])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + ((), (), ["hi"]), + ) assert f.bar is None - eq_(attributes.get_state_history(attributes.instance_state(f), 'bar'), - ((), (), ['hi'])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + ((), (), ["hi"]), + ) def test_scalar_object_via_lazyload(self): # TODO: break into individual tests @@ -2829,9 +3358,14 @@ class LazyloadHistoryTest(fixtures.TestBase): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, 'bar', uselist=False, - callable_=lazyload, trackparent=True, - useobject=True) + attributes.register_attribute( + Foo, + "bar", + uselist=False, + callable_=lazyload, + trackparent=True, + useobject=True, + ) bar1, bar2 = [Bar(id=1), Bar(id=2)] lazy_load = bar1 @@ -2839,27 +3373,39 @@ class LazyloadHistoryTest(fixtures.TestBase): # and history operations f = Foo() - eq_(attributes.get_state_history(attributes.instance_state(f), 'bar'), - ((), [bar1], ())) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + ((), [bar1], ()), + ) f = Foo() f.bar = None - eq_(attributes.get_state_history(attributes.instance_state(f), 'bar'), - ([None], (), [bar1])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + ([None], (), [bar1]), + ) f = Foo() f.bar = bar2 - eq_(attributes.get_state_history(attributes.instance_state(f), 'bar'), - ([bar2], (), [bar1])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + ([bar2], (), [bar1]), + ) f.bar = bar1 - eq_(attributes.get_state_history(attributes.instance_state(f), 'bar'), - ((), [bar1], ())) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + ((), [bar1], ()), + ) f = Foo() eq_(f.bar, bar1) del f.bar - eq_(attributes.get_state_history(attributes.instance_state(f), 'bar'), - ((), (), [bar1])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + ((), (), [bar1]), + ) assert f.bar is None - eq_(attributes.get_state_history(attributes.instance_state(f), 'bar'), - ((), (), [bar1])) + eq_( + attributes.get_state_history(attributes.instance_state(f), "bar"), + ((), (), [bar1]), + ) class ListenerTest(fixtures.ORMTest): @@ -2882,27 +3428,31 @@ class ListenerTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, 'data', uselist=False, - useobject=False) - attributes.register_attribute(Foo, 'barlist', uselist=True, - useobject=True) - attributes.register_attribute(Foo, 'barset', typecallable=set, - uselist=True, useobject=True) - attributes.register_attribute(Bar, 'data', uselist=False, - useobject=False) - event.listen(Foo.data, 'set', on_set, retval=True) - event.listen(Foo.barlist, 'append', append, retval=True) - event.listen(Foo.barset, 'append', append, retval=True) + attributes.register_attribute( + Foo, "data", uselist=False, useobject=False + ) + attributes.register_attribute( + Foo, "barlist", uselist=True, useobject=True + ) + attributes.register_attribute( + Foo, "barset", typecallable=set, uselist=True, useobject=True + ) + attributes.register_attribute( + Bar, "data", uselist=False, useobject=False + ) + event.listen(Foo.data, "set", on_set, retval=True) + event.listen(Foo.barlist, "append", append, retval=True) + event.listen(Foo.barset, "append", append, retval=True) f1 = Foo() - f1.data = 'some data' - eq_(f1.data, 'some data modified') + f1.data = "some data" + eq_(f1.data, "some data modified") b1 = Bar() - b1.data = 'some bar' + b1.data = "some bar" f1.barlist.append(b1) - assert b1.data == 'some bar' - assert f1.barlist[0].data == 'some bar appended' + assert b1.data == "some bar" + assert f1.barlist[0].data == "some bar appended" f1.barset.add(b1) - assert f1.barset.pop().data == 'some bar appended' + assert f1.barset.pop().data == "some bar appended" def test_named(self): canary = Mock() @@ -2916,15 +3466,15 @@ class ListenerTest(fixtures.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) attributes.register_attribute( - Foo, 'data', uselist=False, - useobject=False) + Foo, "data", uselist=False, useobject=False + ) attributes.register_attribute( - Foo, 'barlist', uselist=True, - useobject=True) + Foo, "barlist", uselist=True, useobject=True + ) - event.listen(Foo.data, 'set', canary.set, named=True) - event.listen(Foo.barlist, 'append', canary.append, named=True) - event.listen(Foo.barlist, 'remove', canary.remove, named=True) + event.listen(Foo.data, "set", canary.set, named=True) + event.listen(Foo.barlist, "append", canary.append, named=True) + event.listen(Foo.barlist, "remove", canary.remove, named=True) f1 = Foo() b1 = Bar() @@ -2937,18 +3487,26 @@ class ListenerTest(fixtures.ORMTest): call.set( oldvalue=attributes.NO_VALUE, initiator=attributes.Event( - Foo.data.impl, attributes.OP_REPLACE), - target=f1, value=5), + Foo.data.impl, attributes.OP_REPLACE + ), + target=f1, + value=5, + ), call.append( initiator=attributes.Event( - Foo.barlist.impl, attributes.OP_APPEND), + Foo.barlist.impl, attributes.OP_APPEND + ), target=f1, - value=b1), + value=b1, + ), call.remove( initiator=attributes.Event( - Foo.barlist.impl, attributes.OP_REMOVE), + Foo.barlist.impl, attributes.OP_REMOVE + ), target=f1, - value=b1)] + value=b1, + ), + ], ) def test_collection_link_events(self): @@ -2957,10 +3515,12 @@ class ListenerTest(fixtures.ORMTest): class Bar(object): pass + instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, 'barlist', uselist=True, - useobject=True) + attributes.register_attribute( + Foo, "barlist", uselist=True, useobject=True + ) canary = Mock() event.listen(Foo.barlist, "init_collection", canary.init) @@ -2977,18 +3537,16 @@ class ListenerTest(fixtures.ORMTest): b2 = Bar() f1.barlist = [b2] adapter_two = f1.barlist._sa_adapter - eq_(canary.init.mock_calls, [ - call(f1, [b1], adapter_one), # note the f1.barlist that - # we saved earlier has been mutated - # in place, new as of [ticket:3913] - call(f1, [b2], adapter_two), - ]) - eq_( - canary.dispose.mock_calls, + eq_( + canary.init.mock_calls, [ - call(f1, [b1], adapter_one) - ] + call(f1, [b1], adapter_one), # note the f1.barlist that + # we saved earlier has been mutated + # in place, new as of [ticket:3913] + call(f1, [b2], adapter_two), + ], ) + eq_(canary.dispose.mock_calls, [call(f1, [b1], adapter_one)]) def test_none_on_collection_event(self): """test that append/remove of None in collections emits events. @@ -2996,15 +3554,18 @@ class ListenerTest(fixtures.ORMTest): This is new behavior as of 0.8. """ + class Foo(object): pass class Bar(object): pass + instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, 'barlist', uselist=True, - useobject=True) + attributes.register_attribute( + Foo, "barlist", uselist=True, useobject=True + ) canary = [] def append(state, child, initiator): @@ -3012,8 +3573,9 @@ class ListenerTest(fixtures.ORMTest): def remove(state, child, initiator): canary.append((state, child)) - event.listen(Foo.barlist, 'append', append) - event.listen(Foo.barlist, 'remove', remove) + + event.listen(Foo.barlist, "append", append) + event.listen(Foo.barlist, "remove", remove) b1, b2 = Bar(), Bar() f1 = Foo() @@ -3038,16 +3600,17 @@ class ListenerTest(fixtures.ORMTest): class Foo(object): pass + instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'bar') + attributes.register_attribute(Foo, "bar") event.listen(Foo.bar, "modified", canary) f1 = Foo() - f1.bar = 'hi' + f1.bar = "hi" attributes.flag_modified(f1, "bar") eq_( canary.mock_calls, - [call(f1, attributes.Event(Foo.bar.impl, attributes.OP_MODIFIED))] + [call(f1, attributes.Event(Foo.bar.impl, attributes.OP_MODIFIED))], ) def test_none_init_scalar(self): @@ -3055,8 +3618,9 @@ class ListenerTest(fixtures.ORMTest): class Foo(object): pass + instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'bar') + attributes.register_attribute(Foo, "bar") event.listen(Foo.bar, "set", canary) @@ -3070,8 +3634,9 @@ class ListenerTest(fixtures.ORMTest): class Foo(object): pass + instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'bar', useobject=True) + attributes.register_attribute(Foo, "bar", useobject=True) event.listen(Foo.bar, "set", canary) @@ -3085,8 +3650,9 @@ class ListenerTest(fixtures.ORMTest): class Foo(object): pass + instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'bar', useobject=True, uselist=True) + attributes.register_attribute(Foo, "bar", useobject=True, uselist=True) event.listen(Foo.bar, "set", canary) @@ -3102,16 +3668,19 @@ class ListenerTest(fixtures.ORMTest): def make_a(): class A(object): pass + classes[0] = A def make_b(): class B(classes[0]): pass + classes[1] = B def make_c(): class C(classes[1]): pass + classes[2] = C def instrument_a(): @@ -3124,22 +3693,25 @@ class ListenerTest(fixtures.ORMTest): instrumentation.register_class(classes[2]) def attr_a(): - attributes.register_attribute(classes[0], 'attrib', - uselist=False, useobject=False) + attributes.register_attribute( + classes[0], "attrib", uselist=False, useobject=False + ) def attr_b(): - attributes.register_attribute(classes[1], 'attrib', - uselist=False, useobject=False) + attributes.register_attribute( + classes[1], "attrib", uselist=False, useobject=False + ) def attr_c(): - attributes.register_attribute(classes[2], 'attrib', - uselist=False, useobject=False) + attributes.register_attribute( + classes[2], "attrib", uselist=False, useobject=False + ) def set(state, value, oldvalue, initiator): canary.append(value) def events_a(): - event.listen(classes[0].attrib, 'set', set, propagate=True) + event.listen(classes[0].attrib, "set", set, propagate=True) def teardown(): classes[:] = [None, None, None] @@ -3158,11 +3730,20 @@ class ListenerTest(fixtures.ORMTest): (make_c, instrument_c), (instrument_c, attr_c), (make_a, make_b), - (make_b, make_c) + (make_b, make_c), + ] + elements = [ + make_a, + make_b, + make_c, + instrument_a, + instrument_b, + instrument_c, + attr_a, + attr_b, + attr_c, + events_a, ] - elements = [make_a, make_b, make_c, - instrument_a, instrument_b, instrument_c, - attr_a, attr_b, attr_c, events_a] for i, series in enumerate(all_partial_orderings(ordering, elements)): for fn in series: @@ -3188,12 +3769,12 @@ class TestUnlink(fixtures.TestBase): class B(object): pass + self.A = A self.B = B instrumentation.register_class(A) instrumentation.register_class(B) - attributes.register_attribute(A, 'bs', uselist=True, - useobject=True) + attributes.register_attribute(A, "bs", uselist=True, useobject=True) def test_expired(self): A, B = self.A, self.B @@ -3202,10 +3783,7 @@ class TestUnlink(fixtures.TestBase): a1.bs.append(B()) state = attributes.instance_state(a1) state._expire(state.dict, set()) - assert_raises( - Warning, - coll.append, B() - ) + assert_raises(Warning, coll.append, B()) def test_replaced(self): A, B = self.A, self.B @@ -3226,10 +3804,7 @@ class TestUnlink(fixtures.TestBase): a1.bs.append(B()) state = attributes.instance_state(a1) state._reset(state.dict, "bs") - assert_raises( - Warning, - coll.append, B() - ) + assert_raises(Warning, coll.append, B()) def test_ad_hoc_lazy(self): A, B = self.A, self.B @@ -3238,7 +3813,4 @@ class TestUnlink(fixtures.TestBase): a1.bs.append(B()) state = attributes.instance_state(a1) _set_callable(state, state.dict, "bs", lambda: B()) - assert_raises( - Warning, - coll.append, B() - ) + assert_raises(Warning, coll.append, B()) diff --git a/test/orm/test_backref_mutations.py b/test/orm/test_backref_mutations.py index e50d3ba424..7faa650d72 100644 --- a/test/orm/test_backref_mutations.py +++ b/test/orm/test_backref_mutations.py @@ -13,8 +13,15 @@ from sqlalchemy.testing import assert_raises, assert_raises_message from sqlalchemy import Integer, String, ForeignKey, Sequence, exc as sa_exc from sqlalchemy.testing.schema import Table from sqlalchemy.testing.schema import Column -from sqlalchemy.orm import mapper, relationship, create_session, \ - class_mapper, backref, sessionmaker, Session +from sqlalchemy.orm import ( + mapper, + relationship, + create_session, + class_mapper, + backref, + sessionmaker, + Session, +) from sqlalchemy.orm import attributes, exc as orm_exc from sqlalchemy import testing from sqlalchemy.testing import eq_, is_ @@ -27,15 +34,19 @@ class O2MCollectionTest(_fixtures.FixtureTest): @classmethod def setup_mappers(cls): - Address, addresses, users, User = (cls.classes.Address, - cls.tables.addresses, - cls.tables.users, - cls.classes.User) + Address, addresses, users, User = ( + cls.classes.Address, + cls.tables.addresses, + cls.tables.users, + cls.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, backref="user"), - )) + mapper( + User, + users, + properties=dict(addresses=relationship(Address, backref="user")), + ) def test_collection_move_hitslazy(self): User, Address = self.classes.User, self.classes.Address @@ -44,8 +55,8 @@ class O2MCollectionTest(_fixtures.FixtureTest): a1 = Address(email_address="address1") a2 = Address(email_address="address2") a3 = Address(email_address="address3") - u1 = User(name='jack', addresses=[a1, a2, a3]) - u2 = User(name='ed') + u1 = User(name="jack", addresses=[a1, a2, a3]) + u2 = User(name="ed") sess.add_all([u1, a1, a2, a3]) sess.commit() @@ -55,6 +66,7 @@ class O2MCollectionTest(_fixtures.FixtureTest): u2.addresses.append(a1) u2.addresses.append(a2) u2.addresses.append(a3) + self.assert_sql_count(testing.db, go, 0) def test_collection_move_preloaded(self): @@ -62,9 +74,9 @@ class O2MCollectionTest(_fixtures.FixtureTest): sess = sessionmaker()() a1 = Address(email_address="address1") - u1 = User(name='jack', addresses=[a1]) + u1 = User(name="jack", addresses=[a1]) - u2 = User(name='ed') + u2 = User(name="ed") sess.add_all([u1, u2]) sess.commit() # everything is expired @@ -85,9 +97,9 @@ class O2MCollectionTest(_fixtures.FixtureTest): sess = sessionmaker()() a1 = Address(email_address="address1") - u1 = User(name='jack', addresses=[a1]) + u1 = User(name="jack", addresses=[a1]) - u2 = User(name='ed') + u2 = User(name="ed") sess.add_all([u1, u2]) sess.commit() # everything is expired @@ -106,9 +118,9 @@ class O2MCollectionTest(_fixtures.FixtureTest): sess = sessionmaker()() a1 = Address(email_address="address1") - u1 = User(name='jack', addresses=[a1]) + u1 = User(name="jack", addresses=[a1]) - u2 = User(name='ed') + u2 = User(name="ed") sess.add_all([u1, u2]) sess.commit() # everything is expired @@ -131,9 +143,9 @@ class O2MCollectionTest(_fixtures.FixtureTest): sess = sessionmaker()() - u1 = User(name='jack') - u2 = User(name='ed') - a1 = Address(email_address='a1') + u1 = User(name="jack") + u2 = User(name="ed") + a1 = Address(email_address="a1") a1.user = u1 sess.add_all([u1, u2, a1]) sess.commit() @@ -157,9 +169,9 @@ class O2MCollectionTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address sess = sessionmaker()() - u1 = User(name='jack') - u2 = User(name='ed') - a1 = Address(email_address='a1') + u1 = User(name="jack") + u2 = User(name="ed") + a1 = Address(email_address="a1") a1.user = u1 sess.add_all([u1, u2, a1]) sess.commit() @@ -169,6 +181,7 @@ class O2MCollectionTest(_fixtures.FixtureTest): # PASSIVE_NO_FETCH flag. def go(): a1.user = u2 + self.assert_sql_count(testing.db, go, 0) assert a1 not in u1.addresses @@ -178,8 +191,8 @@ class O2MCollectionTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address sess = sessionmaker()() - u1 = User(name='jack') - a1 = Address(email_address='a1') + u1 = User(name="jack") + a1 = Address(email_address="a1") a1.user = u1 sess.add_all([u1, a1]) sess.commit() @@ -187,6 +200,7 @@ class O2MCollectionTest(_fixtures.FixtureTest): # works for None too def go(): a1.user = None + self.assert_sql_count(testing.db, go, 0) assert a1 not in u1.addresses @@ -196,9 +210,9 @@ class O2MCollectionTest(_fixtures.FixtureTest): sess = sessionmaker()() - u1 = User(name='jack') - u2 = User(name='ed') - a1 = Address(email_address='a1') + u1 = User(name="jack") + u2 = User(name="ed") + a1 = Address(email_address="a1") a1.user = u1 sess.add_all([u1, u2, a1]) sess.commit() @@ -216,9 +230,9 @@ class O2MCollectionTest(_fixtures.FixtureTest): sess = sessionmaker()() - u1 = User(name='jack') - u2 = User(name='ed') - a1 = Address(email_address='a1') + u1 = User(name="jack") + u2 = User(name="ed") + a1 = Address(email_address="a1") a1.user = u1 sess.add_all([u1, u2, a1]) sess.commit() @@ -238,9 +252,9 @@ class O2MCollectionTest(_fixtures.FixtureTest): def test_collection_assignment_mutates_previous_one(self): User, Address = self.classes.User, self.classes.Address - u1 = User(name='jack') - u2 = User(name='ed') - a1 = Address(email_address='a1') + u1 = User(name="jack") + u2 = User(name="ed") + a1 = Address(email_address="a1") u1.addresses.append(a1) is_(a1.user, u1) @@ -254,8 +268,8 @@ class O2MCollectionTest(_fixtures.FixtureTest): def test_collection_assignment_mutates_previous_two(self): User, Address = self.classes.User, self.classes.Address - u1 = User(name='jack') - a1 = Address(email_address='a1') + u1 = User(name="jack") + a1 = Address(email_address="a1") u1.addresses.append(a1) @@ -267,8 +281,8 @@ class O2MCollectionTest(_fixtures.FixtureTest): def test_del_from_collection(self): User, Address = self.classes.User, self.classes.Address - u1 = User(name='jack') - a1 = Address(email_address='a1') + u1 = User(name="jack") + a1 = Address(email_address="a1") u1.addresses.append(a1) @@ -281,8 +295,8 @@ class O2MCollectionTest(_fixtures.FixtureTest): def test_del_from_scalar(self): User, Address = self.classes.User, self.classes.Address - u1 = User(name='jack') - a1 = Address(email_address='a1') + u1 = User(name="jack") + a1 = Address(email_address="a1") u1.addresses.append(a1) @@ -308,10 +322,7 @@ class O2MCollectionTest(_fixtures.FixtureTest): assert a3.user is u1 - eq_( - u1.addresses, - [a1, a3, a2] - ) + eq_(u1.addresses, [a1, a3, a2]) def test_straight_remove(self): User, Address = self.classes.User, self.classes.Address @@ -328,10 +339,7 @@ class O2MCollectionTest(_fixtures.FixtureTest): del u1.addresses[2] assert a3.user is None - eq_( - u1.addresses, - [a1, a2] - ) + eq_(u1.addresses, [a1, a2]) def test_append_del(self): User, Address = self.classes.User, self.classes.Address @@ -349,10 +357,7 @@ class O2MCollectionTest(_fixtures.FixtureTest): del u1.addresses[1] assert a2.user is u1 - eq_( - u1.addresses, - [a1, a3, a2] - ) + eq_(u1.addresses, [a1, a3, a2]) def test_bulk_replace(self): User, Address = self.classes.User, self.classes.Address @@ -372,10 +377,7 @@ class O2MCollectionTest(_fixtures.FixtureTest): u1.addresses = [a1, a2, a1] assert a3.user is None - eq_( - u1.addresses, - [a1, a2, a1] - ) + eq_(u1.addresses, [a1, a2, a1]) class O2OScalarBackrefMoveTest(_fixtures.FixtureTest): @@ -383,25 +385,32 @@ class O2OScalarBackrefMoveTest(_fixtures.FixtureTest): @classmethod def setup_mappers(cls): - Address, addresses, users, User = (cls.classes.Address, - cls.tables.addresses, - cls.tables.users, - cls.classes.User) + Address, addresses, users, User = ( + cls.classes.Address, + cls.tables.addresses, + cls.tables.users, + cls.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties={ - 'address': relationship(Address, backref=backref("user"), - uselist=False) - }) + mapper( + User, + users, + properties={ + "address": relationship( + Address, backref=backref("user"), uselist=False + ) + }, + ) def test_collection_move_preloaded(self): User, Address = self.classes.User, self.classes.Address sess = sessionmaker()() a1 = Address(email_address="address1") - u1 = User(name='jack', address=a1) + u1 = User(name="jack", address=a1) - u2 = User(name='ed') + u2 = User(name="ed") sess.add_all([u1, u2]) sess.commit() # everything is expired @@ -426,7 +435,7 @@ class O2OScalarBackrefMoveTest(_fixtures.FixtureTest): sess = sessionmaker()() a1 = Address(email_address="address1") a2 = Address(email_address="address1") - u1 = User(name='jack', address=a1) + u1 = User(name="jack", address=a1) sess.add_all([u1, a1, a2]) sess.commit() # everything is expired @@ -449,9 +458,9 @@ class O2OScalarBackrefMoveTest(_fixtures.FixtureTest): sess = sessionmaker()() a1 = Address(email_address="address1") - u1 = User(name='jack', address=a1) + u1 = User(name="jack", address=a1) - u2 = User(name='ed') + u2 = User(name="ed") sess.add_all([u1, u2]) sess.commit() # everything is expired @@ -472,7 +481,7 @@ class O2OScalarBackrefMoveTest(_fixtures.FixtureTest): sess = sessionmaker()() a1 = Address(email_address="address1") a2 = Address(email_address="address1") - u1 = User(name='jack', address=a1) + u1 = User(name="jack", address=a1) sess.add_all([u1, a1, a2]) sess.commit() # everything is expired @@ -492,9 +501,9 @@ class O2OScalarBackrefMoveTest(_fixtures.FixtureTest): sess = sessionmaker()() a1 = Address(email_address="address1") - u1 = User(name='jack', address=a1) + u1 = User(name="jack", address=a1) - u2 = User(name='ed') + u2 = User(name="ed") sess.add_all([u1, u2]) sess.commit() # everything is expired @@ -520,7 +529,7 @@ class O2OScalarBackrefMoveTest(_fixtures.FixtureTest): sess = sessionmaker()() a1 = Address(email_address="address1") a2 = Address(email_address="address2") - u1 = User(name='jack', address=a1) + u1 = User(name="jack", address=a1) sess.add_all([u1, a1, a2]) sess.commit() # everything is expired @@ -549,24 +558,28 @@ class O2OScalarMoveTest(_fixtures.FixtureTest): @classmethod def setup_mappers(cls): - Address, addresses, users, User = (cls.classes.Address, - cls.tables.addresses, - cls.tables.users, - cls.classes.User) + Address, addresses, users, User = ( + cls.classes.Address, + cls.tables.addresses, + cls.tables.users, + cls.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties={ - 'address': relationship(Address, uselist=False) - }) + mapper( + User, + users, + properties={"address": relationship(Address, uselist=False)}, + ) def test_collection_move_commitfirst(self): User, Address = self.classes.User, self.classes.Address sess = sessionmaker()() a1 = Address(email_address="address1") - u1 = User(name='jack', address=a1) + u1 = User(name="jack", address=a1) - u2 = User(name='ed') + u2 = User(name="ed") sess.add_all([u1, u2]) sess.commit() # everything is expired @@ -589,31 +602,42 @@ class O2OScalarOrphanTest(_fixtures.FixtureTest): @classmethod def setup_mappers(cls): - Address, addresses, users, User = (cls.classes.Address, - cls.tables.addresses, - cls.tables.users, - cls.classes.User) + Address, addresses, users, User = ( + cls.classes.Address, + cls.tables.addresses, + cls.tables.users, + cls.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties={ - 'address': relationship( - Address, uselist=False, - backref=backref('user', single_parent=True, - cascade="all, delete-orphan")) - }) + mapper( + User, + users, + properties={ + "address": relationship( + Address, + uselist=False, + backref=backref( + "user", + single_parent=True, + cascade="all, delete-orphan", + ), + ) + }, + ) def test_m2o_event(self): User, Address = self.classes.User, self.classes.Address sess = sessionmaker()() a1 = Address(email_address="address1") - u1 = User(name='jack', address=a1) + u1 = User(name="jack", address=a1) sess.add(u1) sess.commit() sess.expunge(u1) - u2 = User(name='ed') + u2 = User(name="ed") # the _SingleParent extension sets the backref get to "active" ! # u1 gets loaded and deleted u2.address = a1 @@ -626,17 +650,23 @@ class M2MCollectionMoveTest(_fixtures.FixtureTest): @classmethod def setup_mappers(cls): - keywords, items, item_keywords, \ - Keyword, Item = (cls.tables.keywords, - cls.tables.items, - cls.tables.item_keywords, - cls.classes.Keyword, - cls.classes.Item) - - mapper(Item, items, properties={ - 'keywords': relationship(Keyword, secondary=item_keywords, - backref='items') - }) + keywords, items, item_keywords, Keyword, Item = ( + cls.tables.keywords, + cls.tables.items, + cls.tables.item_keywords, + cls.classes.Keyword, + cls.classes.Item, + ) + + mapper( + Item, + items, + properties={ + "keywords": relationship( + Keyword, secondary=item_keywords, backref="items" + ) + }, + ) mapper(Keyword, keywords) def test_add_remove_pending_backref(self): @@ -646,13 +676,13 @@ class M2MCollectionMoveTest(_fixtures.FixtureTest): session = Session(autoflush=False) - i1 = Item(description='i1') + i1 = Item(description="i1") session.add(i1) session.commit() - session.expire(i1, ['keywords']) + session.expire(i1, ["keywords"]) - k1 = Keyword(name='k1') + k1 = Keyword(name="k1") k1.items.append(i1) k1.items.remove(i1) eq_(i1.keywords, []) @@ -664,12 +694,12 @@ class M2MCollectionMoveTest(_fixtures.FixtureTest): session = Session(autoflush=False) - k1 = Keyword(name='k1') - i1 = Item(description='i1', keywords=[k1]) + k1 = Keyword(name="k1") + i1 = Item(description="i1", keywords=[k1]) session.add(i1) session.commit() - session.expire(i1, ['keywords']) + session.expire(i1, ["keywords"]) k1.items.remove(i1) k1.items.append(i1) @@ -682,9 +712,9 @@ class M2MCollectionMoveTest(_fixtures.FixtureTest): session = Session(testing.db, autoflush=False) - k1 = Keyword(name='k1') - k2 = Keyword(name='k2') - i1 = Item(description='i1', keywords=[k1]) + k1 = Keyword(name="k1") + k2 = Keyword(name="k2") + i1 = Item(description="i1", keywords=[k1]) session.add(i1) session.add(k2) session.commit() @@ -693,9 +723,12 @@ class M2MCollectionMoveTest(_fixtures.FixtureTest): # the pending # list is still here. eq_( - set(attributes.instance_state(i1). - _pending_mutations['keywords'].added_items), - set([k2]) + set( + attributes.instance_state(i1) + ._pending_mutations["keywords"] + .added_items + ), + set([k2]), ) # because autoflush is off, k2 is still # coming in from pending @@ -705,28 +738,28 @@ class M2MCollectionMoveTest(_fixtures.FixtureTest): eq_(session.scalar("select count(*) from item_keywords"), 1) # the pending collection was removed - assert 'keywords' not in attributes.\ - instance_state(i1).\ - _pending_mutations + assert ( + "keywords" not in attributes.instance_state(i1)._pending_mutations + ) def test_duplicate_adds(self): Item, Keyword = (self.classes.Item, self.classes.Keyword) session = Session(testing.db, autoflush=False) - k1 = Keyword(name='k1') - i1 = Item(description='i1', keywords=[k1]) + k1 = Keyword(name="k1") + i1 = Item(description="i1", keywords=[k1]) session.add(i1) session.commit() k1.items.append(i1) eq_(i1.keywords, [k1, k1]) - session.expire(i1, ['keywords']) + session.expire(i1, ["keywords"]) k1.items.append(i1) eq_(i1.keywords, [k1, k1]) - session.expire(i1, ['keywords']) + session.expire(i1, ["keywords"]) k1.items.append(i1) eq_(i1.keywords, [k1, k1]) @@ -737,11 +770,11 @@ class M2MCollectionMoveTest(_fixtures.FixtureTest): def test_bulk_replace(self): Item, Keyword = (self.classes.Item, self.classes.Keyword) - k1 = Keyword(name='k1') - k2 = Keyword(name='k2') - k3 = Keyword(name='k3') - i1 = Item(description='i1', keywords=[k1, k2]) - i2 = Item(description='i2', keywords=[k3]) + k1 = Keyword(name="k1") + k2 = Keyword(name="k2") + k3 = Keyword(name="k3") + i1 = Item(description="i1", keywords=[k1, k2]) + i2 = Item(description="i2", keywords=[k3]) i1.keywords = [k2, k3] assert i1 in k3.items @@ -754,18 +787,26 @@ class M2MScalarMoveTest(_fixtures.FixtureTest): @classmethod def setup_mappers(cls): - keywords, items, item_keywords, \ - Keyword, Item = (cls.tables.keywords, - cls.tables.items, - cls.tables.item_keywords, - cls.classes.Keyword, - cls.classes.Item) - - mapper(Item, items, properties={ - 'keyword': relationship(Keyword, secondary=item_keywords, - uselist=False, - backref=backref("item", uselist=False)) - }) + keywords, items, item_keywords, Keyword, Item = ( + cls.tables.keywords, + cls.tables.items, + cls.tables.item_keywords, + cls.classes.Keyword, + cls.classes.Item, + ) + + mapper( + Item, + items, + properties={ + "keyword": relationship( + Keyword, + secondary=item_keywords, + uselist=False, + backref=backref("item", uselist=False), + ) + }, + ) mapper(Keyword, keywords) def test_collection_move_preloaded(self): @@ -773,9 +814,9 @@ class M2MScalarMoveTest(_fixtures.FixtureTest): sess = sessionmaker()() - k1 = Keyword(name='k1') - i1 = Item(description='i1', keyword=k1) - i2 = Item(description='i2') + k1 = Keyword(name="k1") + i1 = Item(description="i1", keyword=k1) + i2 = Item(description="i2") sess.add_all([i1, i2, k1]) sess.commit() # everything is expired @@ -796,9 +837,9 @@ class M2MScalarMoveTest(_fixtures.FixtureTest): sess = sessionmaker()() - k1 = Keyword(name='k1') - i1 = Item(description='i1', keyword=k1) - i2 = Item(description='i2') + k1 = Keyword(name="k1") + i1 = Item(description="i1", keyword=k1) + i2 = Item(description="i2") sess.add_all([i1, i2, k1]) sess.commit() # everything is expired @@ -815,9 +856,9 @@ class M2MScalarMoveTest(_fixtures.FixtureTest): sess = sessionmaker()() - k1 = Keyword(name='k1') - i1 = Item(description='i1', keyword=k1) - i2 = Item(description='i2') + k1 = Keyword(name="k1") + i1 = Item(description="i1", keyword=k1) + i2 = Item(description="i2") sess.add_all([i1, i2, k1]) sess.commit() # everything is expired @@ -839,15 +880,19 @@ class O2MStaleBackrefTest(_fixtures.FixtureTest): @classmethod def setup_mappers(cls): - Address, addresses, users, User = (cls.classes.Address, - cls.tables.addresses, - cls.tables.users, - cls.classes.User) + Address, addresses, users, User = ( + cls.classes.Address, + cls.tables.addresses, + cls.tables.users, + cls.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, backref="user"), - )) + mapper( + User, + users, + properties=dict(addresses=relationship(Address, backref="user")), + ) def test_backref_pop_m2o(self): User, Address = self.classes.User, self.classes.Address @@ -870,17 +915,23 @@ class M2MStaleBackrefTest(_fixtures.FixtureTest): @classmethod def setup_mappers(cls): - keywords, items, item_keywords, \ - Keyword, Item = (cls.tables.keywords, - cls.tables.items, - cls.tables.item_keywords, - cls.classes.Keyword, - cls.classes.Item) - - mapper(Item, items, properties={ - 'keywords': relationship(Keyword, secondary=item_keywords, - backref='items') - }) + keywords, items, item_keywords, Keyword, Item = ( + cls.tables.keywords, + cls.tables.items, + cls.tables.item_keywords, + cls.classes.Keyword, + cls.classes.Item, + ) + + mapper( + Item, + items, + properties={ + "keywords": relationship( + Keyword, secondary=item_keywords, backref="items" + ) + }, + ) mapper(Keyword, keywords) def test_backref_pop_m2m(self): @@ -895,4 +946,3 @@ class M2MStaleBackrefTest(_fixtures.FixtureTest): i1.keywords = [] k2.items.remove(i1) assert len(k2.items) == 0 - diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py index f85c3de012..5c87d9429a 100644 --- a/test/orm/test_bind.py +++ b/test/orm/test_bind.py @@ -15,10 +15,12 @@ class BindIntegrationTest(_fixtures.FixtureTest): run_inserts = None def test_mapped_binds(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) # ensure tables are unbound m2 = sa.MetaData() @@ -26,28 +28,43 @@ class BindIntegrationTest(_fixtures.FixtureTest): addresses_unbound = addresses.tometadata(m2) mapper(Address, addresses_unbound) - mapper(User, users_unbound, properties={ - 'addresses': relationship(Address, - backref=backref("user", cascade="all"), - cascade="all")}) + mapper( + User, + users_unbound, + properties={ + "addresses": relationship( + Address, + backref=backref("user", cascade="all"), + cascade="all", + ) + }, + ) - sess = Session(binds={User: self.metadata.bind, - Address: self.metadata.bind}) + sess = Session( + binds={User: self.metadata.bind, Address: self.metadata.bind} + ) - u1 = User(id=1, name='ed') + u1 = User(id=1, name="ed") sess.add(u1) - eq_(sess.query(User).filter(User.id == 1).all(), - [User(id=1, name='ed')]) + eq_( + sess.query(User).filter(User.id == 1).all(), + [User(id=1, name="ed")], + ) # test expression binding - sess.execute(users_unbound.insert(), params=dict(id=2, - name='jack')) - eq_(sess.execute(users_unbound.select(users_unbound.c.id - == 2)).fetchall(), [(2, 'jack')]) + sess.execute(users_unbound.insert(), params=dict(id=2, name="jack")) + eq_( + sess.execute( + users_unbound.select(users_unbound.c.id == 2) + ).fetchall(), + [(2, "jack")], + ) - eq_(sess.execute(users_unbound.select(User.id == 2)).fetchall(), - [(2, 'jack')]) + eq_( + sess.execute(users_unbound.select(User.id == 2)).fetchall(), + [(2, "jack")], + ) sess.execute(users_unbound.delete()) eq_(sess.execute(users_unbound.select()).fetchall(), []) @@ -55,10 +72,12 @@ class BindIntegrationTest(_fixtures.FixtureTest): sess.close() def test_table_binds(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) # ensure tables are unbound m2 = sa.MetaData() @@ -66,27 +85,46 @@ class BindIntegrationTest(_fixtures.FixtureTest): addresses_unbound = addresses.tometadata(m2) mapper(Address, addresses_unbound) - mapper(User, users_unbound, properties={ - 'addresses': relationship(Address, - backref=backref("user", cascade="all"), - cascade="all")}) + mapper( + User, + users_unbound, + properties={ + "addresses": relationship( + Address, + backref=backref("user", cascade="all"), + cascade="all", + ) + }, + ) - Session = sessionmaker(binds={users_unbound: self.metadata.bind, - addresses_unbound: self.metadata.bind}) + Session = sessionmaker( + binds={ + users_unbound: self.metadata.bind, + addresses_unbound: self.metadata.bind, + } + ) sess = Session() - u1 = User(id=1, name='ed') + u1 = User(id=1, name="ed") sess.add(u1) - eq_(sess.query(User).filter(User.id == 1).all(), - [User(id=1, name='ed')]) + eq_( + sess.query(User).filter(User.id == 1).all(), + [User(id=1, name="ed")], + ) - sess.execute(users_unbound.insert(), params=dict(id=2, name='jack')) + sess.execute(users_unbound.insert(), params=dict(id=2, name="jack")) - eq_(sess.execute(users_unbound.select(users_unbound.c.id - == 2)).fetchall(), [(2, 'jack')]) + eq_( + sess.execute( + users_unbound.select(users_unbound.c.id == 2) + ).fetchall(), + [(2, "jack")], + ) - eq_(sess.execute(users_unbound.select(User.id == 2)).fetchall(), - [(2, 'jack')]) + eq_( + sess.execute(users_unbound.select(User.id == 2)).fetchall(), + [(2, "jack")], + ) sess.execute(users_unbound.delete()) eq_(sess.execute(users_unbound.select()).fetchall(), []) @@ -99,20 +137,22 @@ class BindIntegrationTest(_fixtures.FixtureTest): mapper(User, users) session = create_session() - session.execute(users.insert(), dict(name='Johnny')) + session.execute(users.insert(), dict(name="Johnny")) - assert len(session.query(User).filter_by(name='Johnny').all()) == 1 + assert len(session.query(User).filter_by(name="Johnny").all()) == 1 session.execute(users.delete()) - assert len(session.query(User).filter_by(name='Johnny').all()) == 0 + assert len(session.query(User).filter_by(name="Johnny").all()) == 0 session.close() def test_bind_arguments(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) mapper(Address, addresses) @@ -130,11 +170,16 @@ class BindIntegrationTest(_fixtures.FixtureTest): assert sess.connection(mapper=Address, bind=e1).engine is e1 assert sess.connection(mapper=Address).engine is e2 assert sess.connection(clause=addresses.select()).engine is e2 - assert sess.connection(mapper=User, - clause=addresses.select()).engine is e1 - assert sess.connection(mapper=User, - clause=addresses.select(), - bind=e2).engine is e2 + assert ( + sess.connection(mapper=User, clause=addresses.select()).engine + is e1 + ) + assert ( + sess.connection( + mapper=User, clause=addresses.select(), bind=e2 + ).engine + is e2 + ) sess.close() @@ -144,7 +189,9 @@ class BindIntegrationTest(_fixtures.FixtureTest): assert_raises_message( sa.exc.ArgumentError, "Not an acceptable bind target: foobar", - sess.bind_mapper, "foobar", testing.db + sess.bind_mapper, + "foobar", + testing.db, ) mapper(self.classes.User, self.tables.users) @@ -153,7 +200,9 @@ class BindIntegrationTest(_fixtures.FixtureTest): assert_raises_message( sa.exc.ArgumentError, "Not an acceptable bind target: User()", - sess.bind_mapper, u_object, testing.db + sess.bind_mapper, + u_object, + testing.db, ) @engines.close_open_connections @@ -165,17 +214,22 @@ class BindIntegrationTest(_fixtures.FixtureTest): sess = create_session(bind=c) sess.begin() transaction = sess.transaction - u = User(name='u1') + u = User(name="u1") sess.add(u) sess.flush() - assert transaction._connection_for_bind(testing.db, None) \ - is transaction._connection_for_bind(c, None) is c - - assert_raises_message(sa.exc.InvalidRequestError, - 'Session already has a Connection ' - 'associated', - transaction._connection_for_bind, - testing.db.connect(), None) + assert ( + transaction._connection_for_bind(testing.db, None) + is transaction._connection_for_bind(c, None) + is c + ) + + assert_raises_message( + sa.exc.InvalidRequestError, + "Session already has a Connection " "associated", + transaction._connection_for_bind, + testing.db.connect(), + None, + ) transaction.rollback() assert len(sess.query(User).all()) == 0 sess.close() @@ -187,7 +241,7 @@ class BindIntegrationTest(_fixtures.FixtureTest): c = testing.db.connect() sess = create_session(bind=c, autocommit=False) - u = User(name='u1') + u = User(name="u1") sess.add(u) sess.flush() sess.close() @@ -195,7 +249,7 @@ class BindIntegrationTest(_fixtures.FixtureTest): assert c.scalar("select count(1) from users") == 0 sess = create_session(bind=c, autocommit=False) - u = User(name='u2') + u = User(name="u2") sess.add(u) sess.flush() sess.commit() @@ -208,7 +262,7 @@ class BindIntegrationTest(_fixtures.FixtureTest): trans = c.begin() sess = create_session(bind=c, autocommit=True) - u = User(name='u3') + u = User(name="u3") sess.add(u) sess.flush() assert c.in_transaction() @@ -218,13 +272,16 @@ class BindIntegrationTest(_fixtures.FixtureTest): class SessionBindTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('test_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', Integer)) + Table( + "test_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", Integer), + ) @classmethod def setup_classes(cls): @@ -238,8 +295,8 @@ class SessionBindTest(fixtures.MappedTest): meta = MetaData() test_table.tometadata(meta) - assert meta.tables['test_table'].bind is None - mapper(Foo, meta.tables['test_table']) + assert meta.tables["test_table"].bind is None + mapper(Foo, meta.tables["test_table"]) def test_session_bind(self): Foo = self.classes.Foo @@ -255,7 +312,7 @@ class SessionBindTest(fixtures.MappedTest): sess.flush() assert sess.query(Foo).get(f.id) is f finally: - if hasattr(bind, 'close'): + if hasattr(bind, "close"): bind.close() def test_session_unbound(self): @@ -265,29 +322,30 @@ class SessionBindTest(fixtures.MappedTest): sess.add(Foo()) assert_raises_message( sa.exc.UnboundExecutionError, - ('Could not locate a bind configured on Mapper|Foo|test_table ' - 'or this Session'), - sess.flush) + ( + "Could not locate a bind configured on Mapper|Foo|test_table " + "or this Session" + ), + sess.flush, + ) class GetBindTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): + Table("base_table", metadata, Column("id", Integer, primary_key=True)) Table( - 'base_table', metadata, - Column('id', Integer, primary_key=True) + "w_mixin_table", metadata, Column("id", Integer, primary_key=True) ) Table( - 'w_mixin_table', metadata, - Column('id', Integer, primary_key=True) + "joined_sub_table", + metadata, + Column("id", ForeignKey("base_table.id"), primary_key=True), ) Table( - 'joined_sub_table', metadata, - Column('id', ForeignKey('base_table.id'), primary_key=True) - ) - Table( - 'concrete_sub_table', metadata, - Column('id', Integer, primary_key=True) + "concrete_sub_table", + metadata, + Column("id", Integer, primary_key=True), ) @classmethod @@ -313,171 +371,118 @@ class GetBindTest(fixtures.MappedTest): mapper(cls.classes.BaseClass, cls.tables.base_table) mapper( cls.classes.JoinedSubClass, - cls.tables.joined_sub_table, inherits=cls.classes.BaseClass) + cls.tables.joined_sub_table, + inherits=cls.classes.BaseClass, + ) mapper( cls.classes.ConcreteSubClass, - cls.tables.concrete_sub_table, inherits=cls.classes.BaseClass, - concrete=True) + cls.tables.concrete_sub_table, + inherits=cls.classes.BaseClass, + concrete=True, + ) def _fixture(self, binds): return Session(binds=binds) def test_fallback_table_metadata(self): session = self._fixture({}) - is_( - session.get_bind(self.classes.BaseClass), - testing.db - ) + is_(session.get_bind(self.classes.BaseClass), testing.db) def test_bind_base_table_base_class(self): base_class_bind = Mock() - session = self._fixture({ - self.tables.base_table: base_class_bind - }) + session = self._fixture({self.tables.base_table: base_class_bind}) - is_( - session.get_bind(self.classes.BaseClass), - base_class_bind - ) + is_(session.get_bind(self.classes.BaseClass), base_class_bind) def test_bind_base_table_joined_sub_class(self): base_class_bind = Mock() - session = self._fixture({ - self.tables.base_table: base_class_bind - }) + session = self._fixture({self.tables.base_table: base_class_bind}) - is_( - session.get_bind(self.classes.BaseClass), - base_class_bind - ) - is_( - session.get_bind(self.classes.JoinedSubClass), - base_class_bind - ) + is_(session.get_bind(self.classes.BaseClass), base_class_bind) + is_(session.get_bind(self.classes.JoinedSubClass), base_class_bind) def test_bind_joined_sub_table_joined_sub_class(self): - base_class_bind = Mock(name='base') - joined_class_bind = Mock(name='joined') - session = self._fixture({ - self.tables.base_table: base_class_bind, - self.tables.joined_sub_table: joined_class_bind - }) - - is_( - session.get_bind(self.classes.BaseClass), - base_class_bind + base_class_bind = Mock(name="base") + joined_class_bind = Mock(name="joined") + session = self._fixture( + { + self.tables.base_table: base_class_bind, + self.tables.joined_sub_table: joined_class_bind, + } ) + + is_(session.get_bind(self.classes.BaseClass), base_class_bind) # joined table inheritance has to query based on the base # table, so this is what we expect - is_( - session.get_bind(self.classes.JoinedSubClass), - base_class_bind - ) + is_(session.get_bind(self.classes.JoinedSubClass), base_class_bind) def test_bind_base_table_concrete_sub_class(self): base_class_bind = Mock() - session = self._fixture({ - self.tables.base_table: base_class_bind - }) + session = self._fixture({self.tables.base_table: base_class_bind}) - is_( - session.get_bind(self.classes.ConcreteSubClass), - testing.db - ) + is_(session.get_bind(self.classes.ConcreteSubClass), testing.db) def test_bind_sub_table_concrete_sub_class(self): - base_class_bind = Mock(name='base') - concrete_sub_bind = Mock(name='concrete') - - session = self._fixture({ - self.tables.base_table: base_class_bind, - self.tables.concrete_sub_table: concrete_sub_bind - }) - - is_( - session.get_bind(self.classes.BaseClass), - base_class_bind - ) - is_( - session.get_bind(self.classes.ConcreteSubClass), - concrete_sub_bind + base_class_bind = Mock(name="base") + concrete_sub_bind = Mock(name="concrete") + + session = self._fixture( + { + self.tables.base_table: base_class_bind, + self.tables.concrete_sub_table: concrete_sub_bind, + } ) + is_(session.get_bind(self.classes.BaseClass), base_class_bind) + is_(session.get_bind(self.classes.ConcreteSubClass), concrete_sub_bind) + def test_bind_base_class_base_class(self): base_class_bind = Mock() - session = self._fixture({ - self.classes.BaseClass: base_class_bind - }) + session = self._fixture({self.classes.BaseClass: base_class_bind}) - is_( - session.get_bind(self.classes.BaseClass), - base_class_bind - ) + is_(session.get_bind(self.classes.BaseClass), base_class_bind) def test_bind_mixin_class_simple_class(self): base_class_bind = Mock() - session = self._fixture({ - self.classes.MixinOne: base_class_bind - }) + session = self._fixture({self.classes.MixinOne: base_class_bind}) - is_( - session.get_bind(self.classes.ClassWMixin), - base_class_bind - ) + is_(session.get_bind(self.classes.ClassWMixin), base_class_bind) def test_bind_base_class_joined_sub_class(self): base_class_bind = Mock() - session = self._fixture({ - self.classes.BaseClass: base_class_bind - }) + session = self._fixture({self.classes.BaseClass: base_class_bind}) - is_( - session.get_bind(self.classes.JoinedSubClass), - base_class_bind - ) + is_(session.get_bind(self.classes.JoinedSubClass), base_class_bind) def test_bind_joined_sub_class_joined_sub_class(self): - base_class_bind = Mock(name='base') - joined_class_bind = Mock(name='joined') - session = self._fixture({ - self.classes.BaseClass: base_class_bind, - self.classes.JoinedSubClass: joined_class_bind - }) - - is_( - session.get_bind(self.classes.BaseClass), - base_class_bind - ) - is_( - session.get_bind(self.classes.JoinedSubClass), - joined_class_bind + base_class_bind = Mock(name="base") + joined_class_bind = Mock(name="joined") + session = self._fixture( + { + self.classes.BaseClass: base_class_bind, + self.classes.JoinedSubClass: joined_class_bind, + } ) + is_(session.get_bind(self.classes.BaseClass), base_class_bind) + is_(session.get_bind(self.classes.JoinedSubClass), joined_class_bind) + def test_bind_base_class_concrete_sub_class(self): base_class_bind = Mock() - session = self._fixture({ - self.classes.BaseClass: base_class_bind - }) + session = self._fixture({self.classes.BaseClass: base_class_bind}) - is_( - session.get_bind(self.classes.ConcreteSubClass), - base_class_bind - ) + is_(session.get_bind(self.classes.ConcreteSubClass), base_class_bind) def test_bind_sub_class_concrete_sub_class(self): - base_class_bind = Mock(name='base') - concrete_sub_bind = Mock(name='concrete') - - session = self._fixture({ - self.classes.BaseClass: base_class_bind, - self.classes.ConcreteSubClass: concrete_sub_bind - }) - - is_( - session.get_bind(self.classes.BaseClass), - base_class_bind - ) - is_( - session.get_bind(self.classes.ConcreteSubClass), - concrete_sub_bind + base_class_bind = Mock(name="base") + concrete_sub_bind = Mock(name="concrete") + + session = self._fixture( + { + self.classes.BaseClass: base_class_bind, + self.classes.ConcreteSubClass: concrete_sub_bind, + } ) + + is_(session.get_bind(self.classes.BaseClass), base_class_bind) + is_(session.get_bind(self.classes.ConcreteSubClass), concrete_sub_bind) diff --git a/test/orm/test_bulk.py b/test/orm/test_bulk.py index 159d2debf1..1d253a9078 100644 --- a/test/orm/test_bulk.py +++ b/test/orm/test_bulk.py @@ -11,17 +11,21 @@ from test.orm import _fixtures class BulkTest(testing.AssertsExecutionResults): run_inserts = None - run_define_tables = 'each' + run_define_tables = "each" class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('version_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('version_id', Integer, nullable=False), - Column('value', String(40), nullable=False)) + Table( + "version_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("version_id", Integer, nullable=False), + Column("value", String(40), nullable=False), + ) @classmethod def setup_classes(cls): @@ -40,12 +44,9 @@ class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest): s = Session() - s.bulk_save_objects([Foo(value='value')]) + s.bulk_save_objects([Foo(value="value")]) - eq_( - s.query(Foo).all(), - [Foo(version_id=1, value='value')] - ) + eq_(s.query(Foo).all(), [Foo(version_id=1, value="value")]) @testing.emits_warning(r".*versioning cannot be verified") def test_bulk_update_via_save(self): @@ -53,22 +54,18 @@ class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest): s = Session() - s.add(Foo(value='value')) + s.add(Foo(value="value")) s.commit() f1 = s.query(Foo).first() - f1.value = 'new value' + f1.value = "new value" s.bulk_save_objects([f1]) s.expunge_all() - eq_( - s.query(Foo).all(), - [Foo(version_id=2, value='new value')] - ) + eq_(s.query(Foo).all(), [Foo(version_id=2, value="new value")]) class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): - @classmethod def setup_mappers(cls): User, Address, Order = cls.classes("User", "Address", "Order") @@ -79,37 +76,30 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): mapper(Order, o) def test_bulk_save_return_defaults(self): - User, = self.classes("User",) + User, = self.classes("User") s = Session() - objects = [ - User(name="u1"), - User(name="u2"), - User(name="u3") - ] - assert 'id' not in objects[0].__dict__ + objects = [User(name="u1"), User(name="u2"), User(name="u3")] + assert "id" not in objects[0].__dict__ with self.sql_execution_asserter() as asserter: s.bulk_save_objects(objects, return_defaults=True) asserter.assert_( CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", - [{'name': 'u1'}] + "INSERT INTO users (name) VALUES (:name)", [{"name": "u1"}] ), CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", - [{'name': 'u2'}] + "INSERT INTO users (name) VALUES (:name)", [{"name": "u2"}] ), CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", - [{'name': 'u3'}] + "INSERT INTO users (name) VALUES (:name)", [{"name": "u3"}] ), ) - eq_(objects[0].__dict__['id'], 1) + eq_(objects[0].__dict__["id"], 1) def test_bulk_save_mappings_preserve_order(self): - User, = self.classes("User", ) + User, = self.classes("User") s = Session() @@ -131,12 +121,18 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): from sqlalchemy import inspect def _bulk_save_mappings( - mapper, mappings, isupdate, isstates, - return_defaults, update_changed_only, render_nulls): + mapper, + mappings, + isupdate, + isstates, + return_defaults, + update_changed_only, + render_nulls, + ): mock_method(list(mappings), isupdate) mock_method = mock.Mock() - with mock.patch.object(s, '_bulk_save_mappings', _bulk_save_mappings): + with mock.patch.object(s, "_bulk_save_mappings", _bulk_save_mappings): s.bulk_save_objects(objects) eq_( mock_method.mock_calls, @@ -144,30 +140,26 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): mock.call([inspect(user1)], True), mock.call([inspect(user3)], False), mock.call([inspect(user2)], True), - ] + ], ) mock_method = mock.Mock() - with mock.patch.object(s, '_bulk_save_mappings', _bulk_save_mappings): + with mock.patch.object(s, "_bulk_save_mappings", _bulk_save_mappings): s.bulk_save_objects(objects, preserve_order=False) eq_( mock_method.mock_calls, [ mock.call([inspect(user3)], False), mock.call([inspect(user1), inspect(user2)], True), - ] + ], ) def test_bulk_save_no_defaults(self): - User, = self.classes("User",) + User, = self.classes("User") s = Session() - objects = [ - User(name="u1"), - User(name="u2"), - User(name="u3") - ] - assert 'id' not in objects[0].__dict__ + objects = [User(name="u1"), User(name="u2"), User(name="u3")] + assert "id" not in objects[0].__dict__ with self.sql_execution_asserter() as asserter: s.bulk_save_objects(objects) @@ -175,25 +167,21 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): asserter.assert_( CompiledSQL( "INSERT INTO users (name) VALUES (:name)", - [{'name': 'u1'}, {'name': 'u2'}, {'name': 'u3'}] - ), + [{"name": "u1"}, {"name": "u2"}, {"name": "u3"}], + ) ) - assert 'id' not in objects[0].__dict__ + assert "id" not in objects[0].__dict__ def test_bulk_save_updated_include_unchanged(self): - User, = self.classes("User",) + User, = self.classes("User") s = Session(expire_on_commit=False) - objects = [ - User(name="u1"), - User(name="u2"), - User(name="u3") - ] + objects = [User(name="u1"), User(name="u2"), User(name="u3")] s.add_all(objects) s.commit() - objects[0].name = 'u1new' - objects[2].name = 'u3new' + objects[0].name = "u1new" + objects[2].name = "u3new" s = Session() with self.sql_execution_asserter() as asserter: @@ -201,23 +189,20 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): asserter.assert_( CompiledSQL( - "UPDATE users SET name=:name WHERE " - "users.id = :users_id", - [{'users_id': 1, 'name': 'u1new'}, - {'users_id': 2, 'name': 'u2'}, - {'users_id': 3, 'name': 'u3new'}] + "UPDATE users SET name=:name WHERE " "users.id = :users_id", + [ + {"users_id": 1, "name": "u1new"}, + {"users_id": 2, "name": "u2"}, + {"users_id": 3, "name": "u3new"}, + ], ) ) def test_bulk_update(self): - User, = self.classes("User",) + User, = self.classes("User") s = Session(expire_on_commit=False) - objects = [ - User(name="u1"), - User(name="u2"), - User(name="u3") - ] + objects = [User(name="u1"), User(name="u2"), User(name="u3")] s.add_all(objects) s.commit() @@ -225,61 +210,73 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): with self.sql_execution_asserter() as asserter: s.bulk_update_mappings( User, - [{'id': 1, 'name': 'u1new'}, - {'id': 2, 'name': 'u2'}, - {'id': 3, 'name': 'u3new'}] + [ + {"id": 1, "name": "u1new"}, + {"id": 2, "name": "u2"}, + {"id": 3, "name": "u3new"}, + ], ) asserter.assert_( CompiledSQL( "UPDATE users SET name=:name WHERE users.id = :users_id", - [{'users_id': 1, 'name': 'u1new'}, - {'users_id': 2, 'name': 'u2'}, - {'users_id': 3, 'name': 'u3new'}] + [ + {"users_id": 1, "name": "u1new"}, + {"users_id": 2, "name": "u2"}, + {"users_id": 3, "name": "u3new"}, + ], ) ) def test_bulk_insert(self): - User, = self.classes("User",) + User, = self.classes("User") s = Session() with self.sql_execution_asserter() as asserter: s.bulk_insert_mappings( User, - [{'id': 1, 'name': 'u1new'}, - {'id': 2, 'name': 'u2'}, - {'id': 3, 'name': 'u3new'}] + [ + {"id": 1, "name": "u1new"}, + {"id": 2, "name": "u2"}, + {"id": 3, "name": "u3new"}, + ], ) asserter.assert_( CompiledSQL( "INSERT INTO users (id, name) VALUES (:id, :name)", - [{'id': 1, 'name': 'u1new'}, - {'id': 2, 'name': 'u2'}, - {'id': 3, 'name': 'u3new'}] + [ + {"id": 1, "name": "u1new"}, + {"id": 2, "name": "u2"}, + {"id": 3, "name": "u3new"}, + ], ) ) def test_bulk_insert_render_nulls(self): - Order, = self.classes("Order",) + Order, = self.classes("Order") s = Session() with self.sql_execution_asserter() as asserter: s.bulk_insert_mappings( Order, - [{'id': 1, 'description': 'u1new'}, - {'id': 2, 'description': None}, - {'id': 3, 'description': 'u3new'}], - render_nulls=True + [ + {"id": 1, "description": "u1new"}, + {"id": 2, "description": None}, + {"id": 3, "description": "u3new"}, + ], + render_nulls=True, ) asserter.assert_( CompiledSQL( "INSERT INTO orders (id, description) " "VALUES (:id, :description)", - [{'id': 1, 'description': 'u1new'}, - {'id': 2, 'description': None}, - {'id': 3, 'description': 'u3new'}] + [ + {"id": 1, "description": "u1new"}, + {"id": 2, "description": None}, + {"id": 3, "description": "u3new"}, + ], ) ) @@ -288,15 +285,19 @@ class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'a', metadata, + "a", + metadata, Column( - 'id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('x', Integer), - Column('y', Integer, - server_default=FetchedValue(), - server_onupdate=FetchedValue())) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("x", Integer), + Column( + "y", + Integer, + server_default=FetchedValue(), + server_onupdate=FetchedValue(), + ), + ) @classmethod def setup_classes(cls): @@ -328,8 +329,8 @@ class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest): eq_(a1.id, 1) # force a load a1.x = 5 - s.expire(a1, ['y']) - assert 'y' not in a1.__dict__ + s.expire(a1, ["y"]) + assert "y" not in a1.__dict__ s.bulk_save_objects([a1]) s.commit() @@ -341,25 +342,25 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'people_keys', metadata, - Column( - 'person_id', Integer, - primary_key=True, key='id'), - Column('name', String(50), key='personname')) + "people_keys", + metadata, + Column("person_id", Integer, primary_key=True, key="id"), + Column("name", String(50), key="personname"), + ) Table( - 'people_attrs', metadata, - Column( - 'person_id', Integer, - primary_key=True), - Column('name', String(50))) + "people_attrs", + metadata, + Column("person_id", Integer, primary_key=True), + Column("name", String(50)), + ) Table( - 'people_both', metadata, - Column( - 'person_id', Integer, - primary_key=True, key="id_key"), - Column('name', String(50), key='name_key')) + "people_both", + metadata, + Column("person_id", Integer, primary_key=True, key="id_key"), + Column("name", String(50), key="name_key"), + ) @classmethod def setup_classes(cls): @@ -375,20 +376,30 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest): @classmethod def setup_mappers(cls): PersonKeys, PersonAttrs, PersonBoth = cls.classes( - "PersonKeys", "PersonAttrs", "PersonBoth") + "PersonKeys", "PersonAttrs", "PersonBoth" + ) people_keys, people_attrs, people_both = cls.tables( - "people_keys", "people_attrs", "people_both") + "people_keys", "people_attrs", "people_both" + ) mapper(PersonKeys, people_keys) - mapper(PersonAttrs, people_attrs, properties={ - 'id': people_attrs.c.person_id, - 'personname': people_attrs.c.name - }) + mapper( + PersonAttrs, + people_attrs, + properties={ + "id": people_attrs.c.person_id, + "personname": people_attrs.c.name, + }, + ) - mapper(PersonBoth, people_both, properties={ - 'id': people_both.c.id_key, - 'personname': people_both.c.name_key - }) + mapper( + PersonBoth, + people_both, + properties={ + "id": people_both.c.id_key, + "personname": people_both.c.name_key, + }, + ) def test_insert_keys(self): asserter = self._test_insert(self.classes.PersonKeys) @@ -396,8 +407,8 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest): CompiledSQL( "INSERT INTO people_keys (person_id, name) " "VALUES (:id, :personname)", - [{'id': 5, 'personname': 'thename'}] - ), + [{"id": 5, "personname": "thename"}], + ) ) def test_insert_attrs(self): @@ -406,8 +417,8 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest): CompiledSQL( "INSERT INTO people_attrs (person_id, name) " "VALUES (:person_id, :name)", - [{'person_id': 5, 'name': 'thename'}] - ), + [{"person_id": 5, "name": "thename"}], + ) ) def test_insert_both(self): @@ -416,8 +427,8 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest): CompiledSQL( "INSERT INTO people_both (person_id, name) " "VALUES (:id_key, :name_key)", - [{'id_key': 5, 'name_key': 'thename'}] - ), + [{"id_key": 5, "name_key": "thename"}], + ) ) def test_update_keys(self): @@ -426,8 +437,8 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest): CompiledSQL( "UPDATE people_keys SET name=:personname " "WHERE people_keys.person_id = :people_keys_person_id", - [{'personname': 'newname', 'people_keys_person_id': 5}] - ), + [{"personname": "newname", "people_keys_person_id": 5}], + ) ) @testing.requires.updateable_autoincrement_pks @@ -437,8 +448,8 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest): CompiledSQL( "UPDATE people_attrs SET name=:name " "WHERE people_attrs.person_id = :people_attrs_person_id", - [{'name': 'newname', 'people_attrs_person_id': 5}] - ), + [{"name": "newname", "people_attrs_person_id": 5}], + ) ) @testing.requires.updateable_autoincrement_pks @@ -450,8 +461,8 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest): CompiledSQL( "UPDATE people_both SET name=:name_key " "WHERE people_both.person_id = :people_both_person_id", - [{'name_key': 'newname', 'people_both_person_id': 5}] - ), + [{"name_key": "newname", "people_both_person_id": 5}], + ) ) def _test_insert(self, person_cls): @@ -463,10 +474,7 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest): Person, [{"id": 5, "personname": "thename"}] ) - eq_( - s.query(Person).first(), - Person(id=5, personname="thename") - ) + eq_(s.query(Person).first(), Person(id=5, personname="thename")) return asserter @@ -482,10 +490,7 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest): Person, [{"id": 5, "personname": "newname"}] ) - eq_( - s.query(Person).first(), - Person(id=5, personname="newname") - ) + eq_(s.query(Person).first(), Person(id=5, personname="newname")) return asserter @@ -494,39 +499,55 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'people', metadata, + "people", + metadata, Column( - 'person_id', Integer, + "person_id", + Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('type', String(30))) + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("type", String(30)), + ) Table( - 'engineers', metadata, + "engineers", + metadata, Column( - 'person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('status', String(30)), - Column('primary_language', String(50))) + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("status", String(30)), + Column("primary_language", String(50)), + ) Table( - 'managers', metadata, + "managers", + metadata, Column( - 'person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('status', String(30)), - Column('manager_name', String(50))) + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("status", String(30)), + Column("manager_name", String(50)), + ) Table( - 'boss', metadata, + "boss", + metadata, Column( - 'boss_id', Integer, - ForeignKey('managers.person_id'), - primary_key=True), - Column('golf_swing', String(30))) + "boss_id", + Integer, + ForeignKey("managers.person_id"), + primary_key=True, + ), + Column("golf_swing", String(30)), + ) @classmethod def setup_classes(cls): @@ -547,31 +568,31 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): @classmethod def setup_mappers(cls): - Person, Engineer, Manager, Boss = \ - cls.classes('Person', 'Engineer', 'Manager', 'Boss') - p, e, m, b = cls.tables('people', 'engineers', 'managers', 'boss') + Person, Engineer, Manager, Boss = cls.classes( + "Person", "Engineer", "Manager", "Boss" + ) + p, e, m, b = cls.tables("people", "engineers", "managers", "boss") mapper( - Person, p, polymorphic_on=p.c.type, - polymorphic_identity='person') - mapper(Engineer, e, inherits=Person, polymorphic_identity='engineer') - mapper(Manager, m, inherits=Person, polymorphic_identity='manager') - mapper(Boss, b, inherits=Manager, polymorphic_identity='boss') + Person, p, polymorphic_on=p.c.type, polymorphic_identity="person" + ) + mapper(Engineer, e, inherits=Person, polymorphic_identity="engineer") + mapper(Manager, m, inherits=Person, polymorphic_identity="manager") + mapper(Boss, b, inherits=Manager, polymorphic_identity="boss") def test_bulk_save_joined_inh_return_defaults(self): - Person, Engineer, Manager, Boss = \ - self.classes('Person', 'Engineer', 'Manager', 'Boss') + Person, Engineer, Manager, Boss = self.classes( + "Person", "Engineer", "Manager", "Boss" + ) s = Session() objects = [ - Manager(name='m1', status='s1', manager_name='mn1'), - Engineer(name='e1', status='s2', primary_language='l1'), - Engineer(name='e2', status='s3', primary_language='l2'), - Boss( - name='b1', status='s3', manager_name='mn2', - golf_swing='g1') + Manager(name="m1", status="s1", manager_name="mn1"), + Engineer(name="e1", status="s2", primary_language="l1"), + Engineer(name="e2", status="s3", primary_language="l2"), + Boss(name="b1", status="s3", manager_name="mn2", golf_swing="g1"), ] - assert 'person_id' not in objects[0].__dict__ + assert "person_id" not in objects[0].__dict__ with self.sql_execution_asserter() as asserter: s.bulk_save_objects(objects, return_defaults=True) @@ -579,70 +600,81 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): asserter.assert_( CompiledSQL( "INSERT INTO people (name, type) VALUES (:name, :type)", - [{'type': 'manager', 'name': 'm1'}] + [{"type": "manager", "name": "m1"}], ), CompiledSQL( "INSERT INTO managers (person_id, status, manager_name) " "VALUES (:person_id, :status, :manager_name)", - [{'person_id': 1, 'status': 's1', 'manager_name': 'mn1'}] + [{"person_id": 1, "status": "s1", "manager_name": "mn1"}], ), CompiledSQL( "INSERT INTO people (name, type) VALUES (:name, :type)", - [{'type': 'engineer', 'name': 'e1'}] + [{"type": "engineer", "name": "e1"}], ), CompiledSQL( "INSERT INTO people (name, type) VALUES (:name, :type)", - [{'type': 'engineer', 'name': 'e2'}] + [{"type": "engineer", "name": "e2"}], ), CompiledSQL( "INSERT INTO engineers (person_id, status, primary_language) " "VALUES (:person_id, :status, :primary_language)", - [{'person_id': 2, 'status': 's2', 'primary_language': 'l1'}, - {'person_id': 3, 'status': 's3', 'primary_language': 'l2'}] - + [ + {"person_id": 2, "status": "s2", "primary_language": "l1"}, + {"person_id": 3, "status": "s3", "primary_language": "l2"}, + ], ), CompiledSQL( "INSERT INTO people (name, type) VALUES (:name, :type)", - [{'type': 'boss', 'name': 'b1'}] + [{"type": "boss", "name": "b1"}], ), CompiledSQL( "INSERT INTO managers (person_id, status, manager_name) " "VALUES (:person_id, :status, :manager_name)", - [{'person_id': 4, 'status': 's3', 'manager_name': 'mn2'}] - + [{"person_id": 4, "status": "s3", "manager_name": "mn2"}], ), CompiledSQL( "INSERT INTO boss (boss_id, golf_swing) VALUES " "(:boss_id, :golf_swing)", - [{'boss_id': 4, 'golf_swing': 'g1'}] - ) + [{"boss_id": 4, "golf_swing": "g1"}], + ), ) - eq_(objects[0].__dict__['person_id'], 1) - eq_(objects[3].__dict__['person_id'], 4) - eq_(objects[3].__dict__['boss_id'], 4) + eq_(objects[0].__dict__["person_id"], 1) + eq_(objects[3].__dict__["person_id"], 4) + eq_(objects[3].__dict__["boss_id"], 4) def test_bulk_save_joined_inh_no_defaults(self): - Person, Engineer, Manager, Boss = \ - self.classes('Person', 'Engineer', 'Manager', 'Boss') + Person, Engineer, Manager, Boss = self.classes( + "Person", "Engineer", "Manager", "Boss" + ) s = Session() with self.sql_execution_asserter() as asserter: - s.bulk_save_objects([ - Manager( - person_id=1, - name='m1', status='s1', manager_name='mn1'), - Engineer( - person_id=2, - name='e1', status='s2', primary_language='l1'), - Engineer( - person_id=3, - name='e2', status='s3', primary_language='l2'), - Boss( - person_id=4, boss_id=4, - name='b1', status='s3', manager_name='mn2', - golf_swing='g1') - ], - + s.bulk_save_objects( + [ + Manager( + person_id=1, name="m1", status="s1", manager_name="mn1" + ), + Engineer( + person_id=2, + name="e1", + status="s2", + primary_language="l1", + ), + Engineer( + person_id=3, + name="e2", + status="s3", + primary_language="l2", + ), + Boss( + person_id=4, + boss_id=4, + name="b1", + status="s3", + manager_name="mn2", + golf_swing="g1", + ), + ] ) # the only difference here is that common classes are grouped together. @@ -652,45 +684,50 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): CompiledSQL( "INSERT INTO people (person_id, name, type) VALUES " "(:person_id, :name, :type)", - [{'person_id': 1, 'type': 'manager', 'name': 'm1'}] + [{"person_id": 1, "type": "manager", "name": "m1"}], ), CompiledSQL( "INSERT INTO managers (person_id, status, manager_name) " "VALUES (:person_id, :status, :manager_name)", - [{'status': 's1', 'person_id': 1, 'manager_name': 'mn1'}] + [{"status": "s1", "person_id": 1, "manager_name": "mn1"}], ), CompiledSQL( "INSERT INTO people (person_id, name, type) VALUES " "(:person_id, :name, :type)", - [{'person_id': 2, 'type': 'engineer', 'name': 'e1'}, - {'person_id': 3, 'type': 'engineer', 'name': 'e2'}] + [ + {"person_id": 2, "type": "engineer", "name": "e1"}, + {"person_id": 3, "type": "engineer", "name": "e2"}, + ], ), CompiledSQL( "INSERT INTO engineers (person_id, status, primary_language) " "VALUES (:person_id, :status, :primary_language)", - [{'person_id': 2, 'status': 's2', 'primary_language': 'l1'}, - {'person_id': 3, 'status': 's3', 'primary_language': 'l2'}] + [ + {"person_id": 2, "status": "s2", "primary_language": "l1"}, + {"person_id": 3, "status": "s3", "primary_language": "l2"}, + ], ), CompiledSQL( "INSERT INTO people (person_id, name, type) VALUES " "(:person_id, :name, :type)", - [{'person_id': 4, 'type': 'boss', 'name': 'b1'}] + [{"person_id": 4, "type": "boss", "name": "b1"}], ), CompiledSQL( "INSERT INTO managers (person_id, status, manager_name) " "VALUES (:person_id, :status, :manager_name)", - [{'status': 's3', 'person_id': 4, 'manager_name': 'mn2'}] + [{"status": "s3", "person_id": 4, "manager_name": "mn2"}], ), CompiledSQL( "INSERT INTO boss (boss_id, golf_swing) VALUES " "(:boss_id, :golf_swing)", - [{'boss_id': 4, 'golf_swing': 'g1'}] - ) + [{"boss_id": 4, "golf_swing": "g1"}], + ), ) def test_bulk_insert_joined_inh_return_defaults(self): - Person, Engineer, Manager, Boss = \ - self.classes('Person', 'Engineer', 'Manager', 'Boss') + Person, Engineer, Manager, Boss = self.classes( + "Person", "Engineer", "Manager", "Boss" + ) s = Session() with self.sql_execution_asserter() as asserter: @@ -698,46 +735,53 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): Boss, [ dict( - name='b1', status='s1', manager_name='mn1', - golf_swing='g1' + name="b1", + status="s1", + manager_name="mn1", + golf_swing="g1", ), dict( - name='b2', status='s2', manager_name='mn2', - golf_swing='g2' + name="b2", + status="s2", + manager_name="mn2", + golf_swing="g2", ), dict( - name='b3', status='s3', manager_name='mn3', - golf_swing='g3' + name="b3", + status="s3", + manager_name="mn3", + golf_swing="g3", ), - ], return_defaults=True + ], + return_defaults=True, ) asserter.assert_( CompiledSQL( - "INSERT INTO people (name) VALUES (:name)", - [{'name': 'b1'}] + "INSERT INTO people (name) VALUES (:name)", [{"name": "b1"}] ), CompiledSQL( - "INSERT INTO people (name) VALUES (:name)", - [{'name': 'b2'}] + "INSERT INTO people (name) VALUES (:name)", [{"name": "b2"}] ), CompiledSQL( - "INSERT INTO people (name) VALUES (:name)", - [{'name': 'b3'}] + "INSERT INTO people (name) VALUES (:name)", [{"name": "b3"}] ), CompiledSQL( "INSERT INTO managers (person_id, status, manager_name) " "VALUES (:person_id, :status, :manager_name)", - [{'person_id': 1, 'status': 's1', 'manager_name': 'mn1'}, - {'person_id': 2, 'status': 's2', 'manager_name': 'mn2'}, - {'person_id': 3, 'status': 's3', 'manager_name': 'mn3'}] - + [ + {"person_id": 1, "status": "s1", "manager_name": "mn1"}, + {"person_id": 2, "status": "s2", "manager_name": "mn2"}, + {"person_id": 3, "status": "s3", "manager_name": "mn3"}, + ], ), CompiledSQL( "INSERT INTO boss (boss_id, golf_swing) VALUES " "(:boss_id, :golf_swing)", - [{'golf_swing': 'g1', 'boss_id': 1}, - {'golf_swing': 'g2', 'boss_id': 2}, - {'golf_swing': 'g3', 'boss_id': 3}] - ) + [ + {"golf_swing": "g1", "boss_id": 1}, + {"golf_swing": "g2", "boss_id": 2}, + {"golf_swing": "g3", "boss_id": 3}, + ], + ), ) diff --git a/test/orm/test_bundle.py b/test/orm/test_bundle.py index cf1eb40c98..9b27d247df 100644 --- a/test/orm/test_bundle.py +++ b/test/orm/test_bundle.py @@ -8,26 +8,34 @@ from sqlalchemy.sql.elements import ClauseList class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" - run_inserts = 'once' - run_setup_mappers = 'once' + run_inserts = "once" + run_setup_mappers = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - Table('data', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('d1', String(10)), - Column('d2', String(10)), - Column('d3', String(10))) - - Table('other', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data_id', ForeignKey('data.id')), - Column('o1', String(10))) + Table( + "data", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("d1", String(10)), + Column("d2", String(10)), + Column("d3", String(10)), + ) + + Table( + "other", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data_id", ForeignKey("data.id")), + Column("o1", String(10)), + ) @classmethod def setup_classes(cls): @@ -39,34 +47,38 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): @classmethod def setup_mappers(cls): - mapper(cls.classes.Data, cls.tables.data, properties={ - 'others': relationship(cls.classes.Other) - }) + mapper( + cls.classes.Data, + cls.tables.data, + properties={"others": relationship(cls.classes.Other)}, + ) mapper(cls.classes.Other, cls.tables.other) @classmethod def insert_data(cls): sess = Session() - sess.add_all([ - cls.classes.Data(d1='d%dd1' % i, d2='d%dd2' % i, d3='d%dd3' % i, - others=[cls.classes.Other(o1="d%do%d" % (i, j)) - for j in range(5)]) - for i in range(10) - ]) + sess.add_all( + [ + cls.classes.Data( + d1="d%dd1" % i, + d2="d%dd2" % i, + d3="d%dd3" % i, + others=[ + cls.classes.Other(o1="d%do%d" % (i, j)) + for j in range(5) + ], + ) + for i in range(10) + ] + ) sess.commit() def test_same_named_col_clauselist(self): Data, Other = self.classes("Data", "Other") bundle = Bundle("pk", Data.id, Other.id) - self.assert_compile( - ClauseList(Data.id, Other.id), - "data.id, other.id" - ) - self.assert_compile( - bundle.__clause_element__(), - "data.id, other.id" - ) + self.assert_compile(ClauseList(Data.id, Other.id), "data.id, other.id") + self.assert_compile(bundle.__clause_element__(), "data.id, other.id") def test_same_named_col_in_orderby(self): Data, Other = self.classes("Data", "Other") @@ -79,7 +91,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): "data.d2 AS data_d2, data.d3 AS data_d3, " "other.id AS other_id, other.data_id AS other_data_id, " "other.o1 AS other_o1 " - "FROM data, other ORDER BY data.id, other.id" + "FROM data, other ORDER BY data.id, other.id", ) def test_same_named_col_in_fetch(self): @@ -88,30 +100,31 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): sess = Session() eq_( - sess.query(bundle).filter( - Data.id == Other.id).filter(Data.id < 3).all(), - [((1, 1),), ((2, 2),)] + sess.query(bundle) + .filter(Data.id == Other.id) + .filter(Data.id < 3) + .all(), + [((1, 1),), ((2, 2),)], ) def test_c_attr(self): Data = self.classes.Data - b1 = Bundle('b1', Data.d1, Data.d2) + b1 = Bundle("b1", Data.d1, Data.d2) self.assert_compile( - select([b1.c.d1, b1.c.d2]), - "SELECT data.d1, data.d2 FROM data" + select([b1.c.d1, b1.c.d2]), "SELECT data.d1, data.d2 FROM data" ) def test_result(self): Data = self.classes.Data sess = Session() - b1 = Bundle('b1', Data.d1, Data.d2) + b1 = Bundle("b1", Data.d1, Data.d2) eq_( - sess.query(b1).filter(b1.c.d1.between('d3d1', 'd5d1')).all(), - [(('d3d1', 'd3d2'),), (('d4d1', 'd4d2'),), (('d5d1', 'd5d2'),)] + sess.query(b1).filter(b1.c.d1.between("d3d1", "d5d1")).all(), + [(("d3d1", "d3d2"),), (("d4d1", "d4d2"),), (("d5d1", "d5d2"),)], ) def test_subclass(self): @@ -121,18 +134,19 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): class MyBundle(Bundle): def create_row_processor(self, query, procs, labels): def proc(row): - return dict( - zip(labels, (proc(row) for proc in procs)) - ) + return dict(zip(labels, (proc(row) for proc in procs))) + return proc - b1 = MyBundle('b1', Data.d1, Data.d2) + b1 = MyBundle("b1", Data.d1, Data.d2) eq_( - sess.query(b1).filter(b1.c.d1.between('d3d1', 'd5d1')).all(), - [({'d2': 'd3d2', 'd1': 'd3d1'},), - ({'d2': 'd4d2', 'd1': 'd4d1'},), - ({'d2': 'd5d2', 'd1': 'd5d1'},)] + sess.query(b1).filter(b1.c.d1.between("d3d1", "d5d1")).all(), + [ + ({"d2": "d3d2", "d1": "d3d1"},), + ({"d2": "d4d2", "d1": "d4d1"},), + ({"d2": "d5d2", "d1": "d5d1"},), + ], ) def test_multi_bundle(self): @@ -141,117 +155,136 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): d1 = aliased(Data) - b1 = Bundle('b1', d1.d1, d1.d2) - b2 = Bundle('b2', Data.d1, Other.o1) + b1 = Bundle("b1", d1.d1, d1.d2) + b2 = Bundle("b2", Data.d1, Other.o1) sess = Session() - q = sess.query(b1, b2).join(Data.others).join(d1, d1.id == Data.id).\ - filter(b1.c.d1 == 'd3d1') + q = ( + sess.query(b1, b2) + .join(Data.others) + .join(d1, d1.id == Data.id) + .filter(b1.c.d1 == "d3d1") + ) eq_( q.all(), [ - (('d3d1', 'd3d2'), ('d3d1', 'd3o0')), - (('d3d1', 'd3d2'), ('d3d1', 'd3o1')), - (('d3d1', 'd3d2'), ('d3d1', 'd3o2')), - (('d3d1', 'd3d2'), ('d3d1', 'd3o3')), - (('d3d1', 'd3d2'), ('d3d1', 'd3o4'))] + (("d3d1", "d3d2"), ("d3d1", "d3o0")), + (("d3d1", "d3d2"), ("d3d1", "d3o1")), + (("d3d1", "d3d2"), ("d3d1", "d3o2")), + (("d3d1", "d3d2"), ("d3d1", "d3o3")), + (("d3d1", "d3d2"), ("d3d1", "d3o4")), + ], ) def test_single_entity(self): Data = self.classes.Data sess = Session() - b1 = Bundle('b1', Data.d1, Data.d2, single_entity=True) + b1 = Bundle("b1", Data.d1, Data.d2, single_entity=True) eq_( - sess.query(b1).filter(b1.c.d1.between('d3d1', 'd5d1')).all(), - [('d3d1', 'd3d2'), ('d4d1', 'd4d2'), ('d5d1', 'd5d2')] + sess.query(b1).filter(b1.c.d1.between("d3d1", "d5d1")).all(), + [("d3d1", "d3d2"), ("d4d1", "d4d2"), ("d5d1", "d5d2")], ) def test_single_entity_flag_but_multi_entities(self): Data = self.classes.Data sess = Session() - b1 = Bundle('b1', Data.d1, Data.d2, single_entity=True) - b2 = Bundle('b1', Data.d3, single_entity=True) + b1 = Bundle("b1", Data.d1, Data.d2, single_entity=True) + b2 = Bundle("b1", Data.d3, single_entity=True) eq_( - sess.query(b1, b2).filter(b1.c.d1.between('d3d1', 'd5d1')).all(), + sess.query(b1, b2).filter(b1.c.d1.between("d3d1", "d5d1")).all(), [ - (('d3d1', 'd3d2'), ('d3d3',)), - (('d4d1', 'd4d2'), ('d4d3',)), - (('d5d1', 'd5d2'), ('d5d3',)) - ] + (("d3d1", "d3d2"), ("d3d3",)), + (("d4d1", "d4d2"), ("d4d3",)), + (("d5d1", "d5d2"), ("d5d3",)), + ], ) def test_bundle_nesting(self): Data = self.classes.Data sess = Session() - b1 = Bundle('b1', Data.d1, Bundle('b2', Data.d2, Data.d3)) + b1 = Bundle("b1", Data.d1, Bundle("b2", Data.d2, Data.d3)) eq_( - sess.query(b1). - filter(b1.c.d1.between('d3d1', 'd7d1')). - filter(b1.c.b2.c.d2.between('d4d2', 'd6d2')). - all(), - [(('d4d1', ('d4d2', 'd4d3')),), (('d5d1', ('d5d2', 'd5d3')),), - (('d6d1', ('d6d2', 'd6d3')),)] + sess.query(b1) + .filter(b1.c.d1.between("d3d1", "d7d1")) + .filter(b1.c.b2.c.d2.between("d4d2", "d6d2")) + .all(), + [ + (("d4d1", ("d4d2", "d4d3")),), + (("d5d1", ("d5d2", "d5d3")),), + (("d6d1", ("d6d2", "d6d3")),), + ], ) def test_bundle_nesting_unions(self): Data = self.classes.Data sess = Session() - b1 = Bundle('b1', Data.d1, Bundle('b2', Data.d2, Data.d3)) + b1 = Bundle("b1", Data.d1, Bundle("b2", Data.d2, Data.d3)) - q1 = sess.query(b1).\ - filter(b1.c.d1.between('d3d1', 'd7d1')).\ - filter(b1.c.b2.c.d2.between('d4d2', 'd5d2')) + q1 = ( + sess.query(b1) + .filter(b1.c.d1.between("d3d1", "d7d1")) + .filter(b1.c.b2.c.d2.between("d4d2", "d5d2")) + ) - q2 = sess.query(b1).\ - filter(b1.c.d1.between('d3d1', 'd7d1')).\ - filter(b1.c.b2.c.d2.between('d5d2', 'd6d2')) + q2 = ( + sess.query(b1) + .filter(b1.c.d1.between("d3d1", "d7d1")) + .filter(b1.c.b2.c.d2.between("d5d2", "d6d2")) + ) eq_( q1.union(q2).all(), - [(('d4d1', ('d4d2', 'd4d3')),), (('d5d1', ('d5d2', 'd5d3')),), - (('d6d1', ('d6d2', 'd6d3')),)] + [ + (("d4d1", ("d4d2", "d4d3")),), + (("d5d1", ("d5d2", "d5d3")),), + (("d6d1", ("d6d2", "d6d3")),), + ], ) # naming structure is preserved row = q1.union(q2).first() - eq_(row.b1.d1, 'd4d1') - eq_(row.b1.b2.d2, 'd4d2') + eq_(row.b1.d1, "d4d1") + eq_(row.b1.b2.d2, "d4d2") def test_query_count(self): Data = self.classes.Data - b1 = Bundle('b1', Data.d1, Data.d2) + b1 = Bundle("b1", Data.d1, Data.d2) eq_(Session().query(b1).count(), 10) def test_join_relationship(self): Data = self.classes.Data sess = Session() - b1 = Bundle('b1', Data.d1, Data.d2) + b1 = Bundle("b1", Data.d1, Data.d2) q = sess.query(b1).join(Data.others) - self.assert_compile(q, - "SELECT data.d1 AS data_d1, data.d2 " - "AS data_d2 FROM data " - "JOIN other ON data.id = other.data_id") + self.assert_compile( + q, + "SELECT data.d1 AS data_d1, data.d2 " + "AS data_d2 FROM data " + "JOIN other ON data.id = other.data_id", + ) def test_join_selectable(self): Data = self.classes.Data Other = self.classes.Other sess = Session() - b1 = Bundle('b1', Data.d1, Data.d2) + b1 = Bundle("b1", Data.d1, Data.d2) q = sess.query(b1).join(Other) - self.assert_compile(q, - "SELECT data.d1 AS data_d1, data.d2 AS data_d2 " - "FROM data " - "JOIN other ON data.id = other.data_id") + self.assert_compile( + q, + "SELECT data.d1 AS data_d1, data.d2 AS data_d2 " + "FROM data " + "JOIN other ON data.id = other.data_id", + ) def test_joins_from_adapted_entities(self): Data = self.classes.Data @@ -259,7 +292,7 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): # test for #1853 in terms of bundles # specifically this exercises adapt_to_selectable() - b1 = Bundle('b1', Data.id, Data.d1, Data.d2) + b1 = Bundle("b1", Data.id, Data.d1, Data.d2) session = Session() first = session.query(b1) @@ -280,46 +313,54 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): "data.d2 AS data_d2 FROM data) AS anon_1 " "LEFT OUTER JOIN (SELECT data.id AS id FROM data) AS anon_2 " "ON anon_2.id = anon_1.data_id " - "ORDER BY anon_1.data_id, anon_1.data_d1, anon_1.data_d2") + "ORDER BY anon_1.data_id, anon_1.data_d1, anon_1.data_d2", + ) # tuple nesting still occurs eq_( joined.all(), - [((1, 'd0d1', 'd0d2'),), ((2, 'd1d1', 'd1d2'),), - ((3, 'd2d1', 'd2d2'),), ((4, 'd3d1', 'd3d2'),), - ((5, 'd4d1', 'd4d2'),), ((6, 'd5d1', 'd5d2'),), - ((7, 'd6d1', 'd6d2'),), ((8, 'd7d1', 'd7d2'),), - ((9, 'd8d1', 'd8d2'),), ((10, 'd9d1', 'd9d2'),)] + [ + ((1, "d0d1", "d0d2"),), + ((2, "d1d1", "d1d2"),), + ((3, "d2d1", "d2d2"),), + ((4, "d3d1", "d3d2"),), + ((5, "d4d1", "d4d2"),), + ((6, "d5d1", "d5d2"),), + ((7, "d6d1", "d6d2"),), + ((8, "d7d1", "d7d2"),), + ((9, "d8d1", "d8d2"),), + ((10, "d9d1", "d9d2"),), + ], ) def test_filter_by(self): Data = self.classes.Data - b1 = Bundle('b1', Data.id, Data.d1, Data.d2) + b1 = Bundle("b1", Data.id, Data.d1, Data.d2) sess = Session() self.assert_compile( - sess.query(b1).filter_by(d1='d1'), + sess.query(b1).filter_by(d1="d1"), "SELECT data.id AS data_id, data.d1 AS data_d1, " - "data.d2 AS data_d2 FROM data WHERE data.d1 = :d1_1" + "data.d2 AS data_d2 FROM data WHERE data.d1 = :d1_1", ) def test_clause_expansion(self): Data = self.classes.Data - b1 = Bundle('b1', Data.id, Data.d1, Data.d2) + b1 = Bundle("b1", Data.id, Data.d1, Data.d2) sess = Session() self.assert_compile( sess.query(Data).order_by(b1), "SELECT data.id AS data_id, data.d1 AS data_d1, " "data.d2 AS data_d2, data.d3 AS data_d3 FROM data " - "ORDER BY data.id, data.d1, data.d2" + "ORDER BY data.id, data.d1, data.d2", ) self.assert_compile( sess.query(func.row_number().over(order_by=b1)), "SELECT row_number() OVER (ORDER BY data.id, data.d1, data.d2) " - "AS anon_1 FROM data" + "AS anon_1 FROM data", ) diff --git a/test/orm/test_cascade.py b/test/orm/test_cascade.py index fecf3dcc3c..9aaf765927 100644 --- a/test/orm/test_cascade.py +++ b/test/orm/test_cascade.py @@ -1,12 +1,27 @@ import copy from sqlalchemy.testing import assert_raises, assert_raises_message -from sqlalchemy import Integer, String, ForeignKey, \ - exc as sa_exc, util, select, func +from sqlalchemy import ( + Integer, + String, + ForeignKey, + exc as sa_exc, + util, + select, + func, +) from sqlalchemy.testing.schema import Table, Column -from sqlalchemy.orm import mapper, relationship, create_session, \ - sessionmaker, class_mapper, backref, Session, util as orm_util,\ - configure_mappers +from sqlalchemy.orm import ( + mapper, + relationship, + create_session, + sessionmaker, + class_mapper, + backref, + Session, + util as orm_util, + configure_mappers, +) from sqlalchemy.orm.attributes import instance_state from sqlalchemy.orm import attributes, exc as orm_exc, object_mapper from sqlalchemy import testing @@ -22,15 +37,23 @@ class CascadeArgTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30), nullable=False)) - Table('addresses', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', Integer, ForeignKey('users.id')), - Column('email_address', String(50), nullable=False)) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30), nullable=False), + ) + Table( + "addresses", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", Integer, ForeignKey("users.id")), + Column("email_address", String(50), nullable=False), + ) @classmethod def setup_classes(cls): @@ -44,16 +67,23 @@ class CascadeArgTest(fixtures.MappedTest): User, Address = self.classes.User, self.classes.Address users, addresses = self.tables.users, self.tables.addresses - mapper(User, users, properties={ - 'addresses': relationship(Address, - passive_deletes="all", - cascade="all, delete-orphan")}) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, + passive_deletes="all", + cascade="all, delete-orphan", + ) + }, + ) mapper(Address, addresses) assert_raises_message( sa_exc.ArgumentError, "On User.addresses, can't set passive_deletes='all' " "in conjunction with 'delete' or 'delete-orphan' cascade", - configure_mappers + configure_mappers, ) def test_delete_orphan_without_delete(self): @@ -63,7 +93,9 @@ class CascadeArgTest(fixtures.MappedTest): assert_raises_message( sa_exc.SAWarning, "The 'delete-orphan' cascade option requires 'delete'.", - relationship, Address, cascade="save-update, delete-orphan" + relationship, + Address, + cascade="save-update, delete-orphan", ) def test_bad_cascade(self): @@ -73,20 +105,22 @@ class CascadeArgTest(fixtures.MappedTest): assert_raises_message( sa_exc.ArgumentError, r"Invalid cascade option\(s\): 'fake', 'fake2'", - relationship, Address, cascade="fake, all, delete-orphan, fake2" + relationship, + Address, + cascade="fake, all, delete-orphan, fake2", ) def test_cascade_repr(self): eq_( repr(orm_util.CascadeOptions("all, delete-orphan")), "CascadeOptions('delete,delete-orphan,expunge," - "merge,refresh-expire,save-update')" + "merge,refresh-expire,save-update')", ) def test_cascade_immutable(self): assert isinstance( - orm_util.CascadeOptions("all, delete-orphan"), - frozenset) + orm_util.CascadeOptions("all, delete-orphan"), frozenset + ) def test_cascade_deepcopy(self): old = orm_util.CascadeOptions("all, delete-orphan") @@ -98,32 +132,41 @@ class CascadeArgTest(fixtures.MappedTest): users, addresses = self.tables.users, self.tables.addresses rel = relationship(Address) - eq_(rel.cascade, set(['save-update', 'merge'])) + eq_(rel.cascade, set(["save-update", "merge"])) rel.cascade = "save-update, merge, expunge" - eq_(rel.cascade, set(['save-update', 'merge', 'expunge'])) + eq_(rel.cascade, set(["save-update", "merge", "expunge"])) - mapper(User, users, properties={'addresses': rel}) + mapper(User, users, properties={"addresses": rel}) am = mapper(Address, addresses) configure_mappers() - eq_(rel.cascade, set(['save-update', 'merge', 'expunge'])) + eq_(rel.cascade, set(["save-update", "merge", "expunge"])) assert ("addresses", User) not in am._delete_orphans rel.cascade = "all, delete, delete-orphan" assert ("addresses", User) in am._delete_orphans - eq_(rel.cascade, - set(['delete', 'delete-orphan', 'expunge', 'merge', - 'refresh-expire', 'save-update']) - ) + eq_( + rel.cascade, + set( + [ + "delete", + "delete-orphan", + "expunge", + "merge", + "refresh-expire", + "save-update", + ] + ), + ) def test_cascade_unicode(self): User, Address = self.classes.User, self.classes.Address users, addresses = self.tables.users, self.tables.addresses rel = relationship(Address) - rel.cascade = util.u('save-update, merge, expunge') - eq_(rel.cascade, set(['save-update', 'merge', 'expunge'])) + rel.cascade = util.u("save-update, merge, expunge") + eq_(rel.cascade, set(["save-update", "merge", "expunge"])) class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): @@ -131,26 +174,41 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30), nullable=False)) - Table('addresses', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', Integer, ForeignKey('users.id')), - Column('email_address', String(50), nullable=False)) - Table('orders', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', Integer, ForeignKey( - 'users.id'), nullable=False), - Column('description', String(30))) - Table("dingalings", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('address_id', Integer, ForeignKey('addresses.id')), - Column('data', String(30))) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30), nullable=False), + ) + Table( + "addresses", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", Integer, ForeignKey("users.id")), + Column("email_address", String(50), nullable=False), + ) + Table( + "orders", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", Integer, ForeignKey("users.id"), nullable=False), + Column("description", String(30)), + ) + Table( + "dingalings", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("address_id", Integer, ForeignKey("addresses.id")), + Column("data", String(30)), + ) @classmethod def setup_classes(cls): @@ -168,65 +226,95 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - users, Dingaling, Order, User, dingalings, Address, \ - orders, addresses = (cls.tables.users, - cls.classes.Dingaling, - cls.classes.Order, - cls.classes.User, - cls.tables.dingalings, - cls.classes.Address, - cls.tables.orders, - cls.tables.addresses) + users, Dingaling, Order, User, dingalings, Address, orders, addresses = ( + cls.tables.users, + cls.classes.Dingaling, + cls.classes.Order, + cls.classes.User, + cls.tables.dingalings, + cls.classes.Address, + cls.tables.orders, + cls.tables.addresses, + ) mapper(Address, addresses) mapper(Order, orders) - mapper(User, users, properties={ - 'addresses': relationship(Address, - cascade='all, delete-orphan', - backref='user'), - - 'orders': relationship(Order, - cascade='all, delete-orphan', - order_by=orders.c.id) - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, cascade="all, delete-orphan", backref="user" + ), + "orders": relationship( + Order, cascade="all, delete-orphan", order_by=orders.c.id + ), + }, + ) - mapper(Dingaling, dingalings, properties={ - 'address': relationship(Address) - }) + mapper( + Dingaling, + dingalings, + properties={"address": relationship(Address)}, + ) def test_list_assignment_new(self): User, Order = self.classes.User, self.classes.Order sess = Session() - u = User(name='jack', orders=[ - Order(description='order 1'), - Order(description='order 2')]) + u = User( + name="jack", + orders=[ + Order(description="order 1"), + Order(description="order 2"), + ], + ) sess.add(u) sess.commit() - eq_(u, User(name='jack', - orders=[Order(description='order 1'), - Order(description='order 2')])) + eq_( + u, + User( + name="jack", + orders=[ + Order(description="order 1"), + Order(description="order 2"), + ], + ), + ) def test_list_assignment_replace(self): User, Order = self.classes.User, self.classes.Order sess = Session() - u = User(name='jack', orders=[ - Order(description='someorder'), - Order(description='someotherorder')]) + u = User( + name="jack", + orders=[ + Order(description="someorder"), + Order(description="someotherorder"), + ], + ) sess.add(u) u.orders = [Order(description="order 3"), Order(description="order 4")] sess.commit() - eq_(u, User(name='jack', - orders=[Order(description="order 3"), - Order(description="order 4")])) + eq_( + u, + User( + name="jack", + orders=[ + Order(description="order 3"), + Order(description="order 4"), + ], + ), + ) # order 1, order 2 have been deleted - eq_(sess.query(Order).order_by(Order.id).all(), - [Order(description="order 3"), Order(description="order 4")]) + eq_( + sess.query(Order).order_by(Order.id).all(), + [Order(description="order 3"), Order(description="order 4")], + ) def test_standalone_orphan(self): Order = self.classes.Order @@ -243,9 +331,12 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): Order, User = self.classes.Order, self.classes.User sess = sessionmaker(expire_on_commit=False)() - o1, o2, o3 = Order(description='o1'), Order(description='o2'), \ - Order(description='o3') - u = User(name='jack', orders=[o1, o2]) + o1, o2, o3 = ( + Order(description="o1"), + Order(description="o2"), + Order(description="o3"), + ) + u = User(name="jack", orders=[o1, o2]) sess.add(u) sess.commit() sess.close() @@ -262,7 +353,7 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): sess = Session() - u = User(name='jack') + u = User(name="jack") sess.add(u) sess.commit() @@ -279,7 +370,7 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): sess = Session() - u = User(name='jack') + u = User(name="jack") o1 = Order() sess.add(o1) @@ -297,63 +388,81 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): assert o1 not in sess def test_delete(self): - User, users, orders, Order = (self.classes.User, - self.tables.users, - self.tables.orders, - self.classes.Order) + User, users, orders, Order = ( + self.classes.User, + self.tables.users, + self.tables.orders, + self.classes.Order, + ) sess = create_session() - u = User(name='jack', - orders=[Order(description='someorder'), - Order(description='someotherorder')]) + u = User( + name="jack", + orders=[ + Order(description="someorder"), + Order(description="someotherorder"), + ], + ) sess.add(u) sess.flush() sess.delete(u) sess.flush() - eq_(select([func.count('*')]).select_from(users).scalar(), 0) - eq_(select([func.count('*')]).select_from(orders).scalar(), 0) + eq_(select([func.count("*")]).select_from(users).scalar(), 0) + eq_(select([func.count("*")]).select_from(orders).scalar(), 0) def test_delete_unloaded_collections(self): """Unloaded collections are still included in a delete-cascade by default.""" - User, addresses, users, Address = (self.classes.User, - self.tables.addresses, - self.tables.users, - self.classes.Address) + User, addresses, users, Address = ( + self.classes.User, + self.tables.addresses, + self.tables.users, + self.classes.Address, + ) sess = create_session() - u = User(name='jack', - addresses=[Address(email_address="address1"), - Address(email_address="address2")]) + u = User( + name="jack", + addresses=[ + Address(email_address="address1"), + Address(email_address="address2"), + ], + ) sess.add(u) sess.flush() sess.expunge_all() - eq_(select([func.count('*')]).select_from(addresses).scalar(), 2) - eq_(select([func.count('*')]).select_from(users).scalar(), 1) + eq_(select([func.count("*")]).select_from(addresses).scalar(), 2) + eq_(select([func.count("*")]).select_from(users).scalar(), 1) u = sess.query(User).get(u.id) - assert 'addresses' not in u.__dict__ + assert "addresses" not in u.__dict__ sess.delete(u) sess.flush() - eq_(select([func.count('*')]).select_from(addresses).scalar(), 0) - eq_(select([func.count('*')]).select_from(users).scalar(), 0) + eq_(select([func.count("*")]).select_from(addresses).scalar(), 0) + eq_(select([func.count("*")]).select_from(users).scalar(), 0) def test_cascades_onlycollection(self): """Cascade only reaches instances that are still part of the collection, not those that have been removed""" - User, Order, users, orders = (self.classes.User, - self.classes.Order, - self.tables.users, - self.tables.orders) + User, Order, users, orders = ( + self.classes.User, + self.classes.Order, + self.tables.users, + self.tables.orders, + ) sess = create_session() - u = User(name='jack', - orders=[Order(description='someorder'), - Order(description='someotherorder')]) + u = User( + name="jack", + orders=[ + Order(description="someorder"), + Order(description="someotherorder"), + ], + ) sess.add(u) sess.flush() @@ -364,89 +473,104 @@ class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): assert o not in sess.deleted assert o in sess - u2 = User(name='newuser', orders=[o]) + u2 = User(name="newuser", orders=[o]) sess.add(u2) sess.flush() sess.expunge_all() - eq_(select([func.count('*')]).select_from(users).scalar(), 1) - eq_(select([func.count('*')]).select_from(orders).scalar(), 1) - eq_(sess.query(User).all(), - [User(name='newuser', - orders=[Order(description='someorder')])]) + eq_(select([func.count("*")]).select_from(users).scalar(), 1) + eq_(select([func.count("*")]).select_from(orders).scalar(), 1) + eq_( + sess.query(User).all(), + [User(name="newuser", orders=[Order(description="someorder")])], + ) def test_cascade_nosideeffects(self): """test that cascade leaves the state of unloaded scalars/collections unchanged.""" - Dingaling, User, Address = (self.classes.Dingaling, - self.classes.User, - self.classes.Address) + Dingaling, User, Address = ( + self.classes.Dingaling, + self.classes.User, + self.classes.Address, + ) sess = create_session() - u = User(name='jack') + u = User(name="jack") sess.add(u) - assert 'orders' not in u.__dict__ + assert "orders" not in u.__dict__ sess.flush() - assert 'orders' not in u.__dict__ + assert "orders" not in u.__dict__ - a = Address(email_address='foo@bar.com') + a = Address(email_address="foo@bar.com") sess.add(a) - assert 'user' not in a.__dict__ + assert "user" not in a.__dict__ a.user = u sess.flush() - d = Dingaling(data='d1') + d = Dingaling(data="d1") d.address_id = a.id sess.add(d) - assert 'address' not in d.__dict__ + assert "address" not in d.__dict__ sess.flush() assert d.address is a def test_cascade_delete_plusorphans(self): - User, users, orders, Order = (self.classes.User, - self.tables.users, - self.tables.orders, - self.classes.Order) + User, users, orders, Order = ( + self.classes.User, + self.tables.users, + self.tables.orders, + self.classes.Order, + ) sess = create_session() - u = User(name='jack', - orders=[Order(description='someorder'), - Order(description='someotherorder')]) + u = User( + name="jack", + orders=[ + Order(description="someorder"), + Order(description="someotherorder"), + ], + ) sess.add(u) sess.flush() - eq_(select([func.count('*')]).select_from(users).scalar(), 1) - eq_(select([func.count('*')]).select_from(orders).scalar(), 2) + eq_(select([func.count("*")]).select_from(users).scalar(), 1) + eq_(select([func.count("*")]).select_from(orders).scalar(), 2) del u.orders[0] sess.delete(u) sess.flush() - eq_(select([func.count('*')]).select_from(users).scalar(), 0) - eq_(select([func.count('*')]).select_from(orders).scalar(), 0) + eq_(select([func.count("*")]).select_from(users).scalar(), 0) + eq_(select([func.count("*")]).select_from(orders).scalar(), 0) def test_collection_orphans(self): - User, users, orders, Order = (self.classes.User, - self.tables.users, - self.tables.orders, - self.classes.Order) + User, users, orders, Order = ( + self.classes.User, + self.tables.users, + self.tables.orders, + self.classes.Order, + ) sess = create_session() - u = User(name='jack', - orders=[Order(description='someorder'), - Order(description='someotherorder')]) + u = User( + name="jack", + orders=[ + Order(description="someorder"), + Order(description="someotherorder"), + ], + ) sess.add(u) sess.flush() - eq_(select([func.count('*')]).select_from(users).scalar(), 1) - eq_(select([func.count('*')]).select_from(orders).scalar(), 2) + eq_(select([func.count("*")]).select_from(users).scalar(), 1) + eq_(select([func.count("*")]).select_from(orders).scalar(), 2) u.orders[:] = [] sess.flush() - eq_(select([func.count('*')]).select_from(users).scalar(), 1) - eq_(select([func.count('*')]).select_from(orders).scalar(), 0) + eq_(select([func.count("*")]).select_from(users).scalar(), 1) + eq_(select([func.count("*")]).select_from(orders).scalar(), 0) class O2MCascadeTest(fixtures.MappedTest): @@ -454,15 +578,23 @@ class O2MCascadeTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30), nullable=False)) - Table('addresses', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', Integer, ForeignKey('users.id')), - Column('email_address', String(50), nullable=False)) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30), nullable=False), + ) + Table( + "addresses", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", Integer, ForeignKey("users.id")), + Column("email_address", String(50), nullable=False), + ) @classmethod def setup_classes(cls): @@ -475,25 +607,29 @@ class O2MCascadeTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): users, User, Address, addresses = ( - cls.tables.users, cls.classes.User, - cls.classes.Address, cls.tables.addresses) + cls.tables.users, + cls.classes.User, + cls.classes.Address, + cls.tables.addresses, + ) mapper(Address, addresses) - mapper(User, users, properties={ - 'addresses': relationship(Address, backref="user"), - - }) + mapper( + User, + users, + properties={"addresses": relationship(Address, backref="user")}, + ) def test_none_o2m_collection_assignment(self): User, Address = self.classes.User, self.classes.Address s = Session() - u1 = User(name='u', addresses=[None]) + u1 = User(name="u", addresses=[None]) s.add(u1) eq_(u1.addresses, [None]) assert_raises_message( orm_exc.FlushError, "Can't flush None value found in collection User.addresses", - s.commit + s.commit, ) eq_(u1.addresses, [None]) @@ -501,14 +637,14 @@ class O2MCascadeTest(fixtures.MappedTest): User, Address = self.classes.User, self.classes.Address s = Session() - u1 = User(name='u') + u1 = User(name="u") s.add(u1) u1.addresses.append(None) eq_(u1.addresses, [None]) assert_raises_message( orm_exc.FlushError, "Can't flush None value found in collection User.addresses", - s.commit + s.commit, ) eq_(u1.addresses, [None]) @@ -518,15 +654,23 @@ class O2MCascadeDeleteNoOrphanTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30))) - Table('orders', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', Integer, ForeignKey('users.id')), - Column('description', String(30))) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30)), + ) + Table( + "orders", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", Integer, ForeignKey("users.id")), + Column("description", String(30)), + ) @classmethod def setup_classes(cls): @@ -538,36 +682,47 @@ class O2MCascadeDeleteNoOrphanTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - User, Order, orders, users = (cls.classes.User, - cls.classes.Order, - cls.tables.orders, - cls.tables.users) + User, Order, orders, users = ( + cls.classes.User, + cls.classes.Order, + cls.tables.orders, + cls.tables.users, + ) - mapper(User, users, properties=dict( - orders=relationship( - mapper(Order, orders), cascade="all") - )) + mapper( + User, + users, + properties=dict( + orders=relationship(mapper(Order, orders), cascade="all") + ), + ) def test_cascade_delete_noorphans(self): - User, Order, orders, users = (self.classes.User, - self.classes.Order, - self.tables.orders, - self.tables.users) + User, Order, orders, users = ( + self.classes.User, + self.classes.Order, + self.tables.orders, + self.tables.users, + ) sess = create_session() - u = User(name='jack', - orders=[Order(description='someorder'), - Order(description='someotherorder')]) + u = User( + name="jack", + orders=[ + Order(description="someorder"), + Order(description="someotherorder"), + ], + ) sess.add(u) sess.flush() - eq_(select([func.count('*')]).select_from(users).scalar(), 1) - eq_(select([func.count('*')]).select_from(orders).scalar(), 2) + eq_(select([func.count("*")]).select_from(users).scalar(), 1) + eq_(select([func.count("*")]).select_from(orders).scalar(), 2) del u.orders[0] sess.delete(u) sess.flush() - eq_(select([func.count('*')]).select_from(users).scalar(), 0) - eq_(select([func.count('*')]).select_from(orders).scalar(), 1) + eq_(select([func.count("*")]).select_from(users).scalar(), 0) + eq_(select([func.count("*")]).select_from(orders).scalar(), 1) class O2OSingleParentTest(_fixtures.FixtureTest): @@ -575,25 +730,35 @@ class O2OSingleParentTest(_fixtures.FixtureTest): @classmethod def setup_mappers(cls): - Address, addresses, users, User = (cls.classes.Address, - cls.tables.addresses, - cls.tables.users, - cls.classes.User) + Address, addresses, users, User = ( + cls.classes.Address, + cls.tables.addresses, + cls.tables.users, + cls.classes.User, + ) mapper(Address, addresses) - mapper(User, users, - properties={'address': relationship( - Address, backref=backref('user', single_parent=True), - uselist=False)}) + mapper( + User, + users, + properties={ + "address": relationship( + Address, + backref=backref("user", single_parent=True), + uselist=False, + ) + }, + ) def test_single_parent_raise(self): User, Address = self.classes.User, self.classes.Address - a1 = Address(email_address='some address') - u1 = User(name='u1', address=a1) - assert_raises(sa_exc.InvalidRequestError, Address, - email_address='asd', user=u1) - a2 = Address(email_address='asd') + a1 = Address(email_address="some address") + u1 = User(name="u1", address=a1) + assert_raises( + sa_exc.InvalidRequestError, Address, email_address="asd", user=u1 + ) + a2 = Address(email_address="asd") u1.address = a2 assert u1.address is not a1 assert a1.user is None @@ -604,16 +769,24 @@ class O2OSingleParentNoFlushTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30), nullable=False)) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30), nullable=False), + ) - Table('addresses', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', None, ForeignKey('users.id'), nullable=False), - Column('email_address', String(50), nullable=False)) + Table( + "addresses", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", None, ForeignKey("users.id"), nullable=False), + Column("email_address", String(50), nullable=False), + ) @classmethod def setup_classes(cls): @@ -625,29 +798,41 @@ class O2OSingleParentNoFlushTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - Address, addresses, users, User = (cls.classes.Address, - cls.tables.addresses, - cls.tables.users, - cls.classes.User) + Address, addresses, users, User = ( + cls.classes.Address, + cls.tables.addresses, + cls.tables.users, + cls.classes.User, + ) mapper(Address, addresses) - mapper(User, users, - properties={'address': relationship( - Address, backref=backref('user', single_parent=True, - cascade="all, delete-orphan"), - uselist=False)}) + mapper( + User, + users, + properties={ + "address": relationship( + Address, + backref=backref( + "user", + single_parent=True, + cascade="all, delete-orphan", + ), + uselist=False, + ) + }, + ) def test_replace_attribute_no_flush(self): # test [ticket:2921] User, Address = self.classes.User, self.classes.Address - a1 = Address(email_address='some address') - u1 = User(name='u1', address=a1) + a1 = Address(email_address="some address") + u1 = User(name="u1", address=a1) sess = Session() sess.add(u1) sess.commit() - a2 = Address(email_address='asdf') + a2 = Address(email_address="asdf") sess.add(a2) u1.address = a2 @@ -657,41 +842,55 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): run_inserts = None - def _one_to_many_fixture(self, o2m_cascade=True, - m2o_cascade=True, - o2m=False, - m2o=False, - o2m_cascade_backrefs=True, - m2o_cascade_backrefs=True): - - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + def _one_to_many_fixture( + self, + o2m_cascade=True, + m2o_cascade=True, + o2m=False, + m2o=False, + o2m_cascade_backrefs=True, + m2o_cascade_backrefs=True, + ): + + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) if o2m: if m2o: - addresses_rel = {'addresses': relationship( - Address, - cascade_backrefs=o2m_cascade_backrefs, - cascade=o2m_cascade and 'save-update' or '', - backref=backref( - 'user', cascade=m2o_cascade and 'save-update' or '', - cascade_backrefs=m2o_cascade_backrefs) - )} + addresses_rel = { + "addresses": relationship( + Address, + cascade_backrefs=o2m_cascade_backrefs, + cascade=o2m_cascade and "save-update" or "", + backref=backref( + "user", + cascade=m2o_cascade and "save-update" or "", + cascade_backrefs=m2o_cascade_backrefs, + ), + ) + } else: - addresses_rel = {'addresses': relationship( - Address, - cascade=o2m_cascade and 'save-update' or '', - cascade_backrefs=o2m_cascade_backrefs, - )} + addresses_rel = { + "addresses": relationship( + Address, + cascade=o2m_cascade and "save-update" or "", + cascade_backrefs=o2m_cascade_backrefs, + ) + } user_rel = {} elif m2o: - user_rel = {'user': relationship( - User, cascade=m2o_cascade and 'save-update' or '', - cascade_backrefs=m2o_cascade_backrefs - )} + user_rel = { + "user": relationship( + User, + cascade=m2o_cascade and "save-update" or "", + cascade_backrefs=m2o_cascade_backrefs, + ) + } addresses_rel = {} else: addresses_rel = {} @@ -700,46 +899,59 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): mapper(User, users, properties=addresses_rel) mapper(Address, addresses, properties=user_rel) - def _many_to_many_fixture(self, fwd_cascade=True, - bkd_cascade=True, - fwd=False, - bkd=False, - fwd_cascade_backrefs=True, - bkd_cascade_backrefs=True): - - keywords, items, item_keywords, Keyword, Item = \ - (self.tables.keywords, - self.tables.items, - self.tables.item_keywords, - self.classes.Keyword, - self.classes.Item) + def _many_to_many_fixture( + self, + fwd_cascade=True, + bkd_cascade=True, + fwd=False, + bkd=False, + fwd_cascade_backrefs=True, + bkd_cascade_backrefs=True, + ): + + keywords, items, item_keywords, Keyword, Item = ( + self.tables.keywords, + self.tables.items, + self.tables.item_keywords, + self.classes.Keyword, + self.classes.Item, + ) if fwd: if bkd: - keywords_rel = {'keywords': relationship( - Keyword, - secondary=item_keywords, - cascade_backrefs=fwd_cascade_backrefs, - cascade=fwd_cascade and 'save-update' or '', - backref=backref( - 'items', - cascade=bkd_cascade and 'save-update' or '', - cascade_backrefs=bkd_cascade_backrefs))} + keywords_rel = { + "keywords": relationship( + Keyword, + secondary=item_keywords, + cascade_backrefs=fwd_cascade_backrefs, + cascade=fwd_cascade and "save-update" or "", + backref=backref( + "items", + cascade=bkd_cascade and "save-update" or "", + cascade_backrefs=bkd_cascade_backrefs, + ), + ) + } else: - keywords_rel = {'keywords': relationship( - Keyword, - secondary=item_keywords, - cascade=fwd_cascade and 'save-update' or '', - cascade_backrefs=fwd_cascade_backrefs)} + keywords_rel = { + "keywords": relationship( + Keyword, + secondary=item_keywords, + cascade=fwd_cascade and "save-update" or "", + cascade_backrefs=fwd_cascade_backrefs, + ) + } items_rel = {} elif bkd: - items_rel = {'items': relationship( - Item, - secondary=item_keywords, - cascade=bkd_cascade and 'save-update' or '', - cascade_backrefs=bkd_cascade_backrefs - )} + items_rel = { + "items": relationship( + Item, + secondary=item_keywords, + cascade=bkd_cascade and "save-update" or "", + cascade_backrefs=bkd_cascade_backrefs, + ) + } keywords_rel = {} else: keywords_rel = {} @@ -753,8 +965,8 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): self._one_to_many_fixture(o2m=True, m2o=False) sess = Session() - u1 = User(name='u1') - a1 = Address(email_address='a1') + u1 = User(name="u1") + a1 = Address(email_address="a1") u1.addresses.append(a1) sess.add(u1) assert u1 in sess @@ -766,23 +978,21 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): self._one_to_many_fixture(o2m=True, m2o=False, o2m_cascade=False) sess = Session() - u1 = User(name='u1') - a1 = Address(email_address='a1') + u1 = User(name="u1") + a1 = Address(email_address="a1") u1.addresses.append(a1) sess.add(u1) assert u1 in sess assert a1 not in sess - assert_raises_message( - sa_exc.SAWarning, "not in session", sess.flush - ) + assert_raises_message(sa_exc.SAWarning, "not in session", sess.flush) def test_o2m_only_child_persistent(self): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=False, o2m_cascade=False) sess = Session() - u1 = User(name='u1') - a1 = Address(email_address='a1') + u1 = User(name="u1") + a1 = Address(email_address="a1") sess.add(a1) sess.flush() @@ -792,17 +1002,15 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): sess.add(u1) assert u1 in sess assert a1 not in sess - assert_raises_message( - sa_exc.SAWarning, "not in session", sess.flush - ) + assert_raises_message(sa_exc.SAWarning, "not in session", sess.flush) def test_o2m_backref_child_pending(self): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=True) sess = Session() - u1 = User(name='u1') - a1 = Address(email_address='a1') + u1 = User(name="u1") + a1 = Address(email_address="a1") u1.addresses.append(a1) sess.add(u1) assert u1 in sess @@ -812,46 +1020,42 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): def test_o2m_backref_child_transient(self): User, Address = self.classes.User, self.classes.Address - self._one_to_many_fixture(o2m=True, m2o=True, - o2m_cascade=False) + self._one_to_many_fixture(o2m=True, m2o=True, o2m_cascade=False) sess = Session() - u1 = User(name='u1') - a1 = Address(email_address='a1') + u1 = User(name="u1") + a1 = Address(email_address="a1") u1.addresses.append(a1) sess.add(u1) assert u1 in sess assert a1 not in sess - assert_raises_message( - sa_exc.SAWarning, "not in session", sess.flush - ) + assert_raises_message(sa_exc.SAWarning, "not in session", sess.flush) def test_o2m_backref_child_transient_nochange(self): User, Address = self.classes.User, self.classes.Address - self._one_to_many_fixture(o2m=True, m2o=True, - o2m_cascade=False) + self._one_to_many_fixture(o2m=True, m2o=True, o2m_cascade=False) sess = Session() - u1 = User(name='u1') - a1 = Address(email_address='a1') + u1 = User(name="u1") + a1 = Address(email_address="a1") u1.addresses.append(a1) sess.add(u1) assert u1 in sess assert a1 not in sess - @testing.emits_warning(r'.*not in session') + @testing.emits_warning(r".*not in session") def go(): sess.commit() + go() eq_(u1.addresses, []) def test_o2m_backref_child_expunged(self): User, Address = self.classes.User, self.classes.Address - self._one_to_many_fixture(o2m=True, m2o=True, - o2m_cascade=False) + self._one_to_many_fixture(o2m=True, m2o=True, o2m_cascade=False) sess = Session() - u1 = User(name='u1') - a1 = Address(email_address='a1') + u1 = User(name="u1") + a1 = Address(email_address="a1") sess.add(a1) sess.flush() @@ -860,18 +1064,15 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): sess.expunge(a1) assert u1 in sess assert a1 not in sess - assert_raises_message( - sa_exc.SAWarning, "not in session", sess.flush - ) + assert_raises_message(sa_exc.SAWarning, "not in session", sess.flush) def test_o2m_backref_child_expunged_nochange(self): User, Address = self.classes.User, self.classes.Address - self._one_to_many_fixture(o2m=True, m2o=True, - o2m_cascade=False) + self._one_to_many_fixture(o2m=True, m2o=True, o2m_cascade=False) sess = Session() - u1 = User(name='u1') - a1 = Address(email_address='a1') + u1 = User(name="u1") + a1 = Address(email_address="a1") sess.add(a1) sess.flush() @@ -881,9 +1082,10 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): assert u1 in sess assert a1 not in sess - @testing.emits_warning(r'.*not in session') + @testing.emits_warning(r".*not in session") def go(): sess.commit() + go() eq_(u1.addresses, []) @@ -892,8 +1094,8 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): self._one_to_many_fixture(o2m=False, m2o=True) sess = Session() - u1 = User(name='u1') - a1 = Address(email_address='a1') + u1 = User(name="u1") + a1 = Address(email_address="a1") a1.user = u1 sess.add(a1) assert u1 in sess @@ -905,42 +1107,38 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): self._one_to_many_fixture(o2m=False, m2o=True, m2o_cascade=False) sess = Session() - u1 = User(name='u1') - a1 = Address(email_address='a1') + u1 = User(name="u1") + a1 = Address(email_address="a1") a1.user = u1 sess.add(a1) assert u1 not in sess assert a1 in sess - assert_raises_message( - sa_exc.SAWarning, "not in session", sess.flush - ) + assert_raises_message(sa_exc.SAWarning, "not in session", sess.flush) def test_m2o_only_child_expunged(self): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=False, m2o=True, m2o_cascade=False) sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.flush() - a1 = Address(email_address='a1') + a1 = Address(email_address="a1") a1.user = u1 sess.add(a1) sess.expunge(u1) assert u1 not in sess assert a1 in sess - assert_raises_message( - sa_exc.SAWarning, "not in session", sess.flush - ) + assert_raises_message(sa_exc.SAWarning, "not in session", sess.flush) def test_m2o_backref_child_pending(self): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=True) sess = Session() - u1 = User(name='u1') - a1 = Address(email_address='a1') + u1 = User(name="u1") + a1 = Address(email_address="a1") a1.user = u1 sess.add(a1) assert u1 in sess @@ -952,51 +1150,48 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): self._one_to_many_fixture(o2m=True, m2o=True, m2o_cascade=False) sess = Session() - u1 = User(name='u1') - a1 = Address(email_address='a1') + u1 = User(name="u1") + a1 = Address(email_address="a1") a1.user = u1 sess.add(a1) assert u1 not in sess assert a1 in sess - assert_raises_message( - sa_exc.SAWarning, "not in session", sess.flush - ) + assert_raises_message(sa_exc.SAWarning, "not in session", sess.flush) def test_m2o_backref_child_expunged(self): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=True, m2o_cascade=False) sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.flush() - a1 = Address(email_address='a1') + a1 = Address(email_address="a1") a1.user = u1 sess.add(a1) sess.expunge(u1) assert u1 not in sess assert a1 in sess - assert_raises_message( - sa_exc.SAWarning, "not in session", sess.flush - ) + assert_raises_message(sa_exc.SAWarning, "not in session", sess.flush) def test_m2o_backref_child_pending_nochange(self): User, Address = self.classes.User, self.classes.Address self._one_to_many_fixture(o2m=True, m2o=True, m2o_cascade=False) sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") - a1 = Address(email_address='a1') + a1 = Address(email_address="a1") a1.user = u1 sess.add(a1) assert u1 not in sess assert a1 in sess - @testing.emits_warning(r'.*not in session') + @testing.emits_warning(r".*not in session") def go(): sess.commit() + go() # didn't get flushed assert a1.user is None @@ -1006,20 +1201,21 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): self._one_to_many_fixture(o2m=True, m2o=True, m2o_cascade=False) sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.flush() - a1 = Address(email_address='a1') + a1 = Address(email_address="a1") a1.user = u1 sess.add(a1) sess.expunge(u1) assert u1 not in sess assert a1 in sess - @testing.emits_warning(r'.*not in session') + @testing.emits_warning(r".*not in session") def go(): sess.commit() + go() # didn't get flushed assert a1.user is None @@ -1029,8 +1225,8 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): self._many_to_many_fixture(fwd=True, bkd=False) sess = Session() - i1 = Item(description='i1') - k1 = Keyword(name='k1') + i1 = Item(description="i1") + k1 = Keyword(name="k1") i1.keywords.append(k1) sess.add(i1) assert i1 in sess @@ -1042,23 +1238,21 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): self._many_to_many_fixture(fwd=True, bkd=False, fwd_cascade=False) sess = Session() - i1 = Item(description='i1') - k1 = Keyword(name='k1') + i1 = Item(description="i1") + k1 = Keyword(name="k1") i1.keywords.append(k1) sess.add(i1) assert i1 in sess assert k1 not in sess - assert_raises_message( - sa_exc.SAWarning, "not in session", sess.flush - ) + assert_raises_message(sa_exc.SAWarning, "not in session", sess.flush) def test_m2m_only_child_persistent(self): Item, Keyword = self.classes.Item, self.classes.Keyword self._many_to_many_fixture(fwd=True, bkd=False, fwd_cascade=False) sess = Session() - i1 = Item(description='i1') - k1 = Keyword(name='k1') + i1 = Item(description="i1") + k1 = Keyword(name="k1") sess.add(k1) sess.flush() @@ -1068,17 +1262,15 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): sess.add(i1) assert i1 in sess assert k1 not in sess - assert_raises_message( - sa_exc.SAWarning, "not in session", sess.flush - ) + assert_raises_message(sa_exc.SAWarning, "not in session", sess.flush) def test_m2m_backref_child_pending(self): Item, Keyword = self.classes.Item, self.classes.Keyword self._many_to_many_fixture(fwd=True, bkd=True) sess = Session() - i1 = Item(description='i1') - k1 = Keyword(name='k1') + i1 = Item(description="i1") + k1 = Keyword(name="k1") i1.keywords.append(k1) sess.add(i1) assert i1 in sess @@ -1088,46 +1280,42 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): def test_m2m_backref_child_transient(self): Item, Keyword = self.classes.Item, self.classes.Keyword - self._many_to_many_fixture(fwd=True, bkd=True, - fwd_cascade=False) + self._many_to_many_fixture(fwd=True, bkd=True, fwd_cascade=False) sess = Session() - i1 = Item(description='i1') - k1 = Keyword(name='k1') + i1 = Item(description="i1") + k1 = Keyword(name="k1") i1.keywords.append(k1) sess.add(i1) assert i1 in sess assert k1 not in sess - assert_raises_message( - sa_exc.SAWarning, "not in session", sess.flush - ) + assert_raises_message(sa_exc.SAWarning, "not in session", sess.flush) def test_m2m_backref_child_transient_nochange(self): Item, Keyword = self.classes.Item, self.classes.Keyword - self._many_to_many_fixture(fwd=True, bkd=True, - fwd_cascade=False) + self._many_to_many_fixture(fwd=True, bkd=True, fwd_cascade=False) sess = Session() - i1 = Item(description='i1') - k1 = Keyword(name='k1') + i1 = Item(description="i1") + k1 = Keyword(name="k1") i1.keywords.append(k1) sess.add(i1) assert i1 in sess assert k1 not in sess - @testing.emits_warning(r'.*not in session') + @testing.emits_warning(r".*not in session") def go(): sess.commit() + go() eq_(i1.keywords, []) def test_m2m_backref_child_expunged(self): Item, Keyword = self.classes.Item, self.classes.Keyword - self._many_to_many_fixture(fwd=True, bkd=True, - fwd_cascade=False) + self._many_to_many_fixture(fwd=True, bkd=True, fwd_cascade=False) sess = Session() - i1 = Item(description='i1') - k1 = Keyword(name='k1') + i1 = Item(description="i1") + k1 = Keyword(name="k1") sess.add(k1) sess.flush() @@ -1136,18 +1324,15 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): sess.expunge(k1) assert i1 in sess assert k1 not in sess - assert_raises_message( - sa_exc.SAWarning, "not in session", sess.flush - ) + assert_raises_message(sa_exc.SAWarning, "not in session", sess.flush) def test_m2m_backref_child_expunged_nochange(self): Item, Keyword = self.classes.Item, self.classes.Keyword - self._many_to_many_fixture(fwd=True, bkd=True, - fwd_cascade=False) + self._many_to_many_fixture(fwd=True, bkd=True, fwd_cascade=False) sess = Session() - i1 = Item(description='i1') - k1 = Keyword(name='k1') + i1 = Item(description="i1") + k1 = Keyword(name="k1") sess.add(k1) sess.flush() @@ -1157,9 +1342,10 @@ class NoSaveCascadeFlushTest(_fixtures.FixtureTest): assert i1 in sess assert k1 not in sess - @testing.emits_warning(r'.*not in session') + @testing.emits_warning(r".*not in session") def go(): sess.commit() + go() eq_(i1.keywords, []) @@ -1169,16 +1355,23 @@ class NoSaveCascadeBackrefTest(_fixtures.FixtureTest): when the cascade initiated from the forwards side.""" def test_unidirectional_cascade_o2m(self): - User, Order, users, orders = (self.classes.User, - self.classes.Order, - self.tables.users, - self.tables.orders) + User, Order, users, orders = ( + self.classes.User, + self.classes.Order, + self.tables.users, + self.tables.orders, + ) mapper(Order, orders) - mapper(User, users, properties=dict( - orders=relationship( - Order, backref=backref("user", cascade=None)) - )) + mapper( + User, + users, + properties=dict( + orders=relationship( + Order, backref=backref("user", cascade=None) + ) + ), + ) sess = create_session() @@ -1197,14 +1390,22 @@ class NoSaveCascadeBackrefTest(_fixtures.FixtureTest): assert o1 in sess def test_unidirectional_cascade_m2o(self): - User, Order, users, orders = (self.classes.User, - self.classes.Order, - self.tables.users, - self.tables.orders) - - mapper(Order, orders, properties={ - 'user': relationship(User, backref=backref("orders", cascade=None)) - }) + User, Order, users, orders = ( + self.classes.User, + self.classes.Order, + self.tables.users, + self.tables.orders, + ) + + mapper( + Order, + orders, + properties={ + "user": relationship( + User, backref=backref("orders", cascade=None) + ) + }, + ) mapper(User, users) sess = create_session() @@ -1226,18 +1427,26 @@ class NoSaveCascadeBackrefTest(_fixtures.FixtureTest): assert u1 in sess def test_unidirectional_cascade_m2m(self): - keywords, items, item_keywords, Keyword, Item = \ - (self.tables.keywords, - self.tables.items, - self.tables.item_keywords, - self.classes.Keyword, - self.classes.Item) - - mapper(Item, items, - properties={'keywords': relationship(Keyword, - secondary=item_keywords, - cascade='none', - backref='items')}) + keywords, items, item_keywords, Keyword, Item = ( + self.tables.keywords, + self.tables.items, + self.tables.item_keywords, + self.classes.Keyword, + self.classes.Item, + ) + + mapper( + Item, + items, + properties={ + "keywords": relationship( + Keyword, + secondary=item_keywords, + cascade="none", + backref="items", + ) + }, + ) mapper(Keyword, keywords) sess = create_session() @@ -1260,30 +1469,42 @@ class NoSaveCascadeBackrefTest(_fixtures.FixtureTest): class M2OCascadeDeleteOrphanTestOne(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('extra', metadata, Column('id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('prefs_id', Integer, ForeignKey('prefs.id'))) - Table('prefs', metadata, Column('id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('data', String(40))) Table( - 'users', + "extra", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(40)), - Column('pref_id', Integer, ForeignKey('prefs.id')), - Column('foo_id', Integer, ForeignKey('foo.id')), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("prefs_id", Integer, ForeignKey("prefs.id")), + ) + Table( + "prefs", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(40)), + ) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(40)), + Column("pref_id", Integer, ForeignKey("prefs.id")), + Column("foo_id", Integer, ForeignKey("foo.id")), + ) + Table( + "foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(40)), ) - Table('foo', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(40))) @classmethod def setup_classes(cls): @@ -1301,33 +1522,48 @@ class M2OCascadeDeleteOrphanTestOne(fixtures.MappedTest): @classmethod def setup_mappers(cls): - extra, foo, users, Extra, Pref, User, prefs, Foo = (cls.tables.extra, - cls.tables.foo, - cls.tables.users, - cls.classes.Extra, - cls.classes.Pref, - cls.classes.User, - cls.tables.prefs, - cls.classes.Foo) + extra, foo, users, Extra, Pref, User, prefs, Foo = ( + cls.tables.extra, + cls.tables.foo, + cls.tables.users, + cls.classes.Extra, + cls.classes.Pref, + cls.classes.User, + cls.tables.prefs, + cls.classes.Foo, + ) mapper(Extra, extra) - mapper(Pref, prefs, properties=dict( - extra=relationship(Extra, cascade='all, delete'))) - mapper(User, users, properties=dict( - pref=relationship(Pref, lazy='joined', - cascade='all, delete-orphan', - single_parent=True), - foo=relationship(Foo))) # straight m2o + mapper( + Pref, + prefs, + properties=dict(extra=relationship(Extra, cascade="all, delete")), + ) + mapper( + User, + users, + properties=dict( + pref=relationship( + Pref, + lazy="joined", + cascade="all, delete-orphan", + single_parent=True, + ), + foo=relationship(Foo), + ), + ) # straight m2o mapper(Foo, foo) @classmethod def insert_data(cls): - Pref, User, Extra = (cls.classes.Pref, - cls.classes.User, - cls.classes.Extra) + Pref, User, Extra = ( + cls.classes.Pref, + cls.classes.User, + cls.classes.Extra, + ) - u1 = User(name='ed', pref=Pref(data="pref 1", extra=[Extra()])) - u2 = User(name='jack', pref=Pref(data="pref 2", extra=[Extra()])) + u1 = User(name="ed", pref=Pref(data="pref 1", extra=[Extra()])) + u2 = User(name="jack", pref=Pref(data="pref 2", extra=[Extra()])) u3 = User(name="foo", pref=Pref(data="pref 3", extra=[Extra()])) sess = create_session() sess.add_all((u1, u2, u3)) @@ -1335,18 +1571,20 @@ class M2OCascadeDeleteOrphanTestOne(fixtures.MappedTest): sess.close() def test_orphan(self): - prefs, User, extra = (self.tables.prefs, - self.classes.User, - self.tables.extra) + prefs, User, extra = ( + self.tables.prefs, + self.classes.User, + self.tables.extra, + ) sess = create_session() - eq_(select([func.count('*')]).select_from(prefs).scalar(), 3) - eq_(select([func.count('*')]).select_from(extra).scalar(), 3) + eq_(select([func.count("*")]).select_from(prefs).scalar(), 3) + eq_(select([func.count("*")]).select_from(extra).scalar(), 3) jack = sess.query(User).filter_by(name="jack").one() jack.pref = None sess.flush() - eq_(select([func.count('*')]).select_from(prefs).scalar(), 2) - eq_(select([func.count('*')]).select_from(extra).scalar(), 2) + eq_(select([func.count("*")]).select_from(prefs).scalar(), 2) + eq_(select([func.count("*")]).select_from(extra).scalar(), 2) def test_cascade_on_deleted(self): """test a bug introduced by r6711""" @@ -1355,7 +1593,7 @@ class M2OCascadeDeleteOrphanTestOne(fixtures.MappedTest): sess = sessionmaker(expire_on_commit=True)() - u1 = User(name='jack', foo=Foo(data='f1')) + u1 = User(name="jack", foo=Foo(data="f1")) sess.add(u1) sess.commit() @@ -1364,10 +1602,7 @@ class M2OCascadeDeleteOrphanTestOne(fixtures.MappedTest): # the error condition relies upon # these things being true assert User.foo.dispatch._active_history is False - eq_( - attributes.get_history(u1, 'foo'), - ([None], (), ()) - ) + eq_(attributes.get_history(u1, "foo"), ([None], (), ())) sess.add(u1) assert u1 in sess @@ -1380,9 +1615,9 @@ class M2OCascadeDeleteOrphanTestOne(fixtures.MappedTest): Pref, User = self.classes.Pref, self.classes.User sess = sessionmaker(expire_on_commit=False)() - p1, p2 = Pref(data='p1'), Pref(data='p2') + p1, p2 = Pref(data="p1"), Pref(data="p2") - u = User(name='jack', pref=p1) + u = User(name="jack", pref=p1) sess.add(u) sess.commit() sess.close() @@ -1395,9 +1630,11 @@ class M2OCascadeDeleteOrphanTestOne(fixtures.MappedTest): sess.commit() def test_orphan_on_update(self): - prefs, User, extra = (self.tables.prefs, - self.classes.User, - self.tables.extra) + prefs, User, extra = ( + self.tables.prefs, + self.classes.User, + self.tables.extra, + ) sess = create_session() jack = sess.query(User).filter_by(name="jack").one() @@ -1412,23 +1649,25 @@ class M2OCascadeDeleteOrphanTestOne(fixtures.MappedTest): assert p in sess assert e in sess sess.flush() - eq_(select([func.count('*')]).select_from(prefs).scalar(), 2) - eq_(select([func.count('*')]).select_from(extra).scalar(), 2) + eq_(select([func.count("*")]).select_from(prefs).scalar(), 2) + eq_(select([func.count("*")]).select_from(extra).scalar(), 2) def test_pending_expunge(self): Pref, User = self.classes.Pref, self.classes.User sess = create_session() - someuser = User(name='someuser') + someuser = User(name="someuser") sess.add(someuser) sess.flush() - someuser.pref = p1 = Pref(data='somepref') + someuser.pref = p1 = Pref(data="somepref") assert p1 in sess - someuser.pref = Pref(data='someotherpref') + someuser.pref = Pref(data="someotherpref") assert p1 not in sess sess.flush() - eq_(sess.query(Pref).with_parent(someuser).all(), - [Pref(data="someotherpref")]) + eq_( + sess.query(Pref).with_parent(someuser).all(), + [Pref(data="someotherpref")], + ) def test_double_assignment(self): """Double assignment will not accidentally reset the 'parent' flag.""" @@ -1442,30 +1681,43 @@ class M2OCascadeDeleteOrphanTestOne(fixtures.MappedTest): jack.pref = newpref jack.pref = newpref sess.flush() - eq_(sess.query(Pref).order_by(Pref.id).all(), - [Pref(data="pref 1"), Pref(data="pref 3"), Pref(data="newpref")]) + eq_( + sess.query(Pref).order_by(Pref.id).all(), + [Pref(data="pref 1"), Pref(data="pref 3"), Pref(data="newpref")], + ) class M2OCascadeDeleteOrphanTestTwo(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('t1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)), - Column('t2id', Integer, ForeignKey('t2.id'))) - - Table('t2', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)), - Column('t3id', Integer, ForeignKey('t3.id'))) - - Table('t3', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50))) + Table( + "t1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + Column("t2id", Integer, ForeignKey("t2.id")), + ) + + Table( + "t2", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + Column("t3id", Integer, ForeignKey("t3.id")), + ) + + Table( + "t3", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) @classmethod def setup_classes(cls): @@ -1480,29 +1732,43 @@ class M2OCascadeDeleteOrphanTestTwo(fixtures.MappedTest): @classmethod def setup_mappers(cls): - t2, T2, T3, t1, t3, T1 = (cls.tables.t2, - cls.classes.T2, - cls.classes.T3, - cls.tables.t1, - cls.tables.t3, - cls.classes.T1) - - mapper(T1, t1, properties=dict( - t2=relationship(T2, cascade='all, delete-orphan', - single_parent=True))) - mapper(T2, t2, properties=dict( - t3=relationship(T3, cascade='all, delete-orphan', - single_parent=True, - backref=backref('t2', uselist=False)))) + t2, T2, T3, t1, t3, T1 = ( + cls.tables.t2, + cls.classes.T2, + cls.classes.T3, + cls.tables.t1, + cls.tables.t3, + cls.classes.T1, + ) + + mapper( + T1, + t1, + properties=dict( + t2=relationship( + T2, cascade="all, delete-orphan", single_parent=True + ) + ), + ) + mapper( + T2, + t2, + properties=dict( + t3=relationship( + T3, + cascade="all, delete-orphan", + single_parent=True, + backref=backref("t2", uselist=False), + ) + ), + ) mapper(T3, t3) def test_cascade_delete(self): - T2, T3, T1 = (self.classes.T2, - self.classes.T3, - self.classes.T1) + T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) sess = create_session() - x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) + x = T1(data="t1a", t2=T2(data="t2a", t3=T3(data="t3a"))) sess.add(x) sess.flush() @@ -1513,12 +1779,10 @@ class M2OCascadeDeleteOrphanTestTwo(fixtures.MappedTest): eq_(sess.query(T3).all(), []) def test_deletes_orphans_onelevel(self): - T2, T3, T1 = (self.classes.T2, - self.classes.T3, - self.classes.T1) + T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) sess = create_session() - x2 = T1(data='t1b', t2=T2(data='t2b', t3=T3(data='t3b'))) + x2 = T1(data="t1b", t2=T2(data="t2b", t3=T3(data="t3b"))) sess.add(x2) sess.flush() x2.t2 = None @@ -1530,12 +1794,10 @@ class M2OCascadeDeleteOrphanTestTwo(fixtures.MappedTest): eq_(sess.query(T3).all(), []) def test_deletes_orphans_twolevel(self): - T2, T3, T1 = (self.classes.T2, - self.classes.T3, - self.classes.T1) + T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) sess = create_session() - x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) + x = T1(data="t1a", t2=T2(data="t2a", t3=T3(data="t3a"))) sess.add(x) sess.flush() @@ -1547,12 +1809,10 @@ class M2OCascadeDeleteOrphanTestTwo(fixtures.MappedTest): eq_(sess.query(T3).all(), []) def test_finds_orphans_twolevel(self): - T2, T3, T1 = (self.classes.T2, - self.classes.T3, - self.classes.T1) + T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) sess = create_session() - x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) + x = T1(data="t1a", t2=T2(data="t2a", t3=T3(data="t3a"))) sess.add(x) sess.flush() @@ -1567,24 +1827,24 @@ class M2OCascadeDeleteOrphanTestTwo(fixtures.MappedTest): sess = create_session() - y = T2(data='T2a') - x = T1(data='T1a', t2=y) - assert_raises(sa_exc.InvalidRequestError, T1, data='T1b', t2=y) + y = T2(data="T2a") + x = T1(data="T1a", t2=y) + assert_raises(sa_exc.InvalidRequestError, T1, data="T1b", t2=y) def test_single_parent_backref(self): T2, T3 = self.classes.T2, self.classes.T3 sess = create_session() - y = T3(data='T3a') - x = T2(data='T2a', t3=y) + y = T3(data="T3a") + x = T2(data="T2a", t3=y) # cant attach the T3 to another T2 - assert_raises(sa_exc.InvalidRequestError, T2, data='T2b', t3=y) + assert_raises(sa_exc.InvalidRequestError, T2, data="T2b", t3=y) # set via backref tho is OK, unsets from previous parent # first - z = T2(data='T2b') + z = T2(data="T2b") y.t2 = z assert z.t3 is y @@ -1592,24 +1852,36 @@ class M2OCascadeDeleteOrphanTestTwo(fixtures.MappedTest): class M2OCascadeDeleteNoOrphanTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('t1', metadata, Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)), - Column('t2id', Integer, ForeignKey('t2.id'))) - - Table('t2', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)), - Column('t3id', Integer, ForeignKey('t3.id'))) - - Table('t3', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50))) + Table( + "t1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + Column("t2id", Integer, ForeignKey("t2.id")), + ) + + Table( + "t2", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + Column("t3id", Integer, ForeignKey("t3.id")), + ) + + Table( + "t3", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) @classmethod def setup_classes(cls): @@ -1624,24 +1896,24 @@ class M2OCascadeDeleteNoOrphanTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - t2, T2, T3, t1, t3, T1 = (cls.tables.t2, - cls.classes.T2, - cls.classes.T3, - cls.tables.t1, - cls.tables.t3, - cls.classes.T1) - - mapper(T1, t1, properties={'t2': relationship(T2, cascade="all")}) - mapper(T2, t2, properties={'t3': relationship(T3, cascade="all")}) + t2, T2, T3, t1, t3, T1 = ( + cls.tables.t2, + cls.classes.T2, + cls.classes.T3, + cls.tables.t1, + cls.tables.t3, + cls.classes.T1, + ) + + mapper(T1, t1, properties={"t2": relationship(T2, cascade="all")}) + mapper(T2, t2, properties={"t3": relationship(T3, cascade="all")}) mapper(T3, t3) def test_cascade_delete(self): - T2, T3, T1 = (self.classes.T2, - self.classes.T3, - self.classes.T1) + T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) sess = create_session() - x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) + x = T1(data="t1a", t2=T2(data="t2a", t3=T3(data="t3a"))) sess.add(x) sess.flush() @@ -1652,14 +1924,12 @@ class M2OCascadeDeleteNoOrphanTest(fixtures.MappedTest): eq_(sess.query(T3).all(), []) def test_cascade_delete_postappend_onelevel(self): - T2, T3, T1 = (self.classes.T2, - self.classes.T3, - self.classes.T1) + T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) sess = create_session() - x1 = T1(data='t1', ) - x2 = T2(data='t2') - x3 = T3(data='t3') + x1 = T1(data="t1") + x2 = T2(data="t2") + x3 = T3(data="t3") sess.add_all((x1, x2, x3)) sess.flush() @@ -1672,13 +1942,11 @@ class M2OCascadeDeleteNoOrphanTest(fixtures.MappedTest): eq_(sess.query(T3).all(), []) def test_cascade_delete_postappend_twolevel(self): - T2, T3, T1 = (self.classes.T2, - self.classes.T3, - self.classes.T1) + T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) sess = create_session() - x1 = T1(data='t1', t2=T2(data='t2')) - x3 = T3(data='t3') + x1 = T1(data="t1", t2=T2(data="t2")) + x3 = T3(data="t3") sess.add_all((x1, x3)) sess.flush() @@ -1690,12 +1958,10 @@ class M2OCascadeDeleteNoOrphanTest(fixtures.MappedTest): eq_(sess.query(T3).all(), []) def test_preserves_orphans_onelevel(self): - T2, T3, T1 = (self.classes.T2, - self.classes.T3, - self.classes.T1) + T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) sess = create_session() - x2 = T1(data='t1b', t2=T2(data='t2b', t3=T3(data='t3b'))) + x2 = T1(data="t1b", t2=T2(data="t2b", t3=T3(data="t3b"))) sess.add(x2) sess.flush() x2.t2 = None @@ -1708,12 +1974,10 @@ class M2OCascadeDeleteNoOrphanTest(fixtures.MappedTest): @testing.future def test_preserves_orphans_onelevel_postremove(self): - T2, T3, T1 = (self.classes.T2, - self.classes.T3, - self.classes.T1) + T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) sess = create_session() - x2 = T1(data='t1b', t2=T2(data='t2b', t3=T3(data='t3b'))) + x2 = T1(data="t1b", t2=T2(data="t2b", t3=T3(data="t3b"))) sess.add(x2) sess.flush() @@ -1725,12 +1989,10 @@ class M2OCascadeDeleteNoOrphanTest(fixtures.MappedTest): eq_(sess.query(T3).all(), [T3()]) def test_preserves_orphans_twolevel(self): - T2, T3, T1 = (self.classes.T2, - self.classes.T3, - self.classes.T1) + T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) sess = create_session() - x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) + x = T1(data="t1a", t2=T2(data="t2a", t3=T3(data="t3a"))) sess.add(x) sess.flush() @@ -1745,26 +2007,41 @@ class M2OCascadeDeleteNoOrphanTest(fixtures.MappedTest): class M2MCascadeTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('a', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30)), - test_needs_fk=True) - Table('b', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30)), - test_needs_fk=True) - Table('atob', metadata, - Column('aid', Integer, ForeignKey('a.id')), - Column('bid', Integer, ForeignKey('b.id')), - test_needs_fk=True) - Table('c', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30)), - Column('bid', Integer, ForeignKey('b.id')), - test_needs_fk=True) + Table( + "a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + test_needs_fk=True, + ) + Table( + "b", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + test_needs_fk=True, + ) + Table( + "atob", + metadata, + Column("aid", Integer, ForeignKey("a.id")), + Column("bid", Integer, ForeignKey("b.id")), + test_needs_fk=True, + ) + Table( + "c", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + Column("bid", Integer, ForeignKey("b.id")), + test_needs_fk=True, + ) @classmethod def setup_classes(cls): @@ -1778,200 +2055,266 @@ class M2MCascadeTest(fixtures.MappedTest): pass def test_delete_orphan(self): - a, A, B, b, atob = (self.tables.a, - self.classes.A, - self.classes.B, - self.tables.b, - self.tables.atob) + a, A, B, b, atob = ( + self.tables.a, + self.classes.A, + self.classes.B, + self.tables.b, + self.tables.atob, + ) # if no backref here, delete-orphan failed until [ticket:427] # was fixed - mapper(A, a, - properties={'bs': relationship(B, secondary=atob, - cascade='all, delete-orphan', - single_parent=True)}) + mapper( + A, + a, + properties={ + "bs": relationship( + B, + secondary=atob, + cascade="all, delete-orphan", + single_parent=True, + ) + }, + ) mapper(B, b) sess = create_session() - b1 = B(data='b1') - a1 = A(data='a1', bs=[b1]) + b1 = B(data="b1") + a1 = A(data="a1", bs=[b1]) sess.add(a1) sess.flush() a1.bs.remove(b1) sess.flush() - eq_(select([func.count('*')]).select_from(atob).scalar(), 0) - eq_(select([func.count('*')]).select_from(b).scalar(), 0) - eq_(select([func.count('*')]).select_from(a).scalar(), 1) + eq_(select([func.count("*")]).select_from(atob).scalar(), 0) + eq_(select([func.count("*")]).select_from(b).scalar(), 0) + eq_(select([func.count("*")]).select_from(a).scalar(), 1) def test_delete_orphan_dynamic(self): - a, A, B, b, atob = (self.tables.a, - self.classes.A, - self.classes.B, - self.tables.b, - self.tables.atob) - - mapper(A, a, - # if no backref here, delete-orphan - properties={'bs': relationship(B, secondary=atob, - cascade='all, delete-orphan', - single_parent=True, - lazy='dynamic')}) + a, A, B, b, atob = ( + self.tables.a, + self.classes.A, + self.classes.B, + self.tables.b, + self.tables.atob, + ) + + mapper( + A, + a, + # if no backref here, delete-orphan + properties={ + "bs": relationship( + B, + secondary=atob, + cascade="all, delete-orphan", + single_parent=True, + lazy="dynamic", + ) + }, + ) # failed until [ticket:427] was fixed mapper(B, b) sess = create_session() - b1 = B(data='b1') - a1 = A(data='a1', bs=[b1]) + b1 = B(data="b1") + a1 = A(data="a1", bs=[b1]) sess.add(a1) sess.flush() a1.bs.remove(b1) sess.flush() - eq_(select([func.count('*')]).select_from(atob).scalar(), 0) - eq_(select([func.count('*')]).select_from(b).scalar(), 0) - eq_(select([func.count('*')]).select_from(a).scalar(), 1) + eq_(select([func.count("*")]).select_from(atob).scalar(), 0) + eq_(select([func.count("*")]).select_from(b).scalar(), 0) + eq_(select([func.count("*")]).select_from(a).scalar(), 1) def test_delete_orphan_cascades(self): - a, A, c, b, C, B, atob = (self.tables.a, - self.classes.A, - self.tables.c, - self.tables.b, - self.classes.C, - self.classes.B, - self.tables.atob) - - mapper(A, a, properties={ - # if no backref here, delete-orphan failed until [ticket:427] was - # fixed - 'bs': relationship(B, secondary=atob, cascade="all, delete-orphan", - single_parent=True) - }) - mapper(B, b, - properties={'cs': relationship(C, - cascade="all, delete-orphan")}) + a, A, c, b, C, B, atob = ( + self.tables.a, + self.classes.A, + self.tables.c, + self.tables.b, + self.classes.C, + self.classes.B, + self.tables.atob, + ) + + mapper( + A, + a, + properties={ + # if no backref here, delete-orphan failed until [ticket:427] was + # fixed + "bs": relationship( + B, + secondary=atob, + cascade="all, delete-orphan", + single_parent=True, + ) + }, + ) + mapper( + B, + b, + properties={"cs": relationship(C, cascade="all, delete-orphan")}, + ) mapper(C, c) sess = create_session() - b1 = B(data='b1', cs=[C(data='c1')]) - a1 = A(data='a1', bs=[b1]) + b1 = B(data="b1", cs=[C(data="c1")]) + a1 = A(data="a1", bs=[b1]) sess.add(a1) sess.flush() a1.bs.remove(b1) sess.flush() - eq_(select([func.count('*')]).select_from(atob).scalar(), 0) - eq_(select([func.count('*')]).select_from(b).scalar(), 0) - eq_(select([func.count('*')]).select_from(a).scalar(), 1) - eq_(select([func.count('*')]).select_from(c).scalar(), 0) + eq_(select([func.count("*")]).select_from(atob).scalar(), 0) + eq_(select([func.count("*")]).select_from(b).scalar(), 0) + eq_(select([func.count("*")]).select_from(a).scalar(), 1) + eq_(select([func.count("*")]).select_from(c).scalar(), 0) def test_cascade_delete(self): - a, A, B, b, atob = (self.tables.a, - self.classes.A, - self.classes.B, - self.tables.b, - self.tables.atob) - - mapper(A, a, properties={ - 'bs': relationship(B, secondary=atob, cascade="all, delete-orphan", - single_parent=True) - }) + a, A, B, b, atob = ( + self.tables.a, + self.classes.A, + self.classes.B, + self.tables.b, + self.tables.atob, + ) + + mapper( + A, + a, + properties={ + "bs": relationship( + B, + secondary=atob, + cascade="all, delete-orphan", + single_parent=True, + ) + }, + ) mapper(B, b) sess = create_session() - a1 = A(data='a1', bs=[B(data='b1')]) + a1 = A(data="a1", bs=[B(data="b1")]) sess.add(a1) sess.flush() sess.delete(a1) sess.flush() - eq_(select([func.count('*')]).select_from(atob).scalar(), 0) - eq_(select([func.count('*')]).select_from(b).scalar(), 0) - eq_(select([func.count('*')]).select_from(a).scalar(), 0) + eq_(select([func.count("*")]).select_from(atob).scalar(), 0) + eq_(select([func.count("*")]).select_from(b).scalar(), 0) + eq_(select([func.count("*")]).select_from(a).scalar(), 0) def test_single_parent_error(self): - a, A, B, b, atob = (self.tables.a, - self.classes.A, - self.classes.B, - self.tables.b, - self.tables.atob) - - mapper(A, a, properties={ - 'bs': relationship(B, secondary=atob, - cascade="all, delete-orphan") - }) + a, A, B, b, atob = ( + self.tables.a, + self.classes.A, + self.classes.B, + self.tables.b, + self.tables.atob, + ) + + mapper( + A, + a, + properties={ + "bs": relationship( + B, secondary=atob, cascade="all, delete-orphan" + ) + }, + ) mapper(B, b) assert_raises_message( sa_exc.ArgumentError, "On A.bs, delete-orphan cascade is not supported", - configure_mappers + configure_mappers, ) def test_single_parent_raise(self): - a, A, B, b, atob = (self.tables.a, - self.classes.A, - self.classes.B, - self.tables.b, - self.tables.atob) - - mapper(A, a, properties={ - 'bs': relationship(B, secondary=atob, cascade="all, delete-orphan", - single_parent=True) - }) + a, A, B, b, atob = ( + self.tables.a, + self.classes.A, + self.classes.B, + self.tables.b, + self.tables.atob, + ) + + mapper( + A, + a, + properties={ + "bs": relationship( + B, + secondary=atob, + cascade="all, delete-orphan", + single_parent=True, + ) + }, + ) mapper(B, b) sess = create_session() - b1 = B(data='b1') - a1 = A(data='a1', bs=[b1]) + b1 = B(data="b1") + a1 = A(data="a1", bs=[b1]) - assert_raises(sa_exc.InvalidRequestError, - A, data='a2', bs=[b1]) + assert_raises(sa_exc.InvalidRequestError, A, data="a2", bs=[b1]) def test_single_parent_backref(self): """test that setting m2m via a uselist=False backref bypasses the single_parent raise""" - a, A, B, b, atob = (self.tables.a, - self.classes.A, - self.classes.B, - self.tables.b, - self.tables.atob) - - mapper(A, a, properties={ - 'bs': relationship(B, - secondary=atob, - cascade="all, delete-orphan", - single_parent=True, - backref=backref('a', uselist=False)) - }) + a, A, B, b, atob = ( + self.tables.a, + self.classes.A, + self.classes.B, + self.tables.b, + self.tables.atob, + ) + + mapper( + A, + a, + properties={ + "bs": relationship( + B, + secondary=atob, + cascade="all, delete-orphan", + single_parent=True, + backref=backref("a", uselist=False), + ) + }, + ) mapper(B, b) sess = create_session() - b1 = B(data='b1') - a1 = A(data='a1', bs=[b1]) + b1 = B(data="b1") + a1 = A(data="a1", bs=[b1]) - assert_raises( - sa_exc.InvalidRequestError, - A, data='a2', bs=[b1] - ) + assert_raises(sa_exc.InvalidRequestError, A, data="a2", bs=[b1]) - a2 = A(data='a2') + a2 = A(data="a2") b1.a = a2 assert b1 not in a1.bs assert b1 in a2.bs def test_none_m2m_collection_assignment(self): - a, A, B, b, atob = (self.tables.a, - self.classes.A, - self.classes.B, - self.tables.b, - self.tables.atob) - - mapper(A, a, properties={ - 'bs': relationship(B, - secondary=atob, backref="as") - }) + a, A, B, b, atob = ( + self.tables.a, + self.classes.A, + self.classes.B, + self.tables.b, + self.tables.atob, + ) + + mapper( + A, + a, + properties={"bs": relationship(B, secondary=atob, backref="as")}, + ) mapper(B, b) s = Session() @@ -1981,20 +2324,24 @@ class M2MCascadeTest(fixtures.MappedTest): assert_raises_message( orm_exc.FlushError, "Can't flush None value found in collection A.bs", - s.commit + s.commit, ) eq_(a1.bs, [None]) def test_none_m2m_collection_append(self): - a, A, B, b, atob = (self.tables.a, - self.classes.A, - self.classes.B, - self.tables.b, - self.tables.atob) - - mapper(A, a, properties={ - 'bs': relationship(B, secondary=atob, backref="as") - }) + a, A, B, b, atob = ( + self.tables.a, + self.classes.A, + self.classes.B, + self.tables.b, + self.tables.atob, + ) + + mapper( + A, + a, + properties={"bs": relationship(B, secondary=atob, backref="as")}, + ) mapper(B, b) s = Session() @@ -2005,7 +2352,7 @@ class M2MCascadeTest(fixtures.MappedTest): assert_raises_message( orm_exc.FlushError, "Can't flush None value found in collection A.bs", - s.commit + s.commit, ) eq_(a1.bs, [None]) @@ -2013,10 +2360,14 @@ class M2MCascadeTest(fixtures.MappedTest): class O2MSelfReferentialDetelOrphanTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('node', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, ForeignKey('node.id'))) + Table( + "node", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_id", Integer, ForeignKey("node.id")), + ) @classmethod def setup_classes(cls): @@ -2027,16 +2378,17 @@ class O2MSelfReferentialDetelOrphanTest(fixtures.MappedTest): def setup_mappers(cls): Node = cls.classes.Node node = cls.tables.node - mapper(Node, node, properties={ - "children": relationship( - Node, - cascade="all, delete-orphan", - backref=backref( - "parent", - remote_side=node.c.id + mapper( + Node, + node, + properties={ + "children": relationship( + Node, + cascade="all, delete-orphan", + backref=backref("parent", remote_side=node.c.id), ) - ) - }) + }, + ) def test_self_referential_delete(self): Node = self.classes.Node @@ -2059,34 +2411,45 @@ class NoBackrefCascadeTest(_fixtures.FixtureTest): @classmethod def setup_mappers(cls): - addresses, Dingaling, User, dingalings, Address, users = \ - (cls.tables.addresses, - cls.classes.Dingaling, - cls.classes.User, - cls.tables.dingalings, - cls.classes.Address, - cls.tables.users) + addresses, Dingaling, User, dingalings, Address, users = ( + cls.tables.addresses, + cls.classes.Dingaling, + cls.classes.User, + cls.tables.dingalings, + cls.classes.Address, + cls.tables.users, + ) mapper(Address, addresses) - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user', - cascade_backrefs=False) - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, backref="user", cascade_backrefs=False + ) + }, + ) - mapper(Dingaling, dingalings, properties={ - 'address': relationship(Address, backref='dingalings', - cascade_backrefs=False) - }) + mapper( + Dingaling, + dingalings, + properties={ + "address": relationship( + Address, backref="dingalings", cascade_backrefs=False + ) + }, + ) def test_o2m_basic(self): User, Address = self.classes.User, self.classes.Address sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) - a1 = Address(email_address='a1') + a1 = Address(email_address="a1") a1.user = u1 assert a1 not in sess @@ -2095,17 +2458,13 @@ class NoBackrefCascadeTest(_fixtures.FixtureTest): sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) - a1 = Address(email_address='a1') + a1 = Address(email_address="a1") a1.user = u1 - assert_raises_message( - sa_exc.SAWarning, - "not in session", - sess.commit - ) + assert_raises_message(sa_exc.SAWarning, "not in session", sess.commit) assert a1 not in sess @@ -2114,7 +2473,7 @@ class NoBackrefCascadeTest(_fixtures.FixtureTest): sess = Session() - a1 = Address(email_address='a1') + a1 = Address(email_address="a1") sess.add(a1) d1 = Dingaling() @@ -2129,7 +2488,7 @@ class NoBackrefCascadeTest(_fixtures.FixtureTest): sess = Session() - a1 = Address(email_address='a1') + a1 = Address(email_address="a1") d1 = Dingaling() sess.add(d1) @@ -2141,10 +2500,10 @@ class NoBackrefCascadeTest(_fixtures.FixtureTest): sess = Session() - a1 = Address(email_address='a1') + a1 = Address(email_address="a1") sess.add(a1) - u1 = User(name='u1') + u1 = User(name="u1") u1.addresses.append(a1) assert u1 in sess @@ -2153,18 +2512,14 @@ class NoBackrefCascadeTest(_fixtures.FixtureTest): sess = Session() - a1 = Address(email_address='a1') + a1 = Address(email_address="a1") d1 = Dingaling() sess.add(d1) a1.dingalings.append(d1) assert a1 not in sess - assert_raises_message( - sa_exc.SAWarning, - "not in session", - sess.commit - ) + assert_raises_message(sa_exc.SAWarning, "not in session", sess.commit) class PendingOrphanTestSingleLevel(fixtures.MappedTest): @@ -2172,21 +2527,43 @@ class PendingOrphanTestSingleLevel(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('user_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(40))) - - Table('addresses', metadata, - Column('address_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', Integer, ForeignKey('users.user_id')), - Column('email_address', String(40))) - Table('orders', metadata, - Column('order_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', Integer, ForeignKey( - 'users.user_id'), nullable=False)) + Table( + "users", + metadata, + Column( + "user_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(40)), + ) + + Table( + "addresses", + metadata, + Column( + "address_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("user_id", Integer, ForeignKey("users.user_id")), + Column("email_address", String(40)), + ) + Table( + "orders", + metadata, + Column( + "order_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column( + "user_id", Integer, ForeignKey("users.user_id"), nullable=False + ), + ) @classmethod def setup_classes(cls): @@ -2208,21 +2585,27 @@ class PendingOrphanTestSingleLevel(fixtures.MappedTest): """ - users, orders, User, Address, Order, addresses = \ - (self.tables.users, - self.tables.orders, - self.classes.User, - self.classes.Address, - self.classes.Order, - self.tables.addresses) + users, orders, User, Address, Order, addresses = ( + self.tables.users, + self.tables.orders, + self.classes.User, + self.classes.Address, + self.classes.Order, + self.tables.addresses, + ) mapper(Order, orders) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, cascade="all,delete-orphan", - backref="user"), - orders=relationship(Order, cascade='all, delete-orphan') - )) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, cascade="all,delete-orphan", backref="user" + ), + orders=relationship(Order, cascade="all, delete-orphan"), + ), + ) s = Session() # the standalone Address goes in, its foreign key @@ -2251,16 +2634,23 @@ class PendingOrphanTestSingleLevel(fixtures.MappedTest): """Removing a pending item from a collection expunges it from the session.""" - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, cascade="all,delete-orphan", - backref="user") - )) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, cascade="all,delete-orphan", backref="user" + ) + ), + ) s = create_session() u = User() @@ -2280,25 +2670,32 @@ class PendingOrphanTestSingleLevel(fixtures.MappedTest): assert a.address_id is None, "Error: address should not be persistent" def test_nonorphans_ok(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, cascade="all,delete", - backref="user") - )) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, cascade="all,delete", backref="user" + ) + ), + ) s = create_session() - u = User(name='u1', addresses=[Address(email_address='ad1')]) + u = User(name="u1", addresses=[Address(email_address="ad1")]) s.add(u) a1 = u.addresses[0] u.addresses.remove(a1) assert a1 in s s.flush() s.expunge_all() - eq_(s.query(Address).all(), [Address(email_address='ad1')]) + eq_(s.query(Address).all(), [Address(email_address="ad1")]) class PendingOrphanTestTwoLevel(fixtures.MappedTest): @@ -2310,19 +2707,31 @@ class PendingOrphanTestTwoLevel(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('order', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True)) - Table('item', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('order_id', Integer, ForeignKey( - 'order.id'), nullable=False)) - Table('attribute', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('item_id', Integer, ForeignKey('item.id'), - nullable=False)) + Table( + "order", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) + Table( + "item", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column( + "order_id", Integer, ForeignKey("order.id"), nullable=False + ), + ) + Table( + "attribute", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("item_id", Integer, ForeignKey("item.id"), nullable=False), + ) @classmethod def setup_classes(cls): @@ -2336,14 +2745,20 @@ class PendingOrphanTestTwoLevel(fixtures.MappedTest): pass def test_singlelevel_remove(self): - item, Order, order, Item = (self.tables.item, - self.classes.Order, - self.tables.order, - self.classes.Item) - - mapper(Order, order, properties={ - 'items': relationship(Item, cascade="all, delete-orphan") - }) + item, Order, order, Item = ( + self.tables.item, + self.classes.Order, + self.tables.order, + self.classes.Item, + ) + + mapper( + Order, + order, + properties={ + "items": relationship(Item, cascade="all, delete-orphan") + }, + ) mapper(Item, item) s = Session() o1 = Order() @@ -2356,20 +2771,31 @@ class PendingOrphanTestTwoLevel(fixtures.MappedTest): assert i1 not in o1.items def test_multilevel_remove(self): - Item, Attribute, order, item, attribute, Order = \ - (self.classes.Item, - self.classes.Attribute, - self.tables.order, - self.tables.item, - self.tables.attribute, - self.classes.Order) - - mapper(Order, order, properties={ - 'items': relationship(Item, cascade="all, delete-orphan") - }) - mapper(Item, item, properties={ - 'attributes': relationship(Attribute, cascade="all, delete-orphan") - }) + Item, Attribute, order, item, attribute, Order = ( + self.classes.Item, + self.classes.Attribute, + self.tables.order, + self.tables.item, + self.tables.attribute, + self.classes.Order, + ) + + mapper( + Order, + order, + properties={ + "items": relationship(Item, cascade="all, delete-orphan") + }, + ) + mapper( + Item, + item, + properties={ + "attributes": relationship( + Attribute, cascade="all, delete-orphan" + ) + }, + ) mapper(Attribute, attribute) s = Session() o1 = Order() @@ -2407,28 +2833,51 @@ class DoubleParentO2MOrphanTest(fixtures.MappedTest): @classmethod def define_tables(cls, meta): - Table('sales_reps', meta, - Column('sales_rep_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50))) - Table('accounts', meta, - Column('account_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('balance', Integer)) - - Table('customers', meta, - Column('customer_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('sales_rep_id', Integer, - ForeignKey('sales_reps.sales_rep_id')), - Column('account_id', Integer, - ForeignKey('accounts.account_id'))) + Table( + "sales_reps", + meta, + Column( + "sales_rep_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + ) + Table( + "accounts", + meta, + Column( + "account_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("balance", Integer), + ) + + Table( + "customers", + meta, + Column( + "customer_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column( + "sales_rep_id", Integer, ForeignKey("sales_reps.sales_rep_id") + ), + Column("account_id", Integer, ForeignKey("accounts.account_id")), + ) def _fixture(self, legacy_is_orphan, uselist): - sales_reps, customers, accounts = (self.tables.sales_reps, - self.tables.customers, - self.tables.accounts) + sales_reps, customers, accounts = ( + self.tables.sales_reps, + self.tables.customers, + self.tables.accounts, + ) class Customer(fixtures.ComparableEntity): pass @@ -2440,16 +2889,30 @@ class DoubleParentO2MOrphanTest(fixtures.MappedTest): pass mapper(Customer, customers, legacy_is_orphan=legacy_is_orphan) - mapper(Account, accounts, properties=dict( - customers=relationship(Customer, - cascade="all,delete-orphan", - backref="account", - uselist=uselist))) - mapper(SalesRep, sales_reps, properties=dict( - customers=relationship(Customer, - cascade="all,delete-orphan", - backref="sales_rep", - uselist=uselist))) + mapper( + Account, + accounts, + properties=dict( + customers=relationship( + Customer, + cascade="all,delete-orphan", + backref="account", + uselist=uselist, + ) + ), + ) + mapper( + SalesRep, + sales_reps, + properties=dict( + customers=relationship( + Customer, + cascade="all,delete-orphan", + backref="sales_rep", + uselist=uselist, + ) + ), + ) s = create_session() a = Account(balance=0) @@ -2479,8 +2942,7 @@ class DoubleParentO2MOrphanTest(fixtures.MappedTest): assert c in s, "Should not expunge customer yet, still has one parent" sr.customers.remove(c) - assert c not in s, \ - 'Should expunge customer when both parents are gone' + assert c not in s, "Should expunge customer when both parents are gone" def test_double_parent_expunge_o2m_current(self): """test the delete-orphan uow event for multiple delete-orphan @@ -2492,8 +2954,7 @@ class DoubleParentO2MOrphanTest(fixtures.MappedTest): assert c not in s, "Should expunge customer when either parent is gone" sr.customers.remove(c) - assert c not in s, \ - 'Should expunge customer when both parents are gone' + assert c not in s, "Should expunge customer when both parents are gone" def test_double_parent_expunge_o2o_legacy(self): """test the delete-orphan uow event for multiple delete-orphan @@ -2505,8 +2966,7 @@ class DoubleParentO2MOrphanTest(fixtures.MappedTest): assert c in s, "Should not expunge customer yet, still has one parent" sr.customers = None - assert c not in s, \ - 'Should expunge customer when both parents are gone' + assert c not in s, "Should expunge customer when both parents are gone" def test_double_parent_expunge_o2o_current(self): """test the delete-orphan uow event for multiple delete-orphan @@ -2518,8 +2978,7 @@ class DoubleParentO2MOrphanTest(fixtures.MappedTest): assert c not in s, "Should expunge customer when either parent is gone" sr.customers = None - assert c not in s, \ - 'Should expunge customer when both parents are gone' + assert c not in s, "Should expunge customer when both parents are gone" class DoubleParentM2OOrphanTest(fixtures.MappedTest): @@ -2530,32 +2989,65 @@ class DoubleParentM2OOrphanTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('addresses', metadata, - Column('address_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('street', String(30))) - - Table('homes', metadata, - Column('home_id', Integer, primary_key=True, key="id", - test_needs_autoincrement=True), - Column('description', String(30)), - Column('address_id', Integer, ForeignKey('addresses.address_id'), - nullable=False)) - - Table('businesses', metadata, - Column('business_id', Integer, primary_key=True, key="id", - test_needs_autoincrement=True), - Column('description', String(30), key="description"), - Column('address_id', Integer, ForeignKey('addresses.address_id'), - nullable=False)) + Table( + "addresses", + metadata, + Column( + "address_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("street", String(30)), + ) + + Table( + "homes", + metadata, + Column( + "home_id", + Integer, + primary_key=True, + key="id", + test_needs_autoincrement=True, + ), + Column("description", String(30)), + Column( + "address_id", + Integer, + ForeignKey("addresses.address_id"), + nullable=False, + ), + ) + + Table( + "businesses", + metadata, + Column( + "business_id", + Integer, + primary_key=True, + key="id", + test_needs_autoincrement=True, + ), + Column("description", String(30), key="description"), + Column( + "address_id", + Integer, + ForeignKey("addresses.address_id"), + nullable=False, + ), + ) def test_non_orphan(self): """test that an entity can have two parent delete-orphan cascades, and persists normally.""" - homes, businesses, addresses = (self.tables.homes, - self.tables.businesses, - self.tables.addresses) + homes, businesses, addresses = ( + self.tables.homes, + self.tables.businesses, + self.tables.addresses, + ) class Address(fixtures.ComparableEntity): pass @@ -2567,34 +3059,55 @@ class DoubleParentM2OOrphanTest(fixtures.MappedTest): pass mapper(Address, addresses) - mapper(Home, homes, properties={'address': relationship( - Address, cascade='all,delete-orphan', single_parent=True)}) - mapper(Business, businesses, properties={'address': relationship( - Address, cascade='all,delete-orphan', single_parent=True)}) + mapper( + Home, + homes, + properties={ + "address": relationship( + Address, cascade="all,delete-orphan", single_parent=True + ) + }, + ) + mapper( + Business, + businesses, + properties={ + "address": relationship( + Address, cascade="all,delete-orphan", single_parent=True + ) + }, + ) session = create_session() - h1 = Home(description='home1', address=Address(street='address1')) - b1 = Business(description='business1', - address=Address(street='address2')) + h1 = Home(description="home1", address=Address(street="address1")) + b1 = Business( + description="business1", address=Address(street="address2") + ) session.add_all((h1, b1)) session.flush() session.expunge_all() - eq_(session.query(Home).get(h1.id), Home(description='home1', - address=Address( - street='address1'))) - eq_(session.query(Business).get(b1.id), - Business(description='business1', - address=Address(street='address2'))) + eq_( + session.query(Home).get(h1.id), + Home(description="home1", address=Address(street="address1")), + ) + eq_( + session.query(Business).get(b1.id), + Business( + description="business1", address=Address(street="address2") + ), + ) def test_orphan(self): """test that an entity can have two parent delete-orphan cascades, and is detected as an orphan when saved without a parent.""" - homes, businesses, addresses = (self.tables.homes, - self.tables.businesses, - self.tables.addresses) + homes, businesses, addresses = ( + self.tables.homes, + self.tables.businesses, + self.tables.addresses, + ) class Address(fixtures.ComparableEntity): pass @@ -2606,10 +3119,24 @@ class DoubleParentM2OOrphanTest(fixtures.MappedTest): pass mapper(Address, addresses) - mapper(Home, homes, properties={'address': relationship( - Address, cascade='all,delete-orphan', single_parent=True)}) - mapper(Business, businesses, properties={'address': relationship( - Address, cascade='all,delete-orphan', single_parent=True)}) + mapper( + Home, + homes, + properties={ + "address": relationship( + Address, cascade="all,delete-orphan", single_parent=True + ) + }, + ) + mapper( + Business, + businesses, + properties={ + "address": relationship( + Address, cascade="all,delete-orphan", single_parent=True + ) + }, + ) session = create_session() a1 = Address() session.add(a1) @@ -2619,15 +3146,23 @@ class DoubleParentM2OOrphanTest(fixtures.MappedTest): class CollectionAssignmentOrphanTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('table_a', metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('name', String(30))) - Table('table_b', metadata, - Column('id', Integer, - primary_key=True, test_needs_autoincrement=True), - Column('name', String(30)), - Column('a_id', Integer, ForeignKey('table_a.id'))) + Table( + "table_a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30)), + ) + Table( + "table_b", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30)), + Column("a_id", Integer, ForeignKey("table_a.id")), + ) def test_basic(self): table_b, table_a = self.tables.table_b, self.tables.table_a @@ -2638,12 +3173,14 @@ class CollectionAssignmentOrphanTest(fixtures.MappedTest): class B(fixtures.ComparableEntity): pass - mapper(A, table_a, properties={ - 'bs': relationship(B, cascade="all, delete-orphan") - }) + mapper( + A, + table_a, + properties={"bs": relationship(B, cascade="all, delete-orphan")}, + ) mapper(B, table_b) - a1 = A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')]) + a1 = A(name="a1", bs=[B(name="b1"), B(name="b2"), B(name="b3")]) sess = create_session() sess.add(a1) @@ -2651,40 +3188,63 @@ class CollectionAssignmentOrphanTest(fixtures.MappedTest): sess.expunge_all() - eq_(sess.query(A).get(a1.id), - A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')])) + eq_( + sess.query(A).get(a1.id), + A(name="a1", bs=[B(name="b1"), B(name="b2"), B(name="b3")]), + ) a1 = sess.query(A).get(a1.id) assert not class_mapper(B)._is_orphan( - attributes.instance_state(a1.bs[0])) - a1.bs[0].foo = 'b2modified' - a1.bs[1].foo = 'b3modified' + attributes.instance_state(a1.bs[0]) + ) + a1.bs[0].foo = "b2modified" + a1.bs[1].foo = "b3modified" sess.flush() sess.expunge_all() - eq_(sess.query(A).get(a1.id), - A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')])) + eq_( + sess.query(A).get(a1.id), + A(name="a1", bs=[B(name="b1"), B(name="b2"), B(name="b3")]), + ) class OrphanCriterionTest(fixtures.MappedTest): @classmethod def define_tables(self, metadata): - Table("core", metadata, - Column("id", Integer, - primary_key=True, test_needs_autoincrement=True), - Column("related_one_id", Integer, ForeignKey("related_one.id")), - Column("related_two_id", Integer, ForeignKey("related_two.id"))) - - Table("related_one", metadata, - Column("id", Integer, - primary_key=True, test_needs_autoincrement=True)) - - Table("related_two", metadata, - Column("id", Integer, - primary_key=True, test_needs_autoincrement=True)) - - def _fixture(self, legacy_is_orphan, persistent, - r1_present, r2_present, detach_event=True): + Table( + "core", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("related_one_id", Integer, ForeignKey("related_one.id")), + Column("related_two_id", Integer, ForeignKey("related_two.id")), + ) + + Table( + "related_one", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) + + Table( + "related_two", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) + + def _fixture( + self, + legacy_is_orphan, + persistent, + r1_present, + r2_present, + detach_event=True, + ): class Core(object): pass @@ -2697,14 +3257,24 @@ class OrphanCriterionTest(fixtures.MappedTest): self.cores = cores mapper(Core, self.tables.core, legacy_is_orphan=legacy_is_orphan) - mapper(RelatedOne, self.tables.related_one, properties={ - 'cores': relationship(Core, cascade="all, delete-orphan", - backref="r1") - }) - mapper(RelatedTwo, self.tables.related_two, properties={ - 'cores': relationship(Core, cascade="all, delete-orphan", - backref="r2") - }) + mapper( + RelatedOne, + self.tables.related_one, + properties={ + "cores": relationship( + Core, cascade="all, delete-orphan", backref="r1" + ) + }, + ) + mapper( + RelatedTwo, + self.tables.related_two, + properties={ + "cores": relationship( + Core, cascade="all, delete-orphan", backref="r2" + ) + }, + ) c1 = Core() if detach_event: r1 = RelatedOne(cores=[c1]) @@ -2826,14 +3396,23 @@ class O2MConflictTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table("parent", metadata, - Column("id", Integer, primary_key=True, - test_needs_autoincrement=True)) - Table("child", metadata, - Column("id", Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, ForeignKey('parent.id'), - nullable=False)) + Table( + "parent", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) + Table( + "child", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column( + "parent_id", Integer, ForeignKey("parent.id"), nullable=False + ), + ) @classmethod def setup_classes(cls): @@ -2869,113 +3448,171 @@ class O2MConflictTest(fixtures.MappedTest): eq_(sess.query(Child).filter(Child.parent_id == p2.id).all(), [c1]) def test_o2o_delete_old(self): - Child, Parent, parent, child = (self.classes.Child, - self.classes.Parent, - self.tables.parent, - self.tables.child) - - mapper(Parent, parent, properties={ - 'child': relationship(Child, uselist=False) - }) + Child, Parent, parent, child = ( + self.classes.Child, + self.classes.Parent, + self.tables.parent, + self.tables.child, + ) + + mapper( + Parent, + parent, + properties={"child": relationship(Child, uselist=False)}, + ) mapper(Child, child) self._do_move_test(True) self._do_move_test(False) def test_o2m_delete_old(self): - Child, Parent, parent, child = (self.classes.Child, - self.classes.Parent, - self.tables.parent, - self.tables.child) - - mapper(Parent, parent, properties={ - 'child': relationship(Child, uselist=True) - }) + Child, Parent, parent, child = ( + self.classes.Child, + self.classes.Parent, + self.tables.parent, + self.tables.child, + ) + + mapper( + Parent, + parent, + properties={"child": relationship(Child, uselist=True)}, + ) mapper(Child, child) self._do_move_test(True) self._do_move_test(False) def test_o2o_backref_delete_old(self): - Child, Parent, parent, child = (self.classes.Child, - self.classes.Parent, - self.tables.parent, - self.tables.child) - - mapper(Parent, parent, properties={ - 'child': relationship(Child, uselist=False, backref='parent') - }) + Child, Parent, parent, child = ( + self.classes.Child, + self.classes.Parent, + self.tables.parent, + self.tables.child, + ) + + mapper( + Parent, + parent, + properties={ + "child": relationship(Child, uselist=False, backref="parent") + }, + ) mapper(Child, child) self._do_move_test(True) self._do_move_test(False) def test_o2o_delcascade_delete_old(self): - Child, Parent, parent, child = (self.classes.Child, - self.classes.Parent, - self.tables.parent, - self.tables.child) - - mapper(Parent, parent, properties={ - 'child': relationship(Child, uselist=False, cascade="all, delete") - }) + Child, Parent, parent, child = ( + self.classes.Child, + self.classes.Parent, + self.tables.parent, + self.tables.child, + ) + + mapper( + Parent, + parent, + properties={ + "child": relationship( + Child, uselist=False, cascade="all, delete" + ) + }, + ) mapper(Child, child) self._do_move_test(True) self._do_move_test(False) def test_o2o_delorphan_delete_old(self): - Child, Parent, parent, child = (self.classes.Child, - self.classes.Parent, - self.tables.parent, - self.tables.child) - - mapper(Parent, parent, properties={ - 'child': relationship(Child, uselist=False, - cascade="all, delete, delete-orphan") - }) + Child, Parent, parent, child = ( + self.classes.Child, + self.classes.Parent, + self.tables.parent, + self.tables.child, + ) + + mapper( + Parent, + parent, + properties={ + "child": relationship( + Child, uselist=False, cascade="all, delete, delete-orphan" + ) + }, + ) mapper(Child, child) self._do_move_test(True) self._do_move_test(False) def test_o2o_delorphan_backref_delete_old(self): - Child, Parent, parent, child = (self.classes.Child, - self.classes.Parent, - self.tables.parent, - self.tables.child) - - mapper(Parent, parent, properties={ - 'child': relationship(Child, uselist=False, - cascade="all, delete, delete-orphan", - backref='parent') - }) + Child, Parent, parent, child = ( + self.classes.Child, + self.classes.Parent, + self.tables.parent, + self.tables.child, + ) + + mapper( + Parent, + parent, + properties={ + "child": relationship( + Child, + uselist=False, + cascade="all, delete, delete-orphan", + backref="parent", + ) + }, + ) mapper(Child, child) self._do_move_test(True) self._do_move_test(False) def test_o2o_backref_delorphan_delete_old(self): - Child, Parent, parent, child = (self.classes.Child, - self.classes.Parent, - self.tables.parent, - self.tables.child) + Child, Parent, parent, child = ( + self.classes.Child, + self.classes.Parent, + self.tables.parent, + self.tables.child, + ) mapper(Parent, parent) - mapper(Child, child, properties={ - 'parent': relationship(Parent, uselist=False, single_parent=True, - backref=backref('child', uselist=False), - cascade="all,delete,delete-orphan") - }) + mapper( + Child, + child, + properties={ + "parent": relationship( + Parent, + uselist=False, + single_parent=True, + backref=backref("child", uselist=False), + cascade="all,delete,delete-orphan", + ) + }, + ) self._do_move_test(True) self._do_move_test(False) def test_o2m_backref_delorphan_delete_old(self): - Child, Parent, parent, child = (self.classes.Child, - self.classes.Parent, - self.tables.parent, - self.tables.child) + Child, Parent, parent, child = ( + self.classes.Child, + self.classes.Parent, + self.tables.parent, + self.tables.child, + ) mapper(Parent, parent) - mapper(Child, child, properties={ - 'parent': relationship(Parent, uselist=False, single_parent=True, - backref=backref('child', uselist=True), - cascade="all,delete,delete-orphan") - }) + mapper( + Child, + child, + properties={ + "parent": relationship( + Parent, + uselist=False, + single_parent=True, + backref=backref("child", uselist=True), + cascade="all,delete,delete-orphan", + ) + }, + ) self._do_move_test(True) self._do_move_test(False) @@ -2983,23 +3620,38 @@ class O2MConflictTest(fixtures.MappedTest): class PartialFlushTest(fixtures.MappedTest): """test cascade behavior as it relates to object lists passed to flush(). """ + @classmethod def define_tables(cls, metadata): - Table("base", metadata, - Column("id", Integer, primary_key=True, - test_needs_autoincrement=True), - Column("descr", String(50))) - - Table("noninh_child", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('base_id', Integer, ForeignKey('base.id'))) - - Table("parent", metadata, - Column("id", Integer, ForeignKey("base.id"), primary_key=True)) - Table("inh_child", metadata, - Column("id", Integer, ForeignKey("base.id"), primary_key=True), - Column("parent_id", Integer, ForeignKey("parent.id"))) + Table( + "base", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("descr", String(50)), + ) + + Table( + "noninh_child", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("base_id", Integer, ForeignKey("base.id")), + ) + + Table( + "parent", + metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + ) + Table( + "inh_child", + metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column("parent_id", Integer, ForeignKey("parent.id")), + ) def test_o2m_m2o(self): base, noninh_child = self.tables.base, self.tables.noninh_child @@ -3010,15 +3662,17 @@ class PartialFlushTest(fixtures.MappedTest): class Child(fixtures.ComparableEntity): pass - mapper(Base, base, properties={ - 'children': relationship(Child, backref='parent') - }) + mapper( + Base, + base, + properties={"children": relationship(Child, backref="parent")}, + ) mapper(Child, noninh_child) sess = create_session() c1, c2 = Child(), Child() - b1 = Base(descr='b1', children=[c1, c2]) + b1 = Base(descr="b1", children=[c1, c2]) sess.add(b1) assert c1 in sess.new @@ -3034,7 +3688,7 @@ class PartialFlushTest(fixtures.MappedTest): sess = create_session() c1, c2 = Child(), Child() - b1 = Base(descr='b1', children=[c1, c2]) + b1 = Base(descr="b1", children=[c1, c2]) sess.add(b1) sess.flush([c1]) # m2o, otoh, doesn't cascade up the other way. @@ -3044,7 +3698,7 @@ class PartialFlushTest(fixtures.MappedTest): sess = create_session() c1, c2 = Child(), Child() - b1 = Base(descr='b1', children=[c1, c2]) + b1 = Base(descr="b1", children=[c1, c2]) sess.add(b1) sess.flush([c1, c2]) # m2o, otoh, doesn't cascade up the other way. @@ -3055,9 +3709,11 @@ class PartialFlushTest(fixtures.MappedTest): def test_circular_sort(self): """test ticket 1306""" - base, inh_child, parent = (self.tables.base, - self.tables.inh_child, - self.tables.parent) + base, inh_child, parent = ( + self.tables.base, + self.tables.inh_child, + self.tables.parent, + ) class Base(fixtures.ComparableEntity): pass @@ -3070,13 +3726,18 @@ class PartialFlushTest(fixtures.MappedTest): mapper(Base, base) - mapper(Child, inh_child, - inherits=Base, - properties={'parent': relationship( - Parent, - backref='children', - primaryjoin=inh_child.c.parent_id == parent.c.id - )}) + mapper( + Child, + inh_child, + inherits=Base, + properties={ + "parent": relationship( + Parent, + backref="children", + primaryjoin=inh_child.c.parent_id == parent.c.id, + ) + }, + ) mapper(Parent, parent, inherits=Base) @@ -3099,87 +3760,82 @@ class SubclassCascadeTest(fixtures.DeclarativeMappedTest): Base = cls.DeclarativeBasic class Company(Base): - __tablename__ = 'company' + __tablename__ = "company" id = Column(Integer, primary_key=True) name = Column(String(50)) employees = relationship("Employee", cascade="all, delete-orphan") class Employee(Base): - __tablename__ = 'employee' + __tablename__ = "employee" id = Column(Integer, primary_key=True) name = Column(String(50)) type = Column(String(50)) - company_id = Column(ForeignKey('company.id')) + company_id = Column(ForeignKey("company.id")) __mapper_args__ = { - 'polymorphic_identity': 'employee', - 'polymorphic_on': type + "polymorphic_identity": "employee", + "polymorphic_on": type, } class Engineer(Employee): - __tablename__ = 'engineer' - id = Column(Integer, ForeignKey('employee.id'), primary_key=True) + __tablename__ = "engineer" + id = Column(Integer, ForeignKey("employee.id"), primary_key=True) engineer_name = Column(String(30)) languages = relationship("Language", cascade="all, delete-orphan") - __mapper_args__ = { - 'polymorphic_identity': 'engineer', - } + __mapper_args__ = {"polymorphic_identity": "engineer"} class MavenBuild(Base): - __tablename__ = 'maven_build' + __tablename__ = "maven_build" id = Column(Integer, primary_key=True) java_language_id = Column( - ForeignKey('java_language.id'), nullable=False) + ForeignKey("java_language.id"), nullable=False + ) class Manager(Employee): - __tablename__ = 'manager' - id = Column(Integer, ForeignKey('employee.id'), primary_key=True) + __tablename__ = "manager" + id = Column(Integer, ForeignKey("employee.id"), primary_key=True) manager_name = Column(String(30)) - __mapper_args__ = { - 'polymorphic_identity': 'manager', - } + __mapper_args__ = {"polymorphic_identity": "manager"} class Language(Base): - __tablename__ = 'language' + __tablename__ = "language" id = Column(Integer, primary_key=True) - engineer_id = Column(ForeignKey('engineer.id'), nullable=False) + engineer_id = Column(ForeignKey("engineer.id"), nullable=False) name = Column(String(50)) type = Column(String(50)) __mapper_args__ = { "polymorphic_on": type, - "polymorphic_identity": "language" + "polymorphic_identity": "language", } class JavaLanguage(Language): - __tablename__ = 'java_language' - id = Column(ForeignKey('language.id'), primary_key=True) - maven_builds = relationship("MavenBuild", - cascade="all, delete-orphan") + __tablename__ = "java_language" + id = Column(ForeignKey("language.id"), primary_key=True) + maven_builds = relationship( + "MavenBuild", cascade="all, delete-orphan" + ) - __mapper_args__ = { - "polymorphic_identity": "java_language" - } + __mapper_args__ = {"polymorphic_identity": "java_language"} def test_cascade_iterator_polymorphic(self): - Company, Employee, Engineer, Language, JavaLanguage, MavenBuild = \ - self.classes( - 'Company', 'Employee', 'Engineer', 'Language', 'JavaLanguage', - 'MavenBuild' + Company, Employee, Engineer, Language, JavaLanguage, MavenBuild = self.classes( + "Company", + "Employee", + "Engineer", + "Language", + "JavaLanguage", + "MavenBuild", ) obj = Company( employees=[ Engineer( languages=[ - JavaLanguage( - name="java", - maven_builds=[MavenBuild()] - ) - ], - + JavaLanguage(name="java", maven_builds=[MavenBuild()]) + ] ) ] ) @@ -3188,37 +3844,31 @@ class SubclassCascadeTest(fixtures.DeclarativeMappedTest): maven_build = lang.maven_builds[0] from sqlalchemy import inspect + state = inspect(obj) it = inspect(Company).cascade_iterator("save-update", state) - eq_( - set([rec[0] for rec in it]), - set([eng, maven_build, lang]) - ) + eq_(set([rec[0] for rec in it]), set([eng, maven_build, lang])) state = inspect(eng) it = inspect(Employee).cascade_iterator("save-update", state) - eq_( - set([rec[0] for rec in it]), - set([maven_build, lang]) - ) + eq_(set([rec[0] for rec in it]), set([maven_build, lang])) def test_delete_orphan_round_trip(self): - Company, Employee, Engineer, Language, JavaLanguage, \ - MavenBuild = self.classes( - 'Company', 'Employee', 'Engineer', 'Language', 'JavaLanguage', - 'MavenBuild' - ) + Company, Employee, Engineer, Language, JavaLanguage, MavenBuild = self.classes( + "Company", + "Employee", + "Engineer", + "Language", + "JavaLanguage", + "MavenBuild", + ) obj = Company( employees=[ Engineer( languages=[ - JavaLanguage( - name="java", - maven_builds=[MavenBuild()] - ) - ], - + JavaLanguage(name="java", maven_builds=[MavenBuild()]) + ] ) ] ) @@ -3229,4 +3879,4 @@ class SubclassCascadeTest(fixtures.DeclarativeMappedTest): obj.employees = [] s.commit() - eq_(s.query(Language).count(), 0) \ No newline at end of file + eq_(s.query(Language).count(), 0) diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py index 58c8706457..363104596f 100644 --- a/test/orm/test_collection.py +++ b/test/orm/test_collection.py @@ -8,8 +8,13 @@ import sqlalchemy as sa from sqlalchemy import Integer, String, ForeignKey, text from sqlalchemy.testing.schema import Table, Column from sqlalchemy import util, exc as sa_exc -from sqlalchemy.orm import create_session, mapper, relationship, \ - attributes, instrumentation +from sqlalchemy.orm import ( + create_session, + mapper, + relationship, + attributes, + instrumentation, +) from sqlalchemy.testing import fixtures from sqlalchemy.testing import assert_raises, assert_raises_message from sqlalchemy import testing @@ -70,8 +75,8 @@ class CollectionsTest(fixtures.ORMTest): @classmethod def dictable_entity(cls, a=None, b=None, c=None): - id = cls._entity_id = (cls._entity_id + 1) - return cls.Entity(a or str(id), b or 'value %s' % id, c) + id = cls._entity_id = cls._entity_id + 1 + return cls.Entity(a or str(id), b or "value %s" % id, c) def _test_adapter(self, typecallable, creator=None, to_set=None): if creator is None: @@ -82,16 +87,22 @@ class CollectionsTest(fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, - extension=canary, - typecallable=typecallable, - useobject=True) + attributes.register_attribute( + Foo, + "attr", + uselist=True, + extension=canary, + typecallable=typecallable, + useobject=True, + ) obj = Foo() adapter = collections.collection_adapter(obj.attr) direct = obj.attr if to_set is None: - def to_set(col): return set(col) + + def to_set(col): + return set(col) def assert_eq(): self.assert_(to_set(direct) == canary.data) @@ -127,10 +138,14 @@ class CollectionsTest(fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, - extension=canary, - typecallable=typecallable, - useobject=True) + attributes.register_attribute( + Foo, + "attr", + uselist=True, + extension=canary, + typecallable=typecallable, + useobject=True, + ) obj = Foo() adapter = collections.collection_adapter(obj.attr) @@ -148,12 +163,12 @@ class CollectionsTest(fixtures.ORMTest): control.append(e) assert_eq() - if hasattr(direct, 'pop'): + if hasattr(direct, "pop"): direct.pop() control.pop() assert_eq() - if hasattr(direct, '__setitem__'): + if hasattr(direct, "__setitem__"): e = creator() direct.append(e) control.append(e) @@ -163,8 +178,14 @@ class CollectionsTest(fixtures.ORMTest): control[0] = e assert_eq() - if util.reduce(and_, [hasattr(direct, a) for a in - ('__delitem__', 'insert', '__len__')], True): + if util.reduce( + and_, + [ + hasattr(direct, a) + for a in ("__delitem__", "insert", "__len__") + ], + True, + ): values = [creator(), creator(), creator(), creator()] direct[slice(0, 1)] = values control[slice(0, 1)] = values @@ -186,9 +207,10 @@ class CollectionsTest(fixtures.ORMTest): def invalid(): direct[slice(0, 6, 2)] = [creator()] + assert_raises(ValueError, invalid) - if hasattr(direct, '__delitem__'): + if hasattr(direct, "__delitem__"): e = creator() direct.append(e) control.append(e) @@ -196,7 +218,7 @@ class CollectionsTest(fixtures.ORMTest): del control[-1] assert_eq() - if hasattr(direct, '__getslice__'): + if hasattr(direct, "__getslice__"): for e in [creator(), creator(), creator(), creator()]: direct.append(e) control.append(e) @@ -213,7 +235,7 @@ class CollectionsTest(fixtures.ORMTest): del control[::2] assert_eq() - if hasattr(direct, 'remove'): + if hasattr(direct, "remove"): e = creator() direct.append(e) control.append(e) @@ -222,7 +244,7 @@ class CollectionsTest(fixtures.ORMTest): control.remove(e) assert_eq() - if hasattr(direct, '__setitem__') or hasattr(direct, '__setslice__'): + if hasattr(direct, "__setitem__") or hasattr(direct, "__setslice__"): values = [creator(), creator()] direct[:] = values @@ -276,7 +298,7 @@ class CollectionsTest(fixtures.ORMTest): control[0:0] = values assert_eq() - if hasattr(direct, '__delitem__') or hasattr(direct, '__delslice__'): + if hasattr(direct, "__delitem__") or hasattr(direct, "__delslice__"): for i in range(1, 4): e = creator() direct.append(e) @@ -294,7 +316,7 @@ class CollectionsTest(fixtures.ORMTest): del control[:] assert_eq() - if hasattr(direct, 'clear'): + if hasattr(direct, "clear"): for i in range(1, 4): e = creator() direct.append(e) @@ -304,14 +326,14 @@ class CollectionsTest(fixtures.ORMTest): control.clear() assert_eq() - if hasattr(direct, 'extend'): + if hasattr(direct, "extend"): values = [creator(), creator(), creator()] direct.extend(values) control.extend(values) assert_eq() - if hasattr(direct, '__iadd__'): + if hasattr(direct, "__iadd__"): values = [creator(), creator(), creator()] direct += values @@ -327,7 +349,7 @@ class CollectionsTest(fixtures.ORMTest): control += values assert_eq() - if hasattr(direct, '__imul__'): + if hasattr(direct, "__imul__"): direct *= 2 control *= 2 assert_eq() @@ -345,10 +367,14 @@ class CollectionsTest(fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, - extension=canary, - typecallable=typecallable, - useobject=True) + attributes.register_attribute( + Foo, + "attr", + uselist=True, + extension=canary, + typecallable=typecallable, + useobject=True, + ) obj = Foo() direct = obj.attr @@ -443,13 +469,14 @@ class CollectionsTest(fixtures.ORMTest): def __iter__(self): return iter(self.data) + __hash__ = object.__hash__ def __eq__(self, other): return self.data == other def __repr__(self): - return 'ListLike(%s)' % repr(self.data) + return "ListLike(%s)" % repr(self.data) self._test_adapter(ListLike) self._test_list(ListLike) @@ -458,10 +485,11 @@ class CollectionsTest(fixtures.ORMTest): def test_list_subclass(self): class MyList(list): pass + self._test_adapter(MyList) self._test_list(MyList) self._test_list_bulk(MyList) - self.assert_(getattr(MyList, '_sa_instrumented') == id(MyList)) + self.assert_(getattr(MyList, "_sa_instrumented") == id(MyList)) def test_list_duck(self): class ListLike(object): @@ -485,18 +513,19 @@ class CollectionsTest(fixtures.ORMTest): def __iter__(self): return iter(self.data) + __hash__ = object.__hash__ def __eq__(self, other): return self.data == other def __repr__(self): - return 'ListLike(%s)' % repr(self.data) + return "ListLike(%s)" % repr(self.data) self._test_adapter(ListLike) self._test_list(ListLike) self._test_list_bulk(ListLike) - self.assert_(getattr(ListLike, '_sa_instrumented') == id(ListLike)) + self.assert_(getattr(ListLike, "_sa_instrumented") == id(ListLike)) def test_list_emulates(self): class ListIsh(object): @@ -522,18 +551,19 @@ class CollectionsTest(fixtures.ORMTest): def __iter__(self): return iter(self.data) + __hash__ = object.__hash__ def __eq__(self, other): return self.data == other def __repr__(self): - return 'ListIsh(%s)' % repr(self.data) + return "ListIsh(%s)" % repr(self.data) self._test_adapter(ListIsh) self._test_list(ListIsh) self._test_list_bulk(ListIsh) - self.assert_(getattr(ListIsh, '_sa_instrumented') == id(ListIsh)) + self.assert_(getattr(ListIsh, "_sa_instrumented") == id(ListIsh)) def _test_set(self, typecallable, creator=None): if creator is None: @@ -544,10 +574,14 @@ class CollectionsTest(fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, - extension=canary, - typecallable=typecallable, - useobject=True) + attributes.register_attribute( + Foo, + "attr", + uselist=True, + extension=canary, + typecallable=typecallable, + useobject=True, + ) obj = Foo() adapter = collections.collection_adapter(obj.attr) @@ -576,7 +610,7 @@ class CollectionsTest(fixtures.ORMTest): addall(e) addall(e) - if hasattr(direct, 'remove'): + if hasattr(direct, "remove"): e = creator() addall(e) @@ -593,7 +627,7 @@ class CollectionsTest(fixtures.ORMTest): else: self.assert_(False) - if hasattr(direct, 'discard'): + if hasattr(direct, "discard"): e = creator() addall(e) @@ -606,7 +640,7 @@ class CollectionsTest(fixtures.ORMTest): self.assert_(e not in canary.removed) assert_eq() - if hasattr(direct, 'update'): + if hasattr(direct, "update"): zap() e = creator() addall(e) @@ -617,7 +651,7 @@ class CollectionsTest(fixtures.ORMTest): control.update(values) assert_eq() - if hasattr(direct, '__ior__'): + if hasattr(direct, "__ior__"): zap() e = creator() addall(e) @@ -659,7 +693,7 @@ class CollectionsTest(fixtures.ORMTest): control.pop() assert_eq() - if hasattr(direct, 'difference_update'): + if hasattr(direct, "difference_update"): zap() e = creator() addall(creator(), creator()) @@ -673,7 +707,7 @@ class CollectionsTest(fixtures.ORMTest): control.difference_update(values) assert_eq() - if hasattr(direct, '__isub__'): + if hasattr(direct, "__isub__"): zap() e = creator() addall(creator(), creator()) @@ -703,7 +737,7 @@ class CollectionsTest(fixtures.ORMTest): except TypeError: assert True - if hasattr(direct, 'intersection_update'): + if hasattr(direct, "intersection_update"): zap() e = creator() addall(e, creator(), creator()) @@ -718,7 +752,7 @@ class CollectionsTest(fixtures.ORMTest): control.intersection_update(values) assert_eq() - if hasattr(direct, '__iand__'): + if hasattr(direct, "__iand__"): zap() e = creator() addall(e, creator(), creator()) @@ -744,7 +778,7 @@ class CollectionsTest(fixtures.ORMTest): except TypeError: assert True - if hasattr(direct, 'symmetric_difference_update'): + if hasattr(direct, "symmetric_difference_update"): zap() e = creator() addall(e, creator(), creator()) @@ -766,7 +800,7 @@ class CollectionsTest(fixtures.ORMTest): control.symmetric_difference_update(values) assert_eq() - if hasattr(direct, '__ixor__'): + if hasattr(direct, "__ixor__"): zap() e = creator() addall(e, creator(), creator()) @@ -808,10 +842,14 @@ class CollectionsTest(fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, - extension=canary, - typecallable=typecallable, - useobject=True) + attributes.register_attribute( + Foo, + "attr", + uselist=True, + extension=canary, + typecallable=typecallable, + useobject=True, + ) obj = Foo() direct = obj.attr @@ -855,10 +893,11 @@ class CollectionsTest(fixtures.ORMTest): def test_set_subclass(self): class MySet(set): pass + self._test_adapter(MySet) self._test_set(MySet) self._test_set_bulk(MySet) - self.assert_(getattr(MySet, '_sa_instrumented') == id(MySet)) + self.assert_(getattr(MySet, "_sa_instrumented") == id(MySet)) def test_set_duck(self): class SetLike(object): @@ -885,6 +924,7 @@ class CollectionsTest(fixtures.ORMTest): def __iter__(self): return iter(self.data) + __hash__ = object.__hash__ def __eq__(self, other): @@ -893,7 +933,7 @@ class CollectionsTest(fixtures.ORMTest): self._test_adapter(SetLike) self._test_set(SetLike) self._test_set_bulk(SetLike) - self.assert_(getattr(SetLike, '_sa_instrumented') == id(SetLike)) + self.assert_(getattr(SetLike, "_sa_instrumented") == id(SetLike)) def test_set_emulates(self): class SetIsh(object): @@ -922,6 +962,7 @@ class CollectionsTest(fixtures.ORMTest): def clear(self): self.data.clear() + __hash__ = object.__hash__ def __eq__(self, other): @@ -930,7 +971,7 @@ class CollectionsTest(fixtures.ORMTest): self._test_adapter(SetIsh) self._test_set(SetIsh) self._test_set_bulk(SetIsh) - self.assert_(getattr(SetIsh, '_sa_instrumented') == id(SetIsh)) + self.assert_(getattr(SetIsh, "_sa_instrumented") == id(SetIsh)) def _test_dict(self, typecallable, creator=None): if creator is None: @@ -941,10 +982,14 @@ class CollectionsTest(fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, - extension=canary, - typecallable=typecallable, - useobject=True) + attributes.register_attribute( + Foo, + "attr", + uselist=True, + extension=canary, + typecallable=typecallable, + useobject=True, + ) obj = Foo() adapter = collections.collection_adapter(obj.attr) @@ -970,7 +1015,7 @@ class CollectionsTest(fixtures.ORMTest): # assume an 'set' method is available for tests addall(creator()) - if hasattr(direct, '__setitem__'): + if hasattr(direct, "__setitem__"): e = creator() direct[e.a] = e control[e.a] = e @@ -981,7 +1026,7 @@ class CollectionsTest(fixtures.ORMTest): control[e.a] = e assert_eq() - if hasattr(direct, '__delitem__'): + if hasattr(direct, "__delitem__"): e = creator() addall(e) @@ -995,7 +1040,7 @@ class CollectionsTest(fixtures.ORMTest): except KeyError: self.assert_(e not in canary.removed) - if hasattr(direct, 'clear'): + if hasattr(direct, "clear"): addall(creator(), creator(), creator()) direct.clear() @@ -1006,7 +1051,7 @@ class CollectionsTest(fixtures.ORMTest): control.clear() assert_eq() - if hasattr(direct, 'pop'): + if hasattr(direct, "pop"): e = creator() addall(e) @@ -1020,7 +1065,7 @@ class CollectionsTest(fixtures.ORMTest): except KeyError: self.assert_(e not in canary.removed) - if hasattr(direct, 'popitem'): + if hasattr(direct, "popitem"): zap() e = creator() addall(e) @@ -1029,7 +1074,7 @@ class CollectionsTest(fixtures.ORMTest): control.popitem() assert_eq() - if hasattr(direct, 'setdefault'): + if hasattr(direct, "setdefault"): e = creator() val_a = direct.setdefault(e.a, e) @@ -1042,7 +1087,7 @@ class CollectionsTest(fixtures.ORMTest): assert_eq() self.assert_(val_a is val_b) - if hasattr(direct, 'update'): + if hasattr(direct, "update"): e = creator() d = dict([(ee.a, ee) for ee in [e, creator(), creator()]]) addall(e, creator()) @@ -1065,10 +1110,14 @@ class CollectionsTest(fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, - extension=canary, - typecallable=typecallable, - useobject=True) + attributes.register_attribute( + Foo, + "attr", + uselist=True, + extension=canary, + typecallable=typecallable, + useobject=True, + ) obj = Foo() direct = obj.attr @@ -1085,7 +1134,8 @@ class CollectionsTest(fixtures.ORMTest): self.assert_(obj.attr is not direct) self.assert_(obj.attr is not like_me) self.assert_( - set(collections.collection_adapter(obj.attr)) == set([e2])) + set(collections.collection_adapter(obj.attr)) == set([e2]) + ) self.assert_(e1 in canary.removed) self.assert_(e2 in canary.added) @@ -1096,9 +1146,8 @@ class CollectionsTest(fixtures.ORMTest): real_dict = dict(keyignored1=e3) obj.attr = real_dict self.assert_(obj.attr is not real_dict) - self.assert_('keyignored1' not in obj.attr) - eq_(set(collections.collection_adapter(obj.attr)), - set([e3])) + self.assert_("keyignored1" not in obj.attr) + eq_(set(collections.collection_adapter(obj.attr)), set([e3])) self.assert_(e2 in canary.removed) self.assert_(e3 in canary.added) @@ -1115,17 +1164,20 @@ class CollectionsTest(fixtures.ORMTest): def test_dict(self): assert_raises_message( sa_exc.ArgumentError, - 'Type InstrumentedDict must elect an appender ' - 'method to be a collection class', - self._test_adapter, dict, self.dictable_entity, - to_set=lambda c: set(c.values()) + "Type InstrumentedDict must elect an appender " + "method to be a collection class", + self._test_adapter, + dict, + self.dictable_entity, + to_set=lambda c: set(c.values()), ) assert_raises_message( sa_exc.ArgumentError, - 'Type InstrumentedDict must elect an appender method ' - 'to be a collection class', - self._test_dict, dict + "Type InstrumentedDict must elect an appender method " + "to be a collection class", + self._test_dict, + dict, ) def test_dict_subclass(self): @@ -1140,22 +1192,24 @@ class CollectionsTest(fixtures.ORMTest): def _remove(self, item, _sa_initiator=None): self.__delitem__(item.a, _sa_initiator=_sa_initiator) - self._test_adapter(MyDict, self.dictable_entity, - to_set=lambda c: set(c.values())) + self._test_adapter( + MyDict, self.dictable_entity, to_set=lambda c: set(c.values()) + ) self._test_dict(MyDict) self._test_dict_bulk(MyDict) - self.assert_(getattr(MyDict, '_sa_instrumented') == id(MyDict)) + self.assert_(getattr(MyDict, "_sa_instrumented") == id(MyDict)) def test_dict_subclass2(self): class MyEasyDict(collections.MappedCollection): def __init__(self): super(MyEasyDict, self).__init__(lambda e: e.a) - self._test_adapter(MyEasyDict, self.dictable_entity, - to_set=lambda c: set(c.values())) + self._test_adapter( + MyEasyDict, self.dictable_entity, to_set=lambda c: set(c.values()) + ) self._test_dict(MyEasyDict) self._test_dict_bulk(MyEasyDict) - self.assert_(getattr(MyEasyDict, '_sa_instrumented') == id(MyEasyDict)) + self.assert_(getattr(MyEasyDict, "_sa_instrumented") == id(MyEasyDict)) def test_dict_subclass3(self): class MyOrdered(util.OrderedDict, collections.MappedCollection): @@ -1163,11 +1217,12 @@ class CollectionsTest(fixtures.ORMTest): collections.MappedCollection.__init__(self, lambda e: e.a) util.OrderedDict.__init__(self) - self._test_adapter(MyOrdered, self.dictable_entity, - to_set=lambda c: set(c.values())) + self._test_adapter( + MyOrdered, self.dictable_entity, to_set=lambda c: set(c.values()) + ) self._test_dict(MyOrdered) self._test_dict_bulk(MyOrdered) - self.assert_(getattr(MyOrdered, '_sa_instrumented') == id(MyOrdered)) + self.assert_(getattr(MyOrdered, "_sa_instrumented") == id(MyOrdered)) @testing.uses_deprecated(r".*Use the bulk_replace event handler") def test_dict_subclass4(self): @@ -1187,14 +1242,19 @@ class CollectionsTest(fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, - extension=canary, - typecallable=MyDict, useobject=True) + attributes.register_attribute( + Foo, + "attr", + uselist=True, + extension=canary, + typecallable=MyDict, + useobject=True, + ) f = Foo() f.attr = {"k1": 1, "k2": 2} - eq_(f.attr, {'k7': 7, 'k6': 6}) + eq_(f.attr, {"k7": 7, "k6": 6}) def test_dict_duck(self): class DictLike(object): @@ -1230,19 +1290,21 @@ class CollectionsTest(fixtures.ORMTest): @collection.iterator def itervalues(self): return iter(self.data.values()) + __hash__ = object.__hash__ def __eq__(self, other): return self.data == other def __repr__(self): - return 'DictLike(%s)' % repr(self.data) + return "DictLike(%s)" % repr(self.data) - self._test_adapter(DictLike, self.dictable_entity, - to_set=lambda c: set(c.values())) + self._test_adapter( + DictLike, self.dictable_entity, to_set=lambda c: set(c.values()) + ) self._test_dict(DictLike) self._test_dict_bulk(DictLike) - self.assert_(getattr(DictLike, '_sa_instrumented') == id(DictLike)) + self.assert_(getattr(DictLike, "_sa_instrumented") == id(DictLike)) def test_dict_emulates(self): class DictIsh(object): @@ -1280,19 +1342,21 @@ class CollectionsTest(fixtures.ORMTest): @collection.iterator def itervalues(self): return iter(self.data.values()) + __hash__ = object.__hash__ def __eq__(self, other): return self.data == other def __repr__(self): - return 'DictIsh(%s)' % repr(self.data) + return "DictIsh(%s)" % repr(self.data) - self._test_adapter(DictIsh, self.dictable_entity, - to_set=lambda c: set(c.values())) + self._test_adapter( + DictIsh, self.dictable_entity, to_set=lambda c: set(c.values()) + ) self._test_dict(DictIsh) self._test_dict_bulk(DictIsh) - self.assert_(getattr(DictIsh, '_sa_instrumented') == id(DictIsh)) + self.assert_(getattr(DictIsh, "_sa_instrumented") == id(DictIsh)) def _test_object(self, typecallable, creator=None): if creator is None: @@ -1303,10 +1367,14 @@ class CollectionsTest(fixtures.ORMTest): canary = Canary() instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, - extension=canary, - typecallable=typecallable, - useobject=True) + attributes.register_attribute( + Foo, + "attr", + uselist=True, + extension=canary, + typecallable=typecallable, + useobject=True, + ) obj = Foo() adapter = collections.collection_adapter(obj.attr) @@ -1366,6 +1434,7 @@ class CollectionsTest(fixtures.ORMTest): @collection.iterator def __iter__(self): return iter(self.data) + __hash__ = object.__hash__ def __eq__(self, other): @@ -1373,8 +1442,9 @@ class CollectionsTest(fixtures.ORMTest): self._test_adapter(MyCollection) self._test_object(MyCollection) - self.assert_(getattr(MyCollection, '_sa_instrumented') == - id(MyCollection)) + self.assert_( + getattr(MyCollection, "_sa_instrumented") == id(MyCollection) + ) def test_object_emulates(self): class MyCollection2(object): @@ -1382,6 +1452,7 @@ class CollectionsTest(fixtures.ORMTest): def __init__(self): self.data = set() + # looks like a list def append(self, item): @@ -1404,6 +1475,7 @@ class CollectionsTest(fixtures.ORMTest): @collection.iterator def __iter__(self): return iter(self.data) + __hash__ = object.__hash__ def __eq__(self, other): @@ -1411,8 +1483,9 @@ class CollectionsTest(fixtures.ORMTest): self._test_adapter(MyCollection2) self._test_object(MyCollection2) - self.assert_(getattr(MyCollection2, '_sa_instrumented') == - id(MyCollection2)) + self.assert_( + getattr(MyCollection2, "_sa_instrumented") == id(MyCollection2) + ) def test_recipes(self): class Custom(object): @@ -1420,7 +1493,7 @@ class CollectionsTest(fixtures.ORMTest): self.data = [] @collection.appender - @collection.adds('entity') + @collection.adds("entity") def put(self, entity): self.data.append(entity) @@ -1433,7 +1506,7 @@ class CollectionsTest(fixtures.ORMTest): def push(self, *args): self.data.append(args[0]) - @collection.removes('entity') + @collection.removes("entity") def yank(self, entity, arg): self.data.remove(entity) @@ -1452,11 +1525,17 @@ class CollectionsTest(fixtures.ORMTest): class Foo(object): pass + canary = Canary() instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, - extension=canary, - typecallable=Custom, useobject=True) + attributes.register_attribute( + Foo, + "attr", + uselist=True, + extension=canary, + typecallable=Custom, + useobject=True, + ) obj = Foo() adapter = collections.collection_adapter(obj.attr) @@ -1467,6 +1546,7 @@ class CollectionsTest(fixtures.ORMTest): self.assert_(set(direct) == canary.data) self.assert_(set(adapter) == canary.data) self.assert_(list(direct) == control) + creator = self.entity_maker e1 = creator() @@ -1492,7 +1572,7 @@ class CollectionsTest(fixtures.ORMTest): control.append(e3) assert_eq() - direct.yank(e3, 'blah') + direct.yank(e3, "blah") control.remove(e3) assert_eq() @@ -1502,7 +1582,7 @@ class CollectionsTest(fixtures.ORMTest): control.append(e4) control.append(e5) - dr1 = direct.replace('foo', e6, bar='baz') + dr1 = direct.replace("foo", e6, bar="baz") control.insert(0, e6) cr1 = control.pop() assert_eq() @@ -1514,7 +1594,7 @@ class CollectionsTest(fixtures.ORMTest): assert_eq() self.assert_(dr2 is cr2) - dr3 = direct.pop('blah') + dr3 = direct.pop("blah") cr3 = control.pop() assert_eq() self.assert_(dr3 is cr3) @@ -1526,8 +1606,9 @@ class CollectionsTest(fixtures.ORMTest): canary = Canary() creator = self.entity_maker instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, - extension=canary, useobject=True) + attributes.register_attribute( + Foo, "attr", uselist=True, extension=canary, useobject=True + ) obj = Foo() col1 = obj.attr @@ -1559,21 +1640,29 @@ class CollectionsTest(fixtures.ORMTest): class DictHelpersTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('parents', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('label', String(128))) - Table('children', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, ForeignKey('parents.id'), - nullable=False), - Column('a', String(128)), - Column('b', String(128)), - Column('c', String(128))) + Table( + "parents", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("label", String(128)), + ) + Table( + "children", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column( + "parent_id", Integer, ForeignKey("parents.id"), nullable=False + ), + Column("a", String(128)), + Column("b", String(128)), + Column("c", String(128)), + ) @classmethod def setup_classes(cls): @@ -1588,19 +1677,29 @@ class DictHelpersTest(fixtures.MappedTest): self.c = c def _test_scalar_mapped(self, collection_class): - parents, children, Parent, Child = (self.tables.parents, - self.tables.children, - self.classes.Parent, - self.classes.Child) + parents, children, Parent, Child = ( + self.tables.parents, + self.tables.children, + self.classes.Parent, + self.classes.Child, + ) mapper(Child, children) - mapper(Parent, parents, properties={ - 'children': relationship(Child, collection_class=collection_class, - cascade="all, delete-orphan")}) + mapper( + Parent, + parents, + properties={ + "children": relationship( + Child, + collection_class=collection_class, + cascade="all, delete-orphan", + ) + }, + ) p = Parent() - p.children['foo'] = Child('foo', 'value') - p.children['bar'] = Child('bar', 'value') + p.children["foo"] = Child("foo", "value") + p.children["bar"] = Child("bar", "value") session = create_session() session.add(p) session.flush() @@ -1609,66 +1708,83 @@ class DictHelpersTest(fixtures.MappedTest): p = session.query(Parent).get(pid) - eq_(set(p.children.keys()), set(['foo', 'bar'])) - cid = p.children['foo'].id + eq_(set(p.children.keys()), set(["foo", "bar"])) + cid = p.children["foo"].id collections.collection_adapter(p.children).append_with_event( - Child('foo', 'newvalue')) + Child("foo", "newvalue") + ) session.flush() session.expunge_all() p = session.query(Parent).get(pid) - self.assert_(set(p.children.keys()) == set(['foo', 'bar'])) - self.assert_(p.children['foo'].id != cid) + self.assert_(set(p.children.keys()) == set(["foo", "bar"])) + self.assert_(p.children["foo"].id != cid) self.assert_( - len(list(collections.collection_adapter(p.children))) == 2) + len(list(collections.collection_adapter(p.children))) == 2 + ) session.flush() session.expunge_all() p = session.query(Parent).get(pid) self.assert_( - len(list(collections.collection_adapter(p.children))) == 2) + len(list(collections.collection_adapter(p.children))) == 2 + ) collections.collection_adapter(p.children).remove_with_event( - p.children['foo']) + p.children["foo"] + ) self.assert_( - len(list(collections.collection_adapter(p.children))) == 1) + len(list(collections.collection_adapter(p.children))) == 1 + ) session.flush() session.expunge_all() p = session.query(Parent).get(pid) self.assert_( - len(list(collections.collection_adapter(p.children))) == 1) + len(list(collections.collection_adapter(p.children))) == 1 + ) - del p.children['bar'] + del p.children["bar"] self.assert_( - len(list(collections.collection_adapter(p.children))) == 0) + len(list(collections.collection_adapter(p.children))) == 0 + ) session.flush() session.expunge_all() p = session.query(Parent).get(pid) self.assert_( - len(list(collections.collection_adapter(p.children))) == 0) + len(list(collections.collection_adapter(p.children))) == 0 + ) def _test_composite_mapped(self, collection_class): - parents, children, Parent, Child = (self.tables.parents, - self.tables.children, - self.classes.Parent, - self.classes.Child) + parents, children, Parent, Child = ( + self.tables.parents, + self.tables.children, + self.classes.Parent, + self.classes.Child, + ) mapper(Child, children) - mapper(Parent, parents, properties={ - 'children': relationship(Child, collection_class=collection_class, - cascade="all, delete-orphan") - }) + mapper( + Parent, + parents, + properties={ + "children": relationship( + Child, + collection_class=collection_class, + cascade="all, delete-orphan", + ) + }, + ) p = Parent() - p.children[('foo', '1')] = Child('foo', '1', 'value 1') - p.children[('foo', '2')] = Child('foo', '2', 'value 2') + p.children[("foo", "1")] = Child("foo", "1", "value 1") + p.children[("foo", "2")] = Child("foo", "2", "value 2") session = create_session() session.add(p) @@ -1679,11 +1795,13 @@ class DictHelpersTest(fixtures.MappedTest): p = session.query(Parent).get(pid) self.assert_( - set(p.children.keys()) == set([('foo', '1'), ('foo', '2')])) - cid = p.children[('foo', '1')].id + set(p.children.keys()) == set([("foo", "1"), ("foo", "2")]) + ) + cid = p.children[("foo", "1")].id collections.collection_adapter(p.children).append_with_event( - Child('foo', '1', 'newvalue')) + Child("foo", "1", "newvalue") + ) session.flush() session.expunge_all() @@ -1691,11 +1809,13 @@ class DictHelpersTest(fixtures.MappedTest): p = session.query(Parent).get(pid) self.assert_( - set(p.children.keys()) == set([('foo', '1'), ('foo', '2')])) - self.assert_(p.children[('foo', '1')].id != cid) + set(p.children.keys()) == set([("foo", "1"), ("foo", "2")]) + ) + self.assert_(p.children[("foo", "1")].id != cid) self.assert_( - len(list(collections.collection_adapter(p.children))) == 2) + len(list(collections.collection_adapter(p.children))) == 2 + ) def test_mapped_collection(self): collection_class = collections.mapped_collection(lambda c: c.a) @@ -1706,7 +1826,7 @@ class DictHelpersTest(fixtures.MappedTest): self._test_composite_mapped(collection_class) def test_attr_mapped_collection(self): - collection_class = collections.attribute_mapped_collection('a') + collection_class = collections.attribute_mapped_collection("a") self._test_scalar_mapped(collection_class) def test_declarative_column_mapped(self): @@ -1720,40 +1840,45 @@ class DictHelpersTest(fixtures.MappedTest): class Foo(BaseObject): __tablename__ = "foo" id = Column(Integer(), primary_key=True) - bar_id = Column(Integer, ForeignKey('bar.id')) + bar_id = Column(Integer, ForeignKey("bar.id")) for spec, obj, expected in ( (Foo.id, Foo(id=3), 3), - ((Foo.id, Foo.bar_id), Foo(id=3, bar_id=12), (3, 12)) + ((Foo.id, Foo.bar_id), Foo(id=3, bar_id=12), (3, 12)), ): eq_( collections.column_mapped_collection(spec)().keyfunc(obj), - expected + expected, ) def test_column_mapped_assertions(self): - assert_raises_message(sa_exc.ArgumentError, - "Column-based expression object expected " - "for argument 'mapping_spec'; got: 'a'", - collections.column_mapped_collection, 'a') - assert_raises_message(sa_exc.ArgumentError, - "Column-based expression object expected " - "for argument 'mapping_spec'; got: 'a'", - collections.column_mapped_collection, - text('a')) + assert_raises_message( + sa_exc.ArgumentError, + "Column-based expression object expected " + "for argument 'mapping_spec'; got: 'a'", + collections.column_mapped_collection, + "a", + ) + assert_raises_message( + sa_exc.ArgumentError, + "Column-based expression object expected " + "for argument 'mapping_spec'; got: 'a'", + collections.column_mapped_collection, + text("a"), + ) def test_column_mapped_collection(self): children = self.tables.children - collection_class = collections.column_mapped_collection( - children.c.a) + collection_class = collections.column_mapped_collection(children.c.a) self._test_scalar_mapped(collection_class) def test_column_mapped_collection2(self): children = self.tables.children collection_class = collections.column_mapped_collection( - (children.c.a, children.c.b)) + (children.c.a, children.c.b) + ) self._test_composite_mapped(collection_class) def test_mixin(self): @@ -1761,6 +1886,7 @@ class DictHelpersTest(fixtures.MappedTest): def __init__(self): collections.MappedCollection.__init__(self, lambda v: v.a) util.OrderedDict.__init__(self) + collection_class = Ordered self._test_scalar_mapped(collection_class) @@ -1772,6 +1898,7 @@ class DictHelpersTest(fixtures.MappedTest): def collection_class(): return Ordered2(lambda v: (v.a, v.b)) + self._test_composite_mapped(collection_class) @@ -1784,14 +1911,20 @@ class ColumnMappedWSerialize(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('foo', metadata, - Column('id', Integer(), primary_key=True), - Column('b', String(128))) - Table('bar', metadata, - Column('id', Integer(), primary_key=True), - Column('foo_id', Integer, ForeignKey('foo.id')), - Column('bat_id', Integer), - schema="x") + Table( + "foo", + metadata, + Column("id", Integer(), primary_key=True), + Column("b", String(128)), + ) + Table( + "bar", + metadata, + Column("id", Integer(), primary_key=True), + Column("foo_id", Integer, ForeignKey("foo.id")), + Column("bat_id", Integer), + schema="x", + ) @classmethod def setup_classes(cls): @@ -1805,39 +1938,35 @@ class ColumnMappedWSerialize(fixtures.MappedTest): Foo = self.classes.Foo Bar = self.classes.Bar bar = self.tables["x.bar"] - mapper(Foo, self.tables.foo, properties={ - "foo_id": self.tables.foo.c.id - }) - mapper(Bar, bar, inherits=Foo, properties={ - "bar_id": bar.c.id, - }) + mapper( + Foo, self.tables.foo, properties={"foo_id": self.tables.foo.c.id} + ) + mapper(Bar, bar, inherits=Foo, properties={"bar_id": bar.c.id}) bar_spec = Bar(foo_id=1, bar_id=2, bat_id=3) - self._run_test([ - (Foo.foo_id, bar_spec, 1), - ((Bar.bar_id, Bar.bat_id), bar_spec, (2, 3)), - (Bar.foo_id, bar_spec, 1), - (bar.c.id, bar_spec, 2), - ]) + self._run_test( + [ + (Foo.foo_id, bar_spec, 1), + ((Bar.bar_id, Bar.bat_id), bar_spec, (2, 3)), + (Bar.foo_id, bar_spec, 1), + (bar.c.id, bar_spec, 2), + ] + ) def test_selectable_column_mapped(self): from sqlalchemy import select + s = select([self.tables.foo]).alias() Foo = self.classes.Foo mapper(Foo, s) - self._run_test([ - (Foo.b, Foo(b=5), 5), - (s.c.b, Foo(b=5), 5) - ]) + self._run_test([(Foo.b, Foo(b=5), 5), (s.c.b, Foo(b=5), 5)]) def _run_test(self, specs): from sqlalchemy.testing.util import picklers + for spec, obj, expected in specs: coll = collections.column_mapped_collection(spec)() - eq_( - coll.keyfunc(obj), - expected - ) + eq_(coll.keyfunc(obj), expected) # ensure we do the right thing with __reduce__ for loads, dumps in picklers(): c2 = loads(dumps(coll)) @@ -1851,20 +1980,35 @@ class CustomCollectionsTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('sometable', metadata, - Column('col1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30))) - Table('someothertable', metadata, - Column('col1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('scol1', Integer, - ForeignKey('sometable.col1')), - Column('data', String(20))) + Table( + "sometable", + metadata, + Column( + "col1", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("data", String(30)), + ) + Table( + "someothertable", + metadata, + Column( + "col1", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("scol1", Integer, ForeignKey("sometable.col1")), + Column("data", String(20)), + ) def test_basic(self): - someothertable, sometable = self.tables.someothertable, \ - self.tables.sometable + someothertable, sometable = ( + self.tables.someothertable, + self.tables.sometable, + ) class MyList(list): pass @@ -1875,9 +2019,11 @@ class CustomCollectionsTest(fixtures.MappedTest): class Bar(object): pass - mapper(Foo, sometable, properties={ - 'bars': relationship(Bar, collection_class=MyList) - }) + mapper( + Foo, + sometable, + properties={"bars": relationship(Bar, collection_class=MyList)}, + ) mapper(Bar, someothertable) f = Foo() assert isinstance(f.bars, MyList) @@ -1885,17 +2031,22 @@ class CustomCollectionsTest(fixtures.MappedTest): def test_lazyload(self): """test that a 'set' can be used as a collection and can lazyload.""" - someothertable, sometable = self.tables.someothertable, \ - self.tables.sometable + someothertable, sometable = ( + self.tables.someothertable, + self.tables.sometable, + ) class Foo(object): pass class Bar(object): pass - mapper(Foo, sometable, properties={ - 'bars': relationship(Bar, collection_class=set) - }) + + mapper( + Foo, + sometable, + properties={"bars": relationship(Bar, collection_class=set)}, + ) mapper(Bar, someothertable) f = Foo() f.bars.add(Bar()) @@ -1911,8 +2062,10 @@ class CustomCollectionsTest(fixtures.MappedTest): def test_dict(self): """test that a 'dict' can be used as a collection and can lazyload.""" - someothertable, sometable = self.tables.someothertable, \ - self.tables.sometable + someothertable, sometable = ( + self.tables.someothertable, + self.tables.sometable, + ) class Foo(object): pass @@ -1930,9 +2083,13 @@ class CustomCollectionsTest(fixtures.MappedTest): if id(item) in self: del self[id(item)] - mapper(Foo, sometable, properties={ - 'bars': relationship(Bar, collection_class=AppenderDict) - }) + mapper( + Foo, + sometable, + properties={ + "bars": relationship(Bar, collection_class=AppenderDict) + }, + ) mapper(Bar, someothertable) f = Foo() f.bars.set(Bar()) @@ -1949,26 +2106,36 @@ class CustomCollectionsTest(fixtures.MappedTest): """test that the supplied 'dict' wrapper can be used as a collection and can lazyload.""" - someothertable, sometable = self.tables.someothertable, \ - self.tables.sometable + someothertable, sometable = ( + self.tables.someothertable, + self.tables.sometable, + ) class Foo(object): pass class Bar(object): - def __init__(self, data): self.data = data - - mapper(Foo, sometable, properties={ - 'bars': relationship( - Bar, collection_class=collections.column_mapped_collection( - someothertable.c.data)) - }) + def __init__(self, data): + self.data = data + + mapper( + Foo, + sometable, + properties={ + "bars": relationship( + Bar, + collection_class=collections.column_mapped_collection( + someothertable.c.data + ), + ) + }, + ) mapper(Bar, someothertable) f = Foo() col = collections.collection_adapter(f.bars) - col.append_with_event(Bar('a')) - col.append_with_event(Bar('b')) + col.append_with_event(Bar("a")) + col.append_with_event(Bar("b")) sess = create_session() sess.add(f) sess.flush() @@ -1980,8 +2147,8 @@ class CustomCollectionsTest(fixtures.MappedTest): existing = set([id(b) for b in strongref]) col = collections.collection_adapter(f.bars) - col.append_with_event(Bar('b')) - f.bars['a'] = Bar('a') + col.append_with_event(Bar("b")) + f.bars["a"] = Bar("a") sess.flush() sess.expunge_all() f = sess.query(Foo).get(f.col1) @@ -2027,19 +2194,22 @@ class CustomCollectionsTest(fixtures.MappedTest): def __iter__(self): return iter(self.data) + __hash__ = object.__hash__ def __eq__(self, other): return self.data == other def __repr__(self): - return 'ListLike(%s)' % repr(self.data) + return "ListLike(%s)" % repr(self.data) self._test_list(ListLike) def _test_list(self, listcls): - someothertable, sometable = self.tables.someothertable, \ - self.tables.sometable + someothertable, sometable = ( + self.tables.someothertable, + self.tables.sometable, + ) class Parent(object): pass @@ -2047,9 +2217,13 @@ class CustomCollectionsTest(fixtures.MappedTest): class Child(object): pass - mapper(Parent, sometable, properties={ - 'children': relationship(Child, collection_class=listcls) - }) + mapper( + Parent, + sometable, + properties={ + "children": relationship(Child, collection_class=listcls) + }, + ) mapper(Child, someothertable) control = list() @@ -2163,8 +2337,10 @@ class CustomCollectionsTest(fixtures.MappedTest): assert control == list(p.children) def test_custom(self): - someothertable, sometable = self.tables.someothertable, \ - self.tables.sometable + someothertable, sometable = ( + self.tables.someothertable, + self.tables.sometable, + ) class Parent(object): pass @@ -2188,9 +2364,13 @@ class CustomCollectionsTest(fixtures.MappedTest): def __iter__(self): return iter(self.data) - mapper(Parent, sometable, properties={ - 'children': relationship(Child, collection_class=MyCollection) - }) + mapper( + Parent, + sometable, + properties={ + "children": relationship(Child, collection_class=MyCollection) + }, + ) mapper(Child, someothertable) control = list() @@ -2230,15 +2410,14 @@ class InstrumentationTest(fixtures.ORMTest): class Touchy(list): no_touch = DoNotTouch() - assert 'no_touch' in Touchy.__dict__ - assert not hasattr(Touchy, 'no_touch') - assert 'no_touch' in dir(Touchy) + assert "no_touch" in Touchy.__dict__ + assert not hasattr(Touchy, "no_touch") + assert "no_touch" in dir(Touchy) collections._instrument_class(Touchy) @testing.uses_deprecated(r".*Use the bulk_replace event handler") def test_name_setup(self): - class Base(object): @collection.iterator def base_iterate(self, x): @@ -2257,6 +2436,7 @@ class InstrumentationTest(fixtures.ORMTest): return "base_remove" from sqlalchemy.orm.collections import _instrument_class + _instrument_class(Base) eq_(Base._sa_remover(Base(), 5), "base_remove") @@ -2272,6 +2452,7 @@ class InstrumentationTest(fixtures.ORMTest): @collection.remover def sub_remove(self, x): return "sub_remove" + _instrument_class(Sub) eq_(Sub._sa_appender(Sub(), 5), "base_append") @@ -2291,8 +2472,9 @@ class InstrumentationTest(fixtures.ORMTest): pass instrumentation.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, - typecallable=Collection, useobject=True) + attributes.register_attribute( + Foo, "attr", uselist=True, typecallable=Collection, useobject=True + ) f1 = Foo() f1.attr.append(3) @@ -2305,13 +2487,13 @@ class InstrumentationTest(fixtures.ORMTest): eq_(canary, [adapter_1, f1.attr._sa_adapter, None]) def test_referenced_by_owner(self): - class Foo(object): pass instrumentation.register_class(Foo) attributes.register_attribute( - Foo, 'attr', uselist=True, useobject=True) + Foo, "attr", uselist=True, useobject=True + ) f1 = Foo() f1.attr.append(3) diff --git a/test/orm/test_compile.py b/test/orm/test_compile.py index abe35fdb77..9ef346170b 100644 --- a/test/orm/test_compile.py +++ b/test/orm/test_compile.py @@ -15,26 +15,44 @@ class CompileTest(fixtures.ORMTest): def test_with_polymorphic(self): metadata = MetaData(testing.db) - order = Table('orders', metadata, - Column('id', Integer, primary_key=True), - Column('employee_id', Integer, ForeignKey( - 'employees.id'), nullable=False), - Column('type', Unicode(16))) - - employee = Table('employees', metadata, - Column('id', Integer, primary_key=True), - Column('name', Unicode(16), unique=True, - nullable=False)) - - product = Table('products', metadata, - Column('id', Integer, primary_key=True)) - - orderproduct = Table('orderproducts', metadata, - Column('id', Integer, primary_key=True), - Column('order_id', Integer, ForeignKey( - "orders.id"), nullable=False), - Column('product_id', Integer, ForeignKey( - "products.id"), nullable=False)) + order = Table( + "orders", + metadata, + Column("id", Integer, primary_key=True), + Column( + "employee_id", + Integer, + ForeignKey("employees.id"), + nullable=False, + ), + Column("type", Unicode(16)), + ) + + employee = Table( + "employees", + metadata, + Column("id", Integer, primary_key=True), + Column("name", Unicode(16), unique=True, nullable=False), + ) + + product = Table( + "products", metadata, Column("id", Integer, primary_key=True) + ) + + orderproduct = Table( + "orderproducts", + metadata, + Column("id", Integer, primary_key=True), + Column( + "order_id", Integer, ForeignKey("orders.id"), nullable=False + ), + Column( + "product_id", + Integer, + ForeignKey("products.id"), + nullable=False, + ), + ) class Order(object): pass @@ -48,28 +66,40 @@ class CompileTest(fixtures.ORMTest): class OrderProduct(object): pass - order_join = order.select().alias('pjoin') - - order_mapper = mapper(Order, order, - with_polymorphic=('*', order_join), - polymorphic_on=order_join.c.type, - polymorphic_identity='order', - properties={ - 'orderproducts': relationship( - OrderProduct, lazy='select', - backref='order')} - ) - - mapper(Product, product, - properties={ - 'orderproducts': relationship(OrderProduct, lazy='select', - backref='product')} - ) - - mapper(Employee, employee, - properties={ - 'orders': relationship(Order, lazy='select', - backref='employee')}) + order_join = order.select().alias("pjoin") + + order_mapper = mapper( + Order, + order, + with_polymorphic=("*", order_join), + polymorphic_on=order_join.c.type, + polymorphic_identity="order", + properties={ + "orderproducts": relationship( + OrderProduct, lazy="select", backref="order" + ) + }, + ) + + mapper( + Product, + product, + properties={ + "orderproducts": relationship( + OrderProduct, lazy="select", backref="product" + ) + }, + ) + + mapper( + Employee, + employee, + properties={ + "orders": relationship( + Order, lazy="select", backref="employee" + ) + }, + ) mapper(OrderProduct, orderproduct) @@ -83,20 +113,31 @@ class CompileTest(fixtures.ORMTest): metadata = MetaData(testing.db) - order = Table('orders', metadata, - Column('id', Integer, primary_key=True), - Column('type', Unicode(16))) + order = Table( + "orders", + metadata, + Column("id", Integer, primary_key=True), + Column("type", Unicode(16)), + ) - product = Table('products', metadata, - Column('id', Integer, primary_key=True)) + product = Table( + "products", metadata, Column("id", Integer, primary_key=True) + ) - orderproduct = Table('orderproducts', metadata, - Column('id', Integer, primary_key=True), - Column('order_id', Integer, - ForeignKey("orders.id"), nullable=False), - Column('product_id', Integer, - ForeignKey("products.id"), - nullable=False)) + orderproduct = Table( + "orderproducts", + metadata, + Column("id", Integer, primary_key=True), + Column( + "order_id", Integer, ForeignKey("orders.id"), nullable=False + ), + Column( + "product_id", + Integer, + ForeignKey("products.id"), + nullable=False, + ), + ) class Order(object): pass @@ -107,49 +148,59 @@ class CompileTest(fixtures.ORMTest): class OrderProduct(object): pass - order_join = order.select().alias('pjoin') - - order_mapper = mapper(Order, order, - with_polymorphic=('*', order_join), - polymorphic_on=order_join.c.type, - polymorphic_identity='order', - properties={ - 'orderproducts': relationship( - OrderProduct, lazy='select', - backref='product')} - ) + order_join = order.select().alias("pjoin") + + order_mapper = mapper( + Order, + order, + with_polymorphic=("*", order_join), + polymorphic_on=order_join.c.type, + polymorphic_identity="order", + properties={ + "orderproducts": relationship( + OrderProduct, lazy="select", backref="product" + ) + }, + ) - mapper(Product, product, - properties={ - 'orderproducts': relationship(OrderProduct, lazy='select', - backref='product')} - ) + mapper( + Product, + product, + properties={ + "orderproducts": relationship( + OrderProduct, lazy="select", backref="product" + ) + }, + ) mapper(OrderProduct, orderproduct) assert_raises_message( - sa_exc.ArgumentError, - "Error creating backref", - configure_mappers + sa_exc.ArgumentError, "Error creating backref", configure_mappers ) def test_misc_one(self): metadata = MetaData(testing.db) - node_table = Table("node", metadata, - Column('node_id', Integer, primary_key=True), - Column('name_index', Integer, nullable=True)) - node_name_table = Table("node_name", metadata, - Column('node_name_id', Integer, - primary_key=True), - Column('node_id', Integer, - ForeignKey('node.node_id')), - Column('host_id', Integer, - ForeignKey('host.host_id')), - Column('name', String(64), nullable=False)) - host_table = Table("host", metadata, - Column('host_id', Integer, primary_key=True), - Column('hostname', String(64), nullable=False, - unique=True)) + node_table = Table( + "node", + metadata, + Column("node_id", Integer, primary_key=True), + Column("name_index", Integer, nullable=True), + ) + node_name_table = Table( + "node_name", + metadata, + Column("node_name_id", Integer, primary_key=True), + Column("node_id", Integer, ForeignKey("node.node_id")), + Column("host_id", Integer, ForeignKey("host.host_id")), + Column("name", String(64), nullable=False), + ) + host_table = Table( + "host", + metadata, + Column("host_id", Integer, primary_key=True), + Column("hostname", String(64), nullable=False, unique=True), + ) metadata.create_all() try: node_table.insert().execute(node_id=1, node_index=5) @@ -165,12 +216,14 @@ class CompileTest(fixtures.ORMTest): node_mapper = mapper(Node, node_table) host_mapper = mapper(Host, host_table) - node_name_mapper = mapper(NodeName, node_name_table, - properties={ - 'node': relationship( - Node, backref=backref('names')), - 'host': relationship(Host), - }) + node_name_mapper = mapper( + NodeName, + node_name_table, + properties={ + "node": relationship(Node, backref=backref("names")), + "host": relationship(Host), + }, + ) sess = create_session() assert sess.query(Node).get(1).names == [] finally: @@ -179,9 +232,13 @@ class CompileTest(fixtures.ORMTest): def test_conflicting_backref_two(self): meta = MetaData() - a = Table('a', meta, Column('id', Integer, primary_key=True)) - b = Table('b', meta, Column('id', Integer, primary_key=True), - Column('a_id', Integer, ForeignKey('a.id'))) + a = Table("a", meta, Column("id", Integer, primary_key=True)) + b = Table( + "b", + meta, + Column("id", Integer, primary_key=True), + Column("a_id", Integer, ForeignKey("a.id")), + ) class A(object): pass @@ -189,25 +246,23 @@ class CompileTest(fixtures.ORMTest): class B(object): pass - mapper(A, a, properties={ - 'b': relationship(B, backref='a') - }) - mapper(B, b, properties={ - 'a': relationship(A, backref='b') - }) + mapper(A, a, properties={"b": relationship(B, backref="a")}) + mapper(B, b, properties={"a": relationship(A, backref="b")}) assert_raises_message( - sa_exc.ArgumentError, - "Error creating backref", - configure_mappers + sa_exc.ArgumentError, "Error creating backref", configure_mappers ) def test_conflicting_backref_subclass(self): meta = MetaData() - a = Table('a', meta, Column('id', Integer, primary_key=True)) - b = Table('b', meta, Column('id', Integer, primary_key=True), - Column('a_id', Integer, ForeignKey('a.id'))) + a = Table("a", meta, Column("id", Integer, primary_key=True)) + b = Table( + "b", + meta, + Column("id", Integer, primary_key=True), + Column("a_id", Integer, ForeignKey("a.id")), + ) class A(object): pass @@ -218,15 +273,17 @@ class CompileTest(fixtures.ORMTest): class C(B): pass - mapper(A, a, properties={ - 'b': relationship(B, backref='a'), - 'c': relationship(C, backref='a') - }) + mapper( + A, + a, + properties={ + "b": relationship(B, backref="a"), + "c": relationship(C, backref="a"), + }, + ) mapper(B, b) mapper(C, None, inherits=B) assert_raises_message( - sa_exc.ArgumentError, - "Error creating backref", - configure_mappers + sa_exc.ArgumentError, "Error creating backref", configure_mappers ) diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index 91a96d6fa4..5126964c98 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -1,35 +1,44 @@ from sqlalchemy.testing import assert_raises, assert_raises_message import sqlalchemy as sa from sqlalchemy import testing -from sqlalchemy import Integer, String, ForeignKey, \ - select +from sqlalchemy import Integer, String, ForeignKey, select from sqlalchemy.testing.schema import Table, Column -from sqlalchemy.orm import mapper, relationship, \ - CompositeProperty, aliased, persistence +from sqlalchemy.orm import ( + mapper, + relationship, + CompositeProperty, + aliased, + persistence, +) from sqlalchemy.orm import composite, Session, configure_mappers from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures - - class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): @classmethod def define_tables(cls, metadata): - Table('graphs', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30))) - - Table('edges', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('graph_id', Integer, - ForeignKey('graphs.id')), - Column('x1', Integer), - Column('y1', Integer), - Column('x2', Integer), - Column('y2', Integer)) + Table( + "graphs", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30)), + ) + + Table( + "edges", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("graph_id", Integer, ForeignKey("graphs.id")), + Column("x1", Integer), + Column("y1", Integer), + Column("x2", Integer), + Column("y2", Integer), + ) @classmethod def setup_mappers(cls): @@ -42,16 +51,18 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): def __composite_values__(self): return [self.x, self.y] + __hash__ = None def __eq__(self, other): - return isinstance(other, Point) and \ - other.x == self.x and \ - other.y == self.y + return ( + isinstance(other, Point) + and other.x == self.x + and other.y == self.y + ) def __ne__(self, other): - return not isinstance(other, Point) or \ - not self.__eq__(other) + return not isinstance(other, Point) or not self.__eq__(other) class Graph(cls.Comparable): pass @@ -61,24 +72,31 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): if args: self.start, self.end = args - mapper(Graph, graphs, properties={ - 'edges': relationship(Edge) - }) - mapper(Edge, edges, properties={ - 'start': sa.orm.composite(Point, edges.c.x1, edges.c.y1), - 'end': sa.orm.composite(Point, edges.c.x2, edges.c.y2) - }) + mapper(Graph, graphs, properties={"edges": relationship(Edge)}) + mapper( + Edge, + edges, + properties={ + "start": sa.orm.composite(Point, edges.c.x1, edges.c.y1), + "end": sa.orm.composite(Point, edges.c.x2, edges.c.y2), + }, + ) def _fixture(self): - Graph, Edge, Point = (self.classes.Graph, - self.classes.Edge, - self.classes.Point) + Graph, Edge, Point = ( + self.classes.Graph, + self.classes.Edge, + self.classes.Point, + ) sess = Session() - g = Graph(id=1, edges=[ - Edge(Point(3, 4), Point(5, 6)), - Edge(Point(14, 5), Point(2, 7)) - ]) + g = Graph( + id=1, + edges=[ + Edge(Point(3, 4), Point(5, 6)), + Edge(Point(14, 5), Point(2, 7)), + ], + ) sess.add(g) sess.commit() return sess @@ -100,16 +118,15 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): g = sess.query(Graph).get(g1.id) eq_( [(e.start, e.end) for e in g.edges], - [ - (Point(3, 4), Point(5, 6)), - (Point(14, 5), Point(2, 7)), - ] + [(Point(3, 4), Point(5, 6)), (Point(14, 5), Point(2, 7))], ) def test_detect_change(self): - Graph, Edge, Point = (self.classes.Graph, - self.classes.Edge, - self.classes.Point) + Graph, Edge, Point = ( + self.classes.Graph, + self.classes.Edge, + self.classes.Point, + ) sess = self._fixture() @@ -121,9 +138,11 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): eq_(e.end, Point(18, 4)) def test_not_none(self): - Graph, Edge, Point = (self.classes.Graph, - self.classes.Edge, - self.classes.Point) + Graph, Edge, Point = ( + self.classes.Graph, + self.classes.Edge, + self.classes.Point, + ) # current contract. the composite is None # when hasn't been populated etc. on a @@ -155,57 +174,58 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): sess.close() def go(): - g2 = sess.query(Graph).\ - options(sa.orm.joinedload('edges')).\ - get(g.id) + g2 = ( + sess.query(Graph).options(sa.orm.joinedload("edges")).get(g.id) + ) eq_( [(e.start, e.end) for e in g2.edges], - [ - (Point(3, 4), Point(5, 6)), - (Point(14, 5), Point(2, 7)), - ] + [(Point(3, 4), Point(5, 6)), (Point(14, 5), Point(2, 7))], ) + self.assert_sql_count(testing.db, go, 1) def test_comparator(self): - Graph, Edge, Point = (self.classes.Graph, - self.classes.Edge, - self.classes.Point) + Graph, Edge, Point = ( + self.classes.Graph, + self.classes.Edge, + self.classes.Point, + ) sess = self._fixture() g = sess.query(Graph).first() - assert sess.query(Edge).\ - filter(Edge.start == Point(3, 4)).one() is \ - g.edges[0] - - assert sess.query(Edge).\ - filter(Edge.start != Point(3, 4)).first() is \ - g.edges[1] + assert ( + sess.query(Edge).filter(Edge.start == Point(3, 4)).one() + is g.edges[0] + ) - eq_( - sess.query(Edge).filter(Edge.start == None).all(), # noqa - [] + assert ( + sess.query(Edge).filter(Edge.start != Point(3, 4)).first() + is g.edges[1] ) + eq_(sess.query(Edge).filter(Edge.start == None).all(), []) # noqa + def test_comparator_aliased(self): - Graph, Edge, Point = (self.classes.Graph, - self.classes.Edge, - self.classes.Point) + Graph, Edge, Point = ( + self.classes.Graph, + self.classes.Edge, + self.classes.Point, + ) sess = self._fixture() g = sess.query(Graph).first() ea = aliased(Edge) - assert sess.query(ea).\ - filter(ea.start != Point(3, 4)).first() is \ - g.edges[1] + assert ( + sess.query(ea).filter(ea.start != Point(3, 4)).first() + is g.edges[1] + ) def test_bulk_update_sql(self): - Edge, Point = (self.classes.Edge, - self.classes.Point) + Edge, Point = (self.classes.Edge, self.classes.Point) sess = self._fixture() @@ -215,19 +235,19 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): q = sess.query(Edge).filter(Edge.start == Point(14, 5)) bulk_ud = persistence.BulkUpdate.factory( - q, False, {Edge.end: Point(16, 10)}, {}) + q, False, {Edge.end: Point(16, 10)}, {} + ) self.assert_compile( bulk_ud, "UPDATE edges SET x2=:x2, y2=:y2 WHERE edges.x1 = :x1_1 " "AND edges.y1 = :y1_1", - params={'x2': 16, 'x1_1': 14, 'y2': 10, 'y1_1': 5}, - dialect="default" + params={"x2": 16, "x1_1": 14, "y2": 10, "y1_1": 5}, + dialect="default", ) def test_bulk_update_evaluate(self): - Edge, Point = (self.classes.Edge, - self.classes.Point) + Edge, Point = (self.classes.Edge, self.classes.Point) sess = self._fixture() @@ -241,8 +261,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): eq_(e1.end, Point(16, 10)) def test_bulk_update_fetch(self): - Edge, Point = (self.classes.Edge, - self.classes.Point) + Edge, Point = (self.classes.Edge, self.classes.Point) sess = self._fixture() @@ -263,14 +282,11 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): e1 = Edge() e1.start = Point(1, 2) eq_( - get_history(e1, 'start'), - ([Point(x=1, y=2)], (), [Point(x=None, y=None)]) + get_history(e1, "start"), + ([Point(x=1, y=2)], (), [Point(x=None, y=None)]), ) - eq_( - get_history(e1, 'end'), - ((), [Point(x=None, y=None)], ()) - ) + eq_(get_history(e1, "end"), ((), [Point(x=None, y=None)], ())) def test_query_cols_legacy(self): Edge = self.classes.Edge @@ -279,7 +295,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): eq_( sess.query(Edge.start.clauses, Edge.end.clauses).all(), - [(3, 4, 5, 6), (14, 5, 2, 7)] + [(3, 4, 5, 6), (14, 5, 2, 7)], ) def test_query_cols(self): @@ -292,7 +308,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): eq_( sess.query(start, end).filter(start == Point(3, 4)).all(), - [(Point(3, 4), Point(5, 6))] + [(Point(3, 4), Point(5, 6))], ) def test_query_cols_labeled(self): @@ -303,8 +319,11 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): start, end = Edge.start, Edge.end - row = sess.query(start.label('s1'), end).filter( - start == Point(3, 4)).first() + row = ( + sess.query(start.label("s1"), end) + .filter(start == Point(3, 4)) + .first() + ) eq_(row.s1.x, 3) eq_(row.s1.y, 4) eq_(row.end.x, 5) @@ -324,8 +343,8 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): sess.query(Edge.start, Edge.end).all(), [ (Point(x=3, y=4), Point(x=5, y=6)), - (Point(x=14, y=5), Point(x=None, y=None)) - ] + (Point(x=14, y=5), Point(x=None, y=None)), + ], ) def test_save_null(self): @@ -357,7 +376,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): g = sess.query(Graph).first() e = g.edges[0] sess.expire(e) - assert 'start' not in e.__dict__ + assert "start" not in e.__dict__ assert e.start == Point(3, 4) def test_default_value(self): @@ -370,13 +389,17 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): class NestedTest(fixtures.MappedTest, testing.AssertsCompiledSQL): @classmethod def define_tables(cls, metadata): - Table('stuff', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column("a", String(30)), - Column("b", String(30)), - Column("c", String(30)), - Column("d", String(30))) + Table( + "stuff", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("a", String(30)), + Column("b", String(30)), + Column("c", String(30)), + Column("d", String(30)), + ) def _fixture(self): class AB(object): @@ -393,9 +416,12 @@ class NestedTest(fixtures.MappedTest, testing.AssertsCompiledSQL): return (self.a, self.b) + self.cd.__composite_values__() def __eq__(self, other): - return isinstance(other, AB) and \ - self.a == other.a and self.b == other.b and \ - self.cd == other.cd + return ( + isinstance(other, AB) + and self.a == other.a + and self.b == other.b + and self.cd == other.cd + ) def __ne__(self, other): return not self.__eq__(other) @@ -409,8 +435,11 @@ class NestedTest(fixtures.MappedTest, testing.AssertsCompiledSQL): return (self.c, self.d) def __eq__(self, other): - return isinstance(other, CD) and \ - self.c == other.c and self.d == other.d + return ( + isinstance(other, CD) + and self.c == other.c + and self.d == other.d + ) def __ne__(self, other): return not self.__eq__(other) @@ -420,10 +449,15 @@ class NestedTest(fixtures.MappedTest, testing.AssertsCompiledSQL): self.ab = ab stuff = self.tables.stuff - mapper(Thing, stuff, properties={ - "ab": composite( - AB.generate, stuff.c.a, stuff.c.b, stuff.c.c, stuff.c.d) - }) + mapper( + Thing, + stuff, + properties={ + "ab": composite( + AB.generate, stuff.c.a, stuff.c.b, stuff.c.c, stuff.c.d + ) + }, + ) return Thing, AB, CD def test_round_trip(self): @@ -431,23 +465,27 @@ class NestedTest(fixtures.MappedTest, testing.AssertsCompiledSQL): s = Session() - s.add(Thing(AB('a', 'b', CD('c', 'd')))) + s.add(Thing(AB("a", "b", CD("c", "d")))) s.commit() s.close() - t1 = s.query(Thing).filter( - Thing.ab == AB('a', 'b', CD('c', 'd'))).one() - eq_(t1.ab, AB('a', 'b', CD('c', 'd'))) + t1 = ( + s.query(Thing).filter(Thing.ab == AB("a", "b", CD("c", "d"))).one() + ) + eq_(t1.ab, AB("a", "b", CD("c", "d"))) + class PrimaryKeyTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('graphs', metadata, - Column('id', Integer, primary_key=True), - Column('version_id', Integer, primary_key=True, - nullable=True), - Column('name', String(30))) + Table( + "graphs", + metadata, + Column("id", Integer, primary_key=True), + Column("version_id", Integer, primary_key=True, nullable=True), + Column("name", String(30)), + ) @classmethod def setup_mappers(cls): @@ -460,12 +498,15 @@ class PrimaryKeyTest(fixtures.MappedTest): def __composite_values__(self): return (self.id, self.version) + __hash__ = None def __eq__(self, other): - return isinstance(other, Version) and \ - other.id == self.id and \ - other.version == self.version + return ( + isinstance(other, Version) + and other.id == self.id + and other.version == self.version + ) def __ne__(self, other): return not self.__eq__(other) @@ -474,9 +515,15 @@ class PrimaryKeyTest(fixtures.MappedTest): def __init__(self, version): self.version = version - mapper(Graph, graphs, properties={ - 'version': sa.orm.composite(Version, graphs.c.id, - graphs.c.version_id)}) + mapper( + Graph, + graphs, + properties={ + "version": sa.orm.composite( + Version, graphs.c.id, graphs.c.version_id + ) + }, + ) def _fixture(self): Graph, Version = self.classes.Graph, self.classes.Version @@ -533,16 +580,19 @@ class PrimaryKeyTest(fixtures.MappedTest): class DefaultsTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('foobars', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('x1', Integer, default=2), - Column('x2', Integer), - Column('x3', Integer, server_default="15"), - Column('x4', Integer)) + Table( + "foobars", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("x1", Integer, default=2), + Column("x2", Integer), + Column("x3", Integer, server_default="15"), + Column("x4", Integer), + ) @classmethod def setup_mappers(cls): @@ -560,28 +610,41 @@ class DefaultsTest(fixtures.MappedTest): def __composite_values__(self): return self.goofy_x1, self.x2, self.x3, self.x4 + __hash__ = None def __eq__(self, other): - return other.goofy_x1 == self.goofy_x1 and \ - other.x2 == self.x2 and \ - other.x3 == self.x3 and \ - other.x4 == self.x4 + return ( + other.goofy_x1 == self.goofy_x1 + and other.x2 == self.x2 + and other.x3 == self.x3 + and other.x4 == self.x4 + ) def __ne__(self, other): return not self.__eq__(other) def __repr__(self): return "FBComposite(%r, %r, %r, %r)" % ( - self.goofy_x1, self.x2, self.x3, self.x4 + self.goofy_x1, + self.x2, + self.x3, + self.x4, ) - mapper(Foobar, foobars, properties=dict( - foob=sa.orm.composite(FBComposite, - foobars.c.x1, - foobars.c.x2, - foobars.c.x3, - foobars.c.x4) - )) + + mapper( + Foobar, + foobars, + properties=dict( + foob=sa.orm.composite( + FBComposite, + foobars.c.x1, + foobars.c.x2, + foobars.c.x3, + foobars.c.x4, + ) + ), + ) def test_attributes_with_defaults(self): Foobar, FBComposite = self.classes.Foobar, self.classes.FBComposite @@ -614,20 +677,31 @@ class DefaultsTest(fixtures.MappedTest): class MappedSelectTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('descriptions', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('d1', String(20)), - Column('d2', String(20))) - - Table('values', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('description_id', Integer, - ForeignKey('descriptions.id'), - nullable=False), - Column('v1', String(20)), - Column('v2', String(20))) + Table( + "descriptions", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("d1", String(20)), + Column("d2", String(20)), + ) + + Table( + "values", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column( + "description_id", + Integer, + ForeignKey("descriptions.id"), + nullable=False, + ), + Column("v1", String(20)), + Column("v2", String(20)), + ) @classmethod def setup_mappers(cls): @@ -648,68 +722,81 @@ class MappedSelectTest(fixtures.MappedTest): desc_values = select( [values, descriptions.c.d1, descriptions.c.d2], - descriptions.c.id == values.c.description_id - ).alias('descriptions_values') - - mapper(Descriptions, descriptions, properties={ - 'values': relationship(Values, lazy='dynamic'), - 'custom_descriptions': composite( - CustomValues, - descriptions.c.d1, - descriptions.c.d2), - - }) - - mapper(Values, desc_values, properties={ - 'custom_values': composite(CustomValues, - desc_values.c.v1, - desc_values.c.v2), + descriptions.c.id == values.c.description_id, + ).alias("descriptions_values") + + mapper( + Descriptions, + descriptions, + properties={ + "values": relationship(Values, lazy="dynamic"), + "custom_descriptions": composite( + CustomValues, descriptions.c.d1, descriptions.c.d2 + ), + }, + ) - }) + mapper( + Values, + desc_values, + properties={ + "custom_values": composite( + CustomValues, desc_values.c.v1, desc_values.c.v2 + ) + }, + ) def test_set_composite_attrs_via_selectable(self): - Values, CustomValues, values, Descriptions, descriptions = \ - (self.classes.Values, - self.classes.CustomValues, - self.tables.values, - self.classes.Descriptions, - self.tables.descriptions) + Values, CustomValues, values, Descriptions, descriptions = ( + self.classes.Values, + self.classes.CustomValues, + self.tables.values, + self.classes.Descriptions, + self.tables.descriptions, + ) session = Session() d = Descriptions( - custom_descriptions=CustomValues('Color', 'Number'), + custom_descriptions=CustomValues("Color", "Number"), values=[ - Values(custom_values=CustomValues('Red', '5')), - Values(custom_values=CustomValues('Blue', '1')) - ] + Values(custom_values=CustomValues("Red", "5")), + Values(custom_values=CustomValues("Blue", "1")), + ], ) session.add(d) session.commit() eq_( testing.db.execute(descriptions.select()).fetchall(), - [(1, 'Color', 'Number')] + [(1, "Color", "Number")], ) eq_( testing.db.execute(values.select()).fetchall(), - [(1, 1, 'Red', '5'), (2, 1, 'Blue', '1')] + [(1, 1, "Red", "5"), (2, 1, "Blue", "1")], ) class ManyToOneTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('a', - metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('b1', String(20)), - Column('b2_id', Integer, ForeignKey('b.id'))) - - Table('b', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(20))) + Table( + "a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("b1", String(20)), + Column("b2_id", Integer, ForeignKey("b.id")), + ) + + Table( + "b", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(20)), + ) @classmethod def setup_mappers(cls): @@ -729,14 +816,17 @@ class ManyToOneTest(fixtures.MappedTest): return self.b1, self.b2 def __eq__(self, other): - return isinstance(other, C) and \ - other.b1 == self.b1 and \ - other.b2 == self.b2 - - mapper(A, a, properties={ - 'b2': relationship(B), - 'c': composite(C, 'b1', 'b2') - }) + return ( + isinstance(other, C) + and other.b1 == self.b1 + and other.b2 == self.b2 + ) + + mapper( + A, + a, + properties={"b2": relationship(B), "c": composite(C, "b1", "b2")}, + ) mapper(B, b) def test_early_configure(self): @@ -746,63 +836,55 @@ class ManyToOneTest(fixtures.MappedTest): A.c.__clause_element__() def test_persist(self): - A, C, B = (self.classes.A, - self.classes.C, - self.classes.B) + A, C, B = (self.classes.A, self.classes.C, self.classes.B) sess = Session() - sess.add(A(c=C('b1', B(data='b2')))) + sess.add(A(c=C("b1", B(data="b2")))) sess.commit() a1 = sess.query(A).one() - eq_(a1.c, C('b1', B(data='b2'))) + eq_(a1.c, C("b1", B(data="b2"))) def test_query(self): - A, C, B = (self.classes.A, - self.classes.C, - self.classes.B) + A, C, B = (self.classes.A, self.classes.C, self.classes.B) sess = Session() - b1, b2 = B(data='b1'), B(data='b2') - a1 = A(c=C('a1b1', b1)) - a2 = A(c=C('a2b1', b2)) + b1, b2 = B(data="b1"), B(data="b2") + a1 = A(c=C("a1b1", b1)) + a2 = A(c=C("a2b1", b2)) sess.add_all([a1, a2]) sess.commit() - eq_( - sess.query(A).filter(A.c == C('a2b1', b2)).one(), - a2 - ) + eq_(sess.query(A).filter(A.c == C("a2b1", b2)).one(), a2) def test_query_aliased(self): - A, C, B = (self.classes.A, - self.classes.C, - self.classes.B) + A, C, B = (self.classes.A, self.classes.C, self.classes.B) sess = Session() - b1, b2 = B(data='b1'), B(data='b2') - a1 = A(c=C('a1b1', b1)) - a2 = A(c=C('a2b1', b2)) + b1, b2 = B(data="b1"), B(data="b2") + a1 = A(c=C("a1b1", b1)) + a2 = A(c=C("a2b1", b2)) sess.add_all([a1, a2]) sess.commit() ae = aliased(A) - eq_( - sess.query(ae).filter(ae.c == C('a2b1', b2)).one(), - a2 - ) + eq_(sess.query(ae).filter(ae.c == C("a2b1", b2)).one(), a2) class ConfigurationTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('edge', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('x1', Integer), - Column('y1', Integer), - Column('x2', Integer), - Column('y2', Integer)) + Table( + "edge", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("x1", Integer), + Column("y1", Integer), + Column("x2", Integer), + Column("y2", Integer), + ) @classmethod def setup_mappers(cls): @@ -815,13 +897,14 @@ class ConfigurationTest(fixtures.MappedTest): return [self.x, self.y] def __eq__(self, other): - return isinstance(other, Point) and \ - other.x == self.x and \ - other.y == self.y + return ( + isinstance(other, Point) + and other.x == self.x + and other.y == self.y + ) def __ne__(self, other): - return not isinstance(other, Point) or \ - not self.__eq__(other) + return not isinstance(other, Point) or not self.__eq__(other) class Edge(cls.Comparable): pass @@ -834,64 +917,85 @@ class ConfigurationTest(fixtures.MappedTest): sess.add(e1) sess.commit() - eq_( - sess.query(Edge).one(), - Edge(start=Point(3, 4), end=Point(5, 6)) - ) + eq_(sess.query(Edge).one(), Edge(start=Point(3, 4), end=Point(5, 6))) def test_columns(self): - edge, Edge, Point = (self.tables.edge, - self.classes.Edge, - self.classes.Point) + edge, Edge, Point = ( + self.tables.edge, + self.classes.Edge, + self.classes.Point, + ) - mapper(Edge, edge, properties={ - 'start': sa.orm.composite(Point, edge.c.x1, edge.c.y1), - 'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2) - }) + mapper( + Edge, + edge, + properties={ + "start": sa.orm.composite(Point, edge.c.x1, edge.c.y1), + "end": sa.orm.composite(Point, edge.c.x2, edge.c.y2), + }, + ) self._test_roundtrip() def test_attributes(self): - edge, Edge, Point = (self.tables.edge, - self.classes.Edge, - self.classes.Point) + edge, Edge, Point = ( + self.tables.edge, + self.classes.Edge, + self.classes.Point, + ) m = mapper(Edge, edge) - m.add_property('start', sa.orm.composite(Point, Edge.x1, Edge.y1)) - m.add_property('end', sa.orm.composite(Point, Edge.x2, Edge.y2)) + m.add_property("start", sa.orm.composite(Point, Edge.x1, Edge.y1)) + m.add_property("end", sa.orm.composite(Point, Edge.x2, Edge.y2)) self._test_roundtrip() def test_strings(self): - edge, Edge, Point = (self.tables.edge, - self.classes.Edge, - self.classes.Point) + edge, Edge, Point = ( + self.tables.edge, + self.classes.Edge, + self.classes.Point, + ) m = mapper(Edge, edge) - m.add_property('start', sa.orm.composite(Point, 'x1', 'y1')) - m.add_property('end', sa.orm.composite(Point, 'x2', 'y2')) + m.add_property("start", sa.orm.composite(Point, "x1", "y1")) + m.add_property("end", sa.orm.composite(Point, "x2", "y2")) self._test_roundtrip() def test_deferred(self): - edge, Edge, Point = (self.tables.edge, - self.classes.Edge, - self.classes.Point) - mapper(Edge, edge, properties={ - 'start': sa.orm.composite(Point, edge.c.x1, edge.c.y1, - deferred=True, group='s'), - 'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2, - deferred=True) - }) + edge, Edge, Point = ( + self.tables.edge, + self.classes.Edge, + self.classes.Point, + ) + mapper( + Edge, + edge, + properties={ + "start": sa.orm.composite( + Point, edge.c.x1, edge.c.y1, deferred=True, group="s" + ), + "end": sa.orm.composite( + Point, edge.c.x2, edge.c.y2, deferred=True + ), + }, + ) self._test_roundtrip() def test_check_prop_type(self): - edge, Edge, Point = (self.tables.edge, - self.classes.Edge, - self.classes.Point) - mapper(Edge, edge, properties={ - 'start': sa.orm.composite(Point, (edge.c.x1,), edge.c.y1), - }) + edge, Edge, Point = ( + self.tables.edge, + self.classes.Edge, + self.classes.Point, + ) + mapper( + Edge, + edge, + properties={ + "start": sa.orm.composite(Point, (edge.c.x1,), edge.c.y1) + }, + ) assert_raises_message( sa.exc.ArgumentError, # note that we also are checking that the tuple @@ -900,22 +1004,26 @@ class ConfigurationTest(fixtures.MappedTest): r"Composite expects Column objects or mapped " r"attributes/attribute names as " r"arguments, got: \(Column", - configure_mappers + configure_mappers, ) class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" @classmethod def define_tables(cls, metadata): - Table('edge', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('x1', Integer), - Column('y1', Integer), - Column('x2', Integer), - Column('y2', Integer)) + Table( + "edge", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("x1", Integer), + Column("y1", Integer), + Column("x2", Integer), + Column("y2", Integer), + ) @classmethod def setup_mappers(cls): @@ -928,13 +1036,14 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): return [self.x, self.y] def __eq__(self, other): - return isinstance(other, Point) and \ - other.x == self.x and \ - other.y == self.y + return ( + isinstance(other, Point) + and other.x == self.x + and other.y == self.y + ) def __ne__(self, other): - return not isinstance(other, Point) or \ - not self.__eq__(other) + return not isinstance(other, Point) or not self.__eq__(other) class Edge(cls.Comparable): def __init__(self, start, end): @@ -942,15 +1051,17 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): self.end = end def __eq__(self, other): - return isinstance(other, Edge) and \ - other.id == self.id + return isinstance(other, Edge) and other.id == self.id def _fixture(self, custom): - edge, Edge, Point = (self.tables.edge, - self.classes.Edge, - self.classes.Point) + edge, Edge, Point = ( + self.tables.edge, + self.classes.Edge, + self.classes.Point, + ) if custom: + class CustomComparator(sa.orm.CompositeProperty.Comparator): def near(self, other, d): clauses = self.__clause_element__().clauses @@ -958,16 +1069,28 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): diff_y = clauses[1] - other.y return diff_x * diff_x + diff_y * diff_y <= d * d - mapper(Edge, edge, properties={ - 'start': sa.orm.composite(Point, edge.c.x1, edge.c.y1, - comparator_factory=CustomComparator), - 'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2) - }) + mapper( + Edge, + edge, + properties={ + "start": sa.orm.composite( + Point, + edge.c.x1, + edge.c.y1, + comparator_factory=CustomComparator, + ), + "end": sa.orm.composite(Point, edge.c.x2, edge.c.y2), + }, + ) else: - mapper(Edge, edge, properties={ - 'start': sa.orm.composite(Point, edge.c.x1, edge.c.y1), - 'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2) - }) + mapper( + Edge, + edge, + properties={ + "start": sa.orm.composite(Point, edge.c.x1, edge.c.y1), + "end": sa.orm.composite(Point, edge.c.x2, edge.c.y2), + }, + ) def test_comparator_behavior_default(self): self._fixture(False) @@ -978,8 +1101,7 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): self._test_comparator_behavior() def _test_comparator_behavior(self): - Edge, Point = (self.classes.Edge, - self.classes.Point) + Edge, Point = (self.classes.Edge, self.classes.Point) sess = Session() e1 = Edge(Point(3, 4), Point(5, 6)) @@ -987,18 +1109,11 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): sess.add_all([e1, e2]) sess.commit() - assert sess.query(Edge).\ - filter(Edge.start == Point(3, 4)).one() is \ - e1 + assert sess.query(Edge).filter(Edge.start == Point(3, 4)).one() is e1 - assert sess.query(Edge).\ - filter(Edge.start != Point(3, 4)).first() is \ - e2 + assert sess.query(Edge).filter(Edge.start != Point(3, 4)).first() is e2 - eq_( - sess.query(Edge).filter(Edge.start == None).all(), # noqa - [] - ) + eq_(sess.query(Edge).filter(Edge.start == None).all(), []) # noqa def test_default_comparator_factory(self): self._fixture(False) @@ -1009,26 +1124,27 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): def test_custom_comparator_factory(self): self._fixture(True) - Edge, Point = (self.classes.Edge, - self.classes.Point) + Edge, Point = (self.classes.Edge, self.classes.Point) - edge_1, edge_2 = Edge(Point(0, 0), Point(3, 5)), \ - Edge(Point(0, 1), Point(3, 5)) + edge_1, edge_2 = ( + Edge(Point(0, 0), Point(3, 5)), + Edge(Point(0, 1), Point(3, 5)), + ) sess = Session() sess.add_all([edge_1, edge_2]) sess.commit() - near_edges = sess.query(Edge).filter( - Edge.start.near(Point(1, 1), 1) - ).all() + near_edges = ( + sess.query(Edge).filter(Edge.start.near(Point(1, 1), 1)).all() + ) assert edge_1 not in near_edges assert edge_2 in near_edges - near_edges = sess.query(Edge).filter( - Edge.start.near(Point(0, 1), 1) - ).all() + near_edges = ( + sess.query(Edge).filter(Edge.start.near(Point(0, 1), 1)).all() + ) assert edge_1 in near_edges and edge_2 in near_edges @@ -1040,7 +1156,7 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): s.query(Edge).order_by(Edge.start, Edge.end), "SELECT edge.id AS edge_id, edge.x1 AS edge_x1, " "edge.y1 AS edge_y1, edge.x2 AS edge_x2, edge.y2 AS edge_y2 " - "FROM edge ORDER BY edge.x1, edge.y1, edge.x2, edge.y2" + "FROM edge ORDER BY edge.x1, edge.y1, edge.x2, edge.y2", ) def test_order_by_aliased(self): @@ -1054,17 +1170,18 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): "edge_1.y1 AS edge_1_y1, edge_1.x2 AS edge_1_x2, " "edge_1.y2 AS edge_1_y2 " "FROM edge AS edge_1 ORDER BY edge_1.x1, edge_1.y1, " - "edge_1.x2, edge_1.y2" + "edge_1.x2, edge_1.y2", ) def test_clause_expansion(self): self._fixture(False) Edge = self.classes.Edge from sqlalchemy.orm import configure_mappers + configure_mappers() self.assert_compile( select([Edge]).order_by(Edge.start), "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge " - "ORDER BY edge.x1, edge.y1" + "ORDER BY edge.x1, edge.y1", ) diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py index 6e7702ad84..5f1c63f999 100644 --- a/test/orm/test_cycles.py +++ b/test/orm/test_cycles.py @@ -10,8 +10,14 @@ from sqlalchemy import event from sqlalchemy.testing import mock from sqlalchemy import Integer, String, ForeignKey from sqlalchemy.testing.schema import Table, Column -from sqlalchemy.orm import mapper, relationship, backref, \ - create_session, sessionmaker, Session +from sqlalchemy.orm import ( + mapper, + relationship, + backref, + create_session, + sessionmaker, + Session, +) from sqlalchemy.testing import eq_, is_ from sqlalchemy.testing.assertsql import RegexSQL, CompiledSQL, AllOf from sqlalchemy.testing import fixtures @@ -24,16 +30,24 @@ class SelfReferentialTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('t1', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_c1', Integer, ForeignKey('t1.c1')), - Column('data', String(20))) - Table('t2', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('c1id', Integer, ForeignKey('t1.c1')), - Column('data', String(20))) + Table( + "t1", + metadata, + Column( + "c1", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_c1", Integer, ForeignKey("t1.c1")), + Column("data", String(20)), + ) + Table( + "t2", + metadata, + Column( + "c1", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("c1id", Integer, ForeignKey("t1.c1")), + Column("data", String(20)), + ) @classmethod def setup_classes(cls): @@ -48,15 +62,22 @@ class SelfReferentialTest(fixtures.MappedTest): def test_single(self): C1, t1 = self.classes.C1, self.tables.t1 - mapper(C1, t1, properties={ - 'c1s': relationship(C1, cascade="all"), - 'parent': relationship(C1, - primaryjoin=t1.c.parent_c1 == t1.c.c1, - remote_side=t1.c.c1, - lazy='select', - uselist=False)}) - a = C1('head c1') - a.c1s.append(C1('another c1')) + mapper( + C1, + t1, + properties={ + "c1s": relationship(C1, cascade="all"), + "parent": relationship( + C1, + primaryjoin=t1.c.parent_c1 == t1.c.c1, + remote_side=t1.c.c1, + lazy="select", + uselist=False, + ), + }, + ) + a = C1("head c1") + a.c1s.append(C1("another c1")) sess = create_session() sess.add(a) @@ -75,10 +96,17 @@ class SelfReferentialTest(fixtures.MappedTest): C1, t1 = self.classes.C1, self.tables.t1 - mapper(C1, t1, properties={ - 'parent': relationship(C1, - primaryjoin=t1.c.parent_c1 == t1.c.c1, - remote_side=t1.c.c1)}) + mapper( + C1, + t1, + properties={ + "parent": relationship( + C1, + primaryjoin=t1.c.parent_c1 == t1.c.c1, + remote_side=t1.c.c1, + ) + }, + ) c1 = C1() @@ -94,22 +122,31 @@ class SelfReferentialTest(fixtures.MappedTest): assert c2.parent_c1 == c1.c1 def test_cycle(self): - C2, C1, t2, t1 = (self.classes.C2, - self.classes.C1, - self.tables.t2, - self.tables.t1) - - mapper(C1, t1, properties={ - 'c1s': relationship(C1, cascade="all"), - 'c2s': relationship(mapper(C2, t2), cascade="all, delete-orphan")}) - - a = C1('head c1') - a.c1s.append(C1('child1')) - a.c1s.append(C1('child2')) - a.c1s[0].c1s.append(C1('subchild1')) - a.c1s[0].c1s.append(C1('subchild2')) - a.c1s[1].c2s.append(C2('child2 data1')) - a.c1s[1].c2s.append(C2('child2 data2')) + C2, C1, t2, t1 = ( + self.classes.C2, + self.classes.C1, + self.tables.t2, + self.tables.t1, + ) + + mapper( + C1, + t1, + properties={ + "c1s": relationship(C1, cascade="all"), + "c2s": relationship( + mapper(C2, t2), cascade="all, delete-orphan" + ), + }, + ) + + a = C1("head c1") + a.c1s.append(C1("child1")) + a.c1s.append(C1("child2")) + a.c1s[0].c1s.append(C1("subchild1")) + a.c1s[0].c1s.append(C1("subchild2")) + a.c1s[1].c2s.append(C2("child2 data1")) + a.c1s[1].c2s.append(C2("child2 data2")) sess = create_session() sess.add(a) sess.flush() @@ -120,9 +157,7 @@ class SelfReferentialTest(fixtures.MappedTest): def test_setnull_ondelete(self): C1, t1 = self.classes.C1, self.tables.t1 - mapper(C1, t1, properties={ - 'children': relationship(C1) - }) + mapper(C1, t1, properties={"children": relationship(C1)}) sess = create_session() c1 = C1() @@ -146,12 +181,20 @@ class SelfReferentialNoPKTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('item', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('uuid', String(32), unique=True, nullable=False), - Column('parent_uuid', String(32), ForeignKey('item.uuid'), - nullable=True)) + Table( + "item", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("uuid", String(32), unique=True, nullable=False), + Column( + "parent_uuid", + String(32), + ForeignKey("item.uuid"), + nullable=True, + ), + ) @classmethod def setup_classes(cls): @@ -163,11 +206,17 @@ class SelfReferentialNoPKTest(fixtures.MappedTest): def setup_mappers(cls): item, TT = cls.tables.item, cls.classes.TT - mapper(TT, item, properties={ - 'children': relationship( - TT, - remote_side=[item.c.parent_uuid], - backref=backref('parent', remote_side=[item.c.uuid]))}) + mapper( + TT, + item, + properties={ + "children": relationship( + TT, + remote_side=[item.c.parent_uuid], + backref=backref("parent", remote_side=[item.c.uuid]), + ) + }, + ) def test_basic(self): TT = self.classes.TT @@ -202,21 +251,32 @@ class SelfReferentialNoPKTest(fixtures.MappedTest): class InheritTestOne(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table("parent", metadata, - Column("id", Integer, primary_key=True, - test_needs_autoincrement=True), - Column("parent_data", String(50)), - Column("type", String(10))) - - Table("child1", metadata, - Column("id", Integer, ForeignKey("parent.id"), primary_key=True), - Column("child1_data", String(50))) - - Table("child2", metadata, - Column("id", Integer, ForeignKey("parent.id"), primary_key=True), - Column("child1_id", Integer, ForeignKey("child1.id"), - nullable=False), - Column("child2_data", String(50))) + Table( + "parent", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_data", String(50)), + Column("type", String(10)), + ) + + Table( + "child1", + metadata, + Column("id", Integer, ForeignKey("parent.id"), primary_key=True), + Column("child1_data", String(50)), + ) + + Table( + "child2", + metadata, + Column("id", Integer, ForeignKey("parent.id"), primary_key=True), + Column( + "child1_id", Integer, ForeignKey("child1.id"), nullable=False + ), + Column("child2_data", String(50)), + ) @classmethod def setup_classes(cls): @@ -231,19 +291,27 @@ class InheritTestOne(fixtures.MappedTest): @classmethod def setup_mappers(cls): - child1, child2, parent, Parent, Child1, Child2 = (cls.tables.child1, - cls.tables.child2, - cls.tables.parent, - cls.classes.Parent, - cls.classes.Child1, - cls.classes.Child2) + child1, child2, parent, Parent, Child1, Child2 = ( + cls.tables.child1, + cls.tables.child2, + cls.tables.parent, + cls.classes.Parent, + cls.classes.Child1, + cls.classes.Child2, + ) mapper(Parent, parent) mapper(Child1, child1, inherits=Parent) - mapper(Child2, child2, inherits=Parent, properties=dict( - child1=relationship( - Child1, - primaryjoin=child2.c.child1_id == child1.c.id))) + mapper( + Child2, + child2, + inherits=Parent, + properties=dict( + child1=relationship( + Child1, primaryjoin=child2.c.child1_id == child1.c.id + ) + ), + ) def test_many_to_one_only(self): """test similar to SelfReferentialTest.testmanytooneonly""" @@ -281,19 +349,29 @@ class InheritTestTwo(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('a', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('cid', Integer, ForeignKey('c.id'))) + Table( + "a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("cid", Integer, ForeignKey("c.id")), + ) - Table('b', metadata, - Column('id', Integer, ForeignKey("a.id"), primary_key=True)) + Table( + "b", + metadata, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + ) - Table('c', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('aid', Integer, - ForeignKey('a.id', name="foo"))) + Table( + "c", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("aid", Integer, ForeignKey("a.id", name="foo")), + ) @classmethod def setup_classes(cls): @@ -307,20 +385,30 @@ class InheritTestTwo(fixtures.MappedTest): pass def test_flush(self): - a, A, c, b, C, B = (self.tables.a, - self.classes.A, - self.tables.c, - self.tables.b, - self.classes.C, - self.classes.B) + a, A, c, b, C, B = ( + self.tables.a, + self.classes.A, + self.tables.c, + self.tables.b, + self.classes.C, + self.classes.B, + ) - mapper(A, a, properties={ - 'cs': relationship(C, primaryjoin=a.c.cid == c.c.id)}) + mapper( + A, + a, + properties={"cs": relationship(C, primaryjoin=a.c.cid == c.c.id)}, + ) mapper(B, b, inherits=A, inherit_condition=b.c.id == a.c.id) - mapper(C, c, properties={ - 'arel': relationship(A, primaryjoin=a.c.id == c.c.aid)}) + mapper( + C, + c, + properties={ + "arel": relationship(A, primaryjoin=a.c.id == c.c.aid) + }, + ) sess = create_session() bobj = B() @@ -331,27 +419,38 @@ class InheritTestTwo(fixtures.MappedTest): class BiDirectionalManyToOneTest(fixtures.MappedTest): - run_define_tables = 'each' + run_define_tables = "each" @classmethod def define_tables(cls, metadata): - Table('t1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30)), - Column('t2id', Integer, ForeignKey('t2.id'))) - Table('t2', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30)), - Column('t1id', Integer, - ForeignKey('t1.id', name="foo_fk"))) - Table('t3', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30)), - Column('t1id', Integer, ForeignKey('t1.id'), nullable=False), - Column('t2id', Integer, ForeignKey('t2.id'), nullable=False)) + Table( + "t1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + Column("t2id", Integer, ForeignKey("t2.id")), + ) + Table( + "t2", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + Column("t1id", Integer, ForeignKey("t1.id", name="foo_fk")), + ) + Table( + "t3", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + Column("t1id", Integer, ForeignKey("t1.id"), nullable=False), + Column("t2id", Integer, ForeignKey("t2.id"), nullable=False), + ) @classmethod def setup_classes(cls): @@ -366,25 +465,35 @@ class BiDirectionalManyToOneTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - t2, T2, T3, t1, t3, T1 = (cls.tables.t2, - cls.classes.T2, - cls.classes.T3, - cls.tables.t1, - cls.tables.t3, - cls.classes.T1) - - mapper(T1, t1, properties={ - 't2': relationship(T2, primaryjoin=t1.c.t2id == t2.c.id)}) - mapper(T2, t2, properties={ - 't1': relationship(T1, primaryjoin=t2.c.t1id == t1.c.id)}) - mapper(T3, t3, properties={ - 't1': relationship(T1), - 't2': relationship(T2)}) + t2, T2, T3, t1, t3, T1 = ( + cls.tables.t2, + cls.classes.T2, + cls.classes.T3, + cls.tables.t1, + cls.tables.t3, + cls.classes.T1, + ) + + mapper( + T1, + t1, + properties={ + "t2": relationship(T2, primaryjoin=t1.c.t2id == t2.c.id) + }, + ) + mapper( + T2, + t2, + properties={ + "t1": relationship(T1, primaryjoin=t2.c.t1id == t1.c.id) + }, + ) + mapper( + T3, t3, properties={"t1": relationship(T1), "t2": relationship(T2)} + ) def test_reflush(self): - T2, T3, T1 = (self.classes.T2, - self.classes.T3, - self.classes.T1) + T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) o1 = T1() o1.t2 = T2() @@ -406,9 +515,7 @@ class BiDirectionalManyToOneTest(fixtures.MappedTest): def test_reflush_2(self): """A variant on test_reflush()""" - T2, T3, T1 = (self.classes.T2, - self.classes.T3, - self.classes.T1) + T2, T3, T1 = (self.classes.T2, self.classes.T3, self.classes.T1) o1 = T1() o1.t2 = T2() @@ -439,20 +546,27 @@ class BiDirectionalManyToOneTest(fixtures.MappedTest): class BiDirectionalOneToManyTest(fixtures.MappedTest): """tests two mappers with a one-to-many relationship to each other.""" - run_define_tables = 'each' + run_define_tables = "each" @classmethod def define_tables(cls, metadata): - Table('t1', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('c2', Integer, ForeignKey('t2.c1'))) + Table( + "t1", + metadata, + Column( + "c1", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("c2", Integer, ForeignKey("t2.c1")), + ) - Table('t2', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('c2', Integer, - ForeignKey('t1.c1', name='t1c1_fk'))) + Table( + "t2", + metadata, + Column( + "c1", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("c2", Integer, ForeignKey("t1.c1", name="t1c1_fk")), + ) @classmethod def setup_classes(cls): @@ -463,19 +577,31 @@ class BiDirectionalOneToManyTest(fixtures.MappedTest): pass def test_cycle(self): - C2, C1, t2, t1 = (self.classes.C2, - self.classes.C1, - self.tables.t2, - self.tables.t1) - - mapper(C2, t2, properties={ - 'c1s': relationship(C1, - primaryjoin=t2.c.c1 == t1.c.c2, - uselist=True)}) - mapper(C1, t1, properties={ - 'c2s': relationship(C2, - primaryjoin=t1.c.c1 == t2.c.c2, - uselist=True)}) + C2, C1, t2, t1 = ( + self.classes.C2, + self.classes.C1, + self.tables.t2, + self.tables.t1, + ) + + mapper( + C2, + t2, + properties={ + "c1s": relationship( + C1, primaryjoin=t2.c.c1 == t1.c.c2, uselist=True + ) + }, + ) + mapper( + C1, + t1, + properties={ + "c2s": relationship( + C2, primaryjoin=t1.c.c1 == t2.c.c2, uselist=True + ) + }, + ) a = C1() b = C2() @@ -495,29 +621,40 @@ class BiDirectionalOneToManyTest2(fixtures.MappedTest): """Two mappers with a one-to-many relationship to each other, with a second one-to-many on one of the mappers""" - run_define_tables = 'each' + run_define_tables = "each" @classmethod def define_tables(cls, metadata): - Table('t1', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('c2', Integer, ForeignKey('t2.c1')), - test_needs_autoincrement=True) - - Table('t2', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('c2', Integer, - ForeignKey('t1.c1', name='t1c1_fq')), - test_needs_autoincrement=True) - - Table('t1_data', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('t1id', Integer, ForeignKey('t1.c1')), - Column('data', String(20)), - test_needs_autoincrement=True) + Table( + "t1", + metadata, + Column( + "c1", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("c2", Integer, ForeignKey("t2.c1")), + test_needs_autoincrement=True, + ) + + Table( + "t2", + metadata, + Column( + "c1", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("c2", Integer, ForeignKey("t1.c1", name="t1c1_fq")), + test_needs_autoincrement=True, + ) + + Table( + "t1_data", + metadata, + Column( + "c1", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("t1id", Integer, ForeignKey("t1.c1")), + Column("data", String(20)), + test_needs_autoincrement=True, + ) @classmethod def setup_classes(cls): @@ -532,27 +669,41 @@ class BiDirectionalOneToManyTest2(fixtures.MappedTest): @classmethod def setup_mappers(cls): - t2, t1, C1Data, t1_data, C2, C1 = (cls.tables.t2, - cls.tables.t1, - cls.classes.C1Data, - cls.tables.t1_data, - cls.classes.C2, - cls.classes.C1) - - mapper(C2, t2, properties={ - 'c1s': relationship(C1, - primaryjoin=t2.c.c1 == t1.c.c2, - uselist=True)}) - mapper(C1, t1, properties={ - 'c2s': relationship(C2, - primaryjoin=t1.c.c1 == t2.c.c2, - uselist=True), - 'data': relationship(mapper(C1Data, t1_data))}) + t2, t1, C1Data, t1_data, C2, C1 = ( + cls.tables.t2, + cls.tables.t1, + cls.classes.C1Data, + cls.tables.t1_data, + cls.classes.C2, + cls.classes.C1, + ) + + mapper( + C2, + t2, + properties={ + "c1s": relationship( + C1, primaryjoin=t2.c.c1 == t1.c.c2, uselist=True + ) + }, + ) + mapper( + C1, + t1, + properties={ + "c2s": relationship( + C2, primaryjoin=t1.c.c1 == t2.c.c2, uselist=True + ), + "data": relationship(mapper(C1Data, t1_data)), + }, + ) def test_cycle(self): - C2, C1, C1Data = (self.classes.C2, - self.classes.C1, - self.classes.C1Data) + C2, C1, C1Data = ( + self.classes.C2, + self.classes.C1, + self.classes.C1Data, + ) a = C1() b = C2() @@ -563,9 +714,9 @@ class BiDirectionalOneToManyTest2(fixtures.MappedTest): a.c2s.append(b) d.c1s.append(c) b.c1s.append(c) - a.data.append(C1Data(data='c1data1')) - a.data.append(C1Data(data='c1data2')) - c.data.append(C1Data(data='c1data3')) + a.data.append(C1Data(data="c1data1")) + a.data.append(C1Data(data="c1data2")) + c.data.append(C1Data(data="c1data3")) sess = create_session() sess.add_all((a, b, c, d, e, f)) sess.flush() @@ -585,22 +736,34 @@ class OneToManyManyToOneTest(fixtures.MappedTest): dependencies are sorted. """ - run_define_tables = 'each' + + run_define_tables = "each" @classmethod def define_tables(cls, metadata): - Table('ball', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('person_id', Integer, - ForeignKey('person.id', name='fk_person_id')), - Column('data', String(30))) - - Table('person', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('favorite_ball_id', Integer, ForeignKey('ball.id')), - Column('data', String(30))) + Table( + "ball", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column( + "person_id", + Integer, + ForeignKey("person.id", name="fk_person_id"), + ), + Column("data", String(30)), + ) + + Table( + "person", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("favorite_ball_id", Integer, ForeignKey("ball.id")), + Column("data", String(30)), + ) @classmethod def setup_classes(cls): @@ -618,20 +781,30 @@ class OneToManyManyToOneTest(fixtures.MappedTest): """ - person, ball, Ball, Person = (self.tables.person, - self.tables.ball, - self.classes.Ball, - self.classes.Person) + person, ball, Ball, Person = ( + self.tables.person, + self.tables.ball, + self.classes.Ball, + self.classes.Person, + ) mapper(Ball, ball) - mapper(Person, person, properties=dict( - balls=relationship(Ball, - primaryjoin=ball.c.person_id == person.c.id, - remote_side=ball.c.person_id), - favorite=relationship( - Ball, - primaryjoin=person.c.favorite_ball_id == ball.c.id, - remote_side=ball.c.id))) + mapper( + Person, + person, + properties=dict( + balls=relationship( + Ball, + primaryjoin=ball.c.person_id == person.c.id, + remote_side=ball.c.person_id, + ), + favorite=relationship( + Ball, + primaryjoin=person.c.favorite_ball_id == ball.c.id, + remote_side=ball.c.id, + ), + ), + ) b = Ball() p = Person() @@ -641,18 +814,27 @@ class OneToManyManyToOneTest(fixtures.MappedTest): sess.flush() def test_post_update_m2o_no_cascade(self): - person, ball, Ball, Person = (self.tables.person, - self.tables.ball, - self.classes.Ball, - self.classes.Person) + person, ball, Ball, Person = ( + self.tables.person, + self.tables.ball, + self.classes.Ball, + self.classes.Person, + ) mapper(Ball, ball) - mapper(Person, person, properties=dict( - favorite=relationship( - Ball, primaryjoin=person.c.favorite_ball_id == ball.c.id, - post_update=True))) - b = Ball(data='some data') - p = Person(data='some data') + mapper( + Person, + person, + properties=dict( + favorite=relationship( + Ball, + primaryjoin=person.c.favorite_ball_id == ball.c.id, + post_update=True, + ) + ), + ) + b = Ball(data="some data") + p = Person(data="some data") p.favorite = b sess = create_session() sess.add(b) @@ -663,44 +845,54 @@ class OneToManyManyToOneTest(fixtures.MappedTest): self.assert_sql_execution( testing.db, sess.flush, - CompiledSQL("UPDATE person SET favorite_ball_id=:favorite_ball_id " - "WHERE person.id = :person_id", - lambda ctx: { - 'favorite_ball_id': None, - 'person_id': p.id} - ), - CompiledSQL("DELETE FROM person WHERE person.id = :id", - lambda ctx: {'id': p.id} - ), + CompiledSQL( + "UPDATE person SET favorite_ball_id=:favorite_ball_id " + "WHERE person.id = :person_id", + lambda ctx: {"favorite_ball_id": None, "person_id": p.id}, + ), + CompiledSQL( + "DELETE FROM person WHERE person.id = :id", + lambda ctx: {"id": p.id}, + ), ) def test_post_update_m2o(self): """A cycle between two rows, with a post_update on the many-to-one""" - person, ball, Ball, Person = (self.tables.person, - self.tables.ball, - self.classes.Ball, - self.classes.Person) + person, ball, Ball, Person = ( + self.tables.person, + self.tables.ball, + self.classes.Ball, + self.classes.Person, + ) mapper(Ball, ball) - mapper(Person, person, properties=dict( - balls=relationship(Ball, - primaryjoin=ball.c.person_id == person.c.id, - remote_side=ball.c.person_id, - post_update=False, - cascade="all, delete-orphan"), - favorite=relationship( - Ball, - primaryjoin=person.c.favorite_ball_id == ball.c.id, - remote_side=person.c.favorite_ball_id, - post_update=True))) - - b = Ball(data='some data') - p = Person(data='some data') + mapper( + Person, + person, + properties=dict( + balls=relationship( + Ball, + primaryjoin=ball.c.person_id == person.c.id, + remote_side=ball.c.person_id, + post_update=False, + cascade="all, delete-orphan", + ), + favorite=relationship( + Ball, + primaryjoin=person.c.favorite_ball_id == ball.c.id, + remote_side=person.c.favorite_ball_id, + post_update=True, + ), + ), + ) + + b = Ball(data="some data") + p = Person(data="some data") p.balls.append(b) - p.balls.append(Ball(data='some data')) - p.balls.append(Ball(data='some data')) - p.balls.append(Ball(data='some data')) + p.balls.append(Ball(data="some data")) + p.balls.append(Ball(data="some data")) + p.balls.append(Ball(data="some data")) p.favorite = b sess = create_session() sess.add(b) @@ -709,21 +901,31 @@ class OneToManyManyToOneTest(fixtures.MappedTest): self.assert_sql_execution( testing.db, sess.flush, - RegexSQL("^INSERT INTO person", {'data': 'some data'}), - RegexSQL("^INSERT INTO ball", lambda c: { - 'person_id': p.id, 'data': 'some data'}), - RegexSQL("^INSERT INTO ball", lambda c: { - 'person_id': p.id, 'data': 'some data'}), - RegexSQL("^INSERT INTO ball", lambda c: { - 'person_id': p.id, 'data': 'some data'}), - RegexSQL("^INSERT INTO ball", lambda c: { - 'person_id': p.id, 'data': 'some data'}), - CompiledSQL("UPDATE person SET favorite_ball_id=:favorite_ball_id " - "WHERE person.id = :person_id", - lambda ctx: { - 'favorite_ball_id': p.favorite.id, - 'person_id': p.id} - ), + RegexSQL("^INSERT INTO person", {"data": "some data"}), + RegexSQL( + "^INSERT INTO ball", + lambda c: {"person_id": p.id, "data": "some data"}, + ), + RegexSQL( + "^INSERT INTO ball", + lambda c: {"person_id": p.id, "data": "some data"}, + ), + RegexSQL( + "^INSERT INTO ball", + lambda c: {"person_id": p.id, "data": "some data"}, + ), + RegexSQL( + "^INSERT INTO ball", + lambda c: {"person_id": p.id, "data": "some data"}, + ), + CompiledSQL( + "UPDATE person SET favorite_ball_id=:favorite_ball_id " + "WHERE person.id = :person_id", + lambda ctx: { + "favorite_ball_id": p.favorite.id, + "person_id": p.id, + }, + ), ) sess.delete(p) @@ -731,43 +933,55 @@ class OneToManyManyToOneTest(fixtures.MappedTest): self.assert_sql_execution( testing.db, sess.flush, - CompiledSQL("UPDATE person SET favorite_ball_id=:favorite_ball_id " - "WHERE person.id = :person_id", - lambda ctx: {'person_id': p.id, - 'favorite_ball_id': None}), + CompiledSQL( + "UPDATE person SET favorite_ball_id=:favorite_ball_id " + "WHERE person.id = :person_id", + lambda ctx: {"person_id": p.id, "favorite_ball_id": None}, + ), # lambda ctx:[{'id': 1L}, {'id': 4L}, {'id': 3L}, {'id': 2L}]) CompiledSQL("DELETE FROM ball WHERE ball.id = :id", None), - CompiledSQL("DELETE FROM person WHERE person.id = :id", - lambda ctx: [{'id': p.id}]) + CompiledSQL( + "DELETE FROM person WHERE person.id = :id", + lambda ctx: [{"id": p.id}], + ), ) def test_post_update_backref(self): """test bidirectional post_update.""" - person, ball, Ball, Person = (self.tables.person, - self.tables.ball, - self.classes.Ball, - self.classes.Person) + person, ball, Ball, Person = ( + self.tables.person, + self.tables.ball, + self.classes.Ball, + self.classes.Person, + ) mapper(Ball, ball) - mapper(Person, person, properties=dict( - balls=relationship(Ball, - primaryjoin=ball.c.person_id == person.c.id, - remote_side=ball.c.person_id, post_update=True, - backref=backref('person', post_update=True) - ), - favorite=relationship( - Ball, - primaryjoin=person.c.favorite_ball_id == ball.c.id, - remote_side=person.c.favorite_ball_id) - )) + mapper( + Person, + person, + properties=dict( + balls=relationship( + Ball, + primaryjoin=ball.c.person_id == person.c.id, + remote_side=ball.c.person_id, + post_update=True, + backref=backref("person", post_update=True), + ), + favorite=relationship( + Ball, + primaryjoin=person.c.favorite_ball_id == ball.c.id, + remote_side=person.c.favorite_ball_id, + ), + ), + ) sess = sessionmaker()() - p1 = Person(data='p1') - p2 = Person(data='p2') - p3 = Person(data='p3') + p1 = Person(data="p1") + p2 = Person(data="p2") + p3 = Person(data="p3") - b1 = Ball(data='b1') + b1 = Ball(data="b1") b1.person = p1 sess.add_all([p1, p2, p3]) @@ -778,46 +992,52 @@ class OneToManyManyToOneTest(fixtures.MappedTest): # by the fact that there's a "reverse" prop. b1.person = p2 sess.commit() - eq_( - p2, b1.person - ) + eq_(p2, b1.person) # do it the other way p3.balls.append(b1) sess.commit() - eq_( - p3, b1.person - ) + eq_(p3, b1.person) def test_post_update_o2m(self): """A cycle between two rows, with a post_update on the one-to-many""" - person, ball, Ball, Person = (self.tables.person, - self.tables.ball, - self.classes.Ball, - self.classes.Person) + person, ball, Ball, Person = ( + self.tables.person, + self.tables.ball, + self.classes.Ball, + self.classes.Person, + ) mapper(Ball, ball) - mapper(Person, person, properties=dict( - balls=relationship(Ball, - primaryjoin=ball.c.person_id == person.c.id, - remote_side=ball.c.person_id, - cascade="all, delete-orphan", - post_update=True, - backref='person'), - favorite=relationship( - Ball, - primaryjoin=person.c.favorite_ball_id == ball.c.id, - remote_side=person.c.favorite_ball_id))) - - b = Ball(data='some data') - p = Person(data='some data') + mapper( + Person, + person, + properties=dict( + balls=relationship( + Ball, + primaryjoin=ball.c.person_id == person.c.id, + remote_side=ball.c.person_id, + cascade="all, delete-orphan", + post_update=True, + backref="person", + ), + favorite=relationship( + Ball, + primaryjoin=person.c.favorite_ball_id == ball.c.id, + remote_side=person.c.favorite_ball_id, + ), + ), + ) + + b = Ball(data="some data") + p = Person(data="some data") p.balls.append(b) - b2 = Ball(data='some data') + b2 = Ball(data="some data") p.balls.append(b2) - b3 = Ball(data='some data') + b3 = Ball(data="some data") p.balls.append(b3) - b4 = Ball(data='some data') + b4 = Ball(data="some data") p.balls.append(b4) p.favorite = b sess = create_session() @@ -826,79 +1046,92 @@ class OneToManyManyToOneTest(fixtures.MappedTest): self.assert_sql_execution( testing.db, sess.flush, - CompiledSQL("INSERT INTO ball (person_id, data) " - "VALUES (:person_id, :data)", - {'person_id': None, 'data': 'some data'}), - - CompiledSQL("INSERT INTO ball (person_id, data) " - "VALUES (:person_id, :data)", - {'person_id': None, 'data': 'some data'}), - - CompiledSQL("INSERT INTO ball (person_id, data) " - "VALUES (:person_id, :data)", - {'person_id': None, 'data': 'some data'}), - - CompiledSQL("INSERT INTO ball (person_id, data) " - "VALUES (:person_id, :data)", - {'person_id': None, 'data': 'some data'}), - - CompiledSQL("INSERT INTO person (favorite_ball_id, data) " - "VALUES (:favorite_ball_id, :data)", - lambda ctx: {'favorite_ball_id': b.id, - 'data': 'some data'}), - - CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx: [ - {'person_id': p.id, 'ball_id': b.id}, - {'person_id': p.id, 'ball_id': b2.id}, - {'person_id': p.id, 'ball_id': b3.id}, - {'person_id': p.id, 'ball_id': b4.id} - ]), + CompiledSQL( + "INSERT INTO ball (person_id, data) " + "VALUES (:person_id, :data)", + {"person_id": None, "data": "some data"}, + ), + CompiledSQL( + "INSERT INTO ball (person_id, data) " + "VALUES (:person_id, :data)", + {"person_id": None, "data": "some data"}, + ), + CompiledSQL( + "INSERT INTO ball (person_id, data) " + "VALUES (:person_id, :data)", + {"person_id": None, "data": "some data"}, + ), + CompiledSQL( + "INSERT INTO ball (person_id, data) " + "VALUES (:person_id, :data)", + {"person_id": None, "data": "some data"}, + ), + CompiledSQL( + "INSERT INTO person (favorite_ball_id, data) " + "VALUES (:favorite_ball_id, :data)", + lambda ctx: {"favorite_ball_id": b.id, "data": "some data"}, + ), + CompiledSQL( + "UPDATE ball SET person_id=:person_id " + "WHERE ball.id = :ball_id", + lambda ctx: [ + {"person_id": p.id, "ball_id": b.id}, + {"person_id": p.id, "ball_id": b2.id}, + {"person_id": p.id, "ball_id": b3.id}, + {"person_id": p.id, "ball_id": b4.id}, + ], + ), ) sess.delete(p) - self.assert_sql_execution(testing.db, sess.flush, - CompiledSQL( - "UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx: [ - {'person_id': None, - 'ball_id': b.id}, - {'person_id': None, - 'ball_id': b2.id}, - {'person_id': None, - 'ball_id': b3.id}, - {'person_id': None, - 'ball_id': b4.id} - ] - ), - CompiledSQL( - "DELETE FROM person " - "WHERE person.id = :id", - lambda ctx: [{'id': p.id}]), - - CompiledSQL( - "DELETE FROM ball WHERE ball.id = :id", - lambda ctx: [{'id': b.id}, - {'id': b2.id}, - {'id': b3.id}, - {'id': b4.id}]) - ) + self.assert_sql_execution( + testing.db, + sess.flush, + CompiledSQL( + "UPDATE ball SET person_id=:person_id " + "WHERE ball.id = :ball_id", + lambda ctx: [ + {"person_id": None, "ball_id": b.id}, + {"person_id": None, "ball_id": b2.id}, + {"person_id": None, "ball_id": b3.id}, + {"person_id": None, "ball_id": b4.id}, + ], + ), + CompiledSQL( + "DELETE FROM person " "WHERE person.id = :id", + lambda ctx: [{"id": p.id}], + ), + CompiledSQL( + "DELETE FROM ball WHERE ball.id = :id", + lambda ctx: [ + {"id": b.id}, + {"id": b2.id}, + {"id": b3.id}, + {"id": b4.id}, + ], + ), + ) def test_post_update_m2o_detect_none(self): person, ball, Ball, Person = ( self.tables.person, self.tables.ball, self.classes.Ball, - self.classes.Person) + self.classes.Person, + ) - mapper(Ball, ball, properties={ - 'person': relationship( - Person, post_update=True, - primaryjoin=person.c.id == ball.c.person_id) - }) + mapper( + Ball, + ball, + properties={ + "person": relationship( + Person, + post_update=True, + primaryjoin=person.c.id == ball.c.person_id, + ) + }, + ) mapper(Person, person) sess = create_session(autocommit=False, expire_on_commit=True) @@ -907,7 +1140,7 @@ class OneToManyManyToOneTest(fixtures.MappedTest): b1 = sess.query(Ball).first() # needs to be unloaded - assert 'person' not in b1.__dict__ + assert "person" not in b1.__dict__ b1.person = None self.assert_sql_execution( @@ -916,7 +1149,8 @@ class OneToManyManyToOneTest(fixtures.MappedTest): CompiledSQL( "UPDATE ball SET person_id=:person_id " "WHERE ball.id = :ball_id", - lambda ctx: {'person_id': None, 'ball_id': b1.id}) + lambda ctx: {"person_id": None, "ball_id": b1.id}, + ), ) is_(b1.person, None) @@ -930,21 +1164,32 @@ class SelfReferentialPostUpdateTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('node', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('path', String(50), nullable=False), - Column('parent_id', Integer, - ForeignKey('node.id'), nullable=True), - Column('prev_sibling_id', Integer, - ForeignKey('node.id'), nullable=True), - Column('next_sibling_id', Integer, - ForeignKey('node.id'), nullable=True)) + Table( + "node", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("path", String(50), nullable=False), + Column("parent_id", Integer, ForeignKey("node.id"), nullable=True), + Column( + "prev_sibling_id", + Integer, + ForeignKey("node.id"), + nullable=True, + ), + Column( + "next_sibling_id", + Integer, + ForeignKey("node.id"), + nullable=True, + ), + ) @classmethod def setup_classes(cls): class Node(cls.Basic): - def __init__(self, path=''): + def __init__(self, path=""): self.path = path def test_one(self): @@ -957,24 +1202,31 @@ class SelfReferentialPostUpdateTest(fixtures.MappedTest): node, Node = self.tables.node, self.classes.Node - mapper(Node, node, properties={ - 'children': relationship( - Node, - primaryjoin=node.c.id == node.c.parent_id, - cascade="all", - backref=backref("parent", remote_side=node.c.id) - ), - 'prev_sibling': relationship( - Node, - primaryjoin=node.c.prev_sibling_id == node.c.id, - remote_side=node.c.id, - uselist=False), - 'next_sibling': relationship( - Node, - primaryjoin=node.c.next_sibling_id == node.c.id, - remote_side=node.c.id, - uselist=False, - post_update=True)}) + mapper( + Node, + node, + properties={ + "children": relationship( + Node, + primaryjoin=node.c.id == node.c.parent_id, + cascade="all", + backref=backref("parent", remote_side=node.c.id), + ), + "prev_sibling": relationship( + Node, + primaryjoin=node.c.prev_sibling_id == node.c.id, + remote_side=node.c.id, + uselist=False, + ), + "next_sibling": relationship( + Node, + primaryjoin=node.c.next_sibling_id == node.c.id, + remote_side=node.c.id, + uselist=False, + post_update=True, + ), + }, + ) session = create_session() @@ -990,20 +1242,21 @@ class SelfReferentialPostUpdateTest(fixtures.MappedTest): node.prev_sibling = child.prev_sibling child.prev_sibling.next_sibling = node session.delete(child) - root = Node('root') - about = Node('about') - cats = Node('cats') - stories = Node('stories') - bruce = Node('bruce') + root = Node("root") + + about = Node("about") + cats = Node("cats") + stories = Node("stories") + bruce = Node("bruce") append_child(root, about) - assert(about.prev_sibling is None) + assert about.prev_sibling is None append_child(root, cats) - assert(cats.prev_sibling is about) - assert(cats.next_sibling is None) - assert(about.next_sibling is cats) - assert(about.prev_sibling is None) + assert cats.prev_sibling is about + assert cats.next_sibling is None + assert about.next_sibling is cats + assert about.prev_sibling is None append_child(root, stories) append_child(root, bruce) session.add(root) @@ -1017,24 +1270,32 @@ class SelfReferentialPostUpdateTest(fixtures.MappedTest): testing.db, session.flush, AllOf( - CompiledSQL("UPDATE node SET prev_sibling_id=:prev_sibling_id " - "WHERE node.id = :node_id", - lambda ctx: {'prev_sibling_id': about.id, - 'node_id': stories.id}), - - CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id " - "WHERE node.id = :node_id", - lambda ctx: {'next_sibling_id': stories.id, - 'node_id': about.id}), - - CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id " - "WHERE node.id = :node_id", - lambda ctx: {'next_sibling_id': None, - 'node_id': cats.id}), + CompiledSQL( + "UPDATE node SET prev_sibling_id=:prev_sibling_id " + "WHERE node.id = :node_id", + lambda ctx: { + "prev_sibling_id": about.id, + "node_id": stories.id, + }, + ), + CompiledSQL( + "UPDATE node SET next_sibling_id=:next_sibling_id " + "WHERE node.id = :node_id", + lambda ctx: { + "next_sibling_id": stories.id, + "node_id": about.id, + }, + ), + CompiledSQL( + "UPDATE node SET next_sibling_id=:next_sibling_id " + "WHERE node.id = :node_id", + lambda ctx: {"next_sibling_id": None, "node_id": cats.id}, + ), + ), + CompiledSQL( + "DELETE FROM node WHERE node.id = :id", + lambda ctx: [{"id": cats.id}], ), - - CompiledSQL("DELETE FROM node WHERE node.id = :id", - lambda ctx: [{'id': cats.id}]) ) session.delete(root) @@ -1042,30 +1303,35 @@ class SelfReferentialPostUpdateTest(fixtures.MappedTest): self.assert_sql_execution( testing.db, session.flush, - CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id " - "WHERE node.id = :node_id", - lambda ctx: [ - {'node_id': about.id, 'next_sibling_id': None}, - {'node_id': stories.id, 'next_sibling_id': None} - ] - ), + CompiledSQL( + "UPDATE node SET next_sibling_id=:next_sibling_id " + "WHERE node.id = :node_id", + lambda ctx: [ + {"node_id": about.id, "next_sibling_id": None}, + {"node_id": stories.id, "next_sibling_id": None}, + ], + ), AllOf( - CompiledSQL("DELETE FROM node WHERE node.id = :id", - lambda ctx: {'id': about.id} - ), - CompiledSQL("DELETE FROM node WHERE node.id = :id", - lambda ctx: {'id': stories.id} - ), - CompiledSQL("DELETE FROM node WHERE node.id = :id", - lambda ctx: {'id': bruce.id} - ), - ), - CompiledSQL("DELETE FROM node WHERE node.id = :id", - lambda ctx: {'id': root.id} - ), - ) - about = Node('about') - cats = Node('cats') + CompiledSQL( + "DELETE FROM node WHERE node.id = :id", + lambda ctx: {"id": about.id}, + ), + CompiledSQL( + "DELETE FROM node WHERE node.id = :id", + lambda ctx: {"id": stories.id}, + ), + CompiledSQL( + "DELETE FROM node WHERE node.id = :id", + lambda ctx: {"id": bruce.id}, + ), + ), + CompiledSQL( + "DELETE FROM node WHERE node.id = :id", + lambda ctx: {"id": root.id}, + ), + ) + about = Node("about") + cats = Node("cats") about.next_sibling = cats cats.prev_sibling = about session.add(about) @@ -1076,14 +1342,20 @@ class SelfReferentialPostUpdateTest(fixtures.MappedTest): class SelfReferentialPostUpdateTest2(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table("a_table", metadata, - Column("id", Integer(), primary_key=True, - test_needs_autoincrement=True), - Column("fui", String(128)), - Column("b", Integer(), ForeignKey("a_table.id"))) + Table( + "a_table", + metadata, + Column( + "id", + Integer(), + primary_key=True, + test_needs_autoincrement=True, + ), + Column("fui", String(128)), + Column("b", Integer(), ForeignKey("a_table.id")), + ) @classmethod def setup_classes(cls): @@ -1100,10 +1372,15 @@ class SelfReferentialPostUpdateTest2(fixtures.MappedTest): A, a_table = self.classes.A, self.tables.a_table - mapper(A, a_table, properties={ - 'foo': relationship(A, - remote_side=[a_table.c.id], - post_update=True)}) + mapper( + A, + a_table, + properties={ + "foo": relationship( + A, remote_side=[a_table.c.id], post_update=True + ) + }, + ) session = create_session() @@ -1127,54 +1404,76 @@ class SelfReferentialPostUpdateTest2(fixtures.MappedTest): class SelfReferentialPostUpdateTest3(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('parent', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50), nullable=False), - Column('child_id', Integer, - ForeignKey('child.id', name='c1'), nullable=True)) - - Table('child', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50), nullable=False), - Column('child_id', Integer, - ForeignKey('child.id')), - Column('parent_id', Integer, - ForeignKey('parent.id'), nullable=True)) + Table( + "parent", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50), nullable=False), + Column( + "child_id", + Integer, + ForeignKey("child.id", name="c1"), + nullable=True, + ), + ) + + Table( + "child", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50), nullable=False), + Column("child_id", Integer, ForeignKey("child.id")), + Column( + "parent_id", Integer, ForeignKey("parent.id"), nullable=True + ), + ) @classmethod def setup_classes(cls): class Parent(cls.Basic): - def __init__(self, name=''): + def __init__(self, name=""): self.name = name class Child(cls.Basic): - def __init__(self, name=''): + def __init__(self, name=""): self.name = name def test_one(self): - Child, Parent, parent, child = (self.classes.Child, - self.classes.Parent, - self.tables.parent, - self.tables.child) - - mapper(Parent, parent, properties={ - 'children': relationship( - Child, - primaryjoin=parent.c.id == child.c.parent_id), - 'child': relationship( - Child, - primaryjoin=parent.c.child_id == child.c.id, post_update=True) - }) - mapper(Child, child, properties={ - 'parent': relationship(Child, remote_side=child.c.id) - }) + Child, Parent, parent, child = ( + self.classes.Child, + self.classes.Parent, + self.tables.parent, + self.tables.child, + ) + + mapper( + Parent, + parent, + properties={ + "children": relationship( + Child, primaryjoin=parent.c.id == child.c.parent_id + ), + "child": relationship( + Child, + primaryjoin=parent.c.child_id == child.c.id, + post_update=True, + ), + }, + ) + mapper( + Child, + child, + properties={"parent": relationship(Child, remote_side=child.c.id)}, + ) session = create_session() - p1 = Parent('p1') - c1 = Child('c1') - c2 = Child('c2') + p1 = Parent("p1") + c1 = Child("c1") + c2 = Child("c2") p1.children = [c1, c2] c2.parent = c1 p1.child = c2 @@ -1182,8 +1481,8 @@ class SelfReferentialPostUpdateTest3(fixtures.MappedTest): session.add_all([p1, c1, c2]) session.flush() - p2 = Parent('p2') - c3 = Child('c3') + p2 = Parent("p2") + c3 = Child("c3") p2.children = [c3] p2.child = c3 session.add(p2) @@ -1203,55 +1502,85 @@ class PostUpdateBatchingTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('parent', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50), nullable=False), - Column('c1_id', Integer, - ForeignKey('child1.id', name='c1'), nullable=True), - Column('c2_id', Integer, - ForeignKey('child2.id', name='c2'), nullable=True), - Column('c3_id', Integer, - ForeignKey('child3.id', name='c3'), nullable=True) - ) - - Table('child1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50), nullable=False), - Column('parent_id', Integer, - ForeignKey('parent.id'), nullable=False)) - - Table('child2', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50), nullable=False), - Column('parent_id', Integer, - ForeignKey('parent.id'), nullable=False)) - - Table('child3', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50), nullable=False), - Column('parent_id', Integer, - ForeignKey('parent.id'), nullable=False)) + Table( + "parent", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50), nullable=False), + Column( + "c1_id", + Integer, + ForeignKey("child1.id", name="c1"), + nullable=True, + ), + Column( + "c2_id", + Integer, + ForeignKey("child2.id", name="c2"), + nullable=True, + ), + Column( + "c3_id", + Integer, + ForeignKey("child3.id", name="c3"), + nullable=True, + ), + ) + + Table( + "child1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50), nullable=False), + Column( + "parent_id", Integer, ForeignKey("parent.id"), nullable=False + ), + ) + + Table( + "child2", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50), nullable=False), + Column( + "parent_id", Integer, ForeignKey("parent.id"), nullable=False + ), + ) + + Table( + "child3", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50), nullable=False), + Column( + "parent_id", Integer, ForeignKey("parent.id"), nullable=False + ), + ) @classmethod def setup_classes(cls): class Parent(cls.Basic): - def __init__(self, name=''): + def __init__(self, name=""): self.name = name class Child1(cls.Basic): - def __init__(self, name=''): + def __init__(self, name=""): self.name = name class Child2(cls.Basic): - def __init__(self, name=''): + def __init__(self, name=""): self.name = name class Child3(cls.Basic): - def __init__(self, name=''): + def __init__(self, name=""): self.name = name def test_one(self): @@ -1263,38 +1592,49 @@ class PostUpdateBatchingTest(fixtures.MappedTest): self.tables.parent, self.classes.Child1, self.classes.Child2, - self.classes.Child3) - - mapper(Parent, parent, properties={ - 'c1s': relationship( - Child1, - primaryjoin=child1.c.parent_id == parent.c.id), - 'c2s': relationship( - Child2, - primaryjoin=child2.c.parent_id == parent.c.id), - 'c3s': relationship( - Child3, primaryjoin=child3.c.parent_id == parent.c.id), - - 'c1': relationship( - Child1, - primaryjoin=child1.c.id == parent.c.c1_id, post_update=True), - 'c2': relationship( - Child2, - primaryjoin=child2.c.id == parent.c.c2_id, post_update=True), - 'c3': relationship( - Child3, - primaryjoin=child3.c.id == parent.c.c3_id, post_update=True), - }) + self.classes.Child3, + ) + + mapper( + Parent, + parent, + properties={ + "c1s": relationship( + Child1, primaryjoin=child1.c.parent_id == parent.c.id + ), + "c2s": relationship( + Child2, primaryjoin=child2.c.parent_id == parent.c.id + ), + "c3s": relationship( + Child3, primaryjoin=child3.c.parent_id == parent.c.id + ), + "c1": relationship( + Child1, + primaryjoin=child1.c.id == parent.c.c1_id, + post_update=True, + ), + "c2": relationship( + Child2, + primaryjoin=child2.c.id == parent.c.c2_id, + post_update=True, + ), + "c3": relationship( + Child3, + primaryjoin=child3.c.id == parent.c.c3_id, + post_update=True, + ), + }, + ) mapper(Child1, child1) mapper(Child2, child2) mapper(Child3, child3) sess = create_session() - p1 = Parent('p1') - c11, c12, c13 = Child1('c1'), Child1('c2'), Child1('c3') - c21, c22, c23 = Child2('c1'), Child2('c2'), Child2('c3') - c31, c32, c33 = Child3('c1'), Child3('c2'), Child3('c3') + p1 = Parent("p1") + c11, c12, c13 = Child1("c1"), Child1("c2"), Child1("c3") + c21, c22, c23 = Child2("c1"), Child2("c2"), Child2("c3") + c31, c32, c33 = Child3("c1"), Child3("c2"), Child3("c3") p1.c1s = [c11, c12, c13] p1.c2s = [c21, c22, c23] @@ -1312,9 +1652,13 @@ class PostUpdateBatchingTest(fixtures.MappedTest): CompiledSQL( "UPDATE parent SET c1_id=:c1_id, c2_id=:c2_id, c3_id=:c3_id " "WHERE parent.id = :parent_id", - lambda ctx: {'c2_id': c23.id, 'parent_id': p1.id, - 'c1_id': c12.id, 'c3_id': c31.id} - ) + lambda ctx: { + "c2_id": c23.id, + "parent_id": p1.id, + "c1_id": c12.id, + "c3_id": c31.id, + }, + ), ) p1.c1 = p1.c2 = p1.c3 = None @@ -1325,41 +1669,45 @@ class PostUpdateBatchingTest(fixtures.MappedTest): CompiledSQL( "UPDATE parent SET c1_id=:c1_id, c2_id=:c2_id, c3_id=:c3_id " "WHERE parent.id = :parent_id", - lambda ctx: {'c2_id': None, 'parent_id': p1.id, - 'c1_id': None, 'c3_id': None} - ) + lambda ctx: { + "c2_id": None, + "parent_id": p1.id, + "c1_id": None, + "c3_id": None, + }, + ), ) from sqlalchemy import bindparam -class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): +class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) data = Column(Integer) - favorite_b_id = Column(ForeignKey('b.id', name="favorite_b_fk")) + favorite_b_id = Column(ForeignKey("b.id", name="favorite_b_fk")) bs = relationship("B", primaryjoin="A.id == B.a_id") favorite_b = relationship( - "B", primaryjoin="A.favorite_b_id == B.id", post_update=True) + "B", primaryjoin="A.favorite_b_id == B.id", post_update=True + ) updated = Column(Integer, onupdate=lambda: next(cls.counter)) updated_db = Column( Integer, onupdate=bindparam( - key='foo', - callable_=lambda: next(cls.db_counter) - ) + key="foo", callable_=lambda: next(cls.db_counter) + ), ) class B(Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) - a_id = Column(ForeignKey('a.id', name="a_fk")) + a_id = Column(ForeignKey("a.id", name="a_fk")) def setup(self): super(PostUpdateOnUpdateTest, self).setup() @@ -1402,9 +1750,9 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): eq_( canary.mock_calls, [ - mock.call.refresh_flush(a1, mock.ANY, ['updated']), - mock.call.expire(a1, ['updated_db']), - ] + mock.call.refresh_flush(a1, mock.ANY, ["updated"]), + mock.call.expire(a1, ["updated_db"]), + ], ) def test_update_defaults_refresh_flush_event_no_postupdate(self): @@ -1435,9 +1783,9 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): eq_( canary.mock_calls, [ - mock.call.refresh_flush(a1, mock.ANY, ['updated']), - mock.call.expire(a1, ['updated_db']), - ] + mock.call.refresh_flush(a1, mock.ANY, ["updated"]), + mock.call.expire(a1, ["updated_db"]), + ], ) def test_update_defaults_dont_expire_on_delete(self): @@ -1459,9 +1807,9 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): eq_( canary.mock_calls, [ - mock.call.refresh_flush(a1, mock.ANY, ['updated']), - mock.call.expire(a1, ['updated_db']), - ] + mock.call.refresh_flush(a1, mock.ANY, ["updated"]), + mock.call.expire(a1, ["updated_db"]), + ], ) # ensure that we load this value here, we want to see that it @@ -1483,11 +1831,10 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): canary.mock_calls, [ # previous flush - mock.call.refresh_flush(a1, mock.ANY, ['updated']), - mock.call.expire(a1, ['updated_db']), - + mock.call.refresh_flush(a1, mock.ANY, ["updated"]), + mock.call.expire(a1, ["updated_db"]), # nothing happened - ] + ], ) eq_(next(self.counter), 2) @@ -1520,9 +1867,9 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): eq_( canary.mock_calls, [ - mock.call.refresh_flush(a1, mock.ANY, ['updated']), - mock.call.expire(a1, ['updated_db']), - ] + mock.call.refresh_flush(a1, mock.ANY, ["updated"]), + mock.call.expire(a1, ["updated_db"]), + ], ) # ensure that we load this value here, we want to see that it @@ -1544,12 +1891,10 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): canary.mock_calls, [ # previous flush - mock.call.refresh_flush(a1, mock.ANY, ['updated']), - mock.call.expire(a1, ['updated_db']), - - + mock.call.refresh_flush(a1, mock.ANY, ["updated"]), + mock.call.expire(a1, ["updated_db"]), # nothing called for this flush - ] + ], ) def test_update_defaults_can_set_value(self): @@ -1571,4 +1916,3 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): eq_(a1.updated, 5) eq_(a1.updated_db, 7) - diff --git a/test/orm/test_default_strategies.py b/test/orm/test_default_strategies.py index 653665bcb6..82ba7c4bb2 100644 --- a/test/orm/test_default_strategies.py +++ b/test/orm/test_default_strategies.py @@ -7,7 +7,6 @@ from sqlalchemy.testing import eq_, assert_raises_message class DefaultStrategyOptionsTest(_fixtures.FixtureTest): - def _assert_fully_loaded(self, users): # verify everything loaded, with no additional sql needed def go(): @@ -19,8 +18,13 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): # a list for any([...]) instead of any(...) to prove we've # iterated all the items with no sql. f = util.flatten_iterator - assert any([i.keywords for i in - f([o.items for o in f([u.orders for u in users])])]) + assert any( + [ + i.keywords + for i in f([o.items for o in f([u.orders for u in users])]) + ] + ) + self.assert_sql_count(testing.db, go, 0) def _assert_addresses_loaded(self, users): @@ -28,65 +32,126 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): def go(): for u, static in zip(users, self.static.user_all_result): eq_(u.addresses, static.addresses) + self.assert_sql_count(testing.db, go, 0) def _downgrade_fixture(self): - users, Keyword, items, order_items, orders, Item, User, \ - Address, keywords, item_keywords, Order, addresses = \ - self.tables.users, self.classes.Keyword, self.tables.items, \ - self.tables.order_items, self.tables.orders, \ - self.classes.Item, self.classes.User, self.classes.Address, \ - self.tables.keywords, self.tables.item_keywords, \ - self.classes.Order, self.tables.addresses + users, Keyword, items, order_items, orders, Item, User, Address, keywords, item_keywords, Order, addresses = ( + self.tables.users, + self.classes.Keyword, + self.tables.items, + self.tables.order_items, + self.tables.orders, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.tables.keywords, + self.tables.item_keywords, + self.classes.Order, + self.tables.addresses, + ) mapper(Address, addresses) mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords=relationship(Keyword, secondary=item_keywords, - lazy='subquery', - order_by=item_keywords.c.keyword_id))) + mapper( + Item, + items, + properties=dict( + keywords=relationship( + Keyword, + secondary=item_keywords, + lazy="subquery", + order_by=item_keywords.c.keyword_id, + ) + ), + ) - mapper(Order, orders, properties=dict( - items=relationship(Item, secondary=order_items, lazy='subquery', - order_by=order_items.c.item_id))) + mapper( + Order, + orders, + properties=dict( + items=relationship( + Item, + secondary=order_items, + lazy="subquery", + order_by=order_items.c.item_id, + ) + ), + ) - mapper(User, users, properties=dict( - addresses=relationship(Address, lazy='joined', - order_by=addresses.c.id), - orders=relationship(Order, lazy='joined', - order_by=orders.c.id))) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, lazy="joined", order_by=addresses.c.id + ), + orders=relationship( + Order, lazy="joined", order_by=orders.c.id + ), + ), + ) return create_session() def _upgrade_fixture(self): - users, Keyword, items, order_items, orders, Item, User, \ - Address, keywords, item_keywords, Order, addresses = \ - self.tables.users, self.classes.Keyword, self.tables.items, \ - self.tables.order_items, self.tables.orders, \ - self.classes.Item, self.classes.User, self.classes.Address, \ - self.tables.keywords, self.tables.item_keywords, \ - self.classes.Order, self.tables.addresses + users, Keyword, items, order_items, orders, Item, User, Address, keywords, item_keywords, Order, addresses = ( + self.tables.users, + self.classes.Keyword, + self.tables.items, + self.tables.order_items, + self.tables.orders, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.tables.keywords, + self.tables.item_keywords, + self.classes.Order, + self.tables.addresses, + ) mapper(Address, addresses) mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords=relationship(Keyword, secondary=item_keywords, - lazy='select', - order_by=item_keywords.c.keyword_id))) + mapper( + Item, + items, + properties=dict( + keywords=relationship( + Keyword, + secondary=item_keywords, + lazy="select", + order_by=item_keywords.c.keyword_id, + ) + ), + ) - mapper(Order, orders, properties=dict( - items=relationship(Item, secondary=order_items, lazy=True, - order_by=order_items.c.item_id))) + mapper( + Order, + orders, + properties=dict( + items=relationship( + Item, + secondary=order_items, + lazy=True, + order_by=order_items.c.item_id, + ) + ), + ) - mapper(User, users, properties=dict( - addresses=relationship(Address, lazy=True, - order_by=addresses.c.id), - orders=relationship(Order, - order_by=orders.c.id))) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, lazy=True, order_by=addresses.c.id + ), + orders=relationship(Order, order_by=orders.c.id), + ), + ) return create_session() @@ -99,9 +164,12 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): # test _downgrade_fixture mapper defaults, 3 queries (2 subquery # loads). def go(): - users[:] = sess.query(self.classes.User)\ - .order_by(self.classes.User.id)\ + users[:] = ( + sess.query(self.classes.User) + .order_by(self.classes.User.id) .all() + ) + self.assert_sql_count(testing.db, go, 3) # all loaded with no additional sql @@ -120,16 +188,20 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): # demonstrate that enable_eagerloads loads with only 1 sql def go(): - users[:] = sess.query(self.classes.User)\ - .enable_eagerloads(False)\ - .order_by(self.classes.User.id)\ + users[:] = ( + sess.query(self.classes.User) + .enable_eagerloads(False) + .order_by(self.classes.User.id) .all() + ) + self.assert_sql_count(testing.db, go, 1) # demonstrate that users[0].orders must now be loaded with 3 sql # (need to lazyload, and 2 subquery: 3 total) def go(): users[0].orders + self.assert_sql_count(testing.db, go, 3) def test_last_one_wins(self): @@ -137,12 +209,15 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): users = [] def go(): - users[:] = sess.query(self.classes.User)\ - .options(sa.orm.subqueryload('*'))\ - .options(sa.orm.joinedload(self.classes.User.addresses))\ - .options(sa.orm.lazyload('*'))\ - .order_by(self.classes.User.id)\ + users[:] = ( + sess.query(self.classes.User) + .options(sa.orm.subqueryload("*")) + .options(sa.orm.joinedload(self.classes.User.addresses)) + .options(sa.orm.lazyload("*")) + .order_by(self.classes.User.id) .all() + ) + self.assert_sql_count(testing.db, go, 1) # verify all the addresses were joined loaded (no more sql) @@ -151,26 +226,27 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): def test_star_must_be_alone(self): sess = self._downgrade_fixture() User = self.classes.User - opt = sa.orm.subqueryload('*', User.addresses) + opt = sa.orm.subqueryload("*", User.addresses) assert_raises_message( sa.exc.ArgumentError, "Wildcard token cannot be followed by another entity", - sess.query(User).options, opt + sess.query(User).options, + opt, ) def test_global_star_ignored_no_entities_unbound(self): sess = self._downgrade_fixture() User = self.classes.User - opt = sa.orm.lazyload('*') + opt = sa.orm.lazyload("*") q = sess.query(User.name).options(opt) - eq_(q.all(), [('jack',), ('ed',), ('fred',), ('chuck',)]) + eq_(q.all(), [("jack",), ("ed",), ("fred",), ("chuck",)]) def test_global_star_ignored_no_entities_bound(self): sess = self._downgrade_fixture() User = self.classes.User - opt = sa.orm.Load(User).lazyload('*') + opt = sa.orm.Load(User).lazyload("*") q = sess.query(User.name).options(opt) - eq_(q.all(), [('jack',), ('ed',), ('fred',), ('chuck',)]) + eq_(q.all(), [("jack",), ("ed",), ("fred",), ("chuck",)]) def test_select_with_joinedload(self): """Mapper load strategy defaults can be downgraded with @@ -181,11 +257,14 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): # lazyload('*') shuts off 'orders' subquery: only 1 sql def go(): - users[:] = sess.query(self.classes.User)\ - .options(sa.orm.lazyload('*'))\ - .options(sa.orm.joinedload(self.classes.User.addresses))\ - .order_by(self.classes.User.id)\ + users[:] = ( + sess.query(self.classes.User) + .options(sa.orm.lazyload("*")) + .options(sa.orm.joinedload(self.classes.User.addresses)) + .order_by(self.classes.User.id) .all() + ) + self.assert_sql_count(testing.db, go, 1) # verify all the addresses were joined loaded (no more sql) @@ -195,6 +274,7 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): # (same as with test_disable_eagerloads): 3 total sql def go(): users[0].orders + self.assert_sql_count(testing.db, go, 3) def test_select_with_subqueryload(self): @@ -207,17 +287,21 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): # now test 'default_strategy' option combined with 'subquery' # shuts off 'addresses' load AND orders.items load: 2 sql expected def go(): - users[:] = sess.query(self.classes.User)\ - .options(sa.orm.lazyload('*'))\ - .options(sa.orm.subqueryload(self.classes.User.orders))\ - .order_by(self.classes.User.id)\ + users[:] = ( + sess.query(self.classes.User) + .options(sa.orm.lazyload("*")) + .options(sa.orm.subqueryload(self.classes.User.orders)) + .order_by(self.classes.User.id) .all() + ) + self.assert_sql_count(testing.db, go, 2) # Verify orders have already been loaded: 0 sql def go(): for u, static in zip(users, self.static.user_all_result): assert len(u.orders) == len(static.orders) + self.assert_sql_count(testing.db, go, 0) # Verify lazyload('*') prevented orders.items load @@ -226,6 +310,7 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): def go(): for i in users[0].orders[0].items: i.keywords + self.assert_sql_count(testing.db, go, 2) # lastly, make sure they actually loaded properly @@ -240,11 +325,14 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): # test noload('*') shuts off 'orders' subquery, only 1 sql def go(): - users[:] = sess.query(self.classes.User)\ - .options(sa.orm.noload('*'))\ - .options(sa.orm.joinedload(self.classes.User.addresses))\ - .order_by(self.classes.User.id)\ + users[:] = ( + sess.query(self.classes.User) + .options(sa.orm.noload("*")) + .options(sa.orm.joinedload(self.classes.User.addresses)) + .order_by(self.classes.User.id) .all() + ) + self.assert_sql_count(testing.db, go, 1) # verify all the addresses were joined loaded (no more sql) @@ -254,6 +342,7 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): def go(): for u in users: assert u.orders == [] + self.assert_sql_count(testing.db, go, 0) def test_noload_with_subqueryload(self): @@ -266,11 +355,14 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): # test noload('*') option combined with subqueryload() # shuts off 'addresses' load AND orders.items load: 2 sql expected def go(): - users[:] = sess.query(self.classes.User)\ - .options(sa.orm.noload('*'))\ - .options(sa.orm.subqueryload(self.classes.User.orders))\ - .order_by(self.classes.User.id)\ + users[:] = ( + sess.query(self.classes.User) + .options(sa.orm.noload("*")) + .options(sa.orm.subqueryload(self.classes.User.orders)) + .order_by(self.classes.User.id) .all() + ) + self.assert_sql_count(testing.db, go, 2) def go(): @@ -282,6 +374,7 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): for u in users: for o in u.orders: assert o.items == [] + self.assert_sql_count(testing.db, go, 0) def test_joined(self): @@ -292,10 +385,13 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): # test upgrade all to joined: 1 sql def go(): - users[:] = sess.query(self.classes.User)\ - .options(sa.orm.joinedload('*'))\ - .order_by(self.classes.User.id)\ + users[:] = ( + sess.query(self.classes.User) + .options(sa.orm.joinedload("*")) + .order_by(self.classes.User.id) .all() + ) + self.assert_sql_count(testing.db, go, 1) # verify everything loaded, with no additional sql needed @@ -307,13 +403,15 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): # test upgrade all to joined: 1 sql def go(): - users[:] = sess.query(self.classes.User)\ - .options(sa.orm.joinedload('.*'))\ - .options(sa.orm.joinedload("addresses.*"))\ - .options(sa.orm.joinedload("orders.*"))\ - .options(sa.orm.joinedload("orders.items.*"))\ - .order_by(self.classes.User.id)\ + users[:] = ( + sess.query(self.classes.User) + .options(sa.orm.joinedload(".*")) + .options(sa.orm.joinedload("addresses.*")) + .options(sa.orm.joinedload("orders.*")) + .options(sa.orm.joinedload("orders.items.*")) + .order_by(self.classes.User.id) .all() + ) self.assert_sql_count(testing.db, go, 1) self._assert_fully_loaded(users) @@ -327,27 +425,33 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): # test joined all but 'keywords': upgraded to 1 sql def go(): - users[:] = sess.query(self.classes.User)\ - .options(sa.orm.lazyload('orders.items.keywords'))\ - .options(sa.orm.joinedload('*'))\ - .order_by(self.classes.User.id)\ + users[:] = ( + sess.query(self.classes.User) + .options(sa.orm.lazyload("orders.items.keywords")) + .options(sa.orm.joinedload("*")) + .order_by(self.classes.User.id) .all() + ) + self.assert_sql_count(testing.db, go, 1) # everything (but keywords) loaded ok # (note self.static.user_all_result contains no keywords) def go(): eq_(users, self.static.user_all_result) + self.assert_sql_count(testing.db, go, 0) # verify the items were loaded, while item.keywords were not def go(): # redundant with last test, but illustrative users[0].orders[0].items[0] + self.assert_sql_count(testing.db, go, 0) def go(): users[0].orders[0].items[0].keywords + self.assert_sql_count(testing.db, go, 1) def test_joined_with_subqueryload(self): @@ -359,11 +463,14 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): # test upgrade all but 'addresses', which is subquery loaded (2 sql) def go(): - users[:] = sess.query(self.classes.User)\ - .options(sa.orm.subqueryload(self.classes.User.addresses))\ - .options(sa.orm.joinedload('*'))\ - .order_by(self.classes.User.id)\ + users[:] = ( + sess.query(self.classes.User) + .options(sa.orm.subqueryload(self.classes.User.addresses)) + .options(sa.orm.joinedload("*")) + .order_by(self.classes.User.id) .all() + ) + self.assert_sql_count(testing.db, go, 2) # verify everything loaded, with no additional sql needed @@ -377,10 +484,13 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): # test upgrade all to subquery: 1 sql + 4 relationships = 5 def go(): - users[:] = sess.query(self.classes.User)\ - .options(sa.orm.subqueryload('*'))\ - .order_by(self.classes.User.id)\ + users[:] = ( + sess.query(self.classes.User) + .options(sa.orm.subqueryload("*")) + .order_by(self.classes.User.id) .all() + ) + self.assert_sql_count(testing.db, go, 5) # verify everything loaded, with no additional sql needed @@ -392,13 +502,16 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): # test upgrade all to subquery: 1 sql + 4 relationships = 5 def go(): - users[:] = sess.query(self.classes.User)\ - .options(sa.orm.subqueryload('.*'))\ - .options(sa.orm.subqueryload('addresses.*'))\ - .options(sa.orm.subqueryload('orders.*'))\ - .options(sa.orm.subqueryload('orders.items.*'))\ - .order_by(self.classes.User.id)\ + users[:] = ( + sess.query(self.classes.User) + .options(sa.orm.subqueryload(".*")) + .options(sa.orm.subqueryload("addresses.*")) + .options(sa.orm.subqueryload("orders.*")) + .options(sa.orm.subqueryload("orders.items.*")) + .order_by(self.classes.User.id) .all() + ) + self.assert_sql_count(testing.db, go, 5) # verify everything loaded, with no additional sql needed @@ -413,26 +526,32 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): # test subquery all but 'keywords' (1 sql + 3 relationships = 4) def go(): - users[:] = sess.query(self.classes.User)\ - .options(sa.orm.lazyload('orders.items.keywords'))\ - .options(sa.orm.subqueryload('*'))\ - .order_by(self.classes.User.id)\ + users[:] = ( + sess.query(self.classes.User) + .options(sa.orm.lazyload("orders.items.keywords")) + .options(sa.orm.subqueryload("*")) + .order_by(self.classes.User.id) .all() + ) + self.assert_sql_count(testing.db, go, 4) # no more sql # (note self.static.user_all_result contains no keywords) def go(): eq_(users, self.static.user_all_result) + self.assert_sql_count(testing.db, go, 0) # verify the item.keywords were not loaded def go(): users[0].orders[0].items[0] + self.assert_sql_count(testing.db, go, 0) def go(): users[0].orders[0].items[0].keywords + self.assert_sql_count(testing.db, go, 1) def test_subquery_with_joinedload(self): @@ -445,12 +564,15 @@ class DefaultStrategyOptionsTest(_fixtures.FixtureTest): # test upgrade all but 'addresses' & 'orders', which are joinedloaded # (1 sql + items + keywords = 3) def go(): - users[:] = sess.query(self.classes.User)\ - .options(sa.orm.joinedload(self.classes.User.addresses))\ - .options(sa.orm.joinedload(self.classes.User.orders))\ - .options(sa.orm.subqueryload('*'))\ - .order_by(self.classes.User.id)\ + users[:] = ( + sess.query(self.classes.User) + .options(sa.orm.joinedload(self.classes.User.addresses)) + .options(sa.orm.joinedload(self.classes.User.orders)) + .options(sa.orm.subqueryload("*")) + .order_by(self.classes.User.id) .all() + ) + self.assert_sql_count(testing.db, go, 3) # verify everything loaded, with no additional sql needed diff --git a/test/orm/test_defaults.py b/test/orm/test_defaults.py index 426bf9c18e..878cc430bf 100644 --- a/test/orm/test_defaults.py +++ b/test/orm/test_defaults.py @@ -8,72 +8,95 @@ from sqlalchemy.testing import eq_ class TriggerDefaultsTest(fixtures.MappedTest): - __requires__ = ('row_triggers',) + __requires__ = ("row_triggers",) @classmethod def define_tables(cls, metadata): - dt = Table('dt', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('col1', String(20)), - Column('col2', String(20), - server_default=sa.schema.FetchedValue()), - Column('col3', String(20), - sa.schema.FetchedValue(for_update=True)), - Column('col4', String(20), - sa.schema.FetchedValue(), - sa.schema.FetchedValue(for_update=True))) + dt = Table( + "dt", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("col1", String(20)), + Column( + "col2", String(20), server_default=sa.schema.FetchedValue() + ), + Column( + "col3", String(20), sa.schema.FetchedValue(for_update=True) + ), + Column( + "col4", + String(20), + sa.schema.FetchedValue(), + sa.schema.FetchedValue(for_update=True), + ), + ) for ins in ( - sa.DDL("CREATE TRIGGER dt_ins AFTER INSERT ON dt " - "FOR EACH ROW BEGIN " - "UPDATE dt SET col2='ins', col4='ins' " - "WHERE dt.id = NEW.id; END", - on='sqlite'), - sa.DDL("CREATE TRIGGER dt_ins ON dt AFTER INSERT AS " - "UPDATE dt SET col2='ins', col4='ins' " - "WHERE dt.id IN (SELECT id FROM inserted);", - on='mssql'), - sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT " - "ON dt " - "FOR EACH ROW " - "BEGIN " - ":NEW.col2 := 'ins'; :NEW.col4 := 'ins'; END;", - on='oracle'), - sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT ON dt " - "FOR EACH ROW BEGIN " - "SET NEW.col2='ins'; SET NEW.col4='ins'; END", - on=lambda ddl, event, target, bind, **kw: - bind.engine.name not in ('oracle', 'mssql', 'sqlite') - ), + sa.DDL( + "CREATE TRIGGER dt_ins AFTER INSERT ON dt " + "FOR EACH ROW BEGIN " + "UPDATE dt SET col2='ins', col4='ins' " + "WHERE dt.id = NEW.id; END", + on="sqlite", + ), + sa.DDL( + "CREATE TRIGGER dt_ins ON dt AFTER INSERT AS " + "UPDATE dt SET col2='ins', col4='ins' " + "WHERE dt.id IN (SELECT id FROM inserted);", + on="mssql", + ), + sa.DDL( + "CREATE TRIGGER dt_ins BEFORE INSERT " + "ON dt " + "FOR EACH ROW " + "BEGIN " + ":NEW.col2 := 'ins'; :NEW.col4 := 'ins'; END;", + on="oracle", + ), + sa.DDL( + "CREATE TRIGGER dt_ins BEFORE INSERT ON dt " + "FOR EACH ROW BEGIN " + "SET NEW.col2='ins'; SET NEW.col4='ins'; END", + on=lambda ddl, event, target, bind, **kw: bind.engine.name + not in ("oracle", "mssql", "sqlite"), + ), ): - event.listen(dt, 'after_create', ins) + event.listen(dt, "after_create", ins) - event.listen(dt, 'before_drop', sa.DDL("DROP TRIGGER dt_ins")) + event.listen(dt, "before_drop", sa.DDL("DROP TRIGGER dt_ins")) for up in ( - sa.DDL("CREATE TRIGGER dt_up AFTER UPDATE ON dt " - "FOR EACH ROW BEGIN " - "UPDATE dt SET col3='up', col4='up' " - "WHERE dt.id = OLD.id; END", - on='sqlite'), - sa.DDL("CREATE TRIGGER dt_up ON dt AFTER UPDATE AS " - "UPDATE dt SET col3='up', col4='up' " - "WHERE dt.id IN (SELECT id FROM deleted);", - on='mssql'), - sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt " - "FOR EACH ROW BEGIN " - ":NEW.col3 := 'up'; :NEW.col4 := 'up'; END;", - on='oracle'), - sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt " - "FOR EACH ROW BEGIN " - "SET NEW.col3='up'; SET NEW.col4='up'; END", - on=lambda ddl, event, target, bind, **kw: - bind.engine.name not in ('oracle', 'mssql', 'sqlite') - ), + sa.DDL( + "CREATE TRIGGER dt_up AFTER UPDATE ON dt " + "FOR EACH ROW BEGIN " + "UPDATE dt SET col3='up', col4='up' " + "WHERE dt.id = OLD.id; END", + on="sqlite", + ), + sa.DDL( + "CREATE TRIGGER dt_up ON dt AFTER UPDATE AS " + "UPDATE dt SET col3='up', col4='up' " + "WHERE dt.id IN (SELECT id FROM deleted);", + on="mssql", + ), + sa.DDL( + "CREATE TRIGGER dt_up BEFORE UPDATE ON dt " + "FOR EACH ROW BEGIN " + ":NEW.col3 := 'up'; :NEW.col4 := 'up'; END;", + on="oracle", + ), + sa.DDL( + "CREATE TRIGGER dt_up BEFORE UPDATE ON dt " + "FOR EACH ROW BEGIN " + "SET NEW.col3='up'; SET NEW.col4='up'; END", + on=lambda ddl, event, target, bind, **kw: bind.engine.name + not in ("oracle", "mssql", "sqlite"), + ), ): - event.listen(dt, 'after_create', up) + event.listen(dt, "after_create", up) - event.listen(dt, 'before_drop', sa.DDL("DROP TRIGGER dt_up")) + event.listen(dt, "before_drop", sa.DDL("DROP TRIGGER dt_up")) @classmethod def setup_classes(cls): @@ -101,10 +124,10 @@ class TriggerDefaultsTest(fixtures.MappedTest): session.flush() eq_(d1.col1, None) - eq_(d1.col2, 'ins') + eq_(d1.col2, "ins") eq_(d1.col3, None) # don't care which trigger fired - assert d1.col4 in ('ins', 'up') + assert d1.col4 in ("ins", "up") def test_update(self): Default = self.classes.Default @@ -114,29 +137,34 @@ class TriggerDefaultsTest(fixtures.MappedTest): session = create_session() session.add(d1) session.flush() - d1.col1 = 'set' + d1.col1 = "set" session.flush() - eq_(d1.col1, 'set') - eq_(d1.col2, 'ins') - eq_(d1.col3, 'up') - eq_(d1.col4, 'up') + eq_(d1.col1, "set") + eq_(d1.col2, "ins") + eq_(d1.col3, "up") + eq_(d1.col4, "up") class ExcludedDefaultsTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - dt = Table('dt', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('col1', String(20), default="hello")) + dt = Table( + "dt", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("col1", String(20), default="hello"), + ) def test_exclude(self): dt = self.tables.dt class Foo(fixtures.BasicEntity): pass - mapper(Foo, dt, exclude_properties=('col1',)) + + mapper(Foo, dt, exclude_properties=("col1",)) f1 = Foo() sess = create_session() diff --git a/test/orm/test_deferred.py b/test/orm/test_deferred.py index 3f2a8a06b4..29dd6947ea 100644 --- a/test/orm/test_deferred.py +++ b/test/orm/test_deferred.py @@ -1,9 +1,26 @@ import sqlalchemy as sa from sqlalchemy import testing, util -from sqlalchemy.orm import mapper, deferred, defer, undefer, Load, \ - load_only, undefer_group, create_session, synonym, relationship, Session,\ - joinedload, defaultload, aliased, contains_eager, with_polymorphic, \ - query_expression, with_expression, subqueryload +from sqlalchemy.orm import ( + mapper, + deferred, + defer, + undefer, + Load, + load_only, + undefer_group, + create_session, + synonym, + relationship, + Session, + joinedload, + defaultload, + aliased, + contains_eager, + with_polymorphic, + query_expression, + with_expression, + subqueryload, +) from sqlalchemy.testing import eq_, AssertsCompiledSQL, assert_raises_message from test.orm import _fixtures from sqlalchemy.testing.schema import Column @@ -11,19 +28,29 @@ from sqlalchemy import Integer, ForeignKey from sqlalchemy.testing import fixtures -from .inheritance._poly_fixtures import Company, Person, Engineer, Manager, \ - Boss, Machine, Paperwork, _Polymorphic +from .inheritance._poly_fixtures import ( + Company, + Person, + Engineer, + Manager, + Boss, + Machine, + Paperwork, + _Polymorphic, +) class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): - def test_basic(self): """A basic deferred load.""" Order, orders = self.classes.Order, self.tables.orders - mapper(Order, orders, properties={ - 'description': deferred(orders.c.description)}) + mapper( + Order, + orders, + properties={"description": deferred(orders.c.description)}, + ) o = Order() self.assert_(o.description is None) @@ -35,30 +62,36 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): o2 = result[2] x = o2.description - self.sql_eq_(go, [ - ("SELECT orders.id AS orders_id, " - "orders.user_id AS orders_user_id, " - "orders.address_id AS orders_address_id, " - "orders.isopen AS orders_isopen " - "FROM orders ORDER BY orders.id", {}), - ("SELECT orders.description AS orders_description " - "FROM orders WHERE orders.id = :param_1", - {'param_1': 3})]) + self.sql_eq_( + go, + [ + ( + "SELECT orders.id AS orders_id, " + "orders.user_id AS orders_user_id, " + "orders.address_id AS orders_address_id, " + "orders.isopen AS orders_isopen " + "FROM orders ORDER BY orders.id", + {}, + ), + ( + "SELECT orders.description AS orders_description " + "FROM orders WHERE orders.id = :param_1", + {"param_1": 3}, + ), + ], + ) def test_defer_primary_key(self): """what happens when we try to defer the primary key?""" Order, orders = self.classes.Order, self.tables.orders - mapper(Order, orders, properties={ - 'id': deferred(orders.c.id)}) + mapper(Order, orders, properties={"id": deferred(orders.c.id)}) # right now, it's not that graceful :) q = create_session().query(Order) assert_raises_message( - sa.exc.NoSuchColumnError, - "Could not locate", - q.first + sa.exc.NoSuchColumnError, "Could not locate", q.first ) def test_unsaved(self): @@ -66,8 +99,11 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): Order, orders = self.classes.Order, self.tables.orders - mapper(Order, orders, properties={ - 'description': deferred(orders.c.description)}) + mapper( + Order, + orders, + properties={"description": deferred(orders.c.description)}, + ) sess = create_session() o = Order() @@ -76,15 +112,20 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): def go(): o.description = "some description" + self.sql_count_(0, go) def test_synonym_group_bug(self): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties={ - 'isopen': synonym('_isopen', map_column=True), - 'description': deferred(orders.c.description, group='foo') - }) + mapper( + Order, + orders, + properties={ + "isopen": synonym("_isopen", map_column=True), + "description": deferred(orders.c.description, group="foo"), + }, + ) sess = create_session() o1 = sess.query(Order).get(1) @@ -93,8 +134,11 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): def test_unsaved_2(self): Order, orders = self.classes.Order, self.tables.orders - mapper(Order, orders, properties={ - 'description': deferred(orders.c.description)}) + mapper( + Order, + orders, + properties={"description": deferred(orders.c.description)}, + ) sess = create_session() o = Order() @@ -102,6 +146,7 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): def go(): o.description = "some description" + self.sql_count_(0, go) def test_unsaved_group(self): @@ -109,9 +154,14 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties=dict( - description=deferred(orders.c.description, group='primary'), - opened=deferred(orders.c.isopen, group='primary'))) + mapper( + Order, + orders, + properties=dict( + description=deferred(orders.c.description, group="primary"), + opened=deferred(orders.c.isopen, group="primary"), + ), + ) sess = create_session() o = Order() @@ -120,14 +170,20 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): def go(): o.description = "some description" + self.sql_count_(0, go) def test_unsaved_group_2(self): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties=dict( - description=deferred(orders.c.description, group='primary'), - opened=deferred(orders.c.isopen, group='primary'))) + mapper( + Order, + orders, + properties=dict( + description=deferred(orders.c.description, group="primary"), + opened=deferred(orders.c.isopen, group="primary"), + ), + ) sess = create_session() o = Order() @@ -135,13 +191,17 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): def go(): o.description = "some description" + self.sql_count_(0, go) def test_save(self): Order, orders = self.classes.Order, self.tables.orders - m = mapper(Order, orders, properties={ - 'description': deferred(orders.c.description)}) + m = mapper( + Order, + orders, + properties={"description": deferred(orders.c.description)}, + ) sess = create_session() o2 = sess.query(Order).get(2) @@ -153,12 +213,24 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties=util.OrderedDict([ - ('userident', deferred(orders.c.user_id, group='primary')), - ('addrident', deferred(orders.c.address_id, group='primary')), - ('description', deferred(orders.c.description, group='primary')), - ('opened', deferred(orders.c.isopen, group='primary')) - ])) + mapper( + Order, + orders, + properties=util.OrderedDict( + [ + ("userident", deferred(orders.c.user_id, group="primary")), + ( + "addrident", + deferred(orders.c.address_id, group="primary"), + ), + ( + "description", + deferred(orders.c.description, group="primary"), + ), + ("opened", deferred(orders.c.isopen, group="primary")), + ] + ), + ) sess = create_session() q = sess.query(Order).order_by(Order.id) @@ -168,25 +240,35 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): o2 = result[2] eq_(o2.opened, 1) eq_(o2.userident, 7) - eq_(o2.description, 'order 3') - - self.sql_eq_(go, [ - ("SELECT orders.id AS orders_id " - "FROM orders ORDER BY orders.id", {}), - ("SELECT orders.user_id AS orders_user_id, " - "orders.address_id AS orders_address_id, " - "orders.description AS orders_description, " - "orders.isopen AS orders_isopen " - "FROM orders WHERE orders.id = :param_1", - {'param_1': 3})]) + eq_(o2.description, "order 3") + + self.sql_eq_( + go, + [ + ( + "SELECT orders.id AS orders_id " + "FROM orders ORDER BY orders.id", + {}, + ), + ( + "SELECT orders.user_id AS orders_user_id, " + "orders.address_id AS orders_address_id, " + "orders.description AS orders_description, " + "orders.isopen AS orders_isopen " + "FROM orders WHERE orders.id = :param_1", + {"param_1": 3}, + ), + ], + ) o2 = q.all()[2] - eq_(o2.description, 'order 3') + eq_(o2.description, "order 3") assert o2 not in sess.dirty - o2.description = 'order 3' + o2.description = "order 3" def go(): sess.flush() + self.sql_count_(0, go) def test_preserve_changes(self): @@ -195,21 +277,26 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties={ - 'userident': deferred(orders.c.user_id, group='primary'), - 'description': deferred(orders.c.description, group='primary'), - 'opened': deferred(orders.c.isopen, group='primary') - }) + mapper( + Order, + orders, + properties={ + "userident": deferred(orders.c.user_id, group="primary"), + "description": deferred(orders.c.description, group="primary"), + "opened": deferred(orders.c.isopen, group="primary"), + }, + ) sess = create_session() o = sess.query(Order).get(3) - assert 'userident' not in o.__dict__ - o.description = 'somenewdescription' - eq_(o.description, 'somenewdescription') + assert "userident" not in o.__dict__ + o.description = "somenewdescription" + eq_(o.description, "somenewdescription") def go(): eq_(o.opened, 1) + self.assert_sql_count(testing.db, go, 1) - eq_(o.description, 'somenewdescription') + eq_(o.description, "somenewdescription") assert o in sess.dirty def test_commits_state(self): @@ -221,19 +308,24 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties={ - 'userident': deferred(orders.c.user_id, group='primary'), - 'description': deferred(orders.c.description, group='primary'), - 'opened': deferred(orders.c.isopen, group='primary')}) + mapper( + Order, + orders, + properties={ + "userident": deferred(orders.c.user_id, group="primary"), + "description": deferred(orders.c.description, group="primary"), + "opened": deferred(orders.c.isopen, group="primary"), + }, + ) sess = create_session() o2 = sess.query(Order).get(3) # this will load the group of attributes - eq_(o2.description, 'order 3') + eq_(o2.description, "order 3") assert o2 not in sess.dirty # this will mark it as 'dirty', but nothing actually changed - o2.description = 'order 3' + o2.description = "order 3" # therefore the flush() shouldn't actually issue any SQL self.assert_sql_count(testing.db, sess.flush, 0) @@ -245,24 +337,29 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): Order, orders = self.classes.Order, self.tables.orders - order_select = sa.select([ - orders.c.id, - orders.c.user_id, - orders.c.address_id, - orders.c.description, - orders.c.isopen]).alias() - mapper(Order, order_select, properties={ - 'description': deferred(order_select.c.description) - }) + order_select = sa.select( + [ + orders.c.id, + orders.c.user_id, + orders.c.address_id, + orders.c.description, + orders.c.isopen, + ] + ).alias() + mapper( + Order, + order_select, + properties={"description": deferred(order_select.c.description)}, + ) sess = Session() o1 = sess.query(Order).order_by(Order.id).first() - assert 'description' not in o1.__dict__ - eq_(o1.description, 'order 1') + assert "description" not in o1.__dict__ + eq_(o1.description, "order 1") class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): - __dialect__ = 'default' + __dialect__ = "default" def test_options(self): """Options on a mapper to create deferred and undeferred columns""" @@ -272,142 +369,213 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): mapper(Order, orders) sess = create_session() - q = sess.query(Order).order_by(Order.id).options(defer('user_id')) + q = sess.query(Order).order_by(Order.id).options(defer("user_id")) def go(): q.all()[0].user_id - self.sql_eq_(go, [ - ("SELECT orders.id AS orders_id, " - "orders.address_id AS orders_address_id, " - "orders.description AS orders_description, " - "orders.isopen AS orders_isopen " - "FROM orders ORDER BY orders.id", {}), - ("SELECT orders.user_id AS orders_user_id " - "FROM orders WHERE orders.id = :param_1", - {'param_1': 1})]) + self.sql_eq_( + go, + [ + ( + "SELECT orders.id AS orders_id, " + "orders.address_id AS orders_address_id, " + "orders.description AS orders_description, " + "orders.isopen AS orders_isopen " + "FROM orders ORDER BY orders.id", + {}, + ), + ( + "SELECT orders.user_id AS orders_user_id " + "FROM orders WHERE orders.id = :param_1", + {"param_1": 1}, + ), + ], + ) sess.expunge_all() - q2 = q.options(undefer('user_id')) - self.sql_eq_(q2.all, [ - ("SELECT orders.id AS orders_id, " - "orders.user_id AS orders_user_id, " - "orders.address_id AS orders_address_id, " - "orders.description AS orders_description, " - "orders.isopen AS orders_isopen " - "FROM orders ORDER BY orders.id", - {})]) + q2 = q.options(undefer("user_id")) + self.sql_eq_( + q2.all, + [ + ( + "SELECT orders.id AS orders_id, " + "orders.user_id AS orders_user_id, " + "orders.address_id AS orders_address_id, " + "orders.description AS orders_description, " + "orders.isopen AS orders_isopen " + "FROM orders ORDER BY orders.id", + {}, + ) + ], + ) def test_undefer_group(self): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties=util.OrderedDict([ - ('userident', deferred(orders.c.user_id, group='primary')), - ('description', deferred(orders.c.description, group='primary')), - ('opened', deferred(orders.c.isopen, group='primary')) - ] - )) + mapper( + Order, + orders, + properties=util.OrderedDict( + [ + ("userident", deferred(orders.c.user_id, group="primary")), + ( + "description", + deferred(orders.c.description, group="primary"), + ), + ("opened", deferred(orders.c.isopen, group="primary")), + ] + ), + ) sess = create_session() q = sess.query(Order).order_by(Order.id) def go(): - result = q.options(undefer_group('primary')).all() + result = q.options(undefer_group("primary")).all() o2 = result[2] eq_(o2.opened, 1) eq_(o2.userident, 7) - eq_(o2.description, 'order 3') + eq_(o2.description, "order 3") - self.sql_eq_(go, [ - ("SELECT orders.user_id AS orders_user_id, " - "orders.description AS orders_description, " - "orders.isopen AS orders_isopen, " - "orders.id AS orders_id, " - "orders.address_id AS orders_address_id " - "FROM orders ORDER BY orders.id", - {})]) + self.sql_eq_( + go, + [ + ( + "SELECT orders.user_id AS orders_user_id, " + "orders.description AS orders_description, " + "orders.isopen AS orders_isopen, " + "orders.id AS orders_id, " + "orders.address_id AS orders_address_id " + "FROM orders ORDER BY orders.id", + {}, + ) + ], + ) def test_undefer_group_multi(self): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties=util.OrderedDict([ - ('userident', deferred(orders.c.user_id, group='primary')), - ('description', deferred(orders.c.description, group='primary')), - ('opened', deferred(orders.c.isopen, group='secondary')) - ] - )) + mapper( + Order, + orders, + properties=util.OrderedDict( + [ + ("userident", deferred(orders.c.user_id, group="primary")), + ( + "description", + deferred(orders.c.description, group="primary"), + ), + ("opened", deferred(orders.c.isopen, group="secondary")), + ] + ), + ) sess = create_session() q = sess.query(Order).order_by(Order.id) def go(): result = q.options( - undefer_group('primary'), undefer_group('secondary')).all() + undefer_group("primary"), undefer_group("secondary") + ).all() o2 = result[2] eq_(o2.opened, 1) eq_(o2.userident, 7) - eq_(o2.description, 'order 3') + eq_(o2.description, "order 3") - self.sql_eq_(go, [ - ("SELECT orders.user_id AS orders_user_id, " - "orders.description AS orders_description, " - "orders.isopen AS orders_isopen, " - "orders.id AS orders_id, " - "orders.address_id AS orders_address_id " - "FROM orders ORDER BY orders.id", - {})]) + self.sql_eq_( + go, + [ + ( + "SELECT orders.user_id AS orders_user_id, " + "orders.description AS orders_description, " + "orders.isopen AS orders_isopen, " + "orders.id AS orders_id, " + "orders.address_id AS orders_address_id " + "FROM orders ORDER BY orders.id", + {}, + ) + ], + ) def test_undefer_group_multi_pathed(self): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties=util.OrderedDict([ - ('userident', deferred(orders.c.user_id, group='primary')), - ('description', deferred(orders.c.description, group='primary')), - ('opened', deferred(orders.c.isopen, group='secondary')) - ])) + mapper( + Order, + orders, + properties=util.OrderedDict( + [ + ("userident", deferred(orders.c.user_id, group="primary")), + ( + "description", + deferred(orders.c.description, group="primary"), + ), + ("opened", deferred(orders.c.isopen, group="secondary")), + ] + ), + ) sess = create_session() q = sess.query(Order).order_by(Order.id) def go(): result = q.options( - Load(Order).undefer_group('primary').undefer_group('secondary') + Load(Order).undefer_group("primary").undefer_group("secondary") ).all() o2 = result[2] eq_(o2.opened, 1) eq_(o2.userident, 7) - eq_(o2.description, 'order 3') + eq_(o2.description, "order 3") - self.sql_eq_(go, [ - ("SELECT orders.user_id AS orders_user_id, " - "orders.description AS orders_description, " - "orders.isopen AS orders_isopen, " - "orders.id AS orders_id, " - "orders.address_id AS orders_address_id " - "FROM orders ORDER BY orders.id", - {})]) + self.sql_eq_( + go, + [ + ( + "SELECT orders.user_id AS orders_user_id, " + "orders.description AS orders_description, " + "orders.isopen AS orders_isopen, " + "orders.id AS orders_id, " + "orders.address_id AS orders_address_id " + "FROM orders ORDER BY orders.id", + {}, + ) + ], + ) def test_undefer_group_from_relationship_lazyload(self): - users, Order, User, orders = \ - (self.tables.users, - self.classes.Order, - self.classes.User, - self.tables.orders) - - mapper(User, users, properties=dict( - orders=relationship(Order, order_by=orders.c.id))) + users, Order, User, orders = ( + self.tables.users, + self.classes.Order, + self.classes.User, + self.tables.orders, + ) + mapper( - Order, orders, properties=util.OrderedDict([ - ('userident', deferred(orders.c.user_id, group='primary')), - ('description', deferred(orders.c.description, - group='primary')), - ('opened', deferred(orders.c.isopen, group='primary')) - ]) + User, + users, + properties=dict(orders=relationship(Order, order_by=orders.c.id)), + ) + mapper( + Order, + orders, + properties=util.OrderedDict( + [ + ("userident", deferred(orders.c.user_id, group="primary")), + ( + "description", + deferred(orders.c.description, group="primary"), + ), + ("opened", deferred(orders.c.isopen, group="primary")), + ] + ), ) sess = create_session() - q = sess.query(User).filter(User.id == 7).options( - defaultload(User.orders).undefer_group('primary') + q = ( + sess.query(User) + .filter(User.id == 7) + .options(defaultload(User.orders).undefer_group("primary")) ) def go(): @@ -415,37 +583,59 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): o2 = result[0].orders[1] eq_(o2.opened, 1) eq_(o2.userident, 7) - eq_(o2.description, 'order 3') - self.sql_eq_(go, [ - ("SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.id = :id_1", {"id_1": 7}), - ("SELECT orders.user_id AS orders_user_id, orders.description " - "AS orders_description, orders.isopen AS orders_isopen, " - "orders.id AS orders_id, orders.address_id AS orders_address_id " - "FROM orders WHERE :param_1 = orders.user_id ORDER BY orders.id", - {'param_1': 7})]) + eq_(o2.description, "order 3") + + self.sql_eq_( + go, + [ + ( + "SELECT users.id AS users_id, users.name AS users_name " + "FROM users WHERE users.id = :id_1", + {"id_1": 7}, + ), + ( + "SELECT orders.user_id AS orders_user_id, orders.description " + "AS orders_description, orders.isopen AS orders_isopen, " + "orders.id AS orders_id, orders.address_id AS orders_address_id " + "FROM orders WHERE :param_1 = orders.user_id ORDER BY orders.id", + {"param_1": 7}, + ), + ], + ) def test_undefer_group_from_relationship_subqueryload(self): - users, Order, User, orders = \ - (self.tables.users, - self.classes.Order, - self.classes.User, - self.tables.orders) - - mapper(User, users, properties=dict( - orders=relationship(Order, order_by=orders.c.id))) + users, Order, User, orders = ( + self.tables.users, + self.classes.Order, + self.classes.User, + self.tables.orders, + ) + + mapper( + User, + users, + properties=dict(orders=relationship(Order, order_by=orders.c.id)), + ) mapper( - Order, orders, properties=util.OrderedDict([ - ('userident', deferred(orders.c.user_id, group='primary')), - ('description', deferred(orders.c.description, - group='primary')), - ('opened', deferred(orders.c.isopen, group='primary')) - ]) + Order, + orders, + properties=util.OrderedDict( + [ + ("userident", deferred(orders.c.user_id, group="primary")), + ( + "description", + deferred(orders.c.description, group="primary"), + ), + ("opened", deferred(orders.c.isopen, group="primary")), + ] + ), ) sess = create_session() - q = sess.query(User).filter(User.id == 7).options( - subqueryload(User.orders).undefer_group('primary') + q = ( + sess.query(User) + .filter(User.id == 7) + .options(subqueryload(User.orders).undefer_group("primary")) ) def go(): @@ -453,40 +643,62 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): o2 = result[0].orders[1] eq_(o2.opened, 1) eq_(o2.userident, 7) - eq_(o2.description, 'order 3') - self.sql_eq_(go, [ - ("SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.id = :id_1", {"id_1": 7}), - ("SELECT orders.user_id AS orders_user_id, orders.description " - "AS orders_description, orders.isopen AS orders_isopen, " - "orders.id AS orders_id, orders.address_id AS orders_address_id, " - "anon_1.users_id AS anon_1_users_id FROM (SELECT users.id AS " - "users_id FROM users WHERE users.id = :id_1) AS anon_1 " - "JOIN orders ON anon_1.users_id = orders.user_id ORDER BY " - "anon_1.users_id, orders.id", [{'id_1': 7}])] + eq_(o2.description, "order 3") + + self.sql_eq_( + go, + [ + ( + "SELECT users.id AS users_id, users.name AS users_name " + "FROM users WHERE users.id = :id_1", + {"id_1": 7}, + ), + ( + "SELECT orders.user_id AS orders_user_id, orders.description " + "AS orders_description, orders.isopen AS orders_isopen, " + "orders.id AS orders_id, orders.address_id AS orders_address_id, " + "anon_1.users_id AS anon_1_users_id FROM (SELECT users.id AS " + "users_id FROM users WHERE users.id = :id_1) AS anon_1 " + "JOIN orders ON anon_1.users_id = orders.user_id ORDER BY " + "anon_1.users_id, orders.id", + [{"id_1": 7}], + ), + ], ) def test_undefer_group_from_relationship_joinedload(self): - users, Order, User, orders = \ - (self.tables.users, - self.classes.Order, - self.classes.User, - self.tables.orders) - - mapper(User, users, properties=dict( - orders=relationship(Order, order_by=orders.c.id))) + users, Order, User, orders = ( + self.tables.users, + self.classes.Order, + self.classes.User, + self.tables.orders, + ) + mapper( - Order, orders, properties=util.OrderedDict([ - ('userident', deferred(orders.c.user_id, group='primary')), - ('description', deferred(orders.c.description, - group='primary')), - ('opened', deferred(orders.c.isopen, group='primary')) - ]) + User, + users, + properties=dict(orders=relationship(Order, order_by=orders.c.id)), + ) + mapper( + Order, + orders, + properties=util.OrderedDict( + [ + ("userident", deferred(orders.c.user_id, group="primary")), + ( + "description", + deferred(orders.c.description, group="primary"), + ), + ("opened", deferred(orders.c.isopen, group="primary")), + ] + ), ) sess = create_session() - q = sess.query(User).filter(User.id == 7).options( - joinedload(User.orders).undefer_group('primary') + q = ( + sess.query(User) + .filter(User.id == 7) + .options(joinedload(User.orders).undefer_group("primary")) ) def go(): @@ -494,40 +706,61 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): o2 = result[0].orders[1] eq_(o2.opened, 1) eq_(o2.userident, 7) - eq_(o2.description, 'order 3') - self.sql_eq_(go, [ - ("SELECT users.id AS users_id, users.name AS users_name, " - "orders_1.user_id AS orders_1_user_id, orders_1.description AS " - "orders_1_description, orders_1.isopen AS orders_1_isopen, " - "orders_1.id AS orders_1_id, orders_1.address_id AS " - "orders_1_address_id FROM users " - "LEFT OUTER JOIN orders AS orders_1 ON users.id = " - "orders_1.user_id WHERE users.id = :id_1 " - "ORDER BY orders_1.id", {"id_1": 7})] + eq_(o2.description, "order 3") + + self.sql_eq_( + go, + [ + ( + "SELECT users.id AS users_id, users.name AS users_name, " + "orders_1.user_id AS orders_1_user_id, orders_1.description AS " + "orders_1_description, orders_1.isopen AS orders_1_isopen, " + "orders_1.id AS orders_1_id, orders_1.address_id AS " + "orders_1_address_id FROM users " + "LEFT OUTER JOIN orders AS orders_1 ON users.id = " + "orders_1.user_id WHERE users.id = :id_1 " + "ORDER BY orders_1.id", + {"id_1": 7}, + ) + ], ) def test_undefer_group_from_relationship_joinedload_colexpr(self): - users, Order, User, orders = \ - (self.tables.users, - self.classes.Order, - self.classes.User, - self.tables.orders) - - mapper(User, users, properties=dict( - orders=relationship(Order, order_by=orders.c.id))) + users, Order, User, orders = ( + self.tables.users, + self.classes.Order, + self.classes.User, + self.tables.orders, + ) + + mapper( + User, + users, + properties=dict(orders=relationship(Order, order_by=orders.c.id)), + ) mapper( - Order, orders, properties=util.OrderedDict([ - ('userident', deferred(orders.c.user_id, group='primary')), - ('lower_desc', deferred( - sa.func.lower(orders.c.description).label(None), - group='primary')), - ('opened', deferred(orders.c.isopen, group='primary')) - ]) + Order, + orders, + properties=util.OrderedDict( + [ + ("userident", deferred(orders.c.user_id, group="primary")), + ( + "lower_desc", + deferred( + sa.func.lower(orders.c.description).label(None), + group="primary", + ), + ), + ("opened", deferred(orders.c.isopen, group="primary")), + ] + ), ) sess = create_session() - q = sess.query(User).filter(User.id == 7).options( - joinedload(User.orders).undefer_group('primary') + q = ( + sess.query(User) + .filter(User.id == 7) + .options(joinedload(User.orders).undefer_group("primary")) ) def go(): @@ -535,37 +768,52 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): o2 = result[0].orders[1] eq_(o2.opened, 1) eq_(o2.userident, 7) - eq_(o2.lower_desc, 'order 3') - self.sql_eq_(go, [ - ("SELECT users.id AS users_id, users.name AS users_name, " - "orders_1.user_id AS orders_1_user_id, " - "lower(orders_1.description) AS lower_1, " - "orders_1.isopen AS orders_1_isopen, orders_1.id AS orders_1_id, " - "orders_1.address_id AS orders_1_address_id, " - "orders_1.description AS orders_1_description FROM users " - "LEFT OUTER JOIN orders AS orders_1 ON users.id = " - "orders_1.user_id WHERE users.id = :id_1 " - "ORDER BY orders_1.id", {"id_1": 7})] + eq_(o2.lower_desc, "order 3") + + self.sql_eq_( + go, + [ + ( + "SELECT users.id AS users_id, users.name AS users_name, " + "orders_1.user_id AS orders_1_user_id, " + "lower(orders_1.description) AS lower_1, " + "orders_1.isopen AS orders_1_isopen, orders_1.id AS orders_1_id, " + "orders_1.address_id AS orders_1_address_id, " + "orders_1.description AS orders_1_description FROM users " + "LEFT OUTER JOIN orders AS orders_1 ON users.id = " + "orders_1.user_id WHERE users.id = :id_1 " + "ORDER BY orders_1.id", + {"id_1": 7}, + ) + ], ) def test_undefer_star(self): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties=util.OrderedDict([ - ('userident', deferred(orders.c.user_id)), - ('description', deferred(orders.c.description)), - ('opened', deferred(orders.c.isopen)) - ])) + mapper( + Order, + orders, + properties=util.OrderedDict( + [ + ("userident", deferred(orders.c.user_id)), + ("description", deferred(orders.c.description)), + ("opened", deferred(orders.c.isopen)), + ] + ), + ) sess = create_session() - q = sess.query(Order).options(Load(Order).undefer('*')) - self.assert_compile(q, - "SELECT orders.user_id AS orders_user_id, " - "orders.description AS orders_description, " - "orders.isopen AS orders_isopen, " - "orders.id AS orders_id, " - "orders.address_id AS orders_address_id " - "FROM orders") + q = sess.query(Order).options(Load(Order).undefer("*")) + self.assert_compile( + q, + "SELECT orders.user_id AS orders_user_id, " + "orders.description AS orders_description, " + "orders.isopen AS orders_isopen, " + "orders.id AS orders_id, " + "orders.address_id AS orders_address_id " + "FROM orders", + ) def test_locates_col(self): """changed in 1.0 - we don't search for deferred cols in the result @@ -573,16 +821,23 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties={ - 'description': deferred(orders.c.description)}) + mapper( + Order, + orders, + properties={"description": deferred(orders.c.description)}, + ) sess = create_session() - o1 = (sess.query(Order). - order_by(Order.id). - add_column(orders.c.description).first())[0] + o1 = ( + sess.query(Order) + .order_by(Order.id) + .add_column(orders.c.description) + .first() + )[0] def go(): - eq_(o1.description, 'order 1') + eq_(o1.description, "order 1") + # prior to 1.0 we'd search in the result for this column # self.sql_count_(0, go) self.sql_count_(1, go) @@ -599,36 +854,49 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties={ - 'description': deferred(orders.c.description)}) + mapper( + Order, + orders, + properties={"description": deferred(orders.c.description)}, + ) sess = create_session() stmt = sa.select([Order]).order_by(Order.id) - o1 = (sess.query(Order). - from_statement(stmt).all())[0] + o1 = (sess.query(Order).from_statement(stmt).all())[0] def go(): - eq_(o1.description, 'order 1') + eq_(o1.description, "order 1") + # prior to 1.0 we'd search in the result for this column # self.sql_count_(0, go) self.sql_count_(1, go) def test_deep_options(self): - users, items, order_items, Order, Item, User, orders = \ - (self.tables.users, - self.tables.items, - self.tables.order_items, - self.classes.Order, - self.classes.Item, - self.classes.User, - self.tables.orders) - - mapper(Item, items, properties=dict( - description=deferred(items.c.description))) - mapper(Order, orders, properties=dict( - items=relationship(Item, secondary=order_items))) - mapper(User, users, properties=dict( - orders=relationship(Order, order_by=orders.c.id))) + users, items, order_items, Order, Item, User, orders = ( + self.tables.users, + self.tables.items, + self.tables.order_items, + self.classes.Order, + self.classes.Item, + self.classes.User, + self.tables.orders, + ) + + mapper( + Item, + items, + properties=dict(description=deferred(items.c.description)), + ) + mapper( + Order, + orders, + properties=dict(items=relationship(Item, secondary=order_items)), + ) + mapper( + User, + users, + properties=dict(orders=relationship(Order, order_by=orders.c.id)), + ) sess = create_session() q = sess.query(User).order_by(User.id) @@ -636,18 +904,20 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): item = result[0].orders[1].items[1] def go(): - eq_(item.description, 'item 4') + eq_(item.description, "item 4") + self.sql_count_(1, go) - eq_(item.description, 'item 4') + eq_(item.description, "item 4") sess.expunge_all() - result = q.options(undefer('orders.items.description')).all() + result = q.options(undefer("orders.items.description")).all() item = result[0].orders[1].items[1] def go(): - eq_(item.description, 'item 4') + eq_(item.description, "item 4") + self.sql_count_(0, go) - eq_(item.description, 'item 4') + eq_(item.description, "item 4") def test_path_entity(self): """test the legacy *addl_attrs argument.""" @@ -661,53 +931,64 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): items = self.tables.items order_items = self.tables.order_items - mapper(User, users, properties={ - "orders": relationship(Order, lazy="joined") - }) - mapper(Order, orders, properties={ - "items": relationship(Item, secondary=order_items, lazy="joined") - }) + mapper( + User, + users, + properties={"orders": relationship(Order, lazy="joined")}, + ) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, secondary=order_items, lazy="joined" + ) + }, + ) mapper(Item, items) sess = create_session() - exp = ("SELECT users.id AS users_id, users.name AS users_name, " - "items_1.id AS items_1_id, orders_1.id AS orders_1_id, " - "orders_1.user_id AS orders_1_user_id, orders_1.address_id " - "AS orders_1_address_id, orders_1.description AS " - "orders_1_description, orders_1.isopen AS orders_1_isopen " - "FROM users LEFT OUTER JOIN orders AS orders_1 " - "ON users.id = orders_1.user_id LEFT OUTER JOIN " - "(order_items AS order_items_1 JOIN items AS items_1 " - "ON items_1.id = order_items_1.item_id) " - "ON orders_1.id = order_items_1.order_id") + exp = ( + "SELECT users.id AS users_id, users.name AS users_name, " + "items_1.id AS items_1_id, orders_1.id AS orders_1_id, " + "orders_1.user_id AS orders_1_user_id, orders_1.address_id " + "AS orders_1_address_id, orders_1.description AS " + "orders_1_description, orders_1.isopen AS orders_1_isopen " + "FROM users LEFT OUTER JOIN orders AS orders_1 " + "ON users.id = orders_1.user_id LEFT OUTER JOIN " + "(order_items AS order_items_1 JOIN items AS items_1 " + "ON items_1.id = order_items_1.item_id) " + "ON orders_1.id = order_items_1.order_id" + ) q = sess.query(User).options( - defer(User.orders, Order.items, Item.description)) + defer(User.orders, Order.items, Item.description) + ) self.assert_compile(q, exp) def test_chained_multi_col_options(self): users, User = self.tables.users, self.classes.User orders, Order = self.tables.orders, self.classes.Order - mapper(User, users, properties={ - "orders": relationship(Order) - }) + mapper(User, users, properties={"orders": relationship(Order)}) mapper(Order, orders) sess = create_session() q = sess.query(User).options( joinedload(User.orders).defer("description").defer("isopen") ) - self.assert_compile(q, - "SELECT users.id AS users_id, " - "users.name AS users_name, " - "orders_1.id AS orders_1_id, " - "orders_1.user_id AS orders_1_user_id, " - "orders_1.address_id AS orders_1_address_id " - "FROM users " - "LEFT OUTER JOIN orders AS orders_1 " - "ON users.id = orders_1.user_id") + self.assert_compile( + q, + "SELECT users.id AS users_id, " + "users.name AS users_name, " + "orders_1.id AS orders_1_id, " + "orders_1.user_id AS orders_1_user_id, " + "orders_1.address_id AS orders_1_address_id " + "FROM users " + "LEFT OUTER JOIN orders AS orders_1 " + "ON users.id = orders_1.user_id", + ) def test_load_only_no_pk(self): orders, Order = self.tables.orders, self.classes.Order @@ -716,10 +997,12 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): sess = create_session() q = sess.query(Order).options(load_only("isopen", "description")) - self.assert_compile(q, - "SELECT orders.id AS orders_id, " - "orders.description AS orders_description, " - "orders.isopen AS orders_isopen FROM orders") + self.assert_compile( + q, + "SELECT orders.id AS orders_id, " + "orders.description AS orders_description, " + "orders.isopen AS orders_isopen FROM orders", + ) def test_load_only_no_pk_rt(self): orders, Order = self.tables.orders, self.classes.Order @@ -727,27 +1010,33 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): mapper(Order, orders) sess = create_session() - q = sess.query(Order).order_by(Order.id).\ - options(load_only("isopen", "description")) + q = ( + sess.query(Order) + .order_by(Order.id) + .options(load_only("isopen", "description")) + ) eq_(q.first(), Order(id=1)) def test_load_only_w_deferred(self): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties={ - "description": deferred(orders.c.description) - }) + mapper( + Order, + orders, + properties={"description": deferred(orders.c.description)}, + ) sess = create_session() q = sess.query(Order).options( - load_only("isopen", "description"), - undefer("user_id") + load_only("isopen", "description"), undefer("user_id") + ) + self.assert_compile( + q, + "SELECT orders.description AS orders_description, " + "orders.id AS orders_id, " + "orders.user_id AS orders_user_id, " + "orders.isopen AS orders_isopen FROM orders", ) - self.assert_compile(q, - "SELECT orders.description AS orders_description, " - "orders.id AS orders_id, " - "orders.user_id AS orders_user_id, " - "orders.isopen AS orders_isopen FROM orders") def test_load_only_propagate_unbound(self): self._test_load_only_propagate(False) @@ -762,29 +1051,36 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): users = self.tables.users addresses = self.tables.addresses - mapper(User, users, properties={ - "addresses": relationship(Address) - }) + mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) sess = create_session() expected = [ - ("SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.id IN (:id_1, :id_2)", {'id_2': 8, - 'id_1': 7}), - ("SELECT addresses.id AS addresses_id, " + ( + "SELECT users.id AS users_id, users.name AS users_name " + "FROM users WHERE users.id IN (:id_1, :id_2)", + {"id_2": 8, "id_1": 7}, + ), + ( + "SELECT addresses.id AS addresses_id, " "addresses.email_address AS addresses_email_address " "FROM addresses WHERE :param_1 = addresses.user_id", - {'param_1': 7}), - ("SELECT addresses.id AS addresses_id, " + {"param_1": 7}, + ), + ( + "SELECT addresses.id AS addresses_id, " "addresses.email_address AS addresses_email_address " "FROM addresses WHERE :param_1 = addresses.user_id", - {'param_1': 8}), + {"param_1": 8}, + ), ] if use_load: - opt = Load(User).defaultload( - User.addresses).load_only("id", "email_address") + opt = ( + Load(User) + .defaultload(User.addresses) + .load_only("id", "email_address") + ) else: opt = defaultload(User.addresses).load_only("id", "email_address") q = sess.query(User).options(opt).filter(User.id.in_([7, 8])) @@ -812,7 +1108,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): q = sess.query(User, Order, Address).options( Load(User).load_only("name"), Load(Order).load_only("id"), - Load(Address).load_only("id", "email_address") + Load(Address).load_only("id", "email_address"), ) self.assert_compile( @@ -822,7 +1118,8 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): "orders.id AS orders_id, " "addresses.id AS addresses_id, " "addresses.email_address AS addresses_email_address " - "FROM users, orders, addresses") + "FROM users, orders, addresses", + ) def test_load_only_path_specific(self): User = self.classes.User @@ -833,10 +1130,16 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): addresses = self.tables.addresses orders = self.tables.orders - mapper(User, users, properties=util.OrderedDict([ - ("addresses", relationship(Address, lazy="joined")), - ("orders", relationship(Order, lazy="joined")) - ])) + mapper( + User, + users, + properties=util.OrderedDict( + [ + ("addresses", relationship(Address, lazy="joined")), + ("orders", relationship(Order, lazy="joined")), + ] + ), + ) mapper(Address, addresses) mapper(Order, orders) @@ -844,9 +1147,10 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): sess = create_session() q = sess.query(User).options( - load_only("name").defaultload( - "addresses").load_only("id", "email_address"), - defaultload("orders").load_only("id") + load_only("name") + .defaultload("addresses") + .load_only("id", "email_address"), + defaultload("orders").load_only("id"), ) # hmmmm joinedload seems to be forcing users.id into here... @@ -858,7 +1162,7 @@ class DeferredOptionsTest(AssertsCompiledSQL, _fixtures.FixtureTest): "orders_1.id AS orders_1_id FROM users " "LEFT OUTER JOIN addresses AS addresses_1 " "ON users.id = addresses_1.user_id " - "LEFT OUTER JOIN orders AS orders_1 ON users.id = orders_1.user_id" + "LEFT OUTER JOIN orders AS orders_1 ON users.id = orders_1.user_id", ) @@ -870,11 +1174,11 @@ class SelfReferentialMultiPathTest(testing.fixtures.DeclarativeMappedTest): Base = cls.DeclarativeBasic class Node(Base): - __tablename__ = 'node' + __tablename__ = "node" id = sa.Column(sa.Integer, primary_key=True) - parent_id = sa.Column(sa.ForeignKey('node.id')) - parent = relationship('Node', remote_side=[id]) + parent_id = sa.Column(sa.ForeignKey("node.id")) + parent = relationship("Node", remote_side=[id]) name = sa.Column(sa.String(10)) @classmethod @@ -882,11 +1186,13 @@ class SelfReferentialMultiPathTest(testing.fixtures.DeclarativeMappedTest): Node = cls.classes.Node session = Session() - session.add_all([ - Node(id=1, name='name'), - Node(id=2, parent_id=1, name='name'), - Node(id=3, parent_id=1, name='name') - ]) + session.add_all( + [ + Node(id=1, name="name"), + Node(id=2, parent_id=1, name="name"), + Node(id=3, parent_id=1, name="name"), + ] + ) session.commit() def test_present_overrides_deferred(self): @@ -905,24 +1211,28 @@ class SelfReferentialMultiPathTest(testing.fixtures.DeclarativeMappedTest): def go(): for node in nodes: - eq_(node.name, 'name') + eq_(node.name, "name") self.assert_sql_count(testing.db, go, 0) class InheritanceTest(_Polymorphic): - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_mappers(cls): super(InheritanceTest, cls).setup_mappers() from sqlalchemy import inspect + inspect(Company).add_property("managers", relationship(Manager)) def test_load_only_subclass(self): s = Session() - q = s.query(Manager).order_by(Manager.person_id).\ - options(load_only("status", "manager_name")) + q = ( + s.query(Manager) + .order_by(Manager.person_id) + .options(load_only("status", "manager_name")) + ) self.assert_compile( q, "SELECT managers.person_id AS managers_person_id, " @@ -932,13 +1242,16 @@ class InheritanceTest(_Polymorphic): "managers.manager_name AS managers_manager_name " "FROM people JOIN managers " "ON people.person_id = managers.person_id " - "ORDER BY managers.person_id" + "ORDER BY managers.person_id", ) def test_load_only_subclass_bound(self): s = Session() - q = s.query(Manager).order_by(Manager.person_id).\ - options(Load(Manager).load_only("status", "manager_name")) + q = ( + s.query(Manager) + .order_by(Manager.person_id) + .options(Load(Manager).load_only("status", "manager_name")) + ) self.assert_compile( q, "SELECT managers.person_id AS managers_person_id, " @@ -948,13 +1261,16 @@ class InheritanceTest(_Polymorphic): "managers.manager_name AS managers_manager_name " "FROM people JOIN managers " "ON people.person_id = managers.person_id " - "ORDER BY managers.person_id" + "ORDER BY managers.person_id", ) def test_load_only_subclass_and_superclass(self): s = Session() - q = s.query(Boss).order_by(Person.person_id).\ - options(load_only("status", "manager_name")) + q = ( + s.query(Boss) + .order_by(Person.person_id) + .options(load_only("status", "manager_name")) + ) self.assert_compile( q, "SELECT managers.person_id AS managers_person_id, " @@ -964,13 +1280,16 @@ class InheritanceTest(_Polymorphic): "managers.manager_name AS managers_manager_name " "FROM people JOIN managers " "ON people.person_id = managers.person_id JOIN boss " - "ON managers.person_id = boss.boss_id ORDER BY people.person_id" + "ON managers.person_id = boss.boss_id ORDER BY people.person_id", ) def test_load_only_subclass_and_superclass_bound(self): s = Session() - q = s.query(Boss).order_by(Person.person_id).\ - options(Load(Boss).load_only("status", "manager_name")) + q = ( + s.query(Boss) + .order_by(Person.person_id) + .options(Load(Boss).load_only("status", "manager_name")) + ) self.assert_compile( q, "SELECT managers.person_id AS managers_person_id, " @@ -980,14 +1299,17 @@ class InheritanceTest(_Polymorphic): "managers.manager_name AS managers_manager_name " "FROM people JOIN managers " "ON people.person_id = managers.person_id JOIN boss " - "ON managers.person_id = boss.boss_id ORDER BY people.person_id" + "ON managers.person_id = boss.boss_id ORDER BY people.person_id", ) def test_load_only_alias_subclass(self): s = Session() m1 = aliased(Manager, flat=True) - q = s.query(m1).order_by(m1.person_id).\ - options(load_only("status", "manager_name")) + q = ( + s.query(m1) + .order_by(m1.person_id) + .options(load_only("status", "manager_name")) + ) self.assert_compile( q, "SELECT managers_1.person_id AS managers_1_person_id, " @@ -997,14 +1319,17 @@ class InheritanceTest(_Polymorphic): "managers_1.manager_name AS managers_1_manager_name " "FROM people AS people_1 JOIN managers AS " "managers_1 ON people_1.person_id = managers_1.person_id " - "ORDER BY managers_1.person_id" + "ORDER BY managers_1.person_id", ) def test_load_only_alias_subclass_bound(self): s = Session() m1 = aliased(Manager, flat=True) - q = s.query(m1).order_by(m1.person_id).\ - options(Load(m1).load_only("status", "manager_name")) + q = ( + s.query(m1) + .order_by(m1.person_id) + .options(Load(m1).load_only("status", "manager_name")) + ) self.assert_compile( q, "SELECT managers_1.person_id AS managers_1_person_id, " @@ -1014,15 +1339,20 @@ class InheritanceTest(_Polymorphic): "managers_1.manager_name AS managers_1_manager_name " "FROM people AS people_1 JOIN managers AS " "managers_1 ON people_1.person_id = managers_1.person_id " - "ORDER BY managers_1.person_id" + "ORDER BY managers_1.person_id", ) def test_load_only_subclass_from_relationship_polymorphic(self): s = Session() wp = with_polymorphic(Person, [Manager], flat=True) - q = s.query(Company).join(Company.employees.of_type(wp)).options( - contains_eager(Company.employees.of_type(wp)). - load_only(wp.Manager.status, wp.Manager.manager_name) + q = ( + s.query(Company) + .join(Company.employees.of_type(wp)) + .options( + contains_eager(Company.employees.of_type(wp)).load_only( + wp.Manager.status, wp.Manager.manager_name + ) + ) ) self.assert_compile( q, @@ -1036,15 +1366,20 @@ class InheritanceTest(_Polymorphic): "FROM companies JOIN (people AS people_1 LEFT OUTER JOIN " "managers AS managers_1 ON people_1.person_id = " "managers_1.person_id) ON companies.company_id = " - "people_1.company_id" + "people_1.company_id", ) def test_load_only_subclass_from_relationship_polymorphic_bound(self): s = Session() wp = with_polymorphic(Person, [Manager], flat=True) - q = s.query(Company).join(Company.employees.of_type(wp)).options( - Load(Company).contains_eager(Company.employees.of_type(wp)). - load_only(wp.Manager.status, wp.Manager.manager_name) + q = ( + s.query(Company) + .join(Company.employees.of_type(wp)) + .options( + Load(Company) + .contains_eager(Company.employees.of_type(wp)) + .load_only(wp.Manager.status, wp.Manager.manager_name) + ) ) self.assert_compile( q, @@ -1058,14 +1393,19 @@ class InheritanceTest(_Polymorphic): "FROM companies JOIN (people AS people_1 LEFT OUTER JOIN " "managers AS managers_1 ON people_1.person_id = " "managers_1.person_id) ON companies.company_id = " - "people_1.company_id" + "people_1.company_id", ) def test_load_only_subclass_from_relationship(self): s = Session() - q = s.query(Company).join(Company.managers).options( - contains_eager(Company.managers). - load_only("status", "manager_name") + q = ( + s.query(Company) + .join(Company.managers) + .options( + contains_eager(Company.managers).load_only( + "status", "manager_name" + ) + ) ) self.assert_compile( q, @@ -1077,14 +1417,19 @@ class InheritanceTest(_Polymorphic): "managers.status AS managers_status, " "managers.manager_name AS managers_manager_name " "FROM companies JOIN (people JOIN managers ON people.person_id = " - "managers.person_id) ON companies.company_id = people.company_id" + "managers.person_id) ON companies.company_id = people.company_id", ) def test_load_only_subclass_from_relationship_bound(self): s = Session() - q = s.query(Company).join(Company.managers).options( - Load(Company).contains_eager(Company.managers). - load_only("status", "manager_name") + q = ( + s.query(Company) + .join(Company.managers) + .options( + Load(Company) + .contains_eager(Company.managers) + .load_only("status", "manager_name") + ) ) self.assert_compile( q, @@ -1096,7 +1441,7 @@ class InheritanceTest(_Polymorphic): "managers.status AS managers_status, " "managers.manager_name AS managers_manager_name " "FROM companies JOIN (people JOIN managers ON people.person_id = " - "managers.person_id) ON companies.company_id = people.company_id" + "managers.person_id) ON companies.company_id = people.company_id", ) def test_defer_on_wildcard_subclass(self): @@ -1106,13 +1451,16 @@ class InheritanceTest(_Polymorphic): # TODO: what is ".*"? this is not documented anywhere, how did this # get implemented without docs ? see #4390 s = Session() - q = s.query(Manager).order_by(Person.person_id).options( - defer(".*"), undefer("status")) + q = ( + s.query(Manager) + .order_by(Person.person_id) + .options(defer(".*"), undefer("status")) + ) self.assert_compile( q, "SELECT managers.status AS managers_status " "FROM people JOIN managers ON " - "people.person_id = managers.person_id ORDER BY people.person_id" + "people.person_id = managers.person_id ORDER BY people.person_id", ) # note this doesn't apply to "bound" loaders since they don't seem @@ -1130,13 +1478,16 @@ class InheritanceTest(_Polymorphic): "managers.manager_name AS managers_manager_name " "FROM people JOIN managers " "ON people.person_id = managers.person_id " - "ORDER BY people.person_id" + "ORDER BY people.person_id", ) def test_defer_super_name_on_subclass_bound(self): s = Session() - q = s.query(Manager).order_by(Person.person_id).options( - Load(Manager).defer("name")) + q = ( + s.query(Manager) + .order_by(Person.person_id) + .options(Load(Manager).defer("name")) + ) self.assert_compile( q, "SELECT managers.person_id AS managers_person_id, " @@ -1146,18 +1497,17 @@ class InheritanceTest(_Polymorphic): "managers.manager_name AS managers_manager_name " "FROM people JOIN managers " "ON people.person_id = managers.person_id " - "ORDER BY people.person_id" + "ORDER BY people.person_id", ) - class WithExpressionTest(fixtures.DeclarativeMappedTest): @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class A(fixtures.ComparableEntity, Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) x = Column(Integer) y = Column(Integer) @@ -1167,9 +1517,9 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): bs = relationship("B", order_by="B.id") class B(fixtures.ComparableEntity, Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) - a_id = Column(ForeignKey('a.id')) + a_id = Column(ForeignKey("a.id")) p = Column(Integer) q = Column(Integer) @@ -1180,12 +1530,14 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): A, B = cls.classes("A", "B") s = Session() - s.add_all([ - A(id=1, x=1, y=2, bs=[B(id=1, p=1, q=2), B(id=2, p=4, q=8)]), - A(id=2, x=2, y=3), - A(id=3, x=5, y=10, bs=[B(id=3, p=5, q=0)]), - A(id=4, x=2, y=10, bs=[B(id=4, p=19, q=8), B(id=5, p=5, q=5)]), - ]) + s.add_all( + [ + A(id=1, x=1, y=2, bs=[B(id=1, p=1, q=2), B(id=2, p=4, q=8)]), + A(id=2, x=2, y=3), + A(id=3, x=5, y=10, bs=[B(id=3, p=5, q=0)]), + A(id=4, x=2, y=10, bs=[B(id=4, p=19, q=8), B(id=5, p=5, q=5)]), + ] + ) s.commit() @@ -1193,17 +1545,15 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): A = self.classes.A s = Session() - a1 = s.query(A).options( - with_expression(A.my_expr, A.x + A.y)).filter(A.x > 1).\ - order_by(A.id) - - eq_( - a1.all(), - [ - A(my_expr=5), A(my_expr=15), A(my_expr=12) - ] + a1 = ( + s.query(A) + .options(with_expression(A.my_expr, A.x + A.y)) + .filter(A.x > 1) + .order_by(A.id) ) + eq_(a1.all(), [A(my_expr=5), A(my_expr=15), A(my_expr=12)]) + def test_reuse_expr(self): A = self.classes.A @@ -1213,30 +1563,29 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): # but that means Query or Core has to post-modify the statement # after construction. expr = A.x + A.y - a1 = s.query(A).options( - with_expression(A.my_expr, expr)).filter(expr > 10).\ - order_by(expr) - - eq_( - a1.all(), - [A(my_expr=12), A(my_expr=15)] + a1 = ( + s.query(A) + .options(with_expression(A.my_expr, expr)) + .filter(expr > 10) + .order_by(expr) ) + eq_(a1.all(), [A(my_expr=12), A(my_expr=15)]) + def test_in_joinedload(self): A, B = self.classes("A", "B") s = Session() - q = s.query(A).options( - joinedload(A.bs).with_expression(B.b_expr, B.p * A.x) - ).filter(A.id.in_([3, 4])).order_by(A.id) + q = ( + s.query(A) + .options(joinedload(A.bs).with_expression(B.b_expr, B.p * A.x)) + .filter(A.id.in_([3, 4])) + .order_by(A.id) + ) eq_( - q.all(), - [ - A(bs=[B(b_expr=25)]), - A(bs=[B(b_expr=38), B(b_expr=10)]) - ] + q.all(), [A(bs=[B(b_expr=25)]), A(bs=[B(b_expr=38), B(b_expr=10)])] ) def test_no_sql_not_set_up(self): @@ -1254,22 +1603,28 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): A = self.classes.A s = Session() - q = s.query(A).options( - with_expression(A.my_expr, A.x + A.y)).filter(A.x > 1).\ - order_by(A.id) + q = ( + s.query(A) + .options(with_expression(A.my_expr, A.x + A.y)) + .filter(A.x > 1) + .order_by(A.id) + ) a1 = q.first() eq_(a1.my_expr, 5) - s.expire(a1, ['my_expr']) + s.expire(a1, ["my_expr"]) eq_(a1.my_expr, None) # comes back - q = s.query(A).options( - with_expression(A.my_expr, A.x + A.y)).filter(A.x > 1).\ - order_by(A.id) + q = ( + s.query(A) + .options(with_expression(A.my_expr, A.x + A.y)) + .filter(A.x > 1) + .order_by(A.id) + ) q.first() eq_(a1.my_expr, 5) @@ -1277,9 +1632,12 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): A = self.classes.A s = Session() - q = s.query(A).options( - with_expression(A.my_expr, A.x + A.y)).filter(A.x > 1).\ - order_by(A.id) + q = ( + s.query(A) + .options(with_expression(A.my_expr, A.x + A.y)) + .filter(A.x > 1) + .order_by(A.id) + ) a1 = q.first() @@ -1290,9 +1648,11 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): eq_(a1.my_expr, None) # comes back - q = s.query(A).options( - with_expression(A.my_expr, A.x + A.y)).filter(A.x > 1).\ - order_by(A.id) + q = ( + s.query(A) + .options(with_expression(A.my_expr, A.x + A.y)) + .filter(A.x > 1) + .order_by(A.id) + ) q.first() eq_(a1.my_expr, 5) - diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py index e916e8985b..fa051e8840 100644 --- a/test/orm/test_deprecations.py +++ b/test/orm/test_deprecations.py @@ -49,21 +49,27 @@ class QueryAlternativesTest(fixtures.MappedTest): ''' - run_inserts = 'once' + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - Table('users_table', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(64))) - - Table('addresses_table', metadata, - Column('id', Integer, primary_key=True), - Column('user_id', Integer, ForeignKey('users_table.id')), - Column('email_address', String(128)), - Column('purpose', String(16)), - Column('bounces', Integer, default=0)) + Table( + "users_table", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(64)), + ) + + Table( + "addresses_table", + metadata, + Column("id", Integer, primary_key=True), + Column("user_id", Integer, ForeignKey("users_table.id")), + Column("email_address", String(128)), + Column("purpose", String(16)), + Column("bounces", Integer, default=0), + ) @classmethod def setup_classes(cls): @@ -75,33 +81,38 @@ class QueryAlternativesTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - addresses_table, User, users_table, Address = \ - (cls.tables.addresses_table, - cls.classes.User, - cls.tables.users_table, - cls.classes.Address) - - mapper(User, users_table, properties=dict( - addresses=relationship(Address, backref='user'), - )) + addresses_table, User, users_table, Address = ( + cls.tables.addresses_table, + cls.classes.User, + cls.tables.users_table, + cls.classes.Address, + ) + + mapper( + User, + users_table, + properties=dict(addresses=relationship(Address, backref="user")), + ) mapper(Address, addresses_table) @classmethod def fixtures(cls): return dict( users_table=( - ('id', 'name'), - (1, 'jack'), - (2, 'ed'), - (3, 'fred'), - (4, 'chuck')), - + ("id", "name"), + (1, "jack"), + (2, "ed"), + (3, "fred"), + (4, "chuck"), + ), addresses_table=( - ('id', 'user_id', 'email_address', 'purpose', 'bounces'), - (1, 1, 'jack@jack.home', 'Personal', 0), - (2, 1, 'jack@jack.bizz', 'Work', 1), - (3, 2, 'ed@foo.bar', 'Personal', 0), - (4, 3, 'fred@the.fred', 'Personal', 10))) + ("id", "user_id", "email_address", "purpose", "bounces"), + (1, 1, "jack@jack.home", "Personal", 0), + (2, 1, "jack@jack.bizz", "Work", 1), + (3, 2, "ed@foo.bar", "Personal", 0), + (4, 3, "fred@the.fred", "Personal", 10), + ), + ) ###################################################################### @@ -115,6 +126,7 @@ class QueryAlternativesTest(fixtures.MappedTest): Address = self.classes.Address from sqlalchemy.orm.query import Query + cache = {} class MyQuery(Query): @@ -232,11 +244,14 @@ class QueryAlternativesTest(fixtures.MappedTest): session = create_session() - num = session.query(Address).filter_by(purpose='Personal').count() + num = session.query(Address).filter_by(purpose="Personal").count() assert num == 3, num - num = (session.query(User).join('addresses'). - filter(Address.purpose == 'Personal')).count() + num = ( + session.query(User) + .join("addresses") + .filter(Address.purpose == "Personal") + ).count() assert num == 3, num def test_count_whereclause(self): @@ -281,17 +296,24 @@ class QueryAlternativesTest(fixtures.MappedTest): session = create_session() - user = session.query(User).filter_by(name='ed').first() - assert user.name == 'ed' - - user = (session.query(User).join('addresses'). - filter(Address.email_address == 'fred@the.fred')).first() - assert user.name == 'fred' + user = session.query(User).filter_by(name="ed").first() + assert user.name == "ed" - user = session.query(User).filter( - User.addresses.any(Address.email_address == 'fred@the.fred') + user = ( + session.query(User) + .join("addresses") + .filter(Address.email_address == "fred@the.fred") ).first() - assert user.name == 'fred' + assert user.name == "fred" + + user = ( + session.query(User) + .filter( + User.addresses.any(Address.email_address == "fred@the.fred") + ) + .first() + ) + assert user.name == "fred" def test_instances_entities(self): """Query.instances(cursor, *mappers_or_columns, **kwargs) @@ -301,11 +323,12 @@ class QueryAlternativesTest(fixtures.MappedTest): """ - addresses_table, User, users_table, Address = \ - (self.tables.addresses_table, - self.classes.User, - self.tables.users_table, - self.classes.Address) + addresses_table, User, users_table, Address = ( + self.tables.addresses_table, + self.classes.User, + self.tables.users_table, + self.classes.Address, + ) session = create_session() @@ -396,18 +419,26 @@ class QueryAlternativesTest(fixtures.MappedTest): session = create_session() - users = session.query(User).filter_by(name='fred').all() + users = session.query(User).filter_by(name="fred").all() assert len(users) == 1 - users = session.query(User).filter(User.name == 'fred').all() + users = session.query(User).filter(User.name == "fred").all() assert len(users) == 1 - users = (session.query(User).join('addresses'). - filter_by(email_address='fred@the.fred')).all() + users = ( + session.query(User) + .join("addresses") + .filter_by(email_address="fred@the.fred") + ).all() assert len(users) == 1 - users = session.query(User).filter(User.addresses.any( - Address.email_address == 'fred@the.fred')).all() + users = ( + session.query(User) + .filter( + User.addresses.any(Address.email_address == "fred@the.fred") + ) + .all() + ) assert len(users) == 1 def test_selectfirst(self): @@ -442,17 +473,22 @@ class QueryAlternativesTest(fixtures.MappedTest): onebounce = session.query(Address).filter_by(bounces=1).first() assert onebounce.bounces == 1 - onebounce_user = (session.query(User).join('addresses'). - filter_by(bounces=1)).first() - assert onebounce_user.name == 'jack' + onebounce_user = ( + session.query(User).join("addresses").filter_by(bounces=1) + ).first() + assert onebounce_user.name == "jack" - onebounce_user = (session.query(User).join('addresses'). - filter(Address.bounces == 1)).first() - assert onebounce_user.name == 'jack' + onebounce_user = ( + session.query(User).join("addresses").filter(Address.bounces == 1) + ).first() + assert onebounce_user.name == "jack" - onebounce_user = session.query(User).filter(User.addresses.any( - Address.bounces == 1)).first() - assert onebounce_user.name == 'jack' + onebounce_user = ( + session.query(User) + .filter(User.addresses.any(Address.bounces == 1)) + .first() + ) + assert onebounce_user.name == "jack" def test_selectone(self): """Query.selectone(arg=None, **kwargs) @@ -465,7 +501,7 @@ class QueryAlternativesTest(fixtures.MappedTest): session = create_session() - ed = session.query(User).filter(User.name == 'jack').one() + ed = session.query(User).filter(User.name == "jack").one() def test_selectone_by(self): """Query.selectone_by @@ -481,15 +517,22 @@ class QueryAlternativesTest(fixtures.MappedTest): session = create_session() - ed = session.query(User).filter_by(name='jack').one() + ed = session.query(User).filter_by(name="jack").one() - ed = session.query(User).filter(User.name == 'jack').one() + ed = session.query(User).filter(User.name == "jack").one() - ed = session.query(User).join('addresses').filter( - Address.email_address == 'ed@foo.bar').one() + ed = ( + session.query(User) + .join("addresses") + .filter(Address.email_address == "ed@foo.bar") + .one() + ) - ed = session.query(User).filter(User.addresses.any( - Address.email_address == 'ed@foo.bar')).one() + ed = ( + session.query(User) + .filter(User.addresses.any(Address.email_address == "ed@foo.bar")) + .one() + ) def test_select_statement(self): """Query.select_statement(statement, **params) @@ -516,8 +559,11 @@ class QueryAlternativesTest(fixtures.MappedTest): session = create_session() - users = (session.query(User). - from_statement(text('SELECT * FROM users_table'))).all() + users = ( + session.query(User).from_statement( + text("SELECT * FROM users_table") + ) + ).all() assert len(users) == 4 def test_select_whereclause(self): @@ -533,8 +579,8 @@ class QueryAlternativesTest(fixtures.MappedTest): session = create_session() - users = session.query(User).filter(User.name == 'ed').all() - assert len(users) == 1 and users[0].name == 'ed' + users = session.query(User).filter(User.name == "ed").all() + assert len(users) == 1 and users[0].name == "ed" users = session.query(User).filter(text("name='ed'")).all() - assert len(users) == 1 and users[0].name == 'ed' + assert len(users) == 1 and users[0].name == "ed" diff --git a/test/orm/test_descriptor.py b/test/orm/test_descriptor.py index f5bc629e9d..798d2af544 100644 --- a/test/orm/test_descriptor.py +++ b/test/orm/test_descriptor.py @@ -10,8 +10,9 @@ from sqlalchemy.testing import eq_ class TestDescriptor(descriptor_props.DescriptorProperty): - def __init__(self, cls, key, descriptor=None, doc=None, - comparator_factory=None): + def __init__( + self, cls, key, descriptor=None, doc=None, comparator_factory=None + ): self.parent = cls.__mapper__ self.key = key self.doc = doc @@ -27,7 +28,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest): Base = declarative_base() class Foo(Base): - __tablename__ = 'foo' + __tablename__ = "foo" id = Column(Integer, primary_key=True) return Foo @@ -35,7 +36,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest): def test_fixture(self): Foo = self._fixture() - d = TestDescriptor(Foo, 'foo') + d = TestDescriptor(Foo, "foo") d.instrument_class(Foo.__mapper__) assert Foo.foo @@ -45,7 +46,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest): prop = property(lambda self: None) Foo.foo = prop - d = TestDescriptor(Foo, 'foo') + d = TestDescriptor(Foo, "foo") d.instrument_class(Foo.__mapper__) assert Foo().foo is None @@ -55,7 +56,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest): Foo = self._fixture() class myprop(property): - attr = 'bar' + attr = "bar" def method1(self): return "method1" @@ -63,19 +64,19 @@ class DescriptorInstrumentationTest(fixtures.ORMTest): prop = myprop(lambda self: None) Foo.foo = prop - d = TestDescriptor(Foo, 'foo') + d = TestDescriptor(Foo, "foo") d.instrument_class(Foo.__mapper__) assert Foo().foo is None assert Foo.foo is not prop - assert Foo.foo.attr == 'bar' - assert Foo.foo.method1() == 'method1' + assert Foo.foo.attr == "bar" + assert Foo.foo.method1() == "method1" def test_comparator(self): class Comparator(PropComparator): __hash__ = None - attr = 'bar' + attr = "bar" def method1(self): return "method1" @@ -84,46 +85,41 @@ class DescriptorInstrumentationTest(fixtures.ORMTest): return "method2" def __getitem__(self, key): - return 'value' + return "value" def __eq__(self, other): - return column('foo') == func.upper(other) + return column("foo") == func.upper(other) Foo = self._fixture() - d = TestDescriptor(Foo, 'foo', comparator_factory=Comparator) + d = TestDescriptor(Foo, "foo", comparator_factory=Comparator) d.instrument_class(Foo.__mapper__) eq_(Foo.foo.method1(), "method1") - eq_(Foo.foo.method2('x'), "method2") - assert Foo.foo.attr == 'bar' - assert Foo.foo['bar'] == 'value' - eq_( - (Foo.foo == 'bar').__str__(), - "foo = upper(:upper_1)" - ) + eq_(Foo.foo.method2("x"), "method2") + assert Foo.foo.attr == "bar" + assert Foo.foo["bar"] == "value" + eq_((Foo.foo == "bar").__str__(), "foo = upper(:upper_1)") def test_aliased_comparator(self): class Comparator(ColumnProperty.Comparator): __hash__ = None def __eq__(self, other): - return func.foobar(self.__clause_element__()) ==\ - func.foobar(other) + return func.foobar(self.__clause_element__()) == func.foobar( + other + ) Foo = self._fixture() - Foo._name = Column('name', String) + Foo._name = Column("name", String) def comparator_factory(self, mapper): - prop = mapper._props['_name'] + prop = mapper._props["_name"] return Comparator(prop, mapper) - d = TestDescriptor(Foo, 'foo', comparator_factory=comparator_factory) + d = TestDescriptor(Foo, "foo", comparator_factory=comparator_factory) d.instrument_class(Foo.__mapper__) + eq_(str(Foo.foo == "ed"), "foobar(foo.name) = foobar(:foobar_1)") eq_( - str(Foo.foo == 'ed'), - "foobar(foo.name) = foobar(:foobar_1)" - ) - eq_( - str(aliased(Foo).foo == 'ed'), - "foobar(foo_1.name) = foobar(:foobar_1)" + str(aliased(Foo).foo == "ed"), + "foobar(foo_1.name) = foobar(:foobar_1)", ) diff --git a/test/orm/test_dynamic.py b/test/orm/test_dynamic.py index 5dfb3fde51..09252dda32 100644 --- a/test/orm/test_dynamic.py +++ b/test/orm/test_dynamic.py @@ -1,54 +1,85 @@ from sqlalchemy import testing, desc, select, func, exc, cast, Integer from sqlalchemy.orm import ( - mapper, relationship, create_session, Query, attributes, exc as orm_exc, - Session, backref, configure_mappers) + mapper, + relationship, + create_session, + Query, + attributes, + exc as orm_exc, + Session, + backref, + configure_mappers, +) from sqlalchemy.orm.dynamic import AppenderMixin from sqlalchemy.testing import ( - AssertsCompiledSQL, assert_raises_message, assert_raises, eq_, is_) + AssertsCompiledSQL, + assert_raises_message, + assert_raises, + eq_, + is_, +) from test.orm import _fixtures from sqlalchemy.testing.assertsql import CompiledSQL class _DynamicFixture(object): def _user_address_fixture(self, addresses_args={}): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper( - User, users, properties={ - 'addresses': relationship( - Address, lazy="dynamic", **addresses_args)}) + User, + users, + properties={ + "addresses": relationship( + Address, lazy="dynamic", **addresses_args + ) + }, + ) mapper(Address, addresses) return User, Address def _order_item_fixture(self, items_args={}): - items, Order, orders, order_items, Item = (self.tables.items, - self.classes.Order, - self.tables.orders, - self.tables.order_items, - self.classes.Item) + items, Order, orders, order_items, Item = ( + self.tables.items, + self.classes.Order, + self.tables.orders, + self.tables.order_items, + self.classes.Item, + ) mapper( - Order, orders, properties={ - 'items': relationship( - Item, secondary=order_items, lazy="dynamic", - **items_args)}) + Order, + orders, + properties={ + "items": relationship( + Item, secondary=order_items, lazy="dynamic", **items_args + ) + }, + ) mapper(Item, items) return Order, Item class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): - def test_basic(self): User, Address = self._user_address_fixture() sess = create_session() q = sess.query(User) - eq_([User(id=7, - addresses=[Address(id=1, email_address='jack@bean.com')])], - q.filter(User.id == 7).all()) + eq_( + [ + User( + id=7, + addresses=[Address(id=1, email_address="jack@bean.com")], + ) + ], + q.filter(User.id == 7).all(), + ) eq_(self.static.user_address_result, q.all()) def test_statement(self): @@ -65,7 +96,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): "SELECT addresses.id, addresses.user_id, addresses.email_address " "FROM " "addresses WHERE :param_1 = addresses.user_id", - use_default_dialect=True + use_default_dialect=True, ) def test_detached_raise(self): @@ -76,35 +107,40 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): assert_raises( orm_exc.DetachedInstanceError, u.addresses.filter_by, - email_address='e' + email_address="e", ) def test_no_uselist_false(self): User, Address = self._user_address_fixture( - addresses_args={"uselist": False}) + addresses_args={"uselist": False} + ) assert_raises_message( exc.InvalidRequestError, "On relationship User.addresses, 'dynamic' loaders cannot be " "used with many-to-one/one-to-one relationships and/or " "uselist=False.", - configure_mappers + configure_mappers, ) def test_no_m2o(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper( - Address, addresses, properties={ - 'user': relationship(User, lazy='dynamic')}) + Address, + addresses, + properties={"user": relationship(User, lazy="dynamic")}, + ) mapper(User, users) assert_raises_message( exc.InvalidRequestError, "On relationship Address.user, 'dynamic' loaders cannot be " "used with many-to-one/one-to-one relationships and/or " "uselist=False.", - configure_mappers + configure_mappers, ) def test_order_by(self): @@ -114,46 +150,49 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): eq_( list(u.addresses.order_by(desc(Address.email_address))), [ - Address(email_address='ed@wood.com'), - Address(email_address='ed@lala.com'), - Address(email_address='ed@bettyboop.com') - ] + Address(email_address="ed@wood.com"), + Address(email_address="ed@lala.com"), + Address(email_address="ed@bettyboop.com"), + ], ) def test_configured_order_by(self): addresses = self.tables.addresses User, Address = self._user_address_fixture( - addresses_args={"order_by": addresses.c.email_address.desc()}) + addresses_args={"order_by": addresses.c.email_address.desc()} + ) sess = create_session() u = sess.query(User).get(8) eq_( list(u.addresses), [ - Address(email_address='ed@wood.com'), - Address(email_address='ed@lala.com'), - Address(email_address='ed@bettyboop.com') - ] + Address(email_address="ed@wood.com"), + Address(email_address="ed@lala.com"), + Address(email_address="ed@bettyboop.com"), + ], ) # test cancellation of None, replacement with something else eq_( list(u.addresses.order_by(None).order_by(Address.email_address)), [ - Address(email_address='ed@bettyboop.com'), - Address(email_address='ed@lala.com'), - Address(email_address='ed@wood.com') - ] + Address(email_address="ed@bettyboop.com"), + Address(email_address="ed@lala.com"), + Address(email_address="ed@wood.com"), + ], ) # test cancellation of None, replacement with nothing eq_( set(u.addresses.order_by(None)), - set([ - Address(email_address='ed@bettyboop.com'), - Address(email_address='ed@lala.com'), - Address(email_address='ed@wood.com') - ]) + set( + [ + Address(email_address="ed@bettyboop.com"), + Address(email_address="ed@lala.com"), + Address(email_address="ed@wood.com"), + ] + ), ) def test_count(self): @@ -163,15 +202,22 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): eq_(u.addresses.count(), 1) def test_dynamic_on_backref(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(Address, addresses, properties={ - 'user': relationship(User, - backref=backref('addresses', lazy='dynamic')) - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + Address, + addresses, + properties={ + "user": relationship( + User, backref=backref("addresses", lazy="dynamic") + ) + }, + ) mapper(User, users) sess = create_session() @@ -179,6 +225,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): def go(): ad.user = None + self.assert_sql_count(testing.db, go, 0) sess.flush() u = sess.query(User).get(7) @@ -197,8 +244,14 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): q.filter(User.id == 7).all(), [ User( - id=7, addresses=[ - Address(id=1, email_address='jack@bean.com')])]) + id=7, + addresses=[ + Address(id=1, email_address="jack@bean.com") + ], + ) + ], + ) + self.assert_sql_count(testing.db, go, 2) def test_no_populate(self): @@ -207,12 +260,16 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): assert_raises_message( NotImplementedError, "Dynamic attributes don't support collection population.", - attributes.set_committed_value, u1, 'addresses', [] + attributes.set_committed_value, + u1, + "addresses", + [], ) def test_m2m(self): Order, Item = self._order_item_fixture( - items_args={"backref": backref("orders", lazy="dynamic")}) + items_args={"backref": backref("orders", lazy="dynamic")} + ) sess = create_session() o1 = Order(id=15, description="order 10") @@ -225,21 +282,32 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): assert i1 in o1.items.all() @testing.exclude( - 'mysql', 'between', ((5, 1, 49), (5, 1, 52)), - 'https://bugs.launchpad.net/ubuntu/+source/mysql-5.1/+bug/706988') + "mysql", + "between", + ((5, 1, 49), (5, 1, 52)), + "https://bugs.launchpad.net/ubuntu/+source/mysql-5.1/+bug/706988", + ) def test_association_nonaliased(self): - items, Order, orders, order_items, Item = (self.tables.items, - self.classes.Order, - self.tables.orders, - self.tables.order_items, - self.classes.Item) - - mapper(Order, orders, properties={ - 'items': relationship(Item, - secondary=order_items, - lazy="dynamic", - order_by=order_items.c.item_id) - }) + items, Order, orders, order_items, Item = ( + self.tables.items, + self.classes.Order, + self.tables.orders, + self.tables.order_items, + self.classes.Item, + ) + + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, + secondary=order_items, + lazy="dynamic", + order_by=order_items.c.item_id, + ) + }, + ) mapper(Item, items) sess = create_session() @@ -252,31 +320,32 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): " order_items WHERE :param_1 = order_items.order_id AND " "items.id = order_items.item_id" " ORDER BY order_items.item_id", - use_default_dialect=True + use_default_dialect=True, ) # filter criterion against the secondary table # works - eq_( - o.items.filter(order_items.c.item_id == 2).all(), - [Item(id=2)] - ) + eq_(o.items.filter(order_items.c.item_id == 2).all(), [Item(id=2)]) def test_secondary_as_join(self): # test [ticket:4349] User, users = self.classes.User, self.tables.users - items, orders, order_items, Item = (self.tables.items, - self.tables.orders, - self.tables.order_items, - self.classes.Item) - - mapper(User, users, properties={ - 'items': relationship( - Item, - secondary=order_items.join(orders), - lazy="dynamic" - ) - }) + items, orders, order_items, Item = ( + self.tables.items, + self.tables.orders, + self.tables.order_items, + self.classes.Item, + ) + + mapper( + User, + users, + properties={ + "items": relationship( + Item, secondary=order_items.join(orders), lazy="dynamic" + ) + }, + ) mapper(Item, items) sess = create_session() @@ -290,7 +359,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): "ON orders.id = order_items.order_id " "WHERE :param_1 = orders.user_id " "AND items.id = order_items.item_id", - use_default_dialect=True + use_default_dialect=True, ) def test_secondary_doesnt_interfere_w_join_to_fromlist(self): @@ -304,21 +373,32 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): items, order_items, Item = ( self.tables.items, self.tables.order_items, - self.classes.Item) + self.classes.Item, + ) item_keywords = self.tables.item_keywords class ItemKeyword(object): pass - mapper(Order, orders, properties={ - 'items': relationship(Item, secondary=order_items, lazy='dynamic'), - }) mapper( - ItemKeyword, item_keywords, - primary_key=[item_keywords.c.item_id, item_keywords.c.keyword_id]) - mapper(Item, items, properties={ - 'item_keywords': relationship(ItemKeyword) - }) + Order, + orders, + properties={ + "items": relationship( + Item, secondary=order_items, lazy="dynamic" + ) + }, + ) + mapper( + ItemKeyword, + item_keywords, + primary_key=[item_keywords.c.item_id, item_keywords.c.keyword_id], + ) + mapper( + Item, + items, + properties={"item_keywords": relationship(ItemKeyword)}, + ) sess = create_session() order = sess.query(Order).first() @@ -331,7 +411,7 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): "JOIN item_keywords ON items.id = item_keywords.item_id " "WHERE :param_1 = order_items.order_id " "AND items.id = order_items.item_id", - use_default_dialect=True + use_default_dialect=True, ) def test_transient_count(self): @@ -349,8 +429,10 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): def test_custom_query(self): class MyQuery(Query): pass + User, Address = self._user_address_fixture( - addresses_args={"query_class": MyQuery}) + addresses_args={"query_class": MyQuery} + ) sess = create_session() u = User() @@ -359,14 +441,14 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): col = u.addresses assert isinstance(col, Query) assert isinstance(col, MyQuery) - assert hasattr(col, 'append') - eq_(type(col).__name__, 'AppenderMyQuery') + assert hasattr(col, "append") + eq_(type(col).__name__, "AppenderMyQuery") q = col.limit(1) assert isinstance(q, Query) assert isinstance(q, MyQuery) - assert not hasattr(q, 'append') - eq_(type(q).__name__, 'MyQuery') + assert not hasattr(q, "append") + eq_(type(q).__name__, "MyQuery") def test_custom_query_with_custom_mixin(self): class MyAppenderMixin(AppenderMixin): @@ -384,7 +466,8 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): query_class = MyQuery User, Address = self._user_address_fixture( - addresses_args={"query_class": MyAppenderQuery}) + addresses_args={"query_class": MyAppenderQuery} + ) sess = create_session() u = User() @@ -393,21 +476,21 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): col = u.addresses assert isinstance(col, Query) assert isinstance(col, MyQuery) - assert hasattr(col, 'append') - assert hasattr(col, 'add') - eq_(type(col).__name__, 'MyAppenderQuery') + assert hasattr(col, "append") + assert hasattr(col, "add") + eq_(type(col).__name__, "MyAppenderQuery") q = col.limit(1) assert isinstance(q, Query) assert isinstance(q, MyQuery) - assert not hasattr(q, 'append') - assert not hasattr(q, 'add') - eq_(type(q).__name__, 'MyQuery') + assert not hasattr(q, "append") + assert not hasattr(q, "add") + eq_(type(q).__name__, "MyQuery") class UOWTest( - _DynamicFixture, _fixtures.FixtureTest, - testing.AssertsExecutionResults): + _DynamicFixture, _fixtures.FixtureTest, testing.AssertsExecutionResults +): run_inserts = None @@ -416,17 +499,19 @@ class UOWTest( User, Address = self._user_address_fixture() sess = create_session() - u1 = User(name='jack') - a1 = Address(email_address='foo') + u1 = User(name="jack") + a1 = Address(email_address="foo") sess.add_all([u1, a1]) sess.flush() eq_( testing.db.scalar( - select( - [func.count(cast(1, Integer))]). - where(addresses.c.user_id != None)), # noqa - 0) + select([func.count(cast(1, Integer))]).where( + addresses.c.user_id != None + ) + ), # noqa + 0, + ) u1 = sess.query(User).get(u1.id) u1.addresses.append(a1) sess.flush() @@ -435,17 +520,18 @@ class UOWTest( testing.db.execute( select([addresses]).where(addresses.c.user_id != None) # noqa ).fetchall(), - [(a1.id, u1.id, 'foo')] + [(a1.id, u1.id, "foo")], ) u1.addresses.remove(a1) sess.flush() eq_( testing.db.scalar( - select( - [func.count(cast(1, Integer))]). - where(addresses.c.user_id != None)), # noqa - 0 + select([func.count(cast(1, Integer))]).where( + addresses.c.user_id != None + ) + ), # noqa + 0, ) u1.addresses.append(a1) @@ -454,10 +540,10 @@ class UOWTest( testing.db.execute( select([addresses]).where(addresses.c.user_id != None) # noqa ).fetchall(), - [(a1.id, u1.id, 'foo')] + [(a1.id, u1.id, "foo")], ) - a2 = Address(email_address='bar') + a2 = Address(email_address="bar") u1.addresses.remove(a1) u1.addresses.append(a2) sess.flush() @@ -465,18 +551,19 @@ class UOWTest( testing.db.execute( select([addresses]).where(addresses.c.user_id != None) # noqa ).fetchall(), - [(a2.id, u1.id, 'bar')] + [(a2.id, u1.id, "bar")], ) def test_merge(self): addresses = self.tables.addresses User, Address = self._user_address_fixture( - addresses_args={"order_by": addresses.c.email_address}) + addresses_args={"order_by": addresses.c.email_address} + ) sess = create_session() - u1 = User(name='jack') - a1 = Address(email_address='a1') - a2 = Address(email_address='a2') - a3 = Address(email_address='a3') + u1 = User(name="jack") + a1 = Address(email_address="a1") + a2 = Address(email_address="a2") + a3 = Address(email_address="a3") u1.addresses.append(a2) u1.addresses.append(a3) @@ -484,42 +571,36 @@ class UOWTest( sess.add_all([u1, a1]) sess.flush() - u1 = User(id=u1.id, name='jack') + u1 = User(id=u1.id, name="jack") u1.addresses.append(a1) u1.addresses.append(a3) u1 = sess.merge(u1) - eq_(attributes.get_history(u1, 'addresses'), ( - [a1], - [a3], - [a2] - )) + eq_(attributes.get_history(u1, "addresses"), ([a1], [a3], [a2])) sess.flush() - eq_( - list(u1.addresses), - [a1, a3] - ) + eq_(list(u1.addresses), [a1, a3]) def test_hasattr(self): User, Address = self._user_address_fixture() - u1 = User(name='jack') + u1 = User(name="jack") - assert 'addresses' not in u1.__dict__ - u1.addresses = [Address(email_address='test')] - assert 'addresses' in u1.__dict__ + assert "addresses" not in u1.__dict__ + u1.addresses = [Address(email_address="test")] + assert "addresses" in u1.__dict__ def test_collection_set(self): addresses = self.tables.addresses User, Address = self._user_address_fixture( - addresses_args={"order_by": addresses.c.email_address}) + addresses_args={"order_by": addresses.c.email_address} + ) sess = create_session(autoflush=True, autocommit=False) - u1 = User(name='jack') - a1 = Address(email_address='a1') - a2 = Address(email_address='a2') - a3 = Address(email_address='a3') - a4 = Address(email_address='a4') + u1 = User(name="jack") + a1 = Address(email_address="a1") + a2 = Address(email_address="a2") + a3 = Address(email_address="a3") + a4 = Address(email_address="a4") sess.add(u1) u1.addresses = [a1, a3] @@ -544,7 +625,7 @@ class UOWTest( u1_id = u1.id sess.expire_all() - u1.addresses.append(Address(email_address='a2')) + u1.addresses.append(Address(email_address="a2")) self.assert_sql_execution( testing.db, @@ -552,12 +633,13 @@ class UOWTest( CompiledSQL( "SELECT users.id AS users_id, users.name AS users_name " "FROM users WHERE users.id = :param_1", - lambda ctx: [{"param_1": u1_id}]), + lambda ctx: [{"param_1": u1_id}], + ), CompiledSQL( "INSERT INTO addresses (user_id, email_address) " "VALUES (:user_id, :email_address)", - lambda ctx: [{'email_address': 'a2', 'user_id': u1_id}] - ) + lambda ctx: [{"email_address": "a2", "user_id": u1_id}], + ), ) def test_noload_remove(self): @@ -567,7 +649,7 @@ class UOWTest( sess = Session() u1 = User(name="jack", addresses=[Address(email_address="a1")]) - a2 = Address(email_address='a2') + a2 = Address(email_address="a2") u1.addresses.append(a2) sess.add(u1) sess.commit() @@ -585,41 +667,40 @@ class UOWTest( "SELECT addresses.id AS addresses_id, addresses.email_address " "AS addresses_email_address FROM addresses " "WHERE addresses.id = :param_1", - lambda ctx: [{'param_1': a2_id}] + lambda ctx: [{"param_1": a2_id}], ), CompiledSQL( "UPDATE addresses SET user_id=:user_id WHERE addresses.id = " ":addresses_id", - lambda ctx: [{'addresses_id': a2_id, 'user_id': None}] + lambda ctx: [{"addresses_id": a2_id, "user_id": None}], ), CompiledSQL( "SELECT users.id AS users_id, users.name AS users_name " "FROM users WHERE users.id = :param_1", - lambda ctx: [{"param_1": u1_id}]), + lambda ctx: [{"param_1": u1_id}], + ), ) def test_rollback(self): User, Address = self._user_address_fixture() sess = create_session( - expire_on_commit=False, autocommit=False, autoflush=True) - u1 = User(name='jack') - u1.addresses.append(Address(email_address='lala@hoho.com')) + expire_on_commit=False, autocommit=False, autoflush=True + ) + u1 = User(name="jack") + u1.addresses.append(Address(email_address="lala@hoho.com")) sess.add(u1) sess.flush() sess.commit() - u1.addresses.append(Address(email_address='foo@bar.com')) + u1.addresses.append(Address(email_address="foo@bar.com")) eq_( u1.addresses.order_by(Address.id).all(), [ - Address(email_address='lala@hoho.com'), - Address(email_address='foo@bar.com') - ] + Address(email_address="lala@hoho.com"), + Address(email_address="foo@bar.com"), + ], ) sess.rollback() - eq_( - u1.addresses.all(), - [Address(email_address='lala@hoho.com')] - ) + eq_(u1.addresses.all(), [Address(email_address="lala@hoho.com")]) def _test_delete_cascade(self, expected): addresses = self.tables.addresses @@ -627,25 +708,29 @@ class UOWTest( addresses_args={ "order_by": addresses.c.id, "backref": "user", - "cascade": "save-update" if expected else "all, delete"}) + "cascade": "save-update" if expected else "all, delete", + } + ) sess = create_session(autoflush=True, autocommit=False) - u = User(name='ed') + u = User(name="ed") u.addresses.extend( - [Address(email_address=letter) for letter in 'abcdef'] + [Address(email_address=letter) for letter in "abcdef"] ) sess.add(u) sess.commit() eq_( testing.db.scalar( - select([func.count('*')]).where( - addresses.c.user_id == None)), # noqa - 0) + select([func.count("*")]).where(addresses.c.user_id == None) + ), # noqa + 0, + ) eq_( testing.db.scalar( - select([func.count('*')]).where( - addresses.c.user_id != None)), # noqa - 6) + select([func.count("*")]).where(addresses.c.user_id != None) + ), # noqa + 6, + ) sess.delete(u) @@ -654,26 +739,27 @@ class UOWTest( if expected: eq_( testing.db.scalar( - select([func.count('*')]).where( + select([func.count("*")]).where( addresses.c.user_id == None # noqa ) ), - 6 + 6, ) eq_( testing.db.scalar( - select([func.count('*')]).where( + select([func.count("*")]).where( addresses.c.user_id != None # noqa ) ), - 0 + 0, ) else: eq_( testing.db.scalar( - select([func.count('*')]).select_from(addresses) + select([func.count("*")]).select_from(addresses) ), - 0) + 0, + ) def test_delete_nocascade(self): self._test_delete_cascade(True) @@ -685,9 +771,14 @@ class UOWTest( Node, nodes = self.classes.Node, self.tables.nodes mapper( - Node, nodes, properties={ - 'children': relationship( - Node, lazy="dynamic", order_by=nodes.c.id)}) + Node, + nodes, + properties={ + "children": relationship( + Node, lazy="dynamic", order_by=nodes.c.id + ) + }, + ) sess = Session() n2, n3 = Node(), Node() @@ -703,32 +794,36 @@ class UOWTest( addresses_args={ "order_by": addresses.c.id, "backref": "user", - "cascade": "all, delete-orphan"}) + "cascade": "all, delete-orphan", + } + ) sess = create_session(autoflush=True, autocommit=False) - u = User(name='ed') + u = User(name="ed") u.addresses.extend( - [Address(email_address=letter) for letter in 'abcdef'] + [Address(email_address=letter) for letter in "abcdef"] ) sess.add(u) for a in u.addresses.filter( - Address.email_address.in_(['c', 'e', 'f'])): + Address.email_address.in_(["c", "e", "f"]) + ): u.addresses.remove(a) eq_( set(ad for ad, in sess.query(Address.email_address)), - set(['a', 'b', 'd']) + set(["a", "b", "d"]), ) def _backref_test(self, autoflush, saveuser): User, Address = self._user_address_fixture( - addresses_args={"backref": "user"}) + addresses_args={"backref": "user"} + ) sess = create_session(autoflush=autoflush, autocommit=False) - u = User(name='buffy') + u = User(name="buffy") - a = Address(email_address='foo@bar.com') + a = Address(email_address="foo@bar.com") a.user = u if saveuser: @@ -766,7 +861,8 @@ class UOWTest( def test_backref_events(self): User, Address = self._user_address_fixture( - addresses_args={"backref": "user"}) + addresses_args={"backref": "user"} + ) u1 = User() a1 = Address() @@ -775,15 +871,16 @@ class UOWTest( def test_no_deref(self): User, Address = self._user_address_fixture( - addresses_args={"backref": "user", }) + addresses_args={"backref": "user"} + ) session = create_session() user = User() - user.name = 'joe' - user.fullname = 'Joe User' - user.password = 'Joe\'s secret' + user.name = "joe" + user.fullname = "Joe User" + user.password = "Joe's secret" address = Address() - address.email_address = 'joe@joesdomain.example' + address.email_address = "joe@joesdomain.example" address.user = user session.add(user) session.flush() @@ -802,9 +899,9 @@ class UOWTest( session = create_session(testing.db) return session.query(User).first().addresses.all() - eq_(query1(), [Address(email_address='joe@joesdomain.example')]) - eq_(query2(), [Address(email_address='joe@joesdomain.example')]) - eq_(query3(), [Address(email_address='joe@joesdomain.example')]) + eq_(query1(), [Address(email_address="joe@joesdomain.example")]) + eq_(query2(), [Address(email_address="joe@joesdomain.example")]) + eq_(query3(), [Address(email_address="joe@joesdomain.example")]) class HistoryTest(_DynamicFixture, _fixtures.FixtureTest): @@ -812,7 +909,8 @@ class HistoryTest(_DynamicFixture, _fixtures.FixtureTest): def _transient_fixture(self, addresses_args={}): User, Address = self._user_address_fixture( - addresses_args=addresses_args) + addresses_args=addresses_args + ) u1 = User() a1 = Address() @@ -820,10 +918,11 @@ class HistoryTest(_DynamicFixture, _fixtures.FixtureTest): def _persistent_fixture(self, autoflush=True, addresses_args={}): User, Address = self._user_address_fixture( - addresses_args=addresses_args) + addresses_args=addresses_args + ) - u1 = User(name='u1') - a1 = Address(email_address='a1') + u1 = User(name="u1") + a1 = Address(email_address="a1") s = Session(autoflush=autoflush) s.add(u1) s.flush() @@ -845,55 +944,47 @@ class HistoryTest(_DynamicFixture, _fixtures.FixtureTest): elif isinstance(obj, self.classes.Order): attrname = "items" - eq_( - attributes.get_history(obj, attrname), - compare - ) + eq_(attributes.get_history(obj, attrname), compare) if compare_passive is None: compare_passive = compare eq_( - attributes.get_history(obj, attrname, - attributes.LOAD_AGAINST_COMMITTED), - compare_passive + attributes.get_history( + obj, attrname, attributes.LOAD_AGAINST_COMMITTED + ), + compare_passive, ) def test_append_transient(self): u1, a1 = self._transient_fixture() u1.addresses.append(a1) - self._assert_history(u1, - ([a1], [], [])) + self._assert_history(u1, ([a1], [], [])) def test_append_persistent(self): u1, a1, s = self._persistent_fixture() u1.addresses.append(a1) - self._assert_history(u1, - ([a1], [], []) - ) + self._assert_history(u1, ([a1], [], [])) def test_remove_transient(self): u1, a1 = self._transient_fixture() u1.addresses.append(a1) u1.addresses.remove(a1) - self._assert_history(u1, - ([], [], [])) + self._assert_history(u1, ([], [], [])) def test_backref_pop_transient(self): u1, a1 = self._transient_fixture(addresses_args={"backref": "user"}) u1.addresses.append(a1) - self._assert_history(u1, - ([a1], [], [])) + self._assert_history(u1, ([a1], [], [])) a1.user = None # removed from added - self._assert_history(u1, - ([], [], [])) + self._assert_history(u1, ([], [], [])) def test_remove_persistent(self): u1, a1, s = self._persistent_fixture() @@ -903,50 +994,49 @@ class HistoryTest(_DynamicFixture, _fixtures.FixtureTest): u1.addresses.remove(a1) - self._assert_history(u1, - ([], [], [a1])) + self._assert_history(u1, ([], [], [a1])) def test_backref_pop_persistent_autoflush_o2m_active_hist(self): u1, a1, s = self._persistent_fixture( - addresses_args={"backref": backref("user", active_history=True)}) + addresses_args={"backref": backref("user", active_history=True)} + ) u1.addresses.append(a1) s.flush() s.expire_all() a1.user = None - self._assert_history(u1, - ([], [], [a1])) + self._assert_history(u1, ([], [], [a1])) def test_backref_pop_persistent_autoflush_m2m(self): o1, i1, s = self._persistent_m2m_fixture( - items_args={"backref": "orders"}) + items_args={"backref": "orders"} + ) o1.items.append(i1) s.flush() s.expire_all() i1.orders.remove(o1) - self._assert_history(o1, - ([], [], [i1])) + self._assert_history(o1, ([], [], [i1])) def test_backref_pop_persistent_noflush_m2m(self): o1, i1, s = self._persistent_m2m_fixture( - items_args={"backref": "orders"}, autoflush=False) + items_args={"backref": "orders"}, autoflush=False + ) o1.items.append(i1) s.flush() s.expire_all() i1.orders.remove(o1) - self._assert_history(o1, - ([], [], [i1])) + self._assert_history(o1, ([], [], [i1])) def test_unchanged_persistent(self): Address = self.classes.Address u1, a1, s = self._persistent_fixture() - a2, a3 = Address(email_address='a2'), Address(email_address='a3') + a2, a3 = Address(email_address="a2"), Address(email_address="a3") u1.addresses.append(a1) u1.addresses.append(a2) @@ -955,52 +1045,61 @@ class HistoryTest(_DynamicFixture, _fixtures.FixtureTest): u1.addresses.append(a3) u1.addresses.remove(a2) - self._assert_history(u1, - ([a3], [a1], [a2]), - compare_passive=([a3], [], [a2])) + self._assert_history( + u1, ([a3], [a1], [a2]), compare_passive=([a3], [], [a2]) + ) def test_replace_transient(self): Address = self.classes.Address u1, a1 = self._transient_fixture() - a2, a3, a4, a5 = Address(email_address='a2'), \ - Address(email_address='a3'), Address(email_address='a4'), \ - Address(email_address='a5') + a2, a3, a4, a5 = ( + Address(email_address="a2"), + Address(email_address="a3"), + Address(email_address="a4"), + Address(email_address="a5"), + ) u1.addresses = [a1, a2] u1.addresses = [a2, a3, a4, a5] - self._assert_history(u1, - ([a2, a3, a4, a5], [], [])) + self._assert_history(u1, ([a2, a3, a4, a5], [], [])) def test_replace_persistent_noflush(self): Address = self.classes.Address u1, a1, s = self._persistent_fixture(autoflush=False) - a2, a3, a4, a5 = Address(email_address='a2'), \ - Address(email_address='a3'), Address(email_address='a4'), \ - Address(email_address='a5') + a2, a3, a4, a5 = ( + Address(email_address="a2"), + Address(email_address="a3"), + Address(email_address="a4"), + Address(email_address="a5"), + ) u1.addresses = [a1, a2] u1.addresses = [a2, a3, a4, a5] - self._assert_history(u1, - ([a2, a3, a4, a5], [], [])) + self._assert_history(u1, ([a2, a3, a4, a5], [], [])) def test_replace_persistent_autoflush(self): Address = self.classes.Address u1, a1, s = self._persistent_fixture(autoflush=True) - a2, a3, a4, a5 = Address(email_address='a2'), \ - Address(email_address='a3'), Address(email_address='a4'), \ - Address(email_address='a5') + a2, a3, a4, a5 = ( + Address(email_address="a2"), + Address(email_address="a3"), + Address(email_address="a4"), + Address(email_address="a5"), + ) u1.addresses = [a1, a2] u1.addresses = [a2, a3, a4, a5] - self._assert_history(u1, - ([a3, a4, a5], [a2], [a1]), - compare_passive=([a3, a4, a5], [], [a1])) + self._assert_history( + u1, + ([a3, a4, a5], [a2], [a1]), + compare_passive=([a3, a4, a5], [], [a1]), + ) def test_persistent_but_readded_noflush(self): u1, a1, s = self._persistent_fixture(autoflush=False) @@ -1009,9 +1108,9 @@ class HistoryTest(_DynamicFixture, _fixtures.FixtureTest): u1.addresses.append(a1) - self._assert_history(u1, - ([], [a1], []), - compare_passive=([a1], [], [])) + self._assert_history( + u1, ([], [a1], []), compare_passive=([a1], [], []) + ) def test_persistent_but_readded_autoflush(self): u1, a1, s = self._persistent_fixture(autoflush=True) @@ -1020,9 +1119,9 @@ class HistoryTest(_DynamicFixture, _fixtures.FixtureTest): u1.addresses.append(a1) - self._assert_history(u1, - ([], [a1], []), - compare_passive=([a1], [], [])) + self._assert_history( + u1, ([], [a1], []), compare_passive=([a1], [], []) + ) def test_missing_but_removed_noflush(self): u1, a1, s = self._persistent_fixture(autoflush=False) diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py index a2a110b8c8..b38f23529d 100644 --- a/test/orm/test_eager_relations.py +++ b/test/orm/test_eager_relations.py @@ -3,14 +3,37 @@ from sqlalchemy.testing import eq_, is_, is_not_, in_ import sqlalchemy as sa from sqlalchemy import testing -from sqlalchemy.orm import joinedload, deferred, undefer, \ - joinedload_all, backref, Session,\ - defaultload, Load, load_only, contains_eager -from sqlalchemy import Integer, String, Date, ForeignKey, and_, select, \ - func, text +from sqlalchemy.orm import ( + joinedload, + deferred, + undefer, + joinedload_all, + backref, + Session, + defaultload, + Load, + load_only, + contains_eager, +) +from sqlalchemy import ( + Integer, + String, + Date, + ForeignKey, + and_, + select, + func, + text, +) from sqlalchemy.testing.schema import Table, Column -from sqlalchemy.orm import mapper, relationship, create_session, \ - lazyload, aliased, column_property +from sqlalchemy.orm import ( + mapper, + relationship, + create_session, + lazyload, + aliased, + column_property, +) from sqlalchemy.sql import operators from sqlalchemy.testing import assert_raises, assert_raises_message from sqlalchemy.testing.assertsql import CompiledSQL @@ -21,27 +44,41 @@ import datetime class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): - run_inserts = 'once' + run_inserts = "once" run_deletes = None - __dialect__ = 'default' + __dialect__ = "default" def test_basic(self): users, Address, addresses, User = ( self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), lazy='joined', order_by=Address.id) - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), + lazy="joined", + order_by=Address.id, + ) + }, + ) sess = create_session() q = sess.query(User) - eq_([User(id=7, addresses=[ - Address(id=1, email_address='jack@bean.com')])], - q.filter(User.id == 7).all()) + eq_( + [ + User( + id=7, + addresses=[Address(id=1, email_address="jack@bean.com")], + ) + ], + q.filter(User.id == 7).all(), + ) eq_(self.static.user_address_result, q.order_by(User.id).all()) def test_late_compile(self): @@ -49,7 +86,8 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.User, self.classes.Address, self.tables.addresses, - self.tables.users) + self.tables.users, + ) m = mapper(User, users) sess = create_session() @@ -60,11 +98,20 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def go(): eq_( - [User(id=7, addresses=[ - Address(id=1, email_address='jack@bean.com')])], - sess.query(User).options( - joinedload('addresses')).filter(User.id == 7).all() + [ + User( + id=7, + addresses=[ + Address(id=1, email_address="jack@bean.com") + ], + ) + ], + sess.query(User) + .options(joinedload("addresses")) + .filter(User.id == 7) + .all(), ) + self.assert_sql_count(testing.db, go, 1) def test_no_orphan(self): @@ -74,163 +121,209 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship( - Address, cascade="all,delete-orphan", lazy='joined') - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, cascade="all,delete-orphan", lazy="joined" + ) + }, + ) mapper(Address, addresses) sess = create_session() user = sess.query(User).get(7) - assert getattr(User, 'addresses').\ - hasparent( - sa.orm.attributes.instance_state( - user.addresses[0]), optimistic=True) - assert not sa.orm.class_mapper(Address).\ - _is_orphan( - sa.orm.attributes.instance_state(user.addresses[0])) + assert getattr(User, "addresses").hasparent( + sa.orm.attributes.instance_state(user.addresses[0]), + optimistic=True, + ) + assert not sa.orm.class_mapper(Address)._is_orphan( + sa.orm.attributes.instance_state(user.addresses[0]) + ) def test_orderby(self): users, Address, addresses, User = ( self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), - lazy='joined', order_by=addresses.c.email_address), - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), + lazy="joined", + order_by=addresses.c.email_address, + ) + }, + ) q = create_session().query(User) - eq_([ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - Address(id=2, email_address='ed@wood.com') - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[]) - ], q.order_by(User.id).all()) + eq_( + [ + User(id=7, addresses=[Address(id=1)]), + User( + id=8, + addresses=[ + Address(id=3, email_address="ed@bettyboop.com"), + Address(id=4, email_address="ed@lala.com"), + Address(id=2, email_address="ed@wood.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + User(id=10, addresses=[]), + ], + q.order_by(User.id).all(), + ) def test_orderby_multi(self): users, Address, addresses, User = ( self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), - lazy='joined', - order_by=[addresses.c.email_address, addresses.c.id]), - }) + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), + lazy="joined", + order_by=[addresses.c.email_address, addresses.c.id], + ) + }, + ) q = create_session().query(User) - eq_([ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - Address(id=2, email_address='ed@wood.com') - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[]) - ], q.order_by(User.id).all()) + eq_( + [ + User(id=7, addresses=[Address(id=1)]), + User( + id=8, + addresses=[ + Address(id=3, email_address="ed@bettyboop.com"), + Address(id=4, email_address="ed@lala.com"), + Address(id=2, email_address="ed@wood.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + User(id=10, addresses=[]), + ], + q.order_by(User.id).all(), + ) def test_orderby_related(self): """A regular mapper select on a single table can order by a relationship to a second table""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship( - Address, lazy='joined', order_by=addresses.c.id), - )) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, lazy="joined", order_by=addresses.c.id + ) + ), + ) q = create_session().query(User) - result = q.filter(User.id == Address.user_id).order_by( - Address.email_address).all() - - eq_([ - User(id=8, addresses=[ - Address(id=2, email_address='ed@wood.com'), - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=7, addresses=[ - Address(id=1) - ]), - ], result) + result = ( + q.filter(User.id == Address.user_id) + .order_by(Address.email_address) + .all() + ) + + eq_( + [ + User( + id=8, + addresses=[ + Address(id=2, email_address="ed@wood.com"), + Address(id=3, email_address="ed@bettyboop.com"), + Address(id=4, email_address="ed@lala.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + User(id=7, addresses=[Address(id=1)]), + ], + result, + ) def test_orderby_desc(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship( - Address, lazy='joined', - order_by=[sa.desc(addresses.c.email_address)]), - )) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, + lazy="joined", + order_by=[sa.desc(addresses.c.email_address)], + ) + ), + ) sess = create_session() - eq_([ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=2, email_address='ed@wood.com'), - Address(id=4, email_address='ed@lala.com'), - Address(id=3, email_address='ed@bettyboop.com'), - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[]) - ], sess.query(User).order_by(User.id).all()) + eq_( + [ + User(id=7, addresses=[Address(id=1)]), + User( + id=8, + addresses=[ + Address(id=2, email_address="ed@wood.com"), + Address(id=4, email_address="ed@lala.com"), + Address(id=3, email_address="ed@bettyboop.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + User(id=10, addresses=[]), + ], + sess.query(User).order_by(User.id).all(), + ) def test_no_ad_hoc_orderby(self): """part of #2992; make sure string label references can't access an eager loader, else an eager load can corrupt the query. """ - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship( - Address), - )) + mapper(User, users, properties=dict(addresses=relationship(Address))) sess = create_session() - q = sess.query(User).\ - join("addresses").\ - options(joinedload("addresses")).\ - order_by("email_address") + q = ( + sess.query(User) + .join("addresses") + .options(joinedload("addresses")) + .order_by("email_address") + ) self.assert_compile( q, @@ -240,11 +333,14 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "addresses_1_email_address FROM users JOIN addresses " "ON users.id = addresses.user_id LEFT OUTER JOIN addresses " "AS addresses_1 ON users.id = addresses_1.user_id " - "ORDER BY addresses.email_address" + "ORDER BY addresses.email_address", ) - q = sess.query(User).options(joinedload("addresses")).\ - order_by("email_address") + q = ( + sess.query(User) + .options(joinedload("addresses")) + .order_by("email_address") + ) with expect_warnings("Can't resolve label reference 'email_address'"): self.assert_compile( @@ -254,7 +350,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "addresses_1_user_id, addresses_1.email_address AS " "addresses_1_email_address FROM users LEFT OUTER JOIN " "addresses AS addresses_1 ON users.id = addresses_1.user_id " - "ORDER BY email_address" + "ORDER BY email_address", ) def test_deferred_fk_col(self): @@ -264,30 +360,39 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.User, self.tables.dingalings, self.classes.Address, - self.tables.addresses) + self.tables.addresses, + ) - mapper(Address, addresses, properties={ - 'user_id': deferred(addresses.c.user_id), - 'user': relationship(User, lazy='joined') - }) + mapper( + Address, + addresses, + properties={ + "user_id": deferred(addresses.c.user_id), + "user": relationship(User, lazy="joined"), + }, + ) mapper(User, users) sess = create_session() for q in [ - sess.query(Address).filter( - Address.id.in_([1, 4, 5]) - ).order_by(Address.id), - sess.query(Address).filter( - Address.id.in_([1, 4, 5]) - ).order_by(Address.id).limit(3) + sess.query(Address) + .filter(Address.id.in_([1, 4, 5])) + .order_by(Address.id), + sess.query(Address) + .filter(Address.id.in_([1, 4, 5])) + .order_by(Address.id) + .limit(3), ]: sess.expunge_all() - eq_(q.all(), - [Address(id=1, user=User(id=7)), - Address(id=4, user=User(id=8)), - Address(id=5, user=User(id=9))] - ) + eq_( + q.all(), + [ + Address(id=1, user=User(id=7)), + Address(id=4, user=User(id=8)), + Address(id=5, user=User(id=9)), + ], + ) sess.expunge_all() a = sess.query(Address).filter(Address.id == 1).all()[0] @@ -296,6 +401,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): # if the user wants a column undeferred, add the option. def go(): eq_(a.user_id, 7) + # self.assert_sql_count(testing.db, go, 0) self.assert_sql_count(testing.db, go, 1) @@ -304,6 +410,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def go(): eq_(a.user_id, 7) + # same, 1.0 doesn't check these # self.assert_sql_count(testing.db, go, 0) self.assert_sql_count(testing.db, go, 1) @@ -314,105 +421,145 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): # trigger, etc.) sa.orm.clear_mappers() - mapper(Address, addresses, properties={ - 'user_id': deferred(addresses.c.user_id), - }) - mapper(User, users, properties={ - 'addresses': relationship(Address, lazy='joined')}) + mapper( + Address, + addresses, + properties={"user_id": deferred(addresses.c.user_id)}, + ) + mapper( + User, + users, + properties={"addresses": relationship(Address, lazy="joined")}, + ) for q in [ sess.query(User).filter(User.id == 7), - sess.query(User).filter(User.id == 7).limit(1) + sess.query(User).filter(User.id == 7).limit(1), ]: sess.expunge_all() - eq_(q.all(), - [User(id=7, addresses=[Address(id=1)])] - ) + eq_(q.all(), [User(id=7, addresses=[Address(id=1)])]) sess.expunge_all() u = sess.query(User).get(7) def go(): eq_(u.addresses[0].user_id, 7) + # assert that the eager loader didn't have to affect 'user_id' here # and that its still deferred self.assert_sql_count(testing.db, go, 1) sa.orm.clear_mappers() - mapper(User, users, properties={ - 'addresses': relationship(Address, lazy='joined', - order_by=addresses.c.id)}) - mapper(Address, addresses, properties={ - 'user_id': deferred(addresses.c.user_id), - 'dingalings': relationship(Dingaling, lazy='joined')}) - mapper(Dingaling, dingalings, properties={ - 'address_id': deferred(dingalings.c.address_id)}) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, lazy="joined", order_by=addresses.c.id + ) + }, + ) + mapper( + Address, + addresses, + properties={ + "user_id": deferred(addresses.c.user_id), + "dingalings": relationship(Dingaling, lazy="joined"), + }, + ) + mapper( + Dingaling, + dingalings, + properties={"address_id": deferred(dingalings.c.address_id)}, + ) sess.expunge_all() def go(): u = sess.query(User).get(8) - eq_(User(id=8, - addresses=[Address(id=2, dingalings=[Dingaling(id=1)]), - Address(id=3), - Address(id=4)]), - u) + eq_( + User( + id=8, + addresses=[ + Address(id=2, dingalings=[Dingaling(id=1)]), + Address(id=3), + Address(id=4), + ], + ), + u, + ) + self.assert_sql_count(testing.db, go, 1) def test_options_pathing(self): - users, Keyword, orders, items, order_items, \ - Order, Item, User, keywords, item_keywords = ( - self.tables.users, - self.classes.Keyword, - self.tables.orders, - self.tables.items, - self.tables.order_items, - self.classes.Order, - self.classes.Item, - self.classes.User, - self.tables.keywords, - self.tables.item_keywords) - - mapper(User, users, properties={ - 'orders': relationship(Order, order_by=orders.c.id), # o2m, m2o - }) - mapper(Order, orders, properties={ - 'items': relationship( - Item, - secondary=order_items, order_by=items.c.id), # m2m - }) - mapper(Item, items, properties={ - 'keywords': relationship(Keyword, - secondary=item_keywords, - order_by=keywords.c.id) # m2m - }) + users, Keyword, orders, items, order_items, Order, Item, User, keywords, item_keywords = ( + self.tables.users, + self.classes.Keyword, + self.tables.orders, + self.tables.items, + self.tables.order_items, + self.classes.Order, + self.classes.Item, + self.classes.User, + self.tables.keywords, + self.tables.item_keywords, + ) + + mapper( + User, + users, + properties={ + "orders": relationship(Order, order_by=orders.c.id) # o2m, m2o + }, + ) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, secondary=order_items, order_by=items.c.id + ) # m2m + }, + ) + mapper( + Item, + items, + properties={ + "keywords": relationship( + Keyword, secondary=item_keywords, order_by=keywords.c.id + ) # m2m + }, + ) mapper(Keyword, keywords) for opt, count in [ - (( - joinedload(User.orders, Order.items), - ), 10), - ((joinedload("orders.items"), ), 10), - (( - joinedload(User.orders, ), - joinedload(User.orders, Order.items), - joinedload(User.orders, Order.items, Item.keywords), - ), 1), - (( - joinedload(User.orders, Order.items, Item.keywords), - ), 10), - (( - joinedload(User.orders, Order.items), - joinedload(User.orders, Order.items, Item.keywords), - ), 5), + ((joinedload(User.orders, Order.items),), 10), + ((joinedload("orders.items"),), 10), + ( + ( + joinedload(User.orders), + joinedload(User.orders, Order.items), + joinedload(User.orders, Order.items, Item.keywords), + ), + 1, + ), + ((joinedload(User.orders, Order.items, Item.keywords),), 10), + ( + ( + joinedload(User.orders, Order.items), + joinedload(User.orders, Order.items, Item.keywords), + ), + 5, + ), ]: sess = create_session() def go(): eq_( sess.query(User).options(*opt).order_by(User.id).all(), - self.static.user_item_keyword_result + self.static.user_item_keyword_result, ) + self.assert_sql_count(testing.db, go, count) def test_disable_dynamic(self): @@ -422,11 +569,14 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship(Address, lazy="dynamic") - }) + mapper( + User, + users, + properties={"addresses": relationship(Address, lazy="dynamic")}, + ) mapper(Address, addresses) sess = create_session() assert_raises_message( @@ -442,28 +592,48 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.tables.items, self.tables.item_keywords, self.classes.Keyword, - self.classes.Item) + self.classes.Item, + ) mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords=relationship(Keyword, secondary=item_keywords, - lazy='joined', order_by=keywords.c.id))) + mapper( + Item, + items, + properties=dict( + keywords=relationship( + Keyword, + secondary=item_keywords, + lazy="joined", + order_by=keywords.c.id, + ) + ), + ) q = create_session().query(Item).order_by(Item.id) def go(): eq_(self.static.item_keyword_result, q.all()) + self.assert_sql_count(testing.db, go, 1) def go(): - eq_(self.static.item_keyword_result[0:2], - q.join('keywords').filter(Keyword.name == 'red').all()) + eq_( + self.static.item_keyword_result[0:2], + q.join("keywords").filter(Keyword.name == "red").all(), + ) + self.assert_sql_count(testing.db, go, 1) def go(): - eq_(self.static.item_keyword_result[0:2], - (q.join('keywords', aliased=True). - filter(Keyword.name == 'red')).all()) + eq_( + self.static.item_keyword_result[0:2], + ( + q.join("keywords", aliased=True).filter( + Keyword.name == "red" + ) + ).all(), + ) + self.assert_sql_count(testing.db, go, 1) def test_eager_option(self): @@ -472,46 +642,69 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.tables.items, self.tables.item_keywords, self.classes.Keyword, - self.classes.Item) + self.classes.Item, + ) mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords=relationship( - Keyword, secondary=item_keywords, lazy='select', - order_by=keywords.c.id))) + mapper( + Item, + items, + properties=dict( + keywords=relationship( + Keyword, + secondary=item_keywords, + lazy="select", + order_by=keywords.c.id, + ) + ), + ) q = create_session().query(Item) def go(): - eq_(self.static.item_keyword_result[0:2], - (q.options( - joinedload('keywords') - ).join('keywords'). - filter(keywords.c.name == 'red')).order_by(Item.id).all()) + eq_( + self.static.item_keyword_result[0:2], + ( + q.options(joinedload("keywords")) + .join("keywords") + .filter(keywords.c.name == "red") + ) + .order_by(Item.id) + .all(), + ) self.assert_sql_count(testing.db, go, 1) def test_cyclical(self): """A circular eager relationship breaks the cycle with a lazy loader""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship( - Address, lazy='joined', - backref=sa.orm.backref('user', lazy='joined'), - order_by=Address.id) - )) - eq_(sa.orm.class_mapper(User).get_property('addresses').lazy, 'joined') - eq_(sa.orm.class_mapper(Address).get_property('user').lazy, 'joined') + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, + lazy="joined", + backref=sa.orm.backref("user", lazy="joined"), + order_by=Address.id, + ) + ), + ) + eq_(sa.orm.class_mapper(User).get_property("addresses").lazy, "joined") + eq_(sa.orm.class_mapper(Address).get_property("user").lazy, "joined") sess = create_session() eq_( self.static.user_address_result, - sess.query(User).order_by(User.id).all()) + sess.query(User).order_by(User.id).all(), + ) def test_double(self): """Eager loading with two relationships simultaneously, @@ -523,10 +716,11 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.User, self.classes.Address, self.classes.Order, - self.tables.addresses) + self.tables.addresses, + ) - openorders = sa.alias(orders, 'openorders') - closedorders = sa.alias(orders, 'closedorders') + openorders = sa.alias(orders, "openorders") + closedorders = sa.alias(orders, "closedorders") mapper(Address, addresses) mapper(Order, orders) @@ -534,124 +728,173 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): open_mapper = mapper(Order, openorders, non_primary=True) closed_mapper = mapper(Order, closedorders, non_primary=True) - mapper(User, users, properties=dict( - addresses=relationship( - Address, lazy='joined', order_by=addresses.c.id), - open_orders=relationship( - open_mapper, - primaryjoin=sa.and_(openorders.c.isopen == 1, - users.c.id == openorders.c.user_id), - lazy='joined', order_by=openorders.c.id), - closed_orders=relationship( - closed_mapper, - primaryjoin=sa.and_(closedorders.c.isopen == 0, - users.c.id == closedorders.c.user_id), - lazy='joined', order_by=closedorders.c.id))) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, lazy="joined", order_by=addresses.c.id + ), + open_orders=relationship( + open_mapper, + primaryjoin=sa.and_( + openorders.c.isopen == 1, + users.c.id == openorders.c.user_id, + ), + lazy="joined", + order_by=openorders.c.id, + ), + closed_orders=relationship( + closed_mapper, + primaryjoin=sa.and_( + closedorders.c.isopen == 0, + users.c.id == closedorders.c.user_id, + ), + lazy="joined", + order_by=closedorders.c.id, + ), + ), + ) q = create_session().query(User).order_by(User.id) def go(): - eq_([ - User( - id=7, - addresses=[Address(id=1)], - open_orders=[Order(id=3)], - closed_orders=[Order(id=1), Order(id=5)] - ), - User( - id=8, - addresses=[Address(id=2), Address(id=3), Address(id=4)], - open_orders=[], - closed_orders=[] - ), - User( - id=9, - addresses=[Address(id=5)], - open_orders=[Order(id=4)], - closed_orders=[Order(id=2)] - ), - User(id=10) + eq_( + [ + User( + id=7, + addresses=[Address(id=1)], + open_orders=[Order(id=3)], + closed_orders=[Order(id=1), Order(id=5)], + ), + User( + id=8, + addresses=[ + Address(id=2), + Address(id=3), + Address(id=4), + ], + open_orders=[], + closed_orders=[], + ), + User( + id=9, + addresses=[Address(id=5)], + open_orders=[Order(id=4)], + closed_orders=[Order(id=2)], + ), + User(id=10), + ], + q.all(), + ) - ], q.all()) self.assert_sql_count(testing.db, go, 1) def test_double_same_mappers(self): """Eager loading with two relationships simultaneously, from the same table, using aliases.""" - addresses, items, order_items, orders, \ - Item, User, Address, Order, users = ( - self.tables.addresses, - self.tables.items, - self.tables.order_items, - self.tables.orders, - self.classes.Item, - self.classes.User, - self.classes.Address, - self.classes.Order, - self.tables.users) + addresses, items, order_items, orders, Item, User, Address, Order, users = ( + self.tables.addresses, + self.tables.items, + self.tables.order_items, + self.tables.orders, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.classes.Order, + self.tables.users, + ) mapper(Address, addresses) - mapper(Order, orders, properties={ - 'items': relationship(Item, secondary=order_items, lazy='joined', - order_by=items.c.id)}) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, + secondary=order_items, + lazy="joined", + order_by=items.c.id, + ) + }, + ) mapper(Item, items) - mapper(User, users, properties=dict( - addresses=relationship( - Address, lazy='joined', order_by=addresses.c.id), - open_orders=relationship( - Order, - primaryjoin=sa.and_(orders.c.isopen == 1, - users.c.id == orders.c.user_id), - lazy='joined', order_by=orders.c.id), - closed_orders=relationship( - Order, - primaryjoin=sa.and_(orders.c.isopen == 0, - users.c.id == orders.c.user_id), - lazy='joined', order_by=orders.c.id))) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, lazy="joined", order_by=addresses.c.id + ), + open_orders=relationship( + Order, + primaryjoin=sa.and_( + orders.c.isopen == 1, users.c.id == orders.c.user_id + ), + lazy="joined", + order_by=orders.c.id, + ), + closed_orders=relationship( + Order, + primaryjoin=sa.and_( + orders.c.isopen == 0, users.c.id == orders.c.user_id + ), + lazy="joined", + order_by=orders.c.id, + ), + ), + ) q = create_session().query(User).order_by(User.id) def go(): - eq_([ - User(id=7, - addresses=[ - Address(id=1)], - open_orders=[Order(id=3, - items=[ - Item(id=3), - Item(id=4), - Item(id=5)])], - closed_orders=[Order(id=1, - items=[ - Item(id=1), - Item(id=2), - Item(id=3)]), - Order(id=5, - items=[ - Item(id=5)])]), - User(id=8, - addresses=[ - Address(id=2), - Address(id=3), - Address(id=4)], - open_orders=[], - closed_orders=[]), - User(id=9, - addresses=[ - Address(id=5)], - open_orders=[ - Order(id=4, - items=[ - Item(id=1), - Item(id=5)])], - closed_orders=[ - Order(id=2, - items=[ - Item(id=1), - Item(id=2), - Item(id=3)])]), - User(id=10) - ], q.all()) + eq_( + [ + User( + id=7, + addresses=[Address(id=1)], + open_orders=[ + Order( + id=3, + items=[Item(id=3), Item(id=4), Item(id=5)], + ) + ], + closed_orders=[ + Order( + id=1, + items=[Item(id=1), Item(id=2), Item(id=3)], + ), + Order(id=5, items=[Item(id=5)]), + ], + ), + User( + id=8, + addresses=[ + Address(id=2), + Address(id=3), + Address(id=4), + ], + open_orders=[], + closed_orders=[], + ), + User( + id=9, + addresses=[Address(id=5)], + open_orders=[ + Order(id=4, items=[Item(id=1), Item(id=5)]) + ], + closed_orders=[ + Order( + id=2, + items=[Item(id=1), Item(id=2), Item(id=3)], + ) + ], + ), + User(id=10), + ], + q.all(), + ) + self.assert_sql_count(testing.db, go, 1) def test_no_false_hits(self): @@ -664,12 +907,17 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.User, self.classes.Address, self.classes.Order, - self.tables.users) + self.tables.users, + ) - mapper(User, users, properties={ - 'addresses': relationship(Address, lazy='joined'), - 'orders': relationship(Order, lazy='joined') - }) + mapper( + User, + users, + properties={ + "addresses": relationship(Address, lazy="joined"), + "orders": relationship(Order, lazy="joined"), + }, + ) mapper(Address, addresses) mapper(Order, orders) @@ -679,37 +927,57 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): # eager loaders have aliases which should not hit on those columns, # they should be required to locate only their aliased/fully table # qualified column name. - noeagers = create_session().query(User).\ - from_statement(text("select * from users")).all() - assert 'orders' not in noeagers[0].__dict__ - assert 'addresses' not in noeagers[0].__dict__ + noeagers = ( + create_session() + .query(User) + .from_statement(text("select * from users")) + .all() + ) + assert "orders" not in noeagers[0].__dict__ + assert "addresses" not in noeagers[0].__dict__ def test_limit(self): """Limit operations combined with lazy-load relationships.""" - users, items, order_items, orders, Item, \ - User, Address, Order, addresses = ( - self.tables.users, - self.tables.items, - self.tables.order_items, - self.tables.orders, - self.classes.Item, - self.classes.User, - self.classes.Address, - self.classes.Order, - self.tables.addresses) + users, items, order_items, orders, Item, User, Address, Order, addresses = ( + self.tables.users, + self.tables.items, + self.tables.order_items, + self.tables.orders, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.classes.Order, + self.tables.addresses, + ) mapper(Item, items) - mapper(Order, orders, properties={ - 'items': relationship(Item, secondary=order_items, lazy='joined', - order_by=items.c.id) - }) - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), - lazy='joined', order_by=addresses.c.id), - 'orders': relationship(Order, lazy='select', order_by=orders.c.id) - }) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, + secondary=order_items, + lazy="joined", + order_by=items.c.id, + ) + }, + ) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), + lazy="joined", + order_by=addresses.c.id, + ), + "orders": relationship( + Order, lazy="select", order_by=orders.c.id + ), + }, + ) sess = create_session() q = sess.query(User) @@ -718,33 +986,48 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): eq_(self.static.user_all_result[1:3], result) def test_distinct(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) # this is an involved 3x union of the users table to get a lot of rows. # then see if the "distinct" works its way out. you actually get # the same result with or without the distinct, just via less or # more rows. - u2 = users.alias('u2') + u2 = users.alias("u2") s = sa.union_all( - u2.select(use_labels=True), u2.select(use_labels=True), - u2.select(use_labels=True)).alias('u') + u2.select(use_labels=True), + u2.select(use_labels=True), + u2.select(use_labels=True), + ).alias("u") - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), - lazy='joined', order_by=addresses.c.id), - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), + lazy="joined", + order_by=addresses.c.id, + ) + }, + ) sess = create_session() q = sess.query(User) def go(): - result = q.filter(s.c.u2_id == User.id).distinct().\ - order_by(User.id).all() + result = ( + q.filter(s.c.u2_id == User.id) + .distinct() + .order_by(User.id) + .all() + ) eq_(self.static.user_address_result, result) + self.assert_sql_count(testing.db, go, 1) def test_limit_2(self): @@ -753,21 +1036,35 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.tables.items, self.tables.item_keywords, self.classes.Keyword, - self.classes.Item) + self.classes.Item, + ) mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords=relationship( - Keyword, secondary=item_keywords, - lazy='joined', order_by=[keywords.c.id]), - )) + mapper( + Item, + items, + properties=dict( + keywords=relationship( + Keyword, + secondary=item_keywords, + lazy="joined", + order_by=[keywords.c.id], + ) + ), + ) sess = create_session() q = sess.query(Item) - result = q.filter((Item.description == 'item 2') | - (Item.description == 'item 5') | - (Item.description == 'item 3')).\ - order_by(Item.id).limit(2).all() + result = ( + q.filter( + (Item.description == "item 2") + | (Item.description == "item 5") + | (Item.description == "item 3") + ) + .order_by(Item.id) + .limit(2) + .all() + ) eq_(self.static.item_keyword_result[1:3], result) @@ -777,94 +1074,145 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): 'wrapped' select statement resulting from the combination of eager loading and limit/offset clauses.""" - addresses, items, order_items, orders, \ - Item, User, Address, Order, users = ( - self.tables.addresses, - self.tables.items, - self.tables.order_items, - self.tables.orders, - self.classes.Item, - self.classes.User, - self.classes.Address, - self.classes.Order, - self.tables.users) + addresses, items, order_items, orders, Item, User, Address, Order, users = ( + self.tables.addresses, + self.tables.items, + self.tables.order_items, + self.tables.orders, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.classes.Order, + self.tables.users, + ) mapper(Item, items) - mapper(Order, orders, properties=dict( - items=relationship(Item, secondary=order_items, lazy='joined') - )) + mapper( + Order, + orders, + properties=dict( + items=relationship(Item, secondary=order_items, lazy="joined") + ), + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship( - Address, lazy='joined', order_by=addresses.c.id), - orders=relationship(Order, lazy='joined', order_by=orders.c.id), - )) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, lazy="joined", order_by=addresses.c.id + ), + orders=relationship( + Order, lazy="joined", order_by=orders.c.id + ), + ), + ) sess = create_session() q = sess.query(User) - if not testing.against('mssql'): - result = q.join('orders').order_by( - Order.user_id.desc()).limit(2).offset(1) - eq_([ - User(id=9, - orders=[Order(id=2), Order(id=4)], - addresses=[Address(id=5)] - ), - User(id=7, - orders=[Order(id=1), Order(id=3), Order(id=5)], - addresses=[Address(id=1)] - ) - ], result.all()) - - result = q.join('addresses').order_by( - Address.email_address.desc()).limit(1).offset(0) - eq_([ - User(id=7, - orders=[Order(id=1), Order(id=3), Order(id=5)], - addresses=[Address(id=1)] - ) - ], result.all()) + if not testing.against("mssql"): + result = ( + q.join("orders") + .order_by(Order.user_id.desc()) + .limit(2) + .offset(1) + ) + eq_( + [ + User( + id=9, + orders=[Order(id=2), Order(id=4)], + addresses=[Address(id=5)], + ), + User( + id=7, + orders=[Order(id=1), Order(id=3), Order(id=5)], + addresses=[Address(id=1)], + ), + ], + result.all(), + ) + + result = ( + q.join("addresses") + .order_by(Address.email_address.desc()) + .limit(1) + .offset(0) + ) + eq_( + [ + User( + id=7, + orders=[Order(id=1), Order(id=3), Order(id=5)], + addresses=[Address(id=1)], + ) + ], + result.all(), + ) def test_limit_4(self): - User, Order, addresses, users, orders = (self.classes.User, - self.classes.Order, - self.tables.addresses, - self.tables.users, - self.tables.orders) + User, Order, addresses, users, orders = ( + self.classes.User, + self.classes.Order, + self.tables.addresses, + self.tables.users, + self.tables.orders, + ) # tests the LIMIT/OFFSET aliasing on a mapper # against a select. original issue from ticket #904 - sel = sa.select([users, addresses.c.email_address], - users.c.id == addresses.c.user_id).alias('useralias') - mapper(User, sel, properties={ - 'orders': relationship( - Order, primaryjoin=sel.c.id == orders.c.user_id, - lazy='joined', order_by=orders.c.id) - }) + sel = sa.select( + [users, addresses.c.email_address], + users.c.id == addresses.c.user_id, + ).alias("useralias") + mapper( + User, + sel, + properties={ + "orders": relationship( + Order, + primaryjoin=sel.c.id == orders.c.user_id, + lazy="joined", + order_by=orders.c.id, + ) + }, + ) mapper(Order, orders) sess = create_session() - eq_(sess.query(User).first(), - User(name='jack', orders=[ - Order( - address_id=1, - description='order 1', - isopen=0, - user_id=7, - id=1), - Order( - address_id=1, - description='order 3', - isopen=1, - user_id=7, - id=3), - Order( - address_id=None, description='order 5', isopen=0, - user_id=7, id=5)], - email_address='jack@bean.com', id=7) - ) + eq_( + sess.query(User).first(), + User( + name="jack", + orders=[ + Order( + address_id=1, + description="order 1", + isopen=0, + user_id=7, + id=1, + ), + Order( + address_id=1, + description="order 3", + isopen=1, + user_id=7, + id=3, + ), + Order( + address_id=None, + description="order 5", + isopen=0, + user_id=7, + id=5, + ), + ], + email_address="jack@bean.com", + id=7, + ), + ) def test_useget_cancels_eager(self): """test that a one to many lazyload cancels the unnecessary @@ -874,26 +1222,34 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship(User, lazy='joined', backref='addresses') - }) + mapper( + Address, + addresses, + properties={ + "user": relationship(User, lazy="joined", backref="addresses") + }, + ) sess = create_session() u1 = sess.query(User).filter(User.id == 8).one() def go(): eq_(u1.addresses[0].user, u1) + self.assert_sql_execution( - testing.db, go, + testing.db, + go, CompiledSQL( "SELECT addresses.id AS addresses_id, addresses.user_id AS " "addresses_user_id, addresses.email_address AS " "addresses_email_address FROM addresses WHERE :param_1 = " "addresses.user_id", - {'param_1': 8}) + {"param_1": 8}, + ), ) def test_useget_cancels_eager_propagated_present(self): @@ -905,12 +1261,17 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship(User, lazy='joined', backref='addresses') - }) + mapper( + Address, + addresses, + properties={ + "user": relationship(User, lazy="joined", backref="addresses") + }, + ) from sqlalchemy.orm.interfaces import MapperOption @@ -918,45 +1279,64 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): propagate_to_loaders = True sess = create_session() - u1 = sess.query(User).options( - MyBogusOption()).filter(User.id == 8).one() + u1 = ( + sess.query(User) + .options(MyBogusOption()) + .filter(User.id == 8) + .one() + ) def go(): eq_(u1.addresses[0].user, u1) + self.assert_sql_execution( - testing.db, go, + testing.db, + go, CompiledSQL( "SELECT addresses.id AS addresses_id, addresses.user_id AS " "addresses_user_id, addresses.email_address AS " "addresses_email_address FROM addresses WHERE :param_1 = " "addresses.user_id", - {'param_1': 8}) + {"param_1": 8}, + ), ) def test_manytoone_limit(self): """test that the subquery wrapping only occurs with limit/offset and m2m or o2m joins present.""" - users, items, order_items, Order, Item, User, \ - Address, orders, addresses = ( - self.tables.users, - self.tables.items, - self.tables.order_items, - self.classes.Order, - self.classes.Item, - self.classes.User, - self.classes.Address, - self.tables.orders, - self.tables.addresses) - - mapper(User, users, properties=odict( - orders=relationship(Order, backref='user') - )) - mapper(Order, orders, properties=odict([ - ('items', relationship(Item, secondary=order_items, - backref='orders')), - ('address', relationship(Address)) - ])) + users, items, order_items, Order, Item, User, Address, orders, addresses = ( + self.tables.users, + self.tables.items, + self.tables.order_items, + self.classes.Order, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.tables.orders, + self.tables.addresses, + ) + + mapper( + User, + users, + properties=odict(orders=relationship(Order, backref="user")), + ) + mapper( + Order, + orders, + properties=odict( + [ + ( + "items", + relationship( + Item, secondary=order_items, backref="orders" + ), + ), + ("address", relationship(Address)), + ] + ), + ) mapper(Address, addresses) mapper(Item, items) @@ -973,7 +1353,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "FROM users " "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN orders AS " "orders_1 ON anon_1.users_id = orders_1.user_id", - {'param_1': 10} + {"param_1": 10}, ) self.assert_compile( @@ -985,12 +1365,13 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "users_1.id AS users_1_id, users_1.name AS users_1_name " "FROM orders LEFT OUTER JOIN users AS " "users_1 ON users_1.id = orders.user_id LIMIT :param_1", - {'param_1': 10} + {"param_1": 10}, ) self.assert_compile( - sess.query(Order).options( - joinedload(Order.user, innerjoin=True)).limit(10), + sess.query(Order) + .options(joinedload(Order.user, innerjoin=True)) + .limit(10), "SELECT orders.id AS orders_id, orders.user_id AS orders_user_id, " "orders.address_id AS " "orders_address_id, orders.description AS orders_description, " @@ -998,12 +1379,13 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "users_1.id AS users_1_id, users_1.name AS users_1_name " "FROM orders JOIN users AS " "users_1 ON users_1.id = orders.user_id LIMIT :param_1", - {'param_1': 10} + {"param_1": 10}, ) self.assert_compile( - sess.query(User).options( - joinedload_all("orders.address")).limit(10), + sess.query(User) + .options(joinedload_all("orders.address")) + .limit(10), "SELECT anon_1.users_id AS anon_1_users_id, " "anon_1.users_name AS anon_1_users_name, " "addresses_1.id AS addresses_1_id, " @@ -1019,12 +1401,13 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "LEFT OUTER JOIN orders AS orders_1 " "ON anon_1.users_id = orders_1.user_id LEFT OUTER JOIN " "addresses AS addresses_1 ON addresses_1.id = orders_1.address_id", - {'param_1': 10} + {"param_1": 10}, ) self.assert_compile( - sess.query(User).options(joinedload_all("orders.items"), - joinedload("orders.address")), + sess.query(User).options( + joinedload_all("orders.items"), joinedload("orders.address") + ), "SELECT users.id AS users_id, users.name AS users_name, " "items_1.id AS items_1_id, " "items_1.description AS items_1_description, " @@ -1042,15 +1425,16 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "JOIN items AS items_1 ON items_1.id = order_items_1.item_id) " "ON orders_1.id = order_items_1.order_id " "LEFT OUTER JOIN addresses AS addresses_1 " - "ON addresses_1.id = orders_1.address_id" + "ON addresses_1.id = orders_1.address_id", ) self.assert_compile( - sess.query(User).options( + sess.query(User) + .options( joinedload("orders"), - joinedload( - "orders.address", - innerjoin=True)).limit(10), + joinedload("orders.address", innerjoin=True), + ) + .limit(10), "SELECT anon_1.users_id AS anon_1_users_id, anon_1.users_name " "AS anon_1_users_name, addresses_1.id AS addresses_1_id, " "addresses_1.user_id AS addresses_1_user_id, " @@ -1065,13 +1449,16 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "(orders AS orders_1 JOIN addresses AS addresses_1 " "ON addresses_1.id = orders_1.address_id) ON " "anon_1.users_id = orders_1.user_id", - {'param_1': 10} + {"param_1": 10}, ) self.assert_compile( - sess.query(User).options( + sess.query(User) + .options( joinedload("orders", innerjoin=True), - joinedload("orders.address", innerjoin=True)).limit(10), + joinedload("orders.address", innerjoin=True), + ) + .limit(10), "SELECT anon_1.users_id AS anon_1_users_id, " "anon_1.users_name AS anon_1_users_name, " "addresses_1.id AS addresses_1_id, " @@ -1088,36 +1475,51 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "AS orders_1 ON anon_1.users_id = " "orders_1.user_id JOIN addresses AS addresses_1 " "ON addresses_1.id = orders_1.address_id", - {'param_1': 10} + {"param_1": 10}, ) def test_one_to_many_scalar(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) - - mapper(User, users, properties=dict( - address=relationship(mapper(Address, addresses), - lazy='joined', uselist=False) - )) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) + + mapper( + User, + users, + properties=dict( + address=relationship( + mapper(Address, addresses), lazy="joined", uselist=False + ) + ), + ) q = create_session().query(User) def go(): result = q.filter(users.c.id == 7).all() eq_([User(id=7, address=Address(id=1))], result) + self.assert_sql_count(testing.db, go, 1) def test_one_to_many_scalar_subq_wrapping(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) - - mapper(User, users, properties=dict( - address=relationship(mapper(Address, addresses), - lazy='joined', uselist=False) - )) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) + + mapper( + User, + users, + properties=dict( + address=relationship( + mapper(Address, addresses), lazy="joined", uselist=False + ) + ), + ) q = create_session().query(User) q = q.filter(users.c.id == 7).limit(1) @@ -1131,7 +1533,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "ON users.id = addresses_1.user_id " "WHERE users.id = :id_1 " "LIMIT :param_1", - checkparams={'id_1': 7, 'param_1': 1} + checkparams={"id_1": 7, "param_1": 1}, ) def test_many_to_one(self): @@ -1139,11 +1541,16 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.classes.User, + ) - mapper(Address, addresses, properties=dict( - user=relationship(mapper(User, users), lazy='joined') - )) + mapper( + Address, + addresses, + properties=dict( + user=relationship(mapper(User, users), lazy="joined") + ), + ) sess = create_session() q = sess.query(Address) @@ -1152,6 +1559,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): is_not_(a.user, None) u1 = sess.query(User).get(7) is_(a.user, u1) + self.assert_sql_count(testing.db, go, 1) def test_many_to_one_null(self): @@ -1160,30 +1568,40 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): """ - Order, Address, addresses, orders = (self.classes.Order, - self.classes.Address, - self.tables.addresses, - self.tables.orders) + Order, Address, addresses, orders = ( + self.classes.Order, + self.classes.Address, + self.tables.addresses, + self.tables.orders, + ) # use a primaryjoin intended to defeat SA's usage of # query.get() for a many-to-one lazyload - mapper(Order, orders, properties=dict( - address=relationship( - mapper(Address, addresses), - primaryjoin=and_( - addresses.c.id == orders.c.address_id, - addresses.c.email_address != None # noqa - ), - - lazy='joined') - )) + mapper( + Order, + orders, + properties=dict( + address=relationship( + mapper(Address, addresses), + primaryjoin=and_( + addresses.c.id == orders.c.address_id, + addresses.c.email_address != None, # noqa + ), + lazy="joined", + ) + ), + ) sess = create_session() def go(): - o1 = sess.query(Order).options( - lazyload('address')).filter( - Order.id == 5).one() + o1 = ( + sess.query(Order) + .options(lazyload("address")) + .filter(Order.id == 5) + .one() + ) eq_(o1.address, None) + self.assert_sql_count(testing.db, go, 2) sess.expunge_all() @@ -1191,6 +1609,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def go(): o1 = sess.query(Order).filter(Order.id == 5).one() eq_(o1.address, None) + self.assert_sql_count(testing.db, go, 1) def test_one_and_many(self): @@ -1204,121 +1623,164 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.tables.orders, self.classes.Item, self.classes.User, - self.classes.Order) + self.classes.Order, + ) - mapper(User, users, properties={ - 'orders': relationship(Order, lazy='joined', order_by=orders.c.id) - }) + mapper( + User, + users, + properties={ + "orders": relationship( + Order, lazy="joined", order_by=orders.c.id + ) + }, + ) mapper(Item, items) - mapper(Order, orders, properties=dict( - items=relationship( - Item, - secondary=order_items, - lazy='joined', - order_by=items.c.id) - )) + mapper( + Order, + orders, + properties=dict( + items=relationship( + Item, + secondary=order_items, + lazy="joined", + order_by=items.c.id, + ) + ), + ) q = create_session().query(User) result = q.filter(text("users.id in (7, 8, 9)")).order_by( - text("users.id")) + text("users.id") + ) def go(): eq_(self.static.user_order_result[0:3], result.all()) + self.assert_sql_count(testing.db, go, 1) def test_double_with_aggregate(self): - User, users, orders, Order = (self.classes.User, - self.tables.users, - self.tables.orders, - self.classes.Order) + User, users, orders, Order = ( + self.classes.User, + self.tables.users, + self.tables.orders, + self.classes.Order, + ) - max_orders_by_user = sa.select([ - sa.func.max(orders.c.id).label('order_id')], - group_by=[orders.c.user_id] - ).alias('max_orders_by_user') + max_orders_by_user = sa.select( + [sa.func.max(orders.c.id).label("order_id")], + group_by=[orders.c.user_id], + ).alias("max_orders_by_user") max_orders = orders.select( - orders.c.id == max_orders_by_user.c.order_id).\ - alias('max_orders') + orders.c.id == max_orders_by_user.c.order_id + ).alias("max_orders") mapper(Order, orders) - mapper(User, users, properties={ - 'orders': relationship(Order, backref='user', lazy='joined', - order_by=orders.c.id), - 'max_order': relationship( - mapper(Order, max_orders, non_primary=True), - lazy='joined', uselist=False) - }) + mapper( + User, + users, + properties={ + "orders": relationship( + Order, backref="user", lazy="joined", order_by=orders.c.id + ), + "max_order": relationship( + mapper(Order, max_orders, non_primary=True), + lazy="joined", + uselist=False, + ), + }, + ) q = create_session().query(User) def go(): - eq_([ - User(id=7, orders=[ - Order(id=1), - Order(id=3), - Order(id=5), + eq_( + [ + User( + id=7, + orders=[Order(id=1), Order(id=3), Order(id=5)], + max_order=Order(id=5), + ), + User(id=8, orders=[]), + User( + id=9, + orders=[Order(id=2), Order(id=4)], + max_order=Order(id=4), + ), + User(id=10), ], - max_order=Order(id=5) - ), - User(id=8, orders=[]), - User(id=9, orders=[Order(id=2), Order(id=4)], - max_order=Order(id=4) - ), - User(id=10), - ], q.order_by(User.id).all()) + q.order_by(User.id).all(), + ) + self.assert_sql_count(testing.db, go, 1) def test_uselist_false_warning(self): """test that multiple rows received by a uselist=False raises a warning.""" - User, users, orders, Order = (self.classes.User, - self.tables.users, - self.tables.orders, - self.classes.Order) + User, users, orders, Order = ( + self.classes.User, + self.tables.users, + self.tables.orders, + self.classes.Order, + ) - mapper(User, users, properties={ - 'order': relationship(Order, uselist=False) - }) + mapper( + User, + users, + properties={"order": relationship(Order, uselist=False)}, + ) mapper(Order, orders) s = create_session() - assert_raises(sa.exc.SAWarning, - s.query(User).options(joinedload(User.order)).all) + assert_raises( + sa.exc.SAWarning, s.query(User).options(joinedload(User.order)).all + ) def test_wide(self): - users, items, order_items, Order, Item, \ - User, Address, orders, addresses = ( - self.tables.users, - self.tables.items, - self.tables.order_items, - self.classes.Order, - self.classes.Item, - self.classes.User, - self.classes.Address, - self.tables.orders, - self.tables.addresses) + users, items, order_items, Order, Item, User, Address, orders, addresses = ( + self.tables.users, + self.tables.items, + self.tables.order_items, + self.classes.Order, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.tables.orders, + self.tables.addresses, + ) mapper( - Order, orders, properties={ - 'items': relationship( - Item, secondary=order_items, lazy='joined', - order_by=items.c.id)}) + Order, + orders, + properties={ + "items": relationship( + Item, + secondary=order_items, + lazy="joined", + order_by=items.c.id, + ) + }, + ) mapper(Item, items) - mapper(User, users, properties=dict( - addresses=relationship( - mapper( - Address, - addresses), - lazy=False, - order_by=addresses.c.id), - orders=relationship(Order, lazy=False, order_by=orders.c.id), - )) + mapper( + User, + users, + properties=dict( + addresses=relationship( + mapper(Address, addresses), + lazy=False, + order_by=addresses.c.id, + ), + orders=relationship(Order, lazy=False, order_by=orders.c.id), + ), + ) q = create_session().query(User) def go(): eq_(self.static.user_all_result, q.order_by(User.id).all()) + self.assert_sql_count(testing.db, go, 1) def test_against_select(self): @@ -1331,64 +1793,93 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.tables.orders, self.classes.Item, self.classes.User, - self.classes.Order) + self.classes.Order, + ) - s = sa.select([orders], orders.c.isopen == 1).alias('openorders') + s = sa.select([orders], orders.c.isopen == 1).alias("openorders") - mapper(Order, s, properties={ - 'user': relationship(User, lazy='joined') - }) + mapper( + Order, s, properties={"user": relationship(User, lazy="joined")} + ) mapper(User, users) mapper(Item, items) q = create_session().query(Order) - eq_([ - Order(id=3, user=User(id=7)), - Order(id=4, user=User(id=9)) - ], q.all()) + eq_( + [Order(id=3, user=User(id=7)), Order(id=4, user=User(id=9))], + q.all(), + ) q = q.select_from(s.join(order_items).join(items)).filter( - ~Item.id.in_([1, 2, 5])) - eq_([ - Order(id=3, user=User(id=7)), - ], q.all()) + ~Item.id.in_([1, 2, 5]) + ) + eq_([Order(id=3, user=User(id=7))], q.all()) def test_aliasing(self): """test that eager loading uses aliases to insulate the eager load from regular criterion against those tables.""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) - mapper(User, users, properties=dict( - addresses=relationship(mapper(Address, addresses), - lazy='joined', order_by=addresses.c.id) - )) + mapper( + User, + users, + properties=dict( + addresses=relationship( + mapper(Address, addresses), + lazy="joined", + order_by=addresses.c.id, + ) + ), + ) q = create_session().query(User) - result = q.filter(addresses.c.email_address == 'ed@lala.com').filter( - Address.user_id == User.id).order_by(User.id) + result = ( + q.filter(addresses.c.email_address == "ed@lala.com") + .filter(Address.user_id == User.id) + .order_by(User.id) + ) eq_(self.static.user_address_result[1:2], result.all()) def test_inner_join(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) - - mapper(User, users, properties=dict( - addresses=relationship(mapper(Address, addresses), lazy='joined', - innerjoin=True, order_by=addresses.c.id) - )) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) + + mapper( + User, + users, + properties=dict( + addresses=relationship( + mapper(Address, addresses), + lazy="joined", + innerjoin=True, + order_by=addresses.c.id, + ) + ), + ) sess = create_session() eq_( - [User(id=7, addresses=[Address(id=1)]), - User(id=8, - addresses=[Address(id=2, email_address='ed@wood.com'), - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), ]), - User(id=9, addresses=[Address(id=5)])], sess.query(User).all() + [ + User(id=7, addresses=[Address(id=1)]), + User( + id=8, + addresses=[ + Address(id=2, email_address="ed@wood.com"), + Address(id=3, email_address="ed@bettyboop.com"), + Address(id=4, email_address="ed@lala.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + ], + sess.query(User).all(), ) self.assert_compile( sess.query(User), @@ -1398,7 +1889,8 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "addresses_1.email_address AS addresses_1_email_address " "FROM users JOIN " "addresses AS addresses_1 ON users.id = addresses_1.user_id " - "ORDER BY addresses_1.id") + "ORDER BY addresses_1.id", + ) def test_inner_join_unnested_chaining_options(self): users, items, order_items, Order, Item, User, orders = ( @@ -1408,16 +1900,28 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.Order, self.classes.Item, self.classes.User, - self.tables.orders) - - mapper(User, users, properties=dict( - orders=relationship(Order, innerjoin="unnested", - lazy=False) - )) - mapper(Order, orders, properties=dict( - items=relationship(Item, secondary=order_items, lazy=False, - innerjoin="unnested") - )) + self.tables.orders, + ) + + mapper( + User, + users, + properties=dict( + orders=relationship(Order, innerjoin="unnested", lazy=False) + ), + ) + mapper( + Order, + orders, + properties=dict( + items=relationship( + Item, + secondary=order_items, + lazy=False, + innerjoin="unnested", + ) + ), + ) mapper(Item, items) sess = create_session() @@ -1436,7 +1940,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "users.id = orders_1.user_id JOIN order_items AS order_items_1 " "ON orders_1.id = " "order_items_1.order_id JOIN items AS items_1 ON items_1.id = " - "order_items_1.item_id" + "order_items_1.item_id", ) self.assert_compile( @@ -1454,15 +1958,13 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "ON users.id = orders_1.user_id " "LEFT OUTER JOIN (order_items AS order_items_1 " "JOIN items AS items_1 ON items_1.id = order_items_1.item_id) " - "ON orders_1.id = order_items_1.order_id" + "ON orders_1.id = order_items_1.order_id", ) self.assert_compile( sess.query(User).options( - joinedload( - User.orders, - Order.items, - innerjoin=False)), + joinedload(User.orders, Order.items, innerjoin=False) + ), "SELECT users.id AS users_id, users.name AS users_name, " "items_1.id AS " "items_1_id, items_1.description AS items_1_description, " @@ -1476,8 +1978,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "users.id = orders_1.user_id " "LEFT OUTER JOIN (order_items AS order_items_1 " "JOIN items AS items_1 ON items_1.id = order_items_1.item_id) " - "ON orders_1.id = order_items_1.order_id" - + "ON orders_1.id = order_items_1.order_id", ) def test_inner_join_nested_chaining_negative_options(self): @@ -1488,16 +1989,31 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.Order, self.classes.Item, self.classes.User, - self.tables.orders) - - mapper(User, users, properties=dict( - orders=relationship(Order, innerjoin=True, - lazy=False, order_by=orders.c.id) - )) - mapper(Order, orders, properties=dict( - items=relationship(Item, secondary=order_items, lazy=False, - innerjoin=True, order_by=items.c.id) - )) + self.tables.orders, + ) + + mapper( + User, + users, + properties=dict( + orders=relationship( + Order, innerjoin=True, lazy=False, order_by=orders.c.id + ) + ), + ) + mapper( + Order, + orders, + properties=dict( + items=relationship( + Item, + secondary=order_items, + lazy=False, + innerjoin=True, + order_by=items.c.id, + ) + ), + ) mapper(Item, items) sess = create_session() @@ -1516,7 +2032,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "users.id = orders_1.user_id JOIN order_items " "AS order_items_1 ON orders_1.id = " "order_items_1.order_id JOIN items AS items_1 ON items_1.id = " - "order_items_1.item_id ORDER BY orders_1.id, items_1.id" + "order_items_1.item_id ORDER BY orders_1.id, items_1.id", ) q = sess.query(User).options(joinedload(User.orders, innerjoin=False)) @@ -1535,43 +2051,42 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "(orders AS orders_1 JOIN order_items AS order_items_1 " "ON orders_1.id = order_items_1.order_id " "JOIN items AS items_1 ON items_1.id = order_items_1.item_id) " - "ON users.id = orders_1.user_id ORDER BY orders_1.id, items_1.id" + "ON users.id = orders_1.user_id ORDER BY orders_1.id, items_1.id", ) eq_( [ - User(id=7, - orders=[ - Order( - id=1, items=[ - Item( - id=1), Item( - id=2), Item( - id=3)]), - Order( - id=3, items=[ - Item( - id=3), Item( - id=4), Item( - id=5)]), - Order(id=5, items=[Item(id=5)])]), + User( + id=7, + orders=[ + Order( + id=1, items=[Item(id=1), Item(id=2), Item(id=3)] + ), + Order( + id=3, items=[Item(id=3), Item(id=4), Item(id=5)] + ), + Order(id=5, items=[Item(id=5)]), + ], + ), User(id=8, orders=[]), - User(id=9, orders=[ - Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)]), - Order(id=4, items=[Item(id=1), Item(id=5)]) - ] + User( + id=9, + orders=[ + Order( + id=2, items=[Item(id=1), Item(id=2), Item(id=3)] + ), + Order(id=4, items=[Item(id=1), Item(id=5)]), + ], ), - User(id=10, orders=[]) + User(id=10, orders=[]), ], - q.order_by(User.id).all() + q.order_by(User.id).all(), ) self.assert_compile( sess.query(User).options( - joinedload( - User.orders, - Order.items, - innerjoin=False)), + joinedload(User.orders, Order.items, innerjoin=False) + ), "SELECT users.id AS users_id, users.name AS users_name, " "items_1.id AS " "items_1_id, items_1.description AS items_1_description, " @@ -1586,8 +2101,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "LEFT OUTER JOIN (order_items AS order_items_1 " "JOIN items AS items_1 ON items_1.id = order_items_1.item_id) " "ON orders_1.id = order_items_1.order_id ORDER BY " - "orders_1.id, items_1.id" - + "orders_1.id, items_1.id", ) def test_inner_join_nested_chaining_positive_options(self): @@ -1598,23 +2112,30 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.Order, self.classes.Item, self.classes.User, - self.tables.orders) - - mapper(User, users, properties=dict( - orders=relationship(Order, order_by=orders.c.id) - )) - mapper(Order, orders, properties=dict( - items=relationship( - Item, - secondary=order_items, - order_by=items.c.id) - )) + self.tables.orders, + ) + + mapper( + User, + users, + properties=dict(orders=relationship(Order, order_by=orders.c.id)), + ) + mapper( + Order, + orders, + properties=dict( + items=relationship( + Item, secondary=order_items, order_by=items.c.id + ) + ), + ) mapper(Item, items) sess = create_session() q = sess.query(User).options( - joinedload("orders", innerjoin=False). - joinedload("items", innerjoin=True) + joinedload("orders", innerjoin=False).joinedload( + "items", innerjoin=True + ) ) self.assert_compile( @@ -1633,35 +2154,36 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "JOIN items AS " "items_1 ON items_1.id = order_items_1.item_id) " "ON users.id = orders_1.user_id " - "ORDER BY orders_1.id, items_1.id" + "ORDER BY orders_1.id, items_1.id", ) eq_( [ - User(id=7, - orders=[ - Order( - id=1, items=[ - Item( - id=1), Item( - id=2), Item( - id=3)]), - Order( - id=3, items=[ - Item( - id=3), Item( - id=4), Item( - id=5)]), - Order(id=5, items=[Item(id=5)])]), + User( + id=7, + orders=[ + Order( + id=1, items=[Item(id=1), Item(id=2), Item(id=3)] + ), + Order( + id=3, items=[Item(id=3), Item(id=4), Item(id=5)] + ), + Order(id=5, items=[Item(id=5)]), + ], + ), User(id=8, orders=[]), - User(id=9, orders=[ - Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)]), - Order(id=4, items=[Item(id=1), Item(id=5)]) - ] + User( + id=9, + orders=[ + Order( + id=2, items=[Item(id=1), Item(id=2), Item(id=3)] + ), + Order(id=4, items=[Item(id=1), Item(id=5)]), + ], ), - User(id=10, orders=[]) + User(id=10, orders=[]), ], - q.order_by(User.id).all() + q.order_by(User.id).all(), ) def test_unnested_outerjoin_propagation_only_on_correct_path(self): @@ -1671,17 +2193,22 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): Order, orders = self.classes.Order, self.tables.orders Address, addresses = self.classes.Address, self.tables.addresses - mapper(User, users, properties=odict([ - ('orders', relationship(Order)), - ('addresses', relationship(Address)) - ])) + mapper( + User, + users, + properties=odict( + [ + ("orders", relationship(Order)), + ("addresses", relationship(Address)), + ] + ), + ) mapper(Order, orders) mapper(Address, addresses) sess = create_session() q = sess.query(User).options( - joinedload("orders"), - joinedload("addresses", innerjoin="unnested"), + joinedload("orders"), joinedload("addresses", innerjoin="unnested") ) self.assert_compile( @@ -1697,7 +2224,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "addresses_1.email_address AS addresses_1_email_address " "FROM users LEFT OUTER JOIN orders AS orders_1 " "ON users.id = orders_1.user_id JOIN addresses AS addresses_1 " - "ON users.id = addresses_1.user_id" + "ON users.id = addresses_1.user_id", ) def test_nested_outerjoin_propagation_only_on_correct_path(self): @@ -1707,17 +2234,22 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): Order, orders = self.classes.Order, self.tables.orders Address, addresses = self.classes.Address, self.tables.addresses - mapper(User, users, properties=odict([ - ('orders', relationship(Order)), - ('addresses', relationship(Address)) - ])) + mapper( + User, + users, + properties=odict( + [ + ("orders", relationship(Order)), + ("addresses", relationship(Address)), + ] + ), + ) mapper(Order, orders) mapper(Address, addresses) sess = create_session() q = sess.query(User).options( - joinedload("orders"), - joinedload("addresses", innerjoin=True), + joinedload("orders"), joinedload("addresses", innerjoin=True) ) self.assert_compile( @@ -1733,42 +2265,60 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "addresses_1.email_address AS addresses_1_email_address " "FROM users LEFT OUTER JOIN orders AS orders_1 " "ON users.id = orders_1.user_id JOIN addresses AS addresses_1 " - "ON users.id = addresses_1.user_id" + "ON users.id = addresses_1.user_id", + ) + + def test_catch_the_right_target(self): + # test eager join chaining to the "nested" join on the left, + # a new feature as of [ticket:2369] + + users, Keyword, orders, items, order_items, Order, Item, User, keywords, item_keywords = ( + self.tables.users, + self.classes.Keyword, + self.tables.orders, + self.tables.items, + self.tables.order_items, + self.classes.Order, + self.classes.Item, + self.classes.User, + self.tables.keywords, + self.tables.item_keywords, + ) + + mapper( + User, + users, + properties={ + "orders": relationship(Order, backref="user") # o2m, m2o + }, + ) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, secondary=order_items, order_by=items.c.id + ) # m2m + }, + ) + mapper( + Item, + items, + properties={ + "keywords": relationship( + Keyword, secondary=item_keywords, order_by=keywords.c.id + ) # m2m + }, ) - - def test_catch_the_right_target(self): - # test eager join chaining to the "nested" join on the left, - # a new feature as of [ticket:2369] - - users, Keyword, orders, items, order_items, Order, Item, \ - User, keywords, item_keywords = ( - self.tables.users, - self.classes.Keyword, - self.tables.orders, - self.tables.items, - self.tables.order_items, - self.classes.Order, - self.classes.Item, - self.classes.User, - self.tables.keywords, - self.tables.item_keywords) - - mapper(User, users, properties={ - 'orders': relationship(Order, backref='user'), # o2m, m2o - }) - mapper(Order, orders, properties={ - 'items': relationship(Item, secondary=order_items, - order_by=items.c.id), # m2m - }) - mapper(Item, items, properties={ - 'keywords': relationship(Keyword, secondary=item_keywords, - order_by=keywords.c.id) # m2m - }) mapper(Keyword, keywords) sess = create_session() - q = sess.query(User).join(User.orders).join(Order.items).\ - options(joinedload_all("orders.items.keywords")) + q = ( + sess.query(User) + .join(User.orders) + .join(Order.items) + .options(joinedload_all("orders.items.keywords")) + ) # here, the eager join for keywords can catch onto # join(Order.items) or the nested (orders LEFT OUTER JOIN items), @@ -1798,7 +2348,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "JOIN keywords AS keywords_1 ON keywords_1.id = " "item_keywords_1.keyword_id) " "ON items_1.id = item_keywords_1.item_id " - "ORDER BY items_1.id, keywords_1.id" + "ORDER BY items_1.id, keywords_1.id", ) def test_inner_join_unnested_chaining_fixed(self): @@ -1809,15 +2359,26 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.Order, self.classes.Item, self.classes.User, - self.tables.orders) - - mapper(User, users, properties=dict( - orders=relationship(Order, lazy=False) - )) - mapper(Order, orders, properties=dict( - items=relationship(Item, secondary=order_items, lazy=False, - innerjoin="unnested") - )) + self.tables.orders, + ) + + mapper( + User, + users, + properties=dict(orders=relationship(Order, lazy=False)), + ) + mapper( + Order, + orders, + properties=dict( + items=relationship( + Item, + secondary=order_items, + lazy=False, + innerjoin="unnested", + ) + ), + ) mapper(Item, items) sess = create_session() @@ -1839,7 +2400,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "(order_items AS order_items_1 JOIN items AS items_1 ON " "items_1.id = " "order_items_1.item_id) ON orders_1.id = " - "order_items_1.order_id" + "order_items_1.order_id", ) # joining just from Order, innerjoin=True can be respected @@ -1851,7 +2412,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "AS items_1_id, items_1.description AS items_1_description FROM " "orders JOIN order_items AS order_items_1 ON orders.id = " "order_items_1.order_id JOIN items AS items_1 ON items_1.id = " - "order_items_1.item_id" + "order_items_1.item_id", ) def test_inner_join_nested_chaining_fixed(self): @@ -1862,15 +2423,23 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.Order, self.classes.Item, self.classes.User, - self.tables.orders) - - mapper(User, users, properties=dict( - orders=relationship(Order, lazy=False) - )) - mapper(Order, orders, properties=dict( - items=relationship(Item, secondary=order_items, lazy=False, - innerjoin='nested') - )) + self.tables.orders, + ) + + mapper( + User, + users, + properties=dict(orders=relationship(Order, lazy=False)), + ) + mapper( + Order, + orders, + properties=dict( + items=relationship( + Item, secondary=order_items, lazy=False, innerjoin="nested" + ) + ), + ) mapper(Item, items) sess = create_session() @@ -1890,7 +2459,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "(orders AS orders_1 JOIN order_items AS order_items_1 " "ON orders_1.id = order_items_1.order_id " "JOIN items AS items_1 ON items_1.id = order_items_1.item_id) " - "ON users.id = orders_1.user_id" + "ON users.id = orders_1.user_id", ) def test_inner_join_options(self): @@ -1901,18 +2470,29 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.Order, self.classes.Item, self.classes.User, - self.tables.orders) - - mapper(User, users, properties=dict( - orders=relationship(Order, backref=backref('user', innerjoin=True), - order_by=orders.c.id) - )) - mapper(Order, orders, properties=dict( - items=relationship( - Item, - secondary=order_items, - order_by=items.c.id) - )) + self.tables.orders, + ) + + mapper( + User, + users, + properties=dict( + orders=relationship( + Order, + backref=backref("user", innerjoin=True), + order_by=orders.c.id, + ) + ), + ) + mapper( + Order, + orders, + properties=dict( + items=relationship( + Item, secondary=order_items, order_by=items.c.id + ) + ), + ) mapper(Item, items) sess = create_session() self.assert_compile( @@ -1924,11 +2504,13 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "orders_1.description AS orders_1_description, orders_1.isopen " "AS orders_1_isopen " "FROM users JOIN orders AS orders_1 ON users.id = " - "orders_1.user_id ORDER BY orders_1.id") + "orders_1.user_id ORDER BY orders_1.id", + ) self.assert_compile( sess.query(User).options( - joinedload_all(User.orders, Order.items, innerjoin=True)), + joinedload_all(User.orders, Order.items, innerjoin=True) + ), "SELECT users.id AS users_id, users.name AS users_name, " "items_1.id AS items_1_id, " "items_1.description AS items_1_description, " @@ -1942,48 +2524,53 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "order_items_1 ON orders_1.id = order_items_1.order_id " "JOIN items AS items_1 ON " "items_1.id = order_items_1.item_id ORDER BY orders_1.id, " - "items_1.id") + "items_1.id", + ) def go(): eq_( - sess.query(User).options( + sess.query(User) + .options( joinedload(User.orders, innerjoin=True), - joinedload(User.orders, Order.items, innerjoin=True)). - order_by(User.id).all(), - - [User(id=7, - orders=[ - Order( - id=1, items=[ - Item( - id=1), Item( - id=2), Item( - id=3)]), - Order( - id=3, items=[ - Item( - id=3), Item( - id=4), Item( - id=5)]), - Order(id=5, items=[Item(id=5)])]), - User(id=9, orders=[ - Order( - id=2, items=[ - Item( - id=1), Item( - id=2), Item( - id=3)]), - Order(id=4, items=[Item(id=1), Item(id=5)])]) - ] + joinedload(User.orders, Order.items, innerjoin=True), + ) + .order_by(User.id) + .all(), + [ + User( + id=7, + orders=[ + Order( + id=1, + items=[Item(id=1), Item(id=2), Item(id=3)], + ), + Order( + id=3, + items=[Item(id=3), Item(id=4), Item(id=5)], + ), + Order(id=5, items=[Item(id=5)]), + ], + ), + User( + id=9, + orders=[ + Order( + id=2, + items=[Item(id=1), Item(id=2), Item(id=3)], + ), + Order(id=4, items=[Item(id=1), Item(id=5)]), + ], + ), + ], ) + self.assert_sql_count(testing.db, go, 1) # test that default innerjoin setting is used for options self.assert_compile( - sess.query(Order).options( - joinedload( - Order.user)).filter( - Order.description == 'foo'), + sess.query(Order) + .options(joinedload(Order.user)) + .filter(Order.description == "foo"), "SELECT orders.id AS orders_id, orders.user_id AS orders_user_id, " "orders.address_id AS " "orders_address_id, orders.description AS orders_description, " @@ -1991,7 +2578,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "orders_isopen, users_1.id AS users_1_id, users_1.name " "AS users_1_name " "FROM orders JOIN users AS users_1 ON users_1.id = orders.user_id " - "WHERE orders.description = :description_1" + "WHERE orders.description = :description_1", ) def test_propagated_lazyload_wildcard_unbound(self): @@ -2008,14 +2595,21 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.Order, self.classes.Item, self.classes.User, - self.tables.orders) - - mapper(User, users, properties=dict( - orders=relationship(Order, lazy="select") - )) - mapper(Order, orders, properties=dict( - items=relationship(Item, secondary=order_items, lazy="joined") - )) + self.tables.orders, + ) + + mapper( + User, + users, + properties=dict(orders=relationship(Order, lazy="select")), + ) + mapper( + Order, + orders, + properties=dict( + items=relationship(Item, secondary=order_items, lazy="joined") + ), + ) mapper(Item, items) sess = create_session() @@ -2031,62 +2625,80 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): for u in q: u.orders - self.sql_eq_(go, [ - ("SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.id = :id_1", {"id_1": 7}), - ("SELECT orders.id AS orders_id, " - "orders.user_id AS orders_user_id, " - "orders.address_id AS orders_address_id, " - "orders.description AS orders_description, " - "orders.isopen AS orders_isopen FROM orders " - "WHERE :param_1 = orders.user_id", {"param_1": 7}), - ]) + self.sql_eq_( + go, + [ + ( + "SELECT users.id AS users_id, users.name AS users_name " + "FROM users WHERE users.id = :id_1", + {"id_1": 7}, + ), + ( + "SELECT orders.id AS orders_id, " + "orders.user_id AS orders_user_id, " + "orders.address_id AS orders_address_id, " + "orders.description AS orders_description, " + "orders.isopen AS orders_isopen FROM orders " + "WHERE :param_1 = orders.user_id", + {"param_1": 7}, + ), + ], + ) class InnerJoinSplicingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" __backend__ = True # exercise hardcore join nesting on backends @classmethod def define_tables(cls, metadata): - Table('a', metadata, - Column('id', Integer, primary_key=True) - ) - - Table('b', metadata, - Column('id', Integer, primary_key=True), - Column('a_id', Integer, ForeignKey('a.id')), - Column('value', String(10)), - ) - Table('c1', metadata, - Column('id', Integer, primary_key=True), - Column('b_id', Integer, ForeignKey('b.id')), - Column('value', String(10)), - ) - Table('c2', metadata, - Column('id', Integer, primary_key=True), - Column('b_id', Integer, ForeignKey('b.id')), - Column('value', String(10)), - ) - Table('d1', metadata, - Column('id', Integer, primary_key=True), - Column('c1_id', Integer, ForeignKey('c1.id')), - Column('value', String(10)), - ) - Table('d2', metadata, - Column('id', Integer, primary_key=True), - Column('c2_id', Integer, ForeignKey('c2.id')), - Column('value', String(10)), - ) - Table('e1', metadata, - Column('id', Integer, primary_key=True), - Column('d1_id', Integer, ForeignKey('d1.id')), - Column('value', String(10)), - ) + Table("a", metadata, Column("id", Integer, primary_key=True)) + + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("a_id", Integer, ForeignKey("a.id")), + Column("value", String(10)), + ) + Table( + "c1", + metadata, + Column("id", Integer, primary_key=True), + Column("b_id", Integer, ForeignKey("b.id")), + Column("value", String(10)), + ) + Table( + "c2", + metadata, + Column("id", Integer, primary_key=True), + Column("b_id", Integer, ForeignKey("b.id")), + Column("value", String(10)), + ) + Table( + "d1", + metadata, + Column("id", Integer, primary_key=True), + Column("c1_id", Integer, ForeignKey("c1.id")), + Column("value", String(10)), + ) + Table( + "d2", + metadata, + Column("id", Integer, primary_key=True), + Column("c2_id", Integer, ForeignKey("c2.id")), + Column("value", String(10)), + ) + Table( + "e1", + metadata, + Column("id", Integer, primary_key=True), + Column("d1_id", Integer, ForeignKey("d1.id")), + Column("value", String(10)), + ) @classmethod def setup_classes(cls): - class A(cls.Comparable): pass @@ -2111,75 +2723,112 @@ class InnerJoinSplicingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): @classmethod def setup_mappers(cls): A, B, C1, C2, D1, D2, E1 = ( - cls.classes.A, cls.classes.B, cls.classes.C1, - cls.classes.C2, cls.classes.D1, cls.classes.D2, cls.classes.E1) - mapper(A, cls.tables.a, properties={ - 'bs': relationship(B) - }) - mapper(B, cls.tables.b, properties=odict([ - ('c1s', relationship(C1, order_by=cls.tables.c1.c.id)), - ('c2s', relationship(C2, order_by=cls.tables.c2.c.id)) - ])) - mapper(C1, cls.tables.c1, properties={ - 'd1s': relationship(D1, order_by=cls.tables.d1.c.id) - }) - mapper(C2, cls.tables.c2, properties={ - 'd2s': relationship(D2, order_by=cls.tables.d2.c.id) - }) - mapper(D1, cls.tables.d1, properties={ - 'e1s': relationship(E1, order_by=cls.tables.e1.c.id) - }) + cls.classes.A, + cls.classes.B, + cls.classes.C1, + cls.classes.C2, + cls.classes.D1, + cls.classes.D2, + cls.classes.E1, + ) + mapper(A, cls.tables.a, properties={"bs": relationship(B)}) + mapper( + B, + cls.tables.b, + properties=odict( + [ + ("c1s", relationship(C1, order_by=cls.tables.c1.c.id)), + ("c2s", relationship(C2, order_by=cls.tables.c2.c.id)), + ] + ), + ) + mapper( + C1, + cls.tables.c1, + properties={"d1s": relationship(D1, order_by=cls.tables.d1.c.id)}, + ) + mapper( + C2, + cls.tables.c2, + properties={"d2s": relationship(D2, order_by=cls.tables.d2.c.id)}, + ) + mapper( + D1, + cls.tables.d1, + properties={"e1s": relationship(E1, order_by=cls.tables.e1.c.id)}, + ) mapper(D2, cls.tables.d2) mapper(E1, cls.tables.e1) @classmethod def _fixture_data(cls): A, B, C1, C2, D1, D2, E1 = ( - cls.classes.A, cls.classes.B, cls.classes.C1, - cls.classes.C2, cls.classes.D1, cls.classes.D2, cls.classes.E1) + cls.classes.A, + cls.classes.B, + cls.classes.C1, + cls.classes.C2, + cls.classes.D1, + cls.classes.D2, + cls.classes.E1, + ) return [ - A(id=1, bs=[ - B( - id=1, - c1s=[C1( - id=1, value='C11', - d1s=[ - D1(id=1, e1s=[E1(id=1)]), D1(id=2, e1s=[E1(id=2)]) - ] + A( + id=1, + bs=[ + B( + id=1, + c1s=[ + C1( + id=1, + value="C11", + d1s=[ + D1(id=1, e1s=[E1(id=1)]), + D1(id=2, e1s=[E1(id=2)]), + ], + ) + ], + c2s=[ + C2(id=1, value="C21", d2s=[D2(id=3)]), + C2(id=2, value="C22", d2s=[D2(id=4)]), + ], + ), + B( + id=2, + c1s=[ + C1( + id=4, + value="C14", + d1s=[ + D1( + id=3, + e1s=[ + E1(id=3, value="E13"), + E1(id=4, value="E14"), + ], + ), + D1(id=4, e1s=[E1(id=5)]), + ], + ) + ], + c2s=[C2(id=4, value="C24", d2s=[])], + ), + ], + ), + A( + id=2, + bs=[ + B( + id=3, + c1s=[ + C1( + id=8, + d1s=[D1(id=5, value="D15", e1s=[E1(id=6)])], + ) + ], + c2s=[C2(id=8, d2s=[D2(id=6, value="D26")])], ) - ], - c2s=[C2(id=1, value='C21', d2s=[D2(id=3)]), - C2(id=2, value='C22', d2s=[D2(id=4)])] - ), - B( - id=2, - c1s=[ - C1( - id=4, value='C14', - d1s=[D1( - id=3, e1s=[ - E1(id=3, value='E13'), - E1(id=4, value="E14") - ]), - D1(id=4, e1s=[E1(id=5)]) - ] - ) - ], - c2s=[C2(id=4, value='C24', d2s=[])] - ), - ]), - A(id=2, bs=[ - B( - id=3, - c1s=[ - C1( - id=8, - d1s=[D1(id=5, value='D15', e1s=[E1(id=6)])] - ) - ], - c2s=[C2(id=8, d2s=[D2(id=6, value='D26')])] - ) - ]) + ], + ), ] @classmethod @@ -2189,24 +2838,25 @@ class InnerJoinSplicingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): s.commit() def _assert_result(self, query): - eq_( - query.all(), - self._fixture_data() - ) + eq_(query.all(), self._fixture_data()) def test_nested_innerjoin_propagation_multiple_paths_one(self): A, B, C1, C2 = ( - self.classes.A, self.classes.B, self.classes.C1, - self.classes.C2) + self.classes.A, + self.classes.B, + self.classes.C1, + self.classes.C2, + ) s = Session() q = s.query(A).options( - joinedload(A.bs, innerjoin=False). - joinedload(B.c1s, innerjoin=True). - joinedload(C1.d1s, innerjoin=True), - defaultload(A.bs).joinedload(B.c2s, innerjoin=True). - joinedload(C2.d2s, innerjoin=False) + joinedload(A.bs, innerjoin=False) + .joinedload(B.c1s, innerjoin=True) + .joinedload(C1.d1s, innerjoin=True), + defaultload(A.bs) + .joinedload(B.c2s, innerjoin=True) + .joinedload(C2.d2s, innerjoin=False), ) self.assert_compile( q, @@ -2224,7 +2874,7 @@ class InnerJoinSplicingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): "JOIN c1 AS c1_1 ON b_1.id = c1_1.b_id " "JOIN d1 AS d1_1 ON c1_1.id = d1_1.c1_id) ON a.id = b_1.a_id " "LEFT OUTER JOIN d2 AS d2_1 ON c2_1.id = d2_1.c2_id " - "ORDER BY c1_1.id, d1_1.id, c2_1.id, d2_1.id" + "ORDER BY c1_1.id, d1_1.id, c2_1.id, d2_1.id", ) self._assert_result(q) @@ -2235,10 +2885,10 @@ class InnerJoinSplicingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): s = Session() q = s.query(A).options( - joinedload('bs'), - joinedload('bs.c2s', innerjoin=True), - joinedload('bs.c1s', innerjoin=True), - joinedload('bs.c1s.d1s') + joinedload("bs"), + joinedload("bs.c2s", innerjoin=True), + joinedload("bs.c1s", innerjoin=True), + joinedload("bs.c1s.d1s"), ) self.assert_compile( q, @@ -2253,7 +2903,7 @@ class InnerJoinSplicingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): "(b AS b_1 JOIN c2 AS c2_1 ON b_1.id = c2_1.b_id " "JOIN c1 AS c1_1 ON b_1.id = c1_1.b_id) ON a.id = b_1.a_id " "LEFT OUTER JOIN d1 AS d1_1 ON c1_1.id = d1_1.c1_id " - "ORDER BY c1_1.id, d1_1.id, c2_1.id" + "ORDER BY c1_1.id, d1_1.id, c2_1.id", ) self._assert_result(q) @@ -2263,12 +2913,12 @@ class InnerJoinSplicingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): s = Session() q = s.query(A).options( - joinedload('bs', innerjoin=False), - joinedload('bs.c1s', innerjoin=True), - joinedload('bs.c2s', innerjoin=True), - joinedload('bs.c1s.d1s', innerjoin=False), - joinedload('bs.c2s.d2s'), - joinedload('bs.c1s.d1s.e1s', innerjoin=True) + joinedload("bs", innerjoin=False), + joinedload("bs.c1s", innerjoin=True), + joinedload("bs.c2s", innerjoin=True), + joinedload("bs.c1s.d1s", innerjoin=False), + joinedload("bs.c2s.d2s"), + joinedload("bs.c1s.d1s.e1s", innerjoin=True), ) self.assert_compile( @@ -2289,7 +2939,7 @@ class InnerJoinSplicingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): "d1 AS d1_1 JOIN e1 AS e1_1 ON d1_1.id = e1_1.d1_id) " "ON c1_1.id = d1_1.c1_id " "LEFT OUTER JOIN d2 AS d2_1 ON c2_1.id = d2_1.c2_id " - "ORDER BY c1_1.id, d1_1.id, e1_1.id, c2_1.id, d2_1.id" + "ORDER BY c1_1.id, d1_1.id, e1_1.id, c2_1.id, d2_1.id", ) self._assert_result(q) @@ -2305,25 +2955,26 @@ class InnerJoinSplicingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): weird_selectable = b_table.outerjoin(c1_table) b_np = mapper( - B, weird_selectable, non_primary=True, properties=odict([ - # note we need to make this fixed with lazy=False until - # [ticket:3348] is resolved - ('c1s', relationship(C1, lazy=False, innerjoin=True)), - ('c_id', c1_table.c.id), - ('b_value', b_table.c.value), - ]) + B, + weird_selectable, + non_primary=True, + properties=odict( + [ + # note we need to make this fixed with lazy=False until + # [ticket:3348] is resolved + ("c1s", relationship(C1, lazy=False, innerjoin=True)), + ("c_id", c1_table.c.id), + ("b_value", b_table.c.value), + ] + ), ) a_mapper = inspect(A) - a_mapper.add_property( - "bs_np", relationship(b_np) - ) + a_mapper.add_property("bs_np", relationship(b_np)) s = Session() - q = s.query(A).options( - joinedload('bs_np', innerjoin=False) - ) + q = s.query(A).options(joinedload("bs_np", innerjoin=False)) self.assert_compile( q, "SELECT a.id AS a_id, c1_1.id AS c1_1_id, c1_1.b_id AS c1_1_b_id, " @@ -2333,45 +2984,44 @@ class InnerJoinSplicingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): "c1_2.value AS c1_2_value " "FROM a LEFT OUTER JOIN " "(b AS b_1 LEFT OUTER JOIN c1 AS c1_2 ON b_1.id = c1_2.b_id " - "JOIN c1 AS c1_1 ON b_1.id = c1_1.b_id) ON a.id = b_1.a_id" + "JOIN c1 AS c1_1 ON b_1.id = c1_1.b_id) ON a.id = b_1.a_id", ) class InnerJoinSplicingWSecondaryTest( - fixtures.MappedTest, testing.AssertsCompiledSQL): - __dialect__ = 'default' + fixtures.MappedTest, testing.AssertsCompiledSQL +): + __dialect__ = "default" __backend__ = True # exercise hardcore join nesting on backends @classmethod def define_tables(cls, metadata): Table( - 'a', metadata, - Column('id', Integer, primary_key=True), - Column('bid', ForeignKey('b.id')) + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("bid", ForeignKey("b.id")), ) Table( - 'b', metadata, - Column('id', Integer, primary_key=True), - Column('cid', ForeignKey('c.id')) + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("cid", ForeignKey("c.id")), ) + Table("c", metadata, Column("id", Integer, primary_key=True)) + Table( - 'c', metadata, - Column('id', Integer, primary_key=True), + "ctod", + metadata, + Column("cid", ForeignKey("c.id"), primary_key=True), + Column("did", ForeignKey("d.id"), primary_key=True), ) - - Table('ctod', metadata, - Column('cid', ForeignKey('c.id'), primary_key=True), - Column('did', ForeignKey('d.id'), primary_key=True), - ) - Table('d', metadata, - Column('id', Integer, primary_key=True), - ) + Table("d", metadata, Column("id", Integer, primary_key=True)) @classmethod def setup_classes(cls): - class A(cls.Comparable): pass @@ -2387,48 +3037,44 @@ class InnerJoinSplicingWSecondaryTest( @classmethod def setup_mappers(cls): A, B, C, D = ( - cls.classes.A, cls.classes.B, cls.classes.C, - cls.classes.D) - mapper(A, cls.tables.a, properties={ - 'b': relationship(B) - }) - mapper(B, cls.tables.b, properties=odict([ - ('c', relationship(C)), - ])) - mapper(C, cls.tables.c, properties=odict([ - ('ds', relationship(D, secondary=cls.tables.ctod, - order_by=cls.tables.d.c.id)), - ])) + cls.classes.A, + cls.classes.B, + cls.classes.C, + cls.classes.D, + ) + mapper(A, cls.tables.a, properties={"b": relationship(B)}) + mapper(B, cls.tables.b, properties=odict([("c", relationship(C))])) + mapper( + C, + cls.tables.c, + properties=odict( + [ + ( + "ds", + relationship( + D, + secondary=cls.tables.ctod, + order_by=cls.tables.d.c.id, + ), + ) + ] + ), + ) mapper(D, cls.tables.d) @classmethod def _fixture_data(cls): A, B, C, D = ( - cls.classes.A, cls.classes.B, cls.classes.C, - cls.classes.D) + cls.classes.A, + cls.classes.B, + cls.classes.C, + cls.classes.D, + ) d1, d2, d3 = D(id=1), D(id=2), D(id=3) return [ - A( - id=1, - b=B( - id=1, - c=C( - id=1, - ds=[d1, d2] - ) - ) - ), - A( - id=2, - b=B( - id=2, - c=C( - id=2, - ds=[d2, d3] - ) - ) - ) + A(id=1, b=B(id=1, c=C(id=1, ds=[d1, d2]))), + A(id=2, b=B(id=2, c=C(id=2, ds=[d2, d3]))), ] @classmethod @@ -2439,26 +3085,19 @@ class InnerJoinSplicingWSecondaryTest( def _assert_result(self, query): def go(): - eq_( - query.all(), - self._fixture_data() - ) + eq_(query.all(), self._fixture_data()) - self.assert_sql_count( - testing.db, - go, - 1 - ) + self.assert_sql_count(testing.db, go, 1) def test_joined_across(self): A = self.classes.A s = Session() - q = s.query(A) \ - .options( - joinedload('b'). - joinedload('c', innerjoin=True). - joinedload('ds', innerjoin=True)) + q = s.query(A).options( + joinedload("b") + .joinedload("c", innerjoin=True) + .joinedload("ds", innerjoin=True) + ) self.assert_compile( q, "SELECT a.id AS a_id, a.bid AS a_bid, d_1.id AS d_1_id, " @@ -2468,7 +3107,7 @@ class InnerJoinSplicingWSecondaryTest( "(c AS c_1 JOIN ctod AS ctod_1 ON c_1.id = ctod_1.cid) " "ON c_1.id = b_1.cid " "JOIN d AS d_1 ON d_1.id = ctod_1.did) ON b_1.id = a.bid " - "ORDER BY d_1.id" + "ORDER BY d_1.id", ) self._assert_result(q) @@ -2477,24 +3116,23 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): """test #2188""" - __dialect__ = 'default' + __dialect__ = "default" run_create_tables = None @classmethod def define_tables(cls, metadata): - Table('a', metadata, - Column('id', Integer, primary_key=True) - ) + Table("a", metadata, Column("id", Integer, primary_key=True)) - Table('b', metadata, - Column('id', Integer, primary_key=True), - Column('a_id', Integer, ForeignKey('a.id')), - Column('value', Integer), - ) + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("a_id", Integer, ForeignKey("a.id")), + Column("value", Integer), + ) @classmethod def setup_classes(cls): - class A(cls.Comparable): pass @@ -2505,92 +3143,99 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): A, B = self.classes.A, self.classes.B b_table, a_table = self.tables.b, self.tables.a mapper(A, a_table, properties=props) - mapper(B, b_table, properties={ - 'a': relationship(A, backref="bs") - }) + mapper(B, b_table, properties={"a": relationship(A, backref="bs")}) def test_column_property(self): A = self.classes.A b_table, a_table = self.tables.b, self.tables.a - cp = select([func.sum(b_table.c.value)]).\ - where(b_table.c.a_id == a_table.c.id) + cp = select([func.sum(b_table.c.value)]).where( + b_table.c.a_id == a_table.c.id + ) - self._fixture({ - 'summation': column_property(cp) - }) + self._fixture({"summation": column_property(cp)}) self.assert_compile( - create_session().query(A).options(joinedload_all('bs')). - order_by(A.summation). - limit(50), + create_session() + .query(A) + .options(joinedload_all("bs")) + .order_by(A.summation) + .limit(50), "SELECT anon_1.anon_2 AS anon_1_anon_2, anon_1.a_id " "AS anon_1_a_id, b_1.id AS b_1_id, b_1.a_id AS " "b_1_a_id, b_1.value AS b_1_value FROM (SELECT " "(SELECT sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) " "AS anon_2, a.id AS a_id FROM a ORDER BY anon_2 " "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 ON " - "anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2" + "anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2", ) def test_column_property_desc(self): A = self.classes.A b_table, a_table = self.tables.b, self.tables.a - cp = select([func.sum(b_table.c.value)]).\ - where(b_table.c.a_id == a_table.c.id) + cp = select([func.sum(b_table.c.value)]).where( + b_table.c.a_id == a_table.c.id + ) - self._fixture({ - 'summation': column_property(cp) - }) + self._fixture({"summation": column_property(cp)}) self.assert_compile( - create_session().query(A).options(joinedload_all('bs')). - order_by(A.summation.desc()). - limit(50), + create_session() + .query(A) + .options(joinedload_all("bs")) + .order_by(A.summation.desc()) + .limit(50), "SELECT anon_1.anon_2 AS anon_1_anon_2, anon_1.a_id " "AS anon_1_a_id, b_1.id AS b_1_id, b_1.a_id AS " "b_1_a_id, b_1.value AS b_1_value FROM (SELECT " "(SELECT sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) " "AS anon_2, a.id AS a_id FROM a ORDER BY anon_2 DESC " "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 ON " - "anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2 DESC" + "anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2 DESC", ) def test_column_property_correlated(self): A = self.classes.A b_table, a_table = self.tables.b, self.tables.a - cp = select([func.sum(b_table.c.value)]).\ - where(b_table.c.a_id == a_table.c.id).\ - correlate(a_table) + cp = ( + select([func.sum(b_table.c.value)]) + .where(b_table.c.a_id == a_table.c.id) + .correlate(a_table) + ) - self._fixture({ - 'summation': column_property(cp) - }) + self._fixture({"summation": column_property(cp)}) self.assert_compile( - create_session().query(A).options(joinedload_all('bs')). - order_by(A.summation). - limit(50), + create_session() + .query(A) + .options(joinedload_all("bs")) + .order_by(A.summation) + .limit(50), "SELECT anon_1.anon_2 AS anon_1_anon_2, anon_1.a_id " "AS anon_1_a_id, b_1.id AS b_1_id, b_1.a_id AS " "b_1_a_id, b_1.value AS b_1_value FROM (SELECT " "(SELECT sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) " "AS anon_2, a.id AS a_id FROM a ORDER BY anon_2 " "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 ON " - "anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2" + "anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2", ) def test_standalone_subquery_unlabeled(self): A = self.classes.A b_table, a_table = self.tables.b, self.tables.a self._fixture({}) - cp = select([func.sum(b_table.c.value)]).\ - where(b_table.c.a_id == a_table.c.id).\ - correlate(a_table).as_scalar() + cp = ( + select([func.sum(b_table.c.value)]) + .where(b_table.c.a_id == a_table.c.id) + .correlate(a_table) + .as_scalar() + ) # up until 0.8, this was ordering by a new subquery. # the removal of a separate _make_proxy() from ScalarSelect # fixed that. self.assert_compile( - create_session().query(A).options(joinedload_all('bs')). - order_by(cp). - limit(50), + create_session() + .query(A) + .options(joinedload_all("bs")) + .order_by(cp) + .limit(50), "SELECT anon_1.a_id AS anon_1_a_id, anon_1.anon_2 " "AS anon_1_anon_2, b_1.id AS b_1_id, b_1.a_id AS " "b_1_a_id, b_1.value AS b_1_value FROM (SELECT a.id " @@ -2598,20 +3243,26 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): "b.a_id = a.id) AS anon_2 FROM a ORDER BY (SELECT " "sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) " "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 " - "ON anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2" + "ON anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2", ) def test_standalone_subquery_labeled(self): A = self.classes.A b_table, a_table = self.tables.b, self.tables.a self._fixture({}) - cp = select([func.sum(b_table.c.value)]).\ - where(b_table.c.a_id == a_table.c.id).\ - correlate(a_table).as_scalar().label('foo') + cp = ( + select([func.sum(b_table.c.value)]) + .where(b_table.c.a_id == a_table.c.id) + .correlate(a_table) + .as_scalar() + .label("foo") + ) self.assert_compile( - create_session().query(A).options(joinedload_all('bs')). - order_by(cp). - limit(50), + create_session() + .query(A) + .options(joinedload_all("bs")) + .order_by(cp) + .limit(50), "SELECT anon_1.a_id AS anon_1_a_id, anon_1.foo " "AS anon_1_foo, b_1.id AS b_1_id, b_1.a_id AS " "b_1_a_id, b_1.value AS b_1_value FROM (SELECT a.id " @@ -2619,22 +3270,26 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): "b.a_id = a.id) AS foo FROM a ORDER BY foo " "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 " "ON anon_1.a_id = b_1.a_id ORDER BY " - "anon_1.foo" + "anon_1.foo", ) def test_standalone_negated(self): A = self.classes.A b_table, a_table = self.tables.b, self.tables.a self._fixture({}) - cp = select([func.sum(b_table.c.value)]).\ - where(b_table.c.a_id == a_table.c.id).\ - correlate(a_table).\ - as_scalar() + cp = ( + select([func.sum(b_table.c.value)]) + .where(b_table.c.a_id == a_table.c.id) + .correlate(a_table) + .as_scalar() + ) # test a different unary operator self.assert_compile( - create_session().query(A).options(joinedload_all('bs')). - order_by(~cp). - limit(50), + create_session() + .query(A) + .options(joinedload_all("bs")) + .order_by(~cp) + .limit(50), "SELECT anon_1.a_id AS anon_1_a_id, anon_1.anon_2 " "AS anon_1_anon_2, b_1.id AS b_1_id, b_1.a_id AS " "b_1_a_id, b_1.value AS b_1_value FROM (SELECT a.id " @@ -2642,7 +3297,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): "WHERE b.a_id = a.id) FROM a ORDER BY NOT (SELECT " "sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) " "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 " - "ON anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2" + "ON anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2", ) @@ -2650,32 +3305,46 @@ class LoadOnExistingTest(_fixtures.FixtureTest): """test that loaders from a base Query fully populate.""" - run_inserts = 'once' + run_inserts = "once" run_deletes = None def _collection_to_scalar_fixture(self): - User, Address, Dingaling = self.classes.User, \ - self.classes.Address, self.classes.Dingaling - mapper(User, self.tables.users, properties={ - 'addresses': relationship(Address), - }) - mapper(Address, self.tables.addresses, properties={ - 'dingaling': relationship(Dingaling) - }) + User, Address, Dingaling = ( + self.classes.User, + self.classes.Address, + self.classes.Dingaling, + ) + mapper( + User, + self.tables.users, + properties={"addresses": relationship(Address)}, + ) + mapper( + Address, + self.tables.addresses, + properties={"dingaling": relationship(Dingaling)}, + ) mapper(Dingaling, self.tables.dingalings) sess = Session(autoflush=False) return User, Address, Dingaling, sess def _collection_to_collection_fixture(self): - User, Order, Item = self.classes.User, \ - self.classes.Order, self.classes.Item - mapper(User, self.tables.users, properties={ - 'orders': relationship(Order), - }) - mapper(Order, self.tables.orders, properties={ - 'items': relationship(Item, secondary=self.tables.order_items), - }) + User, Order, Item = ( + self.classes.User, + self.classes.Order, + self.classes.Item, + ) + mapper( + User, self.tables.users, properties={"orders": relationship(Order)} + ) + mapper( + Order, + self.tables.orders, + properties={ + "items": relationship(Item, secondary=self.tables.order_items) + }, + ) mapper(Item, self.tables.items) sess = Session(autoflush=False) @@ -2683,9 +3352,11 @@ class LoadOnExistingTest(_fixtures.FixtureTest): def _eager_config_fixture(self): User, Address = self.classes.User, self.classes.Address - mapper(User, self.tables.users, properties={ - 'addresses': relationship(Address, lazy="joined"), - }) + mapper( + User, + self.tables.users, + properties={"addresses": relationship(Address, lazy="joined")}, + ) mapper(Address, self.tables.addresses) sess = Session(autoflush=False) return User, Address, sess @@ -2694,13 +3365,14 @@ class LoadOnExistingTest(_fixtures.FixtureTest): User, Address, sess = self._eager_config_fixture() u1 = sess.query(User).get(8) - assert 'addresses' in u1.__dict__ + assert "addresses" in u1.__dict__ sess.expire(u1) def go(): eq_(u1.id, 8) + self.assert_sql_count(testing.db, go, 1) - assert 'addresses' not in u1.__dict__ + assert "addresses" not in u1.__dict__ def test_loads_second_level_collection_to_scalar(self): User, Address, Dingaling, sess = self._collection_to_scalar_fixture() @@ -2709,17 +3381,18 @@ class LoadOnExistingTest(_fixtures.FixtureTest): a1 = Address() u1.addresses.append(a1) a2 = u1.addresses[0] - a2.email_address = 'foo' - sess.query(User).options(joinedload_all("addresses.dingaling")).\ - filter_by(id=8).all() + a2.email_address = "foo" + sess.query(User).options( + joinedload_all("addresses.dingaling") + ).filter_by(id=8).all() assert u1.addresses[-1] is a1 for a in u1.addresses: if a is not a1: - assert 'dingaling' in a.__dict__ + assert "dingaling" in a.__dict__ else: - assert 'dingaling' not in a.__dict__ + assert "dingaling" not in a.__dict__ if a is a2: - eq_(a2.email_address, 'foo') + eq_(a2.email_address, "foo") def test_loads_second_level_collection_to_collection(self): User, Order, Item, sess = self._collection_to_collection_fixture() @@ -2728,146 +3401,160 @@ class LoadOnExistingTest(_fixtures.FixtureTest): u1.orders o1 = Order() u1.orders.append(o1) - sess.query(User).options(joinedload_all("orders.items")).\ - filter_by(id=7).all() + sess.query(User).options(joinedload_all("orders.items")).filter_by( + id=7 + ).all() for o in u1.orders: if o is not o1: - assert 'items' in o.__dict__ + assert "items" in o.__dict__ else: - assert 'items' not in o.__dict__ + assert "items" not in o.__dict__ def test_load_two_levels_collection_to_scalar(self): User, Address, Dingaling, sess = self._collection_to_scalar_fixture() - u1 = sess.query(User).filter_by( - id=8).options( - joinedload("addresses")).one() - sess.query(User).filter_by( - id=8).options( - joinedload_all("addresses.dingaling")).first() - assert 'dingaling' in u1.addresses[0].__dict__ + u1 = ( + sess.query(User) + .filter_by(id=8) + .options(joinedload("addresses")) + .one() + ) + sess.query(User).filter_by(id=8).options( + joinedload_all("addresses.dingaling") + ).first() + assert "dingaling" in u1.addresses[0].__dict__ def test_load_two_levels_collection_to_collection(self): User, Order, Item, sess = self._collection_to_collection_fixture() - u1 = sess.query(User).filter_by( - id=7).options( - joinedload("orders")).one() - sess.query(User).filter_by( - id=7).options( - joinedload_all("orders.items")).first() - assert 'items' in u1.orders[0].__dict__ + u1 = ( + sess.query(User) + .filter_by(id=7) + .options(joinedload("orders")) + .one() + ) + sess.query(User).filter_by(id=7).options( + joinedload_all("orders.items") + ).first() + assert "items" in u1.orders[0].__dict__ class AddEntityTest(_fixtures.FixtureTest): - run_inserts = 'once' + run_inserts = "once" run_deletes = None def _assert_result(self): - Item, Address, Order, User = (self.classes.Item, - self.classes.Address, - self.classes.Order, - self.classes.User) + Item, Address, Order, User = ( + self.classes.Item, + self.classes.Address, + self.classes.Order, + self.classes.User, + ) return [ ( - User(id=7, - addresses=[Address(id=1)] - ), - Order(id=1, - items=[Item(id=1), Item(id=2), Item(id=3)] - ), + User(id=7, addresses=[Address(id=1)]), + Order(id=1, items=[Item(id=1), Item(id=2), Item(id=3)]), ), ( - User(id=7, - addresses=[Address(id=1)] - ), - Order(id=3, - items=[Item(id=3), Item(id=4), Item(id=5)] - ), + User(id=7, addresses=[Address(id=1)]), + Order(id=3, items=[Item(id=3), Item(id=4), Item(id=5)]), ), ( - User(id=7, - addresses=[Address(id=1)] - ), - Order(id=5, - items=[Item(id=5)] - ), + User(id=7, addresses=[Address(id=1)]), + Order(id=5, items=[Item(id=5)]), ), ( - User(id=9, - addresses=[Address(id=5)] - ), - Order(id=2, - items=[Item(id=1), Item(id=2), Item(id=3)] - ), + User(id=9, addresses=[Address(id=5)]), + Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)]), ), ( - User(id=9, - addresses=[Address(id=5)] - ), - Order(id=4, - items=[Item(id=1), Item(id=5)] - ), - ) + User(id=9, addresses=[Address(id=5)]), + Order(id=4, items=[Item(id=1), Item(id=5)]), + ), ] def test_mapper_configured(self): - users, items, order_items, Order, \ - Item, User, Address, orders, addresses = ( - self.tables.users, - self.tables.items, - self.tables.order_items, - self.classes.Order, - self.classes.Item, - self.classes.User, - self.classes.Address, - self.tables.orders, - self.tables.addresses) - - mapper(User, users, properties={ - 'addresses': relationship(Address, lazy='joined'), - 'orders': relationship(Order) - }) + users, items, order_items, Order, Item, User, Address, orders, addresses = ( + self.tables.users, + self.tables.items, + self.tables.order_items, + self.classes.Order, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.tables.orders, + self.tables.addresses, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship(Address, lazy="joined"), + "orders": relationship(Order), + }, + ) mapper(Address, addresses) - mapper(Order, orders, properties={ - 'items': relationship( - Item, secondary=order_items, lazy='joined', - order_by=items.c.id) - }) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, + secondary=order_items, + lazy="joined", + order_by=items.c.id, + ) + }, + ) mapper(Item, items) sess = create_session() oalias = sa.orm.aliased(Order) def go(): - ret = sess.query(User, oalias).join(oalias, 'orders').\ - order_by(User.id, oalias.id).all() + ret = ( + sess.query(User, oalias) + .join(oalias, "orders") + .order_by(User.id, oalias.id) + .all() + ) eq_(ret, self._assert_result()) + self.assert_sql_count(testing.db, go, 1) def test_options(self): - users, items, order_items, Order,\ - Item, User, Address, orders, addresses = ( - self.tables.users, - self.tables.items, - self.tables.order_items, - self.classes.Order, - self.classes.Item, - self.classes.User, - self.classes.Address, - self.tables.orders, - self.tables.addresses) - - mapper(User, users, properties={ - 'addresses': relationship(Address), - 'orders': relationship(Order) - }) + users, items, order_items, Order, Item, User, Address, orders, addresses = ( + self.tables.users, + self.tables.items, + self.tables.order_items, + self.classes.Order, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.tables.orders, + self.tables.addresses, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship(Address), + "orders": relationship(Order), + }, + ) mapper(Address, addresses) - mapper(Order, orders, properties={ - 'items': relationship( - Item, secondary=order_items, order_by=items.c.id) - }) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, secondary=order_items, order_by=items.c.id + ) + }, + ) mapper(Item, items) sess = create_session() @@ -2875,78 +3562,80 @@ class AddEntityTest(_fixtures.FixtureTest): oalias = sa.orm.aliased(Order) def go(): - ret = sess.query(User, oalias).options(joinedload('addresses')).\ - join(oalias, 'orders').\ - order_by(User.id, oalias.id).all() + ret = ( + sess.query(User, oalias) + .options(joinedload("addresses")) + .join(oalias, "orders") + .order_by(User.id, oalias.id) + .all() + ) eq_(ret, self._assert_result()) + self.assert_sql_count(testing.db, go, 6) sess.expunge_all() def go(): - ret = sess.query(User, oalias).\ - options(joinedload('addresses'), - joinedload(oalias.items)).\ - join(oalias, 'orders').\ - order_by(User.id, oalias.id).all() + ret = ( + sess.query(User, oalias) + .options(joinedload("addresses"), joinedload(oalias.items)) + .join(oalias, "orders") + .order_by(User.id, oalias.id) + .all() + ) eq_(ret, self._assert_result()) + self.assert_sql_count(testing.db, go, 1) class OrderBySecondaryTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('m2m', metadata, - Column( - 'id', - Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('aid', Integer, ForeignKey('a.id')), - Column('bid', Integer, ForeignKey('b.id'))) - - Table('a', metadata, - Column( - 'id', - Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50))) - Table('b', metadata, - Column( - 'id', - Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50))) + Table( + "m2m", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("aid", Integer, ForeignKey("a.id")), + Column("bid", Integer, ForeignKey("b.id")), + ) + + Table( + "a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + Table( + "b", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) @classmethod def fixtures(cls): return dict( - a=(('id', 'data'), - (1, 'a1'), - (2, 'a2')), - - b=(('id', 'data'), - (1, 'b1'), - (2, 'b2'), - (3, 'b3'), - (4, 'b4')), - - m2m=(('id', 'aid', 'bid'), - (2, 1, 1), - (4, 2, 4), - (1, 1, 3), - (6, 2, 2), - (3, 1, 2), - (5, 2, 3))) + a=(("id", "data"), (1, "a1"), (2, "a2")), + b=(("id", "data"), (1, "b1"), (2, "b2"), (3, "b3"), (4, "b4")), + m2m=( + ("id", "aid", "bid"), + (2, 1, 1), + (4, 2, 4), + (1, 1, 3), + (6, 2, 2), + (3, 1, 2), + (5, 2, 3), + ), + ) def test_ordering(self): - a, m2m, b = ( - self.tables.a, - self.tables.m2m, - self.tables.b) + a, m2m, b = (self.tables.a, self.tables.m2m, self.tables.b) class A(fixtures.ComparableEntity): pass @@ -2954,106 +3643,140 @@ class OrderBySecondaryTest(fixtures.MappedTest): class B(fixtures.ComparableEntity): pass - mapper(A, a, properties={ - 'bs': relationship( - B, secondary=m2m, lazy='joined', order_by=m2m.c.id) - }) + mapper( + A, + a, + properties={ + "bs": relationship( + B, secondary=m2m, lazy="joined", order_by=m2m.c.id + ) + }, + ) mapper(B, b) sess = create_session() - eq_(sess.query(A).all(), + eq_( + sess.query(A).all(), [ - A(data='a1', bs=[B(data='b3'), B(data='b1'), B(data='b2')]), - A(bs=[B(data='b4'), B(data='b3'), B(data='b2')]) - ]) + A(data="a1", bs=[B(data="b3"), B(data="b1"), B(data="b2")]), + A(bs=[B(data="b4"), B(data="b3"), B(data="b2")]), + ], + ) class SelfReferentialEagerTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('nodes', metadata, - Column( - 'id', - Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, ForeignKey('nodes.id')), - Column('data', String(30))) + Table( + "nodes", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_id", Integer, ForeignKey("nodes.id")), + Column("data", String(30)), + ) def test_basic(self): nodes = self.tables.nodes class Node(fixtures.ComparableEntity): - def append(self, node): self.children.append(node) - mapper(Node, nodes, properties={ - 'children': relationship(Node, - lazy='joined', - join_depth=3, order_by=nodes.c.id) - }) + mapper( + Node, + nodes, + properties={ + "children": relationship( + Node, lazy="joined", join_depth=3, order_by=nodes.c.id + ) + }, + ) sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) + n1 = Node(data="n1") + n1.append(Node(data="n11")) + n1.append(Node(data="n12")) + n1.append(Node(data="n13")) + n1.children[1].append(Node(data="n121")) + n1.children[1].append(Node(data="n122")) + n1.children[1].append(Node(data="n123")) sess.add(n1) sess.flush() sess.expunge_all() def go(): - d = sess.query(Node).filter_by(data='n1').all()[0] - eq_(Node(data='n1', children=[ - Node(data='n11'), - Node(data='n12', children=[ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ]), - Node(data='n13') - ]), d) + d = sess.query(Node).filter_by(data="n1").all()[0] + eq_( + Node( + data="n1", + children=[ + Node(data="n11"), + Node( + data="n12", + children=[ + Node(data="n121"), + Node(data="n122"), + Node(data="n123"), + ], + ), + Node(data="n13"), + ], + ), + d, + ) + self.assert_sql_count(testing.db, go, 1) sess.expunge_all() def go(): - d = sess.query(Node).filter_by(data='n1').first() - eq_(Node(data='n1', children=[ - Node(data='n11'), - Node(data='n12', children=[ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ]), - Node(data='n13') - ]), d) + d = sess.query(Node).filter_by(data="n1").first() + eq_( + Node( + data="n1", + children=[ + Node(data="n11"), + Node( + data="n12", + children=[ + Node(data="n121"), + Node(data="n122"), + Node(data="n123"), + ], + ), + Node(data="n13"), + ], + ), + d, + ) + self.assert_sql_count(testing.db, go, 1) def test_lazy_fallback_doesnt_affect_eager(self): nodes = self.tables.nodes class Node(fixtures.ComparableEntity): - def append(self, node): self.children.append(node) - mapper(Node, nodes, properties={ - 'children': relationship(Node, lazy='joined', join_depth=1, - order_by=nodes.c.id) - }) + mapper( + Node, + nodes, + properties={ + "children": relationship( + Node, lazy="joined", join_depth=1, order_by=nodes.c.id + ) + }, + ) sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) + n1 = Node(data="n1") + n1.append(Node(data="n11")) + n1.append(Node(data="n12")) + n1.append(Node(data="n13")) + n1.children[1].append(Node(data="n121")) + n1.children[1].append(Node(data="n122")) + n1.children[1].append(Node(data="n123")) sess.add(n1) sess.flush() sess.expunge_all() @@ -3069,170 +3792,216 @@ class SelfReferentialEagerTest(fixtures.MappedTest): def go(): allnodes = sess.query(Node).order_by(Node.data).all() n12 = allnodes[2] - eq_(n12.data, 'n12') - eq_([ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ], list(n12.children)) + eq_(n12.data, "n12") + eq_( + [Node(data="n121"), Node(data="n122"), Node(data="n123")], + list(n12.children), + ) + self.assert_sql_count(testing.db, go, 1) def test_with_deferred(self): nodes = self.tables.nodes class Node(fixtures.ComparableEntity): - def append(self, node): self.children.append(node) - mapper(Node, nodes, properties={ - 'children': relationship(Node, lazy='joined', join_depth=3, - order_by=nodes.c.id), - 'data': deferred(nodes.c.data) - }) + mapper( + Node, + nodes, + properties={ + "children": relationship( + Node, lazy="joined", join_depth=3, order_by=nodes.c.id + ), + "data": deferred(nodes.c.data), + }, + ) sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) + n1 = Node(data="n1") + n1.append(Node(data="n11")) + n1.append(Node(data="n12")) sess.add(n1) sess.flush() sess.expunge_all() def go(): eq_( - Node(data='n1', children=[Node(data='n11'), Node(data='n12')]), + Node(data="n1", children=[Node(data="n11"), Node(data="n12")]), sess.query(Node).order_by(Node.id).first(), ) + self.assert_sql_count(testing.db, go, 4) sess.expunge_all() def go(): - eq_(Node(data='n1', children=[Node(data='n11'), Node(data='n12')]), - sess.query(Node). - options(undefer('data')).order_by(Node.id).first()) + eq_( + Node(data="n1", children=[Node(data="n11"), Node(data="n12")]), + sess.query(Node) + .options(undefer("data")) + .order_by(Node.id) + .first(), + ) + self.assert_sql_count(testing.db, go, 3) sess.expunge_all() def go(): - eq_(Node(data='n1', children=[Node(data='n11'), Node(data='n12')]), - sess.query(Node).options(undefer('data'), - undefer('children.data')).first()) + eq_( + Node(data="n1", children=[Node(data="n11"), Node(data="n12")]), + sess.query(Node) + .options(undefer("data"), undefer("children.data")) + .first(), + ) + self.assert_sql_count(testing.db, go, 1) def test_options(self): nodes = self.tables.nodes class Node(fixtures.ComparableEntity): - def append(self, node): self.children.append(node) - mapper(Node, nodes, properties={ - 'children': relationship(Node, lazy='select', order_by=nodes.c.id) - }) + mapper( + Node, + nodes, + properties={ + "children": relationship( + Node, lazy="select", order_by=nodes.c.id + ) + }, + ) sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) + n1 = Node(data="n1") + n1.append(Node(data="n11")) + n1.append(Node(data="n12")) + n1.append(Node(data="n13")) + n1.children[1].append(Node(data="n121")) + n1.children[1].append(Node(data="n122")) + n1.children[1].append(Node(data="n123")) sess.add(n1) sess.flush() sess.expunge_all() def go(): - d = sess.query(Node).filter_by(data='n1').\ - order_by(Node.id).\ - options(joinedload('children.children')).first() - eq_(Node(data='n1', children=[ - Node(data='n11'), - Node(data='n12', children=[ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ]), - Node(data='n13') - ]), d) + d = ( + sess.query(Node) + .filter_by(data="n1") + .order_by(Node.id) + .options(joinedload("children.children")) + .first() + ) + eq_( + Node( + data="n1", + children=[ + Node(data="n11"), + Node( + data="n12", + children=[ + Node(data="n121"), + Node(data="n122"), + Node(data="n123"), + ], + ), + Node(data="n13"), + ], + ), + d, + ) + self.assert_sql_count(testing.db, go, 2) def go(): - sess.query(Node).order_by(Node.id).filter_by(data='n1').\ - options(joinedload('children.children')).first() + sess.query(Node).order_by(Node.id).filter_by(data="n1").options( + joinedload("children.children") + ).first() # test that the query isn't wrapping the initial query for eager # loading. self.assert_sql_execution( - testing.db, go, + testing.db, + go, CompiledSQL( "SELECT nodes.id AS nodes_id, nodes.parent_id AS " "nodes_parent_id, nodes.data AS nodes_data FROM nodes " "WHERE nodes.data = :data_1 ORDER BY nodes.id LIMIT :param_1", - {'data_1': 'n1'} - ) + {"data_1": "n1"}, + ), ) def test_no_depth(self): nodes = self.tables.nodes class Node(fixtures.ComparableEntity): - def append(self, node): self.children.append(node) - mapper(Node, nodes, properties={ - 'children': relationship(Node, lazy='joined') - }) + mapper( + Node, + nodes, + properties={"children": relationship(Node, lazy="joined")}, + ) sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) + n1 = Node(data="n1") + n1.append(Node(data="n11")) + n1.append(Node(data="n12")) + n1.append(Node(data="n13")) + n1.children[1].append(Node(data="n121")) + n1.children[1].append(Node(data="n122")) + n1.children[1].append(Node(data="n123")) sess.add(n1) sess.flush() sess.expunge_all() def go(): - d = sess.query(Node).filter_by(data='n1').first() - eq_(Node(data='n1', children=[ - Node(data='n11'), - Node(data='n12', children=[ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ]), - Node(data='n13') - ]), d) + d = sess.query(Node).filter_by(data="n1").first() + eq_( + Node( + data="n1", + children=[ + Node(data="n11"), + Node( + data="n12", + children=[ + Node(data="n121"), + Node(data="n122"), + Node(data="n123"), + ], + ), + Node(data="n13"), + ], + ), + d, + ) + self.assert_sql_count(testing.db, go, 3) class MixedSelfReferentialEagerTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('a_table', metadata, - Column( - 'id', - Integer, - primary_key=True, - test_needs_autoincrement=True) - ) - - Table('b_table', metadata, - Column( - 'id', - Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('parent_b1_id', Integer, ForeignKey('b_table.id')), - Column('parent_a_id', Integer, ForeignKey('a_table.id')), - Column('parent_b2_id', Integer, ForeignKey('b_table.id'))) + Table( + "a_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) + + Table( + "b_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_b1_id", Integer, ForeignKey("b_table.id")), + Column("parent_a_id", Integer, ForeignKey("a_table.id")), + Column("parent_b2_id", Integer, ForeignKey("b_table.id")), + ) @classmethod def setup_mappers(cls): @@ -3245,21 +4014,25 @@ class MixedSelfReferentialEagerTest(fixtures.MappedTest): pass mapper(A, a_table) - mapper(B, b_table, properties={ - 'parent_b1': relationship( - B, - remote_side=[b_table.c.id], - primaryjoin=(b_table.c.parent_b1_id == b_table.c.id), - order_by=b_table.c.id - ), - 'parent_z': relationship(A, lazy=True), - 'parent_b2': relationship( - B, - remote_side=[b_table.c.id], - primaryjoin=(b_table.c.parent_b2_id == b_table.c.id), - order_by=b_table.c.id - ) - }) + mapper( + B, + b_table, + properties={ + "parent_b1": relationship( + B, + remote_side=[b_table.c.id], + primaryjoin=(b_table.c.parent_b1_id == b_table.c.id), + order_by=b_table.c.id, + ), + "parent_z": relationship(A, lazy=True), + "parent_b2": relationship( + B, + remote_side=[b_table.c.id], + primaryjoin=(b_table.c.parent_b2_id == b_table.c.id), + order_by=b_table.c.id, + ), + }, + ) @classmethod def insert_data(cls): @@ -3290,49 +4063,59 @@ class MixedSelfReferentialEagerTest(fixtures.MappedTest): def go(): eq_( - session.query(B). - options( - joinedload('parent_b1'), - joinedload('parent_b2'), - joinedload('parent_z') - ). - filter(B.id.in_([2, 8, 11])).order_by(B.id).all(), + session.query(B) + .options( + joinedload("parent_b1"), + joinedload("parent_b2"), + joinedload("parent_z"), + ) + .filter(B.id.in_([2, 8, 11])) + .order_by(B.id) + .all(), [ - B(id=2, - parent_z=A(id=1), + B( + id=2, + parent_z=A(id=1), parent_b1=B(id=1), - parent_b2=None), - B(id=8, - parent_z=A(id=2), + parent_b2=None, + ), + B( + id=8, + parent_z=A(id=2), parent_b1=B(id=1), - parent_b2=B(id=2)), - B(id=11, - parent_z=A(id=3), + parent_b2=B(id=2), + ), + B( + id=11, + parent_z=A(id=3), parent_b1=B(id=1), - parent_b2=B(id=8)) - ] + parent_b2=B(id=8), + ), + ], ) + self.assert_sql_count(testing.db, go, 1) class SelfReferentialM2MEagerTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('widget', metadata, - Column( - 'id', - Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('name', sa.String(40), nullable=False, unique=True), - ) - - Table('widget_rel', metadata, - Column('parent_id', Integer, ForeignKey('widget.id')), - Column('child_id', Integer, ForeignKey('widget.id')), - sa.UniqueConstraint('parent_id', 'child_id'), - ) + Table( + "widget", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", sa.String(40), nullable=False, unique=True), + ) + + Table( + "widget_rel", + metadata, + Column("parent_id", Integer, ForeignKey("widget.id")), + Column("child_id", Integer, ForeignKey("widget.id")), + sa.UniqueConstraint("parent_id", "child_id"), + ) def test_basic(self): widget, widget_rel = self.tables.widget, self.tables.widget_rel @@ -3340,72 +4123,96 @@ class SelfReferentialM2MEagerTest(fixtures.MappedTest): class Widget(fixtures.ComparableEntity): pass - mapper(Widget, widget, properties={ - 'children': relationship( - Widget, secondary=widget_rel, - primaryjoin=widget_rel.c.parent_id == widget.c.id, - secondaryjoin=widget_rel.c.child_id == widget.c.id, - lazy='joined', join_depth=1, - ) - }) + mapper( + Widget, + widget, + properties={ + "children": relationship( + Widget, + secondary=widget_rel, + primaryjoin=widget_rel.c.parent_id == widget.c.id, + secondaryjoin=widget_rel.c.child_id == widget.c.id, + lazy="joined", + join_depth=1, + ) + }, + ) sess = create_session() - w1 = Widget(name='w1') - w2 = Widget(name='w2') + w1 = Widget(name="w1") + w2 = Widget(name="w2") w1.children.append(w2) sess.add(w1) sess.flush() sess.expunge_all() - eq_([Widget(name='w1', children=[Widget(name='w2')])], - sess.query(Widget).filter(Widget.name == 'w1').all()) + eq_( + [Widget(name="w1", children=[Widget(name="w2")])], + sess.query(Widget).filter(Widget.name == "w1").all(), + ) class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None - __dialect__ = 'default' + __dialect__ = "default" - __prefer_backends__ = ('postgresql', 'mysql', 'oracle') + __prefer_backends__ = ("postgresql", "mysql", "oracle") @classmethod def setup_mappers(cls): - users, Keyword, items, order_items, orders, \ - Item, User, Address, keywords, Order, \ - item_keywords, addresses = ( - cls.tables.users, - cls.classes.Keyword, - cls.tables.items, - cls.tables.order_items, - cls.tables.orders, - cls.classes.Item, - cls.classes.User, - cls.classes.Address, - cls.tables.keywords, - cls.classes.Order, - cls.tables.item_keywords, - cls.tables.addresses) - - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user'), - 'orders': relationship(Order, backref='user'), # o2m, m2o - }) + users, Keyword, items, order_items, orders, Item, User, Address, keywords, Order, item_keywords, addresses = ( + cls.tables.users, + cls.classes.Keyword, + cls.tables.items, + cls.tables.order_items, + cls.tables.orders, + cls.classes.Item, + cls.classes.User, + cls.classes.Address, + cls.tables.keywords, + cls.classes.Order, + cls.tables.item_keywords, + cls.tables.addresses, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship(Address, backref="user"), + "orders": relationship(Order, backref="user"), # o2m, m2o + }, + ) mapper(Address, addresses) - mapper(Order, orders, properties={ - 'items': relationship( - Item, secondary=order_items, order_by=items.c.id), # m2m - }) - mapper(Item, items, properties={ - 'keywords': relationship(Keyword, secondary=item_keywords) # m2m - }) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, secondary=order_items, order_by=items.c.id + ) # m2m + }, + ) + mapper( + Item, + items, + properties={ + "keywords": relationship( + Keyword, secondary=item_keywords + ) # m2m + }, + ) mapper(Keyword, keywords) def test_two_entities(self): - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() @@ -3413,46 +4220,62 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def go(): eq_( [ - (User(id=9, addresses=[Address(id=5)]), - Order(id=2, items=[ - Item(id=1), Item(id=2), Item(id=3)])), - (User(id=9, addresses=[Address(id=5)]), - Order(id=4, items=[ - Item(id=1), Item(id=5)])), + ( + User(id=9, addresses=[Address(id=5)]), + Order( + id=2, items=[Item(id=1), Item(id=2), Item(id=3)] + ), + ), + ( + User(id=9, addresses=[Address(id=5)]), + Order(id=4, items=[Item(id=1), Item(id=5)]), + ), ], - sess.query(User, Order).filter(User.id == Order.user_id). - options(joinedload(User.addresses), joinedload(Order.items)). - filter(User.id == 9). - order_by(User.id, Order.id).all(), + sess.query(User, Order) + .filter(User.id == Order.user_id) + .options(joinedload(User.addresses), joinedload(Order.items)) + .filter(User.id == 9) + .order_by(User.id, Order.id) + .all(), ) + self.assert_sql_count(testing.db, go, 1) # one FROM clause def go(): eq_( [ - (User(id=9, addresses=[Address(id=5)]), - Order(id=2, items=[ - Item(id=1), Item(id=2), Item(id=3)])), - (User(id=9, addresses=[Address(id=5)]), - Order(id=4, items=[ - Item(id=1), Item(id=5)])), + ( + User(id=9, addresses=[Address(id=5)]), + Order( + id=2, items=[Item(id=1), Item(id=2), Item(id=3)] + ), + ), + ( + User(id=9, addresses=[Address(id=5)]), + Order(id=4, items=[Item(id=1), Item(id=5)]), + ), ], - sess.query(User, Order).join(User.orders). - options(joinedload(User.addresses), joinedload(Order.items)). - filter(User.id == 9). - order_by(User.id, Order.id).all(), + sess.query(User, Order) + .join(User.orders) + .options(joinedload(User.addresses), joinedload(Order.items)) + .filter(User.id == 9) + .order_by(User.id, Order.id) + .all(), ) + self.assert_sql_count(testing.db, go, 1) def test_two_entities_with_joins(self): # early versions of SQLite could not handle this test # however as of 2018 and probably for some years before that # it has no issue with this. - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() @@ -3463,92 +4286,104 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): eq_( [ ( - User(addresses=[ - Address(email_address='fred@fred.com')], - name='fred'), - Order(description='order 2', isopen=0, - items=[ - Item(description='item 1'), - Item(description='item 2'), - Item(description='item 3')]), - User(addresses=[ - Address(email_address='jack@bean.com')], - name='jack'), - Order(description='order 3', isopen=1, - items=[ - Item(description='item 3'), - Item(description='item 4'), - Item(description='item 5')]) + User( + addresses=[Address(email_address="fred@fred.com")], + name="fred", + ), + Order( + description="order 2", + isopen=0, + items=[ + Item(description="item 1"), + Item(description="item 2"), + Item(description="item 3"), + ], + ), + User( + addresses=[Address(email_address="jack@bean.com")], + name="jack", + ), + Order( + description="order 3", + isopen=1, + items=[ + Item(description="item 3"), + Item(description="item 4"), + Item(description="item 5"), + ], + ), ), - ( User( - addresses=[ - Address( - email_address='fred@fred.com')], - name='fred'), + addresses=[Address(email_address="fred@fred.com")], + name="fred", + ), Order( - description='order 2', isopen=0, items=[ - Item( - description='item 1'), Item( - description='item 2'), Item( - description='item 3')]), + description="order 2", + isopen=0, + items=[ + Item(description="item 1"), + Item(description="item 2"), + Item(description="item 3"), + ], + ), User( - addresses=[ - Address( - email_address='jack@bean.com')], - name='jack'), + addresses=[Address(email_address="jack@bean.com")], + name="jack", + ), Order( address_id=None, - description='order 5', + description="order 5", isopen=0, - items=[ - Item( - description='item 5')]) + items=[Item(description="item 5")], + ), ), - ( User( - addresses=[ - Address( - email_address='fred@fred.com')], - name='fred'), + addresses=[Address(email_address="fred@fred.com")], + name="fred", + ), Order( - description='order 4', isopen=1, items=[ - Item( - description='item 1'), Item( - description='item 5')]), + description="order 4", + isopen=1, + items=[ + Item(description="item 1"), + Item(description="item 5"), + ], + ), User( - addresses=[ - Address( - email_address='jack@bean.com')], - name='jack'), + addresses=[Address(email_address="jack@bean.com")], + name="jack", + ), Order( address_id=None, - description='order 5', + description="order 5", isopen=0, - items=[ - Item( - description='item 5')]) + items=[Item(description="item 5")], + ), ), ], - sess.query(User, Order, u1, o1). - join(Order, User.orders). - options(joinedload(User.addresses), - joinedload(Order.items)).filter(User.id == 9). - join(o1, u1.orders). - options(joinedload(u1.addresses), - joinedload(o1.items)).filter(u1.id == 7). - filter(Order.id < o1.id). - order_by(User.id, Order.id, u1.id, o1.id).all(), + sess.query(User, Order, u1, o1) + .join(Order, User.orders) + .options(joinedload(User.addresses), joinedload(Order.items)) + .filter(User.id == 9) + .join(o1, u1.orders) + .options(joinedload(u1.addresses), joinedload(o1.items)) + .filter(u1.id == 7) + .filter(Order.id < o1.id) + .order_by(User.id, Order.id, u1.id, o1.id) + .all(), ) + self.assert_sql_count(testing.db, go, 1) def test_aliased_entity_one(self): - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() @@ -3559,31 +4394,33 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): eq_( [ ( - User( - id=9, addresses=[ - Address( - id=5)]), Order( - id=2, items=[ - Item( - id=1), Item( - id=2), Item( - id=3)])), - (User(id=9, addresses=[Address(id=5)]), Order( - id=4, items=[Item(id=1), Item(id=5)])), + User(id=9, addresses=[Address(id=5)]), + Order( + id=2, items=[Item(id=1), Item(id=2), Item(id=3)] + ), + ), + ( + User(id=9, addresses=[Address(id=5)]), + Order(id=4, items=[Item(id=1), Item(id=5)]), + ), ], - sess.query(User, oalias).filter(User.id == oalias.user_id). - options( - joinedload(User.addresses), - joinedload(oalias.items)).filter(User.id == 9). - order_by(User.id, oalias.id).all(), + sess.query(User, oalias) + .filter(User.id == oalias.user_id) + .options(joinedload(User.addresses), joinedload(oalias.items)) + .filter(User.id == 9) + .order_by(User.id, oalias.id) + .all(), ) + self.assert_sql_count(testing.db, go, 1) def test_aliased_entity_two(self): - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() @@ -3594,30 +4431,28 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): eq_( [ ( - User( - id=9, addresses=[ - Address( - id=5)]), Order( - id=2, items=[ - Item( - id=1), Item( - id=2), Item( - id=3)])), - (User(id=9, addresses=[Address(id=5)]), Order( - id=4, items=[Item(id=1), Item(id=5)])), + User(id=9, addresses=[Address(id=5)]), + Order( + id=2, items=[Item(id=1), Item(id=2), Item(id=3)] + ), + ), + ( + User(id=9, addresses=[Address(id=5)]), + Order(id=4, items=[Item(id=1), Item(id=5)]), + ), ], - sess.query(User, oalias).join(oalias, User.orders). - options(joinedload(User.addresses), - joinedload(oalias.items)). - filter(User.id == 9). - order_by(User.id, oalias.id).all(), + sess.query(User, oalias) + .join(oalias, User.orders) + .options(joinedload(User.addresses), joinedload(oalias.items)) + .filter(User.id == 9) + .order_by(User.id, oalias.id) + .all(), ) + self.assert_sql_count(testing.db, go, 1) def test_aliased_entity_three(self): - Order, User = ( - self.classes.Order, - self.classes.User) + Order, User = (self.classes.Order, self.classes.User) sess = create_session() @@ -3627,8 +4462,11 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): # orders alias. this should create two FROM clauses even though the # query has a from_clause set up via the join self.assert_compile( - sess.query(User, oalias).join(User.orders). - options(joinedload(oalias.items)).with_labels().statement, + sess.query(User, oalias) + .join(User.orders) + .options(joinedload(oalias.items)) + .with_labels() + .statement, "SELECT users.id AS users_id, users.name AS users_name, " "orders_1.id AS orders_1_id, " "orders_1.user_id AS orders_1_user_id, " @@ -3639,33 +4477,32 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "JOIN orders ON users.id = orders.user_id, " "orders AS orders_1 LEFT OUTER JOIN (order_items AS order_items_1 " "JOIN items AS items_1 ON items_1.id = order_items_1.item_id) " - "ON orders_1.id = order_items_1.order_id ORDER BY items_1.id" + "ON orders_1.id = order_items_1.order_id ORDER BY items_1.id", ) class SubqueryTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('users_table', metadata, - Column( - 'id', - Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('name', String(16)) - ) - - Table('tags_table', metadata, - Column( - 'id', - Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('user_id', Integer, ForeignKey("users_table.id")), - Column('score1', sa.Float), - Column('score2', sa.Float), - ) + Table( + "users_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(16)), + ) + + Table( + "tags_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", Integer, ForeignKey("users_table.id")), + Column("score1", sa.Float), + Column("score2", sa.Float), + ) def test_label_anonymizing(self): """Eager loading works with subqueries with labels, @@ -3681,29 +4518,33 @@ class SubqueryTest(fixtures.MappedTest): """ - tags_table, users_table = self.tables.tags_table, \ - self.tables.users_table + tags_table, users_table = ( + self.tables.tags_table, + self.tables.users_table, + ) class User(fixtures.ComparableEntity): - @property def prop_score(self): return sum([tag.prop_score for tag in self.tags]) class Tag(fixtures.ComparableEntity): - @property def prop_score(self): return self.score1 * self.score2 - for labeled, labelname in [(True, 'score'), (True, None), - (False, None)]: + for labeled, labelname in [ + (True, "score"), + (True, None), + (False, None), + ]: sa.orm.clear_mappers() - tag_score = (tags_table.c.score1 * tags_table.c.score2) - user_score = sa.select([sa.func.sum(tags_table.c.score1 * - tags_table.c.score2)], - tags_table.c.user_id == users_table.c.id) + tag_score = tags_table.c.score1 * tags_table.c.score2 + user_score = sa.select( + [sa.func.sum(tags_table.c.score1 * tags_table.c.score2)], + tags_table.c.user_id == users_table.c.id, + ) if labeled: tag_score = tag_score.label(labelname) @@ -3711,21 +4552,41 @@ class SubqueryTest(fixtures.MappedTest): else: user_score = user_score.as_scalar() - mapper(Tag, tags_table, properties={ - 'query_score': sa.orm.column_property(tag_score), - }) + mapper( + Tag, + tags_table, + properties={"query_score": sa.orm.column_property(tag_score)}, + ) - mapper(User, users_table, properties={ - 'tags': relationship(Tag, backref='user', lazy='joined'), - 'query_score': sa.orm.column_property(user_score), - }) + mapper( + User, + users_table, + properties={ + "tags": relationship(Tag, backref="user", lazy="joined"), + "query_score": sa.orm.column_property(user_score), + }, + ) session = create_session() - session.add(User(name='joe', tags=[Tag(score1=5.0, score2=3.0), - Tag(score1=55.0, score2=1.0)])) - session.add(User(name='bar', tags=[Tag(score1=5.0, score2=4.0), - Tag(score1=50.0, score2=1.0), - Tag(score1=15.0, score2=2.0)])) + session.add( + User( + name="joe", + tags=[ + Tag(score1=5.0, score2=3.0), + Tag(score1=55.0, score2=1.0), + ], + ) + ) + session.add( + User( + name="bar", + tags=[ + Tag(score1=5.0, score2=4.0), + Tag(score1=50.0, score2=1.0), + Tag(score1=15.0, score2=2.0), + ], + ) + ) session.flush() session.expunge_all() @@ -3733,8 +4594,9 @@ class SubqueryTest(fixtures.MappedTest): eq_(user.query_score, user.prop_score) def go(): - u = session.query(User).filter_by(name='joe').one() + u = session.query(User).filter_by(name="joe").one() eq_(u.query_score, u.prop_score) + self.assert_sql_count(testing.db, go, 1) for t in (tags_table, users_table): @@ -3754,84 +4616,83 @@ class CorrelatedSubqueryTest(fixtures.MappedTest): # another argument for joinedload learning about inner joins - __requires__ = ('correlated_outer_joins', ) + __requires__ = ("correlated_outer_joins",) @classmethod def define_tables(cls, metadata): Table( - 'users', metadata, + "users", + metadata, Column( - 'id', - Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), ) Table( - 'stuff', metadata, + "stuff", + metadata, Column( - 'id', - Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('date', Date), - Column('user_id', Integer, ForeignKey('users.id'))) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("date", Date), + Column("user_id", Integer, ForeignKey("users.id")), + ) @classmethod def insert_data(cls): stuff, users = cls.tables.stuff, cls.tables.users users.insert().execute( - {'id': 1, 'name': 'user1'}, - {'id': 2, 'name': 'user2'}, - {'id': 3, 'name': 'user3'}, + {"id": 1, "name": "user1"}, + {"id": 2, "name": "user2"}, + {"id": 3, "name": "user3"}, ) stuff.insert().execute( - {'id': 1, 'user_id': 1, 'date': datetime.date(2007, 10, 15)}, - {'id': 2, 'user_id': 1, 'date': datetime.date(2007, 12, 15)}, - {'id': 3, 'user_id': 1, 'date': datetime.date(2007, 11, 15)}, - {'id': 4, 'user_id': 2, 'date': datetime.date(2008, 1, 15)}, - {'id': 5, 'user_id': 3, 'date': datetime.date(2007, 6, 15)}, - {'id': 6, 'user_id': 3, 'date': datetime.date(2007, 3, 15)}, + {"id": 1, "user_id": 1, "date": datetime.date(2007, 10, 15)}, + {"id": 2, "user_id": 1, "date": datetime.date(2007, 12, 15)}, + {"id": 3, "user_id": 1, "date": datetime.date(2007, 11, 15)}, + {"id": 4, "user_id": 2, "date": datetime.date(2008, 1, 15)}, + {"id": 5, "user_id": 3, "date": datetime.date(2007, 6, 15)}, + {"id": 6, "user_id": 3, "date": datetime.date(2007, 3, 15)}, ) def test_labeled_on_date_noalias(self): - self._do_test('label', True, False) + self._do_test("label", True, False) def test_scalar_on_date_noalias(self): - self._do_test('scalar', True, False) + self._do_test("scalar", True, False) def test_plain_on_date_noalias(self): - self._do_test('none', True, False) + self._do_test("none", True, False) def test_labeled_on_limitid_noalias(self): - self._do_test('label', False, False) + self._do_test("label", False, False) def test_scalar_on_limitid_noalias(self): - self._do_test('scalar', False, False) + self._do_test("scalar", False, False) def test_plain_on_limitid_noalias(self): - self._do_test('none', False, False) + self._do_test("none", False, False) def test_labeled_on_date_alias(self): - self._do_test('label', True, True) + self._do_test("label", True, True) def test_scalar_on_date_alias(self): - self._do_test('scalar', True, True) + self._do_test("scalar", True, True) def test_plain_on_date_alias(self): - self._do_test('none', True, True) + self._do_test("none", True, True) def test_labeled_on_limitid_alias(self): - self._do_test('label', False, True) + self._do_test("label", False, True) def test_scalar_on_limitid_alias(self): - self._do_test('scalar', False, True) + self._do_test("scalar", False, True) def test_plain_on_limitid_alias(self): - self._do_test('none', False, True) + self._do_test("none", False, True) def _do_test(self, labeled, ondate, aliasstuff): stuff, users = self.tables.stuff, self.tables.users @@ -3856,15 +4717,22 @@ class CorrelatedSubqueryTest(fixtures.MappedTest): if ondate: # the more 'relational' way to do this, join on the max date - stuff_view = select([func.max(salias.c.date).label('max_date')]).\ - where(salias.c.user_id == users.c.id).correlate(users) + stuff_view = ( + select([func.max(salias.c.date).label("max_date")]) + .where(salias.c.user_id == users.c.id) + .correlate(users) + ) else: # a common method with the MySQL crowd, which actually might # perform better in some # cases - subquery does a limit with order by DESC, join on the id - stuff_view = select([salias.c.id]).\ - where(salias.c.user_id == users.c.id).\ - correlate(users).order_by(salias.c.date.desc()).limit(1) + stuff_view = ( + select([salias.c.id]) + .where(salias.c.user_id == users.c.id) + .correlate(users) + .order_by(salias.c.date.desc()) + .limit(1) + ) # can't win on this one if testing.against("mssql"): @@ -3872,39 +4740,56 @@ class CorrelatedSubqueryTest(fixtures.MappedTest): else: operator = operators.eq - if labeled == 'label': - stuff_view = stuff_view.label('foo') + if labeled == "label": + stuff_view = stuff_view.label("foo") operator = operators.eq - elif labeled == 'scalar': + elif labeled == "scalar": stuff_view = stuff_view.as_scalar() if ondate: - mapper(User, users, properties={ - 'stuff': relationship( - Stuff, - primaryjoin=and_(users.c.id == stuff.c.user_id, - operator(stuff.c.date, stuff_view))) - }) + mapper( + User, + users, + properties={ + "stuff": relationship( + Stuff, + primaryjoin=and_( + users.c.id == stuff.c.user_id, + operator(stuff.c.date, stuff_view), + ), + ) + }, + ) else: - mapper(User, users, properties={ - 'stuff': relationship( - Stuff, - primaryjoin=and_(users.c.id == stuff.c.user_id, - operator(stuff.c.id, stuff_view))) - }) + mapper( + User, + users, + properties={ + "stuff": relationship( + Stuff, + primaryjoin=and_( + users.c.id == stuff.c.user_id, + operator(stuff.c.id, stuff_view), + ), + ) + }, + ) sess = create_session() def go(): eq_( - sess.query(User).order_by(User.name).options( - joinedload('stuff')).all(), + sess.query(User) + .order_by(User.name) + .options(joinedload("stuff")) + .all(), [ - User(name='user1', stuff=[Stuff(id=2)]), - User(name='user2', stuff=[Stuff(id=4)]), - User(name='user3', stuff=[Stuff(id=5)]) - ] + User(name="user1", stuff=[Stuff(id=2)]), + User(name="user2", stuff=[Stuff(id=4)]), + User(name="user3", stuff=[Stuff(id=5)]), + ], ) + self.assert_sql_count(testing.db, go, 1) sess = create_session() @@ -3912,50 +4797,61 @@ class CorrelatedSubqueryTest(fixtures.MappedTest): def go(): eq_( sess.query(User).order_by(User.name).first(), - User(name='user1', stuff=[Stuff(id=2)]) + User(name="user1", stuff=[Stuff(id=2)]), ) + self.assert_sql_count(testing.db, go, 2) sess = create_session() def go(): eq_( - sess.query(User).order_by(User.name).options( - joinedload('stuff')).first(), - User(name='user1', stuff=[Stuff(id=2)]) + sess.query(User) + .order_by(User.name) + .options(joinedload("stuff")) + .first(), + User(name="user1", stuff=[Stuff(id=2)]), ) + self.assert_sql_count(testing.db, go, 1) sess = create_session() def go(): eq_( - sess.query(User).filter(User.id == 2).options( - joinedload('stuff')).one(), - User(name='user2', stuff=[Stuff(id=4)]) + sess.query(User) + .filter(User.id == 2) + .options(joinedload("stuff")) + .one(), + User(name="user2", stuff=[Stuff(id=4)]), ) + self.assert_sql_count(testing.db, go, 1) class CyclicalInheritingEagerTestOne(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): Table( - 't1', metadata, + "t1", + metadata, Column( - 'c1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('c2', String(30)), - Column('type', String(30)) + "c1", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("c2", String(30)), + Column("type", String(30)), ) - Table('t2', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('c2', String(30)), - Column('type', String(30)), - Column('t1.id', Integer, ForeignKey('t1.c1'))) + Table( + "t2", + metadata, + Column( + "c1", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("c2", String(30)), + Column("type", String(30)), + Column("t1.id", Integer, ForeignKey("t1.c1")), + ) def test_basic(self): t2, t1 = self.tables.t2, self.tables.t1 @@ -3972,43 +4868,51 @@ class CyclicalInheritingEagerTestOne(fixtures.MappedTest): class SubT2(T2): pass - mapper(T, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1') + mapper(T, t1, polymorphic_on=t1.c.type, polymorphic_identity="t1") mapper( - SubT, None, inherits=T, polymorphic_identity='subt1', + SubT, + None, + inherits=T, + polymorphic_identity="subt1", properties={ - 't2s': relationship( - SubT2, lazy='joined', - backref=sa.orm.backref('subt', lazy='joined')) - }) - mapper(T2, t2, polymorphic_on=t2.c.type, polymorphic_identity='t2') - mapper(SubT2, None, inherits=T2, polymorphic_identity='subt2') + "t2s": relationship( + SubT2, + lazy="joined", + backref=sa.orm.backref("subt", lazy="joined"), + ) + }, + ) + mapper(T2, t2, polymorphic_on=t2.c.type, polymorphic_identity="t2") + mapper(SubT2, None, inherits=T2, polymorphic_identity="subt2") # testing a particular endless loop condition in eager load setup create_session().query(SubT).all() -class CyclicalInheritingEagerTestTwo(fixtures.DeclarativeMappedTest, - testing.AssertsCompiledSQL): - __dialect__ = 'default' +class CyclicalInheritingEagerTestTwo( + fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL +): + __dialect__ = "default" @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class PersistentObject(Base): - __tablename__ = 'persistent' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "persistent" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) class Movie(PersistentObject): - __tablename__ = 'movie' - id = Column(Integer, ForeignKey('persistent.id'), primary_key=True) - director_id = Column(Integer, ForeignKey('director.id')) + __tablename__ = "movie" + id = Column(Integer, ForeignKey("persistent.id"), primary_key=True) + director_id = Column(Integer, ForeignKey("director.id")) title = Column(String(50)) class Director(PersistentObject): - __tablename__ = 'director' - id = Column(Integer, ForeignKey('persistent.id'), primary_key=True) + __tablename__ = "director" + id = Column(Integer, ForeignKey("persistent.id"), primary_key=True) movies = relationship("Movie", foreign_keys=Movie.director_id) name = Column(String(50)) @@ -4017,7 +4921,7 @@ class CyclicalInheritingEagerTestTwo(fixtures.DeclarativeMappedTest, s = create_session() self.assert_compile( - s.query(Director).options(joinedload('*')), + s.query(Director).options(joinedload("*")), "SELECT director.id AS director_id, " "persistent.id AS persistent_id, " "director.name AS director_name, movie_1.id AS movie_1_id, " @@ -4028,7 +4932,7 @@ class CyclicalInheritingEagerTestTwo(fixtures.DeclarativeMappedTest, "LEFT OUTER JOIN " "(persistent AS persistent_1 JOIN movie AS movie_1 " "ON persistent_1.id = movie_1.id) " - "ON director.id = movie_1.director_id" + "ON director.id = movie_1.director_id", ) def test_integrate(self): @@ -4045,13 +4949,14 @@ class CyclicalInheritingEagerTestTwo(fixtures.DeclarativeMappedTest, session.commit() session.close_all() - self.d = session.query(Director).options(joinedload('*')).first() + self.d = session.query(Director).options(joinedload("*")).first() assert len(list(session)) == 3 -class CyclicalInheritingEagerTestThree(fixtures.DeclarativeMappedTest, - testing.AssertsCompiledSQL): - __dialect__ = 'default' +class CyclicalInheritingEagerTestThree( + fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL +): + __dialect__ = "default" run_create_tables = None @classmethod @@ -4059,20 +4964,23 @@ class CyclicalInheritingEagerTestThree(fixtures.DeclarativeMappedTest, Base = cls.DeclarativeBasic class PersistentObject(Base): - __tablename__ = 'persistent' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "persistent" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) - __mapper_args__ = {'with_polymorphic': "*"} + __mapper_args__ = {"with_polymorphic": "*"} class Director(PersistentObject): - __tablename__ = 'director' - id = Column(Integer, ForeignKey('persistent.id'), primary_key=True) - other_id = Column(Integer, ForeignKey('persistent.id')) + __tablename__ = "director" + id = Column(Integer, ForeignKey("persistent.id"), primary_key=True) + other_id = Column(Integer, ForeignKey("persistent.id")) name = Column(String(50)) - other = relationship(PersistentObject, - primaryjoin=other_id == PersistentObject.id, - lazy=False) + other = relationship( + PersistentObject, + primaryjoin=other_id == PersistentObject.id, + lazy=False, + ) __mapper_args__ = {"inherit_condition": id == PersistentObject.id} def test_gen_query_nodepth(self): @@ -4084,7 +4992,7 @@ class CyclicalInheritingEagerTestThree(fixtures.DeclarativeMappedTest, "director.id AS director_id," " director.other_id AS director_other_id, " "director.name AS director_name FROM persistent " - "LEFT OUTER JOIN director ON director.id = persistent.id" + "LEFT OUTER JOIN director ON director.id = persistent.id", ) def test_gen_query_depth(self): @@ -4105,13 +5013,14 @@ class CyclicalInheritingEagerTestThree(fixtures.DeclarativeMappedTest, "LEFT OUTER JOIN (persistent AS persistent_1 " "LEFT OUTER JOIN director AS director_1 ON " "director_1.id = persistent_1.id) " - "ON director.other_id = persistent_1.id" + "ON director.other_id = persistent_1.id", ) class EnsureColumnsAddedTest( - fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL): - __dialect__ = 'default' + fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL +): + __dialect__ = "default" run_create_tables = None @classmethod @@ -4119,28 +5028,35 @@ class EnsureColumnsAddedTest( Base = cls.DeclarativeBasic class Parent(Base): - __tablename__ = 'parent' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "parent" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) arb = Column(Integer, unique=True) data = Column(Integer) o2mchild = relationship("O2MChild") - m2mchild = relationship("M2MChild", secondary=Table( - 'parent_to_m2m', Base.metadata, - Column('parent_id', ForeignKey('parent.arb')), - Column('child_id', ForeignKey('m2mchild.id')) - )) + m2mchild = relationship( + "M2MChild", + secondary=Table( + "parent_to_m2m", + Base.metadata, + Column("parent_id", ForeignKey("parent.arb")), + Column("child_id", ForeignKey("m2mchild.id")), + ), + ) class O2MChild(Base): - __tablename__ = 'o2mchild' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) - parent_id = Column(ForeignKey('parent.arb')) + __tablename__ = "o2mchild" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + parent_id = Column(ForeignKey("parent.arb")) class M2MChild(Base): - __tablename__ = 'm2mchild' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "m2mchild" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) def test_joinedload_defered_pk_limit_o2m(self): Parent = self.classes.Parent @@ -4148,9 +5064,9 @@ class EnsureColumnsAddedTest( s = Session() self.assert_compile( - s.query(Parent).options( - load_only('data'), - joinedload(Parent.o2mchild)).limit(10), + s.query(Parent) + .options(load_only("data"), joinedload(Parent.o2mchild)) + .limit(10), "SELECT anon_1.parent_id AS anon_1_parent_id, " "anon_1.parent_data AS anon_1_parent_data, " "anon_1.parent_arb AS anon_1_parent_arb, " @@ -4159,7 +5075,7 @@ class EnsureColumnsAddedTest( "FROM (SELECT parent.id AS parent_id, parent.data AS parent_data, " "parent.arb AS parent_arb FROM parent LIMIT :param_1) AS anon_1 " "LEFT OUTER JOIN o2mchild AS o2mchild_1 " - "ON anon_1.parent_arb = o2mchild_1.parent_id" + "ON anon_1.parent_arb = o2mchild_1.parent_id", ) def test_joinedload_defered_pk_limit_m2m(self): @@ -4168,9 +5084,9 @@ class EnsureColumnsAddedTest( s = Session() self.assert_compile( - s.query(Parent).options( - load_only('data'), - joinedload(Parent.m2mchild)).limit(10), + s.query(Parent) + .options(load_only("data"), joinedload(Parent.m2mchild)) + .limit(10), "SELECT anon_1.parent_id AS anon_1_parent_id, " "anon_1.parent_data AS anon_1_parent_data, " "anon_1.parent_arb AS anon_1_parent_arb, " @@ -4181,7 +5097,7 @@ class EnsureColumnsAddedTest( "LEFT OUTER JOIN (parent_to_m2m AS parent_to_m2m_1 " "JOIN m2mchild AS m2mchild_1 " "ON m2mchild_1.id = parent_to_m2m_1.child_id) " - "ON anon_1.parent_arb = parent_to_m2m_1.parent_id" + "ON anon_1.parent_arb = parent_to_m2m_1.parent_id", ) def test_joinedload_defered_pk_o2m(self): @@ -4191,13 +5107,13 @@ class EnsureColumnsAddedTest( self.assert_compile( s.query(Parent).options( - load_only('data'), - joinedload(Parent.o2mchild)), + load_only("data"), joinedload(Parent.o2mchild) + ), "SELECT parent.id AS parent_id, parent.data AS parent_data, " "parent.arb AS parent_arb, o2mchild_1.id AS o2mchild_1_id, " "o2mchild_1.parent_id AS o2mchild_1_parent_id " "FROM parent LEFT OUTER JOIN o2mchild AS o2mchild_1 " - "ON parent.arb = o2mchild_1.parent_id" + "ON parent.arb = o2mchild_1.parent_id", ) def test_joinedload_defered_pk_m2m(self): @@ -4207,14 +5123,14 @@ class EnsureColumnsAddedTest( self.assert_compile( s.query(Parent).options( - load_only('data'), - joinedload(Parent.m2mchild)), + load_only("data"), joinedload(Parent.m2mchild) + ), "SELECT parent.id AS parent_id, parent.data AS parent_data, " "parent.arb AS parent_arb, m2mchild_1.id AS m2mchild_1_id " "FROM parent LEFT OUTER JOIN (parent_to_m2m AS parent_to_m2m_1 " "JOIN m2mchild AS m2mchild_1 " "ON m2mchild_1.id = parent_to_m2m_1.child_id) " - "ON parent.arb = parent_to_m2m_1.parent_id" + "ON parent.arb = parent_to_m2m_1.parent_id", ) @@ -4226,49 +5142,48 @@ class EntityViaMultiplePathTestOne(fixtures.DeclarativeMappedTest): Base = cls.DeclarativeBasic class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) - b_id = Column(ForeignKey('b.id')) - c_id = Column(ForeignKey('c.id')) + b_id = Column(ForeignKey("b.id")) + c_id = Column(ForeignKey("c.id")) b = relationship("B") c = relationship("C") class B(Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) - c_id = Column(ForeignKey('c.id')) + c_id = Column(ForeignKey("c.id")) c = relationship("C") class C(Base): - __tablename__ = 'c' + __tablename__ = "c" id = Column(Integer, primary_key=True) - d_id = Column(ForeignKey('d.id')) + d_id = Column(ForeignKey("d.id")) d = relationship("D") class D(Base): - __tablename__ = 'd' + __tablename__ = "d" id = Column(Integer, primary_key=True) @classmethod def define_tables(cls, metadata): Table( - 'a', metadata, - Column('id', Integer, primary_key=True), - Column('bid', ForeignKey('b.id')) + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("bid", ForeignKey("b.id")), ) def test_multi_path_load(self): - A, B, C, D = self.classes('A', 'B', 'C', 'D') + A, B, C, D = self.classes("A", "B", "C", "D") s = Session() c = C(d=D()) - s.add( - A(b=B(c=c), c=c) - ) + s.add(A(b=B(c=c), c=c)) s.commit() c_alias_1 = aliased(C) @@ -4277,9 +5192,10 @@ class EntityViaMultiplePathTestOne(fixtures.DeclarativeMappedTest): q = s.query(A) q = q.join(A.b).join(c_alias_1, B.c).join(c_alias_1.d) q = q.options( - contains_eager(A.b). - contains_eager(B.c, alias=c_alias_1). - contains_eager(C.d)) + contains_eager(A.b) + .contains_eager(B.c, alias=c_alias_1) + .contains_eager(C.d) + ) q = q.join(c_alias_2, A.c) q = q.options(contains_eager(A.c, alias=c_alias_2)) @@ -4287,7 +5203,7 @@ class EntityViaMultiplePathTestOne(fixtures.DeclarativeMappedTest): # ensure 'd' key was populated in dict. Varies based on # PYTHONHASHSEED - in_('d', a1.c.__dict__) + in_("d", a1.c.__dict__) class EntityViaMultiplePathTestTwo(fixtures.DeclarativeMappedTest): @@ -4298,7 +5214,7 @@ class EntityViaMultiplePathTestTwo(fixtures.DeclarativeMappedTest): Base = cls.DeclarativeBasic class User(Base): - __tablename__ = 'cs_user' + __tablename__ = "cs_user" id = Column(Integer, primary_key=True) data = Column(Integer) @@ -4306,34 +5222,34 @@ class EntityViaMultiplePathTestTwo(fixtures.DeclarativeMappedTest): class LD(Base): """Child. The column we reference 'A' with is an integer.""" - __tablename__ = 'cs_ld' + __tablename__ = "cs_ld" id = Column(Integer, primary_key=True) - user_id = Column(Integer, ForeignKey('cs_user.id')) + user_id = Column(Integer, ForeignKey("cs_user.id")) user = relationship(User, primaryjoin=user_id == User.id) class A(Base): """Child. The column we reference 'A' with is an integer.""" - __tablename__ = 'cs_a' + __tablename__ = "cs_a" id = Column(Integer, primary_key=True) - ld_id = Column(Integer, ForeignKey('cs_ld.id')) + ld_id = Column(Integer, ForeignKey("cs_ld.id")) ld = relationship(LD, primaryjoin=ld_id == LD.id) class LDA(Base): """Child. The column we reference 'A' with is an integer.""" - __tablename__ = 'cs_lda' + __tablename__ = "cs_lda" id = Column(Integer, primary_key=True) - ld_id = Column(Integer, ForeignKey('cs_ld.id')) - a_id = Column(Integer, ForeignKey('cs_a.id')) + ld_id = Column(Integer, ForeignKey("cs_ld.id")) + a_id = Column(Integer, ForeignKey("cs_a.id")) a = relationship(A, primaryjoin=a_id == A.id) ld = relationship(LD, primaryjoin=ld_id == LD.id) def test_multi_path_load(self): - User, LD, A, LDA = self.classes('User', 'LD', 'A', 'LDA') + User, LD, A, LDA = self.classes("User", "LD", "A", "LDA") s = Session() @@ -4341,27 +5257,27 @@ class EntityViaMultiplePathTestTwo(fixtures.DeclarativeMappedTest): l0 = LD(user=u0) z0 = A(ld=l0) lz0 = LDA(ld=l0, a=z0) - s.add_all([ - u0, l0, z0, lz0 - ]) + s.add_all([u0, l0, z0, lz0]) s.commit() l_ac = aliased(LD) u_ac = aliased(User) - lz_test = (s.query(LDA) - .join('ld') - .options(contains_eager('ld')) - .join('a', (l_ac, 'ld'), (u_ac, 'user')) - .options(contains_eager('a') - .contains_eager('ld', alias=l_ac) - .contains_eager('user', alias=u_ac)) - .first()) - - in_( - 'user', lz_test.a.ld.__dict__ + lz_test = ( + s.query(LDA) + .join("ld") + .options(contains_eager("ld")) + .join("a", (l_ac, "ld"), (u_ac, "user")) + .options( + contains_eager("a") + .contains_eager("ld", alias=l_ac) + .contains_eager("user", alias=u_ac) + ) + .first() ) + in_("user", lz_test.a.ld.__dict__) + class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): """test for [ticket:3963]""" @@ -4371,20 +5287,20 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): Base = cls.DeclarativeBasic class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) bs = relationship("B") class B(Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) - a_id = Column(ForeignKey('a.id')) + a_id = Column(ForeignKey("a.id")) cs = relationship("C") class C(Base): - __tablename__ = 'c' + __tablename__ = "c" id = Column(Integer, primary_key=True) - b_id = Column(ForeignKey('b.id')) + b_id = Column(ForeignKey("b.id")) @classmethod def insert_data(cls): @@ -4399,42 +5315,55 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): for a, _ in query: for b in a.bs: b.cs + self.assert_sql_count(testing.db, go, expected) def test_string_options_aliased_whatever(self): A, B, C = self.classes("A", "B", "C") s = Session() aa = aliased(A) - q = s.query(aa, A).filter( - aa.id == 1).filter(A.id == 2).options( - joinedload("bs").joinedload("cs")) + q = ( + s.query(aa, A) + .filter(aa.id == 1) + .filter(A.id == 2) + .options(joinedload("bs").joinedload("cs")) + ) self._run_tests(q, 1) def test_string_options_unaliased_whatever(self): A, B, C = self.classes("A", "B", "C") s = Session() aa = aliased(A) - q = s.query(A, aa).filter( - aa.id == 2).filter(A.id == 1).options( - joinedload("bs").joinedload("cs")) + q = ( + s.query(A, aa) + .filter(aa.id == 2) + .filter(A.id == 1) + .options(joinedload("bs").joinedload("cs")) + ) self._run_tests(q, 1) def test_lazyload_aliased_abs_bcs_one(self): A, B, C = self.classes("A", "B", "C") s = Session() aa = aliased(A) - q = s.query(aa, A).filter( - aa.id == 1).filter(A.id == 2).options( - joinedload(A.bs).joinedload(B.cs)) + q = ( + s.query(aa, A) + .filter(aa.id == 1) + .filter(A.id == 2) + .options(joinedload(A.bs).joinedload(B.cs)) + ) self._run_tests(q, 3) def test_lazyload_aliased_abs_bcs_two(self): A, B, C = self.classes("A", "B", "C") s = Session() aa = aliased(A) - q = s.query(aa, A).filter( - aa.id == 1).filter(A.id == 2).options( - defaultload(A.bs).joinedload(B.cs)) + q = ( + s.query(aa, A) + .filter(aa.id == 1) + .filter(A.id == 2) + .options(defaultload(A.bs).joinedload(B.cs)) + ) self._run_tests(q, 3) def test_pathed_lazyload_aliased_abs_bcs(self): @@ -4443,8 +5372,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): aa = aliased(A) opt = Load(A).joinedload(A.bs).joinedload(B.cs) - q = s.query(aa, A).filter( - aa.id == 1).filter(A.id == 2).options(opt) + q = s.query(aa, A).filter(aa.id == 1).filter(A.id == 2).options(opt) self._run_tests(q, 3) def test_pathed_lazyload_plus_joined_aliased_abs_bcs(self): @@ -4453,8 +5381,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): aa = aliased(A) opt = Load(aa).defaultload(aa.bs).joinedload(B.cs) - q = s.query(aa, A).filter( - aa.id == 1).filter(A.id == 2).options(opt) + q = s.query(aa, A).filter(aa.id == 1).filter(A.id == 2).options(opt) self._run_tests(q, 2) def test_pathed_joinedload_aliased_abs_bcs(self): @@ -4463,62 +5390,79 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): aa = aliased(A) opt = Load(aa).joinedload(aa.bs).joinedload(B.cs) - q = s.query(aa, A).filter( - aa.id == 1).filter(A.id == 2).options(opt) + q = s.query(aa, A).filter(aa.id == 1).filter(A.id == 2).options(opt) self._run_tests(q, 1) def test_lazyload_plus_joined_aliased_abs_bcs(self): A, B, C = self.classes("A", "B", "C") s = Session() aa = aliased(A) - q = s.query(aa, A).filter( - aa.id == 1).filter(A.id == 2).options( - defaultload(aa.bs).joinedload(B.cs)) + q = ( + s.query(aa, A) + .filter(aa.id == 1) + .filter(A.id == 2) + .options(defaultload(aa.bs).joinedload(B.cs)) + ) self._run_tests(q, 2) def test_joinedload_aliased_abs_bcs(self): A, B, C = self.classes("A", "B", "C") s = Session() aa = aliased(A) - q = s.query(aa, A).filter( - aa.id == 1).filter(A.id == 2).options( - joinedload(aa.bs).joinedload(B.cs)) + q = ( + s.query(aa, A) + .filter(aa.id == 1) + .filter(A.id == 2) + .options(joinedload(aa.bs).joinedload(B.cs)) + ) self._run_tests(q, 1) def test_lazyload_unaliased_abs_bcs_one(self): A, B, C = self.classes("A", "B", "C") s = Session() aa = aliased(A) - q = s.query(A, aa).filter( - aa.id == 2).filter(A.id == 1).options( - joinedload(aa.bs).joinedload(B.cs)) + q = ( + s.query(A, aa) + .filter(aa.id == 2) + .filter(A.id == 1) + .options(joinedload(aa.bs).joinedload(B.cs)) + ) self._run_tests(q, 3) def test_lazyload_unaliased_abs_bcs_two(self): A, B, C = self.classes("A", "B", "C") s = Session() aa = aliased(A) - q = s.query(A, aa).filter( - aa.id == 2).filter(A.id == 1).options( - defaultload(aa.bs).joinedload(B.cs)) + q = ( + s.query(A, aa) + .filter(aa.id == 2) + .filter(A.id == 1) + .options(defaultload(aa.bs).joinedload(B.cs)) + ) self._run_tests(q, 3) def test_lazyload_plus_joined_unaliased_abs_bcs(self): A, B, C = self.classes("A", "B", "C") s = Session() aa = aliased(A) - q = s.query(A, aa).filter( - aa.id == 2).filter(A.id == 1).options( - defaultload(A.bs).joinedload(B.cs)) + q = ( + s.query(A, aa) + .filter(aa.id == 2) + .filter(A.id == 1) + .options(defaultload(A.bs).joinedload(B.cs)) + ) self._run_tests(q, 2) def test_joinedload_unaliased_abs_bcs(self): A, B, C = self.classes("A", "B", "C") s = Session() aa = aliased(A) - q = s.query(A, aa).filter( - aa.id == 2).filter(A.id == 1).options( - joinedload(A.bs).joinedload(B.cs)) + q = ( + s.query(A, aa) + .filter(aa.id == 2) + .filter(A.id == 1) + .options(joinedload(A.bs).joinedload(B.cs)) + ) self._run_tests(q, 1) @@ -4530,25 +5474,31 @@ class EntityViaMultiplePathTestThree(fixtures.DeclarativeMappedTest): Base = cls.DeclarativeBasic class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) - parent_id = Column(Integer, ForeignKey('a.id')) + parent_id = Column(Integer, ForeignKey("a.id")) parent = relationship("A", remote_side=id, lazy="raise") def test_multi_path_load_lazy_none(self): A = self.classes.A s = Session() - s.add_all([ - A(id=1, parent_id=None), - A(id=2, parent_id=2), - A(id=4, parent_id=None), - A(id=3, parent_id=4), - ]) + s.add_all( + [ + A(id=1, parent_id=None), + A(id=2, parent_id=2), + A(id=4, parent_id=None), + A(id=3, parent_id=4), + ] + ) s.commit() - q1 = s.query(A).order_by(A.id).\ - filter(A.id.in_([1, 2])).options(joinedload(A.parent)) + q1 = ( + s.query(A) + .order_by(A.id) + .filter(A.id.in_([1, 2])) + .options(joinedload(A.parent)) + ) def go(): for a in q1: @@ -4559,8 +5509,12 @@ class EntityViaMultiplePathTestThree(fixtures.DeclarativeMappedTest): self.assert_sql_count(testing.db, go, 1) - q1 = s.query(A).order_by(A.id).\ - filter(A.id.in_([3, 4])).options(joinedload(A.parent)) + q1 = ( + s.query(A) + .order_by(A.id) + .filter(A.id.in_([3, 4])) + .options(joinedload(A.parent)) + ) def go(): for a in q1: diff --git a/test/orm/test_evaluator.py b/test/orm/test_evaluator.py index fca050ccf0..651c81a858 100644 --- a/test/orm/test_evaluator.py +++ b/test/orm/test_evaluator.py @@ -24,9 +24,13 @@ def eval_eq(clause, testcases=None): evaluator = compiler.process(clause) def testeval(obj=None, expected_result=None): - assert evaluator(obj) == expected_result, \ - "%s != %r for %s with %r" % ( - evaluator(obj), expected_result, clause, obj) + assert evaluator(obj) == expected_result, "%s != %r for %s with %r" % ( + evaluator(obj), + expected_result, + clause, + obj, + ) + if testcases: for an_obj, result in testcases: testeval(an_obj, result) @@ -36,10 +40,13 @@ def eval_eq(clause, testcases=None): class EvaluateTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(64)), - Column('othername', String(64))) + Table( + "users", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(64)), + Column("othername", String(64)), + ) @classmethod def setup_classes(cls): @@ -55,35 +62,43 @@ class EvaluateTest(fixtures.MappedTest): def test_compare_to_value(self): User = self.classes.User - eval_eq(User.name == 'foo', testcases=[ - (User(name='foo'), True), - (User(name='bar'), False), - (User(name=None), None), - ]) + eval_eq( + User.name == "foo", + testcases=[ + (User(name="foo"), True), + (User(name="bar"), False), + (User(name=None), None), + ], + ) - eval_eq(User.id < 5, testcases=[ - (User(id=3), True), - (User(id=5), False), - (User(id=None), None), - ]) + eval_eq( + User.id < 5, + testcases=[ + (User(id=3), True), + (User(id=5), False), + (User(id=None), None), + ], + ) def test_compare_to_callable_bind(self): User = self.classes.User eval_eq( - User.name == bindparam('x', callable_=lambda: 'foo'), + User.name == bindparam("x", callable_=lambda: "foo"), testcases=[ - (User(name='foo'), True), - (User(name='bar'), False), + (User(name="foo"), True), + (User(name="bar"), False), (User(name=None), None), - ] + ], ) def test_compare_to_none(self): User = self.classes.User - eval_eq(User.name == None, # noqa - testcases=[(User(name='foo'), False), (User(name=None), True)]) + eval_eq( + User.name == None, # noqa + testcases=[(User(name="foo"), False), (User(name=None), True)], + ) def test_warn_on_unannotated_matched_column(self): User = self.classes.User @@ -92,8 +107,9 @@ class EvaluateTest(fixtures.MappedTest): with expect_warnings( r"Evaluating non-mapped column expression 'othername' " - "onto ORM instances; this is a deprecated use case."): - meth = compiler.process(User.name == Column('othername', String)) + "onto ORM instances; this is a deprecated use case." + ): + meth = compiler.process(User.name == Column("othername", String)) u1 = User(id=5) meth(u1) @@ -106,7 +122,8 @@ class EvaluateTest(fixtures.MappedTest): assert_raises_message( evaluator.UnevaluatableError, "Cannot evaluate column: foo", - compiler.process, User.id == Column('foo', Integer) + compiler.process, + User.id == Column("foo", Integer), ) # if we let the above method through as we did @@ -121,58 +138,70 @@ class EvaluateTest(fixtures.MappedTest): eval_eq( User.name == False, # noqa testcases=[ - (User(name='foo'), False), + (User(name="foo"), False), (User(name=True), False), (User(name=False), True), - ] + ], ) eval_eq( User.name == True, # noqa testcases=[ - (User(name='foo'), False), + (User(name="foo"), False), (User(name=True), True), (User(name=False), False), - ] + ], ) def test_boolean_ops(self): User = self.classes.User - eval_eq(and_(User.name == 'foo', User.id == 1), testcases=[ - (User(id=1, name='foo'), True), - (User(id=2, name='foo'), False), - (User(id=1, name='bar'), False), - (User(id=2, name='bar'), False), - (User(id=1, name=None), None), - ]) - - eval_eq(or_(User.name == 'foo', User.id == 1), testcases=[ - (User(id=1, name='foo'), True), - (User(id=2, name='foo'), True), - (User(id=1, name='bar'), True), - (User(id=2, name='bar'), False), - (User(id=1, name=None), True), - (User(id=2, name=None), None), - ]) - - eval_eq(not_(User.id == 1), testcases=[ - (User(id=1), False), - (User(id=2), True), - (User(id=None), None), - ]) + eval_eq( + and_(User.name == "foo", User.id == 1), + testcases=[ + (User(id=1, name="foo"), True), + (User(id=2, name="foo"), False), + (User(id=1, name="bar"), False), + (User(id=2, name="bar"), False), + (User(id=1, name=None), None), + ], + ) + + eval_eq( + or_(User.name == "foo", User.id == 1), + testcases=[ + (User(id=1, name="foo"), True), + (User(id=2, name="foo"), True), + (User(id=1, name="bar"), True), + (User(id=2, name="bar"), False), + (User(id=1, name=None), True), + (User(id=2, name=None), None), + ], + ) + + eval_eq( + not_(User.id == 1), + testcases=[ + (User(id=1), False), + (User(id=2), True), + (User(id=None), None), + ], + ) def test_null_propagation(self): User = self.classes.User - eval_eq((User.name == 'foo') == (User.id == 1), testcases=[ - (User(id=1, name='foo'), True), - (User(id=2, name='foo'), False), - (User(id=1, name='bar'), False), - (User(id=2, name='bar'), True), - (User(id=None, name='foo'), None), - (User(id=None, name=None), None), - ]) + eval_eq( + (User.name == "foo") == (User.id == 1), + testcases=[ + (User(id=1, name="foo"), True), + (User(id=2, name="foo"), False), + (User(id=1, name="bar"), False), + (User(id=2, name="bar"), True), + (User(id=None, name="foo"), None), + (User(id=None, name=None), None), + ], + ) class M2OEvaluateTest(fixtures.DeclarativeMappedTest): @@ -187,12 +216,13 @@ class M2OEvaluateTest(fixtures.DeclarativeMappedTest): class Child(Base): __tablename__ = "child" _id_parent = Column( - "id_parent", Integer, ForeignKey(Parent.id), primary_key=True) + "id_parent", Integer, ForeignKey(Parent.id), primary_key=True + ) name = Column(String(50), primary_key=True) parent = relationship(Parent) def test_delete(self): - Parent, Child = self.classes('Parent', 'Child') + Parent, Child = self.classes("Parent", "Child") session = Session() @@ -206,6 +236,4 @@ class M2OEvaluateTest(fixtures.DeclarativeMappedTest): session.query(Child).filter(Child.parent == p).delete("evaluate") - is_( - inspect(c).deleted, True - ) + is_(inspect(c).deleted, True) diff --git a/test/orm/test_events.py b/test/orm/test_events.py index de193f1f03..1390da4979 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -4,10 +4,19 @@ from sqlalchemy import testing from sqlalchemy import Integer, String from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.testing.schema import Table, Column -from sqlalchemy.orm import mapper, relationship, \ - create_session, class_mapper, \ - Mapper, column_property, query, \ - Session, sessionmaker, attributes, configure_mappers +from sqlalchemy.orm import ( + mapper, + relationship, + create_session, + class_mapper, + Mapper, + column_property, + query, + Session, + sessionmaker, + attributes, + configure_mappers, +) from sqlalchemy.orm.instrumentation import ClassManager from sqlalchemy.orm import instrumentation, events from sqlalchemy.orm import EXT_SKIP @@ -22,7 +31,6 @@ from sqlalchemy.testing.mock import Mock, call, ANY class _RemoveListeners(object): - def teardown(self): events.MapperEvents._clear() events.InstanceEvents._clear() @@ -38,8 +46,8 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): @classmethod def define_tables(cls, metadata): super(MapperEventsTest, cls).define_tables(metadata) - metadata.tables['users'].append_column( - Column('extra', Integer, default=5, onupdate=10) + metadata.tables["users"].append_column( + Column("extra", Integer, default=5, onupdate=10) ) def test_instance_event_listen(self): @@ -56,38 +64,47 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): pass mapper(A, users) - mapper(B, addresses, inherits=A, - properties={'address_id': addresses.c.id}) + mapper( + B, addresses, inherits=A, properties={"address_id": addresses.c.id} + ) def init_a(target, args, kwargs): - canary.append(('init_a', target)) + canary.append(("init_a", target)) def init_b(target, args, kwargs): - canary.append(('init_b', target)) + canary.append(("init_b", target)) def init_c(target, args, kwargs): - canary.append(('init_c', target)) + canary.append(("init_c", target)) def init_d(target, args, kwargs): - canary.append(('init_d', target)) + canary.append(("init_d", target)) def init_e(target, args, kwargs): - canary.append(('init_e', target)) + canary.append(("init_e", target)) - event.listen(mapper, 'init', init_a) - event.listen(Mapper, 'init', init_b) - event.listen(class_mapper(A), 'init', init_c) - event.listen(A, 'init', init_d) - event.listen(A, 'init', init_e, propagate=True) + event.listen(mapper, "init", init_a) + event.listen(Mapper, "init", init_b) + event.listen(class_mapper(A), "init", init_c) + event.listen(A, "init", init_d) + event.listen(A, "init", init_e, propagate=True) a = A() - eq_(canary, [('init_a', a), ('init_b', a), - ('init_c', a), ('init_d', a), ('init_e', a)]) + eq_( + canary, + [ + ("init_a", a), + ("init_b", a), + ("init_c", a), + ("init_d", a), + ("init_e", a), + ], + ) # test propagate flag canary[:] = [] b = B() - eq_(canary, [('init_a', b), ('init_b', b), ('init_e', b)]) + eq_(canary, [("init_a", b), ("init_b", b), ("init_e", b)]) def listen_all(self, mapper, **kw): canary = [] @@ -95,21 +112,22 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): def evt(meth): def go(*args, **kwargs): canary.append(meth) + return go for meth in [ - 'init', - 'init_failure', - 'load', - 'refresh', - 'refresh_flush', - 'expire', - 'before_insert', - 'after_insert', - 'before_update', - 'after_update', - 'before_delete', - 'after_delete' + "init", + "init_failure", + "load", + "refresh", + "refresh_flush", + "expire", + "before_insert", + "after_insert", + "before_update", + "after_update", + "before_delete", + "after_delete", ]: event.listen(mapper, meth, evt(meth), **kw) return canary @@ -118,44 +136,39 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - @event.listens_for(User, 'init') + @event.listens_for(User, "init") def add_name(obj, args, kwargs): - kwargs['name'] = 'ed' + kwargs["name"] = "ed" u1 = User() - eq_(u1.name, 'ed') + eq_(u1.name, "ed") def test_init_failure_hook(self): users = self.tables.users class Thing(object): def __init__(self, **kw): - if kw.get('fail'): + if kw.get("fail"): raise Exception("failure") mapper(Thing, users) canary = Mock() - event.listen(Thing, 'init_failure', canary) + event.listen(Thing, "init_failure", canary) Thing() eq_(canary.mock_calls, []) - assert_raises_message( - Exception, - "failure", - Thing, fail=True - ) - eq_( - canary.mock_calls, - [call(ANY, (), {'fail': True})] - ) + assert_raises_message(Exception, "failure", Thing, fail=True) + eq_(canary.mock_calls, [call(ANY, (), {"fail": True})]) def test_listen_doesnt_force_compile(self): User, users = self.classes.User, self.tables.users - m = mapper(User, users, properties={ - 'addresses': relationship(lambda: ImNotAClass) - }) + m = mapper( + User, + users, + properties={"addresses": relationship(lambda: ImNotAClass)}, + ) event.listen(User, "before_insert", lambda *a, **kw: None) assert not m.configured @@ -167,25 +180,31 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): named_canary = self.listen_all(User, named=True) sess = create_session() - u = User(name='u1') + u = User(name="u1") sess.add(u) sess.flush() sess.expire(u) u = sess.query(User).get(u.id) sess.expunge_all() u = sess.query(User).get(u.id) - u.name = 'u1 changed' + u.name = "u1 changed" sess.flush() sess.delete(u) sess.flush() expected = [ - 'init', 'before_insert', - 'refresh_flush', - 'after_insert', 'expire', - 'refresh', - 'load', - 'before_update', 'refresh_flush', 'after_update', 'before_delete', - 'after_delete'] + "init", + "before_insert", + "refresh_flush", + "after_insert", + "expire", + "refresh", + "load", + "before_update", + "refresh_flush", + "after_update", + "before_delete", + "after_delete", + ] eq_(canary, expected) eq_(named_canary, expected) @@ -205,7 +224,7 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): eq_( canary.mock_calls, - [call.listen4(), call.listen2(), call.listen1(), call.listen3()] + [call.listen4(), call.listen2(), call.listen1(), call.listen3()], ) def test_insert_flags(self): @@ -217,10 +236,11 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): arg = Mock() - event.listen(m, "before_insert", canary.listen1, ) + event.listen(m, "before_insert", canary.listen1) event.listen(m, "before_insert", canary.listen2, insert=True) - event.listen(m, "before_insert", canary.listen3, - propagate=True, insert=True) + event.listen( + m, "before_insert", canary.listen3, propagate=True, insert=True + ) event.listen(m, "load", canary.listen4) event.listen(m, "load", canary.listen5, insert=True) event.listen(m, "load", canary.listen6, propagate=True, insert=True) @@ -237,8 +257,8 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): call.listen1(arg, arg, arg.obj()), call.listen6(arg.obj(), arg), call.listen5(arg.obj(), arg), - call.listen4(arg.obj(), arg) - ] + call.listen4(arg.obj(), arg), + ], ) def test_merge(self): @@ -249,65 +269,95 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): canary = [] def load(obj, ctx): - canary.append('load') - event.listen(mapper, 'load', load) + canary.append("load") + + event.listen(mapper, "load", load) s = Session() - u = User(name='u1') + u = User(name="u1") s.add(u) s.commit() s = Session() u2 = s.merge(u) s = Session() - u2 = s.merge(User(name='u2')) # noqa + u2 = s.merge(User(name="u2")) # noqa s.commit() s.query(User).order_by(User.id).first() - eq_(canary, ['load', 'load', 'load']) + eq_(canary, ["load", "load", "load"]) def test_inheritance(self): - users, addresses, User = (self.tables.users, - self.tables.addresses, - self.classes.User) + users, addresses, User = ( + self.tables.users, + self.tables.addresses, + self.classes.User, + ) class AdminUser(User): pass mapper(User, users) - mapper(AdminUser, addresses, inherits=User, - properties={'address_id': addresses.c.id}) + mapper( + AdminUser, + addresses, + inherits=User, + properties={"address_id": addresses.c.id}, + ) canary1 = self.listen_all(User, propagate=True) canary2 = self.listen_all(User) canary3 = self.listen_all(AdminUser) sess = create_session() - am = AdminUser(name='au1', email_address='au1@e1') + am = AdminUser(name="au1", email_address="au1@e1") sess.add(am) sess.flush() am = sess.query(AdminUser).populate_existing().get(am.id) sess.expunge_all() am = sess.query(AdminUser).get(am.id) - am.name = 'au1 changed' + am.name = "au1 changed" sess.flush() sess.delete(am) sess.flush() - eq_(canary1, ['init', 'before_insert', 'refresh_flush', 'after_insert', - 'refresh', 'load', - 'before_update', 'refresh_flush', - 'after_update', 'before_delete', - 'after_delete']) + eq_( + canary1, + [ + "init", + "before_insert", + "refresh_flush", + "after_insert", + "refresh", + "load", + "before_update", + "refresh_flush", + "after_update", + "before_delete", + "after_delete", + ], + ) eq_(canary2, []) - eq_(canary3, ['init', 'before_insert', 'refresh_flush', 'after_insert', - 'refresh', - 'load', - 'before_update', 'refresh_flush', - 'after_update', 'before_delete', - 'after_delete']) + eq_( + canary3, + [ + "init", + "before_insert", + "refresh_flush", + "after_insert", + "refresh", + "load", + "before_update", + "refresh_flush", + "after_update", + "before_delete", + "after_delete", + ], + ) def test_inheritance_subclass_deferred(self): - users, addresses, User = (self.tables.users, - self.tables.addresses, - self.classes.User) + users, addresses, User = ( + self.tables.users, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) @@ -316,32 +366,59 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): class AdminUser(User): pass - mapper(AdminUser, addresses, inherits=User, - properties={'address_id': addresses.c.id}) + + mapper( + AdminUser, + addresses, + inherits=User, + properties={"address_id": addresses.c.id}, + ) canary3 = self.listen_all(AdminUser) sess = create_session() - am = AdminUser(name='au1', email_address='au1@e1') + am = AdminUser(name="au1", email_address="au1@e1") sess.add(am) sess.flush() am = sess.query(AdminUser).populate_existing().get(am.id) sess.expunge_all() am = sess.query(AdminUser).get(am.id) - am.name = 'au1 changed' + am.name = "au1 changed" sess.flush() sess.delete(am) sess.flush() - eq_(canary1, ['init', 'before_insert', 'refresh_flush', 'after_insert', - 'refresh', 'load', - 'before_update', 'refresh_flush', - 'after_update', 'before_delete', - 'after_delete']) + eq_( + canary1, + [ + "init", + "before_insert", + "refresh_flush", + "after_insert", + "refresh", + "load", + "before_update", + "refresh_flush", + "after_update", + "before_delete", + "after_delete", + ], + ) eq_(canary2, []) - eq_(canary3, ['init', 'before_insert', 'refresh_flush', 'after_insert', - 'refresh', 'load', - 'before_update', 'refresh_flush', - 'after_update', 'before_delete', - 'after_delete']) + eq_( + canary3, + [ + "init", + "before_insert", + "refresh_flush", + "after_insert", + "refresh", + "load", + "before_update", + "refresh_flush", + "after_update", + "before_delete", + "after_delete", + ], + ) def test_before_after_only_collection(self): """before_update is called on parent for collection modifications, @@ -354,10 +431,16 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): self.tables.items, self.tables.item_keywords, self.classes.Keyword, - self.classes.Item) + self.classes.Item, + ) - mapper(Item, items, properties={ - 'keywords': relationship(Keyword, secondary=item_keywords)}) + mapper( + Item, + items, + properties={ + "keywords": relationship(Keyword, secondary=item_keywords) + }, + ) mapper(Keyword, keywords) canary1 = self.listen_all(Item) @@ -369,19 +452,15 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): sess.add(i1) sess.add(k1) sess.flush() - eq_(canary1, - ['init', - 'before_insert', 'after_insert']) - eq_(canary2, - ['init', - 'before_insert', 'after_insert']) + eq_(canary1, ["init", "before_insert", "after_insert"]) + eq_(canary2, ["init", "before_insert", "after_insert"]) canary1[:] = [] canary2[:] = [] i1.keywords.append(k1) sess.flush() - eq_(canary1, ['before_update', 'after_update']) + eq_(canary1, ["before_update", "after_update"]) eq_(canary2, []) def test_before_after_configured_warn_on_non_mapper(self): @@ -395,7 +474,10 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): r"before_configured' and 'after_configured' ORM events only " r"invoke with the mapper\(\) function or Mapper class as " r"the target.", - event.listen, User, 'before_configured', m1 + event.listen, + User, + "before_configured", + m1, ) assert_raises_message( @@ -403,7 +485,10 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): r"before_configured' and 'after_configured' ORM events only " r"invoke with the mapper\(\) function or Mapper class as " r"the target.", - event.listen, User, 'after_configured', m1 + event.listen, + User, + "after_configured", + m1, ) def test_before_after_configured(self): @@ -424,17 +509,19 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): eq_(m2.mock_calls, [call()]) def test_instrument_event(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) canary = [] def instrument_class(mapper, cls): canary.append(cls) - event.listen(Mapper, 'instrument_class', instrument_class) + event.listen(Mapper, "instrument_class", instrument_class) mapper(User, users) eq_(canary, [User]) @@ -472,9 +559,9 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): [ call.instrument_class(MyClass), call.class_instrument(MyClass), - call.init() + call.init(), ], - canary.mock_calls + canary.mock_calls, ) def test_before_mapper_configured_event(self): @@ -497,10 +584,11 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): AnotherBase = declarative_base() class Animal(AnotherBase): - __tablename__ = 'animal' + __tablename__ = "animal" species = Column(String(30), primary_key=True) __mapper_args__ = dict( - polymorphic_on='species', polymorphic_identity='Animal') + polymorphic_on="species", polymorphic_identity="Animal" + ) # Register the first classes and create their Mappers: configure_mappers() @@ -511,7 +599,7 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): # Declare a subclass, table and mapper, which refers to one that has # not been loaded yet (Employer), and therefore cannot be configured: class Mammal(Animal): - nonexistent = relationship('Nonexistent') + nonexistent = relationship("Nonexistent") # These new classes should not be configured at this point: unconfigured = [m for m in _mapper_registry if not m.configured] @@ -523,20 +611,25 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): def probe(): s = Session() s.query(User) + assert_raises(sa.exc.InvalidRequestError, probe) # If we disable configuring mappers while querying, then it succeeds: @event.listens_for( - AnotherBase, "before_mapper_configured", propagate=True, - retval=True) + AnotherBase, + "before_mapper_configured", + propagate=True, + retval=True, + ) def disable_configure_mappers(mapper, cls): return EXT_SKIP probe() -class DeclarativeEventListenTest(_RemoveListeners, - fixtures.DeclarativeMappedTest): +class DeclarativeEventListenTest( + _RemoveListeners, fixtures.DeclarativeMappedTest +): run_setup_classes = "each" run_deletes = None @@ -544,7 +637,7 @@ class DeclarativeEventListenTest(_RemoveListeners, # test [ticket:2949] class A(self.DeclarativeBasic): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) class B(A): @@ -565,10 +658,7 @@ class DeclarativeEventListenTest(_RemoveListeners, m3.dispatch.load(c1._sa_instance_state, "c") m2.dispatch.load(b1._sa_instance_state, "b") m1.dispatch.load(a1._sa_instance_state, "a") - eq_( - listen.mock_calls, - [call(c1, "c"), call(b1, "b"), call(a1, "a")] - ) + eq_(listen.mock_calls, [call(c1, "c"), call(b1, "b"), call(a1, "a")]) class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): @@ -579,6 +669,7 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): it has to get all of these, too. """ + run_inserts = None def test_deferred_map_event(self): @@ -588,13 +679,13 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): 3. event fire should receive event """ - users, User = (self.tables.users, - self.classes.User) + users, User = (self.tables.users, self.classes.User) canary = [] def evt(x, y, z): canary.append(x) + event.listen(User, "before_insert", evt, raw=True) m = mapper(User, users) @@ -608,8 +699,7 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): 3. event fire should receive event """ - users, User = (self.tables.users, - self.classes.User) + users, User = (self.tables.users, self.classes.User) class SubUser(User): pass @@ -621,18 +711,17 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): def evt(x, y, z): canary.append(x) + event.listen(User, "before_insert", canary, propagate=True, raw=True) m = mapper(SubUser, users) m.dispatch.before_insert(5, 6, 7) - eq_(canary.mock_calls, - [call(5, 6, 7)]) + eq_(canary.mock_calls, [call(5, 6, 7)]) m2 = mapper(SubSubUser, users) m2.dispatch.before_insert(8, 9, 10) - eq_(canary.mock_calls, - [call(5, 6, 7), call(8, 9, 10)]) + eq_(canary.mock_calls, [call(5, 6, 7), call(8, 9, 10)]) def test_deferred_map_event_subclass_no_propagate(self): """ @@ -641,8 +730,7 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): 3. event fire should not receive event """ - users, User = (self.tables.users, - self.classes.User) + users, User = (self.tables.users, self.classes.User) class SubUser(User): pass @@ -651,6 +739,7 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): def evt(x, y, z): canary.append(x) + event.listen(User, "before_insert", evt, propagate=False) m = mapper(SubUser, users) @@ -664,8 +753,7 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): 3. event fire should receive event """ - users, User = (self.tables.users, - self.classes.User) + users, User = (self.tables.users, self.classes.User) class SubUser(User): pass @@ -676,6 +764,7 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): def evt(x, y, z): canary.append(x) + event.listen(User, "before_insert", evt, propagate=True, raw=True) m.dispatch.before_insert(5, 6, 7) @@ -688,8 +777,7 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): 3. event fire should receive event """ - users, User = (self.tables.users, - self.classes.User) + users, User = (self.tables.users, self.classes.User) class SubUser(User): pass @@ -717,8 +805,7 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): 3. event fire should receive event """ - users, User = (self.tables.users, - self.classes.User) + users, User = (self.tables.users, self.classes.User) class SubUser(User): pass @@ -729,6 +816,7 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): def evt(x): canary.append(x) + event.listen(User, "load", evt, propagate=True, raw=True) m.class_manager.dispatch.load(5) @@ -741,13 +829,13 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): 3. event fire should receive event """ - users, User = (self.tables.users, - self.classes.User) + users, User = (self.tables.users, self.classes.User) canary = [] def evt(x): canary.append(x) + event.listen(User, "load", evt, raw=True) m = mapper(User, users) @@ -761,8 +849,7 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): 3. event fire on each class should receive one and only one event """ - users, User = (self.tables.users, - self.classes.User) + users, User = (self.tables.users, self.classes.User) class SubUser(User): pass @@ -774,6 +861,7 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): def evt(x): canary.append(x) + event.listen(User, "load", evt, propagate=True, raw=True) m = mapper(SubUser, users) @@ -795,8 +883,7 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): 5. map 2nd subclass 6. event fire on 2nd subclass should receive one and only one event """ - users, User = (self.tables.users, - self.classes.User) + users, User = (self.tables.users, self.classes.User) class SubUser(User): pass @@ -821,8 +908,10 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): m3 = mapper(SubUser2, users) m3.class_manager.dispatch.load(instance) - eq_(canary.mock_calls, [call(instance.obj()), - call(instance.obj()), call(instance.obj())]) + eq_( + canary.mock_calls, + [call(instance.obj()), call(instance.obj()), call(instance.obj())], + ) def test_deferred_instance_event_subclass_no_propagate(self): """ @@ -830,8 +919,7 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): 2. map subclass 3. event fire on subclass should not receive event """ - users, User = (self.tables.users, - self.classes.User) + users, User = (self.tables.users, self.classes.User) class SubUser(User): pass @@ -840,6 +928,7 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): def evt(x): canary.append(x) + event.listen(User, "load", evt, propagate=False) m = mapper(SubUser, users) @@ -853,10 +942,12 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): def evt(x): canary.append(x) + event.listen(User, "attribute_instrument", evt) - instrumentation._instrumentation_factory.\ - dispatch.attribute_instrument(User) + instrumentation._instrumentation_factory.dispatch.attribute_instrument( + User + ) eq_(canary, [User]) def test_isolation_instrument_event(self): @@ -869,10 +960,12 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): def evt(x): canary.append(x) + event.listen(Bar, "attribute_instrument", evt) - instrumentation._instrumentation_factory.dispatch.\ - attribute_instrument(User) + instrumentation._instrumentation_factory.dispatch.attribute_instrument( + User + ) eq_(canary, []) @testing.requires.predictable_gc @@ -902,15 +995,16 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): def evt(x): canary.append(x) + event.listen(User, "attribute_instrument", evt, propagate=True) - instrumentation._instrumentation_factory.dispatch.\ - attribute_instrument(SubUser) + instrumentation._instrumentation_factory.dispatch.attribute_instrument( + SubUser + ) eq_(canary, [SubUser]) def test_deferred_instrument_event_subclass_no_propagate(self): - users, User = (self.tables.users, - self.classes.User) + users, User = (self.tables.users, self.classes.User) class SubUser(User): pass @@ -919,11 +1013,13 @@ class DeferredMapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): def evt(x): canary.append(x) + event.listen(User, "attribute_instrument", evt, propagate=False) mapper(SubUser, users) - instrumentation._instrumentation_factory.dispatch.\ - attribute_instrument(5) + instrumentation._instrumentation_factory.dispatch.attribute_instrument( + 5 + ) eq_(canary, []) @@ -958,13 +1054,13 @@ class LoadTest(_fixtures.FixtureTest): sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() sess.close() sess.query(User).first() - eq_(canary, ['load']) + eq_(canary, ["load"]) def test_repeated_rows(self): User = self.classes.User @@ -973,13 +1069,13 @@ class LoadTest(_fixtures.FixtureTest): sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() sess.close() sess.query(User).union_all(sess.query(User)).all() - eq_(canary, ['load']) + eq_(canary, ["load"]) class RemovalTest(_fixtures.FixtureTest): @@ -988,28 +1084,34 @@ class RemovalTest(_fixtures.FixtureTest): def test_attr_propagated(self): User = self.classes.User - users, addresses, User = (self.tables.users, - self.tables.addresses, - self.classes.User) + users, addresses, User = ( + self.tables.users, + self.tables.addresses, + self.classes.User, + ) class AdminUser(User): pass mapper(User, users) - mapper(AdminUser, addresses, inherits=User, - properties={'address_id': addresses.c.id}) + mapper( + AdminUser, + addresses, + inherits=User, + properties={"address_id": addresses.c.id}, + ) fn = Mock() event.listen(User.name, "set", fn, propagate=True) au = AdminUser() - au.name = 'ed' + au.name = "ed" eq_(fn.call_count, 1) event.remove(User.name, "set", fn) - au.name = 'jack' + au.name = "jack" eq_(fn.call_count, 1) @@ -1041,6 +1143,7 @@ class RemovalTest(_fixtures.FixtureTest): # the _HoldEvents is also cleaned out class Bar(Foo): pass + m = mapper(Bar, users) b1 = Bar() m.dispatch.before_insert(m, None, attributes.instance_state(b1)) @@ -1098,7 +1201,7 @@ class RefreshTest(_fixtures.FixtureTest): sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.flush() @@ -1117,32 +1220,26 @@ class RefreshTest(_fixtures.FixtureTest): @event.listens_for(User, "load") def canary1(obj, context): - obj.name = 'new name!' + obj.name = "new name!" @event.listens_for(User, "refresh") def canary2(obj, context, props): - obj.name = 'refreshed name!' + obj.name = "refreshed name!" sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() sess.close() u1 = sess.query(User).first() - eq_( - attributes.get_history(u1, "name"), - ((), ['new name!'], ()) - ) + eq_(attributes.get_history(u1, "name"), ((), ["new name!"], ())) assert "name" not in attributes.instance_state(u1).committed_state assert u1 not in sess.dirty sess.expire(u1) u1.id - eq_( - attributes.get_history(u1, "name"), - ((), ['refreshed name!'], ()) - ) + eq_(attributes.get_history(u1, "name"), ((), ["refreshed name!"], ())) assert "name" not in attributes.instance_state(u1).committed_state assert u1 in sess.dirty @@ -1153,12 +1250,12 @@ class RefreshTest(_fixtures.FixtureTest): sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() sess.query(User).union_all(sess.query(User)).all() - eq_(canary, [('refresh', set(['id', 'name']))]) + eq_(canary, [("refresh", set(["id", "name"]))]) def test_via_refresh_state(self): User = self.classes.User @@ -1167,12 +1264,12 @@ class RefreshTest(_fixtures.FixtureTest): sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() u1.name - eq_(canary, [('refresh', set(['id', 'name']))]) + eq_(canary, [("refresh", set(["id", "name"]))]) def test_was_expired(self): User = self.classes.User @@ -1181,13 +1278,13 @@ class RefreshTest(_fixtures.FixtureTest): sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.flush() sess.expire(u1) sess.query(User).first() - eq_(canary, [('refresh', set(['id', 'name']))]) + eq_(canary, [("refresh", set(["id", "name"]))]) def test_was_expired_via_commit(self): User = self.classes.User @@ -1196,12 +1293,12 @@ class RefreshTest(_fixtures.FixtureTest): sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() sess.query(User).first() - eq_(canary, [('refresh', set(['id', 'name']))]) + eq_(canary, [("refresh", set(["id", "name"]))]) def test_was_expired_attrs(self): User = self.classes.User @@ -1210,13 +1307,13 @@ class RefreshTest(_fixtures.FixtureTest): sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.flush() - sess.expire(u1, ['name']) + sess.expire(u1, ["name"]) sess.query(User).first() - eq_(canary, [('refresh', set(['name']))]) + eq_(canary, [("refresh", set(["name"]))]) def test_populate_existing(self): User = self.classes.User @@ -1225,12 +1322,12 @@ class RefreshTest(_fixtures.FixtureTest): sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() sess.query(User).populate_existing().first() - eq_(canary, [('refresh', None)]) + eq_(canary, [("refresh", None)]) class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): @@ -1240,7 +1337,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): def my_listener(*arg, **kw): pass - event.listen(Session, 'before_flush', my_listener) + event.listen(Session, "before_flush", my_listener) s = Session() assert my_listener in s.dispatch.before_flush @@ -1258,8 +1355,8 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): S1 = sessionmaker() S2 = sessionmaker() - event.listen(Session, 'before_flush', my_listener_one) - event.listen(S1, 'before_flush', my_listener_two) + event.listen(Session, "before_flush", my_listener_one) + event.listen(S1, "before_flush", my_listener_two) s1 = S1() assert my_listener_one in s1.dispatch.before_flush @@ -1281,7 +1378,10 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): sa.exc.ArgumentError, "Session event listen on a scoped_session requires that its " "creation callable is associated with the Session class.", - event.listen, scope, "before_flush", my_listener_one + event.listen, + scope, + "before_flush", + my_listener_one, ) def test_scoped_session_invalid_class(self): @@ -1291,7 +1391,6 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): pass class NotASession(object): - def __call__(self): return Session() @@ -1301,7 +1400,10 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): sa.exc.ArgumentError, "Session event listen on a scoped_session requires that its " "creation callable is associated with the Session class.", - event.listen, scope, "before_flush", my_listener_one + event.listen, + scope, + "before_flush", + my_listener_one, ) def test_scoped_session_listen(self): @@ -1321,25 +1423,26 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): def listener(name): def go(*arg, **kw): canary.append(name) + return go sess = Session(**kw) for evt in [ - 'after_transaction_create', - 'after_transaction_end', - 'before_commit', - 'after_commit', - 'after_rollback', - 'after_soft_rollback', - 'before_flush', - 'after_flush', - 'after_flush_postexec', - 'after_begin', - 'before_attach', - 'after_attach', - 'after_bulk_update', - 'after_bulk_delete' + "after_transaction_create", + "after_transaction_end", + "before_commit", + "after_commit", + "after_rollback", + "after_soft_rollback", + "before_flush", + "after_flush", + "after_flush_postexec", + "after_begin", + "before_attach", + "after_attach", + "after_bulk_update", + "after_bulk_delete", ]: event.listen(sess, evt, listener(evt)) @@ -1351,18 +1454,26 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): mapper(User, users) sess, canary = self._listener_fixture( - autoflush=False, - autocommit=True, expire_on_commit=False) + autoflush=False, autocommit=True, expire_on_commit=False + ) - u = User(name='u1') + u = User(name="u1") sess.add(u) sess.flush() eq_( canary, - ['before_attach', 'after_attach', 'before_flush', - 'after_transaction_create', 'after_begin', - 'after_flush', 'after_flush_postexec', - 'before_commit', 'after_commit', 'after_transaction_end'] + [ + "before_attach", + "after_attach", + "before_flush", + "after_transaction_create", + "after_begin", + "after_flush", + "after_flush_postexec", + "before_commit", + "after_commit", + "after_transaction_end", + ], ) def test_rollback_hook(self): @@ -1370,30 +1481,43 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): sess, canary = self._listener_fixture() mapper(User, users) - u = User(name='u1', id=1) + u = User(name="u1", id=1) sess.add(u) sess.commit() - u2 = User(name='u1', id=1) + u2 = User(name="u1", id=1) sess.add(u2) - assert_raises( - sa.orm.exc.FlushError, - sess.commit - ) + assert_raises(sa.orm.exc.FlushError, sess.commit) sess.rollback() - eq_(canary, - - ['before_attach', 'after_attach', 'before_commit', 'before_flush', - 'after_transaction_create', 'after_begin', 'after_flush', - 'after_flush_postexec', 'after_transaction_end', 'after_commit', - 'after_transaction_end', 'after_transaction_create', - 'before_attach', 'after_attach', 'before_commit', - 'before_flush', 'after_transaction_create', 'after_begin', - 'after_rollback', - 'after_transaction_end', - 'after_soft_rollback', 'after_transaction_end', - 'after_transaction_create', - 'after_soft_rollback']) + eq_( + canary, + [ + "before_attach", + "after_attach", + "before_commit", + "before_flush", + "after_transaction_create", + "after_begin", + "after_flush", + "after_flush_postexec", + "after_transaction_end", + "after_commit", + "after_transaction_end", + "after_transaction_create", + "before_attach", + "after_attach", + "before_commit", + "before_flush", + "after_transaction_create", + "after_begin", + "after_rollback", + "after_transaction_end", + "after_soft_rollback", + "after_transaction_end", + "after_transaction_create", + "after_soft_rollback", + ], + ) def test_can_use_session_in_outer_rollback_hook(self): User, users = self.classes.User, self.tables.users @@ -1406,19 +1530,16 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): @event.listens_for(sess, "after_soft_rollback") def do_something(session, previous_transaction): if session.is_active: - assertions.append('name' not in u.__dict__) - assertions.append(u.name == 'u1') + assertions.append("name" not in u.__dict__) + assertions.append(u.name == "u1") - u = User(name='u1', id=1) + u = User(name="u1", id=1) sess.add(u) sess.commit() - u2 = User(name='u1', id=1) + u2 = User(name="u1", id=1) sess.add(u2) - assert_raises( - sa.orm.exc.FlushError, - sess.commit - ) + assert_raises(sa.orm.exc.FlushError, sess.commit) sess.rollback() eq_(assertions, [True, True]) @@ -1429,13 +1550,22 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): mapper(User, users) - u = User(name='u1') + u = User(name="u1") sess.add(u) sess.flush() - eq_(canary, ['before_attach', 'after_attach', 'before_flush', - 'after_transaction_create', 'after_begin', - 'after_flush', 'after_flush_postexec', - 'after_transaction_end']) + eq_( + canary, + [ + "before_attach", + "after_attach", + "before_flush", + "after_transaction_create", + "after_begin", + "after_flush", + "after_flush_postexec", + "after_transaction_end", + ], + ) def test_flush_in_commit_hook(self): User, users = self.classes.User, self.tables.users @@ -1443,19 +1573,27 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): sess, canary = self._listener_fixture() mapper(User, users) - u = User(name='u1') + u = User(name="u1") sess.add(u) sess.flush() canary[:] = [] - u.name = 'ed' + u.name = "ed" sess.commit() - eq_(canary, ['before_commit', 'before_flush', - 'after_transaction_create', 'after_flush', - 'after_flush_postexec', - 'after_transaction_end', - 'after_commit', - 'after_transaction_end', 'after_transaction_create', ]) + eq_( + canary, + [ + "before_commit", + "before_flush", + "after_transaction_create", + "after_flush", + "after_flush_postexec", + "after_transaction_end", + "after_commit", + "after_transaction_end", + "after_transaction_create", + ], + ) def test_state_before_attach(self): User, users = self.classes.User, self.tables.users @@ -1470,7 +1608,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): assert inst not in session.new mapper(User, users) - u = User(name='u1') + u = User(name="u1") sess.add(u) sess.flush() sess.expunge(u) @@ -1489,7 +1627,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): assert inst in session.new mapper(User, users) - u = User(name='u1') + u = User(name="u1") sess.add(u) sess.flush() sess.expunge(u) @@ -1498,9 +1636,15 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): def test_standalone_on_commit_hook(self): sess, canary = self._listener_fixture() sess.commit() - eq_(canary, ['before_commit', 'after_commit', - 'after_transaction_end', - 'after_transaction_create']) + eq_( + canary, + [ + "before_commit", + "after_commit", + "after_transaction_end", + "after_transaction_create", + ], + ) def test_on_bulk_update_hook(self): User, users = self.classes.User, self.tables.users @@ -1513,29 +1657,21 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): def legacy(ses, qry, ctx, res): canary.after_bulk_update_legacy(ses, qry, ctx, res) + event.listen(sess, "after_bulk_update", legacy) mapper(User, users) - sess.query(User).update({'name': 'foo'}) + sess.query(User).update({"name": "foo"}) - eq_( - canary.after_begin.call_count, - 1 - ) - eq_( - canary.after_bulk_update.call_count, - 1 - ) + eq_(canary.after_begin.call_count, 1) + eq_(canary.after_bulk_update.call_count, 1) upd = canary.after_bulk_update.mock_calls[0][1][0] - eq_( - upd.session, - sess - ) + eq_(upd.session, sess) eq_( canary.after_bulk_update_legacy.mock_calls, - [call(sess, upd.query, upd.context, upd.result)] + [call(sess, upd.query, upd.context, upd.result)], ) def test_on_bulk_delete_hook(self): @@ -1549,35 +1685,27 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): def legacy(ses, qry, ctx, res): canary.after_bulk_delete_legacy(ses, qry, ctx, res) + event.listen(sess, "after_bulk_delete", legacy) mapper(User, users) sess.query(User).delete() - eq_( - canary.after_begin.call_count, - 1 - ) - eq_( - canary.after_bulk_delete.call_count, - 1 - ) + eq_(canary.after_begin.call_count, 1) + eq_(canary.after_bulk_delete.call_count, 1) upd = canary.after_bulk_delete.mock_calls[0][1][0] - eq_( - upd.session, - sess - ) + eq_(upd.session, sess) eq_( canary.after_bulk_delete_legacy.mock_calls, - [call(sess, upd.query, upd.context, upd.result)] + [call(sess, upd.query, upd.context, upd.result)], ) def test_connection_emits_after_begin(self): sess, canary = self._listener_fixture(bind=testing.db) sess.connection() - eq_(canary, ['after_begin']) + eq_(canary, ["after_begin"]) sess.close() def test_reentrant_flush(self): @@ -1589,10 +1717,11 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): session.flush() sess = Session() - event.listen(sess, 'before_flush', before_flush) - sess.add(User(name='foo')) - assert_raises_message(sa.exc.InvalidRequestError, - 'already flushing', sess.flush) + event.listen(sess, "before_flush", before_flush) + sess.add(User(name="foo")) + assert_raises_message( + sa.exc.InvalidRequestError, "already flushing", sess.flush + ) def test_before_flush_affects_flush_plan(self): users, User = self.tables.users, self.classes.User @@ -1602,50 +1731,49 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): def before_flush(session, flush_context, objects): for obj in list(session.new) + list(session.dirty): if isinstance(obj, User): - session.add(User(name='another %s' % obj.name)) + session.add(User(name="another %s" % obj.name)) for obj in list(session.deleted): if isinstance(obj, User): - x = session.query(User).filter( - User.name == 'another %s' % obj.name).one() + x = ( + session.query(User) + .filter(User.name == "another %s" % obj.name) + .one() + ) session.delete(x) sess = Session() - event.listen(sess, 'before_flush', before_flush) + event.listen(sess, "before_flush", before_flush) - u = User(name='u1') + u = User(name="u1") sess.add(u) sess.flush() - eq_(sess.query(User).order_by(User.name).all(), - [ - User(name='another u1'), - User(name='u1') - ] + eq_( + sess.query(User).order_by(User.name).all(), + [User(name="another u1"), User(name="u1")], ) sess.flush() - eq_(sess.query(User).order_by(User.name).all(), - [ - User(name='another u1'), - User(name='u1') - ] + eq_( + sess.query(User).order_by(User.name).all(), + [User(name="another u1"), User(name="u1")], ) - u.name = 'u2' + u.name = "u2" sess.flush() - eq_(sess.query(User).order_by(User.name).all(), + eq_( + sess.query(User).order_by(User.name).all(), [ - User(name='another u1'), - User(name='another u2'), - User(name='u2') - ] + User(name="another u1"), + User(name="another u2"), + User(name="u2"), + ], ) sess.delete(u) sess.flush() - eq_(sess.query(User).order_by(User.name).all(), - [ - User(name='another u1'), - ] + eq_( + sess.query(User).order_by(User.name).all(), + [User(name="another u1")], ) def test_before_flush_affects_dirty(self): @@ -1658,23 +1786,19 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): obj.name += " modified" sess = Session(autoflush=True) - event.listen(sess, 'before_flush', before_flush) + event.listen(sess, "before_flush", before_flush) - u = User(name='u1') + u = User(name="u1") sess.add(u) sess.flush() - eq_(sess.query(User).order_by(User.name).all(), - [User(name='u1')] - ) + eq_(sess.query(User).order_by(User.name).all(), [User(name="u1")]) - sess.add(User(name='u2')) + sess.add(User(name="u2")) sess.flush() sess.expunge_all() - eq_(sess.query(User).order_by(User.name).all(), - [ - User(name='u1 modified'), - User(name='u2') - ] + eq_( + sess.query(User).order_by(User.name).all(), + [User(name="u1 modified"), User(name="u2")], ) def test_snapshot_still_present_after_commit(self): @@ -1684,7 +1808,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() @@ -1693,11 +1817,11 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): @event.listens_for(sess, "after_commit") def assert_state(session): - assert 'name' in u1.__dict__ - eq_(u1.name, 'u1') + assert "name" in u1.__dict__ + eq_(u1.name, "u1") sess.commit() - assert 'name' not in u1.__dict__ + assert "name" not in u1.__dict__ def test_snapshot_still_present_after_rollback(self): users, User = self.tables.users, self.classes.User @@ -1706,7 +1830,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() @@ -1715,11 +1839,11 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): @event.listens_for(sess, "after_rollback") def assert_state(session): - assert 'name' in u1.__dict__ - eq_(u1.name, 'u1') + assert "name" in u1.__dict__ + eq_(u1.name, "u1") sess.rollback() - assert 'name' not in u1.__dict__ + assert "name" not in u1.__dict__ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): @@ -1730,10 +1854,15 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): if include_address: addresses, Address = self.tables.addresses, self.classes.Address - mapper(User, users, properties={ - "addresses": relationship( - Address, cascade="all, delete-orphan") - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, cascade="all, delete-orphan" + ) + }, + ) mapper(Address, addresses) else: mapper(User, users) @@ -1744,30 +1873,39 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): def start_events(): event.listen( - sess, "transient_to_pending", listener.transient_to_pending) + sess, "transient_to_pending", listener.transient_to_pending + ) event.listen( - sess, "pending_to_transient", listener.pending_to_transient) + sess, "pending_to_transient", listener.pending_to_transient + ) event.listen( - sess, "persistent_to_transient", - listener.persistent_to_transient) + sess, + "persistent_to_transient", + listener.persistent_to_transient, + ) event.listen( - sess, "pending_to_persistent", listener.pending_to_persistent) + sess, "pending_to_persistent", listener.pending_to_persistent + ) event.listen( - sess, "detached_to_persistent", - listener.detached_to_persistent) + sess, "detached_to_persistent", listener.detached_to_persistent + ) event.listen( - sess, "loaded_as_persistent", listener.loaded_as_persistent) + sess, "loaded_as_persistent", listener.loaded_as_persistent + ) event.listen( - sess, "persistent_to_detached", - listener.persistent_to_detached) + sess, "persistent_to_detached", listener.persistent_to_detached + ) event.listen( - sess, "deleted_to_detached", listener.deleted_to_detached) + sess, "deleted_to_detached", listener.deleted_to_detached + ) event.listen( - sess, "persistent_to_deleted", listener.persistent_to_deleted) + sess, "persistent_to_deleted", listener.persistent_to_deleted + ) event.listen( - sess, "deleted_to_persistent", listener.deleted_to_persistent) + sess, "deleted_to_persistent", listener.deleted_to_persistent + ) return listener if include_address: @@ -1785,21 +1923,18 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): assert instance in session listener.flag_checked(instance) - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) eq_( listener.mock_calls, - [ - call.transient_to_pending(sess, u1), - call.flag_checked(u1) - ] + [call.transient_to_pending(sess, u1), call.flag_checked(u1)], ) def test_pending_to_transient_via_rollback(self): sess, User, start_events = self._fixture() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) listener = start_events() @@ -1814,16 +1949,13 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): eq_( listener.mock_calls, - [ - call.pending_to_transient(sess, u1), - call.flag_checked(u1) - ] + [call.pending_to_transient(sess, u1), call.flag_checked(u1)], ) def test_pending_to_transient_via_expunge(self): sess, User, start_events = self._fixture() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) listener = start_events() @@ -1838,16 +1970,13 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): eq_( listener.mock_calls, - [ - call.pending_to_transient(sess, u1), - call.flag_checked(u1) - ] + [call.pending_to_transient(sess, u1), call.flag_checked(u1)], ) def test_pending_to_persistent(self): sess, User, start_events = self._fixture() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) listener = start_events() @@ -1863,10 +1992,7 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): eq_( listener.mock_calls, - [ - call.pending_to_persistent(sess, u1), - call.flag_checked(u1) - ] + [call.pending_to_persistent(sess, u1), call.flag_checked(u1)], ) def test_pending_to_persistent_del(self): @@ -1879,7 +2005,7 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): # we have a strong ref internally is_not_(None, instance) - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) u1_inst_state = u1._sa_instance_state @@ -1895,15 +2021,14 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): listener.mock_calls, [ call.flag_checked(u1_inst_state.obj()), - call.pending_to_persistent( - sess, u1_inst_state.obj()), - ] + call.pending_to_persistent(sess, u1_inst_state.obj()), + ], ) def test_persistent_to_deleted_del(self): sess, User, start_events = self._fixture() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.flush() @@ -1926,14 +2051,14 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): listener.mock_calls, [ call.persistent_to_deleted(sess, u1_inst_state.obj()), - call.flag_checked(u1_inst_state.obj()) - ] + call.flag_checked(u1_inst_state.obj()), + ], ) def test_detached_to_persistent(self): sess, User, start_events = self._fixture() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.flush() @@ -1951,16 +2076,13 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): eq_( listener.mock_calls, - [ - call.detached_to_persistent(sess, u1), - call.flag_checked() - ] + [call.detached_to_persistent(sess, u1), call.flag_checked()], ) def test_loaded_as_persistent(self): sess, User, start_events = self._fixture() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() sess.close() @@ -1977,20 +2099,17 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): assert instance._sa_instance_state.persistent listener.flag_checked(instance) - u1 = sess.query(User).filter_by(name='u1').one() + u1 = sess.query(User).filter_by(name="u1").one() eq_( listener.mock_calls, - [ - call.loaded_as_persistent(sess, u1), - call.flag_checked(u1) - ] + [call.loaded_as_persistent(sess, u1), call.flag_checked(u1)], ) def test_detached_to_persistent_via_deleted(self): sess, User, start_events = self._fixture() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() sess.close() @@ -2020,10 +2139,7 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): eq_( listener.mock_calls, - [ - call.detached_to_persistent(sess, u1), - call.dtp_flag_checked(u1) - ] + [call.detached_to_persistent(sess, u1), call.dtp_flag_checked(u1)], ) sess.flush() @@ -2035,15 +2151,15 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): call.dtp_flag_checked(u1), call.persistent_to_deleted(sess, u1), call.ptd_flag_checked(u1), - ] + ], ) def test_detached_to_persistent_via_cascaded_delete(self): sess, User, Address, start_events = self._fixture(include_address=True) - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) - a1 = Address(email_address='e1') + a1 = Address(email_address="e1") u1.addresses.append(a1) sess.commit() u1.addresses # ensure u1.addresses refers to a1 before detachment @@ -2071,7 +2187,7 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): call.flag_checked(u1), call.detached_to_persistent(sess, a1), call.flag_checked(a1), - ] + ], ) sess.flush() @@ -2079,7 +2195,7 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): def test_persistent_to_deleted(self): sess, User, start_events = self._fixture() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() @@ -2097,26 +2213,20 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): sess.delete(u1) assert u1 in sess.deleted - eq_( - listener.mock_calls, - [] - ) + eq_(listener.mock_calls, []) sess.flush() assert u1 not in sess eq_( listener.mock_calls, - [ - call.persistent_to_deleted(sess, u1), - call.flag_checked(u1) - ] + [call.persistent_to_deleted(sess, u1), call.flag_checked(u1)], ) def test_persistent_to_detached_via_expunge(self): sess, User, start_events = self._fixture() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.flush() @@ -2137,16 +2247,13 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): eq_( listener.mock_calls, - [ - call.persistent_to_detached(sess, u1), - call.flag_checked(u1) - ] + [call.persistent_to_detached(sess, u1), call.flag_checked(u1)], ) def test_persistent_to_detached_via_expunge_all(self): sess, User, start_events = self._fixture() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.flush() @@ -2167,16 +2274,13 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): eq_( listener.mock_calls, - [ - call.persistent_to_detached(sess, u1), - call.flag_checked(u1) - ] + [call.persistent_to_detached(sess, u1), call.flag_checked(u1)], ) def test_persistent_to_transient_via_rollback(self): sess, User, start_events = self._fixture() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.flush() @@ -2196,16 +2300,13 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): eq_( listener.mock_calls, - [ - call.persistent_to_transient(sess, u1), - call.flag_checked(u1) - ] + [call.persistent_to_transient(sess, u1), call.flag_checked(u1)], ) def test_deleted_to_persistent_via_rollback(self): sess, User, start_events = self._fixture() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() @@ -2237,16 +2338,13 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): eq_( listener.mock_calls, - [ - call.deleted_to_persistent(sess, u1), - call.flag_checked(u1) - ] + [call.deleted_to_persistent(sess, u1), call.flag_checked(u1)], ) def test_deleted_to_detached_via_commit(self): sess, User, start_events = self._fixture() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() @@ -2276,10 +2374,7 @@ class SessionLifecycleEventsTest(_RemoveListeners, _fixtures.FixtureTest): eq_( listener.mock_calls, - [ - call.deleted_to_detached(sess, u1), - call.flag_checked(u1) - ] + [call.deleted_to_detached(sess, u1), call.flag_checked(u1)], ) @@ -2294,47 +2389,48 @@ class MapperExtensionTest(_fixtures.FixtureTest): methods = [] class Ext(sa.orm.MapperExtension): - def instrument_class(self, mapper, cls): - methods.append('instrument_class') + methods.append("instrument_class") return sa.orm.EXT_CONTINUE def init_instance( - self, mapper, class_, oldinit, instance, args, kwargs): - methods.append('init_instance') + self, mapper, class_, oldinit, instance, args, kwargs + ): + methods.append("init_instance") return sa.orm.EXT_CONTINUE def init_failed( - self, mapper, class_, oldinit, instance, args, kwargs): - methods.append('init_failed') + self, mapper, class_, oldinit, instance, args, kwargs + ): + methods.append("init_failed") return sa.orm.EXT_CONTINUE def reconstruct_instance(self, mapper, instance): - methods.append('reconstruct_instance') + methods.append("reconstruct_instance") return sa.orm.EXT_CONTINUE def before_insert(self, mapper, connection, instance): - methods.append('before_insert') + methods.append("before_insert") return sa.orm.EXT_CONTINUE def after_insert(self, mapper, connection, instance): - methods.append('after_insert') + methods.append("after_insert") return sa.orm.EXT_CONTINUE def before_update(self, mapper, connection, instance): - methods.append('before_update') + methods.append("before_update") return sa.orm.EXT_CONTINUE def after_update(self, mapper, connection, instance): - methods.append('after_update') + methods.append("after_update") return sa.orm.EXT_CONTINUE def before_delete(self, mapper, connection, instance): - methods.append('before_delete') + methods.append("before_delete") return sa.orm.EXT_CONTINUE def after_delete(self, mapper, connection, instance): - methods.append('after_delete') + methods.append("after_delete") return sa.orm.EXT_CONTINUE return Ext, methods @@ -2348,26 +2444,37 @@ class MapperExtensionTest(_fixtures.FixtureTest): mapper(User, users, extension=Ext()) sess = create_session() - u = User(name='u1') + u = User(name="u1") sess.add(u) sess.flush() u = sess.query(User).populate_existing().get(u.id) sess.expunge_all() u = sess.query(User).get(u.id) - u.name = 'u1 changed' + u.name = "u1 changed" sess.flush() sess.delete(u) sess.flush() - eq_(methods, - ['instrument_class', 'init_instance', 'before_insert', - 'after_insert', - 'reconstruct_instance', - 'before_update', 'after_update', 'before_delete', 'after_delete']) + eq_( + methods, + [ + "instrument_class", + "init_instance", + "before_insert", + "after_insert", + "reconstruct_instance", + "before_update", + "after_update", + "before_delete", + "after_delete", + ], + ) def test_inheritance(self): - users, addresses, User = (self.tables.users, - self.tables.addresses, - self.classes.User) + users, addresses, User = ( + self.tables.users, + self.tables.addresses, + self.classes.User, + ) Ext, methods = self.extension() @@ -2375,26 +2482,39 @@ class MapperExtensionTest(_fixtures.FixtureTest): pass mapper(User, users, extension=Ext()) - mapper(AdminUser, addresses, inherits=User, - properties={'address_id': addresses.c.id}) + mapper( + AdminUser, + addresses, + inherits=User, + properties={"address_id": addresses.c.id}, + ) sess = create_session() - am = AdminUser(name='au1', email_address='au1@e1') + am = AdminUser(name="au1", email_address="au1@e1") sess.add(am) sess.flush() am = sess.query(AdminUser).populate_existing().get(am.id) sess.expunge_all() am = sess.query(AdminUser).get(am.id) - am.name = 'au1 changed' + am.name = "au1 changed" sess.flush() sess.delete(am) sess.flush() - eq_(methods, - ['instrument_class', 'instrument_class', 'init_instance', - 'before_insert', 'after_insert', - 'reconstruct_instance', - 'before_update', 'after_update', 'before_delete', - 'after_delete']) + eq_( + methods, + [ + "instrument_class", + "instrument_class", + "init_instance", + "before_insert", + "after_insert", + "reconstruct_instance", + "before_update", + "after_update", + "before_delete", + "after_delete", + ], + ) def test_before_after_only_collection(self): """before_update is called on parent for collection modifications, @@ -2407,13 +2527,20 @@ class MapperExtensionTest(_fixtures.FixtureTest): self.tables.items, self.tables.item_keywords, self.classes.Keyword, - self.classes.Item) + self.classes.Item, + ) Ext1, methods1 = self.extension() Ext2, methods2 = self.extension() - mapper(Item, items, extension=Ext1(), properties={ - 'keywords': relationship(Keyword, secondary=item_keywords)}) + mapper( + Item, + items, + extension=Ext1(), + properties={ + "keywords": relationship(Keyword, secondary=item_keywords) + }, + ) mapper(Keyword, keywords, extension=Ext2()) sess = create_session() @@ -2422,26 +2549,40 @@ class MapperExtensionTest(_fixtures.FixtureTest): sess.add(i1) sess.add(k1) sess.flush() - eq_(methods1, - ['instrument_class', 'init_instance', - 'before_insert', 'after_insert']) - eq_(methods2, - ['instrument_class', 'init_instance', - 'before_insert', 'after_insert']) + eq_( + methods1, + [ + "instrument_class", + "init_instance", + "before_insert", + "after_insert", + ], + ) + eq_( + methods2, + [ + "instrument_class", + "init_instance", + "before_insert", + "after_insert", + ], + ) del methods1[:] del methods2[:] i1.keywords.append(k1) sess.flush() - eq_(methods1, ['before_update', 'after_update']) + eq_(methods1, ["before_update", "after_update"]) eq_(methods2, []) def test_inheritance_with_dupes(self): """Inheritance with the same extension instance on both mappers.""" - users, addresses, User = (self.tables.users, - self.tables.addresses, - self.classes.User) + users, addresses, User = ( + self.tables.users, + self.tables.addresses, + self.classes.User, + ) Ext, methods = self.extension() @@ -2450,8 +2591,13 @@ class MapperExtensionTest(_fixtures.FixtureTest): ext = Ext() mapper(User, users, extension=ext) - mapper(AdminUser, addresses, inherits=User, extension=ext, - properties={'address_id': addresses.c.id}) + mapper( + AdminUser, + addresses, + inherits=User, + extension=ext, + properties={"address_id": addresses.c.id}, + ) sess = create_session() am = AdminUser(name="au1", email_address="au1@e1") @@ -2460,27 +2606,36 @@ class MapperExtensionTest(_fixtures.FixtureTest): am = sess.query(AdminUser).populate_existing().get(am.id) sess.expunge_all() am = sess.query(AdminUser).get(am.id) - am.name = 'au1 changed' + am.name = "au1 changed" sess.flush() sess.delete(am) sess.flush() - eq_(methods, - ['instrument_class', 'instrument_class', 'init_instance', - 'before_insert', 'after_insert', - 'reconstruct_instance', - 'before_update', 'after_update', 'before_delete', - 'after_delete']) + eq_( + methods, + [ + "instrument_class", + "instrument_class", + "init_instance", + "before_insert", + "after_insert", + "reconstruct_instance", + "before_update", + "after_update", + "before_delete", + "after_delete", + ], + ) def test_unnecessary_methods_not_evented(self): users = self.tables.users class MyExtension(sa.orm.MapperExtension): - def before_insert(self, mapper, connection, instance): pass class Foo(object): pass + m = mapper(Foo, users, extension=MyExtension()) assert not m.class_manager.dispatch.load assert not m.dispatch.before_update @@ -2488,16 +2643,15 @@ class MapperExtensionTest(_fixtures.FixtureTest): class AttributeExtensionTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('t1', - metadata, - Column('id', Integer, primary_key=True), - Column('type', String(40)), - Column('data', String(50)) - - ) + Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("type", String(40)), + Column("data", String(50)), + ) def test_cascading_extensions(self): t1 = self.tables.t1 @@ -2505,13 +2659,11 @@ class AttributeExtensionTest(fixtures.MappedTest): ext_msg = [] class Ex1(sa.orm.AttributeExtension): - def set(self, state, value, oldvalue, initiator): ext_msg.append("Ex1 %r" % value) return "ex1" + value class Ex2(sa.orm.AttributeExtension): - def set(self, state, value, oldvalue, initiator): ext_msg.append("Ex2 %r" % value) return "ex2" + value @@ -2526,33 +2678,46 @@ class AttributeExtensionTest(fixtures.MappedTest): pass mapper( - A, t1, polymorphic_on=t1.c.type, polymorphic_identity='a', - properties={ - 'data': column_property(t1.c.data, extension=Ex1()) - } + A, + t1, + polymorphic_on=t1.c.type, + polymorphic_identity="a", + properties={"data": column_property(t1.c.data, extension=Ex1())}, + ) + mapper(B, polymorphic_identity="b", inherits=A) + mapper( + C, + polymorphic_identity="c", + inherits=B, + properties={"data": column_property(t1.c.data, extension=Ex2())}, ) - mapper(B, polymorphic_identity='b', inherits=A) - mapper(C, polymorphic_identity='c', inherits=B, properties={ - 'data': column_property(t1.c.data, extension=Ex2()) - }) - a1 = A(data='a1') - b1 = B(data='b1') - c1 = C(data='c1') + a1 = A(data="a1") + b1 = B(data="b1") + c1 = C(data="c1") - eq_(a1.data, 'ex1a1') - eq_(b1.data, 'ex1b1') - eq_(c1.data, 'ex2c1') + eq_(a1.data, "ex1a1") + eq_(b1.data, "ex1b1") + eq_(c1.data, "ex2c1") - a1.data = 'a2' - b1.data = 'b2' - c1.data = 'c2' - eq_(a1.data, 'ex1a2') - eq_(b1.data, 'ex1b2') - eq_(c1.data, 'ex2c2') + a1.data = "a2" + b1.data = "b2" + c1.data = "c2" + eq_(a1.data, "ex1a2") + eq_(b1.data, "ex1b2") + eq_(c1.data, "ex2c2") - eq_(ext_msg, ["Ex1 'a1'", "Ex1 'b1'", "Ex2 'c1'", - "Ex1 'a2'", "Ex1 'b2'", "Ex2 'c2'"]) + eq_( + ext_msg, + [ + "Ex1 'a1'", + "Ex1 'b1'", + "Ex2 'c1'", + "Ex1 'a2'", + "Ex1 'b2'", + "Ex2 'c2'", + ], + ) class SessionExtensionTest(_fixtures.FixtureTest): @@ -2565,82 +2730,86 @@ class SessionExtensionTest(_fixtures.FixtureTest): log = [] class MyExt(sa.orm.session.SessionExtension): - def before_commit(self, session): - log.append('before_commit') + log.append("before_commit") def after_commit(self, session): - log.append('after_commit') + log.append("after_commit") def after_rollback(self, session): - log.append('after_rollback') + log.append("after_rollback") def before_flush(self, session, flush_context, objects): - log.append('before_flush') + log.append("before_flush") def after_flush(self, session, flush_context): - log.append('after_flush') + log.append("after_flush") def after_flush_postexec(self, session, flush_context): - log.append('after_flush_postexec') + log.append("after_flush_postexec") def after_begin(self, session, transaction, connection): - log.append('after_begin') + log.append("after_begin") def after_attach(self, session, instance): - log.append('after_attach') + log.append("after_attach") - def after_bulk_update( - self, - session, query, query_context, result - ): - log.append('after_bulk_update') + def after_bulk_update(self, session, query, query_context, result): + log.append("after_bulk_update") - def after_bulk_delete( - self, - session, query, query_context, result - ): - log.append('after_bulk_delete') + def after_bulk_delete(self, session, query, query_context, result): + log.append("after_bulk_delete") sess = create_session(extension=MyExt()) - u = User(name='u1') + u = User(name="u1") sess.add(u) sess.flush() assert log == [ - 'after_attach', - 'before_flush', - 'after_begin', - 'after_flush', - 'after_flush_postexec', - 'before_commit', - 'after_commit', + "after_attach", + "before_flush", + "after_begin", + "after_flush", + "after_flush_postexec", + "before_commit", + "after_commit", ] log = [] sess = create_session(autocommit=False, extension=MyExt()) - u = User(name='u1') + u = User(name="u1") sess.add(u) sess.flush() - assert log == ['after_attach', 'before_flush', 'after_begin', - 'after_flush', 'after_flush_postexec'] + assert log == [ + "after_attach", + "before_flush", + "after_begin", + "after_flush", + "after_flush_postexec", + ] log = [] - u.name = 'ed' + u.name = "ed" sess.commit() - assert log == ['before_commit', 'before_flush', 'after_flush', - 'after_flush_postexec', 'after_commit'] + assert log == [ + "before_commit", + "before_flush", + "after_flush", + "after_flush_postexec", + "after_commit", + ] log = [] sess.commit() - assert log == ['before_commit', 'after_commit'] + assert log == ["before_commit", "after_commit"] log = [] sess.query(User).delete() - assert log == ['after_begin', 'after_bulk_delete'] + assert log == ["after_begin", "after_bulk_delete"] log = [] - sess.query(User).update({'name': 'foo'}) - assert log == ['after_bulk_update'] + sess.query(User).update({"name": "foo"}) + assert log == ["after_bulk_update"] log = [] - sess = create_session(autocommit=False, extension=MyExt(), - bind=testing.db) + sess = create_session( + autocommit=False, extension=MyExt(), bind=testing.db + ) sess.connection() - assert log == ['after_begin'] + assert log == ["after_begin"] sess.close() def test_multiple_extensions(self): @@ -2649,28 +2818,22 @@ class SessionExtensionTest(_fixtures.FixtureTest): log = [] class MyExt1(sa.orm.session.SessionExtension): - def before_commit(self, session): - log.append('before_commit_one') + log.append("before_commit_one") class MyExt2(sa.orm.session.SessionExtension): - def before_commit(self, session): - log.append('before_commit_two') + log.append("before_commit_two") mapper(User, users) sess = create_session(extension=[MyExt1(), MyExt2()]) - u = User(name='u1') + u = User(name="u1") sess.add(u) sess.flush() - assert log == [ - 'before_commit_one', - 'before_commit_two', - ] + assert log == ["before_commit_one", "before_commit_two"] def test_unnecessary_methods_not_evented(self): class MyExtension(sa.orm.session.SessionExtension): - def before_commit(self, session): pass @@ -2680,8 +2843,9 @@ class SessionExtensionTest(_fixtures.FixtureTest): class QueryEventsTest( - _RemoveListeners, _fixtures.FixtureTest, AssertsCompiledSQL): - __dialect__ = 'default' + _RemoveListeners, _fixtures.FixtureTest, AssertsCompiledSQL +): + __dialect__ = "default" @classmethod def setup_mappers(cls): @@ -2694,8 +2858,8 @@ class QueryEventsTest( @event.listens_for(query.Query, "before_compile", retval=True) def no_deleted(query): for desc in query.column_descriptions: - if desc['type'] is User: - entity = desc['expr'] + if desc["type"] is User: + entity = desc["expr"] query = query.filter(entity.id != 10) return query @@ -2708,7 +2872,7 @@ class QueryEventsTest( "SELECT users.id AS users_id, users.name AS users_name " "FROM users " "WHERE users.id = :id_1 AND users.id != :id_2", - checkparams={'id_2': 10, 'id_1': 7} + checkparams={"id_2": 10, "id_1": 7}, ) def test_alters_entities(self): @@ -2720,18 +2884,15 @@ class QueryEventsTest( s = Session() - q = s.query(User.id, ).filter_by(id=7) + q = s.query(User.id).filter_by(id=7) self.assert_compile( q, "SELECT users.id AS users_id, users.name AS users_name " "FROM users " "WHERE users.id = :id_1", - checkparams={'id_1': 7} - ) - eq_( - q.all(), - [(7, 'jack')] + checkparams={"id_1": 7}, ) + eq_(q.all(), [(7, "jack")]) class RefreshFlushInReturningTest(fixtures.MappedTest): @@ -2749,11 +2910,13 @@ class RefreshFlushInReturningTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'test', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('prefetch_val', Integer, default=5), - Column('returning_val', Integer, server_default="5") + "test", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("prefetch_val", Integer, default=5), + Column("returning_val", Integer, server_default="5"), ) @classmethod @@ -2783,13 +2946,10 @@ class RefreshFlushInReturningTest(fixtures.MappedTest): # then we'd have hash order issues. eq_( mock.mock_calls, - [call(t1, ANY, ['returning_val', 'prefetch_val'])] + [call(t1, ANY, ["returning_val", "prefetch_val"])], ) else: - eq_( - mock.mock_calls, - [call(t1, ANY, ['prefetch_val'])] - ) + eq_(mock.mock_calls, [call(t1, ANY, ["prefetch_val"])]) eq_(t1.id, 1) eq_(t1.prefetch_val, 5) diff --git a/test/orm/test_expire.py b/test/orm/test_expire.py index e7ef20e7b5..875ebd31de 100644 --- a/test/orm/test_expire.py +++ b/test/orm/test_expire.py @@ -7,9 +7,21 @@ from sqlalchemy import testing from sqlalchemy import Integer, String, ForeignKey, exc as sa_exc, FetchedValue from sqlalchemy.testing.schema import Table from sqlalchemy.testing.schema import Column -from sqlalchemy.orm import mapper, relationship, create_session, \ - attributes, deferred, exc as orm_exc, defer, undefer,\ - strategies, state, lazyload, backref, Session +from sqlalchemy.orm import ( + mapper, + relationship, + create_session, + attributes, + deferred, + exc as orm_exc, + defer, + undefer, + strategies, + state, + lazyload, + backref, + Session, +) from sqlalchemy.testing import fixtures from test.orm import _fixtures from sqlalchemy.sql import select @@ -17,48 +29,53 @@ from sqlalchemy.orm import make_transient_to_detached class ExpireTest(_fixtures.FixtureTest): - def test_expire(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user'), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={"addresses": relationship(Address, backref="user")}, + ) mapper(Address, addresses) sess = create_session() u = sess.query(User).get(7) assert len(u.addresses) == 1 - u.name = 'foo' + u.name = "foo" del u.addresses[0] sess.expire(u) - assert 'name' not in u.__dict__ + assert "name" not in u.__dict__ def go(): - assert u.name == 'jack' + assert u.name == "jack" + self.assert_sql_count(testing.db, go, 1) - assert 'name' in u.__dict__ + assert "name" in u.__dict__ - u.name = 'foo' + u.name = "foo" sess.flush() # change the value in the DB - users.update(users.c.id == 7, values=dict(name='jack')).execute() + users.update(users.c.id == 7, values=dict(name="jack")).execute() sess.expire(u) # object isn't refreshed yet, using dict to bypass trigger - assert u.__dict__.get('name') != 'jack' - assert 'name' in attributes.instance_state(u).expired_attributes + assert u.__dict__.get("name") != "jack" + assert "name" in attributes.instance_state(u).expired_attributes sess.query(User).all() # test that it refreshed - assert u.__dict__['name'] == 'jack' - assert 'name' not in attributes.instance_state(u).expired_attributes + assert u.__dict__["name"] == "jack" + assert "name" not in attributes.instance_state(u).expired_attributes def go(): - assert u.name == 'jack' + assert u.name == "jack" + self.assert_sql_count(testing.db, go, 0) def test_persistence_check(self): @@ -69,9 +86,12 @@ class ExpireTest(_fixtures.FixtureTest): u = s.query(User).get(7) s.expunge_all() - assert_raises_message(sa_exc.InvalidRequestError, - r"is not persistent within this Session", - s.expire, u) + assert_raises_message( + sa_exc.InvalidRequestError, + r"is not persistent within this Session", + s.expire, + u, + ) def test_get_refreshes(self): users, User = self.tables.users, self.classes.User @@ -83,14 +103,17 @@ class ExpireTest(_fixtures.FixtureTest): def go(): u = s.query(User).get(10) # get() refreshes + self.assert_sql_count(testing.db, go, 1) def go(): - eq_(u.name, 'chuck') # attributes unexpired + eq_(u.name, "chuck") # attributes unexpired + self.assert_sql_count(testing.db, go, 0) def go(): u = s.query(User).get(10) # expire flag reset, so not expired + self.assert_sql_count(testing.db, go, 0) def test_get_on_deleted_expunges(self): @@ -124,7 +147,9 @@ class ExpireTest(_fixtures.FixtureTest): sa.orm.exc.ObjectDeletedError, "Instance '' has been " "deleted, or its row is otherwise not present.", - getattr, u, 'name' + getattr, + u, + "name", ) def test_rollback_undoes_expunge_from_deleted(self): @@ -145,7 +170,7 @@ class ExpireTest(_fixtures.FixtureTest): assert u in s # but now its back, rollback has occurred, the # _remove_newly_deleted is reverted - eq_(u.name, 'chuck') + eq_(u.name, "chuck") def test_deferred(self): """test that unloaded, deferred attributes aren't included in the @@ -153,86 +178,116 @@ class ExpireTest(_fixtures.FixtureTest): Order, orders = self.classes.Order, self.tables.orders - mapper(Order, orders, properties={ - 'description': deferred(orders.c.description)}) + mapper( + Order, + orders, + properties={"description": deferred(orders.c.description)}, + ) s = create_session() o1 = s.query(Order).first() - assert 'description' not in o1.__dict__ + assert "description" not in o1.__dict__ s.expire(o1) assert o1.isopen is not None - assert 'description' not in o1.__dict__ + assert "description" not in o1.__dict__ assert o1.description def test_deferred_notfound(self): users, User = self.tables.users, self.classes.User - mapper(User, users, properties={ - 'name': deferred(users.c.name) - }) + mapper(User, users, properties={"name": deferred(users.c.name)}) s = create_session(autocommit=False) u = s.query(User).get(10) - assert 'name' not in u.__dict__ + assert "name" not in u.__dict__ s.execute(users.delete().where(User.id == 10)) assert_raises_message( sa.orm.exc.ObjectDeletedError, "Instance '' has been " "deleted, or its row is otherwise not present.", - getattr, u, 'name' + getattr, + u, + "name", ) def test_lazyload_autoflushes(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, - order_by=addresses.c.email_address) - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, order_by=addresses.c.email_address + ) + }, + ) mapper(Address, addresses) s = create_session(autoflush=True, autocommit=False) u = s.query(User).get(8) adlist = u.addresses - eq_(adlist, [ - Address(email_address='ed@bettyboop.com'), - Address(email_address='ed@lala.com'), - Address(email_address='ed@wood.com'), - ]) + eq_( + adlist, + [ + Address(email_address="ed@bettyboop.com"), + Address(email_address="ed@lala.com"), + Address(email_address="ed@wood.com"), + ], + ) a1 = u.addresses[2] - a1.email_address = 'aaaaa' - s.expire(u, ['addresses']) - eq_(u.addresses, [ - Address(email_address='aaaaa'), - Address(email_address='ed@bettyboop.com'), - Address(email_address='ed@lala.com'), - ]) + a1.email_address = "aaaaa" + s.expire(u, ["addresses"]) + eq_( + u.addresses, + [ + Address(email_address="aaaaa"), + Address(email_address="ed@bettyboop.com"), + Address(email_address="ed@lala.com"), + ], + ) def test_refresh_collection_exception(self): """test graceful failure for currently unsupported immediate refresh of a collection""" - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship(Address, - order_by=addresses.c.email_address) - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, order_by=addresses.c.email_address + ) + }, + ) mapper(Address, addresses) s = create_session(autoflush=True, autocommit=False) u = s.query(User).get(8) - assert_raises_message(sa_exc.InvalidRequestError, - "properties specified for refresh", - s.refresh, u, ['addresses']) + assert_raises_message( + sa_exc.InvalidRequestError, + "properties specified for refresh", + s.refresh, + u, + ["addresses"], + ) # in contrast to a regular query with no columns - assert_raises_message(sa_exc.InvalidRequestError, - "no columns with which to SELECT", s.query().all) + assert_raises_message( + sa_exc.InvalidRequestError, + "no columns with which to SELECT", + s.query().all, + ) def test_refresh_cancels_expire(self): users, User = self.tables.users, self.classes.User @@ -245,7 +300,8 @@ class ExpireTest(_fixtures.FixtureTest): def go(): u = s.query(User).get(7) - eq_(u.name, 'jack') + eq_(u.name, "jack") + self.assert_sql_count(testing.db, go, 0) def test_expire_doesntload_on_set(self): @@ -256,14 +312,15 @@ class ExpireTest(_fixtures.FixtureTest): sess = create_session() u = sess.query(User).get(7) - sess.expire(u, attribute_names=['name']) + sess.expire(u, attribute_names=["name"]) def go(): - u.name = 'somenewname' + u.name = "somenewname" + self.assert_sql_count(testing.db, go, 0) sess.flush() sess.expunge_all() - assert sess.query(User).get(7).name == 'somenewname' + assert sess.query(User).get(7).name == "somenewname" def test_no_session(self): users, User = self.tables.users, self.classes.User @@ -272,9 +329,9 @@ class ExpireTest(_fixtures.FixtureTest): sess = create_session() u = sess.query(User).get(7) - sess.expire(u, attribute_names=['name']) + sess.expire(u, attribute_names=["name"]) sess.expunge(u) - assert_raises(orm_exc.DetachedInstanceError, getattr, u, 'name') + assert_raises(orm_exc.DetachedInstanceError, getattr, u, "name") def test_pending_raises(self): users, User = self.tables.users, self.classes.User @@ -285,7 +342,7 @@ class ExpireTest(_fixtures.FixtureTest): sess = create_session() u = User(id=15) sess.add(u) - assert_raises(sa_exc.InvalidRequestError, sess.expire, u, ['name']) + assert_raises(sa_exc.InvalidRequestError, sess.expire, u, ["name"]) def test_no_instance_key(self): User, users = self.classes.User, self.tables.users @@ -298,12 +355,12 @@ class ExpireTest(_fixtures.FixtureTest): sess = create_session() u = sess.query(User).get(7) - sess.expire(u, attribute_names=['name']) + sess.expire(u, attribute_names=["name"]) sess.expunge(u) attributes.instance_state(u).key = None - assert 'name' not in u.__dict__ + assert "name" not in u.__dict__ sess.add(u) - assert u.name == 'jack' + assert u.name == "jack" def test_no_instance_key_no_pk(self): users, User = self.tables.users, self.classes.User @@ -314,12 +371,12 @@ class ExpireTest(_fixtures.FixtureTest): sess = create_session() u = sess.query(User).get(7) - sess.expire(u, attribute_names=['name', 'id']) + sess.expire(u, attribute_names=["name", "id"]) sess.expunge(u) attributes.instance_state(u).key = None - assert 'name' not in u.__dict__ + assert "name" not in u.__dict__ sess.add(u) - assert_raises(sa_exc.InvalidRequestError, getattr, u, 'name') + assert_raises(sa_exc.InvalidRequestError, getattr, u, "name") def test_expire_preserves_changes(self): """test that the expire load operation doesn't revert post-expire @@ -336,12 +393,13 @@ class ExpireTest(_fixtures.FixtureTest): def go(): assert o.isopen == 1 + self.assert_sql_count(testing.db, go, 1) - assert o.description == 'order 3 modified' + assert o.description == "order 3 modified" del o.description assert "description" not in o.__dict__ - sess.expire(o, ['isopen']) + sess.expire(o, ["isopen"]) sess.query(Order).all() assert o.isopen == 1 assert "description" not in o.__dict__ @@ -349,26 +407,27 @@ class ExpireTest(_fixtures.FixtureTest): assert o.description is None o.isopen = 15 - sess.expire(o, ['isopen', 'description']) - o.description = 'some new description' + sess.expire(o, ["isopen", "description"]) + o.description = "some new description" sess.query(Order).all() assert o.isopen == 1 - assert o.description == 'some new description' + assert o.description == "some new description" - sess.expire(o, ['isopen', 'description']) + sess.expire(o, ["isopen", "description"]) sess.query(Order).all() del o.isopen def go(): assert o.isopen is None + self.assert_sql_count(testing.db, go, 0) o.isopen = 14 sess.expire(o) - o.description = 'another new description' + o.description = "another new description" sess.query(Order).all() assert o.isopen == 1 - assert o.description == 'another new description' + assert o.description == "another new description" def test_expire_committed(self): """test that the committed state of the attribute receives the most @@ -382,81 +441,104 @@ class ExpireTest(_fixtures.FixtureTest): o = sess.query(Order).get(3) sess.expire(o) - orders.update().execute(description='order 3 modified') + orders.update().execute(description="order 3 modified") assert o.isopen == 1 - assert attributes.instance_state(o) \ - .dict['description'] == 'order 3 modified' + assert ( + attributes.instance_state(o).dict["description"] + == "order 3 modified" + ) def go(): sess.flush() + self.assert_sql_count(testing.db, go, 0) def test_expire_cascade(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, cascade="all, refresh-expire") - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, cascade="all, refresh-expire" + ) + }, + ) mapper(Address, addresses) s = create_session() u = s.query(User).get(8) - assert u.addresses[0].email_address == 'ed@wood.com' + assert u.addresses[0].email_address == "ed@wood.com" - u.addresses[0].email_address = 'someotheraddress' + u.addresses[0].email_address = "someotheraddress" s.expire(u) - assert u.addresses[0].email_address == 'ed@wood.com' + assert u.addresses[0].email_address == "ed@wood.com" def test_refresh_cascade(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, cascade="all, refresh-expire") - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, cascade="all, refresh-expire" + ) + }, + ) mapper(Address, addresses) s = create_session() u = s.query(User).get(8) - assert u.addresses[0].email_address == 'ed@wood.com' + assert u.addresses[0].email_address == "ed@wood.com" - u.addresses[0].email_address = 'someotheraddress' + u.addresses[0].email_address = "someotheraddress" s.refresh(u) - assert u.addresses[0].email_address == 'ed@wood.com' + assert u.addresses[0].email_address == "ed@wood.com" def test_expire_cascade_pending_orphan(self): - cascade = 'save-update, refresh-expire, delete, delete-orphan' + cascade = "save-update, refresh-expire, delete, delete-orphan" self._test_cascade_to_pending(cascade, True) def test_refresh_cascade_pending_orphan(self): - cascade = 'save-update, refresh-expire, delete, delete-orphan' + cascade = "save-update, refresh-expire, delete, delete-orphan" self._test_cascade_to_pending(cascade, False) def test_expire_cascade_pending(self): - cascade = 'save-update, refresh-expire' + cascade = "save-update, refresh-expire" self._test_cascade_to_pending(cascade, True) def test_refresh_cascade_pending(self): - cascade = 'save-update, refresh-expire' + cascade = "save-update, refresh-expire" self._test_cascade_to_pending(cascade, False) def _test_cascade_to_pending(self, cascade, expire_or_refresh): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, cascade=cascade) - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={"addresses": relationship(Address, cascade=cascade)}, + ) mapper(Address, addresses) s = create_session() u = s.query(User).get(8) - a = Address(id=12, email_address='foobar') + a = Address(id=12, email_address="foobar") u.addresses.append(a) if expire_or_refresh: @@ -472,95 +554,120 @@ class ExpireTest(_fixtures.FixtureTest): s.flush() def test_expired_lazy(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user'), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={"addresses": relationship(Address, backref="user")}, + ) mapper(Address, addresses) sess = create_session() u = sess.query(User).get(7) sess.expire(u) - assert 'name' not in u.__dict__ - assert 'addresses' not in u.__dict__ + assert "name" not in u.__dict__ + assert "addresses" not in u.__dict__ def go(): - assert u.addresses[0].email_address == 'jack@bean.com' - assert u.name == 'jack' + assert u.addresses[0].email_address == "jack@bean.com" + assert u.name == "jack" + # two loads self.assert_sql_count(testing.db, go, 2) - assert 'name' in u.__dict__ - assert 'addresses' in u.__dict__ + assert "name" in u.__dict__ + assert "addresses" in u.__dict__ def test_expired_eager(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user', lazy='joined'), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, backref="user", lazy="joined" + ) + }, + ) mapper(Address, addresses) sess = create_session() u = sess.query(User).get(7) sess.expire(u) - assert 'name' not in u.__dict__ - assert 'addresses' not in u.__dict__ + assert "name" not in u.__dict__ + assert "addresses" not in u.__dict__ def go(): - assert u.addresses[0].email_address == 'jack@bean.com' - assert u.name == 'jack' + assert u.addresses[0].email_address == "jack@bean.com" + assert u.name == "jack" + # two loads, since relationship() + scalar are # separate right now on per-attribute load self.assert_sql_count(testing.db, go, 2) - assert 'name' in u.__dict__ - assert 'addresses' in u.__dict__ + assert "name" in u.__dict__ + assert "addresses" in u.__dict__ - sess.expire(u, ['name', 'addresses']) - assert 'name' not in u.__dict__ - assert 'addresses' not in u.__dict__ + sess.expire(u, ["name", "addresses"]) + assert "name" not in u.__dict__ + assert "addresses" not in u.__dict__ def go(): sess.query(User).filter_by(id=7).one() - assert u.addresses[0].email_address == 'jack@bean.com' - assert u.name == 'jack' + assert u.addresses[0].email_address == "jack@bean.com" + assert u.name == "jack" + # one load, since relationship() + scalar are # together when eager load used with Query self.assert_sql_count(testing.db, go, 1) def test_relationship_changes_preserved(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user', lazy='joined'), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, backref="user", lazy="joined" + ) + }, + ) mapper(Address, addresses) sess = create_session() u = sess.query(User).get(8) - sess.expire(u, ['name', 'addresses']) + sess.expire(u, ["name", "addresses"]) u.addresses - assert 'name' not in u.__dict__ + assert "name" not in u.__dict__ del u.addresses[1] u.name - assert 'name' in u.__dict__ + assert "name" in u.__dict__ assert len(u.addresses) == 2 def test_joinedload_props_dontload(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) # relationships currently have to load separately from scalar instances # the use case is: expire "addresses". then access it. lazy load @@ -572,37 +679,41 @@ class ExpireTest(_fixtures.FixtureTest): # lazyload) was issued. would prefer not to complicate lazyloading to # "figure out" that the operation should be aborted right now. - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user', lazy='joined'), - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, backref="user", lazy="joined" + ) + }, + ) mapper(Address, addresses) sess = create_session() u = sess.query(User).get(8) sess.expire(u) u.id - assert 'addresses' not in u.__dict__ + assert "addresses" not in u.__dict__ u.addresses - assert 'addresses' in u.__dict__ + assert "addresses" in u.__dict__ def test_expire_synonym(self): User, users = self.classes.User, self.tables.users - mapper(User, users, properties={ - 'uname': sa.orm.synonym('name') - }) + mapper(User, users, properties={"uname": sa.orm.synonym("name")}) sess = create_session() u = sess.query(User).get(7) - assert 'name' in u.__dict__ + assert "name" in u.__dict__ assert u.uname == u.name sess.expire(u) - assert 'name' not in u.__dict__ + assert "name" not in u.__dict__ - users.update(users.c.id == 7).execute(name='jack2') - assert u.name == 'jack2' - assert u.uname == 'jack2' - assert 'name' in u.__dict__ + users.update(users.c.id == 7).execute(name="jack2") + assert u.name == "jack2" + assert u.uname == "jack2" + assert "name" in u.__dict__ # this wont work unless we add API hooks through the attr. system to # provide "expire" behavior on a synonym @@ -618,79 +729,92 @@ class ExpireTest(_fixtures.FixtureTest): sess = create_session() o = sess.query(Order).get(3) - sess.expire(o, attribute_names=['description']) - assert 'id' in o.__dict__ - assert 'description' not in o.__dict__ - assert attributes.instance_state(o).dict['isopen'] == 1 + sess.expire(o, attribute_names=["description"]) + assert "id" in o.__dict__ + assert "description" not in o.__dict__ + assert attributes.instance_state(o).dict["isopen"] == 1 - orders.update(orders.c.id == 3).execute(description='order 3 modified') + orders.update(orders.c.id == 3).execute(description="order 3 modified") def go(): - assert o.description == 'order 3 modified' + assert o.description == "order 3 modified" + self.assert_sql_count(testing.db, go, 1) - assert attributes.instance_state(o) \ - .dict['description'] == 'order 3 modified' + assert ( + attributes.instance_state(o).dict["description"] + == "order 3 modified" + ) o.isopen = 5 - sess.expire(o, attribute_names=['description']) - assert 'id' in o.__dict__ - assert 'description' not in o.__dict__ - assert o.__dict__['isopen'] == 5 - assert attributes.instance_state(o).committed_state['isopen'] == 1 + sess.expire(o, attribute_names=["description"]) + assert "id" in o.__dict__ + assert "description" not in o.__dict__ + assert o.__dict__["isopen"] == 5 + assert attributes.instance_state(o).committed_state["isopen"] == 1 def go(): - assert o.description == 'order 3 modified' + assert o.description == "order 3 modified" + self.assert_sql_count(testing.db, go, 1) - assert o.__dict__['isopen'] == 5 - assert attributes.instance_state(o) \ - .dict['description'] == 'order 3 modified' - assert attributes.instance_state(o).committed_state['isopen'] == 1 + assert o.__dict__["isopen"] == 5 + assert ( + attributes.instance_state(o).dict["description"] + == "order 3 modified" + ) + assert attributes.instance_state(o).committed_state["isopen"] == 1 sess.flush() - sess.expire(o, attribute_names=['id', 'isopen', 'description']) - assert 'id' not in o.__dict__ - assert 'isopen' not in o.__dict__ - assert 'description' not in o.__dict__ + sess.expire(o, attribute_names=["id", "isopen", "description"]) + assert "id" not in o.__dict__ + assert "isopen" not in o.__dict__ + assert "description" not in o.__dict__ def go(): - assert o.description == 'order 3 modified' + assert o.description == "order 3 modified" assert o.id == 3 assert o.isopen == 5 + self.assert_sql_count(testing.db, go, 1) def test_partial_expire_lazy(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user'), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={"addresses": relationship(Address, backref="user")}, + ) mapper(Address, addresses) sess = create_session() u = sess.query(User).get(8) - sess.expire(u, ['name', 'addresses']) - assert 'name' not in u.__dict__ - assert 'addresses' not in u.__dict__ + sess.expire(u, ["name", "addresses"]) + assert "name" not in u.__dict__ + assert "addresses" not in u.__dict__ # hit the lazy loader. just does the lazy load, # doesn't do the overall refresh def go(): - assert u.addresses[0].email_address == 'ed@wood.com' + assert u.addresses[0].email_address == "ed@wood.com" + self.assert_sql_count(testing.db, go, 1) - assert 'name' not in u.__dict__ + assert "name" not in u.__dict__ # check that mods to expired lazy-load attributes # only do the lazy load - sess.expire(u, ['name', 'addresses']) + sess.expire(u, ["name", "addresses"]) def go(): - u.addresses = [Address(id=10, email_address='foo@bar.com')] + u.addresses = [Address(id=10, email_address="foo@bar.com")] + self.assert_sql_count(testing.db, go, 1) sess.flush() @@ -699,56 +823,70 @@ class ExpireTest(_fixtures.FixtureTest): # so the addresses collection got committed and is # longer expired def go(): - assert u.addresses[0].email_address == 'foo@bar.com' + assert u.addresses[0].email_address == "foo@bar.com" assert len(u.addresses) == 1 + self.assert_sql_count(testing.db, go, 0) # but the name attribute was never loaded and so # still loads def go(): - assert u.name == 'ed' + assert u.name == "ed" + self.assert_sql_count(testing.db, go, 1) def test_partial_expire_eager(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user', lazy='joined'), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, backref="user", lazy="joined" + ) + }, + ) mapper(Address, addresses) sess = create_session() u = sess.query(User).get(8) - sess.expire(u, ['name', 'addresses']) - assert 'name' not in u.__dict__ - assert 'addresses' not in u.__dict__ + sess.expire(u, ["name", "addresses"]) + assert "name" not in u.__dict__ + assert "addresses" not in u.__dict__ def go(): - assert u.addresses[0].email_address == 'ed@wood.com' + assert u.addresses[0].email_address == "ed@wood.com" + self.assert_sql_count(testing.db, go, 1) # check that mods to expired eager-load attributes # do the refresh - sess.expire(u, ['name', 'addresses']) + sess.expire(u, ["name", "addresses"]) def go(): - u.addresses = [Address(id=10, email_address='foo@bar.com')] + u.addresses = [Address(id=10, email_address="foo@bar.com")] + self.assert_sql_count(testing.db, go, 1) sess.flush() # this should ideally trigger the whole load # but currently it works like the lazy case def go(): - assert u.addresses[0].email_address == 'foo@bar.com' + assert u.addresses[0].email_address == "foo@bar.com" assert len(u.addresses) == 1 + self.assert_sql_count(testing.db, go, 0) def go(): - assert u.name == 'ed' + assert u.name == "ed" + # scalar attributes have their own load self.assert_sql_count(testing.db, go, 1) # ideally, this was already loaded, but we arent @@ -756,59 +894,71 @@ class ExpireTest(_fixtures.FixtureTest): # self.assert_sql_count(testing.db, go, 0) def test_relationships_load_on_query(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user'), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={"addresses": relationship(Address, backref="user")}, + ) mapper(Address, addresses) sess = create_session() u = sess.query(User).get(8) - assert 'name' in u.__dict__ + assert "name" in u.__dict__ u.addresses - assert 'addresses' in u.__dict__ - - sess.expire(u, ['name', 'addresses']) - assert 'name' not in u.__dict__ - assert 'addresses' not in u.__dict__ - (sess.query(User).options(sa.orm.joinedload('addresses')). - filter_by(id=8).all()) - assert 'name' in u.__dict__ - assert 'addresses' in u.__dict__ + assert "addresses" in u.__dict__ + + sess.expire(u, ["name", "addresses"]) + assert "name" not in u.__dict__ + assert "addresses" not in u.__dict__ + ( + sess.query(User) + .options(sa.orm.joinedload("addresses")) + .filter_by(id=8) + .all() + ) + assert "name" in u.__dict__ + assert "addresses" in u.__dict__ def test_partial_expire_deferred(self): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties={ - 'description': sa.orm.deferred(orders.c.description) - }) + mapper( + Order, + orders, + properties={"description": sa.orm.deferred(orders.c.description)}, + ) sess = create_session() o = sess.query(Order).get(3) - sess.expire(o, ['description', 'isopen']) - assert 'isopen' not in o.__dict__ - assert 'description' not in o.__dict__ + sess.expire(o, ["description", "isopen"]) + assert "isopen" not in o.__dict__ + assert "description" not in o.__dict__ # test that expired attribute access refreshes # the deferred def go(): assert o.isopen == 1 - assert o.description == 'order 3' + assert o.description == "order 3" + self.assert_sql_count(testing.db, go, 1) - sess.expire(o, ['description', 'isopen']) - assert 'isopen' not in o.__dict__ - assert 'description' not in o.__dict__ + sess.expire(o, ["description", "isopen"]) + assert "isopen" not in o.__dict__ + assert "description" not in o.__dict__ # test that the deferred attribute triggers the full # reload def go(): - assert o.description == 'order 3' + assert o.description == "order 3" assert o.isopen == 1 + self.assert_sql_count(testing.db, go, 1) sa.orm.clear_mappers() @@ -817,70 +967,91 @@ class ExpireTest(_fixtures.FixtureTest): sess.expunge_all() # same tests, using deferred at the options level - o = sess.query(Order).options(sa.orm.defer('description')).get(3) + o = sess.query(Order).options(sa.orm.defer("description")).get(3) - assert 'description' not in o.__dict__ + assert "description" not in o.__dict__ # sanity check def go(): - assert o.description == 'order 3' + assert o.description == "order 3" + self.assert_sql_count(testing.db, go, 1) - assert 'description' in o.__dict__ - assert 'isopen' in o.__dict__ - sess.expire(o, ['description', 'isopen']) - assert 'isopen' not in o.__dict__ - assert 'description' not in o.__dict__ + assert "description" in o.__dict__ + assert "isopen" in o.__dict__ + sess.expire(o, ["description", "isopen"]) + assert "isopen" not in o.__dict__ + assert "description" not in o.__dict__ # test that expired attribute access refreshes # the deferred def go(): assert o.isopen == 1 - assert o.description == 'order 3' + assert o.description == "order 3" + self.assert_sql_count(testing.db, go, 1) - sess.expire(o, ['description', 'isopen']) + sess.expire(o, ["description", "isopen"]) - assert 'isopen' not in o.__dict__ - assert 'description' not in o.__dict__ + assert "isopen" not in o.__dict__ + assert "description" not in o.__dict__ # test that the deferred attribute triggers the full # reload def go(): - assert o.description == 'order 3' + assert o.description == "order 3" assert o.isopen == 1 + self.assert_sql_count(testing.db, go, 1) def test_joinedload_query_refreshes(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user', lazy='joined'), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, backref="user", lazy="joined" + ) + }, + ) mapper(Address, addresses) sess = create_session() u = sess.query(User).get(8) assert len(u.addresses) == 3 sess.expire(u) - assert 'addresses' not in u.__dict__ + assert "addresses" not in u.__dict__ sess.query(User).filter_by(id=8).all() - assert 'addresses' in u.__dict__ + assert "addresses" in u.__dict__ assert len(u.addresses) == 3 @testing.requires.predictable_gc def test_expire_all(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user', lazy='joined', - order_by=addresses.c.id), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, + backref="user", + lazy="joined", + order_by=addresses.c.id, + ) + }, + ) mapper(Address, addresses) sess = create_session() @@ -910,39 +1081,39 @@ class ExpireTest(_fixtures.FixtureTest): # callable u1 = sess.query(User).options(defer(User.name)).first() assert isinstance( - attributes.instance_state(u1).callables['name'], - strategies.LoadDeferredColumns + attributes.instance_state(u1).callables["name"], + strategies.LoadDeferredColumns, ) # expire the attr, it gets the InstanceState callable - sess.expire(u1, ['name']) - assert 'name' in attributes.instance_state(u1).expired_attributes - assert 'name' not in attributes.instance_state(u1).callables + sess.expire(u1, ["name"]) + assert "name" in attributes.instance_state(u1).expired_attributes + assert "name" not in attributes.instance_state(u1).callables # load it, callable is gone u1.name - assert 'name' not in attributes.instance_state(u1).expired_attributes - assert 'name' not in attributes.instance_state(u1).callables + assert "name" not in attributes.instance_state(u1).expired_attributes + assert "name" not in attributes.instance_state(u1).callables # same for expire all sess.expunge_all() u1 = sess.query(User).options(defer(User.name)).first() sess.expire(u1) - assert 'name' in attributes.instance_state(u1).expired_attributes - assert 'name' not in attributes.instance_state(u1).callables + assert "name" in attributes.instance_state(u1).expired_attributes + assert "name" not in attributes.instance_state(u1).callables # load over it. everything normal. sess.query(User).first() - assert 'name' not in attributes.instance_state(u1).expired_attributes - assert 'name' not in attributes.instance_state(u1).callables + assert "name" not in attributes.instance_state(u1).expired_attributes + assert "name" not in attributes.instance_state(u1).callables sess.expunge_all() u1 = sess.query(User).first() # for non present, still expires the same way del u1.name sess.expire(u1) - assert 'name' in attributes.instance_state(u1).expired_attributes - assert 'name' not in attributes.instance_state(u1).callables + assert "name" in attributes.instance_state(u1).expired_attributes + assert "name" not in attributes.instance_state(u1).callables def test_state_deferred_to_col(self): """Behavioral test to verify the current activity of loader callables @@ -950,22 +1121,22 @@ class ExpireTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User - mapper(User, users, properties={'name': deferred(users.c.name)}) + mapper(User, users, properties={"name": deferred(users.c.name)}) sess = create_session() u1 = sess.query(User).options(undefer(User.name)).first() - assert 'name' not in attributes.instance_state(u1).callables + assert "name" not in attributes.instance_state(u1).callables # mass expire, the attribute was loaded, # the attribute gets the callable sess.expire(u1) - assert 'name' in attributes.instance_state(u1).expired_attributes - assert 'name' not in attributes.instance_state(u1).callables + assert "name" in attributes.instance_state(u1).expired_attributes + assert "name" not in attributes.instance_state(u1).callables # load it u1.name - assert 'name' not in attributes.instance_state(u1).expired_attributes - assert 'name' not in attributes.instance_state(u1).callables + assert "name" not in attributes.instance_state(u1).expired_attributes + assert "name" not in attributes.instance_state(u1).callables # mass expire, attribute was loaded but then deleted, # the callable goes away - the state wants to flip @@ -974,15 +1145,15 @@ class ExpireTest(_fixtures.FixtureTest): u1 = sess.query(User).options(undefer(User.name)).first() del u1.name sess.expire(u1) - assert 'name' not in attributes.instance_state(u1).expired_attributes - assert 'name' not in attributes.instance_state(u1).callables + assert "name" not in attributes.instance_state(u1).expired_attributes + assert "name" not in attributes.instance_state(u1).callables # single attribute expire, the attribute gets the callable sess.expunge_all() u1 = sess.query(User).options(undefer(User.name)).first() - sess.expire(u1, ['name']) - assert 'name' in attributes.instance_state(u1).expired_attributes - assert 'name' not in attributes.instance_state(u1).callables + sess.expire(u1, ["name"]) + assert "name" in attributes.instance_state(u1).expired_attributes + assert "name" not in attributes.instance_state(u1).callables def test_state_noload_to_lazy(self): """Behavioral test to verify the current activity of loader callables @@ -992,55 +1163,64 @@ class ExpireTest(_fixtures.FixtureTest): self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.classes.User, + ) mapper( - User, users, - properties={'addresses': relationship(Address, lazy='noload')}) + User, + users, + properties={"addresses": relationship(Address, lazy="noload")}, + ) mapper(Address, addresses) sess = create_session() u1 = sess.query(User).options(lazyload(User.addresses)).first() assert isinstance( - attributes.instance_state(u1).callables['addresses'], - strategies.LoadLazyAttribute + attributes.instance_state(u1).callables["addresses"], + strategies.LoadLazyAttribute, ) # expire, it stays sess.expire(u1) - assert 'addresses' not in attributes.instance_state(u1) \ - .expired_attributes + assert ( + "addresses" not in attributes.instance_state(u1).expired_attributes + ) assert isinstance( - attributes.instance_state(u1).callables['addresses'], - strategies.LoadLazyAttribute + attributes.instance_state(u1).callables["addresses"], + strategies.LoadLazyAttribute, ) # load over it. callable goes away. sess.query(User).first() - assert 'addresses' not in attributes.instance_state(u1) \ - .expired_attributes - assert 'addresses' not in attributes.instance_state(u1).callables + assert ( + "addresses" not in attributes.instance_state(u1).expired_attributes + ) + assert "addresses" not in attributes.instance_state(u1).callables sess.expunge_all() u1 = sess.query(User).options(lazyload(User.addresses)).first() - sess.expire(u1, ['addresses']) - assert 'addresses' not in attributes.instance_state(u1) \ - .expired_attributes + sess.expire(u1, ["addresses"]) + assert ( + "addresses" not in attributes.instance_state(u1).expired_attributes + ) assert isinstance( - attributes.instance_state(u1).callables['addresses'], - strategies.LoadLazyAttribute + attributes.instance_state(u1).callables["addresses"], + strategies.LoadLazyAttribute, ) # load the attr, goes away u1.addresses - assert 'addresses' not in attributes.instance_state(u1) \ - .expired_attributes - assert 'addresses' not in attributes.instance_state(u1).callables + assert ( + "addresses" not in attributes.instance_state(u1).expired_attributes + ) + assert "addresses" not in attributes.instance_state(u1).callables def test_deferred_expire_w_transient_to_detached(self): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties={ - "description": deferred(orders.c.description) - }) + mapper( + Order, + orders, + properties={"description": deferred(orders.c.description)}, + ) s = Session() item = Order(id=1) @@ -1048,52 +1228,69 @@ class ExpireTest(_fixtures.FixtureTest): make_transient_to_detached(item) s.add(item) item.isopen - assert 'description' not in item.__dict__ + assert "description" not in item.__dict__ def test_deferred_expire_normally(self): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties={ - "description": deferred(orders.c.description) - }) + mapper( + Order, + orders, + properties={"description": deferred(orders.c.description)}, + ) s = Session() item = s.query(Order).first() s.expire(item) item.isopen - assert 'description' not in item.__dict__ + assert "description" not in item.__dict__ def test_deferred_expire_explicit_attrs(self): orders, Order = self.tables.orders, self.classes.Order - mapper(Order, orders, properties={ - "description": deferred(orders.c.description) - }) + mapper( + Order, + orders, + properties={"description": deferred(orders.c.description)}, + ) s = Session() item = s.query(Order).first() - s.expire(item, ['isopen', 'description']) + s.expire(item, ["isopen", "description"]) item.isopen - assert 'description' in item.__dict__ + assert "description" in item.__dict__ class PolymorphicExpireTest(fixtures.MappedTest): - run_inserts = 'once' + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - people = Table('people', metadata, - Column('person_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('type', String(30))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('status', String(30))) + people = Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("type", String(30)), + ) + + engineers = Table( + "engineers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("status", String(30)), + ) @classmethod def setup_classes(cls): @@ -1108,66 +1305,82 @@ class PolymorphicExpireTest(fixtures.MappedTest): people, engineers = cls.tables.people, cls.tables.engineers people.insert().execute( - {'person_id': 1, 'name': 'person1', 'type': 'person'}, - {'person_id': 2, 'name': 'engineer1', 'type': 'engineer'}, - {'person_id': 3, 'name': 'engineer2', 'type': 'engineer'}, + {"person_id": 1, "name": "person1", "type": "person"}, + {"person_id": 2, "name": "engineer1", "type": "engineer"}, + {"person_id": 3, "name": "engineer2", "type": "engineer"}, ) engineers.insert().execute( - {'person_id': 2, 'status': 'new engineer'}, - {'person_id': 3, 'status': 'old engineer'}, + {"person_id": 2, "status": "new engineer"}, + {"person_id": 3, "status": "old engineer"}, ) @classmethod def setup_mappers(cls): - Person, people, engineers, Engineer = (cls.classes.Person, - cls.tables.people, - cls.tables.engineers, - cls.classes.Engineer) + Person, people, engineers, Engineer = ( + cls.classes.Person, + cls.tables.people, + cls.tables.engineers, + cls.classes.Engineer, + ) - mapper(Person, people, polymorphic_on=people.c.type, - polymorphic_identity='person') - mapper(Engineer, engineers, inherits=Person, - polymorphic_identity='engineer') + mapper( + Person, + people, + polymorphic_on=people.c.type, + polymorphic_identity="person", + ) + mapper( + Engineer, + engineers, + inherits=Person, + polymorphic_identity="engineer", + ) def test_poly_deferred(self): - Person, people, Engineer = (self.classes.Person, - self.tables.people, - self.classes.Engineer) + Person, people, Engineer = ( + self.classes.Person, + self.tables.people, + self.classes.Engineer, + ) sess = create_session() [p1, e1, e2] = sess.query(Person).order_by(people.c.person_id).all() sess.expire(p1) - sess.expire(e1, ['status']) + sess.expire(e1, ["status"]) sess.expire(e2) for p in [p1, e2]: - assert 'name' not in p.__dict__ + assert "name" not in p.__dict__ - assert 'name' in e1.__dict__ - assert 'status' not in e2.__dict__ - assert 'status' not in e1.__dict__ + assert "name" in e1.__dict__ + assert "status" not in e2.__dict__ + assert "status" not in e1.__dict__ - e1.name = 'new engineer name' + e1.name = "new engineer name" def go(): sess.query(Person).all() + self.assert_sql_count(testing.db, go, 1) for p in [p1, e1, e2]: - assert 'name' in p.__dict__ + assert "name" in p.__dict__ - assert 'status' not in e2.__dict__ - assert 'status' not in e1.__dict__ + assert "status" not in e2.__dict__ + assert "status" not in e1.__dict__ def go(): - assert e1.name == 'new engineer name' - assert e2.name == 'engineer2' - assert e1.status == 'new engineer' - assert e2.status == 'old engineer' + assert e1.name == "new engineer name" + assert e2.name == "engineer2" + assert e1.status == "new engineer" + assert e2.status == "old engineer" + self.assert_sql_count(testing.db, go, 2) - eq_(Engineer.name.get_history(e1), - (['new engineer name'], (), ['engineer1'])) + eq_( + Engineer.name.get_history(e1), + (["new engineer name"], (), ["engineer1"]), + ) def test_no_instance_key(self): Engineer = self.classes.Engineer @@ -1175,12 +1388,12 @@ class PolymorphicExpireTest(fixtures.MappedTest): sess = create_session() e1 = sess.query(Engineer).get(2) - sess.expire(e1, attribute_names=['name']) + sess.expire(e1, attribute_names=["name"]) sess.expunge(e1) attributes.instance_state(e1).key = None - assert 'name' not in e1.__dict__ + assert "name" not in e1.__dict__ sess.add(e1) - assert e1.name == 'engineer1' + assert e1.name == "engineer1" def test_no_instance_key(self): Engineer = self.classes.Engineer @@ -1190,56 +1403,61 @@ class PolymorphicExpireTest(fixtures.MappedTest): sess = create_session() e1 = sess.query(Engineer).get(2) - sess.expire(e1, attribute_names=['name', 'person_id']) + sess.expire(e1, attribute_names=["name", "person_id"]) sess.expunge(e1) attributes.instance_state(e1).key = None - assert 'name' not in e1.__dict__ + assert "name" not in e1.__dict__ sess.add(e1) - assert_raises(sa_exc.InvalidRequestError, getattr, e1, 'name') + assert_raises(sa_exc.InvalidRequestError, getattr, e1, "name") class ExpiredPendingTest(_fixtures.FixtureTest): - run_define_tables = 'once' - run_setup_classes = 'once' + run_define_tables = "once" + run_setup_classes = "once" run_setup_mappers = None run_inserts = None def test_expired_pending(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user'), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={"addresses": relationship(Address, backref="user")}, + ) mapper(Address, addresses) sess = create_session() - a1 = Address(email_address='a1') + a1 = Address(email_address="a1") sess.add(a1) sess.flush() - u1 = User(name='u1') + u1 = User(name="u1") a1.user = u1 sess.flush() # expire 'addresses'. backrefs # which attach to u1 will expect to be "pending" - sess.expire(u1, ['addresses']) + sess.expire(u1, ["addresses"]) # attach an Address. now its "pending" # in user.addresses - a2 = Address(email_address='a2') + a2 = Address(email_address="a2") a2.user = u1 # expire u1.addresses again. this expires # "pending" as well. - sess.expire(u1, ['addresses']) + sess.expire(u1, ["addresses"]) # insert a new row - sess.execute(addresses.insert(), dict( - email_address='a3', user_id=u1.id)) + sess.execute( + addresses.insert(), dict(email_address="a3", user_id=u1.id) + ) # only two addresses pulled from the DB, no "pending" assert len(u1.addresses) == 2 @@ -1252,19 +1470,31 @@ class ExpiredPendingTest(_fixtures.FixtureTest): class LifecycleTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table("data", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30))) - Table("data_fetched", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30), FetchedValue())) - Table("data_defer", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30)), - Column('data2', String(30))) + Table( + "data", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + ) + Table( + "data_fetched", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30), FetchedValue()), + ) + Table( + "data_defer", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + Column("data2", String(30)), + ) @classmethod def setup_classes(cls): @@ -1281,9 +1511,11 @@ class LifecycleTest(fixtures.MappedTest): def setup_mappers(cls): mapper(cls.classes.Data, cls.tables.data) mapper(cls.classes.DataFetched, cls.tables.data_fetched) - mapper(cls.classes.DataDefer, cls.tables.data_defer, properties={ - "data": deferred(cls.tables.data_defer.c.data) - }) + mapper( + cls.classes.DataDefer, + cls.tables.data_defer, + properties={"data": deferred(cls.tables.data_defer.c.data)}, + ) def test_attr_not_inserted(self): Data = self.classes.Data @@ -1297,16 +1529,12 @@ class LifecycleTest(fixtures.MappedTest): # we didn't insert a value for 'data', # so its not in dict, but also when we hit it, it isn't # expired because there's no column default on it or anything like that - assert 'data' not in d1.__dict__ + assert "data" not in d1.__dict__ def go(): eq_(d1.data, None) - self.assert_sql_count( - testing.db, - go, - 0 - ) + self.assert_sql_count(testing.db, go, 0) def test_attr_not_inserted_expired(self): Data = self.classes.Data @@ -1317,7 +1545,7 @@ class LifecycleTest(fixtures.MappedTest): sess.add(d1) sess.flush() - assert 'data' not in d1.__dict__ + assert "data" not in d1.__dict__ # with an expire, we emit sess.expire(d1) @@ -1325,11 +1553,7 @@ class LifecycleTest(fixtures.MappedTest): def go(): eq_(d1.data, None) - self.assert_sql_count( - testing.db, - go, - 1 - ) + self.assert_sql_count(testing.db, go, 1) def test_attr_not_inserted_fetched(self): Data = self.classes.DataFetched @@ -1340,24 +1564,20 @@ class LifecycleTest(fixtures.MappedTest): sess.add(d1) sess.flush() - assert 'data' not in d1.__dict__ + assert "data" not in d1.__dict__ def go(): eq_(d1.data, None) # this one is marked as "fetch" so we emit SQL - self.assert_sql_count( - testing.db, - go, - 1 - ) + self.assert_sql_count(testing.db, go, 1) def test_cols_missing_in_load(self): Data = self.classes.Data sess = create_session() - d1 = Data(data='d1') + d1 = Data(data="d1") sess.add(d1) sess.flush() sess.close() @@ -1367,54 +1587,60 @@ class LifecycleTest(fixtures.MappedTest): # cols not present in the row are implicitly expired def go(): - eq_(d1.data, 'd1') + eq_(d1.data, "d1") - self.assert_sql_count( - testing.db, go, 1 - ) + self.assert_sql_count(testing.db, go, 1) def test_deferred_cols_missing_in_load_state_reset(self): Data = self.classes.DataDefer sess = create_session() - d1 = Data(data='d1') + d1 = Data(data="d1") sess.add(d1) sess.flush() sess.close() sess = create_session() - d1 = sess.query(Data).from_statement( - select([Data.id])).options(undefer(Data.data)).first() - d1.data = 'd2' + d1 = ( + sess.query(Data) + .from_statement(select([Data.id])) + .options(undefer(Data.data)) + .first() + ) + d1.data = "d2" # the deferred loader has to clear out any state # on the col, including that 'd2' here d1 = sess.query(Data).populate_existing().first() def go(): - eq_(d1.data, 'd1') + eq_(d1.data, "d1") - self.assert_sql_count( - testing.db, go, 1 - ) + self.assert_sql_count(testing.db, go, 1) class RefreshTest(_fixtures.FixtureTest): - def test_refresh(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses), - backref='user') - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), backref="user" + ) + }, + ) s = create_session() u = s.query(User).get(7) - u.name = 'foo' + u.name = "foo" a = Address() assert sa.orm.object_session(a) is None u.addresses.append(a) @@ -1427,20 +1653,20 @@ class RefreshTest(_fixtures.FixtureTest): assert u not in s.dirty # username is back to the DB - assert u.name == 'jack' + assert u.name == "jack" assert id(a) not in [id(x) for x in u.addresses] - u.name = 'foo' + u.name = "foo" u.addresses.append(a) # now its dirty assert u in s.dirty - assert u.name == 'foo' + assert u.name == "foo" assert id(a) in [id(x) for x in u.addresses] s.expire(u) # get the attribute, it refreshes - assert u.name == 'jack' + assert u.name == "jack" assert id(a) not in [id(x) for x in u.addresses] def test_persistence_check(self): @@ -1450,9 +1676,11 @@ class RefreshTest(_fixtures.FixtureTest): s = create_session() u = s.query(User).get(7) s.expunge_all() - assert_raises_message(sa_exc.InvalidRequestError, - r"is not persistent within this Session", - lambda: s.refresh(u)) + assert_raises_message( + sa_exc.InvalidRequestError, + r"is not persistent within this Session", + lambda: s.refresh(u), + ) def test_refresh_expired(self): User, users = self.classes.User, self.tables.users @@ -1461,43 +1689,56 @@ class RefreshTest(_fixtures.FixtureTest): s = create_session() u = s.query(User).get(7) s.expire(u) - assert 'name' not in u.__dict__ + assert "name" not in u.__dict__ s.refresh(u) - assert u.name == 'jack' + assert u.name == "jack" def test_refresh_with_lazy(self): """test that when a lazy loader is set as a trigger on an object's attribute (at the attribute level, not the class level), a refresh() operation doesn't fire the lazy loader or create any problems""" - User, Address, addresses, users = (self.classes.User, - self.classes.Address, - self.tables.addresses, - self.tables.users) + User, Address, addresses, users = ( + self.classes.User, + self.classes.Address, + self.tables.addresses, + self.tables.users, + ) s = create_session() - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses))}) - q = s.query(User).options(sa.orm.lazyload('addresses')) + mapper( + User, + users, + properties={"addresses": relationship(mapper(Address, addresses))}, + ) + q = s.query(User).options(sa.orm.lazyload("addresses")) u = q.filter(users.c.id == 8).first() def go(): s.refresh(u) + self.assert_sql_count(testing.db, go, 1) def test_refresh_with_eager(self): """test that a refresh/expire operation loads rows properly and sends correct "isnew" state to eager loaders""" - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses), - lazy='joined') - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), lazy="joined" + ) + }, + ) s = create_session() u = s.query(User).get(8) @@ -1514,24 +1755,27 @@ class RefreshTest(_fixtures.FixtureTest): def test_refresh_maintains_deferred_options(self): # testing a behavior that may have changed with # [ticket:3822] - User, Address, Dingaling = self.classes( - "User", "Address", "Dingaling") + User, Address, Dingaling = self.classes("User", "Address", "Dingaling") users, addresses, dingalings = self.tables( - "users", "addresses", "dingalings") + "users", "addresses", "dingalings" + ) - mapper(User, users, properties={ - 'addresses': relationship(Address) - }) + mapper(User, users, properties={"addresses": relationship(Address)}) - mapper(Address, addresses, properties={ - 'dingalings': relationship(Dingaling) - }) + mapper( + Address, + addresses, + properties={"dingalings": relationship(Dingaling)}, + ) mapper(Dingaling, dingalings) s = create_session() - q = s.query(User).filter_by(name='fred').options( - sa.orm.lazyload('addresses').joinedload("dingalings")) + q = ( + s.query(User) + .filter_by(name="fred") + .options(sa.orm.lazyload("addresses").joinedload("dingalings")) + ) u1 = q.one() @@ -1544,10 +1788,12 @@ class RefreshTest(_fixtures.FixtureTest): def go(): eq_( u1.addresses, - [Address( - email_address='fred@fred.com', - dingalings=[Dingaling(data="ding 2/5")] - )] + [ + Address( + email_address="fred@fred.com", + dingalings=[Dingaling(data="ding 2/5")], + ) + ], ) self.assert_sql_count(testing.db, go, 1) @@ -1555,28 +1801,37 @@ class RefreshTest(_fixtures.FixtureTest): def test_refresh2(self): """test a hang condition that was occurring on expire/refresh""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) s = create_session() mapper(Address, addresses) - mapper(User, users, properties=dict(addresses=relationship( - Address, cascade="all, delete-orphan", lazy='joined'))) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, cascade="all, delete-orphan", lazy="joined" + ) + ), + ) u = User() - u.name = 'Justin' - a = Address(id=10, email_address='lala') + u.name = "Justin" + a = Address(id=10, email_address="lala") u.addresses.append(a) s.add(u) s.flush() s.expunge_all() - u = s.query(User).filter(User.name == 'Justin').one() + u = s.query(User).filter(User.name == "Justin").one() s.expire(u) - assert u.name == 'Justin' + assert u.name == "Justin" s.refresh(u) diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index 27924c4138..0640387501 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -1,14 +1,46 @@ from sqlalchemy import testing from sqlalchemy.testing import ( - fixtures, eq_, is_, assert_raises, - assert_raises_message, AssertsCompiledSQL) + fixtures, + eq_, + is_, + assert_raises, + assert_raises_message, + AssertsCompiledSQL, +) from sqlalchemy import ( - exc as sa_exc, util, Integer, Table, String, ForeignKey, select, func, - and_, asc, desc, inspect, literal_column, cast, exists, text) + exc as sa_exc, + util, + Integer, + Table, + String, + ForeignKey, + select, + func, + and_, + asc, + desc, + inspect, + literal_column, + cast, + exists, + text, +) from sqlalchemy.orm import ( - configure_mappers, Session, mapper, create_session, relationship, - column_property, joinedload_all, contains_eager, contains_alias, - joinedload, clear_mappers, backref, relation, aliased) + configure_mappers, + Session, + mapper, + create_session, + relationship, + column_property, + joinedload_all, + contains_eager, + contains_alias, + joinedload, + clear_mappers, + backref, + relation, + aliased, +) from sqlalchemy.sql import table, column from sqlalchemy.engine import default import sqlalchemy as sa @@ -20,55 +52,83 @@ from sqlalchemy.orm.util import join class QueryTest(_fixtures.FixtureTest): - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod def setup_mappers(cls): - Node, composite_pk_table, users, Keyword, items, Dingaling, \ - order_items, item_keywords, Item, User, dingalings, \ - Address, keywords, CompositePk, nodes, Order, orders, \ - addresses = cls.classes.Node, \ - cls.tables.composite_pk_table, cls.tables.users, \ - cls.classes.Keyword, cls.tables.items, \ - cls.classes.Dingaling, cls.tables.order_items, \ - cls.tables.item_keywords, cls.classes.Item, \ - cls.classes.User, cls.tables.dingalings, \ - cls.classes.Address, cls.tables.keywords, \ - cls.classes.CompositePk, cls.tables.nodes, \ - cls.classes.Order, cls.tables.orders, cls.tables.addresses + Node, composite_pk_table, users, Keyword, items, Dingaling, order_items, item_keywords, Item, User, dingalings, Address, keywords, CompositePk, nodes, Order, orders, addresses = ( + cls.classes.Node, + cls.tables.composite_pk_table, + cls.tables.users, + cls.classes.Keyword, + cls.tables.items, + cls.classes.Dingaling, + cls.tables.order_items, + cls.tables.item_keywords, + cls.classes.Item, + cls.classes.User, + cls.tables.dingalings, + cls.classes.Address, + cls.tables.keywords, + cls.classes.CompositePk, + cls.tables.nodes, + cls.classes.Order, + cls.tables.orders, + cls.tables.addresses, + ) mapper( - User, users, properties={ - 'addresses': relationship( - Address, backref='user', order_by=addresses.c.id), - 'orders': relationship( - Order, backref='user', order_by=orders.c.id), # o2m, m2o - }) + User, + users, + properties={ + "addresses": relationship( + Address, backref="user", order_by=addresses.c.id + ), + "orders": relationship( + Order, backref="user", order_by=orders.c.id + ), # o2m, m2o + }, + ) mapper( - Address, addresses, properties={ - 'dingaling': relationship( - Dingaling, uselist=False, backref="address") # o2o - }) + Address, + addresses, + properties={ + "dingaling": relationship( + Dingaling, uselist=False, backref="address" + ) # o2o + }, + ) mapper(Dingaling, dingalings) mapper( - Order, orders, properties={ - 'items': relationship( - Item, secondary=order_items, order_by=items.c.id), # m2m - 'address': relationship(Address), # m2o - }) + Order, + orders, + properties={ + "items": relationship( + Item, secondary=order_items, order_by=items.c.id + ), # m2m + "address": relationship(Address), # m2o + }, + ) mapper( - Item, items, properties={ - 'keywords': relationship( - Keyword, secondary=item_keywords)}) # m2m + Item, + items, + properties={ + "keywords": relationship(Keyword, secondary=item_keywords) + }, + ) # m2m mapper(Keyword, keywords) mapper( - Node, nodes, properties={ - 'children': relationship( - Node, backref=backref('parent', remote_side=[nodes.c.id])) - }) + Node, + nodes, + properties={ + "children": relationship( + Node, backref=backref("parent", remote_side=[nodes.c.id]) + ) + }, + ) mapper(CompositePk, composite_pk_table) @@ -78,69 +138,96 @@ class QueryTest(_fixtures.FixtureTest): class QueryCorrelatesLikeSelect(QueryTest, AssertsCompiledSQL): __dialect__ = "default" - query_correlated = "SELECT users.name AS users_name, " \ - "(SELECT count(addresses.id) AS count_1 FROM addresses " \ + query_correlated = ( + "SELECT users.name AS users_name, " + "(SELECT count(addresses.id) AS count_1 FROM addresses " "WHERE addresses.user_id = users.id) AS anon_1 FROM users" + ) - query_not_correlated = "SELECT users.name AS users_name, " \ - "(SELECT count(addresses.id) AS count_1 FROM addresses, users " \ + query_not_correlated = ( + "SELECT users.name AS users_name, " + "(SELECT count(addresses.id) AS count_1 FROM addresses, users " "WHERE addresses.user_id = users.id) AS anon_1 FROM users" + ) def test_as_scalar_select_auto_correlate(self): addresses, users = self.tables.addresses, self.tables.users query = select( - [func.count(addresses.c.id)], - addresses.c.user_id == users.c.id).as_scalar() - query = select([users.c.name.label('users_name'), query]) + [func.count(addresses.c.id)], addresses.c.user_id == users.c.id + ).as_scalar() + query = select([users.c.name.label("users_name"), query]) self.assert_compile( - query, self.query_correlated, dialect=default.DefaultDialect()) + query, self.query_correlated, dialect=default.DefaultDialect() + ) def test_as_scalar_select_explicit_correlate(self): addresses, users = self.tables.addresses, self.tables.users - query = select( - [func.count(addresses.c.id)], - addresses.c.user_id == users.c.id).correlate(users).as_scalar() - query = select([users.c.name.label('users_name'), query]) + query = ( + select( + [func.count(addresses.c.id)], addresses.c.user_id == users.c.id + ) + .correlate(users) + .as_scalar() + ) + query = select([users.c.name.label("users_name"), query]) self.assert_compile( - query, self.query_correlated, dialect=default.DefaultDialect()) + query, self.query_correlated, dialect=default.DefaultDialect() + ) def test_as_scalar_select_correlate_off(self): addresses, users = self.tables.addresses, self.tables.users - query = select( - [func.count(addresses.c.id)], - addresses.c.user_id == users.c.id).correlate(None).as_scalar() - query = select([users.c.name.label('users_name'), query]) + query = ( + select( + [func.count(addresses.c.id)], addresses.c.user_id == users.c.id + ) + .correlate(None) + .as_scalar() + ) + query = select([users.c.name.label("users_name"), query]) self.assert_compile( - query, self.query_not_correlated, dialect=default.DefaultDialect()) + query, self.query_not_correlated, dialect=default.DefaultDialect() + ) def test_as_scalar_query_auto_correlate(self): sess = create_session() Address, User = self.classes.Address, self.classes.User - query = sess.query(func.count(Address.id))\ - .filter(Address.user_id == User.id)\ + query = ( + sess.query(func.count(Address.id)) + .filter(Address.user_id == User.id) .as_scalar() + ) query = sess.query(User.name, query) self.assert_compile( - query, self.query_correlated, dialect=default.DefaultDialect()) + query, self.query_correlated, dialect=default.DefaultDialect() + ) def test_as_scalar_query_explicit_correlate(self): sess = create_session() Address, User = self.classes.Address, self.classes.User - query = sess.query(func.count(Address.id)). \ - filter(Address.user_id == User.id). \ - correlate(self.tables.users).as_scalar() + query = ( + sess.query(func.count(Address.id)) + .filter(Address.user_id == User.id) + .correlate(self.tables.users) + .as_scalar() + ) query = sess.query(User.name, query) self.assert_compile( - query, self.query_correlated, dialect=default.DefaultDialect()) + query, self.query_correlated, dialect=default.DefaultDialect() + ) def test_as_scalar_query_correlate_off(self): sess = create_session() Address, User = self.classes.Address, self.classes.User - query = sess.query(func.count(Address.id)). \ - filter(Address.user_id == User.id).correlate(None).as_scalar() + query = ( + sess.query(func.count(Address.id)) + .filter(Address.user_id == User.id) + .correlate(None) + .as_scalar() + ) query = sess.query(User.name, query) self.assert_compile( - query, self.query_not_correlated, dialect=default.DefaultDialect()) + query, self.query_not_correlated, dialect=default.DefaultDialect() + ) def test_correlate_to_union(self): User = self.classes.User @@ -161,7 +248,7 @@ class QueryCorrelatesLikeSelect(QueryTest, AssertsCompiledSQL): "FROM (" "SELECT users.id AS users_id, users.name AS users_name FROM users " "UNION SELECT users.id AS users_id, users.name AS users_name " - "FROM users) AS anon_1" + "FROM users) AS anon_1", ) # only difference is "1" vs. "*" (not sure why that is) @@ -174,7 +261,7 @@ class QueryCorrelatesLikeSelect(QueryTest, AssertsCompiledSQL): "FROM (" "SELECT users.id AS users_id, users.name AS users_name FROM users " "UNION SELECT users.id AS users_id, users.name AS users_name " - "FROM users) AS anon_1" + "FROM users) AS anon_1", ) @@ -186,7 +273,8 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): thru for ClauseElement entities. """ - __dialect__ = 'default' + + __dialect__ = "default" def test_select(self): addresses, users = self.tables.addresses, self.tables.users @@ -194,68 +282,95 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): sess = create_session() self.assert_compile( - sess.query(users).select_entity_from(users.select()). - with_labels().statement, + sess.query(users) + .select_entity_from(users.select()) + .with_labels() + .statement, "SELECT users.id AS users_id, users.name AS users_name " "FROM users, " "(SELECT users.id AS id, users.name AS name FROM users) AS anon_1", ) self.assert_compile( - sess.query(users, exists([1], from_obj=addresses)). - with_labels().statement, + sess.query(users, exists([1], from_obj=addresses)) + .with_labels() + .statement, "SELECT users.id AS users_id, users.name AS users_name, EXISTS " "(SELECT 1 FROM addresses) AS anon_1 FROM users", ) # a little tedious here, adding labels to work around Query's # auto-labelling. - s = sess.query( - addresses.c.id.label('id'), - addresses.c.email_address.label('email')).\ - filter(addresses.c.user_id == users.c.id).correlate(users).\ - statement.alias() + s = ( + sess.query( + addresses.c.id.label("id"), + addresses.c.email_address.label("email"), + ) + .filter(addresses.c.user_id == users.c.id) + .correlate(users) + .statement.alias() + ) self.assert_compile( - sess.query(users, s.c.email).select_entity_from( - users.join(s, s.c.id == users.c.id) - ).with_labels().statement, + sess.query(users, s.c.email) + .select_entity_from(users.join(s, s.c.id == users.c.id)) + .with_labels() + .statement, "SELECT users.id AS users_id, users.name AS users_name, " "anon_1.email AS anon_1_email " "FROM users JOIN (SELECT addresses.id AS id, " "addresses.email_address AS email FROM addresses, users " "WHERE addresses.user_id = users.id) AS anon_1 " - "ON anon_1.id = users.id",) + "ON anon_1.id = users.id", + ) - x = func.lala(users.c.id).label('foo') - self.assert_compile(sess.query(x).filter(x == 5).statement, - "SELECT lala(users.id) AS foo FROM users WHERE " - "lala(users.id) = :param_1") + x = func.lala(users.c.id).label("foo") + self.assert_compile( + sess.query(x).filter(x == 5).statement, + "SELECT lala(users.id) AS foo FROM users WHERE " + "lala(users.id) = :param_1", + ) - self.assert_compile(sess.query(func.sum(x).label('bar')).statement, - "SELECT sum(lala(users.id)) AS bar FROM users") + self.assert_compile( + sess.query(func.sum(x).label("bar")).statement, + "SELECT sum(lala(users.id)) AS bar FROM users", + ) class FromSelfTest(QueryTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_filter(self): User = self.classes.User eq_( [User(id=8), User(id=9)], - create_session().query(User).filter(User.id.in_([8, 9])). - from_self().all()) + create_session() + .query(User) + .filter(User.id.in_([8, 9])) + .from_self() + .all(), + ) eq_( [User(id=8), User(id=9)], - create_session().query(User).order_by(User.id).slice(1, 3). - from_self().all()) + create_session() + .query(User) + .order_by(User.id) + .slice(1, 3) + .from_self() + .all(), + ) eq_( [User(id=8)], list( - create_session().query(User).filter(User.id.in_([8, 9])). - from_self().order_by(User.id)[0:1])) + create_session() + .query(User) + .filter(User.id.in_([8, 9])) + .from_self() + .order_by(User.id)[0:1] + ), + ) def test_join(self): User, Address = self.classes.User, self.classes.Address @@ -265,27 +380,38 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): (User(id=8), Address(id=2)), (User(id=8), Address(id=3)), (User(id=8), Address(id=4)), - (User(id=9), Address(id=5))], - create_session().query(User).filter(User.id.in_([8, 9])). - from_self().join('addresses').add_entity(Address). - order_by(User.id, Address.id).all() + (User(id=9), Address(id=5)), + ], + create_session() + .query(User) + .filter(User.id.in_([8, 9])) + .from_self() + .join("addresses") + .add_entity(Address) + .order_by(User.id, Address.id) + .all(), ) def test_group_by(self): Address = self.classes.Address eq_( - create_session(). - query(Address.user_id, func.count(Address.id).label('count')). - group_by(Address.user_id).order_by(Address.user_id).all(), - [(7, 1), (8, 3), (9, 1)] + create_session() + .query(Address.user_id, func.count(Address.id).label("count")) + .group_by(Address.user_id) + .order_by(Address.user_id) + .all(), + [(7, 1), (8, 3), (9, 1)], ) eq_( - create_session().query(Address.user_id, Address.id). - from_self(Address.user_id, func.count(Address.id)). - group_by(Address.user_id).order_by(Address.user_id).all(), - [(7, 1), (8, 3), (9, 1)] + create_session() + .query(Address.user_id, Address.id) + .from_self(Address.user_id, func.count(Address.id)) + .group_by(Address.user_id) + .order_by(Address.user_id) + .all(), + [(7, 1), (8, 3), (9, 1)], ) def test_having(self): @@ -294,11 +420,10 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): s = create_session() self.assert_compile( - s.query(User.id).group_by(User.id).having(User.id > 5). - from_self(), + s.query(User.id).group_by(User.id).having(User.id > 5).from_self(), "SELECT anon_1.users_id AS anon_1_users_id FROM " "(SELECT users.id AS users_id FROM users GROUP " - "BY users.id HAVING users.id > :id_1) AS anon_1" + "BY users.id HAVING users.id > :id_1) AS anon_1", ) def test_no_joinedload(self): @@ -310,14 +435,16 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): s = create_session() self.assert_compile( - s.query(User).options(joinedload(User.addresses)). - from_self().statement, + s.query(User) + .options(joinedload(User.addresses)) + .from_self() + .statement, "SELECT anon_1.users_id, anon_1.users_name, addresses_1.id, " "addresses_1.user_id, addresses_1.email_address FROM " "(SELECT users.id AS users_id, users.name AS " "users_name FROM users) AS anon_1 LEFT OUTER JOIN " "addresses AS addresses_1 ON anon_1.users_id = " - "addresses_1.user_id ORDER BY addresses_1.id" + "addresses_1.user_id ORDER BY addresses_1.id", ) def test_aliases(self): @@ -330,36 +457,46 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): ualias = aliased(User) eq_( - s.query(User, ualias).filter(User.id > ualias.id). - from_self(User.name, ualias.name). - order_by(User.name, ualias.name).all(), + s.query(User, ualias) + .filter(User.id > ualias.id) + .from_self(User.name, ualias.name) + .order_by(User.name, ualias.name) + .all(), [ - ('chuck', 'ed'), - ('chuck', 'fred'), - ('chuck', 'jack'), - ('ed', 'jack'), - ('fred', 'ed'), - ('fred', 'jack') - ] + ("chuck", "ed"), + ("chuck", "fred"), + ("chuck", "jack"), + ("ed", "jack"), + ("fred", "ed"), + ("fred", "jack"), + ], ) eq_( - s.query(User, ualias).filter(User.id > ualias.id). - from_self(User.name, ualias.name).filter(ualias.name == 'ed'). - order_by(User.name, ualias.name).all(), - [('chuck', 'ed'), ('fred', 'ed')]) + s.query(User, ualias) + .filter(User.id > ualias.id) + .from_self(User.name, ualias.name) + .filter(ualias.name == "ed") + .order_by(User.name, ualias.name) + .all(), + [("chuck", "ed"), ("fred", "ed")], + ) eq_( - s.query(User, ualias).filter(User.id > ualias.id). - from_self(ualias.name, Address.email_address). - join(ualias.addresses). - order_by(ualias.name, Address.email_address).all(), + s.query(User, ualias) + .filter(User.id > ualias.id) + .from_self(ualias.name, Address.email_address) + .join(ualias.addresses) + .order_by(ualias.name, Address.email_address) + .all(), [ - ('ed', 'fred@fred.com'), - ('jack', 'ed@bettyboop.com'), - ('jack', 'ed@lala.com'), - ('jack', 'ed@wood.com'), - ('jack', 'fred@fred.com')]) + ("ed", "fred@fred.com"), + ("jack", "ed@bettyboop.com"), + ("jack", "ed@lala.com"), + ("jack", "ed@wood.com"), + ("jack", "fred@fred.com"), + ], + ) def test_multiple_entities(self): User, Address = self.classes.User, self.classes.Address @@ -367,21 +504,26 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): sess = create_session() eq_( - sess.query(User, Address). - filter(User.id == Address.user_id). - filter(Address.id.in_([2, 5])).from_self().all(), - [ - (User(id=8), Address(id=2)), - (User(id=9), Address(id=5))]) + sess.query(User, Address) + .filter(User.id == Address.user_id) + .filter(Address.id.in_([2, 5])) + .from_self() + .all(), + [(User(id=8), Address(id=2)), (User(id=9), Address(id=5))], + ) eq_( - sess.query(User, Address).filter(User.id == Address.user_id). - filter(Address.id.in_([2, 5])).from_self(). - options(joinedload('addresses')).first(), + sess.query(User, Address) + .filter(User.id == Address.user_id) + .filter(Address.id.in_([2, 5])) + .from_self() + .options(joinedload("addresses")) + .first(), ( - User( - id=8, addresses=[Address(), Address(), Address()]), - Address(id=2)),) + User(id=8, addresses=[Address(), Address(), Address()]), + Address(id=2), + ), + ) def test_multiple_with_column_entities(self): User = self.classes.User @@ -389,16 +531,21 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): sess = create_session() eq_( - sess.query(User.id).from_self(). - add_column(func.count().label('foo')).group_by(User.id). - order_by(User.id).from_self().all(), [ - (7, 1), (8, 1), (9, 1), (10, 1)]) + sess.query(User.id) + .from_self() + .add_column(func.count().label("foo")) + .group_by(User.id) + .order_by(User.id) + .from_self() + .all(), + [(7, 1), (8, 1), (9, 1), (10, 1)], + ) class ColumnAccessTest(QueryTest, AssertsCompiledSQL): """test access of columns after _from_selectable has been applied""" - __dialect__ = 'default' + __dialect__ = "default" def test_from_self(self): User = self.classes.User @@ -406,11 +553,11 @@ class ColumnAccessTest(QueryTest, AssertsCompiledSQL): q = sess.query(User).from_self() self.assert_compile( - q.filter(User.name == 'ed'), + q.filter(User.name == "ed"), "SELECT anon_1.users_id AS anon_1_users_id, anon_1.users_name AS " "anon_1_users_name FROM (SELECT users.id AS users_id, users.name " "AS users_name FROM users) AS anon_1 WHERE anon_1.users_name = " - ":name_1" + ":name_1", ) def test_from_self_twice(self): @@ -419,13 +566,13 @@ class ColumnAccessTest(QueryTest, AssertsCompiledSQL): q = sess.query(User).from_self(User.id, User.name).from_self() self.assert_compile( - q.filter(User.name == 'ed'), + q.filter(User.name == "ed"), "SELECT anon_1.anon_2_users_id AS anon_1_anon_2_users_id, " "anon_1.anon_2_users_name AS anon_1_anon_2_users_name FROM " "(SELECT anon_2.users_id AS anon_2_users_id, anon_2.users_name " "AS anon_2_users_name FROM (SELECT users.id AS users_id, " "users.name AS users_name FROM users) AS anon_2) AS anon_1 " - "WHERE anon_1.anon_2_users_name = :name_1" + "WHERE anon_1.anon_2_users_name = :name_1", ) def test_select_entity_from(self): @@ -435,10 +582,10 @@ class ColumnAccessTest(QueryTest, AssertsCompiledSQL): q = sess.query(User) q = sess.query(User).select_entity_from(q.statement) self.assert_compile( - q.filter(User.name == 'ed'), + q.filter(User.name == "ed"), "SELECT anon_1.id AS anon_1_id, anon_1.name AS anon_1_name " "FROM (SELECT users.id AS id, users.name AS name FROM " - "users) AS anon_1 WHERE anon_1.name = :name_1" + "users) AS anon_1 WHERE anon_1.name = :name_1", ) def test_select_entity_from_no_entities(self): @@ -449,7 +596,9 @@ class ColumnAccessTest(QueryTest, AssertsCompiledSQL): sa.exc.ArgumentError, r"A selectable \(FromClause\) instance is " "expected when the base alias is being set", - sess.query(User).select_entity_from, User) + sess.query(User).select_entity_from, + User, + ) def test_select_from_no_aliasing(self): User = self.classes.User @@ -458,34 +607,34 @@ class ColumnAccessTest(QueryTest, AssertsCompiledSQL): q = sess.query(User) q = sess.query(User).select_from(q.statement) self.assert_compile( - q.filter(User.name == 'ed'), + q.filter(User.name == "ed"), "SELECT users.id AS users_id, users.name AS users_name " "FROM users, (SELECT users.id AS id, users.name AS name FROM " - "users) AS anon_1 WHERE users.name = :name_1" + "users) AS anon_1 WHERE users.name = :name_1", ) def test_anonymous_expression(self): from sqlalchemy.sql import column sess = create_session() - c1, c2 = column('c1'), column('c2') - q1 = sess.query(c1, c2).filter(c1 == 'dog') - q2 = sess.query(c1, c2).filter(c1 == 'cat') + c1, c2 = column("c1"), column("c2") + q1 = sess.query(c1, c2).filter(c1 == "dog") + q2 = sess.query(c1, c2).filter(c1 == "cat") q3 = q1.union(q2) self.assert_compile( q3.order_by(c1), "SELECT anon_1.c1 AS anon_1_c1, anon_1.c2 " "AS anon_1_c2 FROM (SELECT c1, c2 WHERE " "c1 = :c1_1 UNION SELECT c1, c2 " - "WHERE c1 = :c1_2) AS anon_1 ORDER BY anon_1.c1" + "WHERE c1 = :c1_2) AS anon_1 ORDER BY anon_1.c1", ) def test_anonymous_expression_from_self_twice(self): from sqlalchemy.sql import column sess = create_session() - c1, c2 = column('c1'), column('c2') - q1 = sess.query(c1, c2).filter(c1 == 'dog') + c1, c2 = column("c1"), column("c2") + q1 = sess.query(c1, c2).filter(c1 == "dog") q1 = q1.from_self().from_self() self.assert_compile( q1.order_by(c1), @@ -493,31 +642,31 @@ class ColumnAccessTest(QueryTest, AssertsCompiledSQL): "anon_1_anon_2_c2 FROM (SELECT anon_2.c1 AS anon_2_c1, anon_2.c2 " "AS anon_2_c2 " "FROM (SELECT c1, c2 WHERE c1 = :c1_1) AS " - "anon_2) AS anon_1 ORDER BY anon_1.anon_2_c1" + "anon_2) AS anon_1 ORDER BY anon_1.anon_2_c1", ) def test_anonymous_expression_union(self): from sqlalchemy.sql import column sess = create_session() - c1, c2 = column('c1'), column('c2') - q1 = sess.query(c1, c2).filter(c1 == 'dog') - q2 = sess.query(c1, c2).filter(c1 == 'cat') + c1, c2 = column("c1"), column("c2") + q1 = sess.query(c1, c2).filter(c1 == "dog") + q2 = sess.query(c1, c2).filter(c1 == "cat") q3 = q1.union(q2) self.assert_compile( q3.order_by(c1), "SELECT anon_1.c1 AS anon_1_c1, anon_1.c2 " "AS anon_1_c2 FROM (SELECT c1, c2 WHERE " "c1 = :c1_1 UNION SELECT c1, c2 " - "WHERE c1 = :c1_2) AS anon_1 ORDER BY anon_1.c1" + "WHERE c1 = :c1_2) AS anon_1 ORDER BY anon_1.c1", ) def test_table_anonymous_expression_from_self_twice(self): from sqlalchemy.sql import column sess = create_session() - t1 = table('t1', column('c1'), column('c2')) - q1 = sess.query(t1.c.c1, t1.c.c2).filter(t1.c.c1 == 'dog') + t1 = table("t1", column("c1"), column("c2")) + q1 = sess.query(t1.c.c1, t1.c.c2).filter(t1.c.c1 == "dog") q1 = q1.from_self().from_self() self.assert_compile( q1.order_by(t1.c.c1), @@ -527,21 +676,22 @@ class ColumnAccessTest(QueryTest, AssertsCompiledSQL): "FROM (SELECT anon_2.t1_c1 AS anon_2_t1_c1, " "anon_2.t1_c2 AS anon_2_t1_c2 FROM (SELECT t1.c1 AS t1_c1, t1.c2 " "AS t1_c2 FROM t1 WHERE t1.c1 = :c1_1) AS anon_2) AS anon_1 " - "ORDER BY anon_1.anon_2_t1_c1" + "ORDER BY anon_1.anon_2_t1_c1", ) def test_anonymous_labeled_expression(self): sess = create_session() - c1, c2 = column('c1'), column('c2') - q1 = sess.query(c1.label('foo'), c2.label('bar')).filter(c1 == 'dog') - q2 = sess.query(c1.label('foo'), c2.label('bar')).filter(c1 == 'cat') + c1, c2 = column("c1"), column("c2") + q1 = sess.query(c1.label("foo"), c2.label("bar")).filter(c1 == "dog") + q2 = sess.query(c1.label("foo"), c2.label("bar")).filter(c1 == "cat") q3 = q1.union(q2) self.assert_compile( q3.order_by(c1), "SELECT anon_1.foo AS anon_1_foo, anon_1.bar AS anon_1_bar FROM " "(SELECT c1 AS foo, c2 AS bar WHERE c1 = :c1_1 UNION SELECT " "c1 AS foo, c2 AS bar " - "WHERE c1 = :c1_2) AS anon_1 ORDER BY anon_1.foo") + "WHERE c1 = :c1_2) AS anon_1 ORDER BY anon_1.foo", + ) def test_anonymous_expression_plus_aliased_join(self): """test that the 'dont alias non-ORM' rule remains for other @@ -554,51 +704,58 @@ class ColumnAccessTest(QueryTest, AssertsCompiledSQL): sess = create_session() q1 = sess.query(User.id).filter(User.id > 5) q1 = q1.from_self() - q1 = q1.join(User.addresses, aliased=True).\ - order_by(User.id, Address.id, addresses.c.id) + q1 = q1.join(User.addresses, aliased=True).order_by( + User.id, Address.id, addresses.c.id + ) self.assert_compile( q1, "SELECT anon_1.users_id AS anon_1_users_id " "FROM (SELECT users.id AS users_id FROM users " "WHERE users.id > :id_1) AS anon_1 JOIN addresses AS addresses_1 " "ON anon_1.users_id = addresses_1.user_id " - "ORDER BY anon_1.users_id, addresses_1.id, addresses.id" + "ORDER BY anon_1.users_id, addresses_1.id, addresses.id", ) class AddEntityEquivalenceTest(fixtures.MappedTest, AssertsCompiledSQL): - run_setup_mappers = 'once' + run_setup_mappers = "once" @classmethod def define_tables(cls, metadata): Table( - 'a', metadata, + "a", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('type', String(20)), - Column('bid', Integer, ForeignKey('b.id')) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), + Column("type", String(20)), + Column("bid", Integer, ForeignKey("b.id")), ) Table( - 'b', metadata, + "b", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('type', String(20)) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), + Column("type", String(20)), ) Table( - 'c', metadata, - Column('id', Integer, ForeignKey('b.id'), primary_key=True), - Column('age', Integer)) + "c", + metadata, + Column("id", Integer, ForeignKey("b.id"), primary_key=True), + Column("age", Integer), + ) Table( - 'd', metadata, - Column('id', Integer, ForeignKey('a.id'), primary_key=True), - Column('dede', Integer)) + "d", + metadata, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + Column("dede", Integer), + ) @classmethod def setup_classes(cls): @@ -617,14 +774,22 @@ class AddEntityEquivalenceTest(fixtures.MappedTest, AssertsCompiledSQL): pass mapper( - A, a, polymorphic_identity='a', polymorphic_on=a.c.type, - with_polymorphic=('*', None), properties={ - 'link': relation(B, uselist=False, backref='back')}) + A, + a, + polymorphic_identity="a", + polymorphic_on=a.c.type, + with_polymorphic=("*", None), + properties={"link": relation(B, uselist=False, backref="back")}, + ) mapper( - B, b, polymorphic_identity='b', polymorphic_on=b.c.type, - with_polymorphic=('*', None)) - mapper(C, c, inherits=B, polymorphic_identity='c') - mapper(D, d, inherits=A, polymorphic_identity='d') + B, + b, + polymorphic_identity="b", + polymorphic_on=b.c.type, + with_polymorphic=("*", None), + ) + mapper(C, c, inherits=B, polymorphic_identity="c") + mapper(D, d, inherits=A, polymorphic_identity="d") @classmethod def insert_data(cls): @@ -633,10 +798,12 @@ class AddEntityEquivalenceTest(fixtures.MappedTest, AssertsCompiledSQL): sess = create_session() sess.add_all( [ - B(name='b1'), - A(name='a1', link=C(name='c1', age=3)), - C(name='c2', age=6), - A(name='a2')]) + B(name="b1"), + A(name="a1", link=C(name="c1", age=3)), + C(name="c2", age=6), + A(name="a2"), + ] + ) sess.flush() def test_add_entity_equivalence(self): @@ -650,95 +817,125 @@ class AddEntityEquivalenceTest(fixtures.MappedTest, AssertsCompiledSQL): ]: eq_( q.all(), - [( - A(bid=2, id=1, name='a1', type='a'), - C(age=3, id=2, name='c1', type='c') - )] + [ + ( + A(bid=2, id=1, name="a1", type="a"), + C(age=3, id=2, name="c1", type="c"), + ) + ], ) for q in [ sess.query(B, A).join(B.back), sess.query(B).join(B.back).add_entity(A), - sess.query(B).add_entity(A).join(B.back) + sess.query(B).add_entity(A).join(B.back), ]: eq_( q.all(), - [( - C(age=3, id=2, name='c1', type='c'), - A(bid=2, id=1, name='a1', type='a') - )] + [ + ( + C(age=3, id=2, name="c1", type="c"), + A(bid=2, id=1, name="a1", type="a"), + ) + ], ) class InstancesTest(QueryTest, AssertsCompiledSQL): - def test_from_alias_one(self): - User, addresses, users = (self.classes.User, - self.tables.addresses, - self.tables.users) + User, addresses, users = ( + self.classes.User, + self.tables.addresses, + self.tables.users, + ) - query = users.select(users.c.id == 7).\ - union(users.select(users.c.id > 7)).alias('ulist').\ - outerjoin(addresses).\ - select( - use_labels=True, - order_by=[text('ulist.id'), addresses.c.id]) + query = ( + users.select(users.c.id == 7) + .union(users.select(users.c.id > 7)) + .alias("ulist") + .outerjoin(addresses) + .select( + use_labels=True, order_by=[text("ulist.id"), addresses.c.id] + ) + ) sess = create_session() q = sess.query(User) def go(): result = list( q.options( - contains_alias('ulist'), contains_eager('addresses')). - instances(query.execute())) + contains_alias("ulist"), contains_eager("addresses") + ).instances(query.execute()) + ) assert self.static.user_address_result == result + self.assert_sql_count(testing.db, go, 1) def test_from_alias_two(self): - User, addresses, users = (self.classes.User, - self.tables.addresses, - self.tables.users) + User, addresses, users = ( + self.classes.User, + self.tables.addresses, + self.tables.users, + ) - query = users.select(users.c.id == 7).\ - union(users.select(users.c.id > 7)).alias('ulist').\ - outerjoin(addresses). \ - select( - use_labels=True, - order_by=[text('ulist.id'), addresses.c.id]) + query = ( + users.select(users.c.id == 7) + .union(users.select(users.c.id > 7)) + .alias("ulist") + .outerjoin(addresses) + .select( + use_labels=True, order_by=[text("ulist.id"), addresses.c.id] + ) + ) sess = create_session() q = sess.query(User) def go(): - result = q.options( - contains_alias('ulist'), contains_eager('addresses')).\ - from_statement(query).all() + result = ( + q.options(contains_alias("ulist"), contains_eager("addresses")) + .from_statement(query) + .all() + ) assert self.static.user_address_result == result + self.assert_sql_count(testing.db, go, 1) def test_from_alias_three(self): - User, addresses, users = (self.classes.User, - self.tables.addresses, - self.tables.users) + User, addresses, users = ( + self.classes.User, + self.tables.addresses, + self.tables.users, + ) - query = users.select(users.c.id == 7).\ - union(users.select(users.c.id > 7)).alias('ulist').\ - outerjoin(addresses). \ - select( - use_labels=True, - order_by=[text('ulist.id'), addresses.c.id]) + query = ( + users.select(users.c.id == 7) + .union(users.select(users.c.id > 7)) + .alias("ulist") + .outerjoin(addresses) + .select( + use_labels=True, order_by=[text("ulist.id"), addresses.c.id] + ) + ) sess = create_session() # better way. use select_entity_from() def go(): - result = sess.query(User).select_entity_from(query).\ - options(contains_eager('addresses')).all() + result = ( + sess.query(User) + .select_entity_from(query) + .options(contains_eager("addresses")) + .all() + ) assert self.static.user_address_result == result + self.assert_sql_count(testing.db, go, 1) def test_from_alias_four(self): - User, addresses, users = (self.classes.User, - self.tables.addresses, - self.tables.users) + User, addresses, users = ( + self.classes.User, + self.tables.addresses, + self.tables.users, + ) sess = create_session() @@ -746,124 +943,164 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): # generated by select_entity_from() is wrapped within # the adapter created by contains_eager() adalias = addresses.alias() - query = users.select(users.c.id == 7).\ - union(users.select(users.c.id > 7)).\ - alias('ulist').outerjoin(adalias).\ - select(use_labels=True, order_by=[text('ulist.id'), adalias.c.id]) + query = ( + users.select(users.c.id == 7) + .union(users.select(users.c.id > 7)) + .alias("ulist") + .outerjoin(adalias) + .select(use_labels=True, order_by=[text("ulist.id"), adalias.c.id]) + ) def go(): - result = sess.query(User).select_entity_from(query).\ - options(contains_eager('addresses', alias=adalias)).all() + result = ( + sess.query(User) + .select_entity_from(query) + .options(contains_eager("addresses", alias=adalias)) + .all() + ) assert self.static.user_address_result == result + self.assert_sql_count(testing.db, go, 1) def test_contains_eager(self): - users, addresses, User = (self.tables.users, - self.tables.addresses, - self.classes.User) + users, addresses, User = ( + self.tables.users, + self.tables.addresses, + self.classes.User, + ) sess = create_session() # test that contains_eager suppresses the normal outer join rendering - q = sess.query(User).outerjoin(User.addresses).\ - options(contains_eager(User.addresses)).\ - order_by(User.id, addresses.c.id) - self.assert_compile(q.with_labels().statement, - 'SELECT addresses.id AS addresses_id, ' - 'addresses.user_id AS addresses_user_id, ' - 'addresses.email_address AS ' - 'addresses_email_address, users.id AS ' - 'users_id, users.name AS users_name FROM ' - 'users LEFT OUTER JOIN addresses ON ' - 'users.id = addresses.user_id ORDER BY ' - 'users.id, addresses.id', - dialect=default.DefaultDialect()) + q = ( + sess.query(User) + .outerjoin(User.addresses) + .options(contains_eager(User.addresses)) + .order_by(User.id, addresses.c.id) + ) + self.assert_compile( + q.with_labels().statement, + "SELECT addresses.id AS addresses_id, " + "addresses.user_id AS addresses_user_id, " + "addresses.email_address AS " + "addresses_email_address, users.id AS " + "users_id, users.name AS users_name FROM " + "users LEFT OUTER JOIN addresses ON " + "users.id = addresses.user_id ORDER BY " + "users.id, addresses.id", + dialect=default.DefaultDialect(), + ) def go(): assert self.static.user_address_result == q.all() + self.assert_sql_count(testing.db, go, 1) sess.expunge_all() adalias = addresses.alias() - q = sess.query(User).\ - select_entity_from(users.outerjoin(adalias)).\ - options(contains_eager(User.addresses, alias=adalias)).\ - order_by(User.id, adalias.c.id) + q = ( + sess.query(User) + .select_entity_from(users.outerjoin(adalias)) + .options(contains_eager(User.addresses, alias=adalias)) + .order_by(User.id, adalias.c.id) + ) def go(): eq_(self.static.user_address_result, q.all()) + self.assert_sql_count(testing.db, go, 1) sess.expunge_all() - selectquery = users.outerjoin(addresses). \ - select( - users.c.id < 10, use_labels=True, - order_by=[users.c.id, addresses.c.id]) + selectquery = users.outerjoin(addresses).select( + users.c.id < 10, + use_labels=True, + order_by=[users.c.id, addresses.c.id], + ) q = sess.query(User) def go(): result = list( - q.options(contains_eager('addresses')). - instances(selectquery.execute())) + q.options(contains_eager("addresses")).instances( + selectquery.execute() + ) + ) assert self.static.user_address_result[0:3] == result + self.assert_sql_count(testing.db, go, 1) sess.expunge_all() def go(): result = list( - q.options(contains_eager(User.addresses)). - instances(selectquery.execute())) + q.options(contains_eager(User.addresses)).instances( + selectquery.execute() + ) + ) assert self.static.user_address_result[0:3] == result + self.assert_sql_count(testing.db, go, 1) sess.expunge_all() def go(): - result = q.options( - contains_eager('addresses')).from_statement(selectquery).all() + result = ( + q.options(contains_eager("addresses")) + .from_statement(selectquery) + .all() + ) assert self.static.user_address_result[0:3] == result + self.assert_sql_count(testing.db, go, 1) def test_contains_eager_string_alias(self): - addresses, users, User = (self.tables.addresses, - self.tables.users, - self.classes.User) + addresses, users, User = ( + self.tables.addresses, + self.tables.users, + self.classes.User, + ) sess = create_session() q = sess.query(User) - adalias = addresses.alias('adalias') - selectquery = users.outerjoin(adalias). \ - select(use_labels=True, order_by=[users.c.id, adalias.c.id]) + adalias = addresses.alias("adalias") + selectquery = users.outerjoin(adalias).select( + use_labels=True, order_by=[users.c.id, adalias.c.id] + ) # string alias name def go(): result = list( q.options( - contains_eager('addresses', alias="adalias")). - instances(selectquery.execute())) + contains_eager("addresses", alias="adalias") + ).instances(selectquery.execute()) + ) assert self.static.user_address_result == result + self.assert_sql_count(testing.db, go, 1) def test_contains_eager_aliased_instances(self): - addresses, users, User = (self.tables.addresses, - self.tables.users, - self.classes.User) + addresses, users, User = ( + self.tables.addresses, + self.tables.users, + self.classes.User, + ) sess = create_session() q = sess.query(User) - adalias = addresses.alias('adalias') - selectquery = users.outerjoin(adalias).\ - select(use_labels=True, order_by=[users.c.id, adalias.c.id]) + adalias = addresses.alias("adalias") + selectquery = users.outerjoin(adalias).select( + use_labels=True, order_by=[users.c.id, adalias.c.id] + ) # expression.Alias object def go(): result = list( q.options( - contains_eager('addresses', alias=adalias)). - instances(selectquery.execute())) + contains_eager("addresses", alias=adalias) + ).instances(selectquery.execute()) + ) assert self.static.user_address_result == result + self.assert_sql_count(testing.db, go, 1) def test_contains_eager_aliased(self): @@ -876,54 +1113,70 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): adalias = aliased(Address) def go(): - result = q.options( - contains_eager('addresses', alias=adalias) - ).outerjoin(adalias, User.addresses).\ - order_by(User.id, adalias.id) + result = ( + q.options(contains_eager("addresses", alias=adalias)) + .outerjoin(adalias, User.addresses) + .order_by(User.id, adalias.id) + ) assert self.static.user_address_result == result.all() + self.assert_sql_count(testing.db, go, 1) def test_contains_eager_multi_string_alias(self): - orders, items, users, order_items, User = (self.tables.orders, - self.tables.items, - self.tables.users, - self.tables.order_items, - self.classes.User) + orders, items, users, order_items, User = ( + self.tables.orders, + self.tables.items, + self.tables.users, + self.tables.order_items, + self.classes.User, + ) sess = create_session() q = sess.query(User) - oalias = orders.alias('o1') - ialias = items.alias('i1') - query = users.outerjoin(oalias).outerjoin(order_items).\ - outerjoin(ialias).select(use_labels=True).\ - order_by(users.c.id, oalias.c.id, ialias.c.id) + oalias = orders.alias("o1") + ialias = items.alias("i1") + query = ( + users.outerjoin(oalias) + .outerjoin(order_items) + .outerjoin(ialias) + .select(use_labels=True) + .order_by(users.c.id, oalias.c.id, ialias.c.id) + ) # test using string alias with more than one level deep def go(): result = list( q.options( - contains_eager('orders', alias='o1'), - contains_eager('orders.items', alias='i1') - ).instances(query.execute())) + contains_eager("orders", alias="o1"), + contains_eager("orders.items", alias="i1"), + ).instances(query.execute()) + ) assert self.static.user_order_result == result + self.assert_sql_count(testing.db, go, 1) def test_contains_eager_multi_alias(self): - orders, items, users, order_items, User = (self.tables.orders, - self.tables.items, - self.tables.users, - self.tables.order_items, - self.classes.User) + orders, items, users, order_items, User = ( + self.tables.orders, + self.tables.items, + self.tables.users, + self.tables.order_items, + self.classes.User, + ) sess = create_session() q = sess.query(User) - oalias = orders.alias('o1') - ialias = items.alias('i1') - query = users.outerjoin(oalias).outerjoin(order_items).\ - outerjoin(ialias).select(use_labels=True).\ - order_by(users.c.id, oalias.c.id, ialias.c.id) + oalias = orders.alias("o1") + ialias = items.alias("i1") + query = ( + users.outerjoin(oalias) + .outerjoin(order_items) + .outerjoin(ialias) + .select(use_labels=True) + .order_by(users.c.id, oalias.c.id, ialias.c.id) + ) # test using Alias with more than one level deep @@ -935,15 +1188,20 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): def go(): result = list( q.options( - contains_eager('orders', alias=oalias), - contains_eager('orders.items', alias=ialias)). - instances(query.execute())) + contains_eager("orders", alias=oalias), + contains_eager("orders.items", alias=ialias), + ).instances(query.execute()) + ) assert self.static.user_order_result == result + self.assert_sql_count(testing.db, go, 1) def test_contains_eager_multi_aliased(self): Item, User, Order = ( - self.classes.Item, self.classes.User, self.classes.Order) + self.classes.Item, + self.classes.User, + self.classes.Order, + ) sess = create_session() q = sess.query(User) @@ -953,25 +1211,35 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): ialias = aliased(Item) def go(): - result = q.options( - contains_eager(User.orders, alias=oalias), - contains_eager(User.orders, Order.items, alias=ialias)).\ - outerjoin(oalias, User.orders).\ - outerjoin(ialias, oalias.items).\ - order_by(User.id, oalias.id, ialias.id) + result = ( + q.options( + contains_eager(User.orders, alias=oalias), + contains_eager(User.orders, Order.items, alias=ialias), + ) + .outerjoin(oalias, User.orders) + .outerjoin(ialias, oalias.items) + .order_by(User.id, oalias.id, ialias.id) + ) assert self.static.user_order_result == result.all() + self.assert_sql_count(testing.db, go, 1) def test_contains_eager_chaining(self): """test that contains_eager() 'chains' by default.""" - Dingaling, User, Address = (self.classes.Dingaling, - self.classes.User, - self.classes.Address) + Dingaling, User, Address = ( + self.classes.Dingaling, + self.classes.User, + self.classes.Address, + ) sess = create_session() - q = sess.query(User).join(User.addresses).join(Address.dingaling).\ - options(contains_eager(User.addresses, Address.dingaling),) + q = ( + sess.query(User) + .join(User.addresses) + .join(Address.dingaling) + .options(contains_eager(User.addresses, Address.dingaling)) + ) def go(): eq_( @@ -980,32 +1248,49 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): # have a Dingaling here due to using the inner # join for the eager load [ - User(name='ed', addresses=[ - Address(email_address='ed@wood.com', - dingaling=Dingaling(data='ding 1/2')), - ]), - User(name='fred', addresses=[ - Address(email_address='fred@fred.com', - dingaling=Dingaling(data='ding 2/5')) - ]) - ] + User( + name="ed", + addresses=[ + Address( + email_address="ed@wood.com", + dingaling=Dingaling(data="ding 1/2"), + ) + ], + ), + User( + name="fred", + addresses=[ + Address( + email_address="fred@fred.com", + dingaling=Dingaling(data="ding 2/5"), + ) + ], + ), + ], ) + self.assert_sql_count(testing.db, go, 1) def test_contains_eager_chaining_aliased_endpoint(self): """test that contains_eager() 'chains' by default and supports an alias at the end.""" - Dingaling, User, Address = (self.classes.Dingaling, - self.classes.User, - self.classes.Address) + Dingaling, User, Address = ( + self.classes.Dingaling, + self.classes.User, + self.classes.Address, + ) sess = create_session() da = aliased(Dingaling, name="foob") - q = sess.query(User).join(User.addresses).\ - join(da, Address.dingaling).\ - options( - contains_eager(User.addresses, Address.dingaling, alias=da),) + q = ( + sess.query(User) + .join(User.addresses) + .join(da, Address.dingaling) + .options( + contains_eager(User.addresses, Address.dingaling, alias=da) + ) + ) def go(): eq_( @@ -1014,22 +1299,35 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): # have a Dingaling here due to using the inner # join for the eager load [ - User(name='ed', addresses=[ - Address(email_address='ed@wood.com', - dingaling=Dingaling(data='ding 1/2')), - ]), - User(name='fred', addresses=[ - Address(email_address='fred@fred.com', - dingaling=Dingaling(data='ding 2/5')) - ]) - ] + User( + name="ed", + addresses=[ + Address( + email_address="ed@wood.com", + dingaling=Dingaling(data="ding 1/2"), + ) + ], + ), + User( + name="fred", + addresses=[ + Address( + email_address="fred@fred.com", + dingaling=Dingaling(data="ding 2/5"), + ) + ], + ), + ], ) + self.assert_sql_count(testing.db, go, 1) def test_mixed_eager_contains_with_limit(self): - Order, User, Address = (self.classes.Order, - self.classes.User, - self.classes.Address) + Order, User, Address = ( + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() @@ -1043,25 +1341,46 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): # applies context.adapter to result rows. This was # [ticket:1180]. - result = q.outerjoin(User.orders).options( - joinedload(User.addresses), contains_eager(User.orders)). \ - order_by(User.id, Order.id).offset(1).limit(2).all() + result = ( + q.outerjoin(User.orders) + .options( + joinedload(User.addresses), contains_eager(User.orders) + ) + .order_by(User.id, Order.id) + .offset(1) + .limit(2) + .all() + ) eq_( - result, [ + result, + [ User( id=7, addresses=[ Address( - email_address='jack@bean.com', - user_id=7, id=1)], - name='jack', + email_address="jack@bean.com", user_id=7, id=1 + ) + ], + name="jack", orders=[ Order( - address_id=1, user_id=7, description='order 3', - isopen=1, id=3), + address_id=1, + user_id=7, + description="order 3", + isopen=1, + id=3, + ), Order( - address_id=None, user_id=7, - description='order 5', isopen=0, id=5)])]) + address_id=None, + user_id=7, + description="order 5", + isopen=0, + id=5, + ), + ], + ) + ], + ) self.assert_sql_count(testing.db, go, 1) sess.expunge_all() @@ -1072,11 +1391,17 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): # are applied by the eager loader oalias = aliased(Order) - result = q.outerjoin(oalias, User.orders).options( - joinedload(User.addresses), - contains_eager(User.orders, alias=oalias)). \ - order_by(User.id, oalias.id).\ - offset(1).limit(2).all() + result = ( + q.outerjoin(oalias, User.orders) + .options( + joinedload(User.addresses), + contains_eager(User.orders, alias=oalias), + ) + .order_by(User.id, oalias.id) + .offset(1) + .limit(2) + .all() + ) eq_( result, [ @@ -1084,27 +1409,42 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): id=7, addresses=[ Address( - email_address='jack@bean.com', - user_id=7, id=1)], - name='jack', + email_address="jack@bean.com", user_id=7, id=1 + ) + ], + name="jack", orders=[ Order( - address_id=1, user_id=7, description='order 3', - isopen=1, id=3), + address_id=1, + user_id=7, + description="order 3", + isopen=1, + id=3, + ), Order( - address_id=None, user_id=7, - description='order 5', isopen=0, id=5)])]) + address_id=None, + user_id=7, + description="order 5", + isopen=0, + id=5, + ), + ], + ) + ], + ) self.assert_sql_count(testing.db, go, 1) class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_values(self): - Address, users, User = (self.classes.Address, - self.tables.users, - self.classes.User) + Address, users, User = ( + self.classes.Address, + self.tables.users, + self.classes.User, + ) sess = create_session() @@ -1113,56 +1453,91 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): sel = users.select(User.id.in_([7, 8])).alias() q = sess.query(User) q2 = q.select_entity_from(sel).values(User.name) - eq_(list(q2), [('jack',), ('ed',)]) + eq_(list(q2), [("jack",), ("ed",)]) q = sess.query(User) - q2 = q.order_by(User.id).\ - values(User.name, User.name + " " + cast(User.id, String(50))) + q2 = q.order_by(User.id).values( + User.name, User.name + " " + cast(User.id, String(50)) + ) eq_( list(q2), [ - ('jack', 'jack 7'), ('ed', 'ed 8'), - ('fred', 'fred 9'), ('chuck', 'chuck 10')] + ("jack", "jack 7"), + ("ed", "ed 8"), + ("fred", "fred 9"), + ("chuck", "chuck 10"), + ], ) - q2 = q.join('addresses').filter(User.name.like('%e%')).\ - order_by(User.id, Address.id).\ - values(User.name, Address.email_address) + q2 = ( + q.join("addresses") + .filter(User.name.like("%e%")) + .order_by(User.id, Address.id) + .values(User.name, Address.email_address) + ) eq_( list(q2), [ - ('ed', 'ed@wood.com'), ('ed', 'ed@bettyboop.com'), - ('ed', 'ed@lala.com'), ('fred', 'fred@fred.com')]) + ("ed", "ed@wood.com"), + ("ed", "ed@bettyboop.com"), + ("ed", "ed@lala.com"), + ("fred", "fred@fred.com"), + ], + ) - q2 = q.join('addresses').filter(User.name.like('%e%')).\ - order_by(desc(Address.email_address)).\ - slice(1, 3).values(User.name, Address.email_address) - eq_(list(q2), [('ed', 'ed@wood.com'), ('ed', 'ed@lala.com')]) + q2 = ( + q.join("addresses") + .filter(User.name.like("%e%")) + .order_by(desc(Address.email_address)) + .slice(1, 3) + .values(User.name, Address.email_address) + ) + eq_(list(q2), [("ed", "ed@wood.com"), ("ed", "ed@lala.com")]) adalias = aliased(Address) - q2 = q.join(adalias, 'addresses'). \ - filter(User.name.like('%e%')).order_by(adalias.email_address).\ - values(User.name, adalias.email_address) - eq_(list(q2), [('ed', 'ed@bettyboop.com'), ('ed', 'ed@lala.com'), - ('ed', 'ed@wood.com'), ('fred', 'fred@fred.com')]) + q2 = ( + q.join(adalias, "addresses") + .filter(User.name.like("%e%")) + .order_by(adalias.email_address) + .values(User.name, adalias.email_address) + ) + eq_( + list(q2), + [ + ("ed", "ed@bettyboop.com"), + ("ed", "ed@lala.com"), + ("ed", "ed@wood.com"), + ("fred", "fred@fred.com"), + ], + ) q2 = q.values(func.count(User.name)) assert next(q2) == (4,) - q2 = q.select_entity_from(sel).filter(User.id == 8). \ - values(User.name, sel.c.name, User.name) - eq_(list(q2), [('ed', 'ed', 'ed')]) + q2 = ( + q.select_entity_from(sel) + .filter(User.id == 8) + .values(User.name, sel.c.name, User.name) + ) + eq_(list(q2), [("ed", "ed", "ed")]) # using User.xxx is alised against "sel", so this query returns nothing - q2 = q.select_entity_from(sel).filter(User.id == 8).\ - filter(User.id > sel.c.id).values(User.name, sel.c.name, User.name) + q2 = ( + q.select_entity_from(sel) + .filter(User.id == 8) + .filter(User.id > sel.c.id) + .values(User.name, sel.c.name, User.name) + ) eq_(list(q2), []) # whereas this uses users.c.xxx, is not aliased and creates a new join - q2 = q.select_entity_from(sel).filter(users.c.id == 8).\ - filter(users.c.id > sel.c.id). \ - values(users.c.name, sel.c.name, User.name) - eq_(list(q2), [('ed', 'jack', 'jack')]) + q2 = ( + q.select_entity_from(sel) + .filter(users.c.id == 8) + .filter(users.c.id > sel.c.id) + .values(users.c.name, sel.c.name, User.name) + ) + eq_(list(q2), [("ed", "jack", "jack")]) def test_alias_naming(self): User = self.classes.User @@ -1174,10 +1549,10 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): self.assert_compile( q, "SELECT foobar.id AS foobar_id, " - "foobar.name AS foobar_name FROM users AS foobar" + "foobar.name AS foobar_name FROM users AS foobar", ) - @testing.fails_on('mssql', 'FIXME: unknown') + @testing.fails_on("mssql", "FIXME: unknown") def test_values_specific_order_by(self): users, User = self.tables.users, self.classes.User @@ -1188,27 +1563,40 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): sel = users.select(User.id.in_([7, 8])).alias() q = sess.query(User) u2 = aliased(User) - q2 = q.select_entity_from(sel).filter(u2.id > 1).\ - order_by(User.id, sel.c.id, u2.id).\ - values(User.name, sel.c.name, u2.name) + q2 = ( + q.select_entity_from(sel) + .filter(u2.id > 1) + .order_by(User.id, sel.c.id, u2.id) + .values(User.name, sel.c.name, u2.name) + ) eq_( list(q2), [ - ('jack', 'jack', 'jack'), ('jack', 'jack', 'ed'), - ('jack', 'jack', 'fred'), ('jack', 'jack', 'chuck'), - ('ed', 'ed', 'jack'), ('ed', 'ed', 'ed'), - ('ed', 'ed', 'fred'), ('ed', 'ed', 'chuck')]) - - @testing.fails_on('mssql', 'FIXME: unknown') - @testing.fails_on('oracle', - "Oracle doesn't support boolean expressions as " - "columns") - @testing.fails_on('postgresql+pg8000', - "pg8000 parses the SQL itself before passing on " - "to PG, doesn't parse this") - @testing.fails_on('postgresql+zxjdbc', - "zxjdbc parses the SQL itself before passing on " - "to PG, doesn't parse this") + ("jack", "jack", "jack"), + ("jack", "jack", "ed"), + ("jack", "jack", "fred"), + ("jack", "jack", "chuck"), + ("ed", "ed", "jack"), + ("ed", "ed", "ed"), + ("ed", "ed", "fred"), + ("ed", "ed", "chuck"), + ], + ) + + @testing.fails_on("mssql", "FIXME: unknown") + @testing.fails_on( + "oracle", "Oracle doesn't support boolean expressions as " "columns" + ) + @testing.fails_on( + "postgresql+pg8000", + "pg8000 parses the SQL itself before passing on " + "to PG, doesn't parse this", + ) + @testing.fails_on( + "postgresql+zxjdbc", + "zxjdbc parses the SQL itself before passing on " + "to PG, doesn't parse this", + ) @testing.fails_on("firebird", "unknown") def test_values_with_boolean_selects(self): """Tests a values clause that works with select boolean @@ -1219,13 +1607,16 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): sess = create_session() q = sess.query(User) - q2 = q.group_by(User.name.like('%j%')).\ - order_by(desc(User.name.like('%j%'))).\ - values(User.name.like('%j%'), func.count(User.name.like('%j%'))) + q2 = ( + q.group_by(User.name.like("%j%")) + .order_by(desc(User.name.like("%j%"))) + .values(User.name.like("%j%"), func.count(User.name.like("%j%"))) + ) eq_(list(q2), [(True, 1), (False, 3)]) - q2 = q.order_by(desc(User.name.like('%j%'))). \ - values(User.name.like('%j%')) + q2 = q.order_by(desc(User.name.like("%j%"))).values( + User.name.like("%j%") + ) eq_(list(q2), [(True,), (False,), (False,), (False,)]) def test_correlated_subquery(self): @@ -1233,134 +1624,186 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): out those entities to the outermost query.""" Address, users, User = ( - self.classes.Address, self.tables.users, self.classes.User) + self.classes.Address, + self.tables.users, + self.classes.User, + ) sess = create_session() - subq = select([func.count()]).where(User.id == Address.user_id).\ - correlate(users).label('count') + subq = ( + select([func.count()]) + .where(User.id == Address.user_id) + .correlate(users) + .label("count") + ) # we don't want Address to be outside of the subquery here eq_( list(sess.query(User, subq)[0:3]), [ - (User(id=7, name='jack'), 1), (User(id=8, name='ed'), 3), - (User(id=9, name='fred'), 1)]) + (User(id=7, name="jack"), 1), + (User(id=8, name="ed"), 3), + (User(id=9, name="fred"), 1), + ], + ) # same thing without the correlate, as it should # not be needed - subq = select([func.count()]).where(User.id == Address.user_id).\ - label('count') + subq = ( + select([func.count()]) + .where(User.id == Address.user_id) + .label("count") + ) # we don't want Address to be outside of the subquery here eq_( list(sess.query(User, subq)[0:3]), [ - (User(id=7, name='jack'), 1), (User(id=8, name='ed'), 3), - (User(id=9, name='fred'), 1)]) + (User(id=7, name="jack"), 1), + (User(id=8, name="ed"), 3), + (User(id=9, name="fred"), 1), + ], + ) def test_column_queries(self): - Address, users, User = (self.classes.Address, - self.tables.users, - self.classes.User) + Address, users, User = ( + self.classes.Address, + self.tables.users, + self.classes.User, + ) sess = create_session() eq_( sess.query(User.name).all(), - [('jack',), ('ed',), ('fred',), ('chuck',)]) + [("jack",), ("ed",), ("fred",), ("chuck",)], + ) sel = users.select(User.id.in_([7, 8])).alias() q = sess.query(User.name) q2 = q.select_entity_from(sel).all() - eq_(list(q2), [('jack',), ('ed',)]) + eq_(list(q2), [("jack",), ("ed",)]) eq_( - sess.query(User.name, Address.email_address). - filter(User.id == Address.user_id).all(), + sess.query(User.name, Address.email_address) + .filter(User.id == Address.user_id) + .all(), [ - ('jack', 'jack@bean.com'), ('ed', 'ed@wood.com'), - ('ed', 'ed@bettyboop.com'), ('ed', 'ed@lala.com'), - ('fred', 'fred@fred.com')]) + ("jack", "jack@bean.com"), + ("ed", "ed@wood.com"), + ("ed", "ed@bettyboop.com"), + ("ed", "ed@lala.com"), + ("fred", "fred@fred.com"), + ], + ) eq_( - sess.query(User.name, func.count(Address.email_address)). - outerjoin(User.addresses).group_by(User.id, User.name). - order_by(User.id).all(), - [('jack', 1), ('ed', 3), ('fred', 1), ('chuck', 0)]) + sess.query(User.name, func.count(Address.email_address)) + .outerjoin(User.addresses) + .group_by(User.id, User.name) + .order_by(User.id) + .all(), + [("jack", 1), ("ed", 3), ("fred", 1), ("chuck", 0)], + ) eq_( - sess.query(User, func.count(Address.email_address)). - outerjoin(User.addresses).group_by(User). - order_by(User.id).all(), + sess.query(User, func.count(Address.email_address)) + .outerjoin(User.addresses) + .group_by(User) + .order_by(User.id) + .all(), [ - (User(name='jack', id=7), 1), (User(name='ed', id=8), 3), - (User(name='fred', id=9), 1), (User(name='chuck', id=10), 0)]) + (User(name="jack", id=7), 1), + (User(name="ed", id=8), 3), + (User(name="fred", id=9), 1), + (User(name="chuck", id=10), 0), + ], + ) eq_( - sess.query(func.count(Address.email_address), User). - outerjoin(User.addresses).group_by(User). - order_by(User.id).all(), + sess.query(func.count(Address.email_address), User) + .outerjoin(User.addresses) + .group_by(User) + .order_by(User.id) + .all(), [ - (1, User(name='jack', id=7)), (3, User(name='ed', id=8)), - (1, User(name='fred', id=9)), (0, User(name='chuck', id=10))]) + (1, User(name="jack", id=7)), + (3, User(name="ed", id=8)), + (1, User(name="fred", id=9)), + (0, User(name="chuck", id=10)), + ], + ) adalias = aliased(Address) eq_( - sess.query(User, func.count(adalias.email_address)). - outerjoin(adalias, 'addresses').group_by(User). - order_by(User.id).all(), + sess.query(User, func.count(adalias.email_address)) + .outerjoin(adalias, "addresses") + .group_by(User) + .order_by(User.id) + .all(), [ - (User(name='jack', id=7), 1), (User(name='ed', id=8), 3), - (User(name='fred', id=9), 1), (User(name='chuck', id=10), 0)]) + (User(name="jack", id=7), 1), + (User(name="ed", id=8), 3), + (User(name="fred", id=9), 1), + (User(name="chuck", id=10), 0), + ], + ) eq_( - sess.query(func.count(adalias.email_address), User). - outerjoin(adalias, User.addresses).group_by(User). - order_by(User.id).all(), + sess.query(func.count(adalias.email_address), User) + .outerjoin(adalias, User.addresses) + .group_by(User) + .order_by(User.id) + .all(), [ - (1, User(name='jack', id=7)), (3, User(name='ed', id=8)), - (1, User(name='fred', id=9)), (0, User(name='chuck', id=10))] + (1, User(name="jack", id=7)), + (3, User(name="ed", id=8)), + (1, User(name="fred", id=9)), + (0, User(name="chuck", id=10)), + ], ) # select from aliasing + explicit aliasing eq_( - sess.query(User, adalias.email_address, adalias.id). - outerjoin(adalias, User.addresses). - from_self(User, adalias.email_address). - order_by(User.id, adalias.id).all(), + sess.query(User, adalias.email_address, adalias.id) + .outerjoin(adalias, User.addresses) + .from_self(User, adalias.email_address) + .order_by(User.id, adalias.id) + .all(), [ - (User(name='jack', id=7), 'jack@bean.com'), - (User(name='ed', id=8), 'ed@wood.com'), - (User(name='ed', id=8), 'ed@bettyboop.com'), - (User(name='ed', id=8), 'ed@lala.com'), - (User(name='fred', id=9), 'fred@fred.com'), - (User(name='chuck', id=10), None) - ] + (User(name="jack", id=7), "jack@bean.com"), + (User(name="ed", id=8), "ed@wood.com"), + (User(name="ed", id=8), "ed@bettyboop.com"), + (User(name="ed", id=8), "ed@lala.com"), + (User(name="fred", id=9), "fred@fred.com"), + (User(name="chuck", id=10), None), + ], ) # anon + select from aliasing eq_( - sess.query(User).join(User.addresses, aliased=True). - filter(Address.email_address.like('%ed%')). - from_self().all(), - [ - User(name='ed', id=8), - User(name='fred', id=9), - ] + sess.query(User) + .join(User.addresses, aliased=True) + .filter(Address.email_address.like("%ed%")) + .from_self() + .all(), + [User(name="ed", id=8), User(name="fred", id=9)], ) # test eager aliasing, with/without select_entity_from aliasing for q in [ - sess.query(User, adalias.email_address). - outerjoin(adalias, User.addresses). - options(joinedload(User.addresses)). - order_by(User.id, adalias.id).limit(10), - sess.query(User, adalias.email_address, adalias.id). - outerjoin(adalias, User.addresses). - from_self(User, adalias.email_address). - options(joinedload(User.addresses)). - order_by(User.id, adalias.id).limit(10), + sess.query(User, adalias.email_address) + .outerjoin(adalias, User.addresses) + .options(joinedload(User.addresses)) + .order_by(User.id, adalias.id) + .limit(10), + sess.query(User, adalias.email_address, adalias.id) + .outerjoin(adalias, User.addresses) + .from_self(User, adalias.email_address) + .options(joinedload(User.addresses)) + .order_by(User.id, adalias.id) + .limit(10), ]: eq_( q.all(), @@ -1369,62 +1812,105 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): User( addresses=[ Address( - user_id=7, email_address='jack@bean.com', - id=1)], - name='jack', id=7), - 'jack@bean.com'), + user_id=7, + email_address="jack@bean.com", + id=1, + ) + ], + name="jack", + id=7, + ), + "jack@bean.com", + ), ( User( addresses=[ Address( - user_id=8, email_address='ed@wood.com', - id=2), + user_id=8, + email_address="ed@wood.com", + id=2, + ), Address( user_id=8, - email_address='ed@bettyboop.com', id=3), + email_address="ed@bettyboop.com", + id=3, + ), Address( - user_id=8, email_address='ed@lala.com', - id=4)], - name='ed', id=8), - 'ed@wood.com'), + user_id=8, + email_address="ed@lala.com", + id=4, + ), + ], + name="ed", + id=8, + ), + "ed@wood.com", + ), ( User( addresses=[ Address( - user_id=8, email_address='ed@wood.com', - id=2), + user_id=8, + email_address="ed@wood.com", + id=2, + ), Address( user_id=8, - email_address='ed@bettyboop.com', id=3), + email_address="ed@bettyboop.com", + id=3, + ), Address( - user_id=8, email_address='ed@lala.com', - id=4)], - name='ed', id=8), - 'ed@bettyboop.com'), + user_id=8, + email_address="ed@lala.com", + id=4, + ), + ], + name="ed", + id=8, + ), + "ed@bettyboop.com", + ), ( User( addresses=[ Address( - user_id=8, email_address='ed@wood.com', - id=2), + user_id=8, + email_address="ed@wood.com", + id=2, + ), Address( user_id=8, - email_address='ed@bettyboop.com', id=3), + email_address="ed@bettyboop.com", + id=3, + ), Address( - user_id=8, email_address='ed@lala.com', - id=4)], - name='ed', id=8), - 'ed@lala.com'), + user_id=8, + email_address="ed@lala.com", + id=4, + ), + ], + name="ed", + id=8, + ), + "ed@lala.com", + ), ( User( addresses=[ Address( - user_id=9, email_address='fred@fred.com', - id=5)], - name='fred', id=9), - 'fred@fred.com'), - - (User(addresses=[], name='chuck', id=10), None)]) + user_id=9, + email_address="fred@fred.com", + id=5, + ) + ], + name="fred", + id=9, + ), + "fred@fred.com", + ), + (User(addresses=[], name="chuck", id=10), None), + ], + ) def test_column_from_limited_joinedload(self): User = self.classes.User @@ -1432,9 +1918,15 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): sess = create_session() def go(): - results = sess.query(User).limit(1).\ - options(joinedload('addresses')).add_column(User.name).all() - eq_(results, [(User(name='jack'), 'jack')]) + results = ( + sess.query(User) + .limit(1) + .options(joinedload("addresses")) + .add_column(User.name) + .all() + ) + eq_(results, [(User(name="jack"), "jack")]) + self.assert_sql_count(testing.db, go, 1) @testing.fails_on("firebird", "unknown") @@ -1445,30 +1937,44 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): oalias = aliased(Order) for q in [ - sess.query(Order, oalias).filter(Order.user_id == oalias.user_id). - filter(Order.user_id == 7). - filter(Order.id > oalias.id).order_by(Order.id, oalias.id), - sess.query(Order, oalias).from_self(). - filter(Order.user_id == oalias.user_id).filter(Order.user_id == 7). - filter(Order.id > oalias.id).order_by(Order.id, oalias.id), - + sess.query(Order, oalias) + .filter(Order.user_id == oalias.user_id) + .filter(Order.user_id == 7) + .filter(Order.id > oalias.id) + .order_by(Order.id, oalias.id), + sess.query(Order, oalias) + .from_self() + .filter(Order.user_id == oalias.user_id) + .filter(Order.user_id == 7) + .filter(Order.id > oalias.id) + .order_by(Order.id, oalias.id), # same thing, but reversed. - sess.query(oalias, Order).from_self(). - filter(oalias.user_id == Order.user_id). - filter(oalias.user_id == 7).filter(Order.id < oalias.id). - order_by(oalias.id, Order.id), - + sess.query(oalias, Order) + .from_self() + .filter(oalias.user_id == Order.user_id) + .filter(oalias.user_id == 7) + .filter(Order.id < oalias.id) + .order_by(oalias.id, Order.id), # here we go....two layers of aliasing - sess.query(Order, oalias).filter(Order.user_id == oalias.user_id). - filter(Order.user_id == 7).filter(Order.id > oalias.id). - from_self().order_by(Order.id, oalias.id). - limit(10).options(joinedload(Order.items)), - + sess.query(Order, oalias) + .filter(Order.user_id == oalias.user_id) + .filter(Order.user_id == 7) + .filter(Order.id > oalias.id) + .from_self() + .order_by(Order.id, oalias.id) + .limit(10) + .options(joinedload(Order.items)), # gratuitous four layers - sess.query(Order, oalias).filter(Order.user_id == oalias.user_id). - filter(Order.user_id == 7).filter(Order.id > oalias.id). - from_self().from_self().from_self().order_by(Order.id, oalias.id). - limit(10).options(joinedload(Order.items)), + sess.query(Order, oalias) + .filter(Order.user_id == oalias.user_id) + .filter(Order.user_id == 7) + .filter(Order.id > oalias.id) + .from_self() + .from_self() + .from_self() + .order_by(Order.id, oalias.id) + .limit(10) + .options(joinedload(Order.items)), ]: eq_( @@ -1476,34 +1982,64 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): [ ( Order( - address_id=1, description='order 3', isopen=1, - user_id=7, id=3), + address_id=1, + description="order 3", + isopen=1, + user_id=7, + id=3, + ), Order( - address_id=1, description='order 1', isopen=0, - user_id=7, id=1)), + address_id=1, + description="order 1", + isopen=0, + user_id=7, + id=1, + ), + ), ( Order( - address_id=None, description='order 5', isopen=0, - user_id=7, id=5), + address_id=None, + description="order 5", + isopen=0, + user_id=7, + id=5, + ), Order( - address_id=1, description='order 1', isopen=0, - user_id=7, id=1)), + address_id=1, + description="order 1", + isopen=0, + user_id=7, + id=1, + ), + ), ( Order( - address_id=None, description='order 5', isopen=0, - user_id=7, id=5), + address_id=None, + description="order 5", + isopen=0, + user_id=7, + id=5, + ), Order( - address_id=1, description='order 3', isopen=1, - user_id=7, id=3)) - ] + address_id=1, + description="order 3", + isopen=1, + user_id=7, + id=3, + ), + ), + ], ) # ensure column expressions are taken from inside the subquery, not # restated at the top - q = sess.query( - Order.id, Order.description, - literal_column("'q'").label('foo')).\ - filter(Order.description == 'order 3').from_self() + q = ( + sess.query( + Order.id, Order.description, literal_column("'q'").label("foo") + ) + .filter(Order.description == "order 3") + .from_self() + ) self.assert_compile( q, "SELECT anon_1.orders_id AS " @@ -1514,97 +2050,122 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): "orders.description AS orders_description, " "'q' AS foo FROM orders WHERE " "orders.description = :description_1) AS " - "anon_1") - eq_( - q.all(), - [(3, 'order 3', 'q')] + "anon_1", ) + eq_(q.all(), [(3, "order 3", "q")]) def test_multi_mappers(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) test_session = create_session() (user7, user8, user9, user10) = test_session.query(User).all() - (address1, address2, address3, address4, address5) = \ - test_session.query(Address).all() + ( + address1, + address2, + address3, + address4, + address5, + ) = test_session.query(Address).all() - expected = [(user7, address1), - (user8, address2), - (user8, address3), - (user8, address4), - (user9, address5), - (user10, None)] + expected = [ + (user7, address1), + (user8, address2), + (user8, address3), + (user8, address4), + (user9, address5), + (user10, None), + ] sess = create_session() - selectquery = users.outerjoin(addresses). \ - select(use_labels=True, order_by=[users.c.id, addresses.c.id]) + selectquery = users.outerjoin(addresses).select( + use_labels=True, order_by=[users.c.id, addresses.c.id] + ) eq_( list(sess.query(User, Address).instances(selectquery.execute())), - expected) + expected, + ) sess.expunge_all() for address_entity in (Address, aliased(Address)): - q = sess.query(User).add_entity(address_entity).\ - outerjoin(address_entity, 'addresses').\ - order_by(User.id, address_entity.id) + q = ( + sess.query(User) + .add_entity(address_entity) + .outerjoin(address_entity, "addresses") + .order_by(User.id, address_entity.id) + ) eq_(q.all(), expected) sess.expunge_all() q = sess.query(User).add_entity(address_entity) - q = q.join(address_entity, 'addresses') - q = q.filter_by(email_address='ed@bettyboop.com') + q = q.join(address_entity, "addresses") + q = q.filter_by(email_address="ed@bettyboop.com") eq_(q.all(), [(user8, address3)]) sess.expunge_all() - q = sess.query(User, address_entity). \ - join(address_entity, 'addresses'). \ - filter_by(email_address='ed@bettyboop.com') + q = ( + sess.query(User, address_entity) + .join(address_entity, "addresses") + .filter_by(email_address="ed@bettyboop.com") + ) eq_(q.all(), [(user8, address3)]) sess.expunge_all() - q = sess.query(User, address_entity). \ - join(address_entity, 'addresses').\ - options(joinedload('addresses')).\ - filter_by(email_address='ed@bettyboop.com') + q = ( + sess.query(User, address_entity) + .join(address_entity, "addresses") + .options(joinedload("addresses")) + .filter_by(email_address="ed@bettyboop.com") + ) eq_(list(util.OrderedSet(q.all())), [(user8, address3)]) sess.expunge_all() def test_aliased_multi_mappers(self): - User, addresses, users, Address = (self.classes.User, - self.tables.addresses, - self.tables.users, - self.classes.Address) + User, addresses, users, Address = ( + self.classes.User, + self.tables.addresses, + self.tables.users, + self.classes.Address, + ) sess = create_session() (user7, user8, user9, user10) = sess.query(User).all() - (address1, address2, address3, address4, address5) = \ - sess.query(Address).all() + (address1, address2, address3, address4, address5) = sess.query( + Address + ).all() - expected = [(user7, address1), - (user8, address2), - (user8, address3), - (user8, address4), - (user9, address5), - (user10, None)] + expected = [ + (user7, address1), + (user8, address2), + (user8, address3), + (user8, address4), + (user9, address5), + (user10, None), + ] q = sess.query(User) - adalias = addresses.alias('adalias') - q = q.add_entity(Address, alias=adalias). \ - select_entity_from(users.outerjoin(adalias)) + adalias = addresses.alias("adalias") + q = q.add_entity(Address, alias=adalias).select_entity_from( + users.outerjoin(adalias) + ) result = q.order_by(User.id, adalias.c.id).all() assert result == expected sess.expunge_all() q = sess.query(User).add_entity(Address, alias=adalias) - result = q.select_entity_from(users.outerjoin(adalias)). \ - filter(adalias.c.email_address == 'ed@bettyboop.com').all() + result = ( + q.select_entity_from(users.outerjoin(adalias)) + .filter(adalias.c.email_address == "ed@bettyboop.com") + .all() + ) assert result == [(user8, address3)] def test_with_entities(self): @@ -1615,15 +2176,17 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): q = sess.query(User).filter(User.id == 7).order_by(User.name) self.assert_compile( - q.with_entities(User.id, Address). - filter(Address.user_id == User.id), - 'SELECT users.id AS users_id, addresses.id ' - 'AS addresses_id, addresses.user_id AS ' - 'addresses_user_id, addresses.email_address' - ' AS addresses_email_address FROM users, ' - 'addresses WHERE users.id = :id_1 AND ' - 'addresses.user_id = users.id ORDER BY ' - 'users.name') + q.with_entities(User.id, Address).filter( + Address.user_id == User.id + ), + "SELECT users.id AS users_id, addresses.id " + "AS addresses_id, addresses.user_id AS " + "addresses_user_id, addresses.email_address" + " AS addresses_email_address FROM users, " + "addresses WHERE users.id = :id_1 AND " + "addresses.user_id = users.id ORDER BY " + "users.name", + ) def test_multi_columns(self): users, User = self.tables.users, self.classes.User @@ -1637,7 +2200,8 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): sess.expunge_all() assert_raises( - sa_exc.InvalidRequestError, sess.query(User).add_column, object()) + sa_exc.InvalidRequestError, sess.query(User).add_column, object() + ) def test_add_multi_columns(self): """test that add_column accepts a FROM clause.""" @@ -1648,52 +2212,62 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): eq_( sess.query(User.id).add_column(users).all(), - [(7, 7, 'jack'), (8, 8, 'ed'), (9, 9, 'fred'), (10, 10, 'chuck')] + [(7, 7, "jack"), (8, 8, "ed"), (9, 9, "fred"), (10, 10, "chuck")], ) def test_multi_columns_2(self): """test aliased/nonalised joins with the usage of add_column()""" - User, Address, addresses, users = (self.classes.User, - self.classes.Address, - self.tables.addresses, - self.tables.users) + User, Address, addresses, users = ( + self.classes.User, + self.classes.Address, + self.tables.addresses, + self.tables.users, + ) sess = create_session() (user7, user8, user9, user10) = sess.query(User).all() - expected = [(user7, 1), - (user8, 3), - (user9, 1), - (user10, 0) - ] + expected = [(user7, 1), (user8, 3), (user9, 1), (user10, 0)] q = sess.query(User) - q = q.group_by(users).order_by(User.id).outerjoin('addresses').\ - add_column(func.count(Address.id).label('count')) + q = ( + q.group_by(users) + .order_by(User.id) + .outerjoin("addresses") + .add_column(func.count(Address.id).label("count")) + ) eq_(q.all(), expected) sess.expunge_all() adalias = aliased(Address) q = sess.query(User) - q = q.group_by(users).order_by(User.id). \ - outerjoin(adalias, 'addresses').\ - add_column(func.count(adalias.id).label('count')) + q = ( + q.group_by(users) + .order_by(User.id) + .outerjoin(adalias, "addresses") + .add_column(func.count(adalias.id).label("count")) + ) eq_(q.all(), expected) sess.expunge_all() # TODO: figure out why group_by(users) doesn't work here - s = select([users, func.count(addresses.c.id).label('count')]). \ - select_from(users.outerjoin(addresses)). \ - group_by(*[c for c in users.c]).order_by(User.id) + s = ( + select([users, func.count(addresses.c.id).label("count")]) + .select_from(users.outerjoin(addresses)) + .group_by(*[c for c in users.c]) + .order_by(User.id) + ) q = sess.query(User) result = q.add_column("count").from_statement(s).all() assert result == expected def test_raw_columns(self): - addresses, users, User = (self.tables.addresses, - self.tables.users, - self.classes.User) + addresses, users, User = ( + self.tables.addresses, + self.tables.users, + self.classes.User, + ) sess = create_session() (user7, user8, user9, user10) = sess.query(User).all() @@ -1701,52 +2275,77 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): (user7, 1, "Name:jack"), (user8, 3, "Name:ed"), (user9, 1, "Name:fred"), - (user10, 0, "Name:chuck")] + (user10, 0, "Name:chuck"), + ] adalias = addresses.alias() - q = create_session().query(User).add_column(func.count(adalias.c.id))\ - .add_column(("Name:" + users.c.name))\ - .outerjoin(adalias, 'addresses')\ - .group_by(users).order_by(users.c.id) + q = ( + create_session() + .query(User) + .add_column(func.count(adalias.c.id)) + .add_column(("Name:" + users.c.name)) + .outerjoin(adalias, "addresses") + .group_by(users) + .order_by(users.c.id) + ) assert q.all() == expected # test with a straight statement s = select( [ - users, func.count(addresses.c.id).label('count'), - ("Name:" + users.c.name).label('concat')], + users, + func.count(addresses.c.id).label("count"), + ("Name:" + users.c.name).label("concat"), + ], from_obj=[users.outerjoin(addresses)], - group_by=[c for c in users.c], order_by=[users.c.id]) + group_by=[c for c in users.c], + order_by=[users.c.id], + ) q = create_session().query(User) - result = q.add_column("count").add_column("concat") \ - .from_statement(s).all() + result = ( + q.add_column("count").add_column("concat").from_statement(s).all() + ) assert result == expected sess.expunge_all() # test with select_entity_from() - q = create_session().query(User) \ - .add_column(func.count(addresses.c.id)) \ - .add_column(("Name:" + users.c.name)) \ - .select_entity_from(users.outerjoin(addresses)) \ - .group_by(users).order_by(users.c.id) + q = ( + create_session() + .query(User) + .add_column(func.count(addresses.c.id)) + .add_column(("Name:" + users.c.name)) + .select_entity_from(users.outerjoin(addresses)) + .group_by(users) + .order_by(users.c.id) + ) assert q.all() == expected sess.expunge_all() - q = create_session().query(User) \ - .add_column(func.count(addresses.c.id)) \ - .add_column(("Name:" + users.c.name)).outerjoin('addresses')\ - .group_by(users).order_by(users.c.id) + q = ( + create_session() + .query(User) + .add_column(func.count(addresses.c.id)) + .add_column(("Name:" + users.c.name)) + .outerjoin("addresses") + .group_by(users) + .order_by(users.c.id) + ) assert q.all() == expected sess.expunge_all() - q = create_session().query(User).add_column(func.count(adalias.c.id)) \ - .add_column(("Name:" + users.c.name)) \ - .outerjoin(adalias, 'addresses') \ - .group_by(users).order_by(users.c.id) + q = ( + create_session() + .query(User) + .add_column(func.count(adalias.c.id)) + .add_column(("Name:" + users.c.name)) + .outerjoin(adalias, "addresses") + .group_by(users) + .order_by(users.c.id) + ) assert q.all() == expected sess.expunge_all() @@ -1759,25 +2358,33 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): s = create_session() for crit, j, exp in [ ( - User.id + Address.id, User.addresses, + User.id + Address.id, + User.addresses, "SELECT users.id + addresses.id AS anon_1 " "FROM users JOIN addresses ON users.id = " - "addresses.user_id"), + "addresses.user_id", + ), ( - User.id + Address.id, Address.user, + User.id + Address.id, + Address.user, "SELECT users.id + addresses.id AS anon_1 " "FROM addresses JOIN users ON users.id = " - "addresses.user_id"), + "addresses.user_id", + ), ( - Address.id + User.id, User.addresses, + Address.id + User.id, + User.addresses, "SELECT addresses.id + users.id AS anon_1 " "FROM users JOIN addresses ON users.id = " - "addresses.user_id"), + "addresses.user_id", + ), ( - User.id + aa.id, (aa, User.addresses), + User.id + aa.id, + (aa, User.addresses), "SELECT users.id + addresses_1.id AS anon_1 " "FROM users JOIN addresses AS addresses_1 " - "ON users.id = addresses_1.user_id"), + "ON users.id = addresses_1.user_id", + ), ]: q = s.query(crit) mzero = q._entity_zero() @@ -1787,21 +2394,27 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): for crit, j, exp in [ ( - ua.id + Address.id, ua.addresses, + ua.id + Address.id, + ua.addresses, "SELECT users_1.id + addresses.id AS anon_1 " "FROM users AS users_1 JOIN addresses " - "ON users_1.id = addresses.user_id"), + "ON users_1.id = addresses.user_id", + ), ( - ua.id + aa.id, (aa, ua.addresses), + ua.id + aa.id, + (aa, ua.addresses), "SELECT users_1.id + addresses_1.id AS anon_1 " "FROM users AS users_1 JOIN addresses AS " - "addresses_1 ON users_1.id = addresses_1.user_id"), + "addresses_1 ON users_1.id = addresses_1.user_id", + ), ( - ua.id + aa.id, (ua, aa.user), + ua.id + aa.id, + (ua, aa.user), "SELECT users_1.id + addresses_1.id AS anon_1 " "FROM addresses AS addresses_1 JOIN " "users AS users_1 " - "ON users_1.id = addresses_1.user_id") + "ON users_1.id = addresses_1.user_id", + ), ]: q = s.query(crit) mzero = q._entity_zero() @@ -1815,50 +2428,57 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): sess = Session() agg_address = sess.query( Address.id, - func.sum(func.length(Address.email_address)). - label('email_address')).group_by(Address.user_id) + func.sum(func.length(Address.email_address)).label( + "email_address" + ), + ).group_by(Address.user_id) ag1 = aliased(Address, agg_address.subquery()) ag2 = aliased(Address, agg_address.subquery(), adapt_on_names=True) # first, without adapt on names, 'email_address' isn't matched up - we # get the raw "address" element in the SELECT self.assert_compile( - sess.query(User, ag1.email_address).join(ag1, User.addresses). - filter(ag1.email_address > 5), + sess.query(User, ag1.email_address) + .join(ag1, User.addresses) + .filter(ag1.email_address > 5), "SELECT users.id " "AS users_id, users.name AS users_name, addresses.email_address " "AS addresses_email_address FROM addresses, users JOIN " "(SELECT addresses.id AS id, sum(length(addresses.email_address)) " "AS email_address FROM addresses GROUP BY addresses.user_id) AS " "anon_1 ON users.id = addresses.user_id " - "WHERE addresses.email_address > :email_address_1") + "WHERE addresses.email_address > :email_address_1", + ) # second, 'email_address' matches up to the aggreagte, and we get a # smooth JOIN from users->subquery and that's it self.assert_compile( - sess.query(User, ag2.email_address).join(ag2, User.addresses). - filter(ag2.email_address > 5), + sess.query(User, ag2.email_address) + .join(ag2, User.addresses) + .filter(ag2.email_address > 5), "SELECT users.id AS users_id, users.name AS users_name, " "anon_1.email_address AS anon_1_email_address FROM users " "JOIN (" "SELECT addresses.id AS id, sum(length(addresses.email_address)) " "AS email_address FROM addresses GROUP BY addresses.user_id) AS " "anon_1 ON users.id = addresses.user_id " - "WHERE anon_1.email_address > :email_address_1",) + "WHERE anon_1.email_address > :email_address_1", + ) class SelectFromTest(QueryTest, AssertsCompiledSQL): run_setup_mappers = None - __dialect__ = 'default' + __dialect__ = "default" def test_replace_with_select(self): users, Address, addresses, User = ( - self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper( - User, users, properties={ - 'addresses': relationship(Address)}) + mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) sel = users.select(users.c.id.in_([7, 8])).alias() @@ -1866,27 +2486,40 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): eq_( sess.query(User).select_entity_from(sel).all(), - [User(id=7), User(id=8)]) + [User(id=7), User(id=8)], + ) eq_( - sess.query(User).select_entity_from(sel). - filter(User.id == 8).all(), - [User(id=8)]) + sess.query(User) + .select_entity_from(sel) + .filter(User.id == 8) + .all(), + [User(id=8)], + ) eq_( - sess.query(User).select_entity_from(sel). - order_by(desc(User.name)).all(), [ - User(name='jack', id=7), User(name='ed', id=8)]) + sess.query(User) + .select_entity_from(sel) + .order_by(desc(User.name)) + .all(), + [User(name="jack", id=7), User(name="ed", id=8)], + ) eq_( - sess.query(User).select_entity_from(sel). - order_by(asc(User.name)).all(), [ - User(name='ed', id=8), User(name='jack', id=7)]) + sess.query(User) + .select_entity_from(sel) + .order_by(asc(User.name)) + .all(), + [User(name="ed", id=8), User(name="jack", id=7)], + ) eq_( - sess.query(User).select_entity_from(sel). - options(joinedload('addresses')).first(), - User(name='jack', addresses=[Address(id=1)])) + sess.query(User) + .select_entity_from(sel) + .options(joinedload("addresses")) + .first(), + User(name="jack", addresses=[Address(id=1)]), + ) def test_select_from_aliased(self): User, users = self.classes.User, self.tables.users @@ -1895,23 +2528,16 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): sess = create_session() - not_users = table('users', column('id'), column('name')) - ua = aliased( - User, - select([not_users]).alias(), - adapt_on_names=True - ) + not_users = table("users", column("id"), column("name")) + ua = aliased(User, select([not_users]).alias(), adapt_on_names=True) q = sess.query(User.name).select_entity_from(ua).order_by(User.name) self.assert_compile( q, "SELECT anon_1.name AS anon_1_name FROM (SELECT users.id AS id, " - "users.name AS name FROM users) AS anon_1 ORDER BY anon_1.name" - ) - eq_( - q.all(), - [('chuck',), ('ed',), ('fred',), ('jack',)] + "users.name AS name FROM users) AS anon_1 ORDER BY anon_1.name", ) + eq_(q.all(), [("chuck",), ("ed",), ("fred",), ("jack",)]) @testing.uses_deprecated("Mapper.order_by") def test_join_mapper_order_by(self): @@ -1926,8 +2552,8 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): eq_( sess.query(User).select_entity_from(sel).all(), - [ - User(name='jack', id=7), User(name='ed', id=8)]) + [User(name="jack", id=7), User(name="ed", id=8)], + ) def test_differentiate_self_external(self): """test some different combinations of joining a table to a subquery of @@ -1947,32 +2573,39 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): "SELECT users.id AS users_id, users.name AS users_name FROM " "users JOIN (SELECT users.id AS id, users.name AS name FROM users " "WHERE users.id IN (:id_1, :id_2)) " - "AS anon_1 ON users.id > anon_1.id",) + "AS anon_1 ON users.id > anon_1.id", + ) self.assert_compile( - sess.query(ualias).select_entity_from(sel). - filter(ualias.id > sel.c.id), + sess.query(ualias) + .select_entity_from(sel) + .filter(ualias.id > sel.c.id), "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name " "FROM users AS users_1, (" "SELECT users.id AS id, users.name AS name FROM users " "WHERE users.id IN (:id_1, :id_2)) AS anon_1 " - "WHERE users_1.id > anon_1.id",) + "WHERE users_1.id > anon_1.id", + ) self.assert_compile( - sess.query(ualias).select_entity_from(sel). - join(ualias, ualias.id > sel.c.id), + sess.query(ualias) + .select_entity_from(sel) + .join(ualias, ualias.id > sel.c.id), "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name " "FROM (SELECT users.id AS id, users.name AS name " "FROM users WHERE users.id IN (:id_1, :id_2)) AS anon_1 " - "JOIN users AS users_1 ON users_1.id > anon_1.id") + "JOIN users AS users_1 ON users_1.id > anon_1.id", + ) self.assert_compile( - sess.query(ualias).select_entity_from(sel). - join(ualias, ualias.id > User.id), + sess.query(ualias) + .select_entity_from(sel) + .join(ualias, ualias.id > User.id), "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name " "FROM (SELECT users.id AS id, users.name AS name FROM " "users WHERE users.id IN (:id_1, :id_2)) AS anon_1 " - "JOIN users AS users_1 ON users_1.id > anon_1.id") + "JOIN users AS users_1 ON users_1.id > anon_1.id", + ) salias = aliased(User, sel) self.assert_compile( @@ -1980,17 +2613,20 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): "SELECT anon_1.id AS anon_1_id, anon_1.name AS anon_1_name FROM " "(SELECT users.id AS id, users.name AS name " "FROM users WHERE users.id IN (:id_1, :id_2)) AS anon_1 " - "JOIN users AS users_1 ON users_1.id > anon_1.id",) + "JOIN users AS users_1 ON users_1.id > anon_1.id", + ) self.assert_compile( sess.query(ualias).select_entity_from( - join(sel, ualias, ualias.id > sel.c.id)), + join(sel, ualias, ualias.id > sel.c.id) + ), "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name " "FROM " "(SELECT users.id AS id, users.name AS name " "FROM users WHERE users.id " "IN (:id_1, :id_2)) AS anon_1 " - "JOIN users AS users_1 ON users_1.id > anon_1.id") + "JOIN users AS users_1 ON users_1.id > anon_1.id", + ) def test_aliased_class_vs_nonaliased(self): User, users = self.classes.User, self.tables.users @@ -2002,42 +2638,46 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): self.assert_compile( sess.query(User).select_from(ua).join(User, ua.name > User.name), "SELECT users.id AS users_id, users.name AS users_name " - "FROM users AS users_1 JOIN users ON users_1.name > users.name" + "FROM users AS users_1 JOIN users ON users_1.name > users.name", ) self.assert_compile( - sess.query(User.name).select_from(ua). - join(User, ua.name > User.name), + sess.query(User.name) + .select_from(ua) + .join(User, ua.name > User.name), "SELECT users.name AS users_name FROM users AS users_1 " - "JOIN users ON users_1.name > users.name" + "JOIN users ON users_1.name > users.name", ) self.assert_compile( - sess.query(ua.name).select_from(ua). - join(User, ua.name > User.name), + sess.query(ua.name) + .select_from(ua) + .join(User, ua.name > User.name), "SELECT users_1.name AS users_1_name FROM users AS users_1 " - "JOIN users ON users_1.name > users.name" + "JOIN users ON users_1.name > users.name", ) self.assert_compile( sess.query(ua).select_from(User).join(ua, ua.name > User.name), "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name " - "FROM users JOIN users AS users_1 ON users_1.name > users.name" + "FROM users JOIN users AS users_1 ON users_1.name > users.name", ) self.assert_compile( sess.query(ua).select_from(User).join(ua, User.name > ua.name), "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name " - "FROM users JOIN users AS users_1 ON users.name > users_1.name" + "FROM users JOIN users AS users_1 ON users.name > users_1.name", ) # this is tested in many other places here, just adding it # here for comparison self.assert_compile( sess.query(User.name).select_entity_from( - users.select().where(users.c.id > 5)), + users.select().where(users.c.id > 5) + ), "SELECT anon_1.name AS anon_1_name FROM (SELECT users.id AS id, " - "users.name AS name FROM users WHERE users.id > :id_1) AS anon_1") + "users.name AS name FROM users WHERE users.id > :id_1) AS anon_1", + ) def test_join_no_order_by(self): User, users = self.classes.User, self.tables.users @@ -2049,163 +2689,259 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): eq_( sess.query(User).select_entity_from(sel).all(), - [User(name='jack', id=7), User(name='ed', id=8)]) + [User(name="jack", id=7), User(name="ed", id=8)], + ) def test_join_relname_from_selected_from(self): User, Address = self.classes.User, self.classes.Address users, addresses = self.tables.users, self.tables.addresses - mapper(User, users, properties={'addresses': relationship( - mapper(Address, addresses), backref='user')}) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), backref="user" + ) + }, + ) sess = create_session() self.assert_compile( sess.query(User).select_from(Address).join("user"), "SELECT users.id AS users_id, users.name AS users_name " - "FROM addresses JOIN users ON users.id = addresses.user_id" + "FROM addresses JOIN users ON users.id = addresses.user_id", ) def test_filter_by_selected_from(self): User, Address = self.classes.User, self.classes.Address users, addresses = self.tables.users, self.tables.addresses - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses))}) + mapper( + User, + users, + properties={"addresses": relationship(mapper(Address, addresses))}, + ) sess = create_session() self.assert_compile( - sess.query(User).select_from(Address). - filter_by(email_address='ed').join(User), + sess.query(User) + .select_from(Address) + .filter_by(email_address="ed") + .join(User), "SELECT users.id AS users_id, users.name AS users_name " "FROM addresses JOIN users ON users.id = addresses.user_id " - "WHERE addresses.email_address = :email_address_1" + "WHERE addresses.email_address = :email_address_1", ) def test_join_ent_selected_from(self): User, Address = self.classes.User, self.classes.Address users, addresses = self.tables.users, self.tables.addresses - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses))}) + mapper( + User, + users, + properties={"addresses": relationship(mapper(Address, addresses))}, + ) sess = create_session() self.assert_compile( sess.query(User).select_from(Address).join(User), "SELECT users.id AS users_id, users.name AS users_name " - "FROM addresses JOIN users ON users.id = addresses.user_id" + "FROM addresses JOIN users ON users.id = addresses.user_id", ) def test_join(self): users, Address, addresses, User = ( - self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users, properties={'addresses': relationship(Address)}) + mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) sel = users.select(users.c.id.in_([7, 8])) sess = create_session() eq_( - sess.query(User).select_entity_from(sel).join('addresses'). - add_entity(Address).order_by(User.id).order_by(Address.id).all(), + sess.query(User) + .select_entity_from(sel) + .join("addresses") + .add_entity(Address) + .order_by(User.id) + .order_by(Address.id) + .all(), [ ( - User(name='jack', id=7), - Address(user_id=7, email_address='jack@bean.com', id=1)), + User(name="jack", id=7), + Address(user_id=7, email_address="jack@bean.com", id=1), + ), ( - User(name='ed', id=8), - Address(user_id=8, email_address='ed@wood.com', id=2)), + User(name="ed", id=8), + Address(user_id=8, email_address="ed@wood.com", id=2), + ), ( - User(name='ed', id=8), - Address( - user_id=8, email_address='ed@bettyboop.com', id=3)), + User(name="ed", id=8), + Address(user_id=8, email_address="ed@bettyboop.com", id=3), + ), ( - User(name='ed', id=8), - Address(user_id=8, email_address='ed@lala.com', id=4))]) + User(name="ed", id=8), + Address(user_id=8, email_address="ed@lala.com", id=4), + ), + ], + ) adalias = aliased(Address) eq_( - sess.query(User).select_entity_from(sel). - join(adalias, 'addresses').add_entity(adalias).order_by(User.id). - order_by(adalias.id).all(), + sess.query(User) + .select_entity_from(sel) + .join(adalias, "addresses") + .add_entity(adalias) + .order_by(User.id) + .order_by(adalias.id) + .all(), [ ( - User(name='jack', id=7), - Address(user_id=7, email_address='jack@bean.com', id=1)), + User(name="jack", id=7), + Address(user_id=7, email_address="jack@bean.com", id=1), + ), ( - User(name='ed', id=8), - Address(user_id=8, email_address='ed@wood.com', id=2)), + User(name="ed", id=8), + Address(user_id=8, email_address="ed@wood.com", id=2), + ), ( - User(name='ed', id=8), - Address( - user_id=8, email_address='ed@bettyboop.com', id=3)), + User(name="ed", id=8), + Address(user_id=8, email_address="ed@bettyboop.com", id=3), + ), ( - User(name='ed', id=8), - Address(user_id=8, email_address='ed@lala.com', id=4))]) + User(name="ed", id=8), + Address(user_id=8, email_address="ed@lala.com", id=4), + ), + ], + ) def test_more_joins(self): ( - users, Keyword, orders, items, order_items, Order, Item, User, - keywords, item_keywords) = \ - ( - self.tables.users, self.classes.Keyword, self.tables.orders, - self.tables.items, self.tables.order_items, self.classes.Order, - self.classes.Item, self.classes.User, self.tables.keywords, - self.tables.item_keywords) + users, + Keyword, + orders, + items, + order_items, + Order, + Item, + User, + keywords, + item_keywords, + ) = ( + self.tables.users, + self.classes.Keyword, + self.tables.orders, + self.tables.items, + self.tables.order_items, + self.classes.Order, + self.classes.Item, + self.classes.User, + self.tables.keywords, + self.tables.item_keywords, + ) mapper( - User, users, properties={ - 'orders': relationship(Order, backref='user')}) # o2m, m2o + User, + users, + properties={"orders": relationship(Order, backref="user")}, + ) # o2m, m2o mapper( - Order, orders, properties={ - 'items': relationship( - Item, secondary=order_items, order_by=items.c.id)}) # m2m + Order, + orders, + properties={ + "items": relationship( + Item, secondary=order_items, order_by=items.c.id + ) + }, + ) # m2m mapper( - Item, items, properties={ - 'keywords': relationship( - Keyword, secondary=item_keywords, - order_by=keywords.c.id)}) # m2m + Item, + items, + properties={ + "keywords": relationship( + Keyword, secondary=item_keywords, order_by=keywords.c.id + ) + }, + ) # m2m mapper(Keyword, keywords) sess = create_session() sel = users.select(users.c.id.in_([7, 8])) eq_( - sess.query(User).select_entity_from(sel). - join('orders', 'items', 'keywords'). - filter(Keyword.name.in_(['red', 'big', 'round'])).all(), - [User(name='jack', id=7)]) + sess.query(User) + .select_entity_from(sel) + .join("orders", "items", "keywords") + .filter(Keyword.name.in_(["red", "big", "round"])) + .all(), + [User(name="jack", id=7)], + ) eq_( - sess.query(User).select_entity_from(sel). - join('orders', 'items', 'keywords', aliased=True). - filter(Keyword.name.in_(['red', 'big', 'round'])).all(), - [User(name='jack', id=7)]) + sess.query(User) + .select_entity_from(sel) + .join("orders", "items", "keywords", aliased=True) + .filter(Keyword.name.in_(["red", "big", "round"])) + .all(), + [User(name="jack", id=7)], + ) def test_very_nested_joins_with_joinedload(self): ( - users, Keyword, orders, items, order_items, Order, Item, User, - keywords, item_keywords) = \ - ( - self.tables.users, self.classes.Keyword, self.tables.orders, - self.tables.items, self.tables.order_items, self.classes.Order, - self.classes.Item, self.classes.User, self.tables.keywords, - self.tables.item_keywords) + users, + Keyword, + orders, + items, + order_items, + Order, + Item, + User, + keywords, + item_keywords, + ) = ( + self.tables.users, + self.classes.Keyword, + self.tables.orders, + self.tables.items, + self.tables.order_items, + self.classes.Order, + self.classes.Item, + self.classes.User, + self.tables.keywords, + self.tables.item_keywords, + ) mapper( - User, users, properties={ - 'orders': relationship(Order, backref='user')}) # o2m, m2o + User, + users, + properties={"orders": relationship(Order, backref="user")}, + ) # o2m, m2o mapper( - Order, orders, properties={ - 'items': relationship( - Item, secondary=order_items, order_by=items.c.id)}) # m2m + Order, + orders, + properties={ + "items": relationship( + Item, secondary=order_items, order_by=items.c.id + ) + }, + ) # m2m mapper( - Item, items, properties={ - 'keywords': relationship( - Keyword, secondary=item_keywords, - order_by=keywords.c.id)}) # m2m + Item, + items, + properties={ + "keywords": relationship( + Keyword, secondary=item_keywords, order_by=keywords.c.id + ) + }, + ) # m2m mapper(Keyword, keywords) sess = create_session() @@ -2214,72 +2950,119 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): def go(): eq_( - sess.query(User).select_entity_from(sel). - options(joinedload_all('orders.items.keywords')). - join('orders', 'items', 'keywords', aliased=True). - filter(Keyword.name.in_(['red', 'big', 'round'])). - all(), + sess.query(User) + .select_entity_from(sel) + .options(joinedload_all("orders.items.keywords")) + .join("orders", "items", "keywords", aliased=True) + .filter(Keyword.name.in_(["red", "big", "round"])) + .all(), [ - User(name='jack', orders=[ - Order( - description='order 1', items=[ - Item( - description='item 1', keywords=[ - Keyword(name='red'), - Keyword(name='big'), - Keyword(name='round')]), - Item( - description='item 2', keywords=[ - Keyword(name='red', id=2), - Keyword(name='small', id=5), - Keyword(name='square')]), - Item( - description='item 3', keywords=[ - Keyword(name='green', id=3), - Keyword(name='big', id=4), - Keyword(name='round', id=6)])]), - Order( - description='order 3', items=[ - Item( - description='item 3', keywords=[ - Keyword(name='green', id=3), - Keyword(name='big', id=4), - Keyword(name='round', id=6)]), - Item(description='item 4', keywords=[], id=4), - Item( - description='item 5', keywords=[], id=5)]), - Order( - description='order 5', - items=[ - Item(description='item 5', keywords=[])])])]) + User( + name="jack", + orders=[ + Order( + description="order 1", + items=[ + Item( + description="item 1", + keywords=[ + Keyword(name="red"), + Keyword(name="big"), + Keyword(name="round"), + ], + ), + Item( + description="item 2", + keywords=[ + Keyword(name="red", id=2), + Keyword(name="small", id=5), + Keyword(name="square"), + ], + ), + Item( + description="item 3", + keywords=[ + Keyword(name="green", id=3), + Keyword(name="big", id=4), + Keyword(name="round", id=6), + ], + ), + ], + ), + Order( + description="order 3", + items=[ + Item( + description="item 3", + keywords=[ + Keyword(name="green", id=3), + Keyword(name="big", id=4), + Keyword(name="round", id=6), + ], + ), + Item( + description="item 4", keywords=[], id=4 + ), + Item( + description="item 5", keywords=[], id=5 + ), + ], + ), + Order( + description="order 5", + items=[ + Item(description="item 5", keywords=[]) + ], + ), + ], + ) + ], + ) + self.assert_sql_count(testing.db, go, 1) sess.expunge_all() sel2 = orders.select(orders.c.id.in_([1, 2, 3])) eq_( - sess.query(Order).select_entity_from(sel2). - join('items', 'keywords').filter(Keyword.name == 'red'). - order_by(Order.id).all(), + sess.query(Order) + .select_entity_from(sel2) + .join("items", "keywords") + .filter(Keyword.name == "red") + .order_by(Order.id) + .all(), [ - Order(description='order 1', id=1), - Order(description='order 2', id=2)]) + Order(description="order 1", id=1), + Order(description="order 2", id=2), + ], + ) eq_( - sess.query(Order).select_entity_from(sel2). - join('items', 'keywords', aliased=True). - filter(Keyword.name == 'red').order_by(Order.id).all(), + sess.query(Order) + .select_entity_from(sel2) + .join("items", "keywords", aliased=True) + .filter(Keyword.name == "red") + .order_by(Order.id) + .all(), [ - Order(description='order 1', id=1), - Order(description='order 2', id=2)]) + Order(description="order 1", id=1), + Order(description="order 2", id=2), + ], + ) def test_replace_with_eager(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper( - User, users, properties={ - 'addresses': relationship(Address, order_by=addresses.c.id)}) + User, + users, + properties={ + "addresses": relationship(Address, order_by=addresses.c.id) + }, + ) mapper(Address, addresses) sel = users.select(users.c.id.in_([7, 8])) @@ -2287,35 +3070,62 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): def go(): eq_( - sess.query(User).options(joinedload('addresses')). - select_entity_from(sel).order_by(User.id).all(), + sess.query(User) + .options(joinedload("addresses")) + .select_entity_from(sel) + .order_by(User.id) + .all(), [ User(id=7, addresses=[Address(id=1)]), User( - id=8, addresses=[Address(id=2), Address(id=3), - Address(id=4)])]) + id=8, + addresses=[ + Address(id=2), + Address(id=3), + Address(id=4), + ], + ), + ], + ) + self.assert_sql_count(testing.db, go, 1) sess.expunge_all() def go(): eq_( - sess.query(User).options(joinedload('addresses')). - select_entity_from(sel).filter(User.id == 8).order_by(User.id). - all(), + sess.query(User) + .options(joinedload("addresses")) + .select_entity_from(sel) + .filter(User.id == 8) + .order_by(User.id) + .all(), [ User( - id=8, addresses=[Address(id=2), Address(id=3), - Address(id=4)])]) + id=8, + addresses=[ + Address(id=2), + Address(id=3), + Address(id=4), + ], + ) + ], + ) + self.assert_sql_count(testing.db, go, 1) sess.expunge_all() def go(): eq_( - sess.query(User).options(joinedload('addresses')). - select_entity_from(sel).order_by(User.id)[1], + sess.query(User) + .options(joinedload("addresses")) + .select_entity_from(sel) + .order_by(User.id)[1], User( - id=8, addresses=[Address(id=2), Address(id=3), - Address(id=4)])) + id=8, + addresses=[Address(id=2), Address(id=3), Address(id=4)], + ), + ) + self.assert_sql_count(testing.db, go, 1) @@ -2326,41 +3136,71 @@ class CustomJoinTest(QueryTest): """test aliasing of joins with a custom join condition""" ( - addresses, items, order_items, orders, Item, User, Address, Order, - users) = \ - ( - self.tables.addresses, self.tables.items, - self.tables.order_items, self.tables.orders, self.classes.Item, - self.classes.User, self.classes.Address, self.classes.Order, - self.tables.users) + addresses, + items, + order_items, + orders, + Item, + User, + Address, + Order, + users, + ) = ( + self.tables.addresses, + self.tables.items, + self.tables.order_items, + self.tables.orders, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.classes.Order, + self.tables.users, + ) mapper(Address, addresses) mapper( - Order, orders, properties={ - 'items': relationship( - Item, secondary=order_items, lazy='select', - order_by=items.c.id)}) + Order, + orders, + properties={ + "items": relationship( + Item, + secondary=order_items, + lazy="select", + order_by=items.c.id, + ) + }, + ) mapper(Item, items) mapper( - User, users, properties=dict( - addresses=relationship(Address, lazy='select'), + User, + users, + properties=dict( + addresses=relationship(Address, lazy="select"), open_orders=relationship( Order, primaryjoin=and_( - orders.c.isopen == 1, users.c.id == orders.c.user_id), - lazy='select'), + orders.c.isopen == 1, users.c.id == orders.c.user_id + ), + lazy="select", + ), closed_orders=relationship( Order, primaryjoin=and_( - orders.c.isopen == 0, users.c.id == orders.c.user_id), - lazy='select'))) + orders.c.isopen == 0, users.c.id == orders.c.user_id + ), + lazy="select", + ), + ), + ) q = create_session().query(User) eq_( - q.join('open_orders', 'items', aliased=True).filter(Item.id == 4). - join('closed_orders', 'items', aliased=True).filter(Item.id == 3). - all(), - [User(id=7)] + q.join("open_orders", "items", aliased=True) + .filter(Item.id == 4) + .join("closed_orders", "items", aliased=True) + .filter(Item.id == 3) + .all(), + [User(id=7)], ) @@ -2374,37 +3214,46 @@ class ExternalColumnsTest(QueryTest): assert_raises_message( sa_exc.ArgumentError, - "not represented in the mapper's table", mapper, User, users, - properties={ - 'concat': (users.c.id * 2), - }) + "not represented in the mapper's table", + mapper, + User, + users, + properties={"concat": (users.c.id * 2)}, + ) clear_mappers() def test_external_columns(self): """test querying mappings that reference external columns or selectables.""" - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper( - User, users, properties={ - 'concat': column_property((users.c.id * 2)), - 'count': column_property( + User, + users, + properties={ + "concat": column_property((users.c.id * 2)), + "count": column_property( select( [func.count(addresses.c.id)], - users.c.id == addresses.c.user_id).correlate(users). - as_scalar())}) + users.c.id == addresses.c.user_id, + ) + .correlate(users) + .as_scalar() + ), + }, + ) - mapper(Address, addresses, properties={ - 'user': relationship(User) - }) + mapper(Address, addresses, properties={"user": relationship(User)}) sess = create_session() - sess.query(Address).options(joinedload('user')).all() + sess.query(Address).options(joinedload("user")).all() eq_( sess.query(User).all(), @@ -2413,14 +3262,15 @@ class ExternalColumnsTest(QueryTest): User(id=8, concat=16, count=3), User(id=9, concat=18, count=1), User(id=10, concat=20, count=0), - ]) + ], + ) address_result = [ Address(id=1, user=User(id=7, concat=14, count=1)), Address(id=2, user=User(id=8, concat=16, count=3)), Address(id=3, user=User(id=8, concat=16, count=3)), Address(id=4, user=User(id=8, concat=16, count=3)), - Address(id=5, user=User(id=9, concat=18, count=1)) + Address(id=5, user=User(id=9, concat=18, count=1)), ] eq_(sess.query(Address).all(), address_result) @@ -2430,80 +3280,106 @@ class ExternalColumnsTest(QueryTest): def go(): eq_( - sess.query(Address).options(joinedload('user')). - order_by(Address.id).all(), - address_result) + sess.query(Address) + .options(joinedload("user")) + .order_by(Address.id) + .all(), + address_result, + ) + self.assert_sql_count(testing.db, go, 1) ualias = aliased(User) eq_( - sess.query(Address, ualias).join(ualias, 'user').all(), - [(address, address.user) for address in address_result] + sess.query(Address, ualias).join(ualias, "user").all(), + [(address, address.user) for address in address_result], ) eq_( - sess.query(Address, ualias.count).join(ualias, 'user'). - join('user', aliased=True).order_by(Address.id).all(), + sess.query(Address, ualias.count) + .join(ualias, "user") + .join("user", aliased=True) + .order_by(Address.id) + .all(), [ (Address(id=1), 1), (Address(id=2), 3), (Address(id=3), 3), (Address(id=4), 3), - (Address(id=5), 1) - ] + (Address(id=5), 1), + ], ) eq_( - sess.query(Address, ualias.concat, ualias.count). - join(ualias, 'user'). - join('user', aliased=True).order_by(Address.id).all(), + sess.query(Address, ualias.concat, ualias.count) + .join(ualias, "user") + .join("user", aliased=True) + .order_by(Address.id) + .all(), [ (Address(id=1), 14, 1), (Address(id=2), 16, 3), (Address(id=3), 16, 3), (Address(id=4), 16, 3), - (Address(id=5), 18, 1) - ] + (Address(id=5), 18, 1), + ], ) ua = aliased(User) eq_( - sess.query(Address, ua.concat, ua.count). - select_entity_from(join(Address, ua, 'user')). - options(joinedload(Address.user)).order_by(Address.id).all(), + sess.query(Address, ua.concat, ua.count) + .select_entity_from(join(Address, ua, "user")) + .options(joinedload(Address.user)) + .order_by(Address.id) + .all(), [ (Address(id=1, user=User(id=7, concat=14, count=1)), 14, 1), (Address(id=2, user=User(id=8, concat=16, count=3)), 16, 3), (Address(id=3, user=User(id=8, concat=16, count=3)), 16, 3), (Address(id=4, user=User(id=8, concat=16, count=3)), 16, 3), - (Address(id=5, user=User(id=9, concat=18, count=1)), 18, 1) - ]) + (Address(id=5, user=User(id=9, concat=18, count=1)), 18, 1), + ], + ) eq_( list( - sess.query(Address).join('user'). - values(Address.id, User.id, User.concat, User.count)), + sess.query(Address) + .join("user") + .values(Address.id, User.id, User.concat, User.count) + ), [ - (1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), - (5, 9, 18, 1)]) + (1, 7, 14, 1), + (2, 8, 16, 3), + (3, 8, 16, 3), + (4, 8, 16, 3), + (5, 9, 18, 1), + ], + ) eq_( list( - sess.query(Address, ua). - select_entity_from(join(Address, ua, 'user')). - values(Address.id, ua.id, ua.concat, ua.count)), + sess.query(Address, ua) + .select_entity_from(join(Address, ua, "user")) + .values(Address.id, ua.id, ua.concat, ua.count) + ), [ - (1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), - (5, 9, 18, 1)]) + (1, 7, 14, 1), + (2, 8, 16, 3), + (3, 8, 16, 3), + (4, 8, 16, 3), + (5, 9, 18, 1), + ], + ) def test_external_columns_joinedload(self): - users, orders, User, Address, Order, addresses = \ - (self.tables.users, - self.tables.orders, - self.classes.User, - self.classes.Address, - self.classes.Order, - self.tables.addresses) + users, orders, User, Address, Order, addresses = ( + self.tables.users, + self.tables.orders, + self.classes.User, + self.classes.Address, + self.classes.Order, + self.tables.addresses, + ) # in this test, we have a subquery on User that accesses "addresses", # underneath an joinedload for "addresses". So the "addresses" alias @@ -2512,50 +3388,76 @@ class ExternalColumnsTest(QueryTest): # standing practice of eager adapters being "chained" has been removed # since its unnecessary and breaks this exact condition. mapper( - User, users, properties={ - 'addresses': relationship( - Address, backref='user', order_by=addresses.c.id), - 'concat': column_property((users.c.id * 2)), - 'count': column_property( + User, + users, + properties={ + "addresses": relationship( + Address, backref="user", order_by=addresses.c.id + ), + "concat": column_property((users.c.id * 2)), + "count": column_property( select( [func.count(addresses.c.id)], - users.c.id == addresses.c.user_id).correlate(users))}) + users.c.id == addresses.c.user_id, + ).correlate(users) + ), + }, + ) mapper(Address, addresses) mapper( - Order, orders, properties={ - 'address': relationship(Address)}) # m2o + Order, orders, properties={"address": relationship(Address)} + ) # m2o sess = create_session() def go(): - o1 = sess.query(Order).options(joinedload_all('address.user')). \ - get(1) + o1 = ( + sess.query(Order) + .options(joinedload_all("address.user")) + .get(1) + ) eq_(o1.address.user.count, 1) + self.assert_sql_count(testing.db, go, 1) sess = create_session() def go(): - o1 = sess.query(Order).options(joinedload_all('address.user')). \ - first() + o1 = ( + sess.query(Order) + .options(joinedload_all("address.user")) + .first() + ) eq_(o1.address.user.count, 1) + self.assert_sql_count(testing.db, go, 1) def test_external_columns_compound(self): # see [ticket:2167] for background users, Address, addresses, User = ( - self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper( - User, users, properties={ - 'fullname': column_property(users.c.name.label('x'))}) + User, + users, + properties={"fullname": column_property(users.c.name.label("x"))}, + ) mapper( - Address, addresses, properties={ - 'username': column_property( - select([User.fullname]). - where(User.id == addresses.c.user_id).label('y'))}) + Address, + addresses, + properties={ + "username": column_property( + select([User.fullname]) + .where(User.id == addresses.c.user_id) + .label("y") + ) + }, + ) sess = create_session() a1 = sess.query(Address).first() eq_(a1.username, "jack") @@ -2569,30 +3471,40 @@ class TestOverlyEagerEquivalentCols(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'base', metadata, + "base", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), ) Table( - 'sub1', metadata, - Column('id', Integer, ForeignKey('base.id'), primary_key=True), - Column('data', String(50)) + "sub1", + metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column("data", String(50)), ) Table( - 'sub2', metadata, + "sub2", + metadata, Column( - 'id', Integer, ForeignKey('base.id'), ForeignKey('sub1.id'), - primary_key=True), - Column('data', String(50)) + "id", + Integer, + ForeignKey("base.id"), + ForeignKey("sub1.id"), + primary_key=True, + ), + Column("data", String(50)), ) def test_equivs(self): base, sub2, sub1 = ( - self.tables.base, self.tables.sub2, self.tables.sub1) + self.tables.base, + self.tables.sub2, + self.tables.sub1, + ) class Base(fixtures.ComparableEntity): pass @@ -2603,20 +3515,24 @@ class TestOverlyEagerEquivalentCols(fixtures.MappedTest): class Sub2(fixtures.ComparableEntity): pass - mapper(Base, base, properties={ - 'sub1': relationship(Sub1), - 'sub2': relationship(Sub2) - }) + mapper( + Base, + base, + properties={ + "sub1": relationship(Sub1), + "sub2": relationship(Sub2), + }, + ) mapper(Sub1, sub1) mapper(Sub2, sub2) sess = create_session() - s11 = Sub1(data='s11') - s12 = Sub1(data='s12') - s2 = Sub2(data='s2') - b1 = Base(data='b1', sub1=[s11], sub2=[]) - b2 = Base(data='b1', sub1=[s12], sub2=[]) + s11 = Sub1(data="s11") + s12 = Sub1(data="s12") + s2 = Sub2(data="s2") + b1 = Base(data="b1", sub1=[s11], sub2=[]) + b2 = Base(data="b1", sub1=[s12], sub2=[]) sess.add(b1) sess.add(b2) sess.flush() @@ -2626,13 +3542,16 @@ class TestOverlyEagerEquivalentCols(fixtures.MappedTest): b2.sub2 = [s2] sess.flush() - q = sess.query(Base).outerjoin('sub2', aliased=True) + q = sess.query(Base).outerjoin("sub2", aliased=True) assert sub1.c.id not in q._filter_aliases.equivalents eq_( - sess.query(Base).join('sub1').outerjoin('sub2', aliased=True). - filter(Sub1.id == 1).one(), - b1 + sess.query(Base) + .join("sub1") + .outerjoin("sub2", aliased=True) + .filter(Sub1.id == 1) + .one(), + b1, ) @@ -2647,15 +3566,15 @@ class LabelCollideTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'foo', metadata, - Column('id', Integer, primary_key=True), - Column('bar_id', Integer) + "foo", + metadata, + Column("id", Integer, primary_key=True), + Column("bar_id", Integer), ) - Table('foo_bar', metadata, Column('id', Integer, primary_key=True)) + Table("foo_bar", metadata, Column("id", Integer, primary_key=True)) @classmethod def setup_classes(cls): - class Foo(cls.Basic): pass @@ -2670,10 +3589,7 @@ class LabelCollideTest(fixtures.MappedTest): @classmethod def insert_data(cls): s = Session() - s.add_all([ - cls.classes.Foo(id=1, bar_id=2), - cls.classes.Bar(id=3) - ]) + s.add_all([cls.classes.Foo(id=1, bar_id=2), cls.classes.Bar(id=3)]) s.commit() def test_overlap_plain(self): @@ -2684,6 +3600,7 @@ class LabelCollideTest(fixtures.MappedTest): eq_(row.Foo.id, 1) eq_(row.Foo.bar_id, 2) eq_(row.Bar.id, 3) + # all three columns are loaded independently without # overlap, no additional SQL to load all attributes self.assert_sql_count(testing.db, go, 0) @@ -2696,6 +3613,7 @@ class LabelCollideTest(fixtures.MappedTest): eq_(row.Foo.id, 1) eq_(row.Foo.bar_id, 2) eq_(row.Bar.id, 3) + # all three columns are loaded independently without # overlap, no additional SQL to load all attributes self.assert_sql_count(testing.db, go, 0) diff --git a/test/orm/test_generative.py b/test/orm/test_generative.py index 21e4263f80..0e66172a3d 100644 --- a/test/orm/test_generative.py +++ b/test/orm/test_generative.py @@ -11,21 +11,23 @@ from test.orm import _fixtures class GenerativeQueryTest(fixtures.MappedTest): - run_inserts = 'once' + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - Table('foo', metadata, - Column('id', Integer, sa.Sequence('foo_id_seq'), - primary_key=True), - Column('bar', Integer), - Column('range', Integer)) + Table( + "foo", + metadata, + Column("id", Integer, sa.Sequence("foo_id_seq"), primary_key=True), + Column("bar", Integer), + Column("range", Integer), + ) @classmethod def fixtures(cls): rows = tuple([(i, i % 10) for i in range(100)]) - foo_data = (('bar', 'range'),) + rows + foo_data = (("bar", "range"),) + rows return dict(foo=foo_data) @classmethod @@ -68,37 +70,49 @@ class GenerativeQueryTest(fixtures.MappedTest): assert query[10:20][5] == orig[10:20][5] - @testing.uses_deprecated('Call to deprecated function apply_max') + @testing.uses_deprecated("Call to deprecated function apply_max") def test_aggregate(self): foo, Foo = self.tables.foo, self.classes.Foo sess = create_session() query = sess.query(Foo) assert query.count() == 100 - assert sess.query(func.min(foo.c.bar)).filter(foo.c.bar < 30) \ - .one() == (0,) - - assert sess.query(func.max(foo.c.bar)).filter(foo.c.bar < 30) \ - .one() == (29,) - assert next(query.filter(foo.c.bar < 30).values( - sa.func.max(foo.c.bar)))[0] == 29 - assert next(query.filter(foo.c.bar < 30).values( - sa.func.max(foo.c.bar)))[0] == 29 + assert sess.query(func.min(foo.c.bar)).filter( + foo.c.bar < 30 + ).one() == (0,) + + assert sess.query(func.max(foo.c.bar)).filter( + foo.c.bar < 30 + ).one() == (29,) + assert ( + next(query.filter(foo.c.bar < 30).values(sa.func.max(foo.c.bar)))[ + 0 + ] + == 29 + ) + assert ( + next(query.filter(foo.c.bar < 30).values(sa.func.max(foo.c.bar)))[ + 0 + ] + == 29 + ) @testing.fails_if( - lambda: testing.against('mysql+mysqldb') and - testing.db.dialect.dbapi.version_info[:4] == (1, 2, 1, 'gamma'), - "unknown incompatibility") + lambda: testing.against("mysql+mysqldb") + and testing.db.dialect.dbapi.version_info[:4] == (1, 2, 1, "gamma"), + "unknown incompatibility", + ) def test_aggregate_1(self): foo = self.tables.foo query = create_session().query(func.sum(foo.c.bar)) assert query.filter(foo.c.bar < 30).one() == (435,) - @testing.fails_on('firebird', 'FIXME: unknown') + @testing.fails_on("firebird", "FIXME: unknown") @testing.fails_on( - 'mssql', - 'AVG produces an average as the original column type on mssql.') + "mssql", + "AVG produces an average as the original column type on mssql.", + ) def test_aggregate_2(self): foo = self.tables.foo @@ -107,19 +121,22 @@ class GenerativeQueryTest(fixtures.MappedTest): eq_(float(round(avg, 1)), 14.5) @testing.fails_on( - 'mssql', - 'AVG produces an average as the original column type on mssql.') + "mssql", + "AVG produces an average as the original column type on mssql.", + ) def test_aggregate_3(self): foo, Foo = self.tables.foo, self.classes.Foo query = create_session().query(Foo) - avg_f = next(query.filter(foo.c.bar < 30).values( - sa.func.avg(foo.c.bar)))[0] + avg_f = next( + query.filter(foo.c.bar < 30).values(sa.func.avg(foo.c.bar)) + )[0] assert float(round(avg_f, 1)) == 14.5 - avg_o = next(query.filter(foo.c.bar < 30).values( - sa.func.avg(foo.c.bar)))[0] + avg_o = next( + query.filter(foo.c.bar < 30).values(sa.func.avg(foo.c.bar)) + )[0] assert float(round(avg_o, 1)) == 14.5 def test_filter(self): @@ -152,15 +169,15 @@ class GenerativeQueryTest(fixtures.MappedTest): class GenerativeTest2(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('table1', metadata, - Column('id', Integer, primary_key=True)) - Table('table2', metadata, - Column('t1id', Integer, ForeignKey("table1.id"), - primary_key=True), - Column('num', Integer, primary_key=True)) + Table("table1", metadata, Column("id", Integer, primary_key=True)) + Table( + "table2", + metadata, + Column("t1id", Integer, ForeignKey("table1.id"), primary_key=True), + Column("num", Integer, primary_key=True), + ) @classmethod def setup_mappers(cls): @@ -178,52 +195,71 @@ class GenerativeTest2(fixtures.MappedTest): @classmethod def fixtures(cls): return dict( - table1=(('id',), - (1,), - (2,), - (3,), - (4,)), - table2=(('num', 't1id'), - (1, 1), - (2, 1), - (3, 1), - (4, 2), - (5, 2), - (6, 3))) + table1=(("id",), (1,), (2,), (3,), (4,)), + table2=( + ("num", "t1id"), + (1, 1), + (2, 1), + (3, 1), + (4, 2), + (5, 2), + (6, 3), + ), + ) def test_distinct_count(self): - table2, Obj1, table1 = (self.tables.table2, - self.classes.Obj1, - self.tables.table1) + table2, Obj1, table1 = ( + self.tables.table2, + self.classes.Obj1, + self.tables.table1, + ) query = create_session().query(Obj1) eq_(query.count(), 4) - res = query.filter(sa.and_(table1.c.id == table2.c.t1id, - table2.c.t1id == 1)) + res = query.filter( + sa.and_(table1.c.id == table2.c.t1id, table2.c.t1id == 1) + ) eq_(res.count(), 3) - res = query.filter(sa.and_(table1.c.id == table2.c.t1id, - table2.c.t1id == 1)).distinct() + res = query.filter( + sa.and_(table1.c.id == table2.c.t1id, table2.c.t1id == 1) + ).distinct() eq_(res.count(), 1) class RelationshipsTest(_fixtures.FixtureTest): - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod def setup_mappers(cls): - addresses, Order, User, Address, orders, users = (cls.tables.addresses, - cls.classes.Order, - cls.classes.User, - cls.classes.Address, - cls.tables.orders, - cls.tables.users) - - mapper(User, users, properties={ - 'orders': relationship(mapper(Order, orders, properties={ - 'addresses': relationship(mapper(Address, addresses))}))}) + addresses, Order, User, Address, orders, users = ( + cls.tables.addresses, + cls.classes.Order, + cls.classes.User, + cls.classes.Address, + cls.tables.orders, + cls.tables.users, + ) + + mapper( + User, + users, + properties={ + "orders": relationship( + mapper( + Order, + orders, + properties={ + "addresses": relationship( + mapper(Address, addresses) + ) + }, + ) + ) + }, + ) def test_join(self): """Query.join""" @@ -231,65 +267,81 @@ class RelationshipsTest(_fixtures.FixtureTest): User, Address = self.classes.User, self.classes.Address session = create_session() - q = (session.query(User).join('orders', 'addresses'). - filter(Address.id == 1)) + q = ( + session.query(User) + .join("orders", "addresses") + .filter(Address.id == 1) + ) eq_([User(id=7)], q.all()) def test_outer_join(self): """Query.outerjoin""" - Order, User, Address = (self.classes.Order, - self.classes.User, - self.classes.Address) + Order, User, Address = ( + self.classes.Order, + self.classes.User, + self.classes.Address, + ) session = create_session() - q = (session.query(User).outerjoin('orders', 'addresses'). - filter(sa.or_(Order.id == None, Address.id == 1))) # noqa - eq_(set([User(id=7), User(id=8), User(id=10)]), - set(q.all())) + q = ( + session.query(User) + .outerjoin("orders", "addresses") + .filter(sa.or_(Order.id == None, Address.id == 1)) + ) # noqa + eq_(set([User(id=7), User(id=8), User(id=10)]), set(q.all())) def test_outer_join_count(self): """test the join and outerjoin functions on Query""" - Order, User, Address = (self.classes.Order, - self.classes.User, - self.classes.Address) + Order, User, Address = ( + self.classes.Order, + self.classes.User, + self.classes.Address, + ) session = create_session() - q = (session.query(User).outerjoin('orders', 'addresses'). - filter(sa.or_(Order.id == None, Address.id == 1))) # noqa + q = ( + session.query(User) + .outerjoin("orders", "addresses") + .filter(sa.or_(Order.id == None, Address.id == 1)) + ) # noqa eq_(q.count(), 4) def test_from(self): - users, Order, User, Address, orders, addresses = \ - (self.tables.users, - self.classes.Order, - self.classes.User, - self.classes.Address, - self.tables.orders, - self.tables.addresses) + users, Order, User, Address, orders, addresses = ( + self.tables.users, + self.classes.Order, + self.classes.User, + self.classes.Address, + self.tables.orders, + self.tables.addresses, + ) session = create_session() sel = users.outerjoin(orders).outerjoin( - addresses, orders.c.address_id == addresses.c.id) - q = (session.query(User).select_from(sel). - filter(sa.or_(Order.id == None, Address.id == 1))) # noqa - eq_(set([User(id=7), User(id=8), User(id=10)]), - set(q.all())) + addresses, orders.c.address_id == addresses.c.id + ) + q = ( + session.query(User) + .select_from(sel) + .filter(sa.or_(Order.id == None, Address.id == 1)) + ) # noqa + eq_(set([User(id=7), User(id=8), User(id=10)]), set(q.all())) class CaseSensitiveTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('Table1', metadata, - Column('ID', Integer, primary_key=True)) - Table('Table2', metadata, - Column('T1ID', Integer, ForeignKey("Table1.ID"), - primary_key=True), - Column('NUM', Integer, primary_key=True)) + Table("Table1", metadata, Column("ID", Integer, primary_key=True)) + Table( + "Table2", + metadata, + Column("T1ID", Integer, ForeignKey("Table1.ID"), primary_key=True), + Column("NUM", Integer, primary_key=True), + ) @classmethod def setup_mappers(cls): @@ -307,29 +359,32 @@ class CaseSensitiveTest(fixtures.MappedTest): @classmethod def fixtures(cls): return dict( - Table1=(('ID',), - (1,), - (2,), - (3,), - (4,)), - Table2=(('NUM', 'T1ID'), - (1, 1), - (2, 1), - (3, 1), - (4, 2), - (5, 2), - (6, 3))) + Table1=(("ID",), (1,), (2,), (3,), (4,)), + Table2=( + ("NUM", "T1ID"), + (1, 1), + (2, 1), + (3, 1), + (4, 2), + (5, 2), + (6, 3), + ), + ) def test_distinct_count(self): - Table2, Obj1, Table1 = (self.tables.Table2, - self.classes.Obj1, - self.tables.Table1) + Table2, Obj1, Table1 = ( + self.tables.Table2, + self.classes.Obj1, + self.tables.Table1, + ) q = create_session(bind=testing.db).query(Obj1) assert q.count() == 4 res = q.filter( - sa.and_(Table1.c.ID == Table2.c.T1ID, Table2.c.T1ID == 1)) + sa.and_(Table1.c.ID == Table2.c.T1ID, Table2.c.T1ID == 1) + ) assert res.count() == 3 - res = q.filter(sa.and_(Table1.c.ID == Table2.c.T1ID, - Table2.c.T1ID == 1)).distinct() + res = q.filter( + sa.and_(Table1.c.ID == Table2.c.T1ID, Table2.c.T1ID == 1) + ).distinct() eq_(res.count(), 1) diff --git a/test/orm/test_hasparent.py b/test/orm/test_hasparent.py index 38dd722c20..fe5e05a174 100644 --- a/test/orm/test_hasparent.py +++ b/test/orm/test_hasparent.py @@ -2,11 +2,17 @@ from sqlalchemy.testing import assert_raises, assert_raises_message -from sqlalchemy import Integer, String, ForeignKey, Sequence, \ - exc as sa_exc +from sqlalchemy import Integer, String, ForeignKey, Sequence, exc as sa_exc from sqlalchemy.testing.schema import Table, Column -from sqlalchemy.orm import mapper, relationship, create_session, \ - sessionmaker, class_mapper, backref, Session +from sqlalchemy.orm import ( + mapper, + relationship, + create_session, + sessionmaker, + class_mapper, + backref, + Session, +) from sqlalchemy.orm import attributes, exc as orm_exc from sqlalchemy import testing from sqlalchemy.testing import eq_ @@ -23,24 +29,33 @@ class ParentRemovalTest(fixtures.MappedTest): raised. """ + run_inserts = None @classmethod def define_tables(cls, metadata): - if testing.against('oracle'): - fk_args = dict(deferrable=True, initially='deferred') - elif testing.against('mysql'): + if testing.against("oracle"): + fk_args = dict(deferrable=True, initially="deferred") + elif testing.against("mysql"): fk_args = {} else: - fk_args = dict(onupdate='cascade') - - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True)) - Table('addresses', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', Integer, ForeignKey('users.id', **fk_args))) + fk_args = dict(onupdate="cascade") + + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) + Table( + "addresses", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", Integer, ForeignKey("users.id", **fk_args)), + ) @classmethod def setup_classes(cls): @@ -53,11 +68,15 @@ class ParentRemovalTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): mapper(cls.classes.Address, cls.tables.addresses) - mapper(cls.classes.User, cls.tables.users, properties={ - 'addresses': relationship(cls.classes.Address, - cascade='all, delete-orphan'), - - }) + mapper( + cls.classes.User, + cls.tables.users, + properties={ + "addresses": relationship( + cls.classes.Address, cascade="all, delete-orphan" + ) + }, + ) def _assert_hasparent(self, a1): assert attributes.has_parent(self.classes.User, a1, "addresses") @@ -132,7 +151,8 @@ class ParentRemovalTest(fixtures.MappedTest): assert_raises_message( orm_exc.StaleDataError, "can't be sure this is the most recent parent.", - u1.addresses.remove, a1 + u1.addresses.remove, + a1, ) # u1.addresses wasn't actually impacted, because the event was @@ -185,7 +205,8 @@ class ParentRemovalTest(fixtures.MappedTest): assert_raises_message( orm_exc.StaleDataError, "can't be sure this is the most recent parent.", - u1.addresses.remove, a1 + u1.addresses.remove, + a1, ) s.flush() diff --git a/test/orm/test_immediate_load.py b/test/orm/test_immediate_load.py index 45ebc2835d..6a3beeb792 100644 --- a/test/orm/test_immediate_load.py +++ b/test/orm/test_immediate_load.py @@ -7,43 +7,55 @@ from test.orm import _fixtures class ImmediateTest(_fixtures.FixtureTest): - run_inserts = 'once' + run_inserts = "once" run_deletes = None def test_basic_option(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties={ - 'addresses': relationship(Address) - }) + mapper(User, users, properties={"addresses": relationship(Address)}) sess = create_session() - result = sess.query(User).options(immediateload( - User.addresses)).filter(users.c.id == 7).all() + result = ( + sess.query(User) + .options(immediateload(User.addresses)) + .filter(users.c.id == 7) + .all() + ) eq_(len(sess.identity_map), 2) sess.close() eq_( - [User(id=7, - addresses=[Address(id=1, email_address='jack@bean.com')])], - result + [ + User( + id=7, + addresses=[Address(id=1, email_address="jack@bean.com")], + ) + ], + result, ) def test_basic(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties={ - 'addresses': relationship(Address, lazy='immediate') - }) + mapper( + User, + users, + properties={"addresses": relationship(Address, lazy="immediate")}, + ) sess = create_session() result = sess.query(User).filter(users.c.id == 7).all() @@ -51,7 +63,11 @@ class ImmediateTest(_fixtures.FixtureTest): sess.close() eq_( - [User(id=7, - addresses=[Address(id=1, email_address='jack@bean.com')])], - result + [ + User( + id=7, + addresses=[Address(id=1, email_address="jack@bean.com")], + ) + ], + result, ) diff --git a/test/orm/test_inspect.py b/test/orm/test_inspect.py index 8353102081..2c641d026e 100644 --- a/test/orm/test_inspect.py +++ b/test/orm/test_inspect.py @@ -4,8 +4,13 @@ from sqlalchemy.testing import eq_, assert_raises_message, is_ from sqlalchemy import exc, util from sqlalchemy import inspect from test.orm import _fixtures -from sqlalchemy.orm import class_mapper, synonym, Session, aliased,\ - relationship +from sqlalchemy.orm import ( + class_mapper, + synonym, + Session, + aliased, + relationship, +) from sqlalchemy import ForeignKey from sqlalchemy.orm.attributes import instance_state, NO_VALUE from sqlalchemy import testing @@ -16,9 +21,7 @@ class TestORMInspection(_fixtures.FixtureTest): @classmethod def setup_mappers(cls): cls._setup_stock_mapping() - inspect(cls.classes.User).add_property( - "name_syn", synonym("name") - ) + inspect(cls.classes.User).add_property("name_syn", synonym("name")) def test_class_mapper(self): User = self.classes.User @@ -29,20 +32,14 @@ class TestORMInspection(_fixtures.FixtureTest): User = self.classes.User user_table = self.tables.users insp = inspect(User) - eq_( - list(insp.columns), - [user_table.c.id, user_table.c.name] - ) - is_( - insp.columns.id, user_table.c.id - ) + eq_(list(insp.columns), [user_table.c.id, user_table.c.name]) + is_(insp.columns.id, user_table.c.id) def test_primary_key(self): User = self.classes.User user_table = self.tables.users insp = inspect(User) - eq_(insp.primary_key, - (user_table.c.id,)) + eq_(insp.primary_key, (user_table.c.id,)) def test_local_table(self): User = self.classes.User @@ -72,12 +69,16 @@ class TestORMInspection(_fixtures.FixtureTest): class Bar(Foo): pass + user_table = self.tables.users addresses_table = self.tables.addresses mapper(Foo, user_table, with_polymorphic=(Bar,)) - mapper(Bar, addresses_table, inherits=Foo, properties={ - 'address_id': addresses_table.c.id - }) + mapper( + Bar, + addresses_table, + inherits=Foo, + properties={"address_id": addresses_table.c.id}, + ) i1 = inspect(Foo) i2 = inspect(Foo) assert i1.selectable is i2.selectable @@ -98,7 +99,8 @@ class TestORMInspection(_fixtures.FixtureTest): assert_raises_message( exc.NoInspectionAvailable, "No inspection system is available for object of type", - inspect, Foo + inspect, + Foo, ) def test_not_mapped_instance(self): @@ -108,13 +110,14 @@ class TestORMInspection(_fixtures.FixtureTest): assert_raises_message( exc.NoInspectionAvailable, "No inspection system is available for object of type", - inspect, Foo() + inspect, + Foo(), ) def test_property(self): User = self.classes.User insp = inspect(User) - is_(insp.attrs.id, class_mapper(User).get_property('id')) + is_(insp.attrs.id, class_mapper(User).get_property("id")) def test_with_polymorphic(self): User = self.classes.User @@ -130,14 +133,14 @@ class TestORMInspection(_fixtures.FixtureTest): eq_(id_prop.columns, [user_table.c.id]) is_(id_prop.expression, user_table.c.id) - assert not hasattr(id_prop, 'mapper') + assert not hasattr(id_prop, "mapper") def test_attr_keys(self): User = self.classes.User insp = inspect(User) eq_( list(insp.attrs.keys()), - ['addresses', 'orders', 'id', 'name', 'name_syn'] + ["addresses", "orders", "id", "name", "name_syn"], ) def test_col_filter(self): @@ -145,41 +148,25 @@ class TestORMInspection(_fixtures.FixtureTest): insp = inspect(User) eq_( list(insp.column_attrs), - [insp.get_property('id'), insp.get_property('name')] - ) - eq_( - list(insp.column_attrs.keys()), - ['id', 'name'] - ) - is_( - insp.column_attrs.id, - User.id.property + [insp.get_property("id"), insp.get_property("name")], ) + eq_(list(insp.column_attrs.keys()), ["id", "name"]) + is_(insp.column_attrs.id, User.id.property) def test_synonym_filter(self): User = self.classes.User syn = inspect(User).synonyms - eq_( - list(syn.keys()), ['name_syn'] - ) + eq_(list(syn.keys()), ["name_syn"]) is_(syn.name_syn, User.name_syn.original_property) - eq_(dict(syn), { - "name_syn": User.name_syn.original_property - }) + eq_(dict(syn), {"name_syn": User.name_syn.original_property}) def test_relationship_filter(self): User = self.classes.User rel = inspect(User).relationships - eq_( - rel.addresses, - User.addresses.property - ) - eq_( - set(rel.keys()), - set(['orders', 'addresses']) - ) + eq_(rel.addresses, User.addresses.property) + eq_(set(rel.keys()), set(["orders", "addresses"])) def test_insp_relationship_prop(self): User = self.classes.User @@ -239,20 +226,26 @@ class TestORMInspection(_fixtures.FixtureTest): is_(prop.parent, class_mapper(User)) is_(prop.mapper, class_mapper(Address)) - assert not hasattr(prop, 'columns') - assert hasattr(prop, 'expression') + assert not hasattr(prop, "columns") + assert hasattr(prop, "expression") def test_extension_types(self): - from sqlalchemy.ext.associationproxy import \ - association_proxy, ASSOCIATION_PROXY - from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method, \ - HYBRID_PROPERTY, HYBRID_METHOD + from sqlalchemy.ext.associationproxy import ( + association_proxy, + ASSOCIATION_PROXY, + ) + from sqlalchemy.ext.hybrid import ( + hybrid_property, + hybrid_method, + HYBRID_PROPERTY, + HYBRID_METHOD, + ) from sqlalchemy import Table, MetaData, Integer, Column from sqlalchemy.orm import mapper from sqlalchemy.orm.interfaces import NOT_EXTENSION class SomeClass(self.classes.User): - some_assoc = association_proxy('addresses', 'email_address') + some_assoc = association_proxy("addresses", "email_address") @hybrid_property def upper_name(self): @@ -275,45 +268,43 @@ class TestORMInspection(_fixtures.FixtureTest): raise NotImplementedError() m = MetaData() - t = Table('sometable', m, - Column('id', Integer, primary_key=True)) - ta = Table('address_t', m, - Column('id', Integer, primary_key=True), - Column('s_id', ForeignKey('sometable.id')) - ) - mapper(SomeClass, t, properties={ - "addresses": relationship(Address) - }) + t = Table("sometable", m, Column("id", Integer, primary_key=True)) + ta = Table( + "address_t", + m, + Column("id", Integer, primary_key=True), + Column("s_id", ForeignKey("sometable.id")), + ) + mapper(SomeClass, t, properties={"addresses": relationship(Address)}) mapper(Address, ta) mapper(SomeSubClass, inherits=SomeClass) insp = inspect(SomeSubClass) eq_( - dict((k, v.extension_type) - for k, v in list(insp.all_orm_descriptors.items())), + dict( + (k, v.extension_type) + for k, v in list(insp.all_orm_descriptors.items()) + ), { - 'id': NOT_EXTENSION, - 'name': NOT_EXTENSION, - 'name_syn': NOT_EXTENSION, - 'addresses': NOT_EXTENSION, - 'orders': NOT_EXTENSION, - 'upper_name': HYBRID_PROPERTY, - 'foo': HYBRID_PROPERTY, - 'conv': HYBRID_METHOD, - 'some_assoc': ASSOCIATION_PROXY - } + "id": NOT_EXTENSION, + "name": NOT_EXTENSION, + "name_syn": NOT_EXTENSION, + "addresses": NOT_EXTENSION, + "orders": NOT_EXTENSION, + "upper_name": HYBRID_PROPERTY, + "foo": HYBRID_PROPERTY, + "conv": HYBRID_METHOD, + "some_assoc": ASSOCIATION_PROXY, + }, ) is_( insp.all_orm_descriptors.upper_name, - SomeSubClass.__dict__['upper_name'] - ) - is_( - insp.all_orm_descriptors.some_assoc, - SomeClass.some_assoc.parent + SomeSubClass.__dict__["upper_name"], ) + is_(insp.all_orm_descriptors.some_assoc, SomeClass.some_assoc.parent) is_( inspect(SomeClass).all_orm_descriptors.upper_name, - SomeClass.__dict__['upper_name'] + SomeClass.__dict__["upper_name"], ) def test_instance_state(self): @@ -326,122 +317,86 @@ class TestORMInspection(_fixtures.FixtureTest): User = self.classes.User u1 = User() insp = inspect(u1) - insp.info['some_key'] = 'value' - eq_(inspect(u1).info['some_key'], 'value') + insp.info["some_key"] = "value" + eq_(inspect(u1).info["some_key"], "value") def test_instance_state_attr(self): User = self.classes.User - u1 = User(name='ed') + u1 = User(name="ed") insp = inspect(u1) eq_( set(insp.attrs.keys()), - set(['id', 'name', 'name_syn', 'addresses', 'orders']) - ) - eq_( - insp.attrs.name.value, - 'ed' - ) - eq_( - insp.attrs.name.loaded_value, - 'ed' + set(["id", "name", "name_syn", "addresses", "orders"]), ) + eq_(insp.attrs.name.value, "ed") + eq_(insp.attrs.name.loaded_value, "ed") def test_instance_state_attr_passive_value_scalar(self): User = self.classes.User - u1 = User(name='ed') + u1 = User(name="ed") insp = inspect(u1) # value was not set, NO_VALUE - eq_( - insp.attrs.id.loaded_value, - NO_VALUE - ) + eq_(insp.attrs.id.loaded_value, NO_VALUE) # regular accessor sets it - eq_( - insp.attrs.id.value, - None - ) + eq_(insp.attrs.id.value, None) # nope, still not set - eq_( - insp.attrs.id.loaded_value, - NO_VALUE - ) + eq_(insp.attrs.id.loaded_value, NO_VALUE) def test_instance_state_attr_passive_value_collection(self): User = self.classes.User - u1 = User(name='ed') + u1 = User(name="ed") insp = inspect(u1) # value was not set, NO_VALUE - eq_( - insp.attrs.addresses.loaded_value, - NO_VALUE - ) + eq_(insp.attrs.addresses.loaded_value, NO_VALUE) # regular accessor sets it - eq_( - insp.attrs.addresses.value, - [] - ) + eq_(insp.attrs.addresses.value, []) # now the None is there - eq_( - insp.attrs.addresses.loaded_value, - [] - ) + eq_(insp.attrs.addresses.loaded_value, []) def test_instance_state_collection_attr_hist(self): User = self.classes.User - u1 = User(name='ed') + u1 = User(name="ed") insp = inspect(u1) hist = insp.attrs.addresses.history - eq_( - hist.unchanged, None - ) + eq_(hist.unchanged, None) u1.addresses hist = insp.attrs.addresses.history - eq_( - hist.unchanged, [] - ) + eq_(hist.unchanged, []) def test_instance_state_scalar_attr_hist(self): User = self.classes.User - u1 = User(name='ed') + u1 = User(name="ed") sess = Session() sess.add(u1) sess.commit() - assert 'name' not in u1.__dict__ + assert "name" not in u1.__dict__ insp = inspect(u1) hist = insp.attrs.name.history - eq_( - hist.unchanged, None - ) - assert 'name' not in u1.__dict__ + eq_(hist.unchanged, None) + assert "name" not in u1.__dict__ def test_instance_state_collection_attr_load_hist(self): User = self.classes.User - u1 = User(name='ed') + u1 = User(name="ed") insp = inspect(u1) hist = insp.attrs.addresses.load_history() - eq_( - hist.unchanged, () - ) + eq_(hist.unchanged, ()) u1.addresses hist = insp.attrs.addresses.load_history() - eq_( - hist.unchanged, [] - ) + eq_(hist.unchanged, []) def test_instance_state_scalar_attr_hist_load(self): User = self.classes.User - u1 = User(name='ed') + u1 = User(name="ed") sess = Session() sess.add(u1) sess.commit() - assert 'name' not in u1.__dict__ + assert "name" not in u1.__dict__ insp = inspect(u1) hist = insp.attrs.name.load_history() - eq_( - hist.unchanged, ['ed'] - ) - assert 'name' in u1.__dict__ + eq_(hist.unchanged, ["ed"]) + assert "name" in u1.__dict__ def test_attrs_props_prop_added_after_configure(self): class AnonClass(object): @@ -449,37 +404,37 @@ class TestORMInspection(_fixtures.FixtureTest): from sqlalchemy.orm import mapper, column_property from sqlalchemy.ext.hybrid import hybrid_property + m = mapper(AnonClass, self.tables.users) - eq_( - set(inspect(AnonClass).attrs.keys()), - set(['id', 'name'])) + eq_(set(inspect(AnonClass).attrs.keys()), set(["id", "name"])) eq_( set(inspect(AnonClass).all_orm_descriptors.keys()), - set(['id', 'name'])) + set(["id", "name"]), + ) - m.add_property('q', column_property(self.tables.users.c.name)) + m.add_property("q", column_property(self.tables.users.c.name)) def desc(self): return self.name + AnonClass.foob = hybrid_property(desc) - eq_( - set(inspect(AnonClass).attrs.keys()), - set(['id', 'name', 'q'])) + eq_(set(inspect(AnonClass).attrs.keys()), set(["id", "name", "q"])) eq_( set(inspect(AnonClass).all_orm_descriptors.keys()), - set(['id', 'name', 'q', 'foob'])) + set(["id", "name", "q", "foob"]), + ) def test_instance_state_ident_transient(self): User = self.classes.User - u1 = User(name='ed') + u1 = User(name="ed") insp = inspect(u1) is_(insp.identity, None) def test_instance_state_ident_persistent(self): User = self.classes.User - u1 = User(name='ed') + u1 = User(name="ed") s = Session(testing.db) s.add(u1) s.flush() @@ -489,7 +444,7 @@ class TestORMInspection(_fixtures.FixtureTest): def test_is_instance(self): User = self.classes.User - u1 = User(name='ed') + u1 = User(name="ed") insp = inspect(u1) assert insp.is_instance @@ -501,51 +456,44 @@ class TestORMInspection(_fixtures.FixtureTest): def test_identity_key(self): User = self.classes.User - u1 = User(name='ed') + u1 = User(name="ed") s = Session(testing.db) s.add(u1) s.flush() insp = inspect(u1) - eq_( - insp.identity_key, - identity_key(User, (u1.id, )) - ) + eq_(insp.identity_key, identity_key(User, (u1.id,))) def test_persistence_states(self): User = self.classes.User - u1 = User(name='ed') + u1 = User(name="ed") insp = inspect(u1) eq_( - (insp.transient, insp.pending, - insp.persistent, insp.detached), - (True, False, False, False) + (insp.transient, insp.pending, insp.persistent, insp.detached), + (True, False, False, False), ) s = Session(testing.db) s.add(u1) eq_( - (insp.transient, insp.pending, - insp.persistent, insp.detached), - (False, True, False, False) + (insp.transient, insp.pending, insp.persistent, insp.detached), + (False, True, False, False), ) s.flush() eq_( - (insp.transient, insp.pending, - insp.persistent, insp.detached), - (False, False, True, False) + (insp.transient, insp.pending, insp.persistent, insp.detached), + (False, False, True, False), ) s.expunge(u1) eq_( - (insp.transient, insp.pending, - insp.persistent, insp.detached), - (False, False, False, True) + (insp.transient, insp.pending, insp.persistent, insp.detached), + (False, False, False, True), ) def test_session_accessor(self): User = self.classes.User - u1 = User(name='ed') + u1 = User(name="ed") insp = inspect(u1) is_(insp.session, None) @@ -555,6 +503,6 @@ class TestORMInspection(_fixtures.FixtureTest): def test_object_accessor(self): User = self.classes.User - u1 = User(name='ed') + u1 = User(name="ed") insp = inspect(u1) is_(insp.object, u1) diff --git a/test/orm/test_instrumentation.py b/test/orm/test_instrumentation.py index 5e62eeddfb..b6ed1e42b0 100644 --- a/test/orm/test_instrumentation.py +++ b/test/orm/test_instrumentation.py @@ -1,9 +1,16 @@ - from sqlalchemy.testing import assert_raises, assert_raises_message import sqlalchemy as sa from sqlalchemy import MetaData, Integer, ForeignKey, util, event -from sqlalchemy.orm import mapper, relationship, create_session, \ - attributes, class_mapper, clear_mappers, instrumentation, events +from sqlalchemy.orm import ( + mapper, + relationship, + create_session, + attributes, + class_mapper, + clear_mappers, + instrumentation, + events, +) from sqlalchemy.testing.schema import Table from sqlalchemy.testing.schema import Column from sqlalchemy.testing import eq_, ne_ @@ -13,11 +20,14 @@ from sqlalchemy import testing class InitTest(fixtures.ORMTest): def fixture(self): - return Table('t', MetaData(), - Column('id', Integer, primary_key=True), - Column('type', Integer), - Column('x', Integer), - Column('y', Integer)) + return Table( + "t", + MetaData(), + Column("id", Integer, primary_key=True), + Column("type", Integer), + Column("x", Integer), + Column("y", Integer), + ) def register(self, cls, canary): original_init = cls.__init__ @@ -26,214 +36,244 @@ class InitTest(fixtures.ORMTest): manager = instrumentation.manager_of_class(cls) def init(state, args, kwargs): - canary.append((cls, 'init', state.class_)) - event.listen(manager, 'init', init, raw=True) + canary.append((cls, "init", state.class_)) + + event.listen(manager, "init", init, raw=True) def test_ai(self): inits = [] class A(object): def __init__(self): - inits.append((A, '__init__')) + inits.append((A, "__init__")) obj = A() - eq_(inits, [(A, '__init__')]) + eq_(inits, [(A, "__init__")]) def test_A(self): inits = [] class A(object): pass + self.register(A, inits) obj = A() - eq_(inits, [(A, 'init', A)]) + eq_(inits, [(A, "init", A)]) def test_Ai(self): inits = [] class A(object): def __init__(self): - inits.append((A, '__init__')) + inits.append((A, "__init__")) + self.register(A, inits) obj = A() - eq_(inits, [(A, 'init', A), (A, '__init__')]) + eq_(inits, [(A, "init", A), (A, "__init__")]) def test_ai_B(self): inits = [] class A(object): def __init__(self): - inits.append((A, '__init__')) + inits.append((A, "__init__")) class B(A): pass + self.register(B, inits) obj = A() - eq_(inits, [(A, '__init__')]) + eq_(inits, [(A, "__init__")]) del inits[:] obj = B() - eq_(inits, [(B, 'init', B), (A, '__init__')]) + eq_(inits, [(B, "init", B), (A, "__init__")]) def test_ai_Bi(self): inits = [] class A(object): def __init__(self): - inits.append((A, '__init__')) + inits.append((A, "__init__")) class B(A): def __init__(self): - inits.append((B, '__init__')) + inits.append((B, "__init__")) super(B, self).__init__() + self.register(B, inits) obj = A() - eq_(inits, [(A, '__init__')]) + eq_(inits, [(A, "__init__")]) del inits[:] obj = B() - eq_(inits, [(B, 'init', B), (B, '__init__'), (A, '__init__')]) + eq_(inits, [(B, "init", B), (B, "__init__"), (A, "__init__")]) def test_Ai_bi(self): inits = [] class A(object): def __init__(self): - inits.append((A, '__init__')) + inits.append((A, "__init__")) + self.register(A, inits) class B(A): def __init__(self): - inits.append((B, '__init__')) + inits.append((B, "__init__")) super(B, self).__init__() obj = A() - eq_(inits, [(A, 'init', A), (A, '__init__')]) + eq_(inits, [(A, "init", A), (A, "__init__")]) del inits[:] obj = B() - eq_(inits, [(B, '__init__'), (A, 'init', B), (A, '__init__')]) + eq_(inits, [(B, "__init__"), (A, "init", B), (A, "__init__")]) def test_Ai_Bi(self): inits = [] class A(object): def __init__(self): - inits.append((A, '__init__')) + inits.append((A, "__init__")) + self.register(A, inits) class B(A): def __init__(self): - inits.append((B, '__init__')) + inits.append((B, "__init__")) super(B, self).__init__() + self.register(B, inits) obj = A() - eq_(inits, [(A, 'init', A), (A, '__init__')]) + eq_(inits, [(A, "init", A), (A, "__init__")]) del inits[:] obj = B() - eq_(inits, [(B, 'init', B), (B, '__init__'), (A, '__init__')]) + eq_(inits, [(B, "init", B), (B, "__init__"), (A, "__init__")]) def test_Ai_B(self): inits = [] class A(object): def __init__(self): - inits.append((A, '__init__')) + inits.append((A, "__init__")) + self.register(A, inits) class B(A): pass + self.register(B, inits) obj = A() - eq_(inits, [(A, 'init', A), (A, '__init__')]) + eq_(inits, [(A, "init", A), (A, "__init__")]) del inits[:] obj = B() - eq_(inits, [(B, 'init', B), (A, '__init__')]) + eq_(inits, [(B, "init", B), (A, "__init__")]) def test_Ai_Bi_Ci(self): inits = [] class A(object): def __init__(self): - inits.append((A, '__init__')) + inits.append((A, "__init__")) + self.register(A, inits) class B(A): def __init__(self): - inits.append((B, '__init__')) + inits.append((B, "__init__")) super(B, self).__init__() + self.register(B, inits) class C(B): def __init__(self): - inits.append((C, '__init__')) + inits.append((C, "__init__")) super(C, self).__init__() + self.register(C, inits) obj = A() - eq_(inits, [(A, 'init', A), (A, '__init__')]) + eq_(inits, [(A, "init", A), (A, "__init__")]) del inits[:] obj = B() - eq_(inits, [(B, 'init', B), (B, '__init__'), (A, '__init__')]) + eq_(inits, [(B, "init", B), (B, "__init__"), (A, "__init__")]) del inits[:] obj = C() - eq_(inits, [(C, 'init', C), (C, '__init__'), (B, '__init__'), - (A, '__init__')]) + eq_( + inits, + [ + (C, "init", C), + (C, "__init__"), + (B, "__init__"), + (A, "__init__"), + ], + ) def test_Ai_bi_Ci(self): inits = [] class A(object): def __init__(self): - inits.append((A, '__init__')) + inits.append((A, "__init__")) + self.register(A, inits) class B(A): def __init__(self): - inits.append((B, '__init__')) + inits.append((B, "__init__")) super(B, self).__init__() class C(B): def __init__(self): - inits.append((C, '__init__')) + inits.append((C, "__init__")) super(C, self).__init__() + self.register(C, inits) obj = A() - eq_(inits, [(A, 'init', A), (A, '__init__')]) + eq_(inits, [(A, "init", A), (A, "__init__")]) del inits[:] obj = B() - eq_(inits, [(B, '__init__'), (A, 'init', B), (A, '__init__')]) + eq_(inits, [(B, "__init__"), (A, "init", B), (A, "__init__")]) del inits[:] obj = C() - eq_(inits, [(C, 'init', C), (C, '__init__'), (B, '__init__'), - (A, '__init__')]) + eq_( + inits, + [ + (C, "init", C), + (C, "__init__"), + (B, "__init__"), + (A, "__init__"), + ], + ) def test_Ai_b_Ci(self): inits = [] class A(object): def __init__(self): - inits.append((A, '__init__')) + inits.append((A, "__init__")) + self.register(A, inits) class B(A): @@ -241,175 +281,192 @@ class InitTest(fixtures.ORMTest): class C(B): def __init__(self): - inits.append((C, '__init__')) + inits.append((C, "__init__")) super(C, self).__init__() + self.register(C, inits) obj = A() - eq_(inits, [(A, 'init', A), (A, '__init__')]) + eq_(inits, [(A, "init", A), (A, "__init__")]) del inits[:] obj = B() - eq_(inits, [(A, 'init', B), (A, '__init__')]) + eq_(inits, [(A, "init", B), (A, "__init__")]) del inits[:] obj = C() - eq_(inits, [(C, 'init', C), (C, '__init__'), (A, '__init__')]) + eq_(inits, [(C, "init", C), (C, "__init__"), (A, "__init__")]) def test_Ai_B_Ci(self): inits = [] class A(object): def __init__(self): - inits.append((A, '__init__')) + inits.append((A, "__init__")) + self.register(A, inits) class B(A): pass + self.register(B, inits) class C(B): def __init__(self): - inits.append((C, '__init__')) + inits.append((C, "__init__")) super(C, self).__init__() + self.register(C, inits) obj = A() - eq_(inits, [(A, 'init', A), (A, '__init__')]) + eq_(inits, [(A, "init", A), (A, "__init__")]) del inits[:] obj = B() - eq_(inits, [(B, 'init', B), (A, '__init__')]) + eq_(inits, [(B, "init", B), (A, "__init__")]) del inits[:] obj = C() - eq_(inits, [(C, 'init', C), (C, '__init__'), (A, '__init__')]) + eq_(inits, [(C, "init", C), (C, "__init__"), (A, "__init__")]) def test_Ai_B_C(self): inits = [] class A(object): def __init__(self): - inits.append((A, '__init__')) + inits.append((A, "__init__")) + self.register(A, inits) class B(A): pass + self.register(B, inits) class C(B): pass + self.register(C, inits) obj = A() - eq_(inits, [(A, 'init', A), (A, '__init__')]) + eq_(inits, [(A, "init", A), (A, "__init__")]) del inits[:] obj = B() - eq_(inits, [(B, 'init', B), (A, '__init__')]) + eq_(inits, [(B, "init", B), (A, "__init__")]) del inits[:] obj = C() - eq_(inits, [(C, 'init', C), (A, '__init__')]) + eq_(inits, [(C, "init", C), (A, "__init__")]) def test_A_Bi_C(self): inits = [] class A(object): pass + self.register(A, inits) class B(A): def __init__(self): - inits.append((B, '__init__')) + inits.append((B, "__init__")) + self.register(B, inits) class C(B): pass + self.register(C, inits) obj = A() - eq_(inits, [(A, 'init', A)]) + eq_(inits, [(A, "init", A)]) del inits[:] obj = B() - eq_(inits, [(B, 'init', B), (B, '__init__')]) + eq_(inits, [(B, "init", B), (B, "__init__")]) del inits[:] obj = C() - eq_(inits, [(C, 'init', C), (B, '__init__')]) + eq_(inits, [(C, "init", C), (B, "__init__")]) def test_A_B_Ci(self): inits = [] class A(object): pass + self.register(A, inits) class B(A): pass + self.register(B, inits) class C(B): def __init__(self): - inits.append((C, '__init__')) + inits.append((C, "__init__")) + self.register(C, inits) obj = A() - eq_(inits, [(A, 'init', A)]) + eq_(inits, [(A, "init", A)]) del inits[:] obj = B() - eq_(inits, [(B, 'init', B)]) + eq_(inits, [(B, "init", B)]) del inits[:] obj = C() - eq_(inits, [(C, 'init', C), (C, '__init__')]) + eq_(inits, [(C, "init", C), (C, "__init__")]) def test_A_B_C(self): inits = [] class A(object): pass + self.register(A, inits) class B(A): pass + self.register(B, inits) class C(B): pass + self.register(C, inits) obj = A() - eq_(inits, [(A, 'init', A)]) + eq_(inits, [(A, "init", A)]) del inits[:] obj = B() - eq_(inits, [(B, 'init', B)]) + eq_(inits, [(B, "init", B)]) del inits[:] obj = C() - eq_(inits, [(C, 'init', C)]) + eq_(inits, [(C, "init", C)]) def test_defaulted_init(self): class X(object): - def __init__(self_, a, b=123, c='abc'): + def __init__(self_, a, b=123, c="abc"): self_.a = a self_.b = b self_.c = c + instrumentation.register_class(X) - o = X('foo') - eq_(o.a, 'foo') + o = X("foo") + eq_(o.a, "foo") eq_(o.b, 123) - eq_(o.c, 'abc') + eq_(o.c, "abc") class Y(object): unique = object() @@ -417,7 +474,7 @@ class InitTest(fixtures.ORMTest): class OutOfScopeForEval(object): def __repr__(self_): # misleading repr - return '123' + return "123" outofscope = OutOfScopeForEval() @@ -433,13 +490,15 @@ class InitTest(fixtures.ORMTest): class MapperInitTest(fixtures.ORMTest): - def fixture(self): - return Table('t', MetaData(), - Column('id', Integer, primary_key=True), - Column('type', Integer), - Column('x', Integer), - Column('y', Integer)) + return Table( + "t", + MetaData(), + Column("id", Integer, primary_key=True), + Column("type", Integer), + Column("x", Integer), + Column("y", Integer), + ) def test_partially_mapped_inheritance(self): class A(object): @@ -472,7 +531,9 @@ class MapperInitTest(fixtures.ORMTest): r"unreachable cycles and memory leaks, as SQLAlchemy " r"instrumentation often creates reference cycles. " r"Please remove this method.", - mapper, A, self.fixture() + mapper, + A, + self.fixture(), ) @@ -493,7 +554,7 @@ class OnLoadTest(fixtures.ORMTest): try: instrumentation.register_class(A) manager = instrumentation.manager_of_class(A) - event.listen(manager, 'load', canary) + event.listen(manager, "load", canary) a = A() p_a = pickle.dumps(a) @@ -513,18 +574,23 @@ class NativeInstrumentationTest(fixtures.ORMTest): sa = instrumentation.ClassManager.STATE_ATTR ma = instrumentation.ClassManager.MANAGER_ATTR - def fails(method, attr): return assert_raises( - KeyError, getattr(manager, method), attr, property()) + def fails(method, attr): + return assert_raises( + KeyError, getattr(manager, method), attr, property() + ) - fails('install_member', sa) - fails('install_member', ma) - fails('install_descriptor', sa) - fails('install_descriptor', ma) + fails("install_member", sa) + fails("install_member", ma) + fails("install_descriptor", sa) + fails("install_descriptor", ma) def test_mapped_stateattr(self): - t = Table('t', MetaData(), - Column('id', Integer, primary_key=True), - Column(instrumentation.ClassManager.STATE_ATTR, Integer)) + t = Table( + "t", + MetaData(), + Column("id", Integer, primary_key=True), + Column(instrumentation.ClassManager.STATE_ATTR, Integer), + ) class T(object): pass @@ -532,17 +598,21 @@ class NativeInstrumentationTest(fixtures.ORMTest): assert_raises(KeyError, mapper, T, t) def test_mapped_managerattr(self): - t = Table('t', MetaData(), - Column('id', Integer, primary_key=True), - Column(instrumentation.ClassManager.MANAGER_ATTR, Integer)) + t = Table( + "t", + MetaData(), + Column("id", Integer, primary_key=True), + Column(instrumentation.ClassManager.MANAGER_ATTR, Integer), + ) class T(object): pass + assert_raises(KeyError, mapper, T, t) class Py3KFunctionInstTest(fixtures.ORMTest): - __requires__ = ("python3", ) + __requires__ = ("python3",) def _instrument(self, cls): manager = instrumentation.register_class(cls) @@ -550,6 +620,7 @@ class Py3KFunctionInstTest(fixtures.ORMTest): def check(target, args, kwargs): canary.append((args, kwargs)) + event.listen(manager, "init", check) return cls, canary @@ -557,47 +628,39 @@ class Py3KFunctionInstTest(fixtures.ORMTest): cls, canary = self._kw_only_fixture() a = cls("a", b="b", c="c") - eq_(canary, [(('a', ), {'b': 'b', 'c': 'c'})]) + eq_(canary, [(("a",), {"b": "b", "c": "c"})]) def test_kw_plus_posn_args(self): cls, canary = self._kw_plus_posn_fixture() a = cls("a", 1, 2, 3, b="b", c="c") - eq_(canary, [(('a', 1, 2, 3), {'b': 'b', 'c': 'c'})]) + eq_(canary, [(("a", 1, 2, 3), {"b": "b", "c": "c"})]) def test_kw_only_args_plus_opt(self): cls, canary = self._kw_opt_fixture() a = cls("a", b="b") - eq_(canary, [(('a', ), {'b': 'b', 'c': 'c'})]) + eq_(canary, [(("a",), {"b": "b", "c": "c"})]) canary[:] = [] a = cls("a", b="b", c="d") - eq_(canary, [(('a', ), {'b': 'b', 'c': 'd'})]) + eq_(canary, [(("a",), {"b": "b", "c": "d"})]) def test_kw_only_sig(self): cls, canary = self._kw_only_fixture() - assert_raises( - TypeError, - cls, "a", "b", "c" - ) + assert_raises(TypeError, cls, "a", "b", "c") def test_kw_plus_opt_sig(self): cls, canary = self._kw_only_fixture() - assert_raises( - TypeError, - cls, "a", "b", "c" - ) + assert_raises(TypeError, cls, "a", "b", "c") - assert_raises( - TypeError, - cls, "a", "b", c="c" - ) + assert_raises(TypeError, cls, "a", "b", c="c") if util.py3k: _locals = {} - exec(""" + exec( + """ def _kw_only_fixture(self): class A(object): def __init__(self, a, *, b, c): @@ -621,7 +684,9 @@ def _kw_opt_fixture(self): self.b = b self.c = c return self._instrument(A) -""", _locals) +""", + _locals, + ) for k in _locals: setattr(Py3KFunctionInstTest, k, _locals[k]) @@ -630,12 +695,16 @@ class MiscTest(fixtures.ORMTest): """Seems basic, but not directly covered elsewhere!""" def test_compileonattr(self): - t = Table('t', MetaData(), - Column('id', Integer, primary_key=True), - Column('x', Integer)) + t = Table( + "t", + MetaData(), + Column("id", Integer, primary_key=True), + Column("x", Integer), + ) class A(object): pass + mapper(A, t) a = A() @@ -643,18 +712,25 @@ class MiscTest(fixtures.ORMTest): def test_compileonattr_rel(self): m = MetaData() - t1 = Table('t1', m, - Column('id', Integer, primary_key=True), - Column('x', Integer)) - t2 = Table('t2', m, - Column('id', Integer, primary_key=True), - Column('t1_id', Integer, ForeignKey('t1.id'))) + t1 = Table( + "t1", + m, + Column("id", Integer, primary_key=True), + Column("x", Integer), + ) + t2 = Table( + "t2", + m, + Column("id", Integer, primary_key=True), + Column("t1_id", Integer, ForeignKey("t1.id")), + ) class A(object): pass class B(object): pass + mapper(A, t1, properties=dict(bs=relationship(B))) mapper(B, t2) @@ -666,12 +742,12 @@ class MiscTest(fixtures.ORMTest): pass manager = instrumentation.register_class(A) - attributes.register_attribute(A, 'x', uselist=False, useobject=False) + attributes.register_attribute(A, "x", uselist=False, useobject=False) assert instrumentation.manager_of_class(A) is manager instrumentation.unregister_class(A) assert instrumentation.manager_of_class(A) is None - assert not hasattr(A, 'x') + assert not hasattr(A, "x") # I prefer 'is' here but on pypy # it seems only == works @@ -679,24 +755,32 @@ class MiscTest(fixtures.ORMTest): def test_compileonattr_rel_backref_a(self): m = MetaData() - t1 = Table('t1', m, - Column('id', Integer, primary_key=True), - Column('x', Integer)) - t2 = Table('t2', m, - Column('id', Integer, primary_key=True), - Column('t1_id', Integer, ForeignKey('t1.id'))) + t1 = Table( + "t1", + m, + Column("id", Integer, primary_key=True), + Column("x", Integer), + ) + t2 = Table( + "t2", + m, + Column("id", Integer, primary_key=True), + Column("t1_id", Integer, ForeignKey("t1.id")), + ) class Base(object): def __init__(self, *args, **kwargs): pass for base in object, Base: + class A(base): pass class B(base): pass - mapper(A, t1, properties=dict(bs=relationship(B, backref='a'))) + + mapper(A, t1, properties=dict(bs=relationship(B, backref="a"))) mapper(B, t2) b = B() @@ -710,12 +794,18 @@ class MiscTest(fixtures.ORMTest): def test_compileonattr_rel_backref_b(self): m = MetaData() - t1 = Table('t1', m, - Column('id', Integer, primary_key=True), - Column('x', Integer)) - t2 = Table('t2', m, - Column('id', Integer, primary_key=True), - Column('t1_id', Integer, ForeignKey('t1.id'))) + t1 = Table( + "t1", + m, + Column("id", Integer, primary_key=True), + Column("x", Integer), + ) + t2 = Table( + "t2", + m, + Column("id", Integer, primary_key=True), + Column("t1_id", Integer, ForeignKey("t1.id")), + ) class Base(object): def __init__(self): @@ -726,13 +816,15 @@ class MiscTest(fixtures.ORMTest): pass for base in object, Base, Base_AKW: + class A(base): pass class B(base): pass + mapper(A, t1) - mapper(B, t2, properties=dict(a=relationship(A, backref='bs'))) + mapper(B, t2, properties=dict(a=relationship(A, backref="bs"))) a = A() b = B() @@ -740,4 +832,4 @@ class MiscTest(fixtures.ORMTest): session = create_session() session.add(a) - assert b in session, 'base: %s' % base + assert b in session, "base: %s" % base diff --git a/test/orm/test_joins.py b/test/orm/test_joins.py index f74851bd0e..976b5650e0 100644 --- a/test/orm/test_joins.py +++ b/test/orm/test_joins.py @@ -21,53 +21,88 @@ from sqlalchemy.orm.util import join, outerjoin, with_parent class QueryTest(_fixtures.FixtureTest): - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod def setup_mappers(cls): - Node, composite_pk_table, users, Keyword, items, Dingaling, \ - order_items, item_keywords, Item, User, dingalings, \ - Address, keywords, CompositePk, nodes, Order, orders, \ - addresses = cls.classes.Node, \ - cls.tables.composite_pk_table, cls.tables.users, \ - cls.classes.Keyword, cls.tables.items, \ - cls.classes.Dingaling, cls.tables.order_items, \ - cls.tables.item_keywords, cls.classes.Item, \ - cls.classes.User, cls.tables.dingalings, \ - cls.classes.Address, cls.tables.keywords, \ - cls.classes.CompositePk, cls.tables.nodes, \ - cls.classes.Order, cls.tables.orders, cls.tables.addresses - - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user', - order_by=addresses.c.id), - # o2m, m2o - 'orders': relationship(Order, backref='user', order_by=orders.c.id) - }) - mapper(Address, addresses, properties={ - # o2o - 'dingaling': relationship(Dingaling, uselist=False, - backref="address") - }) + Node, composite_pk_table, users, Keyword, items, Dingaling, order_items, item_keywords, Item, User, dingalings, Address, keywords, CompositePk, nodes, Order, orders, addresses = ( + cls.classes.Node, + cls.tables.composite_pk_table, + cls.tables.users, + cls.classes.Keyword, + cls.tables.items, + cls.classes.Dingaling, + cls.tables.order_items, + cls.tables.item_keywords, + cls.classes.Item, + cls.classes.User, + cls.tables.dingalings, + cls.classes.Address, + cls.tables.keywords, + cls.classes.CompositePk, + cls.tables.nodes, + cls.classes.Order, + cls.tables.orders, + cls.tables.addresses, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, backref="user", order_by=addresses.c.id + ), + # o2m, m2o + "orders": relationship( + Order, backref="user", order_by=orders.c.id + ), + }, + ) + mapper( + Address, + addresses, + properties={ + # o2o + "dingaling": relationship( + Dingaling, uselist=False, backref="address" + ) + }, + ) mapper(Dingaling, dingalings) - mapper(Order, orders, properties={ - # m2m - 'items': relationship(Item, secondary=order_items, - order_by=items.c.id), - 'address': relationship(Address), # m2o - }) - mapper(Item, items, properties={ - 'keywords': relationship(Keyword, secondary=item_keywords) # m2m - }) + mapper( + Order, + orders, + properties={ + # m2m + "items": relationship( + Item, secondary=order_items, order_by=items.c.id + ), + "address": relationship(Address), # m2o + }, + ) + mapper( + Item, + items, + properties={ + "keywords": relationship( + Keyword, secondary=item_keywords + ) # m2m + }, + ) mapper(Keyword, keywords) - mapper(Node, nodes, properties={ - 'children': relationship(Node, - backref=backref( - 'parent', remote_side=[nodes.c.id])) - }) + mapper( + Node, + nodes, + properties={ + "children": relationship( + Node, backref=backref("parent", remote_side=[nodes.c.id]) + ) + }, + ) mapper(CompositePk, composite_pk_table) @@ -75,54 +110,100 @@ class QueryTest(_fixtures.FixtureTest): class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): - run_setup_mappers = 'once' + run_setup_mappers = "once" @classmethod def define_tables(cls, metadata): - Table('companies', metadata, - Column('company_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50))) - - Table('people', metadata, - Column('person_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('company_id', Integer, - ForeignKey('companies.company_id')), - Column('name', String(50)), - Column('type', String(30))) - - Table('engineers', metadata, - Column('person_id', Integer, ForeignKey( - 'people.person_id'), primary_key=True), - Column('status', String(30)), - Column('engineer_name', String(50)), - Column('primary_language', String(50))) - - Table('machines', metadata, - Column('machine_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('engineer_id', Integer, - ForeignKey('engineers.person_id'))) - - Table('managers', metadata, - Column('person_id', Integer, ForeignKey( - 'people.person_id'), primary_key=True), - Column('status', String(30)), - Column('manager_name', String(50))) - - Table('boss', metadata, - Column('boss_id', Integer, ForeignKey( - 'managers.person_id'), primary_key=True), - Column('golf_swing', String(30)), - ) - - Table('paperwork', metadata, - Column('paperwork_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('description', String(50)), - Column('person_id', Integer, ForeignKey('people.person_id'))) + Table( + "companies", + metadata, + Column( + "company_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + ) + + Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("company_id", Integer, ForeignKey("companies.company_id")), + Column("name", String(50)), + Column("type", String(30)), + ) + + Table( + "engineers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("status", String(30)), + Column("engineer_name", String(50)), + Column("primary_language", String(50)), + ) + + Table( + "machines", + metadata, + Column( + "machine_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("engineer_id", Integer, ForeignKey("engineers.person_id")), + ) + + Table( + "managers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("status", String(30)), + Column("manager_name", String(50)), + ) + + Table( + "boss", + metadata, + Column( + "boss_id", + Integer, + ForeignKey("managers.person_id"), + primary_key=True, + ), + Column("golf_swing", String(30)), + ) + + Table( + "paperwork", + metadata, + Column( + "paperwork_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("description", String(50)), + Column("person_id", Integer, ForeignKey("people.person_id")), + ) @classmethod def setup_classes(cls): @@ -133,7 +214,8 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): cls.tables.boss, cls.tables.managers, cls.tables.machines, - cls.tables.engineers) + cls.tables.engineers, + ) class Company(cls.Comparable): pass @@ -156,26 +238,42 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): class Paperwork(cls.Comparable): pass - mapper(Company, companies, properties={ - 'employees': relationship(Person, order_by=people.c.person_id) - }) + mapper( + Company, + companies, + properties={ + "employees": relationship(Person, order_by=people.c.person_id) + }, + ) mapper(Machine, machines) - mapper(Person, people, - polymorphic_on=people.c.type, - polymorphic_identity='person', - properties={ - 'paperwork': relationship(Paperwork, - order_by=paperwork.c.paperwork_id) - }) - mapper(Engineer, engineers, inherits=Person, - polymorphic_identity='engineer', - properties={'machines': relationship( - Machine, order_by=machines.c.machine_id)}) - mapper(Manager, managers, - inherits=Person, polymorphic_identity='manager') - mapper(Boss, boss, inherits=Manager, polymorphic_identity='boss') + mapper( + Person, + people, + polymorphic_on=people.c.type, + polymorphic_identity="person", + properties={ + "paperwork": relationship( + Paperwork, order_by=paperwork.c.paperwork_id + ) + }, + ) + mapper( + Engineer, + engineers, + inherits=Person, + polymorphic_identity="engineer", + properties={ + "machines": relationship( + Machine, order_by=machines.c.machine_id + ) + }, + ) + mapper( + Manager, managers, inherits=Person, polymorphic_identity="manager" + ) + mapper(Boss, boss, inherits=Manager, polymorphic_identity="boss") mapper(Paperwork, paperwork) def test_single_prop(self): @@ -189,7 +287,8 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): "companies.name AS companies_name " "FROM companies JOIN people " "ON companies.company_id = people.company_id", - use_default_dialect=True) + use_default_dialect=True, + ) def test_force_via_select_from(self): Company, Engineer = self.classes.Company, self.classes.Engineer @@ -199,25 +298,30 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): self.assert_compile( sess.query(Company) .filter(Company.company_id == Engineer.company_id) - .filter(Engineer.primary_language == 'java'), + .filter(Engineer.primary_language == "java"), "SELECT companies.company_id AS companies_company_id, " "companies.name AS companies_name " "FROM companies, people, engineers " "WHERE companies.company_id = people.company_id " "AND engineers.primary_language " - "= :primary_language_1", use_default_dialect=True) + "= :primary_language_1", + use_default_dialect=True, + ) self.assert_compile( - sess.query(Company).select_from(Company, Engineer) + sess.query(Company) + .select_from(Company, Engineer) .filter(Company.company_id == Engineer.company_id) - .filter(Engineer.primary_language == 'java'), + .filter(Engineer.primary_language == "java"), "SELECT companies.company_id AS companies_company_id, " "companies.name AS companies_name " "FROM companies, people JOIN engineers " "ON people.person_id = engineers.person_id " "WHERE companies.company_id = people.company_id " "AND engineers.primary_language =" - " :primary_language_1", use_default_dialect=True) + " :primary_language_1", + use_default_dialect=True, + ) def test_single_prop_of_type(self): Company, Engineer = self.classes.Company, self.classes.Engineer @@ -232,19 +336,24 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): "(people JOIN engineers " "ON people.person_id = engineers.person_id) " "ON companies.company_id = people.company_id", - use_default_dialect=True) + use_default_dialect=True, + ) def test_prop_with_polymorphic_1(self): - Person, Manager, Paperwork = (self.classes.Person, - self.classes.Manager, - self.classes.Paperwork) + Person, Manager, Paperwork = ( + self.classes.Person, + self.classes.Manager, + self.classes.Paperwork, + ) sess = create_session() self.assert_compile( - sess.query(Person).with_polymorphic(Manager). - order_by(Person.person_id).join('paperwork') - .filter(Paperwork.description.like('%review%')), + sess.query(Person) + .with_polymorphic(Manager) + .order_by(Person.person_id) + .join("paperwork") + .filter(Paperwork.description.like("%review%")), "SELECT people.person_id AS people_person_id, people.company_id AS" " people_company_id, " "people.name AS people_name, people.type AS people_type, " @@ -256,19 +365,25 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): "JOIN paperwork " "ON people.person_id = paperwork.person_id " "WHERE paperwork.description LIKE :description_1 " - "ORDER BY people.person_id", use_default_dialect=True) + "ORDER BY people.person_id", + use_default_dialect=True, + ) def test_prop_with_polymorphic_2(self): - Person, Manager, Paperwork = (self.classes.Person, - self.classes.Manager, - self.classes.Paperwork) + Person, Manager, Paperwork = ( + self.classes.Person, + self.classes.Manager, + self.classes.Paperwork, + ) sess = create_session() self.assert_compile( - sess.query(Person).with_polymorphic(Manager). - order_by(Person.person_id).join('paperwork', aliased=True) - .filter(Paperwork.description.like('%review%')), + sess.query(Person) + .with_polymorphic(Manager) + .order_by(Person.person_id) + .join("paperwork", aliased=True) + .filter(Paperwork.description.like("%review%")), "SELECT people.person_id AS people_person_id, " "people.company_id AS people_company_id, " "people.name AS people_name, people.type AS people_type, " @@ -281,7 +396,8 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): "ON people.person_id = paperwork_1.person_id " "WHERE paperwork_1.description " "LIKE :description_1 ORDER BY people.person_id", - use_default_dialect=True) + use_default_dialect=True, + ) def test_explicit_polymorphic_join_one(self): Company, Engineer = self.classes.Company, self.classes.Engineer @@ -289,8 +405,9 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): sess = create_session() self.assert_compile( - sess.query(Company).join(Engineer) - .filter(Engineer.engineer_name == 'vlad'), + sess.query(Company) + .join(Engineer) + .filter(Engineer.engineer_name == "vlad"), "SELECT companies.company_id AS companies_company_id, " "companies.name AS companies_name " "FROM companies JOIN (people JOIN engineers " @@ -298,7 +415,8 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): "ON " "companies.company_id = people.company_id " "WHERE engineers.engineer_name = :engineer_name_1", - use_default_dialect=True) + use_default_dialect=True, + ) def test_explicit_polymorphic_join_two(self): Company, Engineer = self.classes.Company, self.classes.Engineer @@ -307,7 +425,7 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): self.assert_compile( sess.query(Company) .join(Engineer, Company.company_id == Engineer.company_id) - .filter(Engineer.engineer_name == 'vlad'), + .filter(Engineer.engineer_name == "vlad"), "SELECT companies.company_id AS companies_company_id, " "companies.name AS companies_name " "FROM companies JOIN " @@ -316,7 +434,8 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): "ON " "companies.company_id = people.company_id " "WHERE engineers.engineer_name = :engineer_name_1", - use_default_dialect=True) + use_default_dialect=True, + ) def test_multiple_adaption(self): """test that multiple filter() adapters get chained together " @@ -328,28 +447,31 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): self.classes.Machine, self.tables.engineers, self.tables.machines, - self.classes.Engineer) + self.classes.Engineer, + ) sess = create_session() self.assert_compile( sess.query(Company) .join(people.join(engineers), Company.employees) - .filter(Engineer.name == 'dilbert'), + .filter(Engineer.name == "dilbert"), "SELECT companies.company_id AS companies_company_id, " "companies.name AS companies_name " "FROM companies JOIN (people " "JOIN engineers ON people.person_id = " "engineers.person_id) ON companies.company_id = " "people.company_id WHERE people.name = :name_1", - use_default_dialect=True + use_default_dialect=True, ) mach_alias = machines.select() self.assert_compile( - sess.query(Company).join(people.join(engineers), Company.employees) - .join(mach_alias, Engineer.machines, from_joinpoint=True). - filter(Engineer.name == 'dilbert').filter(Machine.name == 'foo'), + sess.query(Company) + .join(people.join(engineers), Company.employees) + .join(mach_alias, Engineer.machines, from_joinpoint=True) + .filter(Engineer.name == "dilbert") + .filter(Machine.name == "foo"), "SELECT companies.company_id AS companies_company_id, " "companies.name AS companies_name " "FROM companies JOIN (people " @@ -362,20 +484,25 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): "FROM machines) AS anon_1 " "ON engineers.person_id = anon_1.engineer_id " "WHERE people.name = :name_1 AND anon_1.name = :name_2", - use_default_dialect=True + use_default_dialect=True, ) def test_auto_aliasing_multi_link(self): # test [ticket:2903] sess = create_session() - Company, Engineer, Manager, Boss = self.classes.Company, \ - self.classes.Engineer, \ - self.classes.Manager, self.classes.Boss - q = sess.query(Company).\ - join(Company.employees.of_type(Engineer)).\ - join(Company.employees.of_type(Manager)).\ - join(Company.employees.of_type(Boss)) + Company, Engineer, Manager, Boss = ( + self.classes.Company, + self.classes.Engineer, + self.classes.Manager, + self.classes.Boss, + ) + q = ( + sess.query(Company) + .join(Company.employees.of_type(Engineer)) + .join(Company.employees.of_type(Manager)) + .join(Company.employees.of_type(Boss)) + ) self.assert_compile( q, @@ -391,21 +518,26 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): "ON people_2.person_id = managers_2.person_id JOIN boss AS boss_1 " "ON managers_2.person_id = boss_1.boss_id) " "ON companies.company_id = people_2.company_id", - use_default_dialect=True) + use_default_dialect=True, + ) class JoinOnSynonymTest(_fixtures.FixtureTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_mappers(cls): User = cls.classes.User Address = cls.classes.Address users, addresses = (cls.tables.users, cls.tables.addresses) - mapper(User, users, properties={ - 'addresses': relationship(Address), - 'ad_syn': synonym("addresses") - }) + mapper( + User, + users, + properties={ + "addresses": relationship(Address), + "ad_syn": synonym("addresses"), + }, + ) mapper(Address, addresses) def test_join_on_synonym(self): @@ -413,12 +545,12 @@ class JoinOnSynonymTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.assert_compile( Session().query(User).join(User.ad_syn), "SELECT users.id AS users_id, users.name AS users_name " - "FROM users JOIN addresses ON users.id = addresses.user_id" + "FROM users JOIN addresses ON users.id = addresses.user_id", ) class JoinTest(QueryTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_single_name(self): User = self.classes.User @@ -428,12 +560,11 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.assert_compile( sess.query(User).join("orders"), "SELECT users.id AS users_id, users.name AS users_name " - "FROM users JOIN orders ON users.id = orders.user_id" + "FROM users JOIN orders ON users.id = orders.user_id", ) assert_raises( - sa_exc.InvalidRequestError, - sess.query(User).join, "user", + sa_exc.InvalidRequestError, sess.query(User).join, "user" ) self.assert_compile( @@ -442,21 +573,21 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "JOIN orders ON users.id = orders.user_id " "JOIN order_items AS order_items_1 " "ON orders.id = order_items_1.order_id JOIN items " - "ON items.id = order_items_1.item_id" + "ON items.id = order_items_1.item_id", ) # test overlapping paths. User->orders is used by both joins, but # rendered once. self.assert_compile( - sess.query(User).join("orders", "items").join( - "orders", "address"), + sess.query(User).join("orders", "items").join("orders", "address"), "SELECT users.id AS users_id, users.name AS users_name FROM users " "JOIN orders " "ON users.id = orders.user_id " "JOIN order_items AS order_items_1 " "ON orders.id = order_items_1.order_id " "JOIN items ON items.id = order_items_1.item_id JOIN addresses " - "ON addresses.id = orders.address_id") + "ON addresses.id = orders.address_id", + ) def test_invalid_kwarg_join(self): User = self.classes.User @@ -464,12 +595,18 @@ class JoinTest(QueryTest, AssertsCompiledSQL): assert_raises_message( TypeError, "unknown arguments: bar, foob", - sess.query(User).join, "address", foob="bar", bar="bat" + sess.query(User).join, + "address", + foob="bar", + bar="bat", ) assert_raises_message( TypeError, "unknown arguments: bar, foob", - sess.query(User).outerjoin, "address", foob="bar", bar="bat" + sess.query(User).outerjoin, + "address", + foob="bar", + bar="bat", ) def test_left_w_no_entity(self): @@ -479,15 +616,15 @@ class JoinTest(QueryTest, AssertsCompiledSQL): sess = create_session() self.assert_compile( - sess.query(User, literal_column('x'), ).join(Address), + sess.query(User, literal_column("x")).join(Address), "SELECT users.id AS users_id, users.name AS users_name, x " - "FROM users JOIN addresses ON users.id = addresses.user_id" + "FROM users JOIN addresses ON users.id = addresses.user_id", ) self.assert_compile( - sess.query(literal_column('x'), User).join(Address), + sess.query(literal_column("x"), User).join(Address), "SELECT x, users.id AS users_id, users.name AS users_name " - "FROM users JOIN addresses ON users.id = addresses.user_id" + "FROM users JOIN addresses ON users.id = addresses.user_id", ) def test_left_is_none_and_query_has_no_entities(self): @@ -500,25 +637,26 @@ class JoinTest(QueryTest, AssertsCompiledSQL): sa_exc.InvalidRequestError, r"No entities to join from; please use select_from\(\) to " r"establish the left entity/selectable of this join", - sess.query().join, Address + sess.query().join, + Address, ) def test_isouter_flag(self): User = self.classes.User self.assert_compile( - create_session().query(User).join('orders', isouter=True), + create_session().query(User).join("orders", isouter=True), "SELECT users.id AS users_id, users.name AS users_name " - "FROM users LEFT OUTER JOIN orders ON users.id = orders.user_id" + "FROM users LEFT OUTER JOIN orders ON users.id = orders.user_id", ) def test_full_flag(self): User = self.classes.User self.assert_compile( - create_session().query(User).outerjoin('orders', full=True), + create_session().query(User).outerjoin("orders", full=True), "SELECT users.id AS users_id, users.name AS users_name " - "FROM users FULL OUTER JOIN orders ON users.id = orders.user_id" + "FROM users FULL OUTER JOIN orders ON users.id = orders.user_id", ) def test_multi_tuple_form(self): @@ -529,9 +667,11 @@ class JoinTest(QueryTest, AssertsCompiledSQL): """ - Item, Order, User = (self.classes.Item, - self.classes.Order, - self.classes.User) + Item, Order, User = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + ) sess = create_session() @@ -548,8 +688,8 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.assert_compile( sess.query(User).join( - (Order, User.id == Order.user_id), - (Item, Order.items)), + (Order, User.id == Order.user_id), (Item, Order.items) + ), "SELECT users.id AS users_id, users.name AS users_name " "FROM users JOIN orders ON users.id = orders.user_id " "JOIN order_items AS order_items_1 ON orders.id = " @@ -565,36 +705,42 @@ class JoinTest(QueryTest, AssertsCompiledSQL): ) def test_single_prop_1(self): - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() self.assert_compile( sess.query(User).join(User.orders), "SELECT users.id AS users_id, users.name AS users_name " - "FROM users JOIN orders ON users.id = orders.user_id" + "FROM users JOIN orders ON users.id = orders.user_id", ) def test_single_prop_2(self): - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() self.assert_compile( sess.query(User).join(Order.user), "SELECT users.id AS users_id, users.name AS users_name " - "FROM orders JOIN users ON users.id = orders.user_id" + "FROM orders JOIN users ON users.id = orders.user_id", ) def test_single_prop_3(self): - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() oalias1 = aliased(Order) @@ -602,14 +748,16 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.assert_compile( sess.query(User).join(oalias1.user), "SELECT users.id AS users_id, users.name AS users_name " - "FROM orders AS orders_1 JOIN users ON users.id = orders_1.user_id" + "FROM orders AS orders_1 JOIN users ON users.id = orders_1.user_id", ) def test_single_prop_4(self): - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() oalias1 = aliased(Order) @@ -621,13 +769,16 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "SELECT users.id AS users_id, users.name AS users_name " "FROM orders AS orders_1 JOIN users " "ON users.id = orders_1.user_id, " - "orders AS orders_2 JOIN users ON users.id = orders_2.user_id") + "orders AS orders_2 JOIN users ON users.id = orders_2.user_id", + ) def test_single_prop_5(self): - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() self.assert_compile( @@ -636,28 +787,32 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "JOIN orders ON users.id = orders.user_id " "JOIN order_items AS order_items_1 " "ON orders.id = order_items_1.order_id JOIN items " - "ON items.id = order_items_1.item_id" + "ON items.id = order_items_1.item_id", ) def test_single_prop_6(self): - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() ualias = aliased(User) self.assert_compile( sess.query(ualias).join(ualias.orders), "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name " - "FROM users AS users_1 JOIN orders ON users_1.id = orders.user_id" + "FROM users AS users_1 JOIN orders ON users_1.id = orders.user_id", ) def test_single_prop_7(self): - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() # this query is somewhat nonsensical. the old system didn't render a @@ -671,13 +826,16 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "JOIN orders ON users.id = orders.user_id, " "orders AS orders_1 JOIN order_items AS order_items_1 " "ON orders_1.id = order_items_1.order_id " - "JOIN items ON items.id = order_items_1.item_id") + "JOIN items ON items.id = order_items_1.item_id", + ) def test_single_prop_8(self): - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() # same as before using an aliased() for User as well @@ -690,101 +848,117 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "JOIN orders ON users_1.id = orders.user_id, " "orders AS orders_1 JOIN order_items AS order_items_1 " "ON orders_1.id = order_items_1.order_id " - "JOIN items ON items.id = order_items_1.item_id") + "JOIN items ON items.id = order_items_1.item_id", + ) def test_single_prop_9(self): - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() self.assert_compile( - sess.query(User).filter(User.name == 'ed').from_self(). - join(User.orders), + sess.query(User) + .filter(User.name == "ed") + .from_self() + .join(User.orders), "SELECT anon_1.users_id AS anon_1_users_id, " "anon_1.users_name AS anon_1_users_name " "FROM (SELECT users.id AS users_id, users.name AS users_name " "FROM users " "WHERE users.name = :name_1) AS anon_1 JOIN orders " - "ON anon_1.users_id = orders.user_id" + "ON anon_1.users_id = orders.user_id", ) def test_single_prop_10(self): - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() self.assert_compile( - sess.query(User).join(User.addresses, aliased=True). - filter(Address.email_address == 'foo'), + sess.query(User) + .join(User.addresses, aliased=True) + .filter(Address.email_address == "foo"), "SELECT users.id AS users_id, users.name AS users_name " "FROM users JOIN addresses AS addresses_1 " "ON users.id = addresses_1.user_id " - "WHERE addresses_1.email_address = :email_address_1" + "WHERE addresses_1.email_address = :email_address_1", ) def test_single_prop_11(self): - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() self.assert_compile( - sess.query(User).join(User.orders, Order.items, aliased=True). - filter(Item.id == 10), + sess.query(User) + .join(User.orders, Order.items, aliased=True) + .filter(Item.id == 10), "SELECT users.id AS users_id, users.name AS users_name " "FROM users JOIN orders AS orders_1 " "ON users.id = orders_1.user_id " "JOIN order_items AS order_items_1 " "ON orders_1.id = order_items_1.order_id " "JOIN items AS items_1 ON items_1.id = order_items_1.item_id " - "WHERE items_1.id = :id_1") + "WHERE items_1.id = :id_1", + ) def test_single_prop_12(self): - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() oalias1 = aliased(Order) # test #1 for [ticket:1706] ualias = aliased(User) self.assert_compile( - sess.query(ualias). - join(oalias1, ualias.orders). - join(Address, ualias.addresses), + sess.query(ualias) + .join(oalias1, ualias.orders) + .join(Address, ualias.addresses), "SELECT users_1.id AS users_1_id, users_1.name AS " "users_1_name FROM users AS users_1 JOIN orders AS orders_1 " "ON users_1.id = orders_1.user_id JOIN addresses ON users_1.id " - "= addresses.user_id" + "= addresses.user_id", ) def test_single_prop_13(self): - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() # test #2 for [ticket:1706] ualias = aliased(User) ualias2 = aliased(User) self.assert_compile( - sess.query(ualias). - join(Address, ualias.addresses). - join(ualias2, Address.user). - join(Order, ualias.orders), + sess.query(ualias) + .join(Address, ualias.addresses) + .join(ualias2, Address.user) + .join(Order, ualias.orders), "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name " "FROM users " "AS users_1 JOIN addresses ON users_1.id = addresses.user_id " "JOIN users AS users_2 " "ON users_2.id = addresses.user_id JOIN orders " - "ON users_1.id = orders.user_id" + "ON users_1.id = orders.user_id", ) def test_overlapping_paths(self): @@ -793,22 +967,28 @@ class JoinTest(QueryTest, AssertsCompiledSQL): for aliased in (True, False): # load a user who has an order that contains item id 3 and address # id 1 (order 3, owned by jack) - result = create_session().query(User) \ - .join('orders', 'items', aliased=aliased) \ - .filter_by(id=3) \ - .join('orders', 'address', aliased=aliased) \ - .filter_by(id=1).all() - assert [User(id=7, name='jack')] == result + result = ( + create_session() + .query(User) + .join("orders", "items", aliased=aliased) + .filter_by(id=3) + .join("orders", "address", aliased=aliased) + .filter_by(id=1) + .all() + ) + assert [User(id=7, name="jack")] == result def test_overlapping_paths_multilevel(self): User = self.classes.User s = Session() - q = s.query(User).\ - join('orders').\ - join('addresses').\ - join('orders', 'items').\ - join('addresses', 'dingaling') + q = ( + s.query(User) + .join("orders") + .join("addresses") + .join("orders", "items") + .join("addresses", "dingaling") + ) self.assert_compile( q, "SELECT users.id AS users_id, users.name AS users_name " @@ -817,17 +997,22 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "JOIN order_items AS order_items_1 ON orders.id = " "order_items_1.order_id " "JOIN items ON items.id = order_items_1.item_id " - "JOIN dingalings ON addresses.id = dingalings.address_id" - + "JOIN dingalings ON addresses.id = dingalings.address_id", ) def test_overlapping_paths_outerjoin(self): User = self.classes.User - result = create_session().query(User).outerjoin('orders', 'items') \ - .filter_by(id=3).outerjoin('orders', 'address') \ - .filter_by(id=1).all() - assert [User(id=7, name='jack')] == result + result = ( + create_session() + .query(User) + .outerjoin("orders", "items") + .filter_by(id=3) + .outerjoin("orders", "address") + .filter_by(id=1) + .all() + ) + assert [User(id=7, name="jack")] == result def test_raises_on_dupe_target_rel(self): User = self.classes.User @@ -836,52 +1021,65 @@ class JoinTest(QueryTest, AssertsCompiledSQL): sa.exc.SAWarning, "Pathed join target Order.items has already been joined to; " "skipping", - lambda: create_session().query(User).outerjoin('orders', 'items'). - outerjoin('orders', 'items') + lambda: create_session() + .query(User) + .outerjoin("orders", "items") + .outerjoin("orders", "items"), ) def test_from_joinpoint(self): - Item, User, Order = (self.classes.Item, - self.classes.User, - self.classes.Order) + Item, User, Order = ( + self.classes.Item, + self.classes.User, + self.classes.Order, + ) sess = create_session() for oalias, ialias in [ - (True, True), - (False, False), - (True, False), - (False, True)]: + (True, True), + (False, False), + (True, False), + (False, True), + ]: eq_( - sess.query(User).join('orders', aliased=oalias) - .join('items', from_joinpoint=True, aliased=ialias) - .filter(Item.description == 'item 4').all(), - [User(name='jack')] + sess.query(User) + .join("orders", aliased=oalias) + .join("items", from_joinpoint=True, aliased=ialias) + .filter(Item.description == "item 4") + .all(), + [User(name="jack")], ) # use middle criterion eq_( - sess.query(User).join('orders', aliased=oalias) + sess.query(User) + .join("orders", aliased=oalias) .filter(Order.user_id == 9) - .join('items', from_joinpoint=True, aliased=ialias) - .filter(Item.description == 'item 4').all(), - [] + .join("items", from_joinpoint=True, aliased=ialias) + .filter(Item.description == "item 4") + .all(), + [], ) orderalias = aliased(Order) itemalias = aliased(Item) eq_( - sess.query(User).join(orderalias, 'orders') - .join(itemalias, 'items', from_joinpoint=True) - .filter(itemalias.description == 'item 4').all(), - [User(name='jack')] + sess.query(User) + .join(orderalias, "orders") + .join(itemalias, "items", from_joinpoint=True) + .filter(itemalias.description == "item 4") + .all(), + [User(name="jack")], ) eq_( - sess.query(User).join(orderalias, 'orders') - .join(itemalias, 'items', from_joinpoint=True) + sess.query(User) + .join(orderalias, "orders") + .join(itemalias, "items", from_joinpoint=True) .filter(orderalias.user_id == 9) - .filter(itemalias.description == 'item 4').all(), - [] + .filter(itemalias.description == "item 4") + .all(), + [], ) def test_join_nonmapped_column(self): @@ -893,9 +1091,9 @@ class JoinTest(QueryTest, AssertsCompiledSQL): # intentionally join() with a non-existent "left" side self.assert_compile( - sess.query(User.id, literal_column('foo')).join(Order.user), + sess.query(User.id, literal_column("foo")).join(Order.user), "SELECT users.id AS users_id, foo FROM " - "orders JOIN users ON users.id = orders.user_id" + "orders JOIN users ON users.id = orders.user_id", ) def test_backwards_join(self): @@ -907,17 +1105,21 @@ class JoinTest(QueryTest, AssertsCompiledSQL): sess = create_session() eq_( - sess.query(User).join(Address.user) - .filter(Address.email_address == 'ed@wood.com').all(), - [User(id=8, name='ed')] + sess.query(User) + .join(Address.user) + .filter(Address.email_address == "ed@wood.com") + .all(), + [User(id=8, name="ed")], ) # its actually not so controversial if you view it in terms # of multiple entities. eq_( - sess.query(User, Address).join(Address.user) - .filter(Address.email_address == 'ed@wood.com').all(), - [(User(id=8, name='ed'), Address(email_address='ed@wood.com'))] + sess.query(User, Address) + .join(Address.user) + .filter(Address.email_address == "ed@wood.com") + .all(), + [(User(id=8, name="ed"), Address(email_address="ed@wood.com"))], ) # this was the controversial part. now, raise an error if the feature @@ -925,14 +1127,18 @@ class JoinTest(QueryTest, AssertsCompiledSQL): # before the error raise was added, this would silently work..... assert_raises( sa_exc.InvalidRequestError, - sess.query(User).join, Address, Address.user, + sess.query(User).join, + Address, + Address.user, ) # but this one would silently fail adalias = aliased(Address) assert_raises( sa_exc.InvalidRequestError, - sess.query(User).join, adalias, Address.user, + sess.query(User).join, + adalias, + Address.user, ) def test_multiple_with_aliases(self): @@ -944,7 +1150,8 @@ class JoinTest(QueryTest, AssertsCompiledSQL): oalias1 = aliased(Order) oalias2 = aliased(Order) self.assert_compile( - sess.query(ualias).join(oalias1, ualias.orders) + sess.query(ualias) + .join(oalias1, ualias.orders) .join(oalias2, ualias.orders) .filter(or_(oalias1.user_id == 9, oalias2.user_id == 7)), "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name " @@ -954,7 +1161,8 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "users_1.id = orders_2.user_id " "WHERE orders_1.user_id = :user_id_1 " "OR orders_2.user_id = :user_id_2", - use_default_dialect=True) + use_default_dialect=True, + ) def test_select_from_orm_joins(self): User, Order = self.classes.User, self.classes.Order @@ -968,36 +1176,42 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.assert_compile( join(User, oalias2, User.id == oalias2.user_id), "users JOIN orders AS orders_1 ON users.id = orders_1.user_id", - use_default_dialect=True + use_default_dialect=True, ) self.assert_compile( join(ualias, oalias1, ualias.orders), "users AS users_1 JOIN orders AS orders_1 " "ON users_1.id = orders_1.user_id", - use_default_dialect=True) + use_default_dialect=True, + ) self.assert_compile( sess.query(ualias).select_from( - join(ualias, oalias1, ualias.orders)), + join(ualias, oalias1, ualias.orders) + ), "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name " "FROM users AS users_1 " "JOIN orders AS orders_1 ON users_1.id = orders_1.user_id", - use_default_dialect=True) + use_default_dialect=True, + ) self.assert_compile( sess.query(User, ualias).select_from( - join(ualias, oalias1, ualias.orders)), + join(ualias, oalias1, ualias.orders) + ), "SELECT users.id AS users_id, users.name AS users_name, " "users_1.id AS users_1_id, " "users_1.name AS users_1_name FROM users, users AS users_1 " "JOIN orders AS orders_1 ON users_1.id = orders_1.user_id", - use_default_dialect=True) + use_default_dialect=True, + ) # this fails (and we cant quite fix right now). if False: self.assert_compile( - sess.query(User, ualias).join(oalias1, ualias.orders) + sess.query(User, ualias) + .join(oalias1, ualias.orders) .join(oalias2, User.id == oalias2.user_id) .filter(or_(oalias1.user_id == 9, oalias2.user_id == 7)), "SELECT users.id AS users_id, users.name AS users_name, " @@ -1008,14 +1222,17 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "ON users_1.id = orders_1.user_id " "WHERE orders_1.user_id = :user_id_1 " "OR orders_2.user_id = :user_id_2", - use_default_dialect=True) + use_default_dialect=True, + ) # this is the same thing using explicit orm.join() (which now offers # multiple again) self.assert_compile( - sess.query(User, ualias).select_from( + sess.query(User, ualias) + .select_from( join(ualias, oalias1, ualias.orders), - join(User, oalias2, User.id == oalias2.user_id),) + join(User, oalias2, User.id == oalias2.user_id), + ) .filter(or_(oalias1.user_id == 9, oalias2.user_id == 7)), "SELECT users.id AS users_id, users.name AS users_name, " "users_1.id AS users_1_id, users_1.name AS " @@ -1024,7 +1241,8 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "users JOIN orders AS orders_2 ON users.id = orders_2.user_id " "WHERE orders_1.user_id = :user_id_1 " "OR orders_2.user_id = :user_id_2", - use_default_dialect=True) + use_default_dialect=True, + ) def test_overlapping_backwards_joins(self): User, Order = self.classes.User, self.classes.Order @@ -1042,20 +1260,24 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "FROM orders AS orders_1 " "JOIN users ON users.id = orders_1.user_id, orders AS orders_2 " "JOIN users ON users.id = orders_2.user_id", - use_default_dialect=True,) + use_default_dialect=True, + ) def test_replace_multiple_from_clause(self): """test adding joins onto multiple FROM clauses""" - User, Order, Address = (self.classes.User, - self.classes.Order, - self.classes.Address) + User, Order, Address = ( + self.classes.User, + self.classes.Order, + self.classes.Address, + ) sess = create_session() self.assert_compile( sess.query(Address, User) - .join(Address.dingaling).join(User.orders, Order.items), + .join(Address.dingaling) + .join(User.orders, Order.items), "SELECT addresses.id AS addresses_id, " "addresses.user_id AS addresses_user_id, " "addresses.email_address AS addresses_email_address, " @@ -1066,12 +1288,11 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "JOIN order_items AS order_items_1 " "ON orders.id = order_items_1.order_id JOIN items " "ON items.id = order_items_1.item_id", - use_default_dialect=True + use_default_dialect=True, ) def test_invalid_join_entity_from_single_from_clause(self): - Address, Item = ( - self.classes.Address, self.classes.Item) + Address, Item = (self.classes.Address, self.classes.Item) sess = create_session() q = sess.query(Address).select_from(Address) @@ -1081,12 +1302,12 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "Don't know how to join to .*Item.*; " "please use an ON clause to more clearly establish the " "left side of this join", - q.join, Item + q.join, + Item, ) def test_invalid_join_entity_from_no_from_clause(self): - Address, Item = ( - self.classes.Address, self.classes.Item) + Address, Item = (self.classes.Address, self.classes.Item) sess = create_session() q = sess.query(Address) @@ -1096,7 +1317,8 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "Don't know how to join to .*Item.*; " "please use an ON clause to more clearly establish the " "left side of this join", - q.join, Item + q.join, + Item, ) def test_invalid_join_entity_from_multiple_from_clause(self): @@ -1104,18 +1326,21 @@ class JoinTest(QueryTest, AssertsCompiledSQL): we still need to say there's nothing to JOIN from""" User, Address, Item = ( - self.classes.User, self.classes.Address, self.classes.Item) + self.classes.User, + self.classes.Address, + self.classes.Item, + ) sess = create_session() - q = sess.query(Address, User).join(Address.dingaling).\ - join(User.orders) + q = sess.query(Address, User).join(Address.dingaling).join(User.orders) assert_raises_message( sa.exc.InvalidRequestError, "Don't know how to join to .*Item.*; " "please use an ON clause to more clearly establish the " "left side of this join", - q.join, Item + q.join, + Item, ) def test_join_explicit_left_multiple_from_clause(self): @@ -1133,28 +1358,24 @@ class JoinTest(QueryTest, AssertsCompiledSQL): # is users, the other is u1_alias. # User.addresses looks for the "users" table and can match # to both u1_alias and users if the match is not specific enough - q = sess.query(User, u1).\ - select_from(User, u1).\ - join(User.addresses) + q = sess.query(User, u1).select_from(User, u1).join(User.addresses) self.assert_compile( q, "SELECT users.id AS users_id, users.name AS users_name, " "users_1.id AS users_1_id, users_1.name AS users_1_name " "FROM users AS users_1, " - "users JOIN addresses ON users.id = addresses.user_id" + "users JOIN addresses ON users.id = addresses.user_id", ) - q = sess.query(User, u1).\ - select_from(User, u1).\ - join(u1.addresses) + q = sess.query(User, u1).select_from(User, u1).join(u1.addresses) self.assert_compile( q, "SELECT users.id AS users_id, users.name AS users_name, " "users_1.id AS users_1_id, users_1.name AS users_1_name " "FROM users, " - "users AS users_1 JOIN addresses ON users_1.id = addresses.user_id" + "users AS users_1 JOIN addresses ON users_1.id = addresses.user_id", ) def test_join_explicit_left_multiple_adapted(self): @@ -1178,7 +1399,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "Can't identify which entity in which to assign the " "left side of this join.", sess.query(u1, u2).select_from(u1, u2).join, - User.addresses + User.addresses, ) # more specific ON clause @@ -1187,7 +1408,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name, " "users_2.id AS users_2_id, users_2.name AS users_2_name " "FROM users AS users_1, " - "users AS users_2 JOIN addresses ON users_2.id = addresses.user_id" + "users AS users_2 JOIN addresses ON users_2.id = addresses.user_id", ) def test_join_entity_from_multiple_from_clause(self): @@ -1198,12 +1419,12 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.classes.User, self.classes.Order, self.classes.Address, - self.classes.Dingaling) + self.classes.Dingaling, + ) sess = create_session() - q = sess.query(Address, User).join(Address.dingaling).\ - join(User.orders) + q = sess.query(Address, User).join(Address.dingaling).join(User.orders) a1 = aliased(Address) @@ -1212,7 +1433,8 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "Can't determine which FROM clause to join from, there are " "multiple FROMS which can join to this entity. " "Try adding an explicit ON clause to help resolve the ambiguity.", - q.join, a1 + q.join, + a1, ) # to resolve, add an ON clause @@ -1229,7 +1451,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "users JOIN orders " "ON users.id = orders.user_id " "JOIN addresses AS addresses_1 " - "ON orders.address_id = addresses_1.id" + "ON orders.address_id = addresses_1.id", ) # the address->dingalings join is chosen to join to a1 @@ -1243,7 +1465,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "ON addresses.id = dingalings.address_id " "JOIN addresses AS addresses_1 " "ON dingalings.address_id = addresses_1.id, " - "users JOIN orders ON users.id = orders.user_id" + "users JOIN orders ON users.id = orders.user_id", ) def test_join_entity_from_multiple_entities(self): @@ -1253,7 +1475,8 @@ class JoinTest(QueryTest, AssertsCompiledSQL): Order, Address, Dingaling = ( self.classes.Order, self.classes.Address, - self.classes.Dingaling) + self.classes.Dingaling, + ) sess = create_session() @@ -1266,7 +1489,8 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "Can't determine which FROM clause to join from, there are " "multiple FROMS which can join to this entity. " "Try adding an explicit ON clause to help resolve the ambiguity.", - q.join, a1 + q.join, + a1, ) # to resolve, add an ON clause @@ -1282,7 +1506,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "dingalings.data AS dingalings_data " "FROM dingalings, orders " "JOIN addresses AS addresses_1 " - "ON orders.address_id = addresses_1.id" + "ON orders.address_id = addresses_1.id", ) # Dingaling is chosen to join to a1 @@ -1295,61 +1519,77 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "dingalings.address_id AS dingalings_address_id, " "dingalings.data AS dingalings_data " "FROM orders, dingalings JOIN addresses AS addresses_1 " - "ON dingalings.address_id = addresses_1.id" + "ON dingalings.address_id = addresses_1.id", ) def test_multiple_adaption(self): - Item, Order, User = (self.classes.Item, - self.classes.Order, - self.classes.User) + Item, Order, User = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + ) sess = create_session() self.assert_compile( - sess.query(User).join(User.orders, Order.items, aliased=True) - .filter(Order.id == 7).filter(Item.id == 8), + sess.query(User) + .join(User.orders, Order.items, aliased=True) + .filter(Order.id == 7) + .filter(Item.id == 8), "SELECT users.id AS users_id, users.name AS users_name FROM users " "JOIN orders AS orders_1 " "ON users.id = orders_1.user_id JOIN order_items AS order_items_1 " "ON orders_1.id = order_items_1.order_id " "JOIN items AS items_1 ON items_1.id = order_items_1.item_id " "WHERE orders_1.id = :id_1 AND items_1.id = :id_2", - use_default_dialect=True + use_default_dialect=True, ) def test_onclause_conditional_adaption(self): - Item, Order, orders, order_items, User = (self.classes.Item, - self.classes.Order, - self.tables.orders, - self.tables.order_items, - self.classes.User) + Item, Order, orders, order_items, User = ( + self.classes.Item, + self.classes.Order, + self.tables.orders, + self.tables.order_items, + self.classes.User, + ) sess = create_session() # this is now a very weird test, nobody should really # be using the aliased flag in this way. self.assert_compile( - sess.query(User).join(User.orders, aliased=True). - join(Item, - and_(Order.id == order_items.c.order_id, - order_items.c.item_id == Item.id), - from_joinpoint=True, aliased=True), + sess.query(User) + .join(User.orders, aliased=True) + .join( + Item, + and_( + Order.id == order_items.c.order_id, + order_items.c.item_id == Item.id, + ), + from_joinpoint=True, + aliased=True, + ), "SELECT users.id AS users_id, users.name AS users_name FROM users " "JOIN orders AS orders_1 ON users.id = orders_1.user_id " "JOIN items AS items_1 " "ON orders_1.id = order_items.order_id " "AND order_items.item_id = items_1.id", - use_default_dialect=True + use_default_dialect=True, ) oalias = orders.select() self.assert_compile( - sess.query(User).join(oalias, User.orders) - .join(Item, - and_( - Order.id == order_items.c.order_id, - order_items.c.item_id == Item.id), - from_joinpoint=True), + sess.query(User) + .join(oalias, User.orders) + .join( + Item, + and_( + Order.id == order_items.c.order_id, + order_items.c.item_id == Item.id, + ), + from_joinpoint=True, + ), "SELECT users.id AS users_id, users.name AS users_name " "FROM users JOIN " "(SELECT orders.id AS id, orders.user_id AS user_id, " @@ -1358,7 +1598,8 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "ON users.id = anon_1.user_id JOIN items " "ON anon_1.id = order_items.order_id " "AND order_items.item_id = items.id", - use_default_dialect=True) + use_default_dialect=True, + ) # query.join(, aliased=True).join(target, sql_expression) # or: query.join(path_to_some_joined_table_mapper).join(target, @@ -1372,98 +1613,125 @@ class JoinTest(QueryTest, AssertsCompiledSQL): self.assert_compile( sess.query(users).join(addresses), "SELECT users.id AS users_id, users.name AS users_name " - "FROM users JOIN addresses ON users.id = addresses.user_id" + "FROM users JOIN addresses ON users.id = addresses.user_id", ) def test_orderby_arg_bug(self): - User, users, Order = (self.classes.User, - self.tables.users, - self.classes.Order) + User, users, Order = ( + self.classes.User, + self.tables.users, + self.classes.Order, + ) sess = create_session() # no arg error - result = sess.query(User).join('orders', aliased=True) \ - .order_by(Order.id).reset_joinpoint().order_by(users.c.id).all() + result = ( + sess.query(User) + .join("orders", aliased=True) + .order_by(Order.id) + .reset_joinpoint() + .order_by(users.c.id) + .all() + ) def test_no_onclause(self): - Item, User, Order = (self.classes.Item, - self.classes.User, - self.classes.Order) + Item, User, Order = ( + self.classes.Item, + self.classes.User, + self.classes.Order, + ) sess = create_session() eq_( - sess.query(User).select_from(join(User, Order) - .join(Item, Order.items)) - .filter(Item.description == 'item 4').all(), - [User(name='jack')] + sess.query(User) + .select_from(join(User, Order).join(Item, Order.items)) + .filter(Item.description == "item 4") + .all(), + [User(name="jack")], ) eq_( - sess.query(User.name).select_from(join(User, Order) - .join(Item, Order.items)) - .filter(Item.description == 'item 4').all(), - [('jack',)] + sess.query(User.name) + .select_from(join(User, Order).join(Item, Order.items)) + .filter(Item.description == "item 4") + .all(), + [("jack",)], ) eq_( - sess.query(User).join(Order).join(Item, Order.items) - .filter(Item.description == 'item 4').all(), - [User(name='jack')] + sess.query(User) + .join(Order) + .join(Item, Order.items) + .filter(Item.description == "item 4") + .all(), + [User(name="jack")], ) def test_clause_onclause(self): - Item, Order, users, order_items, User = (self.classes.Item, - self.classes.Order, - self.tables.users, - self.tables.order_items, - self.classes.User) + Item, Order, users, order_items, User = ( + self.classes.Item, + self.classes.Order, + self.tables.users, + self.tables.order_items, + self.classes.User, + ) sess = create_session() eq_( - sess.query(User).join(Order, User.id == Order.user_id) + sess.query(User) + .join(Order, User.id == Order.user_id) .join(order_items, Order.id == order_items.c.order_id) .join(Item, order_items.c.item_id == Item.id) - .filter(Item.description == 'item 4').all(), - [User(name='jack')] + .filter(Item.description == "item 4") + .all(), + [User(name="jack")], ) eq_( - sess.query(User.name).join(Order, User.id == Order.user_id) + sess.query(User.name) + .join(Order, User.id == Order.user_id) .join(order_items, Order.id == order_items.c.order_id) .join(Item, order_items.c.item_id == Item.id) - .filter(Item.description == 'item 4').all(), - [('jack',)] + .filter(Item.description == "item 4") + .all(), + [("jack",)], ) ualias = aliased(User) eq_( - sess.query(ualias.name).join(Order, ualias.id == Order.user_id) + sess.query(ualias.name) + .join(Order, ualias.id == Order.user_id) .join(order_items, Order.id == order_items.c.order_id) .join(Item, order_items.c.item_id == Item.id) - .filter(Item.description == 'item 4').all(), - [('jack',)] + .filter(Item.description == "item 4") + .all(), + [("jack",)], ) # explicit onclause with from_self(), means # the onclause must be aliased against the query's custom # FROM object eq_( - sess.query(User).order_by(User.id).offset(2) + sess.query(User) + .order_by(User.id) + .offset(2) .from_self() .join(Order, User.id == Order.user_id) .all(), - [User(name='fred')] + [User(name="fred")], ) # same with an explicit select_from() eq_( - sess.query(User).select_entity_from(select([users]) - .order_by(User.id) - .offset(2).alias()) - .join(Order, User.id == Order.user_id).all(), - [User(name='fred')] + sess.query(User) + .select_entity_from( + select([users]).order_by(User.id).offset(2).alias() + ) + .join(Order, User.id == Order.user_id) + .all(), + [User(name="fred")], ) def test_aliased_classes(self): @@ -1472,14 +1740,17 @@ class JoinTest(QueryTest, AssertsCompiledSQL): sess = create_session() (user7, user8, user9, user10) = sess.query(User).all() - (address1, address2, address3, address4, address5) = sess \ - .query(Address).all() - expected = [(user7, address1), - (user8, address2), - (user8, address3), - (user8, address4), - (user9, address5), - (user10, None)] + (address1, address2, address3, address4, address5) = sess.query( + Address + ).all() + expected = [ + (user7, address1), + (user8, address2), + (user8, address3), + (user8, address4), + (user9, address5), + (user10, None), + ] q = sess.query(User) AdAlias = aliased(Address) @@ -1490,32 +1761,50 @@ class JoinTest(QueryTest, AssertsCompiledSQL): sess.expunge_all() q = sess.query(User).add_entity(AdAlias) - result = q.select_from(outerjoin(User, AdAlias)) \ - .filter(AdAlias.email_address == 'ed@bettyboop.com').all() + result = ( + q.select_from(outerjoin(User, AdAlias)) + .filter(AdAlias.email_address == "ed@bettyboop.com") + .all() + ) eq_(result, [(user8, address3)]) - result = q.select_from(outerjoin(User, AdAlias, 'addresses')) \ - .filter(AdAlias.email_address == 'ed@bettyboop.com').all() + result = ( + q.select_from(outerjoin(User, AdAlias, "addresses")) + .filter(AdAlias.email_address == "ed@bettyboop.com") + .all() + ) eq_(result, [(user8, address3)]) - result = q.select_from( - outerjoin(User, AdAlias, User.id == AdAlias.user_id)).filter( - AdAlias.email_address == 'ed@bettyboop.com').all() + result = ( + q.select_from(outerjoin(User, AdAlias, User.id == AdAlias.user_id)) + .filter(AdAlias.email_address == "ed@bettyboop.com") + .all() + ) eq_(result, [(user8, address3)]) # this is the first test where we are joining "backwards" - from # AdAlias to User even though # the query is against User q = sess.query(User, AdAlias) - result = q.join(AdAlias.user) \ - .filter(User.name == 'ed').order_by(User.id, AdAlias.id) - eq_(result.all(), [(user8, address2), - (user8, address3), (user8, address4), ]) + result = ( + q.join(AdAlias.user) + .filter(User.name == "ed") + .order_by(User.id, AdAlias.id) + ) + eq_( + result.all(), + [(user8, address2), (user8, address3), (user8, address4)], + ) - q = sess.query(User, AdAlias).select_from( - join(AdAlias, User, AdAlias.user)).filter(User.name == 'ed') - eq_(result.all(), [(user8, address2), - (user8, address3), (user8, address4), ]) + q = ( + sess.query(User, AdAlias) + .select_from(join(AdAlias, User, AdAlias.user)) + .filter(User.name == "ed") + ) + eq_( + result.all(), + [(user8, address2), (user8, address3), (user8, address4)], + ) def test_expression_onclauses(self): Order, User = self.classes.Order, self.classes.User @@ -1529,7 +1818,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "SELECT users.id AS users_id, users.name AS users_name " "FROM users JOIN (SELECT users.id AS id, users.name " "AS name FROM users) AS anon_1 ON users.name = anon_1.name", - use_default_dialect=True + use_default_dialect=True, ) subq = sess.query(Order).subquery() @@ -1540,51 +1829,100 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "orders.address_id AS address_id, orders.description AS " "description, orders.isopen AS isopen FROM orders) AS " "anon_1 ON users.id = anon_1.user_id", - use_default_dialect=True + use_default_dialect=True, ) self.assert_compile( sess.query(User).join(Order, User.id == Order.user_id), "SELECT users.id AS users_id, users.name AS users_name " "FROM users JOIN orders ON users.id = orders.user_id", - use_default_dialect=True + use_default_dialect=True, ) def test_implicit_joins_from_aliases(self): - Item, User, Order = (self.classes.Item, - self.classes.User, - self.classes.Order) + Item, User, Order = ( + self.classes.Item, + self.classes.User, + self.classes.Order, + ) sess = create_session() OrderAlias = aliased(Order) - eq_(sess.query(OrderAlias).join('items') - .filter_by(description='item 3').order_by(OrderAlias.id).all(), + eq_( + sess.query(OrderAlias) + .join("items") + .filter_by(description="item 3") + .order_by(OrderAlias.id) + .all(), + [ + Order( + address_id=1, + description="order 1", + isopen=0, + user_id=7, + id=1, + ), + Order( + address_id=4, + description="order 2", + isopen=0, + user_id=9, + id=2, + ), + Order( + address_id=1, + description="order 3", + isopen=1, + user_id=7, + id=3, + ), + ], + ) + + eq_( + sess.query(User, OrderAlias, Item.description) + .join(OrderAlias, "orders") + .join("items", from_joinpoint=True) + .filter_by(description="item 3") + .order_by(User.id, OrderAlias.id) + .all(), [ - Order(address_id=1, description='order 1', isopen=0, user_id=7, - id=1), - Order(address_id=4, description='order 2', isopen=0, user_id=9, - id=2), - Order(address_id=1, description='order 3', isopen=1, user_id=7, - id=3) - ]) - - eq_(sess.query(User, OrderAlias, Item.description). - join(OrderAlias, 'orders').join('items', from_joinpoint=True). - filter_by(description='item 3').order_by(User.id, OrderAlias.id). - all(), - [(User(name='jack', id=7), - Order(address_id=1, description='order 1', isopen=0, user_id=7, - id=1), - 'item 3'), - (User(name='jack', id=7), - Order(address_id=1, description='order 3', isopen=1, user_id=7, - id=3), - 'item 3'), - (User(name='fred', id=9), - Order(address_id=4, description='order 2', isopen=0, user_id=9, - id=2), - 'item 3')]) + ( + User(name="jack", id=7), + Order( + address_id=1, + description="order 1", + isopen=0, + user_id=7, + id=1, + ), + "item 3", + ), + ( + User(name="jack", id=7), + Order( + address_id=1, + description="order 3", + isopen=1, + user_id=7, + id=3, + ), + "item 3", + ), + ( + User(name="fred", id=9), + Order( + address_id=4, + description="order 2", + isopen=0, + user_id=9, + id=2, + ), + "item 3", + ), + ], + ) def test_aliased_classes_m2m(self): Item, Order = self.classes.Item, self.classes.Order @@ -1609,22 +1947,22 @@ class JoinTest(QueryTest, AssertsCompiledSQL): ] q = sess.query(Order) - q = q.add_entity(Item).select_from( - join(Order, Item, 'items')).order_by(Order.id, Item.id) + q = ( + q.add_entity(Item) + .select_from(join(Order, Item, "items")) + .order_by(Order.id, Item.id) + ) result = q.all() eq_(result, expected) IAlias = aliased(Item) - q = sess.query(Order, IAlias).select_from( - join(Order, IAlias, 'items')) \ - .filter(IAlias.description == 'item 3') + q = ( + sess.query(Order, IAlias) + .select_from(join(Order, IAlias, "items")) + .filter(IAlias.description == "item 3") + ) result = q.all() - eq_(result, - [ - (order1, item3), - (order2, item3), - (order3, item3), - ]) + eq_(result, [(order1, item3), (order2, item3), (order3, item3)]) def test_joins_from_adapted_entities(self): User = self.classes.User @@ -1638,17 +1976,19 @@ class JoinTest(QueryTest, AssertsCompiledSQL): subquery = session.query(User.id).subquery() join = subquery, subquery.c.id == User.id joined = unioned.outerjoin(*join) - self.assert_compile(joined, - 'SELECT anon_1.users_id AS ' - 'anon_1_users_id, anon_1.users_name AS ' - 'anon_1_users_name FROM (SELECT users.id ' - 'AS users_id, users.name AS users_name ' - 'FROM users UNION SELECT users.id AS ' - 'users_id, users.name AS users_name FROM ' - 'users) AS anon_1 LEFT OUTER JOIN (SELECT ' - 'users.id AS id FROM users) AS anon_2 ON ' - 'anon_2.id = anon_1.users_id', - use_default_dialect=True) + self.assert_compile( + joined, + "SELECT anon_1.users_id AS " + "anon_1_users_id, anon_1.users_name AS " + "anon_1_users_name FROM (SELECT users.id " + "AS users_id, users.name AS users_name " + "FROM users UNION SELECT users.id AS " + "users_id, users.name AS users_name FROM " + "users) AS anon_1 LEFT OUTER JOIN (SELECT " + "users.id AS id FROM users) AS anon_2 ON " + "anon_2.id = anon_1.users_id", + use_default_dialect=True, + ) first = session.query(User.id) second = session.query(User.id) @@ -1656,14 +1996,16 @@ class JoinTest(QueryTest, AssertsCompiledSQL): subquery = session.query(User.id).subquery() join = subquery, subquery.c.id == User.id joined = unioned.outerjoin(*join) - self.assert_compile(joined, - 'SELECT anon_1.users_id AS anon_1_users_id ' - 'FROM (SELECT users.id AS users_id FROM ' - 'users UNION SELECT users.id AS users_id ' - 'FROM users) AS anon_1 LEFT OUTER JOIN ' - '(SELECT users.id AS id FROM users) AS ' - 'anon_2 ON anon_2.id = anon_1.users_id', - use_default_dialect=True) + self.assert_compile( + joined, + "SELECT anon_1.users_id AS anon_1_users_id " + "FROM (SELECT users.id AS users_id FROM " + "users UNION SELECT users.id AS users_id " + "FROM users) AS anon_1 LEFT OUTER JOIN " + "(SELECT users.id AS id FROM users) AS " + "anon_2 ON anon_2.id = anon_1.users_id", + use_default_dialect=True, + ) def test_joins_from_adapted_entities_isouter(self): User = self.classes.User @@ -1677,17 +2019,19 @@ class JoinTest(QueryTest, AssertsCompiledSQL): subquery = session.query(User.id).subquery() join = subquery, subquery.c.id == User.id joined = unioned.join(*join, isouter=True) - self.assert_compile(joined, - 'SELECT anon_1.users_id AS ' - 'anon_1_users_id, anon_1.users_name AS ' - 'anon_1_users_name FROM (SELECT users.id ' - 'AS users_id, users.name AS users_name ' - 'FROM users UNION SELECT users.id AS ' - 'users_id, users.name AS users_name FROM ' - 'users) AS anon_1 LEFT OUTER JOIN (SELECT ' - 'users.id AS id FROM users) AS anon_2 ON ' - 'anon_2.id = anon_1.users_id', - use_default_dialect=True) + self.assert_compile( + joined, + "SELECT anon_1.users_id AS " + "anon_1_users_id, anon_1.users_name AS " + "anon_1_users_name FROM (SELECT users.id " + "AS users_id, users.name AS users_name " + "FROM users UNION SELECT users.id AS " + "users_id, users.name AS users_name FROM " + "users) AS anon_1 LEFT OUTER JOIN (SELECT " + "users.id AS id FROM users) AS anon_2 ON " + "anon_2.id = anon_1.users_id", + use_default_dialect=True, + ) first = session.query(User.id) second = session.query(User.id) @@ -1695,14 +2039,16 @@ class JoinTest(QueryTest, AssertsCompiledSQL): subquery = session.query(User.id).subquery() join = subquery, subquery.c.id == User.id joined = unioned.join(*join, isouter=True) - self.assert_compile(joined, - 'SELECT anon_1.users_id AS anon_1_users_id ' - 'FROM (SELECT users.id AS users_id FROM ' - 'users UNION SELECT users.id AS users_id ' - 'FROM users) AS anon_1 LEFT OUTER JOIN ' - '(SELECT users.id AS id FROM users) AS ' - 'anon_2 ON anon_2.id = anon_1.users_id', - use_default_dialect=True) + self.assert_compile( + joined, + "SELECT anon_1.users_id AS anon_1_users_id " + "FROM (SELECT users.id AS users_id FROM " + "users UNION SELECT users.id AS users_id " + "FROM users) AS anon_1 LEFT OUTER JOIN " + "(SELECT users.id AS id FROM users) AS " + "anon_2 ON anon_2.id = anon_1.users_id", + use_default_dialect=True, + ) def test_reset_joinpoint(self): User = self.classes.User @@ -1710,100 +2056,158 @@ class JoinTest(QueryTest, AssertsCompiledSQL): for aliased in (True, False): # load a user who has an order that contains item id 3 and address # id 1 (order 3, owned by jack) - result = create_session().query(User) \ - .join('orders', 'items', aliased=aliased) \ - .filter_by(id=3).reset_joinpoint() \ - .join('orders', 'address', aliased=aliased) \ - .filter_by(id=1).all() - assert [User(id=7, name='jack')] == result - - result = create_session().query(User) \ - .join('orders', 'items', aliased=aliased, isouter=True) \ - .filter_by(id=3).reset_joinpoint() \ - .join('orders', 'address', aliased=aliased, isouter=True) \ - .filter_by(id=1).all() - assert [User(id=7, name='jack')] == result - - result = create_session().query(User).outerjoin( - 'orders', 'items', aliased=aliased).filter_by( - id=3).reset_joinpoint().outerjoin( - 'orders', 'address', aliased=aliased).filter_by( - id=1).all() - assert [User(id=7, name='jack')] == result + result = ( + create_session() + .query(User) + .join("orders", "items", aliased=aliased) + .filter_by(id=3) + .reset_joinpoint() + .join("orders", "address", aliased=aliased) + .filter_by(id=1) + .all() + ) + assert [User(id=7, name="jack")] == result + + result = ( + create_session() + .query(User) + .join("orders", "items", aliased=aliased, isouter=True) + .filter_by(id=3) + .reset_joinpoint() + .join("orders", "address", aliased=aliased, isouter=True) + .filter_by(id=1) + .all() + ) + assert [User(id=7, name="jack")] == result + + result = ( + create_session() + .query(User) + .outerjoin("orders", "items", aliased=aliased) + .filter_by(id=3) + .reset_joinpoint() + .outerjoin("orders", "address", aliased=aliased) + .filter_by(id=1) + .all() + ) + assert [User(id=7, name="jack")] == result def test_overlap_with_aliases(self): - orders, User, users = (self.tables.orders, - self.classes.User, - self.tables.users) - - oalias = orders.alias('oalias') + orders, User, users = ( + self.tables.orders, + self.classes.User, + self.tables.users, + ) - result = create_session().query(User).select_from(users.join(oalias)) \ - .filter(oalias.c.description.in_( - ["order 1", "order 2", "order 3"])) \ - .join('orders', 'items').order_by(User.id).all() - assert [User(id=7, name='jack'), User(id=9, name='fred')] == result + oalias = orders.alias("oalias") - result = create_session().query(User).select_from(users.join(oalias)) \ - .filter(oalias.c.description.in_( - ["order 1", "order 2", "order 3"])) \ - .join('orders', 'items').filter_by(id=4).all() - assert [User(id=7, name='jack')] == result + result = ( + create_session() + .query(User) + .select_from(users.join(oalias)) + .filter( + oalias.c.description.in_(["order 1", "order 2", "order 3"]) + ) + .join("orders", "items") + .order_by(User.id) + .all() + ) + assert [User(id=7, name="jack"), User(id=9, name="fred")] == result + + result = ( + create_session() + .query(User) + .select_from(users.join(oalias)) + .filter( + oalias.c.description.in_(["order 1", "order 2", "order 3"]) + ) + .join("orders", "items") + .filter_by(id=4) + .all() + ) + assert [User(id=7, name="jack")] == result def test_aliased(self): """test automatic generation of aliased joins.""" - Item, Order, User, Address = (self.classes.Item, - self.classes.Order, - self.classes.User, - self.classes.Address) + Item, Order, User, Address = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() # test a basic aliasized path - q = sess.query(User).join('addresses', aliased=True).filter_by( - email_address='jack@bean.com') + q = ( + sess.query(User) + .join("addresses", aliased=True) + .filter_by(email_address="jack@bean.com") + ) assert [User(id=7)] == q.all() - q = sess.query(User).join('addresses', aliased=True).filter( - Address.email_address == 'jack@bean.com') + q = ( + sess.query(User) + .join("addresses", aliased=True) + .filter(Address.email_address == "jack@bean.com") + ) assert [User(id=7)] == q.all() - q = sess.query(User).join('addresses', aliased=True).filter(or_( - Address.email_address == 'jack@bean.com', - Address.email_address == 'fred@fred.com')) + q = ( + sess.query(User) + .join("addresses", aliased=True) + .filter( + or_( + Address.email_address == "jack@bean.com", + Address.email_address == "fred@fred.com", + ) + ) + ) assert [User(id=7), User(id=9)] == q.all() # test two aliasized paths, one to 'orders' and the other to # 'orders','items'. one row is returned because user 7 has order 3 and # also has order 1 which has item 1 # this tests a o2m join and a m2m join. - q = sess.query(User).join('orders', aliased=True) \ - .filter(Order.description == "order 3") \ - .join('orders', 'items', aliased=True) \ + q = ( + sess.query(User) + .join("orders", aliased=True) + .filter(Order.description == "order 3") + .join("orders", "items", aliased=True) .filter(Item.description == "item 1") + ) assert q.count() == 1 assert [User(id=7)] == q.all() # test the control version - same joins but not aliased. rows are not # returned because order 3 does not have item 1 - q = sess.query(User).join('orders').filter( - Order.description == "order 3").join( - 'orders', 'items').filter( - Item.description == "item 1") + q = ( + sess.query(User) + .join("orders") + .filter(Order.description == "order 3") + .join("orders", "items") + .filter(Item.description == "item 1") + ) assert [] == q.all() assert q.count() == 0 # the left half of the join condition of the any() is aliased. - q = sess.query(User).join('orders', aliased=True).filter( - Order.items.any(Item.description == 'item 4')) + q = ( + sess.query(User) + .join("orders", aliased=True) + .filter(Order.items.any(Item.description == "item 4")) + ) assert [User(id=7)] == q.all() # test that aliasing gets reset when join() is called - q = sess.query(User).join('orders', aliased=True) \ - .filter(Order.description == "order 3") \ - .join('orders', aliased=True) \ + q = ( + sess.query(User) + .join("orders", aliased=True) + .filter(Order.description == "order 3") + .join("orders", aliased=True) .filter(Order.description == "order 5") + ) assert q.count() == 1 assert [User(id=7)] == q.all() @@ -1814,16 +2218,18 @@ class JoinTest(QueryTest, AssertsCompiledSQL): ualias = aliased(User) eq_( - sess.query(User, ualias).filter(User.id > ualias.id) - .order_by(desc(ualias.id), User.name).all(), + sess.query(User, ualias) + .filter(User.id > ualias.id) + .order_by(desc(ualias.id), User.name) + .all(), [ - (User(id=10, name='chuck'), User(id=9, name='fred')), - (User(id=10, name='chuck'), User(id=8, name='ed')), - (User(id=9, name='fred'), User(id=8, name='ed')), - (User(id=10, name='chuck'), User(id=7, name='jack')), - (User(id=8, name='ed'), User(id=7, name='jack')), - (User(id=9, name='fred'), User(id=7, name='jack')) - ] + (User(id=10, name="chuck"), User(id=9, name="fred")), + (User(id=10, name="chuck"), User(id=8, name="ed")), + (User(id=9, name="fred"), User(id=8, name="ed")), + (User(id=10, name="chuck"), User(id=7, name="jack")), + (User(id=8, name="ed"), User(id=7, name="jack")), + (User(id=9, name="fred"), User(id=7, name="jack")), + ], ) def test_plain_table(self): @@ -1834,8 +2240,9 @@ class JoinTest(QueryTest, AssertsCompiledSQL): eq_( sess.query(User.name) .join(addresses, User.id == addresses.c.user_id) - .order_by(User.id).all(), - [('jack',), ('ed',), ('ed',), ('ed',), ('fred',)] + .order_by(User.id) + .all(), + [("jack",), ("ed",), ("ed",), ("ed",), ("fred",)], ) def test_no_joinpoint_expr(self): @@ -1849,13 +2256,15 @@ class JoinTest(QueryTest, AssertsCompiledSQL): assert_raises_message( sa_exc.InvalidRequestError, "Don't know how to join to .*User.* please use an ON clause to ", - sess.query(users.c.id).join, User + sess.query(users.c.id).join, + User, ) assert_raises_message( sa_exc.InvalidRequestError, "Don't know how to join to .*User.* please use an ON clause to ", - sess.query(users.c.id).select_from(users).join, User + sess.query(users.c.id).select_from(users).join, + User, ) def test_on_clause_no_right_side(self): @@ -1866,36 +2275,43 @@ class JoinTest(QueryTest, AssertsCompiledSQL): assert_raises_message( sa_exc.ArgumentError, "Expected mapped entity or selectable/table as join target", - sess.query(User).join, User.id == Address.user_id + sess.query(User).join, + User.id == Address.user_id, ) def test_select_from(self): """Test that the left edge of the join can be set reliably with select_from().""" - Item, Order, User = (self.classes.Item, - self.classes.Order, - self.classes.User) + Item, Order, User = ( + self.classes.Item, + self.classes.Order, + self.classes.User, + ) sess = create_session() self.assert_compile( - sess.query(Item.id).select_from(User) - .join(User.orders).join(Order.items), + sess.query(Item.id) + .select_from(User) + .join(User.orders) + .join(Order.items), "SELECT items.id AS items_id FROM users JOIN orders ON " "users.id = orders.user_id JOIN order_items AS order_items_1 " "ON orders.id = order_items_1.order_id JOIN items ON items.id = " "order_items_1.item_id", - use_default_dialect=True + use_default_dialect=True, ) # here, the join really wants to add a second FROM clause # for "Item". but select_from disallows that self.assert_compile( - sess.query(Item.id).select_from(User) + sess.query(Item.id) + .select_from(User) .join(Item, User.id == Item.id), "SELECT items.id AS items_id FROM users JOIN items " "ON users.id = items.id", - use_default_dialect=True) + use_default_dialect=True, + ) def test_from_self_resets_joinpaths(self): """test a join from from_self() doesn't confuse joins inside the subquery @@ -1907,7 +2323,9 @@ class JoinTest(QueryTest, AssertsCompiledSQL): sess = create_session() self.assert_compile( - sess.query(Item).join(Item.keywords).from_self(Keyword) + sess.query(Item) + .join(Item.keywords) + .from_self(Keyword) .join(Item.keywords), "SELECT keywords.id AS keywords_id, " "keywords.name AS keywords_name " @@ -1920,20 +2338,23 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "anon_1.items_id = item_keywords_2.item_id " "JOIN keywords ON " "keywords.id = item_keywords_2.keyword_id", - use_default_dialect=True) + use_default_dialect=True, + ) class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL): - __dialect__ = 'default' - run_setup_mappers = 'once' + __dialect__ = "default" + run_setup_mappers = "once" @classmethod def define_tables(cls, metadata): - Table('table1', metadata, - Column('id', Integer, primary_key=True)) - Table('table2', metadata, - Column('id', Integer, primary_key=True), - Column('t1_id', Integer)) + Table("table1", metadata, Column("id", Integer, primary_key=True)) + Table( + "table2", + metadata, + Column("id", Integer, primary_key=True), + Column("t1_id", Integer), + ) @classmethod def setup_classes(cls): @@ -1952,25 +2373,32 @@ class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL): T1, T2 = self.classes.T1, self.classes.T2 sess = Session() - subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\ - group_by(T2.t1_id).subquery() + subq = ( + sess.query(T2.t1_id, func.count(T2.id).label("count")) + .group_by(T2.t1_id) + .subquery() + ) self.assert_compile( sess.query(subq.c.count, T1.id) - .select_from(subq).join(T1, subq.c.t1_id == T1.id), + .select_from(subq) + .join(T1, subq.c.t1_id == T1.id), "SELECT anon_1.count AS anon_1_count, table1.id AS table1_id " "FROM (SELECT table2.t1_id AS t1_id, " "count(table2.id) AS count FROM table2 " "GROUP BY table2.t1_id) AS anon_1 JOIN table1 " - "ON anon_1.t1_id = table1.id" + "ON anon_1.t1_id = table1.id", ) def test_select_mapped_to_mapped_implicit_left(self): T1, T2 = self.classes.T1, self.classes.T2 sess = Session() - subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\ - group_by(T2.t1_id).subquery() + subq = ( + sess.query(T2.t1_id, func.count(T2.id).label("count")) + .group_by(T2.t1_id) + .subquery() + ) self.assert_compile( sess.query(subq.c.count, T1.id).join(T1, subq.c.t1_id == T1.id), @@ -1978,31 +2406,38 @@ class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL): "FROM (SELECT table2.t1_id AS t1_id, " "count(table2.id) AS count FROM table2 " "GROUP BY table2.t1_id) AS anon_1 JOIN table1 " - "ON anon_1.t1_id = table1.id" + "ON anon_1.t1_id = table1.id", ) def test_select_mapped_to_select_explicit_left(self): T1, T2 = self.classes.T1, self.classes.T2 sess = Session() - subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\ - group_by(T2.t1_id).subquery() + subq = ( + sess.query(T2.t1_id, func.count(T2.id).label("count")) + .group_by(T2.t1_id) + .subquery() + ) self.assert_compile( - sess.query(subq.c.count, T1.id).select_from(T1) + sess.query(subq.c.count, T1.id) + .select_from(T1) .join(subq, subq.c.t1_id == T1.id), "SELECT anon_1.count AS anon_1_count, table1.id AS table1_id " "FROM table1 JOIN (SELECT table2.t1_id AS t1_id, " "count(table2.id) AS count FROM table2 GROUP BY table2.t1_id) " - "AS anon_1 ON anon_1.t1_id = table1.id" + "AS anon_1 ON anon_1.t1_id = table1.id", ) def test_select_mapped_to_select_implicit_left(self): T1, T2 = self.classes.T1, self.classes.T2 sess = Session() - subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\ - group_by(T2.t1_id).subquery() + subq = ( + sess.query(T2.t1_id, func.count(T2.id).label("count")) + .group_by(T2.t1_id) + .subquery() + ) # without select_from self.assert_compile( @@ -2011,85 +2446,101 @@ class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL): "FROM table1 JOIN " "(SELECT table2.t1_id AS t1_id, count(table2.id) AS count " "FROM table2 GROUP BY table2.t1_id) " - "AS anon_1 ON anon_1.t1_id = table1.id" + "AS anon_1 ON anon_1.t1_id = table1.id", ) # with select_from, same query self.assert_compile( - sess.query(subq.c.count, T1.id).select_from(T1). - join(subq, subq.c.t1_id == T1.id), + sess.query(subq.c.count, T1.id) + .select_from(T1) + .join(subq, subq.c.t1_id == T1.id), "SELECT anon_1.count AS anon_1_count, table1.id AS table1_id " "FROM table1 JOIN " "(SELECT table2.t1_id AS t1_id, count(table2.id) AS count " "FROM table2 GROUP BY table2.t1_id) " - "AS anon_1 ON anon_1.t1_id = table1.id" + "AS anon_1 ON anon_1.t1_id = table1.id", ) def test_mapped_select_to_mapped_implicit_left(self): T1, T2 = self.classes.T1, self.classes.T2 sess = Session() - subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\ - group_by(T2.t1_id).subquery() + subq = ( + sess.query(T2.t1_id, func.count(T2.id).label("count")) + .group_by(T2.t1_id) + .subquery() + ) # without select_from self.assert_compile( - sess.query(T1.id, subq.c.count). - join(T1, subq.c.t1_id == T1.id), + sess.query(T1.id, subq.c.count).join(T1, subq.c.t1_id == T1.id), "SELECT table1.id AS table1_id, anon_1.count AS anon_1_count " "FROM (SELECT table2.t1_id AS t1_id, count(table2.id) AS count " "FROM table2 GROUP BY table2.t1_id) AS anon_1 " - "JOIN table1 ON anon_1.t1_id = table1.id" + "JOIN table1 ON anon_1.t1_id = table1.id", ) # with select_from, same query self.assert_compile( - sess.query(T1.id, subq.c.count).select_from(subq). - join(T1, subq.c.t1_id == T1.id), + sess.query(T1.id, subq.c.count) + .select_from(subq) + .join(T1, subq.c.t1_id == T1.id), "SELECT table1.id AS table1_id, anon_1.count AS anon_1_count " "FROM (SELECT table2.t1_id AS t1_id, count(table2.id) AS count " "FROM table2 GROUP BY table2.t1_id) AS anon_1 " - "JOIN table1 ON anon_1.t1_id = table1.id" + "JOIN table1 ON anon_1.t1_id = table1.id", ) def test_mapped_select_to_mapped_explicit_left(self): T1, T2 = self.classes.T1, self.classes.T2 sess = Session() - subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\ - group_by(T2.t1_id).subquery() + subq = ( + sess.query(T2.t1_id, func.count(T2.id).label("count")) + .group_by(T2.t1_id) + .subquery() + ) self.assert_compile( - sess.query(T1.id, subq.c.count).select_from(subq) + sess.query(T1.id, subq.c.count) + .select_from(subq) .join(T1, subq.c.t1_id == T1.id), "SELECT table1.id AS table1_id, anon_1.count AS anon_1_count " "FROM (SELECT table2.t1_id AS t1_id, count(table2.id) AS count " "FROM table2 GROUP BY table2.t1_id) AS anon_1 JOIN table1 " - "ON anon_1.t1_id = table1.id" + "ON anon_1.t1_id = table1.id", ) def test_mapped_select_to_select_explicit_left(self): T1, T2 = self.classes.T1, self.classes.T2 sess = Session() - subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\ - group_by(T2.t1_id).subquery() + subq = ( + sess.query(T2.t1_id, func.count(T2.id).label("count")) + .group_by(T2.t1_id) + .subquery() + ) self.assert_compile( - sess.query(T1.id, subq.c.count).select_from(T1) + sess.query(T1.id, subq.c.count) + .select_from(T1) .join(subq, subq.c.t1_id == T1.id), "SELECT table1.id AS table1_id, anon_1.count AS anon_1_count " "FROM table1 JOIN (SELECT table2.t1_id AS t1_id, " "count(table2.id) AS count " "FROM table2 GROUP BY table2.t1_id) AS anon_1 " - "ON anon_1.t1_id = table1.id") + "ON anon_1.t1_id = table1.id", + ) def test_mapped_select_to_select_implicit_left(self): T1, T2 = self.classes.T1, self.classes.T2 sess = Session() - subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\ - group_by(T2.t1_id).subquery() + subq = ( + sess.query(T2.t1_id, func.count(T2.id).label("count")) + .group_by(T2.t1_id) + .subquery() + ) self.assert_compile( sess.query(T1.id, subq.c.count).join(subq, subq.c.t1_id == T1.id), @@ -2097,34 +2548,51 @@ class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL): "FROM table1 JOIN (SELECT table2.t1_id AS t1_id, " "count(table2.id) AS count " "FROM table2 GROUP BY table2.t1_id) AS anon_1 " - "ON anon_1.t1_id = table1.id") + "ON anon_1.t1_id = table1.id", + ) class MultiplePathTest(fixtures.MappedTest, AssertsCompiledSQL): @classmethod def define_tables(cls, metadata): - t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30))) - t2 = Table('t2', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30))) - - t1t2_1 = Table('t1t2_1', metadata, - Column('t1id', Integer, ForeignKey('t1.id')), - Column('t2id', Integer, ForeignKey('t2.id'))) - - t1t2_2 = Table('t1t2_2', metadata, - Column('t1id', Integer, ForeignKey('t1.id')), - Column('t2id', Integer, ForeignKey('t2.id'))) + t1 = Table( + "t1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + ) + t2 = Table( + "t2", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + ) + + t1t2_1 = Table( + "t1t2_1", + metadata, + Column("t1id", Integer, ForeignKey("t1.id")), + Column("t2id", Integer, ForeignKey("t2.id")), + ) + + t1t2_2 = Table( + "t1t2_2", + metadata, + Column("t1id", Integer, ForeignKey("t1.id")), + Column("t2id", Integer, ForeignKey("t2.id")), + ) def test_basic(self): - t2, t1t2_1, t1t2_2, t1 = (self.tables.t2, - self.tables.t1t2_1, - self.tables.t1t2_2, - self.tables.t1) + t2, t1t2_1, t1t2_2, t1 = ( + self.tables.t2, + self.tables.t1t2_1, + self.tables.t1t2_2, + self.tables.t1, + ) class T1(object): pass @@ -2132,14 +2600,24 @@ class MultiplePathTest(fixtures.MappedTest, AssertsCompiledSQL): class T2(object): pass - mapper(T1, t1, properties={ - 't2s_1': relationship(T2, secondary=t1t2_1), - 't2s_2': relationship(T2, secondary=t1t2_2), - }) + mapper( + T1, + t1, + properties={ + "t2s_1": relationship(T2, secondary=t1t2_1), + "t2s_2": relationship(T2, secondary=t1t2_2), + }, + ) mapper(T2, t2) - q = create_session().query(T1).join('t2s_1') \ - .filter(t2.c.id == 5).reset_joinpoint().join('t2s_2') + q = ( + create_session() + .query(T1) + .join("t2s_1") + .filter(t2.c.id == 5) + .reset_joinpoint() + .join("t2s_2") + ) self.assert_compile( q, "SELECT t1.id AS t1_id, t1.data AS t1_data FROM t1 " @@ -2148,35 +2626,48 @@ class MultiplePathTest(fixtures.MappedTest, AssertsCompiledSQL): "JOIN t1t2_2 AS t1t2_2_1 " "ON t1.id = t1t2_2_1.t1id JOIN t2 ON t2.id = t1t2_2_1.t2id " "WHERE t2.id = :id_1", - use_default_dialect=True) + use_default_dialect=True, + ) class SelfRefMixedTest(fixtures.MappedTest, AssertsCompiledSQL): - run_setup_mappers = 'once' + run_setup_mappers = "once" __dialect__ = default.DefaultDialect() @classmethod def define_tables(cls, metadata): - nodes = Table('nodes', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, ForeignKey('nodes.id'))) + nodes = Table( + "nodes", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_id", Integer, ForeignKey("nodes.id")), + ) - sub_table = Table('sub_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('node_id', Integer, ForeignKey('nodes.id'))) + sub_table = Table( + "sub_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("node_id", Integer, ForeignKey("nodes.id")), + ) - assoc_table = Table('assoc_table', metadata, - Column('left_id', Integer, ForeignKey('nodes.id')), - Column('right_id', Integer, - ForeignKey('nodes.id'))) + assoc_table = Table( + "assoc_table", + metadata, + Column("left_id", Integer, ForeignKey("nodes.id")), + Column("right_id", Integer, ForeignKey("nodes.id")), + ) @classmethod def setup_classes(cls): - nodes, assoc_table, sub_table = (cls.tables.nodes, - cls.tables.assoc_table, - cls.tables.sub_table) + nodes, assoc_table, sub_table = ( + cls.tables.nodes, + cls.tables.assoc_table, + cls.tables.sub_table, + ) class Node(cls.Comparable): pass @@ -2184,18 +2675,25 @@ class SelfRefMixedTest(fixtures.MappedTest, AssertsCompiledSQL): class Sub(cls.Comparable): pass - mapper(Node, nodes, properties={ - 'children': relationship(Node, lazy='select', join_depth=3, - backref=backref( - 'parent', remote_side=[nodes.c.id]) - ), - 'subs': relationship(Sub), - 'assoc': relationship( - Node, - secondary=assoc_table, - primaryjoin=nodes.c.id == assoc_table.c.left_id, - secondaryjoin=nodes.c.id == assoc_table.c.right_id) - }) + mapper( + Node, + nodes, + properties={ + "children": relationship( + Node, + lazy="select", + join_depth=3, + backref=backref("parent", remote_side=[nodes.c.id]), + ), + "subs": relationship(Sub), + "assoc": relationship( + Node, + secondary=assoc_table, + primaryjoin=nodes.c.id == assoc_table.c.left_id, + secondaryjoin=nodes.c.id == assoc_table.c.right_id, + ), + }, + ) mapper(Sub, sub_table) def test_o2m_aliased_plus_o2m(self): @@ -2208,14 +2706,14 @@ class SelfRefMixedTest(fixtures.MappedTest, AssertsCompiledSQL): sess.query(Node).join(n1, Node.children).join(Sub, n1.subs), "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id " "FROM nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id " - "JOIN sub_table ON nodes_1.id = sub_table.node_id" + "JOIN sub_table ON nodes_1.id = sub_table.node_id", ) self.assert_compile( sess.query(Node).join(n1, Node.children).join(Sub, Node.subs), "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id " "FROM nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id " - "JOIN sub_table ON nodes.id = sub_table.node_id" + "JOIN sub_table ON nodes.id = sub_table.node_id", ) def test_m2m_aliased_plus_o2m(self): @@ -2244,22 +2742,28 @@ class SelfRefMixedTest(fixtures.MappedTest, AssertsCompiledSQL): class CreateJoinsTest(fixtures.ORMTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def _inherits_fixture(self): m = MetaData() - base = Table('base', m, Column('id', Integer, primary_key=True)) - a = Table('a', m, - Column('id', Integer, ForeignKey('base.id'), - primary_key=True), - Column('b_id', Integer, ForeignKey('b.id'))) - b = Table('b', m, - Column('id', Integer, ForeignKey('base.id'), - primary_key=True), - Column('c_id', Integer, ForeignKey('c.id'))) - c = Table('c', m, - Column('id', Integer, ForeignKey('base.id'), - primary_key=True)) + base = Table("base", m, Column("id", Integer, primary_key=True)) + a = Table( + "a", + m, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column("b_id", Integer, ForeignKey("b.id")), + ) + b = Table( + "b", + m, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column("c_id", Integer, ForeignKey("c.id")), + ) + c = Table( + "c", + m, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + ) class Base(object): pass @@ -2272,11 +2776,20 @@ class CreateJoinsTest(fixtures.ORMTest, AssertsCompiledSQL): class C(Base): pass + mapper(Base, base) - mapper(A, a, inherits=Base, properties={ - 'b': relationship(B, primaryjoin=a.c.b_id == b.c.id)}) - mapper(B, b, inherits=Base, properties={ - 'c': relationship(C, primaryjoin=b.c.c_id == c.c.id)}) + mapper( + A, + a, + inherits=Base, + properties={"b": relationship(B, primaryjoin=a.c.b_id == b.c.id)}, + ) + mapper( + B, + b, + inherits=Base, + properties={"c": relationship(C, primaryjoin=b.c.c_id == c.c.id)}, + ) mapper(C, c, inherits=Base) return A, B, C, Base @@ -2293,7 +2806,7 @@ class CreateJoinsTest(fixtures.ORMTest, AssertsCompiledSQL): "(SELECT 1 FROM (SELECT base.id AS base_id, c.id AS c_id " "FROM base JOIN c ON base.id = c.id) AS anon_2 " "WHERE anon_1.b_c_id = anon_2.c_id AND anon_2.c_id = :id_1" - ")))" + ")))", ) @@ -2306,19 +2819,26 @@ class JoinToNonPolyAliasesTest(fixtures.MappedTest, AssertsCompiledSQL): """ - __dialect__ = 'default' + + __dialect__ = "default" run_create_tables = None run_deletes = None @classmethod def define_tables(cls, metadata): - Table("parent", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50))) - Table("child", metadata, - Column('id', Integer, primary_key=True), - Column('parent_id', Integer, ForeignKey('parent.id')), - Column('data', String(50))) + Table( + "parent", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + Table( + "child", + metadata, + Column("id", Integer, primary_key=True), + Column("parent_id", Integer, ForeignKey("parent.id")), + Column("data", String(50)), + ) @classmethod def setup_mappers(cls): @@ -2344,27 +2864,31 @@ class JoinToNonPolyAliasesTest(fixtures.MappedTest, AssertsCompiledSQL): npc = self.npc sess = Session() self.assert_compile( - sess.query(Parent).join(Parent.npc) - .filter(self.derived.c.data == 'x'), + sess.query(Parent) + .join(Parent.npc) + .filter(self.derived.c.data == "x"), "SELECT parent.id AS parent_id, parent.data AS parent_data " "FROM parent JOIN (SELECT child.id AS id, " "child.parent_id AS parent_id, " "child.data AS data " "FROM child) AS anon_1 ON parent.id = anon_1.parent_id " - "WHERE anon_1.data = :data_1") + "WHERE anon_1.data = :data_1", + ) def test_join_parent_child_select_from(self): Parent = self.classes.Parent npc = self.npc sess = Session() self.assert_compile( - sess.query(npc).select_from(Parent).join(Parent.npc) - .filter(self.derived.c.data == 'x'), + sess.query(npc) + .select_from(Parent) + .join(Parent.npc) + .filter(self.derived.c.data == "x"), "SELECT anon_1.id AS anon_1_id, anon_1.parent_id " "AS anon_1_parent_id, anon_1.data AS anon_1_data " "FROM parent JOIN (SELECT child.id AS id, child.parent_id AS " "parent_id, child.data AS data FROM child) AS anon_1 ON " - "parent.id = anon_1.parent_id WHERE anon_1.data = :data_1" + "parent.id = anon_1.parent_id WHERE anon_1.data = :data_1", ) def test_join_select_parent_child(self): @@ -2372,29 +2896,34 @@ class JoinToNonPolyAliasesTest(fixtures.MappedTest, AssertsCompiledSQL): npc = self.npc sess = Session() self.assert_compile( - sess.query(Parent, npc).join(Parent.npc) - .filter(self.derived.c.data == 'x'), + sess.query(Parent, npc) + .join(Parent.npc) + .filter(self.derived.c.data == "x"), "SELECT parent.id AS parent_id, parent.data AS parent_data, " "anon_1.id AS anon_1_id, anon_1.parent_id AS anon_1_parent_id, " "anon_1.data AS anon_1_data FROM parent JOIN " "(SELECT child.id AS id, child.parent_id AS parent_id, " "child.data AS data FROM child) AS anon_1 ON parent.id = " - "anon_1.parent_id WHERE anon_1.data = :data_1" + "anon_1.parent_id WHERE anon_1.data = :data_1", ) class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - Table('nodes', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, ForeignKey('nodes.id')), - Column('data', String(30))) + Table( + "nodes", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_id", Integer, ForeignKey("nodes.id")), + Column("data", String(30)), + ) @classmethod def setup_classes(cls): @@ -2406,25 +2935,31 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): def setup_mappers(cls): Node, nodes = cls.classes.Node, cls.tables.nodes - mapper(Node, nodes, properties={ - 'children': relationship(Node, lazy='select', join_depth=3, - backref=backref( - 'parent', remote_side=[nodes.c.id]) - ), - }) + mapper( + Node, + nodes, + properties={ + "children": relationship( + Node, + lazy="select", + join_depth=3, + backref=backref("parent", remote_side=[nodes.c.id]), + ) + }, + ) @classmethod def insert_data(cls): Node = cls.classes.Node sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) + n1 = Node(data="n1") + n1.append(Node(data="n11")) + n1.append(Node(data="n12")) + n1.append(Node(data="n13")) + n1.children[1].append(Node(data="n121")) + n1.children[1].append(Node(data="n122")) + n1.children[1].append(Node(data="n123")) sess.add(n1) sess.flush() sess.close() @@ -2433,34 +2968,49 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): Node = self.classes.Node sess = create_session() - node = sess.query(Node) \ - .join('children', aliased=True).filter_by(data='n122').first() - assert node.data == 'n12' + node = ( + sess.query(Node) + .join("children", aliased=True) + .filter_by(data="n122") + .first() + ) + assert node.data == "n12" def test_join_2(self): Node = self.classes.Node sess = create_session() - ret = sess.query(Node.data) \ - .join(Node.children, aliased=True).filter_by(data='n122').all() - assert ret == [('n12',)] + ret = ( + sess.query(Node.data) + .join(Node.children, aliased=True) + .filter_by(data="n122") + .all() + ) + assert ret == [("n12",)] def test_join_3(self): Node = self.classes.Node sess = create_session() - node = sess.query(Node) \ - .join('children', 'children', aliased=True) \ - .filter_by(data='n122').first() - assert node.data == 'n1' + node = ( + sess.query(Node) + .join("children", "children", aliased=True) + .filter_by(data="n122") + .first() + ) + assert node.data == "n1" def test_join_4(self): Node = self.classes.Node sess = create_session() - node = sess.query(Node) \ - .filter_by(data='n122').join('parent', aliased=True) \ - .filter_by(data='n12') \ - .join('parent', aliased=True, from_joinpoint=True) \ - .filter_by(data='n1').first() - assert node.data == 'n122' + node = ( + sess.query(Node) + .filter_by(data="n122") + .join("parent", aliased=True) + .filter_by(data="n12") + .join("parent", aliased=True, from_joinpoint=True) + .filter_by(data="n1") + .first() + ) + assert node.data == "n122" def test_string_or_prop_aliased(self): """test that join('foo') behaves the same as join(Cls.foo) in a self @@ -2471,14 +3021,21 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): Node = self.classes.Node sess = create_session() - nalias = aliased(Node, - sess.query(Node).filter_by(data='n1').subquery()) + nalias = aliased( + Node, sess.query(Node).filter_by(data="n1").subquery() + ) - q1 = sess.query(nalias).join(nalias.children, aliased=True).\ - join(Node.children, from_joinpoint=True) + q1 = ( + sess.query(nalias) + .join(nalias.children, aliased=True) + .join(Node.children, from_joinpoint=True) + ) - q2 = sess.query(nalias).join(nalias.children, aliased=True).\ - join("children", from_joinpoint=True) + q2 = ( + sess.query(nalias) + .join(nalias.children, aliased=True) + .join("children", from_joinpoint=True) + ) for q in (q1, q2): self.assert_compile( @@ -2489,16 +3046,22 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): "nodes.data AS data FROM nodes WHERE nodes.data = :data_1) " "AS anon_1 JOIN nodes AS nodes_1 ON anon_1.id = " "nodes_1.parent_id JOIN nodes ON nodes_1.id = nodes.parent_id", - use_default_dialect=True + use_default_dialect=True, ) - q1 = sess.query(Node).join(nalias.children, aliased=True).\ - join(Node.children, aliased=True, from_joinpoint=True).\ - join(Node.children, from_joinpoint=True) + q1 = ( + sess.query(Node) + .join(nalias.children, aliased=True) + .join(Node.children, aliased=True, from_joinpoint=True) + .join(Node.children, from_joinpoint=True) + ) - q2 = sess.query(Node).join(nalias.children, aliased=True).\ - join("children", aliased=True, from_joinpoint=True).\ - join("children", from_joinpoint=True) + q2 = ( + sess.query(Node) + .join(nalias.children, aliased=True) + .join("children", aliased=True, from_joinpoint=True) + .join("children", from_joinpoint=True) + ) for q in (q1, q2): self.assert_compile( @@ -2510,7 +3073,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): "JOIN nodes AS nodes_1 ON anon_1.id = nodes_1.parent_id " "JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id " "JOIN nodes ON nodes_2.id = nodes.parent_id", - use_default_dialect=True + use_default_dialect=True, ) def test_from_self_inside_excludes_outside(self): @@ -2527,7 +3090,8 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): # n1 is not inside the from_self(), so all cols must be maintained # on the outside self.assert_compile( - sess.query(Node).filter(Node.data == 'n122') + sess.query(Node) + .filter(Node.data == "n122") .from_self(n1, Node.id), "SELECT nodes_1.id AS nodes_1_id, " "nodes_1.parent_id AS nodes_1_parent_id, " @@ -2536,15 +3100,21 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): "nodes.parent_id AS nodes_parent_id, " "nodes.data AS nodes_data FROM " "nodes WHERE nodes.data = :data_1) AS anon_1", - use_default_dialect=True) + use_default_dialect=True, + ) parent = aliased(Node) grandparent = aliased(Node) - q = sess.query(Node, parent, grandparent).\ - join(parent, Node.parent).\ - join(grandparent, parent.parent).\ - filter(Node.data == 'n122').filter(parent.data == 'n12').\ - filter(grandparent.data == 'n1').from_self().limit(1) + q = ( + sess.query(Node, parent, grandparent) + .join(parent, Node.parent) + .join(grandparent, parent.parent) + .filter(Node.data == "n122") + .filter(parent.data == "n12") + .filter(grandparent.data == "n1") + .from_self() + .limit(1) + ) # parent, grandparent *are* inside the from_self(), so they # should get aliased to the outside. @@ -2570,8 +3140,9 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): "ON nodes_2.id = nodes_1.parent_id " "WHERE nodes.data = :data_1 AND nodes_1.data = :data_2 AND " "nodes_2.data = :data_3) AS anon_1 LIMIT :param_1", - {'param_1': 1}, - use_default_dialect=True) + {"param_1": 1}, + use_default_dialect=True, + ) def test_explicit_join_1(self): Node = self.classes.Node @@ -2579,10 +3150,10 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): n2 = aliased(Node) self.assert_compile( - join(Node, n1, 'children').join(n2, 'children'), + join(Node, n1, "children").join(n2, "children"), "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id " "JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id", - use_default_dialect=True + use_default_dialect=True, ) def test_explicit_join_2(self): @@ -2594,7 +3165,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): join(Node, n1, Node.children).join(n2, n1.children), "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id " "JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id", - use_default_dialect=True + use_default_dialect=True, ) def test_explicit_join_3(self): @@ -2605,11 +3176,12 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): # the join_to_left=False here is unfortunate. the default on this # flag should be False. self.assert_compile( - join(Node, n1, Node.children) - .join(n2, Node.children, join_to_left=False), + join(Node, n1, Node.children).join( + n2, Node.children, join_to_left=False + ), "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id " "JOIN nodes AS nodes_2 ON nodes.id = nodes_2.parent_id", - use_default_dialect=True + use_default_dialect=True, ) def test_explicit_join_4(self): @@ -2624,7 +3196,8 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): "nodes.data AS nodes_data FROM nodes JOIN nodes AS nodes_1 " "ON nodes.id = nodes_1.parent_id " "JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id", - use_default_dialect=True) + use_default_dialect=True, + ) def test_explicit_join_5(self): Node = self.classes.Node @@ -2638,16 +3211,21 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): "nodes.data AS nodes_data FROM nodes JOIN nodes AS nodes_1 " "ON nodes.id = nodes_1.parent_id " "JOIN nodes AS nodes_2 ON nodes.id = nodes_2.parent_id", - use_default_dialect=True) + use_default_dialect=True, + ) def test_explicit_join_6(self): Node = self.classes.Node sess = create_session() n1 = aliased(Node) - node = sess.query(Node).select_from(join(Node, n1, 'children')).\ - filter(n1.data == 'n122').first() - assert node.data == 'n12' + node = ( + sess.query(Node) + .select_from(join(Node, n1, "children")) + .filter(n1.data == "n122") + .first() + ) + assert node.data == "n12" def test_explicit_join_7(self): Node = self.classes.Node @@ -2655,10 +3233,13 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): n1 = aliased(Node) n2 = aliased(Node) - node = sess.query(Node).select_from( - join(Node, n1, 'children').join(n2, 'children')).\ - filter(n2.data == 'n122').first() - assert node.data == 'n1' + node = ( + sess.query(Node) + .select_from(join(Node, n1, "children").join(n2, "children")) + .filter(n2.data == "n122") + .first() + ) + assert node.data == "n1" def test_explicit_join_8(self): Node = self.classes.Node @@ -2667,10 +3248,15 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): n2 = aliased(Node) # mix explicit and named onclauses - node = sess.query(Node).select_from( - join(Node, n1, Node.id == n1.parent_id).join(n2, 'children')).\ - filter(n2.data == 'n122').first() - assert node.data == 'n1' + node = ( + sess.query(Node) + .select_from( + join(Node, n1, Node.id == n1.parent_id).join(n2, "children") + ) + .filter(n2.data == "n122") + .first() + ) + assert node.data == "n1" def test_explicit_join_9(self): Node = self.classes.Node @@ -2678,11 +3264,15 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): n1 = aliased(Node) n2 = aliased(Node) - node = sess.query(Node).select_from( - join(Node, n1, 'parent').join(n2, 'parent')).filter( - and_(Node.data == 'n122', n1.data == 'n12', n2.data == 'n1')) \ + node = ( + sess.query(Node) + .select_from(join(Node, n1, "parent").join(n2, "parent")) + .filter( + and_(Node.data == "n122", n1.data == "n12", n2.data == "n1") + ) .first() - assert node.data == 'n122' + ) + assert node.data == "n122" def test_explicit_join_10(self): Node = self.classes.Node @@ -2691,13 +3281,18 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): n2 = aliased(Node) eq_( - list(sess.query(Node).select_from(join(Node, n1, 'parent') - .join(n2, 'parent')). - filter(and_(Node.data == 'n122', - n1.data == 'n12', - n2.data == 'n1')).values(Node.data, n1.data, - n2.data)), - [('n122', 'n12', 'n1')]) + list( + sess.query(Node) + .select_from(join(Node, n1, "parent").join(n2, "parent")) + .filter( + and_( + Node.data == "n122", n1.data == "n12", n2.data == "n1" + ) + ) + .values(Node.data, n1.data, n2.data) + ), + [("n122", "n12", "n1")], + ) def test_join_to_nonaliased(self): Node = self.classes.Node @@ -2707,17 +3302,27 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): n1 = aliased(Node) # using 'n1.parent' implicitly joins to unaliased Node - eq_(sess.query(n1).join(n1.parent).filter(Node.data == 'n1').all(), - [Node(parent_id=1, data='n11', id=2), - Node(parent_id=1, data='n12', id=3), - Node(parent_id=1, data='n13', id=4)]) + eq_( + sess.query(n1).join(n1.parent).filter(Node.data == "n1").all(), + [ + Node(parent_id=1, data="n11", id=2), + Node(parent_id=1, data="n12", id=3), + Node(parent_id=1, data="n13", id=4), + ], + ) # explicit (new syntax) - eq_(sess.query(n1).join(Node, n1.parent).filter(Node.data - == 'n1').all(), - [Node(parent_id=1, data='n11', id=2), - Node(parent_id=1, data='n12', id=3), - Node(parent_id=1, data='n13', id=4)]) + eq_( + sess.query(n1) + .join(Node, n1.parent) + .filter(Node.data == "n1") + .all(), + [ + Node(parent_id=1, data="n11", id=2), + Node(parent_id=1, data="n12", id=3), + Node(parent_id=1, data="n13", id=4), + ], + ) def test_multiple_explicit_entities_one(self): Node = self.classes.Node @@ -2727,12 +3332,14 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): parent = aliased(Node) grandparent = aliased(Node) eq_( - sess.query(Node, parent, grandparent). - join(parent, Node.parent). - join(grandparent, parent.parent). - filter(Node.data == 'n122').filter(parent.data == 'n12'). - filter(grandparent.data == 'n1').first(), - (Node(data='n122'), Node(data='n12'), Node(data='n1')) + sess.query(Node, parent, grandparent) + .join(parent, Node.parent) + .join(grandparent, parent.parent) + .filter(Node.data == "n122") + .filter(parent.data == "n12") + .filter(grandparent.data == "n1") + .first(), + (Node(data="n122"), Node(data="n12"), Node(data="n1")), ) def test_multiple_explicit_entities_two(self): @@ -2743,12 +3350,15 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): parent = aliased(Node) grandparent = aliased(Node) eq_( - sess.query(Node, parent, grandparent). - join(parent, Node.parent). - join(grandparent, parent.parent). - filter(Node.data == 'n122').filter(parent.data == 'n12'). - filter(grandparent.data == 'n1').from_self().first(), - (Node(data='n122'), Node(data='n12'), Node(data='n1')) + sess.query(Node, parent, grandparent) + .join(parent, Node.parent) + .join(grandparent, parent.parent) + .filter(Node.data == "n122") + .filter(parent.data == "n12") + .filter(grandparent.data == "n1") + .from_self() + .first(), + (Node(data="n122"), Node(data="n12"), Node(data="n1")), ) def test_multiple_explicit_entities_three(self): @@ -2760,12 +3370,15 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): grandparent = aliased(Node) # same, change order around eq_( - sess.query(parent, grandparent, Node). - join(parent, Node.parent). - join(grandparent, parent.parent). - filter(Node.data == 'n122').filter(parent.data == 'n12'). - filter(grandparent.data == 'n1').from_self().first(), - (Node(data='n12'), Node(data='n1'), Node(data='n122')) + sess.query(parent, grandparent, Node) + .join(parent, Node.parent) + .join(grandparent, parent.parent) + .filter(Node.data == "n122") + .filter(parent.data == "n12") + .filter(grandparent.data == "n1") + .from_self() + .first(), + (Node(data="n12"), Node(data="n1"), Node(data="n122")), ) def test_multiple_explicit_entities_four(self): @@ -2776,13 +3389,15 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): parent = aliased(Node) grandparent = aliased(Node) eq_( - sess.query(Node, parent, grandparent). - join(parent, Node.parent). - join(grandparent, parent.parent). - filter(Node.data == 'n122').filter(parent.data == 'n12'). - filter(grandparent.data == 'n1'). - options(joinedload(Node.children)).first(), - (Node(data='n122'), Node(data='n12'), Node(data='n1')) + sess.query(Node, parent, grandparent) + .join(parent, Node.parent) + .join(grandparent, parent.parent) + .filter(Node.data == "n122") + .filter(parent.data == "n12") + .filter(grandparent.data == "n1") + .options(joinedload(Node.children)) + .first(), + (Node(data="n122"), Node(data="n12"), Node(data="n1")), ) def test_multiple_explicit_entities_five(self): @@ -2793,85 +3408,142 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): parent = aliased(Node) grandparent = aliased(Node) eq_( - sess.query(Node, parent, grandparent). - join(parent, Node.parent). - join(grandparent, parent.parent). - filter(Node.data == 'n122').filter(parent.data == 'n12'). - filter(grandparent.data == 'n1').from_self(). - options(joinedload(Node.children)).first(), - (Node(data='n122'), Node(data='n12'), Node(data='n1')) + sess.query(Node, parent, grandparent) + .join(parent, Node.parent) + .join(grandparent, parent.parent) + .filter(Node.data == "n122") + .filter(parent.data == "n12") + .filter(grandparent.data == "n1") + .from_self() + .options(joinedload(Node.children)) + .first(), + (Node(data="n122"), Node(data="n12"), Node(data="n1")), ) def test_any(self): Node = self.classes.Node sess = create_session() - eq_(sess.query(Node).filter(Node.children.any(Node.data == 'n1')) - .all(), []) - eq_(sess.query(Node) - .filter(Node.children.any(Node.data == 'n12')).all(), - [Node(data='n1')]) - eq_(sess.query(Node).filter(~Node.children.any()).order_by(Node.id) - .all(), [Node(data='n11'), Node(data='n13'), Node(data='n121'), - Node(data='n122'), Node(data='n123'), ]) + eq_( + sess.query(Node) + .filter(Node.children.any(Node.data == "n1")) + .all(), + [], + ) + eq_( + sess.query(Node) + .filter(Node.children.any(Node.data == "n12")) + .all(), + [Node(data="n1")], + ) + eq_( + sess.query(Node) + .filter(~Node.children.any()) + .order_by(Node.id) + .all(), + [ + Node(data="n11"), + Node(data="n13"), + Node(data="n121"), + Node(data="n122"), + Node(data="n123"), + ], + ) def test_has(self): Node = self.classes.Node sess = create_session() - eq_(sess.query(Node).filter(Node.parent.has(Node.data == 'n12')) - .order_by(Node.id).all(), - [Node(data='n121'), Node(data='n122'), Node(data='n123')]) - eq_(sess.query(Node).filter(Node.parent.has(Node.data == 'n122')) - .all(), []) - eq_(sess.query(Node).filter( - ~Node.parent.has()).all(), [Node(data='n1')]) + eq_( + sess.query(Node) + .filter(Node.parent.has(Node.data == "n12")) + .order_by(Node.id) + .all(), + [Node(data="n121"), Node(data="n122"), Node(data="n123")], + ) + eq_( + sess.query(Node) + .filter(Node.parent.has(Node.data == "n122")) + .all(), + [], + ) + eq_( + sess.query(Node).filter(~Node.parent.has()).all(), + [Node(data="n1")], + ) def test_contains(self): Node = self.classes.Node sess = create_session() - n122 = sess.query(Node).filter(Node.data == 'n122').one() - eq_(sess.query(Node).filter(Node.children.contains(n122)).all(), - [Node(data='n12')]) + n122 = sess.query(Node).filter(Node.data == "n122").one() + eq_( + sess.query(Node).filter(Node.children.contains(n122)).all(), + [Node(data="n12")], + ) - n13 = sess.query(Node).filter(Node.data == 'n13').one() - eq_(sess.query(Node).filter(Node.children.contains(n13)).all(), - [Node(data='n1')]) + n13 = sess.query(Node).filter(Node.data == "n13").one() + eq_( + sess.query(Node).filter(Node.children.contains(n13)).all(), + [Node(data="n1")], + ) def test_eq_ne(self): Node = self.classes.Node sess = create_session() - n12 = sess.query(Node).filter(Node.data == 'n12').one() - eq_(sess.query(Node).filter(Node.parent == n12).all(), - [Node(data='n121'), Node(data='n122'), Node(data='n123')]) + n12 = sess.query(Node).filter(Node.data == "n12").one() + eq_( + sess.query(Node).filter(Node.parent == n12).all(), + [Node(data="n121"), Node(data="n122"), Node(data="n123")], + ) - eq_(sess.query(Node).filter(Node.parent != n12).all(), - [Node(data='n1'), Node(data='n11'), Node(data='n12'), - Node(data='n13')]) + eq_( + sess.query(Node).filter(Node.parent != n12).all(), + [ + Node(data="n1"), + Node(data="n11"), + Node(data="n12"), + Node(data="n13"), + ], + ) class SelfReferentialM2MTest(fixtures.MappedTest): - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - nodes = Table('nodes', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30))) - - node_to_nodes = Table('node_to_nodes', metadata, - Column('left_node_id', Integer, ForeignKey( - 'nodes.id'), primary_key=True), - Column('right_node_id', Integer, ForeignKey( - 'nodes.id'), primary_key=True)) + nodes = Table( + "nodes", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + ) + + node_to_nodes = Table( + "node_to_nodes", + metadata, + Column( + "left_node_id", + Integer, + ForeignKey("nodes.id"), + primary_key=True, + ), + Column( + "right_node_id", + Integer, + ForeignKey("nodes.id"), + primary_key=True, + ), + ) @classmethod def setup_classes(cls): @@ -2880,25 +3552,33 @@ class SelfReferentialM2MTest(fixtures.MappedTest): @classmethod def insert_data(cls): - Node, nodes, node_to_nodes = (cls.classes.Node, - cls.tables.nodes, - cls.tables.node_to_nodes) - - mapper(Node, nodes, properties={ - 'children': relationship( - Node, lazy='select', - secondary=node_to_nodes, - primaryjoin=nodes.c.id == node_to_nodes.c.left_node_id, - secondaryjoin=nodes.c.id == node_to_nodes.c.right_node_id) - }) - sess = create_session() - n1 = Node(data='n1') - n2 = Node(data='n2') - n3 = Node(data='n3') - n4 = Node(data='n4') - n5 = Node(data='n5') - n6 = Node(data='n6') - n7 = Node(data='n7') + Node, nodes, node_to_nodes = ( + cls.classes.Node, + cls.tables.nodes, + cls.tables.node_to_nodes, + ) + + mapper( + Node, + nodes, + properties={ + "children": relationship( + Node, + lazy="select", + secondary=node_to_nodes, + primaryjoin=nodes.c.id == node_to_nodes.c.left_node_id, + secondaryjoin=nodes.c.id == node_to_nodes.c.right_node_id, + ) + }, + ) + sess = create_session() + n1 = Node(data="n1") + n2 = Node(data="n2") + n3 = Node(data="n3") + n4 = Node(data="n4") + n5 = Node(data="n5") + n6 = Node(data="n6") + n7 = Node(data="n7") n1.children = [n2, n3, n4] n2.children = [n3, n6, n7] @@ -2915,23 +3595,40 @@ class SelfReferentialM2MTest(fixtures.MappedTest): Node = self.classes.Node sess = create_session() - eq_(sess.query(Node).filter(Node.children.any(Node.data == 'n3')) - .order_by(Node.data).all(), - [Node(data='n1'), Node(data='n2')]) + eq_( + sess.query(Node) + .filter(Node.children.any(Node.data == "n3")) + .order_by(Node.data) + .all(), + [Node(data="n1"), Node(data="n2")], + ) def test_contains(self): Node = self.classes.Node sess = create_session() - n4 = sess.query(Node).filter_by(data='n4').one() + n4 = sess.query(Node).filter_by(data="n4").one() - eq_(sess.query(Node).filter(Node.children.contains(n4)) - .order_by(Node.data).all(), - [Node(data='n1'), Node(data='n3')]) - eq_(sess.query(Node).filter(not_(Node.children.contains(n4))) - .order_by(Node.data).all(), - [Node(data='n2'), Node(data='n4'), Node(data='n5'), - Node(data='n6'), Node(data='n7')]) + eq_( + sess.query(Node) + .filter(Node.children.contains(n4)) + .order_by(Node.data) + .all(), + [Node(data="n1"), Node(data="n3")], + ) + eq_( + sess.query(Node) + .filter(not_(Node.children.contains(n4))) + .order_by(Node.data) + .all(), + [ + Node(data="n2"), + Node(data="n4"), + Node(data="n5"), + Node(data="n6"), + Node(data="n7"), + ], + ) def test_explicit_join(self): Node = self.classes.Node @@ -2939,74 +3636,77 @@ class SelfReferentialM2MTest(fixtures.MappedTest): sess = create_session() n1 = aliased(Node) - eq_(sess.query(Node).select_from(join(Node, n1, 'children')) - .filter(n1.data.in_(['n3', 'n7'])).order_by(Node.id).all(), - [Node(data='n1'), Node(data='n2')]) + eq_( + sess.query(Node) + .select_from(join(Node, n1, "children")) + .filter(n1.data.in_(["n3", "n7"])) + .order_by(Node.id) + .all(), + [Node(data="n1"), Node(data="n2")], + ) class AliasFromCorrectLeftTest( - fixtures.DeclarativeMappedTest, AssertsCompiledSQL): + fixtures.DeclarativeMappedTest, AssertsCompiledSQL +): run_create_tables = None - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class Object(Base): - __tablename__ = 'object' + __tablename__ = "object" type = Column(String(30)) __mapper_args__ = { - 'polymorphic_identity': 'object', - 'polymorphic_on': type + "polymorphic_identity": "object", + "polymorphic_on": type, } id = Column(Integer, primary_key=True) name = Column(String(256)) class A(Object): - __tablename__ = 'a' + __tablename__ = "a" - __mapper_args__ = {'polymorphic_identity': 'a'} + __mapper_args__ = {"polymorphic_identity": "a"} - id = Column(Integer, ForeignKey('object.id'), primary_key=True) + id = Column(Integer, ForeignKey("object.id"), primary_key=True) b_list = relationship( - 'B', - secondary='a_b_association', - backref='a_list' + "B", secondary="a_b_association", backref="a_list" ) class B(Object): - __tablename__ = 'b' + __tablename__ = "b" - __mapper_args__ = {'polymorphic_identity': 'b'} + __mapper_args__ = {"polymorphic_identity": "b"} - id = Column(Integer, ForeignKey('object.id'), primary_key=True) + id = Column(Integer, ForeignKey("object.id"), primary_key=True) class ABAssociation(Base): - __tablename__ = 'a_b_association' + __tablename__ = "a_b_association" - a_id = Column(Integer, ForeignKey('a.id'), primary_key=True) - b_id = Column(Integer, ForeignKey('b.id'), primary_key=True) + a_id = Column(Integer, ForeignKey("a.id"), primary_key=True) + b_id = Column(Integer, ForeignKey("b.id"), primary_key=True) class X(Base): - __tablename__ = 'x' + __tablename__ = "x" id = Column(Integer, primary_key=True) name = Column(String(30)) - obj_id = Column(Integer, ForeignKey('object.id')) - obj = relationship('Object', backref='x_list') + obj_id = Column(Integer, ForeignKey("object.id")) + obj = relationship("Object", backref="x_list") def test_join_prop_to_string(self): A, B, X = self.classes("A", "B", "X") s = Session() - q = s.query(B).\ - join(B.a_list, 'x_list').filter(X.name == 'x1') + q = s.query(B).join(B.a_list, "x_list").filter(X.name == "x1") self.assert_compile( q, @@ -3019,7 +3719,7 @@ class AliasFromCorrectLeftTest( "object AS object_1 " "JOIN a AS a_1 ON object_1.id = a_1.id" ") ON a_1.id = a_b_association_1.a_id " - "JOIN x ON object_1.id = x.obj_id WHERE x.name = :name_1" + "JOIN x ON object_1.id = x.obj_id WHERE x.name = :name_1", ) def test_join_prop_to_prop(self): @@ -3029,8 +3729,7 @@ class AliasFromCorrectLeftTest( # B -> A, but both are Object. So when we say A.x_list, make sure # we pick the correct right side - q = s.query(B).\ - join(B.a_list, A.x_list).filter(X.name == 'x1') + q = s.query(B).join(B.a_list, A.x_list).filter(X.name == "x1") self.assert_compile( q, @@ -3043,40 +3742,51 @@ class AliasFromCorrectLeftTest( "object AS object_1 " "JOIN a AS a_1 ON object_1.id = a_1.id" ") ON a_1.id = a_b_association_1.a_id " - "JOIN x ON object_1.id = x.obj_id WHERE x.name = :name_1" + "JOIN x ON object_1.id = x.obj_id WHERE x.name = :name_1", ) + class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): __dialect__ = default.DefaultDialect(supports_native_boolean=True) run_setup_bind = None - run_setup_mappers = 'once' + run_setup_mappers = "once" run_create_tables = None @classmethod def define_tables(cls, metadata): - Table('people', metadata, - Column('people_id', Integer, primary_key=True), - Column('age', Integer), - Column('name', String(30))) - Table('bookcases', metadata, - Column('bookcase_id', Integer, primary_key=True), - Column( - 'bookcase_owner_id', - Integer, ForeignKey('people.people_id')), - Column('bookcase_shelves', Integer), - Column('bookcase_width', Integer)) - Table('books', metadata, - Column('book_id', Integer, primary_key=True), - Column( - 'bookcase_id', Integer, ForeignKey('bookcases.bookcase_id')), - Column('book_owner_id', Integer, ForeignKey('people.people_id')), - Column('book_weight', Integer)) + Table( + "people", + metadata, + Column("people_id", Integer, primary_key=True), + Column("age", Integer), + Column("name", String(30)), + ) + Table( + "bookcases", + metadata, + Column("bookcase_id", Integer, primary_key=True), + Column( + "bookcase_owner_id", Integer, ForeignKey("people.people_id") + ), + Column("bookcase_shelves", Integer), + Column("bookcase_width", Integer), + ) + Table( + "books", + metadata, + Column("book_id", Integer, primary_key=True), + Column( + "bookcase_id", Integer, ForeignKey("bookcases.bookcase_id") + ), + Column("book_owner_id", Integer, ForeignKey("people.people_id")), + Column("book_weight", Integer), + ) @classmethod def setup_classes(cls): - people, bookcases, books = cls.tables('people', 'bookcases', 'books') + people, bookcases, books = cls.tables("people", "bookcases", "books") class Person(cls.Comparable): pass @@ -3088,10 +3798,14 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): pass mapper(Person, people) - mapper(Bookcase, bookcases, properties={ - 'owner': relationship(Person), - 'books': relationship(Book) - }) + mapper( + Bookcase, + bookcases, + properties={ + "owner": relationship(Person), + "books": relationship(Book), + }, + ) mapper(Book, books) def test_select_subquery(self): @@ -3099,14 +3813,16 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): s = Session() - subq = s.query(Book.book_id).correlate(Person).filter( - Person.people_id == Book.book_owner_id - ).subquery().lateral() - - stmt = s.query(Person, subq.c.book_id).join( - subq, true() + subq = ( + s.query(Book.book_id) + .correlate(Person) + .filter(Person.people_id == Book.book_owner_id) + .subquery() + .lateral() ) + stmt = s.query(Person, subq.c.book_id).join(subq, true()) + self.assert_compile( stmt, "SELECT people.people_id AS people_people_id, " @@ -3114,7 +3830,7 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): "anon_1.book_id AS anon_1_book_id " "FROM people JOIN LATERAL " "(SELECT books.book_id AS book_id FROM books " - "WHERE people.people_id = books.book_owner_id) AS anon_1 ON true" + "WHERE people.people_id = books.book_owner_id) AS anon_1 ON true", ) # sef == select_entity_from @@ -3125,12 +3841,17 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): stmt = s.query(Person).subquery() - subq = s.query(Book.book_id).filter( - Person.people_id == Book.book_owner_id - ).subquery().lateral() + subq = ( + s.query(Book.book_id) + .filter(Person.people_id == Book.book_owner_id) + .subquery() + .lateral() + ) - stmt = s.query(Person, subq.c.book_id).select_entity_from(stmt).join( - subq, true() + stmt = ( + s.query(Person, subq.c.book_id) + .select_entity_from(stmt) + .join(subq, true()) ) self.assert_compile( @@ -3143,7 +3864,7 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): "people.name AS name FROM people) AS anon_1 " "JOIN LATERAL " "(SELECT books.book_id AS book_id FROM books " - "WHERE anon_1.people_id = books.book_owner_id) AS anon_2 ON true" + "WHERE anon_1.people_id = books.book_owner_id) AS anon_2 ON true", ) def test_select_subquery_sef_implicit_correlate_coreonly(self): @@ -3153,12 +3874,16 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): stmt = s.query(Person).subquery() - subq = select([Book.book_id]).where( - Person.people_id == Book.book_owner_id - ).lateral() + subq = ( + select([Book.book_id]) + .where(Person.people_id == Book.book_owner_id) + .lateral() + ) - stmt = s.query(Person, subq.c.book_id).select_entity_from(stmt).join( - subq, true() + stmt = ( + s.query(Person, subq.c.book_id) + .select_entity_from(stmt) + .join(subq, true()) ) self.assert_compile( @@ -3171,7 +3896,7 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): "people.name AS name FROM people) AS anon_1 " "JOIN LATERAL " "(SELECT books.book_id AS book_id FROM books " - "WHERE anon_1.people_id = books.book_owner_id) AS anon_2 ON true" + "WHERE anon_1.people_id = books.book_owner_id) AS anon_2 ON true", ) def test_select_subquery_sef_explicit_correlate_coreonly(self): @@ -3181,12 +3906,17 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): stmt = s.query(Person).subquery() - subq = select([Book.book_id]).correlate(Person).where( - Person.people_id == Book.book_owner_id - ).lateral() + subq = ( + select([Book.book_id]) + .correlate(Person) + .where(Person.people_id == Book.book_owner_id) + .lateral() + ) - stmt = s.query(Person, subq.c.book_id).select_entity_from(stmt).join( - subq, true() + stmt = ( + s.query(Person, subq.c.book_id) + .select_entity_from(stmt) + .join(subq, true()) ) self.assert_compile( @@ -3199,7 +3929,7 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): "people.name AS name FROM people) AS anon_1 " "JOIN LATERAL " "(SELECT books.book_id AS book_id FROM books " - "WHERE anon_1.people_id = books.book_owner_id) AS anon_2 ON true" + "WHERE anon_1.people_id = books.book_owner_id) AS anon_2 ON true", ) def test_select_subquery_sef_explicit_correlate(self): @@ -3209,12 +3939,18 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): stmt = s.query(Person).subquery() - subq = s.query(Book.book_id).correlate(Person).filter( - Person.people_id == Book.book_owner_id - ).subquery().lateral() + subq = ( + s.query(Book.book_id) + .correlate(Person) + .filter(Person.people_id == Book.book_owner_id) + .subquery() + .lateral() + ) - stmt = s.query(Person, subq.c.book_id).select_entity_from(stmt).join( - subq, true() + stmt = ( + s.query(Person, subq.c.book_id) + .select_entity_from(stmt) + .join(subq, true()) ) self.assert_compile( @@ -3227,7 +3963,7 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): "people.name AS name FROM people) AS anon_1 " "JOIN LATERAL " "(SELECT books.book_id AS book_id FROM books " - "WHERE anon_1.people_id = books.book_owner_id) AS anon_2 ON true" + "WHERE anon_1.people_id = books.book_owner_id) AS anon_2 ON true", ) def test_from_function(self): @@ -3245,7 +3981,7 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): "bookcases.bookcase_width AS bookcases_bookcase_width " "FROM bookcases JOIN " "LATERAL generate_series(:generate_series_1, " - "bookcases.bookcase_shelves) AS anon_1 ON true" + "bookcases.bookcase_shelves) AS anon_1 ON true", ) def test_from_function_select_entity_from(self): @@ -3270,6 +4006,5 @@ class JoinLateralTest(fixtures.MappedTest, AssertsCompiledSQL): "AS anon_1 " "JOIN LATERAL " "generate_series(:generate_series_1, anon_1.bookcase_shelves) " - "AS anon_2 ON true" + "AS anon_2 ON true", ) - diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py index b79aeb14b2..566317f0bf 100644 --- a/test/orm/test_lazy_relations.py +++ b/test/orm/test_lazy_relations.py @@ -19,8 +19,9 @@ from test.orm import _fixtures from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing import mock + class LazyTest(_fixtures.FixtureTest): - run_inserts = 'once' + run_inserts = "once" run_deletes = None def test_basic(self): @@ -28,18 +29,28 @@ class LazyTest(_fixtures.FixtureTest): self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), lazy='select') - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), lazy="select" + ) + }, + ) sess = create_session() q = sess.query(User) eq_( - [User(id=7, - addresses=[Address(id=1, email_address='jack@bean.com')])], - q.filter(users.c.id == 7).all() + [ + User( + id=7, + addresses=[Address(id=1, email_address="jack@bean.com")], + ) + ], + q.filter(users.c.id == 7).all(), ) def test_needs_parent(self): @@ -49,44 +60,56 @@ class LazyTest(_fixtures.FixtureTest): self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), lazy='select') - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), lazy="select" + ) + }, + ) sess = create_session() q = sess.query(User) u = q.filter(users.c.id == 7).first() sess.expunge(u) - assert_raises(orm_exc.DetachedInstanceError, getattr, u, 'addresses') + assert_raises(orm_exc.DetachedInstanceError, getattr, u, "addresses") def test_orderby(self): users, Address, addresses, User = ( self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), - lazy='select', order_by=addresses.c.email_address), - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), + lazy="select", + order_by=addresses.c.email_address, + ) + }, + ) q = create_session().query(User) assert [ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - Address(id=2, email_address='ed@wood.com') - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[]) + User(id=7, addresses=[Address(id=1)]), + User( + id=8, + addresses=[ + Address(id=3, email_address="ed@bettyboop.com"), + Address(id=4, email_address="ed@lala.com"), + Address(id=2, email_address="ed@wood.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + User(id=10, addresses=[]), ] == q.all() def test_orderby_secondary(self): @@ -97,28 +120,33 @@ class LazyTest(_fixtures.FixtureTest): self.classes.Address, self.tables.addresses, self.tables.users, - self.classes.User) + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, lazy='select'), - )) + mapper( + User, + users, + properties=dict(addresses=relationship(Address, lazy="select")), + ) q = create_session().query(User) - result = q.filter(users.c.id == addresses.c.user_id).\ - order_by(addresses.c.email_address).all() + result = ( + q.filter(users.c.id == addresses.c.user_id) + .order_by(addresses.c.email_address) + .all() + ) assert [ - User(id=8, addresses=[ - Address(id=2, email_address='ed@wood.com'), - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=7, addresses=[ - Address(id=1) - ]), + User( + id=8, + addresses=[ + Address(id=2, email_address="ed@wood.com"), + Address(id=3, email_address="ed@bettyboop.com"), + Address(id=4, email_address="ed@lala.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + User(id=7, addresses=[Address(id=1)]), ] == result def test_orderby_desc(self): @@ -126,29 +154,35 @@ class LazyTest(_fixtures.FixtureTest): self.classes.Address, self.tables.addresses, self.tables.users, - self.classes.User) + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship( - Address, lazy='select', - order_by=[sa.desc(addresses.c.email_address)]), - )) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, + lazy="select", + order_by=[sa.desc(addresses.c.email_address)], + ) + ), + ) sess = create_session() assert [ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=2, email_address='ed@wood.com'), - Address(id=4, email_address='ed@lala.com'), - Address(id=3, email_address='ed@bettyboop.com'), - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[]) + User(id=7, addresses=[Address(id=1)]), + User( + id=8, + addresses=[ + Address(id=2, email_address="ed@wood.com"), + Address(id=4, email_address="ed@lala.com"), + Address(id=3, email_address="ed@bettyboop.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + User(id=10, addresses=[]), ] == sess.query(User).all() def test_no_orphan(self): @@ -158,50 +192,69 @@ class LazyTest(_fixtures.FixtureTest): self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship( - Address, cascade="all,delete-orphan", lazy='select') - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, cascade="all,delete-orphan", lazy="select" + ) + }, + ) mapper(Address, addresses) sess = create_session() user = sess.query(User).get(7) - assert getattr(User, 'addresses').hasparent( - attributes.instance_state(user.addresses[0]), optimistic=True) + assert getattr(User, "addresses").hasparent( + attributes.instance_state(user.addresses[0]), optimistic=True + ) assert not sa.orm.class_mapper(Address)._is_orphan( - attributes.instance_state(user.addresses[0])) + attributes.instance_state(user.addresses[0]) + ) def test_limit(self): """test limit operations combined with lazy-load relationships.""" - users, items, order_items, orders, Item, \ - User, Address, Order, addresses = ( - self.tables.users, - self.tables.items, - self.tables.order_items, - self.tables.orders, - self.classes.Item, - self.classes.User, - self.classes.Address, - self.classes.Order, - self.tables.addresses) + users, items, order_items, orders, Item, User, Address, Order, addresses = ( + self.tables.users, + self.tables.items, + self.tables.order_items, + self.tables.orders, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.classes.Order, + self.tables.addresses, + ) mapper(Item, items) - mapper(Order, orders, properties={ - 'items': relationship(Item, secondary=order_items, lazy='select') - }) - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), lazy='select'), - 'orders': relationship(Order, lazy='select') - }) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, secondary=order_items, lazy="select" + ) + }, + ) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), lazy="select" + ), + "orders": relationship(Order, lazy="select"), + }, + ) sess = create_session() q = sess.query(User) - if testing.against('mssql'): + if testing.against("mssql"): result = q.limit(2).all() assert self.static.user_all_result[:2] == result else: @@ -209,38 +262,52 @@ class LazyTest(_fixtures.FixtureTest): assert self.static.user_all_result[1:3] == result def test_distinct(self): - users, items, order_items, orders, \ - Item, User, Address, Order, addresses = ( - self.tables.users, - self.tables.items, - self.tables.order_items, - self.tables.orders, - self.classes.Item, - self.classes.User, - self.classes.Address, - self.classes.Order, - self.tables.addresses) + users, items, order_items, orders, Item, User, Address, Order, addresses = ( + self.tables.users, + self.tables.items, + self.tables.order_items, + self.tables.orders, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.classes.Order, + self.tables.addresses, + ) mapper(Item, items) - mapper(Order, orders, properties={ - 'items': relationship(Item, secondary=order_items, lazy='select') - }) - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), lazy='select'), - 'orders': relationship(Order, lazy='select') - }) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, secondary=order_items, lazy="select" + ) + }, + ) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), lazy="select" + ), + "orders": relationship(Order, lazy="select"), + }, + ) sess = create_session() q = sess.query(User) # use a union all to get a lot of rows to join against - u2 = users.alias('u2') + u2 = users.alias("u2") s = sa.union_all( u2.select(use_labels=True), - u2.select(use_labels=True), u2.select(use_labels=True)).alias('u') - result = q.filter(s.c.u2_id == User.id).order_by(User.id).distinct() \ - .all() + u2.select(use_labels=True), + u2.select(use_labels=True), + ).alias("u") + result = ( + q.filter(s.c.u2_id == User.id).order_by(User.id).distinct().all() + ) eq_(self.static.user_all_result, result) def test_uselist_false_warning(self): @@ -251,43 +318,55 @@ class LazyTest(_fixtures.FixtureTest): self.classes.User, self.tables.users, self.tables.orders, - self.classes.Order) + self.classes.Order, + ) - mapper(User, users, properties={ - 'order': relationship(Order, uselist=False) - }) + mapper( + User, + users, + properties={"order": relationship(Order, uselist=False)}, + ) mapper(Order, orders) s = create_session() u1 = s.query(User).filter(User.id == 7).one() - assert_raises(sa.exc.SAWarning, getattr, u1, 'order') + assert_raises(sa.exc.SAWarning, getattr, u1, "order") def test_callable_bind(self): Address, addresses, users, User = ( self.classes.Address, self.tables.addresses, self.tables.users, - self.classes.User) - - mapper(User, users, properties=dict( - addresses=relationship( - mapper(Address, addresses), - lazy='select', - primaryjoin=and_( - users.c.id == addresses.c.user_id, - users.c.name == bindparam("name", callable_=lambda: "ed") + self.classes.User, + ) + + mapper( + User, + users, + properties=dict( + addresses=relationship( + mapper(Address, addresses), + lazy="select", + primaryjoin=and_( + users.c.id == addresses.c.user_id, + users.c.name + == bindparam("name", callable_=lambda: "ed"), + ), ) - ) - )) + ), + ) s = Session() - ed = s.query(User).filter_by(name='ed').one() - eq_(ed.addresses, [ - Address(id=2, user_id=8), - Address(id=3, user_id=8), - Address(id=4, user_id=8) - ]) - - fred = s.query(User).filter_by(name='fred').one() + ed = s.query(User).filter_by(name="ed").one() + eq_( + ed.addresses, + [ + Address(id=2, user_id=8), + Address(id=3, user_id=8), + Address(id=4, user_id=8), + ], + ) + + fred = s.query(User).filter_by(name="fred").one() eq_(fred.addresses, []) # fred is missing def test_custom_bind(self): @@ -295,18 +374,23 @@ class LazyTest(_fixtures.FixtureTest): self.classes.Address, self.tables.addresses, self.tables.users, - self.classes.User) - - mapper(User, users, properties=dict( - addresses=relationship( - mapper(Address, addresses), - lazy='select', - primaryjoin=and_( - users.c.id == addresses.c.user_id, - users.c.name == bindparam("name") + self.classes.User, + ) + + mapper( + User, + users, + properties=dict( + addresses=relationship( + mapper(Address, addresses), + lazy="select", + primaryjoin=and_( + users.c.id == addresses.c.user_id, + users.c.name == bindparam("name"), + ), ) - ) - )) + ), + ) canary = mock.Mock() @@ -322,23 +406,32 @@ class LazyTest(_fixtures.FixtureTest): query._params = query._params.union(dict(name=self.crit)) s = Session() - ed = s.query(User).options(MyOption("ed")).filter_by(name='ed').one() - eq_(ed.addresses, [ - Address(id=2, user_id=8), - Address(id=3, user_id=8), - Address(id=4, user_id=8) - ]) + ed = s.query(User).options(MyOption("ed")).filter_by(name="ed").one() + eq_( + ed.addresses, + [ + Address(id=2, user_id=8), + Address(id=3, user_id=8), + Address(id=4, user_id=8), + ], + ) eq_(canary.mock_calls, [mock.call()]) - fred = s.query(User).\ - options(MyOption("ed")).filter_by(name='fred').one() + fred = ( + s.query(User).options(MyOption("ed")).filter_by(name="fred").one() + ) eq_(fred.addresses, []) # fred is missing eq_(canary.mock_calls, [mock.call(), mock.call()]) # the lazy query was not cached; the option is re-applied to the # Fred object due to populate_existing() - fred = s.query(User).populate_existing().\ - options(MyOption("fred")).filter_by(name='fred').one() + fred = ( + s.query(User) + .populate_existing() + .options(MyOption("fred")) + .filter_by(name="fred") + .one() + ) eq_(fred.addresses, [Address(id=5, user_id=9)]) # fred is there eq_(canary.mock_calls, [mock.call(), mock.call(), mock.call()]) @@ -348,12 +441,18 @@ class LazyTest(_fixtures.FixtureTest): self.classes.Address, self.tables.addresses, self.tables.users, - self.classes.User) + self.classes.User, + ) - mapper(User, users, properties=dict( - address=relationship( - mapper(Address, addresses), lazy='select', uselist=False) - )) + mapper( + User, + users, + properties=dict( + address=relationship( + mapper(Address, addresses), lazy="select", uselist=False + ) + ), + ) q = create_session().query(User) result = q.filter(users.c.id == 7).all() assert [User(id=7, address=Address(id=1))] == result @@ -363,18 +462,29 @@ class LazyTest(_fixtures.FixtureTest): self.classes.Address, self.tables.addresses, self.tables.users, - self.classes.User) - - mapper(Address, addresses, - primary_key=[addresses.c.user_id, addresses.c.email_address]) - - mapper(User, users, properties=dict( - address=relationship( - Address, uselist=False, - primaryjoin=sa.and_( - users.c.id == addresses.c.user_id, - addresses.c.email_address == 'ed@bettyboop.com')) - )) + self.classes.User, + ) + + mapper( + Address, + addresses, + primary_key=[addresses.c.user_id, addresses.c.email_address], + ) + + mapper( + User, + users, + properties=dict( + address=relationship( + Address, + uselist=False, + primaryjoin=sa.and_( + users.c.id == addresses.c.user_id, + addresses.c.email_address == "ed@bettyboop.com", + ), + ) + ), + ) q = create_session().query(User) eq_( [ @@ -383,7 +493,7 @@ class LazyTest(_fixtures.FixtureTest): User(id=9, address=None), User(id=10, address=None), ], - list(q) + list(q), ) def test_double(self): @@ -396,10 +506,11 @@ class LazyTest(_fixtures.FixtureTest): self.classes.User, self.classes.Address, self.classes.Order, - self.tables.addresses) + self.tables.addresses, + ) - openorders = sa.alias(orders, 'openorders') - closedorders = sa.alias(orders, 'closedorders') + openorders = sa.alias(orders, "openorders") + closedorders = sa.alias(orders, "closedorders") mapper(Address, addresses) @@ -407,19 +518,29 @@ class LazyTest(_fixtures.FixtureTest): open_mapper = mapper(Order, openorders, non_primary=True) closed_mapper = mapper(Order, closedorders, non_primary=True) - mapper(User, users, properties=dict( - addresses=relationship(Address, lazy=True), - open_orders=relationship( - open_mapper, - primaryjoin=sa.and_( - openorders.c.isopen == 1, - users.c.id == openorders.c.user_id), lazy='select'), - closed_orders=relationship( - closed_mapper, - primaryjoin=sa.and_( - closedorders.c.isopen == 0, - users.c.id == closedorders.c.user_id), lazy='select') - )) + mapper( + User, + users, + properties=dict( + addresses=relationship(Address, lazy=True), + open_orders=relationship( + open_mapper, + primaryjoin=sa.and_( + openorders.c.isopen == 1, + users.c.id == openorders.c.user_id, + ), + lazy="select", + ), + closed_orders=relationship( + closed_mapper, + primaryjoin=sa.and_( + closedorders.c.isopen == 0, + users.c.id == closedorders.c.user_id, + ), + lazy="select", + ), + ), + ) q = create_session().query(User) assert [ @@ -427,35 +548,38 @@ class LazyTest(_fixtures.FixtureTest): id=7, addresses=[Address(id=1)], open_orders=[Order(id=3)], - closed_orders=[Order(id=1), Order(id=5)] + closed_orders=[Order(id=1), Order(id=5)], ), User( id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)], open_orders=[], - closed_orders=[] + closed_orders=[], ), User( id=9, addresses=[Address(id=5)], open_orders=[Order(id=4)], - closed_orders=[Order(id=2)] + closed_orders=[Order(id=2)], ), - User(id=10) - + User(id=10), ] == q.all() sess = create_session() user = sess.query(User).get(7) eq_( [Order(id=1), Order(id=5)], - create_session().query(closed_mapper).with_parent( - user, property='closed_orders').all() + create_session() + .query(closed_mapper) + .with_parent(user, property="closed_orders") + .all(), ) eq_( [Order(id=3)], - create_session().query(open_mapper). - with_parent(user, property='open_orders').all() + create_session() + .query(open_mapper) + .with_parent(user, property="open_orders") + .all(), ) def test_many_to_many(self): @@ -464,20 +588,26 @@ class LazyTest(_fixtures.FixtureTest): self.tables.items, self.tables.item_keywords, self.classes.Keyword, - self.classes.Item) + self.classes.Item, + ) mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords=relationship( - Keyword, secondary=item_keywords, lazy='select'), - )) + mapper( + Item, + items, + properties=dict( + keywords=relationship( + Keyword, secondary=item_keywords, lazy="select" + ) + ), + ) q = create_session().query(Item) assert self.static.item_keyword_result == q.all() eq_( self.static.item_keyword_result[0:2], - q.join('keywords').filter(keywords.c.name == 'red').all() + q.join("keywords").filter(keywords.c.name == "red").all(), ) def test_uses_get(self): @@ -488,23 +618,32 @@ class LazyTest(_fixtures.FixtureTest): self.classes.Address, self.tables.addresses, self.tables.users, - self.classes.User) + self.classes.User, + ) for pj in ( None, users.c.id == addresses.c.user_id, - addresses.c.user_id == users.c.id + addresses.c.user_id == users.c.id, ): - mapper(Address, addresses, properties=dict( - user=relationship( - mapper(User, users), lazy='select', primaryjoin=pj) - )) + mapper( + Address, + addresses, + properties=dict( + user=relationship( + mapper(User, users), lazy="select", primaryjoin=pj + ) + ), + ) sess = create_session() # load address - a1 = sess.query(Address).\ - filter_by(email_address="ed@wood.com").one() + a1 = ( + sess.query(Address) + .filter_by(email_address="ed@wood.com") + .one() + ) # load user that is attached to the address u1 = sess.query(User).get(8) @@ -512,6 +651,7 @@ class LazyTest(_fixtures.FixtureTest): def go(): # lazy load of a1.user should get it from the session assert a1.user is u1 + self.assert_sql_count(testing.db, go, 0) sa.orm.clear_mappers() @@ -539,30 +679,43 @@ class LazyTest(_fixtures.FixtureTest): ]: m = sa.MetaData() users = Table( - 'users', m, + "users", + m, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30), nullable=False), + "id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(30), nullable=False), ) addresses = Table( - 'addresses', m, + "addresses", + m, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', tt, ForeignKey('users.id')), - Column('email_address', String(50), nullable=False), + "id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("user_id", tt, ForeignKey("users.id")), + Column("email_address", String(50), nullable=False), ) - mapper(Address, addresses, properties=dict( - user=relationship(mapper(User, users)) - )) + mapper( + Address, + addresses, + properties=dict(user=relationship(mapper(User, users))), + ) sess = create_session(bind=testing.db) # load address - a1 = sess.query(Address).\ - filter_by(email_address="ed@wood.com").one() + a1 = ( + sess.query(Address) + .filter_by(email_address="ed@wood.com") + .one() + ) # load user that is attached to the address u1 = sess.query(User).get(8) @@ -570,6 +723,7 @@ class LazyTest(_fixtures.FixtureTest): def go(): # lazy load of a1.user should get it from the session assert a1.user is u1 + self.assert_sql_count(testing.db, go, 0) sa.orm.clear_mappers() @@ -578,11 +732,16 @@ class LazyTest(_fixtures.FixtureTest): self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.classes.User, + ) - mapper(Address, addresses, properties=dict( - user=relationship(mapper(User, users), lazy='select') - )) + mapper( + Address, + addresses, + properties=dict( + user=relationship(mapper(User, users), lazy="select") + ), + ) sess = create_session() q = sess.query(Address) a = q.filter(addresses.c.id == 1).one() @@ -598,11 +757,14 @@ class LazyTest(_fixtures.FixtureTest): self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user') - }) + mapper( + User, + users, + properties={"addresses": relationship(Address, backref="user")}, + ) mapper(Address, addresses) sess = create_session() ad = sess.query(Address).filter_by(id=1).one() @@ -611,30 +773,35 @@ class LazyTest(_fixtures.FixtureTest): def go(): ad.user = None assert ad.user is None + self.assert_sql_count(testing.db, go, 0) u1 = sess.query(User).filter_by(id=7).one() def go(): assert ad not in u1.addresses + self.assert_sql_count(testing.db, go, 1) - sess.expire(u1, ['addresses']) + sess.expire(u1, ["addresses"]) def go(): assert ad in u1.addresses + self.assert_sql_count(testing.db, go, 1) - sess.expire(u1, ['addresses']) + sess.expire(u1, ["addresses"]) ad2 = Address() def go(): ad2.user = u1 assert ad2.user is u1 + self.assert_sql_count(testing.db, go, 0) def go(): assert ad2 in u1.addresses + self.assert_sql_count(testing.db, go, 1) @@ -652,21 +819,23 @@ class GetterStateTest(_fixtures.FixtureTest): def process_bind_param(self, value, dialect): return ";".join( "%s=%s" % (k, v) - for k, v in - sorted(value.items(), key=lambda key: key[0])) + for k, v in sorted(value.items(), key=lambda key: key[0]) + ) def process_result_value(self, value, dialect): return dict(elem.split("=", 1) for elem in value.split(";")) category = Table( - 'category', metadata, - Column('id', Integer, primary_key=True), - Column('data', MyHashType()) + "category", + metadata, + Column("id", Integer, primary_key=True), + Column("data", MyHashType()), ) article = Table( - 'article', metadata, - Column('id', Integer, primary_key=True), - Column('data', MyHashType()) + "article", + metadata, + Column("id", Integer, primary_key=True), + Column("data", MyHashType()), ) class Category(fixtures.ComparableEntity): @@ -676,13 +845,17 @@ class GetterStateTest(_fixtures.FixtureTest): pass mapper(Category, category) - mapper(Article, article, properties={ - "category": relationship( - Category, - primaryjoin=orm.foreign(article.c.data) == category.c.data, - load_on_pending=load_on_pending - ) - }) + mapper( + Article, + article, + properties={ + "category": relationship( + Category, + primaryjoin=orm.foreign(article.c.data) == category.c.data, + load_on_pending=load_on_pending, + ) + }, + ) metadata.create_all() sess = Session(autoflush=False) @@ -703,26 +876,37 @@ class GetterStateTest(_fixtures.FixtureTest): self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, back_populates='user') - }) - mapper(Address, addresses, properties={ - 'user': relationship( - User, - primaryjoin=and_( - users.c.id == addresses.c.user_id, users.c.id != 27) - if dont_use_get else None, - back_populates='addresses' - ) - }) + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship(Address, back_populates="user") + }, + ) + mapper( + Address, + addresses, + properties={ + "user": relationship( + User, + primaryjoin=and_( + users.c.id == addresses.c.user_id, users.c.id != 27 + ) + if dont_use_get + else None, + back_populates="addresses", + ) + }, + ) sess = create_session() - a1 = Address(email_address='a1') + a1 = Address(email_address="a1") sess.add(a1) if populate_user: - a1.user = User(name='ed') + a1.user = User(name="ed") sess.flush() if populate_user: sess.expire_all() @@ -735,39 +919,29 @@ class GetterStateTest(_fixtures.FixtureTest): eq_(a1.user, None) # doesn't emit SQL - self.assert_sql_count( - testing.db, - go, - 0 - ) + self.assert_sql_count(testing.db, go, 0) @testing.provide_metadata def test_no_use_get_params_not_hashable(self): - Category, Article, sess, a1, c1 = \ - self._unhashable_fixture(self.metadata) + Category, Article, sess, a1, c1 = self._unhashable_fixture( + self.metadata + ) def go(): eq_(a1.category, c1) - self.assert_sql_count( - testing.db, - go, - 1 - ) + self.assert_sql_count(testing.db, go, 1) @testing.provide_metadata def test_no_use_get_params_not_hashable_on_pending(self): - Category, Article, sess, a1, c1 = \ - self._unhashable_fixture(self.metadata, load_on_pending=True) + Category, Article, sess, a1, c1 = self._unhashable_fixture( + self.metadata, load_on_pending=True + ) def go(): eq_(a1.category, c1) - self.assert_sql_count( - testing.db, - go, - 1 - ) + self.assert_sql_count(testing.db, go, 1) def test_get_empty_passive_return_never_set(self): User, Address, sess, a1 = self._u_ad_fixture(False) @@ -775,11 +949,12 @@ class GetterStateTest(_fixtures.FixtureTest): Address.user.impl.get( attributes.instance_state(a1), attributes.instance_dict(a1), - passive=attributes.PASSIVE_RETURN_NEVER_SET), - attributes.NEVER_SET + passive=attributes.PASSIVE_RETURN_NEVER_SET, + ), + attributes.NEVER_SET, ) - assert 'user_id' not in a1.__dict__ - assert 'user' not in a1.__dict__ + assert "user_id" not in a1.__dict__ + assert "user" not in a1.__dict__ def test_history_empty_passive_return_never_set(self): User, Address, sess, a1 = self._u_ad_fixture(False) @@ -787,11 +962,12 @@ class GetterStateTest(_fixtures.FixtureTest): Address.user.impl.get_history( attributes.instance_state(a1), attributes.instance_dict(a1), - passive=attributes.PASSIVE_RETURN_NEVER_SET), - ((), (), ()) + passive=attributes.PASSIVE_RETURN_NEVER_SET, + ), + ((), (), ()), ) - assert 'user_id' not in a1.__dict__ - assert 'user' not in a1.__dict__ + assert "user_id" not in a1.__dict__ + assert "user" not in a1.__dict__ def test_get_empty_passive_no_initialize(self): User, Address, sess, a1 = self._u_ad_fixture(False) @@ -799,11 +975,12 @@ class GetterStateTest(_fixtures.FixtureTest): Address.user.impl.get( attributes.instance_state(a1), attributes.instance_dict(a1), - passive=attributes.PASSIVE_NO_INITIALIZE), - attributes.PASSIVE_NO_RESULT + passive=attributes.PASSIVE_NO_INITIALIZE, + ), + attributes.PASSIVE_NO_RESULT, ) - assert 'user_id' not in a1.__dict__ - assert 'user' not in a1.__dict__ + assert "user_id" not in a1.__dict__ + assert "user" not in a1.__dict__ def test_history_empty_passive_no_initialize(self): User, Address, sess, a1 = self._u_ad_fixture(False) @@ -811,11 +988,12 @@ class GetterStateTest(_fixtures.FixtureTest): Address.user.impl.get_history( attributes.instance_state(a1), attributes.instance_dict(a1), - passive=attributes.PASSIVE_NO_INITIALIZE), - attributes.HISTORY_BLANK + passive=attributes.PASSIVE_NO_INITIALIZE, + ), + attributes.HISTORY_BLANK, ) - assert 'user_id' not in a1.__dict__ - assert 'user' not in a1.__dict__ + assert "user_id" not in a1.__dict__ + assert "user" not in a1.__dict__ def test_get_populated_passive_no_initialize(self): User, Address, sess, a1 = self._u_ad_fixture(True) @@ -823,11 +1001,12 @@ class GetterStateTest(_fixtures.FixtureTest): Address.user.impl.get( attributes.instance_state(a1), attributes.instance_dict(a1), - passive=attributes.PASSIVE_NO_INITIALIZE), - attributes.PASSIVE_NO_RESULT + passive=attributes.PASSIVE_NO_INITIALIZE, + ), + attributes.PASSIVE_NO_RESULT, ) - assert 'user_id' not in a1.__dict__ - assert 'user' not in a1.__dict__ + assert "user_id" not in a1.__dict__ + assert "user" not in a1.__dict__ def test_history_populated_passive_no_initialize(self): User, Address, sess, a1 = self._u_ad_fixture(True) @@ -835,11 +1014,12 @@ class GetterStateTest(_fixtures.FixtureTest): Address.user.impl.get_history( attributes.instance_state(a1), attributes.instance_dict(a1), - passive=attributes.PASSIVE_NO_INITIALIZE), - attributes.HISTORY_BLANK + passive=attributes.PASSIVE_NO_INITIALIZE, + ), + attributes.HISTORY_BLANK, ) - assert 'user_id' not in a1.__dict__ - assert 'user' not in a1.__dict__ + assert "user_id" not in a1.__dict__ + assert "user" not in a1.__dict__ def test_get_populated_passive_return_never_set(self): User, Address, sess, a1 = self._u_ad_fixture(True) @@ -847,8 +1027,9 @@ class GetterStateTest(_fixtures.FixtureTest): Address.user.impl.get( attributes.instance_state(a1), attributes.instance_dict(a1), - passive=attributes.PASSIVE_RETURN_NEVER_SET), - User(name='ed') + passive=attributes.PASSIVE_RETURN_NEVER_SET, + ), + User(name="ed"), ) def test_history_populated_passive_return_never_set(self): @@ -857,13 +1038,14 @@ class GetterStateTest(_fixtures.FixtureTest): Address.user.impl.get_history( attributes.instance_state(a1), attributes.instance_dict(a1), - passive=attributes.PASSIVE_RETURN_NEVER_SET), - ((), [User(name='ed'), ], ()) + passive=attributes.PASSIVE_RETURN_NEVER_SET, + ), + ((), [User(name="ed")], ()), ) class M2OGetTest(_fixtures.FixtureTest): - run_inserts = 'once' + run_inserts = "once" run_deletes = None def test_m2o_noload(self): @@ -873,16 +1055,15 @@ class M2OGetTest(_fixtures.FixtureTest): self.tables.users, self.classes.Address, self.tables.addresses, - self.classes.User) + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship(User) - }) + mapper(Address, addresses, properties={"user": relationship(User)}) sess = create_session() - ad1 = Address(email_address='somenewaddress', id=12) + ad1 = Address(email_address="somenewaddress", id=12) sess.add(ad1) sess.flush() sess.expunge_all() @@ -892,40 +1073,48 @@ class M2OGetTest(_fixtures.FixtureTest): def go(): # one lazy load - assert ad2.user.name == 'jack' + assert ad2.user.name == "jack" # no lazy load assert ad3.user is None + self.assert_sql_count(testing.db, go, 1) class CorrelatedTest(fixtures.MappedTest): - @classmethod def define_tables(self, meta): - Table('user_t', meta, - Column('id', Integer, primary_key=True), - Column('name', String(50))) + Table( + "user_t", + meta, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + ) - Table('stuff', meta, - Column('id', Integer, primary_key=True), - Column('date', sa.Date), - Column('user_id', Integer, ForeignKey('user_t.id'))) + Table( + "stuff", + meta, + Column("id", Integer, primary_key=True), + Column("date", sa.Date), + Column("user_id", Integer, ForeignKey("user_t.id")), + ) @classmethod def insert_data(cls): stuff, user_t = cls.tables.stuff, cls.tables.user_t user_t.insert().execute( - {'id': 1, 'name': 'user1'}, - {'id': 2, 'name': 'user2'}, - {'id': 3, 'name': 'user3'}) + {"id": 1, "name": "user1"}, + {"id": 2, "name": "user2"}, + {"id": 3, "name": "user3"}, + ) stuff.insert().execute( - {'id': 1, 'user_id': 1, 'date': datetime.date(2007, 10, 15)}, - {'id': 2, 'user_id': 1, 'date': datetime.date(2007, 12, 15)}, - {'id': 3, 'user_id': 1, 'date': datetime.date(2007, 11, 15)}, - {'id': 4, 'user_id': 2, 'date': datetime.date(2008, 1, 15)}, - {'id': 5, 'user_id': 3, 'date': datetime.date(2007, 6, 15)}) + {"id": 1, "user_id": 1, "date": datetime.date(2007, 10, 15)}, + {"id": 2, "user_id": 1, "date": datetime.date(2007, 12, 15)}, + {"id": 3, "user_id": 1, "date": datetime.date(2007, 11, 15)}, + {"id": 4, "user_id": 2, "date": datetime.date(2008, 1, 15)}, + {"id": 5, "user_id": 3, "date": datetime.date(2007, 6, 15)}, + ) def test_correlated_lazyload(self): stuff, user_t = self.tables.stuff, self.tables.user_t @@ -938,17 +1127,27 @@ class CorrelatedTest(fixtures.MappedTest): mapper(Stuff, stuff) - stuff_view = sa.select([stuff.c.id]).\ - where(stuff.c.user_id == user_t.c.id).correlate(user_t).\ - order_by(sa.desc(stuff.c.date)).limit(1) + stuff_view = ( + sa.select([stuff.c.id]) + .where(stuff.c.user_id == user_t.c.id) + .correlate(user_t) + .order_by(sa.desc(stuff.c.date)) + .limit(1) + ) - mapper(User, user_t, properties={ - 'stuff': relationship( - Stuff, - primaryjoin=sa.and_( - user_t.c.id == stuff.c.user_id, - stuff.c.id == (stuff_view.as_scalar()))) - }) + mapper( + User, + user_t, + properties={ + "stuff": relationship( + Stuff, + primaryjoin=sa.and_( + user_t.c.id == stuff.c.user_id, + stuff.c.id == (stuff_view.as_scalar()), + ), + ) + }, + ) sess = create_session() @@ -956,15 +1155,18 @@ class CorrelatedTest(fixtures.MappedTest): sess.query(User).all(), [ User( - name='user1', - stuff=[Stuff(date=datetime.date(2007, 12, 15), id=2)]), + name="user1", + stuff=[Stuff(date=datetime.date(2007, 12, 15), id=2)], + ), User( - name='user2', - stuff=[Stuff(id=4, date=datetime.date(2008, 1, 15))]), + name="user2", + stuff=[Stuff(id=4, date=datetime.date(2008, 1, 15))], + ), User( - name='user3', - stuff=[Stuff(id=5, date=datetime.date(2007, 6, 15))]) - ] + name="user3", + stuff=[Stuff(id=5, date=datetime.date(2007, 6, 15))], + ), + ], ) @@ -974,14 +1176,18 @@ class O2MWOSideFixedTest(fixtures.MappedTest): @classmethod def define_tables(self, meta): - Table('city', meta, - Column('id', Integer, primary_key=True), - Column('deleted', Boolean), - ) - Table('person', meta, - Column('id', Integer, primary_key=True), - Column('city_id', ForeignKey('city.id')) - ) + Table( + "city", + meta, + Column("id", Integer, primary_key=True), + Column("deleted", Boolean), + ) + Table( + "person", + meta, + Column("id", Integer, primary_key=True), + Column("city_id", ForeignKey("city.id")), + ) @classmethod def setup_classes(cls): @@ -996,35 +1202,35 @@ class O2MWOSideFixedTest(fixtures.MappedTest): Person, City = cls.classes.Person, cls.classes.City city, person = cls.tables.city, cls.tables.person - mapper(Person, person, properties={ - 'city': relationship(City, - primaryjoin=and_( - person.c.city_id == city.c.id, - city.c.deleted == False), # noqa - backref='people') - }) + mapper( + Person, + person, + properties={ + "city": relationship( + City, + primaryjoin=and_( + person.c.city_id == city.c.id, city.c.deleted == False + ), # noqa + backref="people", + ) + }, + ) mapper(City, city) def _fixture(self, include_other): city, person = self.tables.city, self.tables.person if include_other: - city.insert().execute( - {"id": 1, "deleted": False}, - ) + city.insert().execute({"id": 1, "deleted": False}) person.insert().execute( - {"id": 1, "city_id": 1}, - {"id": 2, "city_id": 1}, + {"id": 1, "city_id": 1}, {"id": 2, "city_id": 1} ) - city.insert().execute( - {"id": 2, "deleted": True}, - ) + city.insert().execute({"id": 2, "deleted": True}) person.insert().execute( - {"id": 3, "city_id": 2}, - {"id": 4, "city_id": 2}, + {"id": 3, "city_id": 2}, {"id": 4, "city_id": 2} ) def test_lazyload_assert_expected_sql(self): @@ -1034,10 +1240,7 @@ class O2MWOSideFixedTest(fixtures.MappedTest): c1, c2 = sess.query(City).order_by(City.id).all() def go(): - eq_( - [p.id for p in c2.people], - [] - ) + eq_([p.id for p in c2.people], []) self.assert_sql_execution( testing.db, @@ -1046,8 +1249,8 @@ class O2MWOSideFixedTest(fixtures.MappedTest): "SELECT person.id AS person_id, person.city_id AS " "person_city_id FROM person " "WHERE person.city_id = :param_1 AND :param_2 = 0", - {"param_1": 2, "param_2": 1} - ) + {"param_1": 2, "param_2": 1}, + ), ) def test_lazyload_people_other_exists(self): @@ -1055,15 +1258,9 @@ class O2MWOSideFixedTest(fixtures.MappedTest): City = self.classes.City sess = Session(testing.db) c1, c2 = sess.query(City).order_by(City.id).all() - eq_( - [p.id for p in c1.people], - [1, 2] - ) + eq_([p.id for p in c1.people], [1, 2]) - eq_( - [p.id for p in c2.people], - [] - ) + eq_([p.id for p in c2.people], []) def test_lazyload_people_no_other_exists(self): # note that if we revert #2948, *this still passes!* @@ -1075,10 +1272,7 @@ class O2MWOSideFixedTest(fixtures.MappedTest): sess = Session(testing.db) c2, = sess.query(City).order_by(City.id).all() - eq_( - [p.id for p in c2.people], - [] - ) + eq_([p.id for p in c2.people], []) class RefersToSelfLazyLoadInterferenceTest(fixtures.MappedTest): @@ -1094,21 +1288,24 @@ class RefersToSelfLazyLoadInterferenceTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'a', metadata, - Column('a_id', Integer, primary_key=True), - Column('b_id', ForeignKey('b.b_id')), + "a", + metadata, + Column("a_id", Integer, primary_key=True), + Column("b_id", ForeignKey("b.b_id")), ) Table( - 'b', metadata, - Column('b_id', Integer, primary_key=True), - Column('parent_id', ForeignKey('b.b_id')), + "b", + metadata, + Column("b_id", Integer, primary_key=True), + Column("parent_id", ForeignKey("b.b_id")), ) Table( - 'c', metadata, - Column('c_id', Integer, primary_key=True), - Column('b_id', ForeignKey('b.b_id')), + "c", + metadata, + Column("c_id", Integer, primary_key=True), + Column("b_id", ForeignKey("b.b_id")), ) @classmethod @@ -1124,14 +1321,21 @@ class RefersToSelfLazyLoadInterferenceTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - mapper(cls.classes.A, cls.tables.a, properties={ - "b": relationship(cls.classes.B) - }) - bm = mapper(cls.classes.B, cls.tables.b, properties={ - "parent": relationship( - cls.classes.B, remote_side=cls.tables.b.c.b_id), - "zc": relationship(cls.classes.C) - }) + mapper( + cls.classes.A, + cls.tables.a, + properties={"b": relationship(cls.classes.B)}, + ) + bm = mapper( + cls.classes.B, + cls.tables.b, + properties={ + "parent": relationship( + cls.classes.B, remote_side=cls.tables.b.c.b_id + ), + "zc": relationship(cls.classes.C), + }, + ) mapper(cls.classes.C, cls.tables.c) bmp = bm._props @@ -1155,14 +1359,15 @@ class RefersToSelfLazyLoadInterferenceTest(fixtures.MappedTest): # If the bug is here, the next line throws an exception session.query(B).options( - sa.orm.joinedload('parent').joinedload('zc')).all() + sa.orm.joinedload("parent").joinedload("zc") + ).all() -class TypeCoerceTest(fixtures.MappedTest, testing.AssertsExecutionResults,): +class TypeCoerceTest(fixtures.MappedTest, testing.AssertsExecutionResults): """ORM-level test for [ticket:3531]""" # mysql is having a recursion issue in the bind_expression - __only_on__ = ('sqlite', 'postgresql') + __only_on__ = ("sqlite", "postgresql") class StringAsInt(TypeDecorator): impl = String(50) @@ -1176,11 +1381,11 @@ class TypeCoerceTest(fixtures.MappedTest, testing.AssertsExecutionResults,): @classmethod def define_tables(cls, metadata): Table( - 'person', metadata, - Column("id", cls.StringAsInt, primary_key=True), + "person", metadata, Column("id", cls.StringAsInt, primary_key=True) ) Table( - "pets", metadata, + "pets", + metadata, Column("id", Integer, primary_key=True), Column("person_id", Integer), ) @@ -1195,17 +1400,22 @@ class TypeCoerceTest(fixtures.MappedTest, testing.AssertsExecutionResults,): @classmethod def setup_mappers(cls): - mapper(cls.classes.Person, cls.tables.person, properties=dict( - pets=relationship( - cls.classes.Pet, primaryjoin=( - orm.foreign(cls.tables.pets.c.person_id) == - sa.cast( - sa.type_coerce(cls.tables.person.c.id, Integer), - Integer - ) + mapper( + cls.classes.Person, + cls.tables.person, + properties=dict( + pets=relationship( + cls.classes.Pet, + primaryjoin=( + orm.foreign(cls.tables.pets.c.person_id) + == sa.cast( + sa.type_coerce(cls.tables.person.c.id, Integer), + Integer, + ) + ), ) - ) - )) + ), + ) mapper(cls.classes.Pet, cls.tables.pets) @@ -1214,9 +1424,7 @@ class TypeCoerceTest(fixtures.MappedTest, testing.AssertsExecutionResults,): Pet = self.classes.Pet s = Session() - s.add_all([ - Person(id=5), Pet(id=1, person_id=5) - ]) + s.add_all([Person(id=5), Pet(id=1, person_id=5)]) s.commit() p1 = s.query(Person).first() @@ -1229,7 +1437,7 @@ class TypeCoerceTest(fixtures.MappedTest, testing.AssertsExecutionResults,): "SELECT pets.id AS pets_id, pets.person_id " "AS pets_person_id FROM pets " "WHERE pets.person_id = CAST(:param_1 AS INTEGER)", - [{'param_1': 5}] + [{"param_1": 5}], ) ) @@ -1240,25 +1448,28 @@ class CompositeSimpleM2OTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'a', metadata, + "a", + metadata, Column("id1", Integer, primary_key=True), Column("id2", Integer, primary_key=True), ) Table( - "b_sameorder", metadata, + "b_sameorder", + metadata, Column("id", Integer, primary_key=True), - Column('a_id1', Integer), - Column('a_id2', Integer), - ForeignKeyConstraint(['a_id1', 'a_id2'], ['a.id1', 'a.id2']) + Column("a_id1", Integer), + Column("a_id2", Integer), + ForeignKeyConstraint(["a_id1", "a_id2"], ["a.id1", "a.id2"]), ) Table( - "b_differentorder", metadata, + "b_differentorder", + metadata, Column("id", Integer, primary_key=True), - Column('a_id1', Integer), - Column('a_id2', Integer), - ForeignKeyConstraint(['a_id1', 'a_id2'], ['a.id1', 'a.id2']) + Column("a_id1", Integer), + Column("a_id2", Integer), + ForeignKeyConstraint(["a_id1", "a_id2"], ["a.id1", "a.id2"]), ) @classmethod @@ -1271,30 +1482,41 @@ class CompositeSimpleM2OTest(fixtures.MappedTest): def test_use_get_sameorder(self): mapper(self.classes.A, self.tables.a) - m_b = mapper(self.classes.B, self.tables.b_sameorder, properties={ - 'a': relationship(self.classes.A) - }) + m_b = mapper( + self.classes.B, + self.tables.b_sameorder, + properties={"a": relationship(self.classes.A)}, + ) configure_mappers() is_true(m_b.relationships.a.strategy.use_get) def test_use_get_reverseorder(self): mapper(self.classes.A, self.tables.a) - m_b = mapper(self.classes.B, self.tables.b_differentorder, properties={ - 'a': relationship(self.classes.A) - }) + m_b = mapper( + self.classes.B, + self.tables.b_differentorder, + properties={"a": relationship(self.classes.A)}, + ) configure_mappers() is_true(m_b.relationships.a.strategy.use_get) def test_dont_use_get_pj_is_different(self): mapper(self.classes.A, self.tables.a) - m_b = mapper(self.classes.B, self.tables.b_sameorder, properties={ - 'a': relationship(self.classes.A, primaryjoin=and_( - self.tables.a.c.id1 == self.tables.b_sameorder.c.a_id1, - self.tables.a.c.id2 == 12 - )) - }) + m_b = mapper( + self.classes.B, + self.tables.b_sameorder, + properties={ + "a": relationship( + self.classes.A, + primaryjoin=and_( + self.tables.a.c.id1 == self.tables.b_sameorder.c.a_id1, + self.tables.a.c.id2 == 12, + ), + ) + }, + ) configure_mappers() is_false(m_b.relationships.a.strategy.use_get) diff --git a/test/orm/test_load_on_fks.py b/test/orm/test_load_on_fks.py index e9ffd640b4..78a22ac54c 100644 --- a/test/orm/test_load_on_fks.py +++ b/test/orm/test_load_on_fks.py @@ -18,18 +18,20 @@ class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase): Base = declarative_base() class Parent(Base): - __tablename__ = 'parent' + __tablename__ = "parent" - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50), nullable=False) children = relationship("Child", load_on_pending=True) class Child(Base): - __tablename__ = 'child' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) - parent_id = Column(Integer, ForeignKey('parent.id')) + __tablename__ = "child" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + parent_id = Column(Integer, ForeignKey("parent.id")) Base.metadata.create_all(engine) @@ -58,29 +60,31 @@ class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase): def go(): assert p1.children == [] + self.assert_sql_count(testing.db, go, 0) class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase): - def setUp(self): global Parent, Child, Base Base = declarative_base() class Parent(Base): - __tablename__ = 'parent' - __table_args__ = {'mysql_engine': 'InnoDB'} + __tablename__ = "parent" + __table_args__ = {"mysql_engine": "InnoDB"} - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) class Child(Base): - __tablename__ = 'child' - __table_args__ = {'mysql_engine': 'InnoDB'} + __tablename__ = "child" + __table_args__ = {"mysql_engine": "InnoDB"} - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) - parent_id = Column(Integer, ForeignKey('parent.id')) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + parent_id = Column(Integer, ForeignKey("parent.id")) parent = relationship(Parent, backref=backref("children")) @@ -193,6 +197,7 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase): def go(): assert p2.children + self.assert_sql_count(testing.db, go, 1) def test_collection_load_from_pending_no_sql(self): @@ -204,6 +209,7 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase): def go(): assert not p2.children + self.assert_sql_count(testing.db, go, 0) def test_load_on_pending_with_set(self): @@ -218,6 +224,7 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase): def go(): c3.parent = p1 + self.assert_sql_count(testing.db, go, 0) def test_backref_doesnt_double(self): @@ -260,7 +267,7 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase): # auto-expire of 'parent' when c1.parent_id # is altered. if fake_autoexpire: - sess.expire(c1, ['parent']) + sess.expire(c1, ["parent"]) # old 0.6 behavior # if manualflush and (not loadrel or @@ -315,8 +322,9 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase): for autoflush in (False, True): for manualflush in (False, True): for enable_relationship_rel in (False, True): - Child.parent.property.load_on_pending = \ + Child.parent.property.load_on_pending = ( loadonpending + ) sess.autoflush = autoflush c2 = Child() @@ -332,8 +340,9 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase): if manualflush: sess.flush() - if (loadonpending and attach) \ - or enable_relationship_rel: + if ( + loadonpending and attach + ) or enable_relationship_rel: assert c2.parent is p2 else: assert c2.parent is None diff --git a/test/orm/test_loading.py b/test/orm/test_loading.py index aa46bb22c9..bad4092ce2 100644 --- a/test/orm/test_loading.py +++ b/test/orm/test_loading.py @@ -1,19 +1,23 @@ from . import _fixtures from sqlalchemy.orm import loading, Session, aliased -from sqlalchemy.testing.assertions import eq_, \ - assert_raises, assert_raises_message +from sqlalchemy.testing.assertions import ( + eq_, + assert_raises, + assert_raises_message, +) from sqlalchemy.util import KeyedTuple from sqlalchemy.testing import mock from sqlalchemy import select from sqlalchemy import exc + # class GetFromIdentityTest(_fixtures.FixtureTest): # class LoadOnIdentTest(_fixtures.FixtureTest): # class InstanceProcessorTest(_fixture.FixtureTest): class InstancesTest(_fixtures.FixtureTest): - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod @@ -31,10 +35,7 @@ class InstancesTest(_fixtures.FixtureTest): q._entities = [ mock.Mock(row_processor=mock.Mock(side_effect=Exception("boom"))) ] - assert_raises( - Exception, - list, loading.instances(q, cursor, ctx) - ) + assert_raises(Exception, list, loading.instances(q, cursor, ctx)) assert cursor.close.called, "Cursor wasn't closed" def test_row_proc_not_created(self): @@ -47,13 +48,13 @@ class InstancesTest(_fixtures.FixtureTest): assert_raises_message( exc.NoSuchColumnError, "Could not locate column in row for column 'users.name'", - q.from_statement(stmt).all + q.from_statement(stmt).all, ) class MergeResultTest(_fixtures.FixtureTest): - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod @@ -64,8 +65,12 @@ class MergeResultTest(_fixtures.FixtureTest): User = self.classes.User s = Session() - u1, u2, u3, u4 = User(id=1, name='u1'), User(id=2, name='u2'), \ - User(id=7, name='u3'), User(id=8, name='u4') + u1, u2, u3, u4 = ( + User(id=1, name="u1"), + User(id=2, name="u2"), + User(id=7, name="u3"), + User(id=8, name="u4"), + ) s.query(User).filter(User.id.in_([7, 8])).all() s.close() return s, [u1, u2, u3, u4] @@ -76,14 +81,8 @@ class MergeResultTest(_fixtures.FixtureTest): q = s.query(User) collection = [u1, u2, u3, u4] - it = loading.merge_result( - q, - collection - ) - eq_( - [x.id for x in it], - [1, 2, 7, 8] - ) + it = loading.merge_result(q, collection) + eq_([x.id for x in it], [1, 2, 7, 8]) def test_single_column(self): User = self.classes.User @@ -91,15 +90,9 @@ class MergeResultTest(_fixtures.FixtureTest): s = Session() q = s.query(User.id) - collection = [(1, ), (2, ), (7, ), (8, )] - it = loading.merge_result( - q, - collection - ) - eq_( - list(it), - [(1, ), (2, ), (7, ), (8, )] - ) + collection = [(1,), (2,), (7,), (8,)] + it = loading.merge_result(q, collection) + eq_(list(it), [(1,), (2,), (7,), (8,)]) def test_entity_col_mix_plain_tuple(self): s, (u1, u2, u3, u4) = self._fixture() @@ -107,16 +100,10 @@ class MergeResultTest(_fixtures.FixtureTest): q = s.query(User, User.id) collection = [(u1, 1), (u2, 2), (u3, 7), (u4, 8)] - it = loading.merge_result( - q, - collection - ) + it = loading.merge_result(q, collection) it = list(it) - eq_( - [(x.id, y) for x, y in it], - [(1, 1), (2, 2), (7, 7), (8, 8)] - ) - eq_(list(it[0].keys()), ['User', 'id']) + eq_([(x.id, y) for x, y in it], [(1, 1), (2, 2), (7, 7), (8, 8)]) + eq_(list(it[0].keys()), ["User", "id"]) def test_entity_col_mix_keyed_tuple(self): s, (u1, u2, u3, u4) = self._fixture() @@ -125,19 +112,13 @@ class MergeResultTest(_fixtures.FixtureTest): q = s.query(User, User.id) def kt(*x): - return KeyedTuple(x, ['User', 'id']) + return KeyedTuple(x, ["User", "id"]) collection = [kt(u1, 1), kt(u2, 2), kt(u3, 7), kt(u4, 8)] - it = loading.merge_result( - q, - collection - ) + it = loading.merge_result(q, collection) it = list(it) - eq_( - [(x.id, y) for x, y in it], - [(1, 1), (2, 2), (7, 7), (8, 8)] - ) - eq_(list(it[0].keys()), ['User', 'id']) + eq_([(x.id, y) for x, y in it], [(1, 1), (2, 2), (7, 7), (8, 8)]) + eq_(list(it[0].keys()), ["User", "id"]) def test_none_entity(self): s, (u1, u2, u3, u4) = self._fixture() @@ -147,17 +128,11 @@ class MergeResultTest(_fixtures.FixtureTest): q = s.query(User, ua) def kt(*x): - return KeyedTuple(x, ['User', 'useralias']) + return KeyedTuple(x, ["User", "useralias"]) collection = [kt(u1, u2), kt(u1, None), kt(u2, u3)] - it = loading.merge_result( - q, - collection - ) + it = loading.merge_result(q, collection) eq_( - [ - (x and x.id or None, y and y.id or None) - for x, y in it - ], - [(u1.id, u2.id), (u1.id, None), (u2.id, u3.id)] + [(x and x.id or None, y and y.id or None) for x, y in it], + [(u1.id, u2.id), (u1.id, None), (u2.id, u3.id)], ) diff --git a/test/orm/test_lockmode.py b/test/orm/test_lockmode.py index 34ae52d3c7..1a05ffbe4d 100644 --- a/test/orm/test_lockmode.py +++ b/test/orm/test_lockmode.py @@ -54,7 +54,8 @@ class LegacyLockModeTest(_fixtures.FixtureTest): assert_raises_message( exc.ArgumentError, "Unknown with_lockmode argument: 'unknown_mode'", - sess.query(User.id).with_lockmode, 'unknown_mode' + sess.query(User.id).with_lockmode, + "unknown_mode", ) @@ -64,12 +65,20 @@ class ForUpdateTest(_fixtures.FixtureTest): User, users = cls.classes.User, cls.tables.users mapper(User, users) - def _assert(self, read=False, nowait=False, of=None, key_share=None, - assert_q_of=None, assert_sel_of=None): + def _assert( + self, + read=False, + nowait=False, + of=None, + key_share=None, + assert_q_of=None, + assert_sel_of=None, + ): User = self.classes.User s = Session() q = s.query(User).with_for_update( - read=read, nowait=nowait, of=of, key_share=key_share) + read=read, nowait=nowait, of=of, key_share=key_share + ) sel = q._compile_context().statement assert q._for_update_arg.read is read @@ -99,9 +108,7 @@ class ForUpdateTest(_fixtures.FixtureTest): def test_of_single_col(self): User, users = self.classes.User, self.tables.users self._assert( - of=User.id, - assert_q_of=[users.c.id], - assert_sel_of=[users.c.id] + of=User.id, assert_q_of=[users.c.id], assert_sel_of=[users.c.id] ) @@ -117,24 +124,21 @@ class BackendTest(_fixtures.FixtureTest): def setup_mappers(cls): User, users = cls.classes.User, cls.tables.users Address, addresses = cls.classes.Address, cls.tables.addresses - mapper(User, users, properties={ - "addresses": relationship(Address) - }) + mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) def test_inner_joinedload_w_limit(self): User = self.classes.User sess = Session() - q = sess.query(User).options( - joinedload(User.addresses, innerjoin=True) - ).with_for_update().limit(1) + q = ( + sess.query(User) + .options(joinedload(User.addresses, innerjoin=True)) + .with_for_update() + .limit(1) + ) if testing.against("oracle"): - assert_raises_message( - exc.DatabaseError, - "ORA-02014", - q.all - ) + assert_raises_message(exc.DatabaseError, "ORA-02014", q.all) else: q.all() sess.close() @@ -162,11 +166,7 @@ class BackendTest(_fixtures.FixtureTest): q = q.limit(1) if testing.against("oracle"): - assert_raises_message( - exc.DatabaseError, - "ORA-02014", - q.all - ) + assert_raises_message(exc.DatabaseError, "ORA-02014", q.all) else: q.all() sess.close() @@ -203,15 +203,14 @@ class BackendTest(_fixtures.FixtureTest): class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): """run some compile tests, even though these are redundant.""" + run_inserts = None @classmethod def setup_mappers(cls): User, users = cls.classes.User, cls.tables.users Address, addresses = cls.classes.Address, cls.tables.addresses - mapper(User, users, properties={ - "addresses": relationship(Address) - }) + mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) def test_default_update(self): @@ -220,7 +219,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.assert_compile( sess.query(User.id).with_for_update(), "SELECT users.id AS users_id FROM users FOR UPDATE", - dialect=default.DefaultDialect() + dialect=default.DefaultDialect(), ) def test_not_supported_by_dialect_should_just_use_update(self): @@ -229,24 +228,25 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.assert_compile( sess.query(User.id).with_for_update(read=True), "SELECT users.id AS users_id FROM users FOR UPDATE", - dialect=default.DefaultDialect() + dialect=default.DefaultDialect(), ) def test_postgres_read(self): User = self.classes.User sess = Session() - self.assert_compile(sess.query(User.id).with_for_update(read=True), - "SELECT users.id AS users_id FROM users FOR SHARE", - dialect="postgresql") + self.assert_compile( + sess.query(User.id).with_for_update(read=True), + "SELECT users.id AS users_id FROM users FOR SHARE", + dialect="postgresql", + ) def test_postgres_read_nowait(self): User = self.classes.User sess = Session() self.assert_compile( - sess.query(User.id). - with_for_update(read=True, nowait=True), + sess.query(User.id).with_for_update(read=True, nowait=True), "SELECT users.id AS users_id FROM users FOR SHARE NOWAIT", - dialect="postgresql" + dialect="postgresql", ) def test_postgres_update(self): @@ -255,7 +255,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.assert_compile( sess.query(User.id).with_for_update(), "SELECT users.id AS users_id FROM users FOR UPDATE", - dialect="postgresql" + dialect="postgresql", ) def test_postgres_update_of(self): @@ -264,7 +264,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.assert_compile( sess.query(User.id).with_for_update(of=User.id), "SELECT users.id AS users_id FROM users FOR UPDATE OF users", - dialect="postgresql" + dialect="postgresql", ) def test_postgres_update_of_entity(self): @@ -273,7 +273,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.assert_compile( sess.query(User.id).with_for_update(of=User), "SELECT users.id AS users_id FROM users FOR UPDATE OF users", - dialect="postgresql" + dialect="postgresql", ) def test_postgres_update_of_entity_list(self): @@ -282,11 +282,12 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): sess = Session() self.assert_compile( - sess.query(User.id, Address.id). - with_for_update(of=[User, Address]), + sess.query(User.id, Address.id).with_for_update( + of=[User, Address] + ), "SELECT users.id AS users_id, addresses.id AS addresses_id " "FROM users, addresses FOR UPDATE OF users, addresses", - dialect="postgresql" + dialect="postgresql", ) def test_postgres_for_no_key_update(self): @@ -295,7 +296,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.assert_compile( sess.query(User.id).with_for_update(key_share=True), "SELECT users.id AS users_id FROM users FOR NO KEY UPDATE", - dialect="postgresql" + dialect="postgresql", ) def test_postgres_for_no_key_nowait_update(self): @@ -304,17 +305,18 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.assert_compile( sess.query(User.id).with_for_update(key_share=True, nowait=True), "SELECT users.id AS users_id FROM users FOR NO KEY UPDATE NOWAIT", - dialect="postgresql" + dialect="postgresql", ) def test_postgres_update_of_list(self): User = self.classes.User sess = Session() self.assert_compile( - sess.query(User.id) - .with_for_update(of=[User.id, User.id, User.id]), + sess.query(User.id).with_for_update( + of=[User.id, User.id, User.id] + ), "SELECT users.id AS users_id FROM users FOR UPDATE OF users", - dialect="postgresql" + dialect="postgresql", ) def test_postgres_update_skip_locked(self): @@ -323,7 +325,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.assert_compile( sess.query(User.id).with_for_update(skip_locked=True), "SELECT users.id AS users_id FROM users FOR UPDATE SKIP LOCKED", - dialect="postgresql" + dialect="postgresql", ) def test_oracle_update(self): @@ -332,7 +334,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.assert_compile( sess.query(User.id).with_for_update(), "SELECT users.id AS users_id FROM users FOR UPDATE", - dialect="oracle" + dialect="oracle", ) def test_oracle_update_skip_locked(self): @@ -341,7 +343,7 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.assert_compile( sess.query(User.id).with_for_update(skip_locked=True), "SELECT users.id AS users_id FROM users FOR UPDATE SKIP LOCKED", - dialect="oracle" + dialect="oracle", ) def test_mysql_read(self): @@ -350,15 +352,17 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.assert_compile( sess.query(User.id).with_for_update(read=True), "SELECT users.id AS users_id FROM users LOCK IN SHARE MODE", - dialect="mysql" + dialect="mysql", ) def test_for_update_on_inner_w_joinedload(self): User = self.classes.User sess = Session() self.assert_compile( - sess.query(User).options( - joinedload(User.addresses)).with_for_update().limit(1), + sess.query(User) + .options(joinedload(User.addresses)) + .with_for_update() + .limit(1), "SELECT anon_1.users_id AS anon_1_users_id, anon_1.users_name " "AS anon_1_users_name, addresses_1.id AS addresses_1_id, " "addresses_1.user_id AS addresses_1_user_id, " @@ -367,15 +371,17 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): "FROM users LIMIT %s FOR UPDATE) AS anon_1 " "LEFT OUTER JOIN addresses AS addresses_1 " "ON anon_1.users_id = addresses_1.user_id FOR UPDATE", - dialect="mysql" + dialect="mysql", ) def test_for_update_on_inner_w_joinedload_no_render_oracle(self): User = self.classes.User sess = Session() self.assert_compile( - sess.query(User).options( - joinedload(User.addresses)).with_for_update().limit(1), + sess.query(User) + .options(joinedload(User.addresses)) + .with_for_update() + .limit(1), "SELECT anon_1.users_id AS anon_1_users_id, " "anon_1.users_name AS anon_1_users_name, " "addresses_1.id AS addresses_1_id, " @@ -386,5 +392,5 @@ class CompileTest(_fixtures.FixtureTest, AssertsCompiledSQL): "FROM users) WHERE ROWNUM <= :param_1) anon_1 " "LEFT OUTER JOIN addresses addresses_1 " "ON anon_1.users_id = addresses_1.user_id FOR UPDATE", - dialect="oracle" + dialect="oracle", ) diff --git a/test/orm/test_manytomany.py b/test/orm/test_manytomany.py index f20c3db715..3401de7faa 100644 --- a/test/orm/test_manytomany.py +++ b/test/orm/test_manytomany.py @@ -1,56 +1,101 @@ -from sqlalchemy.testing import assert_raises, \ - assert_raises_message, eq_ +from sqlalchemy.testing import assert_raises, assert_raises_message, eq_ import sqlalchemy as sa from sqlalchemy import testing from sqlalchemy import Integer, String, ForeignKey from sqlalchemy.testing.schema import Table from sqlalchemy.testing.schema import Column -from sqlalchemy.orm import mapper, relationship, Session, \ - exc as orm_exc, sessionmaker, backref +from sqlalchemy.orm import ( + mapper, + relationship, + Session, + exc as orm_exc, + sessionmaker, + backref, +) from sqlalchemy.testing import fixtures class M2MTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('place', metadata, - Column('place_id', Integer, test_needs_autoincrement=True, - primary_key=True), - Column('name', String(30), nullable=False), - test_needs_acid=True) - - Table('transition', metadata, - Column('transition_id', Integer, - test_needs_autoincrement=True, primary_key=True), - Column('name', String(30), nullable=False), - test_needs_acid=True) - - Table('place_thingy', metadata, - Column('thingy_id', Integer, test_needs_autoincrement=True, - primary_key=True), - Column('place_id', Integer, ForeignKey('place.place_id'), - nullable=False), - Column('name', String(30), nullable=False), - test_needs_acid=True) + Table( + "place", + metadata, + Column( + "place_id", + Integer, + test_needs_autoincrement=True, + primary_key=True, + ), + Column("name", String(30), nullable=False), + test_needs_acid=True, + ) + + Table( + "transition", + metadata, + Column( + "transition_id", + Integer, + test_needs_autoincrement=True, + primary_key=True, + ), + Column("name", String(30), nullable=False), + test_needs_acid=True, + ) + + Table( + "place_thingy", + metadata, + Column( + "thingy_id", + Integer, + test_needs_autoincrement=True, + primary_key=True, + ), + Column( + "place_id", + Integer, + ForeignKey("place.place_id"), + nullable=False, + ), + Column("name", String(30), nullable=False), + test_needs_acid=True, + ) # association table #1 - Table('place_input', metadata, - Column('place_id', Integer, ForeignKey('place.place_id')), - Column('transition_id', Integer, - ForeignKey('transition.transition_id')), - test_needs_acid=True) + Table( + "place_input", + metadata, + Column("place_id", Integer, ForeignKey("place.place_id")), + Column( + "transition_id", + Integer, + ForeignKey("transition.transition_id"), + ), + test_needs_acid=True, + ) # association table #2 - Table('place_output', metadata, - Column('place_id', Integer, ForeignKey('place.place_id')), - Column('transition_id', Integer, - ForeignKey('transition.transition_id')), - test_needs_acid=True) + Table( + "place_output", + metadata, + Column("place_id", Integer, ForeignKey("place.place_id")), + Column( + "transition_id", + Integer, + ForeignKey("transition.transition_id"), + ), + test_needs_acid=True, + ) - Table('place_place', metadata, - Column('pl1_id', Integer, ForeignKey('place.place_id')), - Column('pl2_id', Integer, ForeignKey('place.place_id')), - test_needs_acid=True) + Table( + "place_place", + metadata, + Column("pl1_id", Integer, ForeignKey("place.place_id")), + Column("pl2_id", Integer, ForeignKey("place.place_id")), + test_needs_acid=True, + ) @classmethod def setup_classes(cls): @@ -72,46 +117,63 @@ class M2MTest(fixtures.MappedTest): self.classes.Transition, self.tables.place_input, self.classes.Place, - self.tables.transition) - - mapper(Place, place, properties={ - 'transitions': relationship(Transition, - secondary=place_input, - backref='places') - }) - mapper(Transition, transition, properties={ - 'places': relationship(Place, - secondary=place_input, - backref='transitions') - }) - assert_raises_message(sa.exc.ArgumentError, - "property of that name exists", - sa.orm.configure_mappers) + self.tables.transition, + ) + + mapper( + Place, + place, + properties={ + "transitions": relationship( + Transition, secondary=place_input, backref="places" + ) + }, + ) + mapper( + Transition, + transition, + properties={ + "places": relationship( + Place, secondary=place_input, backref="transitions" + ) + }, + ) + assert_raises_message( + sa.exc.ArgumentError, + "property of that name exists", + sa.orm.configure_mappers, + ) def test_self_referential_roundtrip(self): - place, Place, place_place = (self.tables.place, - self.classes.Place, - self.tables.place_place) + place, Place, place_place = ( + self.tables.place, + self.classes.Place, + self.tables.place_place, + ) - mapper(Place, place, properties={ - 'places': relationship( - Place, - secondary=place_place, - primaryjoin=place.c.place_id == place_place.c.pl1_id, - secondaryjoin=place.c.place_id == place_place.c.pl2_id, - order_by=place_place.c.pl2_id - ) - }) + mapper( + Place, + place, + properties={ + "places": relationship( + Place, + secondary=place_place, + primaryjoin=place.c.place_id == place_place.c.pl1_id, + secondaryjoin=place.c.place_id == place_place.c.pl2_id, + order_by=place_place.c.pl2_id, + ) + }, + ) sess = Session() - p1 = Place('place1') - p2 = Place('place2') - p3 = Place('place3') - p4 = Place('place4') - p5 = Place('place5') - p6 = Place('place6') - p7 = Place('place7') + p1 = Place("place1") + p2 = Place("place2") + p3 = Place("place3") + p4 = Place("place4") + p5 = Place("place5") + p6 = Place("place6") + p7 = Place("place7") sess.add_all((p1, p2, p3, p4, p5, p6, p7)) p1.places.append(p2) p1.places.append(p3) @@ -132,24 +194,30 @@ class M2MTest(fixtures.MappedTest): eq_(p2.places, []) def test_self_referential_bidirectional_mutation(self): - place, Place, place_place = (self.tables.place, - self.classes.Place, - self.tables.place_place) - - mapper(Place, place, properties={ - 'child_places': relationship( - Place, - secondary=place_place, - primaryjoin=place.c.place_id == place_place.c.pl1_id, - secondaryjoin=place.c.place_id == place_place.c.pl2_id, - order_by=place_place.c.pl2_id, - backref='parent_places' - ) - }) + place, Place, place_place = ( + self.tables.place, + self.classes.Place, + self.tables.place_place, + ) + + mapper( + Place, + place, + properties={ + "child_places": relationship( + Place, + secondary=place_place, + primaryjoin=place.c.place_id == place_place.c.pl1_id, + secondaryjoin=place.c.place_id == place_place.c.pl2_id, + order_by=place_place.c.pl2_id, + backref="parent_places", + ) + }, + ) sess = Session() - p1 = Place('place1') - p2 = Place('place2') + p1 = Place("place1") + p2 = Place("place2") p2.parent_places = [p1] sess.add_all([p1, p2]) p1.parent_places.append(p2) @@ -162,42 +230,51 @@ class M2MTest(fixtures.MappedTest): """test that a mapper can have two eager relationships to the same table, via two different association tables. aliases are required.""" - place_input, transition, Transition, PlaceThingy, \ - place, place_thingy, Place, \ - place_output = (self.tables.place_input, - self.tables.transition, - self.classes.Transition, - self.classes.PlaceThingy, - self.tables.place, - self.tables.place_thingy, - self.classes.Place, - self.tables.place_output) + place_input, transition, Transition, PlaceThingy, place, place_thingy, Place, place_output = ( + self.tables.place_input, + self.tables.transition, + self.classes.Transition, + self.classes.PlaceThingy, + self.tables.place, + self.tables.place_thingy, + self.classes.Place, + self.tables.place_output, + ) mapper(PlaceThingy, place_thingy) - mapper(Place, place, properties={ - 'thingies': relationship(PlaceThingy, lazy='joined') - }) + mapper( + Place, + place, + properties={"thingies": relationship(PlaceThingy, lazy="joined")}, + ) - mapper(Transition, transition, properties=dict( - inputs=relationship(Place, place_output, lazy='joined'), - outputs=relationship(Place, place_input, lazy='joined')) + mapper( + Transition, + transition, + properties=dict( + inputs=relationship(Place, place_output, lazy="joined"), + outputs=relationship(Place, place_input, lazy="joined"), + ), ) - tran = Transition('transition1') - tran.inputs.append(Place('place1')) - tran.outputs.append(Place('place2')) - tran.outputs.append(Place('place3')) + tran = Transition("transition1") + tran.inputs.append(Place("place1")) + tran.outputs.append(Place("place2")) + tran.outputs.append(Place("place3")) sess = Session() sess.add(tran) sess.commit() r = sess.query(Transition).all() - self.assert_unordered_result(r, Transition, - {'name': 'transition1', - 'inputs': (Place, [{'name': 'place1'}]), - 'outputs': (Place, [{'name': 'place2'}, - {'name': 'place3'}]) - }) + self.assert_unordered_result( + r, + Transition, + { + "name": "transition1", + "inputs": (Place, [{"name": "place1"}]), + "outputs": (Place, [{"name": "place2"}, {"name": "place3"}]), + }, + ) def test_bidirectional(self): place_input, transition, Transition, Place, place, place_output = ( @@ -206,28 +283,39 @@ class M2MTest(fixtures.MappedTest): self.classes.Transition, self.classes.Place, self.tables.place, - self.tables.place_output) + self.tables.place_output, + ) mapper(Place, place) - mapper(Transition, transition, properties=dict( - inputs=relationship( - Place, place_output, - backref=backref('inputs', order_by=transition.c.transition_id), - order_by=Place.place_id), - outputs=relationship( - Place, place_input, - backref=backref('outputs', - order_by=transition.c.transition_id), - order_by=Place.place_id), - ) - ) - - t1 = Transition('transition1') - t2 = Transition('transition2') - t3 = Transition('transition3') - p1 = Place('place1') - p2 = Place('place2') - p3 = Place('place3') + mapper( + Transition, + transition, + properties=dict( + inputs=relationship( + Place, + place_output, + backref=backref( + "inputs", order_by=transition.c.transition_id + ), + order_by=Place.place_id, + ), + outputs=relationship( + Place, + place_input, + backref=backref( + "outputs", order_by=transition.c.transition_id + ), + order_by=Place.place_id, + ), + ), + ) + + t1 = Transition("transition1") + t2 = Transition("transition2") + t3 = Transition("transition3") + p1 = Place("place1") + p2 = Place("place2") + p3 = Place("place3") sess = Session() sess.add_all([p3, p1, t1, t2, p2, t3]) @@ -241,14 +329,21 @@ class M2MTest(fixtures.MappedTest): p1.outputs.append(t1) sess.commit() - self.assert_result([t1], - Transition, {'outputs': - (Place, [{'name': 'place3'}, - {'name': 'place1'}])}) - self.assert_result([p2], - Place, {'inputs': - (Transition, [{'name': 'transition1'}, - {'name': 'transition2'}])}) + self.assert_result( + [t1], + Transition, + {"outputs": (Place, [{"name": "place3"}, {"name": "place1"}])}, + ) + self.assert_result( + [p2], + Place, + { + "inputs": ( + Transition, + [{"name": "transition1"}, {"name": "transition2"}], + ) + }, + ) @testing.requires.updateable_autoincrement_pks @testing.requires.sane_multi_rowcount @@ -258,16 +353,22 @@ class M2MTest(fixtures.MappedTest): self.classes.Transition, self.tables.place_input, self.tables.place, - self.tables.transition) + self.tables.transition, + ) - mapper(Place, place, properties={ - 'transitions': relationship(Transition, secondary=place_input, - passive_updates=False) - }) + mapper( + Place, + place, + properties={ + "transitions": relationship( + Transition, secondary=place_input, passive_updates=False + ) + }, + ) mapper(Transition, transition) - p1 = Place('place1') - t1 = Transition('t1') + p1 = Place("place1") + t1 = Transition("t1") p1.transitions.append(t1) sess = sessionmaker()() sess.add_all([p1, t1]) @@ -283,7 +384,7 @@ class M2MTest(fixtures.MappedTest): orm_exc.StaleDataError, r"UPDATE statement on table 'place_input' expected to " r"update 1 row\(s\); Only 0 were matched.", - sess.commit + sess.commit, ) sess.rollback() @@ -295,28 +396,41 @@ class M2MTest(fixtures.MappedTest): orm_exc.StaleDataError, r"DELETE statement on table 'place_input' expected to " r"delete 1 row\(s\); Only 0 were matched.", - sess.commit + sess.commit, ) class AssortedPersistenceTests(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table("left", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30))) - - Table("right", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30))) - - Table('secondary', metadata, - Column('left_id', Integer, ForeignKey('left.id'), - primary_key=True), - Column('right_id', Integer, ForeignKey('right.id'), - primary_key=True)) + Table( + "left", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + ) + + Table( + "right", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + ) + + Table( + "secondary", + metadata, + Column( + "left_id", Integer, ForeignKey("left.id"), primary_key=True + ), + Column( + "right_id", Integer, ForeignKey("right.id"), primary_key=True + ), + ) @classmethod def setup_classes(cls): @@ -327,24 +441,42 @@ class AssortedPersistenceTests(fixtures.MappedTest): pass def _standard_bidirectional_fixture(self): - left, secondary, right = self.tables.left, \ - self.tables.secondary, self.tables.right + left, secondary, right = ( + self.tables.left, + self.tables.secondary, + self.tables.right, + ) A, B = self.classes.A, self.classes.B - mapper(A, left, properties={ - 'bs': relationship(B, secondary=secondary, - backref='as', order_by=right.c.id) - }) + mapper( + A, + left, + properties={ + "bs": relationship( + B, secondary=secondary, backref="as", order_by=right.c.id + ) + }, + ) mapper(B, right) def _bidirectional_onescalar_fixture(self): - left, secondary, right = self.tables.left, \ - self.tables.secondary, self.tables.right + left, secondary, right = ( + self.tables.left, + self.tables.secondary, + self.tables.right, + ) A, B = self.classes.A, self.classes.B - mapper(A, left, properties={ - 'bs': relationship(B, secondary=secondary, - backref=backref('a', uselist=False), - order_by=right.c.id) - }) + mapper( + A, + left, + properties={ + "bs": relationship( + B, + secondary=secondary, + backref=backref("a", uselist=False), + order_by=right.c.id, + ) + }, + ) mapper(B, right) def test_session_delete(self): @@ -353,18 +485,17 @@ class AssortedPersistenceTests(fixtures.MappedTest): secondary = self.tables.secondary sess = Session() - sess.add_all([ - A(data='a1', bs=[B(data='b1')]), - A(data='a2', bs=[B(data='b2')]) - ]) + sess.add_all( + [A(data="a1", bs=[B(data="b1")]), A(data="a2", bs=[B(data="b2")])] + ) sess.commit() - a1 = sess.query(A).filter_by(data='a1').one() + a1 = sess.query(A).filter_by(data="a1").one() sess.delete(a1) sess.flush() eq_(sess.query(secondary).count(), 1) - a2 = sess.query(A).filter_by(data='a2').one() + a2 = sess.query(A).filter_by(data="a2").one() sess.delete(a2) sess.flush() eq_(sess.query(secondary).count(), 0) @@ -376,18 +507,16 @@ class AssortedPersistenceTests(fixtures.MappedTest): secondary = self.tables.secondary sess = Session() - sess.add_all([ - A(data='a1', bs=[B(data='b1'), B(data='b2')]), - ]) + sess.add_all([A(data="a1", bs=[B(data="b1"), B(data="b2")])]) sess.commit() - a1 = sess.query(A).filter_by(data='a1').one() - b2 = sess.query(B).filter_by(data='b2').one() + a1 = sess.query(A).filter_by(data="a1").one() + b2 = sess.query(B).filter_by(data="b2").one() assert b2.a is a1 b2.a = None sess.commit() - eq_(a1.bs, [B(data='b1')]) + eq_(a1.bs, [B(data="b1")]) eq_(b2.a, None) eq_(sess.query(secondary).count(), 1) diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index 487299f298..780272cf76 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -3,15 +3,35 @@ from sqlalchemy.testing import assert_raises, assert_raises_message import sqlalchemy as sa from sqlalchemy import testing -from sqlalchemy import MetaData, Integer, String, \ - ForeignKey, func, util, select +from sqlalchemy import ( + MetaData, + Integer, + String, + ForeignKey, + func, + util, + select, +) from sqlalchemy.testing.schema import Table, Column from sqlalchemy.engine import default -from sqlalchemy.orm import mapper, relationship, backref, \ - create_session, class_mapper, configure_mappers, reconstructor, \ - aliased, deferred, synonym, attributes, \ - column_property, composite, dynamic_loader, \ - comparable_property, Session +from sqlalchemy.orm import ( + mapper, + relationship, + backref, + create_session, + class_mapper, + configure_mappers, + reconstructor, + aliased, + deferred, + synonym, + attributes, + column_property, + composite, + dynamic_loader, + comparable_property, + Session, +) from sqlalchemy.orm.persistence import _sort_states from sqlalchemy.testing import eq_, AssertsCompiledSQL, is_ from sqlalchemy.testing import fixtures @@ -22,21 +42,26 @@ import logging.handlers class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_prop_shadow(self): """A backref name may not shadow an existing property name.""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, - properties={ - 'addresses': relationship(Address, backref='email_address') - }) + mapper( + User, + users, + properties={ + "addresses": relationship(Address, backref="email_address") + }, + ) assert_raises(sa.exc.ArgumentError, sa.orm.configure_mappers) def test_update_attr_keys(self): @@ -45,17 +70,25 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): User, users = self.classes.User, self.tables.users - mapper(User, users, properties={ - 'foobar': users.c.name - }) + mapper(User, users, properties={"foobar": users.c.name}) - users.insert().values({User.foobar: 'name1'}).execute() - eq_(sa.select([User.foobar]).where(User.foobar == 'name1'). - execute().fetchall(), [('name1',)]) + users.insert().values({User.foobar: "name1"}).execute() + eq_( + sa.select([User.foobar]) + .where(User.foobar == "name1") + .execute() + .fetchall(), + [("name1",)], + ) - users.update().values({User.foobar: User.foobar + 'foo'}).execute() - eq_(sa.select([User.foobar]).where(User.foobar == 'name1foo'). - execute().fetchall(), [('name1foo',)]) + users.update().values({User.foobar: User.foobar + "foo"}).execute() + eq_( + sa.select([User.foobar]) + .where(User.foobar == "name1foo") + .execute() + .fetchall(), + [("name1foo",)], + ) def test_utils(self): users = self.tables.users @@ -107,6 +140,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): @property def y(self): return "something else" + m = mapper(Foo, users) a1 = aliased(Foo) @@ -114,7 +148,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): (m, "x", Foo.x), (Foo, "x", Foo.x), (a1, "x", a1.x), - (users, "name", users.c.name) + (users, "name", users.c.name), ]: assert _entity_descriptor(arg, key) is ret @@ -123,9 +157,8 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): def boom(): raise Exception("it broke") - mapper(User, users, properties={ - 'addresses': relationship(boom) - }) + + mapper(User, users, properties={"addresses": relationship(boom)}) # test that QueryableAttribute.__str__() doesn't # cause a compile. @@ -138,48 +171,53 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): """ - Address, addresses, User = (self.classes.Address, - self.tables.addresses, - self.classes.User) + Address, addresses, User = ( + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(Address, addresses, properties={ - 'user': relationship(User) - }) + mapper(Address, addresses, properties={"user": relationship(User)}) try: - hasattr(Address.user, 'property') + hasattr(Address.user, "property") except sa.orm.exc.UnmappedClassError: assert util.compat.py32 for i in range(3): - assert_raises_message(sa.exc.InvalidRequestError, - "^One or more " - "mappers failed to initialize - can't " - "proceed with initialization of other " - r"mappers. Triggering mapper\: " - r"'Mapper\|Address\|addresses'." - " Original exception was: Class " - "'test.orm._fixtures.User' is not mapped$", - configure_mappers) + assert_raises_message( + sa.exc.InvalidRequestError, + "^One or more " + "mappers failed to initialize - can't " + "proceed with initialization of other " + r"mappers. Triggering mapper\: " + r"'Mapper\|Address\|addresses'." + " Original exception was: Class " + "'test.orm._fixtures.User' is not mapped$", + configure_mappers, + ) def test_column_prefix(self): users, User = self.tables.users, self.classes.User - mapper(User, users, column_prefix='_', properties={ - 'user_name': synonym('_name') - }) + mapper( + User, + users, + column_prefix="_", + properties={"user_name": synonym("_name")}, + ) s = create_session() u = s.query(User).get(7) - eq_(u._name, 'jack') + eq_(u._name, "jack") eq_(u._id, 7) - u2 = s.query(User).filter_by(user_name='jack').one() + u2 = s.query(User).filter_by(user_name="jack").one() assert u is u2 def test_no_pks_1(self): User, users = self.classes.User, self.tables.users - s = sa.select([users.c.name]).alias('foo') + s = sa.select([users.c.name]).alias("foo") assert_raises(sa.exc.ArgumentError, mapper, User, s) def test_no_pks_2(self): @@ -192,17 +230,22 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): """A configure trigger on an already-configured mapper still triggers a check against all mappers.""" - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) sa.orm.configure_mappers() assert sa.orm.mapperlib.Mapper._new_mappers is False - m = mapper(Address, addresses, properties={ - 'user': relationship(User, backref="addresses")}) + m = mapper( + Address, + addresses, + properties={"user": relationship(User, backref="addresses")}, + ) assert m.configured is False assert sa.orm.mapperlib.Mapper._new_mappers is True @@ -224,24 +267,36 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): mapper(Address, addresses) s = create_session() - a = s.query(Address).from_statement( - sa.select([addresses.c.id, addresses.c.user_id]). - order_by(addresses.c.id)).first() + a = ( + s.query(Address) + .from_statement( + sa.select([addresses.c.id, addresses.c.user_id]).order_by( + addresses.c.id + ) + ) + .first() + ) eq_(a.user_id, 7) eq_(a.id, 1) # email address auto-defers - assert 'email_addres' not in a.__dict__ - eq_(a.email_address, 'jack@bean.com') + assert "email_addres" not in a.__dict__ + eq_(a.email_address, "jack@bean.com") def test_column_not_present(self): - users, addresses, User = (self.tables.users, - self.tables.addresses, - self.classes.User) + users, addresses, User = ( + self.tables.users, + self.tables.addresses, + self.classes.User, + ) - assert_raises_message(sa.exc.ArgumentError, - "not represented in the mapper's table", - mapper, User, users, - properties={'foo': addresses.c.user_id}) + assert_raises_message( + sa.exc.ArgumentError, + "not represented in the mapper's table", + mapper, + User, + users, + properties={"foo": addresses.c.user_id}, + ) def test_constructor_exc(self): """TypeError is raised for illegal constructor args, @@ -250,7 +305,6 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): users, addresses = self.tables.users, self.tables.addresses class Foo(object): - def __init__(self): pass @@ -269,23 +323,23 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): this. """ - class Foo(object): + class Foo(object): def __init__(self, id): self.id = id + m = MetaData() - foo_t = Table('foo', m, - Column('id', String, primary_key=True) - ) + foo_t = Table("foo", m, Column("id", String, primary_key=True)) m = mapper(Foo, foo_t) class DontCompareMeToString(int): if util.py2k: + def __lt__(self, other): assert not isinstance(other, basestring) return int(self) < other - foos = [Foo(id='f%d' % i) for i in range(5)] + foos = [Foo(id="f%d" % i) for i in range(5)] states = [attributes.instance_state(f) for f in foos] for s in states[0:3]: @@ -295,86 +349,116 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): states[2].insert_order = DontCompareMeToString(3) eq_( _sort_states(states), - [states[4], states[3], states[0], states[1], states[2]] + [states[4], states[3], states[0], states[1], states[2]], ) def test_props(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - m = mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses)) - }) - assert User.addresses.property is m.get_property('addresses') + m = mapper( + User, + users, + properties={"addresses": relationship(mapper(Address, addresses))}, + ) + assert User.addresses.property is m.get_property("addresses") def test_unicode_relationship_backref_names(self): # test [ticket:2901] - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties={ - util.u('addresses'): relationship(Address, backref=util.u('user')) - }) + mapper( + User, + users, + properties={ + util.u("addresses"): relationship( + Address, backref=util.u("user") + ) + }, + ) u1 = User() a1 = Address() u1.addresses.append(a1) assert a1.user is u1 def test_configure_on_prop_1(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses)) - }) - User.addresses.any(Address.email_address == 'foo@bar.com') + mapper( + User, + users, + properties={"addresses": relationship(mapper(Address, addresses))}, + ) + User.addresses.any(Address.email_address == "foo@bar.com") def test_configure_on_prop_2(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses)) - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={"addresses": relationship(mapper(Address, addresses))}, + ) eq_(str(User.id == 3), str(users.c.id == 3)) def test_configure_on_prop_3(self): - users, addresses, User = (self.tables.users, - self.tables.addresses, - self.classes.User) + users, addresses, User = ( + self.tables.users, + self.tables.addresses, + self.classes.User, + ) class Foo(User): pass mapper(User, users) - mapper(Foo, addresses, inherits=User, properties={ - 'address_id': addresses.c.id - }) - assert getattr(Foo().__class__, 'name').impl is not None + mapper( + Foo, + addresses, + inherits=User, + properties={"address_id": addresses.c.id}, + ) + assert getattr(Foo().__class__, "name").impl is not None def test_deferred_subclass_attribute_instrument(self): - users, addresses, User = (self.tables.users, - self.tables.addresses, - self.classes.User) + users, addresses, User = ( + self.tables.users, + self.tables.addresses, + self.classes.User, + ) class Foo(User): pass mapper(User, users) configure_mappers() - mapper(Foo, addresses, inherits=User, properties={ - 'address_id': addresses.c.id - }) - assert getattr(Foo().__class__, 'name').impl is not None + mapper( + Foo, + addresses, + inherits=User, + properties={"address_id": addresses.c.id}, + ) + assert getattr(Foo().__class__, "name").impl is not None def test_class_hier_only_instrument_once_multiple_configure(self): users, addresses = (self.tables.users, self.tables.addresses) @@ -396,12 +480,10 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): with mock.patch( "sqlalchemy.orm.attributes.register_attribute_impl", - side_effect=register_attribute_impl + side_effect=register_attribute_impl, ) as some_mock: - mapper(A, users, properties={ - 'bs': relationship(B) - }) + mapper(A, users, properties={"bs": relationship(B)}) mapper(B, addresses) configure_mappers() @@ -411,9 +493,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): configure_mappers() - b_calls = [ - c for c in some_mock.mock_calls if c[1][1] == 'bs' - ] + b_calls = [c for c in some_mock.mock_calls if c[1][1] == "bs"] eq_(len(b_calls), 3) def test_check_descriptor_as_method(self): @@ -422,9 +502,9 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): m = mapper(User, users) class MyClass(User): - def foo(self): pass + m._is_userland_descriptor(MyClass.foo) def test_configure_on_get_props_1(self): @@ -440,23 +520,27 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): m = mapper(User, users) assert not m.configured - assert m.get_property('name') + assert m.get_property("name") assert m.configured def test_configure_on_get_props_3(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) m = mapper(User, users) assert not m.configured configure_mappers() - m2 = mapper(Address, addresses, properties={ - 'user': relationship(User, backref='addresses') - }) - assert m.get_property('addresses') + m2 = mapper( + Address, + addresses, + properties={"user": relationship(User, backref="addresses")}, + ) + assert m.get_property("addresses") def test_info(self): users = self.tables.users @@ -464,12 +548,13 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): class MyComposite(object): pass + for constructor, args in [ (column_property, (users.c.name,)), (relationship, (Address,)), - (composite, (MyComposite, 'id', 'name')), - (synonym, 'foo'), - (comparable_property, 'foo') + (composite, (MyComposite, "id", "name")), + (synonym, "foo"), + (comparable_property, "foo"), ]: obj = constructor(info={"x": "y"}, *args) eq_(obj.info, {"x": "y"}) @@ -485,26 +570,38 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): m = MetaData() # create specific tables here as we don't want # users.c.id.info to be pre-initialized - users = Table('u', m, Column('id', Integer, primary_key=True), - Column('name', String)) - addresses = Table('a', m, Column('id', Integer, primary_key=True), - Column('name', String), - Column('user_id', Integer, ForeignKey('u.id'))) + users = Table( + "u", + m, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + addresses = Table( + "a", + m, + Column("id", Integer, primary_key=True), + Column("name", String), + Column("user_id", Integer, ForeignKey("u.id")), + ) Address = self.classes.Address User = self.classes.User - mapper(User, users, properties={ - "name_lower": column_property(func.lower(users.c.name)), - "addresses": relationship(Address) - }) + mapper( + User, + users, + properties={ + "name_lower": column_property(func.lower(users.c.name)), + "addresses": relationship(Address), + }, + ) mapper(Address, addresses) # attr.info goes down to the original Column object # for the dictionary. The annotated element needs to pass # this on. - assert 'info' not in users.c.id.__dict__ + assert "info" not in users.c.id.__dict__ is_(User.id.info, users.c.id.info) - assert 'info' in users.c.id.__dict__ + assert "info" in users.c.id.__dict__ # for SQL expressions, ORM-level .info is_(User.name_lower.info, User.name_lower.property.info) @@ -513,27 +610,30 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): is_(User.addresses.info, User.addresses.property.info) def test_add_property(self): - users, addresses, Address = (self.tables.users, - self.tables.addresses, - self.classes.Address) + users, addresses, Address = ( + self.tables.users, + self.tables.addresses, + self.classes.Address, + ) assert_col = [] class User(fixtures.ComparableEntity): - def _get_name(self): - assert_col.append(('get', self._name)) + assert_col.append(("get", self._name)) return self._name def _set_name(self, name): - assert_col.append(('set', name)) + assert_col.append(("set", name)) self._name = name + name = property(_get_name, _set_name) def _uc_name(self): if self._name is None: return None return self._name.upper() + uc_name = property(_uc_name) uc_name2 = property(_uc_name) @@ -545,36 +645,40 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): def __eq__(self, other): cls = self.prop.parent.class_ - col = getattr(cls, 'name') + col = getattr(cls, "name") if other is None: return col is None else: return sa.func.upper(col) == sa.func.upper(other) - m.add_property('_name', deferred(users.c.name)) - m.add_property('name', synonym('_name')) - m.add_property('addresses', relationship(Address)) - m.add_property('uc_name', sa.orm.comparable_property(UCComparator)) - m.add_property('uc_name2', sa.orm.comparable_property( - UCComparator, User.uc_name2)) + m.add_property("_name", deferred(users.c.name)) + m.add_property("name", synonym("_name")) + m.add_property("addresses", relationship(Address)) + m.add_property("uc_name", sa.orm.comparable_property(UCComparator)) + m.add_property( + "uc_name2", sa.orm.comparable_property(UCComparator, User.uc_name2) + ) sess = create_session(autocommit=False) assert sess.query(User).get(7) - u = sess.query(User).filter_by(name='jack').one() + u = sess.query(User).filter_by(name="jack").one() def go(): - eq_(len(u.addresses), - len(self.static.user_address_result[0].addresses)) - eq_(u.name, 'jack') - eq_(u.uc_name, 'JACK') - eq_(u.uc_name2, 'JACK') - eq_(assert_col, [('get', 'jack')], str(assert_col)) + eq_( + len(u.addresses), + len(self.static.user_address_result[0].addresses), + ) + eq_(u.name, "jack") + eq_(u.uc_name, "JACK") + eq_(u.uc_name2, "JACK") + eq_(assert_col, [("get", "jack")], str(assert_col)) + self.sql_count_(2, go) - u.name = 'ed' + u.name = "ed" u3 = User() - u3.name = 'some user' + u3.name = "some user" sess.add(u3) sess.flush() sess.rollback() @@ -586,47 +690,49 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): m1 = mapper(User, users) User() - m2 = mapper(Address, addresses, properties={ - 'user': relationship(User, backref="addresses") - }) + m2 = mapper( + Address, + addresses, + properties={"user": relationship(User, backref="addresses")}, + ) # configure mappers takes place when User is generated User() - assert hasattr(User, 'addresses') + assert hasattr(User, "addresses") assert "addresses" in [p.key for p in m1._polymorphic_properties] def test_replace_col_prop_w_syn(self): users, User = self.tables.users, self.classes.User m = mapper(User, users) - m.add_property('_name', users.c.name) - m.add_property('name', synonym('_name')) + m.add_property("_name", users.c.name) + m.add_property("name", synonym("_name")) sess = create_session() - u = sess.query(User).filter_by(name='jack').one() - eq_(u._name, 'jack') - eq_(u.name, 'jack') - u.name = 'jacko' - assert m._columntoproperty[users.c.name] is m.get_property('_name') + u = sess.query(User).filter_by(name="jack").one() + eq_(u._name, "jack") + eq_(u.name, "jack") + u.name = "jacko" + assert m._columntoproperty[users.c.name] is m.get_property("_name") sa.orm.clear_mappers() m = mapper(User, users) - m.add_property('name', synonym('_name', map_column=True)) + m.add_property("name", synonym("_name", map_column=True)) sess.expunge_all() - u = sess.query(User).filter_by(name='jack').one() - eq_(u._name, 'jack') - eq_(u.name, 'jack') - u.name = 'jacko' - assert m._columntoproperty[users.c.name] is m.get_property('_name') + u = sess.query(User).filter_by(name="jack").one() + eq_(u._name, "jack") + eq_(u.name, "jack") + u.name = "jacko" + assert m._columntoproperty[users.c.name] is m.get_property("_name") def test_replace_rel_prop_with_rel_warns(self): users, User = self.tables.users, self.classes.User addresses, Address = self.tables.addresses, self.classes.Address - m = mapper(User, users, properties={ - "addresses": relationship(Address) - }) + m = mapper( + User, users, properties={"addresses": relationship(Address)} + ) mapper(Address, addresses) assert_raises_message( @@ -635,7 +741,8 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): "with new property User.addresses; the old property will " "be discarded", m.add_property, - "addresses", relationship(Address) + "addresses", + relationship(Address), ) def test_add_column_prop_deannotate(self): @@ -644,13 +751,15 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): class SubUser(User): pass + m = mapper(User, users) - m2 = mapper(SubUser, addresses, inherits=User, properties={ - 'address_id': addresses.c.id - }) - m3 = mapper(Address, addresses, properties={ - 'foo': relationship(m2) - }) + m2 = mapper( + SubUser, + addresses, + inherits=User, + properties={"address_id": addresses.c.id}, + ) + m3 = mapper(Address, addresses, properties={"foo": relationship(m2)}) # add property using annotated User.name, # needs to be deannotated m.add_property("x", column_property(User.name + "name")) @@ -669,7 +778,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): "FROM addresses JOIN (users AS users_1 JOIN addresses " "AS addresses_1 ON users_1.id = " "addresses_1.user_id) ON " - "users_1.id = addresses.user_id" + "users_1.id = addresses.user_id", ) def test_column_prop_deannotate(self): @@ -692,20 +801,21 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): assert User.x.property.columns[0].element.right is not expr.right assert User.y.property.columns[0] is not expr2 - assert User.y.property.columns[0].element.\ - _raw_columns[0] is users.c.name - assert User.y.property.columns[0].element.\ - _raw_columns[1] is users.c.id + assert ( + User.y.property.columns[0].element._raw_columns[0] is users.c.name + ) + assert User.y.property.columns[0].element._raw_columns[1] is users.c.id def test_synonym_replaces_backref(self): - addresses, users, User = (self.tables.addresses, - self.tables.users, - self.classes.User) + addresses, users, User = ( + self.tables.addresses, + self.tables.users, + self.classes.User, + ) assert_calls = [] class Address(object): - def _get_user(self): assert_calls.append("get") return self._user @@ -713,18 +823,19 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): def _set_user(self, user): assert_calls.append("set") self._user = user + user = property(_get_user, _set_user) # synonym is created against nonexistent prop - mapper(Address, addresses, properties={ - 'user': synonym('_user') - }) + mapper(Address, addresses, properties={"user": synonym("_user")}) sa.orm.configure_mappers() # later, backref sets up the prop - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='_user') - }) + mapper( + User, + users, + properties={"addresses": relationship(Address, backref="_user")}, + ) sess = create_session() u1 = sess.query(User).get(7) @@ -737,21 +848,29 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): eq_(assert_calls, ["set", "get"]) def test_self_ref_synonym(self): - t = Table('nodes', MetaData(), - Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, ForeignKey('nodes.id'))) + t = Table( + "nodes", + MetaData(), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_id", Integer, ForeignKey("nodes.id")), + ) class Node(object): pass - mapper(Node, t, properties={ - '_children': relationship( - Node, backref=backref('_parent', remote_side=t.c.id)), - 'children': synonym('_children'), - 'parent': synonym('_parent') - }) + mapper( + Node, + t, + properties={ + "_children": relationship( + Node, backref=backref("_parent", remote_side=t.c.id) + ), + "children": synonym("_children"), + "parent": synonym("_parent"), + }, + ) n1 = Node() n2 = Node() @@ -766,16 +885,20 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): class AddressUser(User): pass - m1 = mapper(User, users, polymorphic_identity='user') - m2 = mapper(AddressUser, addresses, inherits=User, - polymorphic_identity='address', properties={ - 'address_id': addresses.c.id - }) + + m1 = mapper(User, users, polymorphic_identity="user") + m2 = mapper( + AddressUser, + addresses, + inherits=User, + polymorphic_identity="address", + properties={"address_id": addresses.c.id}, + ) m3 = mapper(AddressUser, addresses, non_primary=True) assert m3._identity_class is m2._identity_class eq_( m2.identity_key_from_instance(AddressUser()), - m3.identity_key_from_instance(AddressUser()) + m3.identity_key_from_instance(AddressUser()), ) def test_reassign_polymorphic_identity_warns(self): @@ -784,31 +907,44 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): class MyUser(User): pass - m1 = mapper(User, users, polymorphic_on=users.c.name, - polymorphic_identity='user') + + m1 = mapper( + User, + users, + polymorphic_on=users.c.name, + polymorphic_identity="user", + ) assert_raises_message( sa.exc.SAWarning, "Reassigning polymorphic association for identity 'user'", mapper, - MyUser, users, inherits=User, polymorphic_identity='user' + MyUser, + users, + inherits=User, + polymorphic_identity="user", ) def test_illegal_non_primary(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) mapper(Address, addresses) - mapper(User, users, non_primary=True, properties={ - 'addresses': relationship(Address) - }) + mapper( + User, + users, + non_primary=True, + properties={"addresses": relationship(Address)}, + ) assert_raises_message( sa.exc.ArgumentError, "Attempting to assign a new relationship 'addresses' " "to a non-primary mapper on class 'User'", - configure_mappers + configure_mappers, ) def test_illegal_non_primary_2(self): @@ -817,7 +953,11 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): assert_raises_message( sa.exc.InvalidRequestError, "Configure a primary mapper first", - mapper, User, users, non_primary=True) + mapper, + User, + users, + non_primary=True, + ) def test_illegal_non_primary_3(self): users, addresses = self.tables.users, self.tables.addresses @@ -827,21 +967,30 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): class Sub(Base): pass + mapper(Base, users) - assert_raises_message(sa.exc.InvalidRequestError, - "Configure a primary mapper first", - mapper, Sub, addresses, non_primary=True - ) + assert_raises_message( + sa.exc.InvalidRequestError, + "Configure a primary mapper first", + mapper, + Sub, + addresses, + non_primary=True, + ) def test_prop_filters(self): - t = Table('person', MetaData(), - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(128)), - Column('name', String(128)), - Column('employee_number', Integer), - Column('boss_id', Integer, ForeignKey('person.id')), - Column('vendor_id', Integer)) + t = Table( + "person", + MetaData(), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("type", String(128)), + Column("name", String(128)), + Column("employee_number", Integer), + Column("boss_id", Integer, ForeignKey("person.id")), + Column("vendor_id", Integer), + ) class Person(object): pass @@ -868,7 +1017,6 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): pass class HasDef(object): - def name(self): pass @@ -876,43 +1024,61 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): pass mapper( - Empty, t, properties={'empty_id': t.c.id}, - include_properties=[]) - p_m = mapper(Person, t, polymorphic_on=t.c.type, - include_properties=('id', 'type', 'name')) - e_m = mapper(Employee, inherits=p_m, - polymorphic_identity='employee', - properties={ - 'boss': relationship( - Manager, backref=backref('peon'), - remote_side=t.c.id)}, - exclude_properties=('vendor_id', )) + Empty, t, properties={"empty_id": t.c.id}, include_properties=[] + ) + p_m = mapper( + Person, + t, + polymorphic_on=t.c.type, + include_properties=("id", "type", "name"), + ) + e_m = mapper( + Employee, + inherits=p_m, + polymorphic_identity="employee", + properties={ + "boss": relationship( + Manager, backref=backref("peon"), remote_side=t.c.id + ) + }, + exclude_properties=("vendor_id",), + ) mapper( - Manager, inherits=e_m, polymorphic_identity='manager', - include_properties=('id', 'type')) + Manager, + inherits=e_m, + polymorphic_identity="manager", + include_properties=("id", "type"), + ) mapper( - Vendor, inherits=p_m, polymorphic_identity='vendor', - exclude_properties=('boss_id', 'employee_number')) - mapper(Hoho, t, include_properties=('id', 'type', 'name')) + Vendor, + inherits=p_m, + polymorphic_identity="vendor", + exclude_properties=("boss_id", "employee_number"), + ) + mapper(Hoho, t, include_properties=("id", "type", "name")) mapper( - Lala, t, exclude_properties=('vendor_id', 'boss_id'), - column_prefix="p_") + Lala, + t, + exclude_properties=("vendor_id", "boss_id"), + column_prefix="p_", + ) mapper(HasDef, t, column_prefix="h_") mapper(Fub, t, include_properties=(t.c.id, t.c.type)) mapper( - Frob, t, column_prefix='f_', - exclude_properties=( - t.c.boss_id, - 'employee_number', t.c.vendor_id)) + Frob, + t, + column_prefix="f_", + exclude_properties=(t.c.boss_id, "employee_number", t.c.vendor_id), + ) configure_mappers() def assert_props(cls, want): - have = set([n for n in dir(cls) if not n.startswith('_')]) + have = set([n for n in dir(cls) if not n.startswith("_")]) want = set(want) eq_(have, want) @@ -921,35 +1087,62 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): want = set(want) eq_(have, want) - assert_props(HasDef, ['h_boss_id', 'h_employee_number', 'h_id', - 'name', 'h_name', 'h_vendor_id', 'h_type']) - assert_props(Person, ['id', 'name', 'type']) - assert_instrumented(Person, ['id', 'name', 'type']) - assert_props(Employee, ['boss', 'boss_id', 'employee_number', - 'id', 'name', 'type']) - assert_instrumented(Employee, ['boss', 'boss_id', 'employee_number', - 'id', 'name', 'type']) - assert_props(Manager, ['boss', 'boss_id', 'employee_number', 'peon', - 'id', 'name', 'type']) + assert_props( + HasDef, + [ + "h_boss_id", + "h_employee_number", + "h_id", + "name", + "h_name", + "h_vendor_id", + "h_type", + ], + ) + assert_props(Person, ["id", "name", "type"]) + assert_instrumented(Person, ["id", "name", "type"]) + assert_props( + Employee, + ["boss", "boss_id", "employee_number", "id", "name", "type"], + ) + assert_instrumented( + Employee, + ["boss", "boss_id", "employee_number", "id", "name", "type"], + ) + assert_props( + Manager, + [ + "boss", + "boss_id", + "employee_number", + "peon", + "id", + "name", + "type", + ], + ) # 'peon' and 'type' are both explicitly stated properties - assert_instrumented(Manager, ['peon', 'type', 'id']) + assert_instrumented(Manager, ["peon", "type", "id"]) - assert_props(Vendor, ['vendor_id', 'id', 'name', 'type']) - assert_props(Hoho, ['id', 'name', 'type']) - assert_props(Lala, ['p_employee_number', 'p_id', 'p_name', 'p_type']) - assert_props(Fub, ['id', 'type']) - assert_props(Frob, ['f_id', 'f_type', 'f_name', ]) + assert_props(Vendor, ["vendor_id", "id", "name", "type"]) + assert_props(Hoho, ["id", "name", "type"]) + assert_props(Lala, ["p_employee_number", "p_id", "p_name", "p_type"]) + assert_props(Fub, ["id", "type"]) + assert_props(Frob, ["f_id", "f_type", "f_name"]) # putting the discriminator column in exclude_properties, # very weird. As of 0.7.4 this re-maps it. class Foo(Person): pass - assert_props(Empty, ['empty_id']) + + assert_props(Empty, ["empty_id"]) mapper( - Foo, inherits=Person, polymorphic_identity='foo', - exclude_properties=('type', ), + Foo, + inherits=Person, + polymorphic_identity="foo", + exclude_properties=("type",), ) assert hasattr(Foo, "type") assert Foo.type.property.columns[0] is t.c.type @@ -957,26 +1150,32 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): @testing.provide_metadata def test_prop_filters_defaults(self): metadata = self.metadata - t = Table('t', metadata, - Column( - 'id', Integer(), primary_key=True, - test_needs_autoincrement=True), - Column('x', Integer(), nullable=False, server_default='0') - ) + t = Table( + "t", + metadata, + Column( + "id", + Integer(), + primary_key=True, + test_needs_autoincrement=True, + ), + Column("x", Integer(), nullable=False, server_default="0"), + ) t.create() class A(object): pass - mapper(A, t, include_properties=['id']) + + mapper(A, t, include_properties=["id"]) s = Session() s.add(A()) s.commit() def test_we_dont_call_bool(self): class NoBoolAllowed(object): - def __bool__(self): raise Exception("nope") + mapper(NoBoolAllowed, self.tables.users) u1 = NoBoolAllowed() u1.name = "some name" @@ -987,21 +1186,22 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_we_dont_call_eq(self): class NoEqAllowed(object): - def __eq__(self, other): raise Exception("nope") addresses, users = self.tables.addresses, self.tables.users Address = self.classes.Address - mapper(NoEqAllowed, users, properties={ - 'addresses': relationship(Address, backref='user') - }) + mapper( + NoEqAllowed, + users, + properties={"addresses": relationship(Address, backref="user")}, + ) mapper(Address, addresses) u1 = NoEqAllowed() u1.name = "some name" - u1.addresses = [Address(id=12, email_address='a1')] + u1.addresses = [Address(id=12, email_address="a1")] s = Session(testing.db) s.add(u1) s.commit() @@ -1012,125 +1212,162 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_mapping_to_join_raises(self): """Test implicit merging of two cols raises.""" - addresses, users, User = (self.tables.addresses, - self.tables.users, - self.classes.User) + addresses, users, User = ( + self.tables.addresses, + self.tables.users, + self.classes.User, + ) - usersaddresses = sa.join(users, addresses, - users.c.id == addresses.c.user_id) + usersaddresses = sa.join( + users, addresses, users.c.id == addresses.c.user_id + ) assert_raises_message( sa.exc.InvalidRequestError, "Implicitly", - mapper, User, usersaddresses, primary_key=[users.c.id] + mapper, + User, + usersaddresses, + primary_key=[users.c.id], ) def test_mapping_to_join_explicit_prop(self): """Mapping to a join""" - User, addresses, users = (self.classes.User, - self.tables.addresses, - self.tables.users) + User, addresses, users = ( + self.classes.User, + self.tables.addresses, + self.tables.users, + ) - usersaddresses = sa.join(users, addresses, users.c.id - == addresses.c.user_id) - mapper(User, usersaddresses, primary_key=[users.c.id], - properties={'add_id': addresses.c.id} - ) + usersaddresses = sa.join( + users, addresses, users.c.id == addresses.c.user_id + ) + mapper( + User, + usersaddresses, + primary_key=[users.c.id], + properties={"add_id": addresses.c.id}, + ) result = create_session().query(User).order_by(users.c.id).all() eq_(result, self.static.user_result[:3]) def test_mapping_to_join_exclude_prop(self): """Mapping to a join""" - User, addresses, users = (self.classes.User, - self.tables.addresses, - self.tables.users) + User, addresses, users = ( + self.classes.User, + self.tables.addresses, + self.tables.users, + ) - usersaddresses = sa.join(users, addresses, users.c.id - == addresses.c.user_id) - mapper(User, usersaddresses, primary_key=[users.c.id], - exclude_properties=[addresses.c.id] - ) + usersaddresses = sa.join( + users, addresses, users.c.id == addresses.c.user_id + ) + mapper( + User, + usersaddresses, + primary_key=[users.c.id], + exclude_properties=[addresses.c.id], + ) result = create_session().query(User).order_by(users.c.id).all() eq_(result, self.static.user_result[:3]) def test_mapping_to_join_no_pk(self): - email_bounces, addresses, Address = (self.tables.email_bounces, - self.tables.addresses, - self.classes.Address) - - m = mapper(Address, - addresses.join(email_bounces), - properties={'id': [addresses.c.id, email_bounces.c.id]} - ) + email_bounces, addresses, Address = ( + self.tables.email_bounces, + self.tables.addresses, + self.classes.Address, + ) + + m = mapper( + Address, + addresses.join(email_bounces), + properties={"id": [addresses.c.id, email_bounces.c.id]}, + ) configure_mappers() assert addresses in m._pks_by_table assert email_bounces not in m._pks_by_table sess = create_session() - a = Address(id=10, email_address='e1') + a = Address(id=10, email_address="e1") sess.add(a) sess.flush() - eq_( - select([func.count('*')]).select_from(addresses).scalar(), 6) - eq_( - select([func.count('*')]).select_from(email_bounces).scalar(), 5) + eq_(select([func.count("*")]).select_from(addresses).scalar(), 6) + eq_(select([func.count("*")]).select_from(email_bounces).scalar(), 5) def test_mapping_to_outerjoin(self): """Mapping to an outer join with a nullable composite primary key.""" - users, addresses, User = (self.tables.users, - self.tables.addresses, - self.classes.User) + users, addresses, User = ( + self.tables.users, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users.outerjoin(addresses), - primary_key=[users.c.id, addresses.c.id], - properties=dict( - address_id=addresses.c.id)) + mapper( + User, + users.outerjoin(addresses), + primary_key=[users.c.id, addresses.c.id], + properties=dict(address_id=addresses.c.id), + ) session = create_session() result = session.query(User).order_by(User.id, User.address_id).all() - eq_(result, [ - User(id=7, address_id=1), - User(id=8, address_id=2), - User(id=8, address_id=3), - User(id=8, address_id=4), - User(id=9, address_id=5), - User(id=10, address_id=None)]) + eq_( + result, + [ + User(id=7, address_id=1), + User(id=8, address_id=2), + User(id=8, address_id=3), + User(id=8, address_id=4), + User(id=9, address_id=5), + User(id=10, address_id=None), + ], + ) def test_mapping_to_outerjoin_no_partial_pks(self): """test the allow_partial_pks=False flag.""" - users, addresses, User = (self.tables.users, - self.tables.addresses, - self.classes.User) + users, addresses, User = ( + self.tables.users, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users.outerjoin(addresses), - allow_partial_pks=False, - primary_key=[users.c.id, addresses.c.id], - properties=dict( - address_id=addresses.c.id)) + mapper( + User, + users.outerjoin(addresses), + allow_partial_pks=False, + primary_key=[users.c.id, addresses.c.id], + properties=dict(address_id=addresses.c.id), + ) session = create_session() result = session.query(User).order_by(User.id, User.address_id).all() - eq_(result, [ - User(id=7, address_id=1), - User(id=8, address_id=2), - User(id=8, address_id=3), - User(id=8, address_id=4), - User(id=9, address_id=5), - None]) + eq_( + result, + [ + User(id=7, address_id=1), + User(id=8, address_id=2), + User(id=8, address_id=3), + User(id=8, address_id=4), + User(id=9, address_id=5), + None, + ], + ) def test_scalar_pk_arg(self): - users, Keyword, items, Item, User, keywords = (self.tables.users, - self.classes.Keyword, - self.tables.items, - self.classes.Item, - self.classes.User, - self.tables.keywords) + users, Keyword, items, Item, User, keywords = ( + self.tables.users, + self.classes.Keyword, + self.tables.items, + self.classes.Item, + self.classes.User, + self.tables.keywords, + ) m1 = mapper(Item, items, primary_key=[items.c.id]) m2 = mapper(Keyword, keywords, primary_key=keywords.c.id) @@ -1150,22 +1387,25 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.tables.orders, self.classes.Item, self.classes.User, - self.classes.Order) + self.classes.Order, + ) mapper(Item, items) - mapper(Order, orders, properties=dict( - items=relationship(Item, order_items))) + mapper( + Order, + orders, + properties=dict(items=relationship(Item, order_items)), + ) - mapper(User, users, properties=dict( - orders=relationship(Order))) + mapper(User, users, properties=dict(orders=relationship(Order))) session = create_session() - result = (session.query(User). - select_from(users.join(orders). - join(order_items). - join(items)). - filter(items.c.description == 'item 4')).all() + result = ( + session.query(User) + .select_from(users.join(orders).join(order_items).join(items)) + .filter(items.c.description == "item 4") + ).all() eq_(result, [self.static.user_result[0]]) @@ -1175,40 +1415,66 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): mapper(User, users, order_by=users.c.name.desc()) - assert "order by users.name desc" in \ - str(create_session().query(User).statement).lower() - assert "order by" not in \ - str(create_session().query(User).order_by(None).statement).lower() - assert "order by users.name asc" in \ - str(create_session().query(User).order_by( - User.name.asc()).statement).lower() + assert ( + "order by users.name desc" + in str(create_session().query(User).statement).lower() + ) + assert ( + "order by" + not in str( + create_session().query(User).order_by(None).statement + ).lower() + ) + assert ( + "order by users.name asc" + in str( + create_session() + .query(User) + .order_by(User.name.asc()) + .statement + ).lower() + ) eq_( create_session().query(User).all(), - [User(id=7, name='jack'), User(id=9, name='fred'), - User(id=8, name='ed'), User(id=10, name='chuck')] + [ + User(id=7, name="jack"), + User(id=9, name="fred"), + User(id=8, name="ed"), + User(id=10, name="chuck"), + ], ) eq_( create_session().query(User).order_by(User.name).all(), - [User(id=10, name='chuck'), User(id=8, name='ed'), - User(id=9, name='fred'), User(id=7, name='jack')] + [ + User(id=10, name="chuck"), + User(id=8, name="ed"), + User(id=9, name="fred"), + User(id=7, name="jack"), + ], ) # 'Raises a "expression evaluation not supported" error at prepare time - @testing.fails_on('firebird', 'FIXME: unknown') + @testing.fails_on("firebird", "FIXME: unknown") def test_function(self): """Mapping to a SELECT statement that has functions in it.""" - addresses, users, User = (self.tables.addresses, - self.tables.users, - self.classes.User) + addresses, users, User = ( + self.tables.addresses, + self.tables.users, + self.classes.User, + ) - s = sa.select([users, - (users.c.id * 2).label('concat'), - sa.func.count(addresses.c.id).label('count')], - users.c.id == addresses.c.user_id, - group_by=[c for c in users.c]).alias('myselect') + s = sa.select( + [ + users, + (users.c.id * 2).label("concat"), + sa.func.count(addresses.c.id).label("count"), + ], + users.c.id == addresses.c.user_id, + group_by=[c for c in users.c], + ).alias("myselect") mapper(User, s) sess = create_session() @@ -1241,66 +1507,89 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.tables.items, self.tables.item_keywords, self.classes.Keyword, - self.classes.Item) + self.classes.Item, + ) mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords=relationship(Keyword, item_keywords, lazy='select'))) + mapper( + Item, + items, + properties=dict( + keywords=relationship(Keyword, item_keywords, lazy="select") + ), + ) session = create_session() - q = (session.query(Item). - join('keywords'). - distinct(). - filter(Keyword.name == "red")) + q = ( + session.query(Item) + .join("keywords") + .distinct() + .filter(Keyword.name == "red") + ) eq_(q.count(), 2) def test_override_1(self): """Overriding a column raises an error.""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) def go(): - mapper(User, users, - properties=dict( - name=relationship(mapper(Address, addresses)))) + mapper( + User, + users, + properties=dict(name=relationship(mapper(Address, addresses))), + ) assert_raises(sa.exc.ArgumentError, go) def test_override_2(self): """exclude_properties cancels the error.""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) - mapper(User, users, - exclude_properties=['name'], - properties=dict( - name=relationship(mapper(Address, addresses)))) + mapper( + User, + users, + exclude_properties=["name"], + properties=dict(name=relationship(mapper(Address, addresses))), + ) assert bool(User.name) def test_override_3(self): """The column being named elsewhere also cancels the error,""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) - mapper(User, users, - properties=dict( - name=relationship(mapper(Address, addresses)), - foo=users.c.name)) + mapper( + User, + users, + properties=dict( + name=relationship(mapper(Address, addresses)), foo=users.c.name + ), + ) def test_synonym(self): - users, addresses, Address = (self.tables.users, - self.tables.addresses, - self.classes.Address) + users, addresses, Address = ( + self.tables.users, + self.tables.addresses, + self.classes.Address, + ) assert_col = [] @@ -1308,34 +1597,40 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): attribute = 123 class User(object): - def _get_name(self): - assert_col.append(('get', self.name)) + assert_col.append(("get", self.name)) return self.name def _set_name(self, name): - assert_col.append(('set', name)) + assert_col.append(("set", name)) self.name = name + uname = extendedproperty(_get_name, _set_name) - mapper(User, users, properties=dict( - addresses=relationship(mapper(Address, addresses), lazy='select'), - uname=synonym('name'), - adlist=synonym('addresses'), - adname=synonym('addresses') - )) + mapper( + User, + users, + properties=dict( + addresses=relationship( + mapper(Address, addresses), lazy="select" + ), + uname=synonym("name"), + adlist=synonym("addresses"), + adname=synonym("addresses"), + ), + ) # ensure the synonym can get at the proxied comparators without # an explicit compile - User.name == 'ed' + User.name == "ed" User.adname.any() - assert hasattr(User, 'adlist') + assert hasattr(User, "adlist") # as of 0.4.2, synonyms always create a property - assert hasattr(User, 'adname') + assert hasattr(User, "adname") # test compile - assert not isinstance(User.uname == 'jack', bool) + assert not isinstance(User.uname == "jack", bool) assert User.uname.property assert User.adlist.property @@ -1346,7 +1641,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): row = sess.query(User.id, User.uname).first() assert row.uname == row[1] - u = sess.query(User).filter(User.uname == 'jack').one() + u = sess.query(User).filter(User.uname == "jack").one() fixture = self.static.user_address_result[0].addresses eq_(u.adlist, fixture) @@ -1360,35 +1655,27 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): assert u not in sess.dirty u.uname = "some user name" assert len(assert_col) > 0 - eq_(assert_col, [('set', 'some user name')]) + eq_(assert_col, [("set", "some user name")]) eq_(u.uname, "some user name") - eq_(assert_col, [('set', 'some user name'), ('get', 'some user name')]) + eq_(assert_col, [("set", "some user name"), ("get", "some user name")]) eq_(u.name, "some user name") assert u in sess.dirty eq_(User.uname.attribute, 123) def test_synonym_of_synonym(self): - users, User = (self.tables.users, - self.classes.User) + users, User = (self.tables.users, self.classes.User) - mapper(User, users, properties={ - 'x': synonym('id'), - 'y': synonym('x') - }) + mapper(User, users, properties={"x": synonym("id"), "y": synonym("x")}) s = Session() u = s.query(User).filter(User.y == 8).one() eq_(u.y, 8) def test_synonym_get_history(self): - users, User = (self.tables.users, - self.classes.User) + users, User = (self.tables.users, self.classes.User) - mapper(User, users, properties={ - 'x': synonym('id'), - 'y': synonym('x') - }) + mapper(User, users, properties={"x": synonym("id"), "y": synonym("x")}) u1 = User() eq_(attributes.instance_state(u1).attrs.x.history, (None, None, None)) @@ -1407,74 +1694,90 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): users, Address, addresses = ( self.tables.users, self.classes.Address, - self.tables.addresses) + self.tables.addresses, + ) - mapper(User, users, properties={ - 'y': synonym('x'), - 'addresses': relationship(Address) - }) + mapper( + User, + users, + properties={"y": synonym("x"), "addresses": relationship(Address)}, + ) mapper(Address, addresses) User.x = association_proxy("addresses", "email_address") assert_raises_message( sa.exc.InvalidRequestError, r'synonym\(\) attribute "User.x" only supports ORM mapped ' - 'attributes, got .*AssociationProxy', - getattr, User.y, "property" + "attributes, got .*AssociationProxy", + getattr, + User.y, + "property", ) def test_synonym_column_location(self): users, User = self.tables.users, self.classes.User def go(): - mapper(User, users, properties={ - 'not_name': synonym('_name', map_column=True)}) + mapper( + User, + users, + properties={"not_name": synonym("_name", map_column=True)}, + ) assert_raises_message( sa.exc.ArgumentError, - ("Can't compile synonym '_name': no column on table " - "'users' named 'not_name'"), - go) + ( + "Can't compile synonym '_name': no column on table " + "'users' named 'not_name'" + ), + go, + ) def test_column_synonyms(self): """Synonyms which automatically instrument properties, set up aliased column, etc.""" - addresses, users, Address = (self.tables.addresses, - self.tables.users, - self.classes.Address) + addresses, users, Address = ( + self.tables.addresses, + self.tables.users, + self.classes.Address, + ) assert_col = [] class User(object): - def _get_name(self): - assert_col.append(('get', self._name)) + assert_col.append(("get", self._name)) return self._name def _set_name(self, name): - assert_col.append(('set', name)) + assert_col.append(("set", name)) self._name = name + name = property(_get_name, _set_name) mapper(Address, addresses) - mapper(User, users, properties={ - 'addresses': relationship(Address, lazy='select'), - 'name': synonym('_name', map_column=True) - }) + mapper( + User, + users, + properties={ + "addresses": relationship(Address, lazy="select"), + "name": synonym("_name", map_column=True), + }, + ) # test compile - assert not isinstance(User.name == 'jack', bool) + assert not isinstance(User.name == "jack", bool) - assert hasattr(User, 'name') - assert hasattr(User, '_name') + assert hasattr(User, "name") + assert hasattr(User, "_name") sess = create_session() - u = sess.query(User).filter(User.name == 'jack').one() - eq_(u.name, 'jack') - u.name = 'foo' - eq_(u.name, 'foo') - eq_(assert_col, [('get', 'jack'), ('set', 'foo'), ('get', 'foo')]) + u = sess.query(User).filter(User.name == "jack").one() + eq_(u.name, "jack") + u.name = "foo" + eq_(u.name, "foo") + eq_(assert_col, [("get", "jack"), ("set", "foo"), ("get", "foo")]) def test_synonym_map_column_conflict(self): users, User = self.tables.users, self.classes.User @@ -1482,19 +1785,27 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): assert_raises( sa.exc.ArgumentError, mapper, - User, users, properties=util.OrderedDict([ - ('_user_id', users.c.id), - ('id', synonym('_user_id', map_column=True)), - ]) + User, + users, + properties=util.OrderedDict( + [ + ("_user_id", users.c.id), + ("id", synonym("_user_id", map_column=True)), + ] + ), ) assert_raises( sa.exc.ArgumentError, mapper, - User, users, properties=util.OrderedDict([ - ('id', synonym('_user_id', map_column=True)), - ('_user_id', users.c.id), - ]) + User, + users, + properties=util.OrderedDict( + [ + ("id", synonym("_user_id", map_column=True)), + ("_user_id", users.c.id), + ] + ), ) def test_comparable(self): @@ -1519,7 +1830,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): def __eq__(self, other): cls = self.prop.parent.class_ - col = getattr(cls, 'name') + col = getattr(cls, "name") if other is None: return col is None else: @@ -1527,18 +1838,21 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): def map_(with_explicit_property): class User(object): - @extendedproperty def uc_name(self): if self.name is None: return None return self.name.upper() + if with_explicit_property: args = (UCComparator, User.uc_name) else: args = (UCComparator,) - mapper(User, users, properties=dict( - uc_name=sa.orm.comparable_property(*args))) + mapper( + User, + users, + properties=dict(uc_name=sa.orm.comparable_property(*args)), + ) return User for User in (map_(True), map_(False)): @@ -1546,22 +1860,25 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): sess.begin() q = sess.query(User) - assert hasattr(User, 'name') - assert hasattr(User, 'uc_name') + assert hasattr(User, "name") + assert hasattr(User, "uc_name") eq_(User.uc_name.method1(), "method1") - eq_(User.uc_name.method2('x'), "method2") + eq_(User.uc_name.method2("x"), "method2") assert_raises_message( AttributeError, "Neither 'extendedproperty' object nor 'UCComparator' " "object associated with User.uc_name has an attribute " "'nonexistent'", - getattr, User.uc_name, 'nonexistent') + getattr, + User.uc_name, + "nonexistent", + ) # test compile - assert not isinstance(User.uc_name == 'jack', bool) - u = q.filter(User.uc_name == 'JACK').one() + assert not isinstance(User.uc_name == "jack", bool) + u = q.filter(User.uc_name == "JACK").one() assert u.uc_name == "JACK" assert u not in sess.dirty @@ -1575,8 +1892,8 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): sess.expunge_all() q = sess.query(User) - u2 = q.filter(User.name == 'some user name').one() - u3 = q.filter(User.uc_name == 'SOME USER NAME').one() + u2 = q.filter(User.name == "some user name").one() + u3 = q.filter(User.uc_name == "SOME USER NAME").one() assert u2 is u3 @@ -1591,63 +1908,80 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): def __eq__(self, other): # lower case comparison - return func.lower(self.__clause_element__() - ) == func.lower(other) + return func.lower(self.__clause_element__()) == func.lower( + other + ) def intersects(self, other): # non-standard comparator - return self.__clause_element__().op('&=')(other) + return self.__clause_element__().op("&=")(other) - mapper(User, users, properties={ - 'name': sa.orm.column_property(users.c.name, - comparator_factory=MyComparator) - }) + mapper( + User, + users, + properties={ + "name": sa.orm.column_property( + users.c.name, comparator_factory=MyComparator + ) + }, + ) assert_raises_message( AttributeError, "Neither 'InstrumentedAttribute' object nor " "'MyComparator' object associated with User.name has " "an attribute 'nonexistent'", - getattr, User.name, "nonexistent") + getattr, + User.name, + "nonexistent", + ) eq_( - str((User.name == 'ed').compile( - dialect=sa.engine.default.DefaultDialect())), - "lower(users.name) = lower(:lower_1)") + str( + (User.name == "ed").compile( + dialect=sa.engine.default.DefaultDialect() + ) + ), + "lower(users.name) = lower(:lower_1)", + ) eq_( - str((User.name.intersects('ed')).compile( - dialect=sa.engine.default.DefaultDialect())), - "users.name &= :name_1") + str( + (User.name.intersects("ed")).compile( + dialect=sa.engine.default.DefaultDialect() + ) + ), + "users.name &= :name_1", + ) def test_reentrant_compile(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) class MyFakeProperty(sa.orm.properties.ColumnProperty): - def post_instrument_class(self, mapper): super(MyFakeProperty, self).post_instrument_class(mapper) configure_mappers() - m1 = mapper(User, users, properties={ - 'name': MyFakeProperty(users.c.name) - }) + m1 = mapper( + User, users, properties={"name": MyFakeProperty(users.c.name)} + ) m2 = mapper(Address, addresses) configure_mappers() sa.orm.clear_mappers() class MyFakeProperty(sa.orm.properties.ColumnProperty): - def post_instrument_class(self, mapper): super(MyFakeProperty, self).post_instrument_class(mapper) configure_mappers() - m1 = mapper(User, users, properties={ - 'name': MyFakeProperty(users.c.name) - }) + m1 = mapper( + User, users, properties={"name": MyFakeProperty(users.c.name)} + ) m2 = mapper(Address, addresses) configure_mappers() @@ -1657,17 +1991,16 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): recon = [] class User(object): - @reconstructor def reconstruct(self): - recon.append('go') + recon.append("go") mapper(User, users) User() eq_(recon, []) create_session().query(User).first() - eq_(recon, ['go']) + eq_(recon, ["go"]) def test_reconstructor_inheritance(self): users = self.tables.users @@ -1675,30 +2008,28 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): recon = [] class A(object): - @reconstructor def reconstruct(self): assert isinstance(self, A) - recon.append('A') + recon.append("A") class B(A): - @reconstructor def reconstruct(self): assert isinstance(self, B) - recon.append('B') + recon.append("B") class C(A): - @reconstructor def reconstruct(self): assert isinstance(self, C) - recon.append('C') + recon.append("C") - mapper(A, users, polymorphic_on=users.c.name, - polymorphic_identity='jack') - mapper(B, inherits=A, polymorphic_identity='ed') - mapper(C, inherits=A, polymorphic_identity='chuck') + mapper( + A, users, polymorphic_on=users.c.name, polymorphic_identity="jack" + ) + mapper(B, inherits=A, polymorphic_identity="ed") + mapper(C, inherits=A, polymorphic_identity="chuck") A() B() @@ -1709,7 +2040,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): sess.query(A).first() sess.query(B).first() sess.query(C).first() - eq_(recon, ['A', 'B', 'C']) + eq_(recon, ["A", "B", "C"]) def test_reconstructor_init(self): @@ -1718,19 +2049,18 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): recon = [] class User(object): - @reconstructor def __init__(self): - recon.append('go') + recon.append("go") mapper(User, users) User() - eq_(recon, ['go']) + eq_(recon, ["go"]) recon[:] = [] create_session().query(User).first() - eq_(recon, ['go']) + eq_(recon, ["go"]) def test_reconstructor_init_inheritance(self): users = self.tables.users @@ -1738,43 +2068,40 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): recon = [] class A(object): - @reconstructor def __init__(self): assert isinstance(self, A) - recon.append('A') + recon.append("A") class B(A): - @reconstructor def __init__(self): assert isinstance(self, B) - recon.append('B') + recon.append("B") class C(A): - @reconstructor def __init__(self): assert isinstance(self, C) - recon.append('C') + recon.append("C") - mapper(A, users, polymorphic_on=users.c.name, - polymorphic_identity='jack') - mapper(B, inherits=A, polymorphic_identity='ed') - mapper(C, inherits=A, polymorphic_identity='chuck') + mapper( + A, users, polymorphic_on=users.c.name, polymorphic_identity="jack" + ) + mapper(B, inherits=A, polymorphic_identity="ed") + mapper(C, inherits=A, polymorphic_identity="chuck") A() B() C() - eq_(recon, ['A', 'B', 'C']) + eq_(recon, ["A", "B", "C"]) recon[:] = [] sess = create_session() sess.query(A).first() sess.query(B).first() sess.query(C).first() - eq_(recon, ['A', 'B', 'C']) - + eq_(recon, ["A", "B", "C"]) def test_unmapped_reconstructor_inheritance(self): users = self.tables.users @@ -1782,10 +2109,9 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): recon = [] class Base(object): - @reconstructor def reconstruct(self): - recon.append('go') + recon.append("go") class User(Base): pass @@ -1796,38 +2122,41 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): eq_(recon, []) create_session().query(User).first() - eq_(recon, ['go']) + eq_(recon, ["go"]) def test_unmapped_error(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) sa.orm.clear_mappers() - mapper(User, users, properties={ - 'addresses': relationship(Address) - }) + mapper(User, users, properties={"addresses": relationship(Address)}) assert_raises_message( sa.orm.exc.UnmappedClassError, "Class 'test.orm._fixtures.Address' is not mapped", - sa.orm.configure_mappers) + sa.orm.configure_mappers, + ) def test_unmapped_not_type_error(self): assert_raises_message( sa.exc.ArgumentError, "Class object expected, got '5'.", - class_mapper, 5 + class_mapper, + 5, ) def test_unmapped_not_type_error_iter_ok(self): assert_raises_message( sa.exc.ArgumentError, r"Class object expected, got '\(5, 6\)'.", - class_mapper, (5, 6) + class_mapper, + (5, 6), ) def test_attribute_error_raised_class_mapper(self): @@ -1836,16 +2165,22 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - mapper(User, users, properties={ - "addresses": relationship( - Address, - primaryjoin=lambda: users.c.id == addresses.wrong.user_id) - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, + primaryjoin=lambda: users.c.id == addresses.wrong.user_id, + ) + }, + ) mapper(Address, addresses) assert_raises_message( AttributeError, "'Table' object has no attribute 'wrong'", - class_mapper, Address + class_mapper, + Address, ) def test_key_error_raised_class_mapper(self): @@ -1854,17 +2189,19 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): User = self.classes.User Address = self.classes.Address - mapper(User, users, properties={ - "addresses": relationship(Address, - primaryjoin=lambda: users.c.id == - addresses.__dict__['wrong'].user_id) - }) - mapper(Address, addresses) - assert_raises_message( - KeyError, - "wrong", - class_mapper, Address + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, + primaryjoin=lambda: users.c.id + == addresses.__dict__["wrong"].user_id, + ) + }, ) + mapper(Address, addresses) + assert_raises_message(KeyError, "wrong", class_mapper, Address) def test_unmapped_subclass_error_postmap(self): users = self.tables.users @@ -1880,16 +2217,14 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): # we can create new instances, set attributes. s = Sub() - s.name = 'foo' - eq_(s.name, 'foo') - eq_( - attributes.get_history(s, 'name'), - (['foo'], (), ()) - ) + s.name = "foo" + eq_(s.name, "foo") + eq_(attributes.get_history(s, "name"), (["foo"], (), ())) # using it with an ORM operation, raises - assert_raises(sa.orm.exc.UnmappedClassError, - create_session().add, Sub()) + assert_raises( + sa.orm.exc.UnmappedClassError, create_session().add, Sub() + ) def test_unmapped_subclass_error_premap(self): users = self.tables.users @@ -1906,16 +2241,14 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): # we can create new instances, set attributes. s = Sub() - s.name = 'foo' - eq_(s.name, 'foo') - eq_( - attributes.get_history(s, 'name'), - (['foo'], (), ()) - ) + s.name = "foo" + eq_(s.name, "foo") + eq_(attributes.get_history(s, "name"), (["foo"], (), ())) # using it with an ORM operation, raises - assert_raises(sa.orm.exc.UnmappedClassError, - create_session().add, Sub()) + assert_raises( + sa.orm.exc.UnmappedClassError, create_session().add, Sub() + ) def test_oldstyle_mixin(self): users = self.tables.users @@ -1936,25 +2269,35 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): mapper(B, users) -class DocumentTest(fixtures.TestBase): +class DocumentTest(fixtures.TestBase): def test_doc_propagate(self): metadata = MetaData() - t1 = Table('t1', metadata, - Column('col1', Integer, primary_key=True, - doc="primary key column"), - Column('col2', String, doc="data col"), - Column('col3', String, doc="data col 2"), - Column('col4', String, doc="data col 3"), - Column('col5', String), - ) - t2 = Table('t2', metadata, - Column('col1', Integer, primary_key=True, - doc="primary key column"), - Column('col2', String, doc="data col"), - Column('col3', Integer, ForeignKey('t1.col1'), - doc="foreign key to t1.col1") - ) + t1 = Table( + "t1", + metadata, + Column( + "col1", Integer, primary_key=True, doc="primary key column" + ), + Column("col2", String, doc="data col"), + Column("col3", String, doc="data col 2"), + Column("col4", String, doc="data col 3"), + Column("col5", String), + ) + t2 = Table( + "t2", + metadata, + Column( + "col1", Integer, primary_key=True, doc="primary key column" + ), + Column("col2", String, doc="data col"), + Column( + "col3", + Integer, + ForeignKey("t1.col1"), + doc="foreign key to t1.col1", + ), + ) class Foo(object): pass @@ -1962,14 +2305,19 @@ class DocumentTest(fixtures.TestBase): class Bar(object): pass - mapper(Foo, t1, properties={ - 'bars': relationship(Bar, - doc="bar relationship", - backref=backref('foo', doc='foo relationship') - ), - 'foober': column_property(t1.c.col3, doc='alternate data col'), - 'hoho': synonym("col4", doc="syn of col4") - }) + mapper( + Foo, + t1, + properties={ + "bars": relationship( + Bar, + doc="bar relationship", + backref=backref("foo", doc="foo relationship"), + ), + "foober": column_property(t1.c.col3, doc="alternate data col"), + "hoho": synonym("col4", doc="syn of col4"), + }, + ) mapper(Bar, t2) configure_mappers() eq_(Foo.col1.__doc__, "primary key column") @@ -1983,18 +2331,13 @@ class DocumentTest(fixtures.TestBase): class ORMLoggingTest(_fixtures.FixtureTest): - def setup(self): self.buf = logging.handlers.BufferingHandler(100) - for log in [ - logging.getLogger('sqlalchemy.orm'), - ]: + for log in [logging.getLogger("sqlalchemy.orm")]: log.addHandler(self.buf) def teardown(self): - for log in [ - logging.getLogger('sqlalchemy.orm'), - ]: + for log in [logging.getLogger("sqlalchemy.orm")]: log.removeHandler(self.buf) def _current_messages(self): @@ -2005,74 +2348,108 @@ class ORMLoggingTest(_fixtures.FixtureTest): tb = users.select().alias() mapper(User, tb) s = Session() - s.add(User(name='ed')) + s.add(User(name="ed")) s.commit() for msg in self._current_messages(): - assert msg.startswith('(User|%%(%d anon)s) ' % id(tb)) + assert msg.startswith("(User|%%(%d anon)s) " % id(tb)) class OptionsTest(_fixtures.FixtureTest): - def test_synonym_options(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) - mapper(User, users, properties=dict( - addresses=relationship(mapper(Address, addresses), lazy='select', - order_by=addresses.c.id), - adlist=synonym('addresses'))) + mapper( + User, + users, + properties=dict( + addresses=relationship( + mapper(Address, addresses), + lazy="select", + order_by=addresses.c.id, + ), + adlist=synonym("addresses"), + ), + ) def go(): sess = create_session() - u = (sess.query(User). - order_by(User.id). - options(sa.orm.joinedload('adlist')). - filter_by(name='jack')).one() - eq_(u.adlist, - [self.static.user_address_result[0].addresses[0]]) + u = ( + sess.query(User) + .order_by(User.id) + .options(sa.orm.joinedload("adlist")) + .filter_by(name="jack") + ).one() + eq_(u.adlist, [self.static.user_address_result[0].addresses[0]]) + self.assert_sql_count(testing.db, go, 1) def test_eager_options(self): """A lazy relationship can be upgraded to an eager relationship.""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) - mapper(User, users, properties=dict( - addresses=relationship(mapper(Address, addresses), - order_by=addresses.c.id))) + mapper( + User, + users, + properties=dict( + addresses=relationship( + mapper(Address, addresses), order_by=addresses.c.id + ) + ), + ) sess = create_session() - result = (sess.query(User). - order_by(User.id). - options(sa.orm.joinedload('addresses'))).all() + result = ( + sess.query(User) + .order_by(User.id) + .options(sa.orm.joinedload("addresses")) + ).all() def go(): eq_(result, self.static.user_address_result) + self.sql_count_(0, go) def test_eager_options_with_limit(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) - mapper(User, users, properties=dict( - addresses=relationship(mapper(Address, addresses), lazy='select'))) + mapper( + User, + users, + properties=dict( + addresses=relationship( + mapper(Address, addresses), lazy="select" + ) + ), + ) sess = create_session() - u = (sess.query(User). - options(sa.orm.joinedload('addresses')). - filter_by(id=8)).one() + u = ( + sess.query(User) + .options(sa.orm.joinedload("addresses")) + .filter_by(id=8) + ).one() def go(): eq_(u.id, 8) eq_(len(u.addresses), 3) + self.sql_count_(0, go) sess.expunge_all() @@ -2082,36 +2459,58 @@ class OptionsTest(_fixtures.FixtureTest): eq_(len(u.addresses), 3) def test_lazy_options_with_limit(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) - mapper(User, users, properties=dict( - addresses=relationship(mapper(Address, addresses), lazy='joined'))) + mapper( + User, + users, + properties=dict( + addresses=relationship( + mapper(Address, addresses), lazy="joined" + ) + ), + ) sess = create_session() - u = (sess.query(User). - options(sa.orm.lazyload('addresses')). - filter_by(id=8)).one() + u = ( + sess.query(User) + .options(sa.orm.lazyload("addresses")) + .filter_by(id=8) + ).one() def go(): eq_(u.id, 8) eq_(len(u.addresses), 3) + self.sql_count_(1, go) def test_eager_degrade(self): """An eager relationship automatically degrades to a lazy relationship if eager columns are not available""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) - mapper(User, users, properties=dict( - addresses=relationship(mapper(Address, addresses), - lazy='joined', order_by=addresses.c.id))) + mapper( + User, + users, + properties=dict( + addresses=relationship( + mapper(Address, addresses), + lazy="joined", + order_by=addresses.c.id, + ) + ), + ) sess = create_session() # first test straight eager load, 1 statement @@ -2119,6 +2518,7 @@ class OptionsTest(_fixtures.FixtureTest): def go(): result = sess.query(User).order_by(User.id).all() eq_(result, self.static.user_address_result) + self.sql_count_(1, go) sess.expunge_all() @@ -2132,23 +2532,24 @@ class OptionsTest(_fixtures.FixtureTest): def go(): result = list(sess.query(User).instances(r)) eq_(result, self.static.user_address_result) + self.sql_count_(4, go) def test_eager_degrade_deep(self): - users, Keyword, items, order_items, orders, \ - Item, User, Address, keywords, item_keywords, Order, addresses = ( - self.tables.users, - self.classes.Keyword, - self.tables.items, - self.tables.order_items, - self.tables.orders, - self.classes.Item, - self.classes.User, - self.classes.Address, - self.tables.keywords, - self.tables.item_keywords, - self.classes.Order, - self.tables.addresses) + users, Keyword, items, order_items, orders, Item, User, Address, keywords, item_keywords, Order, addresses = ( + self.tables.users, + self.classes.Keyword, + self.tables.items, + self.tables.order_items, + self.tables.orders, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.tables.keywords, + self.tables.item_keywords, + self.classes.Order, + self.tables.addresses, + ) # test with a deeper set of eager loads. when we first load the three # users, they will have no addresses or orders. the number of lazy @@ -2158,20 +2559,44 @@ class OptionsTest(_fixtures.FixtureTest): mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords=relationship(Keyword, secondary=item_keywords, - lazy='joined', - order_by=item_keywords.c.keyword_id))) + mapper( + Item, + items, + properties=dict( + keywords=relationship( + Keyword, + secondary=item_keywords, + lazy="joined", + order_by=item_keywords.c.keyword_id, + ) + ), + ) - mapper(Order, orders, properties=dict( - items=relationship(Item, secondary=order_items, lazy='joined', - order_by=order_items.c.item_id))) + mapper( + Order, + orders, + properties=dict( + items=relationship( + Item, + secondary=order_items, + lazy="joined", + order_by=order_items.c.item_id, + ) + ), + ) - mapper(User, users, properties=dict( - addresses=relationship(Address, lazy='joined', - order_by=addresses.c.id), - orders=relationship(Order, lazy='joined', - order_by=orders.c.id))) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, lazy="joined", order_by=addresses.c.id + ), + orders=relationship( + Order, lazy="joined", order_by=orders.c.id + ), + ), + ) sess = create_session() @@ -2179,6 +2604,7 @@ class OptionsTest(_fixtures.FixtureTest): def go(): result = sess.query(User).order_by(User.id).all() eq_(result, self.static.user_all_result) + self.assert_sql_count(testing.db, go, 1) sess.expunge_all() @@ -2190,27 +2616,39 @@ class OptionsTest(_fixtures.FixtureTest): def go(): result = list(sess.query(User).instances(r)) eq_(result, self.static.user_all_result) + self.assert_sql_count(testing.db, go, 6) def test_lazy_options(self): """An eager relationship can be upgraded to a lazy relationship.""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) - mapper(User, users, properties=dict( - addresses=relationship(mapper(Address, addresses), lazy='joined') - )) + mapper( + User, + users, + properties=dict( + addresses=relationship( + mapper(Address, addresses), lazy="joined" + ) + ), + ) sess = create_session() - result = (sess.query(User). - order_by(User.id). - options(sa.orm.lazyload('addresses'))).all() + result = ( + sess.query(User) + .order_by(User.id) + .options(sa.orm.lazyload("addresses")) + ).all() def go(): eq_(result, self.static.user_address_result) + self.sql_count_(4, go) def test_option_propagate(self): @@ -2221,14 +2659,15 @@ class OptionsTest(_fixtures.FixtureTest): self.classes.Order, self.classes.Item, self.classes.User, - self.tables.orders) - - mapper(User, users, properties=dict( - orders=relationship(Order) - )) - mapper(Order, orders, properties=dict( - items=relationship(Item, secondary=order_items) - )) + self.tables.orders, + ) + + mapper(User, users, properties=dict(orders=relationship(Order))) + mapper( + Order, + orders, + properties=dict(items=relationship(Item, secondary=order_items)), + ) mapper(Item, items) sess = create_session() @@ -2236,42 +2675,58 @@ class OptionsTest(_fixtures.FixtureTest): oalias = aliased(Order) opt1 = sa.orm.joinedload(User.orders, Order.items) opt2 = sa.orm.contains_eager(User.orders, Order.items, alias=oalias) - u1 = sess.query(User).join(oalias, User.orders).\ - options(opt1, opt2).first() + u1 = ( + sess.query(User) + .join(oalias, User.orders) + .options(opt1, opt2) + .first() + ) ustate = attributes.instance_state(u1) assert opt1 in ustate.load_options assert opt2 not in ustate.load_options class DeepOptionsTest(_fixtures.FixtureTest): - @classmethod def setup_mappers(cls): - users, Keyword, items, order_items, Order, Item, User, \ - keywords, item_keywords, orders = ( - cls.tables.users, - cls.classes.Keyword, - cls.tables.items, - cls.tables.order_items, - cls.classes.Order, - cls.classes.Item, - cls.classes.User, - cls.tables.keywords, - cls.tables.item_keywords, - cls.tables.orders) + users, Keyword, items, order_items, Order, Item, User, keywords, item_keywords, orders = ( + cls.tables.users, + cls.classes.Keyword, + cls.tables.items, + cls.tables.order_items, + cls.classes.Order, + cls.classes.Item, + cls.classes.User, + cls.tables.keywords, + cls.tables.item_keywords, + cls.tables.orders, + ) mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords=relationship(Keyword, item_keywords, - order_by=item_keywords.c.item_id))) + mapper( + Item, + items, + properties=dict( + keywords=relationship( + Keyword, item_keywords, order_by=item_keywords.c.item_id + ) + ), + ) - mapper(Order, orders, properties=dict( - items=relationship(Item, order_items, - order_by=items.c.id))) + mapper( + Order, + orders, + properties=dict( + items=relationship(Item, order_items, order_by=items.c.id) + ), + ) - mapper(User, users, properties=dict( - orders=relationship(Order, order_by=orders.c.id))) + mapper( + User, + users, + properties=dict(orders=relationship(Order, order_by=orders.c.id)), + ) def test_deep_options_1(self): User = self.classes.User @@ -2283,6 +2738,7 @@ class DeepOptionsTest(_fixtures.FixtureTest): def go(): u[0].orders[1].items[0].keywords[1] + self.assert_sql_count(testing.db, go, 3) def test_deep_options_2(self): @@ -2292,23 +2748,28 @@ class DeepOptionsTest(_fixtures.FixtureTest): sess = create_session() - result = (sess.query(User). - order_by(User.id). - options( - sa.orm.joinedload_all('orders.items.keywords'))).all() + result = ( + sess.query(User) + .order_by(User.id) + .options(sa.orm.joinedload_all("orders.items.keywords")) + ).all() def go(): result[0].orders[1].items[0].keywords[1] + self.sql_count_(0, go) sess = create_session() - result = (sess.query(User). - options( - sa.orm.subqueryload_all('orders.items.keywords'))).all() + result = ( + sess.query(User).options( + sa.orm.subqueryload_all("orders.items.keywords") + ) + ).all() def go(): result[0].orders[1].items[0].keywords[1] + self.sql_count_(0, go) def test_deep_options_3(self): @@ -2317,21 +2778,26 @@ class DeepOptionsTest(_fixtures.FixtureTest): sess = create_session() # same thing, with separate options calls - q2 = (sess.query(User). - order_by(User.id). - options(sa.orm.joinedload('orders')). - options(sa.orm.joinedload('orders.items')). - options(sa.orm.joinedload('orders.items.keywords'))) + q2 = ( + sess.query(User) + .order_by(User.id) + .options(sa.orm.joinedload("orders")) + .options(sa.orm.joinedload("orders.items")) + .options(sa.orm.joinedload("orders.items.keywords")) + ) u = q2.all() def go(): u[0].orders[1].items[0].keywords[1] + self.sql_count_(0, go) def test_deep_options_4(self): - Item, User, Order = (self.classes.Item, - self.classes.User, - self.classes.Order) + Item, User, Order = ( + self.classes.Item, + self.classes.User, + self.classes.Order, + ) sess = create_session() @@ -2339,36 +2805,46 @@ class DeepOptionsTest(_fixtures.FixtureTest): sa.exc.ArgumentError, "Can't find property 'items' on any entity " "specified in this Query.", - sess.query(User).options, sa.orm.joinedload(Order.items)) + sess.query(User).options, + sa.orm.joinedload(Order.items), + ) # joinedload "keywords" on items. it will lazy load "orders", then # lazy load the "items" on the order, but on "items" it will eager # load the "keywords" - q3 = sess.query(User).order_by(User.id).options( - sa.orm.joinedload('orders.items.keywords')) + q3 = ( + sess.query(User) + .order_by(User.id) + .options(sa.orm.joinedload("orders.items.keywords")) + ) u = q3.all() def go(): u[0].orders[1].items[0].keywords[1] + self.sql_count_(2, go) sess = create_session() - q3 = sess.query(User).order_by(User.id).options( - sa.orm.joinedload(User.orders, Order.items, Item.keywords)) + q3 = ( + sess.query(User) + .order_by(User.id) + .options( + sa.orm.joinedload(User.orders, Order.items, Item.keywords) + ) + ) u = q3.all() def go(): u[0].orders[1].items[0].keywords[1] + self.sql_count_(2, go) class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL): - def test_kwarg_accepted(self): users, Address = self.tables.users, self.classes.Address class DummyComposite(object): - def __init__(self, x, y): pass @@ -2380,12 +2856,12 @@ class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL): for args in ( (column_property, users.c.name), (deferred, users.c.name), - (synonym, 'name'), + (synonym, "name"), (composite, DummyComposite, users.c.id, users.c.name), (relationship, Address), - (backref, 'address'), - (comparable_property, ), - (dynamic_loader, Address) + (backref, "address"), + (comparable_property,), + (dynamic_loader, Address), ): fn = args[0] args = args[1:] @@ -2400,22 +2876,29 @@ class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL): __hash__ = None def __eq__(self, other): - return func.foobar(self.__clause_element__()) == \ - func.foobar(other) + return func.foobar(self.__clause_element__()) == func.foobar( + other + ) + mapper( - User, users, + User, + users, properties={ - 'name': column_property( - users.c.name, comparator_factory=MyFactory)}) + "name": column_property( + users.c.name, comparator_factory=MyFactory + ) + }, + ) self.assert_compile( - User.name == 'ed', + User.name == "ed", "foobar(users.name) = foobar(:foobar_1)", - dialect=default.DefaultDialect() + dialect=default.DefaultDialect(), ) self.assert_compile( - aliased(User).name == 'ed', + aliased(User).name == "ed", "foobar(users_1.name) = foobar(:foobar_1)", - dialect=default.DefaultDialect()) + dialect=default.DefaultDialect(), + ) def test_synonym(self): users, User = self.tables.users, self.classes.User @@ -2426,28 +2909,38 @@ class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL): __hash__ = None def __eq__(self, other): - return func.foobar(self.__clause_element__()) ==\ - func.foobar(other) + return func.foobar(self.__clause_element__()) == func.foobar( + other + ) - mapper(User, users, properties={ - 'name': synonym('_name', map_column=True, - comparator_factory=MyFactory) - }) + mapper( + User, + users, + properties={ + "name": synonym( + "_name", map_column=True, comparator_factory=MyFactory + ) + }, + ) self.assert_compile( - User.name == 'ed', + User.name == "ed", "foobar(users.name) = foobar(:foobar_1)", - dialect=default.DefaultDialect()) + dialect=default.DefaultDialect(), + ) self.assert_compile( - aliased(User).name == 'ed', + aliased(User).name == "ed", "foobar(users_1.name) = foobar(:foobar_1)", - dialect=default.DefaultDialect()) + dialect=default.DefaultDialect(), + ) def test_relationship(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) from sqlalchemy.orm.properties import RelationshipProperty @@ -2458,42 +2951,56 @@ class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL): __hash__ = None def __eq__(self, other): - return func.foobar(self._source_selectable().c.user_id) == \ - func.foobar(other.id) + return func.foobar( + self._source_selectable().c.user_id + ) == func.foobar(other.id) class MyFactory2(RelationshipProperty.Comparator): __hash__ = None def __eq__(self, other): - return func.foobar(self._source_selectable().c.id) == \ - func.foobar(other.user_id) + return func.foobar( + self._source_selectable().c.id + ) == func.foobar(other.user_id) mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship( - User, comparator_factory=MyFactory, - backref=backref("addresses", comparator_factory=MyFactory2) - ) - } + mapper( + Address, + addresses, + properties={ + "user": relationship( + User, + comparator_factory=MyFactory, + backref=backref( + "addresses", comparator_factory=MyFactory2 + ), + ) + }, ) # these are kind of nonsensical tests. - self.assert_compile(Address.user == User(id=5), - "foobar(addresses.user_id) = foobar(:foobar_1)", - dialect=default.DefaultDialect()) - self.assert_compile(User.addresses == Address(id=5, user_id=7), - "foobar(users.id) = foobar(:foobar_1)", - dialect=default.DefaultDialect()) + self.assert_compile( + Address.user == User(id=5), + "foobar(addresses.user_id) = foobar(:foobar_1)", + dialect=default.DefaultDialect(), + ) + self.assert_compile( + User.addresses == Address(id=5, user_id=7), + "foobar(users.id) = foobar(:foobar_1)", + dialect=default.DefaultDialect(), + ) self.assert_compile( aliased(Address).user == User(id=5), "foobar(addresses_1.user_id) = foobar(:foobar_1)", - dialect=default.DefaultDialect()) + dialect=default.DefaultDialect(), + ) self.assert_compile( aliased(User).addresses == Address(id=5, user_id=7), "foobar(users_1.id) = foobar(:foobar_1)", - dialect=default.DefaultDialect()) + dialect=default.DefaultDialect(), + ) class SecondaryOptionsTest(fixtures.MappedTest): @@ -2501,34 +3008,45 @@ class SecondaryOptionsTest(fixtures.MappedTest): """test that the contains_eager() option doesn't bleed into a secondary load.""" - run_inserts = 'once' + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - Table("base", metadata, - Column('id', Integer, primary_key=True), - Column('type', String(50), nullable=False) - ) - Table("child1", metadata, - Column('id', Integer, ForeignKey('base.id'), primary_key=True), - Column( - 'child2id', Integer, ForeignKey('child2.id'), nullable=False) - ) - Table("child2", metadata, - Column('id', Integer, ForeignKey('base.id'), primary_key=True), - ) - Table('related', metadata, - Column('id', Integer, ForeignKey('base.id'), primary_key=True), - ) + Table( + "base", + metadata, + Column("id", Integer, primary_key=True), + Column("type", String(50), nullable=False), + ) + Table( + "child1", + metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column( + "child2id", Integer, ForeignKey("child2.id"), nullable=False + ), + ) + Table( + "child2", + metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + ) + Table( + "related", + metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + ) @classmethod def setup_mappers(cls): - child1, child2, base, related = (cls.tables.child1, - cls.tables.child2, - cls.tables.base, - cls.tables.related) + child1, child2, base, related = ( + cls.tables.child1, + cls.tables.child2, + cls.tables.base, + cls.tables.related, + ) class Base(cls.Comparable): pass @@ -2541,63 +3059,71 @@ class SecondaryOptionsTest(fixtures.MappedTest): class Related(cls.Comparable): pass - mapper(Base, base, polymorphic_on=base.c.type, properties={ - 'related': relationship(Related, uselist=False) - }) - mapper(Child1, child1, inherits=Base, - polymorphic_identity='child1', - properties={ - 'child2': relationship( - Child2, - primaryjoin=child1.c.child2id == base.c.id, - foreign_keys=child1.c.child2id) - }) - mapper(Child2, child2, inherits=Base, polymorphic_identity='child2') + + mapper( + Base, + base, + polymorphic_on=base.c.type, + properties={"related": relationship(Related, uselist=False)}, + ) + mapper( + Child1, + child1, + inherits=Base, + polymorphic_identity="child1", + properties={ + "child2": relationship( + Child2, + primaryjoin=child1.c.child2id == base.c.id, + foreign_keys=child1.c.child2id, + ) + }, + ) + mapper(Child2, child2, inherits=Base, polymorphic_identity="child2") mapper(Related, related) @classmethod def insert_data(cls): - child1, child2, base, related = (cls.tables.child1, - cls.tables.child2, - cls.tables.base, - cls.tables.related) - - base.insert().execute([ - {'id': 1, 'type': 'child1'}, - {'id': 2, 'type': 'child1'}, - {'id': 3, 'type': 'child1'}, - {'id': 4, 'type': 'child2'}, - {'id': 5, 'type': 'child2'}, - {'id': 6, 'type': 'child2'}, - ]) - child2.insert().execute([ - {'id': 4}, - {'id': 5}, - {'id': 6}, - ]) - child1.insert().execute([ - {'id': 1, 'child2id': 4}, - {'id': 2, 'child2id': 5}, - {'id': 3, 'child2id': 6}, - ]) - related.insert().execute([ - {'id': 1}, - {'id': 2}, - {'id': 3}, - {'id': 4}, - {'id': 5}, - {'id': 6}, - ]) + child1, child2, base, related = ( + cls.tables.child1, + cls.tables.child2, + cls.tables.base, + cls.tables.related, + ) + + base.insert().execute( + [ + {"id": 1, "type": "child1"}, + {"id": 2, "type": "child1"}, + {"id": 3, "type": "child1"}, + {"id": 4, "type": "child2"}, + {"id": 5, "type": "child2"}, + {"id": 6, "type": "child2"}, + ] + ) + child2.insert().execute([{"id": 4}, {"id": 5}, {"id": 6}]) + child1.insert().execute( + [ + {"id": 1, "child2id": 4}, + {"id": 2, "child2id": 5}, + {"id": 3, "child2id": 6}, + ] + ) + related.insert().execute( + [{"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}, {"id": 6}] + ) def test_contains_eager(self): Child1, Related = self.classes.Child1, self.classes.Related sess = create_session() - child1s = sess.query(Child1).\ - join(Child1.related).\ - options(sa.orm.contains_eager(Child1.related)).\ - order_by(Child1.id) + child1s = ( + sess.query(Child1) + .join(Child1.related) + .options(sa.orm.contains_eager(Child1.related)) + .order_by(Child1.id) + ) def go(): eq_( @@ -2605,9 +3131,10 @@ class SecondaryOptionsTest(fixtures.MappedTest): [ Child1(id=1, related=Related(id=1)), Child1(id=2, related=Related(id=2)), - Child1(id=3, related=Related(id=3)) - ] + Child1(id=3, related=Related(id=3)), + ], ) + self.assert_sql_count(testing.db, go, 1) c1 = child1s[0] @@ -2620,8 +3147,8 @@ class SecondaryOptionsTest(fixtures.MappedTest): "base.type AS base_type " "FROM base JOIN child2 ON base.id = child2.id " "WHERE base.id = :param_1", - {'param_1': 4} - ) + {"param_1": 4}, + ), ) def test_joinedload_on_other(self): @@ -2629,16 +3156,23 @@ class SecondaryOptionsTest(fixtures.MappedTest): sess = create_session() - child1s = sess.query(Child1).join(Child1.related).options( - sa.orm.joinedload(Child1.related)).order_by(Child1.id) + child1s = ( + sess.query(Child1) + .join(Child1.related) + .options(sa.orm.joinedload(Child1.related)) + .order_by(Child1.id) + ) def go(): eq_( child1s.all(), - [Child1(id=1, related=Related(id=1)), - Child1(id=2, related=Related(id=2)), - Child1(id=3, related=Related(id=3))] + [ + Child1(id=1, related=Related(id=1)), + Child1(id=2, related=Related(id=2)), + Child1(id=3, related=Related(id=3)), + ], ) + self.assert_sql_count(testing.db, go, 1) c1 = child1s[0] @@ -2651,29 +3185,36 @@ class SecondaryOptionsTest(fixtures.MappedTest): "base.type AS base_type " "FROM base JOIN child2 ON base.id = child2.id " "WHERE base.id = :param_1", - - {'param_1': 4} - ) + {"param_1": 4}, + ), ) def test_joinedload_on_same(self): - Child1, Child2, Related = (self.classes.Child1, - self.classes.Child2, - self.classes.Related) + Child1, Child2, Related = ( + self.classes.Child1, + self.classes.Child2, + self.classes.Related, + ) sess = create_session() - child1s = sess.query(Child1).join(Child1.related).options( - sa.orm.joinedload(Child1.child2, Child2.related) - ).order_by(Child1.id) + child1s = ( + sess.query(Child1) + .join(Child1.related) + .options(sa.orm.joinedload(Child1.child2, Child2.related)) + .order_by(Child1.id) + ) def go(): eq_( child1s.all(), - [Child1(id=1, related=Related(id=1)), - Child1(id=2, related=Related(id=2)), - Child1(id=3, related=Related(id=3))] + [ + Child1(id=1, related=Related(id=1)), + Child1(id=2, related=Related(id=2)), + Child1(id=3, related=Related(id=3)), + ], ) + self.assert_sql_count(testing.db, go, 4) c1 = child1s[0] @@ -2689,27 +3230,32 @@ class SecondaryOptionsTest(fixtures.MappedTest): "ON base.id = child2.id " "LEFT OUTER JOIN related AS related_1 " "ON base.id = related_1.id WHERE base.id = :param_1", - {'param_1': 4} - ) + {"param_1": 4}, + ), ) class DeferredPopulationTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table("thing", metadata, - Column( - "id", Integer, primary_key=True, - test_needs_autoincrement=True), - Column("name", String(20))) - - Table("human", metadata, - Column( - "id", Integer, primary_key=True, - test_needs_autoincrement=True), - Column("thing_id", Integer, ForeignKey("thing.id")), - Column("name", String(20))) + Table( + "thing", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(20)), + ) + + Table( + "human", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("thing_id", Integer, ForeignKey("thing.id")), + Column("name", String(20)), + ) @classmethod def setup_mappers(cls): @@ -2728,13 +3274,11 @@ class DeferredPopulationTest(fixtures.MappedTest): def insert_data(cls): thing, human = cls.tables.thing, cls.tables.human - thing.insert().execute([ - {"id": 1, "name": "Chair"}, - ]) + thing.insert().execute([{"id": 1, "name": "Chair"}]) - human.insert().execute([ - {"id": 1, "thing_id": 1, "name": "Clark Kent"}, - ]) + human.insert().execute( + [{"id": 1, "thing_id": 1, "name": "Clark Kent"}] + ) def _test(self, thing): assert "name" in attributes.instance_state(thing).dict @@ -2759,7 +3303,7 @@ class DeferredPopulationTest(fixtures.MappedTest): Thing = self.classes.Thing session = create_session() - result = session.query(Thing).first() # noqa + result = session.query(Thing).first() # noqa thing = session.query(Thing).options(sa.orm.undefer("name")).first() self._test(thing) @@ -2767,8 +3311,11 @@ class DeferredPopulationTest(fixtures.MappedTest): Thing, Human = self.classes.Thing, self.classes.Human session = create_session() - human = session.query(Human).options( # noqa - sa.orm.joinedload("thing")).first() + human = ( + session.query(Human) + .options(sa.orm.joinedload("thing")) # noqa + .first() + ) session.expunge_all() thing = session.query(Thing).options(sa.orm.undefer("name")).first() self._test(thing) @@ -2777,8 +3324,11 @@ class DeferredPopulationTest(fixtures.MappedTest): Thing, Human = self.classes.Thing, self.classes.Human session = create_session() - human = session.query(Human).options( # noqa - sa.orm.joinedload("thing")).first() + human = ( + session.query(Human) + .options(sa.orm.joinedload("thing")) # noqa + .first() + ) thing = session.query(Thing).options(sa.orm.undefer("name")).first() self._test(thing) @@ -2786,8 +3336,12 @@ class DeferredPopulationTest(fixtures.MappedTest): Thing, Human = self.classes.Thing, self.classes.Human session = create_session() - result = session.query(Human).add_entity( # noqa - Thing).join("thing").first() + result = ( + session.query(Human) + .add_entity(Thing) # noqa + .join("thing") + .first() + ) session.expunge_all() thing = session.query(Thing).options(sa.orm.undefer("name")).first() self._test(thing) @@ -2796,14 +3350,18 @@ class DeferredPopulationTest(fixtures.MappedTest): Thing, Human = self.classes.Thing, self.classes.Human session = create_session() - result = session.query(Human).add_entity( # noqa - Thing).join("thing").first() + result = ( + session.query(Human) + .add_entity(Thing) # noqa + .join("thing") + .first() + ) thing = session.query(Thing).options(sa.orm.undefer("name")).first() self._test(thing) class NoLoadTest(_fixtures.FixtureTest): - run_inserts = 'once' + run_inserts = "once" run_deletes = None def test_o2m_noload(self): @@ -2812,11 +3370,18 @@ class NoLoadTest(_fixtures.FixtureTest): self.classes.Address, self.tables.addresses, self.tables.users, - self.classes.User) + self.classes.User, + ) - m = mapper(User, users, properties=dict( - addresses=relationship(mapper(Address, addresses), lazy='noload') - )) + m = mapper( + User, + users, + properties=dict( + addresses=relationship( + mapper(Address, addresses), lazy="noload" + ) + ), + ) q = create_session().query(m) result = [None] @@ -2824,11 +3389,11 @@ class NoLoadTest(_fixtures.FixtureTest): x = q.filter(User.id == 7).all() x[0].addresses result[0] = x + self.assert_sql_count(testing.db, go, 1) self.assert_result( - result[0], User, - {'id': 7, 'addresses': (Address, [])}, + result[0], User, {"id": 7, "addresses": (Address, [])} ) def test_upgrade_o2m_noload_lazyload_option(self): @@ -2836,23 +3401,30 @@ class NoLoadTest(_fixtures.FixtureTest): self.classes.Address, self.tables.addresses, self.tables.users, - self.classes.User) + self.classes.User, + ) - m = mapper(User, users, properties=dict( - addresses=relationship(mapper(Address, addresses), lazy='noload') - )) - q = create_session().query(m).options(sa.orm.lazyload('addresses')) + m = mapper( + User, + users, + properties=dict( + addresses=relationship( + mapper(Address, addresses), lazy="noload" + ) + ), + ) + q = create_session().query(m).options(sa.orm.lazyload("addresses")) result = [None] def go(): x = q.filter(User.id == 7).all() x[0].addresses result[0] = x + self.sql_count_(2, go) self.assert_result( - result[0], User, - {'id': 7, 'addresses': (Address, [{'id': 1}])}, + result[0], User, {"id": 7, "addresses": (Address, [{"id": 1}])} ) def test_m2o_noload_option(self): @@ -2860,22 +3432,26 @@ class NoLoadTest(_fixtures.FixtureTest): self.classes.Address, self.tables.addresses, self.tables.users, - self.classes.User) - mapper(Address, addresses, properties={ - 'user': relationship(User) - }) + self.classes.User, + ) + mapper(Address, addresses, properties={"user": relationship(User)}) mapper(User, users) s = Session() - a1 = s.query(Address).filter_by(id=1).options( - sa.orm.noload('user')).first() + a1 = ( + s.query(Address) + .filter_by(id=1) + .options(sa.orm.noload("user")) + .first() + ) def go(): eq_(a1.user, None) + self.sql_count_(0, go) class RaiseLoadTest(_fixtures.FixtureTest): - run_inserts = 'once' + run_inserts = "once" run_deletes = None def test_o2m_raiseload_mapper(self): @@ -2883,12 +3459,15 @@ class RaiseLoadTest(_fixtures.FixtureTest): self.classes.Address, self.tables.addresses, self.tables.users, - self.classes.User) + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, lazy='raise') - )) + mapper( + User, + users, + properties=dict(addresses=relationship(Address, lazy="raise")), + ) q = create_session().query(User) result = [None] @@ -2897,67 +3476,70 @@ class RaiseLoadTest(_fixtures.FixtureTest): assert_raises_message( sa.exc.InvalidRequestError, "'User.addresses' is not available due to lazy='raise'", - lambda: x[0].addresses) + lambda: x[0].addresses, + ) result[0] = x + self.assert_sql_count(testing.db, go, 1) - self.assert_result( - result[0], User, - {'id': 7}, - ) + self.assert_result(result[0], User, {"id": 7}) def test_o2m_raiseload_option(self): Address, addresses, users, User = ( self.classes.Address, self.tables.addresses, self.tables.users, - self.classes.User) + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address) - )) + mapper(User, users, properties=dict(addresses=relationship(Address))) q = create_session().query(User) result = [None] def go(): - x = q.options( - sa.orm.raiseload(User.addresses)).filter(User.id == 7).all() + x = ( + q.options(sa.orm.raiseload(User.addresses)) + .filter(User.id == 7) + .all() + ) assert_raises_message( sa.exc.InvalidRequestError, "'User.addresses' is not available due to lazy='raise'", - lambda: x[0].addresses) + lambda: x[0].addresses, + ) result[0] = x + self.assert_sql_count(testing.db, go, 1) - self.assert_result( - result[0], User, - {'id': 7}, - ) + self.assert_result(result[0], User, {"id": 7}) def test_o2m_raiseload_lazyload_option(self): Address, addresses, users, User = ( self.classes.Address, self.tables.addresses, self.tables.users, - self.classes.User) + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, lazy='raise') - )) - q = create_session().query(User).options(sa.orm.lazyload('addresses')) + mapper( + User, + users, + properties=dict(addresses=relationship(Address, lazy="raise")), + ) + q = create_session().query(User).options(sa.orm.lazyload("addresses")) result = [None] def go(): x = q.filter(User.id == 7).all() x[0].addresses result[0] = x + self.sql_count_(2, go) self.assert_result( - result[0], User, - {'id': 7, 'addresses': (Address, [{'id': 1}])}, + result[0], User, {"id": 7, "addresses": (Address, [{"id": 1}])} ) def test_m2o_raiseload_option(self): @@ -2965,20 +3547,24 @@ class RaiseLoadTest(_fixtures.FixtureTest): self.classes.Address, self.tables.addresses, self.tables.users, - self.classes.User) - mapper(Address, addresses, properties={ - 'user': relationship(User) - }) + self.classes.User, + ) + mapper(Address, addresses, properties={"user": relationship(User)}) mapper(User, users) s = Session() - a1 = s.query(Address).filter_by(id=1).options( - sa.orm.raiseload('user')).first() + a1 = ( + s.query(Address) + .filter_by(id=1) + .options(sa.orm.raiseload("user")) + .first() + ) def go(): assert_raises_message( sa.exc.InvalidRequestError, "'Address.user' is not available due to lazy='raise'", - lambda: a1.user) + lambda: a1.user, + ) self.sql_count_(0, go) @@ -2987,29 +3573,37 @@ class RaiseLoadTest(_fixtures.FixtureTest): self.classes.Address, self.tables.addresses, self.tables.users, - self.classes.User) - mapper(Address, addresses, properties={ - 'user': relationship(User) - }) + self.classes.User, + ) + mapper(Address, addresses, properties={"user": relationship(User)}) mapper(User, users) s = Session() - a1 = s.query(Address).filter_by(id=1).options( - sa.orm.raiseload('user', sql_only=True)).first() + a1 = ( + s.query(Address) + .filter_by(id=1) + .options(sa.orm.raiseload("user", sql_only=True)) + .first() + ) def go(): assert_raises_message( sa.exc.InvalidRequestError, "'Address.user' is not available due to lazy='raise_on_sql'", - lambda: a1.user) + lambda: a1.user, + ) self.sql_count_(0, go) s.close() u1 = s.query(User).first() - a1 = s.query(Address).filter_by(id=1).options( - sa.orm.raiseload('user', sql_only=True)).first() - assert 'user' not in a1.__dict__ + a1 = ( + s.query(Address) + .filter_by(id=1) + .options(sa.orm.raiseload("user", sql_only=True)) + .first() + ) + assert "user" not in a1.__dict__ is_(a1.user, u1) def test_m2o_non_use_get_raise_on_sql_option(self): @@ -3017,27 +3611,37 @@ class RaiseLoadTest(_fixtures.FixtureTest): self.classes.Address, self.tables.addresses, self.tables.users, - self.classes.User) - mapper(Address, addresses, properties={ - 'user': relationship( - User, - primaryjoin=sa.and_( - addresses.c.user_id == users.c.id, - users.c.name != None # noqa + self.classes.User, + ) + mapper( + Address, + addresses, + properties={ + "user": relationship( + User, + primaryjoin=sa.and_( + addresses.c.user_id == users.c.id, + users.c.name != None, # noqa + ), ) - ) - }) + }, + ) mapper(User, users) s = Session() u1 = s.query(User).first() - a1 = s.query(Address).filter_by(id=1).options( - sa.orm.raiseload('user', sql_only=True)).first() + a1 = ( + s.query(Address) + .filter_by(id=1) + .options(sa.orm.raiseload("user", sql_only=True)) + .first() + ) def go(): assert_raises_message( sa.exc.InvalidRequestError, "'Address.user' is not available due to lazy='raise_on_sql'", - lambda: a1.user) + lambda: a1.user, + ) class RequirementsTest(fixtures.MappedTest): @@ -3046,38 +3650,52 @@ class RequirementsTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('ht1', metadata, - Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('value', String(10))) - Table('ht2', metadata, - Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('ht1_id', Integer, ForeignKey('ht1.id')), - Column('value', String(10))) - Table('ht3', metadata, - Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('value', String(10))) - Table('ht4', metadata, - Column('ht1_id', Integer, ForeignKey('ht1.id'), - primary_key=True), - Column('ht3_id', Integer, ForeignKey('ht3.id'), - primary_key=True)) - Table('ht5', metadata, - Column('ht1_id', Integer, ForeignKey('ht1.id'), - primary_key=True)) - Table('ht6', metadata, - Column('ht1a_id', Integer, ForeignKey('ht1.id'), - primary_key=True), - Column('ht1b_id', Integer, ForeignKey('ht1.id'), - primary_key=True), - Column('value', String(10))) + Table( + "ht1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("value", String(10)), + ) + Table( + "ht2", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("ht1_id", Integer, ForeignKey("ht1.id")), + Column("value", String(10)), + ) + Table( + "ht3", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("value", String(10)), + ) + Table( + "ht4", + metadata, + Column("ht1_id", Integer, ForeignKey("ht1.id"), primary_key=True), + Column("ht3_id", Integer, ForeignKey("ht3.id"), primary_key=True), + ) + Table( + "ht5", + metadata, + Column("ht1_id", Integer, ForeignKey("ht1.id"), primary_key=True), + ) + Table( + "ht6", + metadata, + Column("ht1a_id", Integer, ForeignKey("ht1.id"), primary_key=True), + Column("ht1b_id", Integer, ForeignKey("ht1.id"), primary_key=True), + Column("value", String(10)), + ) if util.py2k: + def test_baseclass(self): ht1 = self.tables.ht1 @@ -3096,8 +3714,7 @@ class RequirementsTest(fixtures.MappedTest): # sa.exc.ArgumentError, mapper, NoWeakrefSupport, t2) class _ValueBase(object): - - def __init__(self, value='abc', id=None): + def __init__(self, value="abc", id=None): self.id = id self.value = value @@ -3122,12 +3739,14 @@ class RequirementsTest(fixtures.MappedTest): test run. """ - ht6, ht5, ht4, ht3, ht2, ht1 = (self.tables.ht6, - self.tables.ht5, - self.tables.ht4, - self.tables.ht3, - self.tables.ht2, - self.tables.ht1) + ht6, ht5, ht4, ht3, ht2, ht1 = ( + self.tables.ht6, + self.tables.ht5, + self.tables.ht4, + self.tables.ht3, + self.tables.ht2, + self.tables.ht1, + ) class H1(self._ValueBase): pass @@ -3141,32 +3760,35 @@ class RequirementsTest(fixtures.MappedTest): class H6(self._ValueBase): pass - mapper(H1, ht1, properties={ - 'h2s': relationship(H2, backref='h1'), - 'h3s': relationship(H3, secondary=ht4, backref='h1s'), - 'h1s': relationship(H1, secondary=ht5, backref='parent_h1'), - 't6a': relationship(H6, backref='h1a', - primaryjoin=ht1.c.id == ht6.c.ht1a_id), - 't6b': relationship(H6, backref='h1b', - primaryjoin=ht1.c.id == ht6.c.ht1b_id), - }) + mapper( + H1, + ht1, + properties={ + "h2s": relationship(H2, backref="h1"), + "h3s": relationship(H3, secondary=ht4, backref="h1s"), + "h1s": relationship(H1, secondary=ht5, backref="parent_h1"), + "t6a": relationship( + H6, backref="h1a", primaryjoin=ht1.c.id == ht6.c.ht1a_id + ), + "t6b": relationship( + H6, backref="h1b", primaryjoin=ht1.c.id == ht6.c.ht1b_id + ), + }, + ) mapper(H2, ht2) mapper(H3, ht3) mapper(H6, ht6) s = create_session() - s.add_all([ - H1('abc'), - H1('def'), - ]) - h1 = H1('ghi') + s.add_all([H1("abc"), H1("def")]) + h1 = H1("ghi") s.add(h1) - h1.h2s.append(H2('abc')) + h1.h2s.append(H2("abc")) h1.h3s.extend([H3(), H3()]) h1.h1s.append(H1()) s.flush() - eq_(select([func.count('*')]).select_from(ht1).scalar(), 4) + eq_(select([func.count("*")]).select_from(ht1).scalar(), 4) h6 = H6() h6.h1a = h1 @@ -3177,87 +3799,98 @@ class RequirementsTest(fixtures.MappedTest): h6.h1b = x = H1() assert x in s - h6.h1b.h2s.append(H2('def')) + h6.h1b.h2s.append(H2("def")) s.flush() - h1.h2s.extend([H2('abc'), H2('def')]) + h1.h2s.extend([H2("abc"), H2("def")]) s.flush() - h1s = s.query(H1).options(sa.orm.joinedload('h2s')).all() + h1s = s.query(H1).options(sa.orm.joinedload("h2s")).all() eq_(len(h1s), 5) - self.assert_unordered_result(h1s, H1, - {'h2s': []}, - {'h2s': []}, - {'h2s': (H2, [{'value': 'abc'}, - {'value': 'def'}, - {'value': 'abc'}])}, - {'h2s': []}, - {'h2s': (H2, [{'value': 'def'}])}) + self.assert_unordered_result( + h1s, + H1, + {"h2s": []}, + {"h2s": []}, + { + "h2s": ( + H2, + [{"value": "abc"}, {"value": "def"}, {"value": "abc"}], + ) + }, + {"h2s": []}, + {"h2s": (H2, [{"value": "def"}])}, + ) - h1s = s.query(H1).options(sa.orm.joinedload('h3s')).all() + h1s = s.query(H1).options(sa.orm.joinedload("h3s")).all() eq_(len(h1s), 5) - h1s = s.query(H1).options(sa.orm.joinedload_all('t6a.h1b'), - sa.orm.joinedload('h2s'), - sa.orm.joinedload_all('h3s.h1s')).all() + h1s = ( + s.query(H1) + .options( + sa.orm.joinedload_all("t6a.h1b"), + sa.orm.joinedload("h2s"), + sa.orm.joinedload_all("h3s.h1s"), + ) + .all() + ) eq_(len(h1s), 5) def test_composite_results(self): - ht2, ht1 = (self.tables.ht2, - self.tables.ht1) + ht2, ht1 = (self.tables.ht2, self.tables.ht1) class H1(self._ValueBase): - def __init__(self, value, id, h2s): self.value = value self.id = id self.h2s = h2s class H2(self._ValueBase): - def __init__(self, value, id): self.value = value self.id = id - mapper(H1, ht1, properties={ - 'h2s': relationship(H2, backref='h1'), - }) + mapper(H1, ht1, properties={"h2s": relationship(H2, backref="h1")}) mapper(H2, ht2) s = Session() - s.add_all([ - H1('abc', 1, h2s=[ - H2('abc', id=1), - H2('def', id=2), - H2('def', id=3), - ]), - H1('def', 2, h2s=[ - H2('abc', id=4), - H2('abc', id=5), - H2('def', id=6), - ]), - ]) + s.add_all( + [ + H1( + "abc", + 1, + h2s=[H2("abc", id=1), H2("def", id=2), H2("def", id=3)], + ), + H1( + "def", + 2, + h2s=[H2("abc", id=4), H2("abc", id=5), H2("def", id=6)], + ), + ] + ) s.commit() eq_( - [(h1.value, h1.id, h2.value, h2.id) - for h1, h2 in - s.query(H1, H2).join(H1.h2s).order_by(H1.id, H2.id)], [ - ('abc', 1, 'abc', 1), - ('abc', 1, 'def', 2), - ('abc', 1, 'def', 3), - ('def', 2, 'abc', 4), - ('def', 2, 'abc', 5), - ('def', 2, 'def', 6), - ] + (h1.value, h1.id, h2.value, h2.id) + for h1, h2 in s.query(H1, H2) + .join(H1.h2s) + .order_by(H1.id, H2.id) + ], + [ + ("abc", 1, "abc", 1), + ("abc", 1, "def", 2), + ("abc", 1, "def", 3), + ("def", 2, "abc", 4), + ("def", 2, "abc", 5), + ("def", 2, "def", 6), + ], ) def test_nonzero_len_recursion(self): ht1 = self.tables.ht1 class H1(object): - def __len__(self): return len(self.get_value()) @@ -3266,7 +3899,6 @@ class RequirementsTest(fixtures.MappedTest): return self.value class H2(object): - def __bool__(self): return bool(self.get_value()) @@ -3287,13 +3919,14 @@ class RequirementsTest(fixtures.MappedTest): class IsUserlandTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('foo', metadata, - Column('id', Integer, primary_key=True), - Column('someprop', Integer) - ) + Table( + "foo", + metadata, + Column("id", Integer, primary_key=True), + Column("someprop", Integer), + ) def _test(self, value, instancelevel=None): class Foo(object): @@ -3340,26 +3973,33 @@ class IsUserlandTest(fixtures.MappedTest): def test_descriptor(self): def somefunc(self): return "hi" + self._test(property(somefunc), "hi") class MagicNamesTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('cartographers', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('alias', String(50)), - Column('quip', String(100))) - Table('maps', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('cart_id', Integer, - ForeignKey('cartographers.id')), - Column('state', String(2)), - Column('data', sa.Text)) + Table( + "cartographers", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), + Column("alias", String(50)), + Column("quip", String(100)), + ) + Table( + "maps", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("cart_id", Integer, ForeignKey("cartographers.id")), + Column("state", String(2)), + Column("data", sa.Text), + ) @classmethod def setup_classes(cls): @@ -3370,59 +4010,95 @@ class MagicNamesTest(fixtures.MappedTest): pass def test_mappish(self): - maps, Cartographer, cartographers, Map = (self.tables.maps, - self.classes.Cartographer, - self.tables.cartographers, - self.classes.Map) + maps, Cartographer, cartographers, Map = ( + self.tables.maps, + self.classes.Cartographer, + self.tables.cartographers, + self.classes.Map, + ) - mapper(Cartographer, cartographers, properties=dict( - query=cartographers.c.quip)) - mapper(Map, maps, properties=dict( - mapper=relationship(Cartographer, backref='maps'))) + mapper( + Cartographer, + cartographers, + properties=dict(query=cartographers.c.quip), + ) + mapper( + Map, + maps, + properties=dict(mapper=relationship(Cartographer, backref="maps")), + ) - c = Cartographer(name='Lenny', alias='The Dude', - query='Where be dragons?') - Map(state='AK', mapper=c) + c = Cartographer( + name="Lenny", alias="The Dude", query="Where be dragons?" + ) + Map(state="AK", mapper=c) sess = create_session() sess.add(c) sess.flush() sess.expunge_all() - for C, M in ((Cartographer, Map), - (sa.orm.aliased(Cartographer), sa.orm.aliased(Map))): - c1 = (sess.query(C). - filter(C.alias == 'The Dude'). - filter(C.query == 'Where be dragons?')).one() + for C, M in ( + (Cartographer, Map), + (sa.orm.aliased(Cartographer), sa.orm.aliased(Map)), + ): + c1 = ( + sess.query(C) + .filter(C.alias == "The Dude") + .filter(C.query == "Where be dragons?") + ).one() sess.query(M).filter(M.mapper == c1).one() def test_direct_stateish(self): - for reserved in (sa.orm.instrumentation.ClassManager.STATE_ATTR, - sa.orm.instrumentation.ClassManager.MANAGER_ATTR): - t = Table('t', sa.MetaData(), - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column(reserved, Integer)) + for reserved in ( + sa.orm.instrumentation.ClassManager.STATE_ATTR, + sa.orm.instrumentation.ClassManager.MANAGER_ATTR, + ): + t = Table( + "t", + sa.MetaData(), + Column( + "id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column(reserved, Integer), + ) class T(object): pass + assert_raises_message( KeyError, - ('%r: requested attribute name conflicts with ' - 'instrumentation attribute of the same name.' % reserved), - mapper, T, t) + ( + "%r: requested attribute name conflicts with " + "instrumentation attribute of the same name." % reserved + ), + mapper, + T, + t, + ) def test_indirect_stateish(self): maps = self.tables.maps - for reserved in (sa.orm.instrumentation.ClassManager.STATE_ATTR, - sa.orm.instrumentation.ClassManager.MANAGER_ATTR): + for reserved in ( + sa.orm.instrumentation.ClassManager.STATE_ATTR, + sa.orm.instrumentation.ClassManager.MANAGER_ATTR, + ): + class M(object): pass assert_raises_message( KeyError, - ('requested attribute name conflicts with ' - 'instrumentation attribute of the same name'), - mapper, M, maps, properties={ - reserved: maps.c.state}) + ( + "requested attribute name conflicts with " + "instrumentation attribute of the same name" + ), + mapper, + M, + maps, + properties={reserved: maps.c.state}, + ) diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py index 8c13902388..9ba2e5bf3b 100644 --- a/test/orm/test_merge.py +++ b/test/orm/test_merge.py @@ -4,9 +4,22 @@ from sqlalchemy import Integer, PickleType, String, ForeignKey, Text import operator from sqlalchemy import testing from sqlalchemy.util import OrderedSet -from sqlalchemy.orm import mapper, relationship, create_session, \ - PropComparator, synonym, comparable_property, sessionmaker, \ - attributes, Session, backref, configure_mappers, foreign, deferred, defer +from sqlalchemy.orm import ( + mapper, + relationship, + create_session, + PropComparator, + synonym, + comparable_property, + sessionmaker, + attributes, + Session, + backref, + configure_mappers, + foreign, + deferred, + defer, +) from sqlalchemy.orm.collections import attribute_mapped_collection from sqlalchemy.orm.interfaces import MapperOption from sqlalchemy.testing import eq_, in_, not_in_ @@ -23,11 +36,13 @@ class MergeTest(_fixtures.FixtureTest): def load_tracker(self, cls, canary=None): if canary is None: + def canary(instance, *args): canary.called += 1 + canary.called = 0 - event.listen(cls, 'load', canary) + event.listen(cls, "load", canary) return canary @@ -38,15 +53,15 @@ class MergeTest(_fixtures.FixtureTest): sess = create_session() load = self.load_tracker(User) - u = User(id=7, name='fred') + u = User(id=7, name="fred") eq_(load.called, 0) u2 = sess.merge(u) eq_(load.called, 1) assert u2 in sess - eq_(u2, User(id=7, name='fred')) + eq_(u2, User(id=7, name="fred")) sess.flush() sess.expunge_all() - eq_(sess.query(User).first(), User(id=7, name='fred')) + eq_(sess.query(User).first(), User(id=7, name="fred")) def test_transient_to_pending_no_pk(self): """test that a transient object with no PK attribute @@ -56,29 +71,44 @@ class MergeTest(_fixtures.FixtureTest): mapper(User, users) sess = create_session() - u = User(name='fred') + u = User(name="fred") def go(): sess.merge(u) + self.assert_sql_count(testing.db, go, 0) def test_transient_to_pending_collection(self): - User, Address, addresses, users = (self.classes.User, - self.classes.Address, - self.tables.addresses, - self.tables.users) - - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user', - collection_class=OrderedSet)}) + User, Address, addresses, users = ( + self.classes.User, + self.classes.Address, + self.tables.addresses, + self.tables.users, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, backref="user", collection_class=OrderedSet + ) + }, + ) mapper(Address, addresses) load = self.load_tracker(User) self.load_tracker(Address, load) - u = User(id=7, name='fred', addresses=OrderedSet([ - Address(id=1, email_address='fred1'), - Address(id=2, email_address='fred2'), - ])) + u = User( + id=7, + name="fred", + addresses=OrderedSet( + [ + Address(id=1, email_address="fred1"), + Address(id=2, email_address="fred2"), + ] + ), + ) eq_(load.called, 0) sess = create_session() @@ -92,29 +122,51 @@ class MergeTest(_fixtures.FixtureTest): sess.flush() sess.expunge_all() - eq_(sess.query(User).one(), - User(id=7, name='fred', addresses=OrderedSet([ - Address(id=1, email_address='fred1'), - Address(id=2, email_address='fred2'), - ]))) + eq_( + sess.query(User).one(), + User( + id=7, + name="fred", + addresses=OrderedSet( + [ + Address(id=1, email_address="fred1"), + Address(id=2, email_address="fred2"), + ] + ), + ), + ) def test_transient_to_pending_collection_pk_none(self): - User, Address, addresses, users = (self.classes.User, - self.classes.Address, - self.tables.addresses, - self.tables.users) - - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user', - collection_class=OrderedSet)}) + User, Address, addresses, users = ( + self.classes.User, + self.classes.Address, + self.tables.addresses, + self.tables.users, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, backref="user", collection_class=OrderedSet + ) + }, + ) mapper(Address, addresses) load = self.load_tracker(User) self.load_tracker(Address, load) - u = User(id=None, name='fred', addresses=OrderedSet([ - Address(id=None, email_address='fred1'), - Address(id=None, email_address='fred2'), - ])) + u = User( + id=None, + name="fred", + addresses=OrderedSet( + [ + Address(id=None, email_address="fred1"), + Address(id=None, email_address="fred2"), + ] + ), + ) eq_(load.called, 0) sess = create_session() @@ -128,11 +180,18 @@ class MergeTest(_fixtures.FixtureTest): sess.flush() sess.expunge_all() - eq_(sess.query(User).one(), - User(name='fred', addresses=OrderedSet([ - Address(email_address='fred1'), - Address(email_address='fred2'), - ]))) + eq_( + sess.query(User).one(), + User( + name="fred", + addresses=OrderedSet( + [ + Address(email_address="fred1"), + Address(email_address="fred2"), + ] + ), + ), + ) def test_transient_to_persistent(self): User, users = self.classes.User, self.tables.users @@ -141,45 +200,59 @@ class MergeTest(_fixtures.FixtureTest): load = self.load_tracker(User) sess = create_session() - u = User(id=7, name='fred') + u = User(id=7, name="fred") sess.add(u) sess.flush() sess.expunge_all() eq_(load.called, 0) - _u2 = u2 = User(id=7, name='fred jones') + _u2 = u2 = User(id=7, name="fred jones") eq_(load.called, 0) u2 = sess.merge(u2) assert u2 is not _u2 eq_(load.called, 1) sess.flush() sess.expunge_all() - eq_(sess.query(User).first(), User(id=7, name='fred jones')) + eq_(sess.query(User).first(), User(id=7, name="fred jones")) eq_(load.called, 2) def test_transient_to_persistent_collection(self): - User, Address, addresses, users = (self.classes.User, - self.classes.Address, - self.tables.addresses, - self.tables.users) - - mapper(User, users, properties={ - 'addresses': relationship(Address, - backref='user', - collection_class=OrderedSet, - order_by=addresses.c.id, - cascade="all, delete-orphan") - }) + User, Address, addresses, users = ( + self.classes.User, + self.classes.Address, + self.tables.addresses, + self.tables.users, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, + backref="user", + collection_class=OrderedSet, + order_by=addresses.c.id, + cascade="all, delete-orphan", + ) + }, + ) mapper(Address, addresses) load = self.load_tracker(User) self.load_tracker(Address, load) - u = User(id=7, name='fred', addresses=OrderedSet([ - Address(id=1, email_address='fred1'), - Address(id=2, email_address='fred2'), - ])) + u = User( + id=7, + name="fred", + addresses=OrderedSet( + [ + Address(id=1, email_address="fred1"), + Address(id=2, email_address="fred2"), + ] + ), + ) sess = create_session() sess.add(u) sess.flush() @@ -187,10 +260,16 @@ class MergeTest(_fixtures.FixtureTest): eq_(load.called, 0) - u = User(id=7, name='fred', addresses=OrderedSet([ - Address(id=3, email_address='fred3'), - Address(id=4, email_address='fred4'), - ])) + u = User( + id=7, + name="fred", + addresses=OrderedSet( + [ + Address(id=3, email_address="fred3"), + Address(id=4, email_address="fred4"), + ] + ), + ) u = sess.merge(u) @@ -200,46 +279,72 @@ class MergeTest(_fixtures.FixtureTest): # marks as deleted, Address ids 1 and 2. eq_(load.called, 5) - eq_(u, - User(id=7, name='fred', addresses=OrderedSet([ - Address(id=3, email_address='fred3'), - Address(id=4, email_address='fred4'), - ]))) + eq_( + u, + User( + id=7, + name="fred", + addresses=OrderedSet( + [ + Address(id=3, email_address="fred3"), + Address(id=4, email_address="fred4"), + ] + ), + ), + ) sess.flush() sess.expunge_all() - eq_(sess.query(User).one(), - User(id=7, name='fred', addresses=OrderedSet([ - Address(id=3, email_address='fred3'), - Address(id=4, email_address='fred4'), - ]))) + eq_( + sess.query(User).one(), + User( + id=7, + name="fred", + addresses=OrderedSet( + [ + Address(id=3, email_address="fred3"), + Address(id=4, email_address="fred4"), + ] + ), + ), + ) def test_detached_to_persistent_collection(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, - backref='user', - order_by=addresses.c.id, - collection_class=OrderedSet)}) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, + backref="user", + order_by=addresses.c.id, + collection_class=OrderedSet, + ) + }, + ) mapper(Address, addresses) load = self.load_tracker(User) self.load_tracker(Address, load) - a = Address(id=1, email_address='fred1') - u = User(id=7, name='fred', addresses=OrderedSet([ - a, - Address(id=2, email_address='fred2'), - ])) + a = Address(id=1, email_address="fred1") + u = User( + id=7, + name="fred", + addresses=OrderedSet([a, Address(id=2, email_address="fred2")]), + ) sess = create_session() sess.add(u) sess.flush() sess.expunge_all() - u.name = 'fred jones' - u.addresses.add(Address(id=3, email_address='fred3')) + u.name = "fred jones" + u.addresses.add(Address(id=3, email_address="fred3")) u.addresses.remove(a) eq_(load.called, 0) @@ -248,53 +353,91 @@ class MergeTest(_fixtures.FixtureTest): sess.flush() sess.expunge_all() - eq_(sess.query(User).first(), - User(id=7, name='fred jones', addresses=OrderedSet([ - Address(id=2, email_address='fred2'), - Address(id=3, email_address='fred3')]))) + eq_( + sess.query(User).first(), + User( + id=7, + name="fred jones", + addresses=OrderedSet( + [ + Address(id=2, email_address="fred2"), + Address(id=3, email_address="fred3"), + ] + ), + ), + ) def test_unsaved_cascade(self): """Merge of a transient entity with two child transient entities, with a bidirectional relationship.""" - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses), - cascade="all", backref="user") - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), cascade="all", backref="user" + ) + }, + ) load = self.load_tracker(User) self.load_tracker(Address, load) sess = create_session() - u = User(id=7, name='fred') - a1 = Address(email_address='foo@bar.com') - a2 = Address(email_address='hoho@bar.com') + u = User(id=7, name="fred") + a1 = Address(email_address="foo@bar.com") + a2 = Address(email_address="hoho@bar.com") u.addresses.append(a1) u.addresses.append(a2) u2 = sess.merge(u) eq_(load.called, 3) - eq_(u, - User(id=7, name='fred', addresses=[ - Address(email_address='foo@bar.com'), - Address(email_address='hoho@bar.com')])) - eq_(u2, - User(id=7, name='fred', addresses=[ - Address(email_address='foo@bar.com'), - Address(email_address='hoho@bar.com')])) + eq_( + u, + User( + id=7, + name="fred", + addresses=[ + Address(email_address="foo@bar.com"), + Address(email_address="hoho@bar.com"), + ], + ), + ) + eq_( + u2, + User( + id=7, + name="fred", + addresses=[ + Address(email_address="foo@bar.com"), + Address(email_address="hoho@bar.com"), + ], + ), + ) sess.flush() sess.expunge_all() u2 = sess.query(User).get(7) - eq_(u2, User(id=7, name='fred', addresses=[ - Address(email_address='foo@bar.com'), - Address(email_address='hoho@bar.com')])) + eq_( + u2, + User( + id=7, + name="fred", + addresses=[ + Address(email_address="foo@bar.com"), + Address(email_address="hoho@bar.com"), + ], + ), + ) eq_(load.called, 6) def test_merge_empty_attributes(self): @@ -321,17 +464,17 @@ class MergeTest(_fixtures.FixtureTest): # value isn't whacked from the destination # dict. u3 = sess.merge(User(id=2)) - eq_(u3.__dict__['data'], "foo") + eq_(u3.__dict__["data"], "foo") # make a change. - u3.data = 'bar' + u3.data = "bar" # merge another no-"data" user. # attribute maintains modified state. # (usually autoflush would have happened # here anyway). u4 = sess.merge(User(id=2)) - eq_(u3.__dict__['data'], "bar") + eq_(u3.__dict__["data"], "bar") sess.flush() # and after the flush. @@ -352,70 +495,98 @@ class MergeTest(_fixtures.FixtureTest): # not sure if I like this - it currently is needed # for test_pickled:PickleTest.test_instance_deferred_cols u6 = sess.merge(User(id=3)) - assert 'data' not in u6.__dict__ + assert "data" not in u6.__dict__ assert u6.data == "foo" # set it to None. this is actually # a change so gets preserved. u6.data = None u7 = sess.merge(User(id=3)) - assert u6.__dict__['data'] is None + assert u6.__dict__["data"] is None def test_merge_irregular_collection(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), - backref='user', - collection_class=attribute_mapped_collection('email_address')), - }) - u1 = User(id=7, name='fred') - u1.addresses['foo@bar.com'] = Address(email_address='foo@bar.com') + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), + backref="user", + collection_class=attribute_mapped_collection( + "email_address" + ), + ) + }, + ) + u1 = User(id=7, name="fred") + u1.addresses["foo@bar.com"] = Address(email_address="foo@bar.com") sess = create_session() sess.merge(u1) sess.flush() - assert list(u1.addresses.keys()) == ['foo@bar.com'] + assert list(u1.addresses.keys()) == ["foo@bar.com"] def test_attribute_cascade(self): """Merge of a persistent entity with two child persistent entities.""" - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses), - backref='user') - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), backref="user" + ) + }, + ) load = self.load_tracker(User) self.load_tracker(Address, load) sess = create_session() # set up data and save - u = User(id=7, name='fred', addresses=[ - Address(email_address='foo@bar.com'), - Address(email_address='hoho@la.com')]) + u = User( + id=7, + name="fred", + addresses=[ + Address(email_address="foo@bar.com"), + Address(email_address="hoho@la.com"), + ], + ) sess.add(u) sess.flush() # assert data was saved sess2 = create_session() u2 = sess2.query(User).get(7) - eq_(u2, - User(id=7, name='fred', addresses=[ - Address(email_address='foo@bar.com'), - Address(email_address='hoho@la.com')])) + eq_( + u2, + User( + id=7, + name="fred", + addresses=[ + Address(email_address="foo@bar.com"), + Address(email_address="hoho@la.com"), + ], + ), + ) # make local changes to data - u.name = 'fred2' - u.addresses[1].email_address = 'hoho@lalala.com' + u.name = "fred2" + u.addresses[1].email_address = "hoho@lalala.com" eq_(load.called, 3) @@ -425,9 +596,17 @@ class MergeTest(_fixtures.FixtureTest): eq_(load.called, 6) # ensure local changes are pending - eq_(u3, User(id=7, name='fred2', addresses=[ - Address(email_address='foo@bar.com'), - Address(email_address='hoho@lalala.com')])) + eq_( + u3, + User( + id=7, + name="fred2", + addresses=[ + Address(email_address="foo@bar.com"), + Address(email_address="hoho@lalala.com"), + ], + ), + ) # save merged data sess3.flush() @@ -435,9 +614,17 @@ class MergeTest(_fixtures.FixtureTest): # assert modified/merged data was saved sess.expunge_all() u = sess.query(User).get(7) - eq_(u, User(id=7, name='fred2', addresses=[ - Address(email_address='foo@bar.com'), - Address(email_address='hoho@lalala.com')])) + eq_( + u, + User( + id=7, + name="fred2", + addresses=[ + Address(email_address="foo@bar.com"), + Address(email_address="hoho@lalala.com"), + ], + ), + ) eq_(load.called, 9) # merge persistent object into another session @@ -449,6 +636,7 @@ class MergeTest(_fixtures.FixtureTest): def go(): sess4.flush() + # no changes; therefore flush should do nothing self.assert_sql_count(testing.db, go, 0) eq_(load.called, 12) @@ -462,6 +650,7 @@ class MergeTest(_fixtures.FixtureTest): def go(): sess5.flush() + # no changes; therefore flush should do nothing # but also, load=False wipes out any difference in committed state, # so no flush at all @@ -471,24 +660,26 @@ class MergeTest(_fixtures.FixtureTest): sess4 = create_session() u = sess4.merge(u, load=False) # post merge change - u.addresses[1].email_address = 'afafds' + u.addresses[1].email_address = "afafds" def go(): sess4.flush() + # afafds change flushes self.assert_sql_count(testing.db, go, 1) eq_(load.called, 18) sess5 = create_session() u2 = sess5.query(User).get(u.id) - eq_(u2.name, 'fred2') - eq_(u2.addresses[1].email_address, 'afafds') + eq_(u2.name, "fred2") + eq_(u2.addresses[1].email_address, "afafds") eq_(load.called, 21) def test_dont_send_neverset_to_get(self): # test issue #3647 CompositePk, composite_pk_table = ( - self.classes.CompositePk, self.tables.composite_pk_table + self.classes.CompositePk, + self.tables.composite_pk_table, ) mapper(CompositePk, composite_pk_table) cp1 = CompositePk(j=1, k=1) @@ -499,6 +690,7 @@ class MergeTest(_fixtures.FixtureTest): def go(): rec.append(sess.merge(cp1)) + self.assert_sql_count(testing.db, go, 0) rec[0].i = 5 sess.commit() @@ -507,19 +699,23 @@ class MergeTest(_fixtures.FixtureTest): def test_dont_send_neverset_to_get_w_relationship(self): # test issue #3647 CompositePk, composite_pk_table = ( - self.classes.CompositePk, self.tables.composite_pk_table + self.classes.CompositePk, + self.tables.composite_pk_table, ) - User, users = ( - self.classes.User, self.tables.users + User, users = (self.classes.User, self.tables.users) + mapper( + User, + users, + properties={ + "elements": relationship( + CompositePk, + primaryjoin=users.c.id == foreign(composite_pk_table.c.i), + ) + }, ) - mapper(User, users, properties={ - 'elements': relationship( - CompositePk, - primaryjoin=users.c.id == foreign(composite_pk_table.c.i)) - }) mapper(CompositePk, composite_pk_table) - u1 = User(id=5, name='some user') + u1 = User(id=5, name="some user") cp1 = CompositePk(j=1, k=1) u1.elements.append(cp1) sess = Session() @@ -528,6 +724,7 @@ class MergeTest(_fixtures.FixtureTest): def go(): rec.append(sess.merge(u1)) + self.assert_sql_count(testing.db, go, 1) u2 = rec[0] sess.commit() @@ -539,14 +736,18 @@ class MergeTest(_fixtures.FixtureTest): target that specifically doesn't include 'merge' cascade. """ - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) - mapper(Address, addresses, properties={ - 'user': relationship(User, cascade="save-update") - }) + mapper( + Address, + addresses, + properties={"user": relationship(User, cascade="save-update")}, + ) mapper(User, users) sess = create_session() u1 = User(name="fred") @@ -560,36 +761,37 @@ class MergeTest(_fixtures.FixtureTest): # no expire of the attribute - assert a2.__dict__['user'] is u1 + assert a2.__dict__["user"] is u1 # merge succeeded eq_( - sess.query(Address).all(), - [Address(id=a1.id, email_address="bar")] + sess.query(Address).all(), [Address(id=a1.id, email_address="bar")] ) # didn't touch user - eq_( - sess.query(User).all(), - [User(name="fred")] - ) + eq_(sess.query(User).all(), [User(name="fred")]) def test_one_to_many_cascade(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses))}) + mapper( + User, + users, + properties={"addresses": relationship(mapper(Address, addresses))}, + ) load = self.load_tracker(User) self.load_tracker(Address, load) sess = create_session() - u = User(name='fred') - a1 = Address(email_address='foo@bar') - a2 = Address(email_address='foo@quux') + u = User(name="fred") + a1 = Address(email_address="foo@bar") + a2 = Address(email_address="foo@quux") u.addresses.extend([a1, a2]) sess.add(u) @@ -601,29 +803,29 @@ class MergeTest(_fixtures.FixtureTest): u2 = sess2.query(User).get(u.id) eq_(load.called, 1) - u.addresses[1].email_address = 'addr 2 modified' + u.addresses[1].email_address = "addr 2 modified" sess2.merge(u) - eq_(u2.addresses[1].email_address, 'addr 2 modified') + eq_(u2.addresses[1].email_address, "addr 2 modified") eq_(load.called, 3) sess3 = create_session() u3 = sess3.query(User).get(u.id) eq_(load.called, 4) - u.name = 'also fred' + u.name = "also fred" sess3.merge(u) eq_(load.called, 6) - eq_(u3.name, 'also fred') + eq_(u3.name, "also fred") def test_many_to_one_cascade(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) - - mapper(Address, addresses, properties={ - 'user': relationship(User) - }) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) + + mapper(Address, addresses, properties={"user": relationship(User)}) mapper(User, users) u1 = User(id=1, name="u1") @@ -638,32 +840,34 @@ class MergeTest(_fixtures.FixtureTest): sess2 = create_session() a2 = sess2.merge(a1) - eq_( - attributes.get_history(a2, 'user'), - ([u2], (), ()) - ) + eq_(attributes.get_history(a2, "user"), ([u2], (), ())) assert a2 in sess2.dirty sess.refresh(a1) sess2 = create_session() a2 = sess2.merge(a1, load=False) - eq_( - attributes.get_history(a2, 'user'), - ((), [u1], ()) - ) + eq_(attributes.get_history(a2, "user"), ((), [u1], ())) assert a2 not in sess2.dirty def test_many_to_many_cascade(self): - items, Order, orders, order_items, Item = (self.tables.items, - self.classes.Order, - self.tables.orders, - self.tables.order_items, - self.classes.Item) + items, Order, orders, order_items, Item = ( + self.tables.items, + self.classes.Order, + self.tables.orders, + self.tables.order_items, + self.classes.Item, + ) - mapper(Order, orders, properties={ - 'items': relationship(mapper(Item, items), - secondary=order_items)}) + mapper( + Order, + orders, + properties={ + "items": relationship( + mapper(Item, items), secondary=order_items + ) + }, + ) load = self.load_tracker(Order) self.load_tracker(Item, load) @@ -671,13 +875,13 @@ class MergeTest(_fixtures.FixtureTest): sess = create_session() i1 = Item() - i1.description = 'item 1' + i1.description = "item 1" i2 = Item() - i2.description = 'item 2' + i2.description = "item 2" o = Order() - o.description = 'order description' + o.description = "order description" o.items.append(i1) o.items.append(i2) @@ -690,30 +894,37 @@ class MergeTest(_fixtures.FixtureTest): o2 = sess2.query(Order).get(o.id) eq_(load.called, 1) - o.items[1].description = 'item 2 modified' + o.items[1].description = "item 2 modified" sess2.merge(o) - eq_(o2.items[1].description, 'item 2 modified') - eq_(load.called, 3) + eq_(o2.items[1].description, "item 2 modified") + eq_(load.called, 3) sess3 = create_session() o3 = sess3.query(Order).get(o.id) eq_(load.called, 4) - o.description = 'desc modified' + o.description = "desc modified" sess3.merge(o) eq_(load.called, 6) - eq_(o3.description, 'desc modified') + eq_(o3.description, "desc modified") def test_one_to_one_cascade(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'address': relationship(mapper(Address, addresses), - uselist=False) - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "address": relationship( + mapper(Address, addresses), uselist=False + ) + }, + ) load = self.load_tracker(User) self.load_tracker(Address, load) sess = create_session() @@ -722,7 +933,7 @@ class MergeTest(_fixtures.FixtureTest): u.id = 7 u.name = "fred" a1 = Address() - a1.email_address = 'foo@bar.com' + a1.email_address = "foo@bar.com" u.address = a1 sess.add(u) @@ -733,8 +944,8 @@ class MergeTest(_fixtures.FixtureTest): sess2 = create_session() u2 = sess2.query(User).get(7) eq_(load.called, 1) - u2.name = 'fred2' - u2.address.email_address = 'hoho@lalala.com' + u2.name = "fred2" + u2.address.email_address = "hoho@lalala.com" eq_(load.called, 2) u3 = sess.merge(u2) @@ -742,18 +953,28 @@ class MergeTest(_fixtures.FixtureTest): assert u3 is u def test_value_to_none(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'address': relationship(mapper(Address, addresses), - uselist=False, backref='user') - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "address": relationship( + mapper(Address, addresses), uselist=False, backref="user" + ) + }, + ) sess = sessionmaker()() - u = User(id=7, name="fred", - address=Address(id=1, email_address='foo@bar.com')) + u = User( + id=7, + name="fred", + address=Address(id=1, email_address="foo@bar.com"), + ) sess.add(u) sess.commit() sess.close() @@ -776,46 +997,62 @@ class MergeTest(_fixtures.FixtureTest): sess = create_session() u = User() - assert_raises_message(sa.exc.InvalidRequestError, - "load=False option does not support", - sess.merge, u, load=False) + assert_raises_message( + sa.exc.InvalidRequestError, + "load=False option does not support", + sess.merge, + u, + load=False, + ) def test_no_load_with_backrefs(self): """load=False populates relationships in both directions without requiring a load""" - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses), - backref='user') - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), backref="user" + ) + }, + ) - u = User(id=7, name='fred', addresses=[ - Address(email_address='ad1'), - Address(email_address='ad2')]) + u = User( + id=7, + name="fred", + addresses=[ + Address(email_address="ad1"), + Address(email_address="ad2"), + ], + ) sess = create_session() sess.add(u) sess.flush() sess.close() - assert 'user' in u.addresses[1].__dict__ + assert "user" in u.addresses[1].__dict__ sess = create_session() u2 = sess.merge(u, load=False) - assert 'user' in u2.addresses[1].__dict__ - eq_(u2.addresses[1].user, User(id=7, name='fred')) + assert "user" in u2.addresses[1].__dict__ + eq_(u2.addresses[1].user, User(id=7, name="fred")) - sess.expire(u2.addresses[1], ['user']) - assert 'user' not in u2.addresses[1].__dict__ + sess.expire(u2.addresses[1], ["user"]) + assert "user" not in u2.addresses[1].__dict__ sess.close() sess = create_session() u = sess.merge(u2, load=False) - assert 'user' not in u.addresses[1].__dict__ - eq_(u.addresses[1].user, User(id=7, name='fred')) + assert "user" not in u.addresses[1].__dict__ + eq_(u.addresses[1].user, User(id=7, name="fred")) def test_dontload_with_eager(self): """ @@ -831,34 +1068,38 @@ class MergeTest(_fixtures.FixtureTest): """ - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses)) - }) + mapper( + User, + users, + properties={"addresses": relationship(mapper(Address, addresses))}, + ) sess = create_session() u = User() u.id = 7 u.name = "fred" a1 = Address() - a1.email_address = 'foo@bar.com' + a1.email_address = "foo@bar.com" u.addresses.append(a1) sess.add(u) sess.flush() sess2 = create_session() - u2 = sess2.query(User).\ - options(sa.orm.joinedload('addresses')).get(7) + u2 = sess2.query(User).options(sa.orm.joinedload("addresses")).get(7) sess3 = create_session() u3 = sess3.merge(u2, load=False) def go(): sess3.flush() + self.assert_sql_count(testing.db, go, 0) def test_no_load_disallows_dirty(self): @@ -878,16 +1119,17 @@ class MergeTest(_fixtures.FixtureTest): sess.add(u) sess.flush() - u.name = 'ed' + u.name = "ed" sess2 = create_session() try: sess2.merge(u, load=False) assert False except sa.exc.InvalidRequestError as e: - assert "merge() with load=False option does not support "\ - "objects marked as 'dirty'. flush() all changes on "\ - "mapped instances before merging with load=False." \ - in str(e) + assert ( + "merge() with load=False option does not support " + "objects marked as 'dirty'. flush() all changes on " + "mapped instances before merging with load=False." in str(e) + ) u2 = sess2.query(User).get(7) @@ -897,24 +1139,33 @@ class MergeTest(_fixtures.FixtureTest): def go(): sess3.flush() + self.assert_sql_count(testing.db, go, 0) def test_no_load_sets_backrefs(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses), - backref='user')}) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), backref="user" + ) + }, + ) sess = create_session() u = User() u.id = 7 u.name = "fred" a1 = Address() - a1.email_address = 'foo@bar.com' + a1.email_address = "foo@bar.com" u.addresses.append(a1) sess.add(u) @@ -928,6 +1179,7 @@ class MergeTest(_fixtures.FixtureTest): def go(): assert u2.addresses[0].user is u2 + self.assert_sql_count(testing.db, go, 0) def test_no_load_preserves_parents(self): @@ -945,21 +1197,30 @@ class MergeTest(_fixtures.FixtureTest): """ - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses), - backref='user', - cascade="all, delete-orphan")}) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), + backref="user", + cascade="all, delete-orphan", + ) + }, + ) sess = create_session() u = User() u.id = 7 u.name = "fred" a1 = Address() - a1.email_address = 'foo@bar.com' + a1.email_address = "foo@bar.com" u.addresses.append(a1) sess.add(u) sess.flush() @@ -970,14 +1231,17 @@ class MergeTest(_fixtures.FixtureTest): u2 = sess2.merge(u, load=False) assert not sess2.dirty a2 = u2.addresses[0] - a2.email_address = 'somenewaddress' + a2.email_address = "somenewaddress" assert not sa.orm.object_mapper(a2)._is_orphan( - sa.orm.attributes.instance_state(a2)) + sa.orm.attributes.instance_state(a2) + ) sess2.flush() sess2.expunge_all() - eq_(sess2.query(User).get(u2.id).addresses[0].email_address, - 'somenewaddress') + eq_( + sess2.query(User).get(u2.id).addresses[0].email_address, + "somenewaddress", + ) # this use case is not supported; this is with a pending Address # on the pre-merged object, and we currently don't support @@ -998,13 +1262,16 @@ class MergeTest(_fixtures.FixtureTest): # if load=False is changed to support dirty objects, this code # needs to pass a2 = u2.addresses[0] - a2.email_address = 'somenewaddress' + a2.email_address = "somenewaddress" assert not sa.orm.object_mapper(a2)._is_orphan( - sa.orm.attributes.instance_state(a2)) + sa.orm.attributes.instance_state(a2) + ) sess2.flush() sess2.expunge_all() - eq_(sess2.query(User).get(u2.id).addresses[0].email_address, - 'somenewaddress') + eq_( + sess2.query(User).get(u2.id).addresses[0].email_address, + "somenewaddress", + ) except sa.exc.InvalidRequestError as e: assert "load=False option does not support" in str(e) @@ -1012,7 +1279,6 @@ class MergeTest(_fixtures.FixtureTest): users = self.tables.users class User(object): - class Comparator(PropComparator): pass @@ -1020,18 +1286,22 @@ class MergeTest(_fixtures.FixtureTest): return self._value def _setValue(self, value): - setattr(self, '_value', value) + setattr(self, "_value", value) value = property(_getValue, _setValue) - mapper(User, users, properties={ - 'uid': synonym('id'), - 'foobar': comparable_property(User.Comparator, User.value), - }) + mapper( + User, + users, + properties={ + "uid": synonym("id"), + "foobar": comparable_property(User.Comparator, User.value), + }, + ) sess = create_session() u = User() - u.name = 'ed' + u.name = "ed" sess.add(u) sess.flush() sess.expunge(u) @@ -1040,20 +1310,27 @@ class MergeTest(_fixtures.FixtureTest): def test_cascade_doesnt_blowaway_manytoone(self): """a merge test that was fixed by [ticket:1202]""" - User, Address, addresses, users = (self.classes.User, - self.classes.Address, - self.tables.addresses, - self.tables.users) + User, Address, addresses, users = ( + self.classes.User, + self.classes.Address, + self.tables.addresses, + self.tables.users, + ) s = create_session(autoflush=True, autocommit=False) - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses), - backref='user')}) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), backref="user" + ) + }, + ) - a1 = Address(user=s.merge(User(id=1, name='ed')), email_address='x') + a1 = Address(user=s.merge(User(id=1, name="ed")), email_address="x") before_id = id(a1.user) - a2 = Address(user=s.merge(User(id=1, name='jack')), - email_address='x') + a2 = Address(user=s.merge(User(id=1, name="jack")), email_address="x") after_id = id(a1.user) other_id = id(a2.user) eq_(before_id, other_id) @@ -1062,48 +1339,67 @@ class MergeTest(_fixtures.FixtureTest): eq_(a1.user, a2.user) def test_cascades_dont_autoflush(self): - User, Address, addresses, users = (self.classes.User, - self.classes.Address, - self.tables.addresses, - self.tables.users) + User, Address, addresses, users = ( + self.classes.User, + self.classes.Address, + self.tables.addresses, + self.tables.users, + ) sess = create_session(autoflush=True, autocommit=False) - m = mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses), - backref='user')}) - user = User(id=8, name='fred', - addresses=[Address(email_address='user')]) + m = mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), backref="user" + ) + }, + ) + user = User( + id=8, name="fred", addresses=[Address(email_address="user")] + ) merged_user = sess.merge(user) assert merged_user in sess.new sess.flush() assert merged_user not in sess.new def test_cascades_dont_autoflush_2(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, - backref='user', - cascade="all, delete-orphan") - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, backref="user", cascade="all, delete-orphan" + ) + }, + ) mapper(Address, addresses) - u = User(id=7, name='fred', addresses=[ - Address(id=1, email_address='fred1'), - ]) + u = User( + id=7, name="fred", addresses=[Address(id=1, email_address="fred1")] + ) sess = create_session(autoflush=True, autocommit=False) sess.add(u) sess.commit() sess.expunge_all() - u = User(id=7, name='fred', addresses=[ - Address(id=1, email_address='fred1'), - Address(id=2, email_address='fred2'), - ]) + u = User( + id=7, + name="fred", + addresses=[ + Address(id=1, email_address="fred1"), + Address(id=2, email_address="fred2"), + ], + ) sess.merge(u) assert sess.autoflush sess.commit() @@ -1121,6 +1417,7 @@ class MergeTest(_fixtures.FixtureTest): def go(): eq_(u.name, None) + self.assert_sql_count(testing.db, go, 0) def test_option_state(self): @@ -1140,10 +1437,7 @@ class MergeTest(_fixtures.FixtureTest): umapper = mapper(User, users) - sess.add_all([ - User(id=1, name='u1'), - User(id=2, name='u2'), - ]) + sess.add_all([User(id=1, name="u1"), User(id=2, name="u2")]) sess.commit() sess2 = sessionmaker()() @@ -1155,7 +1449,7 @@ class MergeTest(_fixtures.FixtureTest): for u in s1_users: ustate = attributes.instance_state(u) - eq_(ustate.load_path.path, (umapper, )) + eq_(ustate.load_path.path, (umapper,)) eq_(ustate.load_options, set()) for u in s2_users: @@ -1163,7 +1457,7 @@ class MergeTest(_fixtures.FixtureTest): for u in s1_users: ustate = attributes.instance_state(u) - eq_(ustate.load_path.path, (umapper, )) + eq_(ustate.load_path.path, (umapper,)) eq_(ustate.load_options, set([opt2])) # test 2. present options are replaced by merge options @@ -1171,7 +1465,7 @@ class MergeTest(_fixtures.FixtureTest): s1_users = sess.query(User).options(opt1).all() for u in s1_users: ustate = attributes.instance_state(u) - eq_(ustate.load_path.path, (umapper, )) + eq_(ustate.load_path.path, (umapper,)) eq_(ustate.load_options, set([opt1])) for u in s2_users: @@ -1179,28 +1473,30 @@ class MergeTest(_fixtures.FixtureTest): for u in s1_users: ustate = attributes.instance_state(u) - eq_(ustate.load_path.path, (umapper, )) + eq_(ustate.load_path.path, (umapper,)) eq_(ustate.load_options, set([opt2])) def test_resolve_conflicts_pending_doesnt_interfere_no_ident(self): User, Address, Order = ( - self.classes.User, self.classes.Address, self.classes.Order) + self.classes.User, + self.classes.Address, + self.classes.Order, + ) users, addresses, orders = ( - self.tables.users, self.tables.addresses, self.tables.orders) - - mapper(User, users, properties={ - 'orders': relationship(Order) - }) - mapper(Order, orders, properties={ - 'address': relationship(Address) - }) + self.tables.users, + self.tables.addresses, + self.tables.orders, + ) + + mapper(User, users, properties={"orders": relationship(Order)}) + mapper(Order, orders, properties={"address": relationship(Address)}) mapper(Address, addresses) - u1 = User(id=7, name='x') + u1 = User(id=7, name="x") u1.orders = [ - Order(description='o1', address=Address(email_address='a')), - Order(description='o2', address=Address(email_address='b')), - Order(description='o3', address=Address(email_address='c')) + Order(description="o1", address=Address(email_address="a")), + Order(description="o2", address=Address(email_address="b")), + Order(description="o3", address=Address(email_address="c")), ] sess = Session() @@ -1208,74 +1504,73 @@ class MergeTest(_fixtures.FixtureTest): sess.flush() eq_( - sess.query(Address.email_address).order_by( - Address.email_address).all(), - [('a', ), ('b', ), ('c', )] + sess.query(Address.email_address) + .order_by(Address.email_address) + .all(), + [("a",), ("b",), ("c",)], ) def test_resolve_conflicts_pending(self): User, Address, Order = ( - self.classes.User, self.classes.Address, self.classes.Order) + self.classes.User, + self.classes.Address, + self.classes.Order, + ) users, addresses, orders = ( - self.tables.users, self.tables.addresses, self.tables.orders) - - mapper(User, users, properties={ - 'orders': relationship(Order) - }) - mapper(Order, orders, properties={ - 'address': relationship(Address) - }) + self.tables.users, + self.tables.addresses, + self.tables.orders, + ) + + mapper(User, users, properties={"orders": relationship(Order)}) + mapper(Order, orders, properties={"address": relationship(Address)}) mapper(Address, addresses) - u1 = User(id=7, name='x') + u1 = User(id=7, name="x") u1.orders = [ - Order(description='o1', address=Address(id=1, email_address='a')), - Order(description='o2', address=Address(id=1, email_address='b')), - Order(description='o3', address=Address(id=1, email_address='c')) + Order(description="o1", address=Address(id=1, email_address="a")), + Order(description="o2", address=Address(id=1, email_address="b")), + Order(description="o3", address=Address(id=1, email_address="c")), ] sess = Session() sess.merge(u1) sess.flush() - eq_( - sess.query(Address).one(), - Address(id=1, email_address='c') - ) + eq_(sess.query(Address).one(), Address(id=1, email_address="c")) def test_resolve_conflicts_persistent(self): User, Address, Order = ( - self.classes.User, self.classes.Address, self.classes.Order) + self.classes.User, + self.classes.Address, + self.classes.Order, + ) users, addresses, orders = ( - self.tables.users, self.tables.addresses, self.tables.orders) - - mapper(User, users, properties={ - 'orders': relationship(Order) - }) - mapper(Order, orders, properties={ - 'address': relationship(Address) - }) + self.tables.users, + self.tables.addresses, + self.tables.orders, + ) + + mapper(User, users, properties={"orders": relationship(Order)}) + mapper(Order, orders, properties={"address": relationship(Address)}) mapper(Address, addresses) sess = Session() - sess.add(Address(id=1, email_address='z')) + sess.add(Address(id=1, email_address="z")) sess.commit() - u1 = User(id=7, name='x') + u1 = User(id=7, name="x") u1.orders = [ - Order(description='o1', address=Address(id=1, email_address='a')), - Order(description='o2', address=Address(id=1, email_address='b')), - Order(description='o3', address=Address(id=1, email_address='c')) + Order(description="o1", address=Address(id=1, email_address="a")), + Order(description="o2", address=Address(id=1, email_address="b")), + Order(description="o3", address=Address(id=1, email_address="c")), ] sess = Session() sess.merge(u1) sess.flush() - eq_( - sess.query(Address).one(), - Address(id=1, email_address='c') - ) + eq_(sess.query(Address).one(), Address(id=1, email_address="c")) class M2ONoUseGetLoadingTest(fixtures.MappedTest): @@ -1287,15 +1582,23 @@ class M2ONoUseGetLoadingTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('user', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50))) - Table('address', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', Integer, ForeignKey('user.id')), - Column('email', String(50))) + Table( + "user", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), + ) + Table( + "address", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", Integer, ForeignKey("user.id")), + Column("email", String(50)), + ) @classmethod def setup_classes(cls): @@ -1309,17 +1612,24 @@ class M2ONoUseGetLoadingTest(fixtures.MappedTest): def setup_mappers(cls): User, Address = cls.classes.User, cls.classes.Address user, address = cls.tables.user, cls.tables.address - mapper(User, user, properties={ - 'addresses': relationship(Address, - backref=backref( - 'user', - # needlessly complex primaryjoin so - # that the use_get flag is False - primaryjoin=and_( - user.c.id == address.c.user_id, - user.c.id == user.c.id - ))) - }) + mapper( + User, + user, + properties={ + "addresses": relationship( + Address, + backref=backref( + "user", + # needlessly complex primaryjoin so + # that the use_get flag is False + primaryjoin=and_( + user.c.id == address.c.user_id, + user.c.id == user.c.id, + ), + ), + ) + }, + ) mapper(Address, address) configure_mappers() assert Address.user.property._use_get is False @@ -1328,10 +1638,18 @@ class M2ONoUseGetLoadingTest(fixtures.MappedTest): def insert_data(cls): User, Address = cls.classes.User, cls.classes.Address s = Session() - s.add_all([ - User(id=1, name='u1', addresses=[Address(id=1, email='a1'), - Address(id=2, email='a2')]) - ]) + s.add_all( + [ + User( + id=1, + name="u1", + addresses=[ + Address(id=1, email="a1"), + Address(id=2, email="a2"), + ], + ) + ] + ) s.commit() # "persistent" - we get at an Address that was already present. @@ -1345,6 +1663,7 @@ class M2ONoUseGetLoadingTest(fixtures.MappedTest): def go(): u1 = User(id=1, addresses=[Address(id=1), Address(id=2)]) u2 = s.merge(u1) + self.assert_sql_count(testing.db, go, 2) def test_persistent_access_one(self): @@ -1356,6 +1675,7 @@ class M2ONoUseGetLoadingTest(fixtures.MappedTest): u2 = s.merge(u1) a1 = u2.addresses[0] assert a1.user is u2 + self.assert_sql_count(testing.db, go, 3) def test_persistent_access_two(self): @@ -1369,6 +1689,7 @@ class M2ONoUseGetLoadingTest(fixtures.MappedTest): assert a1.user is u2 a2 = u2.addresses[1] assert a2.user is u2 + self.assert_sql_count(testing.db, go, 4) # "pending" - we get at an Address that is new- user_id should be @@ -1381,12 +1702,18 @@ class M2ONoUseGetLoadingTest(fixtures.MappedTest): s = Session() def go(): - u1 = User(id=1, - addresses=[Address(id=1), Address(id=2), - Address(id=3, email='a3')]) + u1 = User( + id=1, + addresses=[ + Address(id=1), + Address(id=2), + Address(id=3, email="a3"), + ], + ) u2 = s.merge(u1) a3 = u2.addresses[2] assert a3.user is u2 + self.assert_sql_count(testing.db, go, 3) def test_pending_access_two(self): @@ -1394,14 +1721,20 @@ class M2ONoUseGetLoadingTest(fixtures.MappedTest): s = Session() def go(): - u1 = User(id=1, - addresses=[Address(id=1), Address(id=2), - Address(id=3, email='a3')]) + u1 = User( + id=1, + addresses=[ + Address(id=1), + Address(id=2), + Address(id=3, email="a3"), + ], + ) u2 = s.merge(u1) a3 = u2.addresses[2] assert a3.user is u2 a2 = u2.addresses[1] assert a2.user is u2 + self.assert_sql_count(testing.db, go, 5) @@ -1409,11 +1742,12 @@ class DeferredMergeTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'book', metadata, - Column('id', Integer, primary_key=True), - Column('title', String(200), nullable=False), - Column('summary', String(2000)), - Column('excerpt', Text), + "book", + metadata, + Column("id", Integer, primary_key=True), + Column("title", String(200), nullable=False), + Column("summary", String(2000)), + Column("excerpt", Text), ) @classmethod @@ -1424,41 +1758,48 @@ class DeferredMergeTest(fixtures.MappedTest): def test_deferred_column_mapping(self): # defer 'excerpt' at mapping level instead of query level Book, book = self.classes.Book, self.tables.book - mapper(Book, book, properties={'excerpt': deferred(book.c.excerpt)}) + mapper(Book, book, properties={"excerpt": deferred(book.c.excerpt)}) sess = sessionmaker()() b = Book( id=1, - title='Essential SQLAlchemy', - summary='some summary', - excerpt='some excerpt', + title="Essential SQLAlchemy", + summary="some summary", + excerpt="some excerpt", ) sess.add(b) sess.commit() b1 = sess.query(Book).first() - sess.expire(b1, ['summary']) + sess.expire(b1, ["summary"]) sess.close() def go(): b2 = sess.merge(b1, load=False) # should not emit load for deferred 'excerpt' - eq_(b2.summary, 'some summary') - not_in_('excerpt', b2.__dict__) + eq_(b2.summary, "some summary") + not_in_("excerpt", b2.__dict__) # now it should emit load for deferred 'excerpt' - eq_(b2.excerpt, 'some excerpt') - in_('excerpt', b2.__dict__) - - self.sql_eq_(go, [ - ("SELECT book.summary AS book_summary " - "FROM book WHERE book.id = :param_1", - {'param_1': 1}), - ("SELECT book.excerpt AS book_excerpt " - "FROM book WHERE book.id = :param_1", - {'param_1': 1}) - ]) + eq_(b2.excerpt, "some excerpt") + in_("excerpt", b2.__dict__) + + self.sql_eq_( + go, + [ + ( + "SELECT book.summary AS book_summary " + "FROM book WHERE book.id = :param_1", + {"param_1": 1}, + ), + ( + "SELECT book.excerpt AS book_excerpt " + "FROM book WHERE book.id = :param_1", + {"param_1": 1}, + ), + ], + ) def test_deferred_column_query(self): Book, book = self.classes.Book, self.tables.book @@ -1467,46 +1808,57 @@ class DeferredMergeTest(fixtures.MappedTest): b = Book( id=1, - title='Essential SQLAlchemy', - summary='some summary', - excerpt='some excerpt', + title="Essential SQLAlchemy", + summary="some summary", + excerpt="some excerpt", ) sess.add(b) sess.commit() # defer 'excerpt' at query level instead of mapping level b1 = sess.query(Book).options(defer(Book.excerpt)).first() - sess.expire(b1, ['summary']) + sess.expire(b1, ["summary"]) sess.close() def go(): b2 = sess.merge(b1, load=False) # should not emit load for deferred 'excerpt' - eq_(b2.summary, 'some summary') - not_in_('excerpt', b2.__dict__) + eq_(b2.summary, "some summary") + not_in_("excerpt", b2.__dict__) # now it should emit load for deferred 'excerpt' - eq_(b2.excerpt, 'some excerpt') - in_('excerpt', b2.__dict__) - - self.sql_eq_(go, [ - ("SELECT book.summary AS book_summary " - "FROM book WHERE book.id = :param_1", - {'param_1': 1}), - ("SELECT book.excerpt AS book_excerpt " - "FROM book WHERE book.id = :param_1", - {'param_1': 1}) - ]) + eq_(b2.excerpt, "some excerpt") + in_("excerpt", b2.__dict__) + + self.sql_eq_( + go, + [ + ( + "SELECT book.summary AS book_summary " + "FROM book WHERE book.id = :param_1", + {"param_1": 1}, + ), + ( + "SELECT book.excerpt AS book_excerpt " + "FROM book WHERE book.id = :param_1", + {"param_1": 1}, + ), + ], + ) class MutableMergeTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table("data", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', PickleType(comparator=operator.eq))) + Table( + "data", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", PickleType(comparator=operator.eq)), + ) @classmethod def setup_classes(cls): @@ -1531,9 +1883,12 @@ class MutableMergeTest(fixtures.MappedTest): class CompositeNullPksTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table("data", metadata, - Column('pk1', String(10), primary_key=True), - Column('pk2', String(10), primary_key=True)) + Table( + "data", + metadata, + Column("pk1", String(10), primary_key=True), + Column("pk2", String(10), primary_key=True), + ) @classmethod def setup_classes(cls): @@ -1550,6 +1905,7 @@ class CompositeNullPksTest(fixtures.MappedTest): def go(): return sess.merge(d1) + self.assert_sql_count(testing.db, go, 1) def test_merge_disallow_partial(self): @@ -1562,19 +1918,27 @@ class CompositeNullPksTest(fixtures.MappedTest): def go(): return sess.merge(d1) + self.assert_sql_count(testing.db, go, 0) class LoadOnPendingTest(fixtures.MappedTest): """Test interaction of merge() with load_on_pending relationships""" + @classmethod def define_tables(cls, metadata): - rocks_table = Table("rocks", metadata, - Column("id", Integer, primary_key=True), - Column("description", String(10))) - bugs_table = Table("bugs", metadata, - Column("id", Integer, primary_key=True), - Column("rockid", Integer, ForeignKey('rocks.id'))) + rocks_table = Table( + "rocks", + metadata, + Column("id", Integer, primary_key=True), + Column("description", String(10)), + ) + bugs_table = Table( + "bugs", + metadata, + Column("id", Integer, primary_key=True), + Column("rockid", Integer, ForeignKey("rocks.id")), + ) @classmethod def setup_classes(cls): @@ -1585,18 +1949,24 @@ class LoadOnPendingTest(fixtures.MappedTest): pass def _setup_delete_orphan_o2o(self): - mapper(self.classes.Rock, self.tables.rocks, - properties={'bug': relationship(self.classes.Bug, - cascade='all,delete-orphan', - load_on_pending=True, - uselist=False) - }) + mapper( + self.classes.Rock, + self.tables.rocks, + properties={ + "bug": relationship( + self.classes.Bug, + cascade="all,delete-orphan", + load_on_pending=True, + uselist=False, + ) + }, + ) mapper(self.classes.Bug, self.tables.bugs) self.sess = sessionmaker()() def _merge_delete_orphan_o2o_with(self, bug): # create a transient rock with passed bug - r = self.classes.Rock(id=0, description='moldy') + r = self.classes.Rock(id=0, description="moldy") r.bug = bug m = self.sess.merge(r) # we've already passed ticket #2374 problem since merge() returned, @@ -1625,11 +1995,18 @@ class PolymorphicOnTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('employees', metadata, - Column('employee_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('type', String(1), nullable=False), - Column('data', String(50))) + Table( + "employees", + metadata, + Column( + "employee_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("type", String(1), nullable=False), + Column("data", String(50)), + ) @classmethod def setup_classes(cls): @@ -1643,21 +2020,30 @@ class PolymorphicOnTest(fixtures.MappedTest): pass def _setup_polymorphic_on_mappers(self): - employee_mapper = mapper(self.classes.Employee, - self.tables.employees, - polymorphic_on=case( - value=self.tables.employees.c.type, - whens={ - 'E': 'employee', - 'M': 'manager', - 'G': 'engineer', - 'R': 'engineer', - }), - polymorphic_identity='employee') - mapper(self.classes.Manager, inherits=employee_mapper, - polymorphic_identity='manager') - mapper(self.classes.Engineer, inherits=employee_mapper, - polymorphic_identity='engineer') + employee_mapper = mapper( + self.classes.Employee, + self.tables.employees, + polymorphic_on=case( + value=self.tables.employees.c.type, + whens={ + "E": "employee", + "M": "manager", + "G": "engineer", + "R": "engineer", + }, + ), + polymorphic_identity="employee", + ) + mapper( + self.classes.Manager, + inherits=employee_mapper, + polymorphic_identity="manager", + ) + mapper( + self.classes.Engineer, + inherits=employee_mapper, + polymorphic_identity="engineer", + ) self.sess = sessionmaker()() def test_merge_polymorphic_on(self): @@ -1666,13 +2052,14 @@ class PolymorphicOnTest(fixtures.MappedTest): """ self._setup_polymorphic_on_mappers() - m = self.classes.Manager(employee_id=55, type='M', - data='original data') + m = self.classes.Manager( + employee_id=55, type="M", data="original data" + ) self.sess.add(m) self.sess.commit() self.sess.expunge_all() - m = self.classes.Manager(employee_id=55, data='updated data') + m = self.classes.Manager(employee_id=55, data="updated data") merged = self.sess.merge(m) # we've already passed ticket #2449 problem since diff --git a/test/orm/test_naturalpks.py b/test/orm/test_naturalpks.py index eb20c38158..3abc4065e0 100644 --- a/test/orm/test_naturalpks.py +++ b/test/orm/test_naturalpks.py @@ -3,8 +3,13 @@ Primary key changing capabilities and passive/non-passive cascading updates. """ -from sqlalchemy.testing import fixtures, eq_, ne_, assert_raises, \ - assert_raises_message +from sqlalchemy.testing import ( + fixtures, + eq_, + ne_, + assert_raises, + assert_raises_message, +) import sqlalchemy as sa from sqlalchemy import testing, Integer, String, ForeignKey from sqlalchemy.testing.schema import Table, Column @@ -15,56 +20,70 @@ from test.orm import _fixtures def _backend_specific_fk_args(): if testing.requires.deferrable_fks.enabled: - fk_args = dict(deferrable=True, initially='deferred') + fk_args = dict(deferrable=True, initially="deferred") elif not testing.requires.on_update_cascade.enabled: fk_args = dict() else: - fk_args = dict(onupdate='cascade') + fk_args = dict(onupdate="cascade") return fk_args class NaturalPKTest(fixtures.MappedTest): # MySQL 5.5 on Windows crashes (the entire server, not the client) # if you screw around with ON UPDATE CASCADE type of stuff. - __requires__ = 'skip_mysql_on_windows', 'on_update_or_deferrable_fks' + __requires__ = "skip_mysql_on_windows", "on_update_or_deferrable_fks" __backend__ = True @classmethod def define_tables(cls, metadata): fk_args = _backend_specific_fk_args() - Table('users', metadata, - Column('username', String(50), primary_key=True), - Column('fullname', String(100)), - test_needs_fk=True) + Table( + "users", + metadata, + Column("username", String(50), primary_key=True), + Column("fullname", String(100)), + test_needs_fk=True, + ) Table( - 'addresses', metadata, - Column('email', String(50), primary_key=True), + "addresses", + metadata, + Column("email", String(50), primary_key=True), Column( - 'username', String(50), - ForeignKey('users.username', **fk_args)), - test_needs_fk=True) + "username", String(50), ForeignKey("users.username", **fk_args) + ), + test_needs_fk=True, + ) Table( - 'items', metadata, - Column('itemname', String(50), primary_key=True), - Column('description', String(100)), - test_needs_fk=True) + "items", + metadata, + Column("itemname", String(50), primary_key=True), + Column("description", String(100)), + test_needs_fk=True, + ) Table( - 'users_to_items', metadata, + "users_to_items", + metadata, Column( - 'username', String(50), - ForeignKey('users.username', **fk_args), primary_key=True), + "username", + String(50), + ForeignKey("users.username", **fk_args), + primary_key=True, + ), Column( - 'itemname', String(50), - ForeignKey('items.itemname', **fk_args), primary_key=True), - test_needs_fk=True) + "itemname", + String(50), + ForeignKey("items.itemname", **fk_args), + primary_key=True, + ), + test_needs_fk=True, + ) @classmethod def setup_classes(cls): - class User(cls.Comparable): pass @@ -80,24 +99,25 @@ class NaturalPKTest(fixtures.MappedTest): mapper(User, users) sess = create_session() - u1 = User(username='jack', fullname='jack') + u1 = User(username="jack", fullname="jack") sess.add(u1) sess.flush() - assert sess.query(User).get('jack') is u1 + assert sess.query(User).get("jack") is u1 - u1.username = 'ed' + u1.username = "ed" sess.flush() def go(): - assert sess.query(User).get('ed') is u1 + assert sess.query(User).get("ed") is u1 + self.assert_sql_count(testing.db, go, 0) - assert sess.query(User).get('jack') is None + assert sess.query(User).get("jack") is None sess.expunge_all() - u1 = sess.query(User).get('ed') - eq_(User(username='ed', fullname='jack'), u1) + u1 = sess.query(User).get("ed") + eq_(User(username="ed", fullname="jack"), u1) def test_load_after_expire(self): users, User = self.tables.users, self.classes.User @@ -105,23 +125,23 @@ class NaturalPKTest(fixtures.MappedTest): mapper(User, users) sess = create_session() - u1 = User(username='jack', fullname='jack') + u1 = User(username="jack", fullname="jack") sess.add(u1) sess.flush() - assert sess.query(User).get('jack') is u1 + assert sess.query(User).get("jack") is u1 - users.update(values={User.username: 'jack'}).execute(username='ed') + users.update(values={User.username: "jack"}).execute(username="ed") # expire/refresh works off of primary key. the PK is gone # in this case so there's no way to look it up. criterion- # based session invalidation could solve this [ticket:911] sess.expire(u1) - assert_raises(sa.orm.exc.ObjectDeletedError, getattr, u1, 'username') + assert_raises(sa.orm.exc.ObjectDeletedError, getattr, u1, "username") sess.expunge_all() - assert sess.query(User).get('jack') is None - assert sess.query(User).get('ed').fullname == 'jack' + assert sess.query(User).get("jack") is None + assert sess.query(User).get("ed").fullname == "jack" @testing.requires.returning def test_update_to_sql_expr(self): @@ -130,16 +150,16 @@ class NaturalPKTest(fixtures.MappedTest): mapper(User, users) sess = create_session() - u1 = User(username='jack', fullname='jack') + u1 = User(username="jack", fullname="jack") sess.add(u1) sess.flush() - u1.username = User.username + ' jones' + u1.username = User.username + " jones" sess.flush() - eq_(u1.username, 'jack jones') + eq_(u1.username, "jack jones") def test_update_to_self_sql_expr(self): # SQL expression where the PK won't actually change, @@ -149,33 +169,33 @@ class NaturalPKTest(fixtures.MappedTest): mapper(User, users) sess = create_session() - u1 = User(username='jack', fullname='jack') + u1 = User(username="jack", fullname="jack") sess.add(u1) sess.flush() - u1.username = User.username + '' + u1.username = User.username + "" sess.flush() - eq_(u1.username, 'jack') + eq_(u1.username, "jack") def test_flush_new_pk_after_expire(self): User, users = self.classes.User, self.tables.users mapper(User, users) sess = create_session() - u1 = User(username='jack', fullname='jack') + u1 = User(username="jack", fullname="jack") sess.add(u1) sess.flush() - assert sess.query(User).get('jack') is u1 + assert sess.query(User).get("jack") is u1 sess.expire(u1) - u1.username = 'ed' + u1.username = "ed" sess.flush() sess.expunge_all() - assert sess.query(User).get('ed').fullname == 'jack' + assert sess.query(User).get("ed").fullname == "jack" @testing.requires.on_update_cascade def test_onetomany_passive(self): @@ -185,40 +205,49 @@ class NaturalPKTest(fixtures.MappedTest): self._test_onetomany(False) def _test_onetomany(self, passive_updates): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper( - User, users, properties={ - 'addresses': relationship( - Address, passive_updates=passive_updates)}) + User, + users, + properties={ + "addresses": relationship( + Address, passive_updates=passive_updates + ) + }, + ) mapper(Address, addresses) sess = create_session() - u1 = User(username='jack', fullname='jack') - u1.addresses.append(Address(email='jack1')) - u1.addresses.append(Address(email='jack2')) + u1 = User(username="jack", fullname="jack") + u1.addresses.append(Address(email="jack1")) + u1.addresses.append(Address(email="jack2")) sess.add(u1) sess.flush() - assert sess.query(Address).get('jack1') is u1.addresses[0] + assert sess.query(Address).get("jack1") is u1.addresses[0] - u1.username = 'ed' + u1.username = "ed" sess.flush() - assert u1.addresses[0].username == 'ed' + assert u1.addresses[0].username == "ed" sess.expunge_all() eq_( - [Address(username='ed'), Address(username='ed')], - sess.query(Address).all()) + [Address(username="ed"), Address(username="ed")], + sess.query(Address).all(), + ) - u1 = sess.query(User).get('ed') - u1.username = 'jack' + u1 = sess.query(User).get("ed") + u1.username = "jack" def go(): sess.flush() + if not passive_updates: # test passive_updates=False; # load addresses, update user, update 2 addresses @@ -228,18 +257,18 @@ class NaturalPKTest(fixtures.MappedTest): self.assert_sql_count(testing.db, go, 1) sess.expunge_all() assert User( - username='jack', addresses=[ - Address(username='jack'), - Address(username='jack')]) == sess.query(User).get('jack') + username="jack", + addresses=[Address(username="jack"), Address(username="jack")], + ) == sess.query(User).get("jack") - u1 = sess.query(User).get('jack') + u1 = sess.query(User).get("jack") u1.addresses = [] - u1.username = 'fred' + u1.username = "fred" sess.flush() sess.expunge_all() - assert sess.query(Address).get('jack1').username is None - u1 = sess.query(User).get('fred') - eq_(User(username='fred', fullname='jack'), u1) + assert sess.query(Address).get("jack1").username is None + u1 = sess.query(User).get("fred") + eq_(User(username="fred", fullname="jack"), u1) @testing.requires.on_update_cascade def test_manytoone_passive(self): @@ -254,57 +283,68 @@ class NaturalPKTest(fixtures.MappedTest): hasn't yet been part of a flush. """ - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) with testing.db.begin() as conn: - conn.execute(users.insert(), username='jack', fullname='jack') - conn.execute(addresses.insert(), email='jack1', username='jack') - conn.execute(addresses.insert(), email='jack2', username='jack') + conn.execute(users.insert(), username="jack", fullname="jack") + conn.execute(addresses.insert(), email="jack1", username="jack") + conn.execute(addresses.insert(), email="jack2", username="jack") mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship(User, - passive_updates=False) - }) + mapper( + Address, + addresses, + properties={"user": relationship(User, passive_updates=False)}, + ) sess = create_session() u1 = sess.query(User).first() a1, a2 = sess.query(Address).all() - u1.username = 'ed' + u1.username = "ed" def go(): sess.flush() + self.assert_sql_count(testing.db, go, 2) def _test_manytoone(self, passive_updates): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship(User, passive_updates=passive_updates) - }) + mapper( + Address, + addresses, + properties={ + "user": relationship(User, passive_updates=passive_updates) + }, + ) sess = create_session() - a1 = Address(email='jack1') - a2 = Address(email='jack2') + a1 = Address(email="jack1") + a2 = Address(email="jack2") - u1 = User(username='jack', fullname='jack') + u1 = User(username="jack", fullname="jack") a1.user = u1 a2.user = u1 sess.add(a1) sess.add(a2) sess.flush() - u1.username = 'ed' + u1.username = "ed" def go(): sess.flush() + if passive_updates: self.assert_sql_count(testing.db, go, 1) else: @@ -312,13 +352,15 @@ class NaturalPKTest(fixtures.MappedTest): def go(): sess.flush() + self.assert_sql_count(testing.db, go, 0) - assert a1.username == a2.username == 'ed' + assert a1.username == a2.username == "ed" sess.expunge_all() eq_( - [Address(username='ed'), Address(username='ed')], - sess.query(Address).all()) + [Address(username="ed"), Address(username="ed")], + sess.query(Address).all(), + ) @testing.requires.on_update_cascade def test_onetoone_passive(self): @@ -328,43 +370,52 @@ class NaturalPKTest(fixtures.MappedTest): self._test_onetoone(False) def _test_onetoone(self, passive_updates): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper( - User, users, properties={ + User, + users, + properties={ "address": relationship( - Address, passive_updates=passive_updates, uselist=False)}) + Address, passive_updates=passive_updates, uselist=False + ) + }, + ) mapper(Address, addresses) sess = create_session() - u1 = User(username='jack', fullname='jack') + u1 = User(username="jack", fullname="jack") sess.add(u1) sess.flush() - a1 = Address(email='jack1') + a1 = Address(email="jack1") u1.address = a1 sess.add(a1) sess.flush() - u1.username = 'ed' + u1.username = "ed" def go(): sess.flush() + if passive_updates: - sess.expire(u1, ['address']) + sess.expire(u1, ["address"]) self.assert_sql_count(testing.db, go, 1) else: self.assert_sql_count(testing.db, go, 2) def go(): sess.flush() + self.assert_sql_count(testing.db, go, 0) sess.expunge_all() - eq_([Address(username='ed')], sess.query(Address).all()) + eq_([Address(username="ed")], sess.query(Address).all()) @testing.requires.on_update_cascade def test_bidirectional_passive(self): @@ -374,50 +425,61 @@ class NaturalPKTest(fixtures.MappedTest): self._test_bidirectional(False) def _test_bidirectional(self, passive_updates): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship(User, passive_updates=passive_updates, - backref='addresses')}) + mapper( + Address, + addresses, + properties={ + "user": relationship( + User, passive_updates=passive_updates, backref="addresses" + ) + }, + ) sess = create_session() - a1 = Address(email='jack1') - a2 = Address(email='jack2') + a1 = Address(email="jack1") + a2 = Address(email="jack2") - u1 = User(username='jack', fullname='jack') + u1 = User(username="jack", fullname="jack") a1.user = u1 a2.user = u1 sess.add(a1) sess.add(a2) sess.flush() - u1.username = 'ed' + u1.username = "ed" (ad1, ad2) = sess.query(Address).all() - eq_([Address(username='jack'), Address(username='jack')], [ad1, ad2]) + eq_([Address(username="jack"), Address(username="jack")], [ad1, ad2]) def go(): sess.flush() + if passive_updates: self.assert_sql_count(testing.db, go, 1) else: # two updates bundled self.assert_sql_count(testing.db, go, 2) - eq_([Address(username='ed'), Address(username='ed')], [ad1, ad2]) + eq_([Address(username="ed"), Address(username="ed")], [ad1, ad2]) sess.expunge_all() eq_( - [Address(username='ed'), Address(username='ed')], - sess.query(Address).all()) + [Address(username="ed"), Address(username="ed")], + sess.query(Address).all(), + ) - u1 = sess.query(User).get('ed') - assert len(u1.addresses) == 2 # load addresses - u1.username = 'fred' + u1 = sess.query(User).get("ed") + assert len(u1.addresses) == 2 # load addresses + u1.username = "fred" def go(): sess.flush() + # check that the passive_updates is on on the other side if passive_updates: self.assert_sql_count(testing.db, go, 1) @@ -426,8 +488,9 @@ class NaturalPKTest(fixtures.MappedTest): self.assert_sql_count(testing.db, go, 2) sess.expunge_all() eq_( - [Address(username='fred'), Address(username='fred')], - sess.query(Address).all()) + [Address(username="fred"), Address(username="fred")], + sess.query(Address).all(), + ) @testing.requires.on_update_cascade def test_manytomany_passive(self): @@ -439,24 +502,33 @@ class NaturalPKTest(fixtures.MappedTest): self._test_manytomany(False) def _test_manytomany(self, passive_updates): - users, items, Item, User, users_to_items = (self.tables.users, - self.tables.items, - self.classes.Item, - self.classes.User, - self.tables.users_to_items) + users, items, Item, User, users_to_items = ( + self.tables.users, + self.tables.items, + self.classes.Item, + self.classes.User, + self.tables.users_to_items, + ) mapper( - User, users, properties={ - 'items': relationship( - Item, secondary=users_to_items, backref='users', - passive_updates=passive_updates)}) + User, + users, + properties={ + "items": relationship( + Item, + secondary=users_to_items, + backref="users", + passive_updates=passive_updates, + ) + }, + ) mapper(Item, items) sess = create_session() - u1 = User(username='jack') - u2 = User(username='fred') - i1 = Item(itemname='item1') - i2 = Item(itemname='item2') + u1 = User(username="jack") + u2 = User(username="fred") + i1 = Item(itemname="item1") + i2 = Item(itemname="item2") u1.items.append(i1) u1.items.append(i2) @@ -468,133 +540,142 @@ class NaturalPKTest(fixtures.MappedTest): r = sess.query(Item).all() # ComparableEntity can't handle a comparison with the backrefs # involved.... - eq_(Item(itemname='item1'), r[0]) - eq_(['jack'], [u.username for u in r[0].users]) - eq_(Item(itemname='item2'), r[1]) - eq_(['jack', 'fred'], [u.username for u in r[1].users]) + eq_(Item(itemname="item1"), r[0]) + eq_(["jack"], [u.username for u in r[0].users]) + eq_(Item(itemname="item2"), r[1]) + eq_(["jack", "fred"], [u.username for u in r[1].users]) - u2.username = 'ed' + u2.username = "ed" def go(): sess.flush() + go() def go(): sess.flush() + self.assert_sql_count(testing.db, go, 0) sess.expunge_all() r = sess.query(Item).all() - eq_(Item(itemname='item1'), r[0]) - eq_(['jack'], [u.username for u in r[0].users]) - eq_(Item(itemname='item2'), r[1]) - eq_(['ed', 'jack'], sorted([u.username for u in r[1].users])) + eq_(Item(itemname="item1"), r[0]) + eq_(["jack"], [u.username for u in r[0].users]) + eq_(Item(itemname="item2"), r[1]) + eq_(["ed", "jack"], sorted([u.username for u in r[1].users])) sess.expunge_all() u2 = sess.query(User).get(u2.username) - u2.username = 'wendy' + u2.username = "wendy" sess.flush() r = sess.query(Item).with_parent(u2).all() - eq_(Item(itemname='item2'), r[0]) + eq_(Item(itemname="item2"), r[0]) def test_manytoone_deferred_relationship_expr(self): """for [ticket:4359], test that updates to the columns embedded in an object expression are also updated.""" - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship( - User, - passive_updates=testing.requires.on_update_cascade.enabled) - }) + mapper( + Address, + addresses, + properties={ + "user": relationship( + User, + passive_updates=testing.requires.on_update_cascade.enabled, + ) + }, + ) s = Session() - a1 = Address(email='jack1') - u1 = User(username='jack', fullname='jack') + a1 = Address(email="jack1") + u1 = User(username="jack", fullname="jack") a1.user = u1 # scenario 1. object is still transient, we get a value. expr = Address.user == u1 - eq_(expr.left.callable(), 'jack') + eq_(expr.left.callable(), "jack") # scenario 2. value has been changed while we are transient. # we get the updated value. - u1.username = 'ed' - eq_(expr.left.callable(), 'ed') + u1.username = "ed" + eq_(expr.left.callable(), "ed") s.add_all([u1, a1]) s.commit() - eq_(a1.username, 'ed') + eq_(a1.username, "ed") # scenario 3. the value is changed and flushed, we get the new value. - u1.username = 'fred' + u1.username = "fred" s.flush() - eq_(expr.left.callable(), 'fred') + eq_(expr.left.callable(), "fred") # scenario 4. the value is changed, flushed, and expired. # the callable goes out to get that value. - u1.username = 'wendy' + u1.username = "wendy" s.commit() - assert 'username' not in u1.__dict__ + assert "username" not in u1.__dict__ - eq_(expr.left.callable(), 'wendy') + eq_(expr.left.callable(), "wendy") # scenario 5. the value is changed flushed, expired, # and then when we hit the callable, we are detached. - u1.username = 'jack' + u1.username = "jack" s.commit() - assert 'username' not in u1.__dict__ + assert "username" not in u1.__dict__ s.expunge(u1) # InstanceState has a "last known values" feature we use # to pick up on this - eq_(expr.left.callable(), 'jack') + eq_(expr.left.callable(), "jack") # doesn't unexpire the attribute - assert 'username' not in u1.__dict__ + assert "username" not in u1.__dict__ # once we are persistent again, we check the DB s.add(u1) - eq_(expr.left.callable(), 'jack') - assert 'username' in u1.__dict__ + eq_(expr.left.callable(), "jack") + assert "username" in u1.__dict__ # scenario 6. we are using del - u2 = User(username='jack', fullname='jack') + u2 = User(username="jack", fullname="jack") expr = Address.user == u2 - eq_(expr.left.callable(), 'jack') + eq_(expr.left.callable(), "jack") del u2.username assert_raises_message( sa.exc.InvalidRequestError, "Can't resolve value for column users.username", - expr.left.callable + expr.left.callable, ) - u2.username = 'ed' - eq_(expr.left.callable(), 'ed') + u2.username = "ed" + eq_(expr.left.callable(), "ed") s.add(u2) s.commit() - eq_(expr.left.callable(), 'ed') + eq_(expr.left.callable(), "ed") del u2.username assert_raises_message( sa.exc.InvalidRequestError, "Can't resolve value for column users.username", - expr.left.callable + expr.left.callable, ) @@ -608,23 +689,25 @@ class TransientExceptionTesst(_fixtures.FixtureTest): """ - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={'user': relationship(User)}) + mapper(Address, addresses, properties={"user": relationship(User)}) sess = create_session() - u1 = User(id=5, name='u1') - ad1 = Address(email_address='e1', user=u1) + u1 = User(id=5, name="u1") + ad1 = Address(email_address="e1", user=u1) sess.add_all([u1, ad1]) sess.flush() make_transient(u1) u1.id = None - u1.username = 'u2' + u1.username = "u2" sess.add(u1) sess.flush() @@ -646,11 +729,12 @@ class ReversePKsTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'user', metadata, - Column('code', Integer, autoincrement=False, primary_key=True), - Column('status', Integer, autoincrement=False, primary_key=True), - Column('username', String(50), nullable=False), - test_needs_acid=True + "user", + metadata, + Column("code", Integer, autoincrement=False, primary_key=True), + Column("status", Integer, autoincrement=False, primary_key=True), + Column("username", String(50), nullable=False), + test_needs_acid=True, ) @classmethod @@ -670,11 +754,11 @@ class ReversePKsTest(fixtures.MappedTest): session = sa.orm.sessionmaker()() - a_published = User(1, PUBLISHED, 'a') + a_published = User(1, PUBLISHED, "a") session.add(a_published) session.commit() - a_editable = User(1, EDITABLE, 'a') + a_editable = User(1, EDITABLE, "a") session.add(a_editable) session.commit() @@ -707,11 +791,11 @@ class ReversePKsTest(fixtures.MappedTest): session = sa.orm.sessionmaker()() - a_published = User(1, PUBLISHED, 'a') + a_published = User(1, PUBLISHED, "a") session.add(a_published) session.commit() - a_editable = User(1, EDITABLE, 'a') + a_editable = User(1, EDITABLE, "a") session.add(a_editable) session.commit() @@ -732,9 +816,9 @@ class ReversePKsTest(fixtures.MappedTest): class SelfReferentialTest(fixtures.MappedTest): # mssql, mysql don't allow # ON UPDATE on self-referential keys - __unsupported_on__ = ('mssql', 'mysql') + __unsupported_on__ = ("mssql", "mysql") - __requires__ = 'on_update_or_deferrable_fks', + __requires__ = ("on_update_or_deferrable_fks",) __backend__ = True @classmethod @@ -742,10 +826,12 @@ class SelfReferentialTest(fixtures.MappedTest): fk_args = _backend_specific_fk_args() Table( - 'nodes', metadata, - Column('name', String(50), primary_key=True), - Column('parent', String(50), ForeignKey('nodes.name', **fk_args)), - test_needs_fk=True) + "nodes", + metadata, + Column("name", String(50), primary_key=True), + Column("parent", String(50), ForeignKey("nodes.name", **fk_args)), + test_needs_fk=True, + ) @classmethod def setup_classes(cls): @@ -756,56 +842,78 @@ class SelfReferentialTest(fixtures.MappedTest): Node, nodes = self.classes.Node, self.tables.nodes mapper( - Node, nodes, properties={ - 'children': relationship( + Node, + nodes, + properties={ + "children": relationship( Node, backref=sa.orm.backref( - 'parentnode', remote_side=nodes.c.name, - passive_updates=False), - )}) + "parentnode", + remote_side=nodes.c.name, + passive_updates=False, + ), + ) + }, + ) sess = Session() - n1 = Node(name='n1') + n1 = Node(name="n1") sess.add(n1) - n2 = Node(name='n11', parentnode=n1) - n3 = Node(name='n12', parentnode=n1) - n4 = Node(name='n13', parentnode=n1) + n2 = Node(name="n11", parentnode=n1) + n3 = Node(name="n12", parentnode=n1) + n4 = Node(name="n13", parentnode=n1) sess.add_all([n2, n3, n4]) sess.commit() - n1.name = 'new n1' + n1.name = "new n1" sess.commit() - eq_(['new n1', 'new n1', 'new n1'], - [n.parent - for n in sess.query(Node).filter( - Node.name.in_(['n11', 'n12', 'n13']))]) + eq_( + ["new n1", "new n1", "new n1"], + [ + n.parent + for n in sess.query(Node).filter( + Node.name.in_(["n11", "n12", "n13"]) + ) + ], + ) def test_one_to_many_on_o2m(self): Node, nodes = self.classes.Node, self.tables.nodes mapper( - Node, nodes, properties={ - 'children': relationship( + Node, + nodes, + properties={ + "children": relationship( Node, backref=sa.orm.backref( - 'parentnode', remote_side=nodes.c.name), - passive_updates=False)}) + "parentnode", remote_side=nodes.c.name + ), + passive_updates=False, + ) + }, + ) sess = Session() - n1 = Node(name='n1') - n1.children.append(Node(name='n11')) - n1.children.append(Node(name='n12')) - n1.children.append(Node(name='n13')) + n1 = Node(name="n1") + n1.children.append(Node(name="n11")) + n1.children.append(Node(name="n12")) + n1.children.append(Node(name="n13")) sess.add(n1) sess.commit() - n1.name = 'new n1' + n1.name = "new n1" sess.commit() - eq_(n1.children[1].parent, 'new n1') - eq_(['new n1', 'new n1', 'new n1'], - [n.parent - for n in sess.query(Node).filter( - Node.name.in_(['n11', 'n12', 'n13']))]) + eq_(n1.children[1].parent, "new n1") + eq_( + ["new n1", "new n1", "new n1"], + [ + n.parent + for n in sess.query(Node).filter( + Node.name.in_(["n11", "n12", "n13"]) + ) + ], + ) @testing.requires.on_update_cascade def test_many_to_one_passive(self): @@ -818,30 +926,38 @@ class SelfReferentialTest(fixtures.MappedTest): Node, nodes = self.classes.Node, self.tables.nodes mapper( - Node, nodes, properties={ - 'parentnode': relationship( - Node, remote_side=nodes.c.name, passive_updates=passive)} + Node, + nodes, + properties={ + "parentnode": relationship( + Node, remote_side=nodes.c.name, passive_updates=passive + ) + }, ) sess = Session() - n1 = Node(name='n1') - n11 = Node(name='n11', parentnode=n1) - n12 = Node(name='n12', parentnode=n1) - n13 = Node(name='n13', parentnode=n1) + n1 = Node(name="n1") + n11 = Node(name="n11", parentnode=n1) + n12 = Node(name="n12", parentnode=n1) + n13 = Node(name="n13", parentnode=n1) sess.add_all([n1, n11, n12, n13]) sess.commit() - n1.name = 'new n1' + n1.name = "new n1" sess.commit() eq_( - ['new n1', 'new n1', 'new n1'], + ["new n1", "new n1", "new n1"], [ - n.parent for n in sess.query(Node).filter( - Node.name.in_(['n11', 'n12', 'n13']))]) + n.parent + for n in sess.query(Node).filter( + Node.name.in_(["n11", "n12", "n13"]) + ) + ], + ) class NonPKCascadeTest(fixtures.MappedTest): - __requires__ = 'skip_mysql_on_windows', 'on_update_or_deferrable_fks' + __requires__ = "skip_mysql_on_windows", "on_update_or_deferrable_fks" __backend__ = True @classmethod @@ -849,28 +965,31 @@ class NonPKCascadeTest(fixtures.MappedTest): fk_args = _backend_specific_fk_args() Table( - 'users', metadata, + "users", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('username', String(50), unique=True), - Column('fullname', String(100)), - test_needs_fk=True) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("username", String(50), unique=True), + Column("fullname", String(100)), + test_needs_fk=True, + ) Table( - 'addresses', metadata, + "addresses", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('email', String(50)), + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("email", String(50)), Column( - 'username', String(50), - ForeignKey('users.username', **fk_args)), - test_needs_fk=True) + "username", String(50), ForeignKey("users.username", **fk_args) + ), + test_needs_fk=True, + ) @classmethod def setup_classes(cls): - class User(cls.Comparable): pass @@ -885,48 +1004,59 @@ class NonPKCascadeTest(fixtures.MappedTest): self._test_onetomany(False) def _test_onetomany(self, passive_updates): - User, Address, users, addresses = (self.classes.User, - self.classes.Address, - self.tables.users, - self.tables.addresses) + User, Address, users, addresses = ( + self.classes.User, + self.classes.Address, + self.tables.users, + self.tables.addresses, + ) mapper( - User, users, properties={ - 'addresses': relationship( - Address, passive_updates=passive_updates)}) + User, + users, + properties={ + "addresses": relationship( + Address, passive_updates=passive_updates + ) + }, + ) mapper(Address, addresses) sess = create_session() - u1 = User(username='jack', fullname='jack') - u1.addresses.append(Address(email='jack1')) - u1.addresses.append(Address(email='jack2')) + u1 = User(username="jack", fullname="jack") + u1.addresses.append(Address(email="jack1")) + u1.addresses.append(Address(email="jack2")) sess.add(u1) sess.flush() a1 = u1.addresses[0] eq_( sa.select([addresses.c.username]).execute().fetchall(), - [('jack',), ('jack',)]) + [("jack",), ("jack",)], + ) assert sess.query(Address).get(a1.id) is u1.addresses[0] - u1.username = 'ed' + u1.username = "ed" sess.flush() - assert u1.addresses[0].username == 'ed' + assert u1.addresses[0].username == "ed" eq_( sa.select([addresses.c.username]).execute().fetchall(), - [('ed',), ('ed',)]) + [("ed",), ("ed",)], + ) sess.expunge_all() eq_( - [Address(username='ed'), Address(username='ed')], - sess.query(Address).all()) + [Address(username="ed"), Address(username="ed")], + sess.query(Address).all(), + ) u1 = sess.query(User).get(u1.id) - u1.username = 'jack' + u1.username = "jack" def go(): sess.flush() + if not passive_updates: # test passive_updates=False; load addresses, # update user, update 2 addresses (in one executemany) @@ -936,14 +1066,14 @@ class NonPKCascadeTest(fixtures.MappedTest): self.assert_sql_count(testing.db, go, 1) sess.expunge_all() assert User( - username='jack', addresses=[ - Address(username='jack'), - Address(username='jack')]) == sess.query(User).get(u1.id) + username="jack", + addresses=[Address(username="jack"), Address(username="jack")], + ) == sess.query(User).get(u1.id) sess.expunge_all() u1 = sess.query(User).get(u1.id) u1.addresses = [] - u1.username = 'fred' + u1.username = "fred" sess.flush() sess.expunge_all() a1 = sess.query(Address).get(a1.id) @@ -951,38 +1081,46 @@ class NonPKCascadeTest(fixtures.MappedTest): eq_( sa.select([addresses.c.username]).execute().fetchall(), - [(None,), (None,)]) + [(None,), (None,)], + ) u1 = sess.query(User).get(u1.id) - eq_(User(username='fred', fullname='jack'), u1) + eq_(User(username="fred", fullname="jack"), u1) class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): """A primary key mutation cascades onto a foreign key that is itself a primary key.""" + __backend__ = True @classmethod def define_tables(cls, metadata): fk_args = _backend_specific_fk_args() - Table('users', metadata, - Column('username', String(50), primary_key=True), - test_needs_fk=True) + Table( + "users", + metadata, + Column("username", String(50), primary_key=True), + test_needs_fk=True, + ) Table( - 'addresses', metadata, + "addresses", + metadata, Column( - 'username', String(50), - ForeignKey('users.username', **fk_args), - primary_key=True), - Column('email', String(50), primary_key=True), - Column('etc', String(50)), - test_needs_fk=True) + "username", + String(50), + ForeignKey("users.username", **fk_args), + primary_key=True, + ), + Column("email", String(50), primary_key=True), + Column("etc", String(50)), + test_needs_fk=True, + ) @classmethod def setup_classes(cls): - class User(cls.Comparable): pass @@ -1011,26 +1149,33 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): """ - User, Address, users, addresses = (self.classes.User, - self.classes.Address, - self.tables.users, - self.tables.addresses) + User, Address, users, addresses = ( + self.classes.User, + self.classes.Address, + self.tables.users, + self.tables.addresses, + ) mapper( - User, users, properties={ - 'addresses': relationship( - Address, passive_updates=passive_updates)}) + User, + users, + properties={ + "addresses": relationship( + Address, passive_updates=passive_updates + ) + }, + ) mapper(Address, addresses) sess = create_session() - a1 = Address(username='ed', email='ed@host1') - u1 = User(username='ed', addresses=[a1]) - u2 = User(username='jack') + a1 = Address(username="ed", email="ed@host1") + u1 = User(username="ed", addresses=[a1]) + u2 = User(username="jack") sess.add_all([a1, u1, u2]) sess.flush() - a1.username = 'jack' + a1.username = "jack" sess.flush() def test_o2m_move_passive(self): @@ -1045,21 +1190,28 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): """ - User, Address, users, addresses = (self.classes.User, - self.classes.Address, - self.tables.users, - self.tables.addresses) + User, Address, users, addresses = ( + self.classes.User, + self.classes.Address, + self.tables.users, + self.tables.addresses, + ) mapper( - User, users, properties={ - 'addresses': relationship( - Address, passive_updates=passive_updates)}) + User, + users, + properties={ + "addresses": relationship( + Address, passive_updates=passive_updates + ) + }, + ) mapper(Address, addresses) sess = create_session() - a1 = Address(username='ed', email='ed@host1') - u1 = User(username='ed', addresses=[a1]) - u2 = User(username='jack') + a1 = Address(username="ed", email="ed@host1") + u1 = User(username="ed", addresses=[a1]) + u2 = User(username="jack") sess.add_all([a1, u1, u2]) sess.flush() @@ -1077,28 +1229,34 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): self._test_change_m2o(False) def _test_change_m2o(self, passive_updates): - User, Address, users, addresses = (self.classes.User, - self.classes.Address, - self.tables.users, - self.tables.addresses) + User, Address, users, addresses = ( + self.classes.User, + self.classes.Address, + self.tables.users, + self.tables.addresses, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship(User, passive_updates=passive_updates) - }) + mapper( + Address, + addresses, + properties={ + "user": relationship(User, passive_updates=passive_updates) + }, + ) sess = create_session() - u1 = User(username='jack') - a1 = Address(user=u1, email='foo@bar') + u1 = User(username="jack") + a1 = Address(user=u1, email="foo@bar") sess.add_all([u1, a1]) sess.flush() - u1.username = 'edmodified' + u1.username = "edmodified" sess.flush() - eq_(a1.username, 'edmodified') + eq_(a1.username, "edmodified") sess.expire_all() - eq_(a1.username, 'edmodified') + eq_(a1.username, "edmodified") def test_move_m2o_passive(self): self._test_move_m2o(True) @@ -1107,21 +1265,27 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): self._test_move_m2o(False) def _test_move_m2o(self, passive_updates): - User, Address, users, addresses = (self.classes.User, - self.classes.Address, - self.tables.users, - self.tables.addresses) + User, Address, users, addresses = ( + self.classes.User, + self.classes.Address, + self.tables.users, + self.tables.addresses, + ) # tests [ticket:1856] mapper(User, users) mapper( - Address, addresses, properties={ - 'user': relationship(User, passive_updates=passive_updates)}) + Address, + addresses, + properties={ + "user": relationship(User, passive_updates=passive_updates) + }, + ) sess = create_session() - u1 = User(username='jack') - u2 = User(username='ed') - a1 = Address(user=u1, email='foo@bar') + u1 = User(username="jack") + u2 = User(username="ed") + a1 = Address(user=u1, email="foo@bar") sess.add_all([u1, u2, a1]) sess.flush() @@ -1129,19 +1293,23 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): sess.flush() def test_rowswitch_doesntfire(self): - User, Address, users, addresses = (self.classes.User, - self.classes.Address, - self.tables.users, - self.tables.addresses) + User, Address, users, addresses = ( + self.classes.User, + self.classes.Address, + self.tables.users, + self.tables.addresses, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship(User, passive_updates=True) - }) + mapper( + Address, + addresses, + properties={"user": relationship(User, passive_updates=True)}, + ) sess = create_session() - u1 = User(username='ed') - a1 = Address(user=u1, email='ed@host1') + u1 = User(username="ed") + a1 = Address(user=u1, email="ed@host1") sess.add(u1) sess.add(a1) @@ -1150,8 +1318,8 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): sess.delete(u1) sess.delete(a1) - u2 = User(username='ed') - a2 = Address(user=u2, email='ed@host1', etc='foo') + u2 = User(username="ed") + a2 = Address(user=u2, email="ed@host1", etc="foo") sess.add(u2) sess.add(a2) @@ -1160,12 +1328,18 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): # test that the primary key columns of addresses are not # being updated as well, since this is a row switch. self.assert_sql_execution( - testing.db, sess.flush, CompiledSQL( + testing.db, + sess.flush, + CompiledSQL( "UPDATE addresses SET etc=:etc WHERE " "addresses.username = :addresses_username AND" - " addresses.email = :addresses_email", { - 'etc': 'foo', 'addresses_username': 'ed', - 'addresses_email': 'ed@host1'}), + " addresses.email = :addresses_email", + { + "etc": "foo", + "addresses_username": "ed", + "addresses_email": "ed@host1", + }, + ), ) def _test_onetomany(self, passive_updates): @@ -1177,47 +1351,58 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): """ - User, Address, users, addresses = (self.classes.User, - self.classes.Address, - self.tables.users, - self.tables.addresses) + User, Address, users, addresses = ( + self.classes.User, + self.classes.Address, + self.tables.users, + self.tables.addresses, + ) mapper( - User, users, properties={ - 'addresses': relationship( - Address, passive_updates=passive_updates)}) + User, + users, + properties={ + "addresses": relationship( + Address, passive_updates=passive_updates + ) + }, + ) mapper(Address, addresses) sess = create_session() - a1, a2 = Address(username='ed', email='ed@host1'), \ - Address(username='ed', email='ed@host2') - u1 = User(username='ed', addresses=[a1, a2]) + a1, a2 = ( + Address(username="ed", email="ed@host1"), + Address(username="ed", email="ed@host2"), + ) + u1 = User(username="ed", addresses=[a1, a2]) sess.add(u1) sess.flush() - eq_(a1.username, 'ed') - eq_(a2.username, 'ed') + eq_(a1.username, "ed") + eq_(a2.username, "ed") eq_( sa.select([addresses.c.username]).execute().fetchall(), - [('ed',), ('ed',)]) + [("ed",), ("ed",)], + ) - u1.username = 'jack' - a2.email = 'ed@host3' + u1.username = "jack" + a2.email = "ed@host3" sess.flush() - eq_(a1.username, 'jack') - eq_(a2.username, 'jack') + eq_(a1.username, "jack") + eq_(a2.username, "jack") eq_( sa.select([addresses.c.username]).execute().fetchall(), - [('jack',), ('jack', )]) + [("jack",), ("jack",)], + ) class JoinedInheritanceTest(fixtures.MappedTest): """Test cascades of pk->pk/fk on joined table inh.""" # mssql doesn't allow ON UPDATE on self-referential keys - __unsupported_on__ = ('mssql',) + __unsupported_on__ = ("mssql",) - __requires__ = 'skip_mysql_on_windows', + __requires__ = ("skip_mysql_on_windows",) __backend__ = True @classmethod @@ -1225,33 +1410,44 @@ class JoinedInheritanceTest(fixtures.MappedTest): fk_args = _backend_specific_fk_args() Table( - 'person', metadata, - Column('name', String(50), primary_key=True), - Column('type', String(50), nullable=False), - test_needs_fk=True) + "person", + metadata, + Column("name", String(50), primary_key=True), + Column("type", String(50), nullable=False), + test_needs_fk=True, + ) Table( - 'engineer', metadata, + "engineer", + metadata, Column( - 'name', String(50), ForeignKey('person.name', **fk_args), - primary_key=True), - Column('primary_language', String(50)), + "name", + String(50), + ForeignKey("person.name", **fk_args), + primary_key=True, + ), + Column("primary_language", String(50)), Column( - 'boss_name', String(50), - ForeignKey('manager.name', **fk_args)), - test_needs_fk=True + "boss_name", String(50), ForeignKey("manager.name", **fk_args) + ), + test_needs_fk=True, ) Table( - 'manager', metadata, Column('name', String(50), - ForeignKey('person.name', **fk_args), - primary_key=True), - Column('paperwork', String(50)), test_needs_fk=True + "manager", + metadata, + Column( + "name", + String(50), + ForeignKey("person.name", **fk_args), + primary_key=True, + ), + Column("paperwork", String(50)), + test_needs_fk=True, ) @classmethod def setup_classes(cls): - class Person(cls.Comparable): pass @@ -1280,79 +1476,111 @@ class JoinedInheritanceTest(fixtures.MappedTest): def _test_pk(self, passive_updates): Person, Manager, person, manager, Engineer, engineer = ( - self.classes.Person, self.classes.Manager, self.tables.person, - self.tables.manager, self.classes.Engineer, self.tables.engineer) + self.classes.Person, + self.classes.Manager, + self.tables.person, + self.tables.manager, + self.classes.Engineer, + self.tables.engineer, + ) mapper( - Person, person, polymorphic_on=person.c.type, - polymorphic_identity='person', passive_updates=passive_updates) + Person, + person, + polymorphic_on=person.c.type, + polymorphic_identity="person", + passive_updates=passive_updates, + ) mapper( - Engineer, engineer, inherits=Person, - polymorphic_identity='engineer', properties={ - 'boss': relationship( + Engineer, + engineer, + inherits=Person, + polymorphic_identity="engineer", + properties={ + "boss": relationship( Manager, primaryjoin=manager.c.name == engineer.c.boss_name, - passive_updates=passive_updates)}) + passive_updates=passive_updates, + ) + }, + ) mapper( - Manager, manager, inherits=Person, polymorphic_identity='manager') + Manager, manager, inherits=Person, polymorphic_identity="manager" + ) sess = sa.orm.sessionmaker()() - e1 = Engineer(name='dilbert', primary_language='java') + e1 = Engineer(name="dilbert", primary_language="java") sess.add(e1) sess.commit() - e1.name = 'wally' - e1.primary_language = 'c++' + e1.name = "wally" + e1.primary_language = "c++" sess.commit() def _test_fk(self, passive_updates): Person, Manager, person, manager, Engineer, engineer = ( - self.classes.Person, self.classes.Manager, self.tables.person, - self.tables.manager, self.classes.Engineer, self.tables.engineer) + self.classes.Person, + self.classes.Manager, + self.tables.person, + self.tables.manager, + self.classes.Engineer, + self.tables.engineer, + ) mapper( - Person, person, polymorphic_on=person.c.type, - polymorphic_identity='person', passive_updates=passive_updates) + Person, + person, + polymorphic_on=person.c.type, + polymorphic_identity="person", + passive_updates=passive_updates, + ) mapper( - Engineer, engineer, inherits=Person, - polymorphic_identity='engineer', properties={ - 'boss': relationship( + Engineer, + engineer, + inherits=Person, + polymorphic_identity="engineer", + properties={ + "boss": relationship( Manager, primaryjoin=manager.c.name == engineer.c.boss_name, - passive_updates=passive_updates)}) + passive_updates=passive_updates, + ) + }, + ) mapper( - Manager, manager, inherits=Person, polymorphic_identity='manager') + Manager, manager, inherits=Person, polymorphic_identity="manager" + ) sess = sa.orm.sessionmaker()() - m1 = Manager(name='dogbert', paperwork='lots') - e1, e2 = Engineer(name='dilbert', primary_language='java', boss=m1),\ - Engineer(name='wally', primary_language='c++', boss=m1) - sess.add_all([ - e1, e2, m1 - ]) + m1 = Manager(name="dogbert", paperwork="lots") + e1, e2 = ( + Engineer(name="dilbert", primary_language="java", boss=m1), + Engineer(name="wally", primary_language="c++", boss=m1), + ) + sess.add_all([e1, e2, m1]) sess.commit() - eq_(e1.boss_name, 'dogbert') - eq_(e2.boss_name, 'dogbert') + eq_(e1.boss_name, "dogbert") + eq_(e2.boss_name, "dogbert") sess.expire_all() - m1.name = 'pointy haired' - e1.primary_language = 'scala' - e2.primary_language = 'cobol' + m1.name = "pointy haired" + e1.primary_language = "scala" + e2.primary_language = "cobol" sess.commit() - eq_(e1.boss_name, 'pointy haired') - eq_(e2.boss_name, 'pointy haired') + eq_(e1.boss_name, "pointy haired") + eq_(e2.boss_name, "pointy haired") class JoinedInheritancePKOnFKTest(fixtures.MappedTest): """Test cascades of pk->non-pk/fk on joined table inh.""" # mssql doesn't allow ON UPDATE on self-referential keys - __unsupported_on__ = ('mssql',) + __unsupported_on__ = ("mssql",) - __requires__ = 'skip_mysql_on_windows', + __requires__ = ("skip_mysql_on_windows",) __backend__ = True @classmethod @@ -1360,26 +1588,28 @@ class JoinedInheritancePKOnFKTest(fixtures.MappedTest): fk_args = _backend_specific_fk_args() Table( - 'person', metadata, - Column('name', String(50), primary_key=True), - Column('type', String(50), nullable=False), - test_needs_fk=True) + "person", + metadata, + Column("name", String(50), primary_key=True), + Column("type", String(50), nullable=False), + test_needs_fk=True, + ) Table( - 'engineer', metadata, + "engineer", + metadata, Column( - 'id', Integer, - primary_key=True, test_needs_autoincrement=True), + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), Column( - 'person_name', String(50), - ForeignKey('person.name', **fk_args)), - Column('primary_language', String(50)), - test_needs_fk=True + "person_name", String(50), ForeignKey("person.name", **fk_args) + ), + Column("primary_language", String(50)), + test_needs_fk=True, ) @classmethod def setup_classes(cls): - class Person(cls.Comparable): pass @@ -1388,27 +1618,37 @@ class JoinedInheritancePKOnFKTest(fixtures.MappedTest): def _test_pk(self, passive_updates): Person, person, Engineer, engineer = ( - self.classes.Person, self.tables.person, - self.classes.Engineer, self.tables.engineer) + self.classes.Person, + self.tables.person, + self.classes.Engineer, + self.tables.engineer, + ) mapper( - Person, person, polymorphic_on=person.c.type, - polymorphic_identity='person', passive_updates=passive_updates) + Person, + person, + polymorphic_on=person.c.type, + polymorphic_identity="person", + passive_updates=passive_updates, + ) mapper( - Engineer, engineer, inherits=Person, - polymorphic_identity='engineer') + Engineer, + engineer, + inherits=Person, + polymorphic_identity="engineer", + ) sess = sa.orm.sessionmaker()() - e1 = Engineer(name='dilbert', primary_language='java') + e1 = Engineer(name="dilbert", primary_language="java") sess.add(e1) sess.commit() - e1.name = 'wally' - e1.primary_language = 'c++' + e1.name = "wally" + e1.primary_language = "c++" sess.flush() - eq_(e1.person_name, 'wally') + eq_(e1.person_name, "wally") sess.expire_all() eq_(e1.primary_language, "c++") diff --git a/test/orm/test_of_type.py b/test/orm/test_of_type.py index c8a042e93a..0e8757ca85 100644 --- a/test/orm/test_of_type.py +++ b/test/orm/test_of_type.py @@ -1,6 +1,14 @@ -from sqlalchemy.orm import Session, aliased, with_polymorphic, \ - contains_eager, joinedload, subqueryload, relationship,\ - subqueryload_all, joinedload_all +from sqlalchemy.orm import ( + Session, + aliased, + with_polymorphic, + contains_eager, + joinedload, + subqueryload, + relationship, + subqueryload_all, + joinedload_all, +) from sqlalchemy import and_ from sqlalchemy import testing, exc as sa_exc from sqlalchemy.testing import fixtures @@ -9,84 +17,111 @@ from sqlalchemy.testing.schema import Column from sqlalchemy.engine import default from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy import Integer, String, ForeignKey -from .inheritance._poly_fixtures import (Company, Person, Engineer, Manager, - Boss, Machine, Paperwork, - _PolymorphicFixtureBase, _Polymorphic, - _PolymorphicPolymorphic, - _PolymorphicUnions, _PolymorphicJoins, - _PolymorphicAliasedJoins) +from .inheritance._poly_fixtures import ( + Company, + Person, + Engineer, + Manager, + Boss, + Machine, + Paperwork, + _PolymorphicFixtureBase, + _Polymorphic, + _PolymorphicPolymorphic, + _PolymorphicUnions, + _PolymorphicJoins, + _PolymorphicAliasedJoins, +) from sqlalchemy.testing.assertsql import AllOf, CompiledSQL class _PolymorphicTestBase(object): - __dialect__ = 'default' + __dialect__ = "default" def test_any_one(self): sess = Session() any_ = Company.employees.of_type(Engineer).any( - Engineer.primary_language == 'cobol') + Engineer.primary_language == "cobol" + ) eq_(sess.query(Company).filter(any_).one(), self.c2) def test_any_two(self): sess = Session() calias = aliased(Company) any_ = calias.employees.of_type(Engineer).any( - Engineer.primary_language == 'cobol') + Engineer.primary_language == "cobol" + ) eq_(sess.query(calias).filter(any_).one(), self.c2) def test_any_three(self): sess = Session() - any_ = Company.employees.of_type(Boss).any( - Boss.golf_swing == 'fore') + any_ = Company.employees.of_type(Boss).any(Boss.golf_swing == "fore") eq_(sess.query(Company).filter(any_).one(), self.c1) def test_any_four(self): sess = Session() any_ = Company.employees.of_type(Boss).any( - Manager.manager_name == 'pointy') + Manager.manager_name == "pointy" + ) eq_(sess.query(Company).filter(any_).one(), self.c1) def test_any_five(self): sess = Session() any_ = Company.employees.of_type(Engineer).any( - and_(Engineer.primary_language == 'cobol')) + and_(Engineer.primary_language == "cobol") + ) eq_(sess.query(Company).filter(any_).one(), self.c2) def test_join_to_subclass_one(self): sess = Session() - eq_(sess.query(Company) - .join(Company.employees.of_type(Engineer)) - .filter(Engineer.primary_language == 'java').all(), - [self.c1]) + eq_( + sess.query(Company) + .join(Company.employees.of_type(Engineer)) + .filter(Engineer.primary_language == "java") + .all(), + [self.c1], + ) def test_join_to_subclass_two(self): sess = Session() - eq_(sess.query(Company) - .join(Company.employees.of_type(Engineer), 'machines') - .filter(Machine.name.ilike("%thinkpad%")).all(), - [self.c1]) + eq_( + sess.query(Company) + .join(Company.employees.of_type(Engineer), "machines") + .filter(Machine.name.ilike("%thinkpad%")) + .all(), + [self.c1], + ) def test_join_to_subclass_three(self): sess = Session() - eq_(sess.query(Company, Engineer) - .join(Company.employees.of_type(Engineer)) - .filter(Engineer.primary_language == 'java').count(), - 1) + eq_( + sess.query(Company, Engineer) + .join(Company.employees.of_type(Engineer)) + .filter(Engineer.primary_language == "java") + .count(), + 1, + ) def test_join_to_subclass_four(self): sess = Session() # test [ticket:2093] - eq_(sess.query(Company.company_id, Engineer) - .join(Company.employees.of_type(Engineer)) - .filter(Engineer.primary_language == 'java').count(), - 1) + eq_( + sess.query(Company.company_id, Engineer) + .join(Company.employees.of_type(Engineer)) + .filter(Engineer.primary_language == "java") + .count(), + 1, + ) def test_join_to_subclass_five(self): sess = Session() - eq_(sess.query(Company) - .join(Company.employees.of_type(Engineer)) - .filter(Engineer.primary_language == 'java').count(), - 1) + eq_( + sess.query(Company) + .join(Company.employees.of_type(Engineer)) + .filter(Engineer.primary_language == "java") + .count(), + 1, + ) def test_with_polymorphic_join_compile_one(self): sess = Session() @@ -94,31 +129,32 @@ class _PolymorphicTestBase(object): self.assert_compile( sess.query(Company).join( Company.employees.of_type( - with_polymorphic(Person, [Engineer, Manager], - aliased=True, flat=True) + with_polymorphic( + Person, [Engineer, Manager], aliased=True, flat=True + ) ) ), "SELECT companies.company_id AS companies_company_id, " "companies.name AS companies_name FROM companies " - "JOIN %s" - % ( - self._polymorphic_join_target([Engineer, Manager]) - ) + "JOIN %s" % (self._polymorphic_join_target([Engineer, Manager])), ) def test_with_polymorphic_join_exec_contains_eager_one(self): sess = Session() def go(): - wp = with_polymorphic(Person, [Engineer, Manager], - aliased=True, flat=True) + wp = with_polymorphic( + Person, [Engineer, Manager], aliased=True, flat=True + ) eq_( - sess.query(Company).join( - Company.employees.of_type(wp) - ).order_by(Company.company_id, wp.person_id). - options(contains_eager(Company.employees.of_type(wp))).all(), - [self.c1, self.c2] + sess.query(Company) + .join(Company.employees.of_type(wp)) + .order_by(Company.company_id, wp.person_id) + .options(contains_eager(Company.employees.of_type(wp))) + .all(), + [self.c1, self.c2], ) + self.assert_sql_count(testing.db, go, 1) def test_with_polymorphic_join_exec_contains_eager_two(self): @@ -127,24 +163,28 @@ class _PolymorphicTestBase(object): def go(): wp = with_polymorphic(Person, [Engineer, Manager], aliased=True) eq_( - sess.query(Company).join( - Company.employees.of_type(wp) - ).order_by(Company.company_id, wp.person_id). - options(contains_eager(Company.employees, alias=wp)).all(), - [self.c1, self.c2] + sess.query(Company) + .join(Company.employees.of_type(wp)) + .order_by(Company.company_id, wp.person_id) + .options(contains_eager(Company.employees, alias=wp)) + .all(), + [self.c1, self.c2], ) + self.assert_sql_count(testing.db, go, 1) def test_with_polymorphic_any(self): sess = Session() wp = with_polymorphic(Person, [Engineer], aliased=True) eq_( - sess.query(Company.company_id). - filter( + sess.query(Company.company_id) + .filter( Company.employees.of_type(wp).any( - wp.Engineer.primary_language == 'java') - ).all(), - [(1, )] + wp.Engineer.primary_language == "java" + ) + ) + .all(), + [(1,)], ) def test_subqueryload_implicit_withpoly(self): @@ -152,12 +192,13 @@ class _PolymorphicTestBase(object): def go(): eq_( - sess.query(Company). - filter_by(company_id=1). - options(subqueryload(Company.employees.of_type(Engineer))). - all(), - [self._company_with_emps_fixture()[0]] + sess.query(Company) + .filter_by(company_id=1) + .options(subqueryload(Company.employees.of_type(Engineer))) + .all(), + [self._company_with_emps_fixture()[0]], ) + self.assert_sql_count(testing.db, go, 4) def test_joinedload_implicit_withpoly(self): @@ -165,12 +206,13 @@ class _PolymorphicTestBase(object): def go(): eq_( - sess.query(Company). - filter_by(company_id=1). - options(joinedload(Company.employees.of_type(Engineer))). - all(), - [self._company_with_emps_fixture()[0]] + sess.query(Company) + .filter_by(company_id=1) + .options(joinedload(Company.employees.of_type(Engineer))) + .all(), + [self._company_with_emps_fixture()[0]], ) + self.assert_sql_count(testing.db, go, 3) def test_subqueryload_explicit_withpoly(self): @@ -179,12 +221,13 @@ class _PolymorphicTestBase(object): def go(): target = with_polymorphic(Person, Engineer) eq_( - sess.query(Company). - filter_by(company_id=1). - options(subqueryload(Company.employees.of_type(target))). - all(), - [self._company_with_emps_fixture()[0]] + sess.query(Company) + .filter_by(company_id=1) + .options(subqueryload(Company.employees.of_type(target))) + .all(), + [self._company_with_emps_fixture()[0]], ) + self.assert_sql_count(testing.db, go, 4) def test_joinedload_explicit_withpoly(self): @@ -193,12 +236,13 @@ class _PolymorphicTestBase(object): def go(): target = with_polymorphic(Person, Engineer, flat=True) eq_( - sess.query(Company). - filter_by(company_id=1). - options(joinedload(Company.employees.of_type(target))). - all(), - [self._company_with_emps_fixture()[0]] + sess.query(Company) + .filter_by(company_id=1) + .options(joinedload(Company.employees.of_type(target))) + .all(), + [self._company_with_emps_fixture()[0]], ) + self.assert_sql_count(testing.db, go, 3) def test_joinedload_stacked_of_type(self): @@ -206,56 +250,63 @@ class _PolymorphicTestBase(object): def go(): eq_( - sess.query(Company). - filter_by(company_id=1). - options( + sess.query(Company) + .filter_by(company_id=1) + .options( joinedload(Company.employees.of_type(Manager)), - joinedload(Company.employees.of_type(Engineer)) - ).all(), - [self._company_with_emps_fixture()[0]] + joinedload(Company.employees.of_type(Engineer)), + ) + .all(), + [self._company_with_emps_fixture()[0]], ) + self.assert_sql_count(testing.db, go, 2) -class PolymorphicPolymorphicTest(_PolymorphicTestBase, - _PolymorphicPolymorphic): +class PolymorphicPolymorphicTest( + _PolymorphicTestBase, _PolymorphicPolymorphic +): def _polymorphic_join_target(self, cls): from sqlalchemy.orm import class_mapper from sqlalchemy.sql.expression import FromGrouping + m, sel = class_mapper(Person)._with_polymorphic_args(cls) sel = FromGrouping(sel.alias(flat=True)) comp_sel = sel.compile(dialect=default.DefaultDialect()) - return \ - comp_sel.process(sel, asfrom=True).replace("\n", "") + \ - " ON companies.company_id = people_1.company_id" + return ( + comp_sel.process(sel, asfrom=True).replace("\n", "") + + " ON companies.company_id = people_1.company_id" + ) class PolymorphicUnionsTest(_PolymorphicTestBase, _PolymorphicUnions): - def _polymorphic_join_target(self, cls): from sqlalchemy.orm import class_mapper sel = class_mapper(Person)._with_polymorphic_selectable.element comp_sel = sel.compile(dialect=default.DefaultDialect()) - return \ - comp_sel.process(sel, asfrom=True).replace("\n", "") + \ - " AS anon_1 ON companies.company_id = anon_1.company_id" + return ( + comp_sel.process(sel, asfrom=True).replace("\n", "") + + " AS anon_1 ON companies.company_id = anon_1.company_id" + ) -class PolymorphicAliasedJoinsTest(_PolymorphicTestBase, - _PolymorphicAliasedJoins): +class PolymorphicAliasedJoinsTest( + _PolymorphicTestBase, _PolymorphicAliasedJoins +): def _polymorphic_join_target(self, cls): from sqlalchemy.orm import class_mapper sel = class_mapper(Person)._with_polymorphic_selectable.element comp_sel = sel.compile(dialect=default.DefaultDialect()) - return \ - comp_sel.process(sel, asfrom=True).replace("\n", "") + \ - " AS anon_1 ON companies.company_id = anon_1.people_company_id" + return ( + comp_sel.process(sel, asfrom=True).replace("\n", "") + + " AS anon_1 ON companies.company_id = anon_1.people_company_id" + ) class PolymorphicJoinsTest(_PolymorphicTestBase, _PolymorphicJoins): @@ -263,32 +314,38 @@ class PolymorphicJoinsTest(_PolymorphicTestBase, _PolymorphicJoins): from sqlalchemy.orm import class_mapper from sqlalchemy.sql.expression import FromGrouping - sel = FromGrouping(class_mapper( - Person)._with_polymorphic_selectable.alias(flat=True)) + sel = FromGrouping( + class_mapper(Person)._with_polymorphic_selectable.alias(flat=True) + ) comp_sel = sel.compile(dialect=default.DefaultDialect()) - return \ - comp_sel.process(sel, asfrom=True).replace("\n", "") + \ - " ON companies.company_id = people_1.company_id" + return ( + comp_sel.process(sel, asfrom=True).replace("\n", "") + + " ON companies.company_id = people_1.company_id" + ) def test_joinedload_explicit_with_unaliased_poly_compile(self): sess = Session() target = with_polymorphic(Person, Engineer) - q = sess.query(Company).\ - filter_by(company_id=1).\ - options(joinedload(Company.employees.of_type(target))) + q = ( + sess.query(Company) + .filter_by(company_id=1) + .options(joinedload(Company.employees.of_type(target))) + ) assert_raises_message( sa_exc.InvalidRequestError, "Detected unaliased columns when generating joined load.", - q._compile_context + q._compile_context, ) def test_joinedload_explicit_with_flataliased_poly_compile(self): sess = Session() target = with_polymorphic(Person, Engineer, flat=True) - q = sess.query(Company).\ - filter_by(company_id=1).\ - options(joinedload(Company.employees.of_type(target))) + q = ( + sess.query(Company) + .filter_by(company_id=1) + .options(joinedload(Company.employees.of_type(target))) + ) self.assert_compile( q, "SELECT companies.company_id AS companies_company_id, " @@ -307,19 +364,20 @@ class PolymorphicJoinsTest(_PolymorphicTestBase, _PolymorphicJoins): "ON people_1.person_id = managers_1.person_id) " "ON companies.company_id = people_1.company_id " "WHERE companies.company_id = :company_id_1 " - "ORDER BY people_1.person_id" + "ORDER BY people_1.person_id", ) -class SubclassRelationshipTest(testing.AssertsCompiledSQL, - fixtures.DeclarativeMappedTest): +class SubclassRelationshipTest( + testing.AssertsCompiledSQL, fixtures.DeclarativeMappedTest +): """There's overlap here vs. the ones above.""" - run_setup_classes = 'once' - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_classes = "once" + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_classes(cls): @@ -328,40 +386,44 @@ class SubclassRelationshipTest(testing.AssertsCompiledSQL, class Job(ComparableEntity, Base): __tablename__ = "job" - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) type = Column(String(10)) - widget_id = Column(ForeignKey('widget.id')) + widget_id = Column(ForeignKey("widget.id")) widget = relationship("Widget") - container_id = Column(Integer, ForeignKey('data_container.id')) + container_id = Column(Integer, ForeignKey("data_container.id")) __mapper_args__ = {"polymorphic_on": type} class SubJob(Job): - __tablename__ = 'subjob' - id = Column(Integer, ForeignKey('job.id'), primary_key=True) + __tablename__ = "subjob" + id = Column(Integer, ForeignKey("job.id"), primary_key=True) attr = Column(String(10)) __mapper_args__ = {"polymorphic_identity": "sub"} class ParentThing(ComparableEntity, Base): - __tablename__ = 'parent' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) - container_id = Column(Integer, ForeignKey('data_container.id')) + __tablename__ = "parent" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + container_id = Column(Integer, ForeignKey("data_container.id")) container = relationship("DataContainer") class DataContainer(ComparableEntity, Base): __tablename__ = "data_container" - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(10)) jobs = relationship(Job, order_by=Job.id) class Widget(ComparableEntity, Base): __tablename__ = "widget" - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(10)) @classmethod @@ -373,27 +435,30 @@ class SubclassRelationshipTest(testing.AssertsCompiledSQL, @classmethod def _fixture(cls): - ParentThing, DataContainer, SubJob, Widget = \ - cls.classes.ParentThing,\ - cls.classes.DataContainer,\ - cls.classes.SubJob,\ - cls.classes.Widget + ParentThing, DataContainer, SubJob, Widget = ( + cls.classes.ParentThing, + cls.classes.DataContainer, + cls.classes.SubJob, + cls.classes.Widget, + ) return [ ParentThing( - container=DataContainer(name="d1", - jobs=[ - SubJob(attr="s1", - widget=Widget(name='w1')), - SubJob(attr="s2", - widget=Widget(name='w2'))]) + container=DataContainer( + name="d1", + jobs=[ + SubJob(attr="s1", widget=Widget(name="w1")), + SubJob(attr="s2", widget=Widget(name="w2")), + ], + ) ), ParentThing( - container=DataContainer(name="d2", - jobs=[ - SubJob(attr="s3", - widget=Widget(name='w3')), - SubJob(attr="s4", - widget=Widget(name='w4'))]) + container=DataContainer( + name="d2", + jobs=[ + SubJob(attr="s3", widget=Widget(name="w3")), + SubJob(attr="s4", widget=Widget(name="w4")), + ], + ) ), ] @@ -402,61 +467,62 @@ class SubclassRelationshipTest(testing.AssertsCompiledSQL, return [p.container for p in cls._fixture()] def test_contains_eager_wpoly(self): - ParentThing, DataContainer, Job, SubJob = \ - self.classes.ParentThing,\ - self.classes.DataContainer,\ - self.classes.Job,\ - self.classes.SubJob + ParentThing, DataContainer, Job, SubJob = ( + self.classes.ParentThing, + self.classes.DataContainer, + self.classes.Job, + self.classes.SubJob, + ) Job_P = with_polymorphic(Job, SubJob, aliased=True) s = Session(testing.db) - q = s.query(DataContainer).\ - join(DataContainer.jobs.of_type(Job_P)).\ - options(contains_eager(DataContainer.jobs.of_type(Job_P))) + q = ( + s.query(DataContainer) + .join(DataContainer.jobs.of_type(Job_P)) + .options(contains_eager(DataContainer.jobs.of_type(Job_P))) + ) def go(): - eq_( - q.all(), - self._dc_fixture() - ) + eq_(q.all(), self._dc_fixture()) + self.assert_sql_count(testing.db, go, 5) def test_joinedload_wpoly(self): - ParentThing, DataContainer, Job, SubJob = \ - self.classes.ParentThing,\ - self.classes.DataContainer,\ - self.classes.Job,\ - self.classes.SubJob + ParentThing, DataContainer, Job, SubJob = ( + self.classes.ParentThing, + self.classes.DataContainer, + self.classes.Job, + self.classes.SubJob, + ) Job_P = with_polymorphic(Job, SubJob, aliased=True) s = Session(testing.db) - q = s.query(DataContainer).\ - options(joinedload(DataContainer.jobs.of_type(Job_P))) + q = s.query(DataContainer).options( + joinedload(DataContainer.jobs.of_type(Job_P)) + ) def go(): - eq_( - q.all(), - self._dc_fixture() - ) + eq_(q.all(), self._dc_fixture()) + self.assert_sql_count(testing.db, go, 5) def test_joinedload_wsubclass(self): - ParentThing, DataContainer, Job, SubJob = \ - self.classes.ParentThing,\ - self.classes.DataContainer,\ - self.classes.Job,\ - self.classes.SubJob + ParentThing, DataContainer, Job, SubJob = ( + self.classes.ParentThing, + self.classes.DataContainer, + self.classes.Job, + self.classes.SubJob, + ) s = Session(testing.db) - q = s.query(DataContainer).\ - options(joinedload(DataContainer.jobs.of_type(SubJob))) + q = s.query(DataContainer).options( + joinedload(DataContainer.jobs.of_type(SubJob)) + ) def go(): - eq_( - q.all(), - self._dc_fixture() - ) + eq_(q.all(), self._dc_fixture()) + self.assert_sql_count(testing.db, go, 5) def test_lazyload(self): @@ -465,10 +531,8 @@ class SubclassRelationshipTest(testing.AssertsCompiledSQL, q = s.query(DataContainer) def go(): - eq_( - q.all(), - self._dc_fixture() - ) + eq_(q.all(), self._dc_fixture()) + # SELECT data container # SELECT job * 2 container rows # SELECT subjob * 4 rows @@ -476,98 +540,90 @@ class SubclassRelationshipTest(testing.AssertsCompiledSQL, self.assert_sql_count(testing.db, go, 11) def test_subquery_wsubclass(self): - ParentThing, DataContainer, Job, SubJob = \ - self.classes.ParentThing,\ - self.classes.DataContainer,\ - self.classes.Job,\ - self.classes.SubJob + ParentThing, DataContainer, Job, SubJob = ( + self.classes.ParentThing, + self.classes.DataContainer, + self.classes.Job, + self.classes.SubJob, + ) s = Session(testing.db) - q = s.query(DataContainer).\ - options(subqueryload(DataContainer.jobs.of_type(SubJob))) + q = s.query(DataContainer).options( + subqueryload(DataContainer.jobs.of_type(SubJob)) + ) def go(): - eq_( - q.all(), - self._dc_fixture() - ) + eq_(q.all(), self._dc_fixture()) + self.assert_sql_count(testing.db, go, 6) def test_twolevel_subqueryload_wsubclass(self): - ParentThing, DataContainer, Job, SubJob = \ - self.classes.ParentThing,\ - self.classes.DataContainer,\ - self.classes.Job,\ - self.classes.SubJob + ParentThing, DataContainer, Job, SubJob = ( + self.classes.ParentThing, + self.classes.DataContainer, + self.classes.Job, + self.classes.SubJob, + ) s = Session(testing.db) - q = s.query(ParentThing).\ - options( + q = s.query(ParentThing).options( subqueryload_all( - ParentThing.container, - DataContainer.jobs.of_type(SubJob) - )) + ParentThing.container, DataContainer.jobs.of_type(SubJob) + ) + ) def go(): - eq_( - q.all(), - self._fixture() - ) + eq_(q.all(), self._fixture()) + self.assert_sql_count(testing.db, go, 7) def test_twolevel_subqueryload_wsubclass_mapper_term(self): - DataContainer, SubJob = \ - self.classes.DataContainer,\ - self.classes.SubJob + DataContainer, SubJob = self.classes.DataContainer, self.classes.SubJob s = Session(testing.db) sj_alias = aliased(SubJob) - q = s.query(DataContainer).\ - options( + q = s.query(DataContainer).options( subqueryload_all( - DataContainer.jobs.of_type(sj_alias), - sj_alias.widget - )) + DataContainer.jobs.of_type(sj_alias), sj_alias.widget + ) + ) def go(): - eq_( - q.all(), - self._dc_fixture() - ) + eq_(q.all(), self._dc_fixture()) + self.assert_sql_count(testing.db, go, 3) def test_twolevel_joinedload_wsubclass(self): - ParentThing, DataContainer, Job, SubJob = \ - self.classes.ParentThing,\ - self.classes.DataContainer,\ - self.classes.Job,\ - self.classes.SubJob + ParentThing, DataContainer, Job, SubJob = ( + self.classes.ParentThing, + self.classes.DataContainer, + self.classes.Job, + self.classes.SubJob, + ) s = Session(testing.db) - q = s.query(ParentThing).\ - options( + q = s.query(ParentThing).options( joinedload_all( - ParentThing.container, - DataContainer.jobs.of_type(SubJob) - )) + ParentThing.container, DataContainer.jobs.of_type(SubJob) + ) + ) def go(): - eq_( - q.all(), - self._fixture() - ) + eq_(q.all(), self._fixture()) + self.assert_sql_count(testing.db, go, 5) def test_any_wpoly(self): - ParentThing, DataContainer, Job, SubJob = \ - self.classes.ParentThing,\ - self.classes.DataContainer,\ - self.classes.Job,\ - self.classes.SubJob + ParentThing, DataContainer, Job, SubJob = ( + self.classes.ParentThing, + self.classes.DataContainer, + self.classes.Job, + self.classes.SubJob, + ) Job_P = with_polymorphic(Job, SubJob, aliased=True, flat=True) s = Session() - q = s.query(Job).join(DataContainer.jobs).\ - filter( - DataContainer.jobs.of_type(Job_P). - any(Job_P.id < Job.id) + q = ( + s.query(Job) + .join(DataContainer.jobs) + .filter(DataContainer.jobs.of_type(Job_P).any(Job_P.id < Job.id)) ) self.assert_compile( @@ -582,23 +638,28 @@ class SubclassRelationshipTest(testing.AssertsCompiledSQL, "FROM job AS job_1 LEFT OUTER JOIN subjob AS subjob_1 " "ON job_1.id = subjob_1.id " "WHERE data_container.id = job_1.container_id " - "AND job_1.id < job.id)" + "AND job_1.id < job.id)", ) def test_any_walias(self): - ParentThing, DataContainer, Job, SubJob = \ - self.classes.ParentThing,\ - self.classes.DataContainer,\ - self.classes.Job,\ - self.classes.SubJob + ParentThing, DataContainer, Job, SubJob = ( + self.classes.ParentThing, + self.classes.DataContainer, + self.classes.Job, + self.classes.SubJob, + ) Job_A = aliased(Job) s = Session() - q = s.query(Job).join(DataContainer.jobs).\ - filter( - DataContainer.jobs.of_type(Job_A). - any(and_(Job_A.id < Job.id, Job_A.type == 'fred')) + q = ( + s.query(Job) + .join(DataContainer.jobs) + .filter( + DataContainer.jobs.of_type(Job_A).any( + and_(Job_A.id < Job.id, Job_A.type == "fred") + ) + ) ) self.assert_compile( q, @@ -610,34 +671,38 @@ class SubclassRelationshipTest(testing.AssertsCompiledSQL, "WHERE EXISTS (SELECT 1 " "FROM job AS job_1 " "WHERE data_container.id = job_1.container_id " - "AND job_1.id < job.id AND job_1.type = :type_1)" + "AND job_1.id < job.id AND job_1.type = :type_1)", ) def test_join_wpoly(self): - ParentThing, DataContainer, Job, SubJob = \ - self.classes.ParentThing,\ - self.classes.DataContainer,\ - self.classes.Job,\ - self.classes.SubJob + ParentThing, DataContainer, Job, SubJob = ( + self.classes.ParentThing, + self.classes.DataContainer, + self.classes.Job, + self.classes.SubJob, + ) Job_P = with_polymorphic(Job, SubJob) s = Session() q = s.query(DataContainer).join(DataContainer.jobs.of_type(Job_P)) - self.assert_compile(q, - "SELECT data_container.id AS data_container_id, " - "data_container.name AS data_container_name " - "FROM data_container JOIN " - "(job LEFT OUTER JOIN subjob " - "ON job.id = subjob.id) " - "ON data_container.id = job.container_id") + self.assert_compile( + q, + "SELECT data_container.id AS data_container_id, " + "data_container.name AS data_container_name " + "FROM data_container JOIN " + "(job LEFT OUTER JOIN subjob " + "ON job.id = subjob.id) " + "ON data_container.id = job.container_id", + ) def test_join_wsubclass(self): - ParentThing, DataContainer, Job, SubJob = \ - self.classes.ParentThing,\ - self.classes.DataContainer,\ - self.classes.Job,\ - self.classes.SubJob + ParentThing, DataContainer, Job, SubJob = ( + self.classes.ParentThing, + self.classes.DataContainer, + self.classes.Job, + self.classes.SubJob, + ) s = Session() q = s.query(DataContainer).join(DataContainer.jobs.of_type(SubJob)) @@ -651,88 +716,101 @@ class SubclassRelationshipTest(testing.AssertsCompiledSQL, "SELECT data_container.id AS data_container_id, " "data_container.name AS data_container_name " "FROM data_container JOIN (job JOIN subjob ON job.id = subjob.id) " - "ON data_container.id = job.container_id" + "ON data_container.id = job.container_id", ) def test_join_wpoly_innerjoin(self): - ParentThing, DataContainer, Job, SubJob = \ - self.classes.ParentThing,\ - self.classes.DataContainer,\ - self.classes.Job,\ - self.classes.SubJob + ParentThing, DataContainer, Job, SubJob = ( + self.classes.ParentThing, + self.classes.DataContainer, + self.classes.Job, + self.classes.SubJob, + ) Job_P = with_polymorphic(Job, SubJob, innerjoin=True) s = Session() q = s.query(DataContainer).join(DataContainer.jobs.of_type(Job_P)) - self.assert_compile(q, - "SELECT data_container.id AS data_container_id, " - "data_container.name AS data_container_name " - "FROM data_container JOIN " - "(job JOIN subjob ON job.id = subjob.id) " - "ON data_container.id = job.container_id") + self.assert_compile( + q, + "SELECT data_container.id AS data_container_id, " + "data_container.name AS data_container_name " + "FROM data_container JOIN " + "(job JOIN subjob ON job.id = subjob.id) " + "ON data_container.id = job.container_id", + ) def test_join_walias(self): - ParentThing, DataContainer, Job, SubJob = \ - self.classes.ParentThing,\ - self.classes.DataContainer,\ - self.classes.Job,\ - self.classes.SubJob + ParentThing, DataContainer, Job, SubJob = ( + self.classes.ParentThing, + self.classes.DataContainer, + self.classes.Job, + self.classes.SubJob, + ) Job_A = aliased(Job) s = Session() q = s.query(DataContainer).join(DataContainer.jobs.of_type(Job_A)) - self.assert_compile(q, - "SELECT data_container.id AS data_container_id, " - "data_container.name AS data_container_name " - "FROM data_container JOIN job AS job_1 " - "ON data_container.id = job_1.container_id") + self.assert_compile( + q, + "SELECT data_container.id AS data_container_id, " + "data_container.name AS data_container_name " + "FROM data_container JOIN job AS job_1 " + "ON data_container.id = job_1.container_id", + ) def test_join_explicit_wpoly_noalias(self): - ParentThing, DataContainer, Job, SubJob = \ - self.classes.ParentThing,\ - self.classes.DataContainer,\ - self.classes.Job,\ - self.classes.SubJob + ParentThing, DataContainer, Job, SubJob = ( + self.classes.ParentThing, + self.classes.DataContainer, + self.classes.Job, + self.classes.SubJob, + ) Job_P = with_polymorphic(Job, SubJob) s = Session() q = s.query(DataContainer).join(Job_P, DataContainer.jobs) - self.assert_compile(q, - "SELECT data_container.id AS data_container_id, " - "data_container.name AS data_container_name " - "FROM data_container JOIN " - "(job LEFT OUTER JOIN subjob " - "ON job.id = subjob.id) " - "ON data_container.id = job.container_id") + self.assert_compile( + q, + "SELECT data_container.id AS data_container_id, " + "data_container.name AS data_container_name " + "FROM data_container JOIN " + "(job LEFT OUTER JOIN subjob " + "ON job.id = subjob.id) " + "ON data_container.id = job.container_id", + ) def test_join_explicit_wpoly_flat(self): - ParentThing, DataContainer, Job, SubJob = \ - self.classes.ParentThing,\ - self.classes.DataContainer,\ - self.classes.Job,\ - self.classes.SubJob + ParentThing, DataContainer, Job, SubJob = ( + self.classes.ParentThing, + self.classes.DataContainer, + self.classes.Job, + self.classes.SubJob, + ) Job_P = with_polymorphic(Job, SubJob, flat=True) s = Session() q = s.query(DataContainer).join(Job_P, DataContainer.jobs) - self.assert_compile(q, - "SELECT data_container.id AS data_container_id, " - "data_container.name AS data_container_name " - "FROM data_container JOIN " - "(job AS job_1 LEFT OUTER JOIN subjob AS subjob_1 " - "ON job_1.id = subjob_1.id) " - "ON data_container.id = job_1.container_id") + self.assert_compile( + q, + "SELECT data_container.id AS data_container_id, " + "data_container.name AS data_container_name " + "FROM data_container JOIN " + "(job AS job_1 LEFT OUTER JOIN subjob AS subjob_1 " + "ON job_1.id = subjob_1.id) " + "ON data_container.id = job_1.container_id", + ) def test_join_explicit_wpoly_full_alias(self): - ParentThing, DataContainer, Job, SubJob = \ - self.classes.ParentThing,\ - self.classes.DataContainer,\ - self.classes.Job,\ - self.classes.SubJob + ParentThing, DataContainer, Job, SubJob = ( + self.classes.ParentThing, + self.classes.DataContainer, + self.classes.Job, + self.classes.SubJob, + ) Job_P = with_polymorphic(Job, SubJob, aliased=True) @@ -748,88 +826,89 @@ class SubclassRelationshipTest(testing.AssertsCompiledSQL, "job.container_id AS job_container_id, " "subjob.id AS subjob_id, subjob.attr AS subjob_attr " "FROM job LEFT OUTER JOIN subjob ON job.id = subjob.id) " - "AS anon_1 ON data_container.id = anon_1.job_container_id" + "AS anon_1 ON data_container.id = anon_1.job_container_id", ) class SubclassRelationshipTest2( - testing.AssertsCompiledSQL, fixtures.DeclarativeMappedTest): + testing.AssertsCompiledSQL, fixtures.DeclarativeMappedTest +): - run_setup_classes = 'once' - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_classes = "once" + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class A(Base): - __tablename__ = 't_a' + __tablename__ = "t_a" - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) class B(Base): - __tablename__ = 't_b' + __tablename__ = "t_b" type = Column(String(2)) __mapper_args__ = { - 'polymorphic_identity': 'b', - 'polymorphic_on': type + "polymorphic_identity": "b", + "polymorphic_on": type, } - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) # Relationship to A - a_id = Column(Integer, ForeignKey('t_a.id')) - a = relationship('A', backref='bs') + a_id = Column(Integer, ForeignKey("t_a.id")) + a = relationship("A", backref="bs") class B2(B): - __tablename__ = 't_b2' + __tablename__ = "t_b2" - __mapper_args__ = { - 'polymorphic_identity': 'b2', - } + __mapper_args__ = {"polymorphic_identity": "b2"} - id = Column(Integer, ForeignKey('t_b.id'), primary_key=True) + id = Column(Integer, ForeignKey("t_b.id"), primary_key=True) class C(Base): - __tablename__ = 't_c' + __tablename__ = "t_c" type = Column(String(2)) __mapper_args__ = { - 'polymorphic_identity': 'c', - 'polymorphic_on': type + "polymorphic_identity": "c", + "polymorphic_on": type, } - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) # Relationship to B - b_id = Column(Integer, ForeignKey('t_b.id')) - b = relationship('B', backref='cs') + b_id = Column(Integer, ForeignKey("t_b.id")) + b = relationship("B", backref="cs") class C2(C): - __tablename__ = 't_c2' + __tablename__ = "t_c2" - __mapper_args__ = { - 'polymorphic_identity': 'c2', - } + __mapper_args__ = {"polymorphic_identity": "c2"} - id = Column(Integer, ForeignKey('t_c.id'), primary_key=True) + id = Column(Integer, ForeignKey("t_c.id"), primary_key=True) class D(Base): - __tablename__ = 't_d' + __tablename__ = "t_d" - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) # Relationship to B - c_id = Column(Integer, ForeignKey('t_c.id')) - c = relationship('C', backref='ds') + c_id = Column(Integer, ForeignKey("t_c.id")) + c = relationship("C", backref="ds") @classmethod def insert_data(cls): @@ -840,40 +919,28 @@ class SubclassRelationshipTest2( @classmethod def _fixture(cls): - A, B, B2, C, C2, D = cls.classes('A', 'B', 'B2', 'C', 'C2', 'D') + A, B, B2, C, C2, D = cls.classes("A", "B", "B2", "C", "C2", "D") - return [ - A(bs=[B2(cs=[C2(ds=[D()])])]), - A(bs=[B2(cs=[C2(ds=[D()])])]), - ] + return [A(bs=[B2(cs=[C2(ds=[D()])])]), A(bs=[B2(cs=[C2(ds=[D()])])])] def test_all_subq_query(self): - A, B, B2, C, C2, D = self.classes('A', 'B', 'B2', 'C', 'C2', 'D') + A, B, B2, C, C2, D = self.classes("A", "B", "B2", "C", "C2", "D") session = Session(testing.db) b_b2 = with_polymorphic(B, [B2], flat=True) c_c2 = with_polymorphic(C, [C2], flat=True) - q = session.query( - A - ).options( - subqueryload( - A.bs.of_type(b_b2) - ).subqueryload( - b_b2.cs.of_type(c_c2) - ).subqueryload( - c_c2.ds - ) + q = session.query(A).options( + subqueryload(A.bs.of_type(b_b2)) + .subqueryload(b_b2.cs.of_type(c_c2)) + .subqueryload(c_c2.ds) ) self.assert_sql_execution( testing.db, q.all, - CompiledSQL( - "SELECT t_a.id AS t_a_id FROM t_a", - {} - ), + CompiledSQL("SELECT t_a.id AS t_a_id FROM t_a", {}), CompiledSQL( "SELECT t_b_1.type AS t_b_1_type, t_b_1.id AS t_b_1_id, " "t_b_1.a_id AS t_b_1_a_id, t_b2_1.id AS t_b2_1_id, " @@ -882,7 +949,7 @@ class SubclassRelationshipTest2( "JOIN (t_b AS t_b_1 LEFT OUTER JOIN t_b2 AS t_b2_1 " "ON t_b_1.id = t_b2_1.id) ON anon_1.t_a_id = t_b_1.a_id " "ORDER BY anon_1.t_a_id", - {} + {}, ), CompiledSQL( "SELECT t_c_1.type AS t_c_1_type, t_c_1.id AS t_c_1_id, " @@ -893,7 +960,7 @@ class SubclassRelationshipTest2( "JOIN (t_c AS t_c_1 LEFT OUTER JOIN t_c2 AS t_c2_1 ON " "t_c_1.id = t_c2_1.id) ON t_b_1.id = t_c_1.b_id " "ORDER BY t_b_1.id", - {} + {}, ), CompiledSQL( "SELECT t_d.id AS t_d_id, t_d.c_id AS t_d_c_id, " @@ -906,55 +973,56 @@ class SubclassRelationshipTest2( "ON t_c_1.id = t_c2_1.id) " "ON t_b_1.id = t_c_1.b_id " "JOIN t_d ON t_c_1.id = t_d.c_id ORDER BY t_c_1.id", - {} - ) + {}, + ), ) class SubclassRelationshipTest3( - testing.AssertsCompiledSQL, fixtures.DeclarativeMappedTest): + testing.AssertsCompiledSQL, fixtures.DeclarativeMappedTest +): - run_setup_classes = 'once' - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_classes = "once" + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class _A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) type = Column(String(50), nullable=False) - b = relationship('_B', back_populates='a') + b = relationship("_B", back_populates="a") __mapper_args__ = {"polymorphic_on": type} class _B(Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) type = Column(String(50), nullable=False) a_id = Column(Integer, ForeignKey(_A.id)) - a = relationship(_A, back_populates='b') + a = relationship(_A, back_populates="b") __mapper_args__ = {"polymorphic_on": type} class _C(Base): - __tablename__ = 'c' + __tablename__ = "c" id = Column(Integer, primary_key=True) type = Column(String(50), nullable=False) b_id = Column(Integer, ForeignKey(_B.id)) __mapper_args__ = {"polymorphic_on": type} class A1(_A): - __mapper_args__ = {'polymorphic_identity': 'A1'} + __mapper_args__ = {"polymorphic_identity": "A1"} class B1(_B): - __mapper_args__ = {'polymorphic_identity': 'B1'} + __mapper_args__ = {"polymorphic_identity": "B1"} class C1(_C): - __mapper_args__ = {'polymorphic_identity': 'C1'} - b1 = relationship(B1, backref='c1') + __mapper_args__ = {"polymorphic_identity": "C1"} + b1 = relationship(B1, backref="c1") _query1 = ( "SELECT b.id AS b_id, b.type AS b_type, b.a_id AS b_a_id, " @@ -989,38 +1057,42 @@ class SubclassRelationshipTest3( ) def _test(self, join_of_type, of_type_for_c1, aliased_): - A1, B1, C1 = self.classes('A1', 'B1', 'C1') + A1, B1, C1 = self.classes("A1", "B1", "C1") if aliased_: - A1 = aliased(A1, name='aaa') - B1 = aliased(B1, name='bbb') - C1 = aliased(C1, name='ccc') + A1 = aliased(A1, name="aaa") + B1 = aliased(B1, name="bbb") + C1 = aliased(C1, name="ccc") sess = Session() abc = sess.query(A1) if join_of_type: - abc = abc.outerjoin(A1.b.of_type(B1)).\ - options(contains_eager(A1.b.of_type(B1))) + abc = abc.outerjoin(A1.b.of_type(B1)).options( + contains_eager(A1.b.of_type(B1)) + ) if of_type_for_c1: - abc = abc.outerjoin(B1.c1.of_type(C1)).\ - options( - contains_eager(A1.b.of_type(B1), B1.c1.of_type(C1))) + abc = abc.outerjoin(B1.c1.of_type(C1)).options( + contains_eager(A1.b.of_type(B1), B1.c1.of_type(C1)) + ) else: - abc = abc.outerjoin(B1.c1).\ - options(contains_eager(A1.b.of_type(B1), B1.c1)) + abc = abc.outerjoin(B1.c1).options( + contains_eager(A1.b.of_type(B1), B1.c1) + ) else: - abc = abc.outerjoin(B1, A1.b).\ - options(contains_eager(A1.b.of_type(B1))) + abc = abc.outerjoin(B1, A1.b).options( + contains_eager(A1.b.of_type(B1)) + ) if of_type_for_c1: - abc = abc.outerjoin(C1, B1.c1).\ - options( - contains_eager(A1.b.of_type(B1), B1.c1.of_type(C1))) + abc = abc.outerjoin(C1, B1.c1).options( + contains_eager(A1.b.of_type(B1), B1.c1.of_type(C1)) + ) else: - abc = abc.outerjoin(B1.c1).\ - options(contains_eager(A1.b.of_type(B1), B1.c1)) + abc = abc.outerjoin(B1.c1).options( + contains_eager(A1.b.of_type(B1), B1.c1) + ) if aliased_: if of_type_for_c1: @@ -1031,58 +1103,25 @@ class SubclassRelationshipTest3( self.assert_compile(abc, self._query1) def test_join_of_type_contains_eager_of_type_b1_c1(self): - self._test( - join_of_type=True, - of_type_for_c1=True, - aliased_=False - ) + self._test(join_of_type=True, of_type_for_c1=True, aliased_=False) def test_join_flat_contains_eager_of_type_b1_c1(self): - self._test( - join_of_type=False, - of_type_for_c1=True, - aliased_=False - ) + self._test(join_of_type=False, of_type_for_c1=True, aliased_=False) def test_join_of_type_contains_eager_of_type_b1(self): - self._test( - join_of_type=True, - of_type_for_c1=False, - aliased_=False - ) + self._test(join_of_type=True, of_type_for_c1=False, aliased_=False) def test_join_flat_contains_eager_of_type_b1(self): - self._test( - join_of_type=False, - of_type_for_c1=False, - aliased_=False - ) + self._test(join_of_type=False, of_type_for_c1=False, aliased_=False) def test_aliased_join_of_type_contains_eager_of_type_b1_c1(self): - self._test( - join_of_type=True, - of_type_for_c1=True, - aliased_=True - ) + self._test(join_of_type=True, of_type_for_c1=True, aliased_=True) def test_aliased_join_flat_contains_eager_of_type_b1_c1(self): - self._test( - join_of_type=False, - of_type_for_c1=True, - aliased_=True - ) + self._test(join_of_type=False, of_type_for_c1=True, aliased_=True) def test_aliased_join_of_type_contains_eager_of_type_b1(self): - self._test( - join_of_type=True, - of_type_for_c1=False, - aliased_=True - ) + self._test(join_of_type=True, of_type_for_c1=False, aliased_=True) def test_aliased_join_flat_contains_eager_of_type_b1(self): - self._test( - join_of_type=False, - of_type_for_c1=False, - aliased_=True - ) - + self._test(join_of_type=False, of_type_for_c1=False, aliased_=True) diff --git a/test/orm/test_onetoone.py b/test/orm/test_onetoone.py index 732ce15717..8506592c35 100644 --- a/test/orm/test_onetoone.py +++ b/test/orm/test_onetoone.py @@ -9,19 +9,27 @@ from sqlalchemy.testing import fixtures class O2OTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('jack', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('number', String(50)), - Column('status', String(20)), - Column('subroom', String(5))) + Table( + "jack", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("number", String(50)), + Column("status", String(20)), + Column("subroom", String(5)), + ) - Table('port', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30)), - Column('description', String(100)), - Column('jack_id', Integer, ForeignKey("jack.id"))) + Table( + "port", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30)), + Column("description", String(100)), + Column("jack_id", Integer, ForeignKey("jack.id")), + ) @classmethod def setup_mappers(cls): @@ -32,21 +40,27 @@ class O2OTest(fixtures.MappedTest): pass def test_basic(self): - Port, port, jack, Jack = (self.classes.Port, - self.tables.port, - self.tables.jack, - self.classes.Jack) + Port, port, jack, Jack = ( + self.classes.Port, + self.tables.port, + self.tables.jack, + self.classes.Jack, + ) mapper(Port, port) - mapper(Jack, jack, - properties=dict( - port=relationship(Port, backref='jack', uselist=False))) + mapper( + Jack, + jack, + properties=dict( + port=relationship(Port, backref="jack", uselist=False) + ), + ) session = create_session() - j = Jack(number='101') + j = Jack(number="101") session.add(j) - p = Port(name='fa0/1') + p = Port(name="fa0/1") session.add(p) j.port = p diff --git a/test/orm/test_options.py b/test/orm/test_options.py index 4e6f2f91fd..f057890d89 100644 --- a/test/orm/test_options.py +++ b/test/orm/test_options.py @@ -1,8 +1,25 @@ from sqlalchemy import inspect -from sqlalchemy.orm import attributes, mapper, relationship, backref, \ - configure_mappers, create_session, synonym, Session, class_mapper, \ - aliased, column_property, joinedload_all, joinedload, Query,\ - util as orm_util, Load, defer, defaultload, lazyload +from sqlalchemy.orm import ( + attributes, + mapper, + relationship, + backref, + configure_mappers, + create_session, + synonym, + Session, + class_mapper, + aliased, + column_property, + joinedload_all, + joinedload, + Query, + util as orm_util, + Load, + defer, + defaultload, + lazyload, +) from sqlalchemy.orm.query import QueryContext from sqlalchemy.orm import strategy_options import sqlalchemy as sa @@ -13,9 +30,10 @@ from sqlalchemy import Column, Integer, String, ForeignKey from sqlalchemy.orm import subqueryload from sqlalchemy.testing import fixtures + class QueryTest(_fixtures.FixtureTest): - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod @@ -26,12 +44,16 @@ class QueryTest(_fixtures.FixtureTest): pass mapper( - SubItem, None, inherits=cls.classes.Item, + SubItem, + None, + inherits=cls.classes.Item, properties={ "extra_keywords": relationship( - cls.classes.Keyword, viewonly=True, - secondary=cls.tables.item_keywords) - } + cls.classes.Keyword, + viewonly=True, + secondary=cls.tables.item_keywords, + ) + }, ) @@ -59,7 +81,10 @@ class PathTest(object): for val in opt._to_bind: val._bind_loader( [ent.entity_zero for ent in q._mapper_entities], - q._current_path, attr, False) + q._current_path, + attr, + False, + ) else: opt._process(q, True) attr = q._attributes @@ -67,19 +92,18 @@ class PathTest(object): assert_paths = [k[1] for k in attr] eq_( set([p for p in assert_paths]), - set([self._make_path(p) for p in paths]) + set([self._make_path(p) for p in paths]), ) class LoadTest(PathTest, QueryTest): - def test_str(self): User = self.classes.User result = Load(User) - result.strategy = (('deferred', False), ('instrument', True)) + result.strategy = (("deferred", False), ("instrument", True)) eq_( str(result), - "Load(strategy=(('deferred', False), ('instrument', True)))" + "Load(strategy=(('deferred', False), ('instrument', True)))", ) def test_gen_path_attr_entity(self): @@ -88,9 +112,10 @@ class LoadTest(PathTest, QueryTest): result = Load(User) eq_( - result._generate_path(inspect(User)._path_registry, - User.addresses, "relationship"), - self._make_path_registry([User, "addresses", Address]) + result._generate_path( + inspect(User)._path_registry, User.addresses, "relationship" + ), + self._make_path_registry([User, "addresses", Address]), ) def test_gen_path_attr_column(self): @@ -98,9 +123,10 @@ class LoadTest(PathTest, QueryTest): result = Load(User) eq_( - result._generate_path(inspect(User)._path_registry, - User.name, "column"), - self._make_path_registry([User, "name"]) + result._generate_path( + inspect(User)._path_registry, User.name, "column" + ), + self._make_path_registry([User, "name"]), ) def test_gen_path_string_entity(self): @@ -109,9 +135,10 @@ class LoadTest(PathTest, QueryTest): result = Load(User) eq_( - result._generate_path(inspect(User)._path_registry, - "addresses", "relationship"), - self._make_path_registry([User, "addresses", Address]) + result._generate_path( + inspect(User)._path_registry, "addresses", "relationship" + ), + self._make_path_registry([User, "addresses", Address]), ) def test_gen_path_string_column(self): @@ -120,8 +147,9 @@ class LoadTest(PathTest, QueryTest): result = Load(User) eq_( result._generate_path( - inspect(User)._path_registry, "name", "column"), - self._make_path_registry([User, "name"]) + inspect(User)._path_registry, "name", "column" + ), + self._make_path_registry([User, "name"]), ) def test_gen_path_invalid_from_col(self): @@ -133,8 +161,10 @@ class LoadTest(PathTest, QueryTest): sa.exc.ArgumentError, "Attribute 'name' of entity 'Mapper|User|users' does " "not refer to a mapped entity", - result._generate_path, result.path, User.addresses, "relationship" - + result._generate_path, + result.path, + User.addresses, + "relationship", ) def test_gen_path_attr_entity_invalid_raiseerr(self): @@ -148,7 +178,9 @@ class LoadTest(PathTest, QueryTest): "Attribute 'Order.items' does not link from element " "'Mapper|User|users'", result._generate_path, - inspect(User)._path_registry, Order.items, "relationship", + inspect(User)._path_registry, + Order.items, + "relationship", ) def test_gen_path_attr_entity_invalid_noraiseerr(self): @@ -157,21 +189,22 @@ class LoadTest(PathTest, QueryTest): result = Load(User) - eq_(result._generate_path(inspect(User)._path_registry, Order.items, - "relationship", False), - None) + eq_( + result._generate_path( + inspect(User)._path_registry, + Order.items, + "relationship", + False, + ), + None, + ) def test_set_strat_ent(self): User = self.classes.User l1 = Load(User) l2 = l1.joinedload("addresses") - eq_( - l1.context, - { - ('loader', self._make_path([User, "addresses"])): l2 - } - ) + eq_(l1.context, {("loader", self._make_path([User, "addresses"])): l2}) def test_set_strat_col(self): User = self.classes.User @@ -179,12 +212,7 @@ class LoadTest(PathTest, QueryTest): l1 = Load(User) l2 = l1.defer("name") l3 = list(l2.context.values())[0] - eq_( - l1.context, - { - ('loader', self._make_path([User, "name"])): l3 - } - ) + eq_(l1.context, {("loader", self._make_path([User, "name"])): l3}) class OfTypePathingTest(PathTest, QueryTest): @@ -196,123 +224,141 @@ class OfTypePathingTest(PathTest, QueryTest): class SubAddr(Address): pass - mapper(SubAddr, inherits=Address, properties={ - "sub_attr": column_property(address_table.c.email_address), - "dings": relationship(Dingaling) - }) + mapper( + SubAddr, + inherits=Address, + properties={ + "sub_attr": column_property(address_table.c.email_address), + "dings": relationship(Dingaling), + }, + ) return User, Address, SubAddr def test_oftype_only_col_attr_unbound(self): User, Address, SubAddr = self._fixture() - l1 = defaultload( - User.addresses.of_type(SubAddr)).defer(SubAddr.sub_attr) + l1 = defaultload(User.addresses.of_type(SubAddr)).defer( + SubAddr.sub_attr + ) sess = Session() q = sess.query(User) self._assert_path_result( - l1, q, - [(User, 'addresses'), (User, 'addresses', SubAddr, 'sub_attr')] + l1, + q, + [(User, "addresses"), (User, "addresses", SubAddr, "sub_attr")], ) def test_oftype_only_col_attr_bound(self): User, Address, SubAddr = self._fixture() - l1 = Load(User).defaultload( - User.addresses.of_type(SubAddr)).defer(SubAddr.sub_attr) + l1 = ( + Load(User) + .defaultload(User.addresses.of_type(SubAddr)) + .defer(SubAddr.sub_attr) + ) sess = Session() q = sess.query(User) self._assert_path_result( - l1, q, - [(User, 'addresses'), (User, 'addresses', SubAddr, 'sub_attr')] + l1, + q, + [(User, "addresses"), (User, "addresses", SubAddr, "sub_attr")], ) def test_oftype_only_col_attr_string_unbound(self): User, Address, SubAddr = self._fixture() - l1 = defaultload( - User.addresses.of_type(SubAddr)).defer("sub_attr") + l1 = defaultload(User.addresses.of_type(SubAddr)).defer("sub_attr") sess = Session() q = sess.query(User) self._assert_path_result( - l1, q, - [(User, 'addresses'), (User, 'addresses', SubAddr, 'sub_attr')] + l1, + q, + [(User, "addresses"), (User, "addresses", SubAddr, "sub_attr")], ) def test_oftype_only_col_attr_string_bound(self): User, Address, SubAddr = self._fixture() - l1 = Load(User).defaultload( - User.addresses.of_type(SubAddr)).defer("sub_attr") + l1 = ( + Load(User) + .defaultload(User.addresses.of_type(SubAddr)) + .defer("sub_attr") + ) sess = Session() q = sess.query(User) self._assert_path_result( - l1, q, - [(User, 'addresses'), (User, 'addresses', SubAddr, 'sub_attr')] + l1, + q, + [(User, "addresses"), (User, "addresses", SubAddr, "sub_attr")], ) def test_oftype_only_rel_attr_unbound(self): User, Address, SubAddr = self._fixture() - l1 = defaultload( - User.addresses.of_type(SubAddr)).joinedload(SubAddr.dings) + l1 = defaultload(User.addresses.of_type(SubAddr)).joinedload( + SubAddr.dings + ) sess = Session() q = sess.query(User) self._assert_path_result( - l1, q, - [(User, 'addresses'), (User, 'addresses', SubAddr, 'dings')] + l1, q, [(User, "addresses"), (User, "addresses", SubAddr, "dings")] ) def test_oftype_only_rel_attr_bound(self): User, Address, SubAddr = self._fixture() - l1 = Load(User).defaultload( - User.addresses.of_type(SubAddr)).joinedload(SubAddr.dings) + l1 = ( + Load(User) + .defaultload(User.addresses.of_type(SubAddr)) + .joinedload(SubAddr.dings) + ) sess = Session() q = sess.query(User) self._assert_path_result( - l1, q, - [(User, 'addresses'), (User, 'addresses', SubAddr, 'dings')] + l1, q, [(User, "addresses"), (User, "addresses", SubAddr, "dings")] ) def test_oftype_only_rel_attr_string_unbound(self): User, Address, SubAddr = self._fixture() - l1 = defaultload( - User.addresses.of_type(SubAddr)).joinedload("dings") + l1 = defaultload(User.addresses.of_type(SubAddr)).joinedload("dings") sess = Session() q = sess.query(User) self._assert_path_result( - l1, q, - [(User, 'addresses'), (User, 'addresses', SubAddr, 'dings')] + l1, q, [(User, "addresses"), (User, "addresses", SubAddr, "dings")] ) def test_oftype_only_rel_attr_string_bound(self): User, Address, SubAddr = self._fixture() - l1 = Load(User).defaultload( - User.addresses.of_type(SubAddr)).defer("sub_attr") + l1 = ( + Load(User) + .defaultload(User.addresses.of_type(SubAddr)) + .defer("sub_attr") + ) sess = Session() q = sess.query(User) self._assert_path_result( - l1, q, - [(User, 'addresses'), (User, 'addresses', SubAddr, 'sub_attr')] + l1, + q, + [(User, "addresses"), (User, "addresses", SubAddr, "sub_attr")], ) class OptionsTest(PathTest, QueryTest): - def _option_fixture(self, *arg): return strategy_options._UnboundLoad._from_keys( - strategy_options._UnboundLoad.joinedload, arg, True, {}) + strategy_options._UnboundLoad.joinedload, arg, True, {} + ) def test_get_path_one_level_string(self): User = self.classes.User @@ -321,7 +367,7 @@ class OptionsTest(PathTest, QueryTest): q = sess.query(User) opt = self._option_fixture("addresses") - self._assert_path_result(opt, q, [(User, 'addresses')]) + self._assert_path_result(opt, q, [(User, "addresses")]) def test_get_path_one_level_attribute(self): User = self.classes.User @@ -330,7 +376,7 @@ class OptionsTest(PathTest, QueryTest): q = sess.query(User) opt = self._option_fixture(User.addresses) - self._assert_path_result(opt, q, [(User, 'addresses')]) + self._assert_path_result(opt, q, [(User, "addresses")]) def test_path_on_entity_but_doesnt_match_currentpath(self): User, Address = self.classes.User, self.classes.Address @@ -340,10 +386,11 @@ class OptionsTest(PathTest, QueryTest): # see [ticket:2098] sess = Session() q = sess.query(User) - opt = self._option_fixture('email_address', 'id') + opt = self._option_fixture("email_address", "id") q = sess.query(Address)._with_current_path( - orm_util.PathRegistry.coerce([inspect(User), - inspect(User).attrs.addresses]) + orm_util.PathRegistry.coerce( + [inspect(User), inspect(User).attrs.addresses] + ) ) self._assert_path_result(opt, q, []) @@ -356,73 +403,87 @@ class OptionsTest(PathTest, QueryTest): self._assert_path_result(opt, q, []) def test_path_multilevel_string(self): - Item, User, Order = (self.classes.Item, - self.classes.User, - self.classes.Order) + Item, User, Order = ( + self.classes.Item, + self.classes.User, + self.classes.Order, + ) sess = Session() q = sess.query(User) opt = self._option_fixture("orders.items.keywords") - self._assert_path_result(opt, q, [ - (User, 'orders'), - (User, 'orders', Order, 'items'), - (User, 'orders', Order, 'items', Item, 'keywords') - ]) + self._assert_path_result( + opt, + q, + [ + (User, "orders"), + (User, "orders", Order, "items"), + (User, "orders", Order, "items", Item, "keywords"), + ], + ) def test_path_multilevel_attribute(self): - Item, User, Order = (self.classes.Item, - self.classes.User, - self.classes.Order) + Item, User, Order = ( + self.classes.Item, + self.classes.User, + self.classes.Order, + ) sess = Session() q = sess.query(User) opt = self._option_fixture(User.orders, Order.items, Item.keywords) - self._assert_path_result(opt, q, [ - (User, 'orders'), - (User, 'orders', Order, 'items'), - (User, 'orders', Order, 'items', Item, 'keywords') - ]) + self._assert_path_result( + opt, + q, + [ + (User, "orders"), + (User, "orders", Order, "items"), + (User, "orders", Order, "items", Item, "keywords"), + ], + ) def test_with_current_matching_string(self): - Item, User, Order = (self.classes.Item, - self.classes.User, - self.classes.Order) + Item, User, Order = ( + self.classes.Item, + self.classes.User, + self.classes.Order, + ) sess = Session() q = sess.query(Item)._with_current_path( - self._make_path_registry([User, 'orders', Order, 'items']) + self._make_path_registry([User, "orders", Order, "items"]) ) opt = self._option_fixture("orders.items.keywords") - self._assert_path_result(opt, q, [ - (Item, 'keywords') - ]) + self._assert_path_result(opt, q, [(Item, "keywords")]) def test_with_current_matching_attribute(self): - Item, User, Order = (self.classes.Item, - self.classes.User, - self.classes.Order) + Item, User, Order = ( + self.classes.Item, + self.classes.User, + self.classes.Order, + ) sess = Session() q = sess.query(Item)._with_current_path( - self._make_path_registry([User, 'orders', Order, 'items']) + self._make_path_registry([User, "orders", Order, "items"]) ) opt = self._option_fixture(User.orders, Order.items, Item.keywords) - self._assert_path_result(opt, q, [ - (Item, 'keywords') - ]) + self._assert_path_result(opt, q, [(Item, "keywords")]) def test_with_current_nonmatching_string(self): - Item, User, Order = (self.classes.Item, - self.classes.User, - self.classes.Order) + Item, User, Order = ( + self.classes.Item, + self.classes.User, + self.classes.Order, + ) sess = Session() q = sess.query(Item)._with_current_path( - self._make_path_registry([User, 'orders', Order, 'items']) + self._make_path_registry([User, "orders", Order, "items"]) ) opt = self._option_fixture("keywords") @@ -432,13 +493,15 @@ class OptionsTest(PathTest, QueryTest): self._assert_path_result(opt, q, []) def test_with_current_nonmatching_attribute(self): - Item, User, Order = (self.classes.Item, - self.classes.User, - self.classes.Order) + Item, User, Order = ( + self.classes.Item, + self.classes.User, + self.classes.Order, + ) sess = Session() q = sess.query(Item)._with_current_path( - self._make_path_registry([User, 'orders', Order, 'items']) + self._make_path_registry([User, "orders", Order, "items"]) ) opt = self._option_fixture(Item.keywords) @@ -448,14 +511,17 @@ class OptionsTest(PathTest, QueryTest): self._assert_path_result(opt, q, []) def test_with_current_nonmatching_entity(self): - Item, User, Order = (self.classes.Item, - self.classes.User, - self.classes.Order) + Item, User, Order = ( + self.classes.Item, + self.classes.User, + self.classes.Order, + ) sess = Session() q = sess.query(Item)._with_current_path( self._make_path_registry( - [inspect(aliased(User)), 'orders', Order, 'items']) + [inspect(aliased(User)), "orders", Order, "items"] + ) ) opt = self._option_fixture(User.orders) @@ -465,8 +531,7 @@ class OptionsTest(PathTest, QueryTest): self._assert_path_result(opt, q, []) q = sess.query(Item)._with_current_path( - self._make_path_registry( - [User, 'orders', Order, 'items']) + self._make_path_registry([User, "orders", Order, "items"]) ) ac = aliased(User) @@ -478,15 +543,16 @@ class OptionsTest(PathTest, QueryTest): self._assert_path_result(opt, q, []) def test_with_current_match_aliased_classes(self): - Item, User, Order = (self.classes.Item, - self.classes.User, - self.classes.Order) + Item, User, Order = ( + self.classes.Item, + self.classes.User, + self.classes.Order, + ) ac = aliased(User) sess = Session() q = sess.query(Item)._with_current_path( - self._make_path_registry( - [inspect(ac), 'orders', Order, 'items']) + self._make_path_registry([inspect(ac), "orders", Order, "items"]) ) opt = self._option_fixture(ac.orders, Order.items, Item.keywords) @@ -502,14 +568,17 @@ class OptionsTest(PathTest, QueryTest): class SubAddr(Address): pass - mapper(SubAddr, inherits=Address, properties={ - 'flub': relationship(Dingaling) - }) + + mapper( + SubAddr, + inherits=Address, + properties={"flub": relationship(Dingaling)}, + ) q = sess.query(Address) opt = self._option_fixture(SubAddr.flub) - self._assert_path_result(opt, q, [(SubAddr, 'flub')]) + self._assert_path_result(opt, q, [(SubAddr, "flub")]) def test_from_subclass_to_subclass_attr(self): Dingaling, Address = self.classes.Dingaling, self.classes.Address @@ -518,14 +587,17 @@ class OptionsTest(PathTest, QueryTest): class SubAddr(Address): pass - mapper(SubAddr, inherits=Address, properties={ - 'flub': relationship(Dingaling) - }) + + mapper( + SubAddr, + inherits=Address, + properties={"flub": relationship(Dingaling)}, + ) q = sess.query(SubAddr) opt = self._option_fixture(SubAddr.flub) - self._assert_path_result(opt, q, [(SubAddr, 'flub')]) + self._assert_path_result(opt, q, [(SubAddr, "flub")]) def test_from_base_to_base_attr_via_subclass(self): Dingaling, Address = self.classes.Dingaling, self.classes.Address @@ -534,15 +606,19 @@ class OptionsTest(PathTest, QueryTest): class SubAddr(Address): pass - mapper(SubAddr, inherits=Address, properties={ - 'flub': relationship(Dingaling) - }) + + mapper( + SubAddr, + inherits=Address, + properties={"flub": relationship(Dingaling)}, + ) q = sess.query(Address) opt = self._option_fixture(SubAddr.user) - self._assert_path_result(opt, q, - [(Address, inspect(Address).attrs.user)]) + self._assert_path_result( + opt, q, [(Address, inspect(Address).attrs.user)] + ) def test_of_type(self): User, Address = self.classes.User, self.classes.Address @@ -551,18 +627,29 @@ class OptionsTest(PathTest, QueryTest): class SubAddr(Address): pass + mapper(SubAddr, inherits=Address) q = sess.query(User) opt = self._option_fixture( - User.addresses.of_type(SubAddr), SubAddr.user) + User.addresses.of_type(SubAddr), SubAddr.user + ) u_mapper = inspect(User) a_mapper = inspect(Address) - self._assert_path_result(opt, q, [ - (u_mapper, u_mapper.attrs.addresses), - (u_mapper, u_mapper.attrs.addresses, a_mapper, a_mapper.attrs.user) - ]) + self._assert_path_result( + opt, + q, + [ + (u_mapper, u_mapper.attrs.addresses), + ( + u_mapper, + u_mapper.attrs.addresses, + a_mapper, + a_mapper.attrs.user, + ), + ], + ) def test_of_type_string_attr(self): User, Address = self.classes.User, self.classes.Address @@ -571,43 +658,66 @@ class OptionsTest(PathTest, QueryTest): class SubAddr(Address): pass + mapper(SubAddr, inherits=Address) q = sess.query(User) - opt = self._option_fixture( - User.addresses.of_type(SubAddr), "user") + opt = self._option_fixture(User.addresses.of_type(SubAddr), "user") u_mapper = inspect(User) a_mapper = inspect(Address) - self._assert_path_result(opt, q, [ - (u_mapper, u_mapper.attrs.addresses), - (u_mapper, u_mapper.attrs.addresses, a_mapper, a_mapper.attrs.user) - ]) + self._assert_path_result( + opt, + q, + [ + (u_mapper, u_mapper.attrs.addresses), + ( + u_mapper, + u_mapper.attrs.addresses, + a_mapper, + a_mapper.attrs.user, + ), + ], + ) def test_of_type_plus_level(self): - Dingaling, User, Address = (self.classes.Dingaling, - self.classes.User, - self.classes.Address) + Dingaling, User, Address = ( + self.classes.Dingaling, + self.classes.User, + self.classes.Address, + ) sess = Session() class SubAddr(Address): pass - mapper(SubAddr, inherits=Address, properties={ - 'flub': relationship(Dingaling) - }) + + mapper( + SubAddr, + inherits=Address, + properties={"flub": relationship(Dingaling)}, + ) q = sess.query(User) opt = self._option_fixture( - User.addresses.of_type(SubAddr), SubAddr.flub) + User.addresses.of_type(SubAddr), SubAddr.flub + ) u_mapper = inspect(User) sa_mapper = inspect(SubAddr) - self._assert_path_result(opt, q, [ - (u_mapper, u_mapper.attrs.addresses), - (u_mapper, u_mapper.attrs.addresses, sa_mapper, - sa_mapper.attrs.flub) - ]) + self._assert_path_result( + opt, + q, + [ + (u_mapper, u_mapper.attrs.addresses), + ( + u_mapper, + u_mapper.attrs.addresses, + sa_mapper, + sa_mapper.attrs.flub, + ), + ], + ) def test_aliased_single(self): User = self.classes.User @@ -616,7 +726,7 @@ class OptionsTest(PathTest, QueryTest): ualias = aliased(User) q = sess.query(ualias) opt = self._option_fixture(ualias.addresses) - self._assert_path_result(opt, q, [(inspect(ualias), 'addresses')]) + self._assert_path_result(opt, q, [(inspect(ualias), "addresses")]) def test_with_current_aliased_single(self): User, Address = self.classes.User, self.classes.Address @@ -624,10 +734,10 @@ class OptionsTest(PathTest, QueryTest): sess = Session() ualias = aliased(User) q = sess.query(ualias)._with_current_path( - self._make_path_registry([Address, 'user']) + self._make_path_registry([Address, "user"]) ) opt = self._option_fixture(Address.user, ualias.addresses) - self._assert_path_result(opt, q, [(inspect(ualias), 'addresses')]) + self._assert_path_result(opt, q, [(inspect(ualias), "addresses")]) def test_with_current_aliased_single_nonmatching_option(self): User, Address = self.classes.User, self.classes.Address @@ -635,7 +745,7 @@ class OptionsTest(PathTest, QueryTest): sess = Session() ualias = aliased(User) q = sess.query(User)._with_current_path( - self._make_path_registry([Address, 'user']) + self._make_path_registry([Address, "user"]) ) opt = self._option_fixture(Address.user, ualias.addresses) self._assert_path_result(opt, q, []) @@ -646,7 +756,7 @@ class OptionsTest(PathTest, QueryTest): sess = Session() ualias = aliased(User) q = sess.query(ualias)._with_current_path( - self._make_path_registry([Address, 'user']) + self._make_path_registry([Address, "user"]) ) opt = self._option_fixture(Address.user, User.addresses) self._assert_path_result(opt, q, []) @@ -682,7 +792,7 @@ class OptionsTest(PathTest, QueryTest): opt = self._option_fixture(User.orders) sess = Session() q = sess.query(Item)._with_current_path( - self._make_path_registry([User, 'orders', Order, 'items']) + self._make_path_registry([User, "orders", Order, "items"]) ) self._assert_path_result(opt, q, []) @@ -693,10 +803,9 @@ class OptionsTest(PathTest, QueryTest): sess = Session() q = sess.query(User) opt = self._option_fixture(User.orders).joinedload("items") - self._assert_path_result(opt, q, [ - (User, 'orders'), - (User, 'orders', Order, "items") - ]) + self._assert_path_result( + opt, q, [(User, "orders"), (User, "orders", Order, "items")] + ) def test_chained_plus_dotted(self): User = self.classes.User @@ -705,11 +814,15 @@ class OptionsTest(PathTest, QueryTest): sess = Session() q = sess.query(User) opt = self._option_fixture("orders.items").joinedload("keywords") - self._assert_path_result(opt, q, [ - (User, 'orders'), - (User, 'orders', Order, "items"), - (User, 'orders', Order, "items", Item, "keywords") - ]) + self._assert_path_result( + opt, + q, + [ + (User, "orders"), + (User, "orders", Order, "items"), + (User, "orders", Order, "items", Item, "keywords"), + ], + ) def test_chained_plus_multi(self): User = self.classes.User @@ -717,19 +830,24 @@ class OptionsTest(PathTest, QueryTest): Item = self.classes.Item sess = Session() q = sess.query(User) - opt = self._option_fixture( - User.orders, Order.items).joinedload("keywords") - self._assert_path_result(opt, q, [ - (User, 'orders'), - (User, 'orders', Order, "items"), - (User, 'orders', Order, "items", Item, "keywords") - ]) + opt = self._option_fixture(User.orders, Order.items).joinedload( + "keywords" + ) + self._assert_path_result( + opt, + q, + [ + (User, "orders"), + (User, "orders", Order, "items"), + (User, "orders", Order, "items", Item, "keywords"), + ], + ) class FromSubclassOptionsTest(PathTest, fixtures.DeclarativeMappedTest): # test for regression to #3963 - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod @@ -737,36 +855,36 @@ class FromSubclassOptionsTest(PathTest, fixtures.DeclarativeMappedTest): Base = cls.DeclarativeBasic class BaseCls(Base): - __tablename__ = 'basecls' + __tablename__ = "basecls" id = Column(Integer, primary_key=True) type = Column(String(30)) - related_id = Column(ForeignKey('related.id')) + related_id = Column(ForeignKey("related.id")) related = relationship("Related") class SubClass(BaseCls): - __tablename__ = 'subcls' - id = Column(ForeignKey('basecls.id'), primary_key=True) + __tablename__ = "subcls" + id = Column(ForeignKey("basecls.id"), primary_key=True) class Related(Base): - __tablename__ = 'related' + __tablename__ = "related" id = Column(Integer, primary_key=True) - sub_related_id = Column(ForeignKey('sub_related.id')) - sub_related = relationship('SubRelated') + sub_related_id = Column(ForeignKey("sub_related.id")) + sub_related = relationship("SubRelated") class SubRelated(Base): - __tablename__ = 'sub_related' + __tablename__ = "sub_related" id = Column(Integer, primary_key=True) def test_with_current_nonmatching_entity_subclasses(self): BaseCls, SubClass, Related, SubRelated = self.classes( - 'BaseCls', 'SubClass', 'Related', 'SubRelated') + "BaseCls", "SubClass", "Related", "SubRelated" + ) sess = Session() q = sess.query(Related)._with_current_path( - self._make_path_registry( - [inspect(SubClass), 'related']) + self._make_path_registry([inspect(SubClass), "related"]) ) opt = subqueryload(SubClass.related).subqueryload(Related.sub_related) @@ -787,7 +905,7 @@ class OptionsNoPropTest(_fixtures.FixtureTest): def test_option_with_mapper_basestring(self): Item = self.classes.Item - self._assert_option([Item], 'keywords') + self._assert_option([Item], "keywords") def test_option_with_mapper_PropCompatator(self): Item = self.classes.Item @@ -797,7 +915,7 @@ class OptionsNoPropTest(_fixtures.FixtureTest): def test_option_with_mapper_then_column_basestring(self): Item = self.classes.Item - self._assert_option([Item, Item.id], 'keywords') + self._assert_option([Item, Item.id], "keywords") def test_option_with_mapper_then_column_PropComparator(self): Item = self.classes.Item @@ -807,7 +925,7 @@ class OptionsNoPropTest(_fixtures.FixtureTest): def test_option_with_column_then_mapper_basestring(self): Item = self.classes.Item - self._assert_option([Item.id, Item], 'keywords') + self._assert_option([Item.id, Item], "keywords") def test_option_with_column_then_mapper_PropComparator(self): Item = self.classes.Item @@ -817,11 +935,13 @@ class OptionsNoPropTest(_fixtures.FixtureTest): def test_option_with_column_basestring(self): Item = self.classes.Item - message = \ - "Query has only expression-based entities - "\ + message = ( + "Query has only expression-based entities - " "can't find property named 'keywords'." - self._assert_eager_with_just_column_exception(Item.id, - 'keywords', message) + ) + self._assert_eager_with_just_column_exception( + Item.id, "keywords", message + ) def test_option_with_column_PropComparator(self): Item = self.classes.Item @@ -830,7 +950,7 @@ class OptionsNoPropTest(_fixtures.FixtureTest): Item.id, Item.keywords, "Query has only expression-based entities " - "- can't find property named 'keywords'." + "- can't find property named 'keywords'.", ) def test_option_against_nonexistent_PropComparator(self): @@ -838,73 +958,75 @@ class OptionsNoPropTest(_fixtures.FixtureTest): Keyword = self.classes.Keyword self._assert_eager_with_entity_exception( [Keyword], - (joinedload(Item.keywords), ), + (joinedload(Item.keywords),), r"Can't find property 'keywords' on any entity specified " r"in this Query. Note the full path from root " r"\(Mapper\|Keyword\|keywords\) to target entity must be " - r"specified." + r"specified.", ) def test_option_against_nonexistent_basestring(self): Item = self.classes.Item self._assert_eager_with_entity_exception( [Item], - (joinedload("foo"), ), + (joinedload("foo"),), r"Can't find property named 'foo' on the mapped " - r"entity Mapper\|Item\|items in this Query." + r"entity Mapper\|Item\|items in this Query.", ) def test_option_against_nonexistent_twolevel_basestring(self): Item = self.classes.Item self._assert_eager_with_entity_exception( [Item], - (joinedload("keywords.foo"), ), + (joinedload("keywords.foo"),), r"Can't find property named 'foo' on the mapped entity " - r"Mapper\|Keyword\|keywords in this Query." + r"Mapper\|Keyword\|keywords in this Query.", ) def test_option_against_nonexistent_twolevel_all(self): Item = self.classes.Item self._assert_eager_with_entity_exception( [Item], - (joinedload_all("keywords.foo"), ), + (joinedload_all("keywords.foo"),), r"Can't find property named 'foo' on the mapped entity " - r"Mapper\|Keyword\|keywords in this Query." + r"Mapper\|Keyword\|keywords in this Query.", ) @testing.fails_if( lambda: True, - "PropertyOption doesn't yet check for relation/column on end result") + "PropertyOption doesn't yet check for relation/column on end result", + ) def test_option_against_non_relation_basestring(self): Item = self.classes.Item Keyword = self.classes.Keyword self._assert_eager_with_entity_exception( [Keyword, Item], - (joinedload_all("keywords"), ), + (joinedload_all("keywords"),), r"Attribute 'keywords' of entity 'Mapper\|Keyword\|keywords' " - "does not refer to a mapped entity" + "does not refer to a mapped entity", ) @testing.fails_if( lambda: True, - "PropertyOption doesn't yet check for relation/column on end result") + "PropertyOption doesn't yet check for relation/column on end result", + ) def test_option_against_multi_non_relation_basestring(self): Item = self.classes.Item Keyword = self.classes.Keyword self._assert_eager_with_entity_exception( [Keyword, Item], - (joinedload_all("keywords"), ), + (joinedload_all("keywords"),), r"Attribute 'keywords' of entity 'Mapper\|Keyword\|keywords' " - "does not refer to a mapped entity" + "does not refer to a mapped entity", ) def test_option_against_wrong_entity_type_basestring(self): Item = self.classes.Item self._assert_eager_with_entity_exception( [Item], - (joinedload_all("id", "keywords"), ), + (joinedload_all("id", "keywords"),), r"Attribute 'id' of entity 'Mapper\|Item\|items' does not " - r"refer to a mapped entity" + r"refer to a mapped entity", ) def test_option_against_multi_non_relation_twolevel_basestring(self): @@ -912,9 +1034,9 @@ class OptionsNoPropTest(_fixtures.FixtureTest): Keyword = self.classes.Keyword self._assert_eager_with_entity_exception( [Keyword, Item], - (joinedload_all("id", "keywords"), ), + (joinedload_all("id", "keywords"),), r"Attribute 'id' of entity 'Mapper\|Keyword\|keywords' " - "does not refer to a mapped entity" + "does not refer to a mapped entity", ) def test_option_against_multi_nonexistent_basestring(self): @@ -922,9 +1044,9 @@ class OptionsNoPropTest(_fixtures.FixtureTest): Keyword = self.classes.Keyword self._assert_eager_with_entity_exception( [Keyword, Item], - (joinedload_all("description"), ), + (joinedload_all("description"),), r"Can't find property named 'description' on the mapped " - r"entity Mapper\|Keyword\|keywords in this Query." + r"entity Mapper\|Keyword\|keywords in this Query.", ) def test_option_against_multi_no_entities_basestring(self): @@ -932,9 +1054,9 @@ class OptionsNoPropTest(_fixtures.FixtureTest): Keyword = self.classes.Keyword self._assert_eager_with_entity_exception( [Keyword.id, Item.id], - (joinedload_all("keywords"), ), + (joinedload_all("keywords"),), r"Query has only expression-based entities - can't find property " - "named 'keywords'." + "named 'keywords'.", ) def test_option_against_wrong_multi_entity_type_attr_one(self): @@ -942,9 +1064,9 @@ class OptionsNoPropTest(_fixtures.FixtureTest): Keyword = self.classes.Keyword self._assert_eager_with_entity_exception( [Keyword, Item], - (joinedload_all(Keyword.id, Item.keywords), ), + (joinedload_all(Keyword.id, Item.keywords),), r"Attribute 'id' of entity 'Mapper\|Keyword\|keywords' " - "does not refer to a mapped entity" + "does not refer to a mapped entity", ) def test_option_against_wrong_multi_entity_type_attr_two(self): @@ -952,9 +1074,9 @@ class OptionsNoPropTest(_fixtures.FixtureTest): Keyword = self.classes.Keyword self._assert_eager_with_entity_exception( [Keyword, Item], - (joinedload_all(Keyword.keywords, Item.keywords), ), + (joinedload_all(Keyword.keywords, Item.keywords),), r"Attribute 'keywords' of entity 'Mapper\|Keyword\|keywords' " - "does not refer to a mapped entity" + "does not refer to a mapped entity", ) def test_option_against_wrong_multi_entity_type_attr_three(self): @@ -962,9 +1084,9 @@ class OptionsNoPropTest(_fixtures.FixtureTest): Keyword = self.classes.Keyword self._assert_eager_with_entity_exception( [Keyword.id, Item.id], - (joinedload_all(Keyword.keywords, Item.keywords), ), + (joinedload_all(Keyword.keywords, Item.keywords),), r"Query has only expression-based entities - " - "can't find property named 'keywords'." + "can't find property named 'keywords'.", ) def test_wrong_type_in_option(self): @@ -972,17 +1094,17 @@ class OptionsNoPropTest(_fixtures.FixtureTest): Keyword = self.classes.Keyword self._assert_eager_with_entity_exception( [Item], - (joinedload_all(Keyword), ), - r"mapper option expects string key or list of attributes" + (joinedload_all(Keyword),), + r"mapper option expects string key or list of attributes", ) def test_non_contiguous_all_option(self): User = self.classes.User self._assert_eager_with_entity_exception( [User], - (joinedload_all(User.addresses, User.orders), ), + (joinedload_all(User.addresses, User.orders),), r"Attribute 'User.orders' does not link " - "from element 'Mapper|Address|addresses'" + "from element 'Mapper|Address|addresses'", ) def test_non_contiguous_all_option_of_type(self): @@ -990,21 +1112,29 @@ class OptionsNoPropTest(_fixtures.FixtureTest): Order = self.classes.Order self._assert_eager_with_entity_exception( [User], - (joinedload_all(User.addresses, User.orders.of_type(Order)), ), + (joinedload_all(User.addresses, User.orders.of_type(Order)),), r"Attribute 'User.orders' does not link " - "from element 'Mapper|Address|addresses'" + "from element 'Mapper|Address|addresses'", ) @classmethod def setup_mappers(cls): users, User, addresses, Address, orders, Order = ( - cls.tables.users, cls.classes.User, - cls.tables.addresses, cls.classes.Address, - cls.tables.orders, cls.classes.Order) - mapper(User, users, properties={ - 'addresses': relationship(Address), - 'orders': relationship(Order) - }) + cls.tables.users, + cls.classes.User, + cls.tables.addresses, + cls.classes.Address, + cls.tables.orders, + cls.classes.Order, + ) + mapper( + User, + users, + properties={ + "addresses": relationship(Address), + "orders": relationship(Order), + }, + ) mapper(Address, addresses) mapper(Order, orders) keywords, items, item_keywords, Keyword, Item = ( @@ -1012,41 +1142,56 @@ class OptionsNoPropTest(_fixtures.FixtureTest): cls.tables.items, cls.tables.item_keywords, cls.classes.Keyword, - cls.classes.Item) - mapper(Keyword, keywords, properties={ - "keywords": column_property(keywords.c.name + "some keyword") - }) - mapper(Item, items, - properties=dict(keywords=relationship(Keyword, - secondary=item_keywords))) + cls.classes.Item, + ) + mapper( + Keyword, + keywords, + properties={ + "keywords": column_property(keywords.c.name + "some keyword") + }, + ) + mapper( + Item, + items, + properties=dict( + keywords=relationship(Keyword, secondary=item_keywords) + ), + ) def _assert_option(self, entity_list, option): Item = self.classes.Item - q = create_session().query(*entity_list).\ - options(joinedload(option)) - key = ('loader', (inspect(Item), inspect(Item).attrs.keywords)) + q = create_session().query(*entity_list).options(joinedload(option)) + key = ("loader", (inspect(Item), inspect(Item).attrs.keywords)) assert key in q._attributes - def _assert_eager_with_entity_exception(self, entity_list, options, - message): - assert_raises_message(sa.exc.ArgumentError, - message, - create_session().query(*entity_list).options, - *options) + def _assert_eager_with_entity_exception( + self, entity_list, options, message + ): + assert_raises_message( + sa.exc.ArgumentError, + message, + create_session().query(*entity_list).options, + *options + ) - def _assert_eager_with_just_column_exception(self, column, - eager_option, message): - assert_raises_message(sa.exc.ArgumentError, message, - create_session().query(column).options, - joinedload(eager_option)) + def _assert_eager_with_just_column_exception( + self, column, eager_option, message + ): + assert_raises_message( + sa.exc.ArgumentError, + message, + create_session().query(column).options, + joinedload(eager_option), + ) class PickleTest(PathTest, QueryTest): - def _option_fixture(self, *arg): return strategy_options._UnboundLoad._from_keys( - strategy_options._UnboundLoad.joinedload, arg, True, {}) + strategy_options._UnboundLoad.joinedload, arg, True, {} + ) def test_modern_opt_getstate(self): User = self.classes.User @@ -1058,28 +1203,31 @@ class PickleTest(PathTest, QueryTest): eq_( opt.__getstate__(), { - '_is_chain_link': False, - 'local_opts': {}, - 'is_class_strategy': False, - 'path': [(User, 'addresses', None)], - 'propagate_to_loaders': True, - '_to_bind': [opt], - 'strategy': (('lazy', 'joined'),)} + "_is_chain_link": False, + "local_opts": {}, + "is_class_strategy": False, + "path": [(User, "addresses", None)], + "propagate_to_loaders": True, + "_to_bind": [opt], + "strategy": (("lazy", "joined"),), + }, ) def test_modern_opt_setstate(self): User = self.classes.User opt = strategy_options._UnboundLoad.__new__( - strategy_options._UnboundLoad) + strategy_options._UnboundLoad + ) state = { - '_is_chain_link': False, - 'local_opts': {}, - 'is_class_strategy': False, - 'path': [(User, 'addresses', None)], - 'propagate_to_loaders': True, - '_to_bind': [opt], - 'strategy': (('lazy', 'joined'),)} + "_is_chain_link": False, + "local_opts": {}, + "is_class_strategy": False, + "path": [(User, "addresses", None)], + "propagate_to_loaders": True, + "_to_bind": [opt], + "strategy": (("lazy", "joined"),), + } opt.__setstate__(state) @@ -1087,26 +1235,33 @@ class PickleTest(PathTest, QueryTest): attr = {} load = opt._bind_loader( [ent.entity_zero for ent in query._mapper_entities], - query._current_path, attr, False) + query._current_path, + attr, + False, + ) eq_( load.path, - inspect(User)._path_registry - [User.addresses.property][inspect(self.classes.Address)]) + inspect(User)._path_registry[User.addresses.property][ + inspect(self.classes.Address) + ], + ) def test_legacy_opt_setstate(self): User = self.classes.User opt = strategy_options._UnboundLoad.__new__( - strategy_options._UnboundLoad) + strategy_options._UnboundLoad + ) state = { - '_is_chain_link': False, - 'local_opts': {}, - 'is_class_strategy': False, - 'path': [(User, 'addresses')], - 'propagate_to_loaders': True, - '_to_bind': [opt], - 'strategy': (('lazy', 'joined'),)} + "_is_chain_link": False, + "local_opts": {}, + "is_class_strategy": False, + "path": [(User, "addresses")], + "propagate_to_loaders": True, + "_to_bind": [opt], + "strategy": (("lazy", "joined"),), + } opt.__setstate__(state) @@ -1114,12 +1269,17 @@ class PickleTest(PathTest, QueryTest): attr = {} load = opt._bind_loader( [ent.entity_zero for ent in query._mapper_entities], - query._current_path, attr, False) + query._current_path, + attr, + False, + ) eq_( load.path, - inspect(User)._path_registry - [User.addresses.property][inspect(self.classes.Address)]) + inspect(User)._path_registry[User.addresses.property][ + inspect(self.classes.Address) + ], + ) class LocalOptsTest(PathTest, QueryTest): @@ -1130,18 +1290,13 @@ class LocalOptsTest(PathTest, QueryTest): @strategy_options.loader_option() def some_col_opt_only(loadopt, key, opts): return loadopt.set_column_strategy( - (key, ), - None, - opts, - opts_only=True + (key,), None, opts, opts_only=True ) @strategy_options.loader_option() def some_col_opt_strategy(loadopt, key, opts): return loadopt.set_column_strategy( - (key, ), - {"deferred": True, "instrument": True}, - opts + (key,), {"deferred": True, "instrument": True}, opts ) cls.some_col_opt_only = some_col_opt_only @@ -1158,17 +1313,18 @@ class LocalOptsTest(PathTest, QueryTest): for tb in opt._to_bind: tb._bind_loader( [ent.entity_zero for ent in query._mapper_entities], - query._current_path, attr, False) + query._current_path, + attr, + False, + ) else: attr.update(opt.context) key = ( - 'loader', - tuple(inspect(User)._path_registry[User.name.property])) - eq_( - attr[key].local_opts, - expected + "loader", + tuple(inspect(User)._path_registry[User.name.property]), ) + eq_(attr[key].local_opts, expected) def test_single_opt_only(self): opt = strategy_options._UnboundLoad().some_col_opt_only( @@ -1183,29 +1339,25 @@ class LocalOptsTest(PathTest, QueryTest): ), strategy_options._UnboundLoad().some_col_opt_only( "name", {"bat": "hoho"} - ) + ), ] self._assert_attrs(opts, {"foo": "bar", "bat": "hoho"}) def test_bound_multiple_opt_only(self): User = self.classes.User opts = [ - Load(User).some_col_opt_only( - "name", {"foo": "bar"} - ).some_col_opt_only( - "name", {"bat": "hoho"} - ) + Load(User) + .some_col_opt_only("name", {"foo": "bar"}) + .some_col_opt_only("name", {"bat": "hoho"}) ] self._assert_attrs(opts, {"foo": "bar", "bat": "hoho"}) def test_bound_strat_opt_recvs_from_optonly(self): User = self.classes.User opts = [ - Load(User).some_col_opt_only( - "name", {"foo": "bar"} - ).some_col_opt_strategy( - "name", {"bat": "hoho"} - ) + Load(User) + .some_col_opt_only("name", {"foo": "bar"}) + .some_col_opt_strategy("name", {"bat": "hoho"}) ] self._assert_attrs(opts, {"foo": "bar", "bat": "hoho"}) @@ -1216,7 +1368,7 @@ class LocalOptsTest(PathTest, QueryTest): ), strategy_options._UnboundLoad().some_col_opt_strategy( "name", {"bat": "hoho"} - ) + ), ] self._assert_attrs(opts, {"foo": "bar", "bat": "hoho"}) @@ -1234,11 +1386,9 @@ class LocalOptsTest(PathTest, QueryTest): def test_bound_opt_only_adds_to_strat(self): User = self.classes.User opts = [ - Load(User).some_col_opt_strategy( - "name", {"bat": "hoho"} - ).some_col_opt_only( - "name", {"foo": "bar"} - ), + Load(User) + .some_col_opt_strategy("name", {"bat": "hoho"}) + .some_col_opt_only("name", {"foo": "bar"}) ] self._assert_attrs(opts, {"foo": "bar", "bat": "hoho"}) @@ -1251,21 +1401,21 @@ class CacheKeyTest(PathTest, QueryTest): def test_unbound_cache_key_included_safe(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "orders"]) opt = joinedload(User.orders).joinedload(Order.items) eq_( opt._generate_cache_key(query_path), - ( - ((Order, 'items', Item, ('lazy', 'joined')),) - ) + (((Order, "items", Item, ("lazy", "joined")),)), ) def test_unbound_cache_key_included_safe_multipath(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "orders"]) @@ -1275,21 +1425,18 @@ class CacheKeyTest(PathTest, QueryTest): eq_( opt1._generate_cache_key(query_path), - ( - ((Order, 'items', Item, ('lazy', 'joined')),) - ) + (((Order, "items", Item, ("lazy", "joined")),)), ) eq_( opt2._generate_cache_key(query_path), - ( - ((Order, 'address', Address, ('lazy', 'joined')),) - ) + (((Order, "address", Address, ("lazy", "joined")),)), ) def test_bound_cache_key_included_safe_multipath(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "orders"]) @@ -1299,61 +1446,51 @@ class CacheKeyTest(PathTest, QueryTest): eq_( opt1._generate_cache_key(query_path), - ( - ((Order, 'items', Item, ('lazy', 'joined')),) - ) + (((Order, "items", Item, ("lazy", "joined")),)), ) eq_( opt2._generate_cache_key(query_path), - ( - ((Order, 'address', Address, ('lazy', 'joined')),) - ) + (((Order, "address", Address, ("lazy", "joined")),)), ) def test_bound_cache_key_included_safe(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "orders"]) opt = Load(User).joinedload(User.orders).joinedload(Order.items) eq_( opt._generate_cache_key(query_path), - ( - ((Order, 'items', Item, ('lazy', 'joined')),) - ) + (((Order, "items", Item, ("lazy", "joined")),)), ) def test_unbound_cache_key_excluded_on_other(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) - query_path = self._make_path_registry( - [User, "addresses"]) + query_path = self._make_path_registry([User, "addresses"]) opt = joinedload(User.orders).joinedload(Order.items) - eq_( - opt._generate_cache_key(query_path), - None - ) + eq_(opt._generate_cache_key(query_path), None) def test_bound_cache_key_excluded_on_other(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) - query_path = self._make_path_registry( - [User, "addresses"]) + query_path = self._make_path_registry([User, "addresses"]) opt = Load(User).joinedload(User.orders).joinedload(Order.items) - eq_( - opt._generate_cache_key(query_path), - None - ) + eq_(opt._generate_cache_key(query_path), None) def test_unbound_cache_key_excluded_on_aliased(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) # query of: # @@ -1364,136 +1501,148 @@ class CacheKeyTest(PathTest, QueryTest): # the path excludes our option so cache key should # be None - query_path = self._make_path_registry( - [User, "orders"]) + query_path = self._make_path_registry([User, "orders"]) opt = joinedload(aliased(User).orders).joinedload(Order.items) - eq_( - opt._generate_cache_key(query_path), - None - ) + eq_(opt._generate_cache_key(query_path), None) def test_bound_cache_key_wildcard_one(self): # do not change this test, it is testing # a specific condition in Load._chop_path(). - User, Address = self.classes('User', 'Address') + User, Address = self.classes("User", "Address") query_path = self._make_path_registry([User, "addresses"]) opt = Load(User).lazyload("*") - eq_( - opt._generate_cache_key(query_path), - None - ) + eq_(opt._generate_cache_key(query_path), None) def test_unbound_cache_key_wildcard_one(self): - User, Address = self.classes('User', 'Address') + User, Address = self.classes("User", "Address") query_path = self._make_path_registry([User, "addresses"]) opt = lazyload("*") eq_( opt._generate_cache_key(query_path), - (('relationship:_sa_default', ('lazy', 'select')),) + (("relationship:_sa_default", ("lazy", "select")),), ) def test_bound_cache_key_wildcard_two(self): User, Address, Order, Item, SubItem, Keyword = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem', "Keyword") + "User", "Address", "Order", "Item", "SubItem", "Keyword" + ) query_path = self._make_path_registry([User]) opt = Load(User).lazyload("orders").lazyload("*") eq_( opt._generate_cache_key(query_path), - (('orders', Order, ('lazy', 'select')), - ('orders', Order, 'relationship:*', ('lazy', 'select'))) + ( + ("orders", Order, ("lazy", "select")), + ("orders", Order, "relationship:*", ("lazy", "select")), + ), ) def test_unbound_cache_key_wildcard_two(self): User, Address, Order, Item, SubItem, Keyword = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem', "Keyword") + "User", "Address", "Order", "Item", "SubItem", "Keyword" + ) query_path = self._make_path_registry([User]) opt = lazyload("orders").lazyload("*") eq_( opt._generate_cache_key(query_path), - (('orders', Order, ('lazy', 'select')), - ('orders', Order, 'relationship:*', ('lazy', 'select'))) + ( + ("orders", Order, ("lazy", "select")), + ("orders", Order, "relationship:*", ("lazy", "select")), + ), ) def test_unbound_cache_key_of_type_subclass_relationship(self): User, Address, Order, Item, SubItem, Keyword = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem', "Keyword") + "User", "Address", "Order", "Item", "SubItem", "Keyword" + ) query_path = self._make_path_registry([Order, "items", Item]) - opt = subqueryload( - Order.items.of_type(SubItem)).subqueryload(SubItem.extra_keywords) + opt = subqueryload(Order.items.of_type(SubItem)).subqueryload( + SubItem.extra_keywords + ) eq_( opt._generate_cache_key(query_path), ( - (SubItem, ('lazy', 'subquery')), - ('extra_keywords', Keyword, ('lazy', 'subquery')) - ) + (SubItem, ("lazy", "subquery")), + ("extra_keywords", Keyword, ("lazy", "subquery")), + ), ) def test_unbound_cache_key_of_type_subclass_relationship_stringattr(self): User, Address, Order, Item, SubItem, Keyword = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem', "Keyword") + "User", "Address", "Order", "Item", "SubItem", "Keyword" + ) query_path = self._make_path_registry([Order, "items", Item]) - opt = subqueryload( - Order.items.of_type(SubItem)).subqueryload("extra_keywords") + opt = subqueryload(Order.items.of_type(SubItem)).subqueryload( + "extra_keywords" + ) eq_( opt._generate_cache_key(query_path), ( - (SubItem, ('lazy', 'subquery')), - ('extra_keywords', Keyword, ('lazy', 'subquery')) - ) + (SubItem, ("lazy", "subquery")), + ("extra_keywords", Keyword, ("lazy", "subquery")), + ), ) def test_bound_cache_key_of_type_subclass_relationship(self): User, Address, Order, Item, SubItem, Keyword = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem', "Keyword") + "User", "Address", "Order", "Item", "SubItem", "Keyword" + ) query_path = self._make_path_registry([Order, "items", Item]) - opt = Load(Order).subqueryload( - Order.items.of_type(SubItem)).subqueryload(SubItem.extra_keywords) + opt = ( + Load(Order) + .subqueryload(Order.items.of_type(SubItem)) + .subqueryload(SubItem.extra_keywords) + ) eq_( opt._generate_cache_key(query_path), ( - (SubItem, ('lazy', 'subquery')), - ('extra_keywords', Keyword, ('lazy', 'subquery')) - ) + (SubItem, ("lazy", "subquery")), + ("extra_keywords", Keyword, ("lazy", "subquery")), + ), ) def test_bound_cache_key_of_type_subclass_string_relationship(self): User, Address, Order, Item, SubItem, Keyword = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem', "Keyword") + "User", "Address", "Order", "Item", "SubItem", "Keyword" + ) query_path = self._make_path_registry([Order, "items", Item]) - opt = Load(Order).subqueryload( - Order.items.of_type(SubItem)).subqueryload("extra_keywords") + opt = ( + Load(Order) + .subqueryload(Order.items.of_type(SubItem)) + .subqueryload("extra_keywords") + ) eq_( opt._generate_cache_key(query_path), ( - (SubItem, ('lazy', 'subquery')), - ('extra_keywords', Keyword, ('lazy', 'subquery')) - ) + (SubItem, ("lazy", "subquery")), + ("extra_keywords", Keyword, ("lazy", "subquery")), + ), ) def test_unbound_cache_key_excluded_of_type_safe(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) # query of: # # query(User).options( @@ -1507,16 +1656,15 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([User, "addresses"]) - opt = subqueryload(User.orders).\ - subqueryload(Order.items.of_type(SubItem)) - eq_( - opt._generate_cache_key(query_path), - None + opt = subqueryload(User.orders).subqueryload( + Order.items.of_type(SubItem) ) + eq_(opt._generate_cache_key(query_path), None) def test_unbound_cache_key_excluded_of_type_unsafe(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) # query of: # # query(User).options( @@ -1530,16 +1678,15 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([User, "addresses"]) - opt = subqueryload(User.orders).\ - subqueryload(Order.items.of_type(aliased(SubItem))) - eq_( - opt._generate_cache_key(query_path), - None + opt = subqueryload(User.orders).subqueryload( + Order.items.of_type(aliased(SubItem)) ) + eq_(opt._generate_cache_key(query_path), None) def test_bound_cache_key_excluded_of_type_safe(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) # query of: # # query(User).options( @@ -1553,16 +1700,17 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([User, "addresses"]) - opt = Load(User).subqueryload(User.orders).\ - subqueryload(Order.items.of_type(SubItem)) - eq_( - opt._generate_cache_key(query_path), - None + opt = ( + Load(User) + .subqueryload(User.orders) + .subqueryload(Order.items.of_type(SubItem)) ) + eq_(opt._generate_cache_key(query_path), None) def test_bound_cache_key_excluded_of_type_unsafe(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) # query of: # # query(User).options( @@ -1576,397 +1724,469 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([User, "addresses"]) - opt = Load(User).subqueryload(User.orders).\ - subqueryload(Order.items.of_type(aliased(SubItem))) - eq_( - opt._generate_cache_key(query_path), - None + opt = ( + Load(User) + .subqueryload(User.orders) + .subqueryload(Order.items.of_type(aliased(SubItem))) ) + eq_(opt._generate_cache_key(query_path), None) def test_unbound_cache_key_included_of_type_safe(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "orders"]) opt = joinedload(User.orders).joinedload(Order.items.of_type(SubItem)) eq_( opt._generate_cache_key(query_path), - ( - (Order, 'items', SubItem, ('lazy', 'joined')), - ) + ((Order, "items", SubItem, ("lazy", "joined")),), ) def test_bound_cache_key_included_of_type_safe(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "orders"]) - opt = Load(User).joinedload(User.orders).\ - joinedload(Order.items.of_type(SubItem)) + opt = ( + Load(User) + .joinedload(User.orders) + .joinedload(Order.items.of_type(SubItem)) + ) eq_( opt._generate_cache_key(query_path), - ( - (Order, 'items', SubItem, ('lazy', 'joined')), - ) + ((Order, "items", SubItem, ("lazy", "joined")),), ) def test_unbound_cache_key_included_unsafe_option_one(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "orders"]) - opt = joinedload(User.orders).\ - joinedload(Order.items.of_type(aliased(SubItem))) - eq_( - opt._generate_cache_key(query_path), - False + opt = joinedload(User.orders).joinedload( + Order.items.of_type(aliased(SubItem)) ) + eq_(opt._generate_cache_key(query_path), False) def test_unbound_cache_key_included_unsafe_option_two(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "orders", Order]) - opt = joinedload(User.orders).\ - joinedload(Order.items.of_type(aliased(SubItem))) - eq_( - opt._generate_cache_key(query_path), - False + opt = joinedload(User.orders).joinedload( + Order.items.of_type(aliased(SubItem)) ) + eq_(opt._generate_cache_key(query_path), False) def test_unbound_cache_key_included_unsafe_option_three(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "orders", Order, "items"]) - opt = joinedload(User.orders).\ - joinedload(Order.items.of_type(aliased(SubItem))) - eq_( - opt._generate_cache_key(query_path), - False + opt = joinedload(User.orders).joinedload( + Order.items.of_type(aliased(SubItem)) ) + eq_(opt._generate_cache_key(query_path), False) def test_unbound_cache_key_included_unsafe_query(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) au = aliased(User) query_path = self._make_path_registry([inspect(au), "orders"]) - opt = joinedload(au.orders).\ - joinedload(Order.items) - eq_( - opt._generate_cache_key(query_path), - False - ) + opt = joinedload(au.orders).joinedload(Order.items) + eq_(opt._generate_cache_key(query_path), False) def test_unbound_cache_key_included_safe_w_deferred(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "addresses"]) - opt = joinedload(User.addresses).\ - defer(Address.email_address).defer(Address.user_id) + opt = ( + joinedload(User.addresses) + .defer(Address.email_address) + .defer(Address.user_id) + ) eq_( opt._generate_cache_key(query_path), ( ( - Address, "email_address", - ('deferred', True), - ('instrument', True) + Address, + "email_address", + ("deferred", True), + ("instrument", True), ), - ( - Address, "user_id", - ('deferred', True), - ('instrument', True) - ), - ) + (Address, "user_id", ("deferred", True), ("instrument", True)), + ), ) def test_unbound_cache_key_included_safe_w_deferred_multipath(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "orders"]) base = joinedload(User.orders) opt1 = base.joinedload(Order.items) - opt2 = base.joinedload(Order.address).defer(Address.email_address).\ - defer(Address.user_id) + opt2 = ( + base.joinedload(Order.address) + .defer(Address.email_address) + .defer(Address.user_id) + ) eq_( opt1._generate_cache_key(query_path), - ( - (Order, 'items', Item, ('lazy', 'joined')), - ) + ((Order, "items", Item, ("lazy", "joined")),), ) eq_( opt2._generate_cache_key(query_path), ( - (Order, 'address', Address, ('lazy', 'joined')), - (Order, 'address', Address, 'email_address', - ('deferred', True), ('instrument', True)), - (Order, 'address', Address, 'user_id', - ('deferred', True), ('instrument', True)) - ) + (Order, "address", Address, ("lazy", "joined")), + ( + Order, + "address", + Address, + "email_address", + ("deferred", True), + ("instrument", True), + ), + ( + Order, + "address", + Address, + "user_id", + ("deferred", True), + ("instrument", True), + ), + ), ) def test_bound_cache_key_included_safe_w_deferred(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "addresses"]) - opt = Load(User).joinedload(User.addresses).\ - defer(Address.email_address).defer(Address.user_id) + opt = ( + Load(User) + .joinedload(User.addresses) + .defer(Address.email_address) + .defer(Address.user_id) + ) eq_( opt._generate_cache_key(query_path), ( ( - Address, "email_address", - ('deferred', True), - ('instrument', True) - ), - ( - Address, "user_id", - ('deferred', True), - ('instrument', True) + Address, + "email_address", + ("deferred", True), + ("instrument", True), ), - ) + (Address, "user_id", ("deferred", True), ("instrument", True)), + ), ) def test_bound_cache_key_included_safe_w_deferred_multipath(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "orders"]) base = Load(User).joinedload(User.orders) opt1 = base.joinedload(Order.items) - opt2 = base.joinedload(Order.address).defer(Address.email_address).\ - defer(Address.user_id) + opt2 = ( + base.joinedload(Order.address) + .defer(Address.email_address) + .defer(Address.user_id) + ) eq_( opt1._generate_cache_key(query_path), - ( - (Order, 'items', Item, ('lazy', 'joined')), - ) + ((Order, "items", Item, ("lazy", "joined")),), ) eq_( opt2._generate_cache_key(query_path), ( - (Order, 'address', Address, ('lazy', 'joined')), - (Order, 'address', Address, 'email_address', - ('deferred', True), ('instrument', True)), - (Order, 'address', Address, 'user_id', - ('deferred', True), ('instrument', True)) - ) + (Order, "address", Address, ("lazy", "joined")), + ( + Order, + "address", + Address, + "email_address", + ("deferred", True), + ("instrument", True), + ), + ( + Order, + "address", + Address, + "user_id", + ("deferred", True), + ("instrument", True), + ), + ), ) def test_unbound_cache_key_included_safe_w_option(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) - opt = defaultload("orders").joinedload( - "items", innerjoin=True).defer("description") + opt = ( + defaultload("orders") + .joinedload("items", innerjoin=True) + .defer("description") + ) query_path = self._make_path_registry([User, "orders"]) eq_( opt._generate_cache_key(query_path), ( - (Order, 'items', Item, - ('lazy', 'joined'), ('innerjoin', True)), - (Order, 'items', Item, 'description', - ('deferred', True), ('instrument', True)) - ) + ( + Order, + "items", + Item, + ("lazy", "joined"), + ("innerjoin", True), + ), + ( + Order, + "items", + Item, + "description", + ("deferred", True), + ("instrument", True), + ), + ), ) def test_bound_cache_key_excluded_on_aliased(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) - query_path = self._make_path_registry( - [User, "orders"]) + query_path = self._make_path_registry([User, "orders"]) au = aliased(User) opt = Load(au).joinedload(au.orders).joinedload(Order.items) - eq_( - opt._generate_cache_key(query_path), - None - ) + eq_(opt._generate_cache_key(query_path), None) def test_bound_cache_key_included_unsafe_option_one(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "orders"]) - opt = Load(User).joinedload(User.orders).\ - joinedload(Order.items.of_type(aliased(SubItem))) - eq_( - opt._generate_cache_key(query_path), - False + opt = ( + Load(User) + .joinedload(User.orders) + .joinedload(Order.items.of_type(aliased(SubItem))) ) + eq_(opt._generate_cache_key(query_path), False) def test_bound_cache_key_included_unsafe_option_two(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "orders", Order]) - opt = Load(User).joinedload(User.orders).\ - joinedload(Order.items.of_type(aliased(SubItem))) - eq_( - opt._generate_cache_key(query_path), - False + opt = ( + Load(User) + .joinedload(User.orders) + .joinedload(Order.items.of_type(aliased(SubItem))) ) + eq_(opt._generate_cache_key(query_path), False) def test_bound_cache_key_included_unsafe_option_three(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "orders", Order, "items"]) - opt = Load(User).joinedload(User.orders).\ - joinedload(Order.items.of_type(aliased(SubItem))) - eq_( - opt._generate_cache_key(query_path), - False + opt = ( + Load(User) + .joinedload(User.orders) + .joinedload(Order.items.of_type(aliased(SubItem))) ) + eq_(opt._generate_cache_key(query_path), False) def test_bound_cache_key_included_unsafe_query(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) au = aliased(User) query_path = self._make_path_registry([inspect(au), "orders"]) - opt = Load(au).joinedload(au.orders).\ - joinedload(Order.items) - eq_( - opt._generate_cache_key(query_path), - False - ) - + opt = Load(au).joinedload(au.orders).joinedload(Order.items) + eq_(opt._generate_cache_key(query_path), False) def test_bound_cache_key_included_safe_w_option(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) - opt = Load(User).defaultload("orders").joinedload( - "items", innerjoin=True).defer("description") + opt = ( + Load(User) + .defaultload("orders") + .joinedload("items", innerjoin=True) + .defer("description") + ) query_path = self._make_path_registry([User, "orders"]) eq_( opt._generate_cache_key(query_path), ( - (Order, 'items', Item, - ('lazy', 'joined'), ('innerjoin', True)), - (Order, 'items', Item, 'description', - ('deferred', True), ('instrument', True)) - ) + ( + Order, + "items", + Item, + ("lazy", "joined"), + ("innerjoin", True), + ), + ( + Order, + "items", + Item, + "description", + ("deferred", True), + ("instrument", True), + ), + ), ) def test_unbound_cache_key_included_safe_w_loadonly_strs(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "addresses"]) opt = defaultload(User.addresses).load_only("id", "email_address") eq_( opt._generate_cache_key(query_path), - ( - (Address, 'id', - ('deferred', False), ('instrument', True)), - (Address, 'email_address', - ('deferred', False), ('instrument', True)), - (Address, 'column:*', - ('deferred', True), ('instrument', True), - ('undefer_pks', True)) - ) + (Address, "id", ("deferred", False), ("instrument", True)), + ( + Address, + "email_address", + ("deferred", False), + ("instrument", True), + ), + ( + Address, + "column:*", + ("deferred", True), + ("instrument", True), + ("undefer_pks", True), + ), + ), ) def test_unbound_cache_key_included_safe_w_loadonly_props(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "addresses"]) opt = defaultload(User.addresses).load_only( - Address.id, Address.email_address) + Address.id, Address.email_address + ) eq_( opt._generate_cache_key(query_path), - ( - (Address, 'id', - ('deferred', False), ('instrument', True)), - (Address, 'email_address', - ('deferred', False), ('instrument', True)), - (Address, 'column:*', - ('deferred', True), ('instrument', True), - ('undefer_pks', True)) - ) + (Address, "id", ("deferred", False), ("instrument", True)), + ( + Address, + "email_address", + ("deferred", False), + ("instrument", True), + ), + ( + Address, + "column:*", + ("deferred", True), + ("instrument", True), + ("undefer_pks", True), + ), + ), ) def test_bound_cache_key_included_safe_w_loadonly(self): User, Address, Order, Item, SubItem = self.classes( - 'User', 'Address', 'Order', 'Item', 'SubItem') + "User", "Address", "Order", "Item", "SubItem" + ) query_path = self._make_path_registry([User, "addresses"]) - opt = Load(User).defaultload(User.addresses).\ - load_only("id", "email_address") + opt = ( + Load(User) + .defaultload(User.addresses) + .load_only("id", "email_address") + ) eq_( opt._generate_cache_key(query_path), - ( - (Address, 'id', - ('deferred', False), ('instrument', True)), - (Address, 'email_address', - ('deferred', False), ('instrument', True)), - (Address, 'column:*', - ('deferred', True), ('instrument', True), - ('undefer_pks', True)) - ) + (Address, "id", ("deferred", False), ("instrument", True)), + ( + Address, + "email_address", + ("deferred", False), + ("instrument", True), + ), + ( + Address, + "column:*", + ("deferred", True), + ("instrument", True), + ("undefer_pks", True), + ), + ), ) def test_unbound_cache_key_undefer_group(self): - User, Address = self.classes('User', 'Address') + User, Address = self.classes("User", "Address") query_path = self._make_path_registry([User, "addresses"]) - opt = defaultload(User.addresses).undefer_group('xyz') + opt = defaultload(User.addresses).undefer_group("xyz") eq_( opt._generate_cache_key(query_path), - - ( - (Address, 'column:*', ("undefer_group_xyz", True)), - ) + ((Address, "column:*", ("undefer_group_xyz", True)),), ) def test_bound_cache_key_undefer_group(self): - User, Address = self.classes('User', 'Address') + User, Address = self.classes("User", "Address") query_path = self._make_path_registry([User, "addresses"]) - opt = Load(User).defaultload(User.addresses).undefer_group('xyz') + opt = Load(User).defaultload(User.addresses).undefer_group("xyz") eq_( opt._generate_cache_key(query_path), - - ( - (Address, 'column:*', ("undefer_group_xyz", True)), - ) + ((Address, "column:*", ("undefer_group_xyz", True)),), ) diff --git a/test/orm/test_pickled.py b/test/orm/test_pickled.py index ff8b9e429c..0a3da813e0 100644 --- a/test/orm/test_pickled.py +++ b/test/orm/test_pickled.py @@ -7,72 +7,119 @@ from sqlalchemy.testing.util import picklers from sqlalchemy.testing import assert_raises_message from sqlalchemy import Integer, String, ForeignKey, exc, MetaData from sqlalchemy.testing.schema import Table, Column -from sqlalchemy.orm import mapper, relationship, create_session, \ - sessionmaker, attributes, interfaces,\ - clear_mappers, exc as orm_exc,\ - configure_mappers, Session, lazyload_all,\ - lazyload, aliased, subqueryload +from sqlalchemy.orm import ( + mapper, + relationship, + create_session, + sessionmaker, + attributes, + interfaces, + clear_mappers, + exc as orm_exc, + configure_mappers, + Session, + lazyload_all, + lazyload, + aliased, + subqueryload, +) from sqlalchemy.orm import state as sa_state from sqlalchemy.orm import instrumentation -from sqlalchemy.orm.collections import attribute_mapped_collection, \ - column_mapped_collection +from sqlalchemy.orm.collections import ( + attribute_mapped_collection, + column_mapped_collection, +) from sqlalchemy.testing import fixtures from test.orm import _fixtures -from sqlalchemy.testing.pickleable import User, Address, Dingaling, Order, \ - Child1, Child2, Parent, Screen, EmailUser +from sqlalchemy.testing.pickleable import ( + User, + Address, + Dingaling, + Order, + Child1, + Child2, + Parent, + Screen, + EmailUser, +) from sqlalchemy.orm import with_polymorphic -from .inheritance._poly_fixtures import Company, Person, Engineer, Manager, \ - Boss, Machine, Paperwork, _Polymorphic +from .inheritance._poly_fixtures import ( + Company, + Person, + Engineer, + Manager, + Boss, + Machine, + Paperwork, + _Polymorphic, +) -class PickleTest(fixtures.MappedTest): +class PickleTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30), nullable=False), - test_needs_acid=True, - test_needs_fk=True) - - Table('addresses', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', None, ForeignKey('users.id')), - Column('email_address', String(50), nullable=False), - test_needs_acid=True, - test_needs_fk=True) - Table('orders', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', None, ForeignKey('users.id')), - Column('address_id', None, ForeignKey('addresses.id')), - Column('description', String(30)), - Column('isopen', Integer), - test_needs_acid=True, - test_needs_fk=True) - Table("dingalings", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('address_id', None, ForeignKey('addresses.id')), - Column('data', String(30)), - test_needs_acid=True, - test_needs_fk=True) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30), nullable=False), + test_needs_acid=True, + test_needs_fk=True, + ) + + Table( + "addresses", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", None, ForeignKey("users.id")), + Column("email_address", String(50), nullable=False), + test_needs_acid=True, + test_needs_fk=True, + ) + Table( + "orders", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", None, ForeignKey("users.id")), + Column("address_id", None, ForeignKey("addresses.id")), + Column("description", String(30)), + Column("isopen", Integer), + test_needs_acid=True, + test_needs_fk=True, + ) + Table( + "dingalings", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("address_id", None, ForeignKey("addresses.id")), + Column("data", String(30)), + test_needs_acid=True, + test_needs_fk=True, + ) def test_transient(self): - users, addresses = (self.tables.users, - self.tables.addresses) + users, addresses = (self.tables.users, self.tables.addresses) - mapper(User, users, properties={ - 'addresses': relationship(Address, backref="user") - }) + mapper( + User, + users, + properties={"addresses": relationship(Address, backref="user")}, + ) mapper(Address, addresses) sess = create_session() - u1 = User(name='ed') - u1.addresses.append(Address(email_address='ed@bar.com')) + u1 = User(name="ed") + u1.addresses.append(Address(email_address="ed@bar.com")) u2 = pickle.loads(pickle.dumps(u1)) sess.add(u2) @@ -86,7 +133,7 @@ class PickleTest(fixtures.MappedTest): users = self.tables.users umapper = mapper(User, users) - u1 = User(name='ed') + u1 = User(name="ed") u1_pickled = pickle.dumps(u1, -1) clear_mappers() @@ -95,13 +142,15 @@ class PickleTest(fixtures.MappedTest): orm_exc.UnmappedInstanceError, "Cannot deserialize object of type " " - no mapper()", - pickle.loads, u1_pickled) + pickle.loads, + u1_pickled, + ) def test_no_instrumentation(self): users = self.tables.users umapper = mapper(User, users) - u1 = User(name='ed') + u1 = User(name="ed") u1_pickled = pickle.dumps(u1, -1) clear_mappers() @@ -114,51 +163,63 @@ class PickleTest(fixtures.MappedTest): eq_(str(u1), "User(name='ed')") def test_class_deferred_cols(self): - addresses, users = (self.tables.addresses, - self.tables.users) - - mapper(User, users, properties={ - 'name': sa.orm.deferred(users.c.name), - 'addresses': relationship(Address, backref="user") - }) - mapper(Address, addresses, properties={ - 'email_address': sa.orm.deferred(addresses.c.email_address) - }) + addresses, users = (self.tables.addresses, self.tables.users) + + mapper( + User, + users, + properties={ + "name": sa.orm.deferred(users.c.name), + "addresses": relationship(Address, backref="user"), + }, + ) + mapper( + Address, + addresses, + properties={ + "email_address": sa.orm.deferred(addresses.c.email_address) + }, + ) sess = create_session() - u1 = User(name='ed') - u1.addresses.append(Address(email_address='ed@bar.com')) + u1 = User(name="ed") + u1.addresses.append(Address(email_address="ed@bar.com")) sess.add(u1) sess.flush() sess.expunge_all() u1 = sess.query(User).get(u1.id) - assert 'name' not in u1.__dict__ - assert 'addresses' not in u1.__dict__ + assert "name" not in u1.__dict__ + assert "addresses" not in u1.__dict__ u2 = pickle.loads(pickle.dumps(u1)) sess2 = create_session() sess2.add(u2) - eq_(u2.name, 'ed') - eq_(u2, User(name='ed', addresses=[ - Address(email_address='ed@bar.com')])) + eq_(u2.name, "ed") + eq_( + u2, + User(name="ed", addresses=[Address(email_address="ed@bar.com")]), + ) u2 = pickle.loads(pickle.dumps(u1)) sess2 = create_session() u2 = sess2.merge(u2, load=False) - eq_(u2.name, 'ed') - eq_(u2, User(name='ed', addresses=[ - Address(email_address='ed@bar.com')])) + eq_(u2.name, "ed") + eq_( + u2, + User(name="ed", addresses=[Address(email_address="ed@bar.com")]), + ) def test_instance_lazy_relation_loaders(self): - users, addresses = (self.tables.users, - self.tables.addresses) + users, addresses = (self.tables.users, self.tables.addresses) - mapper(User, users, properties={ - 'addresses': relationship(Address, lazy='noload') - }) + mapper( + User, + users, + properties={"addresses": relationship(Address, lazy="noload")}, + ) mapper(Address, addresses) sess = Session() - u1 = User(name='ed', addresses=[Address(email_address='ed@bar.com')]) + u1 = User(name="ed", addresses=[Address(email_address="ed@bar.com")]) sess.add(u1) sess.commit() @@ -174,9 +235,11 @@ class PickleTest(fixtures.MappedTest): def test_invalidated_flag_pickle(self): users, addresses = (self.tables.users, self.tables.addresses) - mapper(User, users, properties={ - 'addresses': relationship(Address, lazy='noload') - }) + mapper( + User, + users, + properties={"addresses": relationship(Address, lazy="noload")}, + ) mapper(Address, addresses) u1 = User() @@ -188,9 +251,11 @@ class PickleTest(fixtures.MappedTest): def test_invalidated_flag_deepcopy(self): users, addresses = (self.tables.users, self.tables.addresses) - mapper(User, users, properties={ - 'addresses': relationship(Address, lazy='noload') - }) + mapper( + User, + users, + properties={"addresses": relationship(Address, lazy="noload")}, + ) mapper(Address, addresses) u1 = User() @@ -203,62 +268,73 @@ class PickleTest(fixtures.MappedTest): def test_instance_deferred_cols(self): users, addresses = (self.tables.users, self.tables.addresses) - mapper(User, users, properties={ - 'addresses': relationship(Address, backref="user") - }) + mapper( + User, + users, + properties={"addresses": relationship(Address, backref="user")}, + ) mapper(Address, addresses) sess = create_session() - u1 = User(name='ed') - u1.addresses.append(Address(email_address='ed@bar.com')) + u1 = User(name="ed") + u1.addresses.append(Address(email_address="ed@bar.com")) sess.add(u1) sess.flush() sess.expunge_all() - u1 = sess.query(User).\ - options(sa.orm.defer('name'), - sa.orm.defer('addresses.email_address')).\ - get(u1.id) - assert 'name' not in u1.__dict__ - assert 'addresses' not in u1.__dict__ + u1 = ( + sess.query(User) + .options( + sa.orm.defer("name"), sa.orm.defer("addresses.email_address") + ) + .get(u1.id) + ) + assert "name" not in u1.__dict__ + assert "addresses" not in u1.__dict__ u2 = pickle.loads(pickle.dumps(u1)) sess2 = create_session() sess2.add(u2) - eq_(u2.name, 'ed') - assert 'addresses' not in u2.__dict__ + eq_(u2.name, "ed") + assert "addresses" not in u2.__dict__ ad = u2.addresses[0] - assert 'email_address' not in ad.__dict__ - eq_(ad.email_address, 'ed@bar.com') - eq_(u2, User(name='ed', addresses=[ - Address(email_address='ed@bar.com')])) + assert "email_address" not in ad.__dict__ + eq_(ad.email_address, "ed@bar.com") + eq_( + u2, + User(name="ed", addresses=[Address(email_address="ed@bar.com")]), + ) u2 = pickle.loads(pickle.dumps(u1)) sess2 = create_session() u2 = sess2.merge(u2, load=False) - eq_(u2.name, 'ed') - assert 'addresses' not in u2.__dict__ + eq_(u2.name, "ed") + assert "addresses" not in u2.__dict__ ad = u2.addresses[0] # mapper options now transmit over merge(), # new as of 0.6, so email_address is deferred. - assert 'email_address' not in ad.__dict__ + assert "email_address" not in ad.__dict__ - eq_(ad.email_address, 'ed@bar.com') - eq_(u2, User(name='ed', addresses=[ - Address(email_address='ed@bar.com')])) + eq_(ad.email_address, "ed@bar.com") + eq_( + u2, + User(name="ed", addresses=[Address(email_address="ed@bar.com")]), + ) def test_pickle_protocols(self): users, addresses = (self.tables.users, self.tables.addresses) - mapper(User, users, properties={ - 'addresses': relationship(Address, backref="user") - }) + mapper( + User, + users, + properties={"addresses": relationship(Address, backref="user")}, + ) mapper(Address, addresses) sess = sessionmaker()() - u1 = User(name='ed') - u1.addresses.append(Address(email_address='ed@bar.com')) + u1 = User(name="ed") + u1.addresses.append(Address(email_address="ed@bar.com")) sess.add(u1) sess.commit() @@ -273,59 +349,62 @@ class PickleTest(fixtures.MappedTest): users = self.tables.users mapper(User, users) sess = Session() - sess.add(User(id=1, name='ed')) + sess.add(User(id=1, name="ed")) sess.commit() sess.close() - inst = User(id=1, name='ed') + inst = User(id=1, name="ed") del inst._sa_instance_state state = sa_state.InstanceState.__new__(sa_state.InstanceState) state_09 = { - 'class_': User, - 'modified': False, - 'committed_state': {}, - 'instance': inst, - 'callables': {'name': state, 'id': state}, - 'key': (User, (1,)), - 'expired': True} + "class_": User, + "modified": False, + "committed_state": {}, + "instance": inst, + "callables": {"name": state, "id": state}, + "key": (User, (1,)), + "expired": True, + } manager = instrumentation._SerializeManager.__new__( - instrumentation._SerializeManager) + instrumentation._SerializeManager + ) manager.class_ = User - state_09['manager'] = manager + state_09["manager"] = manager state.__setstate__(state_09) - eq_(state.expired_attributes, {'name', 'id'}) + eq_(state.expired_attributes, {"name", "id"}) sess = Session() sess.add(inst) - eq_(inst.name, 'ed') + eq_(inst.name, "ed") # test identity_token expansion - eq_(sa.inspect(inst).key, (User, (1, ), None)) + eq_(sa.inspect(inst).key, (User, (1,), None)) def test_11_pickle(self): users = self.tables.users mapper(User, users) sess = Session() - u1 = User(id=1, name='ed') + u1 = User(id=1, name="ed") sess.add(u1) sess.commit() sess.close() manager = instrumentation._SerializeManager.__new__( - instrumentation._SerializeManager) + instrumentation._SerializeManager + ) manager.class_ = User state_11 = { - - 'class_': User, - 'modified': False, - 'committed_state': {}, - 'instance': u1, - 'manager': manager, - 'key': (User, (1,)), - 'expired_attributes': set(), - 'expired': True} + "class_": User, + "modified": False, + "committed_state": {}, + "instance": u1, + "manager": manager, + "key": (User, (1,)), + "expired_attributes": set(), + "expired": True, + } state = sa_state.InstanceState.__new__(sa_state.InstanceState) state.__setstate__(state_11) @@ -337,9 +416,9 @@ class PickleTest(fixtures.MappedTest): users = self.tables.users mapper(User, users) - u1 = User(id=1, name='ed') + u1 = User(id=1, name="ed") - sa.inspect(u1).info['some_key'] = 'value' + sa.inspect(u1).info["some_key"] = "value" state_dict = sa.inspect(u1).__getstate__() @@ -347,24 +426,30 @@ class PickleTest(fixtures.MappedTest): state.__setstate__(state_dict) u2 = state.obj() - eq_(sa.inspect(u2).info['some_key'], 'value') + eq_(sa.inspect(u2).info["some_key"], "value") @testing.requires.non_broken_pickle def test_options_with_descriptors(self): - users, addresses, dingalings = (self.tables.users, - self.tables.addresses, - self.tables.dingalings) - - mapper(User, users, properties={ - 'addresses': relationship(Address, backref="user") - }) - mapper(Address, addresses, properties={ - 'dingaling': relationship(Dingaling) - }) + users, addresses, dingalings = ( + self.tables.users, + self.tables.addresses, + self.tables.dingalings, + ) + + mapper( + User, + users, + properties={"addresses": relationship(Address, backref="user")}, + ) + mapper( + Address, + addresses, + properties={"dingaling": relationship(Dingaling)}, + ) mapper(Dingaling, dingalings) sess = create_session() - u1 = User(name='ed') - u1.addresses.append(Address(email_address='ed@bar.com')) + u1 = User(name="ed") + u1.addresses.append(Address(email_address="ed@bar.com")) sess.add(u1) sess.flush() sess.expunge_all() @@ -387,18 +472,26 @@ class PickleTest(fixtures.MappedTest): to not rely upon InstanceState to deserialize.""" m = MetaData() - c1 = Table('c1', m, - Column('parent_id', String, ForeignKey('p.id'), - primary_key=True)) - c2 = Table('c2', m, - Column('parent_id', String, ForeignKey('p.id'), - primary_key=True)) - p = Table('p', m, Column('id', String, primary_key=True)) - - mapper(Parent, p, properties={ - 'children1': relationship(Child1), - 'children2': relationship(Child2) - }) + c1 = Table( + "c1", + m, + Column("parent_id", String, ForeignKey("p.id"), primary_key=True), + ) + c2 = Table( + "c2", + m, + Column("parent_id", String, ForeignKey("p.id"), primary_key=True), + ) + p = Table("p", m, Column("id", String, primary_key=True)) + + mapper( + Parent, + p, + properties={ + "children1": relationship(Child1), + "children2": relationship(Child2), + }, + ) mapper(Child1, c1) mapper(Child2, c2) @@ -411,6 +504,7 @@ class PickleTest(fixtures.MappedTest): def test_exceptions(self): class Foo(object): pass + users = self.tables.users mapper(User, users) @@ -426,65 +520,80 @@ class PickleTest(fixtures.MappedTest): def test_attribute_mapped_collection(self): users, addresses = self.tables.users, self.tables.addresses - mapper(User, users, properties={ - 'addresses': relationship( - Address, - collection_class=attribute_mapped_collection('email_address') - ) - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, + collection_class=attribute_mapped_collection( + "email_address" + ), + ) + }, + ) mapper(Address, addresses) u1 = User() u1.addresses = {"email1": Address(email_address="email1")} for loads, dumps in picklers(): repickled = loads(dumps(u1)) eq_(u1.addresses, repickled.addresses) - eq_(repickled.addresses['email1'], - Address(email_address="email1")) + eq_(repickled.addresses["email1"], Address(email_address="email1")) def test_column_mapped_collection(self): users, addresses = self.tables.users, self.tables.addresses - mapper(User, users, properties={ - 'addresses': relationship( - Address, - collection_class=column_mapped_collection( - addresses.c.email_address) - ) - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, + collection_class=column_mapped_collection( + addresses.c.email_address + ), + ) + }, + ) mapper(Address, addresses) u1 = User() u1.addresses = { "email1": Address(email_address="email1"), - "email2": Address(email_address="email2") + "email2": Address(email_address="email2"), } for loads, dumps in picklers(): repickled = loads(dumps(u1)) eq_(u1.addresses, repickled.addresses) - eq_(repickled.addresses['email1'], - Address(email_address="email1")) + eq_(repickled.addresses["email1"], Address(email_address="email1")) def test_composite_column_mapped_collection(self): users, addresses = self.tables.users, self.tables.addresses - mapper(User, users, properties={ - 'addresses': relationship( - Address, - collection_class=column_mapped_collection([ - addresses.c.id, - addresses.c.email_address]) - ) - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, + collection_class=column_mapped_collection( + [addresses.c.id, addresses.c.email_address] + ), + ) + }, + ) mapper(Address, addresses) u1 = User() u1.addresses = { (1, "email1"): Address(id=1, email_address="email1"), - (2, "email2"): Address(id=2, email_address="email2") + (2, "email2"): Address(id=2, email_address="email2"), } for loads, dumps in picklers(): repickled = loads(dumps(u1)) eq_(u1.addresses, repickled.addresses) - eq_(repickled.addresses[(1, 'email1')], - Address(id=1, email_address="email1")) + eq_( + repickled.addresses[(1, "email1")], + Address(id=1, email_address="email1"), + ) class OptionsTest(_Polymorphic): @@ -495,21 +604,26 @@ class OptionsTest(_Polymorphic): for opt, serialized in [ ( sa.orm.joinedload(Company.employees.of_type(Engineer)), - [(Company, "employees", Engineer)]), + [(Company, "employees", Engineer)], + ), ( sa.orm.joinedload(Company.employees.of_type(with_poly)), - [(Company, "employees", None)]), + [(Company, "employees", None)], + ), ]: opt2 = pickle.loads(pickle.dumps(opt)) - eq_(opt.__getstate__()['path'], serialized) - eq_(opt2.__getstate__()['path'], serialized) + eq_(opt.__getstate__()["path"], serialized) + eq_(opt2.__getstate__()["path"], serialized) def test_load(self): s = Session() with_poly = with_polymorphic(Person, [Engineer, Manager], flat=True) - emp = s.query(Company).options( - subqueryload(Company.employees.of_type(with_poly))).first() + emp = ( + s.query(Company) + .options(subqueryload(Company.employees.of_type(with_poly))) + .first() + ) e2 = pickle.loads(pickle.dumps(emp)) @@ -517,26 +631,39 @@ class OptionsTest(_Polymorphic): class PolymorphicDeferredTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30)), - Column('type', String(30))) - Table('email_users', metadata, - Column('id', Integer, ForeignKey('users.id'), primary_key=True), - Column('email_address', String(30))) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30)), + Column("type", String(30)), + ) + Table( + "email_users", + metadata, + Column("id", Integer, ForeignKey("users.id"), primary_key=True), + Column("email_address", String(30)), + ) def test_polymorphic_deferred(self): - email_users, users = (self.tables.email_users, - self.tables.users, - ) - - mapper(User, users, polymorphic_identity='user', - polymorphic_on=users.c.type) - mapper(EmailUser, email_users, inherits=User, - polymorphic_identity='emailuser') - - eu = EmailUser(name="user1", email_address='foo@bar.com') + email_users, users = (self.tables.email_users, self.tables.users) + + mapper( + User, + users, + polymorphic_identity="user", + polymorphic_on=users.c.type, + ) + mapper( + EmailUser, + email_users, + inherits=User, + polymorphic_identity="emailuser", + ) + + eu = EmailUser(name="user1", email_address="foo@bar.com") sess = create_session() sess.add(eu) sess.flush() @@ -546,8 +673,8 @@ class PolymorphicDeferredTest(fixtures.MappedTest): eu2 = pickle.loads(pickle.dumps(eu)) sess2 = create_session() sess2.add(eu2) - assert 'email_address' not in eu2.__dict__ - eq_(eu2.email_address, 'foo@bar.com') + assert "email_address" not in eu2.__dict__ + eq_(eu2.email_address, "foo@bar.com") class TupleLabelTest(_fixtures.FixtureTest): @@ -557,19 +684,28 @@ class TupleLabelTest(_fixtures.FixtureTest): @classmethod def setup_mappers(cls): - users, addresses, orders = (cls.tables.users, cls.tables.addresses, - cls.tables.orders) - mapper(User, users, properties={ - 'addresses': relationship(Address, backref='user', - order_by=addresses.c.id), - # o2m, m2o - 'orders': relationship(Order, backref='user', - order_by=orders.c.id), - }) + users, addresses, orders = ( + cls.tables.users, + cls.tables.addresses, + cls.tables.orders, + ) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, backref="user", order_by=addresses.c.id + ), + # o2m, m2o + "orders": relationship( + Order, backref="user", order_by=orders.c.id + ), + }, + ) mapper(Address, addresses) - mapper(Order, orders, properties={ - 'address': relationship(Address), # m2o - }) + mapper( + Order, orders, properties={"address": relationship(Address)} # m2o + ) def test_tuple_labeling(self): users = self.tables.users @@ -581,22 +717,23 @@ class TupleLabelTest(_fixtures.FixtureTest): if pickled is not False: row = pickle.loads(pickle.dumps(row, pickled)) - eq_(list(row.keys()), ['User', 'Address']) + eq_(list(row.keys()), ["User", "Address"]) eq_(row.User, row[0]) eq_(row.Address, row[1]) - for row in sess.query(User.name, User.id.label('foobar')): + for row in sess.query(User.name, User.id.label("foobar")): if pickled is not False: row = pickle.loads(pickle.dumps(row, pickled)) - eq_(list(row.keys()), ['name', 'foobar']) + eq_(list(row.keys()), ["name", "foobar"]) eq_(row.name, row[0]) eq_(row.foobar, row[1]) - for row in sess.query(User).values(User.name, - User.id.label('foobar')): + for row in sess.query(User).values( + User.name, User.id.label("foobar") + ): if pickled is not False: row = pickle.loads(pickle.dumps(row, pickled)) - eq_(list(row.keys()), ['name', 'foobar']) + eq_(list(row.keys()), ["name", "foobar"]) eq_(row.name, row[0]) eq_(row.foobar, row[1]) @@ -604,23 +741,24 @@ class TupleLabelTest(_fixtures.FixtureTest): for row in sess.query(User, oalias).join(User.orders).all(): if pickled is not False: row = pickle.loads(pickle.dumps(row, pickled)) - eq_(list(row.keys()), ['User']) + eq_(list(row.keys()), ["User"]) eq_(row.User, row[0]) - oalias = aliased(Order, name='orders') - for row in sess.query(User, oalias).join(oalias, User.orders) \ - .all(): + oalias = aliased(Order, name="orders") + for row in ( + sess.query(User, oalias).join(oalias, User.orders).all() + ): if pickled is not False: row = pickle.loads(pickle.dumps(row, pickled)) - eq_(list(row.keys()), ['User', 'orders']) + eq_(list(row.keys()), ["User", "orders"]) eq_(row.User, row[0]) eq_(row.orders, row[1]) # test here that first col is not labeled, only # one name in keys, matches correctly - for row in sess.query(User.name + 'hoho', User.name): - eq_(list(row.keys()), ['name']) - eq_(row[0], row.name + 'hoho') + for row in sess.query(User.name + "hoho", User.name): + eq_(list(row.keys()), ["name"]) + eq_(row[0], row.name + "hoho") if pickled is not False: ret = sess.query(User, Address).join(User.addresses).all() @@ -630,20 +768,28 @@ class TupleLabelTest(_fixtures.FixtureTest): class CustomSetupTeardownTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30), nullable=False), - test_needs_acid=True, - test_needs_fk=True) - - Table('addresses', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', None, ForeignKey('users.id')), - Column('email_address', String(50), nullable=False), - test_needs_acid=True, - test_needs_fk=True) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30), nullable=False), + test_needs_acid=True, + test_needs_fk=True, + ) + + Table( + "addresses", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", None, ForeignKey("users.id")), + Column("email_address", String(50), nullable=False), + test_needs_acid=True, + test_needs_fk=True, + ) def test_rebuild_state(self): """not much of a 'test', but illustrate how to diff --git a/test/orm/test_query.py b/test/orm/test_query.py index f287323604..3f975516f5 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -1,20 +1,62 @@ from sqlalchemy import ( - testing, null, exists, text, union, literal, literal_column, func, between, - Unicode, desc, and_, bindparam, select, distinct, or_, collate, insert, - Integer, String, Boolean, exc as sa_exc, util, cast, MetaData, ForeignKey) + testing, + null, + exists, + text, + union, + literal, + literal_column, + func, + between, + Unicode, + desc, + and_, + bindparam, + select, + distinct, + or_, + collate, + insert, + Integer, + String, + Boolean, + exc as sa_exc, + util, + cast, + MetaData, + ForeignKey, +) from sqlalchemy.sql import operators, expression from sqlalchemy import column, table from sqlalchemy.engine import default from sqlalchemy.orm import ( - attributes, mapper, relationship, create_session, synonym, Session, - aliased, column_property, joinedload_all, joinedload, Query, Bundle, - subqueryload, backref, lazyload, defer) + attributes, + mapper, + relationship, + create_session, + synonym, + Session, + aliased, + column_property, + joinedload_all, + joinedload, + Query, + Bundle, + subqueryload, + backref, + lazyload, + defer, +) from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.schema import Table, Column import sqlalchemy as sa from sqlalchemy.testing.assertions import ( - eq_, assert_raises, assert_raises_message, expect_warnings, - eq_ignore_whitespace) + eq_, + assert_raises, + assert_raises_message, + expect_warnings, + eq_ignore_whitespace, +) from sqlalchemy.testing import fixtures, AssertsCompiledSQL, assert_warnings from test.orm import _fixtures from sqlalchemy.orm.util import join, with_parent @@ -24,8 +66,8 @@ from sqlalchemy import inspect class QueryTest(_fixtures.FixtureTest): - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod @@ -60,12 +102,22 @@ class OnlyReturnTuplesTest(QueryTest): def test_multiple_entity_false(self): User = self.classes.User - row = create_session().query(User.id, User).only_return_tuples(False).first() + row = ( + create_session() + .query(User.id, User) + .only_return_tuples(False) + .first() + ) assert isinstance(row, tuple) def test_multiple_entity_true(self): User = self.classes.User - row = create_session().query(User.id, User).only_return_tuples(True).first() + row = ( + create_session() + .query(User.id, User) + .only_return_tuples(True) + .first() + ) assert isinstance(row, tuple) @@ -75,139 +127,198 @@ class RowTupleTest(QueryTest): def test_custom_names(self): User, users = self.classes.User, self.tables.users - mapper(User, users, properties={'uname': users.c.name}) + mapper(User, users, properties={"uname": users.c.name}) - row = create_session().query(User.id, User.uname).\ - filter(User.id == 7).first() + row = ( + create_session() + .query(User.id, User.uname) + .filter(User.id == 7) + .first() + ) assert row.id == 7 - assert row.uname == 'jack' + assert row.uname == "jack" def test_column_metadata(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) mapper(Address, addresses) sess = create_session() user_alias = aliased(User) - user_alias_id_label = user_alias.id.label('foo') - address_alias = aliased(Address, name='aalias') + user_alias_id_label = user_alias.id.label("foo") + address_alias = aliased(Address, name="aalias") fn = func.count(User.id) - name_label = User.name.label('uname') - bundle = Bundle('b1', User.id, User.name) + name_label = User.name.label("uname") + bundle = Bundle("b1", User.id, User.name) cte = sess.query(User.id).cte() for q, asserted in [ ( sess.query(User), [ { - 'name': 'User', 'type': User, 'aliased': False, - 'expr': User, 'entity': User}] + "name": "User", + "type": User, + "aliased": False, + "expr": User, + "entity": User, + } + ], ), ( sess.query(User.id, User), [ { - 'name': 'id', 'type': users.c.id.type, - 'aliased': False, 'expr': User.id, 'entity': User}, + "name": "id", + "type": users.c.id.type, + "aliased": False, + "expr": User.id, + "entity": User, + }, { - 'name': 'User', 'type': User, 'aliased': False, - 'expr': User, 'entity': User} - ] + "name": "User", + "type": User, + "aliased": False, + "expr": User, + "entity": User, + }, + ], ), ( sess.query(User.id, user_alias), [ { - 'name': 'id', 'type': users.c.id.type, - 'aliased': False, 'expr': User.id, 'entity': User}, + "name": "id", + "type": users.c.id.type, + "aliased": False, + "expr": User.id, + "entity": User, + }, { - 'name': None, 'type': User, 'aliased': True, - 'expr': user_alias, 'entity': user_alias} - ] + "name": None, + "type": User, + "aliased": True, + "expr": user_alias, + "entity": user_alias, + }, + ], ), ( sess.query(user_alias.id), [ { - 'name': 'id', 'type': users.c.id.type, - 'aliased': True, 'expr': user_alias.id, - 'entity': user_alias}, - ] + "name": "id", + "type": users.c.id.type, + "aliased": True, + "expr": user_alias.id, + "entity": user_alias, + } + ], ), ( sess.query(user_alias_id_label), [ { - 'name': 'foo', 'type': users.c.id.type, - 'aliased': True, 'expr': user_alias_id_label, - 'entity': user_alias}, - ] + "name": "foo", + "type": users.c.id.type, + "aliased": True, + "expr": user_alias_id_label, + "entity": user_alias, + } + ], ), ( sess.query(address_alias), [ { - 'name': 'aalias', 'type': Address, 'aliased': True, - 'expr': address_alias, 'entity': address_alias} - ] + "name": "aalias", + "type": Address, + "aliased": True, + "expr": address_alias, + "entity": address_alias, + } + ], ), ( sess.query(name_label, fn), [ { - 'name': 'uname', 'type': users.c.name.type, - 'aliased': False, 'expr': name_label, 'entity': User}, + "name": "uname", + "type": users.c.name.type, + "aliased": False, + "expr": name_label, + "entity": User, + }, { - 'name': None, 'type': fn.type, 'aliased': False, - 'expr': fn, 'entity': User}, - ] + "name": None, + "type": fn.type, + "aliased": False, + "expr": fn, + "entity": User, + }, + ], ), ( sess.query(cte), [ - { - 'aliased': False, - 'expr': cte.c.id, 'type': cte.c.id.type, - 'name': 'id', 'entity': None - }] + { + "aliased": False, + "expr": cte.c.id, + "type": cte.c.id.type, + "name": "id", + "entity": None, + } + ], ), ( sess.query(users), [ - {'aliased': False, - 'expr': users.c.id, 'type': users.c.id.type, - 'name': 'id', 'entity': None}, - {'aliased': False, - 'expr': users.c.name, 'type': users.c.name.type, - 'name': 'name', 'entity': None} - ] + { + "aliased": False, + "expr": users.c.id, + "type": users.c.id.type, + "name": "id", + "entity": None, + }, + { + "aliased": False, + "expr": users.c.name, + "type": users.c.name.type, + "name": "name", + "entity": None, + }, + ], ), ( sess.query(users.c.name), - [{ - "name": "name", "type": users.c.name.type, - "aliased": False, "expr": users.c.name, "entity": None - }] + [ + { + "name": "name", + "type": users.c.name.type, + "aliased": False, + "expr": users.c.name, + "entity": None, + } + ], ), ( sess.query(bundle), [ { - 'aliased': False, - 'expr': bundle, - 'type': Bundle, - 'name': 'b1', 'entity': User + "aliased": False, + "expr": bundle, + "type": Bundle, + "name": "b1", + "entity": User, } - ] - ) + ], + ), ]: - eq_( - q.column_descriptions, - asserted - ) + eq_(q.column_descriptions, asserted) def test_unhashable_type(self): from sqlalchemy.types import TypeDecorator, Integer @@ -225,12 +336,11 @@ class RowTupleTest(QueryTest): mapper(User, users) s = Session() - q = s.query(User, type_coerce(users.c.id, MyType).label('foo')).\ - filter(User.id == 7) - row = q.first() - eq_( - row, (User(id=7), [7]) + q = s.query(User, type_coerce(users.c.id, MyType).label("foo")).filter( + User.id == 7 ) + row = q.first() + eq_(row, (User(id=7), [7])) class BindSensitiveStringifyTest(fixtures.TestBase): @@ -243,23 +353,26 @@ class BindSensitiveStringifyTest(fixtures.TestBase): m = MetaData(bind=bind_to) user_table = Table( - 'users', m, - Column('id', Integer, primary_key=True), - Column('name', String(50))) + "users", + m, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + ) mapper(User, user_table) return User def _dialect_fixture(self): class MyDialect(default.DefaultDialect): - default_paramstyle = 'qmark' + default_paramstyle = "qmark" from sqlalchemy.engine import base + return base.Engine(mock.Mock(), MyDialect(), mock.Mock()) def _test( - self, bound_metadata, bound_session, - session_present, expect_bound): + self, bound_metadata, bound_session, session_present, expect_bound + ): if bound_metadata or bound_session: eng = self._dialect_fixture() else: @@ -275,9 +388,10 @@ class BindSensitiveStringifyTest(fixtures.TestBase): eq_ignore_whitespace( str(q), "SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.id = ?" if expect_bound else - "SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.id = :id_1" + "FROM users WHERE users.id = ?" + if expect_bound + else "SELECT users.id AS users_id, users.name AS users_name " + "FROM users WHERE users.id = :id_1", ) def test_query_unbound_metadata_bound_session(self): @@ -297,14 +411,13 @@ class BindSensitiveStringifyTest(fixtures.TestBase): class RawSelectTest(QueryTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_select_from_entity(self): User = self.classes.User self.assert_compile( - select(['*']).select_from(User), - "SELECT * FROM users" + select(["*"]).select_from(User), "SELECT * FROM users" ) def test_where_relationship(self): @@ -313,7 +426,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): self.assert_compile( select([User]).where(User.addresses), "SELECT users.id, users.name FROM users, addresses " - "WHERE users.id = addresses.user_id" + "WHERE users.id = addresses.user_id", ) def test_where_m2m_relationship(self): @@ -324,23 +437,21 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): "SELECT items.id, items.description FROM items, " "item_keywords AS item_keywords_1, keywords " "WHERE items.id = item_keywords_1.item_id " - "AND keywords.id = item_keywords_1.keyword_id" + "AND keywords.id = item_keywords_1.keyword_id", ) def test_inline_select_from_entity(self): User = self.classes.User self.assert_compile( - select(['*'], from_obj=User), - "SELECT * FROM users" + select(["*"], from_obj=User), "SELECT * FROM users" ) def test_select_from_aliased_entity(self): User = self.classes.User ua = aliased(User, name="ua") self.assert_compile( - select(['*']).select_from(ua), - "SELECT * FROM users AS ua" + select(["*"]).select_from(ua), "SELECT * FROM users AS ua" ) def test_correlate_entity(self): @@ -350,14 +461,18 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): self.assert_compile( select( [ - User.name, Address.id, - select([func.count(Address.id)]). - where(User.id == Address.user_id). - correlate(User).as_scalar()]), + User.name, + Address.id, + select([func.count(Address.id)]) + .where(User.id == Address.user_id) + .correlate(User) + .as_scalar(), + ] + ), "SELECT users.name, addresses.id, " "(SELECT count(addresses.id) AS count_1 " "FROM addresses WHERE users.id = addresses.user_id) AS anon_1 " - "FROM users, addresses" + "FROM users, addresses", ) def test_correlate_aliased_entity(self): @@ -368,10 +483,14 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): self.assert_compile( select( [ - uu.name, Address.id, - select([func.count(Address.id)]). - where(uu.id == Address.user_id). - correlate(uu).as_scalar()]), + uu.name, + Address.id, + select([func.count(Address.id)]) + .where(uu.id == Address.user_id) + .correlate(uu) + .as_scalar(), + ] + ), # for a long time, "uu.id = address.user_id" was reversed; # this was resolved as of #2872 and had to do with # InstrumentedAttribute.__eq__() taking precedence over @@ -379,15 +498,14 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): "SELECT uu.name, addresses.id, " "(SELECT count(addresses.id) AS count_1 " "FROM addresses WHERE uu.id = addresses.user_id) AS anon_1 " - "FROM users AS uu, addresses" + "FROM users AS uu, addresses", ) def test_columns_clause_entity(self): User = self.classes.User self.assert_compile( - select([User]), - "SELECT users.id, users.name FROM users" + select([User]), "SELECT users.id, users.name FROM users" ) def test_columns_clause_columns(self): @@ -395,33 +513,32 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): self.assert_compile( select([User.id, User.name]), - "SELECT users.id, users.name FROM users" + "SELECT users.id, users.name FROM users", ) def test_columns_clause_aliased_columns(self): User = self.classes.User - ua = aliased(User, name='ua') + ua = aliased(User, name="ua") self.assert_compile( - select([ua.id, ua.name]), - "SELECT ua.id, ua.name FROM users AS ua" + select([ua.id, ua.name]), "SELECT ua.id, ua.name FROM users AS ua" ) def test_columns_clause_aliased_entity(self): User = self.classes.User - ua = aliased(User, name='ua') + ua = aliased(User, name="ua") self.assert_compile( - select([ua]), - "SELECT ua.id, ua.name FROM users AS ua" + select([ua]), "SELECT ua.id, ua.name FROM users AS ua" ) def test_core_join(self): User = self.classes.User Address = self.classes.Address from sqlalchemy.sql import join + self.assert_compile( select([User]).select_from(join(User, Address)), "SELECT users.id, users.name FROM users " - "JOIN addresses ON users.id = addresses.user_id" + "JOIN addresses ON users.id = addresses.user_id", ) def test_insert_from_query(self): @@ -429,12 +546,12 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): Address = self.classes.Address s = Session() - q = s.query(User.id, User.name).filter_by(name='ed') + q = s.query(User.id, User.name).filter_by(name="ed") self.assert_compile( - insert(Address).from_select(('id', 'email_address'), q), + insert(Address).from_select(("id", "email_address"), q), "INSERT INTO addresses (id, email_address) " "SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.name = :name_1" + "FROM users WHERE users.name = :name_1", ) def test_insert_from_query_col_attr(self): @@ -442,55 +559,54 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): Address = self.classes.Address s = Session() - q = s.query(User.id, User.name).filter_by(name='ed') + q = s.query(User.id, User.name).filter_by(name="ed") self.assert_compile( insert(Address).from_select( - (Address.id, Address.email_address), q), + (Address.id, Address.email_address), q + ), "INSERT INTO addresses (id, email_address) " "SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.name = :name_1" + "FROM users WHERE users.name = :name_1", ) def test_update_from_entity(self): from sqlalchemy.sql import update + User = self.classes.User self.assert_compile( - update(User), - "UPDATE users SET id=:id, name=:name" + update(User), "UPDATE users SET id=:id, name=:name" ) self.assert_compile( - update(User).values(name='ed').where(User.id == 5), + update(User).values(name="ed").where(User.id == 5), "UPDATE users SET name=:name WHERE users.id = :id_1", - checkparams={"id_1": 5, "name": "ed"} + checkparams={"id_1": 5, "name": "ed"}, ) def test_delete_from_entity(self): from sqlalchemy.sql import delete + User = self.classes.User - self.assert_compile( - delete(User), - "DELETE FROM users" - ) + self.assert_compile(delete(User), "DELETE FROM users") self.assert_compile( delete(User).where(User.id == 5), "DELETE FROM users WHERE users.id = :id_1", - checkparams={"id_1": 5} + checkparams={"id_1": 5}, ) def test_insert_from_entity(self): from sqlalchemy.sql import insert + User = self.classes.User self.assert_compile( - insert(User), - "INSERT INTO users (id, name) VALUES (:id, :name)" + insert(User), "INSERT INTO users (id, name) VALUES (:id, :name)" ) self.assert_compile( insert(User).values(name="ed"), "INSERT INTO users (name) VALUES (:name)", - checkparams={"name": "ed"} + checkparams={"name": "ed"}, ) def test_col_prop_builtin_function(self): @@ -498,16 +614,20 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): pass mapper( - Foo, self.tables.users, properties={ - 'foob': column_property( - func.coalesce(self.tables.users.c.name)) - }) + Foo, + self.tables.users, + properties={ + "foob": column_property( + func.coalesce(self.tables.users.c.name) + ) + }, + ) self.assert_compile( - select([Foo]).where(Foo.foob == 'somename').order_by(Foo.foob), + select([Foo]).where(Foo.foob == "somename").order_by(Foo.foob), "SELECT users.id, users.name FROM users " "WHERE coalesce(users.name) = :param_1 " - "ORDER BY coalesce(users.name)" + "ORDER BY coalesce(users.name)", ) @@ -565,7 +685,7 @@ class GetTest(QueryTest): s = Session() q = s.query(User.id) - assert_raises(sa_exc.InvalidRequestError, q.get, (5, )) + assert_raises(sa_exc.InvalidRequestError, q.get, (5,)) def test_get_null_pk(self): """test that a mapping which can have None in a @@ -579,10 +699,13 @@ class GetTest(QueryTest): pass mapper( - UserThing, s, properties={ - 'id': (users.c.id, addresses.c.user_id), - 'address_id': addresses.c.id, - }) + UserThing, + s, + properties={ + "id": (users.c.id, addresses.c.user_id), + "address_id": addresses.c.id, + }, + ) sess = create_session() u10 = sess.query(UserThing).get((10, None)) eq_(u10, UserThing(id=10)) @@ -595,11 +718,13 @@ class GetTest(QueryTest): s = create_session() - q = s.query(User).join('addresses').filter(Address.user_id == 8) + q = s.query(User).join("addresses").filter(Address.user_id == 8) assert_raises(sa_exc.InvalidRequestError, q.get, 7) assert_raises( sa_exc.InvalidRequestError, - s.query(User).filter(User.id == 7).get, 19) + s.query(User).filter(User.id == 7).get, + 19, + ) # order_by()/get() doesn't raise s.query(User).order_by(User.id).get(8) @@ -614,7 +739,7 @@ class GetTest(QueryTest): s.query(User).get(7) - q = s.query(User).join('addresses').filter(Address.user_id == 8) + q = s.query(User).join("addresses").filter(Address.user_id == 8) assert_raises(sa_exc.InvalidRequestError, q.get, 7) def test_unique_param_names(self): @@ -622,12 +747,13 @@ class GetTest(QueryTest): class SomeUser(object): pass - s = users.select(users.c.id != 12).alias('users') + + s = users.select(users.c.id != 12).alias("users") m = mapper(SomeUser, s) assert s.primary_key == m.primary_key sess = create_session() - assert sess.query(SomeUser).get(7).name == 'jack' + assert sess.query(SomeUser).get(7).name == "jack" def test_load(self): User, Address = self.classes.User, self.classes.Address @@ -643,15 +769,15 @@ class GetTest(QueryTest): u2 = s.query(User).populate_existing().get(7) assert u is not u2 - u2.name = 'some name' - a = Address(email_address='some other name') + u2.name = "some name" + a = Address(email_address="some other name") u2.addresses.append(a) assert u2 in s.dirty assert a in u2.addresses s.query(User).populate_existing().get(7) assert u2 not in s.dirty - assert u2.name == 'jack' + assert u2.name == "jack" assert a not in u2.addresses @testing.provide_metadata @@ -663,21 +789,24 @@ class GetTest(QueryTest): metadata = self.metadata table = Table( - 'unicode_data', metadata, - Column( - 'id', Unicode(40), primary_key=True), - Column('data', Unicode(40))) + "unicode_data", + metadata, + Column("id", Unicode(40), primary_key=True), + Column("data", Unicode(40)), + ) metadata.create_all() - ustring = util.b('petit voix m\xe2\x80\x99a').decode('utf-8') + ustring = util.b("petit voix m\xe2\x80\x99a").decode("utf-8") table.insert().execute(id=ustring, data=ustring) class LocalFoo(self.classes.Base): pass + mapper(LocalFoo, table) eq_( create_session().query(LocalFoo).get(ustring), - LocalFoo(id=ustring, data=ustring)) + LocalFoo(id=ustring, data=ustring), + ) def test_populate_existing(self): User, Address = self.classes.User, self.classes.Address @@ -687,8 +816,8 @@ class GetTest(QueryTest): userlist = s.query(User).all() u = userlist[0] - u.name = 'foo' - a = Address(name='ed') + u.name = "foo" + a = Address(name="ed") u.addresses.append(a) self.assert_(a in u.addresses) @@ -697,23 +826,23 @@ class GetTest(QueryTest): self.assert_(u not in s.dirty) - self.assert_(u.name == 'jack') + self.assert_(u.name == "jack") self.assert_(a not in u.addresses) - u.addresses[0].email_address = 'lala' - u.orders[1].items[2].description = 'item 12' + u.addresses[0].email_address = "lala" + u.orders[1].items[2].description = "item 12" # test that lazy load doesn't change child items s.query(User).populate_existing().all() - assert u.addresses[0].email_address == 'lala' - assert u.orders[1].items[2].description == 'item 12' + assert u.addresses[0].email_address == "lala" + assert u.orders[1].items[2].description == "item 12" # eager load does - s.query(User). \ - options(joinedload('addresses'), joinedload_all('orders.items')). \ - populate_existing().all() - assert u.addresses[0].email_address == 'jack@bean.com' - assert u.orders[1].items[2].description == 'item 5' + s.query(User).options( + joinedload("addresses"), joinedload_all("orders.items") + ).populate_existing().all() + assert u.addresses[0].email_address == "jack@bean.com" + assert u.orders[1].items[2].description == "item 5" class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): @@ -725,25 +854,26 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): for q in ( s.query(User).limit(2), s.query(User).offset(2), - s.query(User).limit(2).offset(2) + s.query(User).limit(2).offset(2), ): assert_raises(sa_exc.InvalidRequestError, q.join, "addresses") assert_raises( - sa_exc.InvalidRequestError, q.filter, User.name == 'ed') + sa_exc.InvalidRequestError, q.filter, User.name == "ed" + ) - assert_raises(sa_exc.InvalidRequestError, q.filter_by, name='ed') + assert_raises(sa_exc.InvalidRequestError, q.filter_by, name="ed") - assert_raises(sa_exc.InvalidRequestError, q.order_by, 'foo') + assert_raises(sa_exc.InvalidRequestError, q.order_by, "foo") - assert_raises(sa_exc.InvalidRequestError, q.group_by, 'foo') + assert_raises(sa_exc.InvalidRequestError, q.group_by, "foo") - assert_raises(sa_exc.InvalidRequestError, q.having, 'foo') + assert_raises(sa_exc.InvalidRequestError, q.having, "foo") q.enable_assertions(False).join("addresses") - q.enable_assertions(False).filter(User.name == 'ed') - q.enable_assertions(False).order_by('foo') - q.enable_assertions(False).group_by('foo') + q.enable_assertions(False).filter(User.name == "ed") + q.enable_assertions(False).order_by("foo") + q.enable_assertions(False).group_by("foo") def test_no_from(self): users, User = self.tables.users, self.classes.User @@ -753,7 +883,7 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): q = s.query(User).select_from(users) assert_raises(sa_exc.InvalidRequestError, q.select_from, users) - q = s.query(User).join('addresses') + q = s.query(User).join("addresses") assert_raises(sa_exc.InvalidRequestError, q.select_from, users) q = s.query(User).order_by(User.id) @@ -775,15 +905,18 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): assert_raises(sa_exc.ArgumentError, q.select_from, User.id) def test_invalid_from_statement(self): - User, addresses, users = (self.classes.User, - self.tables.addresses, - self.tables.users) + User, addresses, users = ( + self.classes.User, + self.tables.addresses, + self.tables.users, + ) s = create_session() q = s.query(User) assert_raises(sa_exc.ArgumentError, q.from_statement, User.id == 5) assert_raises( - sa_exc.ArgumentError, q.from_statement, users.join(addresses)) + sa_exc.ArgumentError, q.from_statement, users.join(addresses) + ) def test_invalid_column(self): User = self.classes.User @@ -809,8 +942,10 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): q = s.query(User).distinct() assert_raises(sa_exc.InvalidRequestError, q.select_from, User) assert_raises( - sa_exc.InvalidRequestError, q.from_statement, - text("select * from table")) + sa_exc.InvalidRequestError, + q.from_statement, + text("select * from table"), + ) assert_raises(sa_exc.InvalidRequestError, q.with_polymorphic, User) def test_order_by(self): @@ -823,8 +958,10 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): q = s.query(User).order_by(User.id) assert_raises(sa_exc.InvalidRequestError, q.select_from, User) assert_raises( - sa_exc.InvalidRequestError, q.from_statement, - text("select * from table")) + sa_exc.InvalidRequestError, + q.from_statement, + text("select * from table"), + ) assert_raises(sa_exc.InvalidRequestError, q.with_polymorphic, User) def test_only_full_mapper_zero(self): @@ -865,7 +1002,7 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): is_(q._mapper_zero(), None) is_(q._entity_zero(), None) - q1 = s.query(Bundle('b1', User.id, User.name)) + q1 = s.query(Bundle("b1", User.id, User.name)) is_(q1._mapper_zero(), inspect(User)) is_(q1._entity_zero(), inspect(User)) @@ -876,24 +1013,20 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): for meth, arg, kw in [ (Query.filter, (User.id == 5,), {}), - (Query.filter_by, (), {'id': 5}), - (Query.limit, (5, ), {}), + (Query.filter_by, (), {"id": 5}), + (Query.limit, (5,), {}), (Query.group_by, (User.name,), {}), - (Query.order_by, (User.name,), {}) + (Query.order_by, (User.name,), {}), ]: q = s.query(User) q = meth(q, *arg, **kw) assert_raises( - sa_exc.InvalidRequestError, - q.from_statement, text("x") + sa_exc.InvalidRequestError, q.from_statement, text("x") ) q = s.query(User) q = q.from_statement(text("x")) - assert_raises( - sa_exc.InvalidRequestError, - meth, q, *arg, **kw - ) + assert_raises(sa_exc.InvalidRequestError, meth, q, *arg, **kw) def test_illegal_coercions(self): User = self.classes.User @@ -901,41 +1034,44 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): assert_raises_message( sa_exc.ArgumentError, "Object .*User.* is not legal as a SQL literal value", - distinct, User + distinct, + User, ) ua = aliased(User) assert_raises_message( sa_exc.ArgumentError, "Object .*User.* is not legal as a SQL literal value", - distinct, ua + distinct, + ua, ) s = Session() assert_raises_message( sa_exc.ArgumentError, "Object .*User.* is not legal as a SQL literal value", - lambda: s.query(User).filter(User.name == User) + lambda: s.query(User).filter(User.name == User), ) u1 = User() assert_raises_message( sa_exc.ArgumentError, "Object .*User.* is not legal as a SQL literal value", - distinct, u1 + distinct, + u1, ) assert_raises_message( sa_exc.ArgumentError, "Object .*User.* is not legal as a SQL literal value", - lambda: s.query(User).filter(User.name == u1) + lambda: s.query(User).filter(User.name == u1), ) class OperatorTest(QueryTest, AssertsCompiledSQL): """test sql.Comparator implementation for MapperProperties""" - __dialect__ = 'default' + __dialect__ = "default" def _test(self, clause, expected, entity=None, checkparams=None): dialect = default.DefaultDialect() @@ -952,8 +1088,8 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): self.assert_compile(clause, expected, checkparams=checkparams) def _test_filter_aliases( - self, - clause, expected, from_, onclause, checkparams=None): + self, clause, expected, from_, onclause, checkparams=None + ): dialect = default.DefaultDialect() sess = Session() lead = sess.query(from_).join(onclause, aliased=True) @@ -969,20 +1105,22 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): User = self.classes.User create_session().query(User) - for (py_op, sql_op) in ((operators.add, '+'), (operators.mul, '*'), - (operators.sub, '-'), - (operators.truediv, '/'), - (operators.div, '/'), - ): + for (py_op, sql_op) in ( + (operators.add, "+"), + (operators.mul, "*"), + (operators.sub, "-"), + (operators.truediv, "/"), + (operators.div, "/"), + ): for (lhs, rhs, res) in ( - (5, User.id, ':id_1 %s users.id'), - (5, literal(6), ':param_1 %s :param_2'), - (User.id, 5, 'users.id %s :id_1'), - (User.id, literal('b'), 'users.id %s :param_1'), - (User.id, User.id, 'users.id %s users.id'), - (literal(5), 'b', ':param_1 %s :param_2'), - (literal(5), User.id, ':param_1 %s users.id'), - (literal(5), literal(6), ':param_1 %s :param_2'), + (5, User.id, ":id_1 %s users.id"), + (5, literal(6), ":param_1 %s :param_2"), + (User.id, 5, "users.id %s :id_1"), + (User.id, literal("b"), "users.id %s :param_1"), + (User.id, User.id, "users.id %s users.id"), + (literal(5), "b", ":param_1 %s :param_2"), + (literal(5), User.id, ":param_1 %s users.id"), + (literal(5), literal(6), ":param_1 %s :param_2"), ): self._test(py_op(lhs, rhs), res % sql_op) @@ -992,37 +1130,47 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): create_session().query(User) ualias = aliased(User) - for (py_op, fwd_op, rev_op) in ((operators.lt, '<', '>'), - (operators.gt, '>', '<'), - (operators.eq, '=', '='), - (operators.ne, '!=', '!='), - (operators.le, '<=', '>='), - (operators.ge, '>=', '<=')): + for (py_op, fwd_op, rev_op) in ( + (operators.lt, "<", ">"), + (operators.gt, ">", "<"), + (operators.eq, "=", "="), + (operators.ne, "!=", "!="), + (operators.le, "<=", ">="), + (operators.ge, ">=", "<="), + ): for (lhs, rhs, l_sql, r_sql) in ( - ('a', User.id, ':id_1', 'users.id'), - ('a', literal('b'), ':param_2', ':param_1'), # note swap! - (User.id, 'b', 'users.id', ':id_1'), - (User.id, literal('b'), 'users.id', ':param_1'), - (User.id, User.id, 'users.id', 'users.id'), - (literal('a'), 'b', ':param_1', ':param_2'), - (literal('a'), User.id, ':param_1', 'users.id'), - (literal('a'), literal('b'), ':param_1', ':param_2'), - (ualias.id, literal('b'), 'users_1.id', ':param_1'), - (User.id, ualias.name, 'users.id', 'users_1.name'), - (User.name, ualias.name, 'users.name', 'users_1.name'), - (ualias.name, User.name, 'users_1.name', 'users.name'), + ("a", User.id, ":id_1", "users.id"), + ("a", literal("b"), ":param_2", ":param_1"), # note swap! + (User.id, "b", "users.id", ":id_1"), + (User.id, literal("b"), "users.id", ":param_1"), + (User.id, User.id, "users.id", "users.id"), + (literal("a"), "b", ":param_1", ":param_2"), + (literal("a"), User.id, ":param_1", "users.id"), + (literal("a"), literal("b"), ":param_1", ":param_2"), + (ualias.id, literal("b"), "users_1.id", ":param_1"), + (User.id, ualias.name, "users.id", "users_1.name"), + (User.name, ualias.name, "users.name", "users_1.name"), + (ualias.name, User.name, "users_1.name", "users.name"), ): # the compiled clause should match either (e.g.): # 'a' < 'b' -or- 'b' > 'a'. - compiled = str(py_op(lhs, rhs).compile( - dialect=default.DefaultDialect())) + compiled = str( + py_op(lhs, rhs).compile(dialect=default.DefaultDialect()) + ) fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql) rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql) - self.assert_(compiled == fwd_sql or compiled == rev_sql, - "\n'" + compiled + "'\n does not match\n'" + - fwd_sql + "'\n or\n'" + rev_sql + "'") + self.assert_( + compiled == fwd_sql or compiled == rev_sql, + "\n'" + + compiled + + "'\n does not match\n'" + + fwd_sql + + "'\n or\n'" + + rev_sql + + "'", + ) def test_o2m_compare_to_null(self): User = self.classes.User @@ -1037,43 +1185,70 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): def test_m2o_compare_to_null(self): Address = self.classes.Address self._test(Address.user == None, "addresses.user_id IS NULL") # noqa - self._test(~(Address.user == None), # noqa - "addresses.user_id IS NOT NULL") - self._test(~(Address.user != None), # noqa - "addresses.user_id IS NULL") + self._test( + ~(Address.user == None), "addresses.user_id IS NOT NULL" # noqa + ) + self._test( + ~(Address.user != None), "addresses.user_id IS NULL" # noqa + ) self._test(None == Address.user, "addresses.user_id IS NULL") # noqa - self._test(~(None == Address.user), # noqa - "addresses.user_id IS NOT NULL") + self._test( + ~(None == Address.user), "addresses.user_id IS NOT NULL" # noqa + ) def test_o2m_compare_to_null_orm_adapt(self): User, Address = self.classes.User, self.classes.Address self._test_filter_aliases( User.id == None, # noqa - "users_1.id IS NULL", Address, Address.user), + "users_1.id IS NULL", + Address, + Address.user, + ), self._test_filter_aliases( User.id != None, # noqa - "users_1.id IS NOT NULL", Address, Address.user), + "users_1.id IS NOT NULL", + Address, + Address.user, + ), self._test_filter_aliases( ~(User.id == None), # noqa - "users_1.id IS NOT NULL", Address, Address.user), + "users_1.id IS NOT NULL", + Address, + Address.user, + ), self._test_filter_aliases( ~(User.id != None), # noqa - "users_1.id IS NULL", Address, Address.user), + "users_1.id IS NULL", + Address, + Address.user, + ), def test_m2o_compare_to_null_orm_adapt(self): User, Address = self.classes.User, self.classes.Address self._test_filter_aliases( Address.user == None, # noqa - "addresses_1.user_id IS NULL", User, User.addresses), + "addresses_1.user_id IS NULL", + User, + User.addresses, + ), self._test_filter_aliases( Address.user != None, # noqa - "addresses_1.user_id IS NOT NULL", User, User.addresses), + "addresses_1.user_id IS NOT NULL", + User, + User.addresses, + ), self._test_filter_aliases( ~(Address.user == None), # noqa - "addresses_1.user_id IS NOT NULL", User, User.addresses), + "addresses_1.user_id IS NOT NULL", + User, + User.addresses, + ), self._test_filter_aliases( ~(Address.user != None), # noqa - "addresses_1.user_id IS NULL", User, User.addresses), + "addresses_1.user_id IS NULL", + User, + User.addresses, + ), def test_o2m_compare_to_null_aliased(self): User = self.classes.User @@ -1087,8 +1262,9 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): Address = self.classes.Address a1 = aliased(Address) self._test(a1.user == None, "addresses_1.user_id IS NULL") # noqa - self._test(~(a1.user == None), # noqa - "addresses_1.user_id IS NOT NULL") + self._test( + ~(a1.user == None), "addresses_1.user_id IS NOT NULL" # noqa + ) self._test(a1.user != None, "addresses_1.user_id IS NOT NULL") # noqa self._test(~(a1.user != None), "addresses_1.user_id IS NULL") # noqa @@ -1108,7 +1284,7 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): User.addresses.any(Address.id == 17), "EXISTS (SELECT 1 FROM addresses " "WHERE users.id = addresses.user_id AND addresses.id = :id_1)", - entity=User + entity=User, ) def test_o2m_any_aliased(self): @@ -1120,7 +1296,7 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): "EXISTS (SELECT 1 FROM addresses AS addresses_1 " "WHERE users_1.id = addresses_1.user_id AND " "addresses_1.id = :id_1)", - entity=u1 + entity=u1, ) def test_o2m_any_orm_adapt(self): @@ -1129,7 +1305,8 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): User.addresses.any(Address.id == 17), "EXISTS (SELECT 1 FROM addresses " "WHERE users_1.id = addresses.user_id AND addresses.id = :id_1)", - Address, Address.user + Address, + Address.user, ) def test_m2o_compare_instance(self): @@ -1149,7 +1326,8 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): self._test( Address.user != u7, "addresses.user_id != :user_id_1 OR addresses.user_id IS NULL", - checkparams={'user_id_1': 7}) + checkparams={"user_id_1": 7}, + ) def test_m2o_compare_instance_orm_adapt(self): User, Address = self.classes.User, self.classes.Address @@ -1159,8 +1337,10 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): self._test_filter_aliases( Address.user == u7, - ":param_1 = addresses_1.user_id", User, User.addresses, - checkparams={'param_1': 7} + ":param_1 = addresses_1.user_id", + User, + User.addresses, + checkparams={"param_1": 7}, ) def test_m2o_compare_instance_negated_warn_on_none(self): @@ -1173,8 +1353,9 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): Address.user != u7_transient, "addresses_1.user_id != :user_id_1 " "OR addresses_1.user_id IS NULL", - User, User.addresses, - checkparams={'user_id_1': None} + User, + User.addresses, + checkparams={"user_id_1": None}, ) def test_m2o_compare_instance_negated_orm_adapt(self): @@ -1188,41 +1369,51 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): self._test_filter_aliases( Address.user != u7, "addresses_1.user_id != :user_id_1 OR addresses_1.user_id IS NULL", - User, User.addresses, - checkparams={'user_id_1': 7} + User, + User.addresses, + checkparams={"user_id_1": 7}, ) self._test_filter_aliases( - ~(Address.user == u7), ":param_1 != addresses_1.user_id", - User, User.addresses, - checkparams={'param_1': 7} + ~(Address.user == u7), + ":param_1 != addresses_1.user_id", + User, + User.addresses, + checkparams={"param_1": 7}, ) self._test_filter_aliases( ~(Address.user != u7), "NOT (addresses_1.user_id != :user_id_1 " - "OR addresses_1.user_id IS NULL)", User, User.addresses, - checkparams={'user_id_1': 7} + "OR addresses_1.user_id IS NULL)", + User, + User.addresses, + checkparams={"user_id_1": 7}, ) self._test_filter_aliases( Address.user != u7_transient, "addresses_1.user_id != :user_id_1 OR addresses_1.user_id IS NULL", - User, User.addresses, - checkparams={'user_id_1': 7} + User, + User.addresses, + checkparams={"user_id_1": 7}, ) self._test_filter_aliases( - ~(Address.user == u7_transient), ":param_1 != addresses_1.user_id", - User, User.addresses, - checkparams={'param_1': 7} + ~(Address.user == u7_transient), + ":param_1 != addresses_1.user_id", + User, + User.addresses, + checkparams={"param_1": 7}, ) self._test_filter_aliases( ~(Address.user != u7_transient), "NOT (addresses_1.user_id != :user_id_1 " - "OR addresses_1.user_id IS NULL)", User, User.addresses, - checkparams={'user_id_1': 7} + "OR addresses_1.user_id IS NULL)", + User, + User.addresses, + checkparams={"user_id_1": 7}, ) def test_m2o_compare_instance_aliased(self): @@ -1237,23 +1428,27 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): self._test( a1.user == u7, ":param_1 = addresses_1.user_id", - checkparams={'param_1': 7}) + checkparams={"param_1": 7}, + ) self._test( a1.user != u7, "addresses_1.user_id != :user_id_1 OR addresses_1.user_id IS NULL", - checkparams={'user_id_1': 7}) + checkparams={"user_id_1": 7}, + ) a1 = aliased(Address) self._test( a1.user == u7_transient, ":param_1 = addresses_1.user_id", - checkparams={'param_1': 7}) + checkparams={"param_1": 7}, + ) self._test( a1.user != u7_transient, "addresses_1.user_id != :user_id_1 OR addresses_1.user_id IS NULL", - checkparams={'user_id_1': 7}) + checkparams={"user_id_1": 7}, + ) def test_selfref_relationship(self): @@ -1263,11 +1458,11 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): # auto self-referential aliasing self._test( - Node.children.any(Node.data == 'n1'), + Node.children.any(Node.data == "n1"), "EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE " "nodes.id = nodes_1.parent_id AND nodes_1.data = :data_1)", entity=Node, - checkparams={'data_1': 'n1'} + checkparams={"data_1": "n1"}, ) # needs autoaliasing @@ -1276,25 +1471,25 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): "NOT (EXISTS (SELECT 1 FROM nodes AS nodes_1 " "WHERE nodes.id = nodes_1.parent_id))", entity=Node, - checkparams={} + checkparams={}, ) self._test( Node.parent == None, # noqa "nodes.parent_id IS NULL", - checkparams={} + checkparams={}, ) self._test( nalias.parent == None, # noqa "nodes_1.parent_id IS NULL", - checkparams={} + checkparams={}, ) self._test( nalias.parent != None, # noqa "nodes_1.parent_id IS NOT NULL", - checkparams={} + checkparams={}, ) self._test( @@ -1302,15 +1497,15 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): "NOT (EXISTS (" "SELECT 1 FROM nodes WHERE nodes_1.id = nodes.parent_id))", entity=nalias, - checkparams={} + checkparams={}, ) self._test( - nalias.children.any(Node.data == 'some data'), + nalias.children.any(Node.data == "some data"), "EXISTS (SELECT 1 FROM nodes WHERE " "nodes_1.id = nodes.parent_id AND nodes.data = :data_1)", entity=nalias, - checkparams={'data_1': 'some data'} + checkparams={"data_1": "some data"}, ) # this fails because self-referential any() is auto-aliasing; @@ -1323,62 +1518,66 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): # ) self._test( - nalias.parent.has(Node.data == 'some data'), + nalias.parent.has(Node.data == "some data"), "EXISTS (SELECT 1 FROM nodes WHERE nodes.id = nodes_1.parent_id " "AND nodes.data = :data_1)", entity=nalias, - checkparams={'data_1': 'some data'} + checkparams={"data_1": "some data"}, ) self._test( - Node.parent.has(Node.data == 'some data'), + Node.parent.has(Node.data == "some data"), "EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE " "nodes_1.id = nodes.parent_id AND nodes_1.data = :data_1)", entity=Node, - checkparams={'data_1': 'some data'} + checkparams={"data_1": "some data"}, ) self._test( Node.parent == Node(id=7), ":param_1 = nodes.parent_id", - checkparams={"param_1": 7} + checkparams={"param_1": 7}, ) self._test( nalias.parent == Node(id=7), ":param_1 = nodes_1.parent_id", - checkparams={"param_1": 7} + checkparams={"param_1": 7}, ) self._test( nalias.parent != Node(id=7), - 'nodes_1.parent_id != :parent_id_1 ' - 'OR nodes_1.parent_id IS NULL', - checkparams={"parent_id_1": 7} + "nodes_1.parent_id != :parent_id_1 " + "OR nodes_1.parent_id IS NULL", + checkparams={"parent_id_1": 7}, ) self._test( nalias.parent != Node(id=7), - 'nodes_1.parent_id != :parent_id_1 ' - 'OR nodes_1.parent_id IS NULL', - checkparams={"parent_id_1": 7} + "nodes_1.parent_id != :parent_id_1 " + "OR nodes_1.parent_id IS NULL", + checkparams={"parent_id_1": 7}, ) self._test( nalias.children.contains(Node(id=7, parent_id=12)), "nodes_1.id = :param_1", - checkparams={"param_1": 12} + checkparams={"param_1": 12}, ) def test_multilevel_any(self): - User, Address, Dingaling = \ - self.classes.User, self.classes.Address, self.classes.Dingaling + User, Address, Dingaling = ( + self.classes.User, + self.classes.Address, + self.classes.Dingaling, + ) sess = Session() q = sess.query(User).filter( User.addresses.any( - and_(Address.id == Dingaling.address_id, - Dingaling.data == 'x'))) + and_(Address.id == Dingaling.address_id, Dingaling.data == "x") + ) + ) # new since #2746 - correlate_except() now takes context into account # so its usage in any() is not as disrupting. self.assert_compile( @@ -1389,18 +1588,18 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): "FROM addresses, dingalings " "WHERE users.id = addresses.user_id AND " "addresses.id = dingalings.address_id AND " - "dingalings.data = :data_1)" + "dingalings.data = :data_1)", ) def test_op(self): User = self.classes.User - self._test(User.name.op('ilike')('17'), "users.name ilike :name_1") + self._test(User.name.op("ilike")("17"), "users.name ilike :name_1") def test_in(self): User = self.classes.User - self._test(User.id.in_(['a', 'b']), "users.id IN (:id_1, :id_2)") + self._test(User.id.in_(["a", "b"]), "users.id IN (:id_1, :id_2)") def test_in_on_relationship_not_supported(self): User, Address = self.classes.User, self.classes.Address @@ -1417,14 +1616,15 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): User = self.classes.User self._test( - User.id.between('a', 'b'), "users.id BETWEEN :id_1 AND :id_2") + User.id.between("a", "b"), "users.id BETWEEN :id_1 AND :id_2" + ) def test_collate(self): User = self.classes.User - self._test(collate(User.id, 'utf8_bin'), "users.id COLLATE utf8_bin") + self._test(collate(User.id, "utf8_bin"), "users.id COLLATE utf8_bin") - self._test(User.id.collate('utf8_bin'), "users.id COLLATE utf8_bin") + self._test(User.id.collate("utf8_bin"), "users.id COLLATE utf8_bin") def test_selfref_between(self): User = self.classes.User @@ -1432,10 +1632,12 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): ualias = aliased(User) self._test( User.id.between(ualias.id, ualias.id), - "users.id BETWEEN users_1.id AND users_1.id") + "users.id BETWEEN users_1.id AND users_1.id", + ) self._test( ualias.id.between(User.id, User.id), - "users_1.id BETWEEN users.id AND users.id") + "users_1.id BETWEEN users.id AND users.id", + ) def test_clauses(self): User, Address = self.classes.User, self.classes.Address @@ -1443,8 +1645,10 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): for (expr, compare) in ( (func.max(User.id), "max(users.id)"), (User.id.desc(), "users.id DESC"), - (between(5, User.id, Address.id), - ":param_1 BETWEEN users.id AND addresses.id"), + ( + between(5, User.id, Address.id), + ":param_1 BETWEEN users.id AND addresses.id", + ), # this one would require adding compile() to # InstrumentedScalarAttribute. do we want this ? # (User.id, "users.id") @@ -1454,20 +1658,30 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): class ExpressionTest(QueryTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_deferred_instances(self): - User, addresses, Address = (self.classes.User, - self.tables.addresses, - self.classes.Address) + User, addresses, Address = ( + self.classes.User, + self.tables.addresses, + self.classes.Address, + ) session = create_session() - s = session.query(User).filter( - and_(addresses.c.email_address == bindparam('emailad'), - Address.user_id == User.id)).statement + s = ( + session.query(User) + .filter( + and_( + addresses.c.email_address == bindparam("emailad"), + Address.user_id == User.id, + ) + ) + .statement + ) result = list( - session.query(User).instances(s.execute(emailad='jack@bean.com'))) + session.query(User).instances(s.execute(emailad="jack@bean.com")) + ) eq_([User(id=7)], result) def test_aliased_sql_construct(self): @@ -1483,7 +1697,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): "addresses.id AS addresses_id, addresses.user_id AS " "addresses_user_id, addresses.email_address AS " "addresses_email_address FROM users JOIN addresses " - "ON users.id = addresses.user_id) AS anon_1" + "ON users.id = addresses.user_id) AS anon_1", ) def test_aliased_sql_construct_raises_adapt_on_names(self): @@ -1493,7 +1707,9 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): assert_raises_message( sa_exc.ArgumentError, "adapt_on_names only applies to ORM elements", - aliased, j, adapt_on_names=True + aliased, + j, + adapt_on_names=True, ) def test_scalar_subquery_compile_whereclause(self): @@ -1512,7 +1728,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): "AS addresses_user_id, addresses.email_address AS " "addresses_email_address FROM addresses WHERE " "addresses.user_id = (SELECT users.id AS users_id " - "FROM users WHERE users.id = :id_1)" + "FROM users WHERE users.id = :id_1)", ) def test_subquery_no_eagerloads(self): @@ -1521,7 +1737,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): self.assert_compile( s.query(User).options(joinedload(User.addresses)).subquery(), - "SELECT users.id, users.name FROM users" + "SELECT users.id, users.name FROM users", ) def test_exists_no_eagerloads(self): @@ -1532,27 +1748,30 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): s.query( s.query(User).options(joinedload(User.addresses)).exists() ), - "SELECT EXISTS (SELECT 1 FROM users) AS anon_1" + "SELECT EXISTS (SELECT 1 FROM users) AS anon_1", ) def test_named_subquery(self): User = self.classes.User session = create_session() - a1 = session.query(User.id).filter(User.id == 7).subquery('foo1') - a2 = session.query(User.id).filter(User.id == 7).subquery(name='foo2') + a1 = session.query(User.id).filter(User.id == 7).subquery("foo1") + a2 = session.query(User.id).filter(User.id == 7).subquery(name="foo2") a3 = session.query(User.id).filter(User.id == 7).subquery() - eq_(a1.name, 'foo1') - eq_(a2.name, 'foo2') - eq_(a3.name, '%%(%d anon)s' % id(a3)) + eq_(a1.name, "foo1") + eq_(a2.name, "foo2") + eq_(a3.name, "%%(%d anon)s" % id(a3)) def test_labeled_subquery(self): User = self.classes.User session = create_session() - a1 = session.query(User.id).filter(User.id == 7). \ - subquery(with_labels=True) + a1 = ( + session.query(User.id) + .filter(User.id == 7) + .subquery(with_labels=True) + ) assert a1.c.users_id is not None def test_reduced_subquery(self): @@ -1560,22 +1779,27 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): ua = aliased(User) session = create_session() - a1 = session.query(User.id, ua.id, ua.name).\ - filter(User.id == ua.id).subquery(reduce_columns=True) - self.assert_compile(a1, - "SELECT users.id, users_1.name FROM " - "users, users AS users_1 " - "WHERE users.id = users_1.id") + a1 = ( + session.query(User.id, ua.id, ua.name) + .filter(User.id == ua.id) + .subquery(reduce_columns=True) + ) + self.assert_compile( + a1, + "SELECT users.id, users_1.name FROM " + "users, users AS users_1 " + "WHERE users.id = users_1.id", + ) def test_label(self): User = self.classes.User session = create_session() - q = session.query(User.id).filter(User.id == 7).label('foo') + q = session.query(User.id).filter(User.id == 7).label("foo") self.assert_compile( session.query(q), - "SELECT (SELECT users.id FROM users WHERE users.id = :id_1) AS foo" + "SELECT (SELECT users.id FROM users WHERE users.id = :id_1) AS foo", ) def test_as_scalar(self): @@ -1585,19 +1809,25 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): q = session.query(User.id).filter(User.id == 7).as_scalar() - self.assert_compile(session.query(User).filter(User.id.in_(q)), - 'SELECT users.id AS users_id, users.name ' - 'AS users_name FROM users WHERE users.id ' - 'IN (SELECT users.id FROM users WHERE ' - 'users.id = :id_1)') + self.assert_compile( + session.query(User).filter(User.id.in_(q)), + "SELECT users.id AS users_id, users.name " + "AS users_name FROM users WHERE users.id " + "IN (SELECT users.id FROM users WHERE " + "users.id = :id_1)", + ) def test_param_transfer(self): User = self.classes.User session = create_session() - q = session.query(User.id).filter(User.id == bindparam('foo')).\ - params(foo=7).subquery() + q = ( + session.query(User.id) + .filter(User.id == bindparam("foo")) + .params(foo=7) + .subquery() + ) q = session.query(User).filter(User.id.in_(q)) @@ -1607,8 +1837,12 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): User, Address = self.classes.User, self.classes.Address session = create_session() - s = session.query(User.id).join(User.addresses).group_by(User.id).\ - having(func.count(Address.id) > 2) + s = ( + session.query(User.id) + .join(User.addresses) + .group_by(User.id) + .having(func.count(Address.id) > 2) + ) eq_(session.query(User).filter(User.id.in_(s)).all(), [User(id=8)]) def test_union(self): @@ -1616,12 +1850,13 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): s = create_session() - q1 = s.query(User).filter(User.name == 'ed').with_labels() - q2 = s.query(User).filter(User.name == 'fred').with_labels() + q1 = s.query(User).filter(User.name == "ed").with_labels() + q2 = s.query(User).filter(User.name == "fred").with_labels() eq_( - s.query(User).from_statement(union(q1, q2). - order_by('users_name')).all(), - [User(name='ed'), User(name='fred')] + s.query(User) + .from_statement(union(q1, q2).order_by("users_name")) + .all(), + [User(name="ed"), User(name="fred")], ) def test_select(self): @@ -1631,12 +1866,12 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): # this is actually not legal on most DBs since the subquery has no # alias - q1 = s.query(User).filter(User.name == 'ed') + q1 = s.query(User).filter(User.name == "ed") self.assert_compile( select([q1]), "SELECT users_id, users_name FROM (SELECT users.id AS users_id, " - "users.name AS users_name FROM users WHERE users.name = :name_1)" + "users.name AS users_name FROM users WHERE users.name = :name_1)", ) def test_join(self): @@ -1646,15 +1881,19 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): # TODO: do we want aliased() to detect a query and convert to # subquery() automatically ? - q1 = s.query(Address).filter(Address.email_address == 'jack@bean.com') + q1 = s.query(Address).filter(Address.email_address == "jack@bean.com") adalias = aliased(Address, q1.subquery()) eq_( - s.query(User, adalias).join(adalias, User.id == adalias.user_id). - all(), + s.query(User, adalias) + .join(adalias, User.id == adalias.user_id) + .all(), [ ( - User(id=7, name='jack'), - Address(email_address='jack@bean.com', user_id=7, id=1))]) + User(id=7, name="jack"), + Address(email_address="jack@bean.com", user_id=7, id=1), + ) + ], + ) def test_group_by_plain(self): User = self.classes.User @@ -1664,7 +1903,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): self.assert_compile( select([q1]), "SELECT users_id, users_name FROM (SELECT users.id AS users_id, " - "users.name AS users_name FROM users GROUP BY users.name)" + "users.name AS users_name FROM users GROUP BY users.name)", ) def test_group_by_append(self): @@ -1678,7 +1917,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): select([q1.group_by(User.id)]), "SELECT users_id, users_name FROM (SELECT users.id AS users_id, " "users.name AS users_name FROM users " - "GROUP BY users.name, users.id)" + "GROUP BY users.name, users.id)", ) def test_group_by_cancellation(self): @@ -1690,14 +1929,14 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): self.assert_compile( select([q1.group_by(None).group_by(User.id)]), "SELECT users_id, users_name FROM (SELECT users.id AS users_id, " - "users.name AS users_name FROM users GROUP BY users.id)" + "users.name AS users_name FROM users GROUP BY users.id)", ) # test cancellation by using None, replacement with nothing self.assert_compile( select([q1.group_by(None)]), "SELECT users_id, users_name FROM (SELECT users.id AS users_id, " - "users.name AS users_name FROM users)" + "users.name AS users_name FROM users)", ) def test_group_by_cancelled_still_present(self): @@ -1716,7 +1955,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): self.assert_compile( select([q1]), "SELECT users_id, users_name FROM (SELECT users.id AS users_id, " - "users.name AS users_name FROM users ORDER BY users.name)" + "users.name AS users_name FROM users ORDER BY users.name)", ) def test_order_by_append(self): @@ -1730,7 +1969,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): select([q1.order_by(User.id)]), "SELECT users_id, users_name FROM (SELECT users.id AS users_id, " "users.name AS users_name FROM users " - "ORDER BY users.name, users.id)" + "ORDER BY users.name, users.id)", ) def test_order_by_cancellation(self): @@ -1742,14 +1981,14 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): self.assert_compile( select([q1.order_by(None).order_by(User.id)]), "SELECT users_id, users_name FROM (SELECT users.id AS users_id, " - "users.name AS users_name FROM users ORDER BY users.id)" + "users.name AS users_name FROM users ORDER BY users.id)", ) # test cancellation by using None, replacement with nothing self.assert_compile( select([q1.order_by(None)]), "SELECT users_id, users_name FROM (SELECT users.id AS users_id, " - "users.name AS users_name FROM users)" + "users.name AS users_name FROM users)", ) def test_order_by_cancellation_false(self): @@ -1761,14 +2000,14 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): self.assert_compile( select([q1.order_by(False).order_by(User.id)]), "SELECT users_id, users_name FROM (SELECT users.id AS users_id, " - "users.name AS users_name FROM users ORDER BY users.id)" + "users.name AS users_name FROM users ORDER BY users.id)", ) # test cancellation by using None, replacement with nothing self.assert_compile( select([q1.order_by(False)]), "SELECT users_id, users_name FROM (SELECT users.id AS users_id, " - "users.name AS users_name FROM users)" + "users.name AS users_name FROM users)", ) def test_order_by_cancelled_allows_assertions(self): @@ -1789,21 +2028,26 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): - __dialect__ = 'default' - run_setup_mappers = 'each' + __dialect__ = "default" + run_setup_mappers = "each" def _fixture(self, label=True, polymorphic=False): User, Address = self.classes("User", "Address") users, addresses = self.tables("users", "addresses") - stmt = select([func.max(addresses.c.email_address)]).\ - where(addresses.c.user_id == users.c.id).\ - correlate(users) + stmt = ( + select([func.max(addresses.c.email_address)]) + .where(addresses.c.user_id == users.c.id) + .correlate(users) + ) if label: stmt = stmt.label("email_ad") - mapper(User, users, properties={ - "ead": column_property(stmt) - }, with_polymorphic="*" if polymorphic else None) + mapper( + User, + users, + properties={"ead": column_property(stmt)}, + with_polymorphic="*" if polymorphic else None, + ) mapper(Address, addresses) def _func_fixture(self, label=False): @@ -1811,17 +2055,23 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): users = self.tables.users if label: - mapper(User, users, properties={ - "foobar": column_property( - func.foob(users.c.name).label(None) - ) - }) + mapper( + User, + users, + properties={ + "foobar": column_property( + func.foob(users.c.name).label(None) + ) + }, + ) else: - mapper(User, users, properties={ - "foobar": column_property( - func.foob(users.c.name) - ) - }) + mapper( + User, + users, + properties={ + "foobar": column_property(func.foob(users.c.name)) + }, + ) def test_anon_label_function_auto(self): self._func_fixture() @@ -1833,7 +2083,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.assert_compile( s.query(User.foobar, u1.foobar), "SELECT foob(users.name) AS foob_1, foob(users_1.name) AS foob_2 " - "FROM users, users AS users_1" + "FROM users, users AS users_1", ) def test_anon_label_function_manual(self): @@ -1846,7 +2096,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.assert_compile( s.query(User.foobar, u1.foobar), "SELECT foob(users.name) AS foob_1, foob(users_1.name) AS foob_2 " - "FROM users, users AS users_1" + "FROM users, users AS users_1", ) def test_anon_label_ad_hoc_labeling(self): @@ -1857,9 +2107,9 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): u1 = aliased(User) self.assert_compile( - s.query(User.foobar.label('x'), u1.foobar.label('y')), + s.query(User.foobar.label("x"), u1.foobar.label("y")), "SELECT foob(users.name) AS x, foob(users_1.name) AS y " - "FROM users, users AS users_1" + "FROM users, users AS users_1", ) def test_order_by_column_prop_string(self): @@ -1874,7 +2124,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): "FROM addresses " "WHERE addresses.user_id = users.id) AS email_ad, " "users.id AS users_id, users.name AS users_name " - "FROM users ORDER BY email_ad" + "FROM users ORDER BY email_ad", ) def test_order_by_column_prop_aliased_string(self): @@ -1892,11 +2142,12 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): "FROM addresses WHERE addresses.user_id = users_1.id) " "AS anon_1, users_1.id AS users_1_id, " "users_1.name AS users_1_name FROM users AS users_1 " - "ORDER BY email_ad" + "ORDER BY email_ad", ) + assert_warnings( - go, - ["Can't resolve label reference 'email_ad'"], regex=True) + go, ["Can't resolve label reference 'email_ad'"], regex=True + ) def test_order_by_column_labeled_prop_attr_aliased_one(self): User = self.classes.User @@ -1910,7 +2161,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): "SELECT (SELECT max(addresses.email_address) AS max_1 " "FROM addresses WHERE addresses.user_id = users_1.id) AS anon_1, " "users_1.id AS users_1_id, users_1.name AS users_1_name " - "FROM users AS users_1 ORDER BY anon_1" + "FROM users AS users_1 ORDER BY anon_1", ) def test_order_by_column_labeled_prop_attr_aliased_two(self): @@ -1925,7 +2176,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): "SELECT (SELECT max(addresses.email_address) AS max_1 " "FROM addresses, " "users AS users_1 WHERE addresses.user_id = users_1.id) " - "AS anon_1 ORDER BY anon_1" + "AS anon_1 ORDER BY anon_1", ) # we're also testing that the state of "ua" is OK after the @@ -1936,7 +2187,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): "SELECT (SELECT max(addresses.email_address) AS max_1 " "FROM addresses WHERE addresses.user_id = users_1.id) AS anon_1, " "users_1.id AS users_1_id, users_1.name AS users_1_name " - "FROM users AS users_1 ORDER BY anon_1" + "FROM users AS users_1 ORDER BY anon_1", ) def test_order_by_column_labeled_prop_attr_aliased_three(self): @@ -1952,7 +2203,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): "FROM addresses, users WHERE addresses.user_id = users.id) " "AS email_ad, (SELECT max(addresses.email_address) AS max_1 " "FROM addresses, users AS users_1 WHERE addresses.user_id = " - "users_1.id) AS anon_1 ORDER BY email_ad, anon_1" + "users_1.id) AS anon_1 ORDER BY email_ad, anon_1", ) q = s.query(User, ua).order_by(User.ead, ua.ead) @@ -1964,7 +2215,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): "(SELECT max(addresses.email_address) AS max_1 FROM addresses " "WHERE addresses.user_id = users_1.id) AS anon_1, users_1.id " "AS users_1_id, users_1.name AS users_1_name FROM users, " - "users AS users_1 ORDER BY email_ad, anon_1" + "users AS users_1 ORDER BY email_ad, anon_1", ) def test_order_by_column_labeled_prop_attr_aliased_four(self): @@ -1979,7 +2230,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): "SELECT (SELECT max(addresses.email_address) AS max_1 FROM " "addresses WHERE addresses.user_id = users_1.id) AS anon_1, " "users_1.id AS users_1_id, users_1.name AS users_1_name, " - "users.id AS users_id FROM users AS users_1, users ORDER BY anon_1" + "users.id AS users_id FROM users AS users_1, users ORDER BY anon_1", ) def test_order_by_column_unlabeled_prop_attr_aliased_one(self): @@ -1994,7 +2245,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): "SELECT (SELECT max(addresses.email_address) AS max_1 " "FROM addresses WHERE addresses.user_id = users_1.id) AS anon_1, " "users_1.id AS users_1_id, users_1.name AS users_1_name " - "FROM users AS users_1 ORDER BY anon_1" + "FROM users AS users_1 ORDER BY anon_1", ) def test_order_by_column_unlabeled_prop_attr_aliased_two(self): @@ -2009,7 +2260,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): "SELECT (SELECT max(addresses.email_address) AS max_1 " "FROM addresses, " "users AS users_1 WHERE addresses.user_id = users_1.id) " - "AS anon_1 ORDER BY anon_1" + "AS anon_1 ORDER BY anon_1", ) # we're also testing that the state of "ua" is OK after the @@ -2020,7 +2271,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): "SELECT (SELECT max(addresses.email_address) AS max_1 " "FROM addresses WHERE addresses.user_id = users_1.id) AS anon_1, " "users_1.id AS users_1_id, users_1.name AS users_1_name " - "FROM users AS users_1 ORDER BY anon_1" + "FROM users AS users_1 ORDER BY anon_1", ) def test_order_by_column_unlabeled_prop_attr_aliased_three(self): @@ -2037,7 +2288,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): "AS anon_1, (SELECT max(addresses.email_address) AS max_1 " "FROM addresses, users AS users_1 " "WHERE addresses.user_id = users_1.id) AS anon_2 " - "ORDER BY anon_1, anon_2" + "ORDER BY anon_1, anon_2", ) q = s.query(User, ua).order_by(User.ead, ua.ead) @@ -2049,7 +2300,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): "(SELECT max(addresses.email_address) AS max_1 FROM addresses " "WHERE addresses.user_id = users_1.id) AS anon_2, users_1.id " "AS users_1_id, users_1.name AS users_1_name FROM users, " - "users AS users_1 ORDER BY anon_1, anon_2" + "users AS users_1 ORDER BY anon_1, anon_2", ) def test_order_by_column_prop_attr(self): @@ -2067,7 +2318,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): "FROM addresses " "WHERE addresses.user_id = users.id) AS email_ad, " "users.id AS users_id, users.name AS users_name " - "FROM users ORDER BY email_ad" + "FROM users ORDER BY email_ad", ) def test_order_by_column_prop_attr_non_present(self): @@ -2082,13 +2333,14 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): "FROM users ORDER BY " "(SELECT max(addresses.email_address) AS max_1 " "FROM addresses " - "WHERE addresses.user_id = users.id)" + "WHERE addresses.user_id = users.id)", ) class ComparatorTest(QueryTest): def test_clause_element_query_resolve(self): from sqlalchemy.orm.properties import ColumnProperty + User = self.classes.User class Comparator(ColumnProperty.Comparator): @@ -2100,9 +2352,10 @@ class ComparatorTest(QueryTest): sess = Session() eq_( - sess.query(Comparator(User.id)).order_by( - Comparator(User.id)).all(), - [(7, ), (8, ), (9, ), (10, )] + sess.query(Comparator(User.id)) + .order_by(Comparator(User.id)) + .all(), + [(7,), (8,), (9,), (10,)], ) @@ -2113,8 +2366,9 @@ class SliceTest(QueryTest): assert User(id=7) == create_session().query(User).first() - assert create_session().query(User).filter(User.id == 27). \ - first() is None + assert ( + create_session().query(User).filter(User.id == 27).first() is None + ) def test_limit_offset_applies(self): """Test that the expected LIMIT/OFFSET is applied for slices. @@ -2131,54 +2385,87 @@ class SliceTest(QueryTest): q = sess.query(User).order_by(User.id) self.assert_sql( - testing.db, lambda: q[10:20], [ + testing.db, + lambda: q[10:20], + [ ( "SELECT users.id AS users_id, users.name " "AS users_name FROM users ORDER BY users.id " "LIMIT :param_1 OFFSET :param_2", - {'param_1': 10, 'param_2': 10})]) + {"param_1": 10, "param_2": 10}, + ) + ], + ) self.assert_sql( - testing.db, lambda: q[:20], [ + testing.db, + lambda: q[:20], + [ ( "SELECT users.id AS users_id, users.name " "AS users_name FROM users ORDER BY users.id " "LIMIT :param_1", - {'param_1': 20})]) + {"param_1": 20}, + ) + ], + ) self.assert_sql( - testing.db, lambda: q[5:], [ + testing.db, + lambda: q[5:], + [ ( "SELECT users.id AS users_id, users.name " "AS users_name FROM users ORDER BY users.id " "LIMIT -1 OFFSET :param_1", - {'param_1': 5})]) + {"param_1": 5}, + ) + ], + ) self.assert_sql(testing.db, lambda: q[2:2], []) self.assert_sql(testing.db, lambda: q[-2:-5], []) self.assert_sql( - testing.db, lambda: q[-5:-2], [ + testing.db, + lambda: q[-5:-2], + [ ( "SELECT users.id AS users_id, users.name AS users_name " - "FROM users ORDER BY users.id", {})]) + "FROM users ORDER BY users.id", + {}, + ) + ], + ) self.assert_sql( - testing.db, lambda: q[-5:], [ + testing.db, + lambda: q[-5:], + [ ( "SELECT users.id AS users_id, users.name AS users_name " - "FROM users ORDER BY users.id", {})]) + "FROM users ORDER BY users.id", + {}, + ) + ], + ) self.assert_sql( - testing.db, lambda: q[:], [ + testing.db, + lambda: q[:], + [ ( "SELECT users.id AS users_id, users.name AS users_name " - "FROM users ORDER BY users.id", {})]) + "FROM users ORDER BY users.id", + {}, + ) + ], + ) class FilterTest(QueryTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_basic(self): User = self.classes.User @@ -2192,11 +2479,13 @@ class FilterTest(QueryTest, AssertsCompiledSQL): sess = create_session() - assert [User(id=8), User(id=9)] == \ - sess.query(User).order_by(User.id).limit(2).offset(1).all() + assert [User(id=8), User(id=9)] == sess.query(User).order_by( + User.id + ).limit(2).offset(1).all() - assert [User(id=8), User(id=9)] == \ - list(sess.query(User).order_by(User.id)[1:3]) + assert [User(id=8), User(id=9)] == list( + sess.query(User).order_by(User.id)[1:3] + ) assert User(id=8) == sess.query(User).order_by(User.id)[1] @@ -2208,17 +2497,24 @@ class FilterTest(QueryTest, AssertsCompiledSQL): """Does a query allow bindparam for the limit?""" User = self.classes.User sess = create_session() - q1 = sess.query(self.classes.User).\ - order_by(self.classes.User.id).limit(bindparam('n')) + q1 = ( + sess.query(self.classes.User) + .order_by(self.classes.User.id) + .limit(bindparam("n")) + ) for n in range(1, 4): result = q1.params(n=n).all() eq_(len(result), n) eq_( - sess.query(User).order_by(User.id).limit(bindparam('limit')). - offset(bindparam('offset')).params(limit=2, offset=1).all(), - [User(id=8), User(id=9)] + sess.query(User) + .order_by(User.id) + .limit(bindparam("limit")) + .offset(bindparam("offset")) + .params(limit=2, offset=1) + .all(), + [User(id=8), User(id=9)], ) @testing.fails_on("mysql", "doesn't like CAST in the limit clause") @@ -2226,13 +2522,22 @@ class FilterTest(QueryTest, AssertsCompiledSQL): def test_select_with_bindparam_offset_limit_w_cast(self): User = self.classes.User sess = create_session() - q1 = sess.query(self.classes.User).\ - order_by(self.classes.User.id).limit(bindparam('n')) + q1 = ( + sess.query(self.classes.User) + .order_by(self.classes.User.id) + .limit(bindparam("n")) + ) eq_( list( - sess.query(User).params(a=1, b=3).order_by(User.id) - [cast(bindparam('a'), Integer):cast(bindparam('b'), Integer)]), - [User(id=8), User(id=9)] + sess.query(User) + .params(a=1, b=3) + .order_by(User.id)[ + cast(bindparam("a"), Integer) : cast( + bindparam("b"), Integer + ) + ] + ), + [User(id=8), User(id=9)], ) @testing.requires.boolean_col_expressions @@ -2247,8 +2552,9 @@ class FilterTest(QueryTest, AssertsCompiledSQL): def test_one_filter(self): User = self.classes.User - assert [User(id=8), User(id=9)] == \ - create_session().query(User).filter(User.name.endswith('ed')).all() + assert [User(id=8), User(id=9)] == create_session().query(User).filter( + User.name.endswith("ed") + ).all() def test_contains(self): """test comparing a collection to an object instance.""" @@ -2257,8 +2563,9 @@ class FilterTest(QueryTest, AssertsCompiledSQL): sess = create_session() address = sess.query(Address).get(3) - assert [User(id=8)] == \ - sess.query(User).filter(User.addresses.contains(address)).all() + assert [User(id=8)] == sess.query(User).filter( + User.addresses.contains(address) + ).all() try: sess.query(User).filter(User.addresses == address) @@ -2266,12 +2573,14 @@ class FilterTest(QueryTest, AssertsCompiledSQL): except sa_exc.InvalidRequestError: assert True - assert [User(id=10)] == \ - sess.query(User).filter(User.addresses == None).all() # noqa + assert [User(id=10)] == sess.query(User).filter( + User.addresses == None + ).all() # noqa try: - assert [User(id=7), User(id=9), User(id=10)] == \ - sess.query(User).filter(User.addresses != address).all() + assert [User(id=7), User(id=9), User(id=10)] == sess.query( + User + ).filter(User.addresses != address).all() assert False except sa_exc.InvalidRequestError: assert True @@ -2285,7 +2594,7 @@ class FilterTest(QueryTest, AssertsCompiledSQL): self.assert_compile( s.query(User).filter(User.addresses), "SELECT users.id AS users_id, users.name AS users_name " - "FROM users, addresses WHERE users.id = addresses.user_id" + "FROM users, addresses WHERE users.id = addresses.user_id", ) def test_unique_binds_join_cond(self): @@ -2296,15 +2605,15 @@ class FilterTest(QueryTest, AssertsCompiledSQL): sess = Session() a1, a2 = sess.query(Address).order_by(Address.id)[0:2] self.assert_compile( - sess.query(User).filter(User.addresses.contains(a1)).union( - sess.query(User).filter(User.addresses.contains(a2)) - ), + sess.query(User) + .filter(User.addresses.contains(a1)) + .union(sess.query(User).filter(User.addresses.contains(a2))), "SELECT anon_1.users_id AS anon_1_users_id, anon_1.users_name AS " "anon_1_users_name FROM (SELECT users.id AS users_id, " "users.name AS users_name FROM users WHERE users.id = :param_1 " "UNION SELECT users.id AS users_id, users.name AS users_name " "FROM users WHERE users.id = :param_2) AS anon_1", - checkparams={'param_1': 7, 'param_2': 8} + checkparams={"param_1": 7, "param_2": 8}, ) def test_any(self): @@ -2314,36 +2623,36 @@ class FilterTest(QueryTest, AssertsCompiledSQL): sess = create_session() - assert [User(id=8), User(id=9)] == \ - sess.query(User). \ - filter( - User.addresses.any(Address.email_address.like('%ed%'))).all() + assert [User(id=8), User(id=9)] == sess.query(User).filter( + User.addresses.any(Address.email_address.like("%ed%")) + ).all() - assert [User(id=8)] == \ - sess.query(User). \ - filter( - User.addresses.any( - Address.email_address.like('%ed%'), id=4)).all() + assert [User(id=8)] == sess.query(User).filter( + User.addresses.any(Address.email_address.like("%ed%"), id=4) + ).all() - assert [User(id=8)] == \ - sess.query(User). \ - filter(User.addresses.any(Address.email_address.like('%ed%'))).\ - filter(User.addresses.any(id=4)).all() + assert [User(id=8)] == sess.query(User).filter( + User.addresses.any(Address.email_address.like("%ed%")) + ).filter(User.addresses.any(id=4)).all() - assert [User(id=9)] == \ - sess.query(User). \ - filter(User.addresses.any(email_address='fred@fred.com')).all() + assert [User(id=9)] == sess.query(User).filter( + User.addresses.any(email_address="fred@fred.com") + ).all() # test that the contents are not adapted by the aliased join - assert [User(id=7), User(id=8)] == \ - sess.query(User).join("addresses", aliased=True). \ - filter( - ~User.addresses.any( - Address.email_address == 'fred@fred.com')).all() + assert ( + [User(id=7), User(id=8)] + == sess.query(User) + .join("addresses", aliased=True) + .filter( + ~User.addresses.any(Address.email_address == "fred@fred.com") + ) + .all() + ) - assert [User(id=10)] == \ - sess.query(User).outerjoin("addresses", aliased=True). \ - filter(~User.addresses.any()).all() + assert [User(id=10)] == sess.query(User).outerjoin( + "addresses", aliased=True + ).filter(~User.addresses.any()).all() def test_any_doesnt_overcorrelate(self): # see also HasAnyTest, a newer suite which tests these at the level of @@ -2353,48 +2662,70 @@ class FilterTest(QueryTest, AssertsCompiledSQL): sess = create_session() # test that any() doesn't overcorrelate - assert [User(id=7), User(id=8)] == \ - sess.query(User).join("addresses"). \ - filter( - ~User.addresses.any( - Address.email_address == 'fred@fred.com')).all() + assert ( + [User(id=7), User(id=8)] + == sess.query(User) + .join("addresses") + .filter( + ~User.addresses.any(Address.email_address == "fred@fred.com") + ) + .all() + ) def test_has(self): # see also HasAnyTest, a newer suite which tests these at the level of # SQL compilation Dingaling, User, Address = ( - self.classes.Dingaling, self.classes.User, self.classes.Address) + self.classes.Dingaling, + self.classes.User, + self.classes.Address, + ) sess = create_session() - assert [Address(id=5)] == \ - sess.query(Address).filter(Address.user.has(name='fred')).all() + assert [Address(id=5)] == sess.query(Address).filter( + Address.user.has(name="fred") + ).all() - assert [Address(id=2), Address(id=3), Address(id=4), Address(id=5)] \ - == sess.query(Address). \ - filter(Address.user.has(User.name.like('%ed%'))). \ - order_by(Address.id).all() + assert ( + [Address(id=2), Address(id=3), Address(id=4), Address(id=5)] + == sess.query(Address) + .filter(Address.user.has(User.name.like("%ed%"))) + .order_by(Address.id) + .all() + ) - assert [Address(id=2), Address(id=3), Address(id=4)] == \ - sess.query(Address). \ - filter(Address.user.has(User.name.like('%ed%'), id=8)). \ - order_by(Address.id).all() + assert ( + [Address(id=2), Address(id=3), Address(id=4)] + == sess.query(Address) + .filter(Address.user.has(User.name.like("%ed%"), id=8)) + .order_by(Address.id) + .all() + ) # test has() doesn't overcorrelate - assert [Address(id=2), Address(id=3), Address(id=4)] == \ - sess.query(Address).join("user"). \ - filter(Address.user.has(User.name.like('%ed%'), id=8)). \ - order_by(Address.id).all() + assert ( + [Address(id=2), Address(id=3), Address(id=4)] + == sess.query(Address) + .join("user") + .filter(Address.user.has(User.name.like("%ed%"), id=8)) + .order_by(Address.id) + .all() + ) # test has() doesn't get subquery contents adapted by aliased join - assert [Address(id=2), Address(id=3), Address(id=4)] == \ - sess.query(Address).join("user", aliased=True). \ - filter(Address.user.has(User.name.like('%ed%'), id=8)). \ - order_by(Address.id).all() + assert ( + [Address(id=2), Address(id=3), Address(id=4)] + == sess.query(Address) + .join("user", aliased=True) + .filter(Address.user.has(User.name.like("%ed%"), id=8)) + .order_by(Address.id) + .all() + ) dingaling = sess.query(Dingaling).get(2) - assert [User(id=9)] == \ - sess.query(User). \ - filter(User.addresses.any(Address.dingaling == dingaling)).all() + assert [User(id=9)] == sess.query(User).filter( + User.addresses.any(Address.dingaling == dingaling) + ).all() def test_contains_m2m(self): Item, Order = self.classes.Item, self.classes.Order @@ -2403,89 +2734,117 @@ class FilterTest(QueryTest, AssertsCompiledSQL): item = sess.query(Item).get(3) eq_( - sess.query(Order).filter(Order.items.contains(item)). - order_by(Order.id).all(), - [Order(id=1), Order(id=2), Order(id=3)] + sess.query(Order) + .filter(Order.items.contains(item)) + .order_by(Order.id) + .all(), + [Order(id=1), Order(id=2), Order(id=3)], ) eq_( - sess.query(Order).filter(~Order.items.contains(item)). - order_by(Order.id).all(), - [Order(id=4), Order(id=5)] + sess.query(Order) + .filter(~Order.items.contains(item)) + .order_by(Order.id) + .all(), + [Order(id=4), Order(id=5)], ) item2 = sess.query(Item).get(5) eq_( - sess.query(Order).filter(Order.items.contains(item)). - filter(Order.items.contains(item2)).all(), - [Order(id=3)] + sess.query(Order) + .filter(Order.items.contains(item)) + .filter(Order.items.contains(item2)) + .all(), + [Order(id=3)], ) def test_comparison(self): """test scalar comparison to an object instance""" Item, Order, Dingaling, User, Address = ( - self.classes.Item, self.classes.Order, self.classes.Dingaling, - self.classes.User, self.classes.Address) + self.classes.Item, + self.classes.Order, + self.classes.Dingaling, + self.classes.User, + self.classes.Address, + ) sess = create_session() user = sess.query(User).get(8) - assert [Address(id=2), Address(id=3), Address(id=4)] == \ - sess.query(Address).filter(Address.user == user).all() + assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query( + Address + ).filter(Address.user == user).all() - assert [Address(id=1), Address(id=5)] == \ - sess.query(Address).filter(Address.user != user).all() + assert [Address(id=1), Address(id=5)] == sess.query(Address).filter( + Address.user != user + ).all() # generates an IS NULL - assert [] == sess.query(Address).filter(Address.user == None).all() # noqa + assert ( + [] == sess.query(Address).filter(Address.user == None).all() + ) # noqa assert [] == sess.query(Address).filter(Address.user == null()).all() - assert [Order(id=5)] == \ - sess.query(Order).filter(Order.address == None).all() # noqa + assert [Order(id=5)] == sess.query(Order).filter( + Order.address == None + ).all() # noqa # o2o dingaling = sess.query(Dingaling).get(2) - assert [Address(id=5)] == \ - sess.query(Address).filter(Address.dingaling == dingaling).all() + assert [Address(id=5)] == sess.query(Address).filter( + Address.dingaling == dingaling + ).all() # m2m eq_( - sess.query(Item).filter(Item.keywords == None). # noqa - order_by(Item.id).all(), [Item(id=4), Item(id=5)]) + sess.query(Item) + .filter(Item.keywords == None) + .order_by(Item.id) # noqa + .all(), + [Item(id=4), Item(id=5)], + ) eq_( - sess.query(Item).filter(Item.keywords != None). # noqa - order_by(Item.id).all(), [Item(id=1), Item(id=2), Item(id=3)]) + sess.query(Item) + .filter(Item.keywords != None) + .order_by(Item.id) # noqa + .all(), + [Item(id=1), Item(id=2), Item(id=3)], + ) def test_filter_by(self): User, Address = self.classes.User, self.classes.Address sess = create_session() user = sess.query(User).get(8) - assert [Address(id=2), Address(id=3), Address(id=4)] == \ - sess.query(Address).filter_by(user=user).all() + assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query( + Address + ).filter_by(user=user).all() # many to one generates IS NULL assert [] == sess.query(Address).filter_by(user=None).all() assert [] == sess.query(Address).filter_by(user=null()).all() # one to many generates WHERE NOT EXISTS - assert [User(name='chuck')] == \ - sess.query(User).filter_by(addresses=None).all() - assert [User(name='chuck')] == \ - sess.query(User).filter_by(addresses=null()).all() + assert [User(name="chuck")] == sess.query(User).filter_by( + addresses=None + ).all() + assert [User(name="chuck")] == sess.query(User).filter_by( + addresses=null() + ).all() def test_filter_by_tables(self): users = self.tables.users addresses = self.tables.addresses sess = create_session() self.assert_compile( - sess.query(users).filter_by(name='ed'). - join(addresses, users.c.id == addresses.c.user_id). - filter_by(email_address='ed@ed.com'), + sess.query(users) + .filter_by(name="ed") + .join(addresses, users.c.id == addresses.c.user_id) + .filter_by(email_address="ed@ed.com"), "SELECT users.id AS users_id, users.name AS users_name " "FROM users JOIN addresses ON users.id = addresses.user_id " "WHERE users.name = :name_1 AND " "addresses.email_address = :email_address_1", - checkparams={'email_address_1': 'ed@ed.com', 'name_1': 'ed'} + checkparams={"email_address_1": "ed@ed.com", "name_1": "ed"}, ) def test_filter_by_no_property(self): @@ -2494,72 +2853,103 @@ class FilterTest(QueryTest, AssertsCompiledSQL): assert_raises_message( sa.exc.InvalidRequestError, "Entity 'addresses' has no property 'name'", - sess.query(addresses).filter_by, name='ed' + sess.query(addresses).filter_by, + name="ed", ) def test_none_comparison(self): Order, User, Address = ( - self.classes.Order, self.classes.User, self.classes.Address) + self.classes.Order, + self.classes.User, + self.classes.Address, + ) sess = create_session() # scalar eq_( [Order(description="order 5")], - sess.query(Order).filter(Order.address_id == None).all() # noqa + sess.query(Order).filter(Order.address_id == None).all(), # noqa ) eq_( [Order(description="order 5")], - sess.query(Order).filter(Order.address_id == null()).all() + sess.query(Order).filter(Order.address_id == null()).all(), ) # o2o eq_( [Address(id=1), Address(id=3), Address(id=4)], - sess.query(Address).filter(Address.dingaling == None). # noqa - order_by(Address.id).all()) + sess.query(Address) + .filter(Address.dingaling == None) + .order_by(Address.id) # noqa + .all(), + ) eq_( [Address(id=1), Address(id=3), Address(id=4)], - sess.query(Address).filter(Address.dingaling == null()). - order_by(Address.id).all()) + sess.query(Address) + .filter(Address.dingaling == null()) + .order_by(Address.id) + .all(), + ) eq_( [Address(id=2), Address(id=5)], - sess.query(Address).filter(Address.dingaling != None). # noqa - order_by(Address.id).all()) + sess.query(Address) + .filter(Address.dingaling != None) + .order_by(Address.id) # noqa + .all(), + ) eq_( [Address(id=2), Address(id=5)], - sess.query(Address).filter(Address.dingaling != null()). - order_by(Address.id).all()) + sess.query(Address) + .filter(Address.dingaling != null()) + .order_by(Address.id) + .all(), + ) # m2o eq_( [Order(id=5)], - sess.query(Order).filter(Order.address == None).all()) # noqa + sess.query(Order).filter(Order.address == None).all(), + ) # noqa eq_( [Order(id=1), Order(id=2), Order(id=3), Order(id=4)], - sess.query(Order).order_by(Order.id). - filter(Order.address != None).all()) # noqa + sess.query(Order) + .order_by(Order.id) + .filter(Order.address != None) + .all(), + ) # noqa # o2m eq_( [User(id=10)], - sess.query(User).filter(User.addresses == None).all()) # noqa + sess.query(User).filter(User.addresses == None).all(), + ) # noqa eq_( [User(id=7), User(id=8), User(id=9)], - sess.query(User).filter(User.addresses != None). # noqa - order_by(User.id).all()) + sess.query(User) + .filter(User.addresses != None) + .order_by(User.id) # noqa + .all(), + ) def test_blank_filter_by(self): User = self.classes.User eq_( [(7,), (8,), (9,), (10,)], - create_session().query(User.id).filter_by().order_by(User.id).all() + create_session() + .query(User.id) + .filter_by() + .order_by(User.id) + .all(), ) eq_( [(7,), (8,), (9,), (10,)], - create_session().query(User.id).filter_by(**{}). - order_by(User.id).all() + create_session() + .query(User.id) + .filter_by(**{}) + .order_by(User.id) + .all(), ) def test_text_coerce(self): @@ -2568,39 +2958,39 @@ class FilterTest(QueryTest, AssertsCompiledSQL): self.assert_compile( s.query(User).filter(text("name='ed'")), "SELECT users.id AS users_id, users.name " - "AS users_name FROM users WHERE name='ed'" + "AS users_name FROM users WHERE name='ed'", ) -class HasAnyTest( - fixtures.DeclarativeMappedTest, AssertsCompiledSQL): - __dialect__ = 'default' +class HasAnyTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): + __dialect__ = "default" @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class D(Base): - __tablename__ = 'd' + __tablename__ = "d" id = Column(Integer, primary_key=True) class C(Base): - __tablename__ = 'c' + __tablename__ = "c" id = Column(Integer, primary_key=True) d_id = Column(ForeignKey(D.id)) bs = relationship("B", back_populates="c") b_d = Table( - 'b_d', Base.metadata, - Column('bid', ForeignKey('b.id')), - Column('did', ForeignKey('d.id')) + "b_d", + Base.metadata, + Column("bid", ForeignKey("b.id")), + Column("did", ForeignKey("d.id")), ) # note we are using the ForeignKey pattern identified as a bug # in [ticket:4367] class B(Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) c_id = Column(ForeignKey(C.id)) @@ -2609,16 +2999,17 @@ class HasAnyTest( d = relationship("D", secondary=b_d) class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) b_id = Column(ForeignKey(B.id)) d = relationship( - 'D', + "D", secondary="join(B, C)", primaryjoin="A.b_id == B.id", secondaryjoin="C.d_id == D.id", - uselist=False) + uselist=False, + ) def test_has_composite_secondary(self): A, D = self.classes("A", "D") @@ -2627,7 +3018,7 @@ class HasAnyTest( s.query(A).filter(A.d.has(D.id == 1)), "SELECT a.id AS a_id, a.b_id AS a_b_id FROM a WHERE EXISTS " "(SELECT 1 FROM d, b JOIN c ON c.id = b.c_id " - "WHERE a.b_id = b.id AND c.d_id = d.id AND d.id = :id_1)" + "WHERE a.b_id = b.id AND c.d_id = d.id AND d.id = :id_1)", ) def test_has_many_to_one(self): @@ -2636,7 +3027,7 @@ class HasAnyTest( self.assert_compile( s.query(B).filter(B.c.has(C.id == 1)), "SELECT b.id AS b_id, b.c_id AS b_c_id FROM b WHERE " - "EXISTS (SELECT 1 FROM c WHERE c.id = b.c_id AND c.id = :id_1)" + "EXISTS (SELECT 1 FROM c WHERE c.id = b.c_id AND c.id = :id_1)", ) def test_any_many_to_many(self): @@ -2646,7 +3037,7 @@ class HasAnyTest( s.query(B).filter(B.d.any(D.id == 1)), "SELECT b.id AS b_id, b.c_id AS b_c_id FROM b WHERE " "EXISTS (SELECT 1 FROM b_d, d WHERE b.id = b_d.bid " - "AND d.id = b_d.did AND d.id = :id_1)" + "AND d.id = b_d.did AND d.id = :id_1)", ) def test_any_one_to_many(self): @@ -2655,7 +3046,7 @@ class HasAnyTest( self.assert_compile( s.query(C).filter(C.bs.any(B.id == 1)), "SELECT c.id AS c_id, c.d_id AS c_d_id FROM c WHERE " - "EXISTS (SELECT 1 FROM b WHERE c.id = b.c_id AND b.id = :id_1)" + "EXISTS (SELECT 1 FROM b WHERE c.id = b.c_id AND b.id = :id_1)", ) def test_any_many_to_many_doesnt_overcorrelate(self): @@ -2668,7 +3059,7 @@ class HasAnyTest( "b JOIN b_d AS b_d_1 ON b.id = b_d_1.bid " "JOIN d ON d.id = b_d_1.did WHERE " "EXISTS (SELECT 1 FROM b_d, d WHERE b.id = b_d.bid " - "AND d.id = b_d.did AND d.id = :id_1)" + "AND d.id = b_d.did AND d.id = :id_1)", ) def test_has_doesnt_overcorrelate(self): @@ -2680,7 +3071,7 @@ class HasAnyTest( "SELECT b.id AS b_id, b.c_id AS b_c_id " "FROM b JOIN c ON c.id = b.c_id " "WHERE EXISTS " - "(SELECT 1 FROM c WHERE c.id = b.c_id AND c.id = :id_1)" + "(SELECT 1 FROM c WHERE c.id = b.c_id AND c.id = :id_1)", ) def test_has_doesnt_get_aliased_join_subq(self): @@ -2692,7 +3083,7 @@ class HasAnyTest( "SELECT b.id AS b_id, b.c_id AS b_c_id " "FROM b JOIN c AS c_1 ON c_1.id = b.c_id " "WHERE EXISTS " - "(SELECT 1 FROM c WHERE c.id = b.c_id AND c.id = :id_1)" + "(SELECT 1 FROM c WHERE c.id = b.c_id AND c.id = :id_1)", ) def test_any_many_to_many_doesnt_get_aliased_join_subq(self): @@ -2706,7 +3097,7 @@ class HasAnyTest( "JOIN d AS d_1 ON d_1.id = b_d_1.did " "WHERE EXISTS " "(SELECT 1 FROM b_d, d WHERE b.id = b_d.bid " - "AND d.id = b_d.did AND d.id = :id_1)" + "AND d.id = b_d.did AND d.id = :id_1)", ) @@ -2747,25 +3138,25 @@ class HasMapperEntitiesTest(QueryTest): class SetOpsTest(QueryTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_union(self): User = self.classes.User s = create_session() - fred = s.query(User).filter(User.name == 'fred') - ed = s.query(User).filter(User.name == 'ed') - jack = s.query(User).filter(User.name == 'jack') + fred = s.query(User).filter(User.name == "fred") + ed = s.query(User).filter(User.name == "ed") + jack = s.query(User).filter(User.name == "jack") eq_( fred.union(ed).order_by(User.name).all(), - [User(name='ed'), User(name='fred')] + [User(name="ed"), User(name="fred")], ) eq_( fred.union(ed, jack).order_by(User.name).all(), - [User(name='ed'), User(name='fred'), User(name='jack')] + [User(name="ed"), User(name="fred"), User(name="jack")], ) def test_statement_labels(self): @@ -2774,18 +3165,24 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): User, Address = self.classes.User, self.classes.Address s = create_session() - q1 = s.query(User, Address).join(User.addresses).\ - filter(Address.email_address == "ed@wood.com") - q2 = s.query(User, Address).join(User.addresses).\ - filter(Address.email_address == "jack@bean.com") + q1 = ( + s.query(User, Address) + .join(User.addresses) + .filter(Address.email_address == "ed@wood.com") + ) + q2 = ( + s.query(User, Address) + .join(User.addresses) + .filter(Address.email_address == "jack@bean.com") + ) q3 = q1.union(q2).order_by(User.name) eq_( q3.all(), [ - (User(name='ed'), Address(email_address="ed@wood.com")), - (User(name='jack'), Address(email_address="jack@bean.com")), - ] + (User(name="ed"), Address(email_address="ed@wood.com")), + (User(name="jack"), Address(email_address="jack@bean.com")), + ], ) def test_union_literal_expressions_compile(self): @@ -2807,7 +3204,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): "FROM (SELECT users.id AS users_id, users.name AS " "users_name, :param_1 AS param_1 " "FROM users UNION SELECT users.id AS users_id, " - "users.name AS users_name, 'y' FROM users) AS anon_1" + "users.name AS users_name, 'y' FROM users) AS anon_1", ) def test_union_literal_expressions_results(self): @@ -2819,30 +3216,28 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): q2 = s.query(User, literal_column("'y'")) q3 = q1.union(q2) - q4 = s.query(User, literal_column("'x'").label('foo')) + q4 = s.query(User, literal_column("'x'").label("foo")) q5 = s.query(User, literal("y")) q6 = q4.union(q5) - eq_( - [x['name'] for x in q6.column_descriptions], - ['User', 'foo'] - ) + eq_([x["name"] for x in q6.column_descriptions], ["User", "foo"]) for q in ( - q3.order_by(User.id, text("anon_1_param_1")), - q6.order_by(User.id, "foo")): + q3.order_by(User.id, text("anon_1_param_1")), + q6.order_by(User.id, "foo"), + ): eq_( q.all(), [ - (User(id=7, name='jack'), 'x'), - (User(id=7, name='jack'), 'y'), - (User(id=8, name='ed'), 'x'), - (User(id=8, name='ed'), 'y'), - (User(id=9, name='fred'), 'x'), - (User(id=9, name='fred'), 'y'), - (User(id=10, name='chuck'), 'x'), - (User(id=10, name='chuck'), 'y') - ] + (User(id=7, name="jack"), "x"), + (User(id=7, name="jack"), "y"), + (User(id=8, name="ed"), "x"), + (User(id=8, name="ed"), "y"), + (User(id=9, name="fred"), "x"), + (User(id=9, name="fred"), "y"), + (User(id=10, name="chuck"), "x"), + (User(id=10, name="chuck"), "y"), + ], ) def test_union_labeled_anonymous_columns(self): @@ -2850,14 +3245,13 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): s = Session() - c1, c2 = column('c1'), column('c2') - q1 = s.query(User, c1.label('foo'), c1.label('bar')) - q2 = s.query(User, c1.label('foo'), c2.label('bar')) + c1, c2 = column("c1"), column("c2") + q1 = s.query(User, c1.label("foo"), c1.label("bar")) + q2 = s.query(User, c1.label("foo"), c2.label("bar")) q3 = q1.union(q2) eq_( - [x['name'] for x in q3.column_descriptions], - ['User', 'foo', 'bar'] + [x["name"] for x in q3.column_descriptions], ["User", "foo", "bar"] ) self.assert_compile( @@ -2868,7 +3262,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): "FROM (SELECT users.id AS users_id, users.name AS users_name, " "c1 AS foo, c1 AS bar FROM users UNION SELECT users.id AS " "users_id, users.name AS users_name, c1 AS foo, c2 AS bar " - "FROM users) AS anon_1" + "FROM users) AS anon_1", ) def test_order_by_anonymous_col(self): @@ -2876,10 +3270,10 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): s = Session() - c1, c2 = column('c1'), column('c2') - f = c1.label('foo') - q1 = s.query(User, f, c2.label('bar')) - q2 = s.query(User, c1.label('foo'), c2.label('bar')) + c1, c2 = column("c1"), column("c2") + f = c1.label("foo") + q1 = s.query(User, f, c2.label("bar")) + q2 = s.query(User, c1.label("foo"), c2.label("bar")) q3 = q1.union(q2) self.assert_compile( @@ -2890,7 +3284,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): "users_name, c1 AS foo, c2 AS bar " "FROM users UNION SELECT users.id " "AS users_id, users.name AS users_name, c1 AS foo, c2 AS bar " - "FROM users) AS anon_1 ORDER BY anon_1.foo" + "FROM users) AS anon_1 ORDER BY anon_1.foo", ) self.assert_compile( @@ -2901,7 +3295,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): "users_name, c1 AS foo, c2 AS bar " "FROM users UNION SELECT users.id " "AS users_id, users.name AS users_name, c1 AS foo, c2 AS bar " - "FROM users) AS anon_1 ORDER BY anon_1.foo" + "FROM users) AS anon_1 ORDER BY anon_1.foo", ) def test_union_mapped_colnames_preserved_across_subquery(self): @@ -2916,15 +3310,12 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): q1.union(q2), "SELECT anon_1.users_name AS anon_1_users_name " "FROM (SELECT users.name AS users_name FROM users " - "UNION SELECT users.name AS users_name FROM users) AS anon_1" + "UNION SELECT users.name AS users_name FROM users) AS anon_1", ) # but in the returned named tuples, # due to [ticket:1942], this should be 'name', not 'users_name' - eq_( - [x['name'] for x in q1.union(q2).column_descriptions], - ['name'] - ) + eq_([x["name"] for x in q1.union(q2).column_descriptions], ["name"]) @testing.requires.intersect def test_intersect(self): @@ -2932,35 +3323,39 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): s = create_session() - fred = s.query(User).filter(User.name == 'fred') - ed = s.query(User).filter(User.name == 'ed') - jack = s.query(User).filter(User.name == 'jack') + fred = s.query(User).filter(User.name == "fred") + ed = s.query(User).filter(User.name == "ed") + jack = s.query(User).filter(User.name == "jack") eq_(fred.intersect(ed, jack).all(), []) - eq_(fred.union(ed).intersect(ed.union(jack)).all(), [User(name='ed')]) + eq_(fred.union(ed).intersect(ed.union(jack)).all(), [User(name="ed")]) def test_eager_load(self): User, Address = self.classes.User, self.classes.Address s = create_session() - fred = s.query(User).filter(User.name == 'fred') - ed = s.query(User).filter(User.name == 'ed') + fred = s.query(User).filter(User.name == "fred") + ed = s.query(User).filter(User.name == "ed") def go(): eq_( - fred.union(ed).order_by(User.name). - options(joinedload(User.addresses)).all(), [ + fred.union(ed) + .order_by(User.name) + .options(joinedload(User.addresses)) + .all(), + [ User( - name='ed', addresses=[Address(), Address(), - Address()]), - User(name='fred', addresses=[Address()])] + name="ed", addresses=[Address(), Address(), Address()] + ), + User(name="fred", addresses=[Address()]), + ], ) + self.assert_sql_count(testing.db, go, 1) class AggregateTest(QueryTest): - def test_sum(self): Order = self.classes.Order @@ -2968,31 +3363,45 @@ class AggregateTest(QueryTest): orders = sess.query(Order).filter(Order.id.in_([2, 3, 4])) eq_( next(orders.values(func.sum(Order.user_id * Order.address_id))), - (79,)) + (79,), + ) eq_(orders.value(func.sum(Order.user_id * Order.address_id)), 79) def test_apply(self): Order = self.classes.Order sess = create_session() - assert sess.query(func.sum(Order.user_id * Order.address_id)). \ - filter(Order.id.in_([2, 3, 4])).one() == (79,) + assert sess.query(func.sum(Order.user_id * Order.address_id)).filter( + Order.id.in_([2, 3, 4]) + ).one() == (79,) def test_having(self): User, Address = self.classes.User, self.classes.Address sess = create_session() - assert [User(name='ed', id=8)] == \ - sess.query(User).order_by(User.id).group_by(User). \ - join('addresses').having(func.count(Address.id) > 2).all() + assert ( + [User(name="ed", id=8)] + == sess.query(User) + .order_by(User.id) + .group_by(User) + .join("addresses") + .having(func.count(Address.id) > 2) + .all() + ) - assert [User(name='jack', id=7), User(name='fred', id=9)] == \ - sess.query(User).order_by(User.id).group_by(User). \ - join('addresses').having(func.count(Address.id) < 2).all() + assert ( + [User(name="jack", id=7), User(name="fred", id=9)] + == sess.query(User) + .order_by(User.id) + .group_by(User) + .join("addresses") + .having(func.count(Address.id) < 2) + .all() + ) class ExistsTest(QueryTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_exists(self): User = self.classes.User @@ -3001,17 +3410,15 @@ class ExistsTest(QueryTest, AssertsCompiledSQL): q1 = sess.query(User) self.assert_compile( sess.query(q1.exists()), - 'SELECT EXISTS (' - 'SELECT 1 FROM users' - ') AS anon_1' + "SELECT EXISTS (" "SELECT 1 FROM users" ") AS anon_1", ) - q2 = sess.query(User).filter(User.name == 'fred') + q2 = sess.query(User).filter(User.name == "fred") self.assert_compile( sess.query(q2.exists()), - 'SELECT EXISTS (' - 'SELECT 1 FROM users WHERE users.name = :name_1' - ') AS anon_1' + "SELECT EXISTS (" + "SELECT 1 FROM users WHERE users.name = :name_1" + ") AS anon_1", ) def test_exists_col_warning(self): @@ -3022,10 +3429,10 @@ class ExistsTest(QueryTest, AssertsCompiledSQL): q1 = sess.query(User, Address).filter(User.id == Address.user_id) self.assert_compile( sess.query(q1.exists()), - 'SELECT EXISTS (' - 'SELECT 1 FROM users, addresses ' - 'WHERE users.id = addresses.user_id' - ') AS anon_1' + "SELECT EXISTS (" + "SELECT 1 FROM users, addresses " + "WHERE users.id = addresses.user_id" + ") AS anon_1", ) def test_exists_w_select_from(self): @@ -3034,8 +3441,7 @@ class ExistsTest(QueryTest, AssertsCompiledSQL): q1 = sess.query().select_from(User).exists() self.assert_compile( - sess.query(q1), - 'SELECT EXISTS (SELECT 1 FROM users) AS anon_1' + sess.query(q1), "SELECT EXISTS (SELECT 1 FROM users) AS anon_1" ) @@ -3047,7 +3453,7 @@ class CountTest(QueryTest): eq_(s.query(User).count(), 4) - eq_(s.query(User).filter(users.c.name.endswith('ed')).count(), 2) + eq_(s.query(User).filter(users.c.name.endswith("ed")).count(), 2) def test_count_char(self): User = self.classes.User @@ -3057,11 +3463,14 @@ class CountTest(QueryTest): # rumors about Oracle preferring count(1) don't appear # to be well founded. self.assert_sql_execution( - testing.db, s.query(User).count, CompiledSQL( + testing.db, + s.query(User).count, + CompiledSQL( "SELECT count(*) AS count_1 FROM " "(SELECT users.id AS users_id, users.name " - "AS users_name FROM users) AS anon_1", {} - ) + "AS users_name FROM users) AS anon_1", + {}, + ), ) def test_multiple_entity(self): @@ -3112,52 +3521,58 @@ class CountTest(QueryTest): class DistinctTest(QueryTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_basic(self): User = self.classes.User eq_( [User(id=7), User(id=8), User(id=9), User(id=10)], - create_session().query(User).order_by(User.id).distinct().all() + create_session().query(User).order_by(User.id).distinct().all(), ) eq_( [User(id=7), User(id=9), User(id=8), User(id=10)], - create_session().query(User).distinct(). - order_by(desc(User.name)).all() + create_session() + .query(User) + .distinct() + .order_by(desc(User.name)) + .all(), ) def test_columns_augmented_roundtrip_one(self): User, Address = self.classes.User, self.classes.Address sess = create_session() - q = sess.query(User).join('addresses').distinct(). \ - order_by(desc(Address.email_address)) - - eq_( - [User(id=7), User(id=9), User(id=8)], - q.all() + q = ( + sess.query(User) + .join("addresses") + .distinct() + .order_by(desc(Address.email_address)) ) + eq_([User(id=7), User(id=9), User(id=8)], q.all()) + def test_columns_augmented_roundtrip_two(self): User, Address = self.classes.User, self.classes.Address sess = create_session() # test that it works on embedded joinedload/LIMIT subquery - q = sess.query(User).join('addresses').distinct(). \ - options(joinedload('addresses')).\ - order_by(desc(Address.email_address)).limit(2) + q = ( + sess.query(User) + .join("addresses") + .distinct() + .options(joinedload("addresses")) + .order_by(desc(Address.email_address)) + .limit(2) + ) def go(): assert [ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=9, addresses=[ - Address(id=5) - ]), + User(id=7, addresses=[Address(id=1)]), + User(id=9, addresses=[Address(id=5)]), ] == q.all() + self.assert_sql_count(testing.db, go, 1) def test_columns_augmented_roundtrip_three(self): @@ -3165,28 +3580,37 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): sess = create_session() - q = sess.query(User.id, User.name.label('foo'), Address.id).\ - filter(User.name == 'jack').\ - distinct().\ - order_by(User.id, User.name, Address.email_address) + q = ( + sess.query(User.id, User.name.label("foo"), Address.id) + .filter(User.name == "jack") + .distinct() + .order_by(User.id, User.name, Address.email_address) + ) # even though columns are added, they aren't in the result eq_( q.all(), - [(7, 'jack', 3), (7, 'jack', 4), (7, 'jack', 2), - (7, 'jack', 5), (7, 'jack', 1)] + [ + (7, "jack", 3), + (7, "jack", 4), + (7, "jack", 2), + (7, "jack", 5), + (7, "jack", 1), + ], ) for row in q: - eq_(row.keys(), ['id', 'foo', 'id']) + eq_(row.keys(), ["id", "foo", "id"]) def test_columns_augmented_sql_one(self): User, Address = self.classes.User, self.classes.Address sess = create_session() - q = sess.query(User.id, User.name.label('foo'), Address.id).\ - distinct().\ - order_by(User.id, User.name, Address.email_address) + q = ( + sess.query(User.id, User.name.label("foo"), Address.id) + .distinct() + .order_by(User.id, User.name, Address.email_address) + ) # Address.email_address is added because of DISTINCT, # however User.id, User.name are not b.c. they're already there, @@ -3196,7 +3620,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): "SELECT DISTINCT users.id AS users_id, users.name AS foo, " "addresses.id AS addresses_id, " "addresses.email_address AS addresses_email_address FROM users, " - "addresses ORDER BY users.id, users.name, addresses.email_address" + "addresses ORDER BY users.id, users.name, addresses.email_address", ) def test_columns_augmented_sql_two(self): @@ -3204,11 +3628,13 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): sess = create_session() - q = sess.query(User).\ - options(joinedload(User.addresses)).\ - distinct().\ - order_by(User.name, Address.email_address).\ - limit(5) + q = ( + sess.query(User) + .options(joinedload(User.addresses)) + .distinct() + .order_by(User.name, Address.email_address) + .limit(5) + ) # addresses.email_address is added to inner query so that # it is available in ORDER BY @@ -3230,7 +3656,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): "addresses AS addresses_1 " "ON anon_1.users_id = addresses_1.user_id " "ORDER BY anon_1.users_name, " - "anon_1.addresses_email_address, addresses_1.id" + "anon_1.addresses_email_address, addresses_1.id", ) def test_columns_augmented_sql_three(self): @@ -3238,9 +3664,11 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): sess = create_session() - q = sess.query(User.id, User.name.label('foo'), Address.id).\ - distinct(User.name).\ - order_by(User.id, User.name, Address.email_address) + q = ( + sess.query(User.id, User.name.label("foo"), Address.id) + .distinct(User.name) + .order_by(User.id, User.name, Address.email_address) + ) # no columns are added when DISTINCT ON is used self.assert_compile( @@ -3248,7 +3676,7 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): "SELECT DISTINCT ON (users.name) users.id AS users_id, " "users.name AS foo, addresses.id AS addresses_id FROM users, " "addresses ORDER BY users.id, users.name, addresses.email_address", - dialect='postgresql' + dialect="postgresql", ) def test_columns_augmented_sql_four(self): @@ -3256,10 +3684,14 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): sess = create_session() - q = sess.query(User).join('addresses').\ - distinct(Address.email_address). \ - options(joinedload('addresses')).\ - order_by(desc(Address.email_address)).limit(2) + q = ( + sess.query(User) + .join("addresses") + .distinct(Address.email_address) + .options(joinedload("addresses")) + .order_by(desc(Address.email_address)) + .limit(2) + ) # but for the subquery / eager load case, we still need to make # the inner columns available for the ORDER BY even though its @@ -3282,53 +3714,60 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): "LEFT OUTER JOIN addresses AS addresses_1 " "ON anon_1.users_id = addresses_1.user_id " "ORDER BY anon_1.addresses_email_address DESC, addresses_1.id", - dialect='postgresql' + dialect="postgresql", ) class PrefixWithTest(QueryTest, AssertsCompiledSQL): - def test_one_prefix(self): User = self.classes.User sess = create_session() - query = sess.query(User.name)\ - .prefix_with('PREFIX_1') - expected = "SELECT PREFIX_1 "\ - "users.name AS users_name FROM users" + query = sess.query(User.name).prefix_with("PREFIX_1") + expected = "SELECT PREFIX_1 " "users.name AS users_name FROM users" self.assert_compile(query, expected, dialect=default.DefaultDialect()) def test_many_prefixes(self): User = self.classes.User sess = create_session() - query = sess.query(User.name).prefix_with('PREFIX_1', 'PREFIX_2') - expected = "SELECT PREFIX_1 PREFIX_2 "\ - "users.name AS users_name FROM users" + query = sess.query(User.name).prefix_with("PREFIX_1", "PREFIX_2") + expected = ( + "SELECT PREFIX_1 PREFIX_2 " "users.name AS users_name FROM users" + ) self.assert_compile(query, expected, dialect=default.DefaultDialect()) def test_chained_prefixes(self): User = self.classes.User sess = create_session() - query = sess.query(User.name)\ - .prefix_with('PREFIX_1')\ - .prefix_with('PREFIX_2', 'PREFIX_3') - expected = "SELECT PREFIX_1 PREFIX_2 PREFIX_3 "\ + query = ( + sess.query(User.name) + .prefix_with("PREFIX_1") + .prefix_with("PREFIX_2", "PREFIX_3") + ) + expected = ( + "SELECT PREFIX_1 PREFIX_2 PREFIX_3 " "users.name AS users_name FROM users" + ) self.assert_compile(query, expected, dialect=default.DefaultDialect()) class YieldTest(_fixtures.FixtureTest): - run_setup_mappers = 'each' - run_inserts = 'each' + run_setup_mappers = "each" + run_inserts = "each" def _eagerload_mappings(self, addresses_lazy=True, user_lazy=True): User, Address = self.classes("User", "Address") users, addresses = self.tables("users", "addresses") - mapper(User, users, properties={ - "addresses": relationship( - Address, lazy=addresses_lazy, - backref=backref("user", lazy=user_lazy) - ) - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, + lazy=addresses_lazy, + backref=backref("user", lazy=user_lazy), + ) + }, + ) mapper(Address, addresses) def test_basic(self): @@ -3338,8 +3777,10 @@ class YieldTest(_fixtures.FixtureTest): sess = create_session() q = iter( - sess.query(User).yield_per(1).from_statement( - text("select * from users"))) + sess.query(User) + .yield_per(1) + .from_statement(text("select * from users")) + ) ret = [] eq_(len(sess.identity_map), 0) @@ -3362,11 +3803,12 @@ class YieldTest(_fixtures.FixtureTest): sess = create_session() q = sess.query(User).yield_per(15) - q = q.execution_options(foo='bar') + q = q.execution_options(foo="bar") assert q._yield_per eq_( q._execution_options, - {"stream_results": True, "foo": "bar", "max_row_buffer": 15}) + {"stream_results": True, "foo": "bar", "max_row_buffer": 15}, + ) def test_no_joinedload_opt(self): self._eagerload_mappings() @@ -3378,7 +3820,7 @@ class YieldTest(_fixtures.FixtureTest): sa_exc.InvalidRequestError, "The yield_per Query option is currently not compatible with " "joined collection eager loading. Please specify ", - q.all + q.all, ) def test_no_subqueryload_opt(self): @@ -3391,7 +3833,7 @@ class YieldTest(_fixtures.FixtureTest): sa_exc.InvalidRequestError, "The yield_per Query option is currently not compatible with " "subquery eager loading. Please specify ", - q.all + q.all, ) def test_no_subqueryload_mapping(self): @@ -3404,7 +3846,7 @@ class YieldTest(_fixtures.FixtureTest): sa_exc.InvalidRequestError, "The yield_per Query option is currently not compatible with " "subquery eager loading. Please specify ", - q.all + q.all, ) def test_joinedload_m2o_ok(self): @@ -3419,73 +3861,93 @@ class YieldTest(_fixtures.FixtureTest): User = self.classes.User sess = create_session() - q = sess.query(User).options(subqueryload("addresses")).\ - enable_eagerloads(False).yield_per(1) + q = ( + sess.query(User) + .options(subqueryload("addresses")) + .enable_eagerloads(False) + .yield_per(1) + ) q.all() - q = sess.query(User).options(joinedload("addresses")).\ - enable_eagerloads(False).yield_per(1) + q = ( + sess.query(User) + .options(joinedload("addresses")) + .enable_eagerloads(False) + .yield_per(1) + ) q.all() def test_m2o_joinedload_not_others(self): self._eagerload_mappings(addresses_lazy="joined") Address = self.classes.Address sess = create_session() - q = sess.query(Address).options( - lazyload('*'), joinedload("user")).yield_per(1).filter_by(id=1) + q = ( + sess.query(Address) + .options(lazyload("*"), joinedload("user")) + .yield_per(1) + .filter_by(id=1) + ) def go(): result = q.all() assert result[0].user + self.assert_sql_count(testing.db, go, 1) class HintsTest(QueryTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_hints(self): User = self.classes.User from sqlalchemy.dialects import mysql + dialect = mysql.dialect() sess = create_session() self.assert_compile( sess.query(User).with_hint( - User, 'USE INDEX (col1_index,col2_index)'), + User, "USE INDEX (col1_index,col2_index)" + ), "SELECT users.id AS users_id, users.name AS users_name " "FROM users USE INDEX (col1_index,col2_index)", - dialect=dialect + dialect=dialect, ) self.assert_compile( sess.query(User).with_hint( - User, 'WITH INDEX col1_index', 'sybase'), + User, "WITH INDEX col1_index", "sybase" + ), "SELECT users.id AS users_id, users.name AS users_name " - "FROM users", dialect=dialect + "FROM users", + dialect=dialect, ) ualias = aliased(User) self.assert_compile( - sess.query(User, ualias).with_hint( - ualias, 'USE INDEX (col1_index,col2_index)'). - join(ualias, ualias.id > User.id), + sess.query(User, ualias) + .with_hint(ualias, "USE INDEX (col1_index,col2_index)") + .join(ualias, ualias.id > User.id), "SELECT users.id AS users_id, users.name AS users_name, " "users_1.id AS users_1_id, users_1.name AS users_1_name " "FROM users INNER JOIN users AS users_1 " "USE INDEX (col1_index,col2_index) " - "ON users_1.id > users.id", dialect=dialect + "ON users_1.id > users.id", + dialect=dialect, ) def test_statement_hints(self): User = self.classes.User sess = create_session() - stmt = sess.query(User).\ - with_statement_hint("test hint one").\ - with_statement_hint("test hint two").\ - with_statement_hint("test hint three", "postgresql") + stmt = ( + sess.query(User) + .with_statement_hint("test hint one") + .with_statement_hint("test hint two") + .with_statement_hint("test hint three", "postgresql") + ) self.assert_compile( stmt, @@ -3497,31 +3959,41 @@ class HintsTest(QueryTest, AssertsCompiledSQL): stmt, "SELECT users.id AS users_id, users.name AS users_name " "FROM users test hint one test hint two test hint three", - dialect='postgresql' + dialect="postgresql", ) class TextTest(QueryTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_fulltext(self): User = self.classes.User with expect_warnings("Textual SQL"): eq_( - create_session().query(User). - from_statement("select * from users order by id").all(), - [User(id=7), User(id=8), User(id=9), User(id=10)] + create_session() + .query(User) + .from_statement("select * from users order by id") + .all(), + [User(id=7), User(id=8), User(id=9), User(id=10)], ) eq_( - create_session().query(User).from_statement( - text("select * from users order by id")).first(), User(id=7) + create_session() + .query(User) + .from_statement(text("select * from users order by id")) + .first(), + User(id=7), ) eq_( - create_session().query(User).from_statement( - text("select * from users where name='nonexistent'")).first(), - None) + create_session() + .query(User) + .from_statement( + text("select * from users where name='nonexistent'") + ) + .first(), + None, + ) def test_fragment(self): User = self.classes.User @@ -3529,17 +4001,24 @@ class TextTest(QueryTest, AssertsCompiledSQL): with expect_warnings("Textual SQL expression"): eq_( create_session().query(User).filter("id in (8, 9)").all(), - [User(id=8), User(id=9)] - + [User(id=8), User(id=9)], ) eq_( - create_session().query(User).filter("name='fred'"). - filter("id=9").all(), [User(id=9)] + create_session() + .query(User) + .filter("name='fred'") + .filter("id=9") + .all(), + [User(id=9)], ) eq_( - create_session().query(User).filter("name='fred'"). - filter(User.id == 9).all(), [User(id=9)] + create_session() + .query(User) + .filter("name='fred'") + .filter(User.id == 9) + .all(), + [User(id=9)], ) def test_binds_coerce(self): @@ -3547,8 +4026,12 @@ class TextTest(QueryTest, AssertsCompiledSQL): with expect_warnings("Textual SQL expression"): eq_( - create_session().query(User).filter("id in (:id1, :id2)"). - params(id1=8, id2=9).all(), [User(id=8), User(id=9)] + create_session() + .query(User) + .filter("id in (:id1, :id2)") + .params(id1=8, id2=9) + .all(), + [User(id=8), User(id=9)], ) def test_as_column(self): @@ -3556,22 +4039,26 @@ class TextTest(QueryTest, AssertsCompiledSQL): s = create_session() assert_raises( - sa_exc.InvalidRequestError, s.query, - User.id, text("users.name")) + sa_exc.InvalidRequestError, s.query, User.id, text("users.name") + ) eq_( s.query(User.id, "name").order_by(User.id).all(), - [(7, 'jack'), (8, 'ed'), (9, 'fred'), (10, 'chuck')]) + [(7, "jack"), (8, "ed"), (9, "fred"), (10, "chuck")], + ) def test_via_select(self): User = self.classes.User s = create_session() eq_( - s.query(User).from_statement( - select([column('id'), column('name')]). - select_from(table('users')).order_by('id'), - ).all(), - [User(id=7), User(id=8), User(id=9), User(id=10)] + s.query(User) + .from_statement( + select([column("id"), column("name")]) + .select_from(table("users")) + .order_by("id") + ) + .all(), + [User(id=7), User(id=8), User(id=9), User(id=10)], ) def test_via_textasfrom_from_statement(self): @@ -3579,10 +4066,14 @@ class TextTest(QueryTest, AssertsCompiledSQL): s = create_session() eq_( - s.query(User).from_statement( - text("select * from users order by id"). - columns(id=Integer, name=String)).all(), - [User(id=7), User(id=8), User(id=9), User(id=10)] + s.query(User) + .from_statement( + text("select * from users order by id").columns( + id=Integer, name=String + ) + ) + .all(), + [User(id=7), User(id=8), User(id=9), User(id=10)], ) def test_via_textasfrom_use_mapped_columns(self): @@ -3590,10 +4081,14 @@ class TextTest(QueryTest, AssertsCompiledSQL): s = create_session() eq_( - s.query(User).from_statement( - text("select * from users order by id"). - columns(User.id, User.name)).all(), - [User(id=7), User(id=8), User(id=9), User(id=10)] + s.query(User) + .from_statement( + text("select * from users order by id").columns( + User.id, User.name + ) + ) + .all(), + [User(id=7), User(id=8), User(id=9), User(id=10)], ) def test_via_textasfrom_select_from(self): @@ -3601,10 +4096,13 @@ class TextTest(QueryTest, AssertsCompiledSQL): s = create_session() eq_( - s.query(User).select_from( + s.query(User) + .select_from( text("select * from users").columns(id=Integer, name=String) - ).order_by(User.id).all(), - [User(id=7), User(id=8), User(id=9), User(id=10)] + ) + .order_by(User.id) + .all(), + [User(id=7), User(id=8), User(id=9), User(id=10)], ) def test_group_by_accepts_text(self): @@ -3615,16 +4113,14 @@ class TextTest(QueryTest, AssertsCompiledSQL): self.assert_compile( q, "SELECT users.id AS users_id, users.name AS users_name " - "FROM users GROUP BY name" + "FROM users GROUP BY name", ) def test_orm_columns_accepts_text(self): from sqlalchemy.orm.base import _orm_columns + t = text("x") - eq_( - _orm_columns(t), - [t] - ) + eq_(_orm_columns(t), [t]) def test_order_by_w_eager_one(self): User = self.classes.User @@ -3644,8 +4140,10 @@ class TextTest(QueryTest, AssertsCompiledSQL): # with expect_warnings("Can't resolve label reference 'name';"): self.assert_compile( - s.query(User).options(joinedload("addresses")). - order_by(desc("name")).limit(1), + s.query(User) + .options(joinedload("addresses")) + .order_by(desc("name")) + .limit(1), "SELECT anon_1.users_id AS anon_1_users_id, " "anon_1.users_name AS anon_1_users_name, " "addresses_1.id AS addresses_1_id, " @@ -3656,7 +4154,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): "DESC LIMIT :param_1) AS anon_1 " "LEFT OUTER JOIN addresses AS addresses_1 " "ON anon_1.users_id = addresses_1.user_id " - "ORDER BY name DESC, addresses_1.id" + "ORDER BY name DESC, addresses_1.id", ) def test_order_by_w_eager_two(self): @@ -3665,8 +4163,10 @@ class TextTest(QueryTest, AssertsCompiledSQL): with expect_warnings("Can't resolve label reference 'name';"): self.assert_compile( - s.query(User).options(joinedload("addresses")). - order_by("name").limit(1), + s.query(User) + .options(joinedload("addresses")) + .order_by("name") + .limit(1), "SELECT anon_1.users_id AS anon_1_users_id, " "anon_1.users_name AS anon_1_users_name, " "addresses_1.id AS addresses_1_id, " @@ -3677,7 +4177,7 @@ class TextTest(QueryTest, AssertsCompiledSQL): "LIMIT :param_1) AS anon_1 " "LEFT OUTER JOIN addresses AS addresses_1 " "ON anon_1.users_id = addresses_1.user_id " - "ORDER BY name, addresses_1.id" + "ORDER BY name, addresses_1.id", ) def test_order_by_w_eager_three(self): @@ -3685,8 +4185,10 @@ class TextTest(QueryTest, AssertsCompiledSQL): s = create_session() self.assert_compile( - s.query(User).options(joinedload("addresses")). - order_by("users_name").limit(1), + s.query(User) + .options(joinedload("addresses")) + .order_by("users_name") + .limit(1), "SELECT anon_1.users_id AS anon_1_users_id, " "anon_1.users_name AS anon_1_users_name, " "addresses_1.id AS addresses_1_id, " @@ -3697,14 +4199,16 @@ class TextTest(QueryTest, AssertsCompiledSQL): "LIMIT :param_1) AS anon_1 " "LEFT OUTER JOIN addresses AS addresses_1 " "ON anon_1.users_id = addresses_1.user_id " - "ORDER BY anon_1.users_name, addresses_1.id" + "ORDER BY anon_1.users_name, addresses_1.id", ) # however! this works (again?) eq_( - s.query(User).options(joinedload("addresses")). - order_by("users_name").first(), - User(name='chuck', addresses=[]) + s.query(User) + .options(joinedload("addresses")) + .order_by("users_name") + .first(), + User(name="chuck", addresses=[]), ) def test_order_by_w_eager_four(self): @@ -3713,8 +4217,10 @@ class TextTest(QueryTest, AssertsCompiledSQL): s = create_session() self.assert_compile( - s.query(User).options(joinedload("addresses")). - order_by(desc("users_name")).limit(1), + s.query(User) + .options(joinedload("addresses")) + .order_by(desc("users_name")) + .limit(1), "SELECT anon_1.users_id AS anon_1_users_id, " "anon_1.users_name AS anon_1_users_name, " "addresses_1.id AS addresses_1_id, " @@ -3725,14 +4231,16 @@ class TextTest(QueryTest, AssertsCompiledSQL): "LIMIT :param_1) AS anon_1 " "LEFT OUTER JOIN addresses AS addresses_1 " "ON anon_1.users_id = addresses_1.user_id " - "ORDER BY anon_1.users_name DESC, addresses_1.id" + "ORDER BY anon_1.users_name DESC, addresses_1.id", ) # however! this works (again?) eq_( - s.query(User).options(joinedload("addresses")). - order_by(desc("users_name")).first(), - User(name='jack', addresses=[Address()]) + s.query(User) + .options(joinedload("addresses")) + .order_by(desc("users_name")) + .first(), + User(name="jack", addresses=[Address()]), ) def test_order_by_w_eager_five(self): @@ -3746,22 +4254,31 @@ class TextTest(QueryTest, AssertsCompiledSQL): sess = create_session() - q = sess.query(User, Address.email_address.label('email_address')) + q = sess.query(User, Address.email_address.label("email_address")) - result = q.join('addresses').options(joinedload(User.orders)).\ - order_by( - "email_address desc").limit(1).offset(0) + result = ( + q.join("addresses") + .options(joinedload(User.orders)) + .order_by("email_address desc") + .limit(1) + .offset(0) + ) with expect_warnings( - "Can't resolve label reference 'email_address desc'"): + "Can't resolve label reference 'email_address desc'" + ): eq_( [ - (User( - id=7, - orders=[Order(id=1), Order(id=3), Order(id=5)], - addresses=[Address(id=1)] - ), 'jack@bean.com') + ( + User( + id=7, + orders=[Order(id=1), Order(id=3), Order(id=5)], + addresses=[Address(id=1)], + ), + "jack@bean.com", + ) ], - result.all()) + result.all(), + ) class TextWarningTest(QueryTest, AssertsCompiledSQL): @@ -3769,10 +4286,10 @@ class TextWarningTest(QueryTest, AssertsCompiledSQL): assert_raises_message( sa.exc.SAWarning, r"Textual (?:SQL|column|SQL FROM) expression %(stmt)r should be " - r"explicitly declared (?:with|as) text\(%(stmt)r\)" % { - "stmt": util.ellipses_string(offending_clause), - }, - fn, arg + r"explicitly declared (?:with|as) text\(%(stmt)r\)" + % {"stmt": util.ellipses_string(offending_clause)}, + fn, + arg, ) with expect_warnings("Textual "): @@ -3782,15 +4299,19 @@ class TextWarningTest(QueryTest, AssertsCompiledSQL): def test_filter(self): User = self.classes.User self._test( - Session().query(User.id).filter, "myid == 5", "myid == 5", - "SELECT users.id AS users_id FROM users WHERE myid == 5" + Session().query(User.id).filter, + "myid == 5", + "myid == 5", + "SELECT users.id AS users_id FROM users WHERE myid == 5", ) def test_having(self): User = self.classes.User self._test( - Session().query(User.id).having, "myid == 5", "myid == 5", - "SELECT users.id AS users_id FROM users HAVING myid == 5" + Session().query(User.id).having, + "myid == 5", + "myid == 5", + "SELECT users.id AS users_id FROM users HAVING myid == 5", ) def test_from_statement(self): @@ -3804,40 +4325,56 @@ class TextWarningTest(QueryTest, AssertsCompiledSQL): class ParentTest(QueryTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_o2m(self): User, orders, Order = ( - self.classes.User, self.tables.orders, self.classes.Order) + self.classes.User, + self.tables.orders, + self.classes.Order, + ) sess = create_session() q = sess.query(User) - u1 = q.filter_by(name='jack').one() + u1 = q.filter_by(name="jack").one() # test auto-lookup of property o = sess.query(Order).with_parent(u1).all() - assert [Order(description="order 1"), Order(description="order 3"), - Order(description="order 5")] == o + assert [ + Order(description="order 1"), + Order(description="order 3"), + Order(description="order 5"), + ] == o # test with explicit property - o = sess.query(Order).with_parent(u1, property='orders').all() - assert [Order(description="order 1"), Order(description="order 3"), - Order(description="order 5")] == o + o = sess.query(Order).with_parent(u1, property="orders").all() + assert [ + Order(description="order 1"), + Order(description="order 3"), + Order(description="order 5"), + ] == o o = sess.query(Order).with_parent(u1, property=User.orders).all() - assert [Order(description="order 1"), Order(description="order 3"), - Order(description="order 5")] == o + assert [ + Order(description="order 1"), + Order(description="order 3"), + Order(description="order 5"), + ] == o o = sess.query(Order).filter(with_parent(u1, User.orders)).all() assert [ - Order(description="order 1"), Order(description="order 3"), - Order(description="order 5")] == o + Order(description="order 1"), + Order(description="order 3"), + Order(description="order 5"), + ] == o # test generative criterion o = sess.query(Order).with_parent(u1).filter(orders.c.id > 2).all() assert [ - Order(description="order 3"), Order(description="order 5")] == o + Order(description="order 3"), + Order(description="order 5"), + ] == o # test against None for parent? this can't be done with the current # API since we don't know what mapper to use @@ -3857,7 +4394,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): "addresses.user_id AS addresses_user_id, " "addresses.email_address AS addresses_email_address " "FROM addresses WHERE :param_1 = addresses.user_id", - {'param_1': 7} + {"param_1": 7}, ) def test_from_entity_standalone_fn(self): @@ -3866,7 +4403,8 @@ class ParentTest(QueryTest, AssertsCompiledSQL): sess = create_session() u1 = sess.query(User).get(7) q = sess.query(User, Address).filter( - with_parent(u1, "addresses", from_entity=Address)) + with_parent(u1, "addresses", from_entity=Address) + ) self.assert_compile( q, "SELECT users.id AS users_id, users.name AS users_name, " @@ -3875,7 +4413,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): "addresses.email_address AS addresses_email_address " "FROM users, addresses " "WHERE :param_1 = addresses.user_id", - {'param_1': 7} + {"param_1": 7}, ) def test_from_entity_query_entity(self): @@ -3884,7 +4422,8 @@ class ParentTest(QueryTest, AssertsCompiledSQL): sess = create_session() u1 = sess.query(User).get(7) q = sess.query(User, Address).with_parent( - u1, "addresses", from_entity=Address) + u1, "addresses", from_entity=Address + ) self.assert_compile( q, "SELECT users.id AS users_id, users.name AS users_name, " @@ -3893,7 +4432,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): "addresses.email_address AS addresses_email_address " "FROM users, addresses " "WHERE :param_1 = addresses.user_id", - {'param_1': 7} + {"param_1": 7}, ) def test_select_from_alias(self): @@ -3910,7 +4449,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): "addresses_1.email_address AS addresses_1_email_address " "FROM addresses AS addresses_1 " "WHERE :param_1 = addresses_1.user_id", - {'param_1': 7} + {"param_1": 7}, ) def test_select_from_alias_explicit_prop(self): @@ -3927,7 +4466,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): "addresses_1.email_address AS addresses_1_email_address " "FROM addresses AS addresses_1 " "WHERE :param_1 = addresses_1.user_id", - {'param_1': 7} + {"param_1": 7}, ) def test_noparent(self): @@ -3936,15 +4475,16 @@ class ParentTest(QueryTest, AssertsCompiledSQL): sess = create_session() q = sess.query(User) - u1 = q.filter_by(name='jack').one() + u1 = q.filter_by(name="jack").one() try: q = sess.query(Item).with_parent(u1) assert False except sa_exc.InvalidRequestError as e: - assert str(e) \ - == "Could not locate a property which relates "\ + assert ( + str(e) == "Could not locate a property which relates " "instances of class 'Item' to instances of class 'User'" + ) def test_m2m(self): Item, Keyword = self.classes.Item, self.classes.Keyword @@ -3953,8 +4493,10 @@ class ParentTest(QueryTest, AssertsCompiledSQL): i1 = sess.query(Item).filter_by(id=2).one() k = sess.query(Keyword).with_parent(i1).all() assert [ - Keyword(name='red'), Keyword(name='small'), - Keyword(name='square')] == k + Keyword(name="red"), + Keyword(name="small"), + Keyword(name="square"), + ] == k def test_with_transient(self): User, Order = self.classes.User, self.classes.Order @@ -3962,22 +4504,26 @@ class ParentTest(QueryTest, AssertsCompiledSQL): sess = Session() q = sess.query(User) - u1 = q.filter_by(name='jack').one() + u1 = q.filter_by(name="jack").one() utrans = User(id=u1.id) - o = sess.query(Order).with_parent(utrans, 'orders') + o = sess.query(Order).with_parent(utrans, "orders") eq_( [ - Order(description="order 1"), Order(description="order 3"), - Order(description="order 5")], - o.all() + Order(description="order 1"), + Order(description="order 3"), + Order(description="order 5"), + ], + o.all(), ) - o = sess.query(Order).filter(with_parent(utrans, 'orders')) + o = sess.query(Order).filter(with_parent(utrans, "orders")) eq_( [ - Order(description="order 1"), Order(description="order 3"), - Order(description="order 5")], - o.all() + Order(description="order 1"), + Order(description="order 3"), + Order(description="order 5"), + ], + o.all(), ) def test_with_pending_autoflush(self): @@ -3989,12 +4535,12 @@ class ParentTest(QueryTest, AssertsCompiledSQL): opending = Order(id=20, user_id=o1.user_id) sess.add(opending) eq_( - sess.query(User).with_parent(opending, 'user').one(), - User(id=o1.user_id) + sess.query(User).with_parent(opending, "user").one(), + User(id=o1.user_id), ) eq_( - sess.query(User).filter(with_parent(opending, 'user')).one(), - User(id=o1.user_id) + sess.query(User).filter(with_parent(opending, "user")).one(), + User(id=o1.user_id), ) def test_with_pending_no_autoflush(self): @@ -4006,8 +4552,8 @@ class ParentTest(QueryTest, AssertsCompiledSQL): opending = Order(user_id=o1.user_id) sess.add(opending) eq_( - sess.query(User).with_parent(opending, 'user').one(), - User(id=o1.user_id) + sess.query(User).with_parent(opending, "user").one(), + User(id=o1.user_id), ) def test_unique_binds_union(self): @@ -4017,8 +4563,8 @@ class ParentTest(QueryTest, AssertsCompiledSQL): sess = Session() u1, u2 = sess.query(User).order_by(User.id)[0:2] - q1 = sess.query(Address).with_parent(u1, 'addresses') - q2 = sess.query(Address).with_parent(u2, 'addresses') + q1 = sess.query(Address).with_parent(u1, "addresses") + q2 = sess.query(Address).with_parent(u2, "addresses") self.assert_compile( q1.union(q2), @@ -4033,7 +4579,7 @@ class ParentTest(QueryTest, AssertsCompiledSQL): "addresses_user_id, addresses.email_address " "AS addresses_email_address " "FROM addresses WHERE :param_2 = addresses.user_id) AS anon_1", - checkparams={'param_1': 7, 'param_2': 8}, + checkparams={"param_1": 7, "param_2": 8}, ) def test_unique_binds_or(self): @@ -4044,32 +4590,39 @@ class ParentTest(QueryTest, AssertsCompiledSQL): self.assert_compile( sess.query(Address).filter( - or_(with_parent(u1, 'addresses'), with_parent(u2, 'addresses')) + or_(with_parent(u1, "addresses"), with_parent(u2, "addresses")) ), "SELECT addresses.id AS addresses_id, addresses.user_id AS " "addresses_user_id, addresses.email_address AS " "addresses_email_address FROM addresses WHERE " ":param_1 = addresses.user_id OR :param_2 = addresses.user_id", - checkparams={'param_1': 7, 'param_2': 8}, + checkparams={"param_1": 7, "param_2": 8}, ) class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): run_inserts = None - __dialect__ = 'default' + __dialect__ = "default" def _fixture1(self): User, Address = self.classes.User, self.classes.Address users, addresses = self.tables.users, self.tables.addresses mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship(User), - 'special_user': relationship( - User, primaryjoin=and_( - users.c.id == addresses.c.user_id, - users.c.name == addresses.c.email_address)) - }) + mapper( + Address, + addresses, + properties={ + "user": relationship(User), + "special_user": relationship( + User, + primaryjoin=and_( + users.c.id == addresses.c.user_id, + users.c.name == addresses.c.email_address, + ), + ), + }, + ) def test_filter_with_transient_dont_assume_pk(self): self._fixture1() @@ -4082,7 +4635,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): sa_exc.StatementError, "Can't resolve value for column users.id on object " ".User at .*; no value has been set for this column", - q.all + q.all, ) def test_filter_with_transient_given_pk(self): @@ -4099,7 +4652,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): "addresses.user_id AS addresses_user_id, " "addresses.email_address AS addresses_email_address " "FROM addresses WHERE :param_1 = addresses.user_id", - checkparams={'param_1': None} + checkparams={"param_1": None}, ) def test_filter_with_transient_given_pk_but_only_later(self): @@ -4121,7 +4674,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): "addresses.user_id AS addresses_user_id, " "addresses.email_address AS addresses_email_address " "FROM addresses WHERE :param_1 = addresses.user_id", - checkparams={'param_1': None} + checkparams={"param_1": None}, ) def test_filter_with_transient_warn_for_none_against_non_pk(self): @@ -4130,7 +4683,8 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): s = Session() q = s.query(Address).filter( - Address.special_user == User(id=None, name=None)) + Address.special_user == User(id=None, name=None) + ) with expect_warnings("Got None for value of column"): self.assert_compile( @@ -4140,7 +4694,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): "addresses.email_address AS addresses_email_address " "FROM addresses WHERE :param_1 = addresses.user_id " "AND :param_2 = addresses.email_address", - checkparams={"param_1": None, "param_2": None} + checkparams={"param_1": None, "param_2": None}, ) def test_with_parent_with_transient_assume_pk(self): @@ -4155,7 +4709,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): q, "SELECT users.id AS users_id, users.name AS users_name " "FROM users WHERE users.id = :param_1", - checkparams={'param_1': None} + checkparams={"param_1": None}, ) def test_with_parent_with_transient_warn_for_none_against_non_pk(self): @@ -4164,7 +4718,8 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): s = Session() q = s.query(User).with_parent( - Address(user_id=None, email_address=None), "special_user") + Address(user_id=None, email_address=None), "special_user" + ) with expect_warnings("Got None for value of column"): self.assert_compile( @@ -4172,7 +4727,7 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): "SELECT users.id AS users_id, users.name AS users_name " "FROM users WHERE users.id = :param_1 " "AND users.name = :param_2", - checkparams={"param_1": None, "param_2": None} + checkparams={"param_1": None, "param_2": None}, ) def test_negated_contains_or_equals_plain_m2o(self): @@ -4184,14 +4739,13 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): with expect_warnings("Got None for value of column"): self.assert_compile( q, - "SELECT addresses.id AS addresses_id, " "addresses.user_id AS addresses_user_id, " "addresses.email_address AS addresses_email_address " "FROM addresses " "WHERE addresses.user_id != :user_id_1 " "OR addresses.user_id IS NULL", - checkparams={'user_id_1': None} + checkparams={"user_id_1": None}, ) def test_negated_contains_or_equals_complex_rel(self): @@ -4213,40 +4767,62 @@ class WithTransientOnNone(_fixtures.FixtureTest, AssertsCompiledSQL): "FROM users " "WHERE users.id = addresses.user_id AND " "users.name = addresses.email_address AND users.id IS NULL))", - checkparams={} + checkparams={}, ) class SynonymTest(QueryTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_mappers(cls): - users, Keyword, items, order_items, orders, Item, User, \ - Address, keywords, Order, item_keywords, addresses = \ - cls.tables.users, cls.classes.Keyword, cls.tables.items, \ - cls.tables.order_items, cls.tables.orders, \ - cls.classes.Item, cls.classes.User, cls.classes.Address, \ - cls.tables.keywords, cls.classes.Order, \ - cls.tables.item_keywords, cls.tables.addresses - - mapper(User, users, properties={ - 'name_syn': synonym('name'), - 'addresses': relationship(Address), - 'orders': relationship( - Order, backref='user', order_by=orders.c.id), # o2m, m2o - 'orders_syn': synonym('orders'), - 'orders_syn_2': synonym('orders_syn') - }) + users, Keyword, items, order_items, orders, Item, User, Address, keywords, Order, item_keywords, addresses = ( + cls.tables.users, + cls.classes.Keyword, + cls.tables.items, + cls.tables.order_items, + cls.tables.orders, + cls.classes.Item, + cls.classes.User, + cls.classes.Address, + cls.tables.keywords, + cls.classes.Order, + cls.tables.item_keywords, + cls.tables.addresses, + ) + + mapper( + User, + users, + properties={ + "name_syn": synonym("name"), + "addresses": relationship(Address), + "orders": relationship( + Order, backref="user", order_by=orders.c.id + ), # o2m, m2o + "orders_syn": synonym("orders"), + "orders_syn_2": synonym("orders_syn"), + }, + ) mapper(Address, addresses) - mapper(Order, orders, properties={ - 'items': relationship(Item, secondary=order_items), # m2m - 'address': relationship(Address), # m2o - 'items_syn': synonym('items') - }) - mapper(Item, items, properties={ - 'keywords': relationship(Keyword, secondary=item_keywords) # m2m - }) + mapper( + Order, + orders, + properties={ + "items": relationship(Item, secondary=order_items), # m2m + "address": relationship(Address), # m2o + "items_syn": synonym("items"), + }, + ) + mapper( + Item, + items, + properties={ + "keywords": relationship( + Keyword, secondary=item_keywords + ) # m2m + }, + ) mapper(Keyword, keywords) def test_options(self): @@ -4255,15 +4831,27 @@ class SynonymTest(QueryTest, AssertsCompiledSQL): s = create_session() def go(): - result = s.query(User).filter_by(name='jack').\ - options(joinedload(User.orders_syn)).all() - eq_(result, [ - User(id=7, name='jack', orders=[ - Order(description='order 1'), - Order(description='order 3'), - Order(description='order 5') - ]) - ]) + result = ( + s.query(User) + .filter_by(name="jack") + .options(joinedload(User.orders_syn)) + .all() + ) + eq_( + result, + [ + User( + id=7, + name="jack", + orders=[ + Order(description="order 1"), + Order(description="order 3"), + Order(description="order 5"), + ], + ) + ], + ) + self.assert_sql_count(testing.db, go, 1) def test_options_syn_of_syn(self): @@ -4272,15 +4860,27 @@ class SynonymTest(QueryTest, AssertsCompiledSQL): s = create_session() def go(): - result = s.query(User).filter_by(name='jack').\ - options(joinedload(User.orders_syn_2)).all() - eq_(result, [ - User(id=7, name='jack', orders=[ - Order(description='order 1'), - Order(description='order 3'), - Order(description='order 5') - ]) - ]) + result = ( + s.query(User) + .filter_by(name="jack") + .options(joinedload(User.orders_syn_2)) + .all() + ) + eq_( + result, + [ + User( + id=7, + name="jack", + orders=[ + Order(description="order 1"), + Order(description="order 3"), + Order(description="order 5"), + ], + ) + ], + ) + self.assert_sql_count(testing.db, go, 1) def test_options_syn_of_syn_string(self): @@ -4289,54 +4889,69 @@ class SynonymTest(QueryTest, AssertsCompiledSQL): s = create_session() def go(): - result = s.query(User).filter_by(name='jack').\ - options(joinedload('orders_syn_2')).all() - eq_(result, [ - User(id=7, name='jack', orders=[ - Order(description='order 1'), - Order(description='order 3'), - Order(description='order 5') - ]) - ]) + result = ( + s.query(User) + .filter_by(name="jack") + .options(joinedload("orders_syn_2")) + .all() + ) + eq_( + result, + [ + User( + id=7, + name="jack", + orders=[ + Order(description="order 1"), + Order(description="order 3"), + Order(description="order 5"), + ], + ) + ], + ) + self.assert_sql_count(testing.db, go, 1) def test_joins(self): User, Order = self.classes.User, self.classes.Order for j in ( - ['orders', 'items'], - ['orders_syn', 'items'], + ["orders", "items"], + ["orders_syn", "items"], [User.orders_syn, Order.items], - ['orders_syn_2', 'items'], - [User.orders_syn_2, 'items'], - ['orders', 'items_syn'], - ['orders_syn', 'items_syn'], - ['orders_syn_2', 'items_syn'], + ["orders_syn_2", "items"], + [User.orders_syn_2, "items"], + ["orders", "items_syn"], + ["orders_syn", "items_syn"], + ["orders_syn_2", "items_syn"], ): - result = create_session().query(User).join(*j).filter_by(id=3). \ - all() - assert [User(id=7, name='jack'), User(id=9, name='fred')] == result + result = ( + create_session().query(User).join(*j).filter_by(id=3).all() + ) + assert [User(id=7, name="jack"), User(id=9, name="fred")] == result def test_with_parent(self): Order, User = self.classes.Order, self.classes.User for nameprop, orderprop in ( - ('name', 'orders'), - ('name_syn', 'orders'), - ('name', 'orders_syn'), - ('name', 'orders_syn_2'), - ('name_syn', 'orders_syn'), - ('name_syn', 'orders_syn_2'), + ("name", "orders"), + ("name_syn", "orders"), + ("name", "orders_syn"), + ("name", "orders_syn_2"), + ("name_syn", "orders_syn"), + ("name_syn", "orders_syn_2"), ): sess = create_session() q = sess.query(User) - u1 = q.filter_by(**{nameprop: 'jack'}).one() + u1 = q.filter_by(**{nameprop: "jack"}).one() o = sess.query(Order).with_parent(u1, property=orderprop).all() assert [ - Order(description="order 1"), Order(description="order 3"), - Order(description="order 5")] == o + Order(description="order 1"), + Order(description="order 3"), + Order(description="order 5"), + ] == o def test_froms_aliased_col(self): Address, User = self.classes.Address, self.classes.User @@ -4344,30 +4959,30 @@ class SynonymTest(QueryTest, AssertsCompiledSQL): sess = create_session() ua = aliased(User) - q = sess.query(ua.name_syn).join( - Address, ua.id == Address.user_id) + q = sess.query(ua.name_syn).join(Address, ua.id == Address.user_id) self.assert_compile( q, "SELECT users_1.name AS users_1_name FROM " - "users AS users_1 JOIN addresses ON users_1.id = addresses.user_id" + "users AS users_1 JOIN addresses ON users_1.id = addresses.user_id", ) class ImmediateTest(_fixtures.FixtureTest): - run_inserts = 'once' + run_inserts = "once" run_deletes = None @classmethod def setup_mappers(cls): - Address, addresses, users, User = (cls.classes.Address, - cls.tables.addresses, - cls.tables.users, - cls.classes.User) + Address, addresses, users, User = ( + cls.classes.Address, + cls.tables.addresses, + cls.tables.users, + cls.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address))) + mapper(User, users, properties=dict(addresses=relationship(Address))) def test_one(self): User, Address = self.classes.User, self.classes.Address @@ -4377,47 +4992,65 @@ class ImmediateTest(_fixtures.FixtureTest): assert_raises_message( sa.orm.exc.NoResultFound, r"No row was found for one\(\)", - sess.query(User).filter(User.id == 99).one) + sess.query(User).filter(User.id == 99).one, + ) eq_(sess.query(User).filter(User.id == 7).one().id, 7) assert_raises_message( sa.orm.exc.MultipleResultsFound, r"Multiple rows were found for one\(\)", - sess.query(User).one) + sess.query(User).one, + ) assert_raises( sa.orm.exc.NoResultFound, - sess.query(User.id, User.name).filter(User.id == 99).one) + sess.query(User.id, User.name).filter(User.id == 99).one, + ) - eq_(sess.query(User.id, User.name).filter(User.id == 7).one(), - (7, 'jack')) + eq_( + sess.query(User.id, User.name).filter(User.id == 7).one(), + (7, "jack"), + ) assert_raises( - sa.orm.exc.MultipleResultsFound, - sess.query(User.id, User.name).one) + sa.orm.exc.MultipleResultsFound, sess.query(User.id, User.name).one + ) assert_raises( sa.orm.exc.NoResultFound, - (sess.query(User, Address).join(User.addresses). - filter(Address.id == 99)).one) + ( + sess.query(User, Address) + .join(User.addresses) + .filter(Address.id == 99) + ).one, + ) - eq_((sess.query(User, Address). - join(User.addresses). - filter(Address.id == 4)).one(), - (User(id=8), Address(id=4))) + eq_( + ( + sess.query(User, Address) + .join(User.addresses) + .filter(Address.id == 4) + ).one(), + (User(id=8), Address(id=4)), + ) assert_raises( sa.orm.exc.MultipleResultsFound, - sess.query(User, Address).join(User.addresses).one) + sess.query(User, Address).join(User.addresses).one, + ) # this result returns multiple rows, the first # two rows being the same. but uniquing is # not applied for a column based result. assert_raises( sa.orm.exc.MultipleResultsFound, - sess.query(User.id).join(User.addresses). - filter(User.id.in_([8, 9])).order_by(User.id).one) + sess.query(User.id) + .join(User.addresses) + .filter(User.id.in_([8, 9])) + .order_by(User.id) + .one, + ) # test that a join which ultimately returns # multiple identities across many rows still @@ -4426,8 +5059,12 @@ class ImmediateTest(_fixtures.FixtureTest): # is applied ([ticket:1688]) assert_raises( sa.orm.exc.MultipleResultsFound, - sess.query(User).join(User.addresses).filter(User.id.in_([8, 9])). - order_by(User.id).one) + sess.query(User) + .join(User.addresses) + .filter(User.id.in_([8, 9])) + .order_by(User.id) + .one, + ) def test_one_or_none(self): User, Address = self.classes.User, self.classes.Address @@ -4441,38 +5078,58 @@ class ImmediateTest(_fixtures.FixtureTest): assert_raises_message( sa.orm.exc.MultipleResultsFound, r"Multiple rows were found for one_or_none\(\)", - sess.query(User).one_or_none) + sess.query(User).one_or_none, + ) - eq_(sess.query(User.id, User.name).filter(User.id == 99).one_or_none(), - None) + eq_( + sess.query(User.id, User.name).filter(User.id == 99).one_or_none(), + None, + ) - eq_(sess.query(User.id, User.name).filter(User.id == 7).one_or_none(), - (7, 'jack')) + eq_( + sess.query(User.id, User.name).filter(User.id == 7).one_or_none(), + (7, "jack"), + ) assert_raises( sa.orm.exc.MultipleResultsFound, - sess.query(User.id, User.name).one_or_none) + sess.query(User.id, User.name).one_or_none, + ) eq_( - (sess.query(User, Address).join(User.addresses). - filter(Address.id == 99)).one_or_none(), None) + ( + sess.query(User, Address) + .join(User.addresses) + .filter(Address.id == 99) + ).one_or_none(), + None, + ) - eq_((sess.query(User, Address). - join(User.addresses). - filter(Address.id == 4)).one_or_none(), - (User(id=8), Address(id=4))) + eq_( + ( + sess.query(User, Address) + .join(User.addresses) + .filter(Address.id == 4) + ).one_or_none(), + (User(id=8), Address(id=4)), + ) assert_raises( sa.orm.exc.MultipleResultsFound, - sess.query(User, Address).join(User.addresses).one_or_none) + sess.query(User, Address).join(User.addresses).one_or_none, + ) # this result returns multiple rows, the first # two rows being the same. but uniquing is # not applied for a column based result. assert_raises( sa.orm.exc.MultipleResultsFound, - sess.query(User.id).join(User.addresses). - filter(User.id.in_([8, 9])).order_by(User.id).one_or_none) + sess.query(User.id) + .join(User.addresses) + .filter(User.id.in_([8, 9])) + .order_by(User.id) + .one_or_none, + ) # test that a join which ultimately returns # multiple identities across many rows still @@ -4481,8 +5138,12 @@ class ImmediateTest(_fixtures.FixtureTest): # is applied ([ticket:1688]) assert_raises( sa.orm.exc.MultipleResultsFound, - sess.query(User).join(User.addresses).filter(User.id.in_([8, 9])). - order_by(User.id).one_or_none) + sess.query(User) + .join(User.addresses) + .filter(User.id.in_([8, 9])) + .order_by(User.id) + .one_or_none, + ) @testing.future def test_getslice(self): @@ -4496,13 +5157,16 @@ class ImmediateTest(_fixtures.FixtureTest): eq_(sess.query(User.id).filter_by(id=7).scalar(), 7) eq_(sess.query(User.id, User.name).filter_by(id=7).scalar(), 7) eq_(sess.query(User.id).filter_by(id=0).scalar(), None) - eq_(sess.query(User).filter_by(id=7).scalar(), - sess.query(User).filter_by(id=7).one()) + eq_( + sess.query(User).filter_by(id=7).scalar(), + sess.query(User).filter_by(id=7).one(), + ) assert_raises(sa.orm.exc.MultipleResultsFound, sess.query(User).scalar) assert_raises( sa.orm.exc.MultipleResultsFound, - sess.query(User.id, User.name).scalar) + sess.query(User.id, User.name).scalar, + ) def test_value(self): User = self.classes.User @@ -4514,11 +5178,10 @@ class ImmediateTest(_fixtures.FixtureTest): eq_(sess.query(User).filter_by(id=0).value(User.id), None) sess.bind = testing.db - eq_(sess.query().value(sa.literal_column('1').label('x')), 1) + eq_(sess.query().value(sa.literal_column("1").label("x")), 1) class ExecutionOptionsTest(QueryTest): - def test_option_building(self): User = self.classes.User @@ -4526,34 +5189,35 @@ class ExecutionOptionsTest(QueryTest): q1 = sess.query(User) assert q1._execution_options == dict() - q2 = q1.execution_options(foo='bar', stream_results=True) + q2 = q1.execution_options(foo="bar", stream_results=True) # q1's options should be unchanged. assert q1._execution_options == dict() # q2 should have them set. - assert q2._execution_options == dict(foo='bar', stream_results=True) - q3 = q2.execution_options(foo='not bar', answer=42) - assert q2._execution_options == dict(foo='bar', stream_results=True) + assert q2._execution_options == dict(foo="bar", stream_results=True) + q3 = q2.execution_options(foo="not bar", answer=42) + assert q2._execution_options == dict(foo="bar", stream_results=True) - q3_options = dict(foo='not bar', stream_results=True, answer=42) + q3_options = dict(foo="not bar", stream_results=True, answer=42) assert q3._execution_options == q3_options def test_options_in_connection(self): User = self.classes.User - execution_options = dict(foo='bar', stream_results=True) + execution_options = dict(foo="bar", stream_results=True) class TQuery(Query): def instances(self, result, ctx): try: eq_( - result.connection._execution_options, - execution_options) + result.connection._execution_options, execution_options + ) finally: result.close() return iter([]) sess = create_session( - bind=testing.db, autocommit=False, query_cls=TQuery) + bind=testing.db, autocommit=False, query_cls=TQuery + ) q1 = sess.query(User).execution_options(**execution_options) q1.all() @@ -4569,52 +5233,51 @@ class BooleanEvalTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_one(self): s = Session() - c = column('x', Boolean) + c = column("x", Boolean) self.assert_compile( s.query(c).filter(c), "SELECT x WHERE x", - dialect=self._dialect(True) + dialect=self._dialect(True), ) def test_two(self): s = Session() - c = column('x', Boolean) + c = column("x", Boolean) self.assert_compile( s.query(c).filter(c), "SELECT x WHERE x = 1", - dialect=self._dialect(False) + dialect=self._dialect(False), ) def test_three(self): s = Session() - c = column('x', Boolean) + c = column("x", Boolean) self.assert_compile( s.query(c).filter(~c), "SELECT x WHERE x = 0", - dialect=self._dialect(False) + dialect=self._dialect(False), ) def test_four(self): s = Session() - c = column('x', Boolean) + c = column("x", Boolean) self.assert_compile( s.query(c).filter(~c), "SELECT x WHERE NOT x", - dialect=self._dialect(True) + dialect=self._dialect(True), ) def test_five(self): s = Session() - c = column('x', Boolean) + c = column("x", Boolean) self.assert_compile( s.query(c).having(c), "SELECT x HAVING x = 1", - dialect=self._dialect(False) + dialect=self._dialect(False), ) class SessionBindTest(QueryTest): - @contextlib.contextmanager def _assert_bind_args(self, session): get_bind = mock.Mock(side_effect=session.get_bind) @@ -4622,7 +5285,7 @@ class SessionBindTest(QueryTest): yield for call_ in get_bind.mock_calls: is_(call_[1][0], inspect(self.classes.User)) - is_not_(call_[2]['clause'], None) + is_not_(call_[2]["clause"], None) def test_single_entity_q(self): User = self.classes.User @@ -4653,28 +5316,32 @@ class SessionBindTest(QueryTest): session = Session() with self._assert_bind_args(session): session.query(User).filter(User.id == 15).update( - {"name": "foob"}, synchronize_session=False) + {"name": "foob"}, synchronize_session=False + ) def test_bulk_delete_no_sync(self): User = self.classes.User session = Session() with self._assert_bind_args(session): session.query(User).filter(User.id == 15).delete( - synchronize_session=False) + synchronize_session=False + ) def test_bulk_update_fetch_sync(self): User = self.classes.User session = Session() with self._assert_bind_args(session): session.query(User).filter(User.id == 15).update( - {"name": "foob"}, synchronize_session='fetch') + {"name": "foob"}, synchronize_session="fetch" + ) def test_bulk_delete_fetch_sync(self): User = self.classes.User session = Session() with self._assert_bind_args(session): session.query(User).filter(User.id == 15).delete( - synchronize_session='fetch') + synchronize_session="fetch" + ) def test_column_property(self): User = self.classes.User @@ -4682,12 +5349,12 @@ class SessionBindTest(QueryTest): mapper = inspect(User) mapper.add_property( "score", - column_property(func.coalesce(self.tables.users.c.name, None))) + column_property(func.coalesce(self.tables.users.c.name, None)), + ) session = Session() with self._assert_bind_args(session): session.query(func.max(User.score)).scalar() - @testing.requires.nested_aggregates def test_column_property_select(self): User = self.classes.User @@ -4697,9 +5364,10 @@ class SessionBindTest(QueryTest): mapper.add_property( "score", column_property( - select([func.sum(Address.id)]). - where(Address.user_id == User.id).as_scalar() - ) + select([func.sum(Address.id)]) + .where(Address.user_id == User.id) + .as_scalar() + ), ) session = Session() @@ -4711,6 +5379,7 @@ class QueryClsTest(QueryTest): def _fn_fixture(self): def query(*arg, **kw): return Query(*arg, **kw) + return query def _subclass_fixture(self): @@ -4740,7 +5409,7 @@ class QueryClsTest(QueryTest): assert u is u2 def _test_o2m_lazyload(self, fixture): - User, Address = self.classes('User', 'Address') + User, Address = self.classes("User", "Address") s = Session(query_cls=fixture()) @@ -4748,7 +5417,7 @@ class QueryClsTest(QueryTest): eq_(u1.addresses, [Address(id=1)]) def _test_m2o_lazyload(self, fixture): - User, Address = self.classes('User', 'Address') + User, Address = self.classes("User", "Address") s = Session(query_cls=fixture()) @@ -4756,20 +5425,20 @@ class QueryClsTest(QueryTest): eq_(a1.user, User(id=7)) def _test_expr(self, fixture): - User, Address = self.classes('User', 'Address') + User, Address = self.classes("User", "Address") s = Session(query_cls=fixture()) - q = s.query(func.max(User.id).label('max')) + q = s.query(func.max(User.id).label("max")) eq_(q.scalar(), 10) def _test_expr_undocumented_query_constructor(self, fixture): # see #4269. not documented but already out there. - User, Address = self.classes('User', 'Address') + User, Address = self.classes("User", "Address") s = Session(query_cls=fixture()) - q = Query(func.max(User.id).label('max')).with_session(s) + q = Query(func.max(User.id).label("max")).with_session(s) eq_(q.scalar(), 10) def test_plain_get(self): @@ -4800,12 +5469,10 @@ class QueryClsTest(QueryTest): self._test_expr_undocumented_query_constructor(self._plain_fixture) def test_callable_expr_undocumented_query_constructor(self): - self._test_expr_undocumented_query_constructor( - self._callable_fixture) + self._test_expr_undocumented_query_constructor(self._callable_fixture) def test_subclass_expr_undocumented_query_constructor(self): - self._test_expr_undocumented_query_constructor( - self._subclass_fixture) + self._test_expr_undocumented_query_constructor(self._subclass_fixture) def test_fn_expr_undocumented_query_constructor(self): self._test_expr_undocumented_query_constructor(self._fn_fixture) diff --git a/test/orm/test_rel_fn.py b/test/orm/test_rel_fn.py index 0850501e21..7aac4cecad 100644 --- a/test/orm/test_rel_fn.py +++ b/test/orm/test_rel_fn.py @@ -1,9 +1,25 @@ -from sqlalchemy.testing import assert_raises_message, eq_, \ - AssertsCompiledSQL, is_ +from sqlalchemy.testing import ( + assert_raises_message, + eq_, + AssertsCompiledSQL, + is_, +) from sqlalchemy.testing import fixtures from sqlalchemy.orm import relationships, foreign, remote, relationship -from sqlalchemy import MetaData, Table, Column, ForeignKey, Integer, \ - select, ForeignKeyConstraint, exc, func, and_, String, Boolean +from sqlalchemy import ( + MetaData, + Table, + Column, + ForeignKey, + Integer, + select, + ForeignKeyConstraint, + exc, + func, + and_, + String, + Boolean, +) from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY from sqlalchemy.testing import mock @@ -12,118 +28,161 @@ class _JoinFixtures(object): @classmethod def setup_class(cls): m = MetaData() - cls.left = Table('lft', m, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer)) - cls.right = Table('rgt', m, - Column('id', Integer, primary_key=True), - Column('lid', Integer, ForeignKey('lft.id')), - Column('x', Integer), - Column('y', Integer)) - cls.right_multi_fk = Table('rgt_multi_fk', m, - Column('id', Integer, primary_key=True), - Column('lid1', Integer, - ForeignKey('lft.id')), - Column('lid2', Integer, - ForeignKey('lft.id'))) - - cls.selfref = Table('selfref', m, - Column('id', Integer, primary_key=True), - Column('sid', Integer, ForeignKey('selfref.id'))) - cls.composite_selfref = Table('composite_selfref', m, - Column('id', Integer, primary_key=True), - Column('group_id', Integer, - primary_key=True), - Column('parent_id', Integer), - ForeignKeyConstraint( - ['parent_id', 'group_id'], - ['composite_selfref.id', - 'composite_selfref.group_id'])) - cls.m2mleft = Table('m2mlft', m, - Column('id', Integer, primary_key=True)) - cls.m2mright = Table('m2mrgt', m, - Column('id', Integer, primary_key=True)) - cls.m2msecondary = Table('m2msecondary', m, - Column('lid', Integer, ForeignKey( - 'm2mlft.id'), primary_key=True), - Column('rid', Integer, ForeignKey( - 'm2mrgt.id'), primary_key=True)) - cls.m2msecondary_no_fks = Table('m2msecondary_no_fks', m, - Column('lid', Integer, - primary_key=True), - Column('rid', Integer, - primary_key=True)) - cls.m2msecondary_ambig_fks = Table('m2msecondary_ambig_fks', m, - Column('lid1', Integer, ForeignKey( - 'm2mlft.id'), primary_key=True), - Column('rid1', Integer, ForeignKey( - 'm2mrgt.id'), primary_key=True), - Column('lid2', Integer, ForeignKey( - 'm2mlft.id'), primary_key=True), - Column('rid2', Integer, ForeignKey( - 'm2mrgt.id'), primary_key=True)) - cls.base_w_sub_rel = Table('base_w_sub_rel', m, - Column('id', Integer, primary_key=True), - Column('sub_id', Integer, - ForeignKey('rel_sub.id'))) - cls.rel_sub = Table('rel_sub', m, - Column('id', Integer, - ForeignKey('base_w_sub_rel.id'), - primary_key=True)) - cls.base = Table('base', m, - Column('id', Integer, primary_key=True), - Column('flag', Boolean)) - cls.sub = Table('sub', m, - Column('id', Integer, ForeignKey('base.id'), - primary_key=True)) - cls.sub_w_base_rel = Table('sub_w_base_rel', m, - Column('id', Integer, ForeignKey('base.id'), - primary_key=True), - Column('base_id', Integer, - ForeignKey('base.id'))) - cls.sub_w_sub_rel = Table('sub_w_sub_rel', m, - Column('id', Integer, ForeignKey('base.id'), - primary_key=True), - Column('sub_id', Integer, - ForeignKey('sub.id')) - ) - cls.right_w_base_rel = Table('right_w_base_rel', m, - Column('id', Integer, primary_key=True), - Column('base_id', Integer, - ForeignKey('base.id'))) - - cls.three_tab_a = Table('three_tab_a', m, - Column('id', Integer, primary_key=True)) - cls.three_tab_b = Table('three_tab_b', m, - Column('id', Integer, primary_key=True), - Column('aid', Integer, ForeignKey( - 'three_tab_a.id'))) - cls.three_tab_c = Table('three_tab_c', m, - Column('id', Integer, primary_key=True), - Column('aid', Integer, ForeignKey( - 'three_tab_a.id')), - Column('bid', Integer, ForeignKey( - 'three_tab_b.id'))) - - cls.composite_target = Table('composite_target', m, - Column('uid', Integer, primary_key=True), - Column('oid', Integer, primary_key=True)) + cls.left = Table( + "lft", + m, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) + cls.right = Table( + "rgt", + m, + Column("id", Integer, primary_key=True), + Column("lid", Integer, ForeignKey("lft.id")), + Column("x", Integer), + Column("y", Integer), + ) + cls.right_multi_fk = Table( + "rgt_multi_fk", + m, + Column("id", Integer, primary_key=True), + Column("lid1", Integer, ForeignKey("lft.id")), + Column("lid2", Integer, ForeignKey("lft.id")), + ) + + cls.selfref = Table( + "selfref", + m, + Column("id", Integer, primary_key=True), + Column("sid", Integer, ForeignKey("selfref.id")), + ) + cls.composite_selfref = Table( + "composite_selfref", + m, + Column("id", Integer, primary_key=True), + Column("group_id", Integer, primary_key=True), + Column("parent_id", Integer), + ForeignKeyConstraint( + ["parent_id", "group_id"], + ["composite_selfref.id", "composite_selfref.group_id"], + ), + ) + cls.m2mleft = Table( + "m2mlft", m, Column("id", Integer, primary_key=True) + ) + cls.m2mright = Table( + "m2mrgt", m, Column("id", Integer, primary_key=True) + ) + cls.m2msecondary = Table( + "m2msecondary", + m, + Column("lid", Integer, ForeignKey("m2mlft.id"), primary_key=True), + Column("rid", Integer, ForeignKey("m2mrgt.id"), primary_key=True), + ) + cls.m2msecondary_no_fks = Table( + "m2msecondary_no_fks", + m, + Column("lid", Integer, primary_key=True), + Column("rid", Integer, primary_key=True), + ) + cls.m2msecondary_ambig_fks = Table( + "m2msecondary_ambig_fks", + m, + Column("lid1", Integer, ForeignKey("m2mlft.id"), primary_key=True), + Column("rid1", Integer, ForeignKey("m2mrgt.id"), primary_key=True), + Column("lid2", Integer, ForeignKey("m2mlft.id"), primary_key=True), + Column("rid2", Integer, ForeignKey("m2mrgt.id"), primary_key=True), + ) + cls.base_w_sub_rel = Table( + "base_w_sub_rel", + m, + Column("id", Integer, primary_key=True), + Column("sub_id", Integer, ForeignKey("rel_sub.id")), + ) + cls.rel_sub = Table( + "rel_sub", + m, + Column( + "id", + Integer, + ForeignKey("base_w_sub_rel.id"), + primary_key=True, + ), + ) + cls.base = Table( + "base", + m, + Column("id", Integer, primary_key=True), + Column("flag", Boolean), + ) + cls.sub = Table( + "sub", + m, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + ) + cls.sub_w_base_rel = Table( + "sub_w_base_rel", + m, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column("base_id", Integer, ForeignKey("base.id")), + ) + cls.sub_w_sub_rel = Table( + "sub_w_sub_rel", + m, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column("sub_id", Integer, ForeignKey("sub.id")), + ) + cls.right_w_base_rel = Table( + "right_w_base_rel", + m, + Column("id", Integer, primary_key=True), + Column("base_id", Integer, ForeignKey("base.id")), + ) + + cls.three_tab_a = Table( + "three_tab_a", m, Column("id", Integer, primary_key=True) + ) + cls.three_tab_b = Table( + "three_tab_b", + m, + Column("id", Integer, primary_key=True), + Column("aid", Integer, ForeignKey("three_tab_a.id")), + ) + cls.three_tab_c = Table( + "three_tab_c", + m, + Column("id", Integer, primary_key=True), + Column("aid", Integer, ForeignKey("three_tab_a.id")), + Column("bid", Integer, ForeignKey("three_tab_b.id")), + ) + + cls.composite_target = Table( + "composite_target", + m, + Column("uid", Integer, primary_key=True), + Column("oid", Integer, primary_key=True), + ) cls.composite_multi_ref = Table( - 'composite_multi_ref', m, - Column('uid1', Integer), - Column('uid2', Integer), - Column('oid', Integer), - ForeignKeyConstraint(("uid1", "oid"), - ("composite_target.uid", - "composite_target.oid")), - ForeignKeyConstraint(("uid2", "oid"), - ("composite_target.uid", - "composite_target.oid"))) - - cls.purely_single_col = Table('purely_single_col', m, - Column('path', String)) + "composite_multi_ref", + m, + Column("uid1", Integer), + Column("uid2", Integer), + Column("oid", Integer), + ForeignKeyConstraint( + ("uid1", "oid"), + ("composite_target.uid", "composite_target.oid"), + ), + ForeignKeyConstraint( + ("uid2", "oid"), + ("composite_target.uid", "composite_target.oid"), + ), + ) + + cls.purely_single_col = Table( + "purely_single_col", m, Column("path", String) + ) def _join_fixture_overlapping_three_tables(self, **kw): def _can_sync(*cols): @@ -132,6 +191,7 @@ class _JoinFixtures(object): return False else: return True + return relationships.JoinCondition( self.three_tab_a, self.three_tab_b, @@ -142,8 +202,8 @@ class _JoinFixtures(object): primaryjoin=and_( self.three_tab_a.c.id == self.three_tab_b.c.aid, self.three_tab_c.c.bid == self.three_tab_b.c.id, - self.three_tab_c.c.aid == self.three_tab_a.c.id - ) + self.three_tab_c.c.aid == self.three_tab_a.c.id, + ), ) def _join_fixture_m2m(self, **kw): @@ -162,41 +222,32 @@ class _JoinFixtures(object): """ j1 = self._join_fixture_m2m() - return j1, relationships.JoinCondition( - self.m2mright, - self.m2mleft, - self.m2mright, - self.m2mleft, - secondary=self.m2msecondary, - primaryjoin=j1.secondaryjoin_minus_local, - secondaryjoin=j1.primaryjoin_minus_local + return ( + j1, + relationships.JoinCondition( + self.m2mright, + self.m2mleft, + self.m2mright, + self.m2mleft, + secondary=self.m2msecondary, + primaryjoin=j1.secondaryjoin_minus_local, + secondaryjoin=j1.primaryjoin_minus_local, + ), ) def _join_fixture_o2m(self, **kw): return relationships.JoinCondition( - self.left, - self.right, - self.left, - self.right, - **kw + self.left, self.right, self.left, self.right, **kw ) def _join_fixture_m2o(self, **kw): return relationships.JoinCondition( - self.right, - self.left, - self.right, - self.left, - **kw + self.right, self.left, self.right, self.left, **kw ) def _join_fixture_o2m_selfref(self, **kw): return relationships.JoinCondition( - self.selfref, - self.selfref, - self.selfref, - self.selfref, - **kw + self.selfref, self.selfref, self.selfref, self.selfref, **kw ) def _join_fixture_m2o_selfref(self, **kw): @@ -224,8 +275,12 @@ class _JoinFixtures(object): self.composite_selfref, self.composite_selfref, self.composite_selfref, - remote_side=set([self.composite_selfref.c.id, - self.composite_selfref.c.group_id]), + remote_side=set( + [ + self.composite_selfref.c.id, + self.composite_selfref.c.group_id, + ] + ), **kw ) @@ -236,10 +291,10 @@ class _JoinFixtures(object): self.composite_selfref, self.composite_selfref, primaryjoin=and_( - self.composite_selfref.c.group_id == - func.foo(self.composite_selfref.c.group_id), - self.composite_selfref.c.parent_id == - self.composite_selfref.c.id + self.composite_selfref.c.group_id + == func.foo(self.composite_selfref.c.group_id), + self.composite_selfref.c.parent_id + == self.composite_selfref.c.id, ), **kw ) @@ -251,10 +306,10 @@ class _JoinFixtures(object): self.composite_selfref, self.composite_selfref, primaryjoin=and_( - self.composite_selfref.c.group_id == - func.foo(self.composite_selfref.c.group_id), - self.composite_selfref.c.parent_id == - self.composite_selfref.c.id + self.composite_selfref.c.group_id + == func.foo(self.composite_selfref.c.group_id), + self.composite_selfref.c.parent_id + == self.composite_selfref.c.id, ), remote_side=set([self.composite_selfref.c.parent_id]), **kw @@ -267,10 +322,10 @@ class _JoinFixtures(object): self.composite_selfref, self.composite_selfref, primaryjoin=and_( - remote(self.composite_selfref.c.group_id) == - func.foo(self.composite_selfref.c.group_id), - remote(self.composite_selfref.c.parent_id) == - self.composite_selfref.c.id + remote(self.composite_selfref.c.group_id) + == func.foo(self.composite_selfref.c.group_id), + remote(self.composite_selfref.c.parent_id) + == self.composite_selfref.c.id, ), **kw ) @@ -281,10 +336,10 @@ class _JoinFixtures(object): self.right, self.left, self.right, - primaryjoin=(self.left.c.x + self.left.c.y) == - relationships.remote(relationships.foreign( - self.right.c.x * self.right.c.y - )), + primaryjoin=(self.left.c.x + self.left.c.y) + == relationships.remote( + relationships.foreign(self.right.c.x * self.right.c.y) + ), **kw ) @@ -294,10 +349,8 @@ class _JoinFixtures(object): self.right, self.left, self.right, - primaryjoin=(self.left.c.x + self.left.c.y) == - relationships.foreign( - self.right.c.x * self.right.c.y - ), + primaryjoin=(self.left.c.x + self.left.c.y) + == relationships.foreign(self.right.c.x * self.right.c.y), **kw ) @@ -307,10 +360,8 @@ class _JoinFixtures(object): self.right, self.left, self.right, - primaryjoin=(self.left.c.x + self.left.c.y) == - ( - self.right.c.x * self.right.c.y - ), + primaryjoin=(self.left.c.x + self.left.c.y) + == (self.right.c.x * self.right.c.y), **kw ) @@ -318,28 +369,27 @@ class _JoinFixtures(object): # see test/orm/inheritance/test_abc_inheritance:TestaTobM2O # and others there right = self.base_w_sub_rel.join( - self.rel_sub, - self.base_w_sub_rel.c.id == self.rel_sub.c.id + self.rel_sub, self.base_w_sub_rel.c.id == self.rel_sub.c.id ) return relationships.JoinCondition( self.base_w_sub_rel, right, self.base_w_sub_rel, self.rel_sub, - primaryjoin=self.base_w_sub_rel.c.sub_id == - self.rel_sub.c.id, + primaryjoin=self.base_w_sub_rel.c.sub_id == self.rel_sub.c.id, **kw ) def _join_fixture_o2m_joined_sub_to_base(self, **kw): - left = self.base.join(self.sub_w_base_rel, - self.base.c.id == self.sub_w_base_rel.c.id) + left = self.base.join( + self.sub_w_base_rel, self.base.c.id == self.sub_w_base_rel.c.id + ) return relationships.JoinCondition( left, self.base, self.sub_w_base_rel, self.base, - primaryjoin=self.sub_w_base_rel.c.base_id == self.base.c.id + primaryjoin=self.sub_w_base_rel.c.base_id == self.base.c.id, ) def _join_fixture_m2o_joined_sub_to_sub_on_base(self, **kw): @@ -347,8 +397,9 @@ class _JoinFixtures(object): # in #2491 where we join on the base cols instead. only # m2o has a problem at the time of this test. left = self.base.join(self.sub, self.base.c.id == self.sub.c.id) - right = self.base.join(self.sub_w_base_rel, - self.base.c.id == self.sub_w_base_rel.c.id) + right = self.base.join( + self.sub_w_base_rel, self.base.c.id == self.sub_w_base_rel.c.id + ) return relationships.JoinCondition( left, right, @@ -359,20 +410,22 @@ class _JoinFixtures(object): def _join_fixture_o2m_joined_sub_to_sub(self, **kw): left = self.base.join(self.sub, self.base.c.id == self.sub.c.id) - right = self.base.join(self.sub_w_sub_rel, - self.base.c.id == self.sub_w_sub_rel.c.id) + right = self.base.join( + self.sub_w_sub_rel, self.base.c.id == self.sub_w_sub_rel.c.id + ) return relationships.JoinCondition( left, right, self.sub, self.sub_w_sub_rel, - primaryjoin=self.sub.c.id == self.sub_w_sub_rel.c.sub_id + primaryjoin=self.sub.c.id == self.sub_w_sub_rel.c.sub_id, ) def _join_fixture_m2o_sub_to_joined_sub(self, **kw): # see test.orm.test_mapper:MapperTest.test_add_column_prop_deannotate, - right = self.base.join(self.right_w_base_rel, - self.base.c.id == self.right_w_base_rel.c.id) + right = self.base.join( + self.right_w_base_rel, self.base.c.id == self.right_w_base_rel.c.id + ) return relationships.JoinCondition( self.right_w_base_rel, right, @@ -382,28 +435,23 @@ class _JoinFixtures(object): def _join_fixture_m2o_sub_to_joined_sub_func(self, **kw): # see test.orm.test_mapper:MapperTest.test_add_column_prop_deannotate, - right = self.base.join(self.right_w_base_rel, - self.base.c.id == self.right_w_base_rel.c.id) + right = self.base.join( + self.right_w_base_rel, self.base.c.id == self.right_w_base_rel.c.id + ) return relationships.JoinCondition( self.right_w_base_rel, right, self.right_w_base_rel, self.right_w_base_rel, - primaryjoin=self.right_w_base_rel.c.base_id == - func.foo(self.base.c.id) + primaryjoin=self.right_w_base_rel.c.base_id + == func.foo(self.base.c.id), ) def _join_fixture_o2o_joined_sub_to_base(self, **kw): - left = self.base.join(self.sub, - self.base.c.id == self.sub.c.id) + left = self.base.join(self.sub, self.base.c.id == self.sub.c.id) # see test_relationships->AmbiguousJoinInterpretedAsSelfRef - return relationships.JoinCondition( - left, - self.sub, - left, - self.sub, - ) + return relationships.JoinCondition(left, self.sub, left, self.sub) def _join_fixture_o2m_to_annotated_func(self, **kw): return relationships.JoinCondition( @@ -411,8 +459,7 @@ class _JoinFixtures(object): self.right, self.left, self.right, - primaryjoin=self.left.c.id == - foreign(func.foo(self.right.c.lid)), + primaryjoin=self.left.c.id == foreign(func.foo(self.right.c.lid)), **kw ) @@ -422,8 +469,7 @@ class _JoinFixtures(object): self.right, self.left, self.right, - primaryjoin=self.left.c.id == - func.foo(self.right.c.lid), + primaryjoin=self.left.c.id == func.foo(self.right.c.lid), consider_as_foreign_keys=[self.right.c.lid], **kw ) @@ -434,20 +480,28 @@ class _JoinFixtures(object): self.composite_multi_ref, self.composite_target, self.composite_multi_ref, - consider_as_foreign_keys=[self.composite_multi_ref.c.uid2, - self.composite_multi_ref.c.oid], + consider_as_foreign_keys=[ + self.composite_multi_ref.c.uid2, + self.composite_multi_ref.c.oid, + ], **kw ) - cls.left = Table('lft', m, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer)) - cls.right = Table('rgt', m, - Column('id', Integer, primary_key=True), - Column('lid', Integer, ForeignKey('lft.id')), - Column('x', Integer), - Column('y', Integer)) + cls.left = Table( + "lft", + m, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) + cls.right = Table( + "rgt", + m, + Column("id", Integer, primary_key=True), + Column("lid", Integer, ForeignKey("lft.id")), + Column("x", Integer), + Column("y", Integer), + ) def _join_fixture_o2m_o_side_none(self, **kw): return relationships.JoinCondition( @@ -455,8 +509,9 @@ class _JoinFixtures(object): self.right, self.left, self.right, - primaryjoin=and_(self.left.c.id == self.right.c.lid, - self.left.c.x == 5), + primaryjoin=and_( + self.left.c.id == self.right.c.lid, self.left.c.x == 5 + ), **kw ) @@ -468,12 +523,8 @@ class _JoinFixtures(object): self.purely_single_col, support_sync=False, primaryjoin=self.purely_single_col.c.path.like( - remote( - foreign( - self.purely_single_col.c.path.concat('%') - ) - ) - ) + remote(foreign(self.purely_single_col.c.path.concat("%"))) + ), ) def _join_fixture_purely_single_m2o(self, **kw): @@ -484,30 +535,33 @@ class _JoinFixtures(object): self.purely_single_col, support_sync=False, primaryjoin=remote(self.purely_single_col.c.path).like( - foreign(self.purely_single_col.c.path.concat('%')) - ) + foreign(self.purely_single_col.c.path.concat("%")) + ), ) def _join_fixture_remote_local_multiple_ref(self, **kw): - def fn(a, b): return ((a == b) | (b == a)) + def fn(a, b): + return (a == b) | (b == a) + return relationships.JoinCondition( - self.selfref, self.selfref, - self.selfref, self.selfref, + self.selfref, + self.selfref, + self.selfref, + self.selfref, support_sync=False, primaryjoin=fn( # we're putting a do-nothing annotation on # "a" so that the left/right is preserved; # annotation vs. non seems to affect __eq__ behavior self.selfref.c.sid._annotate({"foo": "bar"}), - foreign(remote(self.selfref.c.sid))) + foreign(remote(self.selfref.c.sid)), + ), ) def _join_fixture_inh_selfref_w_entity(self, **kw): fake_logger = mock.Mock(info=lambda *arg, **kw: None) prop = mock.Mock( - parent=mock.Mock(), - mapper=mock.Mock(), - logger=fake_logger + parent=mock.Mock(), mapper=mock.Mock(), logger=fake_logger ) local_selectable = self.base.join(self.sub) remote_selectable = self.base.join(self.sub_w_sub_rel) @@ -516,18 +570,22 @@ class _JoinFixtures(object): # present in the columns ahead of time sub_w_sub_rel__sub_id = self.sub_w_sub_rel.c.sub_id._annotate( - {'parentmapper': prop.mapper}) - sub__id = self.sub.c.id._annotate({'parentmapper': prop.parent}) + {"parentmapper": prop.mapper} + ) + sub__id = self.sub.c.id._annotate({"parentmapper": prop.parent}) sub_w_sub_rel__flag = self.base.c.flag._annotate( - {"parentmapper": prop.mapper}) + {"parentmapper": prop.mapper} + ) return relationships.JoinCondition( - local_selectable, remote_selectable, - local_selectable, remote_selectable, + local_selectable, + remote_selectable, + local_selectable, + remote_selectable, primaryjoin=and_( sub_w_sub_rel__sub_id == sub__id, - sub_w_sub_rel__flag == True # noqa + sub_w_sub_rel__flag == True, # noqa ), - prop=prop + prop=prop, ) def _assert_non_simple_warning(self, fn): @@ -537,11 +595,12 @@ class _JoinFixtures(object): "primary join condition for property " r"None - consider using remote\(\) " "annotations to mark the remote side.", - fn + fn, ) - def _assert_raises_no_relevant_fks(self, fn, expr, relname, - primary, *arg, **kw): + def _assert_raises_no_relevant_fks( + self, fn, expr, relname, primary, *arg, **kw + ): assert_raises_message( exc.ArgumentError, r"Could not locate any relevant foreign key columns " @@ -549,14 +608,15 @@ class _JoinFixtures(object): r"Ensure that referencing columns are associated with " r"a ForeignKey or ForeignKeyConstraint, or are annotated " r"in the join condition with the foreign\(\) annotation." - % ( - primary, expr, relname - ), - fn, *arg, **kw + % (primary, expr, relname), + fn, + *arg, + **kw ) - def _assert_raises_no_equality(self, fn, expr, relname, - primary, *arg, **kw): + def _assert_raises_no_equality( + self, fn, expr, relname, primary, *arg, **kw + ): assert_raises_message( exc.ArgumentError, "Could not locate any simple equality expressions " @@ -566,14 +626,16 @@ class _JoinFixtures(object): "ForeignKey or ForeignKeyConstraint, or are annotated in " r"the join condition with the foreign\(\) annotation. " "To allow comparison operators other than '==', " - "the relationship can be marked as viewonly=True." % ( - primary, expr, relname - ), - fn, *arg, **kw + "the relationship can be marked as viewonly=True." + % (primary, expr, relname), + fn, + *arg, + **kw ) - def _assert_raises_ambig_join(self, fn, relname, secondary_arg, - *arg, **kw): + def _assert_raises_ambig_join( + self, fn, relname, secondary_arg, *arg, **kw + ): if secondary_arg is not None: assert_raises_message( exc.AmbiguousForeignKeysError, @@ -586,7 +648,10 @@ class _JoinFixtures(object): "containing a foreign key reference from the " "secondary table to each of the parent and child tables." % (relname, secondary_arg), - fn, *arg, **kw) + fn, + *arg, + **kw + ) else: assert_raises_message( exc.AmbiguousForeignKeysError, @@ -594,10 +659,12 @@ class _JoinFixtures(object): "parent/child tables on relationship %s - " "there are no foreign keys linking these tables. " % (relname,), - fn, *arg, **kw) + fn, + *arg, + **kw + ) - def _assert_raises_no_join(self, fn, relname, secondary_arg, - *arg, **kw): + def _assert_raises_no_join(self, fn, relname, secondary_arg, *arg, **kw): if secondary_arg is not None: assert_raises_message( exc.NoForeignKeysError, @@ -608,9 +675,11 @@ class _JoinFixtures(object): "Ensure that referencing columns are associated " "with a ForeignKey " "or ForeignKeyConstraint, or specify 'primaryjoin' and " - "'secondaryjoin' expressions" - % (relname, secondary_arg), - fn, *arg, **kw) + "'secondaryjoin' expressions" % (relname, secondary_arg), + fn, + *arg, + **kw + ) else: assert_raises_message( exc.NoForeignKeysError, @@ -620,109 +689,89 @@ class _JoinFixtures(object): "Ensure that referencing columns are associated " "with a ForeignKey " "or ForeignKeyConstraint, or specify a 'primaryjoin' " - "expression." - % (relname,), - fn, *arg, **kw) + "expression." % (relname,), + fn, + *arg, + **kw + ) -class ColumnCollectionsTest(_JoinFixtures, fixtures.TestBase, - AssertsCompiledSQL): +class ColumnCollectionsTest( + _JoinFixtures, fixtures.TestBase, AssertsCompiledSQL +): def test_determine_local_remote_pairs_o2o_joined_sub_to_base(self): joincond = self._join_fixture_o2o_joined_sub_to_base() - eq_( - joincond.local_remote_pairs, - [(self.base.c.id, self.sub.c.id)] - ) + eq_(joincond.local_remote_pairs, [(self.base.c.id, self.sub.c.id)]) def test_determine_synchronize_pairs_o2m_to_annotated_func(self): joincond = self._join_fixture_o2m_to_annotated_func() - eq_( - joincond.synchronize_pairs, - [(self.left.c.id, self.right.c.lid)] - ) + eq_(joincond.synchronize_pairs, [(self.left.c.id, self.right.c.lid)]) def test_determine_synchronize_pairs_o2m_to_oldstyle_func(self): joincond = self._join_fixture_o2m_to_oldstyle_func() - eq_( - joincond.synchronize_pairs, - [(self.left.c.id, self.right.c.lid)] - ) + eq_(joincond.synchronize_pairs, [(self.left.c.id, self.right.c.lid)]) def test_determinelocal_remote_m2o_joined_sub_to_sub_on_base(self): joincond = self._join_fixture_m2o_joined_sub_to_sub_on_base() eq_( joincond.local_remote_pairs, - [(self.base.c.id, self.sub_w_base_rel.c.base_id)] + [(self.base.c.id, self.sub_w_base_rel.c.base_id)], ) def test_determine_local_remote_base_to_joined_sub(self): joincond = self._join_fixture_base_to_joined_sub() eq_( joincond.local_remote_pairs, - [ - (self.base_w_sub_rel.c.sub_id, self.rel_sub.c.id) - ] + [(self.base_w_sub_rel.c.sub_id, self.rel_sub.c.id)], ) def test_determine_local_remote_o2m_joined_sub_to_base(self): joincond = self._join_fixture_o2m_joined_sub_to_base() eq_( joincond.local_remote_pairs, - [ - (self.sub_w_base_rel.c.base_id, self.base.c.id) - ] + [(self.sub_w_base_rel.c.base_id, self.base.c.id)], ) def test_determine_local_remote_m2o_sub_to_joined_sub(self): joincond = self._join_fixture_m2o_sub_to_joined_sub() eq_( joincond.local_remote_pairs, - [ - (self.right_w_base_rel.c.base_id, self.base.c.id) - ] + [(self.right_w_base_rel.c.base_id, self.base.c.id)], ) def test_determine_remote_columns_o2m_joined_sub_to_sub(self): joincond = self._join_fixture_o2m_joined_sub_to_sub() eq_( joincond.local_remote_pairs, - [ - (self.sub.c.id, self.sub_w_sub_rel.c.sub_id) - ] + [(self.sub.c.id, self.sub_w_sub_rel.c.sub_id)], ) def test_determine_remote_columns_compound_1(self): - joincond = self._join_fixture_compound_expression_1( - support_sync=False) - eq_( - joincond.remote_columns, - set([self.right.c.x, self.right.c.y]) - ) + joincond = self._join_fixture_compound_expression_1(support_sync=False) + eq_(joincond.remote_columns, set([self.right.c.x, self.right.c.y])) def test_determine_local_remote_compound_1(self): - joincond = self._join_fixture_compound_expression_1( - support_sync=False) + joincond = self._join_fixture_compound_expression_1(support_sync=False) eq_( joincond.local_remote_pairs, [ (self.left.c.x, self.right.c.x), (self.left.c.x, self.right.c.y), (self.left.c.y, self.right.c.x), - (self.left.c.y, self.right.c.y) - ] + (self.left.c.y, self.right.c.y), + ], ) def test_determine_local_remote_compound_2(self): - joincond = self._join_fixture_compound_expression_2( - support_sync=False) + joincond = self._join_fixture_compound_expression_2(support_sync=False) eq_( joincond.local_remote_pairs, [ (self.left.c.x, self.right.c.x), (self.left.c.x, self.right.c.y), (self.left.c.y, self.right.c.x), - (self.left.c.y, self.right.c.y) - ] + (self.left.c.y, self.right.c.y), + ], ) def test_determine_local_remote_compound_3(self): @@ -734,52 +783,48 @@ class ColumnCollectionsTest(_JoinFixtures, fixtures.TestBase, (self.left.c.x, self.right.c.y), (self.left.c.y, self.right.c.x), (self.left.c.y, self.right.c.y), - ] + ], ) def test_err_local_remote_compound_1(self): self._assert_raises_no_relevant_fks( self._join_fixture_compound_expression_1_non_annotated, - r'lft.x \+ lft.y = rgt.x \* rgt.y', - "None", "primary" + r"lft.x \+ lft.y = rgt.x \* rgt.y", + "None", + "primary", ) def test_determine_remote_columns_compound_2(self): - joincond = self._join_fixture_compound_expression_2( - support_sync=False) - eq_( - joincond.remote_columns, - set([self.right.c.x, self.right.c.y]) - ) + joincond = self._join_fixture_compound_expression_2(support_sync=False) + eq_(joincond.remote_columns, set([self.right.c.x, self.right.c.y])) def test_determine_remote_columns_o2m(self): joincond = self._join_fixture_o2m() - eq_( - joincond.remote_columns, - set([self.right.c.lid]) - ) + eq_(joincond.remote_columns, set([self.right.c.lid])) def test_determine_remote_columns_o2m_selfref(self): joincond = self._join_fixture_o2m_selfref() - eq_( - joincond.remote_columns, - set([self.selfref.c.sid]) - ) + eq_(joincond.remote_columns, set([self.selfref.c.sid])) def test_determine_local_remote_pairs_o2m_composite_selfref(self): joincond = self._join_fixture_o2m_composite_selfref() eq_( joincond.local_remote_pairs, [ - (self.composite_selfref.c.group_id, - self.composite_selfref.c.group_id), - (self.composite_selfref.c.id, - self.composite_selfref.c.parent_id), - ] + ( + self.composite_selfref.c.group_id, + self.composite_selfref.c.group_id, + ), + ( + self.composite_selfref.c.id, + self.composite_selfref.c.parent_id, + ), + ], ) def test_determine_local_remote_pairs_o2m_composite_selfref_func_warning( - self): + self + ): self._assert_non_simple_warning( self._join_fixture_o2m_composite_selfref_func ) @@ -794,122 +839,117 @@ class ColumnCollectionsTest(_JoinFixtures, fixtures.TestBase, ) def test_determine_local_remote_pairs_o2m_composite_selfref_func_annotated( - self): + self + ): joincond = self._join_fixture_o2m_composite_selfref_func_annotated() eq_( joincond.local_remote_pairs, [ - (self.composite_selfref.c.group_id, - self.composite_selfref.c.group_id), - (self.composite_selfref.c.id, - self.composite_selfref.c.parent_id), - ] + ( + self.composite_selfref.c.group_id, + self.composite_selfref.c.group_id, + ), + ( + self.composite_selfref.c.id, + self.composite_selfref.c.parent_id, + ), + ], ) def test_determine_remote_columns_m2o_composite_selfref(self): joincond = self._join_fixture_m2o_composite_selfref() eq_( joincond.remote_columns, - set([self.composite_selfref.c.id, - self.composite_selfref.c.group_id]) + set( + [ + self.composite_selfref.c.id, + self.composite_selfref.c.group_id, + ] + ), ) def test_determine_remote_columns_m2o(self): joincond = self._join_fixture_m2o() - eq_( - joincond.remote_columns, - set([self.left.c.id]) - ) + eq_(joincond.remote_columns, set([self.left.c.id])) def test_determine_local_remote_pairs_o2m(self): joincond = self._join_fixture_o2m() - eq_( - joincond.local_remote_pairs, - [(self.left.c.id, self.right.c.lid)] - ) + eq_(joincond.local_remote_pairs, [(self.left.c.id, self.right.c.lid)]) def test_determine_synchronize_pairs_m2m(self): joincond = self._join_fixture_m2m() eq_( joincond.synchronize_pairs, - [(self.m2mleft.c.id, self.m2msecondary.c.lid)] + [(self.m2mleft.c.id, self.m2msecondary.c.lid)], ) eq_( joincond.secondary_synchronize_pairs, - [(self.m2mright.c.id, self.m2msecondary.c.rid)] + [(self.m2mright.c.id, self.m2msecondary.c.rid)], ) def test_determine_local_remote_pairs_o2m_backref(self): joincond = self._join_fixture_o2m() joincond2 = self._join_fixture_m2o( - primaryjoin=joincond.primaryjoin_reverse_remote, - ) - eq_( - joincond2.local_remote_pairs, - [(self.right.c.lid, self.left.c.id)] + primaryjoin=joincond.primaryjoin_reverse_remote ) + eq_(joincond2.local_remote_pairs, [(self.right.c.lid, self.left.c.id)]) def test_determine_local_remote_pairs_m2m(self): joincond = self._join_fixture_m2m() eq_( joincond.local_remote_pairs, - [(self.m2mleft.c.id, self.m2msecondary.c.lid), - (self.m2mright.c.id, self.m2msecondary.c.rid)] + [ + (self.m2mleft.c.id, self.m2msecondary.c.lid), + (self.m2mright.c.id, self.m2msecondary.c.rid), + ], ) def test_determine_local_remote_pairs_m2m_backref(self): j1, j2 = self._join_fixture_m2m_backref() eq_( j1.local_remote_pairs, - [(self.m2mleft.c.id, self.m2msecondary.c.lid), - (self.m2mright.c.id, self.m2msecondary.c.rid)] + [ + (self.m2mleft.c.id, self.m2msecondary.c.lid), + (self.m2mright.c.id, self.m2msecondary.c.rid), + ], ) eq_( j2.local_remote_pairs, [ (self.m2mright.c.id, self.m2msecondary.c.rid), (self.m2mleft.c.id, self.m2msecondary.c.lid), - ] + ], ) def test_determine_local_columns_m2m_backref(self): j1, j2 = self._join_fixture_m2m_backref() - eq_( - j1.local_columns, - set([self.m2mleft.c.id]) - ) - eq_( - j2.local_columns, - set([self.m2mright.c.id]) - ) + eq_(j1.local_columns, set([self.m2mleft.c.id])) + eq_(j2.local_columns, set([self.m2mright.c.id])) def test_determine_remote_columns_m2m_backref(self): j1, j2 = self._join_fixture_m2m_backref() eq_( j1.remote_columns, - set([self.m2msecondary.c.lid, self.m2msecondary.c.rid]) + set([self.m2msecondary.c.lid, self.m2msecondary.c.rid]), ) eq_( j2.remote_columns, - set([self.m2msecondary.c.lid, self.m2msecondary.c.rid]) + set([self.m2msecondary.c.lid, self.m2msecondary.c.rid]), ) def test_determine_remote_columns_m2o_selfref(self): joincond = self._join_fixture_m2o_selfref() - eq_( - joincond.remote_columns, - set([self.selfref.c.id]) - ) + eq_(joincond.remote_columns, set([self.selfref.c.id])) def test_determine_local_remote_cols_three_tab_viewonly(self): joincond = self._join_fixture_overlapping_three_tables() eq_( joincond.local_remote_pairs, - [(self.three_tab_a.c.id, self.three_tab_b.c.aid)] + [(self.three_tab_a.c.id, self.three_tab_b.c.aid)], ) eq_( joincond.remote_columns, - set([self.three_tab_b.c.id, self.three_tab_b.c.aid]) + set([self.three_tab_b.c.id, self.three_tab_b.c.aid]), ) def test_determine_local_remote_overlapping_composite_fks(self): @@ -918,39 +958,34 @@ class ColumnCollectionsTest(_JoinFixtures, fixtures.TestBase, eq_( joincond.local_remote_pairs, [ - (self.composite_target.c.uid, - self.composite_multi_ref.c.uid2,), - (self.composite_target.c.oid, self.composite_multi_ref.c.oid,) - ] + (self.composite_target.c.uid, self.composite_multi_ref.c.uid2), + (self.composite_target.c.oid, self.composite_multi_ref.c.oid), + ], ) def test_determine_local_remote_pairs_purely_single_col_o2m(self): joincond = self._join_fixture_purely_single_o2m() eq_( joincond.local_remote_pairs, - [(self.purely_single_col.c.path, self.purely_single_col.c.path)] + [(self.purely_single_col.c.path, self.purely_single_col.c.path)], ) def test_determine_local_remote_pairs_inh_selfref_w_entities(self): joincond = self._join_fixture_inh_selfref_w_entity() eq_( joincond.local_remote_pairs, - [(self.sub.c.id, self.sub_w_sub_rel.c.sub_id)] + [(self.sub.c.id, self.sub_w_sub_rel.c.sub_id)], ) eq_( joincond.remote_columns, - set([self.base.c.flag, self.sub_w_sub_rel.c.sub_id]) + set([self.base.c.flag, self.sub_w_sub_rel.c.sub_id]), ) class DirectionTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): def test_determine_direction_compound_2(self): - joincond = self._join_fixture_compound_expression_2( - support_sync=False) - is_( - joincond.direction, - ONETOMANY - ) + joincond = self._join_fixture_compound_expression_2(support_sync=False) + is_(joincond.direction, ONETOMANY) def test_determine_direction_o2m(self): joincond = self._join_fixture_o2m() @@ -986,35 +1021,26 @@ class DirectionTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): class DetermineJoinTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_determine_join_o2m(self): joincond = self._join_fixture_o2m() - self.assert_compile( - joincond.primaryjoin, - "lft.id = rgt.lid" - ) + self.assert_compile(joincond.primaryjoin, "lft.id = rgt.lid") def test_determine_join_o2m_selfref(self): joincond = self._join_fixture_o2m_selfref() - self.assert_compile( - joincond.primaryjoin, - "selfref.id = selfref.sid" - ) + self.assert_compile(joincond.primaryjoin, "selfref.id = selfref.sid") def test_determine_join_m2o_selfref(self): joincond = self._join_fixture_m2o_selfref() - self.assert_compile( - joincond.primaryjoin, - "selfref.id = selfref.sid" - ) + self.assert_compile(joincond.primaryjoin, "selfref.id = selfref.sid") def test_determine_join_o2m_composite_selfref(self): joincond = self._join_fixture_o2m_composite_selfref() self.assert_compile( joincond.primaryjoin, "composite_selfref.group_id = composite_selfref.group_id " - "AND composite_selfref.id = composite_selfref.parent_id" + "AND composite_selfref.id = composite_selfref.parent_id", ) def test_determine_join_m2o_composite_selfref(self): @@ -1022,15 +1048,12 @@ class DetermineJoinTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( joincond.primaryjoin, "composite_selfref.group_id = composite_selfref.group_id " - "AND composite_selfref.id = composite_selfref.parent_id" + "AND composite_selfref.id = composite_selfref.parent_id", ) def test_determine_join_m2o(self): joincond = self._join_fixture_m2o() - self.assert_compile( - joincond.primaryjoin, - "lft.id = rgt.lid" - ) + self.assert_compile(joincond.primaryjoin, "lft.id = rgt.lid") def test_determine_join_ambiguous_fks_o2m(self): assert_raises_message( @@ -1052,7 +1075,8 @@ class DetermineJoinTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): def test_determine_join_no_fks_o2m(self): self._assert_raises_no_join( relationships.JoinCondition, - "None", None, + "None", + None, self.left, self.selfref, self.left, @@ -1063,23 +1087,25 @@ class DetermineJoinTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): self._assert_raises_ambig_join( relationships.JoinCondition, - "None", self.m2msecondary_ambig_fks, + "None", + self.m2msecondary_ambig_fks, self.m2mleft, self.m2mright, self.m2mleft, self.m2mright, - secondary=self.m2msecondary_ambig_fks + secondary=self.m2msecondary_ambig_fks, ) def test_determine_join_no_fks_m2m(self): self._assert_raises_no_join( relationships.JoinCondition, - "None", self.m2msecondary_no_fks, + "None", + self.m2msecondary_no_fks, self.m2mleft, self.m2mright, self.m2mleft, self.m2mright, - secondary=self.m2msecondary_no_fks + secondary=self.m2msecondary_no_fks, ) def _join_fixture_fks_ambig_m2m(self): @@ -1091,122 +1117,99 @@ class DetermineJoinTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): secondary=self.m2msecondary_ambig_fks, consider_as_foreign_keys=[ self.m2msecondary_ambig_fks.c.lid1, - self.m2msecondary_ambig_fks.c.rid1] + self.m2msecondary_ambig_fks.c.rid1, + ], ) def test_determine_join_w_fks_ambig_m2m(self): joincond = self._join_fixture_fks_ambig_m2m() self.assert_compile( - joincond.primaryjoin, - "m2mlft.id = m2msecondary_ambig_fks.lid1" + joincond.primaryjoin, "m2mlft.id = m2msecondary_ambig_fks.lid1" ) self.assert_compile( - joincond.secondaryjoin, - "m2mrgt.id = m2msecondary_ambig_fks.rid1" + joincond.secondaryjoin, "m2mrgt.id = m2msecondary_ambig_fks.rid1" ) class AdaptedJoinTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_join_targets_o2m_selfref(self): joincond = self._join_fixture_o2m_selfref() - left = select([joincond.parent_selectable]).alias('pj') + left = select([joincond.parent_selectable]).alias("pj") pj, sj, sec, adapter, ds = joincond.join_targets( - left, - joincond.child_selectable, - True) - self.assert_compile( - pj, "pj.id = selfref.sid" + left, joincond.child_selectable, True ) + self.assert_compile(pj, "pj.id = selfref.sid") - right = select([joincond.child_selectable]).alias('pj') + right = select([joincond.child_selectable]).alias("pj") pj, sj, sec, adapter, ds = joincond.join_targets( - joincond.parent_selectable, - right, - True) - self.assert_compile( - pj, "selfref.id = pj.sid" + joincond.parent_selectable, right, True ) + self.assert_compile(pj, "selfref.id = pj.sid") def test_join_targets_o2m_plain(self): joincond = self._join_fixture_o2m() pj, sj, sec, adapter, ds = joincond.join_targets( - joincond.parent_selectable, - joincond.child_selectable, - False) - self.assert_compile( - pj, "lft.id = rgt.lid" + joincond.parent_selectable, joincond.child_selectable, False ) + self.assert_compile(pj, "lft.id = rgt.lid") def test_join_targets_o2m_left_aliased(self): joincond = self._join_fixture_o2m() - left = select([joincond.parent_selectable]).alias('pj') + left = select([joincond.parent_selectable]).alias("pj") pj, sj, sec, adapter, ds = joincond.join_targets( - left, - joincond.child_selectable, - True) - self.assert_compile( - pj, "pj.id = rgt.lid" + left, joincond.child_selectable, True ) + self.assert_compile(pj, "pj.id = rgt.lid") def test_join_targets_o2m_right_aliased(self): joincond = self._join_fixture_o2m() - right = select([joincond.child_selectable]).alias('pj') + right = select([joincond.child_selectable]).alias("pj") pj, sj, sec, adapter, ds = joincond.join_targets( - joincond.parent_selectable, - right, - True) - self.assert_compile( - pj, "lft.id = pj.lid" + joincond.parent_selectable, right, True ) + self.assert_compile(pj, "lft.id = pj.lid") def test_join_targets_o2m_composite_selfref(self): joincond = self._join_fixture_o2m_composite_selfref() - right = select([joincond.child_selectable]).alias('pj') + right = select([joincond.child_selectable]).alias("pj") pj, sj, sec, adapter, ds = joincond.join_targets( - joincond.parent_selectable, - right, - True) + joincond.parent_selectable, right, True + ) self.assert_compile( pj, "pj.group_id = composite_selfref.group_id " - "AND composite_selfref.id = pj.parent_id" + "AND composite_selfref.id = pj.parent_id", ) def test_join_targets_m2o_composite_selfref(self): joincond = self._join_fixture_m2o_composite_selfref() - right = select([joincond.child_selectable]).alias('pj') + right = select([joincond.child_selectable]).alias("pj") pj, sj, sec, adapter, ds = joincond.join_targets( - joincond.parent_selectable, - right, - True) + joincond.parent_selectable, right, True + ) self.assert_compile( pj, "pj.group_id = composite_selfref.group_id " - "AND pj.id = composite_selfref.parent_id" + "AND pj.id = composite_selfref.parent_id", ) class LazyClauseTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_lazy_clause_o2m(self): joincond = self._join_fixture_o2m() lazywhere, bind_to_col, equated_columns = joincond.create_lazy_clause() - self.assert_compile( - lazywhere, - ":param_1 = rgt.lid" - ) + self.assert_compile(lazywhere, ":param_1 = rgt.lid") def test_lazy_clause_o2m_reverse(self): joincond = self._join_fixture_o2m() - lazywhere, bind_to_col, equated_columns =\ - joincond.create_lazy_clause(reverse_direction=True) - self.assert_compile( - lazywhere, - "lft.id = :param_1" + lazywhere, bind_to_col, equated_columns = joincond.create_lazy_clause( + reverse_direction=True ) + self.assert_compile(lazywhere, "lft.id = :param_1") def test_lazy_clause_o2m_o_side_none(self): # test for #2948. When the join is "o.id == m.oid @@ -1217,18 +1220,19 @@ class LazyClauseTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( lazywhere, ":param_1 = rgt.lid AND :param_2 = :x_1", - checkparams={'param_1': None, 'param_2': None, 'x_1': 5} + checkparams={"param_1": None, "param_2": None, "x_1": 5}, ) def test_lazy_clause_o2m_o_side_none_reverse(self): # continued test for #2948. joincond = self._join_fixture_o2m_o_side_none() lazywhere, bind_to_col, equated_columns = joincond.create_lazy_clause( - reverse_direction=True) + reverse_direction=True + ) self.assert_compile( lazywhere, "lft.id = :param_1 AND lft.x = :x_1", - checkparams={'param_1': None, 'x_1': 5} + checkparams={"param_1": None, "x_1": 5}, ) def test_lazy_clause_remote_local_multiple_ref(self): @@ -1238,32 +1242,31 @@ class LazyClauseTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( lazywhere, ":param_1 = selfref.sid OR selfref.sid = :param_1", - checkparams={'param_1': None} + checkparams={"param_1": None}, ) class DeannotateCorrectlyTest(fixtures.TestBase): def test_pj_deannotates(self): from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) class B(Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) a_id = Column(ForeignKey(A.id)) a = relationship(A) eq_( B.a.property.primaryjoin.left._annotations, - {"parentmapper": A.__mapper__, "remote": True} + {"parentmapper": A.__mapper__, "remote": True}, ) eq_( B.a.property.primaryjoin.right._annotations, - {'foreign': True, 'local': True, 'parentmapper': B.__mapper__} + {"foreign": True, "local": True, "parentmapper": B.__mapper__}, ) - - diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index b60d73c8f8..3ca8d8c099 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -2,14 +2,35 @@ from sqlalchemy.testing import assert_raises, assert_raises_message import datetime import sqlalchemy as sa from sqlalchemy import testing -from sqlalchemy import Integer, String, ForeignKey, MetaData, and_, \ - select, func +from sqlalchemy import ( + Integer, + String, + ForeignKey, + MetaData, + and_, + select, + func, +) from sqlalchemy.testing.schema import Table, Column -from sqlalchemy.orm import mapper, relationship, relation, \ - backref, create_session, configure_mappers, \ - clear_mappers, sessionmaker, attributes,\ - Session, composite, column_property, foreign,\ - remote, synonym, joinedload, subqueryload +from sqlalchemy.orm import ( + mapper, + relationship, + relation, + backref, + create_session, + configure_mappers, + clear_mappers, + sessionmaker, + attributes, + Session, + composite, + column_property, + foreign, + remote, + synonym, + joinedload, + subqueryload, +) from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE from sqlalchemy.testing import eq_, startswith_, AssertsCompiledSQL, is_, in_ from sqlalchemy.testing import fixtures @@ -22,9 +43,9 @@ from sqlalchemy.ext.declarative import declarative_base class _RelationshipErrors(object): - - def _assert_raises_no_relevant_fks(self, fn, expr, relname, - primary, *arg, **kw): + def _assert_raises_no_relevant_fks( + self, fn, expr, relname, primary, *arg, **kw + ): assert_raises_message( sa.exc.ArgumentError, "Could not locate any relevant foreign key columns " @@ -32,14 +53,15 @@ class _RelationshipErrors(object): "Ensure that referencing columns are associated with " "a ForeignKey or ForeignKeyConstraint, or are annotated " r"in the join condition with the foreign\(\) annotation." - % ( - primary, expr, relname - ), - fn, *arg, **kw + % (primary, expr, relname), + fn, + *arg, + **kw ) - def _assert_raises_no_equality(self, fn, expr, relname, - primary, *arg, **kw): + def _assert_raises_no_equality( + self, fn, expr, relname, primary, *arg, **kw + ): assert_raises_message( sa.exc.ArgumentError, "Could not locate any simple equality expressions " @@ -49,14 +71,16 @@ class _RelationshipErrors(object): "ForeignKey or ForeignKeyConstraint, or are annotated in " r"the join condition with the foreign\(\) annotation. " "To allow comparison operators other than '==', " - "the relationship can be marked as viewonly=True." % ( - primary, expr, relname - ), - fn, *arg, **kw + "the relationship can be marked as viewonly=True." + % (primary, expr, relname), + fn, + *arg, + **kw ) - def _assert_raises_ambig_join(self, fn, relname, secondary_arg, - *arg, **kw): + def _assert_raises_ambig_join( + self, fn, relname, secondary_arg, *arg, **kw + ): if secondary_arg is not None: assert_raises_message( exc.ArgumentError, @@ -69,7 +93,10 @@ class _RelationshipErrors(object): "containing a foreign key reference from the " "secondary table to each of the parent and child tables." % (relname, secondary_arg), - fn, *arg, **kw) + fn, + *arg, + **kw + ) else: assert_raises_message( exc.ArgumentError, @@ -79,12 +106,13 @@ class _RelationshipErrors(object): "paths linking the tables. Specify the " "'foreign_keys' argument, providing a list of those " "columns which should be counted as containing a " - "foreign key reference to the parent table." - % (relname,), - fn, *arg, **kw) + "foreign key reference to the parent table." % (relname,), + fn, + *arg, + **kw + ) - def _assert_raises_no_join(self, fn, relname, secondary_arg, - *arg, **kw): + def _assert_raises_no_join(self, fn, relname, secondary_arg, *arg, **kw): if secondary_arg is not None: assert_raises_message( exc.NoForeignKeysError, @@ -95,9 +123,11 @@ class _RelationshipErrors(object): "Ensure that referencing columns are associated with a " "ForeignKey " "or ForeignKeyConstraint, or specify 'primaryjoin' and " - "'secondaryjoin' expressions" - % (relname, secondary_arg), - fn, *arg, **kw) + "'secondaryjoin' expressions" % (relname, secondary_arg), + fn, + *arg, + **kw + ) else: assert_raises_message( exc.NoForeignKeysError, @@ -107,9 +137,11 @@ class _RelationshipErrors(object): "Ensure that referencing columns are associated with a " "ForeignKey " "or ForeignKeyConstraint, or specify a 'primaryjoin' " - "expression." - % (relname,), - fn, *arg, **kw) + "expression." % (relname,), + fn, + *arg, + **kw + ) def _assert_raises_ambiguous_direction(self, fn, relname, *arg, **kw): assert_raises_message( @@ -120,9 +152,10 @@ class _RelationshipErrors(object): "in both the parent and the child's mapped tables. " "Ensure that only those columns referring to a parent column " r"are marked as foreign, either via the foreign\(\) annotation or " - "via the foreign_keys argument." - % relname, - fn, *arg, **kw + "via the foreign_keys argument." % relname, + fn, + *arg, + **kw ) def _assert_raises_no_local_remote(self, fn, relname, *arg, **kw): @@ -133,11 +166,11 @@ class _RelationshipErrors(object): "pairs based on join condition and remote_side arguments. " r"Consider using the remote\(\) annotation to " "accurately mark those elements of the join " - "condition that are on the remote side of the relationship." % ( - relname - ), - - fn, *arg, **kw + "condition that are on the remote side of the relationship." + % (relname), + fn, + *arg, + **kw ) @@ -145,33 +178,51 @@ class DependencyTwoParentTest(fixtures.MappedTest): """Test flush() when a mapper is dependent on multiple relationships""" - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - Table("tbl_a", metadata, - Column("id", Integer, primary_key=True, - test_needs_autoincrement=True), - Column("name", String(128))) - Table("tbl_b", metadata, - Column("id", Integer, primary_key=True, - test_needs_autoincrement=True), - Column("name", String(128))) - Table("tbl_c", metadata, - Column("id", Integer, primary_key=True, - test_needs_autoincrement=True), - Column("tbl_a_id", Integer, ForeignKey("tbl_a.id"), - nullable=False), - Column("name", String(128))) - Table("tbl_d", metadata, - Column("id", Integer, primary_key=True, - test_needs_autoincrement=True), - Column("tbl_c_id", Integer, ForeignKey("tbl_c.id"), - nullable=False), - Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")), - Column("name", String(128))) + Table( + "tbl_a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(128)), + ) + Table( + "tbl_b", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(128)), + ) + Table( + "tbl_c", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column( + "tbl_a_id", Integer, ForeignKey("tbl_a.id"), nullable=False + ), + Column("name", String(128)), + ) + Table( + "tbl_d", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column( + "tbl_c_id", Integer, ForeignKey("tbl_c.id"), nullable=False + ), + Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")), + Column("name", String(128)), + ) @classmethod def setup_classes(cls): @@ -189,40 +240,55 @@ class DependencyTwoParentTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): - A, C, B, D, tbl_b, tbl_c, tbl_a, tbl_d = (cls.classes.A, - cls.classes.C, - cls.classes.B, - cls.classes.D, - cls.tables.tbl_b, - cls.tables.tbl_c, - cls.tables.tbl_a, - cls.tables.tbl_d) - - mapper(A, tbl_a, properties=dict( - c_rows=relationship(C, cascade="all, delete-orphan", - backref="a_row"))) + A, C, B, D, tbl_b, tbl_c, tbl_a, tbl_d = ( + cls.classes.A, + cls.classes.C, + cls.classes.B, + cls.classes.D, + cls.tables.tbl_b, + cls.tables.tbl_c, + cls.tables.tbl_a, + cls.tables.tbl_d, + ) + + mapper( + A, + tbl_a, + properties=dict( + c_rows=relationship( + C, cascade="all, delete-orphan", backref="a_row" + ) + ), + ) mapper(B, tbl_b) - mapper(C, tbl_c, properties=dict( - d_rows=relationship(D, cascade="all, delete-orphan", - backref="c_row"))) - mapper(D, tbl_d, properties=dict( - b_row=relationship(B))) + mapper( + C, + tbl_c, + properties=dict( + d_rows=relationship( + D, cascade="all, delete-orphan", backref="c_row" + ) + ), + ) + mapper(D, tbl_d, properties=dict(b_row=relationship(B))) @classmethod def insert_data(cls): - A, C, B, D = (cls.classes.A, - cls.classes.C, - cls.classes.B, - cls.classes.D) + A, C, B, D = ( + cls.classes.A, + cls.classes.C, + cls.classes.B, + cls.classes.D, + ) session = create_session() - a = A(name='a1') - b = B(name='b1') - c = C(name='c1', a_row=a) + a = A(name="a1") + b = B(name="b1") + c = C(name="c1", a_row=a) - d1 = D(name='d1', b_row=b, c_row=c) # noqa - d2 = D(name='d2', b_row=b, c_row=c) # noqa - d3 = D(name='d3', b_row=b, c_row=c) # noqa + d1 = D(name="d1", b_row=b, c_row=c) # noqa + d2 = D(name="d2", b_row=b, c_row=c) # noqa + d3 = D(name="d3", b_row=b, c_row=c) # noqa session.add(a) session.add(b) session.flush() @@ -231,7 +297,7 @@ class DependencyTwoParentTest(fixtures.MappedTest): A = self.classes.A session = create_session() - a = session.query(A).filter_by(name='a1').one() + a = session.query(A).filter_by(name="a1").one() session.delete(a) session.flush() @@ -240,25 +306,22 @@ class DependencyTwoParentTest(fixtures.MappedTest): C = self.classes.C session = create_session() - c = session.query(C).filter_by(name='c1').one() + c = session.query(C).filter_by(name="c1").one() session.delete(c) session.flush() class M2ODontOverwriteFKTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): Table( - 'a', metadata, - Column('id', Integer, primary_key=True), - Column('bid', ForeignKey('b.id')) - ) - Table( - 'b', metadata, - Column('id', Integer, primary_key=True), + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("bid", ForeignKey("b.id")), ) + Table("b", metadata, Column("id", Integer, primary_key=True)) def _fixture(self, uselist=False): a, b = self.tables.a, self.tables.b @@ -269,9 +332,7 @@ class M2ODontOverwriteFKTest(fixtures.MappedTest): class B(fixtures.BasicEntity): pass - mapper(A, a, properties={ - 'b': relationship(B, uselist=uselist) - }) + mapper(A, a, properties={"b": relationship(B, uselist=uselist)}) mapper(B, b) return A, B @@ -376,18 +437,17 @@ class DirectSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): """ - __dialect__ = 'default' + __dialect__ = "default" @classmethod def define_tables(cls, metadata): - Table('entity', metadata, - Column('path', String(100), primary_key=True) - ) + Table( + "entity", metadata, Column("path", String(100), primary_key=True) + ) @classmethod def setup_classes(cls): class Entity(cls.Basic): - def __init__(self, path): self.path = path @@ -395,14 +455,20 @@ class DirectSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): Entity = self.classes.Entity entity = self.tables.entity - m = mapper(Entity, entity, properties={ - "descendants": relationship( - Entity, - primaryjoin=remote(foreign(entity.c.path)).like( - entity.c.path.concat('/%')), - viewonly=True, - order_by=entity.c.path) - }) + m = mapper( + Entity, + entity, + properties={ + "descendants": relationship( + Entity, + primaryjoin=remote(foreign(entity.c.path)).like( + entity.c.path.concat("/%") + ), + viewonly=True, + order_by=entity.c.path, + ) + }, + ) configure_mappers() assert m.get_property("descendants").direction is ONETOMANY if data: @@ -412,14 +478,20 @@ class DirectSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): Entity = self.classes.Entity entity = self.tables.entity - m = mapper(Entity, entity, properties={ - "anscestors": relationship( - Entity, - primaryjoin=entity.c.path.like( - remote(foreign(entity.c.path)).concat('/%')), - viewonly=True, - order_by=entity.c.path) - }) + m = mapper( + Entity, + entity, + properties={ + "anscestors": relationship( + Entity, + primaryjoin=entity.c.path.like( + remote(foreign(entity.c.path)).concat("/%") + ), + viewonly=True, + order_by=entity.c.path, + ) + }, + ) configure_mappers() assert m.get_property("anscestors").direction is ONETOMANY if data: @@ -428,16 +500,18 @@ class DirectSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): def _fixture(self): Entity = self.classes.Entity sess = Session() - sess.add_all([ - Entity("/foo"), - Entity("/foo/bar1"), - Entity("/foo/bar2"), - Entity("/foo/bar2/bat1"), - Entity("/foo/bar2/bat2"), - Entity("/foo/bar3"), - Entity("/bar"), - Entity("/bar/bat1") - ]) + sess.add_all( + [ + Entity("/foo"), + Entity("/foo/bar1"), + Entity("/foo/bar2"), + Entity("/foo/bar2/bat1"), + Entity("/foo/bar2/bat2"), + Entity("/foo/bar3"), + Entity("/bar"), + Entity("/bar/bat1"), + ] + ) return sess def test_descendants_lazyload_clause(self): @@ -445,12 +519,12 @@ class DirectSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): Entity = self.classes.Entity self.assert_compile( Entity.descendants.property.strategy._lazywhere, - "entity.path LIKE :param_1 || :path_1" + "entity.path LIKE :param_1 || :path_1", ) self.assert_compile( Entity.descendants.property.strategy._rev_lazywhere, - ":param_1 LIKE entity.path || :path_1" + ":param_1 LIKE entity.path || :path_1", ) def test_ancestors_lazyload_clause(self): @@ -459,12 +533,12 @@ class DirectSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): # :param_1 LIKE (:param_1 || :path_1) self.assert_compile( Entity.anscestors.property.strategy._lazywhere, - ":param_1 LIKE entity.path || :path_1" + ":param_1 LIKE entity.path || :path_1", ) self.assert_compile( Entity.anscestors.property.strategy._rev_lazywhere, - "entity.path LIKE :param_1 || :path_1" + "entity.path LIKE :param_1 || :path_1", ) def test_descendants_lazyload(self): @@ -473,52 +547,73 @@ class DirectSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): e1 = sess.query(Entity).filter_by(path="/foo").first() eq_( [e.path for e in e1.descendants], - ["/foo/bar1", "/foo/bar2", "/foo/bar2/bat1", - "/foo/bar2/bat2", "/foo/bar3"] + [ + "/foo/bar1", + "/foo/bar2", + "/foo/bar2/bat1", + "/foo/bar2/bat2", + "/foo/bar3", + ], ) def test_anscestors_lazyload(self): sess = self._anscestors_fixture() Entity = self.classes.Entity e1 = sess.query(Entity).filter_by(path="/foo/bar2/bat1").first() - eq_( - [e.path for e in e1.anscestors], - ["/foo", "/foo/bar2"] - ) + eq_([e.path for e in e1.anscestors], ["/foo", "/foo/bar2"]) def test_descendants_joinedload(self): sess = self._descendants_fixture() Entity = self.classes.Entity - e1 = sess.query(Entity).filter_by(path="/foo").\ - options(joinedload(Entity.descendants)).first() + e1 = ( + sess.query(Entity) + .filter_by(path="/foo") + .options(joinedload(Entity.descendants)) + .first() + ) eq_( [e.path for e in e1.descendants], - ["/foo/bar1", "/foo/bar2", "/foo/bar2/bat1", - "/foo/bar2/bat2", "/foo/bar3"] + [ + "/foo/bar1", + "/foo/bar2", + "/foo/bar2/bat1", + "/foo/bar2/bat2", + "/foo/bar3", + ], ) def test_descendants_subqueryload(self): sess = self._descendants_fixture() Entity = self.classes.Entity - e1 = sess.query(Entity).filter_by(path="/foo").\ - options(subqueryload(Entity.descendants)).first() + e1 = ( + sess.query(Entity) + .filter_by(path="/foo") + .options(subqueryload(Entity.descendants)) + .first() + ) eq_( [e.path for e in e1.descendants], - ["/foo/bar1", "/foo/bar2", "/foo/bar2/bat1", - "/foo/bar2/bat2", "/foo/bar3"] + [ + "/foo/bar1", + "/foo/bar2", + "/foo/bar2/bat1", + "/foo/bar2/bat2", + "/foo/bar3", + ], ) def test_anscestors_joinedload(self): sess = self._anscestors_fixture() Entity = self.classes.Entity - e1 = sess.query(Entity).filter_by(path="/foo/bar2/bat1").\ - options(joinedload(Entity.anscestors)).first() - eq_( - [e.path for e in e1.anscestors], - ["/foo", "/foo/bar2"] + e1 = ( + sess.query(Entity) + .filter_by(path="/foo/bar2/bat1") + .options(joinedload(Entity.anscestors)) + .first() ) + eq_([e.path for e in e1.anscestors], ["/foo", "/foo/bar2"]) def test_plain_join_descendants(self): self._descendants_fixture(data=False) @@ -527,7 +622,7 @@ class DirectSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): self.assert_compile( sess.query(Entity).join(Entity.descendants, aliased=True), "SELECT entity.path AS entity_path FROM entity JOIN entity AS " - "entity_1 ON entity_1.path LIKE entity.path || :path_1" + "entity_1 ON entity_1.path LIKE entity.path || :path_1", ) @@ -541,72 +636,76 @@ class OverlappingFksSiblingTest(fixtures.TestBase): clear_mappers() def _fixture_one( - self, add_b_a=False, add_b_a_viewonly=False, add_b_amember=False, - add_bsub1_a=False, add_bsub2_a_viewonly=False): + self, + add_b_a=False, + add_b_a_viewonly=False, + add_b_amember=False, + add_bsub1_a=False, + add_bsub2_a_viewonly=False, + ): Base = declarative_base(metadata=self.metadata) class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) - a_members = relationship('AMember', backref='a') + a_members = relationship("AMember", backref="a") class AMember(Base): - __tablename__ = 'a_member' + __tablename__ = "a_member" - a_id = Column(Integer, ForeignKey('a.id'), primary_key=True) + a_id = Column(Integer, ForeignKey("a.id"), primary_key=True) a_member_id = Column(Integer, primary_key=True) class B(Base): - __tablename__ = 'b' + __tablename__ = "b" - __mapper_args__ = { - 'polymorphic_on': 'type' - } + __mapper_args__ = {"polymorphic_on": "type"} id = Column(Integer, primary_key=True) type = Column(String(20)) - a_id = Column(Integer, ForeignKey('a.id'), nullable=False) + a_id = Column(Integer, ForeignKey("a.id"), nullable=False) a_member_id = Column(Integer) __table_args__ = ( ForeignKeyConstraint( - ('a_id', 'a_member_id'), - ('a_member.a_id', 'a_member.a_member_id')), + ("a_id", "a_member_id"), + ("a_member.a_id", "a_member.a_member_id"), + ), ) # if added and viewonly is not true, this relationship # writes to B.a_id, which conflicts with BSub2.a_member, # so should warn if add_b_a: - a = relationship('A', viewonly=add_b_a_viewonly) + a = relationship("A", viewonly=add_b_a_viewonly) # if added, this relationship writes to B.a_id, which conflicts # with BSub1.a if add_b_amember: - a_member = relationship('AMember') + a_member = relationship("AMember") # however, *no* warning should be emitted otherwise. class BSub1(B): if add_bsub1_a: - a = relationship('A') + a = relationship("A") - __mapper_args__ = {'polymorphic_identity': 'bsub1'} + __mapper_args__ = {"polymorphic_identity": "bsub1"} class BSub2(B): if add_bsub2_a_viewonly: a = relationship("A", viewonly=True) - a_member = relationship('AMember') + a_member = relationship("AMember") - __mapper_args__ = {'polymorphic_identity': 'bsub2'} + __mapper_args__ = {"polymorphic_identity": "bsub2"} configure_mappers() self.metadata.create_all() @@ -647,7 +746,7 @@ class OverlappingFksSiblingTest(fixtures.TestBase): # everyone has a B.a relationship eq_( session.query(B, A).outerjoin(B.a).order_by(B.id).all(), - [(bsub1, a2), (bsub2, a1)] + [(bsub1, a2), (bsub2, a1)], ) @testing.provide_metadata @@ -656,7 +755,9 @@ class OverlappingFksSiblingTest(fixtures.TestBase): exc.SAWarning, r"relationship '(?:BSub1.a|BSub2.a_member|B.a)' will copy column " r"(?:a.id|a_member.a_id) to column b.a_id", - self._fixture_one, add_b_a=True, add_bsub1_a=True + self._fixture_one, + add_b_a=True, + add_bsub1_a=True, ) @testing.provide_metadata @@ -665,7 +766,9 @@ class OverlappingFksSiblingTest(fixtures.TestBase): exc.SAWarning, r"relationship '(?:BSub1.a|B.a_member)' will copy column " r"(?:a.id|a_member.a_id) to column b.a_id", - self._fixture_one, add_b_amember=True, add_bsub1_a=True + self._fixture_one, + add_b_amember=True, + add_bsub1_a=True, ) @testing.provide_metadata @@ -674,8 +777,10 @@ class OverlappingFksSiblingTest(fixtures.TestBase): exc.SAWarning, r"relationship '(?:BSub1.a|B.a_member|B.a)' will copy column " r"(?:a.id|a_member.a_id) to column b.a_id", - self._fixture_one, add_b_amember=True, add_bsub1_a=True, - add_b_a=True + self._fixture_one, + add_b_amember=True, + add_bsub1_a=True, + add_b_a=True, ) @testing.provide_metadata @@ -686,9 +791,7 @@ class OverlappingFksSiblingTest(fixtures.TestBase): @testing.provide_metadata def test_works_two(self): - self._test_fixture_one_run( - add_b_a=True, add_bsub2_a_viewonly=True - ) + self._test_fixture_one_run(add_b_a=True, add_bsub2_a_viewonly=True) class CompositeSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): @@ -714,36 +817,43 @@ class CompositeSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): """ - __dialect__ = 'default' + __dialect__ = "default" @classmethod def define_tables(cls, metadata): - Table('company_t', metadata, - Column('company_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30))) - - Table('employee_t', metadata, - Column('company_id', Integer, primary_key=True), - Column('emp_id', Integer, primary_key=True), - Column('name', String(30)), - Column('reports_to_id', Integer), - sa.ForeignKeyConstraint( - ['company_id'], - ['company_t.company_id']), - sa.ForeignKeyConstraint( - ['company_id', 'reports_to_id'], - ['employee_t.company_id', 'employee_t.emp_id'])) + Table( + "company_t", + metadata, + Column( + "company_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(30)), + ) + + Table( + "employee_t", + metadata, + Column("company_id", Integer, primary_key=True), + Column("emp_id", Integer, primary_key=True), + Column("name", String(30)), + Column("reports_to_id", Integer), + sa.ForeignKeyConstraint(["company_id"], ["company_t.company_id"]), + sa.ForeignKeyConstraint( + ["company_id", "reports_to_id"], + ["employee_t.company_id", "employee_t.emp_id"], + ), + ) @classmethod def setup_classes(cls): class Company(cls.Basic): - def __init__(self, name): self.name = name class Employee(cls.Basic): - def __init__(self, name, company, emp_id, reports_to=None): self.name = name self.company = company @@ -751,166 +861,233 @@ class CompositeSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): self.reports_to = reports_to def test_explicit(self): - Employee, Company, employee_t, company_t = (self.classes.Employee, - self.classes.Company, - self.tables.employee_t, - self.tables.company_t) + Employee, Company, employee_t, company_t = ( + self.classes.Employee, + self.classes.Company, + self.tables.employee_t, + self.tables.company_t, + ) mapper(Company, company_t) - mapper(Employee, employee_t, properties={ - 'company': relationship(Company, - primaryjoin=employee_t.c.company_id == - company_t.c.company_id, - backref='employees'), - 'reports_to': relationship(Employee, primaryjoin=sa.and_( - employee_t.c.emp_id == employee_t.c.reports_to_id, - employee_t.c.company_id == employee_t.c.company_id - ), - remote_side=[employee_t.c.emp_id, employee_t.c.company_id], - foreign_keys=[ - employee_t.c.reports_to_id, employee_t.c.company_id], - backref=backref('employees', - foreign_keys=[employee_t.c.reports_to_id, - employee_t.c.company_id])) - }) + mapper( + Employee, + employee_t, + properties={ + "company": relationship( + Company, + primaryjoin=employee_t.c.company_id + == company_t.c.company_id, + backref="employees", + ), + "reports_to": relationship( + Employee, + primaryjoin=sa.and_( + employee_t.c.emp_id == employee_t.c.reports_to_id, + employee_t.c.company_id == employee_t.c.company_id, + ), + remote_side=[employee_t.c.emp_id, employee_t.c.company_id], + foreign_keys=[ + employee_t.c.reports_to_id, + employee_t.c.company_id, + ], + backref=backref( + "employees", + foreign_keys=[ + employee_t.c.reports_to_id, + employee_t.c.company_id, + ], + ), + ), + }, + ) self._test() def test_implicit(self): - Employee, Company, employee_t, company_t = (self.classes.Employee, - self.classes.Company, - self.tables.employee_t, - self.tables.company_t) + Employee, Company, employee_t, company_t = ( + self.classes.Employee, + self.classes.Company, + self.tables.employee_t, + self.tables.company_t, + ) mapper(Company, company_t) - mapper(Employee, employee_t, properties={ - 'company': relationship(Company, backref='employees'), - 'reports_to': relationship( - Employee, - remote_side=[employee_t.c.emp_id, employee_t.c.company_id], - foreign_keys=[employee_t.c.reports_to_id, - employee_t.c.company_id], - backref=backref( - 'employees', + mapper( + Employee, + employee_t, + properties={ + "company": relationship(Company, backref="employees"), + "reports_to": relationship( + Employee, + remote_side=[employee_t.c.emp_id, employee_t.c.company_id], foreign_keys=[ - employee_t.c.reports_to_id, employee_t.c.company_id]) - ) - }) + employee_t.c.reports_to_id, + employee_t.c.company_id, + ], + backref=backref( + "employees", + foreign_keys=[ + employee_t.c.reports_to_id, + employee_t.c.company_id, + ], + ), + ), + }, + ) self._test() def test_very_implicit(self): - Employee, Company, employee_t, company_t = (self.classes.Employee, - self.classes.Company, - self.tables.employee_t, - self.tables.company_t) + Employee, Company, employee_t, company_t = ( + self.classes.Employee, + self.classes.Company, + self.tables.employee_t, + self.tables.company_t, + ) mapper(Company, company_t) - mapper(Employee, employee_t, properties={ - 'company': relationship(Company, backref='employees'), - 'reports_to': relationship( - Employee, - remote_side=[employee_t.c.emp_id, employee_t.c.company_id], - backref='employees' - ) - }) + mapper( + Employee, + employee_t, + properties={ + "company": relationship(Company, backref="employees"), + "reports_to": relationship( + Employee, + remote_side=[employee_t.c.emp_id, employee_t.c.company_id], + backref="employees", + ), + }, + ) self._test() def test_very_explicit(self): - Employee, Company, employee_t, company_t = (self.classes.Employee, - self.classes.Company, - self.tables.employee_t, - self.tables.company_t) + Employee, Company, employee_t, company_t = ( + self.classes.Employee, + self.classes.Company, + self.tables.employee_t, + self.tables.company_t, + ) mapper(Company, company_t) - mapper(Employee, employee_t, properties={ - 'company': relationship(Company, backref='employees'), - 'reports_to': relationship( - Employee, - _local_remote_pairs=[ - (employee_t.c.reports_to_id, employee_t.c.emp_id), - (employee_t.c.company_id, employee_t.c.company_id) - ], - foreign_keys=[ - employee_t.c.reports_to_id, - employee_t.c.company_id], - backref=backref( - 'employees', + mapper( + Employee, + employee_t, + properties={ + "company": relationship(Company, backref="employees"), + "reports_to": relationship( + Employee, + _local_remote_pairs=[ + (employee_t.c.reports_to_id, employee_t.c.emp_id), + (employee_t.c.company_id, employee_t.c.company_id), + ], foreign_keys=[ - employee_t.c.reports_to_id, employee_t.c.company_id]) - ) - }) + employee_t.c.reports_to_id, + employee_t.c.company_id, + ], + backref=backref( + "employees", + foreign_keys=[ + employee_t.c.reports_to_id, + employee_t.c.company_id, + ], + ), + ), + }, + ) self._test() def test_annotated(self): - Employee, Company, employee_t, company_t = (self.classes.Employee, - self.classes.Company, - self.tables.employee_t, - self.tables.company_t) + Employee, Company, employee_t, company_t = ( + self.classes.Employee, + self.classes.Company, + self.tables.employee_t, + self.tables.company_t, + ) mapper(Company, company_t) - mapper(Employee, employee_t, properties={ - 'company': relationship(Company, backref='employees'), - 'reports_to': relationship( - Employee, - primaryjoin=sa.and_( - remote(employee_t.c.emp_id) == employee_t.c.reports_to_id, - remote(employee_t.c.company_id) == employee_t.c.company_id + mapper( + Employee, + employee_t, + properties={ + "company": relationship(Company, backref="employees"), + "reports_to": relationship( + Employee, + primaryjoin=sa.and_( + remote(employee_t.c.emp_id) + == employee_t.c.reports_to_id, + remote(employee_t.c.company_id) + == employee_t.c.company_id, + ), + backref=backref("employees"), ), - backref=backref('employees') - ) - }) + }, + ) self._assert_lazy_clauses() self._test() def test_overlapping_warning(self): - Employee, Company, employee_t, company_t = (self.classes.Employee, - self.classes.Company, - self.tables.employee_t, - self.tables.company_t) + Employee, Company, employee_t, company_t = ( + self.classes.Employee, + self.classes.Company, + self.tables.employee_t, + self.tables.company_t, + ) mapper(Company, company_t) - mapper(Employee, employee_t, properties={ - 'company': relationship(Company, backref='employees'), - 'reports_to': relationship( - Employee, - primaryjoin=sa.and_( - remote(employee_t.c.emp_id) == employee_t.c.reports_to_id, - remote(employee_t.c.company_id) == employee_t.c.company_id + mapper( + Employee, + employee_t, + properties={ + "company": relationship(Company, backref="employees"), + "reports_to": relationship( + Employee, + primaryjoin=sa.and_( + remote(employee_t.c.emp_id) + == employee_t.c.reports_to_id, + remote(employee_t.c.company_id) + == employee_t.c.company_id, + ), + backref=backref("employees"), ), - backref=backref('employees') - ) - }) + }, + ) assert_raises_message( exc.SAWarning, r"relationship .* will copy column .* to column " r"employee_t.company_id, which conflicts with relationship\(s\)", - configure_mappers + configure_mappers, ) def test_annotated_no_overwriting(self): - Employee, Company, employee_t, company_t = (self.classes.Employee, - self.classes.Company, - self.tables.employee_t, - self.tables.company_t) + Employee, Company, employee_t, company_t = ( + self.classes.Employee, + self.classes.Company, + self.tables.employee_t, + self.tables.company_t, + ) mapper(Company, company_t) - mapper(Employee, employee_t, properties={ - 'company': relationship(Company, backref='employees'), - 'reports_to': relationship( - Employee, - primaryjoin=sa.and_( - remote(employee_t.c.emp_id) == - foreign(employee_t.c.reports_to_id), - remote(employee_t.c.company_id) == employee_t.c.company_id + mapper( + Employee, + employee_t, + properties={ + "company": relationship(Company, backref="employees"), + "reports_to": relationship( + Employee, + primaryjoin=sa.and_( + remote(employee_t.c.emp_id) + == foreign(employee_t.c.reports_to_id), + remote(employee_t.c.company_id) + == employee_t.c.company_id, + ), + backref=backref("employees"), ), - backref=backref('employees') - ) - }) + }, + ) self._assert_lazy_clauses() self._test_no_warning() @@ -920,8 +1097,8 @@ class CompositeSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): Employee, Company = self.classes.Employee, self.classes.Company - c1 = sess.query(Company).filter_by(name='c1').one() - e3 = sess.query(Employee).filter_by(name='emp3').one() + c1 = sess.query(Company).filter_by(name="c1").one() + e3 = sess.query(Employee).filter_by(name="emp3").one() e3.reports_to = None if expect_failure: @@ -933,7 +1110,7 @@ class CompositeSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): AssertionError, "Dependency rule tried to blank-out primary key column " "'employee_t.company_id'", - sess.flush + sess.flush, ) else: sess.flush() @@ -959,13 +1136,13 @@ class CompositeSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): self.assert_compile( Employee.employees.property.strategy._lazywhere, ":param_1 = employee_t.reports_to_id AND " - ":param_2 = employee_t.company_id" + ":param_2 = employee_t.company_id", ) self.assert_compile( Employee.employees.property.strategy._rev_lazywhere, "employee_t.emp_id = :param_1 AND " - "employee_t.company_id = :param_2" + "employee_t.company_id = :param_2", ) def _test_relationships(self): @@ -973,36 +1150,40 @@ class CompositeSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): employee_t = self.tables.employee_t eq_( set(Employee.employees.property.local_remote_pairs), - set([ - (employee_t.c.company_id, employee_t.c.company_id), - (employee_t.c.emp_id, employee_t.c.reports_to_id), - ]) + set( + [ + (employee_t.c.company_id, employee_t.c.company_id), + (employee_t.c.emp_id, employee_t.c.reports_to_id), + ] + ), ) eq_( Employee.employees.property.remote_side, - set([employee_t.c.company_id, employee_t.c.reports_to_id]) + set([employee_t.c.company_id, employee_t.c.reports_to_id]), ) eq_( set(Employee.reports_to.property.local_remote_pairs), - set([ - (employee_t.c.company_id, employee_t.c.company_id), - (employee_t.c.reports_to_id, employee_t.c.emp_id), - ]) + set( + [ + (employee_t.c.company_id, employee_t.c.company_id), + (employee_t.c.reports_to_id, employee_t.c.emp_id), + ] + ), ) def _setup_data(self, sess): Employee, Company = self.classes.Employee, self.classes.Company - c1 = Company('c1') - c2 = Company('c2') + c1 = Company("c1") + c2 = Company("c2") - e1 = Employee('emp1', c1, 1) - e2 = Employee('emp2', c1, 2, e1) # noqa - e3 = Employee('emp3', c1, 3, e1) - e4 = Employee('emp4', c1, 4, e3) # noqa - e5 = Employee('emp5', c2, 1) - e6 = Employee('emp6', c2, 2, e5) # noqa - e7 = Employee('emp7', c2, 3, e5) # noqa + e1 = Employee("emp1", c1, 1) + e2 = Employee("emp2", c1, 2, e1) # noqa + e3 = Employee("emp3", c1, 3, e1) + e4 = Employee("emp4", c1, 4, e3) # noqa + e5 = Employee("emp5", c2, 1) + e6 = Employee("emp6", c2, 2, e5) # noqa + e7 = Employee("emp7", c2, 3, e5) # noqa sess.add_all((c1, c2)) sess.commit() @@ -1011,55 +1192,64 @@ class CompositeSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): def _test_lazy_relations(self, sess): Employee, Company = self.classes.Employee, self.classes.Company - c1 = sess.query(Company).filter_by(name='c1').one() - c2 = sess.query(Company).filter_by(name='c2').one() - e1 = sess.query(Employee).filter_by(name='emp1').one() - e5 = sess.query(Employee).filter_by(name='emp5').one() + c1 = sess.query(Company).filter_by(name="c1").one() + c2 = sess.query(Company).filter_by(name="c2").one() + e1 = sess.query(Employee).filter_by(name="emp1").one() + e5 = sess.query(Employee).filter_by(name="emp5").one() test_e1 = sess.query(Employee).get([c1.company_id, e1.emp_id]) - assert test_e1.name == 'emp1', test_e1.name + assert test_e1.name == "emp1", test_e1.name test_e5 = sess.query(Employee).get([c2.company_id, e5.emp_id]) - assert test_e5.name == 'emp5', test_e5.name - assert [x.name for x in test_e1.employees] == ['emp2', 'emp3'] - assert sess.query(Employee).\ - get([c1.company_id, 3]).reports_to.name == 'emp1' - assert sess.query(Employee).\ - get([c2.company_id, 3]).reports_to.name == 'emp5' + assert test_e5.name == "emp5", test_e5.name + assert [x.name for x in test_e1.employees] == ["emp2", "emp3"] + assert ( + sess.query(Employee).get([c1.company_id, 3]).reports_to.name + == "emp1" + ) + assert ( + sess.query(Employee).get([c2.company_id, 3]).reports_to.name + == "emp5" + ) def _test_join_aliasing(self, sess): Employee, Company = self.classes.Employee, self.classes.Company eq_( - [n for n, in sess.query(Employee.name). - join(Employee.reports_to, aliased=True). - filter_by(name='emp5'). - reset_joinpoint(). - order_by(Employee.name)], - ['emp6', 'emp7'] + [ + n + for n, in sess.query(Employee.name) + .join(Employee.reports_to, aliased=True) + .filter_by(name="emp5") + .reset_joinpoint() + .order_by(Employee.name) + ], + ["emp6", "emp7"], ) class CompositeJoinPartialFK(fixtures.MappedTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" @classmethod def define_tables(cls, metadata): - Table("parent", metadata, - Column('x', Integer, primary_key=True), - Column('y', Integer, primary_key=True), - Column('z', Integer), - ) - Table("child", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('x', Integer), - Column('y', Integer), - Column('z', Integer), - # note 'z' is not here - sa.ForeignKeyConstraint( - ["x", "y"], - ["parent.x", "parent.y"] - ) - ) + Table( + "parent", + metadata, + Column("x", Integer, primary_key=True), + Column("y", Integer, primary_key=True), + Column("z", Integer), + ) + Table( + "child", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), + # note 'z' is not here + sa.ForeignKeyConstraint(["x", "y"], ["parent.x", "parent.y"]), + ) @classmethod def setup_mappers(cls): @@ -1070,13 +1260,21 @@ class CompositeJoinPartialFK(fixtures.MappedTest, AssertsCompiledSQL): class Child(cls.Comparable): pass - mapper(Parent, parent, properties={ - 'children': relationship(Child, primaryjoin=and_( - parent.c.x == child.c.x, - parent.c.y == child.c.y, - parent.c.z == child.c.z, - )) - }) + + mapper( + Parent, + parent, + properties={ + "children": relationship( + Child, + primaryjoin=and_( + parent.c.x == child.c.x, + parent.c.y == child.c.y, + parent.c.z == child.c.z, + ), + ) + }, + ) mapper(Child, child) def test_joins_fully(self): @@ -1084,7 +1282,7 @@ class CompositeJoinPartialFK(fixtures.MappedTest, AssertsCompiledSQL): self.assert_compile( Parent.children.property.strategy._lazywhere, - ":param_1 = child.x AND :param_2 = child.y AND :param_3 = child.z" + ":param_1 = child.x AND :param_2 = child.y AND :param_3 = child.z", ) @@ -1094,15 +1292,21 @@ class SynonymsAsFKsTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table("tableA", metadata, - Column("id", Integer, primary_key=True), - Column("foo", Integer,), - test_needs_fk=True) + Table( + "tableA", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer), + test_needs_fk=True, + ) - Table("tableB", metadata, - Column("id", Integer, primary_key=True), - Column("_a_id", Integer, key='a_id', primary_key=True), - test_needs_fk=True) + Table( + "tableB", + metadata, + Column("id", Integer, primary_key=True), + Column("_a_id", Integer, key="a_id", primary_key=True), + test_needs_fk=True, + ) @classmethod def setup_classes(cls): @@ -1110,7 +1314,6 @@ class SynonymsAsFKsTest(fixtures.MappedTest): pass class B(cls.Basic): - @property def a_id(self): return self._a_id @@ -1119,16 +1322,27 @@ class SynonymsAsFKsTest(fixtures.MappedTest): """test that active history is enabled on a one-to-many/one that has use_get==True""" - tableB, A, B, tableA = (self.tables.tableB, - self.classes.A, - self.classes.B, - self.tables.tableA) + tableB, A, B, tableA = ( + self.tables.tableB, + self.classes.A, + self.classes.B, + self.tables.tableA, + ) - mapper(B, tableB, properties={ - 'a_id': synonym('_a_id', map_column=True)}) - mapper(A, tableA, properties={ - 'b': relationship(B, primaryjoin=(tableA.c.id == foreign(B.a_id)), - uselist=False)}) + mapper( + B, tableB, properties={"a_id": synonym("_a_id", map_column=True)} + ) + mapper( + A, + tableA, + properties={ + "b": relationship( + B, + primaryjoin=(tableA.c.id == foreign(B.a_id)), + uselist=False, + ) + }, + ) sess = create_session() @@ -1150,15 +1364,22 @@ class FKsAsPksTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table("tableA", metadata, - Column("id", Integer, primary_key=True, - test_needs_autoincrement=True), - Column("foo", Integer,), - test_needs_fk=True) + Table( + "tableA", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("foo", Integer), + test_needs_fk=True, + ) - Table("tableB", metadata, - Column("id", Integer, ForeignKey("tableA.id"), primary_key=True), - test_needs_fk=True) + Table( + "tableB", + metadata, + Column("id", Integer, ForeignKey("tableA.id"), primary_key=True), + test_needs_fk=True, + ) @classmethod def setup_classes(cls): @@ -1172,13 +1393,22 @@ class FKsAsPksTest(fixtures.MappedTest): """test that active history is enabled on a one-to-many/one that has use_get==True""" - tableB, A, B, tableA = (self.tables.tableB, - self.classes.A, - self.classes.B, - self.tables.tableA) + tableB, A, B, tableA = ( + self.tables.tableB, + self.classes.A, + self.classes.B, + self.tables.tableA, + ) - mapper(A, tableA, properties={ - 'b': relationship(B, cascade="all,delete-orphan", uselist=False)}) + mapper( + A, + tableA, + properties={ + "b": relationship( + B, cascade="all,delete-orphan", uselist=False + ) + }, + ) mapper(B, tableB) configure_mappers() @@ -1197,13 +1427,18 @@ class FKsAsPksTest(fixtures.MappedTest): def test_no_delete_PK_AtoB(self): """A cant be deleted without B because B would have no PK value.""" - tableB, A, B, tableA = (self.tables.tableB, - self.classes.A, - self.classes.B, - self.tables.tableA) + tableB, A, B, tableA = ( + self.tables.tableB, + self.classes.A, + self.classes.B, + self.tables.tableA, + ) - mapper(A, tableA, properties={ - 'bs': relationship(B, cascade="save-update")}) + mapper( + A, + tableA, + properties={"bs": relationship(B, cascade="save-update")}, + ) mapper(B, tableB) a1 = A() @@ -1217,18 +1452,23 @@ class FKsAsPksTest(fixtures.MappedTest): sess.flush() assert False except AssertionError as e: - startswith_(str(e), - "Dependency rule tried to blank-out " - "primary key column 'tableB.id' on instance ") + startswith_( + str(e), + "Dependency rule tried to blank-out " + "primary key column 'tableB.id' on instance ", + ) def test_no_delete_PK_BtoA(self): - tableB, A, B, tableA = (self.tables.tableB, - self.classes.A, - self.classes.B, - self.tables.tableA) + tableB, A, B, tableA = ( + self.tables.tableB, + self.classes.A, + self.classes.B, + self.tables.tableA, + ) - mapper(B, tableB, properties={ - 'a': relationship(A, cascade="save-update")}) + mapper( + B, tableB, properties={"a": relationship(A, cascade="save-update")} + ) mapper(A, tableA) b1 = B() @@ -1242,28 +1482,39 @@ class FKsAsPksTest(fixtures.MappedTest): sess.flush() assert False except AssertionError as e: - startswith_(str(e), - "Dependency rule tried to blank-out " - "primary key column 'tableB.id' on instance ") + startswith_( + str(e), + "Dependency rule tried to blank-out " + "primary key column 'tableB.id' on instance ", + ) @testing.fails_on_everything_except( - 'sqlite', testing.requires.mysql_non_strict) + "sqlite", testing.requires.mysql_non_strict + ) def test_nullPKsOK_BtoA(self): A, tableA = self.classes.A, self.tables.tableA # postgresql cant handle a nullable PK column...? tableC = Table( - 'tablec', tableA.metadata, - Column('id', Integer, primary_key=True), - Column('a_id', Integer, ForeignKey('tableA.id'), - primary_key=True, nullable=True)) + "tablec", + tableA.metadata, + Column("id", Integer, primary_key=True), + Column( + "a_id", + Integer, + ForeignKey("tableA.id"), + primary_key=True, + nullable=True, + ), + ) tableC.create() class C(fixtures.BasicEntity): pass - mapper(C, tableC, properties={ - 'a': relationship(A, cascade="save-update") - }) + + mapper( + C, tableC, properties={"a": relationship(A, cascade="save-update")} + ) mapper(A, tableA) c1 = C() @@ -1278,17 +1529,25 @@ class FKsAsPksTest(fixtures.MappedTest): """No 'blank the PK' error when the child is to be deleted as part of a cascade""" - tableB, A, B, tableA = (self.tables.tableB, - self.classes.A, - self.classes.B, - self.tables.tableA) - - for cascade in ("save-update, delete", - # "save-update, delete-orphan", - "save-update, delete, delete-orphan"): - mapper(B, tableB, properties={ - 'a': relationship(A, cascade=cascade, single_parent=True) - }) + tableB, A, B, tableA = ( + self.tables.tableB, + self.classes.A, + self.classes.B, + self.tables.tableA, + ) + + for cascade in ( + "save-update, delete", + # "save-update, delete-orphan", + "save-update, delete, delete-orphan", + ): + mapper( + B, + tableB, + properties={ + "a": relationship(A, cascade=cascade, single_parent=True) + }, + ) mapper(A, tableA) b1 = B() @@ -1308,17 +1567,21 @@ class FKsAsPksTest(fixtures.MappedTest): """No 'blank the PK' error when the child is to be deleted as part of a cascade""" - tableB, A, B, tableA = (self.tables.tableB, - self.classes.A, - self.classes.B, - self.tables.tableA) - - for cascade in ("save-update, delete", - # "save-update, delete-orphan", - "save-update, delete, delete-orphan"): - mapper(A, tableA, properties={ - 'bs': relationship(B, cascade=cascade) - }) + tableB, A, B, tableA = ( + self.tables.tableB, + self.classes.A, + self.classes.B, + self.tables.tableA, + ) + + for cascade in ( + "save-update, delete", + # "save-update, delete-orphan", + "save-update, delete, delete-orphan", + ): + mapper( + A, tableA, properties={"bs": relationship(B, cascade=cascade)} + ) mapper(B, tableB) a1 = A() @@ -1336,13 +1599,14 @@ class FKsAsPksTest(fixtures.MappedTest): sa.orm.clear_mappers() def test_delete_manual_AtoB(self): - tableB, A, B, tableA = (self.tables.tableB, - self.classes.A, - self.classes.B, - self.tables.tableA) + tableB, A, B, tableA = ( + self.tables.tableB, + self.classes.A, + self.classes.B, + self.tables.tableA, + ) - mapper(A, tableA, properties={ - 'bs': relationship(B, cascade="none")}) + mapper(A, tableA, properties={"bs": relationship(B, cascade="none")}) mapper(B, tableB) a1 = A() @@ -1361,13 +1625,14 @@ class FKsAsPksTest(fixtures.MappedTest): sess.expunge_all() def test_delete_manual_BtoA(self): - tableB, A, B, tableA = (self.tables.tableB, - self.classes.A, - self.classes.B, - self.tables.tableA) + tableB, A, B, tableA = ( + self.tables.tableB, + self.classes.A, + self.classes.B, + self.tables.tableA, + ) - mapper(B, tableB, properties={ - 'a': relationship(A, cascade="none")}) + mapper(B, tableB, properties={"a": relationship(A, cascade="none")}) mapper(A, tableA) b1 = B() @@ -1391,20 +1656,28 @@ class UniqueColReferenceSwitchTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table("table_a", metadata, - Column("id", Integer, primary_key=True, - test_needs_autoincrement=True), - Column("ident", String(10), nullable=False, - unique=True), - ) - - Table("table_b", metadata, - Column("id", Integer, primary_key=True, - test_needs_autoincrement=True), - Column("a_ident", String(10), - ForeignKey('table_a.ident'), - nullable=False), - ) + Table( + "table_a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("ident", String(10), nullable=False, unique=True), + ) + + Table( + "table_b", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column( + "a_ident", + String(10), + ForeignKey("table_a.ident"), + nullable=False, + ), + ) @classmethod def setup_classes(cls): @@ -1415,10 +1688,12 @@ class UniqueColReferenceSwitchTest(fixtures.MappedTest): pass def test_switch_parent(self): - A, B, table_b, table_a = (self.classes.A, - self.classes.B, - self.tables.table_b, - self.tables.table_a) + A, B, table_b, table_a = ( + self.classes.A, + self.classes.B, + self.tables.table_b, + self.tables.table_a, + ) mapper(A, table_a) mapper(B, table_b, properties={"a": relationship(A, backref="bs")}) @@ -1426,9 +1701,7 @@ class UniqueColReferenceSwitchTest(fixtures.MappedTest): session = create_session() a1, a2 = A(ident="uuid1"), A(ident="uuid2") session.add_all([a1, a2]) - a1.bs = [ - B(), B() - ] + a1.bs = [B(), B()] session.flush() session.expire_all() a1, a2 = session.query(A).all() @@ -1445,15 +1718,30 @@ class RelationshipToSelectableTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('items', metadata, - Column('item_policy_num', String(10), primary_key=True, - key='policyNum'), - Column('item_policy_eff_date', sa.Date, primary_key=True, - key='policyEffDate'), - Column('item_type', String(20), primary_key=True, - key='type'), - Column('item_id', Integer, primary_key=True, - key='id', autoincrement=False)) + Table( + "items", + metadata, + Column( + "item_policy_num", + String(10), + primary_key=True, + key="policyNum", + ), + Column( + "item_policy_eff_date", + sa.Date, + primary_key=True, + key="policyEffDate", + ), + Column("item_type", String(20), primary_key=True, key="type"), + Column( + "item_id", + Integer, + primary_key=True, + key="id", + autoincrement=False, + ), + ) def test_basic(self): items = self.tables.items @@ -1467,7 +1755,7 @@ class RelationshipToSelectableTest(fixtures.MappedTest): container_select = sa.select( [items.c.policyNum, items.c.policyEffDate, items.c.type], distinct=True, - ).alias('container_select') + ).alias("container_select") mapper(LineItem, items) @@ -1477,21 +1765,22 @@ class RelationshipToSelectableTest(fixtures.MappedTest): properties=dict( lineItems=relationship( LineItem, - lazy='select', - cascade='all, delete-orphan', + lazy="select", + cascade="all, delete-orphan", order_by=sa.asc(items.c.id), primaryjoin=sa.and_( container_select.c.policyNum == items.c.policyNum, - container_select.c.policyEffDate == - items.c.policyEffDate, - container_select.c.type == items.c.type), + container_select.c.policyEffDate + == items.c.policyEffDate, + container_select.c.type == items.c.type, + ), foreign_keys=[ items.c.policyNum, items.c.policyEffDate, - items.c.type - ] + items.c.type, + ], ) - ) + ), ) session = create_session() @@ -1507,8 +1796,9 @@ class RelationshipToSelectableTest(fixtures.MappedTest): session.add(li) session.flush() session.expunge_all() - newcon = session.query(Container).\ - order_by(container_select.c.type).first() + newcon = ( + session.query(Container).order_by(container_select.c.type).first() + ) assert con.policyNum == newcon.policyNum assert len(newcon.lineItems) == 10 for old, new in zip(con.lineItems, newcon.lineItems): @@ -1525,17 +1815,24 @@ class FKEquatedToConstantTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('tags', metadata, Column("id", Integer, primary_key=True, - test_needs_autoincrement=True), - Column("data", String(50)), - ) - - Table('tag_foo', metadata, - Column("id", Integer, primary_key=True, - test_needs_autoincrement=True), - Column('tagid', Integer), - Column("data", String(50)), - ) + Table( + "tags", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + + Table( + "tag_foo", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("tagid", Integer), + Column("data", String(50)), + ) def test_basic(self): tag_foo, tags = self.tables.tag_foo, self.tables.tags @@ -1546,20 +1843,27 @@ class FKEquatedToConstantTest(fixtures.MappedTest): class TagInstance(fixtures.ComparableEntity): pass - mapper(Tag, tags, properties={ - 'foo': relationship( - TagInstance, - primaryjoin=sa.and_(tag_foo.c.data == 'iplc_case', - tag_foo.c.tagid == tags.c.id), - foreign_keys=[tag_foo.c.tagid, tag_foo.c.data]), - }) + mapper( + Tag, + tags, + properties={ + "foo": relationship( + TagInstance, + primaryjoin=sa.and_( + tag_foo.c.data == "iplc_case", + tag_foo.c.tagid == tags.c.id, + ), + foreign_keys=[tag_foo.c.tagid, tag_foo.c.data], + ) + }, + ) mapper(TagInstance, tag_foo) sess = create_session() - t1 = Tag(data='some tag') - t1.foo.append(TagInstance(data='iplc_case')) - t1.foo.append(TagInstance(data='not_iplc_case')) + t1 = Tag(data="some tag") + t1.foo.append(TagInstance(data="iplc_case")) + t1.foo.append(TagInstance(data="not_iplc_case")) sess.add(t1) sess.flush() sess.expunge_all() @@ -1567,31 +1871,36 @@ class FKEquatedToConstantTest(fixtures.MappedTest): # relationship works eq_( sess.query(Tag).all(), - [Tag(data='some tag', foo=[TagInstance(data='iplc_case')])] + [Tag(data="some tag", foo=[TagInstance(data="iplc_case")])], ) # both TagInstances were persisted eq_( sess.query(TagInstance).order_by(TagInstance.data).all(), - [TagInstance(data='iplc_case'), TagInstance(data='not_iplc_case')] + [TagInstance(data="iplc_case"), TagInstance(data="not_iplc_case")], ) class BackrefPropagatesForwardsArgs(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)) - ) - Table('addresses', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', Integer), - Column('email', String(50)) - ) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(50)), + ) + Table( + "addresses", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", Integer), + Column("email", String(50)), + ) @classmethod def setup_classes(cls): @@ -1602,27 +1911,35 @@ class BackrefPropagatesForwardsArgs(fixtures.MappedTest): pass def test_backref(self): - User, Address, users, addresses = (self.classes.User, - self.classes.Address, - self.tables.users, - self.tables.addresses) - - mapper(User, users, properties={ - 'addresses': relationship( - Address, - primaryjoin=addresses.c.user_id == users.c.id, - foreign_keys=addresses.c.user_id, - backref='user') - }) + User, Address, users, addresses = ( + self.classes.User, + self.classes.Address, + self.tables.users, + self.tables.addresses, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, + primaryjoin=addresses.c.user_id == users.c.id, + foreign_keys=addresses.c.user_id, + backref="user", + ) + }, + ) mapper(Address, addresses) sess = sessionmaker()() - u1 = User(name='u1', addresses=[Address(email='a1')]) + u1 = User(name="u1", addresses=[Address(email="a1")]) sess.add(u1) sess.commit() - eq_(sess.query(Address).all(), [ - Address(email='a1', user=User(name='u1')) - ]) + eq_( + sess.query(Address).all(), + [Address(email="a1", user=User(name="u1"))], + ) class AmbiguousJoinInterpretedAsSelfRef(fixtures.MappedTest): @@ -1640,17 +1957,23 @@ class AmbiguousJoinInterpretedAsSelfRef(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'subscriber', metadata, + "subscriber", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True)) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + ) Table( - 'address', metadata, + "address", + metadata, Column( - 'subscriber_id', Integer, - ForeignKey('subscriber.id'), primary_key=True), - Column('type', String(1), primary_key=True), + "subscriber_id", + Integer, + ForeignKey("subscriber.id"), + primary_key=True, + ), + Column("type", String(1), primary_key=True), ) @classmethod @@ -1659,8 +1982,11 @@ class AmbiguousJoinInterpretedAsSelfRef(fixtures.MappedTest): subscriber_and_address = subscriber.join( address, - and_(address.c.subscriber_id == subscriber.c.id, - address.c.type.in_(['A', 'B', 'C']))) + and_( + address.c.subscriber_id == subscriber.c.id, + address.c.type.in_(["A", "B", "C"]), + ), + ) class Address(cls.Comparable): pass @@ -1670,11 +1996,16 @@ class AmbiguousJoinInterpretedAsSelfRef(fixtures.MappedTest): mapper(Address, address) - mapper(Subscriber, subscriber_and_address, properties={ - 'id': [subscriber.c.id, address.c.subscriber_id], - 'addresses': relationship(Address, - backref=backref("customer")) - }) + mapper( + Subscriber, + subscriber_and_address, + properties={ + "id": [subscriber.c.id, address.c.subscriber_id], + "addresses": relationship( + Address, backref=backref("customer") + ), + }, + ) def test_mapping(self): Subscriber, Address = self.classes.Subscriber, self.classes.Address @@ -1683,13 +2014,10 @@ class AmbiguousJoinInterpretedAsSelfRef(fixtures.MappedTest): assert Subscriber.addresses.property.direction is ONETOMANY assert Address.customer.property.direction is MANYTOONE - s1 = Subscriber(type='A', - addresses=[ - Address(type='D'), - Address(type='E'), - ] - ) - a1 = Address(type='B', customer=Subscriber(type='C')) + s1 = Subscriber( + type="A", addresses=[Address(type="D"), Address(type="E")] + ) + a1 = Address(type="B", customer=Subscriber(type="C")) assert s1.addresses[0].customer is s1 assert a1.customer.addresses[0] is a1 @@ -1702,10 +2030,10 @@ class AmbiguousJoinInterpretedAsSelfRef(fixtures.MappedTest): eq_( sess.query(Subscriber).order_by(Subscriber.type).all(), [ - Subscriber(id=1, type='A'), - Subscriber(id=2, type='B'), - Subscriber(id=2, type='C') - ] + Subscriber(id=1, type="A"), + Subscriber(id=2, type="B"), + Subscriber(id=2, type="C"), + ], ) @@ -1716,23 +2044,33 @@ class ManualBackrefTest(_fixtures.FixtureTest): run_inserts = None def test_o2m(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship(Address, back_populates='user') - }) + mapper( + User, + users, + properties={ + "addresses": relationship(Address, back_populates="user") + }, + ) - mapper(Address, addresses, properties={ - 'user': relationship(User, back_populates='addresses') - }) + mapper( + Address, + addresses, + properties={ + "user": relationship(User, back_populates="addresses") + }, + ) sess = create_session() - u1 = User(name='u1') - a1 = Address(email_address='foo') + u1 = User(name="u1") + a1 = Address(email_address="foo") u1.addresses.append(a1) assert a1.user is u1 @@ -1744,19 +2082,29 @@ class ManualBackrefTest(_fixtures.FixtureTest): assert a1 in u1.addresses def test_invalid_key(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, back_populates='userr') - }) - - mapper(Address, addresses, properties={ - 'user': relationship(User, back_populates='addresses') - }) - + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship(Address, back_populates="userr") + }, + ) + + mapper( + Address, + addresses, + properties={ + "user": relationship(User, back_populates="addresses") + }, + ) + assert_raises(sa.exc.InvalidRequestError, configure_mappers) def test_invalid_target(self): @@ -1766,24 +2114,33 @@ class ManualBackrefTest(_fixtures.FixtureTest): self.classes.User, self.tables.dingalings, self.classes.Address, - self.tables.users) + self.tables.users, + ) - mapper(User, users, properties={ - 'addresses': relationship(Address, back_populates='dingaling'), - }) + mapper( + User, + users, + properties={ + "addresses": relationship(Address, back_populates="dingaling") + }, + ) mapper(Dingaling, dingalings) - mapper(Address, addresses, properties={ - 'dingaling': relationship(Dingaling) - }) + mapper( + Address, + addresses, + properties={"dingaling": relationship(Dingaling)}, + ) - assert_raises_message(sa.exc.ArgumentError, - r"reverse_property 'dingaling' on relationship " - r"User.addresses references " - r"relationship Address.dingaling, " - r"which does not " - r"reference mapper Mapper\|User\|users", - configure_mappers) + assert_raises_message( + sa.exc.ArgumentError, + r"reverse_property 'dingaling' on relationship " + r"User.addresses references " + r"relationship Address.dingaling, " + r"which does not " + r"reference mapper Mapper\|User\|users", + configure_mappers, + ) class NoLoadBackPopulates(_fixtures.FixtureTest): @@ -1792,19 +2149,24 @@ class NoLoadBackPopulates(_fixtures.FixtureTest): lazyloader to set up instrumentation""" def test_o2m(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship( - Address, back_populates='user', lazy="noload") - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, back_populates="user", lazy="noload" + ) + }, + ) - mapper(Address, addresses, properties={ - 'user': relationship(User) - }) + mapper(Address, addresses, properties={"user": relationship(User)}) u1 = User() a1 = Address() @@ -1812,20 +2174,24 @@ class NoLoadBackPopulates(_fixtures.FixtureTest): is_(a1.user, u1) def test_m2o(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship( - Address) - }) + mapper(User, users, properties={"addresses": relationship(Address)}) - mapper(Address, addresses, properties={ - 'user': relationship( - User, back_populates='addresses', lazy="noload") - }) + mapper( + Address, + addresses, + properties={ + "user": relationship( + User, back_populates="addresses", lazy="noload" + ) + }, + ) u1 = User() a1 = Address() @@ -1834,48 +2200,49 @@ class NoLoadBackPopulates(_fixtures.FixtureTest): class JoinConditionErrorTest(fixtures.TestBase): - def test_clauseelement_pj(self): from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() class C1(Base): - __tablename__ = 'c1' - id = Column('id', Integer, primary_key=True) + __tablename__ = "c1" + id = Column("id", Integer, primary_key=True) class C2(Base): - __tablename__ = 'c2' - id = Column('id', Integer, primary_key=True) - c1id = Column('c1id', Integer, ForeignKey('c1.id')) + __tablename__ = "c2" + id = Column("id", Integer, primary_key=True) + c1id = Column("c1id", Integer, ForeignKey("c1.id")) c2 = relationship(C1, primaryjoin=C1.id) assert_raises(sa.exc.ArgumentError, configure_mappers) def test_clauseelement_pj_false(self): from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() class C1(Base): - __tablename__ = 'c1' - id = Column('id', Integer, primary_key=True) + __tablename__ = "c1" + id = Column("id", Integer, primary_key=True) class C2(Base): - __tablename__ = 'c2' - id = Column('id', Integer, primary_key=True) - c1id = Column('c1id', Integer, ForeignKey('c1.id')) + __tablename__ = "c2" + id = Column("id", Integer, primary_key=True) + c1id = Column("c1id", Integer, ForeignKey("c1.id")) c2 = relationship(C1, primaryjoin="x" == "y") assert_raises(sa.exc.ArgumentError, configure_mappers) def test_only_column_elements(self): m = MetaData() - t1 = Table('t1', m, - Column('id', Integer, primary_key=True), - Column('foo_id', Integer, ForeignKey('t2.id')), - ) - t2 = Table('t2', m, - Column('id', Integer, primary_key=True), - ) + t1 = Table( + "t1", + m, + Column("id", Integer, primary_key=True), + Column("foo_id", Integer, ForeignKey("t2.id")), + ) + t2 = Table("t2", m, Column("id", Integer, primary_key=True)) class C1(object): pass @@ -1883,8 +2250,11 @@ class JoinConditionErrorTest(fixtures.TestBase): class C2(object): pass - mapper(C1, t1, properties={ - 'c2': relationship(C2, primaryjoin=t1.join(t2))}) + mapper( + C1, + t1, + properties={"c2": relationship(C2, primaryjoin=t1.join(t2))}, + ) mapper(C2, t2) assert_raises(sa.exc.ArgumentError, configure_mappers) @@ -1892,47 +2262,50 @@ class JoinConditionErrorTest(fixtures.TestBase): from sqlalchemy.ext.declarative import declarative_base for argname, arg in [ - ('remote_side', ['c1.id']), - ('remote_side', ['id']), - ('foreign_keys', ['c1id']), - ('foreign_keys', ['C2.c1id']), - ('order_by', ['id']), + ("remote_side", ["c1.id"]), + ("remote_side", ["id"]), + ("foreign_keys", ["c1id"]), + ("foreign_keys", ["C2.c1id"]), + ("order_by", ["id"]), ]: clear_mappers() kw = {argname: arg} Base = declarative_base() class C1(Base): - __tablename__ = 'c1' - id = Column('id', Integer, primary_key=True) + __tablename__ = "c1" + id = Column("id", Integer, primary_key=True) class C2(Base): - __tablename__ = 'c2' - id_ = Column('id', Integer, primary_key=True) - c1id = Column('c1id', Integer, ForeignKey('c1.id')) + __tablename__ = "c2" + id_ = Column("id", Integer, primary_key=True) + c1id = Column("c1id", Integer, ForeignKey("c1.id")) c2 = relationship(C1, **kw) assert_raises_message( sa.exc.ArgumentError, "Column-based expression object expected " - "for argument '%s'; got: '%s', type %r" % - (argname, arg[0], type(arg[0])), - configure_mappers) + "for argument '%s'; got: '%s', type %r" + % (argname, arg[0], type(arg[0])), + configure_mappers, + ) def test_fk_error_not_raised_unrelated(self): m = MetaData() - t1 = Table('t1', m, - Column('id', Integer, primary_key=True), - Column('foo_id', Integer, ForeignKey('t2.nonexistent_id')), - ) - t2 = Table('t2', m, # noqa - Column('id', Integer, primary_key=True), - ) - - t3 = Table('t3', m, - Column('id', Integer, primary_key=True), - Column('t1id', Integer, ForeignKey('t1.id')) - ) + t1 = Table( + "t1", + m, + Column("id", Integer, primary_key=True), + Column("foo_id", Integer, ForeignKey("t2.nonexistent_id")), + ) + t2 = Table("t2", m, Column("id", Integer, primary_key=True)) # noqa + + t3 = Table( + "t3", + m, + Column("id", Integer, primary_key=True), + Column("t1id", Integer, ForeignKey("t1.id")), + ) class C1(object): pass @@ -1940,23 +2313,21 @@ class JoinConditionErrorTest(fixtures.TestBase): class C2(object): pass - mapper(C1, t1, properties={'c2': relationship(C2)}) + mapper(C1, t1, properties={"c2": relationship(C2)}) mapper(C2, t3) assert C1.c2.property.primaryjoin.compare(t1.c.id == t3.c.t1id) def test_join_error_raised(self): m = MetaData() - t1 = Table('t1', m, - Column('id', Integer, primary_key=True), - ) - t2 = Table('t2', m, # noqa - Column('id', Integer, primary_key=True), - ) - - t3 = Table('t3', m, - Column('id', Integer, primary_key=True), - Column('t1id', Integer) - ) + t1 = Table("t1", m, Column("id", Integer, primary_key=True)) + t2 = Table("t2", m, Column("id", Integer, primary_key=True)) # noqa + + t3 = Table( + "t3", + m, + Column("id", Integer, primary_key=True), + Column("t1id", Integer), + ) class C1(object): pass @@ -1964,7 +2335,7 @@ class JoinConditionErrorTest(fixtures.TestBase): class C2(object): pass - mapper(C1, t1, properties={'c2': relationship(C2)}) + mapper(C1, t1, properties={"c2": relationship(C2)}) mapper(C2, t3) assert_raises(sa.exc.ArgumentError, configure_mappers) @@ -1980,30 +2351,44 @@ class TypeMatchTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table("a", metadata, - Column('aid', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('adata', String(30))) - Table("b", metadata, - Column('bid', Integer, primary_key=True, - test_needs_autoincrement=True), - Column("a_id", Integer, ForeignKey("a.aid")), - Column('bdata', String(30))) - Table("c", metadata, - Column('cid', Integer, primary_key=True, - test_needs_autoincrement=True), - Column("b_id", Integer, ForeignKey("b.bid")), - Column('cdata', String(30))) - Table("d", metadata, - Column('did', Integer, primary_key=True, - test_needs_autoincrement=True), - Column("a_id", Integer, ForeignKey("a.aid")), - Column('ddata', String(30))) + Table( + "a", + metadata, + Column( + "aid", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("adata", String(30)), + ) + Table( + "b", + metadata, + Column( + "bid", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("a_id", Integer, ForeignKey("a.aid")), + Column("bdata", String(30)), + ) + Table( + "c", + metadata, + Column( + "cid", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("b_id", Integer, ForeignKey("b.bid")), + Column("cdata", String(30)), + ) + Table( + "d", + metadata, + Column( + "did", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("a_id", Integer, ForeignKey("a.aid")), + Column("ddata", String(30)), + ) def test_o2m_oncascade(self): - a, c, b = (self.tables.a, - self.tables.c, - self.tables.b) + a, c, b = (self.tables.a, self.tables.c, self.tables.b) class A(fixtures.BasicEntity): pass @@ -2013,7 +2398,8 @@ class TypeMatchTest(fixtures.MappedTest): class C(fixtures.BasicEntity): pass - mapper(A, a, properties={'bs': relationship(B)}) + + mapper(A, a, properties={"bs": relationship(B)}) mapper(B, b) mapper(C, c) @@ -2027,14 +2413,14 @@ class TypeMatchTest(fixtures.MappedTest): sess.add(a1) assert False except AssertionError as err: - eq_(str(err), + eq_( + str(err), "Attribute 'bs' on class '%s' doesn't handle " - "objects of type '%s'" % (A, C)) + "objects of type '%s'" % (A, C), + ) def test_o2m_onflush(self): - a, c, b = (self.tables.a, - self.tables.c, - self.tables.b) + a, c, b = (self.tables.a, self.tables.c, self.tables.b) class A(fixtures.BasicEntity): pass @@ -2044,7 +2430,8 @@ class TypeMatchTest(fixtures.MappedTest): class C(fixtures.BasicEntity): pass - mapper(A, a, properties={'bs': relationship(B, cascade="none")}) + + mapper(A, a, properties={"bs": relationship(B, cascade="none")}) mapper(B, b) mapper(C, c) @@ -2057,14 +2444,12 @@ class TypeMatchTest(fixtures.MappedTest): sess.add(a1) sess.add(b1) sess.add(c1) - assert_raises_message(sa.orm.exc.FlushError, - "Attempting to flush an item", - sess.flush) + assert_raises_message( + sa.orm.exc.FlushError, "Attempting to flush an item", sess.flush + ) def test_o2m_nopoly_onflush(self): - a, c, b = (self.tables.a, - self.tables.c, - self.tables.b) + a, c, b = (self.tables.a, self.tables.c, self.tables.b) class A(fixtures.BasicEntity): pass @@ -2074,7 +2459,8 @@ class TypeMatchTest(fixtures.MappedTest): class C(B): pass - mapper(A, a, properties={'bs': relationship(B, cascade="none")}) + + mapper(A, a, properties={"bs": relationship(B, cascade="none")}) mapper(B, b) mapper(C, c, inherits=B) @@ -2087,14 +2473,12 @@ class TypeMatchTest(fixtures.MappedTest): sess.add(a1) sess.add(b1) sess.add(c1) - assert_raises_message(sa.orm.exc.FlushError, - "Attempting to flush an item", - sess.flush) + assert_raises_message( + sa.orm.exc.FlushError, "Attempting to flush an item", sess.flush + ) def test_m2o_nopoly_onflush(self): - a, b, d = (self.tables.a, - self.tables.b, - self.tables.d) + a, b, d = (self.tables.a, self.tables.b, self.tables.d) class A(fixtures.BasicEntity): pass @@ -2104,6 +2488,7 @@ class TypeMatchTest(fixtures.MappedTest): class D(fixtures.BasicEntity): pass + mapper(A, a) mapper(B, b, inherits=A) mapper(D, d, properties={"a": relationship(A, cascade="none")}) @@ -2113,14 +2498,12 @@ class TypeMatchTest(fixtures.MappedTest): sess = create_session() sess.add(b1) sess.add(d1) - assert_raises_message(sa.orm.exc.FlushError, - "Attempting to flush an item", - sess.flush) + assert_raises_message( + sa.orm.exc.FlushError, "Attempting to flush an item", sess.flush + ) def test_m2o_oncascade(self): - a, b, d = (self.tables.a, - self.tables.b, - self.tables.d) + a, b, d = (self.tables.a, self.tables.b, self.tables.d) class A(fixtures.BasicEntity): pass @@ -2130,6 +2513,7 @@ class TypeMatchTest(fixtures.MappedTest): class D(fixtures.BasicEntity): pass + mapper(A, a) mapper(B, b) mapper(D, d, properties={"a": relationship(A)}) @@ -2137,13 +2521,12 @@ class TypeMatchTest(fixtures.MappedTest): d1 = D() d1.a = b1 sess = create_session() - assert_raises_message(AssertionError, - "doesn't handle objects of type", - sess.add, d1) + assert_raises_message( + AssertionError, "doesn't handle objects of type", sess.add, d1 + ) class TypedAssociationTable(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): class MySpecialType(sa.types.TypeDecorator): @@ -2155,31 +2538,42 @@ class TypedAssociationTable(fixtures.MappedTest): def process_result_value(self, value, dialect): return value[4:] - Table('t1', metadata, - Column('col1', MySpecialType(30), primary_key=True), - Column('col2', String(30))) - Table('t2', metadata, - Column('col1', MySpecialType(30), primary_key=True), - Column('col2', String(30))) - Table('t3', metadata, - Column('t1c1', MySpecialType(30), ForeignKey('t1.col1')), - Column('t2c1', MySpecialType(30), ForeignKey('t2.col1'))) + Table( + "t1", + metadata, + Column("col1", MySpecialType(30), primary_key=True), + Column("col2", String(30)), + ) + Table( + "t2", + metadata, + Column("col1", MySpecialType(30), primary_key=True), + Column("col2", String(30)), + ) + Table( + "t3", + metadata, + Column("t1c1", MySpecialType(30), ForeignKey("t1.col1")), + Column("t2c1", MySpecialType(30), ForeignKey("t2.col1")), + ) def test_m2m(self): """Many-to-many tables with special types for candidate keys.""" - t2, t3, t1 = (self.tables.t2, - self.tables.t3, - self.tables.t1) + t2, t3, t1 = (self.tables.t2, self.tables.t3, self.tables.t1) class T1(fixtures.BasicEntity): pass class T2(fixtures.BasicEntity): pass + mapper(T2, t2) - mapper(T1, t1, properties={ - 't2s': relationship(T2, secondary=t3, backref='t1s')}) + mapper( + T1, + t1, + properties={"t2s": relationship(T2, secondary=t3, backref="t1s")}, + ) a = T1() a.col1 = "aid" @@ -2193,12 +2587,12 @@ class TypedAssociationTable(fixtures.MappedTest): sess.add(a) sess.flush() - eq_(select([func.count('*')]).select_from(t3).scalar(), 2) + eq_(select([func.count("*")]).select_from(t3).scalar(), 2) a.t2s.remove(c) sess.flush() - eq_(select([func.count('*')]).select_from(t3).scalar(), 1) + eq_(select([func.count("*")]).select_from(t3).scalar(), 1) class CustomOperatorTest(fixtures.MappedTest, AssertsCompiledSQL): @@ -2207,18 +2601,22 @@ class CustomOperatorTest(fixtures.MappedTest, AssertsCompiledSQL): run_create_tables = run_deletes = None - __dialect__ = 'default' + __dialect__ = "default" @classmethod def define_tables(cls, metadata): - Table('a', metadata, - Column('id', Integer, primary_key=True), - Column('foo', String(50)) - ) - Table('b', metadata, - Column('id', Integer, primary_key=True), - Column('foo', String(50)) - ) + Table( + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", String(50)), + ) + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", String(50)), + ) def test_join_on_custom_op(self): class A(fixtures.BasicEntity): @@ -2227,35 +2625,47 @@ class CustomOperatorTest(fixtures.MappedTest, AssertsCompiledSQL): class B(fixtures.BasicEntity): pass - mapper(A, self.tables.a, properties={ - 'bs': relationship(B, - primaryjoin=self.tables.a.c.foo.op( - '&*', is_comparison=True - )(foreign(self.tables.b.c.foo)), - viewonly=True - ) - }) + mapper( + A, + self.tables.a, + properties={ + "bs": relationship( + B, + primaryjoin=self.tables.a.c.foo.op( + "&*", is_comparison=True + )(foreign(self.tables.b.c.foo)), + viewonly=True, + ) + }, + ) mapper(B, self.tables.b) self.assert_compile( Session().query(A).join(A.bs), "SELECT a.id AS a_id, a.foo AS a_foo " - "FROM a JOIN b ON a.foo &* b.foo" + "FROM a JOIN b ON a.foo &* b.foo", ) class ViewOnlyHistoryTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table("t1", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(40))) - Table("t2", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(40)), - Column('t1id', Integer, ForeignKey('t1.id'))) + Table( + "t1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(40)), + ) + Table( + "t2", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(40)), + Column("t1id", Integer, ForeignKey("t1.id")), + ) def _assert_fk(self, a1, b1, is_set): s = Session(testing.db) @@ -2276,10 +2686,15 @@ class ViewOnlyHistoryTest(fixtures.MappedTest): class B(fixtures.ComparableEntity): pass - mapper(A, self.tables.t1, properties={ - "bs": relationship(B, viewonly=True, - backref=backref("a", viewonly=False)) - }) + mapper( + A, + self.tables.t1, + properties={ + "bs": relationship( + B, viewonly=True, backref=backref("a", viewonly=False) + ) + }, + ) mapper(B, self.tables.t2) a1 = A() @@ -2302,10 +2717,15 @@ class ViewOnlyHistoryTest(fixtures.MappedTest): class B(fixtures.ComparableEntity): pass - mapper(A, self.tables.t1, properties={ - "bs": relationship(B, viewonly=False, - backref=backref("a", viewonly=True)) - }) + mapper( + A, + self.tables.t1, + properties={ + "bs": relationship( + B, viewonly=False, backref=backref("a", viewonly=True) + ) + }, + ) mapper(B, self.tables.t2) a1 = A() @@ -2328,9 +2748,11 @@ class ViewOnlyHistoryTest(fixtures.MappedTest): class B(fixtures.ComparableEntity): pass - mapper(A, self.tables.t1, properties={ - "bs": relationship(B, viewonly=True) - }) + mapper( + A, + self.tables.t1, + properties={"bs": relationship(B, viewonly=True)}, + ) mapper(B, self.tables.t2) a1 = A() @@ -2348,9 +2770,9 @@ class ViewOnlyHistoryTest(fixtures.MappedTest): pass mapper(A, self.tables.t1) - mapper(B, self.tables.t2, properties={ - 'a': relationship(A, viewonly=True) - }) + mapper( + B, self.tables.t2, properties={"a": relationship(A, viewonly=True)} + ) a1 = A() b1 = B() @@ -2361,27 +2783,33 @@ class ViewOnlyHistoryTest(fixtures.MappedTest): class ViewOnlyM2MBackrefTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table("t1", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(40))) - Table("t2", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(40)), - ) - Table("t1t2", metadata, - Column('t1id', Integer, ForeignKey('t1.id'), primary_key=True), - Column('t2id', Integer, ForeignKey('t2.id'), primary_key=True), - ) + Table( + "t1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(40)), + ) + Table( + "t2", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(40)), + ) + Table( + "t1t2", + metadata, + Column("t1id", Integer, ForeignKey("t1.id"), primary_key=True), + Column("t2id", Integer, ForeignKey("t2.id"), primary_key=True), + ) def test_viewonly(self): - t1t2, t2, t1 = (self.tables.t1t2, - self.tables.t2, - self.tables.t1) + t1t2, t2, t1 = (self.tables.t1t2, self.tables.t2, self.tables.t1) class A(fixtures.ComparableEntity): pass @@ -2389,10 +2817,15 @@ class ViewOnlyM2MBackrefTest(fixtures.MappedTest): class B(fixtures.ComparableEntity): pass - mapper(A, t1, properties={ - 'bs': relationship(B, secondary=t1t2, - backref=backref('as_', viewonly=True)) - }) + mapper( + A, + t1, + properties={ + "bs": relationship( + B, secondary=t1t2, backref=backref("as_", viewonly=True) + ) + }, + ) mapper(B, t2) sess = create_session() @@ -2403,12 +2836,8 @@ class ViewOnlyM2MBackrefTest(fixtures.MappedTest): sess.add(a1) sess.flush() - eq_( - sess.query(A).first(), A(bs=[B(id=b1.id)]) - ) - eq_( - sess.query(B).first(), B(as_=[A(id=a1.id)]) - ) + eq_(sess.query(A).first(), A(bs=[B(id=b1.id)])) + eq_(sess.query(B).first(), B(as_=[A(id=a1.id)])) class ViewOnlyOverlappingNames(fixtures.MappedTest): @@ -2417,20 +2846,32 @@ class ViewOnlyOverlappingNames(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table("t1", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(40))) - Table("t2", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(40)), - Column('t1id', Integer, ForeignKey('t1.id'))) - Table("t3", metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(40)), - Column('t2id', Integer, ForeignKey('t2.id'))) + Table( + "t1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(40)), + ) + Table( + "t2", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(40)), + Column("t1id", Integer, ForeignKey("t1.id")), + ) + Table( + "t3", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(40)), + Column("t2id", Integer, ForeignKey("t2.id")), + ) def test_three_table_view(self): """A three table join with overlapping PK names. @@ -2441,9 +2882,7 @@ class ViewOnlyOverlappingNames(fixtures.MappedTest): """ - t2, t3, t1 = (self.tables.t2, - self.tables.t3, - self.tables.t1) + t2, t3, t1 = (self.tables.t2, self.tables.t3, self.tables.t1) class C1(fixtures.BasicEntity): pass @@ -2454,26 +2893,33 @@ class ViewOnlyOverlappingNames(fixtures.MappedTest): class C3(fixtures.BasicEntity): pass - mapper(C1, t1, properties={ - 't2s': relationship(C2), - 't2_view': relationship( - C2, - viewonly=True, - primaryjoin=sa.and_(t1.c.id == t2.c.t1id, - t3.c.t2id == t2.c.id, - t3.c.data == t1.c.data))}) + mapper( + C1, + t1, + properties={ + "t2s": relationship(C2), + "t2_view": relationship( + C2, + viewonly=True, + primaryjoin=sa.and_( + t1.c.id == t2.c.t1id, + t3.c.t2id == t2.c.id, + t3.c.data == t1.c.data, + ), + ), + }, + ) mapper(C2, t2) - mapper(C3, t3, properties={ - 't2': relationship(C2)}) + mapper(C3, t3, properties={"t2": relationship(C2)}) c1 = C1() - c1.data = 'c1data' + c1.data = "c1data" c2a = C2() c1.t2s.append(c2a) c2b = C2() c1.t2s.append(c2b) c3 = C3() - c3.data = 'c1data' + c3.data = "c1data" c3.t2 = c2b sess = create_session() sess.add(c1) @@ -2492,20 +2938,41 @@ class ViewOnlyUniqueNames(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table("t1", metadata, - Column('t1id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(40))) - Table("t2", metadata, - Column('t2id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(40)), - Column('t1id_ref', Integer, ForeignKey('t1.t1id'))) - Table("t3", metadata, - Column('t3id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(40)), - Column('t2id_ref', Integer, ForeignKey('t2.t2id'))) + Table( + "t1", + metadata, + Column( + "t1id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("data", String(40)), + ) + Table( + "t2", + metadata, + Column( + "t2id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("data", String(40)), + Column("t1id_ref", Integer, ForeignKey("t1.t1id")), + ) + Table( + "t3", + metadata, + Column( + "t3id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("data", String(40)), + Column("t2id_ref", Integer, ForeignKey("t2.t2id")), + ) def test_three_table_view(self): """A three table join with overlapping PK names. @@ -2515,9 +2982,7 @@ class ViewOnlyUniqueNames(fixtures.MappedTest): """ - t2, t3, t1 = (self.tables.t2, - self.tables.t3, - self.tables.t1) + t2, t3, t1 = (self.tables.t2, self.tables.t3, self.tables.t1) class C1(fixtures.BasicEntity): pass @@ -2528,26 +2993,33 @@ class ViewOnlyUniqueNames(fixtures.MappedTest): class C3(fixtures.BasicEntity): pass - mapper(C1, t1, properties={ - 't2s': relationship(C2), - 't2_view': relationship( - C2, - viewonly=True, - primaryjoin=sa.and_(t1.c.t1id == t2.c.t1id_ref, - t3.c.t2id_ref == t2.c.t2id, - t3.c.data == t1.c.data))}) + mapper( + C1, + t1, + properties={ + "t2s": relationship(C2), + "t2_view": relationship( + C2, + viewonly=True, + primaryjoin=sa.and_( + t1.c.t1id == t2.c.t1id_ref, + t3.c.t2id_ref == t2.c.t2id, + t3.c.data == t1.c.data, + ), + ), + }, + ) mapper(C2, t2) - mapper(C3, t3, properties={ - 't2': relationship(C2)}) + mapper(C3, t3, properties={"t2": relationship(C2)}) c1 = C1() - c1.data = 'c1data' + c1.data = "c1data" c2a = C2() c1.t2s.append(c2a) c2b = C2() c1.t2s.append(c2b) c3 = C3() - c3.data = 'c1data' + c3.data = "c1data" c3.t2 = c2b sess = create_session() @@ -2567,32 +3039,36 @@ class ViewOnlyLocalRemoteM2M(fixtures.TestBase): def test_local_remote(self): meta = MetaData() - t1 = Table('t1', meta, - Column('id', Integer, primary_key=True), - ) - t2 = Table('t2', meta, - Column('id', Integer, primary_key=True), - ) - t12 = Table('tab', meta, - Column('t1_id', Integer, ForeignKey('t1.id',)), - Column('t2_id', Integer, ForeignKey('t2.id',)), - ) + t1 = Table("t1", meta, Column("id", Integer, primary_key=True)) + t2 = Table("t2", meta, Column("id", Integer, primary_key=True)) + t12 = Table( + "tab", + meta, + Column("t1_id", Integer, ForeignKey("t1.id")), + Column("t2_id", Integer, ForeignKey("t2.id")), + ) class A(object): pass class B(object): pass - mapper(B, t2, ) - m = mapper(A, t1, properties=dict( - b_view=relationship(B, secondary=t12, viewonly=True), - b_plain=relationship(B, secondary=t12), - ) + + mapper(B, t2) + m = mapper( + A, + t1, + properties=dict( + b_view=relationship(B, secondary=t12, viewonly=True), + b_plain=relationship(B, secondary=t12), + ), ) configure_mappers() - assert m.get_property('b_view').local_remote_pairs == \ - m.get_property('b_plain').local_remote_pairs == \ - [(t1.c.id, t12.c.t1_id), (t2.c.id, t12.c.t2_id)] + assert ( + m.get_property("b_view").local_remote_pairs + == m.get_property("b_plain").local_remote_pairs + == [(t1.c.id, t12.c.t1_id), (t2.c.id, t12.c.t2_id)] + ) class ViewOnlyNonEquijoin(fixtures.MappedTest): @@ -2601,11 +3077,13 @@ class ViewOnlyNonEquijoin(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('foos', metadata, - Column('id', Integer, primary_key=True)) - Table('bars', metadata, - Column('id', Integer, primary_key=True), - Column('fid', Integer)) + Table("foos", metadata, Column("id", Integer, primary_key=True)) + Table( + "bars", + metadata, + Column("id", Integer, primary_key=True), + Column("fid", Integer), + ) def test_viewonly_join(self): bars, foos = self.tables.bars, self.tables.foos @@ -2616,28 +3094,43 @@ class ViewOnlyNonEquijoin(fixtures.MappedTest): class Bar(fixtures.ComparableEntity): pass - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, - primaryjoin=foos.c.id > bars.c.fid, - foreign_keys=[bars.c.fid], - viewonly=True)}) + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, + primaryjoin=foos.c.id > bars.c.fid, + foreign_keys=[bars.c.fid], + viewonly=True, + ) + }, + ) mapper(Bar, bars) sess = create_session() - sess.add_all((Foo(id=4), - Foo(id=9), - Bar(id=1, fid=2), - Bar(id=2, fid=3), - Bar(id=3, fid=6), - Bar(id=4, fid=7))) + sess.add_all( + ( + Foo(id=4), + Foo(id=9), + Bar(id=1, fid=2), + Bar(id=2, fid=3), + Bar(id=3, fid=6), + Bar(id=4, fid=7), + ) + ) sess.flush() sess = create_session() - eq_(sess.query(Foo).filter_by(id=4).one(), - Foo(id=4, bars=[Bar(fid=2), Bar(fid=3)])) - eq_(sess.query(Foo).filter_by(id=9).one(), - Foo(id=9, bars=[Bar(fid=2), Bar(fid=3), Bar(fid=6), Bar(fid=7)])) + eq_( + sess.query(Foo).filter_by(id=4).one(), + Foo(id=4, bars=[Bar(fid=2), Bar(fid=3)]), + ) + eq_( + sess.query(Foo).filter_by(id=9).one(), + Foo(id=9, bars=[Bar(fid=2), Bar(fid=3), Bar(fid=6), Bar(fid=7)]), + ) class ViewOnlyRepeatedRemoteColumn(fixtures.MappedTest): @@ -2646,18 +3139,24 @@ class ViewOnlyRepeatedRemoteColumn(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('foos', metadata, - Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('bid1', Integer, ForeignKey('bars.id')), - Column('bid2', Integer, ForeignKey('bars.id'))) - - Table('bars', metadata, - Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50))) + Table( + "foos", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("bid1", Integer, ForeignKey("bars.id")), + Column("bid2", Integer, ForeignKey("bars.id")), + ) + + Table( + "bars", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) def test_relationship_on_or(self): bars, foos = self.tables.bars, self.tables.foos @@ -2668,18 +3167,26 @@ class ViewOnlyRepeatedRemoteColumn(fixtures.MappedTest): class Bar(fixtures.ComparableEntity): pass - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, - primaryjoin=sa.or_(bars.c.id == foos.c.bid1, - bars.c.id == foos.c.bid2), - uselist=True, - viewonly=True)}) + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, + primaryjoin=sa.or_( + bars.c.id == foos.c.bid1, bars.c.id == foos.c.bid2 + ), + uselist=True, + viewonly=True, + ) + }, + ) mapper(Bar, bars) sess = create_session() - b1 = Bar(id=1, data='b1') - b2 = Bar(id=2, data='b2') - b3 = Bar(id=3, data='b3') + b1 = Bar(id=1, data="b1") + b2 = Bar(id=2, data="b2") + b3 = Bar(id=3, data="b3") f1 = Foo(bid1=1, bid2=2) f2 = Foo(bid1=3, bid2=None) @@ -2690,10 +3197,14 @@ class ViewOnlyRepeatedRemoteColumn(fixtures.MappedTest): sess.flush() sess.expunge_all() - eq_(sess.query(Foo).filter_by(id=f1.id).one(), - Foo(bars=[Bar(data='b1'), Bar(data='b2')])) - eq_(sess.query(Foo).filter_by(id=f2.id).one(), - Foo(bars=[Bar(data='b3')])) + eq_( + sess.query(Foo).filter_by(id=f1.id).one(), + Foo(bars=[Bar(data="b1"), Bar(data="b2")]), + ) + eq_( + sess.query(Foo).filter_by(id=f2.id).one(), + Foo(bars=[Bar(data="b3")]), + ) class ViewOnlyRepeatedLocalColumn(fixtures.MappedTest): @@ -2702,16 +3213,25 @@ class ViewOnlyRepeatedLocalColumn(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('foos', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50))) + Table( + "foos", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) - Table('bars', metadata, Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('fid1', Integer, ForeignKey('foos.id')), - Column('fid2', Integer, ForeignKey('foos.id')), - Column('data', String(50))) + Table( + "bars", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("fid1", Integer, ForeignKey("foos.id")), + Column("fid2", Integer, ForeignKey("foos.id")), + Column("data", String(50)), + ) def test_relationship_on_or(self): bars, foos = self.tables.bars, self.tables.foos @@ -2722,20 +3242,28 @@ class ViewOnlyRepeatedLocalColumn(fixtures.MappedTest): class Bar(fixtures.ComparableEntity): pass - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, - primaryjoin=sa.or_(bars.c.fid1 == foos.c.id, - bars.c.fid2 == foos.c.id), - viewonly=True)}) + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, + primaryjoin=sa.or_( + bars.c.fid1 == foos.c.id, bars.c.fid2 == foos.c.id + ), + viewonly=True, + ) + }, + ) mapper(Bar, bars) sess = create_session() - f1 = Foo(id=1, data='f1') - f2 = Foo(id=2, data='f2') - b1 = Bar(fid1=1, data='b1') - b2 = Bar(fid2=1, data='b2') - b3 = Bar(fid1=2, data='b3') - b4 = Bar(fid1=1, fid2=2, data='b4') + f1 = Foo(id=1, data="f1") + f2 = Foo(id=2, data="f2") + b1 = Bar(fid1=1, data="b1") + b2 = Bar(fid2=1, data="b2") + b3 = Bar(fid1=2, data="b3") + b4 = Bar(fid1=1, fid2=2, data="b4") sess.add_all((f1, f2)) sess.flush() @@ -2744,10 +3272,14 @@ class ViewOnlyRepeatedLocalColumn(fixtures.MappedTest): sess.flush() sess.expunge_all() - eq_(sess.query(Foo).filter_by(id=f1.id).one(), - Foo(bars=[Bar(data='b1'), Bar(data='b2'), Bar(data='b4')])) - eq_(sess.query(Foo).filter_by(id=f2.id).one(), - Foo(bars=[Bar(data='b3'), Bar(data='b4')])) + eq_( + sess.query(Foo).filter_by(id=f1.id).one(), + Foo(bars=[Bar(data="b1"), Bar(data="b2"), Bar(data="b4")]), + ) + eq_( + sess.query(Foo).filter_by(id=f2.id).one(), + Foo(bars=[Bar(data="b3"), Bar(data="b4")]), + ) class ViewOnlyComplexJoin(_RelationshipErrors, fixtures.MappedTest): @@ -2756,22 +3288,37 @@ class ViewOnlyComplexJoin(_RelationshipErrors, fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('t1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50))) - Table('t2', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)), - Column('t1id', Integer, ForeignKey('t1.id'))) - Table('t3', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50))) - Table('t2tot3', metadata, - Column('t2id', Integer, ForeignKey('t2.id')), - Column('t3id', Integer, ForeignKey('t3.id'))) + Table( + "t1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + Table( + "t2", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + Column("t1id", Integer, ForeignKey("t1.id")), + ) + Table( + "t3", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + Table( + "t2tot3", + metadata, + Column("t2id", Integer, ForeignKey("t2.id")), + Column("t3id", Integer, ForeignKey("t3.id")), + ) @classmethod def setup_classes(cls): @@ -2785,56 +3332,86 @@ class ViewOnlyComplexJoin(_RelationshipErrors, fixtures.MappedTest): pass def test_basic(self): - T1, t2, T2, T3, t3, t2tot3, t1 = (self.classes.T1, - self.tables.t2, - self.classes.T2, - self.classes.T3, - self.tables.t3, - self.tables.t2tot3, - self.tables.t1) - - mapper(T1, t1, properties={ - 't3s': relationship(T3, primaryjoin=sa.and_( - t1.c.id == t2.c.t1id, - t2.c.id == t2tot3.c.t2id, - t3.c.id == t2tot3.c.t3id), - viewonly=True, - foreign_keys=t3.c.id, remote_side=t2.c.t1id) - }) - mapper(T2, t2, properties={ - 't1': relationship(T1), - 't3s': relationship(T3, secondary=t2tot3) - }) + T1, t2, T2, T3, t3, t2tot3, t1 = ( + self.classes.T1, + self.tables.t2, + self.classes.T2, + self.classes.T3, + self.tables.t3, + self.tables.t2tot3, + self.tables.t1, + ) + + mapper( + T1, + t1, + properties={ + "t3s": relationship( + T3, + primaryjoin=sa.and_( + t1.c.id == t2.c.t1id, + t2.c.id == t2tot3.c.t2id, + t3.c.id == t2tot3.c.t3id, + ), + viewonly=True, + foreign_keys=t3.c.id, + remote_side=t2.c.t1id, + ) + }, + ) + mapper( + T2, + t2, + properties={ + "t1": relationship(T1), + "t3s": relationship(T3, secondary=t2tot3), + }, + ) mapper(T3, t3) sess = create_session() - sess.add(T2(data='t2', t1=T1(data='t1'), t3s=[T3(data='t3')])) + sess.add(T2(data="t2", t1=T1(data="t1"), t3s=[T3(data="t3")])) sess.flush() sess.expunge_all() a = sess.query(T1).first() - eq_(a.t3s, [T3(data='t3')]) + eq_(a.t3s, [T3(data="t3")]) def test_remote_side_escalation(self): - T1, t2, T2, T3, t3, t2tot3, t1 = (self.classes.T1, - self.tables.t2, - self.classes.T2, - self.classes.T3, - self.tables.t3, - self.tables.t2tot3, - self.tables.t1) - - mapper(T1, t1, properties={ - 't3s': relationship(T3, - primaryjoin=sa.and_(t1.c.id == t2.c.t1id, - t2.c.id == t2tot3.c.t2id, - t3.c.id == t2tot3.c.t3id - ), - viewonly=True, - foreign_keys=t3.c.id)}) - mapper(T2, t2, properties={ - 't1': relationship(T1), - 't3s': relationship(T3, secondary=t2tot3)}) + T1, t2, T2, T3, t3, t2tot3, t1 = ( + self.classes.T1, + self.tables.t2, + self.classes.T2, + self.classes.T3, + self.tables.t3, + self.tables.t2tot3, + self.tables.t1, + ) + + mapper( + T1, + t1, + properties={ + "t3s": relationship( + T3, + primaryjoin=sa.and_( + t1.c.id == t2.c.t1id, + t2.c.id == t2tot3.c.t2id, + t3.c.id == t2tot3.c.t3id, + ), + viewonly=True, + foreign_keys=t3.c.id, + ) + }, + ) + mapper( + T2, + t2, + properties={ + "t1": relationship(T1), + "t3s": relationship(T3, secondary=t2tot3), + }, + ) mapper(T3, t3) self._assert_raises_no_local_remote(configure_mappers, "T1.t3s") @@ -2844,36 +3421,40 @@ class FunctionAsPrimaryJoinTest(fixtures.DeclarativeMappedTest): """ - __only_on__= 'sqlite' + __only_on__ = "sqlite" @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class Venue(Base): - __tablename__ = 'venue' + __tablename__ = "venue" id = Column(Integer, primary_key=True) name = Column(String) descendants = relationship( "Venue", primaryjoin=func.instr( - remote(foreign(name)), name + "/").as_comparison(1, 2) == 1, + remote(foreign(name)), name + "/" + ).as_comparison(1, 2) + == 1, viewonly=True, - order_by=name + order_by=name, ) @classmethod def insert_data(cls): Venue = cls.classes.Venue s = Session() - s.add_all([ - Venue(name="parent1"), - Venue(name="parent2"), - Venue(name="parent1/child1"), - Venue(name="parent1/child2"), - Venue(name="parent2/child1"), - ]) + s.add_all( + [ + Venue(name="parent1"), + Venue(name="parent2"), + Venue(name="parent1/child1"), + Venue(name="parent1/child2"), + Venue(name="parent2/child1"), + ] + ) s.commit() def test_lazyload(self): @@ -2882,19 +3463,25 @@ class FunctionAsPrimaryJoinTest(fixtures.DeclarativeMappedTest): v1 = s.query(Venue).filter_by(name="parent1").one() eq_( [d.name for d in v1.descendants], - ['parent1/child1', 'parent1/child2']) + ["parent1/child1", "parent1/child2"], + ) def test_joinedload(self): Venue = self.classes.Venue s = Session() def go(): - v1 = s.query(Venue).filter_by(name="parent1").\ - options(joinedload(Venue.descendants)).one() + v1 = ( + s.query(Venue) + .filter_by(name="parent1") + .options(joinedload(Venue.descendants)) + .one() + ) eq_( [d.name for d in v1.descendants], - ['parent1/child1', 'parent1/child2']) + ["parent1/child1", "parent1/child2"], + ) self.assert_sql_count(testing.db, go, 1) @@ -2908,6 +3495,7 @@ class RemoteForeignBetweenColsTest(fixtures.DeclarativeMappedTest): instrumented attributes, etc. """ + @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic @@ -2915,8 +3503,9 @@ class RemoteForeignBetweenColsTest(fixtures.DeclarativeMappedTest): class Network(fixtures.ComparableEntity, Base): __tablename__ = "network" - id = Column(sa.Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + sa.Integer, primary_key=True, test_needs_autoincrement=True + ) ip_net_addr = Column(Integer) ip_broadcast_addr = Column(Integer) @@ -2925,7 +3514,7 @@ class RemoteForeignBetweenColsTest(fixtures.DeclarativeMappedTest): primaryjoin="remote(foreign(Address.ip_addr)).between(" "Network.ip_net_addr," "Network.ip_broadcast_addr)", - viewonly=True + viewonly=True, ) class Address(fixtures.ComparableEntity, Base): @@ -2938,26 +3527,30 @@ class RemoteForeignBetweenColsTest(fixtures.DeclarativeMappedTest): Network, Address = cls.classes.Network, cls.classes.Address s = Session(testing.db) - s.add_all([ - Network(ip_net_addr=5, ip_broadcast_addr=10), - Network(ip_net_addr=15, ip_broadcast_addr=25), - Network(ip_net_addr=30, ip_broadcast_addr=35), - Address(ip_addr=17), Address(ip_addr=18), Address(ip_addr=9), - Address(ip_addr=27) - ]) - s.commit() - + s.add_all( + [ + Network(ip_net_addr=5, ip_broadcast_addr=10), + Network(ip_net_addr=15, ip_broadcast_addr=25), + Network(ip_net_addr=30, ip_broadcast_addr=35), + Address(ip_addr=17), + Address(ip_addr=18), + Address(ip_addr=9), + Address(ip_addr=27), + ] + ) + s.commit() + def test_col_query(self): Network, Address = self.classes.Network, self.classes.Address session = Session(testing.db) eq_( - session.query(Address.ip_addr). - select_from(Network). - join(Network.addresses). - filter(Network.ip_net_addr == 15). - all(), - [(17, ), (18, )] + session.query(Address.ip_addr) + .select_from(Network) + .join(Network.addresses) + .filter(Network.ip_net_addr == 15) + .all(), + [(17,), (18,)], ) def test_lazyload(self): @@ -2970,17 +3563,23 @@ class RemoteForeignBetweenColsTest(fixtures.DeclarativeMappedTest): class ExplicitLocalRemoteTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('t1', metadata, - Column('id', String(50), primary_key=True), - Column('data', String(50))) - Table('t2', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)), - Column('t1id', String(50))) + Table( + "t1", + metadata, + Column("id", String(50), primary_key=True), + Column("data", String(50)), + ) + Table( + "t2", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + Column("t1id", String(50)), + ) @classmethod def setup_classes(cls): @@ -2991,184 +3590,265 @@ class ExplicitLocalRemoteTest(fixtures.MappedTest): pass def test_onetomany_funcfk_oldstyle(self): - T2, T1, t2, t1 = (self.classes.T2, - self.classes.T1, - self.tables.t2, - self.tables.t1) + T2, T1, t2, t1 = ( + self.classes.T2, + self.classes.T1, + self.tables.t2, + self.tables.t1, + ) # old _local_remote_pairs - mapper(T1, t1, properties={ - 't2s': relationship( - T2, - primaryjoin=t1.c.id == sa.func.lower(t2.c.t1id), - _local_remote_pairs=[(t1.c.id, t2.c.t1id)], - foreign_keys=[t2.c.t1id] - ) - }) + mapper( + T1, + t1, + properties={ + "t2s": relationship( + T2, + primaryjoin=t1.c.id == sa.func.lower(t2.c.t1id), + _local_remote_pairs=[(t1.c.id, t2.c.t1id)], + foreign_keys=[t2.c.t1id], + ) + }, + ) mapper(T2, t2) self._test_onetomany() def test_onetomany_funcfk_annotated(self): - T2, T1, t2, t1 = (self.classes.T2, - self.classes.T1, - self.tables.t2, - self.tables.t1) + T2, T1, t2, t1 = ( + self.classes.T2, + self.classes.T1, + self.tables.t2, + self.tables.t1, + ) # use annotation - mapper(T1, t1, properties={ - 't2s': relationship(T2, - primaryjoin=t1.c.id == - foreign(sa.func.lower(t2.c.t1id)), - )}) + mapper( + T1, + t1, + properties={ + "t2s": relationship( + T2, + primaryjoin=t1.c.id == foreign(sa.func.lower(t2.c.t1id)), + ) + }, + ) mapper(T2, t2) self._test_onetomany() def _test_onetomany(self): - T2, T1, t2, t1 = (self.classes.T2, - self.classes.T1, - self.tables.t2, - self.tables.t1) + T2, T1, t2, t1 = ( + self.classes.T2, + self.classes.T1, + self.tables.t2, + self.tables.t1, + ) is_(T1.t2s.property.direction, ONETOMANY) eq_(T1.t2s.property.local_remote_pairs, [(t1.c.id, t2.c.t1id)]) sess = create_session() - a1 = T1(id='number1', data='a1') - a2 = T1(id='number2', data='a2') - b1 = T2(data='b1', t1id='NuMbEr1') - b2 = T2(data='b2', t1id='Number1') - b3 = T2(data='b3', t1id='Number2') + a1 = T1(id="number1", data="a1") + a2 = T1(id="number2", data="a2") + b1 = T2(data="b1", t1id="NuMbEr1") + b2 = T2(data="b2", t1id="Number1") + b3 = T2(data="b3", t1id="Number2") sess.add_all((a1, a2, b1, b2, b3)) sess.flush() sess.expunge_all() - eq_(sess.query(T1).first(), - T1(id='number1', data='a1', t2s=[ - T2(data='b1', t1id='NuMbEr1'), - T2(data='b2', t1id='Number1')])) + eq_( + sess.query(T1).first(), + T1( + id="number1", + data="a1", + t2s=[ + T2(data="b1", t1id="NuMbEr1"), + T2(data="b2", t1id="Number1"), + ], + ), + ) def test_manytoone_funcfk(self): - T2, T1, t2, t1 = (self.classes.T2, - self.classes.T1, - self.tables.t2, - self.tables.t1) + T2, T1, t2, t1 = ( + self.classes.T2, + self.classes.T1, + self.tables.t2, + self.tables.t1, + ) mapper(T1, t1) - mapper(T2, t2, properties={ - 't1': relationship(T1, - primaryjoin=t1.c.id == sa.func.lower(t2.c.t1id), - _local_remote_pairs=[(t2.c.t1id, t1.c.id)], - foreign_keys=[t2.c.t1id], - uselist=True)}) + mapper( + T2, + t2, + properties={ + "t1": relationship( + T1, + primaryjoin=t1.c.id == sa.func.lower(t2.c.t1id), + _local_remote_pairs=[(t2.c.t1id, t1.c.id)], + foreign_keys=[t2.c.t1id], + uselist=True, + ) + }, + ) sess = create_session() - a1 = T1(id='number1', data='a1') - a2 = T1(id='number2', data='a2') - b1 = T2(data='b1', t1id='NuMbEr1') - b2 = T2(data='b2', t1id='Number1') - b3 = T2(data='b3', t1id='Number2') + a1 = T1(id="number1", data="a1") + a2 = T1(id="number2", data="a2") + b1 = T2(data="b1", t1id="NuMbEr1") + b2 = T2(data="b2", t1id="Number1") + b3 = T2(data="b3", t1id="Number2") sess.add_all((a1, a2, b1, b2, b3)) sess.flush() sess.expunge_all() - eq_(sess.query(T2).filter(T2.data.in_(['b1', 'b2'])).all(), - [T2(data='b1', t1=[T1(id='number1', data='a1')]), - T2(data='b2', t1=[T1(id='number1', data='a1')])]) + eq_( + sess.query(T2).filter(T2.data.in_(["b1", "b2"])).all(), + [ + T2(data="b1", t1=[T1(id="number1", data="a1")]), + T2(data="b2", t1=[T1(id="number1", data="a1")]), + ], + ) def test_onetomany_func_referent(self): - T2, T1, t2, t1 = (self.classes.T2, - self.classes.T1, - self.tables.t2, - self.tables.t1) - - mapper(T1, t1, properties={ - 't2s': relationship( - T2, - primaryjoin=sa.func.lower(t1.c.id) == t2.c.t1id, - _local_remote_pairs=[(t1.c.id, t2.c.t1id)], - foreign_keys=[t2.c.t1id])}) + T2, T1, t2, t1 = ( + self.classes.T2, + self.classes.T1, + self.tables.t2, + self.tables.t1, + ) + + mapper( + T1, + t1, + properties={ + "t2s": relationship( + T2, + primaryjoin=sa.func.lower(t1.c.id) == t2.c.t1id, + _local_remote_pairs=[(t1.c.id, t2.c.t1id)], + foreign_keys=[t2.c.t1id], + ) + }, + ) mapper(T2, t2) sess = create_session() - a1 = T1(id='NuMbeR1', data='a1') - a2 = T1(id='NuMbeR2', data='a2') - b1 = T2(data='b1', t1id='number1') - b2 = T2(data='b2', t1id='number1') - b3 = T2(data='b2', t1id='number2') + a1 = T1(id="NuMbeR1", data="a1") + a2 = T1(id="NuMbeR2", data="a2") + b1 = T2(data="b1", t1id="number1") + b2 = T2(data="b2", t1id="number1") + b3 = T2(data="b2", t1id="number2") sess.add_all((a1, a2, b1, b2, b3)) sess.flush() sess.expunge_all() - eq_(sess.query(T1).first(), - T1(id='NuMbeR1', data='a1', t2s=[ - T2(data='b1', t1id='number1'), - T2(data='b2', t1id='number1')])) + eq_( + sess.query(T1).first(), + T1( + id="NuMbeR1", + data="a1", + t2s=[ + T2(data="b1", t1id="number1"), + T2(data="b2", t1id="number1"), + ], + ), + ) def test_manytoone_func_referent(self): - T2, T1, t2, t1 = (self.classes.T2, - self.classes.T1, - self.tables.t2, - self.tables.t1) + T2, T1, t2, t1 = ( + self.classes.T2, + self.classes.T1, + self.tables.t2, + self.tables.t1, + ) mapper(T1, t1) - mapper(T2, t2, properties={ - 't1': relationship(T1, - primaryjoin=sa.func.lower(t1.c.id) == t2.c.t1id, - _local_remote_pairs=[(t2.c.t1id, t1.c.id)], - foreign_keys=[t2.c.t1id], uselist=True)}) + mapper( + T2, + t2, + properties={ + "t1": relationship( + T1, + primaryjoin=sa.func.lower(t1.c.id) == t2.c.t1id, + _local_remote_pairs=[(t2.c.t1id, t1.c.id)], + foreign_keys=[t2.c.t1id], + uselist=True, + ) + }, + ) sess = create_session() - a1 = T1(id='NuMbeR1', data='a1') - a2 = T1(id='NuMbeR2', data='a2') - b1 = T2(data='b1', t1id='number1') - b2 = T2(data='b2', t1id='number1') - b3 = T2(data='b3', t1id='number2') + a1 = T1(id="NuMbeR1", data="a1") + a2 = T1(id="NuMbeR2", data="a2") + b1 = T2(data="b1", t1id="number1") + b2 = T2(data="b2", t1id="number1") + b3 = T2(data="b3", t1id="number2") sess.add_all((a1, a2, b1, b2, b3)) sess.flush() sess.expunge_all() - eq_(sess.query(T2).filter(T2.data.in_(['b1', 'b2'])).all(), - [T2(data='b1', t1=[T1(id='NuMbeR1', data='a1')]), - T2(data='b2', t1=[T1(id='NuMbeR1', data='a1')])]) + eq_( + sess.query(T2).filter(T2.data.in_(["b1", "b2"])).all(), + [ + T2(data="b1", t1=[T1(id="NuMbeR1", data="a1")]), + T2(data="b2", t1=[T1(id="NuMbeR1", data="a1")]), + ], + ) def test_escalation_1(self): - T2, T1, t2, t1 = (self.classes.T2, - self.classes.T1, - self.tables.t2, - self.tables.t1) - - mapper(T1, t1, properties={ - 't2s': relationship( - T2, - primaryjoin=t1.c.id == sa.func.lower(t2.c.t1id), - _local_remote_pairs=[(t1.c.id, t2.c.t1id)], - foreign_keys=[t2.c.t1id], - remote_side=[t2.c.t1id])}) + T2, T1, t2, t1 = ( + self.classes.T2, + self.classes.T1, + self.tables.t2, + self.tables.t1, + ) + + mapper( + T1, + t1, + properties={ + "t2s": relationship( + T2, + primaryjoin=t1.c.id == sa.func.lower(t2.c.t1id), + _local_remote_pairs=[(t1.c.id, t2.c.t1id)], + foreign_keys=[t2.c.t1id], + remote_side=[t2.c.t1id], + ) + }, + ) mapper(T2, t2) assert_raises(sa.exc.ArgumentError, sa.orm.configure_mappers) def test_escalation_2(self): - T2, T1, t2, t1 = (self.classes.T2, - self.classes.T1, - self.tables.t2, - self.tables.t1) - - mapper(T1, t1, properties={ - 't2s': relationship( - T2, - primaryjoin=t1.c.id == sa.func.lower(t2.c.t1id), - _local_remote_pairs=[(t1.c.id, t2.c.t1id)])}) + T2, T1, t2, t1 = ( + self.classes.T2, + self.classes.T1, + self.tables.t2, + self.tables.t1, + ) + + mapper( + T1, + t1, + properties={ + "t2s": relationship( + T2, + primaryjoin=t1.c.id == sa.func.lower(t2.c.t1id), + _local_remote_pairs=[(t1.c.id, t2.c.t1id)], + ) + }, + ) mapper(T2, t2) assert_raises(sa.exc.ArgumentError, sa.orm.configure_mappers) class InvalidRemoteSideTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('t1', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column('t_id', Integer, ForeignKey('t1.id')) - ) + Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("t_id", Integer, ForeignKey("t1.id")), + ) @classmethod def setup_classes(cls): @@ -3178,88 +3858,104 @@ class InvalidRemoteSideTest(fixtures.MappedTest): def test_o2m_backref(self): T1, t1 = self.classes.T1, self.tables.t1 - mapper(T1, t1, properties={ - 't1s': relationship(T1, backref='parent') - }) + mapper(T1, t1, properties={"t1s": relationship(T1, backref="parent")}) assert_raises_message( sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are " r"both of the same direction symbol\('ONETOMANY'\). Did you " "mean to set remote_side on the many-to-one side ?", - configure_mappers) + configure_mappers, + ) def test_m2o_backref(self): T1, t1 = self.classes.T1, self.tables.t1 - mapper(T1, t1, properties={ - 't1s': relationship(T1, - backref=backref('parent', remote_side=t1.c.id), - remote_side=t1.c.id) - }) + mapper( + T1, + t1, + properties={ + "t1s": relationship( + T1, + backref=backref("parent", remote_side=t1.c.id), + remote_side=t1.c.id, + ) + }, + ) assert_raises_message( sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are " r"both of the same direction symbol\('MANYTOONE'\). Did you " "mean to set remote_side on the many-to-one side ?", - configure_mappers) + configure_mappers, + ) def test_o2m_explicit(self): T1, t1 = self.classes.T1, self.tables.t1 - mapper(T1, t1, properties={ - 't1s': relationship(T1, back_populates='parent'), - 'parent': relationship(T1, back_populates='t1s'), - }) + mapper( + T1, + t1, + properties={ + "t1s": relationship(T1, back_populates="parent"), + "parent": relationship(T1, back_populates="t1s"), + }, + ) # can't be sure of ordering here assert_raises_message( sa.exc.ArgumentError, r"both of the same direction symbol\('ONETOMANY'\). Did you " "mean to set remote_side on the many-to-one side ?", - configure_mappers) + configure_mappers, + ) def test_m2o_explicit(self): T1, t1 = self.classes.T1, self.tables.t1 - mapper(T1, t1, properties={ - 't1s': relationship(T1, back_populates='parent', - remote_side=t1.c.id), - 'parent': relationship(T1, back_populates='t1s', - remote_side=t1.c.id) - }) + mapper( + T1, + t1, + properties={ + "t1s": relationship( + T1, back_populates="parent", remote_side=t1.c.id + ), + "parent": relationship( + T1, back_populates="t1s", remote_side=t1.c.id + ), + }, + ) # can't be sure of ordering here assert_raises_message( sa.exc.ArgumentError, r"both of the same direction symbol\('MANYTOONE'\). Did you " "mean to set remote_side on the many-to-one side ?", - configure_mappers) + configure_mappers, + ) class AmbiguousFKResolutionTest(_RelationshipErrors, fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table("a", metadata, - Column('id', Integer, primary_key=True) - ) - Table("b", metadata, - Column('id', Integer, primary_key=True), - Column('aid_1', Integer, ForeignKey('a.id')), - Column('aid_2', Integer, ForeignKey('a.id')), - ) - Table("atob", metadata, - Column('aid', Integer), - Column('bid', Integer), - ) - Table("atob_ambiguous", metadata, - Column('aid1', Integer, ForeignKey('a.id')), - Column('bid1', Integer, ForeignKey('b.id')), - Column('aid2', Integer, ForeignKey('a.id')), - Column('bid2', Integer, ForeignKey('b.id')), - ) + Table("a", metadata, Column("id", Integer, primary_key=True)) + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("aid_1", Integer, ForeignKey("a.id")), + Column("aid_2", Integer, ForeignKey("a.id")), + ) + Table("atob", metadata, Column("aid", Integer), Column("bid", Integer)) + Table( + "atob_ambiguous", + metadata, + Column("aid1", Integer, ForeignKey("a.id")), + Column("bid1", Integer, ForeignKey("b.id")), + Column("aid2", Integer, ForeignKey("a.id")), + Column("bid2", Integer, ForeignKey("b.id")), + ) @classmethod def setup_classes(cls): @@ -3272,103 +3968,89 @@ class AmbiguousFKResolutionTest(_RelationshipErrors, fixtures.MappedTest): def test_ambiguous_fks_o2m(self): A, B = self.classes.A, self.classes.B a, b = self.tables.a, self.tables.b - mapper(A, a, properties={ - 'bs': relationship(B) - }) + mapper(A, a, properties={"bs": relationship(B)}) mapper(B, b) - self._assert_raises_ambig_join( - configure_mappers, - "A.bs", - None - ) + self._assert_raises_ambig_join(configure_mappers, "A.bs", None) def test_with_fks_o2m(self): A, B = self.classes.A, self.classes.B a, b = self.tables.a, self.tables.b - mapper(A, a, properties={ - 'bs': relationship(B, foreign_keys=b.c.aid_1) - }) + mapper( + A, a, properties={"bs": relationship(B, foreign_keys=b.c.aid_1)} + ) mapper(B, b) sa.orm.configure_mappers() - assert A.bs.property.primaryjoin.compare( - a.c.id == b.c.aid_1 - ) - eq_( - A.bs.property._calculated_foreign_keys, - set([b.c.aid_1]) - ) + assert A.bs.property.primaryjoin.compare(a.c.id == b.c.aid_1) + eq_(A.bs.property._calculated_foreign_keys, set([b.c.aid_1])) def test_with_pj_o2m(self): A, B = self.classes.A, self.classes.B a, b = self.tables.a, self.tables.b - mapper(A, a, properties={ - 'bs': relationship(B, primaryjoin=a.c.id == b.c.aid_1) - }) + mapper( + A, + a, + properties={ + "bs": relationship(B, primaryjoin=a.c.id == b.c.aid_1) + }, + ) mapper(B, b) sa.orm.configure_mappers() - assert A.bs.property.primaryjoin.compare( - a.c.id == b.c.aid_1 - ) - eq_( - A.bs.property._calculated_foreign_keys, - set([b.c.aid_1]) - ) + assert A.bs.property.primaryjoin.compare(a.c.id == b.c.aid_1) + eq_(A.bs.property._calculated_foreign_keys, set([b.c.aid_1])) def test_with_annotated_pj_o2m(self): A, B = self.classes.A, self.classes.B a, b = self.tables.a, self.tables.b - mapper(A, a, properties={ - 'bs': relationship(B, primaryjoin=a.c.id == foreign(b.c.aid_1)) - }) + mapper( + A, + a, + properties={ + "bs": relationship(B, primaryjoin=a.c.id == foreign(b.c.aid_1)) + }, + ) mapper(B, b) sa.orm.configure_mappers() - assert A.bs.property.primaryjoin.compare( - a.c.id == b.c.aid_1 - ) - eq_( - A.bs.property._calculated_foreign_keys, - set([b.c.aid_1]) - ) + assert A.bs.property.primaryjoin.compare(a.c.id == b.c.aid_1) + eq_(A.bs.property._calculated_foreign_keys, set([b.c.aid_1])) def test_no_fks_m2m(self): A, B = self.classes.A, self.classes.B a, b, a_to_b = self.tables.a, self.tables.b, self.tables.atob - mapper(A, a, properties={ - 'bs': relationship(B, secondary=a_to_b) - }) + mapper(A, a, properties={"bs": relationship(B, secondary=a_to_b)}) mapper(B, b) - self._assert_raises_no_join( - sa.orm.configure_mappers, - "A.bs", a_to_b, - ) + self._assert_raises_no_join(sa.orm.configure_mappers, "A.bs", a_to_b) def test_ambiguous_fks_m2m(self): A, B = self.classes.A, self.classes.B a, b, a_to_b = self.tables.a, self.tables.b, self.tables.atob_ambiguous - mapper(A, a, properties={ - 'bs': relationship(B, secondary=a_to_b) - }) + mapper(A, a, properties={"bs": relationship(B, secondary=a_to_b)}) mapper(B, b) self._assert_raises_ambig_join( - configure_mappers, - "A.bs", - "atob_ambiguous" + configure_mappers, "A.bs", "atob_ambiguous" ) def test_with_fks_m2m(self): A, B = self.classes.A, self.classes.B a, b, a_to_b = self.tables.a, self.tables.b, self.tables.atob_ambiguous - mapper(A, a, properties={ - 'bs': relationship(B, secondary=a_to_b, - foreign_keys=[a_to_b.c.aid1, a_to_b.c.bid1]) - }) + mapper( + A, + a, + properties={ + "bs": relationship( + B, + secondary=a_to_b, + foreign_keys=[a_to_b.c.aid1, a_to_b.c.bid1], + ) + }, + ) mapper(B, b) sa.orm.configure_mappers() -class SecondaryNestedJoinTest(fixtures.MappedTest, AssertsCompiledSQL, - testing.AssertsExecutionResults): +class SecondaryNestedJoinTest( + fixtures.MappedTest, AssertsCompiledSQL, testing.AssertsExecutionResults +): """test support for a relationship where the 'secondary' table is a compound join(). @@ -3377,38 +4059,49 @@ class SecondaryNestedJoinTest(fixtures.MappedTest, AssertsCompiledSQL, to ensure the join renders. """ - run_setup_mappers = 'once' - run_inserts = 'once' + + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): Table( - 'a', metadata, + "a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30)), + Column("b_id", ForeignKey("b.id")), + ) + Table( + "b", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30)), - Column('b_id', ForeignKey('b.id')) - ) - Table('b', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30)), - Column('d_id', ForeignKey('d.id')) - ) - Table('c', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30)), - Column('a_id', ForeignKey('a.id')), - Column('d_id', ForeignKey('d.id')) - ) - Table('d', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30)), - ) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30)), + Column("d_id", ForeignKey("d.id")), + ) + Table( + "c", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30)), + Column("a_id", ForeignKey("a.id")), + Column("d_id", ForeignKey("d.id")), + ) + Table( + "d", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30)), + ) @classmethod def setup_classes(cls): @@ -3431,35 +4124,35 @@ class SecondaryNestedJoinTest(fixtures.MappedTest, AssertsCompiledSQL, j = sa.join(b, d, b.c.d_id == d.c.id).join(c, c.c.d_id == d.c.id) # j = join(b, d, b.c.d_id == d.c.id).join(c, c.c.d_id == d.c.id) \ # .alias() - mapper(A, a, properties={ - "b": relationship(B), - "d": relationship( - D, secondary=j, - primaryjoin=and_(a.c.b_id == b.c.id, a.c.id == c.c.a_id), - secondaryjoin=d.c.id == b.c.d_id, - # primaryjoin=and_(a.c.b_id == j.c.b_id, a.c.id == j.c.c_a_id), - # secondaryjoin=d.c.id == j.c.b_d_id, - uselist=False, - viewonly=True - ) - }) - mapper(B, b, properties={ - "d": relationship(D) - }) - mapper(C, c, properties={ - "a": relationship(A), - "d": relationship(D) - }) + mapper( + A, + a, + properties={ + "b": relationship(B), + "d": relationship( + D, + secondary=j, + primaryjoin=and_(a.c.b_id == b.c.id, a.c.id == c.c.a_id), + secondaryjoin=d.c.id == b.c.d_id, + # primaryjoin=and_(a.c.b_id == j.c.b_id, a.c.id == j.c.c_a_id), + # secondaryjoin=d.c.id == j.c.b_d_id, + uselist=False, + viewonly=True, + ), + }, + ) + mapper(B, b, properties={"d": relationship(D)}) + mapper(C, c, properties={"a": relationship(A), "d": relationship(D)}) mapper(D, d) @classmethod def insert_data(cls): A, B, C, D = cls.classes.A, cls.classes.B, cls.classes.C, cls.classes.D sess = Session() - a1, a2, a3, a4 = A(name='a1'), A(name='a2'), A(name='a3'), A(name='a4') - b1, b2, b3, b4 = B(name='b1'), B(name='b2'), B(name='b3'), B(name='b4') - c1, c2, c3, c4 = C(name='c1'), C(name='c2'), C(name='c3'), C(name='c4') - d1, d2 = D(name='d1'), D(name='d2') + a1, a2, a3, a4 = A(name="a1"), A(name="a2"), A(name="a3"), A(name="a4") + b1, b2, b3, b4 = B(name="b1"), B(name="b2"), B(name="b3"), B(name="b4") + c1, c2, c3, c4 = C(name="c1"), C(name="c2"), C(name="c3"), C(name="c4") + d1, d2 = D(name="d1"), D(name="d2") a1.b = b1 a2.b = b2 @@ -3493,7 +4186,7 @@ class SecondaryNestedJoinTest(fixtures.MappedTest, AssertsCompiledSQL, "FROM a JOIN (b AS b_1 JOIN d AS d_1 ON b_1.d_id = d_1.id " "JOIN c AS c_1 ON c_1.d_id = d_1.id) ON a.b_id = b_1.id " "AND a.id = c_1.a_id JOIN d ON d.id = b_1.d_id", - dialect="postgresql" + dialect="postgresql", ) def test_render_joinedload(self): @@ -3506,7 +4199,7 @@ class SecondaryNestedJoinTest(fixtures.MappedTest, AssertsCompiledSQL, "(b AS b_1 JOIN d AS d_2 ON b_1.d_id = d_2.id JOIN c AS c_1 " "ON c_1.d_id = d_2.id JOIN d AS d_1 ON d_1.id = b_1.d_id) " "ON a.b_id = b_1.id AND a.id = c_1.a_id", - dialect="postgresql" + dialect="postgresql", ) def test_render_lazyload(self): @@ -3514,7 +4207,7 @@ class SecondaryNestedJoinTest(fixtures.MappedTest, AssertsCompiledSQL, A, D = self.classes.A, self.classes.D sess = Session() - a1 = sess.query(A).filter(A.name == 'a1').first() + a1 = sess.query(A).filter(A.name == "a1").first() def go(): a1.d @@ -3531,16 +4224,11 @@ class SecondaryNestedJoinTest(fixtures.MappedTest, AssertsCompiledSQL, "JOIN d ON b.d_id = d.id JOIN c ON c.d_id = d.id " "WHERE :param_1 = b.id AND :param_2 = c.a_id " "AND d.id = b.d_id", - {'param_1': a1.id, 'param_2': a1.id} - ) + {"param_1": a1.id, "param_2": a1.id}, + ), ) - mapping = { - "a1": "d1", - "a2": None, - "a3": None, - "a4": "d2" - } + mapping = {"a1": "d1", "a2": None, "a3": None, "a4": "d2"} def test_join(self): A, D = self.classes.A, self.classes.D @@ -3567,23 +4255,35 @@ class SecondaryNestedJoinTest(fixtures.MappedTest, AssertsCompiledSQL, class InvalidRelationshipEscalationTest( - _RelationshipErrors, fixtures.MappedTest): - + _RelationshipErrors, fixtures.MappedTest +): @classmethod def define_tables(cls, metadata): - Table('foos', metadata, - Column('id', Integer, primary_key=True), - Column('fid', Integer)) - Table('bars', metadata, - Column('id', Integer, primary_key=True), - Column('fid', Integer)) - - Table('foos_with_fks', metadata, - Column('id', Integer, primary_key=True), - Column('fid', Integer, ForeignKey('foos_with_fks.id'))) - Table('bars_with_fks', metadata, - Column('id', Integer, primary_key=True), - Column('fid', Integer, ForeignKey('foos_with_fks.id'))) + Table( + "foos", + metadata, + Column("id", Integer, primary_key=True), + Column("fid", Integer), + ) + Table( + "bars", + metadata, + Column("id", Integer, primary_key=True), + Column("fid", Integer), + ) + + Table( + "foos_with_fks", + metadata, + Column("id", Integer, primary_key=True), + Column("fid", Integer, ForeignKey("foos_with_fks.id")), + ) + Table( + "bars_with_fks", + metadata, + Column("id", Integer, primary_key=True), + Column("fid", Integer, ForeignKey("foos_with_fks.id")), + ) @classmethod def setup_classes(cls): @@ -3594,65 +4294,77 @@ class InvalidRelationshipEscalationTest( pass def test_no_join(self): - bars, Foo, Bar, foos = (self.tables.bars, - self.classes.Foo, - self.classes.Bar, - self.tables.foos) + bars, Foo, Bar, foos = ( + self.tables.bars, + self.classes.Foo, + self.classes.Bar, + self.tables.foos, + ) - mapper(Foo, foos, properties={ - 'bars': relationship(Bar)}) + mapper(Foo, foos, properties={"bars": relationship(Bar)}) mapper(Bar, bars) - self._assert_raises_no_join(sa.orm.configure_mappers, - "Foo.bars", None - ) + self._assert_raises_no_join(sa.orm.configure_mappers, "Foo.bars", None) def test_no_join_self_ref(self): - bars, Foo, Bar, foos = (self.tables.bars, - self.classes.Foo, - self.classes.Bar, - self.tables.foos) + bars, Foo, Bar, foos = ( + self.tables.bars, + self.classes.Foo, + self.classes.Bar, + self.tables.foos, + ) - mapper(Foo, foos, properties={ - 'foos': relationship(Foo)}) + mapper(Foo, foos, properties={"foos": relationship(Foo)}) mapper(Bar, bars) - self._assert_raises_no_join( - configure_mappers, - "Foo.foos", - None - ) + self._assert_raises_no_join(configure_mappers, "Foo.foos", None) def test_no_equated(self): - bars, Foo, Bar, foos = (self.tables.bars, - self.classes.Foo, - self.classes.Bar, - self.tables.foos) - - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, - primaryjoin=foos.c.id > bars.c.fid)}) + bars, Foo, Bar, foos = ( + self.tables.bars, + self.classes.Foo, + self.classes.Bar, + self.tables.foos, + ) + + mapper( + Foo, + foos, + properties={ + "bars": relationship(Bar, primaryjoin=foos.c.id > bars.c.fid) + }, + ) mapper(Bar, bars) self._assert_raises_no_relevant_fks( - configure_mappers, - "foos.id > bars.fid", "Foo.bars", "primary" + configure_mappers, "foos.id > bars.fid", "Foo.bars", "primary" ) def test_no_equated_fks(self): - bars, Foo, Bar, foos = (self.tables.bars, - self.classes.Foo, - self.classes.Bar, - self.tables.foos) - - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, - primaryjoin=foos.c.id > bars.c.fid, - foreign_keys=bars.c.fid)}) + bars, Foo, Bar, foos = ( + self.tables.bars, + self.classes.Foo, + self.classes.Bar, + self.tables.foos, + ) + + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, + primaryjoin=foos.c.id > bars.c.fid, + foreign_keys=bars.c.fid, + ) + }, + ) mapper(Bar, bars) self._assert_raises_no_equality( sa.orm.configure_mappers, - "foos.id > bars.fid", "Foo.bars", "primary" + "foos.id > bars.fid", + "Foo.bars", + "primary", ) def test_no_equated_wo_fks_works_on_relaxed(self): @@ -3661,7 +4373,8 @@ class InvalidRelationshipEscalationTest( self.classes.Foo, self.classes.Bar, self.tables.bars_with_fks, - self.tables.foos) + self.tables.foos, + ) # very unique - the join between parent/child # has no fks, but there is an fk join between two other @@ -3670,110 +4383,149 @@ class InvalidRelationshipEscalationTest( # in this case we don't get eq_pairs, but we hit the # "works if viewonly" rule. so here we add another clause regarding # "try foreign keys". - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, - primaryjoin=and_( - bars_with_fks.c.fid == foos_with_fks.c.id, - foos_with_fks.c.id == foos.c.id, - ) - )}) + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, + primaryjoin=and_( + bars_with_fks.c.fid == foos_with_fks.c.id, + foos_with_fks.c.id == foos.c.id, + ), + ) + }, + ) mapper(Bar, bars_with_fks) self._assert_raises_no_equality( sa.orm.configure_mappers, "bars_with_fks.fid = foos_with_fks.id " "AND foos_with_fks.id = foos.id", - "Foo.bars", "primary" + "Foo.bars", + "primary", ) def test_ambiguous_fks(self): - bars, Foo, Bar, foos = (self.tables.bars, - self.classes.Foo, - self.classes.Bar, - self.tables.foos) - - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, - primaryjoin=foos.c.id == bars.c.fid, - foreign_keys=[foos.c.id, bars.c.fid])}) + bars, Foo, Bar, foos = ( + self.tables.bars, + self.classes.Foo, + self.classes.Bar, + self.tables.foos, + ) + + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, + primaryjoin=foos.c.id == bars.c.fid, + foreign_keys=[foos.c.id, bars.c.fid], + ) + }, + ) mapper(Bar, bars) self._assert_raises_ambiguous_direction( - sa.orm.configure_mappers, - "Foo.bars" + sa.orm.configure_mappers, "Foo.bars" ) def test_ambiguous_remoteside_o2m(self): - bars, Foo, Bar, foos = (self.tables.bars, - self.classes.Foo, - self.classes.Bar, - self.tables.foos) - - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, - primaryjoin=foos.c.id == bars.c.fid, - foreign_keys=[bars.c.fid], - remote_side=[foos.c.id, bars.c.fid], - viewonly=True - )}) - mapper(Bar, bars) + bars, Foo, Bar, foos = ( + self.tables.bars, + self.classes.Foo, + self.classes.Bar, + self.tables.foos, + ) - self._assert_raises_no_local_remote( - configure_mappers, - "Foo.bars", + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, + primaryjoin=foos.c.id == bars.c.fid, + foreign_keys=[bars.c.fid], + remote_side=[foos.c.id, bars.c.fid], + viewonly=True, + ) + }, ) + mapper(Bar, bars) + + self._assert_raises_no_local_remote(configure_mappers, "Foo.bars") def test_ambiguous_remoteside_m2o(self): - bars, Foo, Bar, foos = (self.tables.bars, - self.classes.Foo, - self.classes.Bar, - self.tables.foos) - - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, - primaryjoin=foos.c.id == bars.c.fid, - foreign_keys=[foos.c.id], - remote_side=[foos.c.id, bars.c.fid], - viewonly=True - )}) - mapper(Bar, bars) + bars, Foo, Bar, foos = ( + self.tables.bars, + self.classes.Foo, + self.classes.Bar, + self.tables.foos, + ) - self._assert_raises_no_local_remote( - configure_mappers, - "Foo.bars", + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, + primaryjoin=foos.c.id == bars.c.fid, + foreign_keys=[foos.c.id], + remote_side=[foos.c.id, bars.c.fid], + viewonly=True, + ) + }, ) + mapper(Bar, bars) + + self._assert_raises_no_local_remote(configure_mappers, "Foo.bars") def test_no_equated_self_ref_no_fks(self): - bars, Foo, Bar, foos = (self.tables.bars, - self.classes.Foo, - self.classes.Bar, - self.tables.foos) - - mapper(Foo, foos, properties={ - 'foos': relationship(Foo, - primaryjoin=foos.c.id > foos.c.fid)}) + bars, Foo, Bar, foos = ( + self.tables.bars, + self.classes.Foo, + self.classes.Bar, + self.tables.foos, + ) + + mapper( + Foo, + foos, + properties={ + "foos": relationship(Foo, primaryjoin=foos.c.id > foos.c.fid) + }, + ) mapper(Bar, bars) self._assert_raises_no_relevant_fks( - configure_mappers, - "foos.id > foos.fid", "Foo.foos", "primary" + configure_mappers, "foos.id > foos.fid", "Foo.foos", "primary" ) def test_no_equated_self_ref_no_equality(self): - bars, Foo, Bar, foos = (self.tables.bars, - self.classes.Foo, - self.classes.Bar, - self.tables.foos) - - mapper(Foo, foos, properties={ - 'foos': relationship(Foo, - primaryjoin=foos.c.id > foos.c.fid, - foreign_keys=[foos.c.fid])}) + bars, Foo, Bar, foos = ( + self.tables.bars, + self.classes.Foo, + self.classes.Bar, + self.tables.foos, + ) + + mapper( + Foo, + foos, + properties={ + "foos": relationship( + Foo, + primaryjoin=foos.c.id > foos.c.fid, + foreign_keys=[foos.c.fid], + ) + }, + ) mapper(Bar, bars) - self._assert_raises_no_equality(configure_mappers, - "foos.id > foos.fid", "Foo.foos", - "primary") + self._assert_raises_no_equality( + configure_mappers, "foos.id > foos.fid", "Foo.foos", "primary" + ) def test_no_equated_viewonly(self): bars, Bar, bars_with_fks, foos_with_fks, Foo, foos = ( @@ -3782,25 +4534,39 @@ class InvalidRelationshipEscalationTest( self.tables.bars_with_fks, self.tables.foos_with_fks, self.classes.Foo, - self.tables.foos) + self.tables.foos, + ) - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, - primaryjoin=foos.c.id > bars.c.fid, - viewonly=True)}) + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, primaryjoin=foos.c.id > bars.c.fid, viewonly=True + ) + }, + ) mapper(Bar, bars) self._assert_raises_no_relevant_fks( sa.orm.configure_mappers, - "foos.id > bars.fid", "Foo.bars", "primary" + "foos.id > bars.fid", + "Foo.bars", + "primary", ) sa.orm.clear_mappers() - mapper(Foo, foos_with_fks, properties={ - 'bars': relationship( - Bar, - primaryjoin=foos_with_fks.c.id > bars_with_fks.c.fid, - viewonly=True)}) + mapper( + Foo, + foos_with_fks, + properties={ + "bars": relationship( + Bar, + primaryjoin=foos_with_fks.c.id > bars_with_fks.c.fid, + viewonly=True, + ) + }, + ) mapper(Bar, bars_with_fks) sa.orm.configure_mappers() @@ -3811,36 +4577,57 @@ class InvalidRelationshipEscalationTest( self.tables.bars_with_fks, self.tables.foos_with_fks, self.classes.Foo, - self.tables.foos) + self.tables.foos, + ) - mapper(Foo, foos, properties={ - 'foos': relationship(Foo, - primaryjoin=foos.c.id > foos.c.fid, - viewonly=True)}) + mapper( + Foo, + foos, + properties={ + "foos": relationship( + Foo, primaryjoin=foos.c.id > foos.c.fid, viewonly=True + ) + }, + ) mapper(Bar, bars) self._assert_raises_no_relevant_fks( sa.orm.configure_mappers, - "foos.id > foos.fid", "Foo.foos", "primary" + "foos.id > foos.fid", + "Foo.foos", + "primary", ) sa.orm.clear_mappers() - mapper(Foo, foos_with_fks, properties={ - 'foos': relationship( - Foo, - primaryjoin=foos_with_fks.c.id > foos_with_fks.c.fid, - viewonly=True)}) + mapper( + Foo, + foos_with_fks, + properties={ + "foos": relationship( + Foo, + primaryjoin=foos_with_fks.c.id > foos_with_fks.c.fid, + viewonly=True, + ) + }, + ) mapper(Bar, bars_with_fks) sa.orm.configure_mappers() def test_no_equated_self_ref_viewonly_fks(self): Foo, foos = self.classes.Foo, self.tables.foos - mapper(Foo, foos, properties={ - 'foos': relationship(Foo, - primaryjoin=foos.c.id > foos.c.fid, - viewonly=True, - foreign_keys=[foos.c.fid])}) + mapper( + Foo, + foos, + properties={ + "foos": relationship( + Foo, + primaryjoin=foos.c.id > foos.c.fid, + viewonly=True, + foreign_keys=[foos.c.fid], + ) + }, + ) sa.orm.configure_mappers() eq_(Foo.foos.property.local_remote_pairs, [(foos.c.id, foos.c.fid)]) @@ -3852,79 +4639,102 @@ class InvalidRelationshipEscalationTest( self.tables.bars_with_fks, self.tables.foos_with_fks, self.classes.Foo, - self.tables.foos) + self.tables.foos, + ) - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, - primaryjoin=foos.c.id == bars.c.fid)}) + mapper( + Foo, + foos, + properties={ + "bars": relationship(Bar, primaryjoin=foos.c.id == bars.c.fid) + }, + ) mapper(Bar, bars) self._assert_raises_no_relevant_fks( - configure_mappers, - "foos.id = bars.fid", "Foo.bars", "primary" + configure_mappers, "foos.id = bars.fid", "Foo.bars", "primary" ) sa.orm.clear_mappers() - mapper(Foo, foos_with_fks, properties={ - 'bars': relationship( - Bar, - primaryjoin=foos_with_fks.c.id == bars_with_fks.c.fid)}) + mapper( + Foo, + foos_with_fks, + properties={ + "bars": relationship( + Bar, primaryjoin=foos_with_fks.c.id == bars_with_fks.c.fid + ) + }, + ) mapper(Bar, bars_with_fks) sa.orm.configure_mappers() def test_equated_self_ref(self): Foo, foos = self.classes.Foo, self.tables.foos - mapper(Foo, foos, properties={ - 'foos': relationship(Foo, - primaryjoin=foos.c.id == foos.c.fid)}) + mapper( + Foo, + foos, + properties={ + "foos": relationship(Foo, primaryjoin=foos.c.id == foos.c.fid) + }, + ) self._assert_raises_no_relevant_fks( - configure_mappers, - "foos.id = foos.fid", "Foo.foos", "primary" + configure_mappers, "foos.id = foos.fid", "Foo.foos", "primary" ) def test_equated_self_ref_wrong_fks(self): - bars, Foo, foos = (self.tables.bars, - self.classes.Foo, - self.tables.foos) + bars, Foo, foos = ( + self.tables.bars, + self.classes.Foo, + self.tables.foos, + ) - mapper(Foo, foos, properties={ - 'foos': relationship(Foo, - primaryjoin=foos.c.id == foos.c.fid, - foreign_keys=[bars.c.id])}) + mapper( + Foo, + foos, + properties={ + "foos": relationship( + Foo, + primaryjoin=foos.c.id == foos.c.fid, + foreign_keys=[bars.c.id], + ) + }, + ) self._assert_raises_no_relevant_fks( - configure_mappers, - "foos.id = foos.fid", "Foo.foos", "primary" + configure_mappers, "foos.id = foos.fid", "Foo.foos", "primary" ) class InvalidRelationshipEscalationTestM2M( - _RelationshipErrors, fixtures.MappedTest): - + _RelationshipErrors, fixtures.MappedTest +): @classmethod def define_tables(cls, metadata): - Table('foos', metadata, - Column('id', Integer, primary_key=True)) - Table('foobars', metadata, - Column('fid', Integer), Column('bid', Integer)) - Table('bars', metadata, - Column('id', Integer, primary_key=True)) - - Table('foobars_with_fks', metadata, - Column('fid', Integer, ForeignKey('foos.id')), - Column('bid', Integer, ForeignKey('bars.id')) - ) - - Table('foobars_with_many_columns', metadata, - Column('fid', Integer), - Column('bid', Integer), - Column('fid1', Integer), - Column('bid1', Integer), - Column('fid2', Integer), - Column('bid2', Integer), - ) + Table("foos", metadata, Column("id", Integer, primary_key=True)) + Table( + "foobars", metadata, Column("fid", Integer), Column("bid", Integer) + ) + Table("bars", metadata, Column("id", Integer, primary_key=True)) + + Table( + "foobars_with_fks", + metadata, + Column("fid", Integer, ForeignKey("foos.id")), + Column("bid", Integer, ForeignKey("bars.id")), + ) + + Table( + "foobars_with_many_columns", + metadata, + Column("fid", Integer), + Column("bid", Integer), + Column("fid1", Integer), + Column("bid1", Integer), + Column("fid2", Integer), + Column("bid2", Integer), + ) @classmethod def setup_classes(cls): @@ -3935,40 +4745,46 @@ class InvalidRelationshipEscalationTestM2M( pass def test_no_join(self): - foobars, bars, Foo, Bar, foos = (self.tables.foobars, - self.tables.bars, - self.classes.Foo, - self.classes.Bar, - self.tables.foos) - - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, secondary=foobars)}) - mapper(Bar, bars) + foobars, bars, Foo, Bar, foos = ( + self.tables.foobars, + self.tables.bars, + self.classes.Foo, + self.classes.Bar, + self.tables.foos, + ) - self._assert_raises_no_join( - configure_mappers, - "Foo.bars", - "foobars" + mapper( + Foo, + foos, + properties={"bars": relationship(Bar, secondary=foobars)}, ) + mapper(Bar, bars) + + self._assert_raises_no_join(configure_mappers, "Foo.bars", "foobars") def test_no_secondaryjoin(self): - foobars, bars, Foo, Bar, foos = (self.tables.foobars, - self.tables.bars, - self.classes.Foo, - self.classes.Bar, - self.tables.foos) - - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, - secondary=foobars, - primaryjoin=foos.c.id > foobars.c.fid)}) - mapper(Bar, bars) + foobars, bars, Foo, Bar, foos = ( + self.tables.foobars, + self.tables.bars, + self.classes.Foo, + self.classes.Bar, + self.tables.foos, + ) - self._assert_raises_no_join( - configure_mappers, - "Foo.bars", - "foobars" + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, + secondary=foobars, + primaryjoin=foos.c.id > foobars.c.fid, + ) + }, ) + mapper(Bar, bars) + + self._assert_raises_no_join(configure_mappers, "Foo.bars", "foobars") def test_no_fks(self): foobars_with_many_columns, bars, Bar, foobars, Foo, foos = ( @@ -3977,41 +4793,51 @@ class InvalidRelationshipEscalationTestM2M( self.classes.Bar, self.tables.foobars, self.classes.Foo, - self.tables.foos) + self.tables.foos, + ) - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, secondary=foobars, - primaryjoin=foos.c.id == foobars.c.fid, - secondaryjoin=foobars.c.bid == bars.c.id)}) + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, + secondary=foobars, + primaryjoin=foos.c.id == foobars.c.fid, + secondaryjoin=foobars.c.bid == bars.c.id, + ) + }, + ) mapper(Bar, bars) sa.orm.configure_mappers() - eq_( - Foo.bars.property.synchronize_pairs, - [(foos.c.id, foobars.c.fid)] - ) + eq_(Foo.bars.property.synchronize_pairs, [(foos.c.id, foobars.c.fid)]) eq_( Foo.bars.property.secondary_synchronize_pairs, - [(bars.c.id, foobars.c.bid)] + [(bars.c.id, foobars.c.bid)], ) sa.orm.clear_mappers() - mapper(Foo, foos, properties={ - 'bars': relationship( - Bar, - secondary=foobars_with_many_columns, - primaryjoin=foos.c.id == - foobars_with_many_columns.c.fid, - secondaryjoin=foobars_with_many_columns.c.bid == - bars.c.id)}) + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, + secondary=foobars_with_many_columns, + primaryjoin=foos.c.id == foobars_with_many_columns.c.fid, + secondaryjoin=foobars_with_many_columns.c.bid == bars.c.id, + ) + }, + ) mapper(Bar, bars) sa.orm.configure_mappers() eq_( Foo.bars.property.synchronize_pairs, - [(foos.c.id, foobars_with_many_columns.c.fid)] + [(foos.c.id, foobars_with_many_columns.c.fid)], ) eq_( Foo.bars.property.secondary_synchronize_pairs, - [(bars.c.id, foobars_with_many_columns.c.bid)] + [(bars.c.id, foobars_with_many_columns.c.bid)], ) def test_local_col_setup(self): @@ -4020,24 +4846,24 @@ class InvalidRelationshipEscalationTestM2M( self.tables.bars, self.classes.Bar, self.classes.Foo, - self.tables.foos) + self.tables.foos, + ) # ensure m2m backref is set up with correct annotations # [ticket:2578] - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, secondary=foobars_with_fks, - backref="foos") - }) + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, secondary=foobars_with_fks, backref="foos" + ) + }, + ) mapper(Bar, bars) sa.orm.configure_mappers() - eq_( - Foo.bars.property._join_condition.local_columns, - set([foos.c.id]) - ) - eq_( - Bar.foos.property._join_condition.local_columns, - set([bars.c.id]) - ) + eq_(Foo.bars.property._join_condition.local_columns, set([foos.c.id])) + eq_(Bar.foos.property._join_condition.local_columns, set([bars.c.id])) def test_bad_primaryjoin(self): foobars_with_fks, bars, Bar, foobars, Foo, foos = ( @@ -4046,87 +4872,124 @@ class InvalidRelationshipEscalationTestM2M( self.classes.Bar, self.tables.foobars, self.classes.Foo, - self.tables.foos) + self.tables.foos, + ) - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, - secondary=foobars, - primaryjoin=foos.c.id > foobars.c.fid, - secondaryjoin=foobars.c.bid <= bars.c.id)}) + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, + secondary=foobars, + primaryjoin=foos.c.id > foobars.c.fid, + secondaryjoin=foobars.c.bid <= bars.c.id, + ) + }, + ) mapper(Bar, bars) self._assert_raises_no_equality( - configure_mappers, - 'foos.id > foobars.fid', - "Foo.bars", - "primary") + configure_mappers, "foos.id > foobars.fid", "Foo.bars", "primary" + ) sa.orm.clear_mappers() - mapper(Foo, foos, properties={ - 'bars': relationship( - Bar, - secondary=foobars_with_fks, - primaryjoin=foos.c.id > foobars_with_fks.c.fid, - secondaryjoin=foobars_with_fks.c.bid <= bars.c.id)}) + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, + secondary=foobars_with_fks, + primaryjoin=foos.c.id > foobars_with_fks.c.fid, + secondaryjoin=foobars_with_fks.c.bid <= bars.c.id, + ) + }, + ) mapper(Bar, bars) self._assert_raises_no_equality( configure_mappers, - 'foos.id > foobars_with_fks.fid', + "foos.id > foobars_with_fks.fid", "Foo.bars", - "primary") + "primary", + ) sa.orm.clear_mappers() - mapper(Foo, foos, properties={ - 'bars': relationship( - Bar, - secondary=foobars_with_fks, - primaryjoin=foos.c.id > foobars_with_fks.c.fid, - secondaryjoin=foobars_with_fks.c.bid <= bars.c.id, - viewonly=True)}) + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, + secondary=foobars_with_fks, + primaryjoin=foos.c.id > foobars_with_fks.c.fid, + secondaryjoin=foobars_with_fks.c.bid <= bars.c.id, + viewonly=True, + ) + }, + ) mapper(Bar, bars) sa.orm.configure_mappers() def test_bad_secondaryjoin(self): - foobars, bars, Foo, Bar, foos = (self.tables.foobars, - self.tables.bars, - self.classes.Foo, - self.classes.Bar, - self.tables.foos) - - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, - secondary=foobars, - primaryjoin=foos.c.id == foobars.c.fid, - secondaryjoin=foobars.c.bid <= bars.c.id, - foreign_keys=[foobars.c.fid])}) + foobars, bars, Foo, Bar, foos = ( + self.tables.foobars, + self.tables.bars, + self.classes.Foo, + self.classes.Bar, + self.tables.foos, + ) + + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, + secondary=foobars, + primaryjoin=foos.c.id == foobars.c.fid, + secondaryjoin=foobars.c.bid <= bars.c.id, + foreign_keys=[foobars.c.fid], + ) + }, + ) mapper(Bar, bars) self._assert_raises_no_relevant_fks( configure_mappers, "foobars.bid <= bars.id", "Foo.bars", - "secondary" + "secondary", ) def test_no_equated_secondaryjoin(self): - foobars, bars, Foo, Bar, foos = (self.tables.foobars, - self.tables.bars, - self.classes.Foo, - self.classes.Bar, - self.tables.foos) - - mapper(Foo, foos, properties={ - 'bars': relationship(Bar, - secondary=foobars, - primaryjoin=foos.c.id == foobars.c.fid, - secondaryjoin=foobars.c.bid <= bars.c.id, - foreign_keys=[foobars.c.fid, foobars.c.bid])}) + foobars, bars, Foo, Bar, foos = ( + self.tables.foobars, + self.tables.bars, + self.classes.Foo, + self.classes.Bar, + self.tables.foos, + ) + + mapper( + Foo, + foos, + properties={ + "bars": relationship( + Bar, + secondary=foobars, + primaryjoin=foos.c.id == foobars.c.fid, + secondaryjoin=foobars.c.bid <= bars.c.id, + foreign_keys=[foobars.c.fid, foobars.c.bid], + ) + }, + ) mapper(Bar, bars) self._assert_raises_no_equality( configure_mappers, "foobars.bid <= bars.id", "Foo.bars", - "secondary" + "secondary", ) @@ -4145,40 +5008,45 @@ class ActiveHistoryFlagTest(_fixtures.FixtureTest): setattr(obj, attrname, newvalue) eq_( - attributes.get_history(obj, attrname), - ([newvalue, ], (), [oldvalue, ]) + attributes.get_history(obj, attrname), ([newvalue], (), [oldvalue]) ) def test_column_property_flag(self): User, users = self.classes.User, self.tables.users - mapper(User, users, properties={ - 'name': column_property(users.c.name, - active_history=True) - }) - u1 = User(name='jack') - self._test_attribute(u1, 'name', 'ed') + mapper( + User, + users, + properties={ + "name": column_property(users.c.name, active_history=True) + }, + ) + u1 = User(name="jack") + self._test_attribute(u1, "name", "ed") def test_relationship_property_flag(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) - - mapper(Address, addresses, properties={ - 'user': relationship(User, active_history=True) - }) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) + + mapper( + Address, + addresses, + properties={"user": relationship(User, active_history=True)}, + ) mapper(User, users) - u1 = User(name='jack') - u2 = User(name='ed') - a1 = Address(email_address='a1', user=u1) - self._test_attribute(a1, 'user', u2) + u1 = User(name="jack") + u2 = User(name="ed") + a1 = Address(email_address="a1", user=u1) + self._test_attribute(a1, "user", u2) def test_composite_property_flag(self): Order, orders = self.classes.Order, self.tables.orders class MyComposite(object): - def __init__(self, description, isopen): self.description = description self.isopen = isopen @@ -4187,51 +5055,67 @@ class ActiveHistoryFlagTest(_fixtures.FixtureTest): return [self.description, self.isopen] def __eq__(self, other): - return isinstance(other, MyComposite) and \ - other.description == self.description - mapper(Order, orders, properties={ - 'composite': composite( - MyComposite, - orders.c.description, - orders.c.isopen, - active_history=True) - }) - o1 = Order(composite=MyComposite('foo', 1)) - self._test_attribute(o1, "composite", MyComposite('bar', 1)) + return ( + isinstance(other, MyComposite) + and other.description == self.description + ) + + mapper( + Order, + orders, + properties={ + "composite": composite( + MyComposite, + orders.c.description, + orders.c.isopen, + active_history=True, + ) + }, + ) + o1 = Order(composite=MyComposite("foo", 1)) + self._test_attribute(o1, "composite", MyComposite("bar", 1)) class InactiveHistoryNoRaiseTest(_fixtures.FixtureTest): run_inserts = None - def _run_test(self, detached, raiseload, backref, active_history, - delete): + def _run_test(self, detached, raiseload, backref, active_history, delete): if delete: assert not backref, "delete and backref are mutually exclusive" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) opts = {} if active_history: - opts['active_history'] = True + opts["active_history"] = True if raiseload: - opts['lazy'] = "raise" + opts["lazy"] = "raise" - mapper(Address, addresses, properties={ - 'user': relationship( - User, back_populates="addresses", **opts) - }) - mapper(User, users, properties={ - "addresses": relationship(Address, back_populates="user") - }) + mapper( + Address, + addresses, + properties={ + "user": relationship(User, back_populates="addresses", **opts) + }, + ) + mapper( + User, + users, + properties={ + "addresses": relationship(Address, back_populates="user") + }, + ) s = Session() - a1 = Address(email_address='a1') - u1 = User(name='u1', addresses=[a1]) + a1 = Address(email_address="a1") + u1 = User(name="u1", addresses=[a1]) s.add_all([a1, u1]) s.commit() @@ -4249,7 +5133,7 @@ class InactiveHistoryNoRaiseTest(_fixtures.FixtureTest): assert_raises_message( exc.InvalidRequestError, "'Address.user' is not available due to lazy='raise'", - go + go, ) return elif detached: @@ -4257,7 +5141,7 @@ class InactiveHistoryNoRaiseTest(_fixtures.FixtureTest): orm_exc.DetachedInstanceError, "lazy load operation of attribute 'user' " "cannot proceed", - go + go, ) return go() @@ -4266,9 +5150,12 @@ class InactiveHistoryNoRaiseTest(_fixtures.FixtureTest): s.expunge(a1) if delete: + def go(): del a1.user + else: + def go(): a1.user = None @@ -4277,7 +5164,7 @@ class InactiveHistoryNoRaiseTest(_fixtures.FixtureTest): assert_raises_message( exc.InvalidRequestError, "'Address.user' is not available due to lazy='raise'", - go + go, ) return elif detached: @@ -4285,7 +5172,7 @@ class InactiveHistoryNoRaiseTest(_fixtures.FixtureTest): orm_exc.DetachedInstanceError, "lazy load operation of attribute 'user' " "cannot proceed", - go + go, ) return go() @@ -4297,144 +5184,246 @@ class InactiveHistoryNoRaiseTest(_fixtures.FixtureTest): def test_replace_m2o(self): self._run_test( - detached=False, raiseload=False, - backref=False, delete=False, active_history=False) + detached=False, + raiseload=False, + backref=False, + delete=False, + active_history=False, + ) def test_replace_m2o_detached(self): self._run_test( - detached=True, raiseload=False, - backref=False, delete=False, active_history=False) + detached=True, + raiseload=False, + backref=False, + delete=False, + active_history=False, + ) def test_replace_m2o_raiseload(self): self._run_test( - detached=False, raiseload=True, - backref=False, delete=False, active_history=False) + detached=False, + raiseload=True, + backref=False, + delete=False, + active_history=False, + ) def test_replace_m2o_detached_raiseload(self): self._run_test( - detached=True, raiseload=True, - backref=False, delete=False, active_history=False) + detached=True, + raiseload=True, + backref=False, + delete=False, + active_history=False, + ) def test_replace_m2o_backref(self): self._run_test( - detached=False, raiseload=False, - backref=True, delete=False, active_history=False) + detached=False, + raiseload=False, + backref=True, + delete=False, + active_history=False, + ) def test_replace_m2o_detached_backref(self): self._run_test( - detached=True, raiseload=False, - backref=True, delete=False, active_history=False) + detached=True, + raiseload=False, + backref=True, + delete=False, + active_history=False, + ) def test_replace_m2o_raiseload_backref(self): self._run_test( - detached=False, raiseload=True, - backref=True, delete=False, active_history=False) + detached=False, + raiseload=True, + backref=True, + delete=False, + active_history=False, + ) def test_replace_m2o_detached_raiseload_backref(self): self._run_test( - detached=True, raiseload=True, - backref=True, delete=False, active_history=False) + detached=True, + raiseload=True, + backref=True, + delete=False, + active_history=False, + ) def test_replace_m2o_activehistory(self): self._run_test( - detached=False, raiseload=False, - backref=False, delete=False, active_history=True) + detached=False, + raiseload=False, + backref=False, + delete=False, + active_history=True, + ) def test_replace_m2o_detached_activehistory(self): self._run_test( - detached=True, raiseload=False, - backref=False, delete=False, active_history=True) + detached=True, + raiseload=False, + backref=False, + delete=False, + active_history=True, + ) def test_replace_m2o_raiseload_activehistory(self): self._run_test( - detached=False, raiseload=True, - backref=False, delete=False, active_history=True) + detached=False, + raiseload=True, + backref=False, + delete=False, + active_history=True, + ) def test_replace_m2o_detached_raiseload_activehistory(self): self._run_test( - detached=True, raiseload=True, - backref=False, delete=False, active_history=True) + detached=True, + raiseload=True, + backref=False, + delete=False, + active_history=True, + ) def test_replace_m2o_backref_activehistory(self): self._run_test( - detached=False, raiseload=False, - backref=True, delete=False, active_history=True) + detached=False, + raiseload=False, + backref=True, + delete=False, + active_history=True, + ) def test_replace_m2o_detached_backref_activehistory(self): self._run_test( - detached=True, raiseload=False, - backref=True, delete=False, active_history=True) + detached=True, + raiseload=False, + backref=True, + delete=False, + active_history=True, + ) def test_replace_m2o_raiseload_backref_activehistory(self): self._run_test( - detached=False, raiseload=True, - backref=True, delete=False, active_history=True) + detached=False, + raiseload=True, + backref=True, + delete=False, + active_history=True, + ) def test_replace_m2o_detached_raiseload_backref_activehistory(self): self._run_test( - detached=True, raiseload=True, - backref=True, delete=False, active_history=True) + detached=True, + raiseload=True, + backref=True, + delete=False, + active_history=True, + ) def test_delete_m2o(self): self._run_test( - detached=False, raiseload=False, - backref=False, delete=True, active_history=False) + detached=False, + raiseload=False, + backref=False, + delete=True, + active_history=False, + ) def test_delete_m2o_detached(self): self._run_test( - detached=True, raiseload=False, - backref=False, delete=True, active_history=False) + detached=True, + raiseload=False, + backref=False, + delete=True, + active_history=False, + ) def test_delete_m2o_raiseload(self): self._run_test( - detached=False, raiseload=True, - backref=False, delete=True, active_history=False) + detached=False, + raiseload=True, + backref=False, + delete=True, + active_history=False, + ) def test_delete_m2o_detached_raiseload(self): self._run_test( - detached=True, raiseload=True, - backref=False, delete=True, active_history=False) + detached=True, + raiseload=True, + backref=False, + delete=True, + active_history=False, + ) def test_delete_m2o_activehistory(self): self._run_test( - detached=False, raiseload=False, - backref=False, delete=True, active_history=True) + detached=False, + raiseload=False, + backref=False, + delete=True, + active_history=True, + ) def test_delete_m2o_detached_activehistory(self): self._run_test( - detached=True, raiseload=False, - backref=False, delete=True, active_history=True) + detached=True, + raiseload=False, + backref=False, + delete=True, + active_history=True, + ) def test_delete_m2o_raiseload_activehistory(self): self._run_test( - detached=False, raiseload=True, - backref=False, delete=True, active_history=True) + detached=False, + raiseload=True, + backref=False, + delete=True, + active_history=True, + ) def test_delete_m2o_detached_raiseload_activehistory(self): self._run_test( - detached=True, raiseload=True, - backref=False, delete=True, active_history=True) + detached=True, + raiseload=True, + backref=False, + delete=True, + active_history=True, + ) class RelationDeprecationTest(fixtures.MappedTest): """test usage of the old 'relation' function.""" - run_inserts = 'once' + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - Table('users_table', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(64))) + Table( + "users_table", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(64)), + ) - Table('addresses_table', metadata, - Column('id', Integer, primary_key=True), - Column('user_id', Integer, ForeignKey('users_table.id')), - Column('email_address', String(128)), - Column('purpose', String(16)), - Column('bounces', Integer, default=0)) + Table( + "addresses_table", + metadata, + Column("id", Integer, primary_key=True), + Column("user_id", Integer, ForeignKey("users_table.id")), + Column("email_address", String(128)), + Column("purpose", String(16)), + Column("bounces", Integer, default=0), + ) @classmethod def setup_classes(cls): @@ -4448,32 +5437,38 @@ class RelationDeprecationTest(fixtures.MappedTest): def fixtures(cls): return dict( users_table=( - ('id', 'name'), - (1, 'jack'), - (2, 'ed'), - (3, 'fred'), - (4, 'chuck')), - + ("id", "name"), + (1, "jack"), + (2, "ed"), + (3, "fred"), + (4, "chuck"), + ), addresses_table=( - ('id', 'user_id', 'email_address', 'purpose', 'bounces'), - (1, 1, 'jack@jack.home', 'Personal', 0), - (2, 1, 'jack@jack.bizz', 'Work', 1), - (3, 2, 'ed@foo.bar', 'Personal', 0), - (4, 3, 'fred@the.fred', 'Personal', 10))) + ("id", "user_id", "email_address", "purpose", "bounces"), + (1, 1, "jack@jack.home", "Personal", 0), + (2, 1, "jack@jack.bizz", "Work", 1), + (3, 2, "ed@foo.bar", "Personal", 0), + (4, 3, "fred@the.fred", "Personal", 10), + ), + ) def test_relation(self): addresses_table, User, users_table, Address = ( self.tables.addresses_table, self.classes.User, self.tables.users_table, - self.classes.Address) + self.classes.Address, + ) - mapper(User, users_table, properties=dict( - addresses=relation(Address, backref='user'), - )) + mapper( + User, + users_table, + properties=dict(addresses=relation(Address, backref="user")), + ) mapper(Address, addresses_table) session = create_session() - session.query(User).filter(User.addresses.any( - Address.email_address == 'ed@foo.bar')).one() + session.query(User).filter( + User.addresses.any(Address.email_address == "ed@foo.bar") + ).one() diff --git a/test/orm/test_scoping.py b/test/orm/test_scoping.py index 507d98f82c..33abc1496e 100644 --- a/test/orm/test_scoping.py +++ b/test/orm/test_scoping.py @@ -11,17 +11,24 @@ from sqlalchemy.testing.mock import Mock class ScopedSessionTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('table1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30))) - Table('table2', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('someid', None, ForeignKey('table1.id'))) + Table( + "table1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + ) + Table( + "table2", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("someid", None, ForeignKey("table1.id")), + ) def test_basic(self): table2, table1 = self.tables.table2, self.tables.table1 @@ -38,8 +45,11 @@ class ScopedSessionTest(fixtures.MappedTest): query = Session.query_property() custom_query = Session.query_property(query_cls=CustomQuery) - mapper(SomeObject, table1, properties={ - 'options': relationship(SomeOtherObject)}) + mapper( + SomeObject, + table1, + properties={"options": relationship(SomeOtherObject)}, + ) mapper(SomeOtherObject, table2) s = SomeObject(id=1, data="hello") @@ -50,15 +60,24 @@ class ScopedSessionTest(fixtures.MappedTest): Session.refresh(sso) Session.remove() - eq_(SomeObject(id=1, data="hello", - options=[SomeOtherObject(someid=1)]), - Session.query(SomeObject).one()) - eq_(SomeObject(id=1, data="hello", - options=[SomeOtherObject(someid=1)]), - SomeObject.query.one()) - eq_(SomeOtherObject(someid=1), + eq_( + SomeObject( + id=1, data="hello", options=[SomeOtherObject(someid=1)] + ), + Session.query(SomeObject).one(), + ) + eq_( + SomeObject( + id=1, data="hello", options=[SomeOtherObject(someid=1)] + ), + SomeObject.query.one(), + ) + eq_( + SomeOtherObject(someid=1), SomeOtherObject.query.filter( - SomeOtherObject.someid == sso.someid).one()) + SomeOtherObject.someid == sso.someid + ).one(), + ) assert isinstance(SomeOtherObject.query, query.Query) assert not isinstance(SomeOtherObject.query, CustomQuery) assert isinstance(SomeOtherObject.custom_query, query.Query) @@ -70,13 +89,15 @@ class ScopedSessionTest(fixtures.MappedTest): assert_raises_message( sa.exc.InvalidRequestError, "Scoped session is already present", - Session, bind=testing.db + Session, + bind=testing.db, ) assert_raises_message( sa.exc.SAWarning, "At least one scoped session is already present. ", - Session.configure, bind=testing.db + Session.configure, + bind=testing.db, ) def test_call_with_kwargs(self): @@ -94,7 +115,8 @@ class ScopedSessionTest(fixtures.MappedTest): assert_raises_message( sa.exc.InvalidRequestError, "Scoped session is already present", - Session, autocommit=True + Session, + autocommit=True, ) mock_scope_func.return_value = 1 diff --git a/test/orm/test_selectable.py b/test/orm/test_selectable.py index 28f12ac95b..50eaaa37b1 100644 --- a/test/orm/test_selectable.py +++ b/test/orm/test_selectable.py @@ -11,14 +11,19 @@ from sqlalchemy.testing import fixtures # TODO: more tests mapping to selects + class SelectableNoFromsTest(fixtures.MappedTest, AssertsCompiledSQL): @classmethod def define_tables(cls, metadata): - Table('common', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', Integer), - Column('extra', String(45))) + Table( + "common", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", Integer), + Column("extra", String(45)), + ) @classmethod def setup_classes(cls): @@ -35,7 +40,7 @@ class SelectableNoFromsTest(fixtures.MappedTest, AssertsCompiledSQL): Session().query(Subset), "SELECT anon_1.x AS anon_1_x, anon_1.y AS anon_1_y, " "anon_1.z AS anon_1_z FROM (SELECT x, y, z) AS anon_1", - use_default_dialect=True + use_default_dialect=True, ) def test_no_table_needs_pl(self): @@ -45,15 +50,18 @@ class SelectableNoFromsTest(fixtures.MappedTest, AssertsCompiledSQL): assert_raises_message( sa.exc.ArgumentError, "could not assemble any primary key columns", - mapper, Subset, selectable + mapper, + Subset, + selectable, ) def test_no_selects(self): Subset, common = self.classes.Subset, self.tables.common subset_select = select([common.c.id, common.c.data]) - assert_raises(sa.exc.InvalidRequestError, - mapper, Subset, subset_select) + assert_raises( + sa.exc.InvalidRequestError, mapper, Subset, subset_select + ) def test_basic(self): Subset, common = self.classes.Subset, self.tables.common @@ -70,5 +78,7 @@ class SelectableNoFromsTest(fixtures.MappedTest, AssertsCompiledSQL): eq_(sess.query(Subset).filter(Subset.data != 1).first(), None) subset_select = sa.orm.class_mapper(Subset).mapped_table - eq_(sess.query(Subset).filter(subset_select.c.data == 1).one(), - Subset(data=1)) + eq_( + sess.query(Subset).filter(subset_select.c.data == 1).one(), + Subset(data=1), + ) diff --git a/test/orm/test_selectin_relations.py b/test/orm/test_selectin_relations.py index 7891f71f0e..78a56ba319 100644 --- a/test/orm/test_selectin_relations.py +++ b/test/orm/test_selectin_relations.py @@ -2,12 +2,22 @@ from sqlalchemy.testing import eq_, is_, is_not_, is_true from sqlalchemy import testing from sqlalchemy.testing.schema import Table, Column from sqlalchemy import Integer, String, ForeignKey, bindparam -from sqlalchemy.orm import selectinload, selectinload_all, \ - mapper, relationship, clear_mappers, create_session, \ - aliased, joinedload, deferred, undefer,\ - Session, subqueryload, defaultload -from sqlalchemy.testing import assert_raises, \ - assert_raises_message +from sqlalchemy.orm import ( + selectinload, + selectinload_all, + mapper, + relationship, + clear_mappers, + create_session, + aliased, + joinedload, + deferred, + undefer, + Session, + subqueryload, + defaultload, +) +from sqlalchemy.testing import assert_raises, assert_raises_message from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock @@ -16,43 +26,60 @@ import sqlalchemy as sa from sqlalchemy.orm import with_polymorphic -from .inheritance._poly_fixtures import _Polymorphic, Person, Engineer, \ - Paperwork, Machine, MachineType, Company +from .inheritance._poly_fixtures import ( + _Polymorphic, + Person, + Engineer, + Paperwork, + Machine, + MachineType, + Company, +) class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): - run_inserts = 'once' + run_inserts = "once" run_deletes = None def test_basic(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), - order_by=Address.id) - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), order_by=Address.id + ) + }, + ) sess = create_session() q = sess.query(User).options(selectinload(User.addresses)) def go(): eq_( - [User(id=7, addresses=[ - Address(id=1, email_address='jack@bean.com')])], - q.filter(User.id == 7).all() + [ + User( + id=7, + addresses=[ + Address(id=1, email_address="jack@bean.com") + ], + ) + ], + q.filter(User.id == 7).all(), ) self.assert_sql_count(testing.db, go, 2) def go(): - eq_( - self.static.user_address_result, - q.order_by(User.id).all() - ) + eq_(self.static.user_address_result, q.order_by(User.id).all()) + self.assert_sql_count(testing.db, go, 2) def test_from_aliased(self): @@ -62,17 +89,24 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.User, self.tables.dingalings, self.classes.Address, - self.tables.addresses) + self.tables.addresses, + ) mapper(Dingaling, dingalings) - mapper(Address, addresses, properties={ - 'dingalings': relationship(Dingaling, order_by=Dingaling.id) - }) - mapper(User, users, properties={ - 'addresses': relationship( - Address, - order_by=Address.id) - }) + mapper( + Address, + addresses, + properties={ + "dingalings": relationship(Dingaling, order_by=Dingaling.id) + }, + ) + mapper( + User, + users, + properties={ + "addresses": relationship(Address, order_by=Address.id) + }, + ) sess = create_session() u = aliased(User) @@ -81,84 +115,113 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def go(): eq_( - [User(id=7, addresses=[ - Address(id=1, email_address='jack@bean.com')])], - q.filter(u.id == 7).all() + [ + User( + id=7, + addresses=[ + Address(id=1, email_address="jack@bean.com") + ], + ) + ], + q.filter(u.id == 7).all(), ) self.assert_sql_count(testing.db, go, 2) def go(): - eq_( - self.static.user_address_result, - q.order_by(u.id).all() - ) + eq_(self.static.user_address_result, q.order_by(u.id).all()) + self.assert_sql_count(testing.db, go, 2) - q = sess.query(u).\ - options(selectinload_all(u.addresses, Address.dingalings)) + q = sess.query(u).options( + selectinload_all(u.addresses, Address.dingalings) + ) def go(): eq_( [ - User(id=8, addresses=[ - Address(id=2, email_address='ed@wood.com', - dingalings=[Dingaling()]), - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - ]), - User(id=9, addresses=[ - Address(id=5, dingalings=[Dingaling()]) - ]), + User( + id=8, + addresses=[ + Address( + id=2, + email_address="ed@wood.com", + dingalings=[Dingaling()], + ), + Address(id=3, email_address="ed@bettyboop.com"), + Address(id=4, email_address="ed@lala.com"), + ], + ), + User( + id=9, + addresses=[Address(id=5, dingalings=[Dingaling()])], + ), ], - q.filter(u.id.in_([8, 9])).all() + q.filter(u.id.in_([8, 9])).all(), ) + self.assert_sql_count(testing.db, go, 3) def test_from_get(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), - order_by=Address.id) - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), order_by=Address.id + ) + }, + ) sess = create_session() q = sess.query(User).options(selectinload(User.addresses)) def go(): eq_( - User(id=7, addresses=[ - Address(id=1, email_address='jack@bean.com')]), - q.get(7) + User( + id=7, + addresses=[Address(id=1, email_address="jack@bean.com")], + ), + q.get(7), ) self.assert_sql_count(testing.db, go, 2) def test_from_params(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), - order_by=Address.id) - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), order_by=Address.id + ) + }, + ) sess = create_session() q = sess.query(User).options(selectinload(User.addresses)) def go(): eq_( - User(id=7, addresses=[ - Address(id=1, email_address='jack@bean.com')]), - q.filter(User.id == bindparam('foo')).params(foo=7).one() + User( + id=7, + addresses=[Address(id=1, email_address="jack@bean.com")], + ), + q.filter(User.id == bindparam("foo")).params(foo=7).one(), ) self.assert_sql_count(testing.db, go, 2) @@ -166,14 +229,18 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def test_disable_dynamic(self): """test no selectin option on a dynamic.""" - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship(Address, lazy="dynamic") - }) + mapper( + User, + users, + properties={"addresses": relationship(Address, lazy="dynamic")}, + ) mapper(Address, addresses) sess = create_session() @@ -192,17 +259,28 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.tables.items, self.tables.item_keywords, self.classes.Keyword, - self.classes.Item) + self.classes.Item, + ) mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords=relationship(Keyword, secondary=item_keywords, - lazy='selectin', order_by=keywords.c.id))) + mapper( + Item, + items, + properties=dict( + keywords=relationship( + Keyword, + secondary=item_keywords, + lazy="selectin", + order_by=keywords.c.id, + ) + ), + ) q = create_session().query(Item).order_by(Item.id) def go(): eq_(self.static.item_keyword_result, q.all()) + self.assert_sql_count(testing.db, go, 2) def test_many_to_many_with_join(self): @@ -211,18 +289,31 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.tables.items, self.tables.item_keywords, self.classes.Keyword, - self.classes.Item) + self.classes.Item, + ) mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords=relationship(Keyword, secondary=item_keywords, - lazy='selectin', order_by=keywords.c.id))) + mapper( + Item, + items, + properties=dict( + keywords=relationship( + Keyword, + secondary=item_keywords, + lazy="selectin", + order_by=keywords.c.id, + ) + ), + ) q = create_session().query(Item).order_by(Item.id) def go(): - eq_(self.static.item_keyword_result[0:2], - q.join('keywords').filter(Keyword.name == 'red').all()) + eq_( + self.static.item_keyword_result[0:2], + q.join("keywords").filter(Keyword.name == "red").all(), + ) + self.assert_sql_count(testing.db, go, 2) def test_many_to_many_with_join_alias(self): @@ -231,139 +322,193 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.tables.items, self.tables.item_keywords, self.classes.Keyword, - self.classes.Item) + self.classes.Item, + ) mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords=relationship(Keyword, secondary=item_keywords, - lazy='selectin', order_by=keywords.c.id))) + mapper( + Item, + items, + properties=dict( + keywords=relationship( + Keyword, + secondary=item_keywords, + lazy="selectin", + order_by=keywords.c.id, + ) + ), + ) q = create_session().query(Item).order_by(Item.id) def go(): - eq_(self.static.item_keyword_result[0:2], - (q.join('keywords', aliased=True). - filter(Keyword.name == 'red')).all()) + eq_( + self.static.item_keyword_result[0:2], + ( + q.join("keywords", aliased=True).filter( + Keyword.name == "red" + ) + ).all(), + ) + self.assert_sql_count(testing.db, go, 2) def test_orderby(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses), - lazy='selectin', - order_by=addresses.c.email_address), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), + lazy="selectin", + order_by=addresses.c.email_address, + ) + }, + ) q = create_session().query(User) - eq_([ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - Address(id=2, email_address='ed@wood.com') - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[]) - ], q.order_by(User.id).all()) + eq_( + [ + User(id=7, addresses=[Address(id=1)]), + User( + id=8, + addresses=[ + Address(id=3, email_address="ed@bettyboop.com"), + Address(id=4, email_address="ed@lala.com"), + Address(id=2, email_address="ed@wood.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + User(id=10, addresses=[]), + ], + q.order_by(User.id).all(), + ) def test_orderby_multi(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses), - lazy='selectin', - order_by=[ - addresses.c.email_address, - addresses.c.id]), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), + lazy="selectin", + order_by=[addresses.c.email_address, addresses.c.id], + ) + }, + ) q = create_session().query(User) - eq_([ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - Address(id=2, email_address='ed@wood.com') - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[]) - ], q.order_by(User.id).all()) + eq_( + [ + User(id=7, addresses=[Address(id=1)]), + User( + id=8, + addresses=[ + Address(id=3, email_address="ed@bettyboop.com"), + Address(id=4, email_address="ed@lala.com"), + Address(id=2, email_address="ed@wood.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + User(id=10, addresses=[]), + ], + q.order_by(User.id).all(), + ) def test_orderby_related(self): """A regular mapper select on a single table can order by a relationship to a second table""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, - lazy='selectin', - order_by=addresses.c.id), - )) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, lazy="selectin", order_by=addresses.c.id + ) + ), + ) q = create_session().query(User) - result = q.filter(User.id == Address.user_id).\ - order_by(Address.email_address).all() - - eq_([ - User(id=8, addresses=[ - Address(id=2, email_address='ed@wood.com'), - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=7, addresses=[ - Address(id=1) - ]), - ], result) + result = ( + q.filter(User.id == Address.user_id) + .order_by(Address.email_address) + .all() + ) + + eq_( + [ + User( + id=8, + addresses=[ + Address(id=2, email_address="ed@wood.com"), + Address(id=3, email_address="ed@bettyboop.com"), + Address(id=4, email_address="ed@lala.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + User(id=7, addresses=[Address(id=1)]), + ], + result, + ) def test_orderby_desc(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, lazy='selectin', - order_by=[ - sa.desc(addresses.c.email_address) - ]), - )) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, + lazy="selectin", + order_by=[sa.desc(addresses.c.email_address)], + ) + ), + ) sess = create_session() - eq_([ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=2, email_address='ed@wood.com'), - Address(id=4, email_address='ed@lala.com'), - Address(id=3, email_address='ed@bettyboop.com'), - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[]) - ], sess.query(User).order_by(User.id).all()) + eq_( + [ + User(id=7, addresses=[Address(id=1)]), + User( + id=8, + addresses=[ + Address(id=2, email_address="ed@wood.com"), + Address(id=4, email_address="ed@lala.com"), + Address(id=3, email_address="ed@bettyboop.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + User(id=10, addresses=[]), + ], + sess.query(User).order_by(User.id).all(), + ) _pathing_runs = [ ("lazyload", "lazyload", "lazyload", 15), @@ -382,37 +527,50 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self._do_mapper_test(self._pathing_runs) def _do_options_test(self, configs): - users, Keyword, orders, items, order_items, Order, Item, User, \ - keywords, item_keywords = (self.tables.users, - self.classes.Keyword, - self.tables.orders, - self.tables.items, - self.tables.order_items, - self.classes.Order, - self.classes.Item, - self.classes.User, - self.tables.keywords, - self.tables.item_keywords) - - mapper(User, users, properties={ - 'orders': relationship(Order, order_by=orders.c.id), # o2m, m2o - }) - mapper(Order, orders, properties={ - 'items': relationship(Item, - secondary=order_items, - order_by=items.c.id), # m2m - }) - mapper(Item, items, properties={ - 'keywords': relationship(Keyword, - secondary=item_keywords, - order_by=keywords.c.id) # m2m - }) + users, Keyword, orders, items, order_items, Order, Item, User, keywords, item_keywords = ( + self.tables.users, + self.classes.Keyword, + self.tables.orders, + self.tables.items, + self.tables.order_items, + self.classes.Order, + self.classes.Item, + self.classes.User, + self.tables.keywords, + self.tables.item_keywords, + ) + + mapper( + User, + users, + properties={ + "orders": relationship(Order, order_by=orders.c.id) # o2m, m2o + }, + ) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, secondary=order_items, order_by=items.c.id + ) # m2m + }, + ) + mapper( + Item, + items, + properties={ + "keywords": relationship( + Keyword, secondary=item_keywords, order_by=keywords.c.id + ) # m2m + }, + ) mapper(Keyword, keywords) callables = { - 'joinedload': joinedload, - 'selectinload': selectinload, - 'subqueryload': subqueryload + "joinedload": joinedload, + "selectinload": selectinload, + "subqueryload": subqueryload, } for o, i, k, count in configs: @@ -422,46 +580,66 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): if i in callables: options.append(callables[i](User.orders, Order.items)) if k in callables: - options.append(callables[k]( - User.orders, Order.items, Item.keywords)) + options.append( + callables[k](User.orders, Order.items, Item.keywords) + ) self._do_query_tests(options, count) def _do_mapper_test(self, configs): - users, Keyword, orders, items, order_items, Order, Item, User, \ - keywords, item_keywords = (self.tables.users, - self.classes.Keyword, - self.tables.orders, - self.tables.items, - self.tables.order_items, - self.classes.Order, - self.classes.Item, - self.classes.User, - self.tables.keywords, - self.tables.item_keywords) + users, Keyword, orders, items, order_items, Order, Item, User, keywords, item_keywords = ( + self.tables.users, + self.classes.Keyword, + self.tables.orders, + self.tables.items, + self.tables.order_items, + self.classes.Order, + self.classes.Item, + self.classes.User, + self.tables.keywords, + self.tables.item_keywords, + ) opts = { - 'lazyload': 'select', - 'joinedload': 'joined', - 'selectinload': 'selectin', + "lazyload": "select", + "joinedload": "joined", + "selectinload": "selectin", } for o, i, k, count in configs: - mapper(User, users, properties={ - 'orders': relationship(Order, lazy=opts[o], - order_by=orders.c.id), - }) - mapper(Order, orders, properties={ - 'items': relationship(Item, - secondary=order_items, lazy=opts[i], - order_by=items.c.id), - }) - mapper(Item, items, properties={ - 'keywords': relationship(Keyword, - lazy=opts[k], - secondary=item_keywords, - order_by=keywords.c.id) - }) + mapper( + User, + users, + properties={ + "orders": relationship( + Order, lazy=opts[o], order_by=orders.c.id + ) + }, + ) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, + secondary=order_items, + lazy=opts[i], + order_by=items.c.id, + ) + }, + ) + mapper( + Item, + items, + properties={ + "keywords": relationship( + Keyword, + lazy=opts[k], + secondary=item_keywords, + order_by=keywords.c.id, + ) + }, + ) mapper(Keyword, keywords) try: @@ -477,69 +655,103 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def go(): eq_( sess.query(User).options(*opts).order_by(User.id).all(), - self.static.user_item_keyword_result + self.static.user_item_keyword_result, ) + self.assert_sql_count(testing.db, go, count) eq_( - sess.query(User).options(*opts).filter(User.name == 'fred'). - order_by(User.id).all(), - self.static.user_item_keyword_result[2:3] + sess.query(User) + .options(*opts) + .filter(User.name == "fred") + .order_by(User.id) + .all(), + self.static.user_item_keyword_result[2:3], ) sess = create_session() eq_( - sess.query(User).options(*opts).join(User.orders). - filter(Order.id == 3). - order_by(User.id).all(), - self.static.user_item_keyword_result[0:1] + sess.query(User) + .options(*opts) + .join(User.orders) + .filter(Order.id == 3) + .order_by(User.id) + .all(), + self.static.user_item_keyword_result[0:1], ) def test_cyclical(self): """A circular eager relationship breaks the cycle with a lazy loader""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, lazy='selectin', - backref=sa.orm.backref( - 'user', lazy='selectin'), - order_by=Address.id) - )) - is_(sa.orm.class_mapper(User).get_property('addresses').lazy, - 'selectin') - is_(sa.orm.class_mapper(Address).get_property('user').lazy, 'selectin') + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, + lazy="selectin", + backref=sa.orm.backref("user", lazy="selectin"), + order_by=Address.id, + ) + ), + ) + is_( + sa.orm.class_mapper(User).get_property("addresses").lazy, + "selectin", + ) + is_(sa.orm.class_mapper(Address).get_property("user").lazy, "selectin") sess = create_session() - eq_(self.static.user_address_result, - sess.query(User).order_by(User.id).all()) + eq_( + self.static.user_address_result, + sess.query(User).order_by(User.id).all(), + ) def test_cyclical_explicit_join_depth(self): """A circular eager relationship breaks the cycle with a lazy loader""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, lazy='selectin', join_depth=1, - backref=sa.orm.backref( - 'user', lazy='selectin', join_depth=1), - order_by=Address.id) - )) - is_(sa.orm.class_mapper(User).get_property('addresses').lazy, - 'selectin') - is_(sa.orm.class_mapper(Address).get_property('user').lazy, 'selectin') + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, + lazy="selectin", + join_depth=1, + backref=sa.orm.backref( + "user", lazy="selectin", join_depth=1 + ), + order_by=Address.id, + ) + ), + ) + is_( + sa.orm.class_mapper(User).get_property("addresses").lazy, + "selectin", + ) + is_(sa.orm.class_mapper(Address).get_property("user").lazy, "selectin") sess = create_session() - eq_(self.static.user_address_result, - sess.query(User).order_by(User.id).all()) + eq_( + self.static.user_address_result, + sess.query(User).order_by(User.id).all(), + ) def test_double(self): """Eager loading with two relationships simultaneously, @@ -551,10 +763,11 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.User, self.classes.Address, self.classes.Order, - self.tables.addresses) + self.tables.addresses, + ) - openorders = sa.alias(orders, 'openorders') - closedorders = sa.alias(orders, 'closedorders') + openorders = sa.alias(orders, "openorders") + closedorders = sa.alias(orders, "closedorders") mapper(Address, addresses) mapper(Order, orders) @@ -562,150 +775,217 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): open_mapper = mapper(Order, openorders, non_primary=True) closed_mapper = mapper(Order, closedorders, non_primary=True) - mapper(User, users, properties=dict( - addresses=relationship(Address, lazy='selectin', - order_by=addresses.c.id), - open_orders=relationship( - open_mapper, - primaryjoin=sa.and_(openorders.c.isopen == 1, - users.c.id == openorders.c.user_id), - lazy='selectin', order_by=openorders.c.id), - closed_orders=relationship( - closed_mapper, - primaryjoin=sa.and_(closedorders.c.isopen == 0, - users.c.id == closedorders.c.user_id), - lazy='selectin', order_by=closedorders.c.id))) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, lazy="selectin", order_by=addresses.c.id + ), + open_orders=relationship( + open_mapper, + primaryjoin=sa.and_( + openorders.c.isopen == 1, + users.c.id == openorders.c.user_id, + ), + lazy="selectin", + order_by=openorders.c.id, + ), + closed_orders=relationship( + closed_mapper, + primaryjoin=sa.and_( + closedorders.c.isopen == 0, + users.c.id == closedorders.c.user_id, + ), + lazy="selectin", + order_by=closedorders.c.id, + ), + ), + ) q = create_session().query(User).order_by(User.id) def go(): - eq_([ - User( - id=7, - addresses=[Address(id=1)], - open_orders=[Order(id=3)], - closed_orders=[Order(id=1), Order(id=5)] - ), - User( - id=8, - addresses=[Address(id=2), Address(id=3), Address(id=4)], - open_orders=[], - closed_orders=[] - ), - User( - id=9, - addresses=[Address(id=5)], - open_orders=[Order(id=4)], - closed_orders=[Order(id=2)] - ), - User(id=10) + eq_( + [ + User( + id=7, + addresses=[Address(id=1)], + open_orders=[Order(id=3)], + closed_orders=[Order(id=1), Order(id=5)], + ), + User( + id=8, + addresses=[ + Address(id=2), + Address(id=3), + Address(id=4), + ], + open_orders=[], + closed_orders=[], + ), + User( + id=9, + addresses=[Address(id=5)], + open_orders=[Order(id=4)], + closed_orders=[Order(id=2)], + ), + User(id=10), + ], + q.all(), + ) - ], q.all()) self.assert_sql_count(testing.db, go, 4) def test_double_same_mappers(self): """Eager loading with two relationships simultaneously, from the same table, using aliases.""" - addresses, items, order_items, orders, Item, User, Address, Order, \ - users = (self.tables.addresses, - self.tables.items, - self.tables.order_items, - self.tables.orders, - self.classes.Item, - self.classes.User, - self.classes.Address, - self.classes.Order, - self.tables.users) + addresses, items, order_items, orders, Item, User, Address, Order, users = ( + self.tables.addresses, + self.tables.items, + self.tables.order_items, + self.tables.orders, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.classes.Order, + self.tables.users, + ) mapper(Address, addresses) - mapper(Order, orders, properties={ - 'items': relationship(Item, secondary=order_items, lazy='selectin', - order_by=items.c.id)}) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, + secondary=order_items, + lazy="selectin", + order_by=items.c.id, + ) + }, + ) mapper(Item, items) - mapper(User, users, properties=dict( - addresses=relationship( - Address, lazy='selectin', order_by=addresses.c.id), - open_orders=relationship( - Order, - primaryjoin=sa.and_(orders.c.isopen == 1, - users.c.id == orders.c.user_id), - lazy='selectin', order_by=orders.c.id), - closed_orders=relationship( - Order, - primaryjoin=sa.and_(orders.c.isopen == 0, - users.c.id == orders.c.user_id), - lazy='selectin', order_by=orders.c.id))) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, lazy="selectin", order_by=addresses.c.id + ), + open_orders=relationship( + Order, + primaryjoin=sa.and_( + orders.c.isopen == 1, users.c.id == orders.c.user_id + ), + lazy="selectin", + order_by=orders.c.id, + ), + closed_orders=relationship( + Order, + primaryjoin=sa.and_( + orders.c.isopen == 0, users.c.id == orders.c.user_id + ), + lazy="selectin", + order_by=orders.c.id, + ), + ), + ) q = create_session().query(User).order_by(User.id) def go(): - eq_([ - User(id=7, - addresses=[ - Address(id=1)], - open_orders=[Order(id=3, - items=[ - Item(id=3), - Item(id=4), - Item(id=5)])], - closed_orders=[Order(id=1, - items=[ - Item(id=1), - Item(id=2), - Item(id=3)]), - Order(id=5, - items=[ - Item(id=5)])]), - User(id=8, - addresses=[ - Address(id=2), - Address(id=3), - Address(id=4)], - open_orders=[], - closed_orders=[]), - User(id=9, - addresses=[ - Address(id=5)], - open_orders=[ - Order(id=4, - items=[ - Item(id=1), - Item(id=5)])], - closed_orders=[ - Order(id=2, - items=[ - Item(id=1), - Item(id=2), - Item(id=3)])]), - User(id=10) - ], q.all()) + eq_( + [ + User( + id=7, + addresses=[Address(id=1)], + open_orders=[ + Order( + id=3, + items=[Item(id=3), Item(id=4), Item(id=5)], + ) + ], + closed_orders=[ + Order( + id=1, + items=[Item(id=1), Item(id=2), Item(id=3)], + ), + Order(id=5, items=[Item(id=5)]), + ], + ), + User( + id=8, + addresses=[ + Address(id=2), + Address(id=3), + Address(id=4), + ], + open_orders=[], + closed_orders=[], + ), + User( + id=9, + addresses=[Address(id=5)], + open_orders=[ + Order(id=4, items=[Item(id=1), Item(id=5)]) + ], + closed_orders=[ + Order( + id=2, + items=[Item(id=1), Item(id=2), Item(id=3)], + ) + ], + ), + User(id=10), + ], + q.all(), + ) + self.assert_sql_count(testing.db, go, 6) def test_limit(self): """Limit operations combined with lazy-load relationships.""" - users, items, order_items, orders, Item, User, Address, Order, \ - addresses = (self.tables.users, - self.tables.items, - self.tables.order_items, - self.tables.orders, - self.classes.Item, - self.classes.User, - self.classes.Address, - self.classes.Order, - self.tables.addresses) + users, items, order_items, orders, Item, User, Address, Order, addresses = ( + self.tables.users, + self.tables.items, + self.tables.order_items, + self.tables.orders, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.classes.Order, + self.tables.addresses, + ) mapper(Item, items) - mapper(Order, orders, properties={ - 'items': relationship(Item, secondary=order_items, lazy='selectin', - order_by=items.c.id) - }) - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses), - lazy='selectin', - order_by=addresses.c.id), - 'orders': relationship(Order, lazy='select', order_by=orders.c.id) - }) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, + secondary=order_items, + lazy="selectin", + order_by=items.c.id, + ) + }, + ) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), + lazy="selectin", + order_by=addresses.c.id, + ), + "orders": relationship( + Order, lazy="select", order_by=orders.c.id + ), + }, + ) sess = create_session() q = sess.query(User) @@ -718,17 +998,24 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): @testing.uses_deprecated("Mapper.order_by") def test_mapper_order_by(self): - users, User, Address, addresses = (self.tables.users, - self.classes.User, - self.classes.Address, - self.tables.addresses) + users, User, Address, addresses = ( + self.tables.users, + self.classes.User, + self.classes.Address, + self.tables.addresses, + ) mapper(Address, addresses) - mapper(User, users, properties={ - 'addresses': relationship(Address, - lazy='selectin', - order_by=addresses.c.id), - }, order_by=users.c.id.desc()) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, lazy="selectin", order_by=addresses.c.id + ) + }, + order_by=users.c.id.desc(), + ) sess = create_session() q = sess.query(User) @@ -737,31 +1024,45 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): eq_(result, list(reversed(self.static.user_address_result[2:4]))) def test_one_to_many_scalar(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) - - mapper(User, users, properties=dict( - address=relationship(mapper(Address, addresses), - lazy='selectin', uselist=False) - )) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) + + mapper( + User, + users, + properties=dict( + address=relationship( + mapper(Address, addresses), lazy="selectin", uselist=False + ) + ), + ) q = create_session().query(User) def go(): result = q.filter(users.c.id == 7).all() eq_([User(id=7, address=Address(id=1))], result) + self.assert_sql_count(testing.db, go, 2) def test_many_to_one(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(Address, addresses, properties=dict( - user=relationship(mapper(User, users), lazy='selectin') - )) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + Address, + addresses, + properties=dict( + user=relationship(mapper(User, users), lazy="selectin") + ), + ) sess = create_session() q = sess.query(Address) @@ -770,97 +1071,135 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): is_not_(a.user, None) u1 = sess.query(User).get(7) is_(a.user, u1) + self.assert_sql_count(testing.db, go, 2) def test_double_with_aggregate(self): - User, users, orders, Order = (self.classes.User, - self.tables.users, - self.tables.orders, - self.classes.Order) + User, users, orders, Order = ( + self.classes.User, + self.tables.users, + self.tables.orders, + self.classes.Order, + ) - max_orders_by_user = sa.select([sa.func.max(orders.c.id) - .label('order_id')], - group_by=[orders.c.user_id]) \ - .alias('max_orders_by_user') + max_orders_by_user = sa.select( + [sa.func.max(orders.c.id).label("order_id")], + group_by=[orders.c.user_id], + ).alias("max_orders_by_user") max_orders = orders.select( - orders.c.id == max_orders_by_user.c.order_id).\ - alias('max_orders') + orders.c.id == max_orders_by_user.c.order_id + ).alias("max_orders") mapper(Order, orders) - mapper(User, users, properties={ - 'orders': relationship(Order, backref='user', lazy='selectin', - order_by=orders.c.id), - 'max_order': relationship( - mapper(Order, max_orders, non_primary=True), - lazy='selectin', uselist=False) - }) + mapper( + User, + users, + properties={ + "orders": relationship( + Order, + backref="user", + lazy="selectin", + order_by=orders.c.id, + ), + "max_order": relationship( + mapper(Order, max_orders, non_primary=True), + lazy="selectin", + uselist=False, + ), + }, + ) q = create_session().query(User) def go(): - eq_([ - User(id=7, orders=[ - Order(id=1), - Order(id=3), - Order(id=5), + eq_( + [ + User( + id=7, + orders=[Order(id=1), Order(id=3), Order(id=5)], + max_order=Order(id=5), + ), + User(id=8, orders=[]), + User( + id=9, + orders=[Order(id=2), Order(id=4)], + max_order=Order(id=4), + ), + User(id=10), ], - max_order=Order(id=5) - ), - User(id=8, orders=[]), - User(id=9, orders=[Order(id=2), Order(id=4)], - max_order=Order(id=4)), - User(id=10), - ], q.order_by(User.id).all()) + q.order_by(User.id).all(), + ) + self.assert_sql_count(testing.db, go, 3) def test_uselist_false_warning(self): """test that multiple rows received by a uselist=False raises a warning.""" - User, users, orders, Order = (self.classes.User, - self.tables.users, - self.tables.orders, - self.classes.Order) + User, users, orders, Order = ( + self.classes.User, + self.tables.users, + self.tables.orders, + self.classes.Order, + ) - mapper(User, users, properties={ - 'order': relationship(Order, uselist=False) - }) + mapper( + User, + users, + properties={"order": relationship(Order, uselist=False)}, + ) mapper(Order, orders) s = create_session() - assert_raises(sa.exc.SAWarning, - s.query(User).options(selectinload(User.order)).all) + assert_raises( + sa.exc.SAWarning, + s.query(User).options(selectinload(User.order)).all, + ) class LoadOnExistingTest(_fixtures.FixtureTest): """test that loaders from a base Query fully populate.""" - run_inserts = 'once' + run_inserts = "once" run_deletes = None def _collection_to_scalar_fixture(self): - User, Address, Dingaling = self.classes.User, \ - self.classes.Address, self.classes.Dingaling - mapper(User, self.tables.users, properties={ - 'addresses': relationship(Address), - }) - mapper(Address, self.tables.addresses, properties={ - 'dingaling': relationship(Dingaling) - }) + User, Address, Dingaling = ( + self.classes.User, + self.classes.Address, + self.classes.Dingaling, + ) + mapper( + User, + self.tables.users, + properties={"addresses": relationship(Address)}, + ) + mapper( + Address, + self.tables.addresses, + properties={"dingaling": relationship(Dingaling)}, + ) mapper(Dingaling, self.tables.dingalings) sess = Session(autoflush=False) return User, Address, Dingaling, sess def _collection_to_collection_fixture(self): - User, Order, Item = self.classes.User, \ - self.classes.Order, self.classes.Item - mapper(User, self.tables.users, properties={ - 'orders': relationship(Order), - }) - mapper(Order, self.tables.orders, properties={ - 'items': relationship(Item, secondary=self.tables.order_items), - }) + User, Order, Item = ( + self.classes.User, + self.classes.Order, + self.classes.Item, + ) + mapper( + User, self.tables.users, properties={"orders": relationship(Order)} + ) + mapper( + Order, + self.tables.orders, + properties={ + "items": relationship(Item, secondary=self.tables.order_items) + }, + ) mapper(Item, self.tables.items) sess = Session(autoflush=False) @@ -868,19 +1207,25 @@ class LoadOnExistingTest(_fixtures.FixtureTest): def _eager_config_fixture(self): User, Address = self.classes.User, self.classes.Address - mapper(User, self.tables.users, properties={ - 'addresses': relationship(Address, lazy="selectin"), - }) + mapper( + User, + self.tables.users, + properties={"addresses": relationship(Address, lazy="selectin")}, + ) mapper(Address, self.tables.addresses) sess = Session(autoflush=False) return User, Address, sess def _deferred_config_fixture(self): User, Address = self.classes.User, self.classes.Address - mapper(User, self.tables.users, properties={ - 'name': deferred(self.tables.users.c.name), - 'addresses': relationship(Address, lazy="selectin"), - }) + mapper( + User, + self.tables.users, + properties={ + "name": deferred(self.tables.users.c.name), + "addresses": relationship(Address, lazy="selectin"), + }, + ) mapper(Address, self.tables.addresses) sess = Session(autoflush=False) return User, Address, sess @@ -889,24 +1234,26 @@ class LoadOnExistingTest(_fixtures.FixtureTest): User, Address, sess = self._eager_config_fixture() u1 = sess.query(User).get(8) - assert 'addresses' in u1.__dict__ + assert "addresses" in u1.__dict__ sess.expire(u1) def go(): eq_(u1.id, 8) + self.assert_sql_count(testing.db, go, 1) - assert 'addresses' not in u1.__dict__ + assert "addresses" not in u1.__dict__ def test_no_query_on_deferred(self): User, Address, sess = self._deferred_config_fixture() u1 = sess.query(User).get(8) - assert 'addresses' in u1.__dict__ - sess.expire(u1, ['addresses']) + assert "addresses" in u1.__dict__ + sess.expire(u1, ["addresses"]) def go(): - eq_(u1.name, 'ed') + eq_(u1.name, "ed") + self.assert_sql_count(testing.db, go, 1) - assert 'addresses' not in u1.__dict__ + assert "addresses" not in u1.__dict__ def test_populate_existing_propagate(self): User, Address, sess = self._eager_config_fixture() @@ -927,17 +1274,18 @@ class LoadOnExistingTest(_fixtures.FixtureTest): a1 = Address() u1.addresses.append(a1) a2 = u1.addresses[0] - a2.email_address = 'foo' - sess.query(User).options(selectinload_all("addresses.dingaling")).\ - filter_by(id=8).all() + a2.email_address = "foo" + sess.query(User).options( + selectinload_all("addresses.dingaling") + ).filter_by(id=8).all() assert u1.addresses[-1] is a1 for a in u1.addresses: if a is not a1: - assert 'dingaling' in a.__dict__ + assert "dingaling" in a.__dict__ else: - assert 'dingaling' not in a.__dict__ + assert "dingaling" not in a.__dict__ if a is a2: - eq_(a2.email_address, 'foo') + eq_(a2.email_address, "foo") def test_loads_second_level_collection_to_collection(self): User, Order, Item, sess = self._collection_to_collection_fixture() @@ -946,76 +1294,92 @@ class LoadOnExistingTest(_fixtures.FixtureTest): u1.orders o1 = Order() u1.orders.append(o1) - sess.query(User).options(selectinload_all("orders.items")).\ - filter_by(id=7).all() + sess.query(User).options(selectinload_all("orders.items")).filter_by( + id=7 + ).all() for o in u1.orders: if o is not o1: - assert 'items' in o.__dict__ + assert "items" in o.__dict__ else: - assert 'items' not in o.__dict__ + assert "items" not in o.__dict__ def test_load_two_levels_collection_to_scalar(self): User, Address, Dingaling, sess = self._collection_to_scalar_fixture() - u1 = sess.query(User).filter_by(id=8).options( - selectinload("addresses")).one() + u1 = ( + sess.query(User) + .filter_by(id=8) + .options(selectinload("addresses")) + .one() + ) sess.query(User).filter_by(id=8).options( - selectinload_all("addresses.dingaling")).first() - assert 'dingaling' in u1.addresses[0].__dict__ + selectinload_all("addresses.dingaling") + ).first() + assert "dingaling" in u1.addresses[0].__dict__ def test_load_two_levels_collection_to_collection(self): User, Order, Item, sess = self._collection_to_collection_fixture() - u1 = sess.query(User).filter_by(id=7).options( - selectinload("orders")).one() + u1 = ( + sess.query(User) + .filter_by(id=7) + .options(selectinload("orders")) + .one() + ) sess.query(User).filter_by(id=7).options( - selectinload_all("orders.items")).first() - assert 'items' in u1.orders[0].__dict__ + selectinload_all("orders.items") + ).first() + assert "items" in u1.orders[0].__dict__ class OrderBySecondaryTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('m2m', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('aid', Integer, ForeignKey('a.id')), - Column('bid', Integer, ForeignKey('b.id'))) - - Table('a', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50))) - Table('b', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50))) + Table( + "m2m", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("aid", Integer, ForeignKey("a.id")), + Column("bid", Integer, ForeignKey("b.id")), + ) + + Table( + "a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + Table( + "b", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) @classmethod def fixtures(cls): return dict( - a=(('id', 'data'), - (1, 'a1'), - (2, 'a2')), - - b=(('id', 'data'), - (1, 'b1'), - (2, 'b2'), - (3, 'b3'), - (4, 'b4')), - - m2m=(('id', 'aid', 'bid'), - (2, 1, 1), - (4, 2, 4), - (1, 1, 3), - (6, 2, 2), - (3, 1, 2), - (5, 2, 3))) + a=(("id", "data"), (1, "a1"), (2, "a2")), + b=(("id", "data"), (1, "b1"), (2, "b2"), (3, "b3"), (4, "b4")), + m2m=( + ("id", "aid", "bid"), + (2, 1, 1), + (4, 2, 4), + (1, 1, 3), + (6, 2, 2), + (3, 1, 2), + (5, 2, 3), + ), + ) def test_ordering(self): - a, m2m, b = (self.tables.a, - self.tables.m2m, - self.tables.b) + a, m2m, b = (self.tables.a, self.tables.m2m, self.tables.b) class A(fixtures.ComparableEntity): pass @@ -1023,19 +1387,31 @@ class OrderBySecondaryTest(fixtures.MappedTest): class B(fixtures.ComparableEntity): pass - mapper(A, a, properties={ - 'bs': relationship(B, secondary=m2m, lazy='selectin', - order_by=m2m.c.id) - }) + mapper( + A, + a, + properties={ + "bs": relationship( + B, secondary=m2m, lazy="selectin", order_by=m2m.c.id + ) + }, + ) mapper(B, b) sess = create_session() def go(): - eq_(sess.query(A).all(), [ - A(data='a1', bs=[B(data='b3'), B(data='b1'), B(data='b2')]), - A(bs=[B(data='b4'), B(data='b3'), B(data='b2')]) - ]) + eq_( + sess.query(A).all(), + [ + A( + data="a1", + bs=[B(data="b3"), B(data="b1"), B(data="b2")], + ), + A(bs=[B(data="b4"), B(data="b3"), B(data="b2")]), + ], + ) + self.assert_sql_count(testing.db, go, 2) @@ -1053,28 +1429,45 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): @classmethod def define_tables(cls, metadata): - Table('people', metadata, - Column('person_id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('type', String(30))) + Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("type", String(30)), + ) # to test fully, PK of engineers table must be # named differently from that of people - Table('engineers', metadata, - Column('engineer_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('primary_language', String(50))) - - Table('paperwork', metadata, - Column('paperwork_id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('description', String(50)), - Column('person_id', Integer, - ForeignKey('people.person_id'))) + Table( + "engineers", + metadata, + Column( + "engineer_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("primary_language", String(50)), + ) + + Table( + "paperwork", + metadata, + Column( + "paperwork_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("description", String(50)), + Column("person_id", Integer, ForeignKey("people.person_id")), + ) @classmethod def setup_mappers(cls): @@ -1082,16 +1475,24 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): engineers = cls.tables.engineers paperwork = cls.tables.paperwork - mapper(Person, people, - polymorphic_on=people.c.type, - polymorphic_identity='person', - properties={ - 'paperwork': relationship( - Paperwork, order_by=paperwork.c.paperwork_id)}) + mapper( + Person, + people, + polymorphic_on=people.c.type, + polymorphic_identity="person", + properties={ + "paperwork": relationship( + Paperwork, order_by=paperwork.c.paperwork_id + ) + }, + ) - mapper(Engineer, engineers, - inherits=Person, - polymorphic_identity='engineer') + mapper( + Engineer, + engineers, + inherits=Person, + polymorphic_identity="engineer", + ) mapper(Paperwork, paperwork) @@ -1100,8 +1501,10 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): e1 = Engineer(primary_language="java") e2 = Engineer(primary_language="c++") - e1.paperwork = [Paperwork(description="tps report #1"), - Paperwork(description="tps report #2")] + e1.paperwork = [ + Paperwork(description="tps report #1"), + Paperwork(description="tps report #2"), + ] e2.paperwork = [Paperwork(description="tps report #3")] sess = create_session() sess.add_all([e1, e2]) @@ -1111,16 +1514,21 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): sess = create_session() # use Person.paperwork here just to give the least # amount of context - q = sess.query(Engineer).\ - filter(Engineer.primary_language == 'java').\ - options(selectinload(Person.paperwork)) + q = ( + sess.query(Engineer) + .filter(Engineer.primary_language == "java") + .options(selectinload(Person.paperwork)) + ) def go(): - eq_(q.all()[0].paperwork, - [Paperwork(description="tps report #1"), - Paperwork(description="tps report #2")], + eq_( + q.all()[0].paperwork, + [ + Paperwork(description="tps report #1"), + Paperwork(description="tps report #2"), + ], + ) - ) self.assert_sql_execution( testing.db, go, @@ -1132,7 +1540,7 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "FROM people JOIN engineers ON " "people.person_id = engineers.engineer_id " "WHERE engineers.primary_language = :primary_language_1", - {"primary_language_1": "java"} + {"primary_language_1": "java"}, ), CompiledSQL( "SELECT paperwork.person_id AS paperwork_person_id, " @@ -1141,26 +1549,31 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "FROM paperwork WHERE paperwork.person_id " "IN ([EXPANDING_primary_keys]) " "ORDER BY paperwork.person_id, paperwork.paperwork_id", - [{'primary_keys': [1]}] - ) + [{"primary_keys": [1]}], + ), ) def test_correct_select_existingfrom(self): sess = create_session() # use Person.paperwork here just to give the least # amount of context - q = sess.query(Engineer).\ - filter(Engineer.primary_language == 'java').\ - join(Engineer.paperwork).\ - filter(Paperwork.description == "tps report #2").\ - options(selectinload(Person.paperwork)) + q = ( + sess.query(Engineer) + .filter(Engineer.primary_language == "java") + .join(Engineer.paperwork) + .filter(Paperwork.description == "tps report #2") + .options(selectinload(Person.paperwork)) + ) def go(): - eq_(q.one().paperwork, - [Paperwork(description="tps report #1"), - Paperwork(description="tps report #2")], + eq_( + q.one().paperwork, + [ + Paperwork(description="tps report #1"), + Paperwork(description="tps report #2"), + ], + ) - ) self.assert_sql_execution( testing.db, go, @@ -1174,8 +1587,10 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "JOIN paperwork ON people.person_id = paperwork.person_id " "WHERE engineers.primary_language = :primary_language_1 " "AND paperwork.description = :description_1", - {"primary_language_1": "java", - "description_1": "tps report #2"} + { + "primary_language_1": "java", + "description_1": "tps report #2", + }, ), CompiledSQL( "SELECT paperwork.person_id AS paperwork_person_id, " @@ -1184,8 +1599,8 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "FROM paperwork WHERE paperwork.person_id " "IN ([EXPANDING_primary_keys]) " "ORDER BY paperwork.person_id, paperwork.paperwork_id", - [{'primary_keys': [1]}] - ) + [{"primary_keys": [1]}], + ), ) def test_correct_select_with_polymorphic_no_alias(self): @@ -1193,20 +1608,24 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): sess = create_session() wp = with_polymorphic(Person, [Engineer]) - q = sess.query(wp).\ - options(selectinload(wp.paperwork)).\ - order_by(Engineer.primary_language.desc()) + q = ( + sess.query(wp) + .options(selectinload(wp.paperwork)) + .order_by(Engineer.primary_language.desc()) + ) def go(): - eq_(q.first(), + eq_( + q.first(), Engineer( paperwork=[ Paperwork(description="tps report #1"), - Paperwork(description="tps report #2")], - primary_language='java' + Paperwork(description="tps report #2"), + ], + primary_language="java", + ), ) - ) self.assert_sql_execution( testing.db, go, @@ -1217,7 +1636,8 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "engineers.primary_language AS engineers_primary_language " "FROM people LEFT OUTER JOIN engineers ON people.person_id = " "engineers.engineer_id ORDER BY engineers.primary_language " - "DESC LIMIT :param_1"), + "DESC LIMIT :param_1" + ), CompiledSQL( "SELECT paperwork.person_id AS paperwork_person_id, " "paperwork.paperwork_id AS paperwork_paperwork_id, " @@ -1225,8 +1645,8 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "FROM paperwork WHERE paperwork.person_id " "IN ([EXPANDING_primary_keys]) " "ORDER BY paperwork.person_id, paperwork.paperwork_id", - [{'primary_keys': [1]}] - ) + [{"primary_keys": [1]}], + ), ) def test_correct_select_with_polymorphic_alias(self): @@ -1234,20 +1654,24 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): sess = create_session() wp = with_polymorphic(Person, [Engineer], aliased=True) - q = sess.query(wp).\ - options(selectinload(wp.paperwork)).\ - order_by(wp.Engineer.primary_language.desc()) + q = ( + sess.query(wp) + .options(selectinload(wp.paperwork)) + .order_by(wp.Engineer.primary_language.desc()) + ) def go(): - eq_(q.first(), + eq_( + q.first(), Engineer( paperwork=[ Paperwork(description="tps report #1"), - Paperwork(description="tps report #2")], - primary_language='java' + Paperwork(description="tps report #2"), + ], + primary_language="java", + ), ) - ) self.assert_sql_execution( testing.db, go, @@ -1266,7 +1690,8 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "FROM people LEFT OUTER JOIN engineers ON people.person_id = " "engineers.engineer_id) AS anon_1 " "ORDER BY anon_1.engineers_primary_language DESC " - "LIMIT :param_1"), + "LIMIT :param_1" + ), CompiledSQL( "SELECT paperwork.person_id AS paperwork_person_id, " "paperwork.paperwork_id AS paperwork_paperwork_id, " @@ -1274,8 +1699,8 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "FROM paperwork WHERE paperwork.person_id " "IN ([EXPANDING_primary_keys]) " "ORDER BY paperwork.person_id, paperwork.paperwork_id", - [{'primary_keys': [1]}] - ) + [{"primary_keys": [1]}], + ), ) def test_correct_select_with_polymorphic_flat_alias(self): @@ -1283,20 +1708,24 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): sess = create_session() wp = with_polymorphic(Person, [Engineer], aliased=True, flat=True) - q = sess.query(wp).\ - options(selectinload(wp.paperwork)).\ - order_by(wp.Engineer.primary_language.desc()) + q = ( + sess.query(wp) + .options(selectinload(wp.paperwork)) + .order_by(wp.Engineer.primary_language.desc()) + ) def go(): - eq_(q.first(), + eq_( + q.first(), Engineer( paperwork=[ Paperwork(description="tps report #1"), - Paperwork(description="tps report #2")], - primary_language='java' + Paperwork(description="tps report #2"), + ], + primary_language="java", + ), ) - ) self.assert_sql_execution( testing.db, go, @@ -1309,7 +1738,8 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "FROM people AS people_1 " "LEFT OUTER JOIN engineers AS engineers_1 " "ON people_1.person_id = engineers_1.engineer_id " - "ORDER BY engineers_1.primary_language DESC LIMIT :param_1"), + "ORDER BY engineers_1.primary_language DESC LIMIT :param_1" + ), CompiledSQL( "SELECT paperwork.person_id AS paperwork_person_id, " "paperwork.paperwork_id AS paperwork_paperwork_id, " @@ -1317,9 +1747,8 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "FROM paperwork WHERE paperwork.person_id " "IN ([EXPANDING_primary_keys]) " "ORDER BY paperwork.person_id, paperwork.paperwork_id", - [{'primary_keys': [1]}] - - ) + [{"primary_keys": [1]}], + ), ) @@ -1329,87 +1758,75 @@ class HeterogeneousSubtypesTest(fixtures.DeclarativeMappedTest): Base = cls.DeclarativeBasic class Company(Base): - __tablename__ = 'company' + __tablename__ = "company" id = Column(Integer, primary_key=True) name = Column(String(50)) - employees = relationship('Employee', order_by="Employee.id") + employees = relationship("Employee", order_by="Employee.id") class Employee(Base): - __tablename__ = 'employee' + __tablename__ = "employee" id = Column(Integer, primary_key=True) type = Column(String(50)) name = Column(String(50)) - company_id = Column(ForeignKey('company.id')) + company_id = Column(ForeignKey("company.id")) __mapper_args__ = { - 'polymorphic_on': 'type', - 'with_polymorphic': '*', + "polymorphic_on": "type", + "with_polymorphic": "*", } class Programmer(Employee): - __tablename__ = 'programmer' - id = Column(ForeignKey('employee.id'), primary_key=True) - languages = relationship('Language') + __tablename__ = "programmer" + id = Column(ForeignKey("employee.id"), primary_key=True) + languages = relationship("Language") - __mapper_args__ = { - 'polymorphic_identity': 'programmer', - } + __mapper_args__ = {"polymorphic_identity": "programmer"} class Manager(Employee): - __tablename__ = 'manager' - id = Column(ForeignKey('employee.id'), primary_key=True) + __tablename__ = "manager" + id = Column(ForeignKey("employee.id"), primary_key=True) golf_swing_id = Column(ForeignKey("golf_swing.id")) golf_swing = relationship("GolfSwing") - __mapper_args__ = { - 'polymorphic_identity': 'manager', - } + __mapper_args__ = {"polymorphic_identity": "manager"} class Language(Base): - __tablename__ = 'language' + __tablename__ = "language" id = Column(Integer, primary_key=True) programmer_id = Column( - Integer, - ForeignKey('programmer.id'), - nullable=False, + Integer, ForeignKey("programmer.id"), nullable=False ) name = Column(String(50)) class GolfSwing(Base): - __tablename__ = 'golf_swing' + __tablename__ = "golf_swing" id = Column(Integer, primary_key=True) name = Column(String(50)) @classmethod def insert_data(cls): Company, Programmer, Manager, GolfSwing, Language = cls.classes( - "Company", "Programmer", "Manager", "GolfSwing", "Language") + "Company", "Programmer", "Manager", "GolfSwing", "Language" + ) c1 = Company( id=1, - name='Foobar Corp', - employees=[Programmer( - id=1, - name='p1', - languages=[Language(id=1, name='Python')], - ), Manager( - id=2, - name='m1', - golf_swing=GolfSwing(name="fore") - )], + name="Foobar Corp", + employees=[ + Programmer( + id=1, name="p1", languages=[Language(id=1, name="Python")] + ), + Manager(id=2, name="m1", golf_swing=GolfSwing(name="fore")), + ], ) c2 = Company( id=2, - name='bat Corp', + name="bat Corp", employees=[ - Manager( - id=3, - name='m2', - golf_swing=GolfSwing(name="clubs"), - ), Programmer( - id=4, - name='p2', - languages=[Language(id=2, name="Java")] - )], + Manager(id=3, name="m2", golf_swing=GolfSwing(name="clubs")), + Programmer( + id=4, name="p2", languages=[Language(id=2, name="Java")] + ), + ], ) sess = Session() sess.add_all([c1, c2]) @@ -1418,14 +1835,19 @@ class HeterogeneousSubtypesTest(fixtures.DeclarativeMappedTest): def test_one_to_many(self): Company, Programmer, Manager, GolfSwing, Language = self.classes( - "Company", "Programmer", "Manager", "GolfSwing", "Language") + "Company", "Programmer", "Manager", "GolfSwing", "Language" + ) sess = Session() - company = sess.query(Company).filter( - Company.id == 1, - ).options( - selectinload(Company.employees.of_type(Programmer)). - selectinload(Programmer.languages), - ).one() + company = ( + sess.query(Company) + .filter(Company.id == 1) + .options( + selectinload( + Company.employees.of_type(Programmer) + ).selectinload(Programmer.languages) + ) + .one() + ) def go(): eq_(company.employees[0].languages[0].name, "Python") @@ -1434,14 +1856,19 @@ class HeterogeneousSubtypesTest(fixtures.DeclarativeMappedTest): def test_many_to_one(self): Company, Programmer, Manager, GolfSwing, Language = self.classes( - "Company", "Programmer", "Manager", "GolfSwing", "Language") + "Company", "Programmer", "Manager", "GolfSwing", "Language" + ) sess = Session() - company = sess.query(Company).filter( - Company.id == 2, - ).options( - selectinload(Company.employees.of_type(Manager)). - selectinload(Manager.golf_swing), - ).one() + company = ( + sess.query(Company) + .filter(Company.id == 2) + .options( + selectinload(Company.employees.of_type(Manager)).selectinload( + Manager.golf_swing + ) + ) + .one() + ) # NOTE: we *MUST* do a SQL compare on this one because the adaption # is very sensitive @@ -1452,14 +1879,22 @@ class HeterogeneousSubtypesTest(fixtures.DeclarativeMappedTest): def test_both(self): Company, Programmer, Manager, GolfSwing, Language = self.classes( - "Company", "Programmer", "Manager", "GolfSwing", "Language") + "Company", "Programmer", "Manager", "GolfSwing", "Language" + ) sess = Session() - rows = sess.query(Company).options( - selectinload(Company.employees.of_type(Manager)). - selectinload(Manager.golf_swing), - defaultload(Company.employees.of_type(Programmer)). - selectinload(Programmer.languages), - ).order_by(Company.id).all() + rows = ( + sess.query(Company) + .options( + selectinload(Company.employees.of_type(Manager)).selectinload( + Manager.golf_swing + ), + defaultload( + Company.employees.of_type(Programmer) + ).selectinload(Programmer.languages), + ) + .order_by(Company.id) + .all() + ) def go(): eq_(rows[0].employees[0].languages[0].name, "Python") @@ -1482,34 +1917,37 @@ class ChunkingTest(fixtures.DeclarativeMappedTest): Base = cls.DeclarativeBasic class A(fixtures.ComparableEntity, Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) bs = relationship("B", order_by="B.id") class B(fixtures.ComparableEntity, Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) - a_id = Column(ForeignKey('a.id')) + a_id = Column(ForeignKey("a.id")) @classmethod def insert_data(cls): - A, B = cls.classes('A', 'B') + A, B = cls.classes("A", "B") session = Session() - session.add_all([ - A(id=i, bs=[B(id=(i * 6) + j) for j in range(1, 6)]) - for i in range(1, 101) - ]) + session.add_all( + [ + A(id=i, bs=[B(id=(i * 6) + j) for j in range(1, 6)]) + for i in range(1, 101) + ] + ) session.commit() def test_odd_number_chunks(self): - A, B = self.classes('A', 'B') + A, B = self.classes("A", "B") session = Session() def go(): with mock.patch( - "sqlalchemy.orm.strategies.SelectInLoader._chunksize", 47): + "sqlalchemy.orm.strategies.SelectInLoader._chunksize", 47 + ): q = session.query(A).options(selectinload(A.bs)).order_by(A.id) for a in q: @@ -1518,35 +1956,32 @@ class ChunkingTest(fixtures.DeclarativeMappedTest): self.assert_sql_execution( testing.db, go, - CompiledSQL( - "SELECT a.id AS a_id FROM a ORDER BY a.id", - {} - ), + CompiledSQL("SELECT a.id AS a_id FROM a ORDER BY a.id", {}), CompiledSQL( "SELECT b.a_id AS b_a_id, b.id AS b_id " "FROM b WHERE b.a_id IN " "([EXPANDING_primary_keys]) ORDER BY b.a_id, b.id", - {"primary_keys": list(range(1, 48))} + {"primary_keys": list(range(1, 48))}, ), CompiledSQL( "SELECT b.a_id AS b_a_id, b.id AS b_id " "FROM b WHERE b.a_id IN " "([EXPANDING_primary_keys]) ORDER BY b.a_id, b.id", - {"primary_keys": list(range(48, 95))} + {"primary_keys": list(range(48, 95))}, ), CompiledSQL( "SELECT b.a_id AS b_a_id, b.id AS b_id " "FROM b WHERE b.a_id IN " "([EXPANDING_primary_keys]) ORDER BY b.a_id, b.id", - {"primary_keys": list(range(95, 101))} - ) + {"primary_keys": list(range(95, 101))}, + ), ) @testing.requires.independent_cursors def test_yield_per(self): # the docs make a lot of guarantees about yield_per # so test that it works - A, B = self.classes('A', 'B') + A, B = self.classes("A", "B") import random @@ -1555,22 +1990,23 @@ class ChunkingTest(fixtures.DeclarativeMappedTest): yield_per = random.randint(8, 105) offset = random.randint(0, 19) total_rows = 100 - offset - total_expected_statements = 1 + int(total_rows / yield_per) + \ - (1 if total_rows % yield_per else 0) + total_expected_statements = ( + 1 + + int(total_rows / yield_per) + + (1 if total_rows % yield_per else 0) + ) def go(): - for a in session.query(A).\ - yield_per(yield_per).\ - offset(offset).\ - options(selectinload(A.bs)): + for a in ( + session.query(A) + .yield_per(yield_per) + .offset(offset) + .options(selectinload(A.bs)) + ): # this part fails with joined eager loading # (if you enable joined eager w/ yield_per) - eq_( - a.bs, [ - B(id=(a.id * 6) + j) for j in range(1, 6) - ] - ) + eq_(a.bs, [B(id=(a.id * 6) + j) for j in range(1, 6)]) # this part fails with subquery eager loading # (if you enable subquery eager w/ yield_per) @@ -1580,39 +2016,68 @@ class ChunkingTest(fixtures.DeclarativeMappedTest): class SubRelationFromJoinedSubclassMultiLevelTest(_Polymorphic): @classmethod def define_tables(cls, metadata): - Table('companies', metadata, - Column('company_id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50))) - - Table('people', metadata, - Column('person_id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('company_id', ForeignKey('companies.company_id')), - Column('name', String(50)), - Column('type', String(30))) - - Table('engineers', metadata, - Column('engineer_id', ForeignKey('people.person_id'), - primary_key=True), - Column('primary_language', String(50))) - - Table('machines', metadata, - Column('machine_id', - Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('engineer_id', ForeignKey('engineers.engineer_id')), - Column('machine_type_id', - ForeignKey('machine_type.machine_type_id'))) - - Table('machine_type', metadata, - Column('machine_type_id', - Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50))) + Table( + "companies", + metadata, + Column( + "company_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + ) + + Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("company_id", ForeignKey("companies.company_id")), + Column("name", String(50)), + Column("type", String(30)), + ) + + Table( + "engineers", + metadata, + Column( + "engineer_id", ForeignKey("people.person_id"), primary_key=True + ), + Column("primary_language", String(50)), + ) + + Table( + "machines", + metadata, + Column( + "machine_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("engineer_id", ForeignKey("engineers.engineer_id")), + Column( + "machine_type_id", ForeignKey("machine_type.machine_type_id") + ), + ) + + Table( + "machine_type", + metadata, + Column( + "machine_type_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + ) @classmethod def setup_mappers(cls): @@ -1622,24 +2087,36 @@ class SubRelationFromJoinedSubclassMultiLevelTest(_Polymorphic): machines = cls.tables.machines machine_type = cls.tables.machine_type - mapper(Company, companies, properties={ - 'employees': relationship(Person, order_by=people.c.person_id) - }) - mapper(Person, people, - polymorphic_on=people.c.type, - polymorphic_identity='person', - with_polymorphic='*') - - mapper(Engineer, engineers, - inherits=Person, - polymorphic_identity='engineer', properties={ - 'machines': relationship(Machine, - order_by=machines.c.machine_id) - }) - - mapper(Machine, machines, properties={ - 'type': relationship(MachineType) - }) + mapper( + Company, + companies, + properties={ + "employees": relationship(Person, order_by=people.c.person_id) + }, + ) + mapper( + Person, + people, + polymorphic_on=people.c.type, + polymorphic_identity="person", + with_polymorphic="*", + ) + + mapper( + Engineer, + engineers, + inherits=Person, + polymorphic_identity="engineer", + properties={ + "machines": relationship( + Machine, order_by=machines.c.machine_id + ) + }, + ) + + mapper( + Machine, machines, properties={"type": relationship(MachineType)} + ) mapper(MachineType, machine_type) @classmethod @@ -1651,50 +2128,53 @@ class SubRelationFromJoinedSubclassMultiLevelTest(_Polymorphic): @classmethod def _fixture(cls): - mt1 = MachineType(name='mt1') - mt2 = MachineType(name='mt2') + mt1 = MachineType(name="mt1") + mt2 = MachineType(name="mt2") return Company( employees=[ Engineer( - name='e1', + name="e1", machines=[ - Machine(name='m1', type=mt1), - Machine(name='m2', type=mt2) - ] + Machine(name="m1", type=mt1), + Machine(name="m2", type=mt2), + ], ), Engineer( - name='e2', + name="e2", machines=[ - Machine(name='m3', type=mt1), - Machine(name='m4', type=mt1) - ] - ) - ]) + Machine(name="m3", type=mt1), + Machine(name="m4", type=mt1), + ], + ), + ] + ) def test_chained_selectin_subclass(self): s = Session() q = s.query(Company).options( - selectinload(Company.employees.of_type(Engineer)). - selectinload(Engineer.machines). - selectinload(Machine.type) + selectinload(Company.employees.of_type(Engineer)) + .selectinload(Engineer.machines) + .selectinload(Machine.type) ) def go(): - eq_( - q.all(), - [self._fixture()] - ) + eq_(q.all(), [self._fixture()]) + self.assert_sql_count(testing.db, go, 4) class SelfReferentialTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('nodes', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, ForeignKey('nodes.id')), - Column('data', String(30))) + Table( + "nodes", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_id", Integer, ForeignKey("nodes.id")), + Column("data", String(30)), + ) def test_basic(self): nodes = self.tables.nodes @@ -1703,23 +2183,27 @@ class SelfReferentialTest(fixtures.MappedTest): def append(self, node): self.children.append(node) - mapper(Node, nodes, properties={ - 'children': relationship(Node, - lazy='selectin', - join_depth=3, order_by=nodes.c.id) - }) + mapper( + Node, + nodes, + properties={ + "children": relationship( + Node, lazy="selectin", join_depth=3, order_by=nodes.c.id + ) + }, + ) sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) - n2 = Node(data='n2') - n2.append(Node(data='n21')) - n2.children[0].append(Node(data='n211')) - n2.children[0].append(Node(data='n212')) + n1 = Node(data="n1") + n1.append(Node(data="n11")) + n1.append(Node(data="n12")) + n1.append(Node(data="n13")) + n1.children[1].append(Node(data="n121")) + n1.children[1].append(Node(data="n122")) + n1.children[1].append(Node(data="n123")) + n2 = Node(data="n2") + n2.append(Node(data="n21")) + n2.children[0].append(Node(data="n211")) + n2.children[0].append(Node(data="n212")) sess.add(n1) sess.add(n2) @@ -1727,24 +2211,45 @@ class SelfReferentialTest(fixtures.MappedTest): sess.expunge_all() def go(): - d = sess.query(Node).filter(Node.data.in_(['n1', 'n2'])).\ - order_by(Node.data).all() - eq_([Node(data='n1', children=[ - Node(data='n11'), - Node(data='n12', children=[ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ]), - Node(data='n13') - ]), - Node(data='n2', children=[ - Node(data='n21', children=[ - Node(data='n211'), - Node(data='n212'), - ]) - ]) - ], d) + d = ( + sess.query(Node) + .filter(Node.data.in_(["n1", "n2"])) + .order_by(Node.data) + .all() + ) + eq_( + [ + Node( + data="n1", + children=[ + Node(data="n11"), + Node( + data="n12", + children=[ + Node(data="n121"), + Node(data="n122"), + Node(data="n123"), + ], + ), + Node(data="n13"), + ], + ), + Node( + data="n2", + children=[ + Node( + data="n21", + children=[ + Node(data="n211"), + Node(data="n212"), + ], + ) + ], + ), + ], + d, + ) + self.assert_sql_count(testing.db, go, 4) def test_lazy_fallback_doesnt_affect_eager(self): @@ -1754,20 +2259,25 @@ class SelfReferentialTest(fixtures.MappedTest): def append(self, node): self.children.append(node) - mapper(Node, nodes, properties={ - 'children': relationship(Node, lazy='selectin', join_depth=1, - order_by=nodes.c.id) - }) + mapper( + Node, + nodes, + properties={ + "children": relationship( + Node, lazy="selectin", join_depth=1, order_by=nodes.c.id + ) + }, + ) sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[0].append(Node(data='n111')) - n1.children[0].append(Node(data='n112')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) + n1 = Node(data="n1") + n1.append(Node(data="n11")) + n1.append(Node(data="n12")) + n1.append(Node(data="n13")) + n1.children[0].append(Node(data="n111")) + n1.children[0].append(Node(data="n112")) + n1.children[1].append(Node(data="n121")) + n1.children[1].append(Node(data="n122")) + n1.children[1].append(Node(data="n123")) sess.add(n1) sess.flush() sess.expunge_all() @@ -1776,19 +2286,16 @@ class SelfReferentialTest(fixtures.MappedTest): allnodes = sess.query(Node).order_by(Node.data).all() n11 = allnodes[1] - eq_(n11.data, 'n11') - eq_([ - Node(data='n111'), - Node(data='n112'), - ], list(n11.children)) + eq_(n11.data, "n11") + eq_([Node(data="n111"), Node(data="n112")], list(n11.children)) n12 = allnodes[4] - eq_(n12.data, 'n12') - eq_([ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ], list(n12.children)) + eq_(n12.data, "n12") + eq_( + [Node(data="n121"), Node(data="n122"), Node(data="n123")], + list(n12.children), + ) + self.assert_sql_count(testing.db, go, 2) def test_with_deferred(self): @@ -1798,40 +2305,55 @@ class SelfReferentialTest(fixtures.MappedTest): def append(self, node): self.children.append(node) - mapper(Node, nodes, properties={ - 'children': relationship(Node, lazy='selectin', join_depth=3, - order_by=nodes.c.id), - 'data': deferred(nodes.c.data) - }) + mapper( + Node, + nodes, + properties={ + "children": relationship( + Node, lazy="selectin", join_depth=3, order_by=nodes.c.id + ), + "data": deferred(nodes.c.data), + }, + ) sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) + n1 = Node(data="n1") + n1.append(Node(data="n11")) + n1.append(Node(data="n12")) sess.add(n1) sess.flush() sess.expunge_all() def go(): eq_( - Node(data='n1', children=[Node(data='n11'), Node(data='n12')]), + Node(data="n1", children=[Node(data="n11"), Node(data="n12")]), sess.query(Node).order_by(Node.id).first(), ) + self.assert_sql_count(testing.db, go, 6) sess.expunge_all() def go(): - eq_(Node(data='n1', children=[Node(data='n11'), Node(data='n12')]), - sess.query(Node).options(undefer('data')).order_by(Node.id) - .first()) + eq_( + Node(data="n1", children=[Node(data="n11"), Node(data="n12")]), + sess.query(Node) + .options(undefer("data")) + .order_by(Node.id) + .first(), + ) + self.assert_sql_count(testing.db, go, 5) sess.expunge_all() def go(): - eq_(Node(data='n1', children=[Node(data='n11'), Node(data='n12')]), - sess.query(Node).options(undefer('data'), - undefer('children.data')).first()) + eq_( + Node(data="n1", children=[Node(data="n11"), Node(data="n12")]), + sess.query(Node) + .options(undefer("data"), undefer("children.data")) + .first(), + ) + self.assert_sql_count(testing.db, go, 3) def test_options(self): @@ -1841,33 +2363,50 @@ class SelfReferentialTest(fixtures.MappedTest): def append(self, node): self.children.append(node) - mapper(Node, nodes, properties={ - 'children': relationship(Node, order_by=nodes.c.id) - }) + mapper( + Node, + nodes, + properties={"children": relationship(Node, order_by=nodes.c.id)}, + ) sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) + n1 = Node(data="n1") + n1.append(Node(data="n11")) + n1.append(Node(data="n12")) + n1.append(Node(data="n13")) + n1.children[1].append(Node(data="n121")) + n1.children[1].append(Node(data="n122")) + n1.children[1].append(Node(data="n123")) sess.add(n1) sess.flush() sess.expunge_all() def go(): - d = sess.query(Node).filter_by(data='n1').order_by(Node.id).\ - options(selectinload_all('children.children')).first() - eq_(Node(data='n1', children=[ - Node(data='n11'), - Node(data='n12', children=[ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ]), - Node(data='n13') - ]), d) + d = ( + sess.query(Node) + .filter_by(data="n1") + .order_by(Node.id) + .options(selectinload_all("children.children")) + .first() + ) + eq_( + Node( + data="n1", + children=[ + Node(data="n11"), + Node( + data="n12", + children=[ + Node(data="n121"), + Node(data="n122"), + Node(data="n123"), + ], + ), + Node(data="n13"), + ], + ), + d, + ) + self.assert_sql_count(testing.db, go, 3) def test_no_depth(self): @@ -1879,48 +2418,62 @@ class SelfReferentialTest(fixtures.MappedTest): def append(self, node): self.children.append(node) - mapper(Node, nodes, properties={ - 'children': relationship(Node, lazy='selectin') - }) + mapper( + Node, + nodes, + properties={"children": relationship(Node, lazy="selectin")}, + ) sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) - n2 = Node(data='n2') - n2.append(Node(data='n21')) + n1 = Node(data="n1") + n1.append(Node(data="n11")) + n1.append(Node(data="n12")) + n1.append(Node(data="n13")) + n1.children[1].append(Node(data="n121")) + n1.children[1].append(Node(data="n122")) + n1.children[1].append(Node(data="n123")) + n2 = Node(data="n2") + n2.append(Node(data="n21")) sess.add(n1) sess.add(n2) sess.flush() sess.expunge_all() def go(): - d = sess.query(Node).filter(Node.data.in_( - ['n1', 'n2'])).order_by(Node.data).all() - eq_([ - Node(data='n1', children=[ - Node(data='n11'), - Node(data='n12', children=[ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ]), - Node(data='n13') - ]), - Node(data='n2', children=[ - Node(data='n21') - ]) - ], d) + d = ( + sess.query(Node) + .filter(Node.data.in_(["n1", "n2"])) + .order_by(Node.data) + .all() + ) + eq_( + [ + Node( + data="n1", + children=[ + Node(data="n11"), + Node( + data="n12", + children=[ + Node(data="n121"), + Node(data="n122"), + Node(data="n123"), + ], + ), + Node(data="n13"), + ], + ), + Node(data="n2", children=[Node(data="n21")]), + ], + d, + ) + self.assert_sql_count(testing.db, go, 4) class SelfRefInheritanceAliasedTest( - fixtures.DeclarativeMappedTest, - testing.AssertsCompiledSQL): - __dialect__ = 'default' + fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL +): + __dialect__ = "default" @classmethod def setup_classes(cls): @@ -1933,7 +2486,8 @@ class SelfRefInheritanceAliasedTest( foo_id = Column(Integer, ForeignKey("foo.id")) foo = relationship( - lambda: Foo, foreign_keys=foo_id, remote_side=id) + lambda: Foo, foreign_keys=foo_id, remote_side=id + ) __mapper_args__ = { "polymorphic_on": type, @@ -1941,13 +2495,11 @@ class SelfRefInheritanceAliasedTest( } class Bar(Foo): - __mapper_args__ = { - "polymorphic_identity": "bar", - } + __mapper_args__ = {"polymorphic_identity": "bar"} @classmethod def insert_data(cls): - Foo, Bar = cls.classes('Foo', 'Bar') + Foo, Bar = cls.classes("Foo", "Bar") session = Session() target = Bar(id=1) @@ -1956,15 +2508,17 @@ class SelfRefInheritanceAliasedTest( session.commit() def test_twolevel_selectin_w_polymorphic(self): - Foo, Bar = self.classes('Foo', 'Bar') + Foo, Bar = self.classes("Foo", "Bar") r = with_polymorphic(Foo, "*", aliased=True) attr1 = Foo.foo.of_type(r) attr2 = r.foo s = Session() - q = s.query(Foo).filter(Foo.id == 2).options( - selectinload(attr1).selectinload(attr2), + q = ( + s.query(Foo) + .filter(Foo.id == 2) + .options(selectinload(attr1).selectinload(attr2)) ) self.assert_sql_execution( testing.db, @@ -1972,7 +2526,7 @@ class SelfRefInheritanceAliasedTest( CompiledSQL( "SELECT foo.id AS foo_id_1, foo.type AS foo_type, " "foo.foo_id AS foo_foo_id FROM foo WHERE foo.id = :id_1", - [{'id_1': 2}] + [{"id_1": 2}], ), CompiledSQL( "SELECT foo_1.id AS foo_1_id, foo_2.id AS foo_2_id, " @@ -1981,17 +2535,16 @@ class SelfRefInheritanceAliasedTest( "ON foo_2.id = foo_1.foo_id " "WHERE foo_1.id " "IN ([EXPANDING_primary_keys]) ORDER BY foo_1.id", - {'primary_keys': [2]} + {"primary_keys": [2]}, ), CompiledSQL( - "SELECT foo_1.id AS foo_1_id, foo_2.id AS foo_2_id, " "foo_2.type AS foo_2_type, foo_2.foo_id AS foo_2_foo_id " "FROM foo AS foo_1 JOIN foo AS foo_2 " "ON foo_2.id = foo_1.foo_id " "WHERE foo_1.id IN ([EXPANDING_primary_keys]) " "ORDER BY foo_1.id", - {'primary_keys': [3]} + {"primary_keys": [3]}, ), ) @@ -2002,28 +2555,28 @@ class TestExistingRowPopulation(fixtures.DeclarativeMappedTest): Base = cls.DeclarativeBasic class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) - b_id = Column(ForeignKey('b.id')) - a2_id = Column(ForeignKey('a2.id')) + b_id = Column(ForeignKey("b.id")) + a2_id = Column(ForeignKey("a2.id")) a2 = relationship("A2") b = relationship("B") class A2(Base): - __tablename__ = 'a2' + __tablename__ = "a2" id = Column(Integer, primary_key=True) - b_id = Column(ForeignKey('b.id')) + b_id = Column(ForeignKey("b.id")) b = relationship("B") class B(Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) - c1_m2o_id = Column(ForeignKey('c1_m2o.id')) - c2_m2o_id = Column(ForeignKey('c2_m2o.id')) + c1_m2o_id = Column(ForeignKey("c1_m2o.id")) + c2_m2o_id = Column(ForeignKey("c2_m2o.id")) c1_o2m = relationship("C1o2m") c2_o2m = relationship("C2o2m") @@ -2031,49 +2584,44 @@ class TestExistingRowPopulation(fixtures.DeclarativeMappedTest): c2_m2o = relationship("C2m2o") class C1o2m(Base): - __tablename__ = 'c1_o2m' + __tablename__ = "c1_o2m" id = Column(Integer, primary_key=True) - b_id = Column(ForeignKey('b.id')) + b_id = Column(ForeignKey("b.id")) class C2o2m(Base): - __tablename__ = 'c2_o2m' + __tablename__ = "c2_o2m" id = Column(Integer, primary_key=True) - b_id = Column(ForeignKey('b.id')) + b_id = Column(ForeignKey("b.id")) class C1m2o(Base): - __tablename__ = 'c1_m2o' + __tablename__ = "c1_m2o" id = Column(Integer, primary_key=True) class C2m2o(Base): - __tablename__ = 'c2_m2o' + __tablename__ = "c2_m2o" id = Column(Integer, primary_key=True) @classmethod def insert_data(cls): A, A2, B, C1o2m, C2o2m, C1m2o, C2m2o = cls.classes( - 'A', 'A2', 'B', 'C1o2m', 'C2o2m', 'C1m2o', 'C2m2o' + "A", "A2", "B", "C1o2m", "C2o2m", "C1m2o", "C2m2o" ) s = Session() b = B( - c1_o2m=[C1o2m()], - c2_o2m=[C2o2m()], - c1_m2o=C1m2o(), - c2_m2o=C2m2o(), + c1_o2m=[C1o2m()], c2_o2m=[C2o2m()], c1_m2o=C1m2o(), c2_m2o=C2m2o() ) s.add(A(b=b, a2=A2(b=b))) s.commit() def test_o2m(self): - A, A2, B, C1o2m, C2o2m = self.classes( - 'A', 'A2', 'B', 'C1o2m', 'C2o2m' - ) + A, A2, B, C1o2m, C2o2m = self.classes("A", "A2", "B", "C1o2m", "C2o2m") s = Session() @@ -2085,18 +2633,16 @@ class TestExistingRowPopulation(fixtures.DeclarativeMappedTest): q = s.query(A).options( joinedload(A.b).selectinload(B.c2_o2m), - joinedload(A.a2).joinedload(A2.b).selectinload(B.c1_o2m) + joinedload(A.a2).joinedload(A2.b).selectinload(B.c1_o2m), ) a1 = q.all()[0] - is_true('c1_o2m' in a1.b.__dict__) - is_true('c2_o2m' in a1.b.__dict__) + is_true("c1_o2m" in a1.b.__dict__) + is_true("c2_o2m" in a1.b.__dict__) def test_m2o(self): - A, A2, B, C1m2o, C2m2o = self.classes( - 'A', 'A2', 'B', 'C1m2o', 'C2m2o' - ) + A, A2, B, C1m2o, C2m2o = self.classes("A", "A2", "B", "C1m2o", "C2m2o") s = Session() @@ -2108,39 +2654,38 @@ class TestExistingRowPopulation(fixtures.DeclarativeMappedTest): q = s.query(A).options( joinedload(A.b).selectinload(B.c2_m2o), - joinedload(A.a2).joinedload(A2.b).selectinload(B.c1_m2o) + joinedload(A.a2).joinedload(A2.b).selectinload(B.c1_m2o), ) a1 = q.all()[0] - is_true('c1_m2o' in a1.b.__dict__) - is_true('c2_m2o' in a1.b.__dict__) + is_true("c1_m2o" in a1.b.__dict__) + is_true("c2_m2o" in a1.b.__dict__) class SingleInhSubclassTest( - fixtures.DeclarativeMappedTest, - testing.AssertsExecutionResults): - + fixtures.DeclarativeMappedTest, testing.AssertsExecutionResults +): @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class User(Base): - __tablename__ = 'user' + __tablename__ = "user" id = Column(Integer, primary_key=True) type = Column(String(10)) - __mapper_args__ = {'polymorphic_on': type} + __mapper_args__ = {"polymorphic_on": type} class EmployerUser(User): - roles = relationship('Role', lazy='selectin') - __mapper_args__ = {'polymorphic_identity': 'employer'} + roles = relationship("Role", lazy="selectin") + __mapper_args__ = {"polymorphic_identity": "employer"} class Role(Base): - __tablename__ = 'role' + __tablename__ = "role" id = Column(Integer, primary_key=True) - user_id = Column(Integer, ForeignKey('user.id')) + user_id = Column(Integer, ForeignKey("user.id")) @classmethod def insert_data(cls): @@ -2162,12 +2707,12 @@ class SingleInhSubclassTest( CompiledSQL( 'SELECT "user".id AS user_id, "user".type AS user_type ' 'FROM "user" WHERE "user".type IN (:type_1)', - {'type_1': 'employer'} + {"type_1": "employer"}, ), CompiledSQL( "SELECT role.user_id AS role_user_id, role.id AS role_id " "FROM role WHERE role.user_id " "IN ([EXPANDING_primary_keys]) ORDER BY role.user_id", - {'primary_keys': [1]} + {"primary_keys": [1]}, ), ) diff --git a/test/orm/test_session.py b/test/orm/test_session.py index b141b9965a..cf908b988a 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -1,18 +1,37 @@ -from sqlalchemy.testing import eq_, assert_raises, \ - assert_raises_message, assertions, is_true, is_ +from sqlalchemy.testing import ( + eq_, + assert_raises, + assert_raises_message, + assertions, + is_true, + is_, +) from sqlalchemy.testing.util import gc_collect from sqlalchemy.testing import pickleable from sqlalchemy.util import pickle import inspect -from sqlalchemy.orm import create_session, sessionmaker, attributes, \ - make_transient, make_transient_to_detached, Session +from sqlalchemy.orm import ( + create_session, + sessionmaker, + attributes, + make_transient, + make_transient_to_detached, + Session, +) import sqlalchemy as sa from sqlalchemy.testing import engines, config from sqlalchemy import testing from sqlalchemy import Integer, String, Sequence from sqlalchemy.testing.schema import Table, Column -from sqlalchemy.orm import mapper, relationship, backref, joinedload, \ - exc as orm_exc, object_session, was_deleted +from sqlalchemy.orm import ( + mapper, + relationship, + backref, + joinedload, + exc as orm_exc, + object_session, + was_deleted, +) from sqlalchemy.util import pypy from sqlalchemy.testing import fixtures from test.orm import _fixtures @@ -20,6 +39,7 @@ from sqlalchemy import event, ForeignKey from sqlalchemy.util.compat import inspect_getargspec from sqlalchemy.testing import mock + class ExecutionTest(_fixtures.FixtureTest): run_inserts = None __backend__ = True @@ -40,34 +60,37 @@ class ExecutionTest(_fixtures.FixtureTest): users = self.tables.users sess = create_session(bind=self.metadata.bind) - users.insert().execute(id=7, name='jack') + users.insert().execute(id=7, name="jack") # use :bindparam style - eq_(sess.execute("select * from users where id=:id", - {'id': 7}).fetchall(), - [(7, 'jack')]) + eq_( + sess.execute( + "select * from users where id=:id", {"id": 7} + ).fetchall(), + [(7, "jack")], + ) # use :bindparam style - eq_(sess.scalar("select id from users where id=:id", {'id': 7}), 7) + eq_(sess.scalar("select id from users where id=:id", {"id": 7}), 7) def test_parameter_execute(self): users = self.tables.users sess = Session(bind=testing.db) - sess.execute(users.insert(), [ - {"id": 7, "name": "u7"}, - {"id": 8, "name": "u8"} - ]) + sess.execute( + users.insert(), [{"id": 7, "name": "u7"}, {"id": 8, "name": "u8"}] + ) sess.execute(users.insert(), {"id": 9, "name": "u9"}) eq_( - sess.execute(sa.select([users.c.id]). - order_by(users.c.id)).fetchall(), - [(7, ), (8, ), (9, )] + sess.execute( + sa.select([users.c.id]).order_by(users.c.id) + ).fetchall(), + [(7,), (8,), (9,)], ) class TransScopingTest(_fixtures.FixtureTest): run_inserts = None - __prefer_requires__ = "independent_connections", + __prefer_requires__ = ("independent_connections",) def test_no_close_on_flush(self): """Flush() doesn't close a connection the session didn't open""" @@ -79,7 +102,7 @@ class TransScopingTest(_fixtures.FixtureTest): mapper(User, users) s = create_session(bind=c) - s.add(User(name='first')) + s.add(User(name="first")) s.flush() c.execute("select * from users") @@ -93,7 +116,7 @@ class TransScopingTest(_fixtures.FixtureTest): mapper(User, users) s = create_session(bind=c) - s.add(User(name='first')) + s.add(User(name="first")) s.flush() c.execute("select * from users") s.close() @@ -109,7 +132,7 @@ class TransScopingTest(_fixtures.FixtureTest): conn2 = testing.db.connect() sess = create_session(autocommit=False, bind=conn1) - u = User(name='x') + u = User(name="x") sess.add(u) sess.flush() assert conn1.execute("select count(1) from users").scalar() == 1 @@ -117,8 +140,10 @@ class TransScopingTest(_fixtures.FixtureTest): sess.commit() assert conn1.execute("select count(1) from users").scalar() == 1 - assert testing.db.connect().execute('select count(1) from users') \ - .scalar() == 1 + assert ( + testing.db.connect().execute("select count(1) from users").scalar() + == 1 + ) sess.close() @@ -128,24 +153,16 @@ class SessionUtilTest(_fixtures.FixtureTest): def test_object_session_raises(self): User = self.classes.User - assert_raises( - orm_exc.UnmappedInstanceError, - object_session, - object() - ) + assert_raises(orm_exc.UnmappedInstanceError, object_session, object()) - assert_raises( - orm_exc.UnmappedInstanceError, - object_session, - User() - ) + assert_raises(orm_exc.UnmappedInstanceError, object_session, User()) def test_make_transient(self): users, User = self.tables.users, self.classes.User mapper(User, users) sess = create_session() - sess.add(User(name='test')) + sess.add(User(name="test")) sess.flush() u1 = sess.query(User).first() @@ -173,7 +190,7 @@ class SessionUtilTest(_fixtures.FixtureTest): sess.close() - u1.name = 'test2' + u1.name = "test2" sess.add(u1) sess.flush() assert u1 in sess @@ -193,7 +210,7 @@ class SessionUtilTest(_fixtures.FixtureTest): mapper(User, users) sess = Session() - u1 = User(name='test') + u1 = User(name="test") sess.add(u1) sess.commit() @@ -208,14 +225,14 @@ class SessionUtilTest(_fixtures.FixtureTest): mapper(User, users) sess = Session() - u1 = User(id=1, name='test') + u1 = User(id=1, name="test") sess.add(u1) sess.commit() sess.close() u2 = User(id=1) make_transient_to_detached(u2) - assert 'id' in u2.__dict__ + assert "id" in u2.__dict__ sess.add(u2) eq_(u2.name, "test") @@ -224,12 +241,13 @@ class SessionUtilTest(_fixtures.FixtureTest): mapper(User, users) sess = Session() - u1 = User(id=1, name='test') + u1 = User(id=1, name="test") sess.add(u1) assert_raises_message( sa.exc.InvalidRequestError, "Given object must be transient", - make_transient_to_detached, u1 + make_transient_to_detached, + u1, ) def test_make_transient_to_detached_no_key_allowed(self): @@ -237,21 +255,22 @@ class SessionUtilTest(_fixtures.FixtureTest): mapper(User, users) sess = Session() - u1 = User(id=1, name='test') + u1 = User(id=1, name="test") sess.add(u1) sess.commit() sess.expunge(u1) assert_raises_message( sa.exc.InvalidRequestError, "Given object must be transient", - make_transient_to_detached, u1 + make_transient_to_detached, + u1, ) class SessionStateTest(_fixtures.FixtureTest): run_inserts = None - __prefer_requires__ = ('independent_connections', ) + __prefer_requires__ = ("independent_connections",) def test_info(self): s = Session() @@ -271,8 +290,8 @@ class SessionStateTest(_fixtures.FixtureTest): eq_(s3.info, {"global": True, "s1": 5}) maker2 = sessionmaker() - s4 = maker2(info={'s4': 8}) - eq_(s4.info, {'s4': 8}) + s4 = maker2(info={"s4": 8}) + eq_(s4.info, {"s4": 8}) @testing.requires.independent_connections @engines.close_open_connections @@ -286,9 +305,9 @@ class SessionStateTest(_fixtures.FixtureTest): sess = create_session(bind=conn1, autocommit=False, autoflush=True) u = User() - u.name = 'ed' + u.name = "ed" sess.add(u) - u2 = sess.query(User).filter_by(name='ed').one() + u2 = sess.query(User).filter_by(name="ed").one() assert u2 is u eq_(conn1.execute("select count(1) from users").scalar(), 1) eq_(conn2.execute("select count(1) from users").scalar(), 0) @@ -304,11 +323,12 @@ class SessionStateTest(_fixtures.FixtureTest): sess = Session() u = User() - u.name = 'ed' + u.name = "ed" sess.add(u) def go(obj): assert u not in sess.query(User).all() + testing.run_as_contextmanager(sess.no_autoflush, go) assert u in sess.new assert u in sess.query(User).all() @@ -321,7 +341,7 @@ class SessionStateTest(_fixtures.FixtureTest): ZeroDivisionError, testing.run_as_contextmanager, sess.no_autoflush, - lambda obj: 1 / 0 + lambda obj: 1 / 0, ) is_true(sess.autoflush) @@ -333,7 +353,7 @@ class SessionStateTest(_fixtures.FixtureTest): sess = sessionmaker()() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() @@ -362,7 +382,7 @@ class SessionStateTest(_fixtures.FixtureTest): mapper(User, users) sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() @@ -383,8 +403,9 @@ class SessionStateTest(_fixtures.FixtureTest): # commit proceeds w/ warning with assertions.expect_warnings( - "DELETE statement on table 'users' " - r"expected to delete 1 row\(s\); 0 were matched."): + "DELETE statement on table 'users' " + r"expected to delete 1 row\(s\); 0 were matched." + ): sess.commit() @testing.requires.independent_connections @@ -396,19 +417,35 @@ class SessionStateTest(_fixtures.FixtureTest): try: sess = create_session(autocommit=False, autoflush=True) u = User() - u.name = 'ed' + u.name = "ed" sess.add(u) - u2 = sess.query(User).filter_by(name='ed').one() + u2 = sess.query(User).filter_by(name="ed").one() assert u2 is u - assert sess.execute('select count(1) from users', - mapper=User).scalar() == 1 - assert testing.db.connect().execute('select count(1) from users') \ - .scalar() == 0 + assert ( + sess.execute( + "select count(1) from users", mapper=User + ).scalar() + == 1 + ) + assert ( + testing.db.connect() + .execute("select count(1) from users") + .scalar() + == 0 + ) sess.commit() - assert sess.execute('select count(1) from users', - mapper=User).scalar() == 1 - assert testing.db.connect().execute('select count(1) from users') \ - .scalar() == 1 + assert ( + sess.execute( + "select count(1) from users", mapper=User + ).scalar() + == 1 + ) + assert ( + testing.db.connect() + .execute("select count(1) from users") + .scalar() + == 1 + ) sess.close() except Exception: sess.rollback() @@ -420,15 +457,16 @@ class SessionStateTest(_fixtures.FixtureTest): mapper(User, users) conn1 = testing.db.connect() - sess = create_session(bind=conn1, autocommit=False, - autoflush=True) + sess = create_session(bind=conn1, autocommit=False, autoflush=True) u = User() - u.name = 'ed' + u.name = "ed" sess.add(u) sess.commit() - assert conn1.execute('select count(1) from users').scalar() == 1 - assert testing.db.connect().execute('select count(1) from users') \ - .scalar() == 1 + assert conn1.execute("select count(1) from users").scalar() == 1 + assert ( + testing.db.connect().execute("select count(1) from users").scalar() + == 1 + ) sess.commit() def test_autocommit_doesnt_raise_on_pending(self): @@ -437,7 +475,7 @@ class SessionStateTest(_fixtures.FixtureTest): mapper(User, users) session = create_session(autocommit=True) - session.add(User(name='ed')) + session.add(User(name="ed")) session.begin() session.flush() @@ -453,21 +491,28 @@ class SessionStateTest(_fixtures.FixtureTest): @engines.close_open_connections def test_add_delete(self): - User, Address, addresses, users = (self.classes.User, - self.classes.Address, - self.tables.addresses, - self.tables.users) + User, Address, addresses, users = ( + self.classes.User, + self.classes.Address, + self.tables.addresses, + self.tables.users, + ) s = create_session() - mapper(User, users, properties={ - 'addresses': relationship(Address, cascade="all, delete") - }) + mapper( + User, + users, + properties={ + "addresses": relationship(Address, cascade="all, delete") + }, + ) mapper(Address, addresses) - user = User(name='u1') + user = User(name="u1") - assert_raises_message(sa.exc.InvalidRequestError, - 'is not persisted', s.delete, user) + assert_raises_message( + sa.exc.InvalidRequestError, "is not persisted", s.delete, user + ) s.add(user) s.flush() @@ -484,7 +529,7 @@ class SessionStateTest(_fixtures.FixtureTest): s.expunge_all() assert s.query(User).count() == 1 user = s.query(User).one() - assert user.name == 'fred' + assert user.name == "fred" # ensure its not dirty if no changes occur s.expunge_all() @@ -494,14 +539,20 @@ class SessionStateTest(_fixtures.FixtureTest): assert user not in s.dirty s2 = create_session() - assert_raises_message(sa.exc.InvalidRequestError, - 'is already attached to session', - s2.delete, user) + assert_raises_message( + sa.exc.InvalidRequestError, + "is already attached to session", + s2.delete, + user, + ) u2 = s2.query(User).get(user.id) s2.expunge(u2) assert_raises_message( sa.exc.InvalidRequestError, - 'another instance .* is already present', s.delete, u2) + "another instance .* is already present", + s.delete, + u2, + ) s.expire(user) s.expunge(user) assert user not in s @@ -520,14 +571,15 @@ class SessionStateTest(_fixtures.FixtureTest): s1 = Session() s2 = Session() - u1 = User(id=1, name='u1') + u1 = User(id=1, name="u1") make_transient_to_detached(u1) # shorthand for actually persisting it s1.add(u1) assert_raises_message( sa.exc.InvalidRequestError, "Object '' is already attached to session", - s2.add, u1 + s2.add, + u1, ) assert u1 not in s2 assert not s2.identity_map.keys() @@ -537,10 +589,7 @@ class SessionStateTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - for s in ( - create_session(), - create_session(weak_identity_map=False), - ): + for s in (create_session(), create_session(weak_identity_map=False)): users.delete().execute() u1 = User(name="ed") s.add(u1) @@ -556,7 +605,7 @@ class SessionStateTest(_fixtures.FixtureTest): "with key .*? is already " "present in this session.", s.identity_map.add, - sa.orm.attributes.instance_state(u2) + sa.orm.attributes.instance_state(u2), ) def test_pickled_update(self): @@ -565,11 +614,14 @@ class SessionStateTest(_fixtures.FixtureTest): mapper(User, users) sess1 = create_session() sess2 = create_session() - u1 = User(name='u1') + u1 = User(name="u1") sess1.add(u1) - assert_raises_message(sa.exc.InvalidRequestError, - 'already attached to session', sess2.add, - u1) + assert_raises_message( + sa.exc.InvalidRequestError, + "already attached to session", + sess2.add, + u1, + ) u2 = pickle.loads(pickle.dumps(u1)) sess2.add(u2) @@ -580,7 +632,7 @@ class SessionStateTest(_fixtures.FixtureTest): Session = sessionmaker() sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.flush() assert u1.id is not None @@ -599,7 +651,8 @@ class SessionStateTest(_fixtures.FixtureTest): "Can't attach instance ; another instance " "with key .*? is already " "present in this session.", - sess.add, u1 + sess.add, + u1, ) sess.expunge(u2) @@ -633,6 +686,7 @@ class SessionStateTest(_fixtures.FixtureTest): def __init__(self): sess.add(self) Foo.__init__(self) + mapper(Foo, users) mapper(Bar, users) @@ -647,7 +701,7 @@ class SessionStateTest(_fixtures.FixtureTest): sess = Session() - sess.add_all([User(name='u1'), User(name='u2'), User(name='u3')]) + sess.add_all([User(name="u1"), User(name="u2"), User(name="u3")]) sess.commit() # TODO: what are we testing here ? that iteritems() can @@ -685,7 +739,7 @@ class SessionStateTest(_fixtures.FixtureTest): sa.exc.SAWarning, "Attribute history events accumulated on 1 previously " "clean instances", - s.commit + s.commit, ) def test_extra_dirty_state_post_flush_state(self): @@ -699,6 +753,7 @@ class SessionStateTest(_fixtures.FixtureTest): @testing.emits_warning("Attribute") def go(): s.commit() + go() eq_(canary, [False]) @@ -707,7 +762,7 @@ class SessionStateTest(_fixtures.FixtureTest): mapper(User, users) sess = Session() - sess.add(User(name='x')) + sess.add(User(name="x")) sess.commit() u1 = sess.query(User).first() @@ -728,7 +783,7 @@ class SessionStateTest(_fixtures.FixtureTest): mapper(User, users) sess = Session() - u1 = User(name='x') + u1 = User(name="x") sess.add(u1) sess.flush() @@ -747,7 +802,7 @@ class SessionStateTest(_fixtures.FixtureTest): mapper(User, users) sess = Session() - sess.add(User(name='x')) + sess.add(User(name="x")) sess.commit() u1 = sess.query(User).first() @@ -772,17 +827,22 @@ class SessionStateTest(_fixtures.FixtureTest): class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): run_inserts = None - run_deletes = 'each' + run_deletes = "each" @classmethod def setup_mappers(cls): - users, Address, addresses, User = (cls.tables.users, - cls.classes.Address, - cls.tables.addresses, - cls.classes.User) + users, Address, addresses, User = ( + cls.tables.users, + cls.classes.Address, + cls.tables.addresses, + cls.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship(Address, backref="user")}) + mapper( + User, + users, + properties={"addresses": relationship(Address, backref="user")}, + ) mapper(Address, addresses) def test_deferred_expression_unflushed(self): @@ -794,16 +854,18 @@ class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): User, Address = self.classes("User", "Address") sess = create_session(autoflush=True, autocommit=False) - u = User(name='ed', addresses=[Address(email_address='foo')]) + u = User(name="ed", addresses=[Address(email_address="foo")]) sess.add(u) - eq_(sess.query(Address).filter(Address.user == u).one(), - Address(email_address='foo')) + eq_( + sess.query(Address).filter(Address.user == u).one(), + Address(email_address="foo"), + ) def test_deferred_expression_obj_was_gced(self): User, Address = self.classes("User", "Address") sess = create_session(autoflush=True, autocommit=False) - u = User(name='ed', addresses=[Address(email_address='foo')]) + u = User(name="ed", addresses=[Address(email_address="foo")]) sess.add(u) sess.commit() @@ -812,7 +874,7 @@ class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): q = sess.query(Address).filter(Address.user == u) del u gc_collect() - eq_(q.one(), Address(email_address='foo')) + eq_(q.one(), Address(email_address="foo")) def test_deferred_expression_favors_immediate(self): """Test that a deferred expression will return an immediate value @@ -823,26 +885,26 @@ class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): User, Address = self.classes("User", "Address") sess = create_session(autoflush=True, autocommit=False) - u = User(name='ed', addresses=[Address(email_address='foo')]) + u = User(name="ed", addresses=[Address(email_address="foo")]) sess.add(u) sess.commit() q = sess.query(Address).filter(Address.user == u) sess.expire(u) sess.expunge(u) - eq_(q.one(), Address(email_address='foo')) + eq_(q.one(), Address(email_address="foo")) def test_deferred_expression_obj_was_never_flushed(self): User, Address = self.classes("User", "Address") sess = create_session(autoflush=True, autocommit=False) - u = User(name='ed', addresses=[Address(email_address='foo')]) + u = User(name="ed", addresses=[Address(email_address="foo")]) assert_raises_message( sa.exc.InvalidRequestError, "Can't resolve value for column users.id on object " ".User.*.; no value has been set for this column", - (Address.user == u).left.callable + (Address.user == u).left.callable, ) q = sess.query(Address).filter(Address.user == u) @@ -850,13 +912,13 @@ class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): sa.exc.StatementError, "Can't resolve value for column users.id on object " ".User.*.; no value has been set for this column", - q.one + q.one, ) def test_deferred_expression_transient_but_manually_set(self): User, Address = self.classes("User", "Address") - u = User(id=5, name='ed', addresses=[Address(email_address='foo')]) + u = User(id=5, name="ed", addresses=[Address(email_address="foo")]) expr = Address.user == u eq_(expr.left.callable(), 5) @@ -865,7 +927,7 @@ class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): User, Address = self.classes("User", "Address") sess = create_session(autoflush=True, autocommit=False) - u = User(name='ed', addresses=[Address(email_address='foo')]) + u = User(name="ed", addresses=[Address(email_address="foo")]) q = sess.query(Address).filter(Address.user == u) @@ -873,13 +935,13 @@ class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): sess.flush() sess.expunge(u) - eq_(q.one(), Address(email_address='foo')) + eq_(q.one(), Address(email_address="foo")) def test_deferred_expression_unflushed_obj_became_detached_expired(self): User, Address = self.classes("User", "Address") sess = create_session(autoflush=True, autocommit=False) - u = User(name='ed', addresses=[Address(email_address='foo')]) + u = User(name="ed", addresses=[Address(email_address="foo")]) q = sess.query(Address).filter(Address.user == u) @@ -888,38 +950,40 @@ class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): sess.expire(u) sess.expunge(u) - eq_(q.one(), Address(email_address='foo')) + eq_(q.one(), Address(email_address="foo")) def test_deferred_expr_unflushed_obj_became_detached_expired_by_key(self): User, Address = self.classes("User", "Address") sess = create_session(autoflush=True, autocommit=False) - u = User(name='ed', addresses=[Address(email_address='foo')]) + u = User(name="ed", addresses=[Address(email_address="foo")]) q = sess.query(Address).filter(Address.user == u) sess.add(u) sess.flush() - sess.expire(u, ['id']) + sess.expire(u, ["id"]) sess.expunge(u) - eq_(q.one(), Address(email_address='foo')) + eq_(q.one(), Address(email_address="foo")) def test_deferred_expression_expired_obj_became_detached_expired(self): User, Address = self.classes("User", "Address") sess = create_session( - autoflush=True, autocommit=False, expire_on_commit=True) - u = User(name='ed', addresses=[Address(email_address='foo')]) + autoflush=True, autocommit=False, expire_on_commit=True + ) + u = User(name="ed", addresses=[Address(email_address="foo")]) sess.add(u) sess.commit() - assert 'id' not in u.__dict__ # it's expired + assert "id" not in u.__dict__ # it's expired # should not emit SQL def go(): Address.user == u + self.assert_sql_count(testing.db, go, 0) # create the expression here, but note we weren't tracking 'id' @@ -931,7 +995,7 @@ class DeferredRelationshipExpressionTest(_fixtures.FixtureTest): sa.exc.StatementError, "Can't resolve value for column users.id on object " ".User.*.; the object is detached and the value was expired ", - q.one + q.one, ) @@ -939,42 +1003,52 @@ class SessionStateWFixtureTest(_fixtures.FixtureTest): __backend__ = True def test_autoflush_rollback(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties={ - 'addresses': relationship(Address)}) + mapper(User, users, properties={"addresses": relationship(Address)}) sess = create_session(autocommit=False, autoflush=True) u = sess.query(User).get(8) - newad = Address(email_address='a new address') + newad = Address(email_address="a new address") u.addresses.append(newad) - u.name = 'some new name' - assert u.name == 'some new name' + u.name = "some new name" + assert u.name == "some new name" assert len(u.addresses) == 4 assert newad in u.addresses sess.rollback() - assert u.name == 'ed' + assert u.name == "ed" assert len(u.addresses) == 3 assert newad not in u.addresses # pending objects don't get expired - assert newad.email_address == 'a new address' + assert newad.email_address == "a new address" def test_expunge_cascade(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties={ - 'addresses': relationship(Address, - backref=backref("user", cascade="all"), - cascade="all")}) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, + backref=backref("user", cascade="all"), + cascade="all", + ) + }, + ) session = create_session() u = session.query(User).filter_by(id=7).one() @@ -998,6 +1072,7 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): transient/detached. """ + run_inserts = None def setup(self): @@ -1027,14 +1102,14 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): def test_transient(self): User = self.classes.User u1 = User() - u1.name = 'ed' + u1.name = "ed" self._assert_no_cycle(u1) self._assert_modified(u1) def test_transient_to_pending(self): User = self.classes.User u1 = User() - u1.name = 'ed' + u1.name = "ed" self._assert_modified(u1) self._assert_no_cycle(u1) sess = Session() @@ -1046,14 +1121,14 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): def test_dirty_persistent_to_detached_via_expunge(self): sess, u1 = self._persistent_fixture() - u1.name = 'edchanged' + u1.name = "edchanged" self._assert_cycle(u1) sess.expunge(u1) self._assert_no_cycle(u1) def test_dirty_persistent_to_detached_via_close(self): sess, u1 = self._persistent_fixture() - u1.name = 'edchanged' + u1.name = "edchanged" self._assert_cycle(u1) sess.close() self._assert_no_cycle(u1) @@ -1063,14 +1138,14 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): self._assert_no_cycle(u1) self._assert_not_modified(u1) sess.close() - u1.name = 'edchanged' + u1.name = "edchanged" self._assert_modified(u1) self._assert_no_cycle(u1) def test_detached_to_dirty_deleted(self): sess, u1 = self._persistent_fixture() sess.expunge(u1) - u1.name = 'edchanged' + u1.name = "edchanged" self._assert_no_cycle(u1) sess.delete(u1) self._assert_cycle(u1) @@ -1078,7 +1153,7 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): def test_detached_to_dirty_persistent(self): sess, u1 = self._persistent_fixture() sess.expunge(u1) - u1.name = 'edchanged' + u1.name = "edchanged" self._assert_modified(u1) self._assert_no_cycle(u1) sess.add(u1) @@ -1104,7 +1179,7 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): def test_move_persistent_dirty(self): sess, u1 = self._persistent_fixture() - u1.name = 'edchanged' + u1.name = "edchanged" self._assert_cycle(u1) self._assert_modified(u1) sess.close() @@ -1117,7 +1192,7 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): @testing.requires.predictable_gc def test_move_gc_session_persistent_dirty(self): sess, u1 = self._persistent_fixture() - u1.name = 'edchanged' + u1.name = "edchanged" self._assert_cycle(u1) self._assert_modified(u1) del sess @@ -1130,7 +1205,7 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): def test_persistent_dirty_to_expired(self): sess, u1 = self._persistent_fixture() - u1.name = 'edchanged' + u1.name = "edchanged" self._assert_cycle(u1) self._assert_modified(u1) sess.expire(u1) @@ -1151,7 +1226,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): s = create_session() mapper(User, users) - s.add(User(name='ed')) + s.add(User(name="ed")) s.flush() assert not s.dirty @@ -1161,7 +1236,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): assert len(s.identity_map) == 0 user = s.query(User).one() - user.name = 'fred' + user.name = "fred" del user gc_collect() assert len(s.identity_map) == 1 @@ -1173,7 +1248,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): assert not s.identity_map user = s.query(User).one() - assert user.name == 'fred' + assert user.name == "fred" assert s.identity_map @testing.requires.predictable_gc @@ -1183,12 +1258,12 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): s = create_session() mapper(User, users) - s.add(User(name='ed')) + s.add(User(name="ed")) s.flush() assert not s.dirty user = s.query(User).one() - user.name = 'fred' + user.name = "fred" s.expunge(user) u2 = pickle.loads(pickle.dumps(user)) @@ -1210,15 +1285,19 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): @testing.requires.predictable_gc def test_weakref_with_cycles_o2m(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) s = sessionmaker()() - mapper(User, users, properties={ - "addresses": relationship(Address, backref="user") - }) + mapper( + User, + users, + properties={"addresses": relationship(Address, backref="user")}, + ) mapper(Address, addresses) s.add(User(name="ed", addresses=[Address(email_address="ed1")])) s.commit() @@ -1232,7 +1311,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): assert len(s.identity_map) == 0 user = s.query(User).options(joinedload(User.addresses)).one() - user.addresses[0].email_address = 'ed2' + user.addresses[0].email_address = "ed2" user.addresses[0].user # lazyload del user gc_collect() @@ -1244,15 +1323,21 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): @testing.requires.predictable_gc def test_weakref_with_cycles_o2o(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) s = sessionmaker()() - mapper(User, users, properties={ - "address": relationship(Address, backref="user", uselist=False) - }) + mapper( + User, + users, + properties={ + "address": relationship(Address, backref="user", uselist=False) + }, + ) mapper(Address, addresses) s.add(User(name="ed", address=Address(email_address="ed1"))) s.commit() @@ -1266,7 +1351,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): assert len(s.identity_map) == 0 user = s.query(User).options(joinedload(User.address)).one() - user.address.email_address = 'ed2' + user.address.email_address = "ed2" user.address.user # lazyload del user @@ -1284,7 +1369,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() @@ -1294,7 +1379,8 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): assert_raises_message( sa.exc.InvalidRequestError, r".*is already attached to session", - s2.add, u1 + s2.add, + u1, ) # garbage collect sess @@ -1314,7 +1400,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): sess = Session() - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() @@ -1351,10 +1437,10 @@ class StrongIdentityMapTest(_fixtures.FixtureTest): @event.listens_for(session, "detached_to_persistent") @event.listens_for(session, "loaded_as_persistent") def strong_ref_object(sess, instance): - if 'refs' not in sess.info: - sess.info['refs'] = refs = set() + if "refs" not in sess.info: + sess.info["refs"] = refs = set() else: - refs = sess.info['refs'] + refs = sess.info["refs"] refs.add(instance) @@ -1362,17 +1448,18 @@ class StrongIdentityMapTest(_fixtures.FixtureTest): @event.listens_for(session, "persistent_to_deleted") @event.listens_for(session, "persistent_to_transient") def deref_object(sess, instance): - sess.info['refs'].discard(instance) + sess.info["refs"].discard(instance) def prune(): - if 'refs' not in session.info: + if "refs" not in session.info: return 0 sess_size = len(session.identity_map) - session.info['refs'].clear() + session.info["refs"].clear() gc_collect() - session.info['refs'] = set( - s.obj() for s in session.identity_map.all_states()) + session.info["refs"] = set( + s.obj() for s in session.identity_map.all_states() + ) return sess_size - len(session.identity_map) return session, prune @@ -1392,7 +1479,7 @@ class StrongIdentityMapTest(_fixtures.FixtureTest): mapper(User, users) # save user - s.add(User(name='u1')) + s.add(User(name="u1")) s.flush() user = s.query(User).one() user = None @@ -1402,10 +1489,10 @@ class StrongIdentityMapTest(_fixtures.FixtureTest): user = s.query(User).one() assert not s.identity_map._modified - user.name = 'u2' + user.name = "u2" assert s.identity_map._modified s.flush() - eq_(users.select().execute().fetchall(), [(user.id, 'u2')]) + eq_(users.select().execute().fetchall(), [(user.id, "u2")]) @testing.uses_deprecated() def test_prune_imap(self): @@ -1415,7 +1502,7 @@ class StrongIdentityMapTest(_fixtures.FixtureTest): self._test_prune(self._event_fixture) @testing.fails_if(lambda: pypy, "pypy has a real GC") - @testing.fails_on('+zxjdbc', 'http://www.sqlalchemy.org/trac/ticket/1473') + @testing.fails_on("+zxjdbc", "http://www.sqlalchemy.org/trac/ticket/1473") def _test_prune(self, fixture): s, prune = fixture() @@ -1423,7 +1510,7 @@ class StrongIdentityMapTest(_fixtures.FixtureTest): mapper(User, users) - for o in [User(name='u%s' % x) for x in range(10)]: + for o in [User(name="u%s" % x) for x in range(10)]: s.add(o) # o is still live after this loop... @@ -1443,7 +1530,7 @@ class StrongIdentityMapTest(_fixtures.FixtureTest): u = s.query(User).get(id) eq_(prune(), 0) self.assert_(len(s.identity_map) == 1) - u.name = 'squiznart' + u.name = "squiznart" del u eq_(prune(), 0) self.assert_(len(s.identity_map) == 1) @@ -1451,7 +1538,7 @@ class StrongIdentityMapTest(_fixtures.FixtureTest): eq_(prune(), 1) self.assert_(len(s.identity_map) == 0) - s.add(User(name='x')) + s.add(User(name="x")) eq_(prune(), 0) self.assert_(len(s.identity_map) == 0) s.flush() @@ -1477,7 +1564,7 @@ class StrongIdentityMapTest(_fixtures.FixtureTest): sess = Session(weak_identity_map=False) - u1 = User(name='u1') + u1 = User(name="u1") sess.add(u1) sess.commit() @@ -1507,9 +1594,7 @@ class IsModifiedTest(_fixtures.FixtureTest): def _default_mapping_fixture(self): User, Address = self.classes.User, self.classes.Address users, addresses = self.tables.users, self.tables.addresses - mapper(User, users, properties={ - "addresses": relationship(Address) - }) + mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) return User, Address @@ -1519,7 +1604,7 @@ class IsModifiedTest(_fixtures.FixtureTest): s = create_session() # save user - u = User(name='fred') + u = User(name="fred") s.add(u) s.flush() s.expunge_all() @@ -1527,10 +1612,10 @@ class IsModifiedTest(_fixtures.FixtureTest): user = s.query(User).one() assert user not in s.dirty assert not s.is_modified(user) - user.name = 'fred' + user.name = "fred" assert user in s.dirty assert not s.is_modified(user) - user.name = 'ed' + user.name = "ed" assert user in s.dirty assert s.is_modified(user) s.flush() @@ -1550,7 +1635,7 @@ class IsModifiedTest(_fixtures.FixtureTest): User, Address = self._default_mapping_fixture() s = Session() - u = User(name='fred', addresses=[Address(email_address='foo')]) + u = User(name="fred", addresses=[Address(email_address="foo")]) s.add(u) s.commit() @@ -1558,27 +1643,24 @@ class IsModifiedTest(_fixtures.FixtureTest): def go(): assert not s.is_modified(u) - self.assert_sql_count( - testing.db, - go, - 0 - ) + + self.assert_sql_count(testing.db, go, 0) s.expire_all() - u.name = 'newname' + u.name = "newname" # can't predict result here # deterministically, depending on if # 'name' or 'addresses' is tested first mod = s.is_modified(u) - addresses_loaded = 'addresses' in u.__dict__ + addresses_loaded = "addresses" in u.__dict__ assert mod is not addresses_loaded def test_is_modified_passive_on(self): User, Address = self._default_mapping_fixture() s = Session() - u = User(name='fred', addresses=[Address(email_address='foo')]) + u = User(name="fred", addresses=[Address(email_address="foo")]) s.add(u) s.commit() @@ -1586,29 +1668,23 @@ class IsModifiedTest(_fixtures.FixtureTest): def go(): assert not s.is_modified(u, passive=True) - self.assert_sql_count( - testing.db, - go, - 0 - ) - u.name = 'newname' + self.assert_sql_count(testing.db, go, 0) + + u.name = "newname" def go(): assert s.is_modified(u, passive=True) - self.assert_sql_count( - testing.db, - go, - 0 - ) + + self.assert_sql_count(testing.db, go, 0) def test_is_modified_syn(self): User, users = self.classes.User, self.tables.users s = sessionmaker()() - mapper(User, users, properties={'uname': sa.orm.synonym('name')}) - u = User(uname='fred') + mapper(User, users, properties={"uname": sa.orm.synonym("name")}) + u = User(uname="fred") assert s.is_modified(u) s.add(u) s.commit() @@ -1616,26 +1692,32 @@ class IsModifiedTest(_fixtures.FixtureTest): class DisposedStates(fixtures.MappedTest): - run_setup_mappers = 'once' - run_inserts = 'once' + run_setup_mappers = "once" + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - Table('t1', metadata, Column('id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50))) + Table( + "t1", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) @classmethod def setup_classes(cls): class T(cls.Basic): def __init__(self, data): self.data = data + mapper(T, cls.tables.t1) def teardown(self): from sqlalchemy.orm.session import _sessions + _sessions.clear() super(DisposedStates, self).teardown() @@ -1665,15 +1747,20 @@ class DisposedStates(fixtures.MappedTest): T = self.classes.T sess = create_session(**kwargs) - data = o1, o2, o3, o4, o5 = [T('t1'), T('t2'), T('t3'), T('t4'), - T('t5')] + data = o1, o2, o3, o4, o5 = [ + T("t1"), + T("t2"), + T("t3"), + T("t4"), + T("t5"), + ] sess.add_all(data) sess.flush() - o1.data = 't1modified' - o5.data = 't5modified' + o1.data = "t1modified" + o5.data = "t5modified" self._set_imap_in_disposal(sess, o2, o4, o5) return sess @@ -1708,13 +1795,12 @@ class SessionInterface(fixtures.TestBase): # TODO: expand with message body assertions. - _class_methods = set(( - 'connection', 'execute', 'get_bind', 'scalar')) + _class_methods = set(("connection", "execute", "get_bind", "scalar")) def _public_session_methods(self): Session = sa.orm.session.Session - blacklist = set(('begin', 'query')) + blacklist = set(("begin", "query")) ok = set() for meth in Session.public_methods: @@ -1726,9 +1812,19 @@ class SessionInterface(fixtures.TestBase): return ok def _map_it(self, cls): - return mapper(cls, Table('t', sa.MetaData(), - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True))) + return mapper( + cls, + Table( + "t", + sa.MetaData(), + Column( + "id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + ), + ) def _test_instance_guards(self, user_arg): watchdog = set() @@ -1736,48 +1832,61 @@ class SessionInterface(fixtures.TestBase): def x_raises_(obj, method, *args, **kw): watchdog.add(method) callable_ = getattr(obj, method) - assert_raises(sa.orm.exc.UnmappedInstanceError, - callable_, *args, **kw) + assert_raises( + sa.orm.exc.UnmappedInstanceError, callable_, *args, **kw + ) def raises_(method, *args, **kw): x_raises_(create_session(), method, *args, **kw) - raises_('__contains__', user_arg) + raises_("__contains__", user_arg) - raises_('add', user_arg) + raises_("add", user_arg) - raises_('add_all', (user_arg,)) + raises_("add_all", (user_arg,)) - raises_('delete', user_arg) + raises_("delete", user_arg) - raises_('expire', user_arg) + raises_("expire", user_arg) - raises_('expunge', user_arg) + raises_("expunge", user_arg) # flush will no-op without something in the unit of work def _(): class OK(object): pass + self._map_it(OK) s = create_session() s.add(OK()) - x_raises_(s, 'flush', (user_arg,)) + x_raises_(s, "flush", (user_arg,)) + _() - raises_('is_modified', user_arg) + raises_("is_modified", user_arg) - raises_('merge', user_arg) + raises_("merge", user_arg) - raises_('refresh', user_arg) + raises_("refresh", user_arg) - instance_methods = self._public_session_methods() \ - - self._class_methods - set([ - 'bulk_update_mappings', 'bulk_insert_mappings', - 'bulk_save_objects']) + instance_methods = ( + self._public_session_methods() + - self._class_methods + - set( + [ + "bulk_update_mappings", + "bulk_insert_mappings", + "bulk_save_objects", + ] + ) + ) - eq_(watchdog, instance_methods, - watchdog.symmetric_difference(instance_methods)) + eq_( + watchdog, + instance_methods, + watchdog.symmetric_difference(instance_methods), + ) def _test_class_guards(self, user_arg, is_class=True): watchdog = set() @@ -1787,22 +1896,26 @@ class SessionInterface(fixtures.TestBase): callable_ = getattr(create_session(), method) if is_class: assert_raises( - sa.orm.exc.UnmappedClassError, - callable_, *args, **kw) + sa.orm.exc.UnmappedClassError, callable_, *args, **kw + ) else: assert_raises( - sa.exc.NoInspectionAvailable, callable_, *args, **kw) + sa.exc.NoInspectionAvailable, callable_, *args, **kw + ) - raises_('connection', mapper=user_arg) + raises_("connection", mapper=user_arg) - raises_('execute', 'SELECT 1', mapper=user_arg) + raises_("execute", "SELECT 1", mapper=user_arg) - raises_('get_bind', mapper=user_arg) + raises_("get_bind", mapper=user_arg) - raises_('scalar', 'SELECT 1', mapper=user_arg) + raises_("scalar", "SELECT 1", mapper=user_arg) - eq_(watchdog, self._class_methods, - watchdog.symmetric_difference(self._class_methods)) + eq_( + watchdog, + self._class_methods, + watchdog.symmetric_difference(self._class_methods), + ) def test_unmapped_instance(self): class Unmapped(object): @@ -1812,7 +1925,7 @@ class SessionInterface(fixtures.TestBase): self._test_class_guards(Unmapped) def test_unmapped_primitives(self): - for prim in ('doh', 123, ('t', 'u', 'p', 'l', 'e')): + for prim in ("doh", 123, ("t", "u", "p", "l", "e")): self._test_instance_guards(prim) self._test_class_guards(prim, is_class=False) @@ -1826,6 +1939,7 @@ class SessionInterface(fixtures.TestBase): def test_mapped_class_for_instance(self): class Mapped(object): pass + self._map_it(Mapped) self._test_instance_guards(Mapped) @@ -1834,6 +1948,7 @@ class SessionInterface(fixtures.TestBase): def test_missing_state(self): class Mapped(object): pass + early = Mapped() self._map_it(Mapped) @@ -1843,6 +1958,7 @@ class SessionInterface(fixtures.TestBase): def test_refresh_arg_signature(self): class Mapped(object): pass + self._map_it(Mapped) m1 = Mapped() @@ -1853,11 +1969,13 @@ class SessionInterface(fixtures.TestBase): sa.exc.ArgumentError, "with_for_update should be the boolean value True, " "or a dictionary with options", - s.refresh, m1, with_for_update={} + s.refresh, + m1, + with_for_update={}, ) with mock.patch( - "sqlalchemy.orm.session.loading.load_on_ident" + "sqlalchemy.orm.session.loading.load_on_ident" ) as load_on_ident: s.refresh(m1, with_for_update={"read": True}) s.refresh(m1, with_for_update=True) @@ -1865,26 +1983,35 @@ class SessionInterface(fixtures.TestBase): s.refresh(m1) from sqlalchemy.orm.query import LockmodeArg + eq_( [ - call[-1]['with_for_update'] - for call in load_on_ident.mock_calls], - [LockmodeArg(read=True), LockmodeArg(), None, None] + call[-1]["with_for_update"] + for call in load_on_ident.mock_calls + ], + [LockmodeArg(read=True), LockmodeArg(), None, None], ) + class TLTransactionTest(fixtures.MappedTest): - run_dispose_bind = 'once' + run_dispose_bind = "once" __backend__ = True @classmethod def setup_bind(cls): - return engines.testing_engine(options=dict(strategy='threadlocal')) + return engines.testing_engine(options=dict(strategy="threadlocal")) @classmethod def define_tables(cls, metadata): - Table('users', metadata, Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(20)), test_needs_acid=True) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(20)), + test_needs_acid=True, + ) @classmethod def setup_classes(cls): @@ -1897,33 +2024,41 @@ class TLTransactionTest(fixtures.MappedTest): mapper(User, users) - @testing.exclude('mysql', '<', (5, 0, 3), 'FIXME: unknown') + @testing.exclude("mysql", "<", (5, 0, 3), "FIXME: unknown") def test_session_nesting(self): User = self.classes.User sess = create_session(bind=self.bind) self.bind.begin() - u = User(name='ed') + u = User(name="ed") sess.add(u) sess.flush() self.bind.commit() class FlushWarningsTest(fixtures.MappedTest): - run_setup_mappers = 'each' + run_setup_mappers = "each" @classmethod def define_tables(cls, metadata): - Table('user', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(20))) + Table( + "user", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(20)), + ) - Table('address', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', Integer, ForeignKey('user.id')), - Column('email', String(20))) + Table( + "address", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", Integer, ForeignKey("user.id")), + Column("email", String(20)), + ) @classmethod def setup_classes(cls): @@ -1937,55 +2072,64 @@ class FlushWarningsTest(fixtures.MappedTest): def setup_mappers(cls): user, User = cls.tables.user, cls.classes.User address, Address = cls.tables.address, cls.classes.Address - mapper(User, user, properties={ - 'addresses': relationship(Address, backref="user") - }) + mapper( + User, + user, + properties={"addresses": relationship(Address, backref="user")}, + ) mapper(Address, address) def test_o2m_cascade_add(self): Address = self.classes.Address def evt(mapper, conn, instance): - instance.addresses.append(Address(email='x1')) + instance.addresses.append(Address(email="x1")) + self._test(evt, "collection append") def test_o2m_cascade_remove(self): def evt(mapper, conn, instance): del instance.addresses[0] + self._test(evt, "collection remove") def test_m2o_cascade_add(self): User = self.classes.User def evt(mapper, conn, instance): - instance.addresses[0].user = User(name='u2') + instance.addresses[0].user = User(name="u2") + self._test(evt, "related attribute set") def test_m2o_cascade_remove(self): def evt(mapper, conn, instance): a1 = instance.addresses[0] del a1.user + self._test(evt, "related attribute delete") def test_plain_add(self): Address = self.classes.Address def evt(mapper, conn, instance): - object_session(instance).add(Address(email='x1')) + object_session(instance).add(Address(email="x1")) + self._test(evt, r"Session.add\(\)") def test_plain_merge(self): Address = self.classes.Address def evt(mapper, conn, instance): - object_session(instance).merge(Address(email='x1')) + object_session(instance).merge(Address(email="x1")) + self._test(evt, r"Session.merge\(\)") def test_plain_delete(self): Address = self.classes.Address def evt(mapper, conn, instance): - object_session(instance).delete(Address(email='x1')) + object_session(instance).delete(Address(email="x1")) + self._test(evt, r"Session.delete\(\)") def _test(self, fn, method): @@ -1995,10 +2139,8 @@ class FlushWarningsTest(fixtures.MappedTest): s = Session() event.listen(User, "after_insert", fn) - u1 = User(name='u1', addresses=[Address(name='a1')]) + u1 = User(name="u1", addresses=[Address(name="a1")]) s.add(u1) assert_raises_message( - sa.exc.SAWarning, - "Usage of the '%s'" % method, - s.commit + sa.exc.SAWarning, "Usage of the '%s'" % method, s.commit ) diff --git a/test/orm/test_subquery_relations.py b/test/orm/test_subquery_relations.py index 606beb5aa8..8af02520f3 100644 --- a/test/orm/test_subquery_relations.py +++ b/test/orm/test_subquery_relations.py @@ -2,12 +2,23 @@ from sqlalchemy.testing import eq_, is_, is_not_, is_true from sqlalchemy import testing from sqlalchemy.testing.schema import Table, Column from sqlalchemy import Integer, String, ForeignKey, bindparam, inspect -from sqlalchemy.orm import backref, subqueryload, subqueryload_all, \ - mapper, relationship, clear_mappers, create_session, lazyload, \ - aliased, joinedload, deferred, undefer, eagerload_all,\ - Session -from sqlalchemy.testing import eq_, assert_raises, \ - assert_raises_message +from sqlalchemy.orm import ( + backref, + subqueryload, + subqueryload_all, + mapper, + relationship, + clear_mappers, + create_session, + lazyload, + aliased, + joinedload, + deferred, + undefer, + eagerload_all, + Session, +) +from sqlalchemy.testing import eq_, assert_raises, assert_raises_message from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing import fixtures from sqlalchemy.testing.entities import ComparableEntity @@ -16,43 +27,61 @@ import sqlalchemy as sa from sqlalchemy.orm import with_polymorphic -from .inheritance._poly_fixtures import _Polymorphic, Person, Engineer, \ - Paperwork, Page, Machine, MachineType, Company +from .inheritance._poly_fixtures import ( + _Polymorphic, + Person, + Engineer, + Paperwork, + Page, + Machine, + MachineType, + Company, +) class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): - run_inserts = 'once' + run_inserts = "once" run_deletes = None def test_basic(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), - order_by=Address.id) - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), order_by=Address.id + ) + }, + ) sess = create_session() q = sess.query(User).options(subqueryload(User.addresses)) def go(): eq_( - [User(id=7, addresses=[ - Address(id=1, email_address='jack@bean.com')])], - q.filter(User.id == 7).all() + [ + User( + id=7, + addresses=[ + Address(id=1, email_address="jack@bean.com") + ], + ) + ], + q.filter(User.id == 7).all(), ) self.assert_sql_count(testing.db, go, 2) def go(): - eq_( - self.static.user_address_result, - q.order_by(User.id).all() - ) + eq_(self.static.user_address_result, q.order_by(User.id).all()) + self.assert_sql_count(testing.db, go, 2) def test_from_aliased(self): @@ -62,17 +91,24 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.User, self.tables.dingalings, self.classes.Address, - self.tables.addresses) + self.tables.addresses, + ) mapper(Dingaling, dingalings) - mapper(Address, addresses, properties={ - 'dingalings': relationship(Dingaling, order_by=Dingaling.id) - }) - mapper(User, users, properties={ - 'addresses': relationship( - Address, - order_by=Address.id) - }) + mapper( + Address, + addresses, + properties={ + "dingalings": relationship(Dingaling, order_by=Dingaling.id) + }, + ) + mapper( + User, + users, + properties={ + "addresses": relationship(Address, order_by=Address.id) + }, + ) sess = create_session() u = aliased(User) @@ -81,84 +117,113 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def go(): eq_( - [User(id=7, addresses=[ - Address(id=1, email_address='jack@bean.com')])], - q.filter(u.id == 7).all() + [ + User( + id=7, + addresses=[ + Address(id=1, email_address="jack@bean.com") + ], + ) + ], + q.filter(u.id == 7).all(), ) self.assert_sql_count(testing.db, go, 2) def go(): - eq_( - self.static.user_address_result, - q.order_by(u.id).all() - ) + eq_(self.static.user_address_result, q.order_by(u.id).all()) + self.assert_sql_count(testing.db, go, 2) - q = sess.query(u).\ - options(subqueryload_all(u.addresses, Address.dingalings)) + q = sess.query(u).options( + subqueryload_all(u.addresses, Address.dingalings) + ) def go(): eq_( [ - User(id=8, addresses=[ - Address(id=2, email_address='ed@wood.com', - dingalings=[Dingaling()]), - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - ]), - User(id=9, addresses=[ - Address(id=5, dingalings=[Dingaling()]) - ]), + User( + id=8, + addresses=[ + Address( + id=2, + email_address="ed@wood.com", + dingalings=[Dingaling()], + ), + Address(id=3, email_address="ed@bettyboop.com"), + Address(id=4, email_address="ed@lala.com"), + ], + ), + User( + id=9, + addresses=[Address(id=5, dingalings=[Dingaling()])], + ), ], - q.filter(u.id.in_([8, 9])).all() + q.filter(u.id.in_([8, 9])).all(), ) + self.assert_sql_count(testing.db, go, 3) def test_from_get(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), - order_by=Address.id) - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), order_by=Address.id + ) + }, + ) sess = create_session() q = sess.query(User).options(subqueryload(User.addresses)) def go(): eq_( - User(id=7, addresses=[ - Address(id=1, email_address='jack@bean.com')]), - q.get(7) + User( + id=7, + addresses=[Address(id=1, email_address="jack@bean.com")], + ), + q.get(7), ) self.assert_sql_count(testing.db, go, 2) def test_from_params(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship( - mapper(Address, addresses), - order_by=Address.id) - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), order_by=Address.id + ) + }, + ) sess = create_session() q = sess.query(User).options(subqueryload(User.addresses)) def go(): eq_( - User(id=7, addresses=[ - Address(id=1, email_address='jack@bean.com')]), - q.filter(User.id == bindparam('foo')).params(foo=7).one() + User( + id=7, + addresses=[Address(id=1, email_address="jack@bean.com")], + ), + q.filter(User.id == bindparam("foo")).params(foo=7).one(), ) self.assert_sql_count(testing.db, go, 2) @@ -166,14 +231,18 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def test_disable_dynamic(self): """test no subquery option on a dynamic.""" - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) - mapper(User, users, properties={ - 'addresses': relationship(Address, lazy="dynamic") - }) + mapper( + User, + users, + properties={"addresses": relationship(Address, lazy="dynamic")}, + ) mapper(Address, addresses) sess = create_session() @@ -192,17 +261,28 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.tables.items, self.tables.item_keywords, self.classes.Keyword, - self.classes.Item) + self.classes.Item, + ) mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords=relationship(Keyword, secondary=item_keywords, - lazy='subquery', order_by=keywords.c.id))) + mapper( + Item, + items, + properties=dict( + keywords=relationship( + Keyword, + secondary=item_keywords, + lazy="subquery", + order_by=keywords.c.id, + ) + ), + ) q = create_session().query(Item).order_by(Item.id) def go(): eq_(self.static.item_keyword_result, q.all()) + self.assert_sql_count(testing.db, go, 2) def test_many_to_many_with_join(self): @@ -211,18 +291,31 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.tables.items, self.tables.item_keywords, self.classes.Keyword, - self.classes.Item) + self.classes.Item, + ) mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords=relationship(Keyword, secondary=item_keywords, - lazy='subquery', order_by=keywords.c.id))) + mapper( + Item, + items, + properties=dict( + keywords=relationship( + Keyword, + secondary=item_keywords, + lazy="subquery", + order_by=keywords.c.id, + ) + ), + ) q = create_session().query(Item).order_by(Item.id) def go(): - eq_(self.static.item_keyword_result[0:2], - q.join('keywords').filter(Keyword.name == 'red').all()) + eq_( + self.static.item_keyword_result[0:2], + q.join("keywords").filter(Keyword.name == "red").all(), + ) + self.assert_sql_count(testing.db, go, 2) def test_many_to_many_with_join_alias(self): @@ -231,139 +324,193 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.tables.items, self.tables.item_keywords, self.classes.Keyword, - self.classes.Item) + self.classes.Item, + ) mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords=relationship(Keyword, secondary=item_keywords, - lazy='subquery', order_by=keywords.c.id))) + mapper( + Item, + items, + properties=dict( + keywords=relationship( + Keyword, + secondary=item_keywords, + lazy="subquery", + order_by=keywords.c.id, + ) + ), + ) q = create_session().query(Item).order_by(Item.id) def go(): - eq_(self.static.item_keyword_result[0:2], - (q.join('keywords', aliased=True). - filter(Keyword.name == 'red')).all()) + eq_( + self.static.item_keyword_result[0:2], + ( + q.join("keywords", aliased=True).filter( + Keyword.name == "red" + ) + ).all(), + ) + self.assert_sql_count(testing.db, go, 2) def test_orderby(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses), - lazy='subquery', - order_by=addresses.c.email_address), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), + lazy="subquery", + order_by=addresses.c.email_address, + ) + }, + ) q = create_session().query(User) - eq_([ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - Address(id=2, email_address='ed@wood.com') - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[]) - ], q.order_by(User.id).all()) + eq_( + [ + User(id=7, addresses=[Address(id=1)]), + User( + id=8, + addresses=[ + Address(id=3, email_address="ed@bettyboop.com"), + Address(id=4, email_address="ed@lala.com"), + Address(id=2, email_address="ed@wood.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + User(id=10, addresses=[]), + ], + q.order_by(User.id).all(), + ) def test_orderby_multi(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses), - lazy='subquery', - order_by=[ - addresses.c.email_address, - addresses.c.id]), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), + lazy="subquery", + order_by=[addresses.c.email_address, addresses.c.id], + ) + }, + ) q = create_session().query(User) - eq_([ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - Address(id=2, email_address='ed@wood.com') - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[]) - ], q.order_by(User.id).all()) + eq_( + [ + User(id=7, addresses=[Address(id=1)]), + User( + id=8, + addresses=[ + Address(id=3, email_address="ed@bettyboop.com"), + Address(id=4, email_address="ed@lala.com"), + Address(id=2, email_address="ed@wood.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + User(id=10, addresses=[]), + ], + q.order_by(User.id).all(), + ) def test_orderby_related(self): """A regular mapper select on a single table can order by a relationship to a second table""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, - lazy='subquery', - order_by=addresses.c.id), - )) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, lazy="subquery", order_by=addresses.c.id + ) + ), + ) q = create_session().query(User) - result = q.filter(User.id == Address.user_id).\ - order_by(Address.email_address).all() - - eq_([ - User(id=8, addresses=[ - Address(id=2, email_address='ed@wood.com'), - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=7, addresses=[ - Address(id=1) - ]), - ], result) + result = ( + q.filter(User.id == Address.user_id) + .order_by(Address.email_address) + .all() + ) + + eq_( + [ + User( + id=8, + addresses=[ + Address(id=2, email_address="ed@wood.com"), + Address(id=3, email_address="ed@bettyboop.com"), + Address(id=4, email_address="ed@lala.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + User(id=7, addresses=[Address(id=1)]), + ], + result, + ) def test_orderby_desc(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, lazy='subquery', - order_by=[ - sa.desc(addresses.c.email_address) - ]), - )) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, + lazy="subquery", + order_by=[sa.desc(addresses.c.email_address)], + ) + ), + ) sess = create_session() - eq_([ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=2, email_address='ed@wood.com'), - Address(id=4, email_address='ed@lala.com'), - Address(id=3, email_address='ed@bettyboop.com'), - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[]) - ], sess.query(User).order_by(User.id).all()) + eq_( + [ + User(id=7, addresses=[Address(id=1)]), + User( + id=8, + addresses=[ + Address(id=2, email_address="ed@wood.com"), + Address(id=4, email_address="ed@lala.com"), + Address(id=3, email_address="ed@bettyboop.com"), + ], + ), + User(id=9, addresses=[Address(id=5)]), + User(id=10, addresses=[]), + ], + sess.query(User).order_by(User.id).all(), + ) _pathing_runs = [ ("lazyload", "lazyload", "lazyload", 15), @@ -382,37 +529,47 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self._do_mapper_test(self._pathing_runs) def _do_options_test(self, configs): - users, Keyword, orders, items, order_items, Order, Item, User, \ - keywords, item_keywords = (self.tables.users, - self.classes.Keyword, - self.tables.orders, - self.tables.items, - self.tables.order_items, - self.classes.Order, - self.classes.Item, - self.classes.User, - self.tables.keywords, - self.tables.item_keywords) - - mapper(User, users, properties={ - 'orders': relationship(Order, order_by=orders.c.id), # o2m, m2o - }) - mapper(Order, orders, properties={ - 'items': relationship(Item, - secondary=order_items, - order_by=items.c.id), # m2m - }) - mapper(Item, items, properties={ - 'keywords': relationship(Keyword, - secondary=item_keywords, - order_by=keywords.c.id) # m2m - }) + users, Keyword, orders, items, order_items, Order, Item, User, keywords, item_keywords = ( + self.tables.users, + self.classes.Keyword, + self.tables.orders, + self.tables.items, + self.tables.order_items, + self.classes.Order, + self.classes.Item, + self.classes.User, + self.tables.keywords, + self.tables.item_keywords, + ) + + mapper( + User, + users, + properties={ + "orders": relationship(Order, order_by=orders.c.id) # o2m, m2o + }, + ) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, secondary=order_items, order_by=items.c.id + ) # m2m + }, + ) + mapper( + Item, + items, + properties={ + "keywords": relationship( + Keyword, secondary=item_keywords, order_by=keywords.c.id + ) # m2m + }, + ) mapper(Keyword, keywords) - callables = { - 'joinedload': joinedload, - 'subqueryload': subqueryload - } + callables = {"joinedload": joinedload, "subqueryload": subqueryload} for o, i, k, count in configs: options = [] @@ -421,46 +578,66 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): if i in callables: options.append(callables[i](User.orders, Order.items)) if k in callables: - options.append(callables[k]( - User.orders, Order.items, Item.keywords)) + options.append( + callables[k](User.orders, Order.items, Item.keywords) + ) self._do_query_tests(options, count) def _do_mapper_test(self, configs): - users, Keyword, orders, items, order_items, Order, Item, User, \ - keywords, item_keywords = (self.tables.users, - self.classes.Keyword, - self.tables.orders, - self.tables.items, - self.tables.order_items, - self.classes.Order, - self.classes.Item, - self.classes.User, - self.tables.keywords, - self.tables.item_keywords) + users, Keyword, orders, items, order_items, Order, Item, User, keywords, item_keywords = ( + self.tables.users, + self.classes.Keyword, + self.tables.orders, + self.tables.items, + self.tables.order_items, + self.classes.Order, + self.classes.Item, + self.classes.User, + self.tables.keywords, + self.tables.item_keywords, + ) opts = { - 'lazyload': 'select', - 'joinedload': 'joined', - 'subqueryload': 'subquery', + "lazyload": "select", + "joinedload": "joined", + "subqueryload": "subquery", } for o, i, k, count in configs: - mapper(User, users, properties={ - 'orders': relationship(Order, lazy=opts[o], - order_by=orders.c.id), - }) - mapper(Order, orders, properties={ - 'items': relationship(Item, - secondary=order_items, lazy=opts[i], - order_by=items.c.id), - }) - mapper(Item, items, properties={ - 'keywords': relationship(Keyword, - lazy=opts[k], - secondary=item_keywords, - order_by=keywords.c.id) - }) + mapper( + User, + users, + properties={ + "orders": relationship( + Order, lazy=opts[o], order_by=orders.c.id + ) + }, + ) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, + secondary=order_items, + lazy=opts[i], + order_by=items.c.id, + ) + }, + ) + mapper( + Item, + items, + properties={ + "keywords": relationship( + Keyword, + lazy=opts[k], + secondary=item_keywords, + order_by=keywords.c.id, + ) + }, + ) mapper(Keyword, keywords) try: @@ -476,87 +653,125 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def go(): eq_( sess.query(User).options(*opts).order_by(User.id).all(), - self.static.user_item_keyword_result + self.static.user_item_keyword_result, ) + self.assert_sql_count(testing.db, go, count) eq_( - sess.query(User).options(*opts).filter(User.name == 'fred'). - order_by(User.id).all(), - self.static.user_item_keyword_result[2:3] + sess.query(User) + .options(*opts) + .filter(User.name == "fred") + .order_by(User.id) + .all(), + self.static.user_item_keyword_result[2:3], ) sess = create_session() eq_( - sess.query(User).options(*opts).join(User.orders). - filter(Order.id == 3). - order_by(User.id).all(), - self.static.user_item_keyword_result[0:1] + sess.query(User) + .options(*opts) + .join(User.orders) + .filter(Order.id == 3) + .order_by(User.id) + .all(), + self.static.user_item_keyword_result[0:1], ) def test_cyclical(self): """A circular eager relationship breaks the cycle with a lazy loader""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, lazy='subquery', - backref=sa.orm.backref( - 'user', lazy='subquery',), - order_by=Address.id) - )) - is_(sa.orm.class_mapper(User).get_property('addresses').lazy, - 'subquery') - is_(sa.orm.class_mapper(Address).get_property('user').lazy, 'subquery') + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, + lazy="subquery", + backref=sa.orm.backref("user", lazy="subquery"), + order_by=Address.id, + ) + ), + ) + is_( + sa.orm.class_mapper(User).get_property("addresses").lazy, + "subquery", + ) + is_(sa.orm.class_mapper(Address).get_property("user").lazy, "subquery") sess = create_session() - eq_(self.static.user_address_result, - sess.query(User).order_by(User.id).all()) + eq_( + self.static.user_address_result, + sess.query(User).order_by(User.id).all(), + ) def test_cyclical_explicit_join_depth(self): """A circular eager relationship breaks the cycle with a lazy loader""" - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, lazy='subquery', join_depth=1, - backref=sa.orm.backref( - 'user', lazy='subquery', join_depth=1), - order_by=Address.id) - )) - is_(sa.orm.class_mapper(User).get_property('addresses').lazy, - 'subquery') - is_(sa.orm.class_mapper(Address).get_property('user').lazy, 'subquery') + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, + lazy="subquery", + join_depth=1, + backref=sa.orm.backref( + "user", lazy="subquery", join_depth=1 + ), + order_by=Address.id, + ) + ), + ) + is_( + sa.orm.class_mapper(User).get_property("addresses").lazy, + "subquery", + ) + is_(sa.orm.class_mapper(Address).get_property("user").lazy, "subquery") sess = create_session() - eq_(self.static.user_address_result, - sess.query(User).order_by(User.id).all()) + eq_( + self.static.user_address_result, + sess.query(User).order_by(User.id).all(), + ) def test_add_arbitrary_exprs(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relationship(Address, lazy='subquery') - )) + mapper( + User, + users, + properties=dict(addresses=relationship(Address, lazy="subquery")), + ) sess = create_session() self.assert_compile( - sess.query(User, '1'), + sess.query(User, "1"), "SELECT users.id AS users_id, users.name AS users_name, " - "1 FROM users" + "1 FROM users", ) def test_double(self): @@ -569,10 +784,11 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): self.classes.User, self.classes.Address, self.classes.Order, - self.tables.addresses) + self.tables.addresses, + ) - openorders = sa.alias(orders, 'openorders') - closedorders = sa.alias(orders, 'closedorders') + openorders = sa.alias(orders, "openorders") + closedorders = sa.alias(orders, "closedorders") mapper(Address, addresses) mapper(Order, orders) @@ -580,150 +796,217 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): open_mapper = mapper(Order, openorders, non_primary=True) closed_mapper = mapper(Order, closedorders, non_primary=True) - mapper(User, users, properties=dict( - addresses=relationship(Address, lazy='subquery', - order_by=addresses.c.id), - open_orders=relationship( - open_mapper, - primaryjoin=sa.and_(openorders.c.isopen == 1, - users.c.id == openorders.c.user_id), - lazy='subquery', order_by=openorders.c.id), - closed_orders=relationship( - closed_mapper, - primaryjoin=sa.and_(closedorders.c.isopen == 0, - users.c.id == closedorders.c.user_id), - lazy='subquery', order_by=closedorders.c.id))) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, lazy="subquery", order_by=addresses.c.id + ), + open_orders=relationship( + open_mapper, + primaryjoin=sa.and_( + openorders.c.isopen == 1, + users.c.id == openorders.c.user_id, + ), + lazy="subquery", + order_by=openorders.c.id, + ), + closed_orders=relationship( + closed_mapper, + primaryjoin=sa.and_( + closedorders.c.isopen == 0, + users.c.id == closedorders.c.user_id, + ), + lazy="subquery", + order_by=closedorders.c.id, + ), + ), + ) q = create_session().query(User).order_by(User.id) def go(): - eq_([ - User( - id=7, - addresses=[Address(id=1)], - open_orders=[Order(id=3)], - closed_orders=[Order(id=1), Order(id=5)] - ), - User( - id=8, - addresses=[Address(id=2), Address(id=3), Address(id=4)], - open_orders=[], - closed_orders=[] - ), - User( - id=9, - addresses=[Address(id=5)], - open_orders=[Order(id=4)], - closed_orders=[Order(id=2)] - ), - User(id=10) + eq_( + [ + User( + id=7, + addresses=[Address(id=1)], + open_orders=[Order(id=3)], + closed_orders=[Order(id=1), Order(id=5)], + ), + User( + id=8, + addresses=[ + Address(id=2), + Address(id=3), + Address(id=4), + ], + open_orders=[], + closed_orders=[], + ), + User( + id=9, + addresses=[Address(id=5)], + open_orders=[Order(id=4)], + closed_orders=[Order(id=2)], + ), + User(id=10), + ], + q.all(), + ) - ], q.all()) self.assert_sql_count(testing.db, go, 4) def test_double_same_mappers(self): """Eager loading with two relationships simultaneously, from the same table, using aliases.""" - addresses, items, order_items, orders, Item, User, Address, Order, \ - users = (self.tables.addresses, - self.tables.items, - self.tables.order_items, - self.tables.orders, - self.classes.Item, - self.classes.User, - self.classes.Address, - self.classes.Order, - self.tables.users) + addresses, items, order_items, orders, Item, User, Address, Order, users = ( + self.tables.addresses, + self.tables.items, + self.tables.order_items, + self.tables.orders, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.classes.Order, + self.tables.users, + ) mapper(Address, addresses) - mapper(Order, orders, properties={ - 'items': relationship(Item, secondary=order_items, lazy='subquery', - order_by=items.c.id)}) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, + secondary=order_items, + lazy="subquery", + order_by=items.c.id, + ) + }, + ) mapper(Item, items) - mapper(User, users, properties=dict( - addresses=relationship( - Address, lazy='subquery', order_by=addresses.c.id), - open_orders=relationship( - Order, - primaryjoin=sa.and_(orders.c.isopen == 1, - users.c.id == orders.c.user_id), - lazy='subquery', order_by=orders.c.id), - closed_orders=relationship( - Order, - primaryjoin=sa.and_(orders.c.isopen == 0, - users.c.id == orders.c.user_id), - lazy='subquery', order_by=orders.c.id))) + mapper( + User, + users, + properties=dict( + addresses=relationship( + Address, lazy="subquery", order_by=addresses.c.id + ), + open_orders=relationship( + Order, + primaryjoin=sa.and_( + orders.c.isopen == 1, users.c.id == orders.c.user_id + ), + lazy="subquery", + order_by=orders.c.id, + ), + closed_orders=relationship( + Order, + primaryjoin=sa.and_( + orders.c.isopen == 0, users.c.id == orders.c.user_id + ), + lazy="subquery", + order_by=orders.c.id, + ), + ), + ) q = create_session().query(User).order_by(User.id) def go(): - eq_([ - User(id=7, - addresses=[ - Address(id=1)], - open_orders=[Order(id=3, - items=[ - Item(id=3), - Item(id=4), - Item(id=5)])], - closed_orders=[Order(id=1, - items=[ - Item(id=1), - Item(id=2), - Item(id=3)]), - Order(id=5, - items=[ - Item(id=5)])]), - User(id=8, - addresses=[ - Address(id=2), - Address(id=3), - Address(id=4)], - open_orders=[], - closed_orders=[]), - User(id=9, - addresses=[ - Address(id=5)], - open_orders=[ - Order(id=4, - items=[ - Item(id=1), - Item(id=5)])], - closed_orders=[ - Order(id=2, - items=[ - Item(id=1), - Item(id=2), - Item(id=3)])]), - User(id=10) - ], q.all()) + eq_( + [ + User( + id=7, + addresses=[Address(id=1)], + open_orders=[ + Order( + id=3, + items=[Item(id=3), Item(id=4), Item(id=5)], + ) + ], + closed_orders=[ + Order( + id=1, + items=[Item(id=1), Item(id=2), Item(id=3)], + ), + Order(id=5, items=[Item(id=5)]), + ], + ), + User( + id=8, + addresses=[ + Address(id=2), + Address(id=3), + Address(id=4), + ], + open_orders=[], + closed_orders=[], + ), + User( + id=9, + addresses=[Address(id=5)], + open_orders=[ + Order(id=4, items=[Item(id=1), Item(id=5)]) + ], + closed_orders=[ + Order( + id=2, + items=[Item(id=1), Item(id=2), Item(id=3)], + ) + ], + ), + User(id=10), + ], + q.all(), + ) + self.assert_sql_count(testing.db, go, 6) def test_limit(self): """Limit operations combined with lazy-load relationships.""" - users, items, order_items, orders, Item, User, Address, Order, \ - addresses = (self.tables.users, - self.tables.items, - self.tables.order_items, - self.tables.orders, - self.classes.Item, - self.classes.User, - self.classes.Address, - self.classes.Order, - self.tables.addresses) + users, items, order_items, orders, Item, User, Address, Order, addresses = ( + self.tables.users, + self.tables.items, + self.tables.order_items, + self.tables.orders, + self.classes.Item, + self.classes.User, + self.classes.Address, + self.classes.Order, + self.tables.addresses, + ) mapper(Item, items) - mapper(Order, orders, properties={ - 'items': relationship(Item, secondary=order_items, lazy='subquery', - order_by=items.c.id) - }) - mapper(User, users, properties={ - 'addresses': relationship(mapper(Address, addresses), - lazy='subquery', - order_by=addresses.c.id), - 'orders': relationship(Order, lazy='select', order_by=orders.c.id) - }) + mapper( + Order, + orders, + properties={ + "items": relationship( + Item, + secondary=order_items, + lazy="subquery", + order_by=items.c.id, + ) + }, + ) + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), + lazy="subquery", + order_by=addresses.c.id, + ), + "orders": relationship( + Order, lazy="select", order_by=orders.c.id + ), + }, + ) sess = create_session() q = sess.query(User) @@ -736,17 +1019,24 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): @testing.uses_deprecated("Mapper.order_by") def test_mapper_order_by(self): - users, User, Address, addresses = (self.tables.users, - self.classes.User, - self.classes.Address, - self.tables.addresses) + users, User, Address, addresses = ( + self.tables.users, + self.classes.User, + self.classes.Address, + self.tables.addresses, + ) mapper(Address, addresses) - mapper(User, users, properties={ - 'addresses': relationship(Address, - lazy='subquery', - order_by=addresses.c.id), - }, order_by=users.c.id.desc()) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, lazy="subquery", order_by=addresses.c.id + ) + }, + order_by=users.c.id.desc(), + ) sess = create_session() q = sess.query(User) @@ -755,31 +1045,45 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): eq_(result, list(reversed(self.static.user_address_result[2:4]))) def test_one_to_many_scalar(self): - Address, addresses, users, User = (self.classes.Address, - self.tables.addresses, - self.tables.users, - self.classes.User) - - mapper(User, users, properties=dict( - address=relationship(mapper(Address, addresses), - lazy='subquery', uselist=False) - )) + Address, addresses, users, User = ( + self.classes.Address, + self.tables.addresses, + self.tables.users, + self.classes.User, + ) + + mapper( + User, + users, + properties=dict( + address=relationship( + mapper(Address, addresses), lazy="subquery", uselist=False + ) + ), + ) q = create_session().query(User) def go(): result = q.filter(users.c.id == 7).all() eq_([User(id=7, address=Address(id=1))], result) + self.assert_sql_count(testing.db, go, 2) def test_many_to_one(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(Address, addresses, properties=dict( - user=relationship(mapper(User, users), lazy='subquery') - )) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + Address, + addresses, + properties=dict( + user=relationship(mapper(User, users), lazy="subquery") + ), + ) sess = create_session() q = sess.query(Address) @@ -788,97 +1092,135 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): is_not_(a.user, None) u1 = sess.query(User).get(7) is_(a.user, u1) + self.assert_sql_count(testing.db, go, 2) def test_double_with_aggregate(self): - User, users, orders, Order = (self.classes.User, - self.tables.users, - self.tables.orders, - self.classes.Order) + User, users, orders, Order = ( + self.classes.User, + self.tables.users, + self.tables.orders, + self.classes.Order, + ) - max_orders_by_user = sa.select([sa.func.max(orders.c.id) - .label('order_id')], - group_by=[orders.c.user_id]) \ - .alias('max_orders_by_user') + max_orders_by_user = sa.select( + [sa.func.max(orders.c.id).label("order_id")], + group_by=[orders.c.user_id], + ).alias("max_orders_by_user") max_orders = orders.select( - orders.c.id == max_orders_by_user.c.order_id).\ - alias('max_orders') + orders.c.id == max_orders_by_user.c.order_id + ).alias("max_orders") mapper(Order, orders) - mapper(User, users, properties={ - 'orders': relationship(Order, backref='user', lazy='subquery', - order_by=orders.c.id), - 'max_order': relationship( - mapper(Order, max_orders, non_primary=True), - lazy='subquery', uselist=False) - }) + mapper( + User, + users, + properties={ + "orders": relationship( + Order, + backref="user", + lazy="subquery", + order_by=orders.c.id, + ), + "max_order": relationship( + mapper(Order, max_orders, non_primary=True), + lazy="subquery", + uselist=False, + ), + }, + ) q = create_session().query(User) def go(): - eq_([ - User(id=7, orders=[ - Order(id=1), - Order(id=3), - Order(id=5), + eq_( + [ + User( + id=7, + orders=[Order(id=1), Order(id=3), Order(id=5)], + max_order=Order(id=5), + ), + User(id=8, orders=[]), + User( + id=9, + orders=[Order(id=2), Order(id=4)], + max_order=Order(id=4), + ), + User(id=10), ], - max_order=Order(id=5) - ), - User(id=8, orders=[]), - User(id=9, orders=[Order(id=2), Order(id=4)], - max_order=Order(id=4)), - User(id=10), - ], q.order_by(User.id).all()) + q.order_by(User.id).all(), + ) + self.assert_sql_count(testing.db, go, 3) def test_uselist_false_warning(self): """test that multiple rows received by a uselist=False raises a warning.""" - User, users, orders, Order = (self.classes.User, - self.tables.users, - self.tables.orders, - self.classes.Order) + User, users, orders, Order = ( + self.classes.User, + self.tables.users, + self.tables.orders, + self.classes.Order, + ) - mapper(User, users, properties={ - 'order': relationship(Order, uselist=False) - }) + mapper( + User, + users, + properties={"order": relationship(Order, uselist=False)}, + ) mapper(Order, orders) s = create_session() - assert_raises(sa.exc.SAWarning, - s.query(User).options(subqueryload(User.order)).all) + assert_raises( + sa.exc.SAWarning, + s.query(User).options(subqueryload(User.order)).all, + ) class LoadOnExistingTest(_fixtures.FixtureTest): """test that loaders from a base Query fully populate.""" - run_inserts = 'once' + run_inserts = "once" run_deletes = None def _collection_to_scalar_fixture(self): - User, Address, Dingaling = self.classes.User, \ - self.classes.Address, self.classes.Dingaling - mapper(User, self.tables.users, properties={ - 'addresses': relationship(Address), - }) - mapper(Address, self.tables.addresses, properties={ - 'dingaling': relationship(Dingaling) - }) + User, Address, Dingaling = ( + self.classes.User, + self.classes.Address, + self.classes.Dingaling, + ) + mapper( + User, + self.tables.users, + properties={"addresses": relationship(Address)}, + ) + mapper( + Address, + self.tables.addresses, + properties={"dingaling": relationship(Dingaling)}, + ) mapper(Dingaling, self.tables.dingalings) sess = Session(autoflush=False) return User, Address, Dingaling, sess def _collection_to_collection_fixture(self): - User, Order, Item = self.classes.User, \ - self.classes.Order, self.classes.Item - mapper(User, self.tables.users, properties={ - 'orders': relationship(Order), - }) - mapper(Order, self.tables.orders, properties={ - 'items': relationship(Item, secondary=self.tables.order_items), - }) + User, Order, Item = ( + self.classes.User, + self.classes.Order, + self.classes.Item, + ) + mapper( + User, self.tables.users, properties={"orders": relationship(Order)} + ) + mapper( + Order, + self.tables.orders, + properties={ + "items": relationship(Item, secondary=self.tables.order_items) + }, + ) mapper(Item, self.tables.items) sess = Session(autoflush=False) @@ -886,19 +1228,25 @@ class LoadOnExistingTest(_fixtures.FixtureTest): def _eager_config_fixture(self): User, Address = self.classes.User, self.classes.Address - mapper(User, self.tables.users, properties={ - 'addresses': relationship(Address, lazy="subquery"), - }) + mapper( + User, + self.tables.users, + properties={"addresses": relationship(Address, lazy="subquery")}, + ) mapper(Address, self.tables.addresses) sess = Session(autoflush=False) return User, Address, sess def _deferred_config_fixture(self): User, Address = self.classes.User, self.classes.Address - mapper(User, self.tables.users, properties={ - 'name': deferred(self.tables.users.c.name), - 'addresses': relationship(Address, lazy="subquery"), - }) + mapper( + User, + self.tables.users, + properties={ + "name": deferred(self.tables.users.c.name), + "addresses": relationship(Address, lazy="subquery"), + }, + ) mapper(Address, self.tables.addresses) sess = Session(autoflush=False) return User, Address, sess @@ -907,24 +1255,26 @@ class LoadOnExistingTest(_fixtures.FixtureTest): User, Address, sess = self._eager_config_fixture() u1 = sess.query(User).get(8) - assert 'addresses' in u1.__dict__ + assert "addresses" in u1.__dict__ sess.expire(u1) def go(): eq_(u1.id, 8) + self.assert_sql_count(testing.db, go, 1) - assert 'addresses' not in u1.__dict__ + assert "addresses" not in u1.__dict__ def test_no_query_on_deferred(self): User, Address, sess = self._deferred_config_fixture() u1 = sess.query(User).get(8) - assert 'addresses' in u1.__dict__ - sess.expire(u1, ['addresses']) + assert "addresses" in u1.__dict__ + sess.expire(u1, ["addresses"]) def go(): - eq_(u1.name, 'ed') + eq_(u1.name, "ed") + self.assert_sql_count(testing.db, go, 1) - assert 'addresses' not in u1.__dict__ + assert "addresses" not in u1.__dict__ def test_populate_existing_propagate(self): User, Address, sess = self._eager_config_fixture() @@ -945,17 +1295,18 @@ class LoadOnExistingTest(_fixtures.FixtureTest): a1 = Address() u1.addresses.append(a1) a2 = u1.addresses[0] - a2.email_address = 'foo' - sess.query(User).options(subqueryload_all("addresses.dingaling")).\ - filter_by(id=8).all() + a2.email_address = "foo" + sess.query(User).options( + subqueryload_all("addresses.dingaling") + ).filter_by(id=8).all() assert u1.addresses[-1] is a1 for a in u1.addresses: if a is not a1: - assert 'dingaling' in a.__dict__ + assert "dingaling" in a.__dict__ else: - assert 'dingaling' not in a.__dict__ + assert "dingaling" not in a.__dict__ if a is a2: - eq_(a2.email_address, 'foo') + eq_(a2.email_address, "foo") def test_loads_second_level_collection_to_collection(self): User, Order, Item, sess = self._collection_to_collection_fixture() @@ -964,76 +1315,92 @@ class LoadOnExistingTest(_fixtures.FixtureTest): u1.orders o1 = Order() u1.orders.append(o1) - sess.query(User).options(subqueryload_all("orders.items")).\ - filter_by(id=7).all() + sess.query(User).options(subqueryload_all("orders.items")).filter_by( + id=7 + ).all() for o in u1.orders: if o is not o1: - assert 'items' in o.__dict__ + assert "items" in o.__dict__ else: - assert 'items' not in o.__dict__ + assert "items" not in o.__dict__ def test_load_two_levels_collection_to_scalar(self): User, Address, Dingaling, sess = self._collection_to_scalar_fixture() - u1 = sess.query(User).filter_by(id=8).options( - subqueryload("addresses")).one() + u1 = ( + sess.query(User) + .filter_by(id=8) + .options(subqueryload("addresses")) + .one() + ) sess.query(User).filter_by(id=8).options( - subqueryload_all("addresses.dingaling")).first() - assert 'dingaling' in u1.addresses[0].__dict__ + subqueryload_all("addresses.dingaling") + ).first() + assert "dingaling" in u1.addresses[0].__dict__ def test_load_two_levels_collection_to_collection(self): User, Order, Item, sess = self._collection_to_collection_fixture() - u1 = sess.query(User).filter_by(id=7).options( - subqueryload("orders")).one() + u1 = ( + sess.query(User) + .filter_by(id=7) + .options(subqueryload("orders")) + .one() + ) sess.query(User).filter_by(id=7).options( - subqueryload_all("orders.items")).first() - assert 'items' in u1.orders[0].__dict__ + subqueryload_all("orders.items") + ).first() + assert "items" in u1.orders[0].__dict__ class OrderBySecondaryTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('m2m', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('aid', Integer, ForeignKey('a.id')), - Column('bid', Integer, ForeignKey('b.id'))) - - Table('a', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50))) - Table('b', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50))) + Table( + "m2m", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("aid", Integer, ForeignKey("a.id")), + Column("bid", Integer, ForeignKey("b.id")), + ) + + Table( + "a", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + Table( + "b", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) @classmethod def fixtures(cls): return dict( - a=(('id', 'data'), - (1, 'a1'), - (2, 'a2')), - - b=(('id', 'data'), - (1, 'b1'), - (2, 'b2'), - (3, 'b3'), - (4, 'b4')), - - m2m=(('id', 'aid', 'bid'), - (2, 1, 1), - (4, 2, 4), - (1, 1, 3), - (6, 2, 2), - (3, 1, 2), - (5, 2, 3))) + a=(("id", "data"), (1, "a1"), (2, "a2")), + b=(("id", "data"), (1, "b1"), (2, "b2"), (3, "b3"), (4, "b4")), + m2m=( + ("id", "aid", "bid"), + (2, 1, 1), + (4, 2, 4), + (1, 1, 3), + (6, 2, 2), + (3, 1, 2), + (5, 2, 3), + ), + ) def test_ordering(self): - a, m2m, b = (self.tables.a, - self.tables.m2m, - self.tables.b) + a, m2m, b = (self.tables.a, self.tables.m2m, self.tables.b) class A(fixtures.ComparableEntity): pass @@ -1041,54 +1408,88 @@ class OrderBySecondaryTest(fixtures.MappedTest): class B(fixtures.ComparableEntity): pass - mapper(A, a, properties={ - 'bs': relationship(B, secondary=m2m, lazy='subquery', - order_by=m2m.c.id) - }) + mapper( + A, + a, + properties={ + "bs": relationship( + B, secondary=m2m, lazy="subquery", order_by=m2m.c.id + ) + }, + ) mapper(B, b) sess = create_session() def go(): - eq_(sess.query(A).all(), [ - A(data='a1', bs=[B(data='b3'), B(data='b1'), B(data='b2')]), - A(bs=[B(data='b4'), B(data='b3'), B(data='b2')]) - ]) + eq_( + sess.query(A).all(), + [ + A( + data="a1", + bs=[B(data="b3"), B(data="b1"), B(data="b2")], + ), + A(bs=[B(data="b4"), B(data="b3"), B(data="b2")]), + ], + ) + self.assert_sql_count(testing.db, go, 2) class BaseRelationFromJoinedSubclassTest(_Polymorphic): @classmethod def define_tables(cls, metadata): - people = Table('people', metadata, - Column('person_id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('type', String(30))) + people = Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("type", String(30)), + ) # to test fully, PK of engineers table must be # named differently from that of people - engineers = Table('engineers', metadata, - Column('engineer_id', Integer, - ForeignKey('people.person_id'), - primary_key=True), - Column('primary_language', String(50))) - - paperwork = Table('paperwork', metadata, - Column('paperwork_id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('description', String(50)), - Column('person_id', Integer, - ForeignKey('people.person_id'))) + engineers = Table( + "engineers", + metadata, + Column( + "engineer_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("primary_language", String(50)), + ) + + paperwork = Table( + "paperwork", + metadata, + Column( + "paperwork_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("description", String(50)), + Column("person_id", Integer, ForeignKey("people.person_id")), + ) pages = Table( - 'pages', metadata, - Column('page_id', - Integer, primary_key=True, test_needs_autoincrement=True), - Column('stuff', String(50)), - Column('paperwork_id', ForeignKey('paperwork.paperwork_id')) + "pages", + metadata, + Column( + "page_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("stuff", String(50)), + Column("paperwork_id", ForeignKey("paperwork.paperwork_id")), ) @classmethod @@ -1098,20 +1499,30 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): paperwork = cls.tables.paperwork pages = cls.tables.pages - mapper(Person, people, - polymorphic_on=people.c.type, - polymorphic_identity='person', - properties={ - 'paperwork': relationship( - Paperwork, order_by=paperwork.c.paperwork_id)}) + mapper( + Person, + people, + polymorphic_on=people.c.type, + polymorphic_identity="person", + properties={ + "paperwork": relationship( + Paperwork, order_by=paperwork.c.paperwork_id + ) + }, + ) - mapper(Engineer, engineers, - inherits=Person, - polymorphic_identity='engineer') + mapper( + Engineer, + engineers, + inherits=Person, + polymorphic_identity="engineer", + ) - mapper(Paperwork, paperwork, properties={ - 'pages': relationship(Page, order_by=pages.c.page_id) - }) + mapper( + Paperwork, + paperwork, + properties={"pages": relationship(Page, order_by=pages.c.page_id)}, + ) mapper(Page, pages) @@ -1120,15 +1531,22 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): e1 = Engineer(primary_language="java") e2 = Engineer(primary_language="c++") - e1.paperwork = [Paperwork(description="tps report #1", - pages=[ - Page(stuff='report1 page1'), - Page(stuff='report1 page2') - ]), - Paperwork(description="tps report #2", - pages=[ - Page(stuff='report2 page1'), - Page(stuff='report2 page2')])] + e1.paperwork = [ + Paperwork( + description="tps report #1", + pages=[ + Page(stuff="report1 page1"), + Page(stuff="report1 page2"), + ], + ), + Paperwork( + description="tps report #2", + pages=[ + Page(stuff="report2 page1"), + Page(stuff="report2 page2"), + ], + ), + ] e2.paperwork = [Paperwork(description="tps report #3")] sess = create_session() sess.add_all([e1, e2]) @@ -1138,16 +1556,21 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): sess = create_session() # use Person.paperwork here just to give the least # amount of context - q = sess.query(Engineer).\ - filter(Engineer.primary_language == 'java').\ - options(subqueryload(Person.paperwork)) + q = ( + sess.query(Engineer) + .filter(Engineer.primary_language == "java") + .options(subqueryload(Person.paperwork)) + ) def go(): - eq_(q.all()[0].paperwork, - [Paperwork(description="tps report #1"), - Paperwork(description="tps report #2")], + eq_( + q.all()[0].paperwork, + [ + Paperwork(description="tps report #1"), + Paperwork(description="tps report #2"), + ], + ) - ) self.assert_sql_execution( testing.db, go, @@ -1159,7 +1582,7 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "FROM people JOIN engineers ON " "people.person_id = engineers.engineer_id " "WHERE engineers.primary_language = :primary_language_1", - {"primary_language_1": "java"} + {"primary_language_1": "java"}, ), # ensure we get "people JOIN engineer" here, even though # primary key "people.person_id" is against "Person" @@ -1178,26 +1601,31 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "JOIN paperwork " "ON anon_1.people_person_id = paperwork.person_id " "ORDER BY anon_1.people_person_id, paperwork.paperwork_id", - {"primary_language_1": "java"} - ) + {"primary_language_1": "java"}, + ), ) def test_correct_subquery_existingfrom(self): sess = create_session() # use Person.paperwork here just to give the least # amount of context - q = sess.query(Engineer).\ - filter(Engineer.primary_language == 'java').\ - join(Engineer.paperwork).\ - filter(Paperwork.description == "tps report #2").\ - options(subqueryload(Person.paperwork)) + q = ( + sess.query(Engineer) + .filter(Engineer.primary_language == "java") + .join(Engineer.paperwork) + .filter(Paperwork.description == "tps report #2") + .options(subqueryload(Person.paperwork)) + ) def go(): - eq_(q.one().paperwork, - [Paperwork(description="tps report #1"), - Paperwork(description="tps report #2")], + eq_( + q.one().paperwork, + [ + Paperwork(description="tps report #1"), + Paperwork(description="tps report #2"), + ], + ) - ) self.assert_sql_execution( testing.db, go, @@ -1211,8 +1639,10 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "JOIN paperwork ON people.person_id = paperwork.person_id " "WHERE engineers.primary_language = :primary_language_1 " "AND paperwork.description = :description_1", - {"primary_language_1": "java", - "description_1": "tps report #2"} + { + "primary_language_1": "java", + "description_1": "tps report #2", + }, ), CompiledSQL( "SELECT paperwork.paperwork_id AS paperwork_paperwork_id, " @@ -1228,30 +1658,46 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "JOIN paperwork ON anon_1.people_person_id = " "paperwork.person_id " "ORDER BY anon_1.people_person_id, paperwork.paperwork_id", - {"primary_language_1": "java", - "description_1": "tps report #2"} - ) + { + "primary_language_1": "java", + "description_1": "tps report #2", + }, + ), ) def test_correct_subquery_multilevel(self): sess = create_session() # use Person.paperwork here just to give the least # amount of context - q = sess.query(Engineer).\ - filter(Engineer.primary_language == 'java').\ - options( - subqueryload(Engineer.paperwork).subqueryload(Paperwork.pages)) + q = ( + sess.query(Engineer) + .filter(Engineer.primary_language == "java") + .options( + subqueryload(Engineer.paperwork).subqueryload(Paperwork.pages) + ) + ) def go(): - eq_(q.one().paperwork, - [Paperwork(description="tps report #1", - pages=[Page(stuff='report1 page1'), - Page(stuff='report1 page2')]), - Paperwork(description="tps report #2", - pages=[Page(stuff='report2 page1'), - Page(stuff='report2 page2')])], + eq_( + q.one().paperwork, + [ + Paperwork( + description="tps report #1", + pages=[ + Page(stuff="report1 page1"), + Page(stuff="report1 page2"), + ], + ), + Paperwork( + description="tps report #2", + pages=[ + Page(stuff="report2 page1"), + Page(stuff="report2 page2"), + ], + ), + ], + ) - ) self.assert_sql_execution( testing.db, go, @@ -1263,7 +1709,7 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "FROM people JOIN engineers " "ON people.person_id = engineers.engineer_id " "WHERE engineers.primary_language = :primary_language_1", - {"primary_language_1": "java"} + {"primary_language_1": "java"}, ), CompiledSQL( "SELECT paperwork.paperwork_id AS paperwork_paperwork_id, " @@ -1277,7 +1723,7 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "AS anon_1 JOIN paperwork " "ON anon_1.people_person_id = paperwork.person_id " "ORDER BY anon_1.people_person_id, paperwork.paperwork_id", - {"primary_language_1": "java"} + {"primary_language_1": "java"}, ), CompiledSQL( "SELECT pages.page_id AS pages_page_id, " @@ -1292,8 +1738,8 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "ON anon_1.people_person_id = paperwork_1.person_id " "JOIN pages ON paperwork_1.paperwork_id = pages.paperwork_id " "ORDER BY paperwork_1.paperwork_id, pages.page_id", - {"primary_language_1": "java"} - ) + {"primary_language_1": "java"}, + ), ) def test_correct_subquery_with_polymorphic_no_alias(self): @@ -1301,20 +1747,24 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): sess = create_session() wp = with_polymorphic(Person, [Engineer]) - q = sess.query(wp).\ - options(subqueryload(wp.paperwork)).\ - order_by(Engineer.primary_language.desc()) + q = ( + sess.query(wp) + .options(subqueryload(wp.paperwork)) + .order_by(Engineer.primary_language.desc()) + ) def go(): - eq_(q.first(), + eq_( + q.first(), Engineer( paperwork=[ Paperwork(description="tps report #1"), - Paperwork(description="tps report #2")], - primary_language='java' + Paperwork(description="tps report #2"), + ], + primary_language="java", + ), ) - ) self.assert_sql_execution( testing.db, go, @@ -1325,7 +1775,8 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "engineers.primary_language AS engineers_primary_language " "FROM people LEFT OUTER JOIN engineers ON people.person_id = " "engineers.engineer_id ORDER BY engineers.primary_language " - "DESC LIMIT :param_1"), + "DESC LIMIT :param_1" + ), CompiledSQL( "SELECT paperwork.paperwork_id AS paperwork_paperwork_id, " "paperwork.description AS paperwork_description, " @@ -1336,7 +1787,8 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "engineers.engineer_id ORDER BY engineers.primary_language " "DESC LIMIT :param_1) AS anon_1 JOIN paperwork " "ON anon_1.people_person_id = paperwork.person_id " - "ORDER BY anon_1.people_person_id, paperwork.paperwork_id") + "ORDER BY anon_1.people_person_id, paperwork.paperwork_id" + ), ) def test_correct_subquery_with_polymorphic_alias(self): @@ -1344,20 +1796,24 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): sess = create_session() wp = with_polymorphic(Person, [Engineer], aliased=True) - q = sess.query(wp).\ - options(subqueryload(wp.paperwork)).\ - order_by(wp.Engineer.primary_language.desc()) + q = ( + sess.query(wp) + .options(subqueryload(wp.paperwork)) + .order_by(wp.Engineer.primary_language.desc()) + ) def go(): - eq_(q.first(), + eq_( + q.first(), Engineer( paperwork=[ Paperwork(description="tps report #1"), - Paperwork(description="tps report #2")], - primary_language='java' + Paperwork(description="tps report #2"), + ], + primary_language="java", + ), ) - ) self.assert_sql_execution( testing.db, go, @@ -1376,7 +1832,8 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "FROM people LEFT OUTER JOIN engineers ON people.person_id = " "engineers.engineer_id) AS anon_1 " "ORDER BY anon_1.engineers_primary_language DESC " - "LIMIT :param_1"), + "LIMIT :param_1" + ), CompiledSQL( "SELECT paperwork.paperwork_id AS paperwork_paperwork_id, " "paperwork.description AS paperwork_description, " @@ -1398,7 +1855,8 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "JOIN paperwork " "ON anon_1.anon_2_people_person_id = paperwork.person_id " "ORDER BY anon_1.anon_2_people_person_id, " - "paperwork.paperwork_id") + "paperwork.paperwork_id" + ), ) def test_correct_subquery_with_polymorphic_flat_alias(self): @@ -1406,20 +1864,24 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): sess = create_session() wp = with_polymorphic(Person, [Engineer], aliased=True, flat=True) - q = sess.query(wp).\ - options(subqueryload(wp.paperwork)).\ - order_by(wp.Engineer.primary_language.desc()) + q = ( + sess.query(wp) + .options(subqueryload(wp.paperwork)) + .order_by(wp.Engineer.primary_language.desc()) + ) def go(): - eq_(q.first(), + eq_( + q.first(), Engineer( paperwork=[ Paperwork(description="tps report #1"), - Paperwork(description="tps report #2")], - primary_language='java' + Paperwork(description="tps report #2"), + ], + primary_language="java", + ), ) - ) self.assert_sql_execution( testing.db, go, @@ -1432,7 +1894,8 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "FROM people AS people_1 " "LEFT OUTER JOIN engineers AS engineers_1 " "ON people_1.person_id = engineers_1.engineer_id " - "ORDER BY engineers_1.primary_language DESC LIMIT :param_1"), + "ORDER BY engineers_1.primary_language DESC LIMIT :param_1" + ), CompiledSQL( "SELECT paperwork.paperwork_id AS paperwork_paperwork_id, " "paperwork.description AS paperwork_description, " @@ -1446,46 +1909,75 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): "AS anon_1 JOIN paperwork ON anon_1.people_1_person_id = " "paperwork.person_id ORDER BY anon_1.people_1_person_id, " "paperwork.paperwork_id" - ) + ), ) class SubRelationFromJoinedSubclassMultiLevelTest(_Polymorphic): @classmethod def define_tables(cls, metadata): - Table('companies', metadata, - Column('company_id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50))) - - Table('people', metadata, - Column('person_id', Integer, - primary_key=True, - test_needs_autoincrement=True), - Column('company_id', ForeignKey('companies.company_id')), - Column('name', String(50)), - Column('type', String(30))) - - Table('engineers', metadata, - Column('engineer_id', ForeignKey('people.person_id'), - primary_key=True), - Column('primary_language', String(50))) - - Table('machines', metadata, - Column('machine_id', - Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50)), - Column('engineer_id', ForeignKey('engineers.engineer_id')), - Column('machine_type_id', - ForeignKey('machine_type.machine_type_id'))) - - Table('machine_type', metadata, - Column('machine_type_id', - Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50))) + Table( + "companies", + metadata, + Column( + "company_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + ) + + Table( + "people", + metadata, + Column( + "person_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("company_id", ForeignKey("companies.company_id")), + Column("name", String(50)), + Column("type", String(30)), + ) + + Table( + "engineers", + metadata, + Column( + "engineer_id", ForeignKey("people.person_id"), primary_key=True + ), + Column("primary_language", String(50)), + ) + + Table( + "machines", + metadata, + Column( + "machine_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + Column("engineer_id", ForeignKey("engineers.engineer_id")), + Column( + "machine_type_id", ForeignKey("machine_type.machine_type_id") + ), + ) + + Table( + "machine_type", + metadata, + Column( + "machine_type_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("name", String(50)), + ) @classmethod def setup_mappers(cls): @@ -1495,24 +1987,36 @@ class SubRelationFromJoinedSubclassMultiLevelTest(_Polymorphic): machines = cls.tables.machines machine_type = cls.tables.machine_type - mapper(Company, companies, properties={ - 'employees': relationship(Person, order_by=people.c.person_id) - }) - mapper(Person, people, - polymorphic_on=people.c.type, - polymorphic_identity='person', - with_polymorphic='*') - - mapper(Engineer, engineers, - inherits=Person, - polymorphic_identity='engineer', properties={ - 'machines': relationship(Machine, - order_by=machines.c.machine_id) - }) - - mapper(Machine, machines, properties={ - 'type': relationship(MachineType) - }) + mapper( + Company, + companies, + properties={ + "employees": relationship(Person, order_by=people.c.person_id) + }, + ) + mapper( + Person, + people, + polymorphic_on=people.c.type, + polymorphic_identity="person", + with_polymorphic="*", + ) + + mapper( + Engineer, + engineers, + inherits=Person, + polymorphic_identity="engineer", + properties={ + "machines": relationship( + Machine, order_by=machines.c.machine_id + ) + }, + ) + + mapper( + Machine, machines, properties={"type": relationship(MachineType)} + ) mapper(MachineType, machine_type) @classmethod @@ -1524,50 +2028,53 @@ class SubRelationFromJoinedSubclassMultiLevelTest(_Polymorphic): @classmethod def _fixture(cls): - mt1 = MachineType(name='mt1') - mt2 = MachineType(name='mt2') + mt1 = MachineType(name="mt1") + mt2 = MachineType(name="mt2") return Company( employees=[ Engineer( - name='e1', + name="e1", machines=[ - Machine(name='m1', type=mt1), - Machine(name='m2', type=mt2) - ] + Machine(name="m1", type=mt1), + Machine(name="m2", type=mt2), + ], ), Engineer( - name='e2', + name="e2", machines=[ - Machine(name='m3', type=mt1), - Machine(name='m4', type=mt1) - ] - ) - ]) + Machine(name="m3", type=mt1), + Machine(name="m4", type=mt1), + ], + ), + ] + ) def test_chained_subq_subclass(self): s = Session() q = s.query(Company).options( - subqueryload(Company.employees.of_type(Engineer)). - subqueryload(Engineer.machines). - subqueryload(Machine.type) + subqueryload(Company.employees.of_type(Engineer)) + .subqueryload(Engineer.machines) + .subqueryload(Machine.type) ) def go(): - eq_( - q.all(), - [self._fixture()] - ) + eq_(q.all(), [self._fixture()]) + self.assert_sql_count(testing.db, go, 4) class SelfReferentialTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('nodes', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, ForeignKey('nodes.id')), - Column('data', String(30))) + Table( + "nodes", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_id", Integer, ForeignKey("nodes.id")), + Column("data", String(30)), + ) def test_basic(self): nodes = self.tables.nodes @@ -1576,23 +2083,27 @@ class SelfReferentialTest(fixtures.MappedTest): def append(self, node): self.children.append(node) - mapper(Node, nodes, properties={ - 'children': relationship(Node, - lazy='subquery', - join_depth=3, order_by=nodes.c.id) - }) + mapper( + Node, + nodes, + properties={ + "children": relationship( + Node, lazy="subquery", join_depth=3, order_by=nodes.c.id + ) + }, + ) sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) - n2 = Node(data='n2') - n2.append(Node(data='n21')) - n2.children[0].append(Node(data='n211')) - n2.children[0].append(Node(data='n212')) + n1 = Node(data="n1") + n1.append(Node(data="n11")) + n1.append(Node(data="n12")) + n1.append(Node(data="n13")) + n1.children[1].append(Node(data="n121")) + n1.children[1].append(Node(data="n122")) + n1.children[1].append(Node(data="n123")) + n2 = Node(data="n2") + n2.append(Node(data="n21")) + n2.children[0].append(Node(data="n211")) + n2.children[0].append(Node(data="n212")) sess.add(n1) sess.add(n2) @@ -1600,24 +2111,45 @@ class SelfReferentialTest(fixtures.MappedTest): sess.expunge_all() def go(): - d = sess.query(Node).filter(Node.data.in_(['n1', 'n2'])).\ - order_by(Node.data).all() - eq_([Node(data='n1', children=[ - Node(data='n11'), - Node(data='n12', children=[ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ]), - Node(data='n13') - ]), - Node(data='n2', children=[ - Node(data='n21', children=[ - Node(data='n211'), - Node(data='n212'), - ]) - ]) - ], d) + d = ( + sess.query(Node) + .filter(Node.data.in_(["n1", "n2"])) + .order_by(Node.data) + .all() + ) + eq_( + [ + Node( + data="n1", + children=[ + Node(data="n11"), + Node( + data="n12", + children=[ + Node(data="n121"), + Node(data="n122"), + Node(data="n123"), + ], + ), + Node(data="n13"), + ], + ), + Node( + data="n2", + children=[ + Node( + data="n21", + children=[ + Node(data="n211"), + Node(data="n212"), + ], + ) + ], + ), + ], + d, + ) + self.assert_sql_count(testing.db, go, 4) def test_lazy_fallback_doesnt_affect_eager(self): @@ -1627,20 +2159,25 @@ class SelfReferentialTest(fixtures.MappedTest): def append(self, node): self.children.append(node) - mapper(Node, nodes, properties={ - 'children': relationship(Node, lazy='subquery', join_depth=1, - order_by=nodes.c.id) - }) + mapper( + Node, + nodes, + properties={ + "children": relationship( + Node, lazy="subquery", join_depth=1, order_by=nodes.c.id + ) + }, + ) sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[0].append(Node(data='n111')) - n1.children[0].append(Node(data='n112')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) + n1 = Node(data="n1") + n1.append(Node(data="n11")) + n1.append(Node(data="n12")) + n1.append(Node(data="n13")) + n1.children[0].append(Node(data="n111")) + n1.children[0].append(Node(data="n112")) + n1.children[1].append(Node(data="n121")) + n1.children[1].append(Node(data="n122")) + n1.children[1].append(Node(data="n123")) sess.add(n1) sess.flush() sess.expunge_all() @@ -1649,19 +2186,16 @@ class SelfReferentialTest(fixtures.MappedTest): allnodes = sess.query(Node).order_by(Node.data).all() n11 = allnodes[1] - eq_(n11.data, 'n11') - eq_([ - Node(data='n111'), - Node(data='n112'), - ], list(n11.children)) + eq_(n11.data, "n11") + eq_([Node(data="n111"), Node(data="n112")], list(n11.children)) n12 = allnodes[4] - eq_(n12.data, 'n12') - eq_([ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ], list(n12.children)) + eq_(n12.data, "n12") + eq_( + [Node(data="n121"), Node(data="n122"), Node(data="n123")], + list(n12.children), + ) + self.assert_sql_count(testing.db, go, 2) def test_with_deferred(self): @@ -1671,40 +2205,55 @@ class SelfReferentialTest(fixtures.MappedTest): def append(self, node): self.children.append(node) - mapper(Node, nodes, properties={ - 'children': relationship(Node, lazy='subquery', join_depth=3, - order_by=nodes.c.id), - 'data': deferred(nodes.c.data) - }) + mapper( + Node, + nodes, + properties={ + "children": relationship( + Node, lazy="subquery", join_depth=3, order_by=nodes.c.id + ), + "data": deferred(nodes.c.data), + }, + ) sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) + n1 = Node(data="n1") + n1.append(Node(data="n11")) + n1.append(Node(data="n12")) sess.add(n1) sess.flush() sess.expunge_all() def go(): eq_( - Node(data='n1', children=[Node(data='n11'), Node(data='n12')]), + Node(data="n1", children=[Node(data="n11"), Node(data="n12")]), sess.query(Node).order_by(Node.id).first(), ) + self.assert_sql_count(testing.db, go, 6) sess.expunge_all() def go(): - eq_(Node(data='n1', children=[Node(data='n11'), Node(data='n12')]), - sess.query(Node).options(undefer('data')).order_by(Node.id) - .first()) + eq_( + Node(data="n1", children=[Node(data="n11"), Node(data="n12")]), + sess.query(Node) + .options(undefer("data")) + .order_by(Node.id) + .first(), + ) + self.assert_sql_count(testing.db, go, 5) sess.expunge_all() def go(): - eq_(Node(data='n1', children=[Node(data='n11'), Node(data='n12')]), - sess.query(Node).options(undefer('data'), - undefer('children.data')).first()) + eq_( + Node(data="n1", children=[Node(data="n11"), Node(data="n12")]), + sess.query(Node) + .options(undefer("data"), undefer("children.data")) + .first(), + ) + self.assert_sql_count(testing.db, go, 3) def test_options(self): @@ -1714,33 +2263,50 @@ class SelfReferentialTest(fixtures.MappedTest): def append(self, node): self.children.append(node) - mapper(Node, nodes, properties={ - 'children': relationship(Node, order_by=nodes.c.id) - }) + mapper( + Node, + nodes, + properties={"children": relationship(Node, order_by=nodes.c.id)}, + ) sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) + n1 = Node(data="n1") + n1.append(Node(data="n11")) + n1.append(Node(data="n12")) + n1.append(Node(data="n13")) + n1.children[1].append(Node(data="n121")) + n1.children[1].append(Node(data="n122")) + n1.children[1].append(Node(data="n123")) sess.add(n1) sess.flush() sess.expunge_all() def go(): - d = sess.query(Node).filter_by(data='n1').order_by(Node.id).\ - options(subqueryload_all('children.children')).first() - eq_(Node(data='n1', children=[ - Node(data='n11'), - Node(data='n12', children=[ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ]), - Node(data='n13') - ]), d) + d = ( + sess.query(Node) + .filter_by(data="n1") + .order_by(Node.id) + .options(subqueryload_all("children.children")) + .first() + ) + eq_( + Node( + data="n1", + children=[ + Node(data="n11"), + Node( + data="n12", + children=[ + Node(data="n121"), + Node(data="n122"), + Node(data="n123"), + ], + ), + Node(data="n13"), + ], + ), + d, + ) + self.assert_sql_count(testing.db, go, 3) def test_no_depth(self): @@ -1752,57 +2318,79 @@ class SelfReferentialTest(fixtures.MappedTest): def append(self, node): self.children.append(node) - mapper(Node, nodes, properties={ - 'children': relationship(Node, lazy='subquery') - }) + mapper( + Node, + nodes, + properties={"children": relationship(Node, lazy="subquery")}, + ) sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) - n2 = Node(data='n2') - n2.append(Node(data='n21')) + n1 = Node(data="n1") + n1.append(Node(data="n11")) + n1.append(Node(data="n12")) + n1.append(Node(data="n13")) + n1.children[1].append(Node(data="n121")) + n1.children[1].append(Node(data="n122")) + n1.children[1].append(Node(data="n123")) + n2 = Node(data="n2") + n2.append(Node(data="n21")) sess.add(n1) sess.add(n2) sess.flush() sess.expunge_all() def go(): - d = sess.query(Node).filter(Node.data.in_( - ['n1', 'n2'])).order_by(Node.data).all() - eq_([ - Node(data='n1', children=[ - Node(data='n11'), - Node(data='n12', children=[ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ]), - Node(data='n13') - ]), - Node(data='n2', children=[ - Node(data='n21') - ]) - ], d) + d = ( + sess.query(Node) + .filter(Node.data.in_(["n1", "n2"])) + .order_by(Node.data) + .all() + ) + eq_( + [ + Node( + data="n1", + children=[ + Node(data="n11"), + Node( + data="n12", + children=[ + Node(data="n121"), + Node(data="n122"), + Node(data="n123"), + ], + ), + Node(data="n13"), + ], + ), + Node(data="n2", children=[Node(data="n21")]), + ], + d, + ) + self.assert_sql_count(testing.db, go, 4) class InheritanceToRelatedTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('foo', metadata, - Column("id", Integer, primary_key=True), - Column("type", String(50)), - Column("related_id", Integer, ForeignKey("related.id"))) - Table("bar", metadata, - Column("id", Integer, ForeignKey('foo.id'), primary_key=True)) - Table("baz", metadata, - Column("id", Integer, ForeignKey('foo.id'), primary_key=True)) - Table("related", metadata, - Column("id", Integer, primary_key=True)) + Table( + "foo", + metadata, + Column("id", Integer, primary_key=True), + Column("type", String(50)), + Column("related_id", Integer, ForeignKey("related.id")), + ) + Table( + "bar", + metadata, + Column("id", Integer, ForeignKey("foo.id"), primary_key=True), + ) + Table( + "baz", + metadata, + Column("id", Integer, ForeignKey("foo.id"), primary_key=True), + ) + Table("related", metadata, Column("id", Integer, primary_key=True)) @classmethod def setup_classes(cls): @@ -1822,96 +2410,116 @@ class InheritanceToRelatedTest(fixtures.MappedTest): def fixtures(cls): return dict( foo=[ - ('id', 'type', 'related_id'), - (1, 'bar', 1), - (2, 'bar', 2), - (3, 'baz', 1), - (4, 'baz', 2), + ("id", "type", "related_id"), + (1, "bar", 1), + (2, "bar", 2), + (3, "baz", 1), + (4, "baz", 2), ], - bar=[ - ('id', ), - (1,), - (2,) - ], - baz=[ - ('id', ), - (3,), - (4,) - ], - related=[ - ('id', ), - (1,), - (2,) - ] + bar=[("id",), (1,), (2,)], + baz=[("id",), (3,), (4,)], + related=[("id",), (1,), (2,)], ) @classmethod def setup_mappers(cls): - mapper(cls.classes.Foo, cls.tables.foo, properties={ - 'related': relationship(cls.classes.Related) - }, polymorphic_on=cls.tables.foo.c.type) - mapper(cls.classes.Bar, cls.tables.bar, polymorphic_identity='bar', - inherits=cls.classes.Foo) - mapper(cls.classes.Baz, cls.tables.baz, polymorphic_identity='baz', - inherits=cls.classes.Foo) + mapper( + cls.classes.Foo, + cls.tables.foo, + properties={"related": relationship(cls.classes.Related)}, + polymorphic_on=cls.tables.foo.c.type, + ) + mapper( + cls.classes.Bar, + cls.tables.bar, + polymorphic_identity="bar", + inherits=cls.classes.Foo, + ) + mapper( + cls.classes.Baz, + cls.tables.baz, + polymorphic_identity="baz", + inherits=cls.classes.Foo, + ) mapper(cls.classes.Related, cls.tables.related) def test_caches_query_per_base_subq(self): - Foo, Bar, Baz, Related = self.classes.Foo, self.classes.Bar, \ - self.classes.Baz, self.classes.Related + Foo, Bar, Baz, Related = ( + self.classes.Foo, + self.classes.Bar, + self.classes.Baz, + self.classes.Related, + ) s = Session(testing.db) def go(): eq_( - s.query(Foo).with_polymorphic([Bar, Baz]). - order_by(Foo.id). - options(subqueryload(Foo.related)).all(), + s.query(Foo) + .with_polymorphic([Bar, Baz]) + .order_by(Foo.id) + .options(subqueryload(Foo.related)) + .all(), [ Bar(id=1, related=Related(id=1)), Bar(id=2, related=Related(id=2)), Baz(id=3, related=Related(id=1)), - Baz(id=4, related=Related(id=2)) - ] + Baz(id=4, related=Related(id=2)), + ], ) + self.assert_sql_count(testing.db, go, 2) def test_caches_query_per_base_joined(self): # technically this should be in test_eager_relations - Foo, Bar, Baz, Related = self.classes.Foo, self.classes.Bar, \ - self.classes.Baz, self.classes.Related + Foo, Bar, Baz, Related = ( + self.classes.Foo, + self.classes.Bar, + self.classes.Baz, + self.classes.Related, + ) s = Session(testing.db) def go(): eq_( - s.query(Foo).with_polymorphic([Bar, Baz]). - order_by(Foo.id). - options(joinedload(Foo.related)).all(), + s.query(Foo) + .with_polymorphic([Bar, Baz]) + .order_by(Foo.id) + .options(joinedload(Foo.related)) + .all(), [ Bar(id=1, related=Related(id=1)), Bar(id=2, related=Related(id=2)), Baz(id=3, related=Related(id=1)), - Baz(id=4, related=Related(id=2)) - ] + Baz(id=4, related=Related(id=2)), + ], ) + self.assert_sql_count(testing.db, go, 1) class CyclicalInheritingEagerTestOne(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('t1', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('c2', String(30)), - Column('type', String(30))) - - Table('t2', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('c2', String(30)), - Column('type', String(30)), - Column('t1.id', Integer, ForeignKey('t1.c1'))) + Table( + "t1", + metadata, + Column( + "c1", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("c2", String(30)), + Column("type", String(30)), + ) + + Table( + "t2", + metadata, + Column( + "c1", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("c2", String(30)), + Column("type", String(30)), + Column("t1.id", Integer, ForeignKey("t1.c1")), + ) def test_basic(self): t2, t1 = self.tables.t2, self.tables.t1 @@ -1928,40 +2536,51 @@ class CyclicalInheritingEagerTestOne(fixtures.MappedTest): class SubT2(T2): pass - mapper(T, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1') - mapper(SubT, None, inherits=T, polymorphic_identity='subt1', - properties={'t2s': relationship( - SubT2, lazy='subquery', - backref=sa.orm.backref('subt', lazy='subquery'))}) - mapper(T2, t2, polymorphic_on=t2.c.type, polymorphic_identity='t2') - mapper(SubT2, None, inherits=T2, polymorphic_identity='subt2') + mapper(T, t1, polymorphic_on=t1.c.type, polymorphic_identity="t1") + mapper( + SubT, + None, + inherits=T, + polymorphic_identity="subt1", + properties={ + "t2s": relationship( + SubT2, + lazy="subquery", + backref=sa.orm.backref("subt", lazy="subquery"), + ) + }, + ) + mapper(T2, t2, polymorphic_on=t2.c.type, polymorphic_identity="t2") + mapper(SubT2, None, inherits=T2, polymorphic_identity="subt2") # testing a particular endless loop condition in eager load setup create_session().query(SubT).all() -class CyclicalInheritingEagerTestTwo(fixtures.DeclarativeMappedTest, - testing.AssertsCompiledSQL): - __dialect__ = 'default' +class CyclicalInheritingEagerTestTwo( + fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL +): + __dialect__ = "default" @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class PersistentObject(Base): - __tablename__ = 'persistent' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "persistent" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) class Movie(PersistentObject): - __tablename__ = 'movie' - id = Column(Integer, ForeignKey('persistent.id'), primary_key=True) - director_id = Column(Integer, ForeignKey('director.id')) + __tablename__ = "movie" + id = Column(Integer, ForeignKey("persistent.id"), primary_key=True) + director_id = Column(Integer, ForeignKey("director.id")) title = Column(String(50)) class Director(PersistentObject): - __tablename__ = 'director' - id = Column(Integer, ForeignKey('persistent.id'), primary_key=True) + __tablename__ = "director" + id = Column(Integer, ForeignKey("persistent.id"), primary_key=True) movies = relationship("Movie", foreign_keys=Movie.director_id) name = Column(String(50)) @@ -1970,26 +2589,27 @@ class CyclicalInheritingEagerTestTwo(fixtures.DeclarativeMappedTest, s = create_session() - ctx = s.query(Director).options(subqueryload('*'))._compile_context() - - q = ctx.attributes[('subquery', - (inspect(Director), - inspect(Director).attrs.movies))] - self.assert_compile(q, - "SELECT movie.id AS movie_id, " - "persistent.id AS persistent_id, " - "movie.director_id AS movie_director_id, " - "movie.title AS movie_title, " - "anon_1.director_id AS anon_1_director_id " - "FROM (SELECT director.id AS director_id " - "FROM persistent JOIN director " - "ON persistent.id = director.id) AS anon_1 " - "JOIN (persistent JOIN movie " - "ON persistent.id = movie.id) " - "ON anon_1.director_id = movie.director_id " - "ORDER BY anon_1.director_id", - dialect="default" - ) + ctx = s.query(Director).options(subqueryload("*"))._compile_context() + + q = ctx.attributes[ + ("subquery", (inspect(Director), inspect(Director).attrs.movies)) + ] + self.assert_compile( + q, + "SELECT movie.id AS movie_id, " + "persistent.id AS persistent_id, " + "movie.director_id AS movie_director_id, " + "movie.title AS movie_title, " + "anon_1.director_id AS anon_1_director_id " + "FROM (SELECT director.id AS director_id " + "FROM persistent JOIN director " + "ON persistent.id = director.id) AS anon_1 " + "JOIN (persistent JOIN movie " + "ON persistent.id = movie.id) " + "ON anon_1.director_id = movie.director_id " + "ORDER BY anon_1.director_id", + dialect="default", + ) def test_integrate(self): Director = self.classes.Director @@ -2005,15 +2625,16 @@ class CyclicalInheritingEagerTestTwo(fixtures.DeclarativeMappedTest, session.commit() session.close_all() - d = session.query(Director).options(subqueryload('*')).first() + d = session.query(Director).options(subqueryload("*")).first() assert len(list(session)) == 3 -class SubqueryloadDistinctTest(fixtures.DeclarativeMappedTest, - testing.AssertsCompiledSQL): - __dialect__ = 'default' +class SubqueryloadDistinctTest( + fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL +): + __dialect__ = "default" - run_inserts = 'once' + run_inserts = "once" run_deletes = None @classmethod @@ -2021,33 +2642,37 @@ class SubqueryloadDistinctTest(fixtures.DeclarativeMappedTest, Base = cls.DeclarativeBasic class Director(Base): - __tablename__ = 'director' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "director" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(50)) class DirectorPhoto(Base): - __tablename__ = 'director_photo' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + __tablename__ = "director_photo" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) path = Column(String(255)) - director_id = Column(Integer, ForeignKey('director.id')) + director_id = Column(Integer, ForeignKey("director.id")) director = relationship(Director, backref="photos") class Movie(Base): - __tablename__ = 'movie' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) - director_id = Column(Integer, ForeignKey('director.id')) + __tablename__ = "movie" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + director_id = Column(Integer, ForeignKey("director.id")) director = relationship(Director, backref="movies") title = Column(String(50)) credits = relationship("Credit", backref="movie") class Credit(Base): - __tablename__ = 'credit' - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) - movie_id = Column(Integer, ForeignKey('movie.id')) + __tablename__ = "credit" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + movie_id = Column(Integer, ForeignKey("movie.id")) @classmethod def insert_data(cls): @@ -2056,11 +2681,12 @@ class SubqueryloadDistinctTest(fixtures.DeclarativeMappedTest, DirectorPhoto = cls.classes.DirectorPhoto Credit = cls.classes.Credit - d = Director(name='Woody Allen') - d.photos = [DirectorPhoto(path='/1.jpg'), - DirectorPhoto(path='/2.jpg')] - d.movies = [Movie(title='Manhattan', credits=[Credit(), Credit()]), - Movie(title='Sweet and Lowdown', credits=[Credit()])] + d = Director(name="Woody Allen") + d.photos = [DirectorPhoto(path="/1.jpg"), DirectorPhoto(path="/2.jpg")] + d.movies = [ + Movie(title="Manhattan", credits=[Credit(), Credit()]), + Movie(title="Sweet and Lowdown", credits=[Credit()]), + ] sess = create_session() sess.add_all([d]) sess.flush() @@ -2073,9 +2699,7 @@ class SubqueryloadDistinctTest(fixtures.DeclarativeMappedTest, self._run_test_m2o(None, True) self._run_test_m2o(None, False) - def _run_test_m2o(self, - director_strategy_level, - photo_strategy_level): + def _run_test_m2o(self, director_strategy_level, photo_strategy_level): # test where the innermost is m2o, e.g. # Movie->director @@ -2093,28 +2717,24 @@ class SubqueryloadDistinctTest(fixtures.DeclarativeMappedTest, s = create_session() - q = ( - s.query(Movie) - .options( - subqueryload(Movie.director) - .subqueryload(Director.photos) - ) + q = s.query(Movie).options( + subqueryload(Movie.director).subqueryload(Director.photos) ) ctx = q._compile_context() q2 = ctx.attributes[ - ('subquery', (inspect(Movie), inspect(Movie).attrs.director)) + ("subquery", (inspect(Movie), inspect(Movie).attrs.director)) ] self.assert_compile( q2, - 'SELECT director.id AS director_id, ' - 'director.name AS director_name, ' - 'anon_1.movie_director_id AS anon_1_movie_director_id ' - 'FROM (SELECT%s movie.director_id AS movie_director_id ' - 'FROM movie) AS anon_1 ' - 'JOIN director ON director.id = anon_1.movie_director_id ' - 'ORDER BY anon_1.movie_director_id' % ( - " DISTINCT" if expect_distinct else "") + "SELECT director.id AS director_id, " + "director.name AS director_name, " + "anon_1.movie_director_id AS anon_1_movie_director_id " + "FROM (SELECT%s movie.director_id AS movie_director_id " + "FROM movie) AS anon_1 " + "JOIN director ON director.id = anon_1.movie_director_id " + "ORDER BY anon_1.movie_director_id" + % (" DISTINCT" if expect_distinct else ""), ) ctx2 = q2._compile_context() @@ -2122,48 +2742,49 @@ class SubqueryloadDistinctTest(fixtures.DeclarativeMappedTest, rows = result.fetchall() if expect_distinct: - eq_(rows, [ - (1, 'Woody Allen', 1), - ]) + eq_(rows, [(1, "Woody Allen", 1)]) else: - eq_(rows, [ - (1, 'Woody Allen', 1), (1, 'Woody Allen', 1), - ]) + eq_(rows, [(1, "Woody Allen", 1), (1, "Woody Allen", 1)]) q3 = ctx2.attributes[ - ('subquery', (inspect(Director), inspect(Director).attrs.photos)) + ("subquery", (inspect(Director), inspect(Director).attrs.photos)) ] self.assert_compile( q3, - 'SELECT director_photo.id AS director_photo_id, ' - 'director_photo.path AS director_photo_path, ' - 'director_photo.director_id AS director_photo_director_id, ' - 'director_1.id AS director_1_id ' - 'FROM (SELECT%s movie.director_id AS movie_director_id ' - 'FROM movie) AS anon_1 ' - 'JOIN director AS director_1 ' - 'ON director_1.id = anon_1.movie_director_id ' - 'JOIN director_photo ' - 'ON director_1.id = director_photo.director_id ' - 'ORDER BY director_1.id' % ( - " DISTINCT" if expect_distinct else "") + "SELECT director_photo.id AS director_photo_id, " + "director_photo.path AS director_photo_path, " + "director_photo.director_id AS director_photo_director_id, " + "director_1.id AS director_1_id " + "FROM (SELECT%s movie.director_id AS movie_director_id " + "FROM movie) AS anon_1 " + "JOIN director AS director_1 " + "ON director_1.id = anon_1.movie_director_id " + "JOIN director_photo " + "ON director_1.id = director_photo.director_id " + "ORDER BY director_1.id" + % (" DISTINCT" if expect_distinct else ""), ) result = s.execute(q3) rows = result.fetchall() if expect_distinct: - eq_(set(tuple(t) for t in rows), set([ - (1, '/1.jpg', 1, 1), - (2, '/2.jpg', 1, 1), - ])) + eq_( + set(tuple(t) for t in rows), + set([(1, "/1.jpg", 1, 1), (2, "/2.jpg", 1, 1)]), + ) else: # oracle might not order the way we expect here - eq_(set(tuple(t) for t in rows), set([ - (1, '/1.jpg', 1, 1), - (2, '/2.jpg', 1, 1), - (1, '/1.jpg', 1, 1), - (2, '/2.jpg', 1, 1), - ])) + eq_( + set(tuple(t) for t in rows), + set( + [ + (1, "/1.jpg", 1, 1), + (2, "/2.jpg", 1, 1), + (1, "/1.jpg", 1, 1), + (2, "/2.jpg", 1, 1), + ] + ), + ) movies = q.all() @@ -2181,31 +2802,22 @@ class SubqueryloadDistinctTest(fixtures.DeclarativeMappedTest, s = create_session() - q = ( - s.query(Credit) - .options( - subqueryload(Credit.movie) - .subqueryload(Movie.director) - ) + q = s.query(Credit).options( + subqueryload(Credit.movie).subqueryload(Movie.director) ) ctx = q._compile_context() q2 = ctx.attributes[ - ('subquery', (inspect(Credit), Credit.movie.property)) + ("subquery", (inspect(Credit), Credit.movie.property)) ] ctx2 = q2._compile_context() q3 = ctx2.attributes[ - ('subquery', (inspect(Movie), Movie.director.property)) + ("subquery", (inspect(Movie), Movie.director.property)) ] result = s.execute(q3) - eq_( - result.fetchall(), - [ - (1, 'Woody Allen', 1), (1, 'Woody Allen', 1), - ] - ) + eq_(result.fetchall(), [(1, "Woody Allen", 1), (1, "Woody Allen", 1)]) class JoinedNoLoadConflictTest(fixtures.DeclarativeMappedTest): @@ -2216,27 +2828,29 @@ class JoinedNoLoadConflictTest(fixtures.DeclarativeMappedTest): Base = cls.DeclarativeBasic class Parent(ComparableEntity, Base): - __tablename__ = 'parent' + __tablename__ = "parent" - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(20)) - children = relationship('Child', - back_populates='parent', - lazy='noload' - ) + children = relationship( + "Child", back_populates="parent", lazy="noload" + ) class Child(ComparableEntity, Base): - __tablename__ = 'child' + __tablename__ = "child" - id = Column(Integer, primary_key=True, - test_needs_autoincrement=True) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) name = Column(String(20)) - parent_id = Column(Integer, ForeignKey('parent.id')) + parent_id = Column(Integer, ForeignKey("parent.id")) parent = relationship( - 'Parent', back_populates='children', lazy='joined') + "Parent", back_populates="children", lazy="joined" + ) @classmethod def insert_data(cls): @@ -2244,7 +2858,7 @@ class JoinedNoLoadConflictTest(fixtures.DeclarativeMappedTest): Child = cls.classes.Child s = Session() - s.add(Parent(name='parent', children=[Child(name='c1')])) + s.add(Parent(name="parent", children=[Child(name="c1")])) s.commit() def test_subqueryload_on_joined_noload(self): @@ -2257,17 +2871,14 @@ class JoinedNoLoadConflictTest(fixtures.DeclarativeMappedTest): # Parent->subqueryload->Child->joinedload->parent->noload->children. # the actual subqueryload has to emit *after* we've started populating # Parent->subqueryload->child. - parent = s.query(Parent).options([subqueryload('children')]).first() - eq_( - parent.children, - [Child(name='c1')] - ) + parent = s.query(Parent).options([subqueryload("children")]).first() + eq_(parent.children, [Child(name="c1")]) class SelfRefInheritanceAliasedTest( - fixtures.DeclarativeMappedTest, - testing.AssertsCompiledSQL): - __dialect__ = 'default' + fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL +): + __dialect__ = "default" @classmethod def setup_classes(cls): @@ -2280,7 +2891,8 @@ class SelfRefInheritanceAliasedTest( foo_id = Column(Integer, ForeignKey("foo.id")) foo = relationship( - lambda: Foo, foreign_keys=foo_id, remote_side=id) + lambda: Foo, foreign_keys=foo_id, remote_side=id + ) __mapper_args__ = { "polymorphic_on": type, @@ -2288,13 +2900,11 @@ class SelfRefInheritanceAliasedTest( } class Bar(Foo): - __mapper_args__ = { - "polymorphic_identity": "bar", - } + __mapper_args__ = {"polymorphic_identity": "bar"} @classmethod def insert_data(cls): - Foo, Bar = cls.classes('Foo', 'Bar') + Foo, Bar = cls.classes("Foo", "Bar") session = Session() target = Bar(id=1) @@ -2303,15 +2913,18 @@ class SelfRefInheritanceAliasedTest( session.commit() def test_twolevel_subquery_w_polymorphic(self): - Foo, Bar = self.classes('Foo', 'Bar') + Foo, Bar = self.classes("Foo", "Bar") r = with_polymorphic(Foo, "*", aliased=True) attr1 = Foo.foo.of_type(r) attr2 = r.foo s = Session() - q = s.query(Foo).filter(Foo.id == 2).options( - subqueryload(attr1).subqueryload(attr2)) + q = ( + s.query(Foo) + .filter(Foo.id == 2) + .options(subqueryload(attr1).subqueryload(attr2)) + ) self.assert_sql_execution( testing.db, @@ -2319,7 +2932,7 @@ class SelfRefInheritanceAliasedTest( CompiledSQL( "SELECT foo.id AS foo_id_1, foo.type AS foo_type, " "foo.foo_id AS foo_foo_id FROM foo WHERE foo.id = :id_1", - [{'id_1': 2}] + [{"id_1": 2}], ), CompiledSQL( "SELECT foo_1.id AS foo_1_id, foo_1.type AS foo_1_type, " @@ -2329,7 +2942,7 @@ class SelfRefInheritanceAliasedTest( "FROM foo WHERE foo.id = :id_1) AS anon_1 " "JOIN foo AS foo_1 ON foo_1.id = anon_1.foo_foo_id " "ORDER BY anon_1.foo_foo_id", - {'id_1': 2} + {"id_1": 2}, ), CompiledSQL( "SELECT foo.id AS foo_id_1, foo.type AS foo_type, " @@ -2338,7 +2951,7 @@ class SelfRefInheritanceAliasedTest( "WHERE foo.id = :id_1) AS anon_1 " "JOIN foo AS foo_1 ON foo_1.id = anon_1.foo_foo_id " "JOIN foo ON foo.id = foo_1.foo_id ORDER BY foo_1.foo_id", - {'id_1': 2} + {"id_1": 2}, ), ) @@ -2349,28 +2962,28 @@ class TestExistingRowPopulation(fixtures.DeclarativeMappedTest): Base = cls.DeclarativeBasic class A(Base): - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) - b_id = Column(ForeignKey('b.id')) - a2_id = Column(ForeignKey('a2.id')) + b_id = Column(ForeignKey("b.id")) + a2_id = Column(ForeignKey("a2.id")) a2 = relationship("A2") b = relationship("B") class A2(Base): - __tablename__ = 'a2' + __tablename__ = "a2" id = Column(Integer, primary_key=True) - b_id = Column(ForeignKey('b.id')) + b_id = Column(ForeignKey("b.id")) b = relationship("B") class B(Base): - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) - c1_m2o_id = Column(ForeignKey('c1_m2o.id')) - c2_m2o_id = Column(ForeignKey('c2_m2o.id')) + c1_m2o_id = Column(ForeignKey("c1_m2o.id")) + c2_m2o_id = Column(ForeignKey("c2_m2o.id")) c1_o2m = relationship("C1o2m") c2_o2m = relationship("C2o2m") @@ -2378,49 +2991,44 @@ class TestExistingRowPopulation(fixtures.DeclarativeMappedTest): c2_m2o = relationship("C2m2o") class C1o2m(Base): - __tablename__ = 'c1_o2m' + __tablename__ = "c1_o2m" id = Column(Integer, primary_key=True) - b_id = Column(ForeignKey('b.id')) + b_id = Column(ForeignKey("b.id")) class C2o2m(Base): - __tablename__ = 'c2_o2m' + __tablename__ = "c2_o2m" id = Column(Integer, primary_key=True) - b_id = Column(ForeignKey('b.id')) + b_id = Column(ForeignKey("b.id")) class C1m2o(Base): - __tablename__ = 'c1_m2o' + __tablename__ = "c1_m2o" id = Column(Integer, primary_key=True) class C2m2o(Base): - __tablename__ = 'c2_m2o' + __tablename__ = "c2_m2o" id = Column(Integer, primary_key=True) @classmethod def insert_data(cls): A, A2, B, C1o2m, C2o2m, C1m2o, C2m2o = cls.classes( - 'A', 'A2', 'B', 'C1o2m', 'C2o2m', 'C1m2o', 'C2m2o' + "A", "A2", "B", "C1o2m", "C2o2m", "C1m2o", "C2m2o" ) s = Session() b = B( - c1_o2m=[C1o2m()], - c2_o2m=[C2o2m()], - c1_m2o=C1m2o(), - c2_m2o=C2m2o(), + c1_o2m=[C1o2m()], c2_o2m=[C2o2m()], c1_m2o=C1m2o(), c2_m2o=C2m2o() ) s.add(A(b=b, a2=A2(b=b))) s.commit() def test_o2m(self): - A, A2, B, C1o2m, C2o2m = self.classes( - 'A', 'A2', 'B', 'C1o2m', 'C2o2m' - ) + A, A2, B, C1o2m, C2o2m = self.classes("A", "A2", "B", "C1o2m", "C2o2m") s = Session() @@ -2432,18 +3040,16 @@ class TestExistingRowPopulation(fixtures.DeclarativeMappedTest): q = s.query(A).options( joinedload(A.b).subqueryload(B.c2_o2m), - joinedload(A.a2).joinedload(A2.b).subqueryload(B.c1_o2m) + joinedload(A.a2).joinedload(A2.b).subqueryload(B.c1_o2m), ) a1 = q.all()[0] - is_true('c1_o2m' in a1.b.__dict__) - is_true('c2_o2m' in a1.b.__dict__) + is_true("c1_o2m" in a1.b.__dict__) + is_true("c2_o2m" in a1.b.__dict__) def test_m2o(self): - A, A2, B, C1m2o, C2m2o = self.classes( - 'A', 'A2', 'B', 'C1m2o', 'C2m2o' - ) + A, A2, B, C1m2o, C2m2o = self.classes("A", "A2", "B", "C1m2o", "C2m2o") s = Session() @@ -2455,9 +3061,9 @@ class TestExistingRowPopulation(fixtures.DeclarativeMappedTest): q = s.query(A).options( joinedload(A.b).subqueryload(B.c2_m2o), - joinedload(A.a2).joinedload(A2.b).subqueryload(B.c1_m2o) + joinedload(A.a2).joinedload(A2.b).subqueryload(B.c1_m2o), ) a1 = q.all()[0] - is_true('c1_m2o' in a1.b.__dict__) - is_true('c2_m2o' in a1.b.__dict__) + is_true("c1_m2o" in a1.b.__dict__) + is_true("c2_m2o" in a1.b.__dict__) diff --git a/test/orm/test_sync.py b/test/orm/test_sync.py index beb30d2423..d1dad43c8b 100644 --- a/test/orm/test_sync.py +++ b/test/orm/test_sync.py @@ -4,9 +4,18 @@ from sqlalchemy.testing.schema import Table, Column from test.orm import _fixtures from sqlalchemy.testing import fixtures from sqlalchemy import Integer, String, ForeignKey, func -from sqlalchemy.orm import mapper, relationship, backref, \ - create_session, unitofwork, attributes,\ - Session, class_mapper, sync, exc as orm_exc +from sqlalchemy.orm import ( + mapper, + relationship, + backref, + create_session, + unitofwork, + attributes, + Session, + class_mapper, + sync, + exc as orm_exc, +) class AssertsUOW(object): @@ -22,17 +31,23 @@ class AssertsUOW(object): return uow -class SyncTest(fixtures.MappedTest, - testing.AssertsExecutionResults, AssertsUOW): - +class SyncTest( + fixtures.MappedTest, testing.AssertsExecutionResults, AssertsUOW +): @classmethod def define_tables(cls, metadata): - Table('t1', metadata, - Column('id', Integer, primary_key=True), - Column('foo', Integer)) - Table('t2', metadata, - Column('id', Integer, ForeignKey('t1.id'), primary_key=True), - Column('t1id', Integer, ForeignKey('t1.id'))) + Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer), + ) + Table( + "t2", + metadata, + Column("id", Integer, ForeignKey("t1.id"), primary_key=True), + Column("t1id", Integer, ForeignKey("t1.id")), + ) @classmethod def setup_classes(cls): @@ -56,29 +71,32 @@ class SyncTest(fixtures.MappedTest, self.a1 = a1 = A() self.b1 = b1 = B() uowcommit = self._get_test_uow(session) - return uowcommit,\ - attributes.instance_state(a1),\ - attributes.instance_state(b1),\ - a_mapper, b_mapper + return ( + uowcommit, + attributes.instance_state(a1), + attributes.instance_state(b1), + a_mapper, + b_mapper, + ) def test_populate(self): uowcommit, a1, b1, a_mapper, b_mapper = self._fixture() pairs = [(a_mapper.c.id, b_mapper.c.id)] a1.obj().id = 7 - assert 'id' not in b1.obj().__dict__ + assert "id" not in b1.obj().__dict__ sync.populate(a1, a_mapper, b1, b_mapper, pairs, uowcommit, False) eq_(b1.obj().id, 7) - eq_(b1.obj().__dict__['id'], 7) + eq_(b1.obj().__dict__["id"], 7) assert ("pk_cascaded", b1, b_mapper.c.id) not in uowcommit.attributes def test_populate_flag_cascaded(self): uowcommit, a1, b1, a_mapper, b_mapper = self._fixture() pairs = [(a_mapper.c.id, b_mapper.c.id)] a1.obj().id = 7 - assert 'id' not in b1.obj().__dict__ + assert "id" not in b1.obj().__dict__ sync.populate(a1, a_mapper, b1, b_mapper, pairs, uowcommit, True) eq_(b1.obj().id, 7) - eq_(b1.obj().__dict__['id'], 7) + eq_(b1.obj().__dict__["id"], 7) eq_(uowcommit.attributes[("pk_cascaded", b1, b_mapper.c.id)], True) def test_populate_unmapped_source(self): @@ -94,12 +112,13 @@ class SyncTest(fixtures.MappedTest, b1, b_mapper, pairs, - uowcommit, False + uowcommit, + False, ) def test_populate_unmapped_dest(self): uowcommit, a1, b1, a_mapper, b_mapper = self._fixture() - pairs = [(a_mapper.c.id, a_mapper.c.id,)] + pairs = [(a_mapper.c.id, a_mapper.c.id)] assert_raises_message( orm_exc.UnmappedColumnError, r"Can't execute sync rule for destination " @@ -111,38 +130,45 @@ class SyncTest(fixtures.MappedTest, b1, b_mapper, pairs, - uowcommit, False + uowcommit, + False, ) def test_clear(self): uowcommit, a1, b1, a_mapper, b_mapper = self._fixture() - pairs = [(a_mapper.c.id, b_mapper.c.t1id,)] + pairs = [(a_mapper.c.id, b_mapper.c.t1id)] b1.obj().t1id = 8 - eq_(b1.obj().__dict__['t1id'], 8) + eq_(b1.obj().__dict__["t1id"], 8) sync.clear(b1, b_mapper, pairs) - eq_(b1.obj().__dict__['t1id'], None) + eq_(b1.obj().__dict__["t1id"], None) def test_clear_pk(self): uowcommit, a1, b1, a_mapper, b_mapper = self._fixture() - pairs = [(a_mapper.c.id, b_mapper.c.id,)] + pairs = [(a_mapper.c.id, b_mapper.c.id)] b1.obj().id = 8 - eq_(b1.obj().__dict__['id'], 8) + eq_(b1.obj().__dict__["id"], 8) assert_raises_message( AssertionError, "Dependency rule tried to blank-out primary key " "column 't2.id' on instance ' has a NULL " "identity key. If this is an auto-generated value, " "check that the database table allows generation ", - s.commit + s.commit, ) def test_dont_complain_if_no_update(self): @@ -2775,20 +3369,26 @@ class EnsurePKSortableTest(fixtures.MappedTest): class MyNotSortableEnum(SomeEnum): __members__ = OrderedDict() - one = MySortableEnum('one', 1) - two = MySortableEnum('two', 2) - three = MyNotSortableEnum('three', 3) - four = MyNotSortableEnum('four', 4) + one = MySortableEnum("one", 1) + two = MySortableEnum("two", 2) + three = MyNotSortableEnum("three", 3) + four = MyNotSortableEnum("four", 4) @classmethod def define_tables(cls, metadata): - Table('t1', metadata, - Column('id', Enum(cls.MySortableEnum), primary_key=True), - Column('data', String(10))) + Table( + "t1", + metadata, + Column("id", Enum(cls.MySortableEnum), primary_key=True), + Column("data", String(10)), + ) - Table('t2', metadata, - Column('id', Enum(cls.MyNotSortableEnum), primary_key=True), - Column('data', String(10))) + Table( + "t2", + metadata, + Column("id", Enum(cls.MyNotSortableEnum), primary_key=True), + Column("data", String(10)), + ) @classmethod def setup_classes(cls): @@ -2810,15 +3410,15 @@ class EnsurePKSortableTest(fixtures.MappedTest): s.add_all([a, b]) s.commit() - a.data = 'bar' - b.data = 'foo' + a.data = "bar" + b.data = "foo" if sa.util.py3k: assert_raises_message( sa.exc.InvalidRequestError, r"Could not sort objects by primary key; primary key values " r"must be sortable in Python \(was: '<' not supported between " r"instances of 'MyNotSortableEnum' and 'MyNotSortableEnum'\)", - s.flush + s.flush, ) else: s.flush() @@ -2831,6 +3431,6 @@ class EnsurePKSortableTest(fixtures.MappedTest): s.add_all([a, b]) s.commit() - a.data = 'bar' - b.data = 'foo' + a.data = "bar" + b.data = "foo" s.commit() diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index 0d19bee75e..f5c39fee73 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -5,18 +5,32 @@ from sqlalchemy.testing.schema import Table, Column from test.orm import _fixtures from sqlalchemy import exc, util from sqlalchemy.testing import fixtures, config -from sqlalchemy import Integer, String, ForeignKey, func, \ - literal, FetchedValue, text, select -from sqlalchemy.orm import mapper, relationship, backref, \ - create_session, unitofwork, attributes,\ - Session, exc as orm_exc +from sqlalchemy import ( + Integer, + String, + ForeignKey, + func, + literal, + FetchedValue, + text, + select, +) +from sqlalchemy.orm import ( + mapper, + relationship, + backref, + create_session, + unitofwork, + attributes, + Session, + exc as orm_exc, +) from sqlalchemy.testing.mock import Mock, patch from sqlalchemy.testing.assertsql import AllOf, CompiledSQL from sqlalchemy import event class AssertsUOW(object): - def _get_test_uow(self, session): uow = unitofwork.UOWTransaction(session) deleted = set(session._deleted) @@ -36,61 +50,59 @@ class AssertsUOW(object): class UOWTest( - _fixtures.FixtureTest, - testing.AssertsExecutionResults, AssertsUOW): + _fixtures.FixtureTest, testing.AssertsExecutionResults, AssertsUOW +): run_inserts = None class RudimentaryFlushTest(UOWTest): - def test_one_to_many_save(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) sess = create_session() - a1, a2 = Address(email_address='a1'), Address(email_address='a2') - u1 = User(name='u1', addresses=[a1, a2]) + a1, a2 = Address(email_address="a1"), Address(email_address="a2") + u1 = User(name="u1", addresses=[a1, a2]) sess.add(u1) self.assert_sql_execution( testing.db, sess.flush, CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", - {'name': 'u1'} + "INSERT INTO users (name) VALUES (:name)", {"name": "u1"} ), CompiledSQL( "INSERT INTO addresses (user_id, email_address) " "VALUES (:user_id, :email_address)", - lambda ctx: {'email_address': 'a1', 'user_id': u1.id} + lambda ctx: {"email_address": "a1", "user_id": u1.id}, ), CompiledSQL( "INSERT INTO addresses (user_id, email_address) " "VALUES (:user_id, :email_address)", - lambda ctx: {'email_address': 'a2', 'user_id': u1.id} + lambda ctx: {"email_address": "a2", "user_id": u1.id}, ), ) def test_one_to_many_delete_all(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) sess = create_session() - a1, a2 = Address(email_address='a1'), Address(email_address='a2') - u1 = User(name='u1', addresses=[a1, a2]) + a1, a2 = Address(email_address="a1"), Address(email_address="a2") + u1 = User(name="u1", addresses=[a1, a2]) sess.add(u1) sess.flush() @@ -102,27 +114,26 @@ class RudimentaryFlushTest(UOWTest): sess.flush, CompiledSQL( "DELETE FROM addresses WHERE addresses.id = :id", - [{'id': a1.id}, {'id': a2.id}] + [{"id": a1.id}, {"id": a2.id}], ), CompiledSQL( - "DELETE FROM users WHERE users.id = :id", - {'id': u1.id} + "DELETE FROM users WHERE users.id = :id", {"id": u1.id} ), ) def test_one_to_many_delete_parent(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) sess = create_session() - a1, a2 = Address(email_address='a1'), Address(email_address='a2') - u1 = User(name='u1', addresses=[a1, a2]) + a1, a2 = Address(email_address="a1"), Address(email_address="a2") + u1 = User(name="u1", addresses=[a1, a2]) sess.add(u1) sess.flush() @@ -134,67 +145,69 @@ class RudimentaryFlushTest(UOWTest): "UPDATE addresses SET user_id=:user_id WHERE " "addresses.id = :addresses_id", lambda ctx: [ - {'addresses_id': a1.id, 'user_id': None}, - {'addresses_id': a2.id, 'user_id': None} - ] + {"addresses_id": a1.id, "user_id": None}, + {"addresses_id": a2.id, "user_id": None}, + ], ), CompiledSQL( - "DELETE FROM users WHERE users.id = :id", - {'id': u1.id} + "DELETE FROM users WHERE users.id = :id", {"id": u1.id} ), ) def test_many_to_one_save(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship(User) - }) + mapper(Address, addresses, properties={"user": relationship(User)}) sess = create_session() - u1 = User(name='u1') - a1, a2 = Address(email_address='a1', user=u1), \ - Address(email_address='a2', user=u1) + u1 = User(name="u1") + a1, a2 = ( + Address(email_address="a1", user=u1), + Address(email_address="a2", user=u1), + ) sess.add_all([a1, a2]) self.assert_sql_execution( testing.db, sess.flush, CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", - {'name': 'u1'} + "INSERT INTO users (name) VALUES (:name)", {"name": "u1"} ), CompiledSQL( "INSERT INTO addresses (user_id, email_address) " "VALUES (:user_id, :email_address)", - lambda ctx: {'email_address': 'a1', 'user_id': u1.id} + lambda ctx: {"email_address": "a1", "user_id": u1.id}, ), CompiledSQL( "INSERT INTO addresses (user_id, email_address) " "VALUES (:user_id, :email_address)", - lambda ctx: {'email_address': 'a2', 'user_id': u1.id} + lambda ctx: {"email_address": "a2", "user_id": u1.id}, ), ) def test_many_to_one_delete_all(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship(User) - }) + mapper(Address, addresses, properties={"user": relationship(User)}) sess = create_session() - u1 = User(name='u1') - a1, a2 = Address(email_address='a1', user=u1), \ - Address(email_address='a2', user=u1) + u1 = User(name="u1") + a1, a2 = ( + Address(email_address="a1", user=u1), + Address(email_address="a2", user=u1), + ) sess.add_all([a1, a2]) sess.flush() @@ -206,29 +219,30 @@ class RudimentaryFlushTest(UOWTest): sess.flush, CompiledSQL( "DELETE FROM addresses WHERE addresses.id = :id", - [{'id': a1.id}, {'id': a2.id}] + [{"id": a1.id}, {"id": a2.id}], ), CompiledSQL( - "DELETE FROM users WHERE users.id = :id", - {'id': u1.id} + "DELETE FROM users WHERE users.id = :id", {"id": u1.id} ), ) def test_many_to_one_delete_target(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship(User) - }) + mapper(Address, addresses, properties={"user": relationship(User)}) sess = create_session() - u1 = User(name='u1') - a1, a2 = Address(email_address='a1', user=u1), \ - Address(email_address='a2', user=u1) + u1 = User(name="u1") + a1, a2 = ( + Address(email_address="a1", user=u1), + Address(email_address="a2", user=u1), + ) sess.add_all([a1, a2]) sess.flush() @@ -241,30 +255,31 @@ class RudimentaryFlushTest(UOWTest): "UPDATE addresses SET user_id=:user_id WHERE " "addresses.id = :addresses_id", lambda ctx: [ - {'addresses_id': a1.id, 'user_id': None}, - {'addresses_id': a2.id, 'user_id': None} - ] + {"addresses_id": a1.id, "user_id": None}, + {"addresses_id": a2.id, "user_id": None}, + ], ), CompiledSQL( - "DELETE FROM users WHERE users.id = :id", - {'id': u1.id} + "DELETE FROM users WHERE users.id = :id", {"id": u1.id} ), ) def test_many_to_one_delete_unloaded(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'parent': relationship(User) - }) + mapper(Address, addresses, properties={"parent": relationship(User)}) - parent = User(name='p1') - c1, c2 = Address(email_address='c1', parent=parent), \ - Address(email_address='c2', parent=parent) + parent = User(name="p1") + c1, c2 = ( + Address(email_address="c1", parent=parent), + Address(email_address="c2", parent=parent), + ) session = Session() session.add_all([c1, c2]) @@ -303,7 +318,7 @@ class RudimentaryFlushTest(UOWTest): "addresses_email_address FROM addresses " "WHERE addresses.id = " ":param_1", - lambda ctx: {'param_1': c1id} + lambda ctx: {"param_1": c1id}, ), CompiledSQL( "SELECT addresses.id AS addresses_id, " @@ -312,38 +327,40 @@ class RudimentaryFlushTest(UOWTest): "addresses_email_address FROM addresses " "WHERE addresses.id = " ":param_1", - lambda ctx: {'param_1': c2id} + lambda ctx: {"param_1": c2id}, ), CompiledSQL( "SELECT users.id AS users_id, users.name AS users_name " "FROM users WHERE users.id = :param_1", - lambda ctx: {'param_1': pid} + lambda ctx: {"param_1": pid}, ), CompiledSQL( "DELETE FROM addresses WHERE addresses.id = :id", - lambda ctx: [{'id': c1id}, {'id': c2id}] + lambda ctx: [{"id": c1id}, {"id": c2id}], ), CompiledSQL( "DELETE FROM users WHERE users.id = :id", - lambda ctx: {'id': pid} + lambda ctx: {"id": pid}, ), ), ) def test_many_to_one_delete_childonly_unloaded(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'parent': relationship(User) - }) + mapper(Address, addresses, properties={"parent": relationship(User)}) - parent = User(name='p1') - c1, c2 = Address(email_address='c1', parent=parent), \ - Address(email_address='c2', parent=parent) + parent = User(name="p1") + c1, c2 = ( + Address(email_address="c1", parent=parent), + Address(email_address="c2", parent=parent), + ) session = Session() session.add_all([c1, c2]) @@ -375,7 +392,7 @@ class RudimentaryFlushTest(UOWTest): "addresses_email_address FROM addresses " "WHERE addresses.id = " ":param_1", - lambda ctx: {'param_1': c1id} + lambda ctx: {"param_1": c1id}, ), CompiledSQL( "SELECT addresses.id AS addresses_id, " @@ -384,29 +401,31 @@ class RudimentaryFlushTest(UOWTest): "addresses_email_address FROM addresses " "WHERE addresses.id = " ":param_1", - lambda ctx: {'param_1': c2id} + lambda ctx: {"param_1": c2id}, ), ), CompiledSQL( "DELETE FROM addresses WHERE addresses.id = :id", - lambda ctx: [{'id': c1id}, {'id': c2id}] + lambda ctx: [{"id": c1id}, {"id": c2id}], ), ) def test_many_to_one_delete_childonly_unloaded_expired(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'parent': relationship(User) - }) + mapper(Address, addresses, properties={"parent": relationship(User)}) - parent = User(name='p1') - c1, c2 = Address(email_address='c1', parent=parent), \ - Address(email_address='c2', parent=parent) + parent = User(name="p1") + c1, c2 = ( + Address(email_address="c1", parent=parent), + Address(email_address="c2", parent=parent), + ) session = Session() session.add_all([c1, c2]) @@ -437,7 +456,7 @@ class RudimentaryFlushTest(UOWTest): "addresses_email_address FROM addresses " "WHERE addresses.id = " ":param_1", - lambda ctx: {'param_1': c1id} + lambda ctx: {"param_1": c1id}, ), CompiledSQL( "SELECT addresses.id AS addresses_id, " @@ -446,30 +465,32 @@ class RudimentaryFlushTest(UOWTest): "addresses_email_address FROM addresses " "WHERE addresses.id = " ":param_1", - lambda ctx: {'param_1': c2id} + lambda ctx: {"param_1": c2id}, ), ), CompiledSQL( "DELETE FROM addresses WHERE addresses.id = :id", - lambda ctx: [{'id': c1id}, {'id': c2id}] + lambda ctx: [{"id": c1id}, {"id": c2id}], ), ) def test_many_to_one_del_attr(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship(User) - }) + mapper(Address, addresses, properties={"user": relationship(User)}) sess = create_session() - u1 = User(name='u1') - a1, a2 = Address(email_address='a1', user=u1), \ - Address(email_address='a2', user=u1) + u1 = User(name="u1") + a1, a2 = ( + Address(email_address="a1", user=u1), + Address(email_address="a2", user=u1), + ) sess.add_all([a1, a2]) sess.flush() @@ -480,34 +501,34 @@ class RudimentaryFlushTest(UOWTest): CompiledSQL( "UPDATE addresses SET user_id=:user_id WHERE " "addresses.id = :addresses_id", - lambda ctx: [ - {'addresses_id': a1.id, 'user_id': None}, - ] - ) + lambda ctx: [{"addresses_id": a1.id, "user_id": None}], + ), ) def test_many_to_one_del_attr_unloaded(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship(User) - }) + mapper(Address, addresses, properties={"user": relationship(User)}) sess = create_session() - u1 = User(name='u1') - a1, a2 = Address(email_address='a1', user=u1), \ - Address(email_address='a2', user=u1) + u1 = User(name="u1") + a1, a2 = ( + Address(email_address="a1", user=u1), + Address(email_address="a2", user=u1), + ) sess.add_all([a1, a2]) sess.flush() # trying to guarantee that the history only includes # PASSIVE_NO_RESULT for "deleted" and nothing else sess.expunge(u1) - sess.expire(a1, ['user']) + sess.expire(a1, ["user"]) del a1.user sess.add(a1) @@ -517,30 +538,28 @@ class RudimentaryFlushTest(UOWTest): CompiledSQL( "UPDATE addresses SET user_id=:user_id WHERE " "addresses.id = :addresses_id", - lambda ctx: [ - {'addresses_id': a1.id, 'user_id': None}, - ] - ) + lambda ctx: [{"addresses_id": a1.id, "user_id": None}], + ), ) def test_natural_ordering(self): """test that unconnected items take relationship() into account regardless.""" - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'parent': relationship(User) - }) + mapper(Address, addresses, properties={"parent": relationship(User)}) sess = create_session() - u1 = User(id=1, name='u1') - a1 = Address(id=1, user_id=1, email_address='a2') + u1 = User(id=1, name="u1") + a1 = Address(id=1, user_id=1, email_address="a2") sess.add_all([u1, a1]) self.assert_sql_execution( @@ -548,12 +567,13 @@ class RudimentaryFlushTest(UOWTest): sess.flush, CompiledSQL( "INSERT INTO users (id, name) VALUES (:id, :name)", - {'id': 1, 'name': 'u1'}), + {"id": 1, "name": "u1"}, + ), CompiledSQL( "INSERT INTO addresses (id, user_id, email_address) " "VALUES (:id, :user_id, :email_address)", - {'email_address': 'a2', 'user_id': 1, 'id': 1} - ) + {"email_address": "a2", "user_id": 1, "id": 1}, + ), ) sess.delete(u1) @@ -562,13 +582,9 @@ class RudimentaryFlushTest(UOWTest): testing.db, sess.flush, CompiledSQL( - "DELETE FROM addresses WHERE addresses.id = :id", - [{'id': 1}] + "DELETE FROM addresses WHERE addresses.id = :id", [{"id": 1}] ), - CompiledSQL( - "DELETE FROM users WHERE users.id = :id", - [{'id': 1}] - ) + CompiledSQL("DELETE FROM users WHERE users.id = :id", [{"id": 1}]), ) def test_natural_selfref(self): @@ -577,9 +593,7 @@ class RudimentaryFlushTest(UOWTest): Node, nodes = self.classes.Node, self.tables.nodes - mapper(Node, nodes, properties={ - 'children': relationship(Node) - }) + mapper(Node, nodes, properties={"children": relationship(Node)}) sess = create_session() @@ -597,25 +611,35 @@ class RudimentaryFlushTest(UOWTest): CompiledSQL( "INSERT INTO nodes (id, parent_id, data) VALUES " "(:id, :parent_id, :data)", - [{'parent_id': None, 'data': None, 'id': 1}, - {'parent_id': 1, 'data': None, 'id': 2}, - {'parent_id': 2, 'data': None, 'id': 3}] + [ + {"parent_id": None, "data": None, "id": 1}, + {"parent_id": 1, "data": None, "id": 2}, + {"parent_id": 2, "data": None, "id": 3}, + ], ), ) def test_many_to_many(self): keywords, items, item_keywords, Keyword, Item = ( - self.tables.keywords, self.tables.items, self.tables.item_keywords, - self.classes.Keyword, self.classes.Item) + self.tables.keywords, + self.tables.items, + self.tables.item_keywords, + self.classes.Keyword, + self.classes.Item, + ) - mapper(Item, items, properties={ - 'keywords': relationship(Keyword, secondary=item_keywords) - }) + mapper( + Item, + items, + properties={ + "keywords": relationship(Keyword, secondary=item_keywords) + }, + ) mapper(Keyword, keywords) sess = create_session() - k1 = Keyword(name='k1') - i1 = Item(description='i1', keywords=[k1]) + k1 = Keyword(name="k1") + i1 = Item(description="i1", keywords=[k1]) sess.add(i1) self.assert_sql_execution( testing.db, @@ -623,70 +647,76 @@ class RudimentaryFlushTest(UOWTest): AllOf( CompiledSQL( "INSERT INTO keywords (name) VALUES (:name)", - {'name': 'k1'} + {"name": "k1"}, ), CompiledSQL( "INSERT INTO items (description) VALUES (:description)", - {'description': 'i1'} + {"description": "i1"}, ), ), CompiledSQL( "INSERT INTO item_keywords (item_id, keyword_id) " "VALUES (:item_id, :keyword_id)", - lambda ctx: {'item_id': i1.id, 'keyword_id': k1.id} - ) + lambda ctx: {"item_id": i1.id, "keyword_id": k1.id}, + ), ) # test that keywords collection isn't loaded - sess.expire(i1, ['keywords']) - i1.description = 'i2' + sess.expire(i1, ["keywords"]) + i1.description = "i2" self.assert_sql_execution( testing.db, sess.flush, - CompiledSQL("UPDATE items SET description=:description " - "WHERE items.id = :items_id", - lambda ctx: {'description': 'i2', 'items_id': i1.id}) + CompiledSQL( + "UPDATE items SET description=:description " + "WHERE items.id = :items_id", + lambda ctx: {"description": "i2", "items_id": i1.id}, + ), ) def test_m2o_flush_size(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) mapper(User, users) - mapper(Address, addresses, properties={ - 'user': relationship(User, passive_updates=True) - }) + mapper( + Address, + addresses, + properties={"user": relationship(User, passive_updates=True)}, + ) sess = create_session() - u1 = User(name='ed') + u1 = User(name="ed") sess.add(u1) self._assert_uow_size(sess, 2) def test_o2m_flush_size(self): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address), - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) sess = create_session() - u1 = User(name='ed') + u1 = User(name="ed") sess.add(u1) self._assert_uow_size(sess, 2) sess.flush() - u1.name = 'jack' + u1.name = "jack" self._assert_uow_size(sess, 2) sess.flush() - a1 = Address(email_address='foo') + a1 = Address(email_address="foo") sess.add(a1) sess.flush() @@ -698,7 +728,7 @@ class RudimentaryFlushTest(UOWTest): sess = create_session() u1 = sess.query(User).first() - u1.name = 'ed' + u1.name = "ed" self._assert_uow_size(sess, 2) u1.addresses @@ -706,63 +736,55 @@ class RudimentaryFlushTest(UOWTest): class SingleCycleTest(UOWTest): - def teardown(self): engines.testing_reaper.rollback_all() # mysql can't handle delete from nodes # since it doesn't deal with the FKs correctly, # so wipe out the parent_id first - testing.db.execute( - self.tables.nodes.update().values(parent_id=None) - ) + testing.db.execute(self.tables.nodes.update().values(parent_id=None)) super(SingleCycleTest, self).teardown() def test_one_to_many_save(self): Node, nodes = self.classes.Node, self.tables.nodes - mapper(Node, nodes, properties={ - 'children': relationship(Node) - }) + mapper(Node, nodes, properties={"children": relationship(Node)}) sess = create_session() - n2, n3 = Node(data='n2'), Node(data='n3') - n1 = Node(data='n1', children=[n2, n3]) + n2, n3 = Node(data="n2"), Node(data="n3") + n1 = Node(data="n1", children=[n2, n3]) sess.add(n1) self.assert_sql_execution( testing.db, sess.flush, - CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " "(:parent_id, :data)", - {'parent_id': None, 'data': 'n1'} + {"parent_id": None, "data": "n1"}, ), AllOf( CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " "(:parent_id, :data)", - lambda ctx: {'parent_id': n1.id, 'data': 'n2'} + lambda ctx: {"parent_id": n1.id, "data": "n2"}, ), CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " "(:parent_id, :data)", - lambda ctx: {'parent_id': n1.id, 'data': 'n3'} + lambda ctx: {"parent_id": n1.id, "data": "n3"}, ), - ) + ), ) def test_one_to_many_delete_all(self): Node, nodes = self.classes.Node, self.tables.nodes - mapper(Node, nodes, properties={ - 'children': relationship(Node) - }) + mapper(Node, nodes, properties={"children": relationship(Node)}) sess = create_session() - n2, n3 = Node(data='n2', children=[]), Node(data='n3', children=[]) - n1 = Node(data='n1', children=[n2, n3]) + n2, n3 = Node(data="n2", children=[]), Node(data="n3", children=[]) + n1 = Node(data="n1", children=[n2, n3]) sess.add(n1) sess.flush() @@ -773,87 +795,97 @@ class SingleCycleTest(UOWTest): self.assert_sql_execution( testing.db, sess.flush, - CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", - lambda ctx: [{'id': n2.id}, {'id': n3.id}]), - CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", - lambda ctx: {'id': n1.id}) + CompiledSQL( + "DELETE FROM nodes WHERE nodes.id = :id", + lambda ctx: [{"id": n2.id}, {"id": n3.id}], + ), + CompiledSQL( + "DELETE FROM nodes WHERE nodes.id = :id", + lambda ctx: {"id": n1.id}, + ), ) def test_one_to_many_delete_parent(self): Node, nodes = self.classes.Node, self.tables.nodes - mapper(Node, nodes, properties={ - 'children': relationship(Node) - }) + mapper(Node, nodes, properties={"children": relationship(Node)}) sess = create_session() - n2, n3 = Node(data='n2', children=[]), Node(data='n3', children=[]) - n1 = Node(data='n1', children=[n2, n3]) + n2, n3 = Node(data="n2", children=[]), Node(data="n3", children=[]) + n1 = Node(data="n1", children=[n2, n3]) sess.add(n1) sess.flush() sess.delete(n1) self.assert_sql_execution( - testing.db, sess.flush, AllOf( + testing.db, + sess.flush, + AllOf( CompiledSQL( "UPDATE nodes SET parent_id=:parent_id " - "WHERE nodes.id = :nodes_id", lambda ctx: [ - {'nodes_id': n3.id, 'parent_id': None}, - {'nodes_id': n2.id, 'parent_id': None} - ] + "WHERE nodes.id = :nodes_id", + lambda ctx: [ + {"nodes_id": n3.id, "parent_id": None}, + {"nodes_id": n2.id, "parent_id": None}, + ], ) ), CompiledSQL( - "DELETE FROM nodes WHERE nodes.id = :id", lambda ctx: { - 'id': n1.id})) + "DELETE FROM nodes WHERE nodes.id = :id", + lambda ctx: {"id": n1.id}, + ), + ) def test_many_to_one_save(self): Node, nodes = self.classes.Node, self.tables.nodes - mapper(Node, nodes, properties={ - 'parent': relationship(Node, remote_side=nodes.c.id) - }) + mapper( + Node, + nodes, + properties={"parent": relationship(Node, remote_side=nodes.c.id)}, + ) sess = create_session() - n1 = Node(data='n1') - n2, n3 = Node(data='n2', parent=n1), Node(data='n3', parent=n1) + n1 = Node(data="n1") + n2, n3 = Node(data="n2", parent=n1), Node(data="n3", parent=n1) sess.add_all([n2, n3]) self.assert_sql_execution( testing.db, sess.flush, - CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " "(:parent_id, :data)", - {'parent_id': None, 'data': 'n1'} + {"parent_id": None, "data": "n1"}, ), AllOf( CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " "(:parent_id, :data)", - lambda ctx: {'parent_id': n1.id, 'data': 'n2'} + lambda ctx: {"parent_id": n1.id, "data": "n2"}, ), CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " "(:parent_id, :data)", - lambda ctx: {'parent_id': n1.id, 'data': 'n3'} + lambda ctx: {"parent_id": n1.id, "data": "n3"}, ), - ) + ), ) def test_many_to_one_delete_all(self): Node, nodes = self.classes.Node, self.tables.nodes - mapper(Node, nodes, properties={ - 'parent': relationship(Node, remote_side=nodes.c.id) - }) + mapper( + Node, + nodes, + properties={"parent": relationship(Node, remote_side=nodes.c.id)}, + ) sess = create_session() - n1 = Node(data='n1') - n2, n3 = Node(data='n2', parent=n1), Node(data='n3', parent=n1) + n1 = Node(data="n1") + n2, n3 = Node(data="n2", parent=n1), Node(data="n3", parent=n1) sess.add_all([n2, n3]) sess.flush() @@ -864,26 +896,32 @@ class SingleCycleTest(UOWTest): self.assert_sql_execution( testing.db, sess.flush, - CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", - lambda ctx: [{'id': n2.id}, {'id': n3.id}]), - CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", - lambda ctx: {'id': n1.id}) + CompiledSQL( + "DELETE FROM nodes WHERE nodes.id = :id", + lambda ctx: [{"id": n2.id}, {"id": n3.id}], + ), + CompiledSQL( + "DELETE FROM nodes WHERE nodes.id = :id", + lambda ctx: {"id": n1.id}, + ), ) def test_many_to_one_set_null_unloaded(self): Node, nodes = self.classes.Node, self.tables.nodes - mapper(Node, nodes, properties={ - 'parent': relationship(Node, remote_side=nodes.c.id) - }) + mapper( + Node, + nodes, + properties={"parent": relationship(Node, remote_side=nodes.c.id)}, + ) sess = create_session() - n1 = Node(data='n1') - n2 = Node(data='n2', parent=n1) + n1 = Node(data="n1") + n2 = Node(data="n2", parent=n1) sess.add_all([n1, n2]) sess.flush() sess.close() - n2 = sess.query(Node).filter_by(data='n2').one() + n2 = sess.query(Node).filter_by(data="n2").one() n2.parent = None self.assert_sql_execution( testing.db, @@ -891,20 +929,18 @@ class SingleCycleTest(UOWTest): CompiledSQL( "UPDATE nodes SET parent_id=:parent_id WHERE " "nodes.id = :nodes_id", - lambda ctx: {"parent_id": None, "nodes_id": n2.id} - ) + lambda ctx: {"parent_id": None, "nodes_id": n2.id}, + ), ) def test_cycle_rowswitch(self): Node, nodes = self.classes.Node, self.tables.nodes - mapper(Node, nodes, properties={ - 'children': relationship(Node) - }) + mapper(Node, nodes, properties={"children": relationship(Node)}) sess = create_session() - n2, n3 = Node(data='n2', children=[]), Node(data='n3', children=[]) - n1 = Node(data='n1', children=[n2]) + n2, n3 = Node(data="n2", children=[]), Node(data="n3", children=[]) + n1 = Node(data="n1", children=[n2]) sess.add(n1) sess.flush() @@ -916,15 +952,19 @@ class SingleCycleTest(UOWTest): def test_bidirectional_mutations_one(self): Node, nodes = self.classes.Node, self.tables.nodes - mapper(Node, nodes, properties={ - 'children': relationship(Node, - backref=backref('parent', - remote_side=nodes.c.id)) - }) + mapper( + Node, + nodes, + properties={ + "children": relationship( + Node, backref=backref("parent", remote_side=nodes.c.id) + ) + }, + ) sess = create_session() - n2, n3 = Node(data='n2', children=[]), Node(data='n3', children=[]) - n1 = Node(data='n1', children=[n2]) + n2, n3 = Node(data="n2", children=[]), Node(data="n3", children=[]) + n1 = Node(data="n1", children=[n2]) sess.add(n1) sess.flush() sess.delete(n2) @@ -942,20 +982,20 @@ class SingleCycleTest(UOWTest): Node, nodes, properties={ - 'children': relationship( - Node, - backref=backref( - 'parent', - remote_side=nodes.c.id))}) + "children": relationship( + Node, backref=backref("parent", remote_side=nodes.c.id) + ) + }, + ) sess = create_session() - n1 = Node(data='n1') - n1.children.append(Node(data='n11')) - n12 = Node(data='n12') + n1 = Node(data="n1") + n1.children.append(Node(data="n11")) + n12 = Node(data="n12") n1.children.append(n12) - n1.children.append(Node(data='n13')) - n1.children[1].children.append(Node(data='n121')) - n1.children[1].children.append(Node(data='n122')) - n1.children[1].children.append(Node(data='n123')) + n1.children.append(Node(data="n13")) + n1.children[1].children.append(Node(data="n121")) + n1.children[1].children.append(Node(data="n122")) + n1.children[1].children.append(Node(data="n123")) sess.add(n1) self.assert_sql_execution( testing.db, @@ -963,59 +1003,57 @@ class SingleCycleTest(UOWTest): CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " "(:parent_id, :data)", - lambda ctx: {'parent_id': None, 'data': 'n1'} + lambda ctx: {"parent_id": None, "data": "n1"}, ), CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " "(:parent_id, :data)", - lambda ctx: {'parent_id': n1.id, 'data': 'n11'} + lambda ctx: {"parent_id": n1.id, "data": "n11"}, ), CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " "(:parent_id, :data)", - lambda ctx: {'parent_id': n1.id, 'data': 'n12'} + lambda ctx: {"parent_id": n1.id, "data": "n12"}, ), CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " "(:parent_id, :data)", - lambda ctx: {'parent_id': n1.id, 'data': 'n13'} + lambda ctx: {"parent_id": n1.id, "data": "n13"}, ), CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " "(:parent_id, :data)", - lambda ctx: {'parent_id': n12.id, 'data': 'n121'} + lambda ctx: {"parent_id": n12.id, "data": "n121"}, ), CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " "(:parent_id, :data)", - lambda ctx: {'parent_id': n12.id, 'data': 'n122'} + lambda ctx: {"parent_id": n12.id, "data": "n122"}, ), CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " "(:parent_id, :data)", - lambda ctx: {'parent_id': n12.id, 'data': 'n123'} + lambda ctx: {"parent_id": n12.id, "data": "n123"}, ), ) def test_singlecycle_flush_size(self): Node, nodes = self.classes.Node, self.tables.nodes - mapper(Node, nodes, properties={ - 'children': relationship(Node) - }) + mapper(Node, nodes, properties={"children": relationship(Node)}) sess = create_session() - n1 = Node(data='ed') + n1 = Node(data="ed") sess.add(n1) self._assert_uow_size(sess, 2) sess.flush() - n1.data = 'jack' + n1.data = "jack" self._assert_uow_size(sess, 2) sess.flush() - n2 = Node(data='foo') + n2 = Node(data="foo") sess.add(n2) sess.flush() @@ -1027,7 +1065,7 @@ class SingleCycleTest(UOWTest): sess = create_session() n1 = sess.query(Node).first() - n1.data = 'ed' + n1.data = "ed" self._assert_uow_size(sess, 2) n1.children @@ -1036,9 +1074,11 @@ class SingleCycleTest(UOWTest): def test_delete_unloaded_m2o(self): Node, nodes = self.classes.Node, self.tables.nodes - mapper(Node, nodes, properties={ - 'parent': relationship(Node, remote_side=nodes.c.id) - }) + mapper( + Node, + nodes, + properties={"parent": relationship(Node, remote_side=nodes.c.id)}, + ) parent = Node() c1, c2 = Node(parent=parent), Node(parent=parent) @@ -1076,30 +1116,30 @@ class SingleCycleTest(UOWTest): "nodes_parent_id, " "nodes.data AS nodes_data FROM nodes " "WHERE nodes.id = :param_1", - lambda ctx: {'param_1': pid} + lambda ctx: {"param_1": pid}, ), CompiledSQL( "SELECT nodes.id AS nodes_id, nodes.parent_id AS " "nodes_parent_id, " "nodes.data AS nodes_data FROM nodes " "WHERE nodes.id = :param_1", - lambda ctx: {'param_1': c1id} + lambda ctx: {"param_1": c1id}, ), CompiledSQL( "SELECT nodes.id AS nodes_id, nodes.parent_id AS " "nodes_parent_id, " "nodes.data AS nodes_data FROM nodes " "WHERE nodes.id = :param_1", - lambda ctx: {'param_1': c2id} + lambda ctx: {"param_1": c2id}, ), AllOf( CompiledSQL( "DELETE FROM nodes WHERE nodes.id = :id", - lambda ctx: [{'id': c1id}, {'id': c2id}] + lambda ctx: [{"id": c1id}, {"id": c2id}], ), CompiledSQL( "DELETE FROM nodes WHERE nodes.id = :id", - lambda ctx: {'id': pid} + lambda ctx: {"id": pid}, ), ), ), @@ -1107,24 +1147,28 @@ class SingleCycleTest(UOWTest): class SingleCyclePlusAttributeTest( - fixtures.MappedTest, - testing.AssertsExecutionResults, - AssertsUOW): - + fixtures.MappedTest, testing.AssertsExecutionResults, AssertsUOW +): @classmethod def define_tables(cls, metadata): - Table('nodes', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, ForeignKey('nodes.id')), - Column('data', String(30)) - ) - - Table('foobars', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, ForeignKey('nodes.id')), - ) + Table( + "nodes", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_id", Integer, ForeignKey("nodes.id")), + Column("data", String(30)), + ) + + Table( + "foobars", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("parent_id", Integer, ForeignKey("nodes.id")), + ) def test_flush_size(self): foobars, nodes = self.tables.foobars, self.tables.nodes @@ -1135,15 +1179,19 @@ class SingleCyclePlusAttributeTest( class FooBar(fixtures.ComparableEntity): pass - mapper(Node, nodes, properties={ - 'children': relationship(Node), - 'foobars': relationship(FooBar) - }) + mapper( + Node, + nodes, + properties={ + "children": relationship(Node), + "foobars": relationship(FooBar), + }, + ) mapper(FooBar, foobars) sess = create_session() - n1 = Node(data='n1') - n2 = Node(data='n2') + n1 = Node(data="n1") + n2 = Node(data="n2") n1.children.append(n2) sess.add(n1) # ensure "foobars" doesn't get yanked in here @@ -1158,28 +1206,36 @@ class SingleCyclePlusAttributeTest( sess.flush() -class SingleCycleM2MTest(fixtures.MappedTest, - testing.AssertsExecutionResults, AssertsUOW): - +class SingleCycleM2MTest( + fixtures.MappedTest, testing.AssertsExecutionResults, AssertsUOW +): @classmethod def define_tables(cls, metadata): Table( - 'nodes', metadata, + "nodes", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column( - 'data', String(30)), Column( - 'favorite_node_id', Integer, ForeignKey('nodes.id'))) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + Column("favorite_node_id", Integer, ForeignKey("nodes.id")), + ) Table( - 'node_to_nodes', metadata, + "node_to_nodes", + metadata, Column( - 'left_node_id', Integer, - ForeignKey('nodes.id'), primary_key=True), + "left_node_id", + Integer, + ForeignKey("nodes.id"), + primary_key=True, + ), Column( - 'right_node_id', Integer, - ForeignKey('nodes.id'), primary_key=True), + "right_node_id", + Integer, + ForeignKey("nodes.id"), + primary_key=True, + ), ) def test_many_to_many_one(self): @@ -1192,22 +1248,23 @@ class SingleCycleM2MTest(fixtures.MappedTest, Node, nodes, properties={ - 'children': relationship( + "children": relationship( Node, secondary=node_to_nodes, primaryjoin=nodes.c.id == node_to_nodes.c.left_node_id, secondaryjoin=nodes.c.id == node_to_nodes.c.right_node_id, - backref='parents'), - 'favorite': relationship( - Node, - remote_side=nodes.c.id)}) + backref="parents", + ), + "favorite": relationship(Node, remote_side=nodes.c.id), + }, + ) sess = create_session() - n1 = Node(data='n1') - n2 = Node(data='n2') - n3 = Node(data='n3') - n4 = Node(data='n4') - n5 = Node(data='n5') + n1 = Node(data="n1") + n2 = Node(data="n2") + n3 = Node(data="n3") + n4 = Node(data="n4") + n5 = Node(data="n5") n4.favorite = n3 n1.favorite = n5 @@ -1224,16 +1281,24 @@ class SingleCycleM2MTest(fixtures.MappedTest, # so check the end result sess.flush() eq_( - sess.query(node_to_nodes.c.left_node_id, - node_to_nodes.c.right_node_id). - order_by(node_to_nodes.c.left_node_id, - node_to_nodes.c.right_node_id). - all(), - sorted([ - (n1.id, n2.id), (n1.id, n3.id), (n1.id, n4.id), - (n2.id, n3.id), (n2.id, n5.id), - (n3.id, n5.id), (n3.id, n4.id) - ]) + sess.query( + node_to_nodes.c.left_node_id, node_to_nodes.c.right_node_id + ) + .order_by( + node_to_nodes.c.left_node_id, node_to_nodes.c.right_node_id + ) + .all(), + sorted( + [ + (n1.id, n2.id), + (n1.id, n3.id), + (n1.id, n4.id), + (n2.id, n3.id), + (n2.id, n5.id), + (n3.id, n5.id), + (n3.id, n4.id), + ] + ), ) sess.delete(n1) @@ -1249,21 +1314,21 @@ class SingleCycleM2MTest(fixtures.MappedTest, "nodes, node_to_nodes WHERE :param_1 = " "node_to_nodes.right_node_id AND nodes.id = " "node_to_nodes.left_node_id", - lambda ctx: {'param_1': n1.id}, + lambda ctx: {"param_1": n1.id}, ), CompiledSQL( "DELETE FROM node_to_nodes WHERE " "node_to_nodes.left_node_id = :left_node_id AND " "node_to_nodes.right_node_id = :right_node_id", lambda ctx: [ - {'right_node_id': n2.id, 'left_node_id': n1.id}, - {'right_node_id': n3.id, 'left_node_id': n1.id}, - {'right_node_id': n4.id, 'left_node_id': n1.id} - ] + {"right_node_id": n2.id, "left_node_id": n1.id}, + {"right_node_id": n3.id, "left_node_id": n1.id}, + {"right_node_id": n4.id, "left_node_id": n1.id}, + ], ), CompiledSQL( "DELETE FROM nodes WHERE nodes.id = :id", - lambda ctx: {'id': n1.id} + lambda ctx: {"id": n1.id}, ), ) @@ -1283,35 +1348,38 @@ class SingleCycleM2MTest(fixtures.MappedTest, "= :left_node_id AND node_to_nodes.right_node_id = " ":right_node_id", lambda ctx: [ - {'right_node_id': n5.id, 'left_node_id': n3.id}, - {'right_node_id': n4.id, 'left_node_id': n3.id}, - {'right_node_id': n3.id, 'left_node_id': n2.id}, - {'right_node_id': n5.id, 'left_node_id': n2.id} - ] + {"right_node_id": n5.id, "left_node_id": n3.id}, + {"right_node_id": n4.id, "left_node_id": n3.id}, + {"right_node_id": n3.id, "left_node_id": n2.id}, + {"right_node_id": n5.id, "left_node_id": n2.id}, + ], ), CompiledSQL( "DELETE FROM nodes WHERE nodes.id = :id", - lambda ctx: [{'id': n4.id}, {'id': n5.id}] + lambda ctx: [{"id": n4.id}, {"id": n5.id}], ), CompiledSQL( "DELETE FROM nodes WHERE nodes.id = :id", - lambda ctx: [{'id': n2.id}, {'id': n3.id}] + lambda ctx: [{"id": n2.id}, {"id": n3.id}], ), ) class RowswitchAccountingTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('parent', metadata, - Column('id', Integer, primary_key=True), - Column('data', Integer) - ) - Table('child', metadata, - Column('id', Integer, ForeignKey('parent.id'), primary_key=True), - Column('data', Integer) - ) + Table( + "parent", + metadata, + Column("id", Integer, primary_key=True), + Column("data", Integer), + ) + Table( + "child", + metadata, + Column("id", Integer, ForeignKey("parent.id"), primary_key=True), + Column("data", Integer), + ) def _fixture(self): parent, child = self.tables.parent, self.tables.child @@ -1322,11 +1390,18 @@ class RowswitchAccountingTest(fixtures.MappedTest): class Child(fixtures.BasicEntity): pass - mapper(Parent, parent, properties={ - 'child': relationship(Child, uselist=False, - cascade="all, delete-orphan", - backref="parent") - }) + mapper( + Parent, + parent, + properties={ + "child": relationship( + Child, + uselist=False, + cascade="all, delete-orphan", + backref="parent", + ) + }, + ) mapper(Child, child) return Parent, Child @@ -1343,7 +1418,7 @@ class RowswitchAccountingTest(fixtures.MappedTest): p2 = Parent(id=1, child=Child()) p3 = sess.merge(p2) - old = attributes.get_history(p3, 'child')[2][0] + old = attributes.get_history(p3, "child")[2][0] assert old in sess # essentially no SQL should emit here, @@ -1356,7 +1431,7 @@ class RowswitchAccountingTest(fixtures.MappedTest): p4 = Parent(id=1, child=Child()) p5 = sess.merge(p4) - old = attributes.get_history(p5, 'child')[2][0] + old = attributes.get_history(p5, "child")[2][0] assert old in sess sess.flush() @@ -1376,9 +1451,9 @@ class RowswitchAccountingTest(fixtures.MappedTest): eq_( sess.scalar( - select([func.count('*')]).select_from(self.tables.parent) + select([func.count("*")]).select_from(self.tables.parent) ), - 0 + 0, ) sess.close() @@ -1389,21 +1464,16 @@ class RowswitchM2OTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): + Table("a", metadata, Column("id", Integer, primary_key=True)) Table( - 'a', metadata, - Column('id', Integer, primary_key=True), - ) - Table( - 'b', metadata, - Column('id', Integer, primary_key=True), - Column('aid', ForeignKey('a.id')), - Column('cid', ForeignKey('c.id')), - Column('data', String(50)) - ) - Table( - 'c', metadata, - Column('id', Integer, primary_key=True), + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("aid", ForeignKey("a.id")), + Column("cid", ForeignKey("c.id")), + Column("data", String(50)), ) + Table("c", metadata, Column("id", Integer, primary_key=True)) def _fixture(self): a, b, c = self.tables.a, self.tables.b, self.tables.c @@ -1417,12 +1487,12 @@ class RowswitchM2OTest(fixtures.MappedTest): class C(fixtures.BasicEntity): pass - mapper(A, a, properties={ - 'bs': relationship(B, cascade="all, delete-orphan") - }) - mapper(B, b, properties={ - 'c': relationship(C) - }) + mapper( + A, + a, + properties={"bs": relationship(B, cascade="all, delete-orphan")}, + ) + mapper(B, b, properties={"c": relationship(C)}) mapper(C, c) return A, B, C @@ -1439,9 +1509,7 @@ class RowswitchM2OTest(fixtures.MappedTest): A, B, C = self._fixture() sess = Session() - sess.add( - A(id=1, bs=[B(id=1, c=C(id=1))]) - ) + sess.add(A(id=1, bs=[B(id=1, c=C(id=1))])) sess.commit() a1 = sess.query(A).first() @@ -1453,9 +1521,7 @@ class RowswitchM2OTest(fixtures.MappedTest): A, B, C = self._fixture() sess = Session() - sess.add( - A(id=1, bs=[B(id=1, c=C(id=1))]) - ) + sess.add(A(id=1, bs=[B(id=1, c=C(id=1))])) sess.commit() a1 = sess.query(A).first() @@ -1474,9 +1540,7 @@ class RowswitchM2OTest(fixtures.MappedTest): A, B, C = self._fixture() sess = Session() - sess.add( - A(id=1, bs=[B(id=1, data='somedata')]) - ) + sess.add(A(id=1, bs=[B(id=1, data="somedata")])) sess.commit() a1 = sess.query(A).first() @@ -1488,9 +1552,7 @@ class RowswitchM2OTest(fixtures.MappedTest): A, B, C = self._fixture() sess = Session() - sess.add( - A(id=1, bs=[B(id=1, data='somedata')]) - ) + sess.add(A(id=1, bs=[B(id=1, data="somedata")])) sess.commit() a1 = sess.query(A).first() @@ -1503,17 +1565,20 @@ class RowswitchM2OTest(fixtures.MappedTest): class BasicStaleChecksTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('parent', metadata, - Column('id', Integer, primary_key=True), - Column('data', Integer) - ) - Table('child', metadata, - Column('id', Integer, ForeignKey('parent.id'), primary_key=True), - Column('data', Integer) - ) + Table( + "parent", + metadata, + Column("id", Integer, primary_key=True), + Column("data", Integer), + ) + Table( + "child", + metadata, + Column("id", Integer, ForeignKey("parent.id"), primary_key=True), + Column("data", Integer), + ) def _fixture(self, confirm_deleted_rows=True): parent, child = self.tables.parent, self.tables.child @@ -1524,11 +1589,19 @@ class BasicStaleChecksTest(fixtures.MappedTest): class Child(fixtures.BasicEntity): pass - mapper(Parent, parent, properties={ - 'child': relationship(Child, uselist=False, - cascade="all, delete-orphan", - backref="parent"), - }, confirm_deleted_rows=confirm_deleted_rows) + mapper( + Parent, + parent, + properties={ + "child": relationship( + Child, + uselist=False, + cascade="all, delete-orphan", + backref="parent", + ) + }, + confirm_deleted_rows=confirm_deleted_rows, + ) mapper(Child, child) return Parent, Child @@ -1547,7 +1620,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): orm_exc.StaleDataError, r"UPDATE statement on table 'parent' expected to " r"update 1 row\(s\); 0 were matched.", - sess.flush + sess.flush, ) @testing.requires.sane_rowcount @@ -1560,10 +1633,11 @@ class BasicStaleChecksTest(fixtures.MappedTest): return self.context.rowcount with patch.object( - config.db.dialect, "supports_sane_multi_rowcount", False): + config.db.dialect, "supports_sane_multi_rowcount", False + ): with patch( - "sqlalchemy.engine.result.ResultProxy.rowcount", - rowcount): + "sqlalchemy.engine.result.ResultProxy.rowcount", rowcount + ): Parent, Child = self._fixture() sess = Session() p1 = Parent(id=1, data=2) @@ -1577,7 +1651,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): orm_exc.StaleDataError, r"UPDATE statement on table 'parent' expected to " r"update 1 row\(s\); 0 were matched.", - sess.flush + sess.flush, ) def test_update_multi_missing_broken_multi_rowcount(self): @@ -1589,10 +1663,11 @@ class BasicStaleChecksTest(fixtures.MappedTest): return self.context.rowcount with patch.object( - config.db.dialect, "supports_sane_multi_rowcount", False): + config.db.dialect, "supports_sane_multi_rowcount", False + ): with patch( - "sqlalchemy.engine.result.ResultProxy.rowcount", - rowcount): + "sqlalchemy.engine.result.ResultProxy.rowcount", rowcount + ): Parent, Child = self._fixture() sess = Session() p1 = Parent(id=1, data=2) @@ -1607,10 +1682,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): sess.flush() # no exception # update occurred for remaining row - eq_( - sess.query(Parent.id, Parent.data).all(), - [(2, 4)] - ) + eq_(sess.query(Parent.id, Parent.data).all(), [(2, 4)]) def test_update_value_missing_broken_multi_rowcount(self): @util.memoized_property @@ -1621,10 +1693,11 @@ class BasicStaleChecksTest(fixtures.MappedTest): return self.context.rowcount with patch.object( - config.db.dialect, "supports_sane_multi_rowcount", False): + config.db.dialect, "supports_sane_multi_rowcount", False + ): with patch( - "sqlalchemy.engine.result.ResultProxy.rowcount", - rowcount): + "sqlalchemy.engine.result.ResultProxy.rowcount", rowcount + ): Parent, Child = self._fixture() sess = Session() p1 = Parent(id=1, data=1) @@ -1638,7 +1711,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): orm_exc.StaleDataError, r"UPDATE statement on table 'parent' expected to " r"update 1 row\(s\); 0 were matched.", - sess.flush + sess.flush, ) @testing.requires.sane_rowcount @@ -1658,7 +1731,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): exc.SAWarning, r"DELETE statement on table 'parent' expected to " r"delete 1 row\(s\); 0 were matched.", - sess.commit + sess.commit, ) @testing.requires.sane_multi_rowcount @@ -1678,7 +1751,7 @@ class BasicStaleChecksTest(fixtures.MappedTest): exc.SAWarning, r"DELETE statement on table 'parent' expected to " r"delete 2 row\(s\); 0 were matched.", - sess.flush + sess.flush, ) def test_delete_multi_missing_allow(self): @@ -1697,15 +1770,17 @@ class BasicStaleChecksTest(fixtures.MappedTest): class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults): - @classmethod def define_tables(cls, metadata): - Table('t', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)), - Column('def_', String(50), server_default='def1') - ) + Table( + "t", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + Column("def_", String(50), server_default="def1"), + ) def test_batch_interaction(self): """test batching groups same-structured, primary @@ -1717,55 +1792,56 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults): class T(fixtures.ComparableEntity): pass + mapper(T, t) sess = Session() - sess.add_all([ - T(data='t1'), - T(data='t2'), - T(id=3, data='t3'), - T(id=4, data='t4'), - T(id=5, data='t5'), - T(id=6, data=func.lower('t6')), - T(id=7, data='t7'), - T(id=8, data='t8'), - T(id=9, data='t9', def_='def2'), - T(id=10, data='t10', def_='def3'), - T(id=11, data='t11'), - ]) + sess.add_all( + [ + T(data="t1"), + T(data="t2"), + T(id=3, data="t3"), + T(id=4, data="t4"), + T(id=5, data="t5"), + T(id=6, data=func.lower("t6")), + T(id=7, data="t7"), + T(id=8, data="t8"), + T(id=9, data="t9", def_="def2"), + T(id=10, data="t10", def_="def3"), + T(id=11, data="t11"), + ] + ) self.assert_sql_execution( testing.db, sess.flush, - CompiledSQL( - "INSERT INTO t (data) VALUES (:data)", - {'data': 't1'} - ), - CompiledSQL( - "INSERT INTO t (data) VALUES (:data)", - {'data': 't2'} - ), + CompiledSQL("INSERT INTO t (data) VALUES (:data)", {"data": "t1"}), + CompiledSQL("INSERT INTO t (data) VALUES (:data)", {"data": "t2"}), CompiledSQL( "INSERT INTO t (id, data) VALUES (:id, :data)", - [{'data': 't3', 'id': 3}, - {'data': 't4', 'id': 4}, - {'data': 't5', 'id': 5}] + [ + {"data": "t3", "id": 3}, + {"data": "t4", "id": 4}, + {"data": "t5", "id": 5}, + ], ), CompiledSQL( "INSERT INTO t (id, data) VALUES (:id, lower(:lower_1))", - {'lower_1': 't6', 'id': 6} + {"lower_1": "t6", "id": 6}, ), CompiledSQL( "INSERT INTO t (id, data) VALUES (:id, :data)", - [{'data': 't7', 'id': 7}, {'data': 't8', 'id': 8}] + [{"data": "t7", "id": 7}, {"data": "t8", "id": 8}], ), CompiledSQL( "INSERT INTO t (id, data, def_) VALUES (:id, :data, :def_)", - [{'data': 't9', 'id': 9, 'def_': 'def2'}, - {'data': 't10', 'id': 10, 'def_': 'def3'}] + [ + {"data": "t9", "id": 9, "def_": "def2"}, + {"data": "t10", "id": 10, "def_": "def3"}, + ], ), CompiledSQL( "INSERT INTO t (id, data) VALUES (:id, :data)", - {'data': 't11', 'id': 11} + {"data": "t11", "id": 11}, ), ) @@ -1782,17 +1858,25 @@ class LoadersUsingCommittedTest(UOWTest): """ def _mapper_setup(self, passive_updates=True): - users, Address, addresses, User = (self.tables.users, - self.classes.Address, - self.tables.addresses, - self.classes.User) - - mapper(User, users, properties={ - 'addresses': relationship(Address, - order_by=addresses.c.email_address, - passive_updates=passive_updates, - backref='user') - }) + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, + order_by=addresses.c.email_address, + passive_updates=passive_updates, + backref="user", + ) + }, + ) mapper(Address, addresses) return create_session(autocommit=False) @@ -1808,14 +1892,16 @@ class LoadersUsingCommittedTest(UOWTest): # if get committed is used to find target.user, then # it will be still be u1 instead of u2 assert target.user.id == target.user_id == u2.id + from sqlalchemy import event - event.listen(Address, 'before_update', before_update) - a1 = Address(email_address='a1') - u1 = User(name='u1', addresses=[a1]) + event.listen(Address, "before_update", before_update) + + a1 = Address(email_address="a1") + u1 = User(name="u1", addresses=[a1]) sess.add(u1) - u2 = User(name='u2') + u2 = User(name="u2") sess.add(u2) sess.commit() @@ -1864,7 +1950,7 @@ class LoadersUsingCommittedTest(UOWTest): # we expect no related items in the collection # since we are using passive_updates # this is a behavior change since #2350 - assert 'addresses' not in target.__dict__ + assert "addresses" not in target.__dict__ eq_(target.addresses, []) else: # in contrast with passive_updates=True, @@ -1875,16 +1961,17 @@ class LoadersUsingCommittedTest(UOWTest): # (just like they will be after the update) # collection is already loaded - assert 'addresses' in target.__dict__ - eq_([a.id for a in target.addresses], - [a.id for a in [a1, a2]]) + assert "addresses" in target.__dict__ + eq_([a.id for a in target.addresses], [a.id for a in [a1, a2]]) raise AvoidReferencialError() + from sqlalchemy import event - event.listen(User, 'before_update', before_update) - a1 = Address(email_address='jack1') - a2 = Address(email_address='jack2') - u1 = User(id=1, name='jack', addresses=[a1, a2]) + event.listen(User, "before_update", before_update) + + a1 = Address(email_address="jack1") + a2 = Address(email_address="jack2") + u1 = User(id=1, name="jack", addresses=[a1, a2]) sess.add(u1) sess.commit() @@ -1911,11 +1998,13 @@ class NoAttrEventInFlushTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'test', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('prefetch_val', Integer, default=5), - Column('returning_val', Integer, server_default="5") + "test", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("prefetch_val", Integer, default=5), + Column("returning_val", Integer, server_default="5"), ) @classmethod @@ -1952,16 +2041,18 @@ class EagerDefaultsTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'test', metadata, - Column('id', Integer, primary_key=True), - Column('foo', Integer, server_default="3") + "test", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer, server_default="3"), ) Table( - 'test2', metadata, - Column('id', Integer, primary_key=True), - Column('foo', Integer), - Column('bar', Integer, server_onupdate=FetchedValue()) + "test2", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer), + Column("bar", Integer, server_onupdate=FetchedValue()), ) @classmethod @@ -1986,10 +2077,7 @@ class EagerDefaultsTest(fixtures.MappedTest): Thing = self.classes.Thing s = Session() - t1, t2 = ( - Thing(id=1, foo=5), - Thing(id=2, foo=10) - ) + t1, t2 = (Thing(id=1, foo=5), Thing(id=2, foo=10)) s.add_all([t1, t2]) @@ -1998,7 +2086,7 @@ class EagerDefaultsTest(fixtures.MappedTest): s.flush, CompiledSQL( "INSERT INTO test (id, foo) VALUES (:id, :foo)", - [{'foo': 5, 'id': 1}, {'foo': 10, 'id': 2}] + [{"foo": 5, "id": 1}, {"foo": 10, "id": 2}], ), ) @@ -2014,7 +2102,7 @@ class EagerDefaultsTest(fixtures.MappedTest): t1, t2 = ( Thing(id=1, foo=text("2 + 5")), - Thing(id=2, foo=text("5 + 5")) + Thing(id=2, foo=text("5 + 5")), ) s.add_all([t1, t2]) @@ -2026,15 +2114,15 @@ class EagerDefaultsTest(fixtures.MappedTest): CompiledSQL( "INSERT INTO test (id, foo) VALUES (%(id)s, 2 + 5) " "RETURNING test.foo", - [{'id': 1}], - dialect='postgresql' + [{"id": 1}], + dialect="postgresql", ), CompiledSQL( "INSERT INTO test (id, foo) VALUES (%(id)s, 5 + 5) " "RETURNING test.foo", - [{'id': 2}], - dialect='postgresql' - ) + [{"id": 2}], + dialect="postgresql", + ), ) else: @@ -2043,21 +2131,21 @@ class EagerDefaultsTest(fixtures.MappedTest): s.flush, CompiledSQL( "INSERT INTO test (id, foo) VALUES (:id, 2 + 5)", - [{'id': 1}] + [{"id": 1}], ), CompiledSQL( "INSERT INTO test (id, foo) VALUES (:id, 5 + 5)", - [{'id': 2}] + [{"id": 2}], ), CompiledSQL( "SELECT test.foo AS test_foo FROM test " "WHERE test.id = :param_1", - [{'param_1': 1}] + [{"param_1": 1}], ), CompiledSQL( "SELECT test.foo AS test_foo FROM test " "WHERE test.id = :param_1", - [{'param_1': 2}] + [{"param_1": 2}], ), ) @@ -2071,10 +2159,7 @@ class EagerDefaultsTest(fixtures.MappedTest): Thing = self.classes.Thing s = Session() - t1, t2 = ( - Thing(id=1), - Thing(id=2) - ) + t1, t2 = (Thing(id=1), Thing(id=2)) s.add_all([t1, t2]) @@ -2084,13 +2169,13 @@ class EagerDefaultsTest(fixtures.MappedTest): s.commit, CompiledSQL( "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo", - [{'id': 1}], - dialect='postgresql' + [{"id": 1}], + dialect="postgresql", ), CompiledSQL( "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo", - [{'id': 2}], - dialect='postgresql' + [{"id": 2}], + dialect="postgresql", ), ) else: @@ -2099,18 +2184,18 @@ class EagerDefaultsTest(fixtures.MappedTest): s.commit, CompiledSQL( "INSERT INTO test (id) VALUES (:id)", - [{'id': 1}, {'id': 2}] + [{"id": 1}, {"id": 2}], ), CompiledSQL( "SELECT test.foo AS test_foo FROM test " "WHERE test.id = :param_1", - [{'param_1': 1}] + [{"param_1": 1}], ), CompiledSQL( "SELECT test.foo AS test_foo FROM test " "WHERE test.id = :param_1", - [{'param_1': 2}] - ) + [{"param_1": 2}], + ), ) def test_update_defaults_nonpresent(self): @@ -2121,7 +2206,7 @@ class EagerDefaultsTest(fixtures.MappedTest): Thing2(id=1, foo=1, bar=2), Thing2(id=2, foo=2, bar=3), Thing2(id=3, foo=3, bar=4), - Thing2(id=4, foo=4, bar=5) + Thing2(id=4, foo=4, bar=5), ) s.add_all([t1, t2, t3, t4]) @@ -2142,27 +2227,27 @@ class EagerDefaultsTest(fixtures.MappedTest): "UPDATE test2 SET foo=%(foo)s " "WHERE test2.id = %(test2_id)s " "RETURNING test2.bar", - [{'foo': 5, 'test2_id': 1}], - dialect='postgresql' + [{"foo": 5, "test2_id": 1}], + dialect="postgresql", ), CompiledSQL( "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s " "WHERE test2.id = %(test2_id)s", - [{'foo': 6, 'bar': 10, 'test2_id': 2}], - dialect='postgresql' + [{"foo": 6, "bar": 10, "test2_id": 2}], + dialect="postgresql", ), CompiledSQL( "UPDATE test2 SET foo=%(foo)s " "WHERE test2.id = %(test2_id)s " "RETURNING test2.bar", - [{'foo': 7, 'test2_id': 3}], - dialect='postgresql' + [{"foo": 7, "test2_id": 3}], + dialect="postgresql", ), CompiledSQL( "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s " "WHERE test2.id = %(test2_id)s", - [{'foo': 8, 'bar': 12, 'test2_id': 4}], - dialect='postgresql' + [{"foo": 8, "bar": 12, "test2_id": 4}], + dialect="postgresql", ), ) else: @@ -2171,32 +2256,32 @@ class EagerDefaultsTest(fixtures.MappedTest): s.flush, CompiledSQL( "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", - [{'foo': 5, 'test2_id': 1}] + [{"foo": 5, "test2_id": 1}], ), CompiledSQL( "UPDATE test2 SET foo=:foo, bar=:bar " "WHERE test2.id = :test2_id", - [{'foo': 6, 'bar': 10, 'test2_id': 2}], + [{"foo": 6, "bar": 10, "test2_id": 2}], ), CompiledSQL( "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", - [{'foo': 7, 'test2_id': 3}] + [{"foo": 7, "test2_id": 3}], ), CompiledSQL( "UPDATE test2 SET foo=:foo, bar=:bar " "WHERE test2.id = :test2_id", - [{'foo': 8, 'bar': 12, 'test2_id': 4}], + [{"foo": 8, "bar": 12, "test2_id": 4}], ), CompiledSQL( "SELECT test2.bar AS test2_bar FROM test2 " "WHERE test2.id = :param_1", - [{'param_1': 1}] + [{"param_1": 1}], ), CompiledSQL( "SELECT test2.bar AS test2_bar FROM test2 " "WHERE test2.id = :param_1", - [{'param_1': 3}] - ) + [{"param_1": 3}], + ), ) def go(): @@ -2215,7 +2300,7 @@ class EagerDefaultsTest(fixtures.MappedTest): Thing2(id=1, foo=1, bar=2), Thing2(id=2, foo=2, bar=3), Thing2(id=3, foo=3, bar=4), - Thing2(id=4, foo=4, bar=5) + Thing2(id=4, foo=4, bar=5), ) s.add_all([t1, t2, t3, t4]) @@ -2237,27 +2322,27 @@ class EagerDefaultsTest(fixtures.MappedTest): "UPDATE test2 SET foo=%(foo)s, bar=1 + 1 " "WHERE test2.id = %(test2_id)s " "RETURNING test2.bar", - [{'foo': 5, 'test2_id': 1}], - dialect='postgresql' + [{"foo": 5, "test2_id": 1}], + dialect="postgresql", ), CompiledSQL( "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s " "WHERE test2.id = %(test2_id)s", - [{'foo': 6, 'bar': 10, 'test2_id': 2}], - dialect='postgresql' + [{"foo": 6, "bar": 10, "test2_id": 2}], + dialect="postgresql", ), CompiledSQL( "UPDATE test2 SET foo=%(foo)s " "WHERE test2.id = %(test2_id)s " "RETURNING test2.bar", - [{'foo': 7, 'test2_id': 3}], - dialect='postgresql' + [{"foo": 7, "test2_id": 3}], + dialect="postgresql", ), CompiledSQL( "UPDATE test2 SET foo=%(foo)s, bar=5 + 7 " "WHERE test2.id = %(test2_id)s RETURNING test2.bar", - [{'foo': 8, 'test2_id': 4}], - dialect='postgresql' + [{"foo": 8, "test2_id": 4}], + dialect="postgresql", ), ) else: @@ -2267,37 +2352,37 @@ class EagerDefaultsTest(fixtures.MappedTest): CompiledSQL( "UPDATE test2 SET foo=:foo, bar=1 + 1 " "WHERE test2.id = :test2_id", - [{'foo': 5, 'test2_id': 1}] + [{"foo": 5, "test2_id": 1}], ), CompiledSQL( "UPDATE test2 SET foo=:foo, bar=:bar " "WHERE test2.id = :test2_id", - [{'foo': 6, 'bar': 10, 'test2_id': 2}], + [{"foo": 6, "bar": 10, "test2_id": 2}], ), CompiledSQL( "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", - [{'foo': 7, 'test2_id': 3}] + [{"foo": 7, "test2_id": 3}], ), CompiledSQL( "UPDATE test2 SET foo=:foo, bar=5 + 7 " "WHERE test2.id = :test2_id", - [{'foo': 8, 'test2_id': 4}], + [{"foo": 8, "test2_id": 4}], ), CompiledSQL( "SELECT test2.bar AS test2_bar FROM test2 " "WHERE test2.id = :param_1", - [{'param_1': 1}] + [{"param_1": 1}], ), CompiledSQL( "SELECT test2.bar AS test2_bar FROM test2 " "WHERE test2.id = :param_1", - [{'param_1': 3}] + [{"param_1": 3}], ), CompiledSQL( "SELECT test2.bar AS test2_bar FROM test2 " "WHERE test2.id = :param_1", - [{'param_1': 4}] - ) + [{"param_1": 4}], + ), ) def go(): @@ -2312,18 +2397,14 @@ class EagerDefaultsTest(fixtures.MappedTest): Thing = self.classes.Thing s = Session() - mappings = [ - {"id": 1}, - {"id": 2} - ] + mappings = [{"id": 1}, {"id": 2}] self.assert_sql_execution( testing.db, lambda: s.bulk_insert_mappings(Thing, mappings), CompiledSQL( - "INSERT INTO test (id) VALUES (:id)", - [{'id': 1}, {'id': 2}] - ) + "INSERT INTO test (id) VALUES (:id)", [{"id": 1}, {"id": 2}] + ), ) def test_update_defaults_bulk_update(self): @@ -2334,7 +2415,7 @@ class EagerDefaultsTest(fixtures.MappedTest): Thing2(id=1, foo=1, bar=2), Thing2(id=2, foo=2, bar=3), Thing2(id=3, foo=3, bar=4), - Thing2(id=4, foo=4, bar=5) + Thing2(id=4, foo=4, bar=5), ) s.add_all([t1, t2, t3, t4]) @@ -2344,7 +2425,7 @@ class EagerDefaultsTest(fixtures.MappedTest): {"id": 1, "foo": 5}, {"id": 2, "foo": 6, "bar": 10}, {"id": 3, "foo": 7}, - {"id": 4, "foo": 8} + {"id": 4, "foo": 8}, ] self.assert_sql_execution( @@ -2352,27 +2433,24 @@ class EagerDefaultsTest(fixtures.MappedTest): lambda: s.bulk_update_mappings(Thing2, mappings), CompiledSQL( "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", - [{'foo': 5, 'test2_id': 1}] + [{"foo": 5, "test2_id": 1}], ), CompiledSQL( "UPDATE test2 SET foo=:foo, bar=:bar " "WHERE test2.id = :test2_id", - [{'foo': 6, 'bar': 10, 'test2_id': 2}] + [{"foo": 6, "bar": 10, "test2_id": 2}], ), CompiledSQL( "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", - [{'foo': 7, 'test2_id': 3}, {'foo': 8, 'test2_id': 4}] - ) + [{"foo": 7, "test2_id": 3}, {"foo": 8, "test2_id": 4}], + ), ) def test_update_defaults_present(self): Thing2 = self.classes.Thing2 s = Session() - t1, t2 = ( - Thing2(id=1, foo=1, bar=2), - Thing2(id=2, foo=2, bar=3) - ) + t1, t2 = (Thing2(id=1, foo=1, bar=2), Thing2(id=2, foo=2, bar=3)) s.add_all([t1, t2]) s.flush() @@ -2385,9 +2463,9 @@ class EagerDefaultsTest(fixtures.MappedTest): s.commit, CompiledSQL( "UPDATE test2 SET bar=%(bar)s WHERE test2.id = %(test2_id)s", - [{'bar': 5, 'test2_id': 1}, {'bar': 10, 'test2_id': 2}], - dialect='postgresql' - ) + [{"bar": 5, "test2_id": 1}, {"bar": 10, "test2_id": 2}], + dialect="postgresql", + ), ) def test_insert_dont_fetch_nondefaults(self): @@ -2402,10 +2480,9 @@ class EagerDefaultsTest(fixtures.MappedTest): testing.db, s.flush, CompiledSQL( - "INSERT INTO test2 (id, foo, bar) " - "VALUES (:id, :foo, :bar)", - [{'id': 1, 'foo': None, 'bar': 2}] - ) + "INSERT INTO test2 (id, foo, bar) " "VALUES (:id, :foo, :bar)", + [{"id": 1, "foo": None, "bar": 2}], + ), ) def test_update_dont_fetch_nondefaults(self): @@ -2417,7 +2494,7 @@ class EagerDefaultsTest(fixtures.MappedTest): s.add(t1) s.flush() - s.expire(t1, ['foo']) + s.expire(t1, ["foo"]) t1.bar = 3 @@ -2426,8 +2503,8 @@ class EagerDefaultsTest(fixtures.MappedTest): s.flush, CompiledSQL( "UPDATE test2 SET bar=:bar WHERE test2.id = :test2_id", - [{'bar': 3, 'test2_id': 1}] - ) + [{"bar": 3, "test2_id": 1}], + ), ) @@ -2466,11 +2543,13 @@ class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults): return value Table( - 'test', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('value', MyType), - Column('unrelated', String(50)) + "test", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("value", MyType), + Column("unrelated", String(50)), ) @classmethod @@ -2495,9 +2574,7 @@ class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults): t1.value = None s.commit() - eq_( - s.query(Thing.value).scalar(), None - ) + eq_(s.query(Thing.value).scalar(), None) def test_update_against_something_else(self): Thing = self.classes.Thing @@ -2510,19 +2587,17 @@ class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults): t1.value = self.MyWidget("bar") s.commit() - eq_( - s.query(Thing.value).scalar().text, "bar" - ) + eq_(s.query(Thing.value).scalar().text, "bar") def test_no_update_no_change(self): Thing = self.classes.Thing s = Session() - s.add(Thing(value=self.MyWidget("foo"), unrelated='unrelated')) + s.add(Thing(value=self.MyWidget("foo"), unrelated="unrelated")) s.commit() t1 = s.query(Thing).first() - t1.unrelated = 'something else' + t1.unrelated = "something else" self.assert_sql_execution( testing.db, @@ -2530,13 +2605,11 @@ class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults): CompiledSQL( "UPDATE test SET unrelated=:unrelated " "WHERE test.id = :test_id", - [{'test_id': 1, 'unrelated': 'something else'}] + [{"test_id": 1, "unrelated": "something else"}], ), ) - eq_( - s.query(Thing.value).scalar().text, "foo" - ) + eq_(s.query(Thing.value).scalar().text, "foo") class NullEvaluatingTest(fixtures.MappedTest, testing.AssertsExecutionResults): @@ -2551,37 +2624,47 @@ class NullEvaluatingTest(fixtures.MappedTest, testing.AssertsExecutionResults): def process_bind_param(self, value, dialect): if value is None: - value = 'nothing' + value = "nothing" return value Table( - 'test', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('evals_null_no_default', EvalsNull()), - Column('evals_null_default', EvalsNull(), default='default_val'), - Column('no_eval_null_no_default', String(50)), - Column('no_eval_null_default', String(50), default='default_val'), + "test", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("evals_null_no_default", EvalsNull()), + Column("evals_null_default", EvalsNull(), default="default_val"), + Column("no_eval_null_no_default", String(50)), + Column("no_eval_null_default", String(50), default="default_val"), Column( - 'builtin_evals_null_no_default', String(50).evaluates_none()), + "builtin_evals_null_no_default", String(50).evaluates_none() + ), Column( - 'builtin_evals_null_default', - String(50).evaluates_none(), default='default_val'), + "builtin_evals_null_default", + String(50).evaluates_none(), + default="default_val", + ), ) Table( - 'test_w_renames', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('evals_null_no_default', EvalsNull()), - Column('evals_null_default', EvalsNull(), default='default_val'), - Column('no_eval_null_no_default', String(50)), - Column('no_eval_null_default', String(50), default='default_val'), + "test_w_renames", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("evals_null_no_default", EvalsNull()), + Column("evals_null_default", EvalsNull(), default="default_val"), + Column("no_eval_null_no_default", String(50)), + Column("no_eval_null_default", String(50), default="default_val"), Column( - 'builtin_evals_null_no_default', String(50).evaluates_none()), + "builtin_evals_null_no_default", String(50).evaluates_none() + ), Column( - 'builtin_evals_null_default', - String(50).evaluates_none(), default='default_val'), + "builtin_evals_null_default", + String(50).evaluates_none(), + default="default_val", + ), ) @classmethod @@ -2631,12 +2714,8 @@ class NullEvaluatingTest(fixtures.MappedTest, testing.AssertsExecutionResults): Thing, AltNameThing = self.classes.Thing, self.classes.AltNameThing s = Session() - s.bulk_insert_mappings( - Thing, [{attr: None}] - ) - s.bulk_insert_mappings( - AltNameThing, [{"_foo_" + attr: None}] - ) + s.bulk_insert_mappings(Thing, [{attr: None}]) + s.bulk_insert_mappings(AltNameThing, [{"_foo_" + attr: None}]) s.commit() self._assert_col(attr, expected) @@ -2659,132 +2738,82 @@ class NullEvaluatingTest(fixtures.MappedTest, testing.AssertsExecutionResults): Thing, AltNameThing = self.classes.Thing, self.classes.AltNameThing s = Session() - s.bulk_insert_mappings( - Thing, [{}] - ) - s.bulk_insert_mappings( - AltNameThing, [{}] - ) + s.bulk_insert_mappings(Thing, [{}]) + s.bulk_insert_mappings(AltNameThing, [{}]) s.commit() self._assert_col(attr, expected) def test_evalnull_nodefault_insert(self): - self._test_insert( - "evals_null_no_default", 'nothing' - ) + self._test_insert("evals_null_no_default", "nothing") def test_evalnull_nodefault_bulk_insert(self): - self._test_bulk_insert( - "evals_null_no_default", 'nothing' - ) + self._test_bulk_insert("evals_null_no_default", "nothing") def test_evalnull_nodefault_insert_novalue(self): - self._test_insert_novalue( - "evals_null_no_default", None - ) + self._test_insert_novalue("evals_null_no_default", None) def test_evalnull_nodefault_bulk_insert_novalue(self): - self._test_bulk_insert_novalue( - "evals_null_no_default", None - ) + self._test_bulk_insert_novalue("evals_null_no_default", None) def test_evalnull_default_insert(self): - self._test_insert( - "evals_null_default", 'nothing' - ) + self._test_insert("evals_null_default", "nothing") def test_evalnull_default_bulk_insert(self): - self._test_bulk_insert( - "evals_null_default", 'nothing' - ) + self._test_bulk_insert("evals_null_default", "nothing") def test_evalnull_default_insert_novalue(self): - self._test_insert_novalue( - "evals_null_default", 'default_val' - ) + self._test_insert_novalue("evals_null_default", "default_val") def test_evalnull_default_bulk_insert_novalue(self): - self._test_bulk_insert_novalue( - "evals_null_default", 'default_val' - ) + self._test_bulk_insert_novalue("evals_null_default", "default_val") def test_no_evalnull_nodefault_insert(self): - self._test_insert( - "no_eval_null_no_default", None - ) + self._test_insert("no_eval_null_no_default", None) def test_no_evalnull_nodefault_bulk_insert(self): - self._test_bulk_insert( - "no_eval_null_no_default", None - ) + self._test_bulk_insert("no_eval_null_no_default", None) def test_no_evalnull_nodefault_insert_novalue(self): - self._test_insert_novalue( - "no_eval_null_no_default", None - ) + self._test_insert_novalue("no_eval_null_no_default", None) def test_no_evalnull_nodefault_bulk_insert_novalue(self): - self._test_bulk_insert_novalue( - "no_eval_null_no_default", None - ) + self._test_bulk_insert_novalue("no_eval_null_no_default", None) def test_no_evalnull_default_insert(self): - self._test_insert( - "no_eval_null_default", 'default_val' - ) + self._test_insert("no_eval_null_default", "default_val") def test_no_evalnull_default_bulk_insert(self): - self._test_bulk_insert( - "no_eval_null_default", 'default_val' - ) + self._test_bulk_insert("no_eval_null_default", "default_val") def test_no_evalnull_default_insert_novalue(self): - self._test_insert_novalue( - "no_eval_null_default", 'default_val' - ) + self._test_insert_novalue("no_eval_null_default", "default_val") def test_no_evalnull_default_bulk_insert_novalue(self): - self._test_bulk_insert_novalue( - "no_eval_null_default", 'default_val' - ) + self._test_bulk_insert_novalue("no_eval_null_default", "default_val") def test_builtin_evalnull_nodefault_insert(self): - self._test_insert( - "builtin_evals_null_no_default", None - ) + self._test_insert("builtin_evals_null_no_default", None) def test_builtin_evalnull_nodefault_bulk_insert(self): - self._test_bulk_insert( - "builtin_evals_null_no_default", None - ) + self._test_bulk_insert("builtin_evals_null_no_default", None) def test_builtin_evalnull_nodefault_insert_novalue(self): - self._test_insert_novalue( - "builtin_evals_null_no_default", None - ) + self._test_insert_novalue("builtin_evals_null_no_default", None) def test_builtin_evalnull_nodefault_bulk_insert_novalue(self): - self._test_bulk_insert_novalue( - "builtin_evals_null_no_default", None - ) + self._test_bulk_insert_novalue("builtin_evals_null_no_default", None) def test_builtin_evalnull_default_insert(self): - self._test_insert( - "builtin_evals_null_default", None - ) + self._test_insert("builtin_evals_null_default", None) def test_builtin_evalnull_default_bulk_insert(self): - self._test_bulk_insert( - "builtin_evals_null_default", None - ) + self._test_bulk_insert("builtin_evals_null_default", None) def test_builtin_evalnull_default_insert_novalue(self): - self._test_insert_novalue( - "builtin_evals_null_default", 'default_val' - ) + self._test_insert_novalue("builtin_evals_null_default", "default_val") def test_builtin_evalnull_default_bulk_insert_novalue(self): self._test_bulk_insert_novalue( - "builtin_evals_null_default", 'default_val' + "builtin_evals_null_default", "default_val" ) diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index d1ea22dcc4..1c92091a82 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -1,9 +1,27 @@ from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, is_ from sqlalchemy.testing import fixtures -from sqlalchemy import Integer, String, ForeignKey, or_, exc, \ - select, func, Boolean, case, text, column -from sqlalchemy.orm import mapper, relationship, backref, Session, \ - joinedload, synonym, query +from sqlalchemy import ( + Integer, + String, + ForeignKey, + or_, + exc, + select, + func, + Boolean, + case, + text, + column, +) +from sqlalchemy.orm import ( + mapper, + relationship, + backref, + Session, + joinedload, + synonym, + query, +) from sqlalchemy import testing from sqlalchemy.testing import mock @@ -15,15 +33,20 @@ class UpdateDeleteTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(32)), - Column('age_int', Integer)) Table( - "addresses", metadata, - Column('id', Integer, primary_key=True), - Column('user_id', ForeignKey('users.id')) + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(32)), + Column("age_int", Integer), + ) + Table( + "addresses", + metadata, + Column("id", Integer, primary_key=True), + Column("user_id", ForeignKey("users.id")), ) @classmethod @@ -38,12 +61,14 @@ class UpdateDeleteTest(fixtures.MappedTest): def insert_data(cls): users = cls.tables.users - users.insert().execute([ - dict(id=1, name='john', age_int=25), - dict(id=2, name='jack', age_int=47), - dict(id=3, name='jill', age_int=29), - dict(id=4, name='jane', age_int=37), - ]) + users.insert().execute( + [ + dict(id=1, name="john", age_int=25), + dict(id=2, name="jack", age_int=47), + dict(id=3, name="jill", age_int=29), + dict(id=4, name="jane", age_int=37), + ] + ) @classmethod def setup_mappers(cls): @@ -53,10 +78,14 @@ class UpdateDeleteTest(fixtures.MappedTest): Address = cls.classes.Address addresses = cls.tables.addresses - mapper(User, users, properties={ - 'age': users.c.age_int, - 'addresses': relationship(Address) - }) + mapper( + User, + users, + properties={ + "age": users.c.age_int, + "addresses": relationship(Address), + }, + ) mapper(Address, addresses) def test_illegal_eval(self): @@ -68,7 +97,7 @@ class UpdateDeleteTest(fixtures.MappedTest): "are 'evaluate', 'fetch', False", s.query(User).update, {}, - synchronize_session="fake" + synchronize_session="fake", ) def test_illegal_operations(self): @@ -84,26 +113,36 @@ class UpdateDeleteTest(fixtures.MappedTest): (s.query(User).order_by(User.id), r"order_by\(\)"), (s.query(User).group_by(User.id), r"group_by\(\)"), (s.query(User).distinct(), r"distinct\(\)"), - (s.query(User).join(User.addresses), - r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)"), - (s.query(User).outerjoin(User.addresses), - r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)"), - (s.query(User).select_from(Address), - r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)"), - (s.query(User).from_self(), - r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)"), + ( + s.query(User).join(User.addresses), + r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)", + ), + ( + s.query(User).outerjoin(User.addresses), + r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)", + ), + ( + s.query(User).select_from(Address), + r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)", + ), + ( + s.query(User).from_self(), + r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)", + ), ): assert_raises_message( exc.InvalidRequestError, r"Can't call Query.update\(\) or Query.delete\(\) when " "%s has been called" % mname, q.update, - {'name': 'ed'}) + {"name": "ed"}, + ) assert_raises_message( exc.InvalidRequestError, r"Can't call Query.update\(\) or Query.delete\(\) when " "%s has been called" % mname, - q.delete) + q.delete, + ) def test_evaluate_clauseelement(self): User = self.classes.User @@ -115,9 +154,9 @@ class UpdateDeleteTest(fixtures.MappedTest): s = Session() jill = s.query(User).get(3) s.query(User).update( - {Thing(): 'moonbeam'}, - synchronize_session='evaluate') - eq_(jill.name, 'moonbeam') + {Thing(): "moonbeam"}, synchronize_session="evaluate" + ) + eq_(jill.name, "moonbeam") def test_evaluate_invalid(self): User = self.classes.User @@ -131,8 +170,9 @@ class UpdateDeleteTest(fixtures.MappedTest): assert_raises_message( exc.InvalidRequestError, "Invalid expression type: 5", - s.query(User).update, {Thing(): 'moonbeam'}, - synchronize_session='evaluate' + s.query(User).update, + {Thing(): "moonbeam"}, + synchronize_session="evaluate", ) def test_evaluate_unmapped_col(self): @@ -141,54 +181,54 @@ class UpdateDeleteTest(fixtures.MappedTest): s = Session() jill = s.query(User).get(3) s.query(User).update( - {column('name'): 'moonbeam'}, - synchronize_session='evaluate') - eq_(jill.name, 'jill') + {column("name"): "moonbeam"}, synchronize_session="evaluate" + ) + eq_(jill.name, "jill") s.expire(jill) - eq_(jill.name, 'moonbeam') + eq_(jill.name, "moonbeam") def test_evaluate_synonym_string(self): class Foo(object): pass - mapper(Foo, self.tables.users, properties={ - 'uname': synonym("name", ) - }) + + mapper(Foo, self.tables.users, properties={"uname": synonym("name")}) s = Session() jill = s.query(Foo).get(3) s.query(Foo).update( - {'uname': 'moonbeam'}, - synchronize_session='evaluate') - eq_(jill.uname, 'moonbeam') + {"uname": "moonbeam"}, synchronize_session="evaluate" + ) + eq_(jill.uname, "moonbeam") def test_evaluate_synonym_attr(self): class Foo(object): pass - mapper(Foo, self.tables.users, properties={ - 'uname': synonym("name", ) - }) + + mapper(Foo, self.tables.users, properties={"uname": synonym("name")}) s = Session() jill = s.query(Foo).get(3) s.query(Foo).update( - {Foo.uname: 'moonbeam'}, - synchronize_session='evaluate') - eq_(jill.uname, 'moonbeam') + {Foo.uname: "moonbeam"}, synchronize_session="evaluate" + ) + eq_(jill.uname, "moonbeam") def test_evaluate_double_synonym_attr(self): class Foo(object): pass - mapper(Foo, self.tables.users, properties={ - 'uname': synonym("name"), - 'ufoo': synonym('uname') - }) + + mapper( + Foo, + self.tables.users, + properties={"uname": synonym("name"), "ufoo": synonym("uname")}, + ) s = Session() jill = s.query(Foo).get(3) s.query(Foo).update( - {Foo.ufoo: 'moonbeam'}, - synchronize_session='evaluate') - eq_(jill.ufoo, 'moonbeam') + {Foo.ufoo: "moonbeam"}, synchronize_session="evaluate" + ) + eq_(jill.ufoo, "moonbeam") def test_delete(self): User = self.classes.User @@ -197,7 +237,8 @@ class UpdateDeleteTest(fixtures.MappedTest): john, jack, jill, jane = sess.query(User).order_by(User.id).all() sess.query(User).filter( - or_(User.name == 'john', User.name == 'jill')).delete() + or_(User.name == "john", User.name == "jill") + ).delete() assert john not in sess and jill not in sess @@ -217,8 +258,9 @@ class UpdateDeleteTest(fixtures.MappedTest): sess = Session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter(text('name = :name')).params( - name='john').delete('fetch') + sess.query(User).filter(text("name = :name")).params( + name="john" + ).delete("fetch") assert john not in sess eq_(sess.query(User).order_by(User.id).all(), [jack, jill, jane]) @@ -229,8 +271,8 @@ class UpdateDeleteTest(fixtures.MappedTest): sess = Session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() sess.query(User).filter( - or_(User.name == 'john', User.name == 'jill')).\ - delete(synchronize_session='evaluate') + or_(User.name == "john", User.name == "jill") + ).delete(synchronize_session="evaluate") assert john not in sess and jill not in sess sess.rollback() assert john in sess and jill in sess @@ -241,8 +283,8 @@ class UpdateDeleteTest(fixtures.MappedTest): sess = Session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() sess.query(User).filter( - or_(User.name == 'john', User.name == 'jill')).\ - delete(synchronize_session='fetch') + or_(User.name == "john", User.name == "jill") + ).delete(synchronize_session="fetch") assert john not in sess and jill not in sess sess.rollback() assert john in sess and jill in sess @@ -254,8 +296,8 @@ class UpdateDeleteTest(fixtures.MappedTest): john, jack, jill, jane = sess.query(User).order_by(User.id).all() sess.query(User).filter( - or_(User.name == 'john', User.name == 'jill')).\ - delete(synchronize_session=False) + or_(User.name == "john", User.name == "jill") + ).delete(synchronize_session=False) assert john in sess and jill in sess @@ -268,8 +310,8 @@ class UpdateDeleteTest(fixtures.MappedTest): john, jack, jill, jane = sess.query(User).order_by(User.id).all() sess.query(User).filter( - or_(User.name == 'john', User.name == 'jill')).\ - delete(synchronize_session='fetch') + or_(User.name == "john", User.name == "jill") + ).delete(synchronize_session="fetch") assert john not in sess and jill not in sess @@ -283,15 +325,17 @@ class UpdateDeleteTest(fixtures.MappedTest): john, jack, jill, jane = sess.query(User).order_by(User.id).all() - assert_raises(exc.InvalidRequestError, - sess.query(User). - filter( - User.name == select([func.max(User.name)])).delete, - synchronize_session='evaluate' - ) + assert_raises( + exc.InvalidRequestError, + sess.query(User) + .filter(User.name == select([func.max(User.name)])) + .delete, + synchronize_session="evaluate", + ) - sess.query(User).filter(User.name == select([func.max(User.name)])).\ - delete(synchronize_session='fetch') + sess.query(User).filter( + User.name == select([func.max(User.name)]) + ).delete(synchronize_session="fetch") assert john not in sess @@ -303,32 +347,42 @@ class UpdateDeleteTest(fixtures.MappedTest): sess = Session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter(User.age > 29).\ - update({'age': User.age - 10}, synchronize_session='evaluate') + sess.query(User).filter(User.age > 29).update( + {"age": User.age - 10}, synchronize_session="evaluate" + ) eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27]) - eq_(sess.query(User.age).order_by( - User.id).all(), list(zip([25, 37, 29, 27]))) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([25, 37, 29, 27])), + ) - sess.query(User).filter(User.age > 29).\ - update({User.age: User.age - 10}, synchronize_session='evaluate') + sess.query(User).filter(User.age > 29).update( + {User.age: User.age - 10}, synchronize_session="evaluate" + ) eq_([john.age, jack.age, jill.age, jane.age], [25, 27, 29, 27]) - eq_(sess.query(User.age).order_by( - User.id).all(), list(zip([25, 27, 29, 27]))) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([25, 27, 29, 27])), + ) - sess.query(User).filter(User.age > 27).\ - update( - {users.c.age_int: User.age - 10}, - synchronize_session='evaluate') + sess.query(User).filter(User.age > 27).update( + {users.c.age_int: User.age - 10}, synchronize_session="evaluate" + ) eq_([john.age, jack.age, jill.age, jane.age], [25, 27, 19, 27]) - eq_(sess.query(User.age).order_by( - User.id).all(), list(zip([25, 27, 19, 27]))) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([25, 27, 19, 27])), + ) - sess.query(User).filter(User.age == 25).\ - update({User.age: User.age - 10}, synchronize_session='fetch') + sess.query(User).filter(User.age == 25).update( + {User.age: User.age - 10}, synchronize_session="fetch" + ) eq_([john.age, jack.age, jill.age, jane.age], [15, 27, 19, 27]) - eq_(sess.query(User.age).order_by( - User.id).all(), list(zip([15, 27, 19, 27]))) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([15, 27, 19, 27])), + ) def test_update_against_table_col(self): User, users = self.classes.User, self.tables.users @@ -336,10 +390,9 @@ class UpdateDeleteTest(fixtures.MappedTest): sess = Session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() eq_([john.age, jack.age, jill.age, jane.age], [25, 47, 29, 37]) - sess.query(User).filter(User.age > 27).\ - update( - {users.c.age_int: User.age - 10}, - synchronize_session='evaluate') + sess.query(User).filter(User.age > 27).update( + {users.c.age_int: User.age - 10}, synchronize_session="evaluate" + ) eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 19, 27]) def test_update_against_metadata(self): @@ -348,9 +401,12 @@ class UpdateDeleteTest(fixtures.MappedTest): sess = Session() sess.query(users).update( - {users.c.age_int: 29}, synchronize_session=False) - eq_(sess.query(User.age).order_by( - User.id).all(), list(zip([29, 29, 29, 29]))) + {users.c.age_int: 29}, synchronize_session=False + ) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([29, 29, 29, 29])), + ) def test_update_with_bindparams(self): User = self.classes.User @@ -359,22 +415,28 @@ class UpdateDeleteTest(fixtures.MappedTest): john, jack, jill, jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter(text('age_int > :x')).params(x=29).\ - update({'age': User.age - 10}, synchronize_session='fetch') + sess.query(User).filter(text("age_int > :x")).params(x=29).update( + {"age": User.age - 10}, synchronize_session="fetch" + ) eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27]) - eq_(sess.query(User.age).order_by( - User.id).all(), list(zip([25, 37, 29, 27]))) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([25, 37, 29, 27])), + ) def test_update_without_load(self): User = self.classes.User sess = Session() - sess.query(User).filter(User.id == 3).\ - update({'age': 44}, synchronize_session='fetch') - eq_(sess.query(User.age).order_by( - User.id).all(), list(zip([25, 47, 44, 37]))) + sess.query(User).filter(User.id == 3).update( + {"age": 44}, synchronize_session="fetch" + ) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([25, 47, 44, 37])), + ) def test_update_changes_resets_dirty(self): User = self.classes.User @@ -389,8 +451,9 @@ class UpdateDeleteTest(fixtures.MappedTest): # autoflush is false. therefore our '50' and '37' are getting # blown away by this operation. - sess.query(User).filter(User.age > 29).\ - update({'age': User.age - 10}, synchronize_session='evaluate') + sess.query(User).filter(User.age > 29).update( + {"age": User.age - 10}, synchronize_session="evaluate" + ) for x in (john, jack, jill, jane): assert not sess.is_modified(x) @@ -414,8 +477,9 @@ class UpdateDeleteTest(fixtures.MappedTest): john.age = 50 jack.age = 37 - sess.query(User).filter(User.age > 29).\ - update({'age': User.age - 10}, synchronize_session='evaluate') + sess.query(User).filter(User.age > 29).update( + {"age": User.age - 10}, synchronize_session="evaluate" + ) for x in (john, jack, jill, jane): assert not sess.is_modified(x) @@ -435,12 +499,15 @@ class UpdateDeleteTest(fixtures.MappedTest): sess = Session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter(User.age > 29).\ - update({'age': User.age - 10}, synchronize_session='fetch') + sess.query(User).filter(User.age > 29).update( + {"age": User.age - 10}, synchronize_session="fetch" + ) eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27]) - eq_(sess.query(User.age).order_by( - User.id).all(), list(zip([25, 37, 29, 27]))) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([25, 37, 29, 27])), + ) @testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount) def test_update_returns_rowcount(self): @@ -448,12 +515,18 @@ class UpdateDeleteTest(fixtures.MappedTest): sess = Session() - rowcount = sess.query(User).filter( - User.age > 29).update({'age': User.age + 0}) + rowcount = ( + sess.query(User) + .filter(User.age > 29) + .update({"age": User.age + 0}) + ) eq_(rowcount, 2) - rowcount = sess.query(User).filter( - User.age > 29).update({'age': User.age - 10}) + rowcount = ( + sess.query(User) + .filter(User.age > 29) + .update({"age": User.age - 10}) + ) eq_(rowcount, 2) @testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount) @@ -462,8 +535,11 @@ class UpdateDeleteTest(fixtures.MappedTest): sess = Session() - rowcount = sess.query(User).filter(User.age > 26).\ - delete(synchronize_session=False) + rowcount = ( + sess.query(User) + .filter(User.age > 26) + .delete(synchronize_session=False) + ) eq_(rowcount, 3) def test_update_all(self): @@ -472,11 +548,13 @@ class UpdateDeleteTest(fixtures.MappedTest): sess = Session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() - sess.query(User).update({'age': 42}, synchronize_session='evaluate') + sess.query(User).update({"age": 42}, synchronize_session="evaluate") eq_([john.age, jack.age, jill.age, jane.age], [42, 42, 42, 42]) - eq_(sess.query(User.age).order_by( - User.id).all(), list(zip([42, 42, 42, 42]))) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([42, 42, 42, 42])), + ) def test_delete_all(self): User = self.classes.User @@ -484,113 +562,112 @@ class UpdateDeleteTest(fixtures.MappedTest): sess = Session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() - sess.query(User).delete(synchronize_session='evaluate') + sess.query(User).delete(synchronize_session="evaluate") assert not ( - john in sess or jack in sess or jill in sess or jane in sess) + john in sess or jack in sess or jill in sess or jane in sess + ) eq_(sess.query(User).count(), 0) def test_autoflush_before_evaluate_update(self): User = self.classes.User sess = Session() - john = sess.query(User).filter_by(name='john').one() - john.name = 'j2' + john = sess.query(User).filter_by(name="john").one() + john.name = "j2" - sess.query(User).filter_by(name='j2').\ - update({'age': 42}, - synchronize_session='evaluate') + sess.query(User).filter_by(name="j2").update( + {"age": 42}, synchronize_session="evaluate" + ) eq_(john.age, 42) def test_autoflush_before_fetch_update(self): User = self.classes.User sess = Session() - john = sess.query(User).filter_by(name='john').one() - john.name = 'j2' + john = sess.query(User).filter_by(name="john").one() + john.name = "j2" - sess.query(User).filter_by(name='j2').\ - update({'age': 42}, - synchronize_session='fetch') + sess.query(User).filter_by(name="j2").update( + {"age": 42}, synchronize_session="fetch" + ) eq_(john.age, 42) def test_autoflush_before_evaluate_delete(self): User = self.classes.User sess = Session() - john = sess.query(User).filter_by(name='john').one() - john.name = 'j2' + john = sess.query(User).filter_by(name="john").one() + john.name = "j2" - sess.query(User).filter_by(name='j2').\ - delete( - synchronize_session='evaluate') + sess.query(User).filter_by(name="j2").delete( + synchronize_session="evaluate" + ) assert john not in sess def test_autoflush_before_fetch_delete(self): User = self.classes.User sess = Session() - john = sess.query(User).filter_by(name='john').one() - john.name = 'j2' + john = sess.query(User).filter_by(name="john").one() + john.name = "j2" - sess.query(User).filter_by(name='j2').\ - delete( - synchronize_session='fetch') + sess.query(User).filter_by(name="j2").delete( + synchronize_session="fetch" + ) assert john not in sess def test_evaluate_before_update(self): User = self.classes.User sess = Session() - john = sess.query(User).filter_by(name='john').one() - sess.expire(john, ['age']) + john = sess.query(User).filter_by(name="john").one() + sess.expire(john, ["age"]) # eval must be before the update. otherwise # we eval john, age has been expired and doesn't # match the new value coming in - sess.query(User).filter_by(name='john').filter_by(age=25).\ - update({'name': 'j2', 'age': 40}, - synchronize_session='evaluate') - eq_(john.name, 'j2') + sess.query(User).filter_by(name="john").filter_by(age=25).update( + {"name": "j2", "age": 40}, synchronize_session="evaluate" + ) + eq_(john.name, "j2") eq_(john.age, 40) def test_fetch_before_update(self): User = self.classes.User sess = Session() - john = sess.query(User).filter_by(name='john').one() - sess.expire(john, ['age']) + john = sess.query(User).filter_by(name="john").one() + sess.expire(john, ["age"]) - sess.query(User).filter_by(name='john').filter_by(age=25).\ - update({'name': 'j2', 'age': 40}, - synchronize_session='fetch') - eq_(john.name, 'j2') + sess.query(User).filter_by(name="john").filter_by(age=25).update( + {"name": "j2", "age": 40}, synchronize_session="fetch" + ) + eq_(john.name, "j2") eq_(john.age, 40) def test_evaluate_before_delete(self): User = self.classes.User sess = Session() - john = sess.query(User).filter_by(name='john').one() - sess.expire(john, ['age']) + john = sess.query(User).filter_by(name="john").one() + sess.expire(john, ["age"]) - sess.query(User).filter_by(name='john').\ - filter_by(age=25).\ - delete( - synchronize_session='evaluate') + sess.query(User).filter_by(name="john").filter_by(age=25).delete( + synchronize_session="evaluate" + ) assert john not in sess def test_fetch_before_delete(self): User = self.classes.User sess = Session() - john = sess.query(User).filter_by(name='john').one() - sess.expire(john, ['age']) + john = sess.query(User).filter_by(name="john").one() + sess.expire(john, ["age"]) - sess.query(User).filter_by(name='john').\ - filter_by(age=25).\ - delete( - synchronize_session='fetch') + sess.query(User).filter_by(name="john").filter_by(age=25).delete( + synchronize_session="fetch" + ) assert john not in sess def test_update_unordered_dict(self): @@ -601,8 +678,7 @@ class UpdateDeleteTest(fixtures.MappedTest): # are ordered in table order q = session.query(User) with mock.patch.object(q, "_execute_crud") as exec_: - q.filter(User.id == 15).update( - {'name': 'foob', 'id': 123}) + q.filter(User.id == 15).update({"name": "foob", "id": 123}) # Confirm that parameters are a dict instead of tuple or list params_type = type(exec_.mock_calls[0][1][0].parameters) is_(params_type, dict) @@ -615,39 +691,50 @@ class UpdateDeleteTest(fixtures.MappedTest): q = session.query(User) with mock.patch.object(q, "_execute_crud") as exec_: q.filter(User.id == 15).update( - (('id', 123), ('name', 'foob')), - update_args={"preserve_parameter_order": True}) - cols = [c.key - for c in exec_.mock_calls[0][1][0]._parameter_ordering] - eq_(['id', 'name'], cols) + (("id", 123), ("name", "foob")), + update_args={"preserve_parameter_order": True}, + ) + cols = [ + c.key for c in exec_.mock_calls[0][1][0]._parameter_ordering + ] + eq_(["id", "name"], cols) # Now invert the order and use a list instead, and check that order is # also preserved q = session.query(User) with mock.patch.object(q, "_execute_crud") as exec_: q.filter(User.id == 15).update( - [('name', 'foob'), ('id', 123)], - update_args={"preserve_parameter_order": True}) - cols = [c.key - for c in exec_.mock_calls[0][1][0]._parameter_ordering] - eq_(['name', 'id'], cols) + [("name", "foob"), ("id", 123)], + update_args={"preserve_parameter_order": True}, + ) + cols = [ + c.key for c in exec_.mock_calls[0][1][0]._parameter_ordering + ] + eq_(["name", "id"], cols) class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(32)), - Column('age', Integer)) - - Table('documents', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', None, ForeignKey('users.id')), - Column('title', String(32))) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(32)), + Column("age", Integer), + ) + + Table( + "documents", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", None, ForeignKey("users.id")), + Column("title", String(32)), + ) @classmethod def setup_classes(cls): @@ -661,33 +748,46 @@ class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest): def insert_data(cls): users = cls.tables.users - users.insert().execute([ - dict(id=1, name='john', age=25), - dict(id=2, name='jack', age=47), - dict(id=3, name='jill', age=29), - dict(id=4, name='jane', age=37), - ]) + users.insert().execute( + [ + dict(id=1, name="john", age=25), + dict(id=2, name="jack", age=47), + dict(id=3, name="jill", age=29), + dict(id=4, name="jane", age=37), + ] + ) documents = cls.tables.documents - documents.insert().execute([ - dict(id=1, user_id=1, title='foo'), - dict(id=2, user_id=1, title='bar'), - dict(id=3, user_id=2, title='baz'), - ]) + documents.insert().execute( + [ + dict(id=1, user_id=1, title="foo"), + dict(id=2, user_id=1, title="bar"), + dict(id=3, user_id=2, title="baz"), + ] + ) @classmethod def setup_mappers(cls): - documents, Document, User, users = (cls.tables.documents, - cls.classes.Document, - cls.classes.User, - cls.tables.users) + documents, Document, User, users = ( + cls.tables.documents, + cls.classes.Document, + cls.classes.User, + cls.tables.users, + ) mapper(User, users) - mapper(Document, documents, properties={ - 'user': relationship(User, lazy='joined', - backref=backref('documents', lazy='select')) - }) + mapper( + Document, + documents, + properties={ + "user": relationship( + User, + lazy="joined", + backref=backref("documents", lazy="select"), + ) + }, + ) def test_update_with_eager_relationships(self): Document = self.classes.Document @@ -695,13 +795,16 @@ class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest): sess = Session() foo, bar, baz = sess.query(Document).order_by(Document.id).all() - sess.query(Document).filter(Document.user_id == 1).\ - update({'title': Document.title + Document.title}, - synchronize_session='fetch') + sess.query(Document).filter(Document.user_id == 1).update( + {"title": Document.title + Document.title}, + synchronize_session="fetch", + ) - eq_([foo.title, bar.title, baz.title], ['foofoo', 'barbar', 'baz']) - eq_(sess.query(Document.title).order_by(Document.id).all(), - list(zip(['foofoo', 'barbar', 'baz']))) + eq_([foo.title, bar.title, baz.title], ["foofoo", "barbar", "baz"]) + eq_( + sess.query(Document.title).order_by(Document.id).all(), + list(zip(["foofoo", "barbar", "baz"])), + ) def test_update_with_explicit_joinedload(self): User = self.classes.User @@ -709,23 +812,26 @@ class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest): sess = Session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() - sess.query(User).options( - joinedload(User.documents)).filter(User.age > 29).\ - update({'age': User.age - 10}, synchronize_session='fetch') + sess.query(User).options(joinedload(User.documents)).filter( + User.age > 29 + ).update({"age": User.age - 10}, synchronize_session="fetch") eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27]) - eq_(sess.query(User.age).order_by( - User.id).all(), list(zip([25, 37, 29, 27]))) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([25, 37, 29, 27])), + ) def test_delete_with_eager_relationships(self): Document = self.classes.Document sess = Session() - sess.query(Document).filter(Document.user_id == 1).\ - delete(synchronize_session=False) + sess.query(Document).filter(Document.user_id == 1).delete( + synchronize_session=False + ) - eq_(sess.query(Document.title).all(), list(zip(['baz']))) + eq_(sess.query(Document.title).all(), list(zip(["baz"]))) class UpdateDeleteFromTest(fixtures.MappedTest): @@ -733,17 +839,21 @@ class UpdateDeleteFromTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True), - Column('samename', String(10)), - ) - Table('documents', metadata, - Column('id', Integer, primary_key=True), - Column('user_id', None, ForeignKey('users.id')), - Column('title', String(32)), - Column('flag', Boolean), - Column('samename', String(10)), - ) + Table( + "users", + metadata, + Column("id", Integer, primary_key=True), + Column("samename", String(10)), + ) + Table( + "documents", + metadata, + Column("id", Integer, primary_key=True), + Column("user_id", None, ForeignKey("users.id")), + Column("title", String(32)), + Column("flag", Boolean), + Column("samename", String(10)), + ) @classmethod def setup_classes(cls): @@ -757,53 +867,66 @@ class UpdateDeleteFromTest(fixtures.MappedTest): def insert_data(cls): users = cls.tables.users - users.insert().execute([ - dict(id=1, ), - dict(id=2, ), - dict(id=3, ), - dict(id=4, ), - ]) + users.insert().execute( + [dict(id=1), dict(id=2), dict(id=3), dict(id=4)] + ) documents = cls.tables.documents - documents.insert().execute([ - dict(id=1, user_id=1, title='foo'), - dict(id=2, user_id=1, title='bar'), - dict(id=3, user_id=2, title='baz'), - dict(id=4, user_id=2, title='hoho'), - dict(id=5, user_id=3, title='lala'), - dict(id=6, user_id=3, title='bleh'), - ]) + documents.insert().execute( + [ + dict(id=1, user_id=1, title="foo"), + dict(id=2, user_id=1, title="bar"), + dict(id=3, user_id=2, title="baz"), + dict(id=4, user_id=2, title="hoho"), + dict(id=5, user_id=3, title="lala"), + dict(id=6, user_id=3, title="bleh"), + ] + ) @classmethod def setup_mappers(cls): - documents, Document, User, users = (cls.tables.documents, - cls.classes.Document, - cls.classes.User, - cls.tables.users) + documents, Document, User, users = ( + cls.tables.documents, + cls.classes.Document, + cls.classes.User, + cls.tables.users, + ) mapper(User, users) - mapper(Document, documents, properties={ - 'user': relationship(User, backref='documents') - }) + mapper( + Document, + documents, + properties={"user": relationship(User, backref="documents")}, + ) @testing.requires.update_from def test_update_from_joined_subq_test(self): Document = self.classes.Document s = Session() - subq = s.query(func.max(Document.title).label('title')).\ - group_by(Document.user_id).subquery() + subq = ( + s.query(func.max(Document.title).label("title")) + .group_by(Document.user_id) + .subquery() + ) - s.query(Document).filter(Document.title == subq.c.title).\ - update({'flag': True}, synchronize_session=False) + s.query(Document).filter(Document.title == subq.c.title).update( + {"flag": True}, synchronize_session=False + ) eq_( set(s.query(Document.id, Document.flag)), - set([ - (1, True), (2, None), - (3, None), (4, True), - (5, True), (6, None)]) + set( + [ + (1, True), + (2, None), + (3, None), + (4, True), + (5, True), + (6, None), + ] + ), ) @testing.requires.delete_from @@ -811,18 +934,19 @@ class UpdateDeleteFromTest(fixtures.MappedTest): Document = self.classes.Document s = Session() - subq = s.query(func.max(Document.title).label('title')).\ - group_by(Document.user_id).subquery() + subq = ( + s.query(func.max(Document.title).label("title")) + .group_by(Document.user_id) + .subquery() + ) - s.query(Document).filter(Document.title == subq.c.title).\ - delete(synchronize_session=False) + s.query(Document).filter(Document.title == subq.c.title).delete( + synchronize_session=False + ) eq_( set(s.query(Document.id, Document.flag)), - set([ - (2, None), - (3, None), - (6, None)]) + set([(2, None), (3, None), (6, None)]), ) def test_no_eval_against_multi_table_criteria(self): @@ -836,7 +960,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest): exc.InvalidRequestError, "Could not evaluate current criteria in Python.", q.update, - {"name": "ed"} + {"name": "ed"}, ) @testing.requires.update_where_target_in_subquery @@ -844,18 +968,28 @@ class UpdateDeleteFromTest(fixtures.MappedTest): Document = self.classes.Document s = Session() - subq = s.query(func.max(Document.title).label('title')).\ - group_by(Document.user_id).subquery() + subq = ( + s.query(func.max(Document.title).label("title")) + .group_by(Document.user_id) + .subquery() + ) - s.query(Document).filter(Document.title.in_(subq)).\ - update({'flag': True}, synchronize_session=False) + s.query(Document).filter(Document.title.in_(subq)).update( + {"flag": True}, synchronize_session=False + ) eq_( set(s.query(Document.id, Document.flag)), - set([ - (1, True), (2, None), - (3, None), (4, True), - (5, True), (6, None)]) + set( + [ + (1, True), + (2, None), + (3, None), + (4, True), + (5, True), + (6, None), + ] + ), ) @testing.requires.update_where_target_in_subquery @@ -864,60 +998,73 @@ class UpdateDeleteFromTest(fixtures.MappedTest): Document = self.classes.Document s = Session() - subq = s.query(func.max(Document.title).label('title')).\ - group_by(Document.user_id).subquery() + subq = ( + s.query(func.max(Document.title).label("title")) + .group_by(Document.user_id) + .subquery() + ) # this would work with Firebird if you do literal_column('1') # instead case_stmt = case([(Document.title.in_(subq), True)], else_=False) s.query(Document).update( - {'flag': case_stmt}, synchronize_session=False) + {"flag": case_stmt}, synchronize_session=False + ) eq_( set(s.query(Document.id, Document.flag)), - set([ - (1, True), (2, False), - (3, False), (4, True), - (5, True), (6, False)]) + set( + [ + (1, True), + (2, False), + (3, False), + (4, True), + (5, True), + (6, False), + ] + ), ) - @testing.only_on('mysql', 'Multi table update') + @testing.only_on("mysql", "Multi table update") def test_update_from_multitable_same_names(self): Document = self.classes.Document User = self.classes.User s = Session() - s.query(Document).\ - filter(User.id == Document.user_id).\ - filter(User.id == 2).update({ - Document.samename: 'd_samename', - User.samename: 'u_samename' - }, synchronize_session=False) + s.query(Document).filter(User.id == Document.user_id).filter( + User.id == 2 + ).update( + {Document.samename: "d_samename", User.samename: "u_samename"}, + synchronize_session=False, + ) eq_( - s.query(User.id, Document.samename, User.samename). - filter(User.id == Document.user_id). - order_by(User.id).all(), + s.query(User.id, Document.samename, User.samename) + .filter(User.id == Document.user_id) + .order_by(User.id) + .all(), [ (1, None, None), (1, None, None), - (2, 'd_samename', 'u_samename'), - (2, 'd_samename', 'u_samename'), + (2, "d_samename", "u_samename"), + (2, "d_samename", "u_samename"), (3, None, None), (3, None, None), - ] + ], ) class ExpressionUpdateTest(fixtures.MappedTest): - @classmethod def define_tables(cls, metadata): - data = Table('data', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('counter', Integer, nullable=False, default=0) - ) + data = Table( + "data", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("counter", Integer, nullable=False, default=0), + ) @classmethod def setup_classes(cls): @@ -927,7 +1074,7 @@ class ExpressionUpdateTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): data = cls.tables.data - mapper(cls.classes.Data, data, properties={'cnt': data.c.counter}) + mapper(cls.classes.Data, data, properties={"cnt": data.c.counter}) @testing.provide_metadata def test_update_attr_names(self): @@ -944,7 +1091,7 @@ class ExpressionUpdateTest(fixtures.MappedTest): eq_(d1.cnt, 1) - sess.query(Data).update({Data.cnt: Data.cnt + 1}, 'fetch') + sess.query(Data).update({Data.cnt: Data.cnt + 1}, "fetch") sess.flush() eq_(d1.cnt, 2) @@ -956,9 +1103,8 @@ class ExpressionUpdateTest(fixtures.MappedTest): update_args = {"mysql_limit": 1} q = session.query(Data) - with testing.mock.patch.object(q, '_execute_crud') as exec_: - q.update({Data.cnt: Data.cnt + 1}, - update_args=update_args) + with testing.mock.patch.object(q, "_execute_crud") as exec_: + q.update({Data.cnt: Data.cnt + 1}, update_args=update_args) eq_(exec_.call_count, 1) args, kwargs = exec_.mock_calls[0][1:3] eq_(len(args), 2) @@ -968,9 +1114,9 @@ class ExpressionUpdateTest(fixtures.MappedTest): class InheritTest(fixtures.DeclarativeMappedTest): - run_inserts = 'each' + run_inserts = "each" - run_deletes = 'each' + run_deletes = "each" __backend__ = True @classmethod @@ -978,33 +1124,39 @@ class InheritTest(fixtures.DeclarativeMappedTest): Base = cls.DeclarativeBasic class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column( - Integer, primary_key=True, test_needs_autoincrement=True) + Integer, primary_key=True, test_needs_autoincrement=True + ) type = Column(String(50)) name = Column(String(50)) class Engineer(Person): - __tablename__ = 'engineer' - id = Column(Integer, ForeignKey('person.id'), primary_key=True) + __tablename__ = "engineer" + id = Column(Integer, ForeignKey("person.id"), primary_key=True) engineer_name = Column(String(50)) class Manager(Person): - __tablename__ = 'manager' - id = Column(Integer, ForeignKey('person.id'), primary_key=True) + __tablename__ = "manager" + id = Column(Integer, ForeignKey("person.id"), primary_key=True) manager_name = Column(String(50)) @classmethod def insert_data(cls): - Engineer, Person, Manager = cls.classes.Engineer, \ - cls.classes.Person, cls.classes.Manager + Engineer, Person, Manager = ( + cls.classes.Engineer, + cls.classes.Person, + cls.classes.Manager, + ) s = Session(testing.db) - s.add_all([ - Engineer(name='e1', engineer_name='e1'), - Manager(name='m1', manager_name='m1'), - Engineer(name='e2', engineer_name='e2'), - Person(name='p1'), - ]) + s.add_all( + [ + Engineer(name="e1", engineer_name="e1"), + Manager(name="m1", manager_name="m1"), + Engineer(name="e2", engineer_name="e2"), + Person(name="p1"), + ] + ) s.commit() def test_illegal_metadata(self): @@ -1016,30 +1168,29 @@ class InheritTest(fixtures.DeclarativeMappedTest): exc.InvalidRequestError, "This operation requires only one Table or entity be " "specified as the target.", - sess.query(person.join(engineer)).update, {} + sess.query(person.join(engineer)).update, + {}, ) def test_update_subtable_only(self): Engineer = self.classes.Engineer s = Session(testing.db) - s.query(Engineer).update({'engineer_name': 'e5'}) + s.query(Engineer).update({"engineer_name": "e5"}) - eq_( - s.query(Engineer.engineer_name).all(), - [('e5', ), ('e5', )] - ) + eq_(s.query(Engineer.engineer_name).all(), [("e5",), ("e5",)]) @testing.requires.update_from def test_update_from(self): Engineer = self.classes.Engineer Person = self.classes.Person s = Session(testing.db) - s.query(Engineer).filter(Engineer.id == Person.id).\ - filter(Person.name == 'e2').update({'engineer_name': 'e5'}) + s.query(Engineer).filter(Engineer.id == Person.id).filter( + Person.name == "e2" + ).update({"engineer_name": "e5"}) eq_( set(s.query(Person.name, Engineer.engineer_name)), - set([('e1', 'e1', ), ('e2', 'e5')]) + set([("e1", "e1"), ("e2", "e5")]), ) @testing.requires.delete_from @@ -1047,24 +1198,25 @@ class InheritTest(fixtures.DeclarativeMappedTest): Engineer = self.classes.Engineer Person = self.classes.Person s = Session(testing.db) - s.query(Engineer).filter(Engineer.id == Person.id).\ - filter(Person.name == 'e2').delete() + s.query(Engineer).filter(Engineer.id == Person.id).filter( + Person.name == "e2" + ).delete() eq_( set(s.query(Person.name, Engineer.engineer_name)), - set([('e1', 'e1', )]) + set([("e1", "e1")]), ) - @testing.only_on('mysql', 'Multi table update') + @testing.only_on("mysql", "Multi table update") def test_update_from_multitable(self): Engineer = self.classes.Engineer Person = self.classes.Person s = Session(testing.db) - s.query(Engineer).filter(Engineer.id == Person.id).\ - filter(Person.name == 'e2').update({Person.name: 'e22', - Engineer.engineer_name: 'e55'}) + s.query(Engineer).filter(Engineer.id == Person.id).filter( + Person.name == "e2" + ).update({Person.name: "e22", Engineer.engineer_name: "e55"}) eq_( set(s.query(Person.name, Engineer.engineer_name)), - set([('e1', 'e1', ), ('e22', 'e55')]) + set([("e1", "e1"), ("e22", "e55")]), ) diff --git a/test/orm/test_utils.py b/test/orm/test_utils.py index 44161ddcde..d53e99d131 100644 --- a/test/orm/test_utils.py +++ b/test/orm/test_utils.py @@ -19,19 +19,23 @@ from .inheritance import _poly_fixtures class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def _fixture(self, cls, properties={}): - table = Table('point', MetaData(), - Column('id', Integer(), primary_key=True), - Column('x', Integer), - Column('y', Integer)) + table = Table( + "point", + MetaData(), + Column("id", Integer(), primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) mapper(cls, table, properties=properties) return table def test_simple(self): class Point(object): pass + table = self._fixture(Point) alias = aliased(Point) @@ -46,6 +50,7 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): def test_not_instantiatable(self): class Point(object): pass + table = self._fixture(Point) alias = aliased(Point) @@ -64,9 +69,9 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): # TODO: I don't quite understand this # still if util.py2k: - assert not getattr(alias, 'zero') + assert not getattr(alias, "zero") else: - assert getattr(alias, 'zero') + assert getattr(alias, "zero") def test_classmethod(self): class Point(object): @@ -96,7 +101,6 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): assert Point.max_x is alias.max_x def test_descriptors(self): - class descriptor(object): def __init__(self, fn): self.fn = fn @@ -108,7 +112,7 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): return self def method(self): - return 'method' + return "method" class Point(object): center = (0, 0) @@ -122,13 +126,14 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): assert Point.thing != (0, 0) assert Point().thing == (0, 0) - assert Point.thing.method() == 'method' + assert Point.thing.method() == "method" assert alias.thing != (0, 0) - assert alias.thing.method() == 'method' + assert alias.thing.method() == "method" def _assert_has_table(self, expr, table): from sqlalchemy import Column # override testlib's override + for child in expr.get_children(): if isinstance(child, Column): assert child.table is table @@ -150,7 +155,7 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): sess.query(alias).filter(alias.left_of(Point)), "SELECT point_1.id AS point_1_id, point_1.x AS point_1_x, " "point_1.y AS point_1_y FROM point AS point_1, point " - "WHERE point_1.x < point.x" + "WHERE point_1.x < point.x", ) def test_hybrid_descriptor_two(self): @@ -176,7 +181,7 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): sess.query(alias).filter(alias.double_x > Point.x), "SELECT point_1.id AS point_1_id, point_1.x AS point_1_x, " "point_1.y AS point_1_y FROM point AS point_1, point " - "WHERE point_1.x * :x_1 > point.x" + "WHERE point_1.x * :x_1 > point.x", ) def test_hybrid_descriptor_three(self): @@ -203,10 +208,7 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): eq_(str(alias.x + 1), "point_1.x + :x_1") eq_(str(alias.x_alone + 1), "point_1.x + :x_1") - is_( - Point.x_alone.__clause_element__(), - Point.x.__clause_element__() - ) + is_(Point.x_alone.__clause_element__(), Point.x.__clause_element__()) eq_(str(alias.x_alone == alias.x), "point_1.x = point_1.x") @@ -219,7 +221,7 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): sess.query(alias).filter(alias.x_alone > Point.x), "SELECT point_1.id AS point_1_id, point_1.x AS point_1_x, " "point_1.y AS point_1_y FROM point AS point_1, point " - "WHERE point_1.x > point.x" + "WHERE point_1.x > point.x", ) def test_proxy_descriptor_one(self): @@ -227,9 +229,7 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): def __init__(self, x, y): self.x, self.y = x, y - self._fixture(Point, properties={ - 'x_syn': synonym("x") - }) + self._fixture(Point, properties={"x_syn": synonym("x")}) alias = aliased(Point) eq_(str(Point.x_syn), "Point.x_syn") @@ -239,16 +239,14 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( sess.query(alias.x_syn).filter(alias.x_syn > Point.x_syn), "SELECT point_1.x AS point_1_x FROM point AS point_1, point " - "WHERE point_1.x > point.x" + "WHERE point_1.x > point.x", ) def test_parententity_vs_parentmapper(self): class Point(object): pass - self._fixture(Point, properties={ - 'x_syn': synonym("x") - }) + self._fixture(Point, properties={"x_syn": synonym("x")}) pa = aliased(Point) is_(Point.x_syn._parententity, inspect(Point)) @@ -257,17 +255,21 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): is_(Point.x._parentmapper, inspect(Point)) is_( - Point.x_syn.__clause_element__()._annotations['parententity'], - inspect(Point)) + Point.x_syn.__clause_element__()._annotations["parententity"], + inspect(Point), + ) is_( - Point.x.__clause_element__()._annotations['parententity'], - inspect(Point)) + Point.x.__clause_element__()._annotations["parententity"], + inspect(Point), + ) is_( - Point.x_syn.__clause_element__()._annotations['parentmapper'], - inspect(Point)) + Point.x_syn.__clause_element__()._annotations["parentmapper"], + inspect(Point), + ) is_( - Point.x.__clause_element__()._annotations['parentmapper'], - inspect(Point)) + Point.x.__clause_element__()._annotations["parentmapper"], + inspect(Point), + ) pa = aliased(Point) @@ -277,19 +279,20 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): is_(pa.x._parentmapper, inspect(Point)) is_( - pa.x_syn.__clause_element__()._annotations['parententity'], - inspect(pa) + pa.x_syn.__clause_element__()._annotations["parententity"], + inspect(pa), ) is_( - pa.x.__clause_element__()._annotations['parententity'], - inspect(pa) + pa.x.__clause_element__()._annotations["parententity"], inspect(pa) ) is_( - pa.x_syn.__clause_element__()._annotations['parentmapper'], - inspect(Point)) + pa.x_syn.__clause_element__()._annotations["parentmapper"], + inspect(Point), + ) is_( - pa.x.__clause_element__()._annotations['parentmapper'], - inspect(Point)) + pa.x.__clause_element__()._annotations["parentmapper"], + inspect(Point), + ) class IdentityKeyTest(_fixtures.FixtureTest): @@ -320,7 +323,7 @@ class IdentityKeyTest(_fixtures.FixtureTest): mapper(User, users) s = create_session() - u = User(name='u1') + u = User(name="u1") s.add(u) s.flush() key = orm_util.identity_key(instance=u) @@ -347,7 +350,7 @@ class IdentityKeyTest(_fixtures.FixtureTest): class PathRegistryTest(_fixtures.FixtureTest): - run_setup_mappers = 'once' + run_setup_mappers = "once" run_inserts = None run_deletes = None @@ -357,14 +360,8 @@ class PathRegistryTest(_fixtures.FixtureTest): def test_root_registry(self): umapper = inspect(self.classes.User) - is_( - RootRegistry()[umapper], - umapper._path_registry - ) - eq_( - RootRegistry()[umapper], - PathRegistry.coerce((umapper,)) - ) + is_(RootRegistry()[umapper], umapper._path_registry) + eq_(RootRegistry()[umapper], PathRegistry.coerce((umapper,))) def test_expand(self): umapper = inspect(self.classes.User) @@ -372,10 +369,17 @@ class PathRegistryTest(_fixtures.FixtureTest): path = PathRegistry.coerce((umapper,)) eq_( - path[umapper.attrs.addresses][amapper] - [amapper.attrs.email_address], - PathRegistry.coerce((umapper, umapper.attrs.addresses, - amapper, amapper.attrs.email_address)) + path[umapper.attrs.addresses][amapper][ + amapper.attrs.email_address + ], + PathRegistry.coerce( + ( + umapper, + umapper.attrs.addresses, + amapper, + amapper.attrs.email_address, + ) + ), ) def test_entity_boolean(self): @@ -398,24 +402,42 @@ class PathRegistryTest(_fixtures.FixtureTest): def test_indexed_entity(self): umapper = inspect(self.classes.User) amapper = inspect(self.classes.Address) - path = PathRegistry.coerce((umapper, umapper.attrs.addresses, - amapper, amapper.attrs.email_address)) + path = PathRegistry.coerce( + ( + umapper, + umapper.attrs.addresses, + amapper, + amapper.attrs.email_address, + ) + ) is_(path[0], umapper) is_(path[2], amapper) def test_indexed_key(self): umapper = inspect(self.classes.User) amapper = inspect(self.classes.Address) - path = PathRegistry.coerce((umapper, umapper.attrs.addresses, - amapper, amapper.attrs.email_address)) + path = PathRegistry.coerce( + ( + umapper, + umapper.attrs.addresses, + amapper, + amapper.attrs.email_address, + ) + ) eq_(path[1], umapper.attrs.addresses) eq_(path[3], amapper.attrs.email_address) def test_slice(self): umapper = inspect(self.classes.User) amapper = inspect(self.classes.Address) - path = PathRegistry.coerce((umapper, umapper.attrs.addresses, - amapper, amapper.attrs.email_address)) + path = PathRegistry.coerce( + ( + umapper, + umapper.attrs.addresses, + amapper, + amapper.attrs.email_address, + ) + ) eq_(path[1:3], (umapper.attrs.addresses, amapper)) def test_addition(self): @@ -425,8 +447,14 @@ class PathRegistryTest(_fixtures.FixtureTest): p2 = PathRegistry.coerce((amapper, amapper.attrs.email_address)) eq_( p1 + p2, - PathRegistry.coerce((umapper, umapper.attrs.addresses, - amapper, amapper.attrs.email_address)) + PathRegistry.coerce( + ( + umapper, + umapper.attrs.addresses, + amapper, + amapper.attrs.email_address, + ) + ), ) def test_length(self): @@ -436,8 +464,14 @@ class PathRegistryTest(_fixtures.FixtureTest): p0 = PathRegistry.coerce((umapper,)) p1 = PathRegistry.coerce((umapper, umapper.attrs.addresses)) p2 = PathRegistry.coerce((umapper, umapper.attrs.addresses, amapper)) - p3 = PathRegistry.coerce((umapper, umapper.attrs.addresses, - amapper, amapper.attrs.email_address)) + p3 = PathRegistry.coerce( + ( + umapper, + umapper.attrs.addresses, + amapper, + amapper.attrs.email_address, + ) + ) eq_(len(pneg1), 0) eq_(len(p0), 1) @@ -459,11 +493,19 @@ class PathRegistryTest(_fixtures.FixtureTest): p3 = PathRegistry.coerce((umapper, umapper.attrs.name)) p4 = PathRegistry.coerce((u_alias, umapper.attrs.addresses)) p5 = PathRegistry.coerce((umapper, umapper.attrs.addresses, amapper)) - p6 = PathRegistry.coerce((amapper, amapper.attrs.user, umapper, - umapper.attrs.addresses)) - p7 = PathRegistry.coerce((amapper, amapper.attrs.user, umapper, - umapper.attrs.addresses, - amapper, amapper.attrs.email_address)) + p6 = PathRegistry.coerce( + (amapper, amapper.attrs.user, umapper, umapper.attrs.addresses) + ) + p7 = PathRegistry.coerce( + ( + amapper, + amapper.attrs.user, + umapper, + umapper.attrs.addresses, + amapper, + amapper.attrs.email_address, + ) + ) is_(p1 == p2, True) is_(p1 == p3, False) @@ -492,15 +534,9 @@ class PathRegistryTest(_fixtures.FixtureTest): p2 = PathRegistry.coerce((umapper, umapper.attrs.addresses, amapper)) p3 = PathRegistry.coerce((amapper, amapper.attrs.email_address)) - eq_( - p1.path, (umapper, umapper.attrs.addresses) - ) - eq_( - p2.path, (umapper, umapper.attrs.addresses, amapper) - ) - eq_( - p3.path, (amapper, amapper.attrs.email_address) - ) + eq_(p1.path, (umapper, umapper.attrs.addresses)) + eq_(p2.path, (umapper, umapper.attrs.addresses, amapper)) + eq_(p3.path, (amapper, amapper.attrs.email_address)) def test_registry_set(self): reg = {} @@ -517,10 +553,10 @@ class PathRegistryTest(_fixtures.FixtureTest): eq_( reg, { - ('p1key', p1.path): 'p1value', - ('p2key', p2.path): 'p2value', - ('p3key', p3.path): 'p3value', - } + ("p1key", p1.path): "p1value", + ("p2key", p2.path): "p2value", + ("p3key", p3.path): "p3value", + }, ) def test_registry_get(self): @@ -533,9 +569,9 @@ class PathRegistryTest(_fixtures.FixtureTest): p3 = PathRegistry.coerce((amapper, amapper.attrs.email_address)) reg.update( { - ('p1key', p1.path): 'p1value', - ('p2key', p2.path): 'p2value', - ('p3key', p3.path): 'p3value', + ("p1key", p1.path): "p1value", + ("p2key", p2.path): "p2value", + ("p3key", p3.path): "p3value", } ) @@ -555,9 +591,9 @@ class PathRegistryTest(_fixtures.FixtureTest): p3 = PathRegistry.coerce((amapper, amapper.attrs.email_address)) reg.update( { - ('p1key', p1.path): 'p1value', - ('p2key', p2.path): 'p2value', - ('p3key', p3.path): 'p3value', + ("p1key", p1.path): "p1value", + ("p2key", p2.path): "p2value", + ("p3key", p3.path): "p3value", } ) assert p1.contains(reg, "p1key") @@ -572,11 +608,7 @@ class PathRegistryTest(_fixtures.FixtureTest): p1 = PathRegistry.coerce((umapper, umapper.attrs.addresses)) p2 = PathRegistry.coerce((umapper, umapper.attrs.addresses, amapper)) - reg.update( - { - ('p1key', p1.path): 'p1value', - } - ) + reg.update({("p1key", p1.path): "p1value"}) p1.setdefault(reg, "p1key", "p1newvalue_a") p1.setdefault(reg, "p1key_new", "p1newvalue_b") @@ -584,10 +616,10 @@ class PathRegistryTest(_fixtures.FixtureTest): eq_( reg, { - ('p1key', p1.path): 'p1value', - ('p1key_new', p1.path): 'p1newvalue_b', - ('p2key', p2.path): 'p2newvalue', - } + ("p1key", p1.path): "p1value", + ("p1key_new", p1.path): "p1newvalue_b", + ("p2key", p2.path): "p2newvalue", + }, ) def test_serialize(self): @@ -596,22 +628,19 @@ class PathRegistryTest(_fixtures.FixtureTest): umapper = inspect(self.classes.User) amapper = inspect(self.classes.Address) - p1 = PathRegistry.coerce((umapper, umapper.attrs.addresses, amapper, - amapper.attrs.email_address)) + p1 = PathRegistry.coerce( + ( + umapper, + umapper.attrs.addresses, + amapper, + amapper.attrs.email_address, + ) + ) p2 = PathRegistry.coerce((umapper, umapper.attrs.addresses, amapper)) p3 = PathRegistry.coerce((umapper, umapper.attrs.addresses)) - eq_( - p1.serialize(), - [(User, "addresses"), (Address, "email_address")] - ) - eq_( - p2.serialize(), - [(User, "addresses"), (Address, None)] - ) - eq_( - p3.serialize(), - [(User, "addresses")] - ) + eq_(p1.serialize(), [(User, "addresses"), (Address, "email_address")]) + eq_(p2.serialize(), [(User, "addresses"), (Address, None)]) + eq_(p3.serialize(), [(User, "addresses")]) def test_deseralize(self): User = self.classes.User @@ -619,28 +648,32 @@ class PathRegistryTest(_fixtures.FixtureTest): umapper = inspect(self.classes.User) amapper = inspect(self.classes.Address) - p1 = PathRegistry.coerce((umapper, umapper.attrs.addresses, amapper, - amapper.attrs.email_address)) + p1 = PathRegistry.coerce( + ( + umapper, + umapper.attrs.addresses, + amapper, + amapper.attrs.email_address, + ) + ) p2 = PathRegistry.coerce((umapper, umapper.attrs.addresses, amapper)) p3 = PathRegistry.coerce((umapper, umapper.attrs.addresses)) eq_( - PathRegistry.deserialize([(User, "addresses"), - (Address, "email_address")]), - p1 + PathRegistry.deserialize( + [(User, "addresses"), (Address, "email_address")] + ), + p1, ) eq_( PathRegistry.deserialize([(User, "addresses"), (Address, None)]), - p2 - ) - eq_( - PathRegistry.deserialize([(User, "addresses")]), - p3 + p2, ) + eq_(PathRegistry.deserialize([(User, "addresses")]), p3) class PathRegistryInhTest(_poly_fixtures._Polymorphic): - run_setup_mappers = 'once' + run_setup_mappers = "once" run_inserts = None run_deletes = None @@ -654,10 +687,7 @@ class PathRegistryInhTest(_poly_fixtures._Polymorphic): # given a mapper and an attribute on a subclass, # the path converts what you get to be against that subclass - eq_( - p1.path, - (emapper, emapper.attrs.machines) - ) + eq_(p1.path, (emapper, emapper.attrs.machines)) def test_plain_compound(self): Company = _poly_fixtures.Company @@ -667,14 +697,20 @@ class PathRegistryInhTest(_poly_fixtures._Polymorphic): pmapper = inspect(Person) emapper = inspect(Engineer) - p1 = PathRegistry.coerce((cmapper, cmapper.attrs.employees, - pmapper, emapper.attrs.machines)) + p1 = PathRegistry.coerce( + (cmapper, cmapper.attrs.employees, pmapper, emapper.attrs.machines) + ) # given a mapper and an attribute on a subclass, # the path converts what you get to be against that subclass eq_( p1.path, - (cmapper, cmapper.attrs.employees, emapper, emapper.attrs.machines) + ( + cmapper, + cmapper.attrs.employees, + emapper, + emapper.attrs.machines, + ), ) def test_plain_aliased(self): @@ -688,10 +724,7 @@ class PathRegistryInhTest(_poly_fixtures._Polymorphic): p1 = PathRegistry.coerce((p_alias, emapper.attrs.machines)) # plain AliasedClass - the path keeps that AliasedClass directly # as is in the path - eq_( - p1.path, - (p_alias, emapper.attrs.machines) - ) + eq_(p1.path, (p_alias, emapper.attrs.machines)) def test_plain_aliased_compound(self): Company = _poly_fixtures.Company @@ -706,13 +739,19 @@ class PathRegistryInhTest(_poly_fixtures._Polymorphic): c_alias = inspect(c_alias) p_alias = inspect(p_alias) - p1 = PathRegistry.coerce((c_alias, cmapper.attrs.employees, - p_alias, emapper.attrs.machines)) + p1 = PathRegistry.coerce( + (c_alias, cmapper.attrs.employees, p_alias, emapper.attrs.machines) + ) # plain AliasedClass - the path keeps that AliasedClass directly # as is in the path eq_( p1.path, - (c_alias, cmapper.attrs.employees, p_alias, emapper.attrs.machines) + ( + c_alias, + cmapper.attrs.employees, + p_alias, + emapper.attrs.machines, + ), ) def test_with_poly_sub(self): @@ -728,10 +767,7 @@ class PathRegistryInhTest(_poly_fixtures._Polymorphic): # polymorphic AliasedClass - the path uses _entity_for_mapper() # to get the most specific sub-entity - eq_( - p1.path, - (e_poly, emapper.attrs.machines) - ) + eq_(p1.path, (e_poly, emapper.attrs.machines)) def test_with_poly_base(self): Person = _poly_fixtures.Person @@ -747,10 +783,7 @@ class PathRegistryInhTest(_poly_fixtures._Polymorphic): # polymorphic AliasedClass - because "name" is on Person, # we get Person, not Engineer - eq_( - p1.path, - (p_poly, pmapper.attrs.name) - ) + eq_(p1.path, (p_poly, pmapper.attrs.name)) def test_with_poly_use_mapper(self): Person = _poly_fixtures.Person @@ -764,7 +797,4 @@ class PathRegistryInhTest(_poly_fixtures._Polymorphic): # polymorphic AliasedClass with the "use_mapper_path" flag - # the AliasedClass acts just like the base mapper - eq_( - p1.path, - (emapper, emapper.attrs.machines) - ) + eq_(p1.path, (emapper, emapper.attrs.machines)) diff --git a/test/orm/test_validators.py b/test/orm/test_validators.py index cbbf9f7a8e..923775a7bf 100644 --- a/test/orm/test_validators.py +++ b/test/orm/test_validators.py @@ -1,6 +1,11 @@ from test.orm import _fixtures -from sqlalchemy.testing import fixtures, assert_raises, eq_, ne_, \ - assert_raises_message +from sqlalchemy.testing import ( + fixtures, + assert_raises, + eq_, + ne_, + assert_raises_message, +) from sqlalchemy.orm import mapper, Session, validates, relationship from sqlalchemy.orm import collections from sqlalchemy.testing.mock import Mock, call @@ -13,113 +18,114 @@ class ValidatorTest(_fixtures.FixtureTest): canary = Mock() class User(fixtures.ComparableEntity): - @validates('name') + @validates("name") def validate_name(self, key, name): canary(key, name) - ne_(name, 'fred') - return name + ' modified' + ne_(name, "fred") + return name + " modified" mapper(User, users) sess = Session() - u1 = User(name='ed') - eq_(u1.name, 'ed modified') + u1 = User(name="ed") + eq_(u1.name, "ed modified") assert_raises(AssertionError, setattr, u1, "name", "fred") - eq_(u1.name, 'ed modified') - eq_(canary.mock_calls, [call('name', 'ed'), call('name', 'fred')]) + eq_(u1.name, "ed modified") + eq_(canary.mock_calls, [call("name", "ed"), call("name", "fred")]) sess.add(u1) sess.commit() eq_( - sess.query(User).filter_by(name='ed modified').one(), - User(name='ed') + sess.query(User).filter_by(name="ed modified").one(), + User(name="ed"), ) def test_collection(self): - users, addresses, Address = (self.tables.users, - self.tables.addresses, - self.classes.Address) + users, addresses, Address = ( + self.tables.users, + self.tables.addresses, + self.classes.Address, + ) canary = Mock() class User(fixtures.ComparableEntity): - @validates('addresses') + @validates("addresses") def validate_address(self, key, ad): canary(key, ad) - assert '@' in ad.email_address + assert "@" in ad.email_address return ad - mapper(User, users, properties={ - 'addresses': relationship(Address)} - ) + mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) sess = Session() - u1 = User(name='edward') - a0 = Address(email_address='noemail') + u1 = User(name="edward") + a0 = Address(email_address="noemail") assert_raises(AssertionError, u1.addresses.append, a0) - a1 = Address(id=15, email_address='foo@bar.com') + a1 = Address(id=15, email_address="foo@bar.com") u1.addresses.append(a1) - eq_(canary.mock_calls, [call('addresses', a0), call('addresses', a1)]) + eq_(canary.mock_calls, [call("addresses", a0), call("addresses", a1)]) sess.add(u1) sess.commit() eq_( - sess.query(User).filter_by(name='edward').one(), - User(name='edward', addresses=[ - Address(email_address='foo@bar.com')]) + sess.query(User).filter_by(name="edward").one(), + User( + name="edward", addresses=[Address(email_address="foo@bar.com")] + ), ) def test_validators_dict(self): - users, addresses, Address = (self.tables.users, - self.tables.addresses, - self.classes.Address) + users, addresses, Address = ( + self.tables.users, + self.tables.addresses, + self.classes.Address, + ) class User(fixtures.ComparableEntity): - - @validates('name') + @validates("name") def validate_name(self, key, name): - ne_(name, 'fred') - return name + ' modified' + ne_(name, "fred") + return name + " modified" - @validates('addresses') + @validates("addresses") def validate_address(self, key, ad): - assert '@' in ad.email_address + assert "@" in ad.email_address return ad def simple_function(self, key, value): return key, value - u_m = mapper(User, users, properties={ - 'addresses': relationship(Address)}) + u_m = mapper( + User, users, properties={"addresses": relationship(Address)} + ) mapper(Address, addresses) eq_( dict((k, v[0].__name__) for k, v in list(u_m.validators.items())), - {'name': 'validate_name', - 'addresses': 'validate_address'} + {"name": "validate_name", "addresses": "validate_address"}, ) def test_validator_w_removes(self): - users, addresses, Address = (self.tables.users, - self.tables.addresses, - self.classes.Address) + users, addresses, Address = ( + self.tables.users, + self.tables.addresses, + self.classes.Address, + ) canary = Mock() class User(fixtures.ComparableEntity): - - @validates('name', include_removes=True) + @validates("name", include_removes=True) def validate_name(self, key, item, remove): canary(key, item, remove) return item - @validates('addresses', include_removes=True) + @validates("addresses", include_removes=True) def validate_address(self, key, item, remove): canary(key, item, remove) return item - mapper(User, users, properties={ - 'addresses': relationship(Address) - }) + mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) u1 = User() @@ -132,33 +138,38 @@ class ValidatorTest(_fixtures.FixtureTest): u1.addresses = [a1, a2] u1.addresses = [a2, a3] - eq_(canary.mock_calls, [ - call('name', 'ed', False), - call('name', 'mary', False), - call('name', 'mary', True), - # append a1 - call('addresses', a1, False), - # remove a1 - call('addresses', a1, True), - # set to [a1, a2] - this is two appends - call('addresses', a1, False), call('addresses', a2, False), - # set to [a2, a3] - this is a remove of a1, - # append of a3. the appends are first. - # in 1.2 due to #3896, we also get 'a2' in the - # validates as it is part of the set - call('addresses', a2, False), - call('addresses', a3, False), - call('addresses', a1, True), - ]) + eq_( + canary.mock_calls, + [ + call("name", "ed", False), + call("name", "mary", False), + call("name", "mary", True), + # append a1 + call("addresses", a1, False), + # remove a1 + call("addresses", a1, True), + # set to [a1, a2] - this is two appends + call("addresses", a1, False), + call("addresses", a2, False), + # set to [a2, a3] - this is a remove of a1, + # append of a3. the appends are first. + # in 1.2 due to #3896, we also get 'a2' in the + # validates as it is part of the set + call("addresses", a2, False), + call("addresses", a3, False), + call("addresses", a1, True), + ], + ) def test_validator_bulk_collection_set(self): - users, addresses, Address = (self.tables.users, - self.tables.addresses, - self.classes.Address) + users, addresses, Address = ( + self.tables.users, + self.tables.addresses, + self.classes.Address, + ) class User(fixtures.ComparableEntity): - - @validates('addresses', include_removes=True) + @validates("addresses", include_removes=True) def validate_address(self, key, item, remove): if not remove: assert isinstance(item, str) @@ -167,9 +178,7 @@ class ValidatorTest(_fixtures.FixtureTest): item = Address(email_address=item) return item - mapper(User, users, properties={ - 'addresses': relationship(Address) - }) + mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) u1 = User() @@ -177,22 +186,23 @@ class ValidatorTest(_fixtures.FixtureTest): u1.addresses.append("e2") eq_( u1.addresses, - [Address(email_address="e1"), Address(email_address="e2")] + [Address(email_address="e1"), Address(email_address="e2")], ) u1.addresses = ["e3", "e4"] eq_( u1.addresses, - [Address(email_address="e3"), Address(email_address="e4")] + [Address(email_address="e3"), Address(email_address="e4")], ) def test_validator_bulk_dict_set(self): - users, addresses, Address = (self.tables.users, - self.tables.addresses, - self.classes.Address) + users, addresses, Address = ( + self.tables.users, + self.tables.addresses, + self.classes.Address, + ) class User(fixtures.ComparableEntity): - - @validates('addresses', include_removes=True) + @validates("addresses", include_removes=True) def validate_address(self, key, item, remove): if not remove: assert isinstance(item, str) @@ -201,13 +211,18 @@ class ValidatorTest(_fixtures.FixtureTest): item = Address(email_address=item) return item - mapper(User, users, properties={ - 'addresses': relationship( - Address, - collection_class=collections.attribute_mapped_collection( - "email_address") - ) - }) + mapper( + User, + users, + properties={ + "addresses": relationship( + Address, + collection_class=collections.attribute_mapped_collection( + "email_address" + ), + ) + }, + ) mapper(Address, addresses) u1 = User() @@ -217,16 +232,16 @@ class ValidatorTest(_fixtures.FixtureTest): u1.addresses, { "e1": Address(email_address="e1"), - "e2": Address(email_address="e2") - } + "e2": Address(email_address="e2"), + }, ) u1.addresses = {"e3": "e3", "e4": "e4"} eq_( u1.addresses, { "e3": Address(email_address="e3"), - "e4": Address(email_address="e4") - } + "e4": Address(email_address="e4"), + }, ) def test_validator_multi_warning(self): @@ -245,7 +260,9 @@ class ValidatorTest(_fixtures.FixtureTest): exc.InvalidRequestError, "A validation function for mapped attribute " "'name' on mapper Mapper|Foo|users already exists", - mapper, Foo, users + mapper, + Foo, + users, ) class Bar(object): @@ -261,7 +278,9 @@ class ValidatorTest(_fixtures.FixtureTest): exc.InvalidRequestError, "A validation function for mapped attribute " "'name' on mapper Mapper|Bar|users already exists", - mapper, Bar, users + mapper, + Bar, + users, ) def test_validator_wo_backrefs_wo_removes(self): @@ -277,41 +296,57 @@ class ValidatorTest(_fixtures.FixtureTest): self._test_validator_backrefs(True, True) def _test_validator_backrefs(self, include_backrefs, include_removes): - users, addresses = (self.tables.users, - self.tables.addresses) + users, addresses = (self.tables.users, self.tables.addresses) canary = Mock() class User(fixtures.ComparableEntity): if include_removes: - @validates('addresses', include_removes=True, - include_backrefs=include_backrefs) + + @validates( + "addresses", + include_removes=True, + include_backrefs=include_backrefs, + ) def validate_address(self, key, item, remove): canary(key, item, remove) return item + else: - @validates('addresses', include_removes=False, - include_backrefs=include_backrefs) + + @validates( + "addresses", + include_removes=False, + include_backrefs=include_backrefs, + ) def validate_address(self, key, item): canary(key, item) return item class Address(fixtures.ComparableEntity): if include_removes: - @validates('user', include_backrefs=include_backrefs, - include_removes=True) + + @validates( + "user", + include_backrefs=include_backrefs, + include_removes=True, + ) def validate_user(self, key, item, remove): canary(key, item, remove) return item + else: - @validates('user', include_backrefs=include_backrefs) + + @validates("user", include_backrefs=include_backrefs) def validate_user(self, key, item): canary(key, item) return item - mapper(User, users, properties={ - 'addresses': relationship(Address, backref="user") - }) + mapper( + User, + users, + properties={"addresses": relationship(Address, backref="user")}, + ) mapper(Address, addresses) u1 = User() @@ -331,66 +366,64 @@ class ValidatorTest(_fixtures.FixtureTest): if include_backrefs: if include_removes: - eq_(calls, + eq_( + calls, [ # append #1 - call('addresses', Address(), False), - + call("addresses", Address(), False), # backref for append - call('user', User(addresses=[]), False), - + call("user", User(addresses=[]), False), # append #2 - call('addresses', Address(user=None), False), - + call("addresses", Address(user=None), False), # backref for append - call('user', User(addresses=[]), False), - + call("user", User(addresses=[]), False), # assign a2.user = u2 - call('user', User(addresses=[]), False), - + call("user", User(addresses=[]), False), # backref for u1.addresses.remove(a2) - call('addresses', Address(user=None), True), - + call("addresses", Address(user=None), True), # backref for u2.addresses.append(a2) - call('addresses', Address(user=None), False), - + call("addresses", Address(user=None), False), # del a1.user - call('user', User(addresses=[]), True), - + call("user", User(addresses=[]), True), # backref for u1.addresses.remove(a1) - call('addresses', Address(), True), - + call("addresses", Address(), True), # u2.addresses.remove(a2) - call('addresses', Address(user=None), True), - + call("addresses", Address(user=None), True), # backref for a2.user = None - call('user', None, False) - ]) + call("user", None, False), + ], + ) else: - eq_(calls, + eq_( + calls, [ - call('addresses', Address()), - call('user', User(addresses=[])), - call('addresses', Address(user=None)), - call('user', User(addresses=[])), - call('user', User(addresses=[])), - call('addresses', Address(user=None)), - call('user', None) - ]) + call("addresses", Address()), + call("user", User(addresses=[])), + call("addresses", Address(user=None)), + call("user", User(addresses=[])), + call("user", User(addresses=[])), + call("addresses", Address(user=None)), + call("user", None), + ], + ) else: if include_removes: - eq_(calls, + eq_( + calls, [ - call('addresses', Address(), False), - call('addresses', Address(user=None), False), - call('user', User(addresses=[]), False), - call('user', User(addresses=[]), True), - call('addresses', Address(user=None), True) - ]) + call("addresses", Address(), False), + call("addresses", Address(user=None), False), + call("user", User(addresses=[]), False), + call("user", User(addresses=[]), True), + call("addresses", Address(user=None), True), + ], + ) else: - eq_(calls, + eq_( + calls, [ - call('addresses', Address()), - call('addresses', Address(user=None)), - call('user', User(addresses=[])) - ]) + call("addresses", Address()), + call("addresses", Address(user=None)), + call("user", User(addresses=[])), + ], + ) diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index a1eef5d164..6b13bdc9d8 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -4,13 +4,30 @@ from sqlalchemy.testing import engines, config from sqlalchemy import testing from sqlalchemy.testing.mock import patch from sqlalchemy import ( - Integer, String, Date, ForeignKey, orm, exc, select, TypeDecorator) + Integer, + String, + Date, + ForeignKey, + orm, + exc, + select, + TypeDecorator, +) from sqlalchemy.testing.schema import Table, Column from sqlalchemy.orm import ( - mapper, relationship, Session, create_session, sessionmaker, - exc as orm_exc) + mapper, + relationship, + Session, + create_session, + sessionmaker, + exc as orm_exc, +) from sqlalchemy.testing import ( - eq_, assert_raises, assert_raises_message, fixtures) + eq_, + assert_raises, + assert_raises_message, + fixtures, +) from sqlalchemy.testing.assertsql import CompiledSQL import uuid from sqlalchemy import util @@ -25,11 +42,15 @@ class NullVersionIdTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('version_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('version_id', Integer), - Column('value', String(40), nullable=False)) + Table( + "version_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("version_id", Integer), + Column("value", String(40), nullable=False), + ) @classmethod def setup_classes(cls): @@ -40,7 +61,8 @@ class NullVersionIdTest(fixtures.MappedTest): Foo, version_table = self.classes.Foo, self.tables.version_table mapper( - Foo, version_table, + Foo, + version_table, version_id_col=version_table.c.version_id, version_id_generator=False, ) @@ -52,7 +74,7 @@ class NullVersionIdTest(fixtures.MappedTest): Foo = self.classes.Foo s1 = self._fixture() - f1 = Foo(value='f1') + f1 = Foo(value="f1") s1.add(f1) # Prior to the fix for #3673, you would have been allowed to insert @@ -69,14 +91,15 @@ class NullVersionIdTest(fixtures.MappedTest): assert_raises_message( sa.orm.exc.FlushError, "Instance does not contain a non-NULL version value", - s1.commit) + s1.commit, + ) @testing.emits_warning(r".*versioning cannot be verified") def test_null_version_id_update(self): Foo = self.classes.Foo s1 = self._fixture() - f1 = Foo(value='f1', version_id=1) + f1 = Foo(value="f1", version_id=1) s1.add(f1) s1.commit() @@ -85,13 +108,14 @@ class NullVersionIdTest(fixtures.MappedTest): # this, post commit: Foo(id=1, value='f1rev2', version_id=None). Now # you should get a FlushError on update. - f1.value = 'f1rev2' + f1.value = "f1rev2" f1.version_id = None assert_raises_message( sa.orm.exc.FlushError, "Instance does not contain a non-NULL version value", - s1.commit) + s1.commit, + ) class VersioningTest(fixtures.MappedTest): @@ -99,11 +123,15 @@ class VersioningTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('version_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('version_id', Integer, nullable=False), - Column('value', String(40), nullable=False)) + Table( + "version_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("version_id", Integer, nullable=False), + Column("value", String(40), nullable=False), + ) @classmethod def setup_classes(cls): @@ -125,12 +153,12 @@ class VersioningTest(fixtures.MappedTest): testing.db.dialect.supports_sane_rowcount = False try: s1 = self._fixture() - f1 = Foo(value='f1') - f2 = Foo(value='f2') + f1 = Foo(value="f1") + f2 = Foo(value="f2") s1.add_all((f1, f2)) s1.commit() - f1.value = 'f1rev2' + f1.value = "f1rev2" assert_raises(sa.exc.SAWarning, s1.commit) finally: testing.db.dialect.supports_sane_rowcount = save @@ -141,20 +169,20 @@ class VersioningTest(fixtures.MappedTest): Foo = self.classes.Foo s1 = self._fixture() - f1 = Foo(value='f1') - f2 = Foo(value='f2') + f1 = Foo(value="f1") + f2 = Foo(value="f2") s1.add_all((f1, f2)) s1.commit() - f1.value = 'f1rev2' + f1.value = "f1rev2" s1.commit() s2 = create_session(autocommit=False) f1_s = s2.query(Foo).get(f1.id) - f1_s.value = 'f1rev3' + f1_s.value = "f1rev3" s2.commit() - f1.value = 'f1rev3mine' + f1.value = "f1rev3mine" # Only dialects with a sane rowcount can detect the # StaleDataError @@ -162,7 +190,9 @@ class VersioningTest(fixtures.MappedTest): assert_raises_message( sa.orm.exc.StaleDataError, r"UPDATE statement on table 'version_table' expected " - r"to update 1 row\(s\); 0 were matched.", s1.commit), + r"to update 1 row\(s\); 0 were matched.", + s1.commit, + ), s1.rollback() else: s1.commit() @@ -171,7 +201,7 @@ class VersioningTest(fixtures.MappedTest): f1 = s1.query(Foo).get(f1.id) f2 = s1.query(Foo).get(f2.id) - f1_s.value = 'f1rev4' + f1_s.value = "f1rev4" s2.commit() s1.delete(f1) @@ -182,7 +212,8 @@ class VersioningTest(fixtures.MappedTest): sa.orm.exc.StaleDataError, r"DELETE statement on table 'version_table' expected " r"to delete 2 row\(s\); 1 were matched.", - s1.commit) + s1.commit, + ) else: s1.commit() @@ -191,18 +222,18 @@ class VersioningTest(fixtures.MappedTest): Foo = self.classes.Foo s1 = self._fixture() - f1 = Foo(value='f1') - f2 = Foo(value='f2') + f1 = Foo(value="f1") + f2 = Foo(value="f2") s1.add_all((f1, f2)) s1.commit() - f1.value = 'f1rev2' - f2.value = 'f2rev2' + f1.value = "f1rev2" + f2.value = "f2rev2" s1.commit() eq_( s1.query(Foo.id, Foo.value, Foo.version_id).order_by(Foo.id).all(), - [(f1.id, 'f1rev2', 2), (f2.id, 'f2rev2', 2)] + [(f1.id, "f1rev2", 2), (f2.id, "f2rev2", 2)], ) @testing.emits_warning(r".*versioning cannot be verified") @@ -211,12 +242,11 @@ class VersioningTest(fixtures.MappedTest): s1 = self._fixture() s1.bulk_insert_mappings( - Foo, - [{"id": 1, "value": "f1"}, {"id": 2, "value": "f2"}] + Foo, [{"id": 1, "value": "f1"}, {"id": 2, "value": "f2"}] ) eq_( s1.query(Foo.id, Foo.value, Foo.version_id).order_by(Foo.id).all(), - [(1, 'f1', 1), (2, 'f2', 1)] + [(1, "f1", 1), (2, "f2", 1)], ) @testing.emits_warning(r".*versioning cannot be verified") @@ -224,8 +254,8 @@ class VersioningTest(fixtures.MappedTest): Foo = self.classes.Foo s1 = self._fixture() - f1 = Foo(value='f1') - f2 = Foo(value='f2') + f1 = Foo(value="f1") + f2 = Foo(value="f2") s1.add_all((f1, f2)) s1.commit() @@ -234,14 +264,13 @@ class VersioningTest(fixtures.MappedTest): [ {"id": f1.id, "value": "f1rev2", "version_id": 1}, {"id": f2.id, "value": "f2rev2", "version_id": 1}, - - ] + ], ) s1.commit() eq_( s1.query(Foo.id, Foo.value, Foo.version_id).order_by(Foo.id).all(), - [(f1.id, 'f1rev2', 2), (f2.id, 'f2rev2', 2)] + [(f1.id, "f1rev2", 2), (f2.id, "f2rev2", 2)], ) @testing.emits_warning(r".*versioning cannot be verified") @@ -257,7 +286,7 @@ class VersioningTest(fixtures.MappedTest): Foo = self.classes.Foo s1 = self._fixture() - f1 = Foo(value='f1') + f1 = Foo(value="f1") s1.add(f1) s1.commit() eq_(f1.version_id, 1) @@ -286,13 +315,13 @@ class VersioningTest(fixtures.MappedTest): Foo = self.classes.Foo s1 = self._fixture() - f1s1 = Foo(value='f1 value') + f1s1 = Foo(value="f1 value") s1.add(f1s1) s1.commit() s2 = create_session(autocommit=False) f1s2 = s2.query(Foo).get(f1s1.id) - f1s2.value = 'f1 new value' + f1s2.value = "f1 new value" s2.commit() # load, version is wrong @@ -300,7 +329,8 @@ class VersioningTest(fixtures.MappedTest): sa.orm.exc.StaleDataError, r"Instance .* has version id '\d+' which does not " r"match database-loaded version id '\d+'", - s1.query(Foo).with_for_update(read=True).get, f1s1.id + s1.query(Foo).with_for_update(read=True).get, + f1s1.id, ) # reload it - this expires the old version first @@ -322,13 +352,13 @@ class VersioningTest(fixtures.MappedTest): Foo = self.classes.Foo s1 = self._fixture() - f1s1 = Foo(value='f1 value') + f1s1 = Foo(value="f1 value") s1.add(f1s1) s1.commit() s2 = create_session(autocommit=False) f1s2 = s2.query(Foo).get(f1s1.id) - f1s2.value = 'f1 new value' + f1s2.value = "f1 new value" s2.commit() # load, version is wrong @@ -336,18 +366,19 @@ class VersioningTest(fixtures.MappedTest): sa.orm.exc.StaleDataError, r"Instance .* has version id '\d+' which does not " r"match database-loaded version id '\d+'", - s1.query(Foo).with_lockmode('read').get, f1s1.id + s1.query(Foo).with_lockmode("read").get, + f1s1.id, ) # reload it - this expires the old version first - s1.refresh(f1s1, lockmode='read') + s1.refresh(f1s1, lockmode="read") # now assert version OK - s1.query(Foo).with_lockmode('read').get(f1s1.id) + s1.query(Foo).with_lockmode("read").get(f1s1.id) # assert brand new load is OK too s1.close() - s1.query(Foo).with_lockmode('read').get(f1s1.id) + s1.query(Foo).with_lockmode("read").get(f1s1.id) def test_versioncheck_not_versioned(self): """ensure the versioncheck logic skips if there isn't a @@ -358,10 +389,10 @@ class VersioningTest(fixtures.MappedTest): mapper(Foo, version_table) s1 = Session() - f1s1 = Foo(value='f1 value', version_id=1) + f1s1 = Foo(value="f1 value", version_id=1) s1.add(f1s1) s1.commit() - s1.query(Foo).with_lockmode('read').get(f1s1.id) + s1.query(Foo).with_lockmode("read").get(f1s1.id) @testing.emits_warning(r".*versioning cannot be verified") @engines.close_open_connections @@ -373,7 +404,7 @@ class VersioningTest(fixtures.MappedTest): Foo = self.classes.Foo s1 = self._fixture() - f1s1 = Foo(value='f1 value') + f1s1 = Foo(value="f1 value") s1.add(f1s1) s1.commit() @@ -381,11 +412,10 @@ class VersioningTest(fixtures.MappedTest): f1s2 = s2.query(Foo).get(f1s1.id) # not sure if I like this API s2.refresh(f1s2, with_for_update=True) - f1s2.value = 'f1 new value' + f1s2.value = "f1 new value" assert_raises( - exc.DBAPIError, - s1.refresh, f1s1, lockmode='update_nowait' + exc.DBAPIError, s1.refresh, f1s1, lockmode="update_nowait" ) s1.rollback() @@ -403,23 +433,22 @@ class VersioningTest(fixtures.MappedTest): Foo = self.classes.Foo s1 = self._fixture() - f1s1 = Foo(value='f1 value') + f1s1 = Foo(value="f1 value") s1.add(f1s1) s1.commit() s2 = create_session(autocommit=False) f1s2 = s2.query(Foo).get(f1s1.id) - s2.refresh(f1s2, lockmode='update') - f1s2.value = 'f1 new value' + s2.refresh(f1s2, lockmode="update") + f1s2.value = "f1 new value" assert_raises( - exc.DBAPIError, - s1.refresh, f1s1, lockmode='update_nowait' + exc.DBAPIError, s1.refresh, f1s1, lockmode="update_nowait" ) s1.rollback() s2.commit() - s1.refresh(f1s1, lockmode='update_nowait') + s1.refresh(f1s1, lockmode="update_nowait") assert f1s1.version_id == f1s2.version_id @testing.emits_warning(r".*versioning cannot be verified") @@ -432,22 +461,23 @@ class VersioningTest(fixtures.MappedTest): return self.context.rowcount with patch.object( - config.db.dialect, "supports_sane_multi_rowcount", False): + config.db.dialect, "supports_sane_multi_rowcount", False + ): with patch( - "sqlalchemy.engine.result.ResultProxy.rowcount", - rowcount): + "sqlalchemy.engine.result.ResultProxy.rowcount", rowcount + ): Foo = self.classes.Foo s1 = self._fixture() - f1s1 = Foo(value='f1 value') + f1s1 = Foo(value="f1 value") s1.add(f1s1) s1.commit() - f1s1.value = 'f2 value' + f1s1.value = "f2 value" s1.flush() eq_(f1s1.version_id, 2) - @testing.emits_warning(r'.*does not support updated rowcount') + @testing.emits_warning(r".*does not support updated rowcount") @engines.close_open_connections def test_noversioncheck(self): """test query.with_lockmode works when the mapper has no version id @@ -462,7 +492,7 @@ class VersioningTest(fixtures.MappedTest): s1.commit() s2 = create_session(autocommit=False) - f1s2 = s2.query(Foo).with_lockmode('read').get(f1s1.id) + f1s2 = s2.query(Foo).with_lockmode("read").get(f1s1.id) assert f1s2.id == f1s1.id assert f1s2.value == f1s1.value @@ -471,14 +501,14 @@ class VersioningTest(fixtures.MappedTest): Foo = self.classes.Foo s1 = self._fixture() - f1 = Foo(value='f1') + f1 = Foo(value="f1") s1.add(f1) s1.commit() - f1.value = 'f2' + f1.value = "f2" s1.commit() - f2 = Foo(id=f1.id, value='f3') + f2 = Foo(id=f1.id, value="f3") f3 = s1.merge(f2) assert f3 is f1 s1.commit() @@ -489,14 +519,14 @@ class VersioningTest(fixtures.MappedTest): Foo = self.classes.Foo s1 = self._fixture() - f1 = Foo(value='f1') + f1 = Foo(value="f1") s1.add(f1) s1.commit() - f1.value = 'f2' + f1.value = "f2" s1.commit() - f2 = Foo(id=f1.id, value='f3', version_id=2) + f2 = Foo(id=f1.id, value="f3", version_id=2) f3 = s1.merge(f2) assert f3 is f1 s1.commit() @@ -507,21 +537,22 @@ class VersioningTest(fixtures.MappedTest): Foo = self.classes.Foo s1 = self._fixture() - f1 = Foo(value='f1') + f1 = Foo(value="f1") s1.add(f1) s1.commit() - f1.value = 'f2' + f1.value = "f2" s1.commit() - f2 = Foo(id=f1.id, value='f3', version_id=1) + f2 = Foo(id=f1.id, value="f3", version_id=1) assert_raises_message( orm_exc.StaleDataError, "Version id '1' on merged state " " does not match existing version '2'. " "Leave the version attribute unset when " "merging to update the most recent version.", - s1.merge, f2 + s1.merge, + f2, ) @testing.emits_warning(r".*versioning cannot be verified") @@ -529,14 +560,14 @@ class VersioningTest(fixtures.MappedTest): Foo = self.classes.Foo s1 = self._fixture() - f1 = Foo(value='f1') + f1 = Foo(value="f1") s1.add(f1) s1.commit() - f1.value = 'f2' + f1.value = "f2" s1.commit() - f2 = Foo(id=f1.id, value='f3', version_id=1) + f2 = Foo(id=f1.id, value="f3", version_id=1) s1.close() assert_raises_message( @@ -545,7 +576,8 @@ class VersioningTest(fixtures.MappedTest): " does not match existing version '2'. " "Leave the version attribute unset when " "merging to update the most recent version.", - s1.merge, f2 + s1.merge, + f2, ) @@ -555,10 +587,11 @@ class VersionOnPostUpdateTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'node', metadata, - Column('id', Integer, primary_key=True), - Column('version_id', Integer), - Column('parent_id', ForeignKey('node.id')) + "node", + metadata, + Column("id", Integer, primary_key=True), + Column("version_id", Integer), + Column("parent_id", ForeignKey("node.id")), ) @classmethod @@ -570,13 +603,18 @@ class VersionOnPostUpdateTest(fixtures.MappedTest): Node = self.classes.Node node = self.tables.node - mapper(Node, node, properties={ - 'related': relationship( - Node, - remote_side=node.c.id if not o2m else node.c.parent_id, - post_update=post_update - ) - }, version_id_col=node.c.version_id) + mapper( + Node, + node, + properties={ + "related": relationship( + Node, + remote_side=node.c.id if not o2m else node.c.parent_id, + post_update=post_update, + ) + }, + version_id_col=node.c.version_id, + ) s = Session() n1 = Node(id=1) @@ -667,7 +705,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest): orm_exc.StaleDataError, "UPDATE statement on table 'node' expected to " r"update 1 row\(s\); 0 were matched.", - s.flush + s.flush, ) @testing.requires.sane_rowcount_w_returning @@ -689,7 +727,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest): orm_exc.StaleDataError, "UPDATE statement on table 'node' expected to " r"update 1 row\(s\); 0 were matched.", - s.flush + s.flush, ) @@ -699,18 +737,20 @@ class NoBumpOnRelationshipTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'a', metadata, + "a", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('version_id', Integer) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("version_id", Integer), ) Table( - 'b', metadata, + "b", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('a_id', ForeignKey('a.id')) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("a_id", ForeignKey("a.id")), ) @classmethod @@ -722,7 +762,7 @@ class NoBumpOnRelationshipTest(fixtures.MappedTest): pass def _run_test(self, auto_version_counter=True): - A, B = self.classes('A', 'B') + A, B = self.classes("A", "B") s = Session() if auto_version_counter: a1 = A() @@ -740,13 +780,13 @@ class NoBumpOnRelationshipTest(fixtures.MappedTest): eq_(a1.version_id, 1) def test_plain_counter(self): - A, B = self.classes('A', 'B') - a, b = self.tables('a', 'b') + A, B = self.classes("A", "B") + a, b = self.tables("a", "b") mapper( - A, a, properties={ - 'bs': relationship(B, backref='a') - }, + A, + a, + properties={"bs": relationship(B, backref="a")}, version_id_col=a.c.version_id, ) mapper(B, b) @@ -754,30 +794,30 @@ class NoBumpOnRelationshipTest(fixtures.MappedTest): self._run_test() def test_functional_counter(self): - A, B = self.classes('A', 'B') - a, b = self.tables('a', 'b') + A, B = self.classes("A", "B") + a, b = self.tables("a", "b") mapper( - A, a, properties={ - 'bs': relationship(B, backref='a') - }, + A, + a, + properties={"bs": relationship(B, backref="a")}, version_id_col=a.c.version_id, - version_id_generator=lambda num: (num or 0) + 1 + version_id_generator=lambda num: (num or 0) + 1, ) mapper(B, b) self._run_test() def test_no_counter(self): - A, B = self.classes('A', 'B') - a, b = self.tables('a', 'b') + A, B = self.classes("A", "B") + a, b = self.tables("a", "b") mapper( - A, a, properties={ - 'bs': relationship(B, backref='a') - }, + A, + a, + properties={"bs": relationship(B, backref="a")}, version_id_col=a.c.version_id, - version_id_generator=False + version_id_generator=False, ) mapper(B, b) @@ -786,7 +826,7 @@ class NoBumpOnRelationshipTest(fixtures.MappedTest): class ColumnTypeTest(fixtures.MappedTest): __backend__ = True - __requires__ = 'sane_rowcount', + __requires__ = ("sane_rowcount",) @classmethod def define_tables(cls, metadata): @@ -797,10 +837,13 @@ class ColumnTypeTest(fixtures.MappedTest): assert isinstance(value, datetime.date) return value - Table('version_table', metadata, - Column('id', SpecialType, primary_key=True), - Column('version_id', Integer, nullable=False), - Column('value', String(40), nullable=False)) + Table( + "version_table", + metadata, + Column("id", SpecialType, primary_key=True), + Column("version_id", Integer, nullable=False), + Column("value", String(40), nullable=False), + ) @classmethod def setup_classes(cls): @@ -820,11 +863,11 @@ class ColumnTypeTest(fixtures.MappedTest): Foo = self.classes.Foo s1 = self._fixture() - f1 = Foo(id=datetime.date.today(), value='f1') + f1 = Foo(id=datetime.date.today(), value="f1") s1.add(f1) s1.commit() - f1.value = 'f1rev2' + f1.value = "f1rev2" s1.commit() @@ -834,21 +877,22 @@ class RowSwitchTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'p', metadata, - Column('id', String(10), primary_key=True), - Column('version_id', Integer, default=1, nullable=False), - Column('data', String(50)) + "p", + metadata, + Column("id", String(10), primary_key=True), + Column("version_id", Integer, default=1, nullable=False), + Column("data", String(50)), ) Table( - 'c', metadata, - Column('id', String(10), ForeignKey('p.id'), primary_key=True), - Column('version_id', Integer, default=1, nullable=False), - Column('data', String(50)) + "c", + metadata, + Column("id", String(10), ForeignKey("p.id"), primary_key=True), + Column("version_id", Integer, default=1, nullable=False), + Column("data", String(50)), ) @classmethod def setup_classes(cls): - class P(cls.Basic): pass @@ -860,9 +904,15 @@ class RowSwitchTest(fixtures.MappedTest): p, c, C, P = cls.tables.p, cls.tables.c, cls.classes.C, cls.classes.P mapper( - P, p, version_id_col=p.c.version_id, properties={ - 'c': relationship( - C, uselist=False, cascade='all, delete-orphan')}) + P, + p, + version_id_col=p.c.version_id, + properties={ + "c": relationship( + C, uselist=False, cascade="all, delete-orphan" + ) + }, + ) mapper(C, c, version_id_col=c.c.version_id) @testing.emits_warning(r".*versioning cannot be verified") @@ -870,13 +920,13 @@ class RowSwitchTest(fixtures.MappedTest): P = self.classes.P session = sessionmaker()() - session.add(P(id='P1', data='P version 1')) + session.add(P(id="P1", data="P version 1")) session.commit() session.close() p = session.query(P).first() session.delete(p) - session.add(P(id='P1', data="really a row-switch")) + session.add(P(id="P1", data="really a row-switch")) session.commit() @testing.emits_warning(r".*versioning cannot be verified") @@ -886,41 +936,42 @@ class RowSwitchTest(fixtures.MappedTest): assert P.c.property.strategy.use_get session = sessionmaker()() - session.add(P(id='P1', data='P version 1')) + session.add(P(id="P1", data="P version 1")) session.commit() session.close() p = session.query(P).first() - p.c = C(data='child version 1') + p.c = C(data="child version 1") session.commit() p = session.query(P).first() - p.c = C(data='child row-switch') + p.c = C(data="child row-switch") session.commit() class AlternateGeneratorTest(fixtures.MappedTest): __backend__ = True - __requires__ = 'sane_rowcount', + __requires__ = ("sane_rowcount",) @classmethod def define_tables(cls, metadata): Table( - 'p', metadata, - Column('id', String(10), primary_key=True), - Column('version_id', String(32), nullable=False), - Column('data', String(50)) + "p", + metadata, + Column("id", String(10), primary_key=True), + Column("version_id", String(32), nullable=False), + Column("data", String(50)), ) Table( - 'c', metadata, - Column('id', String(10), ForeignKey('p.id'), primary_key=True), - Column('version_id', String(32), nullable=False), - Column('data', String(50)) + "c", + metadata, + Column("id", String(10), ForeignKey("p.id"), primary_key=True), + Column("version_id", String(32), nullable=False), + Column("data", String(50)), ) @classmethod def setup_classes(cls): - class P(cls.Basic): pass @@ -932,14 +983,20 @@ class AlternateGeneratorTest(fixtures.MappedTest): p, c, C, P = cls.tables.p, cls.tables.c, cls.classes.C, cls.classes.P mapper( - P, p, version_id_col=p.c.version_id, + P, + p, + version_id_col=p.c.version_id, version_id_generator=lambda x: make_uuid(), properties={ - 'c': relationship( - C, uselist=False, cascade='all, delete-orphan') - }) + "c": relationship( + C, uselist=False, cascade="all, delete-orphan" + ) + }, + ) mapper( - C, c, version_id_col=c.c.version_id, + C, + c, + version_id_col=c.c.version_id, version_id_generator=lambda x: make_uuid(), ) @@ -948,13 +1005,13 @@ class AlternateGeneratorTest(fixtures.MappedTest): P = self.classes.P session = sessionmaker()() - session.add(P(id='P1', data='P version 1')) + session.add(P(id="P1", data="P version 1")) session.commit() session.close() p = session.query(P).first() session.delete(p) - session.add(P(id='P1', data="really a row-switch")) + session.add(P(id="P1", data="really a row-switch")) session.commit() @testing.emits_warning(r".*versioning cannot be verified") @@ -964,16 +1021,16 @@ class AlternateGeneratorTest(fixtures.MappedTest): assert P.c.property.strategy.use_get session = sessionmaker()() - session.add(P(id='P1', data='P version 1')) + session.add(P(id="P1", data="P version 1")) session.commit() session.close() p = session.query(P).first() - p.c = C(data='child version 1') + p.c = C(data="child version 1") session.commit() p = session.query(P).first() - p.c = C(data='child row-switch') + p.c = C(data="child row-switch") session.commit() @testing.emits_warning(r".*versioning cannot be verified") @@ -987,7 +1044,7 @@ class AlternateGeneratorTest(fixtures.MappedTest): # testing exactly what its looking for sess1 = Session() - sess1.add(P(id='P1', data='P version 1')) + sess1.add(P(id="P1", data="P version 1")) sess1.commit() sess1.close() @@ -1000,16 +1057,16 @@ class AlternateGeneratorTest(fixtures.MappedTest): sess1.commit() # this can be removed and it still passes - sess1.add(P(id='P1', data='P version 2')) + sess1.add(P(id="P1", data="P version 2")) sess1.commit() - p2.data = 'P overwritten by concurrent tx' + p2.data = "P overwritten by concurrent tx" if testing.db.dialect.supports_sane_rowcount: assert_raises_message( orm.exc.StaleDataError, r"UPDATE statement on table 'p' expected to update " r"1 row\(s\); 0 were matched.", - sess2.commit + sess2.commit, ) else: sess2.commit @@ -1021,22 +1078,23 @@ class PlainInheritanceTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( - 'base', metadata, + "base", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('version_id', Integer, nullable=True), - Column('data', String(50)) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("version_id", Integer, nullable=True), + Column("data", String(50)), ) Table( - 'sub', metadata, - Column('id', Integer, ForeignKey('base.id'), primary_key=True), - Column('sub_data', String(50)) + "sub", + metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column("sub_data", String(50)), ) @classmethod def setup_classes(cls): - class Base(cls.Basic): pass @@ -1046,18 +1104,21 @@ class PlainInheritanceTest(fixtures.MappedTest): @testing.emits_warning(r".*versioning cannot be verified") def test_update_child_table_only(self): Base, sub, base, Sub = ( - self.classes.Base, self.tables.sub, self.tables.base, - self.classes.Sub) + self.classes.Base, + self.tables.sub, + self.tables.base, + self.classes.Sub, + ) mapper(Base, base, version_id_col=base.c.version_id) mapper(Sub, sub, inherits=Base) s = Session() - s1 = Sub(data='b', sub_data='s') + s1 = Sub(data="b", sub_data="s") s.add(s1) s.commit() - s1.sub_data = 's2' + s1.sub_data = "s2" s.commit() eq_(s1.version_id, 2) @@ -1068,28 +1129,30 @@ class InheritanceTwoVersionIdsTest(fixtures.MappedTest): versioning column. """ + __backend__ = True @classmethod def define_tables(cls, metadata): Table( - 'base', metadata, + "base", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('version_id', Integer, nullable=True), - Column('data', String(50)) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("version_id", Integer, nullable=True), + Column("data", String(50)), ) Table( - 'sub', metadata, - Column('id', Integer, ForeignKey('base.id'), primary_key=True), - Column('version_id', Integer, nullable=False), - Column('sub_data', String(50)) + "sub", + metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column("version_id", Integer, nullable=False), + Column("sub_data", String(50)), ) @classmethod def setup_classes(cls): - class Base(cls.Basic): pass @@ -1098,14 +1161,17 @@ class InheritanceTwoVersionIdsTest(fixtures.MappedTest): def test_base_both(self): Base, sub, base, Sub = ( - self.classes.Base, self.tables.sub, self.tables.base, - self.classes.Sub) + self.classes.Base, + self.tables.sub, + self.tables.base, + self.classes.Sub, + ) mapper(Base, base, version_id_col=base.c.version_id) mapper(Sub, sub, inherits=Base) session = Session() - b1 = Base(data='b1') + b1 = Base(data="b1") session.add(b1) session.commit() eq_(b1.version_id, 1) @@ -1114,14 +1180,17 @@ class InheritanceTwoVersionIdsTest(fixtures.MappedTest): def test_sub_both(self): Base, sub, base, Sub = ( - self.classes.Base, self.tables.sub, self.tables.base, - self.classes.Sub) + self.classes.Base, + self.tables.sub, + self.tables.base, + self.classes.Sub, + ) mapper(Base, base, version_id_col=base.c.version_id) mapper(Sub, sub, inherits=Base) session = Session() - s1 = Sub(data='s1', sub_data='s1') + s1 = Sub(data="s1", sub_data="s1") session.add(s1) session.commit() @@ -1133,14 +1202,17 @@ class InheritanceTwoVersionIdsTest(fixtures.MappedTest): def test_sub_only(self): Base, sub, base, Sub = ( - self.classes.Base, self.tables.sub, self.tables.base, - self.classes.Sub) + self.classes.Base, + self.tables.sub, + self.tables.base, + self.classes.Sub, + ) mapper(Base, base) mapper(Sub, sub, inherits=Base, version_id_col=sub.c.version_id) session = Session() - s1 = Sub(data='s1', sub_data='s1') + s1 = Sub(data="s1", sub_data="s1") session.add(s1) session.commit() @@ -1152,8 +1224,11 @@ class InheritanceTwoVersionIdsTest(fixtures.MappedTest): def test_mismatch_version_col_warning(self): Base, sub, base, Sub = ( - self.classes.Base, self.tables.sub, self.tables.base, - self.classes.Sub) + self.classes.Base, + self.tables.sub, + self.tables.base, + self.classes.Sub, + ) mapper(Base, base, version_id_col=base.c.version_id) @@ -1164,12 +1239,16 @@ class InheritanceTwoVersionIdsTest(fixtures.MappedTest): "automatically populate the inherited versioning column. " "version_id_col should only be specified on " "the base-most mapper that includes versioning.", - mapper, Sub, sub, inherits=Base, - version_id_col=sub.c.version_id) + mapper, + Sub, + sub, + inherits=Base, + version_id_col=sub.c.version_id, + ) class ServerVersioningTest(fixtures.MappedTest): - run_define_tables = 'each' + run_define_tables = "each" __backend__ = True @classmethod @@ -1196,18 +1275,23 @@ class ServerVersioningTest(fixtures.MappedTest): return stmt._counter Table( - 'version_table', metadata, + "version_table", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), Column( - 'version_id', Integer, nullable=False, - default=IncDefault(), onupdate=IncDefault()), - Column('value', String(40), nullable=False)) + "version_id", + Integer, + nullable=False, + default=IncDefault(), + onupdate=IncDefault(), + ), + Column("value", String(40), nullable=False), + ) @classmethod def setup_classes(cls): - class Foo(cls.Basic): pass @@ -1218,9 +1302,11 @@ class ServerVersioningTest(fixtures.MappedTest): Foo, version_table = self.classes.Foo, self.tables.version_table mapper( - Foo, version_table, version_id_col=version_table.c.version_id, + Foo, + version_table, + version_id_col=version_table.c.version_id, version_id_generator=False, - eager_defaults=eager_defaults + eager_defaults=eager_defaults, ) s1 = Session(expire_on_commit=expire_on_commit) @@ -1235,7 +1321,7 @@ class ServerVersioningTest(fixtures.MappedTest): def _test_insert_col(self, **kw): sess = self._fixture(**kw) - f1 = self.classes.Foo(value='f1') + f1 = self.classes.Foo(value="f1") sess.add(f1) statements = [ @@ -1245,7 +1331,7 @@ class ServerVersioningTest(fixtures.MappedTest): CompiledSQL( "INSERT INTO version_table (version_id, value) " "VALUES (1, :value)", - lambda ctx: [{'value': 'f1'}] + lambda ctx: [{"value": "f1"}], ) ] if not testing.db.dialect.implicit_returning: @@ -1256,7 +1342,7 @@ class ServerVersioningTest(fixtures.MappedTest): "SELECT version_table.version_id " "AS version_table_version_id " "FROM version_table WHERE version_table.id = :param_1", - lambda ctx: [{"param_1": 1}] + lambda ctx: [{"param_1": 1}], ) ) self.assert_sql_execution(testing.db, sess.flush, *statements) @@ -1271,11 +1357,11 @@ class ServerVersioningTest(fixtures.MappedTest): def _test_update_col(self, **kw): sess = self._fixture(**kw) - f1 = self.classes.Foo(value='f1') + f1 = self.classes.Foo(value="f1") sess.add(f1) sess.flush() - f1.value = 'f2' + f1.value = "f2" statements = [ # note that the assertsql tests the rule against @@ -1288,7 +1374,10 @@ class ServerVersioningTest(fixtures.MappedTest): lambda ctx: [ { "version_table_id": 1, - "version_table_version_id": 1, "value": "f2"}] + "version_table_version_id": 1, + "value": "f2", + } + ], ) ] if not testing.db.dialect.implicit_returning: @@ -1299,7 +1388,7 @@ class ServerVersioningTest(fixtures.MappedTest): "SELECT version_table.version_id " "AS version_table_version_id " "FROM version_table WHERE version_table.id = :param_1", - lambda ctx: [{"param_1": 1}] + lambda ctx: [{"param_1": 1}], ) ) self.assert_sql_execution(testing.db, sess.flush, *statements) @@ -1309,7 +1398,7 @@ class ServerVersioningTest(fixtures.MappedTest): def test_sql_expr_bump(self): sess = self._fixture() - f1 = self.classes.Foo(value='f1') + f1 = self.classes.Foo(value="f1") sess.add(f1) sess.flush() @@ -1327,7 +1416,7 @@ class ServerVersioningTest(fixtures.MappedTest): def test_sql_expr_w_mods_bump(self): sess = self._fixture() - f1 = self.classes.Foo(id=2, value='f1') + f1 = self.classes.Foo(id=2, value="f1") sess.add(f1) sess.flush() @@ -1344,15 +1433,15 @@ class ServerVersioningTest(fixtures.MappedTest): def test_multi_update(self): sess = self._fixture() - f1 = self.classes.Foo(value='f1') - f2 = self.classes.Foo(value='f2') - f3 = self.classes.Foo(value='f3') + f1 = self.classes.Foo(value="f1") + f2 = self.classes.Foo(value="f2") + f3 = self.classes.Foo(value="f3") sess.add_all([f1, f2, f3]) sess.flush() - f1.value = 'f1a' - f2.value = 'f2a' - f3.value = 'f3a' + f1.value = "f1a" + f2.value = "f2a" + f3.value = "f3a" statements = [ # note that the assertsql tests the rule against @@ -1365,7 +1454,10 @@ class ServerVersioningTest(fixtures.MappedTest): lambda ctx: [ { "version_table_id": 1, - "version_table_version_id": 1, "value": "f1a"}] + "version_table_version_id": 1, + "value": "f1a", + } + ], ), CompiledSQL( "UPDATE version_table SET version_id=2, value=:value " @@ -1374,7 +1466,10 @@ class ServerVersioningTest(fixtures.MappedTest): lambda ctx: [ { "version_table_id": 2, - "version_table_version_id": 1, "value": "f2a"}] + "version_table_version_id": 1, + "value": "f2a", + } + ], ), CompiledSQL( "UPDATE version_table SET version_id=2, value=:value " @@ -1383,38 +1478,43 @@ class ServerVersioningTest(fixtures.MappedTest): lambda ctx: [ { "version_table_id": 3, - "version_table_version_id": 1, "value": "f3a"}] - ) + "version_table_version_id": 1, + "value": "f3a", + } + ], + ), ] if not testing.db.dialect.implicit_returning: # DBs without implicit returning, we must immediately # SELECT for the new version id - statements.extend([ - CompiledSQL( - "SELECT version_table.version_id " - "AS version_table_version_id " - "FROM version_table WHERE version_table.id = :param_1", - lambda ctx: [{"param_1": 1}] - ), - CompiledSQL( - "SELECT version_table.version_id " - "AS version_table_version_id " - "FROM version_table WHERE version_table.id = :param_1", - lambda ctx: [{"param_1": 2}] - ), - CompiledSQL( - "SELECT version_table.version_id " - "AS version_table_version_id " - "FROM version_table WHERE version_table.id = :param_1", - lambda ctx: [{"param_1": 3}] - ) - ]) + statements.extend( + [ + CompiledSQL( + "SELECT version_table.version_id " + "AS version_table_version_id " + "FROM version_table WHERE version_table.id = :param_1", + lambda ctx: [{"param_1": 1}], + ), + CompiledSQL( + "SELECT version_table.version_id " + "AS version_table_version_id " + "FROM version_table WHERE version_table.id = :param_1", + lambda ctx: [{"param_1": 2}], + ), + CompiledSQL( + "SELECT version_table.version_id " + "AS version_table_version_id " + "FROM version_table WHERE version_table.id = :param_1", + lambda ctx: [{"param_1": 3}], + ), + ] + ) self.assert_sql_execution(testing.db, sess.flush, *statements) def test_delete_col(self): sess = self._fixture() - f1 = self.classes.Foo(value='f1') + f1 = self.classes.Foo(value="f1") sess.add(f1) sess.flush() @@ -1428,7 +1528,7 @@ class ServerVersioningTest(fixtures.MappedTest): "DELETE FROM version_table " "WHERE version_table.id = :id AND " "version_table.version_id = :version_id", - lambda ctx: [{"id": 1, "version_id": 1}] + lambda ctx: [{"id": 1, "version_id": 1}], ) ] self.assert_sql_execution(testing.db, sess.flush, *statements) @@ -1437,7 +1537,7 @@ class ServerVersioningTest(fixtures.MappedTest): def test_concurrent_mod_err_expire_on_commit(self): sess = self._fixture() - f1 = self.classes.Foo(value='f1') + f1 = self.classes.Foo(value="f1") sess.add(f1) sess.commit() @@ -1445,23 +1545,23 @@ class ServerVersioningTest(fixtures.MappedTest): s2 = Session() f2 = s2.query(self.classes.Foo).first() - f2.value = 'f2' + f2.value = "f2" s2.commit() - f1.value = 'f3' + f1.value = "f3" assert_raises_message( orm.exc.StaleDataError, r"UPDATE statement on table 'version_table' expected to " r"update 1 row\(s\); 0 were matched.", - sess.commit + sess.commit, ) @testing.requires.sane_rowcount_w_returning def test_concurrent_mod_err_noexpire_on_commit(self): sess = self._fixture(expire_on_commit=False) - f1 = self.classes.Foo(value='f1') + f1 = self.classes.Foo(value="f1") sess.add(f1) sess.commit() @@ -1472,32 +1572,33 @@ class ServerVersioningTest(fixtures.MappedTest): s2 = Session(expire_on_commit=False) f2 = s2.query(self.classes.Foo).first() - f2.value = 'f2' + f2.value = "f2" s2.commit() - f1.value = 'f3' + f1.value = "f3" assert_raises_message( orm.exc.StaleDataError, r"UPDATE statement on table 'version_table' expected to " r"update 1 row\(s\); 0 were matched.", - sess.commit + sess.commit, ) class ManualVersionTest(fixtures.MappedTest): - run_define_tables = 'each' + run_define_tables = "each" __backend__ = True @classmethod def define_tables(cls, metadata): Table( - "a", metadata, + "a", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30)), - Column('vid', Integer) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + Column("vid", Integer), ) @classmethod @@ -1508,8 +1609,11 @@ class ManualVersionTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): mapper( - cls.classes.A, cls.tables.a, version_id_col=cls.tables.a.c.vid, - version_id_generator=False) + cls.classes.A, + cls.tables.a, + version_id_col=cls.tables.a.c.vid, + version_id_generator=False, + ) def test_insert(self): sess = Session() @@ -1527,12 +1631,12 @@ class ManualVersionTest(fixtures.MappedTest): a1 = self.classes.A() a1.vid = 1 - a1.data = 'd1' + a1.data = "d1" sess.add(a1) sess.commit() a1.vid = 2 - a1.data = 'd2' + a1.data = "d2" sess.commit() @@ -1544,17 +1648,14 @@ class ManualVersionTest(fixtures.MappedTest): a1 = self.classes.A() a1.vid = 1 - a1.data = 'd1' + a1.data = "d1" sess.add(a1) sess.commit() a1.vid = 2 sess.execute(self.tables.a.update().values(vid=3)) - a1.data = 'd2' - assert_raises( - orm_exc.StaleDataError, - sess.commit - ) + a1.data = "d2" + assert_raises(orm_exc.StaleDataError, sess.commit) @testing.emits_warning(r".*versioning cannot be verified") def test_update_version_conditional(self): @@ -1562,18 +1663,18 @@ class ManualVersionTest(fixtures.MappedTest): a1 = self.classes.A() a1.vid = 1 - a1.data = 'd1' + a1.data = "d1" sess.add(a1) sess.commit() # change the data and UPDATE without # incrementing version id - a1.data = 'd2' + a1.data = "d2" sess.commit() eq_(a1.vid, 1) - a1.data = 'd3' + a1.data = "d3" a1.vid = 2 sess.commit() @@ -1581,26 +1682,27 @@ class ManualVersionTest(fixtures.MappedTest): class ManualInheritanceVersionTest(fixtures.MappedTest): - run_define_tables = 'each' + run_define_tables = "each" __backend__ = True - __requires__ = 'sane_rowcount', + __requires__ = ("sane_rowcount",) @classmethod def define_tables(cls, metadata): Table( - "a", metadata, + "a", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(30)), - Column('vid', Integer, nullable=False) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(30)), + Column("vid", Integer, nullable=False), ) Table( - "b", metadata, - Column( - 'id', Integer, ForeignKey('a.id'), primary_key=True), - Column('b_data', String(30)), + "b", + metadata, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + Column("b_data", String(30)), ) @classmethod @@ -1614,11 +1716,13 @@ class ManualInheritanceVersionTest(fixtures.MappedTest): @classmethod def setup_mappers(cls): mapper( - cls.classes.A, cls.tables.a, version_id_col=cls.tables.a.c.vid, - version_id_generator=False) + cls.classes.A, + cls.tables.a, + version_id_col=cls.tables.a.c.vid, + version_id_generator=False, + ) - mapper( - cls.classes.B, cls.tables.b, inherits=cls.classes.A) + mapper(cls.classes.B, cls.tables.b, inherits=cls.classes.A) @testing.emits_warning(r".*versioning cannot be verified") def test_no_increment(self): @@ -1626,18 +1730,18 @@ class ManualInheritanceVersionTest(fixtures.MappedTest): b1 = self.classes.B() b1.vid = 1 - b1.data = 'd1' + b1.data = "d1" sess.add(b1) sess.commit() # change col on subtable only without # incrementing version id - b1.b_data = 'bd2' + b1.b_data = "bd2" sess.commit() eq_(b1.vid, 1) - b1.b_data = 'd3' + b1.b_data = "d3" b1.vid = 2 sess.commit() @@ -1651,11 +1755,15 @@ class VersioningMappedSelectTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): - Table('version_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('version_id', Integer, nullable=False), - Column('value', String(40), nullable=False)) + Table( + "version_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("version_id", Integer, nullable=False), + Column("value", String(40), nullable=False), + ) @classmethod def setup_classes(cls): @@ -1665,8 +1773,11 @@ class VersioningMappedSelectTest(fixtures.MappedTest): def _implicit_version_fixture(self): Foo, version_table = self.classes.Foo, self.tables.version_table - current = version_table.select().\ - where(version_table.c.id > 0).alias('current_table') + current = ( + version_table.select() + .where(version_table.c.id > 0) + .alias("current_table") + ) mapper(Foo, current, version_id_col=version_table.c.version_id) s1 = Session() @@ -1675,12 +1786,18 @@ class VersioningMappedSelectTest(fixtures.MappedTest): def _explicit_version_fixture(self): Foo, version_table = self.classes.Foo, self.tables.version_table - current = version_table.select().\ - where(version_table.c.id > 0).alias('current_table') + current = ( + version_table.select() + .where(version_table.c.id > 0) + .alias("current_table") + ) - mapper(Foo, current, - version_id_col=version_table.c.version_id, - version_id_generator=False) + mapper( + Foo, + current, + version_id_col=version_table.c.version_id, + version_id_generator=False, + ) s1 = Session() return s1 @@ -1689,18 +1806,18 @@ class VersioningMappedSelectTest(fixtures.MappedTest): Foo = self.classes.Foo s1 = self._implicit_version_fixture() - f1 = Foo(value='f1') - f2 = Foo(value='f2') + f1 = Foo(value="f1") + f2 = Foo(value="f2") s1.add_all((f1, f2)) s1.commit() - f1.value = 'f1rev2' - f2.value = 'f2rev2' + f1.value = "f1rev2" + f2.value = "f2rev2" s1.commit() eq_( s1.query(Foo.id, Foo.value, Foo.version_id).order_by(Foo.id).all(), - [(f1.id, 'f1rev2', 2), (f2.id, 'f2rev2', 2)] + [(f1.id, "f1rev2", 2), (f2.id, "f2rev2", 2)], ) @testing.emits_warning(r".*versioning cannot be verified") @@ -1708,20 +1825,20 @@ class VersioningMappedSelectTest(fixtures.MappedTest): Foo = self.classes.Foo s1 = self._explicit_version_fixture() - f1 = Foo(value='f1', version_id=1) - f2 = Foo(value='f2', version_id=1) + f1 = Foo(value="f1", version_id=1) + f2 = Foo(value="f2", version_id=1) s1.add_all((f1, f2)) s1.flush() # note this requires that the Session was not expired until # we fix #4195 - f1.value = 'f1rev2' + f1.value = "f1rev2" f1.version_id = 2 - f2.value = 'f2rev2' + f2.value = "f2rev2" f2.version_id = 2 s1.flush() eq_( s1.query(Foo.id, Foo.value, Foo.version_id).order_by(Foo.id).all(), - [(f1.id, 'f1rev2', 2), (f2.id, 'f2rev2', 2)] - ) \ No newline at end of file + [(f1.id, "f1rev2", 2), (f2.id, "f2rev2", 2)], + ) diff --git a/test/perf/invalidate_stresstest.py b/test/perf/invalidate_stresstest.py index cbf20e18b2..29fbbb118b 100644 --- a/test/perf/invalidate_stresstest.py +++ b/test/perf/invalidate_stresstest.py @@ -1,9 +1,11 @@ from __future__ import print_function import gevent.monkey + gevent.monkey.patch_all() # noqa import logging + logging.basicConfig() # noqa # logging.getLogger("sqlalchemy.pool").setLevel(logging.INFO) from sqlalchemy import event @@ -12,8 +14,9 @@ import sys from sqlalchemy import create_engine import traceback -engine = create_engine('mysql+pymysql://scott:tiger@localhost/test', - pool_size=50, max_overflow=0) +engine = create_engine( + "mysql+pymysql://scott:tiger@localhost/test", pool_size=50, max_overflow=0 +) @event.listens_for(engine, "connect") @@ -32,10 +35,10 @@ def worker(): except Exception: # traceback.print_exc() - sys.stderr.write('X') + sys.stderr.write("X") else: conn.close() - sys.stderr.write('.') + sys.stderr.write(".") def main(): diff --git a/test/perf/orm2010.py b/test/perf/orm2010.py index de63a36c84..fabfe533b4 100644 --- a/test/perf/orm2010.py +++ b/test/perf/orm2010.py @@ -1,17 +1,25 @@ import warnings + warnings.filterwarnings("ignore", r".*Decimal objects natively") # noqa # speed up cdecimal if available try: import cdecimal import sys - sys.modules['decimal'] = cdecimal + + sys.modules["decimal"] = cdecimal except ImportError: pass from sqlalchemy import __version__ -from sqlalchemy import Column, Integer, create_engine, ForeignKey, \ - String, Numeric +from sqlalchemy import ( + Column, + Integer, + create_engine, + ForeignKey, + String, + Numeric, +) from sqlalchemy.orm import Session, relationship @@ -24,43 +32,44 @@ Base = declarative_base() class Employee(Base): - __tablename__ = 'employee' + __tablename__ = "employee" id = Column(Integer, primary_key=True) name = Column(String(100), nullable=False) type = Column(String(50), nullable=False) - __mapper_args__ = {'polymorphic_on': type} + __mapper_args__ = {"polymorphic_on": type} class Boss(Employee): - __tablename__ = 'boss' + __tablename__ = "boss" - id = Column(Integer, ForeignKey('employee.id'), primary_key=True) + id = Column(Integer, ForeignKey("employee.id"), primary_key=True) golf_average = Column(Numeric) - __mapper_args__ = {'polymorphic_identity': 'boss'} + __mapper_args__ = {"polymorphic_identity": "boss"} class Grunt(Employee): - __tablename__ = 'grunt' + __tablename__ = "grunt" - id = Column(Integer, ForeignKey('employee.id'), primary_key=True) + id = Column(Integer, ForeignKey("employee.id"), primary_key=True) savings = Column(Numeric) - employer_id = Column(Integer, ForeignKey('boss.id')) + employer_id = Column(Integer, ForeignKey("boss.id")) - employer = relationship("Boss", backref="employees", - primaryjoin=Boss.id == employer_id) + employer = relationship( + "Boss", backref="employees", primaryjoin=Boss.id == employer_id + ) - __mapper_args__ = {'polymorphic_identity': 'grunt'} + __mapper_args__ = {"polymorphic_identity": "grunt"} -if os.path.exists('orm2010.db'): - os.remove('orm2010.db') +if os.path.exists("orm2010.db"): + os.remove("orm2010.db") # use a file based database so that cursor.execute() has some # palpable overhead. -engine = create_engine('sqlite:///orm2010.db') +engine = create_engine("sqlite:///orm2010.db") Base.metadata.create_all(engine) @@ -72,10 +81,7 @@ def runit(status, factor=1, query_runs=5): num_grunts = num_bosses * 100 bosses = [ - Boss( - name="Boss %d" % i, - golf_average=Decimal(random.randint(40, 150)) - ) + Boss(name="Boss %d" % i, golf_average=Decimal(random.randint(40, 150))) for i in range(num_bosses) ] @@ -85,7 +91,7 @@ def runit(status, factor=1, query_runs=5): grunts = [ Grunt( name="Grunt %d" % i, - savings=Decimal(random.randint(5000000, 15000000) / 100) + savings=Decimal(random.randint(5000000, 15000000) / 100), ) for i in range(num_grunts) ] @@ -115,12 +121,14 @@ def runit(status, factor=1, query_runs=5): # load all the Grunts, print a report with their name, stats, # and their bosses' stats. for grunt in sess.query(Grunt): - report.append(( - grunt.name, - grunt.savings, - grunt.employer.name, - grunt.employer.golf_average - )) + report.append( + ( + grunt.name, + grunt.savings, + grunt.employer.name, + grunt.employer.golf_average, + ) + ) sess.close() # close out the session @@ -128,6 +136,7 @@ def runit(status, factor=1, query_runs=5): def run_with_profile(runsnake=False, dump=False): import cProfile import pstats + filename = "orm2010.profile" if os.path.exists("orm2010.profile"): @@ -136,24 +145,31 @@ def run_with_profile(runsnake=False, dump=False): def status(msg): print(msg) - cProfile.runctx('runit(status)', globals(), locals(), filename) + cProfile.runctx("runit(status)", globals(), locals(), filename) stats = pstats.Stats(filename) - counts_by_methname = dict((key[2], - stats.stats[key][0]) for key in stats.stats) + counts_by_methname = dict( + (key[2], stats.stats[key][0]) for key in stats.stats + ) print("SQLA Version: %s" % __version__) print("Total calls %d" % stats.total_calls) print("Total cpu seconds: %.2f" % stats.total_tt) - print('Total execute calls: %d' - % counts_by_methname[""]) - print('Total executemany calls: %d' - % counts_by_methname.get("", 0)) + print( + "Total execute calls: %d" + % counts_by_methname[ + "" + ] + ) + print( + "Total executemany calls: %d" + % counts_by_methname.get( + "", 0 + ) + ) if dump: - stats.sort_stats('time', 'calls') + stats.sort_stats("time", "calls") stats.print_stats() if runsnake: @@ -162,6 +178,7 @@ def run_with_profile(runsnake=False, dump=False): def run_with_time(): import time + now = time.time() def status(msg): @@ -171,16 +188,25 @@ def run_with_time(): print("Total time: %d" % (time.time() - now)) -if __name__ == '__main__': +if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() - parser.add_argument('--profile', action='store_true', - help='run shorter test suite w/ cprofilng') - parser.add_argument('--dump', action='store_true', - help='dump full call profile (implies --profile)') - parser.add_argument('--runsnake', action='store_true', - help='invoke runsnakerun (implies --profile)') + parser.add_argument( + "--profile", + action="store_true", + help="run shorter test suite w/ cprofilng", + ) + parser.add_argument( + "--dump", + action="store_true", + help="dump full call profile (implies --profile)", + ) + parser.add_argument( + "--runsnake", + action="store_true", + help="invoke runsnakerun (implies --profile)", + ) args = parser.parse_args() diff --git a/test/requirements.py b/test/requirements.py index 6c9c5a7606..cbe4a2be72 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -7,19 +7,20 @@ from sqlalchemy import util import sys from sqlalchemy.testing.requirements import SuiteRequirements from sqlalchemy.testing import exclusions -from sqlalchemy.testing.exclusions import \ - skip, \ - skip_if,\ - only_if,\ - only_on,\ - fails_on_everything_except,\ - fails_on,\ - fails_if,\ - succeeds_if,\ - SpecPredicate,\ - against,\ - LambdaPredicate,\ - requires_tag +from sqlalchemy.testing.exclusions import ( + skip, + skip_if, + only_if, + only_on, + fails_on_everything_except, + fails_on, + fails_if, + succeeds_if, + SpecPredicate, + against, + LambdaPredicate, + requires_tag, +) def no_support(db, reason): @@ -35,11 +36,13 @@ class DefaultRequirements(SuiteRequirements): def deferrable_or_no_constraints(self): """Target database must support deferrable constraints.""" - return skip_if([ - no_support('firebird', 'not supported by database'), - no_support('mysql', 'not supported by database'), - no_support('mssql', 'not supported by database'), - ]) + return skip_if( + [ + no_support("firebird", "not supported by database"), + no_support("mysql", "not supported by database"), + no_support("mssql", "not supported by database"), + ] + ) @property def check_constraints(self): @@ -53,7 +56,7 @@ class DefaultRequirements(SuiteRequirements): return self.check_constraints + fails_on( self._mysql_not_mariadb_102, - "check constraints don't enforce on MySQL, MariaDB<10.2" + "check constraints don't enforce on MySQL, MariaDB<10.2", ) @property @@ -66,17 +69,13 @@ class DefaultRequirements(SuiteRequirements): def implicitly_named_constraints(self): """target database must apply names to unnamed constraints.""" - return skip_if([ - no_support('sqlite', 'not supported by database'), - ]) + return skip_if([no_support("sqlite", "not supported by database")]) @property def foreign_keys(self): """Target database must support foreign keys.""" - return skip_if( - no_support('sqlite', 'not supported by database') - ) + return skip_if(no_support("sqlite", "not supported by database")) @property def on_update_cascade(self): @@ -84,17 +83,18 @@ class DefaultRequirements(SuiteRequirements): foreign keys.""" return skip_if( - ['sqlite', 'oracle'], - 'target backend %(doesnt_support)s ON UPDATE CASCADE' - ) + ["sqlite", "oracle"], + "target backend %(doesnt_support)s ON UPDATE CASCADE", + ) @property def non_updating_cascade(self): """target database must *not* support ON UPDATE..CASCADE behavior in foreign keys.""" - return fails_on_everything_except('sqlite', 'oracle', '+zxjdbc') + \ - skip_if('mssql') + return fails_on_everything_except( + "sqlite", "oracle", "+zxjdbc" + ) + skip_if("mssql") @property def recursive_fk_cascade(self): @@ -107,54 +107,57 @@ class DefaultRequirements(SuiteRequirements): def deferrable_fks(self): """target database must support deferrable fks""" - return only_on(['oracle']) + return only_on(["oracle"]) @property def foreign_key_constraint_option_reflection_ondelete(self): - return only_on(['postgresql', 'mysql', 'sqlite', 'oracle']) + return only_on(["postgresql", "mysql", "sqlite", "oracle"]) @property def foreign_key_constraint_option_reflection_onupdate(self): - return only_on(['postgresql', 'mysql', 'sqlite']) + return only_on(["postgresql", "mysql", "sqlite"]) @property def comment_reflection(self): - return only_on(['postgresql', 'mysql', 'oracle']) + return only_on(["postgresql", "mysql", "oracle"]) @property def unbounded_varchar(self): """Target database must support VARCHAR with no length""" - return skip_if([ - "firebird", "oracle", "mysql" - ], "not supported by database" - ) + return skip_if( + ["firebird", "oracle", "mysql"], "not supported by database" + ) @property def boolean_col_expressions(self): """Target database must support boolean expressions as columns""" - return skip_if([ - no_support('firebird', 'not supported by database'), - no_support('oracle', 'not supported by database'), - no_support('mssql', 'not supported by database'), - no_support('sybase', 'not supported by database'), - ]) + return skip_if( + [ + no_support("firebird", "not supported by database"), + no_support("oracle", "not supported by database"), + no_support("mssql", "not supported by database"), + no_support("sybase", "not supported by database"), + ] + ) @property def non_native_boolean_unconstrained(self): """target database is not native boolean and allows arbitrary integers in it's "bool" column""" - return skip_if([ - LambdaPredicate( - lambda config: against(config, "mssql"), - "SQL Server drivers / odbc seem to change their mind on this" - ), - LambdaPredicate( - lambda config: config.db.dialect.supports_native_boolean, - "native boolean dialect" - ) - ]) + return skip_if( + [ + LambdaPredicate( + lambda config: against(config, "mssql"), + "SQL Server drivers / odbc seem to change their mind on this", + ), + LambdaPredicate( + lambda config: config.db.dialect.supports_native_boolean, + "native boolean dialect", + ), + ] + ) @property def standalone_binds(self): @@ -181,15 +184,15 @@ class DefaultRequirements(SuiteRequirements): artifact. """ - return skip_if(["firebird", "oracle", "postgresql", "sybase"], - "not supported by database") + return skip_if( + ["firebird", "oracle", "postgresql", "sybase"], + "not supported by database", + ) @property def temporary_tables(self): """target database supports temporary tables""" - return skip_if( - ["mssql", "firebird"], "not supported (?)" - ) + return skip_if(["mssql", "firebird"], "not supported (?)") @property def temp_table_reflection(self): @@ -204,20 +207,17 @@ class DefaultRequirements(SuiteRequirements): has SERIAL support. FB and Oracle (and sybase?) require the Sequence to be explicitly added, including if the table was reflected. """ - return skip_if(["firebird", "oracle", "sybase"], - "not supported by database") + return skip_if( + ["firebird", "oracle", "sybase"], "not supported by database" + ) @property def insert_from_select(self): - return skip_if( - ["firebird"], "crashes for unknown reason" - ) + return skip_if(["firebird"], "crashes for unknown reason") @property def fetch_rows_post_commit(self): - return skip_if( - ["firebird"], "not supported" - ) + return skip_if(["firebird"], "not supported") @property def non_broken_binary(self): @@ -260,11 +260,7 @@ class DefaultRequirements(SuiteRequirements): """Target must support simultaneous, independent database cursors on a single connection.""" - return skip_if( - [ - "mssql", - "mysql"], "no driver support" - ) + return skip_if(["mssql", "mysql"], "no driver support") @property def independent_connections(self): @@ -274,13 +270,22 @@ class DefaultRequirements(SuiteRequirements): # This is also true of some configurations of UnixODBC and probably # win32 ODBC as well. - return skip_if([ - no_support("sqlite", - "independent connections disabled " - "when :memory: connections are used"), - exclude("mssql", "<", (9, 0, 0), + return skip_if( + [ + no_support( + "sqlite", + "independent connections disabled " + "when :memory: connections are used", + ), + exclude( + "mssql", + "<", + (9, 0, 0), "SQL Server 2005+ is required for " - "independent connections")]) + "independent connections", + ), + ] + ) @property def memory_process_intensive(self): @@ -288,75 +293,91 @@ class DefaultRequirements(SuiteRequirements): and iterate through hundreds of connections """ - return skip_if([ - no_support("oracle", "Oracle XE usually can't handle these"), - no_support("mssql+pyodbc", "MS ODBC drivers struggle") - ]) + return skip_if( + [ + no_support("oracle", "Oracle XE usually can't handle these"), + no_support("mssql+pyodbc", "MS ODBC drivers struggle"), + ] + ) @property def updateable_autoincrement_pks(self): """Target must support UPDATE on autoincrement/integer primary key.""" - return skip_if(["mssql", "sybase"], - "IDENTITY columns can't be updated") + return skip_if( + ["mssql", "sybase"], "IDENTITY columns can't be updated" + ) @property def isolation_level(self): return only_on( - ('postgresql', 'sqlite', 'mysql', 'mssql'), - "DBAPI has no isolation level support") \ - + fails_on('postgresql+pypostgresql', - 'pypostgresql bombs on multiple isolation level calls') + ("postgresql", "sqlite", "mysql", "mssql"), + "DBAPI has no isolation level support", + ) + fails_on( + "postgresql+pypostgresql", + "pypostgresql bombs on multiple isolation level calls", + ) @property def autocommit(self): """target dialect supports 'AUTOCOMMIT' as an isolation_level""" return only_on( - ('postgresql', 'mysql', 'mssql+pyodbc', 'mssql+pymssql'), - "dialect does not support AUTOCOMMIT isolation mode") + ("postgresql", "mysql", "mssql+pyodbc", "mssql+pymssql"), + "dialect does not support AUTOCOMMIT isolation mode", + ) @property def row_triggers(self): """Target must support standard statement-running EACH ROW triggers.""" - return skip_if([ - # no access to same table - no_support('mysql', 'requires SUPER priv'), - exclude('mysql', '<', (5, 0, 10), 'not supported by database'), - - # huh? TODO: implement triggers for PG tests, remove this - no_support('postgresql', - 'PG triggers need to be implemented for tests'), - ]) + return skip_if( + [ + # no access to same table + no_support("mysql", "requires SUPER priv"), + exclude("mysql", "<", (5, 0, 10), "not supported by database"), + # huh? TODO: implement triggers for PG tests, remove this + no_support( + "postgresql", + "PG triggers need to be implemented for tests", + ), + ] + ) @property def sequences_as_server_defaults(self): """Target database must support SEQUENCE as a server side default.""" return only_on( - 'postgresql', - "doesn't support sequences as a server side default.") + "postgresql", "doesn't support sequences as a server side default." + ) @property def correlated_outer_joins(self): """Target must support an outer join to a subquery which correlates to the parent.""" - return skip_if("oracle", 'Raises "ORA-01799: a column may not be ' - 'outer-joined to a subquery"') + return skip_if( + "oracle", + 'Raises "ORA-01799: a column may not be ' + 'outer-joined to a subquery"', + ) @property def update_from(self): """Target must support UPDATE..FROM syntax""" - return only_on(['postgresql', 'mssql', 'mysql'], - "Backend does not support UPDATE..FROM") + return only_on( + ["postgresql", "mssql", "mysql"], + "Backend does not support UPDATE..FROM", + ) @property def delete_from(self): """Target must support DELETE FROM..FROM or DELETE..USING syntax""" - return only_on(['postgresql', 'mssql', 'mysql', 'sybase'], - "Backend does not support DELETE..FROM") + return only_on( + ["postgresql", "mssql", "mysql", "sybase"], + "Backend does not support DELETE..FROM", + ) @property def update_where_target_in_subquery(self): @@ -374,44 +395,37 @@ class DefaultRequirements(SuiteRequirements): return fails_if( self._mysql_not_mariadb_103, 'MySQL error 1093 "Cant specify target table ' - 'for update in FROM clause", resolved by MariaDB 10.3') + 'for update in FROM clause", resolved by MariaDB 10.3', + ) @property def savepoints(self): """Target database must support savepoints.""" - return skip_if([ - "sqlite", - "sybase", - ("mysql", "<", (5, 0, 3)), - ], "savepoints not supported") + return skip_if( + ["sqlite", "sybase", ("mysql", "<", (5, 0, 3))], + "savepoints not supported", + ) @property def savepoints_w_release(self): return self.savepoints + skip_if( ["oracle", "mssql"], - "database doesn't support release of savepoint" + "database doesn't support release of savepoint", ) - @property def schemas(self): """Target database must support external schemas, and have one named 'test_schema'.""" - return skip_if([ - "firebird" - ], "no schema support") + return skip_if(["firebird"], "no schema support") @property def cross_schema_fk_reflection(self): """target system must support reflection of inter-schema foreign keys """ - return only_on([ - "postgresql", - "mysql", - "mssql", - ]) + return only_on(["postgresql", "mysql", "mssql"]) @property def implicit_default_schema(self): @@ -421,70 +435,70 @@ class DefaultRequirements(SuiteRequirements): basically, PostgreSQL. """ - return only_on([ - "postgresql", - ]) - + return only_on(["postgresql"]) @property def unique_constraint_reflection(self): return fails_on_everything_except( - "postgresql", - "mysql", - "sqlite", - "oracle" - ) + "postgresql", "mysql", "sqlite", "oracle" + ) @property def unique_constraint_reflection_no_index_overlap(self): - return self.unique_constraint_reflection + \ - skip_if("mysql") + skip_if("oracle") - + return ( + self.unique_constraint_reflection + + skip_if("mysql") + + skip_if("oracle") + ) @property def check_constraint_reflection(self): return fails_on_everything_except( - "postgresql", "sqlite", "oracle", - self._mariadb_102 + "postgresql", "sqlite", "oracle", self._mariadb_102 ) @property def temp_table_names(self): """target dialect supports listing of temporary table names""" - return only_on(['sqlite', 'oracle']) + return only_on(["sqlite", "oracle"]) @property def temporary_views(self): """target database supports temporary views""" - return only_on(['sqlite', 'postgresql']) + return only_on(["sqlite", "postgresql"]) @property def update_nowait(self): """Target database must support SELECT...FOR UPDATE NOWAIT""" - return skip_if(["firebird", "mssql", "mysql", "sqlite", "sybase"], - "no FOR UPDATE NOWAIT support") + return skip_if( + ["firebird", "mssql", "mysql", "sqlite", "sybase"], + "no FOR UPDATE NOWAIT support", + ) @property def subqueries(self): """Target database must support subqueries.""" - return skip_if(exclude('mysql', '<', (4, 1, 1)), 'no subquery support') + return skip_if(exclude("mysql", "<", (4, 1, 1)), "no subquery support") @property def ctes(self): """Target database supports CTEs""" - return only_on([ - lambda config: against(config, "mysql") and ( - config.db.dialect._is_mariadb and - config.db.dialect._mariadb_normalized_version_info >= - (10, 2) - ), - "postgresql", - "mssql", - "oracle" - ]) + return only_on( + [ + lambda config: against(config, "mysql") + and ( + config.db.dialect._is_mariadb + and config.db.dialect._mariadb_normalized_version_info + >= (10, 2) + ), + "postgresql", + "mssql", + "oracle", + ] + ) @property def ctes_with_update_delete(self): @@ -492,47 +506,45 @@ class DefaultRequirements(SuiteRequirements): or DELETE statement which refers to the CTE in a correlated subquery. """ - return only_on([ - "postgresql", - "mssql", - # "oracle" - oracle can do this but SQLAlchemy doesn't support - # their syntax yet - ]) + return only_on( + [ + "postgresql", + "mssql", + # "oracle" - oracle can do this but SQLAlchemy doesn't support + # their syntax yet + ] + ) @property def ctes_on_dml(self): """target database supports CTES which consist of INSERT, UPDATE or DELETE *within* the CTE, e.g. WITH x AS (UPDATE....)""" - return only_if( - ['postgresql'] - ) + return only_if(["postgresql"]) @property def mod_operator_as_percent_sign(self): """target database must use a plain percent '%' as the 'modulus' operator.""" - return only_if( - ['mysql', 'sqlite', 'postgresql+psycopg2', 'mssql'] - ) + return only_if(["mysql", "sqlite", "postgresql+psycopg2", "mssql"]) @property def intersect(self): """Target database must support INTERSECT or equivalent.""" - return fails_if([ - "firebird", self._mysql_not_mariadb_103, - "sybase", - ], 'no support for INTERSECT') + return fails_if( + ["firebird", self._mysql_not_mariadb_103, "sybase"], + "no support for INTERSECT", + ) @property def except_(self): """Target database must support EXCEPT or equivalent (i.e. MINUS).""" - return fails_if([ - "firebird", self._mysql_not_mariadb_103, - "sybase", - ], 'no support for EXCEPT') + return fails_if( + ["firebird", self._mysql_not_mariadb_103, "sybase"], + "no support for EXCEPT", + ) @property def order_by_col_from_union(self): @@ -544,7 +556,7 @@ class DefaultRequirements(SuiteRequirements): Fails on SQL Server """ - return fails_if('mssql') + return fails_if("mssql") @property def parens_in_union_contained_select_w_limit_offset(self): @@ -556,7 +568,7 @@ class DefaultRequirements(SuiteRequirements): This is known to fail on SQLite. """ - return fails_if('sqlite') + return fails_if("sqlite") @property def parens_in_union_contained_select_wo_limit_offset(self): @@ -570,49 +582,58 @@ class DefaultRequirements(SuiteRequirements): creates an additional subquery. """ - return fails_if(['sqlite', 'oracle']) + return fails_if(["sqlite", "oracle"]) @property def offset(self): """Target database must support some method of adding OFFSET or equivalent to a result set.""" - return fails_if([ - "sybase" - ], 'no support for OFFSET or equivalent') + return fails_if(["sybase"], "no support for OFFSET or equivalent") @property def window_functions(self): - return only_if([ - "postgresql>=8.4", "mssql", "oracle" - ], "Backend does not support window functions") + return only_if( + ["postgresql>=8.4", "mssql", "oracle"], + "Backend does not support window functions", + ) @property def two_phase_transactions(self): """Target database must support two-phase transactions.""" - return skip_if([ - no_support('firebird', 'no SA implementation'), - no_support('mssql', 'two-phase xact not supported by drivers'), - no_support('oracle', - 'two-phase xact not implemented in SQLA/oracle'), - no_support('drizzle', 'two-phase xact not supported by database'), - no_support('sqlite', 'two-phase xact not supported by database'), - no_support('sybase', - 'two-phase xact not supported by drivers/SQLA'), - no_support('postgresql+zxjdbc', - 'FIXME: JDBC driver confuses the transaction state, ' - 'may need separate XA implementation'), - no_support('mysql', - 'recent MySQL communiity editions have too many issues ' - '(late 2016), disabling for now')]) + return skip_if( + [ + no_support("firebird", "no SA implementation"), + no_support("mssql", "two-phase xact not supported by drivers"), + no_support( + "oracle", "two-phase xact not implemented in SQLA/oracle" + ), + no_support( + "drizzle", "two-phase xact not supported by database" + ), + no_support( + "sqlite", "two-phase xact not supported by database" + ), + no_support( + "sybase", "two-phase xact not supported by drivers/SQLA" + ), + no_support( + "postgresql+zxjdbc", + "FIXME: JDBC driver confuses the transaction state, " + "may need separate XA implementation", + ), + no_support( + "mysql", + "recent MySQL communiity editions have too many issues " + "(late 2016), disabling for now", + ), + ] + ) @property def two_phase_recovery(self): return self.two_phase_transactions + ( - skip_if( - "mysql", - "crashes on most mariadb and mysql versions" - ) + skip_if("mysql", "crashes on most mariadb and mysql versions") ) @property @@ -627,8 +648,9 @@ class DefaultRequirements(SuiteRequirements): target database can persist/return an empty string with a varchar. """ - return fails_if(["oracle"], - 'oracle converts empty strings to a blank space') + return fails_if( + ["oracle"], "oracle converts empty strings to a blank space" + ) @property def empty_strings_text(self): @@ -640,9 +662,7 @@ class DefaultRequirements(SuiteRequirements): @property def unicode_data(self): """target drive must support unicode data stored in columns.""" - return skip_if([ - no_support("sybase", "no unicode driver support") - ]) + return skip_if([no_support("sybase", "no unicode driver support")]) @property def unicode_connections(self): @@ -650,44 +670,44 @@ class DefaultRequirements(SuiteRequirements): Target driver must support some encoding of Unicode across the wire. """ # TODO: expand to exclude MySQLdb versions w/ broken unicode - return skip_if([ - exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'), - ]) + return skip_if( + [exclude("mysql", "<", (4, 1, 1), "no unicode connection support")] + ) @property def unicode_ddl(self): """Target driver must support some degree of non-ascii symbol names.""" # TODO: expand to exclude MySQLdb versions w/ broken unicode - return skip_if([ - no_support('oracle', 'FIXME: no support in database?'), - no_support('sybase', 'FIXME: guessing, needs confirmation'), - no_support('mssql+pymssql', 'no FreeTDS support'), - LambdaPredicate( - lambda config: against(config, "mysql+mysqlconnector") and - config.db.dialect._mysqlconnector_version_info > (2, 0) and - util.py2k, - "bug in mysqlconnector 2.0" - ), - exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'), - ]) + return skip_if( + [ + no_support("oracle", "FIXME: no support in database?"), + no_support("sybase", "FIXME: guessing, needs confirmation"), + no_support("mssql+pymssql", "no FreeTDS support"), + LambdaPredicate( + lambda config: against(config, "mysql+mysqlconnector") + and config.db.dialect._mysqlconnector_version_info > (2, 0) + and util.py2k, + "bug in mysqlconnector 2.0", + ), + exclude( + "mysql", "<", (4, 1, 1), "no unicode connection support" + ), + ] + ) @property def emulated_lastrowid(self): """"target dialect retrieves cursor.lastrowid or an equivalent after an insert() construct executes. """ - return fails_on_everything_except('mysql', - 'sqlite+pysqlite', - 'sqlite+pysqlcipher', - 'sybase', - 'mssql') + return fails_on_everything_except( + "mysql", "sqlite+pysqlite", "sqlite+pysqlcipher", "sybase", "mssql" + ) @property def implements_get_lastrowid(self): - return skip_if([ - no_support('sybase', 'not supported by database'), - ]) + return skip_if([no_support("sybase", "not supported by database")]) @property def dbapi_lastrowid(self): @@ -695,22 +715,24 @@ class DefaultRequirements(SuiteRequirements): cursor object. """ - return skip_if('mssql+pymssql', 'crashes on pymssql') + \ - fails_on_everything_except('mysql', - 'sqlite+pysqlite', - 'sqlite+pysqlcipher') + return skip_if( + "mssql+pymssql", "crashes on pymssql" + ) + fails_on_everything_except( + "mysql", "sqlite+pysqlite", "sqlite+pysqlcipher" + ) @property def nullsordering(self): """Target backends that support nulls ordering.""" - return fails_on_everything_except('postgresql', 'oracle', 'firebird') + return fails_on_everything_except("postgresql", "oracle", "firebird") @property def reflects_pk_names(self): """Target driver reflects the name of primary key constraints.""" - return fails_on_everything_except('postgresql', 'oracle', 'mssql', - 'sybase', 'sqlite') + return fails_on_everything_except( + "postgresql", "oracle", "mssql", "sybase", "sqlite" + ) @property def nested_aggregates(self): @@ -721,37 +743,44 @@ class DefaultRequirements(SuiteRequirements): @property def array_type(self): - return only_on([ - lambda config: against(config, "postgresql") and - not against(config, "+pg8000") and not against(config, "+zxjdbc") - ]) + return only_on( + [ + lambda config: against(config, "postgresql") + and not against(config, "+pg8000") + and not against(config, "+zxjdbc") + ] + ) @property def json_type(self): - return only_on([ - lambda config: - against(config, "mysql") and ( + return only_on( + [ + lambda config: against(config, "mysql") + and ( ( - not config.db.dialect._is_mariadb and - against(config, "mysql >= 5.7") + not config.db.dialect._is_mariadb + and against(config, "mysql >= 5.7") ) or ( - config.db.dialect._mariadb_normalized_version_info >= - (10, 2, 7) + config.db.dialect._mariadb_normalized_version_info + >= (10, 2, 7) ) ), - "postgresql >= 9.3", - "sqlite >= 3.9" - ]) + "postgresql >= 9.3", + "sqlite >= 3.9", + ] + ) @property def reflects_json_type(self): - return only_on([ - lambda config: against(config, "mysql >= 5.7") and - not config.db.dialect._is_mariadb, - "postgresql >= 9.3", - "sqlite >= 3.9" - ]) + return only_on( + [ + lambda config: against(config, "mysql >= 5.7") + and not config.db.dialect._is_mariadb, + "postgresql >= 9.3", + "sqlite >= 3.9", + ] + ) @property def json_array_indexes(self): @@ -778,8 +807,9 @@ class DefaultRequirements(SuiteRequirements): """target dialect supports representation of Python datetime.datetime() with microsecond objects.""" - return skip_if(['mssql', 'mysql', 'firebird', '+zxjdbc', - 'oracle', 'sybase']) + return skip_if( + ["mssql", "mysql", "firebird", "+zxjdbc", "oracle", "sybase"] + ) @property def timestamp_microseconds(self): @@ -787,14 +817,14 @@ class DefaultRequirements(SuiteRequirements): datetime.datetime() with microsecond objects but only if TIMESTAMP is used.""" - return only_on(['oracle']) + return only_on(["oracle"]) @property def datetime_historic(self): """target dialect supports representation of Python datetime.datetime() objects with historic (pre 1900) values.""" - return succeeds_if(['sqlite', 'postgresql', 'firebird']) + return succeeds_if(["sqlite", "postgresql", "firebird"]) @property def date(self): @@ -809,29 +839,30 @@ class DefaultRequirements(SuiteRequirements): of a date column.""" # does not work as of pyodbc 4.0.22 - return fails_on('mysql+mysqlconnector') + skip_if("mssql+pyodbc") + return fails_on("mysql+mysqlconnector") + skip_if("mssql+pyodbc") @property def date_historic(self): """target dialect supports representation of Python datetime.datetime() objects with historic (pre 1900) values.""" - return succeeds_if(['sqlite', 'postgresql', 'firebird']) + return succeeds_if(["sqlite", "postgresql", "firebird"]) @property def time(self): """target dialect supports representation of Python datetime.time() objects.""" - return skip_if(['oracle']) + return skip_if(["oracle"]) @property def time_microseconds(self): """target dialect supports representation of Python datetime.time() with microsecond objects.""" - return skip_if(['mssql', 'mysql', 'firebird', '+zxjdbc', - 'oracle', 'sybase']) + return skip_if( + ["mssql", "mysql", "firebird", "+zxjdbc", "oracle", "sybase"] + ) @property def precision_numerics_general(self): @@ -853,10 +884,13 @@ class DefaultRequirements(SuiteRequirements): return fails_if( [ - ("sybase+pyodbc", None, None, - "Don't know how do get these values through FreeTDS + Sybase" - ), - ("firebird", None, None, "Precision must be from 1 to 18") + ( + "sybase+pyodbc", + None, + None, + "Don't know how do get these values through FreeTDS + Sybase", + ), + ("firebird", None, None, "Precision must be from 1 to 18"), ] ) @@ -868,13 +902,15 @@ class DefaultRequirements(SuiteRequirements): """ def broken_cx_oracle(config): - return against(config, 'oracle+cx_oracle') and \ - config.db.dialect.cx_oracle_ver <= (6, 0, 2) and \ - config.db.dialect.cx_oracle_ver > (6, ) + return ( + against(config, "oracle+cx_oracle") + and config.db.dialect.cx_oracle_ver <= (6, 0, 2) + and config.db.dialect.cx_oracle_ver > (6,) + ) return fails_if( [ - ('sqlite', None, None, 'TODO'), + ("sqlite", None, None, "TODO"), ("firebird", None, None, "Precision must be from 1 to 18"), ("sybase+pysybase", None, None, "TODO"), ] @@ -889,8 +925,12 @@ class DefaultRequirements(SuiteRequirements): return fails_if( [ ("oracle", None, None, "driver doesn't do this automatically"), - ("firebird", None, None, - "database and/or driver truncates decimal places.") + ( + "firebird", + None, + None, + "database and/or driver truncates decimal places.", + ), ] ) @@ -899,32 +939,65 @@ class DefaultRequirements(SuiteRequirements): """target backend will return native floating point numbers with at least seven decimal places when using the generic Float type.""" - return fails_if([ - ('mysql', None, None, - 'mysql FLOAT type only returns 4 decimals'), - ('firebird', None, None, - "firebird FLOAT type isn't high precision"), - ]) + return fails_if( + [ + ( + "mysql", + None, + None, + "mysql FLOAT type only returns 4 decimals", + ), + ( + "firebird", + None, + None, + "firebird FLOAT type isn't high precision", + ), + ] + ) @property def floats_to_four_decimals(self): - return fails_if([ - ("mysql+oursql", None, None, "Floating point error"), - ("firebird", None, None, - "Firebird still has FP inaccuracy even " - "with only four decimal places"), - ('mssql+pyodbc', None, None, - 'mssql+pyodbc has FP inaccuracy even with ' - 'only four decimal places '), - ('mssql+pymssql', None, None, - 'mssql+pymssql has FP inaccuracy even with ' - 'only four decimal places '), - ('postgresql+pg8000', None, None, - 'postgresql+pg8000 has FP inaccuracy even with ' - 'only four decimal places '), - ('postgresql+psycopg2cffi', None, None, - 'postgresql+psycopg2cffi has FP inaccuracy even with ' - 'only four decimal places ')]) + return fails_if( + [ + ("mysql+oursql", None, None, "Floating point error"), + ( + "firebird", + None, + None, + "Firebird still has FP inaccuracy even " + "with only four decimal places", + ), + ( + "mssql+pyodbc", + None, + None, + "mssql+pyodbc has FP inaccuracy even with " + "only four decimal places ", + ), + ( + "mssql+pymssql", + None, + None, + "mssql+pymssql has FP inaccuracy even with " + "only four decimal places ", + ), + ( + "postgresql+pg8000", + None, + None, + "postgresql+pg8000 has FP inaccuracy even with " + "only four decimal places ", + ), + ( + "postgresql+psycopg2cffi", + None, + None, + "postgresql+psycopg2cffi has FP inaccuracy even with " + "only four decimal places ", + ), + ] + ) @property def implicit_decimal_binds(self): @@ -953,19 +1026,20 @@ class DefaultRequirements(SuiteRequirements): try: from MySQLdb import converters from decimal import Decimal + return Decimal not in converters.conversions except: return True - return against(config, "mysql+mysqldb") and \ - config.db.dialect._mysql_dbapi_version <= (1, 3, 13) + return against( + config, "mysql+mysqldb" + ) and config.db.dialect._mysql_dbapi_version <= (1, 3, 13) + return exclusions.fails_on(check, "fixed for mysqlclient post 1.3.13") @property def fetch_null_from_numeric(self): - return skip_if( - ("mssql+pyodbc", None, None, "crashes due to bug #351"), - ) + return skip_if(("mssql+pyodbc", None, None, "crashes due to bug #351")) @property def duplicate_key_raises_integrity_error(self): @@ -977,8 +1051,10 @@ class DefaultRequirements(SuiteRequirements): return False count = config.db.scalar( "SELECT count(*) FROM pg_extension " - "WHERE extname='%s'" % name) + "WHERE extname='%s'" % name + ) return bool(count) + return only_if(check, "needs %s extension" % name) @property @@ -993,8 +1069,8 @@ class DefaultRequirements(SuiteRequirements): def range_types(self): def check_range_types(config): if not against( - config, - ["postgresql+psycopg2", "postgresql+psycopg2cffi"]): + config, ["postgresql+psycopg2", "postgresql+psycopg2cffi"] + ): return False try: config.db.scalar("select '[1,2)'::int4range;") @@ -1007,25 +1083,26 @@ class DefaultRequirements(SuiteRequirements): @property def oracle_test_dblink(self): return skip_if( - lambda config: not config.file_config.has_option( - 'sqla_testing', 'oracle_db_link'), - "oracle_db_link option not specified in config" - ) + lambda config: not config.file_config.has_option( + "sqla_testing", "oracle_db_link" + ), + "oracle_db_link option not specified in config", + ) @property def postgresql_test_dblink(self): return skip_if( - lambda config: not config.file_config.has_option( - 'sqla_testing', 'postgres_test_db_link'), - "postgres_test_db_link option not specified in config" - ) + lambda config: not config.file_config.has_option( + "sqla_testing", "postgres_test_db_link" + ), + "postgres_test_db_link option not specified in config", + ) @property def postgresql_jsonb(self): return only_on("postgresql >= 9.4") + skip_if( - lambda config: - config.db.dialect.driver == "pg8000" and - config.db.dialect._dbapi_version <= (1, 10, 1) + lambda config: config.db.dialect.driver == "pg8000" + and config.db.dialect._dbapi_version <= (1, 10, 1) ) @property @@ -1038,15 +1115,16 @@ class DefaultRequirements(SuiteRequirements): @property def psycopg2_compatibility(self): - return only_on( - ["postgresql+psycopg2", "postgresql+psycopg2cffi"] - ) + return only_on(["postgresql+psycopg2", "postgresql+psycopg2cffi"]) @property def psycopg2_or_pg8000_compatibility(self): return only_on( - ["postgresql+psycopg2", "postgresql+psycopg2cffi", - "postgresql+pg8000"] + [ + "postgresql+psycopg2", + "postgresql+psycopg2cffi", + "postgresql+pg8000", + ] ) @property @@ -1054,40 +1132,49 @@ class DefaultRequirements(SuiteRequirements): return skip_if( [ ( - "+psycopg2", None, None, + "+psycopg2", + None, + None, "psycopg2 2.4 no longer accepts percent " - "sign in bind placeholders"), + "sign in bind placeholders", + ), ( - "+psycopg2cffi", None, None, + "+psycopg2cffi", + None, + None, "psycopg2cffi does not accept percent signs in " - "bind placeholders"), - ("mysql", None, None, "executemany() doesn't work here") + "bind placeholders", + ), + ("mysql", None, None, "executemany() doesn't work here"), ] ) @property def order_by_label_with_expression(self): - return fails_if([ - ('firebird', None, None, - "kinterbasdb doesn't send full type information"), - ('postgresql', None, None, 'only simple labels allowed'), - ('sybase', None, None, 'only simple labels allowed'), - ('mssql', None, None, 'only simple labels allowed') - ]) + return fails_if( + [ + ( + "firebird", + None, + None, + "kinterbasdb doesn't send full type information", + ), + ("postgresql", None, None, "only simple labels allowed"), + ("sybase", None, None, "only simple labels allowed"), + ("mssql", None, None, "only simple labels allowed"), + ] + ) def get_order_by_collation(self, config): lookup = { - # will raise without quoting "postgresql": "POSIX", - # note MySQL databases need to be created w/ utf8mb3 charset # for the test suite "mysql": "utf8mb3_bin", "sqlite": "NOCASE", - # will raise *with* quoting - "mssql": "Latin1_General_CI_AS" + "mssql": "Latin1_General_CI_AS", } try: return lookup[config.db.name] @@ -1098,8 +1185,9 @@ class DefaultRequirements(SuiteRequirements): def skip_mysql_on_windows(self): """Catchall for a large variety of MySQL on Windows failures""" - return skip_if(self._has_mysql_on_windows, - "Not supported on MySQL + Windows") + return skip_if( + self._has_mysql_on_windows, "Not supported on MySQL + Windows" + ) @property def mssql_freetds(self): @@ -1110,7 +1198,8 @@ class DefaultRequirements(SuiteRequirements): return exclusions.skip_if( ["oracle"], "works, but Oracle just gets tired with " - "this much connection activity") + "this much connection activity", + ) @property def no_mssql_freetds(self): @@ -1125,37 +1214,40 @@ class DefaultRequirements(SuiteRequirements): return False with config.db.connect() as conn: drivername = conn.connection.connection.getinfo( - config.db.dialect.dbapi.SQL_DRIVER_NAME) + config.db.dialect.dbapi.SQL_DRIVER_NAME + ) # on linux this is 'libmsodbcsql-13.1.so.9.2'. # don't know what it is on windows return "msodbc" in drivername + return only_if( - has_fastexecutemany, - "only on pyodbc > 4.0.19 w/ msodbc driver") + has_fastexecutemany, "only on pyodbc > 4.0.19 w/ msodbc driver" + ) @property def python_fixed_issue_8743(self): return exclusions.skip_if( lambda: sys.version_info < (2, 7, 8), - "Python issue 8743 fixed in Python 2.7.8" + "Python issue 8743 fixed in Python 2.7.8", ) @property def selectone(self): """target driver must support the literal statement 'select 1'""" - return skip_if(["oracle", "firebird"], - "non-standard SELECT scalar syntax") + return skip_if( + ["oracle", "firebird"], "non-standard SELECT scalar syntax" + ) @property def mysql_for_update(self): return skip_if( "mysql+mysqlconnector", - "lock-sensitive operations crash on mysqlconnector" + "lock-sensitive operations crash on mysqlconnector", ) @property def mysql_fsp(self): - return only_if('mysql >= 5.6.4') + return only_if("mysql >= 5.6.4") @property def mysql_fully_case_sensitive(self): @@ -1164,7 +1256,7 @@ class DefaultRequirements(SuiteRequirements): @property def mysql_zero_date(self): def check(config): - if not against(config, 'mysql'): + if not against(config, "mysql"): return False row = config.db.execute("show variables like 'sql_mode'").first() @@ -1175,7 +1267,7 @@ class DefaultRequirements(SuiteRequirements): @property def mysql_non_strict(self): def check(config): - if not against(config, 'mysql'): + if not against(config, "mysql"): return False row = config.db.execute("show variables like 'sql_mode'").first() @@ -1186,53 +1278,62 @@ class DefaultRequirements(SuiteRequirements): @property def mysql_ngram_fulltext(self): def check(config): - return against(config, "mysql") and \ - not config.db.dialect._is_mariadb and \ - config.db.dialect.server_version_info >= (5, 7) + return ( + against(config, "mysql") + and not config.db.dialect._is_mariadb + and config.db.dialect.server_version_info >= (5, 7) + ) + return only_if(check) def _mariadb_102(self, config): - return against(config, "mysql") and \ - config.db.dialect._is_mariadb and \ - config.db.dialect._mariadb_normalized_version_info > (10, 2) + return ( + against(config, "mysql") + and config.db.dialect._is_mariadb + and config.db.dialect._mariadb_normalized_version_info > (10, 2) + ) def _mysql_not_mariadb_102(self, config): return against(config, "mysql") and ( - not config.db.dialect._is_mariadb or - config.db.dialect._mariadb_normalized_version_info < (10, 2) + not config.db.dialect._is_mariadb + or config.db.dialect._mariadb_normalized_version_info < (10, 2) ) def _mysql_not_mariadb_103(self, config): return against(config, "mysql") and ( - not config.db.dialect._is_mariadb or - config.db.dialect._mariadb_normalized_version_info < (10, 3) + not config.db.dialect._is_mariadb + or config.db.dialect._mariadb_normalized_version_info < (10, 3) ) def _has_mysql_on_windows(self, config): - return against(config, 'mysql') and \ - config.db.dialect._detect_casing(config.db) == 1 + return ( + against(config, "mysql") + and config.db.dialect._detect_casing(config.db) == 1 + ) def _has_mysql_fully_case_sensitive(self, config): - return against(config, 'mysql') and \ - config.db.dialect._detect_casing(config.db) == 0 + return ( + against(config, "mysql") + and config.db.dialect._detect_casing(config.db) == 0 + ) @property def postgresql_utf8_server_encoding(self): return only_if( - lambda config: against(config, 'postgresql') and - config.db.scalar("show server_encoding").lower() == "utf8" + lambda config: against(config, "postgresql") + and config.db.scalar("show server_encoding").lower() == "utf8" ) @property def cxoracle6_or_greater(self): return only_if( - lambda config: against(config, "oracle+cx_oracle") and - config.db.dialect.cx_oracle_ver >= (6, ) + lambda config: against(config, "oracle+cx_oracle") + and config.db.dialect.cx_oracle_ver >= (6,) ) @property def oracle5x(self): return only_if( - lambda config: against(config, "oracle+cx_oracle") and - config.db.dialect.cx_oracle_ver < (6, ) + lambda config: against(config, "oracle+cx_oracle") + and config.db.dialect.cx_oracle_ver < (6,) ) diff --git a/test/sql/test_case_statement.py b/test/sql/test_case_statement.py index d81ff20cdb..181dd79a1f 100644 --- a/test/sql/test_case_statement.py +++ b/test/sql/test_case_statement.py @@ -1,51 +1,73 @@ from sqlalchemy.testing import assert_raises, eq_, assert_raises_message from sqlalchemy.testing import fixtures, AssertsCompiledSQL from sqlalchemy import ( - testing, exc, case, select, literal_column, text, and_, Integer, cast, - String, Column, Table, MetaData) + testing, + exc, + case, + select, + literal_column, + text, + and_, + Integer, + cast, + String, + Column, + Table, + MetaData, +) from sqlalchemy.sql import table, column info_table = None class CaseTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_class(cls): metadata = MetaData(testing.db) global info_table info_table = Table( - 'infos', metadata, - Column('pk', Integer, primary_key=True), - Column('info', String(30))) + "infos", + metadata, + Column("pk", Integer, primary_key=True), + Column("info", String(30)), + ) info_table.create() info_table.insert().execute( - {'pk': 1, 'info': 'pk_1_data'}, - {'pk': 2, 'info': 'pk_2_data'}, - {'pk': 3, 'info': 'pk_3_data'}, - {'pk': 4, 'info': 'pk_4_data'}, - {'pk': 5, 'info': 'pk_5_data'}, - {'pk': 6, 'info': 'pk_6_data'}) + {"pk": 1, "info": "pk_1_data"}, + {"pk": 2, "info": "pk_2_data"}, + {"pk": 3, "info": "pk_3_data"}, + {"pk": 4, "info": "pk_4_data"}, + {"pk": 5, "info": "pk_5_data"}, + {"pk": 6, "info": "pk_6_data"}, + ) @classmethod def teardown_class(cls): info_table.drop() - @testing.fails_on('firebird', 'FIXME: unknown') + @testing.fails_on("firebird", "FIXME: unknown") @testing.requires.subqueries def test_case(self): inner = select( [ case( [ - [info_table.c.pk < 3, 'lessthan3'], + [info_table.c.pk < 3, "lessthan3"], [ and_(info_table.c.pk >= 3, info_table.c.pk < 7), - 'gt3']]).label('x'), - info_table.c.pk, info_table.c.info], from_obj=[info_table]) + "gt3", + ], + ] + ).label("x"), + info_table.c.pk, + info_table.c.info, + ], + from_obj=[info_table], + ) inner_result = inner.execute().fetchall() @@ -57,25 +79,25 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL): # gt3 5 pk_5_data # gt3 6 pk_6_data assert inner_result == [ - ('lessthan3', 1, 'pk_1_data'), - ('lessthan3', 2, 'pk_2_data'), - ('gt3', 3, 'pk_3_data'), - ('gt3', 4, 'pk_4_data'), - ('gt3', 5, 'pk_5_data'), - ('gt3', 6, 'pk_6_data') + ("lessthan3", 1, "pk_1_data"), + ("lessthan3", 2, "pk_2_data"), + ("gt3", 3, "pk_3_data"), + ("gt3", 4, "pk_4_data"), + ("gt3", 5, "pk_5_data"), + ("gt3", 6, "pk_6_data"), ] - outer = select([inner.alias('q_inner')]) + outer = select([inner.alias("q_inner")]) outer_result = outer.execute().fetchall() assert outer_result == [ - ('lessthan3', 1, 'pk_1_data'), - ('lessthan3', 2, 'pk_2_data'), - ('gt3', 3, 'pk_3_data'), - ('gt3', 4, 'pk_4_data'), - ('gt3', 5, 'pk_5_data'), - ('gt3', 6, 'pk_6_data') + ("lessthan3", 1, "pk_1_data"), + ("lessthan3", 2, "pk_2_data"), + ("gt3", 3, "pk_3_data"), + ("gt3", 4, "pk_4_data"), + ("gt3", 5, "pk_5_data"), + ("gt3", 6, "pk_6_data"), ] w_else = select( @@ -83,48 +105,54 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL): case( [ [info_table.c.pk < 3, cast(3, Integer)], - [ - and_( - info_table.c.pk >= 3, info_table.c.pk < 6), - 6]], - else_=0).label('x'), - info_table.c.pk, info_table.c.info], - from_obj=[info_table]) + [and_(info_table.c.pk >= 3, info_table.c.pk < 6), 6], + ], + else_=0, + ).label("x"), + info_table.c.pk, + info_table.c.info, + ], + from_obj=[info_table], + ) else_result = w_else.execute().fetchall() assert else_result == [ - (3, 1, 'pk_1_data'), - (3, 2, 'pk_2_data'), - (6, 3, 'pk_3_data'), - (6, 4, 'pk_4_data'), - (6, 5, 'pk_5_data'), - (0, 6, 'pk_6_data') + (3, 1, "pk_1_data"), + (3, 2, "pk_2_data"), + (6, 3, "pk_3_data"), + (6, 4, "pk_4_data"), + (6, 5, "pk_5_data"), + (0, 6, "pk_6_data"), ] def test_literal_interpretation_ambiguous(self): assert_raises_message( exc.ArgumentError, r"Ambiguous literal: 'x'. Use the 'text\(\)' function", - case, [("x", "y")] + case, + [("x", "y")], ) def test_literal_interpretation_ambiguous_tuple(self): assert_raises_message( exc.ArgumentError, r"Ambiguous literal: \('x', 'y'\). Use the 'text\(\)' function", - case, [(("x", "y"), "z")] + case, + [(("x", "y"), "z")], ) def test_literal_interpretation(self): - t = table('test', column('col1')) + t = table("test", column("col1")) self.assert_compile( case([("x", "y")], value=t.c.col1), - "CASE test.col1 WHEN :param_1 THEN :param_2 END") + "CASE test.col1 WHEN :param_1 THEN :param_2 END", + ) self.assert_compile( case([(t.c.col1 == 7, "y")], else_="z"), - "CASE WHEN (test.col1 = :col1_1) THEN :param_1 ELSE :param_2 END") + "CASE WHEN (test.col1 = :col1_1) THEN :param_1 ELSE :param_2 END", + ) def test_text_doesnt_explode(self): @@ -132,69 +160,81 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL): select( [ case( - [ - ( - info_table.c.info == 'pk_4_data', - text("'yes'"))], - else_=text("'no'")) - ]).order_by(info_table.c.info), - + [(info_table.c.info == "pk_4_data", text("'yes'"))], + else_=text("'no'"), + ) + ] + ).order_by(info_table.c.info), select( [ case( [ ( - info_table.c.info == 'pk_4_data', - literal_column("'yes'"))], - else_=literal_column("'no'") - )] + info_table.c.info == "pk_4_data", + literal_column("'yes'"), + ) + ], + else_=literal_column("'no'"), + ) + ] ).order_by(info_table.c.info), - ]: if testing.against("firebird"): - eq_(s.execute().fetchall(), [ - ('no ', ), ('no ', ), ('no ', ), ('yes', ), - ('no ', ), ('no ', ), - ]) + eq_( + s.execute().fetchall(), + [ + ("no ",), + ("no ",), + ("no ",), + ("yes",), + ("no ",), + ("no ",), + ], + ) else: - eq_(s.execute().fetchall(), [ - ('no', ), ('no', ), ('no', ), ('yes', ), - ('no', ), ('no', ), - ]) + eq_( + s.execute().fetchall(), + [("no",), ("no",), ("no",), ("yes",), ("no",), ("no",)], + ) - @testing.fails_on('firebird', 'FIXME: unknown') + @testing.fails_on("firebird", "FIXME: unknown") def testcase_with_dict(self): query = select( [ case( { - info_table.c.pk < 3: 'lessthan3', - info_table.c.pk >= 3: 'gt3', - }, else_='other'), - info_table.c.pk, info_table.c.info + info_table.c.pk < 3: "lessthan3", + info_table.c.pk >= 3: "gt3", + }, + else_="other", + ), + info_table.c.pk, + info_table.c.info, ], - from_obj=[info_table]) + from_obj=[info_table], + ) assert query.execute().fetchall() == [ - ('lessthan3', 1, 'pk_1_data'), - ('lessthan3', 2, 'pk_2_data'), - ('gt3', 3, 'pk_3_data'), - ('gt3', 4, 'pk_4_data'), - ('gt3', 5, 'pk_5_data'), - ('gt3', 6, 'pk_6_data') + ("lessthan3", 1, "pk_1_data"), + ("lessthan3", 2, "pk_2_data"), + ("gt3", 3, "pk_3_data"), + ("gt3", 4, "pk_4_data"), + ("gt3", 5, "pk_5_data"), + ("gt3", 6, "pk_6_data"), ] simple_query = select( [ case( - {1: 'one', 2: 'two', }, - value=info_table.c.pk, else_='other'), - info_table.c.pk + {1: "one", 2: "two"}, value=info_table.c.pk, else_="other" + ), + info_table.c.pk, ], whereclause=info_table.c.pk < 4, - from_obj=[info_table]) + from_obj=[info_table], + ) assert simple_query.execute().fetchall() == [ - ('one', 1), - ('two', 2), - ('other', 3), + ("one", 1), + ("two", 2), + ("other", 3), ] diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index f543b86773..f3305743ad 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -10,147 +10,201 @@ styling and coherent test organization. """ -from sqlalchemy.testing import eq_, is_, assert_raises, \ - assert_raises_message, eq_ignore_whitespace +from sqlalchemy.testing import ( + eq_, + is_, + assert_raises, + assert_raises_message, + eq_ignore_whitespace, +) from sqlalchemy import testing from sqlalchemy.testing import fixtures, AssertsCompiledSQL -from sqlalchemy import Integer, String, MetaData, Table, Column, select, \ - func, not_, cast, text, tuple_, exists, update, bindparam,\ - literal, and_, null, type_coerce, alias, or_, literal_column,\ - Float, TIMESTAMP, Numeric, Date, Text, union, except_,\ - intersect, union_all, Boolean, distinct, join, outerjoin, asc, desc,\ - over, subquery, case, true, CheckConstraint, Sequence +from sqlalchemy import ( + Integer, + String, + MetaData, + Table, + Column, + select, + func, + not_, + cast, + text, + tuple_, + exists, + update, + bindparam, + literal, + and_, + null, + type_coerce, + alias, + or_, + literal_column, + Float, + TIMESTAMP, + Numeric, + Date, + Text, + union, + except_, + intersect, + union_all, + Boolean, + distinct, + join, + outerjoin, + asc, + desc, + over, + subquery, + case, + true, + CheckConstraint, + Sequence, +) import decimal from sqlalchemy.util import u from sqlalchemy import exc, sql, util, types, schema from sqlalchemy.sql import table, column, label from sqlalchemy.sql.expression import ClauseList, _literal_as_text, HasPrefixes from sqlalchemy.engine import default -from sqlalchemy.dialects import mysql, mssql, postgresql, oracle, \ - sqlite, sybase +from sqlalchemy.dialects import ( + mysql, + mssql, + postgresql, + oracle, + sqlite, + sybase, +) from sqlalchemy.dialects.postgresql.base import PGCompiler, PGDialect from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import compiler -table1 = table('mytable', - column('myid', Integer), - column('name', String), - column('description', String), - ) +table1 = table( + "mytable", + column("myid", Integer), + column("name", String), + column("description", String), +) table2 = table( - 'myothertable', - column('otherid', Integer), - column('othername', String), + "myothertable", column("otherid", Integer), column("othername", String) ) table3 = table( - 'thirdtable', - column('userid', Integer), - column('otherstuff', String), + "thirdtable", column("userid", Integer), column("otherstuff", String) ) metadata = MetaData() # table with a schema table4 = Table( - 'remotetable', metadata, - Column('rem_id', Integer, primary_key=True), - Column('datatype_id', Integer), - Column('value', String(20)), - schema='remote_owner' + "remotetable", + metadata, + Column("rem_id", Integer, primary_key=True), + Column("datatype_id", Integer), + Column("value", String(20)), + schema="remote_owner", ) # table with a 'multipart' schema table5 = Table( - 'remotetable', metadata, - Column('rem_id', Integer, primary_key=True), - Column('datatype_id', Integer), - Column('value', String(20)), - schema='dbo.remote_owner' + "remotetable", + metadata, + Column("rem_id", Integer, primary_key=True), + Column("datatype_id", Integer), + Column("value", String(20)), + schema="dbo.remote_owner", ) -users = table('users', - column('user_id'), - column('user_name'), - column('password'), - ) +users = table( + "users", column("user_id"), column("user_name"), column("password") +) -addresses = table('addresses', - column('address_id'), - column('user_id'), - column('street'), - column('city'), - column('state'), - column('zip') - ) +addresses = table( + "addresses", + column("address_id"), + column("user_id"), + column("street"), + column("city"), + column("state"), + column("zip"), +) -keyed = Table('keyed', metadata, - Column('x', Integer, key='colx'), - Column('y', Integer, key='coly'), - Column('z', Integer), - ) +keyed = Table( + "keyed", + metadata, + Column("x", Integer, key="colx"), + Column("y", Integer, key="coly"), + Column("z", Integer), +) class SelectTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_attribute_sanity(self): - assert hasattr(table1, 'c') - assert hasattr(table1.select(), 'c') - assert not hasattr(table1.c.myid.self_group(), 'columns') - assert hasattr(table1.select().self_group(), 'columns') - assert not hasattr(table1.c.myid, 'columns') - assert not hasattr(table1.c.myid, 'c') - assert not hasattr(table1.select().c.myid, 'c') - assert not hasattr(table1.select().c.myid, 'columns') - assert not hasattr(table1.alias().c.myid, 'columns') - assert not hasattr(table1.alias().c.myid, 'c') + assert hasattr(table1, "c") + assert hasattr(table1.select(), "c") + assert not hasattr(table1.c.myid.self_group(), "columns") + assert hasattr(table1.select().self_group(), "columns") + assert not hasattr(table1.c.myid, "columns") + assert not hasattr(table1.c.myid, "c") + assert not hasattr(table1.select().c.myid, "c") + assert not hasattr(table1.select().c.myid, "columns") + assert not hasattr(table1.alias().c.myid, "columns") + assert not hasattr(table1.alias().c.myid, "c") if util.compat.py32: assert_raises_message( exc.InvalidRequestError, - 'Scalar Select expression has no ' - 'columns; use this object directly within a ' - 'column-level expression.', + "Scalar Select expression has no " + "columns; use this object directly within a " + "column-level expression.", lambda: hasattr( - select([table1.c.myid]).as_scalar().self_group(), - 'columns')) + select([table1.c.myid]).as_scalar().self_group(), "columns" + ), + ) assert_raises_message( exc.InvalidRequestError, - 'Scalar Select expression has no ' - 'columns; use this object directly within a ' - 'column-level expression.', - lambda: hasattr(select([table1.c.myid]).as_scalar(), - 'columns')) + "Scalar Select expression has no " + "columns; use this object directly within a " + "column-level expression.", + lambda: hasattr( + select([table1.c.myid]).as_scalar(), "columns" + ), + ) else: assert not hasattr( - select([table1.c.myid]).as_scalar().self_group(), - 'columns') - assert not hasattr(select([table1.c.myid]).as_scalar(), 'columns') + select([table1.c.myid]).as_scalar().self_group(), "columns" + ) + assert not hasattr(select([table1.c.myid]).as_scalar(), "columns") def test_prefix_constructor(self): class Pref(HasPrefixes): - def _generate(self): return self - assert_raises(exc.ArgumentError, - Pref().prefix_with, - "some prefix", not_a_dialect=True - ) + + assert_raises( + exc.ArgumentError, + Pref().prefix_with, + "some prefix", + not_a_dialect=True, + ) def test_table_select(self): - self.assert_compile(table1.select(), - "SELECT mytable.myid, mytable.name, " - "mytable.description FROM mytable") + self.assert_compile( + table1.select(), + "SELECT mytable.myid, mytable.name, " + "mytable.description FROM mytable", + ) self.assert_compile( - select( - [ - table1, - table2]), + select([table1, table2]), "SELECT mytable.myid, mytable.name, mytable.description, " "myothertable.otherid, myothertable.othername FROM mytable, " - "myothertable") + "myothertable", + ) def test_invalid_col_argument(self): assert_raises(exc.ArgumentError, select, table1) @@ -178,13 +232,12 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): exp1 = literal_column("Q") exp2 = literal_column("Y") self.assert_compile( - select([1]).limit(exp1).offset(exp2), - "SELECT 1 LIMIT Q OFFSET Y" + select([1]).limit(exp1).offset(exp2), "SELECT 1 LIMIT Q OFFSET Y" ) self.assert_compile( - select([1]).limit(bindparam('x')).offset(bindparam('y')), - "SELECT 1 LIMIT :x OFFSET :y" + select([1]).limit(bindparam("x")).offset(bindparam("y")), + "SELECT 1 LIMIT :x OFFSET :y", ) def test_limit_offset_no_int_coercion_two(self): @@ -196,14 +249,18 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): exc.CompileError, "This SELECT structure does not use a simple integer " "value for limit", - getattr, sel, "_limit" + getattr, + sel, + "_limit", ) assert_raises_message( exc.CompileError, "This SELECT structure does not use a simple integer " "value for offset", - getattr, sel, "_offset" + getattr, + sel, + "_offset", ) def test_limit_offset_no_int_coercion_three(self): @@ -215,37 +272,47 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): exc.CompileError, "This SELECT structure does not use a simple integer " "value for limit", - getattr, sel, "_limit" + getattr, + sel, + "_limit", ) assert_raises_message( exc.CompileError, "This SELECT structure does not use a simple integer " "value for offset", - getattr, sel, "_offset" + getattr, + sel, + "_offset", ) def test_limit_offset(self): for lim, offset, exp, params in [ - (5, 10, "LIMIT :param_1 OFFSET :param_2", - {'param_1': 5, 'param_2': 10}), - (None, 10, "LIMIT -1 OFFSET :param_1", {'param_1': 10}), - (5, None, "LIMIT :param_1", {'param_1': 5}), - (0, 0, "LIMIT :param_1 OFFSET :param_2", - {'param_1': 0, 'param_2': 0}), + ( + 5, + 10, + "LIMIT :param_1 OFFSET :param_2", + {"param_1": 5, "param_2": 10}, + ), + (None, 10, "LIMIT -1 OFFSET :param_1", {"param_1": 10}), + (5, None, "LIMIT :param_1", {"param_1": 5}), + ( + 0, + 0, + "LIMIT :param_1 OFFSET :param_2", + {"param_1": 0, "param_2": 0}, + ), ]: self.assert_compile( select([1]).limit(lim).offset(offset), "SELECT 1 " + exp, - checkparams=params + checkparams=params, ) def test_limit_offset_select_literal_binds(self): stmt = select([1]).limit(5).offset(6) self.assert_compile( - stmt, - "SELECT 1 LIMIT 5 OFFSET 6", - literal_binds=True + stmt, "SELECT 1 LIMIT 5 OFFSET 6", literal_binds=True ) def test_limit_offset_compound_select_literal_binds(self): @@ -253,25 +320,24 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( stmt, "SELECT 1 UNION SELECT 2 LIMIT 5 OFFSET 6", - literal_binds=True + literal_binds=True, ) def test_select_precol_compile_ordering(self): - s1 = select([column('x')]).select_from(text('a')).limit(5).as_scalar() + s1 = select([column("x")]).select_from(text("a")).limit(5).as_scalar() s2 = select([s1]).limit(10) class MyCompiler(compiler.SQLCompiler): - def get_select_precolumns(self, select, **kw): result = "" if select._limit: result += "FIRST %s " % self.process( - literal( - select._limit), **kw) + literal(select._limit), **kw + ) if select._offset: result += "SKIP %s " % self.process( - literal( - select._offset), **kw) + literal(select._offset), **kw + ) return result def limit_clause(self, select, **kw): @@ -279,13 +345,13 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): dialect = default.DefaultDialect() dialect.statement_compiler = MyCompiler - dialect.paramstyle = 'qmark' + dialect.paramstyle = "qmark" dialect.positional = True self.assert_compile( s2, "SELECT FIRST ? (SELECT FIRST ? x FROM a) AS anon_1", checkpositional=(10, 5), - dialect=dialect + dialect=dialect, ) def test_from_subquery(self): @@ -293,16 +359,15 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): another select, for the purposes of selecting from the exported columns of that select.""" - s = select([table1], table1.c.name == 'jack') + s = select([table1], table1.c.name == "jack") self.assert_compile( - select( - [s], - s.c.myid == 7), + select([s], s.c.myid == 7), "SELECT myid, name, description FROM " "(SELECT mytable.myid AS myid, " "mytable.name AS name, mytable.description AS description " "FROM mytable " - "WHERE mytable.name = :name_1) WHERE myid = :myid_1") + "WHERE mytable.name = :name_1) WHERE myid = :myid_1", + ) sq = select([table1]) self.assert_compile( @@ -310,44 +375,42 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT myid, name, description FROM " "(SELECT mytable.myid AS myid, " "mytable.name AS name, mytable.description " - "AS description FROM mytable)" + "AS description FROM mytable)", ) - sq = select( - [table1], - ).alias('sq') + sq = select([table1]).alias("sq") self.assert_compile( sq.select(sq.c.myid == 7), "SELECT sq.myid, sq.name, sq.description FROM " "(SELECT mytable.myid AS myid, mytable.name AS name, " "mytable.description AS description FROM mytable) AS sq " - "WHERE sq.myid = :myid_1" + "WHERE sq.myid = :myid_1", ) sq = select( [table1, table2], and_(table1.c.myid == 7, table2.c.otherid == table1.c.myid), - use_labels=True - ).alias('sq') - - sqstring = "SELECT mytable.myid AS mytable_myid, mytable.name AS "\ - "mytable_name, mytable.description AS mytable_description, "\ - "myothertable.otherid AS myothertable_otherid, "\ - "myothertable.othername AS myothertable_othername FROM "\ - "mytable, myothertable WHERE mytable.myid = :myid_1 AND "\ + use_labels=True, + ).alias("sq") + + sqstring = ( + "SELECT mytable.myid AS mytable_myid, mytable.name AS " + "mytable_name, mytable.description AS mytable_description, " + "myothertable.otherid AS myothertable_otherid, " + "myothertable.othername AS myothertable_othername FROM " + "mytable, myothertable WHERE mytable.myid = :myid_1 AND " "myothertable.otherid = mytable.myid" + ) self.assert_compile( sq.select(), "SELECT sq.mytable_myid, sq.mytable_name, " "sq.mytable_description, sq.myothertable_otherid, " - "sq.myothertable_othername FROM (%s) AS sq" % sqstring) + "sq.myothertable_othername FROM (%s) AS sq" % sqstring, + ) - sq2 = select( - [sq], - use_labels=True - ).alias('sq2') + sq2 = select([sq], use_labels=True).alias("sq2") self.assert_compile( sq2.select(), @@ -359,53 +422,53 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "sq.mytable_description AS sq_mytable_description, " "sq.myothertable_otherid AS sq_myothertable_otherid, " "sq.myothertable_othername AS sq_myothertable_othername " - "FROM (%s) AS sq) AS sq2" % sqstring) + "FROM (%s) AS sq) AS sq2" % sqstring, + ) def test_select_from_clauselist(self): self.assert_compile( - select([ClauseList(column('a'), column('b'))] - ).select_from(text('sometable')), - 'SELECT a, b FROM sometable' + select([ClauseList(column("a"), column("b"))]).select_from( + text("sometable") + ), + "SELECT a, b FROM sometable", ) def test_use_labels(self): self.assert_compile( select([table1.c.myid == 5], use_labels=True), - "SELECT mytable.myid = :myid_1 AS anon_1 FROM mytable" + "SELECT mytable.myid = :myid_1 AS anon_1 FROM mytable", ) self.assert_compile( - select([func.foo()], use_labels=True), - "SELECT foo() AS foo_1" + select([func.foo()], use_labels=True), "SELECT foo() AS foo_1" ) # this is native_boolean=False for default dialect self.assert_compile( select([not_(True)], use_labels=True), - "SELECT :param_1 = 0 AS anon_1" + "SELECT :param_1 = 0 AS anon_1", ) self.assert_compile( select([cast("data", Integer)], use_labels=True), - "SELECT CAST(:param_1 AS INTEGER) AS anon_1" + "SELECT CAST(:param_1 AS INTEGER) AS anon_1", ) self.assert_compile( - select([func.sum( - func.lala(table1.c.myid).label('foo')).label('bar')]), - "SELECT sum(lala(mytable.myid)) AS bar FROM mytable" + select( + [func.sum(func.lala(table1.c.myid).label("foo")).label("bar")] + ), + "SELECT sum(lala(mytable.myid)) AS bar FROM mytable", ) self.assert_compile( - select([keyed]), - "SELECT keyed.x, keyed.y" - ", keyed.z FROM keyed" + select([keyed]), "SELECT keyed.x, keyed.y" ", keyed.z FROM keyed" ) self.assert_compile( select([keyed]).apply_labels(), "SELECT keyed.x AS keyed_x, keyed.y AS " - "keyed_y, keyed.z AS keyed_z FROM keyed" + "keyed_y, keyed.z AS keyed_z FROM keyed", ) def test_paramstyles(self): @@ -414,40 +477,40 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( stmt, "select ?, ?, ? from sometable", - dialect=default.DefaultDialect(paramstyle='qmark') + dialect=default.DefaultDialect(paramstyle="qmark"), ) self.assert_compile( stmt, "select :foo, :bar, :bat from sometable", - dialect=default.DefaultDialect(paramstyle='named') + dialect=default.DefaultDialect(paramstyle="named"), ) self.assert_compile( stmt, "select %s, %s, %s from sometable", - dialect=default.DefaultDialect(paramstyle='format') + dialect=default.DefaultDialect(paramstyle="format"), ) self.assert_compile( stmt, "select :1, :2, :3 from sometable", - dialect=default.DefaultDialect(paramstyle='numeric') + dialect=default.DefaultDialect(paramstyle="numeric"), ) self.assert_compile( stmt, "select %(foo)s, %(bar)s, %(bat)s from sometable", - dialect=default.DefaultDialect(paramstyle='pyformat') + dialect=default.DefaultDialect(paramstyle="pyformat"), ) def test_anon_param_name_on_keys(self): self.assert_compile( keyed.insert(), "INSERT INTO keyed (x, y, z) VALUES (%(colx)s, %(coly)s, %(z)s)", - dialect=default.DefaultDialect(paramstyle='pyformat') + dialect=default.DefaultDialect(paramstyle="pyformat"), ) self.assert_compile( keyed.c.coly == 5, "keyed.y = %(coly_1)s", - checkparams={'coly_1': 5}, - dialect=default.DefaultDialect(paramstyle='pyformat') + checkparams={"coly_1": 5}, + dialect=default.DefaultDialect(paramstyle="pyformat"), ) def test_dupe_columns(self): @@ -455,51 +518,54 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): element identity, not rendered result.""" self.assert_compile( - select([column('a'), column('a'), column('a')]), - "SELECT a, a, a", dialect=default.DefaultDialect() + select([column("a"), column("a"), column("a")]), + "SELECT a, a, a", + dialect=default.DefaultDialect(), ) - c = column('a') + c = column("a") self.assert_compile( - select([c, c, c]), - "SELECT a", dialect=default.DefaultDialect() + select([c, c, c]), "SELECT a", dialect=default.DefaultDialect() ) - a, b = column('a'), column('b') + a, b = column("a"), column("b") self.assert_compile( select([a, b, b, b, a, a]), - "SELECT a, b", dialect=default.DefaultDialect() + "SELECT a, b", + dialect=default.DefaultDialect(), ) # using alternate keys. - a, b, c = Column('a', Integer, key='b'), \ - Column('b', Integer), \ - Column('c', Integer, key='a') + a, b, c = ( + Column("a", Integer, key="b"), + Column("b", Integer), + Column("c", Integer, key="a"), + ) self.assert_compile( select([a, b, c, a, b, c]), - "SELECT a, b, c", dialect=default.DefaultDialect() + "SELECT a, b, c", + dialect=default.DefaultDialect(), ) self.assert_compile( - select([bindparam('a'), bindparam('b'), bindparam('c')]), + select([bindparam("a"), bindparam("b"), bindparam("c")]), "SELECT :a AS anon_1, :b AS anon_2, :c AS anon_3", - dialect=default.DefaultDialect(paramstyle='named') + dialect=default.DefaultDialect(paramstyle="named"), ) self.assert_compile( - select([bindparam('a'), bindparam('b'), bindparam('c')]), + select([bindparam("a"), bindparam("b"), bindparam("c")]), "SELECT ? AS anon_1, ? AS anon_2, ? AS anon_3", - dialect=default.DefaultDialect(paramstyle='qmark'), + dialect=default.DefaultDialect(paramstyle="qmark"), ) self.assert_compile( - select([column("a"), column("a"), column("a")]), - "SELECT a, a, a" + select([column("a"), column("a"), column("a")]), "SELECT a, a, a" ) - s = select([bindparam('a'), bindparam('b'), bindparam('c')]) - s = s.compile(dialect=default.DefaultDialect(paramstyle='qmark')) - eq_(s.positiontup, ['a', 'b', 'c']) + s = select([bindparam("a"), bindparam("b"), bindparam("c")]) + s = s.compile(dialect=default.DefaultDialect(paramstyle="qmark")) + eq_(s.positiontup, ["a", "b", "c"]) def test_nested_label_targeting(self): """test nested anonymous label generation. @@ -510,234 +576,275 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): s3 = select([s2], use_labels=True) s4 = s3.alias() s5 = select([s4], use_labels=True) - self.assert_compile(s5, - 'SELECT anon_1.anon_2_myid AS ' - 'anon_1_anon_2_myid, anon_1.anon_2_name AS ' - 'anon_1_anon_2_name, anon_1.anon_2_descript' - 'ion AS anon_1_anon_2_description FROM ' - '(SELECT anon_2.myid AS anon_2_myid, ' - 'anon_2.name AS anon_2_name, ' - 'anon_2.description AS anon_2_description ' - 'FROM (SELECT mytable.myid AS myid, ' - 'mytable.name AS name, mytable.description ' - 'AS description FROM mytable) AS anon_2) ' - 'AS anon_1') + self.assert_compile( + s5, + "SELECT anon_1.anon_2_myid AS " + "anon_1_anon_2_myid, anon_1.anon_2_name AS " + "anon_1_anon_2_name, anon_1.anon_2_descript" + "ion AS anon_1_anon_2_description FROM " + "(SELECT anon_2.myid AS anon_2_myid, " + "anon_2.name AS anon_2_name, " + "anon_2.description AS anon_2_description " + "FROM (SELECT mytable.myid AS myid, " + "mytable.name AS name, mytable.description " + "AS description FROM mytable) AS anon_2) " + "AS anon_1", + ) def test_nested_label_targeting_keyed(self): s1 = keyed.select() s2 = s1.alias() s3 = select([s2], use_labels=True) - self.assert_compile(s3, - "SELECT anon_1.x AS anon_1_x, " - "anon_1.y AS anon_1_y, " - "anon_1.z AS anon_1_z FROM " - "(SELECT keyed.x AS x, keyed.y " - "AS y, keyed.z AS z FROM keyed) AS anon_1") + self.assert_compile( + s3, + "SELECT anon_1.x AS anon_1_x, " + "anon_1.y AS anon_1_y, " + "anon_1.z AS anon_1_z FROM " + "(SELECT keyed.x AS x, keyed.y " + "AS y, keyed.z AS z FROM keyed) AS anon_1", + ) s4 = s3.alias() s5 = select([s4], use_labels=True) - self.assert_compile(s5, - "SELECT anon_1.anon_2_x AS anon_1_anon_2_x, " - "anon_1.anon_2_y AS anon_1_anon_2_y, " - "anon_1.anon_2_z AS anon_1_anon_2_z " - "FROM (SELECT anon_2.x AS anon_2_x, " - "anon_2.y AS anon_2_y, " - "anon_2.z AS anon_2_z FROM " - "(SELECT keyed.x AS x, keyed.y AS y, keyed.z " - "AS z FROM keyed) AS anon_2) AS anon_1" - ) + self.assert_compile( + s5, + "SELECT anon_1.anon_2_x AS anon_1_anon_2_x, " + "anon_1.anon_2_y AS anon_1_anon_2_y, " + "anon_1.anon_2_z AS anon_1_anon_2_z " + "FROM (SELECT anon_2.x AS anon_2_x, " + "anon_2.y AS anon_2_y, " + "anon_2.z AS anon_2_z FROM " + "(SELECT keyed.x AS x, keyed.y AS y, keyed.z " + "AS z FROM keyed) AS anon_2) AS anon_1", + ) def test_exists(self): s = select([table1.c.myid]).where(table1.c.myid == 5) - self.assert_compile(exists(s), - "EXISTS (SELECT mytable.myid FROM mytable " - "WHERE mytable.myid = :myid_1)" - ) - - self.assert_compile(exists(s.as_scalar()), - "EXISTS (SELECT mytable.myid FROM mytable " - "WHERE mytable.myid = :myid_1)" - ) - - self.assert_compile(exists([table1.c.myid], table1.c.myid - == 5).select(), - 'SELECT EXISTS (SELECT mytable.myid FROM ' - 'mytable WHERE mytable.myid = :myid_1) AS anon_1', - params={'mytable_myid': 5}) - self.assert_compile(select([table1, exists([1], - from_obj=table2)]), - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description, EXISTS (SELECT 1 ' - 'FROM myothertable) AS anon_1 FROM mytable', - params={}) - self.assert_compile(select([table1, - exists([1], - from_obj=table2).label('foo')]), - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description, EXISTS (SELECT 1 ' - 'FROM myothertable) AS foo FROM mytable', - params={}) + self.assert_compile( + exists(s), + "EXISTS (SELECT mytable.myid FROM mytable " + "WHERE mytable.myid = :myid_1)", + ) + + self.assert_compile( + exists(s.as_scalar()), + "EXISTS (SELECT mytable.myid FROM mytable " + "WHERE mytable.myid = :myid_1)", + ) + + self.assert_compile( + exists([table1.c.myid], table1.c.myid == 5).select(), + "SELECT EXISTS (SELECT mytable.myid FROM " + "mytable WHERE mytable.myid = :myid_1) AS anon_1", + params={"mytable_myid": 5}, + ) + self.assert_compile( + select([table1, exists([1], from_obj=table2)]), + "SELECT mytable.myid, mytable.name, " + "mytable.description, EXISTS (SELECT 1 " + "FROM myothertable) AS anon_1 FROM mytable", + params={}, + ) + self.assert_compile( + select([table1, exists([1], from_obj=table2).label("foo")]), + "SELECT mytable.myid, mytable.name, " + "mytable.description, EXISTS (SELECT 1 " + "FROM myothertable) AS foo FROM mytable", + params={}, + ) self.assert_compile( table1.select( - exists().where( - table2.c.otherid == table1.c.myid).correlate(table1)), - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description FROM mytable WHERE ' - 'EXISTS (SELECT * FROM myothertable WHERE ' - 'myothertable.otherid = mytable.myid)') + exists() + .where(table2.c.otherid == table1.c.myid) + .correlate(table1) + ), + "SELECT mytable.myid, mytable.name, " + "mytable.description FROM mytable WHERE " + "EXISTS (SELECT * FROM myothertable WHERE " + "myothertable.otherid = mytable.myid)", + ) self.assert_compile( table1.select( - exists().where( - table2.c.otherid == table1.c.myid).correlate(table1)), - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description FROM mytable WHERE ' - 'EXISTS (SELECT * FROM myothertable WHERE ' - 'myothertable.otherid = mytable.myid)') + exists() + .where(table2.c.otherid == table1.c.myid) + .correlate(table1) + ), + "SELECT mytable.myid, mytable.name, " + "mytable.description FROM mytable WHERE " + "EXISTS (SELECT * FROM myothertable WHERE " + "myothertable.otherid = mytable.myid)", + ) self.assert_compile( table1.select( - exists().where( - table2.c.otherid == table1.c.myid).correlate(table1) - ).replace_selectable( - table2, - table2.alias()), - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description FROM mytable WHERE ' - 'EXISTS (SELECT * FROM myothertable AS ' - 'myothertable_1 WHERE myothertable_1.otheri' - 'd = mytable.myid)') + exists() + .where(table2.c.otherid == table1.c.myid) + .correlate(table1) + ).replace_selectable(table2, table2.alias()), + "SELECT mytable.myid, mytable.name, " + "mytable.description FROM mytable WHERE " + "EXISTS (SELECT * FROM myothertable AS " + "myothertable_1 WHERE myothertable_1.otheri" + "d = mytable.myid)", + ) self.assert_compile( table1.select( - exists().where( - table2.c.otherid == table1.c.myid).correlate(table1)). - select_from( - table1.join( - table2, - table1.c.myid == table2.c.otherid)). - replace_selectable( - table2, - table2.alias()), - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description FROM mytable JOIN ' - 'myothertable AS myothertable_1 ON ' - 'mytable.myid = myothertable_1.otherid ' - 'WHERE EXISTS (SELECT * FROM myothertable ' - 'AS myothertable_1 WHERE ' - 'myothertable_1.otherid = mytable.myid)') - - self.assert_compile( - select([ - or_( - exists().where(table2.c.otherid == 'foo'), - exists().where(table2.c.otherid == 'bar') - ) - ]), + exists() + .where(table2.c.otherid == table1.c.myid) + .correlate(table1) + ) + .select_from( + table1.join(table2, table1.c.myid == table2.c.otherid) + ) + .replace_selectable(table2, table2.alias()), + "SELECT mytable.myid, mytable.name, " + "mytable.description FROM mytable JOIN " + "myothertable AS myothertable_1 ON " + "mytable.myid = myothertable_1.otherid " + "WHERE EXISTS (SELECT * FROM myothertable " + "AS myothertable_1 WHERE " + "myothertable_1.otherid = mytable.myid)", + ) + + self.assert_compile( + select( + [ + or_( + exists().where(table2.c.otherid == "foo"), + exists().where(table2.c.otherid == "bar"), + ) + ] + ), "SELECT (EXISTS (SELECT * FROM myothertable " "WHERE myothertable.otherid = :otherid_1)) " "OR (EXISTS (SELECT * FROM myothertable WHERE " - "myothertable.otherid = :otherid_2)) AS anon_1" + "myothertable.otherid = :otherid_2)) AS anon_1", ) self.assert_compile( - select([exists([1])]), - "SELECT EXISTS (SELECT 1) AS anon_1" + select([exists([1])]), "SELECT EXISTS (SELECT 1) AS anon_1" ) self.assert_compile( - select([~exists([1])]), - "SELECT NOT (EXISTS (SELECT 1)) AS anon_1" + select([~exists([1])]), "SELECT NOT (EXISTS (SELECT 1)) AS anon_1" ) self.assert_compile( select([~(~exists([1]))]), - "SELECT NOT (NOT (EXISTS (SELECT 1))) AS anon_1" + "SELECT NOT (NOT (EXISTS (SELECT 1))) AS anon_1", ) def test_where_subquery(self): - s = select([addresses.c.street], addresses.c.user_id - == users.c.user_id, correlate=True).alias('s') + s = select( + [addresses.c.street], + addresses.c.user_id == users.c.user_id, + correlate=True, + ).alias("s") # don't correlate in a FROM list - self.assert_compile(select([users, s.c.street], from_obj=s), - "SELECT users.user_id, users.user_name, " - "users.password, s.street FROM users, " - "(SELECT addresses.street AS street FROM " - "addresses, users WHERE addresses.user_id = " - "users.user_id) AS s") - self.assert_compile(table1.select( - table1.c.myid == select( - [table1.c.myid], - table1.c.name == 'jack')), - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description FROM mytable WHERE ' - 'mytable.myid = (SELECT mytable.myid FROM ' - 'mytable WHERE mytable.name = :name_1)') + self.assert_compile( + select([users, s.c.street], from_obj=s), + "SELECT users.user_id, users.user_name, " + "users.password, s.street FROM users, " + "(SELECT addresses.street AS street FROM " + "addresses, users WHERE addresses.user_id = " + "users.user_id) AS s", + ) + self.assert_compile( + table1.select( + table1.c.myid + == select([table1.c.myid], table1.c.name == "jack") + ), + "SELECT mytable.myid, mytable.name, " + "mytable.description FROM mytable WHERE " + "mytable.myid = (SELECT mytable.myid FROM " + "mytable WHERE mytable.name = :name_1)", + ) self.assert_compile( table1.select( - table1.c.myid == select( - [table2.c.otherid], - table1.c.name == table2.c.othername + table1.c.myid + == select( + [table2.c.otherid], table1.c.name == table2.c.othername ) ), - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description FROM mytable WHERE ' - 'mytable.myid = (SELECT ' - 'myothertable.otherid FROM myothertable ' - 'WHERE mytable.name = myothertable.othernam' - 'e)') - self.assert_compile(table1.select(exists([1], table2.c.otherid - == table1.c.myid)), - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description FROM mytable WHERE ' - 'EXISTS (SELECT 1 FROM myothertable WHERE ' - 'myothertable.otherid = mytable.myid)') - talias = table1.alias('ta') - s = subquery('sq2', [talias], exists([1], table2.c.otherid - == talias.c.myid)) - self.assert_compile(select([s, table1]), - 'SELECT sq2.myid, sq2.name, ' - 'sq2.description, mytable.myid, ' - 'mytable.name, mytable.description FROM ' - '(SELECT ta.myid AS myid, ta.name AS name, ' - 'ta.description AS description FROM ' - 'mytable AS ta WHERE EXISTS (SELECT 1 FROM ' - 'myothertable WHERE myothertable.otherid = ' - 'ta.myid)) AS sq2, mytable') + "SELECT mytable.myid, mytable.name, " + "mytable.description FROM mytable WHERE " + "mytable.myid = (SELECT " + "myothertable.otherid FROM myothertable " + "WHERE mytable.name = myothertable.othernam" + "e)", + ) + self.assert_compile( + table1.select(exists([1], table2.c.otherid == table1.c.myid)), + "SELECT mytable.myid, mytable.name, " + "mytable.description FROM mytable WHERE " + "EXISTS (SELECT 1 FROM myothertable WHERE " + "myothertable.otherid = mytable.myid)", + ) + talias = table1.alias("ta") + s = subquery( + "sq2", [talias], exists([1], table2.c.otherid == talias.c.myid) + ) + self.assert_compile( + select([s, table1]), + "SELECT sq2.myid, sq2.name, " + "sq2.description, mytable.myid, " + "mytable.name, mytable.description FROM " + "(SELECT ta.myid AS myid, ta.name AS name, " + "ta.description AS description FROM " + "mytable AS ta WHERE EXISTS (SELECT 1 FROM " + "myothertable WHERE myothertable.otherid = " + "ta.myid)) AS sq2, mytable", + ) # test constructing the outer query via append_column(), which # occurs in the ORM's Query object - s = select([], exists([1], table2.c.otherid == table1.c.myid), - from_obj=table1) + s = select( + [], exists([1], table2.c.otherid == table1.c.myid), from_obj=table1 + ) s.append_column(table1) - self.assert_compile(s, - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description FROM mytable WHERE ' - 'EXISTS (SELECT 1 FROM myothertable WHERE ' - 'myothertable.otherid = mytable.myid)') + self.assert_compile( + s, + "SELECT mytable.myid, mytable.name, " + "mytable.description FROM mytable WHERE " + "EXISTS (SELECT 1 FROM myothertable WHERE " + "myothertable.otherid = mytable.myid)", + ) def test_orderby_subquery(self): self.assert_compile( table1.select( order_by=[ select( - [ - table2.c.otherid], - table1.c.myid == table2.c.otherid)]), - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description FROM mytable ORDER BY ' - '(SELECT myothertable.otherid FROM ' - 'myothertable WHERE mytable.myid = ' - 'myothertable.otherid)') - self.assert_compile(table1.select(order_by=[ - desc(select([table2.c.otherid], - table1.c.myid == table2.c.otherid))]), - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description FROM mytable ORDER BY ' - '(SELECT myothertable.otherid FROM ' - 'myothertable WHERE mytable.myid = ' - 'myothertable.otherid) DESC') + [table2.c.otherid], table1.c.myid == table2.c.otherid + ) + ] + ), + "SELECT mytable.myid, mytable.name, " + "mytable.description FROM mytable ORDER BY " + "(SELECT myothertable.otherid FROM " + "myothertable WHERE mytable.myid = " + "myothertable.otherid)", + ) + self.assert_compile( + table1.select( + order_by=[ + desc( + select( + [table2.c.otherid], + table1.c.myid == table2.c.otherid, + ) + ) + ] + ), + "SELECT mytable.myid, mytable.name, " + "mytable.description FROM mytable ORDER BY " + "(SELECT myothertable.otherid FROM " + "myothertable WHERE mytable.myid = " + "myothertable.otherid) DESC", + ) def test_scalar_select(self): assert_raises_message( @@ -745,72 +852,86 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): r"Select objects don't have a type\. Call as_scalar\(\) " r"on this Select object to return a 'scalar' " r"version of this Select\.", - func.coalesce, select([table1.c.myid]) + func.coalesce, + select([table1.c.myid]), ) s = select([table1.c.myid], correlate=False).as_scalar() - self.assert_compile(select([table1, s]), - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description, (SELECT mytable.myid ' - 'FROM mytable) AS anon_1 FROM mytable') + self.assert_compile( + select([table1, s]), + "SELECT mytable.myid, mytable.name, " + "mytable.description, (SELECT mytable.myid " + "FROM mytable) AS anon_1 FROM mytable", + ) s = select([table1.c.myid]).as_scalar() - self.assert_compile(select([table2, s]), - 'SELECT myothertable.otherid, ' - 'myothertable.othername, (SELECT ' - 'mytable.myid FROM mytable) AS anon_1 FROM ' - 'myothertable') + self.assert_compile( + select([table2, s]), + "SELECT myothertable.otherid, " + "myothertable.othername, (SELECT " + "mytable.myid FROM mytable) AS anon_1 FROM " + "myothertable", + ) s = select([table1.c.myid]).correlate(None).as_scalar() - self.assert_compile(select([table1, s]), - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description, (SELECT mytable.myid ' - 'FROM mytable) AS anon_1 FROM mytable') + self.assert_compile( + select([table1, s]), + "SELECT mytable.myid, mytable.name, " + "mytable.description, (SELECT mytable.myid " + "FROM mytable) AS anon_1 FROM mytable", + ) s = select([table1.c.myid]).as_scalar() s2 = s.where(table1.c.myid == 5) self.assert_compile( s2, - "(SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_1)" - ) - self.assert_compile( - s, "(SELECT mytable.myid FROM mytable)" + "(SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_1)", ) + self.assert_compile(s, "(SELECT mytable.myid FROM mytable)") # test that aliases use as_scalar() when used in an explicitly # scalar context s = select([table1.c.myid]).alias() - self.assert_compile(select([table1.c.myid]).where(table1.c.myid - == s), - 'SELECT mytable.myid FROM mytable WHERE ' - 'mytable.myid = (SELECT mytable.myid FROM ' - 'mytable)') - self.assert_compile(select([table1.c.myid]).where(s - > table1.c.myid), - 'SELECT mytable.myid FROM mytable WHERE ' - 'mytable.myid < (SELECT mytable.myid FROM ' - 'mytable)') + self.assert_compile( + select([table1.c.myid]).where(table1.c.myid == s), + "SELECT mytable.myid FROM mytable WHERE " + "mytable.myid = (SELECT mytable.myid FROM " + "mytable)", + ) + self.assert_compile( + select([table1.c.myid]).where(s > table1.c.myid), + "SELECT mytable.myid FROM mytable WHERE " + "mytable.myid < (SELECT mytable.myid FROM " + "mytable)", + ) s = select([table1.c.myid]).as_scalar() - self.assert_compile(select([table2, s]), - 'SELECT myothertable.otherid, ' - 'myothertable.othername, (SELECT ' - 'mytable.myid FROM mytable) AS anon_1 FROM ' - 'myothertable') + self.assert_compile( + select([table2, s]), + "SELECT myothertable.otherid, " + "myothertable.othername, (SELECT " + "mytable.myid FROM mytable) AS anon_1 FROM " + "myothertable", + ) # test expressions against scalar selects - self.assert_compile(select([s - literal(8)]), - 'SELECT (SELECT mytable.myid FROM mytable) ' - '- :param_1 AS anon_1') - self.assert_compile(select([select([table1.c.name]).as_scalar() - + literal('x')]), - 'SELECT (SELECT mytable.name FROM mytable) ' - '|| :param_1 AS anon_1') - self.assert_compile(select([s > literal(8)]), - 'SELECT (SELECT mytable.myid FROM mytable) ' - '> :param_1 AS anon_1') - self.assert_compile(select([select([table1.c.name]).label('foo' - )]), - 'SELECT (SELECT mytable.name FROM mytable) ' - 'AS foo') + self.assert_compile( + select([s - literal(8)]), + "SELECT (SELECT mytable.myid FROM mytable) " + "- :param_1 AS anon_1", + ) + self.assert_compile( + select([select([table1.c.name]).as_scalar() + literal("x")]), + "SELECT (SELECT mytable.name FROM mytable) " + "|| :param_1 AS anon_1", + ) + self.assert_compile( + select([s > literal(8)]), + "SELECT (SELECT mytable.myid FROM mytable) " + "> :param_1 AS anon_1", + ) + self.assert_compile( + select([select([table1.c.name]).label("foo")]), + "SELECT (SELECT mytable.name FROM mytable) " "AS foo", + ) # scalar selects should not have any attributes on their 'c' or # 'columns' attribute @@ -819,101 +940,129 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): try: s.c.foo except exc.InvalidRequestError as err: - assert str(err) \ - == 'Scalar Select expression has no columns; use this '\ - 'object directly within a column-level expression.' + assert ( + str(err) + == "Scalar Select expression has no columns; use this " + "object directly within a column-level expression." + ) try: s.columns.foo except exc.InvalidRequestError as err: - assert str(err) \ - == 'Scalar Select expression has no columns; use this '\ - 'object directly within a column-level expression.' - - zips = table('zips', - column('zipcode'), - column('latitude'), - column('longitude'), - ) - places = table('places', - column('id'), - column('nm') - ) - zip = '12345' - qlat = select([zips.c.latitude], zips.c.zipcode == zip).\ - correlate(None).as_scalar() - qlng = select([zips.c.longitude], zips.c.zipcode == zip).\ - correlate(None).as_scalar() - - q = select([places.c.id, places.c.nm, zips.c.zipcode, - func.latlondist(qlat, qlng).label('dist')], - zips.c.zipcode == zip, - order_by=['dist', places.c.nm] - ) - - self.assert_compile(q, - 'SELECT places.id, places.nm, ' - 'zips.zipcode, latlondist((SELECT ' - 'zips.latitude FROM zips WHERE ' - 'zips.zipcode = :zipcode_1), (SELECT ' - 'zips.longitude FROM zips WHERE ' - 'zips.zipcode = :zipcode_2)) AS dist FROM ' - 'places, zips WHERE zips.zipcode = ' - ':zipcode_3 ORDER BY dist, places.nm') - - zalias = zips.alias('main_zip') - qlat = select([zips.c.latitude], zips.c.zipcode == zalias.c.zipcode).\ - as_scalar() - qlng = select([zips.c.longitude], zips.c.zipcode == zalias.c.zipcode).\ - as_scalar() - q = select([places.c.id, places.c.nm, zalias.c.zipcode, - func.latlondist(qlat, qlng).label('dist')], - order_by=['dist', places.c.nm]) - self.assert_compile(q, - 'SELECT places.id, places.nm, ' - 'main_zip.zipcode, latlondist((SELECT ' - 'zips.latitude FROM zips WHERE ' - 'zips.zipcode = main_zip.zipcode), (SELECT ' - 'zips.longitude FROM zips WHERE ' - 'zips.zipcode = main_zip.zipcode)) AS dist ' - 'FROM places, zips AS main_zip ORDER BY ' - 'dist, places.nm') - - a1 = table2.alias('t2alias') + assert ( + str(err) + == "Scalar Select expression has no columns; use this " + "object directly within a column-level expression." + ) + + zips = table( + "zips", column("zipcode"), column("latitude"), column("longitude") + ) + places = table("places", column("id"), column("nm")) + zip = "12345" + qlat = ( + select([zips.c.latitude], zips.c.zipcode == zip) + .correlate(None) + .as_scalar() + ) + qlng = ( + select([zips.c.longitude], zips.c.zipcode == zip) + .correlate(None) + .as_scalar() + ) + + q = select( + [ + places.c.id, + places.c.nm, + zips.c.zipcode, + func.latlondist(qlat, qlng).label("dist"), + ], + zips.c.zipcode == zip, + order_by=["dist", places.c.nm], + ) + + self.assert_compile( + q, + "SELECT places.id, places.nm, " + "zips.zipcode, latlondist((SELECT " + "zips.latitude FROM zips WHERE " + "zips.zipcode = :zipcode_1), (SELECT " + "zips.longitude FROM zips WHERE " + "zips.zipcode = :zipcode_2)) AS dist FROM " + "places, zips WHERE zips.zipcode = " + ":zipcode_3 ORDER BY dist, places.nm", + ) + + zalias = zips.alias("main_zip") + qlat = select( + [zips.c.latitude], zips.c.zipcode == zalias.c.zipcode + ).as_scalar() + qlng = select( + [zips.c.longitude], zips.c.zipcode == zalias.c.zipcode + ).as_scalar() + q = select( + [ + places.c.id, + places.c.nm, + zalias.c.zipcode, + func.latlondist(qlat, qlng).label("dist"), + ], + order_by=["dist", places.c.nm], + ) + self.assert_compile( + q, + "SELECT places.id, places.nm, " + "main_zip.zipcode, latlondist((SELECT " + "zips.latitude FROM zips WHERE " + "zips.zipcode = main_zip.zipcode), (SELECT " + "zips.longitude FROM zips WHERE " + "zips.zipcode = main_zip.zipcode)) AS dist " + "FROM places, zips AS main_zip ORDER BY " + "dist, places.nm", + ) + + a1 = table2.alias("t2alias") s1 = select([a1.c.otherid], table1.c.myid == a1.c.otherid).as_scalar() j1 = table1.join(table2, table1.c.myid == table2.c.otherid) s2 = select([table1, s1], from_obj=j1) - self.assert_compile(s2, - 'SELECT mytable.myid, mytable.name, ' - 'mytable.description, (SELECT ' - 't2alias.otherid FROM myothertable AS ' - 't2alias WHERE mytable.myid = ' - 't2alias.otherid) AS anon_1 FROM mytable ' - 'JOIN myothertable ON mytable.myid = ' - 'myothertable.otherid') + self.assert_compile( + s2, + "SELECT mytable.myid, mytable.name, " + "mytable.description, (SELECT " + "t2alias.otherid FROM myothertable AS " + "t2alias WHERE mytable.myid = " + "t2alias.otherid) AS anon_1 FROM mytable " + "JOIN myothertable ON mytable.myid = " + "myothertable.otherid", + ) def test_label_comparison_one(self): - x = func.lala(table1.c.myid).label('foo') - self.assert_compile(select([x], x == 5), - 'SELECT lala(mytable.myid) AS foo FROM ' - 'mytable WHERE lala(mytable.myid) = ' - ':param_1') + x = func.lala(table1.c.myid).label("foo") + self.assert_compile( + select([x], x == 5), + "SELECT lala(mytable.myid) AS foo FROM " + "mytable WHERE lala(mytable.myid) = " + ":param_1", + ) def test_label_comparison_two(self): self.assert_compile( - label('bar', column('foo', type_=String)) + 'foo', - 'foo || :param_1') + label("bar", column("foo", type_=String)) + "foo", + "foo || :param_1", + ) def test_order_by_labels_enabled(self): - lab1 = (table1.c.myid + 12).label('foo') - lab2 = func.somefunc(table1.c.name).label('bar') + lab1 = (table1.c.myid + 12).label("foo") + lab2 = func.somefunc(table1.c.name).label("bar") dialect = default.DefaultDialect() - self.assert_compile(select([lab1, lab2]).order_by(lab1, desc(lab2)), - "SELECT mytable.myid + :myid_1 AS foo, " - "somefunc(mytable.name) AS bar FROM mytable " - "ORDER BY foo, bar DESC", - dialect=dialect - ) + self.assert_compile( + select([lab1, lab2]).order_by(lab1, desc(lab2)), + "SELECT mytable.myid + :myid_1 AS foo, " + "somefunc(mytable.name) AS bar FROM mytable " + "ORDER BY foo, bar DESC", + dialect=dialect, + ) # the function embedded label renders as the function self.assert_compile( @@ -921,16 +1070,17 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT mytable.myid + :myid_1 AS foo, " "somefunc(mytable.name) AS bar FROM mytable " "ORDER BY hoho(mytable.myid + :myid_1), bar DESC", - dialect=dialect + dialect=dialect, ) # binary expressions render as the expression without labels - self.assert_compile(select([lab1, lab2]).order_by(lab1 + "test"), - "SELECT mytable.myid + :myid_1 AS foo, " - "somefunc(mytable.name) AS bar FROM mytable " - "ORDER BY mytable.myid + :myid_1 + :param_1", - dialect=dialect - ) + self.assert_compile( + select([lab1, lab2]).order_by(lab1 + "test"), + "SELECT mytable.myid + :myid_1 AS foo, " + "somefunc(mytable.name) AS bar FROM mytable " + "ORDER BY mytable.myid + :myid_1 + :param_1", + dialect=dialect, + ) # labels within functions in the columns clause render # with the expression @@ -939,98 +1089,92 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT mytable.myid + :myid_1 AS foo, " "foo(mytable.myid + :myid_1) AS foo_1 FROM mytable " "ORDER BY foo, foo(mytable.myid + :myid_1)", - dialect=dialect + dialect=dialect, ) - lx = (table1.c.myid + table1.c.myid).label('lx') - ly = (func.lower(table1.c.name) + table1.c.description).label('ly') + lx = (table1.c.myid + table1.c.myid).label("lx") + ly = (func.lower(table1.c.name) + table1.c.description).label("ly") self.assert_compile( select([lx, ly]).order_by(lx, ly.desc()), "SELECT mytable.myid + mytable.myid AS lx, " "lower(mytable.name) || mytable.description AS ly " "FROM mytable ORDER BY lx, ly DESC", - dialect=dialect + dialect=dialect, ) # expression isn't actually the same thing (even though label is) self.assert_compile( select([lab1, lab2]).order_by( - table1.c.myid.label('foo'), - desc(table1.c.name.label('bar')) + table1.c.myid.label("foo"), desc(table1.c.name.label("bar")) ), "SELECT mytable.myid + :myid_1 AS foo, " "somefunc(mytable.name) AS bar FROM mytable " "ORDER BY mytable.myid, mytable.name DESC", - dialect=dialect + dialect=dialect, ) # it's also an exact match, not aliased etc. self.assert_compile( select([lab1, lab2]).order_by( - desc(table1.alias().c.name.label('bar')) + desc(table1.alias().c.name.label("bar")) ), "SELECT mytable.myid + :myid_1 AS foo, " "somefunc(mytable.name) AS bar FROM mytable " "ORDER BY mytable_1.name DESC", - dialect=dialect + dialect=dialect, ) # but! it's based on lineage lab2_lineage = lab2.element._clone() self.assert_compile( - select([lab1, lab2]).order_by( - desc(lab2_lineage.label('bar')) - ), + select([lab1, lab2]).order_by(desc(lab2_lineage.label("bar"))), "SELECT mytable.myid + :myid_1 AS foo, " "somefunc(mytable.name) AS bar FROM mytable " "ORDER BY bar DESC", - dialect=dialect + dialect=dialect, ) # here, 'name' is implicitly available, but w/ #3882 we don't # want to render a name that isn't specifically a Label elsewhere # in the query self.assert_compile( - select([table1.c.myid]).order_by(table1.c.name.label('name')), - "SELECT mytable.myid FROM mytable ORDER BY mytable.name" + select([table1.c.myid]).order_by(table1.c.name.label("name")), + "SELECT mytable.myid FROM mytable ORDER BY mytable.name", ) # as well as if it doesn't match self.assert_compile( select([table1.c.myid]).order_by( - func.lower(table1.c.name).label('name')), - "SELECT mytable.myid FROM mytable ORDER BY lower(mytable.name)" + func.lower(table1.c.name).label("name") + ), + "SELECT mytable.myid FROM mytable ORDER BY lower(mytable.name)", ) def test_order_by_labels_disabled(self): - lab1 = (table1.c.myid + 12).label('foo') - lab2 = func.somefunc(table1.c.name).label('bar') + lab1 = (table1.c.myid + 12).label("foo") + lab2 = func.somefunc(table1.c.name).label("bar") dialect = default.DefaultDialect() dialect.supports_simple_order_by_label = False self.assert_compile( - select( - [ - lab1, - lab2]).order_by( - lab1, - desc(lab2)), + select([lab1, lab2]).order_by(lab1, desc(lab2)), "SELECT mytable.myid + :myid_1 AS foo, " "somefunc(mytable.name) AS bar FROM mytable " "ORDER BY mytable.myid + :myid_1, somefunc(mytable.name) DESC", - dialect=dialect) + dialect=dialect, + ) self.assert_compile( select([lab1, lab2]).order_by(func.hoho(lab1), desc(lab2)), "SELECT mytable.myid + :myid_1 AS foo, " "somefunc(mytable.name) AS bar FROM mytable " "ORDER BY hoho(mytable.myid + :myid_1), " "somefunc(mytable.name) DESC", - dialect=dialect + dialect=dialect, ) def test_no_group_by_labels(self): - lab1 = (table1.c.myid + 12).label('foo') - lab2 = func.somefunc(table1.c.name).label('bar') + lab1 = (table1.c.myid + 12).label("foo") + lab2 = func.somefunc(table1.c.name).label("bar") dialect = default.DefaultDialect() self.assert_compile( @@ -1038,140 +1182,140 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT mytable.myid + :myid_1 AS foo, somefunc(mytable.name) " "AS bar FROM mytable GROUP BY mytable.myid + :myid_1, " "somefunc(mytable.name)", - dialect=dialect + dialect=dialect, ) def test_conjunctions(self): - a, b, c = text('a'), text('b'), text('c') + a, b, c = text("a"), text("b"), text("c") x = and_(a, b, c) assert isinstance(x.type, Boolean) - assert str(x) == 'a AND b AND c' + assert str(x) == "a AND b AND c" self.assert_compile( - select([x.label('foo')]), - 'SELECT a AND b AND c AS foo' + select([x.label("foo")]), "SELECT a AND b AND c AS foo" ) self.assert_compile( - and_(table1.c.myid == 12, table1.c.name == 'asdf', - table2.c.othername == 'foo', text("sysdate() = today()")), + and_( + table1.c.myid == 12, + table1.c.name == "asdf", + table2.c.othername == "foo", + text("sysdate() = today()"), + ), "mytable.myid = :myid_1 AND mytable.name = :name_1 " "AND myothertable.othername = " - ":othername_1 AND sysdate() = today()" + ":othername_1 AND sysdate() = today()", ) self.assert_compile( and_( table1.c.myid == 12, - or_(table2.c.othername == 'asdf', - table2.c.othername == 'foo', table2.c.otherid == 9), + or_( + table2.c.othername == "asdf", + table2.c.othername == "foo", + table2.c.otherid == 9, + ), text("sysdate() = today()"), ), - 'mytable.myid = :myid_1 AND (myothertable.othername = ' - ':othername_1 OR myothertable.othername = :othername_2 OR ' - 'myothertable.otherid = :otherid_1) AND sysdate() = ' - 'today()', - checkparams={'othername_1': 'asdf', 'othername_2': 'foo', - 'otherid_1': 9, 'myid_1': 12} + "mytable.myid = :myid_1 AND (myothertable.othername = " + ":othername_1 OR myothertable.othername = :othername_2 OR " + "myothertable.otherid = :otherid_1) AND sysdate() = " + "today()", + checkparams={ + "othername_1": "asdf", + "othername_2": "foo", + "otherid_1": 9, + "myid_1": 12, + }, ) # test a generator self.assert_compile( and_( - conj for conj in [ - table1.c.myid == 12, - table1.c.name == 'asdf' - ] + conj for conj in [table1.c.myid == 12, table1.c.name == "asdf"] ), - "mytable.myid = :myid_1 AND mytable.name = :name_1" + "mytable.myid = :myid_1 AND mytable.name = :name_1", ) def test_nested_conjunctions_short_circuit(self): """test that empty or_(), and_() conjunctions are collapsed by an enclosing conjunction.""" - t = table('t', column('x')) + t = table("t", column("x")) self.assert_compile( - select([t]).where(and_(t.c.x == 5, - or_(and_(or_(t.c.x == 7))))), - "SELECT t.x FROM t WHERE t.x = :x_1 AND t.x = :x_2" + select([t]).where(and_(t.c.x == 5, or_(and_(or_(t.c.x == 7))))), + "SELECT t.x FROM t WHERE t.x = :x_1 AND t.x = :x_2", ) self.assert_compile( - select([t]).where(and_(or_(t.c.x == 12, - and_(or_(t.c.x == 8))))), - "SELECT t.x FROM t WHERE t.x = :x_1 OR t.x = :x_2" + select([t]).where(and_(or_(t.c.x == 12, and_(or_(t.c.x == 8))))), + "SELECT t.x FROM t WHERE t.x = :x_1 OR t.x = :x_2", ) self.assert_compile( - select([t]). - where( + select([t]).where( and_( or_( or_(t.c.x == 12), - and_( - or_(), - or_(and_(t.c.x == 8)), - and_() - ) + and_(or_(), or_(and_(t.c.x == 8)), and_()), ) ) ), - "SELECT t.x FROM t WHERE t.x = :x_1 OR t.x = :x_2" + "SELECT t.x FROM t WHERE t.x = :x_1 OR t.x = :x_2", ) def test_true_short_circuit(self): - t = table('t', column('x')) + t = table("t", column("x")) self.assert_compile( select([t]).where(true()), "SELECT t.x FROM t WHERE 1 = 1", - dialect=default.DefaultDialect(supports_native_boolean=False) + dialect=default.DefaultDialect(supports_native_boolean=False), ) self.assert_compile( select([t]).where(true()), "SELECT t.x FROM t WHERE true", - dialect=default.DefaultDialect(supports_native_boolean=True) + dialect=default.DefaultDialect(supports_native_boolean=True), ) self.assert_compile( select([t]), "SELECT t.x FROM t", - dialect=default.DefaultDialect(supports_native_boolean=True) + dialect=default.DefaultDialect(supports_native_boolean=True), ) def test_distinct(self): self.assert_compile( select([table1.c.myid.distinct()]), - "SELECT DISTINCT mytable.myid FROM mytable" + "SELECT DISTINCT mytable.myid FROM mytable", ) self.assert_compile( select([distinct(table1.c.myid)]), - "SELECT DISTINCT mytable.myid FROM mytable" + "SELECT DISTINCT mytable.myid FROM mytable", ) self.assert_compile( select([table1.c.myid]).distinct(), - "SELECT DISTINCT mytable.myid FROM mytable" + "SELECT DISTINCT mytable.myid FROM mytable", ) self.assert_compile( select([func.count(table1.c.myid.distinct())]), - "SELECT count(DISTINCT mytable.myid) AS count_1 FROM mytable" + "SELECT count(DISTINCT mytable.myid) AS count_1 FROM mytable", ) self.assert_compile( select([func.count(distinct(table1.c.myid))]), - "SELECT count(DISTINCT mytable.myid) AS count_1 FROM mytable" + "SELECT count(DISTINCT mytable.myid) AS count_1 FROM mytable", ) def test_where_empty(self): self.assert_compile( select([table1.c.myid]).where(and_()), - "SELECT mytable.myid FROM mytable" + "SELECT mytable.myid FROM mytable", ) self.assert_compile( select([table1.c.myid]).where(or_()), - "SELECT mytable.myid FROM mytable" + "SELECT mytable.myid FROM mytable", ) def test_multiple_col_binds(self): @@ -1179,133 +1323,165 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): select( [literal_column("*")], or_( - table1.c.myid == 12, table1.c.myid == 'asdf', - table1.c.myid == 'foo') + table1.c.myid == 12, + table1.c.myid == "asdf", + table1.c.myid == "foo", + ), ), "SELECT * FROM mytable WHERE mytable.myid = :myid_1 " - "OR mytable.myid = :myid_2 OR mytable.myid = :myid_3" + "OR mytable.myid = :myid_2 OR mytable.myid = :myid_3", ) def test_order_by_nulls(self): self.assert_compile( - table2.select(order_by=[table2.c.otherid, - table2.c.othername.desc().nullsfirst()]), + table2.select( + order_by=[ + table2.c.otherid, + table2.c.othername.desc().nullsfirst(), + ] + ), "SELECT myothertable.otherid, myothertable.othername FROM " "myothertable ORDER BY myothertable.otherid, " - "myothertable.othername DESC NULLS FIRST" + "myothertable.othername DESC NULLS FIRST", ) self.assert_compile( - table2.select(order_by=[ - table2.c.otherid, table2.c.othername.desc().nullslast()]), + table2.select( + order_by=[ + table2.c.otherid, + table2.c.othername.desc().nullslast(), + ] + ), "SELECT myothertable.otherid, myothertable.othername FROM " "myothertable ORDER BY myothertable.otherid, " - "myothertable.othername DESC NULLS LAST" + "myothertable.othername DESC NULLS LAST", ) self.assert_compile( - table2.select(order_by=[ - table2.c.otherid.nullslast(), - table2.c.othername.desc().nullsfirst()]), + table2.select( + order_by=[ + table2.c.otherid.nullslast(), + table2.c.othername.desc().nullsfirst(), + ] + ), "SELECT myothertable.otherid, myothertable.othername FROM " "myothertable ORDER BY myothertable.otherid NULLS LAST, " - "myothertable.othername DESC NULLS FIRST" + "myothertable.othername DESC NULLS FIRST", ) self.assert_compile( - table2.select(order_by=[table2.c.otherid.nullsfirst(), - table2.c.othername.desc()]), + table2.select( + order_by=[ + table2.c.otherid.nullsfirst(), + table2.c.othername.desc(), + ] + ), "SELECT myothertable.otherid, myothertable.othername FROM " "myothertable ORDER BY myothertable.otherid NULLS FIRST, " - "myothertable.othername DESC" + "myothertable.othername DESC", ) self.assert_compile( - table2.select(order_by=[table2.c.otherid.nullsfirst(), - table2.c.othername.desc().nullslast()]), + table2.select( + order_by=[ + table2.c.otherid.nullsfirst(), + table2.c.othername.desc().nullslast(), + ] + ), "SELECT myothertable.otherid, myothertable.othername FROM " "myothertable ORDER BY myothertable.otherid NULLS FIRST, " - "myothertable.othername DESC NULLS LAST" + "myothertable.othername DESC NULLS LAST", ) def test_orderby_groupby(self): self.assert_compile( - table2.select(order_by=[table2.c.otherid, - asc(table2.c.othername)]), + table2.select( + order_by=[table2.c.otherid, asc(table2.c.othername)] + ), "SELECT myothertable.otherid, myothertable.othername FROM " "myothertable ORDER BY myothertable.otherid, " - "myothertable.othername ASC" + "myothertable.othername ASC", ) self.assert_compile( - table2.select(order_by=[table2.c.otherid, - table2.c.othername.desc()]), + table2.select( + order_by=[table2.c.otherid, table2.c.othername.desc()] + ), "SELECT myothertable.otherid, myothertable.othername FROM " "myothertable ORDER BY myothertable.otherid, " - "myothertable.othername DESC" + "myothertable.othername DESC", ) # generative order_by self.assert_compile( - table2.select().order_by(table2.c.otherid). - order_by(table2.c.othername.desc()), + table2.select() + .order_by(table2.c.otherid) + .order_by(table2.c.othername.desc()), "SELECT myothertable.otherid, myothertable.othername FROM " "myothertable ORDER BY myothertable.otherid, " - "myothertable.othername DESC" + "myothertable.othername DESC", ) self.assert_compile( - table2.select().order_by(table2.c.otherid). - order_by(table2.c.othername.desc() - ).order_by(None), + table2.select() + .order_by(table2.c.otherid) + .order_by(table2.c.othername.desc()) + .order_by(None), "SELECT myothertable.otherid, myothertable.othername " - "FROM myothertable" + "FROM myothertable", ) self.assert_compile( select( [table2.c.othername, func.count(table2.c.otherid)], - group_by=[table2.c.othername]), + group_by=[table2.c.othername], + ), "SELECT myothertable.othername, " "count(myothertable.otherid) AS count_1 " - "FROM myothertable GROUP BY myothertable.othername" + "FROM myothertable GROUP BY myothertable.othername", ) # generative group by self.assert_compile( - select([table2.c.othername, func.count(table2.c.otherid)]). - group_by(table2.c.othername), + select( + [table2.c.othername, func.count(table2.c.otherid)] + ).group_by(table2.c.othername), "SELECT myothertable.othername, " "count(myothertable.otherid) AS count_1 " - "FROM myothertable GROUP BY myothertable.othername" + "FROM myothertable GROUP BY myothertable.othername", ) self.assert_compile( - select([table2.c.othername, func.count(table2.c.otherid)]). - group_by(table2.c.othername).group_by(None), + select([table2.c.othername, func.count(table2.c.otherid)]) + .group_by(table2.c.othername) + .group_by(None), "SELECT myothertable.othername, " "count(myothertable.otherid) AS count_1 " - "FROM myothertable" + "FROM myothertable", ) self.assert_compile( - select([table2.c.othername, func.count(table2.c.otherid)], - group_by=[table2.c.othername], - order_by=[table2.c.othername]), + select( + [table2.c.othername, func.count(table2.c.otherid)], + group_by=[table2.c.othername], + order_by=[table2.c.othername], + ), "SELECT myothertable.othername, " "count(myothertable.otherid) AS count_1 " "FROM myothertable " - "GROUP BY myothertable.othername ORDER BY myothertable.othername" + "GROUP BY myothertable.othername ORDER BY myothertable.othername", ) def test_custom_order_by_clause(self): class CustomCompiler(PGCompiler): def order_by_clause(self, select, **kw): - return super(CustomCompiler, self).\ - order_by_clause(select, **kw) + " CUSTOMIZED" + return ( + super(CustomCompiler, self).order_by_clause(select, **kw) + + " CUSTOMIZED" + ) class CustomDialect(PGDialect): - name = 'custom' + name = "custom" statement_compiler = CustomCompiler stmt = select([table1.c.myid]).order_by(table1.c.myid) @@ -1313,17 +1489,19 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): stmt, "SELECT mytable.myid FROM mytable ORDER BY " "mytable.myid CUSTOMIZED", - dialect=CustomDialect() + dialect=CustomDialect(), ) def test_custom_group_by_clause(self): class CustomCompiler(PGCompiler): def group_by_clause(self, select, **kw): - return super(CustomCompiler, self).\ - group_by_clause(select, **kw) + " CUSTOMIZED" + return ( + super(CustomCompiler, self).group_by_clause(select, **kw) + + " CUSTOMIZED" + ) class CustomDialect(PGDialect): - name = 'custom' + name = "custom" statement_compiler = CustomCompiler stmt = select([table1.c.myid]).group_by(table1.c.myid) @@ -1331,44 +1509,51 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): stmt, "SELECT mytable.myid FROM mytable GROUP BY " "mytable.myid CUSTOMIZED", - dialect=CustomDialect() + dialect=CustomDialect(), ) def test_for_update(self): self.assert_compile( table1.select(table1.c.myid == 7).with_for_update(), "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE") + "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE", + ) # not supported by dialect, should just use update self.assert_compile( table1.select(table1.c.myid == 7).with_for_update(nowait=True), "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE") + "FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE", + ) assert_raises_message( exc.ArgumentError, "Unknown for_update argument: 'unknown_mode'", - table1.select, table1.c.myid == 7, for_update='unknown_mode' + table1.select, + table1.c.myid == 7, + for_update="unknown_mode", ) def test_alias(self): # test the alias for a table1. column names stay the same, # table name "changes" to "foo". self.assert_compile( - select([table1.alias('foo')]), - "SELECT foo.myid, foo.name, foo.description FROM mytable AS foo") + select([table1.alias("foo")]), + "SELECT foo.myid, foo.name, foo.description FROM mytable AS foo", + ) for dialect in (oracle.dialect(),): self.assert_compile( - select([table1.alias('foo')]), + select([table1.alias("foo")]), "SELECT foo.myid, foo.name, foo.description FROM mytable foo", - dialect=dialect) + dialect=dialect, + ) self.assert_compile( select([table1.alias()]), "SELECT mytable_1.myid, mytable_1.name, mytable_1.description " - "FROM mytable AS mytable_1") + "FROM mytable AS mytable_1", + ) # create a select for a join of two tables. use_labels # means the column names will have labels tablename_columnname, @@ -1377,12 +1562,13 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): # from the first table1. q = select( [table1, table2.c.otherid], - table1.c.myid == table2.c.otherid, use_labels=True + table1.c.myid == table2.c.otherid, + use_labels=True, ) # make an alias of the "selectable". column names # stay the same (i.e. the labels), table name "changes" to "t2view". - a = alias(q, 't2view') + a = alias(q, "t2view") # select from that alias, also using labels. two levels of labels # should produce two underscores. @@ -1401,26 +1587,26 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "myothertable_otherid FROM mytable, myothertable " "WHERE mytable.myid = " "myothertable.otherid) AS t2view " - "WHERE t2view.mytable_myid = :mytable_myid_1" + "WHERE t2view.mytable_myid = :mytable_myid_1", ) def test_prefix(self): self.assert_compile( - table1.select().prefix_with("SQL_CALC_FOUND_ROWS"). - prefix_with("SQL_SOME_WEIRD_MYSQL_THING"), + table1.select() + .prefix_with("SQL_CALC_FOUND_ROWS") + .prefix_with("SQL_SOME_WEIRD_MYSQL_THING"), "SELECT SQL_CALC_FOUND_ROWS SQL_SOME_WEIRD_MYSQL_THING " - "mytable.myid, mytable.name, mytable.description FROM mytable" + "mytable.myid, mytable.name, mytable.description FROM mytable", ) def test_prefix_dialect_specific(self): self.assert_compile( - table1.select().prefix_with("SQL_CALC_FOUND_ROWS", - dialect='sqlite'). - prefix_with("SQL_SOME_WEIRD_MYSQL_THING", - dialect='mysql'), + table1.select() + .prefix_with("SQL_CALC_FOUND_ROWS", dialect="sqlite") + .prefix_with("SQL_SOME_WEIRD_MYSQL_THING", dialect="mysql"), "SELECT SQL_SOME_WEIRD_MYSQL_THING " "mytable.myid, mytable.name, mytable.description FROM mytable", - dialect=mysql.dialect() + dialect=mysql.dialect(), ) def test_render_binds_as_literal(self): @@ -1431,140 +1617,149 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): class Compiler(dialect.statement_compiler): ansi_bind_rules = True + dialect.statement_compiler = Compiler self.assert_compile( select([literal("someliteral")]), "SELECT 'someliteral' AS anon_1", - dialect=dialect + dialect=dialect, ) self.assert_compile( select([table1.c.myid + 3]), "SELECT mytable.myid + 3 AS anon_1 FROM mytable", - dialect=dialect + dialect=dialect, ) self.assert_compile( select([table1.c.myid.in_([4, 5, 6])]), "SELECT mytable.myid IN (4, 5, 6) AS anon_1 FROM mytable", - dialect=dialect + dialect=dialect, ) self.assert_compile( select([func.mod(table1.c.myid, 5)]), "SELECT mod(mytable.myid, 5) AS mod_1 FROM mytable", - dialect=dialect + dialect=dialect, ) self.assert_compile( select([literal("foo").in_([])]), "SELECT 1 != 1 AS anon_1", - dialect=dialect + dialect=dialect, ) self.assert_compile( select([literal(util.b("foo"))]), "SELECT 'foo' AS anon_1", - dialect=dialect + dialect=dialect, ) # test callable self.assert_compile( select([table1.c.myid == bindparam("foo", callable_=lambda: 5)]), "SELECT mytable.myid = 5 AS anon_1 FROM mytable", - dialect=dialect + dialect=dialect, ) - empty_in_dialect = default.DefaultDialect(empty_in_strategy='dynamic') + empty_in_dialect = default.DefaultDialect(empty_in_strategy="dynamic") empty_in_dialect.statement_compiler = Compiler assert_raises_message( exc.CompileError, "Bind parameter 'foo' without a " "renderable value not allowed here.", - bindparam("foo").in_( - []).compile, - dialect=empty_in_dialect) + bindparam("foo").in_([]).compile, + dialect=empty_in_dialect, + ) def test_collate(self): # columns clause self.assert_compile( - select([column('x').collate('bar')]), - "SELECT x COLLATE bar AS anon_1" + select([column("x").collate("bar")]), + "SELECT x COLLATE bar AS anon_1", ) # WHERE clause self.assert_compile( - select([column('x')]).where(column('x').collate('bar') == 'foo'), - "SELECT x WHERE (x COLLATE bar) = :param_1" + select([column("x")]).where(column("x").collate("bar") == "foo"), + "SELECT x WHERE (x COLLATE bar) = :param_1", ) # ORDER BY clause self.assert_compile( - select([column('x')]).order_by(column('x').collate('bar')), - "SELECT x ORDER BY x COLLATE bar" + select([column("x")]).order_by(column("x").collate("bar")), + "SELECT x ORDER BY x COLLATE bar", ) def test_literal(self): - self.assert_compile(select([literal('foo')]), - "SELECT :param_1 AS anon_1") + self.assert_compile( + select([literal("foo")]), "SELECT :param_1 AS anon_1" + ) self.assert_compile( - select( - [ - literal("foo") + - literal("bar")], - from_obj=[table1]), - "SELECT :param_1 || :param_2 AS anon_1 FROM mytable") + select([literal("foo") + literal("bar")], from_obj=[table1]), + "SELECT :param_1 || :param_2 AS anon_1 FROM mytable", + ) def test_calculated_columns(self): - value_tbl = table('values', - column('id', Integer), - column('val1', Float), - column('val2', Float), - ) + value_tbl = table( + "values", + column("id", Integer), + column("val1", Float), + column("val2", Float), + ) self.assert_compile( - select([value_tbl.c.id, (value_tbl.c.val2 - - value_tbl.c.val1) / value_tbl.c.val1]), + select( + [ + value_tbl.c.id, + (value_tbl.c.val2 - value_tbl.c.val1) / value_tbl.c.val1, + ] + ), "SELECT values.id, (values.val2 - values.val1) " - "/ values.val1 AS anon_1 FROM values" + "/ values.val1 AS anon_1 FROM values", ) self.assert_compile( - select([ - value_tbl.c.id], - (value_tbl.c.val2 - value_tbl.c.val1) / - value_tbl.c.val1 > 2.0), + select( + [value_tbl.c.id], + (value_tbl.c.val2 - value_tbl.c.val1) / value_tbl.c.val1 > 2.0, + ), "SELECT values.id FROM values WHERE " - "(values.val2 - values.val1) / values.val1 > :param_1" + "(values.val2 - values.val1) / values.val1 > :param_1", ) self.assert_compile( - select([value_tbl.c.id], value_tbl.c.val1 / - (value_tbl.c.val2 - value_tbl.c.val1) / - value_tbl.c.val1 > 2.0), + select( + [value_tbl.c.id], + value_tbl.c.val1 + / (value_tbl.c.val2 - value_tbl.c.val1) + / value_tbl.c.val1 + > 2.0, + ), "SELECT values.id FROM values WHERE " "(values.val1 / (values.val2 - values.val1)) " - "/ values.val1 > :param_1" + "/ values.val1 > :param_1", ) def test_percent_chars(self): - t = table("table%name", - column("percent%"), - column("%(oneofthese)s"), - column("spaces % more spaces"), - ) + t = table( + "table%name", + column("percent%"), + column("%(oneofthese)s"), + column("spaces % more spaces"), + ) self.assert_compile( t.select(use_labels=True), - '''SELECT "table%name"."percent%" AS "table%name_percent%", ''' - '''"table%name"."%(oneofthese)s" AS ''' - '''"table%name_%(oneofthese)s", ''' - '''"table%name"."spaces % more spaces" AS ''' - '''"table%name_spaces % ''' - '''more spaces" FROM "table%name"''' + """SELECT "table%name"."percent%" AS "table%name_percent%", """ + """"table%name"."%(oneofthese)s" AS """ + """"table%name_%(oneofthese)s", """ + """"table%name"."spaces % more spaces" AS """ + """"table%name_spaces % """ + '''more spaces" FROM "table%name"''', ) def test_joins(self): @@ -1572,22 +1767,31 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): join(table2, table1, table1.c.myid == table2.c.otherid).select(), "SELECT myothertable.otherid, myothertable.othername, " "mytable.myid, mytable.name, mytable.description FROM " - "myothertable JOIN mytable ON mytable.myid = myothertable.otherid" + "myothertable JOIN mytable ON mytable.myid = myothertable.otherid", ) self.assert_compile( select( [table1], - from_obj=[join(table1, table2, table1.c.myid - == table2.c.otherid)] + from_obj=[ + join(table1, table2, table1.c.myid == table2.c.otherid) + ], ), "SELECT mytable.myid, mytable.name, mytable.description FROM " - "mytable JOIN myothertable ON mytable.myid = myothertable.otherid") + "mytable JOIN myothertable ON mytable.myid = myothertable.otherid", + ) self.assert_compile( select( - [join(join(table1, table2, table1.c.myid == table2.c.otherid), - table3, table1.c.myid == table3.c.userid)] + [ + join( + join( + table1, table2, table1.c.myid == table2.c.otherid + ), + table3, + table1.c.myid == table3.c.userid, + ) + ] ), "SELECT mytable.myid, mytable.name, mytable.description, " "myothertable.otherid, myothertable.othername, " @@ -1595,27 +1799,29 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "thirdtable.otherstuff FROM mytable JOIN myothertable " "ON mytable.myid =" " myothertable.otherid JOIN thirdtable ON " - "mytable.myid = thirdtable.userid" + "mytable.myid = thirdtable.userid", ) self.assert_compile( - join(users, addresses, users.c.user_id == - addresses.c.user_id).select(), + join( + users, addresses, users.c.user_id == addresses.c.user_id + ).select(), "SELECT users.user_id, users.user_name, users.password, " "addresses.address_id, addresses.user_id, addresses.street, " "addresses.city, addresses.state, addresses.zip " "FROM users JOIN addresses " - "ON users.user_id = addresses.user_id" + "ON users.user_id = addresses.user_id", ) self.assert_compile( - select([table1, table2, table3], - - from_obj=[join(table1, table2, - table1.c.myid == table2.c.otherid). - outerjoin(table3, - table1.c.myid == table3.c.userid)] - ), + select( + [table1, table2, table3], + from_obj=[ + join( + table1, table2, table1.c.myid == table2.c.otherid + ).outerjoin(table3, table1.c.myid == table3.c.userid) + ], + ), "SELECT mytable.myid, mytable.name, mytable.description, " "myothertable.otherid, myothertable.othername, " "thirdtable.userid," @@ -1623,15 +1829,21 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "JOIN myothertable ON mytable.myid " "= myothertable.otherid LEFT OUTER JOIN thirdtable " "ON mytable.myid =" - " thirdtable.userid" + " thirdtable.userid", ) self.assert_compile( - select([table1, table2, table3], - from_obj=[outerjoin(table1, - join(table2, table3, table2.c.otherid - == table3.c.userid), - table1.c.myid == table2.c.otherid)] - ), + select( + [table1, table2, table3], + from_obj=[ + outerjoin( + table1, + join( + table2, table3, table2.c.otherid == table3.c.userid + ), + table1.c.myid == table2.c.otherid, + ) + ], + ), "SELECT mytable.myid, mytable.name, mytable.description, " "myothertable.otherid, myothertable.othername, " "thirdtable.userid," @@ -1639,47 +1851,49 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "(myothertable " "JOIN thirdtable ON myothertable.otherid = " "thirdtable.userid) ON " - "mytable.myid = myothertable.otherid" + "mytable.myid = myothertable.otherid", ) query = select( [table1, table2], or_( - table1.c.name == 'fred', + table1.c.name == "fred", table1.c.myid == 10, - table2.c.othername != 'jack', - text("EXISTS (select yay from foo where boo = lar)") + table2.c.othername != "jack", + text("EXISTS (select yay from foo where boo = lar)"), ), - from_obj=[outerjoin(table1, table2, - table1.c.myid == table2.c.otherid)] + from_obj=[ + outerjoin(table1, table2, table1.c.myid == table2.c.otherid) + ], ) self.assert_compile( - query, "SELECT mytable.myid, mytable.name, mytable.description, " + query, + "SELECT mytable.myid, mytable.name, mytable.description, " "myothertable.otherid, myothertable.othername " "FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = " "myothertable.otherid WHERE mytable.name = :name_1 OR " "mytable.myid = :myid_1 OR myothertable.othername != :othername_1 " - "OR EXISTS (select yay from foo where boo = lar)", ) + "OR EXISTS (select yay from foo where boo = lar)", + ) def test_full_outer_join(self): for spec in [ join(table1, table2, table1.c.myid == table2.c.otherid, full=True), outerjoin( - table1, table2, - table1.c.myid == table2.c.otherid, full=True), - table1.join( - table2, - table1.c.myid == table2.c.otherid, full=True), + table1, table2, table1.c.myid == table2.c.otherid, full=True + ), + table1.join(table2, table1.c.myid == table2.c.otherid, full=True), table1.outerjoin( - table2, - table1.c.myid == table2.c.otherid, full=True), + table2, table1.c.myid == table2.c.otherid, full=True + ), ]: stmt = select([table1]).select_from(spec) self.assert_compile( stmt, "SELECT mytable.myid, mytable.name, mytable.description FROM " "mytable FULL OUTER JOIN myothertable " - "ON mytable.myid = myothertable.otherid") + "ON mytable.myid = myothertable.otherid", + ) def test_compound_selects(self): assert_raises_message( @@ -1687,7 +1901,9 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "All selectables passed to CompoundSelect " "must have identical numbers of columns; " "select #1 has 2 columns, select #2 has 3", - union, table3.select(), table1.select() + union, + table3.select(), + table1.select(), ) x = union( @@ -1697,36 +1913,39 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): ) self.assert_compile( - x, "SELECT mytable.myid, mytable.name, " + x, + "SELECT mytable.myid, mytable.name, " "mytable.description " "FROM mytable WHERE " "mytable.myid = :myid_1 UNION " "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = :myid_2 " - "ORDER BY mytable.myid") - - x = union( - select([table1]), - select([table1]) + "ORDER BY mytable.myid", ) + + x = union(select([table1]), select([table1])) x = union(x, select([table1])) self.assert_compile( - x, "(SELECT mytable.myid, mytable.name, mytable.description " + x, + "(SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable UNION SELECT mytable.myid, mytable.name, " "mytable.description FROM mytable) UNION SELECT mytable.myid," - " mytable.name, mytable.description FROM mytable") + " mytable.name, mytable.description FROM mytable", + ) u1 = union( select([table1.c.myid, table1.c.name]), select([table2]), - select([table3]) + select([table3]), ) self.assert_compile( - u1, "SELECT mytable.myid, mytable.name " + u1, + "SELECT mytable.myid, mytable.name " "FROM mytable UNION SELECT myothertable.otherid, " "myothertable.othername FROM myothertable " "UNION SELECT thirdtable.userid, thirdtable.otherstuff " - "FROM thirdtable") + "FROM thirdtable", + ) assert u1.corresponding_column(table2.c.otherid) is u1.c.myid @@ -1734,25 +1953,30 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): union( select([table1.c.myid, table1.c.name]), select([table2]), - order_by=['myid'], + order_by=["myid"], offset=10, - limit=5 + limit=5, ), "SELECT mytable.myid, mytable.name " "FROM mytable UNION SELECT myothertable.otherid, " "myothertable.othername " "FROM myothertable ORDER BY myid " # note table name is omitted "LIMIT :param_1 OFFSET :param_2", - {'param_1': 5, 'param_2': 10} + {"param_1": 5, "param_2": 10}, ) self.assert_compile( union( - select([table1.c.myid, table1.c.name, - func.max(table1.c.description)], - table1.c.name == 'name2', - group_by=[table1.c.myid, table1.c.name]), - table1.select(table1.c.name == 'name1') + select( + [ + table1.c.myid, + table1.c.name, + func.max(table1.c.description), + ], + table1.c.name == "name2", + group_by=[table1.c.myid, table1.c.name], + ), + table1.select(table1.c.name == "name1"), ), "SELECT mytable.myid, mytable.name, " "max(mytable.description) AS max_1 " @@ -1760,183 +1984,155 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "GROUP BY mytable.myid, " "mytable.name UNION SELECT mytable.myid, mytable.name, " "mytable.description " - "FROM mytable WHERE mytable.name = :name_2" + "FROM mytable WHERE mytable.name = :name_2", ) self.assert_compile( union( - select([literal(100).label('value')]), - select([literal(200).label('value')]) + select([literal(100).label("value")]), + select([literal(200).label("value")]), ), - "SELECT :param_1 AS value UNION SELECT :param_2 AS value" + "SELECT :param_1 AS value UNION SELECT :param_2 AS value", ) self.assert_compile( union_all( select([table1.c.myid]), - union( - select([table2.c.otherid]), - select([table3.c.userid]), - ) + union(select([table2.c.otherid]), select([table3.c.userid])), ), - "SELECT mytable.myid FROM mytable UNION ALL " "(SELECT myothertable.otherid FROM myothertable UNION " - "SELECT thirdtable.userid FROM thirdtable)" + "SELECT thirdtable.userid FROM thirdtable)", ) - s = select([column('foo'), column('bar')]) + s = select([column("foo"), column("bar")]) self.assert_compile( - union( - s.order_by("foo"), - s.order_by("bar")), + union(s.order_by("foo"), s.order_by("bar")), "(SELECT foo, bar ORDER BY foo) UNION " - "(SELECT foo, bar ORDER BY bar)") + "(SELECT foo, bar ORDER BY bar)", + ) self.assert_compile( - union(s.order_by("foo").self_group(), - s.order_by("bar").limit(10).self_group()), + union( + s.order_by("foo").self_group(), + s.order_by("bar").limit(10).self_group(), + ), "(SELECT foo, bar ORDER BY foo) UNION (SELECT foo, " "bar ORDER BY bar LIMIT :param_1)", - {'param_1': 10} - + {"param_1": 10}, ) def test_compound_grouping(self): - s = select([column('foo'), column('bar')]).select_from(text('bat')) + s = select([column("foo"), column("bar")]).select_from(text("bat")) self.assert_compile( union(union(union(s, s), s), s), "((SELECT foo, bar FROM bat UNION SELECT foo, bar FROM bat) " - "UNION SELECT foo, bar FROM bat) UNION SELECT foo, bar FROM bat" + "UNION SELECT foo, bar FROM bat) UNION SELECT foo, bar FROM bat", ) self.assert_compile( union(s, s, s, s), "SELECT foo, bar FROM bat UNION SELECT foo, bar " "FROM bat UNION SELECT foo, bar FROM bat " - "UNION SELECT foo, bar FROM bat" + "UNION SELECT foo, bar FROM bat", ) self.assert_compile( union(s, union(s, union(s, s))), "SELECT foo, bar FROM bat UNION (SELECT foo, bar FROM bat " "UNION (SELECT foo, bar FROM bat " - "UNION SELECT foo, bar FROM bat))" + "UNION SELECT foo, bar FROM bat))", ) self.assert_compile( select([s.alias()]), - 'SELECT anon_1.foo, anon_1.bar FROM ' - '(SELECT foo, bar FROM bat) AS anon_1' + "SELECT anon_1.foo, anon_1.bar FROM " + "(SELECT foo, bar FROM bat) AS anon_1", ) self.assert_compile( select([union(s, s).alias()]), - 'SELECT anon_1.foo, anon_1.bar FROM ' - '(SELECT foo, bar FROM bat UNION ' - 'SELECT foo, bar FROM bat) AS anon_1' + "SELECT anon_1.foo, anon_1.bar FROM " + "(SELECT foo, bar FROM bat UNION " + "SELECT foo, bar FROM bat) AS anon_1", ) self.assert_compile( select([except_(s, s).alias()]), - 'SELECT anon_1.foo, anon_1.bar FROM ' - '(SELECT foo, bar FROM bat EXCEPT ' - 'SELECT foo, bar FROM bat) AS anon_1' + "SELECT anon_1.foo, anon_1.bar FROM " + "(SELECT foo, bar FROM bat EXCEPT " + "SELECT foo, bar FROM bat) AS anon_1", ) # this query sqlite specifically chokes on self.assert_compile( - union( - except_(s, s), - s - ), + union(except_(s, s), s), "(SELECT foo, bar FROM bat EXCEPT SELECT foo, bar FROM bat) " - "UNION SELECT foo, bar FROM bat" + "UNION SELECT foo, bar FROM bat", ) self.assert_compile( - union( - s, - except_(s, s), - ), + union(s, except_(s, s)), "SELECT foo, bar FROM bat " - "UNION (SELECT foo, bar FROM bat EXCEPT SELECT foo, bar FROM bat)" + "UNION (SELECT foo, bar FROM bat EXCEPT SELECT foo, bar FROM bat)", ) # this solves it self.assert_compile( - union( - except_(s, s).alias().select(), - s - ), + union(except_(s, s).alias().select(), s), "SELECT anon_1.foo, anon_1.bar FROM " "(SELECT foo, bar FROM bat EXCEPT " "SELECT foo, bar FROM bat) AS anon_1 " - "UNION SELECT foo, bar FROM bat" + "UNION SELECT foo, bar FROM bat", ) self.assert_compile( - except_( - union(s, s), - union(s, s) - ), + except_(union(s, s), union(s, s)), "(SELECT foo, bar FROM bat UNION SELECT foo, bar FROM bat) " - "EXCEPT (SELECT foo, bar FROM bat UNION SELECT foo, bar FROM bat)" + "EXCEPT (SELECT foo, bar FROM bat UNION SELECT foo, bar FROM bat)", ) s2 = union(s, s) s3 = union(s2, s2) - self.assert_compile(s3, "(SELECT foo, bar FROM bat " - "UNION SELECT foo, bar FROM bat) " - "UNION (SELECT foo, bar FROM bat " - "UNION SELECT foo, bar FROM bat)") + self.assert_compile( + s3, + "(SELECT foo, bar FROM bat " + "UNION SELECT foo, bar FROM bat) " + "UNION (SELECT foo, bar FROM bat " + "UNION SELECT foo, bar FROM bat)", + ) self.assert_compile( - union( - intersect(s, s), - intersect(s, s) - ), + union(intersect(s, s), intersect(s, s)), "(SELECT foo, bar FROM bat INTERSECT SELECT foo, bar FROM bat) " "UNION (SELECT foo, bar FROM bat INTERSECT " - "SELECT foo, bar FROM bat)" + "SELECT foo, bar FROM bat)", ) # tests for [ticket:2528] # sqlite hates all of these. self.assert_compile( - union( - s.limit(1), - s.offset(2) - ), + union(s.limit(1), s.offset(2)), "(SELECT foo, bar FROM bat LIMIT :param_1) " - "UNION (SELECT foo, bar FROM bat LIMIT -1 OFFSET :param_2)" + "UNION (SELECT foo, bar FROM bat LIMIT -1 OFFSET :param_2)", ) self.assert_compile( - union( - s.order_by(column('bar')), - s.offset(2) - ), + union(s.order_by(column("bar")), s.offset(2)), "(SELECT foo, bar FROM bat ORDER BY bar) " - "UNION (SELECT foo, bar FROM bat LIMIT -1 OFFSET :param_1)" + "UNION (SELECT foo, bar FROM bat LIMIT -1 OFFSET :param_1)", ) self.assert_compile( - union( - s.limit(1).alias('a'), - s.limit(2).alias('b') - ), + union(s.limit(1).alias("a"), s.limit(2).alias("b")), "(SELECT foo, bar FROM bat LIMIT :param_1) " - "UNION (SELECT foo, bar FROM bat LIMIT :param_2)" + "UNION (SELECT foo, bar FROM bat LIMIT :param_2)", ) self.assert_compile( - union( - s.limit(1).self_group(), - s.limit(2).self_group() - ), + union(s.limit(1).self_group(), s.limit(2).self_group()), "(SELECT foo, bar FROM bat LIMIT :param_1) " - "UNION (SELECT foo, bar FROM bat LIMIT :param_2)" + "UNION (SELECT foo, bar FROM bat LIMIT :param_2)", ) self.assert_compile( @@ -1944,22 +2140,19 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT anon_1.foo, anon_1.bar FROM " "((SELECT foo, bar FROM bat LIMIT :param_1) " "UNION (SELECT foo, bar FROM bat LIMIT :param_2 OFFSET :param_3)) " - "AS anon_1" + "AS anon_1", ) # this version works for SQLite self.assert_compile( - union( - s.limit(1).alias().select(), - s.offset(2).alias().select(), - ), + union(s.limit(1).alias().select(), s.offset(2).alias().select()), "SELECT anon_1.foo, anon_1.bar " "FROM (SELECT foo, bar FROM bat" " LIMIT :param_1) AS anon_1 " "UNION SELECT anon_2.foo, anon_2.bar " "FROM (SELECT foo, bar " "FROM bat" - " LIMIT -1 OFFSET :param_2) AS anon_2" + " LIMIT -1 OFFSET :param_2) AS anon_2", ) def test_binds(self): @@ -1971,15 +2164,16 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): expected_default_params_list, test_param_dict, expected_test_params_dict, - expected_test_params_list + expected_test_params_list, ) in [ ( select( [table1, table2], and_( table1.c.myid == table2.c.otherid, - table1.c.name == bindparam('mytablename') - )), + table1.c.name == bindparam("mytablename"), + ), + ), "SELECT mytable.myid, mytable.name, mytable.description, " "myothertable.otherid, myothertable.othername FROM mytable, " "myothertable WHERE mytable.myid = myothertable.otherid " @@ -1988,55 +2182,80 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "myothertable.otherid, myothertable.othername FROM mytable, " "myothertable WHERE mytable.myid = myothertable.otherid AND " "mytable.name = ?", - {'mytablename': None}, [None], - {'mytablename': 5}, {'mytablename': 5}, [5] + {"mytablename": None}, + [None], + {"mytablename": 5}, + {"mytablename": 5}, + [5], ), ( - select([table1], or_(table1.c.myid == bindparam('myid'), - table2.c.otherid == bindparam('myid'))), + select( + [table1], + or_( + table1.c.myid == bindparam("myid"), + table2.c.otherid == bindparam("myid"), + ), + ), "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable, myothertable WHERE mytable.myid = :myid " "OR myothertable.otherid = :myid", "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable, myothertable WHERE mytable.myid = ? " "OR myothertable.otherid = ?", - {'myid': None}, [None, None], - {'myid': 5}, {'myid': 5}, [5, 5] + {"myid": None}, + [None, None], + {"myid": 5}, + {"myid": 5}, + [5, 5], ), ( - text("SELECT mytable.myid, mytable.name, " - "mytable.description FROM " - "mytable, myothertable WHERE mytable.myid = :myid OR " - "myothertable.otherid = :myid"), + text( + "SELECT mytable.myid, mytable.name, " + "mytable.description FROM " + "mytable, myothertable WHERE mytable.myid = :myid OR " + "myothertable.otherid = :myid" + ), "SELECT mytable.myid, mytable.name, mytable.description FROM " "mytable, myothertable WHERE mytable.myid = :myid OR " "myothertable.otherid = :myid", "SELECT mytable.myid, mytable.name, mytable.description FROM " "mytable, myothertable WHERE mytable.myid = ? OR " "myothertable.otherid = ?", - {'myid': None}, [None, None], - {'myid': 5}, {'myid': 5}, [5, 5] + {"myid": None}, + [None, None], + {"myid": 5}, + {"myid": 5}, + [5, 5], ), ( - select([table1], or_(table1.c.myid == - bindparam('myid', unique=True), - table2.c.otherid == - bindparam('myid', unique=True))), + select( + [table1], + or_( + table1.c.myid == bindparam("myid", unique=True), + table2.c.otherid == bindparam("myid", unique=True), + ), + ), "SELECT mytable.myid, mytable.name, mytable.description FROM " "mytable, myothertable WHERE mytable.myid = " ":myid_1 OR myothertable.otherid = :myid_2", "SELECT mytable.myid, mytable.name, mytable.description FROM " "mytable, myothertable WHERE mytable.myid = ? " "OR myothertable.otherid = ?", - {'myid_1': None, 'myid_2': None}, [None, None], - {'myid_1': 5, 'myid_2': 6}, {'myid_1': 5, 'myid_2': 6}, [5, 6] + {"myid_1": None, "myid_2": None}, + [None, None], + {"myid_1": 5, "myid_2": 6}, + {"myid_1": 5, "myid_2": 6}, + [5, 6], ), ( - bindparam('test', type_=String, required=False) + text("'hi'"), + bindparam("test", type_=String, required=False) + text("'hi'"), ":test || 'hi'", "? || 'hi'", - {'test': None}, [None], - {}, {'test': None}, [None] + {"test": None}, + [None], + {}, + {"test": None}, + [None], ), ( # testing select.params() here - bindparam() objects @@ -2044,89 +2263,125 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): select( [table1], or_( - table1.c.myid == bindparam('myid'), - table2.c.otherid == bindparam('myotherid') - )).params({'myid': 8, 'myotherid': 7}), + table1.c.myid == bindparam("myid"), + table2.c.otherid == bindparam("myotherid"), + ), + ).params({"myid": 8, "myotherid": 7}), "SELECT mytable.myid, mytable.name, mytable.description FROM " "mytable, myothertable WHERE mytable.myid = " ":myid OR myothertable.otherid = :myotherid", "SELECT mytable.myid, mytable.name, mytable.description FROM " "mytable, myothertable WHERE mytable.myid = " "? OR myothertable.otherid = ?", - {'myid': 8, 'myotherid': 7}, [8, 7], - {'myid': 5}, {'myid': 5, 'myotherid': 7}, [5, 7] + {"myid": 8, "myotherid": 7}, + [8, 7], + {"myid": 5}, + {"myid": 5, "myotherid": 7}, + [5, 7], ), ( - select([table1], or_(table1.c.myid == - bindparam('myid', value=7, unique=True), - table2.c.otherid == - bindparam('myid', value=8, unique=True))), + select( + [table1], + or_( + table1.c.myid + == bindparam("myid", value=7, unique=True), + table2.c.otherid + == bindparam("myid", value=8, unique=True), + ), + ), "SELECT mytable.myid, mytable.name, mytable.description FROM " "mytable, myothertable WHERE mytable.myid = " ":myid_1 OR myothertable.otherid = :myid_2", "SELECT mytable.myid, mytable.name, mytable.description FROM " "mytable, myothertable WHERE mytable.myid = " "? OR myothertable.otherid = ?", - {'myid_1': 7, 'myid_2': 8}, [7, 8], - {'myid_1': 5, 'myid_2': 6}, {'myid_1': 5, 'myid_2': 6}, [5, 6] + {"myid_1": 7, "myid_2": 8}, + [7, 8], + {"myid_1": 5, "myid_2": 6}, + {"myid_1": 5, "myid_2": 6}, + [5, 6], ), ]: - self.assert_compile(stmt, expected_named_stmt, - params=expected_default_params_dict) - self.assert_compile(stmt, expected_positional_stmt, - dialect=sqlite.dialect()) + self.assert_compile( + stmt, expected_named_stmt, params=expected_default_params_dict + ) + self.assert_compile( + stmt, expected_positional_stmt, dialect=sqlite.dialect() + ) nonpositional = stmt.compile() positional = stmt.compile(dialect=sqlite.dialect()) pp = positional.params - eq_([pp[k] for k in positional.positiontup], - expected_default_params_list) + eq_( + [pp[k] for k in positional.positiontup], + expected_default_params_list, + ) - eq_(nonpositional.construct_params(test_param_dict), - expected_test_params_dict) + eq_( + nonpositional.construct_params(test_param_dict), + expected_test_params_dict, + ) pp = positional.construct_params(test_param_dict) eq_( [pp[k] for k in positional.positiontup], - expected_test_params_list + expected_test_params_list, ) # check that params() doesn't modify original statement - s = select([table1], or_(table1.c.myid == bindparam('myid'), - table2.c.otherid == - bindparam('myotherid'))) - s2 = s.params({'myid': 8, 'myotherid': 7}) - s3 = s2.params({'myid': 9}) - assert s.compile().params == {'myid': None, 'myotherid': None} - assert s2.compile().params == {'myid': 8, 'myotherid': 7} - assert s3.compile().params == {'myid': 9, 'myotherid': 7} + s = select( + [table1], + or_( + table1.c.myid == bindparam("myid"), + table2.c.otherid == bindparam("myotherid"), + ), + ) + s2 = s.params({"myid": 8, "myotherid": 7}) + s3 = s2.params({"myid": 9}) + assert s.compile().params == {"myid": None, "myotherid": None} + assert s2.compile().params == {"myid": 8, "myotherid": 7} + assert s3.compile().params == {"myid": 9, "myotherid": 7} # test using same 'unique' param object twice in one compile s = select([table1.c.myid]).where(table1.c.myid == 12).as_scalar() s2 = select([table1, s], table1.c.myid == s) self.assert_compile( - s2, "SELECT mytable.myid, mytable.name, mytable.description, " + s2, + "SELECT mytable.myid, mytable.name, mytable.description, " "(SELECT mytable.myid FROM mytable WHERE mytable.myid = " ":myid_1) AS anon_1 FROM mytable WHERE mytable.myid = " - "(SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_1)") + "(SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_1)", + ) positional = s2.compile(dialect=sqlite.dialect()) pp = positional.params assert [pp[k] for k in positional.positiontup] == [12, 12] # check that conflicts with "unique" params are caught - s = select([table1], or_(table1.c.myid == 7, - table1.c.myid == bindparam('myid_1'))) - assert_raises_message(exc.CompileError, - "conflicts with unique bind parameter " - "of the same name", - str, s) - - s = select([table1], or_(table1.c.myid == 7, table1.c.myid == 8, - table1.c.myid == bindparam('myid_1'))) - assert_raises_message(exc.CompileError, - "conflicts with unique bind parameter " - "of the same name", - str, s) + s = select( + [table1], + or_(table1.c.myid == 7, table1.c.myid == bindparam("myid_1")), + ) + assert_raises_message( + exc.CompileError, + "conflicts with unique bind parameter " "of the same name", + str, + s, + ) + + s = select( + [table1], + or_( + table1.c.myid == 7, + table1.c.myid == 8, + table1.c.myid == bindparam("myid_1"), + ), + ) + assert_raises_message( + exc.CompileError, + "conflicts with unique bind parameter " "of the same name", + str, + s, + ) def _test_binds_no_hash_collision(self): """test that construct_params doesn't corrupt dict @@ -2134,84 +2389,85 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): total_params = 100000 - in_clause = [':in%d' % i for i in range(total_params)] - params = dict(('in%d' % i, i) for i in range(total_params)) - t = text('text clause %s' % ', '.join(in_clause)) + in_clause = [":in%d" % i for i in range(total_params)] + params = dict(("in%d" % i, i) for i in range(total_params)) + t = text("text clause %s" % ", ".join(in_clause)) eq_(len(t.bindparams), total_params) c = t.compile() pp = c.construct_params(params) - eq_(len(set(pp)), total_params, '%s %s' % (len(set(pp)), len(pp))) + eq_(len(set(pp)), total_params, "%s %s" % (len(set(pp)), len(pp))) eq_(len(set(pp.values())), total_params) def test_bind_as_col(self): - t = table('foo', column('id')) + t = table("foo", column("id")) - s = select([t, literal('lala').label('hoho')]) + s = select([t, literal("lala").label("hoho")]) self.assert_compile(s, "SELECT foo.id, :param_1 AS hoho FROM foo") assert [str(c) for c in s.c] == ["id", "hoho"] def test_bind_callable(self): - expr = column('x') == bindparam("key", callable_=lambda: 12) - self.assert_compile( - expr, - "x = :key", - {'x': 12} - ) + expr = column("x") == bindparam("key", callable_=lambda: 12) + self.assert_compile(expr, "x = :key", {"x": 12}) def test_bind_params_missing(self): assert_raises_message( exc.InvalidRequestError, r"A value is required for bind parameter 'x'", - select( - [table1]).where( + select([table1]) + .where( and_( table1.c.myid == bindparam("x", required=True), - table1.c.name == bindparam("y", required=True) + table1.c.name == bindparam("y", required=True), ) - ).compile().construct_params, - params=dict(y=5) + ) + .compile() + .construct_params, + params=dict(y=5), ) assert_raises_message( exc.InvalidRequestError, r"A value is required for bind parameter 'x'", - select( - [table1]).where( - table1.c.myid == bindparam( - "x", - required=True)).compile().construct_params) + select([table1]) + .where(table1.c.myid == bindparam("x", required=True)) + .compile() + .construct_params, + ) assert_raises_message( exc.InvalidRequestError, r"A value is required for bind parameter 'x', " "in parameter group 2", - select( - [table1]).where( + select([table1]) + .where( and_( table1.c.myid == bindparam("x", required=True), - table1.c.name == bindparam("y", required=True) + table1.c.name == bindparam("y", required=True), ) - ).compile().construct_params, - params=dict(y=5), _group_number=2) + ) + .compile() + .construct_params, + params=dict(y=5), + _group_number=2, + ) assert_raises_message( exc.InvalidRequestError, r"A value is required for bind parameter 'x', " "in parameter group 2", - select( - [table1]).where( - table1.c.myid == bindparam( - "x", - required=True)).compile().construct_params, - _group_number=2) + select([table1]) + .where(table1.c.myid == bindparam("x", required=True)) + .compile() + .construct_params, + _group_number=2, + ) def test_tuple(self): self.assert_compile( - tuple_(table1.c.myid, table1.c.name).in_( - [(1, 'foo'), (5, 'bar')]), + tuple_(table1.c.myid, table1.c.name).in_([(1, "foo"), (5, "bar")]), "(mytable.myid, mytable.name) IN " - "((:param_1, :param_2), (:param_3, :param_4))" + "((:param_1, :param_2), (:param_3, :param_4))", ) self.assert_compile( @@ -2219,7 +2475,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): [tuple_(table2.c.otherid, table2.c.othername)] ), "(mytable.myid, mytable.name) IN " - "((myothertable.otherid, myothertable.othername))" + "((myothertable.otherid, myothertable.othername))", ) self.assert_compile( @@ -2227,226 +2483,245 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): select([table2.c.otherid, table2.c.othername]) ), "(mytable.myid, mytable.name) IN (SELECT " - "myothertable.otherid, myothertable.othername FROM myothertable)" + "myothertable.otherid, myothertable.othername FROM myothertable)", ) def test_expanding_parameter(self): self.assert_compile( tuple_(table1.c.myid, table1.c.name).in_( - bindparam('foo', expanding=True)), - "(mytable.myid, mytable.name) IN ([EXPANDING_foo])" + bindparam("foo", expanding=True) + ), + "(mytable.myid, mytable.name) IN ([EXPANDING_foo])", ) self.assert_compile( - table1.c.myid.in_(bindparam('foo', expanding=True)), - "mytable.myid IN ([EXPANDING_foo])" + table1.c.myid.in_(bindparam("foo", expanding=True)), + "mytable.myid IN ([EXPANDING_foo])", ) def test_cast(self): - tbl = table('casttest', - column('id', Integer), - column('v1', Float), - column('v2', Float), - column('ts', TIMESTAMP), - ) + tbl = table( + "casttest", + column("id", Integer), + column("v1", Float), + column("v2", Float), + column("ts", TIMESTAMP), + ) def check_results(dialect, expected_results, literal): - eq_(len(expected_results), 5, - 'Incorrect number of expected results') - eq_(str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), - 'CAST(casttest.v1 AS %s)' % expected_results[0]) - eq_(str(tbl.c.v1.cast(Numeric).compile(dialect=dialect)), - 'CAST(casttest.v1 AS %s)' % expected_results[0]) - eq_(str(cast(tbl.c.v1, Numeric(12, 9)).compile(dialect=dialect)), - 'CAST(casttest.v1 AS %s)' % expected_results[1]) - eq_(str(cast(tbl.c.ts, Date).compile(dialect=dialect)), - 'CAST(casttest.ts AS %s)' % expected_results[2]) - eq_(str(cast(1234, Text).compile(dialect=dialect)), - 'CAST(%s AS %s)' % (literal, expected_results[3])) - eq_(str(cast('test', String(20)).compile(dialect=dialect)), - 'CAST(%s AS %s)' % (literal, expected_results[4])) + eq_( + len(expected_results), + 5, + "Incorrect number of expected results", + ) + eq_( + str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), + "CAST(casttest.v1 AS %s)" % expected_results[0], + ) + eq_( + str(tbl.c.v1.cast(Numeric).compile(dialect=dialect)), + "CAST(casttest.v1 AS %s)" % expected_results[0], + ) + eq_( + str(cast(tbl.c.v1, Numeric(12, 9)).compile(dialect=dialect)), + "CAST(casttest.v1 AS %s)" % expected_results[1], + ) + eq_( + str(cast(tbl.c.ts, Date).compile(dialect=dialect)), + "CAST(casttest.ts AS %s)" % expected_results[2], + ) + eq_( + str(cast(1234, Text).compile(dialect=dialect)), + "CAST(%s AS %s)" % (literal, expected_results[3]), + ) + eq_( + str(cast("test", String(20)).compile(dialect=dialect)), + "CAST(%s AS %s)" % (literal, expected_results[4]), + ) # fixme: shoving all of this dialect-specific stuff in one test # is now officially completely ridiculous AND non-obviously omits # coverage on other dialects. sel = select([tbl, cast(tbl.c.v1, Numeric)]).compile( - dialect=dialect) + dialect=dialect + ) if isinstance(dialect, type(mysql.dialect())): - eq_(str(sel), + eq_( + str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, " "casttest.ts, " - "CAST(casttest.v1 AS DECIMAL) AS anon_1 \nFROM casttest") + "CAST(casttest.v1 AS DECIMAL) AS anon_1 \nFROM casttest", + ) else: - eq_(str(sel), + eq_( + str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, " "casttest.ts, CAST(casttest.v1 AS NUMERIC) AS " - "anon_1 \nFROM casttest") + "anon_1 \nFROM casttest", + ) # first test with PostgreSQL engine check_results( - postgresql.dialect(), [ - 'NUMERIC', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], - '%(param_1)s') + postgresql.dialect(), + ["NUMERIC", "NUMERIC(12, 9)", "DATE", "TEXT", "VARCHAR(20)"], + "%(param_1)s", + ) # then the Oracle engine check_results( - oracle.dialect(), [ - 'NUMERIC', 'NUMERIC(12, 9)', 'DATE', - 'CLOB', 'VARCHAR2(20 CHAR)'], - ':param_1') + oracle.dialect(), + ["NUMERIC", "NUMERIC(12, 9)", "DATE", "CLOB", "VARCHAR2(20 CHAR)"], + ":param_1", + ) # then the sqlite engine - check_results(sqlite.dialect(), ['NUMERIC', 'NUMERIC(12, 9)', - 'DATE', 'TEXT', 'VARCHAR(20)'], '?') + check_results( + sqlite.dialect(), + ["NUMERIC", "NUMERIC(12, 9)", "DATE", "TEXT", "VARCHAR(20)"], + "?", + ) # then the MySQL engine - check_results(mysql.dialect(), ['DECIMAL', 'DECIMAL(12, 9)', - 'DATE', 'CHAR', 'CHAR(20)'], '%s') - - self.assert_compile(cast(text('NULL'), Integer), - 'CAST(NULL AS INTEGER)', - dialect=sqlite.dialect()) - self.assert_compile(cast(null(), Integer), - 'CAST(NULL AS INTEGER)', - dialect=sqlite.dialect()) - self.assert_compile(cast(literal_column('NULL'), Integer), - 'CAST(NULL AS INTEGER)', - dialect=sqlite.dialect()) + check_results( + mysql.dialect(), + ["DECIMAL", "DECIMAL(12, 9)", "DATE", "CHAR", "CHAR(20)"], + "%s", + ) - def test_over(self): self.assert_compile( - func.row_number().over(), - "row_number() OVER ()" + cast(text("NULL"), Integer), + "CAST(NULL AS INTEGER)", + dialect=sqlite.dialect(), + ) + self.assert_compile( + cast(null(), Integer), + "CAST(NULL AS INTEGER)", + dialect=sqlite.dialect(), + ) + self.assert_compile( + cast(literal_column("NULL"), Integer), + "CAST(NULL AS INTEGER)", + dialect=sqlite.dialect(), ) + + def test_over(self): + self.assert_compile(func.row_number().over(), "row_number() OVER ()") self.assert_compile( func.row_number().over( order_by=[table1.c.name, table1.c.description] ), - "row_number() OVER (ORDER BY mytable.name, mytable.description)" + "row_number() OVER (ORDER BY mytable.name, mytable.description)", ) self.assert_compile( func.row_number().over( partition_by=[table1.c.name, table1.c.description] ), "row_number() OVER (PARTITION BY mytable.name, " - "mytable.description)" + "mytable.description)", ) self.assert_compile( func.row_number().over( - partition_by=[table1.c.name], - order_by=[table1.c.description] + partition_by=[table1.c.name], order_by=[table1.c.description] ), "row_number() OVER (PARTITION BY mytable.name " - "ORDER BY mytable.description)" + "ORDER BY mytable.description)", ) self.assert_compile( func.row_number().over( - partition_by=table1.c.name, - order_by=table1.c.description + partition_by=table1.c.name, order_by=table1.c.description ), "row_number() OVER (PARTITION BY mytable.name " - "ORDER BY mytable.description)" + "ORDER BY mytable.description)", ) self.assert_compile( func.row_number().over( partition_by=table1.c.name, - order_by=[table1.c.name, table1.c.description] + order_by=[table1.c.name, table1.c.description], ), "row_number() OVER (PARTITION BY mytable.name " - "ORDER BY mytable.name, mytable.description)" + "ORDER BY mytable.name, mytable.description)", ) self.assert_compile( func.row_number().over( - partition_by=[], - order_by=[table1.c.name, table1.c.description] + partition_by=[], order_by=[table1.c.name, table1.c.description] ), - "row_number() OVER (ORDER BY mytable.name, mytable.description)" + "row_number() OVER (ORDER BY mytable.name, mytable.description)", ) self.assert_compile( func.row_number().over( - partition_by=[table1.c.name, table1.c.description], - order_by=[] + partition_by=[table1.c.name, table1.c.description], order_by=[] ), "row_number() OVER (PARTITION BY mytable.name, " - "mytable.description)" + "mytable.description)", ) self.assert_compile( - func.row_number().over( - partition_by=[], - order_by=[] - ), - "row_number() OVER ()" + func.row_number().over(partition_by=[], order_by=[]), + "row_number() OVER ()", ) self.assert_compile( - select([func.row_number().over( - order_by=table1.c.description - ).label('foo')]), + select( + [ + func.row_number() + .over(order_by=table1.c.description) + .label("foo") + ] + ), "SELECT row_number() OVER (ORDER BY mytable.description) " - "AS foo FROM mytable" + "AS foo FROM mytable", ) # test from_obj generation. # from func: self.assert_compile( - select([ - func.max(table1.c.name).over( - partition_by=['description'] - ) - ]), + select( + [func.max(table1.c.name).over(partition_by=["description"])] + ), "SELECT max(mytable.name) OVER (PARTITION BY mytable.description) " - "AS anon_1 FROM mytable" + "AS anon_1 FROM mytable", ) # from partition_by self.assert_compile( - select([ - func.row_number().over( - partition_by=[table1.c.name] - ) - ]), + select([func.row_number().over(partition_by=[table1.c.name])]), "SELECT row_number() OVER (PARTITION BY mytable.name) " - "AS anon_1 FROM mytable" + "AS anon_1 FROM mytable", ) # from order_by self.assert_compile( - select([ - func.row_number().over( - order_by=table1.c.name - ) - ]), + select([func.row_number().over(order_by=table1.c.name)]), "SELECT row_number() OVER (ORDER BY mytable.name) " - "AS anon_1 FROM mytable" + "AS anon_1 FROM mytable", ) # this tests that _from_objects # concantenates OK self.assert_compile( select([column("x") + over(func.foo())]), - "SELECT x + foo() OVER () AS anon_1" + "SELECT x + foo() OVER () AS anon_1", ) # test a reference to a label that in the referecned selectable; # this resolves - expr = (table1.c.myid + 5).label('sum') + expr = (table1.c.myid + 5).label("sum") stmt = select([expr]).alias() self.assert_compile( select([stmt.c.sum, func.row_number().over(order_by=stmt.c.sum)]), "SELECT anon_1.sum, row_number() OVER (ORDER BY anon_1.sum) " "AS anon_2 FROM (SELECT mytable.myid + :myid_1 AS sum " - "FROM mytable) AS anon_1" + "FROM mytable) AS anon_1", ) # test a reference to a label that's at the same level as the OVER # in the columns clause; doesn't resolve - expr = (table1.c.myid + 5).label('sum') + expr = (table1.c.myid + 5).label("sum") self.assert_compile( select([expr, func.row_number().over(order_by=expr)]), "SELECT mytable.myid + :myid_1 AS sum, " "row_number() OVER " - "(ORDER BY mytable.myid + :myid_1) AS anon_1 FROM mytable" + "(ORDER BY mytable.myid + :myid_1) AS anon_1 FROM mytable", ) def test_over_framespec(self): @@ -2457,7 +2732,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT row_number() OVER " "(ORDER BY mytable.myid ROWS BETWEEN CURRENT " "ROW AND UNBOUNDED FOLLOWING)" - " AS anon_1 FROM mytable" + " AS anon_1 FROM mytable", ) self.assert_compile( @@ -2465,7 +2740,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT row_number() OVER " "(ORDER BY mytable.myid ROWS BETWEEN UNBOUNDED " "PRECEDING AND UNBOUNDED FOLLOWING)" - " AS anon_1 FROM mytable" + " AS anon_1 FROM mytable", ) self.assert_compile( @@ -2473,7 +2748,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT row_number() OVER " "(ORDER BY mytable.myid RANGE BETWEEN " "UNBOUNDED PRECEDING AND CURRENT ROW)" - " AS anon_1 FROM mytable" + " AS anon_1 FROM mytable", ) self.assert_compile( @@ -2482,7 +2757,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "(ORDER BY mytable.myid RANGE BETWEEN " ":param_1 PRECEDING AND :param_2 FOLLOWING)" " AS anon_1 FROM mytable", - checkparams={'param_1': 5, 'param_2': 10} + checkparams={"param_1": 5, "param_2": 10}, ) self.assert_compile( @@ -2491,7 +2766,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "(ORDER BY mytable.myid RANGE BETWEEN " ":param_1 FOLLOWING AND :param_2 FOLLOWING)" " AS anon_1 FROM mytable", - checkparams={'param_1': 1, 'param_2': 10} + checkparams={"param_1": 1, "param_2": 10}, ) self.assert_compile( @@ -2500,89 +2775,108 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "(ORDER BY mytable.myid RANGE BETWEEN " ":param_1 PRECEDING AND :param_2 PRECEDING)" " AS anon_1 FROM mytable", - checkparams={'param_1': 10, 'param_2': 1} + checkparams={"param_1": 10, "param_2": 1}, ) def test_over_invalid_framespecs(self): assert_raises_message( exc.ArgumentError, "Integer or None expected for range value", - func.row_number().over, range_=("foo", 8) + func.row_number().over, + range_=("foo", 8), ) assert_raises_message( exc.ArgumentError, "Integer or None expected for range value", - func.row_number().over, range_=(-5, "foo") + func.row_number().over, + range_=(-5, "foo"), ) assert_raises_message( exc.ArgumentError, "'range_' and 'rows' are mutually exclusive", - func.row_number().over, range_=(-5, 8), rows=(-2, 5) + func.row_number().over, + range_=(-5, 8), + rows=(-2, 5), ) def test_over_within_group(self): from sqlalchemy import within_group - stmt = select([ - table1.c.myid, - within_group( - func.percentile_cont(0.5), - table1.c.name.desc() - ).over( - range_=(1, 2), - partition_by=table1.c.name, - order_by=table1.c.myid - ) - ]) + + stmt = select( + [ + table1.c.myid, + within_group( + func.percentile_cont(0.5), table1.c.name.desc() + ).over( + range_=(1, 2), + partition_by=table1.c.name, + order_by=table1.c.myid, + ), + ] + ) eq_ignore_whitespace( str(stmt), "SELECT mytable.myid, percentile_cont(:percentile_cont_1) " "WITHIN GROUP (ORDER BY mytable.name DESC) " "OVER (PARTITION BY mytable.name ORDER BY mytable.myid " "RANGE BETWEEN :param_1 FOLLOWING AND :param_2 FOLLOWING) " - "AS anon_1 FROM mytable" + "AS anon_1 FROM mytable", + ) + + stmt = select( + [ + table1.c.myid, + within_group( + func.percentile_cont(0.5), table1.c.name.desc() + ).over( + rows=(1, 2), + partition_by=table1.c.name, + order_by=table1.c.myid, + ), + ] ) - - stmt = select([ - table1.c.myid, - within_group( - func.percentile_cont(0.5), - table1.c.name.desc() - ).over( - rows=(1, 2), - partition_by=table1.c.name, - order_by=table1.c.myid - ) - ]) eq_ignore_whitespace( str(stmt), "SELECT mytable.myid, percentile_cont(:percentile_cont_1) " "WITHIN GROUP (ORDER BY mytable.name DESC) " "OVER (PARTITION BY mytable.name ORDER BY mytable.myid " "ROWS BETWEEN :param_1 FOLLOWING AND :param_2 FOLLOWING) " - "AS anon_1 FROM mytable" + "AS anon_1 FROM mytable", ) - - def test_date_between(self): import datetime - table = Table('dt', metadata, - Column('date', Date)) + + table = Table("dt", metadata, Column("date", Date)) self.assert_compile( - table.select(table.c.date.between(datetime.date(2006, 6, 1), - datetime.date(2006, 6, 5))), + table.select( + table.c.date.between( + datetime.date(2006, 6, 1), datetime.date(2006, 6, 5) + ) + ), "SELECT dt.date FROM dt WHERE dt.date BETWEEN :date_1 AND :date_2", - checkparams={'date_1': datetime.date(2006, 6, 1), - 'date_2': datetime.date(2006, 6, 5)}) + checkparams={ + "date_1": datetime.date(2006, 6, 1), + "date_2": datetime.date(2006, 6, 5), + }, + ) self.assert_compile( - table.select(sql.between(table.c.date, datetime.date(2006, 6, 1), - datetime.date(2006, 6, 5))), + table.select( + sql.between( + table.c.date, + datetime.date(2006, 6, 1), + datetime.date(2006, 6, 5), + ) + ), "SELECT dt.date FROM dt WHERE dt.date BETWEEN :date_1 AND :date_2", - checkparams={'date_1': datetime.date(2006, 6, 1), - 'date_2': datetime.date(2006, 6, 5)}) + checkparams={ + "date_1": datetime.date(2006, 6, 1), + "date_2": datetime.date(2006, 6, 5), + }, + ) def test_delayed_col_naming(self): my_str = Column(String) @@ -2592,18 +2886,18 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): assert_raises_message( exc.InvalidRequestError, "Cannot initialize a sub-selectable with this Column", - lambda: sel1.c + lambda: sel1.c, ) # calling label or as_scalar doesn't compile # anything. - sel2 = select([func.substr(my_str, 2, 3)]).label('my_substr') + sel2 = select([func.substr(my_str, 2, 3)]).label("my_substr") assert_raises_message( exc.CompileError, "Cannot compile Column object until its 'name' is assigned.", sel2.compile, - dialect=default.DefaultDialect() + dialect=default.DefaultDialect(), ) sel3 = select([my_str]).as_scalar() @@ -2611,24 +2905,17 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): exc.CompileError, "Cannot compile Column object until its 'name' is assigned.", sel3.compile, - dialect=default.DefaultDialect() + dialect=default.DefaultDialect(), ) - my_str.name = 'foo' + my_str.name = "foo" + self.assert_compile(sel1, "SELECT foo") self.assert_compile( - sel1, - "SELECT foo", - ) - self.assert_compile( - sel2, - '(SELECT substr(foo, :substr_2, :substr_3) AS substr_1)', + sel2, "(SELECT substr(foo, :substr_2, :substr_3) AS substr_1)" ) - self.assert_compile( - sel3, - "(SELECT foo)" - ) + self.assert_compile(sel3, "(SELECT foo)") def test_naming(self): # TODO: the part where we check c.keys() are not "compile" tests, they @@ -2636,36 +2923,46 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): # version of that suite f1 = func.hoho(table1.c.name) - s1 = select([table1.c.myid, table1.c.myid.label('foobar'), - f1, - func.lala(table1.c.name).label('gg')]) - - eq_( - list(s1.c.keys()), - ['myid', 'foobar', str(f1), 'gg'] + s1 = select( + [ + table1.c.myid, + table1.c.myid.label("foobar"), + f1, + func.lala(table1.c.name).label("gg"), + ] ) + eq_(list(s1.c.keys()), ["myid", "foobar", str(f1), "gg"]) + meta = MetaData() - t1 = Table('mytable', meta, Column('col1', Integer)) + t1 = Table("mytable", meta, Column("col1", Integer)) exprs = ( table1.c.myid == 12, func.hoho(table1.c.myid), cast(table1.c.name, Numeric), - literal('x'), + literal("x"), ) for col, key, expr, lbl in ( - (table1.c.name, 'name', 'mytable.name', None), - (exprs[0], str(exprs[0]), 'mytable.myid = :myid_1', 'anon_1'), - (exprs[1], str(exprs[1]), 'hoho(mytable.myid)', 'hoho_1'), - (exprs[2], str(exprs[2]), - 'CAST(mytable.name AS NUMERIC)', 'anon_1'), - (t1.c.col1, 'col1', 'mytable.col1', None), - (column('some wacky thing'), 'some wacky thing', - '"some wacky thing"', ''), - (exprs[3], exprs[3].key, ":param_1", "anon_1") + (table1.c.name, "name", "mytable.name", None), + (exprs[0], str(exprs[0]), "mytable.myid = :myid_1", "anon_1"), + (exprs[1], str(exprs[1]), "hoho(mytable.myid)", "hoho_1"), + ( + exprs[2], + str(exprs[2]), + "CAST(mytable.name AS NUMERIC)", + "anon_1", + ), + (t1.c.col1, "col1", "mytable.col1", None), + ( + column("some wacky thing"), + "some wacky thing", + '"some wacky thing"', + "", + ), + (exprs[3], exprs[3].key, ":param_1", "anon_1"), ): - if getattr(col, 'table', None) is not None: + if getattr(col, "table", None) is not None: t = col.table else: t = table1 @@ -2675,107 +2972,151 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): if lbl: self.assert_compile( - s1, "SELECT %s AS %s FROM mytable" % - (expr, lbl)) + s1, "SELECT %s AS %s FROM mytable" % (expr, lbl) + ) else: self.assert_compile(s1, "SELECT %s FROM mytable" % (expr,)) s1 = select([s1]) if lbl: self.assert_compile( - s1, "SELECT %s FROM (SELECT %s AS %s FROM mytable)" % - (lbl, expr, lbl)) + s1, + "SELECT %s FROM (SELECT %s AS %s FROM mytable)" + % (lbl, expr, lbl), + ) elif col.table is not None: # sqlite rule labels subquery columns self.assert_compile( - s1, "SELECT %s FROM (SELECT %s AS %s FROM mytable)" % - (key, expr, key)) + s1, + "SELECT %s FROM (SELECT %s AS %s FROM mytable)" + % (key, expr, key), + ) else: - self.assert_compile(s1, - "SELECT %s FROM (SELECT %s FROM mytable)" % - (expr, expr)) + self.assert_compile( + s1, + "SELECT %s FROM (SELECT %s FROM mytable)" % (expr, expr), + ) def test_hints(self): s = select([table1.c.myid]).with_hint(table1, "test hint %(name)s") - s2 = select([table1.c.myid]).\ - with_hint(table1, "index(%(name)s idx)", 'oracle').\ - with_hint(table1, "WITH HINT INDEX idx", 'sybase') + s2 = ( + select([table1.c.myid]) + .with_hint(table1, "index(%(name)s idx)", "oracle") + .with_hint(table1, "WITH HINT INDEX idx", "sybase") + ) a1 = table1.alias() s3 = select([a1.c.myid]).with_hint(a1, "index(%(name)s hint)") - subs4 = select([ - table1, table2 - ]).select_from( - table1.join(table2, table1.c.myid == table2.c.otherid)).\ - with_hint(table1, 'hint1') + subs4 = ( + select([table1, table2]) + .select_from( + table1.join(table2, table1.c.myid == table2.c.otherid) + ) + .with_hint(table1, "hint1") + ) - s4 = select([table3]).select_from( - table3.join( - subs4, - subs4.c.othername == table3.c.otherstuff + s4 = ( + select([table3]) + .select_from( + table3.join(subs4, subs4.c.othername == table3.c.otherstuff) ) - ).\ - with_hint(table3, 'hint3') + .with_hint(table3, "hint3") + ) - t1 = table('QuotedName', column('col1')) - s6 = select([t1.c.col1]).where(t1.c.col1 > 10).\ - with_hint(t1, '%(name)s idx1') - a2 = t1.alias('SomeName') - s7 = select([a2.c.col1]).where(a2.c.col1 > 10).\ - with_hint(a2, '%(name)s idx1') + t1 = table("QuotedName", column("col1")) + s6 = ( + select([t1.c.col1]) + .where(t1.c.col1 > 10) + .with_hint(t1, "%(name)s idx1") + ) + a2 = t1.alias("SomeName") + s7 = ( + select([a2.c.col1]) + .where(a2.c.col1 > 10) + .with_hint(a2, "%(name)s idx1") + ) - mysql_d, oracle_d, sybase_d = \ - mysql.dialect(), \ - oracle.dialect(), \ - sybase.dialect() + mysql_d, oracle_d, sybase_d = ( + mysql.dialect(), + oracle.dialect(), + sybase.dialect(), + ) for stmt, dialect, expected in [ - (s, mysql_d, - "SELECT mytable.myid FROM mytable test hint mytable"), - (s, oracle_d, - "SELECT /*+ test hint mytable */ mytable.myid FROM mytable"), - (s, sybase_d, - "SELECT mytable.myid FROM mytable test hint mytable"), - (s2, mysql_d, - "SELECT mytable.myid FROM mytable"), - (s2, oracle_d, - "SELECT /*+ index(mytable idx) */ mytable.myid FROM mytable"), - (s2, sybase_d, - "SELECT mytable.myid FROM mytable WITH HINT INDEX idx"), - (s3, mysql_d, + (s, mysql_d, "SELECT mytable.myid FROM mytable test hint mytable"), + ( + s, + oracle_d, + "SELECT /*+ test hint mytable */ mytable.myid FROM mytable", + ), + ( + s, + sybase_d, + "SELECT mytable.myid FROM mytable test hint mytable", + ), + (s2, mysql_d, "SELECT mytable.myid FROM mytable"), + ( + s2, + oracle_d, + "SELECT /*+ index(mytable idx) */ mytable.myid FROM mytable", + ), + ( + s2, + sybase_d, + "SELECT mytable.myid FROM mytable WITH HINT INDEX idx", + ), + ( + s3, + mysql_d, "SELECT mytable_1.myid FROM mytable AS mytable_1 " - "index(mytable_1 hint)"), - (s3, oracle_d, + "index(mytable_1 hint)", + ), + ( + s3, + oracle_d, "SELECT /*+ index(mytable_1 hint) */ mytable_1.myid FROM " - "mytable mytable_1"), - (s3, sybase_d, + "mytable mytable_1", + ), + ( + s3, + sybase_d, "SELECT mytable_1.myid FROM mytable AS mytable_1 " - "index(mytable_1 hint)"), - (s4, mysql_d, + "index(mytable_1 hint)", + ), + ( + s4, + mysql_d, "SELECT thirdtable.userid, thirdtable.otherstuff " "FROM thirdtable " "hint3 INNER JOIN (SELECT mytable.myid, mytable.name, " "mytable.description, myothertable.otherid, " "myothertable.othername FROM mytable hint1 INNER " "JOIN myothertable ON mytable.myid = myothertable.otherid) " - "ON othername = thirdtable.otherstuff"), - (s4, sybase_d, + "ON othername = thirdtable.otherstuff", + ), + ( + s4, + sybase_d, "SELECT thirdtable.userid, thirdtable.otherstuff " "FROM thirdtable " "hint3 JOIN (SELECT mytable.myid, mytable.name, " "mytable.description, myothertable.otherid, " "myothertable.othername FROM mytable hint1 " "JOIN myothertable ON mytable.myid = myothertable.otherid) " - "ON othername = thirdtable.otherstuff"), - (s4, oracle_d, + "ON othername = thirdtable.otherstuff", + ), + ( + s4, + oracle_d, "SELECT /*+ hint3 */ thirdtable.userid, thirdtable.otherstuff " "FROM thirdtable JOIN (SELECT /*+ hint1 */ mytable.myid," " mytable.name, mytable.description, myothertable.otherid," " myothertable.othername FROM mytable JOIN myothertable ON" " mytable.myid = myothertable.otherid) ON othername =" - " thirdtable.otherstuff"), + " thirdtable.otherstuff", + ), # TODO: figure out dictionary ordering solution here # (s5, oracle_d, # "SELECT /*+ hint3 */ /*+ hint1 */ thirdtable.userid, " @@ -2785,68 +3126,64 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): # " myothertable.othername FROM mytable JOIN myothertable ON" # " mytable.myid = myothertable.otherid) ON othername =" # " thirdtable.otherstuff"), - (s6, oracle_d, + ( + s6, + oracle_d, """SELECT /*+ "QuotedName" idx1 */ "QuotedName".col1 """ - """FROM "QuotedName" WHERE "QuotedName".col1 > :col1_1"""), - (s7, oracle_d, - """SELECT /*+ "SomeName" idx1 */ "SomeName".col1 FROM """ - """"QuotedName" "SomeName" WHERE "SomeName".col1 > :col1_1"""), + """FROM "QuotedName" WHERE "QuotedName".col1 > :col1_1""", + ), + ( + s7, + oracle_d, + """SELECT /*+ "SomeName" idx1 */ "SomeName".col1 FROM """ + """"QuotedName" "SomeName" WHERE "SomeName".col1 > :col1_1""", + ), ]: - self.assert_compile( - stmt, - expected, - dialect=dialect - ) + self.assert_compile(stmt, expected, dialect=dialect) def test_statement_hints(self): - stmt = select([table1.c.myid]).\ - with_statement_hint("test hint one").\ - with_statement_hint("test hint two", 'mysql') + stmt = ( + select([table1.c.myid]) + .with_statement_hint("test hint one") + .with_statement_hint("test hint two", "mysql") + ) self.assert_compile( - stmt, - "SELECT mytable.myid FROM mytable test hint one", + stmt, "SELECT mytable.myid FROM mytable test hint one" ) self.assert_compile( stmt, "SELECT mytable.myid FROM mytable test hint one test hint two", - dialect='mysql' + dialect="mysql", ) def test_literal_as_text_fromstring(self): - self.assert_compile( - and_(text("a"), text("b")), - "a AND b" - ) + self.assert_compile(and_(text("a"), text("b")), "a AND b") def test_literal_as_text_nonstring_raise(self): - assert_raises(exc.ArgumentError, - and_, ("a",), ("b",) - ) + assert_raises(exc.ArgumentError, and_, ("a",), ("b",)) class UnsupportedTest(fixtures.TestBase): - def test_unsupported_element_str_visit_name(self): from sqlalchemy.sql.expression import ClauseElement class SomeElement(ClauseElement): - __visit_name__ = 'some_element' + __visit_name__ = "some_element" assert_raises_message( exc.UnsupportedCompilationError, r"Compiler ", - SomeElement().compile + SomeElement().compile, ) def test_unsupported_element_meth_visit_name(self): from sqlalchemy.sql.expression import ClauseElement class SomeElement(ClauseElement): - @classmethod def __visit_name__(cls): return "some_element" @@ -2855,7 +3192,7 @@ class UnsupportedTest(fixtures.TestBase): exc.UnsupportedCompilationError, r"Compiler ", - SomeElement().compile + SomeElement().compile, ) def test_unsupported_operator(self): @@ -2863,12 +3200,13 @@ class UnsupportedTest(fixtures.TestBase): def myop(x, y): pass + binary = BinaryExpression(column("foo"), column("bar"), myop) assert_raises_message( exc.UnsupportedCompilationError, r"Compiler " = :param_1' - ) + eq_ignore_whitespace(str(stmt), '"" = :param_1') def test_cte(self): # stringify of these was supported anyway by defaultdialect. @@ -2895,7 +3230,7 @@ class StringifySpecialTest(fixtures.TestBase): eq_ignore_whitespace( str(stmt), "WITH anon_1 AS (SELECT mytable.myid AS myid FROM mytable) " - "SELECT anon_1.myid FROM anon_1" + "SELECT anon_1.myid FROM anon_1", ) def test_next_sequence_value(self): @@ -2906,8 +3241,7 @@ class StringifySpecialTest(fixtures.TestBase): seq = Sequence("my_sequence") eq_ignore_whitespace( - str(seq.next_value()), - "" + str(seq.next_value()), "" ) def test_returning(self): @@ -2916,47 +3250,43 @@ class StringifySpecialTest(fixtures.TestBase): eq_ignore_whitespace( str(stmt), "INSERT INTO mytable (myid, name, description) " - "VALUES (:myid, :name, :description) RETURNING mytable.myid" + "VALUES (:myid, :name, :description) RETURNING mytable.myid", ) def test_array_index(self): - stmt = select([column('foo', types.ARRAY(Integer))[5]]) + stmt = select([column("foo", types.ARRAY(Integer))[5]]) - eq_ignore_whitespace( - str(stmt), - "SELECT foo[:foo_1] AS anon_1" - ) + eq_ignore_whitespace(str(stmt), "SELECT foo[:foo_1] AS anon_1") def test_unknown_type(self): class MyType(types.TypeEngine): - __visit_name__ = 'mytype' + __visit_name__ = "mytype" stmt = select([cast(table1.c.myid, MyType)]) eq_ignore_whitespace( str(stmt), - "SELECT CAST(mytable.myid AS MyType) AS anon_1 FROM mytable" + "SELECT CAST(mytable.myid AS MyType) AS anon_1 FROM mytable", ) def test_within_group(self): # stringify of these was supported anyway by defaultdialect. from sqlalchemy import within_group - stmt = select([ - table1.c.myid, - within_group( - func.percentile_cont(0.5), - table1.c.name.desc() - ) - ]) + + stmt = select( + [ + table1.c.myid, + within_group(func.percentile_cont(0.5), table1.c.name.desc()), + ] + ) eq_ignore_whitespace( str(stmt), "SELECT mytable.myid, percentile_cont(:percentile_cont_1) " - "WITHIN GROUP (ORDER BY mytable.name DESC) AS anon_1 FROM mytable" + "WITHIN GROUP (ORDER BY mytable.name DESC) AS anon_1 FROM mytable", ) class KwargPropagationTest(fixtures.TestBase): - @classmethod def setup_class(cls): from sqlalchemy.sql.expression import ColumnClause, TableClause @@ -2969,7 +3299,7 @@ class KwargPropagationTest(fixtures.TestBase): cls.column = CatchCol("x") cls.table = CatchTable("y") - cls.criterion = cls.column == CatchCol('y') + cls.criterion = cls.column == CatchCol("y") @compiles(CatchCol) def compile_col(element, compiler, **kw): @@ -2983,16 +3313,18 @@ class KwargPropagationTest(fixtures.TestBase): def _do_test(self, element): d = default.DefaultDialect() - d.statement_compiler(d, element, - compile_kwargs={"canary": True}) + d.statement_compiler(d, element, compile_kwargs={"canary": True}) def test_binary(self): self._do_test(self.column == 5) def test_select(self): - s = select([self.column]).select_from(self.table).\ - where(self.column == self.criterion).\ - order_by(self.column) + s = ( + select([self.column]) + .select_from(self.table) + .where(self.column == self.criterion) + .order_by(self.column) + ) self._do_test(s) def test_case(self): @@ -3029,8 +3361,11 @@ class ExecutionOptionsTest(fixtures.TestBase): def test_embedded_element_true_to_false(self): stmt = table1.insert().cte() eq_(stmt._execution_options, {"autocommit": True}) - s2 = select([table1]).select_from(stmt).\ - execution_options(autocommit=False) + s2 = ( + select([table1]) + .select_from(stmt) + .execution_options(autocommit=False) + ) eq_(s2._execution_options, {"autocommit": False}) compiled = s2.compile() @@ -3038,7 +3373,7 @@ class ExecutionOptionsTest(fixtures.TestBase): class DDLTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def _illegal_type_fixture(self): class MyType(types.TypeEngine): @@ -3047,195 +3382,197 @@ class DDLTest(fixtures.TestBase, AssertsCompiledSQL): @compiles(MyType) def compile(element, compiler, **kw): raise exc.CompileError("Couldn't compile type") + return MyType def test_reraise_of_column_spec_issue(self): MyType = self._illegal_type_fixture() - t1 = Table('t', MetaData(), - Column('x', MyType()) - ) + t1 = Table("t", MetaData(), Column("x", MyType())) assert_raises_message( exc.CompileError, r"\(in table 't', column 'x'\): Couldn't compile type", - schema.CreateTable(t1).compile + schema.CreateTable(t1).compile, ) def test_reraise_of_column_spec_issue_unicode(self): MyType = self._illegal_type_fixture() - t1 = Table('t', MetaData(), - Column(u('méil'), MyType()) - ) + t1 = Table("t", MetaData(), Column(u("méil"), MyType())) assert_raises_message( exc.CompileError, u(r"\(in table 't', column 'méil'\): Couldn't compile type"), - schema.CreateTable(t1).compile + schema.CreateTable(t1).compile, ) def test_system_flag(self): m = MetaData() - t = Table('t', m, Column('x', Integer), - Column('y', Integer, system=True), - Column('z', Integer)) + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, system=True), + Column("z", Integer), + ) self.assert_compile( - schema.CreateTable(t), - "CREATE TABLE t (x INTEGER, z INTEGER)" + schema.CreateTable(t), "CREATE TABLE t (x INTEGER, z INTEGER)" ) m2 = MetaData() t2 = t.tometadata(m2) self.assert_compile( - schema.CreateTable(t2), - "CREATE TABLE t (x INTEGER, z INTEGER)" + schema.CreateTable(t2), "CREATE TABLE t (x INTEGER, z INTEGER)" ) def test_composite_pk_constraint_autoinc_first_implicit(self): m = MetaData() t = Table( - 't', m, - Column('a', Integer, primary_key=True), - Column('b', Integer, primary_key=True, autoincrement=True) + "t", + m, + Column("a", Integer, primary_key=True), + Column("b", Integer, primary_key=True, autoincrement=True), ) self.assert_compile( schema.CreateTable(t), "CREATE TABLE t (" "a INTEGER NOT NULL, " "b INTEGER NOT NULL, " - "PRIMARY KEY (b, a))" + "PRIMARY KEY (b, a))", ) def test_composite_pk_constraint_maintains_order_explicit(self): m = MetaData() t = Table( - 't', m, - Column('a', Integer), - Column('b', Integer, autoincrement=True), - schema.PrimaryKeyConstraint('a', 'b') + "t", + m, + Column("a", Integer), + Column("b", Integer, autoincrement=True), + schema.PrimaryKeyConstraint("a", "b"), ) self.assert_compile( schema.CreateTable(t), "CREATE TABLE t (" "a INTEGER NOT NULL, " "b INTEGER NOT NULL, " - "PRIMARY KEY (a, b))" + "PRIMARY KEY (a, b))", ) def test_create_table_suffix(self): class MyDialect(default.DefaultDialect): class MyCompiler(compiler.DDLCompiler): def create_table_suffix(self, table): - return 'SOME SUFFIX' + return "SOME SUFFIX" ddl_compiler = MyCompiler m = MetaData() - t1 = Table('t1', m, Column('q', Integer)) + t1 = Table("t1", m, Column("q", Integer)) self.assert_compile( schema.CreateTable(t1), "CREATE TABLE t1 SOME SUFFIX (q INTEGER)", - dialect=MyDialect() + dialect=MyDialect(), ) def test_table_no_cols(self): m = MetaData() - t1 = Table('t1', m) - self.assert_compile( - schema.CreateTable(t1), - "CREATE TABLE t1 ()" - ) + t1 = Table("t1", m) + self.assert_compile(schema.CreateTable(t1), "CREATE TABLE t1 ()") def test_table_no_cols_w_constraint(self): m = MetaData() - t1 = Table('t1', m, CheckConstraint('a = 1')) + t1 = Table("t1", m, CheckConstraint("a = 1")) self.assert_compile( - schema.CreateTable(t1), - "CREATE TABLE t1 (CHECK (a = 1))" + schema.CreateTable(t1), "CREATE TABLE t1 (CHECK (a = 1))" ) def test_table_one_col_w_constraint(self): m = MetaData() - t1 = Table('t1', m, Column('q', Integer), CheckConstraint('a = 1')) + t1 = Table("t1", m, Column("q", Integer), CheckConstraint("a = 1")) self.assert_compile( schema.CreateTable(t1), - "CREATE TABLE t1 (q INTEGER, CHECK (a = 1))" + "CREATE TABLE t1 (q INTEGER, CHECK (a = 1))", ) def test_schema_translate_map_table(self): m = MetaData() - t1 = Table('t1', m, Column('q', Integer)) - t2 = Table('t2', m, Column('q', Integer), schema='foo') - t3 = Table('t3', m, Column('q', Integer), schema='bar') + t1 = Table("t1", m, Column("q", Integer)) + t2 = Table("t2", m, Column("q", Integer), schema="foo") + t3 = Table("t3", m, Column("q", Integer), schema="bar") schema_translate_map = {None: "z", "bar": None, "foo": "bat"} self.assert_compile( schema.CreateTable(t1), "CREATE TABLE z.t1 (q INTEGER)", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) self.assert_compile( schema.CreateTable(t2), "CREATE TABLE bat.t2 (q INTEGER)", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) self.assert_compile( schema.CreateTable(t3), "CREATE TABLE t3 (q INTEGER)", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) def test_schema_translate_map_sequence(self): - s1 = schema.Sequence('s1') - s2 = schema.Sequence('s2', schema='foo') - s3 = schema.Sequence('s3', schema='bar') + s1 = schema.Sequence("s1") + s2 = schema.Sequence("s2", schema="foo") + s3 = schema.Sequence("s3", schema="bar") schema_translate_map = {None: "z", "bar": None, "foo": "bat"} self.assert_compile( schema.CreateSequence(s1), "CREATE SEQUENCE z.s1", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) self.assert_compile( schema.CreateSequence(s2), "CREATE SEQUENCE bat.s2", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) self.assert_compile( schema.CreateSequence(s3), "CREATE SEQUENCE s3", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_select(self): - self.assert_compile(table4.select(), - "SELECT remote_owner.remotetable.rem_id, " - "remote_owner.remotetable.datatype_id," - " remote_owner.remotetable.value " - "FROM remote_owner.remotetable") + self.assert_compile( + table4.select(), + "SELECT remote_owner.remotetable.rem_id, " + "remote_owner.remotetable.datatype_id," + " remote_owner.remotetable.value " + "FROM remote_owner.remotetable", + ) self.assert_compile( table4.select( - and_( - table4.c.datatype_id == 7, - table4.c.value == 'hi')), + and_(table4.c.datatype_id == 7, table4.c.value == "hi") + ), "SELECT remote_owner.remotetable.rem_id, " "remote_owner.remotetable.datatype_id," " remote_owner.remotetable.value " "FROM remote_owner.remotetable WHERE " "remote_owner.remotetable.datatype_id = :datatype_id_1 AND" - " remote_owner.remotetable.value = :value_1") + " remote_owner.remotetable.value = :value_1", + ) - s = table4.select(and_(table4.c.datatype_id == 7, - table4.c.value == 'hi'), use_labels=True) + s = table4.select( + and_(table4.c.datatype_id == 7, table4.c.value == "hi"), + use_labels=True, + ) self.assert_compile( - s, "SELECT remote_owner.remotetable.rem_id AS" + s, + "SELECT remote_owner.remotetable.rem_id AS" " remote_owner_remotetable_rem_id, " "remote_owner.remotetable.datatype_id AS" " remote_owner_remotetable_datatype_id, " @@ -3243,86 +3580,94 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): "AS remote_owner_remotetable_value FROM " "remote_owner.remotetable WHERE " "remote_owner.remotetable.datatype_id = :datatype_id_1 AND " - "remote_owner.remotetable.value = :value_1") + "remote_owner.remotetable.value = :value_1", + ) # multi-part schema name - self.assert_compile(table5.select(), - 'SELECT "dbo.remote_owner".remotetable.rem_id, ' - '"dbo.remote_owner".remotetable.datatype_id, ' - '"dbo.remote_owner".remotetable.value ' - 'FROM "dbo.remote_owner".remotetable' - ) + self.assert_compile( + table5.select(), + 'SELECT "dbo.remote_owner".remotetable.rem_id, ' + '"dbo.remote_owner".remotetable.datatype_id, ' + '"dbo.remote_owner".remotetable.value ' + 'FROM "dbo.remote_owner".remotetable', + ) # multi-part schema name labels - convert '.' to '_' - self.assert_compile(table5.select(use_labels=True), - 'SELECT "dbo.remote_owner".remotetable.rem_id AS' - ' dbo_remote_owner_remotetable_rem_id, ' - '"dbo.remote_owner".remotetable.datatype_id' - ' AS dbo_remote_owner_remotetable_datatype_id,' - ' "dbo.remote_owner".remotetable.value AS ' - 'dbo_remote_owner_remotetable_value FROM' - ' "dbo.remote_owner".remotetable' - ) + self.assert_compile( + table5.select(use_labels=True), + 'SELECT "dbo.remote_owner".remotetable.rem_id AS' + " dbo_remote_owner_remotetable_rem_id, " + '"dbo.remote_owner".remotetable.datatype_id' + " AS dbo_remote_owner_remotetable_datatype_id," + ' "dbo.remote_owner".remotetable.value AS ' + "dbo_remote_owner_remotetable_value FROM" + ' "dbo.remote_owner".remotetable', + ) def test_schema_translate_select(self): m = MetaData() table1 = Table( - 'mytable', m, Column('myid', Integer), - Column('name', String), - Column('description', String) + "mytable", + m, + Column("myid", Integer), + Column("name", String), + Column("description", String), ) - schema_translate_map = {"remote_owner": "foob", None: 'bar'} + schema_translate_map = {"remote_owner": "foob", None: "bar"} self.assert_compile( - table1.select().where(table1.c.name == 'hi'), + table1.select().where(table1.c.name == "hi"), "SELECT bar.mytable.myid, bar.mytable.name, " "bar.mytable.description FROM bar.mytable " "WHERE bar.mytable.name = :name_1", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) self.assert_compile( - table4.select().where(table4.c.value == 'hi'), + table4.select().where(table4.c.value == "hi"), "SELECT foob.remotetable.rem_id, foob.remotetable.datatype_id, " "foob.remotetable.value FROM foob.remotetable " "WHERE foob.remotetable.value = :value_1", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) schema_translate_map = {"remote_owner": "foob"} self.assert_compile( - select([ - table1, table4 - ]).select_from( + select([table1, table4]).select_from( join(table1, table4, table1.c.myid == table4.c.rem_id) ), "SELECT mytable.myid, mytable.name, mytable.description, " "foob.remotetable.rem_id, foob.remotetable.datatype_id, " "foob.remotetable.value FROM mytable JOIN foob.remotetable " "ON mytable.myid = foob.remotetable.rem_id", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) def test_schema_translate_aliases(self): - schema_translate_map = {None: 'bar'} + schema_translate_map = {None: "bar"} m = MetaData() table1 = Table( - 'mytable', m, Column('myid', Integer), - Column('name', String), - Column('description', String) + "mytable", + m, + Column("myid", Integer), + Column("name", String), + Column("description", String), ) table2 = Table( - 'myothertable', m, Column('otherid', Integer), - Column('othername', String), + "myothertable", + m, + Column("otherid", Integer), + Column("othername", String), ) alias = table1.alias() - stmt = select([ - table2, alias - ]).select_from(table2.join(alias, table2.c.otherid == alias.c.myid)).\ - where(alias.c.name == 'foo') + stmt = ( + select([table2, alias]) + .select_from(table2.join(alias, table2.c.otherid == alias.c.myid)) + .where(alias.c.name == "foo") + ) self.assert_compile( stmt, @@ -3331,109 +3676,122 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): "FROM bar.myothertable JOIN bar.mytable AS mytable_1 " "ON bar.myothertable.otherid = mytable_1.myid " "WHERE mytable_1.name = :name_1", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) def test_schema_translate_crud(self): - schema_translate_map = {"remote_owner": "foob", None: 'bar'} + schema_translate_map = {"remote_owner": "foob", None: "bar"} m = MetaData() table1 = Table( - 'mytable', m, - Column('myid', Integer), Column('name', String), - Column('description', String) + "mytable", + m, + Column("myid", Integer), + Column("name", String), + Column("description", String), ) self.assert_compile( - table1.insert().values(description='foo'), + table1.insert().values(description="foo"), "INSERT INTO bar.mytable (description) VALUES (:description)", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) self.assert_compile( - table1.update().where(table1.c.name == 'hi'). - values(description='foo'), + table1.update() + .where(table1.c.name == "hi") + .values(description="foo"), "UPDATE bar.mytable SET description=:description " "WHERE bar.mytable.name = :name_1", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) self.assert_compile( - table1.delete().where(table1.c.name == 'hi'), + table1.delete().where(table1.c.name == "hi"), "DELETE FROM bar.mytable WHERE bar.mytable.name = :name_1", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) self.assert_compile( - table4.insert().values(value='there'), + table4.insert().values(value="there"), "INSERT INTO foob.remotetable (value) VALUES (:value)", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) self.assert_compile( - table4.update().where(table4.c.value == 'hi'). - values(value='there'), + table4.update() + .where(table4.c.value == "hi") + .values(value="there"), "UPDATE foob.remotetable SET value=:value " "WHERE foob.remotetable.value = :value_1", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) self.assert_compile( - table4.delete().where(table4.c.value == 'hi'), + table4.delete().where(table4.c.value == "hi"), "DELETE FROM foob.remotetable WHERE " "foob.remotetable.value = :value_1", - schema_translate_map=schema_translate_map + schema_translate_map=schema_translate_map, ) def test_alias(self): - a = alias(table4, 'remtable') - self.assert_compile(a.select(a.c.datatype_id == 7), - "SELECT remtable.rem_id, remtable.datatype_id, " - "remtable.value FROM" - " remote_owner.remotetable AS remtable " - "WHERE remtable.datatype_id = :datatype_id_1") + a = alias(table4, "remtable") + self.assert_compile( + a.select(a.c.datatype_id == 7), + "SELECT remtable.rem_id, remtable.datatype_id, " + "remtable.value FROM" + " remote_owner.remotetable AS remtable " + "WHERE remtable.datatype_id = :datatype_id_1", + ) def test_update(self): self.assert_compile( - table4.update(table4.c.value == 'test', - values={table4.c.datatype_id: 12}), + table4.update( + table4.c.value == "test", values={table4.c.datatype_id: 12} + ), "UPDATE remote_owner.remotetable SET datatype_id=:datatype_id " - "WHERE remote_owner.remotetable.value = :value_1") + "WHERE remote_owner.remotetable.value = :value_1", + ) def test_insert(self): - self.assert_compile(table4.insert(values=(2, 5, 'test')), - "INSERT INTO remote_owner.remotetable " - "(rem_id, datatype_id, value) VALUES " - "(:rem_id, :datatype_id, :value)") + self.assert_compile( + table4.insert(values=(2, 5, "test")), + "INSERT INTO remote_owner.remotetable " + "(rem_id, datatype_id, value) VALUES " + "(:rem_id, :datatype_id, :value)", + ) class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_dont_overcorrelate(self): - self.assert_compile(select([table1], from_obj=[table1, - table1.select()]), - "SELECT mytable.myid, mytable.name, " - "mytable.description FROM mytable, (SELECT " - "mytable.myid AS myid, mytable.name AS " - "name, mytable.description AS description " - "FROM mytable)") + self.assert_compile( + select([table1], from_obj=[table1, table1.select()]), + "SELECT mytable.myid, mytable.name, " + "mytable.description FROM mytable, (SELECT " + "mytable.myid AS myid, mytable.name AS " + "name, mytable.description AS description " + "FROM mytable)", + ) def _fixture(self): - t1 = table('t1', column('a')) - t2 = table('t2', column('a')) + t1 = table("t1", column("a")) + t2 = table("t2", column("a")) return t1, t2, select([t1]).where(t1.c.a == t2.c.a) def _assert_where_correlated(self, stmt): self.assert_compile( stmt, "SELECT t2.a FROM t2 WHERE t2.a = " - "(SELECT t1.a FROM t1 WHERE t1.a = t2.a)") + "(SELECT t1.a FROM t1 WHERE t1.a = t2.a)", + ) def _assert_where_all_correlated(self, stmt): self.assert_compile( stmt, "SELECT t1.a, t2.a FROM t1, t2 WHERE t2.a = " - "(SELECT t1.a WHERE t1.a = t2.a)") + "(SELECT t1.a WHERE t1.a = t2.a)", + ) # note there's no more "backwards" correlation after # we've done #2746 @@ -3452,171 +3810,197 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( stmt, "SELECT t2.a, (SELECT t1.a FROM t1 WHERE t1.a = t2.a) " - "AS anon_1 FROM t2") + "AS anon_1 FROM t2", + ) def _assert_column_all_correlated(self, stmt): self.assert_compile( stmt, "SELECT t1.a, t2.a, " - "(SELECT t1.a WHERE t1.a = t2.a) AS anon_1 FROM t1, t2") + "(SELECT t1.a WHERE t1.a = t2.a) AS anon_1 FROM t1, t2", + ) def _assert_having_correlated(self, stmt): - self.assert_compile(stmt, - "SELECT t2.a FROM t2 HAVING t2.a = " - "(SELECT t1.a FROM t1 WHERE t1.a = t2.a)") + self.assert_compile( + stmt, + "SELECT t2.a FROM t2 HAVING t2.a = " + "(SELECT t1.a FROM t1 WHERE t1.a = t2.a)", + ) def _assert_from_uncorrelated(self, stmt): self.assert_compile( stmt, "SELECT t2.a, anon_1.a FROM t2, " - "(SELECT t1.a AS a FROM t1, t2 WHERE t1.a = t2.a) AS anon_1") + "(SELECT t1.a AS a FROM t1, t2 WHERE t1.a = t2.a) AS anon_1", + ) def _assert_from_all_uncorrelated(self, stmt): self.assert_compile( stmt, "SELECT t1.a, t2.a, anon_1.a FROM t1, t2, " - "(SELECT t1.a AS a FROM t1, t2 WHERE t1.a = t2.a) AS anon_1") + "(SELECT t1.a AS a FROM t1, t2 WHERE t1.a = t2.a) AS anon_1", + ) def _assert_where_uncorrelated(self, stmt): - self.assert_compile(stmt, - "SELECT t2.a FROM t2 WHERE t2.a = " - "(SELECT t1.a FROM t1, t2 WHERE t1.a = t2.a)") + self.assert_compile( + stmt, + "SELECT t2.a FROM t2 WHERE t2.a = " + "(SELECT t1.a FROM t1, t2 WHERE t1.a = t2.a)", + ) def _assert_column_uncorrelated(self, stmt): - self.assert_compile(stmt, - "SELECT t2.a, (SELECT t1.a FROM t1, t2 " - "WHERE t1.a = t2.a) AS anon_1 FROM t2") + self.assert_compile( + stmt, + "SELECT t2.a, (SELECT t1.a FROM t1, t2 " + "WHERE t1.a = t2.a) AS anon_1 FROM t2", + ) def _assert_having_uncorrelated(self, stmt): - self.assert_compile(stmt, - "SELECT t2.a FROM t2 HAVING t2.a = " - "(SELECT t1.a FROM t1, t2 WHERE t1.a = t2.a)") + self.assert_compile( + stmt, + "SELECT t2.a FROM t2 HAVING t2.a = " + "(SELECT t1.a FROM t1, t2 WHERE t1.a = t2.a)", + ) def _assert_where_single_full_correlated(self, stmt): - self.assert_compile(stmt, - "SELECT t1.a FROM t1 WHERE t1.a = (SELECT t1.a)") + self.assert_compile( + stmt, "SELECT t1.a FROM t1 WHERE t1.a = (SELECT t1.a)" + ) def test_correlate_semiauto_where(self): t1, t2, s1 = self._fixture() self._assert_where_correlated( - select([t2]).where(t2.c.a == s1.correlate(t2))) + select([t2]).where(t2.c.a == s1.correlate(t2)) + ) def test_correlate_semiauto_column(self): t1, t2, s1 = self._fixture() self._assert_column_correlated( - select([t2, s1.correlate(t2).as_scalar()])) + select([t2, s1.correlate(t2).as_scalar()]) + ) def test_correlate_semiauto_from(self): t1, t2, s1 = self._fixture() - self._assert_from_uncorrelated( - select([t2, s1.correlate(t2).alias()])) + self._assert_from_uncorrelated(select([t2, s1.correlate(t2).alias()])) def test_correlate_semiauto_having(self): t1, t2, s1 = self._fixture() self._assert_having_correlated( - select([t2]).having(t2.c.a == s1.correlate(t2))) + select([t2]).having(t2.c.a == s1.correlate(t2)) + ) def test_correlate_except_inclusion_where(self): t1, t2, s1 = self._fixture() self._assert_where_correlated( - select([t2]).where(t2.c.a == s1.correlate_except(t1))) + select([t2]).where(t2.c.a == s1.correlate_except(t1)) + ) def test_correlate_except_exclusion_where(self): t1, t2, s1 = self._fixture() self._assert_where_uncorrelated( - select([t2]).where(t2.c.a == s1.correlate_except(t2))) + select([t2]).where(t2.c.a == s1.correlate_except(t2)) + ) def test_correlate_except_inclusion_column(self): t1, t2, s1 = self._fixture() self._assert_column_correlated( - select([t2, s1.correlate_except(t1).as_scalar()])) + select([t2, s1.correlate_except(t1).as_scalar()]) + ) def test_correlate_except_exclusion_column(self): t1, t2, s1 = self._fixture() self._assert_column_uncorrelated( - select([t2, s1.correlate_except(t2).as_scalar()])) + select([t2, s1.correlate_except(t2).as_scalar()]) + ) def test_correlate_except_inclusion_from(self): t1, t2, s1 = self._fixture() self._assert_from_uncorrelated( - select([t2, s1.correlate_except(t1).alias()])) + select([t2, s1.correlate_except(t1).alias()]) + ) def test_correlate_except_exclusion_from(self): t1, t2, s1 = self._fixture() self._assert_from_uncorrelated( - select([t2, s1.correlate_except(t2).alias()])) + select([t2, s1.correlate_except(t2).alias()]) + ) def test_correlate_except_none(self): t1, t2, s1 = self._fixture() self._assert_where_all_correlated( - select([t1, t2]).where(t2.c.a == s1.correlate_except(None))) + select([t1, t2]).where(t2.c.a == s1.correlate_except(None)) + ) def test_correlate_except_having(self): t1, t2, s1 = self._fixture() self._assert_having_correlated( - select([t2]).having(t2.c.a == s1.correlate_except(t1))) + select([t2]).having(t2.c.a == s1.correlate_except(t1)) + ) def test_correlate_auto_where(self): t1, t2, s1 = self._fixture() - self._assert_where_correlated( - select([t2]).where(t2.c.a == s1)) + self._assert_where_correlated(select([t2]).where(t2.c.a == s1)) def test_correlate_auto_column(self): t1, t2, s1 = self._fixture() - self._assert_column_correlated( - select([t2, s1.as_scalar()])) + self._assert_column_correlated(select([t2, s1.as_scalar()])) def test_correlate_auto_from(self): t1, t2, s1 = self._fixture() - self._assert_from_uncorrelated( - select([t2, s1.alias()])) + self._assert_from_uncorrelated(select([t2, s1.alias()])) def test_correlate_auto_having(self): t1, t2, s1 = self._fixture() - self._assert_having_correlated( - select([t2]).having(t2.c.a == s1)) + self._assert_having_correlated(select([t2]).having(t2.c.a == s1)) def test_correlate_disabled_where(self): t1, t2, s1 = self._fixture() self._assert_where_uncorrelated( - select([t2]).where(t2.c.a == s1.correlate(None))) + select([t2]).where(t2.c.a == s1.correlate(None)) + ) def test_correlate_disabled_column(self): t1, t2, s1 = self._fixture() self._assert_column_uncorrelated( - select([t2, s1.correlate(None).as_scalar()])) + select([t2, s1.correlate(None).as_scalar()]) + ) def test_correlate_disabled_from(self): t1, t2, s1 = self._fixture() self._assert_from_uncorrelated( - select([t2, s1.correlate(None).alias()])) + select([t2, s1.correlate(None).alias()]) + ) def test_correlate_disabled_having(self): t1, t2, s1 = self._fixture() self._assert_having_uncorrelated( - select([t2]).having(t2.c.a == s1.correlate(None))) + select([t2]).having(t2.c.a == s1.correlate(None)) + ) def test_correlate_all_where(self): t1, t2, s1 = self._fixture() self._assert_where_all_correlated( - select([t1, t2]).where(t2.c.a == s1.correlate(t1, t2))) + select([t1, t2]).where(t2.c.a == s1.correlate(t1, t2)) + ) def test_correlate_all_column(self): t1, t2, s1 = self._fixture() self._assert_column_all_correlated( - select([t1, t2, s1.correlate(t1, t2).as_scalar()])) + select([t1, t2, s1.correlate(t1, t2).as_scalar()]) + ) def test_correlate_all_from(self): t1, t2, s1 = self._fixture() self._assert_from_all_uncorrelated( - select([t1, t2, s1.correlate(t1, t2).alias()])) + select([t1, t2, s1.correlate(t1, t2).alias()]) + ) def test_correlate_where_all_unintentional(self): t1, t2, s1 = self._fixture() assert_raises_message( exc.InvalidRequestError, "returned no FROM clauses due to auto-correlation", - select([t1, t2]).where(t2.c.a == s1).compile + select([t1, t2]).where(t2.c.a == s1).compile, ) def test_correlate_from_all_ok(self): @@ -3624,16 +4008,16 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( select([t1, t2, s1]), "SELECT t1.a, t2.a, a FROM t1, t2, " - "(SELECT t1.a AS a FROM t1, t2 WHERE t1.a = t2.a)" + "(SELECT t1.a AS a FROM t1, t2 WHERE t1.a = t2.a)", ) def test_correlate_auto_where_singlefrom(self): t1, t2, s1 = self._fixture() s = select([t1.c.a]) s2 = select([t1]).where(t1.c.a == s) - self.assert_compile(s2, - "SELECT t1.a FROM t1 WHERE t1.a = " - "(SELECT t1.a FROM t1)") + self.assert_compile( + s2, "SELECT t1.a FROM t1 WHERE t1.a = " "(SELECT t1.a FROM t1)" + ) def test_correlate_semiauto_where_singlefrom(self): t1, t2, s1 = self._fixture() @@ -3654,89 +4038,103 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL): def test_correlate_alone_noeffect(self): # new as of #2668 t1, t2, s1 = self._fixture() - self.assert_compile(s1.correlate(t1, t2), - "SELECT t1.a FROM t1, t2 WHERE t1.a = t2.a") + self.assert_compile( + s1.correlate(t1, t2), "SELECT t1.a FROM t1, t2 WHERE t1.a = t2.a" + ) def test_correlate_except_froms(self): # new as of #2748 - t1 = table('t1', column('a')) - t2 = table('t2', column('a'), column('b')) + t1 = table("t1", column("a")) + t2 = table("t2", column("a"), column("b")) s = select([t2.c.b]).where(t1.c.a == t2.c.a) - s = s.correlate_except(t2).alias('s') + s = s.correlate_except(t2).alias("s") s2 = select([func.foo(s.c.b)]).as_scalar() s3 = select([t1], order_by=s2) self.assert_compile( - s3, "SELECT t1.a FROM t1 ORDER BY " + s3, + "SELECT t1.a FROM t1 ORDER BY " "(SELECT foo(s.b) AS foo_1 FROM " - "(SELECT t2.b AS b FROM t2 WHERE t1.a = t2.a) AS s)") + "(SELECT t2.b AS b FROM t2 WHERE t1.a = t2.a) AS s)", + ) def test_multilevel_froms_correlation(self): # new as of #2748 - p = table('parent', column('id')) - c = table('child', column('id'), column('parent_id'), column('pos')) + p = table("parent", column("id")) + c = table("child", column("id"), column("parent_id"), column("pos")) - s = c.select().where( - c.c.parent_id == p.c.id).order_by( - c.c.pos).limit(1) + s = ( + c.select() + .where(c.c.parent_id == p.c.id) + .order_by(c.c.pos) + .limit(1) + ) s = s.correlate(p) s = exists().select_from(s).where(s.c.id == 1) s = select([p]).where(s) self.assert_compile( - s, "SELECT parent.id FROM parent WHERE EXISTS (SELECT * " + s, + "SELECT parent.id FROM parent WHERE EXISTS (SELECT * " "FROM (SELECT child.id AS id, child.parent_id AS parent_id, " "child.pos AS pos FROM child WHERE child.parent_id = parent.id " - "ORDER BY child.pos LIMIT :param_1) WHERE id = :id_1)") + "ORDER BY child.pos LIMIT :param_1) WHERE id = :id_1)", + ) def test_no_contextless_correlate_except(self): # new as of #2748 - t1 = table('t1', column('x')) - t2 = table('t2', column('y')) - t3 = table('t3', column('z')) + t1 = table("t1", column("x")) + t2 = table("t2", column("y")) + t3 = table("t3", column("z")) - s = select([t1]).where(t1.c.x == t2.c.y).\ - where(t2.c.y == t3.c.z).correlate_except(t1) + s = ( + select([t1]) + .where(t1.c.x == t2.c.y) + .where(t2.c.y == t3.c.z) + .correlate_except(t1) + ) self.assert_compile( - s, - "SELECT t1.x FROM t1, t2, t3 WHERE t1.x = t2.y AND t2.y = t3.z") + s, "SELECT t1.x FROM t1, t2, t3 WHERE t1.x = t2.y AND t2.y = t3.z" + ) def test_multilevel_implicit_correlation_disabled(self): # test that implicit correlation with multilevel WHERE correlation # behaves like 0.8.1, 0.7 (i.e. doesn't happen) - t1 = table('t1', column('x')) - t2 = table('t2', column('y')) - t3 = table('t3', column('z')) + t1 = table("t1", column("x")) + t2 = table("t2", column("y")) + t3 = table("t3", column("z")) s = select([t1.c.x]).where(t1.c.x == t2.c.y) s2 = select([t3.c.z]).where(t3.c.z == s.as_scalar()) s3 = select([t1]).where(t1.c.x == s2.as_scalar()) - self.assert_compile(s3, - "SELECT t1.x FROM t1 " - "WHERE t1.x = (SELECT t3.z " - "FROM t3 " - "WHERE t3.z = (SELECT t1.x " - "FROM t1, t2 " - "WHERE t1.x = t2.y))" - ) + self.assert_compile( + s3, + "SELECT t1.x FROM t1 " + "WHERE t1.x = (SELECT t3.z " + "FROM t3 " + "WHERE t3.z = (SELECT t1.x " + "FROM t1, t2 " + "WHERE t1.x = t2.y))", + ) def test_from_implicit_correlation_disabled(self): # test that implicit correlation with immediate and # multilevel FROM clauses behaves like 0.8.1 (i.e. doesn't happen) - t1 = table('t1', column('x')) - t2 = table('t2', column('y')) + t1 = table("t1", column("x")) + t2 = table("t2", column("y")) s = select([t1.c.x]).where(t1.c.x == t2.c.y) s2 = select([t2, s]) s3 = select([t1, s2]) - self.assert_compile(s3, - "SELECT t1.x, y, x FROM t1, " - "(SELECT t2.y AS y, x FROM t2, " - "(SELECT t1.x AS x FROM t1, t2 WHERE t1.x = t2.y))" - ) + self.assert_compile( + s3, + "SELECT t1.x, y, x FROM t1, " + "(SELECT t2.y AS y, x FROM t2, " + "(SELECT t1.x AS x FROM t1, t2 WHERE t1.x = t2.y))", + ) class CoercionTest(fixtures.TestBase, AssertsCompiledSQL): @@ -3744,28 +4142,27 @@ class CoercionTest(fixtures.TestBase, AssertsCompiledSQL): def _fixture(self): m = MetaData() - return Table('foo', m, - Column('id', Integer)) + return Table("foo", m, Column("id", Integer)) - bool_table = table('t', column('x', Boolean)) + bool_table = table("t", column("x", Boolean)) def test_coerce_bool_where(self): self.assert_compile( select([self.bool_table]).where(self.bool_table.c.x), - "SELECT t.x FROM t WHERE t.x" + "SELECT t.x FROM t WHERE t.x", ) def test_coerce_bool_where_non_native(self): self.assert_compile( select([self.bool_table]).where(self.bool_table.c.x), "SELECT t.x FROM t WHERE t.x = 1", - dialect=default.DefaultDialect(supports_native_boolean=False) + dialect=default.DefaultDialect(supports_native_boolean=False), ) self.assert_compile( select([self.bool_table]).where(~self.bool_table.c.x), "SELECT t.x FROM t WHERE t.x = 0", - dialect=default.DefaultDialect(supports_native_boolean=False) + dialect=default.DefaultDialect(supports_native_boolean=False), ) def test_null_constant(self): @@ -3779,40 +4176,35 @@ class CoercionTest(fixtures.TestBase, AssertsCompiledSQL): def test_val_and_false(self): t = self._fixture() - self.assert_compile(and_(t.c.id == 1, False), - "false") + self.assert_compile(and_(t.c.id == 1, False), "false") def test_val_and_true_coerced(self): t = self._fixture() - self.assert_compile(and_(t.c.id == 1, True), - "foo.id = :id_1") + self.assert_compile(and_(t.c.id == 1, True), "foo.id = :id_1") def test_val_is_null_coerced(self): t = self._fixture() - self.assert_compile(and_(t.c.id == None), # noqa - "foo.id IS NULL") + self.assert_compile(and_(t.c.id == None), "foo.id IS NULL") # noqa def test_val_and_None(self): t = self._fixture() - self.assert_compile(and_(t.c.id == 1, None), - "foo.id = :id_1 AND NULL") + self.assert_compile(and_(t.c.id == 1, None), "foo.id = :id_1 AND NULL") def test_None_and_val(self): t = self._fixture() - self.assert_compile(and_(None, t.c.id == 1), - "NULL AND foo.id = :id_1") + self.assert_compile(and_(None, t.c.id == 1), "NULL AND foo.id = :id_1") def test_None_and_nothing(self): # current convention is None in and_() # returns None May want # to revise this at some point. - self.assert_compile( - and_(None), "NULL") + self.assert_compile(and_(None), "NULL") def test_val_and_null(self): t = self._fixture() - self.assert_compile(and_(t.c.id == 1, null()), - "foo.id = :id_1 AND NULL") + self.assert_compile( + and_(t.c.id == 1, null()), "foo.id = :id_1 AND NULL" + ) class ResultMapTest(fixtures.TestBase): @@ -3823,101 +4215,109 @@ class ResultMapTest(fixtures.TestBase): """ def test_compound_populates(self): - t = Table('t', MetaData(), Column('a', Integer), Column('b', Integer)) + t = Table("t", MetaData(), Column("a", Integer), Column("b", Integer)) stmt = select([t]).union(select([t])) comp = stmt.compile() eq_( comp._create_result_map(), - {'a': ('a', (t.c.a, 'a', 'a'), t.c.a.type), - 'b': ('b', (t.c.b, 'b', 'b'), t.c.b.type)} + { + "a": ("a", (t.c.a, "a", "a"), t.c.a.type), + "b": ("b", (t.c.b, "b", "b"), t.c.b.type), + }, ) def test_compound_not_toplevel_doesnt_populate(self): - t = Table('t', MetaData(), Column('a', Integer), Column('b', Integer)) + t = Table("t", MetaData(), Column("a", Integer), Column("b", Integer)) subq = select([t]).union(select([t])) stmt = select([t.c.a]).select_from(t.join(subq, t.c.a == subq.c.a)) comp = stmt.compile() eq_( comp._create_result_map(), - {'a': ('a', (t.c.a, 'a', 'a'), t.c.a.type)} + {"a": ("a", (t.c.a, "a", "a"), t.c.a.type)}, ) def test_compound_only_top_populates(self): - t = Table('t', MetaData(), Column('a', Integer), Column('b', Integer)) + t = Table("t", MetaData(), Column("a", Integer), Column("b", Integer)) stmt = select([t.c.a]).union(select([t.c.b])) comp = stmt.compile() eq_( comp._create_result_map(), - {'a': ('a', (t.c.a, 'a', 'a'), t.c.a.type)}, + {"a": ("a", (t.c.a, "a", "a"), t.c.a.type)}, ) def test_label_plus_element(self): - t = Table('t', MetaData(), Column('a', Integer)) - l1 = t.c.a.label('bar') + t = Table("t", MetaData(), Column("a", Integer)) + l1 = t.c.a.label("bar") tc = type_coerce(t.c.a, String) stmt = select([t.c.a, l1, tc]) comp = stmt.compile() - tc_anon_label = comp._create_result_map()['anon_1'][1][0] + tc_anon_label = comp._create_result_map()["anon_1"][1][0] eq_( comp._create_result_map(), { - 'a': ('a', (t.c.a, 'a', 'a'), t.c.a.type), - 'bar': ('bar', (l1, 'bar'), l1.type), - 'anon_1': ( - '%%(%d anon)s' % id(tc), - (tc_anon_label, 'anon_1', tc), tc.type), + "a": ("a", (t.c.a, "a", "a"), t.c.a.type), + "bar": ("bar", (l1, "bar"), l1.type), + "anon_1": ( + "%%(%d anon)s" % id(tc), + (tc_anon_label, "anon_1", tc), + tc.type, + ), }, ) def test_label_conflict_union(self): - t1 = Table('t1', MetaData(), Column('a', Integer), - Column('b', Integer)) - t2 = Table('t2', MetaData(), Column('t1_a', Integer)) + t1 = Table( + "t1", MetaData(), Column("a", Integer), Column("b", Integer) + ) + t2 = Table("t2", MetaData(), Column("t1_a", Integer)) union = select([t2]).union(select([t2])).alias() t1_alias = t1.alias() - stmt = select([t1, t1_alias]).select_from( - t1.join(union, t1.c.a == union.c.t1_a)).apply_labels() + stmt = ( + select([t1, t1_alias]) + .select_from(t1.join(union, t1.c.a == union.c.t1_a)) + .apply_labels() + ) comp = stmt.compile() eq_( set(comp._create_result_map()), - set(['t1_1_b', 't1_1_a', 't1_a', 't1_b']) - ) - is_( - comp._create_result_map()['t1_a'][1][2], t1.c.a + set(["t1_1_b", "t1_1_a", "t1_a", "t1_b"]), ) + is_(comp._create_result_map()["t1_a"][1][2], t1.c.a) def test_insert_with_select_values(self): - astring = Column('a', String) - aint = Column('a', Integer) + astring = Column("a", String) + aint = Column("a", Integer) m = MetaData() - Table('t1', m, astring) - t2 = Table('t2', m, aint) + Table("t1", m, astring) + t2 = Table("t2", m, aint) stmt = t2.insert().values(a=select([astring])).returning(aint) comp = stmt.compile(dialect=postgresql.dialect()) eq_( comp._create_result_map(), - {'a': ('a', (aint, 'a', 'a'), aint.type)} + {"a": ("a", (aint, "a", "a"), aint.type)}, ) def test_insert_from_select(self): - astring = Column('a', String) - aint = Column('a', Integer) + astring = Column("a", String) + aint = Column("a", Integer) m = MetaData() - Table('t1', m, astring) - t2 = Table('t2', m, aint) + Table("t1", m, astring) + t2 = Table("t2", m, aint) - stmt = t2.insert().from_select(['a'], select([astring])).\ - returning(aint) + stmt = ( + t2.insert().from_select(["a"], select([astring])).returning(aint) + ) comp = stmt.compile(dialect=postgresql.dialect()) eq_( comp._create_result_map(), - {'a': ('a', (aint, 'a', 'a'), aint.type)} + {"a": ("a", (aint, "a", "a"), aint.type)}, ) def test_nested_api(self): from sqlalchemy.engine.result import ResultMetaData + stmt2 = select([table2]) stmt1 = select([table1]).select_from(stmt2) @@ -3936,7 +4336,8 @@ class ResultMapTest(fixtures.TestBase): self._add_to_result_map("k1", "k1", (1, 2, 3), int_) else: text = super(MyCompiler, self).visit_select( - stmt, *arg, **kw) + stmt, *arg, **kw + ) self._add_to_result_map("k2", "k2", (3, 4, 5), int_) return text @@ -3945,62 +4346,68 @@ class ResultMapTest(fixtures.TestBase): eq_( ResultMetaData._create_result_map(contexts[stmt2][0]), { - 'otherid': ( - 'otherid', - (table2.c.otherid, 'otherid', 'otherid'), - table2.c.otherid.type), - 'othername': ( - 'othername', - (table2.c.othername, 'othername', 'othername'), - table2.c.othername.type), - 'k1': ('k1', (1, 2, 3), int_) - } + "otherid": ( + "otherid", + (table2.c.otherid, "otherid", "otherid"), + table2.c.otherid.type, + ), + "othername": ( + "othername", + (table2.c.othername, "othername", "othername"), + table2.c.othername.type, + ), + "k1": ("k1", (1, 2, 3), int_), + }, ) eq_( comp._create_result_map(), { - 'myid': ( - 'myid', - (table1.c.myid, 'myid', 'myid'), table1.c.myid.type + "myid": ( + "myid", + (table1.c.myid, "myid", "myid"), + table1.c.myid.type, + ), + "k2": ("k2", (3, 4, 5), int_), + "name": ( + "name", + (table1.c.name, "name", "name"), + table1.c.name.type, ), - 'k2': ('k2', (3, 4, 5), int_), - 'name': ( - 'name', (table1.c.name, 'name', 'name'), - table1.c.name.type), - 'description': ( - 'description', - (table1.c.description, 'description', 'description'), - table1.c.description.type)} + "description": ( + "description", + (table1.c.description, "description", "description"), + table1.c.description.type, + ), + }, ) def test_select_wraps_for_translate_ambiguity(self): # test for issue #3657 - t = table('a', column('x'), column('y'), column('z')) + t = table("a", column("x"), column("y"), column("z")) - l1, l2, l3 = t.c.z.label('a'), t.c.x.label('b'), t.c.x.label('c') + l1, l2, l3 = t.c.z.label("a"), t.c.x.label("b"), t.c.x.label("c") orig = [t.c.x, t.c.y, l1, l2, l3] stmt = select(orig) wrapped = stmt._generate() wrapped = wrapped.column( - func.ROW_NUMBER().over(order_by=t.c.z)).alias() + func.ROW_NUMBER().over(order_by=t.c.z) + ).alias() wrapped_again = select([c for c in wrapped.c]) compiled = wrapped_again.compile( - compile_kwargs={'select_wraps_for': stmt}) + compile_kwargs={"select_wraps_for": stmt} + ) proxied = [obj[0] for (k, n, obj, type_) in compiled._result_columns] - for orig_obj, proxied_obj in zip( - orig, - proxied - ): + for orig_obj, proxied_obj in zip(orig, proxied): is_(orig_obj, proxied_obj) def test_select_wraps_for_translate_ambiguity_dupe_cols(self): # test for issue #3657 - t = table('a', column('x'), column('y'), column('z')) + t = table("a", column("x"), column("y"), column("z")) - l1, l2, l3 = t.c.z.label('a'), t.c.x.label('b'), t.c.x.label('c') + l1, l2, l3 = t.c.z.label("a"), t.c.x.label("b"), t.c.x.label("c") orig = [t.c.x, t.c.y, l1, l2, l3] # create the statement with some duplicate columns. right now @@ -4018,7 +4425,8 @@ class ResultMapTest(fixtures.TestBase): wrapped = stmt._generate() wrapped = wrapped.column( - func.ROW_NUMBER().over(order_by=t.c.z)).alias() + func.ROW_NUMBER().over(order_by=t.c.z) + ).alias() # so when we wrap here we're going to have only 5 columns wrapped_again = select([c for c in wrapped.c]) @@ -4027,11 +4435,9 @@ class ResultMapTest(fixtures.TestBase): # "select_wraps_for" can't use inner_columns to match because # these collections are not the same compiled = wrapped_again.compile( - compile_kwargs={'select_wraps_for': stmt}) + compile_kwargs={"select_wraps_for": stmt} + ) proxied = [obj[0] for (k, n, obj, type_) in compiled._result_columns] - for orig_obj, proxied_obj in zip( - orig, - proxied - ): + for orig_obj, proxied_obj in zip(orig, proxied): is_(orig_obj, proxied_obj) diff --git a/test/sql/test_constraints.py b/test/sql/test_constraints.py index 3365b3cf0d..a5d2043ce8 100644 --- a/test/sql/test_constraints.py +++ b/test/sql/test_constraints.py @@ -1,81 +1,109 @@ from sqlalchemy.testing import assert_raises, assert_raises_message -from sqlalchemy import Table, Integer, String, Column, PrimaryKeyConstraint,\ - ForeignKeyConstraint, ForeignKey, UniqueConstraint, Index, MetaData, \ - CheckConstraint, func, text +from sqlalchemy import ( + Table, + Integer, + String, + Column, + PrimaryKeyConstraint, + ForeignKeyConstraint, + ForeignKey, + UniqueConstraint, + Index, + MetaData, + CheckConstraint, + func, + text, +) from sqlalchemy import exc, schema -from sqlalchemy.testing import fixtures, AssertsExecutionResults, \ - AssertsCompiledSQL +from sqlalchemy.testing import ( + fixtures, + AssertsExecutionResults, + AssertsCompiledSQL, +) from sqlalchemy import testing from sqlalchemy.engine import default from sqlalchemy.testing import engines from sqlalchemy.testing.assertions import expect_warnings from sqlalchemy.testing import eq_ -from sqlalchemy.testing.assertsql import (AllOf, - RegexSQL, - CompiledSQL, - DialectSQL) +from sqlalchemy.testing.assertsql import ( + AllOf, + RegexSQL, + CompiledSQL, + DialectSQL, +) from sqlalchemy.sql import table, column class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): - __dialect__ = 'default' + __dialect__ = "default" __backend__ = True @testing.provide_metadata def test_pk_fk_constraint_create(self): metadata = self.metadata - Table('employees', metadata, - Column('id', Integer), - Column('soc', String(40)), - Column('name', String(30)), - PrimaryKeyConstraint('id', 'soc') - ) - Table('elements', metadata, - Column('id', Integer), - Column('stuff', String(30)), - Column('emp_id', Integer), - Column('emp_soc', String(40)), - PrimaryKeyConstraint('id', name='elements_primkey'), - ForeignKeyConstraint(['emp_id', 'emp_soc'], - ['employees.id', 'employees.soc']) - ) + Table( + "employees", + metadata, + Column("id", Integer), + Column("soc", String(40)), + Column("name", String(30)), + PrimaryKeyConstraint("id", "soc"), + ) + Table( + "elements", + metadata, + Column("id", Integer), + Column("stuff", String(30)), + Column("emp_id", Integer), + Column("emp_soc", String(40)), + PrimaryKeyConstraint("id", name="elements_primkey"), + ForeignKeyConstraint( + ["emp_id", "emp_soc"], ["employees.id", "employees.soc"] + ), + ) self.assert_sql_execution( testing.db, lambda: metadata.create_all(checkfirst=False), - CompiledSQL('CREATE TABLE employees (' - 'id INTEGER NOT NULL, ' - 'soc VARCHAR(40) NOT NULL, ' - 'name VARCHAR(30), ' - 'PRIMARY KEY (id, soc)' - ')' - ), - CompiledSQL('CREATE TABLE elements (' - 'id INTEGER NOT NULL, ' - 'stuff VARCHAR(30), ' - 'emp_id INTEGER, ' - 'emp_soc VARCHAR(40), ' - 'CONSTRAINT elements_primkey PRIMARY KEY (id), ' - 'FOREIGN KEY(emp_id, emp_soc) ' - 'REFERENCES employees (id, soc)' - ')' - ) - ) - - @testing.force_drop_names('a', 'b') + CompiledSQL( + "CREATE TABLE employees (" + "id INTEGER NOT NULL, " + "soc VARCHAR(40) NOT NULL, " + "name VARCHAR(30), " + "PRIMARY KEY (id, soc)" + ")" + ), + CompiledSQL( + "CREATE TABLE elements (" + "id INTEGER NOT NULL, " + "stuff VARCHAR(30), " + "emp_id INTEGER, " + "emp_soc VARCHAR(40), " + "CONSTRAINT elements_primkey PRIMARY KEY (id), " + "FOREIGN KEY(emp_id, emp_soc) " + "REFERENCES employees (id, soc)" + ")" + ), + ) + + @testing.force_drop_names("a", "b") def test_fk_cant_drop_cycled_unnamed(self): metadata = MetaData() - Table("a", metadata, - Column('id', Integer, primary_key=True), - Column('bid', Integer), - ForeignKeyConstraint(["bid"], ["b.id"]) - ) Table( - "b", metadata, - Column('id', Integer, primary_key=True), + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("bid", Integer), + ForeignKeyConstraint(["bid"], ["b.id"]), + ) + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), Column("aid", Integer), - ForeignKeyConstraint(["aid"], ["a.id"])) + ForeignKeyConstraint(["aid"], ["a.id"]), + ) metadata.create_all(testing.db) if testing.db.dialect.supports_alter: assert_raises_message( @@ -85,159 +113,186 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): "that the ForeignKey and ForeignKeyConstraint objects " "involved in the cycle have names so that they can be " "dropped using DROP CONSTRAINT.", - metadata.drop_all, testing.db + metadata.drop_all, + testing.db, ) else: with expect_warnings( - "Can't sort tables for DROP; an unresolvable " - "foreign key dependency "): + "Can't sort tables for DROP; an unresolvable " + "foreign key dependency " + ): with self.sql_execution_asserter() as asserter: metadata.drop_all(testing.db, checkfirst=False) asserter.assert_( - AllOf( - CompiledSQL("DROP TABLE a"), - CompiledSQL("DROP TABLE b") - ) + AllOf(CompiledSQL("DROP TABLE a"), CompiledSQL("DROP TABLE b")) ) @testing.provide_metadata def test_fk_table_auto_alter_constraint_create(self): metadata = self.metadata - Table("a", metadata, - Column('id', Integer, primary_key=True), - Column('bid', Integer), - ForeignKeyConstraint(["bid"], ["b.id"]) - ) Table( - "b", metadata, - Column('id', Integer, primary_key=True), + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("bid", Integer), + ForeignKeyConstraint(["bid"], ["b.id"]), + ) + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), Column("aid", Integer), - ForeignKeyConstraint(["aid"], ["a.id"], name="bfk")) + ForeignKeyConstraint(["aid"], ["a.id"], name="bfk"), + ) self._assert_cyclic_constraint( - metadata, auto=True, sqlite_warning=True) + metadata, auto=True, sqlite_warning=True + ) @testing.provide_metadata def test_fk_column_auto_alter_inline_constraint_create(self): metadata = self.metadata - Table("a", metadata, - Column('id', Integer, primary_key=True), - Column('bid', Integer, ForeignKey("b.id")), - ) - Table("b", metadata, - Column('id', Integer, primary_key=True), - Column("aid", Integer, - ForeignKey("a.id", name="bfk") - ), - ) + Table( + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("bid", Integer, ForeignKey("b.id")), + ) + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("aid", Integer, ForeignKey("a.id", name="bfk")), + ) self._assert_cyclic_constraint( - metadata, auto=True, sqlite_warning=True) + metadata, auto=True, sqlite_warning=True + ) @testing.provide_metadata def test_fk_column_use_alter_inline_constraint_create(self): metadata = self.metadata - Table("a", metadata, - Column('id', Integer, primary_key=True), - Column('bid', Integer, ForeignKey("b.id")), - ) - Table("b", metadata, - Column('id', Integer, primary_key=True), - Column("aid", Integer, - ForeignKey("a.id", name="bfk", use_alter=True) - ), - ) + Table( + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("bid", Integer, ForeignKey("b.id")), + ) + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column( + "aid", Integer, ForeignKey("a.id", name="bfk", use_alter=True) + ), + ) self._assert_cyclic_constraint(metadata, auto=False) @testing.provide_metadata def test_fk_table_use_alter_constraint_create(self): metadata = self.metadata - Table("a", metadata, - Column('id', Integer, primary_key=True), - Column('bid', Integer), - ForeignKeyConstraint(["bid"], ["b.id"]) - ) Table( - "b", metadata, - Column('id', Integer, primary_key=True), + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("bid", Integer), + ForeignKeyConstraint(["bid"], ["b.id"]), + ) + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), Column("aid", Integer), ForeignKeyConstraint( - ["aid"], ["a.id"], use_alter=True, name="bfk")) + ["aid"], ["a.id"], use_alter=True, name="bfk" + ), + ) self._assert_cyclic_constraint(metadata) @testing.provide_metadata def test_fk_column_use_alter_constraint_create(self): metadata = self.metadata - Table("a", metadata, - Column('id', Integer, primary_key=True), - Column('bid', Integer, ForeignKey("b.id")), - ) - Table("b", metadata, - Column('id', Integer, primary_key=True), - Column("aid", Integer, - ForeignKey("a.id", use_alter=True, name="bfk") - ), - ) + Table( + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("bid", Integer, ForeignKey("b.id")), + ) + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column( + "aid", Integer, ForeignKey("a.id", use_alter=True, name="bfk") + ), + ) self._assert_cyclic_constraint(metadata, auto=False) def _assert_cyclic_constraint( - self, metadata, auto=False, sqlite_warning=False): + self, metadata, auto=False, sqlite_warning=False + ): if testing.db.dialect.supports_alter: self._assert_cyclic_constraint_supports_alter(metadata, auto=auto) else: self._assert_cyclic_constraint_no_alter( - metadata, auto=auto, sqlite_warning=sqlite_warning) + metadata, auto=auto, sqlite_warning=sqlite_warning + ) def _assert_cyclic_constraint_supports_alter(self, metadata, auto=False): table_assertions = [] if auto: table_assertions = [ - CompiledSQL('CREATE TABLE b (' - 'id INTEGER NOT NULL, ' - 'aid INTEGER, ' - 'PRIMARY KEY (id)' - ')' - ), CompiledSQL( - 'CREATE TABLE a (' - 'id INTEGER NOT NULL, ' - 'bid INTEGER, ' - 'PRIMARY KEY (id)' - ')' - ) + "CREATE TABLE b (" + "id INTEGER NOT NULL, " + "aid INTEGER, " + "PRIMARY KEY (id)" + ")" + ), + CompiledSQL( + "CREATE TABLE a (" + "id INTEGER NOT NULL, " + "bid INTEGER, " + "PRIMARY KEY (id)" + ")" + ), ] else: table_assertions = [ - CompiledSQL('CREATE TABLE b (' - 'id INTEGER NOT NULL, ' - 'aid INTEGER, ' - 'PRIMARY KEY (id)' - ')' - ), CompiledSQL( - 'CREATE TABLE a (' - 'id INTEGER NOT NULL, ' - 'bid INTEGER, ' - 'PRIMARY KEY (id), ' - 'FOREIGN KEY(bid) REFERENCES b (id)' - ')' - ) + "CREATE TABLE b (" + "id INTEGER NOT NULL, " + "aid INTEGER, " + "PRIMARY KEY (id)" + ")" + ), + CompiledSQL( + "CREATE TABLE a (" + "id INTEGER NOT NULL, " + "bid INTEGER, " + "PRIMARY KEY (id), " + "FOREIGN KEY(bid) REFERENCES b (id)" + ")" + ), ] assertions = [AllOf(*table_assertions)] fk_assertions = [] fk_assertions.append( - CompiledSQL('ALTER TABLE b ADD CONSTRAINT bfk ' - 'FOREIGN KEY(aid) REFERENCES a (id)') + CompiledSQL( + "ALTER TABLE b ADD CONSTRAINT bfk " + "FOREIGN KEY(aid) REFERENCES a (id)" + ) ) if auto: fk_assertions.append( - CompiledSQL('ALTER TABLE a ADD ' - 'FOREIGN KEY(bid) REFERENCES b (id)') + CompiledSQL( + "ALTER TABLE a ADD " "FOREIGN KEY(bid) REFERENCES b (id)" + ) ) assertions.append(AllOf(*fk_assertions)) @@ -246,9 +301,9 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): asserter.assert_(*assertions) assertions = [ - CompiledSQL('ALTER TABLE b DROP CONSTRAINT bfk'), + CompiledSQL("ALTER TABLE b DROP CONSTRAINT bfk"), CompiledSQL("DROP TABLE a"), - CompiledSQL("DROP TABLE b") + CompiledSQL("DROP TABLE b"), ] with self.sql_execution_asserter() as asserter: @@ -256,49 +311,50 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): asserter.assert_(*assertions) def _assert_cyclic_constraint_no_alter( - self, metadata, auto=False, sqlite_warning=False): + self, metadata, auto=False, sqlite_warning=False + ): table_assertions = [] if auto: table_assertions.append( DialectSQL( - 'CREATE TABLE b (' - 'id INTEGER NOT NULL, ' - 'aid INTEGER, ' - 'PRIMARY KEY (id), ' - 'CONSTRAINT bfk FOREIGN KEY(aid) REFERENCES a (id)' - ')' + "CREATE TABLE b (" + "id INTEGER NOT NULL, " + "aid INTEGER, " + "PRIMARY KEY (id), " + "CONSTRAINT bfk FOREIGN KEY(aid) REFERENCES a (id)" + ")" ) ) table_assertions.append( DialectSQL( - 'CREATE TABLE a (' - 'id INTEGER NOT NULL, ' - 'bid INTEGER, ' - 'PRIMARY KEY (id), ' - 'FOREIGN KEY(bid) REFERENCES b (id)' - ')' + "CREATE TABLE a (" + "id INTEGER NOT NULL, " + "bid INTEGER, " + "PRIMARY KEY (id), " + "FOREIGN KEY(bid) REFERENCES b (id)" + ")" ) ) else: table_assertions.append( DialectSQL( - 'CREATE TABLE b (' - 'id INTEGER NOT NULL, ' - 'aid INTEGER, ' - 'PRIMARY KEY (id), ' - 'CONSTRAINT bfk FOREIGN KEY(aid) REFERENCES a (id)' - ')' + "CREATE TABLE b (" + "id INTEGER NOT NULL, " + "aid INTEGER, " + "PRIMARY KEY (id), " + "CONSTRAINT bfk FOREIGN KEY(aid) REFERENCES a (id)" + ")" ) ) table_assertions.append( DialectSQL( - 'CREATE TABLE a (' - 'id INTEGER NOT NULL, ' - 'bid INTEGER, ' - 'PRIMARY KEY (id), ' - 'FOREIGN KEY(bid) REFERENCES b (id)' - ')' + "CREATE TABLE a (" + "id INTEGER NOT NULL, " + "bid INTEGER, " + "PRIMARY KEY (id), " + "FOREIGN KEY(bid) REFERENCES b (id)" + ")" ) ) @@ -308,10 +364,9 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): metadata.create_all(checkfirst=False) asserter.assert_(*assertions) - assertions = [AllOf( - CompiledSQL("DROP TABLE a"), - CompiledSQL("DROP TABLE b") - )] + assertions = [ + AllOf(CompiledSQL("DROP TABLE a"), CompiledSQL("DROP TABLE b")) + ] if sqlite_warning: with expect_warnings("Can't sort tables for DROP; "): @@ -326,38 +381,44 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): def test_cycle_unnamed_fks(self): metadata = MetaData(testing.db) - Table("a", metadata, - Column('id', Integer, primary_key=True), - Column('bid', Integer, ForeignKey("b.id")), - ) + Table( + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("bid", Integer, ForeignKey("b.id")), + ) - Table("b", metadata, - Column('id', Integer, primary_key=True), - Column("aid", Integer, ForeignKey("a.id")), - ) + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("aid", Integer, ForeignKey("a.id")), + ) assertions = [ AllOf( CompiledSQL( - 'CREATE TABLE b (' - 'id INTEGER NOT NULL, ' - 'aid INTEGER, ' - 'PRIMARY KEY (id)' - ')' + "CREATE TABLE b (" + "id INTEGER NOT NULL, " + "aid INTEGER, " + "PRIMARY KEY (id)" + ")" ), CompiledSQL( - 'CREATE TABLE a (' - 'id INTEGER NOT NULL, ' - 'bid INTEGER, ' - 'PRIMARY KEY (id)' - ')' - ) + "CREATE TABLE a (" + "id INTEGER NOT NULL, " + "bid INTEGER, " + "PRIMARY KEY (id)" + ")" + ), ), AllOf( - CompiledSQL('ALTER TABLE b ADD ' - 'FOREIGN KEY(aid) REFERENCES a (id)'), - CompiledSQL('ALTER TABLE a ADD ' - 'FOREIGN KEY(bid) REFERENCES b (id)') + CompiledSQL( + "ALTER TABLE b ADD " "FOREIGN KEY(aid) REFERENCES a (id)" + ), + CompiledSQL( + "ALTER TABLE a ADD " "FOREIGN KEY(bid) REFERENCES b (id)" + ), ), ] with self.sql_execution_asserter() as asserter: @@ -374,58 +435,65 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): "ForeignKey and ForeignKeyConstraint objects involved in the " "cycle have names so that they can be dropped using " "DROP CONSTRAINT.", - metadata.drop_all, checkfirst=False + metadata.drop_all, + checkfirst=False, ) else: with expect_warnings( - "Can't sort tables for DROP; an unresolvable " - "foreign key dependency exists between tables"): + "Can't sort tables for DROP; an unresolvable " + "foreign key dependency exists between tables" + ): with self.sql_execution_asserter() as asserter: metadata.drop_all(checkfirst=False) asserter.assert_( - AllOf( - CompiledSQL("DROP TABLE b"), - CompiledSQL("DROP TABLE a"), - ) + AllOf(CompiledSQL("DROP TABLE b"), CompiledSQL("DROP TABLE a")) ) @testing.force_drop_names("a", "b") def test_cycle_named_fks(self): metadata = MetaData(testing.db) - Table("a", metadata, - Column('id', Integer, primary_key=True), - Column('bid', Integer, ForeignKey("b.id")), - ) + Table( + "a", + metadata, + Column("id", Integer, primary_key=True), + Column("bid", Integer, ForeignKey("b.id")), + ) - Table("b", metadata, - Column('id', Integer, primary_key=True), - Column( - "aid", Integer, - ForeignKey("a.id", use_alter=True, name='aidfk')), - ) + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column( + "aid", + Integer, + ForeignKey("a.id", use_alter=True, name="aidfk"), + ), + ) assertions = [ AllOf( CompiledSQL( - 'CREATE TABLE b (' - 'id INTEGER NOT NULL, ' - 'aid INTEGER, ' - 'PRIMARY KEY (id)' - ')' + "CREATE TABLE b (" + "id INTEGER NOT NULL, " + "aid INTEGER, " + "PRIMARY KEY (id)" + ")" ), CompiledSQL( - 'CREATE TABLE a (' - 'id INTEGER NOT NULL, ' - 'bid INTEGER, ' - 'PRIMARY KEY (id), ' - 'FOREIGN KEY(bid) REFERENCES b (id)' - ')' - ) + "CREATE TABLE a (" + "id INTEGER NOT NULL, " + "bid INTEGER, " + "PRIMARY KEY (id), " + "FOREIGN KEY(bid) REFERENCES b (id)" + ")" + ), + ), + CompiledSQL( + "ALTER TABLE b ADD CONSTRAINT aidfk " + "FOREIGN KEY(aid) REFERENCES a (id)" ), - CompiledSQL('ALTER TABLE b ADD CONSTRAINT aidfk ' - 'FOREIGN KEY(aid) REFERENCES a (id)'), ] with self.sql_execution_asserter() as asserter: metadata.create_all(checkfirst=False) @@ -439,19 +507,15 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): asserter.assert_( CompiledSQL("ALTER TABLE b DROP CONSTRAINT aidfk"), AllOf( - CompiledSQL("DROP TABLE b"), - CompiledSQL("DROP TABLE a"), - ) + CompiledSQL("DROP TABLE b"), CompiledSQL("DROP TABLE a") + ), ) else: with self.sql_execution_asserter() as asserter: metadata.drop_all(checkfirst=False) asserter.assert_( - AllOf( - CompiledSQL("DROP TABLE b"), - CompiledSQL("DROP TABLE a"), - ), + AllOf(CompiledSQL("DROP TABLE b"), CompiledSQL("DROP TABLE a")) ) @testing.requires.check_constraints @@ -459,89 +523,112 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): def test_check_constraint_create(self): metadata = self.metadata - Table('foo', metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer), - CheckConstraint('x>y')) - Table('bar', metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer, CheckConstraint('x>7')), - Column('z', Integer) - ) + Table( + "foo", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + CheckConstraint("x>y"), + ) + Table( + "bar", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer, CheckConstraint("x>7")), + Column("z", Integer), + ) self.assert_sql_execution( testing.db, lambda: metadata.create_all(checkfirst=False), AllOf( - CompiledSQL('CREATE TABLE foo (' - 'id INTEGER NOT NULL, ' - 'x INTEGER, ' - 'y INTEGER, ' - 'PRIMARY KEY (id), ' - 'CHECK (x>y)' - ')' - ), - CompiledSQL('CREATE TABLE bar (' - 'id INTEGER NOT NULL, ' - 'x INTEGER CHECK (x>7), ' - 'z INTEGER, ' - 'PRIMARY KEY (id)' - ')' - ) - ) + CompiledSQL( + "CREATE TABLE foo (" + "id INTEGER NOT NULL, " + "x INTEGER, " + "y INTEGER, " + "PRIMARY KEY (id), " + "CHECK (x>y)" + ")" + ), + CompiledSQL( + "CREATE TABLE bar (" + "id INTEGER NOT NULL, " + "x INTEGER CHECK (x>7), " + "z INTEGER, " + "PRIMARY KEY (id)" + ")" + ), + ), ) @testing.provide_metadata def test_unique_constraint_create(self): metadata = self.metadata - Table('foo', metadata, - Column('id', Integer, primary_key=True), - Column('value', String(30), unique=True)) - Table('bar', metadata, - Column('id', Integer, primary_key=True), - Column('value', String(30)), - Column('value2', String(30)), - UniqueConstraint('value', 'value2', name='uix1') - ) + Table( + "foo", + metadata, + Column("id", Integer, primary_key=True), + Column("value", String(30), unique=True), + ) + Table( + "bar", + metadata, + Column("id", Integer, primary_key=True), + Column("value", String(30)), + Column("value2", String(30)), + UniqueConstraint("value", "value2", name="uix1"), + ) self.assert_sql_execution( testing.db, lambda: metadata.create_all(checkfirst=False), AllOf( - CompiledSQL('CREATE TABLE foo (' - 'id INTEGER NOT NULL, ' - 'value VARCHAR(30), ' - 'PRIMARY KEY (id), ' - 'UNIQUE (value)' - ')'), - CompiledSQL('CREATE TABLE bar (' - 'id INTEGER NOT NULL, ' - 'value VARCHAR(30), ' - 'value2 VARCHAR(30), ' - 'PRIMARY KEY (id), ' - 'CONSTRAINT uix1 UNIQUE (value, value2)' - ')') - ) + CompiledSQL( + "CREATE TABLE foo (" + "id INTEGER NOT NULL, " + "value VARCHAR(30), " + "PRIMARY KEY (id), " + "UNIQUE (value)" + ")" + ), + CompiledSQL( + "CREATE TABLE bar (" + "id INTEGER NOT NULL, " + "value VARCHAR(30), " + "value2 VARCHAR(30), " + "PRIMARY KEY (id), " + "CONSTRAINT uix1 UNIQUE (value, value2)" + ")" + ), + ), ) @testing.provide_metadata def test_index_create(self): metadata = self.metadata - employees = Table('employees', metadata, - Column('id', Integer, primary_key=True), - Column('first_name', String(30)), - Column('last_name', String(30)), - Column('email_address', String(30))) + employees = Table( + "employees", + metadata, + Column("id", Integer, primary_key=True), + Column("first_name", String(30)), + Column("last_name", String(30)), + Column("email_address", String(30)), + ) - i = Index('employee_name_index', - employees.c.last_name, employees.c.first_name) + i = Index( + "employee_name_index", + employees.c.last_name, + employees.c.first_name, + ) assert i in employees.indexes - i2 = Index('employee_email_index', - employees.c.email_address, unique=True) + i2 = Index( + "employee_email_index", employees.c.email_address, unique=True + ) assert i2 in employees.indexes self.assert_sql_execution( @@ -549,11 +636,17 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): lambda: metadata.create_all(checkfirst=False), RegexSQL("^CREATE TABLE"), AllOf( - CompiledSQL('CREATE INDEX employee_name_index ON ' - 'employees (last_name, first_name)', []), - CompiledSQL('CREATE UNIQUE INDEX employee_email_index ON ' - 'employees (email_address)', []) - ) + CompiledSQL( + "CREATE INDEX employee_name_index ON " + "employees (last_name, first_name)", + [], + ), + CompiledSQL( + "CREATE UNIQUE INDEX employee_email_index ON " + "employees (email_address)", + [], + ), + ), ) @testing.provide_metadata @@ -562,27 +655,36 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): metadata = self.metadata - employees = Table('companyEmployees', metadata, - Column('id', Integer, primary_key=True), - Column('firstName', String(30)), - Column('lastName', String(30)), - Column('emailAddress', String(30))) + employees = Table( + "companyEmployees", + metadata, + Column("id", Integer, primary_key=True), + Column("firstName", String(30)), + Column("lastName", String(30)), + Column("emailAddress", String(30)), + ) - Index('employeeNameIndex', - employees.c.lastName, employees.c.firstName) + Index("employeeNameIndex", employees.c.lastName, employees.c.firstName) - Index('employeeEmailIndex', - employees.c.emailAddress, unique=True) + Index("employeeEmailIndex", employees.c.emailAddress, unique=True) self.assert_sql_execution( - testing.db, lambda: metadata.create_all( - checkfirst=False), RegexSQL("^CREATE TABLE"), AllOf( + testing.db, + lambda: metadata.create_all(checkfirst=False), + RegexSQL("^CREATE TABLE"), + AllOf( CompiledSQL( 'CREATE INDEX "employeeNameIndex" ON ' - '"companyEmployees" ("lastName", "firstName")', []), + '"companyEmployees" ("lastName", "firstName")', + [], + ), CompiledSQL( 'CREATE UNIQUE INDEX "employeeEmailIndex" ON ' - '"companyEmployees" ("emailAddress")', []))) + '"companyEmployees" ("emailAddress")', + [], + ), + ), + ) @testing.provide_metadata def test_index_create_inline(self): @@ -590,22 +692,32 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): metadata = self.metadata - events = Table('events', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(30), index=True, unique=True), - Column('location', String(30), index=True), - Column('sport', String(30)), - Column('announcer', String(30)), - Column('winner', String(30))) + events = Table( + "events", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30), index=True, unique=True), + Column("location", String(30), index=True), + Column("sport", String(30)), + Column("announcer", String(30)), + Column("winner", String(30)), + ) - Index('sport_announcer', events.c.sport, events.c.announcer, - unique=True) - Index('idx_winners', events.c.winner) + Index( + "sport_announcer", events.c.sport, events.c.announcer, unique=True + ) + Index("idx_winners", events.c.winner) eq_( set(ix.name for ix in events.indexes), - set(['ix_events_name', 'ix_events_location', - 'sport_announcer', 'idx_winners']) + set( + [ + "ix_events_name", + "ix_events_location", + "sport_announcer", + "idx_winners", + ] + ), ) self.assert_sql_execution( @@ -613,72 +725,71 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): lambda: events.create(testing.db), RegexSQL("^CREATE TABLE events"), AllOf( - CompiledSQL('CREATE UNIQUE INDEX ix_events_name ON events ' - '(name)'), - CompiledSQL('CREATE INDEX ix_events_location ON events ' - '(location)'), - CompiledSQL('CREATE UNIQUE INDEX sport_announcer ON events ' - '(sport, announcer)'), - CompiledSQL('CREATE INDEX idx_winners ON events (winner)'), - ) + CompiledSQL( + "CREATE UNIQUE INDEX ix_events_name ON events " "(name)" + ), + CompiledSQL( + "CREATE INDEX ix_events_location ON events " "(location)" + ), + CompiledSQL( + "CREATE UNIQUE INDEX sport_announcer ON events " + "(sport, announcer)" + ), + CompiledSQL("CREATE INDEX idx_winners ON events (winner)"), + ), ) @testing.provide_metadata def test_index_functional_create(self): metadata = self.metadata - t = Table('sometable', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) - ) - Index('myindex', t.c.data.desc()) + t = Table( + "sometable", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + Index("myindex", t.c.data.desc()) self.assert_sql_execution( testing.db, lambda: t.create(testing.db), - CompiledSQL('CREATE TABLE sometable (id INTEGER NOT NULL, ' - 'data VARCHAR(50), PRIMARY KEY (id))'), - CompiledSQL('CREATE INDEX myindex ON sometable (data DESC)') + CompiledSQL( + "CREATE TABLE sometable (id INTEGER NOT NULL, " + "data VARCHAR(50), PRIMARY KEY (id))" + ), + CompiledSQL("CREATE INDEX myindex ON sometable (data DESC)"), ) class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_create_index_plain(self): - t = Table('t', MetaData(), Column('x', Integer)) + t = Table("t", MetaData(), Column("x", Integer)) i = Index("xyz", t.c.x) - self.assert_compile( - schema.CreateIndex(i), - "CREATE INDEX xyz ON t (x)" - ) + self.assert_compile(schema.CreateIndex(i), "CREATE INDEX xyz ON t (x)") def test_drop_index_plain_unattached(self): self.assert_compile( - schema.DropIndex(Index(name="xyz")), - "DROP INDEX xyz" + schema.DropIndex(Index(name="xyz")), "DROP INDEX xyz" ) def test_drop_index_plain(self): self.assert_compile( - schema.DropIndex(Index(name="xyz")), - "DROP INDEX xyz" + schema.DropIndex(Index(name="xyz")), "DROP INDEX xyz" ) def test_create_index_schema(self): - t = Table('t', MetaData(), Column('x', Integer), schema="foo") + t = Table("t", MetaData(), Column("x", Integer), schema="foo") i = Index("xyz", t.c.x) self.assert_compile( - schema.CreateIndex(i), - "CREATE INDEX xyz ON foo.t (x)" + schema.CreateIndex(i), "CREATE INDEX xyz ON foo.t (x)" ) def test_drop_index_schema(self): - t = Table('t', MetaData(), Column('x', Integer), schema="foo") + t = Table("t", MetaData(), Column("x", Integer), schema="foo") i = Index("xyz", t.c.x) - self.assert_compile( - schema.DropIndex(i), - "DROP INDEX foo.xyz" - ) + self.assert_compile(schema.DropIndex(i), "DROP INDEX foo.xyz") def test_too_long_index_name(self): dialect = testing.db.dialect.__class__() @@ -688,157 +799,171 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL): dialect.max_index_name_length = max_index for tname, cname, exp in [ - ('sometable', 'this_name_is_too_long', 'ix_sometable_t_09aa'), - ('sometable', 'this_name_alsois_long', 'ix_sometable_t_3cf1'), + ("sometable", "this_name_is_too_long", "ix_sometable_t_09aa"), + ("sometable", "this_name_alsois_long", "ix_sometable_t_3cf1"), ]: - t1 = Table(tname, MetaData(), - Column(cname, Integer, index=True), - ) + t1 = Table( + tname, MetaData(), Column(cname, Integer, index=True) + ) ix1 = list(t1.indexes)[0] self.assert_compile( schema.CreateIndex(ix1), - "CREATE INDEX %s " - "ON %s (%s)" % (exp, tname, cname), - dialect=dialect + "CREATE INDEX %s " "ON %s (%s)" % (exp, tname, cname), + dialect=dialect, ) dialect.max_identifier_length = 22 dialect.max_index_name_length = None - t1 = Table('t', MetaData(), Column('c', Integer)) + t1 = Table("t", MetaData(), Column("c", Integer)) assert_raises( exc.IdentifierError, - schema.CreateIndex(Index( - "this_other_name_is_too_long_for_what_were_doing", - t1.c.c)).compile, - dialect=dialect + schema.CreateIndex( + Index( + "this_other_name_is_too_long_for_what_were_doing", t1.c.c + ) + ).compile, + dialect=dialect, ) def test_functional_index(self): metadata = MetaData() - x = Table('x', metadata, - Column('q', String(50)) - ) - idx = Index('y', func.lower(x.c.q)) + x = Table("x", metadata, Column("q", String(50))) + idx = Index("y", func.lower(x.c.q)) self.assert_compile( - schema.CreateIndex(idx), - "CREATE INDEX y ON x (lower(q))" + schema.CreateIndex(idx), "CREATE INDEX y ON x (lower(q))" ) self.assert_compile( schema.CreateIndex(idx), "CREATE INDEX y ON x (lower(q))", - dialect=testing.db.dialect + dialect=testing.db.dialect, ) def test_index_against_text_separate(self): metadata = MetaData() - idx = Index('y', text("some_function(q)")) - t = Table('x', metadata, - Column('q', String(50)) - ) + idx = Index("y", text("some_function(q)")) + t = Table("x", metadata, Column("q", String(50))) t.append_constraint(idx) self.assert_compile( - schema.CreateIndex(idx), - "CREATE INDEX y ON x (some_function(q))" + schema.CreateIndex(idx), "CREATE INDEX y ON x (some_function(q))" ) def test_index_against_text_inline(self): metadata = MetaData() - idx = Index('y', text("some_function(q)")) - x = Table('x', metadata, - Column('q', String(50)), - idx - ) + idx = Index("y", text("some_function(q)")) + x = Table("x", metadata, Column("q", String(50)), idx) self.assert_compile( - schema.CreateIndex(idx), - "CREATE INDEX y ON x (some_function(q))" + schema.CreateIndex(idx), "CREATE INDEX y ON x (some_function(q))" ) def test_index_declaration_inline(self): metadata = MetaData() - t1 = Table('t1', metadata, - Column('x', Integer), - Column('y', Integer), - Index('foo', 'x', 'y') - ) + t1 = Table( + "t1", + metadata, + Column("x", Integer), + Column("y", Integer), + Index("foo", "x", "y"), + ) self.assert_compile( schema.CreateIndex(list(t1.indexes)[0]), - "CREATE INDEX foo ON t1 (x, y)" + "CREATE INDEX foo ON t1 (x, y)", ) def _test_deferrable(self, constraint_factory): dialect = default.DefaultDialect() - t = Table('tbl', MetaData(), - Column('a', Integer), - Column('b', Integer), - constraint_factory(deferrable=True)) + t = Table( + "tbl", + MetaData(), + Column("a", Integer), + Column("b", Integer), + constraint_factory(deferrable=True), + ) sql = str(schema.CreateTable(t).compile(dialect=dialect)) - assert 'DEFERRABLE' in sql, sql - assert 'NOT DEFERRABLE' not in sql, sql + assert "DEFERRABLE" in sql, sql + assert "NOT DEFERRABLE" not in sql, sql - t = Table('tbl', MetaData(), - Column('a', Integer), - Column('b', Integer), - constraint_factory(deferrable=False)) + t = Table( + "tbl", + MetaData(), + Column("a", Integer), + Column("b", Integer), + constraint_factory(deferrable=False), + ) sql = str(schema.CreateTable(t).compile(dialect=dialect)) - assert 'NOT DEFERRABLE' in sql + assert "NOT DEFERRABLE" in sql - t = Table('tbl', MetaData(), - Column('a', Integer), - Column('b', Integer), - constraint_factory(deferrable=True, initially='IMMEDIATE')) + t = Table( + "tbl", + MetaData(), + Column("a", Integer), + Column("b", Integer), + constraint_factory(deferrable=True, initially="IMMEDIATE"), + ) sql = str(schema.CreateTable(t).compile(dialect=dialect)) - assert 'NOT DEFERRABLE' not in sql - assert 'INITIALLY IMMEDIATE' in sql + assert "NOT DEFERRABLE" not in sql + assert "INITIALLY IMMEDIATE" in sql - t = Table('tbl', MetaData(), - Column('a', Integer), - Column('b', Integer), - constraint_factory(deferrable=True, initially='DEFERRED')) + t = Table( + "tbl", + MetaData(), + Column("a", Integer), + Column("b", Integer), + constraint_factory(deferrable=True, initially="DEFERRED"), + ) sql = str(schema.CreateTable(t).compile(dialect=dialect)) - assert 'NOT DEFERRABLE' not in sql - assert 'INITIALLY DEFERRED' in sql + assert "NOT DEFERRABLE" not in sql + assert "INITIALLY DEFERRED" in sql def test_column_level_ck_name(self): t = Table( - 'tbl', + "tbl", MetaData(), Column( - 'a', + "a", Integer, - CheckConstraint( - "a > 5", - name="ck_a_greater_five"))) + CheckConstraint("a > 5", name="ck_a_greater_five"), + ), + ) self.assert_compile( schema.CreateTable(t), "CREATE TABLE tbl (a INTEGER CONSTRAINT " - "ck_a_greater_five CHECK (a > 5))" + "ck_a_greater_five CHECK (a > 5))", ) def test_deferrable_pk(self): - def factory(**kw): return PrimaryKeyConstraint('a', **kw) + def factory(**kw): + return PrimaryKeyConstraint("a", **kw) + self._test_deferrable(factory) def test_deferrable_table_fk(self): - def factory(**kw): return ForeignKeyConstraint(['b'], ['tbl.a'], **kw) + def factory(**kw): + return ForeignKeyConstraint(["b"], ["tbl.a"], **kw) + self._test_deferrable(factory) def test_deferrable_column_fk(self): - t = Table('tbl', MetaData(), - Column('a', Integer), - Column('b', Integer, - ForeignKey('tbl.a', deferrable=True, - initially='DEFERRED'))) + t = Table( + "tbl", + MetaData(), + Column("a", Integer), + Column( + "b", + Integer, + ForeignKey("tbl.a", deferrable=True, initially="DEFERRED"), + ), + ) self.assert_compile( schema.CreateTable(t), @@ -848,10 +973,12 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_fk_match_clause(self): - t = Table('tbl', MetaData(), - Column('a', Integer), - Column('b', Integer, - ForeignKey('tbl.a', match="SIMPLE"))) + t = Table( + "tbl", + MetaData(), + Column("a", Integer), + Column("b", Integer, ForeignKey("tbl.a", match="SIMPLE")), + ) self.assert_compile( schema.CreateTable(t), @@ -863,55 +990,64 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( schema.AddConstraint(list(t.foreign_keys)[0].constraint), "ALTER TABLE tbl ADD FOREIGN KEY(b) " - "REFERENCES tbl (a) MATCH SIMPLE" + "REFERENCES tbl (a) MATCH SIMPLE", ) def test_create_table_omit_fks(self): fkcs = [ - ForeignKeyConstraint(['a'], ['remote.id'], name='foo'), - ForeignKeyConstraint(['b'], ['remote.id'], name='bar'), - ForeignKeyConstraint(['c'], ['remote.id'], name='bat'), + ForeignKeyConstraint(["a"], ["remote.id"], name="foo"), + ForeignKeyConstraint(["b"], ["remote.id"], name="bar"), + ForeignKeyConstraint(["c"], ["remote.id"], name="bat"), ] m = MetaData() t = Table( - 't', m, - Column('a', Integer), - Column('b', Integer), - Column('c', Integer), + "t", + m, + Column("a", Integer), + Column("b", Integer), + Column("c", Integer), *fkcs ) - Table('remote', m, Column('id', Integer, primary_key=True)) + Table("remote", m, Column("id", Integer, primary_key=True)) self.assert_compile( schema.CreateTable(t, include_foreign_key_constraints=[]), - "CREATE TABLE t (a INTEGER, b INTEGER, c INTEGER)" + "CREATE TABLE t (a INTEGER, b INTEGER, c INTEGER)", ) self.assert_compile( schema.CreateTable(t, include_foreign_key_constraints=fkcs[0:2]), "CREATE TABLE t (a INTEGER, b INTEGER, c INTEGER, " "CONSTRAINT foo FOREIGN KEY(a) REFERENCES remote (id), " - "CONSTRAINT bar FOREIGN KEY(b) REFERENCES remote (id))" + "CONSTRAINT bar FOREIGN KEY(b) REFERENCES remote (id))", ) def test_deferrable_unique(self): - def factory(**kw): return UniqueConstraint('b', **kw) + def factory(**kw): + return UniqueConstraint("b", **kw) + self._test_deferrable(factory) def test_deferrable_table_check(self): - def factory(**kw): return CheckConstraint('a < b', **kw) + def factory(**kw): + return CheckConstraint("a < b", **kw) + self._test_deferrable(factory) def test_multiple(self): m = MetaData() - Table("foo", m, - Column('id', Integer, primary_key=True), - Column('bar', Integer, primary_key=True) - ) - tb = Table("some_table", m, - Column('id', Integer, primary_key=True), - Column('foo_id', Integer, ForeignKey('foo.id')), - Column('foo_bar', Integer, ForeignKey('foo.bar')), - ) + Table( + "foo", + m, + Column("id", Integer, primary_key=True), + Column("bar", Integer, primary_key=True), + ) + tb = Table( + "some_table", + m, + Column("id", Integer, primary_key=True), + Column("foo_id", Integer, ForeignKey("foo.id")), + Column("foo_bar", Integer, ForeignKey("foo.bar")), + ) self.assert_compile( schema.CreateTable(tb), "CREATE TABLE some_table (" @@ -920,93 +1056,106 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL): "foo_bar INTEGER, " "PRIMARY KEY (id), " "FOREIGN KEY(foo_id) REFERENCES foo (id), " - "FOREIGN KEY(foo_bar) REFERENCES foo (bar))" + "FOREIGN KEY(foo_bar) REFERENCES foo (bar))", ) def test_empty_pkc(self): # test that an empty primary key is ignored metadata = MetaData() - tbl = Table('test', metadata, - Column('x', Integer, autoincrement=False), - Column('y', Integer, autoincrement=False), - PrimaryKeyConstraint()) - self.assert_compile(schema.CreateTable(tbl), - "CREATE TABLE test (x INTEGER, y INTEGER)" - ) + tbl = Table( + "test", + metadata, + Column("x", Integer, autoincrement=False), + Column("y", Integer, autoincrement=False), + PrimaryKeyConstraint(), + ) + self.assert_compile( + schema.CreateTable(tbl), "CREATE TABLE test (x INTEGER, y INTEGER)" + ) def test_empty_uc(self): # test that an empty constraint is ignored metadata = MetaData() - tbl = Table('test', metadata, - Column('x', Integer, autoincrement=False), - Column('y', Integer, autoincrement=False), - UniqueConstraint()) - self.assert_compile(schema.CreateTable(tbl), - "CREATE TABLE test (x INTEGER, y INTEGER)" - ) + tbl = Table( + "test", + metadata, + Column("x", Integer, autoincrement=False), + Column("y", Integer, autoincrement=False), + UniqueConstraint(), + ) + self.assert_compile( + schema.CreateTable(tbl), "CREATE TABLE test (x INTEGER, y INTEGER)" + ) def test_deferrable_column_check(self): - t = Table('tbl', MetaData(), - Column('a', Integer), - Column('b', Integer, - CheckConstraint('a < b', - deferrable=True, - initially='DEFERRED'))) + t = Table( + "tbl", + MetaData(), + Column("a", Integer), + Column( + "b", + Integer, + CheckConstraint( + "a < b", deferrable=True, initially="DEFERRED" + ), + ), + ) self.assert_compile( schema.CreateTable(t), "CREATE TABLE tbl (a INTEGER, b INTEGER CHECK (a < b) " - "DEFERRABLE INITIALLY DEFERRED)" + "DEFERRABLE INITIALLY DEFERRED)", ) def test_use_alter(self): m = MetaData() - Table('t', m, - Column('a', Integer), - ) + Table("t", m, Column("a", Integer)) - Table('t2', m, - Column('a', Integer, ForeignKey('t.a', use_alter=True, - name='fk_ta')), - Column('b', Integer, ForeignKey('t.a', name='fk_tb')) - ) + Table( + "t2", + m, + Column( + "a", Integer, ForeignKey("t.a", use_alter=True, name="fk_ta") + ), + Column("b", Integer, ForeignKey("t.a", name="fk_tb")), + ) - e = engines.mock_engine(dialect_name='postgresql') + e = engines.mock_engine(dialect_name="postgresql") m.create_all(e) m.drop_all(e) - e.assert_sql([ - 'CREATE TABLE t (a INTEGER)', - 'CREATE TABLE t2 (a INTEGER, b INTEGER, CONSTRAINT fk_tb ' - 'FOREIGN KEY(b) REFERENCES t (a))', - 'ALTER TABLE t2 ' - 'ADD CONSTRAINT fk_ta FOREIGN KEY(a) REFERENCES t (a)', - 'ALTER TABLE t2 DROP CONSTRAINT fk_ta', - 'DROP TABLE t2', - 'DROP TABLE t' - ]) + e.assert_sql( + [ + "CREATE TABLE t (a INTEGER)", + "CREATE TABLE t2 (a INTEGER, b INTEGER, CONSTRAINT fk_tb " + "FOREIGN KEY(b) REFERENCES t (a))", + "ALTER TABLE t2 " + "ADD CONSTRAINT fk_ta FOREIGN KEY(a) REFERENCES t (a)", + "ALTER TABLE t2 DROP CONSTRAINT fk_ta", + "DROP TABLE t2", + "DROP TABLE t", + ] + ) def _constraint_create_fixture(self): m = MetaData() - t = Table('tbl', m, - Column('a', Integer), - Column('b', Integer) - ) + t = Table("tbl", m, Column("a", Integer), Column("b", Integer)) - t2 = Table('t2', m, - Column('a', Integer), - Column('b', Integer) - ) + t2 = Table("t2", m, Column("a", Integer), Column("b", Integer)) return t, t2 def test_render_ck_constraint_inline(self): t, t2 = self._constraint_create_fixture() - CheckConstraint('a < b', name="my_test_constraint", - deferrable=True, initially='DEFERRED', - table=t) + CheckConstraint( + "a < b", + name="my_test_constraint", + deferrable=True, + initially="DEFERRED", + table=t, + ) # before we create an AddConstraint, # the CONSTRAINT comes out inline @@ -1017,28 +1166,36 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL): "b INTEGER, " "CONSTRAINT my_test_constraint CHECK (a < b) " "DEFERRABLE INITIALLY DEFERRED" - ")" + ")", ) def test_render_ck_constraint_external(self): t, t2 = self._constraint_create_fixture() - constraint = CheckConstraint('a < b', name="my_test_constraint", - deferrable=True, initially='DEFERRED', - table=t) + constraint = CheckConstraint( + "a < b", + name="my_test_constraint", + deferrable=True, + initially="DEFERRED", + table=t, + ) self.assert_compile( schema.AddConstraint(constraint), "ALTER TABLE tbl ADD CONSTRAINT my_test_constraint " - "CHECK (a < b) DEFERRABLE INITIALLY DEFERRED" + "CHECK (a < b) DEFERRABLE INITIALLY DEFERRED", ) def test_external_ck_constraint_cancels_internal(self): t, t2 = self._constraint_create_fixture() - constraint = CheckConstraint('a < b', name="my_test_constraint", - deferrable=True, initially='DEFERRED', - table=t) + constraint = CheckConstraint( + "a < b", + name="my_test_constraint", + deferrable=True, + initially="DEFERRED", + table=t, + ) schema.AddConstraint(constraint) @@ -1047,34 +1204,39 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL): # is disabled self.assert_compile( schema.CreateTable(t), - "CREATE TABLE tbl (" - "a INTEGER, " - "b INTEGER" - ")" + "CREATE TABLE tbl (" "a INTEGER, " "b INTEGER" ")", ) def test_render_drop_constraint(self): t, t2 = self._constraint_create_fixture() - constraint = CheckConstraint('a < b', name="my_test_constraint", - deferrable=True, initially='DEFERRED', - table=t) + constraint = CheckConstraint( + "a < b", + name="my_test_constraint", + deferrable=True, + initially="DEFERRED", + table=t, + ) self.assert_compile( schema.DropConstraint(constraint), - "ALTER TABLE tbl DROP CONSTRAINT my_test_constraint" + "ALTER TABLE tbl DROP CONSTRAINT my_test_constraint", ) def test_render_drop_constraint_cascade(self): t, t2 = self._constraint_create_fixture() - constraint = CheckConstraint('a < b', name="my_test_constraint", - deferrable=True, initially='DEFERRED', - table=t) + constraint = CheckConstraint( + "a < b", + name="my_test_constraint", + deferrable=True, + initially="DEFERRED", + table=t, + ) self.assert_compile( schema.DropConstraint(constraint, cascade=True), - "ALTER TABLE tbl DROP CONSTRAINT my_test_constraint CASCADE" + "ALTER TABLE tbl DROP CONSTRAINT my_test_constraint CASCADE", ) def test_render_add_fk_constraint_stringcol(self): @@ -1084,7 +1246,7 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL): t.append_constraint(constraint) self.assert_compile( schema.AddConstraint(constraint), - "ALTER TABLE tbl ADD FOREIGN KEY(b) REFERENCES t2 (a)" + "ALTER TABLE tbl ADD FOREIGN KEY(b) REFERENCES t2 (a)", ) def test_render_add_fk_constraint_realcol(self): @@ -1094,7 +1256,7 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL): t.append_constraint(constraint) self.assert_compile( schema.AddConstraint(constraint), - "ALTER TABLE tbl ADD FOREIGN KEY(a) REFERENCES t2 (b)" + "ALTER TABLE tbl ADD FOREIGN KEY(a) REFERENCES t2 (b)", ) def test_render_add_uq_constraint_stringcol(self): @@ -1104,7 +1266,7 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL): t2.append_constraint(constraint) self.assert_compile( schema.AddConstraint(constraint), - "ALTER TABLE t2 ADD CONSTRAINT uq_cst UNIQUE (a, b)" + "ALTER TABLE t2 ADD CONSTRAINT uq_cst UNIQUE (a, b)", ) def test_render_add_uq_constraint_realcol(self): @@ -1113,7 +1275,7 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL): constraint = UniqueConstraint(t2.c.a, t2.c.b, name="uq_cs2") self.assert_compile( schema.AddConstraint(constraint), - "ALTER TABLE t2 ADD CONSTRAINT uq_cs2 UNIQUE (a, b)" + "ALTER TABLE t2 ADD CONSTRAINT uq_cs2 UNIQUE (a, b)", ) def test_render_add_pk_constraint(self): @@ -1124,7 +1286,7 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL): assert t.c.a.primary_key is True self.assert_compile( schema.AddConstraint(constraint), - "ALTER TABLE tbl ADD PRIMARY KEY (a)" + "ALTER TABLE tbl ADD PRIMARY KEY (a)", ) def test_render_check_constraint_sql_literal(self): @@ -1134,7 +1296,7 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( schema.AddConstraint(constraint), - "ALTER TABLE tbl ADD CHECK (a > 5)" + "ALTER TABLE tbl ADD CHECK (a > 5)", ) def test_render_check_constraint_inline_sql_literal(self): @@ -1142,20 +1304,20 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL): m = MetaData() t = Table( - 't', m, - Column('a', Integer, CheckConstraint(Column('a', Integer) > 5))) + "t", + m, + Column("a", Integer, CheckConstraint(Column("a", Integer) > 5)), + ) self.assert_compile( - schema.CreateColumn(t.c.a), - "a INTEGER CHECK (a > 5)" + schema.CreateColumn(t.c.a), "a INTEGER CHECK (a > 5)" ) def test_render_index_sql_literal(self): t, t2 = self._constraint_create_fixture() - constraint = Index('name', t.c.a + 5) + constraint = Index("name", t.c.a + 5) self.assert_compile( - schema.CreateIndex(constraint), - "CREATE INDEX name ON tbl (a + 5)" + schema.CreateIndex(constraint), "CREATE INDEX name ON tbl (a + 5)" ) diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index efc4640ed7..74c32387f4 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1,7 +1,15 @@ from sqlalchemy.testing import fixtures, eq_ from sqlalchemy.testing import AssertsCompiledSQL, assert_raises_message -from sqlalchemy.sql import table, column, select, func, literal, exists, \ - and_, bindparam +from sqlalchemy.sql import ( + table, + column, + select, + func, + literal, + exists, + and_, + bindparam, +) from sqlalchemy.dialects import mssql from sqlalchemy.engine import default from sqlalchemy.exc import CompileError @@ -11,36 +19,49 @@ from sqlalchemy.sql.visitors import cloned_traverse class CTETest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default_enhanced' + __dialect__ = "default_enhanced" def test_nonrecursive(self): - orders = table('orders', - column('region'), - column('amount'), - column('product'), - column('quantity') - ) - - regional_sales = select([ - orders.c.region, - func.sum(orders.c.amount).label('total_sales') - ]).group_by(orders.c.region).cte("regional_sales") - - top_regions = select([regional_sales.c.region]).\ - where( - regional_sales.c.total_sales > select([ - func.sum(regional_sales.c.total_sales) / 10 - ]) - ).cte("top_regions") - - s = select([ - orders.c.region, - orders.c.product, - func.sum(orders.c.quantity).label("product_units"), - func.sum(orders.c.amount).label("product_sales") - ]).where(orders.c.region.in_( - select([top_regions.c.region]) - )).group_by(orders.c.region, orders.c.product) + orders = table( + "orders", + column("region"), + column("amount"), + column("product"), + column("quantity"), + ) + + regional_sales = ( + select( + [ + orders.c.region, + func.sum(orders.c.amount).label("total_sales"), + ] + ) + .group_by(orders.c.region) + .cte("regional_sales") + ) + + top_regions = ( + select([regional_sales.c.region]) + .where( + regional_sales.c.total_sales + > select([func.sum(regional_sales.c.total_sales) / 10]) + ) + .cte("top_regions") + ) + + s = ( + select( + [ + orders.c.region, + orders.c.product, + func.sum(orders.c.quantity).label("product_units"), + func.sum(orders.c.amount).label("product_sales"), + ] + ) + .where(orders.c.region.in_(select([top_regions.c.region]))) + .group_by(orders.c.region, orders.c.product) + ) # needs to render regional_sales first as top_regions # refers to it @@ -59,41 +80,51 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "sum(orders.amount) AS product_sales " "FROM orders WHERE orders.region " "IN (SELECT top_regions.region FROM top_regions) " - "GROUP BY orders.region, orders.product" + "GROUP BY orders.region, orders.product", ) def test_recursive(self): - parts = table('parts', - column('part'), - column('sub_part'), - column('quantity'), - ) - - included_parts = select([ - parts.c.sub_part, - parts.c.part, - parts.c.quantity]).\ - where(parts.c.part == 'our part').\ - cte(recursive=True) + parts = table( + "parts", column("part"), column("sub_part"), column("quantity") + ) + + included_parts = ( + select([parts.c.sub_part, parts.c.part, parts.c.quantity]) + .where(parts.c.part == "our part") + .cte(recursive=True) + ) incl_alias = included_parts.alias() parts_alias = parts.alias() included_parts = included_parts.union( - select([ - parts_alias.c.sub_part, - parts_alias.c.part, - parts_alias.c.quantity]). - where(parts_alias.c.part == incl_alias.c.sub_part) - ) - - s = select([ - included_parts.c.sub_part, - func.sum(included_parts.c.quantity).label('total_quantity')]).\ - select_from(included_parts.join( - parts, included_parts.c.part == parts.c.part)).\ - group_by(included_parts.c.sub_part) + select( + [ + parts_alias.c.sub_part, + parts_alias.c.part, + parts_alias.c.quantity, + ] + ).where(parts_alias.c.part == incl_alias.c.sub_part) + ) + + s = ( + select( + [ + included_parts.c.sub_part, + func.sum(included_parts.c.quantity).label( + "total_quantity" + ), + ] + ) + .select_from( + included_parts.join( + parts, included_parts.c.part == parts.c.part + ) + ) + .group_by(included_parts.c.sub_part) + ) self.assert_compile( - s, "WITH RECURSIVE anon_1(sub_part, part, quantity) " + s, + "WITH RECURSIVE anon_1(sub_part, part, quantity) " "AS (SELECT parts.sub_part AS sub_part, parts.part " "AS part, parts.quantity AS quantity FROM parts " "WHERE parts.part = :part_1 UNION " @@ -104,12 +135,14 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT anon_1.sub_part, " "sum(anon_1.quantity) AS total_quantity FROM anon_1 " "JOIN parts ON anon_1.part = parts.part " - "GROUP BY anon_1.sub_part") + "GROUP BY anon_1.sub_part", + ) # quick check that the "WITH RECURSIVE" varies per # dialect self.assert_compile( - s, "WITH anon_1(sub_part, part, quantity) " + s, + "WITH anon_1(sub_part, part, quantity) " "AS (SELECT parts.sub_part AS sub_part, parts.part " "AS part, parts.quantity AS quantity FROM parts " "WHERE parts.part = :part_1 UNION " @@ -120,40 +153,52 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT anon_1.sub_part, " "sum(anon_1.quantity) AS total_quantity FROM anon_1 " "JOIN parts ON anon_1.part = parts.part " - "GROUP BY anon_1.sub_part", dialect=mssql.dialect()) + "GROUP BY anon_1.sub_part", + dialect=mssql.dialect(), + ) def test_recursive_inner_cte_unioned_to_alias(self): - parts = table('parts', - column('part'), - column('sub_part'), - column('quantity'), - ) - - included_parts = select([ - parts.c.sub_part, - parts.c.part, - parts.c.quantity]).\ - where(parts.c.part == 'our part').\ - cte(recursive=True) - - incl_alias = included_parts.alias('incl') + parts = table( + "parts", column("part"), column("sub_part"), column("quantity") + ) + + included_parts = ( + select([parts.c.sub_part, parts.c.part, parts.c.quantity]) + .where(parts.c.part == "our part") + .cte(recursive=True) + ) + + incl_alias = included_parts.alias("incl") parts_alias = parts.alias() included_parts = incl_alias.union( - select([ - parts_alias.c.sub_part, - parts_alias.c.part, - parts_alias.c.quantity]). - where(parts_alias.c.part == incl_alias.c.sub_part) - ) - - s = select([ - included_parts.c.sub_part, - func.sum(included_parts.c.quantity).label('total_quantity')]).\ - select_from(included_parts.join( - parts, included_parts.c.part == parts.c.part)).\ - group_by(included_parts.c.sub_part) + select( + [ + parts_alias.c.sub_part, + parts_alias.c.part, + parts_alias.c.quantity, + ] + ).where(parts_alias.c.part == incl_alias.c.sub_part) + ) + + s = ( + select( + [ + included_parts.c.sub_part, + func.sum(included_parts.c.quantity).label( + "total_quantity" + ), + ] + ) + .select_from( + included_parts.join( + parts, included_parts.c.part == parts.c.part + ) + ) + .group_by(included_parts.c.sub_part) + ) self.assert_compile( - s, "WITH RECURSIVE incl(sub_part, part, quantity) " + s, + "WITH RECURSIVE incl(sub_part, part, quantity) " "AS (SELECT parts.sub_part AS sub_part, parts.part " "AS part, parts.quantity AS quantity FROM parts " "WHERE parts.part = :part_1 UNION " @@ -164,37 +209,38 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT incl.sub_part, " "sum(incl.quantity) AS total_quantity FROM incl " "JOIN parts ON incl.part = parts.part " - "GROUP BY incl.sub_part") + "GROUP BY incl.sub_part", + ) def test_recursive_union_no_alias_one(self): s1 = select([literal(0).label("x")]) cte = s1.cte(name="cte", recursive=True) - cte = cte.union_all( - select([cte.c.x + 1]).where(cte.c.x < 10) - ) + cte = cte.union_all(select([cte.c.x + 1]).where(cte.c.x < 10)) s2 = select([cte]) - self.assert_compile(s2, - "WITH RECURSIVE cte(x) AS " - "(SELECT :param_1 AS x UNION ALL " - "SELECT cte.x + :x_1 AS anon_1 " - "FROM cte WHERE cte.x < :x_2) " - "SELECT cte.x FROM cte" - ) + self.assert_compile( + s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2) " + "SELECT cte.x FROM cte", + ) def test_recursive_union_alias_one(self): s1 = select([literal(0).label("x")]) cte = s1.cte(name="cte", recursive=True) - cte = cte.union_all( - select([cte.c.x + 1]).where(cte.c.x < 10) - ).alias("cr1") + cte = cte.union_all(select([cte.c.x + 1]).where(cte.c.x < 10)).alias( + "cr1" + ) s2 = select([cte]) - self.assert_compile(s2, - "WITH RECURSIVE cte(x) AS " - "(SELECT :param_1 AS x UNION ALL " - "SELECT cte.x + :x_1 AS anon_1 " - "FROM cte WHERE cte.x < :x_2) " - "SELECT cr1.x FROM cte AS cr1" - ) + self.assert_compile( + s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2) " + "SELECT cr1.x FROM cte AS cr1", + ) def test_recursive_union_no_alias_two(self): """ @@ -216,14 +262,15 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): t = select([func.values(1).label("n")]).cte("t", recursive=True) t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100)) s = select([func.sum(t.c.n)]) - self.assert_compile(s, - "WITH RECURSIVE t(n) AS " - "(SELECT values(:values_1) AS n " - "UNION ALL SELECT t.n + :n_1 AS anon_1 " - "FROM t " - "WHERE t.n < :n_2) " - "SELECT sum(t.n) AS sum_1 FROM t" - ) + self.assert_compile( + s, + "WITH RECURSIVE t(n) AS " + "(SELECT values(:values_1) AS n " + "UNION ALL SELECT t.n + :n_1 AS anon_1 " + "FROM t " + "WHERE t.n < :n_2) " + "SELECT sum(t.n) AS sum_1 FROM t", + ) def test_recursive_union_alias_two(self): """ @@ -234,16 +281,17 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): # we're cheating here. also yes we need the SELECT, # sorry PG. t = select([func.values(1).label("n")]).cte("t", recursive=True) - t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100)).alias('ta') + t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100)).alias("ta") s = select([func.sum(t.c.n)]) - self.assert_compile(s, - "WITH RECURSIVE t(n) AS " - "(SELECT values(:values_1) AS n " - "UNION ALL SELECT t.n + :n_1 AS anon_1 " - "FROM t " - "WHERE t.n < :n_2) " - "SELECT sum(ta.n) AS sum_1 FROM t AS ta" - ) + self.assert_compile( + s, + "WITH RECURSIVE t(n) AS " + "(SELECT values(:values_1) AS n " + "UNION ALL SELECT t.n + :n_1 AS anon_1 " + "FROM t " + "WHERE t.n < :n_2) " + "SELECT sum(ta.n) AS sum_1 FROM t AS ta", + ) def test_recursive_union_no_alias_three(self): # like test one, but let's refer to the CTE @@ -254,20 +302,19 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): # can't do it here... # bar = select([cte]).cte('bar') - cte = cte.union_all( - select([cte.c.x + 1]).where(cte.c.x < 10) - ) - bar = select([cte]).cte('bar') + cte = cte.union_all(select([cte.c.x + 1]).where(cte.c.x < 10)) + bar = select([cte]).cte("bar") s2 = select([cte, bar]) - self.assert_compile(s2, - "WITH RECURSIVE cte(x) AS " - "(SELECT :param_1 AS x UNION ALL " - "SELECT cte.x + :x_1 AS anon_1 " - "FROM cte WHERE cte.x < :x_2), " - "bar AS (SELECT cte.x AS x FROM cte) " - "SELECT cte.x, bar.x FROM cte, bar" - ) + self.assert_compile( + s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2), " + "bar AS (SELECT cte.x AS x FROM cte) " + "SELECT cte.x, bar.x FROM cte, bar", + ) def test_recursive_union_alias_three(self): # like test one, but let's refer to the CTE @@ -278,20 +325,21 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): # can't do it here... # bar = select([cte]).cte('bar') - cte = cte.union_all( - select([cte.c.x + 1]).where(cte.c.x < 10) - ).alias("cs1") - bar = select([cte]).cte('bar').alias("cs2") + cte = cte.union_all(select([cte.c.x + 1]).where(cte.c.x < 10)).alias( + "cs1" + ) + bar = select([cte]).cte("bar").alias("cs2") s2 = select([cte, bar]) - self.assert_compile(s2, - "WITH RECURSIVE cte(x) AS " - "(SELECT :param_1 AS x UNION ALL " - "SELECT cte.x + :x_1 AS anon_1 " - "FROM cte WHERE cte.x < :x_2), " - "bar AS (SELECT cs1.x AS x FROM cte AS cs1) " - "SELECT cs1.x, cs2.x FROM cte AS cs1, bar AS cs2" - ) + self.assert_compile( + s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2), " + "bar AS (SELECT cs1.x AS x FROM cte AS cs1) " + "SELECT cs1.x, cs2.x FROM cte AS cs1, bar AS cs2", + ) def test_recursive_union_no_alias_four(self): # like test one and three, but let's refer @@ -302,43 +350,45 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): s1 = select([literal(0).label("x")]) cte = s1.cte(name="cte", recursive=True) - bar = select([cte]).cte('bar') - cte = cte.union_all( - select([cte.c.x + 1]).where(cte.c.x < 10) - ) + bar = select([cte]).cte("bar") + cte = cte.union_all(select([cte.c.x + 1]).where(cte.c.x < 10)) # outer cte rendered first, then bar, which # includes "inner" cte s2 = select([cte, bar]) - self.assert_compile(s2, - "WITH RECURSIVE cte(x) AS " - "(SELECT :param_1 AS x UNION ALL " - "SELECT cte.x + :x_1 AS anon_1 " - "FROM cte WHERE cte.x < :x_2), " - "bar AS (SELECT cte.x AS x FROM cte) " - "SELECT cte.x, bar.x FROM cte, bar" - ) + self.assert_compile( + s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2), " + "bar AS (SELECT cte.x AS x FROM cte) " + "SELECT cte.x, bar.x FROM cte, bar", + ) # bar rendered, only includes "inner" cte, # "outer" cte isn't present s2 = select([bar]) - self.assert_compile(s2, - "WITH RECURSIVE cte(x) AS " - "(SELECT :param_1 AS x), " - "bar AS (SELECT cte.x AS x FROM cte) " - "SELECT bar.x FROM bar" - ) + self.assert_compile( + s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x), " + "bar AS (SELECT cte.x AS x FROM cte) " + "SELECT bar.x FROM bar", + ) # bar rendered, but then the "outer" # cte is rendered. s2 = select([bar, cte]) self.assert_compile( - s2, "WITH RECURSIVE bar AS (SELECT cte.x AS x FROM cte), " + s2, + "WITH RECURSIVE bar AS (SELECT cte.x AS x FROM cte), " "cte(x) AS " "(SELECT :param_1 AS x UNION ALL " "SELECT cte.x + :x_1 AS anon_1 " "FROM cte WHERE cte.x < :x_2) " - "SELECT bar.x, cte.x FROM bar, cte") + "SELECT bar.x, cte.x FROM bar, cte", + ) def test_recursive_union_alias_four(self): # like test one and three, but let's refer @@ -349,222 +399,228 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): s1 = select([literal(0).label("x")]) cte = s1.cte(name="cte", recursive=True) - bar = select([cte]).cte('bar').alias("cs1") - cte = cte.union_all( - select([cte.c.x + 1]).where(cte.c.x < 10) - ).alias("cs2") + bar = select([cte]).cte("bar").alias("cs1") + cte = cte.union_all(select([cte.c.x + 1]).where(cte.c.x < 10)).alias( + "cs2" + ) # outer cte rendered first, then bar, which # includes "inner" cte s2 = select([cte, bar]) - self.assert_compile(s2, - "WITH RECURSIVE cte(x) AS " - "(SELECT :param_1 AS x UNION ALL " - "SELECT cte.x + :x_1 AS anon_1 " - "FROM cte WHERE cte.x < :x_2), " - "bar AS (SELECT cte.x AS x FROM cte) " - "SELECT cs2.x, cs1.x FROM cte AS cs2, bar AS cs1" - ) + self.assert_compile( + s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2), " + "bar AS (SELECT cte.x AS x FROM cte) " + "SELECT cs2.x, cs1.x FROM cte AS cs2, bar AS cs1", + ) # bar rendered, only includes "inner" cte, # "outer" cte isn't present s2 = select([bar]) - self.assert_compile(s2, - "WITH RECURSIVE cte(x) AS " - "(SELECT :param_1 AS x), " - "bar AS (SELECT cte.x AS x FROM cte) " - "SELECT cs1.x FROM bar AS cs1" - ) + self.assert_compile( + s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x), " + "bar AS (SELECT cte.x AS x FROM cte) " + "SELECT cs1.x FROM bar AS cs1", + ) # bar rendered, but then the "outer" # cte is rendered. s2 = select([bar, cte]) self.assert_compile( - s2, "WITH RECURSIVE bar AS (SELECT cte.x AS x FROM cte), " + s2, + "WITH RECURSIVE bar AS (SELECT cte.x AS x FROM cte), " "cte(x) AS " "(SELECT :param_1 AS x UNION ALL " "SELECT cte.x + :x_1 AS anon_1 " "FROM cte WHERE cte.x < :x_2) " - "SELECT cs1.x, cs2.x FROM bar AS cs1, cte AS cs2") + "SELECT cs1.x, cs2.x FROM bar AS cs1, cte AS cs2", + ) def test_conflicting_names(self): """test a flat out name conflict.""" s1 = select([1]) - c1 = s1.cte(name='cte1', recursive=True) + c1 = s1.cte(name="cte1", recursive=True) s2 = select([1]) - c2 = s2.cte(name='cte1', recursive=True) + c2 = s2.cte(name="cte1", recursive=True) s = select([c1, c2]) assert_raises_message( CompileError, - "Multiple, unrelated CTEs found " - "with the same name: 'cte1'", - s.compile + "Multiple, unrelated CTEs found " "with the same name: 'cte1'", + s.compile, ) def test_union(self): - orders = table('orders', - column('region'), - column('amount'), - ) - - regional_sales = select([ - orders.c.region, - orders.c.amount - ]).cte("regional_sales") - - s = select( - [regional_sales.c.region]).where( + orders = table("orders", column("region"), column("amount")) + + regional_sales = select([orders.c.region, orders.c.amount]).cte( + "regional_sales" + ) + + s = select([regional_sales.c.region]).where( regional_sales.c.amount > 500 ) - self.assert_compile(s, - "WITH regional_sales AS " - "(SELECT orders.region AS region, " - "orders.amount AS amount FROM orders) " - "SELECT regional_sales.region " - "FROM regional_sales WHERE " - "regional_sales.amount > :amount_1") + self.assert_compile( + s, + "WITH regional_sales AS " + "(SELECT orders.region AS region, " + "orders.amount AS amount FROM orders) " + "SELECT regional_sales.region " + "FROM regional_sales WHERE " + "regional_sales.amount > :amount_1", + ) s = s.union_all( - select([regional_sales.c.region]). - where( + select([regional_sales.c.region]).where( regional_sales.c.amount < 300 ) ) - self.assert_compile(s, - "WITH regional_sales AS " - "(SELECT orders.region AS region, " - "orders.amount AS amount FROM orders) " - "SELECT regional_sales.region FROM regional_sales " - "WHERE regional_sales.amount > :amount_1 " - "UNION ALL SELECT regional_sales.region " - "FROM regional_sales WHERE " - "regional_sales.amount < :amount_2") + self.assert_compile( + s, + "WITH regional_sales AS " + "(SELECT orders.region AS region, " + "orders.amount AS amount FROM orders) " + "SELECT regional_sales.region FROM regional_sales " + "WHERE regional_sales.amount > :amount_1 " + "UNION ALL SELECT regional_sales.region " + "FROM regional_sales WHERE " + "regional_sales.amount < :amount_2", + ) def test_union_cte_aliases(self): - orders = table('orders', - column('region'), - column('amount'), - ) - - regional_sales = select([ - orders.c.region, - orders.c.amount - ]).cte("regional_sales").alias("rs") - - s = select( - [regional_sales.c.region]).where( + orders = table("orders", column("region"), column("amount")) + + regional_sales = ( + select([orders.c.region, orders.c.amount]) + .cte("regional_sales") + .alias("rs") + ) + + s = select([regional_sales.c.region]).where( regional_sales.c.amount > 500 ) - self.assert_compile(s, - "WITH regional_sales AS " - "(SELECT orders.region AS region, " - "orders.amount AS amount FROM orders) " - "SELECT rs.region " - "FROM regional_sales AS rs WHERE " - "rs.amount > :amount_1") + self.assert_compile( + s, + "WITH regional_sales AS " + "(SELECT orders.region AS region, " + "orders.amount AS amount FROM orders) " + "SELECT rs.region " + "FROM regional_sales AS rs WHERE " + "rs.amount > :amount_1", + ) s = s.union_all( - select([regional_sales.c.region]). - where( + select([regional_sales.c.region]).where( regional_sales.c.amount < 300 ) ) - self.assert_compile(s, - "WITH regional_sales AS " - "(SELECT orders.region AS region, " - "orders.amount AS amount FROM orders) " - "SELECT rs.region FROM regional_sales AS rs " - "WHERE rs.amount > :amount_1 " - "UNION ALL SELECT rs.region " - "FROM regional_sales AS rs WHERE " - "rs.amount < :amount_2") + self.assert_compile( + s, + "WITH regional_sales AS " + "(SELECT orders.region AS region, " + "orders.amount AS amount FROM orders) " + "SELECT rs.region FROM regional_sales AS rs " + "WHERE rs.amount > :amount_1 " + "UNION ALL SELECT rs.region " + "FROM regional_sales AS rs WHERE " + "rs.amount < :amount_2", + ) cloned = cloned_traverse(s, {}, {}) - self.assert_compile(cloned, - "WITH regional_sales AS " - "(SELECT orders.region AS region, " - "orders.amount AS amount FROM orders) " - "SELECT rs.region FROM regional_sales AS rs " - "WHERE rs.amount > :amount_1 " - "UNION ALL SELECT rs.region " - "FROM regional_sales AS rs WHERE " - "rs.amount < :amount_2") + self.assert_compile( + cloned, + "WITH regional_sales AS " + "(SELECT orders.region AS region, " + "orders.amount AS amount FROM orders) " + "SELECT rs.region FROM regional_sales AS rs " + "WHERE rs.amount > :amount_1 " + "UNION ALL SELECT rs.region " + "FROM regional_sales AS rs WHERE " + "rs.amount < :amount_2", + ) def test_cloned_alias(self): entity = table( - 'entity', column('id'), column('employer_id'), column('name')) - tag = table('tag', column('tag'), column('entity_id')) + "entity", column("id"), column("employer_id"), column("name") + ) + tag = table("tag", column("tag"), column("entity_id")) - tags = select([ - tag.c.entity_id, - func.array_agg(tag.c.tag).label('tags'), - ]).group_by(tag.c.entity_id).cte('unaliased_tags') + tags = ( + select([tag.c.entity_id, func.array_agg(tag.c.tag).label("tags")]) + .group_by(tag.c.entity_id) + .cte("unaliased_tags") + ) - entity_tags = tags.alias(name='entity_tags') - employer_tags = tags.alias(name='employer_tags') + entity_tags = tags.alias(name="entity_tags") + employer_tags = tags.alias(name="employer_tags") q = ( select([entity.c.name]) .select_from( - entity - .outerjoin(entity_tags, tags.c.entity_id == entity.c.id) - .outerjoin(employer_tags, - tags.c.entity_id == entity.c.employer_id) + entity.outerjoin( + entity_tags, tags.c.entity_id == entity.c.id + ).outerjoin( + employer_tags, tags.c.entity_id == entity.c.employer_id + ) ) - .where(entity_tags.c.tags.op('@>')(bindparam('tags'))) - .where(employer_tags.c.tags.op('@>')(bindparam('tags'))) + .where(entity_tags.c.tags.op("@>")(bindparam("tags"))) + .where(employer_tags.c.tags.op("@>")(bindparam("tags"))) ) self.assert_compile( q, - 'WITH unaliased_tags AS ' - '(SELECT tag.entity_id AS entity_id, array_agg(tag.tag) AS tags ' - 'FROM tag GROUP BY tag.entity_id)' - ' SELECT entity.name ' - 'FROM entity ' - 'LEFT OUTER JOIN unaliased_tags AS entity_tags ON ' - 'unaliased_tags.entity_id = entity.id ' - 'LEFT OUTER JOIN unaliased_tags AS employer_tags ON ' - 'unaliased_tags.entity_id = entity.employer_id ' - 'WHERE (entity_tags.tags @> :tags) AND ' - '(employer_tags.tags @> :tags)' - ) - - cloned = q.params(tags=['tag1', 'tag2']) + "WITH unaliased_tags AS " + "(SELECT tag.entity_id AS entity_id, array_agg(tag.tag) AS tags " + "FROM tag GROUP BY tag.entity_id)" + " SELECT entity.name " + "FROM entity " + "LEFT OUTER JOIN unaliased_tags AS entity_tags ON " + "unaliased_tags.entity_id = entity.id " + "LEFT OUTER JOIN unaliased_tags AS employer_tags ON " + "unaliased_tags.entity_id = entity.employer_id " + "WHERE (entity_tags.tags @> :tags) AND " + "(employer_tags.tags @> :tags)", + ) + + cloned = q.params(tags=["tag1", "tag2"]) self.assert_compile( cloned, - 'WITH unaliased_tags AS ' - '(SELECT tag.entity_id AS entity_id, array_agg(tag.tag) AS tags ' - 'FROM tag GROUP BY tag.entity_id)' - ' SELECT entity.name ' - 'FROM entity ' - 'LEFT OUTER JOIN unaliased_tags AS entity_tags ON ' - 'unaliased_tags.entity_id = entity.id ' - 'LEFT OUTER JOIN unaliased_tags AS employer_tags ON ' - 'unaliased_tags.entity_id = entity.employer_id ' - 'WHERE (entity_tags.tags @> :tags) AND ' - '(employer_tags.tags @> :tags)') + "WITH unaliased_tags AS " + "(SELECT tag.entity_id AS entity_id, array_agg(tag.tag) AS tags " + "FROM tag GROUP BY tag.entity_id)" + " SELECT entity.name " + "FROM entity " + "LEFT OUTER JOIN unaliased_tags AS entity_tags ON " + "unaliased_tags.entity_id = entity.id " + "LEFT OUTER JOIN unaliased_tags AS employer_tags ON " + "unaliased_tags.entity_id = entity.employer_id " + "WHERE (entity_tags.tags @> :tags) AND " + "(employer_tags.tags @> :tags)", + ) def test_reserved_quote(self): - orders = table('orders', - column('order'), - ) + orders = table("orders", column("order")) s = select([orders.c.order]).cte("regional_sales", recursive=True) s = select([s.c.order]) - self.assert_compile(s, - 'WITH RECURSIVE regional_sales("order") AS ' - '(SELECT orders."order" AS "order" ' - "FROM orders)" - ' SELECT regional_sales."order" ' - "FROM regional_sales" - ) + self.assert_compile( + s, + 'WITH RECURSIVE regional_sales("order") AS ' + '(SELECT orders."order" AS "order" ' + "FROM orders)" + ' SELECT regional_sales."order" ' + "FROM regional_sales", + ) def test_multi_subq_quote(self): - cte = select([literal(1).label("id")]).cte(name='CTE') + cte = select([literal(1).label("id")]).cte(name="CTE") s1 = select([cte.c.id]).alias() s2 = select([cte.c.id]).alias() @@ -573,13 +629,13 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( s, 'WITH "CTE" AS (SELECT :param_1 AS id) ' - 'SELECT anon_1.id, anon_2.id FROM ' + "SELECT anon_1.id, anon_2.id FROM " '(SELECT "CTE".id AS id FROM "CTE") AS anon_1, ' - '(SELECT "CTE".id AS id FROM "CTE") AS anon_2' + '(SELECT "CTE".id AS id FROM "CTE") AS anon_2', ) def test_multi_subq_alias(self): - cte = select([literal(1).label("id")]).cte(name='cte1').alias("aa") + cte = select([literal(1).label("id")]).cte(name="cte1").alias("aa") s1 = select([cte.c.id]).alias() s2 = select([cte.c.id]).alias() @@ -590,37 +646,33 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "WITH cte1 AS (SELECT :param_1 AS id) " "SELECT anon_1.id, anon_2.id FROM " "(SELECT aa.id AS id FROM cte1 AS aa) AS anon_1, " - "(SELECT aa.id AS id FROM cte1 AS aa) AS anon_2" + "(SELECT aa.id AS id FROM cte1 AS aa) AS anon_2", ) def test_cte_refers_to_aliased_cte_twice(self): # test issue #4204 - a = table('a', column('id')) - b = table('b', column('id'), column('fid')) - c = table('c', column('id'), column('fid')) + a = table("a", column("id")) + b = table("b", column("id"), column("fid")) + c = table("c", column("id"), column("fid")) - cte1 = ( - select([a.c.id]) - .cte(name='cte1') - ) + cte1 = select([a.c.id]).cte(name="cte1") - aa = cte1.alias('aa') + aa = cte1.alias("aa") cte2 = ( select([b.c.id]) .select_from(b.join(aa, b.c.fid == aa.c.id)) - .cte(name='cte2') + .cte(name="cte2") ) cte3 = ( select([c.c.id]) .select_from(c.join(aa, c.c.fid == aa.c.id)) - .cte(name='cte3') + .cte(name="cte3") ) - stmt = ( - select([cte3.c.id, cte2.c.id]) - .select_from(cte2.join(cte3, cte2.c.id == cte3.c.id)) + stmt = select([cte3.c.id, cte2.c.id]).select_from( + cte2.join(cte3, cte2.c.id == cte3.c.id) ) self.assert_compile( stmt, @@ -629,11 +681,11 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "JOIN cte1 AS aa ON b.fid = aa.id), " "cte3 AS (SELECT c.id AS id FROM c " "JOIN cte1 AS aa ON c.fid = aa.id) " - "SELECT cte3.id, cte2.id FROM cte2 JOIN cte3 ON cte2.id = cte3.id" + "SELECT cte3.id, cte2.id FROM cte2 JOIN cte3 ON cte2.id = cte3.id", ) def test_named_alias_no_quote(self): - cte = select([literal(1).label("id")]).cte(name='CTE') + cte = select([literal(1).label("id")]).cte(name="CTE") s1 = select([cte.c.id]).alias(name="no_quotes") @@ -641,12 +693,12 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( s, 'WITH "CTE" AS (SELECT :param_1 AS id) ' - 'SELECT no_quotes.id FROM ' - '(SELECT "CTE".id AS id FROM "CTE") AS no_quotes' + "SELECT no_quotes.id FROM " + '(SELECT "CTE".id AS id FROM "CTE") AS no_quotes', ) def test_named_alias_quote(self): - cte = select([literal(1).label("id")]).cte(name='CTE') + cte = select([literal(1).label("id")]).cte(name="CTE") s1 = select([cte.c.id]).alias(name="Quotes Required") @@ -655,79 +707,96 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): s, 'WITH "CTE" AS (SELECT :param_1 AS id) ' 'SELECT "Quotes Required".id FROM ' - '(SELECT "CTE".id AS id FROM "CTE") AS "Quotes Required"' + '(SELECT "CTE".id AS id FROM "CTE") AS "Quotes Required"', ) def test_named_alias_disable_quote(self): cte = select([literal(1).label("id")]).cte( - name=quoted_name('CTE', quote=False)) + name=quoted_name("CTE", quote=False) + ) s1 = select([cte.c.id]).alias( - name=quoted_name("DontQuote", quote=False)) + name=quoted_name("DontQuote", quote=False) + ) s = select([s1]) self.assert_compile( s, - 'WITH CTE AS (SELECT :param_1 AS id) ' - 'SELECT DontQuote.id FROM ' - '(SELECT CTE.id AS id FROM CTE) AS DontQuote' + "WITH CTE AS (SELECT :param_1 AS id) " + "SELECT DontQuote.id FROM " + "(SELECT CTE.id AS id FROM CTE) AS DontQuote", ) def test_positional_binds(self): - orders = table('orders', - column('order'), - ) + orders = table("orders", column("order")) s = select([orders.c.order, literal("x")]).cte("regional_sales") s = select([s.c.order, literal("y")]) dialect = default.DefaultDialect() dialect.positional = True - dialect.paramstyle = 'numeric' + dialect.paramstyle = "numeric" self.assert_compile( s, 'WITH regional_sales AS (SELECT orders."order" ' 'AS "order", :1 AS anon_2 FROM orders) SELECT ' 'regional_sales."order", :2 AS anon_1 FROM regional_sales', - checkpositional=( - 'x', - 'y'), - dialect=dialect) + checkpositional=("x", "y"), + dialect=dialect, + ) self.assert_compile( - s.union(s), 'WITH regional_sales AS (SELECT orders."order" ' + s.union(s), + 'WITH regional_sales AS (SELECT orders."order" ' 'AS "order", :1 AS anon_2 FROM orders) SELECT ' 'regional_sales."order", :2 AS anon_1 FROM regional_sales ' 'UNION SELECT regional_sales."order", :3 AS anon_1 ' - 'FROM regional_sales', checkpositional=( - 'x', 'y', 'y'), dialect=dialect) + "FROM regional_sales", + checkpositional=("x", "y", "y"), + dialect=dialect, + ) - s = select([orders.c.order]).\ - where(orders.c.order == 'x').cte("regional_sales") + s = ( + select([orders.c.order]) + .where(orders.c.order == "x") + .cte("regional_sales") + ) s = select([s.c.order]).where(s.c.order == "y") self.assert_compile( - s, 'WITH regional_sales AS (SELECT orders."order" AS ' + s, + 'WITH regional_sales AS (SELECT orders."order" AS ' '"order" FROM orders WHERE orders."order" = :1) ' 'SELECT regional_sales."order" FROM regional_sales ' - 'WHERE regional_sales."order" = :2', checkpositional=( - 'x', 'y'), dialect=dialect) + 'WHERE regional_sales."order" = :2', + checkpositional=("x", "y"), + dialect=dialect, + ) def test_positional_binds_2(self): - orders = table('orders', - column('order'), - ) + orders = table("orders", column("order")) s = select([orders.c.order, literal("x")]).cte("regional_sales") s = select([s.c.order, literal("y")]) dialect = default.DefaultDialect() dialect.positional = True - dialect.paramstyle = 'numeric' - s1 = select([orders.c.order]).where(orders.c.order == 'x').\ - cte("regional_sales_1") + dialect.paramstyle = "numeric" + s1 = ( + select([orders.c.order]) + .where(orders.c.order == "x") + .cte("regional_sales_1") + ) s1a = s1.alias() - s2 = select([orders.c.order == 'y', s1a.c.order, - orders.c.order, s1.c.order]).\ - where(orders.c.order == 'z').\ - cte("regional_sales_2") + s2 = ( + select( + [ + orders.c.order == "y", + s1a.c.order, + orders.c.order, + s1.c.order, + ] + ) + .where(orders.c.order == "z") + .cte("regional_sales_2") + ) s3 = select([s2]) @@ -739,52 +808,65 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): 'anon_2."order" AS "order", ' 'orders."order" AS "order", ' 'regional_sales_1."order" AS "order" FROM orders, ' - 'regional_sales_1 ' - 'AS anon_2, regional_sales_1 ' + "regional_sales_1 " + "AS anon_2, regional_sales_1 " 'WHERE orders."order" = :3) SELECT regional_sales_2.anon_1, ' 'regional_sales_2."order" FROM regional_sales_2', - checkpositional=('x', 'y', 'z'), dialect=dialect) + checkpositional=("x", "y", "z"), + dialect=dialect, + ) def test_positional_binds_2_asliteral(self): - orders = table('orders', - column('order'), - ) + orders = table("orders", column("order")) s = select([orders.c.order, literal("x")]).cte("regional_sales") s = select([s.c.order, literal("y")]) dialect = default.DefaultDialect() dialect.positional = True - dialect.paramstyle = 'numeric' - s1 = select([orders.c.order]).where(orders.c.order == 'x').\ - cte("regional_sales_1") + dialect.paramstyle = "numeric" + s1 = ( + select([orders.c.order]) + .where(orders.c.order == "x") + .cte("regional_sales_1") + ) s1a = s1.alias() - s2 = select([orders.c.order == 'y', s1a.c.order, - orders.c.order, s1.c.order]).\ - where(orders.c.order == 'z').\ - cte("regional_sales_2") + s2 = ( + select( + [ + orders.c.order == "y", + s1a.c.order, + orders.c.order, + s1.c.order, + ] + ) + .where(orders.c.order == "z") + .cte("regional_sales_2") + ) s3 = select([s2]) self.assert_compile( s3, - 'WITH regional_sales_1 AS ' + "WITH regional_sales_1 AS " '(SELECT orders."order" AS "order" ' - 'FROM orders ' - 'WHERE orders."order" = \'x\'), ' - 'regional_sales_2 AS ' - '(SELECT orders."order" = \'y\' AS anon_1, ' + "FROM orders " + "WHERE orders.\"order\" = 'x'), " + "regional_sales_2 AS " + "(SELECT orders.\"order\" = 'y' AS anon_1, " 'anon_2."order" AS "order", orders."order" AS "order", ' 'regional_sales_1."order" AS "order" ' - 'FROM orders, regional_sales_1 AS anon_2, regional_sales_1 ' - 'WHERE orders."order" = \'z\') ' + "FROM orders, regional_sales_1 AS anon_2, regional_sales_1 " + "WHERE orders.\"order\" = 'z') " 'SELECT regional_sales_2.anon_1, regional_sales_2."order" ' - 'FROM regional_sales_2', - checkpositional=(), dialect=dialect, - literal_binds=True) + "FROM regional_sales_2", + checkpositional=(), + dialect=dialect, + literal_binds=True, + ) def test_all_aliases(self): - orders = table('order', column('order')) + orders = table("order", column("order")) s = select([orders.c.order]).cte("regional_sales") r1 = s.alias() @@ -797,38 +879,36 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): 'WITH regional_sales AS (SELECT "order"."order" ' 'AS "order" FROM "order") ' 'SELECT anon_1."order", anon_2."order" ' - 'FROM regional_sales AS anon_1, ' - 'regional_sales AS anon_2 WHERE anon_1."order" > anon_2."order"' + "FROM regional_sales AS anon_1, " + 'regional_sales AS anon_2 WHERE anon_1."order" > anon_2."order"', ) - s3 = select( - [orders]).select_from( - orders.join( - r1, - r1.c.order == orders.c.order)) + s3 = select([orders]).select_from( + orders.join(r1, r1.c.order == orders.c.order) + ) self.assert_compile( s3, - 'WITH regional_sales AS ' + "WITH regional_sales AS " '(SELECT "order"."order" AS "order" ' 'FROM "order")' ' SELECT "order"."order" ' 'FROM "order" JOIN regional_sales AS anon_1 ' - 'ON anon_1."order" = "order"."order"' + 'ON anon_1."order" = "order"."order"', ) def test_suffixes(self): - orders = table('order', column('order')) + orders = table("order", column("order")) s = select([orders.c.order]).cte("regional_sales") - s = s.suffix_with("pg suffix", dialect='postgresql') - s = s.suffix_with('oracle suffix', dialect='oracle') + s = s.suffix_with("pg suffix", dialect="postgresql") + s = s.suffix_with("oracle suffix", dialect="oracle") stmt = select([orders]).where(orders.c.order > s.c.order) self.assert_compile( stmt, 'WITH regional_sales AS (SELECT "order"."order" AS "order" ' 'FROM "order") SELECT "order"."order" FROM "order", ' - 'regional_sales WHERE "order"."order" > regional_sales."order"' + 'regional_sales WHERE "order"."order" > regional_sales."order"', ) self.assert_compile( @@ -837,7 +917,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): 'FROM "order") oracle suffix ' 'SELECT "order"."order" FROM "order", ' 'regional_sales WHERE "order"."order" > regional_sales."order"', - dialect='oracle' + dialect="oracle", ) self.assert_compile( @@ -845,30 +925,36 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): 'WITH regional_sales AS (SELECT "order"."order" AS "order" ' 'FROM "order") pg suffix SELECT "order"."order" FROM "order", ' 'regional_sales WHERE "order"."order" > regional_sales."order"', - dialect='postgresql' + dialect="postgresql", ) def test_upsert_from_select(self): orders = table( - 'orders', - column('region'), - column('amount'), - column('product'), - column('quantity') + "orders", + column("region"), + column("amount"), + column("product"), + column("quantity"), ) upsert = ( orders.update() - .where(orders.c.region == 'Region1') - .values(amount=1.0, product='Product1', quantity=1) - .returning(*(orders.c._all_columns)).cte('upsert')) + .where(orders.c.region == "Region1") + .values(amount=1.0, product="Product1", quantity=1) + .returning(*(orders.c._all_columns)) + .cte("upsert") + ) insert = orders.insert().from_select( orders.c.keys(), - select([ - literal('Region1'), literal(1.0), - literal('Product1'), literal(1) - ]).where(~exists(upsert.select())) + select( + [ + literal("Region1"), + literal(1.0), + literal("Product1"), + literal(1), + ] + ).where(~exists(upsert.select())), ) self.assert_compile( @@ -882,52 +968,55 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT :param_1 AS anon_1, :param_2 AS anon_2, " ":param_3 AS anon_3, :param_4 AS anon_4 WHERE NOT (EXISTS " "(SELECT upsert.region, upsert.amount, upsert.product, " - "upsert.quantity FROM upsert))" + "upsert.quantity FROM upsert))", ) def test_anon_update_cte(self): - orders = table( - 'orders', - column('region') + orders = table("orders", column("region")) + stmt = ( + orders.update() + .where(orders.c.region == "x") + .values(region="y") + .returning(orders.c.region) + .cte() ) - stmt = orders.update().where(orders.c.region == 'x').\ - values(region='y').\ - returning(orders.c.region).cte() self.assert_compile( stmt.select(), "WITH anon_1 AS (UPDATE orders SET region=:region " "WHERE orders.region = :region_1 RETURNING orders.region) " - "SELECT anon_1.region FROM anon_1" + "SELECT anon_1.region FROM anon_1", ) def test_anon_insert_cte(self): - orders = table( - 'orders', - column('region') + orders = table("orders", column("region")) + stmt = ( + orders.insert().values(region="y").returning(orders.c.region).cte() ) - stmt = orders.insert().\ - values(region='y').\ - returning(orders.c.region).cte() self.assert_compile( stmt.select(), "WITH anon_1 AS (INSERT INTO orders (region) " "VALUES (:region) RETURNING orders.region) " - "SELECT anon_1.region FROM anon_1" + "SELECT anon_1.region FROM anon_1", ) def test_pg_example_one(self): - products = table('products', column('id'), column('date')) - products_log = table('products_log', column('id'), column('date')) + products = table("products", column("id"), column("date")) + products_log = table("products_log", column("id"), column("date")) - moved_rows = products.delete().where(and_( - products.c.date >= 'dateone', - products.c.date < 'datetwo')).returning(*products.c).\ - cte('moved_rows') + moved_rows = ( + products.delete() + .where( + and_(products.c.date >= "dateone", products.c.date < "datetwo") + ) + .returning(*products.c) + .cte("moved_rows") + ) stmt = products_log.insert().from_select( - products_log.c, moved_rows.select()) + products_log.c, moved_rows.select() + ) self.assert_compile( stmt, "WITH moved_rows AS " @@ -935,17 +1024,21 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "AND products.date < :date_2 " "RETURNING products.id, products.date) " "INSERT INTO products_log (id, date) " - "SELECT moved_rows.id, moved_rows.date FROM moved_rows" + "SELECT moved_rows.id, moved_rows.date FROM moved_rows", ) def test_pg_example_two(self): - products = table('products', column('id'), column('price')) + products = table("products", column("id"), column("price")) - t = products.update().values(price='someprice').\ - returning(*products.c).cte('t') + t = ( + products.update() + .values(price="someprice") + .returning(*products.c) + .cte("t") + ) stmt = t.select() - assert 'autocommit' not in stmt._execution_options - eq_(stmt.compile().execution_options['autocommit'], True) + assert "autocommit" not in stmt._execution_options + eq_(stmt.compile().execution_options["autocommit"], True) self.assert_compile( stmt, @@ -953,34 +1046,29 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "(UPDATE products SET price=:price " "RETURNING products.id, products.price) " "SELECT t.id, t.price " - "FROM t" + "FROM t", ) def test_pg_example_three(self): - parts = table( - 'parts', - column('part'), - column('sub_part'), - ) + parts = table("parts", column("part"), column("sub_part")) - included_parts = select([ - parts.c.sub_part, - parts.c.part]).\ - where(parts.c.part == 'our part').\ - cte("included_parts", recursive=True) + included_parts = ( + select([parts.c.sub_part, parts.c.part]) + .where(parts.c.part == "our part") + .cte("included_parts", recursive=True) + ) - pr = included_parts.alias('pr') - p = parts.alias('p') + pr = included_parts.alias("pr") + p = parts.alias("p") included_parts = included_parts.union_all( - select([ - p.c.sub_part, - p.c.part]). - where(p.c.part == pr.c.sub_part) + select([p.c.sub_part, p.c.part]).where(p.c.part == pr.c.sub_part) + ) + stmt = ( + parts.delete() + .where(parts.c.part.in_(select([included_parts.c.part]))) + .returning(parts.c.part) ) - stmt = parts.delete().where( - parts.c.part.in_(select([included_parts.c.part]))).returning( - parts.c.part) # the outer RETURNING is a bonus over what PG's docs have self.assert_compile( @@ -994,19 +1082,23 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "WHERE p.part = pr.sub_part) " "DELETE FROM parts WHERE parts.part IN " "(SELECT included_parts.part FROM included_parts) " - "RETURNING parts.part" + "RETURNING parts.part", ) def test_insert_in_the_cte(self): - products = table('products', column('id'), column('price')) + products = table("products", column("id"), column("price")) - cte = products.insert().values(id=1, price=27.0).\ - returning(*products.c).cte('pd') + cte = ( + products.insert() + .values(id=1, price=27.0) + .returning(*products.c) + .cte("pd") + ) stmt = select([cte]) - assert 'autocommit' not in stmt._execution_options - eq_(stmt.compile().execution_options['autocommit'], True) + assert "autocommit" not in stmt._execution_options + eq_(stmt.compile().execution_options["autocommit"], True) self.assert_compile( stmt, @@ -1014,17 +1106,17 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "(INSERT INTO products (id, price) VALUES (:id, :price) " "RETURNING products.id, products.price) " "SELECT pd.id, pd.price " - "FROM pd" + "FROM pd", ) def test_update_pulls_from_cte(self): - products = table('products', column('id'), column('price')) + products = table("products", column("id"), column("price")) - cte = products.select().cte('pd') - assert 'autocommit' not in cte._execution_options + cte = products.select().cte("pd") + assert "autocommit" not in cte._execution_options stmt = products.update().where(products.c.price == cte.c.price) - eq_(stmt.compile().execution_options['autocommit'], True) + eq_(stmt.compile().execution_options["autocommit"], True) self.assert_compile( stmt, @@ -1032,5 +1124,5 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "(SELECT products.id AS id, products.price AS price " "FROM products) " "UPDATE products SET id=:id, price=:price FROM pd " - "WHERE products.price = pd.price" + "WHERE products.price = pd.price", ) diff --git a/test/sql/test_ddlemit.py b/test/sql/test_ddlemit.py index 25f9c595fe..6c949f9ee9 100644 --- a/test/sql/test_ddlemit.py +++ b/test/sql/test_ddlemit.py @@ -6,54 +6,65 @@ from sqlalchemy.testing.mock import Mock class EmitDDLTest(fixtures.TestBase): - def _mock_connection(self, item_exists): def has_item(connection, name, schema): return item_exists(name) - return Mock(dialect=Mock( - supports_sequences=True, - has_table=Mock(side_effect=has_item), - has_sequence=Mock(side_effect=has_item), - supports_comments=True, - inline_comments=False, - ) - ) - - def _mock_create_fixture(self, checkfirst, tables, - item_exists=lambda item: False): + return Mock( + dialect=Mock( + supports_sequences=True, + has_table=Mock(side_effect=has_item), + has_sequence=Mock(side_effect=has_item), + supports_comments=True, + inline_comments=False, + ) + ) + + def _mock_create_fixture( + self, checkfirst, tables, item_exists=lambda item: False + ): connection = self._mock_connection(item_exists) - return SchemaGenerator(connection.dialect, connection, - checkfirst=checkfirst, - tables=tables) + return SchemaGenerator( + connection.dialect, + connection, + checkfirst=checkfirst, + tables=tables, + ) - def _mock_drop_fixture(self, checkfirst, tables, - item_exists=lambda item: True): + def _mock_drop_fixture( + self, checkfirst, tables, item_exists=lambda item: True + ): connection = self._mock_connection(item_exists) - return SchemaDropper(connection.dialect, connection, - checkfirst=checkfirst, - tables=tables) + return SchemaDropper( + connection.dialect, + connection, + checkfirst=checkfirst, + tables=tables, + ) def _table_fixture(self): m = MetaData() - return (m, ) + tuple( - Table('t%d' % i, m, Column('x', Integer)) - for i in range(1, 6) + return (m,) + tuple( + Table("t%d" % i, m, Column("x", Integer)) for i in range(1, 6) ) def _use_alter_fixture_one(self): m = MetaData() t1 = Table( - 't1', m, Column('id', Integer, primary_key=True), - Column('t2id', Integer, ForeignKey('t2.id')) + "t1", + m, + Column("id", Integer, primary_key=True), + Column("t2id", Integer, ForeignKey("t2.id")), ) t2 = Table( - 't2', m, Column('id', Integer, primary_key=True), - Column('t1id', Integer, ForeignKey('t1.id')) + "t2", + m, + Column("id", Integer, primary_key=True), + Column("t1id", Integer, ForeignKey("t1.id")), ) return m, t1, t2 @@ -61,33 +72,30 @@ class EmitDDLTest(fixtures.TestBase): m = MetaData() t1 = Table( - 't1', m, Column('id', Integer, primary_key=True), - Column('t2id', Integer, ForeignKey('t2.id')) - ) - t2 = Table( - 't2', m, Column('id', Integer, primary_key=True), + "t1", + m, + Column("id", Integer, primary_key=True), + Column("t2id", Integer, ForeignKey("t2.id")), ) + t2 = Table("t2", m, Column("id", Integer, primary_key=True)) return m, t1, t2 def _table_seq_fixture(self): m = MetaData() - s1 = Sequence('s1') - s2 = Sequence('s2') - t1 = Table('t1', m, Column("x", Integer, s1, primary_key=True)) - t2 = Table('t2', m, Column("x", Integer, s2, primary_key=True)) + s1 = Sequence("s1") + s2 = Sequence("s2") + t1 = Table("t1", m, Column("x", Integer, s1, primary_key=True)) + t2 = Table("t2", m, Column("x", Integer, s2, primary_key=True)) return m, t1, t2, s1, s2 def _table_comment_fixture(self): m = MetaData() - c1 = Column('id', Integer, comment='c1') + c1 = Column("id", Integer, comment="c1") - t1 = Table( - 't1', m, c1, - comment='t1' - ) + t1 = Table("t1", m, c1, comment="t1") return m, t1, c1 @@ -95,154 +103,127 @@ class EmitDDLTest(fixtures.TestBase): m, t1, c1 = self._table_comment_fixture() generator = self._mock_create_fixture( - False, [t1], item_exists=lambda t: t not in ("t1",)) + False, [t1], item_exists=lambda t: t not in ("t1",) + ) self._assert_create_comment([t1, t1, c1], generator, m) def test_create_seq_checkfirst(self): m, t1, t2, s1, s2 = self._table_seq_fixture() generator = self._mock_create_fixture( - True, [ - t1, t2], item_exists=lambda t: t not in ( - "t1", "s1")) + True, [t1, t2], item_exists=lambda t: t not in ("t1", "s1") + ) self._assert_create([t1, s1], generator, m) def test_drop_seq_checkfirst(self): m, t1, t2, s1, s2 = self._table_seq_fixture() generator = self._mock_drop_fixture( - True, [ - t1, t2], item_exists=lambda t: t in ( - "t1", "s1")) + True, [t1, t2], item_exists=lambda t: t in ("t1", "s1") + ) self._assert_drop([t1, s1], generator, m) def test_create_collection_checkfirst(self): m, t1, t2, t3, t4, t5 = self._table_fixture() generator = self._mock_create_fixture( - True, [ - t2, t3, t4], item_exists=lambda t: t not in ( - "t2", "t4")) + True, [t2, t3, t4], item_exists=lambda t: t not in ("t2", "t4") + ) self._assert_create_tables([t2, t4], generator, m) def test_drop_collection_checkfirst(self): m, t1, t2, t3, t4, t5 = self._table_fixture() generator = self._mock_drop_fixture( - True, [ - t2, t3, t4], item_exists=lambda t: t in ( - "t2", "t4")) + True, [t2, t3, t4], item_exists=lambda t: t in ("t2", "t4") + ) self._assert_drop_tables([t2, t4], generator, m) def test_create_collection_nocheck(self): m, t1, t2, t3, t4, t5 = self._table_fixture() generator = self._mock_create_fixture( - False, [ - t2, t3, t4], item_exists=lambda t: t not in ( - "t2", "t4")) + False, [t2, t3, t4], item_exists=lambda t: t not in ("t2", "t4") + ) self._assert_create_tables([t2, t3, t4], generator, m) def test_create_empty_collection(self): m, t1, t2, t3, t4, t5 = self._table_fixture() generator = self._mock_create_fixture( - True, - [], - item_exists=lambda t: t not in ( - "t2", - "t4")) + True, [], item_exists=lambda t: t not in ("t2", "t4") + ) self._assert_create_tables([], generator, m) def test_drop_empty_collection(self): m, t1, t2, t3, t4, t5 = self._table_fixture() generator = self._mock_drop_fixture( - True, - [], - item_exists=lambda t: t in ( - "t2", - "t4")) + True, [], item_exists=lambda t: t in ("t2", "t4") + ) self._assert_drop_tables([], generator, m) def test_drop_collection_nocheck(self): m, t1, t2, t3, t4, t5 = self._table_fixture() generator = self._mock_drop_fixture( - False, [ - t2, t3, t4], item_exists=lambda t: t in ( - "t2", "t4")) + False, [t2, t3, t4], item_exists=lambda t: t in ("t2", "t4") + ) self._assert_drop_tables([t2, t3, t4], generator, m) def test_create_metadata_checkfirst(self): m, t1, t2, t3, t4, t5 = self._table_fixture() generator = self._mock_create_fixture( - True, - None, - item_exists=lambda t: t not in ( - "t2", - "t4")) + True, None, item_exists=lambda t: t not in ("t2", "t4") + ) self._assert_create_tables([t2, t4], generator, m) def test_drop_metadata_checkfirst(self): m, t1, t2, t3, t4, t5 = self._table_fixture() generator = self._mock_drop_fixture( - True, - None, - item_exists=lambda t: t in ( - "t2", - "t4")) + True, None, item_exists=lambda t: t in ("t2", "t4") + ) self._assert_drop_tables([t2, t4], generator, m) def test_create_metadata_nocheck(self): m, t1, t2, t3, t4, t5 = self._table_fixture() generator = self._mock_create_fixture( - False, - None, - item_exists=lambda t: t not in ( - "t2", - "t4")) + False, None, item_exists=lambda t: t not in ("t2", "t4") + ) self._assert_create_tables([t1, t2, t3, t4, t5], generator, m) def test_drop_metadata_nocheck(self): m, t1, t2, t3, t4, t5 = self._table_fixture() generator = self._mock_drop_fixture( - False, - None, - item_exists=lambda t: t in ( - "t2", - "t4")) + False, None, item_exists=lambda t: t in ("t2", "t4") + ) self._assert_drop_tables([t1, t2, t3, t4, t5], generator, m) def test_create_metadata_auto_alter_fk(self): m, t1, t2 = self._use_alter_fixture_one() - generator = self._mock_create_fixture( - False, [t1, t2] - ) + generator = self._mock_create_fixture(False, [t1, t2]) self._assert_create_w_alter( - [t1, t2] + - list(t1.foreign_key_constraints) + - list(t2.foreign_key_constraints), + [t1, t2] + + list(t1.foreign_key_constraints) + + list(t2.foreign_key_constraints), generator, - m + m, ) def test_create_metadata_inline_fk(self): m, t1, t2 = self._fk_fixture_one() - generator = self._mock_create_fixture( - False, [t1, t2] - ) + generator = self._mock_create_fixture(False, [t1, t2]) self._assert_create_w_alter( - [t1, t2] + - list(t1.foreign_key_constraints) + - list(t2.foreign_key_constraints), + [t1, t2] + + list(t1.foreign_key_constraints) + + list(t2.foreign_key_constraints), generator, - m + m, ) def _assert_create_tables(self, elements, generator, argument): @@ -254,38 +235,60 @@ class EmitDDLTest(fixtures.TestBase): def _assert_create(self, elements, generator, argument): self._assert_ddl( (schema.CreateTable, schema.CreateSequence), - elements, generator, argument) + elements, + generator, + argument, + ) def _assert_drop(self, elements, generator, argument): self._assert_ddl( (schema.DropTable, schema.DropSequence), - elements, generator, argument) + elements, + generator, + argument, + ) def _assert_create_w_alter(self, elements, generator, argument): self._assert_ddl( (schema.CreateTable, schema.CreateSequence, schema.AddConstraint), - elements, generator, argument) + elements, + generator, + argument, + ) def _assert_drop_w_alter(self, elements, generator, argument): self._assert_ddl( (schema.DropTable, schema.DropSequence, schema.DropConstraint), - elements, generator, argument) + elements, + generator, + argument, + ) def _assert_create_comment(self, elements, generator, argument): self._assert_ddl( - (schema.CreateTable, schema.SetTableComment, schema.SetColumnComment), - elements, generator, argument) + ( + schema.CreateTable, + schema.SetTableComment, + schema.SetColumnComment, + ), + elements, + generator, + argument, + ) def _assert_ddl(self, ddl_cls, elements, generator, argument): generator.traverse_single(argument) for call_ in generator.connection.execute.mock_calls: c = call_[1][0] assert isinstance(c, ddl_cls) - assert c.element in elements, "element %r was not expected"\ - % c.element + assert c.element in elements, ( + "element %r was not expected" % c.element + ) elements.remove(c.element) - if getattr(c, 'include_foreign_key_constraints', None) is not None: + if getattr(c, "include_foreign_key_constraints", None) is not None: elements[:] = [ - e for e in elements - if e not in set(c.include_foreign_key_constraints)] + e + for e in elements + if e not in set(c.include_foreign_key_constraints) + ] assert not elements, "elements remain in list: %r" % elements diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index c5efdf132e..b518a606f2 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -1,5 +1,10 @@ -from sqlalchemy.testing import eq_, assert_raises_message, \ - assert_raises, AssertsCompiledSQL, expect_warnings +from sqlalchemy.testing import ( + eq_, + assert_raises_message, + assert_raises, + AssertsCompiledSQL, + expect_warnings, +) import datetime from sqlalchemy.schema import CreateSequence, DropSequence, CreateTable from sqlalchemy.sql import select, text, literal_column @@ -7,8 +12,19 @@ import sqlalchemy as sa from sqlalchemy import testing from sqlalchemy.testing import engines from sqlalchemy import ( - MetaData, Integer, String, ForeignKey, Boolean, exc, Sequence, func, - literal, Unicode, cast, DateTime) + MetaData, + Integer, + String, + ForeignKey, + Boolean, + exc, + Sequence, + func, + literal, + Unicode, + cast, + DateTime, +) from sqlalchemy.types import TypeDecorator, TypeEngine from sqlalchemy.testing.schema import Table, Column from sqlalchemy.dialects import sqlite @@ -23,96 +39,100 @@ t = f = f2 = ts = currenttime = metadata = default_generator = None class DDLTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_string(self): # note: that the datatype is an Integer here doesn't matter, # the server_default is interpreted independently of the # column's datatype. m = MetaData() - t = Table('t', m, Column('x', Integer, server_default='5')) + t = Table("t", m, Column("x", Integer, server_default="5")) self.assert_compile( - CreateTable(t), - "CREATE TABLE t (x INTEGER DEFAULT '5')" + CreateTable(t), "CREATE TABLE t (x INTEGER DEFAULT '5')" ) def test_string_w_quotes(self): m = MetaData() - t = Table('t', m, Column('x', Integer, server_default="5'6")) + t = Table("t", m, Column("x", Integer, server_default="5'6")) self.assert_compile( - CreateTable(t), - "CREATE TABLE t (x INTEGER DEFAULT '5''6')" + CreateTable(t), "CREATE TABLE t (x INTEGER DEFAULT '5''6')" ) def test_text(self): m = MetaData() - t = Table('t', m, Column('x', Integer, server_default=text('5 + 8'))) + t = Table("t", m, Column("x", Integer, server_default=text("5 + 8"))) self.assert_compile( - CreateTable(t), - "CREATE TABLE t (x INTEGER DEFAULT 5 + 8)" + CreateTable(t), "CREATE TABLE t (x INTEGER DEFAULT 5 + 8)" ) def test_text_w_quotes(self): m = MetaData() - t = Table('t', m, Column('x', Integer, server_default=text("5 ' 8"))) + t = Table("t", m, Column("x", Integer, server_default=text("5 ' 8"))) self.assert_compile( - CreateTable(t), - "CREATE TABLE t (x INTEGER DEFAULT 5 ' 8)" + CreateTable(t), "CREATE TABLE t (x INTEGER DEFAULT 5 ' 8)" ) def test_literal_binds_w_quotes(self): m = MetaData() - t = Table('t', m, Column('x', Integer, - server_default=literal("5 ' 8"))) + t = Table( + "t", m, Column("x", Integer, server_default=literal("5 ' 8")) + ) self.assert_compile( - CreateTable(t), - """CREATE TABLE t (x INTEGER DEFAULT '5 '' 8')""" + CreateTable(t), """CREATE TABLE t (x INTEGER DEFAULT '5 '' 8')""" ) def test_text_literal_binds(self): m = MetaData() t = Table( - 't', m, + "t", + m, Column( - 'x', Integer, server_default=text('q + :x1').bindparams(x1=7))) + "x", Integer, server_default=text("q + :x1").bindparams(x1=7) + ), + ) self.assert_compile( - CreateTable(t), - "CREATE TABLE t (x INTEGER DEFAULT q + 7)" + CreateTable(t), "CREATE TABLE t (x INTEGER DEFAULT q + 7)" ) def test_sqlexpr(self): m = MetaData() - t = Table('t', m, Column( - 'x', Integer, - server_default=literal_column('a') + literal_column('b')) + t = Table( + "t", + m, + Column( + "x", + Integer, + server_default=literal_column("a") + literal_column("b"), + ), ) self.assert_compile( - CreateTable(t), - "CREATE TABLE t (x INTEGER DEFAULT a + b)" + CreateTable(t), "CREATE TABLE t (x INTEGER DEFAULT a + b)" ) def test_literal_binds_plain(self): m = MetaData() - t = Table('t', m, Column( - 'x', Integer, - server_default=literal('a') + literal('b')) + t = Table( + "t", + m, + Column("x", Integer, server_default=literal("a") + literal("b")), ) self.assert_compile( - CreateTable(t), - "CREATE TABLE t (x INTEGER DEFAULT 'a' || 'b')" + CreateTable(t), "CREATE TABLE t (x INTEGER DEFAULT 'a' || 'b')" ) def test_literal_binds_pgarray(self): from sqlalchemy.dialects.postgresql import ARRAY, array + m = MetaData() - t = Table('t', m, Column( - 'x', ARRAY(Integer), - server_default=array([1, 2, 3])) + t = Table( + "t", + m, + Column("x", ARRAY(Integer), server_default=array([1, 2, 3])), ) self.assert_compile( CreateTable(t), "CREATE TABLE t (x INTEGER[] DEFAULT ARRAY[1, 2, 3])", - dialect='postgresql' + dialect="postgresql", ) @@ -125,30 +145,29 @@ class DefaultTest(fixtures.TestBase): db = testing.db metadata = MetaData(db) - default_generator = {'x': 50} + default_generator = {"x": 50} def mydefault(): - default_generator['x'] += 1 - return default_generator['x'] + default_generator["x"] += 1 + return default_generator["x"] def myupdate_with_ctx(ctx): conn = ctx.connection - return conn.execute(sa.select([sa.text('13')])).scalar() + return conn.execute(sa.select([sa.text("13")])).scalar() def mydefault_using_connection(ctx): conn = ctx.connection try: - return conn.execute(sa.select([sa.text('12')])).scalar() + return conn.execute(sa.select([sa.text("12")])).scalar() finally: # ensure a "close()" on this connection does nothing, # since its a "branched" connection conn.close() - use_function_defaults = testing.against('postgresql', 'mssql') - is_oracle = testing.against('oracle') + use_function_defaults = testing.against("postgresql", "mssql") + is_oracle = testing.against("oracle") class MyClass(object): - @classmethod def gen_default(cls, ctx): return "hi" @@ -172,89 +191,92 @@ class DefaultTest(fixtures.TestBase): func.trunc( func.current_timestamp(), sa.literal_column("'DAY'"), - type_=sa.Date)])) + type_=sa.Date, + ) + ] + ) + ) assert isinstance(ts, datetime.date) and not isinstance( - ts, datetime.datetime) - f = sa.select([func.length('abcdef')], bind=db).scalar() - f2 = sa.select([func.length('abcdefghijk')], bind=db).scalar() + ts, datetime.datetime + ) + f = sa.select([func.length("abcdef")], bind=db).scalar() + f2 = sa.select([func.length("abcdefghijk")], bind=db).scalar() # TODO: engine propigation across nested functions not working currenttime = func.trunc( - currenttime, sa.literal_column("'DAY'"), bind=db, - type_=sa.Date) + currenttime, sa.literal_column("'DAY'"), bind=db, type_=sa.Date + ) def1 = currenttime def2 = func.trunc( sa.text("current_timestamp"), - sa.literal_column("'DAY'"), type_=sa.Date) + sa.literal_column("'DAY'"), + type_=sa.Date, + ) deftype = sa.Date elif use_function_defaults: - f = sa.select([func.length('abcdef')], bind=db).scalar() - f2 = sa.select([func.length('abcdefghijk')], bind=db).scalar() + f = sa.select([func.length("abcdef")], bind=db).scalar() + f2 = sa.select([func.length("abcdefghijk")], bind=db).scalar() def1 = currenttime deftype = sa.Date - if testing.against('mssql'): + if testing.against("mssql"): def2 = sa.text("getdate()") else: def2 = sa.text("current_date") ts = db.scalar(func.current_date()) else: - f = len('abcdef') - f2 = len('abcdefghijk') + f = len("abcdef") + f2 = len("abcdefghijk") def1 = def2 = "3" ts = 3 deftype = Integer t = Table( - 'default_test1', metadata, + "default_test1", + metadata, # python function - Column('col1', Integer, primary_key=True, - default=mydefault), - + Column("col1", Integer, primary_key=True, default=mydefault), # python literal - Column('col2', String(20), - default="imthedefault", - onupdate="im the update"), - + Column( + "col2", + String(20), + default="imthedefault", + onupdate="im the update", + ), # preexecute expression - Column('col3', Integer, - default=func.length('abcdef'), - onupdate=func.length('abcdefghijk')), - + Column( + "col3", + Integer, + default=func.length("abcdef"), + onupdate=func.length("abcdefghijk"), + ), # SQL-side default from sql expression - Column('col4', deftype, - server_default=def1), - + Column("col4", deftype, server_default=def1), # SQL-side default from literal expression - Column('col5', deftype, - server_default=def2), - + Column("col5", deftype, server_default=def2), # preexecute + update timestamp - Column('col6', sa.Date, - default=currenttime, - onupdate=currenttime), - - Column('boolcol1', sa.Boolean, default=True), - Column('boolcol2', sa.Boolean, default=False), - + Column("col6", sa.Date, default=currenttime, onupdate=currenttime), + Column("boolcol1", sa.Boolean, default=True), + Column("boolcol2", sa.Boolean, default=False), # python function which uses ExecutionContext - Column('col7', Integer, - default=mydefault_using_connection, - onupdate=myupdate_with_ctx), - + Column( + "col7", + Integer, + default=mydefault_using_connection, + onupdate=myupdate_with_ctx, + ), # python builtin - Column('col8', sa.Date, - default=datetime.date.today, - onupdate=datetime.date.today), + Column( + "col8", + sa.Date, + default=datetime.date.today, + onupdate=datetime.date.today, + ), # combo - Column('col9', String(20), - default='py', - server_default='ddl'), - + Column("col9", String(20), default="py", server_default="ddl"), # python method w/ context - Column('col10', String(20), default=MyClass.gen_default), - + Column("col10", String(20), default=MyClass.gen_default), # fixed default w/ type that has bound processor - Column('col11', MyType(), default='foo') + Column("col11", MyType(), default="foo"), ) t.create() @@ -264,12 +286,14 @@ class DefaultTest(fixtures.TestBase): t.drop() def teardown(self): - default_generator['x'] = 50 + default_generator["x"] = 50 t.delete().execute() def test_bad_arg_signature(self): - ex_msg = "ColumnDefault Python function takes zero " \ + ex_msg = ( + "ColumnDefault Python function takes zero " "or one positional arguments" + ) def fn1(x, y): pass @@ -278,22 +302,21 @@ class DefaultTest(fixtures.TestBase): pass class fn3(object): - def __init__(self, x, y): pass class FN4(object): - def __call__(self, x, y): pass + fn4 = FN4() for fn in fn1, fn2, fn3, fn4: assert_raises_message( - sa.exc.ArgumentError, ex_msg, sa.ColumnDefault, fn) + sa.exc.ArgumentError, ex_msg, sa.ColumnDefault, fn + ) def test_arg_signature(self): - def fn1(): pass @@ -305,50 +328,52 @@ class DefaultTest(fixtures.TestBase): def fn4(x=1, y=2, z=3): eq_(x, 1) + fn5 = list class fn6a(object): - def __init__(self, x): eq_(x, "context") class fn6b(object): - def __init__(self, x, y=3): eq_(x, "context") class FN7(object): - def __call__(self, x): eq_(x, "context") + fn7 = FN7() class FN8(object): - def __call__(self, x, y=3): eq_(x, "context") + fn8 = FN8() for fn in fn1, fn2, fn3, fn4, fn5, fn6a, fn6b, fn7, fn8: c = sa.ColumnDefault(fn) c.arg("context") - @testing.fails_on('firebird', 'Data type unknown') + @testing.fails_on("firebird", "Data type unknown") def test_standalone(self): c = testing.db.engine.contextual_connect() x = c.execute(t.c.col1.default) y = t.c.col2.default.execute() z = c.execute(t.c.col3.default) assert 50 <= x <= 57 - eq_(y, 'imthedefault') + eq_(y, "imthedefault") eq_(z, f) eq_(f2, 11) def test_py_vs_server_default_detection(self): - def has_(name, *wanted): slots = [ - 'default', 'onupdate', 'server_default', 'server_onupdate'] + "default", + "onupdate", + "server_default", + "server_onupdate", + ] col = tbl.c[name] for slot in wanted: slots.remove(slot) @@ -357,94 +382,146 @@ class DefaultTest(fixtures.TestBase): assert getattr(col, slot) is None, getattr(col, slot) tbl = t - has_('col1', 'default') - has_('col2', 'default', 'onupdate') - has_('col3', 'default', 'onupdate') - has_('col4', 'server_default') - has_('col5', 'server_default') - has_('col6', 'default', 'onupdate') - has_('boolcol1', 'default') - has_('boolcol2', 'default') - has_('col7', 'default', 'onupdate') - has_('col8', 'default', 'onupdate') - has_('col9', 'default', 'server_default') + has_("col1", "default") + has_("col2", "default", "onupdate") + has_("col3", "default", "onupdate") + has_("col4", "server_default") + has_("col5", "server_default") + has_("col6", "default", "onupdate") + has_("boolcol1", "default") + has_("boolcol2", "default") + has_("col7", "default", "onupdate") + has_("col8", "default", "onupdate") + has_("col9", "default", "server_default") ColumnDefault, DefaultClause = sa.ColumnDefault, sa.DefaultClause - t2 = Table('t2', MetaData(), - Column('col1', Integer, Sequence('foo')), - Column('col2', Integer, - default=Sequence('foo'), - server_default='y'), - Column('col3', Integer, - Sequence('foo'), - server_default='x'), - Column('col4', Integer, - ColumnDefault('x'), - DefaultClause('y')), - Column('col4', Integer, - ColumnDefault('x'), - DefaultClause('y'), - DefaultClause('y', for_update=True)), - Column('col5', Integer, - ColumnDefault('x'), - DefaultClause('y'), - onupdate='z'), - Column('col6', Integer, - ColumnDefault('x'), - server_default='y', - onupdate='z'), - Column('col7', Integer, - default='x', - server_default='y', - onupdate='z'), - Column('col8', Integer, - server_onupdate='u', - default='x', - server_default='y', - onupdate='z')) + t2 = Table( + "t2", + MetaData(), + Column("col1", Integer, Sequence("foo")), + Column( + "col2", Integer, default=Sequence("foo"), server_default="y" + ), + Column("col3", Integer, Sequence("foo"), server_default="x"), + Column("col4", Integer, ColumnDefault("x"), DefaultClause("y")), + Column( + "col4", + Integer, + ColumnDefault("x"), + DefaultClause("y"), + DefaultClause("y", for_update=True), + ), + Column( + "col5", + Integer, + ColumnDefault("x"), + DefaultClause("y"), + onupdate="z", + ), + Column( + "col6", + Integer, + ColumnDefault("x"), + server_default="y", + onupdate="z", + ), + Column( + "col7", Integer, default="x", server_default="y", onupdate="z" + ), + Column( + "col8", + Integer, + server_onupdate="u", + default="x", + server_default="y", + onupdate="z", + ), + ) tbl = t2 - has_('col1', 'default') - has_('col2', 'default', 'server_default') - has_('col3', 'default', 'server_default') - has_('col4', 'default', 'server_default', 'server_onupdate') - has_('col5', 'default', 'server_default', 'onupdate') - has_('col6', 'default', 'server_default', 'onupdate') - has_('col7', 'default', 'server_default', 'onupdate') + has_("col1", "default") + has_("col2", "default", "server_default") + has_("col3", "default", "server_default") + has_("col4", "default", "server_default", "server_onupdate") + has_("col5", "default", "server_default", "onupdate") + has_("col6", "default", "server_default", "onupdate") + has_("col7", "default", "server_default", "onupdate") has_( - 'col8', 'default', 'server_default', 'onupdate', 'server_onupdate') + "col8", "default", "server_default", "onupdate", "server_onupdate" + ) - @testing.fails_on('firebird', 'Data type unknown') + @testing.fails_on("firebird", "Data type unknown") def test_insert(self): r = t.insert().execute() assert r.lastrow_has_defaults() - eq_(set(r.context.postfetch_cols), - set([t.c.col3, t.c.col5, t.c.col4, t.c.col6])) + eq_( + set(r.context.postfetch_cols), + set([t.c.col3, t.c.col5, t.c.col4, t.c.col6]), + ) r = t.insert(inline=True).execute() assert r.lastrow_has_defaults() - eq_(set(r.context.postfetch_cols), - set([t.c.col3, t.c.col5, t.c.col4, t.c.col6])) + eq_( + set(r.context.postfetch_cols), + set([t.c.col3, t.c.col5, t.c.col4, t.c.col6]), + ) t.insert().execute() ctexec = sa.select( - [currenttime.label('now')], bind=testing.db).scalar() + [currenttime.label("now")], bind=testing.db + ).scalar() result = t.select().order_by(t.c.col1).execute() today = datetime.date.today() - eq_(result.fetchall(), [ - (x, 'imthedefault', f, ts, ts, ctexec, True, False, - 12, today, 'py', 'hi', 'BINDfoo') - for x in range(51, 54)]) + eq_( + result.fetchall(), + [ + ( + x, + "imthedefault", + f, + ts, + ts, + ctexec, + True, + False, + 12, + today, + "py", + "hi", + "BINDfoo", + ) + for x in range(51, 54) + ], + ) t.insert().execute(col9=None) assert r.lastrow_has_defaults() - eq_(set(r.context.postfetch_cols), - set([t.c.col3, t.c.col5, t.c.col4, t.c.col6])) + eq_( + set(r.context.postfetch_cols), + set([t.c.col3, t.c.col5, t.c.col4, t.c.col6]), + ) - eq_(t.select(t.c.col1 == 54).execute().fetchall(), - [(54, 'imthedefault', f, ts, ts, ctexec, True, False, - 12, today, None, 'hi', 'BINDfoo')]) + eq_( + t.select(t.c.col1 == 54).execute().fetchall(), + [ + ( + 54, + "imthedefault", + f, + ts, + ts, + ctexec, + True, + False, + 12, + today, + None, + "hi", + "BINDfoo", + ) + ], + ) def test_insertmany(self): t.insert().execute({}, {}, {}) @@ -452,13 +529,56 @@ class DefaultTest(fixtures.TestBase): ctexec = currenttime.scalar() result = t.select().order_by(t.c.col1).execute() today = datetime.date.today() - eq_(result.fetchall(), - [(51, 'imthedefault', f, ts, ts, ctexec, True, False, - 12, today, 'py', 'hi', 'BINDfoo'), - (52, 'imthedefault', f, ts, ts, ctexec, True, False, - 12, today, 'py', 'hi', 'BINDfoo'), - (53, 'imthedefault', f, ts, ts, ctexec, True, False, - 12, today, 'py', 'hi', 'BINDfoo')]) + eq_( + result.fetchall(), + [ + ( + 51, + "imthedefault", + f, + ts, + ts, + ctexec, + True, + False, + 12, + today, + "py", + "hi", + "BINDfoo", + ), + ( + 52, + "imthedefault", + f, + ts, + ts, + ctexec, + True, + False, + 12, + today, + "py", + "hi", + "BINDfoo", + ), + ( + 53, + "imthedefault", + f, + ts, + ts, + ctexec, + True, + False, + 12, + today, + "py", + "hi", + "BINDfoo", + ), + ], + ) @testing.requires.multivalues_inserts def test_insert_multivalues(self): @@ -468,13 +588,56 @@ class DefaultTest(fixtures.TestBase): ctexec = currenttime.scalar() result = t.select().order_by(t.c.col1).execute() today = datetime.date.today() - eq_(result.fetchall(), - [(51, 'imthedefault', f, ts, ts, ctexec, True, False, - 12, today, 'py', 'hi', 'BINDfoo'), - (52, 'imthedefault', f, ts, ts, ctexec, True, False, - 12, today, 'py', 'hi', 'BINDfoo'), - (53, 'imthedefault', f, ts, ts, ctexec, True, False, - 12, today, 'py', 'hi', 'BINDfoo')]) + eq_( + result.fetchall(), + [ + ( + 51, + "imthedefault", + f, + ts, + ts, + ctexec, + True, + False, + 12, + today, + "py", + "hi", + "BINDfoo", + ), + ( + 52, + "imthedefault", + f, + ts, + ts, + ctexec, + True, + False, + 12, + today, + "py", + "hi", + "BINDfoo", + ), + ( + 53, + "imthedefault", + f, + ts, + ts, + ctexec, + True, + False, + 12, + today, + "py", + "hi", + "BINDfoo", + ), + ], + ) def test_no_embed_in_sql(self): """Using a DefaultGenerator, Sequence, DefaultClause @@ -482,25 +645,28 @@ class DefaultTest(fixtures.TestBase): clause of insert, update, raises an informative error""" for const in ( - sa.Sequence('y'), - sa.ColumnDefault('y'), - sa.DefaultClause('y') + sa.Sequence("y"), + sa.ColumnDefault("y"), + sa.DefaultClause("y"), ): assert_raises_message( sa.exc.ArgumentError, "SQL expression object or string expected, got object of type " "<.* 'list'> instead", - t.select, [const] + t.select, + [const], ) assert_raises_message( sa.exc.InvalidRequestError, "cannot be used directly as a column expression.", - str, t.insert().values(col4=const) + str, + t.insert().values(col4=const), ) assert_raises_message( sa.exc.InvalidRequestError, "cannot be used directly as a column expression.", - str, t.update().values(col4=const) + str, + t.update().values(col4=const), ) def test_missing_many_param(self): @@ -509,45 +675,89 @@ class DefaultTest(fixtures.TestBase): "A value is required for bind parameter 'col7', in parameter " "group 1", t.insert().execute, - {'col4': 7, 'col7': 12, 'col8': 19}, - {'col4': 7, 'col8': 19}, - {'col4': 7, 'col7': 12, 'col8': 19}, + {"col4": 7, "col7": 12, "col8": 19}, + {"col4": 7, "col8": 19}, + {"col4": 7, "col7": 12, "col8": 19}, ) def test_insert_values(self): - t.insert(values={'col3': 50}).execute() + t.insert(values={"col3": 50}).execute() result = t.select().execute() - eq_(50, result.first()['col3']) + eq_(50, result.first()["col3"]) - @testing.fails_on('firebird', 'Data type unknown') + @testing.fails_on("firebird", "Data type unknown") def test_updatemany(self): # MySQL-Python 1.2.2 breaks functions in execute_many :( - if (testing.against('mysql+mysqldb') and - testing.db.dialect.dbapi.version_info[:3] == (1, 2, 2)): + if testing.against( + "mysql+mysqldb" + ) and testing.db.dialect.dbapi.version_info[:3] == (1, 2, 2): return t.insert().execute({}, {}, {}) - t.update(t.c.col1 == sa.bindparam('pkval')).execute( - {'pkval': 51, 'col7': None, 'col8': None, 'boolcol1': False}) + t.update(t.c.col1 == sa.bindparam("pkval")).execute( + {"pkval": 51, "col7": None, "col8": None, "boolcol1": False} + ) - t.update(t.c.col1 == sa.bindparam('pkval')).execute( - {'pkval': 51}, - {'pkval': 52}, - {'pkval': 53}) + t.update(t.c.col1 == sa.bindparam("pkval")).execute( + {"pkval": 51}, {"pkval": 52}, {"pkval": 53} + ) result = t.select().execute() ctexec = currenttime.scalar() today = datetime.date.today() - eq_(result.fetchall(), - [(51, 'im the update', f2, ts, ts, ctexec, False, False, - 13, today, 'py', 'hi', 'BINDfoo'), - (52, 'im the update', f2, ts, ts, ctexec, True, False, - 13, today, 'py', 'hi', 'BINDfoo'), - (53, 'im the update', f2, ts, ts, ctexec, True, False, - 13, today, 'py', 'hi', 'BINDfoo')]) - - @testing.fails_on('firebird', 'Data type unknown') + eq_( + result.fetchall(), + [ + ( + 51, + "im the update", + f2, + ts, + ts, + ctexec, + False, + False, + 13, + today, + "py", + "hi", + "BINDfoo", + ), + ( + 52, + "im the update", + f2, + ts, + ts, + ctexec, + True, + False, + 13, + today, + "py", + "hi", + "BINDfoo", + ), + ( + 53, + "im the update", + f2, + ts, + ts, + ctexec, + True, + False, + 13, + today, + "py", + "hi", + "BINDfoo", + ), + ], + ) + + @testing.fails_on("firebird", "Data type unknown") def test_update(self): r = t.insert().execute() pk = r.inserted_primary_key[0] @@ -555,39 +765,56 @@ class DefaultTest(fixtures.TestBase): ctexec = currenttime.scalar() result = t.select(t.c.col1 == pk).execute() result = result.first() - eq_(result, - (pk, 'im the update', f2, None, None, ctexec, True, False, - 13, datetime.date.today(), 'py', 'hi', 'BINDfoo')) + eq_( + result, + ( + pk, + "im the update", + f2, + None, + None, + ctexec, + True, + False, + 13, + datetime.date.today(), + "py", + "hi", + "BINDfoo", + ), + ) eq_(11, f2) - @testing.fails_on('firebird', 'Data type unknown') + @testing.fails_on("firebird", "Data type unknown") def test_update_values(self): r = t.insert().execute() pk = r.inserted_primary_key[0] - t.update(t.c.col1 == pk, values={'col3': 55}).execute() + t.update(t.c.col1 == pk, values={"col3": 55}).execute() result = t.select(t.c.col1 == pk).execute() result = result.first() - eq_(55, result['col3']) + eq_(55, result["col3"]) class CTEDefaultTest(fixtures.TablesTest): - __requires__ = ('ctes', 'returning', 'ctes_on_dml') + __requires__ = ("ctes", "returning", "ctes_on_dml") __backend__ = True @classmethod def define_tables(cls, metadata): Table( - 'q', metadata, - Column('x', Integer, default=2), - Column('y', Integer, onupdate=5), - Column('z', Integer) + "q", + metadata, + Column("x", Integer, default=2), + Column("y", Integer, onupdate=5), + Column("z", Integer), ) Table( - 'p', metadata, - Column('s', Integer), - Column('t', Integer), - Column('u', Integer, onupdate=1) + "p", + metadata, + Column("s", Integer), + Column("t", Integer), + Column("u", Integer, onupdate=1), ) def _test_a_in_b(self, a, b): @@ -595,48 +822,53 @@ class CTEDefaultTest(fixtures.TablesTest): p = self.tables.p with testing.db.connect() as conn: - if a == 'delete': + if a == "delete": conn.execute(q.insert().values(y=10, z=1)) - cte = q.delete().\ - where(q.c.z == 1).returning(q.c.z).cte('c') + cte = q.delete().where(q.c.z == 1).returning(q.c.z).cte("c") expected = None elif a == "insert": - cte = q.insert().values(z=1, y=10).returning(q.c.z).cte('c') + cte = q.insert().values(z=1, y=10).returning(q.c.z).cte("c") expected = (2, 10) elif a == "update": conn.execute(q.insert().values(x=5, y=10, z=1)) - cte = q.update().\ - where(q.c.z == 1).values(x=7).returning(q.c.z).cte('c') + cte = ( + q.update() + .where(q.c.z == 1) + .values(x=7) + .returning(q.c.z) + .cte("c") + ) expected = (7, 5) elif a == "select": conn.execute(q.insert().values(x=5, y=10, z=1)) - cte = sa.select([q.c.z]).cte('c') + cte = sa.select([q.c.z]).cte("c") expected = (5, 10) if b == "select": conn.execute(p.insert().values(s=1)) stmt = select([p.c.s, cte.c.z]) elif b == "insert": - sel = select([1, cte.c.z, ]) - stmt = p.insert().from_select(['s', 't'], sel).returning( - p.c.s, p.c.t) + sel = select([1, cte.c.z]) + stmt = ( + p.insert() + .from_select(["s", "t"], sel) + .returning(p.c.s, p.c.t) + ) elif b == "delete": - stmt = p.insert().values(s=1, t=cte.c.z).returning( - p.c.s, cte.c.z) + stmt = ( + p.insert().values(s=1, t=cte.c.z).returning(p.c.s, cte.c.z) + ) elif b == "update": conn.execute(p.insert().values(s=1)) - stmt = p.update().values(t=5).\ - where(p.c.s == cte.c.z).\ - returning(p.c.u, cte.c.z) - eq_( - conn.execute(stmt).fetchall(), - [(1, 1)] - ) + stmt = ( + p.update() + .values(t=5) + .where(p.c.s == cte.c.z) + .returning(p.c.u, cte.c.z) + ) + eq_(conn.execute(stmt).fetchall(), [(1, 1)]) - eq_( - conn.execute(select([q.c.x, q.c.y])).fetchone(), - expected - ) + eq_(conn.execute(select([q.c.x, q.c.y])).fetchone(), expected) @testing.requires.ctes_on_dml def test_update_in_select(self): @@ -661,27 +893,34 @@ class CTEDefaultTest(fixtures.TablesTest): class PKDefaultTest(fixtures.TablesTest): - __requires__ = ('subqueries',) + __requires__ = ("subqueries",) __backend__ = True @classmethod def define_tables(cls, metadata): - t2 = Table( - 't2', metadata, - Column('nextid', Integer)) + t2 = Table("t2", metadata, Column("nextid", Integer)) Table( - 't1', metadata, + "t1", + metadata, Column( - 'id', Integer, primary_key=True, - default=sa.select([func.max(t2.c.nextid)]).as_scalar()), - Column('data', String(30))) + "id", + Integer, + primary_key=True, + default=sa.select([func.max(t2.c.nextid)]).as_scalar(), + ), + Column("data", String(30)), + ) Table( - 'date_table', metadata, + "date_table", + metadata, Column( - 'date_id', - DateTime, default=text("current_timestamp"), primary_key=True) + "date_id", + DateTime, + default=text("current_timestamp"), + primary_key=True, + ), ) @testing.requires.returning @@ -693,21 +932,24 @@ class PKDefaultTest(fixtures.TablesTest): def _test(self, returning): t2, t1, date_table = ( - self.tables.t2, self.tables.t1, self.tables.date_table + self.tables.t2, + self.tables.t1, + self.tables.date_table, ) if not returning and not testing.db.dialect.implicit_returning: engine = testing.db else: engine = engines.testing_engine( - options={'implicit_returning': returning}) + options={"implicit_returning": returning} + ) with engine.begin() as conn: conn.execute(t2.insert(), nextid=1) - r = conn.execute(t1.insert(), data='hi') + r = conn.execute(t1.insert(), data="hi") eq_([1], r.inserted_primary_key) conn.execute(t2.insert(), nextid=2) - r = conn.execute(t1.insert(), data='there') + r = conn.execute(t1.insert(), data="there") eq_([2], r.inserted_primary_key) r = conn.execute(date_table.insert()) @@ -715,19 +957,26 @@ class PKDefaultTest(fixtures.TablesTest): class PKIncrementTest(fixtures.TablesTest): - run_define_tables = 'each' + run_define_tables = "each" __backend__ = True @classmethod def define_tables(cls, metadata): - Table("aitable", metadata, - Column('id', Integer, Sequence('ai_id_seq', optional=True), - primary_key=True), - Column('int1', Integer), - Column('str1', String(20))) + Table( + "aitable", + metadata, + Column( + "id", + Integer, + Sequence("ai_id_seq", optional=True), + primary_key=True, + ), + Column("int1", Integer), + Column("str1", String(20)), + ) # TODO: add coverage for increment on a secondary column in a key - @testing.fails_on('firebird', 'Data type unknown') + @testing.fails_on("firebird", "Data type unknown") def _test_autoincrement(self, bind): aitable = self.tables.aitable @@ -738,19 +987,19 @@ class PKIncrementTest(fixtures.TablesTest): self.assert_(last not in ids) ids.add(last) - rs = bind.execute(aitable.insert(), str1='row 2') + rs = bind.execute(aitable.insert(), str1="row 2") last = rs.inserted_primary_key[0] self.assert_(last) self.assert_(last not in ids) ids.add(last) - rs = bind.execute(aitable.insert(), int1=3, str1='row 3') + rs = bind.execute(aitable.insert(), int1=3, str1="row 3") last = rs.inserted_primary_key[0] self.assert_(last) self.assert_(last not in ids) ids.add(last) - rs = bind.execute(aitable.insert(values={'int1': func.length('four')})) + rs = bind.execute(aitable.insert(values={"int1": func.length("four")})) last = rs.inserted_primary_key[0] self.assert_(last) self.assert_(last not in ids) @@ -758,8 +1007,10 @@ class PKIncrementTest(fixtures.TablesTest): eq_(ids, set([1, 2, 3, 4])) - eq_(list(bind.execute(aitable.select().order_by(aitable.c.id))), - [(1, 1, None), (2, None, 'row 2'), (3, 3, 'row 3'), (4, 4, None)]) + eq_( + list(bind.execute(aitable.select().order_by(aitable.c.id))), + [(1, 1, None), (2, None, "row 2"), (3, 3, "row 3"), (4, 4, None)], + ) def test_autoincrement_autocommit(self): self._test_autoincrement(testing.db) @@ -785,85 +1036,87 @@ class PKIncrementTest(fixtures.TablesTest): class EmptyInsertTest(fixtures.TestBase): __backend__ = True - @testing.exclude('sqlite', '<', (3, 3, 8), 'no empty insert support') - @testing.fails_on('oracle', 'FIXME: unknown') + @testing.exclude("sqlite", "<", (3, 3, 8), "no empty insert support") + @testing.fails_on("oracle", "FIXME: unknown") @testing.provide_metadata def test_empty_insert(self): t1 = Table( - 't1', self.metadata, - Column('is_true', Boolean, server_default=('1'))) + "t1", + self.metadata, + Column("is_true", Boolean, server_default=("1")), + ) self.metadata.create_all() t1.insert().execute() - eq_(1, select([func.count(text('*'))], from_obj=t1).scalar()) + eq_(1, select([func.count(text("*"))], from_obj=t1).scalar()) eq_(True, t1.select().scalar()) class AutoIncrementTest(fixtures.TablesTest): - __requires__ = ('identity',) - run_define_tables = 'each' + __requires__ = ("identity",) + run_define_tables = "each" __backend__ = True @classmethod def define_tables(cls, metadata): """Each test manipulates self.metadata individually.""" - @testing.exclude('sqlite', '<', (3, 4), 'no database support') + @testing.exclude("sqlite", "<", (3, 4), "no database support") def test_autoincrement_single_col(self): - single = Table('single', self.metadata, - Column('id', Integer, primary_key=True)) + single = Table( + "single", self.metadata, Column("id", Integer, primary_key=True) + ) single.create() r = single.insert().execute() id_ = r.inserted_primary_key[0] eq_(id_, 1) - eq_(1, sa.select([func.count(sa.text('*'))], from_obj=single).scalar()) + eq_(1, sa.select([func.count(sa.text("*"))], from_obj=single).scalar()) def test_autoincrement_fk(self): nodes = Table( - 'nodes', self.metadata, - Column('id', Integer, primary_key=True), - Column('parent_id', Integer, ForeignKey('nodes.id')), - Column('data', String(30))) + "nodes", + self.metadata, + Column("id", Integer, primary_key=True), + Column("parent_id", Integer, ForeignKey("nodes.id")), + Column("data", String(30)), + ) nodes.create() - r = nodes.insert().execute(data='foo') + r = nodes.insert().execute(data="foo") id_ = r.inserted_primary_key[0] - nodes.insert().execute(data='bar', parent_id=id_) + nodes.insert().execute(data="bar", parent_id=id_) def test_autoinc_detection_no_affinity(self): class MyType(TypeDecorator): impl = TypeEngine assert MyType()._type_affinity is None - t = Table( - 'x', MetaData(), - Column('id', MyType(), primary_key=True) - ) + t = Table("x", MetaData(), Column("id", MyType(), primary_key=True)) assert t._autoincrement_column is None def test_autoincrement_ignore_fk(self): m = MetaData() - Table( - 'y', m, - Column('id', Integer(), primary_key=True) - ) + Table("y", m, Column("id", Integer(), primary_key=True)) x = Table( - 'x', m, + "x", + m, Column( - 'id', Integer(), ForeignKey('y.id'), - autoincrement="ignore_fk", primary_key=True) + "id", + Integer(), + ForeignKey("y.id"), + autoincrement="ignore_fk", + primary_key=True, + ), ) assert x._autoincrement_column is x.c.id def test_autoincrement_fk_disqualifies(self): m = MetaData() - Table( - 'y', m, - Column('id', Integer(), primary_key=True) - ) + Table("y", m, Column("id", Integer(), primary_key=True)) x = Table( - 'x', m, - Column('id', Integer(), ForeignKey('y.id'), primary_key=True) + "x", + m, + Column("id", Integer(), ForeignKey("y.id"), primary_key=True), ) assert x._autoincrement_column is None @@ -871,124 +1124,129 @@ class AutoIncrementTest(fixtures.TablesTest): def test_non_autoincrement(self): # sqlite INT primary keys can be non-unique! (only for ints) nonai = Table( - "nonaitest", self.metadata, - Column('id', Integer, autoincrement=False, primary_key=True), - Column('data', String(20))) + "nonaitest", + self.metadata, + Column("id", Integer, autoincrement=False, primary_key=True), + Column("data", String(20)), + ) nonai.create() def go(): # postgresql + mysql strict will fail on first row, # mysql in legacy mode fails on second row - nonai.insert().execute(data='row 1') - nonai.insert().execute(data='row 2') + nonai.insert().execute(data="row 1") + nonai.insert().execute(data="row 2") # just testing SQLite for now, it passes - with expect_warnings( - ".*has no Python-side or server-side default.*", - ): + with expect_warnings(".*has no Python-side or server-side default.*"): go() def test_col_w_sequence_non_autoinc_no_firing(self): metadata = self.metadata # plain autoincrement/PK table in the actual schema - Table( - "x", metadata, - Column("set_id", Integer, primary_key=True) - ) + Table("x", metadata, Column("set_id", Integer, primary_key=True)) metadata.create_all() # for the INSERT use a table with a Sequence # and autoincrement=False. Using a ForeignKey # would have the same effect dataset_no_autoinc = Table( - "x", MetaData(), + "x", + MetaData(), Column( - "set_id", Integer, Sequence("some_seq"), - primary_key=True, autoincrement=False) + "set_id", + Integer, + Sequence("some_seq"), + primary_key=True, + autoincrement=False, + ), ) testing.db.execute(dataset_no_autoinc.insert()) eq_( testing.db.scalar( - select([func.count('*')]).select_from(dataset_no_autoinc)), 1 + select([func.count("*")]).select_from(dataset_no_autoinc) + ), + 1, ) class SequenceDDLTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" __backend__ = True def test_create_drop_ddl(self): self.assert_compile( - CreateSequence(Sequence('foo_seq')), - "CREATE SEQUENCE foo_seq", + CreateSequence(Sequence("foo_seq")), "CREATE SEQUENCE foo_seq" ) self.assert_compile( - CreateSequence(Sequence('foo_seq', start=5)), + CreateSequence(Sequence("foo_seq", start=5)), "CREATE SEQUENCE foo_seq START WITH 5", ) self.assert_compile( - CreateSequence(Sequence('foo_seq', increment=2)), + CreateSequence(Sequence("foo_seq", increment=2)), "CREATE SEQUENCE foo_seq INCREMENT BY 2", ) self.assert_compile( - CreateSequence(Sequence('foo_seq', increment=2, start=5)), + CreateSequence(Sequence("foo_seq", increment=2, start=5)), "CREATE SEQUENCE foo_seq INCREMENT BY 2 START WITH 5", ) self.assert_compile( - CreateSequence(Sequence( - 'foo_seq', increment=2, start=0, minvalue=0)), + CreateSequence( + Sequence("foo_seq", increment=2, start=0, minvalue=0) + ), "CREATE SEQUENCE foo_seq INCREMENT BY 2 START WITH 0 MINVALUE 0", ) self.assert_compile( - CreateSequence(Sequence( - 'foo_seq', increment=2, start=1, maxvalue=5)), + CreateSequence( + Sequence("foo_seq", increment=2, start=1, maxvalue=5) + ), "CREATE SEQUENCE foo_seq INCREMENT BY 2 START WITH 1 MAXVALUE 5", ) self.assert_compile( - CreateSequence(Sequence( - 'foo_seq', increment=2, start=1, nomaxvalue=True)), + CreateSequence( + Sequence("foo_seq", increment=2, start=1, nomaxvalue=True) + ), "CREATE SEQUENCE foo_seq INCREMENT BY 2 START WITH 1 NO MAXVALUE", ) self.assert_compile( - CreateSequence(Sequence( - 'foo_seq', increment=2, start=0, nominvalue=True)), + CreateSequence( + Sequence("foo_seq", increment=2, start=0, nominvalue=True) + ), "CREATE SEQUENCE foo_seq INCREMENT BY 2 START WITH 0 NO MINVALUE", ) self.assert_compile( - CreateSequence(Sequence( - 'foo_seq', start=1, maxvalue=10, cycle=True)), + CreateSequence( + Sequence("foo_seq", start=1, maxvalue=10, cycle=True) + ), "CREATE SEQUENCE foo_seq START WITH 1 MAXVALUE 10 CYCLE", ) self.assert_compile( - CreateSequence(Sequence( - 'foo_seq', cache=1000, order=True)), + CreateSequence(Sequence("foo_seq", cache=1000, order=True)), "CREATE SEQUENCE foo_seq CACHE 1000 ORDER", ) self.assert_compile( - CreateSequence(Sequence( - 'foo_seq', order=True)), + CreateSequence(Sequence("foo_seq", order=True)), "CREATE SEQUENCE foo_seq ORDER", ) self.assert_compile( - DropSequence(Sequence('foo_seq')), - "DROP SEQUENCE foo_seq", + DropSequence(Sequence("foo_seq")), "DROP SEQUENCE foo_seq" ) class SequenceExecTest(fixtures.TestBase): - __requires__ = ('sequences',) + __requires__ = ("sequences",) __backend__ = True @classmethod @@ -1042,28 +1300,23 @@ class SequenceExecTest(fixtures.TestBase): """test can use next_value() in select column expr""" s = Sequence("my_sequence") - self._assert_seq_result( - testing.db.scalar(select([s.next_value()])) - ) + self._assert_seq_result(testing.db.scalar(select([s.next_value()]))) - @testing.fails_on('oracle', "ORA-02287: sequence number not allowed here") + @testing.fails_on("oracle", "ORA-02287: sequence number not allowed here") @testing.provide_metadata def test_func_embedded_whereclause(self): """test can use next_value() in whereclause""" metadata = self.metadata - t1 = Table( - 't', metadata, - Column('x', Integer) - ) + t1 = Table("t", metadata, Column("x", Integer)) t1.create(testing.db) - testing.db.execute(t1.insert(), [{'x': 1}, {'x': 300}, {'x': 301}]) + testing.db.execute(t1.insert(), [{"x": 1}, {"x": 300}, {"x": 301}]) s = Sequence("my_sequence") eq_( testing.db.execute( t1.select().where(t1.c.x > s.next_value()) ).fetchall(), - [(300, ), (301, )] + [(300,), (301,)], ) @testing.provide_metadata @@ -1071,18 +1324,11 @@ class SequenceExecTest(fixtures.TestBase): """test can use next_value() in values() of _ValuesBase""" metadata = self.metadata - t1 = Table( - 't', metadata, - Column('x', Integer) - ) + t1 = Table("t", metadata, Column("x", Integer)) t1.create(testing.db) s = Sequence("my_sequence") - testing.db.execute( - t1.insert().values(x=s.next_value()) - ) - self._assert_seq_result( - testing.db.scalar(t1.select()) - ) + testing.db.execute(t1.insert().values(x=s.next_value())) + self._assert_seq_result(testing.db.scalar(t1.select())) @testing.provide_metadata def test_inserted_pk_no_returning(self): @@ -1090,13 +1336,10 @@ class SequenceExecTest(fixtures.TestBase): pk_col=next_value(), implicit returning is not used.""" metadata = self.metadata - e = engines.testing_engine(options={'implicit_returning': False}) + e = engines.testing_engine(options={"implicit_returning": False}) s = Sequence("my_sequence") metadata.bind = e - t1 = Table( - 't', metadata, - Column('x', Integer, primary_key=True) - ) + t1 = Table("t", metadata, Column("x", Integer, primary_key=True)) t1.create() r = e.execute(t1.insert().values(x=s.next_value())) eq_(r.inserted_primary_key, [None]) @@ -1108,35 +1351,29 @@ class SequenceExecTest(fixtures.TestBase): pk_col=next_value(), when implicit returning is used.""" metadata = self.metadata - e = engines.testing_engine(options={'implicit_returning': True}) + e = engines.testing_engine(options={"implicit_returning": True}) s = Sequence("my_sequence") metadata.bind = e - t1 = Table( - 't', metadata, - Column('x', Integer, primary_key=True) - ) + t1 = Table("t", metadata, Column("x", Integer, primary_key=True)) t1.create() - r = e.execute( - t1.insert().values(x=s.next_value()) - ) + r = e.execute(t1.insert().values(x=s.next_value())) self._assert_seq_result(r.inserted_primary_key[0]) class SequenceTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __requires__ = ('sequences',) + __requires__ = ("sequences",) __backend__ = True - @testing.fails_on('firebird', 'no FB support for start/increment') + @testing.fails_on("firebird", "no FB support for start/increment") def test_start_increment(self): for seq in ( - Sequence('foo_seq'), - Sequence('foo_seq', start=8), - Sequence('foo_seq', increment=5)): + Sequence("foo_seq"), + Sequence("foo_seq", start=8), + Sequence("foo_seq", increment=5), + ): seq.create(testing.db) try: - values = [ - testing.db.execute(seq) for i in range(3) - ] + values = [testing.db.execute(seq) for i in range(3)] start = seq.start or 1 inc = seq.increment or 1 assert values == list(range(start, start + inc * 3, inc)) @@ -1153,7 +1390,10 @@ class SequenceTest(fixtures.TestBase, testing.AssertsCompiledSQL): for s in (Sequence("my_seq"), Sequence("my_seq", optional=True)): assert str(s.next_value().compile(dialect=testing.db.dialect)) in ( - "nextval('my_seq')", "gen_id(my_seq, 1)", "my_seq.nextval",) + "nextval('my_seq')", + "gen_id(my_seq, 1)", + "my_seq.nextval", + ) def test_nextval_unsupported(self): """test next_value() used on non-sequence platform @@ -1165,37 +1405,37 @@ class SequenceTest(fixtures.TestBase, testing.AssertsCompiledSQL): NotImplementedError, "Dialect 'sqlite' does not support sequence increments.", s.next_value().compile, - dialect=d + dialect=d, ) def test_checkfirst_sequence(self): s = Sequence("my_sequence") s.create(testing.db, checkfirst=False) - assert self._has_sequence('my_sequence') + assert self._has_sequence("my_sequence") s.create(testing.db, checkfirst=True) s.drop(testing.db, checkfirst=False) - assert not self._has_sequence('my_sequence') + assert not self._has_sequence("my_sequence") s.drop(testing.db, checkfirst=True) def test_checkfirst_metadata(self): m = MetaData() Sequence("my_sequence", metadata=m) m.create_all(testing.db, checkfirst=False) - assert self._has_sequence('my_sequence') + assert self._has_sequence("my_sequence") m.create_all(testing.db, checkfirst=True) m.drop_all(testing.db, checkfirst=False) - assert not self._has_sequence('my_sequence') + assert not self._has_sequence("my_sequence") m.drop_all(testing.db, checkfirst=True) def test_checkfirst_table(self): m = MetaData() s = Sequence("my_sequence") - t = Table('t', m, Column('c', Integer, s, primary_key=True)) + t = Table("t", m, Column("c", Integer, s, primary_key=True)) t.create(testing.db, checkfirst=False) - assert self._has_sequence('my_sequence') + assert self._has_sequence("my_sequence") t.create(testing.db, checkfirst=True) t.drop(testing.db, checkfirst=False) - assert not self._has_sequence('my_sequence') + assert not self._has_sequence("my_sequence") t.drop(testing.db, checkfirst=True) @testing.provide_metadata @@ -1204,9 +1444,7 @@ class SequenceTest(fixtures.TestBase, testing.AssertsCompiledSQL): Sequence("s1", metadata=metadata) s2 = Sequence("s2", metadata=metadata) s3 = Sequence("s3") - t = Table( - 't', metadata, - Column('c', Integer, s3, primary_key=True)) + t = Table("t", metadata, Column("c", Integer, s3, primary_key=True)) assert s3.metadata is metadata t.create(testing.db, checkfirst=True) @@ -1216,29 +1454,33 @@ class SequenceTest(fixtures.TestBase, testing.AssertsCompiledSQL): # re-created since it's linked to 't'. # 's1' and 's2' are, however. metadata.create_all(testing.db) - assert self._has_sequence('s1') - assert self._has_sequence('s2') - assert not self._has_sequence('s3') + assert self._has_sequence("s1") + assert self._has_sequence("s2") + assert not self._has_sequence("s3") s2.drop(testing.db) - assert self._has_sequence('s1') - assert not self._has_sequence('s2') + assert self._has_sequence("s1") + assert not self._has_sequence("s2") metadata.drop_all(testing.db) - assert not self._has_sequence('s1') - assert not self._has_sequence('s2') + assert not self._has_sequence("s1") + assert not self._has_sequence("s2") @testing.requires.returning @testing.provide_metadata def test_freestanding_sequence_via_autoinc(self): t = Table( - 'some_table', self.metadata, + "some_table", + self.metadata, Column( - 'id', Integer, + "id", + Integer, autoincrement=True, primary_key=True, default=Sequence( - 'my_sequence', metadata=self.metadata).next_value()) + "my_sequence", metadata=self.metadata + ).next_value(), + ), ) self.metadata.create_all(testing.db) @@ -1250,7 +1492,7 @@ cartitems = sometable = metadata = None class TableBoundSequenceTest(fixtures.TestBase): - __requires__ = ('sequences',) + __requires__ = ("sequences",) __backend__ = True @classmethod @@ -1258,19 +1500,25 @@ class TableBoundSequenceTest(fixtures.TestBase): global cartitems, sometable, metadata metadata = MetaData(testing.db) cartitems = Table( - "cartitems", metadata, + "cartitems", + metadata, Column( - "cart_id", Integer, Sequence('cart_id_seq'), primary_key=True), + "cart_id", Integer, Sequence("cart_id_seq"), primary_key=True + ), Column("description", String(40)), - Column("createdate", sa.DateTime()) + Column("createdate", sa.DateTime()), ) sometable = Table( - 'Manager', metadata, - Column('obj_id', Integer, Sequence('obj_id_seq')), - Column('name', String(128)), + "Manager", + metadata, + Column("obj_id", Integer, Sequence("obj_id_seq")), + Column("name", String(128)), Column( - 'id', Integer, Sequence('Manager_id_seq', optional=True), - primary_key=True), + "id", + Integer, + Sequence("Manager_id_seq", optional=True), + primary_key=True, + ), ) metadata.create_all() @@ -1280,24 +1528,30 @@ class TableBoundSequenceTest(fixtures.TestBase): metadata.drop_all() def test_insert_via_seq(self): - cartitems.insert().execute(description='hi') - cartitems.insert().execute(description='there') - r = cartitems.insert().execute(description='lala') + cartitems.insert().execute(description="hi") + cartitems.insert().execute(description="there") + r = cartitems.insert().execute(description="lala") assert r.inserted_primary_key and r.inserted_primary_key[0] is not None id_ = r.inserted_primary_key[0] - eq_(1, - sa.select([func.count(cartitems.c.cart_id)], - sa.and_(cartitems.c.description == 'lala', - cartitems.c.cart_id == id_)).scalar()) + eq_( + 1, + sa.select( + [func.count(cartitems.c.cart_id)], + sa.and_( + cartitems.c.description == "lala", + cartitems.c.cart_id == id_, + ), + ).scalar(), + ) cartitems.select().execute().fetchall() def test_seq_nonpk(self): """test sequences fire off as defaults on non-pk columns""" - engine = engines.testing_engine(options={'implicit_returning': False}) + engine = engines.testing_engine(options={"implicit_returning": False}) result = engine.execute(sometable.insert(), name="somename") assert set(result.postfetch_cols()) == set([sometable.c.obj_id]) @@ -1305,22 +1559,25 @@ class TableBoundSequenceTest(fixtures.TestBase): result = engine.execute(sometable.insert(), name="someother") assert set(result.postfetch_cols()) == set([sometable.c.obj_id]) - sometable.insert().execute( - {'name': 'name3'}, - {'name': 'name4'}) - eq_(sometable.select().order_by(sometable.c.id).execute().fetchall(), - [(1, "somename", 1), - (2, "someother", 2), - (3, "name3", 3), - (4, "name4", 4)]) + sometable.insert().execute({"name": "name3"}, {"name": "name4"}) + eq_( + sometable.select().order_by(sometable.c.id).execute().fetchall(), + [ + (1, "somename", 1), + (2, "someother", 2), + (3, "name3", 3), + (4, "name4", 4), + ], + ) class SequenceAsServerDefaultTest( - testing.AssertsExecutionResults, fixtures.TablesTest): - __requires__ = ('sequences_as_server_defaults',) + testing.AssertsExecutionResults, fixtures.TablesTest +): + __requires__ = ("sequences_as_server_defaults",) __backend__ = True - run_create_tables = 'each' + run_create_tables = "each" @classmethod def define_tables(cls, metadata): @@ -1328,16 +1585,18 @@ class SequenceAsServerDefaultTest( s = Sequence("t_seq", metadata=m) Table( - "t_seq_test", m, + "t_seq_test", + m, Column("id", Integer, s, server_default=s.next_value()), - Column("data", String(50)) + Column("data", String(50)), ) s2 = Sequence("t_seq_2", metadata=m) Table( - "t_seq_test_2", m, + "t_seq_test_2", + m, Column("id", Integer, server_default=s2.next_value()), - Column("data", String(50)) + Column("data", String(50)), ) def test_default_textual_w_default(self): @@ -1356,7 +1615,8 @@ class SequenceAsServerDefaultTest( def test_default_textual_server_only(self): with testing.db.connect() as conn: conn.execute( - "insert into t_seq_test_2 (data) values ('some data')") + "insert into t_seq_test_2 (data) values ('some data')" + ) eq_(conn.scalar("select id from t_seq_test_2"), 1) @@ -1372,24 +1632,18 @@ class SequenceAsServerDefaultTest( testing.db, lambda: self.metadata.drop_all(checkfirst=False), AllOf( - CompiledSQL( - "DROP TABLE t_seq_test_2", - {} - ), + CompiledSQL("DROP TABLE t_seq_test_2", {}), EachOf( + CompiledSQL("DROP TABLE t_seq_test", {}), CompiledSQL( - "DROP TABLE t_seq_test", - {} - ), - CompiledSQL( - "DROP SEQUENCE t_seq", # dropped as part of t_seq_test - {} + "DROP SEQUENCE t_seq", # dropped as part of t_seq_test + {}, ), ), ), CompiledSQL( "DROP SEQUENCE t_seq_2", # dropped as part of metadata level - {} + {}, ), ) @@ -1402,6 +1656,7 @@ class SpecialTypePKTest(fixtures.TestBase): column.type._type_affinity, rather than the class of "type" itself. """ + __backend__ = True @classmethod @@ -1424,33 +1679,31 @@ class SpecialTypePKTest(fixtures.TestBase): @testing.provide_metadata def _run_test(self, *arg, **kw): metadata = self.metadata - implicit_returning = kw.pop('implicit_returning', True) - kw['primary_key'] = True - if kw.get('autoincrement', True): - kw['test_needs_autoincrement'] = True + implicit_returning = kw.pop("implicit_returning", True) + kw["primary_key"] = True + if kw.get("autoincrement", True): + kw["test_needs_autoincrement"] = True t = Table( - 'x', metadata, - Column('y', self.MyInteger, *arg, **kw), - Column('data', Integer), - implicit_returning=implicit_returning + "x", + metadata, + Column("y", self.MyInteger, *arg, **kw), + Column("data", Integer), + implicit_returning=implicit_returning, ) t.create() r = t.insert().values(data=5).execute() # we don't pre-fetch 'server_default'. - if 'server_default' in kw and ( - not testing.db.dialect.implicit_returning or - not implicit_returning): + if "server_default" in kw and ( + not testing.db.dialect.implicit_returning or not implicit_returning + ): eq_(r.inserted_primary_key, [None]) else: - eq_(r.inserted_primary_key, ['INT_1']) + eq_(r.inserted_primary_key, ["INT_1"]) r.close() - eq_( - t.select().execute().first(), - ('INT_1', 5) - ) + eq_(t.select().execute().first(), ("INT_1", 5)) def test_plain(self): # among other things, tests that autoincrement @@ -1459,7 +1712,8 @@ class SpecialTypePKTest(fixtures.TestBase): def test_literal_default_label(self): self._run_test( - default=literal("INT_1", type_=self.MyInteger).label('foo')) + default=literal("INT_1", type_=self.MyInteger).label("foo") + ) def test_literal_default_no_label(self): self._run_test(default=literal("INT_1", type_=self.MyInteger)) @@ -1468,16 +1722,16 @@ class SpecialTypePKTest(fixtures.TestBase): self._run_test(default=literal_column("1", type_=self.MyInteger)) def test_sequence(self): - self._run_test(Sequence('foo_seq')) + self._run_test(Sequence("foo_seq")) def test_text_clause_default_no_type(self): - self._run_test(default=text('1')) + self._run_test(default=text("1")) def test_server_default(self): - self._run_test(server_default='1',) + self._run_test(server_default="1") def test_server_default_no_autoincrement(self): - self._run_test(server_default='1', autoincrement=False) + self._run_test(server_default="1", autoincrement=False) def test_clause(self): stmt = select([cast("INT_1", type_=self.MyInteger)]).as_scalar() @@ -1489,7 +1743,7 @@ class SpecialTypePKTest(fixtures.TestBase): @testing.requires.returning def test_server_default_no_implicit_returning(self): - self._run_test(server_default='1', autoincrement=False) + self._run_test(server_default="1", autoincrement=False) class ServerDefaultsOnPKTest(fixtures.TestBase): @@ -1508,19 +1762,18 @@ class ServerDefaultsOnPKTest(fixtures.TestBase): metadata = self.metadata t = Table( - 'x', metadata, + "x", + metadata, Column( - 'y', String(10), server_default='key_one', primary_key=True), - Column('data', String(10)), - implicit_returning=False + "y", String(10), server_default="key_one", primary_key=True + ), + Column("data", String(10)), + implicit_returning=False, ) metadata.create_all() - r = t.insert().execute(data='data') + r = t.insert().execute(data="data") eq_(r.inserted_primary_key, [None]) - eq_( - t.select().execute().fetchall(), - [('key_one', 'data')] - ) + eq_(t.select().execute().fetchall(), [("key_one", "data")]) @testing.requires.returning @testing.provide_metadata @@ -1530,103 +1783,91 @@ class ServerDefaultsOnPKTest(fixtures.TestBase): metadata = self.metadata t = Table( - 'x', metadata, + "x", + metadata, Column( - 'y', String(10), server_default='key_one', primary_key=True), - Column('data', String(10)) + "y", String(10), server_default="key_one", primary_key=True + ), + Column("data", String(10)), ) metadata.create_all() - r = t.insert().execute(data='data') - eq_(r.inserted_primary_key, ['key_one']) - eq_( - t.select().execute().fetchall(), - [('key_one', 'data')] - ) + r = t.insert().execute(data="data") + eq_(r.inserted_primary_key, ["key_one"]) + eq_(t.select().execute().fetchall(), [("key_one", "data")]) @testing.provide_metadata def test_int_default_none_on_insert(self): metadata = self.metadata t = Table( - 'x', metadata, - Column('y', Integer, server_default='5', primary_key=True), - Column('data', String(10)), - implicit_returning=False + "x", + metadata, + Column("y", Integer, server_default="5", primary_key=True), + Column("data", String(10)), + implicit_returning=False, ) assert t._autoincrement_column is None metadata.create_all() - r = t.insert().execute(data='data') + r = t.insert().execute(data="data") eq_(r.inserted_primary_key, [None]) - if testing.against('sqlite'): - eq_( - t.select().execute().fetchall(), - [(1, 'data')] - ) + if testing.against("sqlite"): + eq_(t.select().execute().fetchall(), [(1, "data")]) else: - eq_( - t.select().execute().fetchall(), - [(5, 'data')] - ) + eq_(t.select().execute().fetchall(), [(5, "data")]) @testing.provide_metadata def test_autoincrement_reflected_from_server_default(self): metadata = self.metadata t = Table( - 'x', metadata, - Column('y', Integer, server_default='5', primary_key=True), - Column('data', String(10)), - implicit_returning=False + "x", + metadata, + Column("y", Integer, server_default="5", primary_key=True), + Column("data", String(10)), + implicit_returning=False, ) assert t._autoincrement_column is None metadata.create_all() m2 = MetaData(metadata.bind) - t2 = Table('x', m2, autoload=True, implicit_returning=False) + t2 = Table("x", m2, autoload=True, implicit_returning=False) assert t2._autoincrement_column is None @testing.provide_metadata def test_int_default_none_on_insert_reflected(self): metadata = self.metadata Table( - 'x', metadata, - Column('y', Integer, server_default='5', primary_key=True), - Column('data', String(10)), - implicit_returning=False + "x", + metadata, + Column("y", Integer, server_default="5", primary_key=True), + Column("data", String(10)), + implicit_returning=False, ) metadata.create_all() m2 = MetaData(metadata.bind) - t2 = Table('x', m2, autoload=True, implicit_returning=False) + t2 = Table("x", m2, autoload=True, implicit_returning=False) - r = t2.insert().execute(data='data') + r = t2.insert().execute(data="data") eq_(r.inserted_primary_key, [None]) - if testing.against('sqlite'): - eq_( - t2.select().execute().fetchall(), - [(1, 'data')] - ) + if testing.against("sqlite"): + eq_(t2.select().execute().fetchall(), [(1, "data")]) else: - eq_( - t2.select().execute().fetchall(), - [(5, 'data')] - ) + eq_(t2.select().execute().fetchall(), [(5, "data")]) @testing.requires.returning @testing.provide_metadata def test_int_default_on_insert_with_returning(self): metadata = self.metadata t = Table( - 'x', metadata, - Column('y', Integer, server_default='5', primary_key=True), - Column('data', String(10)) + "x", + metadata, + Column("y", Integer, server_default="5", primary_key=True), + Column("data", String(10)), ) metadata.create_all() - r = t.insert().execute(data='data') + r = t.insert().execute(data="data") eq_(r.inserted_primary_key, [5]) - eq_( - t.select().execute().fetchall(), - [(5, 'data')] - ) + eq_(t.select().execute().fetchall(), [(5, "data")]) class UnicodeDefaultsTest(fixtures.TestBase): @@ -1636,18 +1877,19 @@ class UnicodeDefaultsTest(fixtures.TestBase): Column(Unicode(32)) def test_unicode_default(self): - default = u('foo') + default = u("foo") Column(Unicode(32), default=default) def test_nonunicode_default(self): - default = b('foo') + default = b("foo") assert_raises_message( sa.exc.SAWarning, "Unicode column 'foobar' has non-unicode " "default value b?'foo' specified.", Column, - "foobar", Unicode(32), - default=default + "foobar", + Unicode(32), + default=default, ) @@ -1656,34 +1898,34 @@ class InsertFromSelectTest(fixtures.TestBase): def _fixture(self): data = Table( - 'data', self.metadata, - Column('x', Integer), - Column('y', Integer) + "data", self.metadata, Column("x", Integer), Column("y", Integer) ) data.create() - testing.db.execute(data.insert(), {'x': 2, 'y': 5}, {'x': 7, 'y': 12}) + testing.db.execute(data.insert(), {"x": 2, "y": 5}, {"x": 7, "y": 12}) return data @testing.provide_metadata def test_insert_from_select_override_defaults(self): data = self._fixture() - table = Table('sometable', self.metadata, - Column('x', Integer), - Column('foo', Integer, default=12), - Column('y', Integer)) + table = Table( + "sometable", + self.metadata, + Column("x", Integer), + Column("foo", Integer, default=12), + Column("y", Integer), + ) table.create() sel = select([data.c.x, data.c.y]) - ins = table.insert().\ - from_select(["x", "y"], sel) + ins = table.insert().from_select(["x", "y"], sel) testing.db.execute(ins) eq_( testing.db.execute(table.select().order_by(table.c.x)).fetchall(), - [(2, 12, 5), (7, 12, 12)] + [(2, 12, 5), (7, 12, 12)], ) @testing.provide_metadata @@ -1695,25 +1937,28 @@ class InsertFromSelectTest(fixtures.TestBase): def foo(ctx): return next(counter) - table = Table('sometable', self.metadata, - Column('x', Integer), - Column('foo', Integer, default=foo), - Column('y', Integer)) + table = Table( + "sometable", + self.metadata, + Column("x", Integer), + Column("foo", Integer, default=foo), + Column("y", Integer), + ) table.create() sel = select([data.c.x, data.c.y]) - ins = table.insert().\ - from_select(["x", "y"], sel) + ins = table.insert().from_select(["x", "y"], sel) testing.db.execute(ins) # counter is only called once! eq_( testing.db.execute(table.select().order_by(table.c.x)).fetchall(), - [(2, 1, 5), (7, 1, 12)] + [(2, 1, 5), (7, 1, 12)], ) + class CurrentParametersTest(fixtures.TablesTest): __backend__ = True @@ -1723,15 +1968,16 @@ class CurrentParametersTest(fixtures.TablesTest): pass Table( - "some_table", metadata, - Column('x', String(50), default=gen_default), - Column('y', String(50)), + "some_table", + metadata, + Column("x", String(50), default=gen_default), + Column("y", String(50)), ) def _fixture(self, fn): - def gen_default(context): fn(context) + some_table = self.tables.some_table some_table.c.x.default.arg = gen_default return fn @@ -1744,12 +1990,12 @@ class CurrentParametersTest(fixtures.TablesTest): collect(context.get_current_parameters()) table = self.tables.some_table - if exec_type in ('multivalues', 'executemany'): + if exec_type in ("multivalues", "executemany"): parameters = [{"y": "h1"}, {"y": "h2"}] else: parameters = [{"y": "hello"}] - if exec_type == 'multivalues': + if exec_type == "multivalues": stmt, params = table.insert().values(parameters), {} else: stmt, params = table.insert(), parameters @@ -1758,7 +2004,7 @@ class CurrentParametersTest(fixtures.TablesTest): conn.execute(stmt, params) eq_( collect.mock_calls, - [mock.call({"y": param['y'], "x": None}) for param in parameters] + [mock.call({"y": param["y"], "x": None}) for param in parameters], ) def test_single_w_attribute(self): diff --git a/test/sql/test_delete.py b/test/sql/test_delete.py index 331a536018..cb5e696f20 100644 --- a/test/sql/test_delete.py +++ b/test/sql/test_delete.py @@ -1,128 +1,153 @@ #! coding:utf-8 -from sqlalchemy import Integer, String, ForeignKey, delete, select, and_, \ - or_, exists +from sqlalchemy import ( + Integer, + String, + ForeignKey, + delete, + select, + and_, + or_, + exists, +) from sqlalchemy.dialects import mysql from sqlalchemy.engine import default from sqlalchemy import testing from sqlalchemy import exc -from sqlalchemy.testing import AssertsCompiledSQL, fixtures, eq_, \ - assert_raises_message +from sqlalchemy.testing import ( + AssertsCompiledSQL, + fixtures, + eq_, + assert_raises_message, +) from sqlalchemy.testing.schema import Table, Column class _DeleteTestBase(object): - @classmethod def define_tables(cls, metadata): - Table('mytable', metadata, - Column('myid', Integer), - Column('name', String(30)), - Column('description', String(50))) - Table('myothertable', metadata, - Column('otherid', Integer), - Column('othername', String(30))) + Table( + "mytable", + metadata, + Column("myid", Integer), + Column("name", String(30)), + Column("description", String(50)), + ) + Table( + "myothertable", + metadata, + Column("otherid", Integer), + Column("othername", String(30)), + ) class DeleteTest(_DeleteTestBase, fixtures.TablesTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_delete_literal_binds(self): table1 = self.tables.mytable - stmt = table1.delete().where(table1.c.name == 'jill') + stmt = table1.delete().where(table1.c.name == "jill") self.assert_compile( stmt, "DELETE FROM mytable WHERE mytable.name = 'jill'", - literal_binds=True) + literal_binds=True, + ) def test_delete(self): table1 = self.tables.mytable self.assert_compile( delete(table1, table1.c.myid == 7), - 'DELETE FROM mytable WHERE mytable.myid = :myid_1') + "DELETE FROM mytable WHERE mytable.myid = :myid_1", + ) self.assert_compile( table1.delete().where(table1.c.myid == 7), - 'DELETE FROM mytable WHERE mytable.myid = :myid_1') + "DELETE FROM mytable WHERE mytable.myid = :myid_1", + ) self.assert_compile( - table1.delete(). - where(table1.c.myid == 7). - where(table1.c.name == 'somename'), - 'DELETE FROM mytable ' - 'WHERE mytable.myid = :myid_1 ' - 'AND mytable.name = :name_1') + table1.delete() + .where(table1.c.myid == 7) + .where(table1.c.name == "somename"), + "DELETE FROM mytable " + "WHERE mytable.myid = :myid_1 " + "AND mytable.name = :name_1", + ) def test_where_empty(self): table1 = self.tables.mytable self.assert_compile( - table1.delete().where(and_()), - "DELETE FROM mytable" + table1.delete().where(and_()), "DELETE FROM mytable" ) self.assert_compile( - table1.delete().where(or_()), - "DELETE FROM mytable" + table1.delete().where(or_()), "DELETE FROM mytable" ) def test_prefix_with(self): table1 = self.tables.mytable - stmt = table1.delete().\ - prefix_with('A', 'B', dialect='mysql').\ - prefix_with('C', 'D') + stmt = ( + table1.delete() + .prefix_with("A", "B", dialect="mysql") + .prefix_with("C", "D") + ) - self.assert_compile(stmt, - 'DELETE C D FROM mytable') + self.assert_compile(stmt, "DELETE C D FROM mytable") - self.assert_compile(stmt, - 'DELETE A B C D FROM mytable', - dialect=mysql.dialect()) + self.assert_compile( + stmt, "DELETE A B C D FROM mytable", dialect=mysql.dialect() + ) def test_alias(self): table1 = self.tables.mytable - talias1 = table1.alias('t1') + talias1 = table1.alias("t1") stmt = delete(talias1).where(talias1.c.myid == 7) self.assert_compile( - stmt, - 'DELETE FROM mytable AS t1 WHERE t1.myid = :myid_1') + stmt, "DELETE FROM mytable AS t1 WHERE t1.myid = :myid_1" + ) def test_correlated(self): table1, table2 = self.tables.mytable, self.tables.myothertable # test a non-correlated WHERE clause s = select([table2.c.othername], table2.c.otherid == 7) - self.assert_compile(delete(table1, table1.c.name == s), - 'DELETE FROM mytable ' - 'WHERE mytable.name = (' - 'SELECT myothertable.othername ' - 'FROM myothertable ' - 'WHERE myothertable.otherid = :otherid_1' - ')') + self.assert_compile( + delete(table1, table1.c.name == s), + "DELETE FROM mytable " + "WHERE mytable.name = (" + "SELECT myothertable.othername " + "FROM myothertable " + "WHERE myothertable.otherid = :otherid_1" + ")", + ) # test one that is actually correlated... s = select([table2.c.othername], table2.c.otherid == table1.c.myid) - self.assert_compile(table1.delete(table1.c.name == s), - 'DELETE FROM mytable ' - 'WHERE mytable.name = (' - 'SELECT myothertable.othername ' - 'FROM myothertable ' - 'WHERE myothertable.otherid = mytable.myid' - ')') + self.assert_compile( + table1.delete(table1.c.name == s), + "DELETE FROM mytable " + "WHERE mytable.name = (" + "SELECT myothertable.othername " + "FROM myothertable " + "WHERE myothertable.otherid = mytable.myid" + ")", + ) class DeleteFromCompileTest( - _DeleteTestBase, fixtures.TablesTest, AssertsCompiledSQL): + _DeleteTestBase, fixtures.TablesTest, AssertsCompiledSQL +): # DELETE FROM is also tested by individual dialects since there is no # consistent syntax. here we use the StrSQLcompiler which has a fake # syntax. - __dialect__ = 'default_enhanced' + __dialect__ = "default_enhanced" def test_delete_extra_froms(self): table1, table2 = self.tables.mytable, self.tables.myothertable @@ -137,10 +162,15 @@ class DeleteFromCompileTest( def test_correlation_to_extra(self): table1, table2 = self.tables.mytable, self.tables.myothertable - stmt = table1.delete().where( - table1.c.myid == table2.c.otherid).where( - ~exists().where(table2.c.otherid == table1.c.myid). - where(table2.c.othername == 'x').correlate(table2) + stmt = ( + table1.delete() + .where(table1.c.myid == table2.c.otherid) + .where( + ~exists() + .where(table2.c.otherid == table1.c.myid) + .where(table2.c.othername == "x") + .correlate(table2) + ) ) self.assert_compile( @@ -154,10 +184,15 @@ class DeleteFromCompileTest( def test_dont_correlate_to_extra(self): table1, table2 = self.tables.mytable, self.tables.myothertable - stmt = table1.delete().where( - table1.c.myid == table2.c.otherid).where( - ~exists().where(table2.c.otherid == table1.c.myid). - where(table2.c.othername == 'x').correlate() + stmt = ( + table1.delete() + .where(table1.c.myid == table2.c.otherid) + .where( + ~exists() + .where(table2.c.otherid == table1.c.myid) + .where(table2.c.othername == "x") + .correlate() + ) ) self.assert_compile( @@ -172,16 +207,21 @@ class DeleteFromCompileTest( def test_autocorrelate_error(self): table1, table2 = self.tables.mytable, self.tables.myothertable - stmt = table1.delete().where( - table1.c.myid == table2.c.otherid).where( - ~exists().where(table2.c.otherid == table1.c.myid). - where(table2.c.othername == 'x') + stmt = ( + table1.delete() + .where(table1.c.myid == table2.c.otherid) + .where( + ~exists() + .where(table2.c.otherid == table1.c.myid) + .where(table2.c.othername == "x") + ) ) assert_raises_message( exc.InvalidRequestError, ".*returned no FROM clauses due to auto-correlation.*", - stmt.compile, dialect=default.StrCompileDialect() + stmt.compile, + dialect=default.StrCompileDialect(), ) @@ -190,56 +230,77 @@ class DeleteFromRoundTripTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table('mytable', metadata, - Column('myid', Integer), - Column('name', String(30)), - Column('description', String(50))) - Table('myothertable', metadata, - Column('otherid', Integer), - Column('othername', String(30))) - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30), nullable=False)) - Table('addresses', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', None, ForeignKey('users.id')), - Column('name', String(30), nullable=False), - Column('email_address', String(50), nullable=False)) - Table('dingalings', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('address_id', None, ForeignKey('addresses.id')), - Column('data', String(30))) - Table('update_w_default', metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('ycol', Integer, key='y'), - Column('data', String(30), onupdate=lambda: "hi")) + Table( + "mytable", + metadata, + Column("myid", Integer), + Column("name", String(30)), + Column("description", String(50)), + ) + Table( + "myothertable", + metadata, + Column("otherid", Integer), + Column("othername", String(30)), + ) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30), nullable=False), + ) + Table( + "addresses", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", None, ForeignKey("users.id")), + Column("name", String(30), nullable=False), + Column("email_address", String(50), nullable=False), + ) + Table( + "dingalings", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("address_id", None, ForeignKey("addresses.id")), + Column("data", String(30)), + ) + Table( + "update_w_default", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("ycol", Integer, key="y"), + Column("data", String(30), onupdate=lambda: "hi"), + ) @classmethod def fixtures(cls): return dict( users=( - ('id', 'name'), - (7, 'jack'), - (8, 'ed'), - (9, 'fred'), - (10, 'chuck') + ("id", "name"), + (7, "jack"), + (8, "ed"), + (9, "fred"), + (10, "chuck"), ), addresses=( - ('id', 'user_id', 'name', 'email_address'), - (1, 7, 'x', 'jack@bean.com'), - (2, 8, 'x', 'ed@wood.com'), - (3, 8, 'x', 'ed@bettyboop.com'), - (4, 8, 'x', 'ed@lala.com'), - (5, 9, 'x', 'fred@fred.com') + ("id", "user_id", "name", "email_address"), + (1, 7, "x", "jack@bean.com"), + (2, 8, "x", "ed@wood.com"), + (3, 8, "x", "ed@bettyboop.com"), + (4, 8, "x", "ed@lala.com"), + (5, 9, "x", "fred@fred.com"), ), dingalings=( - ('id', 'address_id', 'data'), - (1, 2, 'ding 1/2'), - (2, 5, 'ding 2/5') + ("id", "address_id", "data"), + (1, 2, "ding 1/2"), + (2, 5, "ding 2/5"), ), ) @@ -252,14 +313,14 @@ class DeleteFromRoundTripTest(fixtures.TablesTest): conn.execute(dingalings.delete()) # fk violation otherwise conn.execute( - addresses.delete(). - where(users.c.id == addresses.c.user_id). - where(users.c.name == 'ed') + addresses.delete() + .where(users.c.id == addresses.c.user_id) + .where(users.c.name == "ed") ) expected = [ - (1, 7, 'x', 'jack@bean.com'), - (5, 9, 'x', 'fred@fred.com') + (1, 7, "x", "jack@bean.com"), + (5, 9, "x", "fred@fred.com"), ] self._assert_table(addresses, expected) @@ -270,14 +331,13 @@ class DeleteFromRoundTripTest(fixtures.TablesTest): dingalings = self.tables.dingalings testing.db.execute( - dingalings.delete(). - where(users.c.id == addresses.c.user_id). - where(users.c.name == 'ed'). - where(addresses.c.id == dingalings.c.address_id)) - - expected = [ - (2, 5, 'ding 2/5') - ] + dingalings.delete() + .where(users.c.id == addresses.c.user_id) + .where(users.c.name == "ed") + .where(addresses.c.id == dingalings.c.address_id) + ) + + expected = [(2, 5, "ding 2/5")] self._assert_table(dingalings, expected) @testing.requires.delete_from @@ -289,16 +349,13 @@ class DeleteFromRoundTripTest(fixtures.TablesTest): conn.execute(dingalings.delete()) # fk violation otherwise a1 = addresses.alias() conn.execute( - addresses.delete(). - where(users.c.id == addresses.c.user_id). - where(users.c.name == 'ed'). - where(a1.c.id == addresses.c.id) + addresses.delete() + .where(users.c.id == addresses.c.user_id) + .where(users.c.name == "ed") + .where(a1.c.id == addresses.c.id) ) - expected = [ - (1, 7, 'x', 'jack@bean.com'), - (5, 9, 'x', 'fred@fred.com') - ] + expected = [(1, 7, "x", "jack@bean.com"), (5, 9, "x", "fred@fred.com")] self._assert_table(addresses, expected) @testing.requires.delete_from @@ -309,14 +366,13 @@ class DeleteFromRoundTripTest(fixtures.TablesTest): d1 = dingalings.alias() testing.db.execute( - delete(d1). - where(users.c.id == addresses.c.user_id). - where(users.c.name == 'ed'). - where(addresses.c.id == d1.c.address_id)) - - expected = [ - (2, 5, 'ding 2/5') - ] + delete(d1) + .where(users.c.id == addresses.c.user_id) + .where(users.c.name == "ed") + .where(addresses.c.id == d1.c.address_id) + ) + + expected = [(2, 5, "ding 2/5")] self._assert_table(dingalings, expected) def _assert_table(self, table, expected): diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 6fba7519c6..b775df7406 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -1,8 +1,24 @@ from sqlalchemy.testing import eq_, is_ import datetime -from sqlalchemy import func, select, Integer, literal, DateTime, Table, \ - Column, Sequence, MetaData, extract, Date, String, bindparam, \ - literal_column, ARRAY, Numeric, Boolean +from sqlalchemy import ( + func, + select, + Integer, + literal, + DateTime, + Table, + Column, + Sequence, + MetaData, + extract, + Date, + String, + bindparam, + literal_column, + ARRAY, + Numeric, + Boolean, +) from sqlalchemy.sql import table, column from sqlalchemy import sql, util from sqlalchemy.sql.compiler import BIND_TEMPLATES @@ -16,31 +32,35 @@ from sqlalchemy.testing import fixtures, AssertsCompiledSQL, engines from sqlalchemy.dialects import sqlite, postgresql, mysql, oracle from sqlalchemy.testing import assert_raises_message, assert_raises -table1 = table('mytable', - column('myid', Integer), - column('name', String), - column('description', String), - ) +table1 = table( + "mytable", + column("myid", Integer), + column("name", String), + column("description", String), +) class CompileTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def tear_down(self): functions._registry.clear() def test_compile(self): - for dialect in all_dialects(exclude=('sybase', )): + for dialect in all_dialects(exclude=("sybase",)): bindtemplate = BIND_TEMPLATES[dialect.paramstyle] - self.assert_compile(func.current_timestamp(), - "CURRENT_TIMESTAMP", dialect=dialect) + self.assert_compile( + func.current_timestamp(), "CURRENT_TIMESTAMP", dialect=dialect + ) self.assert_compile(func.localtime(), "LOCALTIME", dialect=dialect) - if dialect.name in ('firebird',): - self.assert_compile(func.nosuchfunction(), - "nosuchfunction", dialect=dialect) + if dialect.name in ("firebird",): + self.assert_compile( + func.nosuchfunction(), "nosuchfunction", dialect=dialect + ) else: - self.assert_compile(func.nosuchfunction(), - "nosuchfunction()", dialect=dialect) + self.assert_compile( + func.nosuchfunction(), "nosuchfunction()", dialect=dialect + ) # test generic function compile class fake_func(GenericFunction): @@ -50,15 +70,17 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): GenericFunction.__init__(self, arg, **kwargs) self.assert_compile( - fake_func('foo'), - "fake_func(%s)" % - bindtemplate % {'name': 'fake_func_1', 'position': 1}, - dialect=dialect) + fake_func("foo"), + "fake_func(%s)" + % bindtemplate + % {"name": "fake_func_1", "position": 1}, + dialect=dialect, + ) def test_use_labels(self): - self.assert_compile(select([func.foo()], use_labels=True), - "SELECT foo() AS foo_1" - ) + self.assert_compile( + select([func.foo()], use_labels=True), "SELECT foo() AS foo_1" + ) def test_underscores(self): self.assert_compile(func.if_(), "if()") @@ -67,10 +89,10 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): assert isinstance(func.now().type, sqltypes.DateTime) for ret, dialect in [ - ('CURRENT_TIMESTAMP', sqlite.dialect()), - ('now()', postgresql.dialect()), - ('now()', mysql.dialect()), - ('CURRENT_TIMESTAMP', oracle.dialect()) + ("CURRENT_TIMESTAMP", sqlite.dialect()), + ("now()", postgresql.dialect()), + ("now()", mysql.dialect()), + ("CURRENT_TIMESTAMP", oracle.dialect()), ]: self.assert_compile(func.now(), ret, dialect=dialect) @@ -79,54 +101,55 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): assert isinstance(func.random(type_=Integer).type, Integer) for ret, dialect in [ - ('random()', sqlite.dialect()), - ('random()', postgresql.dialect()), - ('rand()', mysql.dialect()), - ('random()', oracle.dialect()) + ("random()", sqlite.dialect()), + ("random()", postgresql.dialect()), + ("rand()", mysql.dialect()), + ("random()", oracle.dialect()), ]: self.assert_compile(func.random(), ret, dialect=dialect) def test_cube_operators(self): - t = table('t', column('value'), - column('x'), column('y'), column('z'), column('q')) + t = table( + "t", + column("value"), + column("x"), + column("y"), + column("z"), + column("q"), + ) stmt = select([func.sum(t.c.value)]) self.assert_compile( stmt.group_by(func.cube(t.c.x, t.c.y)), - "SELECT sum(t.value) AS sum_1 FROM t GROUP BY CUBE(t.x, t.y)" + "SELECT sum(t.value) AS sum_1 FROM t GROUP BY CUBE(t.x, t.y)", ) self.assert_compile( stmt.group_by(func.rollup(t.c.x, t.c.y)), - "SELECT sum(t.value) AS sum_1 FROM t GROUP BY ROLLUP(t.x, t.y)" + "SELECT sum(t.value) AS sum_1 FROM t GROUP BY ROLLUP(t.x, t.y)", ) self.assert_compile( - stmt.group_by( - func.grouping_sets(t.c.x, t.c.y) - ), + stmt.group_by(func.grouping_sets(t.c.x, t.c.y)), "SELECT sum(t.value) AS sum_1 FROM t " - "GROUP BY GROUPING SETS(t.x, t.y)" + "GROUP BY GROUPING SETS(t.x, t.y)", ) self.assert_compile( stmt.group_by( func.grouping_sets( - sql.tuple_(t.c.x, t.c.y), - sql.tuple_(t.c.z, t.c.q), + sql.tuple_(t.c.x, t.c.y), sql.tuple_(t.c.z, t.c.q) ) ), "SELECT sum(t.value) AS sum_1 FROM t GROUP BY " - "GROUPING SETS((t.x, t.y), (t.z, t.q))" + "GROUPING SETS((t.x, t.y), (t.z, t.q))", ) def test_generic_annotation(self): - fn = func.coalesce('x', 'y')._annotate({"foo": "bar"}) - self.assert_compile( - fn, "coalesce(:coalesce_1, :coalesce_2)" - ) + fn = func.coalesce("x", "y")._annotate({"foo": "bar"}) + self.assert_compile(fn, "coalesce(:coalesce_1, :coalesce_2)") def test_custom_default_namespace(self): class myfunc(GenericFunction): @@ -160,6 +183,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def cls1(pk_name): class myfunc(GenericFunction): package = pk_name + return myfunc f1 = cls1("mypackage") @@ -170,15 +194,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_custom_name(self): class MyFunction(GenericFunction): - name = 'my_func' + name = "my_func" def __init__(self, *args): args = args + (3,) super(MyFunction, self).__init__(*args) self.assert_compile( - func.my_func(1, 2), - "my_func(:my_func_1, :my_func_2, :my_func_3)" + func.my_func(1, 2), "my_func(:my_func_1, :my_func_2, :my_func_3)" ) def test_custom_registered_identifier(self): @@ -197,120 +220,109 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): type = Integer identifier = "buf3" - self.assert_compile( - func.geo.buf1(), - "BufferOne()" - ) - self.assert_compile( - func.buf2(), - "BufferTwo()" - ) - self.assert_compile( - func.buf3(), - "BufferThree()" - ) + self.assert_compile(func.geo.buf1(), "BufferOne()") + self.assert_compile(func.buf2(), "BufferTwo()") + self.assert_compile(func.buf3(), "BufferThree()") def test_custom_args(self): class myfunc(GenericFunction): pass self.assert_compile( - myfunc(1, 2, 3), - "myfunc(:myfunc_1, :myfunc_2, :myfunc_3)" + myfunc(1, 2, 3), "myfunc(:myfunc_1, :myfunc_2, :myfunc_3)" ) def test_namespacing_conflicts(self): - self.assert_compile(func.text('foo'), 'text(:text_1)') + self.assert_compile(func.text("foo"), "text(:text_1)") def test_generic_count(self): assert isinstance(func.count().type, sqltypes.Integer) - self.assert_compile(func.count(), 'count(*)') - self.assert_compile(func.count(1), 'count(:count_1)') - c = column('abc') - self.assert_compile(func.count(c), 'count(abc)') + self.assert_compile(func.count(), "count(*)") + self.assert_compile(func.count(1), "count(:count_1)") + c = column("abc") + self.assert_compile(func.count(c), "count(abc)") def test_ansi_functions_with_args(self): - ct = func.current_timestamp('somearg') + ct = func.current_timestamp("somearg") self.assert_compile(ct, "CURRENT_TIMESTAMP(:current_timestamp_1)") def test_char_length_fixed_args(self): - assert_raises( - TypeError, - func.char_length, 'a', 'b' - ) - assert_raises( - TypeError, - func.char_length - ) + assert_raises(TypeError, func.char_length, "a", "b") + assert_raises(TypeError, func.char_length) def test_return_type_detection(self): for fn in [func.coalesce, func.max, func.min, func.sum]: for args, type_ in [ - ((datetime.date(2007, 10, 5), - datetime.date(2005, 10, 15)), sqltypes.Date), + ( + (datetime.date(2007, 10, 5), datetime.date(2005, 10, 15)), + sqltypes.Date, + ), ((3, 5), sqltypes.Integer), - ((decimal.Decimal(3), decimal.Decimal(5)), - sqltypes.Numeric), + ((decimal.Decimal(3), decimal.Decimal(5)), sqltypes.Numeric), (("foo", "bar"), sqltypes.String), - ((datetime.datetime(2007, 10, 5, 8, 3, 34), - datetime.datetime(2005, 10, 15, 14, 45, 33)), - sqltypes.DateTime) + ( + ( + datetime.datetime(2007, 10, 5, 8, 3, 34), + datetime.datetime(2005, 10, 15, 14, 45, 33), + ), + sqltypes.DateTime, + ), ]: - assert isinstance(fn(*args).type, type_), \ - "%s / %r != %s" % (fn(), fn(*args).type, type_) + assert isinstance(fn(*args).type, type_), "%s / %r != %s" % ( + fn(), + fn(*args).type, + type_, + ) assert isinstance(func.concat("foo", "bar").type, sqltypes.String) def test_assorted(self): - table1 = table('mytable', - column('myid', Integer), - ) + table1 = table("mytable", column("myid", Integer)) - table2 = table( - 'myothertable', - column('otherid', Integer), - ) + table2 = table("myothertable", column("otherid", Integer)) # test an expression with a function - self.assert_compile(func.lala(3, 4, literal("five"), - table1.c.myid) * table2.c.otherid, - "lala(:lala_1, :lala_2, :param_1, mytable.myid) * " - "myothertable.otherid") + self.assert_compile( + func.lala(3, 4, literal("five"), table1.c.myid) * table2.c.otherid, + "lala(:lala_1, :lala_2, :param_1, mytable.myid) * " + "myothertable.otherid", + ) # test it in a SELECT - self.assert_compile(select( - [func.count(table1.c.myid)]), - "SELECT count(mytable.myid) AS count_1 FROM mytable") + self.assert_compile( + select([func.count(table1.c.myid)]), + "SELECT count(mytable.myid) AS count_1 FROM mytable", + ) # test a "dotted" function name - self.assert_compile(select([func.foo.bar.lala( - table1.c.myid)]), - "SELECT foo.bar.lala(mytable.myid) AS lala_1 FROM mytable") + self.assert_compile( + select([func.foo.bar.lala(table1.c.myid)]), + "SELECT foo.bar.lala(mytable.myid) AS lala_1 FROM mytable", + ) # test the bind parameter name with a "dotted" function name is # only the name (limits the length of the bind param name) - self.assert_compile(select([func.foo.bar.lala(12)]), - "SELECT foo.bar.lala(:lala_2) AS lala_1") + self.assert_compile( + select([func.foo.bar.lala(12)]), + "SELECT foo.bar.lala(:lala_2) AS lala_1", + ) # test a dotted func off the engine itself self.assert_compile(func.lala.hoho(7), "lala.hoho(:hoho_1)") # test None becomes NULL self.assert_compile( - func.my_func( - 1, - 2, - None, - 3), - "my_func(:my_func_1, :my_func_2, NULL, :my_func_3)") + func.my_func(1, 2, None, 3), + "my_func(:my_func_1, :my_func_2, NULL, :my_func_3)", + ) # test pickling self.assert_compile( - util.pickle.loads(util.pickle.dumps( - func.my_func(1, 2, None, 3))), - "my_func(:my_func_1, :my_func_2, NULL, :my_func_3)") + util.pickle.loads(util.pickle.dumps(func.my_func(1, 2, None, 3))), + "my_func(:my_func_1, :my_func_2, NULL, :my_func_3)", + ) # assert func raises AttributeError for __bases__ attribute, since # its not a class fixes pydoc @@ -322,33 +334,40 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_functions_with_cols(self): users = table( - 'users', - column('id'), - column('name'), - column('fullname')) - calculate = select([column('q'), column('z'), column('r')], from_obj=[ - func.calculate( - bindparam('x', None), bindparam('y', None) - )]) - - self.assert_compile(select([users], users.c.id > calculate.c.z), - "SELECT users.id, users.name, users.fullname " - "FROM users, (SELECT q, z, r " - "FROM calculate(:x, :y)) " - "WHERE users.id > z" - ) - - s = select([users], users.c.id.between( - calculate.alias('c1').unique_params(x=17, y=45).c.z, - calculate.alias('c2').unique_params(x=5, y=12).c.z)) - - self.assert_compile( - s, "SELECT users.id, users.name, users.fullname " + "users", column("id"), column("name"), column("fullname") + ) + calculate = select( + [column("q"), column("z"), column("r")], + from_obj=[ + func.calculate(bindparam("x", None), bindparam("y", None)) + ], + ) + + self.assert_compile( + select([users], users.c.id > calculate.c.z), + "SELECT users.id, users.name, users.fullname " + "FROM users, (SELECT q, z, r " + "FROM calculate(:x, :y)) " + "WHERE users.id > z", + ) + + s = select( + [users], + users.c.id.between( + calculate.alias("c1").unique_params(x=17, y=45).c.z, + calculate.alias("c2").unique_params(x=5, y=12).c.z, + ), + ) + + self.assert_compile( + s, + "SELECT users.id, users.name, users.fullname " "FROM users, (SELECT q, z, r " "FROM calculate(:x_1, :y_1)) AS c1, (SELECT q, z, r " "FROM calculate(:x_2, :y_2)) AS c2 " - "WHERE users.id BETWEEN c1.z AND c2.z", checkparams={ - 'y_1': 45, 'x_1': 17, 'y_2': 12, 'x_2': 5}) + "WHERE users.id BETWEEN c1.z AND c2.z", + checkparams={"y_1": 45, "x_1": 17, "y_2": 12, "x_2": 5}, + ) def test_non_functions(self): expr = func.cast("foo", Integer) @@ -359,263 +378,253 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_select_method_one(self): expr = func.rows("foo") - self.assert_compile( - expr.select(), - "SELECT rows(:rows_2) AS rows_1" - ) + self.assert_compile(expr.select(), "SELECT rows(:rows_2) AS rows_1") def test_alias_method_one(self): expr = func.rows("foo") - self.assert_compile( - expr.alias(), - "rows(:rows_1)" - ) + self.assert_compile(expr.alias(), "rows(:rows_1)") def test_select_method_two(self): expr = func.rows("foo") self.assert_compile( - select(['*']).select_from(expr.select()), - "SELECT * FROM (SELECT rows(:rows_2) AS rows_1)" + select(["*"]).select_from(expr.select()), + "SELECT * FROM (SELECT rows(:rows_2) AS rows_1)", ) def test_select_method_three(self): expr = func.rows("foo") self.assert_compile( - select([column('foo')]).select_from(expr), - "SELECT foo FROM rows(:rows_1)" + select([column("foo")]).select_from(expr), + "SELECT foo FROM rows(:rows_1)", ) def test_alias_method_two(self): expr = func.rows("foo") self.assert_compile( - select(['*']).select_from(expr.alias('bar')), - "SELECT * FROM rows(:rows_1) AS bar" + select(["*"]).select_from(expr.alias("bar")), + "SELECT * FROM rows(:rows_1) AS bar", ) def test_alias_method_columns(self): - expr = func.rows("foo").alias('bar') + expr = func.rows("foo").alias("bar") # this isn't very useful but is the old behavior # prior to #2974. # testing here that the expression exports its column # list in a way that at least doesn't break. self.assert_compile( - select([expr]), - "SELECT bar.rows_1 FROM rows(:rows_2) AS bar" + select([expr]), "SELECT bar.rows_1 FROM rows(:rows_2) AS bar" ) def test_alias_method_columns_two(self): - expr = func.rows("foo").alias('bar') + expr = func.rows("foo").alias("bar") assert len(expr.c) def test_funcfilter_empty(self): - self.assert_compile( - func.count(1).filter(), - "count(:count_1)" - ) + self.assert_compile(func.count(1).filter(), "count(:count_1)") def test_funcfilter_criterion(self): self.assert_compile( - func.count(1).filter( - table1.c.name != None # noqa - ), - "count(:count_1) FILTER (WHERE mytable.name IS NOT NULL)" + func.count(1).filter(table1.c.name != None), # noqa + "count(:count_1) FILTER (WHERE mytable.name IS NOT NULL)", ) def test_funcfilter_compound_criterion(self): self.assert_compile( func.count(1).filter( - table1.c.name == None, # noqa - table1.c.myid > 0 + table1.c.name == None, table1.c.myid > 0 # noqa ), "count(:count_1) FILTER (WHERE mytable.name IS NULL AND " - "mytable.myid > :myid_1)" + "mytable.myid > :myid_1)", ) def test_funcfilter_label(self): self.assert_compile( - select([func.count(1).filter( - table1.c.description != None # noqa - ).label('foo')]), + select( + [ + func.count(1) + .filter(table1.c.description != None) # noqa + .label("foo") + ] + ), "SELECT count(:count_1) FILTER (WHERE mytable.description " - "IS NOT NULL) AS foo FROM mytable" + "IS NOT NULL) AS foo FROM mytable", ) def test_funcfilter_fromobj_fromfunc(self): # test from_obj generation. # from func: self.assert_compile( - select([ - func.max(table1.c.name).filter( - literal_column('description') != None # noqa - ) - ]), + select( + [ + func.max(table1.c.name).filter( + literal_column("description") != None # noqa + ) + ] + ), "SELECT max(mytable.name) FILTER (WHERE description " - "IS NOT NULL) AS anon_1 FROM mytable" + "IS NOT NULL) AS anon_1 FROM mytable", ) def test_funcfilter_fromobj_fromcriterion(self): # from criterion: self.assert_compile( - select([ - func.count(1).filter( - table1.c.name == 'name' - ) - ]), + select([func.count(1).filter(table1.c.name == "name")]), "SELECT count(:count_1) FILTER (WHERE mytable.name = :name_1) " - "AS anon_1 FROM mytable" + "AS anon_1 FROM mytable", ) def test_funcfilter_chaining(self): # test chaining: self.assert_compile( - select([ - func.count(1).filter( - table1.c.name == 'name' - ).filter( - table1.c.description == 'description' - ) - ]), + select( + [ + func.count(1) + .filter(table1.c.name == "name") + .filter(table1.c.description == "description") + ] + ), "SELECT count(:count_1) FILTER (WHERE " "mytable.name = :name_1 AND mytable.description = :description_1) " - "AS anon_1 FROM mytable" + "AS anon_1 FROM mytable", ) def test_funcfilter_windowing_orderby(self): # test filtered windowing: self.assert_compile( - select([ - func.rank().filter( - table1.c.name > 'foo' - ).over( - order_by=table1.c.name - ) - ]), + select( + [ + func.rank() + .filter(table1.c.name > "foo") + .over(order_by=table1.c.name) + ] + ), "SELECT rank() FILTER (WHERE mytable.name > :name_1) " - "OVER (ORDER BY mytable.name) AS anon_1 FROM mytable" + "OVER (ORDER BY mytable.name) AS anon_1 FROM mytable", ) def test_funcfilter_windowing_orderby_partitionby(self): self.assert_compile( - select([ - func.rank().filter( - table1.c.name > 'foo' - ).over( - order_by=table1.c.name, - partition_by=['description'] - ) - ]), + select( + [ + func.rank() + .filter(table1.c.name > "foo") + .over(order_by=table1.c.name, partition_by=["description"]) + ] + ), "SELECT rank() FILTER (WHERE mytable.name > :name_1) " "OVER (PARTITION BY mytable.description ORDER BY mytable.name) " - "AS anon_1 FROM mytable" + "AS anon_1 FROM mytable", ) def test_funcfilter_windowing_range(self): self.assert_compile( - select([ - func.rank().filter( - table1.c.name > 'foo' - ).over( - range_=(1, 5), - partition_by=['description'] - ) - ]), + select( + [ + func.rank() + .filter(table1.c.name > "foo") + .over(range_=(1, 5), partition_by=["description"]) + ] + ), "SELECT rank() FILTER (WHERE mytable.name > :name_1) " "OVER (PARTITION BY mytable.description RANGE BETWEEN :param_1 " "FOLLOWING AND :param_2 FOLLOWING) " - "AS anon_1 FROM mytable" + "AS anon_1 FROM mytable", ) def test_funcfilter_windowing_rows(self): self.assert_compile( - select([ - func.rank().filter( - table1.c.name > 'foo' - ).over( - rows=(1, 5), - partition_by=['description'] - ) - ]), + select( + [ + func.rank() + .filter(table1.c.name > "foo") + .over(rows=(1, 5), partition_by=["description"]) + ] + ), "SELECT rank() FILTER (WHERE mytable.name > :name_1) " "OVER (PARTITION BY mytable.description ROWS BETWEEN :param_1 " "FOLLOWING AND :param_2 FOLLOWING) " - "AS anon_1 FROM mytable" + "AS anon_1 FROM mytable", ) def test_funcfilter_within_group(self): - stmt = select([ - table1.c.myid, - func.percentile_cont(0.5).within_group( - table1.c.name - ) - ]) + stmt = select( + [ + table1.c.myid, + func.percentile_cont(0.5).within_group(table1.c.name), + ] + ) self.assert_compile( stmt, "SELECT mytable.myid, percentile_cont(:percentile_cont_1) " "WITHIN GROUP (ORDER BY mytable.name) " "AS anon_1 " "FROM mytable", - {'percentile_cont_1': 0.5} + {"percentile_cont_1": 0.5}, ) def test_funcfilter_within_group_multi(self): - stmt = select([ - table1.c.myid, - func.percentile_cont(0.5).within_group( - table1.c.name, table1.c.description - ) - ]) + stmt = select( + [ + table1.c.myid, + func.percentile_cont(0.5).within_group( + table1.c.name, table1.c.description + ), + ] + ) self.assert_compile( stmt, "SELECT mytable.myid, percentile_cont(:percentile_cont_1) " "WITHIN GROUP (ORDER BY mytable.name, mytable.description) " "AS anon_1 " "FROM mytable", - {'percentile_cont_1': 0.5} + {"percentile_cont_1": 0.5}, ) def test_funcfilter_within_group_desc(self): - stmt = select([ - table1.c.myid, - func.percentile_cont(0.5).within_group( - table1.c.name.desc() - ) - ]) + stmt = select( + [ + table1.c.myid, + func.percentile_cont(0.5).within_group(table1.c.name.desc()), + ] + ) self.assert_compile( stmt, "SELECT mytable.myid, percentile_cont(:percentile_cont_1) " "WITHIN GROUP (ORDER BY mytable.name DESC) " "AS anon_1 " "FROM mytable", - {'percentile_cont_1': 0.5} + {"percentile_cont_1": 0.5}, ) def test_funcfilter_within_group_w_over(self): - stmt = select([ - table1.c.myid, - func.percentile_cont(0.5).within_group( - table1.c.name.desc() - ).over(partition_by=table1.c.description) - ]) + stmt = select( + [ + table1.c.myid, + func.percentile_cont(0.5) + .within_group(table1.c.name.desc()) + .over(partition_by=table1.c.description), + ] + ) self.assert_compile( stmt, "SELECT mytable.myid, percentile_cont(:percentile_cont_1) " "WITHIN GROUP (ORDER BY mytable.name DESC) " "OVER (PARTITION BY mytable.description) AS anon_1 " "FROM mytable", - {'percentile_cont_1': 0.5} + {"percentile_cont_1": 0.5}, ) def test_incorrect_none_type(self): class MissingType(FunctionElement): - name = 'mt' + name = "mt" type = None assert_raises_message( TypeError, "Object None associated with '.type' attribute is " "not a TypeEngine class or object", - MissingType().compile + MissingType().compile, ) def test_as_comparison(self): @@ -624,21 +633,24 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): is_(fn.type._type_affinity, Boolean) self.assert_compile( - fn.left, ":substring_1", - checkparams={'substring_1': 'foo'}) + fn.left, ":substring_1", checkparams={"substring_1": "foo"} + ) self.assert_compile( - fn.right, ":substring_1", - checkparams={'substring_1': 'foobar'}) + fn.right, ":substring_1", checkparams={"substring_1": "foobar"} + ) self.assert_compile( - fn, "substring(:substring_1, :substring_2)", - checkparams={"substring_1": "foo", "substring_2": "foobar"}) + fn, + "substring(:substring_1, :substring_2)", + checkparams={"substring_1": "foo", "substring_2": "foobar"}, + ) def test_as_comparison_annotate(self): fn = func.foobar("x", "y", "q", "p", "r").as_comparison(2, 5) from sqlalchemy.sql import annotation + fn_annotated = annotation._deep_annotate(fn, {"token": "yes"}) eq_(fn.left._annotations, {}) @@ -646,15 +658,21 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_as_comparison_many_argument(self): - fn = func.some_comparison("x", "y", "z", "p", "q", "r").as_comparison(2, 5) + fn = func.some_comparison("x", "y", "z", "p", "q", "r").as_comparison( + 2, 5 + ) is_(fn.type._type_affinity, Boolean) self.assert_compile( - fn.left, ":some_comparison_1", - checkparams={"some_comparison_1": "y"}) + fn.left, + ":some_comparison_1", + checkparams={"some_comparison_1": "y"}, + ) self.assert_compile( - fn.right, ":some_comparison_1", - checkparams={"some_comparison_1": "q"}) + fn.right, + ":some_comparison_1", + checkparams={"some_comparison_1": "q"}, + ) from sqlalchemy.sql import visitors @@ -667,9 +685,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ":some_comparison_3, " ":some_comparison_4, :some_comparison_5, :some_comparison_6)", checkparams={ - 'some_comparison_1': 'x', 'some_comparison_2': 'y', - 'some_comparison_3': 'z', 'some_comparison_4': 'p', - 'some_comparison_5': 'q', 'some_comparison_6': 'r'}) + "some_comparison_1": "x", + "some_comparison_2": "y", + "some_comparison_3": "z", + "some_comparison_4": "p", + "some_comparison_5": "q", + "some_comparison_6": "r", + }, + ) self.assert_compile( fn_2, @@ -677,27 +700,30 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ":some_comparison_3, " ":some_comparison_4, ABC, :some_comparison_5)", checkparams={ - 'some_comparison_1': 'x', 'some_comparison_2': 'y', - 'some_comparison_3': 'z', 'some_comparison_4': 'p', - 'some_comparison_5': 'r'} + "some_comparison_1": "x", + "some_comparison_2": "y", + "some_comparison_3": "z", + "some_comparison_4": "p", + "some_comparison_5": "r", + }, ) class ReturnTypeTest(AssertsCompiledSQL, fixtures.TestBase): - def test_array_agg(self): - expr = func.array_agg(column('data', Integer)) + expr = func.array_agg(column("data", Integer)) is_(expr.type._type_affinity, ARRAY) is_(expr.type.item_type._type_affinity, Integer) def test_array_agg_array_datatype(self): - expr = func.array_agg(column('data', ARRAY(Integer))) + expr = func.array_agg(column("data", ARRAY(Integer))) is_(expr.type._type_affinity, ARRAY) is_(expr.type.item_type._type_affinity, Integer) def test_array_agg_array_literal_implicit_type(self): from sqlalchemy.dialects.postgresql import array, ARRAY as PG_ARRAY - expr = array([column('data', Integer), column('d2', Integer)]) + + expr = array([column("data", Integer), column("d2", Integer)]) assert isinstance(expr.type, PG_ARRAY) @@ -707,54 +733,50 @@ class ReturnTypeTest(AssertsCompiledSQL, fixtures.TestBase): is_(agg_expr.type.item_type._type_affinity, Integer) self.assert_compile( - agg_expr, - "array_agg(ARRAY[data, d2])", - dialect="postgresql" + agg_expr, "array_agg(ARRAY[data, d2])", dialect="postgresql" ) def test_array_agg_array_literal_explicit_type(self): from sqlalchemy.dialects.postgresql import array - expr = array([column('data', Integer), column('d2', Integer)]) + + expr = array([column("data", Integer), column("d2", Integer)]) agg_expr = func.array_agg(expr, type_=ARRAY(Integer)) is_(agg_expr.type._type_affinity, ARRAY) is_(agg_expr.type.item_type._type_affinity, Integer) self.assert_compile( - agg_expr, - "array_agg(ARRAY[data, d2])", - dialect="postgresql" + agg_expr, "array_agg(ARRAY[data, d2])", dialect="postgresql" ) def test_mode(self): - expr = func.mode(0.5).within_group( - column('data', Integer).desc()) + expr = func.mode(0.5).within_group(column("data", Integer).desc()) is_(expr.type._type_affinity, Integer) def test_percentile_cont(self): - expr = func.percentile_cont(0.5).within_group(column('data', Integer)) + expr = func.percentile_cont(0.5).within_group(column("data", Integer)) is_(expr.type._type_affinity, Integer) def test_percentile_cont_array(self): expr = func.percentile_cont(0.5, 0.7).within_group( - column('data', Integer)) + column("data", Integer) + ) is_(expr.type._type_affinity, ARRAY) is_(expr.type.item_type._type_affinity, Integer) def test_percentile_cont_array_desc(self): expr = func.percentile_cont(0.5, 0.7).within_group( - column('data', Integer).desc()) + column("data", Integer).desc() + ) is_(expr.type._type_affinity, ARRAY) is_(expr.type.item_type._type_affinity, Integer) def test_cume_dist(self): - expr = func.cume_dist(0.5).within_group( - column('data', Integer).desc()) + expr = func.cume_dist(0.5).within_group(column("data", Integer).desc()) is_(expr.type._type_affinity, Numeric) def test_percent_rank(self): - expr = func.percent_rank(0.5).within_group( - column('data', Integer)) + expr = func.percent_rank(0.5).within_group(column("data", Integer)) is_(expr.type._type_affinity, Numeric) @@ -790,13 +812,13 @@ class ExecuteTest(fixtures.TestBase): f = func.foo() eq_(f._execution_options, {}) - f = f.execution_options(foo='bar') - eq_(f._execution_options, {'foo': 'bar'}) + f = f.execution_options(foo="bar") + eq_(f._execution_options, {"foo": "bar"}) s = f.select() - eq_(s._execution_options, {'foo': 'bar'}) + eq_(s._execution_options, {"foo": "bar"}) - ret = testing.db.execute(func.now().execution_options(foo='bar')) - eq_(ret.context.execution_options, {'foo': 'bar'}) + ret = testing.db.execute(func.now().execution_options(foo="bar")) + eq_(ret.context.execution_options, {"foo": "bar"}) ret.close() @engines.close_first @@ -809,67 +831,83 @@ class ExecuteTest(fixtures.TestBase): """ meta = self.metadata - t = Table('t1', meta, - Column('id', Integer, Sequence('t1idseq', optional=True), - primary_key=True), - Column('value', Integer) - ) - t2 = Table('t2', meta, - Column('id', Integer, Sequence('t2idseq', optional=True), - primary_key=True), - Column('value', Integer, default=7), - Column('stuff', String(20), onupdate="thisisstuff") - ) + t = Table( + "t1", + meta, + Column( + "id", + Integer, + Sequence("t1idseq", optional=True), + primary_key=True, + ), + Column("value", Integer), + ) + t2 = Table( + "t2", + meta, + Column( + "id", + Integer, + Sequence("t2idseq", optional=True), + primary_key=True, + ), + Column("value", Integer, default=7), + Column("stuff", String(20), onupdate="thisisstuff"), + ) meta.create_all() t.insert(values=dict(value=func.length("one"))).execute() - assert t.select().execute().first()['value'] == 3 + assert t.select().execute().first()["value"] == 3 t.update(values=dict(value=func.length("asfda"))).execute() - assert t.select().execute().first()['value'] == 5 + assert t.select().execute().first()["value"] == 5 r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute() id = r.inserted_primary_key[0] - assert t.select(t.c.id == id).execute().first()['value'] == 9 + assert t.select(t.c.id == id).execute().first()["value"] == 9 t.update(values={t.c.value: func.length("asdf")}).execute() - assert t.select().execute().first()['value'] == 4 + assert t.select().execute().first()["value"] == 4 t2.insert().execute() t2.insert(values=dict(value=func.length("one"))).execute() - t2.insert(values=dict(value=func.length("asfda") + -19)).\ - execute(stuff="hi") + t2.insert(values=dict(value=func.length("asfda") + -19)).execute( + stuff="hi" + ) res = exec_sorted(select([t2.c.value, t2.c.stuff])) - eq_(res, [(-14, 'hi'), (3, None), (7, None)]) + eq_(res, [(-14, "hi"), (3, None), (7, None)]) - t2.update(values=dict(value=func.length("asdsafasd"))).\ - execute(stuff="some stuff") - assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == \ - [(9, "some stuff"), (9, "some stuff"), - (9, "some stuff")] + t2.update(values=dict(value=func.length("asdsafasd"))).execute( + stuff="some stuff" + ) + assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [ + (9, "some stuff"), + (9, "some stuff"), + (9, "some stuff"), + ] t2.delete().execute() t2.insert(values=dict(value=func.length("one") + 8)).execute() - assert t2.select().execute().first()['value'] == 11 + assert t2.select().execute().first()["value"] == 11 t2.update(values=dict(value=func.length("asfda"))).execute() eq_( select([t2.c.value, t2.c.stuff]).execute().first(), - (5, "thisisstuff") + (5, "thisisstuff"), ) - t2.update(values={t2.c.value: func.length("asfdaasdf"), - t2.c.stuff: "foo"}).execute() - eq_(select([t2.c.value, t2.c.stuff]).execute().first(), - (9, "foo") - ) + t2.update( + values={t2.c.value: func.length("asfdaasdf"), t2.c.stuff: "foo"} + ).execute() + eq_(select([t2.c.value, t2.c.stuff]).execute().first(), (9, "foo")) - @testing.fails_on_everything_except('postgresql') + @testing.fails_on_everything_except("postgresql") def test_as_from(self): # TODO: shouldn't this work on oracle too ? x = func.current_date(bind=testing.db).execute().scalar() y = func.current_date(bind=testing.db).select().execute().scalar() z = func.current_date(bind=testing.db).scalar() - w = select(['*'], from_obj=[func.current_date(bind=testing.db)]).\ - scalar() + w = select( + ["*"], from_obj=[func.current_date(bind=testing.db)] + ).scalar() assert x == y == z == w @@ -881,28 +919,30 @@ class ExecuteTest(fixtures.TestBase): def execute(field): return testing.db.execute(select([extract(field, date)])).scalar() - assert execute('year') == 2010 - assert execute('month') == 5 - assert execute('day') == 1 + assert execute("year") == 2010 + assert execute("month") == 5 + assert execute("day") == 1 date = datetime.datetime(2010, 5, 1, 12, 11, 10) - assert execute('year') == 2010 - assert execute('month') == 5 - assert execute('day') == 1 + assert execute("year") == 2010 + assert execute("month") == 5 + assert execute("day") == 1 def test_extract_expression(self): meta = MetaData(testing.db) - table = Table('test', meta, - Column('dt', DateTime), - Column('d', Date)) + table = Table("test", meta, Column("dt", DateTime), Column("d", Date)) meta.create_all() try: table.insert().execute( - {'dt': datetime.datetime(2010, 5, 1, 12, 11, 10), - 'd': datetime.date(2010, 5, 1)}) - rs = select([extract('year', table.c.dt), - extract('month', table.c.d)]).execute() + { + "dt": datetime.datetime(2010, 5, 1, 12, 11, 10), + "d": datetime.date(2010, 5, 1), + } + ) + rs = select( + [extract("year", table.c.dt), extract("month", table.c.d)] + ).execute() row = rs.first() assert row[0] == 2010 assert row[1] == 5 @@ -914,5 +954,6 @@ class ExecuteTest(fixtures.TestBase): def exec_sorted(statement, *args, **kw): """Executes a statement and returns a sorted list plain tuple rows.""" - return sorted([tuple(row) - for row in statement.execute(*args, **kw).fetchall()]) + return sorted( + [tuple(row) for row in statement.execute(*args, **kw).fetchall()] + ) diff --git a/test/sql/test_generative.py b/test/sql/test_generative.py index 8b14368798..1d064dd3a1 100644 --- a/test/sql/test_generative.py +++ b/test/sql/test_generative.py @@ -1,21 +1,45 @@ from sqlalchemy.sql import table, column, ClauseElement, operators from sqlalchemy.sql.expression import _clone, _from_objects -from sqlalchemy import func, select, Integer, Table, \ - Column, MetaData, extract, String, bindparam, tuple_, and_, union, text,\ - case, ForeignKey, literal_column -from sqlalchemy.testing import fixtures, AssertsExecutionResults, \ - AssertsCompiledSQL +from sqlalchemy import ( + func, + select, + Integer, + Table, + Column, + MetaData, + extract, + String, + bindparam, + tuple_, + and_, + union, + text, + case, + ForeignKey, + literal_column, +) +from sqlalchemy.testing import ( + fixtures, + AssertsExecutionResults, + AssertsCompiledSQL, +) from sqlalchemy import testing -from sqlalchemy.sql.visitors import ClauseVisitor, CloningVisitor, \ - cloned_traverse, ReplacingCloningVisitor +from sqlalchemy.sql.visitors import ( + ClauseVisitor, + CloningVisitor, + cloned_traverse, + ReplacingCloningVisitor, +) from sqlalchemy.sql import visitors from sqlalchemy import exc from sqlalchemy.sql import util as sql_util -from sqlalchemy.testing import (eq_, - is_, - is_not_, - assert_raises, - assert_raises_message) +from sqlalchemy.testing import ( + eq_, + is_, + is_not_, + assert_raises, + assert_raises_message, +) A = B = t1 = t2 = t3 = table1 = table2 = table3 = table4 = None @@ -33,7 +57,7 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults): # define deep equality semantics as well as deep # identity semantics. class A(ClauseElement): - __visit_name__ = 'a' + __visit_name__ = "a" def __init__(self, expr): self.expr = expr @@ -53,7 +77,7 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults): return "A(%s)" % repr(self.expr) class B(ClauseElement): - __visit_name__ = 'b' + __visit_name__ = "b" def __init__(self, *items): self.items = items @@ -93,8 +117,9 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults): a1 = A("expr1") struct = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) struct2 = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) - struct3 = B(a1, A("expr2"), B(A("expr1b"), - A("expr2bmodified")), A("expr3")) + struct3 = B( + a1, A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3") + ) assert a1.is_other(a1) assert struct.is_other(struct) @@ -104,11 +129,11 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults): assert not struct.is_other(struct3) def test_clone(self): - struct = B(A("expr1"), A("expr2"), B(A("expr1b"), - A("expr2b")), A("expr3")) + struct = B( + A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3") + ) class Vis(CloningVisitor): - def visit_a(self, a): pass @@ -121,11 +146,11 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults): assert not struct.is_other(s2) def test_no_clone(self): - struct = B(A("expr1"), A("expr2"), B(A("expr1b"), - A("expr2b")), A("expr3")) + struct = B( + A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3") + ) class Vis(ClauseVisitor): - def visit_a(self, a): pass @@ -139,7 +164,8 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults): def test_clone_anon_label(self): from sqlalchemy.sql.elements import Grouping - c1 = Grouping(literal_column('q')) + + c1 = Grouping(literal_column("q")) s1 = select([c1]) class Vis(CloningVisitor): @@ -151,15 +177,23 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults): eq_(list(s2.inner_columns)[0].anon_label, c1.anon_label) def test_change_in_place(self): - struct = B(A("expr1"), A("expr2"), B(A("expr1b"), - A("expr2b")), A("expr3")) - struct2 = B(A("expr1"), A("expr2modified"), B(A("expr1b"), - A("expr2b")), A("expr3")) - struct3 = B(A("expr1"), A("expr2"), B(A("expr1b"), - A("expr2bmodified")), A("expr3")) + struct = B( + A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3") + ) + struct2 = B( + A("expr1"), + A("expr2modified"), + B(A("expr1b"), A("expr2b")), + A("expr3"), + ) + struct3 = B( + A("expr1"), + A("expr2"), + B(A("expr1b"), A("expr2bmodified")), + A("expr3"), + ) class Vis(CloningVisitor): - def visit_a(self, a): if a.expr == "expr2": a.expr = "expr2modified" @@ -174,7 +208,6 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults): assert struct2 == s2 class Vis2(CloningVisitor): - def visit_a(self, a): if a.expr == "expr2b": a.expr = "expr2bmodified" @@ -194,9 +227,9 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults): class CustomObj(Column): pass - assert CustomObj.__visit_name__ == Column.__visit_name__ == 'column' + assert CustomObj.__visit_name__ == Column.__visit_name__ == "column" - foo, bar = CustomObj('foo', String), CustomObj('bar', String) + foo, bar = CustomObj("foo", String), CustomObj("bar", String) bin = foo == bar set(ClauseVisitor().iterate(bin)) assert set(ClauseVisitor().iterate(bin)) == set([foo, bar, bin]) @@ -212,19 +245,13 @@ class BinaryEndpointTraversalTest(fixtures.TestBase): def visit(binary, l, r): canary.append((binary.operator, l, r)) print(binary.operator, l, r) + sql_util.visit_binary_product(visit, expr) - eq_( - canary, expected - ) + eq_(canary, expected) def test_basic(self): a, b = column("a"), column("b") - self._assert_traversal( - a == b, - [ - (operators.eq, a, b) - ] - ) + self._assert_traversal(a == b, [(operators.eq, a, b)]) def test_with_tuples(self): a, b, c, d, b1, b1a, b1b, e, f = ( @@ -236,11 +263,9 @@ class BinaryEndpointTraversalTest(fixtures.TestBase): column("b1a"), column("b1b"), column("e"), - column("f") + column("f"), ) - expr = tuple_( - a, b, b1 == tuple_(b1a, b1b == d), c - ) > tuple_( + expr = tuple_(a, b, b1 == tuple_(b1a, b1b == d), c) > tuple_( func.go(e + f) ) self._assert_traversal( @@ -253,8 +278,8 @@ class BinaryEndpointTraversalTest(fixtures.TestBase): (operators.eq, b1, b1a), (operators.eq, b1b, d), (operators.gt, c, e), - (operators.gt, c, f) - ] + (operators.gt, c, f), + ], ) def test_composed(self): @@ -267,13 +292,7 @@ class BinaryEndpointTraversalTest(fixtures.TestBase): column("j"), column("r"), ) - expr = and_( - (a + b) == q + func.sum(e + f), - and_( - j == r, - f == q - ) - ) + expr = and_((a + b) == q + func.sum(e + f), and_(j == r, f == q)) self._assert_traversal( expr, [ @@ -285,7 +304,7 @@ class BinaryEndpointTraversalTest(fixtures.TestBase): (operators.eq, b, f), (operators.eq, j, r), (operators.eq, f, q), - ] + ], ) def test_subquery(self): @@ -293,11 +312,7 @@ class BinaryEndpointTraversalTest(fixtures.TestBase): subq = select([c]).where(c == a).as_scalar() expr = and_(a == b, b == subq) self._assert_traversal( - expr, - [ - (operators.eq, a, b), - (operators.eq, b, subq), - ] + expr, [(operators.eq, a, b), (operators.eq, b, subq)] ) @@ -305,36 +320,31 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): """test copy-in-place behavior of various ClauseElements.""" - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_class(cls): global t1, t2, t3 - t1 = table("table1", - column("col1"), - column("col2"), - column("col3"), - ) - t2 = table("table2", - column("col1"), - column("col2"), - column("col3"), - ) - t3 = Table('table3', MetaData(), - Column('col1', Integer), - Column('col2', Integer) - ) + t1 = table("table1", column("col1"), column("col2"), column("col3")) + t2 = table("table2", column("col1"), column("col2"), column("col3")) + t3 = Table( + "table3", + MetaData(), + Column("col1", Integer), + Column("col2", Integer), + ) def test_binary(self): clause = t1.c.col2 == t2.c.col2 eq_(str(clause), str(CloningVisitor().traverse(clause))) def test_binary_anon_label_quirk(self): - t = table('t1', column('col1')) + t = table("t1", column("col1")) f = t.c.col1 * 5 - self.assert_compile(select([f]), - "SELECT t1.col1 * :col1_1 AS anon_1 FROM t1") + self.assert_compile( + select([f]), "SELECT t1.col1 * :col1_1 AS anon_1 FROM t1" + ) f.anon_label @@ -342,9 +352,8 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): f = sql_util.ClauseAdapter(a).traverse(f) self.assert_compile( - select( - [f]), - "SELECT t1_1.col1 * :col1_1 AS anon_1 FROM t1 AS t1_1") + select([f]), "SELECT t1_1.col1 * :col1_1 AS anon_1 FROM t1 AS t1_1" + ) def test_join(self): clause = t1.join(t2, t1.c.col2 == t2.c.col2) @@ -352,7 +361,6 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): assert str(clause) == str(CloningVisitor().traverse(clause)) class Vis(CloningVisitor): - def visit_binary(self, binary): binary.right = t2.c.col3 @@ -368,24 +376,23 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): adapter = sql_util.ColumnAdapter(aliased) - f = select([ - adapter.columns[c] - for c in aliased2.c - ]).select_from(aliased) + f = select([adapter.columns[c] for c in aliased2.c]).select_from( + aliased + ) s = select([aliased2]).select_from(aliased) eq_(str(s), str(f)) - f = select([ - adapter.columns[func.count(aliased2.c.col1)] - ]).select_from(aliased) + f = select([adapter.columns[func.count(aliased2.c.col1)]]).select_from( + aliased + ) eq_( str(select([func.count(aliased2.c.col1)]).select_from(aliased)), - str(f) + str(f), ) def test_aliased_cloned_column_adapt_inner(self): - clause = select([t1.c.col1, func.foo(t1.c.col2).label('foo')]) + clause = select([t1.c.col1, func.foo(t1.c.col2).label("foo")]) aliased1 = select([clause.c.col1, clause.c.foo]) aliased2 = clause @@ -397,20 +404,12 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): # aliased2. corresponding_column checks these # now. adapter = sql_util.ColumnAdapter(aliased1) - f1 = select([ - adapter.columns[c] - for c in aliased2._raw_columns - ]) - f2 = select([ - adapter.columns[c] - for c in aliased3._raw_columns - ]) - eq_( - str(f1), str(f2) - ) + f1 = select([adapter.columns[c] for c in aliased2._raw_columns]) + f2 = select([adapter.columns[c] for c in aliased3._raw_columns]) + eq_(str(f1), str(f2)) def test_aliased_cloned_column_adapt_exported(self): - clause = select([t1.c.col1, func.foo(t1.c.col2).label('foo')]) + clause = select([t1.c.col1, func.foo(t1.c.col2).label("foo")]) aliased1 = select([clause.c.col1, clause.c.foo]) aliased2 = clause @@ -422,20 +421,12 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): # have an _is_clone_of pointer. But we now modified _make_proxy # to assign this. adapter = sql_util.ColumnAdapter(aliased1) - f1 = select([ - adapter.columns[c] - for c in aliased2.c - ]) - f2 = select([ - adapter.columns[c] - for c in aliased3.c - ]) - eq_( - str(f1), str(f2) - ) + f1 = select([adapter.columns[c] for c in aliased2.c]) + f2 = select([adapter.columns[c] for c in aliased3.c]) + eq_(str(f1), str(f2)) def test_aliased_cloned_schema_column_adapt_exported(self): - clause = select([t3.c.col1, func.foo(t3.c.col2).label('foo')]) + clause = select([t3.c.col1, func.foo(t3.c.col2).label("foo")]) aliased1 = select([clause.c.col1, clause.c.foo]) aliased2 = clause @@ -447,20 +438,12 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): # have an _is_clone_of pointer. But we now modified _make_proxy # to assign this. adapter = sql_util.ColumnAdapter(aliased1) - f1 = select([ - adapter.columns[c] - for c in aliased2.c - ]) - f2 = select([ - adapter.columns[c] - for c in aliased3.c - ]) - eq_( - str(f1), str(f2) - ) + f1 = select([adapter.columns[c] for c in aliased2.c]) + f2 = select([adapter.columns[c] for c in aliased3.c]) + eq_(str(f1), str(f2)) def test_labeled_expression_adapt(self): - lbl_x = (t3.c.col1 == 1).label('x') + lbl_x = (t3.c.col1 == 1).label("x") t3_alias = t3.alias() adapter = sql_util.ColumnAdapter(t3_alias) @@ -471,13 +454,13 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): lblx_adapted = adapter.traverse(lbl_x) self.assert_compile( select([lblx_adapted.self_group()]), - "SELECT (table3_1.col1 = :col1_1) AS x FROM table3 AS table3_1" + "SELECT (table3_1.col1 = :col1_1) AS x FROM table3 AS table3_1", ) self.assert_compile( select([lblx_adapted.is_(True)]), "SELECT (table3_1.col1 = :col1_1) IS 1 AS anon_1 " - "FROM table3 AS table3_1" + "FROM table3 AS table3_1", ) def test_cte_w_union(self): @@ -486,50 +469,55 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): s = select([func.sum(t.c.n)]) from sqlalchemy.sql.visitors import cloned_traverse + cloned = cloned_traverse(s, {}, {}) - self.assert_compile(cloned, - "WITH RECURSIVE t(n) AS " - "(SELECT values(:values_1) AS n " - "UNION ALL SELECT t.n + :n_1 AS anon_1 " - "FROM t " - "WHERE t.n < :n_2) " - "SELECT sum(t.n) AS sum_1 FROM t" - ) + self.assert_compile( + cloned, + "WITH RECURSIVE t(n) AS " + "(SELECT values(:values_1) AS n " + "UNION ALL SELECT t.n + :n_1 AS anon_1 " + "FROM t " + "WHERE t.n < :n_2) " + "SELECT sum(t.n) AS sum_1 FROM t", + ) def test_aliased_cte_w_union(self): - t = select([func.values(1).label("n")]).\ - cte("t", recursive=True).alias('foo') + t = ( + select([func.values(1).label("n")]) + .cte("t", recursive=True) + .alias("foo") + ) t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100)) s = select([func.sum(t.c.n)]) from sqlalchemy.sql.visitors import cloned_traverse + cloned = cloned_traverse(s, {}, {}) self.assert_compile( cloned, "WITH RECURSIVE foo(n) AS (SELECT values(:values_1) AS n " "UNION ALL SELECT foo.n + :n_1 AS anon_1 FROM foo " - "WHERE foo.n < :n_2) SELECT sum(foo.n) AS sum_1 FROM foo" + "WHERE foo.n < :n_2) SELECT sum(foo.n) AS sum_1 FROM foo", ) def test_text(self): clause = text( - "select * from table where foo=:bar", - bindparams=[bindparam('bar')]) + "select * from table where foo=:bar", bindparams=[bindparam("bar")] + ) c1 = str(clause) class Vis(CloningVisitor): - def visit_textclause(self, text): text.text = text.text + " SOME MODIFIER=:lala" - text._bindparams['lala'] = bindparam('lala') + text._bindparams["lala"] = bindparam("lala") clause2 = Vis().traverse(clause) assert c1 == str(clause) assert str(clause2) == c1 + " SOME MODIFIER=:lala" - assert list(clause._bindparams.keys()) == ['bar'] - assert set(clause2._bindparams.keys()) == set(['bar', 'lala']) + assert list(clause._bindparams.keys()) == ["bar"] + assert set(clause2._bindparams.keys()) == set(["bar", "lala"]) def test_select(self): s2 = select([t1]) @@ -537,9 +525,9 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): s3_assert = str(select([t1], t1.c.col2 == 7)) class Vis(CloningVisitor): - def visit_select(self, select): select.append_whereclause(t1.c.col2 == 7) + s3 = Vis().traverse(s2) assert str(s3) == s3_assert assert str(s2) == s2_assert @@ -547,18 +535,18 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): print(str(s3)) class Vis(ClauseVisitor): - def visit_select(self, select): select.append_whereclause(t1.c.col2 == 7) + Vis().traverse(s2) assert str(s2) == s3_assert s4_assert = str(select([t1], and_(t1.c.col2 == 7, t1.c.col3 == 9))) class Vis(CloningVisitor): - def visit_select(self, select): select.append_whereclause(t1.c.col3 == 9) + s4 = Vis().traverse(s3) print(str(s3)) print(str(s4)) @@ -568,11 +556,11 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): s5_assert = str(select([t1], and_(t1.c.col2 == 7, t1.c.col1 == 9))) class Vis(CloningVisitor): - def visit_binary(self, binary): if binary.left is t1.c.col3: binary.left = t1.c.col1 binary.right = bindparam("col1", unique=True) + s5 = Vis().traverse(s4) print(str(s4)) print(str(s5)) @@ -591,18 +579,18 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): assert str(u) == str(u2) assert [str(c) for c in u2.c] == cols - s1 = select([t1], t1.c.col1 == bindparam('id_param')) + s1 = select([t1], t1.c.col1 == bindparam("id_param")) s2 = select([t2]) u = union(s1, s2) u2 = u.params(id_param=7) u3 = u.params(id_param=10) assert str(u) == str(u2) == str(u3) - assert u2.compile().params == {'id_param': 7} - assert u3.compile().params == {'id_param': 10} + assert u2.compile().params == {"id_param": 7} + assert u3.compile().params == {"id_param": 10} def test_in(self): - expr = t1.c.col1.in_(['foo', 'bar']) + expr = t1.c.col1.in_(["foo", "bar"]) expr2 = CloningVisitor().traverse(expr) assert str(expr) == str(expr2) @@ -628,7 +616,7 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): def test_adapt_union(self): u = union( t1.select().where(t1.c.col1 == 4), - t1.select().where(t1.c.col1 == 5) + t1.select().where(t1.c.col1 == 5), ).alias() assert sql_util.ClauseAdapter(u).traverse(t1) is u @@ -642,48 +630,54 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): s3 = select([s], s.c.col2 == s2.c.col2) self.assert_compile( - s3, "SELECT anon_1.col1, anon_1.col2, anon_1.col3 FROM " + s3, + "SELECT anon_1.col1, anon_1.col2, anon_1.col3 FROM " "(SELECT table1.col1 AS col1, table1.col2 AS col2, " "table1.col3 AS col3 FROM table1 WHERE table1.col1 = :param_1) " "AS anon_1, " "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 " "AS col3 FROM table1 WHERE table1.col1 = :param_2) AS anon_2 " - "WHERE anon_1.col2 = anon_2.col2") + "WHERE anon_1.col2 = anon_2.col2", + ) s = select([t1], t1.c.col1 == 4).alias() s2 = CloningVisitor().traverse(s).alias() s3 = select([s], s.c.col2 == s2.c.col2) self.assert_compile( - s3, "SELECT anon_1.col1, anon_1.col2, anon_1.col3 FROM " + s3, + "SELECT anon_1.col1, anon_1.col2, anon_1.col3 FROM " "(SELECT table1.col1 AS col1, table1.col2 AS col2, " "table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) " "AS anon_1, " "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 " "AS col3 FROM table1 WHERE table1.col1 = :col1_2) AS anon_2 " - "WHERE anon_1.col2 = anon_2.col2") + "WHERE anon_1.col2 = anon_2.col2", + ) def test_extract(self): - s = select([extract('foo', t1.c.col1).label('col1')]) + s = select([extract("foo", t1.c.col1).label("col1")]) self.assert_compile( - s, - "SELECT EXTRACT(foo FROM table1.col1) AS col1 FROM table1") + s, "SELECT EXTRACT(foo FROM table1.col1) AS col1 FROM table1" + ) s2 = CloningVisitor().traverse(s).alias() s3 = select([s2.c.col1]) self.assert_compile( - s, - "SELECT EXTRACT(foo FROM table1.col1) AS col1 FROM table1") - self.assert_compile(s3, - "SELECT anon_1.col1 FROM (SELECT EXTRACT(foo FROM " - "table1.col1) AS col1 FROM table1) AS anon_1") + s, "SELECT EXTRACT(foo FROM table1.col1) AS col1 FROM table1" + ) + self.assert_compile( + s3, + "SELECT anon_1.col1 FROM (SELECT EXTRACT(foo FROM " + "table1.col1) AS col1 FROM table1) AS anon_1", + ) - @testing.emits_warning('.*replaced by another column with the same key') + @testing.emits_warning(".*replaced by another column with the same key") def test_alias(self): - subq = t2.select().alias('subq') - s = select([t1.c.col1, subq.c.col1], - from_obj=[t1, subq, - t1.join(subq, t1.c.col1 == subq.c.col2)] - ) + subq = t2.select().alias("subq") + s = select( + [t1.c.col1, subq.c.col1], + from_obj=[t1, subq, t1.join(subq, t1.c.col1 == subq.c.col2)], + ) orig = str(s) s2 = CloningVisitor().traverse(s) assert orig == str(s) == str(s2) @@ -691,26 +685,26 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): s4 = CloningVisitor().traverse(s2) assert orig == str(s) == str(s2) == str(s4) - s3 = sql_util.ClauseAdapter(table('foo')).traverse(s) + s3 = sql_util.ClauseAdapter(table("foo")).traverse(s) assert orig == str(s) == str(s3) - s4 = sql_util.ClauseAdapter(table('foo')).traverse(s3) + s4 = sql_util.ClauseAdapter(table("foo")).traverse(s3) assert orig == str(s) == str(s3) == str(s4) - subq = subq.alias('subq') - s = select([t1.c.col1, subq.c.col1], - from_obj=[t1, subq, - t1.join(subq, t1.c.col1 == subq.c.col2)] - ) + subq = subq.alias("subq") + s = select( + [t1.c.col1, subq.c.col1], + from_obj=[t1, subq, t1.join(subq, t1.c.col1 == subq.c.col2)], + ) s5 = CloningVisitor().traverse(s) assert orig == str(s) == str(s5) def test_correlated_select(self): - s = select([literal_column('*')], t1.c.col1 == t2.c.col1, - from_obj=[t1, t2]).correlate(t2) + s = select( + [literal_column("*")], t1.c.col1 == t2.c.col1, from_obj=[t1, t2] + ).correlate(t2) class Vis(CloningVisitor): - def visit_select(self, select): select.append_whereclause(t1.c.col2 == 7) @@ -719,26 +713,30 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT table2.col1, table2.col2, table2.col3 " "FROM table2 WHERE table2.col1 = " "(SELECT * FROM table1 WHERE table1.col1 = table2.col1 " - "AND table1.col2 = :col2_1)" + "AND table1.col2 = :col2_1)", ) def test_this_thing(self): - s = select([t1]).where(t1.c.col1 == 'foo').alias() + s = select([t1]).where(t1.c.col1 == "foo").alias() s2 = select([s.c.col1]) - self.assert_compile(s2, - 'SELECT anon_1.col1 FROM (SELECT ' - 'table1.col1 AS col1, table1.col2 AS col2, ' - 'table1.col3 AS col3 FROM table1 WHERE ' - 'table1.col1 = :col1_1) AS anon_1') + self.assert_compile( + s2, + "SELECT anon_1.col1 FROM (SELECT " + "table1.col1 AS col1, table1.col2 AS col2, " + "table1.col3 AS col3 FROM table1 WHERE " + "table1.col1 = :col1_1) AS anon_1", + ) t1a = t1.alias() s2 = sql_util.ClauseAdapter(t1a).traverse(s2) - self.assert_compile(s2, - 'SELECT anon_1.col1 FROM (SELECT ' - 'table1_1.col1 AS col1, table1_1.col2 AS ' - 'col2, table1_1.col3 AS col3 FROM table1 ' - 'AS table1_1 WHERE table1_1.col1 = ' - ':col1_1) AS anon_1') + self.assert_compile( + s2, + "SELECT anon_1.col1 FROM (SELECT " + "table1_1.col1 AS col1, table1_1.col2 AS " + "col2, table1_1.col3 AS col3 FROM table1 " + "AS table1_1 WHERE table1_1.col1 = " + ":col1_1) AS anon_1", + ) def test_select_fromtwice_one(self): t1a = t1.alias() @@ -746,95 +744,91 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): s = select([1], t1.c.col1 == t1a.c.col1, from_obj=t1a).correlate(t1a) s = select([t1]).where(t1.c.col1 == s) self.assert_compile( - s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1 " + s, + "SELECT table1.col1, table1.col2, table1.col3 FROM table1 " "WHERE table1.col1 = " "(SELECT 1 FROM table1, table1 AS table1_1 " - "WHERE table1.col1 = table1_1.col1)") + "WHERE table1.col1 = table1_1.col1)", + ) s = CloningVisitor().traverse(s) self.assert_compile( - s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1 " + s, + "SELECT table1.col1, table1.col2, table1.col3 FROM table1 " "WHERE table1.col1 = " "(SELECT 1 FROM table1, table1 AS table1_1 " - "WHERE table1.col1 = table1_1.col1)") + "WHERE table1.col1 = table1_1.col1)", + ) def test_select_fromtwice_two(self): - s = select([t1]).where(t1.c.col1 == 'foo').alias() + s = select([t1]).where(t1.c.col1 == "foo").alias() s2 = select([1], t1.c.col1 == s.c.col1, from_obj=s).correlate(t1) s3 = select([t1]).where(t1.c.col1 == s2) self.assert_compile( - s3, "SELECT table1.col1, table1.col2, table1.col3 " + s3, + "SELECT table1.col1, table1.col2, table1.col3 " "FROM table1 WHERE table1.col1 = " "(SELECT 1 FROM " "(SELECT table1.col1 AS col1, table1.col2 AS col2, " "table1.col3 AS col3 FROM table1 " "WHERE table1.col1 = :col1_1) " - "AS anon_1 WHERE table1.col1 = anon_1.col1)") + "AS anon_1 WHERE table1.col1 = anon_1.col1)", + ) s4 = ReplacingCloningVisitor().traverse(s3) self.assert_compile( - s4, "SELECT table1.col1, table1.col2, table1.col3 " + s4, + "SELECT table1.col1, table1.col2, table1.col3 " "FROM table1 WHERE table1.col1 = " "(SELECT 1 FROM " "(SELECT table1.col1 AS col1, table1.col2 AS col2, " "table1.col3 AS col3 FROM table1 " "WHERE table1.col1 = :col1_1) " - "AS anon_1 WHERE table1.col1 = anon_1.col1)") + "AS anon_1 WHERE table1.col1 = anon_1.col1)", + ) class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_class(cls): global t1, t2 - t1 = table("table1", - column("col1"), - column("col2"), - column("col3"), - column("col4") - ) - t2 = table("table2", - column("col1"), - column("col2"), - column("col3"), - ) + t1 = table( + "table1", + column("col1"), + column("col2"), + column("col3"), + column("col4"), + ) + t2 = table("table2", column("col1"), column("col2"), column("col3")) def test_traverse_memoizes_w_columns(self): t1a = t1.alias() adapter = sql_util.ColumnAdapter(t1a, anonymize_labels=True) - expr = select([t1a.c.col1]).label('x') + expr = select([t1a.c.col1]).label("x") expr_adapted = adapter.traverse(expr) is_not_(expr, expr_adapted) - is_( - adapter.columns[expr], - expr_adapted - ) + is_(adapter.columns[expr], expr_adapted) def test_traverse_memoizes_w_itself(self): t1a = t1.alias() adapter = sql_util.ColumnAdapter(t1a, anonymize_labels=True) - expr = select([t1a.c.col1]).label('x') + expr = select([t1a.c.col1]).label("x") expr_adapted = adapter.traverse(expr) is_not_(expr, expr_adapted) - is_( - adapter.traverse(expr), - expr_adapted - ) + is_(adapter.traverse(expr), expr_adapted) def test_columns_memoizes_w_itself(self): t1a = t1.alias() adapter = sql_util.ColumnAdapter(t1a, anonymize_labels=True) - expr = select([t1a.c.col1]).label('x') + expr = select([t1a.c.col1]).label("x") expr_adapted = adapter.columns[expr] is_not_(expr, expr_adapted) - is_( - adapter.columns[expr], - expr_adapted - ) + is_(adapter.columns[expr], expr_adapted) def test_wrapping_fallthrough(self): t1a = t1.alias(name="t1a") @@ -850,57 +844,35 @@ class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL): # t1.c.col1 -> s1.c.t1a_col1 # adapted by a2 - is_( - a3.columns[t1.c.col1], s1.c.t1a_col1 - ) - is_( - a4.columns[t1.c.col1], s1.c.t1a_col1 - ) + is_(a3.columns[t1.c.col1], s1.c.t1a_col1) + is_(a4.columns[t1.c.col1], s1.c.t1a_col1) # chaining can't fall through because a1 grabs it # first - is_( - a5.columns[t1.c.col1], t1a.c.col1 - ) + is_(a5.columns[t1.c.col1], t1a.c.col1) # t2.c.col1 -> s1.c.t2a_col1 # adapted by a2 - is_( - a3.columns[t2.c.col1], s1.c.t2a_col1 - ) - is_( - a4.columns[t2.c.col1], s1.c.t2a_col1 - ) + is_(a3.columns[t2.c.col1], s1.c.t2a_col1) + is_(a4.columns[t2.c.col1], s1.c.t2a_col1) # chaining, t2 hits s1 - is_( - a5.columns[t2.c.col1], s1.c.t2a_col1 - ) + is_(a5.columns[t2.c.col1], s1.c.t2a_col1) # t1.c.col2 -> t1a.c.col2 # fallthrough to a1 - is_( - a3.columns[t1.c.col2], t1a.c.col2 - ) - is_( - a4.columns[t1.c.col2], t1a.c.col2 - ) + is_(a3.columns[t1.c.col2], t1a.c.col2) + is_(a4.columns[t1.c.col2], t1a.c.col2) # chaining hits a1 - is_( - a5.columns[t1.c.col2], t1a.c.col2 - ) + is_(a5.columns[t1.c.col2], t1a.c.col2) # t2.c.col2 -> t2.c.col2 # fallthrough to no adaption - is_( - a3.columns[t2.c.col2], t2.c.col2 - ) - is_( - a4.columns[t2.c.col2], t2.c.col2 - ) + is_(a3.columns[t2.c.col2], t2.c.col2) + is_(a4.columns[t2.c.col2], t2.c.col2) def test_wrapping_ordering(self): """illustrate an example where order of wrappers matters. @@ -926,25 +898,15 @@ class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL): # in different contexts, order of wrapping matters # t2.c.col1 via a2 is stmt2.c.col1; then ignored by a1 - is_( - a2_to_a1.columns[t2.c.col1], stmt2.c.col1 - ) + is_(a2_to_a1.columns[t2.c.col1], stmt2.c.col1) # t2.c.col1 via a1 is stmt.c.table2_col1; a2 then # sends this to stmt2.c.table2_col1 - is_( - a1_to_a2.columns[t2.c.col1], stmt2.c.table2_col1 - ) + is_(a1_to_a2.columns[t2.c.col1], stmt2.c.table2_col1) # for mutually exclusive columns, order doesn't matter - is_( - a2_to_a1.columns[t1.c.col1], stmt2.c.table1_col1 - ) - is_( - a1_to_a2.columns[t1.c.col1], stmt2.c.table1_col1 - ) - is_( - a2_to_a1.columns[t2.c.col2], stmt2.c.col2 - ) + is_(a2_to_a1.columns[t1.c.col1], stmt2.c.table1_col1) + is_(a1_to_a2.columns[t1.c.col1], stmt2.c.table1_col1) + is_(a2_to_a1.columns[t2.c.col2], stmt2.c.col2) def test_wrapping_multiple(self): """illustrate that wrapping runs both adapters""" @@ -959,7 +921,7 @@ class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( a3.traverse(stmt), - "SELECT t1a.col1, t2a.col2 FROM table1 AS t1a, table2 AS t2a" + "SELECT t1a.col1, t2a.col2 FROM table1 AS t1a, table2 AS t2a", ) # chaining does too because these adapters don't share any @@ -967,7 +929,7 @@ class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL): a4 = a2.chain(a1) self.assert_compile( a4.traverse(stmt), - "SELECT t1a.col1, t2a.col2 FROM table1 AS t1a, table2 AS t2a" + "SELECT t1a.col1, t2a.col2 FROM table1 AS t1a, table2 AS t2a", ) def test_wrapping_inclusions(self): @@ -977,13 +939,13 @@ class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL): t1a = t1.alias(name="t1a") t2a = t2.alias(name="t2a") a1 = sql_util.ColumnAdapter( - t1a, - include_fn=lambda col: "a1" in col._annotations) + t1a, include_fn=lambda col: "a1" in col._annotations + ) s1 = select([t1a, t2a]).apply_labels().alias() a2 = sql_util.ColumnAdapter( - s1, - include_fn=lambda col: "a2" in col._annotations) + s1, include_fn=lambda col: "a2" in col._annotations + ) a3 = a2.wrap(a1) c1a1 = t1.c.col1._annotate(dict(a1=True)) @@ -994,62 +956,45 @@ class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL): c2a2 = t2.c.col1._annotate(dict(a2=True)) c2aa = t2.c.col1._annotate(dict(a1=True, a2=True)) - is_( - a3.columns[c1a1], t1a.c.col1 - ) - is_( - a3.columns[c1a2], s1.c.t1a_col1 - ) - is_( - a3.columns[c1aa], s1.c.t1a_col1 - ) + is_(a3.columns[c1a1], t1a.c.col1) + is_(a3.columns[c1a2], s1.c.t1a_col1) + is_(a3.columns[c1aa], s1.c.t1a_col1) # not covered by a1, accepted by a2 - is_( - a3.columns[c2aa], s1.c.t2a_col1 - ) + is_(a3.columns[c2aa], s1.c.t2a_col1) # not covered by a1, accepted by a2 - is_( - a3.columns[c2a2], s1.c.t2a_col1 - ) + is_(a3.columns[c2a2], s1.c.t2a_col1) # not covered by a1, rejected by a2 - is_( - a3.columns[c2a1], c2a1 - ) + is_(a3.columns[c2a1], c2a1) class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_class(cls): global t1, t2 - t1 = table("table1", - column("col1"), - column("col2"), - column("col3"), - ) - t2 = table("table2", - column("col1"), - column("col2"), - column("col3"), - ) + t1 = table("table1", column("col1"), column("col2"), column("col3")) + t2 = table("table2", column("col1"), column("col2"), column("col3")) def test_correlation_on_clone(self): - t1alias = t1.alias('t1alias') - t2alias = t2.alias('t2alias') + t1alias = t1.alias("t1alias") + t2alias = t2.alias("t2alias") vis = sql_util.ClauseAdapter(t1alias) - s = select([literal_column('*')], - from_obj=[t1alias, t2alias]).as_scalar() + s = select( + [literal_column("*")], from_obj=[t1alias, t2alias] + ).as_scalar() assert t2alias in s._froms assert t1alias in s._froms - self.assert_compile(select([literal_column('*')], t2alias.c.col1 == s), - 'SELECT * FROM table2 AS t2alias WHERE ' - 't2alias.col1 = (SELECT * FROM table1 AS ' - 't1alias)') + self.assert_compile( + select([literal_column("*")], t2alias.c.col1 == s), + "SELECT * FROM table2 AS t2alias WHERE " + "t2alias.col1 = (SELECT * FROM table1 AS " + "t1alias)", + ) s = vis.traverse(s) assert t2alias not in s._froms # not present because it's been @@ -1060,61 +1005,91 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): # correlate list on "s" needs to take into account the full # _cloned_set for each element in _froms when correlating - self.assert_compile(select([literal_column('*')], t2alias.c.col1 == s), - 'SELECT * FROM table2 AS t2alias WHERE ' - 't2alias.col1 = (SELECT * FROM table1 AS ' - 't1alias)') - s = select([literal_column('*')], - from_obj=[t1alias, t2alias]).correlate(t2alias).as_scalar() - self.assert_compile(select([literal_column('*')], t2alias.c.col1 == s), - 'SELECT * FROM table2 AS t2alias WHERE ' - 't2alias.col1 = (SELECT * FROM table1 AS ' - 't1alias)') + self.assert_compile( + select([literal_column("*")], t2alias.c.col1 == s), + "SELECT * FROM table2 AS t2alias WHERE " + "t2alias.col1 = (SELECT * FROM table1 AS " + "t1alias)", + ) + s = ( + select([literal_column("*")], from_obj=[t1alias, t2alias]) + .correlate(t2alias) + .as_scalar() + ) + self.assert_compile( + select([literal_column("*")], t2alias.c.col1 == s), + "SELECT * FROM table2 AS t2alias WHERE " + "t2alias.col1 = (SELECT * FROM table1 AS " + "t1alias)", + ) s = vis.traverse(s) - self.assert_compile(select([literal_column('*')], t2alias.c.col1 == s), - 'SELECT * FROM table2 AS t2alias WHERE ' - 't2alias.col1 = (SELECT * FROM table1 AS ' - 't1alias)') + self.assert_compile( + select([literal_column("*")], t2alias.c.col1 == s), + "SELECT * FROM table2 AS t2alias WHERE " + "t2alias.col1 = (SELECT * FROM table1 AS " + "t1alias)", + ) s = CloningVisitor().traverse(s) - self.assert_compile(select([literal_column('*')], t2alias.c.col1 == s), - 'SELECT * FROM table2 AS t2alias WHERE ' - 't2alias.col1 = (SELECT * FROM table1 AS ' - 't1alias)') + self.assert_compile( + select([literal_column("*")], t2alias.c.col1 == s), + "SELECT * FROM table2 AS t2alias WHERE " + "t2alias.col1 = (SELECT * FROM table1 AS " + "t1alias)", + ) - s = select([literal_column('*')]).where(t1.c.col1 == t2.c.col1) \ + s = ( + select([literal_column("*")]) + .where(t1.c.col1 == t2.c.col1) .as_scalar() - self.assert_compile(select([t1.c.col1, s]), - 'SELECT table1.col1, (SELECT * FROM table2 ' - 'WHERE table1.col1 = table2.col1) AS ' - 'anon_1 FROM table1') + ) + self.assert_compile( + select([t1.c.col1, s]), + "SELECT table1.col1, (SELECT * FROM table2 " + "WHERE table1.col1 = table2.col1) AS " + "anon_1 FROM table1", + ) vis = sql_util.ClauseAdapter(t1alias) s = vis.traverse(s) - self.assert_compile(select([t1alias.c.col1, s]), - 'SELECT t1alias.col1, (SELECT * FROM ' - 'table2 WHERE t1alias.col1 = table2.col1) ' - 'AS anon_1 FROM table1 AS t1alias') + self.assert_compile( + select([t1alias.c.col1, s]), + "SELECT t1alias.col1, (SELECT * FROM " + "table2 WHERE t1alias.col1 = table2.col1) " + "AS anon_1 FROM table1 AS t1alias", + ) s = CloningVisitor().traverse(s) - self.assert_compile(select([t1alias.c.col1, s]), - 'SELECT t1alias.col1, (SELECT * FROM ' - 'table2 WHERE t1alias.col1 = table2.col1) ' - 'AS anon_1 FROM table1 AS t1alias') - s = select([literal_column('*')]).where(t1.c.col1 == t2.c.col1) \ - .correlate(t1).as_scalar() - self.assert_compile(select([t1.c.col1, s]), - 'SELECT table1.col1, (SELECT * FROM table2 ' - 'WHERE table1.col1 = table2.col1) AS ' - 'anon_1 FROM table1') + self.assert_compile( + select([t1alias.c.col1, s]), + "SELECT t1alias.col1, (SELECT * FROM " + "table2 WHERE t1alias.col1 = table2.col1) " + "AS anon_1 FROM table1 AS t1alias", + ) + s = ( + select([literal_column("*")]) + .where(t1.c.col1 == t2.c.col1) + .correlate(t1) + .as_scalar() + ) + self.assert_compile( + select([t1.c.col1, s]), + "SELECT table1.col1, (SELECT * FROM table2 " + "WHERE table1.col1 = table2.col1) AS " + "anon_1 FROM table1", + ) vis = sql_util.ClauseAdapter(t1alias) s = vis.traverse(s) - self.assert_compile(select([t1alias.c.col1, s]), - 'SELECT t1alias.col1, (SELECT * FROM ' - 'table2 WHERE t1alias.col1 = table2.col1) ' - 'AS anon_1 FROM table1 AS t1alias') + self.assert_compile( + select([t1alias.c.col1, s]), + "SELECT t1alias.col1, (SELECT * FROM " + "table2 WHERE t1alias.col1 = table2.col1) " + "AS anon_1 FROM table1 AS t1alias", + ) s = CloningVisitor().traverse(s) - self.assert_compile(select([t1alias.c.col1, s]), - 'SELECT t1alias.col1, (SELECT * FROM ' - 'table2 WHERE t1alias.col1 = table2.col1) ' - 'AS anon_1 FROM table1 AS t1alias') + self.assert_compile( + select([t1alias.c.col1, s]), + "SELECT t1alias.col1, (SELECT * FROM " + "table2 WHERE t1alias.col1 = table2.col1) " + "AS anon_1 FROM table1 AS t1alias", + ) @testing.fails_on_everything_except() def test_joins_dont_adapt(self): @@ -1122,284 +1097,319 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): # make much sense. ClauseAdapter doesn't make any changes if # it's against a straight join. - users = table('users', column('id')) - addresses = table('addresses', column('id'), column('user_id')) + users = table("users", column("id")) + addresses = table("addresses", column("id"), column("user_id")) ualias = users.alias() - s = select([func.count(addresses.c.id)], users.c.id - == addresses.c.user_id).correlate(users) + s = select( + [func.count(addresses.c.id)], users.c.id == addresses.c.user_id + ).correlate(users) s = sql_util.ClauseAdapter(ualias).traverse(s) j1 = addresses.join(ualias, addresses.c.user_id == ualias.c.id) - self.assert_compile(sql_util.ClauseAdapter(j1).traverse(s), - 'SELECT count(addresses.id) AS count_1 ' - 'FROM addresses WHERE users_1.id = ' - 'addresses.user_id') + self.assert_compile( + sql_util.ClauseAdapter(j1).traverse(s), + "SELECT count(addresses.id) AS count_1 " + "FROM addresses WHERE users_1.id = " + "addresses.user_id", + ) def test_table_to_alias_1(self): - t1alias = t1.alias('t1alias') + t1alias = t1.alias("t1alias") vis = sql_util.ClauseAdapter(t1alias) - ff = vis.traverse(func.count(t1.c.col1).label('foo')) + ff = vis.traverse(func.count(t1.c.col1).label("foo")) assert list(_from_objects(ff)) == [t1alias] def test_table_to_alias_2(self): - t1alias = t1.alias('t1alias') + t1alias = t1.alias("t1alias") vis = sql_util.ClauseAdapter(t1alias) self.assert_compile( - vis.traverse(select([literal_column('*')], from_obj=[t1])), - 'SELECT * FROM table1 AS t1alias') + vis.traverse(select([literal_column("*")], from_obj=[t1])), + "SELECT * FROM table1 AS t1alias", + ) def test_table_to_alias_3(self): - t1alias = t1.alias('t1alias') + t1alias = t1.alias("t1alias") vis = sql_util.ClauseAdapter(t1alias) self.assert_compile( - select([literal_column('*')], t1.c.col1 == t2.c.col2), - 'SELECT * FROM table1, table2 WHERE table1.col1 = table2.col2') + select([literal_column("*")], t1.c.col1 == t2.c.col2), + "SELECT * FROM table1, table2 WHERE table1.col1 = table2.col2", + ) def test_table_to_alias_4(self): - t1alias = t1.alias('t1alias') + t1alias = t1.alias("t1alias") vis = sql_util.ClauseAdapter(t1alias) - self.assert_compile(vis.traverse(select([literal_column('*')], - t1.c.col1 == t2.c.col2)), - 'SELECT * FROM table1 AS t1alias, table2 ' - 'WHERE t1alias.col1 = table2.col2') + self.assert_compile( + vis.traverse( + select([literal_column("*")], t1.c.col1 == t2.c.col2) + ), + "SELECT * FROM table1 AS t1alias, table2 " + "WHERE t1alias.col1 = table2.col2", + ) def test_table_to_alias_5(self): - t1alias = t1.alias('t1alias') + t1alias = t1.alias("t1alias") vis = sql_util.ClauseAdapter(t1alias) self.assert_compile( vis.traverse( select( - [literal_column('*')], + [literal_column("*")], t1.c.col1 == t2.c.col2, - from_obj=[ - t1, - t2])), - 'SELECT * FROM table1 AS t1alias, table2 ' - 'WHERE t1alias.col1 = table2.col2') + from_obj=[t1, t2], + ) + ), + "SELECT * FROM table1 AS t1alias, table2 " + "WHERE t1alias.col1 = table2.col2", + ) def test_table_to_alias_6(self): - t1alias = t1.alias('t1alias') + t1alias = t1.alias("t1alias") vis = sql_util.ClauseAdapter(t1alias) - self.assert_compile(select([t1alias, t2]).where( - t1alias.c.col1 == vis.traverse( - select([literal_column('*')], - t1.c.col1 == t2.c.col2, from_obj=[t1, t2]).correlate(t1) - ) - ), + self.assert_compile( + select([t1alias, t2]).where( + t1alias.c.col1 + == vis.traverse( + select( + [literal_column("*")], + t1.c.col1 == t2.c.col2, + from_obj=[t1, t2], + ).correlate(t1) + ) + ), "SELECT t1alias.col1, t1alias.col2, t1alias.col3, " "table2.col1, table2.col2, table2.col3 " "FROM table1 AS t1alias, table2 WHERE t1alias.col1 = " - "(SELECT * FROM table2 WHERE t1alias.col1 = table2.col2)" + "(SELECT * FROM table2 WHERE t1alias.col1 = table2.col2)", ) def test_table_to_alias_7(self): - t1alias = t1.alias('t1alias') + t1alias = t1.alias("t1alias") vis = sql_util.ClauseAdapter(t1alias) self.assert_compile( - select([t1alias, t2]). - where(t1alias.c.col1 == vis.traverse( - select([literal_column('*')], - t1.c.col1 == t2.c.col2, from_obj=[t1, t2]). - correlate(t2))), + select([t1alias, t2]).where( + t1alias.c.col1 + == vis.traverse( + select( + [literal_column("*")], + t1.c.col1 == t2.c.col2, + from_obj=[t1, t2], + ).correlate(t2) + ) + ), "SELECT t1alias.col1, t1alias.col2, t1alias.col3, " "table2.col1, table2.col2, table2.col3 " "FROM table1 AS t1alias, table2 " "WHERE t1alias.col1 = " "(SELECT * FROM table1 AS t1alias " - "WHERE t1alias.col1 = table2.col2)") + "WHERE t1alias.col1 = table2.col2)", + ) def test_table_to_alias_8(self): - t1alias = t1.alias('t1alias') + t1alias = t1.alias("t1alias") vis = sql_util.ClauseAdapter(t1alias) self.assert_compile( vis.traverse(case([(t1.c.col1 == 5, t1.c.col2)], else_=t1.c.col1)), - 'CASE WHEN (t1alias.col1 = :col1_1) THEN ' - 't1alias.col2 ELSE t1alias.col1 END') + "CASE WHEN (t1alias.col1 = :col1_1) THEN " + "t1alias.col2 ELSE t1alias.col1 END", + ) def test_table_to_alias_9(self): - t1alias = t1.alias('t1alias') + t1alias = t1.alias("t1alias") vis = sql_util.ClauseAdapter(t1alias) self.assert_compile( vis.traverse( - case( - [ - (5, - t1.c.col2)], - value=t1.c.col1, - else_=t1.c.col1)), - 'CASE t1alias.col1 WHEN :param_1 THEN ' - 't1alias.col2 ELSE t1alias.col1 END') + case([(5, t1.c.col2)], value=t1.c.col1, else_=t1.c.col1) + ), + "CASE t1alias.col1 WHEN :param_1 THEN " + "t1alias.col2 ELSE t1alias.col1 END", + ) def test_table_to_alias_10(self): - s = select([literal_column('*')], from_obj=[t1]).alias('foo') - self.assert_compile(s.select(), - 'SELECT foo.* FROM (SELECT * FROM table1) ' - 'AS foo') + s = select([literal_column("*")], from_obj=[t1]).alias("foo") + self.assert_compile( + s.select(), "SELECT foo.* FROM (SELECT * FROM table1) " "AS foo" + ) def test_table_to_alias_11(self): - s = select([literal_column('*')], from_obj=[t1]).alias('foo') - t1alias = t1.alias('t1alias') + s = select([literal_column("*")], from_obj=[t1]).alias("foo") + t1alias = t1.alias("t1alias") vis = sql_util.ClauseAdapter(t1alias) - self.assert_compile(vis.traverse(s.select()), - 'SELECT foo.* FROM (SELECT * FROM table1 ' - 'AS t1alias) AS foo') + self.assert_compile( + vis.traverse(s.select()), + "SELECT foo.* FROM (SELECT * FROM table1 " "AS t1alias) AS foo", + ) def test_table_to_alias_12(self): - s = select([literal_column('*')], from_obj=[t1]).alias('foo') - self.assert_compile(s.select(), - 'SELECT foo.* FROM (SELECT * FROM table1) ' - 'AS foo') + s = select([literal_column("*")], from_obj=[t1]).alias("foo") + self.assert_compile( + s.select(), "SELECT foo.* FROM (SELECT * FROM table1) " "AS foo" + ) def test_table_to_alias_13(self): - t1alias = t1.alias('t1alias') + t1alias = t1.alias("t1alias") vis = sql_util.ClauseAdapter(t1alias) - ff = vis.traverse(func.count(t1.c.col1).label('foo')) - self.assert_compile(select([ff]), - 'SELECT count(t1alias.col1) AS foo FROM ' - 'table1 AS t1alias') + ff = vis.traverse(func.count(t1.c.col1).label("foo")) + self.assert_compile( + select([ff]), + "SELECT count(t1alias.col1) AS foo FROM " "table1 AS t1alias", + ) assert list(_from_objects(ff)) == [t1alias] # def test_table_to_alias_2(self): - # TODO: self.assert_compile(vis.traverse(select([func.count(t1.c - # .col1).l abel('foo')]), clone=True), "SELECT - # count(t1alias.col1) AS foo FROM table1 AS t1alias") + # TODO: self.assert_compile(vis.traverse(select([func.count(t1.c + # .col1).l abel('foo')]), clone=True), "SELECT + # count(t1alias.col1) AS foo FROM table1 AS t1alias") def test_table_to_alias_14(self): - t1alias = t1.alias('t1alias') + t1alias = t1.alias("t1alias") vis = sql_util.ClauseAdapter(t1alias) - t2alias = t2.alias('t2alias') + t2alias = t2.alias("t2alias") vis.chain(sql_util.ClauseAdapter(t2alias)) self.assert_compile( vis.traverse( - select([literal_column('*')], t1.c.col1 == t2.c.col2)), - 'SELECT * FROM table1 AS t1alias, table2 ' - 'AS t2alias WHERE t1alias.col1 = ' - 't2alias.col2') + select([literal_column("*")], t1.c.col1 == t2.c.col2) + ), + "SELECT * FROM table1 AS t1alias, table2 " + "AS t2alias WHERE t1alias.col1 = " + "t2alias.col2", + ) def test_table_to_alias_15(self): - t1alias = t1.alias('t1alias') + t1alias = t1.alias("t1alias") vis = sql_util.ClauseAdapter(t1alias) - t2alias = t2.alias('t2alias') + t2alias = t2.alias("t2alias") vis.chain(sql_util.ClauseAdapter(t2alias)) self.assert_compile( vis.traverse( - select( - ['*'], - t1.c.col1 == t2.c.col2, - from_obj=[ - t1, - t2])), - 'SELECT * FROM table1 AS t1alias, table2 ' - 'AS t2alias WHERE t1alias.col1 = ' - 't2alias.col2') + select(["*"], t1.c.col1 == t2.c.col2, from_obj=[t1, t2]) + ), + "SELECT * FROM table1 AS t1alias, table2 " + "AS t2alias WHERE t1alias.col1 = " + "t2alias.col2", + ) def test_table_to_alias_16(self): - t1alias = t1.alias('t1alias') + t1alias = t1.alias("t1alias") vis = sql_util.ClauseAdapter(t1alias) - t2alias = t2.alias('t2alias') + t2alias = t2.alias("t2alias") vis.chain(sql_util.ClauseAdapter(t2alias)) self.assert_compile( select([t1alias, t2alias]).where( - t1alias.c.col1 == - vis.traverse(select(['*'], - t1.c.col1 == t2.c.col2, - from_obj=[t1, t2]).correlate(t1)) + t1alias.c.col1 + == vis.traverse( + select( + ["*"], t1.c.col1 == t2.c.col2, from_obj=[t1, t2] + ).correlate(t1) + ) ), "SELECT t1alias.col1, t1alias.col2, t1alias.col3, " "t2alias.col1, t2alias.col2, t2alias.col3 " "FROM table1 AS t1alias, table2 AS t2alias " "WHERE t1alias.col1 = " "(SELECT * FROM table2 AS t2alias " - "WHERE t1alias.col1 = t2alias.col2)" + "WHERE t1alias.col1 = t2alias.col2)", ) def test_table_to_alias_17(self): - t1alias = t1.alias('t1alias') + t1alias = t1.alias("t1alias") vis = sql_util.ClauseAdapter(t1alias) - t2alias = t2.alias('t2alias') + t2alias = t2.alias("t2alias") vis.chain(sql_util.ClauseAdapter(t2alias)) self.assert_compile( t2alias.select().where( - t2alias.c.col2 == vis.traverse( + t2alias.c.col2 + == vis.traverse( select( - ['*'], - t1.c.col1 == t2.c.col2, - from_obj=[ - t1, - t2]).correlate(t2))), - 'SELECT t2alias.col1, t2alias.col2, t2alias.col3 ' - 'FROM table2 AS t2alias WHERE t2alias.col2 = ' - '(SELECT * FROM table1 AS t1alias WHERE ' - 't1alias.col1 = t2alias.col2)') + ["*"], t1.c.col1 == t2.c.col2, from_obj=[t1, t2] + ).correlate(t2) + ) + ), + "SELECT t2alias.col1, t2alias.col2, t2alias.col3 " + "FROM table2 AS t2alias WHERE t2alias.col2 = " + "(SELECT * FROM table1 AS t1alias WHERE " + "t1alias.col1 = t2alias.col2)", + ) def test_include_exclude(self): m = MetaData() - a = Table('a', m, - Column('id', Integer, primary_key=True), - Column('xxx_id', Integer, - ForeignKey('a.id', name='adf', use_alter=True) - ) - ) - - e = (a.c.id == a.c.xxx_id) + a = Table( + "a", + m, + Column("id", Integer, primary_key=True), + Column( + "xxx_id", + Integer, + ForeignKey("a.id", name="adf", use_alter=True), + ), + ) + + e = a.c.id == a.c.xxx_id assert str(e) == "a.id = a.xxx_id" b = a.alias() - e = sql_util.ClauseAdapter(b, include_fn=lambda x: x in set([a.c.id]), - equivalents={a.c.id: set([a.c.id])} - ).traverse(e) + e = sql_util.ClauseAdapter( + b, + include_fn=lambda x: x in set([a.c.id]), + equivalents={a.c.id: set([a.c.id])}, + ).traverse(e) assert str(e) == "a_1.id = a.xxx_id" def test_recursive_equivalents(self): m = MetaData() - a = Table('a', m, Column('x', Integer), Column('y', Integer)) - b = Table('b', m, Column('x', Integer), Column('y', Integer)) - c = Table('c', m, Column('x', Integer), Column('y', Integer)) + a = Table("a", m, Column("x", Integer), Column("y", Integer)) + b = Table("b", m, Column("x", Integer), Column("y", Integer)) + c = Table("c", m, Column("x", Integer), Column("y", Integer)) # force a recursion overflow, by linking a.c.x<->c.c.x, and # asking for a nonexistent col. corresponding_column should prevent # endless depth. adapt = sql_util.ClauseAdapter( - b, equivalents={a.c.x: set([c.c.x]), c.c.x: set([a.c.x])}) + b, equivalents={a.c.x: set([c.c.x]), c.c.x: set([a.c.x])} + ) assert adapt._corresponding_column(a.c.x, False) is None def test_multilevel_equivalents(self): m = MetaData() - a = Table('a', m, Column('x', Integer), Column('y', Integer)) - b = Table('b', m, Column('x', Integer), Column('y', Integer)) - c = Table('c', m, Column('x', Integer), Column('y', Integer)) + a = Table("a", m, Column("x", Integer), Column("y", Integer)) + b = Table("b", m, Column("x", Integer), Column("y", Integer)) + c = Table("c", m, Column("x", Integer), Column("y", Integer)) alias = select([a]).select_from(a.join(b, a.c.x == b.c.x)).alias() # two levels of indirection from c.x->b.x->a.x, requires recursive # corresponding_column call adapt = sql_util.ClauseAdapter( - alias, equivalents={b.c.x: set([a.c.x]), c.c.x: set([b.c.x])}) + alias, equivalents={b.c.x: set([a.c.x]), c.c.x: set([b.c.x])} + ) assert adapt._corresponding_column(a.c.x, False) is alias.c.x assert adapt._corresponding_column(c.c.x, False) is alias.c.x def test_join_to_alias(self): metadata = MetaData() - a = Table('a', metadata, - Column('id', Integer, primary_key=True)) - b = Table('b', metadata, - Column('id', Integer, primary_key=True), - Column('aid', Integer, ForeignKey('a.id')), - ) - c = Table('c', metadata, - Column('id', Integer, primary_key=True), - Column('bid', Integer, ForeignKey('b.id')), - ) - - d = Table('d', metadata, - Column('id', Integer, primary_key=True), - Column('aid', Integer, ForeignKey('a.id')), - ) + a = Table("a", metadata, Column("id", Integer, primary_key=True)) + b = Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("aid", Integer, ForeignKey("a.id")), + ) + c = Table( + "c", + metadata, + Column("id", Integer, primary_key=True), + Column("bid", Integer, ForeignKey("b.id")), + ) + + d = Table( + "d", + metadata, + Column("id", Integer, primary_key=True), + Column("aid", Integer, ForeignKey("a.id")), + ) j1 = a.outerjoin(b) j2 = select([j1], use_labels=True) @@ -1407,12 +1417,14 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): j3 = c.join(j2, j2.c.b_id == c.c.bid) j4 = j3.outerjoin(d) - self.assert_compile(j4, - 'c JOIN (SELECT a.id AS a_id, b.id AS ' - 'b_id, b.aid AS b_aid FROM a LEFT OUTER ' - 'JOIN b ON a.id = b.aid) ON b_id = c.bid ' - 'LEFT OUTER JOIN d ON a_id = d.aid') - j5 = j3.alias('foo') + self.assert_compile( + j4, + "c JOIN (SELECT a.id AS a_id, b.id AS " + "b_id, b.aid AS b_aid FROM a LEFT OUTER " + "JOIN b ON a.id = b.aid) ON b_id = c.bid " + "LEFT OUTER JOIN d ON a_id = d.aid", + ) + j5 = j3.alias("foo") j6 = sql_util.ClauseAdapter(j5).copy_and_process([j4])[0] # this statement takes c join(a join b), wraps it inside an @@ -1420,14 +1432,16 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): # right side "left outer join d" stays the same, except "d" # joins against foo.a_id instead of plain "a_id" - self.assert_compile(j6, - '(SELECT c.id AS c_id, c.bid AS c_bid, ' - 'a_id AS a_id, b_id AS b_id, b_aid AS ' - 'b_aid FROM c JOIN (SELECT a.id AS a_id, ' - 'b.id AS b_id, b.aid AS b_aid FROM a LEFT ' - 'OUTER JOIN b ON a.id = b.aid) ON b_id = ' - 'c.bid) AS foo LEFT OUTER JOIN d ON ' - 'foo.a_id = d.aid') + self.assert_compile( + j6, + "(SELECT c.id AS c_id, c.bid AS c_bid, " + "a_id AS a_id, b_id AS b_id, b_aid AS " + "b_aid FROM c JOIN (SELECT a.id AS a_id, " + "b.id AS b_id, b.aid AS b_aid FROM a LEFT " + "OUTER JOIN b ON a.id = b.aid) ON b_id = " + "c.bid) AS foo LEFT OUTER JOIN d ON " + "foo.a_id = d.aid", + ) def test_derived_from(self): assert select([t1]).is_derived_from(t1) @@ -1435,7 +1449,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): assert not t1.is_derived_from(select([t1])) assert t1.alias().is_derived_from(t1) - s1 = select([t1, t2]).alias('foo') + s1 = select([t1, t2]).alias("foo") s2 = select([s1]).limit(5).offset(10).alias() assert s2.is_derived_from(s1) s2 = s2._clone() @@ -1445,111 +1459,117 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): # original issue from ticket #904 - s1 = select([t1]).alias('foo') + s1 = select([t1]).alias("foo") s2 = select([s1]).limit(5).offset(10).alias() - self.assert_compile(sql_util.ClauseAdapter(s2).traverse(s1), - 'SELECT foo.col1, foo.col2, foo.col3 FROM ' - '(SELECT table1.col1 AS col1, table1.col2 ' - 'AS col2, table1.col3 AS col3 FROM table1) ' - 'AS foo LIMIT :param_1 OFFSET :param_2', - {'param_1': 5, 'param_2': 10}) + self.assert_compile( + sql_util.ClauseAdapter(s2).traverse(s1), + "SELECT foo.col1, foo.col2, foo.col3 FROM " + "(SELECT table1.col1 AS col1, table1.col2 " + "AS col2, table1.col3 AS col3 FROM table1) " + "AS foo LIMIT :param_1 OFFSET :param_2", + {"param_1": 5, "param_2": 10}, + ) def test_aliasedselect_to_aliasedselect_join(self): - s1 = select([t1]).alias('foo') + s1 = select([t1]).alias("foo") s2 = select([s1]).limit(5).offset(10).alias() j = s1.outerjoin(t2, s1.c.col1 == t2.c.col1) - self.assert_compile(sql_util.ClauseAdapter(s2).traverse(j).select(), - 'SELECT anon_1.col1, anon_1.col2, ' - 'anon_1.col3, table2.col1, table2.col2, ' - 'table2.col3 FROM (SELECT foo.col1 AS ' - 'col1, foo.col2 AS col2, foo.col3 AS col3 ' - 'FROM (SELECT table1.col1 AS col1, ' - 'table1.col2 AS col2, table1.col3 AS col3 ' - 'FROM table1) AS foo LIMIT :param_1 OFFSET ' - ':param_2) AS anon_1 LEFT OUTER JOIN ' - 'table2 ON anon_1.col1 = table2.col1', - {'param_1': 5, 'param_2': 10}) + self.assert_compile( + sql_util.ClauseAdapter(s2).traverse(j).select(), + "SELECT anon_1.col1, anon_1.col2, " + "anon_1.col3, table2.col1, table2.col2, " + "table2.col3 FROM (SELECT foo.col1 AS " + "col1, foo.col2 AS col2, foo.col3 AS col3 " + "FROM (SELECT table1.col1 AS col1, " + "table1.col2 AS col2, table1.col3 AS col3 " + "FROM table1) AS foo LIMIT :param_1 OFFSET " + ":param_2) AS anon_1 LEFT OUTER JOIN " + "table2 ON anon_1.col1 = table2.col1", + {"param_1": 5, "param_2": 10}, + ) def test_aliasedselect_to_aliasedselect_join_nested_table(self): - s1 = select([t1]).alias('foo') + s1 = select([t1]).alias("foo") s2 = select([s1]).limit(5).offset(10).alias() - talias = t1.alias('bar') + talias = t1.alias("bar") assert not s2.is_derived_from(talias) j = s1.outerjoin(talias, s1.c.col1 == talias.c.col1) - self.assert_compile(sql_util.ClauseAdapter(s2).traverse(j).select(), - 'SELECT anon_1.col1, anon_1.col2, ' - 'anon_1.col3, bar.col1, bar.col2, bar.col3 ' - 'FROM (SELECT foo.col1 AS col1, foo.col2 ' - 'AS col2, foo.col3 AS col3 FROM (SELECT ' - 'table1.col1 AS col1, table1.col2 AS col2, ' - 'table1.col3 AS col3 FROM table1) AS foo ' - 'LIMIT :param_1 OFFSET :param_2) AS anon_1 ' - 'LEFT OUTER JOIN table1 AS bar ON ' - 'anon_1.col1 = bar.col1', {'param_1': 5, - 'param_2': 10}) + self.assert_compile( + sql_util.ClauseAdapter(s2).traverse(j).select(), + "SELECT anon_1.col1, anon_1.col2, " + "anon_1.col3, bar.col1, bar.col2, bar.col3 " + "FROM (SELECT foo.col1 AS col1, foo.col2 " + "AS col2, foo.col3 AS col3 FROM (SELECT " + "table1.col1 AS col1, table1.col2 AS col2, " + "table1.col3 AS col3 FROM table1) AS foo " + "LIMIT :param_1 OFFSET :param_2) AS anon_1 " + "LEFT OUTER JOIN table1 AS bar ON " + "anon_1.col1 = bar.col1", + {"param_1": 5, "param_2": 10}, + ) def test_functions(self): self.assert_compile( - sql_util.ClauseAdapter(t1.alias()). - traverse(func.count(t1.c.col1)), - 'count(table1_1.col1)') + sql_util.ClauseAdapter(t1.alias()).traverse(func.count(t1.c.col1)), + "count(table1_1.col1)", + ) s = select([func.count(t1.c.col1)]) - self.assert_compile(sql_util.ClauseAdapter(t1.alias()).traverse(s), - 'SELECT count(table1_1.col1) AS count_1 ' - 'FROM table1 AS table1_1') + self.assert_compile( + sql_util.ClauseAdapter(t1.alias()).traverse(s), + "SELECT count(table1_1.col1) AS count_1 " + "FROM table1 AS table1_1", + ) def test_recursive(self): metadata = MetaData() - a = Table('a', metadata, - Column('id', Integer, primary_key=True)) - b = Table('b', metadata, - Column('id', Integer, primary_key=True), - Column('aid', Integer, ForeignKey('a.id')), - ) - c = Table('c', metadata, - Column('id', Integer, primary_key=True), - Column('bid', Integer, ForeignKey('b.id')), - ) - - d = Table('d', metadata, - Column('id', Integer, primary_key=True), - Column('aid', Integer, ForeignKey('a.id')), - ) + a = Table("a", metadata, Column("id", Integer, primary_key=True)) + b = Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("aid", Integer, ForeignKey("a.id")), + ) + c = Table( + "c", + metadata, + Column("id", Integer, primary_key=True), + Column("bid", Integer, ForeignKey("b.id")), + ) + + d = Table( + "d", + metadata, + Column("id", Integer, primary_key=True), + Column("aid", Integer, ForeignKey("a.id")), + ) u = union( a.join(b).select().apply_labels(), - a.join(d).select().apply_labels() + a.join(d).select().apply_labels(), ).alias() self.assert_compile( - sql_util.ClauseAdapter(u). - traverse(select([c.c.bid]).where(c.c.bid == u.c.b_aid)), + sql_util.ClauseAdapter(u).traverse( + select([c.c.bid]).where(c.c.bid == u.c.b_aid) + ), "SELECT c.bid " "FROM c, (SELECT a.id AS a_id, b.id AS b_id, b.aid AS b_aid " "FROM a JOIN b ON a.id = b.aid UNION SELECT a.id AS a_id, d.id " "AS d_id, d.aid AS d_aid " "FROM a JOIN d ON a.id = d.aid) AS anon_1 " - "WHERE c.bid = anon_1.b_aid" + "WHERE c.bid = anon_1.b_aid", ) - t1 = table("table1", - column("col1"), - column("col2"), - column("col3"), - ) - t2 = table("table2", - column("col1"), - column("col2"), - column("col3"), - ) + t1 = table("table1", column("col1"), column("col2"), column("col3")) + t2 = table("table2", column("col1"), column("col2"), column("col3")) def test_label_anonymize_one(self): t1a = t1.alias() adapter = sql_util.ClauseAdapter(t1a, anonymize_labels=True) - expr = select([t1.c.col2]).where(t1.c.col3 == 5).label('expr') + expr = select([t1.c.col2]).where(t1.c.col3 == 5).label("expr") expr_adapted = adapter.traverse(expr) stmt = select([expr, expr_adapted]).order_by(expr, expr_adapted) @@ -1560,7 +1580,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): "AS expr, " "(SELECT table1_1.col2 FROM table1 AS table1_1 " "WHERE table1_1.col3 = :col3_2) AS anon_1 " - "ORDER BY expr, anon_1" + "ORDER BY expr, anon_1", ) def test_label_anonymize_two(self): @@ -1578,14 +1598,14 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): "AS anon_1, " "(SELECT table1_1.col2 FROM table1 AS table1_1 " "WHERE table1_1.col3 = :col3_2) AS anon_2 " - "ORDER BY anon_1, anon_2" + "ORDER BY anon_1, anon_2", ) def test_label_anonymize_three(self): t1a = t1.alias() adapter = sql_util.ColumnAdapter( - t1a, anonymize_labels=True, - allow_label_resolve=False) + t1a, anonymize_labels=True, allow_label_resolve=False + ) expr = select([t1.c.col2]).where(t1.c.col3 == 5).label(None) l1 = expr @@ -1603,236 +1623,235 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): class SpliceJoinsTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_class(cls): global table1, table2, table3, table4 def _table(name): - return table(name, column('col1'), column('col2'), - column('col3')) + return table(name, column("col1"), column("col2"), column("col3")) table1, table2, table3, table4 = [ - _table(name) for name in ( - 'table1', 'table2', 'table3', 'table4')] + _table(name) for name in ("table1", "table2", "table3", "table4") + ] def test_splice(self): t1, t2, t3, t4 = table1, table2, table1.alias(), table2.alias() - j = t1.join( - t2, - t1.c.col1 == t2.c.col1).join( - t3, - t2.c.col1 == t3.c.col1).join( - t4, - t4.c.col1 == t1.c.col1) + j = ( + t1.join(t2, t1.c.col1 == t2.c.col1) + .join(t3, t2.c.col1 == t3.c.col1) + .join(t4, t4.c.col1 == t1.c.col1) + ) s = select([t1]).where(t1.c.col2 < 5).alias() - self.assert_compile(sql_util.splice_joins(s, j), - '(SELECT table1.col1 AS col1, table1.col2 ' - 'AS col2, table1.col3 AS col3 FROM table1 ' - 'WHERE table1.col2 < :col2_1) AS anon_1 ' - 'JOIN table2 ON anon_1.col1 = table2.col1 ' - 'JOIN table1 AS table1_1 ON table2.col1 = ' - 'table1_1.col1 JOIN table2 AS table2_1 ON ' - 'table2_1.col1 = anon_1.col1') + self.assert_compile( + sql_util.splice_joins(s, j), + "(SELECT table1.col1 AS col1, table1.col2 " + "AS col2, table1.col3 AS col3 FROM table1 " + "WHERE table1.col2 < :col2_1) AS anon_1 " + "JOIN table2 ON anon_1.col1 = table2.col1 " + "JOIN table1 AS table1_1 ON table2.col1 = " + "table1_1.col1 JOIN table2 AS table2_1 ON " + "table2_1.col1 = anon_1.col1", + ) def test_stop_on(self): t1, t2, t3 = table1, table2, table3 j1 = t1.join(t2, t1.c.col1 == t2.c.col1) j2 = j1.join(t3, t2.c.col1 == t3.c.col1) s = select([t1]).select_from(j1).alias() - self.assert_compile(sql_util.splice_joins(s, j2), - '(SELECT table1.col1 AS col1, table1.col2 ' - 'AS col2, table1.col3 AS col3 FROM table1 ' - 'JOIN table2 ON table1.col1 = table2.col1) ' - 'AS anon_1 JOIN table2 ON anon_1.col1 = ' - 'table2.col1 JOIN table3 ON table2.col1 = ' - 'table3.col1') - self.assert_compile(sql_util.splice_joins(s, j2, j1), - '(SELECT table1.col1 AS col1, table1.col2 ' - 'AS col2, table1.col3 AS col3 FROM table1 ' - 'JOIN table2 ON table1.col1 = table2.col1) ' - 'AS anon_1 JOIN table3 ON table2.col1 = ' - 'table3.col1') + self.assert_compile( + sql_util.splice_joins(s, j2), + "(SELECT table1.col1 AS col1, table1.col2 " + "AS col2, table1.col3 AS col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col1) " + "AS anon_1 JOIN table2 ON anon_1.col1 = " + "table2.col1 JOIN table3 ON table2.col1 = " + "table3.col1", + ) + self.assert_compile( + sql_util.splice_joins(s, j2, j1), + "(SELECT table1.col1 AS col1, table1.col2 " + "AS col2, table1.col3 AS col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col1) " + "AS anon_1 JOIN table3 ON table2.col1 = " + "table3.col1", + ) def test_splice_2(self): t2a = table2.alias() t3a = table3.alias() - j1 = table1.join( - t2a, - table1.c.col1 == t2a.c.col1).join( - t3a, - t2a.c.col2 == t3a.c.col2) + j1 = table1.join(t2a, table1.c.col1 == t2a.c.col1).join( + t3a, t2a.c.col2 == t3a.c.col2 + ) t2b = table4.alias() j2 = table1.join(t2b, table1.c.col3 == t2b.c.col3) - self.assert_compile(sql_util.splice_joins(table1, j1), - 'table1 JOIN table2 AS table2_1 ON ' - 'table1.col1 = table2_1.col1 JOIN table3 ' - 'AS table3_1 ON table2_1.col2 = ' - 'table3_1.col2') - self.assert_compile(sql_util.splice_joins(table1, j2), - 'table1 JOIN table4 AS table4_1 ON ' - 'table1.col3 = table4_1.col3') - self.assert_compile( - sql_util.splice_joins( - sql_util.splice_joins( - table1, - j1), - j2), - 'table1 JOIN table2 AS table2_1 ON ' - 'table1.col1 = table2_1.col1 JOIN table3 ' - 'AS table3_1 ON table2_1.col2 = ' - 'table3_1.col2 JOIN table4 AS table4_1 ON ' - 'table1.col3 = table4_1.col3') + self.assert_compile( + sql_util.splice_joins(table1, j1), + "table1 JOIN table2 AS table2_1 ON " + "table1.col1 = table2_1.col1 JOIN table3 " + "AS table3_1 ON table2_1.col2 = " + "table3_1.col2", + ) + self.assert_compile( + sql_util.splice_joins(table1, j2), + "table1 JOIN table4 AS table4_1 ON " "table1.col3 = table4_1.col3", + ) + self.assert_compile( + sql_util.splice_joins(sql_util.splice_joins(table1, j1), j2), + "table1 JOIN table2 AS table2_1 ON " + "table1.col1 = table2_1.col1 JOIN table3 " + "AS table3_1 ON table2_1.col2 = " + "table3_1.col2 JOIN table4 AS table4_1 ON " + "table1.col3 = table4_1.col3", + ) class SelectTest(fixtures.TestBase, AssertsCompiledSQL): """tests the generative capability of Select""" - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_class(cls): global t1, t2 - t1 = table("table1", - column("col1"), - column("col2"), - column("col3"), - ) - t2 = table("table2", - column("col1"), - column("col2"), - column("col3"), - ) + t1 = table("table1", column("col1"), column("col2"), column("col3")) + t2 = table("table2", column("col1"), column("col2"), column("col3")) def test_columns(self): s = t1.select() - self.assert_compile(s, - 'SELECT table1.col1, table1.col2, ' - 'table1.col3 FROM table1') - select_copy = s.column(column('yyy')) - self.assert_compile(select_copy, - 'SELECT table1.col1, table1.col2, ' - 'table1.col3, yyy FROM table1') + self.assert_compile( + s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1" + ) + select_copy = s.column(column("yyy")) + self.assert_compile( + select_copy, + "SELECT table1.col1, table1.col2, " "table1.col3, yyy FROM table1", + ) assert s.columns is not select_copy.columns assert s._columns is not select_copy._columns assert s._raw_columns is not select_copy._raw_columns - self.assert_compile(s, - 'SELECT table1.col1, table1.col2, ' - 'table1.col3 FROM table1') + self.assert_compile( + s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1" + ) def test_froms(self): s = t1.select() - self.assert_compile(s, - 'SELECT table1.col1, table1.col2, ' - 'table1.col3 FROM table1') + self.assert_compile( + s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1" + ) select_copy = s.select_from(t2) - self.assert_compile(select_copy, - 'SELECT table1.col1, table1.col2, ' - 'table1.col3 FROM table1, table2') + self.assert_compile( + select_copy, + "SELECT table1.col1, table1.col2, " + "table1.col3 FROM table1, table2", + ) assert s._froms is not select_copy._froms - self.assert_compile(s, - 'SELECT table1.col1, table1.col2, ' - 'table1.col3 FROM table1') + self.assert_compile( + s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1" + ) def test_prefixes(self): s = t1.select() - self.assert_compile(s, - 'SELECT table1.col1, table1.col2, ' - 'table1.col3 FROM table1') - select_copy = s.prefix_with('FOOBER') - self.assert_compile(select_copy, - 'SELECT FOOBER table1.col1, table1.col2, ' - 'table1.col3 FROM table1') - self.assert_compile(s, - 'SELECT table1.col1, table1.col2, ' - 'table1.col3 FROM table1') + self.assert_compile( + s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1" + ) + select_copy = s.prefix_with("FOOBER") + self.assert_compile( + select_copy, + "SELECT FOOBER table1.col1, table1.col2, " + "table1.col3 FROM table1", + ) + self.assert_compile( + s, "SELECT table1.col1, table1.col2, " "table1.col3 FROM table1" + ) def test_execution_options(self): - s = select().execution_options(foo='bar') - s2 = s.execution_options(bar='baz') - s3 = s.execution_options(foo='not bar') + s = select().execution_options(foo="bar") + s2 = s.execution_options(bar="baz") + s3 = s.execution_options(foo="not bar") # The original select should not be modified. - assert s._execution_options == dict(foo='bar') + assert s._execution_options == dict(foo="bar") # s2 should have its execution_options based on s, though. - assert s2._execution_options == dict(foo='bar', bar='baz') - assert s3._execution_options == dict(foo='not bar') + assert s2._execution_options == dict(foo="bar", bar="baz") + assert s3._execution_options == dict(foo="not bar") def test_invalid_options(self): assert_raises( - exc.ArgumentError, - select().execution_options, compiled_cache={} + exc.ArgumentError, select().execution_options, compiled_cache={} ) assert_raises( exc.ArgumentError, select().execution_options, - isolation_level='READ_COMMITTED' + isolation_level="READ_COMMITTED", ) # this feature not available yet def _NOTYET_test_execution_options_in_kwargs(self): - s = select(execution_options=dict(foo='bar')) - s2 = s.execution_options(bar='baz') + s = select(execution_options=dict(foo="bar")) + s2 = s.execution_options(bar="baz") # The original select should not be modified. - assert s._execution_options == dict(foo='bar') + assert s._execution_options == dict(foo="bar") # s2 should have its execution_options based on s, though. - assert s2._execution_options == dict(foo='bar', bar='baz') + assert s2._execution_options == dict(foo="bar", bar="baz") # this feature not available yet def _NOTYET_test_execution_options_in_text(self): - s = text('select 42', execution_options=dict(foo='bar')) - assert s._execution_options == dict(foo='bar') + s = text("select 42", execution_options=dict(foo="bar")) + assert s._execution_options == dict(foo="bar") class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL): """Tests the generative capability of Insert, Update""" - __dialect__ = 'default' + __dialect__ = "default" # fixme: consolidate converage from elsewhere here and expand @classmethod def setup_class(cls): global t1, t2 - t1 = table("table1", - column("col1"), - column("col2"), - column("col3"), - ) - t2 = table("table2", - column("col1"), - column("col2"), - column("col3"), - ) + t1 = table("table1", column("col1"), column("col2"), column("col3")) + t2 = table("table2", column("col1"), column("col2"), column("col3")) def test_prefixes(self): i = t1.insert() - self.assert_compile(i, - "INSERT INTO table1 (col1, col2, col3) " - "VALUES (:col1, :col2, :col3)") + self.assert_compile( + i, + "INSERT INTO table1 (col1, col2, col3) " + "VALUES (:col1, :col2, :col3)", + ) gen = i.prefix_with("foober") - self.assert_compile(gen, - "INSERT foober INTO table1 (col1, col2, col3) " - "VALUES (:col1, :col2, :col3)") + self.assert_compile( + gen, + "INSERT foober INTO table1 (col1, col2, col3) " + "VALUES (:col1, :col2, :col3)", + ) - self.assert_compile(i, - "INSERT INTO table1 (col1, col2, col3) " - "VALUES (:col1, :col2, :col3)") + self.assert_compile( + i, + "INSERT INTO table1 (col1, col2, col3) " + "VALUES (:col1, :col2, :col3)", + ) - i2 = t1.insert(prefixes=['squiznart']) - self.assert_compile(i2, - "INSERT squiznart INTO table1 (col1, col2, col3) " - "VALUES (:col1, :col2, :col3)") + i2 = t1.insert(prefixes=["squiznart"]) + self.assert_compile( + i2, + "INSERT squiznart INTO table1 (col1, col2, col3) " + "VALUES (:col1, :col2, :col3)", + ) gen2 = i2.prefix_with("quux") - self.assert_compile(gen2, - "INSERT squiznart quux INTO " - "table1 (col1, col2, col3) " - "VALUES (:col1, :col2, :col3)") + self.assert_compile( + gen2, + "INSERT squiznart quux INTO " + "table1 (col1, col2, col3) " + "VALUES (:col1, :col2, :col3)", + ) def test_add_kwarg(self): i = t1.insert() @@ -1857,11 +1876,13 @@ class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL): i = t1.insert() eq_(i.parameters, None) i = i.values([(5, 6, 7), (8, 9, 10)]) - eq_(i.parameters, [ - {"col1": 5, "col2": 6, "col3": 7}, - {"col1": 8, "col2": 9, "col3": 10}, - ] - ) + eq_( + i.parameters, + [ + {"col1": 5, "col2": 6, "col3": 7}, + {"col1": 8, "col2": 9, "col3": 10}, + ], + ) def test_inline_values_single(self): i = t1.insert(values={"col1": 5}) @@ -1895,7 +1916,8 @@ class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL): assert_raises_message( exc.InvalidRequestError, "This construct already has multiple parameter sets.", - i.values, col2=7 + i.values, + col2=7, ) def test_cant_mix_single_multi_formats_dict_to_list(self): @@ -1904,7 +1926,8 @@ class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL): exc.ArgumentError, "Can't mix single-values and multiple values " "formats in one statement", - i.values, [{"col1": 6}] + i.values, + [{"col1": 6}], ) def test_cant_mix_single_multi_formats_list_to_dict(self): @@ -1913,7 +1936,8 @@ class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL): exc.ArgumentError, "Can't mix single-values and multiple values " "formats in one statement", - i.values, {"col1": 5} + i.values, + {"col1": 5}, ) def test_erroneous_multi_args_dicts(self): @@ -1922,7 +1946,9 @@ class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL): exc.ArgumentError, "Only a single dictionary/tuple or list of " "dictionaries/tuples is accepted positionally.", - i.values, {"col1": 5}, {"col1": 7} + i.values, + {"col1": 5}, + {"col1": 7}, ) def test_erroneous_multi_args_tuples(self): @@ -1931,7 +1957,9 @@ class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL): exc.ArgumentError, "Only a single dictionary/tuple or list of " "dictionaries/tuples is accepted positionally.", - i.values, (5, 6, 7), (8, 9, 10) + i.values, + (5, 6, 7), + (8, 9, 10), ) def test_erroneous_multi_args_plus_kw(self): @@ -1939,7 +1967,9 @@ class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL): assert_raises_message( exc.ArgumentError, "Can't pass kwargs and multiple parameter sets simultaneously", - i.values, [{"col1": 5}], col2=7 + i.values, + [{"col1": 5}], + col2=7, ) def test_update_no_support_multi_values(self): @@ -1947,12 +1977,14 @@ class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL): assert_raises_message( exc.InvalidRequestError, "This construct does not support multiple parameter sets.", - u.values, [{"col1": 5}, {"col1": 7}] + u.values, + [{"col1": 5}, {"col1": 7}], ) def test_update_no_support_multi_constructor(self): assert_raises_message( exc.InvalidRequestError, "This construct does not support multiple parameter sets.", - t1.update, values=[{"col1": 5}, {"col1": 7}] + t1.update, + values=[{"col1": 5}, {"col1": 7}], ) diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py index 729c420c00..3643deabdf 100644 --- a/test/sql/test_insert.py +++ b/test/sql/test_insert.py @@ -1,144 +1,167 @@ #! coding:utf-8 -from sqlalchemy import Column, Integer, MetaData, String, Table,\ - bindparam, exc, func, insert, select, column, text, table,\ - Sequence +from sqlalchemy import ( + Column, + Integer, + MetaData, + String, + Table, + bindparam, + exc, + func, + insert, + select, + column, + text, + table, + Sequence, +) from sqlalchemy.dialects import mysql, postgresql from sqlalchemy.engine import default -from sqlalchemy.testing import AssertsCompiledSQL,\ - assert_raises_message, fixtures, eq_, expect_warnings, assert_raises +from sqlalchemy.testing import ( + AssertsCompiledSQL, + assert_raises_message, + fixtures, + eq_, + expect_warnings, + assert_raises, +) from sqlalchemy.sql import crud class _InsertTestBase(object): - @classmethod def define_tables(cls, metadata): - Table('mytable', metadata, - Column('myid', Integer), - Column('name', String(30)), - Column('description', String(30))) - Table('myothertable', metadata, - Column('otherid', Integer, primary_key=True), - Column('othername', String(30))) - Table('table_w_defaults', metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer, default=10), - Column('y', Integer, server_default=text('5')), - Column('z', Integer, default=lambda: 10)) + Table( + "mytable", + metadata, + Column("myid", Integer), + Column("name", String(30)), + Column("description", String(30)), + ) + Table( + "myothertable", + metadata, + Column("otherid", Integer, primary_key=True), + Column("othername", String(30)), + ) + Table( + "table_w_defaults", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer, default=10), + Column("y", Integer, server_default=text("5")), + Column("z", Integer, default=lambda: 10), + ) class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_binds_that_match_columns(self): """test bind params named after column names replace the normal SET/VALUES generation.""" - t = table('foo', column('x'), column('y')) + t = table("foo", column("x"), column("y")) - i = t.insert().values(x=3 + bindparam('x')) - self.assert_compile(i, - "INSERT INTO foo (x) VALUES ((:param_1 + :x))") + i = t.insert().values(x=3 + bindparam("x")) + self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x))") self.assert_compile( i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x), :y)", - params={ - 'x': 1, - 'y': 2}) + params={"x": 1, "y": 2}, + ) - i = t.insert().values(x=bindparam('y')) + i = t.insert().values(x=bindparam("y")) self.assert_compile(i, "INSERT INTO foo (x) VALUES (:y)") - i = t.insert().values(x=bindparam('y'), y=5) + i = t.insert().values(x=bindparam("y"), y=5) assert_raises(exc.CompileError, i.compile) - i = t.insert().values(x=3 + bindparam('y'), y=5) + i = t.insert().values(x=3 + bindparam("y"), y=5) assert_raises(exc.CompileError, i.compile) - i = t.insert().values(x=3 + bindparam('x2')) - self.assert_compile(i, - "INSERT INTO foo (x) VALUES ((:param_1 + :x2))") + i = t.insert().values(x=3 + bindparam("x2")) + self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x2))") self.assert_compile( - i, - "INSERT INTO foo (x) VALUES ((:param_1 + :x2))", - params={}) + i, "INSERT INTO foo (x) VALUES ((:param_1 + :x2))", params={} + ) self.assert_compile( i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x2), :y)", - params={ - 'x': 1, - 'y': 2}) + params={"x": 1, "y": 2}, + ) self.assert_compile( i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x2), :y)", - params={ - 'x2': 1, - 'y': 2}) + params={"x2": 1, "y": 2}, + ) def test_insert_literal_binds(self): table1 = self.tables.mytable - stmt = table1.insert().values(myid=3, name='jack') + stmt = table1.insert().values(myid=3, name="jack") self.assert_compile( stmt, "INSERT INTO mytable (myid, name) VALUES (3, 'jack')", - literal_binds=True) + literal_binds=True, + ) def test_insert_literal_binds_sequence_notimplemented(self): - table = Table('x', MetaData(), Column('y', Integer, Sequence('y_seq'))) + table = Table("x", MetaData(), Column("y", Integer, Sequence("y_seq"))) dialect = default.DefaultDialect() dialect.supports_sequences = True - stmt = table.insert().values(myid=3, name='jack') + stmt = table.insert().values(myid=3, name="jack") assert_raises( NotImplementedError, stmt.compile, - compile_kwargs=dict(literal_binds=True), dialect=dialect + compile_kwargs=dict(literal_binds=True), + dialect=dialect, ) def test_inline_defaults(self): m = MetaData() - foo = Table('foo', m, - Column('id', Integer)) + foo = Table("foo", m, Column("id", Integer)) - t = Table('test', m, - Column('col1', Integer, default=func.foo(1)), - Column('col2', Integer, default=select( - [func.coalesce(func.max(foo.c.id))])), - ) + t = Table( + "test", + m, + Column("col1", Integer, default=func.foo(1)), + Column( + "col2", + Integer, + default=select([func.coalesce(func.max(foo.c.id))]), + ), + ) self.assert_compile( - t.insert( - inline=True, values={}), + t.insert(inline=True, values={}), "INSERT INTO test (col1, col2) VALUES (foo(:foo_1), " "(SELECT coalesce(max(foo.id)) AS coalesce_1 FROM " - "foo))") + "foo))", + ) def test_generic_insert_bind_params_all_columns(self): table1 = self.tables.mytable - self.assert_compile(insert(table1), - 'INSERT INTO mytable (myid, name, description) ' - 'VALUES (:myid, :name, :description)') + self.assert_compile( + insert(table1), + "INSERT INTO mytable (myid, name, description) " + "VALUES (:myid, :name, :description)", + ) def test_insert_with_values_dict(self): table1 = self.tables.mytable - checkparams = { - 'myid': 3, - 'name': 'jack' - } + checkparams = {"myid": 3, "name": "jack"} self.assert_compile( - insert( - table1, - dict( - myid=3, - name='jack')), - 'INSERT INTO mytable (myid, name) VALUES (:myid, :name)', - checkparams=checkparams) + insert(table1, dict(myid=3, name="jack")), + "INSERT INTO mytable (myid, name) VALUES (:myid, :name)", + checkparams=checkparams, + ) def test_unconsumed_names_kwargs(self): t = table("t", column("x"), column("y")) @@ -151,136 +174,133 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): def test_bindparam_name_no_consume_error(self): t = table("t", column("x"), column("y")) # bindparam names don't get counted - i = t.insert().values(x=3 + bindparam('x2')) - self.assert_compile( - i, - "INSERT INTO t (x) VALUES ((:param_1 + :x2))" - ) + i = t.insert().values(x=3 + bindparam("x2")) + self.assert_compile(i, "INSERT INTO t (x) VALUES ((:param_1 + :x2))") # even if in the params list - i = t.insert().values(x=3 + bindparam('x2')) + i = t.insert().values(x=3 + bindparam("x2")) self.assert_compile( - i, - "INSERT INTO t (x) VALUES ((:param_1 + :x2))", - params={"x2": 1} + i, "INSERT INTO t (x) VALUES ((:param_1 + :x2))", params={"x2": 1} ) def test_unconsumed_names_values_dict(self): table1 = self.tables.mytable - checkparams = { - 'myid': 3, - 'name': 'jack', - 'unknowncol': 'oops' - } + checkparams = {"myid": 3, "name": "jack", "unknowncol": "oops"} stmt = insert(table1, values=checkparams) assert_raises_message( exc.CompileError, - 'Unconsumed column names: unknowncol', + "Unconsumed column names: unknowncol", stmt.compile, - dialect=postgresql.dialect() + dialect=postgresql.dialect(), ) def test_unconsumed_names_multi_values_dict(self): table1 = self.tables.mytable - checkparams = [{ - 'myid': 3, - 'name': 'jack', - 'unknowncol': 'oops' - }, { - 'myid': 4, - 'name': 'someone', - 'unknowncol': 'oops' - }] + checkparams = [ + {"myid": 3, "name": "jack", "unknowncol": "oops"}, + {"myid": 4, "name": "someone", "unknowncol": "oops"}, + ] stmt = insert(table1, values=checkparams) assert_raises_message( exc.CompileError, - 'Unconsumed column names: unknowncol', + "Unconsumed column names: unknowncol", stmt.compile, - dialect=postgresql.dialect() + dialect=postgresql.dialect(), ) def test_insert_with_values_tuple(self): table1 = self.tables.mytable checkparams = { - 'myid': 3, - 'name': 'jack', - 'description': 'mydescription' + "myid": 3, + "name": "jack", + "description": "mydescription", } - self.assert_compile(insert(table1, (3, 'jack', 'mydescription')), - 'INSERT INTO mytable (myid, name, description) ' - 'VALUES (:myid, :name, :description)', - checkparams=checkparams) + self.assert_compile( + insert(table1, (3, "jack", "mydescription")), + "INSERT INTO mytable (myid, name, description) " + "VALUES (:myid, :name, :description)", + checkparams=checkparams, + ) def test_insert_with_values_func(self): table1 = self.tables.mytable - self.assert_compile(insert(table1, values=dict(myid=func.lala())), - 'INSERT INTO mytable (myid) VALUES (lala())') + self.assert_compile( + insert(table1, values=dict(myid=func.lala())), + "INSERT INTO mytable (myid) VALUES (lala())", + ) def test_insert_with_user_supplied_bind_params(self): table1 = self.tables.mytable values = { - table1.c.myid: bindparam('userid'), - table1.c.name: bindparam('username') + table1.c.myid: bindparam("userid"), + table1.c.name: bindparam("username"), } self.assert_compile( - insert( - table1, - values), - 'INSERT INTO mytable (myid, name) VALUES (:userid, :username)') + insert(table1, values), + "INSERT INTO mytable (myid, name) VALUES (:userid, :username)", + ) def test_insert_values(self): table1 = self.tables.mytable - values1 = {table1.c.myid: bindparam('userid')} - values2 = {table1.c.name: bindparam('username')} + values1 = {table1.c.myid: bindparam("userid")} + values2 = {table1.c.name: bindparam("username")} self.assert_compile( - insert( - table1, - values=values1).values(values2), - 'INSERT INTO mytable (myid, name) VALUES (:userid, :username)') + insert(table1, values=values1).values(values2), + "INSERT INTO mytable (myid, name) VALUES (:userid, :username)", + ) def test_prefix_with(self): table1 = self.tables.mytable - stmt = table1.insert().\ - prefix_with('A', 'B', dialect='mysql').\ - prefix_with('C', 'D') + stmt = ( + table1.insert() + .prefix_with("A", "B", dialect="mysql") + .prefix_with("C", "D") + ) self.assert_compile( stmt, - 'INSERT C D INTO mytable (myid, name, description) ' - 'VALUES (:myid, :name, :description)') + "INSERT C D INTO mytable (myid, name, description) " + "VALUES (:myid, :name, :description)", + ) self.assert_compile( stmt, - 'INSERT A B C D INTO mytable (myid, name, description) ' - 'VALUES (%s, %s, %s)', - dialect=mysql.dialect()) + "INSERT A B C D INTO mytable (myid, name, description) " + "VALUES (%s, %s, %s)", + dialect=mysql.dialect(), + ) def test_inline_default(self): metadata = MetaData() - table = Table('sometable', metadata, - Column('id', Integer, primary_key=True), - Column('foo', Integer, default=func.foobar())) + table = Table( + "sometable", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer, default=func.foobar()), + ) - self.assert_compile(table.insert(values={}, inline=True), - 'INSERT INTO sometable (foo) VALUES (foobar())') + self.assert_compile( + table.insert(values={}, inline=True), + "INSERT INTO sometable (foo) VALUES (foobar())", + ) self.assert_compile( - table.insert( - inline=True), - 'INSERT INTO sometable (foo) VALUES (foobar())', - params={}) + table.insert(inline=True), + "INSERT INTO sometable (foo) VALUES (foobar())", + params={}, + ) def test_insert_returning_not_in_default(self): table1 = self.tables.mytable @@ -290,68 +310,75 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): exc.CompileError, "RETURNING is not supported by this dialect's statement compiler.", stmt.compile, - dialect=default.DefaultDialect() + dialect=default.DefaultDialect(), ) def test_insert_from_select_returning(self): table1 = self.tables.mytable sel = select([table1.c.myid, table1.c.name]).where( - table1.c.name == 'foo') - ins = self.tables.myothertable.insert().\ - from_select(("otherid", "othername"), sel).returning( - self.tables.myothertable.c.otherid - ) + table1.c.name == "foo" + ) + ins = ( + self.tables.myothertable.insert() + .from_select(("otherid", "othername"), sel) + .returning(self.tables.myothertable.c.otherid) + ) self.assert_compile( ins, "INSERT INTO myothertable (otherid, othername) " "SELECT mytable.myid, mytable.name FROM mytable " "WHERE mytable.name = %(name_1)s RETURNING myothertable.otherid", checkparams={"name_1": "foo"}, - dialect="postgresql" + dialect="postgresql", ) def test_insert_from_select_select(self): table1 = self.tables.mytable sel = select([table1.c.myid, table1.c.name]).where( - table1.c.name == 'foo') - ins = self.tables.myothertable.insert().\ - from_select(("otherid", "othername"), sel) + table1.c.name == "foo" + ) + ins = self.tables.myothertable.insert().from_select( + ("otherid", "othername"), sel + ) self.assert_compile( ins, "INSERT INTO myothertable (otherid, othername) " "SELECT mytable.myid, mytable.name FROM mytable " "WHERE mytable.name = :name_1", - checkparams={"name_1": "foo"} + checkparams={"name_1": "foo"}, ) def test_insert_from_select_seq(self): m = MetaData() t1 = Table( - 't', m, - Column('id', Integer, Sequence('id_seq'), primary_key=True), - Column('data', String) + "t", + m, + Column("id", Integer, Sequence("id_seq"), primary_key=True), + Column("data", String), ) - stmt = t1.insert().from_select(('data', ), select([t1.c.data])) + stmt = t1.insert().from_select(("data",), select([t1.c.data])) self.assert_compile( stmt, "INSERT INTO t (data, id) SELECT t.data, " "nextval('id_seq') AS next_value_1 FROM t", - dialect=postgresql.dialect() + dialect=postgresql.dialect(), ) def test_insert_from_select_cte_one(self): table1 = self.tables.mytable - cte = select([table1.c.name]).where(table1.c.name == 'bar').cte() + cte = select([table1.c.name]).where(table1.c.name == "bar").cte() sel = select([table1.c.myid, table1.c.name]).where( - table1.c.name == cte.c.name) + table1.c.name == cte.c.name + ) - ins = self.tables.myothertable.insert().\ - from_select(("otherid", "othername"), sel) + ins = self.tables.myothertable.insert().from_select( + ("otherid", "othername"), sel + ) self.assert_compile( ins, "WITH anon_1 AS " @@ -360,7 +387,7 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): "INSERT INTO myothertable (otherid, othername) " "SELECT mytable.myid, mytable.name FROM mytable, anon_1 " "WHERE mytable.name = anon_1.name", - checkparams={"name_1": "bar"} + checkparams={"name_1": "bar"}, ) def test_insert_from_select_cte_follows_insert_one(self): @@ -369,13 +396,15 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): table1 = self.tables.mytable - cte = select([table1.c.name]).where(table1.c.name == 'bar').cte() + cte = select([table1.c.name]).where(table1.c.name == "bar").cte() sel = select([table1.c.myid, table1.c.name]).where( - table1.c.name == cte.c.name) + table1.c.name == cte.c.name + ) - ins = self.tables.myothertable.insert().\ - from_select(("otherid", "othername"), sel) + ins = self.tables.myothertable.insert().from_select( + ("otherid", "othername"), sel + ) self.assert_compile( ins, "INSERT INTO myothertable (otherid, othername) " @@ -385,7 +414,7 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): "SELECT mytable.myid, mytable.name FROM mytable, anon_1 " "WHERE mytable.name = anon_1.name", checkparams={"name_1": "bar"}, - dialect=dialect + dialect=dialect, ) def test_insert_from_select_cte_two(self): @@ -400,7 +429,7 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): "WITH c AS (SELECT mytable.myid AS myid, mytable.name AS name, " "mytable.description AS description FROM mytable) " "INSERT INTO mytable (myid, name, description) " - "SELECT c.myid, c.name, c.description FROM c" + "SELECT c.myid, c.name, c.description FROM c", ) def test_insert_from_select_cte_follows_insert_two(self): @@ -418,91 +447,101 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): "WITH c AS (SELECT mytable.myid AS myid, mytable.name AS name, " "mytable.description AS description FROM mytable) " "SELECT c.myid, c.name, c.description FROM c", - dialect=dialect + dialect=dialect, ) def test_insert_from_select_select_alt_ordering(self): table1 = self.tables.mytable sel = select([table1.c.name, table1.c.myid]).where( - table1.c.name == 'foo') - ins = self.tables.myothertable.insert().\ - from_select(("othername", "otherid"), sel) + table1.c.name == "foo" + ) + ins = self.tables.myothertable.insert().from_select( + ("othername", "otherid"), sel + ) self.assert_compile( ins, "INSERT INTO myothertable (othername, otherid) " "SELECT mytable.name, mytable.myid FROM mytable " "WHERE mytable.name = :name_1", - checkparams={"name_1": "foo"} + checkparams={"name_1": "foo"}, ) def test_insert_from_select_no_defaults(self): metadata = MetaData() - table = Table('sometable', metadata, - Column('id', Integer, primary_key=True), - Column('foo', Integer, default=func.foobar())) + table = Table( + "sometable", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer, default=func.foobar()), + ) table1 = self.tables.mytable - sel = select([table1.c.myid]).where(table1.c.name == 'foo') - ins = table.insert().\ - from_select(["id"], sel, include_defaults=False) + sel = select([table1.c.myid]).where(table1.c.name == "foo") + ins = table.insert().from_select(["id"], sel, include_defaults=False) self.assert_compile( ins, "INSERT INTO sometable (id) SELECT mytable.myid " "FROM mytable WHERE mytable.name = :name_1", - checkparams={"name_1": "foo"} + checkparams={"name_1": "foo"}, ) def test_insert_from_select_with_sql_defaults(self): metadata = MetaData() - table = Table('sometable', metadata, - Column('id', Integer, primary_key=True), - Column('foo', Integer, default=func.foobar())) + table = Table( + "sometable", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer, default=func.foobar()), + ) table1 = self.tables.mytable - sel = select([table1.c.myid]).where(table1.c.name == 'foo') - ins = table.insert().\ - from_select(["id"], sel) + sel = select([table1.c.myid]).where(table1.c.name == "foo") + ins = table.insert().from_select(["id"], sel) self.assert_compile( ins, "INSERT INTO sometable (id, foo) SELECT " "mytable.myid, foobar() AS foobar_1 " "FROM mytable WHERE mytable.name = :name_1", - checkparams={"name_1": "foo"} + checkparams={"name_1": "foo"}, ) def test_insert_from_select_with_python_defaults(self): metadata = MetaData() - table = Table('sometable', metadata, - Column('id', Integer, primary_key=True), - Column('foo', Integer, default=12)) + table = Table( + "sometable", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer, default=12), + ) table1 = self.tables.mytable - sel = select([table1.c.myid]).where(table1.c.name == 'foo') - ins = table.insert().\ - from_select(["id"], sel) + sel = select([table1.c.myid]).where(table1.c.name == "foo") + ins = table.insert().from_select(["id"], sel) self.assert_compile( ins, "INSERT INTO sometable (id, foo) SELECT " "mytable.myid, :foo AS anon_1 " "FROM mytable WHERE mytable.name = :name_1", # value filled in at execution time - checkparams={"name_1": "foo", "foo": None} + checkparams={"name_1": "foo", "foo": None}, ) def test_insert_from_select_override_defaults(self): metadata = MetaData() - table = Table('sometable', metadata, - Column('id', Integer, primary_key=True), - Column('foo', Integer, default=12)) + table = Table( + "sometable", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer, default=12), + ) table1 = self.tables.mytable - sel = select( - [table1.c.myid, table1.c.myid.label('q')]).where( - table1.c.name == 'foo') - ins = table.insert().\ - from_select(["id", "foo"], sel) + sel = select([table1.c.myid, table1.c.myid.label("q")]).where( + table1.c.name == "foo" + ) + ins = table.insert().from_select(["id", "foo"], sel) self.assert_compile( ins, "INSERT INTO sometable (id, foo) SELECT " "mytable.myid, mytable.myid AS q " "FROM mytable WHERE mytable.name = :name_1", - checkparams={"name_1": "foo"} + checkparams={"name_1": "foo"}, ) def test_insert_from_select_fn_defaults(self): @@ -511,159 +550,171 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): def foo(ctx): return 12 - table = Table('sometable', metadata, - Column('id', Integer, primary_key=True), - Column('foo', Integer, default=foo)) + table = Table( + "sometable", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer, default=foo), + ) table1 = self.tables.mytable - sel = select( - [table1.c.myid]).where( - table1.c.name == 'foo') - ins = table.insert().\ - from_select(["id"], sel) + sel = select([table1.c.myid]).where(table1.c.name == "foo") + ins = table.insert().from_select(["id"], sel) self.assert_compile( ins, "INSERT INTO sometable (id, foo) SELECT " "mytable.myid, :foo AS anon_1 " "FROM mytable WHERE mytable.name = :name_1", # value filled in at execution time - checkparams={"name_1": "foo", "foo": None} + checkparams={"name_1": "foo", "foo": None}, ) def test_insert_from_select_dont_mutate_raw_columns(self): # test [ticket:3603] from sqlalchemy import table + table_ = table( - 'mytable', - Column('foo', String), - Column('bar', String, default='baz'), + "mytable", + Column("foo", String), + Column("bar", String, default="baz"), ) stmt = select([table_.c.foo]) - insert = table_.insert().from_select(['foo'], stmt) + insert = table_.insert().from_select(["foo"], stmt) self.assert_compile(stmt, "SELECT mytable.foo FROM mytable") self.assert_compile( insert, "INSERT INTO mytable (foo, bar) " - "SELECT mytable.foo, :bar AS anon_1 FROM mytable" + "SELECT mytable.foo, :bar AS anon_1 FROM mytable", ) self.assert_compile(stmt, "SELECT mytable.foo FROM mytable") self.assert_compile( insert, "INSERT INTO mytable (foo, bar) " - "SELECT mytable.foo, :bar AS anon_1 FROM mytable" + "SELECT mytable.foo, :bar AS anon_1 FROM mytable", ) def test_insert_mix_select_values_exception(self): table1 = self.tables.mytable sel = select([table1.c.myid, table1.c.name]).where( - table1.c.name == 'foo') - ins = self.tables.myothertable.insert().\ - from_select(("otherid", "othername"), sel) + table1.c.name == "foo" + ) + ins = self.tables.myothertable.insert().from_select( + ("otherid", "othername"), sel + ) assert_raises_message( exc.InvalidRequestError, "This construct already inserts from a SELECT", - ins.values, othername="5" + ins.values, + othername="5", ) def test_insert_mix_values_select_exception(self): table1 = self.tables.mytable sel = select([table1.c.myid, table1.c.name]).where( - table1.c.name == 'foo') + table1.c.name == "foo" + ) ins = self.tables.myothertable.insert().values(othername="5") assert_raises_message( exc.InvalidRequestError, "This construct already inserts value expressions", - ins.from_select, ("otherid", "othername"), sel + ins.from_select, + ("otherid", "othername"), + sel, ) def test_insert_from_select_table(self): table1 = self.tables.mytable - ins = self.tables.myothertable.insert().\ - from_select(("otherid", "othername"), table1) + ins = self.tables.myothertable.insert().from_select( + ("otherid", "othername"), table1 + ) # note we aren't checking the number of columns right now self.assert_compile( ins, "INSERT INTO myothertable (otherid, othername) " "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable", - checkparams={} + checkparams={}, ) def test_insert_from_select_union(self): mytable = self.tables.mytable - name = column('name') - description = column('desc') - sel = select( - [name, mytable.c.description], - ).union( + name = column("name") + description = column("desc") + sel = select([name, mytable.c.description]).union( select([name, description]) ) - ins = mytable.insert().\ - from_select( - [mytable.c.name, mytable.c.description], sel) + ins = mytable.insert().from_select( + [mytable.c.name, mytable.c.description], sel + ) self.assert_compile( ins, "INSERT INTO mytable (name, description) " "SELECT name, mytable.description FROM mytable " - 'UNION SELECT name, "desc"' + 'UNION SELECT name, "desc"', ) def test_insert_from_select_col_values(self): table1 = self.tables.mytable table2 = self.tables.myothertable sel = select([table1.c.myid, table1.c.name]).where( - table1.c.name == 'foo') - ins = table2.insert().\ - from_select((table2.c.otherid, table2.c.othername), sel) + table1.c.name == "foo" + ) + ins = table2.insert().from_select( + (table2.c.otherid, table2.c.othername), sel + ) self.assert_compile( ins, "INSERT INTO myothertable (otherid, othername) " "SELECT mytable.myid, mytable.name FROM mytable " "WHERE mytable.name = :name_1", - checkparams={"name_1": "foo"} + checkparams={"name_1": "foo"}, ) def test_anticipate_no_pk_composite_pk(self): t = Table( - 't', MetaData(), Column('x', Integer, primary_key=True), - Column('y', Integer, primary_key=True) + "t", + MetaData(), + Column("x", Integer, primary_key=True), + Column("y", Integer, primary_key=True), ) with expect_warnings( "Column 't.y' is marked as a member.*" - "Note that as of SQLAlchemy 1.1,", + "Note that as of SQLAlchemy 1.1," ): self.assert_compile( - t.insert(), - "INSERT INTO t (x) VALUES (:x)", - params={'x': 5}, + t.insert(), "INSERT INTO t (x) VALUES (:x)", params={"x": 5} ) def test_anticipate_no_pk_composite_pk_implicit_returning(self): t = Table( - 't', MetaData(), Column('x', Integer, primary_key=True), - Column('y', Integer, primary_key=True) + "t", + MetaData(), + Column("x", Integer, primary_key=True), + Column("y", Integer, primary_key=True), ) d = postgresql.dialect() d.implicit_returning = True with expect_warnings( "Column 't.y' is marked as a member.*" - "Note that as of SQLAlchemy 1.1,", + "Note that as of SQLAlchemy 1.1," ): self.assert_compile( t.insert(), "INSERT INTO t (x) VALUES (%(x)s)", params={"x": 5}, - dialect=d + dialect=d, ) def test_anticipate_no_pk_composite_pk_prefetch(self): t = Table( - 't', MetaData(), Column('x', Integer, primary_key=True), - Column('y', Integer, primary_key=True) + "t", + MetaData(), + Column("x", Integer, primary_key=True), + Column("y", Integer, primary_key=True), ) d = postgresql.dialect() d.implicit_returning = False @@ -674,202 +725,193 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): self.assert_compile( t.insert(), "INSERT INTO t (x) VALUES (%(x)s)", - params={'x': 5}, - dialect=d + params={"x": 5}, + dialect=d, ) def test_anticipate_nullable_composite_pk(self): t = Table( - 't', MetaData(), Column('x', Integer, primary_key=True), - Column('y', Integer, primary_key=True, nullable=True) + "t", + MetaData(), + Column("x", Integer, primary_key=True), + Column("y", Integer, primary_key=True, nullable=True), ) self.assert_compile( - t.insert(), - "INSERT INTO t (x) VALUES (:x)", - params={'x': 5}, + t.insert(), "INSERT INTO t (x) VALUES (:x)", params={"x": 5} ) def test_anticipate_no_pk_non_composite_pk(self): t = Table( - 't', MetaData(), - Column('x', Integer, primary_key=True, autoincrement=False), - Column('q', Integer) + "t", + MetaData(), + Column("x", Integer, primary_key=True, autoincrement=False), + Column("q", Integer), ) with expect_warnings( - "Column 't.x' is marked as a member.*" - "may not store NULL.$" + "Column 't.x' is marked as a member.*" "may not store NULL.$" ): self.assert_compile( - t.insert(), - "INSERT INTO t (q) VALUES (:q)", - params={"q": 5} + t.insert(), "INSERT INTO t (q) VALUES (:q)", params={"q": 5} ) def test_anticipate_no_pk_non_composite_pk_implicit_returning(self): t = Table( - 't', MetaData(), - Column('x', Integer, primary_key=True, autoincrement=False), - Column('q', Integer) + "t", + MetaData(), + Column("x", Integer, primary_key=True, autoincrement=False), + Column("q", Integer), ) d = postgresql.dialect() d.implicit_returning = True with expect_warnings( - "Column 't.x' is marked as a member.*" - "may not store NULL.$", + "Column 't.x' is marked as a member.*" "may not store NULL.$" ): self.assert_compile( t.insert(), "INSERT INTO t (q) VALUES (%(q)s)", params={"q": 5}, - dialect=d + dialect=d, ) def test_anticipate_no_pk_non_composite_pk_prefetch(self): t = Table( - 't', MetaData(), - Column('x', Integer, primary_key=True, autoincrement=False), - Column('q', Integer) + "t", + MetaData(), + Column("x", Integer, primary_key=True, autoincrement=False), + Column("q", Integer), ) d = postgresql.dialect() d.implicit_returning = False with expect_warnings( - "Column 't.x' is marked as a member.*" - "may not store NULL.$" + "Column 't.x' is marked as a member.*" "may not store NULL.$" ): self.assert_compile( t.insert(), "INSERT INTO t (q) VALUES (%(q)s)", params={"q": 5}, - dialect=d + dialect=d, ) def test_anticipate_no_pk_lower_case_table(self): t = table( - 't', - Column( - 'id', Integer, primary_key=True, autoincrement=False), - Column('notpk', String(10), nullable=True) + "t", + Column("id", Integer, primary_key=True, autoincrement=False), + Column("notpk", String(10), nullable=True), ) with expect_warnings( - "Column 't.id' is marked as a member.*" - "may not store NULL.$" + "Column 't.id' is marked as a member.*" "may not store NULL.$" ): self.assert_compile( - t.insert(), - "INSERT INTO t () VALUES ()", - params={} + t.insert(), "INSERT INTO t () VALUES ()", params={} ) class InsertImplicitReturningTest( - _InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): + _InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL +): __dialect__ = postgresql.dialect(implicit_returning=True) def test_insert_select(self): table1 = self.tables.mytable sel = select([table1.c.myid, table1.c.name]).where( - table1.c.name == 'foo') - ins = self.tables.myothertable.insert().\ - from_select(("otherid", "othername"), sel) + table1.c.name == "foo" + ) + ins = self.tables.myothertable.insert().from_select( + ("otherid", "othername"), sel + ) self.assert_compile( ins, "INSERT INTO myothertable (otherid, othername) " "SELECT mytable.myid, mytable.name FROM mytable " "WHERE mytable.name = %(name_1)s", - checkparams={"name_1": "foo"} + checkparams={"name_1": "foo"}, ) def test_insert_select_return_defaults(self): table1 = self.tables.mytable sel = select([table1.c.myid, table1.c.name]).where( - table1.c.name == 'foo') - ins = self.tables.myothertable.insert().\ - from_select(("otherid", "othername"), sel).\ - return_defaults(self.tables.myothertable.c.otherid) + table1.c.name == "foo" + ) + ins = ( + self.tables.myothertable.insert() + .from_select(("otherid", "othername"), sel) + .return_defaults(self.tables.myothertable.c.otherid) + ) self.assert_compile( ins, "INSERT INTO myothertable (otherid, othername) " "SELECT mytable.myid, mytable.name FROM mytable " "WHERE mytable.name = %(name_1)s", - checkparams={"name_1": "foo"} + checkparams={"name_1": "foo"}, ) def test_insert_multiple_values(self): - ins = self.tables.myothertable.insert().values([ - {"othername": "foo"}, - {"othername": "bar"}, - ]) + ins = self.tables.myothertable.insert().values( + [{"othername": "foo"}, {"othername": "bar"}] + ) self.assert_compile( ins, "INSERT INTO myothertable (othername) " "VALUES (%(othername_m0)s), " "(%(othername_m1)s)", - checkparams={ - 'othername_m1': 'bar', - 'othername_m0': 'foo'} + checkparams={"othername_m1": "bar", "othername_m0": "foo"}, ) def test_insert_multiple_values_literal_binds(self): - ins = self.tables.myothertable.insert().values([ - {"othername": "foo"}, - {"othername": "bar"}, - ]) + ins = self.tables.myothertable.insert().values( + [{"othername": "foo"}, {"othername": "bar"}] + ) self.assert_compile( ins, "INSERT INTO myothertable (othername) VALUES ('foo'), ('bar')", checkparams={}, - literal_binds=True + literal_binds=True, ) def test_insert_multiple_values_return_defaults(self): # TODO: not sure if this should raise an # error or what - ins = self.tables.myothertable.insert().values([ - {"othername": "foo"}, - {"othername": "bar"}, - ]).return_defaults(self.tables.myothertable.c.otherid) + ins = ( + self.tables.myothertable.insert() + .values([{"othername": "foo"}, {"othername": "bar"}]) + .return_defaults(self.tables.myothertable.c.otherid) + ) self.assert_compile( ins, "INSERT INTO myothertable (othername) " "VALUES (%(othername_m0)s), " "(%(othername_m1)s)", - checkparams={ - 'othername_m1': 'bar', - 'othername_m0': 'foo'} + checkparams={"othername_m1": "bar", "othername_m0": "foo"}, ) def test_insert_single_list_values(self): - ins = self.tables.myothertable.insert().values([ - {"othername": "foo"}, - ]) + ins = self.tables.myothertable.insert().values([{"othername": "foo"}]) self.assert_compile( ins, "INSERT INTO myothertable (othername) " "VALUES (%(othername_m0)s)", - checkparams={'othername_m0': 'foo'} + checkparams={"othername_m0": "foo"}, ) def test_insert_single_element_values(self): - ins = self.tables.myothertable.insert().values( - {"othername": "foo"}, - ) + ins = self.tables.myothertable.insert().values({"othername": "foo"}) self.assert_compile( ins, "INSERT INTO myothertable (othername) " "VALUES (%(othername)s) RETURNING myothertable.otherid", - checkparams={'othername': 'foo'} + checkparams={"othername": "foo"}, ) class EmptyTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_empty_insert_default(self): table1 = self.tables.mytable stmt = table1.insert().values({}) # hide from 2to3 - self.assert_compile(stmt, 'INSERT INTO mytable () VALUES ()') + self.assert_compile(stmt, "INSERT INTO mytable () VALUES ()") def test_supports_empty_insert_true(self): table1 = self.tables.mytable @@ -878,9 +920,9 @@ class EmptyTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): dialect.supports_empty_insert = dialect.supports_default_values = True stmt = table1.insert().values({}) # hide from 2to3 - self.assert_compile(stmt, - 'INSERT INTO mytable DEFAULT VALUES', - dialect=dialect) + self.assert_compile( + stmt, "INSERT INTO mytable DEFAULT VALUES", dialect=dialect + ) def test_supports_empty_insert_false(self): table1 = self.tables.mytable @@ -894,21 +936,24 @@ class EmptyTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): "The 'default' dialect with current database version " "settings does not support empty inserts.", stmt.compile, - dialect=dialect) + dialect=dialect, + ) def _test_insert_with_empty_collection_values(self, collection): table1 = self.tables.mytable ins = table1.insert().values(collection) - self.assert_compile(ins, - 'INSERT INTO mytable () VALUES ()', - checkparams={}) + self.assert_compile( + ins, "INSERT INTO mytable () VALUES ()", checkparams={} + ) # empty dict populates on next values call - self.assert_compile(ins.values(myid=3), - 'INSERT INTO mytable (myid) VALUES (:myid)', - checkparams={'myid': 3}) + self.assert_compile( + ins.values(myid=3), + "INSERT INTO mytable (myid) VALUES (:myid)", + checkparams={"myid": 3}, + ) def test_insert_with_empty_list_values(self): self._test_insert_with_empty_collection_values([]) @@ -921,38 +966,40 @@ class EmptyTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_not_supported(self): table1 = self.tables.mytable dialect = default.DefaultDialect() - stmt = table1.insert().values([{'myid': 1}, {'myid': 2}]) + stmt = table1.insert().values([{"myid": 1}, {"myid": 2}]) assert_raises_message( exc.CompileError, "The 'default' dialect with current database version settings " "does not support in-place multirow inserts.", - stmt.compile, dialect=dialect) + stmt.compile, + dialect=dialect, + ) def test_named(self): table1 = self.tables.mytable values = [ - {'myid': 1, 'name': 'a', 'description': 'b'}, - {'myid': 2, 'name': 'c', 'description': 'd'}, - {'myid': 3, 'name': 'e', 'description': 'f'} + {"myid": 1, "name": "a", "description": "b"}, + {"myid": 2, "name": "c", "description": "d"}, + {"myid": 3, "name": "e", "description": "f"}, ] checkparams = { - 'myid_m0': 1, - 'myid_m1': 2, - 'myid_m2': 3, - 'name_m0': 'a', - 'name_m1': 'c', - 'name_m2': 'e', - 'description_m0': 'b', - 'description_m1': 'd', - 'description_m2': 'f', + "myid_m0": 1, + "myid_m1": 2, + "myid_m2": 3, + "name_m0": "a", + "name_m1": "c", + "name_m2": "e", + "description_m0": "b", + "description_m1": "d", + "description_m2": "f", } dialect = default.DefaultDialect() @@ -960,32 +1007,33 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): self.assert_compile( table1.insert().values(values), - 'INSERT INTO mytable (myid, name, description) VALUES ' - '(:myid_m0, :name_m0, :description_m0), ' - '(:myid_m1, :name_m1, :description_m1), ' - '(:myid_m2, :name_m2, :description_m2)', + "INSERT INTO mytable (myid, name, description) VALUES " + "(:myid_m0, :name_m0, :description_m0), " + "(:myid_m1, :name_m1, :description_m1), " + "(:myid_m2, :name_m2, :description_m2)", checkparams=checkparams, - dialect=dialect) + dialect=dialect, + ) def test_named_with_column_objects(self): table1 = self.tables.mytable values = [ - {table1.c.myid: 1, table1.c.name: 'a', table1.c.description: 'b'}, - {table1.c.myid: 2, table1.c.name: 'c', table1.c.description: 'd'}, - {table1.c.myid: 3, table1.c.name: 'e', table1.c.description: 'f'}, + {table1.c.myid: 1, table1.c.name: "a", table1.c.description: "b"}, + {table1.c.myid: 2, table1.c.name: "c", table1.c.description: "d"}, + {table1.c.myid: 3, table1.c.name: "e", table1.c.description: "f"}, ] checkparams = { - 'myid_m0': 1, - 'myid_m1': 2, - 'myid_m2': 3, - 'name_m0': 'a', - 'name_m1': 'c', - 'name_m2': 'e', - 'description_m0': 'b', - 'description_m1': 'd', - 'description_m2': 'f', + "myid_m0": 1, + "myid_m1": 2, + "myid_m2": 3, + "name_m0": "a", + "name_m1": "c", + "name_m2": "e", + "description_m0": "b", + "description_m1": "d", + "description_m2": "f", } dialect = default.DefaultDialect() @@ -993,50 +1041,48 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): self.assert_compile( table1.insert().values(values), - 'INSERT INTO mytable (myid, name, description) VALUES ' - '(:myid_m0, :name_m0, :description_m0), ' - '(:myid_m1, :name_m1, :description_m1), ' - '(:myid_m2, :name_m2, :description_m2)', + "INSERT INTO mytable (myid, name, description) VALUES " + "(:myid_m0, :name_m0, :description_m0), " + "(:myid_m1, :name_m1, :description_m1), " + "(:myid_m2, :name_m2, :description_m2)", checkparams=checkparams, - dialect=dialect) + dialect=dialect, + ) def test_positional(self): table1 = self.tables.mytable values = [ - {'myid': 1, 'name': 'a', 'description': 'b'}, - {'myid': 2, 'name': 'c', 'description': 'd'}, - {'myid': 3, 'name': 'e', 'description': 'f'} + {"myid": 1, "name": "a", "description": "b"}, + {"myid": 2, "name": "c", "description": "d"}, + {"myid": 3, "name": "e", "description": "f"}, ] - checkpositional = (1, 'a', 'b', 2, 'c', 'd', 3, 'e', 'f') + checkpositional = (1, "a", "b", 2, "c", "d", 3, "e", "f") dialect = default.DefaultDialect() dialect.supports_multivalues_insert = True - dialect.paramstyle = 'format' + dialect.paramstyle = "format" dialect.positional = True self.assert_compile( table1.insert().values(values), - 'INSERT INTO mytable (myid, name, description) VALUES ' - '(%s, %s, %s), (%s, %s, %s), (%s, %s, %s)', + "INSERT INTO mytable (myid, name, description) VALUES " + "(%s, %s, %s), (%s, %s, %s), (%s, %s, %s)", checkpositional=checkpositional, - dialect=dialect) + dialect=dialect, + ) def test_positional_w_defaults(self): table1 = self.tables.table_w_defaults - values = [ - {'id': 1}, - {'id': 2}, - {'id': 3} - ] + values = [{"id": 1}, {"id": 2}, {"id": 3}] checkpositional = (1, None, None, 2, None, None, 3, None, None) dialect = default.DefaultDialect() dialect.supports_multivalues_insert = True - dialect.paramstyle = 'format' + dialect.paramstyle = "format" dialect.positional = True self.assert_compile( @@ -1045,128 +1091,163 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): "(%s, %s, %s), (%s, %s, %s), (%s, %s, %s)", checkpositional=checkpositional, check_prefetch=[ - table1.c.x, table1.c.z, + table1.c.x, + table1.c.z, crud._multiparam_column(table1.c.x, 0), crud._multiparam_column(table1.c.z, 0), crud._multiparam_column(table1.c.x, 1), - crud._multiparam_column(table1.c.z, 1) + crud._multiparam_column(table1.c.z, 1), ], - dialect=dialect) + dialect=dialect, + ) def test_inline_default(self): metadata = MetaData() - table = Table('sometable', metadata, - Column('id', Integer, primary_key=True), - Column('data', String), - Column('foo', Integer, default=func.foobar())) + table = Table( + "sometable", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String), + Column("foo", Integer, default=func.foobar()), + ) values = [ - {'id': 1, 'data': 'data1'}, - {'id': 2, 'data': 'data2', 'foo': 'plainfoo'}, - {'id': 3, 'data': 'data3'}, + {"id": 1, "data": "data1"}, + {"id": 2, "data": "data2", "foo": "plainfoo"}, + {"id": 3, "data": "data3"}, ] checkparams = { - 'id_m0': 1, - 'id_m1': 2, - 'id_m2': 3, - 'data_m0': 'data1', - 'data_m1': 'data2', - 'data_m2': 'data3', - 'foo_m1': 'plainfoo', + "id_m0": 1, + "id_m1": 2, + "id_m2": 3, + "data_m0": "data1", + "data_m1": "data2", + "data_m2": "data3", + "foo_m1": "plainfoo", } self.assert_compile( table.insert().values(values), - 'INSERT INTO sometable (id, data, foo) VALUES ' - '(%(id_m0)s, %(data_m0)s, foobar()), ' - '(%(id_m1)s, %(data_m1)s, %(foo_m1)s), ' - '(%(id_m2)s, %(data_m2)s, foobar())', + "INSERT INTO sometable (id, data, foo) VALUES " + "(%(id_m0)s, %(data_m0)s, foobar()), " + "(%(id_m1)s, %(data_m1)s, %(foo_m1)s), " + "(%(id_m2)s, %(data_m2)s, foobar())", checkparams=checkparams, - dialect=postgresql.dialect()) + dialect=postgresql.dialect(), + ) def test_python_scalar_default(self): metadata = MetaData() - table = Table('sometable', metadata, - Column('id', Integer, primary_key=True), - Column('data', String), - Column('foo', Integer, default=10)) + table = Table( + "sometable", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String), + Column("foo", Integer, default=10), + ) values = [ - {'id': 1, 'data': 'data1'}, - {'id': 2, 'data': 'data2', 'foo': 15}, - {'id': 3, 'data': 'data3'}, + {"id": 1, "data": "data1"}, + {"id": 2, "data": "data2", "foo": 15}, + {"id": 3, "data": "data3"}, ] checkparams = { - 'id_m0': 1, - 'id_m1': 2, - 'id_m2': 3, - 'data_m0': 'data1', - 'data_m1': 'data2', - 'data_m2': 'data3', - 'foo': None, # evaluated later - 'foo_m1': 15, - 'foo_m2': None # evaluated later + "id_m0": 1, + "id_m1": 2, + "id_m2": 3, + "data_m0": "data1", + "data_m1": "data2", + "data_m2": "data3", + "foo": None, # evaluated later + "foo_m1": 15, + "foo_m2": None, # evaluated later } stmt = table.insert().values(values) eq_( - dict([ - (k, v.type._type_affinity) - for (k, v) in - stmt.compile(dialect=postgresql.dialect()).binds.items()]), + dict( + [ + (k, v.type._type_affinity) + for (k, v) in stmt.compile( + dialect=postgresql.dialect() + ).binds.items() + ] + ), { - 'foo': Integer, 'data_m2': String, 'id_m0': Integer, - 'id_m2': Integer, 'foo_m1': Integer, 'data_m1': String, - 'id_m1': Integer, 'foo_m2': Integer, 'data_m0': String} + "foo": Integer, + "data_m2": String, + "id_m0": Integer, + "id_m2": Integer, + "foo_m1": Integer, + "data_m1": String, + "id_m1": Integer, + "foo_m2": Integer, + "data_m0": String, + }, ) self.assert_compile( stmt, - 'INSERT INTO sometable (id, data, foo) VALUES ' - '(%(id_m0)s, %(data_m0)s, %(foo)s), ' - '(%(id_m1)s, %(data_m1)s, %(foo_m1)s), ' - '(%(id_m2)s, %(data_m2)s, %(foo_m2)s)', + "INSERT INTO sometable (id, data, foo) VALUES " + "(%(id_m0)s, %(data_m0)s, %(foo)s), " + "(%(id_m1)s, %(data_m1)s, %(foo_m1)s), " + "(%(id_m2)s, %(data_m2)s, %(foo_m2)s)", checkparams=checkparams, - dialect=postgresql.dialect()) + dialect=postgresql.dialect(), + ) def test_python_fn_default(self): metadata = MetaData() - table = Table('sometable', metadata, - Column('id', Integer, primary_key=True), - Column('data', String), - Column('foo', Integer, default=lambda: 10)) + table = Table( + "sometable", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String), + Column("foo", Integer, default=lambda: 10), + ) values = [ - {'id': 1, 'data': 'data1'}, - {'id': 2, 'data': 'data2', 'foo': 15}, - {'id': 3, 'data': 'data3'}, + {"id": 1, "data": "data1"}, + {"id": 2, "data": "data2", "foo": 15}, + {"id": 3, "data": "data3"}, ] checkparams = { - 'id_m0': 1, - 'id_m1': 2, - 'id_m2': 3, - 'data_m0': 'data1', - 'data_m1': 'data2', - 'data_m2': 'data3', - 'foo': None, # evaluated later - 'foo_m1': 15, - 'foo_m2': None, # evaluated later + "id_m0": 1, + "id_m1": 2, + "id_m2": 3, + "data_m0": "data1", + "data_m1": "data2", + "data_m2": "data3", + "foo": None, # evaluated later + "foo_m1": 15, + "foo_m2": None, # evaluated later } stmt = table.insert().values(values) eq_( - dict([ - (k, v.type._type_affinity) - for (k, v) in - stmt.compile(dialect=postgresql.dialect()).binds.items()]), + dict( + [ + (k, v.type._type_affinity) + for (k, v) in stmt.compile( + dialect=postgresql.dialect() + ).binds.items() + ] + ), { - 'foo': Integer, 'data_m2': String, 'id_m0': Integer, - 'id_m2': Integer, 'foo_m1': Integer, 'data_m1': String, - 'id_m1': Integer, 'foo_m2': Integer, 'data_m0': String} + "foo": Integer, + "data_m2": String, + "id_m0": Integer, + "id_m2": Integer, + "foo_m1": Integer, + "data_m1": String, + "id_m1": Integer, + "foo_m2": Integer, + "data_m0": String, + }, ) self.assert_compile( @@ -1176,14 +1257,18 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): "(%(id_m1)s, %(data_m1)s, %(foo_m1)s), " "(%(id_m2)s, %(data_m2)s, %(foo_m2)s)", checkparams=checkparams, - dialect=postgresql.dialect()) + dialect=postgresql.dialect(), + ) def test_sql_functions(self): metadata = MetaData() - table = Table('sometable', metadata, - Column('id', Integer, primary_key=True), - Column('data', String), - Column('foo', Integer)) + table = Table( + "sometable", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String), + Column("foo", Integer), + ) values = [ {"id": 1, "data": "foo", "foo": func.foob()}, @@ -1193,21 +1278,17 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): {"id": 5, "data": "bar", "foo": func.foob()}, ] checkparams = { - 'id_m0': 1, - 'data_m0': 'foo', - - 'id_m1': 2, - 'data_m1': 'bar', - - 'id_m2': 3, - 'data_m2': 'bar', - - 'id_m3': 4, - 'data_m3': 'bar', - 'foo_m3': 15, - - 'id_m4': 5, - 'data_m4': 'bar' + "id_m0": 1, + "data_m0": "foo", + "id_m1": 2, + "data_m1": "bar", + "id_m2": 3, + "data_m2": "bar", + "id_m3": 4, + "data_m3": "bar", + "foo_m3": 15, + "id_m4": 5, + "data_m4": "bar", } self.assert_compile( @@ -1219,50 +1300,58 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): "(%(id_m3)s, %(data_m3)s, %(foo_m3)s), " "(%(id_m4)s, %(data_m4)s, foob())", checkparams=checkparams, - dialect=postgresql.dialect()) + dialect=postgresql.dialect(), + ) def test_server_default(self): metadata = MetaData() - table = Table('sometable', metadata, - Column('id', Integer, primary_key=True), - Column('data', String), - Column('foo', Integer, server_default=func.foobar())) + table = Table( + "sometable", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String), + Column("foo", Integer, server_default=func.foobar()), + ) values = [ - {'id': 1, 'data': 'data1'}, - {'id': 2, 'data': 'data2', 'foo': 'plainfoo'}, - {'id': 3, 'data': 'data3'}, + {"id": 1, "data": "data1"}, + {"id": 2, "data": "data2", "foo": "plainfoo"}, + {"id": 3, "data": "data3"}, ] checkparams = { - 'id_m0': 1, - 'id_m1': 2, - 'id_m2': 3, - 'data_m0': 'data1', - 'data_m1': 'data2', - 'data_m2': 'data3', + "id_m0": 1, + "id_m1": 2, + "id_m2": 3, + "data_m0": "data1", + "data_m1": "data2", + "data_m2": "data3", } self.assert_compile( table.insert().values(values), - 'INSERT INTO sometable (id, data) VALUES ' - '(%(id_m0)s, %(data_m0)s), ' - '(%(id_m1)s, %(data_m1)s), ' - '(%(id_m2)s, %(data_m2)s)', + "INSERT INTO sometable (id, data) VALUES " + "(%(id_m0)s, %(data_m0)s), " + "(%(id_m1)s, %(data_m1)s), " + "(%(id_m2)s, %(data_m2)s)", checkparams=checkparams, - dialect=postgresql.dialect()) + dialect=postgresql.dialect(), + ) def test_server_default_absent_value(self): metadata = MetaData() - table = Table('sometable', metadata, - Column('id', Integer, primary_key=True), - Column('data', String), - Column('foo', Integer, server_default=func.foobar())) + table = Table( + "sometable", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String), + Column("foo", Integer, server_default=func.foobar()), + ) values = [ - {'id': 1, 'data': 'data1', 'foo': 'plainfoo'}, - {'id': 2, 'data': 'data2'}, - {'id': 3, 'data': 'data3', 'foo': 'otherfoo'}, + {"id": 1, "data": "data1", "foo": "plainfoo"}, + {"id": 2, "data": "data2"}, + {"id": 3, "data": "data3", "foo": "otherfoo"}, ] assert_raises_message( @@ -1270,5 +1359,5 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): "INSERT value for column sometable.foo is explicitly rendered " "as a boundparameter in the VALUES clause; a Python-side value or " "SQL expression is required", - table.insert().values(values).compile + table.insert().values(values).compile, ) diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py index 502ef69122..7803de75e7 100644 --- a/test/sql/test_insert_exec.py +++ b/test/sql/test_insert_exec.py @@ -2,8 +2,18 @@ from sqlalchemy.testing import eq_, assert_raises_message, is_ from sqlalchemy import testing from sqlalchemy.testing import fixtures, engines from sqlalchemy import ( - exc, sql, String, Integer, MetaData, and_, ForeignKey, - VARCHAR, INT, Sequence, func) + exc, + sql, + String, + Integer, + MetaData, + and_, + ForeignKey, + VARCHAR, + INT, + Sequence, + func, +) from sqlalchemy.testing.schema import Table, Column @@ -13,12 +23,13 @@ class InsertExecTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table( - 'users', metadata, + "users", + metadata, Column( - 'user_id', INT, primary_key=True, - test_needs_autoincrement=True), - Column('user_name', VARCHAR(20)), - test_needs_acid=True + "user_id", INT, primary_key=True, test_needs_autoincrement=True + ), + Column("user_name", VARCHAR(20)), + test_needs_acid=True, ) @testing.requires.multivalues_inserts @@ -26,15 +37,17 @@ class InsertExecTest(fixtures.TablesTest): users = self.tables.users users.insert( values=[ - {'user_id': 7, 'user_name': 'jack'}, - {'user_id': 8, 'user_name': 'ed'}]).execute() + {"user_id": 7, "user_name": "jack"}, + {"user_id": 8, "user_name": "ed"}, + ] + ).execute() rows = users.select().order_by(users.c.user_id).execute().fetchall() - eq_(rows[0], (7, 'jack')) - eq_(rows[1], (8, 'ed')) - users.insert(values=[(9, 'jack'), (10, 'ed')]).execute() + eq_(rows[0], (7, "jack")) + eq_(rows[1], (8, "ed")) + users.insert(values=[(9, "jack"), (10, "ed")]).execute() rows = users.select().order_by(users.c.user_id).execute().fetchall() - eq_(rows[2], (9, 'jack')) - eq_(rows[3], (10, 'ed')) + eq_(rows[2], (9, "jack")) + eq_(rows[3], (10, "ed")) def test_insert_heterogeneous_params(self): """test that executemany parameters are asserted to match the @@ -48,17 +61,15 @@ class InsertExecTest(fixtures.TablesTest): "parameter group 2 " r"\[SQL: u?'INSERT INTO users", users.insert().execute, - {'user_id': 7, 'user_name': 'jack'}, - {'user_id': 8, 'user_name': 'ed'}, - {'user_id': 9} + {"user_id": 7, "user_name": "jack"}, + {"user_id": 8, "user_name": "ed"}, + {"user_id": 9}, ) # this succeeds however. We aren't yet doing # a length check on all subsequent parameters. users.insert().execute( - {'user_id': 7}, - {'user_id': 8, 'user_name': 'ed'}, - {'user_id': 9} + {"user_id": 7}, {"user_id": 8, "user_name": "ed"}, {"user_id": 9} ) def _test_lastrow_accessor(self, table_, values, assertvalues): @@ -76,33 +87,39 @@ class InsertExecTest(fixtures.TablesTest): ins = table_.insert() comp = ins.compile(engine, column_keys=list(values)) if not set(values).issuperset( - c.key for c in table_.primary_key): + c.key for c in table_.primary_key + ): is_(bool(comp.returning), True) result = engine.execute(table_.insert(), **values) ret = values.copy() for col, id in zip( - table_.primary_key, result.inserted_primary_key): + table_.primary_key, result.inserted_primary_key + ): ret[col.key] = id if result.lastrow_has_defaults(): criterion = and_( *[ - col == id for col, id in - zip(table_.primary_key, result.inserted_primary_key)]) + col == id + for col, id in zip( + table_.primary_key, result.inserted_primary_key + ) + ] + ) row = engine.execute(table_.select(criterion)).first() for c in table_.c: ret[c.key] = row[c] return ret - if testing.against('firebird', 'postgresql', 'oracle', 'mssql'): + if testing.against("firebird", "postgresql", "oracle", "mssql"): assert testing.db.dialect.implicit_returning if testing.db.dialect.implicit_returning: test_engines = [ - engines.testing_engine(options={'implicit_returning': False}), - engines.testing_engine(options={'implicit_returning': True}), + engines.testing_engine(options={"implicit_returning": False}), + engines.testing_engine(options={"implicit_returning": True}), ] else: test_engines = [testing.db] @@ -115,47 +132,57 @@ class InsertExecTest(fixtures.TablesTest): finally: table_.drop(bind=engine) - @testing.skip_if('sqlite') + @testing.skip_if("sqlite") def test_lastrow_accessor_one(self): metadata = MetaData() self._test_lastrow_accessor( Table( - "t1", metadata, + "t1", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('foo', String(30), primary_key=True)), - {'foo': 'hi'}, - {'id': 1, 'foo': 'hi'} + "id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("foo", String(30), primary_key=True), + ), + {"foo": "hi"}, + {"id": 1, "foo": "hi"}, ) - @testing.skip_if('sqlite') + @testing.skip_if("sqlite") def test_lastrow_accessor_two(self): metadata = MetaData() self._test_lastrow_accessor( Table( - "t2", metadata, + "t2", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('foo', String(30), primary_key=True), - Column('bar', String(30), server_default='hi') + "id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("foo", String(30), primary_key=True), + Column("bar", String(30), server_default="hi"), ), - {'foo': 'hi'}, - {'id': 1, 'foo': 'hi', 'bar': 'hi'} + {"foo": "hi"}, + {"id": 1, "foo": "hi", "bar": "hi"}, ) def test_lastrow_accessor_three(self): metadata = MetaData() self._test_lastrow_accessor( Table( - "t3", metadata, + "t3", + metadata, Column("id", String(40), primary_key=True), - Column('foo', String(30), primary_key=True), - Column("bar", String(30)) + Column("foo", String(30), primary_key=True), + Column("bar", String(30)), ), - {'id': 'hi', 'foo': 'thisisfoo', 'bar': "thisisbar"}, - {'id': 'hi', 'foo': 'thisisfoo', 'bar': "thisisbar"} + {"id": "hi", "foo": "thisisfoo", "bar": "thisisbar"}, + {"id": "hi", "foo": "thisisfoo", "bar": "thisisbar"}, ) @testing.requires.sequences @@ -163,84 +190,105 @@ class InsertExecTest(fixtures.TablesTest): metadata = MetaData() self._test_lastrow_accessor( Table( - "t4", metadata, + "t4", + metadata, Column( - 'id', Integer, - Sequence('t4_id_seq', optional=True), - primary_key=True), - Column('foo', String(30), primary_key=True), - Column('bar', String(30), server_default='hi') + "id", + Integer, + Sequence("t4_id_seq", optional=True), + primary_key=True, + ), + Column("foo", String(30), primary_key=True), + Column("bar", String(30), server_default="hi"), ), - {'foo': 'hi', 'id': 1}, - {'id': 1, 'foo': 'hi', 'bar': 'hi'} + {"foo": "hi", "id": 1}, + {"id": 1, "foo": "hi", "bar": "hi"}, ) def test_lastrow_accessor_five(self): metadata = MetaData() self._test_lastrow_accessor( Table( - "t5", metadata, - Column('id', String(10), primary_key=True), - Column('bar', String(30), server_default='hi') + "t5", + metadata, + Column("id", String(10), primary_key=True), + Column("bar", String(30), server_default="hi"), ), - {'id': 'id1'}, - {'id': 'id1', 'bar': 'hi'}, + {"id": "id1"}, + {"id": "id1", "bar": "hi"}, ) - @testing.skip_if('sqlite') + @testing.skip_if("sqlite") def test_lastrow_accessor_six(self): metadata = MetaData() self._test_lastrow_accessor( Table( - "t6", metadata, + "t6", + metadata, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('bar', Integer, primary_key=True) + "id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("bar", Integer, primary_key=True), ), - {'bar': 0}, - {'id': 1, 'bar': 0}, + {"bar": 0}, + {"id": 1, "bar": 0}, ) # TODO: why not in the sqlite suite? - @testing.only_on('sqlite+pysqlite') + @testing.only_on("sqlite+pysqlite") @testing.provide_metadata def test_lastrowid_zero(self): from sqlalchemy.dialects import sqlite + eng = engines.testing_engine() class ExcCtx(sqlite.base.SQLiteExecutionContext): - def get_lastrowid(self): return 0 + eng.dialect.execution_ctx_cls = ExcCtx t = Table( - 't', self.metadata, Column('x', Integer, primary_key=True), - Column('y', Integer)) + "t", + self.metadata, + Column("x", Integer, primary_key=True), + Column("y", Integer), + ) t.create(eng) r = eng.execute(t.insert().values(y=5)) eq_(r.inserted_primary_key, [0]) @testing.fails_on( - 'sqlite', "sqlite autoincremnt doesn't work with composite pks") + "sqlite", "sqlite autoincremnt doesn't work with composite pks" + ) @testing.provide_metadata def test_misordered_lastrow(self): metadata = self.metadata related = Table( - 'related', metadata, - Column('id', Integer, primary_key=True), - mysql_engine='MyISAM' + "related", + metadata, + Column("id", Integer, primary_key=True), + mysql_engine="MyISAM", ) t6 = Table( - "t6", metadata, + "t6", + metadata, Column( - 'manual_id', Integer, ForeignKey('related.id'), - primary_key=True), + "manual_id", + Integer, + ForeignKey("related.id"), + primary_key=True, + ), Column( - 'auto_id', Integer, primary_key=True, - test_needs_autoincrement=True), - mysql_engine='MyISAM' + "auto_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + mysql_engine="MyISAM", ) metadata.create_all() @@ -255,7 +303,8 @@ class InsertExecTest(fixtures.TablesTest): users = self.tables.users stmt = users.insert().from_select( (users.c.user_id, users.c.user_name), - users.select().where(users.c.user_id == 20)) + users.select().where(users.c.user_id == 20), + ) testing.db.execute(stmt) @@ -263,7 +312,8 @@ class InsertExecTest(fixtures.TablesTest): users = self.tables.users stmt = users.insert().from_select( ["user_id", "user_name"], - users.select().where(users.c.user_id == 20)) + users.select().where(users.c.user_id == 20), + ) testing.db.execute(stmt) @@ -271,12 +321,15 @@ class InsertExecTest(fixtures.TablesTest): @testing.requires.returning def test_no_inserted_pk_on_returning(self): users = self.tables.users - result = testing.db.execute(users.insert().returning( - users.c.user_id, users.c.user_name)) + result = testing.db.execute( + users.insert().returning(users.c.user_id, users.c.user_name) + ) assert_raises_message( exc.InvalidRequestError, r"Can't call inserted_primary_key when returning\(\) is used.", - getattr, result, 'inserted_primary_key' + getattr, + result, + "inserted_primary_key", ) @@ -286,27 +339,32 @@ class TableInsertTest(fixtures.TablesTest): regarding the inline=True flag, lower-case 't' tables. """ - run_create_tables = 'each' + + run_create_tables = "each" __backend__ = True @classmethod def define_tables(cls, metadata): Table( - 'foo', metadata, - Column('id', Integer, Sequence('t_id_seq'), primary_key=True), - Column('data', String(50)), - Column('x', Integer) + "foo", + metadata, + Column("id", Integer, Sequence("t_id_seq"), primary_key=True), + Column("data", String(50)), + Column("x", Integer), ) def _fixture(self, types=True): if types: t = sql.table( - 'foo', sql.column('id', Integer), - sql.column('data', String), - sql.column('x', Integer)) + "foo", + sql.column("id", Integer), + sql.column("data", String), + sql.column("x", Integer), + ) else: t = sql.table( - 'foo', sql.column('id'), sql.column('data'), sql.column('x')) + "foo", sql.column("id"), sql.column("data"), sql.column("x") + ) return t def _test(self, stmt, row, returning=None, inserted_primary_key=False): @@ -324,99 +382,104 @@ class TableInsertTest(fixtures.TablesTest): testing.db.execute(stmt, rows) eq_( testing.db.execute( - self.tables.foo.select(). - order_by(self.tables.foo.c.id)).fetchall(), - data) + self.tables.foo.select().order_by(self.tables.foo.c.id) + ).fetchall(), + data, + ) @testing.requires.sequences def test_explicit_sequence(self): t = self._fixture() self._test( t.insert().values( - id=func.next_value(Sequence('t_id_seq')), data='data', x=5), - (1, 'data', 5) + id=func.next_value(Sequence("t_id_seq")), data="data", x=5 + ), + (1, "data", 5), ) def test_uppercase(self): t = self.tables.foo self._test( - t.insert().values(id=1, data='data', x=5), - (1, 'data', 5), - inserted_primary_key=[1] + t.insert().values(id=1, data="data", x=5), + (1, "data", 5), + inserted_primary_key=[1], ) def test_uppercase_inline(self): t = self.tables.foo self._test( - t.insert(inline=True).values(id=1, data='data', x=5), - (1, 'data', 5), - inserted_primary_key=[1] + t.insert(inline=True).values(id=1, data="data", x=5), + (1, "data", 5), + inserted_primary_key=[1], ) @testing.crashes( "mssql+pyodbc", - "Pyodbc + SQL Server + Py3K, some decimal handling issue") + "Pyodbc + SQL Server + Py3K, some decimal handling issue", + ) def test_uppercase_inline_implicit(self): t = self.tables.foo self._test( - t.insert(inline=True).values(data='data', x=5), - (1, 'data', 5), - inserted_primary_key=[None] + t.insert(inline=True).values(data="data", x=5), + (1, "data", 5), + inserted_primary_key=[None], ) def test_uppercase_implicit(self): t = self.tables.foo self._test( - t.insert().values(data='data', x=5), - (1, 'data', 5), - inserted_primary_key=[1] + t.insert().values(data="data", x=5), + (1, "data", 5), + inserted_primary_key=[1], ) def test_uppercase_direct_params(self): t = self.tables.foo self._test( - t.insert().values(id=1, data='data', x=5), - (1, 'data', 5), - inserted_primary_key=[1] + t.insert().values(id=1, data="data", x=5), + (1, "data", 5), + inserted_primary_key=[1], ) @testing.requires.returning def test_uppercase_direct_params_returning(self): t = self.tables.foo self._test( - t.insert().values(id=1, data='data', x=5).returning(t.c.id, t.c.x), - (1, 'data', 5), - returning=(1, 5) + t.insert().values(id=1, data="data", x=5).returning(t.c.id, t.c.x), + (1, "data", 5), + returning=(1, 5), ) @testing.fails_on( - 'mssql', "lowercase table doesn't support identity insert disable") + "mssql", "lowercase table doesn't support identity insert disable" + ) def test_direct_params(self): t = self._fixture() self._test( - t.insert().values(id=1, data='data', x=5), - (1, 'data', 5), - inserted_primary_key=[] + t.insert().values(id=1, data="data", x=5), + (1, "data", 5), + inserted_primary_key=[], ) @testing.fails_on( - 'mssql', "lowercase table doesn't support identity insert disable") + "mssql", "lowercase table doesn't support identity insert disable" + ) @testing.requires.returning def test_direct_params_returning(self): t = self._fixture() self._test( - t.insert().values(id=1, data='data', x=5).returning(t.c.id, t.c.x), - (1, 'data', 5), - returning=(1, 5) + t.insert().values(id=1, data="data", x=5).returning(t.c.id, t.c.x), + (1, "data", 5), + returning=(1, 5), ) @testing.requires.emulated_lastrowid def test_implicit_pk(self): t = self._fixture() self._test( - t.insert().values(data='data', x=5), - (1, 'data', 5), - inserted_primary_key=[] + t.insert().values(data="data", x=5), + (1, "data", 5), + inserted_primary_key=[], ) @testing.requires.emulated_lastrowid @@ -425,22 +488,18 @@ class TableInsertTest(fixtures.TablesTest): self._test_multi( t.insert(), [ - {'data': 'd1', 'x': 5}, - {'data': 'd2', 'x': 6}, - {'data': 'd3', 'x': 7}, - ], - [ - (1, 'd1', 5), - (2, 'd2', 6), - (3, 'd3', 7) + {"data": "d1", "x": 5}, + {"data": "d2", "x": 6}, + {"data": "d3", "x": 7}, ], + [(1, "d1", 5), (2, "d2", 6), (3, "d3", 7)], ) @testing.requires.emulated_lastrowid def test_implicit_pk_inline(self): t = self._fixture() self._test( - t.insert(inline=True).values(data='data', x=5), - (1, 'data', 5), - inserted_primary_key=[] + t.insert(inline=True).values(data="data", x=5), + (1, "data", 5), + inserted_primary_key=[], ) diff --git a/test/sql/test_inspect.py b/test/sql/test_inspect.py index 7178bc58ac..0e78c06c81 100644 --- a/test/sql/test_inspect.py +++ b/test/sql/test_inspect.py @@ -7,20 +7,15 @@ from sqlalchemy.testing import is_ class TestCoreInspection(fixtures.TestBase): - def test_table(self): - t = Table('t', MetaData(), - Column('x', Integer) - ) + t = Table("t", MetaData(), Column("x", Integer)) is_(inspect(t), t) assert t.is_selectable is_(t.selectable, t) def test_select(self): - t = Table('t', MetaData(), - Column('x', Integer) - ) + t = Table("t", MetaData(), Column("x", Integer)) s = t.select() is_(inspect(s), s) @@ -28,10 +23,10 @@ class TestCoreInspection(fixtures.TestBase): is_(s.selectable, s) def test_column_expr(self): - c = Column('x', Integer) + c = Column("x", Integer) is_(inspect(c), c) assert not c.is_selectable - assert not hasattr(c, 'selectable') + assert not hasattr(c, "selectable") def test_no_clause_element_on_clauseelement(self): # re [ticket:3802], there are in the wild examples @@ -39,5 +34,5 @@ class TestCoreInspection(fixtures.TestBase): # absence of __clause_element__ as a test for "this is the clause # element" must be maintained - x = Column('foo', Integer) - assert not hasattr(x, '__clause_element__') + x = Column("foo", Integer) + assert not hasattr(x, "__clause_element__") diff --git a/test/sql/test_join_rewriting.py b/test/sql/test_join_rewriting.py index c699a5c973..dd74406fd7 100644 --- a/test/sql/test_join_rewriting.py +++ b/test/sql/test_join_rewriting.py @@ -3,8 +3,16 @@ to support SQLite's lack of right-nested joins. SQlite as of version 3.7.16 no longer has this limitation. """ -from sqlalchemy import Table, Column, Integer, MetaData, ForeignKey, \ - select, exists, union +from sqlalchemy import ( + Table, + Column, + Integer, + MetaData, + ForeignKey, + select, + exists, + union, +) from sqlalchemy.testing import fixtures, AssertsCompiledSQL from sqlalchemy import util from sqlalchemy.engine import default @@ -14,70 +22,74 @@ from sqlalchemy import testing m = MetaData() -a = Table('a', m, - Column('id', Integer, primary_key=True) - ) - -b = Table('b', m, - Column('id', Integer, primary_key=True), - Column('a_id', Integer, ForeignKey('a.id')) - ) - -b_a = Table('b_a', m, - Column('id', Integer, primary_key=True), - ) - -b1 = Table('b1', m, - Column('id', Integer, primary_key=True), - Column('a_id', Integer, ForeignKey('a.id')) - ) - -b2 = Table('b2', m, - Column('id', Integer, primary_key=True), - Column('a_id', Integer, ForeignKey('a.id')) - ) - -a_to_b = Table('a_to_b', m, - Column('a_id', Integer, ForeignKey('a.id')), - Column('b_id', Integer, ForeignKey('b.id')), - ) - -c = Table('c', m, - Column('id', Integer, primary_key=True), - Column('b_id', Integer, ForeignKey('b.id')) - ) - -d = Table('d', m, - Column('id', Integer, primary_key=True), - Column('c_id', Integer, ForeignKey('c.id')) - ) - -e = Table('e', m, - Column('id', Integer, primary_key=True) - ) - -f = Table('f', m, - Column('id', Integer, primary_key=True), - Column('a_id', ForeignKey('a.id')) - ) - -b_key = Table('b_key', m, - Column('id', Integer, primary_key=True, key='bid'), - ) - -a_to_b_key = Table('a_to_b_key', m, - Column('aid', Integer, ForeignKey('a.id')), - Column('bid', Integer, ForeignKey('b_key.bid')), - ) +a = Table("a", m, Column("id", Integer, primary_key=True)) + +b = Table( + "b", + m, + Column("id", Integer, primary_key=True), + Column("a_id", Integer, ForeignKey("a.id")), +) + +b_a = Table("b_a", m, Column("id", Integer, primary_key=True)) + +b1 = Table( + "b1", + m, + Column("id", Integer, primary_key=True), + Column("a_id", Integer, ForeignKey("a.id")), +) + +b2 = Table( + "b2", + m, + Column("id", Integer, primary_key=True), + Column("a_id", Integer, ForeignKey("a.id")), +) + +a_to_b = Table( + "a_to_b", + m, + Column("a_id", Integer, ForeignKey("a.id")), + Column("b_id", Integer, ForeignKey("b.id")), +) + +c = Table( + "c", + m, + Column("id", Integer, primary_key=True), + Column("b_id", Integer, ForeignKey("b.id")), +) + +d = Table( + "d", + m, + Column("id", Integer, primary_key=True), + Column("c_id", Integer, ForeignKey("c.id")), +) + +e = Table("e", m, Column("id", Integer, primary_key=True)) + +f = Table( + "f", + m, + Column("id", Integer, primary_key=True), + Column("a_id", ForeignKey("a.id")), +) + +b_key = Table("b_key", m, Column("id", Integer, primary_key=True, key="bid")) + +a_to_b_key = Table( + "a_to_b_key", + m, + Column("aid", Integer, ForeignKey("a.id")), + Column("bid", Integer, ForeignKey("b_key.bid")), +) class _JoinRewriteTestBase(AssertsCompiledSQL): - def _test(self, s, assert_): - self.assert_compile( - s, - assert_ - ) + self.assert_compile(s, assert_) compiled = s.compile(dialect=self.__dialect__) @@ -107,10 +119,13 @@ class _JoinRewriteTestBase(AssertsCompiledSQL): # TODO: do this test also with individual cols, things change # lots based on how you go with this - s = select([a, b, c], use_labels=True).\ - select_from(j2).\ - where(b.c.id == 2).\ - where(c.c.id == 3).order_by(a.c.id, b.c.id, c.c.id) + s = ( + select([a, b, c], use_labels=True) + .select_from(j2) + .where(b.c.id == 2) + .where(c.c.id == 3) + .order_by(a.c.id, b.c.id, c.c.id) + ) self._test(s, self._a_bc) @@ -118,8 +133,7 @@ class _JoinRewriteTestBase(AssertsCompiledSQL): j1 = b_key.join(a_to_b_key) j2 = a.join(j1) - s = select([a, b_key.c.bid], use_labels=True).\ - select_from(j2) + s = select([a, b_key.c.bid], use_labels=True).select_from(j2) self._test(s, self._a_bkeyassoc) @@ -130,8 +144,7 @@ class _JoinRewriteTestBase(AssertsCompiledSQL): j1 = bkey_alias.join(a_to_b_key_alias) j2 = a.join(j1) - s = select([a, bkey_alias.c.bid], use_labels=True).\ - select_from(j2) + s = select([a, bkey_alias.c.bid], use_labels=True).select_from(j2) self._test(s, self._a_bkeyassoc_aliased) @@ -140,18 +153,17 @@ class _JoinRewriteTestBase(AssertsCompiledSQL): j2 = b.join(j1) j3 = a.join(j2) - s = select([a, b, c, d], use_labels=True).\ - select_from(j3).\ - where(b.c.id == 2).\ - where(c.c.id == 3).\ - where(d.c.id == 4).\ - order_by(a.c.id, b.c.id, c.c.id, d.c.id) - - self._test( - s, - self._a__b_dc + s = ( + select([a, b, c, d], use_labels=True) + .select_from(j3) + .where(b.c.id == 2) + .where(c.c.id == 3) + .where(d.c.id == 4) + .order_by(a.c.id, b.c.id, c.c.id, d.c.id) ) + self._test(s, self._a__b_dc) + def test_a_bc_comma_a1_selbc(self): # test here we're emulating is # test.orm.inheritance.test_polymorphic_rel: @@ -162,14 +174,15 @@ class _JoinRewriteTestBase(AssertsCompiledSQL): a_a = a.alias() j4 = a_a.join(j2) - s = select([a, a_a, b, c, j2], use_labels=True).\ - select_from(j3).select_from(j4).order_by(j2.c.b_id) - - self._test( - s, - self._a_bc_comma_a1_selbc + s = ( + select([a, a_a, b, c, j2], use_labels=True) + .select_from(j3) + .select_from(j4) + .order_by(j2.c.b_id) ) + self._test(s, self._a_bc_comma_a1_selbc) + def test_a_atobalias_balias_c_w_exists(self): a_to_b_alias = a_to_b.alias() b_alias = b.alias() @@ -179,19 +192,22 @@ class _JoinRewriteTestBase(AssertsCompiledSQL): # TODO: if we put straight a_to_b_alias here, # it fails to alias the columns clause. - s = select([a, - a_to_b_alias.c.a_id, - a_to_b_alias.c.b_id, - b_alias.c.id, - b_alias.c.a_id, - exists().select_from(c). - where(c.c.b_id == b_alias.c.id).label(None)], - use_labels=True).select_from(j2) - - self._test( - s, - self._a_atobalias_balias_c_w_exists - ) + s = select( + [ + a, + a_to_b_alias.c.a_id, + a_to_b_alias.c.b_id, + b_alias.c.id, + b_alias.c.a_id, + exists() + .select_from(c) + .where(c.c.b_id == b_alias.c.id) + .label(None), + ], + use_labels=True, + ).select_from(j2) + + self._test(s, self._a_atobalias_balias_c_w_exists) def test_a_atobalias_balias(self): a_to_b_alias = a_to_b.alias() @@ -202,10 +218,7 @@ class _JoinRewriteTestBase(AssertsCompiledSQL): s = select([a, a_to_b_alias, b_alias], use_labels=True).select_from(j2) - self._test( - s, - self._a_atobalias_balias - ) + self._test(s, self._a_atobalias_balias) def test_b_ab1_union_b_ab2(self): j1 = a.join(b1) @@ -215,14 +228,10 @@ class _JoinRewriteTestBase(AssertsCompiledSQL): b_j2 = b.join(j2) s = union( - select([b_j1], use_labels=True), - select([b_j2], use_labels=True) + select([b_j1], use_labels=True), select([b_j2], use_labels=True) ).select(use_labels=True) - self._test( - s, - self._b_ab1_union_c_ab2 - ) + self._test(s, self._b_ab1_union_c_ab2) def test_b_a_id_double_overlap_annotated(self): # test issue #3057 @@ -231,17 +240,14 @@ class _JoinRewriteTestBase(AssertsCompiledSQL): annot = [ b.c.id._annotate({}), b.c.a_id._annotate({}), - b_a.c.id._annotate({}) + b_a.c.id._annotate({}), ] s = select(annot).select_from(j1).apply_labels().alias() s = select(list(s.c)).apply_labels() - self._test( - s, - self._b_a_id_double_overlap_annotated - ) + self._test(s, self._b_a_id_double_overlap_annotated) def test_f_b1a_where_in_b2a(self): # test issue #3130 @@ -251,20 +257,14 @@ class _JoinRewriteTestBase(AssertsCompiledSQL): s = select([f]).select_from(f.join(b1a)).where(b1.c.id.in_(subq)) s = s.apply_labels() - self._test( - s, - self._f_b1a_where_in_b2a - ) + self._test(s, self._f_b1a_where_in_b2a) def test_anon_scalar_subqueries(self): s1 = select([1]).as_scalar() s2 = select([2]).as_scalar() s = select([s1, s2]).apply_labels() - self._test( - s, - self._anon_scalar_subqueries - ) + self._test(s, self._anon_scalar_subqueries) class JoinRewriteTest(_JoinRewriteTestBase, fixtures.TestBase): @@ -350,7 +350,8 @@ class JoinRewriteTest(_JoinRewriteTestBase, fixtures.TestBase): "FROM (SELECT a_to_b_key.aid AS aid, a_to_b_key.bid AS bid " "FROM a_to_b_key) AS anon_1 " "JOIN b_key ON b_key.id = anon_1.bid) AS anon_2 " - "ON a.id = anon_2.anon_1_aid") + "ON a.id = anon_2.anon_1_aid" + ) _a_atobalias_balias_c_w_exists = ( "SELECT a.id AS a_id, " @@ -363,7 +364,8 @@ class JoinRewriteTest(_JoinRewriteTestBase, fixtures.TestBase): "b_1.a_id AS b_1_a_id " "FROM a_to_b AS a_to_b_1 " "JOIN b AS b_1 ON b_1.id = a_to_b_1.b_id) AS anon_1 " - "ON a.id = anon_1.a_to_b_1_a_id") + "ON a.id = anon_1.a_to_b_1_a_id" + ) _a_atobalias_balias = ( "SELECT a.id AS a_id, anon_1.a_to_b_1_a_id AS a_to_b_1_a_id, " @@ -373,7 +375,8 @@ class JoinRewriteTest(_JoinRewriteTestBase, fixtures.TestBase): "a_to_b_1.b_id AS a_to_b_1_b_id, " "b_1.id AS b_1_id, b_1.a_id AS b_1_a_id FROM a_to_b AS a_to_b_1 " "JOIN b AS b_1 ON b_1.id = a_to_b_1.b_id) AS anon_1 " - "ON a.id = anon_1.a_to_b_1_a_id") + "ON a.id = anon_1.a_to_b_1_a_id" + ) _b_ab1_union_c_ab2 = ( "SELECT b_id AS b_id, b_a_id AS b_a_id, a_id AS a_id, b1_id AS b1_id, " @@ -412,6 +415,7 @@ class JoinRewriteTest(_JoinRewriteTestBase, fixtures.TestBase): class JoinPlainTest(_JoinRewriteTestBase, fixtures.TestBase): """test rendering of each join with normal nesting.""" + @util.classproperty def __dialect__(cls): dialect = default.DefaultDialect() @@ -480,7 +484,8 @@ class JoinPlainTest(_JoinRewriteTestBase, fixtures.TestBase): "EXISTS (SELECT * FROM c WHERE c.b_id = b_1.id) AS anon_1 " "FROM a LEFT OUTER JOIN " "(a_to_b AS a_to_b_1 JOIN b AS b_1 ON b_1.id = a_to_b_1.b_id) " - "ON a.id = a_to_b_1.a_id") + "ON a.id = a_to_b_1.a_id" + ) _a_atobalias_balias = ( "SELECT a.id AS a_id, a_to_b_1.a_id AS a_to_b_1_a_id, " @@ -500,7 +505,8 @@ class JoinPlainTest(_JoinRewriteTestBase, fixtures.TestBase): "UNION " "SELECT b.id AS b_id, b.a_id AS b_a_id, a.id AS a_id, b2.id AS b2_id, " "b2.a_id AS b2_a_id FROM b " - "JOIN (a JOIN b2 ON a.id = b2.a_id) ON a.id = b.a_id)") + "JOIN (a JOIN b2 ON a.id = b2.a_id) ON a.id = b.a_id)" + ) _b_a_id_double_overlap_annotated = ( "SELECT anon_1.b_id AS anon_1_b_id, anon_1.b_a_id AS anon_1_b_a_id, " @@ -522,7 +528,6 @@ class JoinPlainTest(_JoinRewriteTestBase, fixtures.TestBase): class JoinNoUseLabelsTest(_JoinRewriteTestBase, fixtures.TestBase): - @util.classproperty def __dialect__(cls): dialect = default.DefaultDialect() @@ -531,10 +536,7 @@ class JoinNoUseLabelsTest(_JoinRewriteTestBase, fixtures.TestBase): def _test(self, s, assert_): s.use_labels = False - self.assert_compile( - s, - assert_ - ) + self.assert_compile(s, assert_) _a_bkeyselect_bkey = ( "SELECT a.id, b_key.id FROM a JOIN ((SELECT a_to_b_key.aid AS aid, " @@ -639,11 +641,23 @@ class JoinExecTest(_JoinRewriteTestBase, fixtures.TestBase): __backend__ = True - _a_bc = _a_bc_comma_a1_selbc = _a__b_dc = _a_bkeyassoc = \ - _a_bkeyassoc_aliased = _a_atobalias_balias_c_w_exists = \ - _a_atobalias_balias = _b_ab1_union_c_ab2 = \ - _b_a_id_double_overlap_annotated = _f_b1a_where_in_b2a = \ - _anon_scalar_subqueries = None + _a_bc = ( + _a_bc_comma_a1_selbc + ) = ( + _a__b_dc + ) = ( + _a_bkeyassoc + ) = ( + _a_bkeyassoc_aliased + ) = ( + _a_atobalias_balias_c_w_exists + ) = ( + _a_atobalias_balias + ) = ( + _b_ab1_union_c_ab2 + ) = ( + _b_a_id_double_overlap_annotated + ) = _f_b1a_where_in_b2a = _anon_scalar_subqueries = None @classmethod def setup_class(cls): @@ -660,8 +674,9 @@ class JoinExecTest(_JoinRewriteTestBase, fixtures.TestBase): assert col in result._metadata._keymap @testing.skip_if("oracle", "oracle's cranky") - @testing.skip_if("mssql", "can't query EXISTS in the columns " - "clause w/o subquery") + @testing.skip_if( + "mssql", "can't query EXISTS in the columns " "clause w/o subquery" + ) def test_a_atobalias_balias_c_w_exists(self): super(JoinExecTest, self).test_a_atobalias_balias_c_w_exists() @@ -669,13 +684,13 @@ class JoinExecTest(_JoinRewriteTestBase, fixtures.TestBase): "sqlite", "non-standard aliasing rules used at the moment, " "possibly fix this or add another test that uses " - "cross-compatible aliasing") + "cross-compatible aliasing", + ) def test_b_ab1_union_b_ab2(self): super(JoinExecTest, self).test_b_ab1_union_b_ab2() class DialectFlagTest(fixtures.TestBase, AssertsCompiledSQL): - def test_dialect_flag(self): d1 = default.DefaultDialect(supports_right_nested_joins=True) d2 = default.DefaultDialect(supports_right_nested_joins=False) @@ -683,8 +698,7 @@ class DialectFlagTest(fixtures.TestBase, AssertsCompiledSQL): j1 = b.join(c) j2 = a.join(j1) - s = select([a, b, c], use_labels=True).\ - select_from(j2) + s = select([a, b, c], use_labels=True).select_from(j2) self.assert_compile( s, @@ -692,12 +706,16 @@ class DialectFlagTest(fixtures.TestBase, AssertsCompiledSQL): "c.id AS c_id, " "c.b_id AS c_b_id FROM a JOIN (b JOIN c ON b.id = c.b_id) " "ON a.id = b.a_id", - dialect=d1) + dialect=d1, + ) self.assert_compile( - s, "SELECT a.id AS a_id, anon_1.b_id AS b_id, " + s, + "SELECT a.id AS a_id, anon_1.b_id AS b_id, " "anon_1.b_a_id AS b_a_id, " "anon_1.c_id AS c_id, anon_1.c_b_id AS c_b_id " "FROM a JOIN (SELECT b.id AS b_id, b.a_id AS b_a_id, " "c.id AS c_id, " "c.b_id AS c_b_id FROM b JOIN c ON b.id = c.b_id) AS anon_1 " - "ON a.id = anon_1.b_a_id", dialect=d2) + "ON a.id = anon_1.b_a_id", + dialect=d2, + ) diff --git a/test/sql/test_labels.py b/test/sql/test_labels.py index 0b279754f7..8dfaa68599 100644 --- a/test/sql/test_labels.py +++ b/test/sql/test_labels.py @@ -1,33 +1,46 @@ -from sqlalchemy import exc as exceptions, select, MetaData, Integer, or_, \ - bindparam +from sqlalchemy import ( + exc as exceptions, + select, + MetaData, + Integer, + or_, + bindparam, +) from sqlalchemy.engine import default from sqlalchemy.sql import table, column from sqlalchemy.sql.elements import _truncated_label -from sqlalchemy.testing import AssertsCompiledSQL, assert_raises, engines,\ - fixtures, eq_ +from sqlalchemy.testing import ( + AssertsCompiledSQL, + assert_raises, + engines, + fixtures, + eq_, +) from sqlalchemy.testing.schema import Table, Column IDENT_LENGTH = 29 class MaxIdentTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'DefaultDialect' + __dialect__ = "DefaultDialect" - table1 = table('some_large_named_table', - column('this_is_the_primarykey_column'), - column('this_is_the_data_column') - ) + table1 = table( + "some_large_named_table", + column("this_is_the_primarykey_column"), + column("this_is_the_data_column"), + ) - table2 = table('table_with_exactly_29_characs', - column('this_is_the_primarykey_column'), - column('this_is_the_data_column') - ) + table2 = table( + "table_with_exactly_29_characs", + column("this_is_the_primarykey_column"), + column("this_is_the_data_column"), + ) def _length_fixture(self, length=IDENT_LENGTH, positional=False): dialect = default.DefaultDialect() dialect.max_identifier_length = length if positional: - dialect.paramstyle = 'format' + dialect.paramstyle = "format" dialect.positional = True return dialect @@ -39,14 +52,14 @@ class MaxIdentTest(fixtures.TestBase, AssertsCompiledSQL): def test_table_alias_1(self): self.assert_compile( self.table2.alias().select(), - 'SELECT ' - 'table_with_exactly_29_c_1.' - 'this_is_the_primarykey_column, ' - 'table_with_exactly_29_c_1.this_is_the_data_column ' - 'FROM ' - 'table_with_exactly_29_characs ' - 'AS table_with_exactly_29_c_1', - dialect=self._length_fixture() + "SELECT " + "table_with_exactly_29_c_1." + "this_is_the_primarykey_column, " + "table_with_exactly_29_c_1.this_is_the_data_column " + "FROM " + "table_with_exactly_29_characs " + "AS table_with_exactly_29_c_1", + dialect=self._length_fixture(), ) def test_table_alias_2(self): @@ -55,32 +68,36 @@ class MaxIdentTest(fixtures.TestBase, AssertsCompiledSQL): ta = table2.alias() on = table1.c.this_is_the_data_column == ta.c.this_is_the_data_column self.assert_compile( - select([table1, ta]).select_from(table1.join(ta, on)). - where(ta.c.this_is_the_data_column == 'data3'), - 'SELECT ' - 'some_large_named_table.this_is_the_primarykey_column, ' - 'some_large_named_table.this_is_the_data_column, ' - 'table_with_exactly_29_c_1.this_is_the_primarykey_column, ' - 'table_with_exactly_29_c_1.this_is_the_data_column ' - 'FROM ' - 'some_large_named_table ' - 'JOIN ' - 'table_with_exactly_29_characs ' - 'AS ' - 'table_with_exactly_29_c_1 ' - 'ON ' - 'some_large_named_table.this_is_the_data_column = ' - 'table_with_exactly_29_c_1.this_is_the_data_column ' - 'WHERE ' - 'table_with_exactly_29_c_1.this_is_the_data_column = ' - ':this_is_the_data_column_1', - dialect=self._length_fixture() + select([table1, ta]) + .select_from(table1.join(ta, on)) + .where(ta.c.this_is_the_data_column == "data3"), + "SELECT " + "some_large_named_table.this_is_the_primarykey_column, " + "some_large_named_table.this_is_the_data_column, " + "table_with_exactly_29_c_1.this_is_the_primarykey_column, " + "table_with_exactly_29_c_1.this_is_the_data_column " + "FROM " + "some_large_named_table " + "JOIN " + "table_with_exactly_29_characs " + "AS " + "table_with_exactly_29_c_1 " + "ON " + "some_large_named_table.this_is_the_data_column = " + "table_with_exactly_29_c_1.this_is_the_data_column " + "WHERE " + "table_with_exactly_29_c_1.this_is_the_data_column = " + ":this_is_the_data_column_1", + dialect=self._length_fixture(), ) def test_too_long_name_disallowed(self): m = MetaData() - t = Table('this_name_is_too_long_for_what_were_doing_in_this_test', - m, Column('foo', Integer)) + t = Table( + "this_name_is_too_long_for_what_were_doing_in_this_test", + m, + Column("foo", Integer), + ) eng = self._engine_fixture() methods = (t.create, t.drop, m.create_all, m.drop_all) for meth in methods: @@ -91,29 +108,32 @@ class MaxIdentTest(fixtures.TestBase, AssertsCompiledSQL): compiled = s.compile(dialect=self._length_fixture()) assert set( - compiled._create_result_map()['some_large_named_table__2'][1]).\ - issuperset( + compiled._create_result_map()["some_large_named_table__2"][1] + ).issuperset( [ - 'some_large_named_table_this_is_the_data_column', - 'some_large_named_table__2', - table1.c.this_is_the_data_column + "some_large_named_table_this_is_the_data_column", + "some_large_named_table__2", + table1.c.this_is_the_data_column, ] ) assert set( - compiled._create_result_map()['some_large_named_table__1'][1]).\ - issuperset( + compiled._create_result_map()["some_large_named_table__1"][1] + ).issuperset( [ - 'some_large_named_table_this_is_the_primarykey_column', - 'some_large_named_table__1', - table1.c.this_is_the_primarykey_column + "some_large_named_table_this_is_the_primarykey_column", + "some_large_named_table__1", + table1.c.this_is_the_primarykey_column, ] ) def test_result_map_use_labels(self): table1 = self.table1 - s = table1.select().apply_labels().\ - order_by(table1.c.this_is_the_primarykey_column) + s = ( + table1.select() + .apply_labels() + .order_by(table1.c.this_is_the_primarykey_column) + ) self._assert_labeled_table1_select(s) @@ -123,27 +143,30 @@ class MaxIdentTest(fixtures.TestBase, AssertsCompiledSQL): # version) generate a subquery for limits/offsets. ensure that the # generated result map corresponds to the selected table, not the # select query - s = table1.select(use_labels=True, - order_by=[table1.c.this_is_the_primarykey_column]).\ - limit(2) + s = table1.select( + use_labels=True, order_by=[table1.c.this_is_the_primarykey_column] + ).limit(2) self._assert_labeled_table1_select(s) def test_result_map_subquery(self): table1 = self.table1 - s = table1.select( - table1.c.this_is_the_primarykey_column == 4).\ - alias('foo') + s = table1.select(table1.c.this_is_the_primarykey_column == 4).alias( + "foo" + ) s2 = select([s]) compiled = s2.compile(dialect=self._length_fixture()) - assert \ - set(compiled._create_result_map()['this_is_the_data_column'][1]).\ - issuperset(['this_is_the_data_column', - s.c.this_is_the_data_column]) assert set( - compiled._create_result_map()['this_is_the_primarykey__1'][1]).\ - issuperset(['this_is_the_primarykey_column', - 'this_is_the_primarykey__1', - s.c.this_is_the_primarykey_column]) + compiled._create_result_map()["this_is_the_data_column"][1] + ).issuperset(["this_is_the_data_column", s.c.this_is_the_data_column]) + assert set( + compiled._create_result_map()["this_is_the_primarykey__1"][1] + ).issuperset( + [ + "this_is_the_primarykey_column", + "this_is_the_primarykey__1", + s.c.this_is_the_primarykey_column, + ] + ) def test_result_map_anon_alias(self): table1 = self.table1 @@ -169,24 +192,28 @@ class MaxIdentTest(fixtures.TestBase, AssertsCompiledSQL): "some_large_named_table.this_is_the_primarykey_column " "= :this_is_the_primarykey__1" ") " - "AS anon_1", dialect=dialect) + "AS anon_1", + dialect=dialect, + ) compiled = s.compile(dialect=dialect) assert set( - compiled._create_result_map()['anon_1_this_is_the_data_3'][1]).\ - issuperset([ - 'anon_1_this_is_the_data_3', - q.corresponding_column( - table1.c.this_is_the_data_column) - ]) + compiled._create_result_map()["anon_1_this_is_the_data_3"][1] + ).issuperset( + [ + "anon_1_this_is_the_data_3", + q.corresponding_column(table1.c.this_is_the_data_column), + ] + ) assert set( - compiled._create_result_map()['anon_1_this_is_the_prim_1'][1]).\ - issuperset([ - 'anon_1_this_is_the_prim_1', - q.corresponding_column( - table1.c.this_is_the_primarykey_column) - ]) + compiled._create_result_map()["anon_1_this_is_the_prim_1"][1] + ).issuperset( + [ + "anon_1_this_is_the_prim_1", + q.corresponding_column(table1.c.this_is_the_primarykey_column), + ] + ) def test_column_bind_labels_1(self): table1 = self.table1 @@ -199,8 +226,8 @@ class MaxIdentTest(fixtures.TestBase, AssertsCompiledSQL): "FROM some_large_named_table WHERE " "some_large_named_table.this_is_the_primarykey_column = " ":this_is_the_primarykey__1", - checkparams={'this_is_the_primarykey__1': 4}, - dialect=self._length_fixture() + checkparams={"this_is_the_primarykey__1": 4}, + dialect=self._length_fixture(), ) self.assert_compile( @@ -210,18 +237,20 @@ class MaxIdentTest(fixtures.TestBase, AssertsCompiledSQL): "FROM some_large_named_table WHERE " "some_large_named_table.this_is_the_primarykey_column = " "%s", - checkpositional=(4, ), - checkparams={'this_is_the_primarykey__1': 4}, - dialect=self._length_fixture(positional=True) + checkpositional=(4,), + checkparams={"this_is_the_primarykey__1": 4}, + dialect=self._length_fixture(positional=True), ) def test_column_bind_labels_2(self): table1 = self.table1 - s = table1.select(or_( - table1.c.this_is_the_primarykey_column == 4, - table1.c.this_is_the_primarykey_column == 2 - )) + s = table1.select( + or_( + table1.c.this_is_the_primarykey_column == 4, + table1.c.this_is_the_primarykey_column == 2, + ) + ) self.assert_compile( s, "SELECT some_large_named_table.this_is_the_primarykey_column, " @@ -232,10 +261,10 @@ class MaxIdentTest(fixtures.TestBase, AssertsCompiledSQL): "some_large_named_table.this_is_the_primarykey_column = " ":this_is_the_primarykey__2", checkparams={ - 'this_is_the_primarykey__1': 4, - 'this_is_the_primarykey__2': 2 + "this_is_the_primarykey__1": 4, + "this_is_the_primarykey__2": 2, }, - dialect=self._length_fixture() + dialect=self._length_fixture(), ) self.assert_compile( s, @@ -247,140 +276,155 @@ class MaxIdentTest(fixtures.TestBase, AssertsCompiledSQL): "some_large_named_table.this_is_the_primarykey_column = " "%s", checkparams={ - 'this_is_the_primarykey__1': 4, - 'this_is_the_primarykey__2': 2 + "this_is_the_primarykey__1": 4, + "this_is_the_primarykey__2": 2, }, checkpositional=(4, 2), - dialect=self._length_fixture(positional=True) + dialect=self._length_fixture(positional=True), ) def test_bind_param_non_truncated(self): table1 = self.table1 stmt = table1.insert().values( this_is_the_data_column=bindparam( - "this_is_the_long_bindparam_name") + "this_is_the_long_bindparam_name" + ) ) compiled = stmt.compile(dialect=self._length_fixture(length=10)) eq_( compiled.construct_params( - params={"this_is_the_long_bindparam_name": 5}), - {'this_is_the_long_bindparam_name': 5} + params={"this_is_the_long_bindparam_name": 5} + ), + {"this_is_the_long_bindparam_name": 5}, ) def test_bind_param_truncated_named(self): table1 = self.table1 bp = bindparam(_truncated_label("this_is_the_long_bindparam_name")) - stmt = table1.insert().values( - this_is_the_data_column=bp - ) + stmt = table1.insert().values(this_is_the_data_column=bp) compiled = stmt.compile(dialect=self._length_fixture(length=10)) eq_( - compiled.construct_params(params={ - "this_is_the_long_bindparam_name": 5}), - {"this_1": 5} + compiled.construct_params( + params={"this_is_the_long_bindparam_name": 5} + ), + {"this_1": 5}, ) def test_bind_param_truncated_positional(self): table1 = self.table1 bp = bindparam(_truncated_label("this_is_the_long_bindparam_name")) - stmt = table1.insert().values( - this_is_the_data_column=bp - ) + stmt = table1.insert().values(this_is_the_data_column=bp) compiled = stmt.compile( - dialect=self._length_fixture(length=10, positional=True)) + dialect=self._length_fixture(length=10, positional=True) + ) eq_( - compiled.construct_params(params={ - "this_is_the_long_bindparam_name": 5}), - {"this_1": 5} + compiled.construct_params( + params={"this_is_the_long_bindparam_name": 5} + ), + {"this_1": 5}, ) class LabelLengthTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'DefaultDialect' + __dialect__ = "DefaultDialect" - table1 = table('some_large_named_table', - column('this_is_the_primarykey_column'), - column('this_is_the_data_column') - ) + table1 = table( + "some_large_named_table", + column("this_is_the_primarykey_column"), + column("this_is_the_data_column"), + ) - table2 = table('table_with_exactly_29_characs', - column('this_is_the_primarykey_column'), - column('this_is_the_data_column') - ) + table2 = table( + "table_with_exactly_29_characs", + column("this_is_the_primarykey_column"), + column("this_is_the_data_column"), + ) def test_adjustable_1(self): table1 = self.table1 - q = table1.select( - table1.c.this_is_the_primarykey_column == 4).alias('foo') + q = table1.select(table1.c.this_is_the_primarykey_column == 4).alias( + "foo" + ) x = select([q]) compile_dialect = default.DefaultDialect(label_length=10) self.assert_compile( - x, 'SELECT ' - 'foo.this_1, foo.this_2 ' - 'FROM (' - 'SELECT ' - 'some_large_named_table.this_is_the_primarykey_column ' - 'AS this_1, ' - 'some_large_named_table.this_is_the_data_column ' - 'AS this_2 ' - 'FROM ' - 'some_large_named_table ' - 'WHERE ' - 'some_large_named_table.this_is_the_primarykey_column ' - '= :this_1' - ') ' - 'AS foo', dialect=compile_dialect) + x, + "SELECT " + "foo.this_1, foo.this_2 " + "FROM (" + "SELECT " + "some_large_named_table.this_is_the_primarykey_column " + "AS this_1, " + "some_large_named_table.this_is_the_data_column " + "AS this_2 " + "FROM " + "some_large_named_table " + "WHERE " + "some_large_named_table.this_is_the_primarykey_column " + "= :this_1" + ") " + "AS foo", + dialect=compile_dialect, + ) def test_adjustable_2(self): table1 = self.table1 - q = table1.select( - table1.c.this_is_the_primarykey_column == 4).alias('foo') + q = table1.select(table1.c.this_is_the_primarykey_column == 4).alias( + "foo" + ) x = select([q]) compile_dialect = default.DefaultDialect(label_length=10) self.assert_compile( - x, 'SELECT ' - 'foo.this_1, foo.this_2 ' - 'FROM (' - 'SELECT ' - 'some_large_named_table.this_is_the_primarykey_column ' - 'AS this_1, ' - 'some_large_named_table.this_is_the_data_column ' - 'AS this_2 ' - 'FROM ' - 'some_large_named_table ' - 'WHERE ' - 'some_large_named_table.this_is_the_primarykey_column ' - '= :this_1' - ') ' - 'AS foo', dialect=compile_dialect) + x, + "SELECT " + "foo.this_1, foo.this_2 " + "FROM (" + "SELECT " + "some_large_named_table.this_is_the_primarykey_column " + "AS this_1, " + "some_large_named_table.this_is_the_data_column " + "AS this_2 " + "FROM " + "some_large_named_table " + "WHERE " + "some_large_named_table.this_is_the_primarykey_column " + "= :this_1" + ") " + "AS foo", + dialect=compile_dialect, + ) def test_adjustable_3(self): table1 = self.table1 compile_dialect = default.DefaultDialect(label_length=4) - q = table1.select( - table1.c.this_is_the_primarykey_column == 4).alias('foo') + q = table1.select(table1.c.this_is_the_primarykey_column == 4).alias( + "foo" + ) x = select([q]) self.assert_compile( - x, 'SELECT ' - 'foo._1, foo._2 ' - 'FROM (' - 'SELECT ' - 'some_large_named_table.this_is_the_primarykey_column ' - 'AS _1, ' - 'some_large_named_table.this_is_the_data_column ' - 'AS _2 ' - 'FROM ' - 'some_large_named_table ' - 'WHERE ' - 'some_large_named_table.this_is_the_primarykey_column ' - '= :_1' - ') ' - 'AS foo', dialect=compile_dialect) + x, + "SELECT " + "foo._1, foo._2 " + "FROM (" + "SELECT " + "some_large_named_table.this_is_the_primarykey_column " + "AS _1, " + "some_large_named_table.this_is_the_data_column " + "AS _2 " + "FROM " + "some_large_named_table " + "WHERE " + "some_large_named_table.this_is_the_primarykey_column " + "= :_1" + ") " + "AS foo", + dialect=compile_dialect, + ) def test_adjustable_4(self): table1 = self.table1 @@ -390,22 +434,25 @@ class LabelLengthTest(fixtures.TestBase, AssertsCompiledSQL): compile_dialect = default.DefaultDialect(label_length=10) self.assert_compile( - x, 'SELECT ' - 'anon_1.this_2 AS anon_1, ' - 'anon_1.this_4 AS anon_3 ' - 'FROM (' - 'SELECT ' - 'some_large_named_table.this_is_the_primarykey_column ' - 'AS this_2, ' - 'some_large_named_table.this_is_the_data_column ' - 'AS this_4 ' - 'FROM ' - 'some_large_named_table ' - 'WHERE ' - 'some_large_named_table.this_is_the_primarykey_column ' - '= :this_1' - ') ' - 'AS anon_1', dialect=compile_dialect) + x, + "SELECT " + "anon_1.this_2 AS anon_1, " + "anon_1.this_4 AS anon_3 " + "FROM (" + "SELECT " + "some_large_named_table.this_is_the_primarykey_column " + "AS this_2, " + "some_large_named_table.this_is_the_data_column " + "AS this_4 " + "FROM " + "some_large_named_table " + "WHERE " + "some_large_named_table.this_is_the_primarykey_column " + "= :this_1" + ") " + "AS anon_1", + dialect=compile_dialect, + ) def test_adjustable_5(self): table1 = self.table1 @@ -414,64 +461,80 @@ class LabelLengthTest(fixtures.TestBase, AssertsCompiledSQL): compile_dialect = default.DefaultDialect(label_length=4) self.assert_compile( - x, 'SELECT ' - '_1._2 AS _1, ' - '_1._4 AS _3 ' - 'FROM (' - 'SELECT ' - 'some_large_named_table.this_is_the_primarykey_column ' - 'AS _2, ' - 'some_large_named_table.this_is_the_data_column ' - 'AS _4 ' - 'FROM ' - 'some_large_named_table ' - 'WHERE ' - 'some_large_named_table.this_is_the_primarykey_column ' - '= :_1' - ') ' - 'AS _1', dialect=compile_dialect) + x, + "SELECT " + "_1._2 AS _1, " + "_1._4 AS _3 " + "FROM (" + "SELECT " + "some_large_named_table.this_is_the_primarykey_column " + "AS _2, " + "some_large_named_table.this_is_the_data_column " + "AS _4 " + "FROM " + "some_large_named_table " + "WHERE " + "some_large_named_table.this_is_the_primarykey_column " + "= :_1" + ") " + "AS _1", + dialect=compile_dialect, + ) def test_adjustable_result_schema_column_1(self): table1 = self.table1 - q = table1.select( - table1.c.this_is_the_primarykey_column == 4).apply_labels().\ - alias('foo') + q = ( + table1.select(table1.c.this_is_the_primarykey_column == 4) + .apply_labels() + .alias("foo") + ) dialect = default.DefaultDialect(label_length=10) compiled = q.compile(dialect=dialect) - assert set(compiled._create_result_map()['some_2'][1]).issuperset([ - table1.c.this_is_the_data_column, - 'some_large_named_table_this_is_the_data_column', - 'some_2' - ]) + assert set(compiled._create_result_map()["some_2"][1]).issuperset( + [ + table1.c.this_is_the_data_column, + "some_large_named_table_this_is_the_data_column", + "some_2", + ] + ) - assert set(compiled._create_result_map()['some_1'][1]).issuperset([ - table1.c.this_is_the_primarykey_column, - 'some_large_named_table_this_is_the_primarykey_column', - 'some_1' - ]) + assert set(compiled._create_result_map()["some_1"][1]).issuperset( + [ + table1.c.this_is_the_primarykey_column, + "some_large_named_table_this_is_the_primarykey_column", + "some_1", + ] + ) def test_adjustable_result_schema_column_2(self): table1 = self.table1 - q = table1.select( - table1.c.this_is_the_primarykey_column == 4).alias('foo') + q = table1.select(table1.c.this_is_the_primarykey_column == 4).alias( + "foo" + ) x = select([q]) dialect = default.DefaultDialect(label_length=10) compiled = x.compile(dialect=dialect) - assert set(compiled._create_result_map()['this_2'][1]).issuperset([ - q.corresponding_column(table1.c.this_is_the_data_column), - 'this_is_the_data_column', - 'this_2']) + assert set(compiled._create_result_map()["this_2"][1]).issuperset( + [ + q.corresponding_column(table1.c.this_is_the_data_column), + "this_is_the_data_column", + "this_2", + ] + ) - assert set(compiled._create_result_map()['this_1'][1]).issuperset([ - q.corresponding_column(table1.c.this_is_the_primarykey_column), - 'this_is_the_primarykey_column', - 'this_1']) + assert set(compiled._create_result_map()["this_1"][1]).issuperset( + [ + q.corresponding_column(table1.c.this_is_the_primarykey_column), + "this_is_the_primarykey_column", + "this_1", + ] + ) def test_table_plus_column_exceeds_length(self): """test that the truncation only occurs when tablename + colname are @@ -480,78 +543,78 @@ class LabelLengthTest(fixtures.TestBase, AssertsCompiledSQL): """ compile_dialect = default.DefaultDialect(label_length=30) - a_table = table( - 'thirty_characters_table_xxxxxx', - column('id') - ) + a_table = table("thirty_characters_table_xxxxxx", column("id")) other_table = table( - 'other_thirty_characters_table_', - column('id'), - column('thirty_characters_table_id') + "other_thirty_characters_table_", + column("id"), + column("thirty_characters_table_id"), ) anon = a_table.alias() j1 = other_table.outerjoin( - anon, - anon.c.id == other_table.c.thirty_characters_table_id) + anon, anon.c.id == other_table.c.thirty_characters_table_id + ) self.assert_compile( - select([other_table, anon]). - select_from(j1).apply_labels(), - 'SELECT ' - 'other_thirty_characters_table_.id ' - 'AS other_thirty_characters__1, ' - 'other_thirty_characters_table_.thirty_characters_table_id ' - 'AS other_thirty_characters__2, ' - 'thirty_characters_table__1.id ' - 'AS thirty_characters_table__3 ' - 'FROM ' - 'other_thirty_characters_table_ ' - 'LEFT OUTER JOIN ' - 'thirty_characters_table_xxxxxx AS thirty_characters_table__1 ' - 'ON thirty_characters_table__1.id = ' - 'other_thirty_characters_table_.thirty_characters_table_id', - dialect=compile_dialect) + select([other_table, anon]).select_from(j1).apply_labels(), + "SELECT " + "other_thirty_characters_table_.id " + "AS other_thirty_characters__1, " + "other_thirty_characters_table_.thirty_characters_table_id " + "AS other_thirty_characters__2, " + "thirty_characters_table__1.id " + "AS thirty_characters_table__3 " + "FROM " + "other_thirty_characters_table_ " + "LEFT OUTER JOIN " + "thirty_characters_table_xxxxxx AS thirty_characters_table__1 " + "ON thirty_characters_table__1.id = " + "other_thirty_characters_table_.thirty_characters_table_id", + dialect=compile_dialect, + ) def test_colnames_longer_than_labels_lowercase(self): - t1 = table('a', column('abcde')) + t1 = table("a", column("abcde")) self._test_colnames_longer_than_labels(t1) def test_colnames_longer_than_labels_uppercase(self): m = MetaData() - t1 = Table('a', m, Column('abcde', Integer)) + t1 = Table("a", m, Column("abcde", Integer)) self._test_colnames_longer_than_labels(t1) def _test_colnames_longer_than_labels(self, t1): dialect = default.DefaultDialect(label_length=4) - a1 = t1.alias(name='asdf') + a1 = t1.alias(name="asdf") # 'abcde' is longer than 4, but rendered as itself # needs to have all characters s = select([a1]) - self.assert_compile(select([a1]), - 'SELECT asdf.abcde FROM a AS asdf', - dialect=dialect) + self.assert_compile( + select([a1]), "SELECT asdf.abcde FROM a AS asdf", dialect=dialect + ) compiled = s.compile(dialect=dialect) - assert set(compiled._create_result_map()['abcde'][1]).issuperset([ - 'abcde', a1.c.abcde, 'abcde']) + assert set(compiled._create_result_map()["abcde"][1]).issuperset( + ["abcde", a1.c.abcde, "abcde"] + ) # column still there, but short label s = select([a1]).apply_labels() - self.assert_compile(s, - 'SELECT asdf.abcde AS _1 FROM a AS asdf', - dialect=dialect) + self.assert_compile( + s, "SELECT asdf.abcde AS _1 FROM a AS asdf", dialect=dialect + ) compiled = s.compile(dialect=dialect) - assert set(compiled._create_result_map()['_1'][1]).issuperset([ - 'asdf_abcde', a1.c.abcde, '_1']) + assert set(compiled._create_result_map()["_1"][1]).issuperset( + ["asdf_abcde", a1.c.abcde, "_1"] + ) def test_label_overlap_unlabeled(self): """test that an anon col can't overlap with a fixed name, #3396""" table1 = table( - "tablename", column('columnname_one'), column('columnn_1')) + "tablename", column("columnname_one"), column("columnn_1") + ) stmt = select([table1]).apply_labels() @@ -560,10 +623,10 @@ class LabelLengthTest(fixtures.TestBase, AssertsCompiledSQL): stmt, "SELECT tablename.columnname_one AS tablename_columnn_1, " "tablename.columnn_1 AS tablename_columnn_2 FROM tablename", - dialect=dialect + dialect=dialect, ) compiled = stmt.compile(dialect=dialect) eq_( set(compiled._create_result_map()), - set(['tablename_columnn_1', 'tablename_columnn_2']) + set(["tablename_columnn_1", "tablename_columnn_2"]), ) diff --git a/test/sql/test_lateral.py b/test/sql/test_lateral.py index 785dcd9603..163f636f6f 100644 --- a/test/sql/test_lateral.py +++ b/test/sql/test_lateral.py @@ -16,23 +16,33 @@ class LateralTest(fixtures.TablesTest, AssertsCompiledSQL): @classmethod def define_tables(cls, metadata): - Table('people', metadata, - Column('people_id', Integer, primary_key=True), - Column('age', Integer), - Column('name', String(30))) - Table('bookcases', metadata, - Column('bookcase_id', Integer, primary_key=True), - Column( - 'bookcase_owner_id', - Integer, ForeignKey('people.people_id')), - Column('bookcase_shelves', Integer), - Column('bookcase_width', Integer)) - Table('books', metadata, - Column('book_id', Integer, primary_key=True), - Column( - 'bookcase_id', Integer, ForeignKey('bookcases.bookcase_id')), - Column('book_owner_id', Integer, ForeignKey('people.people_id')), - Column('book_weight', Integer)) + Table( + "people", + metadata, + Column("people_id", Integer, primary_key=True), + Column("age", Integer), + Column("name", String(30)), + ) + Table( + "bookcases", + metadata, + Column("bookcase_id", Integer, primary_key=True), + Column( + "bookcase_owner_id", Integer, ForeignKey("people.people_id") + ), + Column("bookcase_shelves", Integer), + Column("bookcase_width", Integer), + ) + Table( + "books", + metadata, + Column("book_id", Integer, primary_key=True), + Column( + "bookcase_id", Integer, ForeignKey("bookcases.bookcase_id") + ), + Column("book_owner_id", Integer, ForeignKey("people.people_id")), + Column("book_weight", Integer), + ) def test_standalone(self): table1 = self.tables.people @@ -42,12 +52,12 @@ class LateralTest(fixtures.TablesTest, AssertsCompiledSQL): # in the context of a FROM clause self.assert_compile( lateral(subq, name="alias"), - 'LATERAL (SELECT people.people_id FROM people)' + "LATERAL (SELECT people.people_id FROM people)", ) self.assert_compile( subq.lateral(name="alias"), - 'LATERAL (SELECT people.people_id FROM people)' + "LATERAL (SELECT people.people_id FROM people)", ) def test_select_from(self): @@ -56,16 +66,17 @@ class LateralTest(fixtures.TablesTest, AssertsCompiledSQL): # in a FROM context, now you get "AS alias" and column labeling self.assert_compile( - select([subq.lateral(name='alias')]), - 'SELECT alias.people_id FROM LATERAL ' - '(SELECT people.people_id AS people_id FROM people) AS alias' + select([subq.lateral(name="alias")]), + "SELECT alias.people_id FROM LATERAL " + "(SELECT people.people_id AS people_id FROM people) AS alias", ) def test_plain_join(self): table1 = self.tables.people table2 = self.tables.books - subq = select([table2.c.book_id]).\ - where(table2.c.book_owner_id == table1.c.people_id) + subq = select([table2.c.book_id]).where( + table2.c.book_owner_id == table1.c.people_id + ) # FROM books, people? isn't this wrong? No! Because # this is only a fragment, books isn't in any other FROM clause @@ -73,7 +84,7 @@ class LateralTest(fixtures.TablesTest, AssertsCompiledSQL): join(table1, lateral(subq, name="alias"), true()), "people JOIN LATERAL (SELECT books.book_id AS book_id " "FROM books, people WHERE books.book_owner_id = people.people_id) " - "AS alias ON true" + "AS alias ON true", ) # put it in correct context, implicit correlation works fine @@ -84,7 +95,7 @@ class LateralTest(fixtures.TablesTest, AssertsCompiledSQL): "SELECT people.people_id, people.age, people.name " "FROM people JOIN LATERAL (SELECT books.book_id AS book_id " "FROM books WHERE books.book_owner_id = people.people_id) " - "AS alias ON true" + "AS alias ON true", ) # explicit correlation @@ -96,25 +107,29 @@ class LateralTest(fixtures.TablesTest, AssertsCompiledSQL): "SELECT people.people_id, people.age, people.name " "FROM people JOIN LATERAL (SELECT books.book_id AS book_id " "FROM books WHERE books.book_owner_id = people.people_id) " - "AS alias ON true" + "AS alias ON true", ) def test_join_lateral_w_select_subquery(self): table1 = self.tables.people table2 = self.tables.books - subq = select([table2.c.book_id]).\ - correlate(table1).\ - where(table1.c.people_id == table2.c.book_owner_id).lateral() - stmt = select([table1, subq.c.book_id]).\ - select_from(table1.join(subq, true())) + subq = ( + select([table2.c.book_id]) + .correlate(table1) + .where(table1.c.people_id == table2.c.book_owner_id) + .lateral() + ) + stmt = select([table1, subq.c.book_id]).select_from( + table1.join(subq, true()) + ) self.assert_compile( stmt, "SELECT people.people_id, people.age, people.name, anon_1.book_id " "FROM people JOIN LATERAL (SELECT books.book_id AS book_id " "FROM books " - "WHERE people.people_id = books.book_owner_id) AS anon_1 ON true" + "WHERE people.people_id = books.book_owner_id) AS anon_1 ON true", ) def test_from_function(self): @@ -127,5 +142,5 @@ class LateralTest(fixtures.TablesTest, AssertsCompiledSQL): "bookcases.bookcase_shelves, bookcases.bookcase_width " "FROM bookcases JOIN " "LATERAL generate_series(:generate_series_1, " - "bookcases.bookcase_shelves) AS anon_1 ON true" + "bookcases.bookcase_shelves) AS anon_1 ON true", ) diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index f030a7e77c..9a28b0c7b7 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -2,12 +2,33 @@ from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import emits_warning import pickle -from sqlalchemy import Integer, String, UniqueConstraint, \ - CheckConstraint, ForeignKey, MetaData, Sequence, \ - ForeignKeyConstraint, PrimaryKeyConstraint, ColumnDefault, Index, event,\ - events, Unicode, types as sqltypes, bindparam, \ - Table, Column, Boolean, Enum, func, text, TypeDecorator, \ - BLANK_SCHEMA, ARRAY +from sqlalchemy import ( + Integer, + String, + UniqueConstraint, + CheckConstraint, + ForeignKey, + MetaData, + Sequence, + ForeignKeyConstraint, + PrimaryKeyConstraint, + ColumnDefault, + Index, + event, + events, + Unicode, + types as sqltypes, + bindparam, + Table, + Column, + Boolean, + Enum, + func, + text, + TypeDecorator, + BLANK_SCHEMA, + ARRAY, +) from sqlalchemy import schema, exc from sqlalchemy.engine import default from sqlalchemy.sql import elements, naming @@ -20,14 +41,14 @@ from contextlib import contextmanager from sqlalchemy import util from sqlalchemy.testing import engines -class MetaDataTest(fixtures.TestBase, ComparesTables): +class MetaDataTest(fixtures.TestBase, ComparesTables): def test_metadata_contains(self): metadata = MetaData() - t1 = Table('t1', metadata, Column('x', Integer)) - t2 = Table('t2', metadata, Column('x', Integer), schema='foo') - t3 = Table('t2', MetaData(), Column('x', Integer)) - t4 = Table('t1', MetaData(), Column('x', Integer), schema='foo') + t1 = Table("t1", metadata, Column("x", Integer)) + t2 = Table("t2", metadata, Column("x", Integer), schema="foo") + t3 = Table("t2", MetaData(), Column("x", Integer)) + t4 = Table("t1", MetaData(), Column("x", Integer), schema="foo") assert "t1" in metadata assert "foo.t2" in metadata @@ -40,50 +61,68 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): def test_uninitialized_column_copy(self): for col in [ - Column('foo', String(), nullable=False), - Column('baz', String(), unique=True), + Column("foo", String(), nullable=False), + Column("baz", String(), unique=True), Column(Integer(), primary_key=True), - Column('bar', Integer(), Sequence('foo_seq'), primary_key=True, - key='bar'), - Column(Integer(), ForeignKey('bat.blah'), doc="this is a col"), - Column('bar', Integer(), ForeignKey('bat.blah'), primary_key=True, - key='bar'), - Column('bar', Integer(), info={'foo': 'bar'}), + Column( + "bar", + Integer(), + Sequence("foo_seq"), + primary_key=True, + key="bar", + ), + Column(Integer(), ForeignKey("bat.blah"), doc="this is a col"), + Column( + "bar", + Integer(), + ForeignKey("bat.blah"), + primary_key=True, + key="bar", + ), + Column("bar", Integer(), info={"foo": "bar"}), ]: c2 = col.copy() - for attr in ('name', 'type', 'nullable', - 'primary_key', 'key', 'unique', 'info', - 'doc'): + for attr in ( + "name", + "type", + "nullable", + "primary_key", + "key", + "unique", + "info", + "doc", + ): eq_(getattr(col, attr), getattr(c2, attr)) eq_(len(col.foreign_keys), len(c2.foreign_keys)) if col.default: - eq_(c2.default.name, 'foo_seq') + eq_(c2.default.name, "foo_seq") for a1, a2 in zip(col.foreign_keys, c2.foreign_keys): assert a1 is not a2 - eq_(a2._colspec, 'bat.blah') + eq_(a2._colspec, "bat.blah") def test_col_subclass_copy(self): class MyColumn(schema.Column): - def __init__(self, *args, **kw): - self.widget = kw.pop('widget', None) + self.widget = kw.pop("widget", None) super(MyColumn, self).__init__(*args, **kw) def copy(self, *arg, **kw): c = super(MyColumn, self).copy(*arg, **kw) c.widget = self.widget return c - c1 = MyColumn('foo', Integer, widget='x') + + c1 = MyColumn("foo", Integer, widget="x") c2 = c1.copy() assert isinstance(c2, MyColumn) - eq_(c2.widget, 'x') + eq_(c2.widget, "x") def test_uninitialized_column_copy_events(self): msgs = [] def write(c, t): msgs.append("attach %s.%s" % (t.name, c.name)) - c1 = Column('foo', String()) + + c1 = Column("foo", String()) m = MetaData() for i in range(3): cx = c1.copy() @@ -91,39 +130,39 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): # that listeners will be re-established from the # natural construction of things. cx._on_table_attach(write) - Table('foo%d' % i, m, cx) - eq_(msgs, ['attach foo0.foo', 'attach foo1.foo', 'attach foo2.foo']) + Table("foo%d" % i, m, cx) + eq_(msgs, ["attach foo0.foo", "attach foo1.foo", "attach foo2.foo"]) def test_schema_collection_add(self): metadata = MetaData() - Table('t1', metadata, Column('x', Integer), schema='foo') - Table('t2', metadata, Column('x', Integer), schema='bar') - Table('t3', metadata, Column('x', Integer)) + Table("t1", metadata, Column("x", Integer), schema="foo") + Table("t2", metadata, Column("x", Integer), schema="bar") + Table("t3", metadata, Column("x", Integer)) - eq_(metadata._schemas, set(['foo', 'bar'])) + eq_(metadata._schemas, set(["foo", "bar"])) eq_(len(metadata.tables), 3) def test_schema_collection_remove(self): metadata = MetaData() - t1 = Table('t1', metadata, Column('x', Integer), schema='foo') - Table('t2', metadata, Column('x', Integer), schema='bar') - t3 = Table('t3', metadata, Column('x', Integer), schema='bar') + t1 = Table("t1", metadata, Column("x", Integer), schema="foo") + Table("t2", metadata, Column("x", Integer), schema="bar") + t3 = Table("t3", metadata, Column("x", Integer), schema="bar") metadata.remove(t3) - eq_(metadata._schemas, set(['foo', 'bar'])) + eq_(metadata._schemas, set(["foo", "bar"])) eq_(len(metadata.tables), 2) metadata.remove(t1) - eq_(metadata._schemas, set(['bar'])) + eq_(metadata._schemas, set(["bar"])) eq_(len(metadata.tables), 1) def test_schema_collection_remove_all(self): metadata = MetaData() - Table('t1', metadata, Column('x', Integer), schema='foo') - Table('t2', metadata, Column('x', Integer), schema='bar') + Table("t1", metadata, Column("x", Integer), schema="foo") + Table("t2", metadata, Column("x", Integer), schema="bar") metadata.clear() eq_(metadata._schemas, set()) @@ -132,46 +171,56 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): def test_metadata_tables_immutable(self): metadata = MetaData() - Table('t1', metadata, Column('x', Integer)) - assert 't1' in metadata.tables + Table("t1", metadata, Column("x", Integer)) + assert "t1" in metadata.tables - assert_raises( - TypeError, - lambda: metadata.tables.pop('t1') - ) + assert_raises(TypeError, lambda: metadata.tables.pop("t1")) @testing.provide_metadata def test_dupe_tables(self): metadata = self.metadata - Table('table1', metadata, - Column('col1', Integer, primary_key=True), - Column('col2', String(20))) + Table( + "table1", + metadata, + Column("col1", Integer, primary_key=True), + Column("col2", String(20)), + ) metadata.create_all() - Table('table1', metadata, autoload=True) + Table("table1", metadata, autoload=True) def go(): - Table('table1', metadata, - Column('col1', Integer, primary_key=True), - Column('col2', String(20))) + Table( + "table1", + metadata, + Column("col1", Integer, primary_key=True), + Column("col2", String(20)), + ) + assert_raises_message( tsa.exc.InvalidRequestError, "Table 'table1' is already defined for this " "MetaData instance. Specify 'extend_existing=True' " "to redefine options and columns on an existing " "Table object.", - go + go, ) def test_fk_copy(self): - c1 = Column('foo', Integer) - c2 = Column('bar', Integer) + c1 = Column("foo", Integer) + c2 = Column("bar", Integer) m = MetaData() - t1 = Table('t', m, c1, c2) - - kw = dict(onupdate="X", - ondelete="Y", use_alter=True, name='f1', - deferrable="Z", initially="Q", link_to_name=True) + t1 = Table("t", m, c1, c2) + + kw = dict( + onupdate="X", + ondelete="Y", + use_alter=True, + name="f1", + deferrable="Z", + initially="Q", + link_to_name=True, + ) fk1 = ForeignKey(c1, **kw) fk2 = ForeignKeyConstraint((c1,), (c2,), **kw) @@ -185,14 +234,18 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): eq_(getattr(fk2c, k), kw[k]) def test_check_constraint_copy(self): - def r(x): return x - c = CheckConstraint("foo bar", - name='name', - initially=True, - deferrable=True, - _create_rule=r) + def r(x): + return x + + c = CheckConstraint( + "foo bar", + name="name", + initially=True, + deferrable=True, + _create_rule=r, + ) c2 = c.copy() - eq_(c2.name, 'name') + eq_(c2.name, "name") eq_(str(c2.sqltext), "foo bar") eq_(c2.initially, True) eq_(c2.deferrable, True) @@ -200,152 +253,157 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): def test_col_replace_w_constraint(self): m = MetaData() - a = Table('a', m, Column('id', Integer, primary_key=True)) + a = Table("a", m, Column("id", Integer, primary_key=True)) - aid = Column('a_id', ForeignKey('a.id')) - b = Table('b', m, aid) + aid = Column("a_id", ForeignKey("a.id")) + b = Table("b", m, aid) b.append_column(aid) assert b.c.a_id.references(a.c.id) eq_(len(b.constraints), 2) def test_fk_construct(self): - c1 = Column('foo', Integer) - c2 = Column('bar', Integer) + c1 = Column("foo", Integer) + c2 = Column("bar", Integer) m = MetaData() - t1 = Table('t', m, c1, c2) - fk1 = ForeignKeyConstraint(('foo', ), ('bar', ), table=t1) + t1 = Table("t", m, c1, c2) + fk1 = ForeignKeyConstraint(("foo",), ("bar",), table=t1) assert fk1 in t1.constraints def test_fk_constraint_col_collection_w_table(self): - c1 = Column('foo', Integer) - c2 = Column('bar', Integer) + c1 = Column("foo", Integer) + c2 = Column("bar", Integer) m = MetaData() - t1 = Table('t', m, c1, c2) - fk1 = ForeignKeyConstraint(('foo', ), ('bar', ), table=t1) + t1 = Table("t", m, c1, c2) + fk1 = ForeignKeyConstraint(("foo",), ("bar",), table=t1) eq_(dict(fk1.columns), {"foo": c1}) def test_fk_constraint_col_collection_no_table(self): - fk1 = ForeignKeyConstraint(('foo', 'bat'), ('bar', 'hoho')) + fk1 = ForeignKeyConstraint(("foo", "bat"), ("bar", "hoho")) eq_(dict(fk1.columns), {}) - eq_(fk1.column_keys, ['foo', 'bat']) - eq_(fk1._col_description, 'foo, bat') + eq_(fk1.column_keys, ["foo", "bat"]) + eq_(fk1._col_description, "foo, bat") eq_(fk1._elements, {"foo": fk1.elements[0], "bat": fk1.elements[1]}) def test_fk_constraint_col_collection_no_table_real_cols(self): - c1 = Column('foo', Integer) - c2 = Column('bar', Integer) - fk1 = ForeignKeyConstraint((c1, ), (c2, )) + c1 = Column("foo", Integer) + c2 = Column("bar", Integer) + fk1 = ForeignKeyConstraint((c1,), (c2,)) eq_(dict(fk1.columns), {}) - eq_(fk1.column_keys, ['foo']) - eq_(fk1._col_description, 'foo') + eq_(fk1.column_keys, ["foo"]) + eq_(fk1._col_description, "foo") eq_(fk1._elements, {"foo": fk1.elements[0]}) def test_fk_constraint_col_collection_added_to_table(self): - c1 = Column('foo', Integer) + c1 = Column("foo", Integer) m = MetaData() - fk1 = ForeignKeyConstraint(('foo', ), ('bar', )) - Table('t', m, c1, fk1) + fk1 = ForeignKeyConstraint(("foo",), ("bar",)) + Table("t", m, c1, fk1) eq_(dict(fk1.columns), {"foo": c1}) eq_(fk1._elements, {"foo": fk1.elements[0]}) def test_fk_constraint_col_collection_via_fk(self): - fk = ForeignKey('bar') - c1 = Column('foo', Integer, fk) + fk = ForeignKey("bar") + c1 = Column("foo", Integer, fk) m = MetaData() - t1 = Table('t', m, c1) + t1 = Table("t", m, c1) fk1 = fk.constraint - eq_(fk1.column_keys, ['foo']) + eq_(fk1.column_keys, ["foo"]) assert fk1 in t1.constraints - eq_(fk1.column_keys, ['foo']) + eq_(fk1.column_keys, ["foo"]) eq_(dict(fk1.columns), {"foo": c1}) eq_(fk1._elements, {"foo": fk}) def test_fk_no_such_parent_col_error(self): meta = MetaData() - a = Table('a', meta, Column('a', Integer)) - Table('b', meta, Column('b', Integer)) + a = Table("a", meta, Column("a", Integer)) + Table("b", meta, Column("b", Integer)) def go(): - a.append_constraint( - ForeignKeyConstraint(['x'], ['b.b']) - ) + a.append_constraint(ForeignKeyConstraint(["x"], ["b.b"])) + assert_raises_message( exc.ArgumentError, "Can't create ForeignKeyConstraint on " "table 'a': no column named 'x' is present.", - go + go, ) def test_fk_given_non_col(self): - not_a_col = bindparam('x') + not_a_col = bindparam("x") assert_raises_message( exc.ArgumentError, "String, Column, or Column-bound argument expected, got Bind", - ForeignKey, not_a_col + ForeignKey, + not_a_col, ) def test_fk_given_non_col_clauseelem(self): class Foo(object): - def __clause_element__(self): - return bindparam('x') + return bindparam("x") + assert_raises_message( exc.ArgumentError, "String, Column, or Column-bound argument expected, got Bind", - ForeignKey, Foo() + ForeignKey, + Foo(), ) def test_fk_given_col_non_table(self): - t = Table('t', MetaData(), Column('x', Integer)) + t = Table("t", MetaData(), Column("x", Integer)) xa = t.alias().c.x assert_raises_message( exc.ArgumentError, "ForeignKey received Column not bound to a Table, got: .*Alias", - ForeignKey, xa + ForeignKey, + xa, ) def test_fk_given_col_non_table_clauseelem(self): - t = Table('t', MetaData(), Column('x', Integer)) + t = Table("t", MetaData(), Column("x", Integer)) class Foo(object): - def __clause_element__(self): return t.alias().c.x assert_raises_message( exc.ArgumentError, "ForeignKey received Column not bound to a Table, got: .*Alias", - ForeignKey, Foo() + ForeignKey, + Foo(), ) def test_fk_no_such_target_col_error_upfront(self): meta = MetaData() - a = Table('a', meta, Column('a', Integer)) - Table('b', meta, Column('b', Integer)) + a = Table("a", meta, Column("a", Integer)) + Table("b", meta, Column("b", Integer)) - a.append_constraint(ForeignKeyConstraint(['a'], ['b.x'])) + a.append_constraint(ForeignKeyConstraint(["a"], ["b.x"])) assert_raises_message( exc.NoReferencedColumnError, "Could not initialize target column for ForeignKey 'b.x' on " "table 'a': table 'b' has no column named 'x'", - getattr, list(a.foreign_keys)[0], "column" + getattr, + list(a.foreign_keys)[0], + "column", ) def test_fk_no_such_target_col_error_delayed(self): meta = MetaData() - a = Table('a', meta, Column('a', Integer)) - a.append_constraint( - ForeignKeyConstraint(['a'], ['b.x'])) + a = Table("a", meta, Column("a", Integer)) + a.append_constraint(ForeignKeyConstraint(["a"], ["b.x"])) - b = Table('b', meta, Column('b', Integer)) + b = Table("b", meta, Column("b", Integer)) assert_raises_message( exc.NoReferencedColumnError, "Could not initialize target column for ForeignKey 'b.x' on " "table 'a': table 'b' has no column named 'x'", - getattr, list(a.foreign_keys)[0], "column" + getattr, + list(a.foreign_keys)[0], + "column", ) def test_fk_mismatched_local_remote_cols(self): @@ -354,177 +412,202 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): exc.ArgumentError, "ForeignKeyConstraint number of constrained columns must " "match the number of referenced columns.", - ForeignKeyConstraint, ['a'], ['b.a', 'b.b'] + ForeignKeyConstraint, + ["a"], + ["b.a", "b.b"], ) assert_raises_message( exc.ArgumentError, "ForeignKeyConstraint number of constrained columns " "must match the number of referenced columns.", - ForeignKeyConstraint, ['a', 'b'], ['b.a'] + ForeignKeyConstraint, + ["a", "b"], + ["b.a"], ) assert_raises_message( exc.ArgumentError, "ForeignKeyConstraint with duplicate source column " "references are not supported.", - ForeignKeyConstraint, ['a', 'a'], ['b.a', 'b.b'] + ForeignKeyConstraint, + ["a", "a"], + ["b.a", "b.b"], ) def test_pickle_metadata_sequence_restated(self): m1 = MetaData() - Table('a', m1, - Column('id', Integer, primary_key=True), - Column('x', Integer, Sequence("x_seq"))) + Table( + "a", + m1, + Column("id", Integer, primary_key=True), + Column("x", Integer, Sequence("x_seq")), + ) m2 = pickle.loads(pickle.dumps(m1)) s2 = Sequence("x_seq") - t2 = Table('a', m2, - Column('id', Integer, primary_key=True), - Column('x', Integer, s2), - extend_existing=True) + t2 = Table( + "a", + m2, + Column("id", Integer, primary_key=True), + Column("x", Integer, s2), + extend_existing=True, + ) - assert m2._sequences['x_seq'] is t2.c.x.default - assert m2._sequences['x_seq'] is s2 + assert m2._sequences["x_seq"] is t2.c.x.default + assert m2._sequences["x_seq"] is s2 def test_sequence_restated_replaced(self): """Test restatement of Sequence replaces.""" m1 = MetaData() s1 = Sequence("x_seq") - t = Table('a', m1, - Column('x', Integer, s1) - ) - assert m1._sequences['x_seq'] is s1 - - s2 = Sequence('x_seq') - Table('a', m1, - Column('x', Integer, s2), - extend_existing=True - ) + t = Table("a", m1, Column("x", Integer, s1)) + assert m1._sequences["x_seq"] is s1 + + s2 = Sequence("x_seq") + Table("a", m1, Column("x", Integer, s2), extend_existing=True) assert t.c.x.default is s2 - assert m1._sequences['x_seq'] is s2 + assert m1._sequences["x_seq"] is s2 def test_sequence_attach_to_table(self): m1 = MetaData() s1 = Sequence("s") - t = Table('a', m1, Column('x', Integer, s1)) + t = Table("a", m1, Column("x", Integer, s1)) assert s1.metadata is m1 def test_sequence_attach_to_existing_table(self): m1 = MetaData() s1 = Sequence("s") - t = Table('a', m1, Column('x', Integer)) + t = Table("a", m1, Column("x", Integer)) t.c.x._init_items(s1) assert s1.metadata is m1 def test_pickle_metadata_sequence_implicit(self): m1 = MetaData() - Table('a', m1, - Column('id', Integer, primary_key=True), - Column('x', Integer, Sequence("x_seq"))) + Table( + "a", + m1, + Column("id", Integer, primary_key=True), + Column("x", Integer, Sequence("x_seq")), + ) m2 = pickle.loads(pickle.dumps(m1)) - t2 = Table('a', m2, extend_existing=True) + t2 = Table("a", m2, extend_existing=True) - eq_(m2._sequences, {'x_seq': t2.c.x.default}) + eq_(m2._sequences, {"x_seq": t2.c.x.default}) def test_pickle_metadata_schema(self): m1 = MetaData() - Table('a', m1, - Column('id', Integer, primary_key=True), - Column('x', Integer, Sequence("x_seq")), - schema='y') + Table( + "a", + m1, + Column("id", Integer, primary_key=True), + Column("x", Integer, Sequence("x_seq")), + schema="y", + ) m2 = pickle.loads(pickle.dumps(m1)) - Table('a', m2, schema='y', - extend_existing=True) + Table("a", m2, schema="y", extend_existing=True) eq_(m2._schemas, m1._schemas) def test_metadata_schema_arg(self): - m1 = MetaData(schema='sch1') - m2 = MetaData(schema='sch1', quote_schema=True) - m3 = MetaData(schema='sch1', quote_schema=False) + m1 = MetaData(schema="sch1") + m2 = MetaData(schema="sch1", quote_schema=True) + m3 = MetaData(schema="sch1", quote_schema=False) m4 = MetaData() - for i, (name, metadata, schema, quote_schema, - exp_schema, exp_quote_schema) in enumerate([ - ('t1', m1, None, None, 'sch1', None), - ('t2', m1, 'sch2', None, 'sch2', None), - ('t3', m1, 'sch2', True, 'sch2', True), - ('t4', m1, 'sch1', None, 'sch1', None), - ('t5', m1, BLANK_SCHEMA, None, None, None), - ('t1', m2, None, None, 'sch1', True), - ('t2', m2, 'sch2', None, 'sch2', None), - ('t3', m2, 'sch2', True, 'sch2', True), - ('t4', m2, 'sch1', None, 'sch1', None), - ('t1', m3, None, None, 'sch1', False), - ('t2', m3, 'sch2', None, 'sch2', None), - ('t3', m3, 'sch2', True, 'sch2', True), - ('t4', m3, 'sch1', None, 'sch1', None), - ('t1', m4, None, None, None, None), - ('t2', m4, 'sch2', None, 'sch2', None), - ('t3', m4, 'sch2', True, 'sch2', True), - ('t4', m4, 'sch1', None, 'sch1', None), - ('t5', m4, BLANK_SCHEMA, None, None, None), - ]): + for ( + i, + ( + name, + metadata, + schema, + quote_schema, + exp_schema, + exp_quote_schema, + ), + ) in enumerate( + [ + ("t1", m1, None, None, "sch1", None), + ("t2", m1, "sch2", None, "sch2", None), + ("t3", m1, "sch2", True, "sch2", True), + ("t4", m1, "sch1", None, "sch1", None), + ("t5", m1, BLANK_SCHEMA, None, None, None), + ("t1", m2, None, None, "sch1", True), + ("t2", m2, "sch2", None, "sch2", None), + ("t3", m2, "sch2", True, "sch2", True), + ("t4", m2, "sch1", None, "sch1", None), + ("t1", m3, None, None, "sch1", False), + ("t2", m3, "sch2", None, "sch2", None), + ("t3", m3, "sch2", True, "sch2", True), + ("t4", m3, "sch1", None, "sch1", None), + ("t1", m4, None, None, None, None), + ("t2", m4, "sch2", None, "sch2", None), + ("t3", m4, "sch2", True, "sch2", True), + ("t4", m4, "sch1", None, "sch1", None), + ("t5", m4, BLANK_SCHEMA, None, None, None), + ] + ): kw = {} if schema is not None: - kw['schema'] = schema + kw["schema"] = schema if quote_schema is not None: - kw['quote_schema'] = quote_schema + kw["quote_schema"] = quote_schema t = Table(name, metadata, **kw) eq_(t.schema, exp_schema, "test %d, table schema" % i) - eq_(t.schema.quote if t.schema is not None else None, + eq_( + t.schema.quote if t.schema is not None else None, exp_quote_schema, - "test %d, table quote_schema" % i) + "test %d, table quote_schema" % i, + ) seq = Sequence(name, metadata=metadata, **kw) eq_(seq.schema, exp_schema, "test %d, seq schema" % i) - eq_(seq.schema.quote if seq.schema is not None else None, + eq_( + seq.schema.quote if seq.schema is not None else None, exp_quote_schema, - "test %d, seq quote_schema" % i) + "test %d, seq quote_schema" % i, + ) def test_manual_dependencies(self): meta = MetaData() - a = Table('a', meta, Column('foo', Integer)) - b = Table('b', meta, Column('foo', Integer)) - c = Table('c', meta, Column('foo', Integer)) - d = Table('d', meta, Column('foo', Integer)) - e = Table('e', meta, Column('foo', Integer)) + a = Table("a", meta, Column("foo", Integer)) + b = Table("b", meta, Column("foo", Integer)) + c = Table("c", meta, Column("foo", Integer)) + d = Table("d", meta, Column("foo", Integer)) + e = Table("e", meta, Column("foo", Integer)) e.add_is_dependent_on(c) a.add_is_dependent_on(b) b.add_is_dependent_on(d) e.add_is_dependent_on(b) c.add_is_dependent_on(a) - eq_( - meta.sorted_tables, - [d, b, a, c, e] - ) + eq_(meta.sorted_tables, [d, b, a, c, e]) def test_deterministic_order(self): meta = MetaData() - a = Table('a', meta, Column('foo', Integer)) - b = Table('b', meta, Column('foo', Integer)) - c = Table('c', meta, Column('foo', Integer)) - d = Table('d', meta, Column('foo', Integer)) - e = Table('e', meta, Column('foo', Integer)) + a = Table("a", meta, Column("foo", Integer)) + b = Table("b", meta, Column("foo", Integer)) + c = Table("c", meta, Column("foo", Integer)) + d = Table("d", meta, Column("foo", Integer)) + e = Table("e", meta, Column("foo", Integer)) e.add_is_dependent_on(c) a.add_is_dependent_on(b) - eq_( - meta.sorted_tables, - [b, c, d, a, e] - ) + eq_(meta.sorted_tables, [b, c, d, a, e]) def test_nonexistent(self): - assert_raises(tsa.exc.NoSuchTableError, Table, - 'fake_table', - MetaData(testing.db), autoload=True) + assert_raises( + tsa.exc.NoSuchTableError, + Table, + "fake_table", + MetaData(testing.db), + autoload=True, + ) def test_assorted_repr(self): t1 = Table("foo", MetaData(), Column("x", Integer)) @@ -532,89 +615,74 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): ck = schema.CheckConstraint("x > y", name="someconstraint") for const, exp in ( - (Sequence("my_seq"), - "Sequence('my_seq')"), - (Sequence("my_seq", start=5), - "Sequence('my_seq', start=5)"), - (Column("foo", Integer), - "Column('foo', Integer(), table=None)"), - (Table("bar", MetaData(), Column("x", String)), + (Sequence("my_seq"), "Sequence('my_seq')"), + (Sequence("my_seq", start=5), "Sequence('my_seq', start=5)"), + (Column("foo", Integer), "Column('foo', Integer(), table=None)"), + ( + Table("bar", MetaData(), Column("x", String)), "Table('bar', MetaData(bind=None), " - "Column('x', String(), table=), schema=None)"), - (schema.DefaultGenerator(for_update=True), - "DefaultGenerator(for_update=True)"), + "Column('x', String(), table=), schema=None)", + ), + ( + schema.DefaultGenerator(for_update=True), + "DefaultGenerator(for_update=True)", + ), (schema.Index("bar", "c"), "Index('bar', 'c')"), (i1, "Index('bar', Column('x', Integer(), table=))"), (schema.FetchedValue(), "FetchedValue()"), - (ck, - "CheckConstraint(" - "%s" - ", name='someconstraint')" % repr(ck.sqltext)), - (ColumnDefault(('foo', 'bar')), "ColumnDefault(('foo', 'bar'))") + ( + ck, + "CheckConstraint(" + "%s" + ", name='someconstraint')" % repr(ck.sqltext), + ), + (ColumnDefault(("foo", "bar")), "ColumnDefault(('foo', 'bar'))"), ): - eq_( - repr(const), - exp - ) + eq_(repr(const), exp) class ToMetaDataTest(fixtures.TestBase, ComparesTables): - @testing.requires.check_constraints def test_copy(self): from sqlalchemy.testing.schema import Table + meta = MetaData() table = Table( - 'mytable', + "mytable", meta, + Column("myid", Integer, Sequence("foo_id_seq"), primary_key=True), + Column("name", String(40), nullable=True), Column( - 'myid', - Integer, - Sequence('foo_id_seq'), - primary_key=True), - Column( - 'name', - String(40), - nullable=True), - Column( - 'foo', + "foo", String(40), nullable=False, - server_default='x', - server_onupdate='q'), + server_default="x", + server_onupdate="q", + ), Column( - 'bar', - String(40), - nullable=False, - default='y', - onupdate='z'), + "bar", String(40), nullable=False, default="y", onupdate="z" + ), Column( - 'description', - String(30), - CheckConstraint("description='hi'")), - UniqueConstraint('name'), - test_needs_fk=True) + "description", String(30), CheckConstraint("description='hi'") + ), + UniqueConstraint("name"), + test_needs_fk=True, + ) table2 = Table( - 'othertable', + "othertable", meta, - Column( - 'id', - Integer, - Sequence('foo_seq'), - primary_key=True), - Column( - 'myid', - Integer, - ForeignKey('mytable.myid'), - ), - test_needs_fk=True) + Column("id", Integer, Sequence("foo_seq"), primary_key=True), + Column("myid", Integer, ForeignKey("mytable.myid")), + test_needs_fk=True, + ) table3 = Table( - 'has_comments', meta, - Column('foo', Integer, comment='some column'), - comment='table comment' + "has_comments", + meta, + Column("foo", Integer, comment="some column"), + comment="table comment", ) def test_to_metadata(): @@ -630,47 +698,61 @@ class ToMetaDataTest(fixtures.TestBase, ComparesTables): assert meta2.bind is None pickle.loads(pickle.dumps(meta2)) return ( - meta2.tables['mytable'], - meta2.tables['othertable'], meta2.tables['has_comments']) + meta2.tables["mytable"], + meta2.tables["othertable"], + meta2.tables["has_comments"], + ) def test_pickle_via_reflect(): # this is the most common use case, pickling the results of a # database reflection meta2 = MetaData(bind=testing.db) - t1 = Table('mytable', meta2, autoload=True) - Table('othertable', meta2, autoload=True) - Table('has_comments', meta2, autoload=True) + t1 = Table("mytable", meta2, autoload=True) + Table("othertable", meta2, autoload=True) + Table("has_comments", meta2, autoload=True) meta3 = pickle.loads(pickle.dumps(meta2)) assert meta3.bind is None - assert meta3.tables['mytable'] is not t1 + assert meta3.tables["mytable"] is not t1 return ( - meta3.tables['mytable'], meta3.tables['othertable'], - meta3.tables['has_comments'] + meta3.tables["mytable"], + meta3.tables["othertable"], + meta3.tables["has_comments"], ) meta.create_all(testing.db) try: - for test, has_constraints, reflect in \ - (test_to_metadata, True, False), \ - (test_pickle, True, False), \ - (test_pickle_via_reflect, False, True): + for test, has_constraints, reflect in ( + (test_to_metadata, True, False), + (test_pickle, True, False), + (test_pickle_via_reflect, False, True), + ): table_c, table2_c, table3_c = test() self.assert_tables_equal(table, table_c) self.assert_tables_equal(table2, table2_c) assert table is not table_c assert table.primary_key is not table_c.primary_key - assert list(table2_c.c.myid.foreign_keys)[0].column \ + assert ( + list(table2_c.c.myid.foreign_keys)[0].column is table_c.c.myid - assert list(table2_c.c.myid.foreign_keys)[0].column \ + ) + assert ( + list(table2_c.c.myid.foreign_keys)[0].column is not table.c.myid - assert 'x' in str(table_c.c.foo.server_default.arg) + ) + assert "x" in str(table_c.c.foo.server_default.arg) if not reflect: assert isinstance(table_c.c.myid.default, Sequence) - assert str(table_c.c.foo.server_onupdate.arg) == 'q' - assert str(table_c.c.bar.default.arg) == 'y' - assert getattr(table_c.c.bar.onupdate.arg, 'arg', - table_c.c.bar.onupdate.arg) == 'z' + assert str(table_c.c.foo.server_onupdate.arg) == "q" + assert str(table_c.c.bar.default.arg) == "y" + assert ( + getattr( + table_c.c.bar.onupdate.arg, + "arg", + table_c.c.bar.onupdate.arg, + ) + == "z" + ) assert isinstance(table2_c.c.id.default, Sequence) # constraints don't get reflected for any dialect right @@ -701,8 +783,8 @@ class ToMetaDataTest(fixtures.TestBase, ComparesTables): def test_col_key_fk_parent(self): # test #2643 m1 = MetaData() - a = Table('a', m1, Column('x', Integer)) - b = Table('b', m1, Column('x', Integer, ForeignKey('a.x'), key='y')) + a = Table("a", m1, Column("x", Integer)) + b = Table("b", m1, Column("x", Integer, ForeignKey("a.x"), key="y")) assert b.c.y.references(a.c.x) m2 = MetaData() @@ -713,120 +795,141 @@ class ToMetaDataTest(fixtures.TestBase, ComparesTables): def test_change_schema(self): meta = MetaData() - table = Table('mytable', meta, - Column('myid', Integer, primary_key=True), - Column('name', String(40), nullable=True), - Column('description', String(30), - CheckConstraint("description='hi'")), - UniqueConstraint('name'), - ) + table = Table( + "mytable", + meta, + Column("myid", Integer, primary_key=True), + Column("name", String(40), nullable=True), + Column( + "description", String(30), CheckConstraint("description='hi'") + ), + UniqueConstraint("name"), + ) - table2 = Table('othertable', meta, - Column('id', Integer, primary_key=True), - Column('myid', Integer, ForeignKey('mytable.myid')), - ) + table2 = Table( + "othertable", + meta, + Column("id", Integer, primary_key=True), + Column("myid", Integer, ForeignKey("mytable.myid")), + ) meta2 = MetaData() - table_c = table.tometadata(meta2, schema='someschema') - table2_c = table2.tometadata(meta2, schema='someschema') + table_c = table.tometadata(meta2, schema="someschema") + table2_c = table2.tometadata(meta2, schema="someschema") - eq_(str(table_c.join(table2_c).onclause), str(table_c.c.myid - == table2_c.c.myid)) - eq_(str(table_c.join(table2_c).onclause), - 'someschema.mytable.myid = someschema.othertable.myid') + eq_( + str(table_c.join(table2_c).onclause), + str(table_c.c.myid == table2_c.c.myid), + ) + eq_( + str(table_c.join(table2_c).onclause), + "someschema.mytable.myid = someschema.othertable.myid", + ) def test_retain_table_schema(self): meta = MetaData() - table = Table('mytable', meta, - Column('myid', Integer, primary_key=True), - Column('name', String(40), nullable=True), - Column('description', String(30), - CheckConstraint("description='hi'")), - UniqueConstraint('name'), - schema='myschema', - ) + table = Table( + "mytable", + meta, + Column("myid", Integer, primary_key=True), + Column("name", String(40), nullable=True), + Column( + "description", String(30), CheckConstraint("description='hi'") + ), + UniqueConstraint("name"), + schema="myschema", + ) table2 = Table( - 'othertable', + "othertable", meta, - Column( - 'id', - Integer, - primary_key=True), - Column( - 'myid', - Integer, - ForeignKey('myschema.mytable.myid')), - schema='myschema', + Column("id", Integer, primary_key=True), + Column("myid", Integer, ForeignKey("myschema.mytable.myid")), + schema="myschema", ) meta2 = MetaData() table_c = table.tometadata(meta2) table2_c = table2.tometadata(meta2) - eq_(str(table_c.join(table2_c).onclause), str(table_c.c.myid - == table2_c.c.myid)) - eq_(str(table_c.join(table2_c).onclause), - 'myschema.mytable.myid = myschema.othertable.myid') + eq_( + str(table_c.join(table2_c).onclause), + str(table_c.c.myid == table2_c.c.myid), + ) + eq_( + str(table_c.join(table2_c).onclause), + "myschema.mytable.myid = myschema.othertable.myid", + ) def test_change_name_retain_metadata(self): meta = MetaData() - table = Table('mytable', meta, - Column('myid', Integer, primary_key=True), - Column('name', String(40), nullable=True), - Column('description', String(30), - CheckConstraint("description='hi'")), - UniqueConstraint('name'), - schema='myschema', - ) + table = Table( + "mytable", + meta, + Column("myid", Integer, primary_key=True), + Column("name", String(40), nullable=True), + Column( + "description", String(30), CheckConstraint("description='hi'") + ), + UniqueConstraint("name"), + schema="myschema", + ) - table2 = table.tometadata(table.metadata, name='newtable') - table3 = table.tometadata(table.metadata, schema='newschema', - name='newtable') + table2 = table.tometadata(table.metadata, name="newtable") + table3 = table.tometadata( + table.metadata, schema="newschema", name="newtable" + ) assert table.metadata is table2.metadata assert table.metadata is table3.metadata - eq_((table.name, table2.name, table3.name), - ('mytable', 'newtable', 'newtable')) - eq_((table.key, table2.key, table3.key), - ('myschema.mytable', 'myschema.newtable', 'newschema.newtable')) + eq_( + (table.name, table2.name, table3.name), + ("mytable", "newtable", "newtable"), + ) + eq_( + (table.key, table2.key, table3.key), + ("myschema.mytable", "myschema.newtable", "newschema.newtable"), + ) def test_change_name_change_metadata(self): meta = MetaData() meta2 = MetaData() - table = Table('mytable', meta, - Column('myid', Integer, primary_key=True), - Column('name', String(40), nullable=True), - Column('description', String(30), - CheckConstraint("description='hi'")), - UniqueConstraint('name'), - schema='myschema', - ) + table = Table( + "mytable", + meta, + Column("myid", Integer, primary_key=True), + Column("name", String(40), nullable=True), + Column( + "description", String(30), CheckConstraint("description='hi'") + ), + UniqueConstraint("name"), + schema="myschema", + ) - table2 = table.tometadata(meta2, name='newtable') + table2 = table.tometadata(meta2, name="newtable") assert table.metadata is not table2.metadata - eq_((table.name, table2.name), - ('mytable', 'newtable')) - eq_((table.key, table2.key), - ('myschema.mytable', 'myschema.newtable')) + eq_((table.name, table2.name), ("mytable", "newtable")) + eq_((table.key, table2.key), ("myschema.mytable", "myschema.newtable")) def test_change_name_selfref_fk_moves(self): meta = MetaData() - referenced = Table('ref', meta, - Column('id', Integer, primary_key=True), - ) - table = Table('mytable', meta, - Column('id', Integer, primary_key=True), - Column('parent_id', ForeignKey('mytable.id')), - Column('ref_id', ForeignKey('ref.id')) - ) + referenced = Table( + "ref", meta, Column("id", Integer, primary_key=True) + ) + table = Table( + "mytable", + meta, + Column("id", Integer, primary_key=True), + Column("parent_id", ForeignKey("mytable.id")), + Column("ref_id", ForeignKey("ref.id")), + ) - table2 = table.tometadata(table.metadata, name='newtable') + table2 = table.tometadata(table.metadata, name="newtable") assert table.metadata is table2.metadata assert table2.c.ref_id.references(referenced.c.id) assert table2.c.parent_id.references(table2.c.id) @@ -834,18 +937,21 @@ class ToMetaDataTest(fixtures.TestBase, ComparesTables): def test_change_name_selfref_fk_moves_w_schema(self): meta = MetaData() - referenced = Table('ref', meta, - Column('id', Integer, primary_key=True), - ) - table = Table('mytable', meta, - Column('id', Integer, primary_key=True), - Column('parent_id', ForeignKey('mytable.id')), - Column('ref_id', ForeignKey('ref.id')) - ) + referenced = Table( + "ref", meta, Column("id", Integer, primary_key=True) + ) + table = Table( + "mytable", + meta, + Column("id", Integer, primary_key=True), + Column("parent_id", ForeignKey("mytable.id")), + Column("ref_id", ForeignKey("ref.id")), + ) table2 = table.tometadata( - table.metadata, name='newtable', schema='newschema') - ref2 = referenced.tometadata(table.metadata, schema='newschema') + table.metadata, name="newtable", schema="newschema" + ) + ref2 = referenced.tometadata(table.metadata, schema="newschema") assert table.metadata is table2.metadata assert table2.c.ref_id.references(ref2.c.id) assert table2.c.parent_id.references(table2.c.id) @@ -854,8 +960,9 @@ class ToMetaDataTest(fixtures.TestBase, ComparesTables): m2 = MetaData() existing_schema = t2.schema if schema: - t2c = t2.tometadata(m2, schema=schema, - referred_schema_fn=referred_schema_fn) + t2c = t2.tometadata( + m2, schema=schema, referred_schema_fn=referred_schema_fn + ) eq_(t2c.schema, schema) else: t2c = t2.tometadata(m2, referred_schema_fn=referred_schema_fn) @@ -865,151 +972,164 @@ class ToMetaDataTest(fixtures.TestBase, ComparesTables): def test_fk_has_schema_string_retain_schema(self): m = MetaData() - t2 = Table('t2', m, Column('y', Integer, ForeignKey('q.t1.x'))) + t2 = Table("t2", m, Column("y", Integer, ForeignKey("q.t1.x"))) self._assert_fk(t2, None, "q.t1.x") - Table('t1', m, Column('x', Integer), schema='q') + Table("t1", m, Column("x", Integer), schema="q") self._assert_fk(t2, None, "q.t1.x") def test_fk_has_schema_string_new_schema(self): m = MetaData() - t2 = Table('t2', m, Column('y', Integer, ForeignKey('q.t1.x'))) + t2 = Table("t2", m, Column("y", Integer, ForeignKey("q.t1.x"))) self._assert_fk(t2, "z", "q.t1.x") - Table('t1', m, Column('x', Integer), schema='q') + Table("t1", m, Column("x", Integer), schema="q") self._assert_fk(t2, "z", "q.t1.x") def test_fk_has_schema_col_retain_schema(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer), schema='q') - t2 = Table('t2', m, Column('y', Integer, ForeignKey(t1.c.x))) + t1 = Table("t1", m, Column("x", Integer), schema="q") + t2 = Table("t2", m, Column("y", Integer, ForeignKey(t1.c.x))) self._assert_fk(t2, "z", "q.t1.x") def test_fk_has_schema_col_new_schema(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer), schema='q') - t2 = Table('t2', m, Column('y', Integer, ForeignKey(t1.c.x))) + t1 = Table("t1", m, Column("x", Integer), schema="q") + t2 = Table("t2", m, Column("y", Integer, ForeignKey(t1.c.x))) self._assert_fk(t2, "z", "q.t1.x") def test_fk_and_referent_has_same_schema_string_retain_schema(self): m = MetaData() - t2 = Table('t2', m, Column('y', Integer, - ForeignKey('q.t1.x')), schema="q") + t2 = Table( + "t2", m, Column("y", Integer, ForeignKey("q.t1.x")), schema="q" + ) self._assert_fk(t2, None, "q.t1.x") - Table('t1', m, Column('x', Integer), schema='q') + Table("t1", m, Column("x", Integer), schema="q") self._assert_fk(t2, None, "q.t1.x") def test_fk_and_referent_has_same_schema_string_new_schema(self): m = MetaData() - t2 = Table('t2', m, Column('y', Integer, - ForeignKey('q.t1.x')), schema="q") + t2 = Table( + "t2", m, Column("y", Integer, ForeignKey("q.t1.x")), schema="q" + ) self._assert_fk(t2, "z", "z.t1.x") - Table('t1', m, Column('x', Integer), schema='q') + Table("t1", m, Column("x", Integer), schema="q") self._assert_fk(t2, "z", "z.t1.x") def test_fk_and_referent_has_same_schema_col_retain_schema(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer), schema='q') - t2 = Table('t2', m, Column('y', Integer, - ForeignKey(t1.c.x)), schema='q') + t1 = Table("t1", m, Column("x", Integer), schema="q") + t2 = Table( + "t2", m, Column("y", Integer, ForeignKey(t1.c.x)), schema="q" + ) self._assert_fk(t2, None, "q.t1.x") def test_fk_and_referent_has_same_schema_col_new_schema(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer), schema='q') - t2 = Table('t2', m, Column('y', Integer, - ForeignKey(t1.c.x)), schema='q') - self._assert_fk(t2, 'z', "z.t1.x") + t1 = Table("t1", m, Column("x", Integer), schema="q") + t2 = Table( + "t2", m, Column("y", Integer, ForeignKey(t1.c.x)), schema="q" + ) + self._assert_fk(t2, "z", "z.t1.x") def test_fk_and_referent_has_diff_schema_string_retain_schema(self): m = MetaData() - t2 = Table('t2', m, Column('y', Integer, - ForeignKey('p.t1.x')), schema="q") + t2 = Table( + "t2", m, Column("y", Integer, ForeignKey("p.t1.x")), schema="q" + ) self._assert_fk(t2, None, "p.t1.x") - Table('t1', m, Column('x', Integer), schema='p') + Table("t1", m, Column("x", Integer), schema="p") self._assert_fk(t2, None, "p.t1.x") def test_fk_and_referent_has_diff_schema_string_new_schema(self): m = MetaData() - t2 = Table('t2', m, Column('y', Integer, - ForeignKey('p.t1.x')), schema="q") + t2 = Table( + "t2", m, Column("y", Integer, ForeignKey("p.t1.x")), schema="q" + ) self._assert_fk(t2, "z", "p.t1.x") - Table('t1', m, Column('x', Integer), schema='p') + Table("t1", m, Column("x", Integer), schema="p") self._assert_fk(t2, "z", "p.t1.x") def test_fk_and_referent_has_diff_schema_col_retain_schema(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer), schema='p') - t2 = Table('t2', m, Column('y', Integer, - ForeignKey(t1.c.x)), schema='q') + t1 = Table("t1", m, Column("x", Integer), schema="p") + t2 = Table( + "t2", m, Column("y", Integer, ForeignKey(t1.c.x)), schema="q" + ) self._assert_fk(t2, None, "p.t1.x") def test_fk_and_referent_has_diff_schema_col_new_schema(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer), schema='p') - t2 = Table('t2', m, Column('y', Integer, - ForeignKey(t1.c.x)), schema='q') - self._assert_fk(t2, 'z', "p.t1.x") + t1 = Table("t1", m, Column("x", Integer), schema="p") + t2 = Table( + "t2", m, Column("y", Integer, ForeignKey(t1.c.x)), schema="q" + ) + self._assert_fk(t2, "z", "p.t1.x") def test_fk_custom_system(self): m = MetaData() - t2 = Table('t2', m, Column('y', Integer, - ForeignKey('p.t1.x')), schema='q') + t2 = Table( + "t2", m, Column("y", Integer, ForeignKey("p.t1.x")), schema="q" + ) def ref_fn(table, to_schema, constraint, referred_schema): assert table is t2 eq_(to_schema, "z") eq_(referred_schema, "p") return "h" - self._assert_fk(t2, 'z', "h.t1.x", referred_schema_fn=ref_fn) + + self._assert_fk(t2, "z", "h.t1.x", referred_schema_fn=ref_fn) def test_copy_info(self): m = MetaData() - fk = ForeignKey('t2.id') - c = Column('c', Integer, fk) - ck = CheckConstraint('c > 5') - t = Table('t', m, c, ck) - - m.info['minfo'] = True - fk.info['fkinfo'] = True - c.info['cinfo'] = True - ck.info['ckinfo'] = True - t.info['tinfo'] = True - t.primary_key.info['pkinfo'] = True - fkc = [const for const in t.constraints if - isinstance(const, ForeignKeyConstraint)][0] - fkc.info['fkcinfo'] = True + fk = ForeignKey("t2.id") + c = Column("c", Integer, fk) + ck = CheckConstraint("c > 5") + t = Table("t", m, c, ck) + + m.info["minfo"] = True + fk.info["fkinfo"] = True + c.info["cinfo"] = True + ck.info["ckinfo"] = True + t.info["tinfo"] = True + t.primary_key.info["pkinfo"] = True + fkc = [ + const + for const in t.constraints + if isinstance(const, ForeignKeyConstraint) + ][0] + fkc.info["fkcinfo"] = True m2 = MetaData() t2 = t.tometadata(m2) - m.info['minfo'] = False - fk.info['fkinfo'] = False - c.info['cinfo'] = False - ck.info['ckinfo'] = False - t.primary_key.info['pkinfo'] = False - fkc.info['fkcinfo'] = False + m.info["minfo"] = False + fk.info["fkinfo"] = False + c.info["cinfo"] = False + ck.info["ckinfo"] = False + t.primary_key.info["pkinfo"] = False + fkc.info["fkcinfo"] = False eq_(m2.info, {}) eq_(t2.info, {"tinfo": True}) @@ -1017,21 +1137,29 @@ class ToMetaDataTest(fixtures.TestBase, ComparesTables): eq_(list(t2.c.c.foreign_keys)[0].info, {"fkinfo": True}) eq_(t2.primary_key.info, {"pkinfo": True}) - fkc2 = [const for const in t2.constraints - if isinstance(const, ForeignKeyConstraint)][0] + fkc2 = [ + const + for const in t2.constraints + if isinstance(const, ForeignKeyConstraint) + ][0] eq_(fkc2.info, {"fkcinfo": True}) - ck2 = [const for const in - t2.constraints if isinstance(const, CheckConstraint)][0] + ck2 = [ + const + for const in t2.constraints + if isinstance(const, CheckConstraint) + ][0] eq_(ck2.info, {"ckinfo": True}) def test_dialect_kwargs(self): meta = MetaData() - table = Table('mytable', meta, - Column('myid', Integer, primary_key=True), - mysql_engine='InnoDB', - ) + table = Table( + "mytable", + meta, + Column("myid", Integer, primary_key=True), + mysql_engine="InnoDB", + ) meta2 = MetaData() table_c = table.tometadata(meta2) @@ -1043,72 +1171,82 @@ class ToMetaDataTest(fixtures.TestBase, ComparesTables): def test_indexes(self): meta = MetaData() - table = Table('mytable', meta, - Column('id', Integer, primary_key=True), - Column('data1', Integer, index=True), - Column('data2', Integer), - Index('text', text('data1 + 1')), - ) - Index('multi', table.c.data1, table.c.data2) - Index('func', func.abs(table.c.data1)) - Index('multi-func', table.c.data1, func.abs(table.c.data2)) + table = Table( + "mytable", + meta, + Column("id", Integer, primary_key=True), + Column("data1", Integer, index=True), + Column("data2", Integer), + Index("text", text("data1 + 1")), + ) + Index("multi", table.c.data1, table.c.data2) + Index("func", func.abs(table.c.data1)) + Index("multi-func", table.c.data1, func.abs(table.c.data2)) meta2 = MetaData() table_c = table.tometadata(meta2) def _get_key(i): - return [i.name, i.unique] + \ - sorted(i.kwargs.items()) + \ - [str(col) for col in i.expressions] + return ( + [i.name, i.unique] + + sorted(i.kwargs.items()) + + [str(col) for col in i.expressions] + ) eq_( sorted([_get_key(i) for i in table.indexes]), - sorted([_get_key(i) for i in table_c.indexes]) + sorted([_get_key(i) for i in table_c.indexes]), ) def test_indexes_with_col_redefine(self): meta = MetaData() - table = Table('mytable', meta, - Column('id', Integer, primary_key=True), - Column('data1', Integer), - Column('data2', Integer), - Index('text', text('data1 + 1')), - ) - Index('multi', table.c.data1, table.c.data2) - Index('func', func.abs(table.c.data1)) - Index('multi-func', table.c.data1, func.abs(table.c.data2)) - - table = Table('mytable', meta, - Column('data1', Integer), - Column('data2', Integer), - extend_existing=True - ) + table = Table( + "mytable", + meta, + Column("id", Integer, primary_key=True), + Column("data1", Integer), + Column("data2", Integer), + Index("text", text("data1 + 1")), + ) + Index("multi", table.c.data1, table.c.data2) + Index("func", func.abs(table.c.data1)) + Index("multi-func", table.c.data1, func.abs(table.c.data2)) + + table = Table( + "mytable", + meta, + Column("data1", Integer), + Column("data2", Integer), + extend_existing=True, + ) meta2 = MetaData() table_c = table.tometadata(meta2) def _get_key(i): - return [i.name, i.unique] + \ - sorted(i.kwargs.items()) + \ - [str(col) for col in i.expressions] + return ( + [i.name, i.unique] + + sorted(i.kwargs.items()) + + [str(col) for col in i.expressions] + ) eq_( sorted([_get_key(i) for i in table.indexes]), - sorted([_get_key(i) for i in table_c.indexes]) + sorted([_get_key(i) for i in table_c.indexes]), ) @emits_warning("Table '.+' already exists within the given MetaData") def test_already_exists(self): meta1 = MetaData() - table1 = Table('mytable', meta1, - Column('myid', Integer, primary_key=True), - ) + table1 = Table( + "mytable", meta1, Column("myid", Integer, primary_key=True) + ) meta2 = MetaData() - table2 = Table('mytable', meta2, - Column('yourid', Integer, primary_key=True), - ) + table2 = Table( + "mytable", meta2, Column("yourid", Integer, primary_key=True) + ) table_c = table1.tometadata(meta2) table_d = table2.tometadata(meta2) @@ -1117,86 +1255,97 @@ class ToMetaDataTest(fixtures.TestBase, ComparesTables): assert table_c is table_d def test_default_schema_metadata(self): - meta = MetaData(schema='myschema') + meta = MetaData(schema="myschema") table = Table( - 'mytable', + "mytable", meta, + Column("myid", Integer, primary_key=True), + Column("name", String(40), nullable=True), Column( - 'myid', - Integer, - primary_key=True), - Column( - 'name', - String(40), - nullable=True), - Column( - 'description', - String(30), - CheckConstraint("description='hi'")), - UniqueConstraint('name'), + "description", String(30), CheckConstraint("description='hi'") + ), + UniqueConstraint("name"), ) table2 = Table( - 'othertable', meta, Column( - 'id', Integer, primary_key=True), Column( - 'myid', Integer, ForeignKey('myschema.mytable.myid')), ) + "othertable", + meta, + Column("id", Integer, primary_key=True), + Column("myid", Integer, ForeignKey("myschema.mytable.myid")), + ) - meta2 = MetaData(schema='someschema') + meta2 = MetaData(schema="someschema") table_c = table.tometadata(meta2, schema=None) table2_c = table2.tometadata(meta2, schema=None) - eq_(str(table_c.join(table2_c).onclause), - str(table_c.c.myid == table2_c.c.myid)) - eq_(str(table_c.join(table2_c).onclause), - "someschema.mytable.myid = someschema.othertable.myid") + eq_( + str(table_c.join(table2_c).onclause), + str(table_c.c.myid == table2_c.c.myid), + ) + eq_( + str(table_c.join(table2_c).onclause), + "someschema.mytable.myid = someschema.othertable.myid", + ) def test_strip_schema(self): meta = MetaData() - table = Table('mytable', meta, - Column('myid', Integer, primary_key=True), - Column('name', String(40), nullable=True), - Column('description', String(30), - CheckConstraint("description='hi'")), - UniqueConstraint('name'), - ) + table = Table( + "mytable", + meta, + Column("myid", Integer, primary_key=True), + Column("name", String(40), nullable=True), + Column( + "description", String(30), CheckConstraint("description='hi'") + ), + UniqueConstraint("name"), + ) - table2 = Table('othertable', meta, - Column('id', Integer, primary_key=True), - Column('myid', Integer, ForeignKey('mytable.myid')), - ) + table2 = Table( + "othertable", + meta, + Column("id", Integer, primary_key=True), + Column("myid", Integer, ForeignKey("mytable.myid")), + ) meta2 = MetaData() table_c = table.tometadata(meta2, schema=None) table2_c = table2.tometadata(meta2, schema=None) - eq_(str(table_c.join(table2_c).onclause), str(table_c.c.myid - == table2_c.c.myid)) - eq_(str(table_c.join(table2_c).onclause), - 'mytable.myid = othertable.myid') + eq_( + str(table_c.join(table2_c).onclause), + str(table_c.c.myid == table2_c.c.myid), + ) + eq_( + str(table_c.join(table2_c).onclause), + "mytable.myid = othertable.myid", + ) def test_unique_true_flag(self): meta = MetaData() - table = Table('mytable', meta, Column('x', Integer, unique=True)) + table = Table("mytable", meta, Column("x", Integer, unique=True)) m2 = MetaData() t2 = table.tometadata(m2) eq_( - len([ - const for const - in t2.constraints - if isinstance(const, UniqueConstraint)]), - 1 + len( + [ + const + for const in t2.constraints + if isinstance(const, UniqueConstraint) + ] + ), + 1, ) def test_index_true_flag(self): meta = MetaData() - table = Table('mytable', meta, Column('x', Integer, index=True)) + table = Table("mytable", meta, Column("x", Integer, index=True)) m2 = MetaData() @@ -1214,168 +1363,161 @@ class InfoTest(fixtures.TestBase): eq_(m1.info, {"foo": "bar"}) def test_foreignkey_constraint_info(self): - fkc = ForeignKeyConstraint(['a'], ['b'], name='bar') + fkc = ForeignKeyConstraint(["a"], ["b"], name="bar") eq_(fkc.info, {}) fkc = ForeignKeyConstraint( - ['a'], ['b'], name='bar', info={"foo": "bar"}) + ["a"], ["b"], name="bar", info={"foo": "bar"} + ) eq_(fkc.info, {"foo": "bar"}) def test_foreignkey_info(self): - fkc = ForeignKey('a') + fkc = ForeignKey("a") eq_(fkc.info, {}) - fkc = ForeignKey('a', info={"foo": "bar"}) + fkc = ForeignKey("a", info={"foo": "bar"}) eq_(fkc.info, {"foo": "bar"}) def test_primarykey_constraint_info(self): - pkc = PrimaryKeyConstraint('a', name='x') + pkc = PrimaryKeyConstraint("a", name="x") eq_(pkc.info, {}) - pkc = PrimaryKeyConstraint('a', name='x', info={'foo': 'bar'}) - eq_(pkc.info, {'foo': 'bar'}) + pkc = PrimaryKeyConstraint("a", name="x", info={"foo": "bar"}) + eq_(pkc.info, {"foo": "bar"}) def test_unique_constraint_info(self): - uc = UniqueConstraint('a', name='x') + uc = UniqueConstraint("a", name="x") eq_(uc.info, {}) - uc = UniqueConstraint('a', name='x', info={'foo': 'bar'}) - eq_(uc.info, {'foo': 'bar'}) + uc = UniqueConstraint("a", name="x", info={"foo": "bar"}) + eq_(uc.info, {"foo": "bar"}) def test_check_constraint_info(self): - cc = CheckConstraint('foo=bar', name='x') + cc = CheckConstraint("foo=bar", name="x") eq_(cc.info, {}) - cc = CheckConstraint('foo=bar', name='x', info={'foo': 'bar'}) - eq_(cc.info, {'foo': 'bar'}) + cc = CheckConstraint("foo=bar", name="x", info={"foo": "bar"}) + eq_(cc.info, {"foo": "bar"}) def test_index_info(self): - ix = Index('x', 'a') + ix = Index("x", "a") eq_(ix.info, {}) - ix = Index('x', 'a', info={'foo': 'bar'}) - eq_(ix.info, {'foo': 'bar'}) + ix = Index("x", "a", info={"foo": "bar"}) + eq_(ix.info, {"foo": "bar"}) def test_column_info(self): - c = Column('x', Integer) + c = Column("x", Integer) eq_(c.info, {}) - c = Column('x', Integer, info={'foo': 'bar'}) - eq_(c.info, {'foo': 'bar'}) + c = Column("x", Integer, info={"foo": "bar"}) + eq_(c.info, {"foo": "bar"}) def test_table_info(self): - t = Table('x', MetaData()) + t = Table("x", MetaData()) eq_(t.info, {}) - t = Table('x', MetaData(), info={'foo': 'bar'}) - eq_(t.info, {'foo': 'bar'}) + t = Table("x", MetaData(), info={"foo": "bar"}) + eq_(t.info, {"foo": "bar"}) class TableTest(fixtures.TestBase, AssertsCompiledSQL): - @testing.requires.temporary_tables - @testing.skip_if('mssql', 'different col format') + @testing.skip_if("mssql", "different col format") def test_prefixes(self): from sqlalchemy import Table - table1 = Table("temporary_table_1", MetaData(), - Column("col1", Integer), - prefixes=["TEMPORARY"]) + + table1 = Table( + "temporary_table_1", + MetaData(), + Column("col1", Integer), + prefixes=["TEMPORARY"], + ) self.assert_compile( schema.CreateTable(table1), - "CREATE TEMPORARY TABLE temporary_table_1 (col1 INTEGER)" + "CREATE TEMPORARY TABLE temporary_table_1 (col1 INTEGER)", ) - table2 = Table("temporary_table_2", MetaData(), - Column("col1", Integer), - prefixes=["VIRTUAL"]) + table2 = Table( + "temporary_table_2", + MetaData(), + Column("col1", Integer), + prefixes=["VIRTUAL"], + ) self.assert_compile( schema.CreateTable(table2), - "CREATE VIRTUAL TABLE temporary_table_2 (col1 INTEGER)" + "CREATE VIRTUAL TABLE temporary_table_2 (col1 INTEGER)", ) def test_table_info(self): metadata = MetaData() - t1 = Table('foo', metadata, info={'x': 'y'}) - t2 = Table('bar', metadata, info={}) - t3 = Table('bat', metadata) - assert t1.info == {'x': 'y'} + t1 = Table("foo", metadata, info={"x": "y"}) + t2 = Table("bar", metadata, info={}) + t3 = Table("bat", metadata) + assert t1.info == {"x": "y"} assert t2.info == {} assert t3.info == {} for t in (t1, t2, t3): - t.info['bar'] = 'zip' - assert t.info['bar'] == 'zip' + t.info["bar"] = "zip" + assert t.info["bar"] == "zip" def test_reset_exported_passes(self): m = MetaData() - t = Table('t', m, Column('foo', Integer)) - eq_( - list(t.c), [t.c.foo] - ) + t = Table("t", m, Column("foo", Integer)) + eq_(list(t.c), [t.c.foo]) t._reset_exported() - eq_( - list(t.c), [t.c.foo] - ) + eq_(list(t.c), [t.c.foo]) def test_foreign_key_constraints_collection(self): metadata = MetaData() - t1 = Table('foo', metadata, Column('a', Integer)) + t1 = Table("foo", metadata, Column("a", Integer)) eq_(t1.foreign_key_constraints, set()) - fk1 = ForeignKey('q.id') - fk2 = ForeignKey('j.id') - fk3 = ForeignKeyConstraint(['b', 'c'], ['r.x', 'r.y']) + fk1 = ForeignKey("q.id") + fk2 = ForeignKey("j.id") + fk3 = ForeignKeyConstraint(["b", "c"], ["r.x", "r.y"]) - t1.append_column(Column('b', Integer, fk1)) - eq_( - t1.foreign_key_constraints, - set([fk1.constraint])) + t1.append_column(Column("b", Integer, fk1)) + eq_(t1.foreign_key_constraints, set([fk1.constraint])) - t1.append_column(Column('c', Integer, fk2)) - eq_( - t1.foreign_key_constraints, - set([fk1.constraint, fk2.constraint])) + t1.append_column(Column("c", Integer, fk2)) + eq_(t1.foreign_key_constraints, set([fk1.constraint, fk2.constraint])) t1.append_constraint(fk3) eq_( t1.foreign_key_constraints, - set([fk1.constraint, fk2.constraint, fk3])) + set([fk1.constraint, fk2.constraint, fk3]), + ) def test_c_immutable(self): m = MetaData() - t1 = Table('t', m, Column('x', Integer), Column('y', Integer)) - assert_raises( - TypeError, - t1.c.extend, [Column('z', Integer)] - ) + t1 = Table("t", m, Column("x", Integer), Column("y", Integer)) + assert_raises(TypeError, t1.c.extend, [Column("z", Integer)]) def assign(): - t1.c['z'] = Column('z', Integer) - assert_raises( - TypeError, - assign - ) + t1.c["z"] = Column("z", Integer) + + assert_raises(TypeError, assign) def assign2(): - t1.c.z = Column('z', Integer) - assert_raises( - TypeError, - assign2 - ) + t1.c.z = Column("z", Integer) + + assert_raises(TypeError, assign2) def test_c_mutate_after_unpickle(self): m = MetaData() - y = Column('y', Integer) - t1 = Table('t', m, Column('x', Integer), y) + y = Column("y", Integer) + t1 = Table("t", m, Column("x", Integer), y) t2 = pickle.loads(pickle.dumps(t1)) - z = Column('z', Integer) - g = Column('g', Integer) + z = Column("z", Integer) + g = Column("g", Integer) t2.append_column(z) is_(t1.c.contains_column(y), True) @@ -1389,39 +1531,39 @@ class TableTest(fixtures.TestBase, AssertsCompiledSQL): def test_autoincrement_replace(self): m = MetaData() - t = Table('t', m, - Column('id', Integer, primary_key=True) - ) + t = Table("t", m, Column("id", Integer, primary_key=True)) is_(t._autoincrement_column, t.c.id) - t = Table('t', m, - Column('id', Integer, primary_key=True), - extend_existing=True - ) + t = Table( + "t", + m, + Column("id", Integer, primary_key=True), + extend_existing=True, + ) is_(t._autoincrement_column, t.c.id) def test_pk_args_standalone(self): m = MetaData() - t = Table('t', m, - Column('x', Integer, primary_key=True), - PrimaryKeyConstraint(mssql_clustered=True) - ) - eq_( - list(t.primary_key), [t.c.x] - ) - eq_( - t.primary_key.dialect_kwargs, {"mssql_clustered": True} + t = Table( + "t", + m, + Column("x", Integer, primary_key=True), + PrimaryKeyConstraint(mssql_clustered=True), ) + eq_(list(t.primary_key), [t.c.x]) + eq_(t.primary_key.dialect_kwargs, {"mssql_clustered": True}) def test_pk_cols_sets_flags(self): m = MetaData() - t = Table('t', m, - Column('x', Integer), - Column('y', Integer), - Column('z', Integer), - PrimaryKeyConstraint('x', 'y') - ) + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), + PrimaryKeyConstraint("x", "y"), + ) eq_(t.c.x.primary_key, True) eq_(t.c.y.primary_key, True) eq_(t.c.z.primary_key, False) @@ -1432,10 +1574,12 @@ class TableTest(fixtures.TestBase, AssertsCompiledSQL): exc.SAWarning, "Table 't' specifies columns 'x' as primary_key=True, " "not matching locally specified columns 'q'", - Table, 't', m, - Column('x', Integer, primary_key=True), - Column('q', Integer), - PrimaryKeyConstraint('q') + Table, + "t", + m, + Column("x", Integer, primary_key=True), + Column("q", Integer), + PrimaryKeyConstraint("q"), ) def test_pk_col_mismatch_two(self): @@ -1444,40 +1588,46 @@ class TableTest(fixtures.TestBase, AssertsCompiledSQL): exc.SAWarning, "Table 't' specifies columns 'a', 'b', 'c' as primary_key=True, " "not matching locally specified columns 'b', 'c'", - Table, 't', m, - Column('a', Integer, primary_key=True), - Column('b', Integer, primary_key=True), - Column('c', Integer, primary_key=True), - PrimaryKeyConstraint('b', 'c') + Table, + "t", + m, + Column("a", Integer, primary_key=True), + Column("b", Integer, primary_key=True), + Column("c", Integer, primary_key=True), + PrimaryKeyConstraint("b", "c"), ) @testing.emits_warning("Table 't'") def test_pk_col_mismatch_three(self): m = MetaData() - t = Table('t', m, - Column('x', Integer, primary_key=True), - Column('q', Integer), - PrimaryKeyConstraint('q') - ) + t = Table( + "t", + m, + Column("x", Integer, primary_key=True), + Column("q", Integer), + PrimaryKeyConstraint("q"), + ) eq_(list(t.primary_key), [t.c.q]) @testing.emits_warning("Table 't'") def test_pk_col_mismatch_four(self): m = MetaData() - t = Table('t', m, - Column('a', Integer, primary_key=True), - Column('b', Integer, primary_key=True), - Column('c', Integer, primary_key=True), - PrimaryKeyConstraint('b', 'c') - ) + t = Table( + "t", + m, + Column("a", Integer, primary_key=True), + Column("b", Integer, primary_key=True), + Column("c", Integer, primary_key=True), + PrimaryKeyConstraint("b", "c"), + ) eq_(list(t.primary_key), [t.c.b, t.c.c]) def test_pk_always_flips_nullable(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer), PrimaryKeyConstraint('x')) + t1 = Table("t1", m, Column("x", Integer), PrimaryKeyConstraint("x")) - t2 = Table('t2', m, Column('x', Integer, primary_key=True)) + t2 = Table("t2", m, Column("x", Integer, primary_key=True)) eq_(list(t1.primary_key), [t1.c.x]) @@ -1492,67 +1642,58 @@ class TableTest(fixtures.TestBase, AssertsCompiledSQL): class PKAutoIncrementTest(fixtures.TestBase): def test_multi_integer_no_autoinc(self): - pk = PrimaryKeyConstraint( - Column('a', Integer), - Column('b', Integer) - ) - t = Table('t', MetaData()) + pk = PrimaryKeyConstraint(Column("a", Integer), Column("b", Integer)) + t = Table("t", MetaData()) t.append_constraint(pk) is_(pk._autoincrement_column, None) def test_multi_integer_multi_autoinc(self): pk = PrimaryKeyConstraint( - Column('a', Integer, autoincrement=True), - Column('b', Integer, autoincrement=True) + Column("a", Integer, autoincrement=True), + Column("b", Integer, autoincrement=True), ) - t = Table('t', MetaData()) + t = Table("t", MetaData()) t.append_constraint(pk) assert_raises_message( exc.ArgumentError, "Only one Column may be marked", - lambda: pk._autoincrement_column + lambda: pk._autoincrement_column, ) def test_single_integer_no_autoinc(self): - pk = PrimaryKeyConstraint( - Column('a', Integer), - ) - t = Table('t', MetaData()) + pk = PrimaryKeyConstraint(Column("a", Integer)) + t = Table("t", MetaData()) t.append_constraint(pk) - is_(pk._autoincrement_column, pk.columns['a']) + is_(pk._autoincrement_column, pk.columns["a"]) def test_single_string_no_autoinc(self): - pk = PrimaryKeyConstraint( - Column('a', String), - ) - t = Table('t', MetaData()) + pk = PrimaryKeyConstraint(Column("a", String)) + t = Table("t", MetaData()) t.append_constraint(pk) is_(pk._autoincrement_column, None) def test_single_string_illegal_autoinc(self): - t = Table('t', MetaData(), Column('a', String, autoincrement=True)) - pk = PrimaryKeyConstraint( - t.c.a - ) + t = Table("t", MetaData(), Column("a", String, autoincrement=True)) + pk = PrimaryKeyConstraint(t.c.a) t.append_constraint(pk) assert_raises_message( exc.ArgumentError, "Column type VARCHAR on column 't.a'", - lambda: pk._autoincrement_column + lambda: pk._autoincrement_column, ) def test_single_integer_default(self): t = Table( - 't', MetaData(), - Column('a', Integer, autoincrement=True, default=lambda: 1)) - pk = PrimaryKeyConstraint( - t.c.a + "t", + MetaData(), + Column("a", Integer, autoincrement=True, default=lambda: 1), ) + pk = PrimaryKeyConstraint(t.c.a) t.append_constraint(pk) is_(pk._autoincrement_column, t.c.a) @@ -1562,47 +1703,45 @@ class PKAutoIncrementTest(fixtures.TestBase): # if the user puts autoincrement=True with a server_default, trust # them on it t = Table( - 't', MetaData(), - Column('a', Integer, - autoincrement=True, server_default=func.magic())) - pk = PrimaryKeyConstraint( - t.c.a + "t", + MetaData(), + Column( + "a", Integer, autoincrement=True, server_default=func.magic() + ), ) + pk = PrimaryKeyConstraint(t.c.a) t.append_constraint(pk) is_(pk._autoincrement_column, t.c.a) def test_implicit_autoinc_but_fks(self): m = MetaData() - Table('t1', m, Column('id', Integer, primary_key=True)) - t2 = Table( - 't2', MetaData(), - Column('a', Integer, ForeignKey('t1.id'))) - pk = PrimaryKeyConstraint( - t2.c.a - ) + Table("t1", m, Column("id", Integer, primary_key=True)) + t2 = Table("t2", MetaData(), Column("a", Integer, ForeignKey("t1.id"))) + pk = PrimaryKeyConstraint(t2.c.a) t2.append_constraint(pk) is_(pk._autoincrement_column, None) def test_explicit_autoinc_but_fks(self): m = MetaData() - Table('t1', m, Column('id', Integer, primary_key=True)) + Table("t1", m, Column("id", Integer, primary_key=True)) t2 = Table( - 't2', MetaData(), - Column('a', Integer, ForeignKey('t1.id'), autoincrement=True)) - pk = PrimaryKeyConstraint( - t2.c.a + "t2", + MetaData(), + Column("a", Integer, ForeignKey("t1.id"), autoincrement=True), ) + pk = PrimaryKeyConstraint(t2.c.a) t2.append_constraint(pk) is_(pk._autoincrement_column, t2.c.a) t3 = Table( - 't3', MetaData(), - Column('a', Integer, - ForeignKey('t1.id'), autoincrement='ignore_fk')) - pk = PrimaryKeyConstraint( - t3.c.a + "t3", + MetaData(), + Column( + "a", Integer, ForeignKey("t1.id"), autoincrement="ignore_fk" + ), ) + pk = PrimaryKeyConstraint(t3.c.a) t3.append_constraint(pk) is_(pk._autoincrement_column, t3.c.a) @@ -1622,12 +1761,14 @@ class SchemaTypeTest(fixtures.TestBase): def _on_table_create(self, target, bind, **kw): super(SchemaTypeTest.TrackEvents, self)._on_table_create( - target, bind, **kw) + target, bind, **kw + ) self.evt_targets += (target,) def _on_metadata_create(self, target, bind, **kw): super(SchemaTypeTest.TrackEvents, self)._on_metadata_create( - target, bind, **kw) + target, bind, **kw + ) self.evt_targets += (target,) # TODO: Enum and Boolean put TypeEngine first. Changing that here @@ -1644,7 +1785,6 @@ class SchemaTypeTest(fixtures.TestBase): pass class MyTypeWImpl(MyType): - def _gen_dialect_impl(self, dialect): return self.adapt(SchemaTypeTest.MyTypeImpl) @@ -1724,13 +1864,13 @@ class SchemaTypeTest(fixtures.TestBase): orig_set_parent_w_dispatch(parent) canary._set_parent_with_dispatch(parent) - with mock.patch.object(evt_target, '_set_parent', _set_parent): + with mock.patch.object(evt_target, "_set_parent", _set_parent): with mock.patch.object( - evt_target, '_set_parent_with_dispatch', - _set_parent_w_dispatch): + evt_target, "_set_parent_with_dispatch", _set_parent_w_dispatch + ): event.listen(evt_target, "before_parent_attach", canary.go) - c = Column('q', typ) + c = Column("q", typ) if double: # no clean way yet to fix this, inner schema type is called @@ -1741,8 +1881,8 @@ class SchemaTypeTest(fixtures.TestBase): mock.call._set_parent(c), mock.call.go(evt_target, c), mock.call._set_parent(c), - mock.call._set_parent_with_dispatch(c) - ] + mock.call._set_parent_with_dispatch(c), + ], ) else: eq_( @@ -1750,39 +1890,39 @@ class SchemaTypeTest(fixtures.TestBase): [ mock.call.go(evt_target, c), mock.call._set_parent(c), - mock.call._set_parent_with_dispatch(c) - ] + mock.call._set_parent_with_dispatch(c), + ], ) def test_independent_schema(self): m = MetaData() type_ = self.MyType(schema="q") - t1 = Table('x', m, Column("y", type_), schema="z") + t1 = Table("x", m, Column("y", type_), schema="z") eq_(t1.c.y.type.schema, "q") def test_inherit_schema(self): m = MetaData() type_ = self.MyType(schema="q", inherit_schema=True) - t1 = Table('x', m, Column("y", type_), schema="z") + t1 = Table("x", m, Column("y", type_), schema="z") eq_(t1.c.y.type.schema, "z") def test_independent_schema_enum(self): m = MetaData() type_ = sqltypes.Enum("a", schema="q") - t1 = Table('x', m, Column("y", type_), schema="z") + t1 = Table("x", m, Column("y", type_), schema="z") eq_(t1.c.y.type.schema, "q") def test_inherit_schema_enum(self): m = MetaData() type_ = sqltypes.Enum("a", "b", "c", schema="q", inherit_schema=True) - t1 = Table('x', m, Column("y", type_), schema="z") + t1 = Table("x", m, Column("y", type_), schema="z") eq_(t1.c.y.type.schema, "z") def test_tometadata_copy_type(self): m1 = MetaData() type_ = self.MyType() - t1 = Table('x', m1, Column("y", type_)) + t1 = Table("x", m1, Column("y", type_)) m2 = MetaData() t2 = t1.tometadata(m2) @@ -1794,14 +1934,13 @@ class SchemaTypeTest(fixtures.TestBase): is_(t2.c.y.type.table, t2) def test_tometadata_copy_decorated(self): - class MyDecorated(TypeDecorator): impl = self.MyType m1 = MetaData() type_ = MyDecorated(schema="z") - t1 = Table('x', m1, Column("y", type_)) + t1 = Table("x", m1, Column("y", type_)) m2 = MetaData() t2 = t1.tometadata(m2) @@ -1811,7 +1950,7 @@ class SchemaTypeTest(fixtures.TestBase): m1 = MetaData() type_ = self.MyType() - t1 = Table('x', m1, Column("y", type_)) + t1 = Table("x", m1, Column("y", type_)) m2 = MetaData() t2 = t1.tometadata(m2, schema="bar") @@ -1822,7 +1961,7 @@ class SchemaTypeTest(fixtures.TestBase): m1 = MetaData() type_ = self.MyType(inherit_schema=True) - t1 = Table('x', m1, Column("y", type_)) + t1 = Table("x", m1, Column("y", type_)) m2 = MetaData() t2 = t1.tometadata(m2, schema="bar") @@ -1834,7 +1973,7 @@ class SchemaTypeTest(fixtures.TestBase): m1 = MetaData() type_ = self.MyType() - t1 = Table('x', m1, Column("y", type_)) + t1 = Table("x", m1, Column("y", type_)) m2 = MetaData() t2 = t1.tometadata(m2) @@ -1851,17 +1990,17 @@ class SchemaTypeTest(fixtures.TestBase): def test_enum_column_copy_transfers_events(self): m = MetaData() - type_ = self.WrapEnum('a', 'b', 'c', name='foo') - y = Column('y', type_) + type_ = self.WrapEnum("a", "b", "c", name="foo") + y = Column("y", type_) y_copy = y.copy() - t1 = Table('x', m, y_copy) + t1 = Table("x", m, y_copy) is_true(y_copy.type._create_events) # for PostgreSQL, this will emit CREATE TYPE m.dispatch.before_create(t1, testing.db) try: - eq_(t1.c.y.type.evt_targets, (t1, )) + eq_(t1.c.y.type.evt_targets, (t1,)) finally: # do the drop so that PostgreSQL emits DROP TYPE m.dispatch.after_drop(t1, testing.db) @@ -1869,25 +2008,30 @@ class SchemaTypeTest(fixtures.TestBase): def test_enum_nonnative_column_copy_transfers_events(self): m = MetaData() - type_ = self.WrapEnum('a', 'b', 'c', name='foo', native_enum=False) - y = Column('y', type_) + type_ = self.WrapEnum("a", "b", "c", name="foo", native_enum=False) + y = Column("y", type_) y_copy = y.copy() - t1 = Table('x', m, y_copy) + t1 = Table("x", m, y_copy) is_true(y_copy.type._create_events) m.dispatch.before_create(t1, testing.db) - eq_(t1.c.y.type.evt_targets, (t1, )) + eq_(t1.c.y.type.evt_targets, (t1,)) def test_enum_nonnative_column_copy_transfers_constraintpref(self): m = MetaData() type_ = self.WrapEnum( - 'a', 'b', 'c', name='foo', - native_enum=False, create_constraint=False) - y = Column('y', type_) + "a", + "b", + "c", + name="foo", + native_enum=False, + create_constraint=False, + ) + y = Column("y", type_) y_copy = y.copy() - Table('x', m, y_copy) + Table("x", m, y_copy) is_false(y_copy.type.create_constraint) @@ -1895,9 +2039,9 @@ class SchemaTypeTest(fixtures.TestBase): m = MetaData() type_ = self.WrapBoolean() - y = Column('y', type_) + y = Column("y", type_) y_copy = y.copy() - t1 = Table('x', m, y_copy) + t1 = Table("x", m, y_copy) is_true(y_copy.type._create_events) @@ -1905,9 +2049,9 @@ class SchemaTypeTest(fixtures.TestBase): m = MetaData() type_ = self.WrapBoolean(create_constraint=False) - y = Column('y', type_) + y = Column("y", type_) y_copy = y.copy() - Table('x', m, y_copy) + Table("x", m, y_copy) is_false(y_copy.type.create_constraint) @@ -1915,7 +2059,7 @@ class SchemaTypeTest(fixtures.TestBase): m1 = MetaData() typ = self.MyType(metadata=m1) m1.dispatch.before_create(m1, testing.db) - eq_(typ.evt_targets, (m1, )) + eq_(typ.evt_targets, (m1,)) dialect_impl = typ.dialect_impl(testing.db.dialect) eq_(dialect_impl.evt_targets, ()) @@ -1924,24 +2068,24 @@ class SchemaTypeTest(fixtures.TestBase): m1 = MetaData() typ = self.MyTypeWImpl(metadata=m1) m1.dispatch.before_create(m1, testing.db) - eq_(typ.evt_targets, (m1, )) + eq_(typ.evt_targets, (m1,)) dialect_impl = typ.dialect_impl(testing.db.dialect) - eq_(dialect_impl.evt_targets, (m1, )) + eq_(dialect_impl.evt_targets, (m1,)) def test_table_dispatch_decorator_schematype(self): m1 = MetaData() typ = self.MyTypeDecAndSchema() - t1 = Table('t1', m1, Column('x', typ)) + t1 = Table("t1", m1, Column("x", typ)) m1.dispatch.before_create(t1, testing.db) - eq_(typ.evt_targets, (t1, )) + eq_(typ.evt_targets, (t1,)) def test_table_dispatch_no_new_impl(self): m1 = MetaData() typ = self.MyType() - t1 = Table('t1', m1, Column('x', typ)) + t1 = Table("t1", m1, Column("x", typ)) m1.dispatch.before_create(t1, testing.db) - eq_(typ.evt_targets, (t1, )) + eq_(typ.evt_targets, (t1,)) dialect_impl = typ.dialect_impl(testing.db.dialect) eq_(dialect_impl.evt_targets, ()) @@ -1949,12 +2093,12 @@ class SchemaTypeTest(fixtures.TestBase): def test_table_dispatch_new_impl(self): m1 = MetaData() typ = self.MyTypeWImpl() - t1 = Table('t1', m1, Column('x', typ)) + t1 = Table("t1", m1, Column("x", typ)) m1.dispatch.before_create(t1, testing.db) - eq_(typ.evt_targets, (t1, )) + eq_(typ.evt_targets, (t1,)) dialect_impl = typ.dialect_impl(testing.db.dialect) - eq_(dialect_impl.evt_targets, (t1, )) + eq_(dialect_impl.evt_targets, (t1,)) def test_create_metadata_bound_no_crash(self): m1 = MetaData() @@ -1965,127 +2109,103 @@ class SchemaTypeTest(fixtures.TestBase): def test_boolean_constraint_type_doesnt_double(self): m1 = MetaData() - t1 = Table('x', m1, Column("flag", Boolean())) + t1 = Table("x", m1, Column("flag", Boolean())) eq_( - len([ - c for c in t1.constraints - if isinstance(c, CheckConstraint)]), - 1 + len([c for c in t1.constraints if isinstance(c, CheckConstraint)]), + 1, ) m2 = MetaData() t2 = t1.tometadata(m2) eq_( - len([ - c for c in t2.constraints - if isinstance(c, CheckConstraint)]), - 1 + len([c for c in t2.constraints if isinstance(c, CheckConstraint)]), + 1, ) def test_enum_constraint_type_doesnt_double(self): m1 = MetaData() - t1 = Table('x', m1, Column("flag", Enum('a', 'b', 'c'))) + t1 = Table("x", m1, Column("flag", Enum("a", "b", "c"))) eq_( - len([ - c for c in t1.constraints - if isinstance(c, CheckConstraint)]), - 1 + len([c for c in t1.constraints if isinstance(c, CheckConstraint)]), + 1, ) m2 = MetaData() t2 = t1.tometadata(m2) eq_( - len([ - c for c in t2.constraints - if isinstance(c, CheckConstraint)]), - 1 + len([c for c in t2.constraints if isinstance(c, CheckConstraint)]), + 1, ) class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): - def test_default_schema_metadata_fk(self): m = MetaData(schema="foo") - t1 = Table('t1', m, Column('x', Integer)) - t2 = Table('t2', m, Column('x', Integer, ForeignKey('t1.x'))) + t1 = Table("t1", m, Column("x", Integer)) + t2 = Table("t2", m, Column("x", Integer, ForeignKey("t1.x"))) assert t2.c.x.references(t1.c.x) def test_ad_hoc_schema_equiv_fk(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer), schema="foo") + t1 = Table("t1", m, Column("x", Integer), schema="foo") t2 = Table( - 't2', - m, - Column( - 'x', - Integer, - ForeignKey('t1.x')), - schema="foo") + "t2", m, Column("x", Integer, ForeignKey("t1.x")), schema="foo" + ) assert_raises( - exc.NoReferencedTableError, - lambda: t2.c.x.references(t1.c.x) + exc.NoReferencedTableError, lambda: t2.c.x.references(t1.c.x) ) def test_default_schema_metadata_fk_alt_remote(self): m = MetaData(schema="foo") - t1 = Table('t1', m, Column('x', Integer)) - t2 = Table('t2', m, Column('x', Integer, ForeignKey('t1.x')), - schema="bar") + t1 = Table("t1", m, Column("x", Integer)) + t2 = Table( + "t2", m, Column("x", Integer, ForeignKey("t1.x")), schema="bar" + ) assert t2.c.x.references(t1.c.x) def test_default_schema_metadata_fk_alt_local_raises(self): m = MetaData(schema="foo") - t1 = Table('t1', m, Column('x', Integer), schema="bar") - t2 = Table('t2', m, Column('x', Integer, ForeignKey('t1.x'))) + t1 = Table("t1", m, Column("x", Integer), schema="bar") + t2 = Table("t2", m, Column("x", Integer, ForeignKey("t1.x"))) assert_raises( - exc.NoReferencedTableError, - lambda: t2.c.x.references(t1.c.x) + exc.NoReferencedTableError, lambda: t2.c.x.references(t1.c.x) ) def test_default_schema_metadata_fk_alt_local(self): m = MetaData(schema="foo") - t1 = Table('t1', m, Column('x', Integer), schema="bar") - t2 = Table('t2', m, Column('x', Integer, ForeignKey('bar.t1.x'))) + t1 = Table("t1", m, Column("x", Integer), schema="bar") + t2 = Table("t2", m, Column("x", Integer, ForeignKey("bar.t1.x"))) assert t2.c.x.references(t1.c.x) def test_create_drop_schema(self): self.assert_compile( - schema.CreateSchema("sa_schema"), - "CREATE SCHEMA sa_schema" + schema.CreateSchema("sa_schema"), "CREATE SCHEMA sa_schema" ) self.assert_compile( - schema.DropSchema("sa_schema"), - "DROP SCHEMA sa_schema" + schema.DropSchema("sa_schema"), "DROP SCHEMA sa_schema" ) self.assert_compile( schema.DropSchema("sa_schema", cascade=True), - "DROP SCHEMA sa_schema CASCADE" + "DROP SCHEMA sa_schema CASCADE", ) def test_iteration(self): metadata = MetaData() table1 = Table( - 'table1', + "table1", metadata, - Column( - 'col1', - Integer, - primary_key=True), - schema='someschema') + Column("col1", Integer, primary_key=True), + schema="someschema", + ) table2 = Table( - 'table2', + "table2", metadata, - Column( - 'col1', - Integer, - primary_key=True), - Column( - 'col2', - Integer, - ForeignKey('someschema.table1.col1')), - schema='someschema') + Column("col1", Integer, primary_key=True), + Column("col2", Integer, ForeignKey("someschema.table1.col1")), + schema="someschema", + ) t1 = str(schema.CreateTable(table1).compile(bind=testing.db)) t2 = str(schema.CreateTable(table2).compile(bind=testing.db)) @@ -2098,16 +2218,18 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): class UseExistingTest(fixtures.TablesTest): - @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(30))) + Table( + "users", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30)), + ) def _useexisting_fixture(self): meta2 = MetaData(testing.db) - Table('users', meta2, autoload=True) + Table("users", meta2, autoload=True) return meta2 def _notexisting_fixture(self): @@ -2117,21 +2239,23 @@ class UseExistingTest(fixtures.TablesTest): meta2 = self._useexisting_fixture() def go(): - Table('users', meta2, Column('name', - Unicode), autoload=True) + Table("users", meta2, Column("name", Unicode), autoload=True) + assert_raises_message( exc.InvalidRequestError, - "Table 'users' is already defined for this " - "MetaData instance.", - go + "Table 'users' is already defined for this " "MetaData instance.", + go, ) def test_keep_plus_existing_raises(self): meta2 = self._useexisting_fixture() assert_raises( exc.ArgumentError, - Table, 'users', meta2, keep_existing=True, - extend_existing=True + Table, + "users", + meta2, + keep_existing=True, + extend_existing=True, ) @testing.uses_deprecated() @@ -2139,118 +2263,152 @@ class UseExistingTest(fixtures.TablesTest): meta2 = self._useexisting_fixture() assert_raises( exc.ArgumentError, - Table, 'users', meta2, useexisting=True, - extend_existing=True + Table, + "users", + meta2, + useexisting=True, + extend_existing=True, ) def test_keep_existing_no_dupe_constraints(self): meta2 = self._notexisting_fixture() - users = Table('users', meta2, - Column('id', Integer), - Column('name', Unicode), - UniqueConstraint('name'), - keep_existing=True - ) - assert 'name' in users.c - assert 'id' in users.c + users = Table( + "users", + meta2, + Column("id", Integer), + Column("name", Unicode), + UniqueConstraint("name"), + keep_existing=True, + ) + assert "name" in users.c + assert "id" in users.c eq_(len(users.constraints), 2) - u2 = Table('users', meta2, - Column('id', Integer), - Column('name', Unicode), - UniqueConstraint('name'), - keep_existing=True - ) + u2 = Table( + "users", + meta2, + Column("id", Integer), + Column("name", Unicode), + UniqueConstraint("name"), + keep_existing=True, + ) eq_(len(u2.constraints), 2) def test_extend_existing_dupes_constraints(self): meta2 = self._notexisting_fixture() - users = Table('users', meta2, - Column('id', Integer), - Column('name', Unicode), - UniqueConstraint('name'), - extend_existing=True - ) - assert 'name' in users.c - assert 'id' in users.c + users = Table( + "users", + meta2, + Column("id", Integer), + Column("name", Unicode), + UniqueConstraint("name"), + extend_existing=True, + ) + assert "name" in users.c + assert "id" in users.c eq_(len(users.constraints), 2) - u2 = Table('users', meta2, - Column('id', Integer), - Column('name', Unicode), - UniqueConstraint('name'), - extend_existing=True - ) + u2 = Table( + "users", + meta2, + Column("id", Integer), + Column("name", Unicode), + UniqueConstraint("name"), + extend_existing=True, + ) # constraint got duped eq_(len(u2.constraints), 3) def test_keep_existing_coltype(self): meta2 = self._useexisting_fixture() - users = Table('users', meta2, Column('name', Unicode), - autoload=True, keep_existing=True) + users = Table( + "users", + meta2, + Column("name", Unicode), + autoload=True, + keep_existing=True, + ) assert not isinstance(users.c.name.type, Unicode) def test_keep_existing_quote(self): meta2 = self._useexisting_fixture() - users = Table('users', meta2, quote=True, autoload=True, - keep_existing=True) + users = Table( + "users", meta2, quote=True, autoload=True, keep_existing=True + ) assert not users.name.quote def test_keep_existing_add_column(self): meta2 = self._useexisting_fixture() - users = Table('users', meta2, - Column('foo', Integer), - autoload=True, - keep_existing=True) + users = Table( + "users", + meta2, + Column("foo", Integer), + autoload=True, + keep_existing=True, + ) assert "foo" not in users.c def test_keep_existing_coltype_no_orig(self): meta2 = self._notexisting_fixture() - users = Table('users', meta2, Column('name', Unicode), - autoload=True, keep_existing=True) + users = Table( + "users", + meta2, + Column("name", Unicode), + autoload=True, + keep_existing=True, + ) assert isinstance(users.c.name.type, Unicode) @testing.skip_if( lambda: testing.db.dialect.requires_name_normalize, - "test depends on lowercase as case insensitive") + "test depends on lowercase as case insensitive", + ) def test_keep_existing_quote_no_orig(self): meta2 = self._notexisting_fixture() - users = Table('users', meta2, quote=True, - autoload=True, - keep_existing=True) + users = Table( + "users", meta2, quote=True, autoload=True, keep_existing=True + ) assert users.name.quote def test_keep_existing_add_column_no_orig(self): meta2 = self._notexisting_fixture() - users = Table('users', meta2, - Column('foo', Integer), - autoload=True, - keep_existing=True) + users = Table( + "users", + meta2, + Column("foo", Integer), + autoload=True, + keep_existing=True, + ) assert "foo" in users.c def test_keep_existing_coltype_no_reflection(self): meta2 = self._useexisting_fixture() - users = Table('users', meta2, Column('name', Unicode), - keep_existing=True) + users = Table( + "users", meta2, Column("name", Unicode), keep_existing=True + ) assert not isinstance(users.c.name.type, Unicode) def test_keep_existing_quote_no_reflection(self): meta2 = self._useexisting_fixture() - users = Table('users', meta2, quote=True, - keep_existing=True) + users = Table("users", meta2, quote=True, keep_existing=True) assert not users.name.quote def test_keep_existing_add_column_no_reflection(self): meta2 = self._useexisting_fixture() - users = Table('users', meta2, - Column('foo', Integer), - keep_existing=True) + users = Table( + "users", meta2, Column("foo", Integer), keep_existing=True + ) assert "foo" not in users.c def test_extend_existing_coltype(self): meta2 = self._useexisting_fixture() - users = Table('users', meta2, Column('name', Unicode), - autoload=True, extend_existing=True) + users = Table( + "users", + meta2, + Column("name", Unicode), + autoload=True, + extend_existing=True, + ) assert isinstance(users.c.name.type, Unicode) def test_extend_existing_quote(self): @@ -2258,46 +2416,63 @@ class UseExistingTest(fixtures.TablesTest): assert_raises_message( tsa.exc.ArgumentError, "Can't redefine 'quote' or 'quote_schema' arguments", - Table, 'users', meta2, quote=True, autoload=True, - extend_existing=True + Table, + "users", + meta2, + quote=True, + autoload=True, + extend_existing=True, ) def test_extend_existing_add_column(self): meta2 = self._useexisting_fixture() - users = Table('users', meta2, - Column('foo', Integer), - autoload=True, - extend_existing=True) + users = Table( + "users", + meta2, + Column("foo", Integer), + autoload=True, + extend_existing=True, + ) assert "foo" in users.c def test_extend_existing_coltype_no_orig(self): meta2 = self._notexisting_fixture() - users = Table('users', meta2, Column('name', Unicode), - autoload=True, extend_existing=True) + users = Table( + "users", + meta2, + Column("name", Unicode), + autoload=True, + extend_existing=True, + ) assert isinstance(users.c.name.type, Unicode) @testing.skip_if( lambda: testing.db.dialect.requires_name_normalize, - "test depends on lowercase as case insensitive") + "test depends on lowercase as case insensitive", + ) def test_extend_existing_quote_no_orig(self): meta2 = self._notexisting_fixture() - users = Table('users', meta2, quote=True, - autoload=True, - extend_existing=True) + users = Table( + "users", meta2, quote=True, autoload=True, extend_existing=True + ) assert users.name.quote def test_extend_existing_add_column_no_orig(self): meta2 = self._notexisting_fixture() - users = Table('users', meta2, - Column('foo', Integer), - autoload=True, - extend_existing=True) + users = Table( + "users", + meta2, + Column("foo", Integer), + autoload=True, + extend_existing=True, + ) assert "foo" in users.c def test_extend_existing_coltype_no_reflection(self): meta2 = self._useexisting_fixture() - users = Table('users', meta2, Column('name', Unicode), - extend_existing=True) + users = Table( + "users", meta2, Column("name", Unicode), extend_existing=True + ) assert isinstance(users.c.name.type, Unicode) def test_extend_existing_quote_no_reflection(self): @@ -2305,35 +2480,30 @@ class UseExistingTest(fixtures.TablesTest): assert_raises_message( tsa.exc.ArgumentError, "Can't redefine 'quote' or 'quote_schema' arguments", - Table, 'users', meta2, quote=True, - extend_existing=True + Table, + "users", + meta2, + quote=True, + extend_existing=True, ) def test_extend_existing_add_column_no_reflection(self): meta2 = self._useexisting_fixture() - users = Table('users', meta2, - Column('foo', Integer), - extend_existing=True) + users = Table( + "users", meta2, Column("foo", Integer), extend_existing=True + ) assert "foo" in users.c class ConstraintTest(fixtures.TestBase): - def _single_fixture(self): m = MetaData() - t1 = Table('t1', m, - Column('a', Integer), - Column('b', Integer) - ) + t1 = Table("t1", m, Column("a", Integer), Column("b", Integer)) - t2 = Table('t2', m, - Column('a', Integer, ForeignKey('t1.a')) - ) + t2 = Table("t2", m, Column("a", Integer, ForeignKey("t1.a"))) - t3 = Table('t3', m, - Column('a', Integer) - ) + t3 = Table("t3", m, Column("a", Integer)) return t1, t2, t3 def _assert_index_col_x(self, t, i, columns=True): @@ -2346,106 +2516,108 @@ class ConstraintTest(fixtures.TestBase): def test_separate_decl_columns(self): m = MetaData() - t = Table('t', m, Column('x', Integer)) - i = Index('i', t.c.x) + t = Table("t", m, Column("x", Integer)) + i = Index("i", t.c.x) self._assert_index_col_x(t, i) def test_separate_decl_columns_functional(self): m = MetaData() - t = Table('t', m, Column('x', Integer)) - i = Index('i', func.foo(t.c.x)) + t = Table("t", m, Column("x", Integer)) + i = Index("i", func.foo(t.c.x)) self._assert_index_col_x(t, i) def test_index_no_cols_private_table_arg(self): m = MetaData() - t = Table('t', m, Column('x', Integer)) - i = Index('i', _table=t) + t = Table("t", m, Column("x", Integer)) + i = Index("i", _table=t) is_(i.table, t) eq_(list(i.columns), []) def test_index_w_cols_private_table_arg(self): m = MetaData() - t = Table('t', m, Column('x', Integer)) - i = Index('i', t.c.x, _table=t) + t = Table("t", m, Column("x", Integer)) + i = Index("i", t.c.x, _table=t) is_(i.table, t) eq_(i.columns, [t.c.x]) def test_inline_decl_columns(self): m = MetaData() - c = Column('x', Integer) - i = Index('i', c) - t = Table('t', m, c, i) + c = Column("x", Integer) + i = Index("i", c) + t = Table("t", m, c, i) self._assert_index_col_x(t, i) def test_inline_decl_columns_functional(self): m = MetaData() - c = Column('x', Integer) - i = Index('i', func.foo(c)) - t = Table('t', m, c, i) + c = Column("x", Integer) + i = Index("i", func.foo(c)) + t = Table("t", m, c, i) self._assert_index_col_x(t, i) def test_inline_decl_string(self): m = MetaData() - i = Index('i', "x") - t = Table('t', m, Column('x', Integer), i) + i = Index("i", "x") + t = Table("t", m, Column("x", Integer), i) self._assert_index_col_x(t, i) def test_inline_decl_textonly(self): m = MetaData() - i = Index('i', text("foobar(x)")) - t = Table('t', m, Column('x', Integer), i) + i = Index("i", text("foobar(x)")) + t = Table("t", m, Column("x", Integer), i) self._assert_index_col_x(t, i, columns=False) def test_separate_decl_textonly(self): m = MetaData() - i = Index('i', text("foobar(x)")) - t = Table('t', m, Column('x', Integer)) + i = Index("i", text("foobar(x)")) + t = Table("t", m, Column("x", Integer)) t.append_constraint(i) self._assert_index_col_x(t, i, columns=False) def test_unnamed_column_exception(self): # this can occur in some declarative situations c = Column(Integer) - idx = Index('q', c) + idx = Index("q", c) m = MetaData() - t = Table('t', m, Column('q')) + t = Table("t", m, Column("q")) assert_raises_message( exc.ArgumentError, "Can't add unnamed column to column collection", - t.append_constraint, idx + t.append_constraint, + idx, ) def test_column_associated_w_lowercase_table(self): from sqlalchemy import table - c = Column('x', Integer) - table('foo', c) - idx = Index('q', c) + + c = Column("x", Integer) + table("foo", c) + idx = Index("q", c) is_(idx.table, None) # lower-case-T table doesn't have indexes def test_clauseelement_extraction_one(self): - t = Table('t', MetaData(), Column('x', Integer), Column('y', Integer)) + t = Table("t", MetaData(), Column("x", Integer), Column("y", Integer)) class MyThing(object): def __clause_element__(self): return t.c.x + 5 - idx = Index('foo', MyThing()) + idx = Index("foo", MyThing()) self._assert_index_col_x(t, idx) def test_clauseelement_extraction_two(self): - t = Table('t', MetaData(), Column('x', Integer), Column('y', Integer)) + t = Table("t", MetaData(), Column("x", Integer), Column("y", Integer)) class MyThing(object): def __clause_element__(self): return t.c.x + 5 - idx = Index('bar', MyThing(), t.c.y) + idx = Index("bar", MyThing(), t.c.y) eq_(set(t.indexes), set([idx])) def test_clauseelement_extraction_three(self): - t = Table('t', MetaData(), Column('x', Integer), Column('y', Integer)) + t = Table("t", MetaData(), Column("x", Integer), Column("y", Integer)) expr1 = t.c.x + 5 @@ -2453,7 +2625,7 @@ class ConstraintTest(fixtures.TestBase): def __clause_element__(self): return expr1 - idx = Index('bar', MyThing(), t.c.y) + idx = Index("bar", MyThing(), t.c.y) is_(idx.expressions[0], expr1) is_(idx.expressions[1], t.c.y) @@ -2493,59 +2665,52 @@ class ConstraintTest(fixtures.TestBase): is_(fkc.referred_table, t1) def test_referred_table_accessor_not_available(self): - t1 = Table('t', MetaData(), Column('x', ForeignKey('q.id'))) + t1 = Table("t", MetaData(), Column("x", ForeignKey("q.id"))) fkc = list(t1.foreign_key_constraints)[0] assert_raises_message( exc.InvalidRequestError, "Foreign key associated with column 't.x' could not find " "table 'q' with which to generate a foreign key to target " "column 'id'", - getattr, fkc, "referred_table" + getattr, + fkc, + "referred_table", ) def test_related_column_not_present_atfirst_ok(self): m = MetaData() - base_table = Table("base", m, - Column("id", Integer, primary_key=True) - ) - fk = ForeignKey('base.q') - derived_table = Table("derived", m, - Column("id", None, fk, - primary_key=True), - ) - - base_table.append_column(Column('q', Integer)) + base_table = Table("base", m, Column("id", Integer, primary_key=True)) + fk = ForeignKey("base.q") + derived_table = Table( + "derived", m, Column("id", None, fk, primary_key=True) + ) + + base_table.append_column(Column("q", Integer)) assert fk.column is base_table.c.q assert isinstance(derived_table.c.id.type, Integer) def test_related_column_not_present_atfirst_ok_onname(self): m = MetaData() - base_table = Table("base", m, - Column("id", Integer, primary_key=True) - ) - fk = ForeignKey('base.q', link_to_name=True) - derived_table = Table("derived", m, - Column("id", None, fk, - primary_key=True), - ) - - base_table.append_column(Column('q', Integer, key='zz')) + base_table = Table("base", m, Column("id", Integer, primary_key=True)) + fk = ForeignKey("base.q", link_to_name=True) + derived_table = Table( + "derived", m, Column("id", None, fk, primary_key=True) + ) + + base_table.append_column(Column("q", Integer, key="zz")) assert fk.column is base_table.c.zz assert isinstance(derived_table.c.id.type, Integer) def test_related_column_not_present_atfirst_ok_linktoname_conflict(self): m = MetaData() - base_table = Table("base", m, - Column("id", Integer, primary_key=True) - ) - fk = ForeignKey('base.q', link_to_name=True) - derived_table = Table("derived", m, - Column("id", None, fk, - primary_key=True), - ) - - base_table.append_column(Column('zz', Integer, key='q')) - base_table.append_column(Column('q', Integer, key='zz')) + base_table = Table("base", m, Column("id", Integer, primary_key=True)) + fk = ForeignKey("base.q", link_to_name=True) + derived_table = Table( + "derived", m, Column("id", None, fk, primary_key=True) + ) + + base_table.append_column(Column("zz", Integer, key="q")) + base_table.append_column(Column("q", Integer, key="zz")) assert fk.column is base_table.c.zz assert isinstance(derived_table.c.id.type, Integer) @@ -2557,134 +2722,136 @@ class ConstraintTest(fixtures.TestBase): r"ForeignKeyConstraint on t1\(x, y\) refers to " "multiple remote tables: t2 and t3", Table, - 't1', m, Column('x', Integer), Column('y', Integer), - ForeignKeyConstraint(['x', 'y'], ['t2.x', 't3.y']) + "t1", + m, + Column("x", Integer), + Column("y", Integer), + ForeignKeyConstraint(["x", "y"], ["t2.x", "t3.y"]), ) def test_invalid_composite_fk_check_columns(self): m = MetaData() - t2 = Table('t2', m, Column('x', Integer)) - t3 = Table('t3', m, Column('y', Integer)) + t2 = Table("t2", m, Column("x", Integer)) + t3 = Table("t3", m, Column("y", Integer)) assert_raises_message( exc.ArgumentError, r"ForeignKeyConstraint on t1\(x, y\) refers to " "multiple remote tables: t2 and t3", Table, - 't1', m, Column('x', Integer), Column('y', Integer), - ForeignKeyConstraint(['x', 'y'], [t2.c.x, t3.c.y]) + "t1", + m, + Column("x", Integer), + Column("y", Integer), + ForeignKeyConstraint(["x", "y"], [t2.c.x, t3.c.y]), ) def test_invalid_composite_fk_check_columns_notattached(self): m = MetaData() - x = Column('x', Integer) - y = Column('y', Integer) + x = Column("x", Integer) + y = Column("y", Integer) # no error is raised for this one right now. # which is a minor bug. - Table('t1', m, Column('x', Integer), Column('y', Integer), - ForeignKeyConstraint(['x', 'y'], [x, y]) - ) + Table( + "t1", + m, + Column("x", Integer), + Column("y", Integer), + ForeignKeyConstraint(["x", "y"], [x, y]), + ) - Table('t2', m, x) - Table('t3', m, y) + Table("t2", m, x) + Table("t3", m, y) def test_constraint_copied_to_proxy_ok(self): m = MetaData() - Table('t1', m, Column('id', Integer, primary_key=True)) - t2 = Table('t2', m, Column('id', Integer, ForeignKey('t1.id'), - primary_key=True)) + Table("t1", m, Column("id", Integer, primary_key=True)) + t2 = Table( + "t2", + m, + Column("id", Integer, ForeignKey("t1.id"), primary_key=True), + ) s = tsa.select([t2]) t2fk = list(t2.c.id.foreign_keys)[0] sfk = list(s.c.id.foreign_keys)[0] # the two FKs share the ForeignKeyConstraint - is_( - t2fk.constraint, - sfk.constraint - ) + is_(t2fk.constraint, sfk.constraint) # but the ForeignKeyConstraint isn't # aware of the select's FK - eq_( - t2fk.constraint.elements, - [t2fk] - ) + eq_(t2fk.constraint.elements, [t2fk]) def test_type_propagate_composite_fk_string(self): metadata = MetaData() Table( - 'a', metadata, - Column('key1', Integer, primary_key=True), - Column('key2', String(40), primary_key=True)) - - b = Table('b', metadata, - Column('a_key1', None), - Column('a_key2', None), - Column('id', Integer, primary_key=True), - ForeignKeyConstraint(['a_key1', 'a_key2'], - ['a.key1', 'a.key2']) - ) + "a", + metadata, + Column("key1", Integer, primary_key=True), + Column("key2", String(40), primary_key=True), + ) + + b = Table( + "b", + metadata, + Column("a_key1", None), + Column("a_key2", None), + Column("id", Integer, primary_key=True), + ForeignKeyConstraint(["a_key1", "a_key2"], ["a.key1", "a.key2"]), + ) assert isinstance(b.c.a_key1.type, Integer) assert isinstance(b.c.a_key2.type, String) def test_type_propagate_composite_fk_col(self): metadata = MetaData() - a = Table('a', metadata, - Column('key1', Integer, primary_key=True), - Column('key2', String(40), primary_key=True)) - - b = Table('b', metadata, - Column('a_key1', None), - Column('a_key2', None), - Column('id', Integer, primary_key=True), - ForeignKeyConstraint(['a_key1', 'a_key2'], - [a.c.key1, a.c.key2]) - ) + a = Table( + "a", + metadata, + Column("key1", Integer, primary_key=True), + Column("key2", String(40), primary_key=True), + ) + + b = Table( + "b", + metadata, + Column("a_key1", None), + Column("a_key2", None), + Column("id", Integer, primary_key=True), + ForeignKeyConstraint(["a_key1", "a_key2"], [a.c.key1, a.c.key2]), + ) assert isinstance(b.c.a_key1.type, Integer) assert isinstance(b.c.a_key2.type, String) def test_type_propagate_standalone_fk_string(self): metadata = MetaData() - Table( - 'a', metadata, - Column('key1', Integer, primary_key=True)) + Table("a", metadata, Column("key1", Integer, primary_key=True)) - b = Table('b', metadata, - Column('a_key1', None, ForeignKey("a.key1")), - ) + b = Table("b", metadata, Column("a_key1", None, ForeignKey("a.key1"))) assert isinstance(b.c.a_key1.type, Integer) def test_type_propagate_standalone_fk_col(self): metadata = MetaData() - a = Table('a', metadata, - Column('key1', Integer, primary_key=True)) + a = Table("a", metadata, Column("key1", Integer, primary_key=True)) - b = Table('b', metadata, - Column('a_key1', None, ForeignKey(a.c.key1)), - ) + b = Table("b", metadata, Column("a_key1", None, ForeignKey(a.c.key1))) assert isinstance(b.c.a_key1.type, Integer) def test_type_propagate_chained_string_source_first(self): metadata = MetaData() - Table( - 'a', metadata, - Column('key1', Integer, primary_key=True) - ) + Table("a", metadata, Column("key1", Integer, primary_key=True)) - b = Table('b', metadata, - Column('a_key1', None, ForeignKey("a.key1")), - ) + b = Table("b", metadata, Column("a_key1", None, ForeignKey("a.key1"))) - c = Table('c', metadata, - Column('b_key1', None, ForeignKey("b.a_key1")), - ) + c = Table( + "c", metadata, Column("b_key1", None, ForeignKey("b.a_key1")) + ) assert isinstance(b.c.a_key1.type, Integer) assert isinstance(c.c.b_key1.type, Integer) @@ -2692,17 +2859,13 @@ class ConstraintTest(fixtures.TestBase): def test_type_propagate_chained_string_source_last(self): metadata = MetaData() - b = Table('b', metadata, - Column('a_key1', None, ForeignKey("a.key1")), - ) + b = Table("b", metadata, Column("a_key1", None, ForeignKey("a.key1"))) - c = Table('c', metadata, - Column('b_key1', None, ForeignKey("b.a_key1")), - ) + c = Table( + "c", metadata, Column("b_key1", None, ForeignKey("b.a_key1")) + ) - Table( - 'a', metadata, - Column('key1', Integer, primary_key=True)) + Table("a", metadata, Column("key1", Integer, primary_key=True)) assert isinstance(b.c.a_key1.type, Integer) assert isinstance(c.c.b_key1.type, Integer) @@ -2710,21 +2873,31 @@ class ConstraintTest(fixtures.TestBase): def test_type_propagate_chained_string_source_last_onname(self): metadata = MetaData() - b = Table('b', metadata, - Column( - 'a_key1', None, - ForeignKey("a.key1", link_to_name=True), key="ak1"), - ) + b = Table( + "b", + metadata, + Column( + "a_key1", + None, + ForeignKey("a.key1", link_to_name=True), + key="ak1", + ), + ) - c = Table('c', metadata, - Column( - 'b_key1', None, - ForeignKey("b.a_key1", link_to_name=True), key="bk1"), - ) + c = Table( + "c", + metadata, + Column( + "b_key1", + None, + ForeignKey("b.a_key1", link_to_name=True), + key="bk1", + ), + ) Table( - 'a', metadata, - Column('key1', Integer, primary_key=True, key='ak1')) + "a", metadata, Column("key1", Integer, primary_key=True, key="ak1") + ) assert isinstance(b.c.ak1.type, Integer) assert isinstance(c.c.bk1.type, Integer) @@ -2732,34 +2905,41 @@ class ConstraintTest(fixtures.TestBase): def test_type_propagate_chained_string_source_last_onname_conflict(self): metadata = MetaData() - b = Table('b', metadata, - # b.c.key1 -> a.c.key1 -> String - Column( - 'ak1', None, - ForeignKey("a.key1", link_to_name=False), key="key1"), - # b.c.ak1 -> a.c.ak1 -> Integer - Column( - 'a_key1', None, - ForeignKey("a.key1", link_to_name=True), key="ak1"), - ) - - c = Table('c', metadata, - # c.c.b_key1 -> b.c.ak1 -> Integer - Column( - 'b_key1', None, - ForeignKey("b.ak1", link_to_name=False)), - # c.c.b_ak1 -> b.c.ak1 - Column( - 'b_ak1', None, - ForeignKey("b.ak1", link_to_name=True)), - ) + b = Table( + "b", + metadata, + # b.c.key1 -> a.c.key1 -> String + Column( + "ak1", + None, + ForeignKey("a.key1", link_to_name=False), + key="key1", + ), + # b.c.ak1 -> a.c.ak1 -> Integer + Column( + "a_key1", + None, + ForeignKey("a.key1", link_to_name=True), + key="ak1", + ), + ) + + c = Table( + "c", + metadata, + # c.c.b_key1 -> b.c.ak1 -> Integer + Column("b_key1", None, ForeignKey("b.ak1", link_to_name=False)), + # c.c.b_ak1 -> b.c.ak1 + Column("b_ak1", None, ForeignKey("b.ak1", link_to_name=True)), + ) Table( - 'a', metadata, + "a", + metadata, # a.c.key1 - Column('ak1', String, key="key1"), + Column("ak1", String, key="key1"), # a.c.ak1 - Column('key1', Integer, primary_key=True, key='ak1'), + Column("key1", Integer, primary_key=True, key="ak1"), ) assert isinstance(b.c.key1.type, String) @@ -2770,30 +2950,26 @@ class ConstraintTest(fixtures.TestBase): def test_type_propagate_chained_col_orig_first(self): metadata = MetaData() - a = Table('a', metadata, - Column('key1', Integer, primary_key=True)) + a = Table("a", metadata, Column("key1", Integer, primary_key=True)) - b = Table('b', metadata, - Column('a_key1', None, ForeignKey(a.c.key1)), - ) + b = Table("b", metadata, Column("a_key1", None, ForeignKey(a.c.key1))) - c = Table('c', metadata, - Column('b_key1', None, ForeignKey(b.c.a_key1)), - ) + c = Table( + "c", metadata, Column("b_key1", None, ForeignKey(b.c.a_key1)) + ) assert isinstance(b.c.a_key1.type, Integer) assert isinstance(c.c.b_key1.type, Integer) def test_column_accessor_col(self): - c1 = Column('x', Integer) + c1 = Column("x", Integer) fk = ForeignKey(c1) is_(fk.column, c1) def test_column_accessor_clause_element(self): - c1 = Column('x', Integer) + c1 = Column("x", Integer) class CThing(object): - def __init__(self, c): self.c = c @@ -2809,50 +2985,58 @@ class ConstraintTest(fixtures.TestBase): exc.InvalidRequestError, "this ForeignKey object does not yet have a parent " "Column associated with it.", - getattr, fk, "column" + getattr, + fk, + "column", ) def test_column_accessor_string_no_parent_table(self): fk = ForeignKey("sometable.somecol") - Column('x', fk) + Column("x", fk) assert_raises_message( exc.InvalidRequestError, "this ForeignKey's parent column is not yet " "associated with a Table.", - getattr, fk, "column" + getattr, + fk, + "column", ) def test_column_accessor_string_no_target_table(self): fk = ForeignKey("sometable.somecol") - c1 = Column('x', fk) - Table('t', MetaData(), c1) + c1 = Column("x", fk) + Table("t", MetaData(), c1) assert_raises_message( exc.NoReferencedTableError, "Foreign key associated with column 't.x' could not find " "table 'sometable' with which to generate a " "foreign key to target column 'somecol'", - getattr, fk, "column" + getattr, + fk, + "column", ) def test_column_accessor_string_no_target_column(self): fk = ForeignKey("sometable.somecol") - c1 = Column('x', fk) + c1 = Column("x", fk) m = MetaData() - Table('t', m, c1) - Table("sometable", m, Column('notsomecol', Integer)) + Table("t", m, c1) + Table("sometable", m, Column("notsomecol", Integer)) assert_raises_message( exc.NoReferencedColumnError, "Could not initialize target column for ForeignKey " "'sometable.somecol' on table 't': " "table 'sometable' has no column named 'somecol'", - getattr, fk, "column" + getattr, + fk, + "column", ) def test_remove_table_fk_bookkeeping(self): metadata = MetaData() - fk = ForeignKey('t1.x') - t2 = Table('t2', metadata, Column('y', Integer, fk)) - t3 = Table('t3', metadata, Column('y', Integer, ForeignKey('t1.x'))) + fk = ForeignKey("t1.x") + t2 = Table("t2", metadata, Column("y", Integer, fk)) + t3 = Table("t3", metadata, Column("y", Integer, ForeignKey("t1.x"))) assert t2.key in metadata.tables assert ("t1", "x") in metadata._fk_memos @@ -2869,13 +3053,15 @@ class ConstraintTest(fixtures.TestBase): assert fk not in metadata._fk_memos[("t1", "x")] # make the referenced table - t1 = Table('t1', metadata, Column('x', Integer)) + t1 = Table("t1", metadata, Column("x", Integer)) # t2 tells us exactly what's wrong assert_raises_message( exc.InvalidRequestError, "Table t2 is no longer associated with its parent MetaData", - getattr, fk, "column" + getattr, + fk, + "column", ) # t3 is unaffected @@ -2885,63 +3071,51 @@ class ConstraintTest(fixtures.TestBase): metadata.remove(t2) def test_double_fk_usage_raises(self): - f = ForeignKey('b.id') + f = ForeignKey("b.id") - Column('x', Integer, f) + Column("x", Integer, f) assert_raises(exc.InvalidRequestError, Column, "y", Integer, f) def test_auto_append_constraint(self): m = MetaData() - t = Table('tbl', m, - Column('a', Integer), - Column('b', Integer) - ) + t = Table("tbl", m, Column("a", Integer), Column("b", Integer)) - t2 = Table('t2', m, - Column('a', Integer), - Column('b', Integer) - ) + t2 = Table("t2", m, Column("a", Integer), Column("b", Integer)) for c in ( UniqueConstraint(t.c.a), CheckConstraint(t.c.a > 5), ForeignKeyConstraint([t.c.a], [t2.c.a]), - PrimaryKeyConstraint(t.c.a) + PrimaryKeyConstraint(t.c.a), ): assert c in t.constraints t.append_constraint(c) assert c in t.constraints - c = Index('foo', t.c.a) + c = Index("foo", t.c.a) assert c in t.indexes def test_auto_append_lowercase_table(self): from sqlalchemy import table, column - t = table('t', column('a')) - t2 = table('t2', column('a')) + t = table("t", column("a")) + t2 = table("t2", column("a")) for c in ( UniqueConstraint(t.c.a), CheckConstraint(t.c.a > 5), ForeignKeyConstraint([t.c.a], [t2.c.a]), PrimaryKeyConstraint(t.c.a), - Index('foo', t.c.a) + Index("foo", t.c.a), ): assert True def test_tometadata_ok(self): m = MetaData() - t = Table('tbl', m, - Column('a', Integer), - Column('b', Integer) - ) + t = Table("tbl", m, Column("a", Integer), Column("b", Integer)) - t2 = Table('t2', m, - Column('a', Integer), - Column('b', Integer) - ) + t2 = Table("t2", m, Column("a", Integer), Column("b", Integer)) UniqueConstraint(t.c.a) CheckConstraint(t.c.a > 5) @@ -2959,10 +3133,7 @@ class ConstraintTest(fixtures.TestBase): def test_check_constraint_copy(self): m = MetaData() - t = Table('tbl', m, - Column('a', Integer), - Column('b', Integer) - ) + t = Table("tbl", m, Column("a", Integer), Column("b", Integer)) ck = CheckConstraint(t.c.a > 5) ck2 = ck.copy() assert ck in t.constraints @@ -2971,15 +3142,9 @@ class ConstraintTest(fixtures.TestBase): def test_ambig_check_constraint_auto_append(self): m = MetaData() - t = Table('tbl', m, - Column('a', Integer), - Column('b', Integer) - ) + t = Table("tbl", m, Column("a", Integer), Column("b", Integer)) - t2 = Table('t2', m, - Column('a', Integer), - Column('b', Integer) - ) + t2 = Table("t2", m, Column("a", Integer), Column("b", Integer)) c = CheckConstraint(t.c.a > t2.c.b) assert c not in t.constraints assert c not in t2.constraints @@ -2987,22 +3152,22 @@ class ConstraintTest(fixtures.TestBase): def test_auto_append_ck_on_col_attach_one(self): m = MetaData() - a = Column('a', Integer) - b = Column('b', Integer) + a = Column("a", Integer) + b = Column("b", Integer) ck = CheckConstraint(a > b) - t = Table('tbl', m, a, b) + t = Table("tbl", m, a, b) assert ck in t.constraints def test_auto_append_ck_on_col_attach_two(self): m = MetaData() - a = Column('a', Integer) - b = Column('b', Integer) - c = Column('c', Integer) + a = Column("a", Integer) + b = Column("b", Integer) + c = Column("c", Integer) ck = CheckConstraint(a > b + c) - t = Table('tbl', m, a) + t = Table("tbl", m, a) assert ck not in t.constraints t.append_column(b) @@ -3014,18 +3179,18 @@ class ConstraintTest(fixtures.TestBase): def test_auto_append_ck_on_col_attach_three(self): m = MetaData() - a = Column('a', Integer) - b = Column('b', Integer) - c = Column('c', Integer) + a = Column("a", Integer) + b = Column("b", Integer) + c = Column("c", Integer) ck = CheckConstraint(a > b + c) - t = Table('tbl', m, a) + t = Table("tbl", m, a) assert ck not in t.constraints t.append_column(b) assert ck not in t.constraints - t2 = Table('t2', m) + t2 = Table("t2", m) t2.append_column(c) # two different tables, so CheckConstraint does nothing. @@ -3034,22 +3199,22 @@ class ConstraintTest(fixtures.TestBase): def test_auto_append_uq_on_col_attach_one(self): m = MetaData() - a = Column('a', Integer) - b = Column('b', Integer) + a = Column("a", Integer) + b = Column("b", Integer) uq = UniqueConstraint(a, b) - t = Table('tbl', m, a, b) + t = Table("tbl", m, a, b) assert uq in t.constraints def test_auto_append_uq_on_col_attach_two(self): m = MetaData() - a = Column('a', Integer) - b = Column('b', Integer) - c = Column('c', Integer) + a = Column("a", Integer) + b = Column("b", Integer) + c = Column("c", Integer) uq = UniqueConstraint(a, b, c) - t = Table('tbl', m, a) + t = Table("tbl", m, a) assert uq not in t.constraints t.append_column(b) @@ -3061,24 +3226,25 @@ class ConstraintTest(fixtures.TestBase): def test_auto_append_uq_on_col_attach_three(self): m = MetaData() - a = Column('a', Integer) - b = Column('b', Integer) - c = Column('c', Integer) + a = Column("a", Integer) + b = Column("b", Integer) + c = Column("c", Integer) uq = UniqueConstraint(a, b, c) - t = Table('tbl', m, a) + t = Table("tbl", m, a) assert uq not in t.constraints t.append_column(b) assert uq not in t.constraints - t2 = Table('t2', m) + t2 = Table("t2", m) # two different tables, so UniqueConstraint raises assert_raises_message( exc.ArgumentError, r"Column\(s\) 't2\.c' are not part of table 'tbl'\.", - t2.append_column, c + t2.append_column, + c, ) def test_auto_append_uq_on_col_attach_four(self): @@ -3088,12 +3254,12 @@ class ConstraintTest(fixtures.TestBase): """ m = MetaData() - a = Column('a', Integer) - b = Column('b', Integer) - c = Column('c', Integer) - uq = UniqueConstraint(a, 'b', 'c') + a = Column("a", Integer) + b = Column("b", Integer) + c = Column("c", Integer) + uq = UniqueConstraint(a, "b", "c") - t = Table('tbl', m, a) + t = Table("tbl", m, a) assert uq not in t.constraints t.append_column(b) @@ -3111,7 +3277,7 @@ class ConstraintTest(fixtures.TestBase): eq_( [cn for cn in t.constraints if isinstance(cn, UniqueConstraint)], - [uq] + [uq], ) def test_auto_append_uq_on_col_attach_five(self): @@ -3121,13 +3287,13 @@ class ConstraintTest(fixtures.TestBase): """ m = MetaData() - a = Column('a', Integer) - b = Column('b', Integer) - c = Column('c', Integer) + a = Column("a", Integer) + b = Column("b", Integer) + c = Column("c", Integer) - t = Table('tbl', m, a, c, b) + t = Table("tbl", m, a, c, b) - uq = UniqueConstraint(a, 'b', 'c') + uq = UniqueConstraint(a, "b", "c") assert uq in t.constraints @@ -3137,38 +3303,36 @@ class ConstraintTest(fixtures.TestBase): eq_( [cn for cn in t.constraints if isinstance(cn, UniqueConstraint)], - [uq] + [uq], ) def test_index_asserts_cols_standalone(self): metadata = MetaData() - t1 = Table('t1', metadata, - Column('x', Integer) - ) - t2 = Table('t2', metadata, - Column('y', Integer) - ) + t1 = Table("t1", metadata, Column("x", Integer)) + t2 = Table("t2", metadata, Column("y", Integer)) assert_raises_message( exc.ArgumentError, r"Column\(s\) 't2.y' are not part of table 't1'.", Index, - "bar", t1.c.x, t2.c.y + "bar", + t1.c.x, + t2.c.y, ) def test_index_asserts_cols_inline(self): metadata = MetaData() - t1 = Table('t1', metadata, - Column('x', Integer) - ) + t1 = Table("t1", metadata, Column("x", Integer)) assert_raises_message( exc.ArgumentError, "Index 'bar' is against table 't1', and " "cannot be associated with table 't2'.", - Table, 't2', metadata, - Column('y', Integer), - Index('bar', t1.c.x) + Table, + "t2", + metadata, + Column("y", Integer), + Index("bar", t1.c.x), ) def test_raise_index_nonexistent_name(self): @@ -3176,28 +3340,25 @@ class ConstraintTest(fixtures.TestBase): # the KeyError isn't ideal here, a nicer message # perhaps assert_raises( - KeyError, - Table, 't', m, Column('x', Integer), Index("foo", "q") + KeyError, Table, "t", m, Column("x", Integer), Index("foo", "q") ) def test_raise_not_a_column(self): - assert_raises( - exc.ArgumentError, - Index, "foo", 5 - ) + assert_raises(exc.ArgumentError, Index, "foo", 5) def test_raise_expr_no_column(self): - idx = Index('foo', func.lower(5)) + idx = Index("foo", func.lower(5)) assert_raises_message( exc.CompileError, "Index 'foo' is not associated with any table.", - schema.CreateIndex(idx).compile, dialect=testing.db.dialect + schema.CreateIndex(idx).compile, + dialect=testing.db.dialect, ) assert_raises_message( exc.CompileError, "Index 'foo' is not associated with any table.", - schema.CreateIndex(idx).compile + schema.CreateIndex(idx).compile, ) def test_no_warning_w_no_columns(self): @@ -3206,26 +3367,29 @@ class ConstraintTest(fixtures.TestBase): assert_raises_message( exc.CompileError, "Index 'foo' is not associated with any table.", - schema.CreateIndex(idx).compile, dialect=testing.db.dialect + schema.CreateIndex(idx).compile, + dialect=testing.db.dialect, ) assert_raises_message( exc.CompileError, "Index 'foo' is not associated with any table.", - schema.CreateIndex(idx).compile + schema.CreateIndex(idx).compile, ) def test_raise_clauseelement_not_a_column(self): m = MetaData() - t2 = Table('t2', m, Column('x', Integer)) + t2 = Table("t2", m, Column("x", Integer)) class SomeClass(object): - def __clause_element__(self): return t2 + assert_raises_message( exc.ArgumentError, r"Element Table\('t2', .* is not a string name or column element", - Index, "foo", SomeClass() + Index, + "foo", + SomeClass(), ) @@ -3233,28 +3397,30 @@ class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase): """Test Column() construction.""" - __dialect__ = 'default' + __dialect__ = "default" def columns(self): - return [Column(Integer), - Column('b', Integer), - Column(Integer), - Column('d', Integer), - Column(Integer, name='e'), - Column(type_=Integer), - Column(Integer()), - Column('h', Integer()), - Column(type_=Integer())] + return [ + Column(Integer), + Column("b", Integer), + Column(Integer), + Column("d", Integer), + Column(Integer, name="e"), + Column(type_=Integer), + Column(Integer()), + Column("h", Integer()), + Column(type_=Integer()), + ] def test_basic(self): c = self.columns() - for i, v in ((0, 'a'), (2, 'c'), (5, 'f'), (6, 'g'), (8, 'i')): + for i, v in ((0, "a"), (2, "c"), (5, "f"), (6, "g"), (8, "i")): c[i].name = v c[i].key = v del i, v - tbl = Table('table', MetaData(), *c) + tbl = Table("table", MetaData(), *c) for i, col in enumerate(tbl.c): assert col.name == c[i].name @@ -3266,35 +3432,47 @@ class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase): exc.ArgumentError, "Column must be constructed with a non-blank name or assign a " "non-blank .name ", - Table, 't', MetaData(), c) + Table, + "t", + MetaData(), + c, + ) def test_name_blank(self): - c = Column('', Integer) + c = Column("", Integer) assert_raises_message( exc.ArgumentError, "Column must be constructed with a non-blank name or assign a " "non-blank .name ", - Table, 't', MetaData(), c) + Table, + "t", + MetaData(), + c, + ) def test_dupe_column(self): - c = Column('x', Integer) - Table('t', MetaData(), c) + c = Column("x", Integer) + Table("t", MetaData(), c) assert_raises_message( exc.ArgumentError, "Column object 'x' already assigned to Table 't'", - Table, 'q', MetaData(), c) + Table, + "q", + MetaData(), + c, + ) def test_incomplete_key(self): c = Column(Integer) assert c.name is None assert c.key is None - c.name = 'named' - Table('t', MetaData(), c) + c.name = "named" + Table("t", MetaData(), c) - assert c.name == 'named' + assert c.name == "named" assert c.name == c.key def test_unique_index_flags_default_to_none(self): @@ -3302,28 +3480,29 @@ class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase): eq_(c.unique, None) eq_(c.index, None) - c = Column('c', Integer, index=True) + c = Column("c", Integer, index=True) eq_(c.unique, None) eq_(c.index, True) - t = Table('t', MetaData(), c) + t = Table("t", MetaData(), c) eq_(list(t.indexes)[0].unique, False) c = Column(Integer, unique=True) eq_(c.unique, True) eq_(c.index, None) - c = Column('c', Integer, index=True, unique=True) + c = Column("c", Integer, index=True, unique=True) eq_(c.unique, True) eq_(c.index, True) - t = Table('t', MetaData(), c) + t = Table("t", MetaData(), c) eq_(list(t.indexes)[0].unique, True) def test_bogus(self): - assert_raises(exc.ArgumentError, Column, 'foo', name='bar') - assert_raises(exc.ArgumentError, Column, 'foo', Integer, - type_=Integer()) + assert_raises(exc.ArgumentError, Column, "foo", name="bar") + assert_raises( + exc.ArgumentError, Column, "foo", Integer, type_=Integer() + ) def test_custom_subclass_proxy(self): """test proxy generation of a Column subclass, can be compiled.""" @@ -3333,9 +3512,8 @@ class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase): from sqlalchemy.sql import select class MyColumn(Column): - def _constructor(self, name, type, **kw): - kw['name'] = name + kw["name"] = name return MyColumn(type, **kw) def __init__(self, type, **kw): @@ -3350,13 +3528,10 @@ class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase): return s + "-" id = MyColumn(Integer, primary_key=True) - id.name = 'id' + id.name = "id" name = MyColumn(String) - name.name = 'name' - t1 = Table('foo', MetaData(), - id, - name - ) + name.name = "name" + t1 = Table("foo", MetaData(), id, name) # goofy thing eq_(t1.c.name.my_goofy_thing(), "hi") @@ -3380,24 +3555,22 @@ class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase): from sqlalchemy.sql import select class MyColumn(Column): - def __init__(self, type, **kw): Column.__init__(self, type, **kw) id = MyColumn(Integer, primary_key=True) - id.name = 'id' + id.name = "id" name = MyColumn(String) - name.name = 'name' - t1 = Table('foo', MetaData(), - id, - name - ) + name.name = "name" + t1 = Table("foo", MetaData(), id, name) assert_raises_message( TypeError, "Could not create a copy of this " "object. Ensure the class includes a _constructor()", - getattr, select([t1.select().alias()]), 'c' + getattr, + select([t1.select().alias()]), + "c", ) def test_custom_create(self): @@ -3412,7 +3585,7 @@ class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase): text = "%s SPECIAL DIRECTIVE %s" % ( column.name, - compiler.type_compiler.process(column.type) + compiler.type_compiler.process(column.type), ) default = compiler.get_column_default_string(column) if default is not None: @@ -3423,23 +3596,23 @@ class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase): if column.constraints: text += " ".join( - compiler.process(const) - for const in column.constraints) + compiler.process(const) for const in column.constraints + ) return text t = Table( - 'mytable', MetaData(), - Column('x', Integer, info={ - "special": True}, primary_key=True), - Column('y', String(50)), - Column('z', String(20), info={ - "special": True})) + "mytable", + MetaData(), + Column("x", Integer, info={"special": True}, primary_key=True), + Column("y", String(50)), + Column("z", String(20), info={"special": True}), + ) self.assert_compile( schema.CreateTable(t), "CREATE TABLE mytable (x SPECIAL DIRECTIVE INTEGER " "NOT NULL, y VARCHAR(50), " - "z SPECIAL DIRECTIVE VARCHAR(20), PRIMARY KEY (x))" + "z SPECIAL DIRECTIVE VARCHAR(20), PRIMARY KEY (x))", ) deregister(schema.CreateColumn) @@ -3450,127 +3623,126 @@ class ColumnDefaultsTest(fixtures.TestBase): """test assignment of default fixures to columns""" def _fixture(self, *arg, **kw): - return Column('x', Integer, *arg, **kw) + return Column("x", Integer, *arg, **kw) def test_server_default_positional(self): - target = schema.DefaultClause('y') + target = schema.DefaultClause("y") c = self._fixture(target) assert c.server_default is target assert target.column is c def test_onupdate_default_not_server_default_one(self): - target1 = schema.DefaultClause('y') - target2 = schema.DefaultClause('z') + target1 = schema.DefaultClause("y") + target2 = schema.DefaultClause("z") c = self._fixture(server_default=target1, server_onupdate=target2) - eq_(c.server_default.arg, 'y') - eq_(c.server_onupdate.arg, 'z') + eq_(c.server_default.arg, "y") + eq_(c.server_onupdate.arg, "z") def test_onupdate_default_not_server_default_two(self): - target1 = schema.DefaultClause('y', for_update=True) - target2 = schema.DefaultClause('z', for_update=True) + target1 = schema.DefaultClause("y", for_update=True) + target2 = schema.DefaultClause("z", for_update=True) c = self._fixture(server_default=target1, server_onupdate=target2) - eq_(c.server_default.arg, 'y') - eq_(c.server_onupdate.arg, 'z') + eq_(c.server_default.arg, "y") + eq_(c.server_onupdate.arg, "z") def test_onupdate_default_not_server_default_three(self): - target1 = schema.DefaultClause('y', for_update=False) - target2 = schema.DefaultClause('z', for_update=True) + target1 = schema.DefaultClause("y", for_update=False) + target2 = schema.DefaultClause("z", for_update=True) c = self._fixture(target1, target2) - eq_(c.server_default.arg, 'y') - eq_(c.server_onupdate.arg, 'z') + eq_(c.server_default.arg, "y") + eq_(c.server_onupdate.arg, "z") def test_onupdate_default_not_server_default_four(self): - target1 = schema.DefaultClause('y', for_update=False) + target1 = schema.DefaultClause("y", for_update=False) c = self._fixture(server_onupdate=target1) is_(c.server_default, None) - eq_(c.server_onupdate.arg, 'y') + eq_(c.server_onupdate.arg, "y") def test_server_default_keyword_as_schemaitem(self): - target = schema.DefaultClause('y') + target = schema.DefaultClause("y") c = self._fixture(server_default=target) assert c.server_default is target assert target.column is c def test_server_default_keyword_as_clause(self): - target = 'y' + target = "y" c = self._fixture(server_default=target) assert c.server_default.arg == target assert c.server_default.column is c def test_server_default_onupdate_positional(self): - target = schema.DefaultClause('y', for_update=True) + target = schema.DefaultClause("y", for_update=True) c = self._fixture(target) assert c.server_onupdate is target assert target.column is c def test_server_default_onupdate_keyword_as_schemaitem(self): - target = schema.DefaultClause('y', for_update=True) + target = schema.DefaultClause("y", for_update=True) c = self._fixture(server_onupdate=target) assert c.server_onupdate is target assert target.column is c def test_server_default_onupdate_keyword_as_clause(self): - target = 'y' + target = "y" c = self._fixture(server_onupdate=target) assert c.server_onupdate.arg == target assert c.server_onupdate.column is c def test_column_default_positional(self): - target = schema.ColumnDefault('y') + target = schema.ColumnDefault("y") c = self._fixture(target) assert c.default is target assert target.column is c def test_column_default_keyword_as_schemaitem(self): - target = schema.ColumnDefault('y') + target = schema.ColumnDefault("y") c = self._fixture(default=target) assert c.default is target assert target.column is c def test_column_default_keyword_as_clause(self): - target = 'y' + target = "y" c = self._fixture(default=target) assert c.default.arg == target assert c.default.column is c def test_column_default_onupdate_positional(self): - target = schema.ColumnDefault('y', for_update=True) + target = schema.ColumnDefault("y", for_update=True) c = self._fixture(target) assert c.onupdate is target assert target.column is c def test_column_default_onupdate_keyword_as_schemaitem(self): - target = schema.ColumnDefault('y', for_update=True) + target = schema.ColumnDefault("y", for_update=True) c = self._fixture(onupdate=target) assert c.onupdate is target assert target.column is c def test_column_default_onupdate_keyword_as_clause(self): - target = 'y' + target = "y" c = self._fixture(onupdate=target) assert c.onupdate.arg == target assert c.onupdate.column is c class ColumnOptionsTest(fixtures.TestBase): - def test_default_generators(self): - g1, g2 = Sequence('foo_id_seq'), ColumnDefault('f5') + g1, g2 = Sequence("foo_id_seq"), ColumnDefault("f5") assert Column(String, default=g1).default is g1 assert Column(String, onupdate=g1).onupdate is g1 assert Column(String, default=g2).default is g2 assert Column(String, onupdate=g2).onupdate is g2 def _null_type_error(self, col): - t = Table('t', MetaData(), col) + t = Table("t", MetaData(), col) assert_raises_message( exc.CompileError, r"\(in table 't', column 'foo'\): Can't generate DDL for NullType", - schema.CreateTable(t).compile + schema.CreateTable(t).compile, ) def _no_name_error(self, col): @@ -3578,13 +3750,16 @@ class ColumnOptionsTest(fixtures.TestBase): exc.ArgumentError, "Column must be constructed with a non-blank name or " "assign a non-blank .name", - Table, 't', MetaData(), col + Table, + "t", + MetaData(), + col, ) def _no_error(self, col): m = MetaData() - b = Table('bar', m, Column('id', Integer)) - t = Table('t', m, col) + b = Table("bar", m, Column("id", Integer)) + t = Table("t", m, col) schema.CreateTable(t).compile() def test_argument_signatures(self): @@ -3597,71 +3772,82 @@ class ColumnOptionsTest(fixtures.TestBase): self._null_type_error(Column("foo", Sequence("a"))) - self._no_name_error(Column(ForeignKey('bar.id'))) + self._no_name_error(Column(ForeignKey("bar.id"))) - self._no_error(Column("foo", ForeignKey('bar.id'))) + self._no_error(Column("foo", ForeignKey("bar.id"))) - self._no_name_error(Column(ForeignKey('bar.id'), default="foo")) + self._no_name_error(Column(ForeignKey("bar.id"), default="foo")) - self._no_name_error(Column(ForeignKey('bar.id'), Sequence("a"))) - self._no_error(Column("foo", ForeignKey('bar.id'), default="foo")) - self._no_error(Column("foo", ForeignKey('bar.id'), Sequence("a"))) + self._no_name_error(Column(ForeignKey("bar.id"), Sequence("a"))) + self._no_error(Column("foo", ForeignKey("bar.id"), default="foo")) + self._no_error(Column("foo", ForeignKey("bar.id"), Sequence("a"))) def test_column_info(self): - c1 = Column('foo', String, info={'x': 'y'}) - c2 = Column('bar', String, info={}) - c3 = Column('bat', String) - assert c1.info == {'x': 'y'} + c1 = Column("foo", String, info={"x": "y"}) + c2 = Column("bar", String, info={}) + c3 = Column("bat", String) + assert c1.info == {"x": "y"} assert c2.info == {} assert c3.info == {} for c in (c1, c2, c3): - c.info['bar'] = 'zip' - assert c.info['bar'] == 'zip' + c.info["bar"] = "zip" + assert c.info["bar"] == "zip" class CatchAllEventsTest(fixtures.RemovesEvents, fixtures.TestBase): - def test_all_events(self): canary = [] def before_attach(obj, parent): - canary.append("%s->%s" % (obj.__class__.__name__, - parent.__class__.__name__)) + canary.append( + "%s->%s" % (obj.__class__.__name__, parent.__class__.__name__) + ) def after_attach(obj, parent): canary.append("%s->%s" % (obj.__class__.__name__, parent)) self.event_listen( - schema.SchemaItem, - "before_parent_attach", - before_attach) + schema.SchemaItem, "before_parent_attach", before_attach + ) self.event_listen( - schema.SchemaItem, - "after_parent_attach", - after_attach) + schema.SchemaItem, "after_parent_attach", after_attach + ) m = MetaData() - Table('t1', m, - Column('id', Integer, Sequence('foo_id'), primary_key=True), - Column('bar', String, ForeignKey('t2.id')) - ) - Table('t2', m, - Column('id', Integer, primary_key=True), - ) + Table( + "t1", + m, + Column("id", Integer, Sequence("foo_id"), primary_key=True), + Column("bar", String, ForeignKey("t2.id")), + ) + Table("t2", m, Column("id", Integer, primary_key=True)) eq_( canary, - ['Sequence->Column', 'Sequence->id', 'ForeignKey->Column', - 'ForeignKey->bar', 'Table->MetaData', - 'PrimaryKeyConstraint->Table', 'PrimaryKeyConstraint->t1', - 'Column->Table', 'Column->t1', 'Column->Table', - 'Column->t1', 'ForeignKeyConstraint->Table', - 'ForeignKeyConstraint->t1', 'Table->MetaData(bind=None)', - 'Table->MetaData', 'PrimaryKeyConstraint->Table', - 'PrimaryKeyConstraint->t2', 'Column->Table', 'Column->t2', - 'Table->MetaData(bind=None)'] + [ + "Sequence->Column", + "Sequence->id", + "ForeignKey->Column", + "ForeignKey->bar", + "Table->MetaData", + "PrimaryKeyConstraint->Table", + "PrimaryKeyConstraint->t1", + "Column->Table", + "Column->t1", + "Column->Table", + "Column->t1", + "ForeignKeyConstraint->Table", + "ForeignKeyConstraint->t1", + "Table->MetaData(bind=None)", + "Table->MetaData", + "PrimaryKeyConstraint->Table", + "PrimaryKeyConstraint->t2", + "Column->Table", + "Column->t2", + "Table->MetaData(bind=None)", + ], ) def test_events_per_constraint(self): @@ -3669,80 +3855,82 @@ class CatchAllEventsTest(fixtures.RemovesEvents, fixtures.TestBase): def evt(target): def before_attach(obj, parent): - canary.append("%s->%s" % (target.__name__, - parent.__class__.__name__)) + canary.append( + "%s->%s" % (target.__name__, parent.__class__.__name__) + ) def after_attach(obj, parent): - assert hasattr(obj, 'name') # so we can change it + assert hasattr(obj, "name") # so we can change it canary.append("%s->%s" % (target.__name__, parent)) + self.event_listen(target, "before_parent_attach", before_attach) self.event_listen(target, "after_parent_attach", after_attach) for target in [ - schema.ForeignKeyConstraint, schema.PrimaryKeyConstraint, + schema.ForeignKeyConstraint, + schema.PrimaryKeyConstraint, schema.UniqueConstraint, schema.CheckConstraint, - schema.Index + schema.Index, ]: evt(target) m = MetaData() - Table('t1', m, - Column('id', Integer, Sequence('foo_id'), primary_key=True), - Column('bar', String, ForeignKey('t2.id'), index=True), - Column('bat', Integer, unique=True), - ) - Table('t2', m, - Column('id', Integer, primary_key=True), - Column('bar', Integer), - Column('bat', Integer), - CheckConstraint("bar>5"), - UniqueConstraint('bar', 'bat'), - Index(None, 'bar', 'bat') - ) + Table( + "t1", + m, + Column("id", Integer, Sequence("foo_id"), primary_key=True), + Column("bar", String, ForeignKey("t2.id"), index=True), + Column("bat", Integer, unique=True), + ) + Table( + "t2", + m, + Column("id", Integer, primary_key=True), + Column("bar", Integer), + Column("bat", Integer), + CheckConstraint("bar>5"), + UniqueConstraint("bar", "bat"), + Index(None, "bar", "bat"), + ) eq_( canary, [ - 'PrimaryKeyConstraint->Table', 'PrimaryKeyConstraint->t1', - 'Index->Table', 'Index->t1', - 'ForeignKeyConstraint->Table', 'ForeignKeyConstraint->t1', - 'UniqueConstraint->Table', 'UniqueConstraint->t1', - 'PrimaryKeyConstraint->Table', 'PrimaryKeyConstraint->t2', - 'CheckConstraint->Table', 'CheckConstraint->t2', - 'UniqueConstraint->Table', 'UniqueConstraint->t2', - 'Index->Table', 'Index->t2' - ] + "PrimaryKeyConstraint->Table", + "PrimaryKeyConstraint->t1", + "Index->Table", + "Index->t1", + "ForeignKeyConstraint->Table", + "ForeignKeyConstraint->t1", + "UniqueConstraint->Table", + "UniqueConstraint->t1", + "PrimaryKeyConstraint->Table", + "PrimaryKeyConstraint->t2", + "CheckConstraint->Table", + "CheckConstraint->t2", + "UniqueConstraint->Table", + "UniqueConstraint->t2", + "Index->Table", + "Index->t2", + ], ) class DialectKWArgTest(fixtures.TestBase): - @contextmanager def _fixture(self): from sqlalchemy.engine.default import DefaultDialect class ParticipatingDialect(DefaultDialect): construct_arguments = [ - (schema.Index, { - "x": 5, - "y": False, - "z_one": None - }), - (schema.ForeignKeyConstraint, { - "foobar": False - }) + (schema.Index, {"x": 5, "y": False, "z_one": None}), + (schema.ForeignKeyConstraint, {"foobar": False}), ] class ParticipatingDialect2(DefaultDialect): construct_arguments = [ - (schema.Index, { - "x": 9, - "y": True, - "pp": "default" - }), - (schema.Table, { - "*": None - }) + (schema.Index, {"x": 9, "y": True, "pp": "default"}), + (schema.Table, {"*": None}), ] class NonParticipatingDialect(DefaultDialect): @@ -3757,6 +3945,7 @@ class DialectKWArgTest(fixtures.TestBase): return NonParticipatingDialect else: raise exc.NoSuchModuleError("no dialect %r" % dialect_name) + with mock.patch("sqlalchemy.dialects.registry.load", load): yield @@ -3765,32 +3954,21 @@ class DialectKWArgTest(fixtures.TestBase): def test_participating(self): with self._fixture(): - idx = Index('a', 'b', 'c', participating_y=True) + idx = Index("a", "b", "c", participating_y=True) eq_( idx.dialect_options, - {"participating": {"x": 5, "y": True, "z_one": None}} - ) - eq_( - idx.dialect_kwargs, - { - 'participating_y': True, - } + {"participating": {"x": 5, "y": True, "z_one": None}}, ) + eq_(idx.dialect_kwargs, {"participating_y": True}) def test_nonparticipating(self): with self._fixture(): idx = Index( - 'a', - 'b', - 'c', - nonparticipating_y=True, - nonparticipating_q=5) + "a", "b", "c", nonparticipating_y=True, nonparticipating_q=5 + ) eq_( idx.dialect_kwargs, - { - 'nonparticipating_y': True, - 'nonparticipating_q': 5 - } + {"nonparticipating_y": True, "nonparticipating_q": 5}, ) def test_bad_kwarg_raise(self): @@ -3799,7 +3977,11 @@ class DialectKWArgTest(fixtures.TestBase): TypeError, "Additional arguments should be named " "_, got 'foobar'", - Index, 'a', 'b', 'c', foobar=True + Index, + "a", + "b", + "c", + foobar=True, ) def test_unknown_dialect_warning(self): @@ -3808,7 +3990,11 @@ class DialectKWArgTest(fixtures.TestBase): exc.SAWarning, "Can't validate argument 'unknown_y'; can't locate " "any SQLAlchemy dialect named 'unknown'", - Index, 'a', 'b', 'c', unknown_y=True + Index, + "a", + "b", + "c", + unknown_y=True, ) def test_participating_bad_kw(self): @@ -3818,7 +4004,11 @@ class DialectKWArgTest(fixtures.TestBase): "Argument 'participating_q_p_x' is not accepted by dialect " "'participating' on behalf of " "", - Index, 'a', 'b', 'c', participating_q_p_x=8 + Index, + "a", + "b", + "c", + participating_q_p_x=8, ) def test_participating_unknown_schema_item(self): @@ -3830,310 +4020,328 @@ class DialectKWArgTest(fixtures.TestBase): "Argument 'participating_q_p_x' is not accepted by dialect " "'participating' on behalf of " "", - UniqueConstraint, 'a', 'b', participating_q_p_x=8 + UniqueConstraint, + "a", + "b", + participating_q_p_x=8, ) @testing.emits_warning("Can't validate") def test_unknown_dialect_warning_still_populates(self): with self._fixture(): - idx = Index('a', 'b', 'c', unknown_y=True) + idx = Index("a", "b", "c", unknown_y=True) eq_(idx.dialect_kwargs, {"unknown_y": True}) # still populates @testing.emits_warning("Can't validate") def test_unknown_dialect_warning_still_populates_multiple(self): with self._fixture(): - idx = Index('a', 'b', 'c', unknown_y=True, unknown_z=5, - otherunknown_foo='bar', participating_y=8) + idx = Index( + "a", + "b", + "c", + unknown_y=True, + unknown_z=5, + otherunknown_foo="bar", + participating_y=8, + ) eq_( idx.dialect_options, { - "unknown": {'y': True, 'z': 5, '*': None}, - "otherunknown": {'foo': 'bar', '*': None}, - "participating": {'x': 5, 'y': 8, 'z_one': None} - } + "unknown": {"y": True, "z": 5, "*": None}, + "otherunknown": {"foo": "bar", "*": None}, + "participating": {"x": 5, "y": 8, "z_one": None}, + }, ) - eq_(idx.dialect_kwargs, - {'unknown_z': 5, 'participating_y': 8, - 'unknown_y': True, - 'otherunknown_foo': 'bar'} - ) # still populates + eq_( + idx.dialect_kwargs, + { + "unknown_z": 5, + "participating_y": 8, + "unknown_y": True, + "otherunknown_foo": "bar", + }, + ) # still populates def test_runs_safekwarg(self): - with mock.patch("sqlalchemy.util.safe_kwarg", - lambda arg: "goofy_%s" % arg): + with mock.patch( + "sqlalchemy.util.safe_kwarg", lambda arg: "goofy_%s" % arg + ): with self._fixture(): - idx = Index('a', 'b') - idx.kwargs[util.u('participating_x')] = 7 + idx = Index("a", "b") + idx.kwargs[util.u("participating_x")] = 7 - eq_( - list(idx.dialect_kwargs), - ['goofy_participating_x'] - ) + eq_(list(idx.dialect_kwargs), ["goofy_participating_x"]) def test_combined(self): with self._fixture(): - idx = Index('a', 'b', 'c', participating_x=7, - nonparticipating_y=True) + idx = Index( + "a", "b", "c", participating_x=7, nonparticipating_y=True + ) eq_( idx.dialect_options, { - 'participating': {'y': False, 'x': 7, 'z_one': None}, - 'nonparticipating': {'y': True, '*': None} - } + "participating": {"y": False, "x": 7, "z_one": None}, + "nonparticipating": {"y": True, "*": None}, + }, ) eq_( idx.dialect_kwargs, - { - 'participating_x': 7, - 'nonparticipating_y': True, - } + {"participating_x": 7, "nonparticipating_y": True}, ) def test_multiple_participating(self): with self._fixture(): - idx = Index('a', 'b', 'c', - participating_x=7, - participating2_x=15, - participating2_y="lazy" - ) + idx = Index( + "a", + "b", + "c", + participating_x=7, + participating2_x=15, + participating2_y="lazy", + ) eq_( idx.dialect_options, { - "participating": {'x': 7, 'y': False, 'z_one': None}, - "participating2": {'x': 15, 'y': 'lazy', 'pp': 'default'}, - } + "participating": {"x": 7, "y": False, "z_one": None}, + "participating2": {"x": 15, "y": "lazy", "pp": "default"}, + }, ) eq_( idx.dialect_kwargs, { - 'participating_x': 7, - 'participating2_x': 15, - 'participating2_y': 'lazy' - } + "participating_x": 7, + "participating2_x": 15, + "participating2_y": "lazy", + }, ) def test_foreign_key_propagate(self): with self._fixture(): m = MetaData() - fk = ForeignKey('t2.id', participating_foobar=True) - t = Table('t', m, Column('id', Integer, fk)) + fk = ForeignKey("t2.id", participating_foobar=True) + t = Table("t", m, Column("id", Integer, fk)) fkc = [ - c for c in t.constraints if isinstance( - c, - ForeignKeyConstraint)][0] - eq_( - fkc.dialect_kwargs, - {'participating_foobar': True} - ) + c for c in t.constraints if isinstance(c, ForeignKeyConstraint) + ][0] + eq_(fkc.dialect_kwargs, {"participating_foobar": True}) def test_foreign_key_propagate_exceptions_delayed(self): with self._fixture(): m = MetaData() - fk = ForeignKey('t2.id', participating_fake=True) - c1 = Column('id', Integer, fk) + fk = ForeignKey("t2.id", participating_fake=True) + c1 = Column("id", Integer, fk) assert_raises_message( exc.ArgumentError, "Argument 'participating_fake' is not accepted by " "dialect 'participating' on behalf of " "", - Table, 't', m, c1 + Table, + "t", + m, + c1, ) def test_wildcard(self): with self._fixture(): m = MetaData() - t = Table('x', m, Column('x', Integer), - participating2_xyz='foo', - participating2_engine='InnoDB', - ) + t = Table( + "x", + m, + Column("x", Integer), + participating2_xyz="foo", + participating2_engine="InnoDB", + ) eq_( t.dialect_kwargs, { - 'participating2_xyz': 'foo', - 'participating2_engine': 'InnoDB' - } + "participating2_xyz": "foo", + "participating2_engine": "InnoDB", + }, ) def test_uninit_wildcard(self): with self._fixture(): m = MetaData() - t = Table('x', m, Column('x', Integer)) - eq_( - t.dialect_options['participating2'], {'*': None} - ) - eq_( - t.dialect_kwargs, {} - ) + t = Table("x", m, Column("x", Integer)) + eq_(t.dialect_options["participating2"], {"*": None}) + eq_(t.dialect_kwargs, {}) def test_not_contains_wildcard(self): with self._fixture(): m = MetaData() - t = Table('x', m, Column('x', Integer)) - assert 'foobar' not in t.dialect_options['participating2'] + t = Table("x", m, Column("x", Integer)) + assert "foobar" not in t.dialect_options["participating2"] def test_contains_wildcard(self): with self._fixture(): m = MetaData() - t = Table('x', m, Column('x', Integer), participating2_foobar=5) - assert 'foobar' in t.dialect_options['participating2'] + t = Table("x", m, Column("x", Integer), participating2_foobar=5) + assert "foobar" in t.dialect_options["participating2"] def test_update(self): with self._fixture(): - idx = Index('a', 'b', 'c', participating_x=20) - eq_(idx.dialect_kwargs, { - "participating_x": 20, - }) - idx._validate_dialect_kwargs({ - "participating_x": 25, - "participating_z_one": "default"}) - eq_(idx.dialect_options, { - "participating": {"x": 25, "y": False, "z_one": "default"} - }) - eq_(idx.dialect_kwargs, { - "participating_x": 25, - 'participating_z_one': "default" - }) - - idx._validate_dialect_kwargs({ - "participating_x": 25, - "participating_z_one": "default"}) - - eq_(idx.dialect_options, { - "participating": {"x": 25, "y": False, "z_one": "default"} - }) - eq_(idx.dialect_kwargs, { - "participating_x": 25, - 'participating_z_one': "default" - }) - - idx._validate_dialect_kwargs({ - "participating_y": True, - 'participating2_y': "p2y"}) - eq_(idx.dialect_options, { - "participating": {"x": 25, "y": True, "z_one": "default"}, - "participating2": {"y": "p2y", "pp": "default", "x": 9} - }) - eq_(idx.dialect_kwargs, { - "participating_x": 25, - "participating_y": True, - 'participating2_y': "p2y", - "participating_z_one": "default"}) + idx = Index("a", "b", "c", participating_x=20) + eq_(idx.dialect_kwargs, {"participating_x": 20}) + idx._validate_dialect_kwargs( + {"participating_x": 25, "participating_z_one": "default"} + ) + eq_( + idx.dialect_options, + {"participating": {"x": 25, "y": False, "z_one": "default"}}, + ) + eq_( + idx.dialect_kwargs, + {"participating_x": 25, "participating_z_one": "default"}, + ) + + idx._validate_dialect_kwargs( + {"participating_x": 25, "participating_z_one": "default"} + ) + + eq_( + idx.dialect_options, + {"participating": {"x": 25, "y": False, "z_one": "default"}}, + ) + eq_( + idx.dialect_kwargs, + {"participating_x": 25, "participating_z_one": "default"}, + ) + + idx._validate_dialect_kwargs( + {"participating_y": True, "participating2_y": "p2y"} + ) + eq_( + idx.dialect_options, + { + "participating": {"x": 25, "y": True, "z_one": "default"}, + "participating2": {"y": "p2y", "pp": "default", "x": 9}, + }, + ) + eq_( + idx.dialect_kwargs, + { + "participating_x": 25, + "participating_y": True, + "participating2_y": "p2y", + "participating_z_one": "default", + }, + ) def test_key_error_kwargs_no_dialect(self): with self._fixture(): - idx = Index('a', 'b', 'c') - assert_raises( - KeyError, - idx.kwargs.__getitem__, 'foo_bar' - ) + idx = Index("a", "b", "c") + assert_raises(KeyError, idx.kwargs.__getitem__, "foo_bar") def test_key_error_kwargs_no_underscore(self): with self._fixture(): - idx = Index('a', 'b', 'c') - assert_raises( - KeyError, - idx.kwargs.__getitem__, 'foobar' - ) + idx = Index("a", "b", "c") + assert_raises(KeyError, idx.kwargs.__getitem__, "foobar") def test_key_error_kwargs_no_argument(self): with self._fixture(): - idx = Index('a', 'b', 'c') + idx = Index("a", "b", "c") assert_raises( - KeyError, - idx.kwargs.__getitem__, 'participating_asdmfq34098' + KeyError, idx.kwargs.__getitem__, "participating_asdmfq34098" ) assert_raises( KeyError, - idx.kwargs.__getitem__, 'nonparticipating_asdmfq34098' + idx.kwargs.__getitem__, + "nonparticipating_asdmfq34098", ) def test_key_error_dialect_options(self): with self._fixture(): - idx = Index('a', 'b', 'c') + idx = Index("a", "b", "c") assert_raises( KeyError, - idx.dialect_options['participating'].__getitem__, 'asdfaso890' + idx.dialect_options["participating"].__getitem__, + "asdfaso890", ) assert_raises( KeyError, - idx.dialect_options['nonparticipating'].__getitem__, - 'asdfaso890') + idx.dialect_options["nonparticipating"].__getitem__, + "asdfaso890", + ) def test_ad_hoc_participating_via_opt(self): with self._fixture(): - idx = Index('a', 'b', 'c') - idx.dialect_options['participating']['foobar'] = 5 + idx = Index("a", "b", "c") + idx.dialect_options["participating"]["foobar"] = 5 - eq_(idx.dialect_options['participating']['foobar'], 5) - eq_(idx.kwargs['participating_foobar'], 5) + eq_(idx.dialect_options["participating"]["foobar"], 5) + eq_(idx.kwargs["participating_foobar"], 5) def test_ad_hoc_nonparticipating_via_opt(self): with self._fixture(): - idx = Index('a', 'b', 'c') - idx.dialect_options['nonparticipating']['foobar'] = 5 + idx = Index("a", "b", "c") + idx.dialect_options["nonparticipating"]["foobar"] = 5 - eq_(idx.dialect_options['nonparticipating']['foobar'], 5) - eq_(idx.kwargs['nonparticipating_foobar'], 5) + eq_(idx.dialect_options["nonparticipating"]["foobar"], 5) + eq_(idx.kwargs["nonparticipating_foobar"], 5) def test_ad_hoc_participating_via_kwargs(self): with self._fixture(): - idx = Index('a', 'b', 'c') - idx.kwargs['participating_foobar'] = 5 + idx = Index("a", "b", "c") + idx.kwargs["participating_foobar"] = 5 - eq_(idx.dialect_options['participating']['foobar'], 5) - eq_(idx.kwargs['participating_foobar'], 5) + eq_(idx.dialect_options["participating"]["foobar"], 5) + eq_(idx.kwargs["participating_foobar"], 5) def test_ad_hoc_nonparticipating_via_kwargs(self): with self._fixture(): - idx = Index('a', 'b', 'c') - idx.kwargs['nonparticipating_foobar'] = 5 + idx = Index("a", "b", "c") + idx.kwargs["nonparticipating_foobar"] = 5 - eq_(idx.dialect_options['nonparticipating']['foobar'], 5) - eq_(idx.kwargs['nonparticipating_foobar'], 5) + eq_(idx.dialect_options["nonparticipating"]["foobar"], 5) + eq_(idx.kwargs["nonparticipating_foobar"], 5) def test_ad_hoc_via_kwargs_invalid_key(self): with self._fixture(): - idx = Index('a', 'b', 'c') + idx = Index("a", "b", "c") assert_raises_message( exc.ArgumentError, "Keys must be of the form _", - idx.kwargs.__setitem__, "foobar", 5 + idx.kwargs.__setitem__, + "foobar", + 5, ) def test_ad_hoc_via_kwargs_invalid_dialect(self): with self._fixture(): - idx = Index('a', 'b', 'c') + idx = Index("a", "b", "c") assert_raises_message( exc.ArgumentError, "no dialect 'nonexistent'", - idx.kwargs.__setitem__, "nonexistent_foobar", 5 + idx.kwargs.__setitem__, + "nonexistent_foobar", + 5, ) def test_add_new_arguments_participating(self): with self._fixture(): Index.argument_for("participating", "xyzqpr", False) - idx = Index('a', 'b', 'c', participating_xyzqpr=True) + idx = Index("a", "b", "c", participating_xyzqpr=True) - eq_(idx.kwargs['participating_xyzqpr'], True) + eq_(idx.kwargs["participating_xyzqpr"], True) - idx = Index('a', 'b', 'c') - eq_(idx.dialect_options['participating']['xyzqpr'], False) + idx = Index("a", "b", "c") + eq_(idx.dialect_options["participating"]["xyzqpr"], False) def test_add_new_arguments_participating_no_existing(self): with self._fixture(): PrimaryKeyConstraint.argument_for("participating", "xyzqpr", False) - pk = PrimaryKeyConstraint('a', 'b', 'c', participating_xyzqpr=True) + pk = PrimaryKeyConstraint("a", "b", "c", participating_xyzqpr=True) - eq_(pk.kwargs['participating_xyzqpr'], True) + eq_(pk.kwargs["participating_xyzqpr"], True) - pk = PrimaryKeyConstraint('a', 'b', 'c') - eq_(pk.dialect_options['participating']['xyzqpr'], False) + pk = PrimaryKeyConstraint("a", "b", "c") + eq_(pk.dialect_options["participating"]["xyzqpr"], False) def test_add_new_arguments_nonparticipating(self): with self._fixture(): @@ -4141,7 +4349,10 @@ class DialectKWArgTest(fixtures.TestBase): exc.ArgumentError, "Dialect 'nonparticipating' does have keyword-argument " "validation and defaults enabled configured", - Index.argument_for, "nonparticipating", "xyzqpr", False + Index.argument_for, + "nonparticipating", + "xyzqpr", + False, ) def test_add_new_arguments_invalid_dialect(self): @@ -4149,43 +4360,48 @@ class DialectKWArgTest(fixtures.TestBase): assert_raises_message( exc.ArgumentError, "no dialect 'nonexistent'", - Index.argument_for, "nonexistent", "foobar", 5 + Index.argument_for, + "nonexistent", + "foobar", + 5, ) class NamingConventionTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def _fixture(self, naming_convention, table_schema=None): m1 = MetaData(naming_convention=naming_convention) - u1 = Table('user', m1, - Column('id', Integer, primary_key=True), - Column('version', Integer, primary_key=True), - Column('data', String(30)), - Column('Data2', String(30), key="data2"), - Column('Data3', String(30), key="data3"), - schema=table_schema - ) + u1 = Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("version", Integer, primary_key=True), + Column("data", String(30)), + Column("Data2", String(30), key="data2"), + Column("Data3", String(30), key="data3"), + schema=table_schema, + ) return u1 def test_uq_name(self): - u1 = self._fixture(naming_convention={ - "uq": "uq_%(table_name)s_%(column_0_name)s" - }) + u1 = self._fixture( + naming_convention={"uq": "uq_%(table_name)s_%(column_0_name)s"} + ) uq = UniqueConstraint(u1.c.data) eq_(uq.name, "uq_user_data") def test_uq_conv_name(self): - u1 = self._fixture(naming_convention={ - "uq": "uq_%(table_name)s_%(column_0_name)s" - }) + u1 = self._fixture( + naming_convention={"uq": "uq_%(table_name)s_%(column_0_name)s"} + ) uq = UniqueConstraint(u1.c.data, name=naming.conv("myname")) self.assert_compile( schema.AddConstraint(uq), 'ALTER TABLE "user" ADD CONSTRAINT myname UNIQUE (data)', - dialect="default" + dialect="default", ) def test_uq_defer_name_no_convention(self): @@ -4194,59 +4410,59 @@ class NamingConventionTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( schema.AddConstraint(uq), 'ALTER TABLE "user" ADD CONSTRAINT myname UNIQUE (data)', - dialect="default" + dialect="default", ) def test_uq_defer_name_convention(self): - u1 = self._fixture(naming_convention={ - "uq": "uq_%(table_name)s_%(column_0_name)s" - }) + u1 = self._fixture( + naming_convention={"uq": "uq_%(table_name)s_%(column_0_name)s"} + ) uq = UniqueConstraint(u1.c.data, name=naming._defer_name("myname")) self.assert_compile( schema.AddConstraint(uq), 'ALTER TABLE "user" ADD CONSTRAINT uq_user_data UNIQUE (data)', - dialect="default" + dialect="default", ) def test_uq_key(self): - u1 = self._fixture(naming_convention={ - "uq": "uq_%(table_name)s_%(column_0_key)s" - }) + u1 = self._fixture( + naming_convention={"uq": "uq_%(table_name)s_%(column_0_key)s"} + ) uq = UniqueConstraint(u1.c.data, u1.c.data2) eq_(uq.name, "uq_user_data") def test_uq_label(self): - u1 = self._fixture(naming_convention={ - "uq": "uq_%(table_name)s_%(column_0_label)s" - }) + u1 = self._fixture( + naming_convention={"uq": "uq_%(table_name)s_%(column_0_label)s"} + ) uq = UniqueConstraint(u1.c.data, u1.c.data2) eq_(uq.name, "uq_user_user_data") def test_uq_allcols_underscore_name(self): - u1 = self._fixture(naming_convention={ - "uq": "uq_%(table_name)s_%(column_0_N_name)s" - }) + u1 = self._fixture( + naming_convention={"uq": "uq_%(table_name)s_%(column_0_N_name)s"} + ) uq = UniqueConstraint(u1.c.data, u1.c.data2, u1.c.data3) eq_(uq.name, "uq_user_data_Data2_Data3") def test_uq_allcols_merged_name(self): - u1 = self._fixture(naming_convention={ - "uq": "uq_%(table_name)s_%(column_0N_name)s" - }) + u1 = self._fixture( + naming_convention={"uq": "uq_%(table_name)s_%(column_0N_name)s"} + ) uq = UniqueConstraint(u1.c.data, u1.c.data2, u1.c.data3) eq_(uq.name, "uq_user_dataData2Data3") def test_uq_allcols_merged_key(self): - u1 = self._fixture(naming_convention={ - "uq": "uq_%(table_name)s_%(column_0N_key)s" - }) + u1 = self._fixture( + naming_convention={"uq": "uq_%(table_name)s_%(column_0N_key)s"} + ) uq = UniqueConstraint(u1.c.data, u1.c.data2, u1.c.data3) eq_(uq.name, "uq_user_datadata2data3") def test_uq_allcols_truncated_name(self): - u1 = self._fixture(naming_convention={ - "uq": "uq_%(table_name)s_%(column_0N_name)s" - }) + u1 = self._fixture( + naming_convention={"uq": "uq_%(table_name)s_%(column_0N_name)s"} + ) uq = UniqueConstraint(u1.c.data, u1.c.data2, u1.c.data3) dialect = default.DefaultDialect() @@ -4255,7 +4471,7 @@ class NamingConventionTest(fixtures.TestBase, AssertsCompiledSQL): 'ALTER TABLE "user" ADD ' 'CONSTRAINT "uq_user_dataData2Data3" ' 'UNIQUE (data, "Data2", "Data3")', - dialect=dialect + dialect=dialect, ) dialect.max_identifier_length = 15 @@ -4263,188 +4479,222 @@ class NamingConventionTest(fixtures.TestBase, AssertsCompiledSQL): schema.AddConstraint(uq), 'ALTER TABLE "user" ADD ' 'CONSTRAINT uq_user_2769 UNIQUE (data, "Data2", "Data3")', - dialect=dialect + dialect=dialect, ) def test_fk_allcols_underscore_name(self): - u1 = self._fixture(naming_convention={ - "fk": "fk_%(table_name)s_%(column_0_N_name)s_" - "%(referred_table_name)s_%(referred_column_0_N_name)s"}) + u1 = self._fixture( + naming_convention={ + "fk": "fk_%(table_name)s_%(column_0_N_name)s_" + "%(referred_table_name)s_%(referred_column_0_N_name)s" + } + ) m1 = u1.metadata - a1 = Table('address', m1, - Column('id', Integer, primary_key=True), - Column('UserData', String(30), key="user_data"), - Column('UserData2', String(30), key="user_data2"), - Column('UserData3', String(30), key="user_data3") - ) - fk = ForeignKeyConstraint(['user_data', 'user_data2', 'user_data3'], - ['user.data', 'user.data2', 'user.data3']) + a1 = Table( + "address", + m1, + Column("id", Integer, primary_key=True), + Column("UserData", String(30), key="user_data"), + Column("UserData2", String(30), key="user_data2"), + Column("UserData3", String(30), key="user_data3"), + ) + fk = ForeignKeyConstraint( + ["user_data", "user_data2", "user_data3"], + ["user.data", "user.data2", "user.data3"], + ) a1.append_constraint(fk) self.assert_compile( schema.AddConstraint(fk), - 'ALTER TABLE address ADD CONSTRAINT ' + "ALTER TABLE address ADD CONSTRAINT " '"fk_address_UserData_UserData2_UserData3_user_data_Data2_Data3" ' 'FOREIGN KEY("UserData", "UserData2", "UserData3") ' 'REFERENCES "user" (data, "Data2", "Data3")', - dialect=default.DefaultDialect() + dialect=default.DefaultDialect(), ) def test_fk_allcols_merged_name(self): - u1 = self._fixture(naming_convention={ - "fk": "fk_%(table_name)s_%(column_0N_name)s_" - "%(referred_table_name)s_%(referred_column_0N_name)s"}) + u1 = self._fixture( + naming_convention={ + "fk": "fk_%(table_name)s_%(column_0N_name)s_" + "%(referred_table_name)s_%(referred_column_0N_name)s" + } + ) m1 = u1.metadata - a1 = Table('address', m1, - Column('id', Integer, primary_key=True), - Column('UserData', String(30), key="user_data"), - Column('UserData2', String(30), key="user_data2"), - Column('UserData3', String(30), key="user_data3") - ) - fk = ForeignKeyConstraint(['user_data', 'user_data2', 'user_data3'], - ['user.data', 'user.data2', 'user.data3']) + a1 = Table( + "address", + m1, + Column("id", Integer, primary_key=True), + Column("UserData", String(30), key="user_data"), + Column("UserData2", String(30), key="user_data2"), + Column("UserData3", String(30), key="user_data3"), + ) + fk = ForeignKeyConstraint( + ["user_data", "user_data2", "user_data3"], + ["user.data", "user.data2", "user.data3"], + ) a1.append_constraint(fk) self.assert_compile( schema.AddConstraint(fk), - 'ALTER TABLE address ADD CONSTRAINT ' + "ALTER TABLE address ADD CONSTRAINT " '"fk_address_UserDataUserData2UserData3_user_dataData2Data3" ' 'FOREIGN KEY("UserData", "UserData2", "UserData3") ' 'REFERENCES "user" (data, "Data2", "Data3")', - dialect=default.DefaultDialect() + dialect=default.DefaultDialect(), ) def test_fk_allcols_truncated_name(self): - u1 = self._fixture(naming_convention={ - "fk": "fk_%(table_name)s_%(column_0N_name)s_" - "%(referred_table_name)s_%(referred_column_0N_name)s"}) + u1 = self._fixture( + naming_convention={ + "fk": "fk_%(table_name)s_%(column_0N_name)s_" + "%(referred_table_name)s_%(referred_column_0N_name)s" + } + ) m1 = u1.metadata - a1 = Table('address', m1, - Column('id', Integer, primary_key=True), - Column('UserData', String(30), key="user_data"), - Column('UserData2', String(30), key="user_data2"), - Column('UserData3', String(30), key="user_data3") - ) - fk = ForeignKeyConstraint(['user_data', 'user_data2', 'user_data3'], - ['user.data', 'user.data2', 'user.data3']) + a1 = Table( + "address", + m1, + Column("id", Integer, primary_key=True), + Column("UserData", String(30), key="user_data"), + Column("UserData2", String(30), key="user_data2"), + Column("UserData3", String(30), key="user_data3"), + ) + fk = ForeignKeyConstraint( + ["user_data", "user_data2", "user_data3"], + ["user.data", "user.data2", "user.data3"], + ) a1.append_constraint(fk) dialect = default.DefaultDialect() dialect.max_identifier_length = 15 self.assert_compile( schema.AddConstraint(fk), - 'ALTER TABLE address ADD CONSTRAINT ' - 'fk_addr_f9ff ' + "ALTER TABLE address ADD CONSTRAINT " + "fk_addr_f9ff " 'FOREIGN KEY("UserData", "UserData2", "UserData3") ' 'REFERENCES "user" (data, "Data2", "Data3")', - dialect=dialect + dialect=dialect, ) def test_ix_allcols_truncation(self): - u1 = self._fixture(naming_convention={ - "ix": "ix_%(table_name)s_%(column_0N_name)s" - }) + u1 = self._fixture( + naming_convention={"ix": "ix_%(table_name)s_%(column_0N_name)s"} + ) ix = Index(None, u1.c.data, u1.c.data2, u1.c.data3) dialect = default.DefaultDialect() dialect.max_identifier_length = 15 self.assert_compile( schema.CreateIndex(ix), - 'CREATE INDEX ix_user_2de9 ON ' - '"user" (data, "Data2", "Data3")', - dialect=dialect + "CREATE INDEX ix_user_2de9 ON " '"user" (data, "Data2", "Data3")', + dialect=dialect, ) def test_ix_name(self): - u1 = self._fixture(naming_convention={ - "ix": "ix_%(table_name)s_%(column_0_name)s" - }) + u1 = self._fixture( + naming_convention={"ix": "ix_%(table_name)s_%(column_0_name)s"} + ) ix = Index(None, u1.c.data) eq_(ix.name, "ix_user_data") def test_ck_name_required(self): - u1 = self._fixture(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s" - }) - ck = CheckConstraint(u1.c.data == 'x', name='mycheck') + u1 = self._fixture( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) + ck = CheckConstraint(u1.c.data == "x", name="mycheck") eq_(ck.name, "ck_user_mycheck") assert_raises_message( exc.InvalidRequestError, r"Naming convention including %\(constraint_name\)s token " "requires that constraint is explicitly named.", - CheckConstraint, u1.c.data == 'x' + CheckConstraint, + u1.c.data == "x", ) def test_ck_name_deferred_required(self): - u1 = self._fixture(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s" - }) - ck = CheckConstraint(u1.c.data == 'x', name=elements._defer_name(None)) + u1 = self._fixture( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) + ck = CheckConstraint(u1.c.data == "x", name=elements._defer_name(None)) assert_raises_message( exc.InvalidRequestError, r"Naming convention including %\(constraint_name\)s token " "requires that constraint is explicitly named.", - schema.AddConstraint(ck).compile + schema.AddConstraint(ck).compile, ) def test_column_attached_ck_name(self): - m = MetaData(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s" - }) - ck = CheckConstraint('x > 5', name='x1') - Table('t', m, Column('x', ck)) + m = MetaData( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) + ck = CheckConstraint("x > 5", name="x1") + Table("t", m, Column("x", ck)) eq_(ck.name, "ck_t_x1") def test_table_attached_ck_name(self): - m = MetaData(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s" - }) - ck = CheckConstraint('x > 5', name='x1') - Table('t', m, Column('x', Integer), ck) + m = MetaData( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) + ck = CheckConstraint("x > 5", name="x1") + Table("t", m, Column("x", Integer), ck) eq_(ck.name, "ck_t_x1") def test_uq_name_already_conv(self): - m = MetaData(naming_convention={ - "uq": "uq_%(constraint_name)s_%(column_0_name)s" - }) + m = MetaData( + naming_convention={ + "uq": "uq_%(constraint_name)s_%(column_0_name)s" + } + ) - t = Table('mytable', m) - uq = UniqueConstraint(name=naming.conv('my_special_key')) + t = Table("mytable", m) + uq = UniqueConstraint(name=naming.conv("my_special_key")) t.append_constraint(uq) eq_(uq.name, "my_special_key") def test_fk_name_schema(self): - u1 = self._fixture(naming_convention={ - "fk": "fk_%(table_name)s_%(column_0_name)s_" - "%(referred_table_name)s_%(referred_column_0_name)s" - }, table_schema="foo") + u1 = self._fixture( + naming_convention={ + "fk": "fk_%(table_name)s_%(column_0_name)s_" + "%(referred_table_name)s_%(referred_column_0_name)s" + }, + table_schema="foo", + ) m1 = u1.metadata - a1 = Table('address', m1, - Column('id', Integer, primary_key=True), - Column('user_id', Integer), - Column('user_version_id', Integer) - ) - fk = ForeignKeyConstraint(['user_id', 'user_version_id'], - ['foo.user.id', 'foo.user.version']) + a1 = Table( + "address", + m1, + Column("id", Integer, primary_key=True), + Column("user_id", Integer), + Column("user_version_id", Integer), + ) + fk = ForeignKeyConstraint( + ["user_id", "user_version_id"], ["foo.user.id", "foo.user.version"] + ) a1.append_constraint(fk) eq_(fk.name, "fk_address_user_id_user_id") def test_fk_attrs(self): - u1 = self._fixture(naming_convention={ - "fk": "fk_%(table_name)s_%(column_0_name)s_" - "%(referred_table_name)s_%(referred_column_0_name)s" - }) + u1 = self._fixture( + naming_convention={ + "fk": "fk_%(table_name)s_%(column_0_name)s_" + "%(referred_table_name)s_%(referred_column_0_name)s" + } + ) m1 = u1.metadata - a1 = Table('address', m1, - Column('id', Integer, primary_key=True), - Column('user_id', Integer), - Column('user_version_id', Integer) - ) - fk = ForeignKeyConstraint(['user_id', 'user_version_id'], - ['user.id', 'user.version']) + a1 = Table( + "address", + m1, + Column("id", Integer, primary_key=True), + Column("user_id", Integer), + Column("user_version_id", Integer), + ) + fk = ForeignKeyConstraint( + ["user_id", "user_version_id"], ["user.id", "user.version"] + ) a1.append_constraint(fk) eq_(fk.name, "fk_address_user_id_user_id") @@ -4452,32 +4702,38 @@ class NamingConventionTest(fixtures.TestBase, AssertsCompiledSQL): def key_hash(const, table): return "HASH_%s" % table.name - u1 = self._fixture(naming_convention={ - "fk": "fk_%(table_name)s_%(key_hash)s", - "key_hash": key_hash - }) + u1 = self._fixture( + naming_convention={ + "fk": "fk_%(table_name)s_%(key_hash)s", + "key_hash": key_hash, + } + ) m1 = u1.metadata - a1 = Table('address', m1, - Column('id', Integer, primary_key=True), - Column('user_id', Integer), - Column('user_version_id', Integer) - ) - fk = ForeignKeyConstraint(['user_id', 'user_version_id'], - ['user.id', 'user.version']) + a1 = Table( + "address", + m1, + Column("id", Integer, primary_key=True), + Column("user_id", Integer), + Column("user_version_id", Integer), + ) + fk = ForeignKeyConstraint( + ["user_id", "user_version_id"], ["user.id", "user.version"] + ) a1.append_constraint(fk) eq_(fk.name, "fk_address_HASH_address") def test_schematype_ck_name_boolean(self): - m1 = MetaData(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s"}) + m1 = MetaData( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) - u1 = Table('user', m1, - Column('x', Boolean(name='foo')) - ) + u1 = Table("user", m1, Column("x", Boolean(name="foo"))) # constraint is not hit eq_( - [c for c in u1.constraints - if isinstance(c, CheckConstraint)][0].name, "foo" + [c for c in u1.constraints if isinstance(c, CheckConstraint)][ + 0 + ].name, + "foo", ) # but is hit at compile time self.assert_compile( @@ -4485,20 +4741,21 @@ class NamingConventionTest(fixtures.TestBase, AssertsCompiledSQL): 'CREATE TABLE "user" (' "x BOOLEAN, " "CONSTRAINT ck_user_foo CHECK (x IN (0, 1))" - ")" + ")", ) def test_schematype_ck_name_boolean_not_on_name(self): - m1 = MetaData(naming_convention={ - "ck": "ck_%(table_name)s_%(column_0_name)s"}) + m1 = MetaData( + naming_convention={"ck": "ck_%(table_name)s_%(column_0_name)s"} + ) - u1 = Table('user', m1, - Column('x', Boolean()) - ) + u1 = Table("user", m1, Column("x", Boolean())) # constraint is not hit eq_( - [c for c in u1.constraints - if isinstance(c, CheckConstraint)][0].name, "_unnamed_" + [c for c in u1.constraints if isinstance(c, CheckConstraint)][ + 0 + ].name, + "_unnamed_", ) # but is hit at compile time self.assert_compile( @@ -4506,19 +4763,20 @@ class NamingConventionTest(fixtures.TestBase, AssertsCompiledSQL): 'CREATE TABLE "user" (' "x BOOLEAN, " "CONSTRAINT ck_user_x CHECK (x IN (0, 1))" - ")" + ")", ) def test_schematype_ck_name_enum(self): - m1 = MetaData(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s"}) + m1 = MetaData( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) - u1 = Table('user', m1, - Column('x', Enum('a', 'b', name='foo')) - ) + u1 = Table("user", m1, Column("x", Enum("a", "b", name="foo"))) eq_( - [c for c in u1.constraints - if isinstance(c, CheckConstraint)][0].name, "foo" + [c for c in u1.constraints if isinstance(c, CheckConstraint)][ + 0 + ].name, + "foo", ) # but is hit at compile time self.assert_compile( @@ -4526,19 +4784,22 @@ class NamingConventionTest(fixtures.TestBase, AssertsCompiledSQL): 'CREATE TABLE "user" (' "x VARCHAR(1), " "CONSTRAINT ck_user_foo CHECK (x IN ('a', 'b'))" - ")" + ")", ) def test_schematype_ck_name_propagate_conv(self): - m1 = MetaData(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s"}) + m1 = MetaData( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) - u1 = Table('user', m1, - Column('x', Enum('a', 'b', name=naming.conv('foo'))) - ) + u1 = Table( + "user", m1, Column("x", Enum("a", "b", name=naming.conv("foo"))) + ) eq_( - [c for c in u1.constraints - if isinstance(c, CheckConstraint)][0].name, "foo" + [c for c in u1.constraints if isinstance(c, CheckConstraint)][ + 0 + ].name, + "foo", ) # but is hit at compile time self.assert_compile( @@ -4546,62 +4807,60 @@ class NamingConventionTest(fixtures.TestBase, AssertsCompiledSQL): 'CREATE TABLE "user" (' "x VARCHAR(1), " "CONSTRAINT foo CHECK (x IN ('a', 'b'))" - ")" + ")", ) def test_schematype_ck_name_boolean_no_name(self): - m1 = MetaData(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s" - }) - - u1 = Table( - 'user', m1, - Column('x', Boolean()) + m1 = MetaData( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} ) + + u1 = Table("user", m1, Column("x", Boolean())) # constraint gets special _defer_none_name eq_( - [c for c in u1.constraints - if isinstance(c, CheckConstraint)][0].name, "_unnamed_" + [c for c in u1.constraints if isinstance(c, CheckConstraint)][ + 0 + ].name, + "_unnamed_", ) # no issue with native boolean self.assert_compile( schema.CreateTable(u1), - 'CREATE TABLE "user" (' - "x BOOLEAN" - ")", - dialect='postgresql' + 'CREATE TABLE "user" (' "x BOOLEAN" ")", + dialect="postgresql", ) assert_raises_message( exc.InvalidRequestError, r"Naming convention including \%\(constraint_name\)s token " r"requires that constraint is explicitly named.", - schema.CreateTable(u1).compile, dialect=default.DefaultDialect() + schema.CreateTable(u1).compile, + dialect=default.DefaultDialect(), ) def test_schematype_no_ck_name_boolean_no_name(self): m1 = MetaData() # no naming convention - u1 = Table( - 'user', m1, - Column('x', Boolean()) - ) + u1 = Table("user", m1, Column("x", Boolean())) # constraint gets special _defer_none_name eq_( - [c for c in u1.constraints - if isinstance(c, CheckConstraint)][0].name, "_unnamed_" + [c for c in u1.constraints if isinstance(c, CheckConstraint)][ + 0 + ].name, + "_unnamed_", ) self.assert_compile( schema.CreateTable(u1), - 'CREATE TABLE "user" (x BOOLEAN, CHECK (x IN (0, 1)))' + 'CREATE TABLE "user" (x BOOLEAN, CHECK (x IN (0, 1)))', ) def test_ck_constraint_redundant_event(self): - u1 = self._fixture(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s"}) + u1 = self._fixture( + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} + ) - ck1 = CheckConstraint(u1.c.version > 3, name='foo') + ck1 = CheckConstraint(u1.c.version > 3, name="foo") u1.append_constraint(ck1) u1.append_constraint(ck1) u1.append_constraint(ck1) @@ -4615,8 +4874,8 @@ class NamingConventionTest(fixtures.TestBase, AssertsCompiledSQL): eq_(m2.naming_convention, {"pk": "%(table_name)s_pk"}) - t2a = Table('t2', m, Column('id', Integer, primary_key=True)) - t2b = Table('t2', m2, Column('id', Integer, primary_key=True)) + t2a = Table("t2", m, Column("id", Integer, primary_key=True)) + t2b = Table("t2", m2, Column("id", Integer, primary_key=True)) eq_(t2a.primary_key.name, t2b.primary_key.name) eq_(t2b.primary_key.name, "t2_pk") diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index 00961a2e84..2ce842cafc 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -2,12 +2,29 @@ from sqlalchemy.testing import fixtures, eq_, is_, is_not_ from sqlalchemy import testing from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import expect_warnings -from sqlalchemy.sql import column, desc, asc, literal, collate, null, \ - true, false, any_, all_ +from sqlalchemy.sql import ( + column, + desc, + asc, + literal, + collate, + null, + true, + false, + any_, + all_, +) from sqlalchemy.sql import sqltypes -from sqlalchemy.sql.expression import BinaryExpression, \ - ClauseList, Grouping, \ - UnaryExpression, select, union, func, tuple_ +from sqlalchemy.sql.expression import ( + BinaryExpression, + ClauseList, + Grouping, + UnaryExpression, + select, + union, + func, + tuple_, +) from sqlalchemy.sql import operators, table import operator from sqlalchemy import String, Integer, LargeBinary @@ -16,11 +33,26 @@ from sqlalchemy.engine import default from sqlalchemy.sql.elements import _literal_as_text, Label from sqlalchemy.schema import Column, Table, MetaData from sqlalchemy.sql import compiler -from sqlalchemy.types import TypeEngine, TypeDecorator, UserDefinedType, \ - Boolean, MatchType, Indexable, Concatenable, ARRAY, JSON, \ - DateTime -from sqlalchemy.dialects import mysql, firebird, postgresql, oracle, \ - sqlite, mssql +from sqlalchemy.types import ( + TypeEngine, + TypeDecorator, + UserDefinedType, + Boolean, + MatchType, + Indexable, + Concatenable, + ARRAY, + JSON, + DateTime, +) +from sqlalchemy.dialects import ( + mysql, + firebird, + postgresql, + oracle, + sqlite, + mssql, +) from sqlalchemy import util import datetime import collections @@ -29,53 +61,42 @@ from sqlalchemy import and_, not_, between, or_ class LoopOperate(operators.ColumnOperators): - def operate(self, op, *other, **kwargs): return op class DefaultColumnComparatorTest(fixtures.TestBase): - def _do_scalar_test(self, operator, compare_to): - left = column('left') - assert left.comparator.operate(operator).compare( - compare_to(left) - ) + left = column("left") + assert left.comparator.operate(operator).compare(compare_to(left)) self._loop_test(operator) - def _do_operate_test(self, operator, right=column('right')): - left = column('left') + def _do_operate_test(self, operator, right=column("right")): + left = column("left") - assert left.comparator.operate( - operator, - right).compare( + assert left.comparator.operate(operator, right).compare( BinaryExpression( - _literal_as_text(left), - _literal_as_text(right), - operator)) + _literal_as_text(left), _literal_as_text(right), operator + ) + ) - assert operator( - left, - right).compare( + assert operator(left, right).compare( BinaryExpression( - _literal_as_text(left), - _literal_as_text(right), - operator)) + _literal_as_text(left), _literal_as_text(right), operator + ) + ) self._loop_test(operator, right) if operators.is_comparison(operator): is_( left.comparator.operate(operator, right).type, - sqltypes.BOOLEANTYPE + sqltypes.BOOLEANTYPE, ) def _loop_test(self, operator, *arg): loop = LoopOperate() - is_( - operator(loop, *arg), - operator - ) + is_(operator(loop, *arg), operator) def test_desc(self): self._do_scalar_test(operators.desc_op, desc) @@ -153,77 +174,75 @@ class DefaultColumnComparatorTest(fixtures.TestBase): assert_raises_message( NotImplementedError, "Operator 'getitem' is not supported on this expression", - self._do_operate_test, operators.getitem + self._do_operate_test, + operators.getitem, ) assert_raises_message( NotImplementedError, "Operator 'getitem' is not supported on this expression", - lambda: column('left')[3] + lambda: column("left")[3], ) def test_in(self): - left = column('left') + left = column("left") assert left.comparator.operate(operators.in_op, [1, 2, 3]).compare( BinaryExpression( left, - Grouping(ClauseList( - literal(1), literal(2), literal(3) - )), - operators.in_op + Grouping(ClauseList(literal(1), literal(2), literal(3))), + operators.in_op, ) ) self._loop_test(operators.in_op, [1, 2, 3]) def test_notin(self): - left = column('left') + left = column("left") assert left.comparator.operate(operators.notin_op, [1, 2, 3]).compare( BinaryExpression( left, - Grouping(ClauseList( - literal(1), literal(2), literal(3) - )), - operators.notin_op + Grouping(ClauseList(literal(1), literal(2), literal(3))), + operators.notin_op, ) ) self._loop_test(operators.notin_op, [1, 2, 3]) def test_in_no_accept_list_of_non_column_element(self): - left = column('left') + left = column("left") foo = ClauseList() assert_raises_message( exc.InvalidRequestError, r"in_\(\) accepts either a list of expressions, a selectable", - left.in_, [foo] + left.in_, + [foo], ) def test_in_no_accept_non_list_non_selectable(self): - left = column('left') - right = column('right') + left = column("left") + right = column("right") assert_raises_message( exc.InvalidRequestError, r"in_\(\) accepts either a list of expressions, a selectable", - left.in_, right + left.in_, + right, ) def test_in_no_accept_non_list_thing_with_getitem(self): # test [ticket:2726] class HasGetitem(String): - class comparator_factory(String.Comparator): - def __getitem__(self, value): return value - left = column('left') - right = column('right', HasGetitem) + left = column("left") + right = column("right", HasGetitem) assert_raises_message( exc.InvalidRequestError, r"in_\(\) accepts either a list of expressions, a selectable", - left.in_, right + left.in_, + right, ) def test_collate(self): - left = column('left') + left = column("left") right = "some collation" left.comparator.operate(operators.collate, right).compare( collate(left, right) @@ -239,10 +258,8 @@ class DefaultColumnComparatorTest(fixtures.TestBase): class TypeTwo(TypeEngine): pass - expr = column('x', TypeOne()) - column('y', TypeTwo()) - is_( - expr.type._type_affinity, TypeOne - ) + expr = column("x", TypeOne()) - column("y", TypeTwo()) + is_(expr.type._type_affinity, TypeOne) def test_concatenable_adapt(self): class TypeOne(Concatenable, TypeEngine): @@ -254,116 +271,89 @@ class DefaultColumnComparatorTest(fixtures.TestBase): class TypeThree(TypeEngine): pass - expr = column('x', TypeOne()) - column('y', TypeTwo()) - is_( - expr.type._type_affinity, TypeOne - ) - is_( - expr.operator, operator.sub - ) + expr = column("x", TypeOne()) - column("y", TypeTwo()) + is_(expr.type._type_affinity, TypeOne) + is_(expr.operator, operator.sub) - expr = column('x', TypeOne()) + column('y', TypeTwo()) - is_( - expr.type._type_affinity, TypeOne - ) - is_( - expr.operator, operators.concat_op - ) + expr = column("x", TypeOne()) + column("y", TypeTwo()) + is_(expr.type._type_affinity, TypeOne) + is_(expr.operator, operators.concat_op) - expr = column('x', TypeOne()) - column('y', TypeThree()) - is_( - expr.type._type_affinity, TypeOne - ) - is_( - expr.operator, operator.sub - ) + expr = column("x", TypeOne()) - column("y", TypeThree()) + is_(expr.type._type_affinity, TypeOne) + is_(expr.operator, operator.sub) - expr = column('x', TypeOne()) + column('y', TypeThree()) - is_( - expr.type._type_affinity, TypeOne - ) - is_( - expr.operator, operator.add - ) + expr = column("x", TypeOne()) + column("y", TypeThree()) + is_(expr.type._type_affinity, TypeOne) + is_(expr.operator, operator.add) def test_contains_override_raises(self): for col in [ - Column('x', String), - Column('x', Integer), - Column('x', DateTime) + Column("x", String), + Column("x", Integer), + Column("x", DateTime), ]: assert_raises_message( NotImplementedError, "Operator 'contains' is not supported on this expression", - lambda: 'foo' in col + lambda: "foo" in col, ) class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def _factorial_fixture(self): class MyInteger(Integer): - class comparator_factory(Integer.Comparator): - def factorial(self): - return UnaryExpression(self.expr, - modifier=operators.custom_op("!"), - type_=MyInteger) + return UnaryExpression( + self.expr, + modifier=operators.custom_op("!"), + type_=MyInteger, + ) def factorial_prefix(self): - return UnaryExpression(self.expr, - operator=operators.custom_op("!!"), - type_=MyInteger) + return UnaryExpression( + self.expr, + operator=operators.custom_op("!!"), + type_=MyInteger, + ) def __invert__(self): - return UnaryExpression(self.expr, - operator=operators.custom_op("!!!"), - type_=MyInteger) + return UnaryExpression( + self.expr, + operator=operators.custom_op("!!!"), + type_=MyInteger, + ) return MyInteger def test_factorial(self): - col = column('somecol', self._factorial_fixture()) - self.assert_compile( - col.factorial(), - "somecol !" - ) + col = column("somecol", self._factorial_fixture()) + self.assert_compile(col.factorial(), "somecol !") def test_double_factorial(self): - col = column('somecol', self._factorial_fixture()) - self.assert_compile( - col.factorial().factorial(), - "somecol ! !" - ) + col = column("somecol", self._factorial_fixture()) + self.assert_compile(col.factorial().factorial(), "somecol ! !") def test_factorial_prefix(self): - col = column('somecol', self._factorial_fixture()) - self.assert_compile( - col.factorial_prefix(), - "!! somecol" - ) + col = column("somecol", self._factorial_fixture()) + self.assert_compile(col.factorial_prefix(), "!! somecol") def test_factorial_invert(self): - col = column('somecol', self._factorial_fixture()) - self.assert_compile( - ~col, - "!!! somecol" - ) + col = column("somecol", self._factorial_fixture()) + self.assert_compile(~col, "!!! somecol") def test_double_factorial_invert(self): - col = column('somecol', self._factorial_fixture()) - self.assert_compile( - ~(~col), - "!!! (!!! somecol)" - ) + col = column("somecol", self._factorial_fixture()) + self.assert_compile(~(~col), "!!! (!!! somecol)") def test_unary_no_ops(self): assert_raises_message( exc.CompileError, "Unary expression has no operator or modifier", - UnaryExpression(literal("x")).compile + UnaryExpression(literal("x")).compile, ) def test_unary_both_ops(self): @@ -371,83 +361,68 @@ class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): exc.CompileError, "Unary expression does not support operator and " "modifier simultaneously", - UnaryExpression(literal("x"), - operator=operators.custom_op("x"), - modifier=operators.custom_op("y")).compile + UnaryExpression( + literal("x"), + operator=operators.custom_op("x"), + modifier=operators.custom_op("y"), + ).compile, ) class _CustomComparatorTests(object): - def test_override_builtin(self): - c1 = Column('foo', self._add_override_factory()) + c1 = Column("foo", self._add_override_factory()) self._assert_add_override(c1) def test_column_proxy(self): - t = Table('t', MetaData(), - Column('foo', self._add_override_factory()) - ) + t = Table("t", MetaData(), Column("foo", self._add_override_factory())) proxied = t.select().c.foo self._assert_add_override(proxied) self._assert_and_override(proxied) def test_alias_proxy(self): - t = Table('t', MetaData(), - Column('foo', self._add_override_factory()) - ) + t = Table("t", MetaData(), Column("foo", self._add_override_factory())) proxied = t.alias().c.foo self._assert_add_override(proxied) self._assert_and_override(proxied) def test_binary_propagate(self): - c1 = Column('foo', self._add_override_factory()) + c1 = Column("foo", self._add_override_factory()) self._assert_add_override(c1 - 6) self._assert_and_override(c1 - 6) def test_reverse_binary_propagate(self): - c1 = Column('foo', self._add_override_factory()) + c1 = Column("foo", self._add_override_factory()) self._assert_add_override(6 - c1) self._assert_and_override(6 - c1) def test_binary_multi_propagate(self): - c1 = Column('foo', self._add_override_factory()) + c1 = Column("foo", self._add_override_factory()) self._assert_add_override((c1 - 6) + 5) self._assert_and_override((c1 - 6) + 5) def test_no_boolean_propagate(self): - c1 = Column('foo', self._add_override_factory()) + c1 = Column("foo", self._add_override_factory()) self._assert_not_add_override(c1 == 56) self._assert_not_and_override(c1 == 56) def _assert_and_override(self, expr): - assert (expr & text("5")).compare( - expr.op("goofy_and")(text("5")) - ) + assert (expr & text("5")).compare(expr.op("goofy_and")(text("5"))) def _assert_add_override(self, expr): - assert (expr + 5).compare( - expr.op("goofy")(5) - ) + assert (expr + 5).compare(expr.op("goofy")(5)) def _assert_not_add_override(self, expr): - assert not (expr + 5).compare( - expr.op("goofy")(5) - ) + assert not (expr + 5).compare(expr.op("goofy")(5)) def _assert_not_and_override(self, expr): - assert not (expr & text("5")).compare( - expr.op("goofy_and")(text("5")) - ) + assert not (expr & text("5")).compare(expr.op("goofy_and")(text("5"))) class CustomComparatorTest(_CustomComparatorTests, fixtures.TestBase): - def _add_override_factory(self): - class MyInteger(Integer): - class comparator_factory(TypeEngine.Comparator): - def __init__(self, expr): super(MyInteger.comparator_factory, self).__init__(expr) @@ -461,14 +436,11 @@ class CustomComparatorTest(_CustomComparatorTests, fixtures.TestBase): class TypeDecoratorComparatorTest(_CustomComparatorTests, fixtures.TestBase): - def _add_override_factory(self): - class MyInteger(TypeDecorator): impl = Integer class comparator_factory(TypeDecorator.Comparator): - def __init__(self, expr): super(MyInteger.comparator_factory, self).__init__(expr) @@ -482,15 +454,13 @@ class TypeDecoratorComparatorTest(_CustomComparatorTests, fixtures.TestBase): class TypeDecoratorTypeDecoratorComparatorTest( - _CustomComparatorTests, fixtures.TestBase): - + _CustomComparatorTests, fixtures.TestBase +): def _add_override_factory(self): - class MyIntegerOne(TypeDecorator): impl = Integer class comparator_factory(TypeDecorator.Comparator): - def __init__(self, expr): super(MyIntegerOne.comparator_factory, self).__init__(expr) @@ -507,19 +477,15 @@ class TypeDecoratorTypeDecoratorComparatorTest( class TypeDecoratorWVariantComparatorTest( - _CustomComparatorTests, - fixtures.TestBase): - + _CustomComparatorTests, fixtures.TestBase +): def _add_override_factory(self): - class SomeOtherInteger(Integer): - class comparator_factory(TypeEngine.Comparator): - def __init__(self, expr): - super( - SomeOtherInteger.comparator_factory, - self).__init__(expr) + super(SomeOtherInteger.comparator_factory, self).__init__( + expr + ) def __add__(self, other): return self.expr.op("not goofy")(other) @@ -531,7 +497,6 @@ class TypeDecoratorWVariantComparatorTest( impl = Integer class comparator_factory(TypeDecorator.Comparator): - def __init__(self, expr): super(MyInteger.comparator_factory, self).__init__(expr) @@ -545,14 +510,11 @@ class TypeDecoratorWVariantComparatorTest( class CustomEmbeddedinTypeDecoratorTest( - _CustomComparatorTests, - fixtures.TestBase): - + _CustomComparatorTests, fixtures.TestBase +): def _add_override_factory(self): class MyInteger(Integer): - class comparator_factory(TypeEngine.Comparator): - def __init__(self, expr): super(MyInteger.comparator_factory, self).__init__(expr) @@ -569,23 +531,19 @@ class CustomEmbeddedinTypeDecoratorTest( class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase): - def _add_override_factory(self): class MyInteger(Integer): - class comparator_factory(TypeEngine.Comparator): - def __init__(self, expr): super(MyInteger.comparator_factory, self).__init__(expr) def foob(self, other): return self.expr.op("foob")(other) + return MyInteger def _assert_add_override(self, expr): - assert (expr.foob(5)).compare( - expr.op("foob")(5) - ) + assert (expr.foob(5)).compare(expr.op("foob")(5)) def _assert_not_add_override(self, expr): assert not hasattr(expr, "foob") @@ -598,71 +556,49 @@ class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase): class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_contains(self): class MyType(UserDefinedType): - class comparator_factory(UserDefinedType.Comparator): - def contains(self, other, **kw): return self.op("->")(other) - self.assert_compile( - Column('x', MyType()).contains(5), - "x -> :x_1" - ) + self.assert_compile(Column("x", MyType()).contains(5), "x -> :x_1") def test_getitem(self): class MyType(UserDefinedType): - class comparator_factory(UserDefinedType.Comparator): - def __getitem__(self, index): return self.op("->")(index) - self.assert_compile( - Column('x', MyType())[5], - "x -> :x_1" - ) + self.assert_compile(Column("x", MyType())[5], "x -> :x_1") def test_op_not_an_iterator(self): # see [ticket:2726] class MyType(UserDefinedType): - class comparator_factory(UserDefinedType.Comparator): - def __getitem__(self, index): return self.op("->")(index) - col = Column('x', MyType()) + col = Column("x", MyType()) assert not isinstance(col, util.collections_abc.Iterable) def test_lshift(self): class MyType(UserDefinedType): - class comparator_factory(UserDefinedType.Comparator): - def __lshift__(self, other): return self.op("->")(other) - self.assert_compile( - Column('x', MyType()) << 5, - "x -> :x_1" - ) + self.assert_compile(Column("x", MyType()) << 5, "x -> :x_1") def test_rshift(self): class MyType(UserDefinedType): - class comparator_factory(UserDefinedType.Comparator): - def __rshift__(self, other): return self.op("->")(other) - self.assert_compile( - Column('x', MyType()) >> 5, - "x -> :x_1" - ) + self.assert_compile(Column("x", MyType()) >> 5, "x -> :x_1") class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): @@ -675,14 +611,14 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): return "MYOTHERTYPE" class MyCompiler(compiler.SQLCompiler): - def visit_json_getitem_op_binary(self, binary, operator, **kw): return self._generate_generic_binary( binary, " -> ", eager_grouping=True, **kw ) def visit_json_path_getitem_op_binary( - self, binary, operator, **kw): + self, binary, operator, **kw + ): return self._generate_generic_binary( binary, " #> ", eager_grouping=True, **kw ) @@ -695,7 +631,7 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): type_compiler = MyTypeCompiler class MyType(JSON): - __visit_name__ = 'mytype' + __visit_name__ = "mytype" pass @@ -703,109 +639,83 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.__dialect__ = MyDialect() def test_setup_getitem(self): - col = Column('x', self.MyType()) + col = Column("x", self.MyType()) - is_( - col[5].type._type_affinity, JSON - ) - is_( - col[5]['foo'].type._type_affinity, JSON - ) - is_( - col[('a', 'b', 'c')].type._type_affinity, JSON - ) + is_(col[5].type._type_affinity, JSON) + is_(col[5]["foo"].type._type_affinity, JSON) + is_(col[("a", "b", "c")].type._type_affinity, JSON) def test_getindex_literal_integer(self): - col = Column('x', self.MyType()) + col = Column("x", self.MyType()) - self.assert_compile( - col[5], - "x -> :x_1", - checkparams={'x_1': 5} - ) + self.assert_compile(col[5], "x -> :x_1", checkparams={"x_1": 5}) def test_getindex_literal_string(self): - col = Column('x', self.MyType()) + col = Column("x", self.MyType()) self.assert_compile( - col['foo'], - "x -> :x_1", - checkparams={'x_1': 'foo'} + col["foo"], "x -> :x_1", checkparams={"x_1": "foo"} ) def test_path_getindex_literal(self): - col = Column('x', self.MyType()) + col = Column("x", self.MyType()) self.assert_compile( - col[('a', 'b', 3, 4, 'd')], + col[("a", "b", 3, 4, "d")], "x #> :x_1", - checkparams={'x_1': ('a', 'b', 3, 4, 'd')} + checkparams={"x_1": ("a", "b", 3, 4, "d")}, ) def test_getindex_sqlexpr(self): - col = Column('x', self.MyType()) - col2 = Column('y', Integer()) + col = Column("x", self.MyType()) + col2 = Column("y", Integer()) - self.assert_compile( - col[col2], - "x -> y", - checkparams={} - ) + self.assert_compile(col[col2], "x -> y", checkparams={}) def test_getindex_sqlexpr_right_grouping(self): - col = Column('x', self.MyType()) - col2 = Column('y', Integer()) + col = Column("x", self.MyType()) + col2 = Column("y", Integer()) self.assert_compile( - col[col2 + 8], - "x -> (y + :y_1)", - checkparams={'y_1': 8} + col[col2 + 8], "x -> (y + :y_1)", checkparams={"y_1": 8} ) def test_getindex_sqlexpr_left_grouping(self): - col = Column('x', self.MyType()) + col = Column("x", self.MyType()) - self.assert_compile( - col[8] != None, # noqa - "(x -> :x_1) IS NOT NULL" - ) + self.assert_compile(col[8] != None, "(x -> :x_1) IS NOT NULL") # noqa def test_getindex_sqlexpr_both_grouping(self): - col = Column('x', self.MyType()) - col2 = Column('y', Integer()) + col = Column("x", self.MyType()) + col2 = Column("y", Integer()) self.assert_compile( col[col2 + 8] != None, # noqa "(x -> (y + :y_1)) IS NOT NULL", - checkparams={'y_1': 8} + checkparams={"y_1": 8}, ) def test_override_operators(self): - special_index_op = operators.custom_op('$$>') + special_index_op = operators.custom_op("$$>") class MyOtherType(JSON, TypeEngine): - __visit_name__ = 'myothertype' + __visit_name__ = "myothertype" class Comparator(TypeEngine.Comparator): - def _adapt_expression(self, op, other_comparator): return special_index_op, MyOtherType() comparator_factory = Comparator - col = Column('x', MyOtherType()) - self.assert_compile( - col[5], - "x $$> :x_1", - checkparams={'x_1': 5} - ) + col = Column("x", MyOtherType()) + self.assert_compile(col[5], "x $$> :x_1", checkparams={"x_1": 5}) class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): @@ -827,7 +737,7 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): def visit_getitem_binary(self, binary, operator, **kw): return "%s[%s]" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) class MyDialect(default.DefaultDialect): @@ -835,7 +745,7 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): type_compiler = MyTypeCompiler class MyType(ARRAY): - __visit_name__ = 'mytype' + __visit_name__ = "mytype" def __init__(self, zero_indexes=False, dimensions=1): if zero_indexes: @@ -850,148 +760,109 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): """test the behavior of the _setup_getitem() method given a simple 'dimensions' scheme - this is identical to postgresql.ARRAY.""" - col = Column('x', self.MyType(dimensions=3)) + col = Column("x", self.MyType(dimensions=3)) - is_( - col[5].type._type_affinity, ARRAY - ) - eq_( - col[5].type.dimensions, 2 - ) - is_( - col[5][6].type._type_affinity, ARRAY - ) - eq_( - col[5][6].type.dimensions, 1 - ) - is_( - col[5][6][7].type._type_affinity, Integer - ) + is_(col[5].type._type_affinity, ARRAY) + eq_(col[5].type.dimensions, 2) + is_(col[5][6].type._type_affinity, ARRAY) + eq_(col[5][6].type.dimensions, 1) + is_(col[5][6][7].type._type_affinity, Integer) def test_getindex_literal(self): - col = Column('x', self.MyType()) + col = Column("x", self.MyType()) - self.assert_compile( - col[5], - "x[:x_1]", - checkparams={'x_1': 5} - ) + self.assert_compile(col[5], "x[:x_1]", checkparams={"x_1": 5}) def test_contains_override_raises(self): - col = Column('x', self.MyType()) + col = Column("x", self.MyType()) assert_raises_message( NotImplementedError, "Operator 'contains' is not supported on this expression", - lambda: 'foo' in col + lambda: "foo" in col, ) def test_getindex_sqlexpr(self): - col = Column('x', self.MyType()) - col2 = Column('y', Integer()) + col = Column("x", self.MyType()) + col2 = Column("y", Integer()) - self.assert_compile( - col[col2], - "x[y]", - checkparams={} - ) + self.assert_compile(col[col2], "x[y]", checkparams={}) self.assert_compile( - col[col2 + 8], - "x[(y + :y_1)]", - checkparams={'y_1': 8} + col[col2 + 8], "x[(y + :y_1)]", checkparams={"y_1": 8} ) def test_getslice_literal(self): - col = Column('x', self.MyType()) + col = Column("x", self.MyType()) self.assert_compile( - col[5:6], - "x[:x_1::x_2]", - checkparams={'x_1': 5, 'x_2': 6} + col[5:6], "x[:x_1::x_2]", checkparams={"x_1": 5, "x_2": 6} ) def test_getslice_sqlexpr(self): - col = Column('x', self.MyType()) - col2 = Column('y', Integer()) + col = Column("x", self.MyType()) + col2 = Column("y", Integer()) self.assert_compile( - col[col2:col2 + 5], - "x[y:y + :y_1]", - checkparams={'y_1': 5} + col[col2 : col2 + 5], "x[y:y + :y_1]", checkparams={"y_1": 5} ) def test_getindex_literal_zeroind(self): - col = Column('x', self.MyType(zero_indexes=True)) + col = Column("x", self.MyType(zero_indexes=True)) - self.assert_compile( - col[5], - "x[:x_1]", - checkparams={'x_1': 6} - ) + self.assert_compile(col[5], "x[:x_1]", checkparams={"x_1": 6}) def test_getindex_sqlexpr_zeroind(self): - col = Column('x', self.MyType(zero_indexes=True)) - col2 = Column('y', Integer()) + col = Column("x", self.MyType(zero_indexes=True)) + col2 = Column("y", Integer()) - self.assert_compile( - col[col2], - "x[(y + :y_1)]", - checkparams={'y_1': 1} - ) + self.assert_compile(col[col2], "x[(y + :y_1)]", checkparams={"y_1": 1}) self.assert_compile( col[col2 + 8], "x[(y + :y_1 + :param_1)]", - checkparams={'y_1': 8, 'param_1': 1} + checkparams={"y_1": 8, "param_1": 1}, ) def test_getslice_literal_zeroind(self): - col = Column('x', self.MyType(zero_indexes=True)) + col = Column("x", self.MyType(zero_indexes=True)) self.assert_compile( - col[5:6], - "x[:x_1::x_2]", - checkparams={'x_1': 6, 'x_2': 7} + col[5:6], "x[:x_1::x_2]", checkparams={"x_1": 6, "x_2": 7} ) def test_getslice_sqlexpr_zeroind(self): - col = Column('x', self.MyType(zero_indexes=True)) - col2 = Column('y', Integer()) + col = Column("x", self.MyType(zero_indexes=True)) + col2 = Column("y", Integer()) self.assert_compile( - col[col2:col2 + 5], + col[col2 : col2 + 5], "x[y + :y_1:y + :y_2 + :param_1]", - checkparams={'y_1': 1, 'y_2': 5, 'param_1': 1} + checkparams={"y_1": 1, "y_2": 5, "param_1": 1}, ) def test_override_operators(self): - special_index_op = operators.custom_op('->') + special_index_op = operators.custom_op("->") class MyOtherType(Indexable, TypeEngine): - __visit_name__ = 'myothertype' + __visit_name__ = "myothertype" class Comparator(TypeEngine.Comparator): - def _adapt_expression(self, op, other_comparator): return special_index_op, MyOtherType() comparator_factory = Comparator - col = Column('x', MyOtherType()) - self.assert_compile( - col[5], - "x -> :x_1", - checkparams={'x_1': 5} - ) + col = Column("x", MyOtherType()) + self.assert_compile(col[5], "x -> :x_1", checkparams={"x_1": 5}) class BooleanEvalTest(fixtures.TestBase, testing.AssertsCompiledSQL): @@ -1005,115 +876,101 @@ class BooleanEvalTest(fixtures.TestBase, testing.AssertsCompiledSQL): return d def test_one(self): - c = column('x', Boolean) + c = column("x", Boolean) self.assert_compile( select([c]).where(c), "SELECT x WHERE x", - dialect=self._dialect(True) + dialect=self._dialect(True), ) def test_two_a(self): - c = column('x', Boolean) + c = column("x", Boolean) self.assert_compile( select([c]).where(c), "SELECT x WHERE x = 1", - dialect=self._dialect(False) + dialect=self._dialect(False), ) def test_two_b(self): - c = column('x', Boolean) + c = column("x", Boolean) self.assert_compile( select([c], whereclause=c), "SELECT x WHERE x = 1", - dialect=self._dialect(False) + dialect=self._dialect(False), ) def test_three_a(self): - c = column('x', Boolean) + c = column("x", Boolean) self.assert_compile( select([c]).where(~c), "SELECT x WHERE x = 0", - dialect=self._dialect(False) + dialect=self._dialect(False), ) def test_three_b(self): - c = column('x', Boolean) + c = column("x", Boolean) self.assert_compile( select([c], whereclause=~c), "SELECT x WHERE x = 0", - dialect=self._dialect(False) + dialect=self._dialect(False), ) def test_four(self): - c = column('x', Boolean) + c = column("x", Boolean) self.assert_compile( select([c]).where(~c), "SELECT x WHERE NOT x", - dialect=self._dialect(True) + dialect=self._dialect(True), ) def test_five_a(self): - c = column('x', Boolean) + c = column("x", Boolean) self.assert_compile( select([c]).having(c), "SELECT x HAVING x = 1", - dialect=self._dialect(False) + dialect=self._dialect(False), ) def test_five_b(self): - c = column('x', Boolean) + c = column("x", Boolean) self.assert_compile( select([c], having=c), "SELECT x HAVING x = 1", - dialect=self._dialect(False) + dialect=self._dialect(False), ) def test_six(self): self.assert_compile( - or_(false(), true()), - "1 = 1", - dialect=self._dialect(False) + or_(false(), true()), "1 = 1", dialect=self._dialect(False) ) def test_eight(self): self.assert_compile( - and_(false(), true()), - "false", - dialect=self._dialect(True) + and_(false(), true()), "false", dialect=self._dialect(True) ) def test_nine(self): self.assert_compile( - and_(false(), true()), - "0 = 1", - dialect=self._dialect(False) + and_(false(), true()), "0 = 1", dialect=self._dialect(False) ) def test_ten(self): - c = column('x', Boolean) - self.assert_compile( - c == 1, - "x = :x_1", - dialect=self._dialect(False) - ) + c = column("x", Boolean) + self.assert_compile(c == 1, "x = :x_1", dialect=self._dialect(False)) def test_eleven(self): - c = column('x', Boolean) + c = column("x", Boolean) self.assert_compile( - c.is_(true()), - "x IS true", - dialect=self._dialect(True) + c.is_(true()), "x IS true", dialect=self._dialect(True) ) def test_twelve(self): - c = column('x', Boolean) + c = column("x", Boolean) # I don't have a solution for this one yet, # other than adding some heavy-handed conditionals # into compiler self.assert_compile( - c.is_(true()), - "x IS 1", - dialect=self._dialect(False) + c.is_(true()), "x IS 1", dialect=self._dialect(False) ) @@ -1121,6 +978,7 @@ class ConjunctionTest(fixtures.TestBase, testing.AssertsCompiledSQL): """test interaction of and_()/or_() with boolean , null constants """ + __dialect__ = default.DefaultDialect(supports_native_boolean=True) def test_one(self): @@ -1133,17 +991,14 @@ class ConjunctionTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile(or_(and_()), "") def test_four(self): - x = column('x') + x = column("x") self.assert_compile( - and_(or_(x == 5), or_(x == 7)), - "x = :x_1 AND x = :x_2") + and_(or_(x == 5), or_(x == 7)), "x = :x_1 AND x = :x_2" + ) def test_five(self): x = column("x") - self.assert_compile( - and_(true()._ifnone(None), x == 7), - "x = :x_1" - ) + self.assert_compile(and_(true()._ifnone(None), x == 7), "x = :x_1") def test_six(self): x = column("x") @@ -1153,74 +1008,59 @@ class ConjunctionTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_six_pt_five(self): x = column("x") - self.assert_compile(select([x]).where(or_(x == 7, true())), - "SELECT x WHERE true") + self.assert_compile( + select([x]).where(or_(x == 7, true())), "SELECT x WHERE true" + ) self.assert_compile( - select( - [x]).where( - or_( - x == 7, - true())), + select([x]).where(or_(x == 7, true())), "SELECT x WHERE 1 = 1", - dialect=default.DefaultDialect( - supports_native_boolean=False)) + dialect=default.DefaultDialect(supports_native_boolean=False), + ) def test_seven(self): x = column("x") self.assert_compile( - and_(true(), x == 7, true(), x == 9), - "x = :x_1 AND x = :x_2") + and_(true(), x == 7, true(), x == 9), "x = :x_1 AND x = :x_2" + ) def test_eight(self): x = column("x") self.assert_compile( - or_(false(), x == 7, false(), x == 9), - "x = :x_1 OR x = :x_2") + or_(false(), x == 7, false(), x == 9), "x = :x_1 OR x = :x_2" + ) def test_nine(self): x = column("x") - self.assert_compile( - and_(x == 7, x == 9, false(), x == 5), - "false" - ) - self.assert_compile( - ~and_(x == 7, x == 9, false(), x == 5), - "true" - ) + self.assert_compile(and_(x == 7, x == 9, false(), x == 5), "false") + self.assert_compile(~and_(x == 7, x == 9, false(), x == 5), "true") def test_ten(self): - self.assert_compile( - and_(None, None), - "NULL AND NULL" - ) + self.assert_compile(and_(None, None), "NULL AND NULL") def test_eleven(self): x = column("x") self.assert_compile( - select([x]).where(None).where(None), - "SELECT x WHERE NULL AND NULL" + select([x]).where(None).where(None), "SELECT x WHERE NULL AND NULL" ) def test_twelve(self): x = column("x") self.assert_compile( - select([x]).where(and_(None, None)), - "SELECT x WHERE NULL AND NULL" + select([x]).where(and_(None, None)), "SELECT x WHERE NULL AND NULL" ) def test_thirteen(self): x = column("x") self.assert_compile( select([x]).where(~and_(None, None)), - "SELECT x WHERE NOT (NULL AND NULL)" + "SELECT x WHERE NOT (NULL AND NULL)", ) def test_fourteen(self): x = column("x") self.assert_compile( - select([x]).where(~null()), - "SELECT x WHERE NOT NULL" + select([x]).where(~null()), "SELECT x WHERE NOT NULL" ) def test_constant_non_singleton(self): @@ -1230,537 +1070,560 @@ class ConjunctionTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_constant_render_distinct(self): self.assert_compile( - select([null(), null()]), - "SELECT NULL AS anon_1, NULL AS anon_2" + select([null(), null()]), "SELECT NULL AS anon_1, NULL AS anon_2" ) self.assert_compile( - select([true(), true()]), - "SELECT true AS anon_1, true AS anon_2" + select([true(), true()]), "SELECT true AS anon_1, true AS anon_2" ) self.assert_compile( select([false(), false()]), - "SELECT false AS anon_1, false AS anon_2" + "SELECT false AS anon_1, false AS anon_2", ) def test_is_true_literal(self): - c = column('x', Boolean) - self.assert_compile( - c.is_(True), - "x IS true" - ) + c = column("x", Boolean) + self.assert_compile(c.is_(True), "x IS true") def test_is_false_literal(self): - c = column('x', Boolean) - self.assert_compile( - c.is_(False), - "x IS false" - ) + c = column("x", Boolean) + self.assert_compile(c.is_(False), "x IS false") def test_and_false_literal_leading(self): - self.assert_compile( - and_(False, True), - "false" - ) + self.assert_compile(and_(False, True), "false") - self.assert_compile( - and_(False, False), - "false" - ) + self.assert_compile(and_(False, False), "false") def test_and_true_literal_leading(self): - self.assert_compile( - and_(True, True), - "true" - ) + self.assert_compile(and_(True, True), "true") - self.assert_compile( - and_(True, False), - "false" - ) + self.assert_compile(and_(True, False), "false") def test_or_false_literal_leading(self): - self.assert_compile( - or_(False, True), - "true" - ) + self.assert_compile(or_(False, True), "true") - self.assert_compile( - or_(False, False), - "false" - ) + self.assert_compile(or_(False, False), "false") def test_or_true_literal_leading(self): - self.assert_compile( - or_(True, True), - "true" - ) + self.assert_compile(or_(True, True), "true") - self.assert_compile( - or_(True, False), - "true" - ) + self.assert_compile(or_(True, False), "true") class OperatorPrecedenceTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" - table1 = table('mytable', - column('myid', Integer), - column('name', String), - column('description', String), - ) + table1 = table( + "mytable", + column("myid", Integer), + column("name", String), + column("description", String), + ) - table2 = table('op', column('field')) + table2 = table("op", column("field")) def test_operator_precedence_1(self): self.assert_compile( self.table2.select((self.table2.c.field == 5) == None), # noqa - "SELECT op.field FROM op WHERE (op.field = :field_1) IS NULL") + "SELECT op.field FROM op WHERE (op.field = :field_1) IS NULL", + ) def test_operator_precedence_2(self): self.assert_compile( self.table2.select( - (self.table2.c.field + 5) == self.table2.c.field), - "SELECT op.field FROM op WHERE op.field + :field_1 = op.field") + (self.table2.c.field + 5) == self.table2.c.field + ), + "SELECT op.field FROM op WHERE op.field + :field_1 = op.field", + ) def test_operator_precedence_3(self): self.assert_compile( self.table2.select((self.table2.c.field + 5) * 6), - "SELECT op.field FROM op WHERE (op.field + :field_1) * :param_1") + "SELECT op.field FROM op WHERE (op.field + :field_1) * :param_1", + ) def test_operator_precedence_4(self): self.assert_compile( - self.table2.select( - (self.table2.c.field * 5) + 6), - "SELECT op.field FROM op WHERE op.field * :field_1 + :param_1") + self.table2.select((self.table2.c.field * 5) + 6), + "SELECT op.field FROM op WHERE op.field * :field_1 + :param_1", + ) def test_operator_precedence_5(self): - self.assert_compile(self.table2.select( - 5 + self.table2.c.field.in_([5, 6])), - "SELECT op.field FROM op WHERE :param_1 + " - "(op.field IN (:field_1, :field_2))") + self.assert_compile( + self.table2.select(5 + self.table2.c.field.in_([5, 6])), + "SELECT op.field FROM op WHERE :param_1 + " + "(op.field IN (:field_1, :field_2))", + ) def test_operator_precedence_6(self): - self.assert_compile(self.table2.select( - (5 + self.table2.c.field).in_([5, 6])), + self.assert_compile( + self.table2.select((5 + self.table2.c.field).in_([5, 6])), "SELECT op.field FROM op WHERE :field_1 + op.field " - "IN (:param_1, :param_2)") + "IN (:param_1, :param_2)", + ) def test_operator_precedence_7(self): - self.assert_compile(self.table2.select( - not_(and_(self.table2.c.field == 5, - self.table2.c.field == 7))), + self.assert_compile( + self.table2.select( + not_(and_(self.table2.c.field == 5, self.table2.c.field == 7)) + ), "SELECT op.field FROM op WHERE NOT " - "(op.field = :field_1 AND op.field = :field_2)") + "(op.field = :field_1 AND op.field = :field_2)", + ) def test_operator_precedence_8(self): self.assert_compile( - self.table2.select( - not_( - self.table2.c.field == 5)), - "SELECT op.field FROM op WHERE op.field != :field_1") + self.table2.select(not_(self.table2.c.field == 5)), + "SELECT op.field FROM op WHERE op.field != :field_1", + ) def test_operator_precedence_9(self): - self.assert_compile(self.table2.select( - not_(self.table2.c.field.between(5, 6))), + self.assert_compile( + self.table2.select(not_(self.table2.c.field.between(5, 6))), "SELECT op.field FROM op WHERE " - "op.field NOT BETWEEN :field_1 AND :field_2") + "op.field NOT BETWEEN :field_1 AND :field_2", + ) def test_operator_precedence_10(self): self.assert_compile( - self.table2.select( - not_( - self.table2.c.field) == 5), - "SELECT op.field FROM op WHERE (NOT op.field) = :param_1") + self.table2.select(not_(self.table2.c.field) == 5), + "SELECT op.field FROM op WHERE (NOT op.field) = :param_1", + ) def test_operator_precedence_11(self): - self.assert_compile(self.table2.select( - (self.table2.c.field == self.table2.c.field). - between(False, True)), + self.assert_compile( + self.table2.select( + (self.table2.c.field == self.table2.c.field).between( + False, True + ) + ), "SELECT op.field FROM op WHERE (op.field = op.field) " - "BETWEEN :param_1 AND :param_2") + "BETWEEN :param_1 AND :param_2", + ) def test_operator_precedence_12(self): - self.assert_compile(self.table2.select( - between((self.table2.c.field == self.table2.c.field), - False, True)), + self.assert_compile( + self.table2.select( + between( + (self.table2.c.field == self.table2.c.field), False, True + ) + ), "SELECT op.field FROM op WHERE (op.field = op.field) " - "BETWEEN :param_1 AND :param_2") + "BETWEEN :param_1 AND :param_2", + ) def test_operator_precedence_13(self): self.assert_compile( self.table2.select( - self.table2.c.field.match( - self.table2.c.field).is_(None)), - "SELECT op.field FROM op WHERE (op.field MATCH op.field) IS NULL") + self.table2.c.field.match(self.table2.c.field).is_(None) + ), + "SELECT op.field FROM op WHERE (op.field MATCH op.field) IS NULL", + ) def test_operator_precedence_collate_1(self): self.assert_compile( - self.table1.c.name == literal('foo').collate('utf-8'), - 'mytable.name = (:param_1 COLLATE "utf-8")' + self.table1.c.name == literal("foo").collate("utf-8"), + 'mytable.name = (:param_1 COLLATE "utf-8")', ) def test_operator_precedence_collate_2(self): self.assert_compile( - (self.table1.c.name == literal('foo')).collate('utf-8'), - 'mytable.name = :param_1 COLLATE "utf-8"' + (self.table1.c.name == literal("foo")).collate("utf-8"), + 'mytable.name = :param_1 COLLATE "utf-8"', ) def test_operator_precedence_collate_3(self): self.assert_compile( - self.table1.c.name.collate('utf-8') == 'foo', - '(mytable.name COLLATE "utf-8") = :param_1' + self.table1.c.name.collate("utf-8") == "foo", + '(mytable.name COLLATE "utf-8") = :param_1', ) def test_operator_precedence_collate_4(self): self.assert_compile( and_( - (self.table1.c.name == literal('foo')).collate('utf-8'), - (self.table2.c.field == literal('bar')).collate('utf-8'), + (self.table1.c.name == literal("foo")).collate("utf-8"), + (self.table2.c.field == literal("bar")).collate("utf-8"), ), 'mytable.name = :param_1 COLLATE "utf-8" ' - 'AND op.field = :param_2 COLLATE "utf-8"' + 'AND op.field = :param_2 COLLATE "utf-8"', ) def test_operator_precedence_collate_5(self): self.assert_compile( select([self.table1.c.name]).order_by( - self.table1.c.name.collate('utf-8').desc()), + self.table1.c.name.collate("utf-8").desc() + ), "SELECT mytable.name FROM mytable " - 'ORDER BY mytable.name COLLATE "utf-8" DESC' + 'ORDER BY mytable.name COLLATE "utf-8" DESC', ) def test_operator_precedence_collate_6(self): self.assert_compile( select([self.table1.c.name]).order_by( - self.table1.c.name.collate('utf-8').desc().nullslast()), + self.table1.c.name.collate("utf-8").desc().nullslast() + ), "SELECT mytable.name FROM mytable " - 'ORDER BY mytable.name COLLATE "utf-8" DESC NULLS LAST' + 'ORDER BY mytable.name COLLATE "utf-8" DESC NULLS LAST', ) def test_operator_precedence_collate_7(self): self.assert_compile( select([self.table1.c.name]).order_by( - self.table1.c.name.collate('utf-8').asc()), + self.table1.c.name.collate("utf-8").asc() + ), "SELECT mytable.name FROM mytable " - 'ORDER BY mytable.name COLLATE "utf-8" ASC' + 'ORDER BY mytable.name COLLATE "utf-8" ASC', ) def test_commutative_operators(self): self.assert_compile( literal("a") + literal("b") * literal("c"), - ":param_1 || :param_2 * :param_3" + ":param_1 || :param_2 * :param_3", ) def test_op_operators(self): self.assert_compile( - self.table1.select(self.table1.c.myid.op('hoho')(12) == 14), + self.table1.select(self.table1.c.myid.op("hoho")(12) == 14), "SELECT mytable.myid, mytable.name, mytable.description FROM " - "mytable WHERE (mytable.myid hoho :myid_1) = :param_1" + "mytable WHERE (mytable.myid hoho :myid_1) = :param_1", ) def test_op_operators_comma_precedence(self): self.assert_compile( - func.foo(self.table1.c.myid.op('hoho')(12)), - "foo(mytable.myid hoho :myid_1)" + func.foo(self.table1.c.myid.op("hoho")(12)), + "foo(mytable.myid hoho :myid_1)", ) def test_op_operators_comparison_precedence(self): self.assert_compile( - self.table1.c.myid.op('hoho')(12) == 5, - "(mytable.myid hoho :myid_1) = :param_1" + self.table1.c.myid.op("hoho")(12) == 5, + "(mytable.myid hoho :myid_1) = :param_1", ) def test_op_operators_custom_precedence(self): - op1 = self.table1.c.myid.op('hoho', precedence=5) - op2 = op1(5).op('lala', precedence=4)(4) - op3 = op1(5).op('lala', precedence=6)(4) + op1 = self.table1.c.myid.op("hoho", precedence=5) + op2 = op1(5).op("lala", precedence=4)(4) + op3 = op1(5).op("lala", precedence=6)(4) self.assert_compile(op2, "mytable.myid hoho :myid_1 lala :param_1") self.assert_compile(op3, "(mytable.myid hoho :myid_1) lala :param_1") def test_is_eq_precedence_flat(self): self.assert_compile( - (self.table1.c.name == null()) != - (self.table1.c.description == null()), + (self.table1.c.name == null()) + != (self.table1.c.description == null()), "(mytable.name IS NULL) != (mytable.description IS NULL)", ) class OperatorAssociativityTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_associativity_1(self): - f = column('f') + f = column("f") self.assert_compile(f - f, "f - f") def test_associativity_2(self): - f = column('f') + f = column("f") self.assert_compile(f - f - f, "(f - f) - f") def test_associativity_3(self): - f = column('f') + f = column("f") self.assert_compile((f - f) - f, "(f - f) - f") def test_associativity_4(self): - f = column('f') - self.assert_compile((f - f).label('foo') - f, "(f - f) - f") + f = column("f") + self.assert_compile((f - f).label("foo") - f, "(f - f) - f") def test_associativity_5(self): - f = column('f') + f = column("f") self.assert_compile(f - (f - f), "f - (f - f)") def test_associativity_6(self): - f = column('f') - self.assert_compile(f - (f - f).label('foo'), "f - (f - f)") + f = column("f") + self.assert_compile(f - (f - f).label("foo"), "f - (f - f)") def test_associativity_7(self): - f = column('f') + f = column("f") # because - less precedent than / self.assert_compile(f / (f - f), "f / (f - f)") def test_associativity_8(self): - f = column('f') - self.assert_compile(f / (f - f).label('foo'), "f / (f - f)") + f = column("f") + self.assert_compile(f / (f - f).label("foo"), "f / (f - f)") def test_associativity_9(self): - f = column('f') + f = column("f") self.assert_compile(f / f - f, "f / f - f") def test_associativity_10(self): - f = column('f') + f = column("f") self.assert_compile((f / f) - f, "f / f - f") def test_associativity_11(self): - f = column('f') - self.assert_compile((f / f).label('foo') - f, "f / f - f") + f = column("f") + self.assert_compile((f / f).label("foo") - f, "f / f - f") def test_associativity_12(self): - f = column('f') + f = column("f") # because / more precedent than - self.assert_compile(f - (f / f), "f - f / f") def test_associativity_13(self): - f = column('f') - self.assert_compile(f - (f / f).label('foo'), "f - f / f") + f = column("f") + self.assert_compile(f - (f / f).label("foo"), "f - f / f") def test_associativity_14(self): - f = column('f') + f = column("f") self.assert_compile(f - f / f, "f - f / f") def test_associativity_15(self): - f = column('f') + f = column("f") self.assert_compile((f - f) / f, "(f - f) / f") def test_associativity_16(self): - f = column('f') + f = column("f") self.assert_compile(((f - f) / f) - f, "(f - f) / f - f") def test_associativity_17(self): - f = column('f') + f = column("f") # - lower precedence than / self.assert_compile((f - f) / (f - f), "(f - f) / (f - f)") def test_associativity_18(self): - f = column('f') + f = column("f") # / higher precedence than - self.assert_compile((f / f) - (f / f), "f / f - f / f") def test_associativity_19(self): - f = column('f') + f = column("f") self.assert_compile((f / f) - (f - f), "f / f - (f - f)") def test_associativity_20(self): - f = column('f') + f = column("f") self.assert_compile((f / f) / (f - f), "(f / f) / (f - f)") def test_associativity_21(self): - f = column('f') + f = column("f") self.assert_compile(f / (f / (f - f)), "f / (f / (f - f))") def test_associativity_22(self): - f = column('f') - self.assert_compile((f == f) == f, '(f = f) = f') + f = column("f") + self.assert_compile((f == f) == f, "(f = f) = f") def test_associativity_23(self): - f = column('f') - self.assert_compile((f != f) != f, '(f != f) != f') + f = column("f") + self.assert_compile((f != f) != f, "(f != f) != f") class IsDistinctFromTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" - table1 = table('mytable', - column('myid', Integer), - ) + table1 = table("mytable", column("myid", Integer)) def test_is_distinct_from(self): - self.assert_compile(self.table1.c.myid.is_distinct_from(1), - "mytable.myid IS DISTINCT FROM :myid_1") + self.assert_compile( + self.table1.c.myid.is_distinct_from(1), + "mytable.myid IS DISTINCT FROM :myid_1", + ) def test_is_distinct_from_sqlite(self): - self.assert_compile(self.table1.c.myid.is_distinct_from(1), - "mytable.myid IS NOT ?", - dialect=sqlite.dialect()) + self.assert_compile( + self.table1.c.myid.is_distinct_from(1), + "mytable.myid IS NOT ?", + dialect=sqlite.dialect(), + ) def test_is_distinct_from_postgresql(self): - self.assert_compile(self.table1.c.myid.is_distinct_from(1), - "mytable.myid IS DISTINCT FROM %(myid_1)s", - dialect=postgresql.dialect()) + self.assert_compile( + self.table1.c.myid.is_distinct_from(1), + "mytable.myid IS DISTINCT FROM %(myid_1)s", + dialect=postgresql.dialect(), + ) def test_not_is_distinct_from(self): - self.assert_compile(~self.table1.c.myid.is_distinct_from(1), - "mytable.myid IS NOT DISTINCT FROM :myid_1") + self.assert_compile( + ~self.table1.c.myid.is_distinct_from(1), + "mytable.myid IS NOT DISTINCT FROM :myid_1", + ) def test_not_is_distinct_from_postgresql(self): - self.assert_compile(~self.table1.c.myid.is_distinct_from(1), - "mytable.myid IS NOT DISTINCT FROM %(myid_1)s", - dialect=postgresql.dialect()) + self.assert_compile( + ~self.table1.c.myid.is_distinct_from(1), + "mytable.myid IS NOT DISTINCT FROM %(myid_1)s", + dialect=postgresql.dialect(), + ) def test_isnot_distinct_from(self): - self.assert_compile(self.table1.c.myid.isnot_distinct_from(1), - "mytable.myid IS NOT DISTINCT FROM :myid_1") + self.assert_compile( + self.table1.c.myid.isnot_distinct_from(1), + "mytable.myid IS NOT DISTINCT FROM :myid_1", + ) def test_isnot_distinct_from_sqlite(self): - self.assert_compile(self.table1.c.myid.isnot_distinct_from(1), - "mytable.myid IS ?", - dialect=sqlite.dialect()) + self.assert_compile( + self.table1.c.myid.isnot_distinct_from(1), + "mytable.myid IS ?", + dialect=sqlite.dialect(), + ) def test_isnot_distinct_from_postgresql(self): - self.assert_compile(self.table1.c.myid.isnot_distinct_from(1), - "mytable.myid IS NOT DISTINCT FROM %(myid_1)s", - dialect=postgresql.dialect()) + self.assert_compile( + self.table1.c.myid.isnot_distinct_from(1), + "mytable.myid IS NOT DISTINCT FROM %(myid_1)s", + dialect=postgresql.dialect(), + ) def test_not_isnot_distinct_from(self): - self.assert_compile(~self.table1.c.myid.isnot_distinct_from(1), - "mytable.myid IS DISTINCT FROM :myid_1") + self.assert_compile( + ~self.table1.c.myid.isnot_distinct_from(1), + "mytable.myid IS DISTINCT FROM :myid_1", + ) def test_not_isnot_distinct_from_postgresql(self): - self.assert_compile(~self.table1.c.myid.isnot_distinct_from(1), - "mytable.myid IS DISTINCT FROM %(myid_1)s", - dialect=postgresql.dialect()) + self.assert_compile( + ~self.table1.c.myid.isnot_distinct_from(1), + "mytable.myid IS DISTINCT FROM %(myid_1)s", + dialect=postgresql.dialect(), + ) class InTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" - table1 = table('mytable', - column('myid', Integer), - ) + table1 = table("mytable", column("myid", Integer)) table2 = table( - 'myothertable', - column('otherid', Integer), - column('othername', String) + "myothertable", column("otherid", Integer), column("othername", String) ) def _dialect(self, empty_in_strategy="static"): - return default.DefaultDialect( - empty_in_strategy=empty_in_strategy - ) + return default.DefaultDialect(empty_in_strategy=empty_in_strategy) def test_in_1(self): - self.assert_compile(self.table1.c.myid.in_(['a']), - "mytable.myid IN (:myid_1)") + self.assert_compile( + self.table1.c.myid.in_(["a"]), "mytable.myid IN (:myid_1)" + ) def test_in_2(self): - self.assert_compile(~self.table1.c.myid.in_(['a']), - "mytable.myid NOT IN (:myid_1)") + self.assert_compile( + ~self.table1.c.myid.in_(["a"]), "mytable.myid NOT IN (:myid_1)" + ) def test_in_3(self): - self.assert_compile(self.table1.c.myid.in_(['a', 'b']), - "mytable.myid IN (:myid_1, :myid_2)") + self.assert_compile( + self.table1.c.myid.in_(["a", "b"]), + "mytable.myid IN (:myid_1, :myid_2)", + ) def test_in_4(self): - self.assert_compile(self.table1.c.myid.in_(iter(['a', 'b'])), - "mytable.myid IN (:myid_1, :myid_2)") + self.assert_compile( + self.table1.c.myid.in_(iter(["a", "b"])), + "mytable.myid IN (:myid_1, :myid_2)", + ) def test_in_5(self): - self.assert_compile(self.table1.c.myid.in_([literal('a')]), - "mytable.myid IN (:param_1)") + self.assert_compile( + self.table1.c.myid.in_([literal("a")]), + "mytable.myid IN (:param_1)", + ) def test_in_6(self): - self.assert_compile(self.table1.c.myid.in_([literal('a'), 'b']), - "mytable.myid IN (:param_1, :myid_1)") + self.assert_compile( + self.table1.c.myid.in_([literal("a"), "b"]), + "mytable.myid IN (:param_1, :myid_1)", + ) def test_in_7(self): self.assert_compile( - self.table1.c.myid.in_([literal('a'), literal('b')]), - "mytable.myid IN (:param_1, :param_2)") + self.table1.c.myid.in_([literal("a"), literal("b")]), + "mytable.myid IN (:param_1, :param_2)", + ) def test_in_8(self): - self.assert_compile(self.table1.c.myid.in_(['a', literal('b')]), - "mytable.myid IN (:myid_1, :param_1)") + self.assert_compile( + self.table1.c.myid.in_(["a", literal("b")]), + "mytable.myid IN (:myid_1, :param_1)", + ) def test_in_9(self): - self.assert_compile(self.table1.c.myid.in_([literal(1) + 'a']), - "mytable.myid IN (:param_1 + :param_2)") + self.assert_compile( + self.table1.c.myid.in_([literal(1) + "a"]), + "mytable.myid IN (:param_1 + :param_2)", + ) def test_in_10(self): - self.assert_compile(self.table1.c.myid.in_([literal('a') + 'a', 'b']), - "mytable.myid IN (:param_1 || :param_2, :myid_1)") + self.assert_compile( + self.table1.c.myid.in_([literal("a") + "a", "b"]), + "mytable.myid IN (:param_1 || :param_2, :myid_1)", + ) def test_in_11(self): self.assert_compile( self.table1.c.myid.in_( - [ - literal('a') + - literal('a'), - literal('b')]), - "mytable.myid IN (:param_1 || :param_2, :param_3)") + [literal("a") + literal("a"), literal("b")] + ), + "mytable.myid IN (:param_1 || :param_2, :param_3)", + ) def test_in_12(self): - self.assert_compile(self.table1.c.myid.in_([1, literal(3) + 4]), - "mytable.myid IN (:myid_1, :param_1 + :param_2)") + self.assert_compile( + self.table1.c.myid.in_([1, literal(3) + 4]), + "mytable.myid IN (:myid_1, :param_1 + :param_2)", + ) def test_in_13(self): - self.assert_compile(self.table1.c.myid.in_([literal('a') < 'b']), - "mytable.myid IN (:param_1 < :param_2)") + self.assert_compile( + self.table1.c.myid.in_([literal("a") < "b"]), + "mytable.myid IN (:param_1 < :param_2)", + ) def test_in_14(self): - self.assert_compile(self.table1.c.myid.in_([self.table1.c.myid]), - "mytable.myid IN (mytable.myid)") + self.assert_compile( + self.table1.c.myid.in_([self.table1.c.myid]), + "mytable.myid IN (mytable.myid)", + ) def test_in_15(self): - self.assert_compile(self.table1.c.myid.in_(['a', self.table1.c.myid]), - "mytable.myid IN (:myid_1, mytable.myid)") + self.assert_compile( + self.table1.c.myid.in_(["a", self.table1.c.myid]), + "mytable.myid IN (:myid_1, mytable.myid)", + ) def test_in_16(self): - self.assert_compile(self.table1.c.myid.in_([literal('a'), - self.table1.c.myid]), - "mytable.myid IN (:param_1, mytable.myid)") + self.assert_compile( + self.table1.c.myid.in_([literal("a"), self.table1.c.myid]), + "mytable.myid IN (:param_1, mytable.myid)", + ) def test_in_17(self): self.assert_compile( - self.table1.c.myid.in_( - [ - literal('a'), - self.table1.c.myid + - 'a']), - "mytable.myid IN (:param_1, mytable.myid + :myid_1)") + self.table1.c.myid.in_([literal("a"), self.table1.c.myid + "a"]), + "mytable.myid IN (:param_1, mytable.myid + :myid_1)", + ) def test_in_18(self): self.assert_compile( - self.table1.c.myid.in_( - [ - literal(1), - 'a' + - self.table1.c.myid]), - "mytable.myid IN (:param_1, :myid_1 + mytable.myid)") + self.table1.c.myid.in_([literal(1), "a" + self.table1.c.myid]), + "mytable.myid IN (:param_1, :myid_1 + mytable.myid)", + ) def test_in_19(self): - self.assert_compile(self.table1.c.myid.in_([1, 2, 3]), - "mytable.myid IN (:myid_1, :myid_2, :myid_3)") + self.assert_compile( + self.table1.c.myid.in_([1, 2, 3]), + "mytable.myid IN (:myid_1, :myid_2, :myid_3)", + ) def test_in_20(self): - self.assert_compile(self.table1.c.myid.in_( - select([self.table2.c.otherid])), - "mytable.myid IN (SELECT myothertable.otherid FROM myothertable)") + self.assert_compile( + self.table1.c.myid.in_(select([self.table2.c.otherid])), + "mytable.myid IN (SELECT myothertable.otherid FROM myothertable)", + ) def test_in_21(self): - self.assert_compile(~self.table1.c.myid.in_( - select([self.table2.c.otherid])), + self.assert_compile( + ~self.table1.c.myid.in_(select([self.table2.c.otherid])), "mytable.myid NOT IN " - "(SELECT myothertable.otherid FROM myothertable)") + "(SELECT myothertable.otherid FROM myothertable)", + ) def test_in_22(self): self.assert_compile( @@ -1768,49 +1631,64 @@ class InTest(fixtures.TestBase, testing.AssertsCompiledSQL): text("SELECT myothertable.otherid FROM myothertable") ), "mytable.myid IN (SELECT myothertable.otherid " - "FROM myothertable)" + "FROM myothertable)", ) def test_in_24(self): self.assert_compile( select([self.table1.c.myid.in_(select([self.table2.c.otherid]))]), "SELECT mytable.myid IN (SELECT myothertable.otherid " - "FROM myothertable) AS anon_1 FROM mytable" + "FROM myothertable) AS anon_1 FROM mytable", ) def test_in_25(self): self.assert_compile( - select([self.table1.c.myid.in_( - select([self.table2.c.otherid]).as_scalar())]), + select( + [ + self.table1.c.myid.in_( + select([self.table2.c.otherid]).as_scalar() + ) + ] + ), "SELECT mytable.myid IN (SELECT myothertable.otherid " - "FROM myothertable) AS anon_1 FROM mytable" + "FROM myothertable) AS anon_1 FROM mytable", ) def test_in_26(self): - self.assert_compile(self.table1.c.myid.in_( - union( - select([self.table1.c.myid], self.table1.c.myid == 5), - select([self.table1.c.myid], self.table1.c.myid == 12), - ) - ), "mytable.myid IN (" + self.assert_compile( + self.table1.c.myid.in_( + union( + select([self.table1.c.myid], self.table1.c.myid == 5), + select([self.table1.c.myid], self.table1.c.myid == 12), + ) + ), + "mytable.myid IN (" "SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_1 " "UNION SELECT mytable.myid FROM mytable " - "WHERE mytable.myid = :myid_2)") + "WHERE mytable.myid = :myid_2)", + ) def test_in_27(self): # test that putting a select in an IN clause does not # blow away its ORDER BY clause self.assert_compile( - select([self.table1, self.table2], - self.table2.c.otherid.in_( - select([self.table2.c.otherid], - order_by=[self.table2.c.othername], - limit=10, correlate=False) - ), - from_obj=[self.table1.join( - self.table2, - self.table1.c.myid == self.table2.c.otherid)], - order_by=[self.table1.c.myid] + select( + [self.table1, self.table2], + self.table2.c.otherid.in_( + select( + [self.table2.c.otherid], + order_by=[self.table2.c.othername], + limit=10, + correlate=False, + ) + ), + from_obj=[ + self.table1.join( + self.table2, + self.table1.c.myid == self.table2.c.otherid, + ) + ], + order_by=[self.table1.c.myid], ), "SELECT mytable.myid, " "myothertable.otherid, myothertable.othername FROM mytable " @@ -1818,115 +1696,127 @@ class InTest(fixtures.TestBase, testing.AssertsCompiledSQL): "WHERE myothertable.otherid IN (SELECT myothertable.otherid " "FROM myothertable ORDER BY myothertable.othername " "LIMIT :param_1) ORDER BY mytable.myid", - {'param_1': 10} + {"param_1": 10}, ) def test_in_28(self): self.assert_compile( - self.table1.c.myid.in_([None]), - "mytable.myid IN (NULL)" + self.table1.c.myid.in_([None]), "mytable.myid IN (NULL)" ) def test_empty_in_dynamic_1(self): - self.assert_compile(self.table1.c.myid.in_([]), - "mytable.myid != mytable.myid", - dialect=self._dialect("dynamic")) + self.assert_compile( + self.table1.c.myid.in_([]), + "mytable.myid != mytable.myid", + dialect=self._dialect("dynamic"), + ) def test_empty_in_dynamic_2(self): - self.assert_compile(self.table1.c.myid.notin_([]), - "mytable.myid = mytable.myid", - dialect=self._dialect("dynamic")) + self.assert_compile( + self.table1.c.myid.notin_([]), + "mytable.myid = mytable.myid", + dialect=self._dialect("dynamic"), + ) def test_empty_in_dynamic_3(self): - self.assert_compile(~self.table1.c.myid.in_([]), - "mytable.myid = mytable.myid", - dialect=self._dialect("dynamic")) + self.assert_compile( + ~self.table1.c.myid.in_([]), + "mytable.myid = mytable.myid", + dialect=self._dialect("dynamic"), + ) def test_empty_in_dynamic_warn_1(self): with testing.expect_warnings( - "The IN-predicate was invoked with an empty sequence."): - self.assert_compile(self.table1.c.myid.in_([]), - "mytable.myid != mytable.myid", - dialect=self._dialect("dynamic_warn")) + "The IN-predicate was invoked with an empty sequence." + ): + self.assert_compile( + self.table1.c.myid.in_([]), + "mytable.myid != mytable.myid", + dialect=self._dialect("dynamic_warn"), + ) def test_empty_in_dynamic_warn_2(self): with testing.expect_warnings( - "The IN-predicate was invoked with an empty sequence."): - self.assert_compile(self.table1.c.myid.notin_([]), - "mytable.myid = mytable.myid", - dialect=self._dialect("dynamic_warn")) + "The IN-predicate was invoked with an empty sequence." + ): + self.assert_compile( + self.table1.c.myid.notin_([]), + "mytable.myid = mytable.myid", + dialect=self._dialect("dynamic_warn"), + ) def test_empty_in_dynamic_warn_3(self): with testing.expect_warnings( - "The IN-predicate was invoked with an empty sequence."): - self.assert_compile(~self.table1.c.myid.in_([]), - "mytable.myid = mytable.myid", - dialect=self._dialect("dynamic_warn")) + "The IN-predicate was invoked with an empty sequence." + ): + self.assert_compile( + ~self.table1.c.myid.in_([]), + "mytable.myid = mytable.myid", + dialect=self._dialect("dynamic_warn"), + ) def test_empty_in_static_1(self): - self.assert_compile(self.table1.c.myid.in_([]), - "1 != 1") + self.assert_compile(self.table1.c.myid.in_([]), "1 != 1") def test_empty_in_static_2(self): - self.assert_compile(self.table1.c.myid.notin_([]), - "1 = 1") + self.assert_compile(self.table1.c.myid.notin_([]), "1 = 1") def test_empty_in_static_3(self): - self.assert_compile(~self.table1.c.myid.in_([]), - "1 = 1") + self.assert_compile(~self.table1.c.myid.in_([]), "1 = 1") class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" - table1 = table('mytable', - column('myid', Integer), - ) + table1 = table("mytable", column("myid", Integer)) def _test_math_op(self, py_op, sql_op): for (lhs, rhs, res) in ( - (5, self.table1.c.myid, ':myid_1 %s mytable.myid'), - (5, literal(5), ':param_1 %s :param_2'), - (self.table1.c.myid, 'b', 'mytable.myid %s :myid_1'), - (self.table1.c.myid, literal(2.7), 'mytable.myid %s :param_1'), - (self.table1.c.myid, self.table1.c.myid, - 'mytable.myid %s mytable.myid'), - (literal(5), 8, ':param_1 %s :param_2'), - (literal(6), self.table1.c.myid, ':param_1 %s mytable.myid'), - (literal(7), literal(5.5), ':param_1 %s :param_2'), + (5, self.table1.c.myid, ":myid_1 %s mytable.myid"), + (5, literal(5), ":param_1 %s :param_2"), + (self.table1.c.myid, "b", "mytable.myid %s :myid_1"), + (self.table1.c.myid, literal(2.7), "mytable.myid %s :param_1"), + ( + self.table1.c.myid, + self.table1.c.myid, + "mytable.myid %s mytable.myid", + ), + (literal(5), 8, ":param_1 %s :param_2"), + (literal(6), self.table1.c.myid, ":param_1 %s mytable.myid"), + (literal(7), literal(5.5), ":param_1 %s :param_2"), ): self.assert_compile(py_op(lhs, rhs), res % sql_op) def test_math_op_add(self): - self._test_math_op(operator.add, '+') + self._test_math_op(operator.add, "+") def test_math_op_mul(self): - self._test_math_op(operator.mul, '*') + self._test_math_op(operator.mul, "*") def test_math_op_sub(self): - self._test_math_op(operator.sub, '-') + self._test_math_op(operator.sub, "-") def test_math_op_div(self): if util.py3k: - self._test_math_op(operator.truediv, '/') + self._test_math_op(operator.truediv, "/") else: - self._test_math_op(operator.div, '/') + self._test_math_op(operator.div, "/") def test_math_op_mod(self): - self._test_math_op(operator.mod, '%') + self._test_math_op(operator.mod, "%") class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" - table1 = table('mytable', - column('myid', Integer), - ) + table1 = table("mytable", column("myid", Integer)) def test_pickle_operators_one(self): - clause = (self.table1.c.myid == 12) & \ - self.table1.c.myid.between(15, 20) & \ - self.table1.c.myid.like('hoho') + clause = ( + (self.table1.c.myid == 12) + & self.table1.c.myid.between(15, 20) + & self.table1.c.myid.like("hoho") + ) eq_(str(clause), str(util.pickle.loads(util.pickle.dumps(clause)))) def test_pickle_operators_two(self): @@ -1936,17 +1826,21 @@ class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): def _test_comparison_op(self, py_op, fwd_op, rev_op): dt = datetime.datetime(2012, 5, 10, 15, 27, 18) for (lhs, rhs, l_sql, r_sql) in ( - ('a', self.table1.c.myid, ':myid_1', 'mytable.myid'), - ('a', literal('b'), ':param_2', ':param_1'), # note swap! - (self.table1.c.myid, 'b', 'mytable.myid', ':myid_1'), - (self.table1.c.myid, literal('b'), 'mytable.myid', ':param_1'), - (self.table1.c.myid, self.table1.c.myid, - 'mytable.myid', 'mytable.myid'), - (literal('a'), 'b', ':param_1', ':param_2'), - (literal('a'), self.table1.c.myid, ':param_1', 'mytable.myid'), - (literal('a'), literal('b'), ':param_1', ':param_2'), - (dt, literal('b'), ':param_2', ':param_1'), - (literal('b'), dt, ':param_1', ':param_2'), + ("a", self.table1.c.myid, ":myid_1", "mytable.myid"), + ("a", literal("b"), ":param_2", ":param_1"), # note swap! + (self.table1.c.myid, "b", "mytable.myid", ":myid_1"), + (self.table1.c.myid, literal("b"), "mytable.myid", ":param_1"), + ( + self.table1.c.myid, + self.table1.c.myid, + "mytable.myid", + "mytable.myid", + ), + (literal("a"), "b", ":param_1", ":param_2"), + (literal("a"), self.table1.c.myid, ":param_1", "mytable.myid"), + (literal("a"), literal("b"), ":param_1", ":param_2"), + (dt, literal("b"), ":param_2", ":param_1"), + (literal("b"), dt, ":param_1", ":param_2"), ): # the compiled clause should match either (e.g.): @@ -1955,36 +1849,43 @@ class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql) rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql) - self.assert_(compiled == fwd_sql or compiled == rev_sql, - "\n'" + compiled + "'\n does not match\n'" + - fwd_sql + "'\n or\n'" + rev_sql + "'") + self.assert_( + compiled == fwd_sql or compiled == rev_sql, + "\n'" + + compiled + + "'\n does not match\n'" + + fwd_sql + + "'\n or\n'" + + rev_sql + + "'", + ) def test_comparison_operators_lt(self): - self._test_comparison_op(operator.lt, '<', '>'), + self._test_comparison_op(operator.lt, "<", ">"), def test_comparison_operators_gt(self): - self._test_comparison_op(operator.gt, '>', '<') + self._test_comparison_op(operator.gt, ">", "<") def test_comparison_operators_eq(self): - self._test_comparison_op(operator.eq, '=', '=') + self._test_comparison_op(operator.eq, "=", "=") def test_comparison_operators_ne(self): - self._test_comparison_op(operator.ne, '!=', '!=') + self._test_comparison_op(operator.ne, "!=", "!=") def test_comparison_operators_le(self): - self._test_comparison_op(operator.le, '<=', '>=') + self._test_comparison_op(operator.le, "<=", ">=") def test_comparison_operators_ge(self): - self._test_comparison_op(operator.ge, '>=', '<=') + self._test_comparison_op(operator.ge, ">=", "<=") class NonZeroTest(fixtures.TestBase): - def _raises(self, expr): assert_raises_message( TypeError, "Boolean value of this clause is not defined", - bool, expr + bool, + expr, ) def _assert_true(self, expr): @@ -1994,57 +1895,51 @@ class NonZeroTest(fixtures.TestBase): is_(bool(expr), False) def test_column_identity_eq(self): - c1 = column('c1') + c1 = column("c1") self._assert_true(c1 == c1) def test_column_identity_gt(self): - c1 = column('c1') + c1 = column("c1") self._raises(c1 > c1) def test_column_compare_eq(self): - c1, c2 = column('c1'), column('c2') + c1, c2 = column("c1"), column("c2") self._assert_false(c1 == c2) def test_column_compare_gt(self): - c1, c2 = column('c1'), column('c2') + c1, c2 = column("c1"), column("c2") self._raises(c1 > c2) def test_binary_identity_eq(self): - c1 = column('c1') + c1 = column("c1") expr = c1 > 5 self._assert_true(expr == expr) def test_labeled_binary_identity_eq(self): - c1 = column('c1') + c1 = column("c1") expr = (c1 > 5).label(None) self._assert_true(expr == expr) def test_annotated_binary_identity_eq(self): - c1 = column('c1') - expr1 = (c1 > 5) + c1 = column("c1") + expr1 = c1 > 5 expr2 = expr1._annotate({"foo": "bar"}) self._assert_true(expr1 == expr2) def test_labeled_binary_compare_gt(self): - c1 = column('c1') + c1 = column("c1") expr1 = (c1 > 5).label(None) expr2 = (c1 > 5).label(None) self._assert_false(expr1 == expr2) class NegationTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" - table1 = table('mytable', - column('myid', Integer), - column('name', String), - ) + table1 = table("mytable", column("myid", Integer), column("name", String)) def test_negate_operators_1(self): - for (py_op, op) in ( - (operator.neg, '-'), - (operator.inv, 'NOT '), - ): + for (py_op, op) in ((operator.neg, "-"), (operator.inv, "NOT ")): for expr, expected in ( (self.table1.c.myid, "mytable.myid"), (literal("foo"), ":param_1"), @@ -2053,72 +1948,80 @@ class NegationTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_negate_operators_2(self): self.assert_compile( - self.table1.select((self.table1.c.myid != 12) & - ~(self.table1.c.name == 'john')), + self.table1.select( + (self.table1.c.myid != 12) & ~(self.table1.c.name == "john") + ), "SELECT mytable.myid, mytable.name FROM " "mytable WHERE mytable.myid != :myid_1 " - "AND mytable.name != :name_1" + "AND mytable.name != :name_1", ) def test_negate_operators_3(self): self.assert_compile( - self.table1.select((self.table1.c.myid != 12) & - ~(self.table1.c.name.between('jack', 'john'))), + self.table1.select( + (self.table1.c.myid != 12) + & ~(self.table1.c.name.between("jack", "john")) + ), "SELECT mytable.myid, mytable.name FROM " "mytable WHERE mytable.myid != :myid_1 AND " - "mytable.name NOT BETWEEN :name_1 AND :name_2" + "mytable.name NOT BETWEEN :name_1 AND :name_2", ) def test_negate_operators_4(self): self.assert_compile( - self.table1.select((self.table1.c.myid != 12) & - ~and_(self.table1.c.name == 'john', - self.table1.c.name == 'ed', - self.table1.c.name == 'fred')), + self.table1.select( + (self.table1.c.myid != 12) + & ~and_( + self.table1.c.name == "john", + self.table1.c.name == "ed", + self.table1.c.name == "fred", + ) + ), "SELECT mytable.myid, mytable.name FROM " "mytable WHERE mytable.myid != :myid_1 AND " "NOT (mytable.name = :name_1 AND mytable.name = :name_2 " - "AND mytable.name = :name_3)" + "AND mytable.name = :name_3)", ) def test_negate_operators_5(self): self.assert_compile( self.table1.select( - (self.table1.c.myid != 12) & ~self.table1.c.name), + (self.table1.c.myid != 12) & ~self.table1.c.name + ), "SELECT mytable.myid, mytable.name FROM " - "mytable WHERE mytable.myid != :myid_1 AND NOT mytable.name") + "mytable WHERE mytable.myid != :myid_1 AND NOT mytable.name", + ) def test_negate_operator_type(self): - is_( - (-self.table1.c.myid).type, - self.table1.c.myid.type, - ) + is_((-self.table1.c.myid).type, self.table1.c.myid.type) def test_negate_operator_label(self): orig_expr = or_( - self.table1.c.myid == 1, self.table1.c.myid == 2).label('foo') + self.table1.c.myid == 1, self.table1.c.myid == 2 + ).label("foo") expr = not_(orig_expr) isinstance(expr, Label) - eq_(expr.name, 'foo') + eq_(expr.name, "foo") is_not_(expr, orig_expr) is_(expr._element.operator, operator.inv) # e.g. and not false_ self.assert_compile( expr, "NOT (mytable.myid = :myid_1 OR mytable.myid = :myid_2)", - dialect=default.DefaultDialect(supports_native_boolean=False) + dialect=default.DefaultDialect(supports_native_boolean=False), ) def test_negate_operator_self_group(self): orig_expr = or_( - self.table1.c.myid == 1, self.table1.c.myid == 2).self_group() + self.table1.c.myid == 1, self.table1.c.myid == 2 + ).self_group() expr = not_(orig_expr) is_not_(expr, orig_expr) self.assert_compile( expr, "NOT (mytable.myid = :myid_1 OR mytable.myid = :myid_2)", - dialect=default.DefaultDialect(supports_native_boolean=False) + dialect=default.DefaultDialect(supports_native_boolean=False), ) def test_implicitly_boolean(self): @@ -2127,569 +2030,585 @@ class NegationTest(fixtures.TestBase, testing.AssertsCompiledSQL): assert not self.table1.c.myid._is_implicitly_boolean assert (self.table1.c.myid == 5)._is_implicitly_boolean assert (self.table1.c.myid == 5).self_group()._is_implicitly_boolean - assert (self.table1.c.myid == 5).label('x')._is_implicitly_boolean + assert (self.table1.c.myid == 5).label("x")._is_implicitly_boolean assert not_(self.table1.c.myid == 5)._is_implicitly_boolean assert or_( self.table1.c.myid == 5, self.table1.c.myid == 7 )._is_implicitly_boolean - assert not column('x', Boolean)._is_implicitly_boolean + assert not column("x", Boolean)._is_implicitly_boolean assert not (self.table1.c.myid + 5)._is_implicitly_boolean - assert not not_(column('x', Boolean))._is_implicitly_boolean - assert not select([self.table1.c.myid]).\ - as_scalar()._is_implicitly_boolean + assert not not_(column("x", Boolean))._is_implicitly_boolean + assert ( + not select([self.table1.c.myid]).as_scalar()._is_implicitly_boolean + ) assert not text("x = y")._is_implicitly_boolean assert not literal_column("x = y")._is_implicitly_boolean class LikeTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" - table1 = table('mytable', - column('myid', Integer), - column('name', String), - ) + table1 = table("mytable", column("myid", Integer), column("name", String)) def test_like_1(self): self.assert_compile( - self.table1.c.myid.like('somstr'), - "mytable.myid LIKE :myid_1") + self.table1.c.myid.like("somstr"), "mytable.myid LIKE :myid_1" + ) def test_like_2(self): self.assert_compile( - ~self.table1.c.myid.like('somstr'), - "mytable.myid NOT LIKE :myid_1") + ~self.table1.c.myid.like("somstr"), "mytable.myid NOT LIKE :myid_1" + ) def test_like_3(self): self.assert_compile( - self.table1.c.myid.like('somstr', escape='\\'), - "mytable.myid LIKE :myid_1 ESCAPE '\\'") + self.table1.c.myid.like("somstr", escape="\\"), + "mytable.myid LIKE :myid_1 ESCAPE '\\'", + ) def test_like_4(self): self.assert_compile( - ~self.table1.c.myid.like('somstr', escape='\\'), - "mytable.myid NOT LIKE :myid_1 ESCAPE '\\'") + ~self.table1.c.myid.like("somstr", escape="\\"), + "mytable.myid NOT LIKE :myid_1 ESCAPE '\\'", + ) def test_like_5(self): self.assert_compile( - self.table1.c.myid.ilike('somstr', escape='\\'), - "lower(mytable.myid) LIKE lower(:myid_1) ESCAPE '\\'") + self.table1.c.myid.ilike("somstr", escape="\\"), + "lower(mytable.myid) LIKE lower(:myid_1) ESCAPE '\\'", + ) def test_like_6(self): self.assert_compile( - ~self.table1.c.myid.ilike('somstr', escape='\\'), - "lower(mytable.myid) NOT LIKE lower(:myid_1) ESCAPE '\\'") + ~self.table1.c.myid.ilike("somstr", escape="\\"), + "lower(mytable.myid) NOT LIKE lower(:myid_1) ESCAPE '\\'", + ) def test_like_7(self): self.assert_compile( - self.table1.c.myid.ilike('somstr', escape='\\'), + self.table1.c.myid.ilike("somstr", escape="\\"), "mytable.myid ILIKE %(myid_1)s ESCAPE '\\\\'", - dialect=postgresql.dialect()) + dialect=postgresql.dialect(), + ) def test_like_8(self): self.assert_compile( - ~self.table1.c.myid.ilike('somstr', escape='\\'), + ~self.table1.c.myid.ilike("somstr", escape="\\"), "mytable.myid NOT ILIKE %(myid_1)s ESCAPE '\\\\'", - dialect=postgresql.dialect()) + dialect=postgresql.dialect(), + ) def test_like_9(self): self.assert_compile( - self.table1.c.name.ilike('%something%'), - "lower(mytable.name) LIKE lower(:name_1)") + self.table1.c.name.ilike("%something%"), + "lower(mytable.name) LIKE lower(:name_1)", + ) def test_like_10(self): self.assert_compile( - self.table1.c.name.ilike('%something%'), + self.table1.c.name.ilike("%something%"), "mytable.name ILIKE %(name_1)s", - dialect=postgresql.dialect()) + dialect=postgresql.dialect(), + ) def test_like_11(self): self.assert_compile( - ~self.table1.c.name.ilike('%something%'), - "lower(mytable.name) NOT LIKE lower(:name_1)") + ~self.table1.c.name.ilike("%something%"), + "lower(mytable.name) NOT LIKE lower(:name_1)", + ) def test_like_12(self): self.assert_compile( - ~self.table1.c.name.ilike('%something%'), + ~self.table1.c.name.ilike("%something%"), "mytable.name NOT ILIKE %(name_1)s", - dialect=postgresql.dialect()) + dialect=postgresql.dialect(), + ) class BetweenTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" - table1 = table('mytable', - column('myid', Integer), - column('name', String), - ) + table1 = table("mytable", column("myid", Integer), column("name", String)) def test_between_1(self): self.assert_compile( self.table1.c.myid.between(1, 2), - "mytable.myid BETWEEN :myid_1 AND :myid_2") + "mytable.myid BETWEEN :myid_1 AND :myid_2", + ) def test_between_2(self): self.assert_compile( ~self.table1.c.myid.between(1, 2), - "mytable.myid NOT BETWEEN :myid_1 AND :myid_2") + "mytable.myid NOT BETWEEN :myid_1 AND :myid_2", + ) def test_between_3(self): self.assert_compile( self.table1.c.myid.between(1, 2, symmetric=True), - "mytable.myid BETWEEN SYMMETRIC :myid_1 AND :myid_2") + "mytable.myid BETWEEN SYMMETRIC :myid_1 AND :myid_2", + ) def test_between_4(self): self.assert_compile( ~self.table1.c.myid.between(1, 2, symmetric=True), - "mytable.myid NOT BETWEEN SYMMETRIC :myid_1 AND :myid_2") + "mytable.myid NOT BETWEEN SYMMETRIC :myid_1 AND :myid_2", + ) def test_between_5(self): self.assert_compile( between(self.table1.c.myid, 1, 2, symmetric=True), - "mytable.myid BETWEEN SYMMETRIC :myid_1 AND :myid_2") + "mytable.myid BETWEEN SYMMETRIC :myid_1 AND :myid_2", + ) def test_between_6(self): self.assert_compile( ~between(self.table1.c.myid, 1, 2, symmetric=True), - "mytable.myid NOT BETWEEN SYMMETRIC :myid_1 AND :myid_2") + "mytable.myid NOT BETWEEN SYMMETRIC :myid_1 AND :myid_2", + ) class MatchTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" - table1 = table('mytable', - column('myid', Integer), - column('name', String), - ) + table1 = table("mytable", column("myid", Integer), column("name", String)) def test_match_1(self): - self.assert_compile(self.table1.c.myid.match('somstr'), - "mytable.myid MATCH ?", - dialect=sqlite.dialect()) + self.assert_compile( + self.table1.c.myid.match("somstr"), + "mytable.myid MATCH ?", + dialect=sqlite.dialect(), + ) def test_match_2(self): self.assert_compile( - self.table1.c.myid.match('somstr'), + self.table1.c.myid.match("somstr"), "MATCH (mytable.myid) AGAINST (%s IN BOOLEAN MODE)", - dialect=mysql.dialect()) + dialect=mysql.dialect(), + ) def test_match_3(self): - self.assert_compile(self.table1.c.myid.match('somstr'), - "CONTAINS (mytable.myid, :myid_1)", - dialect=mssql.dialect()) + self.assert_compile( + self.table1.c.myid.match("somstr"), + "CONTAINS (mytable.myid, :myid_1)", + dialect=mssql.dialect(), + ) def test_match_4(self): - self.assert_compile(self.table1.c.myid.match('somstr'), - "mytable.myid @@ to_tsquery(%(myid_1)s)", - dialect=postgresql.dialect()) + self.assert_compile( + self.table1.c.myid.match("somstr"), + "mytable.myid @@ to_tsquery(%(myid_1)s)", + dialect=postgresql.dialect(), + ) def test_match_5(self): - self.assert_compile(self.table1.c.myid.match('somstr'), - "CONTAINS (mytable.myid, :myid_1)", - dialect=oracle.dialect()) + self.assert_compile( + self.table1.c.myid.match("somstr"), + "CONTAINS (mytable.myid, :myid_1)", + dialect=oracle.dialect(), + ) def test_match_is_now_matchtype(self): - expr = self.table1.c.myid.match('somstr') + expr = self.table1.c.myid.match("somstr") assert expr.type._type_affinity is MatchType()._type_affinity assert isinstance(expr.type, MatchType) def test_boolean_inversion_postgresql(self): self.assert_compile( - ~self.table1.c.myid.match('somstr'), + ~self.table1.c.myid.match("somstr"), "NOT mytable.myid @@ to_tsquery(%(myid_1)s)", - dialect=postgresql.dialect()) + dialect=postgresql.dialect(), + ) def test_boolean_inversion_mysql(self): # because mysql doesnt have native boolean self.assert_compile( - ~self.table1.c.myid.match('somstr'), + ~self.table1.c.myid.match("somstr"), "NOT MATCH (mytable.myid) AGAINST (%s IN BOOLEAN MODE)", - dialect=mysql.dialect()) + dialect=mysql.dialect(), + ) def test_boolean_inversion_mssql(self): # because mssql doesnt have native boolean self.assert_compile( - ~self.table1.c.myid.match('somstr'), + ~self.table1.c.myid.match("somstr"), "NOT CONTAINS (mytable.myid, :myid_1)", - dialect=mssql.dialect()) + dialect=mssql.dialect(), + ) class ComposedLikeOperatorsTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_contains(self): self.assert_compile( - column('x').contains('y'), + column("x").contains("y"), "x LIKE '%' || :x_1 || '%'", - checkparams={'x_1': 'y'} + checkparams={"x_1": "y"}, ) def test_contains_escape(self): self.assert_compile( - column('x').contains('a%b_c', escape='\\'), + column("x").contains("a%b_c", escape="\\"), "x LIKE '%' || :x_1 || '%' ESCAPE '\\'", - checkparams={'x_1': 'a%b_c'} + checkparams={"x_1": "a%b_c"}, ) def test_contains_autoescape(self): self.assert_compile( - column('x').contains('a%b_c/d', autoescape=True), + column("x").contains("a%b_c/d", autoescape=True), "x LIKE '%' || :x_1 || '%' ESCAPE '/'", - checkparams={'x_1': 'a/%b/_c//d'} + checkparams={"x_1": "a/%b/_c//d"}, ) def test_contains_literal(self): self.assert_compile( - column('x').contains(literal_column('y')), + column("x").contains(literal_column("y")), "x LIKE '%' || y || '%'", - checkparams={} + checkparams={}, ) def test_contains_text(self): self.assert_compile( - column('x').contains(text('y')), + column("x").contains(text("y")), "x LIKE '%' || y || '%'", - checkparams={} + checkparams={}, ) def test_not_contains(self): self.assert_compile( - ~column('x').contains('y'), + ~column("x").contains("y"), "x NOT LIKE '%' || :x_1 || '%'", - checkparams={'x_1': 'y'} + checkparams={"x_1": "y"}, ) def test_not_contains_escape(self): self.assert_compile( - ~column('x').contains('a%b_c', escape='\\'), + ~column("x").contains("a%b_c", escape="\\"), "x NOT LIKE '%' || :x_1 || '%' ESCAPE '\\'", - checkparams={'x_1': 'a%b_c'} + checkparams={"x_1": "a%b_c"}, ) def test_not_contains_autoescape(self): self.assert_compile( - ~column('x').contains('a%b_c/d', autoescape=True), + ~column("x").contains("a%b_c/d", autoescape=True), "x NOT LIKE '%' || :x_1 || '%' ESCAPE '/'", - checkparams={'x_1': 'a/%b/_c//d'} + checkparams={"x_1": "a/%b/_c//d"}, ) def test_contains_concat(self): self.assert_compile( - column('x').contains('y'), + column("x").contains("y"), "x LIKE concat(concat('%%', %s), '%%')", - checkparams={'x_1': 'y'}, - dialect=mysql.dialect() + checkparams={"x_1": "y"}, + dialect=mysql.dialect(), ) def test_not_contains_concat(self): self.assert_compile( - ~column('x').contains('y'), + ~column("x").contains("y"), "x NOT LIKE concat(concat('%%', %s), '%%')", - checkparams={'x_1': 'y'}, - dialect=mysql.dialect() + checkparams={"x_1": "y"}, + dialect=mysql.dialect(), ) def test_contains_literal_concat(self): self.assert_compile( - column('x').contains(literal_column('y')), + column("x").contains(literal_column("y")), "x LIKE concat(concat('%%', y), '%%')", checkparams={}, - dialect=mysql.dialect() + dialect=mysql.dialect(), ) def test_contains_text_concat(self): self.assert_compile( - column('x').contains(text('y')), + column("x").contains(text("y")), "x LIKE concat(concat('%%', y), '%%')", checkparams={}, - dialect=mysql.dialect() + dialect=mysql.dialect(), ) def test_like(self): self.assert_compile( - column('x').like('y'), - "x LIKE :x_1", - checkparams={'x_1': 'y'} + column("x").like("y"), "x LIKE :x_1", checkparams={"x_1": "y"} ) def test_like_escape(self): self.assert_compile( - column('x').like('a%b_c', escape='\\'), + column("x").like("a%b_c", escape="\\"), "x LIKE :x_1 ESCAPE '\\'", - checkparams={'x_1': 'a%b_c'} + checkparams={"x_1": "a%b_c"}, ) def test_ilike(self): self.assert_compile( - column('x').ilike('y'), + column("x").ilike("y"), "lower(x) LIKE lower(:x_1)", - checkparams={'x_1': 'y'} + checkparams={"x_1": "y"}, ) def test_ilike_escape(self): self.assert_compile( - column('x').ilike('a%b_c', escape='\\'), + column("x").ilike("a%b_c", escape="\\"), "lower(x) LIKE lower(:x_1) ESCAPE '\\'", - checkparams={'x_1': 'a%b_c'} + checkparams={"x_1": "a%b_c"}, ) def test_notlike(self): self.assert_compile( - column('x').notlike('y'), + column("x").notlike("y"), "x NOT LIKE :x_1", - checkparams={'x_1': 'y'} + checkparams={"x_1": "y"}, ) def test_notlike_escape(self): self.assert_compile( - column('x').notlike('a%b_c', escape='\\'), + column("x").notlike("a%b_c", escape="\\"), "x NOT LIKE :x_1 ESCAPE '\\'", - checkparams={'x_1': 'a%b_c'} + checkparams={"x_1": "a%b_c"}, ) def test_notilike(self): self.assert_compile( - column('x').notilike('y'), + column("x").notilike("y"), "lower(x) NOT LIKE lower(:x_1)", - checkparams={'x_1': 'y'} + checkparams={"x_1": "y"}, ) def test_notilike_escape(self): self.assert_compile( - column('x').notilike('a%b_c', escape='\\'), + column("x").notilike("a%b_c", escape="\\"), "lower(x) NOT LIKE lower(:x_1) ESCAPE '\\'", - checkparams={'x_1': 'a%b_c'} + checkparams={"x_1": "a%b_c"}, ) def test_startswith(self): self.assert_compile( - column('x').startswith('y'), + column("x").startswith("y"), "x LIKE :x_1 || '%'", - checkparams={'x_1': 'y'} + checkparams={"x_1": "y"}, ) def test_startswith_escape(self): self.assert_compile( - column('x').startswith('a%b_c', escape='\\'), + column("x").startswith("a%b_c", escape="\\"), "x LIKE :x_1 || '%' ESCAPE '\\'", - checkparams={'x_1': 'a%b_c'} + checkparams={"x_1": "a%b_c"}, ) def test_startswith_autoescape(self): self.assert_compile( - column('x').startswith('a%b_c/d', autoescape=True), + column("x").startswith("a%b_c/d", autoescape=True), "x LIKE :x_1 || '%' ESCAPE '/'", - checkparams={'x_1': 'a/%b/_c//d'} + checkparams={"x_1": "a/%b/_c//d"}, ) def test_startswith_autoescape_custom_escape(self): self.assert_compile( - column('x').startswith('a%b_c/d^e', autoescape=True, escape='^'), + column("x").startswith("a%b_c/d^e", autoescape=True, escape="^"), "x LIKE :x_1 || '%' ESCAPE '^'", - checkparams={'x_1': 'a^%b^_c/d^^e'} + checkparams={"x_1": "a^%b^_c/d^^e"}, ) def test_not_startswith(self): self.assert_compile( - ~column('x').startswith('y'), + ~column("x").startswith("y"), "x NOT LIKE :x_1 || '%'", - checkparams={'x_1': 'y'} + checkparams={"x_1": "y"}, ) def test_not_startswith_escape(self): self.assert_compile( - ~column('x').startswith('a%b_c', escape='\\'), + ~column("x").startswith("a%b_c", escape="\\"), "x NOT LIKE :x_1 || '%' ESCAPE '\\'", - checkparams={'x_1': 'a%b_c'} + checkparams={"x_1": "a%b_c"}, ) def test_not_startswith_autoescape(self): self.assert_compile( - ~column('x').startswith('a%b_c/d', autoescape=True), + ~column("x").startswith("a%b_c/d", autoescape=True), "x NOT LIKE :x_1 || '%' ESCAPE '/'", - checkparams={'x_1': 'a/%b/_c//d'} + checkparams={"x_1": "a/%b/_c//d"}, ) def test_startswith_literal(self): self.assert_compile( - column('x').startswith(literal_column('y')), + column("x").startswith(literal_column("y")), "x LIKE y || '%'", - checkparams={} + checkparams={}, ) def test_startswith_text(self): self.assert_compile( - column('x').startswith(text('y')), + column("x").startswith(text("y")), "x LIKE y || '%'", - checkparams={} + checkparams={}, ) def test_startswith_concat(self): self.assert_compile( - column('x').startswith('y'), + column("x").startswith("y"), "x LIKE concat(%s, '%%')", - checkparams={'x_1': 'y'}, - dialect=mysql.dialect() + checkparams={"x_1": "y"}, + dialect=mysql.dialect(), ) def test_not_startswith_concat(self): self.assert_compile( - ~column('x').startswith('y'), + ~column("x").startswith("y"), "x NOT LIKE concat(%s, '%%')", - checkparams={'x_1': 'y'}, - dialect=mysql.dialect() + checkparams={"x_1": "y"}, + dialect=mysql.dialect(), ) def test_startswith_firebird(self): self.assert_compile( - column('x').startswith('y'), + column("x").startswith("y"), "x STARTING WITH :x_1", - checkparams={'x_1': 'y'}, - dialect=firebird.dialect() + checkparams={"x_1": "y"}, + dialect=firebird.dialect(), ) def test_not_startswith_firebird(self): self.assert_compile( - ~column('x').startswith('y'), + ~column("x").startswith("y"), "x NOT STARTING WITH :x_1", - checkparams={'x_1': 'y'}, - dialect=firebird.dialect() + checkparams={"x_1": "y"}, + dialect=firebird.dialect(), ) def test_startswith_literal_mysql(self): self.assert_compile( - column('x').startswith(literal_column('y')), + column("x").startswith(literal_column("y")), "x LIKE concat(y, '%%')", checkparams={}, - dialect=mysql.dialect() + dialect=mysql.dialect(), ) def test_startswith_text_mysql(self): self.assert_compile( - column('x').startswith(text('y')), + column("x").startswith(text("y")), "x LIKE concat(y, '%%')", checkparams={}, - dialect=mysql.dialect() + dialect=mysql.dialect(), ) def test_endswith(self): self.assert_compile( - column('x').endswith('y'), + column("x").endswith("y"), "x LIKE '%' || :x_1", - checkparams={'x_1': 'y'} + checkparams={"x_1": "y"}, ) def test_endswith_escape(self): self.assert_compile( - column('x').endswith('a%b_c', escape='\\'), + column("x").endswith("a%b_c", escape="\\"), "x LIKE '%' || :x_1 ESCAPE '\\'", - checkparams={'x_1': 'a%b_c'} + checkparams={"x_1": "a%b_c"}, ) def test_endswith_autoescape(self): self.assert_compile( - column('x').endswith('a%b_c/d', autoescape=True), + column("x").endswith("a%b_c/d", autoescape=True), "x LIKE '%' || :x_1 ESCAPE '/'", - checkparams={'x_1': 'a/%b/_c//d'} + checkparams={"x_1": "a/%b/_c//d"}, ) def test_endswith_autoescape_custom_escape(self): self.assert_compile( - column('x').endswith('a%b_c/d^e', autoescape=True, escape="^"), + column("x").endswith("a%b_c/d^e", autoescape=True, escape="^"), "x LIKE '%' || :x_1 ESCAPE '^'", - checkparams={'x_1': 'a^%b^_c/d^^e'} + checkparams={"x_1": "a^%b^_c/d^^e"}, ) def test_endswith_autoescape_warning(self): with expect_warnings("The autoescape parameter is now a simple"): self.assert_compile( - column('x').endswith('a%b_c/d', autoescape='P'), + column("x").endswith("a%b_c/d", autoescape="P"), "x LIKE '%' || :x_1 ESCAPE '/'", - checkparams={'x_1': 'a/%b/_c//d'} + checkparams={"x_1": "a/%b/_c//d"}, ) def test_endswith_autoescape_nosqlexpr(self): assert_raises_message( TypeError, "String value expected when autoescape=True", - column('x').endswith, - literal_column("'a%b_c/d'"), autoescape=True + column("x").endswith, + literal_column("'a%b_c/d'"), + autoescape=True, ) def test_not_endswith(self): self.assert_compile( - ~column('x').endswith('y'), + ~column("x").endswith("y"), "x NOT LIKE '%' || :x_1", - checkparams={'x_1': 'y'} + checkparams={"x_1": "y"}, ) def test_not_endswith_escape(self): self.assert_compile( - ~column('x').endswith('a%b_c', escape='\\'), + ~column("x").endswith("a%b_c", escape="\\"), "x NOT LIKE '%' || :x_1 ESCAPE '\\'", - checkparams={'x_1': 'a%b_c'} + checkparams={"x_1": "a%b_c"}, ) def test_not_endswith_autoescape(self): self.assert_compile( - ~column('x').endswith('a%b_c/d', autoescape=True), + ~column("x").endswith("a%b_c/d", autoescape=True), "x NOT LIKE '%' || :x_1 ESCAPE '/'", - checkparams={'x_1': 'a/%b/_c//d'} + checkparams={"x_1": "a/%b/_c//d"}, ) def test_endswith_literal(self): self.assert_compile( - column('x').endswith(literal_column('y')), + column("x").endswith(literal_column("y")), "x LIKE '%' || y", - checkparams={} + checkparams={}, ) def test_endswith_text(self): self.assert_compile( - column('x').endswith(text('y')), - "x LIKE '%' || y", - checkparams={} + column("x").endswith(text("y")), "x LIKE '%' || y", checkparams={} ) def test_endswith_mysql(self): self.assert_compile( - column('x').endswith('y'), + column("x").endswith("y"), "x LIKE concat('%%', %s)", - checkparams={'x_1': 'y'}, - dialect=mysql.dialect() + checkparams={"x_1": "y"}, + dialect=mysql.dialect(), ) def test_not_endswith_mysql(self): self.assert_compile( - ~column('x').endswith('y'), + ~column("x").endswith("y"), "x NOT LIKE concat('%%', %s)", - checkparams={'x_1': 'y'}, - dialect=mysql.dialect() + checkparams={"x_1": "y"}, + dialect=mysql.dialect(), ) def test_endswith_literal_mysql(self): self.assert_compile( - column('x').endswith(literal_column('y')), + column("x").endswith(literal_column("y")), "x LIKE concat('%%', y)", checkparams={}, - dialect=mysql.dialect() + dialect=mysql.dialect(), ) def test_endswith_text_mysql(self): self.assert_compile( - column('x').endswith(text('y')), + column("x").endswith(text("y")), "x LIKE concat('%%', y)", checkparams={}, - dialect=mysql.dialect() + dialect=mysql.dialect(), ) class CustomOpTest(fixtures.TestBase): - def test_is_comparison(self): - c = column('x') - c2 = column('y') - op1 = c.op('$', is_comparison=True)(c2).operator - op2 = c.op('$', is_comparison=False)(c2).operator + c = column("x") + c2 = column("y") + op1 = c.op("$", is_comparison=True)(c2).operator + op2 = c.op("$", is_comparison=False)(c2).operator assert operators.is_comparison(op1) assert not operators.is_comparison(op2) @@ -2708,63 +2627,66 @@ class CustomOpTest(fixtures.TestBase): postgresql.ARRAY(Integer), sqltypes.Numeric(5, 2), ]: - c = column('x', typ) - expr = c.op('$', is_comparison=True)(None) + c = column("x", typ) + expr = c.op("$", is_comparison=True)(None) is_(expr.type, sqltypes.BOOLEANTYPE) - c = column('x', typ) - expr = c.bool_op('$')(None) + c = column("x", typ) + expr = c.bool_op("$")(None) is_(expr.type, sqltypes.BOOLEANTYPE) - expr = c.op('$')(None) + expr = c.op("$")(None) is_(expr.type, typ) - expr = c.op('$', return_type=some_return_type)(None) + expr = c.op("$", return_type=some_return_type)(None) is_(expr.type, some_return_type) - expr = c.op( - '$', is_comparison=True, return_type=some_return_type)(None) + expr = c.op("$", is_comparison=True, return_type=some_return_type)( + None + ) is_(expr.type, some_return_type) class TupleTypingTest(fixtures.TestBase): - def _assert_types(self, expr): eq_(expr.clauses[0].type._type_affinity, Integer) eq_(expr.clauses[1].type._type_affinity, String) eq_(expr.clauses[2].type._type_affinity, LargeBinary()._type_affinity) def test_type_coersion_on_eq(self): - a, b, c = column( - 'a', Integer), column( - 'b', String), column( - 'c', LargeBinary) + a, b, c = ( + column("a", Integer), + column("b", String), + column("c", LargeBinary), + ) t1 = tuple_(a, b, c) - expr = t1 == (3, 'hi', 'there') + expr = t1 == (3, "hi", "there") self._assert_types(expr.right) def test_type_coersion_on_in(self): - a, b, c = column( - 'a', Integer), column( - 'b', String), column( - 'c', LargeBinary) + a, b, c = ( + column("a", Integer), + column("b", String), + column("c", LargeBinary), + ) t1 = tuple_(a, b, c) - expr = t1.in_([(3, 'hi', 'there'), (4, 'Q', 'P')]) + expr = t1.in_([(3, "hi", "there"), (4, "Q", "P")]) eq_(len(expr.right.clauses), 2) for elem in expr.right.clauses: self._assert_types(elem) class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def _fixture(self): m = MetaData() t = Table( - 'tab1', m, - Column('arrval', ARRAY(Integer)), - Column('data', Integer) + "tab1", + m, + Column("arrval", ARRAY(Integer)), + Column("data", Integer), ) return t @@ -2774,7 +2696,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( 5 == any_(t.c.arrval), ":param_1 = ANY (tab1.arrval)", - checkparams={"param_1": 5} + checkparams={"param_1": 5}, ) def test_any_array_method(self): @@ -2783,7 +2705,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( 5 == t.c.arrval.any_(), ":param_1 = ANY (tab1.arrval)", - checkparams={"param_1": 5} + checkparams={"param_1": 5}, ) def test_all_array(self): @@ -2792,7 +2714,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( 5 == all_(t.c.arrval), ":param_1 = ALL (tab1.arrval)", - checkparams={"param_1": 5} + checkparams={"param_1": 5}, ) def test_all_array_method(self): @@ -2801,7 +2723,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( 5 == t.c.arrval.all_(), ":param_1 = ALL (tab1.arrval)", - checkparams={"param_1": 5} + checkparams={"param_1": 5}, ) def test_any_comparator_array(self): @@ -2810,7 +2732,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( 5 > any_(t.c.arrval), ":param_1 > ANY (tab1.arrval)", - checkparams={"param_1": 5} + checkparams={"param_1": 5}, ) def test_all_comparator_array(self): @@ -2819,7 +2741,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( 5 > all_(t.c.arrval), ":param_1 > ALL (tab1.arrval)", - checkparams={"param_1": 5} + checkparams={"param_1": 5}, ) def test_any_comparator_array_wexpr(self): @@ -2828,7 +2750,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( t.c.data > any_(t.c.arrval), "tab1.data > ANY (tab1.arrval)", - checkparams={} + checkparams={}, ) def test_all_comparator_array_wexpr(self): @@ -2837,7 +2759,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( t.c.data > all_(t.c.arrval), "tab1.data > ALL (tab1.arrval)", - checkparams={} + checkparams={}, ) def test_illegal_ops(self): @@ -2846,7 +2768,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): assert_raises_message( exc.ArgumentError, "Only comparison operators may be used with ANY/ALL", - lambda: 5 + all_(t.c.arrval) + lambda: 5 + all_(t.c.arrval), ) # TODO: @@ -2854,8 +2776,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): # as the left-hand side just does its thing. Types # would need to reject their right-hand side. self.assert_compile( - t.c.data + all_(t.c.arrval), - "tab1.data + ALL (tab1.arrval)" + t.c.data + all_(t.c.arrval), "tab1.data + ALL (tab1.arrval)" ) def test_any_array_comparator_accessor(self): @@ -2864,7 +2785,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( t.c.arrval.any(5, operator.gt), ":param_1 > ANY (tab1.arrval)", - checkparams={"param_1": 5} + checkparams={"param_1": 5}, ) def test_all_array_comparator_accessor(self): @@ -2873,7 +2794,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( t.c.arrval.all(5, operator.gt), ":param_1 > ALL (tab1.arrval)", - checkparams={"param_1": 5} + checkparams={"param_1": 5}, ) def test_any_array_expression(self): @@ -2884,9 +2805,13 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): "%(param_1)s = ANY (tab1.arrval[%(arrval_1)s:%(arrval_2)s] || " "ARRAY[%(param_2)s, %(param_3)s])", checkparams={ - 'arrval_2': 6, 'param_1': 5, 'param_3': 4, - 'arrval_1': 5, 'param_2': 3}, - dialect='postgresql' + "arrval_2": 6, + "param_1": 5, + "param_3": 4, + "arrval_1": 5, + "param_2": 3, + }, + dialect="postgresql", ) def test_all_array_expression(self): @@ -2897,9 +2822,13 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): "%(param_1)s = ALL (tab1.arrval[%(arrval_1)s:%(arrval_2)s] || " "ARRAY[%(param_2)s, %(param_3)s])", checkparams={ - 'arrval_2': 6, 'param_1': 5, 'param_3': 4, - 'arrval_1': 5, 'param_2': 3}, - dialect='postgresql' + "arrval_2": 6, + "param_1": 5, + "param_3": 4, + "arrval_1": 5, + "param_2": 3, + }, + dialect="postgresql", ) def test_any_subq(self): @@ -2909,7 +2838,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): 5 == any_(select([t.c.data]).where(t.c.data < 10)), ":param_1 = ANY (SELECT tab1.data " "FROM tab1 WHERE tab1.data < :data_1)", - checkparams={'data_1': 10, 'param_1': 5} + checkparams={"data_1": 10, "param_1": 5}, ) def test_any_subq_method(self): @@ -2919,7 +2848,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): 5 == select([t.c.data]).where(t.c.data < 10).as_scalar().any_(), ":param_1 = ANY (SELECT tab1.data " "FROM tab1 WHERE tab1.data < :data_1)", - checkparams={'data_1': 10, 'param_1': 5} + checkparams={"data_1": 10, "param_1": 5}, ) def test_all_subq(self): @@ -2929,7 +2858,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): 5 == all_(select([t.c.data]).where(t.c.data < 10)), ":param_1 = ALL (SELECT tab1.data " "FROM tab1 WHERE tab1.data < :data_1)", - checkparams={'data_1': 10, 'param_1': 5} + checkparams={"data_1": 10, "param_1": 5}, ) def test_all_subq_method(self): @@ -2939,5 +2868,5 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): 5 == select([t.c.data]).where(t.c.data < 10).as_scalar().all_(), ":param_1 = ALL (SELECT tab1.data " "FROM tab1 WHERE tab1.data < :data_1)", - checkparams={'data_1': 10, 'param_1': 5} + checkparams={"data_1": 10, "param_1": 5}, ) diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 175b69c4f2..13f3b01ff2 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -1,12 +1,41 @@ -from sqlalchemy.testing import eq_, assert_raises_message, assert_raises, \ - is_, in_, not_in_ +from sqlalchemy.testing import ( + eq_, + assert_raises_message, + assert_raises, + is_, + in_, + not_in_, +) from sqlalchemy import testing from sqlalchemy.testing import fixtures, engines from sqlalchemy import ( - exc, sql, func, select, String, Integer, MetaData, and_, ForeignKey, - union, intersect, except_, union_all, VARCHAR, INT, text, - bindparam, literal, not_, literal_column, desc, asc, - TypeDecorator, or_, cast, tuple_) + exc, + sql, + func, + select, + String, + Integer, + MetaData, + and_, + ForeignKey, + union, + intersect, + except_, + union_all, + VARCHAR, + INT, + text, + bindparam, + literal, + not_, + literal_column, + desc, + asc, + TypeDecorator, + or_, + cast, + tuple_, +) from sqlalchemy.engine import default from sqlalchemy.testing.schema import Table, Column @@ -25,28 +54,34 @@ class QueryTest(fixtures.TestBase): global users, users2, addresses, metadata metadata = MetaData(testing.db) users = Table( - 'query_users', metadata, + "query_users", + metadata, Column( - 'user_id', INT, primary_key=True, - test_needs_autoincrement=True), - Column('user_name', VARCHAR(20)), - test_needs_acid=True + "user_id", INT, primary_key=True, test_needs_autoincrement=True + ), + Column("user_name", VARCHAR(20)), + test_needs_acid=True, ) addresses = Table( - 'query_addresses', metadata, + "query_addresses", + metadata, Column( - 'address_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', Integer, ForeignKey('query_users.user_id')), - Column('address', String(30)), - test_needs_acid=True + "address_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("user_id", Integer, ForeignKey("query_users.user_id")), + Column("address", String(30)), + test_needs_acid=True, ) users2 = Table( - 'u2', metadata, - Column('user_id', INT, primary_key=True), - Column('user_name', VARCHAR(20)), - test_needs_acid=True + "u2", + metadata, + Column("user_id", INT, primary_key=True), + Column("user_name", VARCHAR(20)), + test_needs_acid=True, ) metadata.create_all() @@ -62,7 +97,8 @@ class QueryTest(fixtures.TestBase): metadata.drop_all() @testing.fails_on( - 'firebird', "kinterbasdb doesn't send full type information") + "firebird", "kinterbasdb doesn't send full type information" + ) def test_order_by_label(self): """test that a label within an ORDER BY works on each backend. @@ -73,41 +109,43 @@ class QueryTest(fixtures.TestBase): """ users.insert().execute( - {'user_id': 7, 'user_name': 'jack'}, - {'user_id': 8, 'user_name': 'ed'}, - {'user_id': 9, 'user_name': 'fred'}, + {"user_id": 7, "user_name": "jack"}, + {"user_id": 8, "user_name": "ed"}, + {"user_id": 9, "user_name": "fred"}, ) - concat = ("test: " + users.c.user_name).label('thedata') + concat = ("test: " + users.c.user_name).label("thedata") eq_( select([concat]).order_by("thedata").execute().fetchall(), - [("test: ed",), ("test: fred",), ("test: jack",)] + [("test: ed",), ("test: fred",), ("test: jack",)], ) eq_( select([concat]).order_by("thedata").execute().fetchall(), - [("test: ed",), ("test: fred",), ("test: jack",)] + [("test: ed",), ("test: fred",), ("test: jack",)], ) - concat = ("test: " + users.c.user_name).label('thedata') + concat = ("test: " + users.c.user_name).label("thedata") eq_( - select([concat]).order_by(desc('thedata')).execute().fetchall(), - [("test: jack",), ("test: fred",), ("test: ed",)] + select([concat]).order_by(desc("thedata")).execute().fetchall(), + [("test: jack",), ("test: fred",), ("test: ed",)], ) @testing.requires.order_by_label_with_expression def test_order_by_label_compound(self): users.insert().execute( - {'user_id': 7, 'user_name': 'jack'}, - {'user_id': 8, 'user_name': 'ed'}, - {'user_id': 9, 'user_name': 'fred'}, + {"user_id": 7, "user_name": "jack"}, + {"user_id": 8, "user_name": "ed"}, + {"user_id": 9, "user_name": "fred"}, ) - concat = ("test: " + users.c.user_name).label('thedata') + concat = ("test: " + users.c.user_name).label("thedata") eq_( - select([concat]).order_by(literal_column('thedata') + "x"). - execute().fetchall(), - [("test: ed",), ("test: fred",), ("test: jack",)] + select([concat]) + .order_by(literal_column("thedata") + "x") + .execute() + .fetchall(), + [("test: ed",), ("test: fred",), ("test: jack",)], ) @testing.requires.boolean_col_expressions @@ -120,93 +158,124 @@ class QueryTest(fixtures.TestBase): eq_(testing.db.execute(select([or_(false, false)])).scalar(), False) eq_( testing.db.execute(select([not_(or_(false, false))])).scalar(), - True) + True, + ) row = testing.db.execute( select( - [or_(false, false).label("x"), - and_(true, false).label("y")])).first() + [or_(false, false).label("x"), and_(true, false).label("y")] + ) + ).first() assert row.x == False # noqa assert row.y == False # noqa row = testing.db.execute( - select( - [or_(true, false).label("x"), - and_(true, false).label("y")])).first() + select([or_(true, false).label("x"), and_(true, false).label("y")]) + ).first() assert row.x == True # noqa assert row.y == False # noqa def test_like_ops(self): users.insert().execute( - {'user_id': 1, 'user_name': 'apples'}, - {'user_id': 2, 'user_name': 'oranges'}, - {'user_id': 3, 'user_name': 'bananas'}, - {'user_id': 4, 'user_name': 'legumes'}, - {'user_id': 5, 'user_name': 'hi % there'}, + {"user_id": 1, "user_name": "apples"}, + {"user_id": 2, "user_name": "oranges"}, + {"user_id": 3, "user_name": "bananas"}, + {"user_id": 4, "user_name": "legumes"}, + {"user_id": 5, "user_name": "hi % there"}, ) for expr, result in ( - (select([users.c.user_id]). - where(users.c.user_name.startswith('apple')), [(1,)]), - (select([users.c.user_id]). - where(users.c.user_name.contains('i % t')), [(5,)]), - (select([users.c.user_id]). - where(users.c.user_name.endswith('anas')), [(3,)]), - (select([users.c.user_id]). - where(users.c.user_name.contains('i % t', escape='&')), - [(5,)]), + ( + select([users.c.user_id]).where( + users.c.user_name.startswith("apple") + ), + [(1,)], + ), + ( + select([users.c.user_id]).where( + users.c.user_name.contains("i % t") + ), + [(5,)], + ), + ( + select([users.c.user_id]).where( + users.c.user_name.endswith("anas") + ), + [(3,)], + ), + ( + select([users.c.user_id]).where( + users.c.user_name.contains("i % t", escape="&") + ), + [(5,)], + ), ): eq_(expr.execute().fetchall(), result) @testing.requires.mod_operator_as_percent_sign - @testing.emits_warning('.*now automatically escapes.*') + @testing.emits_warning(".*now automatically escapes.*") def test_percents_in_text(self): for expr, result in ( (text("select 6 % 10"), 6), (text("select 17 % 10"), 7), - (text("select '%'"), '%'), - (text("select '%%'"), '%%'), - (text("select '%%%'"), '%%%'), - (text("select 'hello % world'"), "hello % world") + (text("select '%'"), "%"), + (text("select '%%'"), "%%"), + (text("select '%%%'"), "%%%"), + (text("select 'hello % world'"), "hello % world"), ): eq_(testing.db.scalar(expr), result) def test_ilike(self): users.insert().execute( - {'user_id': 1, 'user_name': 'one'}, - {'user_id': 2, 'user_name': 'TwO'}, - {'user_id': 3, 'user_name': 'ONE'}, - {'user_id': 4, 'user_name': 'OnE'}, + {"user_id": 1, "user_name": "one"}, + {"user_id": 2, "user_name": "TwO"}, + {"user_id": 3, "user_name": "ONE"}, + {"user_id": 4, "user_name": "OnE"}, ) eq_( - select([users.c.user_id]).where(users.c.user_name.ilike('one')). - execute().fetchall(), [(1, ), (3, ), (4, )]) + select([users.c.user_id]) + .where(users.c.user_name.ilike("one")) + .execute() + .fetchall(), + [(1,), (3,), (4,)], + ) eq_( - select([users.c.user_id]).where(users.c.user_name.ilike('TWO')). - execute().fetchall(), [(2, )]) + select([users.c.user_id]) + .where(users.c.user_name.ilike("TWO")) + .execute() + .fetchall(), + [(2,)], + ) - if testing.against('postgresql'): + if testing.against("postgresql"): eq_( - select([users.c.user_id]). - where(users.c.user_name.like('one')).execute().fetchall(), - [(1, )]) + select([users.c.user_id]) + .where(users.c.user_name.like("one")) + .execute() + .fetchall(), + [(1,)], + ) eq_( - select([users.c.user_id]). - where(users.c.user_name.like('TWO')).execute().fetchall(), []) + select([users.c.user_id]) + .where(users.c.user_name.like("TWO")) + .execute() + .fetchall(), + [], + ) def test_compiled_execute(self): - users.insert().execute(user_id=7, user_name='jack') - s = select([users], users.c.user_id == bindparam('id')).compile() + users.insert().execute(user_id=7, user_name="jack") + s = select([users], users.c.user_id == bindparam("id")).compile() c = testing.db.connect() - assert c.execute(s, id=7).fetchall()[0]['user_id'] == 7 + assert c.execute(s, id=7).fetchall()[0]["user_id"] == 7 def test_compiled_insert_execute(self): - users.insert().compile().execute(user_id=7, user_name='jack') - s = select([users], users.c.user_id == bindparam('id')).compile() + users.insert().compile().execute(user_id=7, user_name="jack") + s = select([users], users.c.user_id == bindparam("id")).compile() c = testing.db.connect() - assert c.execute(s, id=7).fetchall()[0]['user_id'] == 7 + assert c.execute(s, id=7).fetchall()[0]["user_id"] == 7 def test_repeated_bindparams(self): """Tests that a BindParam can be used more than once. @@ -215,18 +284,19 @@ class QueryTest(fixtures.TestBase): paramstyles. """ - users.insert().execute(user_id=7, user_name='jack') - users.insert().execute(user_id=8, user_name='fred') + users.insert().execute(user_id=7, user_name="jack") + users.insert().execute(user_id=8, user_name="fred") - u = bindparam('userid') + u = bindparam("userid") s = users.select(and_(users.c.user_name == u, users.c.user_name == u)) - r = s.execute(userid='fred').fetchall() + r = s.execute(userid="fred").fetchall() assert len(r) == 1 def test_bindparam_detection(self): - dialect = default.DefaultDialect(paramstyle='qmark') + dialect = default.DefaultDialect(paramstyle="qmark") - def prep(q): return str(sql.text(q).compile(dialect=dialect)) + def prep(q): + return str(sql.text(q).compile(dialect=dialect)) def a_eq(got, wanted): if got != wanted: @@ -234,7 +304,7 @@ class QueryTest(fixtures.TestBase): print("Received %s" % got) self.assert_(got == wanted, got) - a_eq(prep('select foo'), 'select foo') + a_eq(prep("select foo"), "select foo") a_eq(prep("time='12:30:00'"), "time='12:30:00'") a_eq(prep("time='12:30:00'"), "time='12:30:00'") a_eq(prep(":this:that"), ":this:that") @@ -249,7 +319,7 @@ class QueryTest(fixtures.TestBase): a_eq(prep("(:that$:other)"), "(:that$:other)") a_eq(prep(".:that$ :other."), ".? ?.") - a_eq(prep(r'select \foo'), r'select \foo') + a_eq(prep(r"select \foo"), r"select \foo") a_eq(prep(r"time='12\:30:00'"), r"time='12\:30:00'") a_eq(prep(r":this \:that"), "? :that") a_eq(prep(r"(\:that$other)"), "(:that$other)") @@ -271,12 +341,13 @@ class QueryTest(fixtures.TestBase): eq_( testing.db.scalar(select([cast("INT_5", type_=MyInteger)])), - "INT_5" + "INT_5", ) eq_( testing.db.scalar( - select([cast("INT_5", type_=MyInteger).label('foo')])), - "INT_5" + select([cast("INT_5", type_=MyInteger).label("foo")]) + ), + "INT_5", ) def test_order_by(self): @@ -285,61 +356,96 @@ class QueryTest(fixtures.TestBase): Tests simple, compound, aliased and DESC clauses. """ - users.insert().execute(user_id=1, user_name='c') - users.insert().execute(user_id=2, user_name='b') - users.insert().execute(user_id=3, user_name='a') + users.insert().execute(user_id=1, user_name="c") + users.insert().execute(user_id=2, user_name="b") + users.insert().execute(user_id=3, user_name="a") def a_eq(executable, wanted): got = list(executable.execute()) eq_(got, wanted) for labels in False, True: - a_eq(users.select(order_by=[users.c.user_id], - use_labels=labels), - [(1, 'c'), (2, 'b'), (3, 'a')]) - - a_eq(users.select(order_by=[users.c.user_name, users.c.user_id], - use_labels=labels), - [(3, 'a'), (2, 'b'), (1, 'c')]) - - a_eq(select([users.c.user_id.label('foo')], - use_labels=labels, - order_by=[users.c.user_id]), - [(1,), (2,), (3,)]) - - a_eq(select([users.c.user_id.label('foo'), users.c.user_name], - use_labels=labels, - order_by=[users.c.user_name, users.c.user_id]), - [(3, 'a'), (2, 'b'), (1, 'c')]) - - a_eq(users.select(distinct=True, - use_labels=labels, - order_by=[users.c.user_id]), - [(1, 'c'), (2, 'b'), (3, 'a')]) - - a_eq(select([users.c.user_id.label('foo')], - distinct=True, - use_labels=labels, - order_by=[users.c.user_id]), - [(1,), (2,), (3,)]) - - a_eq(select([users.c.user_id.label('a'), - users.c.user_id.label('b'), - users.c.user_name], - use_labels=labels, - order_by=[users.c.user_id]), - [(1, 1, 'c'), (2, 2, 'b'), (3, 3, 'a')]) - - a_eq(users.select(distinct=True, - use_labels=labels, - order_by=[desc(users.c.user_id)]), - [(3, 'a'), (2, 'b'), (1, 'c')]) - - a_eq(select([users.c.user_id.label('foo')], - distinct=True, - use_labels=labels, - order_by=[users.c.user_id.desc()]), - [(3,), (2,), (1,)]) + a_eq( + users.select(order_by=[users.c.user_id], use_labels=labels), + [(1, "c"), (2, "b"), (3, "a")], + ) + + a_eq( + users.select( + order_by=[users.c.user_name, users.c.user_id], + use_labels=labels, + ), + [(3, "a"), (2, "b"), (1, "c")], + ) + + a_eq( + select( + [users.c.user_id.label("foo")], + use_labels=labels, + order_by=[users.c.user_id], + ), + [(1,), (2,), (3,)], + ) + + a_eq( + select( + [users.c.user_id.label("foo"), users.c.user_name], + use_labels=labels, + order_by=[users.c.user_name, users.c.user_id], + ), + [(3, "a"), (2, "b"), (1, "c")], + ) + + a_eq( + users.select( + distinct=True, + use_labels=labels, + order_by=[users.c.user_id], + ), + [(1, "c"), (2, "b"), (3, "a")], + ) + + a_eq( + select( + [users.c.user_id.label("foo")], + distinct=True, + use_labels=labels, + order_by=[users.c.user_id], + ), + [(1,), (2,), (3,)], + ) + + a_eq( + select( + [ + users.c.user_id.label("a"), + users.c.user_id.label("b"), + users.c.user_name, + ], + use_labels=labels, + order_by=[users.c.user_id], + ), + [(1, 1, "c"), (2, 2, "b"), (3, 3, "a")], + ) + + a_eq( + users.select( + distinct=True, + use_labels=labels, + order_by=[desc(users.c.user_id)], + ), + [(3, "a"), (2, "b"), (1, "c")], + ) + + a_eq( + select( + [users.c.user_id.label("foo")], + distinct=True, + use_labels=labels, + order_by=[users.c.user_id.desc()], + ), + [(3,), (2,), (1,)], + ) @testing.requires.nullsordering def test_order_by_nulls(self): @@ -349,67 +455,98 @@ class QueryTest(fixtures.TestBase): """ users.insert().execute(user_id=1) - users.insert().execute(user_id=2, user_name='b') - users.insert().execute(user_id=3, user_name='a') + users.insert().execute(user_id=2, user_name="b") + users.insert().execute(user_id=3, user_name="a") def a_eq(executable, wanted): got = list(executable.execute()) eq_(got, wanted) for labels in False, True: - a_eq(users.select(order_by=[users.c.user_name.nullsfirst()], - use_labels=labels), - [(1, None), (3, 'a'), (2, 'b')]) + a_eq( + users.select( + order_by=[users.c.user_name.nullsfirst()], + use_labels=labels, + ), + [(1, None), (3, "a"), (2, "b")], + ) - a_eq(users.select(order_by=[users.c.user_name.nullslast()], - use_labels=labels), - [(3, 'a'), (2, 'b'), (1, None)]) + a_eq( + users.select( + order_by=[users.c.user_name.nullslast()], use_labels=labels + ), + [(3, "a"), (2, "b"), (1, None)], + ) - a_eq(users.select(order_by=[asc(users.c.user_name).nullsfirst()], - use_labels=labels), - [(1, None), (3, 'a'), (2, 'b')]) + a_eq( + users.select( + order_by=[asc(users.c.user_name).nullsfirst()], + use_labels=labels, + ), + [(1, None), (3, "a"), (2, "b")], + ) - a_eq(users.select(order_by=[asc(users.c.user_name).nullslast()], - use_labels=labels), - [(3, 'a'), (2, 'b'), (1, None)]) + a_eq( + users.select( + order_by=[asc(users.c.user_name).nullslast()], + use_labels=labels, + ), + [(3, "a"), (2, "b"), (1, None)], + ) - a_eq(users.select(order_by=[users.c.user_name.desc().nullsfirst()], - use_labels=labels), - [(1, None), (2, 'b'), (3, 'a')]) + a_eq( + users.select( + order_by=[users.c.user_name.desc().nullsfirst()], + use_labels=labels, + ), + [(1, None), (2, "b"), (3, "a")], + ) a_eq( users.select( order_by=[users.c.user_name.desc().nullslast()], - use_labels=labels), - [(2, 'b'), (3, 'a'), (1, None)]) + use_labels=labels, + ), + [(2, "b"), (3, "a"), (1, None)], + ) a_eq( users.select( order_by=[desc(users.c.user_name).nullsfirst()], - use_labels=labels), - [(1, None), (2, 'b'), (3, 'a')]) + use_labels=labels, + ), + [(1, None), (2, "b"), (3, "a")], + ) - a_eq(users.select(order_by=[desc(users.c.user_name).nullslast()], - use_labels=labels), - [(2, 'b'), (3, 'a'), (1, None)]) + a_eq( + users.select( + order_by=[desc(users.c.user_name).nullslast()], + use_labels=labels, + ), + [(2, "b"), (3, "a"), (1, None)], + ) a_eq( users.select( order_by=[users.c.user_name.nullsfirst(), users.c.user_id], - use_labels=labels), - [(1, None), (3, 'a'), (2, 'b')]) + use_labels=labels, + ), + [(1, None), (3, "a"), (2, "b")], + ) a_eq( users.select( order_by=[users.c.user_name.nullslast(), users.c.user_id], - use_labels=labels), - [(3, 'a'), (2, 'b'), (1, None)]) + use_labels=labels, + ), + [(3, "a"), (2, "b"), (1, None)], + ) def test_in_filtering(self): """test the behavior of the in_() function.""" - users.insert().execute(user_id=7, user_name='jack') - users.insert().execute(user_id=8, user_name='fred') + users.insert().execute(user_id=7, user_name="jack") + users.insert().execute(user_id=8, user_name="fred") users.insert().execute(user_id=9, user_name=None) s = users.select(users.c.user_name.in_([])) @@ -421,11 +558,11 @@ class QueryTest(fixtures.TestBase): r = s.execute().fetchall() assert len(r) == 3 - s = users.select(users.c.user_name.in_(['jack', 'fred'])) + s = users.select(users.c.user_name.in_(["jack", "fred"])) r = s.execute().fetchall() assert len(r) == 2 - s = users.select(not_(users.c.user_name.in_(['jack', 'fred']))) + s = users.select(not_(users.c.user_name.in_(["jack", "fred"]))) r = s.execute().fetchall() # Null values are not outside any set assert len(r) == 0 @@ -434,39 +571,41 @@ class QueryTest(fixtures.TestBase): testing.db.execute( users.insert(), [ - dict(user_id=7, user_name='jack'), - dict(user_id=8, user_name='fred'), - dict(user_id=9, user_name=None) - ] + dict(user_id=7, user_name="jack"), + dict(user_id=8, user_name="fred"), + dict(user_id=9, user_name=None), + ], ) with testing.db.connect() as conn: - stmt = select([users]).where( - users.c.user_name.in_(bindparam('uname', expanding=True)) - ).order_by(users.c.user_id) - - eq_( - conn.execute(stmt, {"uname": ['jack']}).fetchall(), - [(7, 'jack')] + stmt = ( + select([users]) + .where( + users.c.user_name.in_(bindparam("uname", expanding=True)) + ) + .order_by(users.c.user_id) ) eq_( - conn.execute(stmt, {"uname": ['jack', 'fred']}).fetchall(), - [(7, 'jack'), (8, 'fred')] + conn.execute(stmt, {"uname": ["jack"]}).fetchall(), + [(7, "jack")], ) eq_( - conn.execute(stmt, {"uname": []}).fetchall(), - [] + conn.execute(stmt, {"uname": ["jack", "fred"]}).fetchall(), + [(7, "jack"), (8, "fred")], ) + eq_(conn.execute(stmt, {"uname": []}).fetchall(), []) + assert_raises_message( exc.StatementError, "'expanding' parameters can't be used with executemany()", conn.execute, users.update().where( - users.c.user_name.in_(bindparam('uname', expanding=True)) - ), [{"uname": ['fred']}, {"uname": ['ed']}] + users.c.user_name.in_(bindparam("uname", expanding=True)) + ), + [{"uname": ["fred"]}, {"uname": ["ed"]}], ) @testing.requires.no_quoting_special_bind_names @@ -474,94 +613,113 @@ class QueryTest(fixtures.TestBase): testing.db.execute( users.insert(), [ - dict(user_id=7, user_name='jack'), - dict(user_id=8, user_name='fred'), - ] + dict(user_id=7, user_name="jack"), + dict(user_id=8, user_name="fred"), + ], ) with testing.db.connect() as conn: - stmt = select([users]).where( - users.c.user_name.in_(bindparam('u35', expanding=True)) - ).where( - users.c.user_id == bindparam("u46") - ).order_by(users.c.user_id) + stmt = ( + select([users]) + .where(users.c.user_name.in_(bindparam("u35", expanding=True))) + .where(users.c.user_id == bindparam("u46")) + .order_by(users.c.user_id) + ) eq_( conn.execute( - stmt, {"u35": ['jack', 'fred'], "u46": 7}).fetchall(), - [(7, 'jack')] + stmt, {"u35": ["jack", "fred"], "u46": 7} + ).fetchall(), + [(7, "jack")], ) - stmt = select([users]).where( - users.c.user_name.in_(bindparam('u.35', expanding=True)) - ).where( - users.c.user_id == bindparam("u.46") - ).order_by(users.c.user_id) + stmt = ( + select([users]) + .where( + users.c.user_name.in_(bindparam("u.35", expanding=True)) + ) + .where(users.c.user_id == bindparam("u.46")) + .order_by(users.c.user_id) + ) eq_( conn.execute( - stmt, {"u.35": ['jack', 'fred'], "u.46": 7}).fetchall(), - [(7, 'jack')] + stmt, {"u.35": ["jack", "fred"], "u.46": 7} + ).fetchall(), + [(7, "jack")], ) def test_expanding_in_multiple(self): testing.db.execute( users.insert(), [ - dict(user_id=7, user_name='jack'), - dict(user_id=8, user_name='fred'), - dict(user_id=9, user_name='ed') - ] + dict(user_id=7, user_name="jack"), + dict(user_id=8, user_name="fred"), + dict(user_id=9, user_name="ed"), + ], ) with testing.db.connect() as conn: - stmt = select([users]).where( - users.c.user_name.in_(bindparam('uname', expanding=True)) - ).where( - users.c.user_id.in_(bindparam('userid', expanding=True)) - ).order_by(users.c.user_id) + stmt = ( + select([users]) + .where( + users.c.user_name.in_(bindparam("uname", expanding=True)) + ) + .where( + users.c.user_id.in_(bindparam("userid", expanding=True)) + ) + .order_by(users.c.user_id) + ) eq_( conn.execute( - stmt, - {"uname": ['jack', 'fred', 'ed'], "userid": [8, 9]} + stmt, {"uname": ["jack", "fred", "ed"], "userid": [8, 9]} ).fetchall(), - [(8, 'fred'), (9, 'ed')] + [(8, "fred"), (9, "ed")], ) def test_expanding_in_repeated(self): testing.db.execute( users.insert(), [ - dict(user_id=7, user_name='jack'), - dict(user_id=8, user_name='fred'), - dict(user_id=9, user_name='ed') - ] + dict(user_id=7, user_name="jack"), + dict(user_id=8, user_name="fred"), + dict(user_id=9, user_name="ed"), + ], ) with testing.db.connect() as conn: - stmt = select([users]).where( - users.c.user_name.in_( - bindparam('uname', expanding=True) - ) | users.c.user_name.in_(bindparam('uname2', expanding=True)) - ).where(users.c.user_id == 8) + stmt = ( + select([users]) + .where( + users.c.user_name.in_(bindparam("uname", expanding=True)) + | users.c.user_name.in_( + bindparam("uname2", expanding=True) + ) + ) + .where(users.c.user_id == 8) + ) stmt = stmt.union( - select([users]).where( - users.c.user_name.in_( - bindparam('uname', expanding=True) - ) | users.c.user_name.in_( - bindparam('uname2', expanding=True)) - ).where(users.c.user_id == 9) + select([users]) + .where( + users.c.user_name.in_(bindparam("uname", expanding=True)) + | users.c.user_name.in_( + bindparam("uname2", expanding=True) + ) + ) + .where(users.c.user_id == 9) ).order_by(stmt.c.user_id) eq_( conn.execute( stmt, { - "uname": ['jack', 'fred'], - "uname2": ['ed'], "userid": [8, 9]} + "uname": ["jack", "fred"], + "uname2": ["ed"], + "userid": [8, 9], + }, ).fetchall(), - [(8, 'fred'), (9, 'ed')] + [(8, "fred"), (9, "ed")], ) @testing.requires.tuple_in @@ -569,34 +727,38 @@ class QueryTest(fixtures.TestBase): testing.db.execute( users.insert(), [ - dict(user_id=7, user_name='jack'), - dict(user_id=8, user_name='fred'), - dict(user_id=9, user_name=None) - ] + dict(user_id=7, user_name="jack"), + dict(user_id=8, user_name="fred"), + dict(user_id=9, user_name=None), + ], ) with testing.db.connect() as conn: - stmt = select([users]).where( - tuple_( - users.c.user_id, - users.c.user_name - ).in_(bindparam('uname', expanding=True)) - ).order_by(users.c.user_id) + stmt = ( + select([users]) + .where( + tuple_(users.c.user_id, users.c.user_name).in_( + bindparam("uname", expanding=True) + ) + ) + .order_by(users.c.user_id) + ) eq_( - conn.execute(stmt, {"uname": [(7, 'jack')]}).fetchall(), - [(7, 'jack')] + conn.execute(stmt, {"uname": [(7, "jack")]}).fetchall(), + [(7, "jack")], ) eq_( - conn.execute(stmt, {"uname": [(7, 'jack'), (8, 'fred')]}).fetchall(), - [(7, 'jack'), (8, 'fred')] + conn.execute( + stmt, {"uname": [(7, "jack"), (8, "fred")]} + ).fetchall(), + [(7, "jack"), (8, "fred")], ) - - @testing.fails_on('firebird', "uses sql-92 rules") - @testing.fails_on('sybase', "uses sql-92 rules") - @testing.skip_if(['mssql']) + @testing.fails_on("firebird", "uses sql-92 rules") + @testing.fails_on("sybase", "uses sql-92 rules") + @testing.skip_if(["mssql"]) def test_bind_in(self): """test calling IN against a bind parameter. @@ -605,24 +767,24 @@ class QueryTest(fixtures.TestBase): """ - users.insert().execute(user_id=7, user_name='jack') - users.insert().execute(user_id=8, user_name='fred') + users.insert().execute(user_id=7, user_name="jack") + users.insert().execute(user_id=8, user_name="fred") users.insert().execute(user_id=9, user_name=None) - u = bindparam('search_key') + u = bindparam("search_key") s = users.select(not_(u.in_([]))) - r = s.execute(search_key='john').fetchall() + r = s.execute(search_key="john").fetchall() assert len(r) == 3 r = s.execute(search_key=None).fetchall() assert len(r) == 3 - @testing.emits_warning('.*empty sequence.*') + @testing.emits_warning(".*empty sequence.*") def test_literal_in(self): """similar to test_bind_in but use a bind with a value.""" - users.insert().execute(user_id=7, user_name='jack') - users.insert().execute(user_id=8, user_name='fred') + users.insert().execute(user_id=7, user_name="jack") + users.insert().execute(user_id=8, user_name="fred") users.insert().execute(user_id=9, user_name=None) s = users.select(not_(literal("john").in_([]))) @@ -641,10 +803,10 @@ class QueryTest(fixtures.TestBase): conn.execute( users.insert(), [ - {'user_id': 7, 'user_name': 'jack'}, - {'user_id': 8, 'user_name': 'ed'}, - {'user_id': 9, 'user_name': None} - ] + {"user_id": 7, "user_name": "jack"}, + {"user_id": 8, "user_name": "ed"}, + {"user_id": 9, "user_name": None}, + ], ) s = users.select(users.c.user_name.in_([]) == True) # noqa @@ -666,7 +828,8 @@ class QueryTest(fixtures.TestBase): """ engine = engines.testing_engine( - options={"empty_in_strategy": "dynamic"}) + options={"empty_in_strategy": "dynamic"} + ) with engine.connect() as conn: users.create(engine, checkfirst=True) @@ -674,10 +837,10 @@ class QueryTest(fixtures.TestBase): conn.execute( users.insert(), [ - {'user_id': 7, 'user_name': 'jack'}, - {'user_id': 8, 'user_name': 'ed'}, - {'user_id': 9, 'user_name': None} - ] + {"user_id": 7, "user_name": "jack"}, + {"user_id": 8, "user_name": "ed"}, + {"user_id": 9, "user_name": None}, + ], ) s = users.select(users.c.user_name.in_([]) == True) # noqa @@ -698,57 +861,65 @@ class RequiredBindTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table( - 'foo', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column('x', Integer) + "foo", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("x", Integer), ) def _assert_raises(self, stmt, params): assert_raises_message( exc.StatementError, "A value is required for bind parameter 'x'", - testing.db.execute, stmt, **params) + testing.db.execute, + stmt, + **params + ) assert_raises_message( exc.StatementError, "A value is required for bind parameter 'x'", - testing.db.execute, stmt, params) + testing.db.execute, + stmt, + params, + ) def test_insert(self): stmt = self.tables.foo.insert().values( - x=bindparam('x'), data=bindparam('data')) - self._assert_raises(stmt, {'data': 'data'}) + x=bindparam("x"), data=bindparam("data") + ) + self._assert_raises(stmt, {"data": "data"}) def test_select_where(self): - stmt = select([self.tables.foo]). \ - where(self.tables.foo.c.data == bindparam('data')). \ - where(self.tables.foo.c.x == bindparam('x')) - self._assert_raises(stmt, {'data': 'data'}) + stmt = ( + select([self.tables.foo]) + .where(self.tables.foo.c.data == bindparam("data")) + .where(self.tables.foo.c.x == bindparam("x")) + ) + self._assert_raises(stmt, {"data": "data"}) @testing.requires.standalone_binds def test_select_columns(self): - stmt = select([bindparam('data'), bindparam('x')]) - self._assert_raises( - stmt, {'data': 'data'} - ) + stmt = select([bindparam("data"), bindparam("x")]) + self._assert_raises(stmt, {"data": "data"}) def test_text(self): stmt = text("select * from foo where x=:x and data=:data1") - self._assert_raises( - stmt, {'data1': 'data'} - ) + self._assert_raises(stmt, {"data1": "data"}) def test_required_flag(self): - is_(bindparam('foo').required, True) - is_(bindparam('foo', required=False).required, False) - is_(bindparam('foo', 'bar').required, False) - is_(bindparam('foo', 'bar', required=True).required, True) + is_(bindparam("foo").required, True) + is_(bindparam("foo", required=False).required, False) + is_(bindparam("foo", "bar").required, False) + is_(bindparam("foo", "bar", required=True).required, True) - def c(): return None - is_(bindparam('foo', callable_=c, required=True).required, True) - is_(bindparam('foo', callable_=c).required, False) - is_(bindparam('foo', callable_=c, required=False).required, False) + def c(): + return None + + is_(bindparam("foo", callable_=c, required=True).required, True) + is_(bindparam("foo", callable_=c).required, False) + is_(bindparam("foo", callable_=c, required=False).required, False) class LimitTest(fixtures.TestBase): @@ -759,70 +930,97 @@ class LimitTest(fixtures.TestBase): global users, addresses, metadata metadata = MetaData(testing.db) users = Table( - 'query_users', metadata, - Column('user_id', INT, primary_key=True), - Column('user_name', VARCHAR(20)), + "query_users", + metadata, + Column("user_id", INT, primary_key=True), + Column("user_name", VARCHAR(20)), ) addresses = Table( - 'query_addresses', metadata, - Column('address_id', Integer, primary_key=True), - Column('user_id', Integer, ForeignKey('query_users.user_id')), - Column('address', String(30))) + "query_addresses", + metadata, + Column("address_id", Integer, primary_key=True), + Column("user_id", Integer, ForeignKey("query_users.user_id")), + Column("address", String(30)), + ) metadata.create_all() - users.insert().execute(user_id=1, user_name='john') - addresses.insert().execute(address_id=1, user_id=1, address='addr1') - users.insert().execute(user_id=2, user_name='jack') - addresses.insert().execute(address_id=2, user_id=2, address='addr1') - users.insert().execute(user_id=3, user_name='ed') - addresses.insert().execute(address_id=3, user_id=3, address='addr2') - users.insert().execute(user_id=4, user_name='wendy') - addresses.insert().execute(address_id=4, user_id=4, address='addr3') - users.insert().execute(user_id=5, user_name='laura') - addresses.insert().execute(address_id=5, user_id=5, address='addr4') - users.insert().execute(user_id=6, user_name='ralph') - addresses.insert().execute(address_id=6, user_id=6, address='addr5') - users.insert().execute(user_id=7, user_name='fido') - addresses.insert().execute(address_id=7, user_id=7, address='addr5') + users.insert().execute(user_id=1, user_name="john") + addresses.insert().execute(address_id=1, user_id=1, address="addr1") + users.insert().execute(user_id=2, user_name="jack") + addresses.insert().execute(address_id=2, user_id=2, address="addr1") + users.insert().execute(user_id=3, user_name="ed") + addresses.insert().execute(address_id=3, user_id=3, address="addr2") + users.insert().execute(user_id=4, user_name="wendy") + addresses.insert().execute(address_id=4, user_id=4, address="addr3") + users.insert().execute(user_id=5, user_name="laura") + addresses.insert().execute(address_id=5, user_id=5, address="addr4") + users.insert().execute(user_id=6, user_name="ralph") + addresses.insert().execute(address_id=6, user_id=6, address="addr5") + users.insert().execute(user_id=7, user_name="fido") + addresses.insert().execute(address_id=7, user_id=7, address="addr5") @classmethod def teardown_class(cls): metadata.drop_all() def test_select_limit(self): - r = users.select(limit=3, order_by=[users.c.user_id]).execute(). \ - fetchall() - self.assert_(r == [(1, 'john'), (2, 'jack'), (3, 'ed')], repr(r)) + r = ( + users.select(limit=3, order_by=[users.c.user_id]) + .execute() + .fetchall() + ) + self.assert_(r == [(1, "john"), (2, "jack"), (3, "ed")], repr(r)) @testing.requires.offset def test_select_limit_offset(self): """Test the interaction between limit and offset""" - r = users.select(limit=3, offset=2, order_by=[users.c.user_id]). \ - execute().fetchall() - self.assert_(r == [(3, 'ed'), (4, 'wendy'), (5, 'laura')]) - r = users.select(offset=5, order_by=[users.c.user_id]).execute(). \ - fetchall() - self.assert_(r == [(6, 'ralph'), (7, 'fido')]) + r = ( + users.select(limit=3, offset=2, order_by=[users.c.user_id]) + .execute() + .fetchall() + ) + self.assert_(r == [(3, "ed"), (4, "wendy"), (5, "laura")]) + r = ( + users.select(offset=5, order_by=[users.c.user_id]) + .execute() + .fetchall() + ) + self.assert_(r == [(6, "ralph"), (7, "fido")]) def test_select_distinct_limit(self): """Test the interaction between limit and distinct""" r = sorted( - [x[0] for x in select([addresses.c.address]).distinct(). - limit(3).order_by(addresses.c.address).execute().fetchall()]) + [ + x[0] + for x in select([addresses.c.address]) + .distinct() + .limit(3) + .order_by(addresses.c.address) + .execute() + .fetchall() + ] + ) self.assert_(len(r) == 3, repr(r)) self.assert_(r[0] != r[1] and r[1] != r[2], repr(r)) @testing.requires.offset - @testing.fails_on('mssql', 'FIXME: unknown') + @testing.fails_on("mssql", "FIXME: unknown") def test_select_distinct_offset(self): """Test the interaction between distinct and offset""" r = sorted( - [x[0] for x in select([addresses.c.address]).distinct(). - offset(1).order_by(addresses.c.address). - execute().fetchall()]) + [ + x[0] + for x in select([addresses.c.address]) + .distinct() + .offset(1) + .order_by(addresses.c.address) + .execute() + .fetchall() + ] + ) eq_(len(r), 4) self.assert_(r[0] != r[1] and r[1] != r[2] and r[2] != [3], repr(r)) @@ -830,8 +1028,15 @@ class LimitTest(fixtures.TestBase): def test_select_distinct_limit_offset(self): """Test the interaction between limit and limit/offset""" - r = select([addresses.c.address]).order_by(addresses.c.address). \ - distinct().offset(2).limit(3).execute().fetchall() + r = ( + select([addresses.c.address]) + .order_by(addresses.c.address) + .distinct() + .offset(2) + .limit(3) + .execute() + .fetchall() + ) self.assert_(len(r) == 3, repr(r)) self.assert_(r[0] != r[1] and r[1] != r[2], repr(r)) @@ -848,46 +1053,67 @@ class CompoundTest(fixtures.TestBase): global metadata, t1, t2, t3 metadata = MetaData(testing.db) t1 = Table( - 't1', metadata, + "t1", + metadata, Column( - 'col1', Integer, test_needs_autoincrement=True, - primary_key=True), - Column('col2', String(30)), - Column('col3', String(40)), - Column('col4', String(30))) + "col1", + Integer, + test_needs_autoincrement=True, + primary_key=True, + ), + Column("col2", String(30)), + Column("col3", String(40)), + Column("col4", String(30)), + ) t2 = Table( - 't2', metadata, + "t2", + metadata, Column( - 'col1', Integer, test_needs_autoincrement=True, - primary_key=True), - Column('col2', String(30)), - Column('col3', String(40)), - Column('col4', String(30))) + "col1", + Integer, + test_needs_autoincrement=True, + primary_key=True, + ), + Column("col2", String(30)), + Column("col3", String(40)), + Column("col4", String(30)), + ) t3 = Table( - 't3', metadata, + "t3", + metadata, Column( - 'col1', Integer, test_needs_autoincrement=True, - primary_key=True), - Column('col2', String(30)), - Column('col3', String(40)), - Column('col4', String(30))) + "col1", + Integer, + test_needs_autoincrement=True, + primary_key=True, + ), + Column("col2", String(30)), + Column("col3", String(40)), + Column("col4", String(30)), + ) metadata.create_all() - t1.insert().execute([ - dict(col2="t1col2r1", col3="aaa", col4="aaa"), - dict(col2="t1col2r2", col3="bbb", col4="bbb"), - dict(col2="t1col2r3", col3="ccc", col4="ccc"), - ]) - t2.insert().execute([ - dict(col2="t2col2r1", col3="aaa", col4="bbb"), - dict(col2="t2col2r2", col3="bbb", col4="ccc"), - dict(col2="t2col2r3", col3="ccc", col4="aaa"), - ]) - t3.insert().execute([ - dict(col2="t3col2r1", col3="aaa", col4="ccc"), - dict(col2="t3col2r2", col3="bbb", col4="aaa"), - dict(col2="t3col2r3", col3="ccc", col4="bbb"), - ]) + t1.insert().execute( + [ + dict(col2="t1col2r1", col3="aaa", col4="aaa"), + dict(col2="t1col2r2", col3="bbb", col4="bbb"), + dict(col2="t1col2r3", col3="ccc", col4="ccc"), + ] + ) + t2.insert().execute( + [ + dict(col2="t2col2r1", col3="aaa", col4="bbb"), + dict(col2="t2col2r2", col3="bbb", col4="ccc"), + dict(col2="t2col2r3", col3="ccc", col4="aaa"), + ] + ) + t3.insert().execute( + [ + dict(col2="t3col2r1", col3="aaa", col4="ccc"), + dict(col2="t3col2r2", col3="bbb", col4="aaa"), + dict(col2="t3col2r3", col3="ccc", col4="bbb"), + ] + ) @engines.close_first def teardown(self): @@ -903,70 +1129,92 @@ class CompoundTest(fixtures.TestBase): @testing.requires.subqueries def test_union(self): (s1, s2) = ( - select([t1.c.col3.label('col3'), t1.c.col4.label('col4')], - t1.c.col2.in_(["t1col2r1", "t1col2r2"])), - select([t2.c.col3.label('col3'), t2.c.col4.label('col4')], - t2.c.col2.in_(["t2col2r2", "t2col2r3"])) + select( + [t1.c.col3.label("col3"), t1.c.col4.label("col4")], + t1.c.col2.in_(["t1col2r1", "t1col2r2"]), + ), + select( + [t2.c.col3.label("col3"), t2.c.col4.label("col4")], + t2.c.col2.in_(["t2col2r2", "t2col2r3"]), + ), ) u = union(s1, s2) - wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), - ('ccc', 'aaa')] + wanted = [ + ("aaa", "aaa"), + ("bbb", "bbb"), + ("bbb", "ccc"), + ("ccc", "aaa"), + ] found1 = self._fetchall_sorted(u.execute()) eq_(found1, wanted) - found2 = self._fetchall_sorted(u.alias('bar').select().execute()) + found2 = self._fetchall_sorted(u.alias("bar").select().execute()) eq_(found2, wanted) - @testing.fails_on('firebird', "doesn't like ORDER BY with UNIONs") + @testing.fails_on("firebird", "doesn't like ORDER BY with UNIONs") def test_union_ordered(self): (s1, s2) = ( - select([t1.c.col3.label('col3'), t1.c.col4.label('col4')], - t1.c.col2.in_(["t1col2r1", "t1col2r2"])), - select([t2.c.col3.label('col3'), t2.c.col4.label('col4')], - t2.c.col2.in_(["t2col2r2", "t2col2r3"])) + select( + [t1.c.col3.label("col3"), t1.c.col4.label("col4")], + t1.c.col2.in_(["t1col2r1", "t1col2r2"]), + ), + select( + [t2.c.col3.label("col3"), t2.c.col4.label("col4")], + t2.c.col2.in_(["t2col2r2", "t2col2r3"]), + ), ) - u = union(s1, s2, order_by=['col3', 'col4']) - - wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), - ('ccc', 'aaa')] + u = union(s1, s2, order_by=["col3", "col4"]) + + wanted = [ + ("aaa", "aaa"), + ("bbb", "bbb"), + ("bbb", "ccc"), + ("ccc", "aaa"), + ] eq_(u.execute().fetchall(), wanted) - @testing.fails_on('firebird', "doesn't like ORDER BY with UNIONs") + @testing.fails_on("firebird", "doesn't like ORDER BY with UNIONs") @testing.requires.subqueries def test_union_ordered_alias(self): (s1, s2) = ( - select([t1.c.col3.label('col3'), t1.c.col4.label('col4')], - t1.c.col2.in_(["t1col2r1", "t1col2r2"])), - select([t2.c.col3.label('col3'), t2.c.col4.label('col4')], - t2.c.col2.in_(["t2col2r2", "t2col2r3"])) + select( + [t1.c.col3.label("col3"), t1.c.col4.label("col4")], + t1.c.col2.in_(["t1col2r1", "t1col2r2"]), + ), + select( + [t2.c.col3.label("col3"), t2.c.col4.label("col4")], + t2.c.col2.in_(["t2col2r2", "t2col2r3"]), + ), ) - u = union(s1, s2, order_by=['col3', 'col4']) + u = union(s1, s2, order_by=["col3", "col4"]) - wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), - ('ccc', 'aaa')] - eq_(u.alias('bar').select().execute().fetchall(), wanted) + wanted = [ + ("aaa", "aaa"), + ("bbb", "bbb"), + ("bbb", "ccc"), + ("ccc", "aaa"), + ] + eq_(u.alias("bar").select().execute().fetchall(), wanted) - @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on') + @testing.crashes("oracle", "FIXME: unknown, verify not fails_on") @testing.fails_on( - 'firebird', - "has trouble extracting anonymous column from union subquery") - @testing.fails_on('mysql', 'FIXME: unknown') - @testing.fails_on('sqlite', 'FIXME: unknown') + "firebird", + "has trouble extracting anonymous column from union subquery", + ) + @testing.fails_on("mysql", "FIXME: unknown") + @testing.fails_on("sqlite", "FIXME: unknown") def test_union_all(self): e = union_all( select([t1.c.col3]), - union( - select([t1.c.col3]), - select([t1.c.col3]), - ) + union(select([t1.c.col3]), select([t1.c.col3])), ) - wanted = [('aaa',), ('aaa',), ('bbb',), ('bbb',), ('ccc',), ('ccc',)] + wanted = [("aaa",), ("aaa",), ("bbb",), ("bbb",), ("ccc",), ("ccc",)] found1 = self._fetchall_sorted(e.execute()) eq_(found1, wanted) - found2 = self._fetchall_sorted(e.alias('foo').select().execute()) + found2 = self._fetchall_sorted(e.alias("foo").select().execute()) eq_(found2, wanted) def test_union_all_lightweight(self): @@ -976,49 +1224,52 @@ class CompoundTest(fixtures.TestBase): """ - u = union( - select([t1.c.col3]), - select([t1.c.col3]), - ).alias() + u = union(select([t1.c.col3]), select([t1.c.col3])).alias() - e = union_all( - select([t1.c.col3]), - select([u.c.col3]) - ) + e = union_all(select([t1.c.col3]), select([u.c.col3])) - wanted = [('aaa',), ('aaa',), ('bbb',), ('bbb',), ('ccc',), ('ccc',)] + wanted = [("aaa",), ("aaa",), ("bbb",), ("bbb",), ("ccc",), ("ccc",)] found1 = self._fetchall_sorted(e.execute()) eq_(found1, wanted) - found2 = self._fetchall_sorted(e.alias('foo').select().execute()) + found2 = self._fetchall_sorted(e.alias("foo").select().execute()) eq_(found2, wanted) @testing.requires.intersect def test_intersect(self): i = intersect( select([t2.c.col3, t2.c.col4]), - select([t2.c.col3, t2.c.col4], t2.c.col4 == t3.c.col3) + select([t2.c.col3, t2.c.col4], t2.c.col4 == t3.c.col3), ) - wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] + wanted = [("aaa", "bbb"), ("bbb", "ccc"), ("ccc", "aaa")] found1 = self._fetchall_sorted(i.execute()) eq_(found1, wanted) - found2 = self._fetchall_sorted(i.alias('bar').select().execute()) + found2 = self._fetchall_sorted(i.alias("bar").select().execute()) eq_(found2, wanted) @testing.requires.except_ - @testing.fails_on('sqlite', "Can't handle this style of nesting") + @testing.fails_on("sqlite", "Can't handle this style of nesting") def test_except_style1(self): - e = except_(union( - select([t1.c.col3, t1.c.col4]), + e = except_( + union( + select([t1.c.col3, t1.c.col4]), + select([t2.c.col3, t2.c.col4]), + select([t3.c.col3, t3.c.col4]), + ), select([t2.c.col3, t2.c.col4]), - select([t3.c.col3, t3.c.col4]), - ), select([t2.c.col3, t2.c.col4])) + ) - wanted = [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), - ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] + wanted = [ + ("aaa", "aaa"), + ("aaa", "ccc"), + ("bbb", "aaa"), + ("bbb", "bbb"), + ("ccc", "bbb"), + ("ccc", "ccc"), + ] found = self._fetchall_sorted(e.alias().select().execute()) eq_(found, wanted) @@ -1028,14 +1279,25 @@ class CompoundTest(fixtures.TestBase): # same as style1, but add alias().select() to the except_(). # sqlite can handle it now. - e = except_(union( - select([t1.c.col3, t1.c.col4]), + e = except_( + union( + select([t1.c.col3, t1.c.col4]), + select([t2.c.col3, t2.c.col4]), + select([t3.c.col3, t3.c.col4]), + ) + .alias() + .select(), select([t2.c.col3, t2.c.col4]), - select([t3.c.col3, t3.c.col4]), - ).alias().select(), select([t2.c.col3, t2.c.col4])) + ) - wanted = [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), - ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] + wanted = [ + ("aaa", "aaa"), + ("aaa", "ccc"), + ("bbb", "aaa"), + ("bbb", "bbb"), + ("ccc", "bbb"), + ("ccc", "ccc"), + ] found1 = self._fetchall_sorted(e.execute()) eq_(found1, wanted) @@ -1044,7 +1306,8 @@ class CompoundTest(fixtures.TestBase): eq_(found2, wanted) @testing.fails_on( - ['sqlite', 'mysql'], "Can't handle this style of nesting") + ["sqlite", "mysql"], "Can't handle this style of nesting" + ) @testing.requires.except_ def test_except_style3(self): # aaa, bbb, ccc - (aaa, bbb, ccc - (ccc)) = ccc @@ -1052,11 +1315,11 @@ class CompoundTest(fixtures.TestBase): select([t1.c.col3]), # aaa, bbb, ccc except_( select([t2.c.col3]), # aaa, bbb, ccc - select([t3.c.col3], t3.c.col3 == 'ccc'), # ccc - ) + select([t3.c.col3], t3.c.col3 == "ccc"), # ccc + ), ) - eq_(e.execute().fetchall(), [('ccc',)]) - eq_(e.alias('foo').select().execute().fetchall(), [('ccc',)]) + eq_(e.execute().fetchall(), [("ccc",)]) + eq_(e.alias("foo").select().execute().fetchall(), [("ccc",)]) @testing.requires.except_ def test_except_style4(self): @@ -1065,31 +1328,31 @@ class CompoundTest(fixtures.TestBase): select([t1.c.col3]), # aaa, bbb, ccc except_( select([t2.c.col3]), # aaa, bbb, ccc - select([t3.c.col3], t3.c.col3 == 'ccc'), # ccc - ).alias().select() + select([t3.c.col3], t3.c.col3 == "ccc"), # ccc + ) + .alias() + .select(), ) - eq_(e.execute().fetchall(), [('ccc',)]) - eq_( - e.alias().select().execute().fetchall(), - [('ccc',)] - ) + eq_(e.execute().fetchall(), [("ccc",)]) + eq_(e.alias().select().execute().fetchall(), [("ccc",)]) @testing.requires.intersect - @testing.fails_on(['sqlite', 'mysql'], - "sqlite can't handle leading parenthesis") + @testing.fails_on( + ["sqlite", "mysql"], "sqlite can't handle leading parenthesis" + ) def test_intersect_unions(self): u = intersect( union( - select([t1.c.col3, t1.c.col4]), - select([t3.c.col3, t3.c.col4]), + select([t1.c.col3, t1.c.col4]), select([t3.c.col3, t3.c.col4]) ), union( - select([t2.c.col3, t2.c.col4]), - select([t3.c.col3, t3.c.col4]), - ).alias().select() + select([t2.c.col3, t2.c.col4]), select([t3.c.col3, t3.c.col4]) + ) + .alias() + .select(), ) - wanted = [('aaa', 'ccc'), ('bbb', 'aaa'), ('ccc', 'bbb')] + wanted = [("aaa", "ccc"), ("bbb", "aaa"), ("ccc", "bbb")] found = self._fetchall_sorted(u.execute()) eq_(found, wanted) @@ -1098,15 +1361,17 @@ class CompoundTest(fixtures.TestBase): def test_intersect_unions_2(self): u = intersect( union( - select([t1.c.col3, t1.c.col4]), - select([t3.c.col3, t3.c.col4]), - ).alias().select(), + select([t1.c.col3, t1.c.col4]), select([t3.c.col3, t3.c.col4]) + ) + .alias() + .select(), union( - select([t2.c.col3, t2.c.col4]), - select([t3.c.col3, t3.c.col4]), - ).alias().select() + select([t2.c.col3, t2.c.col4]), select([t3.c.col3, t3.c.col4]) + ) + .alias() + .select(), ) - wanted = [('aaa', 'ccc'), ('bbb', 'aaa'), ('ccc', 'bbb')] + wanted = [("aaa", "ccc"), ("bbb", "aaa"), ("ccc", "bbb")] found = self._fetchall_sorted(u.execute()) eq_(found, wanted) @@ -1119,9 +1384,11 @@ class CompoundTest(fixtures.TestBase): select([t1.c.col3, t1.c.col4]), select([t2.c.col3, t2.c.col4]), select([t3.c.col3, t3.c.col4]), - ).alias().select() + ) + .alias() + .select(), ) - wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] + wanted = [("aaa", "bbb"), ("bbb", "ccc"), ("ccc", "aaa")] found = self._fetchall_sorted(u.execute()) eq_(found, wanted) @@ -1134,10 +1401,12 @@ class CompoundTest(fixtures.TestBase): select([t1.c.col3, t1.c.col4]), select([t2.c.col3, t2.c.col4]), select([t3.c.col3, t3.c.col4]), - ).alias().select() + ) + .alias() + .select(), ).alias() - wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] + wanted = [("aaa", "bbb"), ("bbb", "ccc"), ("ccc", "aaa")] found = self._fetchall_sorted(ua.select().execute()) eq_(found, wanted) @@ -1155,6 +1424,7 @@ class JoinTest(fixtures.TestBase): `JOIN rhs ON lhs.col=rhs.col` vs `rhs.col=lhs.col`. At least one database seems to be sensitive to this. """ + __backend__ = True @classmethod @@ -1163,29 +1433,42 @@ class JoinTest(fixtures.TestBase): global t1, t2, t3 metadata = MetaData(testing.db) - t1 = Table('t1', metadata, - Column('t1_id', Integer, primary_key=True), - Column('name', String(32))) - t2 = Table('t2', metadata, - Column('t2_id', Integer, primary_key=True), - Column('t1_id', Integer, ForeignKey('t1.t1_id')), - Column('name', String(32))) - t3 = Table('t3', metadata, - Column('t3_id', Integer, primary_key=True), - Column('t2_id', Integer, ForeignKey('t2.t2_id')), - Column('name', String(32))) + t1 = Table( + "t1", + metadata, + Column("t1_id", Integer, primary_key=True), + Column("name", String(32)), + ) + t2 = Table( + "t2", + metadata, + Column("t2_id", Integer, primary_key=True), + Column("t1_id", Integer, ForeignKey("t1.t1_id")), + Column("name", String(32)), + ) + t3 = Table( + "t3", + metadata, + Column("t3_id", Integer, primary_key=True), + Column("t2_id", Integer, ForeignKey("t2.t2_id")), + Column("name", String(32)), + ) metadata.drop_all() metadata.create_all() # t1.10 -> t2.20 -> t3.30 # t1.11 -> t2.21 # t1.12 - t1.insert().execute({'t1_id': 10, 'name': 't1 #10'}, - {'t1_id': 11, 'name': 't1 #11'}, - {'t1_id': 12, 'name': 't1 #12'}) - t2.insert().execute({'t2_id': 20, 't1_id': 10, 'name': 't2 #20'}, - {'t2_id': 21, 't1_id': 11, 'name': 't2 #21'}) - t3.insert().execute({'t3_id': 30, 't2_id': 20, 'name': 't3 #30'}) + t1.insert().execute( + {"t1_id": 10, "name": "t1 #10"}, + {"t1_id": 11, "name": "t1 #11"}, + {"t1_id": 12, "name": "t1 #12"}, + ) + t2.insert().execute( + {"t2_id": 20, "t1_id": 10, "name": "t2 #20"}, + {"t2_id": 21, "t1_id": 11, "name": "t2 #21"}, + ) + t3.insert().execute({"t3_id": 30, "t2_id": 20, "name": "t3 #30"}) @classmethod def teardown_class(cls): @@ -1194,8 +1477,7 @@ class JoinTest(fixtures.TestBase): def assertRows(self, statement, expected): """Execute a statement and assert that rows returned equal expected.""" - found = sorted([tuple(row) - for row in statement.execute().fetchall()]) + found = sorted([tuple(row) for row in statement.execute().fetchall()]) eq_(found, sorted(expected)) @@ -1204,8 +1486,8 @@ class JoinTest(fixtures.TestBase): for criteria in (t1.c.t1_id == t2.c.t1_id, t2.c.t1_id == t1.c.t1_id): expr = select( - [t1.c.t1_id, t2.c.t2_id], - from_obj=[t1.join(t2, criteria)]) + [t1.c.t1_id, t2.c.t2_id], from_obj=[t1.join(t2, criteria)] + ) self.assertRows(expr, [(10, 20), (11, 21)]) def test_join_x2(self): @@ -1213,8 +1495,8 @@ class JoinTest(fixtures.TestBase): for criteria in (t1.c.t1_id == t2.c.t1_id, t2.c.t1_id == t1.c.t1_id): expr = select( - [t1.c.t1_id, t2.c.t2_id], - from_obj=[t1.join(t2, criteria)]) + [t1.c.t1_id, t2.c.t2_id], from_obj=[t1.join(t2, criteria)] + ) self.assertRows(expr, [(10, 20), (11, 21)]) def test_outerjoin_x1(self): @@ -1223,7 +1505,8 @@ class JoinTest(fixtures.TestBase): for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = select( [t1.c.t1_id, t2.c.t2_id], - from_obj=[t1.join(t2).join(t3, criteria)]) + from_obj=[t1.join(t2).join(t3, criteria)], + ) self.assertRows(expr, [(10, 20)]) def test_outerjoin_x2(self): @@ -1232,10 +1515,15 @@ class JoinTest(fixtures.TestBase): for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], - from_obj=[t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id). - outerjoin(t3, criteria)]) + from_obj=[ + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria + ) + ], + ) self.assertRows( - expr, [(10, 20, 30), (11, 21, None), (12, None, None)]) + expr, [(10, 20, 30), (11, 21, None), (12, None, None)] + ) def test_outerjoin_where_x2_t1(self): """Outer joins t1->t2,t3, where on t1.""" @@ -1243,16 +1531,28 @@ class JoinTest(fixtures.TestBase): for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], - t1.c.name == 't1 #10', - from_obj=[(t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id). - outerjoin(t3, criteria))]) + t1.c.name == "t1 #10", + from_obj=[ + ( + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria + ) + ) + ], + ) self.assertRows(expr, [(10, 20, 30)]) expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], t1.c.t1_id < 12, - from_obj=[(t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id). - outerjoin(t3, criteria))]) + from_obj=[ + ( + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria + ) + ) + ], + ) self.assertRows(expr, [(10, 20, 30), (11, 21, None)]) def test_outerjoin_where_x2_t2(self): @@ -1261,16 +1561,28 @@ class JoinTest(fixtures.TestBase): for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], - t2.c.name == 't2 #20', - from_obj=[(t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id). - outerjoin(t3, criteria))]) + t2.c.name == "t2 #20", + from_obj=[ + ( + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria + ) + ) + ], + ) self.assertRows(expr, [(10, 20, 30)]) expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], t2.c.t2_id < 29, - from_obj=[(t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id). - outerjoin(t3, criteria))]) + from_obj=[ + ( + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria + ) + ) + ], + ) self.assertRows(expr, [(10, 20, 30), (11, 21, None)]) def test_outerjoin_where_x2_t3(self): @@ -1279,16 +1591,28 @@ class JoinTest(fixtures.TestBase): for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], - t3.c.name == 't3 #30', - from_obj=[(t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id). - outerjoin(t3, criteria))]) + t3.c.name == "t3 #30", + from_obj=[ + ( + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria + ) + ) + ], + ) self.assertRows(expr, [(10, 20, 30)]) expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], t3.c.t3_id < 39, - from_obj=[(t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id). - outerjoin(t3, criteria))]) + from_obj=[ + ( + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria + ) + ) + ], + ) self.assertRows(expr, [(10, 20, 30)]) def test_outerjoin_where_x2_t1t3(self): @@ -1297,16 +1621,28 @@ class JoinTest(fixtures.TestBase): for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], - and_(t1.c.name == 't1 #10', t3.c.name == 't3 #30'), - from_obj=[(t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id). - outerjoin(t3, criteria))]) + and_(t1.c.name == "t1 #10", t3.c.name == "t3 #30"), + from_obj=[ + ( + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria + ) + ) + ], + ) self.assertRows(expr, [(10, 20, 30)]) expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], and_(t1.c.t1_id < 19, t3.c.t3_id < 39), - from_obj=[(t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id). - outerjoin(t3, criteria))]) + from_obj=[ + ( + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria + ) + ) + ], + ) self.assertRows(expr, [(10, 20, 30)]) def test_outerjoin_where_x2_t1t2(self): @@ -1315,16 +1651,28 @@ class JoinTest(fixtures.TestBase): for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], - and_(t1.c.name == 't1 #10', t2.c.name == 't2 #20'), - from_obj=[(t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id). - outerjoin(t3, criteria))]) + and_(t1.c.name == "t1 #10", t2.c.name == "t2 #20"), + from_obj=[ + ( + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria + ) + ) + ], + ) self.assertRows(expr, [(10, 20, 30)]) expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], and_(t1.c.t1_id < 12, t2.c.t2_id < 39), - from_obj=[(t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id). - outerjoin(t3, criteria))]) + from_obj=[ + ( + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria + ) + ) + ], + ) self.assertRows(expr, [(10, 20, 30), (11, 21, None)]) def test_outerjoin_where_x2_t1t2t3(self): @@ -1333,19 +1681,32 @@ class JoinTest(fixtures.TestBase): for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], - and_(t1.c.name == 't1 #10', - t2.c.name == 't2 #20', - t3.c.name == 't3 #30'), - from_obj=[(t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id). - outerjoin(t3, criteria))]) + and_( + t1.c.name == "t1 #10", + t2.c.name == "t2 #20", + t3.c.name == "t3 #30", + ), + from_obj=[ + ( + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria + ) + ) + ], + ) self.assertRows(expr, [(10, 20, 30)]) expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], and_(t1.c.t1_id < 19, t2.c.t2_id < 29, t3.c.t3_id < 39), from_obj=[ - (t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id). - outerjoin(t3, criteria))]) + ( + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria + ) + ) + ], + ) self.assertRows(expr, [(10, 20, 30)]) def test_mixed(self): @@ -1354,7 +1715,8 @@ class JoinTest(fixtures.TestBase): for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], - from_obj=[(t1.join(t2).outerjoin(t3, criteria))]) + from_obj=[(t1.join(t2).outerjoin(t3, criteria))], + ) print(expr) self.assertRows(expr, [(10, 20, 30), (11, 21, None)]) @@ -1364,40 +1726,48 @@ class JoinTest(fixtures.TestBase): for criteria in (t2.c.t2_id == t3.c.t2_id, t3.c.t2_id == t2.c.t2_id): expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], - t1.c.name == 't1 #10', - from_obj=[(t1.join(t2).outerjoin(t3, criteria))]) + t1.c.name == "t1 #10", + from_obj=[(t1.join(t2).outerjoin(t3, criteria))], + ) self.assertRows(expr, [(10, 20, 30)]) expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], - t2.c.name == 't2 #20', - from_obj=[(t1.join(t2).outerjoin(t3, criteria))]) + t2.c.name == "t2 #20", + from_obj=[(t1.join(t2).outerjoin(t3, criteria))], + ) self.assertRows(expr, [(10, 20, 30)]) expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], - t3.c.name == 't3 #30', - from_obj=[(t1.join(t2).outerjoin(t3, criteria))]) + t3.c.name == "t3 #30", + from_obj=[(t1.join(t2).outerjoin(t3, criteria))], + ) self.assertRows(expr, [(10, 20, 30)]) expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], - and_(t1.c.name == 't1 #10', t2.c.name == 't2 #20'), - from_obj=[(t1.join(t2).outerjoin(t3, criteria))]) + and_(t1.c.name == "t1 #10", t2.c.name == "t2 #20"), + from_obj=[(t1.join(t2).outerjoin(t3, criteria))], + ) self.assertRows(expr, [(10, 20, 30)]) expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], - and_(t2.c.name == 't2 #20', t3.c.name == 't3 #30'), - from_obj=[(t1.join(t2).outerjoin(t3, criteria))]) + and_(t2.c.name == "t2 #20", t3.c.name == "t3 #30"), + from_obj=[(t1.join(t2).outerjoin(t3, criteria))], + ) self.assertRows(expr, [(10, 20, 30)]) expr = select( [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], - and_(t1.c.name == 't1 #10', - t2.c.name == 't2 #20', - t3.c.name == 't3 #30'), - from_obj=[(t1.join(t2).outerjoin(t3, criteria))]) + and_( + t1.c.name == "t1 #10", + t2.c.name == "t2 #20", + t3.c.name == "t3 #30", + ), + from_obj=[(t1.join(t2).outerjoin(t3, criteria))], + ) self.assertRows(expr, [(10, 20, 30)]) @@ -1412,19 +1782,22 @@ class OperatorTest(fixtures.TestBase): global metadata, flds metadata = MetaData(testing.db) flds = Table( - 'flds', metadata, + "flds", + metadata, Column( - 'idcol', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('intcol', Integer), - Column('strcol', String(50)), + "idcol", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("intcol", Integer), + Column("strcol", String(50)), ) metadata.create_all() - flds.insert().execute([ - dict(intcol=5, strcol='foo'), - dict(intcol=13, strcol='bar') - ]) + flds.insert().execute( + [dict(intcol=5, strcol="foo"), dict(intcol=13, strcol="bar")] + ) @classmethod def teardown_class(cls): @@ -1433,16 +1806,19 @@ class OperatorTest(fixtures.TestBase): # TODO: seems like more tests warranted for this setup. def test_modulo(self): eq_( - select([flds.c.intcol % 3], - order_by=flds.c.idcol).execute().fetchall(), - [(2,), (1,)] + select([flds.c.intcol % 3], order_by=flds.c.idcol) + .execute() + .fetchall(), + [(2,), (1,)], ) @testing.requires.window_functions def test_over(self): eq_( - select([ - flds.c.intcol, func.row_number().over(order_by=flds.c.strcol) - ]).execute().fetchall(), - [(13, 1), (5, 2)] + select( + [flds.c.intcol, func.row_number().over(order_by=flds.c.strcol)] + ) + .execute() + .fetchall(), + [(13, 1), (5, 2)], ) diff --git a/test/sql/test_quote.py b/test/sql/test_quote.py index a51e14244e..4cc185298e 100644 --- a/test/sql/test_quote.py +++ b/test/sql/test_quote.py @@ -1,5 +1,15 @@ -from sqlalchemy import MetaData, Table, Column, Integer, select, \ - ForeignKey, Index, CheckConstraint, inspect, column +from sqlalchemy import ( + MetaData, + Table, + Column, + Integer, + select, + ForeignKey, + Index, + CheckConstraint, + inspect, + column, +) from sqlalchemy import sql, schema, types as sqltypes from sqlalchemy.sql import compiler from sqlalchemy.testing import fixtures, AssertsCompiledSQL, eq_ @@ -21,15 +31,21 @@ class QuoteExecTest(fixtures.TestBase): global table1, table2 metadata = MetaData(testing.db) - table1 = Table('WorstCase1', metadata, - Column('lowercase', Integer, primary_key=True), - Column('UPPERCASE', Integer), - Column('MixedCase', Integer), - Column('ASC', Integer, key='a123')) - table2 = Table('WorstCase2', metadata, - Column('desc', Integer, primary_key=True, key='d123'), - Column('Union', Integer, key='u123'), - Column('MixedCase', Integer)) + table1 = Table( + "WorstCase1", + metadata, + Column("lowercase", Integer, primary_key=True), + Column("UPPERCASE", Integer), + Column("MixedCase", Integer), + Column("ASC", Integer, key="a123"), + ) + table2 = Table( + "WorstCase2", + metadata, + Column("desc", Integer, primary_key=True, key="d123"), + Column("Union", Integer, key="u123"), + Column("MixedCase", Integer), + ) table1.create() table2.create() @@ -45,8 +61,8 @@ class QuoteExecTest(fixtures.TestBase): def test_reflect(self): meta2 = MetaData(testing.db) - t2 = Table('WorstCase1', meta2, autoload=True, quote=True) - assert 'lowercase' in t2.c + t2 = Table("WorstCase1", meta2, autoload=True, quote=True) + assert "lowercase" in t2.c # indicates the DB returns unquoted names as # UPPERCASE, which we then assume are unquoted and go to @@ -54,16 +70,16 @@ class QuoteExecTest(fixtures.TestBase): # names from a "name normalize" backend, as they cannot be # distinguished from case-insensitive/unquoted names. if testing.db.dialect.requires_name_normalize: - assert 'uppercase' in t2.c + assert "uppercase" in t2.c else: - assert 'UPPERCASE' in t2.c + assert "UPPERCASE" in t2.c # ASC OTOH is a reserved word, which is always quoted, so # with that name we keep the quotes on and it stays uppercase # regardless. Seems a little weird, though. - assert 'ASC' in t2.c + assert "ASC" in t2.c - assert 'MixedCase' in t2.c + assert "MixedCase" in t2.c @testing.provide_metadata def test_has_table_case_sensitive(self): @@ -72,97 +88,101 @@ class QuoteExecTest(fixtures.TestBase): testing.db.execute("CREATE TABLE TAB1 (id INTEGER)") else: testing.db.execute("CREATE TABLE tab1 (id INTEGER)") - testing.db.execute('CREATE TABLE %s (id INTEGER)' % - preparer.quote_identifier("tab2")) - testing.db.execute('CREATE TABLE %s (id INTEGER)' % - preparer.quote_identifier("TAB3")) - testing.db.execute('CREATE TABLE %s (id INTEGER)' % - preparer.quote_identifier("TAB4")) - - t1 = Table('tab1', self.metadata, - Column('id', Integer, primary_key=True), - ) - t2 = Table('tab2', self.metadata, - Column('id', Integer, primary_key=True), - quote=True - ) - t3 = Table('TAB3', self.metadata, - Column('id', Integer, primary_key=True), - ) - t4 = Table('TAB4', self.metadata, - Column('id', Integer, primary_key=True), - quote=True) + testing.db.execute( + "CREATE TABLE %s (id INTEGER)" % preparer.quote_identifier("tab2") + ) + testing.db.execute( + "CREATE TABLE %s (id INTEGER)" % preparer.quote_identifier("TAB3") + ) + testing.db.execute( + "CREATE TABLE %s (id INTEGER)" % preparer.quote_identifier("TAB4") + ) + + t1 = Table( + "tab1", self.metadata, Column("id", Integer, primary_key=True) + ) + t2 = Table( + "tab2", + self.metadata, + Column("id", Integer, primary_key=True), + quote=True, + ) + t3 = Table( + "TAB3", self.metadata, Column("id", Integer, primary_key=True) + ) + t4 = Table( + "TAB4", + self.metadata, + Column("id", Integer, primary_key=True), + quote=True, + ) insp = inspect(testing.db) assert testing.db.has_table(t1.name) - eq_([c['name'] for c in insp.get_columns(t1.name)], ['id']) + eq_([c["name"] for c in insp.get_columns(t1.name)], ["id"]) assert testing.db.has_table(t2.name) - eq_([c['name'] for c in insp.get_columns(t2.name)], ['id']) + eq_([c["name"] for c in insp.get_columns(t2.name)], ["id"]) assert testing.db.has_table(t3.name) - eq_([c['name'] for c in insp.get_columns(t3.name)], ['id']) + eq_([c["name"] for c in insp.get_columns(t3.name)], ["id"]) assert testing.db.has_table(t4.name) - eq_([c['name'] for c in insp.get_columns(t4.name)], ['id']) + eq_([c["name"] for c in insp.get_columns(t4.name)], ["id"]) def test_basic(self): table1.insert().execute( - {'lowercase': 1, 'UPPERCASE': 2, 'MixedCase': 3, 'a123': 4}, - {'lowercase': 2, 'UPPERCASE': 2, 'MixedCase': 3, 'a123': 4}, - {'lowercase': 4, 'UPPERCASE': 3, 'MixedCase': 2, 'a123': 1}) + {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, + {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, + {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1}, + ) table2.insert().execute( - {'d123': 1, 'u123': 2, 'MixedCase': 3}, - {'d123': 2, 'u123': 2, 'MixedCase': 3}, - {'d123': 4, 'u123': 3, 'MixedCase': 2}) + {"d123": 1, "u123": 2, "MixedCase": 3}, + {"d123": 2, "u123": 2, "MixedCase": 3}, + {"d123": 4, "u123": 3, "MixedCase": 2}, + ) columns = [ table1.c.lowercase, table1.c.UPPERCASE, table1.c.MixedCase, - table1.c.a123 + table1.c.a123, ] result = select(columns).execute().fetchall() - assert(result == [(1, 2, 3, 4), (2, 2, 3, 4), (4, 3, 2, 1)]) + assert result == [(1, 2, 3, 4), (2, 2, 3, 4), (4, 3, 2, 1)] - columns = [ - table2.c.d123, - table2.c.u123, - table2.c.MixedCase - ] + columns = [table2.c.d123, table2.c.u123, table2.c.MixedCase] result = select(columns).execute().fetchall() - assert(result == [(1, 2, 3), (2, 2, 3), (4, 3, 2)]) + assert result == [(1, 2, 3), (2, 2, 3), (4, 3, 2)] def test_use_labels(self): table1.insert().execute( - {'lowercase': 1, 'UPPERCASE': 2, 'MixedCase': 3, 'a123': 4}, - {'lowercase': 2, 'UPPERCASE': 2, 'MixedCase': 3, 'a123': 4}, - {'lowercase': 4, 'UPPERCASE': 3, 'MixedCase': 2, 'a123': 1}) + {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, + {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, + {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1}, + ) table2.insert().execute( - {'d123': 1, 'u123': 2, 'MixedCase': 3}, - {'d123': 2, 'u123': 2, 'MixedCase': 3}, - {'d123': 4, 'u123': 3, 'MixedCase': 2}) + {"d123": 1, "u123": 2, "MixedCase": 3}, + {"d123": 2, "u123": 2, "MixedCase": 3}, + {"d123": 4, "u123": 3, "MixedCase": 2}, + ) columns = [ table1.c.lowercase, table1.c.UPPERCASE, table1.c.MixedCase, - table1.c.a123 + table1.c.a123, ] result = select(columns, use_labels=True).execute().fetchall() - assert(result == [(1, 2, 3, 4), (2, 2, 3, 4), (4, 3, 2, 1)]) + assert result == [(1, 2, 3, 4), (2, 2, 3, 4), (4, 3, 2, 1)] - columns = [ - table2.c.d123, - table2.c.u123, - table2.c.MixedCase - ] + columns = [table2.c.d123, table2.c.u123, table2.c.MixedCase] result = select(columns, use_labels=True).execute().fetchall() - assert(result == [(1, 2, 3), (2, 2, 3), (4, 3, 2)]) + assert result == [(1, 2, 3), (2, 2, 3), (4, 3, 2)] class QuoteTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" @classmethod def setup_class(cls): @@ -173,17 +193,23 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL): global table1, table2 metadata = MetaData(testing.db) - table1 = Table('WorstCase1', metadata, - Column('lowercase', Integer, primary_key=True), - Column('UPPERCASE', Integer), - Column('MixedCase', Integer), - Column('ASC', Integer, key='a123')) - table2 = Table('WorstCase2', metadata, - Column('desc', Integer, primary_key=True, key='d123'), - Column('Union', Integer, key='u123'), - Column('MixedCase', Integer)) - - @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on') + table1 = Table( + "WorstCase1", + metadata, + Column("lowercase", Integer, primary_key=True), + Column("UPPERCASE", Integer), + Column("MixedCase", Integer), + Column("ASC", Integer, key="a123"), + ) + table2 = Table( + "WorstCase2", + metadata, + Column("desc", Integer, primary_key=True, key="d123"), + Column("Union", Integer, key="u123"), + Column("MixedCase", Integer), + ) + + @testing.crashes("oracle", "FIXME: unknown, verify not fails_on") @testing.requires.subqueries def test_labels(self): """test the quoting of labels. @@ -206,38 +232,41 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL): """ self.assert_compile( - table1.select(distinct=True).alias('LaLa').select(), - 'SELECT ' + table1.select(distinct=True).alias("LaLa").select(), + "SELECT " '"LaLa".lowercase, ' '"LaLa"."UPPERCASE", ' '"LaLa"."MixedCase", ' '"LaLa"."ASC" ' - 'FROM (' - 'SELECT DISTINCT ' + "FROM (" + "SELECT DISTINCT " '"WorstCase1".lowercase AS lowercase, ' '"WorstCase1"."UPPERCASE" AS "UPPERCASE", ' '"WorstCase1"."MixedCase" AS "MixedCase", ' '"WorstCase1"."ASC" AS "ASC" ' 'FROM "WorstCase1"' - ') AS "LaLa"' + ') AS "LaLa"', ) def test_lower_case_names(self): # Create table with quote defaults metadata = MetaData() - t1 = Table('t1', metadata, - Column('col1', Integer), - schema='foo') + t1 = Table("t1", metadata, Column("col1", Integer), schema="foo") # Note that the names are not quoted b/c they are all lower case - result = 'CREATE TABLE foo.t1 (col1 INTEGER)' + result = "CREATE TABLE foo.t1 (col1 INTEGER)" self.assert_compile(schema.CreateTable(t1), result) # Create the same table with quotes set to True now metadata = MetaData() - t1 = Table('t1', metadata, - Column('col1', Integer, quote=True), - schema='foo', quote=True, quote_schema=True) + t1 = Table( + "t1", + metadata, + Column("col1", Integer, quote=True), + schema="foo", + quote=True, + quote_schema=True, + ) # Note that the names are now quoted result = 'CREATE TABLE "foo"."t1" ("col1" INTEGER)' @@ -246,9 +275,7 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL): def test_upper_case_names(self): # Create table with quote defaults metadata = MetaData() - t1 = Table('TABLE1', metadata, - Column('COL1', Integer), - schema='FOO') + t1 = Table("TABLE1", metadata, Column("COL1", Integer), schema="FOO") # Note that the names are quoted b/c they are not all lower case result = 'CREATE TABLE "FOO"."TABLE1" ("COL1" INTEGER)' @@ -256,20 +283,23 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL): # Create the same table with quotes set to False now metadata = MetaData() - t1 = Table('TABLE1', metadata, - Column('COL1', Integer, quote=False), - schema='FOO', quote=False, quote_schema=False) + t1 = Table( + "TABLE1", + metadata, + Column("COL1", Integer, quote=False), + schema="FOO", + quote=False, + quote_schema=False, + ) # Note that the names are now unquoted - result = 'CREATE TABLE FOO.TABLE1 (COL1 INTEGER)' + result = "CREATE TABLE FOO.TABLE1 (COL1 INTEGER)" self.assert_compile(schema.CreateTable(t1), result) def test_mixed_case_names(self): # Create table with quote defaults metadata = MetaData() - t1 = Table('Table1', metadata, - Column('Col1', Integer), - schema='Foo') + t1 = Table("Table1", metadata, Column("Col1", Integer), schema="Foo") # Note that the names are quoted b/c they are not all lower case result = 'CREATE TABLE "Foo"."Table1" ("Col1" INTEGER)' @@ -277,20 +307,25 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL): # Create the same table with quotes set to False now metadata = MetaData() - t1 = Table('Table1', metadata, - Column('Col1', Integer, quote=False), - schema='Foo', quote=False, quote_schema=False) + t1 = Table( + "Table1", + metadata, + Column("Col1", Integer, quote=False), + schema="Foo", + quote=False, + quote_schema=False, + ) # Note that the names are now unquoted - result = 'CREATE TABLE Foo.Table1 (Col1 INTEGER)' + result = "CREATE TABLE Foo.Table1 (Col1 INTEGER)" self.assert_compile(schema.CreateTable(t1), result) def test_numeric_initial_char(self): # Create table with quote defaults metadata = MetaData() - t1 = Table('35table', metadata, - Column('25column', Integer), - schema='45schema') + t1 = Table( + "35table", metadata, Column("25column", Integer), schema="45schema" + ) # Note that the names are quoted b/c the initial # character is in ['$','0', '1' ... '9'] @@ -299,20 +334,25 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL): # Create the same table with quotes set to False now metadata = MetaData() - t1 = Table('35table', metadata, - Column('25column', Integer, quote=False), - schema='45schema', quote=False, quote_schema=False) + t1 = Table( + "35table", + metadata, + Column("25column", Integer, quote=False), + schema="45schema", + quote=False, + quote_schema=False, + ) # Note that the names are now unquoted - result = 'CREATE TABLE 45schema.35table (25column INTEGER)' + result = "CREATE TABLE 45schema.35table (25column INTEGER)" self.assert_compile(schema.CreateTable(t1), result) def test_illegal_initial_char(self): # Create table with quote defaults metadata = MetaData() - t1 = Table('$table', metadata, - Column('$column', Integer), - schema='$schema') + t1 = Table( + "$table", metadata, Column("$column", Integer), schema="$schema" + ) # Note that the names are quoted b/c the initial # character is in ['$','0', '1' ... '9'] @@ -321,384 +361,400 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL): # Create the same table with quotes set to False now metadata = MetaData() - t1 = Table('$table', metadata, - Column('$column', Integer, quote=False), - schema='$schema', quote=False, quote_schema=False) + t1 = Table( + "$table", + metadata, + Column("$column", Integer, quote=False), + schema="$schema", + quote=False, + quote_schema=False, + ) # Note that the names are now unquoted - result = 'CREATE TABLE $schema.$table ($column INTEGER)' + result = "CREATE TABLE $schema.$table ($column INTEGER)" self.assert_compile(schema.CreateTable(t1), result) def test_reserved_words(self): # Create table with quote defaults metadata = MetaData() - table = Table('foreign', metadata, - Column('col1', Integer), - Column('from', Integer), - Column('order', Integer), - schema='create') + table = Table( + "foreign", + metadata, + Column("col1", Integer), + Column("from", Integer), + Column("order", Integer), + schema="create", + ) # Note that the names are quoted b/c they are reserved words - x = select([table.c.col1, table.c['from'], table.c.order]) - self.assert_compile(x, - 'SELECT ' - '"create"."foreign".col1, ' - '"create"."foreign"."from", ' - '"create"."foreign"."order" ' - 'FROM "create"."foreign"' - ) + x = select([table.c.col1, table.c["from"], table.c.order]) + self.assert_compile( + x, + "SELECT " + '"create"."foreign".col1, ' + '"create"."foreign"."from", ' + '"create"."foreign"."order" ' + 'FROM "create"."foreign"', + ) # Create the same table with quotes set to False now metadata = MetaData() - table = Table('foreign', metadata, - Column('col1', Integer), - Column('from', Integer, quote=False), - Column('order', Integer, quote=False), - schema='create', quote=False, quote_schema=False) + table = Table( + "foreign", + metadata, + Column("col1", Integer), + Column("from", Integer, quote=False), + Column("order", Integer, quote=False), + schema="create", + quote=False, + quote_schema=False, + ) # Note that the names are now unquoted - x = select([table.c.col1, table.c['from'], table.c.order]) - self.assert_compile(x, - 'SELECT ' - 'create.foreign.col1, ' - 'create.foreign.from, ' - 'create.foreign.order ' - 'FROM create.foreign' - ) + x = select([table.c.col1, table.c["from"], table.c.order]) + self.assert_compile( + x, + "SELECT " + "create.foreign.col1, " + "create.foreign.from, " + "create.foreign.order " + "FROM create.foreign", + ) def test_subquery_one(self): # Lower case names, should not quote metadata = MetaData() - t1 = Table('t1', metadata, - Column('col1', Integer), - schema='foo') - a = t1.select().alias('anon') + t1 = Table("t1", metadata, Column("col1", Integer), schema="foo") + a = t1.select().alias("anon") b = select([1], a.c.col1 == 2, from_obj=a) - self.assert_compile(b, - 'SELECT 1 ' - 'FROM (' - 'SELECT ' - 'foo.t1.col1 AS col1 ' - 'FROM ' - 'foo.t1' - ') AS anon ' - 'WHERE anon.col1 = :col1_1' - ) + self.assert_compile( + b, + "SELECT 1 " + "FROM (" + "SELECT " + "foo.t1.col1 AS col1 " + "FROM " + "foo.t1" + ") AS anon " + "WHERE anon.col1 = :col1_1", + ) def test_subquery_two(self): # Lower case names, quotes on, should quote metadata = MetaData() - t1 = Table('t1', metadata, - Column('col1', Integer, quote=True), - schema='foo', quote=True, quote_schema=True) - a = t1.select().alias('anon') + t1 = Table( + "t1", + metadata, + Column("col1", Integer, quote=True), + schema="foo", + quote=True, + quote_schema=True, + ) + a = t1.select().alias("anon") b = select([1], a.c.col1 == 2, from_obj=a) - self.assert_compile(b, - 'SELECT 1 ' - 'FROM (' - 'SELECT ' - '"foo"."t1"."col1" AS "col1" ' - 'FROM ' - '"foo"."t1"' - ') AS anon ' - 'WHERE anon."col1" = :col1_1' - ) + self.assert_compile( + b, + "SELECT 1 " + "FROM (" + "SELECT " + '"foo"."t1"."col1" AS "col1" ' + "FROM " + '"foo"."t1"' + ") AS anon " + 'WHERE anon."col1" = :col1_1', + ) def test_subquery_three(self): # Not lower case names, should quote metadata = MetaData() - t1 = Table('T1', metadata, - Column('Col1', Integer), - schema='Foo') - a = t1.select().alias('Anon') + t1 = Table("T1", metadata, Column("Col1", Integer), schema="Foo") + a = t1.select().alias("Anon") b = select([1], a.c.Col1 == 2, from_obj=a) - self.assert_compile(b, - 'SELECT 1 ' - 'FROM (' - 'SELECT ' - '"Foo"."T1"."Col1" AS "Col1" ' - 'FROM ' - '"Foo"."T1"' - ') AS "Anon" ' - 'WHERE ' - '"Anon"."Col1" = :Col1_1' - ) + self.assert_compile( + b, + "SELECT 1 " + "FROM (" + "SELECT " + '"Foo"."T1"."Col1" AS "Col1" ' + "FROM " + '"Foo"."T1"' + ') AS "Anon" ' + "WHERE " + '"Anon"."Col1" = :Col1_1', + ) def test_subquery_four(self): # Not lower case names, quotes off, should not quote metadata = MetaData() - t1 = Table('T1', metadata, - Column('Col1', Integer, quote=False), - schema='Foo', quote=False, quote_schema=False) - a = t1.select().alias('Anon') + t1 = Table( + "T1", + metadata, + Column("Col1", Integer, quote=False), + schema="Foo", + quote=False, + quote_schema=False, + ) + a = t1.select().alias("Anon") b = select([1], a.c.Col1 == 2, from_obj=a) - self.assert_compile(b, - 'SELECT 1 ' - 'FROM (' - 'SELECT ' - 'Foo.T1.Col1 AS Col1 ' - 'FROM ' - 'Foo.T1' - ') AS "Anon" ' - 'WHERE ' - '"Anon".Col1 = :Col1_1' - ) + self.assert_compile( + b, + "SELECT 1 " + "FROM (" + "SELECT " + "Foo.T1.Col1 AS Col1 " + "FROM " + "Foo.T1" + ') AS "Anon" ' + "WHERE " + '"Anon".Col1 = :Col1_1', + ) def test_simple_order_by_label(self): m = MetaData() - t1 = Table('t1', m, Column('col1', Integer)) - cl = t1.c.col1.label('ShouldQuote') + t1 = Table("t1", m, Column("col1", Integer)) + cl = t1.c.col1.label("ShouldQuote") self.assert_compile( select([cl]).order_by(cl), - 'SELECT t1.col1 AS "ShouldQuote" FROM t1 ORDER BY "ShouldQuote"' + 'SELECT t1.col1 AS "ShouldQuote" FROM t1 ORDER BY "ShouldQuote"', ) def test_collate(self): - self.assert_compile( - column('foo').collate('utf8'), - "foo COLLATE utf8" - ) + self.assert_compile(column("foo").collate("utf8"), "foo COLLATE utf8") self.assert_compile( - column('foo').collate('fr_FR'), + column("foo").collate("fr_FR"), 'foo COLLATE "fr_FR"', - dialect="postgresql" + dialect="postgresql", ) self.assert_compile( - column('foo').collate('utf8_GERMAN_ci'), - 'foo COLLATE `utf8_GERMAN_ci`', - dialect="mysql" + column("foo").collate("utf8_GERMAN_ci"), + "foo COLLATE `utf8_GERMAN_ci`", + dialect="mysql", ) self.assert_compile( - column('foo').collate('SQL_Latin1_General_CP1_CI_AS'), - 'foo COLLATE SQL_Latin1_General_CP1_CI_AS', - dialect="mssql" + column("foo").collate("SQL_Latin1_General_CP1_CI_AS"), + "foo COLLATE SQL_Latin1_General_CP1_CI_AS", + dialect="mssql", ) def test_join(self): # Lower case names, should not quote metadata = MetaData() - t1 = Table('t1', metadata, - Column('col1', Integer)) - t2 = Table('t2', metadata, - Column('col1', Integer), - Column('t1col1', Integer, ForeignKey('t1.col1'))) - self.assert_compile(t2.join(t1).select(), - 'SELECT ' - 't2.col1, t2.t1col1, t1.col1 ' - 'FROM ' - 't2 ' - 'JOIN ' - 't1 ON t1.col1 = t2.t1col1' - ) + t1 = Table("t1", metadata, Column("col1", Integer)) + t2 = Table( + "t2", + metadata, + Column("col1", Integer), + Column("t1col1", Integer, ForeignKey("t1.col1")), + ) + self.assert_compile( + t2.join(t1).select(), + "SELECT " + "t2.col1, t2.t1col1, t1.col1 " + "FROM " + "t2 " + "JOIN " + "t1 ON t1.col1 = t2.t1col1", + ) # Lower case names, quotes on, should quote metadata = MetaData() - t1 = Table('t1', metadata, - Column('col1', Integer, quote=True), - quote=True) + t1 = Table( + "t1", metadata, Column("col1", Integer, quote=True), quote=True + ) t2 = Table( - 't2', + "t2", metadata, - Column( - 'col1', - Integer, - quote=True), - Column( - 't1col1', - Integer, - ForeignKey('t1.col1'), - quote=True), - quote=True) - self.assert_compile(t2.join(t1).select(), - 'SELECT ' - '"t2"."col1", "t2"."t1col1", "t1"."col1" ' - 'FROM ' - '"t2" ' - 'JOIN ' - '"t1" ON "t1"."col1" = "t2"."t1col1"' - ) + Column("col1", Integer, quote=True), + Column("t1col1", Integer, ForeignKey("t1.col1"), quote=True), + quote=True, + ) + self.assert_compile( + t2.join(t1).select(), + "SELECT " + '"t2"."col1", "t2"."t1col1", "t1"."col1" ' + "FROM " + '"t2" ' + "JOIN " + '"t1" ON "t1"."col1" = "t2"."t1col1"', + ) # Not lower case names, should quote metadata = MetaData() - t1 = Table('T1', metadata, - Column('Col1', Integer)) - t2 = Table('T2', metadata, - Column('Col1', Integer), - Column('T1Col1', Integer, ForeignKey('T1.Col1'))) - self.assert_compile(t2.join(t1).select(), - 'SELECT ' - '"T2"."Col1", "T2"."T1Col1", "T1"."Col1" ' - 'FROM ' - '"T2" ' - 'JOIN ' - '"T1" ON "T1"."Col1" = "T2"."T1Col1"' - ) + t1 = Table("T1", metadata, Column("Col1", Integer)) + t2 = Table( + "T2", + metadata, + Column("Col1", Integer), + Column("T1Col1", Integer, ForeignKey("T1.Col1")), + ) + self.assert_compile( + t2.join(t1).select(), + "SELECT " + '"T2"."Col1", "T2"."T1Col1", "T1"."Col1" ' + "FROM " + '"T2" ' + "JOIN " + '"T1" ON "T1"."Col1" = "T2"."T1Col1"', + ) # Not lower case names, quotes off, should not quote metadata = MetaData() - t1 = Table('T1', metadata, - Column('Col1', Integer, quote=False), - quote=False) + t1 = Table( + "T1", metadata, Column("Col1", Integer, quote=False), quote=False + ) t2 = Table( - 'T2', + "T2", metadata, - Column( - 'Col1', - Integer, - quote=False), - Column( - 'T1Col1', - Integer, - ForeignKey('T1.Col1'), - quote=False), - quote=False) - self.assert_compile(t2.join(t1).select(), - 'SELECT ' - 'T2.Col1, T2.T1Col1, T1.Col1 ' - 'FROM ' - 'T2 ' - 'JOIN ' - 'T1 ON T1.Col1 = T2.T1Col1' - ) + Column("Col1", Integer, quote=False), + Column("T1Col1", Integer, ForeignKey("T1.Col1"), quote=False), + quote=False, + ) + self.assert_compile( + t2.join(t1).select(), + "SELECT " + "T2.Col1, T2.T1Col1, T1.Col1 " + "FROM " + "T2 " + "JOIN " + "T1 ON T1.Col1 = T2.T1Col1", + ) def test_label_and_alias(self): # Lower case names, should not quote metadata = MetaData() - table = Table('t1', metadata, - Column('col1', Integer)) - x = select([table.c.col1.label('label1')]).alias('alias1') - self.assert_compile(select([x.c.label1]), - 'SELECT ' - 'alias1.label1 ' - 'FROM (' - 'SELECT ' - 't1.col1 AS label1 ' - 'FROM t1' - ') AS alias1' - ) + table = Table("t1", metadata, Column("col1", Integer)) + x = select([table.c.col1.label("label1")]).alias("alias1") + self.assert_compile( + select([x.c.label1]), + "SELECT " + "alias1.label1 " + "FROM (" + "SELECT " + "t1.col1 AS label1 " + "FROM t1" + ") AS alias1", + ) # Not lower case names, should quote metadata = MetaData() - table = Table('T1', metadata, - Column('Col1', Integer)) - x = select([table.c.Col1.label('Label1')]).alias('Alias1') - self.assert_compile(select([x.c.Label1]), - 'SELECT ' - '"Alias1"."Label1" ' - 'FROM (' - 'SELECT ' - '"T1"."Col1" AS "Label1" ' - 'FROM "T1"' - ') AS "Alias1"' - ) + table = Table("T1", metadata, Column("Col1", Integer)) + x = select([table.c.Col1.label("Label1")]).alias("Alias1") + self.assert_compile( + select([x.c.Label1]), + "SELECT " + '"Alias1"."Label1" ' + "FROM (" + "SELECT " + '"T1"."Col1" AS "Label1" ' + 'FROM "T1"' + ') AS "Alias1"', + ) def test_literal_column_already_with_quotes(self): # Lower case names metadata = MetaData() - table = Table('t1', metadata, - Column('col1', Integer)) + table = Table("t1", metadata, Column("col1", Integer)) # Note that 'col1' is already quoted (literal_column) - columns = [sql.literal_column("'col1'").label('label1')] - x = select(columns, from_obj=[table]).alias('alias1') + columns = [sql.literal_column("'col1'").label("label1")] + x = select(columns, from_obj=[table]).alias("alias1") x = x.select() - self.assert_compile(x, - 'SELECT ' - 'alias1.label1 ' - 'FROM (' - 'SELECT ' - '\'col1\' AS label1 ' - 'FROM t1' - ') AS alias1' - ) + self.assert_compile( + x, + "SELECT " + "alias1.label1 " + "FROM (" + "SELECT " + "'col1' AS label1 " + "FROM t1" + ") AS alias1", + ) # Not lower case names metadata = MetaData() - table = Table('T1', metadata, - Column('Col1', Integer)) + table = Table("T1", metadata, Column("Col1", Integer)) # Note that 'Col1' is already quoted (literal_column) - columns = [sql.literal_column("'Col1'").label('Label1')] - x = select(columns, from_obj=[table]).alias('Alias1') + columns = [sql.literal_column("'Col1'").label("Label1")] + x = select(columns, from_obj=[table]).alias("Alias1") x = x.select() - self.assert_compile(x, - 'SELECT ' - '"Alias1"."Label1" ' - 'FROM (' - 'SELECT ' - '\'Col1\' AS "Label1" ' - 'FROM "T1"' - ') AS "Alias1"' - ) + self.assert_compile( + x, + "SELECT " + '"Alias1"."Label1" ' + "FROM (" + "SELECT " + "'Col1' AS \"Label1\" " + 'FROM "T1"' + ') AS "Alias1"', + ) def test_apply_labels_should_quote(self): # Not lower case names, should quote metadata = MetaData() - t1 = Table('T1', metadata, - Column('Col1', Integer), - schema='Foo') + t1 = Table("T1", metadata, Column("Col1", Integer), schema="Foo") - self.assert_compile(t1.select().apply_labels(), - 'SELECT ' - '"Foo"."T1"."Col1" AS "Foo_T1_Col1" ' - 'FROM ' - '"Foo"."T1"' - ) + self.assert_compile( + t1.select().apply_labels(), + "SELECT " + '"Foo"."T1"."Col1" AS "Foo_T1_Col1" ' + "FROM " + '"Foo"."T1"', + ) def test_apply_labels_shouldnt_quote(self): # Not lower case names, quotes off metadata = MetaData() - t1 = Table('T1', metadata, - Column('Col1', Integer, quote=False), - schema='Foo', quote=False, quote_schema=False) + t1 = Table( + "T1", + metadata, + Column("Col1", Integer, quote=False), + schema="Foo", + quote=False, + quote_schema=False, + ) # TODO: is this what we really want here ? # what if table/schema *are* quoted? - self.assert_compile(t1.select().apply_labels(), - 'SELECT ' - 'Foo.T1.Col1 AS Foo_T1_Col1 ' - 'FROM ' - 'Foo.T1' - ) + self.assert_compile( + t1.select().apply_labels(), + "SELECT " "Foo.T1.Col1 AS Foo_T1_Col1 " "FROM " "Foo.T1", + ) def test_quote_flag_propagate_check_constraint(self): m = MetaData() - t = Table('t', m, Column('x', Integer, quote=True)) + t = Table("t", m, Column("x", Integer, quote=True)) CheckConstraint(t.c.x > 5) self.assert_compile( schema.CreateTable(t), - "CREATE TABLE t (" - '"x" INTEGER, ' - 'CHECK ("x" > 5)' - ")" + "CREATE TABLE t (" '"x" INTEGER, ' 'CHECK ("x" > 5)' ")", ) def test_quote_flag_propagate_index(self): m = MetaData() - t = Table('t', m, Column('x', Integer, quote=True)) + t = Table("t", m, Column("x", Integer, quote=True)) idx = Index("foo", t.c.x) self.assert_compile( - schema.CreateIndex(idx), - 'CREATE INDEX foo ON t ("x")' + schema.CreateIndex(idx), 'CREATE INDEX foo ON t ("x")' ) def test_quote_flag_propagate_anon_label(self): m = MetaData() - t = Table('t', m, Column('x', Integer, quote=True)) + t = Table("t", m, Column("x", Integer, quote=True)) self.assert_compile( select([t.alias()]).apply_labels(), - 'SELECT t_1."x" AS "t_1_x" FROM t AS t_1' + 'SELECT t_1."x" AS "t_1_x" FROM t AS t_1', ) - t2 = Table('t2', m, Column('x', Integer), quote=True) + t2 = Table("t2", m, Column("x", Integer), quote=True) self.assert_compile( select([t2.c.x]).apply_labels(), - 'SELECT "t2".x AS "t2_x" FROM "t2"' + 'SELECT "t2".x AS "t2_x" FROM "t2"', ) @@ -716,28 +772,27 @@ class PreparerTest(fixtures.TestBase): print("Received %s" % have) self.assert_(have == want) - a_eq(unformat('foo'), ['foo']) - a_eq(unformat('"foo"'), ['foo']) + a_eq(unformat("foo"), ["foo"]) + a_eq(unformat('"foo"'), ["foo"]) a_eq(unformat("'foo'"), ["'foo'"]) - a_eq(unformat('foo.bar'), ['foo', 'bar']) - a_eq(unformat('"foo"."bar"'), ['foo', 'bar']) - a_eq(unformat('foo."bar"'), ['foo', 'bar']) - a_eq(unformat('"foo".bar'), ['foo', 'bar']) - a_eq(unformat('"foo"."b""a""r"."baz"'), ['foo', 'b"a"r', 'baz']) + a_eq(unformat("foo.bar"), ["foo", "bar"]) + a_eq(unformat('"foo"."bar"'), ["foo", "bar"]) + a_eq(unformat('foo."bar"'), ["foo", "bar"]) + a_eq(unformat('"foo".bar'), ["foo", "bar"]) + a_eq(unformat('"foo"."b""a""r"."baz"'), ["foo", 'b"a"r', "baz"]) def test_unformat_custom(self): - class Custom(compiler.IdentifierPreparer): - def __init__(self, dialect): super(Custom, self).__init__( - dialect, initial_quote='`', final_quote='`') + dialect, initial_quote="`", final_quote="`" + ) def _escape_identifier(self, value): - return value.replace('`', '``') + return value.replace("`", "``") def _unescape_identifier(self, value): - return value.replace('``', '`') + return value.replace("``", "`") prep = Custom(default.DefaultDialect()) unformat = prep.unformat_identifiers @@ -748,18 +803,17 @@ class PreparerTest(fixtures.TestBase): print("Received %s" % have) self.assert_(have == want) - a_eq(unformat('foo'), ['foo']) - a_eq(unformat('`foo`'), ['foo']) - a_eq(unformat(repr('foo')), ["'foo'"]) - a_eq(unformat('foo.bar'), ['foo', 'bar']) - a_eq(unformat('`foo`.`bar`'), ['foo', 'bar']) - a_eq(unformat('foo.`bar`'), ['foo', 'bar']) - a_eq(unformat('`foo`.bar'), ['foo', 'bar']) - a_eq(unformat('`foo`.`b``a``r`.`baz`'), ['foo', 'b`a`r', 'baz']) + a_eq(unformat("foo"), ["foo"]) + a_eq(unformat("`foo`"), ["foo"]) + a_eq(unformat(repr("foo")), ["'foo'"]) + a_eq(unformat("foo.bar"), ["foo", "bar"]) + a_eq(unformat("`foo`.`bar`"), ["foo", "bar"]) + a_eq(unformat("foo.`bar`"), ["foo", "bar"]) + a_eq(unformat("`foo`.bar"), ["foo", "bar"]) + a_eq(unformat("`foo`.`b``a``r`.`baz`"), ["foo", "b`a`r", "baz"]) class QuotedIdentTest(fixtures.TestBase): - def test_concat_quotetrue(self): q1 = quoted_name("x", True) self._assert_not_quoted("y" + q1) @@ -819,13 +873,13 @@ class QuotedIdentTest(fixtures.TestBase): def test_apply_map_quoted(self): q1 = _anonymous_label(quoted_name("x%s", True)) - q2 = q1.apply_map(('bar')) + q2 = q1.apply_map(("bar")) eq_(q2, "xbar") eq_(q2.quote, True) def test_apply_map_plain(self): q1 = _anonymous_label(quoted_name("x%s", None)) - q2 = q1.apply_map(('bar')) + q2 = q1.apply_map(("bar")) eq_(q2, "xbar") self._assert_not_quoted(q2) diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 8e6279708e..c5e1efef60 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -1,12 +1,36 @@ -from sqlalchemy.testing import eq_, assert_raises_message, assert_raises, \ - in_, not_in_, is_, ne_, le_ +from sqlalchemy.testing import ( + eq_, + assert_raises_message, + assert_raises, + in_, + not_in_, + is_, + ne_, + le_, +) from sqlalchemy import testing from sqlalchemy.testing import fixtures, engines from sqlalchemy import util from sqlalchemy import ( - exc, sql, func, select, String, Integer, MetaData, ForeignKey, - VARCHAR, INT, CHAR, text, type_coerce, literal_column, - TypeDecorator, table, column, literal) + exc, + sql, + func, + select, + String, + Integer, + MetaData, + ForeignKey, + VARCHAR, + INT, + CHAR, + text, + type_coerce, + literal_column, + TypeDecorator, + table, + column, + literal, +) from sqlalchemy.engine import result as _result from sqlalchemy.testing.schema import Table, Column import operator @@ -23,37 +47,43 @@ class ResultProxyTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table( - 'users', metadata, + "users", + metadata, Column( - 'user_id', INT, primary_key=True, - test_needs_autoincrement=True), - Column('user_name', VARCHAR(20)), - test_needs_acid=True + "user_id", INT, primary_key=True, test_needs_autoincrement=True + ), + Column("user_name", VARCHAR(20)), + test_needs_acid=True, ) Table( - 'addresses', metadata, + "addresses", + metadata, Column( - 'address_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', Integer, ForeignKey('users.user_id')), - Column('address', String(30)), - test_needs_acid=True + "address_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("user_id", Integer, ForeignKey("users.user_id")), + Column("address", String(30)), + test_needs_acid=True, ) Table( - 'users2', metadata, - Column('user_id', INT, primary_key=True), - Column('user_name', VARCHAR(20)), - test_needs_acid=True + "users2", + metadata, + Column("user_id", INT, primary_key=True), + Column("user_name", VARCHAR(20)), + test_needs_acid=True, ) def test_row_iteration(self): users = self.tables.users users.insert().execute( - {'user_id': 7, 'user_name': 'jack'}, - {'user_id': 8, 'user_name': 'ed'}, - {'user_id': 9, 'user_name': 'fred'}, + {"user_id": 7, "user_name": "jack"}, + {"user_id": 8, "user_name": "ed"}, + {"user_id": 9, "user_name": "fred"}, ) r = users.select().execute() rows = [] @@ -65,15 +95,15 @@ class ResultProxyTest(fixtures.TablesTest): users = self.tables.users users.insert().execute( - {'user_id': 7, 'user_name': 'jack'}, - {'user_id': 8, 'user_name': 'ed'}, - {'user_id': 9, 'user_name': 'fred'}, + {"user_id": 7, "user_name": "jack"}, + {"user_id": 8, "user_name": "ed"}, + {"user_id": 9, "user_name": "fred"}, ) r = users.select().execute() rows = [] while True: - row = next(r, 'foo') - if row == 'foo': + row = next(r, "foo") + if row == "foo": break rows.append(row) eq_(len(rows), 3) @@ -83,27 +113,30 @@ class ResultProxyTest(fixtures.TablesTest): users = self.tables.users users.insert().execute( - {'user_id': 7, 'user_name': 'jack'}, - {'user_id': 8, 'user_name': 'ed'}, - {'user_id': 9, 'user_name': 'fred'}, + {"user_id": 7, "user_name": "jack"}, + {"user_id": 8, "user_name": "ed"}, + {"user_id": 9, "user_name": "fred"}, ) - sel = select([users.c.user_id]).where(users.c.user_name == 'jack'). \ - as_scalar() + sel = ( + select([users.c.user_id]) + .where(users.c.user_name == "jack") + .as_scalar() + ) for row in select([sel + 1, sel + 3], bind=users.bind).execute(): - eq_(row['anon_1'], 8) - eq_(row['anon_2'], 10) + eq_(row["anon_1"], 8) + eq_(row["anon_2"], 10) def test_row_comparison(self): users = self.tables.users - users.insert().execute(user_id=7, user_name='jack') + users.insert().execute(user_id=7, user_name="jack") rp = users.select().execute().first() eq_(rp, rp) - is_(not(rp != rp), True) + is_(not (rp != rp), True) - equal = (7, 'jack') + equal = (7, "jack") eq_(rp, equal) eq_(equal, rp) @@ -113,15 +146,20 @@ class ResultProxyTest(fixtures.TablesTest): def endless(): while True: yield 1 + ne_(rp, endless()) ne_(endless(), rp) # test that everything compares the same # as it would against a tuple - for compare in [False, 8, endless(), 'xyz', (7, 'jack')]: + for compare in [False, 8, endless(), "xyz", (7, "jack")]: for op in [ - operator.eq, operator.ne, operator.gt, - operator.lt, operator.ge, operator.le + operator.eq, + operator.ne, + operator.gt, + operator.lt, + operator.ge, + operator.le, ]: try: @@ -142,94 +180,94 @@ class ResultProxyTest(fixtures.TablesTest): @testing.provide_metadata def test_column_label_overlap_fallback(self): - content = Table( - 'content', self.metadata, - Column('type', String(30)), - ) - bar = Table( - 'bar', self.metadata, - Column('content_type', String(30)) - ) + content = Table("content", self.metadata, Column("type", String(30))) + bar = Table("bar", self.metadata, Column("content_type", String(30))) self.metadata.create_all(testing.db) testing.db.execute(content.insert().values(type="t1")) row = testing.db.execute(content.select(use_labels=True)).first() in_(content.c.type, row) not_in_(bar.c.content_type, row) - in_(sql.column('content_type'), row) + in_(sql.column("content_type"), row) row = testing.db.execute( - select([content.c.type.label("content_type")])).first() + select([content.c.type.label("content_type")]) + ).first() in_(content.c.type, row) not_in_(bar.c.content_type, row) - in_(sql.column('content_type'), row) + in_(sql.column("content_type"), row) - row = testing.db.execute(select([func.now().label("content_type")])). \ - first() + row = testing.db.execute( + select([func.now().label("content_type")]) + ).first() not_in_(content.c.type, row) not_in_(bar.c.content_type, row) - in_(sql.column('content_type'), row) + in_(sql.column("content_type"), row) def test_pickled_rows(self): users = self.tables.users addresses = self.tables.addresses users.insert().execute( - {'user_id': 7, 'user_name': 'jack'}, - {'user_id': 8, 'user_name': 'ed'}, - {'user_id': 9, 'user_name': 'fred'}, + {"user_id": 7, "user_name": "jack"}, + {"user_id": 8, "user_name": "ed"}, + {"user_id": 9, "user_name": "fred"}, ) for pickle in False, True: for use_labels in False, True: - result = users.select(use_labels=use_labels).order_by( - users.c.user_id).execute().fetchall() + result = ( + users.select(use_labels=use_labels) + .order_by(users.c.user_id) + .execute() + .fetchall() + ) if pickle: result = util.pickle.loads(util.pickle.dumps(result)) - eq_( - result, - [(7, "jack"), (8, "ed"), (9, "fred")] - ) + eq_(result, [(7, "jack"), (8, "ed"), (9, "fred")]) if use_labels: - eq_(result[0]['users_user_id'], 7) + eq_(result[0]["users_user_id"], 7) eq_( list(result[0].keys()), - ["users_user_id", "users_user_name"]) + ["users_user_id", "users_user_name"], + ) else: - eq_(result[0]['user_id'], 7) + eq_(result[0]["user_id"], 7) eq_(list(result[0].keys()), ["user_id", "user_name"]) eq_(result[0][0], 7) eq_(result[0][users.c.user_id], 7) - eq_(result[0][users.c.user_name], 'jack') + eq_(result[0][users.c.user_name], "jack") if not pickle or use_labels: assert_raises( exc.NoSuchColumnError, - lambda: result[0][addresses.c.user_id]) + lambda: result[0][addresses.c.user_id], + ) else: # test with a different table. name resolution is # causing 'user_id' to match when use_labels wasn't used. eq_(result[0][addresses.c.user_id], 7) assert_raises( - exc.NoSuchColumnError, lambda: result[0]['fake key']) + exc.NoSuchColumnError, lambda: result[0]["fake key"] + ) assert_raises( exc.NoSuchColumnError, - lambda: result[0][addresses.c.address_id]) + lambda: result[0][addresses.c.address_id], + ) def test_column_error_printing(self): result = testing.db.execute(select([1])) row = result.first() class unprintable(object): - def __str__(self): raise ValueError("nope") @@ -242,25 +280,21 @@ class ResultProxyTest(fixtures.TablesTest): (unprintable(), "unprintable element.*"), ]: assert_raises_message( - exc.NoSuchColumnError, - msg % repl, - result._getter, accessor + exc.NoSuchColumnError, msg % repl, result._getter, accessor ) is_(result._getter(accessor, False), None) assert_raises_message( - exc.NoSuchColumnError, - msg % repl, - lambda: row[accessor] + exc.NoSuchColumnError, msg % repl, lambda: row[accessor] ) def test_fetchmany(self): users = self.tables.users - users.insert().execute(user_id=7, user_name='jack') - users.insert().execute(user_id=8, user_name='ed') - users.insert().execute(user_id=9, user_name='fred') + users.insert().execute(user_id=7, user_name="jack") + users.insert().execute(user_id=8, user_name="ed") + users.insert().execute(user_id=9, user_name="fred") r = users.select().execute() rows = [] for row in r.fetchmany(size=2): @@ -271,82 +305,80 @@ class ResultProxyTest(fixtures.TablesTest): users = self.tables.users addresses = self.tables.addresses - users.insert().execute(user_id=1, user_name='john') - users.insert().execute(user_id=2, user_name='jack') + users.insert().execute(user_id=1, user_name="john") + users.insert().execute(user_id=2, user_name="jack") addresses.insert().execute( - address_id=1, user_id=2, address='foo@bar.com') + address_id=1, user_id=2, address="foo@bar.com" + ) - r = text( - "select * from addresses", bind=testing.db).execute().first() + r = text("select * from addresses", bind=testing.db).execute().first() eq_(r[0:1], (1,)) - eq_(r[1:], (2, 'foo@bar.com')) + eq_(r[1:], (2, "foo@bar.com")) eq_(r[:-1], (1, 2)) def test_column_accessor_basic_compiled(self): users = self.tables.users users.insert().execute( - dict(user_id=1, user_name='john'), - dict(user_id=2, user_name='jack') + dict(user_id=1, user_name="john"), + dict(user_id=2, user_name="jack"), ) r = users.select(users.c.user_id == 2).execute().first() eq_(r.user_id, 2) - eq_(r['user_id'], 2) + eq_(r["user_id"], 2) eq_(r[users.c.user_id], 2) - eq_(r.user_name, 'jack') - eq_(r['user_name'], 'jack') - eq_(r[users.c.user_name], 'jack') + eq_(r.user_name, "jack") + eq_(r["user_name"], "jack") + eq_(r[users.c.user_name], "jack") def test_column_accessor_basic_text(self): users = self.tables.users users.insert().execute( - dict(user_id=1, user_name='john'), - dict(user_id=2, user_name='jack') + dict(user_id=1, user_name="john"), + dict(user_id=2, user_name="jack"), ) r = testing.db.execute( - text("select * from users where user_id=2")).first() + text("select * from users where user_id=2") + ).first() eq_(r.user_id, 2) - eq_(r['user_id'], 2) + eq_(r["user_id"], 2) eq_(r[users.c.user_id], 2) - eq_(r.user_name, 'jack') - eq_(r['user_name'], 'jack') - eq_(r[users.c.user_name], 'jack') + eq_(r.user_name, "jack") + eq_(r["user_name"], "jack") + eq_(r[users.c.user_name], "jack") def test_column_accessor_textual_select(self): users = self.tables.users users.insert().execute( - dict(user_id=1, user_name='john'), - dict(user_id=2, user_name='jack') + dict(user_id=1, user_name="john"), + dict(user_id=2, user_name="jack"), ) # this will create column() objects inside # the select(), these need to match on name anyway r = testing.db.execute( - select([ - column('user_id'), column('user_name') - ]).select_from(table('users')). - where(text('user_id=2')) + select([column("user_id"), column("user_name")]) + .select_from(table("users")) + .where(text("user_id=2")) ).first() eq_(r.user_id, 2) - eq_(r['user_id'], 2) + eq_(r["user_id"], 2) eq_(r[users.c.user_id], 2) - eq_(r.user_name, 'jack') - eq_(r['user_name'], 'jack') - eq_(r[users.c.user_name], 'jack') + eq_(r.user_name, "jack") + eq_(r["user_name"], "jack") + eq_(r[users.c.user_name], "jack") def test_column_accessor_dotted_union(self): users = self.tables.users - users.insert().execute( - dict(user_id=1, user_name='john'), - ) + users.insert().execute(dict(user_id=1, user_name="john")) # test a little sqlite < 3.10.0 weirdness - with the UNION, # cols come back as "users.user_id" in cursor.description @@ -358,106 +390,120 @@ class ResultProxyTest(fixtures.TablesTest): "users.user_name from users" ) ).first() - eq_(r['user_id'], 1) - eq_(r['user_name'], "john") + eq_(r["user_id"], 1) + eq_(r["user_name"], "john") eq_(list(r.keys()), ["user_id", "user_name"]) def test_column_accessor_sqlite_raw(self): users = self.tables.users - users.insert().execute( - dict(user_id=1, user_name='john'), - ) + users.insert().execute(dict(user_id=1, user_name="john")) - r = text( - "select users.user_id, users.user_name " - "from users " - "UNION select users.user_id, " - "users.user_name from users", - bind=testing.db).execution_options(sqlite_raw_colnames=True). \ - execute().first() + r = ( + text( + "select users.user_id, users.user_name " + "from users " + "UNION select users.user_id, " + "users.user_name from users", + bind=testing.db, + ) + .execution_options(sqlite_raw_colnames=True) + .execute() + .first() + ) if testing.against("sqlite < 3.10.0"): - not_in_('user_id', r) - not_in_('user_name', r) - eq_(r['users.user_id'], 1) - eq_(r['users.user_name'], "john") + not_in_("user_id", r) + not_in_("user_name", r) + eq_(r["users.user_id"], 1) + eq_(r["users.user_name"], "john") eq_(list(r.keys()), ["users.user_id", "users.user_name"]) else: - not_in_('users.user_id', r) - not_in_('users.user_name', r) - eq_(r['user_id'], 1) - eq_(r['user_name'], "john") + not_in_("users.user_id", r) + not_in_("users.user_name", r) + eq_(r["user_id"], 1) + eq_(r["user_name"], "john") eq_(list(r.keys()), ["user_id", "user_name"]) def test_column_accessor_sqlite_translated(self): users = self.tables.users - users.insert().execute( - dict(user_id=1, user_name='john'), - ) + users.insert().execute(dict(user_id=1, user_name="john")) - r = text( - "select users.user_id, users.user_name " - "from users " - "UNION select users.user_id, " - "users.user_name from users", - bind=testing.db).execute().first() - eq_(r['user_id'], 1) - eq_(r['user_name'], "john") + r = ( + text( + "select users.user_id, users.user_name " + "from users " + "UNION select users.user_id, " + "users.user_name from users", + bind=testing.db, + ) + .execute() + .first() + ) + eq_(r["user_id"], 1) + eq_(r["user_name"], "john") if testing.against("sqlite < 3.10.0"): - eq_(r['users.user_id'], 1) - eq_(r['users.user_name'], "john") + eq_(r["users.user_id"], 1) + eq_(r["users.user_name"], "john") else: - not_in_('users.user_id', r) - not_in_('users.user_name', r) + not_in_("users.user_id", r) + not_in_("users.user_name", r) eq_(list(r.keys()), ["user_id", "user_name"]) def test_column_accessor_labels_w_dots(self): users = self.tables.users - users.insert().execute( - dict(user_id=1, user_name='john'), - ) + users.insert().execute(dict(user_id=1, user_name="john")) # test using literal tablename.colname - r = text( - 'select users.user_id AS "users.user_id", ' - 'users.user_name AS "users.user_name" ' - 'from users', bind=testing.db).\ - execution_options(sqlite_raw_colnames=True).execute().first() - eq_(r['users.user_id'], 1) - eq_(r['users.user_name'], "john") + r = ( + text( + 'select users.user_id AS "users.user_id", ' + 'users.user_name AS "users.user_name" ' + "from users", + bind=testing.db, + ) + .execution_options(sqlite_raw_colnames=True) + .execute() + .first() + ) + eq_(r["users.user_id"], 1) + eq_(r["users.user_name"], "john") not_in_("user_name", r) eq_(list(r.keys()), ["users.user_id", "users.user_name"]) def test_column_accessor_unary(self): users = self.tables.users - users.insert().execute( - dict(user_id=1, user_name='john'), - ) + users.insert().execute(dict(user_id=1, user_name="john")) # unary expressions - r = select([users.c.user_name.distinct()]).order_by( - users.c.user_name).execute().first() - eq_(r[users.c.user_name], 'john') - eq_(r.user_name, 'john') + r = ( + select([users.c.user_name.distinct()]) + .order_by(users.c.user_name) + .execute() + .first() + ) + eq_(r[users.c.user_name], "john") + eq_(r.user_name, "john") def test_column_accessor_err(self): r = testing.db.execute(select([1])).first() assert_raises_message( AttributeError, "Could not locate column in row for column 'foo'", - getattr, r, "foo" + getattr, + r, + "foo", ) assert_raises_message( KeyError, "Could not locate column in row for column 'foo'", - lambda: r['foo'] + lambda: r["foo"], ) def test_graceful_fetch_on_non_rows(self): @@ -481,8 +527,8 @@ class ResultProxyTest(fixtures.TablesTest): lambda r: r.first(), lambda r: r.scalar(), lambda r: r.fetchmany(), - lambda r: r._getter('user'), - lambda r: r._has_key('user'), + lambda r: r._getter("user"), + lambda r: r._has_key("user"), ]: trans = conn.begin() result = conn.execute(users.insert(), user_id=1) @@ -490,7 +536,8 @@ class ResultProxyTest(fixtures.TablesTest): exc.ResourceClosedError, "This result object does not return rows. " "It has been closed automatically.", - meth, result, + meth, + result, ) trans.rollback() @@ -503,19 +550,17 @@ class ResultProxyTest(fixtures.TablesTest): assert_raises_message( exc.ResourceClosedError, "This result object is closed.", - result.fetchone + result.fetchone, ) def test_connectionless_autoclose_rows_exhausted(self): users = self.tables.users - users.insert().execute( - dict(user_id=1, user_name='john'), - ) + users.insert().execute(dict(user_id=1, user_name="john")) result = testing.db.execute("select * from users") connection = result.connection assert not connection.closed - eq_(result.fetchone(), (1, 'john')) + eq_(result.fetchone(), (1, "john")) assert not connection.closed eq_(result.fetchone(), None) assert connection.closed @@ -523,12 +568,15 @@ class ResultProxyTest(fixtures.TablesTest): @testing.requires.returning def test_connectionless_autoclose_crud_rows_exhausted(self): users = self.tables.users - stmt = users.insert().values(user_id=1, user_name='john').\ - returning(users.c.user_id) + stmt = ( + users.insert() + .values(user_id=1, user_name="john") + .returning(users.c.user_id) + ) result = testing.db.execute(stmt) connection = result.connection assert not connection.closed - eq_(result.fetchone(), (1, )) + eq_(result.fetchone(), (1,)) assert not connection.closed eq_(result.fetchone(), None) assert connection.closed @@ -548,15 +596,17 @@ class ResultProxyTest(fixtures.TablesTest): assert_raises_message( exc.ResourceClosedError, "This result object does not return rows.", - result.fetchone + result.fetchone, ) def test_row_case_sensitive(self): row = testing.db.execute( - select([ - literal_column("1").label("case_insensitive"), - literal_column("2").label("CaseSensitive") - ]) + select( + [ + literal_column("1").label("case_insensitive"), + literal_column("2").label("CaseSensitive"), + ] + ) ).first() eq_(list(row.keys()), ["case_insensitive", "CaseSensitive"]) @@ -568,28 +618,25 @@ class ResultProxyTest(fixtures.TablesTest): eq_(row["case_insensitive"], 1) eq_(row["CaseSensitive"], 2) - assert_raises( - KeyError, - lambda: row["Case_insensitive"] - ) - assert_raises( - KeyError, - lambda: row["casesensitive"] - ) + assert_raises(KeyError, lambda: row["Case_insensitive"]) + assert_raises(KeyError, lambda: row["casesensitive"]) def test_row_case_sensitive_unoptimized(self): ins_db = engines.testing_engine(options={"case_sensitive": True}) row = ins_db.execute( - select([ - literal_column("1").label("case_insensitive"), - literal_column("2").label("CaseSensitive"), - text("3 AS screw_up_the_cols") - ]) + select( + [ + literal_column("1").label("case_insensitive"), + literal_column("2").label("CaseSensitive"), + text("3 AS screw_up_the_cols"), + ] + ) ).first() eq_( list(row.keys()), - ["case_insensitive", "CaseSensitive", "screw_up_the_cols"]) + ["case_insensitive", "CaseSensitive", "screw_up_the_cols"], + ) in_("case_insensitive", row._keymap) in_("CaseSensitive", row._keymap) @@ -606,10 +653,12 @@ class ResultProxyTest(fixtures.TablesTest): def test_row_case_insensitive(self): ins_db = engines.testing_engine(options={"case_sensitive": False}) row = ins_db.execute( - select([ - literal_column("1").label("case_insensitive"), - literal_column("2").label("CaseSensitive") - ]) + select( + [ + literal_column("1").label("case_insensitive"), + literal_column("2").label("CaseSensitive"), + ] + ) ).first() eq_(list(row.keys()), ["case_insensitive", "CaseSensitive"]) @@ -626,16 +675,19 @@ class ResultProxyTest(fixtures.TablesTest): def test_row_case_insensitive_unoptimized(self): ins_db = engines.testing_engine(options={"case_sensitive": False}) row = ins_db.execute( - select([ - literal_column("1").label("case_insensitive"), - literal_column("2").label("CaseSensitive"), - text("3 AS screw_up_the_cols") - ]) + select( + [ + literal_column("1").label("case_insensitive"), + literal_column("2").label("CaseSensitive"), + text("3 AS screw_up_the_cols"), + ] + ) ).first() eq_( list(row.keys()), - ["case_insensitive", "CaseSensitive", "screw_up_the_cols"]) + ["case_insensitive", "CaseSensitive", "screw_up_the_cols"], + ) in_("case_insensitive", row._keymap) in_("CaseSensitive", row._keymap) @@ -651,24 +703,27 @@ class ResultProxyTest(fixtures.TablesTest): def test_row_as_args(self): users = self.tables.users - users.insert().execute(user_id=1, user_name='john') + users.insert().execute(user_id=1, user_name="john") r = users.select(users.c.user_id == 1).execute().first() users.delete().execute() users.insert().execute(r) - eq_(users.select().execute().fetchall(), [(1, 'john')]) + eq_(users.select().execute().fetchall(), [(1, "john")]) def test_result_as_args(self): users = self.tables.users users2 = self.tables.users2 - users.insert().execute([ - dict(user_id=1, user_name='john'), - dict(user_id=2, user_name='ed')]) + users.insert().execute( + [ + dict(user_id=1, user_name="john"), + dict(user_id=2, user_name="ed"), + ] + ) r = users.select().execute() users2.insert().execute(list(r)) eq_( users2.select().order_by(users2.c.user_id).execute().fetchall(), - [(1, 'john'), (2, 'ed')] + [(1, "john"), (2, "ed")], ) users2.delete().execute() @@ -676,7 +731,7 @@ class ResultProxyTest(fixtures.TablesTest): users2.insert().execute(*list(r)) eq_( users2.select().order_by(users2.c.user_id).execute().fetchall(), - [(1, 'john'), (2, 'ed')] + [(1, "john"), (2, "ed")], ) @testing.requires.duplicate_names_in_cursor_description @@ -684,14 +739,14 @@ class ResultProxyTest(fixtures.TablesTest): users = self.tables.users addresses = self.tables.addresses - users.insert().execute(user_id=1, user_name='john') + users.insert().execute(user_id=1, user_name="john") result = users.outerjoin(addresses).select().execute() r = result.first() assert_raises_message( exc.InvalidRequestError, "Ambiguous column name", - lambda: r['user_id'] + lambda: r["user_id"], ) # pure positional targeting; users.c.user_id @@ -702,18 +757,18 @@ class ResultProxyTest(fixtures.TablesTest): # try to trick it - fake_table isn't in the result! # we get the correct error - fake_table = Table('fake', MetaData(), Column('user_id', Integer)) + fake_table = Table("fake", MetaData(), Column("user_id", Integer)) assert_raises_message( exc.InvalidRequestError, "Could not locate column in row for column 'fake.user_id'", - lambda: r[fake_table.c.user_id] + lambda: r[fake_table.c.user_id], ) r = util.pickle.loads(util.pickle.dumps(r)) assert_raises_message( exc.InvalidRequestError, "Ambiguous column name", - lambda: r['user_id'] + lambda: r["user_id"], ) result = users.outerjoin(addresses).select().execute() @@ -723,14 +778,14 @@ class ResultProxyTest(fixtures.TablesTest): assert_raises_message( exc.InvalidRequestError, "Ambiguous column name", - lambda: r['user_id'] + lambda: r["user_id"], ) @testing.requires.duplicate_names_in_cursor_description def test_ambiguous_column_by_col(self): users = self.tables.users - users.insert().execute(user_id=1, user_name='john') + users.insert().execute(user_id=1, user_name="john") ua = users.alias() u2 = users.alias() result = select([users.c.user_id, ua.c.user_id]).execute() @@ -747,22 +802,26 @@ class ResultProxyTest(fixtures.TablesTest): assert_raises_message( exc.InvalidRequestError, "Could not locate column in row", - lambda: row[u2.c.user_id] + lambda: row[u2.c.user_id], ) @testing.requires.duplicate_names_in_cursor_description def test_ambiguous_column_case_sensitive(self): eng = engines.testing_engine(options=dict(case_sensitive=False)) - row = eng.execute(select([ - literal_column('1').label('SOMECOL'), - literal_column('1').label('SOMECOL'), - ])).first() + row = eng.execute( + select( + [ + literal_column("1").label("SOMECOL"), + literal_column("1").label("SOMECOL"), + ] + ) + ).first() assert_raises_message( exc.InvalidRequestError, "Ambiguous column name", - lambda: row['somecol'] + lambda: row["somecol"], ) @testing.requires.duplicate_names_in_cursor_description @@ -773,47 +832,45 @@ class ResultProxyTest(fixtures.TablesTest): # ticket 2702. in 0.7 we'd get True, False. # in 0.8, both columns are present so it's True; # but when they're fetched you'll get the ambiguous error. - users.insert().execute(user_id=1, user_name='john') - result = select([users.c.user_id, addresses.c.user_id]).\ - select_from(users.outerjoin(addresses)).execute() + users.insert().execute(user_id=1, user_name="john") + result = ( + select([users.c.user_id, addresses.c.user_id]) + .select_from(users.outerjoin(addresses)) + .execute() + ) row = result.first() eq_( set([users.c.user_id in row, addresses.c.user_id in row]), - set([True]) + set([True]), ) def test_ambiguous_column_by_col_plus_label(self): users = self.tables.users - users.insert().execute(user_id=1, user_name='john') + users.insert().execute(user_id=1, user_name="john") result = select( - [users.c.user_id, - type_coerce(users.c.user_id, Integer).label('foo')]).execute() + [ + users.c.user_id, + type_coerce(users.c.user_id, Integer).label("foo"), + ] + ).execute() row = result.first() - eq_( - row[users.c.user_id], 1 - ) - eq_( - row[1], 1 - ) + eq_(row[users.c.user_id], 1) + eq_(row[1], 1) def test_fetch_partial_result_map(self): users = self.tables.users - users.insert().execute(user_id=7, user_name='ed') + users.insert().execute(user_id=7, user_name="ed") - t = text("select * from users").columns( - user_name=String() - ) - eq_( - testing.db.execute(t).fetchall(), [(7, 'ed')] - ) + t = text("select * from users").columns(user_name=String()) + eq_(testing.db.execute(t).fetchall(), [(7, "ed")]) def test_fetch_unordered_result_map(self): users = self.tables.users - users.insert().execute(user_id=7, user_name='ed') + users.insert().execute(user_id=7, user_name="ed") class Goofy1(TypeDecorator): impl = String @@ -835,166 +892,155 @@ class ResultProxyTest(fixtures.TablesTest): t = text( "select user_name as a, user_name as b, " - "user_name as c from users").columns( - a=Goofy1(), b=Goofy2(), c=Goofy3() - ) - eq_( - testing.db.execute(t).fetchall(), [ - ('eda', 'edb', 'edc') - ] - ) + "user_name as c from users" + ).columns(a=Goofy1(), b=Goofy2(), c=Goofy3()) + eq_(testing.db.execute(t).fetchall(), [("eda", "edb", "edc")]) @testing.requires.subqueries def test_column_label_targeting(self): users = self.tables.users - users.insert().execute(user_id=7, user_name='ed') + users.insert().execute(user_id=7, user_name="ed") for s in ( - users.select().alias('foo'), + users.select().alias("foo"), users.select().alias(users.name), ): row = s.select(use_labels=True).execute().first() eq_(row[s.c.user_id], 7) - eq_(row[s.c.user_name], 'ed') + eq_(row[s.c.user_name], "ed") def test_keys(self): users = self.tables.users - users.insert().execute(user_id=1, user_name='foo') + users.insert().execute(user_id=1, user_name="foo") result = users.select().execute() - eq_( - result.keys(), - ['user_id', 'user_name'] - ) + eq_(result.keys(), ["user_id", "user_name"]) row = result.first() - eq_( - row.keys(), - ['user_id', 'user_name'] - ) + eq_(row.keys(), ["user_id", "user_name"]) def test_keys_anon_labels(self): """test [ticket:3483]""" users = self.tables.users - users.insert().execute(user_id=1, user_name='foo') + users.insert().execute(user_id=1, user_name="foo") result = testing.db.execute( - select([ - users.c.user_id, - users.c.user_name.label(None), - func.count(literal_column('1'))]). - group_by(users.c.user_id, users.c.user_name) + select( + [ + users.c.user_id, + users.c.user_name.label(None), + func.count(literal_column("1")), + ] + ).group_by(users.c.user_id, users.c.user_name) ) - eq_( - result.keys(), - ['user_id', 'user_name_1', 'count_1'] - ) + eq_(result.keys(), ["user_id", "user_name_1", "count_1"]) row = result.first() - eq_( - row.keys(), - ['user_id', 'user_name_1', 'count_1'] - ) + eq_(row.keys(), ["user_id", "user_name_1", "count_1"]) def test_items(self): users = self.tables.users - users.insert().execute(user_id=1, user_name='foo') + users.insert().execute(user_id=1, user_name="foo") r = users.select().execute().first() eq_( [(x[0].lower(), x[1]) for x in list(r.items())], - [('user_id', 1), ('user_name', 'foo')]) + [("user_id", 1), ("user_name", "foo")], + ) def test_len(self): users = self.tables.users - users.insert().execute(user_id=1, user_name='foo') + users.insert().execute(user_id=1, user_name="foo") r = users.select().execute().first() eq_(len(r), 2) - r = testing.db.execute('select user_name, user_id from users'). \ - first() + r = testing.db.execute("select user_name, user_id from users").first() eq_(len(r), 2) - r = testing.db.execute('select user_name from users').first() + r = testing.db.execute("select user_name from users").first() eq_(len(r), 1) def test_sorting_in_python(self): users = self.tables.users users.insert().execute( - dict(user_id=1, user_name='foo'), - dict(user_id=2, user_name='bar'), - dict(user_id=3, user_name='def'), + dict(user_id=1, user_name="foo"), + dict(user_id=2, user_name="bar"), + dict(user_id=3, user_name="def"), ) rows = users.select().order_by(users.c.user_name).execute().fetchall() - eq_(rows, [(2, 'bar'), (3, 'def'), (1, 'foo')]) + eq_(rows, [(2, "bar"), (3, "def"), (1, "foo")]) - eq_(sorted(rows), [(1, 'foo'), (2, 'bar'), (3, 'def')]) + eq_(sorted(rows), [(1, "foo"), (2, "bar"), (3, "def")]) def test_column_order_with_simple_query(self): # should return values in column definition order users = self.tables.users - users.insert().execute(user_id=1, user_name='foo') + users.insert().execute(user_id=1, user_name="foo") r = users.select(users.c.user_id == 1).execute().first() eq_(r[0], 1) - eq_(r[1], 'foo') - eq_([x.lower() for x in list(r.keys())], ['user_id', 'user_name']) - eq_(list(r.values()), [1, 'foo']) + eq_(r[1], "foo") + eq_([x.lower() for x in list(r.keys())], ["user_id", "user_name"]) + eq_(list(r.values()), [1, "foo"]) def test_column_order_with_text_query(self): # should return values in query order users = self.tables.users - users.insert().execute(user_id=1, user_name='foo') - r = testing.db.execute('select user_name, user_id from users'). \ - first() - eq_(r[0], 'foo') + users.insert().execute(user_id=1, user_name="foo") + r = testing.db.execute("select user_name, user_id from users").first() + eq_(r[0], "foo") eq_(r[1], 1) - eq_([x.lower() for x in list(r.keys())], ['user_name', 'user_id']) - eq_(list(r.values()), ['foo', 1]) + eq_([x.lower() for x in list(r.keys())], ["user_name", "user_id"]) + eq_(list(r.values()), ["foo", 1]) - @testing.crashes('oracle', 'FIXME: unknown, varify not fails_on()') - @testing.crashes('firebird', 'An identifier must begin with a letter') + @testing.crashes("oracle", "FIXME: unknown, varify not fails_on()") + @testing.crashes("firebird", "An identifier must begin with a letter") @testing.provide_metadata def test_column_accessor_shadow(self): shadowed = Table( - 'test_shadowed', self.metadata, - Column('shadow_id', INT, primary_key=True), - Column('shadow_name', VARCHAR(20)), - Column('parent', VARCHAR(20)), - Column('row', VARCHAR(40)), - Column('_parent', VARCHAR(20)), - Column('_row', VARCHAR(20)), + "test_shadowed", + self.metadata, + Column("shadow_id", INT, primary_key=True), + Column("shadow_name", VARCHAR(20)), + Column("parent", VARCHAR(20)), + Column("row", VARCHAR(40)), + Column("_parent", VARCHAR(20)), + Column("_row", VARCHAR(20)), ) self.metadata.create_all() shadowed.insert().execute( - shadow_id=1, shadow_name='The Shadow', parent='The Light', - row='Without light there is no shadow', - _parent='Hidden parent', _row='Hidden row') + shadow_id=1, + shadow_name="The Shadow", + parent="The Light", + row="Without light there is no shadow", + _parent="Hidden parent", + _row="Hidden row", + ) r = shadowed.select(shadowed.c.shadow_id == 1).execute().first() eq_(r.shadow_id, 1) - eq_(r['shadow_id'], 1) + eq_(r["shadow_id"], 1) eq_(r[shadowed.c.shadow_id], 1) - eq_(r.shadow_name, 'The Shadow') - eq_(r['shadow_name'], 'The Shadow') - eq_(r[shadowed.c.shadow_name], 'The Shadow') + eq_(r.shadow_name, "The Shadow") + eq_(r["shadow_name"], "The Shadow") + eq_(r[shadowed.c.shadow_name], "The Shadow") - eq_(r.parent, 'The Light') - eq_(r['parent'], 'The Light') - eq_(r[shadowed.c.parent], 'The Light') + eq_(r.parent, "The Light") + eq_(r["parent"], "The Light") + eq_(r[shadowed.c.parent], "The Light") - eq_(r.row, 'Without light there is no shadow') - eq_(r['row'], 'Without light there is no shadow') - eq_(r[shadowed.c.row], 'Without light there is no shadow') + eq_(r.row, "Without light there is no shadow") + eq_(r["row"], "Without light there is no shadow") + eq_(r[shadowed.c.row], "Without light there is no shadow") - eq_(r['_parent'], 'Hidden parent') - eq_(r['_row'], 'Hidden row') + eq_(r["_parent"], "Hidden parent") + eq_(r["_row"], "Hidden row") def test_nontuple_row(self): """ensure the C version of BaseRowProxy handles @@ -1003,7 +1049,6 @@ class ResultProxyTest(fixtures.TablesTest): from sqlalchemy.engine import RowProxy class MyList(object): - def __init__(self, data): self.internal_list = data @@ -1013,11 +1058,15 @@ class ResultProxyTest(fixtures.TablesTest): def __getitem__(self, i): return list.__getitem__(self.internal_list, i) - proxy = RowProxy(object(), MyList(['value']), [None], { - 'key': (None, None, 0), 0: (None, None, 0)}) - eq_(list(proxy), ['value']) - eq_(proxy[0], 'value') - eq_(proxy['key'], 'value') + proxy = RowProxy( + object(), + MyList(["value"]), + [None], + {"key": (None, None, 0), 0: (None, None, 0)}, + ) + eq_(list(proxy), ["value"]) + eq_(proxy[0], "value") + eq_(proxy["key"], "value") @testing.provide_metadata def test_no_rowcount_on_selects_inserts(self): @@ -1033,28 +1082,26 @@ class ResultProxyTest(fixtures.TablesTest): engine = engines.testing_engine() - t = Table('t1', metadata, - Column('data', String(10)) - ) + t = Table("t1", metadata, Column("data", String(10))) metadata.create_all(engine) with patch.object( - engine.dialect.execution_ctx_cls, "rowcount") as mock_rowcount: + engine.dialect.execution_ctx_cls, "rowcount" + ) as mock_rowcount: mock_rowcount.__get__ = Mock() - engine.execute(t.insert(), - {'data': 'd1'}, - {'data': 'd2'}, - {'data': 'd3'}) + engine.execute( + t.insert(), {"data": "d1"}, {"data": "d2"}, {"data": "d3"} + ) eq_(len(mock_rowcount.__get__.mock_calls), 0) eq_( engine.execute(t.select()).fetchall(), - [('d1', ), ('d2', ), ('d3', )] + [("d1",), ("d2",), ("d3",)], ) eq_(len(mock_rowcount.__get__.mock_calls), 0) - engine.execute(t.update(), {'data': 'd4'}) + engine.execute(t.update(), {"data": "d4"}) eq_(len(mock_rowcount.__get__.mock_calls), 1) @@ -1066,58 +1113,66 @@ class ResultProxyTest(fixtures.TablesTest): from sqlalchemy.engine import RowProxy row = RowProxy( - object(), ['value'], [None], - {'key': (None, None, 0), 0: (None, None, 0)}) + object(), + ["value"], + [None], + {"key": (None, None, 0), 0: (None, None, 0)}, + ) assert isinstance(row, collections_abc.Sequence) @testing.provide_metadata def test_rowproxy_getitem_indexes_compiled(self): - values = Table('rp', self.metadata, - Column('key', String(10), primary_key=True), - Column('value', String(10))) + values = Table( + "rp", + self.metadata, + Column("key", String(10), primary_key=True), + Column("value", String(10)), + ) values.create() - testing.db.execute(values.insert(), dict(key='One', value='Uno')) + testing.db.execute(values.insert(), dict(key="One", value="Uno")) row = testing.db.execute(values.select()).first() - eq_(row['key'], 'One') - eq_(row['value'], 'Uno') - eq_(row[0], 'One') - eq_(row[1], 'Uno') - eq_(row[-2], 'One') - eq_(row[-1], 'Uno') - eq_(row[1:0:-1], ('Uno',)) + eq_(row["key"], "One") + eq_(row["value"], "Uno") + eq_(row[0], "One") + eq_(row[1], "Uno") + eq_(row[-2], "One") + eq_(row[-1], "Uno") + eq_(row[1:0:-1], ("Uno",)) @testing.only_on("sqlite") def test_rowproxy_getitem_indexes_raw(self): row = testing.db.execute("select 'One' as key, 'Uno' as value").first() - eq_(row['key'], 'One') - eq_(row['value'], 'Uno') - eq_(row[0], 'One') - eq_(row[1], 'Uno') - eq_(row[-2], 'One') - eq_(row[-1], 'Uno') - eq_(row[1:0:-1], ('Uno',)) + eq_(row["key"], "One") + eq_(row["value"], "Uno") + eq_(row[0], "One") + eq_(row[1], "Uno") + eq_(row[-2], "One") + eq_(row[-1], "Uno") + eq_(row[1:0:-1], ("Uno",)) @testing.requires.cextensions def test_row_c_sequence_check(self): import csv metadata = MetaData() - metadata.bind = 'sqlite://' - users = Table('users', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(40)), - ) + metadata.bind = "sqlite://" + users = Table( + "users", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(40)), + ) users.create() - users.insert().execute(name='Test') + users.insert().execute(name="Test") row = users.select().execute().fetchone() s = util.StringIO() writer = csv.writer(s) # csv performs PySequenceCheck call writer.writerow(row) - assert s.getvalue().strip() == '1,Test' + assert s.getvalue().strip() == "1,Test" @testing.requires.selectone def test_empty_accessors(self): @@ -1129,33 +1184,28 @@ class ResultProxyTest(fixtures.TablesTest): lambda r: r.last_updated_params(), lambda r: r.prefetch_cols(), lambda r: r.postfetch_cols(), - lambda r: r.inserted_primary_key + lambda r: r.inserted_primary_key, ], - "Statement is not a compiled expression construct." + "Statement is not a compiled expression construct.", ), ( select([1]), [ lambda r: r.last_inserted_params(), - lambda r: r.inserted_primary_key + lambda r: r.inserted_primary_key, ], - r"Statement is not an insert\(\) expression construct." + r"Statement is not an insert\(\) expression construct.", ), ( select([1]), - [ - lambda r: r.last_updated_params(), - ], - r"Statement is not an update\(\) expression construct." + [lambda r: r.last_updated_params()], + r"Statement is not an update\(\) expression construct.", ), ( select([1]), - [ - lambda r: r.prefetch_cols(), - lambda r: r.postfetch_cols() - ], + [lambda r: r.prefetch_cols(), lambda r: r.postfetch_cols()], r"Statement is not an insert\(\) " - r"or update\(\) expression construct." + r"or update\(\) expression construct.", ), ] @@ -1164,9 +1214,7 @@ class ResultProxyTest(fixtures.TablesTest): try: for meth in meths: assert_raises_message( - sa_exc.InvalidRequestError, - msg, - meth, r + sa_exc.InvalidRequestError, msg, meth, r ) finally: @@ -1174,28 +1222,31 @@ class ResultProxyTest(fixtures.TablesTest): class KeyTargetingTest(fixtures.TablesTest): - run_inserts = 'once' + run_inserts = "once" run_deletes = None __backend__ = True @classmethod def define_tables(cls, metadata): Table( - 'keyed1', metadata, Column("a", CHAR(2), key="b"), - Column("c", CHAR(2), key="q") + "keyed1", + metadata, + Column("a", CHAR(2), key="b"), + Column("c", CHAR(2), key="q"), ) - Table('keyed2', metadata, Column("a", CHAR(2)), Column("b", CHAR(2))) - Table('keyed3', metadata, Column("a", CHAR(2)), Column("d", CHAR(2))) - Table('keyed4', metadata, Column("b", CHAR(2)), Column("q", CHAR(2))) - Table('content', metadata, Column('t', String(30), key="type")) - Table('bar', metadata, Column('ctype', String(30), key="content_type")) + Table("keyed2", metadata, Column("a", CHAR(2)), Column("b", CHAR(2))) + Table("keyed3", metadata, Column("a", CHAR(2)), Column("d", CHAR(2))) + Table("keyed4", metadata, Column("b", CHAR(2)), Column("q", CHAR(2))) + Table("content", metadata, Column("t", String(30), key="type")) + Table("bar", metadata, Column("ctype", String(30), key="content_type")) if testing.requires.schemas.enabled: Table( - 'wschema', metadata, + "wschema", + metadata, Column("a", CHAR(2), key="b"), Column("c", CHAR(2), key="q"), - schema=testing.config.test_schema + schema=testing.config.test_schema, ) @classmethod @@ -1208,12 +1259,12 @@ class KeyTargetingTest(fixtures.TablesTest): if testing.requires.schemas.enabled: cls.tables[ - '%s.wschema' % testing.config.test_schema].insert().execute( - dict(b="a1", q="c1")) + "%s.wschema" % testing.config.test_schema + ].insert().execute(dict(b="a1", q="c1")) @testing.requires.schemas def test_keyed_accessor_wschema(self): - keyed1 = self.tables['%s.wschema' % testing.config.test_schema] + keyed1 = self.tables["%s.wschema" % testing.config.test_schema] row = testing.db.execute(keyed1.select()).first() eq_(row.b, "a1") @@ -1249,9 +1300,7 @@ class KeyTargetingTest(fixtures.TablesTest): eq_(row.b, "b2") # row.a is ambiguous assert_raises_message( - exc.InvalidRequestError, - "Ambig", - getattr, row, "a" + exc.InvalidRequestError, "Ambig", getattr, row, "a" ) def test_keyed_accessor_composite_names_precedent(self): @@ -1274,12 +1323,16 @@ class KeyTargetingTest(fixtures.TablesTest): assert_raises_message( exc.InvalidRequestError, "Ambiguous column name 'a'", - getattr, row, "b" + getattr, + row, + "b", ) assert_raises_message( exc.InvalidRequestError, "Ambiguous column name 'a'", - getattr, row, "a" + getattr, + row, + "a", ) eq_(row.d, "d3") @@ -1287,39 +1340,42 @@ class KeyTargetingTest(fixtures.TablesTest): keyed1 = self.tables.keyed1 keyed2 = self.tables.keyed2 - row = testing.db.execute(select([keyed1, keyed2]).apply_labels()). \ - first() + row = testing.db.execute( + select([keyed1, keyed2]).apply_labels() + ).first() eq_(row.keyed1_b, "a1") eq_(row.keyed1_a, "a1") eq_(row.keyed1_q, "c1") eq_(row.keyed1_c, "c1") eq_(row.keyed2_a, "a2") eq_(row.keyed2_b, "b2") - assert_raises(KeyError, lambda: row['keyed2_c']) - assert_raises(KeyError, lambda: row['keyed2_q']) + assert_raises(KeyError, lambda: row["keyed2_c"]) + assert_raises(KeyError, lambda: row["keyed2_q"]) def test_column_label_overlap_fallback(self): content, bar = self.tables.content, self.tables.bar row = testing.db.execute( - select([content.c.type.label("content_type")])).first() + select([content.c.type.label("content_type")]) + ).first() not_in_(content.c.type, row) not_in_(bar.c.content_type, row) - in_(sql.column('content_type'), row) + in_(sql.column("content_type"), row) - row = testing.db.execute(select([func.now().label("content_type")])). \ - first() + row = testing.db.execute( + select([func.now().label("content_type")]) + ).first() not_in_(content.c.type, row) not_in_(bar.c.content_type, row) - in_(sql.column('content_type'), row) + in_(sql.column("content_type"), row) def test_column_label_overlap_fallback_2(self): content, bar = self.tables.content, self.tables.bar row = testing.db.execute(content.select(use_labels=True)).first() in_(content.c.type, row) not_in_(bar.c.content_type, row) - not_in_(sql.column('content_type'), row) + not_in_(sql.column("content_type"), row) def test_columnclause_schema_column_one(self): keyed2 = self.tables.keyed2 @@ -1328,7 +1384,7 @@ class KeyTargetingTest(fixtures.TablesTest): # ColumnClause._compare_name_for_result allows the # columns which the statement is against to be lightweight # cols, which results in a more liberal comparison scheme - a, b = sql.column('a'), sql.column('b') + a, b = sql.column("a"), sql.column("b") stmt = select([a, b]).select_from(table("keyed2")) row = testing.db.execute(stmt).first() @@ -1340,7 +1396,7 @@ class KeyTargetingTest(fixtures.TablesTest): def test_columnclause_schema_column_two(self): keyed2 = self.tables.keyed2 - a, b = sql.column('a'), sql.column('b') + a, b = sql.column("a"), sql.column("b") stmt = select([keyed2.c.a, keyed2.c.b]) row = testing.db.execute(stmt).first() @@ -1354,7 +1410,7 @@ class KeyTargetingTest(fixtures.TablesTest): # this is also addressed by [ticket:2932] - a, b = sql.column('a'), sql.column('b') + a, b = sql.column("a"), sql.column("b") stmt = text("select a, b from keyed2").columns(a=CHAR, b=CHAR) row = testing.db.execute(stmt).first() @@ -1370,9 +1426,10 @@ class KeyTargetingTest(fixtures.TablesTest): # this is also addressed by [ticket:2932] - a, b = sql.column('keyed2_a'), sql.column('keyed2_b') + a, b = sql.column("keyed2_a"), sql.column("keyed2_b") stmt = text("select a AS keyed2_a, b AS keyed2_b from keyed2").columns( - a, b) + a, b + ) row = testing.db.execute(stmt).first() in_(keyed2.c.a, row) @@ -1388,7 +1445,8 @@ class KeyTargetingTest(fixtures.TablesTest): # this is also addressed by [ticket:2932] stmt = text("select a AS keyed2_a, b AS keyed2_b from keyed2").columns( - keyed2_a=CHAR, keyed2_b=CHAR) + keyed2_a=CHAR, keyed2_b=CHAR + ) row = testing.db.execute(stmt).first() in_(keyed2.c.a, row) @@ -1398,29 +1456,29 @@ class KeyTargetingTest(fixtures.TablesTest): class PositionalTextTest(fixtures.TablesTest): - run_inserts = 'once' + run_inserts = "once" run_deletes = None __backend__ = True @classmethod def define_tables(cls, metadata): Table( - 'text1', + "text1", metadata, Column("a", CHAR(2)), Column("b", CHAR(2)), Column("c", CHAR(2)), - Column("d", CHAR(2)) + Column("d", CHAR(2)), ) @classmethod def insert_data(cls): - cls.tables.text1.insert().execute([ - dict(a="a1", b="b1", c="c1", d="d1"), - ]) + cls.tables.text1.insert().execute( + [dict(a="a1", b="b1", c="c1", d="d1")] + ) def test_via_column(self): - c1, c2, c3, c4 = column('q'), column('p'), column('r'), column('d') + c1, c2, c3, c4 = column("q"), column("p"), column("r"), column("d") stmt = text("select a, b, c, d from text1").columns(c1, c2, c3, c4) result = testing.db.execute(stmt) @@ -1435,7 +1493,7 @@ class PositionalTextTest(fixtures.TablesTest): eq_(row["d"], "d1") def test_fewer_cols_than_sql_positional(self): - c1, c2 = column('q'), column('p') + c1, c2 = column("q"), column("p") stmt = text("select a, b, c, d from text1").columns(c1, c2) # no warning as this can be similar for non-positional @@ -1446,7 +1504,7 @@ class PositionalTextTest(fixtures.TablesTest): eq_(row["c"], "c1") def test_fewer_cols_than_sql_non_positional(self): - c1, c2 = column('a'), column('p') + c1, c2 = column("a"), column("p") stmt = text("select a, b, c, d from text1").columns(c2, c1, d=CHAR) # no warning as this can be similar for non-positional @@ -1459,38 +1517,36 @@ class PositionalTextTest(fixtures.TablesTest): # c2 name does not match, doesn't locate assert_raises_message( - exc.NoSuchColumnError, - "in row for column 'p'", - lambda: row[c2] + exc.NoSuchColumnError, "in row for column 'p'", lambda: row[c2] ) def test_more_cols_than_sql(self): - c1, c2, c3, c4 = column('q'), column('p'), column('r'), column('d') + c1, c2, c3, c4 = column("q"), column("p"), column("r"), column("d") stmt = text("select a, b from text1").columns(c1, c2, c3, c4) with assertions.expect_warnings( - r"Number of columns in textual SQL \(4\) is " - r"smaller than number of columns requested \(2\)"): + r"Number of columns in textual SQL \(4\) is " + r"smaller than number of columns requested \(2\)" + ): result = testing.db.execute(stmt) row = result.first() eq_(row[c2], "b1") assert_raises_message( - exc.NoSuchColumnError, - "in row for column 'r'", - lambda: row[c3] + exc.NoSuchColumnError, "in row for column 'r'", lambda: row[c3] ) def test_dupe_col_obj(self): - c1, c2, c3 = column('q'), column('p'), column('r') + c1, c2, c3 = column("q"), column("p"), column("r") stmt = text("select a, b, c, d from text1").columns(c1, c2, c3, c2) assert_raises_message( exc.InvalidRequestError, "Duplicate column expression requested in " "textual SQL: <.*.ColumnClause.*; p>", - testing.db.execute, stmt + testing.db.execute, + stmt, ) def test_anon_aliased_unique(self): @@ -1523,7 +1579,7 @@ class PositionalTextTest(fixtures.TablesTest): assert_raises_message( exc.NoSuchColumnError, "Could not locate column in row for column 'text1.b'", - lambda: row[text1.c.b] + lambda: row[text1.c.b], ) def test_anon_aliased_overlapping(self): @@ -1558,7 +1614,8 @@ class PositionalTextTest(fixtures.TablesTest): # all cols are named "a". if we are positional, we don't care. # this is new logic in 1.1 stmt = text("select a, b as a, c as a, d as a from text1").columns( - c1, c2, c3, c4) + c1, c2, c3, c4 + ) result = testing.db.execute(stmt) row = result.first() @@ -1572,42 +1629,45 @@ class PositionalTextTest(fixtures.TablesTest): assert_raises_message( exc.NoSuchColumnError, "Could not locate column in row for column 'text1.a'", - lambda: row[text1.c.a] + lambda: row[text1.c.a], ) class AlternateResultProxyTest(fixtures.TablesTest): - __requires__ = ('sqlite', ) + __requires__ = ("sqlite",) @classmethod def setup_bind(cls): - cls.engine = engine = engines.testing_engine('sqlite://') + cls.engine = engine = engines.testing_engine("sqlite://") return engine @classmethod def define_tables(cls, metadata): Table( - 'test', metadata, - Column('x', Integer, primary_key=True), - Column('y', String(50, convert_unicode='force')) + "test", + metadata, + Column("x", Integer, primary_key=True), + Column("y", String(50, convert_unicode="force")), ) @classmethod def insert_data(cls): - cls.engine.execute(cls.tables.test.insert(), [ - {'x': i, 'y': "t_%d" % i} for i in range(1, 12) - ]) + cls.engine.execute( + cls.tables.test.insert(), + [{"x": i, "y": "t_%d" % i} for i in range(1, 12)], + ) @contextmanager def _proxy_fixture(self, cls): self.table = self.tables.test class ExcCtx(default.DefaultExecutionContext): - def get_result_proxy(self): return cls(self) + self.patcher = patch.object( - self.engine.dialect, "execution_ctx_cls", ExcCtx) + self.engine.dialect, "execution_ctx_cls", ExcCtx + ) with self.patcher: yield @@ -1664,21 +1724,15 @@ class AlternateResultProxyTest(fixtures.TablesTest): def _assert_result_closed(self, r): assert_raises_message( - sa_exc.ResourceClosedError, - "object is closed", - r.fetchone + sa_exc.ResourceClosedError, "object is closed", r.fetchone ) assert_raises_message( - sa_exc.ResourceClosedError, - "object is closed", - r.fetchmany, 2 + sa_exc.ResourceClosedError, "object is closed", r.fetchmany, 2 ) assert_raises_message( - sa_exc.ResourceClosedError, - "object is closed", - r.fetchall + sa_exc.ResourceClosedError, "object is closed", r.fetchall ) def test_basic_plain(self): @@ -1738,35 +1792,28 @@ class AlternateResultProxyTest(fixtures.TablesTest): def test_buffered_row_growth(self): with self._proxy_fixture(_result.BufferedRowResultProxy): with self.engine.connect() as conn: - conn.execute(self.table.insert(), [ - {'x': i, 'y': "t_%d" % i} for i in range(15, 1200) - ]) + conn.execute( + self.table.insert(), + [{"x": i, "y": "t_%d" % i} for i in range(15, 1200)], + ) result = conn.execute(self.table.select()) - checks = { - 0: 5, 1: 10, 9: 20, 135: 250, 274: 500, - 1351: 1000 - } + checks = {0: 5, 1: 10, 9: 20, 135: 250, 274: 500, 1351: 1000} for idx, row in enumerate(result, 0): if idx in checks: eq_(result._bufsize, checks[idx]) - le_( - len(result._BufferedRowResultProxy__rowbuffer), - 1000 - ) + le_(len(result._BufferedRowResultProxy__rowbuffer), 1000) def test_max_row_buffer_option(self): with self._proxy_fixture(_result.BufferedRowResultProxy): with self.engine.connect() as conn: - conn.execute(self.table.insert(), [ - {'x': i, 'y': "t_%d" % i} for i in range(15, 1200) - ]) + conn.execute( + self.table.insert(), + [{"x": i, "y": "t_%d" % i} for i in range(15, 1200)], + ) result = conn.execution_options(max_row_buffer=27).execute( self.table.select() ) for idx, row in enumerate(result, 0): if idx in (16, 70, 150, 250): eq_(result._bufsize, 27) - le_( - len(result._BufferedRowResultProxy__rowbuffer), - 27 - ) + le_(len(result._BufferedRowResultProxy__rowbuffer), 27) diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index f8d183b714..e298a2d552 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -2,18 +2,29 @@ from sqlalchemy.testing import eq_ from sqlalchemy import testing from sqlalchemy.testing.schema import Table, Column from sqlalchemy.types import TypeDecorator -from sqlalchemy.testing import fixtures, AssertsExecutionResults, engines, \ - assert_raises_message +from sqlalchemy.testing import ( + fixtures, + AssertsExecutionResults, + engines, + assert_raises_message, +) from sqlalchemy import exc as sa_exc -from sqlalchemy import MetaData, String, Integer, Boolean, func, select, \ - Sequence +from sqlalchemy import ( + MetaData, + String, + Integer, + Boolean, + func, + select, + Sequence, +) import itertools table = GoofyType = seq = None class ReturningTest(fixtures.TestBase, AssertsExecutionResults): - __requires__ = 'returning', + __requires__ = ("returning",) __backend__ = True def setup(self): @@ -34,106 +45,134 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): return value + "BAR" table = Table( - 'tables', meta, + "tables", + meta, Column( - 'id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('persons', Integer), - Column('full', Boolean), - Column('goofy', GoofyType(50))) + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("persons", Integer), + Column("full", Boolean), + Column("goofy", GoofyType(50)), + ) table.create(checkfirst=True) def teardown(self): table.drop() def test_column_targeting(self): - result = table.insert().returning( - table.c.id, table.c.full).execute({'persons': 1, 'full': False}) + result = ( + table.insert() + .returning(table.c.id, table.c.full) + .execute({"persons": 1, "full": False}) + ) row = result.first() - assert row[table.c.id] == row['id'] == 1 - assert row[table.c.full] == row['full'] - assert row['full'] is False - - result = table.insert().values( - persons=5, full=True, goofy="somegoofy").\ - returning(table.c.persons, table.c.full, table.c.goofy).execute() + assert row[table.c.id] == row["id"] == 1 + assert row[table.c.full] == row["full"] + assert row["full"] is False + + result = ( + table.insert() + .values(persons=5, full=True, goofy="somegoofy") + .returning(table.c.persons, table.c.full, table.c.goofy) + .execute() + ) row = result.first() - assert row[table.c.persons] == row['persons'] == 5 - assert row[table.c.full] == row['full'] + assert row[table.c.persons] == row["persons"] == 5 + assert row[table.c.full] == row["full"] - eq_(row[table.c.goofy], row['goofy']) - eq_(row['goofy'], "FOOsomegoofyBAR") + eq_(row[table.c.goofy], row["goofy"]) + eq_(row["goofy"], "FOOsomegoofyBAR") - @testing.fails_on('firebird', "fb can't handle returning x AS y") + @testing.fails_on("firebird", "fb can't handle returning x AS y") def test_labeling(self): - result = table.insert().values(persons=6).\ - returning(table.c.persons.label('lala')).execute() + result = ( + table.insert() + .values(persons=6) + .returning(table.c.persons.label("lala")) + .execute() + ) row = result.first() - assert row['lala'] == 6 + assert row["lala"] == 6 @testing.fails_on( - 'firebird', - "fb/kintersbasdb can't handle the bind params") - @testing.fails_on('oracle+zxjdbc', "JDBC driver bug") + "firebird", "fb/kintersbasdb can't handle the bind params" + ) + @testing.fails_on("oracle+zxjdbc", "JDBC driver bug") def test_anon_expressions(self): - result = table.insert().values(goofy="someOTHERgoofy").\ - returning(func.lower(table.c.goofy, type_=GoofyType)).execute() + result = ( + table.insert() + .values(goofy="someOTHERgoofy") + .returning(func.lower(table.c.goofy, type_=GoofyType)) + .execute() + ) row = result.first() eq_(row[0], "foosomeothergoofyBAR") - result = table.insert().values(persons=12).\ - returning(table.c.persons + 18).execute() + result = ( + table.insert() + .values(persons=12) + .returning(table.c.persons + 18) + .execute() + ) row = result.first() eq_(row[0], 30) def test_update_returning(self): table.insert().execute( - [{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) + [{"persons": 5, "full": False}, {"persons": 3, "full": False}] + ) - result = table.update( - table.c.persons > 4, dict( - full=True)).returning( - table.c.id).execute() + result = ( + table.update(table.c.persons > 4, dict(full=True)) + .returning(table.c.id) + .execute() + ) eq_(result.fetchall(), [(1,)]) - result2 = select([table.c.id, table.c.full]).order_by( - table.c.id).execute() + result2 = ( + select([table.c.id, table.c.full]).order_by(table.c.id).execute() + ) eq_(result2.fetchall(), [(1, True), (2, False)]) def test_insert_returning(self): - result = table.insert().returning( - table.c.id).execute({'persons': 1, 'full': False}) + result = ( + table.insert() + .returning(table.c.id) + .execute({"persons": 1, "full": False}) + ) eq_(result.fetchall(), [(1,)]) @testing.requires.multivalues_inserts def test_multirow_returning(self): - ins = table.insert().returning(table.c.id, table.c.persons).values( - [ - {'persons': 1, 'full': False}, - {'persons': 2, 'full': True}, - {'persons': 3, 'full': False}, - ] + ins = ( + table.insert() + .returning(table.c.id, table.c.persons) + .values( + [ + {"persons": 1, "full": False}, + {"persons": 2, "full": True}, + {"persons": 3, "full": False}, + ] + ) ) result = testing.db.execute(ins) - eq_( - result.fetchall(), - [(1, 1), (2, 2), (3, 3)] - ) + eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)]) def test_no_ipk_on_returning(self): result = testing.db.execute( - table.insert().returning(table.c.id), - {'persons': 1, 'full': False} + table.insert().returning(table.c.id), {"persons": 1, "full": False} ) assert_raises_message( sa_exc.InvalidRequestError, r"Can't call inserted_primary_key when returning\(\) is used.", - getattr, result, "inserted_primary_key" + getattr, + result, + "inserted_primary_key", ) - @testing.fails_on_everything_except('postgresql', 'firebird') + @testing.fails_on_everything_except("postgresql", "firebird") def test_literal_returning(self): if testing.against("postgresql"): literal_true = "true" @@ -142,26 +181,28 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): result4 = testing.db.execute( 'insert into tables (id, persons, "full") ' - 'values (5, 10, %s) returning persons' % - literal_true) - eq_([dict(row) for row in result4], [{'persons': 10}]) + "values (5, 10, %s) returning persons" % literal_true + ) + eq_([dict(row) for row in result4], [{"persons": 10}]) def test_delete_returning(self): table.insert().execute( - [{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) + [{"persons": 5, "full": False}, {"persons": 3, "full": False}] + ) - result = table.delete( - table.c.persons > 4).returning( - table.c.id).execute() + result = ( + table.delete(table.c.persons > 4).returning(table.c.id).execute() + ) eq_(result.fetchall(), [(1,)]) - result2 = select([table.c.id, table.c.full]).order_by( - table.c.id).execute() - eq_(result2.fetchall(), [(2, False), ]) + result2 = ( + select([table.c.id, table.c.full]).order_by(table.c.id).execute() + ) + eq_(result2.fetchall(), [(2, False)]) class CompositeStatementTest(fixtures.TestBase): - __requires__ = 'returning', + __requires__ = ("returning",) __backend__ = True @testing.provide_metadata @@ -172,47 +213,46 @@ class CompositeStatementTest(fixtures.TestBase): def process_result_value(self, value, dialect): raise Exception("I have not been selected") - t1 = Table( - 't1', self.metadata, - Column('x', MyType()) - ) + t1 = Table("t1", self.metadata, Column("x", MyType())) - t2 = Table( - 't2', self.metadata, - Column('x', Integer) - ) + t2 = Table("t2", self.metadata, Column("x", Integer)) self.metadata.create_all(testing.db) with testing.db.connect() as conn: conn.execute(t1.insert().values(x=5)) - stmt = t2.insert().values( - x=select([t1.c.x]).as_scalar()).returning(t2.c.x) + stmt = ( + t2.insert() + .values(x=select([t1.c.x]).as_scalar()) + .returning(t2.c.x) + ) result = conn.execute(stmt) eq_(result.scalar(), 5) class SequenceReturningTest(fixtures.TestBase): - __requires__ = 'returning', 'sequences' + __requires__ = "returning", "sequences" __backend__ = True def setup(self): meta = MetaData(testing.db) global table, seq - seq = Sequence('tid_seq') - table = Table('tables', meta, - Column('id', Integer, seq, primary_key=True), - Column('data', String(50)) - ) + seq = Sequence("tid_seq") + table = Table( + "tables", + meta, + Column("id", Integer, seq, primary_key=True), + Column("data", String(50)), + ) table.create(checkfirst=True) def teardown(self): table.drop() def test_insert(self): - r = table.insert().values(data='hi').returning(table.c.id).execute() - assert r.first() == (1, ) + r = table.insert().values(data="hi").returning(table.c.id).execute() + assert r.first() == (1,) assert seq.execute() == 2 @@ -220,7 +260,7 @@ class KeyReturningTest(fixtures.TestBase, AssertsExecutionResults): """test returning() works with columns that define 'key'.""" - __requires__ = 'returning', + __requires__ = ("returning",) __backend__ = True def setup(self): @@ -228,39 +268,38 @@ class KeyReturningTest(fixtures.TestBase, AssertsExecutionResults): global table table = Table( - 'tables', + "tables", meta, Column( - 'id', + "id", Integer, primary_key=True, - key='foo_id', - test_needs_autoincrement=True), - Column( - 'data', - String(20)), + key="foo_id", + test_needs_autoincrement=True, + ), + Column("data", String(20)), ) table.create(checkfirst=True) def teardown(self): table.drop() - @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') - @testing.exclude('postgresql', '<', (8, 2), '8.2+ feature') + @testing.exclude("firebird", "<", (2, 0), "2.0+ feature") + @testing.exclude("postgresql", "<", (8, 2), "8.2+ feature") def test_insert(self): - result = table.insert().returning( - table.c.foo_id).execute( - data='somedata') + result = ( + table.insert().returning(table.c.foo_id).execute(data="somedata") + ) row = result.first() - assert row[table.c.foo_id] == row['id'] == 1 + assert row[table.c.foo_id] == row["id"] == 1 result = table.select().execute().first() - assert row[table.c.foo_id] == row['id'] == 1 + assert row[table.c.foo_id] == row["id"] == 1 class ReturnDefaultsTest(fixtures.TablesTest): - __requires__ = ('returning', ) - run_define_tables = 'each' + __requires__ = ("returning",) + run_define_tables = "each" __backend__ = True @classmethod @@ -278,13 +317,15 @@ class ReturnDefaultsTest(fixtures.TablesTest): return str(next(counter)) Table( - "t1", metadata, + "t1", + metadata, Column( - "id", Integer, primary_key=True, - test_needs_autoincrement=True), + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), Column("data", String(50)), Column("insdef", Integer, default=IncDefault()), - Column("upddef", Integer, onupdate=IncDefault())) + Column("upddef", Integer, onupdate=IncDefault()), + ) def test_chained_insert_pk(self): t1 = self.tables.t1 @@ -293,7 +334,7 @@ class ReturnDefaultsTest(fixtures.TablesTest): ) eq_( [result.returned_defaults[k] for k in (t1.c.id, t1.c.insdef)], - [1, 0] + [1, 0], ) def test_arg_insert_pk(self): @@ -303,32 +344,24 @@ class ReturnDefaultsTest(fixtures.TablesTest): ) eq_( [result.returned_defaults[k] for k in (t1.c.id, t1.c.insdef)], - [1, 0] + [1, 0], ) def test_chained_update_pk(self): t1 = self.tables.t1 - testing.db.execute( - t1.insert().values(upddef=1) - ) - result = testing.db.execute(t1.update().values(data='d1'). - return_defaults(t1.c.upddef)) - eq_( - [result.returned_defaults[k] for k in (t1.c.upddef,)], - [1] + testing.db.execute(t1.insert().values(upddef=1)) + result = testing.db.execute( + t1.update().values(data="d1").return_defaults(t1.c.upddef) ) + eq_([result.returned_defaults[k] for k in (t1.c.upddef,)], [1]) def test_arg_update_pk(self): t1 = self.tables.t1 - testing.db.execute( - t1.insert().values(upddef=1) - ) - result = testing.db.execute(t1.update(return_defaults=[t1.c.upddef]). - values(data='d1')) - eq_( - [result.returned_defaults[k] for k in (t1.c.upddef,)], - [1] + testing.db.execute(t1.insert().values(upddef=1)) + result = testing.db.execute( + t1.update(return_defaults=[t1.c.upddef]).values(data="d1") ) + eq_([result.returned_defaults[k] for k in (t1.c.upddef,)], [1]) def test_insert_non_default(self): """test that a column not marked at all as a @@ -339,8 +372,8 @@ class ReturnDefaultsTest(fixtures.TablesTest): t1.insert().values(upddef=1).return_defaults(t1.c.data) ) eq_( - [result.returned_defaults[k] for k in (t1.c.id, t1.c.data,)], - [1, None] + [result.returned_defaults[k] for k in (t1.c.id, t1.c.data)], + [1, None], ) def test_update_non_default(self): @@ -348,42 +381,33 @@ class ReturnDefaultsTest(fixtures.TablesTest): default works with this feature.""" t1 = self.tables.t1 - testing.db.execute( - t1.insert().values(upddef=1) - ) + testing.db.execute(t1.insert().values(upddef=1)) result = testing.db.execute( - t1.update(). values( - upddef=2).return_defaults( - t1.c.data)) - eq_( - [result.returned_defaults[k] for k in (t1.c.data,)], - [None] + t1.update().values(upddef=2).return_defaults(t1.c.data) ) + eq_([result.returned_defaults[k] for k in (t1.c.data,)], [None]) def test_insert_non_default_plus_default(self): t1 = self.tables.t1 result = testing.db.execute( - t1.insert().values(upddef=1).return_defaults( - t1.c.data, t1.c.insdef) + t1.insert() + .values(upddef=1) + .return_defaults(t1.c.data, t1.c.insdef) ) eq_( dict(result.returned_defaults), - {"id": 1, "data": None, "insdef": 0} + {"id": 1, "data": None, "insdef": 0}, ) def test_update_non_default_plus_default(self): t1 = self.tables.t1 - testing.db.execute( - t1.insert().values(upddef=1) - ) + testing.db.execute(t1.insert().values(upddef=1)) result = testing.db.execute( - t1.update(). - values(insdef=2).return_defaults( - t1.c.data, t1.c.upddef)) - eq_( - dict(result.returned_defaults), - {"data": None, 'upddef': 1} + t1.update() + .values(insdef=2) + .return_defaults(t1.c.data, t1.c.upddef) ) + eq_(dict(result.returned_defaults), {"data": None, "upddef": 1}) def test_insert_all(self): t1 = self.tables.t1 @@ -392,36 +416,30 @@ class ReturnDefaultsTest(fixtures.TablesTest): ) eq_( dict(result.returned_defaults), - {"id": 1, "data": None, "insdef": 0} + {"id": 1, "data": None, "insdef": 0}, ) def test_update_all(self): t1 = self.tables.t1 - testing.db.execute( - t1.insert().values(upddef=1) - ) + testing.db.execute(t1.insert().values(upddef=1)) result = testing.db.execute( - t1.update(). - values(insdef=2).return_defaults() - ) - eq_( - dict(result.returned_defaults), - {'upddef': 1} + t1.update().values(insdef=2).return_defaults() ) + eq_(dict(result.returned_defaults), {"upddef": 1}) class ImplicitReturningFlag(fixtures.TestBase): __backend__ = True def test_flag_turned_off(self): - e = engines.testing_engine(options={'implicit_returning': False}) + e = engines.testing_engine(options={"implicit_returning": False}) assert e.dialect.implicit_returning is False c = e.connect() c.close() assert e.dialect.implicit_returning is False def test_flag_turned_on(self): - e = engines.testing_engine(options={'implicit_returning': True}) + e = engines.testing_engine(options={"implicit_returning": True}) assert e.dialect.implicit_returning is True c = e.connect() c.close() @@ -432,6 +450,7 @@ class ImplicitReturningFlag(fixtures.TestBase): def go(): supports[0] = True + testing.requires.returning(go)() e = engines.testing_engine() diff --git a/test/sql/test_rowcount.py b/test/sql/test_rowcount.py index ea29bcf7ea..126e1f0cd8 100644 --- a/test/sql/test_rowcount.py +++ b/test/sql/test_rowcount.py @@ -8,7 +8,7 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults): """tests rowcount functionality""" - __requires__ = ('sane_rowcount', ) + __requires__ = ("sane_rowcount",) __backend__ = True @classmethod @@ -17,28 +17,35 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults): metadata = MetaData(testing.db) employees_table = Table( - 'employees', metadata, + "employees", + metadata, Column( - 'employee_id', Integer, - Sequence('employee_id_seq', optional=True), primary_key=True), - Column('name', String(50)), - Column('department', String(1))) + "employee_id", + Integer, + Sequence("employee_id_seq", optional=True), + primary_key=True, + ), + Column("name", String(50)), + Column("department", String(1)), + ) metadata.create_all() def setup(self): global data - data = [('Angela', 'A'), - ('Andrew', 'A'), - ('Anand', 'A'), - ('Bob', 'B'), - ('Bobette', 'B'), - ('Buffy', 'B'), - ('Charlie', 'C'), - ('Cynthia', 'C'), - ('Chris', 'C')] + data = [ + ("Angela", "A"), + ("Andrew", "A"), + ("Anand", "A"), + ("Bob", "B"), + ("Bobette", "B"), + ("Buffy", "B"), + ("Charlie", "C"), + ("Cynthia", "C"), + ("Chris", "C"), + ] i = employees_table.insert() - i.execute(*[{'name': n, 'department': d} for n, d in data]) + i.execute(*[{"name": n, "department": d} for n, d in data]) def teardown(self): employees_table.delete().execute() @@ -56,23 +63,26 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults): def test_update_rowcount1(self): # WHERE matches 3, 3 rows changed department = employees_table.c.department - r = employees_table.update(department == 'C').execute(department='Z') + r = employees_table.update(department == "C").execute(department="Z") assert r.rowcount == 3 def test_update_rowcount2(self): # WHERE matches 3, 0 rows changed department = employees_table.c.department - r = employees_table.update(department == 'C').execute(department='C') + r = employees_table.update(department == "C").execute(department="C") assert r.rowcount == 3 @testing.skip_if( - testing.requires.oracle5x, - "unknown DBAPI error fixed in later version") + testing.requires.oracle5x, "unknown DBAPI error fixed in later version" + ) @testing.requires.sane_rowcount_w_returning def test_update_rowcount_return_defaults(self): department = employees_table.c.department - stmt = employees_table.update(department == 'C').values( - name=employees_table.c.department + 'Z').return_defaults() + stmt = ( + employees_table.update(department == "C") + .values(name=employees_table.c.department + "Z") + .return_defaults() + ) r = stmt.execute() assert r.rowcount == 3 @@ -81,7 +91,8 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults): # test issue #3622, make sure eager rowcount is called for text with testing.db.connect() as conn: result = conn.execute( - "update employees set department='Z' where department='C'") + "update employees set department='Z' where department='C'" + ) eq_(result.rowcount, 3) def test_text_rowcount(self): @@ -90,43 +101,49 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults): result = conn.execute( text( "update employees set department='Z' " - "where department='C'")) + "where department='C'" + ) + ) eq_(result.rowcount, 3) def test_delete_rowcount(self): # WHERE matches 3, 3 rows deleted department = employees_table.c.department - r = employees_table.delete(department == 'C').execute() + r = employees_table.delete(department == "C").execute() assert r.rowcount == 3 @testing.requires.sane_multi_rowcount def test_multi_update_rowcount(self): - stmt = employees_table.update().\ - where(employees_table.c.name == bindparam('emp_name')).\ - values(department="C") + stmt = ( + employees_table.update() + .where(employees_table.c.name == bindparam("emp_name")) + .values(department="C") + ) r = testing.db.execute( stmt, - [{"emp_name": "Bob"}, {"emp_name": "Cynthia"}, - {"emp_name": "nonexistent"}] + [ + {"emp_name": "Bob"}, + {"emp_name": "Cynthia"}, + {"emp_name": "nonexistent"}, + ], ) - eq_( - r.rowcount, 2 - ) + eq_(r.rowcount, 2) @testing.requires.sane_multi_rowcount def test_multi_delete_rowcount(self): - stmt = employees_table.delete().\ - where(employees_table.c.name == bindparam('emp_name')) + stmt = employees_table.delete().where( + employees_table.c.name == bindparam("emp_name") + ) r = testing.db.execute( stmt, - [{"emp_name": "Bob"}, {"emp_name": "Cynthia"}, - {"emp_name": "nonexistent"}] - ) - - eq_( - r.rowcount, 2 + [ + {"emp_name": "Bob"}, + {"emp_name": "Cynthia"}, + {"emp_name": "nonexistent"}, + ], ) + eq_(r.rowcount, 2) diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 4b92e3e3ec..023a0bc61b 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -1,10 +1,12 @@ """Test various algorithmic properties of selectables.""" -from sqlalchemy.testing import eq_, assert_raises, \ - assert_raises_message, is_ +from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, is_ from sqlalchemy import * -from sqlalchemy.testing import fixtures, AssertsCompiledSQL, \ - AssertsExecutionResults +from sqlalchemy.testing import ( + fixtures, + AssertsCompiledSQL, + AssertsExecutionResults, +) from sqlalchemy.sql import elements from sqlalchemy import testing from sqlalchemy.sql import util as sql_util, visitors, expression @@ -14,33 +16,37 @@ from sqlalchemy import util from sqlalchemy.schema import Column, Table, MetaData metadata = MetaData() -table1 = Table('table1', metadata, - Column('col1', Integer, primary_key=True), - Column('col2', String(20)), - Column('col3', Integer), - Column('colx', Integer), - - ) - -table2 = Table('table2', metadata, - Column('col1', Integer, primary_key=True), - Column('col2', Integer, ForeignKey('table1.col1')), - Column('col3', String(20)), - Column('coly', Integer), - ) - -keyed = Table('keyed', metadata, - Column('x', Integer, key='colx'), - Column('y', Integer, key='coly'), - Column('z', Integer), - ) +table1 = Table( + "table1", + metadata, + Column("col1", Integer, primary_key=True), + Column("col2", String(20)), + Column("col3", Integer), + Column("colx", Integer), +) + +table2 = Table( + "table2", + metadata, + Column("col1", Integer, primary_key=True), + Column("col2", Integer, ForeignKey("table1.col1")), + Column("col3", String(20)), + Column("coly", Integer), +) + +keyed = Table( + "keyed", + metadata, + Column("x", Integer, key="colx"), + Column("y", Integer, key="coly"), + Column("z", Integer), +) class SelectableTest( - fixtures.TestBase, - AssertsExecutionResults, - AssertsCompiledSQL): - __dialect__ = 'default' + fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL +): + __dialect__ = "default" def test_indirect_correspondence_on_labels(self): # this test depends upon 'distance' to @@ -48,8 +54,13 @@ class SelectableTest( # same column three times - s = select([table1.c.col1.label('c2'), table1.c.col1, - table1.c.col1.label('c1')]) + s = select( + [ + table1.c.col1.label("c2"), + table1.c.col1, + table1.c.col1.label("c1"), + ] + ) # this tests the same thing as # test_direct_correspondence_on_labels below - @@ -60,25 +71,25 @@ class SelectableTest( assert s.corresponding_column(s.c.c1) is s.c.c1 def test_labeled_subquery_twice(self): - scalar_select = select([table1.c.col1]).label('foo') + scalar_select = select([table1.c.col1]).label("foo") s1 = select([scalar_select]) s2 = select([scalar_select, scalar_select]) eq_( s1.c.foo.proxy_set, - set([s1.c.foo, scalar_select, scalar_select.element]) + set([s1.c.foo, scalar_select, scalar_select.element]), ) eq_( s2.c.foo.proxy_set, - set([s2.c.foo, scalar_select, scalar_select.element]) + set([s2.c.foo, scalar_select, scalar_select.element]), ) assert s1.corresponding_column(scalar_select) is s1.c.foo assert s2.corresponding_column(scalar_select) is s2.c.foo def test_label_grouped_still_corresponds(self): - label = select([table1.c.col1]).label('foo') + label = select([table1.c.col1]).label("foo") label2 = label.self_group() s1 = select([label]) @@ -90,14 +101,14 @@ class SelectableTest( # this test depends on labels being part # of the proxy set to get the right result - l1, l2 = table1.c.col1.label('foo'), table1.c.col1.label('bar') + l1, l2 = table1.c.col1.label("foo"), table1.c.col1.label("bar") sel = select([l1, l2]) sel2 = sel.alias() assert sel2.corresponding_column(l1) is sel2.c.foo assert sel2.corresponding_column(l2) is sel2.c.bar - sel2 = select([table1.c.col1.label('foo'), table1.c.col2.label('bar')]) + sel2 = select([table1.c.col1.label("foo"), table1.c.col2.label("bar")]) sel3 = sel.union(sel2).alias() assert sel3.corresponding_column(l1) is sel3.c.foo @@ -105,9 +116,9 @@ class SelectableTest( def test_keyed_gen(self): s = select([keyed]) - eq_(s.c.colx.key, 'colx') + eq_(s.c.colx.key, "colx") - eq_(s.c.colx.name, 'x') + eq_(s.c.colx.name, "x") assert s.corresponding_column(keyed.c.colx) is s.c.colx assert s.corresponding_column(keyed.c.coly) is s.c.coly @@ -131,26 +142,26 @@ class SelectableTest( assert sel2.corresponding_column(keyed.c.z) is sel2.c.keyed_z def test_keyed_c_collection_upper(self): - c = Column('foo', Integer, key='bar') - t = Table('t', MetaData(), c) + c = Column("foo", Integer, key="bar") + t = Table("t", MetaData(), c) is_(t.c.bar, c) def test_keyed_c_collection_lower(self): - c = column('foo') - c.key = 'bar' - t = table('t', c) + c = column("foo") + c.key = "bar" + t = table("t", c) is_(t.c.bar, c) def test_clone_c_proxy_key_upper(self): - c = Column('foo', Integer, key='bar') - t = Table('t', MetaData(), c) + c = Column("foo", Integer, key="bar") + t = Table("t", MetaData(), c) s = select([t])._clone() assert c in s.c.bar.proxy_set def test_clone_c_proxy_key_lower(self): - c = column('foo') - c.key = 'bar' - t = table('t', c) + c = column("foo") + c.key = "bar" + t = table("t", c) s = select([t])._clone() assert c in s.c.bar.proxy_set @@ -160,19 +171,16 @@ class SelectableTest( def myop(x, y): pass - t = table('t', column('x'), column('y')) + t = table("t", column("x"), column("y")) expr = BinaryExpression(t.c.x, t.c.y, myop) s = select([t, expr]) - eq_( - s.c.keys(), - ['x', 'y', expr.anon_label] - ) + eq_(s.c.keys(), ["x", "y", expr.anon_label]) def test_cloned_intersection(self): - t1 = table('t1', column('x')) - t2 = table('t2', column('x')) + t1 = table("t1", column("x")) + t2 = table("t2", column("x")) s1 = t1.select() s2 = t2.select() @@ -184,15 +192,13 @@ class SelectableTest( s3c1 = s3._clone() eq_( - expression._cloned_intersection( - [s1c1, s3c1], [s2c1, s1c2] - ), - set([s1c1]) + expression._cloned_intersection([s1c1, s3c1], [s2c1, s1c2]), + set([s1c1]), ) def test_cloned_difference(self): - t1 = table('t1', column('x')) - t2 = table('t2', column('x')) + t1 = table("t1", column("x")) + t2 = table("t2", column("x")) s1 = t1.select() s2 = t2.select() @@ -205,75 +211,70 @@ class SelectableTest( s3c1 = s3._clone() eq_( - expression._cloned_difference( - [s1c1, s2c1, s3c1], [s2c1, s1c2] - ), - set([s3c1]) + expression._cloned_difference([s1c1, s2c1, s3c1], [s2c1, s1c2]), + set([s3c1]), ) def test_distance_on_aliases(self): - a1 = table1.alias('a1') - for s in (select([a1, table1], use_labels=True), - select([table1, a1], use_labels=True)): - assert s.corresponding_column(table1.c.col1) \ - is s.c.table1_col1 + a1 = table1.alias("a1") + for s in ( + select([a1, table1], use_labels=True), + select([table1, a1], use_labels=True), + ): + assert s.corresponding_column(table1.c.col1) is s.c.table1_col1 assert s.corresponding_column(a1.c.col1) is s.c.a1_col1 def test_join_against_self(self): - jj = select([table1.c.col1.label('bar_col1')]) + jj = select([table1.c.col1.label("bar_col1")]) jjj = join(table1, jj, table1.c.col1 == jj.c.bar_col1) # test column directly against itself - assert jjj.corresponding_column(jjj.c.table1_col1) \ - is jjj.c.table1_col1 + assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1 assert jjj.corresponding_column(jj.c.bar_col1) is jjj.c.bar_col1 # test alias of the join - j2 = jjj.alias('foo') - assert j2.corresponding_column(table1.c.col1) \ - is j2.c.table1_col1 + j2 = jjj.alias("foo") + assert j2.corresponding_column(table1.c.col1) is j2.c.table1_col1 def test_clone_append_column(self): - sel = select([literal_column('1').label('a')]) - eq_(list(sel.c.keys()), ['a']) + sel = select([literal_column("1").label("a")]) + eq_(list(sel.c.keys()), ["a"]) cloned = visitors.ReplacingCloningVisitor().traverse(sel) - cloned.append_column(literal_column('2').label('b')) + cloned.append_column(literal_column("2").label("b")) cloned.append_column(func.foo()) - eq_(list(cloned.c.keys()), ['a', 'b', 'foo()']) + eq_(list(cloned.c.keys()), ["a", "b", "foo()"]) def test_append_column_after_replace_selectable(self): - basesel = select([literal_column('1').label('a')]) - tojoin = select([ - literal_column('1').label('a'), - literal_column('2').label('b') - ]) - basefrom = basesel.alias('basefrom') - joinfrom = tojoin.alias('joinfrom') + basesel = select([literal_column("1").label("a")]) + tojoin = select( + [literal_column("1").label("a"), literal_column("2").label("b")] + ) + basefrom = basesel.alias("basefrom") + joinfrom = tojoin.alias("joinfrom") sel = select([basefrom.c.a]) replaced = sel.replace_selectable( - basefrom, - basefrom.join(joinfrom, basefrom.c.a == joinfrom.c.a) + basefrom, basefrom.join(joinfrom, basefrom.c.a == joinfrom.c.a) ) self.assert_compile( replaced, "SELECT basefrom.a FROM (SELECT 1 AS a) AS basefrom " "JOIN (SELECT 1 AS a, 2 AS b) AS joinfrom " - "ON basefrom.a = joinfrom.a" + "ON basefrom.a = joinfrom.a", ) replaced.append_column(joinfrom.c.b) self.assert_compile( replaced, "SELECT basefrom.a, joinfrom.b FROM (SELECT 1 AS a) AS basefrom " "JOIN (SELECT 1 AS a, 2 AS b) AS joinfrom " - "ON basefrom.a = joinfrom.a" + "ON basefrom.a = joinfrom.a", ) def test_against_cloned_non_table(self): # test that corresponding column digs across # clone boundaries with anonymous labeled elements - col = func.count().label('foo') + col = func.count().label("foo") sel = select([col]) sel2 = visitors.ReplacingCloningVisitor().traverse(sel) @@ -287,14 +288,14 @@ class SelectableTest( self.assert_compile( s1.with_only_columns([s1]), "SELECT (SELECT table1.col1, table1.col2, " - "table1.col3, table1.colx FROM table1) AS anon_1" + "table1.col3, table1.colx FROM table1) AS anon_1", ) def test_type_coerce_preserve_subq(self): class MyType(TypeDecorator): impl = Integer - stmt = select([type_coerce(column('x'), MyType).label('foo')]) + stmt = select([type_coerce(column("x"), MyType).label("foo")]) stmt2 = stmt.select() assert isinstance(stmt._raw_columns[0].type, MyType) assert isinstance(stmt.c.foo.type, MyType) @@ -303,30 +304,32 @@ class SelectableTest( def test_select_on_table(self): sel = select([table1, table2], use_labels=True) - assert sel.corresponding_column(table1.c.col1) \ + assert sel.corresponding_column(table1.c.col1) is sel.c.table1_col1 + assert ( + sel.corresponding_column(table1.c.col1, require_embedded=True) is sel.c.table1_col1 - assert sel.corresponding_column( - table1.c.col1, - require_embedded=True) is sel.c.table1_col1 - assert table1.corresponding_column(sel.c.table1_col1) \ - is table1.c.col1 - assert table1.corresponding_column(sel.c.table1_col1, - require_embedded=True) is None + ) + assert table1.corresponding_column(sel.c.table1_col1) is table1.c.col1 + assert ( + table1.corresponding_column( + sel.c.table1_col1, require_embedded=True + ) + is None + ) def test_join_against_join(self): j = outerjoin(table1, table2, table1.c.col1 == table2.c.col2) - jj = select([table1.c.col1.label('bar_col1')], - from_obj=[j]).alias('foo') + jj = select([table1.c.col1.label("bar_col1")], from_obj=[j]).alias( + "foo" + ) jjj = join(table1, jj, table1.c.col1 == jj.c.bar_col1) - assert jjj.corresponding_column(jjj.c.table1_col1) \ - is jjj.c.table1_col1 - j2 = jjj.alias('foo') - assert j2.corresponding_column(jjj.c.table1_col1) \ - is j2.c.table1_col1 + assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1 + j2 = jjj.alias("foo") + assert j2.corresponding_column(jjj.c.table1_col1) is j2.c.table1_col1 assert jjj.corresponding_column(jj.c.bar_col1) is jj.c.bar_col1 def test_table_alias(self): - a = table1.alias('a') + a = table1.alias("a") j = join(a, table2) @@ -338,13 +341,13 @@ class SelectableTest( # prominent w/ PostgreSQL's tuple functions stmt = select([table1.c.col1, table1.c.col2]) - a = stmt.alias('a') + a = stmt.alias("a") self.assert_compile( select([func.foo(a)]), "SELECT foo(SELECT table1.col1, table1.col2 FROM table1) " "AS foo_1 FROM " "(SELECT table1.col1 AS col1, table1.col2 AS col2 FROM table1) " - "AS a" + "AS a", ) def test_union(self): @@ -353,15 +356,25 @@ class SelectableTest( # with a certain Table, against a column in a Union where one of # its underlying Selects matches to that same Table - u = select([table1.c.col1, - table1.c.col2, - table1.c.col3, - table1.c.colx, - null().label('coly')]).union(select([table2.c.col1, - table2.c.col2, - table2.c.col3, - null().label('colx'), - table2.c.coly])) + u = select( + [ + table1.c.col1, + table1.c.col2, + table1.c.col3, + table1.c.colx, + null().label("coly"), + ] + ).union( + select( + [ + table2.c.col1, + table2.c.col2, + table2.c.col3, + null().label("colx"), + table2.c.coly, + ] + ) + ) s1 = table1.select(use_labels=True) s2 = table2.select(use_labels=True) @@ -388,8 +401,10 @@ class SelectableTest( assert u1.corresponding_column(table1.c.col3) is u1.c.col1 def test_singular_union(self): - u = union(select([table1.c.col1, table1.c.col2, table1.c.col3]), - select([table1.c.col1, table1.c.col2, table1.c.col3])) + u = union( + select([table1.c.col1, table1.c.col2, table1.c.col3]), + select([table1.c.col1, table1.c.col2, table1.c.col3]), + ) u = union(select([table1.c.col1, table1.c.col2, table1.c.col3])) assert u.c.col1 is not None assert u.c.col2 is not None @@ -399,14 +414,29 @@ class SelectableTest( # same as testunion, except its an alias of the union - u = select([table1.c.col1, + u = ( + select( + [ + table1.c.col1, table1.c.col2, table1.c.col3, table1.c.colx, - null().label('coly')]).union( - select([table2.c.col1, table2.c.col2, table2.c.col3, - null().label('colx'), table2.c.coly]) - ).alias('analias') + null().label("coly"), + ] + ) + .union( + select( + [ + table2.c.col1, + table2.c.col2, + table2.c.col3, + null().label("colx"), + table2.c.coly, + ] + ) + ) + .alias("analias") + ) s1 = table1.select(use_labels=True) s2 = table2.select(use_labels=True) assert u.corresponding_column(s1.c.table1_col2) is u.c.col2 @@ -429,7 +459,8 @@ class SelectableTest( def test_union_of_text(self): s1 = select([table1.c.col1, table1.c.col2]) s2 = text("select col1, col2 from foo").columns( - column('col1'), column('col2')) + column("col1"), column("col2") + ) u1 = union(s1, s2) assert u1.corresponding_column(s1.c.col1) is u1.c.col1 @@ -445,8 +476,10 @@ class SelectableTest( s2 = select([table2.c.col1, table2.c.col2, table2.c.col3]) u1 = union(s1, s2) - assert u1.corresponding_column( - s1.c._all_columns[0]) is u1.c._all_columns[0] + assert ( + u1.corresponding_column(s1.c._all_columns[0]) + is u1.c._all_columns[0] + ) assert u1.corresponding_column(s2.c.col1) is u1.c._all_columns[0] assert u1.corresponding_column(s1.c.col2) is u1.c.col2 assert u1.corresponding_column(s2.c.col2) is u1.c.col2 @@ -462,8 +495,10 @@ class SelectableTest( s2 = select([table2.c.col1, table2.c.col2, table2.c.col3]) u1 = union(s1, s2) - assert u1.corresponding_column( - s1.c._all_columns[0]) is u1.c._all_columns[0] + assert ( + u1.corresponding_column(s1.c._all_columns[0]) + is u1.c._all_columns[0] + ) assert u1.corresponding_column(s2.c.col1) is u1.c._all_columns[0] assert u1.corresponding_column(s1.c.col2) is u1.c.col2 assert u1.corresponding_column(s2.c.col2) is u1.c.col2 @@ -477,13 +512,18 @@ class SelectableTest( @testing.emits_warning("Column 'col1'") def test_union_alias_dupe_keys_grouped(self): - s1 = select([table1.c.col1, table1.c.col2, table2.c.col1]).\ - limit(1).alias() + s1 = ( + select([table1.c.col1, table1.c.col2, table2.c.col1]) + .limit(1) + .alias() + ) s2 = select([table2.c.col1, table2.c.col2, table2.c.col3]).limit(1) u1 = union(s1, s2) - assert u1.corresponding_column( - s1.c._all_columns[0]) is u1.c._all_columns[0] + assert ( + u1.corresponding_column(s1.c._all_columns[0]) + is u1.c._all_columns[0] + ) assert u1.corresponding_column(s2.c.col1) is u1.c._all_columns[0] assert u1.corresponding_column(s1.c.col2) is u1.c.col2 assert u1.corresponding_column(s2.c.col2) is u1.c.col2 @@ -499,14 +539,29 @@ class SelectableTest( # like testaliasunion, but off a Select off the union. - u = select([table1.c.col1, + u = ( + select( + [ + table1.c.col1, table1.c.col2, table1.c.col3, table1.c.colx, - null().label('coly')]).union( - select([table2.c.col1, table2.c.col2, table2.c.col3, - null().label('colx'), table2.c.coly]) - ).alias('analias') + null().label("coly"), + ] + ) + .union( + select( + [ + table2.c.col1, + table2.c.col2, + table2.c.col3, + null().label("colx"), + table2.c.coly, + ] + ) + ) + .alias("analias") + ) s = select([u]) s1 = table1.select(use_labels=True) s2 = table2.select(use_labels=True) @@ -517,14 +572,29 @@ class SelectableTest( # same as testunion, except its an alias of the union - u = select([table1.c.col1, + u = ( + select( + [ + table1.c.col1, table1.c.col2, table1.c.col3, table1.c.colx, - null().label('coly')]).union( - select([table2.c.col1, table2.c.col2, table2.c.col3, - null().label('colx'), table2.c.coly]) - ).alias('analias') + null().label("coly"), + ] + ) + .union( + select( + [ + table2.c.col1, + table2.c.col2, + table2.c.col3, + null().label("colx"), + table2.c.coly, + ] + ) + ) + .alias("analias") + ) j1 = table1.join(table2) assert u.corresponding_column(j1.c.table1_colx) is u.c.colx assert j1.corresponding_column(u.c.colx) is j1.c.table1_colx @@ -532,14 +602,14 @@ class SelectableTest( def test_join(self): a = join(table1, table2) print(str(a.select(use_labels=True))) - b = table2.alias('b') + b = table2.alias("b") j = join(a, b) print(str(j)) criterion = a.c.table1_col1 == b.c.col2 self.assert_(criterion.compare(j.onclause)) def test_select_alias(self): - a = table1.select().alias('a') + a = table1.select().alias("a") j = join(a, table2) criterion = a.c.col1 == table2.c.col2 @@ -562,15 +632,19 @@ class SelectableTest( is_(expr2.left, sel2) def test_column_labels(self): - a = select([table1.c.col1.label('acol1'), - table1.c.col2.label('acol2'), - table1.c.col3.label('acol3')]) + a = select( + [ + table1.c.col1.label("acol1"), + table1.c.col2.label("acol2"), + table1.c.col3.label("acol3"), + ] + ) j = join(a, table2) criterion = a.c.acol1 == table2.c.col2 self.assert_(criterion.compare(j.onclause)) def test_labeled_select_correspoinding(self): - l1 = select([func.max(table1.c.col1)]).label('foo') + l1 = select([func.max(table1.c.col1)]).label("foo") s = select([l1]) eq_(s.corresponding_column(l1), s.c.foo) @@ -579,7 +653,7 @@ class SelectableTest( eq_(s.corresponding_column(l1), s.c.foo) def test_select_alias_labels(self): - a = table2.select(use_labels=True).alias('a') + a = table2.select(use_labels=True).alias("a") j = join(a, table1) criterion = table1.c.col1 == a.c.table2_col2 @@ -587,14 +661,13 @@ class SelectableTest( def test_table_joined_to_select_of_table(self): metadata = MetaData() - a = Table('a', metadata, - Column('id', Integer, primary_key=True)) + a = Table("a", metadata, Column("id", Integer, primary_key=True)) - j2 = select([a.c.id.label('aid')]).alias('bar') + j2 = select([a.c.id.label("aid")]).alias("bar") j3 = a.join(j2, j2.c.aid == a.c.id) - j4 = select([j3]).alias('foo') + j4 = select([j3]).alias("foo") assert j4.corresponding_column(j2.c.aid) is j4.c.aid assert j4.corresponding_column(a.c.id) is j4.c.id @@ -602,9 +675,9 @@ class SelectableTest( m = MetaData() m2 = MetaData() - t1 = Table('t1', m, Column('id', Integer), Column('id2', Integer)) - t2 = Table('t2', m, Column('id', Integer, ForeignKey('t1.id'))) - t3 = Table('t3', m2, Column('id', Integer, ForeignKey('t1.id2'))) + t1 = Table("t1", m, Column("id", Integer), Column("id2", Integer)) + t2 = Table("t2", m, Column("id", Integer, ForeignKey("t1.id"))) + t3 = Table("t3", m2, Column("id", Integer, ForeignKey("t1.id2"))) s = select([t2, t3], use_labels=True) @@ -612,24 +685,24 @@ class SelectableTest( def test_multi_label_chain_naming_col(self): # See [ticket:2167] for this one. - l1 = table1.c.col1.label('a') - l2 = select([l1]).label('b') + l1 = table1.c.col1.label("a") + l2 = select([l1]).label("b") s = select([l2]) assert s.c.b is not None self.assert_compile( s.select(), - "SELECT b FROM (SELECT (SELECT table1.col1 AS a FROM table1) AS b)" + "SELECT b FROM (SELECT (SELECT table1.col1 AS a FROM table1) AS b)", ) - s2 = select([s.label('c')]) + s2 = select([s.label("c")]) self.assert_compile( s2.select(), "SELECT c FROM (SELECT (SELECT (" - "SELECT table1.col1 AS a FROM table1) AS b) AS c)" + "SELECT table1.col1 AS a FROM table1) AS b) AS c)", ) def test_self_referential_select_raises(self): - t = table('t', column('x')) + t = table("t", column("x")) s = select([t]) @@ -637,140 +710,110 @@ class SelectableTest( assert_raises_message( exc.InvalidRequestError, r"select\(\) construct refers to itself as a FROM", - s.compile + s.compile, ) def test_unusual_column_elements_text(self): """test that .c excludes text().""" s = select([table1.c.col1, text("foo")]) - eq_( - list(s.c), - [s.c.col1] - ) + eq_(list(s.c), [s.c.col1]) def test_unusual_column_elements_clauselist(self): """Test that raw ClauseList is expanded into .c.""" from sqlalchemy.sql.expression import ClauseList + s = select([table1.c.col1, ClauseList(table1.c.col2, table1.c.col3)]) - eq_( - list(s.c), - [s.c.col1, s.c.col2, s.c.col3] - ) + eq_(list(s.c), [s.c.col1, s.c.col2, s.c.col3]) def test_unusual_column_elements_boolean_clauselist(self): """test that BooleanClauseList is placed as single element in .c.""" c2 = and_(table1.c.col2 == 5, table1.c.col3 == 4) s = select([table1.c.col1, c2]) - eq_( - list(s.c), - [s.c.col1, s.corresponding_column(c2)] - ) + eq_(list(s.c), [s.c.col1, s.corresponding_column(c2)]) def test_from_list_deferred_constructor(self): - c1 = Column('c1', Integer) - c2 = Column('c2', Integer) + c1 = Column("c1", Integer) + c2 = Column("c2", Integer) s = select([c1]) - t = Table('t', MetaData(), c1, c2) + t = Table("t", MetaData(), c1, c2) eq_(c1._from_objects, [t]) eq_(c2._from_objects, [t]) - self.assert_compile(select([c1]), - "SELECT t.c1 FROM t") - self.assert_compile(select([c2]), - "SELECT t.c2 FROM t") + self.assert_compile(select([c1]), "SELECT t.c1 FROM t") + self.assert_compile(select([c2]), "SELECT t.c2 FROM t") def test_from_list_deferred_whereclause(self): - c1 = Column('c1', Integer) - c2 = Column('c2', Integer) + c1 = Column("c1", Integer) + c2 = Column("c2", Integer) s = select([c1]).where(c1 == 5) - t = Table('t', MetaData(), c1, c2) + t = Table("t", MetaData(), c1, c2) eq_(c1._from_objects, [t]) eq_(c2._from_objects, [t]) - self.assert_compile(select([c1]), - "SELECT t.c1 FROM t") - self.assert_compile(select([c2]), - "SELECT t.c2 FROM t") + self.assert_compile(select([c1]), "SELECT t.c1 FROM t") + self.assert_compile(select([c2]), "SELECT t.c2 FROM t") def test_from_list_deferred_fromlist(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer)) + t1 = Table("t1", m, Column("x", Integer)) - c1 = Column('c1', Integer) + c1 = Column("c1", Integer) s = select([c1]).where(c1 == 5).select_from(t1) - t2 = Table('t2', MetaData(), c1) + t2 = Table("t2", MetaData(), c1) eq_(c1._from_objects, [t2]) - self.assert_compile(select([c1]), - "SELECT t2.c1 FROM t2") + self.assert_compile(select([c1]), "SELECT t2.c1 FROM t2") def test_from_list_deferred_cloning(self): - c1 = Column('c1', Integer) - c2 = Column('c2', Integer) + c1 = Column("c1", Integer) + c2 = Column("c2", Integer) s = select([c1]) s2 = select([c2]) s3 = sql_util.ClauseAdapter(s).traverse(s2) - Table('t', MetaData(), c1, c2) + Table("t", MetaData(), c1, c2) - self.assert_compile( - s3, - "SELECT t.c2 FROM t" - ) + self.assert_compile(s3, "SELECT t.c2 FROM t") def test_from_list_with_columns(self): - table1 = table('t1', column('a')) - table2 = table('t2', column('b')) + table1 = table("t1", column("a")) + table2 = table("t2", column("b")) s1 = select([table1.c.a, table2.c.b]) - self.assert_compile(s1, - "SELECT t1.a, t2.b FROM t1, t2" - ) + self.assert_compile(s1, "SELECT t1.a, t2.b FROM t1, t2") s2 = s1.with_only_columns([table2.c.b]) - self.assert_compile(s2, - "SELECT t2.b FROM t2" - ) + self.assert_compile(s2, "SELECT t2.b FROM t2") s3 = sql_util.ClauseAdapter(table1).traverse(s1) - self.assert_compile(s3, - "SELECT t1.a, t2.b FROM t1, t2" - ) + self.assert_compile(s3, "SELECT t1.a, t2.b FROM t1, t2") s4 = s3.with_only_columns([table2.c.b]) - self.assert_compile(s4, - "SELECT t2.b FROM t2" - ) + self.assert_compile(s4, "SELECT t2.b FROM t2") def test_from_list_warning_against_existing(self): - c1 = Column('c1', Integer) + c1 = Column("c1", Integer) s = select([c1]) # force a compile. - self.assert_compile( - s, - "SELECT c1" - ) + self.assert_compile(s, "SELECT c1") - Table('t', MetaData(), c1) + Table("t", MetaData(), c1) - self.assert_compile( - s, - "SELECT t.c1 FROM t" - ) + self.assert_compile(s, "SELECT t.c1 FROM t") def test_from_list_recovers_after_warning(self): - c1 = Column('c1', Integer) - c2 = Column('c2', Integer) + c1 = Column("c1", Integer) + c2 = Column("c2", Integer) s = select([c1]) @@ -779,7 +822,8 @@ class SelectableTest( @testing.emits_warning() def go(): - return Table('t', MetaData(), c1, c2) + return Table("t", MetaData(), c1, c2) + t = go() eq_(c1._from_objects, [t]) @@ -793,175 +837,175 @@ class SelectableTest( self.assert_compile(select([c2]), "SELECT t.c2 FROM t") def test_label_gen_resets_on_table(self): - c1 = Column('c1', Integer) + c1 = Column("c1", Integer) eq_(c1._label, "c1") - Table('t1', MetaData(), c1) + Table("t1", MetaData(), c1) eq_(c1._label, "t1_c1") class RefreshForNewColTest(fixtures.TestBase): - def test_join_uninit(self): - a = table('a', column('x')) - b = table('b', column('y')) + a = table("a", column("x")) + b = table("b", column("y")) j = a.join(b, a.c.x == b.c.y) - q = column('q') + q = column("q") b.append_column(q) j._refresh_for_new_column(q) assert j.c.b_q is q def test_join_init(self): - a = table('a', column('x')) - b = table('b', column('y')) + a = table("a", column("x")) + b = table("b", column("y")) j = a.join(b, a.c.x == b.c.y) j.c - q = column('q') + q = column("q") b.append_column(q) j._refresh_for_new_column(q) assert j.c.b_q is q def test_join_samename_init(self): - a = table('a', column('x')) - b = table('b', column('y')) + a = table("a", column("x")) + b = table("b", column("y")) j = a.join(b, a.c.x == b.c.y) j.c - q = column('x') + q = column("x") b.append_column(q) j._refresh_for_new_column(q) assert j.c.b_x is q def test_select_samename_init(self): - a = table('a', column('x')) - b = table('b', column('y')) + a = table("a", column("x")) + b = table("b", column("y")) s = select([a, b]).apply_labels() s.c - q = column('x') + q = column("x") b.append_column(q) s._refresh_for_new_column(q) assert q in s.c.b_x.proxy_set def test_aliased_select_samename_uninit(self): - a = table('a', column('x')) - b = table('b', column('y')) + a = table("a", column("x")) + b = table("b", column("y")) s = select([a, b]).apply_labels().alias() - q = column('x') + q = column("x") b.append_column(q) s._refresh_for_new_column(q) assert q in s.c.b_x.proxy_set def test_aliased_select_samename_init(self): - a = table('a', column('x')) - b = table('b', column('y')) + a = table("a", column("x")) + b = table("b", column("y")) s = select([a, b]).apply_labels().alias() s.c - q = column('x') + q = column("x") b.append_column(q) s._refresh_for_new_column(q) assert q in s.c.b_x.proxy_set def test_aliased_select_irrelevant(self): - a = table('a', column('x')) - b = table('b', column('y')) - c = table('c', column('z')) + a = table("a", column("x")) + b = table("b", column("y")) + c = table("c", column("z")) s = select([a, b]).apply_labels().alias() s.c - q = column('x') + q = column("x") c.append_column(q) s._refresh_for_new_column(q) - assert 'c_x' not in s.c + assert "c_x" not in s.c def test_aliased_select_no_cols_clause(self): - a = table('a', column('x')) + a = table("a", column("x")) s = select([a.c.x]).apply_labels().alias() s.c - q = column('q') + q = column("q") a.append_column(q) s._refresh_for_new_column(q) - assert 'a_q' not in s.c + assert "a_q" not in s.c def test_union_uninit(self): - a = table('a', column('x')) + a = table("a", column("x")) s1 = select([a]) s2 = select([a]) s3 = s1.union(s2) - q = column('q') + q = column("q") a.append_column(q) s3._refresh_for_new_column(q) assert a.c.q in s3.c.q.proxy_set def test_union_init_raises(self): - a = table('a', column('x')) + a = table("a", column("x")) s1 = select([a]) s2 = select([a]) s3 = s1.union(s2) s3.c - q = column('q') + q = column("q") a.append_column(q) assert_raises_message( NotImplementedError, "CompoundSelect constructs don't support addition of " "columns to underlying selectables", - s3._refresh_for_new_column, q + s3._refresh_for_new_column, + q, ) def test_nested_join_uninit(self): - a = table('a', column('x')) - b = table('b', column('y')) - c = table('c', column('z')) + a = table("a", column("x")) + b = table("b", column("y")) + c = table("c", column("z")) j = a.join(b, a.c.x == b.c.y).join(c, b.c.y == c.c.z) - q = column('q') + q = column("q") b.append_column(q) j._refresh_for_new_column(q) assert j.c.b_q is q def test_nested_join_init(self): - a = table('a', column('x')) - b = table('b', column('y')) - c = table('c', column('z')) + a = table("a", column("x")) + b = table("b", column("y")) + c = table("c", column("z")) j = a.join(b, a.c.x == b.c.y).join(c, b.c.y == c.c.z) j.c - q = column('q') + q = column("q") b.append_column(q) j._refresh_for_new_column(q) assert j.c.b_q is q def test_fk_table(self): m = MetaData() - fk = ForeignKey('x.id') - Table('x', m, Column('id', Integer)) - a = Table('a', m, Column('x', Integer, fk)) + fk = ForeignKey("x.id") + Table("x", m, Column("id", Integer)) + a = Table("a", m, Column("x", Integer, fk)) a.c - q = Column('q', Integer) + q = Column("q", Integer) a.append_column(q) a._refresh_for_new_column(q) eq_(a.foreign_keys, set([fk])) - fk2 = ForeignKey('g.id') - p = Column('p', Integer, fk2) + fk2 = ForeignKey("g.id") + p = Column("p", Integer, fk2) a.append_column(p) a._refresh_for_new_column(p) eq_(a.foreign_keys, set([fk, fk2])) def test_fk_join(self): m = MetaData() - fk = ForeignKey('x.id') - Table('x', m, Column('id', Integer)) - a = Table('a', m, Column('x', Integer, fk)) - b = Table('b', m, Column('y', Integer)) + fk = ForeignKey("x.id") + Table("x", m, Column("id", Integer)) + a = Table("a", m, Column("x", Integer, fk)) + b = Table("b", m, Column("y", Integer)) j = a.join(b, a.c.x == b.c.y) j.c - q = Column('q', Integer) + q = Column("q", Integer) b.append_column(q) j._refresh_for_new_column(q) eq_(j.foreign_keys, set([fk])) - fk2 = ForeignKey('g.id') - p = Column('p', Integer, fk2) + fk2 = ForeignKey("g.id") + p = Column("p", Integer, fk2) b.append_column(p) j._refresh_for_new_column(p) eq_(j.foreign_keys, set([fk, fk2])) @@ -972,18 +1016,18 @@ class AnonLabelTest(fixtures.TestBase): """Test behaviors fixed by [ticket:2168].""" def test_anon_labels_named_column(self): - c1 = column('x') + c1 = column("x") assert c1.label(None) is not c1 eq_(str(select([c1.label(None)])), "SELECT x AS x_1") def test_anon_labels_literal_column(self): - c1 = literal_column('x') + c1 = literal_column("x") assert c1.label(None) is not c1 eq_(str(select([c1.label(None)])), "SELECT x AS x_1") def test_anon_labels_func(self): - c1 = func.count('*') + c1 = func.count("*") assert c1.label(None) is not c1 eq_(str(select([c1])), "SELECT count(:count_2) AS count_1") @@ -992,76 +1036,76 @@ class AnonLabelTest(fixtures.TestBase): eq_(str(select([c1.label(None)])), "SELECT count(:count_2) AS count_1") def test_named_labels_named_column(self): - c1 = column('x') - eq_(str(select([c1.label('y')])), "SELECT x AS y") + c1 = column("x") + eq_(str(select([c1.label("y")])), "SELECT x AS y") def test_named_labels_literal_column(self): - c1 = literal_column('x') - eq_(str(select([c1.label('y')])), "SELECT x AS y") + c1 = literal_column("x") + eq_(str(select([c1.label("y")])), "SELECT x AS y") class JoinAliasingTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_flat_ok_on_non_join(self): - a = table('a', column('a')) + a = table("a", column("a")) s = a.select() self.assert_compile( s.alias(flat=True).select(), - "SELECT anon_1.a FROM (SELECT a.a AS a FROM a) AS anon_1" + "SELECT anon_1.a FROM (SELECT a.a AS a FROM a) AS anon_1", ) def test_join_alias(self): - a = table('a', column('a')) - b = table('b', column('b')) + a = table("a", column("a")) + b = table("b", column("b")) self.assert_compile( a.join(b, a.c.a == b.c.b).alias(), - "SELECT a.a AS a_a, b.b AS b_b FROM a JOIN b ON a.a = b.b" + "SELECT a.a AS a_a, b.b AS b_b FROM a JOIN b ON a.a = b.b", ) def test_join_standalone_alias(self): - a = table('a', column('a')) - b = table('b', column('b')) + a = table("a", column("a")) + b = table("b", column("b")) self.assert_compile( alias(a.join(b, a.c.a == b.c.b)), - "SELECT a.a AS a_a, b.b AS b_b FROM a JOIN b ON a.a = b.b" + "SELECT a.a AS a_a, b.b AS b_b FROM a JOIN b ON a.a = b.b", ) def test_join_alias_flat(self): - a = table('a', column('a')) - b = table('b', column('b')) + a = table("a", column("a")) + b = table("b", column("b")) self.assert_compile( a.join(b, a.c.a == b.c.b).alias(flat=True), - "a AS a_1 JOIN b AS b_1 ON a_1.a = b_1.b" + "a AS a_1 JOIN b AS b_1 ON a_1.a = b_1.b", ) def test_join_standalone_alias_flat(self): - a = table('a', column('a')) - b = table('b', column('b')) + a = table("a", column("a")) + b = table("b", column("b")) self.assert_compile( alias(a.join(b, a.c.a == b.c.b), flat=True), - "a AS a_1 JOIN b AS b_1 ON a_1.a = b_1.b" + "a AS a_1 JOIN b AS b_1 ON a_1.a = b_1.b", ) def test_composed_join_alias_flat(self): - a = table('a', column('a')) - b = table('b', column('b')) - c = table('c', column('c')) - d = table('d', column('d')) + a = table("a", column("a")) + b = table("b", column("b")) + c = table("c", column("c")) + d = table("d", column("d")) j1 = a.join(b, a.c.a == b.c.b) j2 = c.join(d, c.c.c == d.c.d) self.assert_compile( j1.join(j2, b.c.b == c.c.c).alias(flat=True), "a AS a_1 JOIN b AS b_1 ON a_1.a = b_1.b JOIN " - "(c AS c_1 JOIN d AS d_1 ON c_1.c = d_1.d) ON b_1.b = c_1.c" + "(c AS c_1 JOIN d AS d_1 ON c_1.c = d_1.d) ON b_1.b = c_1.c", ) def test_composed_join_alias(self): - a = table('a', column('a')) - b = table('b', column('b')) - c = table('c', column('c')) - d = table('d', column('d')) + a = table("a", column("a")) + b = table("b", column("b")) + c = table("c", column("c")) + d = table("d", column("d")) j1 = a.join(b, a.c.a == b.c.b) j2 = c.join(d, c.c.c == d.c.d) @@ -1070,29 +1114,35 @@ class JoinAliasingTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT anon_1.a_a, anon_1.b_b, anon_1.c_c, anon_1.d_d " "FROM (SELECT a.a AS a_a, b.b AS b_b, c.c AS c_c, d.d AS d_d " "FROM a JOIN b ON a.a = b.b " - "JOIN (c JOIN d ON c.c = d.d) ON b.b = c.c) AS anon_1" + "JOIN (c JOIN d ON c.c = d.d) ON b.b = c.c) AS anon_1", ) class JoinConditionTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_join_condition(self): m = MetaData() - t1 = Table('t1', m, Column('id', Integer)) - t2 = Table('t2', m, - Column('id', Integer), - Column('t1id', ForeignKey('t1.id'))) - t3 = Table('t3', m, - Column('id', Integer), - Column('t1id', ForeignKey('t1.id')), - Column('t2id', ForeignKey('t2.id'))) - t4 = Table('t4', m, Column('id', Integer), - Column('t2id', ForeignKey('t2.id'))) - t5 = Table('t5', m, - Column('t1id1', ForeignKey('t1.id')), - Column('t1id2', ForeignKey('t1.id')), - ) + t1 = Table("t1", m, Column("id", Integer)) + t2 = Table( + "t2", m, Column("id", Integer), Column("t1id", ForeignKey("t1.id")) + ) + t3 = Table( + "t3", + m, + Column("id", Integer), + Column("t1id", ForeignKey("t1.id")), + Column("t2id", ForeignKey("t2.id")), + ) + t4 = Table( + "t4", m, Column("id", Integer), Column("t2id", ForeignKey("t2.id")) + ) + t5 = Table( + "t5", + m, + Column("t1id1", ForeignKey("t1.id")), + Column("t1id2", ForeignKey("t1.id")), + ) t1t2 = t1.join(t2) t2t3 = t2.join(t3) @@ -1108,10 +1158,8 @@ class JoinConditionTest(fixtures.TestBase, AssertsCompiledSQL): (t1t2, t2t3, t2, t1t2.c.t2_id == t2t3.c.t3_t2id), ]: assert expected.compare( - sql_util.join_condition( - left, - right, - a_subset=a_subset)) + sql_util.join_condition(left, right, a_subset=a_subset) + ) # these are ambiguous, or have no joins for left, right, a_subset in [ @@ -1120,12 +1168,14 @@ class JoinConditionTest(fixtures.TestBase, AssertsCompiledSQL): (t1, t4, None), (t1t2, t2t3, None), (t5, t1, None), - (t5.select(use_labels=True), t1, None) + (t5.select(use_labels=True), t1, None), ]: assert_raises( exc.ArgumentError, sql_util.join_condition, - left, right, a_subset=a_subset + left, + right, + a_subset=a_subset, ) als = t2t3.alias() @@ -1138,41 +1188,38 @@ class JoinConditionTest(fixtures.TestBase, AssertsCompiledSQL): (t2t3, t4, t2t3.c.t2_id == t4.c.t2id), (t2t3.join(t1), t4, t2t3.c.t2_id == t4.c.t2id), (t2t3.join(t1), t4, t2t3.c.t2_id == t4.c.t2id), - (t1t2, als, t1t2.c.t2_id == als.c.t3_t2id) + (t1t2, als, t1t2.c.t2_id == als.c.t3_t2id), ]: - assert expected.compare( - left.join(right).onclause - ) + assert expected.compare(left.join(right).onclause) # these are right-nested joins j = t1t2.join(t2t3) assert j.onclause.compare(t2.c.id == t3.c.t2id) self.assert_compile( - j, "t1 JOIN t2 ON t1.id = t2.t1id JOIN " - "(t2 JOIN t3 ON t2.id = t3.t2id) ON t2.id = t3.t2id") + j, + "t1 JOIN t2 ON t1.id = t2.t1id JOIN " + "(t2 JOIN t3 ON t2.id = t3.t2id) ON t2.id = t3.t2id", + ) st2t3 = t2t3.select(use_labels=True) j = t1t2.join(st2t3) assert j.onclause.compare(t2.c.id == st2t3.c.t3_t2id) self.assert_compile( - j, "t1 JOIN t2 ON t1.id = t2.t1id JOIN " + j, + "t1 JOIN t2 ON t1.id = t2.t1id JOIN " "(SELECT t2.id AS t2_id, t2.t1id AS t2_t1id, " "t3.id AS t3_id, t3.t1id AS t3_t1id, t3.t2id AS t3_t2id " - "FROM t2 JOIN t3 ON t2.id = t3.t2id) ON t2.id = t3_t2id") + "FROM t2 JOIN t3 ON t2.id = t3.t2id) ON t2.id = t3_t2id", + ) def test_join_multiple_equiv_fks(self): m = MetaData() - t1 = Table('t1', m, - Column('id', Integer, primary_key=True) - ) + t1 = Table("t1", m, Column("id", Integer, primary_key=True)) t2 = Table( - 't2', + "t2", m, - Column( - 't1id', - Integer, - ForeignKey('t1.id'), - ForeignKey('t1.id'))) + Column("t1id", Integer, ForeignKey("t1.id"), ForeignKey("t1.id")), + ) assert sql_util.join_condition(t1, t2).compare(t1.c.id == t2.c.t1id) @@ -1181,35 +1228,43 @@ class JoinConditionTest(fixtures.TestBase, AssertsCompiledSQL): # bounding the "good" column with two "bad" ones is so to # try to get coverage to get the "continue" statements # in the loop... - t1 = Table('t1', m, - Column('y', Integer, ForeignKey('t22.id')), - Column('x', Integer, ForeignKey('t2.id')), - Column('q', Integer, ForeignKey('t22.id')), - ) - t2 = Table('t2', m, Column('id', Integer)) + t1 = Table( + "t1", + m, + Column("y", Integer, ForeignKey("t22.id")), + Column("x", Integer, ForeignKey("t2.id")), + Column("q", Integer, ForeignKey("t22.id")), + ) + t2 = Table("t2", m, Column("id", Integer)) assert sql_util.join_condition(t1, t2).compare(t1.c.x == t2.c.id) assert sql_util.join_condition(t2, t1).compare(t1.c.x == t2.c.id) def test_join_cond_no_such_unrelated_column(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer, ForeignKey('t2.id')), - Column('y', Integer, ForeignKey('t3.q'))) - t2 = Table('t2', m, Column('id', Integer)) - Table('t3', m, Column('id', Integer)) + t1 = Table( + "t1", + m, + Column("x", Integer, ForeignKey("t2.id")), + Column("y", Integer, ForeignKey("t3.q")), + ) + t2 = Table("t2", m, Column("id", Integer)) + Table("t3", m, Column("id", Integer)) assert sql_util.join_condition(t1, t2).compare(t1.c.x == t2.c.id) assert sql_util.join_condition(t2, t1).compare(t1.c.x == t2.c.id) def test_join_cond_no_such_related_table(self): m1 = MetaData() m2 = MetaData() - t1 = Table('t1', m1, Column('x', Integer, ForeignKey('t2.id'))) - t2 = Table('t2', m2, Column('id', Integer)) + t1 = Table("t1", m1, Column("x", Integer, ForeignKey("t2.id"))) + t2 = Table("t2", m2, Column("id", Integer)) assert_raises_message( exc.NoReferencedTableError, "Foreign key associated with column 't1.x' could not find " "table 't2' with which to generate a foreign key to " "target column 'id'", - sql_util.join_condition, t1, t2 + sql_util.join_condition, + t1, + t2, ) assert_raises_message( @@ -1217,19 +1272,23 @@ class JoinConditionTest(fixtures.TestBase, AssertsCompiledSQL): "Foreign key associated with column 't1.x' could not find " "table 't2' with which to generate a foreign key to " "target column 'id'", - sql_util.join_condition, t2, t1 + sql_util.join_condition, + t2, + t1, ) def test_join_cond_no_such_related_column(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer, ForeignKey('t2.q'))) - t2 = Table('t2', m, Column('id', Integer)) + t1 = Table("t1", m, Column("x", Integer, ForeignKey("t2.q"))) + t2 = Table("t2", m, Column("id", Integer)) assert_raises_message( exc.NoReferencedColumnError, "Could not initialize target column for " "ForeignKey 't2.q' on table 't1': " "table 't2' has no column named 'q'", - sql_util.join_condition, t1, t2 + sql_util.join_condition, + t1, + t2, ) assert_raises_message( @@ -1237,25 +1296,35 @@ class JoinConditionTest(fixtures.TestBase, AssertsCompiledSQL): "Could not initialize target column for " "ForeignKey 't2.q' on table 't1': " "table 't2' has no column named 'q'", - sql_util.join_condition, t2, t1 + sql_util.join_condition, + t2, + t1, ) class PrimaryKeyTest(fixtures.TestBase, AssertsExecutionResults): - def test_join_pk_collapse_implicit(self): """test that redundant columns in a join get 'collapsed' into a minimal primary key, which is the root column along a chain of foreign key relationships.""" meta = MetaData() - a = Table('a', meta, Column('id', Integer, primary_key=True)) - b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), - primary_key=True)) - c = Table('c', meta, Column('id', Integer, ForeignKey('b.id'), - primary_key=True)) - d = Table('d', meta, Column('id', Integer, ForeignKey('c.id'), - primary_key=True)) + a = Table("a", meta, Column("id", Integer, primary_key=True)) + b = Table( + "b", + meta, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + ) + c = Table( + "c", + meta, + Column("id", Integer, ForeignKey("b.id"), primary_key=True), + ) + d = Table( + "d", + meta, + Column("id", Integer, ForeignKey("c.id"), primary_key=True), + ) assert c.c.id.references(b.c.id) assert not d.c.id.references(a.c.id) assert list(a.join(b).primary_key) == [a.c.id] @@ -1271,43 +1340,60 @@ class PrimaryKeyTest(fixtures.TestBase, AssertsExecutionResults): explicit join conditions.""" meta = MetaData() - a = Table('a', meta, Column('id', Integer, primary_key=True), - Column('x', Integer)) - b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), - primary_key=True), Column('x', Integer)) - c = Table('c', meta, Column('id', Integer, ForeignKey('b.id'), - primary_key=True), Column('x', Integer)) - d = Table('d', meta, Column('id', Integer, ForeignKey('c.id'), - primary_key=True), Column('x', Integer)) + a = Table( + "a", + meta, + Column("id", Integer, primary_key=True), + Column("x", Integer), + ) + b = Table( + "b", + meta, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + Column("x", Integer), + ) + c = Table( + "c", + meta, + Column("id", Integer, ForeignKey("b.id"), primary_key=True), + Column("x", Integer), + ) + d = Table( + "d", + meta, + Column("id", Integer, ForeignKey("c.id"), primary_key=True), + Column("x", Integer), + ) print(list(a.join(b, a.c.x == b.c.id).primary_key)) assert list(a.join(b, a.c.x == b.c.id).primary_key) == [a.c.id] assert list(b.join(c, b.c.x == c.c.id).primary_key) == [b.c.id] - assert list(a.join(b).join(c, c.c.id == b.c.x).primary_key) \ - == [a.c.id] - assert list(b.join(c, c.c.x == b.c.id).join(d).primary_key) \ - == [b.c.id] - assert list(b.join(c, c.c.id == b.c.x).join(d).primary_key) \ - == [b.c.id] + assert list(a.join(b).join(c, c.c.id == b.c.x).primary_key) == [a.c.id] + assert list(b.join(c, c.c.x == b.c.id).join(d).primary_key) == [b.c.id] + assert list(b.join(c, c.c.id == b.c.x).join(d).primary_key) == [b.c.id] + assert list( + d.join(b, d.c.id == b.c.id).join(c, b.c.id == c.c.x).primary_key + ) == [b.c.id] assert list( - d.join( - b, - d.c.id == b.c.id).join( - c, - b.c.id == c.c.x).primary_key) == [ - b.c.id] - assert list(a.join(b).join(c, c.c.id - == b.c.x).join(d).primary_key) == [a.c.id] - assert list(a.join(b, and_(a.c.id == b.c.id, a.c.x - == b.c.id)).primary_key) == [a.c.id] + a.join(b).join(c, c.c.id == b.c.x).join(d).primary_key + ) == [a.c.id] + assert list( + a.join(b, and_(a.c.id == b.c.id, a.c.x == b.c.id)).primary_key + ) == [a.c.id] def test_init_doesnt_blowitaway(self): meta = MetaData() - a = Table('a', meta, - Column('id', Integer, primary_key=True), - Column('x', Integer)) - b = Table('b', meta, - Column('id', Integer, ForeignKey('a.id'), primary_key=True), - Column('x', Integer)) + a = Table( + "a", + meta, + Column("id", Integer, primary_key=True), + Column("x", Integer), + ) + b = Table( + "b", + meta, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + Column("x", Integer), + ) j = a.join(b) assert list(j.primary_key) == [a.c.id] @@ -1317,12 +1403,18 @@ class PrimaryKeyTest(fixtures.TestBase, AssertsExecutionResults): def test_non_column_clause(self): meta = MetaData() - a = Table('a', meta, - Column('id', Integer, primary_key=True), - Column('x', Integer)) - b = Table('b', meta, - Column('id', Integer, ForeignKey('a.id'), primary_key=True), - Column('x', Integer, primary_key=True)) + a = Table( + "a", + meta, + Column("id", Integer, primary_key=True), + Column("x", Integer), + ) + b = Table( + "b", + meta, + Column("id", Integer, ForeignKey("a.id"), primary_key=True), + Column("x", Integer, primary_key=True), + ) j = a.join(b, and_(a.c.id == b.c.id, b.c.x == 5)) assert str(j) == "a JOIN b ON a.id = b.id AND b.x = :x_1", str(j) @@ -1331,123 +1423,156 @@ class PrimaryKeyTest(fixtures.TestBase, AssertsExecutionResults): def test_onclause_direction(self): metadata = MetaData() - employee = Table('Employee', metadata, - Column('name', String(100)), - Column('id', Integer, primary_key=True), - ) + employee = Table( + "Employee", + metadata, + Column("name", String(100)), + Column("id", Integer, primary_key=True), + ) - engineer = Table('Engineer', metadata, - Column('id', Integer, - ForeignKey('Employee.id'), primary_key=True)) + engineer = Table( + "Engineer", + metadata, + Column("id", Integer, ForeignKey("Employee.id"), primary_key=True), + ) - eq_(util.column_set(employee.join(engineer, employee.c.id - == engineer.c.id).primary_key), - util.column_set([employee.c.id])) - eq_(util.column_set(employee.join(engineer, engineer.c.id - == employee.c.id).primary_key), - util.column_set([employee.c.id])) + eq_( + util.column_set( + employee.join( + engineer, employee.c.id == engineer.c.id + ).primary_key + ), + util.column_set([employee.c.id]), + ) + eq_( + util.column_set( + employee.join( + engineer, engineer.c.id == employee.c.id + ).primary_key + ), + util.column_set([employee.c.id]), + ) class ReduceTest(fixtures.TestBase, AssertsExecutionResults): - def test_reduce(self): meta = MetaData() - t1 = Table('t1', meta, - Column('t1id', Integer, primary_key=True), - Column('t1data', String(30))) + t1 = Table( + "t1", + meta, + Column("t1id", Integer, primary_key=True), + Column("t1data", String(30)), + ) t2 = Table( - 't2', + "t2", meta, - Column( - 't2id', - Integer, - ForeignKey('t1.t1id'), - primary_key=True), - Column( - 't2data', - String(30))) + Column("t2id", Integer, ForeignKey("t1.t1id"), primary_key=True), + Column("t2data", String(30)), + ) t3 = Table( - 't3', + "t3", meta, - Column( - 't3id', - Integer, - ForeignKey('t2.t2id'), - primary_key=True), - Column( - 't3data', - String(30))) - - eq_(util.column_set(sql_util.reduce_columns([ - t1.c.t1id, - t1.c.t1data, - t2.c.t2id, - t2.c.t2data, - t3.c.t3id, - t3.c.t3data, - ])), util.column_set([t1.c.t1id, t1.c.t1data, t2.c.t2data, - t3.c.t3data])) + Column("t3id", Integer, ForeignKey("t2.t2id"), primary_key=True), + Column("t3data", String(30)), + ) + + eq_( + util.column_set( + sql_util.reduce_columns( + [ + t1.c.t1id, + t1.c.t1data, + t2.c.t2id, + t2.c.t2data, + t3.c.t3id, + t3.c.t3data, + ] + ) + ), + util.column_set( + [t1.c.t1id, t1.c.t1data, t2.c.t2data, t3.c.t3data] + ), + ) def test_reduce_selectable(self): metadata = MetaData() - engineers = Table('engineers', metadata, - Column('engineer_id', Integer, primary_key=True), - Column('engineer_name', String(50))) - managers = Table('managers', metadata, - Column('manager_id', Integer, primary_key=True), - Column('manager_name', String(50))) - s = select([engineers, - managers]).where(engineers.c.engineer_name - == managers.c.manager_name) - eq_(util.column_set(sql_util.reduce_columns(list(s.c), s)), - util.column_set([s.c.engineer_id, s.c.engineer_name, - s.c.manager_id])) + engineers = Table( + "engineers", + metadata, + Column("engineer_id", Integer, primary_key=True), + Column("engineer_name", String(50)), + ) + managers = Table( + "managers", + metadata, + Column("manager_id", Integer, primary_key=True), + Column("manager_name", String(50)), + ) + s = select([engineers, managers]).where( + engineers.c.engineer_name == managers.c.manager_name + ) + eq_( + util.column_set(sql_util.reduce_columns(list(s.c), s)), + util.column_set( + [s.c.engineer_id, s.c.engineer_name, s.c.manager_id] + ), + ) def test_reduce_generation(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer, primary_key=True), - Column('y', Integer)) - t2 = Table('t2', m, Column('z', Integer, ForeignKey('t1.x')), - Column('q', Integer)) + t1 = Table( + "t1", + m, + Column("x", Integer, primary_key=True), + Column("y", Integer), + ) + t2 = Table( + "t2", + m, + Column("z", Integer, ForeignKey("t1.x")), + Column("q", Integer), + ) s1 = select([t1, t2]) s2 = s1.reduce_columns(only_synonyms=False) - eq_( - set(s2.inner_columns), - set([t1.c.x, t1.c.y, t2.c.q]) - ) + eq_(set(s2.inner_columns), set([t1.c.x, t1.c.y, t2.c.q])) s2 = s1.reduce_columns() - eq_( - set(s2.inner_columns), - set([t1.c.x, t1.c.y, t2.c.z, t2.c.q]) - ) + eq_(set(s2.inner_columns), set([t1.c.x, t1.c.y, t2.c.z, t2.c.q])) def test_reduce_only_synonym_fk(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer, primary_key=True), - Column('y', Integer)) - t2 = Table('t2', m, Column('x', Integer, ForeignKey('t1.x')), - Column('q', Integer, ForeignKey('t1.y'))) + t1 = Table( + "t1", + m, + Column("x", Integer, primary_key=True), + Column("y", Integer), + ) + t2 = Table( + "t2", + m, + Column("x", Integer, ForeignKey("t1.x")), + Column("q", Integer, ForeignKey("t1.y")), + ) s1 = select([t1, t2]) s1 = s1.reduce_columns(only_synonyms=True) - eq_( - set(s1.c), - set([s1.c.x, s1.c.y, s1.c.q]) - ) + eq_(set(s1.c), set([s1.c.x, s1.c.y, s1.c.q])) def test_reduce_only_synonym_lineage(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer, primary_key=True), - Column('y', Integer), - Column('z', Integer) - ) + t1 = Table( + "t1", + m, + Column("x", Integer, primary_key=True), + Column("y", Integer), + Column("z", Integer), + ) # test that the first appearance in the columns clause # wins - t1 is first, t1.c.x wins s1 = select([t1]) s2 = select([t1, s1]).where(t1.c.x == s1.c.x).where(s1.c.y == t1.c.z) eq_( set(s2.reduce_columns().inner_columns), - set([t1.c.x, t1.c.y, t1.c.z, s1.c.y, s1.c.z]) + set([t1.c.x, t1.c.y, t1.c.z, s1.c.y, s1.c.z]), ) # reverse order, s1.c.x wins @@ -1455,91 +1580,129 @@ class ReduceTest(fixtures.TestBase, AssertsExecutionResults): s2 = select([s1, t1]).where(t1.c.x == s1.c.x).where(s1.c.y == t1.c.z) eq_( set(s2.reduce_columns().inner_columns), - set([s1.c.x, t1.c.y, t1.c.z, s1.c.y, s1.c.z]) + set([s1.c.x, t1.c.y, t1.c.z, s1.c.y, s1.c.z]), ) def test_reduce_aliased_join(self): metadata = MetaData() people = Table( - 'people', metadata, Column( - 'person_id', Integer, Sequence( - 'person_id_seq', optional=True), primary_key=True), Column( - 'name', String(50)), Column( - 'type', String(30))) + "people", + metadata, + Column( + "person_id", + Integer, + Sequence("person_id_seq", optional=True), + primary_key=True, + ), + Column("name", String(50)), + Column("type", String(30)), + ) engineers = Table( - 'engineers', + "engineers", metadata, - Column('person_id', Integer, ForeignKey('people.person_id' - ), primary_key=True), - Column('status', String(30)), - Column('engineer_name', String(50)), - Column('primary_language', String(50)), + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("status", String(30)), + Column("engineer_name", String(50)), + Column("primary_language", String(50)), ) managers = Table( - 'managers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), - primary_key=True), - Column('status', String(30)), - Column('manager_name', String(50))) - pjoin = \ - people.outerjoin(engineers).outerjoin(managers).\ - select(use_labels=True).alias('pjoin' - ) - eq_(util.column_set(sql_util.reduce_columns( - [pjoin.c.people_person_id, pjoin.c.engineers_person_id, - pjoin.c.managers_person_id])), - util.column_set([pjoin.c.people_person_id])) + "managers", + metadata, + Column( + "person_id", + Integer, + ForeignKey("people.person_id"), + primary_key=True, + ), + Column("status", String(30)), + Column("manager_name", String(50)), + ) + pjoin = ( + people.outerjoin(engineers) + .outerjoin(managers) + .select(use_labels=True) + .alias("pjoin") + ) + eq_( + util.column_set( + sql_util.reduce_columns( + [ + pjoin.c.people_person_id, + pjoin.c.engineers_person_id, + pjoin.c.managers_person_id, + ] + ) + ), + util.column_set([pjoin.c.people_person_id]), + ) def test_reduce_aliased_union(self): metadata = MetaData() item_table = Table( - 'item', + "item", metadata, Column( - 'id', - Integer, - ForeignKey('base_item.id'), - primary_key=True), - Column( - 'dummy', - Integer, - default=0)) + "id", Integer, ForeignKey("base_item.id"), primary_key=True + ), + Column("dummy", Integer, default=0), + ) base_item_table = Table( - 'base_item', metadata, Column( - 'id', Integer, primary_key=True), Column( - 'child_name', String(255), default=None)) + "base_item", + metadata, + Column("id", Integer, primary_key=True), + Column("child_name", String(255), default=None), + ) from sqlalchemy.orm.util import polymorphic_union - item_join = polymorphic_union({ - 'BaseItem': - base_item_table.select( - base_item_table.c.child_name - == 'BaseItem'), - 'Item': base_item_table.join(item_table)}, - None, 'item_join') - eq_(util.column_set(sql_util.reduce_columns([item_join.c.id, - item_join.c.dummy, - item_join.c.child_name])), - util.column_set([item_join.c.id, - item_join.c.dummy, - item_join.c.child_name])) + + item_join = polymorphic_union( + { + "BaseItem": base_item_table.select( + base_item_table.c.child_name == "BaseItem" + ), + "Item": base_item_table.join(item_table), + }, + None, + "item_join", + ) + eq_( + util.column_set( + sql_util.reduce_columns( + [item_join.c.id, item_join.c.dummy, item_join.c.child_name] + ) + ), + util.column_set( + [item_join.c.id, item_join.c.dummy, item_join.c.child_name] + ), + ) def test_reduce_aliased_union_2(self): metadata = MetaData() - page_table = Table('page', metadata, Column('id', Integer, - primary_key=True)) - magazine_page_table = Table('magazine_page', metadata, - Column('page_id', Integer, - ForeignKey('page.id'), - primary_key=True)) + page_table = Table( + "page", metadata, Column("id", Integer, primary_key=True) + ) + magazine_page_table = Table( + "magazine_page", + metadata, + Column( + "page_id", Integer, ForeignKey("page.id"), primary_key=True + ), + ) classified_page_table = Table( - 'classified_page', + "classified_page", metadata, Column( - 'magazine_page_id', + "magazine_page_id", Integer, - ForeignKey('magazine_page.page_id'), - primary_key=True)) + ForeignKey("magazine_page.page_id"), + primary_key=True, + ), + ) # this is essentially the union formed by the ORM's # polymorphic_union function. we define two versions with @@ -1549,25 +1712,33 @@ class ReduceTest(fixtures.TestBase, AssertsExecutionResults): # classified_page.magazine_page_id pjoin = union( - select([ - page_table.c.id, - magazine_page_table.c.page_id, - classified_page_table.c.magazine_page_id - ]). - select_from( - page_table.join(magazine_page_table). - join(classified_page_table)), - - select([ - page_table.c.id, - magazine_page_table.c.page_id, - cast(null(), Integer).label('magazine_page_id') - ]). - select_from(page_table.join(magazine_page_table)) - ).alias('pjoin') - eq_(util.column_set(sql_util.reduce_columns( - [pjoin.c.id, pjoin.c.page_id, pjoin.c.magazine_page_id])), - util.column_set([pjoin.c.id])) + select( + [ + page_table.c.id, + magazine_page_table.c.page_id, + classified_page_table.c.magazine_page_id, + ] + ).select_from( + page_table.join(magazine_page_table).join( + classified_page_table + ) + ), + select( + [ + page_table.c.id, + magazine_page_table.c.page_id, + cast(null(), Integer).label("magazine_page_id"), + ] + ).select_from(page_table.join(magazine_page_table)), + ).alias("pjoin") + eq_( + util.column_set( + sql_util.reduce_columns( + [pjoin.c.id, pjoin.c.page_id, pjoin.c.magazine_page_id] + ) + ), + util.column_set([pjoin.c.id]), + ) # the first selectable has a CAST, which is a placeholder for # classified_page.magazine_page_id in the second selectable. @@ -1576,45 +1747,70 @@ class ReduceTest(fixtures.TestBase, AssertsExecutionResults): # currently makes the external column look like that of the # first selectable only. - pjoin = union(select([ - page_table.c.id, - magazine_page_table.c.page_id, - cast(null(), Integer).label('magazine_page_id') - ]). - select_from(page_table.join(magazine_page_table)), - - select([ - page_table.c.id, - magazine_page_table.c.page_id, - classified_page_table.c.magazine_page_id - ]). - select_from(page_table.join(magazine_page_table). - join(classified_page_table)) - ).alias('pjoin') - eq_(util.column_set(sql_util.reduce_columns( - [pjoin.c.id, pjoin.c.page_id, pjoin.c.magazine_page_id])), - util.column_set([pjoin.c.id])) + pjoin = union( + select( + [ + page_table.c.id, + magazine_page_table.c.page_id, + cast(null(), Integer).label("magazine_page_id"), + ] + ).select_from(page_table.join(magazine_page_table)), + select( + [ + page_table.c.id, + magazine_page_table.c.page_id, + classified_page_table.c.magazine_page_id, + ] + ).select_from( + page_table.join(magazine_page_table).join( + classified_page_table + ) + ), + ).alias("pjoin") + eq_( + util.column_set( + sql_util.reduce_columns( + [pjoin.c.id, pjoin.c.page_id, pjoin.c.magazine_page_id] + ) + ), + util.column_set([pjoin.c.id]), + ) class DerivedTest(fixtures.TestBase, AssertsExecutionResults): - def test_table(self): meta = MetaData() - t1 = Table('t1', meta, Column('c1', Integer, primary_key=True), - Column('c2', String(30))) - t2 = Table('t2', meta, Column('c1', Integer, primary_key=True), - Column('c2', String(30))) + t1 = Table( + "t1", + meta, + Column("c1", Integer, primary_key=True), + Column("c2", String(30)), + ) + t2 = Table( + "t2", + meta, + Column("c1", Integer, primary_key=True), + Column("c2", String(30)), + ) assert t1.is_derived_from(t1) assert not t2.is_derived_from(t1) def test_alias(self): meta = MetaData() - t1 = Table('t1', meta, Column('c1', Integer, primary_key=True), - Column('c2', String(30))) - t2 = Table('t2', meta, Column('c1', Integer, primary_key=True), - Column('c2', String(30))) + t1 = Table( + "t1", + meta, + Column("c1", Integer, primary_key=True), + Column("c2", String(30)), + ) + t2 = Table( + "t2", + meta, + Column("c1", Integer, primary_key=True), + Column("c2", String(30)), + ) assert t1.alias().is_derived_from(t1) assert not t2.alias().is_derived_from(t1) @@ -1624,44 +1820,43 @@ class DerivedTest(fixtures.TestBase, AssertsExecutionResults): def test_select(self): meta = MetaData() - t1 = Table('t1', meta, Column('c1', Integer, primary_key=True), - Column('c2', String(30))) - t2 = Table('t2', meta, Column('c1', Integer, primary_key=True), - Column('c2', String(30))) + t1 = Table( + "t1", + meta, + Column("c1", Integer, primary_key=True), + Column("c2", String(30)), + ) + t2 = Table( + "t2", + meta, + Column("c1", Integer, primary_key=True), + Column("c2", String(30)), + ) assert t1.select().is_derived_from(t1) assert not t2.select().is_derived_from(t1) assert select([t1, t2]).is_derived_from(t1) - assert t1.select().alias('foo').is_derived_from(t1) - assert select([t1, t2]).alias('foo').is_derived_from(t1) - assert not t2.select().alias('foo').is_derived_from(t1) + assert t1.select().alias("foo").is_derived_from(t1) + assert select([t1, t2]).alias("foo").is_derived_from(t1) + assert not t2.select().alias("foo").is_derived_from(t1) class AnnotationsTest(fixtures.TestBase): - def test_hashing(self): - t = table('t', column('x')) + t = table("t", column("x")) a = t.alias() s = t.select() s2 = a.select() - for obj in [ - t, - t.c.x, - a, - s, - s2, - t.c.x > 1, - (t.c.x > 1).label(None) - ]: + for obj in [t, t.c.x, a, s, s2, t.c.x > 1, (t.c.x > 1).label(None)]: annot = obj._annotate({}) eq_(set([obj]), set([annot])) def test_compare(self): - t = table('t', column('x'), column('y')) + t = table("t", column("x"), column("y")) x_a = t.c.x._annotate({}) assert t.c.x.compare(x_a) assert x_a.compare(t.c.x) @@ -1681,40 +1876,44 @@ class AnnotationsTest(fixtures.TestBase): def test_late_name_add(self): from sqlalchemy.schema import Column + c1 = Column(Integer) c1_a = c1._annotate({"foo": "bar"}) - c1.name = 'somename' - eq_(c1_a.name, 'somename') + c1.name = "somename" + eq_(c1_a.name, "somename") def test_late_table_add(self): c1 = Column("foo", Integer) c1_a = c1._annotate({"foo": "bar"}) - t = Table('t', MetaData(), c1) + t = Table("t", MetaData(), c1) is_(c1_a.table, t) def test_basic_attrs(self): - t = Table('t', MetaData(), - Column('x', Integer, info={'q': 'p'}), - Column('y', Integer, key='q')) + t = Table( + "t", + MetaData(), + Column("x", Integer, info={"q": "p"}), + Column("y", Integer, key="q"), + ) x_a = t.c.x._annotate({}) y_a = t.c.q._annotate({}) - t.c.x.info['z'] = 'h' + t.c.x.info["z"] = "h" - eq_(y_a.key, 'q') + eq_(y_a.key, "q") is_(x_a.table, t) - eq_(x_a.info, {'q': 'p', 'z': 'h'}) + eq_(x_a.info, {"q": "p", "z": "h"}) eq_(t.c.x.anon_label, x_a.anon_label) def test_custom_constructions(self): from sqlalchemy.schema import Column class MyColumn(Column): - def __init__(self): - Column.__init__(self, 'foo', Integer) + Column.__init__(self, "foo", Integer) + _constructor = Column - t1 = Table('t1', MetaData(), MyColumn()) + t1 = Table("t1", MetaData(), MyColumn()) s1 = t1.select() assert isinstance(t1.c.foo, MyColumn) assert isinstance(s1.c.foo, Column) @@ -1734,8 +1933,9 @@ class AnnotationsTest(fixtures.TestBase): pass assert isinstance( - MyColumn('x', Integer)._annotate({"foo": "bar"}), - AnnotatedColumnElement) + MyColumn("x", Integer)._annotate({"foo": "bar"}), + AnnotatedColumnElement, + ) def test_custom_construction_correct_anno_expr(self): # [ticket:2918] @@ -1744,14 +1944,14 @@ class AnnotationsTest(fixtures.TestBase): class MyColumn(Column): pass - col = MyColumn('x', Integer) + col = MyColumn("x", Integer) binary_1 = col == 5 - col_anno = MyColumn('x', Integer)._annotate({"foo": "bar"}) + col_anno = MyColumn("x", Integer)._annotate({"foo": "bar"}) binary_2 = col_anno == 5 eq_(binary_2.left._annotations, {"foo": "bar"}) def test_annotated_corresponding_column(self): - table1 = table('table1', column("col1")) + table1 = table("table1", column("col1")) s1 = select([table1.c.col1]) t1 = s1._annotate({}) @@ -1766,93 +1966,107 @@ class AnnotationsTest(fixtures.TestBase): inner = select([s1]) - assert inner.corresponding_column( - t2.c.col1, - require_embedded=False) is inner.corresponding_column( - t2.c.col1, - require_embedded=True) is inner.c.col1 - assert inner.corresponding_column( - t1.c.col1, - require_embedded=False) is inner.corresponding_column( - t1.c.col1, - require_embedded=True) is inner.c.col1 + assert ( + inner.corresponding_column(t2.c.col1, require_embedded=False) + is inner.corresponding_column(t2.c.col1, require_embedded=True) + is inner.c.col1 + ) + assert ( + inner.corresponding_column(t1.c.col1, require_embedded=False) + is inner.corresponding_column(t1.c.col1, require_embedded=True) + is inner.c.col1 + ) def test_annotated_visit(self): - table1 = table('table1', column("col1"), column("col2")) + table1 = table("table1", column("col1"), column("col2")) - bin = table1.c.col1 == bindparam('foo', value=None) + bin = table1.c.col1 == bindparam("foo", value=None) assert str(bin) == "table1.col1 = :foo" def visit_binary(b): b.right = table1.c.col2 - b2 = visitors.cloned_traverse(bin, {}, {'binary': visit_binary}) + b2 = visitors.cloned_traverse(bin, {}, {"binary": visit_binary}) assert str(b2) == "table1.col1 = table1.col2" - b3 = visitors.cloned_traverse(bin._annotate({}), {}, {'binary': - visit_binary}) - assert str(b3) == 'table1.col1 = table1.col2' + b3 = visitors.cloned_traverse( + bin._annotate({}), {}, {"binary": visit_binary} + ) + assert str(b3) == "table1.col1 = table1.col2" def visit_binary(b): - b.left = bindparam('bar') + b.left = bindparam("bar") - b4 = visitors.cloned_traverse(b2, {}, {'binary': visit_binary}) + b4 = visitors.cloned_traverse(b2, {}, {"binary": visit_binary}) assert str(b4) == ":bar = table1.col2" - b5 = visitors.cloned_traverse(b3, {}, {'binary': visit_binary}) + b5 = visitors.cloned_traverse(b3, {}, {"binary": visit_binary}) assert str(b5) == ":bar = table1.col2" def test_label_accessors(self): - t1 = table('t1', column('c1')) + t1 = table("t1", column("c1")) l1 = t1.c.c1.label(None) is_(l1._order_by_label_element, l1) l1a = l1._annotate({"foo": "bar"}) is_(l1a._order_by_label_element, l1a) def test_annotate_aliased(self): - t1 = table('t1', column('c1')) - s = select([(t1.c.c1 + 3).label('bat')]) + t1 = table("t1", column("c1")) + s = select([(t1.c.c1 + 3).label("bat")]) a = s.alias() - a = sql_util._deep_annotate(a, {'foo': 'bar'}) - eq_(a._annotations['foo'], 'bar') - eq_(a.element._annotations['foo'], 'bar') + a = sql_util._deep_annotate(a, {"foo": "bar"}) + eq_(a._annotations["foo"], "bar") + eq_(a.element._annotations["foo"], "bar") def test_annotate_expressions(self): - table1 = table('table1', column('col1'), column('col2')) - for expr, expected in [(table1.c.col1, 'table1.col1'), - (table1.c.col1 == 5, - 'table1.col1 = :col1_1'), - (table1.c.col1.in_([2, 3, 4]), - 'table1.col1 IN (:col1_1, :col1_2, ' - ':col1_3)')]: + table1 = table("table1", column("col1"), column("col2")) + for expr, expected in [ + (table1.c.col1, "table1.col1"), + (table1.c.col1 == 5, "table1.col1 = :col1_1"), + ( + table1.c.col1.in_([2, 3, 4]), + "table1.col1 IN (:col1_1, :col1_2, " ":col1_3)", + ), + ]: eq_(str(expr), expected) eq_(str(expr._annotate({})), expected) eq_(str(sql_util._deep_annotate(expr, {})), expected) - eq_(str(sql_util._deep_annotate( - expr, {}, exclude=[table1.c.col1])), expected) + eq_( + str( + sql_util._deep_annotate(expr, {}, exclude=[table1.c.col1]) + ), + expected, + ) def test_deannotate(self): - table1 = table('table1', column("col1"), column("col2")) + table1 = table("table1", column("col1"), column("col2")) - bin = table1.c.col1 == bindparam('foo', value=None) + bin = table1.c.col1 == bindparam("foo", value=None) - b2 = sql_util._deep_annotate(bin, {'_orm_adapt': True}) + b2 = sql_util._deep_annotate(bin, {"_orm_adapt": True}) b3 = sql_util._deep_deannotate(b2) b4 = sql_util._deep_deannotate(bin) for elem in (b2._annotations, b2.left._annotations): - assert '_orm_adapt' in elem + assert "_orm_adapt" in elem - for elem in b3._annotations, b3.left._annotations, \ - b4._annotations, b4.left._annotations: + for elem in ( + b3._annotations, + b3.left._annotations, + b4._annotations, + b4.left._annotations, + ): assert elem == {} assert b2.left is not bin.left assert b3.left is not b2.left and b2.left is not bin.left assert b4.left is bin.left # since column is immutable # deannotate copies the element - assert bin.right is not b2.right and b2.right is not b3.right \ + assert ( + bin.right is not b2.right + and b2.right is not b3.right and b3.right is not b4.right + ) def test_annotate_unique_traversal(self): """test that items are copied only once during @@ -1864,16 +2078,14 @@ class AnnotationsTest(fixtures.TestBase): case now, as deannotate is making clones again in some cases. """ - table1 = table('table1', column('x')) - table2 = table('table2', column('y')) + table1 = table("table1", column("x")) + table2 = table("table2", column("y")) a1 = table1.alias() - s = select([a1.c.x]).select_from( - a1.join(table2, a1.c.x == table2.c.y) - ) + s = select([a1.c.x]).select_from(a1.join(table2, a1.c.x == table2.c.y)) for sel in ( sql_util._deep_deannotate(s), visitors.cloned_traverse(s, {}, {}), - visitors.replacement_traverse(s, {}, lambda x: None) + visitors.replacement_traverse(s, {}, lambda x: None), ): # the columns clause isn't changed at all assert sel._raw_columns[0].table is a1 @@ -1886,7 +2098,7 @@ class AnnotationsTest(fixtures.TestBase): # when encountered. for sel in ( sql_util._deep_deannotate(s, {"foo": "bar"}), - sql_util._deep_annotate(s, {'foo': 'bar'}), + sql_util._deep_annotate(s, {"foo": "bar"}), ): assert sel._froms[0] is not sel._froms[1].left @@ -1899,7 +2111,7 @@ class AnnotationsTest(fixtures.TestBase): preserving them when deep_annotate is run on them. """ - t1 = table('table1', column("col1"), column("col2")) + t1 = table("table1", column("col1"), column("col2")) s = select([t1.c.col1._annotate({"foo": "bar"})]) s2 = select([t1.c.col1._annotate({"bat": "hoho"})]) s3 = s.union(s2) @@ -1907,42 +2119,40 @@ class AnnotationsTest(fixtures.TestBase): eq_( sel.selects[0]._raw_columns[0]._annotations, - {"foo": "bar", "new": "thing"} + {"foo": "bar", "new": "thing"}, ) eq_( sel.selects[1]._raw_columns[0]._annotations, - {"bat": "hoho", "new": "thing"} + {"bat": "hoho", "new": "thing"}, ) def test_deannotate_2(self): - table1 = table('table1', column("col1"), column("col2")) - j = table1.c.col1._annotate({"remote": True}) == \ - table1.c.col2._annotate({"local": True}) + table1 = table("table1", column("col1"), column("col2")) + j = table1.c.col1._annotate( + {"remote": True} + ) == table1.c.col2._annotate({"local": True}) j2 = sql_util._deep_deannotate(j) - eq_( - j.left._annotations, {"remote": True} - ) - eq_( - j2.left._annotations, {} - ) + eq_(j.left._annotations, {"remote": True}) + eq_(j2.left._annotations, {}) def test_deannotate_3(self): - table1 = table('table1', column("col1"), column("col2"), - column("col3"), column("col4")) + table1 = table( + "table1", + column("col1"), + column("col2"), + column("col3"), + column("col4"), + ) j = and_( - table1.c.col1._annotate({"remote": True}) == - table1.c.col2._annotate({"local": True}), - table1.c.col3._annotate({"remote": True}) == - table1.c.col4._annotate({"local": True}) + table1.c.col1._annotate({"remote": True}) + == table1.c.col2._annotate({"local": True}), + table1.c.col3._annotate({"remote": True}) + == table1.c.col4._annotate({"local": True}), ) j2 = sql_util._deep_deannotate(j) - eq_( - j.clauses[0].left._annotations, {"remote": True} - ) - eq_( - j2.clauses[0].left._annotations, {} - ) + eq_(j.clauses[0].left._annotations, {"remote": True}) + eq_(j2.clauses[0].left._annotations, {}) def test_annotate_fromlist_preservation(self): """test the FROM list in select still works @@ -1952,37 +2162,34 @@ class AnnotationsTest(fixtures.TestBase): #2453, continued """ - table1 = table('table1', column('x')) - table2 = table('table2', column('y')) + table1 = table("table1", column("x")) + table2 = table("table2", column("y")) a1 = table1.alias() - s = select([a1.c.x]).select_from( - a1.join(table2, a1.c.x == table2.c.y) - ) + s = select([a1.c.x]).select_from(a1.join(table2, a1.c.x == table2.c.y)) assert_s = select([select([s])]) for fn in ( sql_util._deep_deannotate, - lambda s: sql_util._deep_annotate(s, {'foo': 'bar'}), + lambda s: sql_util._deep_annotate(s, {"foo": "bar"}), lambda s: visitors.cloned_traverse(s, {}, {}), - lambda s: visitors.replacement_traverse(s, {}, lambda x: None) + lambda s: visitors.replacement_traverse(s, {}, lambda x: None), ): sel = fn(select([fn(select([fn(s)]))])) eq_(str(assert_s), str(sel)) def test_bind_unique_test(self): - table('t', column('a'), column('b')) + table("t", column("a"), column("b")) b = bindparam("bind", value="x", unique=True) # the annotation of "b" should render the # same. The "unique" test in compiler should # also pass, [ticket:2425] - eq_(str(or_(b, b._annotate({"foo": "bar"}))), - ":bind_1 OR :bind_1") + eq_(str(or_(b, b._annotate({"foo": "bar"}))), ":bind_1 OR :bind_1") def test_comparators_cleaned_out_construction(self): - c = column('a') + c = column("a") comp1 = c.comparator @@ -1991,7 +2198,7 @@ class AnnotationsTest(fixtures.TestBase): assert comp1 is not comp2 def test_comparators_cleaned_out_reannotate(self): - c = column('a') + c = column("a") c1 = c._annotate({"foo": "bar"}) comp1 = c1.comparator @@ -2002,7 +2209,7 @@ class AnnotationsTest(fixtures.TestBase): assert comp1 is not comp2 def test_comparator_cleanout_integration(self): - c = column('a') + c = column("a") c1 = c._annotate({"foo": "bar"}) comp1 = c1.comparator @@ -2018,8 +2225,8 @@ class ReprTest(fixtures.TestBase): for obj in [ elements.Cast(1, 2), elements.TypeClause(String()), - elements.ColumnClause('x'), - elements.BindParameter('q'), + elements.ColumnClause("x"), + elements.BindParameter("q"), elements.Null(), elements.True_(), elements.False_(), @@ -2027,22 +2234,21 @@ class ReprTest(fixtures.TestBase): elements.BooleanClauseList.and_(), elements.Tuple(), elements.Case([]), - elements.Extract('foo', column('x')), - elements.UnaryExpression(column('x')), - elements.Grouping(column('x')), + elements.Extract("foo", column("x")), + elements.UnaryExpression(column("x")), + elements.Grouping(column("x")), elements.Over(func.foo()), - elements.Label('q', column('x')), + elements.Label("q", column("x")), ]: repr(obj) class WithLabelsTest(fixtures.TestBase): - def _assert_labels_warning(self, s): assert_raises_message( exc.SAWarning, r"replaced by Column.*, which has the same key", - lambda: s.c + lambda: s.c, ) def _assert_result_keys(self, s, keys): @@ -2055,147 +2261,128 @@ class WithLabelsTest(fixtures.TestBase): def _names_overlap(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer)) - t2 = Table('t2', m, Column('x', Integer)) + t1 = Table("t1", m, Column("x", Integer)) + t2 = Table("t2", m, Column("x", Integer)) return select([t1, t2]) def test_names_overlap_nolabel(self): sel = self._names_overlap() self._assert_labels_warning(sel) - self._assert_result_keys(sel, ['x']) + self._assert_result_keys(sel, ["x"]) def test_names_overlap_label(self): sel = self._names_overlap().apply_labels() - eq_( - list(sel.c.keys()), - ['t1_x', 't2_x'] - ) - self._assert_result_keys(sel, ['t1_x', 't2_x']) + eq_(list(sel.c.keys()), ["t1_x", "t2_x"]) + self._assert_result_keys(sel, ["t1_x", "t2_x"]) def _names_overlap_keys_dont(self): m = MetaData() - t1 = Table('t1', m, Column('x', Integer, key='a')) - t2 = Table('t2', m, Column('x', Integer, key='b')) + t1 = Table("t1", m, Column("x", Integer, key="a")) + t2 = Table("t2", m, Column("x", Integer, key="b")) return select([t1, t2]) def test_names_overlap_keys_dont_nolabel(self): sel = self._names_overlap_keys_dont() - eq_( - list(sel.c.keys()), - ['a', 'b'] - ) - self._assert_result_keys(sel, ['x']) + eq_(list(sel.c.keys()), ["a", "b"]) + self._assert_result_keys(sel, ["x"]) def test_names_overlap_keys_dont_label(self): sel = self._names_overlap_keys_dont().apply_labels() - eq_( - list(sel.c.keys()), - ['t1_a', 't2_b'] - ) - self._assert_result_keys(sel, ['t1_x', 't2_x']) + eq_(list(sel.c.keys()), ["t1_a", "t2_b"]) + self._assert_result_keys(sel, ["t1_x", "t2_x"]) def _labels_overlap(self): m = MetaData() - t1 = Table('t', m, Column('x_id', Integer)) - t2 = Table('t_x', m, Column('id', Integer)) + t1 = Table("t", m, Column("x_id", Integer)) + t2 = Table("t_x", m, Column("id", Integer)) return select([t1, t2]) def test_labels_overlap_nolabel(self): sel = self._labels_overlap() - eq_( - list(sel.c.keys()), - ['x_id', 'id'] - ) - self._assert_result_keys(sel, ['x_id', 'id']) + eq_(list(sel.c.keys()), ["x_id", "id"]) + self._assert_result_keys(sel, ["x_id", "id"]) def test_labels_overlap_label(self): sel = self._labels_overlap().apply_labels() t2 = sel.froms[1] - eq_( - list(sel.c.keys()), - ['t_x_id', t2.c.id.anon_label] - ) - self._assert_result_keys(sel, ['t_x_id', 'id_1']) - self._assert_subq_result_keys(sel, ['t_x_id', 'id_1']) + eq_(list(sel.c.keys()), ["t_x_id", t2.c.id.anon_label]) + self._assert_result_keys(sel, ["t_x_id", "id_1"]) + self._assert_subq_result_keys(sel, ["t_x_id", "id_1"]) def _labels_overlap_keylabels_dont(self): m = MetaData() - t1 = Table('t', m, Column('x_id', Integer, key='a')) - t2 = Table('t_x', m, Column('id', Integer, key='b')) + t1 = Table("t", m, Column("x_id", Integer, key="a")) + t2 = Table("t_x", m, Column("id", Integer, key="b")) return select([t1, t2]) def test_labels_overlap_keylabels_dont_nolabel(self): sel = self._labels_overlap_keylabels_dont() - eq_(list(sel.c.keys()), ['a', 'b']) - self._assert_result_keys(sel, ['x_id', 'id']) + eq_(list(sel.c.keys()), ["a", "b"]) + self._assert_result_keys(sel, ["x_id", "id"]) def test_labels_overlap_keylabels_dont_label(self): sel = self._labels_overlap_keylabels_dont().apply_labels() - eq_(list(sel.c.keys()), ['t_a', 't_x_b']) - self._assert_result_keys(sel, ['t_x_id', 'id_1']) + eq_(list(sel.c.keys()), ["t_a", "t_x_b"]) + self._assert_result_keys(sel, ["t_x_id", "id_1"]) def _keylabels_overlap_labels_dont(self): m = MetaData() - t1 = Table('t', m, Column('a', Integer, key='x_id')) - t2 = Table('t_x', m, Column('b', Integer, key='id')) + t1 = Table("t", m, Column("a", Integer, key="x_id")) + t2 = Table("t_x", m, Column("b", Integer, key="id")) return select([t1, t2]) def test_keylabels_overlap_labels_dont_nolabel(self): sel = self._keylabels_overlap_labels_dont() - eq_(list(sel.c.keys()), ['x_id', 'id']) - self._assert_result_keys(sel, ['a', 'b']) + eq_(list(sel.c.keys()), ["x_id", "id"]) + self._assert_result_keys(sel, ["a", "b"]) def test_keylabels_overlap_labels_dont_label(self): sel = self._keylabels_overlap_labels_dont().apply_labels() t2 = sel.froms[1] - eq_(list(sel.c.keys()), ['t_x_id', t2.c.id.anon_label]) - self._assert_result_keys(sel, ['t_a', 't_x_b']) - self._assert_subq_result_keys(sel, ['t_a', 't_x_b']) + eq_(list(sel.c.keys()), ["t_x_id", t2.c.id.anon_label]) + self._assert_result_keys(sel, ["t_a", "t_x_b"]) + self._assert_subq_result_keys(sel, ["t_a", "t_x_b"]) def _keylabels_overlap_labels_overlap(self): m = MetaData() - t1 = Table('t', m, Column('x_id', Integer, key='x_a')) - t2 = Table('t_x', m, Column('id', Integer, key='a')) + t1 = Table("t", m, Column("x_id", Integer, key="x_a")) + t2 = Table("t_x", m, Column("id", Integer, key="a")) return select([t1, t2]) def test_keylabels_overlap_labels_overlap_nolabel(self): sel = self._keylabels_overlap_labels_overlap() - eq_(list(sel.c.keys()), ['x_a', 'a']) - self._assert_result_keys(sel, ['x_id', 'id']) - self._assert_subq_result_keys(sel, ['x_id', 'id']) + eq_(list(sel.c.keys()), ["x_a", "a"]) + self._assert_result_keys(sel, ["x_id", "id"]) + self._assert_subq_result_keys(sel, ["x_id", "id"]) def test_keylabels_overlap_labels_overlap_label(self): sel = self._keylabels_overlap_labels_overlap().apply_labels() t2 = sel.froms[1] - eq_(list(sel.c.keys()), ['t_x_a', t2.c.a.anon_label]) - self._assert_result_keys(sel, ['t_x_id', 'id_1']) - self._assert_subq_result_keys(sel, ['t_x_id', 'id_1']) + eq_(list(sel.c.keys()), ["t_x_a", t2.c.a.anon_label]) + self._assert_result_keys(sel, ["t_x_id", "id_1"]) + self._assert_subq_result_keys(sel, ["t_x_id", "id_1"]) def _keys_overlap_names_dont(self): m = MetaData() - t1 = Table('t1', m, Column('a', Integer, key='x')) - t2 = Table('t2', m, Column('b', Integer, key='x')) + t1 = Table("t1", m, Column("a", Integer, key="x")) + t2 = Table("t2", m, Column("b", Integer, key="x")) return select([t1, t2]) def test_keys_overlap_names_dont_nolabel(self): sel = self._keys_overlap_names_dont() self._assert_labels_warning(sel) - self._assert_result_keys(sel, ['a', 'b']) + self._assert_result_keys(sel, ["a", "b"]) def test_keys_overlap_names_dont_label(self): sel = self._keys_overlap_names_dont().apply_labels() - eq_( - list(sel.c.keys()), - ['t1_x', 't2_x'] - ) - self._assert_result_keys(sel, ['t1_a', 't2_b']) + eq_(list(sel.c.keys()), ["t1_x", "t2_x"]) + self._assert_result_keys(sel, ["t1_a", "t2_b"]) class ResultMapTest(fixtures.TestBase): - def _fixture(self): m = MetaData() - t = Table('t', m, Column('x', Integer), Column('y', Integer)) + t = Table("t", m, Column("x", Integer), Column("y", Integer)) return t def _mapping(self, stmt): @@ -2208,7 +2395,7 @@ class ResultMapTest(fixtures.TestBase): def test_select_label_alt_name(self): t = self._fixture() - l1, l2 = t.c.x.label('a'), t.c.y.label('b') + l1, l2 = t.c.x.label("a"), t.c.y.label("b") s = select([l1, l2]) mapping = self._mapping(s) assert l1 in mapping @@ -2217,7 +2404,7 @@ class ResultMapTest(fixtures.TestBase): def test_select_alias_label_alt_name(self): t = self._fixture() - l1, l2 = t.c.x.label('a'), t.c.y.label('b') + l1, l2 = t.c.x.label("a"), t.c.y.label("b") s = select([l1, l2]).alias() mapping = self._mapping(s) assert l1 in mapping @@ -2253,7 +2440,7 @@ class ResultMapTest(fixtures.TestBase): x, y = t.c.x, t.c.y ta = t.alias() - l1, l2 = ta.c.x.label('a'), ta.c.y.label('b') + l1, l2 = ta.c.x.label("a"), ta.c.y.label("b") s = select([l1, l2]) mapping = self._mapping(s) @@ -2268,34 +2455,40 @@ class ResultMapTest(fixtures.TestBase): assert t.c.x not in mapping eq_( [type(entry[-1]) for entry in s.compile()._result_columns], - [Boolean] + [Boolean], ) def test_plain_exists(self): expr = exists([1]) eq_(type(expr.type), Boolean) eq_( - [type(entry[-1]) for - entry in select([expr]).compile()._result_columns], - [Boolean] + [ + type(entry[-1]) + for entry in select([expr]).compile()._result_columns + ], + [Boolean], ) def test_plain_exists_negate(self): expr = ~exists([1]) eq_(type(expr.type), Boolean) eq_( - [type(entry[-1]) for - entry in select([expr]).compile()._result_columns], - [Boolean] + [ + type(entry[-1]) + for entry in select([expr]).compile()._result_columns + ], + [Boolean], ) def test_plain_exists_double_negate(self): expr = ~(~exists([1])) eq_(type(expr.type), Boolean) eq_( - [type(entry[-1]) for - entry in select([expr]).compile()._result_columns], - [Boolean] + [ + type(entry[-1]) + for entry in select([expr]).compile()._result_columns + ], + [Boolean], ) def test_column_subquery_plain(self): @@ -2307,7 +2500,7 @@ class ResultMapTest(fixtures.TestBase): assert s1 in mapping eq_( [type(entry[-1]) for entry in s2.compile()._result_columns], - [Integer] + [Integer], ) def test_unary_boolean(self): @@ -2315,7 +2508,7 @@ class ResultMapTest(fixtures.TestBase): s1 = select([not_(True)], use_labels=True) eq_( [type(entry[-1]) for entry in s1.compile()._result_columns], - [Boolean] + [Boolean], ) @@ -2323,19 +2516,15 @@ class ForUpdateTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" def _assert_legacy(self, leg, read=False, nowait=False): - t = table('t', column('c')) + t = table("t", column("c")) s1 = select([t], for_update=leg) if leg is False: assert s1._for_update_arg is None assert s1.for_update is None else: - eq_( - s1._for_update_arg.read, read - ) - eq_( - s1._for_update_arg.nowait, nowait - ) + eq_(s1._for_update_arg.read, read) + eq_(s1._for_update_arg.nowait, nowait) eq_(s1.for_update, leg) def test_false_legacy(self): @@ -2354,28 +2543,30 @@ class ForUpdateTest(fixtures.TestBase, AssertsCompiledSQL): self._assert_legacy("read_nowait", read=True, nowait=True) def test_legacy_setter(self): - t = table('t', column('c')) + t = table("t", column("c")) s = select([t]) - s.for_update = 'nowait' + s.for_update = "nowait" eq_(s._for_update_arg.nowait, True) def test_basic_clone(self): - t = table('t', column('c')) + t = table("t", column("c")) s = select([t]).with_for_update(read=True, of=t.c.c) s2 = visitors.ReplacingCloningVisitor().traverse(s) assert s2._for_update_arg is not s._for_update_arg eq_(s2._for_update_arg.read, True) eq_(s2._for_update_arg.of, [t.c.c]) - self.assert_compile(s2, - "SELECT t.c FROM t FOR SHARE OF t", - dialect="postgresql") + self.assert_compile( + s2, "SELECT t.c FROM t FOR SHARE OF t", dialect="postgresql" + ) def test_adapt(self): - t = table('t', column('c')) + t = table("t", column("c")) s = select([t]).with_for_update(read=True, of=t.c.c) a = t.alias() s2 = sql_util.ClauseAdapter(a).traverse(s) eq_(s2._for_update_arg.of, [a.c.c]) - self.assert_compile(s2, - "SELECT t_1.c FROM t AS t_1 FOR SHARE OF t_1", - dialect="postgresql") + self.assert_compile( + s2, + "SELECT t_1.c FROM t AS t_1 FOR SHARE OF t_1", + dialect="postgresql", + ) diff --git a/test/sql/test_tablesample.py b/test/sql/test_tablesample.py index 879e83182d..712450d9f2 100644 --- a/test/sql/test_tablesample.py +++ b/test/sql/test_tablesample.py @@ -16,10 +16,13 @@ class TableSampleTest(fixtures.TablesTest, AssertsCompiledSQL): @classmethod def define_tables(cls, metadata): - Table('people', metadata, - Column('people_id', Integer, primary_key=True), - Column('age', Integer), - Column('name', String(30))) + Table( + "people", + metadata, + Column("people_id", Integer, primary_key=True), + Column("age", Integer), + Column("name", String(30)), + ) def test_standalone(self): table1 = self.tables.people @@ -27,27 +30,28 @@ class TableSampleTest(fixtures.TablesTest, AssertsCompiledSQL): # no special alias handling even though clause is not in the # context of a FROM clause self.assert_compile( - tablesample(table1, 1, name='alias'), - 'people AS alias TABLESAMPLE system(:system_1)' + tablesample(table1, 1, name="alias"), + "people AS alias TABLESAMPLE system(:system_1)", ) self.assert_compile( - table1.tablesample(1, name='alias'), - 'people AS alias TABLESAMPLE system(:system_1)' + table1.tablesample(1, name="alias"), + "people AS alias TABLESAMPLE system(:system_1)", ) self.assert_compile( - tablesample(table1, func.bernoulli(1), name='alias', - seed=func.random()), - 'people AS alias TABLESAMPLE bernoulli(:bernoulli_1) ' - 'REPEATABLE (random())' + tablesample( + table1, func.bernoulli(1), name="alias", seed=func.random() + ), + "people AS alias TABLESAMPLE bernoulli(:bernoulli_1) " + "REPEATABLE (random())", ) def test_select_from(self): table1 = self.tables.people self.assert_compile( - select([table1.tablesample(text('1'), name='alias').c.people_id]), - 'SELECT alias.people_id FROM ' - 'people AS alias TABLESAMPLE system(1)' + select([table1.tablesample(text("1"), name="alias").c.people_id]), + "SELECT alias.people_id FROM " + "people AS alias TABLESAMPLE system(1)", ) diff --git a/test/sql/test_text.py b/test/sql/test_text.py index c31c22853e..34415600e8 100644 --- a/test/sql/test_text.py +++ b/test/sql/test_text.py @@ -1,34 +1,56 @@ """Test the TextClause and related constructs.""" -from sqlalchemy.testing import fixtures, AssertsCompiledSQL, eq_, \ - assert_raises_message, expect_warnings, assert_warnings -from sqlalchemy import text, select, Integer, String, Float, \ - bindparam, and_, func, literal_column, exc, MetaData, Table, Column,\ - asc, func, desc, union, literal +from sqlalchemy.testing import ( + fixtures, + AssertsCompiledSQL, + eq_, + assert_raises_message, + expect_warnings, + assert_warnings, +) +from sqlalchemy import ( + text, + select, + Integer, + String, + Float, + bindparam, + and_, + func, + literal_column, + exc, + MetaData, + Table, + Column, + asc, + func, + desc, + union, + literal, +) from sqlalchemy.types import NullType from sqlalchemy.sql import table, column, util as sql_util from sqlalchemy import util -table1 = table('mytable', - column('myid', Integer), - column('name', String), - column('description', String), - ) +table1 = table( + "mytable", + column("myid", Integer), + column("name", String), + column("description", String), +) table2 = table( - 'myothertable', - column('otherid', Integer), - column('othername', String), + "myothertable", column("otherid", Integer), column("othername", String) ) class CompileTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_basic(self): self.assert_compile( text("select * from foo where lala = bar"), - "select * from foo where lala = bar" + "select * from foo where lala = bar", ) @@ -37,21 +59,24 @@ class SelectCompositionTest(fixtures.TestBase, AssertsCompiledSQL): """test the usage of text() implicit within the select() construct when strings are passed.""" - __dialect__ = 'default' + __dialect__ = "default" def test_select_composition_one(self): - self.assert_compile(select( - [ - literal_column("foobar(a)"), - literal_column("pk_foo_bar(syslaal)") - ], - text("a = 12"), - from_obj=[ - text("foobar left outer join lala on foobar.foo = lala.foo") - ] - ), + self.assert_compile( + select( + [ + literal_column("foobar(a)"), + literal_column("pk_foo_bar(syslaal)"), + ], + text("a = 12"), + from_obj=[ + text( + "foobar left outer join lala on foobar.foo = lala.foo" + ) + ], + ), "SELECT foobar(a), pk_foo_bar(syslaal) FROM foobar " - "left outer join lala on foobar.foo = lala.foo WHERE a = 12" + "left outer join lala on foobar.foo = lala.foo WHERE a = 12", ) def test_select_composition_two(self): @@ -62,40 +87,54 @@ class SelectCompositionTest(fixtures.TestBase, AssertsCompiledSQL): s.append_whereclause(text("column2=19")) s = s.order_by("column1") s.append_from(text("table1")) - self.assert_compile(s, "SELECT column1, column2 FROM table1 WHERE " - "column1=12 AND column2=19 ORDER BY column1") + self.assert_compile( + s, + "SELECT column1, column2 FROM table1 WHERE " + "column1=12 AND column2=19 ORDER BY column1", + ) def test_select_composition_three(self): self.assert_compile( - select([column("column1"), column("column2")], - from_obj=table1).alias('somealias').select(), + select([column("column1"), column("column2")], from_obj=table1) + .alias("somealias") + .select(), "SELECT somealias.column1, somealias.column2 FROM " - "(SELECT column1, column2 FROM mytable) AS somealias" + "(SELECT column1, column2 FROM mytable) AS somealias", ) def test_select_composition_four(self): # test that use_labels doesn't interfere with literal columns self.assert_compile( - select([ - text("column1"), column("column2"), - column("column3").label("bar"), table1.c.myid], + select( + [ + text("column1"), + column("column2"), + column("column3").label("bar"), + table1.c.myid, + ], from_obj=table1, - use_labels=True), + use_labels=True, + ), "SELECT column1, column2, column3 AS bar, " "mytable.myid AS mytable_myid " - "FROM mytable" + "FROM mytable", ) def test_select_composition_five(self): # test that use_labels doesn't interfere # with literal columns that have textual labels self.assert_compile( - select([ - text("column1 AS foobar"), text("column2 AS hoho"), - table1.c.myid], - from_obj=table1, use_labels=True), + select( + [ + text("column1 AS foobar"), + text("column2 AS hoho"), + table1.c.myid, + ], + from_obj=table1, + use_labels=True, + ), "SELECT column1 AS foobar, column2 AS hoho, " - "mytable.myid AS mytable_myid FROM mytable" + "mytable.myid AS mytable_myid FROM mytable", ) def test_select_composition_six(self): @@ -103,70 +142,84 @@ class SelectCompositionTest(fixtures.TestBase, AssertsCompiledSQL): # doesn't interfere with literal columns, # exported columns don't get quoted self.assert_compile( - select([ - literal_column("column1 AS foobar"), - literal_column("column2 AS hoho"), table1.c.myid], - from_obj=[table1]).select(), + select( + [ + literal_column("column1 AS foobar"), + literal_column("column2 AS hoho"), + table1.c.myid, + ], + from_obj=[table1], + ).select(), "SELECT column1 AS foobar, column2 AS hoho, myid FROM " "(SELECT column1 AS foobar, column2 AS hoho, " - "mytable.myid AS myid FROM mytable)" + "mytable.myid AS myid FROM mytable)", ) def test_select_composition_seven(self): self.assert_compile( - select([ - literal_column('col1'), - literal_column('col2') - ], from_obj=table('tablename')).alias('myalias'), - "SELECT col1, col2 FROM tablename" + select( + [literal_column("col1"), literal_column("col2")], + from_obj=table("tablename"), + ).alias("myalias"), + "SELECT col1, col2 FROM tablename", ) def test_select_composition_eight(self): - self.assert_compile(select( - [table1.alias('t'), text("foo.f")], - text("foo.f = t.id"), - from_obj=[text("(select f from bar where lala=heyhey) foo")] - ), + self.assert_compile( + select( + [table1.alias("t"), text("foo.f")], + text("foo.f = t.id"), + from_obj=[text("(select f from bar where lala=heyhey) foo")], + ), "SELECT t.myid, t.name, t.description, foo.f FROM mytable AS t, " - "(select f from bar where lala=heyhey) foo WHERE foo.f = t.id") + "(select f from bar where lala=heyhey) foo WHERE foo.f = t.id", + ) def test_select_bundle_columns(self): - self.assert_compile(select( - [table1, table2.c.otherid, - text("sysdate()"), text("foo, bar, lala")], - and_( - text("foo.id = foofoo(lala)"), - text("datetime(foo) = Today"), - table1.c.myid == table2.c.otherid, - ) - ), + self.assert_compile( + select( + [ + table1, + table2.c.otherid, + text("sysdate()"), + text("foo, bar, lala"), + ], + and_( + text("foo.id = foofoo(lala)"), + text("datetime(foo) = Today"), + table1.c.myid == table2.c.otherid, + ), + ), "SELECT mytable.myid, mytable.name, mytable.description, " "myothertable.otherid, sysdate(), foo, bar, lala " "FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND " - "datetime(foo) = Today AND mytable.myid = myothertable.otherid") + "datetime(foo) = Today AND mytable.myid = myothertable.otherid", + ) class BindParamTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_legacy(self): - t = text("select * from foo where lala=:bar and hoho=:whee", - bindparams=[bindparam('bar', 4), bindparam('whee', 7)]) + t = text( + "select * from foo where lala=:bar and hoho=:whee", + bindparams=[bindparam("bar", 4), bindparam("whee", 7)], + ) self.assert_compile( t, "select * from foo where lala=:bar and hoho=:whee", - checkparams={'bar': 4, 'whee': 7}, + checkparams={"bar": 4, "whee": 7}, ) def test_positional(self): t = text("select * from foo where lala=:bar and hoho=:whee") - t = t.bindparams(bindparam('bar', 4), bindparam('whee', 7)) + t = t.bindparams(bindparam("bar", 4), bindparam("whee", 7)) self.assert_compile( t, "select * from foo where lala=:bar and hoho=:whee", - checkparams={'bar': 4, 'whee': 7}, + checkparams={"bar": 4, "whee": 7}, ) def test_kw(self): @@ -176,78 +229,78 @@ class BindParamTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( t, "select * from foo where lala=:bar and hoho=:whee", - checkparams={'bar': 4, 'whee': 7}, + checkparams={"bar": 4, "whee": 7}, ) def test_positional_plus_kw(self): t = text("select * from foo where lala=:bar and hoho=:whee") - t = t.bindparams(bindparam('bar', 4), whee=7) + t = t.bindparams(bindparam("bar", 4), whee=7) self.assert_compile( t, "select * from foo where lala=:bar and hoho=:whee", - checkparams={'bar': 4, 'whee': 7}, + checkparams={"bar": 4, "whee": 7}, ) def test_literal_binds(self): t = text("select * from foo where lala=:bar and hoho=:whee") - t = t.bindparams(bindparam('bar', 4), whee='whee') + t = t.bindparams(bindparam("bar", 4), whee="whee") self.assert_compile( t, "select * from foo where lala=4 and hoho='whee'", checkparams={}, - literal_binds=True + literal_binds=True, ) def _assert_type_map(self, t, compare): - map_ = dict( - (b.key, b.type) for b in t._bindparams.values() - ) + map_ = dict((b.key, b.type) for b in t._bindparams.values()) for k in compare: assert compare[k]._type_affinity is map_[k]._type_affinity def test_typing_construction(self): t = text("select * from table :foo :bar :bat") - self._assert_type_map(t, {"foo": NullType(), - "bar": NullType(), - "bat": NullType()}) + self._assert_type_map( + t, {"foo": NullType(), "bar": NullType(), "bat": NullType()} + ) - t = t.bindparams(bindparam('foo', type_=String)) + t = t.bindparams(bindparam("foo", type_=String)) - self._assert_type_map(t, {"foo": String(), - "bar": NullType(), - "bat": NullType()}) + self._assert_type_map( + t, {"foo": String(), "bar": NullType(), "bat": NullType()} + ) - t = t.bindparams(bindparam('bar', type_=Integer)) + t = t.bindparams(bindparam("bar", type_=Integer)) - self._assert_type_map(t, {"foo": String(), - "bar": Integer(), - "bat": NullType()}) + self._assert_type_map( + t, {"foo": String(), "bar": Integer(), "bat": NullType()} + ) t = t.bindparams(bat=45.564) - self._assert_type_map(t, {"foo": String(), - "bar": Integer(), - "bat": Float()}) + self._assert_type_map( + t, {"foo": String(), "bar": Integer(), "bat": Float()} + ) def test_binds_compiled_named(self): self.assert_compile( - text("select * from foo where lala=:bar and hoho=:whee"). - bindparams(bar=4, whee=7), + text( + "select * from foo where lala=:bar and hoho=:whee" + ).bindparams(bar=4, whee=7), "select * from foo where lala=%(bar)s and hoho=%(whee)s", - checkparams={'bar': 4, 'whee': 7}, - dialect="postgresql" + checkparams={"bar": 4, "whee": 7}, + dialect="postgresql", ) def test_binds_compiled_positional(self): self.assert_compile( - text("select * from foo where lala=:bar and hoho=:whee"). - bindparams(bar=4, whee=7), + text( + "select * from foo where lala=:bar and hoho=:whee" + ).bindparams(bar=4, whee=7), "select * from foo where lala=? and hoho=?", - checkparams={'bar': 4, 'whee': 7}, - dialect="sqlite" + checkparams={"bar": 4, "whee": 7}, + dialect="sqlite", ) def test_missing_bind_kw(self): @@ -257,7 +310,8 @@ class BindParamTest(fixtures.TestBase, AssertsCompiledSQL): r"a bound parameter named 'bar'", text(":foo").bindparams, foo=5, - bar=7) + bar=7, + ) def test_missing_bind_posn(self): assert_raises_message( @@ -265,70 +319,66 @@ class BindParamTest(fixtures.TestBase, AssertsCompiledSQL): r"This text\(\) construct doesn't define " r"a bound parameter named 'bar'", text(":foo").bindparams, - bindparam( - 'foo', - value=5), - bindparam( - 'bar', - value=7)) + bindparam("foo", value=5), + bindparam("bar", value=7), + ) def test_escaping_colons(self): # test escaping out text() params with a backslash self.assert_compile( - text(r"select * from foo where clock='05:06:07' " - r"and mork='\:mindy'"), + text( + r"select * from foo where clock='05:06:07' " + r"and mork='\:mindy'" + ), "select * from foo where clock='05:06:07' and mork=':mindy'", checkparams={}, params={}, - dialect="postgresql" + dialect="postgresql", ) def test_escaping_double_colons(self): self.assert_compile( text( r"SELECT * FROM pg_attribute WHERE " - r"attrelid = :tab\:\:regclass"), - "SELECT * FROM pg_attribute WHERE " - "attrelid = %(tab)s::regclass", - params={'tab': None}, - dialect="postgresql" + r"attrelid = :tab\:\:regclass" + ), + "SELECT * FROM pg_attribute WHERE " "attrelid = %(tab)s::regclass", + params={"tab": None}, + dialect="postgresql", ) def test_text_in_select_nonfrom(self): - generate_series = text("generate_series(:x, :y, :z) as s(a)").\ - bindparams(x=None, y=None, z=None) + generate_series = text( + "generate_series(:x, :y, :z) as s(a)" + ).bindparams(x=None, y=None, z=None) - s = select([ - (func.current_date() + literal_column("s.a")).label("dates") - ]).select_from(generate_series) + s = select( + [(func.current_date() + literal_column("s.a")).label("dates")] + ).select_from(generate_series) self.assert_compile( s, "SELECT CURRENT_DATE + s.a AS dates FROM " "generate_series(:x, :y, :z) as s(a)", - checkparams={'y': None, 'x': None, 'z': None} + checkparams={"y": None, "x": None, "z": None}, ) self.assert_compile( s.params(x=5, y=6, z=7), "SELECT CURRENT_DATE + s.a AS dates FROM " "generate_series(:x, :y, :z) as s(a)", - checkparams={'y': 6, 'x': 5, 'z': 7} + checkparams={"y": 6, "x": 5, "z": 7}, ) def test_escaping_percent_signs(self): stmt = text("select '%' where foo like '%bar%'") self.assert_compile( - stmt, - "select '%' where foo like '%bar%'", - dialect="sqlite" + stmt, "select '%' where foo like '%bar%'", dialect="sqlite" ) self.assert_compile( - stmt, - "select '%%' where foo like '%%bar%%'", - dialect="mysql" + stmt, "select '%%' where foo like '%%bar%%'", dialect="mysql" ) def test_percent_signs_literal_binds(self): @@ -337,88 +387,90 @@ class BindParamTest(fixtures.TestBase, AssertsCompiledSQL): stmt, "SELECT 'percent % signs %%' AS anon_1", dialect="sqlite", - literal_binds=True + literal_binds=True, ) self.assert_compile( stmt, "SELECT 'percent %% signs %%%%' AS anon_1", dialect="mysql", - literal_binds=True + literal_binds=True, ) class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_basic_toplevel_resultmap_positional(self): t = text("select id, name from user").columns( - column('id', Integer), - column('name') + column("id", Integer), column("name") ) compiled = t.compile() - eq_(compiled._create_result_map(), - {'id': ('id', - (t.c.id._proxies[0], - 'id', - 'id'), - t.c.id.type), - 'name': ('name', - (t.c.name._proxies[0], - 'name', - 'name'), - t.c.name.type)}) + eq_( + compiled._create_result_map(), + { + "id": ("id", (t.c.id._proxies[0], "id", "id"), t.c.id.type), + "name": ( + "name", + (t.c.name._proxies[0], "name", "name"), + t.c.name.type, + ), + }, + ) def test_basic_toplevel_resultmap(self): t = text("select id, name from user").columns(id=Integer, name=String) compiled = t.compile() - eq_(compiled._create_result_map(), - {'id': ('id', - (t.c.id._proxies[0], - 'id', - 'id'), - t.c.id.type), - 'name': ('name', - (t.c.name._proxies[0], - 'name', - 'name'), - t.c.name.type)}) + eq_( + compiled._create_result_map(), + { + "id": ("id", (t.c.id._proxies[0], "id", "id"), t.c.id.type), + "name": ( + "name", + (t.c.name._proxies[0], "name", "name"), + t.c.name.type, + ), + }, + ) def test_basic_subquery_resultmap(self): t = text("select id, name from user").columns(id=Integer, name=String) stmt = select([table1.c.myid]).select_from( - table1.join(t, table1.c.myid == t.c.id)) + table1.join(t, table1.c.myid == t.c.id) + ) compiled = stmt.compile() eq_( compiled._create_result_map(), { - "myid": ("myid", - (table1.c.myid, "myid", "myid"), table1.c.myid.type), - } + "myid": ( + "myid", + (table1.c.myid, "myid", "myid"), + table1.c.myid.type, + ) + }, ) def test_column_collection_ordered(self): - t = text("select a, b, c from foo").columns(column('a'), - column('b'), column('c')) - eq_(t.c.keys(), ['a', 'b', 'c']) + t = text("select a, b, c from foo").columns( + column("a"), column("b"), column("c") + ) + eq_(t.c.keys(), ["a", "b", "c"]) def test_column_collection_pos_plus_bykey(self): # overlapping positional names + type names t = text("select a, b, c from foo").columns( - column('a'), - column('b'), - b=Integer, - c=String) - eq_(t.c.keys(), ['a', 'b', 'c']) + column("a"), column("b"), b=Integer, c=String + ) + eq_(t.c.keys(), ["a", "b", "c"]) eq_(t.c.b.type._type_affinity, Integer) eq_(t.c.c.type._type_affinity, String) def _xy_table_fixture(self): m = MetaData() - t = Table('t', m, Column('x', Integer), Column('y', Integer)) + t = Table("t", m, Column("x", Integer), Column("y", Integer)) return t def _mapping(self, stmt): @@ -431,7 +483,7 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): def test_select_label_alt_name(self): t = self._xy_table_fixture() - l1, l2 = t.c.x.label('a'), t.c.y.label('b') + l1, l2 = t.c.x.label("a"), t.c.y.label("b") s = text("select x AS a, y AS b FROM t").columns(l1, l2) mapping = self._mapping(s) assert l1 in mapping @@ -440,7 +492,7 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): def test_select_alias_label_alt_name(self): t = self._xy_table_fixture() - l1, l2 = t.c.x.label('a'), t.c.y.label('b') + l1, l2 = t.c.x.label("a"), t.c.y.label("b") s = text("select x AS a, y AS b FROM t").columns(l1, l2).alias() mapping = self._mapping(s) assert l1 in mapping @@ -475,7 +527,7 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): x, y = t.c.x, t.c.y ta = t.alias() - l1, l2 = ta.c.x.label('a'), ta.c.y.label('b') + l1, l2 = ta.c.x.label("a"), ta.c.y.label("b") s = text("SELECT ta.x AS a, ta.y AS b FROM t AS ta").columns(l1, l2) mapping = self._mapping(s) @@ -484,29 +536,33 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): assert ta.c.x not in mapping def test_cte(self): - t = text("select id, name from user").columns( - id=Integer, - name=String).cte('t') + t = ( + text("select id, name from user") + .columns(id=Integer, name=String) + .cte("t") + ) s = select([table1]).where(table1.c.myid == t.c.id) self.assert_compile( s, "WITH t AS (select id, name from user) " "SELECT mytable.myid, mytable.name, mytable.description " - "FROM mytable, t WHERE mytable.myid = t.id" + "FROM mytable, t WHERE mytable.myid = t.id", ) def test_alias(self): - t = text("select id, name from user").columns( - id=Integer, - name=String).alias('t') + t = ( + text("select id, name from user") + .columns(id=Integer, name=String) + .alias("t") + ) s = select([table1]).where(table1.c.myid == t.c.id) self.assert_compile( s, "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable, (select id, name from user) AS t " - "WHERE mytable.myid = t.id" + "WHERE mytable.myid = t.id", ) def test_scalar_subquery(self): @@ -519,7 +575,7 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( s, "SELECT mytable.myid, (select id from user) AS anon_1 " - "FROM mytable WHERE mytable.myid = (select id from user)" + "FROM mytable WHERE mytable.myid = (select id from user)", ) def test_build_bindparams(self): @@ -527,16 +583,13 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): t = t.bindparams(bindparam("foo", type_=Integer)) t = t.columns(id=Integer) t = t.bindparams(bar=String) - t = t.bindparams(bindparam('bat', value='bat')) + t = t.bindparams(bindparam("bat", value="bat")) - eq_( - set(t.element._bindparams), - set(["bat", "foo", "bar"]) - ) + eq_(set(t.element._bindparams), set(["bat", "foo", "bar"])) class TextWarningsTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def _test(self, fn, arg, offending_clause, expected): with expect_warnings("Textual "): @@ -546,75 +599,74 @@ class TextWarningsTest(fixtures.TestBase, AssertsCompiledSQL): assert_raises_message( exc.SAWarning, r"Textual (?:SQL|column|SQL FROM) expression %(stmt)r should be " - r"explicitly declared (?:with|as) text\(%(stmt)r\)" % { - "stmt": util.ellipses_string(offending_clause), - }, - fn, arg + r"explicitly declared (?:with|as) text\(%(stmt)r\)" + % {"stmt": util.ellipses_string(offending_clause)}, + fn, + arg, ) def test_where(self): self._test( - select([table1.c.myid]).where, "myid == 5", "myid == 5", - "SELECT mytable.myid FROM mytable WHERE myid == 5" + select([table1.c.myid]).where, + "myid == 5", + "myid == 5", + "SELECT mytable.myid FROM mytable WHERE myid == 5", ) def test_column(self): - self._test( - select, ["myid"], "myid", - "SELECT myid" - ) + self._test(select, ["myid"], "myid", "SELECT myid") def test_having(self): self._test( - select([table1.c.myid]).having, "myid == 5", "myid == 5", - "SELECT mytable.myid FROM mytable HAVING myid == 5" + select([table1.c.myid]).having, + "myid == 5", + "myid == 5", + "SELECT mytable.myid FROM mytable HAVING myid == 5", ) def test_from(self): self._test( - select([table1.c.myid]).select_from, "mytable", "mytable", - "SELECT mytable.myid FROM mytable, mytable" # two FROMs + select([table1.c.myid]).select_from, + "mytable", + "mytable", + "SELECT mytable.myid FROM mytable, mytable", # two FROMs ) class OrderByLabelResolutionTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def _test_warning(self, stmt, offending_clause, expected): with expect_warnings( - "Can't resolve label reference %r;" % offending_clause): - self.assert_compile( - stmt, - expected - ) + "Can't resolve label reference %r;" % offending_clause + ): + self.assert_compile(stmt, expected) assert_raises_message( exc.SAWarning, - "Can't resolve label reference %r; converting to text" % - offending_clause, - stmt.compile + "Can't resolve label reference %r; converting to text" + % offending_clause, + stmt.compile, ) def test_order_by_label(self): - stmt = select([table1.c.myid.label('foo')]).order_by('foo') + stmt = select([table1.c.myid.label("foo")]).order_by("foo") self.assert_compile( - stmt, - "SELECT mytable.myid AS foo FROM mytable ORDER BY foo" + stmt, "SELECT mytable.myid AS foo FROM mytable ORDER BY foo" ) def test_order_by_colname(self): - stmt = select([table1.c.myid]).order_by('name') + stmt = select([table1.c.myid]).order_by("name") self.assert_compile( - stmt, - "SELECT mytable.myid FROM mytable ORDER BY mytable.name" + stmt, "SELECT mytable.myid FROM mytable ORDER BY mytable.name" ) def test_order_by_alias_colname(self): t1 = table1.alias() - stmt = select([t1.c.myid]).apply_labels().order_by('name') + stmt = select([t1.c.myid]).apply_labels().order_by("name") self.assert_compile( stmt, "SELECT mytable_1.myid AS mytable_1_myid " - "FROM mytable AS mytable_1 ORDER BY mytable_1.name" + "FROM mytable AS mytable_1 ORDER BY mytable_1.name", ) def test_order_by_named_label_from_anon_label(self): @@ -623,7 +675,7 @@ class OrderByLabelResolutionTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( stmt, "SELECT mytable.myid AS foo, mytable.name " - "FROM mytable ORDER BY foo" + "FROM mytable ORDER BY foo", ) def test_order_by_outermost_label(self): @@ -637,43 +689,39 @@ class OrderByLabelResolutionTest(fixtures.TestBase, AssertsCompiledSQL): stmt, "SELECT anon_1.name, bar() AS foo FROM " "(SELECT mytable.myid AS foo, mytable.name AS name " - "FROM mytable) AS anon_1 ORDER BY foo" + "FROM mytable) AS anon_1 ORDER BY foo", ) def test_unresolvable_warning_order_by(self): - stmt = select([table1.c.myid]).order_by('foobar') + stmt = select([table1.c.myid]).order_by("foobar") self._test_warning( - stmt, "foobar", - "SELECT mytable.myid FROM mytable ORDER BY foobar" + stmt, "foobar", "SELECT mytable.myid FROM mytable ORDER BY foobar" ) def test_group_by_label(self): - stmt = select([table1.c.myid.label('foo')]).group_by('foo') + stmt = select([table1.c.myid.label("foo")]).group_by("foo") self.assert_compile( - stmt, - "SELECT mytable.myid AS foo FROM mytable GROUP BY foo" + stmt, "SELECT mytable.myid AS foo FROM mytable GROUP BY foo" ) def test_group_by_colname(self): - stmt = select([table1.c.myid]).group_by('name') + stmt = select([table1.c.myid]).group_by("name") self.assert_compile( - stmt, - "SELECT mytable.myid FROM mytable GROUP BY mytable.name" + stmt, "SELECT mytable.myid FROM mytable GROUP BY mytable.name" ) def test_unresolvable_warning_group_by(self): - stmt = select([table1.c.myid]).group_by('foobar') + stmt = select([table1.c.myid]).group_by("foobar") self._test_warning( - stmt, "foobar", - "SELECT mytable.myid FROM mytable GROUP BY foobar" + stmt, "foobar", "SELECT mytable.myid FROM mytable GROUP BY foobar" ) def test_asc(self): - stmt = select([table1.c.myid]).order_by(asc('name'), 'description') + stmt = select([table1.c.myid]).order_by(asc("name"), "description") self.assert_compile( stmt, "SELECT mytable.myid FROM mytable " - "ORDER BY mytable.name ASC, mytable.description" + "ORDER BY mytable.name ASC, mytable.description", ) def test_group_by_subquery(self): @@ -685,39 +733,39 @@ class OrderByLabelResolutionTest(fixtures.TestBase, AssertsCompiledSQL): "anon_1.description AS anon_1_description FROM " "(SELECT mytable.myid AS myid, mytable.name AS name, " "mytable.description AS description FROM mytable) AS anon_1 " - "GROUP BY anon_1.myid" + "GROUP BY anon_1.myid", ) def test_order_by_func_label_desc(self): - stmt = select([func.foo('bar').label('fb'), table1]).\ - order_by(desc('fb')) + stmt = select([func.foo("bar").label("fb"), table1]).order_by( + desc("fb") + ) self.assert_compile( stmt, "SELECT foo(:foo_1) AS fb, mytable.myid, mytable.name, " - "mytable.description FROM mytable ORDER BY fb DESC" + "mytable.description FROM mytable ORDER BY fb DESC", ) def test_pg_distinct(self): - stmt = select([table1]).distinct('name') + stmt = select([table1]).distinct("name") self.assert_compile( stmt, "SELECT DISTINCT ON (mytable.name) mytable.myid, " "mytable.name, mytable.description FROM mytable", - dialect="postgresql" + dialect="postgresql", ) def test_over(self): stmt = select([column("foo"), column("bar")]) stmt = select( - [func.row_number(). - over(order_by='foo', partition_by='bar')] + [func.row_number().over(order_by="foo", partition_by="bar")] ).select_from(stmt) self.assert_compile( stmt, "SELECT row_number() OVER (PARTITION BY bar ORDER BY foo) " - "AS anon_1 FROM (SELECT foo, bar)" + "AS anon_1 FROM (SELECT foo, bar)", ) def test_union_column(self): @@ -728,23 +776,20 @@ class OrderByLabelResolutionTest(fixtures.TestBase, AssertsCompiledSQL): stmt, "SELECT mytable.myid, mytable.name, mytable.description FROM " "mytable UNION SELECT mytable.myid, mytable.name, " - "mytable.description FROM mytable ORDER BY name" + "mytable.description FROM mytable ORDER BY name", ) def test_union_label(self): - s1 = select([func.foo("hoho").label('x')]) - s2 = select([func.foo("Bar").label('y')]) + s1 = select([func.foo("hoho").label("x")]) + s2 = select([func.foo("Bar").label("y")]) stmt = union(s1, s2).order_by("x") self.assert_compile( stmt, - "SELECT foo(:foo_1) AS x UNION SELECT foo(:foo_2) AS y ORDER BY x" + "SELECT foo(:foo_1) AS x UNION SELECT foo(:foo_2) AS y ORDER BY x", ) def test_standalone_units_stringable(self): - self.assert_compile( - desc("somelabel"), - "somelabel DESC" - ) + self.assert_compile(desc("somelabel"), "somelabel DESC") def test_columnadapter_anonymized(self): """test issue #3148 @@ -755,14 +800,18 @@ class OrderByLabelResolutionTest(fixtures.TestBase, AssertsCompiledSQL): """ exprs = [ table1.c.myid, - table1.c.name.label('t1name'), - func.foo("hoho").label('x')] + table1.c.name.label("t1name"), + func.foo("hoho").label("x"), + ] ta = table1.alias() adapter = sql_util.ColumnAdapter(ta, anonymize_labels=True) - s1 = select([adapter.columns[expr] for expr in exprs]).\ - apply_labels().order_by("myid", "t1name", "x") + s1 = ( + select([adapter.columns[expr] for expr in exprs]) + .apply_labels() + .order_by("myid", "t1name", "x") + ) def go(): # the labels here are anonymized, so label naming @@ -771,13 +820,17 @@ class OrderByLabelResolutionTest(fixtures.TestBase, AssertsCompiledSQL): s1, "SELECT mytable_1.myid AS mytable_1_myid, " "mytable_1.name AS name_1, foo(:foo_2) AS foo_1 " - "FROM mytable AS mytable_1 ORDER BY mytable_1.myid, t1name, x" + "FROM mytable AS mytable_1 ORDER BY mytable_1.myid, t1name, x", ) assert_warnings( go, - ["Can't resolve label reference 't1name'", - "Can't resolve label reference 'x'"], regex=True) + [ + "Can't resolve label reference 't1name'", + "Can't resolve label reference 'x'", + ], + regex=True, + ) def test_columnadapter_non_anonymized(self): """test issue #3148 @@ -788,19 +841,23 @@ class OrderByLabelResolutionTest(fixtures.TestBase, AssertsCompiledSQL): """ exprs = [ table1.c.myid, - table1.c.name.label('t1name'), - func.foo("hoho").label('x')] + table1.c.name.label("t1name"), + func.foo("hoho").label("x"), + ] ta = table1.alias() adapter = sql_util.ColumnAdapter(ta) - s1 = select([adapter.columns[expr] for expr in exprs]).\ - apply_labels().order_by("myid", "t1name", "x") + s1 = ( + select([adapter.columns[expr] for expr in exprs]) + .apply_labels() + .order_by("myid", "t1name", "x") + ) # labels are maintained self.assert_compile( s1, "SELECT mytable_1.myid AS mytable_1_myid, " "mytable_1.name AS t1name, foo(:foo_1) AS x " - "FROM mytable AS mytable_1 ORDER BY mytable_1.myid, t1name, x" + "FROM mytable AS mytable_1 ORDER BY mytable_1.myid, t1name, x", ) diff --git a/test/sql/test_type_expressions.py b/test/sql/test_type_expressions.py index 71229bfaeb..8c5fb1f071 100644 --- a/test/sql/test_type_expressions.py +++ b/test/sql/test_type_expressions.py @@ -1,22 +1,22 @@ -from sqlalchemy import (Table, - Column, - String, - func, - MetaData, - select, - TypeDecorator, - cast) +from sqlalchemy import ( + Table, + Column, + String, + func, + MetaData, + select, + TypeDecorator, + cast, +) from sqlalchemy.testing import fixtures, AssertsCompiledSQL from sqlalchemy import testing from sqlalchemy.testing import eq_ class _ExprFixture(object): - def _test_table(self, type_): test_table = Table( - 'test_table', - MetaData(), Column('x', String), Column('y', type_) + "test_table", MetaData(), Column("x", String), Column("y", type_) ) return test_table @@ -92,7 +92,6 @@ class _ExprFixture(object): return self._test_table(variant) def _dialect_level_fixture(self): - class ImplString(String): def bind_expression(self, bindvalue): return func.dialect_bind(bindvalue) @@ -101,27 +100,28 @@ class _ExprFixture(object): return func.dialect_colexpr(col) from sqlalchemy.engine import default + dialect = default.DefaultDialect() dialect.colspecs = {String: ImplString} return dialect class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_select_cols(self): table = self._fixture() self.assert_compile( select([table]), - "SELECT test_table.x, lower(test_table.y) AS y FROM test_table" + "SELECT test_table.x, lower(test_table.y) AS y FROM test_table", ) def test_anonymous_expr(self): table = self._fixture() self.assert_compile( select([cast(table.c.y, String)]), - "SELECT CAST(test_table.y AS VARCHAR) AS anon_1 FROM test_table" + "SELECT CAST(test_table.y AS VARCHAR) AS anon_1 FROM test_table", ) def test_select_cols_use_labels(self): @@ -130,28 +130,27 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( select([table]).apply_labels(), "SELECT test_table.x AS test_table_x, " - "lower(test_table.y) AS test_table_y FROM test_table" + "lower(test_table.y) AS test_table_y FROM test_table", ) def test_select_cols_use_labels_result_map_targeting(self): table = self._fixture() compiled = select([table]).apply_labels().compile() - assert table.c.y in compiled._create_result_map()['test_table_y'][1] - assert table.c.x in compiled._create_result_map()['test_table_x'][1] + assert table.c.y in compiled._create_result_map()["test_table_y"][1] + assert table.c.x in compiled._create_result_map()["test_table_x"][1] # the lower() function goes into the result_map, we don't really # need this but it's fine self.assert_compile( - compiled._create_result_map()['test_table_y'][1][3], - "lower(test_table.y)" + compiled._create_result_map()["test_table_y"][1][3], + "lower(test_table.y)", ) # then the original column gets put in there as well. # as of 1.1 it's important that it is first as this is # taken as significant by the result processor. self.assert_compile( - compiled._create_result_map()['test_table_y'][1][0], - "test_table.y" + compiled._create_result_map()["test_table_y"][1][0], "test_table.y" ) def test_insert_binds(self): @@ -159,12 +158,12 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( table.insert(), - "INSERT INTO test_table (x, y) VALUES (:x, lower(:y))" + "INSERT INTO test_table (x, y) VALUES (:x, lower(:y))", ) self.assert_compile( table.insert().values(y="hi"), - "INSERT INTO test_table (y) VALUES (lower(:y))" + "INSERT INTO test_table (y) VALUES (lower(:y))", ) def test_select_binds(self): @@ -172,7 +171,7 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( select([table]).where(table.c.y == "hi"), "SELECT test_table.x, lower(test_table.y) AS y FROM " - "test_table WHERE test_table.y = lower(:y_1)" + "test_table WHERE test_table.y = lower(:y_1)", ) def test_dialect(self): @@ -184,7 +183,7 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): select([table.c.x]).where(table.c.x == "hi"), "SELECT dialect_colexpr(test_table.x) AS x " "FROM test_table WHERE test_table.x = dialect_bind(:x_1)", - dialect=dialect + dialect=dialect, ) def test_type_decorator_inner(self): @@ -193,7 +192,7 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( select([table]).where(table.c.y == "hi"), "SELECT test_table.x, inside_colexpr(test_table.y) AS y " - "FROM test_table WHERE test_table.y = inside_bind(:y_1)" + "FROM test_table WHERE test_table.y = inside_bind(:y_1)", ) def test_type_decorator_inner_plus_dialect(self): @@ -209,7 +208,7 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): "SELECT dialect_colexpr(test_table.x) AS x, " "dialect_colexpr(test_table.y) AS y FROM test_table " "WHERE test_table.y = dialect_bind(:y_1)", - dialect=dialect + dialect=dialect, ) def test_type_decorator_outer(self): @@ -218,7 +217,7 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( select([table]).where(table.c.y == "hi"), "SELECT test_table.x, outside_colexpr(test_table.y) AS y " - "FROM test_table WHERE test_table.y = outside_bind(:y_1)" + "FROM test_table WHERE test_table.y = outside_bind(:y_1)", ) def test_type_decorator_outer_plus_dialect(self): @@ -232,7 +231,7 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): "SELECT dialect_colexpr(test_table.x) AS x, " "outside_colexpr(test_table.y) AS y " "FROM test_table WHERE test_table.y = outside_bind(:y_1)", - dialect=dialect + dialect=dialect, ) def test_type_decorator_both(self): @@ -243,7 +242,7 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): "SELECT test_table.x, " "outside_colexpr(inside_colexpr(test_table.y)) AS y " "FROM test_table WHERE " - "test_table.y = outside_bind(inside_bind(:y_1))" + "test_table.y = outside_bind(inside_bind(:y_1))", ) def test_type_decorator_both_plus_dialect(self): @@ -260,7 +259,7 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): "outside_colexpr(dialect_colexpr(test_table.y)) AS y " "FROM test_table WHERE " "test_table.y = outside_bind(dialect_bind(:y_1))", - dialect=dialect + dialect=dialect, ) def test_type_decorator_both_w_variant(self): @@ -271,19 +270,19 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): "SELECT test_table.x, " "outside_colexpr(inside_colexpr(test_table.y)) AS y " "FROM test_table WHERE " - "test_table.y = outside_bind(inside_bind(:y_1))" + "test_table.y = outside_bind(inside_bind(:y_1))", ) class DerivedTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def test_select_from_select(self): table = self._fixture() self.assert_compile( table.select().select(), "SELECT x, lower(y) AS y FROM (SELECT test_table.x " - "AS x, test_table.y AS y FROM test_table)" + "AS x, test_table.y AS y FROM test_table)", ) def test_select_from_alias(self): @@ -292,7 +291,7 @@ class DerivedTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): table.select().alias().select(), "SELECT anon_1.x, lower(anon_1.y) AS y FROM (SELECT " "test_table.x AS x, test_table.y AS y " - "FROM test_table) AS anon_1" + "FROM test_table) AS anon_1", ) def test_select_from_aliased_join(self): @@ -302,16 +301,17 @@ class DerivedTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): j = s1.join(s2, s1.c.x == s2.c.x) s3 = j.select() self.assert_compile( - s3, "SELECT anon_1.x, lower(anon_1.y) AS y, anon_2.x, " + s3, + "SELECT anon_1.x, lower(anon_1.y) AS y, anon_2.x, " "lower(anon_2.y) AS y " "FROM (SELECT test_table.x AS x, test_table.y AS y " "FROM test_table) AS anon_1 JOIN (SELECT " "test_table.x AS x, test_table.y AS y " - "FROM test_table) AS anon_2 ON anon_1.x = anon_2.x") + "FROM test_table) AS anon_2 ON anon_1.x = anon_2.x", + ) class RoundTripTestBase(object): - def test_round_trip(self): testing.db.execute( self.tables.test_table.insert(), @@ -323,83 +323,63 @@ class RoundTripTestBase(object): # test insert coercion alone eq_( testing.db.execute( - "select * from test_table order by y").fetchall(), - [ - ("X1", "y1"), - ("X2", "y2"), - ("X3", "y3"), - ] + "select * from test_table order by y" + ).fetchall(), + [("X1", "y1"), ("X2", "y2"), ("X3", "y3")], ) # conversion back to upper eq_( testing.db.execute( - select([self.tables.test_table]). - order_by(self.tables.test_table.c.y) + select([self.tables.test_table]).order_by( + self.tables.test_table.c.y + ) ).fetchall(), - [ - ("X1", "Y1"), - ("X2", "Y2"), - ("X3", "Y3"), - ] + [("X1", "Y1"), ("X2", "Y2"), ("X3", "Y3")], ) def test_targeting_no_labels(self): testing.db.execute( - self.tables.test_table.insert(), - {"x": "X1", "y": "Y1"}, + self.tables.test_table.insert(), {"x": "X1", "y": "Y1"} ) row = testing.db.execute(select([self.tables.test_table])).first() - eq_( - row[self.tables.test_table.c.y], - "Y1" - ) + eq_(row[self.tables.test_table.c.y], "Y1") def test_targeting_by_string(self): testing.db.execute( - self.tables.test_table.insert(), - {"x": "X1", "y": "Y1"}, + self.tables.test_table.insert(), {"x": "X1", "y": "Y1"} ) row = testing.db.execute(select([self.tables.test_table])).first() - eq_( - row["y"], - "Y1" - ) + eq_(row["y"], "Y1") def test_targeting_apply_labels(self): testing.db.execute( - self.tables.test_table.insert(), - {"x": "X1", "y": "Y1"}, - ) - row = testing.db.execute(select([self.tables.test_table]). - apply_labels()).first() - eq_( - row[self.tables.test_table.c.y], - "Y1" + self.tables.test_table.insert(), {"x": "X1", "y": "Y1"} ) + row = testing.db.execute( + select([self.tables.test_table]).apply_labels() + ).first() + eq_(row[self.tables.test_table.c.y], "Y1") def test_targeting_individual_labels(self): testing.db.execute( - self.tables.test_table.insert(), - {"x": "X1", "y": "Y1"}, - ) - row = testing.db.execute(select([ - self.tables.test_table.c.x.label('xbar'), - self.tables.test_table.c.y.label('ybar') - ])).first() - eq_( - row[self.tables.test_table.c.y], - "Y1" + self.tables.test_table.insert(), {"x": "X1", "y": "Y1"} ) - + row = testing.db.execute( + select( + [ + self.tables.test_table.c.x.label("xbar"), + self.tables.test_table.c.y.label("ybar"), + ] + ) + ).first() + eq_(row[self.tables.test_table.c.y], "Y1") class StringRoundTripTest(fixtures.TablesTest, RoundTripTestBase): - @classmethod def define_tables(cls, metadata): class MyString(String): - def bind_expression(self, bindvalue): return func.lower(bindvalue) @@ -407,16 +387,14 @@ class StringRoundTripTest(fixtures.TablesTest, RoundTripTestBase): return func.upper(col) Table( - 'test_table', + "test_table", metadata, - Column('x', String(50)), - Column('y', MyString(50) - ) + Column("x", String(50)), + Column("y", MyString(50)), ) class TypeDecRoundTripTest(fixtures.TablesTest, RoundTripTestBase): - @classmethod def define_tables(cls, metadata): class MyString(TypeDecorator): @@ -429,38 +407,33 @@ class TypeDecRoundTripTest(fixtures.TablesTest, RoundTripTestBase): return func.upper(col) Table( - 'test_table', + "test_table", metadata, - Column('x', String(50)), - Column('y', MyString(50) - ) + Column("x", String(50)), + Column("y", MyString(50)), ) class ReturningTest(fixtures.TablesTest): - __requires__ = 'returning', + __requires__ = ("returning",) @classmethod def define_tables(cls, metadata): class MyString(String): - def column_expression(self, col): return func.lower(col) Table( - 'test_table', - metadata, Column('x', String(50)), - Column('y', MyString(50), server_default="YVALUE") + "test_table", + metadata, + Column("x", String(50)), + Column("y", MyString(50), server_default="YVALUE"), ) @testing.provide_metadata def test_insert_returning(self): table = self.tables.test_table result = testing.db.execute( - table.insert().returning(table.c.y), - {"x": "xvalue"} - ) - eq_( - result.first(), - ("yvalue",) + table.insert().returning(table.c.y), {"x": "xvalue"} ) + eq_(result.first(), ("yvalue",)) diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 7b464f8c0b..179f1050c9 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -1,16 +1,63 @@ # coding: utf-8 -from sqlalchemy.testing import eq_, is_, is_not_, assert_raises, \ - assert_raises_message, expect_warnings +from sqlalchemy.testing import ( + eq_, + is_, + is_not_, + assert_raises, + assert_raises_message, + expect_warnings, +) import decimal import datetime import os from sqlalchemy import ( - Unicode, MetaData, PickleType, Boolean, TypeDecorator, Integer, - Interval, Float, Numeric, Text, CHAR, String, distinct, select, bindparam, - and_, func, Date, LargeBinary, literal, cast, text, Enum, - type_coerce, VARCHAR, Time, DateTime, BigInteger, SmallInteger, BOOLEAN, - BLOB, NCHAR, NVARCHAR, CLOB, TIME, DATE, DATETIME, TIMESTAMP, SMALLINT, - INTEGER, DECIMAL, NUMERIC, FLOAT, REAL, ARRAY, JSON) + Unicode, + MetaData, + PickleType, + Boolean, + TypeDecorator, + Integer, + Interval, + Float, + Numeric, + Text, + CHAR, + String, + distinct, + select, + bindparam, + and_, + func, + Date, + LargeBinary, + literal, + cast, + text, + Enum, + type_coerce, + VARCHAR, + Time, + DateTime, + BigInteger, + SmallInteger, + BOOLEAN, + BLOB, + NCHAR, + NVARCHAR, + CLOB, + TIME, + DATE, + DATETIME, + TIMESTAMP, + SMALLINT, + INTEGER, + DECIMAL, + NUMERIC, + FLOAT, + REAL, + ARRAY, + JSON, +) from sqlalchemy.sql import ddl from sqlalchemy.sql import visitors from sqlalchemy import inspection @@ -21,8 +68,12 @@ from sqlalchemy.schema import CheckConstraint, AddConstraint from sqlalchemy.engine import default from sqlalchemy.testing.schema import Table, Column from sqlalchemy import testing -from sqlalchemy.testing import AssertsCompiledSQL, AssertsExecutionResults, \ - engines, pickleable +from sqlalchemy.testing import ( + AssertsCompiledSQL, + AssertsExecutionResults, + engines, + pickleable, +) from sqlalchemy.testing.util import picklers from sqlalchemy.testing.util import round_decimal from sqlalchemy.testing import fixtures @@ -33,22 +84,22 @@ import importlib class AdaptTest(fixtures.TestBase): - def _all_dialect_modules(self): return [ importlib.import_module("sqlalchemy.dialects.%s" % d) - for d in dialects.__all__ if not d.startswith("_") + for d in dialects.__all__ + if not d.startswith("_") ] def _all_dialects(self): - return [d.base.dialect() for d in - self._all_dialect_modules()] + return [d.base.dialect() for d in self._all_dialect_modules()] def _types_for_mod(self, mod): for key in dir(mod): typ = getattr(mod, key) - if not isinstance(typ, type) or \ - not issubclass(typ, types.TypeEngine): + if not isinstance(typ, type) or not issubclass( + typ, types.TypeEngine + ): continue yield typ @@ -61,6 +112,7 @@ class AdaptTest(fixtures.TestBase): def test_uppercase_importable(self): import sqlalchemy as sa + for typ in self._types_for_mod(types): if typ.__name__ == typ.__name__.upper(): assert getattr(sa, typ.__name__) is typ @@ -91,29 +143,34 @@ class AdaptTest(fixtures.TestBase): (TIME, ("TIME", "TIME WITHOUT TIME ZONE")), (CLOB, "CLOB"), (VARCHAR(10), ("VARCHAR(10)", "VARCHAR(10 CHAR)")), - (NVARCHAR(10), ( - "NVARCHAR(10)", "NATIONAL VARCHAR(10)", "NVARCHAR2(10)")), + ( + NVARCHAR(10), + ("NVARCHAR(10)", "NATIONAL VARCHAR(10)", "NVARCHAR2(10)"), + ), (CHAR, "CHAR"), (NCHAR, ("NCHAR", "NATIONAL CHAR")), (BLOB, ("BLOB", "BLOB SUB_TYPE 0")), - (BOOLEAN, ("BOOLEAN", "BOOL", "INTEGER")) + (BOOLEAN, ("BOOLEAN", "BOOL", "INTEGER")), ): if isinstance(expected, str): - expected = (expected, ) + expected = (expected,) try: - compiled = types.to_instance(type_).\ - compile(dialect=dialect) + compiled = types.to_instance(type_).compile( + dialect=dialect + ) except NotImplementedError: continue - assert compiled in expected, \ - "%r matches none of %r for dialect %s" % \ - (compiled, expected, dialect.name) + assert compiled in expected, ( + "%r matches none of %r for dialect %s" + % (compiled, expected, dialect.name) + ) - assert str(types.to_instance(type_)) in expected, \ - "default str() of type %r not expected, %r" % \ - (type_, expected) + assert str(types.to_instance(type_)) in expected, ( + "default str() of type %r not expected, %r" + % (type_, expected) + ) @testing.uses_deprecated() def test_adapt_method(self): @@ -134,11 +191,15 @@ class AdaptTest(fixtures.TestBase): up_adaptions = [typ] + typ.__subclasses__() yield False, typ, up_adaptions for subcl in typ.__subclasses__(): - if subcl is not typ and typ is not TypeDecorator and \ - "sqlalchemy" in subcl.__module__: + if ( + subcl is not typ + and typ is not TypeDecorator + and "sqlalchemy" in subcl.__module__ + ): yield True, subcl, [typ] from sqlalchemy.sql import sqltypes + for is_down_adaption, typ, target_adaptions in adaptions(): if typ in (types.TypeDecorator, types.TypeEngine, types.Variant): continue @@ -148,10 +209,9 @@ class AdaptTest(fixtures.TestBase): t1 = typ() for cls in target_adaptions: if ( - (is_down_adaption and - issubclass(typ, sqltypes.Emulated)) or - (not is_down_adaption and - issubclass(cls, sqltypes.Emulated)) + is_down_adaption and issubclass(typ, sqltypes.Emulated) + ) or ( + not is_down_adaption and issubclass(cls, sqltypes.Emulated) ): continue @@ -167,18 +227,24 @@ class AdaptTest(fixtures.TestBase): for k in t1.__dict__: if k in ( - 'impl', '_is_oracle_number', - '_create_events', 'create_constraint', - 'inherit_schema', 'schema', 'metadata', - 'name', ): + "impl", + "_is_oracle_number", + "_create_events", + "create_constraint", + "inherit_schema", + "schema", + "metadata", + "name", + ): continue # assert each value was copied, or that # the adapted type has a more specific # value than the original (i.e. SQL Server # applies precision=24 for REAL) - assert \ - getattr(t2, k) == t1.__dict__[k] or \ - t1.__dict__[k] is None + assert ( + getattr(t2, k) == t1.__dict__[k] + or t1.__dict__[k] is None + ) def test_python_type(self): eq_(types.Integer().python_type, int) @@ -192,11 +258,10 @@ class AdaptTest(fixtures.TestBase): eq_(types.String().python_type, str) eq_(types.Unicode().python_type, util.text_type) eq_(types.String(convert_unicode=True).python_type, util.text_type) - eq_(types.Enum('one', 'two', 'three').python_type, str) + eq_(types.Enum("one", "two", "three").python_type, str) assert_raises( - NotImplementedError, - lambda: types.TypeEngine().python_type + NotImplementedError, lambda: types.TypeEngine().python_type ) @testing.uses_deprecated() @@ -219,22 +284,17 @@ class AdaptTest(fixtures.TestBase): """ t1 = String(length=50, convert_unicode=False) t2 = t1.adapt(Text, convert_unicode=True) - eq_( - t2.length, 50 - ) - eq_( - t2.convert_unicode, True - ) + eq_(t2.length, 50) + eq_(t2.convert_unicode, True) class TypeAffinityTest(fixtures.TestBase): - def test_type_affinity(self): for type_, affin in [ (String(), String), (VARCHAR(), String), (Date(), Date), - (LargeBinary(), types._Binary) + (LargeBinary(), types._Binary), ]: eq_(type_._type_affinity, affin) @@ -258,7 +318,7 @@ class TypeAffinityTest(fixtures.TestBase): impl = CHAR def load_dialect_impl(self, dialect): - if dialect.name == 'postgresql': + if dialect.name == "postgresql": return dialect.type_descriptor(postgresql.UUID()) else: return dialect.type_descriptor(CHAR(32)) @@ -270,29 +330,28 @@ class TypeAffinityTest(fixtures.TestBase): class PickleTypesTest(fixtures.TestBase): - def test_pickle_types(self): for loads, dumps in picklers(): column_types = [ - Column('Boo', Boolean()), - Column('Str', String()), - Column('Tex', Text()), - Column('Uni', Unicode()), - Column('Int', Integer()), - Column('Sma', SmallInteger()), - Column('Big', BigInteger()), - Column('Num', Numeric()), - Column('Flo', Float()), - Column('Dat', DateTime()), - Column('Dat', Date()), - Column('Tim', Time()), - Column('Lar', LargeBinary()), - Column('Pic', PickleType()), - Column('Int', Interval()), + Column("Boo", Boolean()), + Column("Str", String()), + Column("Tex", Text()), + Column("Uni", Unicode()), + Column("Int", Integer()), + Column("Sma", SmallInteger()), + Column("Big", BigInteger()), + Column("Num", Numeric()), + Column("Flo", Float()), + Column("Dat", DateTime()), + Column("Dat", Date()), + Column("Tim", Time()), + Column("Lar", LargeBinary()), + Column("Pic", PickleType()), + Column("Int", Interval()), ] for column_type in column_types: meta = MetaData() - Table('foo', meta, column_type) + Table("foo", meta, column_type) loads(dumps(column_type)) loads(dumps(meta)) @@ -301,18 +360,19 @@ class _UserDefinedTypeFixture(object): @classmethod def define_tables(cls, metadata): class MyType(types.UserDefinedType): - def get_col_spec(self): return "VARCHAR(100)" def bind_processor(self, dialect): def process(value): return "BIND_IN" + value + return process def result_processor(self, dialect, coltype): def process(value): return value + "BIND_OUT" + return process def adapt(self, typeobj): @@ -322,19 +382,23 @@ class _UserDefinedTypeFixture(object): impl = String def bind_processor(self, dialect): - impl_processor = super(MyDecoratedType, self).\ - bind_processor(dialect) or (lambda value: value) + impl_processor = super(MyDecoratedType, self).bind_processor( + dialect + ) or (lambda value: value) def process(value): return "BIND_IN" + impl_processor(value) + return process def result_processor(self, dialect, coltype): - impl_processor = super(MyDecoratedType, self).\ - result_processor(dialect, coltype) or (lambda value: value) + impl_processor = super(MyDecoratedType, self).result_processor( + dialect, coltype + ) or (lambda value: value) def process(value): return impl_processor(value) + "BIND_OUT" + return process def copy(self): @@ -365,7 +429,6 @@ class _UserDefinedTypeFixture(object): return MyNewIntType() class MyNewIntSubClass(MyNewIntType): - def process_result_value(self, value, dialect): return value * 15 @@ -376,54 +439,85 @@ class _UserDefinedTypeFixture(object): impl = Unicode def bind_processor(self, dialect): - impl_processor = super(MyUnicodeType, self).\ - bind_processor(dialect) or (lambda value: value) + impl_processor = super(MyUnicodeType, self).bind_processor( + dialect + ) or (lambda value: value) def process(value): return "BIND_IN" + impl_processor(value) + return process def result_processor(self, dialect, coltype): - impl_processor = super(MyUnicodeType, self).\ - result_processor(dialect, coltype) or (lambda value: value) + impl_processor = super(MyUnicodeType, self).result_processor( + dialect, coltype + ) or (lambda value: value) def process(value): return impl_processor(value) + "BIND_OUT" + return process def copy(self): return MyUnicodeType(self.impl.length) Table( - 'users', metadata, - Column('user_id', Integer, primary_key=True), + "users", + metadata, + Column("user_id", Integer, primary_key=True), # totall custom type - Column('goofy', MyType, nullable=False), - + Column("goofy", MyType, nullable=False), # decorated type with an argument, so its a String - Column('goofy2', MyDecoratedType(50), nullable=False), - - Column('goofy4', MyUnicodeType(50), nullable=False), - Column('goofy7', MyNewUnicodeType(50), nullable=False), - Column('goofy8', MyNewIntType, nullable=False), - Column('goofy9', MyNewIntSubClass, nullable=False), + Column("goofy2", MyDecoratedType(50), nullable=False), + Column("goofy4", MyUnicodeType(50), nullable=False), + Column("goofy7", MyNewUnicodeType(50), nullable=False), + Column("goofy8", MyNewIntType, nullable=False), + Column("goofy9", MyNewIntSubClass, nullable=False), ) + class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): __backend__ = True def _data_fixture(self): users = self.tables.users with testing.db.connect() as conn: - conn.execute(users.insert(), dict( - user_id=2, goofy='jack', goofy2='jack', goofy4=util.u('jack'), - goofy7=util.u('jack'), goofy8=12, goofy9=12)) - conn.execute(users.insert(), dict( - user_id=3, goofy='lala', goofy2='lala', goofy4=util.u('lala'), - goofy7=util.u('lala'), goofy8=15, goofy9=15)) - conn.execute(users.insert(), dict( - user_id=4, goofy='fred', goofy2='fred', goofy4=util.u('fred'), - goofy7=util.u('fred'), goofy8=9, goofy9=9)) + conn.execute( + users.insert(), + dict( + user_id=2, + goofy="jack", + goofy2="jack", + goofy4=util.u("jack"), + goofy7=util.u("jack"), + goofy8=12, + goofy9=12, + ), + ) + conn.execute( + users.insert(), + dict( + user_id=3, + goofy="lala", + goofy2="lala", + goofy4=util.u("lala"), + goofy7=util.u("lala"), + goofy8=15, + goofy9=15, + ), + ) + conn.execute( + users.insert(), + dict( + user_id=4, + goofy="fred", + goofy2="fred", + goofy4=util.u("fred"), + goofy7=util.u("fred"), + goofy8=9, + goofy9=9, + ), + ) def test_processing(self): users = self.tables.users @@ -432,11 +526,13 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): result = users.select().order_by(users.c.user_id).execute().fetchall() for assertstr, assertint, assertint2, row in zip( [ - "BIND_INjackBIND_OUT", "BIND_INlalaBIND_OUT", - "BIND_INfredBIND_OUT"], + "BIND_INjackBIND_OUT", + "BIND_INlalaBIND_OUT", + "BIND_INfredBIND_OUT", + ], [1200, 1500, 900], [1800, 2250, 1350], - result + result, ): for col in list(row)[1:5]: eq_(col, assertstr) @@ -449,9 +545,11 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): users = self.tables.users self._data_fixture() - stmt = select([users.c.user_id, users.c.goofy8]).where( - users.c.goofy8.in_([15, 9]) - ).order_by(users.c.user_id) + stmt = ( + select([users.c.user_id, users.c.goofy8]) + .where(users.c.goofy8.in_([15, 9])) + .order_by(users.c.user_id) + ) result = testing.db.execute(stmt, {"goofy": [15, 9]}) eq_(result.fetchall(), [(3, 1500), (4, 900)]) @@ -459,15 +557,18 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): users = self.tables.users self._data_fixture() - stmt = select([users.c.user_id, users.c.goofy8]).where( - users.c.goofy8.in_(bindparam("goofy", expanding=True)) - ).order_by(users.c.user_id) + stmt = ( + select([users.c.user_id, users.c.goofy8]) + .where(users.c.goofy8.in_(bindparam("goofy", expanding=True))) + .order_by(users.c.user_id) + ) result = testing.db.execute(stmt, {"goofy": [15, 9]}) eq_(result.fetchall(), [(3, 1500), (4, 900)]) class UserDefinedTest( - _UserDefinedTypeFixture, fixtures.TablesTest, AssertsCompiledSQL): + _UserDefinedTypeFixture, fixtures.TablesTest, AssertsCompiledSQL +): run_create_tables = None run_inserts = None @@ -485,32 +586,25 @@ class UserDefinedTest( self.assert_compile( select([literal("test", MyType)]), "SELECT 'HI->test<-THERE' AS anon_1", - dialect='default', - literal_binds=True + dialect="default", + literal_binds=True, ) def test_kw_colspec(self): class MyType(types.UserDefinedType): def get_col_spec(self, **kw): - return "FOOB %s" % kw['type_expression'].name + return "FOOB %s" % kw["type_expression"].name class MyOtherType(types.UserDefinedType): def get_col_spec(self): return "BAR" - t = Table('t', MetaData(), Column('bar', MyType, nullable=False)) + t = Table("t", MetaData(), Column("bar", MyType, nullable=False)) - self.assert_compile( - ddl.CreateColumn(t.c.bar), - "bar FOOB bar NOT NULL" - ) + self.assert_compile(ddl.CreateColumn(t.c.bar), "bar FOOB bar NOT NULL") - t = Table('t', MetaData(), - Column('bar', MyOtherType, nullable=False)) - self.assert_compile( - ddl.CreateColumn(t.c.bar), - "bar BAR NOT NULL" - ) + t = Table("t", MetaData(), Column("bar", MyOtherType, nullable=False)) + self.assert_compile(ddl.CreateColumn(t.c.bar), "bar BAR NOT NULL") def test_typedecorator_literal_render_fallback_bound(self): # fall back to process_bind_param for literal @@ -524,19 +618,22 @@ class UserDefinedTest( self.assert_compile( select([literal("test", MyType)]), "SELECT 'HI->test<-THERE' AS anon_1", - dialect='default', - literal_binds=True + dialect="default", + literal_binds=True, ) def test_typedecorator_impl(self): for impl_, exp, kw in [ (Float, "FLOAT", {}), - (Float, "FLOAT(2)", {'precision': 2}), - (Float(2), "FLOAT(2)", {'precision': 4}), + (Float, "FLOAT(2)", {"precision": 2}), + (Float(2), "FLOAT(2)", {"precision": 4}), (Numeric(19, 2), "NUMERIC(19, 2)", {}), ]: for dialect_ in ( - dialects.postgresql, dialects.mssql, dialects.mysql): + dialects.postgresql, + dialects.mssql, + dialects.mysql, + ): dialect_ = dialect_.dialect() raw_impl = types.to_instance(impl_, **kw) @@ -552,21 +649,17 @@ class UserDefinedTest( dec_dialect_impl = dec_type.dialect_impl(dialect_) eq_(dec_dialect_impl.__class__, MyType) eq_( - raw_dialect_impl.__class__, - dec_dialect_impl.impl.__class__) - - self.assert_compile( - MyType(**kw), - exp, - dialect=dialect_ + raw_dialect_impl.__class__, dec_dialect_impl.impl.__class__ ) + self.assert_compile(MyType(**kw), exp, dialect=dialect_) + def test_user_defined_typedec_impl(self): class MyType(types.TypeDecorator): impl = Float def load_dialect_impl(self, dialect): - if dialect.name == 'sqlite': + if dialect.name == "sqlite": return String(50) else: return super(MyType, self).load_dialect_impl(dialect) @@ -578,11 +671,11 @@ class UserDefinedTest( self.assert_compile(t, "FLOAT", dialect=pg) eq_( t.dialect_impl(dialect=sl).impl.__class__, - String().dialect_impl(dialect=sl).__class__ + String().dialect_impl(dialect=sl).__class__, ) eq_( t.dialect_impl(dialect=pg).impl.__class__, - Float().dialect_impl(pg).__class__ + Float().dialect_impl(pg).__class__, ) def test_type_decorator_repr(self): @@ -593,60 +686,54 @@ class UserDefinedTest( def test_user_defined_typedec_impl_bind(self): class TypeOne(types.TypeEngine): - def bind_processor(self, dialect): def go(value): return value + " ONE" + return go class TypeTwo(types.TypeEngine): - def bind_processor(self, dialect): def go(value): return value + " TWO" + return go class MyType(types.TypeDecorator): impl = TypeOne def load_dialect_impl(self, dialect): - if dialect.name == 'sqlite': + if dialect.name == "sqlite": return TypeOne() else: return TypeTwo() def process_bind_param(self, value, dialect): return "MYTYPE " + value + sl = dialects.sqlite.dialect() pg = dialects.postgresql.dialect() t = MyType() - eq_( - t._cached_bind_processor(sl)('foo'), - "MYTYPE foo ONE" - ) - eq_( - t._cached_bind_processor(pg)('foo'), - "MYTYPE foo TWO" - ) + eq_(t._cached_bind_processor(sl)("foo"), "MYTYPE foo ONE") + eq_(t._cached_bind_processor(pg)("foo"), "MYTYPE foo TWO") def test_user_defined_dialect_specific_args(self): class MyType(types.UserDefinedType): - - def __init__(self, foo='foo', **kwargs): + def __init__(self, foo="foo", **kwargs): super(MyType, self).__init__() self.foo = foo self.dialect_specific_args = kwargs def adapt(self, cls): return cls(foo=self.foo, **self.dialect_specific_args) - t = MyType(bar='bar') + + t = MyType(bar="bar") a = t.dialect_impl(testing.db.dialect) - eq_(a.foo, 'foo') - eq_(a.dialect_specific_args['bar'], 'bar') + eq_(a.foo, "foo") + eq_(a.dialect_specific_args["bar"], "bar") class TypeCoerceCastTest(fixtures.TablesTest): - @classmethod def define_tables(cls, metadata): class MyType(types.TypeDecorator): @@ -660,12 +747,12 @@ class TypeCoerceCastTest(fixtures.TablesTest): cls.MyType = MyType - Table('t', metadata, Column('data', String(50))) + Table("t", metadata, Column("data", String(50))) @testing.fails_on( - "oracle", "oracle doesn't like CAST in the VALUES of an INSERT") - @testing.fails_on( - "mysql", "mysql dialect warns on skipped CAST") + "oracle", "oracle doesn't like CAST in the VALUES of an INSERT" + ) + @testing.fails_on("mysql", "mysql dialect warns on skipped CAST") def test_insert_round_trip_cast(self): self._test_insert_round_trip(cast) @@ -676,18 +763,19 @@ class TypeCoerceCastTest(fixtures.TablesTest): MyType = self.MyType t = self.tables.t - t.insert().values(data=coerce_fn('d1', MyType)).execute() + t.insert().values(data=coerce_fn("d1", MyType)).execute() eq_( select([coerce_fn(t.c.data, MyType)]).execute().fetchall(), - [('BIND_INd1BIND_OUT', )] + [("BIND_INd1BIND_OUT",)], ) @testing.fails_on( - "oracle", "ORA-00906: missing left parenthesis - " - "seems to be CAST(:param AS type)") - @testing.fails_on( - "mysql", "mysql dialect warns on skipped CAST") + "oracle", + "ORA-00906: missing left parenthesis - " + "seems to be CAST(:param AS type)", + ) + @testing.fails_on("mysql", "mysql dialect warns on skipped CAST") def test_coerce_from_nulltype_cast(self): self._test_coerce_from_nulltype(cast) @@ -700,7 +788,6 @@ class TypeCoerceCastTest(fixtures.TablesTest): # test coerce from nulltype - e.g. use an object that # does't match to a known type class MyObj(object): - def __str__(self): return "THISISMYOBJ" @@ -710,13 +797,13 @@ class TypeCoerceCastTest(fixtures.TablesTest): eq_( select([coerce_fn(t.c.data, MyType)]).execute().fetchall(), - [('BIND_INTHISISMYOBJBIND_OUT',)] + [("BIND_INTHISISMYOBJBIND_OUT",)], ) @testing.fails_on( - "oracle", "oracle doesn't like CAST in the VALUES of an INSERT") - @testing.fails_on( - "mysql", "mysql dialect warns on skipped CAST") + "oracle", "oracle doesn't like CAST in the VALUES of an INSERT" + ) + @testing.fails_on("mysql", "mysql dialect warns on skipped CAST") def test_vs_non_coerced_cast(self): self._test_vs_non_coerced(cast) @@ -727,18 +814,19 @@ class TypeCoerceCastTest(fixtures.TablesTest): MyType = self.MyType t = self.tables.t - t.insert().values(data=coerce_fn('d1', MyType)).execute() + t.insert().values(data=coerce_fn("d1", MyType)).execute() eq_( - select( - [t.c.data, coerce_fn(t.c.data, MyType)]).execute().fetchall(), - [('BIND_INd1', 'BIND_INd1BIND_OUT')] + select([t.c.data, coerce_fn(t.c.data, MyType)]) + .execute() + .fetchall(), + [("BIND_INd1", "BIND_INd1BIND_OUT")], ) @testing.fails_on( - "oracle", "oracle doesn't like CAST in the VALUES of an INSERT") - @testing.fails_on( - "mysql", "mysql dialect warns on skipped CAST") + "oracle", "oracle doesn't like CAST in the VALUES of an INSERT" + ) + @testing.fails_on("mysql", "mysql dialect warns on skipped CAST") def test_vs_non_coerced_alias_cast(self): self._test_vs_non_coerced_alias(cast) @@ -749,18 +837,21 @@ class TypeCoerceCastTest(fixtures.TablesTest): MyType = self.MyType t = self.tables.t - t.insert().values(data=coerce_fn('d1', MyType)).execute() + t.insert().values(data=coerce_fn("d1", MyType)).execute() eq_( - select([t.c.data, coerce_fn(t.c.data, MyType)]). - alias().select().execute().fetchall(), - [('BIND_INd1', 'BIND_INd1BIND_OUT')] + select([t.c.data, coerce_fn(t.c.data, MyType)]) + .alias() + .select() + .execute() + .fetchall(), + [("BIND_INd1", "BIND_INd1BIND_OUT")], ) @testing.fails_on( - "oracle", "oracle doesn't like CAST in the VALUES of an INSERT") - @testing.fails_on( - "mysql", "mysql dialect warns on skipped CAST") + "oracle", "oracle doesn't like CAST in the VALUES of an INSERT" + ) + @testing.fails_on("mysql", "mysql dialect warns on skipped CAST") def test_vs_non_coerced_where_cast(self): self._test_vs_non_coerced_where(cast) @@ -771,26 +862,30 @@ class TypeCoerceCastTest(fixtures.TablesTest): MyType = self.MyType t = self.tables.t - t.insert().values(data=coerce_fn('d1', MyType)).execute() + t.insert().values(data=coerce_fn("d1", MyType)).execute() # coerce on left side eq_( - select([t.c.data, coerce_fn(t.c.data, MyType)]). - where(coerce_fn(t.c.data, MyType) == 'd1').execute().fetchall(), - [('BIND_INd1', 'BIND_INd1BIND_OUT')] + select([t.c.data, coerce_fn(t.c.data, MyType)]) + .where(coerce_fn(t.c.data, MyType) == "d1") + .execute() + .fetchall(), + [("BIND_INd1", "BIND_INd1BIND_OUT")], ) # coerce on right side eq_( - select([t.c.data, coerce_fn(t.c.data, MyType)]). - where(t.c.data == coerce_fn('d1', MyType)).execute().fetchall(), - [('BIND_INd1', 'BIND_INd1BIND_OUT')] + select([t.c.data, coerce_fn(t.c.data, MyType)]) + .where(t.c.data == coerce_fn("d1", MyType)) + .execute() + .fetchall(), + [("BIND_INd1", "BIND_INd1BIND_OUT")], ) @testing.fails_on( - "oracle", "oracle doesn't like CAST in the VALUES of an INSERT") - @testing.fails_on( - "mysql", "mysql dialect warns on skipped CAST") + "oracle", "oracle doesn't like CAST in the VALUES of an INSERT" + ) + @testing.fails_on("mysql", "mysql dialect warns on skipped CAST") def test_coerce_none_cast(self): self._test_coerce_none(cast) @@ -801,24 +896,27 @@ class TypeCoerceCastTest(fixtures.TablesTest): MyType = self.MyType t = self.tables.t - t.insert().values(data=coerce_fn('d1', MyType)).execute() + t.insert().values(data=coerce_fn("d1", MyType)).execute() eq_( - select([t.c.data, coerce_fn(t.c.data, MyType)]). - where(t.c.data == coerce_fn(None, MyType)).execute().fetchall(), - [] + select([t.c.data, coerce_fn(t.c.data, MyType)]) + .where(t.c.data == coerce_fn(None, MyType)) + .execute() + .fetchall(), + [], ) eq_( - select([t.c.data, coerce_fn(t.c.data, MyType)]). - where(coerce_fn(t.c.data, MyType) == None). # noqa - execute().fetchall(), - [] + select([t.c.data, coerce_fn(t.c.data, MyType)]) + .where(coerce_fn(t.c.data, MyType) == None) + .execute() # noqa + .fetchall(), + [], ) @testing.fails_on( - "oracle", "oracle doesn't like CAST in the VALUES of an INSERT") - @testing.fails_on( - "mysql", "mysql dialect warns on skipped CAST") + "oracle", "oracle doesn't like CAST in the VALUES of an INSERT" + ) + @testing.fails_on("mysql", "mysql dialect warns on skipped CAST") def test_resolve_clause_element_cast(self): self._test_resolve_clause_element(cast) @@ -829,10 +927,9 @@ class TypeCoerceCastTest(fixtures.TablesTest): MyType = self.MyType t = self.tables.t - t.insert().values(data=coerce_fn('d1', MyType)).execute() + t.insert().values(data=coerce_fn("d1", MyType)).execute() class MyFoob(object): - def __clause_element__(self): return t.c.data @@ -840,7 +937,7 @@ class TypeCoerceCastTest(fixtures.TablesTest): testing.db.execute( select([t.c.data, coerce_fn(MyFoob(), MyType)]) ).fetchall(), - [('BIND_INd1', 'BIND_INd1BIND_OUT')] + [("BIND_INd1", "BIND_INd1BIND_OUT")], ) def test_cast_replace_col_w_bind(self): @@ -853,7 +950,7 @@ class TypeCoerceCastTest(fixtures.TablesTest): MyType = self.MyType t = self.tables.t - t.insert().values(data=coerce_fn('d1', MyType)).execute() + t.insert().values(data=coerce_fn("d1", MyType)).execute() stmt = select([t.c.data, coerce_fn(t.c.data, MyType)]) @@ -871,15 +968,16 @@ class TypeCoerceCastTest(fixtures.TablesTest): # original statement eq_( testing.db.execute(stmt).fetchall(), - [('BIND_INd1', 'BIND_INd1BIND_OUT')] + [("BIND_INd1", "BIND_INd1BIND_OUT")], ) # replaced with binds; CAST can't affect the bound parameter # on the way in here eq_( testing.db.execute(new_stmt).fetchall(), - [('x', 'BIND_INxBIND_OUT')] if coerce_fn is type_coerce - else [('x', 'xBIND_OUT')] + [("x", "BIND_INxBIND_OUT")] + if coerce_fn is type_coerce + else [("x", "xBIND_OUT")], ) def test_cast_bind(self): @@ -892,24 +990,30 @@ class TypeCoerceCastTest(fixtures.TablesTest): MyType = self.MyType t = self.tables.t - t.insert().values(data=coerce_fn('d1', MyType)).execute() + t.insert().values(data=coerce_fn("d1", MyType)).execute() - stmt = select([ - bindparam(None, "x", String(50), unique=True), - coerce_fn(bindparam(None, "x", String(50), unique=True), MyType) - ]) + stmt = select( + [ + bindparam(None, "x", String(50), unique=True), + coerce_fn( + bindparam(None, "x", String(50), unique=True), MyType + ), + ] + ) eq_( testing.db.execute(stmt).fetchall(), - [('x', 'BIND_INxBIND_OUT')] if coerce_fn is type_coerce - else [('x', 'xBIND_OUT')] + [("x", "BIND_INxBIND_OUT")] + if coerce_fn is type_coerce + else [("x", "xBIND_OUT")], ) @testing.fails_on( - "oracle", "ORA-00906: missing left parenthesis - " - "seems to be CAST(:param AS type)") - @testing.fails_on( - "mysql", "mysql dialect warns on skipped CAST") + "oracle", + "ORA-00906: missing left parenthesis - " + "seems to be CAST(:param AS type)", + ) + @testing.fails_on("mysql", "mysql dialect warns on skipped CAST") def test_cast_existing_typed(self): MyType = self.MyType coerce_fn = cast @@ -917,10 +1021,8 @@ class TypeCoerceCastTest(fixtures.TablesTest): # when cast() is given an already typed value, # the type does not take effect on the value itself. eq_( - testing.db.scalar( - select([coerce_fn(literal('d1'), MyType)]) - ), - 'd1BIND_OUT' + testing.db.scalar(select([coerce_fn(literal("d1"), MyType)])), + "d1BIND_OUT", ) def test_type_coerce_existing_typed(self): @@ -931,38 +1033,37 @@ class TypeCoerceCastTest(fixtures.TablesTest): # type_coerce does upgrade the given expression to the # given type. - t.insert().values(data=coerce_fn(literal('d1'), MyType)).execute() + t.insert().values(data=coerce_fn(literal("d1"), MyType)).execute() eq_( select([coerce_fn(t.c.data, MyType)]).execute().fetchall(), - [('BIND_INd1BIND_OUT', )]) + [("BIND_INd1BIND_OUT",)], + ) class VariantTest(fixtures.TestBase, AssertsCompiledSQL): - def setup(self): class UTypeOne(types.UserDefinedType): - def get_col_spec(self): return "UTYPEONE" def bind_processor(self, dialect): def process(value): return value + "UONE" + return process class UTypeTwo(types.UserDefinedType): - def get_col_spec(self): return "UTYPETWO" def bind_processor(self, dialect): def process(value): return value + "UTWO" + return process class UTypeThree(types.UserDefinedType): - def get_col_spec(self): return "UTYPETHREE" @@ -970,142 +1071,127 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL): self.UTypeTwo = UTypeTwo self.UTypeThree = UTypeThree self.variant = self.UTypeOne().with_variant( - self.UTypeTwo(), 'postgresql') - self.composite = self.variant.with_variant(self.UTypeThree(), 'mysql') + self.UTypeTwo(), "postgresql" + ) + self.composite = self.variant.with_variant(self.UTypeThree(), "mysql") def test_illegal_dupe(self): - v = self.UTypeOne().with_variant( - self.UTypeTwo(), 'postgresql' - ) + v = self.UTypeOne().with_variant(self.UTypeTwo(), "postgresql") assert_raises_message( exc.ArgumentError, "Dialect 'postgresql' is already present " "in the mapping for this Variant", - lambda: v.with_variant(self.UTypeThree(), 'postgresql') + lambda: v.with_variant(self.UTypeThree(), "postgresql"), ) def test_compile(self): + self.assert_compile(self.variant, "UTYPEONE", use_default_dialect=True) self.assert_compile( - self.variant, - "UTYPEONE", - use_default_dialect=True + self.variant, "UTYPEONE", dialect=dialects.mysql.dialect() ) self.assert_compile( - self.variant, - "UTYPEONE", - dialect=dialects.mysql.dialect() - ) - self.assert_compile( - self.variant, - "UTYPETWO", - dialect=dialects.postgresql.dialect() + self.variant, "UTYPETWO", dialect=dialects.postgresql.dialect() ) def test_to_instance(self): self.assert_compile( self.UTypeOne().with_variant(self.UTypeTwo, "postgresql"), "UTYPETWO", - dialect=dialects.postgresql.dialect() + dialect=dialects.postgresql.dialect(), ) def test_compile_composite(self): self.assert_compile( - self.composite, - "UTYPEONE", - use_default_dialect=True + self.composite, "UTYPEONE", use_default_dialect=True ) self.assert_compile( - self.composite, - "UTYPETHREE", - dialect=dialects.mysql.dialect() + self.composite, "UTYPETHREE", dialect=dialects.mysql.dialect() ) self.assert_compile( - self.composite, - "UTYPETWO", - dialect=dialects.postgresql.dialect() + self.composite, "UTYPETWO", dialect=dialects.postgresql.dialect() ) def test_bind_process(self): eq_( - self.variant._cached_bind_processor( - dialects.mysql.dialect())('foo'), - 'fooUONE' + self.variant._cached_bind_processor(dialects.mysql.dialect())( + "foo" + ), + "fooUONE", ) eq_( - self.variant._cached_bind_processor( - default.DefaultDialect())('foo'), - 'fooUONE' + self.variant._cached_bind_processor(default.DefaultDialect())( + "foo" + ), + "fooUONE", ) eq_( - self.variant._cached_bind_processor( - dialects.postgresql.dialect())('foo'), - 'fooUTWO' + self.variant._cached_bind_processor(dialects.postgresql.dialect())( + "foo" + ), + "fooUTWO", ) def test_bind_process_composite(self): - assert self.composite._cached_bind_processor( - dialects.mysql.dialect()) is None + assert ( + self.composite._cached_bind_processor(dialects.mysql.dialect()) + is None + ) eq_( - self.composite._cached_bind_processor( - default.DefaultDialect())('foo'), - 'fooUONE' + self.composite._cached_bind_processor(default.DefaultDialect())( + "foo" + ), + "fooUONE", ) eq_( self.composite._cached_bind_processor( - dialects.postgresql.dialect())('foo'), - 'fooUTWO' + dialects.postgresql.dialect() + )("foo"), + "fooUTWO", ) def test_comparator_variant(self): - expr = column('x', self.variant) == "bar" - is_( - expr.right.type, self.variant - ) + expr = column("x", self.variant) == "bar" + is_(expr.right.type, self.variant) @testing.only_on("sqlite") @testing.provide_metadata def test_round_trip(self): - variant = self.UTypeOne().with_variant( - self.UTypeTwo(), 'sqlite') + variant = self.UTypeOne().with_variant(self.UTypeTwo(), "sqlite") - t = Table('t', self.metadata, - Column('x', variant) - ) + t = Table("t", self.metadata, Column("x", variant)) with testing.db.connect() as conn: t.create(conn) - conn.execute( - t.insert(), - x='foo' - ) + conn.execute(t.insert(), x="foo") - eq_( - conn.scalar(select([t.c.x]).where(t.c.x == 'foo')), - 'fooUTWO' - ) + eq_(conn.scalar(select([t.c.x]).where(t.c.x == "foo")), "fooUTWO") @testing.only_on("sqlite") @testing.provide_metadata def test_round_trip_sqlite_datetime(self): variant = DateTime().with_variant( - dialects.sqlite.DATETIME(truncate_microseconds=True), 'sqlite') - - t = Table('t', self.metadata, - Column('x', variant) + dialects.sqlite.DATETIME(truncate_microseconds=True), "sqlite" ) + + t = Table("t", self.metadata, Column("x", variant)) with testing.db.connect() as conn: t.create(conn) conn.execute( - t.insert(), - x=datetime.datetime(2015, 4, 18, 10, 15, 17, 4839) + t.insert(), x=datetime.datetime(2015, 4, 18, 10, 15, 17, 4839) ) eq_( - conn.scalar(select([t.c.x]).where(t.c.x == datetime.datetime(2015, 4, 18, 10, 15, 17, 1059))), - datetime.datetime(2015, 4, 18, 10, 15, 17) + conn.scalar( + select([t.c.x]).where( + t.c.x + == datetime.datetime(2015, 4, 18, 10, 15, 17, 1059) + ) + ), + datetime.datetime(2015, 4, 18, 10, 15, 17), ) + class UnicodeTest(fixtures.TestBase): """Exercise the Unicode and related types. @@ -1114,12 +1200,14 @@ class UnicodeTest(fixtures.TestBase): sqlalchemy/testing/suite/test_types.py. """ + __backend__ = True data = util.u( "Alors vous imaginez ma surprise, au lever du jour, quand " "une drôle de petite voix m’a réveillé. " - "Elle disait: « S’il vous plaît… dessine-moi un mouton! »") + "Elle disait: « S’il vous plaît… dessine-moi un mouton! »" + ) def test_unicode_warnings_typelevel_native_unicode(self): @@ -1129,10 +1217,10 @@ class UnicodeTest(fixtures.TestBase): dialect.supports_unicode_binds = True uni = u.dialect_impl(dialect).bind_processor(dialect) if util.py3k: - assert_raises(exc.SAWarning, uni, b'x') + assert_raises(exc.SAWarning, uni, b"x") assert isinstance(uni(unicodedata), str) else: - assert_raises(exc.SAWarning, uni, 'x') + assert_raises(exc.SAWarning, uni, "x") assert isinstance(uni(unicodedata), unicode) # noqa def test_unicode_warnings_typelevel_sqla_unicode(self): @@ -1141,10 +1229,10 @@ class UnicodeTest(fixtures.TestBase): dialect = default.DefaultDialect() dialect.supports_unicode_binds = False uni = u.dialect_impl(dialect).bind_processor(dialect) - assert_raises(exc.SAWarning, uni, util.b('x')) + assert_raises(exc.SAWarning, uni, util.b("x")) assert isinstance(uni(unicodedata), util.binary_type) - eq_(uni(unicodedata), unicodedata.encode('utf-8')) + eq_(uni(unicodedata), unicodedata.encode("utf-8")) def test_unicode_warnings_totally_wrong_type(self): u = Unicode() @@ -1152,7 +1240,8 @@ class UnicodeTest(fixtures.TestBase): dialect.supports_unicode_binds = False uni = u.dialect_impl(dialect).bind_processor(dialect) with expect_warnings( - "Unicode type received non-unicode bind param value 5."): + "Unicode type received non-unicode bind param value 5." + ): eq_(uni(5), 5) def test_unicode_warnings_dialectlevel(self): @@ -1165,10 +1254,10 @@ class UnicodeTest(fixtures.TestBase): s = String() uni = s.dialect_impl(dialect).bind_processor(dialect) - uni(util.b('x')) + uni(util.b("x")) assert isinstance(uni(unicodedata), util.binary_type) - eq_(uni(unicodedata), unicodedata.encode('utf-8')) + eq_(uni(unicodedata), unicodedata.encode("utf-8")) def test_ignoring_unicode_error(self): """checks String(unicode_error='ignore') is passed to @@ -1176,15 +1265,12 @@ class UnicodeTest(fixtures.TestBase): unicodedata = self.data - type_ = String(248, convert_unicode='force', unicode_error='ignore') - dialect = default.DefaultDialect(encoding='ascii') + type_ = String(248, convert_unicode="force", unicode_error="ignore") + dialect = default.DefaultDialect(encoding="ascii") proc = type_.result_processor(dialect, 10) - utfdata = unicodedata.encode('utf8') - eq_( - proc(utfdata), - unicodedata.encode('ascii', 'ignore').decode() - ) + utfdata = unicodedata.encode("utf8") + eq_(proc(utfdata), unicodedata.encode("ascii", "ignore").decode()) class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): @@ -1206,17 +1292,17 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): class SomeOtherEnum(SomeEnum): __members__ = OrderedDict() - one = SomeEnum('one', 1) - two = SomeEnum('two', 2) - three = SomeEnum('three', 3, 'four') - a_member = SomeEnum('AMember', 'a') - b_member = SomeEnum('BMember', 'b') + one = SomeEnum("one", 1) + two = SomeEnum("two", 2) + three = SomeEnum("three", 3, "four") + a_member = SomeEnum("AMember", "a") + b_member = SomeEnum("BMember", "b") - other_one = SomeOtherEnum('one', 1) - other_two = SomeOtherEnum('two', 2) - other_three = SomeOtherEnum('three', 3) - other_a_member = SomeOtherEnum('AMember', 'a') - other_b_member = SomeOtherEnum('BMember', 'b') + other_one = SomeOtherEnum("one", 1) + other_two = SomeOtherEnum("two", 2) + other_three = SomeOtherEnum("three", 3) + other_a_member = SomeOtherEnum("AMember", "a") + other_b_member = SomeOtherEnum("BMember", "b") @staticmethod def get_enum_string_values(some_enum): @@ -1225,32 +1311,48 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table( - 'enum_table', metadata, Column("id", Integer, primary_key=True), - Column('someenum', Enum('one', 'two', 'three', name='myenum')) + "enum_table", + metadata, + Column("id", Integer, primary_key=True), + Column("someenum", Enum("one", "two", "three", name="myenum")), ) Table( - 'non_native_enum_table', metadata, + "non_native_enum_table", + metadata, Column("id", Integer, primary_key=True, autoincrement=False), - Column('someenum', Enum('one', 'two', 'three', native_enum=False)), - Column('someotherenum', - Enum('one', 'two', 'three', - create_constraint=False, native_enum=False, - validate_strings=True)), + Column("someenum", Enum("one", "two", "three", native_enum=False)), + Column( + "someotherenum", + Enum( + "one", + "two", + "three", + create_constraint=False, + native_enum=False, + validate_strings=True, + ), + ), ) Table( - 'stdlib_enum_table', metadata, + "stdlib_enum_table", + metadata, Column("id", Integer, primary_key=True), - Column('someenum', Enum(cls.SomeEnum)) + Column("someenum", Enum(cls.SomeEnum)), ) Table( - 'stdlib_enum_table2', metadata, + "stdlib_enum_table2", + metadata, Column("id", Integer, primary_key=True), - Column('someotherenum', - Enum(cls.SomeOtherEnum, - values_callable=EnumTest.get_enum_string_values)) + Column( + "someotherenum", + Enum( + cls.SomeOtherEnum, + values_callable=EnumTest.get_enum_string_values, + ), + ), ) def test_python_type(self): @@ -1261,12 +1363,12 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): SomeEnum = self.SomeEnum for loads, dumps in picklers(): column_types = [ - Column('Enu', Enum('x', 'y', 'z', name="somename")), - Column('En2', Enum(self.SomeEnum)), + Column("Enu", Enum("x", "y", "z", name="somename")), + Column("En2", Enum(self.SomeEnum)), ] for column_type in column_types: meta = MetaData() - Table('foo', meta, column_type) + Table("foo", meta, column_type) loads(dumps(column_type)) loads(dumps(meta)) @@ -1276,34 +1378,39 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): bind_processor = type_.bind_processor(testing.db.dialect) bind_processor_validates = validate_type.bind_processor( - testing.db.dialect) - eq_(bind_processor('one'), "one") + testing.db.dialect + ) + eq_(bind_processor("one"), "one") eq_(bind_processor(self.one), "one") eq_(bind_processor("foo"), "foo") assert_raises_message( LookupError, '"5" is not among the defined enum values', - bind_processor, 5 + bind_processor, + 5, ) assert_raises_message( LookupError, '"foo" is not among the defined enum values', - bind_processor_validates, "foo" + bind_processor_validates, + "foo", ) result_processor = type_.result_processor(testing.db.dialect, None) - eq_(result_processor('one'), self.one) + eq_(result_processor("one"), self.one) assert_raises_message( LookupError, '"foo" is not among the defined enum values', - result_processor, "foo" + result_processor, + "foo", ) literal_processor = type_.literal_processor(testing.db.dialect) validate_literal_processor = validate_type.literal_processor( - testing.db.dialect) + testing.db.dialect + ) eq_(literal_processor("one"), "'one'") eq_(literal_processor("foo"), "'foo'") @@ -1311,13 +1418,15 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): assert_raises_message( LookupError, '"5" is not among the defined enum values', - literal_processor, 5 + literal_processor, + 5, ) assert_raises_message( LookupError, '"foo" is not among the defined enum values', - validate_literal_processor, "foo" + validate_literal_processor, + "foo", ) def test_validators_plain(self): @@ -1326,105 +1435,112 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): bind_processor = type_.bind_processor(testing.db.dialect) bind_processor_validates = validate_type.bind_processor( - testing.db.dialect) - eq_(bind_processor('one'), "one") - eq_(bind_processor('foo'), "foo") + testing.db.dialect + ) + eq_(bind_processor("one"), "one") + eq_(bind_processor("foo"), "foo") assert_raises_message( LookupError, '"5" is not among the defined enum values', - bind_processor, 5 + bind_processor, + 5, ) assert_raises_message( LookupError, '"foo" is not among the defined enum values', - bind_processor_validates, "foo" + bind_processor_validates, + "foo", ) result_processor = type_.result_processor(testing.db.dialect, None) - eq_(result_processor('one'), "one") + eq_(result_processor("one"), "one") assert_raises_message( LookupError, '"foo" is not among the defined enum values', - result_processor, "foo" + result_processor, + "foo", ) literal_processor = type_.literal_processor(testing.db.dialect) validate_literal_processor = validate_type.literal_processor( - testing.db.dialect) + testing.db.dialect + ) eq_(literal_processor("one"), "'one'") eq_(literal_processor("foo"), "'foo'") assert_raises_message( LookupError, '"5" is not among the defined enum values', - literal_processor, 5 + literal_processor, + 5, ) assert_raises_message( LookupError, '"foo" is not among the defined enum values', - validate_literal_processor, "foo" + validate_literal_processor, + "foo", ) def test_validators_not_in_like_roundtrip(self): - enum_table = self.tables['non_native_enum_table'] - - enum_table.insert().execute([ - {'id': 1, 'someenum': 'two'}, - {'id': 2, 'someenum': 'two'}, - {'id': 3, 'someenum': 'one'}, - ]) + enum_table = self.tables["non_native_enum_table"] - eq_( - enum_table.select(). - where(enum_table.c.someenum.like('%wo%')). - order_by(enum_table.c.id).execute().fetchall(), + enum_table.insert().execute( [ - (1, 'two', None), - (2, 'two', None), + {"id": 1, "someenum": "two"}, + {"id": 2, "someenum": "two"}, + {"id": 3, "someenum": "one"}, ] ) - def test_validators_not_in_concatenate_roundtrip(self): - enum_table = self.tables['non_native_enum_table'] + eq_( + enum_table.select() + .where(enum_table.c.someenum.like("%wo%")) + .order_by(enum_table.c.id) + .execute() + .fetchall(), + [(1, "two", None), (2, "two", None)], + ) - enum_table.insert().execute([ - {'id': 1, 'someenum': 'two'}, - {'id': 2, 'someenum': 'two'}, - {'id': 3, 'someenum': 'one'}, - ]) + def test_validators_not_in_concatenate_roundtrip(self): + enum_table = self.tables["non_native_enum_table"] - eq_( - select(['foo' + enum_table.c.someenum]). - order_by(enum_table.c.id).execute().fetchall(), + enum_table.insert().execute( [ - ('footwo', ), - ('footwo', ), - ('fooone', ) + {"id": 1, "someenum": "two"}, + {"id": 2, "someenum": "two"}, + {"id": 3, "someenum": "one"}, ] ) + eq_( + select(["foo" + enum_table.c.someenum]) + .order_by(enum_table.c.id) + .execute() + .fetchall(), + [("footwo",), ("footwo",), ("fooone",)], + ) + @testing.fails_on( - 'postgresql+zxjdbc', + "postgresql+zxjdbc", 'zxjdbc fails on ENUM: column "XXX" is of type XXX ' - 'but expression is of type character varying') + "but expression is of type character varying", + ) def test_round_trip(self): - enum_table = self.tables['enum_table'] + enum_table = self.tables["enum_table"] - enum_table.insert().execute([ - {'id': 1, 'someenum': 'two'}, - {'id': 2, 'someenum': 'two'}, - {'id': 3, 'someenum': 'one'}, - ]) + enum_table.insert().execute( + [ + {"id": 1, "someenum": "two"}, + {"id": 2, "someenum": "two"}, + {"id": 3, "someenum": "one"}, + ] + ) eq_( enum_table.select().order_by(enum_table.c.id).execute().fetchall(), - [ - (1, 'two'), - (2, 'two'), - (3, 'one'), - ] + [(1, "two"), (2, "two"), (3, "one")], ) def test_null_round_trip(self): @@ -1437,50 +1553,57 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): with testing.db.connect() as conn: conn.execute( - non_native_enum_table.insert(), {"id": 1, "someenum": None}) + non_native_enum_table.insert(), {"id": 1, "someenum": None} + ) eq_(conn.scalar(select([non_native_enum_table.c.someenum])), None) @testing.requires.enforces_check_constraints def test_check_constraint(self): assert_raises( ( - exc.IntegrityError, exc.ProgrammingError, + exc.IntegrityError, + exc.ProgrammingError, exc.OperationalError, # PyMySQL raising InternalError until # https://github.com/PyMySQL/PyMySQL/issues/607 is resolved - exc.InternalError), + exc.InternalError, + ), testing.db.execute, "insert into non_native_enum_table " - "(id, someenum) values(1, 'four')") + "(id, someenum) values(1, 'four')", + ) @testing.requires.enforces_check_constraints @testing.provide_metadata def test_variant_we_are_default(self): # test that the "variant" does not create a constraint t = Table( - 'my_table', self.metadata, + "my_table", + self.metadata, Column( - 'data', Enum("one", "two", "three", - native_enum=False, name="e1").with_variant( - Enum("four", "five", "six", native_enum=False, - name="e2"), "some_other_db" - ) + "data", + Enum( + "one", "two", "three", native_enum=False, name="e1" + ).with_variant( + Enum("four", "five", "six", native_enum=False, name="e2"), + "some_other_db", + ), ), - mysql_engine='InnoDB' + mysql_engine="InnoDB", ) eq_( len([c for c in t.constraints if isinstance(c, CheckConstraint)]), - 2 + 2, ) with testing.db.connect() as conn: self.metadata.create_all(conn) assert_raises( - (exc.DBAPIError, ), + (exc.DBAPIError,), conn.execute, - "insert into my_table " - "(data) values('four')") + "insert into my_table " "(data) values('four')", + ) conn.execute("insert into my_table (data) values ('two')") @testing.requires.enforces_check_constraints @@ -1488,29 +1611,32 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): def test_variant_we_are_not_default(self): # test that the "variant" does not create a constraint t = Table( - 'my_table', self.metadata, + "my_table", + self.metadata, Column( - 'data', Enum("one", "two", "three", native_enum=False, - name="e1").with_variant( + "data", + Enum( + "one", "two", "three", native_enum=False, name="e1" + ).with_variant( Enum("four", "five", "six", native_enum=False, name="e2"), - testing.db.dialect.name - ) - ) + testing.db.dialect.name, + ), + ), ) # ensure Variant isn't exploding the constraints eq_( len([c for c in t.constraints if isinstance(c, CheckConstraint)]), - 2 + 2, ) with testing.db.connect() as conn: self.metadata.create_all(conn) assert_raises( - (exc.DBAPIError, ), + (exc.DBAPIError,), conn.execute, - "insert into my_table " - "(data) values('two')") + "insert into my_table " "(data) values('two')", + ) conn.execute("insert into my_table (data) values ('four')") def test_skip_check_constraint(self): @@ -1521,51 +1647,56 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): ) eq_( conn.scalar("select someotherenum from non_native_enum_table"), - "four") + "four", + ) assert_raises_message( LookupError, '"four" is not among the defined enum values', conn.scalar, - select([self.tables.non_native_enum_table.c.someotherenum]) + select([self.tables.non_native_enum_table.c.someotherenum]), ) def test_non_native_round_trip(self): - non_native_enum_table = self.tables['non_native_enum_table'] + non_native_enum_table = self.tables["non_native_enum_table"] - non_native_enum_table.insert().execute([ - {'id': 1, 'someenum': 'two'}, - {'id': 2, 'someenum': 'two'}, - {'id': 3, 'someenum': 'one'}, - ]) - - eq_( - select([ - non_native_enum_table.c.id, - non_native_enum_table.c.someenum]). - order_by(non_native_enum_table.c.id).execute().fetchall(), + non_native_enum_table.insert().execute( [ - (1, 'two'), - (2, 'two'), - (3, 'one'), + {"id": 1, "someenum": "two"}, + {"id": 2, "someenum": "two"}, + {"id": 3, "someenum": "one"}, ] ) + eq_( + select( + [non_native_enum_table.c.id, non_native_enum_table.c.someenum] + ) + .order_by(non_native_enum_table.c.id) + .execute() + .fetchall(), + [(1, "two"), (2, "two"), (3, "one")], + ) + def test_pep435_enum_round_trip(self): - stdlib_enum_table = self.tables['stdlib_enum_table'] - - stdlib_enum_table.insert().execute([ - {'id': 1, 'someenum': self.SomeEnum.two}, - {'id': 2, 'someenum': self.SomeEnum.two}, - {'id': 3, 'someenum': self.SomeEnum.one}, - {'id': 4, 'someenum': self.SomeEnum.three}, - {'id': 5, 'someenum': self.SomeEnum.four}, - {'id': 6, 'someenum': 'three'}, - {'id': 7, 'someenum': 'four'}, - ]) + stdlib_enum_table = self.tables["stdlib_enum_table"] + + stdlib_enum_table.insert().execute( + [ + {"id": 1, "someenum": self.SomeEnum.two}, + {"id": 2, "someenum": self.SomeEnum.two}, + {"id": 3, "someenum": self.SomeEnum.one}, + {"id": 4, "someenum": self.SomeEnum.three}, + {"id": 5, "someenum": self.SomeEnum.four}, + {"id": 6, "someenum": "three"}, + {"id": 7, "someenum": "four"}, + ] + ) eq_( - stdlib_enum_table.select(). - order_by(stdlib_enum_table.c.id).execute().fetchall(), + stdlib_enum_table.select() + .order_by(stdlib_enum_table.c.id) + .execute() + .fetchall(), [ (1, self.SomeEnum.two), (2, self.SomeEnum.two), @@ -1574,108 +1705,118 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): (5, self.SomeEnum.three), (6, self.SomeEnum.three), (7, self.SomeEnum.three), - ] + ], ) def test_pep435_enum_values_callable_round_trip(self): - stdlib_enum_table_custom_values =\ - self.tables['stdlib_enum_table2'] + stdlib_enum_table_custom_values = self.tables["stdlib_enum_table2"] - stdlib_enum_table_custom_values.insert().execute([ - {'id': 1, 'someotherenum': self.SomeOtherEnum.AMember}, - {'id': 2, 'someotherenum': self.SomeOtherEnum.BMember}, - {'id': 3, 'someotherenum': self.SomeOtherEnum.AMember} - ]) + stdlib_enum_table_custom_values.insert().execute( + [ + {"id": 1, "someotherenum": self.SomeOtherEnum.AMember}, + {"id": 2, "someotherenum": self.SomeOtherEnum.BMember}, + {"id": 3, "someotherenum": self.SomeOtherEnum.AMember}, + ] + ) eq_( - stdlib_enum_table_custom_values.select(). - order_by(stdlib_enum_table_custom_values.c.id).execute(). - fetchall(), + stdlib_enum_table_custom_values.select() + .order_by(stdlib_enum_table_custom_values.c.id) + .execute() + .fetchall(), [ (1, self.SomeOtherEnum.AMember), (2, self.SomeOtherEnum.BMember), - (3, self.SomeOtherEnum.AMember) - ] + (3, self.SomeOtherEnum.AMember), + ], ) def test_pep435_enum_expanding_in(self): - stdlib_enum_table_custom_values =\ - self.tables['stdlib_enum_table2'] - - stdlib_enum_table_custom_values.insert().execute([ - {'id': 1, 'someotherenum': self.SomeOtherEnum.one}, - {'id': 2, 'someotherenum': self.SomeOtherEnum.two}, - {'id': 3, 'someotherenum': self.SomeOtherEnum.three} - ]) - - stmt = stdlib_enum_table_custom_values.select().where( - stdlib_enum_table_custom_values.c.someotherenum.in_( - bindparam("member", expanding=True) + stdlib_enum_table_custom_values = self.tables["stdlib_enum_table2"] + + stdlib_enum_table_custom_values.insert().execute( + [ + {"id": 1, "someotherenum": self.SomeOtherEnum.one}, + {"id": 2, "someotherenum": self.SomeOtherEnum.two}, + {"id": 3, "someotherenum": self.SomeOtherEnum.three}, + ] + ) + + stmt = ( + stdlib_enum_table_custom_values.select() + .where( + stdlib_enum_table_custom_values.c.someotherenum.in_( + bindparam("member", expanding=True) + ) ) - ).order_by(stdlib_enum_table_custom_values.c.id) + .order_by(stdlib_enum_table_custom_values.c.id) + ) eq_( testing.db.execute( stmt, - {"member": [ - self.SomeOtherEnum.one, - self.SomeOtherEnum.three]} + {"member": [self.SomeOtherEnum.one, self.SomeOtherEnum.three]}, ).fetchall(), - [ - (1, self.SomeOtherEnum.one), - (3, self.SomeOtherEnum.three) - ] + [(1, self.SomeOtherEnum.one), (3, self.SomeOtherEnum.three)], ) def test_adapt(self): from sqlalchemy.dialects.postgresql import ENUM - e1 = Enum('one', 'two', 'three', native_enum=False) + + e1 = Enum("one", "two", "three", native_enum=False) false_adapt = e1.adapt(ENUM) eq_(false_adapt.native_enum, False) assert not isinstance(false_adapt, ENUM) - e1 = Enum('one', 'two', 'three', native_enum=True) + e1 = Enum("one", "two", "three", native_enum=True) true_adapt = e1.adapt(ENUM) eq_(true_adapt.native_enum, True) assert isinstance(true_adapt, ENUM) - e1 = Enum('one', 'two', 'three', name='foo', - schema='bar', metadata=MetaData()) - eq_(e1.adapt(ENUM).name, 'foo') - eq_(e1.adapt(ENUM).schema, 'bar') + e1 = Enum( + "one", + "two", + "three", + name="foo", + schema="bar", + metadata=MetaData(), + ) + eq_(e1.adapt(ENUM).name, "foo") + eq_(e1.adapt(ENUM).schema, "bar") is_(e1.adapt(ENUM).metadata, e1.metadata) - eq_(e1.adapt(Enum).name, 'foo') - eq_(e1.adapt(Enum).schema, 'bar') + eq_(e1.adapt(Enum).name, "foo") + eq_(e1.adapt(Enum).schema, "bar") is_(e1.adapt(Enum).metadata, e1.metadata) e1 = Enum(self.SomeEnum) - eq_(e1.adapt(ENUM).name, 'someenum') - eq_(e1.adapt(ENUM).enums, - ['one', 'two', 'three', 'four', 'AMember', 'BMember']) + eq_(e1.adapt(ENUM).name, "someenum") + eq_( + e1.adapt(ENUM).enums, + ["one", "two", "three", "four", "AMember", "BMember"], + ) - e1_vc = Enum(self.SomeOtherEnum, - values_callable=EnumTest.get_enum_string_values) - eq_(e1_vc.adapt(ENUM).name, 'someotherenum') - eq_(e1_vc.adapt(ENUM).enums, ['1', '2', '3', 'a', 'b']) + e1_vc = Enum( + self.SomeOtherEnum, values_callable=EnumTest.get_enum_string_values + ) + eq_(e1_vc.adapt(ENUM).name, "someotherenum") + eq_(e1_vc.adapt(ENUM).enums, ["1", "2", "3", "a", "b"]) @testing.provide_metadata def test_create_metadata_bound_no_crash(self): m1 = self.metadata - Enum('a', 'b', 'c', metadata=m1, name='ncenum') + Enum("a", "b", "c", metadata=m1, name="ncenum") m1.create_all(testing.db) def test_non_native_constraint_custom_type(self): class Foob(object): - def __init__(self, name): self.name = name class MyEnum(TypeDecorator): - def __init__(self, values): self.impl = Enum( - *[v.name for v in values], name="myenum", - native_enum=False) + *[v.name for v in values], name="myenum", native_enum=False + ) # future method def process_literal_param(self, value, dialect): @@ -1685,21 +1826,22 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): return value.name m = MetaData() - t1 = Table('t', m, Column('x', MyEnum([Foob('a'), Foob('b')]))) - const = [ - c for c in t1.constraints if isinstance(c, CheckConstraint)][0] + t1 = Table("t", m, Column("x", MyEnum([Foob("a"), Foob("b")]))) + const = [c for c in t1.constraints if isinstance(c, CheckConstraint)][ + 0 + ] self.assert_compile( AddConstraint(const), "ALTER TABLE t ADD CONSTRAINT myenum CHECK (x IN ('a', 'b'))", - dialect="default" + dialect="default", ) def test_lookup_failure(self): assert_raises( exc.StatementError, - self.tables['non_native_enum_table'].insert().execute, - {'id': 4, 'someotherenum': 'four'} + self.tables["non_native_enum_table"].insert().execute, + {"id": 4, "someotherenum": "four"}, ) def test_mock_engine_no_prob(self): @@ -1707,7 +1849,7 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): are created with checkfirst=False""" e = engines.mock_engine() - t = Table('t1', MetaData(), Column('x', Enum("x", "y", name="pge"))) + t = Table("t1", MetaData(), Column("x", Enum("x", "y", name="pge"))) t.create(e, checkfirst=False) # basically looking for the start of # the constraint, or the ENUM def itself, @@ -1716,12 +1858,19 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): def test_repr(self): e = Enum( - "x", "y", name="somename", convert_unicode=True, quote=True, - inherit_schema=True, native_enum=False) + "x", + "y", + name="somename", + convert_unicode=True, + quote=True, + inherit_schema=True, + native_enum=False, + ) eq_( repr(e), "Enum('x', 'y', name='somename', " - "inherit_schema=True, native_enum=False)") + "inherit_schema=True, native_enum=False)", + ) binary_table = MyPickleType = metadata = None @@ -1739,25 +1888,29 @@ class BinaryTest(fixtures.TestBase, AssertsExecutionResults): def process_bind_param(self, value, dialect): if value: - value.stuff = 'this is modified stuff' + value.stuff = "this is modified stuff" return value def process_result_value(self, value, dialect): if value: - value.stuff = 'this is the right stuff' + value.stuff = "this is the right stuff" return value metadata = MetaData(testing.db) binary_table = Table( - 'binary_table', metadata, + "binary_table", + metadata, Column( - 'primary_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', LargeBinary), - Column('data_slice', LargeBinary(100)), - Column('misc', String(30)), - Column('pickled', PickleType), - Column('mypickle', MyPickleType) + "primary_id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("data", LargeBinary), + Column("data_slice", LargeBinary(100)), + Column("misc", String(30)), + Column("pickled", PickleType), + Column("mypickle", MyPickleType), ) metadata.create_all() @@ -1771,57 +1924,79 @@ class BinaryTest(fixtures.TestBase, AssertsExecutionResults): @testing.requires.non_broken_binary def test_round_trip(self): - testobj1 = pickleable.Foo('im foo 1') - testobj2 = pickleable.Foo('im foo 2') - testobj3 = pickleable.Foo('im foo 3') + testobj1 = pickleable.Foo("im foo 1") + testobj2 = pickleable.Foo("im foo 2") + testobj3 = pickleable.Foo("im foo 3") - stream1 = self.load_stream('binary_data_one.dat') - stream2 = self.load_stream('binary_data_two.dat') + stream1 = self.load_stream("binary_data_one.dat") + stream2 = self.load_stream("binary_data_two.dat") binary_table.insert().execute( - primary_id=1, misc='binary_data_one.dat', data=stream1, - data_slice=stream1[0:100], pickled=testobj1, mypickle=testobj3) + primary_id=1, + misc="binary_data_one.dat", + data=stream1, + data_slice=stream1[0:100], + pickled=testobj1, + mypickle=testobj3, + ) binary_table.insert().execute( - primary_id=2, misc='binary_data_two.dat', data=stream2, - data_slice=stream2[0:99], pickled=testobj2) + primary_id=2, + misc="binary_data_two.dat", + data=stream2, + data_slice=stream2[0:99], + pickled=testobj2, + ) binary_table.insert().execute( - primary_id=3, misc='binary_data_two.dat', data=None, - data_slice=stream2[0:99], pickled=None) + primary_id=3, + misc="binary_data_two.dat", + data=None, + data_slice=stream2[0:99], + pickled=None, + ) for stmt in ( binary_table.select(order_by=binary_table.c.primary_id), text( "select * from binary_table order by binary_table.primary_id", typemap={ - 'pickled': PickleType, 'mypickle': MyPickleType, - 'data': LargeBinary, 'data_slice': LargeBinary}, - bind=testing.db) + "pickled": PickleType, + "mypickle": MyPickleType, + "data": LargeBinary, + "data_slice": LargeBinary, + }, + bind=testing.db, + ), ): result = stmt.execute().fetchall() - eq_(stream1, result[0]['data']) - eq_(stream1[0:100], result[0]['data_slice']) - eq_(stream2, result[1]['data']) - eq_(testobj1, result[0]['pickled']) - eq_(testobj2, result[1]['pickled']) - eq_(testobj3.moredata, result[0]['mypickle'].moredata) - eq_(result[0]['mypickle'].stuff, 'this is the right stuff') + eq_(stream1, result[0]["data"]) + eq_(stream1[0:100], result[0]["data_slice"]) + eq_(stream2, result[1]["data"]) + eq_(testobj1, result[0]["pickled"]) + eq_(testobj2, result[1]["pickled"]) + eq_(testobj3.moredata, result[0]["mypickle"].moredata) + eq_(result[0]["mypickle"].stuff, "this is the right stuff") @testing.requires.binary_comparisons def test_comparison(self): """test that type coercion occurs on comparison for binary""" - expr = binary_table.c.data == 'foo' + expr = binary_table.c.data == "foo" assert isinstance(expr.right.type, LargeBinary) data = os.urandom(32) binary_table.insert().execute(data=data) eq_( - select([func.count('*')]).select_from(binary_table). - where(binary_table.c.data == data).scalar(), 1) + select([func.count("*")]) + .select_from(binary_table) + .where(binary_table.c.data == data) + .scalar(), + 1, + ) @testing.requires.binary_literals def test_literal_roundtrip(self): compiled = select([cast(literal(util.b("foo")), LargeBinary)]).compile( - dialect=testing.db.dialect, compile_kwargs={"literal_binds": True}) + dialect=testing.db.dialect, compile_kwargs={"literal_binds": True} + ) result = testing.db.execute(compiled) eq_(result.scalar(), util.b("foo")) @@ -1831,18 +2006,19 @@ class BinaryTest(fixtures.TestBase, AssertsExecutionResults): def load_stream(self, name): f = os.path.join(os.path.dirname(__file__), "..", name) - with open(f, mode='rb') as o: + with open(f, mode="rb") as o: return o.read() class JSONTest(fixtures.TestBase): - def setup(self): metadata = MetaData() - self.test_table = Table('test_table', metadata, - Column('id', Integer, primary_key=True), - Column('test_column', JSON), - ) + self.test_table = Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("test_column", JSON), + ) self.jsoncol = self.test_table.c.test_column self.dialect = default.DefaultDialect() @@ -1851,63 +2027,50 @@ class JSONTest(fixtures.TestBase): def test_bind_serialize_default(self): proc = self.test_table.c.test_column.type._cached_bind_processor( - self.dialect) + self.dialect + ) eq_( proc({"A": [1, 2, 3, True, False]}), - '{"A": [1, 2, 3, true, false]}' + '{"A": [1, 2, 3, true, false]}', ) def test_bind_serialize_None(self): proc = self.test_table.c.test_column.type._cached_bind_processor( - self.dialect) - eq_( - proc(None), - 'null' + self.dialect ) + eq_(proc(None), "null") def test_bind_serialize_none_as_null(self): - proc = JSON(none_as_null=True)._cached_bind_processor( - self.dialect) - eq_( - proc(None), - None - ) - eq_( - proc(null()), - None - ) + proc = JSON(none_as_null=True)._cached_bind_processor(self.dialect) + eq_(proc(None), None) + eq_(proc(null()), None) def test_bind_serialize_null(self): proc = self.test_table.c.test_column.type._cached_bind_processor( - self.dialect) - eq_( - proc(null()), - None + self.dialect ) + eq_(proc(null()), None) def test_result_deserialize_default(self): proc = self.test_table.c.test_column.type._cached_result_processor( - self.dialect, None) + self.dialect, None + ) eq_( proc('{"A": [1, 2, 3, true, false]}'), - {"A": [1, 2, 3, True, False]} + {"A": [1, 2, 3, True, False]}, ) def test_result_deserialize_null(self): proc = self.test_table.c.test_column.type._cached_result_processor( - self.dialect, None) - eq_( - proc('null'), - None + self.dialect, None ) + eq_(proc("null"), None) def test_result_deserialize_None(self): proc = self.test_table.c.test_column.type._cached_result_processor( - self.dialect, None) - eq_( - proc(None), - None + self.dialect, None ) + eq_(proc(None), None) def _dialect_index_fixture(self, int_processor, str_processor): class MyInt(Integer): @@ -1958,19 +2121,19 @@ class JSONTest(fixtures.TestBase): eq_(bindproc(expr.right.value), "5") def test_index_bind_proc_str(self): - expr = self.test_table.c.test_column['five'] + expr = self.test_table.c.test_column["five"] str_dialect = self._dialect_index_fixture(True, True) non_str_dialect = self._dialect_index_fixture(False, False) bindproc = expr.right.type._cached_bind_processor(str_dialect) - eq_(bindproc(expr.right.value), 'five10') + eq_(bindproc(expr.right.value), "five10") bindproc = expr.right.type._cached_bind_processor(non_str_dialect) - eq_(bindproc(expr.right.value), 'five') + eq_(bindproc(expr.right.value), "five") def test_index_literal_proc_str(self): - expr = self.test_table.c.test_column['five'] + expr = self.test_table.c.test_column["five"] str_dialect = self._dialect_index_fixture(True, True) non_str_dialect = self._dialect_index_fixture(False, False) @@ -1983,36 +2146,27 @@ class JSONTest(fixtures.TestBase): class ArrayTest(fixtures.TestBase): - def _myarray_fixture(self): class MyArray(ARRAY): pass + return MyArray def test_array_index_map_dimensions(self): - col = column('x', ARRAY(Integer, dimensions=3)) - is_( - col[5].type._type_affinity, ARRAY - ) - eq_( - col[5].type.dimensions, 2 - ) - is_( - col[5][6].type._type_affinity, ARRAY - ) - eq_( - col[5][6].type.dimensions, 1 - ) - is_( - col[5][6][7].type._type_affinity, Integer - ) + col = column("x", ARRAY(Integer, dimensions=3)) + is_(col[5].type._type_affinity, ARRAY) + eq_(col[5].type.dimensions, 2) + is_(col[5][6].type._type_affinity, ARRAY) + eq_(col[5][6].type.dimensions, 1) + is_(col[5][6][7].type._type_affinity, Integer) def test_array_getitem_single_type(self): m = MetaData() arrtable = Table( - 'arrtable', m, - Column('intarr', ARRAY(Integer)), - Column('strarr', ARRAY(String)), + "arrtable", + m, + Column("intarr", ARRAY(Integer)), + Column("strarr", ARRAY(String)), ) is_(arrtable.c.intarr[1].type._type_affinity, Integer) is_(arrtable.c.strarr[1].type._type_affinity, String) @@ -2020,9 +2174,10 @@ class ArrayTest(fixtures.TestBase): def test_array_getitem_slice_type(self): m = MetaData() arrtable = Table( - 'arrtable', m, - Column('intarr', ARRAY(Integer)), - Column('strarr', ARRAY(String)), + "arrtable", + m, + Column("intarr", ARRAY(Integer)), + Column("strarr", ARRAY(String)), ) is_(arrtable.c.intarr[1:3].type._type_affinity, ARRAY) is_(arrtable.c.strarr[1:3].type._type_affinity, ARRAY) @@ -2031,9 +2186,10 @@ class ArrayTest(fixtures.TestBase): MyArray = self._myarray_fixture() m = MetaData() arrtable = Table( - 'arrtable', m, - Column('intarr', MyArray(Integer)), - Column('strarr', MyArray(String)), + "arrtable", + m, + Column("intarr", MyArray(Integer)), + Column("strarr", MyArray(String)), ) is_(arrtable.c.intarr[1:3].type._type_affinity, ARRAY) is_(arrtable.c.strarr[1:3].type._type_affinity, ARRAY) @@ -2047,34 +2203,36 @@ test_table = meta = MyCustomType = MyTypeDec = None class ExpressionTest( - fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL): - __dialect__ = 'default' + fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL +): + __dialect__ = "default" @classmethod def setup_class(cls): global test_table, meta, MyCustomType, MyTypeDec class MyCustomType(types.UserDefinedType): - def get_col_spec(self): return "INT" def bind_processor(self, dialect): def process(value): return value * 10 + return process def result_processor(self, dialect, coltype): def process(value): return value / 10 + return process class MyOldCustomType(MyCustomType): - def adapt_operator(self, op): return { operators.add: operators.sub, - operators.sub: operators.add}.get(op, op) + operators.sub: operators.add, + }.get(op, op) class MyTypeDec(types.TypeDecorator): impl = String @@ -2087,20 +2245,26 @@ class ExpressionTest( meta = MetaData(testing.db) test_table = Table( - 'test', meta, - Column('id', Integer, primary_key=True), - Column('data', String(30)), - Column('atimestamp', Date), - Column('avalue', MyCustomType), - Column('bvalue', MyTypeDec(50)), + "test", + meta, + Column("id", Integer, primary_key=True), + Column("data", String(30)), + Column("atimestamp", Date), + Column("avalue", MyCustomType), + Column("bvalue", MyTypeDec(50)), ) meta.create_all() - test_table.insert().execute({ - 'id': 1, 'data': 'somedata', - 'atimestamp': datetime.date(2007, 10, 15), 'avalue': 25, - 'bvalue': 'foo'}) + test_table.insert().execute( + { + "id": 1, + "data": "somedata", + "atimestamp": datetime.date(2007, 10, 15), + "avalue": 25, + "bvalue": "foo", + } + ) @classmethod def teardown_class(cls): @@ -2111,8 +2275,15 @@ class ExpressionTest( eq_( test_table.select().execute().fetchall(), - [(1, 'somedata', datetime.date(2007, 10, 15), 25, - 'BIND_INfooBIND_OUT')] + [ + ( + 1, + "somedata", + datetime.date(2007, 10, 15), + 25, + "BIND_INfooBIND_OUT", + ) + ], ) def test_bind_adapt(self): @@ -2122,11 +2293,16 @@ class ExpressionTest( eq_( testing.db.execute( - select([ - test_table.c.id, test_table.c.data, - test_table.c.atimestamp]).where(expr), - {"thedate": datetime.date(2007, 10, 15)}).fetchall(), [ - (1, 'somedata', datetime.date(2007, 10, 15))] + select( + [ + test_table.c.id, + test_table.c.data, + test_table.c.atimestamp, + ] + ).where(expr), + {"thedate": datetime.date(2007, 10, 15)}, + ).fetchall(), + [(1, "somedata", datetime.date(2007, 10, 15))], ) expr = test_table.c.avalue == bindparam("somevalue") @@ -2134,10 +2310,17 @@ class ExpressionTest( eq_( testing.db.execute( - test_table.select().where(expr), {'somevalue': 25} - ).fetchall(), [( - 1, 'somedata', datetime.date(2007, 10, 15), 25, - 'BIND_INfooBIND_OUT')] + test_table.select().where(expr), {"somevalue": 25} + ).fetchall(), + [ + ( + 1, + "somedata", + datetime.date(2007, 10, 15), + 25, + "BIND_INfooBIND_OUT", + ) + ], ) expr = test_table.c.bvalue == bindparam("somevalue") @@ -2146,9 +2329,16 @@ class ExpressionTest( eq_( testing.db.execute( test_table.select().where(expr), {"somevalue": "foo"} - ).fetchall(), [( - 1, 'somedata', datetime.date(2007, 10, 15), 25, - 'BIND_INfooBIND_OUT')] + ).fetchall(), + [ + ( + 1, + "somedata", + datetime.date(2007, 10, 15), + 25, + "BIND_INfooBIND_OUT", + ) + ], ) def test_bind_adapt_update(self): @@ -2156,14 +2346,14 @@ class ExpressionTest( stmt = test_table.update().values(avalue=bp) compiled = stmt.compile() eq_(bp.type._type_affinity, types.NullType) - eq_(compiled.binds['somevalue'].type._type_affinity, MyCustomType) + eq_(compiled.binds["somevalue"].type._type_affinity, MyCustomType) def test_bind_adapt_insert(self): bp = bindparam("somevalue") stmt = test_table.insert().values(avalue=bp) compiled = stmt.compile() eq_(bp.type._type_affinity, types.NullType) - eq_(compiled.binds['somevalue'].type._type_affinity, MyCustomType) + eq_(compiled.binds["somevalue"].type._type_affinity, MyCustomType) def test_bind_adapt_expression(self): bp = bindparam("somevalue") @@ -2175,16 +2365,16 @@ class ExpressionTest( # literals get typed based on the types dictionary, unless # compatible with the left side type - expr = column('foo', String) == 5 + expr = column("foo", String) == 5 eq_(expr.right.type._type_affinity, Integer) - expr = column('foo', String) == "asdf" + expr = column("foo", String) == "asdf" eq_(expr.right.type._type_affinity, String) - expr = column('foo', CHAR) == 5 + expr = column("foo", CHAR) == 5 eq_(expr.right.type._type_affinity, Integer) - expr = column('foo', CHAR) == "asdf" + expr = column("foo", CHAR) == "asdf" eq_(expr.right.type.__class__, CHAR) def test_actual_literal_adapters(self): @@ -2197,12 +2387,9 @@ class ExpressionTest( (datetime.time(10, 15, 20), Time), (datetime.datetime(2015, 7, 20, 10, 15, 20), DateTime), (datetime.timedelta(seconds=5), Interval), - (None, types.NullType) + (None, types.NullType), ]: - is_( - literal(data).type.__class__, - expected - ) + is_(literal(data).type.__class__, expected) def test_typedec_operator_adapt(self): expr = test_table.c.bvalue + "hi" @@ -2211,8 +2398,8 @@ class ExpressionTest( assert expr.right.type.__class__ is MyTypeDec eq_( - testing.db.execute(select([expr.label('foo')])).scalar(), - "BIND_INfooBIND_INhiBIND_OUT" + testing.db.execute(select([expr.label("foo")])).scalar(), + "BIND_INfooBIND_INhiBIND_OUT", ) def test_typedec_is_adapt(self): @@ -2221,35 +2408,35 @@ class ExpressionTest( impl = Integer class CoerceBool(TypeDecorator): - coerce_to_is_types = (bool, ) + coerce_to_is_types = (bool,) impl = Boolean class CoerceNone(TypeDecorator): coerce_to_is_types = (type(None),) impl = Integer - c1 = column('x', CoerceNothing()) - c2 = column('x', CoerceBool()) - c3 = column('x', CoerceNone()) + c1 = column("x", CoerceNothing()) + c2 = column("x", CoerceBool()) + c3 = column("x", CoerceNone()) self.assert_compile( and_(c1 == None, c2 == None, c3 == None), # noqa - "x = :x_1 AND x = :x_2 AND x IS NULL" + "x = :x_1 AND x = :x_2 AND x IS NULL", ) self.assert_compile( and_(c1 == True, c2 == True, c3 == True), # noqa "x = :x_1 AND x = true AND x = :x_2", - dialect=default.DefaultDialect(supports_native_boolean=True) + dialect=default.DefaultDialect(supports_native_boolean=True), ) self.assert_compile( and_(c1 == 3, c2 == 3, c3 == 3), "x = :x_1 AND x = :x_2 AND x = :x_3", - dialect=default.DefaultDialect(supports_native_boolean=True) + dialect=default.DefaultDialect(supports_native_boolean=True), ) self.assert_compile( and_(c1.is_(True), c2.is_(True), c3.is_(True)), "x IS :x_1 AND x IS true AND x IS :x_2", - dialect=default.DefaultDialect(supports_native_boolean=True) + dialect=default.DefaultDialect(supports_native_boolean=True), ) def test_typedec_righthand_coercion(self): @@ -2262,21 +2449,19 @@ class ExpressionTest( def process_result_value(self, value, dialect): return value + "BIND_OUT" - tab = table('test', column('bvalue', MyTypeDec)) + tab = table("test", column("bvalue", MyTypeDec)) expr = tab.c.bvalue + 6 self.assert_compile( - expr, - "test.bvalue || :bvalue_1", - use_default_dialect=True + expr, "test.bvalue || :bvalue_1", use_default_dialect=True ) is_(expr.right.type.__class__, MyTypeDec) is_(expr.type.__class__, MyTypeDec) eq_( - testing.db.execute(select([expr.label('foo')])).scalar(), - "BIND_INfooBIND_IN6BIND_OUT" + testing.db.execute(select([expr.label("foo")])).scalar(), + "BIND_INfooBIND_IN6BIND_OUT", ) def test_variant_righthand_coercion_honors_wrapped(self): @@ -2284,16 +2469,16 @@ class ExpressionTest( my_json_variant = JSON().with_variant(String(), "sqlite") tab = table( - 'test', - column('avalue', my_json_normal), - column('bvalue', my_json_variant) + "test", + column("avalue", my_json_normal), + column("bvalue", my_json_variant), ) - expr = tab.c.avalue['foo'] == 'bar' + expr = tab.c.avalue["foo"] == "bar" is_(expr.right.type._type_affinity, String) is_not_(expr.right.type, my_json_normal) - expr = tab.c.bvalue['foo'] == 'bar' + expr = tab.c.bvalue["foo"] == "bar" is_(expr.right.type._type_affinity, String) is_not_(expr.right.type, my_json_variant) @@ -2301,12 +2486,13 @@ class ExpressionTest( def test_variant_righthand_coercion_returns_self(self): my_datetime_normal = DateTime() my_datetime_variant = DateTime().with_variant( - dialects.sqlite.DATETIME(truncate_microseconds=False), "sqlite") + dialects.sqlite.DATETIME(truncate_microseconds=False), "sqlite" + ) tab = table( - 'test', - column('avalue', my_datetime_normal), - column('bvalue', my_datetime_variant) + "test", + column("avalue", my_datetime_normal), + column("bvalue", my_datetime_variant), ) expr = tab.c.avalue == datetime.datetime(2015, 10, 14, 15, 17, 18) @@ -2353,72 +2539,69 @@ class ExpressionTest( assert expr.right.type._type_affinity is MyFoobarType def test_date_coercion(self): - expr = column('bar', types.NULLTYPE) - column('foo', types.TIMESTAMP) + expr = column("bar", types.NULLTYPE) - column("foo", types.TIMESTAMP) eq_(expr.type._type_affinity, types.NullType) - expr = func.sysdate() - column('foo', types.TIMESTAMP) + expr = func.sysdate() - column("foo", types.TIMESTAMP) eq_(expr.type._type_affinity, types.Interval) - expr = func.current_date() - column('foo', types.TIMESTAMP) + expr = func.current_date() - column("foo", types.TIMESTAMP) eq_(expr.type._type_affinity, types.Interval) def test_interval_coercion(self): - expr = column('bar', types.Interval) + column('foo', types.Date) + expr = column("bar", types.Interval) + column("foo", types.Date) eq_(expr.type._type_affinity, types.DateTime) - expr = column('bar', types.Interval) * column('foo', types.Numeric) + expr = column("bar", types.Interval) * column("foo", types.Numeric) eq_(expr.type._type_affinity, types.Interval) - def test_numerics_coercion(self): for op in (operator.add, operator.mul, operator.truediv, operator.sub): for other in (Numeric(10, 2), Integer): expr = op( - column('bar', types.Numeric(10, 2)), - column('foo', other) + column("bar", types.Numeric(10, 2)), column("foo", other) ) assert isinstance(expr.type, types.Numeric) expr = op( - column('foo', other), - column('bar', types.Numeric(10, 2)) + column("foo", other), column("bar", types.Numeric(10, 2)) ) assert isinstance(expr.type, types.Numeric) def test_asdecimal_int_to_numeric(self): - expr = column('a', Integer) * column('b', Numeric(asdecimal=False)) + expr = column("a", Integer) * column("b", Numeric(asdecimal=False)) is_(expr.type.asdecimal, False) - expr = column('a', Integer) * column('b', Numeric()) + expr = column("a", Integer) * column("b", Numeric()) is_(expr.type.asdecimal, True) - expr = column('a', Integer) * column('b', Float()) + expr = column("a", Integer) * column("b", Float()) is_(expr.type.asdecimal, False) assert isinstance(expr.type, Float) def test_asdecimal_numeric_to_int(self): - expr = column('a', Numeric(asdecimal=False)) * column('b', Integer) + expr = column("a", Numeric(asdecimal=False)) * column("b", Integer) is_(expr.type.asdecimal, False) - expr = column('a', Numeric()) * column('b', Integer) + expr = column("a", Numeric()) * column("b", Integer) is_(expr.type.asdecimal, True) - expr = column('a', Float()) * column('b', Integer) + expr = column("a", Float()) * column("b", Integer) is_(expr.type.asdecimal, False) assert isinstance(expr.type, Float) def test_null_comparison(self): eq_( - str(column('a', types.NullType()) + column('b', types.NullType())), - "a + b" + str(column("a", types.NullType()) + column("b", types.NullType())), + "a + b", ) def test_expression_typing(self): - expr = column('bar', Integer) - 3 + expr = column("bar", Integer) - 3 eq_(expr.type._type_affinity, Integer) - expr = bindparam('bar') + bindparam('foo') + expr = bindparam("bar") + bindparam("foo") eq_(expr.type, types.NULLTYPE) def test_distinct(self): @@ -2443,24 +2626,18 @@ class ExpressionTest( assert_raises_message( exc.ArgumentError, r"Object some_sqla_thing\(\) is not legal as a SQL literal value", - lambda: column('a', String) == SomeSQLAThing() + lambda: column("a", String) == SomeSQLAThing(), ) - is_( - bindparam('x', SomeOtherThing()).type, - types.NULLTYPE - ) + is_(bindparam("x", SomeOtherThing()).type, types.NULLTYPE) def test_detect_coercion_not_fooled_by_mock(self): m1 = mock.Mock() - is_( - bindparam('x', m1).type, - types.NULLTYPE - ) + is_(bindparam("x", m1).type, types.NULLTYPE) class CompileTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" @testing.requires.unbounded_varchar def test_string_plain(self): @@ -2471,7 +2648,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_string_collation(self): self.assert_compile( - String(50, collation="FOO"), 'VARCHAR(50) COLLATE "FOO"') + String(50, collation="FOO"), 'VARCHAR(50) COLLATE "FOO"' + ) def test_char_plain(self): self.assert_compile(CHAR(), "CHAR") @@ -2481,7 +2659,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_char_collation(self): self.assert_compile( - CHAR(50, collation="FOO"), 'CHAR(50) COLLATE "FOO"') + CHAR(50, collation="FOO"), 'CHAR(50) COLLATE "FOO"' + ) def test_text_plain(self): self.assert_compile(Text(), "TEXT") @@ -2490,39 +2669,42 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(Text(50), "TEXT(50)") def test_text_collation(self): - self.assert_compile( - Text(collation="FOO"), 'TEXT COLLATE "FOO"') + self.assert_compile(Text(collation="FOO"), 'TEXT COLLATE "FOO"') def test_default_compile_pg_inet(self): self.assert_compile( - dialects.postgresql.INET(), "INET", allow_dialect_select=True) + dialects.postgresql.INET(), "INET", allow_dialect_select=True + ) def test_default_compile_pg_float(self): self.assert_compile( - dialects.postgresql.FLOAT(), "FLOAT", allow_dialect_select=True) + dialects.postgresql.FLOAT(), "FLOAT", allow_dialect_select=True + ) def test_default_compile_mysql_integer(self): self.assert_compile( - dialects.mysql.INTEGER(display_width=5), "INTEGER(5)", - allow_dialect_select=True) + dialects.mysql.INTEGER(display_width=5), + "INTEGER(5)", + allow_dialect_select=True, + ) def test_numeric_plain(self): - self.assert_compile(types.NUMERIC(), 'NUMERIC') + self.assert_compile(types.NUMERIC(), "NUMERIC") def test_numeric_precision(self): - self.assert_compile(types.NUMERIC(2), 'NUMERIC(2)') + self.assert_compile(types.NUMERIC(2), "NUMERIC(2)") def test_numeric_scale(self): - self.assert_compile(types.NUMERIC(2, 4), 'NUMERIC(2, 4)') + self.assert_compile(types.NUMERIC(2, 4), "NUMERIC(2, 4)") def test_decimal_plain(self): - self.assert_compile(types.DECIMAL(), 'DECIMAL') + self.assert_compile(types.DECIMAL(), "DECIMAL") def test_decimal_precision(self): - self.assert_compile(types.DECIMAL(2), 'DECIMAL(2)') + self.assert_compile(types.DECIMAL(2), "DECIMAL(2)") def test_decimal_scale(self): - self.assert_compile(types.DECIMAL(2, 4), 'DECIMAL(2, 4)') + self.assert_compile(types.DECIMAL(2, 4), "DECIMAL(2, 4)") def test_kwarg_legacy_typecompiler(self): from sqlalchemy.sql import compiler @@ -2534,19 +2716,19 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): # not affected def visit_INTEGER(self, type_, **kw): - return "MYINTEGER %s" % kw['type_expression'].name + return "MYINTEGER %s" % kw["type_expression"].name dialect = default.DefaultDialect() dialect.type_compiler = SomeTypeCompiler(dialect) self.assert_compile( - ddl.CreateColumn(Column('bar', VARCHAR(50))), + ddl.CreateColumn(Column("bar", VARCHAR(50))), "bar MYVARCHAR", - dialect=dialect + dialect=dialect, ) self.assert_compile( - ddl.CreateColumn(Column('bar', INTEGER)), + ddl.CreateColumn(Column("bar", INTEGER)), "bar MYINTEGER bar", - dialect=dialect + dialect=dialect, ) @@ -2558,14 +2740,11 @@ class TestKWArgPassThru(AssertsCompiledSQL, fixtures.TestBase): class MyType(types.UserDefinedType): def get_col_spec(self, **kw): - return "FOOB %s" % kw['type_expression'].name + return "FOOB %s" % kw["type_expression"].name m = MetaData() - t = Table('t', m, Column('bar', MyType, nullable=False)) - self.assert_compile( - ddl.CreateColumn(t.c.bar), - "bar FOOB bar NOT NULL" - ) + t = Table("t", m, Column("bar", MyType, nullable=False)) + self.assert_compile(ddl.CreateColumn(t.c.bar), "bar FOOB bar NOT NULL") class NumericRawSQLTest(fixtures.TestBase): @@ -2576,11 +2755,11 @@ class NumericRawSQLTest(fixtures.TestBase): """ def _fixture(self, metadata, type, data): - t = Table('t', metadata, Column("val", type)) + t = Table("t", metadata, Column("val", type)) metadata.create_all() t.insert().execute(val=data) - @testing.fails_on('sqlite', "Doesn't provide Decimal results natively") + @testing.fails_on("sqlite", "Doesn't provide Decimal results natively") @testing.provide_metadata def test_decimal_fp(self): metadata = self.metadata @@ -2589,7 +2768,7 @@ class NumericRawSQLTest(fixtures.TestBase): assert isinstance(val, decimal.Decimal) eq_(val, decimal.Decimal("45.5")) - @testing.fails_on('sqlite', "Doesn't provide Decimal results natively") + @testing.fails_on("sqlite", "Doesn't provide Decimal results natively") @testing.provide_metadata def test_decimal_int(self): metadata = self.metadata @@ -2614,7 +2793,7 @@ class NumericRawSQLTest(fixtures.TestBase): assert isinstance(val, float) # some DBAPIs have unusual float handling - if testing.against('oracle+cx_oracle', 'mysql+oursql', 'firebird'): + if testing.against("oracle+cx_oracle", "mysql+oursql", "firebird"): eq_(round_decimal(val, 3), 46.583) else: eq_(val, 46.583) @@ -2624,22 +2803,22 @@ interval_table = metadata = None class IntervalTest(fixtures.TestBase, AssertsExecutionResults): - @classmethod def setup_class(cls): global interval_table, metadata metadata = MetaData(testing.db) interval_table = Table( - "intervaltable", metadata, + "intervaltable", + metadata, Column( - "id", Integer, primary_key=True, - test_needs_autoincrement=True), + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), Column("native_interval", Interval()), Column( "native_interval_args", - Interval(day_precision=3, second_precision=6)), - Column( - "non_native_interval", Interval(native=False)), + Interval(day_precision=3, second_precision=6), + ), + Column("non_native_interval", Interval(native=False)), ) metadata.create_all() @@ -2662,59 +2841,66 @@ class IntervalTest(fixtures.TestBase, AssertsExecutionResults): small_delta = datetime.timedelta(days=15, seconds=5874) delta = datetime.timedelta(414) interval_table.insert().execute( - native_interval=small_delta, native_interval_args=delta, - non_native_interval=delta) + native_interval=small_delta, + native_interval_args=delta, + non_native_interval=delta, + ) row = interval_table.select().execute().first() - eq_(row['native_interval'], small_delta) - eq_(row['native_interval_args'], delta) - eq_(row['non_native_interval'], delta) + eq_(row["native_interval"], small_delta) + eq_(row["native_interval_args"], delta) + eq_(row["non_native_interval"], delta) def test_null(self): interval_table.insert().execute( - id=1, native_inverval=None, non_native_interval=None) + id=1, native_inverval=None, non_native_interval=None + ) row = interval_table.select().execute().first() - eq_(row['native_interval'], None) - eq_(row['native_interval_args'], None) - eq_(row['non_native_interval'], None) + eq_(row["native_interval"], None) + eq_(row["native_interval_args"], None) + eq_(row["non_native_interval"], None) class BooleanTest( - fixtures.TablesTest, AssertsExecutionResults, AssertsCompiledSQL): + fixtures.TablesTest, AssertsExecutionResults, AssertsCompiledSQL +): """test edge cases for booleans. Note that the main boolean test suite is now in testing/suite/test_types.py """ + @classmethod def define_tables(cls, metadata): Table( - 'boolean_table', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('value', Boolean), - Column('unconstrained_value', Boolean(create_constraint=False)), + "boolean_table", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("value", Boolean), + Column("unconstrained_value", Boolean(create_constraint=False)), ) @testing.fails_on( - 'mysql', - "The CHECK clause is parsed but ignored by all storage engines.") - @testing.fails_on( - 'mssql', "FIXME: MS-SQL 2005 doesn't honor CHECK ?!?") + "mysql", + "The CHECK clause is parsed but ignored by all storage engines.", + ) + @testing.fails_on("mssql", "FIXME: MS-SQL 2005 doesn't honor CHECK ?!?") @testing.skip_if(lambda: testing.db.dialect.supports_native_boolean) def test_constraint(self): assert_raises( (exc.IntegrityError, exc.ProgrammingError), testing.db.execute, - "insert into boolean_table (id, value) values(1, 5)") + "insert into boolean_table (id, value) values(1, 5)", + ) @testing.skip_if(lambda: testing.db.dialect.supports_native_boolean) def test_unconstrained(self): testing.db.execute( "insert into boolean_table (id, unconstrained_value)" - "values (1, 5)") + "values (1, 5)" + ) def test_non_native_constraint_custom_type(self): class Foob(object): - def __init__(self, value): self.value = value @@ -2729,14 +2915,15 @@ class BooleanTest( return value.value m = MetaData() - t1 = Table('t', m, Column('x', MyBool())) - const = [ - c for c in t1.constraints if isinstance(c, CheckConstraint)][0] + t1 = Table("t", m, Column("x", MyBool())) + const = [c for c in t1.constraints if isinstance(c, CheckConstraint)][ + 0 + ] self.assert_compile( AddConstraint(const), "ALTER TABLE t ADD CHECK (x IN (0, 1))", - dialect="sqlite" + dialect="sqlite", ) @testing.skip_if(lambda: testing.db.dialect.supports_native_boolean) @@ -2748,10 +2935,9 @@ class BooleanTest( "Value 5 is not None, True, or False", conn.execute, boolean_table.insert(), - {"id": 1, "unconstrained_value": 5} + {"id": 1, "unconstrained_value": 5}, ) - @testing.requires.non_native_boolean_unconstrained def test_nonnative_processor_coerces_integer_to_boolean(self): boolean_table = self.tables.boolean_table @@ -2762,194 +2948,204 @@ class BooleanTest( ) eq_( - conn.scalar("select unconstrained_value from boolean_table"), - 5 + conn.scalar("select unconstrained_value from boolean_table"), 5 ) eq_( conn.scalar(select([boolean_table.c.unconstrained_value])), - True + True, ) def test_bind_processor_coercion_native_true(self): proc = Boolean().bind_processor( - mock.Mock(supports_native_boolean=True)) + mock.Mock(supports_native_boolean=True) + ) is_(proc(True), True) def test_bind_processor_coercion_native_false(self): proc = Boolean().bind_processor( - mock.Mock(supports_native_boolean=True)) + mock.Mock(supports_native_boolean=True) + ) is_(proc(False), False) def test_bind_processor_coercion_native_none(self): proc = Boolean().bind_processor( - mock.Mock(supports_native_boolean=True)) + mock.Mock(supports_native_boolean=True) + ) is_(proc(None), None) def test_bind_processor_coercion_native_0(self): proc = Boolean().bind_processor( - mock.Mock(supports_native_boolean=True)) + mock.Mock(supports_native_boolean=True) + ) is_(proc(0), False) def test_bind_processor_coercion_native_1(self): proc = Boolean().bind_processor( - mock.Mock(supports_native_boolean=True)) + mock.Mock(supports_native_boolean=True) + ) is_(proc(1), True) def test_bind_processor_coercion_native_str(self): proc = Boolean().bind_processor( - mock.Mock(supports_native_boolean=True)) + mock.Mock(supports_native_boolean=True) + ) assert_raises_message( - TypeError, - "Not a boolean value: 'foo'", - proc, "foo" + TypeError, "Not a boolean value: 'foo'", proc, "foo" ) def test_bind_processor_coercion_native_int_out_of_range(self): proc = Boolean().bind_processor( - mock.Mock(supports_native_boolean=True)) + mock.Mock(supports_native_boolean=True) + ) assert_raises_message( - ValueError, - "Value 15 is not None, True, or False", - proc, 15 + ValueError, "Value 15 is not None, True, or False", proc, 15 ) def test_bind_processor_coercion_nonnative_true(self): proc = Boolean().bind_processor( - mock.Mock(supports_native_boolean=False)) + mock.Mock(supports_native_boolean=False) + ) eq_(proc(True), 1) def test_bind_processor_coercion_nonnative_false(self): proc = Boolean().bind_processor( - mock.Mock(supports_native_boolean=False)) + mock.Mock(supports_native_boolean=False) + ) eq_(proc(False), 0) def test_bind_processor_coercion_nonnative_none(self): proc = Boolean().bind_processor( - mock.Mock(supports_native_boolean=False)) + mock.Mock(supports_native_boolean=False) + ) is_(proc(None), None) def test_bind_processor_coercion_nonnative_0(self): proc = Boolean().bind_processor( - mock.Mock(supports_native_boolean=False)) + mock.Mock(supports_native_boolean=False) + ) eq_(proc(0), 0) def test_bind_processor_coercion_nonnative_1(self): proc = Boolean().bind_processor( - mock.Mock(supports_native_boolean=False)) + mock.Mock(supports_native_boolean=False) + ) eq_(proc(1), 1) def test_bind_processor_coercion_nonnative_str(self): proc = Boolean().bind_processor( - mock.Mock(supports_native_boolean=False)) + mock.Mock(supports_native_boolean=False) + ) assert_raises_message( - TypeError, - "Not a boolean value: 'foo'", - proc, "foo" + TypeError, "Not a boolean value: 'foo'", proc, "foo" ) def test_bind_processor_coercion_nonnative_int_out_of_range(self): proc = Boolean().bind_processor( - mock.Mock(supports_native_boolean=False)) + mock.Mock(supports_native_boolean=False) + ) assert_raises_message( - ValueError, - "Value 15 is not None, True, or False", - proc, 15 + ValueError, "Value 15 is not None, True, or False", proc, 15 ) def test_literal_processor_coercion_native_true(self): proc = Boolean().literal_processor( - default.DefaultDialect(supports_native_boolean=True)) + default.DefaultDialect(supports_native_boolean=True) + ) eq_(proc(True), "true") def test_literal_processor_coercion_native_false(self): proc = Boolean().literal_processor( - default.DefaultDialect(supports_native_boolean=True)) + default.DefaultDialect(supports_native_boolean=True) + ) eq_(proc(False), "false") def test_literal_processor_coercion_native_1(self): proc = Boolean().literal_processor( - default.DefaultDialect(supports_native_boolean=True)) + default.DefaultDialect(supports_native_boolean=True) + ) eq_(proc(1), "true") def test_literal_processor_coercion_native_0(self): proc = Boolean().literal_processor( - default.DefaultDialect(supports_native_boolean=True)) + default.DefaultDialect(supports_native_boolean=True) + ) eq_(proc(0), "false") def test_literal_processor_coercion_native_str(self): proc = Boolean().literal_processor( - default.DefaultDialect(supports_native_boolean=True)) + default.DefaultDialect(supports_native_boolean=True) + ) assert_raises_message( - TypeError, - "Not a boolean value: 'foo'", - proc, "foo" + TypeError, "Not a boolean value: 'foo'", proc, "foo" ) def test_literal_processor_coercion_native_int_out_of_range(self): proc = Boolean().literal_processor( - default.DefaultDialect(supports_native_boolean=True)) + default.DefaultDialect(supports_native_boolean=True) + ) assert_raises_message( - ValueError, - "Value 15 is not None, True, or False", - proc, 15 + ValueError, "Value 15 is not None, True, or False", proc, 15 ) def test_literal_processor_coercion_nonnative_true(self): proc = Boolean().literal_processor( - default.DefaultDialect(supports_native_boolean=False)) + default.DefaultDialect(supports_native_boolean=False) + ) eq_(proc(True), "1") def test_literal_processor_coercion_nonnative_false(self): proc = Boolean().literal_processor( - default.DefaultDialect(supports_native_boolean=False)) + default.DefaultDialect(supports_native_boolean=False) + ) eq_(proc(False), "0") def test_literal_processor_coercion_nonnative_1(self): proc = Boolean().literal_processor( - default.DefaultDialect(supports_native_boolean=False)) + default.DefaultDialect(supports_native_boolean=False) + ) eq_(proc(1), "1") def test_literal_processor_coercion_nonnative_0(self): proc = Boolean().literal_processor( - default.DefaultDialect(supports_native_boolean=False)) + default.DefaultDialect(supports_native_boolean=False) + ) eq_(proc(0), "0") def test_literal_processor_coercion_nonnative_str(self): proc = Boolean().literal_processor( - default.DefaultDialect(supports_native_boolean=False)) + default.DefaultDialect(supports_native_boolean=False) + ) assert_raises_message( - TypeError, - "Not a boolean value: 'foo'", - proc, "foo" + TypeError, "Not a boolean value: 'foo'", proc, "foo" ) - class PickleTest(fixtures.TestBase): - def test_eq_comparison(self): p1 = PickleType() for obj in ( - {'1': '2'}, + {"1": "2"}, pickleable.Bar(5, 6), - pickleable.OldSchool(10, 11) + pickleable.OldSchool(10, 11), ): assert p1.compare_values(p1.copy_value(obj), obj) assert_raises( - NotImplementedError, p1.compare_values, - pickleable.BrokenComparable('foo'), - pickleable.BrokenComparable('foo')) + NotImplementedError, + p1.compare_values, + pickleable.BrokenComparable("foo"), + pickleable.BrokenComparable("foo"), + ) def test_nonmutable_comparison(self): p1 = PickleType() for obj in ( - {'1': '2'}, + {"1": "2"}, pickleable.Bar(5, 6), - pickleable.OldSchool(10, 11) + pickleable.OldSchool(10, 11), ): assert p1.compare_values(p1.copy_value(obj), obj) @@ -2958,7 +3154,6 @@ meta = None class CallableTest(fixtures.TestBase): - @classmethod def setup_class(cls): global meta @@ -2971,9 +3166,7 @@ class CallableTest(fixtures.TestBase): def test_callable_as_arg(self): ucode = util.partial(Unicode) - thing_table = Table( - 'thing', meta, Column('name', ucode(20)) - ) + thing_table = Table("thing", meta, Column("name", ucode(20))) assert isinstance(thing_table.c.name.type, Unicode) thing_table.create() @@ -2981,7 +3174,7 @@ class CallableTest(fixtures.TestBase): ucode = util.partial(Unicode) thang_table = Table( - 'thang', meta, Column('name', type_=ucode(20), primary_key=True) + "thang", meta, Column("name", type_=ucode(20), primary_key=True) ) assert isinstance(thang_table.c.name.type, Unicode) thang_table.create() diff --git a/test/sql/test_unicode.py b/test/sql/test_unicode.py index e29aea54fb..fafdb03a3a 100644 --- a/test/sql/test_unicode.py +++ b/test/sql/test_unicode.py @@ -9,7 +9,7 @@ from sqlalchemy.util import u, ue class UnicodeSchemaTest(fixtures.TestBase): - __requires__ = ('unicode_ddl',) + __requires__ = ("unicode_ddl",) __backend__ = True @classmethod @@ -17,53 +17,67 @@ class UnicodeSchemaTest(fixtures.TestBase): global metadata, t1, t2, t3 metadata = MetaData(testing.db) - t1 = Table(u('unitable1'), metadata, - Column(u('méil'), Integer, primary_key=True), - Column(ue('\u6e2c\u8a66'), Integer), - test_needs_fk=True, - ) + t1 = Table( + u("unitable1"), + metadata, + Column(u("méil"), Integer, primary_key=True), + Column(ue("\u6e2c\u8a66"), Integer), + test_needs_fk=True, + ) t2 = Table( - u('Unitéble2'), + u("Unitéble2"), metadata, + Column(u("méil"), Integer, primary_key=True, key="a"), Column( - u('méil'), - Integer, - primary_key=True, - key="a"), - Column( - ue('\u6e2c\u8a66'), + ue("\u6e2c\u8a66"), Integer, - ForeignKey( - u('unitable1.méil')), - key="b"), + ForeignKey(u("unitable1.méil")), + key="b", + ), test_needs_fk=True, ) # Few DBs support Unicode foreign keys - if testing.against('sqlite'): - t3 = Table(ue('\u6e2c\u8a66'), metadata, - Column(ue('\u6e2c\u8a66_id'), Integer, primary_key=True, - autoincrement=False), - Column(ue('unitable1_\u6e2c\u8a66'), Integer, - ForeignKey(ue('unitable1.\u6e2c\u8a66')) - ), - Column(u('Unitéble2_b'), Integer, - ForeignKey(u('Unitéble2.b')) - ), - Column(ue('\u6e2c\u8a66_self'), Integer, - ForeignKey(ue('\u6e2c\u8a66.\u6e2c\u8a66_id')) - ), - test_needs_fk=True, - ) + if testing.against("sqlite"): + t3 = Table( + ue("\u6e2c\u8a66"), + metadata, + Column( + ue("\u6e2c\u8a66_id"), + Integer, + primary_key=True, + autoincrement=False, + ), + Column( + ue("unitable1_\u6e2c\u8a66"), + Integer, + ForeignKey(ue("unitable1.\u6e2c\u8a66")), + ), + Column( + u("Unitéble2_b"), Integer, ForeignKey(u("Unitéble2.b")) + ), + Column( + ue("\u6e2c\u8a66_self"), + Integer, + ForeignKey(ue("\u6e2c\u8a66.\u6e2c\u8a66_id")), + ), + test_needs_fk=True, + ) else: - t3 = Table(ue('\u6e2c\u8a66'), metadata, - Column(ue('\u6e2c\u8a66_id'), Integer, primary_key=True, - autoincrement=False), - Column(ue('unitable1_\u6e2c\u8a66'), Integer), - Column(u('Unitéble2_b'), Integer), - Column(ue('\u6e2c\u8a66_self'), Integer), - test_needs_fk=True, - ) + t3 = Table( + ue("\u6e2c\u8a66"), + metadata, + Column( + ue("\u6e2c\u8a66_id"), + Integer, + primary_key=True, + autoincrement=False, + ), + Column(ue("unitable1_\u6e2c\u8a66"), Integer), + Column(u("Unitéble2_b"), Integer), + Column(ue("\u6e2c\u8a66_self"), Integer), + test_needs_fk=True, + ) metadata.create_all() @engines.close_first @@ -78,86 +92,104 @@ class UnicodeSchemaTest(fixtures.TestBase): metadata.drop_all() def test_insert(self): - t1.insert().execute({u('méil'): 1, ue('\u6e2c\u8a66'): 5}) - t2.insert().execute({u('a'): 1, u('b'): 1}) - t3.insert().execute({ue('\u6e2c\u8a66_id'): 1, - ue('unitable1_\u6e2c\u8a66'): 5, - u('Unitéble2_b'): 1, - ue('\u6e2c\u8a66_self'): 1}) + t1.insert().execute({u("méil"): 1, ue("\u6e2c\u8a66"): 5}) + t2.insert().execute({u("a"): 1, u("b"): 1}) + t3.insert().execute( + { + ue("\u6e2c\u8a66_id"): 1, + ue("unitable1_\u6e2c\u8a66"): 5, + u("Unitéble2_b"): 1, + ue("\u6e2c\u8a66_self"): 1, + } + ) assert t1.select().execute().fetchall() == [(1, 5)] assert t2.select().execute().fetchall() == [(1, 1)] assert t3.select().execute().fetchall() == [(1, 5, 1, 1)] def test_col_targeting(self): - t1.insert().execute({u('méil'): 1, ue('\u6e2c\u8a66'): 5}) - t2.insert().execute({u('a'): 1, u('b'): 1}) - t3.insert().execute({ue('\u6e2c\u8a66_id'): 1, - ue('unitable1_\u6e2c\u8a66'): 5, - u('Unitéble2_b'): 1, - ue('\u6e2c\u8a66_self'): 1}) + t1.insert().execute({u("méil"): 1, ue("\u6e2c\u8a66"): 5}) + t2.insert().execute({u("a"): 1, u("b"): 1}) + t3.insert().execute( + { + ue("\u6e2c\u8a66_id"): 1, + ue("unitable1_\u6e2c\u8a66"): 5, + u("Unitéble2_b"): 1, + ue("\u6e2c\u8a66_self"): 1, + } + ) row = t1.select().execute().first() - eq_(row[t1.c[u('méil')]], 1) - eq_(row[t1.c[ue('\u6e2c\u8a66')]], 5) + eq_(row[t1.c[u("méil")]], 1) + eq_(row[t1.c[ue("\u6e2c\u8a66")]], 5) row = t2.select().execute().first() - eq_(row[t2.c[u('a')]], 1) - eq_(row[t2.c[u('b')]], 1) + eq_(row[t2.c[u("a")]], 1) + eq_(row[t2.c[u("b")]], 1) row = t3.select().execute().first() - eq_(row[t3.c[ue('\u6e2c\u8a66_id')]], 1) - eq_(row[t3.c[ue('unitable1_\u6e2c\u8a66')]], 5) - eq_(row[t3.c[u('Unitéble2_b')]], 1) - eq_(row[t3.c[ue('\u6e2c\u8a66_self')]], 1) + eq_(row[t3.c[ue("\u6e2c\u8a66_id")]], 1) + eq_(row[t3.c[ue("unitable1_\u6e2c\u8a66")]], 5) + eq_(row[t3.c[u("Unitéble2_b")]], 1) + eq_(row[t3.c[ue("\u6e2c\u8a66_self")]], 1) def test_reflect(self): - t1.insert().execute({u('méil'): 2, ue('\u6e2c\u8a66'): 7}) - t2.insert().execute({u('a'): 2, u('b'): 2}) - t3.insert().execute({ue('\u6e2c\u8a66_id'): 2, - ue('unitable1_\u6e2c\u8a66'): 7, - u('Unitéble2_b'): 2, - ue('\u6e2c\u8a66_self'): 2}) + t1.insert().execute({u("méil"): 2, ue("\u6e2c\u8a66"): 7}) + t2.insert().execute({u("a"): 2, u("b"): 2}) + t3.insert().execute( + { + ue("\u6e2c\u8a66_id"): 2, + ue("unitable1_\u6e2c\u8a66"): 7, + u("Unitéble2_b"): 2, + ue("\u6e2c\u8a66_self"): 2, + } + ) meta = MetaData(testing.db) tt1 = Table(t1.name, meta, autoload=True) tt2 = Table(t2.name, meta, autoload=True) tt3 = Table(t3.name, meta, autoload=True) - tt1.insert().execute({u('méil'): 1, ue('\u6e2c\u8a66'): 5}) - tt2.insert().execute({u('méil'): 1, ue('\u6e2c\u8a66'): 1}) - tt3.insert().execute({ue('\u6e2c\u8a66_id'): 1, - ue('unitable1_\u6e2c\u8a66'): 5, - u('Unitéble2_b'): 1, - ue('\u6e2c\u8a66_self'): 1}) + tt1.insert().execute({u("méil"): 1, ue("\u6e2c\u8a66"): 5}) + tt2.insert().execute({u("méil"): 1, ue("\u6e2c\u8a66"): 1}) + tt3.insert().execute( + { + ue("\u6e2c\u8a66_id"): 1, + ue("unitable1_\u6e2c\u8a66"): 5, + u("Unitéble2_b"): 1, + ue("\u6e2c\u8a66_self"): 1, + } + ) self.assert_( - tt1.select( - order_by=desc( - u('méil'))).execute().fetchall() == [ - (2, 7), (1, 5)]) + tt1.select(order_by=desc(u("méil"))).execute().fetchall() + == [(2, 7), (1, 5)] + ) + self.assert_( + tt2.select(order_by=desc(u("méil"))).execute().fetchall() + == [(2, 2), (1, 1)] + ) self.assert_( - tt2.select( - order_by=desc( - u('méil'))).execute().fetchall() == [ - (2, 2), (1, 1)]) - self.assert_(tt3.select(order_by=desc(ue('\u6e2c\u8a66_id'))). - execute().fetchall() == - [(2, 7, 2, 2), (1, 5, 1, 1)]) + tt3.select(order_by=desc(ue("\u6e2c\u8a66_id"))) + .execute() + .fetchall() + == [(2, 7, 2, 2), (1, 5, 1, 1)] + ) def test_repr(self): m = MetaData() t = Table( - ue('\u6e2c\u8a66'), - m, - Column( - ue('\u6e2c\u8a66_id'), - Integer)) + ue("\u6e2c\u8a66"), m, Column(ue("\u6e2c\u8a66_id"), Integer) + ) # I hardly understand what's going on with the backslashes in # this one on py2k vs. py3k - eq_(repr(t), - ("Table('\\u6e2c\\u8a66', MetaData(bind=None), " - "Column('\\u6e2c\\u8a66_id', Integer(), table=<\u6e2c\u8a66>), " - "schema=None)")) + eq_( + repr(t), + ( + "Table('\\u6e2c\\u8a66', MetaData(bind=None), " + "Column('\\u6e2c\\u8a66_id', Integer(), table=<\u6e2c\u8a66>), " + "schema=None)" + ), + ) diff --git a/test/sql/test_update.py b/test/sql/test_update.py index 56d12d9272..191ecb4d6b 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -1,88 +1,130 @@ -from sqlalchemy import Integer, String, ForeignKey, and_, or_, func, \ - literal, update, table, bindparam, column, select, exc, exists, text, \ - MetaData +from sqlalchemy import ( + Integer, + String, + ForeignKey, + and_, + or_, + func, + literal, + update, + table, + bindparam, + column, + select, + exc, + exists, + text, + MetaData, +) from sqlalchemy import testing from sqlalchemy.dialects import mysql from sqlalchemy.engine import default -from sqlalchemy.testing import AssertsCompiledSQL, eq_, fixtures, \ - assert_raises_message, assert_raises +from sqlalchemy.testing import ( + AssertsCompiledSQL, + eq_, + fixtures, + assert_raises_message, + assert_raises, +) from sqlalchemy.testing.schema import Table, Column from sqlalchemy import util class _UpdateFromTestBase(object): - @classmethod def define_tables(cls, metadata): - Table('mytable', metadata, - Column('myid', Integer), - Column('name', String(30)), - Column('description', String(50))) - Table('myothertable', metadata, - Column('otherid', Integer), - Column('othername', String(30))) - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30), nullable=False)) - Table('addresses', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', None, ForeignKey('users.id')), - Column('name', String(30), nullable=False), - Column('email_address', String(50), nullable=False)) - Table('dingalings', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('address_id', None, ForeignKey('addresses.id')), - Column('data', String(30))) - Table('update_w_default', metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('ycol', Integer, key='y'), - Column('data', String(30), onupdate=lambda: "hi")) + Table( + "mytable", + metadata, + Column("myid", Integer), + Column("name", String(30)), + Column("description", String(50)), + ) + Table( + "myothertable", + metadata, + Column("otherid", Integer), + Column("othername", String(30)), + ) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30), nullable=False), + ) + Table( + "addresses", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", None, ForeignKey("users.id")), + Column("name", String(30), nullable=False), + Column("email_address", String(50), nullable=False), + ) + Table( + "dingalings", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("address_id", None, ForeignKey("addresses.id")), + Column("data", String(30)), + ) + Table( + "update_w_default", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("ycol", Integer, key="y"), + Column("data", String(30), onupdate=lambda: "hi"), + ) @classmethod def fixtures(cls): return dict( users=( - ('id', 'name'), - (7, 'jack'), - (8, 'ed'), - (9, 'fred'), - (10, 'chuck') + ("id", "name"), + (7, "jack"), + (8, "ed"), + (9, "fred"), + (10, "chuck"), ), addresses=( - ('id', 'user_id', 'name', 'email_address'), - (1, 7, 'x', 'jack@bean.com'), - (2, 8, 'x', 'ed@wood.com'), - (3, 8, 'x', 'ed@bettyboop.com'), - (4, 8, 'x', 'ed@lala.com'), - (5, 9, 'x', 'fred@fred.com') + ("id", "user_id", "name", "email_address"), + (1, 7, "x", "jack@bean.com"), + (2, 8, "x", "ed@wood.com"), + (3, 8, "x", "ed@bettyboop.com"), + (4, 8, "x", "ed@lala.com"), + (5, 9, "x", "fred@fred.com"), ), dingalings=( - ('id', 'address_id', 'data'), - (1, 2, 'ding 1/2'), - (2, 5, 'ding 2/5') + ("id", "address_id", "data"), + (1, 2, "ding 1/2"), + (2, 5, "ding 2/5"), ), ) class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): - __dialect__ = 'default_enhanced' + __dialect__ = "default_enhanced" def test_update_literal_binds(self): table1 = self.tables.mytable table1 = self.tables.mytable - stmt = table1.update().values(name='jack').\ - where(table1.c.name == 'jill') + stmt = ( + table1.update().values(name="jack").where(table1.c.name == "jill") + ) self.assert_compile( stmt, "UPDATE mytable SET name='jack' WHERE mytable.name = 'jill'", - literal_binds=True) + literal_binds=True, + ) def test_correlated_update_one(self): table1 = self.tables.mytable @@ -91,27 +133,33 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): u = update( table1, values={ - table1.c.name: - text("(select name from mytable where id=mytable.id)") - } + table1.c.name: text( + "(select name from mytable where id=mytable.id)" + ) + }, ) self.assert_compile( u, "UPDATE mytable SET name=(select name from mytable " - "where id=mytable.id)") + "where id=mytable.id)", + ) def test_correlated_update_two(self): table1 = self.tables.mytable mt = table1.alias() - u = update(table1, values={ - table1.c.name: - select([mt.c.name], mt.c.myid == table1.c.myid) - }) + u = update( + table1, + values={ + table1.c.name: select([mt.c.name], mt.c.myid == table1.c.myid) + }, + ) self.assert_compile( - u, "UPDATE mytable SET name=(SELECT mytable_1.name FROM " + u, + "UPDATE mytable SET name=(SELECT mytable_1.name FROM " "mytable AS mytable_1 WHERE " - "mytable_1.myid = mytable.myid)") + "mytable_1.myid = mytable.myid)", + ) def test_correlated_update_three(self): table1 = self.tables.mytable @@ -119,12 +167,14 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): # test against a regular constructed subquery s = select([table2], table2.c.otherid == table1.c.myid) - u = update(table1, table1.c.name == 'jack', values={table1.c.name: s}) + u = update(table1, table1.c.name == "jack", values={table1.c.name: s}) self.assert_compile( - u, "UPDATE mytable SET name=(SELECT myothertable.otherid, " + u, + "UPDATE mytable SET name=(SELECT myothertable.otherid, " "myothertable.othername FROM myothertable WHERE " "myothertable.otherid = mytable.myid) " - "WHERE mytable.name = :name_1") + "WHERE mytable.name = :name_1", + ) def test_correlated_update_four(self): table1 = self.tables.mytable @@ -133,11 +183,13 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): # test a non-correlated WHERE clause s = select([table2.c.othername], table2.c.otherid == 7) u = update(table1, table1.c.name == s) - self.assert_compile(u, - "UPDATE mytable SET myid=:myid, name=:name, " - "description=:description WHERE mytable.name = " - "(SELECT myothertable.othername FROM myothertable " - "WHERE myothertable.otherid = :otherid_1)") + self.assert_compile( + u, + "UPDATE mytable SET myid=:myid, name=:name, " + "description=:description WHERE mytable.name = " + "(SELECT myothertable.othername FROM myothertable " + "WHERE myothertable.otherid = :otherid_1)", + ) def test_correlated_update_five(self): table1 = self.tables.mytable @@ -146,44 +198,56 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): # test one that is actually correlated... s = select([table2.c.othername], table2.c.otherid == table1.c.myid) u = table1.update(table1.c.name == s) - self.assert_compile(u, - "UPDATE mytable SET myid=:myid, name=:name, " - "description=:description WHERE mytable.name = " - "(SELECT myothertable.othername FROM myothertable " - "WHERE myothertable.otherid = mytable.myid)") + self.assert_compile( + u, + "UPDATE mytable SET myid=:myid, name=:name, " + "description=:description WHERE mytable.name = " + "(SELECT myothertable.othername FROM myothertable " + "WHERE myothertable.otherid = mytable.myid)", + ) def test_correlated_update_six(self): table1 = self.tables.mytable table2 = self.tables.myothertable # test correlated FROM implicit in WHERE and SET clauses - u = table1.update().values(name=table2.c.othername)\ - .where(table2.c.otherid == table1.c.myid) + u = ( + table1.update() + .values(name=table2.c.othername) + .where(table2.c.otherid == table1.c.myid) + ) self.assert_compile( - u, "UPDATE mytable SET name=myothertable.othername " - "FROM myothertable WHERE myothertable.otherid = mytable.myid") + u, + "UPDATE mytable SET name=myothertable.othername " + "FROM myothertable WHERE myothertable.otherid = mytable.myid", + ) def test_correlated_update_seven(self): table1 = self.tables.mytable table2 = self.tables.myothertable - u = table1.update().values(name='foo')\ - .where(table2.c.otherid == table1.c.myid) + u = ( + table1.update() + .values(name="foo") + .where(table2.c.otherid == table1.c.myid) + ) # this is the "default_enhanced" compiler. there's no UPDATE FROM # in the base compiler. # See also test/dialect/mssql/test_compiler->test_update_from(). self.assert_compile( - u, "UPDATE mytable SET name=:name " - "FROM myothertable WHERE myothertable.otherid = mytable.myid") + u, + "UPDATE mytable SET name=:name " + "FROM myothertable WHERE myothertable.otherid = mytable.myid", + ) def test_binds_that_match_columns(self): """test bind params named after column names replace the normal SET/VALUES generation.""" - t = table('foo', column('x'), column('y')) + t = table("foo", column("x"), column("y")) - u = t.update().where(t.c.x == bindparam('x')) + u = t.update().where(t.c.x == bindparam("x")) assert_raises(exc.CompileError, u.compile) @@ -191,186 +255,205 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): assert_raises(exc.CompileError, u.values(x=7).compile) - self.assert_compile(u.values(y=7), - "UPDATE foo SET y=:y WHERE foo.x = :x") + self.assert_compile( + u.values(y=7), "UPDATE foo SET y=:y WHERE foo.x = :x" + ) - assert_raises(exc.CompileError, - u.values(x=7).compile, column_keys=['x', 'y']) - assert_raises(exc.CompileError, u.compile, column_keys=['x', 'y']) + assert_raises( + exc.CompileError, u.values(x=7).compile, column_keys=["x", "y"] + ) + assert_raises(exc.CompileError, u.compile, column_keys=["x", "y"]) self.assert_compile( - u.values( - x=3 + - bindparam('x')), - "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x") + u.values(x=3 + bindparam("x")), + "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x", + ) self.assert_compile( - u.values( - x=3 + - bindparam('x')), + u.values(x=3 + bindparam("x")), "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x", - params={ - 'x': 1}) + params={"x": 1}, + ) self.assert_compile( - u.values( - x=3 + - bindparam('x')), + u.values(x=3 + bindparam("x")), "UPDATE foo SET x=(:param_1 + :x), y=:y WHERE foo.x = :x", - params={ - 'x': 1, - 'y': 2}) + params={"x": 1, "y": 2}, + ) def test_labels_no_collision(self): - t = table('foo', column('id'), column('foo_id')) + t = table("foo", column("id"), column("foo_id")) self.assert_compile( t.update().where(t.c.id == 5), - "UPDATE foo SET id=:id, foo_id=:foo_id WHERE foo.id = :id_1" + "UPDATE foo SET id=:id, foo_id=:foo_id WHERE foo.id = :id_1", ) self.assert_compile( t.update().where(t.c.id == bindparam(key=t.c.id._label)), - "UPDATE foo SET id=:id, foo_id=:foo_id WHERE foo.id = :foo_id_1" + "UPDATE foo SET id=:id, foo_id=:foo_id WHERE foo.id = :foo_id_1", ) def test_inline_defaults(self): m = MetaData() - foo = Table('foo', m, - Column('id', Integer)) - - t = Table('test', m, - Column('col1', Integer, onupdate=func.foo(1)), - Column('col2', Integer, onupdate=select( - [func.coalesce(func.max(foo.c.id))])), - Column('col3', String(30)) - ) + foo = Table("foo", m, Column("id", Integer)) + + t = Table( + "test", + m, + Column("col1", Integer, onupdate=func.foo(1)), + Column( + "col2", + Integer, + onupdate=select([func.coalesce(func.max(foo.c.id))]), + ), + Column("col3", String(30)), + ) - self.assert_compile(t.update(inline=True, values={'col3': 'foo'}), - "UPDATE test SET col1=foo(:foo_1), col2=(SELECT " - "coalesce(max(foo.id)) AS coalesce_1 FROM foo), " - "col3=:col3") + self.assert_compile( + t.update(inline=True, values={"col3": "foo"}), + "UPDATE test SET col1=foo(:foo_1), col2=(SELECT " + "coalesce(max(foo.id)) AS coalesce_1 FROM foo), " + "col3=:col3", + ) def test_update_1(self): table1 = self.tables.mytable self.assert_compile( update(table1, table1.c.myid == 7), - 'UPDATE mytable SET name=:name WHERE mytable.myid = :myid_1', - params={table1.c.name: 'fred'}) + "UPDATE mytable SET name=:name WHERE mytable.myid = :myid_1", + params={table1.c.name: "fred"}, + ) def test_update_2(self): table1 = self.tables.mytable self.assert_compile( - table1.update(). - where(table1.c.myid == 7). - values({table1.c.myid: 5}), - 'UPDATE mytable SET myid=:myid WHERE mytable.myid = :myid_1', - checkparams={'myid': 5, 'myid_1': 7}) + table1.update() + .where(table1.c.myid == 7) + .values({table1.c.myid: 5}), + "UPDATE mytable SET myid=:myid WHERE mytable.myid = :myid_1", + checkparams={"myid": 5, "myid_1": 7}, + ) def test_update_3(self): table1 = self.tables.mytable self.assert_compile( update(table1, table1.c.myid == 7), - 'UPDATE mytable SET name=:name WHERE mytable.myid = :myid_1', - params={'name': 'fred'}) + "UPDATE mytable SET name=:name WHERE mytable.myid = :myid_1", + params={"name": "fred"}, + ) def test_update_4(self): table1 = self.tables.mytable self.assert_compile( update(table1, values={table1.c.name: table1.c.myid}), - 'UPDATE mytable SET name=mytable.myid') + "UPDATE mytable SET name=mytable.myid", + ) def test_update_5(self): table1 = self.tables.mytable self.assert_compile( - update(table1, - whereclause=table1.c.name == bindparam('crit'), - values={table1.c.name: 'hi'}), - 'UPDATE mytable SET name=:name WHERE mytable.name = :crit', - params={'crit': 'notthere'}, - checkparams={'crit': 'notthere', 'name': 'hi'}) + update( + table1, + whereclause=table1.c.name == bindparam("crit"), + values={table1.c.name: "hi"}, + ), + "UPDATE mytable SET name=:name WHERE mytable.name = :crit", + params={"crit": "notthere"}, + checkparams={"crit": "notthere", "name": "hi"}, + ) def test_update_6(self): table1 = self.tables.mytable self.assert_compile( - update(table1, - table1.c.myid == 12, - values={table1.c.name: table1.c.myid}), - 'UPDATE mytable ' - 'SET name=mytable.myid, description=:description ' - 'WHERE mytable.myid = :myid_1', - params={'description': 'test'}, - checkparams={'description': 'test', 'myid_1': 12}) + update( + table1, + table1.c.myid == 12, + values={table1.c.name: table1.c.myid}, + ), + "UPDATE mytable " + "SET name=mytable.myid, description=:description " + "WHERE mytable.myid = :myid_1", + params={"description": "test"}, + checkparams={"description": "test", "myid_1": 12}, + ) def test_update_7(self): table1 = self.tables.mytable self.assert_compile( update(table1, table1.c.myid == 12, values={table1.c.myid: 9}), - 'UPDATE mytable ' - 'SET myid=:myid, description=:description ' - 'WHERE mytable.myid = :myid_1', - params={'myid_1': 12, 'myid': 9, 'description': 'test'}) + "UPDATE mytable " + "SET myid=:myid, description=:description " + "WHERE mytable.myid = :myid_1", + params={"myid_1": 12, "myid": 9, "description": "test"}, + ) def test_update_8(self): table1 = self.tables.mytable self.assert_compile( update(table1, table1.c.myid == 12), - 'UPDATE mytable SET myid=:myid WHERE mytable.myid = :myid_1', - params={'myid': 18}, checkparams={'myid': 18, 'myid_1': 12}) + "UPDATE mytable SET myid=:myid WHERE mytable.myid = :myid_1", + params={"myid": 18}, + checkparams={"myid": 18, "myid_1": 12}, + ) def test_update_9(self): table1 = self.tables.mytable - s = table1.update(table1.c.myid == 12, values={table1.c.name: 'lala'}) - c = s.compile(column_keys=['id', 'name']) + s = table1.update(table1.c.myid == 12, values={table1.c.name: "lala"}) + c = s.compile(column_keys=["id", "name"]) eq_(str(s), str(c)) def test_update_10(self): table1 = self.tables.mytable v1 = {table1.c.name: table1.c.myid} - v2 = {table1.c.name: table1.c.name + 'foo'} + v2 = {table1.c.name: table1.c.name + "foo"} self.assert_compile( update(table1, table1.c.myid == 12, values=v1).values(v2), - 'UPDATE mytable ' - 'SET ' - 'name=(mytable.name || :name_1), ' - 'description=:description ' - 'WHERE mytable.myid = :myid_1', - params={'description': 'test'}) + "UPDATE mytable " + "SET " + "name=(mytable.name || :name_1), " + "description=:description " + "WHERE mytable.myid = :myid_1", + params={"description": "test"}, + ) def test_update_11(self): table1 = self.tables.mytable values = { - table1.c.name: table1.c.name + 'lala', - table1.c.myid: func.do_stuff(table1.c.myid, literal('hoho')) + table1.c.name: table1.c.name + "lala", + table1.c.myid: func.do_stuff(table1.c.myid, literal("hoho")), } self.assert_compile( update( table1, - (table1.c.myid == func.hoho(4)) & ( - table1.c.name == literal('foo') + - table1.c.name + - literal('lala')), - values=values), - 'UPDATE mytable ' - 'SET ' - 'myid=do_stuff(mytable.myid, :param_1), ' - 'name=(mytable.name || :name_1) ' - 'WHERE ' - 'mytable.myid = hoho(:hoho_1) AND ' - 'mytable.name = :param_2 || mytable.name || :param_3') + (table1.c.myid == func.hoho(4)) + & ( + table1.c.name + == literal("foo") + table1.c.name + literal("lala") + ), + values=values, + ), + "UPDATE mytable " + "SET " + "myid=do_stuff(mytable.myid, :param_1), " + "name=(mytable.name || :name_1) " + "WHERE " + "mytable.myid = hoho(:hoho_1) AND " + "mytable.name = :param_2 || mytable.name || :param_3", + ) def test_unconsumed_names_kwargs(self): t = table("t", column("x"), column("y")) @@ -388,8 +471,11 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): assert_raises_message( exc.CompileError, "Unconsumed column names: j", - t.update().values(x=5, j=7).values({t2.c.z: 5}). - where(t.c.x == t2.c.q).compile, + t.update() + .values(x=5, j=7) + .values({t2.c.z: 5}) + .where(t.c.x == t2.c.q) + .compile, ) def test_unconsumed_names_kwargs_w_keys(self): @@ -399,7 +485,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): exc.CompileError, "Unconsumed column names: j", t.update().values(x=5, j=7).compile, - column_keys=['j'] + column_keys=["j"], ) def test_update_ordered_parameters_1(self): @@ -408,25 +494,28 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): # Confirm that we can pass values as list value pairs # note these are ordered *differently* from table.c values = [ - (table1.c.name, table1.c.name + 'lala'), - (table1.c.myid, func.do_stuff(table1.c.myid, literal('hoho'))), + (table1.c.name, table1.c.name + "lala"), + (table1.c.myid, func.do_stuff(table1.c.myid, literal("hoho"))), ] self.assert_compile( update( table1, - (table1.c.myid == func.hoho(4)) & ( - table1.c.name == literal('foo') + - table1.c.name + - literal('lala')), + (table1.c.myid == func.hoho(4)) + & ( + table1.c.name + == literal("foo") + table1.c.name + literal("lala") + ), preserve_parameter_order=True, - values=values), - 'UPDATE mytable ' - 'SET ' - 'name=(mytable.name || :name_1), ' - 'myid=do_stuff(mytable.myid, :param_1) ' - 'WHERE ' - 'mytable.myid = hoho(:hoho_1) AND ' - 'mytable.name = :param_2 || mytable.name || :param_3') + values=values, + ), + "UPDATE mytable " + "SET " + "name=(mytable.name || :name_1), " + "myid=do_stuff(mytable.myid, :param_1) " + "WHERE " + "mytable.myid = hoho(:hoho_1) AND " + "mytable.name = :param_2 || mytable.name || :param_3", + ) def test_update_ordered_parameters_2(self): table1 = self.tables.mytable @@ -434,39 +523,39 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): # Confirm that we can pass values as list value pairs # note these are ordered *differently* from table.c values = [ - (table1.c.name, table1.c.name + 'lala'), - ('description', 'some desc'), - (table1.c.myid, func.do_stuff(table1.c.myid, literal('hoho'))) + (table1.c.name, table1.c.name + "lala"), + ("description", "some desc"), + (table1.c.myid, func.do_stuff(table1.c.myid, literal("hoho"))), ] self.assert_compile( update( table1, - (table1.c.myid == func.hoho(4)) & ( - table1.c.name == literal('foo') + - table1.c.name + - literal('lala')), - preserve_parameter_order=True).values(values), - 'UPDATE mytable ' - 'SET ' - 'name=(mytable.name || :name_1), ' - 'description=:description, ' - 'myid=do_stuff(mytable.myid, :param_1) ' - 'WHERE ' - 'mytable.myid = hoho(:hoho_1) AND ' - 'mytable.name = :param_2 || mytable.name || :param_3') + (table1.c.myid == func.hoho(4)) + & ( + table1.c.name + == literal("foo") + table1.c.name + literal("lala") + ), + preserve_parameter_order=True, + ).values(values), + "UPDATE mytable " + "SET " + "name=(mytable.name || :name_1), " + "description=:description, " + "myid=do_stuff(mytable.myid, :param_1) " + "WHERE " + "mytable.myid = hoho(:hoho_1) AND " + "mytable.name = :param_2 || mytable.name || :param_3", + ) def test_update_ordered_parameters_fire_onupdate(self): table = self.tables.update_w_default - values = [ - (table.c.y, table.c.x + 5), - ('x', 10) - ] + values = [(table.c.y, table.c.x + 5), ("x", 10)] self.assert_compile( table.update(preserve_parameter_order=True).values(values), "UPDATE update_w_default SET ycol=(update_w_default.x + :x_1), " - "x=:x, data=:data" + "x=:x, data=:data", ) def test_update_ordered_parameters_override_onupdate(self): @@ -475,13 +564,13 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): values = [ (table.c.y, table.c.x + 5), (table.c.data, table.c.x + 10), - ('x', 10) + ("x", 10), ] self.assert_compile( table.update(preserve_parameter_order=True).values(values), "UPDATE update_w_default SET ycol=(update_w_default.x + :x_1), " - "data=(update_w_default.x + :x_2), x=:x" + "data=(update_w_default.x + :x_2), x=:x", ) def test_update_preserve_order_reqs_listtups(self): @@ -491,7 +580,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): r"When preserve_parameter_order is True, values\(\) " r"only accepts a list of 2-tuples", table1.update(preserve_parameter_order=True).values, - {"description": "foo", "name": "bar"} + {"description": "foo", "name": "bar"}, ) def test_update_ordereddict(self): @@ -499,53 +588,65 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): # Confirm that ordered dicts are treated as normal dicts, # columns sorted in table order - values = util.OrderedDict(( - (table1.c.name, table1.c.name + 'lala'), - (table1.c.myid, func.do_stuff(table1.c.myid, literal('hoho'))))) + values = util.OrderedDict( + ( + (table1.c.name, table1.c.name + "lala"), + (table1.c.myid, func.do_stuff(table1.c.myid, literal("hoho"))), + ) + ) self.assert_compile( update( table1, - (table1.c.myid == func.hoho(4)) & ( - table1.c.name == literal('foo') + - table1.c.name + - literal('lala')), - values=values), - 'UPDATE mytable ' - 'SET ' - 'myid=do_stuff(mytable.myid, :param_1), ' - 'name=(mytable.name || :name_1) ' - 'WHERE ' - 'mytable.myid = hoho(:hoho_1) AND ' - 'mytable.name = :param_2 || mytable.name || :param_3') + (table1.c.myid == func.hoho(4)) + & ( + table1.c.name + == literal("foo") + table1.c.name + literal("lala") + ), + values=values, + ), + "UPDATE mytable " + "SET " + "myid=do_stuff(mytable.myid, :param_1), " + "name=(mytable.name || :name_1) " + "WHERE " + "mytable.myid = hoho(:hoho_1) AND " + "mytable.name = :param_2 || mytable.name || :param_3", + ) def test_where_empty(self): table1 = self.tables.mytable self.assert_compile( table1.update().where(and_()), "UPDATE mytable SET myid=:myid, name=:name, " - "description=:description") + "description=:description", + ) self.assert_compile( - table1.update().where( - or_()), + table1.update().where(or_()), "UPDATE mytable SET myid=:myid, name=:name, " - "description=:description") + "description=:description", + ) def test_prefix_with(self): table1 = self.tables.mytable - stmt = table1.update().\ - prefix_with('A', 'B', dialect='mysql').\ - prefix_with('C', 'D') + stmt = ( + table1.update() + .prefix_with("A", "B", dialect="mysql") + .prefix_with("C", "D") + ) - self.assert_compile(stmt, - 'UPDATE C D mytable SET myid=:myid, name=:name, ' - 'description=:description') + self.assert_compile( + stmt, + "UPDATE C D mytable SET myid=:myid, name=:name, " + "description=:description", + ) self.assert_compile( stmt, - 'UPDATE A B C D mytable SET myid=%s, name=%s, description=%s', - dialect=mysql.dialect()) + "UPDATE A B C D mytable SET myid=%s, name=%s, description=%s", + dialect=mysql.dialect(), + ) def test_update_to_expression(self): """test update from an expression. @@ -558,8 +659,10 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): table1 = self.tables.mytable expr = func.foo(table1.c.myid) eq_(expr.key, None) - self.assert_compile(table1.update().values({expr: 'bar'}), - 'UPDATE mytable SET foo(myid)=:param_1') + self.assert_compile( + table1.update().values({expr: "bar"}), + "UPDATE mytable SET foo(myid)=:param_1", + ) def test_update_bound_ordering(self): """test that bound parameters between the UPDATE and FROM clauses @@ -569,9 +672,11 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): table1 = self.tables.mytable table2 = self.tables.myothertable sel = select([table2]).where(table2.c.otherid == 5).alias() - upd = table1.update().\ - where(table1.c.name == sel.c.othername).\ - values(name='foo') + upd = ( + table1.update() + .where(table1.c.name == sel.c.othername) + .values(name="foo") + ) dialect = default.StrCompileDialect() dialect.positional = True @@ -583,8 +688,8 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): "FROM myothertable " "WHERE myothertable.otherid = :otherid_1) AS anon_1 " "WHERE mytable.name = anon_1.othername", - checkpositional=('foo', 5), - dialect=dialect + checkpositional=("foo", 5), + dialect=dialect, ) self.assert_compile( @@ -594,73 +699,79 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): "FROM myothertable " "WHERE myothertable.otherid = %s) AS anon_1 SET mytable.name=%s " "WHERE mytable.name = anon_1.othername", - checkpositional=(5, 'foo'), - dialect=mysql.dialect() + checkpositional=(5, "foo"), + dialect=mysql.dialect(), ) -class UpdateFromCompileTest(_UpdateFromTestBase, fixtures.TablesTest, - AssertsCompiledSQL): - __dialect__ = 'default_enhanced' +class UpdateFromCompileTest( + _UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL +): + __dialect__ = "default_enhanced" run_create_tables = run_inserts = run_deletes = None def test_alias_one(self): table1 = self.tables.mytable - talias1 = table1.alias('t1') + talias1 = table1.alias("t1") # this case is nonsensical. the UPDATE is entirely # against the alias, but we name the table-bound column # in values. The behavior here isn't really defined self.assert_compile( - update(talias1, talias1.c.myid == 7). - values({table1.c.name: "fred"}), - 'UPDATE mytable AS t1 ' - 'SET name=:name ' - 'WHERE t1.myid = :myid_1') + update(talias1, talias1.c.myid == 7).values( + {table1.c.name: "fred"} + ), + "UPDATE mytable AS t1 " + "SET name=:name " + "WHERE t1.myid = :myid_1", + ) def test_alias_two(self): table1 = self.tables.mytable - talias1 = table1.alias('t1') + talias1 = table1.alias("t1") # Here, compared to # test_alias_one(), here we actually have UPDATE..FROM, # which is causing the "table1.c.name" param to be handled # as an "extra table", hence we see the full table name rendered. self.assert_compile( - update(talias1, table1.c.myid == 7). - values({table1.c.name: 'fred'}), - 'UPDATE mytable AS t1 ' - 'SET name=:mytable_name ' - 'FROM mytable ' - 'WHERE mytable.myid = :myid_1', - checkparams={'mytable_name': 'fred', 'myid_1': 7}, + update(talias1, table1.c.myid == 7).values( + {table1.c.name: "fred"} + ), + "UPDATE mytable AS t1 " + "SET name=:mytable_name " + "FROM mytable " + "WHERE mytable.myid = :myid_1", + checkparams={"mytable_name": "fred", "myid_1": 7}, ) def test_alias_two_mysql(self): table1 = self.tables.mytable - talias1 = table1.alias('t1') + talias1 = table1.alias("t1") self.assert_compile( - update(talias1, table1.c.myid == 7). - values({table1.c.name: 'fred'}), + update(talias1, table1.c.myid == 7).values( + {table1.c.name: "fred"} + ), "UPDATE mytable AS t1, mytable SET mytable.name=%s " "WHERE mytable.myid = %s", - checkparams={'mytable_name': 'fred', 'myid_1': 7}, - dialect='mysql') + checkparams={"mytable_name": "fred", "myid_1": 7}, + dialect="mysql", + ) def test_update_from_multitable_same_name_mysql(self): users, addresses = self.tables.users, self.tables.addresses self.assert_compile( - users.update(). - values(name='newname'). - values({addresses.c.name: "new address"}). - where(users.c.id == addresses.c.user_id), + users.update() + .values(name="newname") + .values({addresses.c.name: "new address"}) + .where(users.c.id == addresses.c.user_id), "UPDATE users, addresses SET addresses.name=%s, " "users.name=%s WHERE users.id = addresses.user_id", - checkparams={'addresses_name': 'new address', 'name': 'newname'}, - dialect='mysql' + checkparams={"addresses_name": "new address", "name": "newname"}, + dialect="mysql", ) def test_update_from_join_mysql(self): @@ -668,120 +779,117 @@ class UpdateFromCompileTest(_UpdateFromTestBase, fixtures.TablesTest, j = users.join(addresses) self.assert_compile( - update(j). - values(name='newname'). - where(addresses.c.email_address == 'e1'), + update(j) + .values(name="newname") + .where(addresses.c.email_address == "e1"), "" - 'UPDATE users ' - 'INNER JOIN addresses ON users.id = addresses.user_id ' - 'SET users.name=%s ' - 'WHERE ' - 'addresses.email_address = %s', - checkparams={'email_address_1': 'e1', 'name': 'newname'}, - dialect=mysql.dialect()) + "UPDATE users " + "INNER JOIN addresses ON users.id = addresses.user_id " + "SET users.name=%s " + "WHERE " + "addresses.email_address = %s", + checkparams={"email_address_1": "e1", "name": "newname"}, + dialect=mysql.dialect(), + ) def test_render_table(self): users, addresses = self.tables.users, self.tables.addresses self.assert_compile( - users.update(). - values(name='newname'). - where(users.c.id == addresses.c.user_id). - where(addresses.c.email_address == 'e1'), - 'UPDATE users ' - 'SET name=:name FROM addresses ' - 'WHERE ' - 'users.id = addresses.user_id AND ' - 'addresses.email_address = :email_address_1', - checkparams={'email_address_1': 'e1', 'name': 'newname'}) + users.update() + .values(name="newname") + .where(users.c.id == addresses.c.user_id) + .where(addresses.c.email_address == "e1"), + "UPDATE users " + "SET name=:name FROM addresses " + "WHERE " + "users.id = addresses.user_id AND " + "addresses.email_address = :email_address_1", + checkparams={"email_address_1": "e1", "name": "newname"}, + ) def test_render_multi_table(self): users = self.tables.users addresses = self.tables.addresses dingalings = self.tables.dingalings - checkparams = { - 'email_address_1': 'e1', - 'id_1': 2, - 'name': 'newname' - } + checkparams = {"email_address_1": "e1", "id_1": 2, "name": "newname"} self.assert_compile( - users.update(). - values(name='newname'). - where(users.c.id == addresses.c.user_id). - where(addresses.c.email_address == 'e1'). - where(addresses.c.id == dingalings.c.address_id). - where(dingalings.c.id == 2), - 'UPDATE users ' - 'SET name=:name ' - 'FROM addresses, dingalings ' - 'WHERE ' - 'users.id = addresses.user_id AND ' - 'addresses.email_address = :email_address_1 AND ' - 'addresses.id = dingalings.address_id AND ' - 'dingalings.id = :id_1', - checkparams=checkparams) + users.update() + .values(name="newname") + .where(users.c.id == addresses.c.user_id) + .where(addresses.c.email_address == "e1") + .where(addresses.c.id == dingalings.c.address_id) + .where(dingalings.c.id == 2), + "UPDATE users " + "SET name=:name " + "FROM addresses, dingalings " + "WHERE " + "users.id = addresses.user_id AND " + "addresses.email_address = :email_address_1 AND " + "addresses.id = dingalings.address_id AND " + "dingalings.id = :id_1", + checkparams=checkparams, + ) def test_render_table_mysql(self): users, addresses = self.tables.users, self.tables.addresses self.assert_compile( - users.update(). - values(name='newname'). - where(users.c.id == addresses.c.user_id). - where(addresses.c.email_address == 'e1'), - 'UPDATE users, addresses ' - 'SET users.name=%s ' - 'WHERE ' - 'users.id = addresses.user_id AND ' - 'addresses.email_address = %s', - checkparams={'email_address_1': 'e1', 'name': 'newname'}, - dialect=mysql.dialect()) + users.update() + .values(name="newname") + .where(users.c.id == addresses.c.user_id) + .where(addresses.c.email_address == "e1"), + "UPDATE users, addresses " + "SET users.name=%s " + "WHERE " + "users.id = addresses.user_id AND " + "addresses.email_address = %s", + checkparams={"email_address_1": "e1", "name": "newname"}, + dialect=mysql.dialect(), + ) def test_render_subquery(self): users, addresses = self.tables.users, self.tables.addresses - checkparams = { - 'email_address_1': 'e1', - 'id_1': 7, - 'name': 'newname' - } + checkparams = {"email_address_1": "e1", "id_1": 7, "name": "newname"} - cols = [ - addresses.c.id, - addresses.c.user_id, - addresses.c.email_address - ] + cols = [addresses.c.id, addresses.c.user_id, addresses.c.email_address] subq = select(cols).where(addresses.c.id == 7).alias() self.assert_compile( - users.update(). - values(name='newname'). - where(users.c.id == subq.c.user_id). - where(subq.c.email_address == 'e1'), - 'UPDATE users ' - 'SET name=:name FROM (' - 'SELECT ' - 'addresses.id AS id, ' - 'addresses.user_id AS user_id, ' - 'addresses.email_address AS email_address ' - 'FROM addresses ' - 'WHERE addresses.id = :id_1' - ') AS anon_1 ' - 'WHERE users.id = anon_1.user_id ' - 'AND anon_1.email_address = :email_address_1', - checkparams=checkparams) + users.update() + .values(name="newname") + .where(users.c.id == subq.c.user_id) + .where(subq.c.email_address == "e1"), + "UPDATE users " + "SET name=:name FROM (" + "SELECT " + "addresses.id AS id, " + "addresses.user_id AS user_id, " + "addresses.email_address AS email_address " + "FROM addresses " + "WHERE addresses.id = :id_1" + ") AS anon_1 " + "WHERE users.id = anon_1.user_id " + "AND anon_1.email_address = :email_address_1", + checkparams=checkparams, + ) def test_correlation_to_extra(self): users, addresses = self.tables.users, self.tables.addresses - stmt = users.update().values(name="newname").where( - users.c.id == addresses.c.user_id - ).where( - ~exists().where( - addresses.c.user_id == users.c.id - ).where(addresses.c.email_address == 'foo').correlate(addresses) + stmt = ( + users.update() + .values(name="newname") + .where(users.c.id == addresses.c.user_id) + .where( + ~exists() + .where(addresses.c.user_id == users.c.id) + .where(addresses.c.email_address == "foo") + .correlate(addresses) + ) ) self.assert_compile( @@ -789,18 +897,22 @@ class UpdateFromCompileTest(_UpdateFromTestBase, fixtures.TablesTest, "UPDATE users SET name=:name FROM addresses WHERE " "users.id = addresses.user_id AND NOT " "(EXISTS (SELECT * FROM users WHERE addresses.user_id = users.id " - "AND addresses.email_address = :email_address_1))" + "AND addresses.email_address = :email_address_1))", ) def test_dont_correlate_to_extra(self): users, addresses = self.tables.users, self.tables.addresses - stmt = users.update().values(name="newname").where( - users.c.id == addresses.c.user_id - ).where( - ~exists().where( - addresses.c.user_id == users.c.id - ).where(addresses.c.email_address == 'foo').correlate() + stmt = ( + users.update() + .values(name="newname") + .where(users.c.id == addresses.c.user_id) + .where( + ~exists() + .where(addresses.c.user_id == users.c.id) + .where(addresses.c.email_address == "foo") + .correlate() + ) ) self.assert_compile( @@ -809,24 +921,28 @@ class UpdateFromCompileTest(_UpdateFromTestBase, fixtures.TablesTest, "users.id = addresses.user_id AND NOT " "(EXISTS (SELECT * FROM addresses, users " "WHERE addresses.user_id = users.id " - "AND addresses.email_address = :email_address_1))" + "AND addresses.email_address = :email_address_1))", ) def test_autocorrelate_error(self): users, addresses = self.tables.users, self.tables.addresses - stmt = users.update().values(name="newname").where( - users.c.id == addresses.c.user_id - ).where( - ~exists().where( - addresses.c.user_id == users.c.id - ).where(addresses.c.email_address == 'foo') + stmt = ( + users.update() + .values(name="newname") + .where(users.c.id == addresses.c.user_id) + .where( + ~exists() + .where(addresses.c.user_id == users.c.id) + .where(addresses.c.email_address == "foo") + ) ) assert_raises_message( exc.InvalidRequestError, ".*returned no FROM clauses due to auto-correlation.*", - stmt.compile, dialect=default.StrCompileDialect() + stmt.compile, + dialect=default.StrCompileDialect(), ) @@ -838,17 +954,19 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): users, addresses = self.tables.users, self.tables.addresses testing.db.execute( - addresses.update(). - values(email_address=users.c.name). - where(users.c.id == addresses.c.user_id). - where(users.c.name == 'ed')) + addresses.update() + .values(email_address=users.c.name) + .where(users.c.id == addresses.c.user_id) + .where(users.c.name == "ed") + ) expected = [ - (1, 7, 'x', 'jack@bean.com'), - (2, 8, 'x', 'ed'), - (3, 8, 'x', 'ed'), - (4, 8, 'x', 'ed'), - (5, 9, 'x', 'fred@fred.com')] + (1, 7, "x", "jack@bean.com"), + (2, 8, "x", "ed"), + (3, 8, "x", "ed"), + (4, 8, "x", "ed"), + (5, 9, "x", "fred@fred.com"), + ] self._assert_addresses(addresses, expected) @testing.requires.update_from @@ -857,19 +975,20 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): a1 = addresses.alias() testing.db.execute( - addresses.update(). - values(email_address=users.c.name). - where(users.c.id == a1.c.user_id). - where(users.c.name == 'ed'). - where(a1.c.id == addresses.c.id) + addresses.update() + .values(email_address=users.c.name) + .where(users.c.id == a1.c.user_id) + .where(users.c.name == "ed") + .where(a1.c.id == addresses.c.id) ) expected = [ - (1, 7, 'x', 'jack@bean.com'), - (2, 8, 'x', 'ed'), - (3, 8, 'x', 'ed'), - (4, 8, 'x', 'ed'), - (5, 9, 'x', 'fred@fred.com')] + (1, 7, "x", "jack@bean.com"), + (2, 8, "x", "ed"), + (3, 8, "x", "ed"), + (4, 8, "x", "ed"), + (5, 9, "x", "fred@fred.com"), + ] self._assert_addresses(addresses, expected) @testing.requires.update_from @@ -879,108 +998,95 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): dingalings = self.tables.dingalings testing.db.execute( - addresses.update(). - values(email_address=users.c.name). - where(users.c.id == addresses.c.user_id). - where(users.c.name == 'ed'). - where(addresses.c.id == dingalings.c.address_id). - where(dingalings.c.id == 1)) + addresses.update() + .values(email_address=users.c.name) + .where(users.c.id == addresses.c.user_id) + .where(users.c.name == "ed") + .where(addresses.c.id == dingalings.c.address_id) + .where(dingalings.c.id == 1) + ) expected = [ - (1, 7, 'x', 'jack@bean.com'), - (2, 8, 'x', 'ed'), - (3, 8, 'x', 'ed@bettyboop.com'), - (4, 8, 'x', 'ed@lala.com'), - (5, 9, 'x', 'fred@fred.com')] + (1, 7, "x", "jack@bean.com"), + (2, 8, "x", "ed"), + (3, 8, "x", "ed@bettyboop.com"), + (4, 8, "x", "ed@lala.com"), + (5, 9, "x", "fred@fred.com"), + ] self._assert_addresses(addresses, expected) - @testing.only_on('mysql', 'Multi table update') + @testing.only_on("mysql", "Multi table update") def test_exec_multitable(self): users, addresses = self.tables.users, self.tables.addresses - values = { - addresses.c.email_address: 'updated', - users.c.name: 'ed2' - } + values = {addresses.c.email_address: "updated", users.c.name: "ed2"} testing.db.execute( - addresses.update(). - values(values). - where(users.c.id == addresses.c.user_id). - where(users.c.name == 'ed')) + addresses.update() + .values(values) + .where(users.c.id == addresses.c.user_id) + .where(users.c.name == "ed") + ) expected = [ - (1, 7, 'x', 'jack@bean.com'), - (2, 8, 'x', 'updated'), - (3, 8, 'x', 'updated'), - (4, 8, 'x', 'updated'), - (5, 9, 'x', 'fred@fred.com')] + (1, 7, "x", "jack@bean.com"), + (2, 8, "x", "updated"), + (3, 8, "x", "updated"), + (4, 8, "x", "updated"), + (5, 9, "x", "fred@fred.com"), + ] self._assert_addresses(addresses, expected) - expected = [ - (7, 'jack'), - (8, 'ed2'), - (9, 'fred'), - (10, 'chuck')] + expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")] self._assert_users(users, expected) - @testing.only_on('mysql', 'Multi table update') + @testing.only_on("mysql", "Multi table update") def test_exec_join_multitable(self): users, addresses = self.tables.users, self.tables.addresses - values = { - addresses.c.email_address: 'updated', - users.c.name: 'ed2' - } + values = {addresses.c.email_address: "updated", users.c.name: "ed2"} testing.db.execute( - update(users.join(addresses)). - values(values). - where(users.c.name == 'ed')) + update(users.join(addresses)) + .values(values) + .where(users.c.name == "ed") + ) expected = [ - (1, 7, 'x', 'jack@bean.com'), - (2, 8, 'x', 'updated'), - (3, 8, 'x', 'updated'), - (4, 8, 'x', 'updated'), - (5, 9, 'x', 'fred@fred.com')] + (1, 7, "x", "jack@bean.com"), + (2, 8, "x", "updated"), + (3, 8, "x", "updated"), + (4, 8, "x", "updated"), + (5, 9, "x", "fred@fred.com"), + ] self._assert_addresses(addresses, expected) - expected = [ - (7, 'jack'), - (8, 'ed2'), - (9, 'fred'), - (10, 'chuck')] + expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")] self._assert_users(users, expected) - @testing.only_on('mysql', 'Multi table update') + @testing.only_on("mysql", "Multi table update") def test_exec_multitable_same_name(self): users, addresses = self.tables.users, self.tables.addresses - values = { - addresses.c.name: 'ad_ed2', - users.c.name: 'ed2' - } + values = {addresses.c.name: "ad_ed2", users.c.name: "ed2"} testing.db.execute( - addresses.update(). - values(values). - where(users.c.id == addresses.c.user_id). - where(users.c.name == 'ed')) + addresses.update() + .values(values) + .where(users.c.id == addresses.c.user_id) + .where(users.c.name == "ed") + ) expected = [ - (1, 7, 'x', 'jack@bean.com'), - (2, 8, 'ad_ed2', 'ed@wood.com'), - (3, 8, 'ad_ed2', 'ed@bettyboop.com'), - (4, 8, 'ad_ed2', 'ed@lala.com'), - (5, 9, 'x', 'fred@fred.com')] + (1, 7, "x", "jack@bean.com"), + (2, 8, "ad_ed2", "ed@wood.com"), + (3, 8, "ad_ed2", "ed@bettyboop.com"), + (4, 8, "ad_ed2", "ed@lala.com"), + (5, 9, "x", "fred@fred.com"), + ] self._assert_addresses(addresses, expected) - expected = [ - (7, 'jack'), - (8, 'ed2'), - (9, 'fred'), - (10, 'chuck')] + expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")] self._assert_users(users, expected) def _assert_addresses(self, addresses, expected): @@ -992,136 +1098,137 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): eq_(testing.db.execute(stmt).fetchall(), expected) -class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase, - fixtures.TablesTest): +class UpdateFromMultiTableUpdateDefaultsTest( + _UpdateFromTestBase, fixtures.TablesTest +): __backend__ = True @classmethod def define_tables(cls, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30), nullable=False), - Column('some_update', String(30), onupdate='im the update')) - - Table('addresses', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', None, ForeignKey('users.id')), - Column('email_address', String(50), nullable=False), - ) - - Table('foobar', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('user_id', None, ForeignKey('users.id')), - Column('data', String(30)), - Column('some_update', String(30), onupdate='im the other update') - ) + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30), nullable=False), + Column("some_update", String(30), onupdate="im the update"), + ) + + Table( + "addresses", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", None, ForeignKey("users.id")), + Column("email_address", String(50), nullable=False), + ) + + Table( + "foobar", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", None, ForeignKey("users.id")), + Column("data", String(30)), + Column("some_update", String(30), onupdate="im the other update"), + ) @classmethod def fixtures(cls): return dict( users=( - ('id', 'name', 'some_update'), - (8, 'ed', 'value'), - (9, 'fred', 'value'), + ("id", "name", "some_update"), + (8, "ed", "value"), + (9, "fred", "value"), ), addresses=( - ('id', 'user_id', 'email_address'), - (2, 8, 'ed@wood.com'), - (3, 8, 'ed@bettyboop.com'), - (4, 9, 'fred@fred.com') + ("id", "user_id", "email_address"), + (2, 8, "ed@wood.com"), + (3, 8, "ed@bettyboop.com"), + (4, 9, "fred@fred.com"), ), foobar=( - ('id', 'user_id', 'data'), - (2, 8, 'd1'), - (3, 8, 'd2'), - (4, 9, 'd3') - ) + ("id", "user_id", "data"), + (2, 8, "d1"), + (3, 8, "d2"), + (4, 9, "d3"), + ), ) - @testing.only_on('mysql', 'Multi table update') + @testing.only_on("mysql", "Multi table update") def test_defaults_second_table(self): users, addresses = self.tables.users, self.tables.addresses - values = { - addresses.c.email_address: 'updated', - users.c.name: 'ed2' - } + values = {addresses.c.email_address: "updated", users.c.name: "ed2"} ret = testing.db.execute( - addresses.update(). - values(values). - where(users.c.id == addresses.c.user_id). - where(users.c.name == 'ed')) + addresses.update() + .values(values) + .where(users.c.id == addresses.c.user_id) + .where(users.c.name == "ed") + ) eq_(set(ret.prefetch_cols()), set([users.c.some_update])) expected = [ - (2, 8, 'updated'), - (3, 8, 'updated'), - (4, 9, 'fred@fred.com')] + (2, 8, "updated"), + (3, 8, "updated"), + (4, 9, "fred@fred.com"), + ] self._assert_addresses(addresses, expected) - expected = [ - (8, 'ed2', 'im the update'), - (9, 'fred', 'value')] + expected = [(8, "ed2", "im the update"), (9, "fred", "value")] self._assert_users(users, expected) - @testing.only_on('mysql', 'Multi table update') + @testing.only_on("mysql", "Multi table update") def test_defaults_second_table_same_name(self): users, foobar = self.tables.users, self.tables.foobar - values = { - foobar.c.data: foobar.c.data + 'a', - users.c.name: 'ed2' - } + values = {foobar.c.data: foobar.c.data + "a", users.c.name: "ed2"} ret = testing.db.execute( - users.update(). - values(values). - where(users.c.id == foobar.c.user_id). - where(users.c.name == 'ed')) + users.update() + .values(values) + .where(users.c.id == foobar.c.user_id) + .where(users.c.name == "ed") + ) eq_( set(ret.prefetch_cols()), - set([users.c.some_update, foobar.c.some_update]) + set([users.c.some_update, foobar.c.some_update]), ) expected = [ - (2, 8, 'd1a', 'im the other update'), - (3, 8, 'd2a', 'im the other update'), - (4, 9, 'd3', None)] + (2, 8, "d1a", "im the other update"), + (3, 8, "d2a", "im the other update"), + (4, 9, "d3", None), + ] self._assert_foobar(foobar, expected) - expected = [ - (8, 'ed2', 'im the update'), - (9, 'fred', 'value')] + expected = [(8, "ed2", "im the update"), (9, "fred", "value")] self._assert_users(users, expected) - @testing.only_on('mysql', 'Multi table update') + @testing.only_on("mysql", "Multi table update") def test_no_defaults_second_table(self): users, addresses = self.tables.users, self.tables.addresses ret = testing.db.execute( - addresses.update(). - values({'email_address': users.c.name}). - where(users.c.id == addresses.c.user_id). - where(users.c.name == 'ed')) + addresses.update() + .values({"email_address": users.c.name}) + .where(users.c.id == addresses.c.user_id) + .where(users.c.name == "ed") + ) eq_(ret.prefetch_cols(), []) - expected = [ - (2, 8, 'ed'), - (3, 8, 'ed'), - (4, 9, 'fred@fred.com')] + expected = [(2, 8, "ed"), (3, 8, "ed"), (4, 9, "fred@fred.com")] self._assert_addresses(addresses, expected) # users table not actually updated, so no onupdate - expected = [ - (8, 'ed', 'value'), - (9, 'fred', 'value')] + expected = [(8, "ed", "value"), (9, "fred", "value")] self._assert_users(users, expected) def _assert_foobar(self, foobar, expected): diff --git a/test/sql/test_utils.py b/test/sql/test_utils.py index bd8368cd25..bc124ada97 100644 --- a/test/sql/test_utils.py +++ b/test/sql/test_utils.py @@ -9,34 +9,17 @@ from sqlalchemy.sql import util as sql_util class CompareClausesTest(fixtures.TestBase): def setup(self): m = MetaData() - self.a = Table( - 'a', m, - Column('x', Integer), - Column('y', Integer) - ) + self.a = Table("a", m, Column("x", Integer), Column("y", Integer)) - self.b = Table( - 'b', m, - Column('y', Integer), - Column('z', Integer) - ) + self.b = Table("b", m, Column("y", Integer), Column("z", Integer)) def test_compare_clauselist_associative(self): - l1 = and_( - self.a.c.x == self.b.c.y, - self.a.c.y == self.b.c.z - ) + l1 = and_(self.a.c.x == self.b.c.y, self.a.c.y == self.b.c.z) - l2 = and_( - self.a.c.y == self.b.c.z, - self.a.c.x == self.b.c.y, - ) + l2 = and_(self.a.c.y == self.b.c.z, self.a.c.x == self.b.c.y) - l3 = and_( - self.a.c.x == self.b.c.z, - self.a.c.y == self.b.c.y - ) + l3 = and_(self.a.c.x == self.b.c.z, self.a.c.y == self.b.c.y) is_true(l1.compare(l1)) is_true(l1.compare(l2)) @@ -45,35 +28,33 @@ class CompareClausesTest(fixtures.TestBase): def test_compare_clauselist_not_associative(self): l1 = ClauseList( - self.a.c.x, self.a.c.y, self.b.c.y, operator=operators.sub) + self.a.c.x, self.a.c.y, self.b.c.y, operator=operators.sub + ) l2 = ClauseList( - self.b.c.y, self.a.c.x, self.a.c.y, operator=operators.sub) + self.b.c.y, self.a.c.x, self.a.c.y, operator=operators.sub + ) is_true(l1.compare(l1)) is_false(l1.compare(l2)) def test_compare_clauselist_assoc_different_operator(self): - l1 = and_( - self.a.c.x == self.b.c.y, - self.a.c.y == self.b.c.z - ) + l1 = and_(self.a.c.x == self.b.c.y, self.a.c.y == self.b.c.z) - l2 = or_( - self.a.c.y == self.b.c.z, - self.a.c.x == self.b.c.y, - ) + l2 = or_(self.a.c.y == self.b.c.z, self.a.c.x == self.b.c.y) is_false(l1.compare(l2)) def test_compare_clauselist_not_assoc_different_operator(self): l1 = ClauseList( - self.a.c.x, self.a.c.y, self.b.c.y, operator=operators.sub) + self.a.c.x, self.a.c.y, self.b.c.y, operator=operators.sub + ) l2 = ClauseList( - self.a.c.x, self.a.c.y, self.b.c.y, operator=operators.div) + self.a.c.x, self.a.c.y, self.b.c.y, operator=operators.div + ) is_false(l1.compare(l2)) @@ -83,9 +64,11 @@ class CompareClausesTest(fixtures.TestBase): b3 = bindparam("bar", type_=Integer()) b4 = bindparam("foo", type_=String()) - def c1(): return 5 + def c1(): + return 5 - def c2(): return 6 + def c2(): + return 6 b5 = bindparam("foo", type_=Integer(), callable_=c1) b6 = bindparam("foo", type_=Integer(), callable_=c2) @@ -114,7 +97,4 @@ class MiscTest(fixtures.TestBase): class MyElement(ColumnElement): pass - eq_( - sql_util.find_tables(MyElement(), check_columns=True), - [] - ) + eq_(sql_util.find_tables(MyElement(), check_columns=True), [])